修复跳过检测导致的json解析失败
This commit is contained in:
parent
41573ac5c7
commit
dd973a90b3
|
|
@ -59,9 +59,9 @@ def f_detections_to_objects(detections_path: str, output_path: str):
|
||||||
if track_info[track_id]["class_id"] is None:
|
if track_info[track_id]["class_id"] is None:
|
||||||
track_info[track_id]["class_id"] = class_id
|
track_info[track_id]["class_id"] = class_id
|
||||||
|
|
||||||
if frame < track_info[track_id]["start_frame"]:
|
if frame < track_info[track_id]["start_frame"]: # type: ignore
|
||||||
track_info[track_id]["start_frame"] = frame
|
track_info[track_id]["start_frame"] = frame
|
||||||
if frame > track_info[track_id]["end_frame"]:
|
if frame > track_info[track_id]["end_frame"]: # type: ignore
|
||||||
track_info[track_id]["end_frame"] = frame
|
track_info[track_id]["end_frame"] = frame
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
|
|
|
||||||
|
|
@ -407,9 +407,11 @@ def hazard_inspection(
|
||||||
for track_id, track_info in obj_dict.items():
|
for track_id, track_info in obj_dict.items():
|
||||||
class_str = class_list[track_info["class_id"]]
|
class_str = class_list[track_info["class_id"]]
|
||||||
vid_url = vid_dict[track_id]["vid_url"]
|
vid_url = vid_dict[track_id]["vid_url"]
|
||||||
reasoning_text, answer_text = single_object_inspection(vid_url, class_str, rule_dict, prompt_str, enable_thinking)
|
result = single_object_inspection(vid_url, class_str, rule_dict, prompt_str, enable_thinking)
|
||||||
|
if result == "skip":
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
reasoning_text, answer_text = result
|
||||||
|
|
||||||
# 整合思考过程
|
# 整合思考过程
|
||||||
all_reasoning_text += f"\n\n# 类别: {track_id}:{class_str}\n{reasoning_text}"
|
all_reasoning_text += f"\n\n# 类别: {track_id}:{class_str}\n{reasoning_text}"
|
||||||
|
|
@ -438,9 +440,9 @@ def hazard_inspection(
|
||||||
|
|
||||||
# tag_id
|
# tag_id
|
||||||
if "tag_id" in obj and obj["tag_id"] < len(class_result.get("tag", [])):
|
if "tag_id" in obj and obj["tag_id"] < len(class_result.get("tag", [])):
|
||||||
tag = class_result["tag"][obj["tag_id"]]
|
tag: str = class_result["tag"][obj["tag_id"]]
|
||||||
else:
|
else:
|
||||||
tag = None
|
tag: str = ""
|
||||||
|
|
||||||
if tag not in all_tags:
|
if tag not in all_tags:
|
||||||
all_tags.append(tag)
|
all_tags.append(tag)
|
||||||
|
|
@ -457,9 +459,9 @@ def hazard_inspection(
|
||||||
temp_obj["base_id"] = all_bases.index(base)
|
temp_obj["base_id"] = all_bases.index(base)
|
||||||
|
|
||||||
# 隐患等级
|
# 隐患等级
|
||||||
if get_hazard_level_from_rule(rule_dict, class_list[track_info["class_id"]], tag) == "重大隐患":
|
if get_hazard_level_from_rule(rule_dict, class_str, tag) == "重大隐患":
|
||||||
temp_obj["level"] = 1
|
temp_obj["level"] = 1
|
||||||
elif get_hazard_level_from_rule(rule_dict, class_list[track_info["class_id"]], tag) == "一般隐患":
|
elif get_hazard_level_from_rule(rule_dict, class_str, tag) == "一般隐患":
|
||||||
temp_obj["level"] = 0
|
temp_obj["level"] = 0
|
||||||
else:
|
else:
|
||||||
temp_obj["level"] = -1,
|
temp_obj["level"] = -1,
|
||||||
|
|
@ -501,13 +503,13 @@ def hazard_inspection(
|
||||||
|
|
||||||
return all_reasoning_text, json_result
|
return all_reasoning_text, json_result
|
||||||
|
|
||||||
def single_object_inspection(vid_url: str, class_str: str, rule_dict: dict, prompt_str: str, enable_thinking: bool = True) -> tuple[str, str]:
|
def single_object_inspection(vid_url: str, class_str: str, rule_dict: dict, prompt_str: str, enable_thinking: bool = True) -> tuple[str, str] | str:
|
||||||
# 获取检查规则
|
# 获取检查规则
|
||||||
rule = get_rule_by_class_str(class_str, rule_dict)
|
rule = get_rule_by_class_str(class_str, rule_dict)
|
||||||
|
|
||||||
if rule == "":
|
if rule == "":
|
||||||
print(f"类别 {class_str} 没有检查规则,跳过检查")
|
print(f"类别 {class_str} 没有检查规则,跳过检查")
|
||||||
return ""
|
return "skip"
|
||||||
else:
|
else:
|
||||||
print(f"类别 {class_str} 的检查规则: {rule}")
|
print(f"类别 {class_str} 的检查规则: {rule}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import cv2
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
def generate_video_to_objects(
|
def generate_video_to_objects(
|
||||||
obj_dict: dict[dict],
|
obj_dict: dict[str, dict],
|
||||||
input_video_path: str,
|
input_video_path: str,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -71,7 +71,7 @@ def generate_video_to_objects(
|
||||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
min_frames = int(min_seconds * fps)
|
min_frames = int(min_seconds * fps)
|
||||||
|
|
||||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
fourcc = cv2.VideoWriter_fourcc(*'avc1') # type: ignore
|
||||||
|
|
||||||
# 3. 预处理所有物体的帧范围 + 初始化写入器
|
# 3. 预处理所有物体的帧范围 + 初始化写入器
|
||||||
total_objects = len(obj_dict)
|
total_objects = len(obj_dict)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue