bug修复
- 补齐检测结束时间保存代码 - 修复检测结果json格式错误时的处理逻辑 - run函数提添加结果json输出 - 确保fps数据存入全局变量
This commit is contained in:
parent
c678c8b96f
commit
754bfad6e9
165
Qwen_app.py
165
Qwen_app.py
|
|
@ -70,141 +70,6 @@ def update_preview(frame_idx: int, vid_file: str):
|
||||||
|
|
||||||
return img_path, class_tag
|
return img_path, class_tag
|
||||||
|
|
||||||
# def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True,):
|
|
||||||
# #================初始化=================
|
|
||||||
# vid_name: str = Path(vid_file).stem
|
|
||||||
# vid_end: str = Path(vid_file).suffix
|
|
||||||
# vid_path: str = f"./input/{vid_name}{vid_end}"
|
|
||||||
|
|
||||||
# cap = cv2.VideoCapture(vid_path)
|
|
||||||
# global fps
|
|
||||||
# fps = int(cap.get(cv2.CAP_PROP_FPS)) # 获取视频FPS
|
|
||||||
# cap.release()
|
|
||||||
|
|
||||||
# output_dir: str = f"output/{vid_name}"
|
|
||||||
# # interval = int(fps / 5)
|
|
||||||
# interval = fps
|
|
||||||
# conf = 0.7
|
|
||||||
# time_data: dict = {}
|
|
||||||
|
|
||||||
# annotated_frames: SAM3 = SAM3()
|
|
||||||
# result: dict = {}
|
|
||||||
|
|
||||||
# if output_dir:
|
|
||||||
# os.makedirs(output_dir, exist_ok=True) # exist_ok=True 防止重复创建报错
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# #================获取物体信息=================
|
|
||||||
# # 保存开始时间字符串
|
|
||||||
# time_data["start_time"] = str(datetime.now())
|
|
||||||
# with open(f"{output_dir}/time.json", "w") as f:
|
|
||||||
# f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
|
|
||||||
|
|
||||||
# if run_sam3:
|
|
||||||
# # 针对厂房防火分区
|
|
||||||
# annotated_frames.run(vid_path, output_dir, "lib/class_list/1.厂房防火.json", interval, conf)
|
|
||||||
# else:
|
|
||||||
# annotated_frames.load_from_json(f"{output_dir}/frame_all.json")
|
|
||||||
# # print(annotated_frames.data)
|
|
||||||
|
|
||||||
# # 提取ai能看到的部分
|
|
||||||
# ai_frames: dict = get_annnotated_frame_for_ai_without_xyxy(annotated_frames.data(), 1, conf)
|
|
||||||
# save_json_to_file(ai_frames, f"{output_dir}/frame_all_ai.json")
|
|
||||||
|
|
||||||
# #================隐患检查=================
|
|
||||||
|
|
||||||
# if run_inspection:
|
|
||||||
# if vid_path.startswith("oss"):
|
|
||||||
# video_url = vid_path
|
|
||||||
# else:
|
|
||||||
# video_url: str|None = None
|
|
||||||
# video_url_file = f"{output_dir}/video_url.json"
|
|
||||||
|
|
||||||
# # 检查URL是否已存在
|
|
||||||
# if os.path.exists(video_url_file):
|
|
||||||
# try:
|
|
||||||
# with open(video_url_file, "r", encoding="utf-8") as f:
|
|
||||||
# url_data = json.load(f)
|
|
||||||
# if vid_name in url_data:
|
|
||||||
# video_url = url_data[vid_name]
|
|
||||||
# print(f"使用已存在的URL: {video_url}")
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"读取URL文件失败: {e}")
|
|
||||||
|
|
||||||
# # 如果URL不存在,上传文件
|
|
||||||
# if video_url is None:
|
|
||||||
# print(f"上传视频文件: {vid_path}")
|
|
||||||
# video_url = upload_files_and_get_urls_concurrently(
|
|
||||||
# file_path_list=[vid_path],
|
|
||||||
# max_workers=8
|
|
||||||
# )[0]
|
|
||||||
# if video_url:
|
|
||||||
# # 保存为JSON格式,包含文件名和URL的键值对
|
|
||||||
# url_data = {}
|
|
||||||
# if os.path.exists(video_url_file):
|
|
||||||
# try:
|
|
||||||
# with open(video_url_file, "r", encoding="utf-8") as f:
|
|
||||||
# url_data = json.load(f)
|
|
||||||
# except:
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# url_data[vid_name] = video_url
|
|
||||||
# with open(video_url_file, "w", encoding="utf-8") as f:
|
|
||||||
# json.dump(url_data, f, ensure_ascii=False, indent=4)
|
|
||||||
# print(f"URL已保存: {video_url}")
|
|
||||||
|
|
||||||
# if video_url is None:
|
|
||||||
# raise ValueError("视频上传失败,无法获取 URL")
|
|
||||||
|
|
||||||
# result_test: str
|
|
||||||
# reason_test: str
|
|
||||||
# reason_test, result_test = hazard_inspection(ai_frames, video_url, enable_thinking=True, fps=2)
|
|
||||||
# result = json.loads(result_test)
|
|
||||||
# # merged_result = merge_conflict_inspection_data(result)
|
|
||||||
|
|
||||||
# result["class"] = ai_frames["class_list"]
|
|
||||||
|
|
||||||
# with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f:
|
|
||||||
# f.write(json.dumps(result, ensure_ascii=False, indent=4))
|
|
||||||
# with open(f"{output_dir}/hazard_inspection_reason.json", "w", encoding="utf-8") as f:
|
|
||||||
# f.write(json.dumps(reason_test, ensure_ascii=False, indent=4))
|
|
||||||
|
|
||||||
# # with open(f"{output_dir}/hazard_inspection_merged.json", "w", encoding="utf-8") as f:
|
|
||||||
# # f.write(json.dumps(merged_result, ensure_ascii=False, indent=4))
|
|
||||||
|
|
||||||
# #================生成报告=================
|
|
||||||
# if gen_report:
|
|
||||||
# if result == {}:
|
|
||||||
# with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
|
|
||||||
# result = json.load(f)
|
|
||||||
|
|
||||||
# with open(f"知识库/rule.json", "r", encoding="utf-8") as f:
|
|
||||||
# rule_definitions = json.load(f)
|
|
||||||
|
|
||||||
# report_generator(
|
|
||||||
# video_path=vid_path, # 视频文件路径
|
|
||||||
# detection_data=annotated_frames.data(), # 物体检测数据
|
|
||||||
# hazard_results=result, # 隐患检查结果
|
|
||||||
# rule_definitions=rule_definitions, # 规则定义
|
|
||||||
# output_path=output_dir, # 输出文件夹
|
|
||||||
# frame_interval=interval, # 帧间隔(可根据实际视频帧率调整)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # 保存结束时间字符串
|
|
||||||
# time_data["end_time"] = str(datetime.now())
|
|
||||||
# with open(f"{output_dir}/time.json", "w") as f:
|
|
||||||
# f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
|
|
||||||
|
|
||||||
# #================更新预览=================
|
|
||||||
|
|
||||||
# # 获取总帧数
|
|
||||||
# cap = cv2.VideoCapture(vid_path)
|
|
||||||
# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
||||||
# cap.release()
|
|
||||||
# img_path, class_tag = update_preview(0, vid_name)
|
|
||||||
# return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag
|
|
||||||
|
|
||||||
def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True):
|
def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True):
|
||||||
vid_name = vid_file.split(".")[0] # 视频名称(不含后缀)
|
vid_name = vid_file.split(".")[0] # 视频名称(不含后缀)
|
||||||
vid_end = f".{vid_file.split('.')[-1]}" # 视频后缀
|
vid_end = f".{vid_file.split('.')[-1]}" # 视频后缀
|
||||||
|
|
@ -218,6 +83,7 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r
|
||||||
frame_detections_path = f"{output_dir}/frame_detections.json"
|
frame_detections_path = f"{output_dir}/frame_detections.json"
|
||||||
objects_json_path = f"{output_dir}/objects.json"
|
objects_json_path = f"{output_dir}/objects.json"
|
||||||
vid_dir = f"{output_dir}/obj_vids"
|
vid_dir = f"{output_dir}/obj_vids"
|
||||||
|
json_result: dict = {}
|
||||||
|
|
||||||
rule_dict: dict = load_json_data('知识库/rule.json')
|
rule_dict: dict = load_json_data('知识库/rule.json')
|
||||||
video_url_file = f"{output_dir}/video_url.json"
|
video_url_file = f"{output_dir}/video_url.json"
|
||||||
|
|
@ -231,6 +97,7 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r
|
||||||
# 获取总帧数与帧率
|
# 获取总帧数与帧率
|
||||||
cap = cv2.VideoCapture(input_video_path)
|
cap = cv2.VideoCapture(input_video_path)
|
||||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
global fps
|
||||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||||
cap.release()
|
cap.release()
|
||||||
|
|
||||||
|
|
@ -274,7 +141,7 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r
|
||||||
# 上传视频并获取 URL
|
# 上传视频并获取 URL
|
||||||
vid_dict = get_vid_dict(vid_dir, obj_dict, video_url_file, use_url_cache)
|
vid_dict = get_vid_dict(vid_dir, obj_dict, video_url_file, use_url_cache)
|
||||||
|
|
||||||
hazard_inspection(
|
_,json_result_str = hazard_inspection(
|
||||||
output_dir,
|
output_dir,
|
||||||
obj_dict,
|
obj_dict,
|
||||||
rule_dict,
|
rule_dict,
|
||||||
|
|
@ -283,26 +150,35 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r
|
||||||
fps=fps,
|
fps=fps,
|
||||||
enable_thinking=enable_thinking
|
enable_thinking=enable_thinking
|
||||||
)
|
)
|
||||||
|
json_result: dict = json.loads(json_result_str)
|
||||||
|
|
||||||
|
else: # 不运行隐患检查,直接加载之前的结果
|
||||||
|
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
|
||||||
|
json_result = json.load(f)
|
||||||
|
|
||||||
|
# 保存结束时间字符串
|
||||||
|
time_data["end_time"] = str(datetime.now())
|
||||||
|
with open(f"{output_dir}/time.json", "w") as f:
|
||||||
|
f.write(json.dumps(time_data, ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#================生成报告=================
|
#================生成报告=================
|
||||||
|
|
||||||
if gen_report:
|
if gen_report:
|
||||||
result: dict = {}
|
|
||||||
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
|
|
||||||
result = json.load(f)
|
|
||||||
|
|
||||||
with open(f"知识库/rule.json", "r", encoding="utf-8") as f:
|
with open(f"知识库/rule.json", "r", encoding="utf-8") as f:
|
||||||
rule_definitions = json.load(f)
|
rule_definitions = json.load(f)
|
||||||
|
|
||||||
report_generator(
|
report_generator(
|
||||||
video_path=input_video_path, # 视频文件路径
|
video_path=input_video_path, # 视频文件路径
|
||||||
detection_data=annotated_frames.data(), # 物体检测数据
|
detection_data=annotated_frames.data(), # 物体检测数据
|
||||||
hazard_results=result, # 隐患检查结果
|
hazard_results=json_result, # 隐患检查结果
|
||||||
rule_definitions=rule_definitions, # 规则定义
|
rule_definitions=rule_definitions, # 规则定义
|
||||||
output_path=output_dir, # 输出文件夹
|
output_path=output_dir, # 输出文件夹
|
||||||
frame_interval=interval, # 帧间隔(可根据实际视频帧率调整)
|
frame_interval=interval, # 帧间隔(可根据实际视频帧率调整)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#================保存结束时间=================
|
||||||
# 保存结束时间字符串
|
# 保存结束时间字符串
|
||||||
time_data["end_time"] = str(datetime.now())
|
time_data["end_time"] = str(datetime.now())
|
||||||
with open(f"{output_dir}/time.json", "w") as f:
|
with open(f"{output_dir}/time.json", "w") as f:
|
||||||
|
|
@ -312,7 +188,9 @@ def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_r
|
||||||
|
|
||||||
|
|
||||||
img_path, class_tag = update_preview(0, vid_name)
|
img_path, class_tag = update_preview(0, vid_name)
|
||||||
return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag
|
return gr.update(maximum=total_frames-1, value=0, step = interval), img_path, class_tag, json_result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_vid_dict(vid_dir: str, obj_dict: dict, video_url_file: str, use_url_cache: bool) -> dict:
|
def get_vid_dict(vid_dir: str, obj_dict: dict, video_url_file: str, use_url_cache: bool) -> dict:
|
||||||
"""获取视频字典
|
"""获取视频字典
|
||||||
|
|
@ -411,10 +289,11 @@ with gr.Blocks() as demo:
|
||||||
preview = gr.Image(label="预览", scale=2)
|
preview = gr.Image(label="预览", scale=2)
|
||||||
img_slider = gr.Slider(label="帧索引", minimum=0, maximum=0, value=0, step=1)
|
img_slider = gr.Slider(label="帧索引", minimum=0, maximum=0, value=0, step=1)
|
||||||
textbox = gr.Textbox(label="隐患结果", lines=10)
|
textbox = gr.Textbox(label="隐患结果", lines=10)
|
||||||
|
jsonbox = gr.JSON(label="隐患结果json")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
run_button.click(fn=run, inputs=[vid_file, run_sam3, run_inspection, gen_report], outputs=[img_slider, preview, textbox])
|
run_button.click(fn=run, inputs=[vid_file, run_sam3, run_inspection, gen_report], outputs=[img_slider, preview, textbox, jsonbox])
|
||||||
img_slider.change(fn=update_preview, inputs=[img_slider, vid_file], outputs=[preview, textbox], show_progress="hidden")
|
img_slider.change(fn=update_preview, inputs=[img_slider, vid_file], outputs=[preview, textbox], show_progress="hidden")
|
||||||
get_vid_path_btn.click(fn=get_full_vid_path, inputs=vid_file, outputs=full_path_text)
|
get_vid_path_btn.click(fn=get_full_vid_path, inputs=vid_file, outputs=full_path_text)
|
||||||
|
|
||||||
|
|
|
||||||
177
lib/qwen_fun.py
177
lib/qwen_fun.py
|
|
@ -359,6 +359,38 @@ def search_knowledge_base(
|
||||||
print(f"请求失败,状态码: {response.status_code}")
|
print(f"请求失败,状态码: {response.status_code}")
|
||||||
print("错误信息:", response.text)
|
print("错误信息:", response.text)
|
||||||
return ""
|
return ""
|
||||||
|
def format_check(output_data: dict) -> bool:
|
||||||
|
"""检查 JSON 字符串是否符合要求"""
|
||||||
|
try:
|
||||||
|
# 检查必要的键
|
||||||
|
required_keys = ['tag', 'base', 'objects']
|
||||||
|
for key in required_keys:
|
||||||
|
if key not in output_data:
|
||||||
|
print(f"Missing required key: {key}")
|
||||||
|
return False
|
||||||
|
# 检查tag是否为列表
|
||||||
|
if not isinstance(output_data['tag'], list):
|
||||||
|
print("tag must be a list")
|
||||||
|
return False
|
||||||
|
# 检查base是否为列表
|
||||||
|
if not isinstance(output_data['base'], list):
|
||||||
|
print("base must be a list")
|
||||||
|
return False
|
||||||
|
# 检查objects是否为列表
|
||||||
|
if not isinstance(output_data['objects'], list):
|
||||||
|
print("objects must be a list")
|
||||||
|
return False
|
||||||
|
# 检查objects中的每个元素
|
||||||
|
for obj in output_data['objects']:
|
||||||
|
required_obj_keys = ['tag_id', 'base_id', 'hazard_track_id', 'conf', 'location', 'recommend']
|
||||||
|
for key in required_obj_keys:
|
||||||
|
if key not in obj:
|
||||||
|
print(f"Missing required key in object: {key}")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Format check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
def hazard_inspection(
|
def hazard_inspection(
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
|
|
@ -407,83 +439,99 @@ def hazard_inspection(
|
||||||
for track_id, track_info in obj_dict.items():
|
for track_id, track_info in obj_dict.items():
|
||||||
class_str = class_list[track_info["class_id"]]
|
class_str = class_list[track_info["class_id"]]
|
||||||
vid_url = vid_dict[track_id]["vid_url"]
|
vid_url = vid_dict[track_id]["vid_url"]
|
||||||
result = single_object_inspection(vid_url, class_str, rule_dict, prompt_str, enable_thinking)
|
result: tuple[str, str] | str = ""
|
||||||
if result == "skip":
|
reasoning_text: str = ""
|
||||||
continue
|
answer_text: str = ""
|
||||||
else:
|
|
||||||
reasoning_text, answer_text = result
|
|
||||||
|
|
||||||
# 整合思考过程
|
while True:
|
||||||
all_reasoning_text += f"\n\n# 类别: {track_id}:{class_str}\n{reasoning_text}"
|
result = single_object_inspection(vid_url, class_str, rule_dict, prompt_str, enable_thinking)
|
||||||
|
if result == "skip":
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
reasoning_text, answer_text = result
|
||||||
|
|
||||||
# 解析 JSON 结果
|
# 解析 JSON 结果
|
||||||
try:
|
try:
|
||||||
# 提取 JSON 部分(去除可能的代码块标记)
|
# 提取 JSON 部分(去除可能的代码块标记)
|
||||||
json_str = answer_text.strip()
|
json_str = answer_text.strip()
|
||||||
if json_str.startswith("```json"):
|
if json_str.startswith("```json"):
|
||||||
json_str = json_str[7:]
|
json_str = json_str[7:]
|
||||||
if json_str.startswith("```"):
|
if json_str.startswith("```"):
|
||||||
json_str = json_str[3:]
|
json_str = json_str[3:]
|
||||||
if json_str.endswith("```"):
|
if json_str.endswith("```"):
|
||||||
json_str = json_str[:-3]
|
json_str = json_str[:-3]
|
||||||
json_str = json_str.strip()
|
json_str = json_str.strip()
|
||||||
|
|
||||||
class_result = json.loads(json_str)
|
class_result = json.loads(json_str)
|
||||||
|
|
||||||
# 保存class_result到文件,便于调试
|
# 检查 JSON 格式是否正确
|
||||||
with open(f"{output_dir}/inspection_result_{int(track_id):03d}.txt", "w", encoding="utf-8") as f:
|
if isinstance(class_result, dict) and format_check(class_result):
|
||||||
f.write(json.dumps(class_result, ensure_ascii=False, indent=2))
|
class_result = class_result
|
||||||
|
# 检查是否为列表且包含有效元素
|
||||||
|
elif isinstance(class_result, list) and len(class_result) > 0 and format_check(class_result[0]):
|
||||||
|
class_result = class_result[0]
|
||||||
|
else:
|
||||||
|
print("JSON 格式错误, 重试隐患检查")
|
||||||
|
continue # 重试隐患检查,直到格式正确
|
||||||
|
|
||||||
# 收集该类别的结果
|
# 保存class_result到文件,便于调试
|
||||||
temp_obj = {}
|
with open(f"{output_dir}/inspection_result_{int(track_id):03d}.txt", "w", encoding="utf-8") as f:
|
||||||
for obj in class_result.get("objects", []):
|
f.write(json.dumps(class_result, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
# tag_id
|
# 整合思考过程
|
||||||
if "tag_id" in obj and obj["tag_id"] < len(class_result.get("tag", [])):
|
all_reasoning_text += f"\n\n# 类别: {track_id}:{class_str}\n{reasoning_text}"
|
||||||
tag: str = class_result["tag"][obj["tag_id"]]
|
|
||||||
else:
|
|
||||||
tag: str = ""
|
|
||||||
|
|
||||||
if tag not in all_tags:
|
# 收集该类别的结果
|
||||||
all_tags.append(tag)
|
temp_obj = {}
|
||||||
temp_obj["tag_id"] = all_tags.index(tag)
|
for obj in class_result.get("objects", []):
|
||||||
|
|
||||||
# base_id
|
# tag_id
|
||||||
if "base_id" in obj and obj["base_id"] < len(class_result.get("base", [])):
|
if "tag_id" in obj and obj["tag_id"] < len(class_result.get("tag", [])):
|
||||||
base = class_result["base"][obj["base_id"]]
|
tag: str = class_result["tag"][obj["tag_id"]]
|
||||||
else:
|
else:
|
||||||
base = None
|
tag: str = ""
|
||||||
|
|
||||||
if base not in all_bases:
|
if tag not in all_tags:
|
||||||
all_bases.append(base)
|
all_tags.append(tag)
|
||||||
temp_obj["base_id"] = all_bases.index(base)
|
temp_obj["tag_id"] = all_tags.index(tag)
|
||||||
|
|
||||||
# 隐患等级
|
# base_id
|
||||||
if get_hazard_level_from_rule(rule_dict, class_str, tag) == "重大隐患":
|
if "base_id" in obj and obj["base_id"] < len(class_result.get("base", [])):
|
||||||
temp_obj["level"] = 1
|
base = class_result["base"][obj["base_id"]]
|
||||||
elif get_hazard_level_from_rule(rule_dict, class_str, tag) == "一般隐患":
|
else:
|
||||||
temp_obj["level"] = 0
|
base = None
|
||||||
else:
|
|
||||||
temp_obj["level"] = -1,
|
|
||||||
|
|
||||||
# 其他字段
|
if base not in all_bases:
|
||||||
temp_obj["track_id"] = track_id
|
all_bases.append(base)
|
||||||
temp_obj["hazard_track_id"] = hazard_count
|
temp_obj["base_id"] = all_bases.index(base)
|
||||||
temp_obj["class_id"] = track_info["class_id"]
|
|
||||||
temp_obj["conf"] = obj["conf"]
|
|
||||||
temp_obj["start_frame"] = track_info["start_frame"]
|
|
||||||
temp_obj["end_frame"] = track_info["end_frame"]
|
|
||||||
temp_obj["start_sec"] = round(track_info["start_frame"] / fps, 1) # 处理成x.x秒
|
|
||||||
temp_obj["location"] = obj.get("location", "")
|
|
||||||
|
|
||||||
hazard_count += 1
|
# 隐患等级
|
||||||
|
if get_hazard_level_from_rule(rule_dict, class_str, tag) == "重大隐患":
|
||||||
|
temp_obj["level"] = 1
|
||||||
|
elif get_hazard_level_from_rule(rule_dict, class_str, tag) == "一般隐患":
|
||||||
|
temp_obj["level"] = 0
|
||||||
|
else:
|
||||||
|
temp_obj["level"] = -1,
|
||||||
|
|
||||||
all_objects.append(temp_obj)
|
# 其他字段
|
||||||
|
temp_obj["track_id"] = track_id
|
||||||
|
temp_obj["hazard_track_id"] = hazard_count
|
||||||
|
temp_obj["class_id"] = track_info["class_id"]
|
||||||
|
temp_obj["conf"] = obj["conf"]
|
||||||
|
temp_obj["start_frame"] = track_info["start_frame"]
|
||||||
|
temp_obj["end_frame"] = track_info["end_frame"]
|
||||||
|
temp_obj["start_sec"] = round(track_info["start_frame"] / fps, 1) # 处理成x.x秒
|
||||||
|
temp_obj["location"] = obj.get("location", "")
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
hazard_count += 1
|
||||||
print(f"解析类别 {class_str} 的 JSON 结果失败: {e}")
|
|
||||||
print(f"原始输出: {answer_text}")
|
all_objects.append(temp_obj)
|
||||||
continue
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"解析类别 {class_str} 的 JSON 结果失败: {e}")
|
||||||
|
print(f"原始输出: {answer_text}")
|
||||||
|
continue
|
||||||
|
break # 跳出当前循环,继续下一个物体
|
||||||
|
|
||||||
# 构建最终结果
|
# 构建最终结果
|
||||||
final_result = {
|
final_result = {
|
||||||
|
|
@ -503,6 +551,7 @@ def hazard_inspection(
|
||||||
|
|
||||||
return all_reasoning_text, json_result
|
return all_reasoning_text, json_result
|
||||||
|
|
||||||
|
|
||||||
def single_object_inspection(vid_url: str, class_str: str, rule_dict: dict, prompt_str: str, enable_thinking: bool = True) -> tuple[str, str] | str:
|
def single_object_inspection(vid_url: str, class_str: str, rule_dict: dict, prompt_str: str, enable_thinking: bool = True) -> tuple[str, str] | str:
|
||||||
# 获取检查规则
|
# 获取检查规则
|
||||||
rule = get_rule_by_class_str(class_str, rule_dict)
|
rule = get_rule_by_class_str(class_str, rule_dict)
|
||||||
|
|
@ -686,7 +735,7 @@ def report_generator(
|
||||||
start_frame = max(0, min(start_frame, total_frames - 1))
|
start_frame = max(0, min(start_frame, total_frames - 1))
|
||||||
|
|
||||||
# 获取隐患对应的类名和标签
|
# 获取隐患对应的类名和标签
|
||||||
class_name = hazard_results["class"][obj["class_id"]]
|
class_name = hazard_results["class_list"][obj["class_id"]]
|
||||||
tag_name = hazard_results["tag"][obj["tag_id"]]
|
tag_name = hazard_results["tag"][obj["tag_id"]]
|
||||||
original_level = obj["level"]
|
original_level = obj["level"]
|
||||||
location = obj.get("location", "")
|
location = obj.get("location", "")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue