HazardInspector/lib/qwen_fun.py

760 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
from pathlib import Path
import requests
from openai import BadRequestError, OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
import cv2
# 图片列表:键为下拉选项名称,值为包含网络地址和本地地址的字典
# 注意请使用有效的图片URL或将图片下载到本地后使用本地路径
file_dict: dict = {}
# file_info_path: str = "file_dict.json"
file_info_path: str = r"D:\Userfile\Downloads\export_urls.csv"
# 设置环境变量(请替换为您自己的 API Key
os.environ["DASHSCOPE_API_KEY"] = 'sk-c7ffed1b32794284a5d21b046d7d0008'
os.environ["MAXKB_API_KEY"] = 'agent-300b841bd6a7cb3a0bf603c093c22398'
os.environ["MODEL_NAME"] = 'qwen3.5-27b'
def chat(prompt, file_path: str | list, enable_thinking=False, fps=10, thinking_budget: int=81920):
"""
调用 DashScope 接口进行图片理解与提示词检测。
参数 img_info 为包含网络和本地路径的字典。
"""
print("开始处理视频...")
print("当前文件: ", file_path)
print("============提示词============")
print(prompt)
enable_thinking = enable_thinking
client = OpenAI(
api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
messages = []
if type(file_path) == str:
if '.jpg' in file_path.lower():
file_path = [["img", file_path]]
elif '.mp4' or '.avi' in file_path.lower():
file_path = [["vid", file_path, fps]]
content: list = []
for item in file_path:
if item[0] == "img":
content.append({"type": "image_url", "image_url": {"url": item[1]}})
elif item[0] == "vid":
content.append({"type": "video_url", "video_url": {"url": item[1]}, "fps": item[2]})
else:
raise ValueError(f"不支持的文件类型:{item[0]}", 99999)
content.append({"type": "text", "text": prompt})
messages = [{
"role": "user",
"content": content
}]
# 开启流式输出
try:
completion = client.chat.completions.create(
model=os.getenv("MODEL_NAME"), # type: ignore
messages=messages, # type: ignore
extra_body={"enable_thinking": enable_thinking, "thinking_budget": thinking_budget},
stream=True,
# 支持访问图片和视频的OSS协议 URL
extra_headers={"X-DashScope-OssResourceResolve": "enable"}
) # type: ignore
except BadRequestError as e:
# 这里捕捉到特定的 400 错误
# 判断是否是图片相关的错误
if "Failed to download multimodal content" in str(e):
# 返回空的模型回复,并返回错误信息
raise ValueError("❌ 图片链接不存在或已过期")
# return "", "❌ 图片链接不存在或已过期"
else:
# 其他错误直接抛出或返回通用错误
raise ValueError(f"❌ 发生错误: {str(e)}")
except Exception as e:
# 捕捉其他未知错误
raise ValueError(f"❌ 未知错误: {str(e)}")
# 捕获流式数据
reasoning_text = ""
answer_text = ""
is_answering = False
if enable_thinking:
print("============思考过程============")
for chunk in completion:
delta = chunk.choices[0].delta
# 思考过程
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
reasoning_text += delta.reasoning_content
print(delta.reasoning_content, end="")
# 最终答案
if hasattr(delta, "content") and delta.content:
if not is_answering:
is_answering = True
print("============最终答案============")
answer_text += delta.content
print(delta.content, end="")
print() # 换行
# 返回格式化后的结果
return reasoning_text, answer_text
def upload_file_and_get_url(file_path: str) -> str:
"""
单个文件上传到 OSS 并返回 URL内部使用
"""
# 1. 获取上传策略
policy_resp = requests.get(
"https://dashscope.aliyuncs.com/api/v1/uploads",
headers={"Authorization": f"Bearer {os.getenv('DASHSCOPE_API_KEY')}"},
params={"action": "getPolicy", "model": os.getenv('MODEL_NAME')}
)
# 检查请求是否成功
if policy_resp.status_code != 200:
raise Exception(f"获取上传策略失败,状态码: {policy_resp.status_code}, 响应: {policy_resp.text}")
# 尝试获取 'data' 字段
policy_json = policy_resp.json()
if 'data' not in policy_json:
raise Exception(f"获取上传策略失败,返回内容不包含 'data' 字段: {policy_json}")
policy_data = policy_json['data']
# 2. 上传文件到 OSS
file_name = Path(file_path).name
key = f"{policy_data['upload_dir']}/{file_name}"
with open(file_path, 'rb') as f:
files = {
'OSSAccessKeyId': (None, policy_data['oss_access_key_id']),
'Signature': (None, policy_data['signature']),
'policy': (None, policy_data['policy']),
'x-oss-object-acl': (None, policy_data['x_oss_object_acl']),
'x-oss-forbid-overwrite': (None, policy_data['x_oss_forbid_overwrite']),
'key': (None, key),
'success_action_status': (None, '200'),
'file': (file_name, f)
}
response = requests.post(policy_data['upload_host'], files=files)
if response.status_code != 200:
raise Exception(f"文件上传失败: {file_path}, 状态码: {response.status_code}, 响应: {response.text}")
return f"oss://{key}"
def upload_files_and_get_urls_concurrently(
file_path_list: list[str],
max_workers: int | None = None,
timeout: int = 300
) -> list[str | None]:
"""
并发上传多个文件到 OSS 并返回对应的 URL 列表(顺序对应)
"""
# 初始化一个与文件列表长度相同的结果列表
results: list[str | None] = [None] * len(file_path_list)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 使用字典存储 future 与其对应的索引
future_to_index = {
executor.submit(
upload_file_and_get_url,
file_path
): idx
for idx, file_path in enumerate(file_path_list)
}
for future in as_completed(future_to_index, timeout=timeout):
idx = future_to_index[future]
file_path = file_path_list[idx]
try:
url = future.result()
results[idx] = url # 将结果放回对应的索引位置
print(f"[成功] {file_path} -> {url}")
except Exception as exc:
# 捕获异常并打印错误信息
print(f"[错误] {file_path} 上传失败: {exc}")
results[idx] = None # 保持为 None 或者设为错误标识
return results
def save_json_to_file(json_data, file_path: str) -> None:
"""
保存 JSON 数据到文件
参数:
json_data: 要保存的 JSON 数据
file_path: 文件路径
"""
# 1. 提取目录路径
dir_path = os.path.dirname(file_path)
# 2. 如果有目录路径,确保它存在
if dir_path:
os.makedirs(dir_path, exist_ok=True) # exist_ok=True 防止重复创建报错
# 3. 写入文件
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=4)
except Exception as e:
print(e)
return
print(f"已保存 JSON 数据到: {file_path}")
def load_json_data(json_path: str):
with open(json_path, 'r', encoding='utf-8') as f:
return json.load(f)
def get_unique_track_id_count(json_path: str) -> int:
"""
提取并计算JSON文件中不重复的track_id总数
参数:
json_path: JSON文件路径
返回:
不重复的track_id总数
"""
# 加载JSON数据
data = load_json_data(json_path)
# 存储唯一的track_id
unique_track_ids = set()
# 遍历所有帧
for frame_data in data.values():
# 遍历当前帧的所有检测结果
for detection in frame_data:
# 提取track_id并添加到集合中
if 'track_id' in detection:
unique_track_ids.add(detection['track_id'])
# 返回唯一track_id的数量
return len(unique_track_ids)
def get_annnotated_frame_for_ai_without_xyxy(json_data: dict, interval: int = 5, conf: float = 0.7) -> dict:
"""
从 frame_all.json 文件中提取ai能看到的部分并重新计算帧号
按照指定格式重构返回结果
输出json_data格式
{
"class_list": ["class_name1", "class_name2", ..."], # 类别名称
"track_id_list": {
"0": { # track_id
"class_id": 0,
"start_frame": 0,
"end_frame": 10
},
...
},
}
"""
# 抽取0,5,10...并重命名帧号为0,1,2...
frame_ids: list = list(json_data.keys())
selected_frames: list = frame_ids[::interval]
# 提取所有类别名称(去重)
class_set = set()
for frame_id in selected_frames:
frame_data = json_data[frame_id]
for obj in frame_data:
class_set.add(obj["class_str"])
class_list = sorted(list(class_set))
# 建立类别名称到ID的映射
class_name_to_id = {name: idx for idx, name in enumerate(class_list)}
# 构建track_id_list
track_id_dict = {}
for new_idx, frame_id in enumerate(selected_frames):
frame_data = json_data[frame_id]
for obj in frame_data:
if obj["confidence"] < conf:
continue
track_id = obj["track_id"]
class_id = class_name_to_id[obj["class_str"]]
if track_id not in track_id_dict:
track_id_dict[track_id] = {
"class_id": class_id,
"start_frame": new_idx,
"end_frame": new_idx
}
else:
track_id_dict[track_id]["end_frame"] = new_idx
# 重构返回结果
result = {
"class_list": class_list,
"track_id_list": track_id_dict
}
return result
def search_knowledge_base(
question: str,
system_prompt: str = "你是科技有限公司 MaxKB 知识问答系统的智能小,你的工作是 MaxKB 用户解答中,用户找你回答问题时,你要把主题放在 MaxKB 知识问答系统身上",
**kwargs
):
"""
通过 MaxKB 知识库问答系统获取关于 'MaxKB 是什么?' 的答案。
"""
# 1. 请求 URL
url = "http://localhost:8080/chat/api/019d5265-e90e-7663-b202-fc5a47d8601c/chat/completions"
# 2. 请求头
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer agent-300b841bd6a7cb3a0bf603c093c22398"
}
# 3. 请求体Payload
payload = {
"model": "qwen3.5-27b",
"messages": [
{
"role": system_prompt,
"content": question
}
]
}
# 4. 发送 POST 请求
response = requests.post(url, headers=headers, json=payload)
# 5. 处理响应
if response.status_code == 200:
try:
# 尝试解析为 JSON
data = response.json()
# 打印完整的响应结构(调试用)
print("Full Response:")
print(json.dumps(data, ensure_ascii=False, indent=2))
# 尝试提取回答内容(根据常见的结构)
# 这里假设答案在 data['choices'][0]['message']['content'] 中
if 'choices' in data and data['choices']:
answer = data['choices'][0].get('message', {}).get('content', '')
print("\n=== MaxKB 回答 ===")
print(answer.strip())
return answer.strip()
else:
print("\n未在响应中找到 'choices' 字段。")
return ""
except Exception as e:
print("解析响应时出错:", e)
print("原始响应:", response.text)
return ""
else:
print(f"请求失败,状态码: {response.status_code}")
print("错误信息:", response.text)
return ""
def hazard_inspection(
class_dict: dict,
video_url: str,
enable_thinking: bool = True,
fps: int = 10
) -> tuple[str, str]:
"""
检查视频中给定类型物体的隐患
2026.04.07
参数:
class_dict: 类别字典
input_video_path: 视频路径
enable_thinking: 是否开启思考模式
fps
class_dict:结构:
{
"class_list": ["class_name1", "class_name2", ..."], # 类别名称
"track_id_list": {
"0": { # track_id
"class_id": 0,
"start_frame": 0,
"end_frame": 10
},
...
},
}
返回:
tuple[str, str]: (reasoning_text, json_result)
json_result: JSON格式的隐患检测结果
"""
# 筛选检查规则 知识库\rule.json
all_check_rule_dict: dict = load_json_data(r'知识库\rule.json')
# 读取提示词文件 7_检测提示词_视频.md
prompt_str: str = ""
with open(r'prompt\7_检测提示词_视频_new3.md', 'r', encoding='utf-8') as file:
prompt_str = file.read()
# 存储所有类别的结果
all_reasoning_text = ""
# 存储所有类别的 JSON 结果
all_tags = []
all_bases = []
all_objects = []
# 保存每个类别的原始 tag 和 base 映射,用于后续重新映射
class_tag_mapping = {}
class_base_mapping = {}
# 遍历每个类别,分别进行 chat
<<<<<<< HEAD
print("类别列表:", class_dict["class_list"])
for class_str in class_dict["class_list"]:
print(f"当前类别: {class_str}")
# 检查该类别是否有检查规则
class_check_rule: str = ""
scene_list: list[str] = list(all_check_rule_dict.keys())
print(f"场景列表: {scene_list}")
for scene in scene_list:
if class_str in all_check_rule_dict[scene].keys():
class_check_rule = all_check_rule_dict[scene][class_str]
break
if class_check_rule == "":
# 该类别没有检查规则
continue
=======
for class_str in class_dict["class_list"]:
# 检查该类别是否有检查规则
if class_str not in all_check_rule_dict["(一)厂房防火分区"].keys():
continue
# 获取该类别的规则
class_check_rule = all_check_rule_dict["(一)厂房防火分区"][class_str]
>>>>>>> 7562de4 (2 预览图还有点问题)
# 筛选出该类别的 track_id_list
class_track_info = []
for track_id, track_data in class_dict["track_id_list"].items():
track_class_name = class_dict["class_list"][track_data["class_id"]]
if track_class_name == class_str:
class_track_info.append(f"Track ID {track_id}: {track_class_name} (帧 {track_data['start_frame']}-{track_data['end_frame']})")
# 如果该类别没有跟踪对象,跳过
if not class_track_info:
continue
track_data_str = "\n".join(class_track_info)
# 构建该类别的 prompt
prompt: str = f"""
# 类别: {class_str}
# 跟踪对象信息
{track_data_str}
# 检查规则
{class_check_rule}
{prompt_str}
"""
# 调用 chat 函数
reasoning_text, answer_text = chat(prompt, video_url, enable_thinking, fps)
# 整合思考过程
all_reasoning_text += f"\n\n# 类别: {class_str}\n{reasoning_text}"
# 解析 JSON 结果
try:
# 提取 JSON 部分(去除可能的代码块标记)
json_str = answer_text.strip()
if json_str.startswith("```json"):
json_str = json_str[7:]
if json_str.startswith("```"):
json_str = json_str[3:]
if json_str.endswith("```"):
json_str = json_str[:-3]
json_str = json_str.strip()
class_result = json.loads(json_str)
# 收集该类别的结果
if "tag" in class_result:
# 保存该类别的 tag 映射
class_tag_mapping[class_str] = class_result["tag"]
all_tags.extend(class_result["tag"])
if "base" in class_result:
# 保存该类别的 base 映射
class_base_mapping[class_str] = class_result["base"]
all_bases.extend(class_result["base"])
if "objects" in class_result:
all_objects.extend(class_result["objects"])
except json.JSONDecodeError as e:
print(f"解析类别 {class_str} 的 JSON 结果失败: {e}")
print(f"原始输出: {answer_text}")
continue
# 合并去重 tag 和 base
unique_tags = []
seen_tags = set()
for tag in all_tags:
if tag not in seen_tags:
seen_tags.add(tag)
unique_tags.append(tag)
unique_bases = []
seen_bases = set()
for base in all_bases:
if base not in seen_bases:
seen_bases.add(base)
unique_bases.append(base)
# 重新分配 tag_id 和 base_id
# 由于每个类别是独立处理的,需要根据 class_id 找到对应的类别,然后重新映射
# 但是 object 中的 class_id 是相对于 class_list 的,需要找到对应的类别名称
final_objects = []
for obj in all_objects:
new_obj = obj.copy()
new_obj["hazard_track_id"] = len(final_objects)
# 通过 class_id 找到类别名称
class_id = obj.get("class_id", 0)
if class_id < len(class_dict["class_list"]):
class_name = class_dict["class_list"][class_id]
# 重新映射 tag_id
if "tag_id" in obj and class_name in class_tag_mapping:
original_tags = class_tag_mapping[class_name]
if obj["tag_id"] < len(original_tags):
original_tag = original_tags[obj["tag_id"]]
if original_tag in unique_tags:
new_obj["tag_id"] = unique_tags.index(original_tag)
# 重新映射 base_id
if "base_id" in obj and class_name in class_base_mapping:
original_bases = class_base_mapping[class_name]
if obj["base_id"] < len(original_bases):
original_base = original_bases[obj["base_id"]]
if original_base in unique_bases:
new_obj["base_id"] = unique_bases.index(original_base)
final_objects.append(new_obj)
# 构建最终结果
final_result = {
"tag": unique_tags,
"base": unique_bases,
"objects": final_objects
}
# 转换为 JSON 字符串
json_result = json.dumps(final_result, ensure_ascii=False, indent=2)
return all_reasoning_text, json_result
def merge_conflict_inspection_data(data: dict) -> dict:
"""
合并冲突数据(相同 tag_id 和 class_id 的重叠时间段)。
参数:
- data: 包含 "class", "tag", "objects" 键的字典。
返回:
- 合并后的数据字典,结构与输入相同,但 objects 已合并。
"""
# 1. 按照 (tag_id, class_id) 对对象进行分组
groups = {}
for obj in data["objects"]:
key = (obj["tag_id"], obj["class_id"])
groups.setdefault(key, []).append(obj)
merged_objects = []
new_id = 1
# 2. 对每个分组进行合并
for (tag_id, class_id), objs in groups.items():
# 按 start_frame 排序
sorted_objs = sorted(objs, key=lambda x: x["start_frame"])
# 初始化第一个区间
cur_start = sorted_objs[0]["start_frame"]
cur_end = sorted_objs[0]["end_frame"]
cur_level = sorted_objs[0]["level"]
for obj in sorted_objs[1:]:
# 如果当前区间与下一个区间重叠或相连end >= start
if obj["start_frame"] <= cur_end:
# 合并区间:取结束帧的最大值
cur_end = max(cur_end, obj["end_frame"])
# 级别取最大值(如果级别不同)
cur_level = max(cur_level, obj["level"])
else:
# 当前区间结束,保存合并结果
merged_objects.append({
"hazard_track_id": new_id,
"tag_id": tag_id,
"class_id": class_id,
"level": cur_level,
"start_frame": cur_start,
"end_frame": cur_end
})
new_id += 1
# 开始新的区间
cur_start = obj["start_frame"]
cur_end = obj["end_frame"]
data["objects"] = merged_objects
return data
def get_hazard_level_from_rule(rule_definitions: dict, class_name: str, tag_name: str) -> str:
"""
从知识库规则中获取隐患等级。
参数:
rule_definitions (dict): 隐患规则定义
class_name (str): 隐患类别名称(如"灭火器"
tag_name (str): 隐患标签名称(如"灭火器被遮挡"
返回:
str: 隐患等级(如"重大隐患"),未找到时返回"一般隐患"
"""
class_rules = rule_definitions.get(class_name, {})
if not class_rules:
return "一般隐患"
for level_name, level_content in class_rules.items():
if isinstance(level_content, dict) and tag_name in level_content:
return level_name
return "一般隐患"
def map_original_level(level: int) -> str:
"""
将原始 level 映射为相似度描述。
参数:
level (int): 原始等级1表示一般2表示明显
返回:
str: 相似度描述("疑似""高度相似"
"""
return "高度相似" if level == 2 else "疑似"
def report_generator(
video_path: str,
detection_data: dict,
hazard_results: dict,
rule_definitions: dict,
output_path: str = "output",
frame_interval: float = 1.0
):
"""
生成隐患报告的主函数。
参数:
video_path (str): 视频文件路径
detection_data (dict): 物体检测数据,格式为 {frame_id: [detection_info, ...]}
hazard_results (dict): 隐患检查结果
rule_definitions (dict): 隐患规则定义
output_path (str): 输出文件夹路径,默认 'output'
frame_interval (float): 帧间隔时间(秒),用于计算实际帧号,默认 1.0
功能:
1. 创建必要的文件夹结构
2. 根据隐患时间范围截取对应的图片
3. 生成 Markdown 报告
"""
# ----------------------- 1. 准备工作 -----------------------
# 创建文件夹结构
report_dir = os.path.join(output_path, "report")
assets_dir = os.path.join(report_dir, "_assets")
os.makedirs(assets_dir, exist_ok=True)
# 打开视频文件
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"无法打开视频文件: {video_path}")
# 获取视频总帧数(用于安全检查)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# ----------------------- 2. 处理隐患数据并截取图片 -----------------------
# 用于存储 Markdown 表格的行数据
md_table_rows = []
# 遍历所有隐患对象
for idx, obj in enumerate(hazard_results["objects"], start=1):
# 计算实际的起始帧(考虑帧间隔)
start_frame = int(obj["start_frame"] * frame_interval)
# 防止帧号越界
start_frame = max(0, min(start_frame, total_frames - 1))
# 获取隐患对应的类名和标签
class_name = hazard_results["class"][obj["class_id"]]
tag_name = hazard_results["tag"][obj["tag_id"]]
original_level = obj["level"]
location = obj.get("location", "")
# 从知识库规则中获取隐患等级
rule_level = get_hazard_level_from_rule(rule_definitions, class_name, tag_name)
# 将原始 level 映射为相似度描述
similarity = map_original_level(original_level)
# ----------------------- 2.1. 截取图片 -----------------------
# 默认选择 start_frame 作为截图帧
target_frame = start_frame
# 截图并保存
img_filename = f"{idx}_{target_frame}.jpg"
img_path = os.path.join(assets_dir, img_filename)
# 读取对应帧并保存整帧图片
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
ret, frame = cap.read()
if ret:
cv2.imwrite(img_path, frame)
# ----------------------- 2.2. 生成 Markdown 表格行 -----------------------
# 生成图片引用的 Markdown 语法
img_markdown = f"![](./_assets/{img_filename})"
# 只取前 2 条依据作为展示
base_text = hazard_results["base"][obj["base_id"]]
# 组装表格行
row = f"|{idx}|{rule_level}|{similarity}|{base_text}|{start_frame}|{location}|{img_markdown}|"
md_table_rows.append(row)
# ----------------------- 3. 生成 Markdown 报告 -----------------------
md_content = f"""# 隐患报告
## 隐患清单
| 隐患序号 | 隐患等级 | 相似度 | 依据 | 开始帧 | 位置描述 | 照片 |
| :---: | :---: | :---: | :--- | :---: | :--- | :---: |
"""
md_content += "\n".join(md_table_rows)
# 写入报告文件
report_path = os.path.join(report_dir, "report.md")
with open(report_path, "w", encoding="utf-8") as f:
f.write(md_content)
# 释放资源
cap.release()
print(f"报告已生成: {report_path}")