353 lines
12 KiB
Python
353 lines
12 KiB
Python
from lib.json_fun import f_detections_to_objects, load_json_data
|
||
from lib.qwen_fun import get_annnotated_frame_for_ai_without_xyxy, hazard_inspection, merge_conflict_inspection_data, report_generator, search_knowledge_base
|
||
from encodings.punycode import T
|
||
|
||
from lib.qwen_fun_vid import generate_video_to_objects
|
||
"""
|
||
测试 在给定标注框与类别,原始视频经过转换之后,AI能否准确识别物体特征
|
||
"""
|
||
|
||
from tkinter import N
|
||
from datetime import datetime
|
||
import json
|
||
import os
|
||
from lib.qwen_fun import save_json_to_file, upload_files_and_get_urls_concurrently
|
||
from lib.sam3 import SAM3
|
||
import cv2
|
||
import json
|
||
import numpy as np
|
||
from pathlib import Path
|
||
import gradio as gr
|
||
|
||
fps = 0
|
||
VIDEO_FOLDER: str = "input"
|
||
|
||
def load_file_list() -> list[str]:
|
||
"""
|
||
加载指定文件夹下的所有文件名称(不含子目录)
|
||
返回:文件名列表
|
||
"""
|
||
file_names: list[str] = []
|
||
|
||
if not os.path.isdir(VIDEO_FOLDER):
|
||
return file_names
|
||
|
||
for item in os.listdir(VIDEO_FOLDER):
|
||
item_path: str = os.path.join(VIDEO_FOLDER, item)
|
||
if os.path.isfile(item_path):
|
||
file_names.append(item)
|
||
|
||
return file_names
|
||
|
||
def reload_files():
|
||
"""
|
||
刷新文件列表并设置默认值
|
||
"""
|
||
file_list = load_file_list()
|
||
default_value = file_list[0] if file_list else None
|
||
return gr.update(choices=file_list, value=default_value)
|
||
|
||
def get_full_vid_path(vid_file: str) -> str:
|
||
"""
|
||
获取视频绝对路径
|
||
"""
|
||
full_path: str = os.path.join(os.getcwd(), VIDEO_FOLDER, vid_file)
|
||
gr.Info(full_path)
|
||
return full_path
|
||
|
||
def update_preview(frame_idx: int, vid_file: str):
|
||
vid_name: str = Path(vid_file).stem
|
||
output_dir: str = f"output/{vid_name}"
|
||
global fps
|
||
idx = int(frame_idx // fps)
|
||
img_path = f"{output_dir}/boxes/frame_{idx:04d}.jpg"
|
||
|
||
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
|
||
result = f.read()
|
||
|
||
dic = json.loads(result)
|
||
class_tag = get_class_tag_by_frame(dic, frame_idx, fps)
|
||
|
||
return img_path, class_tag
|
||
|
||
def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True):
|
||
vid_name = vid_file.split(".")[0] # 视频名称(不含后缀)
|
||
vid_end = f".{vid_file.split('.')[-1]}" # 视频后缀
|
||
use_url_cache = True # 是否使用 URL 缓存,避免重复上传视频
|
||
enable_thinking = True # 是否启用思考模式
|
||
|
||
run_vid_process = True # 是否运行视频处理流程(提取物体视频)
|
||
|
||
input_video_path = f"input/{vid_name}{vid_end}"
|
||
output_dir = f"output/{vid_name}"
|
||
frame_detections_path = f"{output_dir}/frame_detections.json"
|
||
objects_json_path = f"{output_dir}/objects.json"
|
||
vid_dir = f"{output_dir}/obj_vids"
|
||
json_result: dict = {}
|
||
|
||
rule_dict: dict = load_json_data('知识库/rule.json')
|
||
video_url_file = f"{output_dir}/video_url.json"
|
||
time_data: dict = {
|
||
# 保存开始时间字符串
|
||
"start_time": str(datetime.now())
|
||
}
|
||
interval = 1
|
||
annotated_frames: SAM3 = SAM3()
|
||
|
||
# 获取总帧数与帧率
|
||
cap = cv2.VideoCapture(input_video_path)
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
global fps
|
||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||
cap.release()
|
||
|
||
|
||
if output_dir:
|
||
os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错
|
||
with open(f"{output_dir}/time.json", "w") as f:
|
||
f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
|
||
|
||
#================获取物体信息=================
|
||
#这里可以换成你的yolo运行逻辑,输出文件保存在frame_detections_path
|
||
|
||
if run_sam3: # 开关判定
|
||
# 针对厂房防火分区
|
||
annotated_frames.run(input_video_path, output_dir, "lib/class_list/1.厂房防火.json")
|
||
else: # 不运行sam3,直接加载之前的结果
|
||
annotated_frames.load_from_json(frame_detections_path)
|
||
# print(annotated_frames.data)
|
||
|
||
#================隐患检查=================
|
||
if run_inspection: # 开关判定
|
||
|
||
|
||
# 提取物体信息
|
||
f_detections_to_objects(
|
||
frame_detections_path,
|
||
objects_json_path
|
||
)
|
||
obj = json.load(open(objects_json_path, "r", encoding="utf-8"))
|
||
class_list = obj["class_list"]
|
||
obj_dict = obj["track_id_list"]
|
||
|
||
if run_vid_process:
|
||
# 生成物体视频
|
||
generate_video_to_objects(
|
||
obj_dict,
|
||
input_video_path,
|
||
output_dir=vid_dir,
|
||
)
|
||
|
||
# 上传视频并获取 URL
|
||
vid_dict = get_vid_dict(vid_dir, obj_dict, video_url_file, use_url_cache)
|
||
|
||
_,json_result_str = hazard_inspection(
|
||
output_dir,
|
||
obj_dict,
|
||
rule_dict,
|
||
class_list,
|
||
vid_dict,
|
||
fps=fps,
|
||
enable_thinking=enable_thinking
|
||
)
|
||
json_result: dict = json.loads(json_result_str)
|
||
|
||
else: # 不运行隐患检查,直接加载之前的结果
|
||
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
|
||
json_result = json.load(f)
|
||
|
||
# 保存结束时间字符串
|
||
time_data["end_time"] = str(datetime.now())
|
||
with open(f"{output_dir}/time.json", "w") as f:
|
||
f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
|
||
|
||
|
||
|
||
#================生成报告=================
|
||
|
||
if gen_report:
|
||
with open(f"知识库/rule.json", "r", encoding="utf-8") as f:
|
||
rule_definitions = json.load(f)
|
||
|
||
report_generator(
|
||
video_path=input_video_path, # 视频文件路径
|
||
detection_data=annotated_frames.data(), # 物体检测数据
|
||
hazard_results=json_result, # 隐患检查结果
|
||
rule_definitions=rule_definitions, # 规则定义
|
||
output_path=output_dir, # 输出文件夹
|
||
frame_interval=interval, # 帧间隔(可根据实际视频帧率调整)
|
||
)
|
||
|
||
#================保存结束时间=================
|
||
# 保存结束时间字符串
|
||
time_data["end_time"] = str(datetime.now())
|
||
with open(f"{output_dir}/time.json", "w") as f:
|
||
f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
|
||
|
||
#================更新预览=================
|
||
|
||
|
||
img_path, class_tag = update_preview(0, vid_name)
|
||
return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag, json_result
|
||
|
||
|
||
|
||
def get_vid_dict(vid_dir: str, obj_dict: dict, video_url_file: str, use_url_cache: bool) -> dict:
|
||
"""获取视频字典
|
||
参数:
|
||
vid_dir: 物体视频存放目录
|
||
obj_dict: 物品字典,键为物品TrackID,值为物品信息
|
||
video_url_file: 视频URL文件路径
|
||
use_url_cache: 是否使用缓存的URL
|
||
返回:
|
||
视频字典,键为物品TrackID,值为视频本地地址和视频URL
|
||
"""
|
||
|
||
vid_dict = {}
|
||
|
||
# 检查URL是否已存在
|
||
if use_url_cache:
|
||
if os.path.exists(video_url_file):
|
||
try:
|
||
with open(video_url_file, "r", encoding="utf-8") as f:
|
||
loaded_data = json.load(f)
|
||
# 验证数据结构:确保值是字典且包含必要键
|
||
for track_id, value in loaded_data.items():
|
||
if isinstance(value, dict) and "vid_path" in value and "vid_url" in value:
|
||
vid_dict[track_id] = value
|
||
else:
|
||
print(f"跳过无效的视频信息: {track_id} -> {value}")
|
||
except Exception as e:
|
||
print(f"读取URL文件失败: {e}")
|
||
|
||
# 初始化未存在的track_id
|
||
for track_id in obj_dict.keys():
|
||
if track_id in vid_dict:
|
||
continue
|
||
try:
|
||
vid_path = f"{vid_dir}/obj_{int(track_id):03d}.mp4"
|
||
vid_dict[track_id] = {"vid_path": vid_path, "vid_url": None}
|
||
except (ValueError, TypeError):
|
||
print(f"无效的track_id: {track_id},跳过")
|
||
|
||
# 找到未上传的视频
|
||
unuploaded_vids = []
|
||
for track_id, vid_info in vid_dict.items():
|
||
if isinstance(vid_info, dict) and vid_info.get("vid_url") is None:
|
||
unuploaded_vids.append(vid_info["vid_path"])
|
||
|
||
# 上传未上传的视频并获取 URL
|
||
uploaded_urls = upload_files_and_get_urls_concurrently(
|
||
file_path_list=unuploaded_vids,
|
||
max_workers=8
|
||
)
|
||
|
||
# 更新视频URL
|
||
for track_id, vid_info in vid_dict.items():
|
||
if isinstance(vid_info, dict) and vid_info.get("vid_url") is None:
|
||
if vid_info["vid_path"] in unuploaded_vids:
|
||
idx = unuploaded_vids.index(vid_info["vid_path"])
|
||
if idx < len(uploaded_urls):
|
||
vid_info["vid_url"] = uploaded_urls[idx]
|
||
|
||
# 保存字典为 JSON 文件
|
||
with open(video_url_file, "w", encoding="utf-8") as f:
|
||
json.dump(vid_dict, f, ensure_ascii=False, indent=4)
|
||
|
||
return vid_dict
|
||
|
||
|
||
# 创建 Gradio 页面
|
||
with gr.Blocks() as demo:
|
||
gr.Markdown("# 📸 隐患排查系统 (sam3 + qwen3.5-27b)")
|
||
|
||
with gr.Row():
|
||
with gr.Column():
|
||
|
||
initial_files: list[str] = load_file_list()
|
||
default_video = initial_files[0] if initial_files else None
|
||
vid_file = gr.Dropdown(
|
||
label="视频",
|
||
choices=initial_files,
|
||
value=default_video, # 默认选中第一个
|
||
interactive=True,
|
||
)
|
||
reload_file_list_button = gr.Button("刷新视频列表")
|
||
reload_file_list_button.click(fn=reload_files, inputs=[], outputs=[vid_file])
|
||
|
||
run_sam3 = gr.Checkbox(label="1. 运行 SAM3 模型", value=True)
|
||
run_inspection = gr.Checkbox(label="2. 运行隐患排查", value=True)
|
||
gen_report = gr.Checkbox(label="3. 生成报告", value=True)
|
||
audio_recognition = gr.Checkbox(label="4. 运行音频识别", value=False)
|
||
|
||
run_button = gr.Button("运行", variant="primary")
|
||
|
||
get_vid_path_btn = gr.Button("获取视频路径")
|
||
full_path_text = gr.Textbox(visible=False)
|
||
|
||
with gr.Column():
|
||
preview = gr.Image(label="预览", scale=2)
|
||
img_slider = gr.Slider(label="帧索引", minimum=0, maximum=0, value=0, step=1)
|
||
textbox = gr.Textbox(label="隐患结果", lines=10)
|
||
jsonbox = gr.JSON(label="隐患结果json")
|
||
|
||
|
||
|
||
run_button.click(fn=run, inputs=[vid_file, run_sam3, run_inspection, gen_report], outputs=[img_slider, preview, textbox, jsonbox])
|
||
img_slider.change(fn=update_preview, inputs=[img_slider, vid_file], outputs=[preview, textbox], show_progress="hidden")
|
||
get_vid_path_btn.click(fn=get_full_vid_path, inputs=vid_file, outputs=full_path_text)
|
||
|
||
|
||
def get_class_tag_by_frame(data, idx, fps):
|
||
"""
|
||
根据给定的帧索引 (idx),返回在该帧范围内的所有对象的 class:tag 信息。
|
||
|
||
参数:
|
||
data (dict): 包含 'tag', 'objects' 的字典数据(hazard_inspection.json格式)
|
||
idx (int): 需要查询的帧索引
|
||
|
||
返回:
|
||
str: 符合条件的 class:tag 列表,多个结果之间用换行符分隔。如果没有匹配项,返回空字符串。
|
||
"""
|
||
# 参数检查
|
||
if not isinstance(data, dict):
|
||
raise ValueError("数据必须是字典类型")
|
||
if 'tag' not in data or 'objects' not in data:
|
||
raise ValueError("数据必须包含 'tag' 和 'objects' 键")
|
||
|
||
tag_list = data['tag']
|
||
objects = data['objects']
|
||
|
||
interval = fps
|
||
idx = int(idx / interval) #转换
|
||
|
||
# 用于存储符合条件的 class:tag 字符串
|
||
result = []
|
||
all_class_tag = []
|
||
|
||
# 遍历每个物体
|
||
for obj in objects:
|
||
# 根据 tag_id 获取对应的隐患标签字符串
|
||
tag_id = obj.get('tag_id', 0)
|
||
class_id = obj.get('class_id', 0)
|
||
tag_str = tag_list[tag_id] if tag_id < len(tag_list) else f"未知标签({tag_id})"
|
||
location = obj.get('location', '')
|
||
start_frame = obj.get('start_frame', 0)
|
||
level = obj.get('level', '')
|
||
|
||
# 检查帧范围是否包含 idx
|
||
if start_frame == idx:
|
||
result.append(f"{tag_str} | 等级:{level} | 位置: {location}")
|
||
all_class_tag.append(f"{tag_str} | class_id:{class_id} | 等级:{level} | 开始帧:{start_frame} | 位置:{location}")
|
||
|
||
# 使用换行符连接所有结果
|
||
output = f"当前帧隐患:\n"+"\n".join(result)+"\n\n"+"所有隐患对象信息:\n"+"\n".join(all_class_tag)
|
||
return output
|
||
|
||
|
||
# 启动应用
|
||
if __name__ == "__main__":
|
||
demo.launch(
|
||
debug=True,
|
||
allowed_paths=[VIDEO_FOLDER]
|
||
) |