77 lines
2.2 KiB
Python
77 lines
2.2 KiB
Python
import json
|
|
from collections import defaultdict
|
|
|
|
|
|
def f_detections_to_objects(detections_path: str, output_path: str):
|
|
"""
|
|
将输入文件中的yolo检测数据转换为object数据并保存到输出文件中
|
|
:param detections_path: 输入文件路径
|
|
:param output_path: 输出文件路径
|
|
:return: None
|
|
|
|
detections 格式:
|
|
{
|
|
"frame_id": [
|
|
{
|
|
"xyxy": [x1, y1, x2, y2]
|
|
"confidence": 0.314,
|
|
"track_id": 1,
|
|
"class_id": 1,
|
|
"class_str": "class_str",
|
|
},
|
|
...
|
|
],
|
|
...
|
|
}
|
|
|
|
objects 格式:
|
|
{
|
|
"class_list": [
|
|
"插座",
|
|
"消火栓",
|
|
"配电箱"
|
|
],
|
|
"track_id_list": {
|
|
"0": {
|
|
"class_id": 4,
|
|
"start_frame": 551,
|
|
"end_frame": 597
|
|
},
|
|
}
|
|
}
|
|
"""
|
|
with open(detections_path, 'r', encoding='utf-8') as f:
|
|
detections_data = json.load(f)
|
|
|
|
class_list = []
|
|
track_info = defaultdict(lambda: {"class_id": None, "start_frame": float('inf'), "end_frame": 0})
|
|
|
|
for frame_str, detections in detections_data.items():
|
|
frame = int(frame_str)
|
|
for det in detections:
|
|
track_id = det["track_id"]
|
|
class_str = det["class_str"]
|
|
|
|
if class_str not in class_list:
|
|
class_list.append(class_str)
|
|
class_id = class_list.index(class_str)
|
|
|
|
if track_info[track_id]["class_id"] is None:
|
|
track_info[track_id]["class_id"] = class_id
|
|
|
|
if frame < track_info[track_id]["start_frame"]:
|
|
track_info[track_id]["start_frame"] = frame
|
|
if frame > track_info[track_id]["end_frame"]:
|
|
track_info[track_id]["end_frame"] = frame
|
|
|
|
result = {
|
|
"class_list": list(class_list),
|
|
"track_id_list": dict(track_info)
|
|
}
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
|
|
def load_json_data(json_path: str):
|
|
with open(json_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f) |