更新新流程

This commit is contained in:
yueliuli 2026-04-23 10:41:40 +08:00
parent 63a9cba36c
commit 67b78ec361
9 changed files with 757 additions and 420 deletions

View File

@ -1,5 +1,8 @@
from lib.json_fun import f_detections_to_objects, load_json_data
from lib.qwen_fun import get_annnotated_frame_for_ai_without_xyxy, hazard_inspection, merge_conflict_inspection_data, report_generator, search_knowledge_base from lib.qwen_fun import get_annnotated_frame_for_ai_without_xyxy, hazard_inspection, merge_conflict_inspection_data, report_generator, search_knowledge_base
from encodings.punycode import T from encodings.punycode import T
from lib.qwen_fun_vid import generate_video_to_objects
""" """
测试 在给定标注框与类别原始视频经过转换之后AI能否准确识别物体特征 测试 在给定标注框与类别原始视频经过转换之后AI能否准确识别物体特征
""" """
@ -17,7 +20,6 @@ from pathlib import Path
import gradio as gr import gradio as gr
fps = 0 fps = 0
<<<<<<< HEAD
VIDEO_FOLDER: str = "input" VIDEO_FOLDER: str = "input"
def load_file_list() -> list[str]: def load_file_list() -> list[str]:
@ -37,6 +39,14 @@ def load_file_list() -> list[str]:
return file_names return file_names
def reload_files():
"""
刷新文件列表并设置默认值
"""
file_list = load_file_list()
default_value = file_list[0] if file_list else None
return gr.update(choices=file_list, value=default_value)
def get_full_vid_path(vid_file: str) -> str: def get_full_vid_path(vid_file: str) -> str:
""" """
获取视频绝对路径 获取视频绝对路径
@ -47,10 +57,6 @@ def get_full_vid_path(vid_file: str) -> str:
def update_preview(frame_idx: int, vid_file: str): def update_preview(frame_idx: int, vid_file: str):
vid_name: str = Path(vid_file).stem vid_name: str = Path(vid_file).stem
=======
def update_preview(frame_idx: int, vid_name: str):
>>>>>>> 7562de4 (2 预览图还有点问题)
output_dir: str = f"output/{vid_name}" output_dir: str = f"output/{vid_name}"
global fps global fps
idx = int(frame_idx // fps) idx = int(frame_idx // fps)
@ -64,114 +70,222 @@ def update_preview(frame_idx: int, vid_name: str):
return img_path, class_tag return img_path, class_tag
def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True,): # def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True,):
#================初始化================= # #================初始化=================
vid_name: str = Path(vid_file).stem # vid_name: str = Path(vid_file).stem
vid_end: str = Path(vid_file).suffix # vid_end: str = Path(vid_file).suffix
vid_path: str = f"./input/{vid_name}{vid_end}" # vid_path: str = f"./input/{vid_name}{vid_end}"
cap = cv2.VideoCapture(vid_path) # cap = cv2.VideoCapture(vid_path)
global fps # global fps
fps = int(cap.get(cv2.CAP_PROP_FPS)) # 获取视频FPS # fps = int(cap.get(cv2.CAP_PROP_FPS)) # 获取视频FPS
# cap.release()
# output_dir: str = f"output/{vid_name}"
# # interval = int(fps / 5)
# interval = fps
# conf = 0.7
# time_data: dict = {}
# annotated_frames: SAM3 = SAM3()
# result: dict = {}
# if output_dir:
# os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错
# #================获取物体信息=================
# # 保存开始时间字符串
# time_data["start_time"] = str(datetime.now())
# with open(f"{output_dir}/time.json", "w") as f:
# f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
# if run_sam3:
# # 针对厂房防火分区
# annotated_frames.run(vid_path, output_dir, "lib/class_list/1.厂房防火.json", interval, conf)
# else:
# annotated_frames.load_from_json(f"{output_dir}/frame_all.json")
# # print(annotated_frames.data)
# # 提取ai能看到的部分
# ai_frames: dict = get_annnotated_frame_for_ai_without_xyxy(annotated_frames.data(), 1, conf)
# save_json_to_file(ai_frames, f"{output_dir}/frame_all_ai.json")
# #================隐患检查=================
# if run_inspection:
# if vid_path.startswith("oss"):
# video_url = vid_path
# else:
# video_url: str|None = None
# video_url_file = f"{output_dir}/video_url.json"
# # 检查URL是否已存在
# if os.path.exists(video_url_file):
# try:
# with open(video_url_file, "r", encoding="utf-8") as f:
# url_data = json.load(f)
# if vid_name in url_data:
# video_url = url_data[vid_name]
# print(f"使用已存在的URL: {video_url}")
# except Exception as e:
# print(f"读取URL文件失败: {e}")
# # 如果URL不存在上传文件
# if video_url is None:
# print(f"上传视频文件: {vid_path}")
# video_url = upload_files_and_get_urls_concurrently(
# file_path_list=[vid_path],
# max_workers=8
# )[0]
# if video_url:
# # 保存为JSON格式包含文件名和URL的键值对
# url_data = {}
# if os.path.exists(video_url_file):
# try:
# with open(video_url_file, "r", encoding="utf-8") as f:
# url_data = json.load(f)
# except:
# pass
# url_data[vid_name] = video_url
# with open(video_url_file, "w", encoding="utf-8") as f:
# json.dump(url_data, f, ensure_ascii=False, indent=4)
# print(f"URL已保存: {video_url}")
# if video_url is None:
# raise ValueError("视频上传失败,无法获取 URL")
# result_test: str
# reason_test: str
# reason_test, result_test = hazard_inspection(ai_frames, video_url, enable_thinking=True, fps=2)
# result = json.loads(result_test)
# # merged_result = merge_conflict_inspection_data(result)
# result["class"] = ai_frames["class_list"]
# with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f:
# f.write(json.dumps(result, ensure_ascii=False, indent=4))
# with open(f"{output_dir}/hazard_inspection_reason.json", "w", encoding="utf-8") as f:
# f.write(json.dumps(reason_test, ensure_ascii=False, indent=4))
# # with open(f"{output_dir}/hazard_inspection_merged.json", "w", encoding="utf-8") as f:
# # f.write(json.dumps(merged_result, ensure_ascii=False, indent=4))
# #================生成报告=================
# if gen_report:
# if result == {}:
# with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
# result = json.load(f)
# with open(f"知识库/rule.json", "r", encoding="utf-8") as f:
# rule_definitions = json.load(f)
# report_generator(
# video_path=vid_path, # 视频文件路径
# detection_data=annotated_frames.data(), # 物体检测数据
# hazard_results=result, # 隐患检查结果
# rule_definitions=rule_definitions, # 规则定义
# output_path=output_dir, # 输出文件夹
# frame_interval=interval, # 帧间隔(可根据实际视频帧率调整)
# )
# # 保存结束时间字符串
# time_data["end_time"] = str(datetime.now())
# with open(f"{output_dir}/time.json", "w") as f:
# f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
# #================更新预览=================
# # 获取总帧数
# cap = cv2.VideoCapture(vid_path)
# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# cap.release()
# img_path, class_tag = update_preview(0, vid_name)
# return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag
def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True):
vid_name = vid_file.split(".")[0] # 视频名称(不含后缀)
vid_end = f".{vid_file.split('.')[-1]}" # 视频后缀
use_url_cache = True # 是否使用 URL 缓存,避免重复上传视频
enable_thinking = True # 是否启用思考模式
run_vid_process = True # 是否运行视频处理流程(提取物体视频)
input_video_path = f"input/{vid_name}{vid_end}"
output_dir = f"output/{vid_name}"
frame_detections_path = f"{output_dir}/frame_detections.json"
objects_json_path = f"{output_dir}/objects.json"
vid_dir = f"{output_dir}/obj_vids"
rule_dict: dict = load_json_data('知识库/rule.json')
video_url_file = f"{output_dir}/video_url.json"
time_data: dict = {
# 保存开始时间字符串
"start_time": str(datetime.now())
}
interval = 1
annotated_frames: SAM3 = SAM3()
# 获取总帧数与帧率
cap = cv2.VideoCapture(input_video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release() cap.release()
output_dir: str = f"output/{vid_name}"
# interval = int(fps / 5)
interval = fps
conf = 0.7
time_data: dict = {}
annotated_frames: SAM3 = SAM3()
result: dict = {}
if output_dir: if output_dir:
os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错 os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错
#================获取物体信息=================
# 保存开始时间字符串
time_data["start_time"] = str(datetime.now())
with open(f"{output_dir}/time.json", "w") as f: with open(f"{output_dir}/time.json", "w") as f:
f.write(json.dumps(time_data, ensure_ascii=False, indent=4)) f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
if run_sam3: #================获取物体信息=================
#这里可以换成你的yolo运行逻辑输出文件保存在frame_detections_path
if run_sam3: # 开关判定
# 针对厂房防火分区 # 针对厂房防火分区
annotated_frames.run(vid_path, output_dir, "lib/class_list/1.厂房防火.json", interval, conf) annotated_frames.run(input_video_path, output_dir, "lib/class_list/1.厂房防火.json")
else: else: # 不运行sam3直接加载之前的结果
annotated_frames.load_from_json(f"{output_dir}/frame_all.json") annotated_frames.load_from_json(frame_detections_path)
# print(annotated_frames.data) # print(annotated_frames.data)
# 提取ai能看到的部分
ai_frames: dict = get_annnotated_frame_for_ai_without_xyxy(annotated_frames.data(), 1, conf)
save_json_to_file(ai_frames, f"{output_dir}/frame_all_ai.json")
#================隐患检查================= #================隐患检查=================
if run_inspection: # 开关判定
if run_inspection:
if vid_path.startswith("oss"):
video_url = vid_path
else:
video_url: str|None = None
video_url_file = f"{output_dir}/video_url.json"
# 检查URL是否已存在 # 提取物体信息
if os.path.exists(video_url_file): f_detections_to_objects(
try: frame_detections_path,
with open(video_url_file, "r", encoding="utf-8") as f: objects_json_path
url_data = json.load(f) )
if vid_name in url_data: obj = json.load(open(objects_json_path, "r", encoding="utf-8"))
video_url = url_data[vid_name] class_list = obj["class_list"]
print(f"使用已存在的URL: {video_url}") obj_dict = obj["track_id_list"]
except Exception as e:
print(f"读取URL文件失败: {e}")
# 如果URL不存在上传文件 if run_vid_process:
if video_url is None: # 生成物体视频
print(f"上传视频文件: {vid_path}") generate_video_to_objects(
video_url = upload_files_and_get_urls_concurrently( obj_dict,
file_path_list=[vid_path], input_video_path,
max_workers=8 output_dir=vid_dir,
)[0] )
if video_url:
# 保存为JSON格式包含文件名和URL的键值对
url_data = {}
if os.path.exists(video_url_file):
try:
with open(video_url_file, "r", encoding="utf-8") as f:
url_data = json.load(f)
except:
pass
url_data[vid_name] = video_url # 上传视频并获取 URL
with open(video_url_file, "w", encoding="utf-8") as f: vid_dict = get_vid_dict(vid_dir, obj_dict, video_url_file, use_url_cache)
json.dump(url_data, f, ensure_ascii=False, indent=4)
print(f"URL已保存: {video_url}")
if video_url is None: hazard_inspection(
raise ValueError("视频上传失败,无法获取 URL") output_dir,
obj_dict,
<<<<<<< HEAD rule_dict,
result_test: str class_list,
reason_test: str vid_dict,
reason_test, result_test = hazard_inspection(ai_frames, video_url, enable_thinking=True, fps=2) fps=fps,
result = json.loads(result_test) enable_thinking=enable_thinking
======= )
result = json.loads(hazard_inspection(ai_frames, video_url, enable_thinking=False, fps=2)[1])
>>>>>>> 7562de4 (2 预览图还有点问题)
# merged_result = merge_conflict_inspection_data(result)
result["class"] = ai_frames["class_list"]
with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f:
f.write(json.dumps(result, ensure_ascii=False, indent=4))
with open(f"{output_dir}/hazard_inspection_reason.json", "w", encoding="utf-8") as f:
f.write(json.dumps(reason_test, ensure_ascii=False, indent=4))
# with open(f"{output_dir}/hazard_inspection_merged.json", "w", encoding="utf-8") as f:
# f.write(json.dumps(merged_result, ensure_ascii=False, indent=4))
#================生成报告================= #================生成报告=================
if gen_report: if gen_report:
if result == {}: if result == {}:
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f: with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
@ -181,7 +295,7 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r
rule_definitions = json.load(f) rule_definitions = json.load(f)
report_generator( report_generator(
video_path=vid_path, # 视频文件路径 video_path=input_video_path, # 视频文件路径
detection_data=annotated_frames.data(), # 物体检测数据 detection_data=annotated_frames.data(), # 物体检测数据
hazard_results=result, # 隐患检查结果 hazard_results=result, # 隐患检查结果
rule_definitions=rule_definitions, # 规则定义 rule_definitions=rule_definitions, # 规则定义
@ -196,13 +310,61 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r
#================更新预览================= #================更新预览=================
# 获取总帧数
cap = cv2.VideoCapture(vid_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
img_path, class_tag = update_preview(0, vid_name) img_path, class_tag = update_preview(0, vid_name)
return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag
def get_vid_dict(vid_dir: str, obj_dict: dict, video_url_file: str, use_url_cache: bool) -> dict:
"""获取视频字典
参数:
vid_dir: 物体视频存放目录
obj_dict: 物品字典键为物品TrackID值为物品信息
video_url_file: 视频URL文件路径
use_url_cache: 是否使用缓存的URL
返回:
视频字典键为物品TrackID值为视频本地地址和视频URL
"""
vid_dict = {}
# 检查URL是否已存在
if use_url_cache:
if os.path.exists(video_url_file):
try:
with open(video_url_file, "r", encoding="utf-8") as f:
vid_dict = json.load(f)
except Exception as e:
print(f"读取URL文件失败: {e}")
for track_id in obj_dict.keys():
if track_id in vid_dict:
continue
vid_path = f"{vid_dir}/obj_{int(track_id):03d}.mp4"
vid_dict[track_id] = {"vid_path": vid_path, "vid_url": None}
# 找到未上传的视频
unuploaded_vids = []
for track_id, vid_info in vid_dict.items():
if vid_info["vid_url"] is None:
unuploaded_vids.append(vid_info["vid_path"])
# 上传未上传的视频并获取 URL
uploaded_urls = upload_files_and_get_urls_concurrently(
file_path_list=unuploaded_vids,
max_workers=8
)
# 更新视频URL
for track_id, vid_info in vid_dict.items():
if vid_info["vid_url"] is None:
idx = unuploaded_vids.index(vid_info["vid_path"])
vid_info["vid_url"] = uploaded_urls[idx]
# 保存字典为 JSON 文件
with open(video_url_file, "w", encoding="utf-8") as f:
json.dump(vid_dict, f, ensure_ascii=False, indent=4)
return vid_dict
# 创建 Gradio 页面 # 创建 Gradio 页面
with gr.Blocks() as demo: with gr.Blocks() as demo:
@ -210,7 +372,7 @@ with gr.Blocks() as demo:
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
<<<<<<< HEAD
initial_files: list[str] = load_file_list() initial_files: list[str] = load_file_list()
default_video = initial_files[0] if initial_files else None default_video = initial_files[0] if initial_files else None
vid_file = gr.Dropdown( vid_file = gr.Dropdown(
@ -219,10 +381,8 @@ with gr.Blocks() as demo:
value=default_video, # 默认选中第一个 value=default_video, # 默认选中第一个
interactive=True, interactive=True,
) )
======= reload_file_list_button = gr.Button("刷新视频列表")
vid_name = gr.Textbox(label="视频名称", value="Miehhuoxqih") reload_file_list_button.click(fn=reload_files, inputs=[], outputs=[vid_file])
vid_end = gr.Textbox(label="视频后缀", value=".AVI")
>>>>>>> 7562de4 (2 预览图还有点问题)
run_sam3 = gr.Checkbox(label="1. 运行 SAM3 模型", value=True) run_sam3 = gr.Checkbox(label="1. 运行 SAM3 模型", value=True)
run_inspection = gr.Checkbox(label="2. 运行隐患排查", value=True) run_inspection = gr.Checkbox(label="2. 运行隐患排查", value=True)
@ -251,7 +411,7 @@ def get_class_tag_by_frame(data, idx, fps):
根据给定的帧索引 (idx)返回在该帧范围内的所有对象的 class:tag 信息 根据给定的帧索引 (idx)返回在该帧范围内的所有对象的 class:tag 信息
参数: 参数:
data (dict): 包含 'class', 'tag', 'objects' 的字典数据 data (dict): 包含 'tag', 'objects' 的字典数据hazard_inspection.json格式
idx (int): 需要查询的帧索引 idx (int): 需要查询的帧索引
返回: 返回:
@ -260,10 +420,9 @@ def get_class_tag_by_frame(data, idx, fps):
# 参数检查 # 参数检查
if not isinstance(data, dict): if not isinstance(data, dict):
raise ValueError("数据必须是字典类型") raise ValueError("数据必须是字典类型")
if 'class' not in data or 'tag' not in data or 'objects' not in data: if 'tag' not in data or 'objects' not in data:
raise ValueError("数据必须包含 'class', 'tag''objects'") raise ValueError("数据必须包含 'tag''objects'")
class_list = data['class']
tag_list = data['tag'] tag_list = data['tag']
objects = data['objects'] objects = data['objects']
@ -276,17 +435,21 @@ def get_class_tag_by_frame(data, idx, fps):
# 遍历每个物体 # 遍历每个物体
for obj in objects: for obj in objects:
# 根据 class_id 和 tag_id 获取对应的字符串 # 根据 tag_id 获取对应的隐患标签字符串
class_str = class_list[obj['class_id']] tag_id = obj.get('tag_id', 0)
tag_str = tag_list[obj['tag_id']] class_id = obj.get('class_id', 0)
tag_str = tag_list[tag_id] if tag_id < len(tag_list) else f"未知标签({tag_id})"
location = obj.get('location', '') location = obj.get('location', '')
start_frame = obj.get('start_frame', 0)
level = obj.get('level', '')
# 检查帧范围是否包含 idx # 检查帧范围是否包含 idx
if obj['start_frame'] == idx: if start_frame == idx:
result.append(f"{class_str}:{tag_str} (位置: {location})") result.append(f"{tag_str} | 等级:{level} | 位置: {location}")
all_class_tag.append(f"{class_str}:{tag_str} (开始帧: {obj['start_frame']*interval}, 位置: {location})") all_class_tag.append(f"{tag_str} | class_id:{class_id} | 等级:{level} | 开始帧:{start_frame} | 位置:{location}")
# 使用换行符连接所有结果 # 使用换行符连接所有结果
output = f"当前帧隐患:\n"+"\n".join(result)+"\n\n"+"所有对象的 class:tag 信息:\n"+"\n".join(all_class_tag) output = f"当前帧隐患:\n"+"\n".join(result)+"\n\n"+"所有隐患对象信息:\n"+"\n".join(all_class_tag)
return output return output
@ -296,6 +459,3 @@ if __name__ == "__main__":
debug=True, debug=True,
allowed_paths=[VIDEO_FOLDER] allowed_paths=[VIDEO_FOLDER]
) )

View File

@ -1,4 +1,7 @@
from lib.qwen_fun import hazard_inspection, search_knowledge_base import os
from lib.json_fun import f_detections_to_objects
from lib.qwen_fun import chat, hazard_inspection, load_json_data, search_knowledge_base, upload_files_and_get_urls_concurrently
from encodings.punycode import T from encodings.punycode import T
""" """
测试 在给定标注框与类别原始视频经过转换之后AI能否准确识别物体特征 测试 在给定标注框与类别原始视频经过转换之后AI能否准确识别物体特征
@ -7,59 +10,126 @@ from encodings.punycode import T
from tkinter import N from tkinter import N
from datetime import datetime from datetime import datetime
import json import json
import os from lib.qwen_fun_vid import generate_video_to_objects
from lib.qwen_fun import chat, get_annnotated_frame_for_ai, get_unique_track_id_count, save_json_to_file, search_knowledge_base, upload_files_and_get_urls_concurrently
from lib.qwen_fun_vid import create_mian_vid_for_ai, frame_all_to_obj_vid
from lib.sam3 import SAM3 from lib.sam3 import SAM3
from ultralytics.models.sam import SAM3VideoSemanticPredictor from ultralytics.models.sam import SAM3VideoSemanticPredictor
import cv2
import json import json
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
def run(output_dir, vid_path, interval: int = 5): def run():
if output_dir: vid_name = "santai5" # 视频名称(不含后缀)
os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错 vid_end = ".mp4" # 视频后缀
use_url_cache = True # 是否使用 URL 缓存,避免重复上传视频
enable_thinking = False # 是否启用思考模式
annotated_frames: SAM3 = SAM3() run_vid_process = False # 是否运行视频处理流程(提取物体视频)
annotated_frames.run(vid_path, output_dir, "lib/class_list/xiaofangz.json")
# annotated_frames.load_from_json(f"{output_dir}/frame_all.json")
# print(annotated_frames.data)
# 提取ai能看到的部分 input_video_path = f"input/{vid_name}{vid_end}"
ai_frames: dict = get_annnotated_frame_for_ai(annotated_frames.data(), interval) output_dir = f"output/{vid_name}"
save_json_to_file(ai_frames, f"{output_dir}/frame_all_ai.json") frame_detections_path = f"{output_dir}/frame_detections.json"
objects_json_path = f"{output_dir}/objects.json"
vid_dir = f"{output_dir}/obj_vids"
obj = json.load(open(objects_json_path, "r", encoding="utf-8"))
class_list = obj["class_list"]
obj_dict = obj["track_id_list"]
rule_dict: dict = load_json_data('知识库/rule.json')
video_url_file = f"{output_dir}/video_url.json"
# 提取物体信息
f_detections_to_objects(
frame_detections_path,
objects_json_path
)
if vid_path.startswith("oss"): if run_vid_process:
video_url = vid_path # 生成物体视频
else: generate_video_to_objects(
video_url: str|None = upload_files_and_get_urls_concurrently( obj_dict,
file_path_list=[vid_path], input_video_path,
max_workers=8 output_dir=vid_dir,
)[0] )
if video_url:
with open(f"{output_dir}/video_url.json", "w") as f:
f.write(video_url)
if video_url is None: # 上传视频并获取 URL
raise ValueError("视频上传失败,无法获取 URL") vid_dict = get_vid_dict(vid_dir, obj_dict, video_url_file, use_url_cache)
result = hazard_inspection(ai_frames, video_url)[1] hazard_inspection(
output_dir,
obj_dict,
rule_dict,
class_list,
vid_dict,
fps=30,
enable_thinking=enable_thinking
)
def get_vid_dict(vid_dir: str, obj_dict: dict, video_url_file: str, use_url_cache: bool) -> dict:
"""获取视频字典
参数:
vid_dir: 物体视频存放目录
obj_dict: 物品字典键为物品TrackID值为物品信息
video_url_file: 视频URL文件路径
use_url_cache: 是否使用缓存的URL
返回:
视频字典键为物品TrackID值为视频本地地址和视频URL
"""
vid_dict = {}
# 检查URL是否已存在
if use_url_cache:
if os.path.exists(video_url_file):
try:
with open(video_url_file, "r", encoding="utf-8") as f:
vid_dict = json.load(f)
except Exception as e:
print(f"读取URL文件失败: {e}")
for track_id in obj_dict.keys():
if track_id in vid_dict:
continue
vid_path = f"{vid_dir}/obj_{int(track_id):03d}.mp4"
vid_dict[track_id] = {"vid_path": vid_path, "vid_url": None}
# 找到未上传的视频
unuploaded_vids = []
for track_id, vid_info in vid_dict.items():
if vid_info["vid_url"] is None:
unuploaded_vids.append(vid_info["vid_path"])
# 上传未上传的视频并获取 URL
uploaded_urls = upload_files_and_get_urls_concurrently(
file_path_list=unuploaded_vids,
max_workers=8
)
# 更新视频URL
for track_id, vid_info in vid_dict.items():
if vid_info["vid_url"] is None:
idx = unuploaded_vids.index(vid_info["vid_path"])
vid_info["vid_url"] = uploaded_urls[idx]
# 保存字典为 JSON 文件
with open(video_url_file, "w", encoding="utf-8") as f:
json.dump(vid_dict, f, ensure_ascii=False, indent=4)
return vid_dict
with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f:
f.write(result)
# 启动应用 # 启动应用
if __name__ == "__main__": if __name__ == "__main__":
VID_NAME: str = "Peihdianhxiang" # objects_json_path = "output/santai5/objects.json"
VID_END: str = ".mp4" # input_video_path = "input/santai5.mp4"
interval: int = int(30 / 5) # obj_dict = json.load(open(objects_json_path, "r", encoding="utf-8"))["track_id_list"]
output_dir: str = f"output/{VID_NAME}" # # 记录开始时间
vid_path: str = f"./input/{VID_NAME}{VID_END}" # start_time = datetime.now()
# generate_video_to_objects(
# obj_dict,
# input_video_path,
# output_dir="output/santai5/obj_vids",
# )
run(output_dir, vid_path, interval) run()

View File

@ -8,21 +8,53 @@ def f_detections_to_objects(detections_path: str, output_path: str):
:param detections_path: 输入文件路径 :param detections_path: 输入文件路径
:param output_path: 输出文件路径 :param output_path: 输出文件路径
:return: None :return: None
detections 格式
{
"frame_id": [
{
"xyxy": [x1, y1, x2, y2]
"confidence": 0.314,
"track_id": 1,
"class_id": 1,
"class_str": "class_str",
},
...
],
...
}
objects 格式
{
"class_list": [
"插座",
"消火栓",
"配电箱"
],
"track_id_list": {
"0": {
"class_id": 4,
"start_frame": 551,
"end_frame": 597
},
}
}
""" """
with open(detections_path, 'r', encoding='utf-8') as f: with open(detections_path, 'r', encoding='utf-8') as f:
detections_data = json.load(f) detections_data = json.load(f)
class_set = set() class_list = []
track_info = defaultdict(lambda: {"class_id": None, "start_frame": float('inf'), "end_frame": 0}) track_info = defaultdict(lambda: {"class_id": None, "start_frame": float('inf'), "end_frame": 0})
for frame_str, detections in detections_data.items(): for frame_str, detections in detections_data.items():
frame = int(frame_str) frame = int(frame_str)
for det in detections: for det in detections:
track_id = det["track_id"] track_id = det["track_id"]
class_id = det["class_id"]
class_str = det["class_str"] class_str = det["class_str"]
class_set.add(class_str) if class_str not in class_list:
class_list.append(class_str)
class_id = class_list.index(class_str)
if track_info[track_id]["class_id"] is None: if track_info[track_id]["class_id"] is None:
track_info[track_id]["class_id"] = class_id track_info[track_id]["class_id"] = class_id
@ -33,15 +65,13 @@ def f_detections_to_objects(detections_path: str, output_path: str):
track_info[track_id]["end_frame"] = frame track_info[track_id]["end_frame"] = frame
result = { result = {
"class_list": list(class_set), "class_list": list(class_list),
"track_id_list": dict(track_info) "track_id_list": dict(track_info)
} }
with open(output_path, 'w', encoding='utf-8') as f: with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2) json.dump(result, f, ensure_ascii=False, indent=2)
if __name__ == "__main__": def load_json_data(json_path: str):
f_detections_to_objects( with open(json_path, 'r', encoding='utf-8') as f:
"output/santai5/frame_detections.json", return json.load(f)
"output/santai5/objects.json"
)

View File

@ -1,6 +1,7 @@
import json import json
import os import os
from pathlib import Path from pathlib import Path
from lib.json_fun import load_json_data
import requests import requests
from openai import BadRequestError, OpenAI from openai import BadRequestError, OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
@ -209,9 +210,7 @@ def save_json_to_file(json_data, file_path: str) -> None:
print(f"已保存 JSON 数据到: {file_path}") print(f"已保存 JSON 数据到: {file_path}")
def load_json_data(json_path: str):
with open(json_path, 'r', encoding='utf-8') as f:
return json.load(f)
def get_unique_track_id_count(json_path: str) -> int: def get_unique_track_id_count(json_path: str) -> int:
""" """
@ -362,47 +361,34 @@ def search_knowledge_base(
return "" return ""
def hazard_inspection( def hazard_inspection(
class_dict: dict, output_dir: str,
video_url: str, obj_dict: dict,
rule_dict: dict,
class_list: list[str],
vid_dict: dict,
fps: int,
enable_thinking: bool = True, enable_thinking: bool = True,
fps: int = 10
) -> tuple[str, str]: ) -> tuple[str, str]:
""" """
检查视频中给定类型物体的隐患 检查视频中给定类型物体的隐患
2026.04.07 2026.04.17
参数: 参数:
class_dict: 类别字典 output_dir: 输出目录
input_video_path: 视频路径 obj_dict: 物体视频字典
class_list: 类别列表
vid_dict: 视频字典
fps: 视频帧率
enable_thinking: 是否开启思考模式 enable_thinking: 是否开启思考模式
fps
class_dict:结构: obj_dict 格式
{
"class_list": ["class_name1", "class_name2", ..."], # 类别名称 vid_dict 格式
"track_id_list": {
"0": { # track_id
"class_id": 0,
"start_frame": 0,
"end_frame": 10
},
...
},
}
返回: 返回:
tuple[str, str]: (reasoning_text, json_result) tuple[str, str]: (reasoning_text, json_result)
json_result: JSON格式的隐患检测结果
""" """
# 筛选检查规则 知识库\rule.json
all_check_rule_dict: dict = load_json_data(r'知识库\rule.json')
# 读取提示词文件 7_检测提示词_视频.md
prompt_str: str = ""
with open(r'prompt\7_检测提示词_视频_new3.md', 'r', encoding='utf-8') as file:
prompt_str = file.read()
# 存储所有类别的结果 # 存储所有类别的结果
all_reasoning_text = "" all_reasoning_text = ""
@ -410,58 +396,23 @@ def hazard_inspection(
all_tags = [] all_tags = []
all_bases = [] all_bases = []
all_objects = [] all_objects = []
hazard_count = 0
# 保存每个类别的原始 tag 和 base 映射,用于后续重新映射
class_tag_mapping = {}
class_base_mapping = {}
# 遍历每个类别,分别进行 chat with open(r'prompt\7_检测提示词_视频_new4.md', 'r', encoding='utf-8') as file:
print("类别列表:", class_dict["class_list"]) prompt_str = file.read()
for class_str in class_dict["class_list"]:
print(f"当前类别: {class_str}")
# 检查该类别是否有检查规则
class_check_rule: str = ""
scene_list: list[str] = list(all_check_rule_dict.keys())
print(f"场景列表: {scene_list}")
for scene in scene_list:
if class_str in all_check_rule_dict[scene].keys():
class_check_rule = all_check_rule_dict[scene][class_str]
break
if class_check_rule == "":
# 该类别没有检查规则
continue
# 筛选出该类别的 track_id_list
class_track_info = []
for track_id, track_data in class_dict["track_id_list"].items():
track_class_name = class_dict["class_list"][track_data["class_id"]]
if track_class_name == class_str:
class_track_info.append(f"Track ID {track_id}: {track_class_name} (帧 {track_data['start_frame']}-{track_data['end_frame']})")
# 如果该类别没有跟踪对象,跳过 # 遍历每个物体视频,进行检查
if not class_track_info: for track_id, track_info in obj_dict.items():
continue class_str = class_list[track_info["class_id"]]
vid_url = vid_dict[track_id]["vid_url"]
reasoning_text, answer_text = single_object_inspection(vid_url, class_str, rule_dict, prompt_str, enable_thinking)
track_data_str = "\n".join(class_track_info)
# 构建该类别的 prompt
prompt: str = f"""
# 类别: {class_str}
# 跟踪对象信息
{track_data_str}
# 检查规则
{class_check_rule}
{prompt_str}
"""
# 调用 chat 函数
reasoning_text, answer_text = chat(prompt, video_url, enable_thinking, fps)
# 整合思考过程 # 整合思考过程
all_reasoning_text += f"\n\n# 类别: {class_str}\n{reasoning_text}" all_reasoning_text += f"\n\n# 类别: {track_id}:{class_str}\n{reasoning_text}"
# 解析 JSON 结果 # 解析 JSON 结果
try: try:
@ -477,81 +428,110 @@ def hazard_inspection(
class_result = json.loads(json_str) class_result = json.loads(json_str)
# 保存class_result到文件便于调试
with open(f"{output_dir}/inspection_result_{int(track_id):03d}.txt", "w", encoding="utf-8") as f:
f.write(json.dumps(class_result, ensure_ascii=False, indent=2))
# 收集该类别的结果 # 收集该类别的结果
if "tag" in class_result: temp_obj = {}
# 保存该类别的 tag 映射 for obj in class_result.get("objects", []):
class_tag_mapping[class_str] = class_result["tag"]
all_tags.extend(class_result["tag"]) # tag_id
if "base" in class_result: if "tag_id" in obj and obj["tag_id"] < len(class_result.get("tag", [])):
# 保存该类别的 base 映射 tag = class_result["tag"][obj["tag_id"]]
class_base_mapping[class_str] = class_result["base"] else:
all_bases.extend(class_result["base"]) tag = None
if "objects" in class_result:
all_objects.extend(class_result["objects"]) if tag not in all_tags:
all_tags.append(tag)
temp_obj["tag_id"] = all_tags.index(tag)
# base_id
if "base_id" in obj and obj["base_id"] < len(class_result.get("base", [])):
base = class_result["base"][obj["base_id"]]
else:
base = None
if base not in all_bases:
all_bases.append(base)
temp_obj["base_id"] = all_bases.index(base)
# 其他字段
temp_obj["track_id"] = track_id
temp_obj["hazard_track_id"] = hazard_count
temp_obj["class_id"] = track_info["class_id"]
temp_obj["level"] = obj["level"]
temp_obj["start_frame"] = track_info["start_frame"]
temp_obj["start_sec"] = round(track_info["start_frame"] / fps, 1) # 处理成x.x秒
temp_obj["location"] = obj.get("location", "")
hazard_count += 1
all_objects.append(temp_obj)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(f"解析类别 {class_str} 的 JSON 结果失败: {e}") print(f"解析类别 {class_str} 的 JSON 结果失败: {e}")
print(f"原始输出: {answer_text}") print(f"原始输出: {answer_text}")
continue continue
# 合并去重 tag 和 base
unique_tags = []
seen_tags = set()
for tag in all_tags:
if tag not in seen_tags:
seen_tags.add(tag)
unique_tags.append(tag)
unique_bases = []
seen_bases = set()
for base in all_bases:
if base not in seen_bases:
seen_bases.add(base)
unique_bases.append(base)
# 重新分配 tag_id 和 base_id
# 由于每个类别是独立处理的,需要根据 class_id 找到对应的类别,然后重新映射
# 但是 object 中的 class_id 是相对于 class_list 的,需要找到对应的类别名称
final_objects = []
for obj in all_objects:
new_obj = obj.copy()
new_obj["hazard_track_id"] = len(final_objects)
# 通过 class_id 找到类别名称
class_id = obj.get("class_id", 0)
if class_id < len(class_dict["class_list"]):
class_name = class_dict["class_list"][class_id]
# 重新映射 tag_id
if "tag_id" in obj and class_name in class_tag_mapping:
original_tags = class_tag_mapping[class_name]
if obj["tag_id"] < len(original_tags):
original_tag = original_tags[obj["tag_id"]]
if original_tag in unique_tags:
new_obj["tag_id"] = unique_tags.index(original_tag)
# 重新映射 base_id
if "base_id" in obj and class_name in class_base_mapping:
original_bases = class_base_mapping[class_name]
if obj["base_id"] < len(original_bases):
original_base = original_bases[obj["base_id"]]
if original_base in unique_bases:
new_obj["base_id"] = unique_bases.index(original_base)
final_objects.append(new_obj)
# 构建最终结果 # 构建最终结果
final_result = { final_result = {
"tag": unique_tags, "tag": all_tags,
"base": unique_bases, "base": all_bases,
"objects": final_objects "objects": all_objects
} }
# 转换为 JSON 字符串 # 转换为 JSON 字符串
json_result = json.dumps(final_result, ensure_ascii=False, indent=2) json_result = json.dumps(final_result, ensure_ascii=False, indent=2)
with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f:
f.write(json_result)
with open(f"{output_dir}/hazard_inspection_reasoning.md", "w", encoding="utf-8") as f:
f.write(all_reasoning_text)
return all_reasoning_text, json_result return all_reasoning_text, json_result
def single_object_inspection(vid_url: str, class_str: str, rule_dict: dict, prompt_str: str, enable_thinking: bool = True) -> tuple[str, str]:
# 获取检查规则
rule = get_rule_by_class_str(class_str, rule_dict)
if rule == "":
print(f"类别 {class_str} 没有检查规则,跳过检查")
return ""
else:
print(f"类别 {class_str} 的检查规则: {rule}")
# 构建该类别的 prompt
prompt: str = f"""
# 检测对象为视频中出现的{class_str}
# 检查规则
{rule}
{prompt_str}
"""
return chat(prompt, vid_url, enable_thinking)
def get_rule_by_class_str(class_str: str, all_check_rule_dict: dict) -> str:
"""根据类别名称获取检查规则
参数:
class_str: 类别名称
all_check_rule_dict: 所有检查规则字典
返回:
检查规则字符串
"""
print(f"当前类别: {class_str}")
# 检查该类别是否有检查规则
class_check_rule: str = ""
scene_list: list[str] = list(all_check_rule_dict.keys())
for scene in scene_list:
if class_str in all_check_rule_dict[scene].keys():
class_check_rule = all_check_rule_dict[scene][class_str]
break
return str(class_check_rule)
def merge_conflict_inspection_data(data: dict) -> dict: def merge_conflict_inspection_data(data: dict) -> dict:
""" """
合并冲突数据相同 tag_id class_id 的重叠时间段 合并冲突数据相同 tag_id class_id 的重叠时间段

View File

@ -5,126 +5,131 @@ import supervision as sv
import cv2 import cv2
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
def generate_video_with_boxes( import os
boxes_data: list[dict], import cv2
from pathlib import Path
def generate_video_to_objects(
obj_dict: dict[dict],
input_video_path: str, input_video_path: str,
output_video_path: str, output_dir: str,
frame_annotation_interval: int
) -> None: ) -> None:
""" """
将提供的标注数据渲染到视频上支持按帧间隔进行标注 根据 obj_dict 中的物体信息从原视频中截取附近帧生成新视频
:param boxes_data: 包含标注信息的列表每个元素结构为: :param obj_dict: 包含物体信息的字典每个元素结构为:
{ "0": {
"frame_id": int, "class_id": 4,
"boxes": List[Tuple[int, int, int, int, str]] # (x1, y1, x2, y2, label) "start_frame": 551,
} "end_frame": 597
},
:param input_video_path: 输入视频文件路径 :param input_video_path: 输入视频文件路径
:param output_video_path: 输出视频文件路径 :param output_dir: 输出视频目录
:param frame_annotation_interval: 标注间隔单位默认为 1每帧标注
:return:
""" """
# ------------------------------------------------- # 最低秒数
# 1. 视频读取与基本配置 min_seconds = 2
# ------------------------------------------------- # 前后额外帧数
cap = cv2.VideoCapture(input_video_path) extra_frames = 5
if not cap.isOpened():
raise FileNotFoundError(f"无法打开视频文件或流: {input_video_path}") print(f"开始抽取物体视频: {input_video_path}")
print(f"输出目录: {output_dir}")
# 确保输出目录存在
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 1. 打开原始视频
backends = [
(cv2.CAP_FFMPEG, 'FFmpeg'),
(cv2.CAP_DSHOW, 'DirectShow'),
(cv2.CAP_ANY, 'Default')
]
cap = None
for backend, backend_name in backends:
try:
cap = cv2.VideoCapture(input_video_path, backend)
if cap.isOpened():
if backend == cv2.CAP_FFMPEG:
try:
cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
hw_accel = cap.get(cv2.CAP_PROP_HW_ACCELERATION)
print(f"FFmpeg硬件加速: {'已启用' if hw_accel > 0 else '未启用'}")
except Exception as e:
print(f"设置硬件加速失败: {e}")
print(f"使用后端: {backend_name}")
break
except Exception as e:
print(f"尝试{backend_name}后端失败: {e}")
continue
if not cap or not cap.isOpened():
raise Exception(f"无法打开视频文件: {input_video_path}")
# 2. 获取视频参数
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
min_frames = int(min_seconds * fps)
# ------------------------------------------------- fourcc = cv2.VideoWriter_fourcc(*'avc1')
# 2. 初始化 VideoWriter
# -------------------------------------------------
Path(os.path.dirname(output_video_path)).mkdir(parents=True, exist_ok=True)
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
out = cv2.VideoWriter(output_video_path, fourcc, fps, (video_width, video_height))
# ------------------------------------------------- # 3. 预处理所有物体的帧范围 + 初始化写入器
# 3. 初始化标注工具 total_objects = len(obj_dict)
# ------------------------------------------------- print(f"总共需要处理 {total_objects} 个物体")
box_annotator = sv.BoxAnnotator()
label_annotator = sv.RichLabelAnnotator(
font_path="C:/Windows/Fonts/simhei.ttf",
text_color=sv.Color.WHITE,
text_padding=5,
font_size=20
)
# ------------------------------------------------- writers = {} # 存储所有视频写入器
# 4. 数据预处理:构建 frame_id -> boxes 的映射 obj_ranges = {} # 存储每个物体的起止帧
# -------------------------------------------------
frame_to_boxes: dict[int, list[tuple[int, int, int, int, str]]] = {}
for entry in boxes_data:
frame_id = entry["frame_id"]
boxes = entry["boxes"]
# 确保每个框都有正确的结构
cleaned_boxes = []
for box in boxes:
if len(box) == 5:
cleaned_boxes.append(box) # (x1, y1, x2, y2, label)
frame_to_boxes[frame_id] = cleaned_boxes
# ------------------------------------------------- for track_id, track_data in obj_dict.items():
# 5. 主循环:逐帧读取并渲染 start_idx = max(0, track_data["start_frame"] - extra_frames)
# ------------------------------------------------- end_idx = min(total_frames - 1, track_data["end_frame"] + extra_frames)
# 保证最小长度
if end_idx - start_idx + 1 < min_frames:
end_idx = start_idx + min_frames - 1
output_path = os.path.join(output_dir, f"obj_{int(track_id):03d}.mp4")
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
if not out.isOpened():
raise Exception(f"无法初始化视频写入器: {output_path}")
writers[track_id] = out
obj_ranges[track_id] = (start_idx, end_idx)
print(f"物体 {track_id}: 帧 {start_idx} ~ {end_idx} -> {output_path}")
# 4. 单轮遍历视频,一次性写入所有需要的帧(最高效)
frame_idx = 0 frame_idx = 0
print_interval = 50
while cap.isOpened(): while cap.isOpened():
# 读取一帧
ret, frame = cap.read() ret, frame = cap.read()
if not ret: if not ret:
break break
# 计算当前帧对应的标注帧索引(考虑间隔) # 遍历所有物体,判断当前帧是否需要写入
annotation_frame_idx = frame_idx // frame_annotation_interval for track_id, (start, end) in obj_ranges.items():
if start <= frame_idx <= end:
writers[track_id].write(frame)
# 如果当前帧没有标注数据,直接写入原始帧 # 进度打印
if annotation_frame_idx not in frame_to_boxes: if frame_idx % print_interval == 0:
out.write(frame) progress = (frame_idx / total_frames) * 100
frame_idx += 1 print(f"处理帧: {frame_idx}/{total_frames} ({progress:.1f}%)")
continue
# 获取当前帧的所有框信息
boxes_info = frame_to_boxes[annotation_frame_idx]
boxes = []
labels = []
for _, (x1, y1, x2, y2, label) in enumerate(boxes_info):
# 坐标
boxes.append([x1, y1, x2, y2])
# 使用索引作为唯一标识符,或直接使用 label
labels.append(label)
# 转换为 NumPy 数组
boxes_np = np.array(boxes, dtype=np.float64)
# 构建 Detections 对象
detections = sv.Detections(
xyxy=boxes_np,
confidence=np.ones(len(boxes_np)), # 默认置信度为 1.0
class_id=np.zeros(len(boxes_np), dtype=int) # 类别 ID 在此场景下不重要
)
detections.tracker_id = np.arange(len(boxes_np), dtype=int) # 使用索引作为 ID
# 绘制边框和标签
annotated_frame = box_annotator.annotate(scene=frame.copy(), detections=detections)
annotated_frame = label_annotator.annotate(
scene=annotated_frame,
detections=detections,
labels=labels
)
# 写入帧
out.write(annotated_frame)
frame_idx += 1 frame_idx += 1
# ------------------------------------------------- # 5. 释放所有写入器
# 6. 资源释放 for track_id, out in writers.items():
# ------------------------------------------------- out.release()
start, end = obj_ranges[track_id]
print(f"物体 {track_id} 完成 | 总帧数: {end - start + 1}")
# 6. 释放资源
cap.release() cap.release()
out.release()
cv2.destroyAllWindows() cv2.destroyAllWindows()
print(f"视频渲染完成,已保存至: {output_video_path}") print(f"\n✅ 所有物体视频生成完成!目录: {output_dir}")
def process_track_id( def process_track_id(
track_id: int, track_id: int,
@ -489,3 +494,4 @@ def create_mian_vid_for_ai(
print(f"已生成视频: {output_video_path}, 共 {current_output_frame}") print(f"已生成视频: {output_video_path}, 共 {current_output_frame}")
return output_video_path return output_video_path

View File

@ -10,7 +10,8 @@ class SAM3:
def __init__(self): def __init__(self):
self._data: dict = {} self._data: dict = {}
def run(self, vid_path: str, output_dir: str, class_file_path: str, interval: int = 1, conf: float = 0.6) -> dict: # def run(self, vid_path: str, output_dir: str, class_file_path: str, interval: int = 1, conf: float = 0.6) -> dict:
def run(self, vid_path: str, output_dir: str, class_file_path: str, conf: float = 0.6) -> dict:
""" """
运行 SAM3 模型返回视频中每个帧的检测结果 运行 SAM3 模型返回视频中每个帧的检测结果
参数: 参数:
@ -35,7 +36,7 @@ class SAM3:
, ,
} }
""" """
new_vid_path = extract_frames_to_video(vid_path, output_dir, interval) # new_vid_path = extract_frames_to_video(vid_path, output_dir, interval)
self._data = {} self._data = {}
text_dict = json.load(open(class_file_path, "r", encoding="utf-8")) text_dict = json.load(open(class_file_path, "r", encoding="utf-8"))
@ -46,7 +47,7 @@ class SAM3:
predictor = SAM3VideoSemanticPredictor(overrides=overrides) predictor = SAM3VideoSemanticPredictor(overrides=overrides)
results = predictor( results = predictor(
source=new_vid_path, source=vid_path,
text=text_list_en, text=text_list_en,
stream=True, stream=True,
) )
@ -96,13 +97,13 @@ class SAM3:
# annotated_frame = r.plot() # annotated_frame = r.plot()
frame_idx += 1 frame_idx += 1
with open(f"{output_dir}/frame_all.json", "w", encoding="utf-8") as f: with open(f"{output_dir}/frame_detections.json", "w", encoding="utf-8") as f:
json.dump(annotated_frames, f, ensure_ascii=False, indent=2) json.dump(annotated_frames, f, ensure_ascii=False, indent=2)
print(f"\n总共处理 {frame_idx}") print(f"\n总共处理 {frame_idx}")
print(f"结果已保存到 {output_dir}/frame_all.json") print(f"结果已保存到 {output_dir}/frame_detections.json")
print(f"标记图片已保存到 {output_dir}/boxes/") print(f"标记图片已保存到 {output_dir}/boxes/")
cv2.destroyAllWindows() cv2.destroyAllWindows()

View File

@ -13,9 +13,7 @@
{ {
"hazard_track_id": 0, "hazard_track_id": 0,
"tag_id": 0, "tag_id": 0,
"class_id": 0,
"level": 0, "level": 0,
"start_frame": 0,
"base_id": 0, "base_id": 0,
"location": "这里是隐患位置描述" "location": "这里是隐患位置描述"
}, },
@ -37,18 +35,14 @@
{ {
"hazard_track_id": 0, "hazard_track_id": 0,
"tag_id": 0, "tag_id": 0,
"class_id": 0,
"level": 0, "level": 0,
"start_frame": 0,
"base_id": 0, "base_id": 0,
"location": "这里是隐患位置描述" "location": "这里是隐患位置描述"
}, },
{ {
"hazard_track_id": 1, "hazard_track_id": 1,
"tag_id": 1, "tag_id": 1,
"class_id": 1,
"level": 1, "level": 1,
"start_frame": 0,
"base_id": 1, "base_id": 1,
"location": "这里是隐患位置描述" "location": "这里是隐患位置描述"
} }
@ -58,10 +52,9 @@
# 输出格式注意事项 # 输出格式注意事项
- 你的输出只能包含tag、base和objects三个键。 - 你的输出只能包含tag、base和objects三个键。
- tag是一个字符串数组每个元素是隐患标签必须为中文不能使用英文。每个隐患点只能有一个标签如果规则中存在多个标签必须选择最符合视频中情况的一个标签。 - tag是一个字符串数组每个元素是隐患标签必须为中文不能使用英文。每个隐患点只能有一个标签如果规则中存在多个标签必须选择最符合视频中情况的一个标签。
- objects是一个字典列表每个字典必须包含hazard_track_id整数、tag_id整数class_id整数level整数、start_frame整数、base_id整数、location字符串 - objects是一个字典列表每个字典必须包含hazard_track_id整数、tag_id整数、level整数、base_id整数、location字符串
- hazard_track_id分配规则根据视频画面每个连续出现的类型+隐患组合应分配单独的hazard_track_id。相同的tag_id+class_id组合在时间上不应相交应确保只存在一份连续的记录。绝对禁止出现多个hazard_track_id具有相同的tag_id和class_id且时间重叠的情况 - hazard_track_id分配规则根据视频画面每个隐患点只能分配一个hazard_track_id不能重复分配
- level必须为0或1或2不能为其他整数。为0表示隐患等级为疑似为1表示隐患等级为低为2表示隐患等级为高。 - level必须为0或1或2不能为其他整数。为0表示隐患等级为疑似为1表示隐患等级为低为2表示隐患等级为高。
- 不要输出bbox_2d、label或任何中文标点。
- 输出格式必须为标准json格式且结构必须与模板一致 - 输出格式必须为标准json格式且结构必须与模板一致
- class_id必须与class_list中的顺序严格一致保持一一映射关系。 - class_id必须与class_list中的顺序严格一致保持一一映射关系。
- 所有hazard_track_id都为独立隐患点不存在误检不得合并或拆分 - 所有hazard_track_id都为独立隐患点不存在误检不得合并或拆分
@ -72,16 +65,15 @@
- location表示该隐患在视频画面中的位置描述必须为中文描述要准确、清晰能够明确指出隐患在画面中的相对位置。 - location表示该隐患在视频画面中的位置描述必须为中文描述要准确、清晰能够明确指出隐患在画面中的相对位置。
# 任务1 # 任务1
- **帧级分析**:根据提供的物体class、出现时间点提供的数据为每帧对应的物体列表与隐患识别规则在视频中同track_id的物体需持续观察对隐患进行识别每个隐患点分配一个hazard_track_id杜绝在同一个物体上重复识别隐患点 - **帧级分析**:根据提供的物体名称与隐患识别规则,在视频中对隐患进行识别每个隐患点分配一个hazard_track_id杜绝在同一个物体上重复识别隐患点
- **汇总处理**在完成所有帧的分析后基于各帧的分析结果为每个hazard_track_id确定最终的隐患标签、等级、位置描述以及开始帧位置 - **汇总处理**在完成所有帧的分析后基于各帧的分析结果为每个hazard_track_id确定最终的隐患标签、等级、位置描述以及开始帧位置
- **基本要求**:只检测由class确定的物体每个hazard_track_id对应的字典中必须包含该hazard_track_id的tag_id、class_id、level、start_frame、base_id、location信息 - **基本要求**:只检测指定物体每个hazard_track_id对应的字典中必须包含该hazard_track_id的tag_id、level、base_id、location信息
- **匹配规则**如果物体与检测条目匹配就将该检测条目添加到objects列表中并设置相应的tag_id - **匹配规则**如果物体与检测条目匹配就将该检测条目添加到objects列表中并设置相应的tag_id
- **关键约束** - **关键约束**
1. **逐帧分析**必须分析每一帧中的每个物体用class匹配相应的检测规则根据隐患识别规则进行隐患识别 1. **语音识别**:必须对视频中的语音进行识别,辅助隐患识别
2. **语音识别**:必须对视频中的语音进行识别,辅助隐患识别 2. **规则参考**:严格参考知识库中的规则结构进行隐患识别,规则结构参考 `知识库/rule.json`
3. **规则参考**:严格参考知识库中的规则结构进行隐患识别,规则结构参考 `知识库/rule.json` 3. **全面识别**:必须对提供的物体进行隐患识别
4. **全面识别**:必须对所有提供的物体进行隐患识别,不得遗漏任何物体 4. **准确匹配**:根据物体名称与隐患识别规则进行准确匹配,确定隐患标签和等级
5. **准确匹配**根据物体的class与隐患识别规则进行准确匹配确定隐患标签和等级 5. **等级判定**根据规则中的匹配条件和依据合理判定匹配等级0-疑似1-确定)
6. **等级判定**根据规则中的匹配条件和依据合理判定匹配等级0-疑似1-确定) 6. **hazard_track_id分配**根据视频画面每个隐患点应分配单独的hazard_track_id。
7. **hazard_track_id分配**根据视频画面每个隐患点应分配单独的hazard_track_id。 7. **位置描述**大模型需在输出时提供隐患点相对于视频画面的位置location字段必须准确描述隐患在画面中的位置例如"画面左上角"、"画面中央偏右"等。
8. **位置描述**大模型需在输出时提供隐患点相对于视频画面的位置location字段必须准确描述隐患在画面中的位置例如"画面左上角"、"画面中央偏右"等。

View File

@ -1,3 +1,3 @@
### qwen试用 ### qwen试用
账号: 18237294717 账号: 18237294717 王凡的
https://bailian.console.aliyun.com/cn-beijing/?spm=a2c4g.11186623.0.0.2c2b2217h6fbio&tab=app#/app-center https://bailian.console.aliyun.com/cn-beijing/?spm=a2c4g.11186623.0.0.2c2b2217h6fbio&tab=app#/app-center

98
requirements.txt Normal file
View File

@ -0,0 +1,98 @@
aiofiles==24.1.0
annotated-doc==0.0.4
annotated-types==0.7.0
anyio==4.12.1
audioop-lts==0.2.2
brotli==1.2.0
certifi==2026.2.25
charset-normalizer==3.4.5
click==8.3.1
clip @ git+https://github.com/ultralytics/CLIP.git@88ade288431a46233f1556d1e141901b3ef0a36b
colorama==0.4.6
contourpy==1.3.3
cycler==0.12.1
defusedxml==0.7.1
distro==1.9.0
fastapi==0.135.1
ffmpy==1.0.0
filelock==3.25.1
fonttools==4.62.0
fsspec==2026.2.0
ftfy==6.3.1
gradio==6.9.0
gradio_client==2.3.0
groovy==0.1.2
h11==0.16.0
hf-xet==1.3.2
httpcore==1.0.9
httpx==0.28.1
huggingface_hub==1.6.0
idna==3.11
Jinja2==3.1.6
jiter==0.13.0
kiwisolver==1.5.0
lap==0.5.13
markdown-it-py==4.0.0
MarkupSafe==3.0.3
matplotlib==3.10.8
mdurl==0.1.2
MouseInfo==0.1.3
mpmath==1.3.0
networkx==3.6.1
numpy==2.4.3
openai==2.29.0
opencv-python==4.13.0.92
orjson==3.11.7
packaging @ file:///C:/miniconda3/conda-bld/packaging_1761049099114/work
pandas==3.0.1
pillow==12.1.1
polars==1.38.1
polars-runtime-32==1.38.1
psutil==7.2.2
PyAutoGUI==0.9.54
pydantic==2.12.5
pydantic_core==2.41.5
pyDeprecate==0.5.0
pydub==0.25.1
PyGetWindow==0.0.9
Pygments==2.19.2
PyMsgBox==2.0.1
pyparsing==3.3.2
pyperclip==1.11.0
PyRect==0.2.0
PyScreeze==1.0.1
python-dateutil==2.9.0.post0
python-multipart==0.0.22
pytweening==1.2.0
pytz==2026.1.post1
PyYAML==6.0.3
regex==2026.2.28
requests==2.32.5
rich==14.3.3
safehttpx==0.1.7
safetensors==0.7.0
scipy==1.17.1
semantic-version==2.10.0
setuptools==80.10.2
shellingham==1.5.4
six==1.17.0
sniffio==1.3.1
starlette==0.52.1
supervision==0.27.0.post2
sympy==1.14.0
timm==1.0.25
tomlkit==0.13.3
torch==2.7.1+cu118
torchaudio==2.7.1+cu118
torchvision==0.22.1+cu118
tqdm==4.67.3
typer==0.24.1
typing-inspection==0.4.2
typing_extensions==4.15.0
tzdata==2025.3
ultralytics==8.4.21
ultralytics-thop==2.0.18
urllib3==2.6.3
uvicorn==0.41.0
wcwidth==0.6.0
wheel==0.46.3