更新新流程
This commit is contained in:
parent
63a9cba36c
commit
67b78ec361
396
Qwen_app.py
396
Qwen_app.py
|
|
@ -1,5 +1,8 @@
|
||||||
|
from lib.json_fun import f_detections_to_objects, load_json_data
|
||||||
from lib.qwen_fun import get_annnotated_frame_for_ai_without_xyxy, hazard_inspection, merge_conflict_inspection_data, report_generator, search_knowledge_base
|
from lib.qwen_fun import get_annnotated_frame_for_ai_without_xyxy, hazard_inspection, merge_conflict_inspection_data, report_generator, search_knowledge_base
|
||||||
from encodings.punycode import T
|
from encodings.punycode import T
|
||||||
|
|
||||||
|
from lib.qwen_fun_vid import generate_video_to_objects
|
||||||
"""
|
"""
|
||||||
测试 在给定标注框与类别,原始视频经过转换之后,AI能否准确识别物体特征
|
测试 在给定标注框与类别,原始视频经过转换之后,AI能否准确识别物体特征
|
||||||
"""
|
"""
|
||||||
|
|
@ -17,7 +20,6 @@ from pathlib import Path
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
fps = 0
|
fps = 0
|
||||||
<<<<<<< HEAD
|
|
||||||
VIDEO_FOLDER: str = "input"
|
VIDEO_FOLDER: str = "input"
|
||||||
|
|
||||||
def load_file_list() -> list[str]:
|
def load_file_list() -> list[str]:
|
||||||
|
|
@ -37,6 +39,14 @@ def load_file_list() -> list[str]:
|
||||||
|
|
||||||
return file_names
|
return file_names
|
||||||
|
|
||||||
|
def reload_files():
|
||||||
|
"""
|
||||||
|
刷新文件列表并设置默认值
|
||||||
|
"""
|
||||||
|
file_list = load_file_list()
|
||||||
|
default_value = file_list[0] if file_list else None
|
||||||
|
return gr.update(choices=file_list, value=default_value)
|
||||||
|
|
||||||
def get_full_vid_path(vid_file: str) -> str:
|
def get_full_vid_path(vid_file: str) -> str:
|
||||||
"""
|
"""
|
||||||
获取视频绝对路径
|
获取视频绝对路径
|
||||||
|
|
@ -47,10 +57,6 @@ def get_full_vid_path(vid_file: str) -> str:
|
||||||
|
|
||||||
def update_preview(frame_idx: int, vid_file: str):
|
def update_preview(frame_idx: int, vid_file: str):
|
||||||
vid_name: str = Path(vid_file).stem
|
vid_name: str = Path(vid_file).stem
|
||||||
=======
|
|
||||||
|
|
||||||
def update_preview(frame_idx: int, vid_name: str):
|
|
||||||
>>>>>>> 7562de4 (2 预览图还有点问题)
|
|
||||||
output_dir: str = f"output/{vid_name}"
|
output_dir: str = f"output/{vid_name}"
|
||||||
global fps
|
global fps
|
||||||
idx = int(frame_idx // fps)
|
idx = int(frame_idx // fps)
|
||||||
|
|
@ -64,114 +70,222 @@ def update_preview(frame_idx: int, vid_name: str):
|
||||||
|
|
||||||
return img_path, class_tag
|
return img_path, class_tag
|
||||||
|
|
||||||
def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True,):
|
# def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True,):
|
||||||
#================初始化=================
|
# #================初始化=================
|
||||||
vid_name: str = Path(vid_file).stem
|
# vid_name: str = Path(vid_file).stem
|
||||||
vid_end: str = Path(vid_file).suffix
|
# vid_end: str = Path(vid_file).suffix
|
||||||
vid_path: str = f"./input/{vid_name}{vid_end}"
|
# vid_path: str = f"./input/{vid_name}{vid_end}"
|
||||||
|
|
||||||
cap = cv2.VideoCapture(vid_path)
|
# cap = cv2.VideoCapture(vid_path)
|
||||||
global fps
|
# global fps
|
||||||
fps = int(cap.get(cv2.CAP_PROP_FPS)) # 获取视频FPS
|
# fps = int(cap.get(cv2.CAP_PROP_FPS)) # 获取视频FPS
|
||||||
|
# cap.release()
|
||||||
|
|
||||||
|
# output_dir: str = f"output/{vid_name}"
|
||||||
|
# # interval = int(fps / 5)
|
||||||
|
# interval = fps
|
||||||
|
# conf = 0.7
|
||||||
|
# time_data: dict = {}
|
||||||
|
|
||||||
|
# annotated_frames: SAM3 = SAM3()
|
||||||
|
# result: dict = {}
|
||||||
|
|
||||||
|
# if output_dir:
|
||||||
|
# os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# #================获取物体信息=================
|
||||||
|
# # 保存开始时间字符串
|
||||||
|
# time_data["start_time"] = str(datetime.now())
|
||||||
|
# with open(f"{output_dir}/time.json", "w") as f:
|
||||||
|
# f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
|
# if run_sam3:
|
||||||
|
# # 针对厂房防火分区
|
||||||
|
# annotated_frames.run(vid_path, output_dir, "lib/class_list/1.厂房防火.json", interval, conf)
|
||||||
|
# else:
|
||||||
|
# annotated_frames.load_from_json(f"{output_dir}/frame_all.json")
|
||||||
|
# # print(annotated_frames.data)
|
||||||
|
|
||||||
|
# # 提取ai能看到的部分
|
||||||
|
# ai_frames: dict = get_annnotated_frame_for_ai_without_xyxy(annotated_frames.data(), 1, conf)
|
||||||
|
# save_json_to_file(ai_frames, f"{output_dir}/frame_all_ai.json")
|
||||||
|
|
||||||
|
# #================隐患检查=================
|
||||||
|
|
||||||
|
# if run_inspection:
|
||||||
|
# if vid_path.startswith("oss"):
|
||||||
|
# video_url = vid_path
|
||||||
|
# else:
|
||||||
|
# video_url: str|None = None
|
||||||
|
# video_url_file = f"{output_dir}/video_url.json"
|
||||||
|
|
||||||
|
# # 检查URL是否已存在
|
||||||
|
# if os.path.exists(video_url_file):
|
||||||
|
# try:
|
||||||
|
# with open(video_url_file, "r", encoding="utf-8") as f:
|
||||||
|
# url_data = json.load(f)
|
||||||
|
# if vid_name in url_data:
|
||||||
|
# video_url = url_data[vid_name]
|
||||||
|
# print(f"使用已存在的URL: {video_url}")
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"读取URL文件失败: {e}")
|
||||||
|
|
||||||
|
# # 如果URL不存在,上传文件
|
||||||
|
# if video_url is None:
|
||||||
|
# print(f"上传视频文件: {vid_path}")
|
||||||
|
# video_url = upload_files_and_get_urls_concurrently(
|
||||||
|
# file_path_list=[vid_path],
|
||||||
|
# max_workers=8
|
||||||
|
# )[0]
|
||||||
|
# if video_url:
|
||||||
|
# # 保存为JSON格式,包含文件名和URL的键值对
|
||||||
|
# url_data = {}
|
||||||
|
# if os.path.exists(video_url_file):
|
||||||
|
# try:
|
||||||
|
# with open(video_url_file, "r", encoding="utf-8") as f:
|
||||||
|
# url_data = json.load(f)
|
||||||
|
# except:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# url_data[vid_name] = video_url
|
||||||
|
# with open(video_url_file, "w", encoding="utf-8") as f:
|
||||||
|
# json.dump(url_data, f, ensure_ascii=False, indent=4)
|
||||||
|
# print(f"URL已保存: {video_url}")
|
||||||
|
|
||||||
|
# if video_url is None:
|
||||||
|
# raise ValueError("视频上传失败,无法获取 URL")
|
||||||
|
|
||||||
|
# result_test: str
|
||||||
|
# reason_test: str
|
||||||
|
# reason_test, result_test = hazard_inspection(ai_frames, video_url, enable_thinking=True, fps=2)
|
||||||
|
# result = json.loads(result_test)
|
||||||
|
# # merged_result = merge_conflict_inspection_data(result)
|
||||||
|
|
||||||
|
# result["class"] = ai_frames["class_list"]
|
||||||
|
|
||||||
|
# with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f:
|
||||||
|
# f.write(json.dumps(result, ensure_ascii=False, indent=4))
|
||||||
|
# with open(f"{output_dir}/hazard_inspection_reason.json", "w", encoding="utf-8") as f:
|
||||||
|
# f.write(json.dumps(reason_test, ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
|
# # with open(f"{output_dir}/hazard_inspection_merged.json", "w", encoding="utf-8") as f:
|
||||||
|
# # f.write(json.dumps(merged_result, ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
|
# #================生成报告=================
|
||||||
|
# if gen_report:
|
||||||
|
# if result == {}:
|
||||||
|
# with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
|
||||||
|
# result = json.load(f)
|
||||||
|
|
||||||
|
# with open(f"知识库/rule.json", "r", encoding="utf-8") as f:
|
||||||
|
# rule_definitions = json.load(f)
|
||||||
|
|
||||||
|
# report_generator(
|
||||||
|
# video_path=vid_path, # 视频文件路径
|
||||||
|
# detection_data=annotated_frames.data(), # 物体检测数据
|
||||||
|
# hazard_results=result, # 隐患检查结果
|
||||||
|
# rule_definitions=rule_definitions, # 规则定义
|
||||||
|
# output_path=output_dir, # 输出文件夹
|
||||||
|
# frame_interval=interval, # 帧间隔(可根据实际视频帧率调整)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # 保存结束时间字符串
|
||||||
|
# time_data["end_time"] = str(datetime.now())
|
||||||
|
# with open(f"{output_dir}/time.json", "w") as f:
|
||||||
|
# f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
|
# #================更新预览=================
|
||||||
|
|
||||||
|
# # 获取总帧数
|
||||||
|
# cap = cv2.VideoCapture(vid_path)
|
||||||
|
# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
# cap.release()
|
||||||
|
# img_path, class_tag = update_preview(0, vid_name)
|
||||||
|
# return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag
|
||||||
|
|
||||||
|
def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True):
|
||||||
|
vid_name = vid_file.split(".")[0] # 视频名称(不含后缀)
|
||||||
|
vid_end = f".{vid_file.split('.')[-1]}" # 视频后缀
|
||||||
|
use_url_cache = True # 是否使用 URL 缓存,避免重复上传视频
|
||||||
|
enable_thinking = True # 是否启用思考模式
|
||||||
|
|
||||||
|
run_vid_process = True # 是否运行视频处理流程(提取物体视频)
|
||||||
|
|
||||||
|
input_video_path = f"input/{vid_name}{vid_end}"
|
||||||
|
output_dir = f"output/{vid_name}"
|
||||||
|
frame_detections_path = f"{output_dir}/frame_detections.json"
|
||||||
|
objects_json_path = f"{output_dir}/objects.json"
|
||||||
|
vid_dir = f"{output_dir}/obj_vids"
|
||||||
|
|
||||||
|
rule_dict: dict = load_json_data('知识库/rule.json')
|
||||||
|
video_url_file = f"{output_dir}/video_url.json"
|
||||||
|
time_data: dict = {
|
||||||
|
# 保存开始时间字符串
|
||||||
|
"start_time": str(datetime.now())
|
||||||
|
}
|
||||||
|
interval = 1
|
||||||
|
annotated_frames: SAM3 = SAM3()
|
||||||
|
|
||||||
|
# 获取总帧数与帧率
|
||||||
|
cap = cv2.VideoCapture(input_video_path)
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
cap.release()
|
cap.release()
|
||||||
|
|
||||||
output_dir: str = f"output/{vid_name}"
|
|
||||||
# interval = int(fps / 5)
|
|
||||||
interval = fps
|
|
||||||
conf = 0.7
|
|
||||||
time_data: dict = {}
|
|
||||||
|
|
||||||
annotated_frames: SAM3 = SAM3()
|
|
||||||
result: dict = {}
|
|
||||||
|
|
||||||
if output_dir:
|
if output_dir:
|
||||||
os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错
|
os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#================获取物体信息=================
|
|
||||||
# 保存开始时间字符串
|
|
||||||
time_data["start_time"] = str(datetime.now())
|
|
||||||
with open(f"{output_dir}/time.json", "w") as f:
|
with open(f"{output_dir}/time.json", "w") as f:
|
||||||
f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
|
f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
if run_sam3:
|
#================获取物体信息=================
|
||||||
|
#这里可以换成你的yolo运行逻辑,输出文件保存在frame_detections_path
|
||||||
|
|
||||||
|
if run_sam3: # 开关判定
|
||||||
# 针对厂房防火分区
|
# 针对厂房防火分区
|
||||||
annotated_frames.run(vid_path, output_dir, "lib/class_list/1.厂房防火.json", interval, conf)
|
annotated_frames.run(input_video_path, output_dir, "lib/class_list/1.厂房防火.json")
|
||||||
else:
|
else: # 不运行sam3,直接加载之前的结果
|
||||||
annotated_frames.load_from_json(f"{output_dir}/frame_all.json")
|
annotated_frames.load_from_json(frame_detections_path)
|
||||||
# print(annotated_frames.data)
|
# print(annotated_frames.data)
|
||||||
|
|
||||||
# 提取ai能看到的部分
|
|
||||||
ai_frames: dict = get_annnotated_frame_for_ai_without_xyxy(annotated_frames.data(), 1, conf)
|
|
||||||
save_json_to_file(ai_frames, f"{output_dir}/frame_all_ai.json")
|
|
||||||
|
|
||||||
#================隐患检查=================
|
#================隐患检查=================
|
||||||
|
if run_inspection: # 开关判定
|
||||||
|
|
||||||
if run_inspection:
|
|
||||||
if vid_path.startswith("oss"):
|
|
||||||
video_url = vid_path
|
|
||||||
else:
|
|
||||||
video_url: str|None = None
|
|
||||||
video_url_file = f"{output_dir}/video_url.json"
|
|
||||||
|
|
||||||
# 检查URL是否已存在
|
# 提取物体信息
|
||||||
if os.path.exists(video_url_file):
|
f_detections_to_objects(
|
||||||
try:
|
frame_detections_path,
|
||||||
with open(video_url_file, "r", encoding="utf-8") as f:
|
objects_json_path
|
||||||
url_data = json.load(f)
|
)
|
||||||
if vid_name in url_data:
|
obj = json.load(open(objects_json_path, "r", encoding="utf-8"))
|
||||||
video_url = url_data[vid_name]
|
class_list = obj["class_list"]
|
||||||
print(f"使用已存在的URL: {video_url}")
|
obj_dict = obj["track_id_list"]
|
||||||
except Exception as e:
|
|
||||||
print(f"读取URL文件失败: {e}")
|
|
||||||
|
|
||||||
# 如果URL不存在,上传文件
|
if run_vid_process:
|
||||||
if video_url is None:
|
# 生成物体视频
|
||||||
print(f"上传视频文件: {vid_path}")
|
generate_video_to_objects(
|
||||||
video_url = upload_files_and_get_urls_concurrently(
|
obj_dict,
|
||||||
file_path_list=[vid_path],
|
input_video_path,
|
||||||
max_workers=8
|
output_dir=vid_dir,
|
||||||
)[0]
|
)
|
||||||
if video_url:
|
|
||||||
# 保存为JSON格式,包含文件名和URL的键值对
|
|
||||||
url_data = {}
|
|
||||||
if os.path.exists(video_url_file):
|
|
||||||
try:
|
|
||||||
with open(video_url_file, "r", encoding="utf-8") as f:
|
|
||||||
url_data = json.load(f)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
url_data[vid_name] = video_url
|
# 上传视频并获取 URL
|
||||||
with open(video_url_file, "w", encoding="utf-8") as f:
|
vid_dict = get_vid_dict(vid_dir, obj_dict, video_url_file, use_url_cache)
|
||||||
json.dump(url_data, f, ensure_ascii=False, indent=4)
|
|
||||||
print(f"URL已保存: {video_url}")
|
|
||||||
|
|
||||||
if video_url is None:
|
hazard_inspection(
|
||||||
raise ValueError("视频上传失败,无法获取 URL")
|
output_dir,
|
||||||
|
obj_dict,
|
||||||
<<<<<<< HEAD
|
rule_dict,
|
||||||
result_test: str
|
class_list,
|
||||||
reason_test: str
|
vid_dict,
|
||||||
reason_test, result_test = hazard_inspection(ai_frames, video_url, enable_thinking=True, fps=2)
|
fps=fps,
|
||||||
result = json.loads(result_test)
|
enable_thinking=enable_thinking
|
||||||
=======
|
)
|
||||||
result = json.loads(hazard_inspection(ai_frames, video_url, enable_thinking=False, fps=2)[1])
|
|
||||||
>>>>>>> 7562de4 (2 预览图还有点问题)
|
|
||||||
# merged_result = merge_conflict_inspection_data(result)
|
|
||||||
|
|
||||||
result["class"] = ai_frames["class_list"]
|
|
||||||
|
|
||||||
with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(result, ensure_ascii=False, indent=4))
|
|
||||||
with open(f"{output_dir}/hazard_inspection_reason.json", "w", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(reason_test, ensure_ascii=False, indent=4))
|
|
||||||
|
|
||||||
# with open(f"{output_dir}/hazard_inspection_merged.json", "w", encoding="utf-8") as f:
|
|
||||||
# f.write(json.dumps(merged_result, ensure_ascii=False, indent=4))
|
|
||||||
|
|
||||||
#================生成报告=================
|
#================生成报告=================
|
||||||
|
|
||||||
if gen_report:
|
if gen_report:
|
||||||
if result == {}:
|
if result == {}:
|
||||||
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
|
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
|
||||||
|
|
@ -181,7 +295,7 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r
|
||||||
rule_definitions = json.load(f)
|
rule_definitions = json.load(f)
|
||||||
|
|
||||||
report_generator(
|
report_generator(
|
||||||
video_path=vid_path, # 视频文件路径
|
video_path=input_video_path, # 视频文件路径
|
||||||
detection_data=annotated_frames.data(), # 物体检测数据
|
detection_data=annotated_frames.data(), # 物体检测数据
|
||||||
hazard_results=result, # 隐患检查结果
|
hazard_results=result, # 隐患检查结果
|
||||||
rule_definitions=rule_definitions, # 规则定义
|
rule_definitions=rule_definitions, # 规则定义
|
||||||
|
|
@ -196,13 +310,61 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r
|
||||||
|
|
||||||
#================更新预览=================
|
#================更新预览=================
|
||||||
|
|
||||||
# 获取总帧数
|
|
||||||
cap = cv2.VideoCapture(vid_path)
|
|
||||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
||||||
cap.release()
|
|
||||||
img_path, class_tag = update_preview(0, vid_name)
|
img_path, class_tag = update_preview(0, vid_name)
|
||||||
return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag
|
return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag
|
||||||
|
|
||||||
|
def get_vid_dict(vid_dir: str, obj_dict: dict, video_url_file: str, use_url_cache: bool) -> dict:
|
||||||
|
"""获取视频字典
|
||||||
|
参数:
|
||||||
|
vid_dir: 物体视频存放目录
|
||||||
|
obj_dict: 物品字典,键为物品TrackID,值为物品信息
|
||||||
|
video_url_file: 视频URL文件路径
|
||||||
|
use_url_cache: 是否使用缓存的URL
|
||||||
|
返回:
|
||||||
|
视频字典,键为物品TrackID,值为视频本地地址和视频URL
|
||||||
|
"""
|
||||||
|
|
||||||
|
vid_dict = {}
|
||||||
|
|
||||||
|
# 检查URL是否已存在
|
||||||
|
if use_url_cache:
|
||||||
|
if os.path.exists(video_url_file):
|
||||||
|
try:
|
||||||
|
with open(video_url_file, "r", encoding="utf-8") as f:
|
||||||
|
vid_dict = json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取URL文件失败: {e}")
|
||||||
|
for track_id in obj_dict.keys():
|
||||||
|
if track_id in vid_dict:
|
||||||
|
continue
|
||||||
|
vid_path = f"{vid_dir}/obj_{int(track_id):03d}.mp4"
|
||||||
|
vid_dict[track_id] = {"vid_path": vid_path, "vid_url": None}
|
||||||
|
|
||||||
|
# 找到未上传的视频
|
||||||
|
unuploaded_vids = []
|
||||||
|
for track_id, vid_info in vid_dict.items():
|
||||||
|
if vid_info["vid_url"] is None:
|
||||||
|
unuploaded_vids.append(vid_info["vid_path"])
|
||||||
|
|
||||||
|
# 上传未上传的视频并获取 URL
|
||||||
|
uploaded_urls = upload_files_and_get_urls_concurrently(
|
||||||
|
file_path_list=unuploaded_vids,
|
||||||
|
max_workers=8
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新视频URL
|
||||||
|
for track_id, vid_info in vid_dict.items():
|
||||||
|
if vid_info["vid_url"] is None:
|
||||||
|
idx = unuploaded_vids.index(vid_info["vid_path"])
|
||||||
|
vid_info["vid_url"] = uploaded_urls[idx]
|
||||||
|
|
||||||
|
# 保存字典为 JSON 文件
|
||||||
|
with open(video_url_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(vid_dict, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
return vid_dict
|
||||||
|
|
||||||
|
|
||||||
# 创建 Gradio 页面
|
# 创建 Gradio 页面
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
|
|
@ -210,7 +372,7 @@ with gr.Blocks() as demo:
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
<<<<<<< HEAD
|
|
||||||
initial_files: list[str] = load_file_list()
|
initial_files: list[str] = load_file_list()
|
||||||
default_video = initial_files[0] if initial_files else None
|
default_video = initial_files[0] if initial_files else None
|
||||||
vid_file = gr.Dropdown(
|
vid_file = gr.Dropdown(
|
||||||
|
|
@ -219,10 +381,8 @@ with gr.Blocks() as demo:
|
||||||
value=default_video, # 默认选中第一个
|
value=default_video, # 默认选中第一个
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
=======
|
reload_file_list_button = gr.Button("刷新视频列表")
|
||||||
vid_name = gr.Textbox(label="视频名称", value="Miehhuoxqih")
|
reload_file_list_button.click(fn=reload_files, inputs=[], outputs=[vid_file])
|
||||||
vid_end = gr.Textbox(label="视频后缀", value=".AVI")
|
|
||||||
>>>>>>> 7562de4 (2 预览图还有点问题)
|
|
||||||
|
|
||||||
run_sam3 = gr.Checkbox(label="1. 运行 SAM3 模型", value=True)
|
run_sam3 = gr.Checkbox(label="1. 运行 SAM3 模型", value=True)
|
||||||
run_inspection = gr.Checkbox(label="2. 运行隐患排查", value=True)
|
run_inspection = gr.Checkbox(label="2. 运行隐患排查", value=True)
|
||||||
|
|
@ -251,7 +411,7 @@ def get_class_tag_by_frame(data, idx, fps):
|
||||||
根据给定的帧索引 (idx),返回在该帧范围内的所有对象的 class:tag 信息。
|
根据给定的帧索引 (idx),返回在该帧范围内的所有对象的 class:tag 信息。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
data (dict): 包含 'class', 'tag', 'objects' 的字典数据
|
data (dict): 包含 'tag', 'objects' 的字典数据(hazard_inspection.json格式)
|
||||||
idx (int): 需要查询的帧索引
|
idx (int): 需要查询的帧索引
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
|
|
@ -260,10 +420,9 @@ def get_class_tag_by_frame(data, idx, fps):
|
||||||
# 参数检查
|
# 参数检查
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
raise ValueError("数据必须是字典类型")
|
raise ValueError("数据必须是字典类型")
|
||||||
if 'class' not in data or 'tag' not in data or 'objects' not in data:
|
if 'tag' not in data or 'objects' not in data:
|
||||||
raise ValueError("数据必须包含 'class', 'tag' 和 'objects' 键")
|
raise ValueError("数据必须包含 'tag' 和 'objects' 键")
|
||||||
|
|
||||||
class_list = data['class']
|
|
||||||
tag_list = data['tag']
|
tag_list = data['tag']
|
||||||
objects = data['objects']
|
objects = data['objects']
|
||||||
|
|
||||||
|
|
@ -276,17 +435,21 @@ def get_class_tag_by_frame(data, idx, fps):
|
||||||
|
|
||||||
# 遍历每个物体
|
# 遍历每个物体
|
||||||
for obj in objects:
|
for obj in objects:
|
||||||
# 根据 class_id 和 tag_id 获取对应的字符串
|
# 根据 tag_id 获取对应的隐患标签字符串
|
||||||
class_str = class_list[obj['class_id']]
|
tag_id = obj.get('tag_id', 0)
|
||||||
tag_str = tag_list[obj['tag_id']]
|
class_id = obj.get('class_id', 0)
|
||||||
|
tag_str = tag_list[tag_id] if tag_id < len(tag_list) else f"未知标签({tag_id})"
|
||||||
location = obj.get('location', '')
|
location = obj.get('location', '')
|
||||||
|
start_frame = obj.get('start_frame', 0)
|
||||||
|
level = obj.get('level', '')
|
||||||
|
|
||||||
# 检查帧范围是否包含 idx
|
# 检查帧范围是否包含 idx
|
||||||
if obj['start_frame'] == idx:
|
if start_frame == idx:
|
||||||
result.append(f"{class_str}:{tag_str} (位置: {location})")
|
result.append(f"{tag_str} | 等级:{level} | 位置: {location}")
|
||||||
all_class_tag.append(f"{class_str}:{tag_str} (开始帧: {obj['start_frame']*interval}, 位置: {location})")
|
all_class_tag.append(f"{tag_str} | class_id:{class_id} | 等级:{level} | 开始帧:{start_frame} | 位置:{location}")
|
||||||
|
|
||||||
# 使用换行符连接所有结果
|
# 使用换行符连接所有结果
|
||||||
output = f"当前帧隐患:\n"+"\n".join(result)+"\n\n"+"所有对象的 class:tag 信息:\n"+"\n".join(all_class_tag)
|
output = f"当前帧隐患:\n"+"\n".join(result)+"\n\n"+"所有隐患对象信息:\n"+"\n".join(all_class_tag)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -295,7 +458,4 @@ if __name__ == "__main__":
|
||||||
demo.launch(
|
demo.launch(
|
||||||
debug=True,
|
debug=True,
|
||||||
allowed_paths=[VIDEO_FOLDER]
|
allowed_paths=[VIDEO_FOLDER]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
142
Qwen_cmd test.py
142
Qwen_cmd test.py
|
|
@ -1,4 +1,7 @@
|
||||||
from lib.qwen_fun import hazard_inspection, search_knowledge_base
|
import os
|
||||||
|
|
||||||
|
from lib.json_fun import f_detections_to_objects
|
||||||
|
from lib.qwen_fun import chat, hazard_inspection, load_json_data, search_knowledge_base, upload_files_and_get_urls_concurrently
|
||||||
from encodings.punycode import T
|
from encodings.punycode import T
|
||||||
"""
|
"""
|
||||||
测试 在给定标注框与类别,原始视频经过转换之后,AI能否准确识别物体特征
|
测试 在给定标注框与类别,原始视频经过转换之后,AI能否准确识别物体特征
|
||||||
|
|
@ -7,59 +10,126 @@ from encodings.punycode import T
|
||||||
from tkinter import N
|
from tkinter import N
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
from lib.qwen_fun_vid import generate_video_to_objects
|
||||||
from lib.qwen_fun import chat, get_annnotated_frame_for_ai, get_unique_track_id_count, save_json_to_file, search_knowledge_base, upload_files_and_get_urls_concurrently
|
|
||||||
from lib.qwen_fun_vid import create_mian_vid_for_ai, frame_all_to_obj_vid
|
|
||||||
from lib.sam3 import SAM3
|
from lib.sam3 import SAM3
|
||||||
from ultralytics.models.sam import SAM3VideoSemanticPredictor
|
from ultralytics.models.sam import SAM3VideoSemanticPredictor
|
||||||
import cv2
|
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
def run(output_dir, vid_path, interval: int = 5):
|
def run():
|
||||||
if output_dir:
|
vid_name = "santai5" # 视频名称(不含后缀)
|
||||||
os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错
|
vid_end = ".mp4" # 视频后缀
|
||||||
|
use_url_cache = True # 是否使用 URL 缓存,避免重复上传视频
|
||||||
|
enable_thinking = False # 是否启用思考模式
|
||||||
|
|
||||||
annotated_frames: SAM3 = SAM3()
|
run_vid_process = False # 是否运行视频处理流程(提取物体视频)
|
||||||
annotated_frames.run(vid_path, output_dir, "lib/class_list/xiaofangz.json")
|
|
||||||
# annotated_frames.load_from_json(f"{output_dir}/frame_all.json")
|
|
||||||
# print(annotated_frames.data)
|
|
||||||
|
|
||||||
# 提取ai能看到的部分
|
input_video_path = f"input/{vid_name}{vid_end}"
|
||||||
ai_frames: dict = get_annnotated_frame_for_ai(annotated_frames.data(), interval)
|
output_dir = f"output/{vid_name}"
|
||||||
save_json_to_file(ai_frames, f"{output_dir}/frame_all_ai.json")
|
frame_detections_path = f"{output_dir}/frame_detections.json"
|
||||||
|
objects_json_path = f"{output_dir}/objects.json"
|
||||||
|
vid_dir = f"{output_dir}/obj_vids"
|
||||||
|
obj = json.load(open(objects_json_path, "r", encoding="utf-8"))
|
||||||
|
class_list = obj["class_list"]
|
||||||
|
obj_dict = obj["track_id_list"]
|
||||||
|
rule_dict: dict = load_json_data('知识库/rule.json')
|
||||||
|
video_url_file = f"{output_dir}/video_url.json"
|
||||||
|
|
||||||
|
# 提取物体信息
|
||||||
|
f_detections_to_objects(
|
||||||
|
frame_detections_path,
|
||||||
|
objects_json_path
|
||||||
|
)
|
||||||
|
|
||||||
if vid_path.startswith("oss"):
|
if run_vid_process:
|
||||||
video_url = vid_path
|
# 生成物体视频
|
||||||
else:
|
generate_video_to_objects(
|
||||||
video_url: str|None = upload_files_and_get_urls_concurrently(
|
obj_dict,
|
||||||
file_path_list=[vid_path],
|
input_video_path,
|
||||||
max_workers=8
|
output_dir=vid_dir,
|
||||||
)[0]
|
)
|
||||||
if video_url:
|
|
||||||
with open(f"{output_dir}/video_url.json", "w") as f:
|
|
||||||
f.write(video_url)
|
|
||||||
|
|
||||||
if video_url is None:
|
# 上传视频并获取 URL
|
||||||
raise ValueError("视频上传失败,无法获取 URL")
|
vid_dict = get_vid_dict(vid_dir, obj_dict, video_url_file, use_url_cache)
|
||||||
|
|
||||||
result = hazard_inspection(ai_frames, video_url)[1]
|
hazard_inspection(
|
||||||
|
output_dir,
|
||||||
|
obj_dict,
|
||||||
|
rule_dict,
|
||||||
|
class_list,
|
||||||
|
vid_dict,
|
||||||
|
fps=30,
|
||||||
|
enable_thinking=enable_thinking
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_vid_dict(vid_dir: str, obj_dict: dict, video_url_file: str, use_url_cache: bool) -> dict:
|
||||||
|
"""获取视频字典
|
||||||
|
参数:
|
||||||
|
vid_dir: 物体视频存放目录
|
||||||
|
obj_dict: 物品字典,键为物品TrackID,值为物品信息
|
||||||
|
video_url_file: 视频URL文件路径
|
||||||
|
use_url_cache: 是否使用缓存的URL
|
||||||
|
返回:
|
||||||
|
视频字典,键为物品TrackID,值为视频本地地址和视频URL
|
||||||
|
"""
|
||||||
|
|
||||||
|
vid_dict = {}
|
||||||
|
|
||||||
|
# 检查URL是否已存在
|
||||||
|
if use_url_cache:
|
||||||
|
if os.path.exists(video_url_file):
|
||||||
|
try:
|
||||||
|
with open(video_url_file, "r", encoding="utf-8") as f:
|
||||||
|
vid_dict = json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取URL文件失败: {e}")
|
||||||
|
for track_id in obj_dict.keys():
|
||||||
|
if track_id in vid_dict:
|
||||||
|
continue
|
||||||
|
vid_path = f"{vid_dir}/obj_{int(track_id):03d}.mp4"
|
||||||
|
vid_dict[track_id] = {"vid_path": vid_path, "vid_url": None}
|
||||||
|
|
||||||
|
# 找到未上传的视频
|
||||||
|
unuploaded_vids = []
|
||||||
|
for track_id, vid_info in vid_dict.items():
|
||||||
|
if vid_info["vid_url"] is None:
|
||||||
|
unuploaded_vids.append(vid_info["vid_path"])
|
||||||
|
|
||||||
|
# 上传未上传的视频并获取 URL
|
||||||
|
uploaded_urls = upload_files_and_get_urls_concurrently(
|
||||||
|
file_path_list=unuploaded_vids,
|
||||||
|
max_workers=8
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新视频URL
|
||||||
|
for track_id, vid_info in vid_dict.items():
|
||||||
|
if vid_info["vid_url"] is None:
|
||||||
|
idx = unuploaded_vids.index(vid_info["vid_path"])
|
||||||
|
vid_info["vid_url"] = uploaded_urls[idx]
|
||||||
|
|
||||||
|
# 保存字典为 JSON 文件
|
||||||
|
with open(video_url_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(vid_dict, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
return vid_dict
|
||||||
|
|
||||||
with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f:
|
|
||||||
f.write(result)
|
|
||||||
|
|
||||||
# 启动应用
|
# 启动应用
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
VID_NAME: str = "Peihdianhxiang"
|
# objects_json_path = "output/santai5/objects.json"
|
||||||
VID_END: str = ".mp4"
|
# input_video_path = "input/santai5.mp4"
|
||||||
interval: int = int(30 / 5)
|
# obj_dict = json.load(open(objects_json_path, "r", encoding="utf-8"))["track_id_list"]
|
||||||
|
|
||||||
output_dir: str = f"output/{VID_NAME}"
|
# # 记录开始时间
|
||||||
vid_path: str = f"./input/{VID_NAME}{VID_END}"
|
# start_time = datetime.now()
|
||||||
|
# generate_video_to_objects(
|
||||||
|
# obj_dict,
|
||||||
|
# input_video_path,
|
||||||
|
# output_dir="output/santai5/obj_vids",
|
||||||
|
# )
|
||||||
|
|
||||||
run(output_dir, vid_path, interval)
|
run()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,21 +8,53 @@ def f_detections_to_objects(detections_path: str, output_path: str):
|
||||||
:param detections_path: 输入文件路径
|
:param detections_path: 输入文件路径
|
||||||
:param output_path: 输出文件路径
|
:param output_path: 输出文件路径
|
||||||
:return: None
|
:return: None
|
||||||
|
|
||||||
|
detections 格式:
|
||||||
|
{
|
||||||
|
"frame_id": [
|
||||||
|
{
|
||||||
|
"xyxy": [x1, y1, x2, y2]
|
||||||
|
"confidence": 0.314,
|
||||||
|
"track_id": 1,
|
||||||
|
"class_id": 1,
|
||||||
|
"class_str": "class_str",
|
||||||
|
},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
...
|
||||||
|
}
|
||||||
|
|
||||||
|
objects 格式:
|
||||||
|
{
|
||||||
|
"class_list": [
|
||||||
|
"插座",
|
||||||
|
"消火栓",
|
||||||
|
"配电箱"
|
||||||
|
],
|
||||||
|
"track_id_list": {
|
||||||
|
"0": {
|
||||||
|
"class_id": 4,
|
||||||
|
"start_frame": 551,
|
||||||
|
"end_frame": 597
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
with open(detections_path, 'r', encoding='utf-8') as f:
|
with open(detections_path, 'r', encoding='utf-8') as f:
|
||||||
detections_data = json.load(f)
|
detections_data = json.load(f)
|
||||||
|
|
||||||
class_set = set()
|
class_list = []
|
||||||
track_info = defaultdict(lambda: {"class_id": None, "start_frame": float('inf'), "end_frame": 0})
|
track_info = defaultdict(lambda: {"class_id": None, "start_frame": float('inf'), "end_frame": 0})
|
||||||
|
|
||||||
for frame_str, detections in detections_data.items():
|
for frame_str, detections in detections_data.items():
|
||||||
frame = int(frame_str)
|
frame = int(frame_str)
|
||||||
for det in detections:
|
for det in detections:
|
||||||
track_id = det["track_id"]
|
track_id = det["track_id"]
|
||||||
class_id = det["class_id"]
|
|
||||||
class_str = det["class_str"]
|
class_str = det["class_str"]
|
||||||
|
|
||||||
class_set.add(class_str)
|
if class_str not in class_list:
|
||||||
|
class_list.append(class_str)
|
||||||
|
class_id = class_list.index(class_str)
|
||||||
|
|
||||||
if track_info[track_id]["class_id"] is None:
|
if track_info[track_id]["class_id"] is None:
|
||||||
track_info[track_id]["class_id"] = class_id
|
track_info[track_id]["class_id"] = class_id
|
||||||
|
|
@ -33,15 +65,13 @@ def f_detections_to_objects(detections_path: str, output_path: str):
|
||||||
track_info[track_id]["end_frame"] = frame
|
track_info[track_id]["end_frame"] = frame
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"class_list": list(class_set),
|
"class_list": list(class_list),
|
||||||
"track_id_list": dict(track_info)
|
"track_id_list": dict(track_info)
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(output_path, 'w', encoding='utf-8') as f:
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def load_json_data(json_path: str):
|
||||||
f_detections_to_objects(
|
with open(json_path, 'r', encoding='utf-8') as f:
|
||||||
"output/santai5/frame_detections.json",
|
return json.load(f)
|
||||||
"output/santai5/objects.json"
|
|
||||||
)
|
|
||||||
248
lib/qwen_fun.py
248
lib/qwen_fun.py
|
|
@ -1,6 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from lib.json_fun import load_json_data
|
||||||
import requests
|
import requests
|
||||||
from openai import BadRequestError, OpenAI
|
from openai import BadRequestError, OpenAI
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
@ -209,9 +210,7 @@ def save_json_to_file(json_data, file_path: str) -> None:
|
||||||
|
|
||||||
print(f"已保存 JSON 数据到: {file_path}")
|
print(f"已保存 JSON 数据到: {file_path}")
|
||||||
|
|
||||||
def load_json_data(json_path: str):
|
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
def get_unique_track_id_count(json_path: str) -> int:
|
def get_unique_track_id_count(json_path: str) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
@ -362,47 +361,34 @@ def search_knowledge_base(
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def hazard_inspection(
|
def hazard_inspection(
|
||||||
class_dict: dict,
|
output_dir: str,
|
||||||
video_url: str,
|
obj_dict: dict,
|
||||||
|
rule_dict: dict,
|
||||||
|
class_list: list[str],
|
||||||
|
vid_dict: dict,
|
||||||
|
fps: int,
|
||||||
enable_thinking: bool = True,
|
enable_thinking: bool = True,
|
||||||
fps: int = 10
|
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
检查视频中给定类型物体的隐患
|
检查视频中给定类型物体的隐患
|
||||||
2026.04.07
|
2026.04.17
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
class_dict: 类别字典
|
output_dir: 输出目录
|
||||||
input_video_path: 视频路径
|
obj_dict: 物体视频字典
|
||||||
|
class_list: 类别列表
|
||||||
|
vid_dict: 视频字典
|
||||||
|
fps: 视频帧率
|
||||||
enable_thinking: 是否开启思考模式
|
enable_thinking: 是否开启思考模式
|
||||||
fps
|
|
||||||
|
|
||||||
class_dict:结构:
|
obj_dict 格式:
|
||||||
{
|
|
||||||
"class_list": ["class_name1", "class_name2", ..."], # 类别名称
|
vid_dict 格式:
|
||||||
"track_id_list": {
|
|
||||||
"0": { # track_id
|
|
||||||
"class_id": 0,
|
|
||||||
"start_frame": 0,
|
|
||||||
"end_frame": 10
|
|
||||||
},
|
|
||||||
...
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
tuple[str, str]: (reasoning_text, json_result)
|
tuple[str, str]: (reasoning_text, json_result)
|
||||||
json_result: JSON格式的隐患检测结果
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 筛选检查规则 知识库\rule.json
|
|
||||||
all_check_rule_dict: dict = load_json_data(r'知识库\rule.json')
|
|
||||||
|
|
||||||
# 读取提示词文件 7_检测提示词_视频.md
|
|
||||||
prompt_str: str = ""
|
|
||||||
with open(r'prompt\7_检测提示词_视频_new3.md', 'r', encoding='utf-8') as file:
|
|
||||||
prompt_str = file.read()
|
|
||||||
|
|
||||||
# 存储所有类别的结果
|
# 存储所有类别的结果
|
||||||
all_reasoning_text = ""
|
all_reasoning_text = ""
|
||||||
|
|
||||||
|
|
@ -410,58 +396,23 @@ def hazard_inspection(
|
||||||
all_tags = []
|
all_tags = []
|
||||||
all_bases = []
|
all_bases = []
|
||||||
all_objects = []
|
all_objects = []
|
||||||
|
hazard_count = 0
|
||||||
|
|
||||||
# 保存每个类别的原始 tag 和 base 映射,用于后续重新映射
|
|
||||||
class_tag_mapping = {}
|
|
||||||
class_base_mapping = {}
|
|
||||||
|
|
||||||
# 遍历每个类别,分别进行 chat
|
with open(r'prompt\7_检测提示词_视频_new4.md', 'r', encoding='utf-8') as file:
|
||||||
print("类别列表:", class_dict["class_list"])
|
prompt_str = file.read()
|
||||||
for class_str in class_dict["class_list"]:
|
|
||||||
print(f"当前类别: {class_str}")
|
|
||||||
# 检查该类别是否有检查规则
|
|
||||||
class_check_rule: str = ""
|
|
||||||
scene_list: list[str] = list(all_check_rule_dict.keys())
|
|
||||||
print(f"场景列表: {scene_list}")
|
|
||||||
for scene in scene_list:
|
|
||||||
if class_str in all_check_rule_dict[scene].keys():
|
|
||||||
class_check_rule = all_check_rule_dict[scene][class_str]
|
|
||||||
break
|
|
||||||
if class_check_rule == "":
|
|
||||||
# 该类别没有检查规则
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 筛选出该类别的 track_id_list
|
|
||||||
class_track_info = []
|
|
||||||
for track_id, track_data in class_dict["track_id_list"].items():
|
|
||||||
track_class_name = class_dict["class_list"][track_data["class_id"]]
|
|
||||||
if track_class_name == class_str:
|
|
||||||
class_track_info.append(f"Track ID {track_id}: {track_class_name} (帧 {track_data['start_frame']}-{track_data['end_frame']})")
|
|
||||||
|
|
||||||
# 如果该类别没有跟踪对象,跳过
|
# 遍历每个物体视频,进行检查
|
||||||
if not class_track_info:
|
for track_id, track_info in obj_dict.items():
|
||||||
continue
|
class_str = class_list[track_info["class_id"]]
|
||||||
|
vid_url = vid_dict[track_id]["vid_url"]
|
||||||
|
reasoning_text, answer_text = single_object_inspection(vid_url, class_str, rule_dict, prompt_str, enable_thinking)
|
||||||
|
|
||||||
track_data_str = "\n".join(class_track_info)
|
|
||||||
|
|
||||||
# 构建该类别的 prompt
|
|
||||||
prompt: str = f"""
|
|
||||||
# 类别: {class_str}
|
|
||||||
|
|
||||||
# 跟踪对象信息
|
|
||||||
{track_data_str}
|
|
||||||
|
|
||||||
# 检查规则
|
|
||||||
{class_check_rule}
|
|
||||||
|
|
||||||
{prompt_str}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 调用 chat 函数
|
|
||||||
reasoning_text, answer_text = chat(prompt, video_url, enable_thinking, fps)
|
|
||||||
|
|
||||||
# 整合思考过程
|
# 整合思考过程
|
||||||
all_reasoning_text += f"\n\n# 类别: {class_str}\n{reasoning_text}"
|
all_reasoning_text += f"\n\n# 类别: {track_id}:{class_str}\n{reasoning_text}"
|
||||||
|
|
||||||
# 解析 JSON 结果
|
# 解析 JSON 结果
|
||||||
try:
|
try:
|
||||||
|
|
@ -477,81 +428,110 @@ def hazard_inspection(
|
||||||
|
|
||||||
class_result = json.loads(json_str)
|
class_result = json.loads(json_str)
|
||||||
|
|
||||||
|
# 保存class_result到文件,便于调试
|
||||||
|
with open(f"{output_dir}/inspection_result_{int(track_id):03d}.txt", "w", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(class_result, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
# 收集该类别的结果
|
# 收集该类别的结果
|
||||||
if "tag" in class_result:
|
temp_obj = {}
|
||||||
# 保存该类别的 tag 映射
|
for obj in class_result.get("objects", []):
|
||||||
class_tag_mapping[class_str] = class_result["tag"]
|
|
||||||
all_tags.extend(class_result["tag"])
|
# tag_id
|
||||||
if "base" in class_result:
|
if "tag_id" in obj and obj["tag_id"] < len(class_result.get("tag", [])):
|
||||||
# 保存该类别的 base 映射
|
tag = class_result["tag"][obj["tag_id"]]
|
||||||
class_base_mapping[class_str] = class_result["base"]
|
else:
|
||||||
all_bases.extend(class_result["base"])
|
tag = None
|
||||||
if "objects" in class_result:
|
|
||||||
all_objects.extend(class_result["objects"])
|
if tag not in all_tags:
|
||||||
|
all_tags.append(tag)
|
||||||
|
temp_obj["tag_id"] = all_tags.index(tag)
|
||||||
|
|
||||||
|
# base_id
|
||||||
|
if "base_id" in obj and obj["base_id"] < len(class_result.get("base", [])):
|
||||||
|
base = class_result["base"][obj["base_id"]]
|
||||||
|
else:
|
||||||
|
base = None
|
||||||
|
|
||||||
|
if base not in all_bases:
|
||||||
|
all_bases.append(base)
|
||||||
|
temp_obj["base_id"] = all_bases.index(base)
|
||||||
|
|
||||||
|
# 其他字段
|
||||||
|
temp_obj["track_id"] = track_id
|
||||||
|
temp_obj["hazard_track_id"] = hazard_count
|
||||||
|
temp_obj["class_id"] = track_info["class_id"]
|
||||||
|
temp_obj["level"] = obj["level"]
|
||||||
|
temp_obj["start_frame"] = track_info["start_frame"]
|
||||||
|
temp_obj["start_sec"] = round(track_info["start_frame"] / fps, 1) # 处理成x.x秒
|
||||||
|
temp_obj["location"] = obj.get("location", "")
|
||||||
|
|
||||||
|
hazard_count += 1
|
||||||
|
|
||||||
|
all_objects.append(temp_obj)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"解析类别 {class_str} 的 JSON 结果失败: {e}")
|
print(f"解析类别 {class_str} 的 JSON 结果失败: {e}")
|
||||||
print(f"原始输出: {answer_text}")
|
print(f"原始输出: {answer_text}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 合并去重 tag 和 base
|
|
||||||
unique_tags = []
|
|
||||||
seen_tags = set()
|
|
||||||
for tag in all_tags:
|
|
||||||
if tag not in seen_tags:
|
|
||||||
seen_tags.add(tag)
|
|
||||||
unique_tags.append(tag)
|
|
||||||
|
|
||||||
unique_bases = []
|
|
||||||
seen_bases = set()
|
|
||||||
for base in all_bases:
|
|
||||||
if base not in seen_bases:
|
|
||||||
seen_bases.add(base)
|
|
||||||
unique_bases.append(base)
|
|
||||||
|
|
||||||
# 重新分配 tag_id 和 base_id
|
|
||||||
# 由于每个类别是独立处理的,需要根据 class_id 找到对应的类别,然后重新映射
|
|
||||||
# 但是 object 中的 class_id 是相对于 class_list 的,需要找到对应的类别名称
|
|
||||||
final_objects = []
|
|
||||||
for obj in all_objects:
|
|
||||||
new_obj = obj.copy()
|
|
||||||
new_obj["hazard_track_id"] = len(final_objects)
|
|
||||||
|
|
||||||
# 通过 class_id 找到类别名称
|
|
||||||
class_id = obj.get("class_id", 0)
|
|
||||||
if class_id < len(class_dict["class_list"]):
|
|
||||||
class_name = class_dict["class_list"][class_id]
|
|
||||||
|
|
||||||
# 重新映射 tag_id
|
|
||||||
if "tag_id" in obj and class_name in class_tag_mapping:
|
|
||||||
original_tags = class_tag_mapping[class_name]
|
|
||||||
if obj["tag_id"] < len(original_tags):
|
|
||||||
original_tag = original_tags[obj["tag_id"]]
|
|
||||||
if original_tag in unique_tags:
|
|
||||||
new_obj["tag_id"] = unique_tags.index(original_tag)
|
|
||||||
|
|
||||||
# 重新映射 base_id
|
|
||||||
if "base_id" in obj and class_name in class_base_mapping:
|
|
||||||
original_bases = class_base_mapping[class_name]
|
|
||||||
if obj["base_id"] < len(original_bases):
|
|
||||||
original_base = original_bases[obj["base_id"]]
|
|
||||||
if original_base in unique_bases:
|
|
||||||
new_obj["base_id"] = unique_bases.index(original_base)
|
|
||||||
|
|
||||||
final_objects.append(new_obj)
|
|
||||||
|
|
||||||
# 构建最终结果
|
# 构建最终结果
|
||||||
final_result = {
|
final_result = {
|
||||||
"tag": unique_tags,
|
"tag": all_tags,
|
||||||
"base": unique_bases,
|
"base": all_bases,
|
||||||
"objects": final_objects
|
"objects": all_objects
|
||||||
}
|
}
|
||||||
|
|
||||||
# 转换为 JSON 字符串
|
# 转换为 JSON 字符串
|
||||||
json_result = json.dumps(final_result, ensure_ascii=False, indent=2)
|
json_result = json.dumps(final_result, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f:
|
||||||
|
f.write(json_result)
|
||||||
|
with open(f"{output_dir}/hazard_inspection_reasoning.md", "w", encoding="utf-8") as f:
|
||||||
|
f.write(all_reasoning_text)
|
||||||
|
|
||||||
return all_reasoning_text, json_result
|
return all_reasoning_text, json_result
|
||||||
|
|
||||||
|
def single_object_inspection(vid_url: str, class_str: str, rule_dict: dict, prompt_str: str, enable_thinking: bool = True) -> tuple[str, str]:
|
||||||
|
# 获取检查规则
|
||||||
|
rule = get_rule_by_class_str(class_str, rule_dict)
|
||||||
|
|
||||||
|
if rule == "":
|
||||||
|
print(f"类别 {class_str} 没有检查规则,跳过检查")
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
print(f"类别 {class_str} 的检查规则: {rule}")
|
||||||
|
|
||||||
|
# 构建该类别的 prompt
|
||||||
|
prompt: str = f"""
|
||||||
|
# 检测对象为视频中出现的{class_str}
|
||||||
|
|
||||||
|
# 检查规则
|
||||||
|
{rule}
|
||||||
|
|
||||||
|
{prompt_str}
|
||||||
|
"""
|
||||||
|
|
||||||
|
return chat(prompt, vid_url, enable_thinking)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rule_by_class_str(class_str: str, all_check_rule_dict: dict) -> str:
|
||||||
|
"""根据类别名称获取检查规则
|
||||||
|
参数:
|
||||||
|
class_str: 类别名称
|
||||||
|
all_check_rule_dict: 所有检查规则字典
|
||||||
|
返回:
|
||||||
|
检查规则字符串
|
||||||
|
"""
|
||||||
|
print(f"当前类别: {class_str}")
|
||||||
|
# 检查该类别是否有检查规则
|
||||||
|
class_check_rule: str = ""
|
||||||
|
scene_list: list[str] = list(all_check_rule_dict.keys())
|
||||||
|
for scene in scene_list:
|
||||||
|
if class_str in all_check_rule_dict[scene].keys():
|
||||||
|
class_check_rule = all_check_rule_dict[scene][class_str]
|
||||||
|
break
|
||||||
|
return str(class_check_rule)
|
||||||
|
|
||||||
def merge_conflict_inspection_data(data: dict) -> dict:
|
def merge_conflict_inspection_data(data: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
合并冲突数据(相同 tag_id 和 class_id 的重叠时间段)。
|
合并冲突数据(相同 tag_id 和 class_id 的重叠时间段)。
|
||||||
|
|
|
||||||
|
|
@ -5,126 +5,131 @@ import supervision as sv
|
||||||
import cv2
|
import cv2
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
def generate_video_with_boxes(
|
import os
|
||||||
boxes_data: list[dict],
|
import cv2
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def generate_video_to_objects(
|
||||||
|
obj_dict: dict[dict],
|
||||||
input_video_path: str,
|
input_video_path: str,
|
||||||
output_video_path: str,
|
output_dir: str,
|
||||||
frame_annotation_interval: int
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
将提供的标注数据渲染到视频上,支持按帧间隔进行标注。
|
根据 obj_dict 中的物体信息,从原视频中截取附近帧,生成新视频
|
||||||
|
|
||||||
:param boxes_data: 包含标注信息的列表,每个元素结构为:
|
:param obj_dict: 包含物体信息的字典,每个元素结构为:
|
||||||
{
|
"0": {
|
||||||
"frame_id": int,
|
"class_id": 4,
|
||||||
"boxes": List[Tuple[int, int, int, int, str]] # (x1, y1, x2, y2, label)
|
"start_frame": 551,
|
||||||
}
|
"end_frame": 597
|
||||||
|
},
|
||||||
:param input_video_path: 输入视频文件路径
|
:param input_video_path: 输入视频文件路径
|
||||||
:param output_video_path: 输出视频文件路径
|
:param output_dir: 输出视频目录
|
||||||
:param frame_annotation_interval: 标注间隔(单位:帧),默认为 1(每帧标注)
|
|
||||||
:return: 无
|
|
||||||
"""
|
"""
|
||||||
# -------------------------------------------------
|
# 最低秒数
|
||||||
# 1. 视频读取与基本配置
|
min_seconds = 2
|
||||||
# -------------------------------------------------
|
# 前后额外帧数
|
||||||
cap = cv2.VideoCapture(input_video_path)
|
extra_frames = 5
|
||||||
if not cap.isOpened():
|
|
||||||
raise FileNotFoundError(f"无法打开视频文件或流: {input_video_path}")
|
print(f"开始抽取物体视频: {input_video_path}")
|
||||||
|
print(f"输出目录: {output_dir}")
|
||||||
|
# 确保输出目录存在
|
||||||
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 1. 打开原始视频
|
||||||
|
backends = [
|
||||||
|
(cv2.CAP_FFMPEG, 'FFmpeg'),
|
||||||
|
(cv2.CAP_DSHOW, 'DirectShow'),
|
||||||
|
(cv2.CAP_ANY, 'Default')
|
||||||
|
]
|
||||||
|
|
||||||
|
cap = None
|
||||||
|
for backend, backend_name in backends:
|
||||||
|
try:
|
||||||
|
cap = cv2.VideoCapture(input_video_path, backend)
|
||||||
|
if cap.isOpened():
|
||||||
|
if backend == cv2.CAP_FFMPEG:
|
||||||
|
try:
|
||||||
|
cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
|
||||||
|
hw_accel = cap.get(cv2.CAP_PROP_HW_ACCELERATION)
|
||||||
|
print(f"FFmpeg硬件加速: {'已启用' if hw_accel > 0 else '未启用'}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"设置硬件加速失败: {e}")
|
||||||
|
print(f"使用后端: {backend_name}")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"尝试{backend_name}后端失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not cap or not cap.isOpened():
|
||||||
|
raise Exception(f"无法打开视频文件: {input_video_path}")
|
||||||
|
|
||||||
|
# 2. 获取视频参数
|
||||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
min_frames = int(min_seconds * fps)
|
||||||
|
|
||||||
# -------------------------------------------------
|
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||||
# 2. 初始化 VideoWriter
|
|
||||||
# -------------------------------------------------
|
|
||||||
Path(os.path.dirname(output_video_path)).mkdir(parents=True, exist_ok=True)
|
|
||||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
|
|
||||||
out = cv2.VideoWriter(output_video_path, fourcc, fps, (video_width, video_height))
|
|
||||||
|
|
||||||
# -------------------------------------------------
|
# 3. 预处理所有物体的帧范围 + 初始化写入器
|
||||||
# 3. 初始化标注工具
|
total_objects = len(obj_dict)
|
||||||
# -------------------------------------------------
|
print(f"总共需要处理 {total_objects} 个物体")
|
||||||
box_annotator = sv.BoxAnnotator()
|
|
||||||
label_annotator = sv.RichLabelAnnotator(
|
|
||||||
font_path="C:/Windows/Fonts/simhei.ttf",
|
|
||||||
text_color=sv.Color.WHITE,
|
|
||||||
text_padding=5,
|
|
||||||
font_size=20
|
|
||||||
)
|
|
||||||
|
|
||||||
# -------------------------------------------------
|
writers = {} # 存储所有视频写入器
|
||||||
# 4. 数据预处理:构建 frame_id -> boxes 的映射
|
obj_ranges = {} # 存储每个物体的起止帧
|
||||||
# -------------------------------------------------
|
|
||||||
frame_to_boxes: dict[int, list[tuple[int, int, int, int, str]]] = {}
|
|
||||||
for entry in boxes_data:
|
|
||||||
frame_id = entry["frame_id"]
|
|
||||||
boxes = entry["boxes"]
|
|
||||||
# 确保每个框都有正确的结构
|
|
||||||
cleaned_boxes = []
|
|
||||||
for box in boxes:
|
|
||||||
if len(box) == 5:
|
|
||||||
cleaned_boxes.append(box) # (x1, y1, x2, y2, label)
|
|
||||||
frame_to_boxes[frame_id] = cleaned_boxes
|
|
||||||
|
|
||||||
# -------------------------------------------------
|
for track_id, track_data in obj_dict.items():
|
||||||
# 5. 主循环:逐帧读取并渲染
|
start_idx = max(0, track_data["start_frame"] - extra_frames)
|
||||||
# -------------------------------------------------
|
end_idx = min(total_frames - 1, track_data["end_frame"] + extra_frames)
|
||||||
|
|
||||||
|
# 保证最小长度
|
||||||
|
if end_idx - start_idx + 1 < min_frames:
|
||||||
|
end_idx = start_idx + min_frames - 1
|
||||||
|
|
||||||
|
output_path = os.path.join(output_dir, f"obj_{int(track_id):03d}.mp4")
|
||||||
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
||||||
|
|
||||||
|
if not out.isOpened():
|
||||||
|
raise Exception(f"无法初始化视频写入器: {output_path}")
|
||||||
|
|
||||||
|
writers[track_id] = out
|
||||||
|
obj_ranges[track_id] = (start_idx, end_idx)
|
||||||
|
print(f"物体 {track_id}: 帧 {start_idx} ~ {end_idx} -> {output_path}")
|
||||||
|
|
||||||
|
# 4. 单轮遍历视频,一次性写入所有需要的帧(最高效)
|
||||||
frame_idx = 0
|
frame_idx = 0
|
||||||
|
print_interval = 50
|
||||||
|
|
||||||
while cap.isOpened():
|
while cap.isOpened():
|
||||||
|
# 读取一帧
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 计算当前帧对应的标注帧索引(考虑间隔)
|
# 遍历所有物体,判断当前帧是否需要写入
|
||||||
annotation_frame_idx = frame_idx // frame_annotation_interval
|
for track_id, (start, end) in obj_ranges.items():
|
||||||
|
if start <= frame_idx <= end:
|
||||||
|
writers[track_id].write(frame)
|
||||||
|
|
||||||
# 如果当前帧没有标注数据,直接写入原始帧
|
# 进度打印
|
||||||
if annotation_frame_idx not in frame_to_boxes:
|
if frame_idx % print_interval == 0:
|
||||||
out.write(frame)
|
progress = (frame_idx / total_frames) * 100
|
||||||
frame_idx += 1
|
print(f"处理帧: {frame_idx}/{total_frames} ({progress:.1f}%)")
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取当前帧的所有框信息
|
|
||||||
boxes_info = frame_to_boxes[annotation_frame_idx]
|
|
||||||
|
|
||||||
boxes = []
|
|
||||||
labels = []
|
|
||||||
for _, (x1, y1, x2, y2, label) in enumerate(boxes_info):
|
|
||||||
# 坐标
|
|
||||||
boxes.append([x1, y1, x2, y2])
|
|
||||||
# 使用索引作为唯一标识符,或直接使用 label
|
|
||||||
labels.append(label)
|
|
||||||
|
|
||||||
# 转换为 NumPy 数组
|
|
||||||
boxes_np = np.array(boxes, dtype=np.float64)
|
|
||||||
|
|
||||||
# 构建 Detections 对象
|
|
||||||
detections = sv.Detections(
|
|
||||||
xyxy=boxes_np,
|
|
||||||
confidence=np.ones(len(boxes_np)), # 默认置信度为 1.0
|
|
||||||
class_id=np.zeros(len(boxes_np), dtype=int) # 类别 ID 在此场景下不重要
|
|
||||||
)
|
|
||||||
detections.tracker_id = np.arange(len(boxes_np), dtype=int) # 使用索引作为 ID
|
|
||||||
|
|
||||||
# 绘制边框和标签
|
|
||||||
annotated_frame = box_annotator.annotate(scene=frame.copy(), detections=detections)
|
|
||||||
annotated_frame = label_annotator.annotate(
|
|
||||||
scene=annotated_frame,
|
|
||||||
detections=detections,
|
|
||||||
labels=labels
|
|
||||||
)
|
|
||||||
|
|
||||||
# 写入帧
|
|
||||||
out.write(annotated_frame)
|
|
||||||
frame_idx += 1
|
frame_idx += 1
|
||||||
|
|
||||||
# -------------------------------------------------
|
# 5. 释放所有写入器
|
||||||
# 6. 资源释放
|
for track_id, out in writers.items():
|
||||||
# -------------------------------------------------
|
out.release()
|
||||||
|
start, end = obj_ranges[track_id]
|
||||||
|
print(f"物体 {track_id} 完成 | 总帧数: {end - start + 1}")
|
||||||
|
|
||||||
|
# 6. 释放资源
|
||||||
cap.release()
|
cap.release()
|
||||||
out.release()
|
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
print(f"视频渲染完成,已保存至: {output_video_path}")
|
print(f"\n✅ 所有物体视频生成完成!目录: {output_dir}")
|
||||||
|
|
||||||
def process_track_id(
|
def process_track_id(
|
||||||
track_id: int,
|
track_id: int,
|
||||||
|
|
@ -488,4 +493,5 @@ def create_mian_vid_for_ai(
|
||||||
out.release()
|
out.release()
|
||||||
print(f"已生成视频: {output_video_path}, 共 {current_output_frame} 帧")
|
print(f"已生成视频: {output_video_path}, 共 {current_output_frame} 帧")
|
||||||
|
|
||||||
return output_video_path
|
return output_video_path
|
||||||
|
|
||||||
|
|
|
||||||
11
lib/sam3.py
11
lib/sam3.py
|
|
@ -10,7 +10,8 @@ class SAM3:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._data: dict = {}
|
self._data: dict = {}
|
||||||
|
|
||||||
def run(self, vid_path: str, output_dir: str, class_file_path: str, interval: int = 1, conf: float = 0.6) -> dict:
|
# def run(self, vid_path: str, output_dir: str, class_file_path: str, interval: int = 1, conf: float = 0.6) -> dict:
|
||||||
|
def run(self, vid_path: str, output_dir: str, class_file_path: str, conf: float = 0.6) -> dict:
|
||||||
"""
|
"""
|
||||||
运行 SAM3 模型,返回视频中每个帧的检测结果。
|
运行 SAM3 模型,返回视频中每个帧的检测结果。
|
||||||
参数:
|
参数:
|
||||||
|
|
@ -35,7 +36,7 @@ class SAM3:
|
||||||
,
|
,
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
new_vid_path = extract_frames_to_video(vid_path, output_dir, interval)
|
# new_vid_path = extract_frames_to_video(vid_path, output_dir, interval)
|
||||||
|
|
||||||
self._data = {}
|
self._data = {}
|
||||||
text_dict = json.load(open(class_file_path, "r", encoding="utf-8"))
|
text_dict = json.load(open(class_file_path, "r", encoding="utf-8"))
|
||||||
|
|
@ -46,7 +47,7 @@ class SAM3:
|
||||||
predictor = SAM3VideoSemanticPredictor(overrides=overrides)
|
predictor = SAM3VideoSemanticPredictor(overrides=overrides)
|
||||||
|
|
||||||
results = predictor(
|
results = predictor(
|
||||||
source=new_vid_path,
|
source=vid_path,
|
||||||
text=text_list_en,
|
text=text_list_en,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
@ -96,13 +97,13 @@ class SAM3:
|
||||||
# annotated_frame = r.plot()
|
# annotated_frame = r.plot()
|
||||||
frame_idx += 1
|
frame_idx += 1
|
||||||
|
|
||||||
with open(f"{output_dir}/frame_all.json", "w", encoding="utf-8") as f:
|
with open(f"{output_dir}/frame_detections.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(annotated_frames, f, ensure_ascii=False, indent=2)
|
json.dump(annotated_frames, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print(f"\n总共处理 {frame_idx} 帧")
|
print(f"\n总共处理 {frame_idx} 帧")
|
||||||
print(f"结果已保存到 {output_dir}/frame_all.json")
|
print(f"结果已保存到 {output_dir}/frame_detections.json")
|
||||||
print(f"标记图片已保存到 {output_dir}/boxes/")
|
print(f"标记图片已保存到 {output_dir}/boxes/")
|
||||||
|
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,7 @@
|
||||||
{
|
{
|
||||||
"hazard_track_id": 0,
|
"hazard_track_id": 0,
|
||||||
"tag_id": 0,
|
"tag_id": 0,
|
||||||
"class_id": 0,
|
|
||||||
"level": 0,
|
"level": 0,
|
||||||
"start_frame": 0,
|
|
||||||
"base_id": 0,
|
"base_id": 0,
|
||||||
"location": "这里是隐患位置描述"
|
"location": "这里是隐患位置描述"
|
||||||
},
|
},
|
||||||
|
|
@ -37,18 +35,14 @@
|
||||||
{
|
{
|
||||||
"hazard_track_id": 0,
|
"hazard_track_id": 0,
|
||||||
"tag_id": 0,
|
"tag_id": 0,
|
||||||
"class_id": 0,
|
|
||||||
"level": 0,
|
"level": 0,
|
||||||
"start_frame": 0,
|
|
||||||
"base_id": 0,
|
"base_id": 0,
|
||||||
"location": "这里是隐患位置描述"
|
"location": "这里是隐患位置描述"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"hazard_track_id": 1,
|
"hazard_track_id": 1,
|
||||||
"tag_id": 1,
|
"tag_id": 1,
|
||||||
"class_id": 1,
|
|
||||||
"level": 1,
|
"level": 1,
|
||||||
"start_frame": 0,
|
|
||||||
"base_id": 1,
|
"base_id": 1,
|
||||||
"location": "这里是隐患位置描述"
|
"location": "这里是隐患位置描述"
|
||||||
}
|
}
|
||||||
|
|
@ -58,10 +52,9 @@
|
||||||
# 输出格式注意事项
|
# 输出格式注意事项
|
||||||
- 你的输出只能包含tag、base和objects三个键。
|
- 你的输出只能包含tag、base和objects三个键。
|
||||||
- tag是一个字符串数组,每个元素是隐患标签,必须为中文,不能使用英文。每个隐患点只能有一个标签,如果规则中存在多个标签,必须选择最符合视频中情况的一个标签。
|
- tag是一个字符串数组,每个元素是隐患标签,必须为中文,不能使用英文。每个隐患点只能有一个标签,如果规则中存在多个标签,必须选择最符合视频中情况的一个标签。
|
||||||
- objects是一个字典列表,每个字典必须包含hazard_track_id(整数)、tag_id(整数)、class_id(整数)、level(整数)、start_frame(整数)、base_id(整数)、location(字符串)。
|
- objects是一个字典列表,每个字典必须包含hazard_track_id(整数)、tag_id(整数)、level(整数)、base_id(整数)、location(字符串)。
|
||||||
- hazard_track_id分配规则:根据视频画面,每个连续出现的类型+隐患组合应分配单独的hazard_track_id。相同的tag_id+class_id组合在时间上不应相交,应确保只存在一份连续的记录。绝对禁止出现多个hazard_track_id具有相同的tag_id和class_id且时间重叠的情况。
|
- hazard_track_id分配规则:根据视频画面,每个隐患点只能分配一个hazard_track_id,不能重复分配。
|
||||||
- level必须为0或1或2,不能为其他整数。为0表示隐患等级为疑似,为1表示隐患等级为低,为2表示隐患等级为高。
|
- level必须为0或1或2,不能为其他整数。为0表示隐患等级为疑似,为1表示隐患等级为低,为2表示隐患等级为高。
|
||||||
- 不要输出bbox_2d、label或任何中文标点。
|
|
||||||
- 输出格式必须为标准json格式,且结构必须与模板一致
|
- 输出格式必须为标准json格式,且结构必须与模板一致
|
||||||
- class_id必须与class_list中的顺序严格一致,保持一一映射关系。
|
- class_id必须与class_list中的顺序严格一致,保持一一映射关系。
|
||||||
- 所有hazard_track_id都为独立隐患点,不存在误检,不得合并或拆分
|
- 所有hazard_track_id都为独立隐患点,不存在误检,不得合并或拆分
|
||||||
|
|
@ -72,16 +65,15 @@
|
||||||
- location表示该隐患在视频画面中的位置描述,必须为中文,描述要准确、清晰,能够明确指出隐患在画面中的相对位置。
|
- location表示该隐患在视频画面中的位置描述,必须为中文,描述要准确、清晰,能够明确指出隐患在画面中的相对位置。
|
||||||
|
|
||||||
# 任务1
|
# 任务1
|
||||||
- **帧级分析**:根据提供的物体class、出现时间点(提供的数据为每帧对应的物体列表)与隐患识别规则,在视频中,同track_id的物体需持续观察,对隐患进行识别,每个隐患点分配一个hazard_track_id,杜绝在同一个物体上重复识别隐患点
|
- **帧级分析**:根据提供的物体名称与隐患识别规则,在视频中对隐患进行识别,每个隐患点分配一个hazard_track_id,杜绝在同一个物体上重复识别隐患点
|
||||||
- **汇总处理**:在完成所有帧的分析后,基于各帧的分析结果,为每个hazard_track_id确定最终的隐患标签、等级、位置描述以及开始帧位置
|
- **汇总处理**:在完成所有帧的分析后,基于各帧的分析结果,为每个hazard_track_id确定最终的隐患标签、等级、位置描述以及开始帧位置
|
||||||
- **基本要求**:只检测由class确定的物体,每个hazard_track_id对应的字典中必须包含该hazard_track_id的tag_id、class_id、level、start_frame、base_id、location信息
|
- **基本要求**:只检测指定物体,每个hazard_track_id对应的字典中必须包含该hazard_track_id的tag_id、level、base_id、location信息
|
||||||
- **匹配规则**:如果物体与检测条目匹配,就将该检测条目添加到objects列表中,并设置相应的tag_id
|
- **匹配规则**:如果物体与检测条目匹配,就将该检测条目添加到objects列表中,并设置相应的tag_id
|
||||||
- **关键约束**:
|
- **关键约束**:
|
||||||
1. **逐帧分析**:必须分析每一帧中的每个物体,用class匹配相应的检测规则,根据隐患识别规则进行隐患识别
|
1. **语音识别**:必须对视频中的语音进行识别,辅助隐患识别
|
||||||
2. **语音识别**:必须对视频中的语音进行识别,辅助隐患识别
|
2. **规则参考**:严格参考知识库中的规则结构进行隐患识别,规则结构参考 `知识库/rule.json`
|
||||||
3. **规则参考**:严格参考知识库中的规则结构进行隐患识别,规则结构参考 `知识库/rule.json`
|
3. **全面识别**:必须对提供的物体进行隐患识别
|
||||||
4. **全面识别**:必须对所有提供的物体进行隐患识别,不得遗漏任何物体
|
4. **准确匹配**:根据物体名称与隐患识别规则进行准确匹配,确定隐患标签和等级
|
||||||
5. **准确匹配**:根据物体的class与隐患识别规则进行准确匹配,确定隐患标签和等级
|
5. **等级判定**:根据规则中的匹配条件和依据,合理判定匹配等级(0-疑似,1-确定)
|
||||||
6. **等级判定**:根据规则中的匹配条件和依据,合理判定匹配等级(0-疑似,1-确定)
|
6. **hazard_track_id分配**:根据视频画面,每个隐患点应分配单独的hazard_track_id。
|
||||||
7. **hazard_track_id分配**:根据视频画面,每个隐患点应分配单独的hazard_track_id。
|
7. **位置描述**:大模型需在输出时提供隐患点相对于视频画面的位置,location字段必须准确描述隐患在画面中的位置,例如:"画面左上角"、"画面中央偏右"等。
|
||||||
8. **位置描述**:大模型需在输出时提供隐患点相对于视频画面的位置,location字段必须准确描述隐患在画面中的位置,例如:"画面左上角"、"画面中央偏右"等。
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
### qwen试用
|
### qwen试用
|
||||||
账号: 18237294717
|
账号: 18237294717 王凡的
|
||||||
https://bailian.console.aliyun.com/cn-beijing/?spm=a2c4g.11186623.0.0.2c2b2217h6fbio&tab=app#/app-center
|
https://bailian.console.aliyun.com/cn-beijing/?spm=a2c4g.11186623.0.0.2c2b2217h6fbio&tab=app#/app-center
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
aiofiles==24.1.0
|
||||||
|
annotated-doc==0.0.4
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.12.1
|
||||||
|
audioop-lts==0.2.2
|
||||||
|
brotli==1.2.0
|
||||||
|
certifi==2026.2.25
|
||||||
|
charset-normalizer==3.4.5
|
||||||
|
click==8.3.1
|
||||||
|
clip @ git+https://github.com/ultralytics/CLIP.git@88ade288431a46233f1556d1e141901b3ef0a36b
|
||||||
|
colorama==0.4.6
|
||||||
|
contourpy==1.3.3
|
||||||
|
cycler==0.12.1
|
||||||
|
defusedxml==0.7.1
|
||||||
|
distro==1.9.0
|
||||||
|
fastapi==0.135.1
|
||||||
|
ffmpy==1.0.0
|
||||||
|
filelock==3.25.1
|
||||||
|
fonttools==4.62.0
|
||||||
|
fsspec==2026.2.0
|
||||||
|
ftfy==6.3.1
|
||||||
|
gradio==6.9.0
|
||||||
|
gradio_client==2.3.0
|
||||||
|
groovy==0.1.2
|
||||||
|
h11==0.16.0
|
||||||
|
hf-xet==1.3.2
|
||||||
|
httpcore==1.0.9
|
||||||
|
httpx==0.28.1
|
||||||
|
huggingface_hub==1.6.0
|
||||||
|
idna==3.11
|
||||||
|
Jinja2==3.1.6
|
||||||
|
jiter==0.13.0
|
||||||
|
kiwisolver==1.5.0
|
||||||
|
lap==0.5.13
|
||||||
|
markdown-it-py==4.0.0
|
||||||
|
MarkupSafe==3.0.3
|
||||||
|
matplotlib==3.10.8
|
||||||
|
mdurl==0.1.2
|
||||||
|
MouseInfo==0.1.3
|
||||||
|
mpmath==1.3.0
|
||||||
|
networkx==3.6.1
|
||||||
|
numpy==2.4.3
|
||||||
|
openai==2.29.0
|
||||||
|
opencv-python==4.13.0.92
|
||||||
|
orjson==3.11.7
|
||||||
|
packaging @ file:///C:/miniconda3/conda-bld/packaging_1761049099114/work
|
||||||
|
pandas==3.0.1
|
||||||
|
pillow==12.1.1
|
||||||
|
polars==1.38.1
|
||||||
|
polars-runtime-32==1.38.1
|
||||||
|
psutil==7.2.2
|
||||||
|
PyAutoGUI==0.9.54
|
||||||
|
pydantic==2.12.5
|
||||||
|
pydantic_core==2.41.5
|
||||||
|
pyDeprecate==0.5.0
|
||||||
|
pydub==0.25.1
|
||||||
|
PyGetWindow==0.0.9
|
||||||
|
Pygments==2.19.2
|
||||||
|
PyMsgBox==2.0.1
|
||||||
|
pyparsing==3.3.2
|
||||||
|
pyperclip==1.11.0
|
||||||
|
PyRect==0.2.0
|
||||||
|
PyScreeze==1.0.1
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
python-multipart==0.0.22
|
||||||
|
pytweening==1.2.0
|
||||||
|
pytz==2026.1.post1
|
||||||
|
PyYAML==6.0.3
|
||||||
|
regex==2026.2.28
|
||||||
|
requests==2.32.5
|
||||||
|
rich==14.3.3
|
||||||
|
safehttpx==0.1.7
|
||||||
|
safetensors==0.7.0
|
||||||
|
scipy==1.17.1
|
||||||
|
semantic-version==2.10.0
|
||||||
|
setuptools==80.10.2
|
||||||
|
shellingham==1.5.4
|
||||||
|
six==1.17.0
|
||||||
|
sniffio==1.3.1
|
||||||
|
starlette==0.52.1
|
||||||
|
supervision==0.27.0.post2
|
||||||
|
sympy==1.14.0
|
||||||
|
timm==1.0.25
|
||||||
|
tomlkit==0.13.3
|
||||||
|
torch==2.7.1+cu118
|
||||||
|
torchaudio==2.7.1+cu118
|
||||||
|
torchvision==0.22.1+cu118
|
||||||
|
tqdm==4.67.3
|
||||||
|
typer==0.24.1
|
||||||
|
typing-inspection==0.4.2
|
||||||
|
typing_extensions==4.15.0
|
||||||
|
tzdata==2025.3
|
||||||
|
ultralytics==8.4.21
|
||||||
|
ultralytics-thop==2.0.18
|
||||||
|
urllib3==2.6.3
|
||||||
|
uvicorn==0.41.0
|
||||||
|
wcwidth==0.6.0
|
||||||
|
wheel==0.46.3
|
||||||
Loading…
Reference in New Issue