import os from tkinter import N from ultralytics.models.sam import SAM3VideoSemanticPredictor import cv2 import json import numpy as np from pathlib import Path class SAM3: def __init__(self): self._data: dict = {} def run(self, vid_path: str, output_dir: str, class_file_path: str, conf: float = 0.6) -> dict: """ 运行 SAM3 模型,返回视频中每个帧的检测结果。 参数: vid_path: 视频路径 output_dir: 输出目录 class_file_path: 类别文件路径 返回: 每个帧的检测结果字典,键为帧索引,值为检测结果列表 字典结构: { "0": [ # frame_idx { "xyxy": [x1, y1, x2, y2], # 检测框坐标 "confidence": 0.9, # 置信度 "track_id": 123, # 跟踪ID "class_id": 0, # 类别ID "class_str": "class_name" # 类别名称 }, ... ] , } """ # new_vid_path = extract_frames_to_video(vid_path, output_dir, interval) self._data = {} text_dict = json.load(open(class_file_path, "r", encoding="utf-8")) text_list_en = list(text_dict.keys()) text_list_cn = list(text_dict.values()) overrides = dict(conf=conf, task="segment", mode="predict", imgsz=640, model="./model/sam3.pt", half=True, save=True, device=0) predictor = SAM3VideoSemanticPredictor(overrides=overrides) results = predictor( source=vid_path, text=text_list_en, stream=True, ) frame_idx = 0 # Path(f"{output_dir}/json").mkdir(parents=True, exist_ok=True) annotated_frames: dict = {} for r in results: frame_data = [] if r.boxes is not None: boxes = r.boxes.xyxy.cpu().numpy() confs = r.boxes.conf.cpu().numpy() cls_ids = r.boxes.cls.cpu().numpy() track_ids = r.boxes.id.cpu().numpy() for i in range(len(boxes)): frame_data.append({ "xyxy": boxes[i].tolist(), "confidence": float(confs[i]), "track_id": int(track_ids[i]), "class_id": int(cls_ids[i]), "class_str": text_list_cn[int(cls_ids[i])] }) annotated_frames[frame_idx] = frame_data # with open(f"./output/json/frame_{frame_idx:04d}.json", "w", encoding="utf-8") as f: # json.dump(frame_data, f, ensure_ascii=False, indent=2) # 保存源图 source_frame = r.orig_img source_path = f"{output_dir}/source/frame_{frame_idx:04d}.jpg" Path(f"{output_dir}/source").mkdir(parents=True, exist_ok=True) cv2.imwrite(source_path, source_frame) # 保存标记图 mask_frame = r.plot() mask_path = f"{output_dir}/boxes/frame_{frame_idx:04d}.jpg" Path(f"{output_dir}/boxes").mkdir(parents=True, exist_ok=True) cv2.imwrite(mask_path, mask_frame) print(f"Frame {frame_idx}: {len(frame_data)} boxes") # annotated_frame = r.plot() frame_idx += 1 with open(f"{output_dir}/frame_all.json", "w", encoding="utf-8") as f: json.dump(annotated_frames, f, ensure_ascii=False, indent=2) print(f"\n总共处理 {frame_idx} 帧") print(f"结果已保存到 {output_dir}/frame_all.json") print(f"标记图片已保存到 {output_dir}/boxes/") cv2.destroyAllWindows() self._data = annotated_frames return annotated_frames def load_from_json(self, json_path: str) -> dict: """ 从 JSON 文件加载检测结果 参数: json_path: JSON 文件路径 返回: 每个帧的检测结果字典,键为帧索引,值为检测结果列表 字典结构: { "0": [ # frame_idx { "xyxy": [x1, y1, x2, y2], # 检测框坐标 "confidence": 0.9, # 置信度 "track_id": 123, # 跟踪ID "class_id": 0, # 类别ID "class_str": "class_name" # 类别名称 }, ... ] , } """ with open(json_path, "r", encoding="utf-8") as f: self._data = json.load(f) return self._data def get_class_dict(self): """ 只保留每帧的类列表,不包含其他信息 格式1: { "class_list": ["class_name1", "class_name2", ..."], "frame_list": { "0": [ # frame_idx 0, # 类别ID 1, # 类别ID ... ], }, } """ if self._data is None: raise ValueError("请先调用 run() 方法运行模型") # 收集所有类别名称并去重 class_set = set() # 无序,唯一 for frame_data in self._data.values(): for item in frame_data: class_set.add(item["class_str"]) class_list = sorted(list(class_set)) # 构建每帧的类别ID列表 frame_list = {} for frame_idx, frame_data in self._data.items(): class_ids = set() # 无序,唯一 for item in frame_data: class_name = item["class_str"] class_id = class_list.index(class_name) class_ids.add(class_id) frame_list[str(frame_idx)] = sorted(list(class_ids)) return { "class_list": class_list, "frame_list": frame_list } def data(self): return self._data def extract_frames_to_video(video_path, output_dir, interval=1): """ 抽帧生成新视频(无任何图像处理) 参数: - video_path: 原始视频文件路径 - output_dir: 输出文件夹路径(新视频将保存在此文件夹) - interval: 抽帧间隔(如 6 表示每隔 5 帧抽取 1 帧) """ print(f"开始抽帧: {video_path}") print(f"抽帧间隔: {interval}") print(f"输出目录: {output_dir}") # 确保输出目录存在 Path(output_dir).mkdir(parents=True, exist_ok=True) # 1. 打开原始视频 # 尝试使用硬件加速后端 backends = [ (cv2.CAP_FFMPEG, 'FFmpeg'), # 通用后端,支持硬件加速 (cv2.CAP_DSHOW, 'DirectShow'), # Windows 硬件加速 (cv2.CAP_ANY, 'Default') # 默认后端 ] cap = None for backend, backend_name in backends: try: cap = cv2.VideoCapture(video_path, backend) if cap.isOpened(): # 尝试启用硬件加速 if backend == cv2.CAP_FFMPEG: # 设置FFmpeg硬件加速 try: cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY) # 检查硬件加速是否启用 hw_accel = cap.get(cv2.CAP_PROP_HW_ACCELERATION) print(f"FFmpeg硬件加速: {'已启用' if hw_accel > 0 else '未启用'}") except Exception as e: print(f"设置硬件加速失败: {e}") print(f"使用后端: {backend_name}") break except Exception as e: print(f"尝试{backend_name}后端失败: {e}") continue if not cap or not cap.isOpened(): raise Exception(f"无法打开视频文件: {video_path}") # 2. 获取视频参数 fps = cap.get(cv2.CAP_PROP_FPS) # 原始帧率 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 3. 计算新视频的帧率(保持原始播放速度) # 新视频的帧率 = 原始帧率 / 抽帧间隔 new_fps = fps / max(interval, 1) # 4. 初始化视频写入器 output_path = os.path.join(output_dir, "output_video.mp4") # 使用更高效的编码格式 fourcc = cv2.VideoWriter_fourcc(*'avc1') # type: ignore # H.264 编码,更高效 # 尝试启用硬件加速写入 try: # 尝试使用硬件加速的VideoWriter out = cv2.VideoWriter(output_path, fourcc, new_fps, (width, height), isColor=True) print("视频写入器初始化成功") except Exception as e: print(f"硬件加速写入失败,使用默认方式: {e}") out = cv2.VideoWriter(output_path, fourcc, new_fps, (width, height)) # 5. 开始抽帧并写入新视频 frame_idx = 0 saved_frames = 0 print_interval = max(100, total_frames // 10) # 每处理100帧或10%进度打印一次 while cap.isOpened(): # 优化:对于不需要的帧,使用grab()跳过,只对需要的帧使用retrieve() if frame_idx % interval == 0: # 需要处理的帧,使用read()获取 ret, frame = cap.read() if not ret: break out.write(frame) # 写入新视频 saved_frames += 1 else: # 不需要处理的帧,使用grab()跳过(更快) ret = cap.grab() if not ret: break frame_idx += 1 # 减少打印频率 if frame_idx % print_interval == 0: progress = (frame_idx / total_frames) * 100 print(f"已处理帧 {frame_idx}/{total_frames} ({progress:.1f}%)") # 6. 释放资源 cap.release() out.release() cv2.destroyAllWindows() print(f"抽帧完成: 原视频 {total_frames} 帧, 抽取后 {saved_frames} 帧") print(f"新视频已保存至: {output_path}") print(f"新视频帧率: {new_fps:.2f} fps (原始帧率: {fps:.2f} fps)") return output_path if __name__ == "__main__": VID_NAME: str = "Peihdianhxiang" VID_END: str = ".mp4" conf: float = 0.6 output_dir: str = f"output/{VID_NAME}" vid_path: str = f"./input/{VID_NAME}{VID_END}" sam3: SAM3 = SAM3() sam3.run(vid_path, output_dir, "1.厂房防火.json", conf)