SpeechRecognition/server.py

320 lines
10 KiB
Python

"""
Web API Server for ASR and Speaker Diarization
提供语音识别和说话人分离的 REST API 服务
"""
import os
import sys
import gc
from pathlib import Path
from flask import Flask, request, jsonify, send_file
from werkzeug.utils import secure_filename
import threading
import uuid
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024
app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['RESULT_FOLDER'] = 'results'
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
GLOBAL_ASR_SERVICE = None
GLOBAL_DIAR_SERVICE = None
ASR_MODEL_LOADED = False
DIAR_MODEL_LOADED = False
ASR_MODEL_LOCK = threading.Lock()
DIAR_MODEL_LOCK = threading.Lock()
def get_asr_service():
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED
if GLOBAL_ASR_SERVICE is None:
from asr_service import ASRService
GLOBAL_ASR_SERVICE = ASRService()
return GLOBAL_ASR_SERVICE
def get_diar_service():
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED
if GLOBAL_DIAR_SERVICE is None:
from diarization_service import DiarizationService
GLOBAL_DIAR_SERVICE = DiarizationService()
return GLOBAL_DIAR_SERVICE
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查"""
return jsonify({
'status': 'ok',
'asr_loaded': ASR_MODEL_LOADED,
'diar_loaded': DIAR_MODEL_LOADED
})
@app.route('/api/asr/load', methods=['GET'])
def load_asr_model():
"""加载 ASR 模型"""
global ASR_MODEL_LOADED
with ASR_MODEL_LOCK:
if ASR_MODEL_LOADED:
return jsonify({'message': 'ASR 模型已加载', 'loaded': True})
try:
data = request.json or {}
model_name = data.get('model_name', 'paraformer-zh')
device = data.get('device', 'auto')
print(f"正在加载 ASR 模型: {model_name}, 设备: {device}")
service = get_asr_service()
service._load_model()
ASR_MODEL_LOADED = True
return jsonify({
'message': 'ASR 模型加载成功',
'loaded': True,
'model': model_name,
'device': device
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/diar/load', methods=['POST'])
def load_diar_model():
"""加载 3D-Speaker 模型"""
global DIAR_MODEL_LOADED
with DIAR_MODEL_LOCK:
if DIAR_MODEL_LOADED:
return jsonify({'message': '3D-Speaker 模型已加载', 'loaded': True})
try:
data = request.json or {}
embedding_model = data.get('embedding_model', 'eres2netv2')
device = data.get('device', 'auto')
cluster_threshold = data.get('cluster_threshold', 0.5)
min_cluster_size = data.get('min_cluster_size', 10)
print(f"正在加载 3D-Speaker 模型: {embedding_model}, 设备: {device}")
service = get_diar_service()
service._load_model()
DIAR_MODEL_LOADED = True
return jsonify({
'message': '3D-Speaker 模型加载成功',
'loaded': True,
'embedding_model': embedding_model,
'device': device
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/asr/unload', methods=['POST'])
def unload_asr_model():
"""卸载 ASR 模型"""
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED
try:
GLOBAL_ASR_SERVICE = None
ASR_MODEL_LOADED = False
gc.collect()
return jsonify({'message': 'ASR 模型已卸载'})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/diar/unload', methods=['POST'])
def unload_diar_model():
"""卸载 3D-Speaker 模型"""
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED
try:
GLOBAL_DIAR_SERVICE = None
DIAR_MODEL_LOADED = False
gc.collect()
return jsonify({'message': '3D-Speaker 模型已卸载'})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/recognize/single', methods=['POST'])
def recognize_single():
"""单文件推理"""
try:
if 'file' not in request.files:
return jsonify({'error': '请上传音频文件'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': '文件名不能为空'}), 400
data = request.form or {}
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
embedding_model = data.get('embedding_model', 'eres2netv2')
cluster_threshold = float(data.get('cluster_threshold', 0.5))
min_cluster_size = int(data.get('min_cluster_size', 10))
output_format = data.get('format', 'json')
filename = secure_filename(file.filename)
task_id = str(uuid.uuid4())[:8]
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
file.save(audio_path)
try:
global ASR_MODEL_LOADED
service = get_asr_service()
if not ASR_MODEL_LOADED:
service._load_model()
ASR_MODEL_LOADED = True
sentences = service.recognize(
audio_path,
use_3d_speaker=use_3d_speaker,
embedding_model=embedding_model,
cluster_threshold=cluster_threshold,
min_cluster_size=min_cluster_size
)
result = {
'file': filename,
'total_sentences': len(sentences),
'sentences': [s.to_dict() for s in sentences]
}
if output_format == 'json':
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.json")
with open(result_path, 'w', encoding='utf-8') as f:
import json
json.dump(result, f, ensure_ascii=False, indent=2)
return jsonify(result)
else:
srt_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.srt")
service.export_to_srt(sentences, srt_path)
return send_file(srt_path, as_attachment=True, download_name=f"{task_id}_result.srt")
finally:
if os.path.exists(audio_path):
os.remove(audio_path)
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'error': str(e)}), 500
@app.route('/api/recognize/batch', methods=['POST'])
def recognize_batch():
"""批量推理"""
try:
if 'files' not in request.files:
return jsonify({'error': '请上传音频文件'}), 400
files = request.files.getlist('files')
if not files or len(files) == 0:
return jsonify({'error': '文件列表为空'}), 400
data = request.form or {}
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
embedding_model = data.get('embedding_model', 'eres2netv2')
cluster_threshold = float(data.get('cluster_threshold', 0.5))
min_cluster_size = int(data.get('min_cluster_size', 10))
task_id = str(uuid.uuid4())[:8]
audio_paths = []
for f in files:
if f.filename:
filename = secure_filename(f.filename)
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
f.save(audio_path)
audio_paths.append(audio_path)
try:
global ASR_MODEL_LOADED
service = get_asr_service()
if not ASR_MODEL_LOADED:
service._load_model()
ASR_MODEL_LOADED = True
results = []
for audio_path in audio_paths:
try:
sentences = service.recognize(
audio_path,
use_3d_speaker=use_3d_speaker,
embedding_model=embedding_model,
cluster_threshold=cluster_threshold,
min_cluster_size=min_cluster_size
)
results.append({
'file': os.path.basename(audio_path),
'total_sentences': len(sentences),
'sentences': [s.to_dict() for s in sentences]
})
except Exception as e:
results.append({
'file': os.path.basename(audio_path),
'error': str(e)
})
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_batch_result.json")
import json
with open(result_path, 'w', encoding='utf-8') as f:
json.dump({
'task_id': task_id,
'total_files': len(results),
'results': results
}, f, ensure_ascii=False, indent=2)
return jsonify({
'task_id': task_id,
'total_files': len(results),
'results': results,
'result_file': result_path
})
finally:
for audio_path in audio_paths:
if os.path.exists(audio_path):
os.remove(audio_path)
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'error': str(e)}), 500
@app.route('/api/status', methods=['GET'])
def get_status():
"""获取模型状态"""
return jsonify({
'asr_loaded': ASR_MODEL_LOADED,
'diar_loaded': DIAR_MODEL_LOADED,
'asr_model': 'paraformer-zh',
'diar_model': '3D-Speaker'
})
if __name__ == '__main__':
print("=" * 60)
print(" ASR & Speaker Diarization API Server")
print("=" * 60)
print("\nAPI 接口:")
print(" GET /health - 健康检查")
print(" GET /api/status - 获取模型状态")
print(" POST /api/asr/load - 加载 ASR 模型")
print(" POST /api/diar/load - 加载 3D-Speaker 模型")
print(" POST /api/asr/unload - 卸载 ASR 模型")
print(" POST /api/diar/unload - 卸载 3D-Speaker 模型")
print(" POST /api/recognize/single - 单文件推理")
print(" POST /api/recognize/batch - 批量推理")
print("\n" + "=" * 60)
print("启动服务: http://localhost:5000")
print("=" * 60)
app.run(host='0.0.0.0', port=5000, debug=False)