222 lines
7.7 KiB
Python
222 lines
7.7 KiB
Python
"""
|
||
测试 ASR 路由
|
||
"""
|
||
import pytest
|
||
import json
|
||
from pathlib import Path
|
||
import tempfile
|
||
import os
|
||
import sys
|
||
|
||
project_root = Path(__file__).parent.parent.absolute()
|
||
sys.path.insert(0, str(project_root))
|
||
os.chdir(project_root)
|
||
|
||
from main import create_app
|
||
|
||
|
||
class TestRecognizeRoute:
|
||
"""测试语音识别路由"""
|
||
|
||
def test_recognize_missing_path_parameter(self):
|
||
"""测试缺少 path 参数时返回错误"""
|
||
app = create_app()
|
||
app.config['TESTING'] = True
|
||
with app.test_client() as client:
|
||
response = client.get('/api/recognize')
|
||
assert response.status_code == 400
|
||
data = response.get_json()
|
||
assert data['status'] == 'error'
|
||
assert 'path' in data['message']
|
||
|
||
def test_recognize_with_empty_path(self):
|
||
"""测试 path 为空时返回错误"""
|
||
app = create_app()
|
||
app.config['TESTING'] = True
|
||
with app.test_client() as client:
|
||
response = client.get('/api/recognize?path=')
|
||
assert response.status_code == 400
|
||
data = response.get_json()
|
||
assert data['status'] == 'error'
|
||
|
||
def test_recognize_with_valid_path(self, temp_dirs):
|
||
"""测试有效 path 参数(文件不存在情况)"""
|
||
app = create_app()
|
||
app.config['TESTING'] = True
|
||
app.config['INPUT_DIR'] = str(temp_dirs['input'])
|
||
app.config['OUTPUT_DIR'] = str(temp_dirs['output'])
|
||
|
||
# 注意:由于实际文件不存在,可能返回错误
|
||
# 这里主要测试路由是否能正确处理请求
|
||
with app.test_client() as client:
|
||
response = client.get('/api/recognize?path=test.mp4')
|
||
# 可能返回 200(成功)或 500(文件不存在)
|
||
assert response.status_code in [200, 400, 500]
|
||
|
||
def test_recognize_response_format(self, temp_dirs):
|
||
"""测试识别响应格式"""
|
||
app = create_app()
|
||
app.config['TESTING'] = True
|
||
app.config['INPUT_DIR'] = str(temp_dirs['input'])
|
||
app.config['OUTPUT_DIR'] = str(temp_dirs['output'])
|
||
|
||
with app.test_client() as client:
|
||
response = client.get('/api/recognize?path=test.mp4')
|
||
if response.status_code == 200:
|
||
data = response.get_json()
|
||
assert 'status' in data
|
||
assert 'message' in data
|
||
assert 'data' in data
|
||
if 'task_id' in data.get('data', {}):
|
||
assert isinstance(data['data']['task_id'], str)
|
||
|
||
|
||
class TestResultRoute:
|
||
"""测试结果获取路由"""
|
||
|
||
def test_result_missing_path_parameter(self):
|
||
"""测试缺少 path 参数时返回错误"""
|
||
app = create_app()
|
||
app.config['TESTING'] = True
|
||
with app.test_client() as client:
|
||
response = client.get('/api/result')
|
||
assert response.status_code == 400
|
||
data = response.get_json()
|
||
assert data['status'] == 'error'
|
||
assert 'path' in data['message']
|
||
|
||
def test_result_with_empty_path(self):
|
||
"""测试 path 为空时返回错误"""
|
||
app = create_app()
|
||
app.config['TESTING'] = True
|
||
with app.test_client() as client:
|
||
response = client.get('/api/result?path=')
|
||
assert response.status_code == 400
|
||
data = response.get_json()
|
||
assert data['status'] == 'error'
|
||
|
||
def test_result_nonexistent_file(self, temp_dirs):
|
||
"""测试结果文件不存在时返回 404"""
|
||
app = create_app()
|
||
app.config['TESTING'] = True
|
||
app.config['OUTPUT_DIR'] = str(temp_dirs['output'])
|
||
|
||
with app.test_client() as client:
|
||
response = client.get('/api/result?path=nonexistent.mp4')
|
||
assert response.status_code == 404
|
||
data = response.get_json()
|
||
assert data['status'] == 'error'
|
||
|
||
def test_result_with_valid_file(self, temp_dirs):
|
||
"""测试获取存在的结果文件"""
|
||
app = create_app()
|
||
app.config['TESTING'] = True
|
||
app.config['OUTPUT_DIR'] = str(temp_dirs['output'])
|
||
|
||
# 创建模拟结果文件
|
||
output_dir = Path(temp_dirs['output']) / 'SpeechRecognition'
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
result_file = output_dir / 'test_result.json'
|
||
test_data = {
|
||
'status': 'success',
|
||
'data': {
|
||
'sentences': [
|
||
{
|
||
'speaker': 'SPK1',
|
||
'text': '测试文本',
|
||
'begin_time': 0.0,
|
||
'end_time': 1.0
|
||
}
|
||
]
|
||
}
|
||
}
|
||
with open(result_file, 'w', encoding='utf-8') as f:
|
||
json.dump(test_data, f, ensure_ascii=False)
|
||
|
||
with app.test_client() as client:
|
||
response = client.get('/api/result?path=test.mp4')
|
||
assert response.status_code == 200
|
||
data = response.get_json()
|
||
assert data['status'] == 'success'
|
||
assert 'data' in data
|
||
|
||
|
||
class TestASRGlobalState:
|
||
"""测试 ASR 全局状态"""
|
||
|
||
def test_task_running_dict_exists(self):
|
||
"""测试任务运行状态字典存在"""
|
||
from app.asr.routes import task_running
|
||
assert task_running is not None
|
||
assert isinstance(task_running, dict)
|
||
|
||
def test_global_asr_service_variable(self):
|
||
"""测试全局 ASR 服务变量存在"""
|
||
from app.asr.routes import GLOBAL_ASR_SERVICE
|
||
# 初始可能为 None
|
||
assert GLOBAL_ASR_SERVICE is None or hasattr(GLOBAL_ASR_SERVICE, 'recognize')
|
||
|
||
def test_global_diar_service_variable(self):
|
||
"""测试全局说话人分离服务变量存在"""
|
||
from app.asr.routes import GLOBAL_DIAR_SERVICE
|
||
# 初始可能为 None
|
||
assert GLOBAL_DIAR_SERVICE is None or hasattr(GLOBAL_DIAR_SERVICE, 'diarize')
|
||
|
||
|
||
class TestASRServiceIntegration:
|
||
"""测试 ASR 服务集成"""
|
||
|
||
def test_asr_service_import(self):
|
||
"""测试 ASR 服务可以导入"""
|
||
from app.asr.asr_service import ASRService
|
||
assert ASRService is not None
|
||
|
||
def test_diarization_service_import(self):
|
||
"""测试说话人分离服务可以导入"""
|
||
from app.asr.diarization_service import DiarizationService
|
||
assert DiarizationService is not None
|
||
|
||
def test_asr_service_initialization(self):
|
||
"""测试 ASR 服务初始化"""
|
||
from app.asr.asr_service import ASRService
|
||
service = ASRService()
|
||
assert service is not None
|
||
assert service.model_name == 'paraformer-zh'
|
||
|
||
def test_diarization_service_initialization(self):
|
||
"""测试说话人分离服务初始化"""
|
||
from app.asr.diarization_service import DiarizationService
|
||
service = DiarizationService()
|
||
assert service is not None
|
||
assert service.embedding_model == 'eres2net'
|
||
|
||
|
||
class TestASRCore:
|
||
"""测试 ASR 核心功能"""
|
||
|
||
def test_core_module_import(self):
|
||
"""测试核心模块可以导入"""
|
||
from app.asr import core
|
||
assert core is not None
|
||
|
||
def test_get_video_list_function(self):
|
||
"""测试获取视频列表函数存在"""
|
||
from app.asr.core import get_video_list
|
||
assert get_video_list is not None
|
||
|
||
def test_extract_wav_function(self):
|
||
"""测试提取 WAV 函数存在"""
|
||
from app.asr.core import extract_wav
|
||
assert extract_wav is not None
|
||
|
||
def test_get_video_list_with_empty_folder(self, temp_dirs):
|
||
"""测试空文件夹获取视频列表"""
|
||
from app.asr.core import get_video_list
|
||
videos = get_video_list(temp_dirs['input'])
|
||
assert isinstance(videos, list)
|
||
assert len(videos) == 0
|
||
|
||
|
||
if __name__ == '__main__':
|
||
pytest.main([__file__, '-v'])
|