749 lines
25 KiB
Python
749 lines
25 KiB
Python
import json
|
||
import os
|
||
from pathlib import Path
|
||
import requests
|
||
from openai import BadRequestError, OpenAI
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import requests
|
||
import cv2
|
||
|
||
# 图片列表:键为下拉选项名称,值为包含网络地址和本地地址的字典
|
||
# 注意:请使用有效的图片URL,或将图片下载到本地后使用本地路径
|
||
file_dict: dict = {}
|
||
# file_info_path: str = "file_dict.json"
|
||
file_info_path: str = r"D:\Userfile\Downloads\export_urls.csv"
|
||
# 设置环境变量(请替换为您自己的 API Key)
|
||
os.environ["DASHSCOPE_API_KEY"] = 'sk-c7ffed1b32794284a5d21b046d7d0008'
|
||
os.environ["MAXKB_API_KEY"] = 'agent-300b841bd6a7cb3a0bf603c093c22398'
|
||
os.environ["MODEL_NAME"] = 'qwen3.5-27b'
|
||
|
||
def chat(prompt, file_path: str | list, enable_thinking=False, fps=10, thinking_budget: int=81920):
|
||
"""
|
||
调用 DashScope 接口进行图片理解与提示词检测。
|
||
参数 img_info 为包含网络和本地路径的字典。
|
||
"""
|
||
print("开始处理视频...")
|
||
print("当前文件: ", file_path)
|
||
print("============提示词============")
|
||
print(prompt)
|
||
|
||
enable_thinking = enable_thinking
|
||
client = OpenAI(
|
||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
)
|
||
messages = []
|
||
|
||
if type(file_path) == str:
|
||
if '.jpg' in file_path.lower():
|
||
file_path = [["img", file_path]]
|
||
elif '.mp4' or '.avi' in file_path.lower():
|
||
file_path = [["vid", file_path, fps]]
|
||
|
||
|
||
content: list = []
|
||
for item in file_path:
|
||
if item[0] == "img":
|
||
content.append({"type": "image_url", "image_url": {"url": item[1]}})
|
||
elif item[0] == "vid":
|
||
content.append({"type": "video_url", "video_url": {"url": item[1]}, "fps": item[2]})
|
||
else:
|
||
raise ValueError(f"不支持的文件类型:{item[0]}", 99999)
|
||
|
||
content.append({"type": "text", "text": prompt})
|
||
|
||
messages = [{
|
||
"role": "user",
|
||
"content": content
|
||
}]
|
||
|
||
# 开启流式输出
|
||
try:
|
||
completion = client.chat.completions.create(
|
||
model=os.getenv("MODEL_NAME"), # type: ignore
|
||
messages=messages, # type: ignore
|
||
extra_body={"enable_thinking": enable_thinking, "thinking_budget": thinking_budget},
|
||
stream=True,
|
||
# 支持访问图片和视频的OSS协议 URL
|
||
extra_headers={"X-DashScope-OssResourceResolve": "enable"}
|
||
) # type: ignore
|
||
except BadRequestError as e:
|
||
# 这里捕捉到特定的 400 错误
|
||
# 判断是否是图片相关的错误
|
||
if "Failed to download multimodal content" in str(e):
|
||
# 返回空的模型回复,并返回错误信息
|
||
raise ValueError("❌ 图片链接不存在或已过期")
|
||
# return "", "❌ 图片链接不存在或已过期"
|
||
else:
|
||
# 其他错误直接抛出或返回通用错误
|
||
raise ValueError(f"❌ 发生错误: {str(e)}")
|
||
except Exception as e:
|
||
# 捕捉其他未知错误
|
||
raise ValueError(f"❌ 未知错误: {str(e)}")
|
||
|
||
# 捕获流式数据
|
||
reasoning_text = ""
|
||
answer_text = ""
|
||
is_answering = False
|
||
|
||
if enable_thinking:
|
||
print("============思考过程============")
|
||
for chunk in completion:
|
||
delta = chunk.choices[0].delta
|
||
# 思考过程
|
||
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
||
reasoning_text += delta.reasoning_content
|
||
print(delta.reasoning_content, end="")
|
||
# 最终答案
|
||
if hasattr(delta, "content") and delta.content:
|
||
if not is_answering:
|
||
is_answering = True
|
||
print("============最终答案============")
|
||
answer_text += delta.content
|
||
print(delta.content, end="")
|
||
print() # 换行
|
||
|
||
# 返回格式化后的结果
|
||
return reasoning_text, answer_text
|
||
|
||
def upload_file_and_get_url(file_path: str) -> str:
|
||
"""
|
||
单个文件上传到 OSS 并返回 URL(内部使用)
|
||
"""
|
||
# 1. 获取上传策略
|
||
policy_resp = requests.get(
|
||
"https://dashscope.aliyuncs.com/api/v1/uploads",
|
||
headers={"Authorization": f"Bearer {os.getenv('DASHSCOPE_API_KEY')}"},
|
||
params={"action": "getPolicy", "model": os.getenv('MODEL_NAME')}
|
||
)
|
||
|
||
# 检查请求是否成功
|
||
if policy_resp.status_code != 200:
|
||
raise Exception(f"获取上传策略失败,状态码: {policy_resp.status_code}, 响应: {policy_resp.text}")
|
||
|
||
# 尝试获取 'data' 字段
|
||
policy_json = policy_resp.json()
|
||
if 'data' not in policy_json:
|
||
raise Exception(f"获取上传策略失败,返回内容不包含 'data' 字段: {policy_json}")
|
||
|
||
policy_data = policy_json['data']
|
||
|
||
# 2. 上传文件到 OSS
|
||
file_name = Path(file_path).name
|
||
key = f"{policy_data['upload_dir']}/{file_name}"
|
||
|
||
with open(file_path, 'rb') as f:
|
||
files = {
|
||
'OSSAccessKeyId': (None, policy_data['oss_access_key_id']),
|
||
'Signature': (None, policy_data['signature']),
|
||
'policy': (None, policy_data['policy']),
|
||
'x-oss-object-acl': (None, policy_data['x_oss_object_acl']),
|
||
'x-oss-forbid-overwrite': (None, policy_data['x_oss_forbid_overwrite']),
|
||
'key': (None, key),
|
||
'success_action_status': (None, '200'),
|
||
'file': (file_name, f)
|
||
}
|
||
response = requests.post(policy_data['upload_host'], files=files)
|
||
if response.status_code != 200:
|
||
raise Exception(f"文件上传失败: {file_path}, 状态码: {response.status_code}, 响应: {response.text}")
|
||
|
||
return f"oss://{key}"
|
||
|
||
def upload_files_and_get_urls_concurrently(
|
||
file_path_list: list[str],
|
||
max_workers: int | None = None,
|
||
timeout: int = 300
|
||
) -> list[str | None]:
|
||
"""
|
||
并发上传多个文件到 OSS 并返回对应的 URL 列表(顺序对应)
|
||
"""
|
||
# 初始化一个与文件列表长度相同的结果列表
|
||
results: list[str | None] = [None] * len(file_path_list)
|
||
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
# 使用字典存储 future 与其对应的索引
|
||
future_to_index = {
|
||
executor.submit(
|
||
upload_file_and_get_url,
|
||
file_path
|
||
): idx
|
||
for idx, file_path in enumerate(file_path_list)
|
||
}
|
||
|
||
for future in as_completed(future_to_index, timeout=timeout):
|
||
idx = future_to_index[future]
|
||
file_path = file_path_list[idx]
|
||
try:
|
||
url = future.result()
|
||
results[idx] = url # 将结果放回对应的索引位置
|
||
print(f"[成功] {file_path} -> {url}")
|
||
except Exception as exc:
|
||
# 捕获异常并打印错误信息
|
||
print(f"[错误] {file_path} 上传失败: {exc}")
|
||
results[idx] = None # 保持为 None 或者设为错误标识
|
||
|
||
return results
|
||
|
||
def save_json_to_file(json_data, file_path: str) -> None:
|
||
"""
|
||
保存 JSON 数据到文件
|
||
|
||
参数:
|
||
json_data: 要保存的 JSON 数据
|
||
file_path: 文件路径
|
||
"""
|
||
# 1. 提取目录路径
|
||
dir_path = os.path.dirname(file_path)
|
||
|
||
# 2. 如果有目录路径,确保它存在
|
||
if dir_path:
|
||
os.makedirs(dir_path, exist_ok=True) # exist_ok=True 防止重复创建报错
|
||
|
||
# 3. 写入文件
|
||
try:
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
||
except Exception as e:
|
||
print(e)
|
||
return
|
||
|
||
print(f"已保存 JSON 数据到: {file_path}")
|
||
|
||
def load_json_data(json_path: str):
|
||
with open(json_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
|
||
def get_unique_track_id_count(json_path: str) -> int:
|
||
"""
|
||
提取并计算JSON文件中不重复的track_id总数
|
||
|
||
参数:
|
||
json_path: JSON文件路径
|
||
|
||
返回:
|
||
不重复的track_id总数
|
||
"""
|
||
# 加载JSON数据
|
||
data = load_json_data(json_path)
|
||
|
||
# 存储唯一的track_id
|
||
unique_track_ids = set()
|
||
|
||
# 遍历所有帧
|
||
for frame_data in data.values():
|
||
# 遍历当前帧的所有检测结果
|
||
for detection in frame_data:
|
||
# 提取track_id并添加到集合中
|
||
if 'track_id' in detection:
|
||
unique_track_ids.add(detection['track_id'])
|
||
|
||
# 返回唯一track_id的数量
|
||
return len(unique_track_ids)
|
||
|
||
def get_annnotated_frame_for_ai_without_xyxy(json_data: dict, interval: int = 5, conf: float = 0.7) -> dict:
|
||
"""
|
||
从 frame_all.json 文件中提取ai能看到的部分,并重新计算帧号
|
||
按照指定格式重构返回结果
|
||
|
||
输出json_data格式:
|
||
{
|
||
"class_list": ["class_name1", "class_name2", ..."], # 类别名称
|
||
"track_id_list": {
|
||
"0": { # track_id
|
||
"class_id": 0,
|
||
"start_frame": 0,
|
||
"end_frame": 10
|
||
},
|
||
...
|
||
},
|
||
}
|
||
"""
|
||
# 抽取0,5,10...并重命名帧号为0,1,2...
|
||
frame_ids: list = list(json_data.keys())
|
||
selected_frames: list = frame_ids[::interval]
|
||
|
||
# 提取所有类别名称(去重)
|
||
class_set = set()
|
||
for frame_id in selected_frames:
|
||
frame_data = json_data[frame_id]
|
||
for obj in frame_data:
|
||
class_set.add(obj["class_str"])
|
||
class_list = sorted(list(class_set))
|
||
|
||
# 建立类别名称到ID的映射
|
||
class_name_to_id = {name: idx for idx, name in enumerate(class_list)}
|
||
|
||
# 构建track_id_list
|
||
track_id_dict = {}
|
||
for new_idx, frame_id in enumerate(selected_frames):
|
||
frame_data = json_data[frame_id]
|
||
for obj in frame_data:
|
||
if obj["confidence"] < conf:
|
||
continue
|
||
|
||
track_id = obj["track_id"]
|
||
class_id = class_name_to_id[obj["class_str"]]
|
||
|
||
if track_id not in track_id_dict:
|
||
track_id_dict[track_id] = {
|
||
"class_id": class_id,
|
||
"start_frame": new_idx,
|
||
"end_frame": new_idx
|
||
}
|
||
else:
|
||
track_id_dict[track_id]["end_frame"] = new_idx
|
||
|
||
# 重构返回结果
|
||
result = {
|
||
"class_list": class_list,
|
||
"track_id_list": track_id_dict
|
||
}
|
||
|
||
return result
|
||
|
||
def search_knowledge_base(
|
||
question: str,
|
||
system_prompt: str = "你是科技有限公司 MaxKB 知识问答系统的智能小,你的工作是 MaxKB 用户解答中,用户找你回答问题时,你要把主题放在 MaxKB 知识问答系统身上",
|
||
**kwargs
|
||
):
|
||
"""
|
||
通过 MaxKB 知识库问答系统获取关于 'MaxKB 是什么?' 的答案。
|
||
"""
|
||
# 1. 请求 URL
|
||
url = "http://localhost:8080/chat/api/019d5265-e90e-7663-b202-fc5a47d8601c/chat/completions"
|
||
|
||
# 2. 请求头
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": "Bearer agent-300b841bd6a7cb3a0bf603c093c22398"
|
||
}
|
||
|
||
# 3. 请求体(Payload)
|
||
payload = {
|
||
"model": "qwen3.5-27b",
|
||
"messages": [
|
||
{
|
||
"role": system_prompt,
|
||
"content": question
|
||
}
|
||
]
|
||
}
|
||
|
||
# 4. 发送 POST 请求
|
||
response = requests.post(url, headers=headers, json=payload)
|
||
|
||
# 5. 处理响应
|
||
if response.status_code == 200:
|
||
try:
|
||
# 尝试解析为 JSON
|
||
data = response.json()
|
||
# 打印完整的响应结构(调试用)
|
||
print("Full Response:")
|
||
print(json.dumps(data, ensure_ascii=False, indent=2))
|
||
|
||
# 尝试提取回答内容(根据常见的结构)
|
||
# 这里假设答案在 data['choices'][0]['message']['content'] 中
|
||
if 'choices' in data and data['choices']:
|
||
answer = data['choices'][0].get('message', {}).get('content', '')
|
||
print("\n=== MaxKB 回答 ===")
|
||
print(answer.strip())
|
||
|
||
return answer.strip()
|
||
else:
|
||
print("\n未在响应中找到 'choices' 字段。")
|
||
return ""
|
||
except Exception as e:
|
||
print("解析响应时出错:", e)
|
||
print("原始响应:", response.text)
|
||
return ""
|
||
else:
|
||
print(f"请求失败,状态码: {response.status_code}")
|
||
print("错误信息:", response.text)
|
||
return ""
|
||
|
||
def hazard_inspection(
|
||
class_dict: dict,
|
||
video_url: str,
|
||
enable_thinking: bool = True,
|
||
fps: int = 10
|
||
) -> tuple[str, str]:
|
||
"""
|
||
检查视频中给定类型物体的隐患
|
||
2026.04.07
|
||
|
||
参数:
|
||
class_dict: 类别字典
|
||
input_video_path: 视频路径
|
||
enable_thinking: 是否开启思考模式
|
||
fps
|
||
|
||
class_dict:结构:
|
||
{
|
||
"class_list": ["class_name1", "class_name2", ..."], # 类别名称
|
||
"track_id_list": {
|
||
"0": { # track_id
|
||
"class_id": 0,
|
||
"start_frame": 0,
|
||
"end_frame": 10
|
||
},
|
||
...
|
||
},
|
||
}
|
||
|
||
返回:
|
||
tuple[str, str]: (reasoning_text, json_result)
|
||
json_result: JSON格式的隐患检测结果
|
||
"""
|
||
|
||
# 筛选检查规则 知识库\rule.json
|
||
all_check_rule_dict: dict = load_json_data(r'知识库\rule.json')
|
||
|
||
# 读取提示词文件 7_检测提示词_视频.md
|
||
prompt_str: str = ""
|
||
with open(r'prompt\7_检测提示词_视频_new3.md', 'r', encoding='utf-8') as file:
|
||
prompt_str = file.read()
|
||
|
||
# 存储所有类别的结果
|
||
all_reasoning_text = ""
|
||
|
||
# 存储所有类别的 JSON 结果
|
||
all_tags = []
|
||
all_bases = []
|
||
all_objects = []
|
||
|
||
# 保存每个类别的原始 tag 和 base 映射,用于后续重新映射
|
||
class_tag_mapping = {}
|
||
class_base_mapping = {}
|
||
|
||
# 遍历每个类别,分别进行 chat
|
||
print("类别列表:", class_dict["class_list"])
|
||
for class_str in class_dict["class_list"]:
|
||
print(f"当前类别: {class_str}")
|
||
# 检查该类别是否有检查规则
|
||
class_check_rule: str = ""
|
||
scene_list: list[str] = list(all_check_rule_dict.keys())
|
||
print(f"场景列表: {scene_list}")
|
||
for scene in scene_list:
|
||
if class_str in all_check_rule_dict[scene].keys():
|
||
class_check_rule = all_check_rule_dict[scene][class_str]
|
||
break
|
||
if class_check_rule == "":
|
||
# 该类别没有检查规则
|
||
continue
|
||
|
||
# 筛选出该类别的 track_id_list
|
||
class_track_info = []
|
||
for track_id, track_data in class_dict["track_id_list"].items():
|
||
track_class_name = class_dict["class_list"][track_data["class_id"]]
|
||
if track_class_name == class_str:
|
||
class_track_info.append(f"Track ID {track_id}: {track_class_name} (帧 {track_data['start_frame']}-{track_data['end_frame']})")
|
||
|
||
# 如果该类别没有跟踪对象,跳过
|
||
if not class_track_info:
|
||
continue
|
||
|
||
track_data_str = "\n".join(class_track_info)
|
||
|
||
# 构建该类别的 prompt
|
||
prompt: str = f"""
|
||
# 类别: {class_str}
|
||
|
||
# 跟踪对象信息
|
||
{track_data_str}
|
||
|
||
# 检查规则
|
||
{class_check_rule}
|
||
|
||
{prompt_str}
|
||
"""
|
||
|
||
# 调用 chat 函数
|
||
reasoning_text, answer_text = chat(prompt, video_url, enable_thinking, fps)
|
||
|
||
# 整合思考过程
|
||
all_reasoning_text += f"\n\n# 类别: {class_str}\n{reasoning_text}"
|
||
|
||
# 解析 JSON 结果
|
||
try:
|
||
# 提取 JSON 部分(去除可能的代码块标记)
|
||
json_str = answer_text.strip()
|
||
if json_str.startswith("```json"):
|
||
json_str = json_str[7:]
|
||
if json_str.startswith("```"):
|
||
json_str = json_str[3:]
|
||
if json_str.endswith("```"):
|
||
json_str = json_str[:-3]
|
||
json_str = json_str.strip()
|
||
|
||
class_result = json.loads(json_str)
|
||
|
||
# 收集该类别的结果
|
||
if "tag" in class_result:
|
||
# 保存该类别的 tag 映射
|
||
class_tag_mapping[class_str] = class_result["tag"]
|
||
all_tags.extend(class_result["tag"])
|
||
if "base" in class_result:
|
||
# 保存该类别的 base 映射
|
||
class_base_mapping[class_str] = class_result["base"]
|
||
all_bases.extend(class_result["base"])
|
||
if "objects" in class_result:
|
||
all_objects.extend(class_result["objects"])
|
||
|
||
except json.JSONDecodeError as e:
|
||
print(f"解析类别 {class_str} 的 JSON 结果失败: {e}")
|
||
print(f"原始输出: {answer_text}")
|
||
continue
|
||
|
||
# 合并去重 tag 和 base
|
||
unique_tags = []
|
||
seen_tags = set()
|
||
for tag in all_tags:
|
||
if tag not in seen_tags:
|
||
seen_tags.add(tag)
|
||
unique_tags.append(tag)
|
||
|
||
unique_bases = []
|
||
seen_bases = set()
|
||
for base in all_bases:
|
||
if base not in seen_bases:
|
||
seen_bases.add(base)
|
||
unique_bases.append(base)
|
||
|
||
# 重新分配 tag_id 和 base_id
|
||
# 由于每个类别是独立处理的,需要根据 class_id 找到对应的类别,然后重新映射
|
||
# 但是 object 中的 class_id 是相对于 class_list 的,需要找到对应的类别名称
|
||
final_objects = []
|
||
for obj in all_objects:
|
||
new_obj = obj.copy()
|
||
new_obj["hazard_track_id"] = len(final_objects)
|
||
|
||
# 通过 class_id 找到类别名称
|
||
class_id = obj.get("class_id", 0)
|
||
if class_id < len(class_dict["class_list"]):
|
||
class_name = class_dict["class_list"][class_id]
|
||
|
||
# 重新映射 tag_id
|
||
if "tag_id" in obj and class_name in class_tag_mapping:
|
||
original_tags = class_tag_mapping[class_name]
|
||
if obj["tag_id"] < len(original_tags):
|
||
original_tag = original_tags[obj["tag_id"]]
|
||
if original_tag in unique_tags:
|
||
new_obj["tag_id"] = unique_tags.index(original_tag)
|
||
|
||
# 重新映射 base_id
|
||
if "base_id" in obj and class_name in class_base_mapping:
|
||
original_bases = class_base_mapping[class_name]
|
||
if obj["base_id"] < len(original_bases):
|
||
original_base = original_bases[obj["base_id"]]
|
||
if original_base in unique_bases:
|
||
new_obj["base_id"] = unique_bases.index(original_base)
|
||
|
||
final_objects.append(new_obj)
|
||
|
||
# 构建最终结果
|
||
final_result = {
|
||
"tag": unique_tags,
|
||
"base": unique_bases,
|
||
"objects": final_objects
|
||
}
|
||
|
||
# 转换为 JSON 字符串
|
||
json_result = json.dumps(final_result, ensure_ascii=False, indent=2)
|
||
|
||
return all_reasoning_text, json_result
|
||
|
||
def merge_conflict_inspection_data(data: dict) -> dict:
|
||
"""
|
||
合并冲突数据(相同 tag_id 和 class_id 的重叠时间段)。
|
||
|
||
参数:
|
||
- data: 包含 "class", "tag", "objects" 键的字典。
|
||
|
||
返回:
|
||
- 合并后的数据字典,结构与输入相同,但 objects 已合并。
|
||
"""
|
||
|
||
# 1. 按照 (tag_id, class_id) 对对象进行分组
|
||
groups = {}
|
||
for obj in data["objects"]:
|
||
key = (obj["tag_id"], obj["class_id"])
|
||
groups.setdefault(key, []).append(obj)
|
||
|
||
merged_objects = []
|
||
new_id = 1
|
||
|
||
# 2. 对每个分组进行合并
|
||
for (tag_id, class_id), objs in groups.items():
|
||
# 按 start_frame 排序
|
||
sorted_objs = sorted(objs, key=lambda x: x["start_frame"])
|
||
|
||
# 初始化第一个区间
|
||
cur_start = sorted_objs[0]["start_frame"]
|
||
cur_end = sorted_objs[0]["end_frame"]
|
||
cur_level = sorted_objs[0]["level"]
|
||
|
||
for obj in sorted_objs[1:]:
|
||
# 如果当前区间与下一个区间重叠或相连(end >= start)
|
||
if obj["start_frame"] <= cur_end:
|
||
# 合并区间:取结束帧的最大值
|
||
cur_end = max(cur_end, obj["end_frame"])
|
||
# 级别取最大值(如果级别不同)
|
||
cur_level = max(cur_level, obj["level"])
|
||
else:
|
||
# 当前区间结束,保存合并结果
|
||
merged_objects.append({
|
||
"hazard_track_id": new_id,
|
||
"tag_id": tag_id,
|
||
"class_id": class_id,
|
||
"level": cur_level,
|
||
"start_frame": cur_start,
|
||
"end_frame": cur_end
|
||
})
|
||
new_id += 1
|
||
# 开始新的区间
|
||
cur_start = obj["start_frame"]
|
||
cur_end = obj["end_frame"]
|
||
data["objects"] = merged_objects
|
||
return data
|
||
|
||
def get_hazard_level_from_rule(rule_definitions: dict, class_name: str, tag_name: str) -> str:
|
||
"""
|
||
从知识库规则中获取隐患等级。
|
||
|
||
参数:
|
||
rule_definitions (dict): 隐患规则定义
|
||
class_name (str): 隐患类别名称(如"灭火器")
|
||
tag_name (str): 隐患标签名称(如"灭火器被遮挡")
|
||
|
||
返回:
|
||
str: 隐患等级(如"重大隐患"),未找到时返回"一般隐患"
|
||
"""
|
||
class_rules = rule_definitions.get(class_name, {})
|
||
if not class_rules:
|
||
return "一般隐患"
|
||
|
||
for level_name, level_content in class_rules.items():
|
||
if isinstance(level_content, dict) and tag_name in level_content:
|
||
return level_name
|
||
|
||
return "一般隐患"
|
||
|
||
|
||
def map_original_level(level: int) -> str:
|
||
"""
|
||
将原始 level 映射为相似度描述。
|
||
|
||
参数:
|
||
level (int): 原始等级(1表示一般,2表示明显)
|
||
|
||
返回:
|
||
str: 相似度描述("疑似"或"高度相似")
|
||
"""
|
||
return "高度相似" if level == 2 else "疑似"
|
||
|
||
|
||
def report_generator(
|
||
video_path: str,
|
||
detection_data: dict,
|
||
hazard_results: dict,
|
||
rule_definitions: dict,
|
||
output_path: str = "output",
|
||
frame_interval: float = 1.0
|
||
):
|
||
"""
|
||
生成隐患报告的主函数。
|
||
|
||
参数:
|
||
video_path (str): 视频文件路径
|
||
detection_data (dict): 物体检测数据,格式为 {frame_id: [detection_info, ...]}
|
||
hazard_results (dict): 隐患检查结果
|
||
rule_definitions (dict): 隐患规则定义
|
||
output_path (str): 输出文件夹路径,默认 'output'
|
||
frame_interval (float): 帧间隔时间(秒),用于计算实际帧号,默认 1.0
|
||
|
||
功能:
|
||
1. 创建必要的文件夹结构
|
||
2. 根据隐患时间范围截取对应的图片
|
||
3. 生成 Markdown 报告
|
||
"""
|
||
|
||
# ----------------------- 1. 准备工作 -----------------------
|
||
# 创建文件夹结构
|
||
report_dir = os.path.join(output_path, "report")
|
||
assets_dir = os.path.join(report_dir, "_assets")
|
||
os.makedirs(assets_dir, exist_ok=True)
|
||
|
||
# 打开视频文件
|
||
cap = cv2.VideoCapture(video_path)
|
||
if not cap.isOpened():
|
||
raise IOError(f"无法打开视频文件: {video_path}")
|
||
|
||
# 获取视频总帧数(用于安全检查)
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
||
# ----------------------- 2. 处理隐患数据并截取图片 -----------------------
|
||
# 用于存储 Markdown 表格的行数据
|
||
md_table_rows = []
|
||
|
||
# 遍历所有隐患对象
|
||
for idx, obj in enumerate(hazard_results["objects"], start=1):
|
||
# 计算实际的起始帧(考虑帧间隔)
|
||
start_frame = int(obj["start_frame"] * frame_interval)
|
||
|
||
# 防止帧号越界
|
||
start_frame = max(0, min(start_frame, total_frames - 1))
|
||
|
||
# 获取隐患对应的类名和标签
|
||
class_name = hazard_results["class"][obj["class_id"]]
|
||
tag_name = hazard_results["tag"][obj["tag_id"]]
|
||
original_level = obj["level"]
|
||
location = obj.get("location", "")
|
||
|
||
# 从知识库规则中获取隐患等级
|
||
rule_level = get_hazard_level_from_rule(rule_definitions, class_name, tag_name)
|
||
# 将原始 level 映射为相似度描述
|
||
similarity = map_original_level(original_level)
|
||
|
||
# ----------------------- 2.1. 截取图片 -----------------------
|
||
# 默认选择 start_frame 作为截图帧
|
||
target_frame = start_frame
|
||
|
||
# 截图并保存
|
||
img_filename = f"{idx}_{target_frame}.jpg"
|
||
img_path = os.path.join(assets_dir, img_filename)
|
||
|
||
# 读取对应帧并保存整帧图片
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
||
ret, frame = cap.read()
|
||
if ret:
|
||
cv2.imwrite(img_path, frame)
|
||
|
||
# ----------------------- 2.2. 生成 Markdown 表格行 -----------------------
|
||
# 生成图片引用的 Markdown 语法
|
||
img_markdown = f""
|
||
|
||
# 只取前 2 条依据作为展示
|
||
base_text = hazard_results["base"][obj["base_id"]]
|
||
|
||
# 组装表格行
|
||
row = f"|{idx}|{rule_level}|{similarity}|{base_text}|{start_frame}|{location}|{img_markdown}|"
|
||
md_table_rows.append(row)
|
||
|
||
# ----------------------- 3. 生成 Markdown 报告 -----------------------
|
||
md_content = f"""# 隐患报告
|
||
|
||
## 隐患清单
|
||
|
||
| 隐患序号 | 隐患等级 | 相似度 | 依据 | 开始帧 | 位置描述 | 照片 |
|
||
| :---: | :---: | :---: | :--- | :---: | :--- | :---: |
|
||
"""
|
||
md_content += "\n".join(md_table_rows)
|
||
|
||
# 写入报告文件
|
||
report_path = os.path.join(report_dir, "report.md")
|
||
with open(report_path, "w", encoding="utf-8") as f:
|
||
f.write(md_content)
|
||
|
||
# 释放资源
|
||
cap.release()
|
||
print(f"报告已生成: {report_path}") |