791 lines
28 KiB
Python
791 lines
28 KiB
Python
import json
|
||
import os
|
||
from pathlib import Path
|
||
from lib.json_fun import load_json_data
|
||
import requests
|
||
from openai import BadRequestError, OpenAI
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import requests
|
||
import cv2
|
||
|
||
# 图片列表:键为下拉选项名称,值为包含网络地址和本地地址的字典
|
||
# 注意:请使用有效的图片URL,或将图片下载到本地后使用本地路径
|
||
file_dict: dict = {}
|
||
# file_info_path: str = "file_dict.json"
|
||
file_info_path: str = r"D:\Userfile\Downloads\export_urls.csv"
|
||
# 设置环境变量(请替换为您自己的 API Key)
|
||
os.environ["DASHSCOPE_API_KEY"] = 'sk-c7ffed1b32794284a5d21b046d7d0008'
|
||
os.environ["MAXKB_API_KEY"] = 'agent-300b841bd6a7cb3a0bf603c093c22398'
|
||
os.environ["MODEL_NAME"] = 'qwen3.5-27b'
|
||
|
||
def chat(prompt, file_path: str | list, enable_thinking=False, fps=10, thinking_budget: int=81920):
|
||
"""
|
||
调用 DashScope 接口进行图片理解与提示词检测。
|
||
参数 img_info 为包含网络和本地路径的字典。
|
||
"""
|
||
print("开始处理视频...")
|
||
print("当前文件: ", file_path)
|
||
print("============提示词============")
|
||
print(prompt)
|
||
|
||
enable_thinking = enable_thinking
|
||
client = OpenAI(
|
||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
)
|
||
messages = []
|
||
|
||
if type(file_path) == str:
|
||
if '.jpg' in file_path.lower():
|
||
file_path = [["img", file_path]]
|
||
elif '.mp4' or '.avi' in file_path.lower():
|
||
file_path = [["vid", file_path, fps]]
|
||
|
||
|
||
content: list = []
|
||
for item in file_path:
|
||
if item[0] == "img":
|
||
content.append({"type": "image_url", "image_url": {"url": item[1]}})
|
||
elif item[0] == "vid":
|
||
content.append({"type": "video_url", "video_url": {"url": item[1]}, "fps": item[2]})
|
||
else:
|
||
raise ValueError(f"不支持的文件类型:{item[0]}", 99999)
|
||
|
||
content.append({"type": "text", "text": prompt})
|
||
|
||
messages = [{
|
||
"role": "user",
|
||
"content": content
|
||
}]
|
||
|
||
# 开启流式输出
|
||
try:
|
||
completion = client.chat.completions.create(
|
||
model=os.getenv("MODEL_NAME"), # type: ignore
|
||
messages=messages, # type: ignore
|
||
extra_body={"enable_thinking": enable_thinking, "thinking_budget": thinking_budget},
|
||
stream=True,
|
||
# 支持访问图片和视频的OSS协议 URL
|
||
extra_headers={"X-DashScope-OssResourceResolve": "enable"}
|
||
) # type: ignore
|
||
except BadRequestError as e:
|
||
# 这里捕捉到特定的 400 错误
|
||
# 判断是否是图片相关的错误
|
||
if "Failed to download multimodal content" in str(e):
|
||
# 返回空的模型回复,并返回错误信息
|
||
raise ValueError("❌ 图片链接不存在或已过期")
|
||
# return "", "❌ 图片链接不存在或已过期"
|
||
else:
|
||
# 其他错误直接抛出或返回通用错误
|
||
raise ValueError(f"❌ 发生错误: {str(e)}")
|
||
except Exception as e:
|
||
# 捕捉其他未知错误
|
||
raise ValueError(f"❌ 未知错误: {str(e)}")
|
||
|
||
# 捕获流式数据
|
||
reasoning_text = ""
|
||
answer_text = ""
|
||
is_answering = False
|
||
|
||
if enable_thinking:
|
||
print("============思考过程============")
|
||
for chunk in completion:
|
||
delta = chunk.choices[0].delta
|
||
# 思考过程
|
||
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
||
reasoning_text += delta.reasoning_content
|
||
print(delta.reasoning_content, end="")
|
||
# 最终答案
|
||
if hasattr(delta, "content") and delta.content:
|
||
if not is_answering:
|
||
is_answering = True
|
||
print("============最终答案============")
|
||
answer_text += delta.content
|
||
print(delta.content, end="")
|
||
print() # 换行
|
||
|
||
# 返回格式化后的结果
|
||
return reasoning_text, answer_text
|
||
|
||
def upload_file_and_get_url(file_path: str) -> str:
|
||
"""
|
||
单个文件上传到 OSS 并返回 URL(内部使用)
|
||
"""
|
||
# 1. 获取上传策略
|
||
policy_resp = requests.get(
|
||
"https://dashscope.aliyuncs.com/api/v1/uploads",
|
||
headers={"Authorization": f"Bearer {os.getenv('DASHSCOPE_API_KEY')}"},
|
||
params={"action": "getPolicy", "model": os.getenv('MODEL_NAME')}
|
||
)
|
||
|
||
# 检查请求是否成功
|
||
if policy_resp.status_code != 200:
|
||
raise Exception(f"获取上传策略失败,状态码: {policy_resp.status_code}, 响应: {policy_resp.text}")
|
||
|
||
# 尝试获取 'data' 字段
|
||
policy_json = policy_resp.json()
|
||
if 'data' not in policy_json:
|
||
raise Exception(f"获取上传策略失败,返回内容不包含 'data' 字段: {policy_json}")
|
||
|
||
policy_data = policy_json['data']
|
||
|
||
# 2. 上传文件到 OSS
|
||
file_name = Path(file_path).name
|
||
key = f"{policy_data['upload_dir']}/{file_name}"
|
||
|
||
with open(file_path, 'rb') as f:
|
||
files = {
|
||
'OSSAccessKeyId': (None, policy_data['oss_access_key_id']),
|
||
'Signature': (None, policy_data['signature']),
|
||
'policy': (None, policy_data['policy']),
|
||
'x-oss-object-acl': (None, policy_data['x_oss_object_acl']),
|
||
'x-oss-forbid-overwrite': (None, policy_data['x_oss_forbid_overwrite']),
|
||
'key': (None, key),
|
||
'success_action_status': (None, '200'),
|
||
'file': (file_name, f)
|
||
}
|
||
response = requests.post(policy_data['upload_host'], files=files)
|
||
if response.status_code != 200:
|
||
raise Exception(f"文件上传失败: {file_path}, 状态码: {response.status_code}, 响应: {response.text}")
|
||
|
||
return f"oss://{key}"
|
||
|
||
def upload_files_and_get_urls_concurrently(
|
||
file_path_list: list[str],
|
||
max_workers: int | None = None,
|
||
timeout: int = 300
|
||
) -> list[str | None]:
|
||
"""
|
||
并发上传多个文件到 OSS 并返回对应的 URL 列表(顺序对应)
|
||
"""
|
||
# 初始化一个与文件列表长度相同的结果列表
|
||
results: list[str | None] = [None] * len(file_path_list)
|
||
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
# 使用字典存储 future 与其对应的索引
|
||
future_to_index = {
|
||
executor.submit(
|
||
upload_file_and_get_url,
|
||
file_path
|
||
): idx
|
||
for idx, file_path in enumerate(file_path_list)
|
||
}
|
||
|
||
for future in as_completed(future_to_index, timeout=timeout):
|
||
idx = future_to_index[future]
|
||
file_path = file_path_list[idx]
|
||
try:
|
||
url = future.result()
|
||
results[idx] = url # 将结果放回对应的索引位置
|
||
print(f"[成功] {file_path} -> {url}")
|
||
except Exception as exc:
|
||
# 捕获异常并打印错误信息
|
||
print(f"[错误] {file_path} 上传失败: {exc}")
|
||
results[idx] = None # 保持为 None 或者设为错误标识
|
||
|
||
return results
|
||
|
||
def save_json_to_file(json_data, file_path: str) -> None:
|
||
"""
|
||
保存 JSON 数据到文件
|
||
|
||
参数:
|
||
json_data: 要保存的 JSON 数据
|
||
file_path: 文件路径
|
||
"""
|
||
# 1. 提取目录路径
|
||
dir_path = os.path.dirname(file_path)
|
||
|
||
# 2. 如果有目录路径,确保它存在
|
||
if dir_path:
|
||
os.makedirs(dir_path, exist_ok=True) # exist_ok=True 防止重复创建报错
|
||
|
||
# 3. 写入文件
|
||
try:
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
||
except Exception as e:
|
||
print(e)
|
||
return
|
||
|
||
print(f"已保存 JSON 数据到: {file_path}")
|
||
|
||
|
||
|
||
def get_unique_track_id_count(json_path: str) -> int:
|
||
"""
|
||
提取并计算JSON文件中不重复的track_id总数
|
||
|
||
参数:
|
||
json_path: JSON文件路径
|
||
|
||
返回:
|
||
不重复的track_id总数
|
||
"""
|
||
# 加载JSON数据
|
||
data = load_json_data(json_path)
|
||
|
||
# 存储唯一的track_id
|
||
unique_track_ids = set()
|
||
|
||
# 遍历所有帧
|
||
for frame_data in data.values():
|
||
# 遍历当前帧的所有检测结果
|
||
for detection in frame_data:
|
||
# 提取track_id并添加到集合中
|
||
if 'track_id' in detection:
|
||
unique_track_ids.add(detection['track_id'])
|
||
|
||
# 返回唯一track_id的数量
|
||
return len(unique_track_ids)
|
||
|
||
def get_annnotated_frame_for_ai_without_xyxy(json_data: dict, interval: int = 5, conf: float = 0.7) -> dict:
|
||
"""
|
||
从 frame_all.json 文件中提取ai能看到的部分,并重新计算帧号
|
||
按照指定格式重构返回结果
|
||
|
||
输出json_data格式:
|
||
{
|
||
"class_list": ["class_name1", "class_name2", ..."], # 类别名称
|
||
"track_id_list": {
|
||
"0": { # track_id
|
||
"class_id": 0,
|
||
"start_frame": 0,
|
||
"end_frame": 10
|
||
},
|
||
...
|
||
},
|
||
}
|
||
"""
|
||
# 抽取0,5,10...并重命名帧号为0,1,2...
|
||
frame_ids: list = list(json_data.keys())
|
||
selected_frames: list = frame_ids[::interval]
|
||
|
||
# 提取所有类别名称(去重)
|
||
class_set = set()
|
||
for frame_id in selected_frames:
|
||
frame_data = json_data[frame_id]
|
||
for obj in frame_data:
|
||
class_set.add(obj["class_str"])
|
||
class_list = sorted(list(class_set))
|
||
|
||
# 建立类别名称到ID的映射
|
||
class_name_to_id = {name: idx for idx, name in enumerate(class_list)}
|
||
|
||
# 构建track_id_list
|
||
track_id_dict = {}
|
||
for new_idx, frame_id in enumerate(selected_frames):
|
||
frame_data = json_data[frame_id]
|
||
for obj in frame_data:
|
||
if obj["confidence"] < conf:
|
||
continue
|
||
|
||
track_id = obj["track_id"]
|
||
class_id = class_name_to_id[obj["class_str"]]
|
||
|
||
if track_id not in track_id_dict:
|
||
track_id_dict[track_id] = {
|
||
"class_id": class_id,
|
||
"start_frame": new_idx,
|
||
"end_frame": new_idx
|
||
}
|
||
else:
|
||
track_id_dict[track_id]["end_frame"] = new_idx
|
||
|
||
# 重构返回结果
|
||
result = {
|
||
"class_list": class_list,
|
||
"track_id_list": track_id_dict
|
||
}
|
||
|
||
return result
|
||
|
||
def search_knowledge_base(
|
||
question: str,
|
||
system_prompt: str = "你是科技有限公司 MaxKB 知识问答系统的智能小,你的工作是 MaxKB 用户解答中,用户找你回答问题时,你要把主题放在 MaxKB 知识问答系统身上",
|
||
**kwargs
|
||
):
|
||
"""
|
||
通过 MaxKB 知识库问答系统获取关于 'MaxKB 是什么?' 的答案。
|
||
"""
|
||
# 1. 请求 URL
|
||
url = "http://localhost:8080/chat/api/019d5265-e90e-7663-b202-fc5a47d8601c/chat/completions"
|
||
|
||
# 2. 请求头
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": "Bearer agent-300b841bd6a7cb3a0bf603c093c22398"
|
||
}
|
||
|
||
# 3. 请求体(Payload)
|
||
payload = {
|
||
"model": "qwen3.5-27b",
|
||
"messages": [
|
||
{
|
||
"role": system_prompt,
|
||
"content": question
|
||
}
|
||
]
|
||
}
|
||
|
||
# 4. 发送 POST 请求
|
||
response = requests.post(url, headers=headers, json=payload)
|
||
|
||
# 5. 处理响应
|
||
if response.status_code == 200:
|
||
try:
|
||
# 尝试解析为 JSON
|
||
data = response.json()
|
||
# 打印完整的响应结构(调试用)
|
||
print("Full Response:")
|
||
print(json.dumps(data, ensure_ascii=False, indent=2))
|
||
|
||
# 尝试提取回答内容(根据常见的结构)
|
||
# 这里假设答案在 data['choices'][0]['message']['content'] 中
|
||
if 'choices' in data and data['choices']:
|
||
answer = data['choices'][0].get('message', {}).get('content', '')
|
||
print("\n=== MaxKB 回答 ===")
|
||
print(answer.strip())
|
||
|
||
return answer.strip()
|
||
else:
|
||
print("\n未在响应中找到 'choices' 字段。")
|
||
return ""
|
||
except Exception as e:
|
||
print("解析响应时出错:", e)
|
||
print("原始响应:", response.text)
|
||
return ""
|
||
else:
|
||
print(f"请求失败,状态码: {response.status_code}")
|
||
print("错误信息:", response.text)
|
||
return ""
|
||
def format_check(output_data: dict) -> bool:
|
||
"""检查 JSON 字符串是否符合要求"""
|
||
try:
|
||
# 检查必要的键
|
||
required_keys = ['tag', 'base', 'objects']
|
||
for key in required_keys:
|
||
if key not in output_data:
|
||
print(f"Missing required key: {key}")
|
||
return False
|
||
# 检查tag是否为列表
|
||
if not isinstance(output_data['tag'], list):
|
||
print("tag must be a list")
|
||
return False
|
||
# 检查base是否为列表
|
||
if not isinstance(output_data['base'], list):
|
||
print("base must be a list")
|
||
return False
|
||
# 检查objects是否为列表
|
||
if not isinstance(output_data['objects'], list):
|
||
print("objects must be a list")
|
||
return False
|
||
# 检查objects中的每个元素
|
||
for obj in output_data['objects']:
|
||
required_obj_keys = ['tag_id', 'base_id', 'hazard_track_id', 'conf', 'location', 'recommend']
|
||
for key in required_obj_keys:
|
||
if key not in obj:
|
||
print(f"Missing required key in object: {key}")
|
||
return False
|
||
return True
|
||
except Exception as e:
|
||
print(f"Format check failed: {e}")
|
||
return False
|
||
|
||
def hazard_inspection(
|
||
output_dir: str,
|
||
obj_dict: dict,
|
||
rule_dict: dict,
|
||
class_list: list[str],
|
||
vid_dict: dict,
|
||
fps: int,
|
||
enable_thinking: bool = True,
|
||
) -> tuple[str, str]:
|
||
"""
|
||
检查视频中给定类型物体的隐患
|
||
2026.04.17
|
||
|
||
参数:
|
||
output_dir: 输出目录
|
||
obj_dict: 物体视频字典
|
||
class_list: 类别列表
|
||
vid_dict: 视频字典
|
||
fps: 视频帧率
|
||
enable_thinking: 是否开启思考模式
|
||
|
||
obj_dict 格式:
|
||
|
||
vid_dict 格式:
|
||
|
||
返回:
|
||
tuple[str, str]: (reasoning_text, json_result)
|
||
"""
|
||
|
||
# 存储所有类别的结果
|
||
all_reasoning_text = ""
|
||
|
||
# 存储所有类别的 JSON 结果
|
||
all_tags = []
|
||
all_bases = []
|
||
all_objects = []
|
||
hazard_count = 0
|
||
|
||
|
||
with open(r'prompt\7_检测提示词_视频_new4.md', 'r', encoding='utf-8') as file:
|
||
prompt_str = file.read()
|
||
|
||
|
||
# 遍历每个物体视频,进行检查
|
||
for track_id, track_info in obj_dict.items():
|
||
class_str = class_list[track_info["class_id"]]
|
||
vid_url = vid_dict[track_id]["vid_url"]
|
||
result: tuple[str, str] | str = ""
|
||
reasoning_text: str = ""
|
||
answer_text: str = ""
|
||
|
||
while True:
|
||
result = single_object_inspection(vid_url, class_str, rule_dict, prompt_str, enable_thinking)
|
||
if result == "skip":
|
||
break
|
||
else:
|
||
reasoning_text, answer_text = result
|
||
|
||
# 解析 JSON 结果
|
||
try:
|
||
# 提取 JSON 部分(去除可能的代码块标记)
|
||
json_str = answer_text.strip()
|
||
if json_str.startswith("```json"):
|
||
json_str = json_str[7:]
|
||
if json_str.startswith("```"):
|
||
json_str = json_str[3:]
|
||
if json_str.endswith("```"):
|
||
json_str = json_str[:-3]
|
||
json_str = json_str.strip()
|
||
|
||
class_result = json.loads(json_str)
|
||
|
||
# 检查 JSON 格式是否正确
|
||
if isinstance(class_result, dict) and format_check(class_result):
|
||
class_result = class_result
|
||
# 检查是否为列表且包含有效元素
|
||
elif isinstance(class_result, list) and len(class_result) > 0 and format_check(class_result[0]):
|
||
class_result = class_result[0]
|
||
else:
|
||
print("JSON 格式错误, 重试隐患检查")
|
||
continue # 重试隐患检查,直到格式正确
|
||
|
||
# 保存class_result到文件,便于调试
|
||
with open(f"{output_dir}/inspection_result_{int(track_id):03d}.txt", "w", encoding="utf-8") as f:
|
||
f.write(json.dumps(class_result, ensure_ascii=False, indent=2))
|
||
|
||
# 整合思考过程
|
||
all_reasoning_text += f"\n\n# 类别: {track_id}:{class_str}\n{reasoning_text}"
|
||
|
||
# 收集该类别的结果
|
||
temp_obj = {}
|
||
for obj in class_result.get("objects", []):
|
||
|
||
# tag_id
|
||
if "tag_id" in obj and obj["tag_id"] < len(class_result.get("tag", [])):
|
||
tag: str = class_result["tag"][obj["tag_id"]]
|
||
else:
|
||
continue # 重试隐患检查,直到格式正确
|
||
|
||
if tag not in all_tags:
|
||
all_tags.append(tag)
|
||
temp_obj["tag_id"] = all_tags.index(tag)
|
||
|
||
# base_id
|
||
if "base_id" in obj and obj["base_id"] < len(class_result.get("base", [])):
|
||
base = class_result["base"][obj["base_id"]]
|
||
else:
|
||
continue # 重试隐患检查,直到格式正确
|
||
|
||
if base not in all_bases:
|
||
all_bases.append(base)
|
||
temp_obj["base_id"] = all_bases.index(base)
|
||
|
||
# 隐患等级
|
||
if get_hazard_level_from_rule(rule_dict, class_str, tag) == "重大隐患":
|
||
temp_obj["level"] = 1
|
||
elif get_hazard_level_from_rule(rule_dict, class_str, tag) == "一般隐患":
|
||
temp_obj["level"] = 0
|
||
else:
|
||
temp_obj["level"] = -1,
|
||
|
||
# 其他字段
|
||
temp_obj["track_id"] = track_id
|
||
temp_obj["hazard_track_id"] = hazard_count
|
||
temp_obj["class_id"] = track_info["class_id"]
|
||
temp_obj["conf"] = obj["conf"]
|
||
temp_obj["start_frame"] = track_info["start_frame"]
|
||
temp_obj["end_frame"] = track_info["end_frame"]
|
||
temp_obj["start_sec"] = round(track_info["start_frame"] / fps, 1) # 处理成x.x秒
|
||
temp_obj["location"] = obj.get("location", "")
|
||
temp_obj["recommend"] = obj.get("recommend", "")
|
||
|
||
hazard_count += 1
|
||
|
||
all_objects.append(temp_obj)
|
||
|
||
except json.JSONDecodeError as e:
|
||
print(f"解析类别 {class_str} 的 JSON 结果失败: {e}")
|
||
print(f"原始输出: {answer_text}")
|
||
continue
|
||
break # 跳出当前循环,继续下一个物体
|
||
|
||
# 构建最终结果
|
||
final_result = {
|
||
"class_list": class_list,
|
||
"tag": all_tags,
|
||
"base": all_bases,
|
||
"objects": all_objects
|
||
}
|
||
|
||
# 转换为 JSON 字符串
|
||
json_result = json.dumps(final_result, ensure_ascii=False, indent=2)
|
||
|
||
with open(f"{output_dir}/hazard_inspection.json", "w", encoding="utf-8") as f:
|
||
f.write(json_result)
|
||
with open(f"{output_dir}/hazard_inspection_reasoning.md", "w", encoding="utf-8") as f:
|
||
f.write(all_reasoning_text)
|
||
|
||
return all_reasoning_text, json_result
|
||
|
||
|
||
def single_object_inspection(vid_url: str, class_str: str, rule_dict: dict, prompt_str: str, enable_thinking: bool = True) -> tuple[str, str] | str:
|
||
# 获取检查规则
|
||
rule = get_rule_by_class_str(class_str, rule_dict)
|
||
|
||
if rule == "":
|
||
print(f"类别 {class_str} 没有检查规则,跳过检查")
|
||
return "skip"
|
||
else:
|
||
print(f"类别 {class_str} 的检查规则: {rule}")
|
||
|
||
# 构建该类别的 prompt
|
||
prompt: str = f"""
|
||
# 检测对象为视频中出现的{class_str}
|
||
|
||
# 检查规则
|
||
{rule}
|
||
|
||
{prompt_str}
|
||
"""
|
||
|
||
return chat(prompt, vid_url, enable_thinking)
|
||
|
||
|
||
def get_rule_by_class_str(class_str: str, all_check_rule_dict: dict) -> str:
|
||
"""根据类别名称获取检查规则
|
||
参数:
|
||
class_str: 类别名称
|
||
all_check_rule_dict: 所有检查规则字典
|
||
返回:
|
||
检查规则字符串
|
||
"""
|
||
print(f"当前类别: {class_str}")
|
||
# 检查该类别是否有检查规则
|
||
class_check_rule: str = ""
|
||
scene_list: list[str] = list(all_check_rule_dict.keys())
|
||
for scene in scene_list:
|
||
if class_str in all_check_rule_dict[scene].keys():
|
||
class_check_rule = all_check_rule_dict[scene][class_str]
|
||
break
|
||
return str(class_check_rule)
|
||
|
||
def merge_conflict_inspection_data(data: dict) -> dict:
|
||
"""
|
||
合并冲突数据(相同 tag_id 和 class_id 的重叠时间段)。
|
||
|
||
参数:
|
||
- data: 包含 "class", "tag", "objects" 键的字典。
|
||
|
||
返回:
|
||
- 合并后的数据字典,结构与输入相同,但 objects 已合并。
|
||
"""
|
||
|
||
# 1. 按照 (tag_id, class_id) 对对象进行分组
|
||
groups = {}
|
||
for obj in data["objects"]:
|
||
key = (obj["tag_id"], obj["class_id"])
|
||
groups.setdefault(key, []).append(obj)
|
||
|
||
merged_objects = []
|
||
new_id = 1
|
||
|
||
# 2. 对每个分组进行合并
|
||
for (tag_id, class_id), objs in groups.items():
|
||
# 按 start_frame 排序
|
||
sorted_objs = sorted(objs, key=lambda x: x["start_frame"])
|
||
|
||
# 初始化第一个区间
|
||
cur_start = sorted_objs[0]["start_frame"]
|
||
cur_end = sorted_objs[0]["end_frame"]
|
||
cur_level = sorted_objs[0]["level"]
|
||
|
||
for obj in sorted_objs[1:]:
|
||
# 如果当前区间与下一个区间重叠或相连(end >= start)
|
||
if obj["start_frame"] <= cur_end:
|
||
# 合并区间:取结束帧的最大值
|
||
cur_end = max(cur_end, obj["end_frame"])
|
||
# 级别取最大值(如果级别不同)
|
||
cur_level = max(cur_level, obj["level"])
|
||
else:
|
||
# 当前区间结束,保存合并结果
|
||
merged_objects.append({
|
||
"hazard_track_id": new_id,
|
||
"tag_id": tag_id,
|
||
"class_id": class_id,
|
||
"level": cur_level,
|
||
"start_frame": cur_start,
|
||
"end_frame": cur_end
|
||
})
|
||
new_id += 1
|
||
# 开始新的区间
|
||
cur_start = obj["start_frame"]
|
||
cur_end = obj["end_frame"]
|
||
data["objects"] = merged_objects
|
||
return data
|
||
|
||
def get_hazard_level_from_rule(rule_definitions: dict, class_name: str, tag_name: str) -> str:
|
||
"""
|
||
从知识库规则中获取隐患等级。
|
||
|
||
参数:
|
||
rule_definitions (dict): 隐患规则定义
|
||
class_name (str): 隐患类别名称(如"灭火器")
|
||
tag_name (str): 隐患标签名称(如"灭火器被遮挡")
|
||
|
||
返回:
|
||
str: 隐患等级(如"重大隐患"),未找到时返回"一般隐患"
|
||
"""
|
||
class_rules = rule_definitions.get(class_name, {})
|
||
if not class_rules:
|
||
return "一般隐患"
|
||
|
||
for level_name, level_content in class_rules.items():
|
||
if isinstance(level_content, dict) and tag_name in level_content:
|
||
return level_name
|
||
|
||
return "一般隐患"
|
||
|
||
|
||
def map_original_level(level: int) -> str:
|
||
"""
|
||
将原始 level 映射为相似度描述。
|
||
|
||
参数:
|
||
level (int): 原始等级(1表示一般,2表示明显)
|
||
|
||
返回:
|
||
str: 相似度描述("疑似"或"高度相似")
|
||
"""
|
||
return "高度相似" if level == 2 else "疑似"
|
||
|
||
|
||
def report_generator(
|
||
video_path: str,
|
||
detection_data: dict,
|
||
hazard_results: dict,
|
||
rule_definitions: dict,
|
||
output_path: str = "output",
|
||
frame_interval: float = 1.0
|
||
):
|
||
"""
|
||
生成隐患报告的主函数。
|
||
|
||
参数:
|
||
video_path (str): 视频文件路径
|
||
detection_data (dict): 物体检测数据,格式为 {frame_id: [detection_info, ...]}
|
||
hazard_results (dict): 隐患检查结果
|
||
rule_definitions (dict): 隐患规则定义
|
||
output_path (str): 输出文件夹路径,默认 'output'
|
||
frame_interval (float): 帧间隔时间(秒),用于计算实际帧号,默认 1.0
|
||
|
||
功能:
|
||
1. 创建必要的文件夹结构
|
||
2. 根据隐患时间范围截取对应的图片
|
||
3. 生成 Markdown 报告
|
||
"""
|
||
|
||
# ----------------------- 1. 准备工作 -----------------------
|
||
# 创建文件夹结构
|
||
report_dir = os.path.join(output_path, "report")
|
||
assets_dir = os.path.join(report_dir, "_assets")
|
||
os.makedirs(assets_dir, exist_ok=True)
|
||
|
||
# 打开视频文件
|
||
cap = cv2.VideoCapture(video_path)
|
||
if not cap.isOpened():
|
||
raise IOError(f"无法打开视频文件: {video_path}")
|
||
|
||
# 获取视频总帧数(用于安全检查)
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
||
# ----------------------- 2. 处理隐患数据并截取图片 -----------------------
|
||
# 用于存储 Markdown 表格的行数据
|
||
md_table_rows = []
|
||
|
||
# 遍历所有隐患对象
|
||
for idx, obj in enumerate(hazard_results["objects"], start=1):
|
||
# 计算实际的起始帧(考虑帧间隔)
|
||
start_frame = int(obj["start_frame"] * frame_interval)
|
||
|
||
# 防止帧号越界
|
||
start_frame = max(0, min(start_frame, total_frames - 1))
|
||
|
||
# 获取隐患对应的类名和标签
|
||
class_name = hazard_results["class_list"][obj["class_id"]]
|
||
tag_name = hazard_results["tag"][obj["tag_id"]]
|
||
original_level = obj["level"]
|
||
location = obj.get("location", "")
|
||
|
||
# 从知识库规则中获取隐患等级
|
||
rule_level = get_hazard_level_from_rule(rule_definitions, class_name, tag_name)
|
||
# 将原始 level 映射为相似度描述
|
||
similarity = map_original_level(original_level)
|
||
|
||
# ----------------------- 2.1. 截取图片 -----------------------
|
||
# 默认选择 start_frame 作为截图帧
|
||
target_frame = start_frame
|
||
|
||
# 截图并保存
|
||
img_filename = f"{idx}_{target_frame}.jpg"
|
||
img_path = os.path.join(assets_dir, img_filename)
|
||
|
||
# 读取对应帧并保存整帧图片
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
||
ret, frame = cap.read()
|
||
if ret:
|
||
cv2.imwrite(img_path, frame)
|
||
|
||
# ----------------------- 2.2. 生成 Markdown 表格行 -----------------------
|
||
# 生成图片引用的 Markdown 语法
|
||
img_markdown = f""
|
||
|
||
# 只取前 2 条依据作为展示
|
||
base_text = hazard_results["base"][obj["base_id"]]
|
||
|
||
# 组装表格行
|
||
row = f"|{idx}|{rule_level}|{similarity}|{base_text}|{start_frame}|{location}|{img_markdown}|"
|
||
md_table_rows.append(row)
|
||
|
||
# ----------------------- 3. 生成 Markdown 报告 -----------------------
|
||
md_content = f"""# 隐患报告
|
||
|
||
## 隐患清单
|
||
|
||
| 隐患序号 | 隐患等级 | 相似度 | 依据 | 开始帧 | 位置描述 | 照片 |
|
||
| :---: | :---: | :---: | :--- | :---: | :--- | :---: |
|
||
"""
|
||
md_content += "\n".join(md_table_rows)
|
||
|
||
# 写入报告文件
|
||
report_path = os.path.join(report_dir, "report.md")
|
||
with open(report_path, "w", encoding="utf-8") as f:
|
||
f.write(md_content)
|
||
|
||
# 释放资源
|
||
cap.release()
|
||
print(f"报告已生成: {report_path}") |