292 lines
10 KiB
Python
292 lines
10 KiB
Python
import os
|
||
from tkinter import N
|
||
from ultralytics.models.sam import SAM3VideoSemanticPredictor
|
||
import cv2
|
||
import json
|
||
import numpy as np
|
||
from pathlib import Path
|
||
|
||
class SAM3:
|
||
def __init__(self):
|
||
self._data: dict = {}
|
||
|
||
# def run(self, vid_path: str, output_dir: str, class_file_path: str, interval: int = 1, conf: float = 0.6) -> dict:
|
||
def run(self, vid_path: str, output_dir: str, class_file_path: str, conf: float = 0.6) -> dict:
|
||
"""
|
||
运行 SAM3 模型,返回视频中每个帧的检测结果。
|
||
参数:
|
||
vid_path: 视频路径
|
||
output_dir: 输出目录
|
||
class_file_path: 类别文件路径
|
||
返回:
|
||
每个帧的检测结果字典,键为帧索引,值为检测结果列表
|
||
|
||
字典结构:
|
||
{
|
||
"0": [ # frame_idx
|
||
{
|
||
"xyxy": [x1, y1, x2, y2], # 检测框坐标
|
||
"confidence": 0.9, # 置信度
|
||
"track_id": 123, # 跟踪ID
|
||
"class_id": 0, # 类别ID
|
||
"class_str": "class_name" # 类别名称
|
||
},
|
||
...
|
||
]
|
||
,
|
||
}
|
||
"""
|
||
# new_vid_path = extract_frames_to_video(vid_path, output_dir, interval)
|
||
|
||
self._data = {}
|
||
text_dict = json.load(open(class_file_path, "r", encoding="utf-8"))
|
||
text_list_en = list(text_dict.keys())
|
||
text_list_cn = list(text_dict.values())
|
||
|
||
overrides = dict(conf=conf, task="segment", mode="predict", imgsz=640, model="./model/sam3.pt", half=True, save=True, device=0)
|
||
predictor = SAM3VideoSemanticPredictor(overrides=overrides)
|
||
|
||
results = predictor(
|
||
source=vid_path,
|
||
text=text_list_en,
|
||
stream=True,
|
||
)
|
||
|
||
frame_idx = 0
|
||
|
||
|
||
|
||
# Path(f"{output_dir}/json").mkdir(parents=True, exist_ok=True)
|
||
annotated_frames: dict = {}
|
||
for r in results:
|
||
frame_data = []
|
||
|
||
if r.boxes is not None:
|
||
boxes = r.boxes.xyxy.cpu().numpy()
|
||
confs = r.boxes.conf.cpu().numpy()
|
||
cls_ids = r.boxes.cls.cpu().numpy()
|
||
track_ids = r.boxes.id.cpu().numpy()
|
||
|
||
for i in range(len(boxes)):
|
||
frame_data.append({
|
||
"xyxy": boxes[i].tolist(),
|
||
"confidence": float(confs[i]),
|
||
"track_id": int(track_ids[i]),
|
||
"class_id": int(cls_ids[i]),
|
||
"class_str": text_list_cn[int(cls_ids[i])]
|
||
})
|
||
|
||
annotated_frames[frame_idx] = frame_data
|
||
# with open(f"./output/json/frame_{frame_idx:04d}.json", "w", encoding="utf-8") as f:
|
||
# json.dump(frame_data, f, ensure_ascii=False, indent=2)
|
||
|
||
# 保存源图
|
||
source_frame = r.orig_img
|
||
source_path = f"{output_dir}/source/frame_{frame_idx:04d}.jpg"
|
||
Path(f"{output_dir}/source").mkdir(parents=True, exist_ok=True)
|
||
cv2.imwrite(source_path, source_frame)
|
||
|
||
# 保存标记图
|
||
mask_frame = r.plot()
|
||
mask_path = f"{output_dir}/boxes/frame_{frame_idx:04d}.jpg"
|
||
Path(f"{output_dir}/boxes").mkdir(parents=True, exist_ok=True)
|
||
cv2.imwrite(mask_path, mask_frame)
|
||
|
||
print(f"Frame {frame_idx}: {len(frame_data)} boxes")
|
||
|
||
# annotated_frame = r.plot()
|
||
frame_idx += 1
|
||
|
||
with open(f"{output_dir}/frame_detections.json", "w", encoding="utf-8") as f:
|
||
json.dump(annotated_frames, f, ensure_ascii=False, indent=2)
|
||
|
||
|
||
|
||
print(f"\n总共处理 {frame_idx} 帧")
|
||
print(f"结果已保存到 {output_dir}/frame_detections.json")
|
||
print(f"标记图片已保存到 {output_dir}/boxes/")
|
||
|
||
cv2.destroyAllWindows()
|
||
|
||
self._data = annotated_frames
|
||
return annotated_frames
|
||
|
||
def load_from_json(self, json_path: str) -> dict:
|
||
"""
|
||
从 JSON 文件加载检测结果
|
||
参数:
|
||
json_path: JSON 文件路径
|
||
返回:
|
||
每个帧的检测结果字典,键为帧索引,值为检测结果列表
|
||
|
||
字典结构:
|
||
{
|
||
"0": [ # frame_idx
|
||
{
|
||
"xyxy": [x1, y1, x2, y2], # 检测框坐标
|
||
"confidence": 0.9, # 置信度
|
||
"track_id": 123, # 跟踪ID
|
||
"class_id": 0, # 类别ID
|
||
"class_str": "class_name" # 类别名称
|
||
},
|
||
...
|
||
]
|
||
,
|
||
}
|
||
"""
|
||
with open(json_path, "r", encoding="utf-8") as f:
|
||
self._data = json.load(f)
|
||
return self._data
|
||
|
||
def get_class_dict(self):
|
||
"""
|
||
只保留每帧的类列表,不包含其他信息
|
||
|
||
格式1:
|
||
{
|
||
"class_list": ["class_name1", "class_name2", ..."],
|
||
"frame_list": {
|
||
"0": [ # frame_idx
|
||
0, # 类别ID
|
||
1, # 类别ID
|
||
...
|
||
],
|
||
},
|
||
}
|
||
"""
|
||
if self._data is None:
|
||
raise ValueError("请先调用 run() 方法运行模型")
|
||
|
||
# 收集所有类别名称并去重
|
||
class_set = set() # 无序,唯一
|
||
for frame_data in self._data.values():
|
||
for item in frame_data:
|
||
class_set.add(item["class_str"])
|
||
class_list = sorted(list(class_set))
|
||
|
||
# 构建每帧的类别ID列表
|
||
frame_list = {}
|
||
for frame_idx, frame_data in self._data.items():
|
||
class_ids = set() # 无序,唯一
|
||
for item in frame_data:
|
||
class_name = item["class_str"]
|
||
class_id = class_list.index(class_name)
|
||
class_ids.add(class_id)
|
||
frame_list[str(frame_idx)] = sorted(list(class_ids))
|
||
|
||
return {
|
||
"class_list": class_list,
|
||
"frame_list": frame_list
|
||
}
|
||
|
||
def data(self):
|
||
return self._data
|
||
|
||
def extract_frames_to_video(video_path, output_dir, interval=1):
|
||
"""
|
||
抽帧生成新视频(无任何图像处理)
|
||
|
||
参数:
|
||
- video_path: 原始视频文件路径
|
||
- output_dir: 输出文件夹路径(新视频将保存在此文件夹)
|
||
- interval: 抽帧间隔(如 6 表示每隔 5 帧抽取 1 帧)
|
||
"""
|
||
print(f"开始抽帧: {video_path}")
|
||
print(f"抽帧间隔: {interval}")
|
||
print(f"输出目录: {output_dir}")
|
||
# 确保输出目录存在
|
||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||
|
||
# 1. 打开原始视频
|
||
# 尝试使用硬件加速后端
|
||
backends = [
|
||
(cv2.CAP_FFMPEG, 'FFmpeg'), # 通用后端,支持硬件加速
|
||
(cv2.CAP_DSHOW, 'DirectShow'), # Windows 硬件加速
|
||
(cv2.CAP_ANY, 'Default') # 默认后端
|
||
]
|
||
|
||
cap = None
|
||
for backend, backend_name in backends:
|
||
try:
|
||
cap = cv2.VideoCapture(video_path, backend)
|
||
if cap.isOpened():
|
||
# 尝试启用硬件加速
|
||
if backend == cv2.CAP_FFMPEG:
|
||
# 设置FFmpeg硬件加速
|
||
try:
|
||
cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
|
||
# 检查硬件加速是否启用
|
||
hw_accel = cap.get(cv2.CAP_PROP_HW_ACCELERATION)
|
||
print(f"FFmpeg硬件加速: {'已启用' if hw_accel > 0 else '未启用'}")
|
||
except Exception as e:
|
||
print(f"设置硬件加速失败: {e}")
|
||
print(f"使用后端: {backend_name}")
|
||
break
|
||
except Exception as e:
|
||
print(f"尝试{backend_name}后端失败: {e}")
|
||
continue
|
||
|
||
if not cap or not cap.isOpened():
|
||
raise Exception(f"无法打开视频文件: {video_path}")
|
||
|
||
# 2. 获取视频参数
|
||
fps = cap.get(cv2.CAP_PROP_FPS) # 原始帧率
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
||
# 3. 计算新视频的帧率(保持原始播放速度)
|
||
# 新视频的帧率 = 原始帧率 / 抽帧间隔
|
||
new_fps = fps / max(interval, 1)
|
||
|
||
# 4. 初始化视频写入器
|
||
output_path = os.path.join(output_dir, "output_video.mp4")
|
||
# 使用更高效的编码格式
|
||
fourcc = cv2.VideoWriter_fourcc(*'avc1') # type: ignore # H.264 编码,更高效
|
||
|
||
# 尝试启用硬件加速写入
|
||
try:
|
||
# 尝试使用硬件加速的VideoWriter
|
||
out = cv2.VideoWriter(output_path, fourcc, new_fps, (width, height), isColor=True)
|
||
print("视频写入器初始化成功")
|
||
except Exception as e:
|
||
print(f"硬件加速写入失败,使用默认方式: {e}")
|
||
out = cv2.VideoWriter(output_path, fourcc, new_fps, (width, height))
|
||
|
||
# 5. 开始抽帧并写入新视频
|
||
frame_idx = 0
|
||
saved_frames = 0
|
||
print_interval = max(100, total_frames // 10) # 每处理100帧或10%进度打印一次
|
||
|
||
while cap.isOpened():
|
||
# 优化:对于不需要的帧,使用grab()跳过,只对需要的帧使用retrieve()
|
||
if frame_idx % interval == 0:
|
||
# 需要处理的帧,使用read()获取
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
out.write(frame) # 写入新视频
|
||
saved_frames += 1
|
||
else:
|
||
# 不需要处理的帧,使用grab()跳过(更快)
|
||
ret = cap.grab()
|
||
if not ret:
|
||
break
|
||
|
||
frame_idx += 1
|
||
|
||
# 减少打印频率
|
||
if frame_idx % print_interval == 0:
|
||
progress = (frame_idx / total_frames) * 100
|
||
print(f"已处理帧 {frame_idx}/{total_frames} ({progress:.1f}%)")
|
||
|
||
# 6. 释放资源
|
||
cap.release()
|
||
out.release()
|
||
cv2.destroyAllWindows()
|
||
|
||
print(f"抽帧完成: 原视频 {total_frames} 帧, 抽取后 {saved_frames} 帧")
|
||
print(f"新视频已保存至: {output_path}")
|
||
print(f"新视频帧率: {new_fps:.2f} fps (原始帧率: {fps:.2f} fps)")
|
||
|
||
return output_path |