From 754bfad6e9f3e6407337a1c76c114614f980cf19 Mon Sep 17 00:00:00 2001 From: yueliuli <1628111725@qq.com> Date: Thu, 23 Apr 2026 16:10:42 +0800 Subject: [PATCH] =?UTF-8?q?bug=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 补齐检测结束时间保存代码 - 修复检测结果json格式错误时的处理逻辑 - run函数提添加结果json输出 - 确保fps数据存入全局变量 --- Qwen_app.py | 165 ++++++-------------------------------------- lib/qwen_fun.py | 177 +++++++++++++++++++++++++++++++----------------- 2 files changed, 135 insertions(+), 207 deletions(-) diff --git a/Qwen_app.py b/Qwen_app.py index a77cdf4..acf6315 100644 --- a/Qwen_app.py +++ b/Qwen_app.py @@ -70,141 +70,6 @@ def update_preview(frame_idx: int, vid_file: str): return img_path, class_tag -# def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True,): -# #================初始化================= -# vid_name: str = Path(vid_file).stem -# vid_end: str = Path(vid_file).suffix -# vid_path: str = f"./input/{vid_name}{vid_end}" - -# cap = cv2.VideoCapture(vid_path) -# global fps -# fps = int(cap.get(cv2.CAP_PROP_FPS)) # 获取视频FPS -# cap.release() - -# output_dir: str = f"output/{vid_name}" -# # interval = int(fps / 5) -# interval = fps -# conf = 0.7 -# time_data: dict = {} - -# annotated_frames: SAM3 = SAM3() -# result: dict = {} - -# if output_dir: -# os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错 - - - -# #================获取物体信息================= -# # 保存开始时间字符串 -# time_data["start_time"] = str(datetime.now()) -# with open(f"{output_dir}/time.json", "w") as f: -# f.write(json.dumps(time_data, ensure_ascii=False, indent=4)) - -# if run_sam3: -# # 针对厂房防火分区 -# annotated_frames.run(vid_path, output_dir, "lib/class_list/1.厂房防火.json", interval, conf) -# else: -# annotated_frames.load_from_json(f"{output_dir}/frame_all.json") -# # print(annotated_frames.data) - -# # 提取ai能看到的部分 -# ai_frames: dict = get_annnotated_frame_for_ai_without_xyxy(annotated_frames.data(), 1, conf) -# save_json_to_file(ai_frames, f"{output_dir}/frame_all_ai.json") - -# #================隐患检查================= - -# if run_inspection: -# if vid_path.startswith("oss"): -# video_url = vid_path -# else: -# video_url: str|None = None -# video_url_file = f"{output_dir}/video_url.json" - -# # 检查URL是否已存在 -# if os.path.exists(video_url_file): -# try: -# with open(video_url_file, "r", encoding="utf-8") as f: -# url_data = json.load(f) -# if vid_name in url_data: -# video_url = url_data[vid_name] -# print(f"使用已存在的URL: {video_url}") -# except Exception as e: -# print(f"读取URL文件失败: {e}") - -# # 如果URL不存在,上传文件 -# if video_url is None: -# print(f"上传视频文件: {vid_path}") -# video_url = upload_files_and_get_urls_concurrently( -# file_path_list=[vid_path], -# max_workers=8 -# )[0] -# if video_url: -# # 保存为JSON格式,包含文件名和URL的键值对 -# url_data = {} -# if os.path.exists(video_url_file): -# try: -# with open(video_url_file, "r", encoding="utf-8") as f: -# url_data = json.load(f) -# except: -# pass - -# url_data[vid_name] = video_url -# with open(video_url_file, "w", encoding="utf-8") as f: -# json.dump(url_data, f, ensure_ascii=False, indent=4) -# print(f"URL已保存: {video_url}") - -# if video_url is None: -# raise ValueError("视频上传失败,无法获取 URL") - -# result_test: str -# reason_test: str -# reason_test, result_test = hazard_inspection(ai_frames, video_url, enable_thinking=True, fps=2) -# result = json.loads(result_test) -# # merged_result = merge_conflict_inspection_data(result) - -# result["class"] = ai_frames["class_list"] - -# with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f: -# f.write(json.dumps(result, ensure_ascii=False, indent=4)) -# with open(f"{output_dir}/hazard_inspection_reason.json", "w", encoding="utf-8") as f: -# f.write(json.dumps(reason_test, ensure_ascii=False, indent=4)) - -# # with open(f"{output_dir}/hazard_inspection_merged.json", "w", encoding="utf-8") as f: -# # f.write(json.dumps(merged_result, ensure_ascii=False, indent=4)) - -# #================生成报告================= -# if gen_report: -# if result == {}: -# with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f: -# result = json.load(f) - -# with open(f"知识库/rule.json", "r", encoding="utf-8") as f: -# rule_definitions = json.load(f) - -# report_generator( -# video_path=vid_path, # 视频文件路径 -# detection_data=annotated_frames.data(), # 物体检测数据 -# hazard_results=result, # 隐患检查结果 -# rule_definitions=rule_definitions, # 规则定义 -# output_path=output_dir, # 输出文件夹 -# frame_interval=interval, # 帧间隔(可根据实际视频帧率调整) -# ) - -# # 保存结束时间字符串 -# time_data["end_time"] = str(datetime.now()) -# with open(f"{output_dir}/time.json", "w") as f: -# f.write(json.dumps(time_data, ensure_ascii=False, indent=4)) - -# #================更新预览================= - -# # 获取总帧数 -# cap = cv2.VideoCapture(vid_path) -# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) -# cap.release() -# img_path, class_tag = update_preview(0, vid_name) -# return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag - def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True): vid_name = vid_file.split(".")[0] # 视频名称(不含后缀) vid_end = f".{vid_file.split('.')[-1]}" # 视频后缀 @@ -218,6 +83,7 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r frame_detections_path = f"{output_dir}/frame_detections.json" objects_json_path = f"{output_dir}/objects.json" vid_dir = f"{output_dir}/obj_vids" + json_result: dict = {} rule_dict: dict = load_json_data('知识库/rule.json') video_url_file = f"{output_dir}/video_url.json" @@ -231,6 +97,7 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r # 获取总帧数与帧率 cap = cv2.VideoCapture(input_video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + global fps fps = int(cap.get(cv2.CAP_PROP_FPS)) cap.release() @@ -274,7 +141,7 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r # 上传视频并获取 URL vid_dict = get_vid_dict(vid_dir, obj_dict, video_url_file, use_url_cache) - hazard_inspection( + _,json_result_str = hazard_inspection( output_dir, obj_dict, rule_dict, @@ -283,26 +150,35 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r fps=fps, enable_thinking=enable_thinking ) + json_result: dict = json.loads(json_result_str) + + else: # 不运行隐患检查,直接加载之前的结果 + with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f: + json_result = json.load(f) + + # 保存结束时间字符串 + time_data["end_time"] = str(datetime.now()) + with open(f"{output_dir}/time.json", "w") as f: + f.write(json.dumps(time_data, ensure_ascii=False, indent=4)) + + #================生成报告================= if gen_report: - result: dict = {} - with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f: - result = json.load(f) - with open(f"知识库/rule.json", "r", encoding="utf-8") as f: rule_definitions = json.load(f) report_generator( video_path=input_video_path, # 视频文件路径 detection_data=annotated_frames.data(), # 物体检测数据 - hazard_results=result, # 隐患检查结果 + hazard_results=json_result, # 隐患检查结果 rule_definitions=rule_definitions, # 规则定义 output_path=output_dir, # 输出文件夹 frame_interval=interval, # 帧间隔(可根据实际视频帧率调整) ) + #================保存结束时间================= # 保存结束时间字符串 time_data["end_time"] = str(datetime.now()) with open(f"{output_dir}/time.json", "w") as f: @@ -312,7 +188,9 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r img_path, class_tag = update_preview(0, vid_name) - return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag + return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag, json_result + + def get_vid_dict(vid_dir: str, obj_dict: dict, video_url_file: str, use_url_cache: bool) -> dict: """获取视频字典 @@ -411,10 +289,11 @@ with gr.Blocks() as demo: preview = gr.Image(label="预览", scale=2) img_slider = gr.Slider(label="帧索引", minimum=0, maximum=0, value=0, step=1) textbox = gr.Textbox(label="隐患结果", lines=10) + jsonbox = gr.JSON(label="隐患结果json") - run_button.click(fn=run, inputs=[vid_file, run_sam3, run_inspection, gen_report], outputs=[img_slider, preview, textbox]) + run_button.click(fn=run, inputs=[vid_file, run_sam3, run_inspection, gen_report], outputs=[img_slider, preview, textbox, jsonbox]) img_slider.change(fn=update_preview, inputs=[img_slider, vid_file], outputs=[preview, textbox], show_progress="hidden") get_vid_path_btn.click(fn=get_full_vid_path, inputs=vid_file, outputs=full_path_text) diff --git a/lib/qwen_fun.py b/lib/qwen_fun.py index 568117b..2e01e12 100644 --- a/lib/qwen_fun.py +++ b/lib/qwen_fun.py @@ -359,6 +359,38 @@ def search_knowledge_base( print(f"请求失败,状态码: {response.status_code}") print("错误信息:", response.text) return "" +def format_check(output_data: dict) -> bool: + """检查 JSON 字符串是否符合要求""" + try: + # 检查必要的键 + required_keys = ['tag', 'base', 'objects'] + for key in required_keys: + if key not in output_data: + print(f"Missing required key: {key}") + return False + # 检查tag是否为列表 + if not isinstance(output_data['tag'], list): + print("tag must be a list") + return False + # 检查base是否为列表 + if not isinstance(output_data['base'], list): + print("base must be a list") + return False + # 检查objects是否为列表 + if not isinstance(output_data['objects'], list): + print("objects must be a list") + return False + # 检查objects中的每个元素 + for obj in output_data['objects']: + required_obj_keys = ['tag_id', 'base_id', 'hazard_track_id', 'conf', 'location', 'recommend'] + for key in required_obj_keys: + if key not in obj: + print(f"Missing required key in object: {key}") + return False + return True + except Exception as e: + print(f"Format check failed: {e}") + return False def hazard_inspection( output_dir: str, @@ -407,83 +439,99 @@ def hazard_inspection( for track_id, track_info in obj_dict.items(): class_str = class_list[track_info["class_id"]] vid_url = vid_dict[track_id]["vid_url"] - result = single_object_inspection(vid_url, class_str, rule_dict, prompt_str, enable_thinking) - if result == "skip": - continue - else: - reasoning_text, answer_text = result + result: tuple[str, str] | str = "" + reasoning_text: str = "" + answer_text: str = "" - # 整合思考过程 - all_reasoning_text += f"\n\n# 类别: {track_id}:{class_str}\n{reasoning_text}" + while True: + result = single_object_inspection(vid_url, class_str, rule_dict, prompt_str, enable_thinking) + if result == "skip": + break + else: + reasoning_text, answer_text = result - # 解析 JSON 结果 - try: - # 提取 JSON 部分(去除可能的代码块标记) - json_str = answer_text.strip() - if json_str.startswith("```json"): - json_str = json_str[7:] - if json_str.startswith("```"): - json_str = json_str[3:] - if json_str.endswith("```"): - json_str = json_str[:-3] - json_str = json_str.strip() + # 解析 JSON 结果 + try: + # 提取 JSON 部分(去除可能的代码块标记) + json_str = answer_text.strip() + if json_str.startswith("```json"): + json_str = json_str[7:] + if json_str.startswith("```"): + json_str = json_str[3:] + if json_str.endswith("```"): + json_str = json_str[:-3] + json_str = json_str.strip() - class_result = json.loads(json_str) + class_result = json.loads(json_str) - # 保存class_result到文件,便于调试 - with open(f"{output_dir}/inspection_result_{int(track_id):03d}.txt", "w", encoding="utf-8") as f: - f.write(json.dumps(class_result, ensure_ascii=False, indent=2)) + # 检查 JSON 格式是否正确 + if isinstance(class_result, dict) and format_check(class_result): + class_result = class_result + # 检查是否为列表且包含有效元素 + elif isinstance(class_result, list) and len(class_result) > 0 and format_check(class_result[0]): + class_result = class_result[0] + else: + print("JSON 格式错误, 重试隐患检查") + continue # 重试隐患检查,直到格式正确 - # 收集该类别的结果 - temp_obj = {} - for obj in class_result.get("objects", []): + # 保存class_result到文件,便于调试 + with open(f"{output_dir}/inspection_result_{int(track_id):03d}.txt", "w", encoding="utf-8") as f: + f.write(json.dumps(class_result, ensure_ascii=False, indent=2)) - # tag_id - if "tag_id" in obj and obj["tag_id"] < len(class_result.get("tag", [])): - tag: str = class_result["tag"][obj["tag_id"]] - else: - tag: str = "" + # 整合思考过程 + all_reasoning_text += f"\n\n# 类别: {track_id}:{class_str}\n{reasoning_text}" - if tag not in all_tags: - all_tags.append(tag) - temp_obj["tag_id"] = all_tags.index(tag) + # 收集该类别的结果 + temp_obj = {} + for obj in class_result.get("objects", []): - # base_id - if "base_id" in obj and obj["base_id"] < len(class_result.get("base", [])): - base = class_result["base"][obj["base_id"]] - else: - base = None + # tag_id + if "tag_id" in obj and obj["tag_id"] < len(class_result.get("tag", [])): + tag: str = class_result["tag"][obj["tag_id"]] + else: + tag: str = "" - if base not in all_bases: - all_bases.append(base) - temp_obj["base_id"] = all_bases.index(base) + if tag not in all_tags: + all_tags.append(tag) + temp_obj["tag_id"] = all_tags.index(tag) - # 隐患等级 - if get_hazard_level_from_rule(rule_dict, class_str, tag) == "重大隐患": - temp_obj["level"] = 1 - elif get_hazard_level_from_rule(rule_dict, class_str, tag) == "一般隐患": - temp_obj["level"] = 0 - else: - temp_obj["level"] = -1, + # base_id + if "base_id" in obj and obj["base_id"] < len(class_result.get("base", [])): + base = class_result["base"][obj["base_id"]] + else: + base = None - # 其他字段 - temp_obj["track_id"] = track_id - temp_obj["hazard_track_id"] = hazard_count - temp_obj["class_id"] = track_info["class_id"] - temp_obj["conf"] = obj["conf"] - temp_obj["start_frame"] = track_info["start_frame"] - temp_obj["end_frame"] = track_info["end_frame"] - temp_obj["start_sec"] = round(track_info["start_frame"] / fps, 1) # 处理成x.x秒 - temp_obj["location"] = obj.get("location", "") + if base not in all_bases: + all_bases.append(base) + temp_obj["base_id"] = all_bases.index(base) - hazard_count += 1 + # 隐患等级 + if get_hazard_level_from_rule(rule_dict, class_str, tag) == "重大隐患": + temp_obj["level"] = 1 + elif get_hazard_level_from_rule(rule_dict, class_str, tag) == "一般隐患": + temp_obj["level"] = 0 + else: + temp_obj["level"] = -1, - all_objects.append(temp_obj) + # 其他字段 + temp_obj["track_id"] = track_id + temp_obj["hazard_track_id"] = hazard_count + temp_obj["class_id"] = track_info["class_id"] + temp_obj["conf"] = obj["conf"] + temp_obj["start_frame"] = track_info["start_frame"] + temp_obj["end_frame"] = track_info["end_frame"] + temp_obj["start_sec"] = round(track_info["start_frame"] / fps, 1) # 处理成x.x秒 + temp_obj["location"] = obj.get("location", "") - except json.JSONDecodeError as e: - print(f"解析类别 {class_str} 的 JSON 结果失败: {e}") - print(f"原始输出: {answer_text}") - continue + hazard_count += 1 + + all_objects.append(temp_obj) + + except json.JSONDecodeError as e: + print(f"解析类别 {class_str} 的 JSON 结果失败: {e}") + print(f"原始输出: {answer_text}") + continue + break # 跳出当前循环,继续下一个物体 # 构建最终结果 final_result = { @@ -503,6 +551,7 @@ def hazard_inspection( return all_reasoning_text, json_result + def single_object_inspection(vid_url: str, class_str: str, rule_dict: dict, prompt_str: str, enable_thinking: bool = True) -> tuple[str, str] | str: # 获取检查规则 rule = get_rule_by_class_str(class_str, rule_dict) @@ -686,7 +735,7 @@ def report_generator( start_frame = max(0, min(start_frame, total_frames - 1)) # 获取隐患对应的类名和标签 - class_name = hazard_results["class"][obj["class_id"]] + class_name = hazard_results["class_list"][obj["class_id"]] tag_name = hazard_results["tag"][obj["tag_id"]] original_level = obj["level"] location = obj.get("location", "")