HazardInspector/Qwen_app.py

302 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from lib.qwen_fun import get_annnotated_frame_for_ai_without_xyxy, hazard_inspection, merge_conflict_inspection_data, report_generator, search_knowledge_base
from encodings.punycode import T
"""
测试 在给定标注框与类别原始视频经过转换之后AI能否准确识别物体特征
"""
from tkinter import N
from datetime import datetime
import json
import os
from lib.qwen_fun import save_json_to_file, upload_files_and_get_urls_concurrently
from lib.sam3 import SAM3
import cv2
import json
import numpy as np
from pathlib import Path
import gradio as gr
fps = 0
<<<<<<< HEAD
VIDEO_FOLDER: str = "input"
def load_file_list() -> list[str]:
"""
加载指定文件夹下的所有文件名称(不含子目录)
返回:文件名列表
"""
file_names: list[str] = []
if not os.path.isdir(VIDEO_FOLDER):
return file_names
for item in os.listdir(VIDEO_FOLDER):
item_path: str = os.path.join(VIDEO_FOLDER, item)
if os.path.isfile(item_path):
file_names.append(item)
return file_names
def get_full_vid_path(vid_file: str) -> str:
"""
获取视频绝对路径
"""
full_path: str = os.path.join(os.getcwd(), VIDEO_FOLDER, vid_file)
gr.Info(full_path)
return full_path
def update_preview(frame_idx: int, vid_file: str):
vid_name: str = Path(vid_file).stem
=======
def update_preview(frame_idx: int, vid_name: str):
>>>>>>> 7562de4 (2 预览图还有点问题)
output_dir: str = f"output/{vid_name}"
global fps
idx = int(frame_idx // fps)
img_path = f"{output_dir}/boxes/frame_{idx:04d}.jpg"
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
result = f.read()
dic = json.loads(result)
class_tag = get_class_tag_by_frame(dic, frame_idx, fps)
return img_path, class_tag
def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True,):
#================初始化=================
vid_name: str = Path(vid_file).stem
vid_end: str = Path(vid_file).suffix
vid_path: str = f"./input/{vid_name}{vid_end}"
cap = cv2.VideoCapture(vid_path)
global fps
fps = int(cap.get(cv2.CAP_PROP_FPS)) # 获取视频FPS
cap.release()
output_dir: str = f"output/{vid_name}"
# interval = int(fps / 5)
interval = fps
conf = 0.7
time_data: dict = {}
annotated_frames: SAM3 = SAM3()
result: dict = {}
if output_dir:
os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错
#================获取物体信息=================
# 保存开始时间字符串
time_data["start_time"] = str(datetime.now())
with open(f"{output_dir}/time.json", "w") as f:
f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
if run_sam3:
# 针对厂房防火分区
annotated_frames.run(vid_path, output_dir, "lib/class_list/1.厂房防火.json", interval, conf)
else:
annotated_frames.load_from_json(f"{output_dir}/frame_all.json")
# print(annotated_frames.data)
# 提取ai能看到的部分
ai_frames: dict = get_annnotated_frame_for_ai_without_xyxy(annotated_frames.data(), 1, conf)
save_json_to_file(ai_frames, f"{output_dir}/frame_all_ai.json")
#================隐患检查=================
if run_inspection:
if vid_path.startswith("oss"):
video_url = vid_path
else:
video_url: str|None = None
video_url_file = f"{output_dir}/video_url.json"
# 检查URL是否已存在
if os.path.exists(video_url_file):
try:
with open(video_url_file, "r", encoding="utf-8") as f:
url_data = json.load(f)
if vid_name in url_data:
video_url = url_data[vid_name]
print(f"使用已存在的URL: {video_url}")
except Exception as e:
print(f"读取URL文件失败: {e}")
# 如果URL不存在上传文件
if video_url is None:
print(f"上传视频文件: {vid_path}")
video_url = upload_files_and_get_urls_concurrently(
file_path_list=[vid_path],
max_workers=8
)[0]
if video_url:
# 保存为JSON格式包含文件名和URL的键值对
url_data = {}
if os.path.exists(video_url_file):
try:
with open(video_url_file, "r", encoding="utf-8") as f:
url_data = json.load(f)
except:
pass
url_data[vid_name] = video_url
with open(video_url_file, "w", encoding="utf-8") as f:
json.dump(url_data, f, ensure_ascii=False, indent=4)
print(f"URL已保存: {video_url}")
if video_url is None:
raise ValueError("视频上传失败,无法获取 URL")
<<<<<<< HEAD
result_test: str
reason_test: str
reason_test, result_test = hazard_inspection(ai_frames, video_url, enable_thinking=True, fps=2)
result = json.loads(result_test)
=======
result = json.loads(hazard_inspection(ai_frames, video_url, enable_thinking=False, fps=2)[1])
>>>>>>> 7562de4 (2 预览图还有点问题)
# merged_result = merge_conflict_inspection_data(result)
result["class"] = ai_frames["class_list"]
with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f:
f.write(json.dumps(result, ensure_ascii=False, indent=4))
with open(f"{output_dir}/hazard_inspection_reason.json", "w", encoding="utf-8") as f:
f.write(json.dumps(reason_test, ensure_ascii=False, indent=4))
# with open(f"{output_dir}/hazard_inspection_merged.json", "w", encoding="utf-8") as f:
# f.write(json.dumps(merged_result, ensure_ascii=False, indent=4))
#================生成报告=================
if gen_report:
if result == {}:
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
result = json.load(f)
with open(f"知识库/rule.json", "r", encoding="utf-8") as f:
rule_definitions = json.load(f)
report_generator(
video_path=vid_path, # 视频文件路径
detection_data=annotated_frames.data(), # 物体检测数据
hazard_results=result, # 隐患检查结果
rule_definitions=rule_definitions, # 规则定义
output_path=output_dir, # 输出文件夹
frame_interval=interval, # 帧间隔(可根据实际视频帧率调整)
)
# 保存结束时间字符串
time_data["end_time"] = str(datetime.now())
with open(f"{output_dir}/time.json", "w") as f:
f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
#================更新预览=================
# 获取总帧数
cap = cv2.VideoCapture(vid_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
img_path, class_tag = update_preview(0, vid_name)
return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag
# 创建 Gradio 页面
with gr.Blocks() as demo:
gr.Markdown("# 📸 隐患排查系统 (sam3 + qwen3.5-27b)")
with gr.Row():
with gr.Column():
<<<<<<< HEAD
initial_files: list[str] = load_file_list()
default_video = initial_files[0] if initial_files else None
vid_file = gr.Dropdown(
label="视频",
choices=initial_files,
value=default_video, # 默认选中第一个
interactive=True,
)
=======
vid_name = gr.Textbox(label="视频名称", value="Miehhuoxqih")
vid_end = gr.Textbox(label="视频后缀", value=".AVI")
>>>>>>> 7562de4 (2 预览图还有点问题)
run_sam3 = gr.Checkbox(label="1. 运行 SAM3 模型", value=True)
run_inspection = gr.Checkbox(label="2. 运行隐患排查", value=True)
gen_report = gr.Checkbox(label="3. 生成报告", value=True)
audio_recognition = gr.Checkbox(label="4. 运行音频识别", value=False)
run_button = gr.Button("运行", variant="primary")
get_vid_path_btn = gr.Button("获取视频路径")
full_path_text = gr.Textbox(visible=False)
with gr.Column():
preview = gr.Image(label="预览", scale=2)
img_slider = gr.Slider(label="帧索引", minimum=0, maximum=0, value=0, step=1)
textbox = gr.Textbox(label="隐患结果", lines=10)
run_button.click(fn=run, inputs=[vid_file, run_sam3, run_inspection, gen_report], outputs=[img_slider, preview, textbox])
img_slider.change(fn=update_preview, inputs=[img_slider, vid_file], outputs=[preview, textbox], show_progress="hidden")
get_vid_path_btn.click(fn=get_full_vid_path, inputs=vid_file, outputs=full_path_text)
def get_class_tag_by_frame(data, idx, fps):
"""
根据给定的帧索引 (idx),返回在该帧范围内的所有对象的 class:tag 信息。
参数:
data (dict): 包含 'class', 'tag', 'objects' 的字典数据
idx (int): 需要查询的帧索引
返回:
str: 符合条件的 class:tag 列表,多个结果之间用换行符分隔。如果没有匹配项,返回空字符串。
"""
# 参数检查
if not isinstance(data, dict):
raise ValueError("数据必须是字典类型")
if 'class' not in data or 'tag' not in data or 'objects' not in data:
raise ValueError("数据必须包含 'class', 'tag''objects'")
class_list = data['class']
tag_list = data['tag']
objects = data['objects']
interval = fps
idx = int(idx / interval) #转换
# 用于存储符合条件的 class:tag 字符串
result = []
all_class_tag = []
# 遍历每个物体
for obj in objects:
# 根据 class_id 和 tag_id 获取对应的字符串
class_str = class_list[obj['class_id']]
tag_str = tag_list[obj['tag_id']]
location = obj.get('location', '')
# 检查帧范围是否包含 idx
if obj['start_frame'] == idx:
result.append(f"{class_str}:{tag_str} (位置: {location})")
all_class_tag.append(f"{class_str}:{tag_str} (开始帧: {obj['start_frame']*interval}, 位置: {location})")
# 使用换行符连接所有结果
output = f"当前帧隐患:\n"+"\n".join(result)+"\n\n"+"所有对象的 class:tag 信息:\n"+"\n".join(all_class_tag)
return output
# 启动应用
if __name__ == "__main__":
demo.launch(
debug=True,
allowed_paths=[VIDEO_FOLDER]
)