HazardInspector/lib/sam3 copy.py

303 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from tkinter import N
from ultralytics.models.sam import SAM3VideoSemanticPredictor
import cv2
import json
import numpy as np
from pathlib import Path
class SAM3:
def __init__(self):
self._data: dict = {}
def run(self, vid_path: str, output_dir: str, class_file_path: str, conf: float = 0.6) -> dict:
"""
运行 SAM3 模型,返回视频中每个帧的检测结果。
参数:
vid_path: 视频路径
output_dir: 输出目录
class_file_path: 类别文件路径
返回:
每个帧的检测结果字典,键为帧索引,值为检测结果列表
字典结构:
{
"0": [ # frame_idx
{
"xyxy": [x1, y1, x2, y2], # 检测框坐标
"confidence": 0.9, # 置信度
"track_id": 123, # 跟踪ID
"class_id": 0, # 类别ID
"class_str": "class_name" # 类别名称
},
...
]
,
}
"""
# new_vid_path = extract_frames_to_video(vid_path, output_dir, interval)
self._data = {}
text_dict = json.load(open(class_file_path, "r", encoding="utf-8"))
text_list_en = list(text_dict.keys())
text_list_cn = list(text_dict.values())
overrides = dict(conf=conf, task="segment", mode="predict", imgsz=640, model="./model/sam3.pt", half=True, save=True, device=0)
predictor = SAM3VideoSemanticPredictor(overrides=overrides)
results = predictor(
source=vid_path,
text=text_list_en,
stream=True,
)
frame_idx = 0
# Path(f"{output_dir}/json").mkdir(parents=True, exist_ok=True)
annotated_frames: dict = {}
for r in results:
frame_data = []
if r.boxes is not None:
boxes = r.boxes.xyxy.cpu().numpy()
confs = r.boxes.conf.cpu().numpy()
cls_ids = r.boxes.cls.cpu().numpy()
track_ids = r.boxes.id.cpu().numpy()
for i in range(len(boxes)):
frame_data.append({
"xyxy": boxes[i].tolist(),
"confidence": float(confs[i]),
"track_id": int(track_ids[i]),
"class_id": int(cls_ids[i]),
"class_str": text_list_cn[int(cls_ids[i])]
})
annotated_frames[frame_idx] = frame_data
# with open(f"./output/json/frame_{frame_idx:04d}.json", "w", encoding="utf-8") as f:
# json.dump(frame_data, f, ensure_ascii=False, indent=2)
# 保存源图
source_frame = r.orig_img
source_path = f"{output_dir}/source/frame_{frame_idx:04d}.jpg"
Path(f"{output_dir}/source").mkdir(parents=True, exist_ok=True)
cv2.imwrite(source_path, source_frame)
# 保存标记图
mask_frame = r.plot()
mask_path = f"{output_dir}/boxes/frame_{frame_idx:04d}.jpg"
Path(f"{output_dir}/boxes").mkdir(parents=True, exist_ok=True)
cv2.imwrite(mask_path, mask_frame)
print(f"Frame {frame_idx}: {len(frame_data)} boxes")
# annotated_frame = r.plot()
frame_idx += 1
with open(f"{output_dir}/frame_all.json", "w", encoding="utf-8") as f:
json.dump(annotated_frames, f, ensure_ascii=False, indent=2)
print(f"\n总共处理 {frame_idx}")
print(f"结果已保存到 {output_dir}/frame_all.json")
print(f"标记图片已保存到 {output_dir}/boxes/")
cv2.destroyAllWindows()
self._data = annotated_frames
return annotated_frames
def load_from_json(self, json_path: str) -> dict:
"""
从 JSON 文件加载检测结果
参数:
json_path: JSON 文件路径
返回:
每个帧的检测结果字典,键为帧索引,值为检测结果列表
字典结构:
{
"0": [ # frame_idx
{
"xyxy": [x1, y1, x2, y2], # 检测框坐标
"confidence": 0.9, # 置信度
"track_id": 123, # 跟踪ID
"class_id": 0, # 类别ID
"class_str": "class_name" # 类别名称
},
...
]
,
}
"""
with open(json_path, "r", encoding="utf-8") as f:
self._data = json.load(f)
return self._data
def get_class_dict(self):
"""
只保留每帧的类列表,不包含其他信息
格式1
{
"class_list": ["class_name1", "class_name2", ..."],
"frame_list": {
"0": [ # frame_idx
0, # 类别ID
1, # 类别ID
...
],
},
}
"""
if self._data is None:
raise ValueError("请先调用 run() 方法运行模型")
# 收集所有类别名称并去重
class_set = set() # 无序,唯一
for frame_data in self._data.values():
for item in frame_data:
class_set.add(item["class_str"])
class_list = sorted(list(class_set))
# 构建每帧的类别ID列表
frame_list = {}
for frame_idx, frame_data in self._data.items():
class_ids = set() # 无序,唯一
for item in frame_data:
class_name = item["class_str"]
class_id = class_list.index(class_name)
class_ids.add(class_id)
frame_list[str(frame_idx)] = sorted(list(class_ids))
return {
"class_list": class_list,
"frame_list": frame_list
}
def data(self):
return self._data
def extract_frames_to_video(video_path, output_dir, interval=1):
"""
抽帧生成新视频(无任何图像处理)
参数:
- video_path: 原始视频文件路径
- output_dir: 输出文件夹路径(新视频将保存在此文件夹)
- interval: 抽帧间隔(如 6 表示每隔 5 帧抽取 1 帧)
"""
print(f"开始抽帧: {video_path}")
print(f"抽帧间隔: {interval}")
print(f"输出目录: {output_dir}")
# 确保输出目录存在
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 1. 打开原始视频
# 尝试使用硬件加速后端
backends = [
(cv2.CAP_FFMPEG, 'FFmpeg'), # 通用后端,支持硬件加速
(cv2.CAP_DSHOW, 'DirectShow'), # Windows 硬件加速
(cv2.CAP_ANY, 'Default') # 默认后端
]
cap = None
for backend, backend_name in backends:
try:
cap = cv2.VideoCapture(video_path, backend)
if cap.isOpened():
# 尝试启用硬件加速
if backend == cv2.CAP_FFMPEG:
# 设置FFmpeg硬件加速
try:
cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
# 检查硬件加速是否启用
hw_accel = cap.get(cv2.CAP_PROP_HW_ACCELERATION)
print(f"FFmpeg硬件加速: {'已启用' if hw_accel > 0 else '未启用'}")
except Exception as e:
print(f"设置硬件加速失败: {e}")
print(f"使用后端: {backend_name}")
break
except Exception as e:
print(f"尝试{backend_name}后端失败: {e}")
continue
if not cap or not cap.isOpened():
raise Exception(f"无法打开视频文件: {video_path}")
# 2. 获取视频参数
fps = cap.get(cv2.CAP_PROP_FPS) # 原始帧率
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 3. 计算新视频的帧率(保持原始播放速度)
# 新视频的帧率 = 原始帧率 / 抽帧间隔
new_fps = fps / max(interval, 1)
# 4. 初始化视频写入器
output_path = os.path.join(output_dir, "output_video.mp4")
# 使用更高效的编码格式
fourcc = cv2.VideoWriter_fourcc(*'avc1') # type: ignore # H.264 编码,更高效
# 尝试启用硬件加速写入
try:
# 尝试使用硬件加速的VideoWriter
out = cv2.VideoWriter(output_path, fourcc, new_fps, (width, height), isColor=True)
print("视频写入器初始化成功")
except Exception as e:
print(f"硬件加速写入失败,使用默认方式: {e}")
out = cv2.VideoWriter(output_path, fourcc, new_fps, (width, height))
# 5. 开始抽帧并写入新视频
frame_idx = 0
saved_frames = 0
print_interval = max(100, total_frames // 10) # 每处理100帧或10%进度打印一次
while cap.isOpened():
# 优化对于不需要的帧使用grab()跳过只对需要的帧使用retrieve()
if frame_idx % interval == 0:
# 需要处理的帧使用read()获取
ret, frame = cap.read()
if not ret:
break
out.write(frame) # 写入新视频
saved_frames += 1
else:
# 不需要处理的帧使用grab()跳过(更快)
ret = cap.grab()
if not ret:
break
frame_idx += 1
# 减少打印频率
if frame_idx % print_interval == 0:
progress = (frame_idx / total_frames) * 100
print(f"已处理帧 {frame_idx}/{total_frames} ({progress:.1f}%)")
# 6. 释放资源
cap.release()
out.release()
cv2.destroyAllWindows()
print(f"抽帧完成: 原视频 {total_frames} 帧, 抽取后 {saved_frames}")
print(f"新视频已保存至: {output_path}")
print(f"新视频帧率: {new_fps:.2f} fps (原始帧率: {fps:.2f} fps)")
return output_path
if __name__ == "__main__":
VID_NAME: str = "Peihdianhxiang"
VID_END: str = ".mp4"
conf: float = 0.6
output_dir: str = f"output/{VID_NAME}"
vid_path: str = f"./input/{VID_NAME}{VID_END}"
sam3: SAM3 = SAM3()
sam3.run(vid_path, output_dir, "1.厂房防火.json", conf)