HazardInspector/Qwen_app.py

353 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from lib.json_fun import f_detections_to_objects, load_json_data
from lib.qwen_fun import get_annnotated_frame_for_ai_without_xyxy, hazard_inspection, merge_conflict_inspection_data, report_generator, search_knowledge_base
from encodings.punycode import T
from lib.qwen_fun_vid import generate_video_to_objects
"""
测试 在给定标注框与类别原始视频经过转换之后AI能否准确识别物体特征
"""
from tkinter import N
from datetime import datetime
import json
import os
from lib.qwen_fun import save_json_to_file, upload_files_and_get_urls_concurrently
from lib.sam3 import SAM3
import cv2
import json
import numpy as np
from pathlib import Path
import gradio as gr
fps = 0
VIDEO_FOLDER: str = "input"
def load_file_list() -> list[str]:
"""
加载指定文件夹下的所有文件名称(不含子目录)
返回:文件名列表
"""
file_names: list[str] = []
if not os.path.isdir(VIDEO_FOLDER):
return file_names
for item in os.listdir(VIDEO_FOLDER):
item_path: str = os.path.join(VIDEO_FOLDER, item)
if os.path.isfile(item_path):
file_names.append(item)
return file_names
def reload_files():
"""
刷新文件列表并设置默认值
"""
file_list = load_file_list()
default_value = file_list[0] if file_list else None
return gr.update(choices=file_list, value=default_value)
def get_full_vid_path(vid_file: str) -> str:
"""
获取视频绝对路径
"""
full_path: str = os.path.join(os.getcwd(), VIDEO_FOLDER, vid_file)
gr.Info(full_path)
return full_path
def update_preview(frame_idx: int, vid_file: str):
vid_name: str = Path(vid_file).stem
output_dir: str = f"output/{vid_name}"
global fps
idx = int(frame_idx // fps)
img_path = f"{output_dir}/boxes/frame_{idx:04d}.jpg"
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
result = f.read()
dic = json.loads(result)
class_tag = get_class_tag_by_frame(dic, frame_idx, fps)
return img_path, class_tag
def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True):
vid_name = vid_file.split(".")[0] # 视频名称(不含后缀)
vid_end = f".{vid_file.split('.')[-1]}" # 视频后缀
use_url_cache = True # 是否使用 URL 缓存,避免重复上传视频
enable_thinking = True # 是否启用思考模式
run_vid_process = True # 是否运行视频处理流程(提取物体视频)
input_video_path = f"input/{vid_name}{vid_end}"
output_dir = f"output/{vid_name}"
frame_detections_path = f"{output_dir}/frame_detections.json"
objects_json_path = f"{output_dir}/objects.json"
vid_dir = f"{output_dir}/obj_vids"
json_result: dict = {}
rule_dict: dict = load_json_data('知识库/rule.json')
video_url_file = f"{output_dir}/video_url.json"
time_data: dict = {
# 保存开始时间字符串
"start_time": str(datetime.now())
}
interval = 1
annotated_frames: SAM3 = SAM3()
# 获取总帧数与帧率
cap = cv2.VideoCapture(input_video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
global fps
fps = int(cap.get(cv2.CAP_PROP_FPS))
cap.release()
if output_dir:
os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错
with open(f"{output_dir}/time.json", "w") as f:
f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
#================获取物体信息=================
#这里可以换成你的yolo运行逻辑输出文件保存在frame_detections_path
if run_sam3: # 开关判定
# 针对厂房防火分区
annotated_frames.run(input_video_path, output_dir, "lib/class_list/1.厂房防火.json")
else: # 不运行sam3直接加载之前的结果
annotated_frames.load_from_json(frame_detections_path)
# print(annotated_frames.data)
#================隐患检查=================
if run_inspection: # 开关判定
# 提取物体信息
f_detections_to_objects(
frame_detections_path,
objects_json_path
)
obj = json.load(open(objects_json_path, "r", encoding="utf-8"))
class_list = obj["class_list"]
obj_dict = obj["track_id_list"]
if run_vid_process:
# 生成物体视频
generate_video_to_objects(
obj_dict,
input_video_path,
output_dir=vid_dir,
)
# 上传视频并获取 URL
vid_dict = get_vid_dict(vid_dir, obj_dict, video_url_file, use_url_cache)
_,json_result_str = hazard_inspection(
output_dir,
obj_dict,
rule_dict,
class_list,
vid_dict,
fps=fps,
enable_thinking=enable_thinking
)
json_result: dict = json.loads(json_result_str)
else: # 不运行隐患检查,直接加载之前的结果
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
json_result = json.load(f)
# 保存结束时间字符串
time_data["end_time"] = str(datetime.now())
with open(f"{output_dir}/time.json", "w") as f:
f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
#================生成报告=================
if gen_report:
with open(f"知识库/rule.json", "r", encoding="utf-8") as f:
rule_definitions = json.load(f)
report_generator(
video_path=input_video_path, # 视频文件路径
detection_data=annotated_frames.data(), # 物体检测数据
hazard_results=json_result, # 隐患检查结果
rule_definitions=rule_definitions, # 规则定义
output_path=output_dir, # 输出文件夹
frame_interval=interval, # 帧间隔(可根据实际视频帧率调整)
)
#================保存结束时间=================
# 保存结束时间字符串
time_data["end_time"] = str(datetime.now())
with open(f"{output_dir}/time.json", "w") as f:
f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
#================更新预览=================
img_path, class_tag = update_preview(0, vid_name)
return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag, json_result
def get_vid_dict(vid_dir: str, obj_dict: dict, video_url_file: str, use_url_cache: bool) -> dict:
"""获取视频字典
参数:
vid_dir: 物体视频存放目录
obj_dict: 物品字典键为物品TrackID值为物品信息
video_url_file: 视频URL文件路径
use_url_cache: 是否使用缓存的URL
返回:
视频字典键为物品TrackID值为视频本地地址和视频URL
"""
vid_dict = {}
# 检查URL是否已存在
if use_url_cache:
if os.path.exists(video_url_file):
try:
with open(video_url_file, "r", encoding="utf-8") as f:
loaded_data = json.load(f)
# 验证数据结构:确保值是字典且包含必要键
for track_id, value in loaded_data.items():
if isinstance(value, dict) and "vid_path" in value and "vid_url" in value:
vid_dict[track_id] = value
else:
print(f"跳过无效的视频信息: {track_id} -> {value}")
except Exception as e:
print(f"读取URL文件失败: {e}")
# 初始化未存在的track_id
for track_id in obj_dict.keys():
if track_id in vid_dict:
continue
try:
vid_path = f"{vid_dir}/obj_{int(track_id):03d}.mp4"
vid_dict[track_id] = {"vid_path": vid_path, "vid_url": None}
except (ValueError, TypeError):
print(f"无效的track_id: {track_id},跳过")
# 找到未上传的视频
unuploaded_vids = []
for track_id, vid_info in vid_dict.items():
if isinstance(vid_info, dict) and vid_info.get("vid_url") is None:
unuploaded_vids.append(vid_info["vid_path"])
# 上传未上传的视频并获取 URL
uploaded_urls = upload_files_and_get_urls_concurrently(
file_path_list=unuploaded_vids,
max_workers=8
)
# 更新视频URL
for track_id, vid_info in vid_dict.items():
if isinstance(vid_info, dict) and vid_info.get("vid_url") is None:
if vid_info["vid_path"] in unuploaded_vids:
idx = unuploaded_vids.index(vid_info["vid_path"])
if idx < len(uploaded_urls):
vid_info["vid_url"] = uploaded_urls[idx]
# 保存字典为 JSON 文件
with open(video_url_file, "w", encoding="utf-8") as f:
json.dump(vid_dict, f, ensure_ascii=False, indent=4)
return vid_dict
# 创建 Gradio 页面
with gr.Blocks() as demo:
gr.Markdown("# 📸 隐患排查系统 (sam3 + qwen3.5-27b)")
with gr.Row():
with gr.Column():
initial_files: list[str] = load_file_list()
default_video = initial_files[0] if initial_files else None
vid_file = gr.Dropdown(
label="视频",
choices=initial_files,
value=default_video, # 默认选中第一个
interactive=True,
)
reload_file_list_button = gr.Button("刷新视频列表")
reload_file_list_button.click(fn=reload_files, inputs=[], outputs=[vid_file])
run_sam3 = gr.Checkbox(label="1. 运行 SAM3 模型", value=True)
run_inspection = gr.Checkbox(label="2. 运行隐患排查", value=True)
gen_report = gr.Checkbox(label="3. 生成报告", value=True)
audio_recognition = gr.Checkbox(label="4. 运行音频识别", value=False)
run_button = gr.Button("运行", variant="primary")
get_vid_path_btn = gr.Button("获取视频路径")
full_path_text = gr.Textbox(visible=False)
with gr.Column():
preview = gr.Image(label="预览", scale=2)
img_slider = gr.Slider(label="帧索引", minimum=0, maximum=0, value=0, step=1)
textbox = gr.Textbox(label="隐患结果", lines=10)
jsonbox = gr.JSON(label="隐患结果json")
run_button.click(fn=run, inputs=[vid_file, run_sam3, run_inspection, gen_report], outputs=[img_slider, preview, textbox, jsonbox])
img_slider.change(fn=update_preview, inputs=[img_slider, vid_file], outputs=[preview, textbox], show_progress="hidden")
get_vid_path_btn.click(fn=get_full_vid_path, inputs=vid_file, outputs=full_path_text)
def get_class_tag_by_frame(data, idx, fps):
"""
根据给定的帧索引 (idx),返回在该帧范围内的所有对象的 class:tag 信息。
参数:
data (dict): 包含 'tag', 'objects' 的字典数据hazard_inspection.json格式
idx (int): 需要查询的帧索引
返回:
str: 符合条件的 class:tag 列表,多个结果之间用换行符分隔。如果没有匹配项,返回空字符串。
"""
# 参数检查
if not isinstance(data, dict):
raise ValueError("数据必须是字典类型")
if 'tag' not in data or 'objects' not in data:
raise ValueError("数据必须包含 'tag''objects'")
tag_list = data['tag']
objects = data['objects']
interval = fps
idx = int(idx / interval) #转换
# 用于存储符合条件的 class:tag 字符串
result = []
all_class_tag = []
# 遍历每个物体
for obj in objects:
# 根据 tag_id 获取对应的隐患标签字符串
tag_id = obj.get('tag_id', 0)
class_id = obj.get('class_id', 0)
tag_str = tag_list[tag_id] if tag_id < len(tag_list) else f"未知标签({tag_id})"
location = obj.get('location', '')
start_frame = obj.get('start_frame', 0)
level = obj.get('level', '')
# 检查帧范围是否包含 idx
if start_frame == idx:
result.append(f"{tag_str} | 等级:{level} | 位置: {location}")
all_class_tag.append(f"{tag_str} | class_id:{class_id} | 等级:{level} | 开始帧:{start_frame} | 位置:{location}")
# 使用换行符连接所有结果
output = f"当前帧隐患:\n"+"\n".join(result)+"\n\n"+"所有隐患对象信息:\n"+"\n".join(all_class_tag)
return output
# 启动应用
if __name__ == "__main__":
demo.launch(
debug=True,
allowed_paths=[VIDEO_FOLDER]
)