from lib.json_fun import f_detections_to_objects, load_json_data from lib.qwen_fun import get_annnotated_frame_for_ai_without_xyxy, hazard_inspection, merge_conflict_inspection_data, report_generator, search_knowledge_base from encodings.punycode import T from lib.qwen_fun_vid import generate_video_to_objects """ 测试 在给定标注框与类别,原始视频经过转换之后,AI能否准确识别物体特征 """ from tkinter import N from datetime import datetime import json import os from lib.qwen_fun import save_json_to_file, upload_files_and_get_urls_concurrently from lib.sam3 import SAM3 import cv2 import json import numpy as np from pathlib import Path import gradio as gr fps = 0 VIDEO_FOLDER: str = "input" def load_file_list() -> list[str]: """ 加载指定文件夹下的所有文件名称(不含子目录) 返回:文件名列表 """ file_names: list[str] = [] if not os.path.isdir(VIDEO_FOLDER): return file_names for item in os.listdir(VIDEO_FOLDER): item_path: str = os.path.join(VIDEO_FOLDER, item) if os.path.isfile(item_path): file_names.append(item) return file_names def reload_files(): """ 刷新文件列表并设置默认值 """ file_list = load_file_list() default_value = file_list[0] if file_list else None return gr.update(choices=file_list, value=default_value) def get_full_vid_path(vid_file: str) -> str: """ 获取视频绝对路径 """ full_path: str = os.path.join(os.getcwd(), VIDEO_FOLDER, vid_file) gr.Info(full_path) return full_path def update_preview(frame_idx: int, vid_file: str): vid_name: str = Path(vid_file).stem output_dir: str = f"output/{vid_name}" global fps idx = int(frame_idx // fps) img_path = f"{output_dir}/boxes/frame_{idx:04d}.jpg" with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f: result = f.read() dic = json.loads(result) class_tag = get_class_tag_by_frame(dic, frame_idx, fps) return img_path, class_tag # def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True,): # #================初始化================= # vid_name: str = Path(vid_file).stem # vid_end: str = Path(vid_file).suffix # vid_path: str = f"./input/{vid_name}{vid_end}" # cap = cv2.VideoCapture(vid_path) # global fps # fps = int(cap.get(cv2.CAP_PROP_FPS)) # 获取视频FPS # cap.release() # output_dir: str = f"output/{vid_name}" # # interval = int(fps / 5) # interval = fps # conf = 0.7 # time_data: dict = {} # annotated_frames: SAM3 = SAM3() # result: dict = {} # if output_dir: # os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错 # #================获取物体信息================= # # 保存开始时间字符串 # time_data["start_time"] = str(datetime.now()) # with open(f"{output_dir}/time.json", "w") as f: # f.write(json.dumps(time_data, ensure_ascii=False, indent=4)) # if run_sam3: # # 针对厂房防火分区 # annotated_frames.run(vid_path, output_dir, "lib/class_list/1.厂房防火.json", interval, conf) # else: # annotated_frames.load_from_json(f"{output_dir}/frame_all.json") # # print(annotated_frames.data) # # 提取ai能看到的部分 # ai_frames: dict = get_annnotated_frame_for_ai_without_xyxy(annotated_frames.data(), 1, conf) # save_json_to_file(ai_frames, f"{output_dir}/frame_all_ai.json") # #================隐患检查================= # if run_inspection: # if vid_path.startswith("oss"): # video_url = vid_path # else: # video_url: str|None = None # video_url_file = f"{output_dir}/video_url.json" # # 检查URL是否已存在 # if os.path.exists(video_url_file): # try: # with open(video_url_file, "r", encoding="utf-8") as f: # url_data = json.load(f) # if vid_name in url_data: # video_url = url_data[vid_name] # print(f"使用已存在的URL: {video_url}") # except Exception as e: # print(f"读取URL文件失败: {e}") # # 如果URL不存在,上传文件 # if video_url is None: # print(f"上传视频文件: {vid_path}") # video_url = upload_files_and_get_urls_concurrently( # file_path_list=[vid_path], # max_workers=8 # )[0] # if video_url: # # 保存为JSON格式,包含文件名和URL的键值对 # url_data = {} # if os.path.exists(video_url_file): # try: # with open(video_url_file, "r", encoding="utf-8") as f: # url_data = json.load(f) # except: # pass # url_data[vid_name] = video_url # with open(video_url_file, "w", encoding="utf-8") as f: # json.dump(url_data, f, ensure_ascii=False, indent=4) # print(f"URL已保存: {video_url}") # if video_url is None: # raise ValueError("视频上传失败,无法获取 URL") # result_test: str # reason_test: str # reason_test, result_test = hazard_inspection(ai_frames, video_url, enable_thinking=True, fps=2) # result = json.loads(result_test) # # merged_result = merge_conflict_inspection_data(result) # result["class"] = ai_frames["class_list"] # with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f: # f.write(json.dumps(result, ensure_ascii=False, indent=4)) # with open(f"{output_dir}/hazard_inspection_reason.json", "w", encoding="utf-8") as f: # f.write(json.dumps(reason_test, ensure_ascii=False, indent=4)) # # with open(f"{output_dir}/hazard_inspection_merged.json", "w", encoding="utf-8") as f: # # f.write(json.dumps(merged_result, ensure_ascii=False, indent=4)) # #================生成报告================= # if gen_report: # if result == {}: # with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f: # result = json.load(f) # with open(f"知识库/rule.json", "r", encoding="utf-8") as f: # rule_definitions = json.load(f) # report_generator( # video_path=vid_path, # 视频文件路径 # detection_data=annotated_frames.data(), # 物体检测数据 # hazard_results=result, # 隐患检查结果 # rule_definitions=rule_definitions, # 规则定义 # output_path=output_dir, # 输出文件夹 # frame_interval=interval, # 帧间隔(可根据实际视频帧率调整) # ) # # 保存结束时间字符串 # time_data["end_time"] = str(datetime.now()) # with open(f"{output_dir}/time.json", "w") as f: # f.write(json.dumps(time_data, ensure_ascii=False, indent=4)) # #================更新预览================= # # 获取总帧数 # cap = cv2.VideoCapture(vid_path) # total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # cap.release() # img_path, class_tag = update_preview(0, vid_name) # return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True): vid_name = vid_file.split(".")[0] # 视频名称(不含后缀) vid_end = f".{vid_file.split('.')[-1]}" # 视频后缀 use_url_cache = True # 是否使用 URL 缓存,避免重复上传视频 enable_thinking = True # 是否启用思考模式 run_vid_process = True # 是否运行视频处理流程(提取物体视频) input_video_path = f"input/{vid_name}{vid_end}" output_dir = f"output/{vid_name}" frame_detections_path = f"{output_dir}/frame_detections.json" objects_json_path = f"{output_dir}/objects.json" vid_dir = f"{output_dir}/obj_vids" rule_dict: dict = load_json_data('知识库/rule.json') video_url_file = f"{output_dir}/video_url.json" time_data: dict = { # 保存开始时间字符串 "start_time": str(datetime.now()) } interval = 1 annotated_frames: SAM3 = SAM3() # 获取总帧数与帧率 cap = cv2.VideoCapture(input_video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) cap.release() if output_dir: os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错 with open(f"{output_dir}/time.json", "w") as f: f.write(json.dumps(time_data, ensure_ascii=False, indent=4)) #================获取物体信息================= #这里可以换成你的yolo运行逻辑,输出文件保存在frame_detections_path if run_sam3: # 开关判定 # 针对厂房防火分区 annotated_frames.run(input_video_path, output_dir, "lib/class_list/1.厂房防火.json") else: # 不运行sam3,直接加载之前的结果 annotated_frames.load_from_json(frame_detections_path) # print(annotated_frames.data) #================隐患检查================= if run_inspection: # 开关判定 # 提取物体信息 f_detections_to_objects( frame_detections_path, objects_json_path ) obj = json.load(open(objects_json_path, "r", encoding="utf-8")) class_list = obj["class_list"] obj_dict = obj["track_id_list"] if run_vid_process: # 生成物体视频 generate_video_to_objects( obj_dict, input_video_path, output_dir=vid_dir, ) # 上传视频并获取 URL vid_dict = get_vid_dict(vid_dir, obj_dict, video_url_file, use_url_cache) hazard_inspection( output_dir, obj_dict, rule_dict, class_list, vid_dict, fps=fps, enable_thinking=enable_thinking ) #================生成报告================= if gen_report: if result == {}: with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f: result = json.load(f) with open(f"知识库/rule.json", "r", encoding="utf-8") as f: rule_definitions = json.load(f) report_generator( video_path=input_video_path, # 视频文件路径 detection_data=annotated_frames.data(), # 物体检测数据 hazard_results=result, # 隐患检查结果 rule_definitions=rule_definitions, # 规则定义 output_path=output_dir, # 输出文件夹 frame_interval=interval, # 帧间隔(可根据实际视频帧率调整) ) # 保存结束时间字符串 time_data["end_time"] = str(datetime.now()) with open(f"{output_dir}/time.json", "w") as f: f.write(json.dumps(time_data, ensure_ascii=False, indent=4)) #================更新预览================= img_path, class_tag = update_preview(0, vid_name) return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag def get_vid_dict(vid_dir: str, obj_dict: dict, video_url_file: str, use_url_cache: bool) -> dict: """获取视频字典 参数: vid_dir: 物体视频存放目录 obj_dict: 物品字典,键为物品TrackID,值为物品信息 video_url_file: 视频URL文件路径 use_url_cache: 是否使用缓存的URL 返回: 视频字典,键为物品TrackID,值为视频本地地址和视频URL """ vid_dict = {} # 检查URL是否已存在 if use_url_cache: if os.path.exists(video_url_file): try: with open(video_url_file, "r", encoding="utf-8") as f: vid_dict = json.load(f) except Exception as e: print(f"读取URL文件失败: {e}") for track_id in obj_dict.keys(): if track_id in vid_dict: continue vid_path = f"{vid_dir}/obj_{int(track_id):03d}.mp4" vid_dict[track_id] = {"vid_path": vid_path, "vid_url": None} # 找到未上传的视频 unuploaded_vids = [] for track_id, vid_info in vid_dict.items(): if vid_info["vid_url"] is None: unuploaded_vids.append(vid_info["vid_path"]) # 上传未上传的视频并获取 URL uploaded_urls = upload_files_and_get_urls_concurrently( file_path_list=unuploaded_vids, max_workers=8 ) # 更新视频URL for track_id, vid_info in vid_dict.items(): if vid_info["vid_url"] is None: idx = unuploaded_vids.index(vid_info["vid_path"]) vid_info["vid_url"] = uploaded_urls[idx] # 保存字典为 JSON 文件 with open(video_url_file, "w", encoding="utf-8") as f: json.dump(vid_dict, f, ensure_ascii=False, indent=4) return vid_dict # 创建 Gradio 页面 with gr.Blocks() as demo: gr.Markdown("# 📸 隐患排查系统 (sam3 + qwen3.5-27b)") with gr.Row(): with gr.Column(): initial_files: list[str] = load_file_list() default_video = initial_files[0] if initial_files else None vid_file = gr.Dropdown( label="视频", choices=initial_files, value=default_video, # 默认选中第一个 interactive=True, ) reload_file_list_button = gr.Button("刷新视频列表") reload_file_list_button.click(fn=reload_files, inputs=[], outputs=[vid_file]) run_sam3 = gr.Checkbox(label="1. 运行 SAM3 模型", value=True) run_inspection = gr.Checkbox(label="2. 运行隐患排查", value=True) gen_report = gr.Checkbox(label="3. 生成报告", value=True) audio_recognition = gr.Checkbox(label="4. 运行音频识别", value=False) run_button = gr.Button("运行", variant="primary") get_vid_path_btn = gr.Button("获取视频路径") full_path_text = gr.Textbox(visible=False) with gr.Column(): preview = gr.Image(label="预览", scale=2) img_slider = gr.Slider(label="帧索引", minimum=0, maximum=0, value=0, step=1) textbox = gr.Textbox(label="隐患结果", lines=10) run_button.click(fn=run, inputs=[vid_file, run_sam3, run_inspection, gen_report], outputs=[img_slider, preview, textbox]) img_slider.change(fn=update_preview, inputs=[img_slider, vid_file], outputs=[preview, textbox], show_progress="hidden") get_vid_path_btn.click(fn=get_full_vid_path, inputs=vid_file, outputs=full_path_text) def get_class_tag_by_frame(data, idx, fps): """ 根据给定的帧索引 (idx),返回在该帧范围内的所有对象的 class:tag 信息。 参数: data (dict): 包含 'tag', 'objects' 的字典数据(hazard_inspection.json格式) idx (int): 需要查询的帧索引 返回: str: 符合条件的 class:tag 列表,多个结果之间用换行符分隔。如果没有匹配项,返回空字符串。 """ # 参数检查 if not isinstance(data, dict): raise ValueError("数据必须是字典类型") if 'tag' not in data or 'objects' not in data: raise ValueError("数据必须包含 'tag' 和 'objects' 键") tag_list = data['tag'] objects = data['objects'] interval = fps idx = int(idx / interval) #转换 # 用于存储符合条件的 class:tag 字符串 result = [] all_class_tag = [] # 遍历每个物体 for obj in objects: # 根据 tag_id 获取对应的隐患标签字符串 tag_id = obj.get('tag_id', 0) class_id = obj.get('class_id', 0) tag_str = tag_list[tag_id] if tag_id < len(tag_list) else f"未知标签({tag_id})" location = obj.get('location', '') start_frame = obj.get('start_frame', 0) level = obj.get('level', '') # 检查帧范围是否包含 idx if start_frame == idx: result.append(f"{tag_str} | 等级:{level} | 位置: {location}") all_class_tag.append(f"{tag_str} | class_id:{class_id} | 等级:{level} | 开始帧:{start_frame} | 位置:{location}") # 使用换行符连接所有结果 output = f"当前帧隐患:\n"+"\n".join(result)+"\n\n"+"所有隐患对象信息:\n"+"\n".join(all_class_tag) return output # 启动应用 if __name__ == "__main__": demo.launch( debug=True, allowed_paths=[VIDEO_FOLDER] )