heng/Qwen_app.py

463 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from lib.json_fun import f_detections_to_objects, load_json_data
from lib.qwen_fun import get_annnotated_frame_for_ai_without_xyxy, hazard_inspection, merge_conflict_inspection_data, report_generator, search_knowledge_base
from encodings.punycode import T
from lib.qwen_fun_vid import generate_video_to_objects
"""
测试 在给定标注框与类别原始视频经过转换之后AI能否准确识别物体特征
"""
from tkinter import N
from datetime import datetime
import json
import os
import re
import sys
import subprocess
import importlib.util
from typing import Any
from lib.qwen_fun import save_json_to_file, upload_files_and_get_urls_concurrently
from lib.sam3 import SAM3
import cv2
import json
import numpy as np
from pathlib import Path
import gradio as gr
fps = 0
VIDEO_FOLDER: str = "input"
def load_file_list() -> list[str]:
"""
加载指定文件夹下的所有文件名称(不含子目录)
返回:文件名列表
"""
file_names: list[str] = []
if not os.path.isdir(VIDEO_FOLDER):
return file_names
for item in os.listdir(VIDEO_FOLDER):
item_path: str = os.path.join(VIDEO_FOLDER, item)
if os.path.isfile(item_path):
file_names.append(item)
return file_names
def reload_files():
"""
刷新文件列表并设置默认值
"""
file_list = load_file_list()
default_value = file_list[0] if file_list else None
return gr.update(choices=file_list, value=default_value)
def get_full_vid_path(vid_file: str) -> str:
"""
获取视频绝对路径
"""
full_path: str = os.path.join(os.getcwd(), VIDEO_FOLDER, vid_file)
gr.Info(full_path)
return full_path
def update_preview(frame_idx: int, vid_file: str):
vid_name: str = Path(vid_file).stem
output_dir: str = f"output/{vid_name}"
global fps
idx = int(frame_idx // fps)
img_path = f"{output_dir}/boxes/frame_{idx:04d}.jpg"
with open(f"{output_dir}/hazard_inspection.json", "r", encoding="utf-8") as f:
result = f.read()
dic = json.loads(result)
class_tag = get_class_tag_by_frame(dic, frame_idx, fps)
return img_path, class_tag
def run(vid_file: str, run_sam3: bool = True, run_inspection: bool = True, gen_report: bool = True):
"""
把 batch_run_videos_then_analyze_images.py 的全部功能整合到这一个函数里。
针对 Gradio 选中的单个视频执行1.py 生成 track 图片 -> analyze_track_images_with_llm.py 检测图片。
参数(保持与原 run 一致,方便 Gradio 直接复用):
vid_file: input/ 目录下的视频文件名
run_sam3: 是否运行 1.py 生成 track 图片目录(沿用原变量名做开关)
run_inspection: 是否运行 analyze 图片检测脚本
gen_report: 占位analyze 脚本会自行生成 Word/JSON 报告
"""
# ============================================================================
# 【配置区域】(来自 batch_run_videos_then_analyze_images.py
# ============================================================================
VIDEO_DIR = r""
VIDEO_PATHS: list[str] = [r"C:\factory-inspection\videos\MOV00001_20260213_103217_fixed.mp4"]
YOLO_MODEL_PATH = r"C:\factory-inspection\yolo\best.pt"
TRACK_SCRIPT_PATH = r"C:\factory-inspection\1.py"
ANALYZE_SCRIPT_PATH = r"C:\factory-inspection\scripts\analyze_track_images_with_llm.py"
OUTPUT_ROOT = r"C:\factory-inspection\batch_video_image_reports"
VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".MP4", ".AVI", ".MOV", ".MKV"}
# ============================================================================
def _safe_name(path: Path) -> str:
return re.sub(r"[^0-9A-Za-z\u4e00-\u9fff_.-]+", "_", path.stem or "video")
def _load_module(script_path: Path) -> Any:
spec = importlib.util.spec_from_file_location("factory_video_tracker", script_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"无法加载脚本:{script_path}")
module = importlib.util.module_from_spec(spec)
sys.modules["factory_video_tracker"] = module
spec.loader.exec_module(module)
return module
def _collect_videos(video_dir: str, video_paths: list[str]) -> list[Path]:
videos: list[Path] = []
for raw in video_paths:
path = Path(raw)
if path.is_file():
videos.append(path)
if video_dir:
root = Path(video_dir)
if root.is_dir():
for path in sorted(root.rglob("*")):
if path.suffix in VIDEO_EXTS and path.is_file():
videos.append(path)
# 去重但保持顺序
result: list[Path] = []
seen: set[str] = set()
for video in videos:
key = str(video.resolve()).lower()
if key not in seen:
seen.add(key)
result.append(video)
return result
def _parse_class_id_and_name(value: Any) -> tuple[int | None, str]:
text = str(value or "").strip()
match = re.match(r"^\s*(\d+)\s+(.+?)\s*$", text)
if match:
return int(match.group(1)), match.group(2).strip()
match = re.match(r"^\s*(\d+)\s*$", text)
if match:
class_id = int(match.group(1))
return class_id, f"class_{class_id}"
return None, text or "未知目标"
def _parse_bbox(value: Any) -> list[float] | None:
if isinstance(value, list) and len(value) == 4:
return [float(x) for x in value]
if isinstance(value, str):
parts = [x.strip() for x in value.split(",") if x.strip()]
if len(parts) == 4:
return [float(x) for x in parts]
return None
def _track_id_from_path(json_path: Path) -> int | None:
match = re.search(r"track_(\d+)", json_path.stem) or re.search(r"track_(\d+)", json_path.parent.name)
return int(match.group(1)) if match else None
def _normalize_1py_jsons(track_dir: Path) -> None:
"""1.py 可能输出单个 dictanalyze 脚本常用的是 list。这里做格式转换。"""
for json_path in track_dir.rglob("*.json"):
try:
data = json.loads(json_path.read_text(encoding="utf-8"))
except Exception as exc:
print(f"跳过异常 JSON{json_path},原因:{exc}")
continue
if isinstance(data, list):
continue
if not isinstance(data, dict):
continue
class_id, class_name = _parse_class_id_and_name(data.get("class_id") or data.get("class_name"))
normalized = [{
"xyxy": _parse_bbox(data.get("xyxy") or data.get("bbox") or data.get("location")),
"confidence": data.get("confidence", data.get("conf", 0.0)),
"track_id": data.get("track_id", _track_id_from_path(json_path)),
"class_id": class_id,
"class_str": class_name,
"start_frame": data.get("start_frame"),
"end_frame": data.get("end_frame"),
"start_sec": data.get("start_sec"),
}]
json_path.write_text(json.dumps(normalized, ensure_ascii=False, indent=2), encoding="utf-8")
def _video_to_tracks(video_path: Path, yolo_model_path: Path, output_dir: Path) -> Path:
module = _load_module(Path(TRACK_SCRIPT_PATH))
config = module.InferenceConfig()
# 覆盖 1.py 里的硬编码路径
config.VIDEO_PATH = str(video_path)
config.MODEL_PATH = str(yolo_model_path)
tracks_dir = output_dir / "tracks"
config.BEST_FRAME_DIR = str(tracks_dir)
config.OUTPUT_DIR = str(output_dir / "tracker_output")
tracker = module.VideoTracker(config, output_path=str(output_dir / "tracked_video.mp4"))
tracker.track_video()
_normalize_1py_jsons(tracks_dir)
return tracks_dir
def _run_image_analyze(track_dir: Path, report_dir: Path) -> None:
cmd = [
sys.executable,
str(Path(ANALYZE_SCRIPT_PATH)),
"--input",
str(track_dir),
"--output",
str(report_dir),
]
env = os.environ.copy()
env.setdefault("PYTHONUTF8", "1")
env.setdefault("PYTHONIOENCODING", "utf-8")
# 防止子进程出现 is_available=True 但 device_count=0 的诡异情况
env.pop("CUDA_VISIBLE_DEVICES", None)
print("执行图片检测:", " ".join(f'"{x}"' if " " in x else x for x in cmd))
code = subprocess.call(cmd, env=env)
if code != 0:
raise RuntimeError(f"图片检测失败,退出码:{code}")
def _process_one_video(video_path: Path, yolo_model_path: Path, output_root: Path) -> tuple[Path, Path]:
run_name = f"{_safe_name(video_path)}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
run_dir = output_root / run_name
report_dir = run_dir / "report"
run_dir.mkdir(parents=True, exist_ok=True)
report_dir.mkdir(parents=True, exist_ok=True)
print("=" * 60)
print(f"处理视频:{video_path}")
if run_sam3:
print("第 1 步:调用 1.py 生成 track 图片目录")
tracks_dir = _video_to_tracks(video_path, yolo_model_path, run_dir)
print(f"track 图片目录:{tracks_dir}")
else:
tracks_dir = run_dir / "tracks"
print(f"跳过 1.py使用已存在的 track 图片目录:{tracks_dir}")
if run_inspection or gen_report:
print("第 2 步:调用 analyze_track_images_with_llm.py 检测图片")
_run_image_analyze(tracks_dir, report_dir)
print(f"报告目录:{report_dir}")
else:
print("跳过图片检测")
print("=" * 60)
return tracks_dir, report_dir
# ===== 主体逻辑:单视频走单视频分支,多视频/批量收集逻辑也保留可用 =====
selected_video: Path = Path(VIDEO_FOLDER) / vid_file if vid_file else Path()
if not selected_video.is_file():
# Gradio 没选到时,退回 batch 脚本配置的 VIDEO_DIR / VIDEO_PATHS
videos = _collect_videos(VIDEO_DIR, VIDEO_PATHS)
if not videos:
raise gr.Error(
f"找不到视频:{selected_video};且 VIDEO_DIR / VIDEO_PATHS 都为空。"
)
selected_video = videos[0]
yolo_model = Path(str(YOLO_MODEL_PATH).strip().strip('"'))
if not yolo_model.is_file():
raise gr.Error(f"YOLO 权重不存在:{yolo_model},请设置 YOLO_MODEL_PATH。")
output_root = Path(OUTPUT_ROOT)
output_root.mkdir(parents=True, exist_ok=True)
tracks_dir, report_dir = _process_one_video(selected_video, yolo_model, output_root)
print(f"视频处理完成:{selected_video}")
# ===== 给 Gradio 返回 4 元组(沿用原 run 的输出结构) =====
# 取一张 track 图片作为预览
preview_path: str = ""
for img in tracks_dir.rglob("*.jpg"):
preview_path = str(img)
break
if not preview_path:
for img in tracks_dir.rglob("*.png"):
preview_path = str(img)
break
# 尝试读取 analyze 输出的 JSON 报告
json_result: dict | list = {}
for jp in sorted(report_dir.rglob("*.json")):
try:
json_result = json.loads(Path(jp).read_text(encoding="utf-8"))
break
except Exception:
continue
status_text = (
f"处理完成。\n"
f"视频:{selected_video}\n"
f"Track 目录:{tracks_dir}\n"
f"报告目录:{report_dir}"
)
return gr.update(maximum=0, value=0, step=1), preview_path, status_text, json_result
def get_vid_dict(vid_dir: str, obj_dict: dict, video_url_file: str, use_url_cache: bool) -> dict:
"""获取视频字典
参数:
vid_dir: 物体视频存放目录
obj_dict: 物品字典键为物品TrackID值为物品信息
video_url_file: 视频URL文件路径
use_url_cache: 是否使用缓存的URL
返回:
视频字典键为物品TrackID值为视频本地地址和视频URL
"""
vid_dict = {}
# 检查URL是否已存在
if use_url_cache:
if os.path.exists(video_url_file):
try:
with open(video_url_file, "r", encoding="utf-8") as f:
loaded_data = json.load(f)
# 验证数据结构:确保值是字典且包含必要键
for track_id, value in loaded_data.items():
if isinstance(value, dict) and "vid_path" in value and "vid_url" in value:
vid_dict[track_id] = value
else:
print(f"跳过无效的视频信息: {track_id} -> {value}")
except Exception as e:
print(f"读取URL文件失败: {e}")
# 初始化未存在的track_id
for track_id in obj_dict.keys():
if track_id in vid_dict:
continue
try:
vid_path = f"{vid_dir}/obj_{int(track_id):03d}.mp4"
vid_dict[track_id] = {"vid_path": vid_path, "vid_url": None}
except (ValueError, TypeError):
print(f"无效的track_id: {track_id},跳过")
# 找到未上传的视频
unuploaded_vids = []
for track_id, vid_info in vid_dict.items():
if isinstance(vid_info, dict) and vid_info.get("vid_url") is None:
unuploaded_vids.append(vid_info["vid_path"])
# 上传未上传的视频并获取 URL
uploaded_urls = upload_files_and_get_urls_concurrently(
file_path_list=unuploaded_vids,
max_workers=8
)
# 更新视频URL
for track_id, vid_info in vid_dict.items():
if isinstance(vid_info, dict) and vid_info.get("vid_url") is None:
if vid_info["vid_path"] in unuploaded_vids:
idx = unuploaded_vids.index(vid_info["vid_path"])
if idx < len(uploaded_urls):
vid_info["vid_url"] = uploaded_urls[idx]
# 保存字典为 JSON 文件
with open(video_url_file, "w", encoding="utf-8") as f:
json.dump(vid_dict, f, ensure_ascii=False, indent=4)
return vid_dict
# 创建 Gradio 页面
with gr.Blocks() as demo:
gr.Markdown("# 📸 隐患排查系统 (sam3 + qwen3.5-27b)")
with gr.Row():
with gr.Column():
initial_files: list[str] = load_file_list()
default_video = initial_files[0] if initial_files else None
vid_file = gr.Dropdown(
label="视频",
choices=initial_files,
value=default_video, # 默认选中第一个
interactive=True,
)
reload_file_list_button = gr.Button("刷新视频列表")
reload_file_list_button.click(fn=reload_files, inputs=[], outputs=[vid_file])
run_sam3 = gr.Checkbox(label="1. 运行 SAM3 模型", value=True)
run_inspection = gr.Checkbox(label="2. 运行隐患排查", value=True)
gen_report = gr.Checkbox(label="3. 生成报告", value=True)
audio_recognition = gr.Checkbox(label="4. 运行音频识别", value=False)
run_button = gr.Button("运行", variant="primary")
get_vid_path_btn = gr.Button("获取视频路径")
full_path_text = gr.Textbox(visible=False)
with gr.Column():
preview = gr.Image(label="预览", scale=2)
img_slider = gr.Slider(label="帧索引", minimum=0, maximum=0, value=0, step=1)
textbox = gr.Textbox(label="隐患结果", lines=10)
jsonbox = gr.JSON(label="隐患结果json")
run_button.click(fn=run, inputs=[vid_file, run_sam3, run_inspection, gen_report], outputs=[img_slider, preview, textbox, jsonbox])
img_slider.change(fn=update_preview, inputs=[img_slider, vid_file], outputs=[preview, textbox], show_progress="hidden")
get_vid_path_btn.click(fn=get_full_vid_path, inputs=vid_file, outputs=full_path_text)
def get_class_tag_by_frame(data, idx, fps):
"""
根据给定的帧索引 (idx),返回在该帧范围内的所有对象的 class:tag 信息。
参数:
data (dict): 包含 'tag', 'objects' 的字典数据hazard_inspection.json格式
idx (int): 需要查询的帧索引
返回:
str: 符合条件的 class:tag 列表,多个结果之间用换行符分隔。如果没有匹配项,返回空字符串。
"""
# 参数检查
if not isinstance(data, dict):
raise ValueError("数据必须是字典类型")
if 'tag' not in data or 'objects' not in data:
raise ValueError("数据必须包含 'tag''objects'")
tag_list = data['tag']
objects = data['objects']
interval = fps
idx = int(idx / interval) #转换
# 用于存储符合条件的 class:tag 字符串
result = []
all_class_tag = []
# 遍历每个物体
for obj in objects:
# 根据 tag_id 获取对应的隐患标签字符串
tag_id = obj.get('tag_id', 0)
class_id = obj.get('class_id', 0)
tag_str = tag_list[tag_id] if tag_id < len(tag_list) else f"未知标签({tag_id})"
location = obj.get('location', '')
start_frame = obj.get('start_frame', 0)
level = obj.get('level', '')
# 检查帧范围是否包含 idx
if start_frame == idx:
result.append(f"{tag_str} | 等级:{level} | 位置: {location}")
all_class_tag.append(f"{tag_str} | class_id:{class_id} | 等级:{level} | 开始帧:{start_frame} | 位置:{location}")
# 使用换行符连接所有结果
output = f"当前帧隐患:\n"+"\n".join(result)+"\n\n"+"所有隐患对象信息:\n"+"\n".join(all_class_tag)
return output
# 启动应用
if __name__ == "__main__":
demo.launch(
debug=True,
allowed_paths=[VIDEO_FOLDER]
)