324 lines
10 KiB
Python
324 lines
10 KiB
Python
"""
|
|
测试 ASR 核心功能
|
|
"""
|
|
import pytest
|
|
from pathlib import Path
|
|
import tempfile
|
|
import shutil
|
|
import sys
|
|
import os
|
|
|
|
project_root = Path(__file__).parent.parent.absolute()
|
|
sys.path.insert(0, str(project_root))
|
|
os.chdir(project_root)
|
|
|
|
|
|
class TestASRCoreFunctions:
|
|
"""测试 ASR 核心函数"""
|
|
|
|
def test_get_video_list_function_exists(self):
|
|
"""测试获取视频列表函数存在"""
|
|
from app.asr.core import get_video_list
|
|
assert callable(get_video_list)
|
|
|
|
def test_get_video_list_empty_folder(self, temp_dirs):
|
|
"""测试空文件夹获取视频列表"""
|
|
from app.asr.core import get_video_list
|
|
videos = get_video_list(temp_dirs['input'])
|
|
assert isinstance(videos, list)
|
|
assert len(videos) == 0
|
|
|
|
def test_get_video_list_with_video_files(self, temp_dirs):
|
|
"""测试有视频文件时获取列表"""
|
|
from app.asr.core import get_video_list
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
# 创建独立的临时目录
|
|
test_dir = Path(tempfile.mkdtemp())
|
|
(test_dir / 'test1.mp4').touch()
|
|
(test_dir / 'test2.mp4').touch()
|
|
(test_dir / 'test3.avi').touch()
|
|
|
|
videos = get_video_list(test_dir)
|
|
assert len(videos) == 3
|
|
assert all(isinstance(v, Path) for v in videos)
|
|
|
|
# 清理
|
|
import shutil
|
|
shutil.rmtree(test_dir)
|
|
|
|
def test_get_video_list_sorting(self, temp_dirs):
|
|
"""测试视频列表排序"""
|
|
from app.asr.core import get_video_list
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
# 创建独立的临时目录
|
|
test_dir = Path(tempfile.mkdtemp())
|
|
|
|
# 创建带时间戳的文件名
|
|
(test_dir / 'VID_20251031_132320.mp4').touch()
|
|
(test_dir / 'VID_20251031_132330.mp4').touch()
|
|
(test_dir / 'VID_20251031_132340.mp4').touch()
|
|
|
|
videos = get_video_list(test_dir)
|
|
assert len(videos) == 3
|
|
# 验证按文件名排序
|
|
assert '132320' in str(videos[0])
|
|
assert '132330' in str(videos[1])
|
|
assert '132340' in str(videos[2])
|
|
|
|
# 清理
|
|
import shutil
|
|
shutil.rmtree(test_dir)
|
|
|
|
def test_get_video_list_supported_formats(self, temp_dirs):
|
|
"""测试支持的视频格式"""
|
|
from app.asr.core import get_video_list
|
|
|
|
formats = ['mp4', 'avi', 'mkv', 'mov', 'flv', 'wmv', 'm4v']
|
|
for fmt in formats:
|
|
(temp_dirs['input'] / f'test.{fmt}').touch()
|
|
|
|
videos = get_video_list(temp_dirs['input'])
|
|
assert len(videos) == len(formats)
|
|
|
|
|
|
class TestTempDirectory:
|
|
"""测试临时目录管理"""
|
|
|
|
def test_clear_temp_dir_function_exists(self):
|
|
"""测试清空临时目录函数存在"""
|
|
from app.asr.core import clear_temp_dir
|
|
assert callable(clear_temp_dir)
|
|
|
|
def test_ensure_output_dir_function_exists(self):
|
|
"""测试确保输出目录函数存在"""
|
|
from app.asr.core import ensure_output_dir
|
|
assert callable(ensure_output_dir)
|
|
|
|
def test_clear_temp_dir_creates_directory(self, temp_dirs):
|
|
"""测试清空临时目录会创建目录"""
|
|
from app.asr.core import clear_temp_dir, TEMP_DIR
|
|
|
|
# 临时修改 TEMP_DIR
|
|
import app.asr.core as core_module
|
|
original_temp = core_module.TEMP_DIR
|
|
core_module.TEMP_DIR = temp_dirs['temp']
|
|
|
|
clear_temp_dir()
|
|
assert temp_dirs['temp'].exists()
|
|
|
|
# 恢复原路径
|
|
core_module.TEMP_DIR = original_temp
|
|
|
|
def test_ensure_output_dir_creates_directory(self, temp_dirs):
|
|
"""测试确保输出目录会创建目录"""
|
|
from app.asr.core import ensure_output_dir, OUTPUT_DIR
|
|
|
|
# 临时修改 OUTPUT_DIR
|
|
import app.asr.core as core_module
|
|
original_output = core_module.OUTPUT_DIR
|
|
core_module.OUTPUT_DIR = temp_dirs['output']
|
|
|
|
ensure_output_dir()
|
|
assert temp_dirs['output'].exists()
|
|
|
|
# 恢复原路径
|
|
core_module.OUTPUT_DIR = original_output
|
|
|
|
|
|
class TestExtractWav:
|
|
"""测试 WAV 提取功能"""
|
|
|
|
def test_extract_wav_function_exists(self):
|
|
"""测试提取 WAV 函数存在"""
|
|
from app.asr.core import extract_wav
|
|
assert callable(extract_wav)
|
|
|
|
def test_extract_wav_with_nonexistent_video(self, temp_dirs):
|
|
"""测试不存在的视频文件"""
|
|
from app.asr.core import extract_wav
|
|
|
|
video_path = temp_dirs['input'] / 'nonexistent.mp4'
|
|
result = extract_wav(video_path, temp_dirs['temp'])
|
|
assert result is None
|
|
|
|
|
|
class TestASRService:
|
|
"""测试 ASR 服务类"""
|
|
|
|
def test_asr_service_class_exists(self):
|
|
"""测试 ASR 服务类存在"""
|
|
from app.asr.asr_service import ASRService
|
|
assert ASRService is not None
|
|
|
|
def test_asr_service_initialization(self):
|
|
"""测试 ASR 服务初始化"""
|
|
from app.asr.asr_service import ASRService
|
|
service = ASRService()
|
|
assert service is not None
|
|
assert service.model_name == 'paraformer-zh'
|
|
|
|
def test_asr_service_custom_model(self):
|
|
"""测试自定义模型初始化"""
|
|
from app.asr.asr_service import ASRService
|
|
service = ASRService(model_name='SenseVoice')
|
|
assert service.model_name == 'SenseVoice'
|
|
|
|
def test_asr_service_device_auto(self):
|
|
"""测试自动设备检测"""
|
|
from app.asr.asr_service import ASRService
|
|
service = ASRService(device='auto')
|
|
assert service.device in ['cpu', 'cuda']
|
|
|
|
def test_asr_service_sentence_class(self):
|
|
"""测试句子数据类"""
|
|
from app.asr.asr_service import Sentence
|
|
sentence = Sentence(
|
|
speaker='SPK1',
|
|
text='测试文本',
|
|
begin_time=0.0,
|
|
end_time=1.0
|
|
)
|
|
assert sentence.speaker == 'SPK1'
|
|
assert sentence.text == '测试文本'
|
|
assert sentence.begin_time == 0.0
|
|
assert sentence.end_time == 1.0
|
|
|
|
def test_sentence_to_dict(self):
|
|
"""测试句子转字典"""
|
|
from app.asr.asr_service import Sentence
|
|
sentence = Sentence(
|
|
speaker='SPK1',
|
|
text='测试文本',
|
|
begin_time=0.0,
|
|
end_time=1.0
|
|
)
|
|
d = sentence.to_dict()
|
|
assert d['speaker'] == 'SPK1'
|
|
assert d['text'] == '测试文本'
|
|
assert d['begin_time'] == 0.0
|
|
assert d['end_time'] == 1.0
|
|
assert 'duration' in d
|
|
|
|
|
|
class TestDiarizationService:
|
|
"""测试说话人分离服务类"""
|
|
|
|
def test_diarization_service_class_exists(self):
|
|
"""测试说话人分离服务类存在"""
|
|
from app.asr.diarization_service import DiarizationService
|
|
assert DiarizationService is not None
|
|
|
|
def test_diarization_service_initialization(self):
|
|
"""测试说话人分离服务初始化"""
|
|
from app.asr.diarization_service import DiarizationService
|
|
service = DiarizationService()
|
|
assert service is not None
|
|
assert service.embedding_model == 'eres2net'
|
|
|
|
def test_diarization_service_custom_model(self):
|
|
"""测试自定义嵌入模型"""
|
|
from app.asr.diarization_service import DiarizationService
|
|
service = DiarizationService(embedding_model='campplus')
|
|
assert service.embedding_model == 'campplus'
|
|
|
|
def test_diarization_segment_class(self):
|
|
"""测试说话人分离片段数据类"""
|
|
from app.asr.diarization_service import DiarizationSegment
|
|
segment = DiarizationSegment(
|
|
speaker='SPK1',
|
|
begin_time=0.0,
|
|
end_time=1.0
|
|
)
|
|
assert segment.speaker == 'SPK1'
|
|
assert segment.begin_time == 0.0
|
|
assert segment.end_time == 1.0
|
|
|
|
def test_diarization_segment_to_dict(self):
|
|
"""测试片段转字典"""
|
|
from app.asr.diarization_service import DiarizationSegment
|
|
segment = DiarizationSegment(
|
|
speaker='SPK1',
|
|
begin_time=0.0,
|
|
end_time=1.0
|
|
)
|
|
d = segment.to_dict()
|
|
assert d['speaker'] == 'SPK1'
|
|
assert d['begin_time'] == 0.0
|
|
assert d['end_time'] == 1.0
|
|
assert 'duration' in d
|
|
|
|
|
|
class TestMapSpeaker:
|
|
"""测试说话人映射功能"""
|
|
|
|
def test_map_speaker_module_import(self):
|
|
"""测试说话人映射模块可以导入"""
|
|
from app.asr import map_speaker
|
|
assert map_speaker is not None
|
|
|
|
|
|
class TestTranscodeCore:
|
|
"""测试转码核心功能"""
|
|
|
|
def test_convert_to_h264_function_exists(self):
|
|
"""测试转码函数存在"""
|
|
from app.transcode.core import convert_to_h264
|
|
assert callable(convert_to_h264)
|
|
|
|
def test_convert_to_h264_file_not_found(self, temp_dirs):
|
|
"""测试文件不存在时抛出异常"""
|
|
from app.transcode.core import convert_to_h264
|
|
|
|
with pytest.raises(FileNotFoundError):
|
|
convert_to_h264(
|
|
input_root=temp_dirs['input'],
|
|
vid_full_name='nonexistent.mp4',
|
|
output_root=temp_dirs['output']
|
|
)
|
|
|
|
def test_convert_to_h264_output_naming(self, temp_dirs):
|
|
"""测试输出文件命名"""
|
|
from app.transcode.core import convert_to_h264
|
|
|
|
# 创建假输入文件
|
|
input_file = temp_dirs['input'] / 'test_video.mp4'
|
|
input_file.touch()
|
|
|
|
# 尝试验证输出命名(可能因 ffmpeg 失败)
|
|
try:
|
|
output = convert_to_h264(
|
|
input_root=temp_dirs['input'],
|
|
vid_full_name='test_video.mp4',
|
|
output_root=temp_dirs['output']
|
|
)
|
|
assert 'test_video_h264.mp4' in output
|
|
except Exception as e:
|
|
# 预期可能因为 ffmpeg 问题失败
|
|
assert True # 只要不崩溃即可
|
|
|
|
|
|
class TestCaddyRun:
|
|
"""测试 Caddy 运行"""
|
|
|
|
def test_run_caddy_function_exists(self):
|
|
"""测试运行 Caddy 函数存在"""
|
|
from lib.caddy.run import run_caddy
|
|
assert callable(run_caddy)
|
|
|
|
def test_run_caddy_default_port(self):
|
|
"""测试默认端口"""
|
|
from lib.caddy.run import run_caddy
|
|
# 验证函数签名
|
|
import inspect
|
|
sig = inspect.signature(run_caddy)
|
|
assert 'port' in sig.parameters
|
|
assert sig.parameters['port'].default == 8086
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__, '-v'])
|