diff --git a/lib/json_fun.py b/lib/json_fun.py index 507a718..79964e6 100644 --- a/lib/json_fun.py +++ b/lib/json_fun.py @@ -59,9 +59,9 @@ def f_detections_to_objects(detections_path: str, output_path: str): if track_info[track_id]["class_id"] is None: track_info[track_id]["class_id"] = class_id - if frame < track_info[track_id]["start_frame"]: + if frame < track_info[track_id]["start_frame"]: # type: ignore track_info[track_id]["start_frame"] = frame - if frame > track_info[track_id]["end_frame"]: + if frame > track_info[track_id]["end_frame"]: # type: ignore track_info[track_id]["end_frame"] = frame result = { diff --git a/lib/qwen_fun.py b/lib/qwen_fun.py index f6e7ba1..568117b 100644 --- a/lib/qwen_fun.py +++ b/lib/qwen_fun.py @@ -407,9 +407,11 @@ def hazard_inspection( for track_id, track_info in obj_dict.items(): class_str = class_list[track_info["class_id"]] vid_url = vid_dict[track_id]["vid_url"] - reasoning_text, answer_text = single_object_inspection(vid_url, class_str, rule_dict, prompt_str, enable_thinking) - - + result = single_object_inspection(vid_url, class_str, rule_dict, prompt_str, enable_thinking) + if result == "skip": + continue + else: + reasoning_text, answer_text = result # 整合思考过程 all_reasoning_text += f"\n\n# 类别: {track_id}:{class_str}\n{reasoning_text}" @@ -438,9 +440,9 @@ def hazard_inspection( # tag_id if "tag_id" in obj and obj["tag_id"] < len(class_result.get("tag", [])): - tag = class_result["tag"][obj["tag_id"]] + tag: str = class_result["tag"][obj["tag_id"]] else: - tag = None + tag: str = "" if tag not in all_tags: all_tags.append(tag) @@ -457,9 +459,9 @@ def hazard_inspection( temp_obj["base_id"] = all_bases.index(base) # 隐患等级 - if get_hazard_level_from_rule(rule_dict, class_list[track_info["class_id"]], tag) == "重大隐患": + if get_hazard_level_from_rule(rule_dict, class_str, tag) == "重大隐患": temp_obj["level"] = 1 - elif get_hazard_level_from_rule(rule_dict, class_list[track_info["class_id"]], tag) == "一般隐患": + elif get_hazard_level_from_rule(rule_dict, class_str, tag) == "一般隐患": temp_obj["level"] = 0 else: temp_obj["level"] = -1, @@ -501,13 +503,13 @@ def hazard_inspection( return all_reasoning_text, json_result -def single_object_inspection(vid_url: str, class_str: str, rule_dict: dict, prompt_str: str, enable_thinking: bool = True) -> tuple[str, str]: +def single_object_inspection(vid_url: str, class_str: str, rule_dict: dict, prompt_str: str, enable_thinking: bool = True) -> tuple[str, str] | str: # 获取检查规则 rule = get_rule_by_class_str(class_str, rule_dict) if rule == "": print(f"类别 {class_str} 没有检查规则,跳过检查") - return "" + return "skip" else: print(f"类别 {class_str} 的检查规则: {rule}") diff --git a/lib/qwen_fun_vid.py b/lib/qwen_fun_vid.py index c3fd6e2..73a7d6e 100644 --- a/lib/qwen_fun_vid.py +++ b/lib/qwen_fun_vid.py @@ -10,7 +10,7 @@ import cv2 from pathlib import Path def generate_video_to_objects( - obj_dict: dict[dict], + obj_dict: dict[str, dict], input_video_path: str, output_dir: str, ) -> None: @@ -71,7 +71,7 @@ def generate_video_to_objects( total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) min_frames = int(min_seconds * fps) - fourcc = cv2.VideoWriter_fourcc(*'avc1') + fourcc = cv2.VideoWriter_fourcc(*'avc1') # type: ignore # 3. 预处理所有物体的帧范围 + 初始化写入器 total_objects = len(obj_dict)