320 lines
10 KiB
Python
320 lines
10 KiB
Python
"""
|
|
Web API Server for ASR and Speaker Diarization
|
|
提供语音识别和说话人分离的 REST API 服务
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import gc
|
|
from pathlib import Path
|
|
from flask import Flask, request, jsonify, send_file
|
|
from werkzeug.utils import secure_filename
|
|
import threading
|
|
import uuid
|
|
|
|
app = Flask(__name__)
|
|
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024
|
|
app.config['UPLOAD_FOLDER'] = 'uploads'
|
|
app.config['RESULT_FOLDER'] = 'results'
|
|
|
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
|
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
|
|
|
|
GLOBAL_ASR_SERVICE = None
|
|
GLOBAL_DIAR_SERVICE = None
|
|
ASR_MODEL_LOADED = False
|
|
DIAR_MODEL_LOADED = False
|
|
ASR_MODEL_LOCK = threading.Lock()
|
|
DIAR_MODEL_LOCK = threading.Lock()
|
|
|
|
|
|
def get_asr_service():
|
|
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED
|
|
if GLOBAL_ASR_SERVICE is None:
|
|
from asr_service import ASRService
|
|
GLOBAL_ASR_SERVICE = ASRService()
|
|
return GLOBAL_ASR_SERVICE
|
|
|
|
|
|
def get_diar_service():
|
|
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED
|
|
if GLOBAL_DIAR_SERVICE is None:
|
|
from diarization_service import DiarizationService
|
|
GLOBAL_DIAR_SERVICE = DiarizationService()
|
|
return GLOBAL_DIAR_SERVICE
|
|
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
def health_check():
|
|
"""健康检查"""
|
|
return jsonify({
|
|
'status': 'ok',
|
|
'asr_loaded': ASR_MODEL_LOADED,
|
|
'diar_loaded': DIAR_MODEL_LOADED
|
|
})
|
|
|
|
|
|
@app.route('/api/asr/load', methods=['GET'])
|
|
def load_asr_model():
|
|
"""加载 ASR 模型"""
|
|
global ASR_MODEL_LOADED
|
|
with ASR_MODEL_LOCK:
|
|
if ASR_MODEL_LOADED:
|
|
return jsonify({'message': 'ASR 模型已加载', 'loaded': True})
|
|
|
|
try:
|
|
data = request.json or {}
|
|
model_name = data.get('model_name', 'paraformer-zh')
|
|
device = data.get('device', 'auto')
|
|
|
|
print(f"正在加载 ASR 模型: {model_name}, 设备: {device}")
|
|
service = get_asr_service()
|
|
service._load_model()
|
|
ASR_MODEL_LOADED = True
|
|
|
|
return jsonify({
|
|
'message': 'ASR 模型加载成功',
|
|
'loaded': True,
|
|
'model': model_name,
|
|
'device': device
|
|
})
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
@app.route('/api/diar/load', methods=['POST'])
|
|
def load_diar_model():
|
|
"""加载 3D-Speaker 模型"""
|
|
global DIAR_MODEL_LOADED
|
|
with DIAR_MODEL_LOCK:
|
|
if DIAR_MODEL_LOADED:
|
|
return jsonify({'message': '3D-Speaker 模型已加载', 'loaded': True})
|
|
|
|
try:
|
|
data = request.json or {}
|
|
embedding_model = data.get('embedding_model', 'eres2netv2')
|
|
device = data.get('device', 'auto')
|
|
cluster_threshold = data.get('cluster_threshold', 0.5)
|
|
min_cluster_size = data.get('min_cluster_size', 10)
|
|
|
|
print(f"正在加载 3D-Speaker 模型: {embedding_model}, 设备: {device}")
|
|
service = get_diar_service()
|
|
service._load_model()
|
|
DIAR_MODEL_LOADED = True
|
|
|
|
return jsonify({
|
|
'message': '3D-Speaker 模型加载成功',
|
|
'loaded': True,
|
|
'embedding_model': embedding_model,
|
|
'device': device
|
|
})
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
@app.route('/api/asr/unload', methods=['POST'])
|
|
def unload_asr_model():
|
|
"""卸载 ASR 模型"""
|
|
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED
|
|
|
|
try:
|
|
GLOBAL_ASR_SERVICE = None
|
|
ASR_MODEL_LOADED = False
|
|
gc.collect()
|
|
|
|
return jsonify({'message': 'ASR 模型已卸载'})
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
@app.route('/api/diar/unload', methods=['POST'])
|
|
def unload_diar_model():
|
|
"""卸载 3D-Speaker 模型"""
|
|
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED
|
|
|
|
try:
|
|
GLOBAL_DIAR_SERVICE = None
|
|
DIAR_MODEL_LOADED = False
|
|
gc.collect()
|
|
|
|
return jsonify({'message': '3D-Speaker 模型已卸载'})
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
@app.route('/api/recognize/single', methods=['POST'])
|
|
def recognize_single():
|
|
"""单文件推理"""
|
|
try:
|
|
if 'file' not in request.files:
|
|
return jsonify({'error': '请上传音频文件'}), 400
|
|
|
|
file = request.files['file']
|
|
if file.filename == '':
|
|
return jsonify({'error': '文件名不能为空'}), 400
|
|
|
|
data = request.form or {}
|
|
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
|
|
embedding_model = data.get('embedding_model', 'eres2netv2')
|
|
cluster_threshold = float(data.get('cluster_threshold', 0.5))
|
|
min_cluster_size = int(data.get('min_cluster_size', 10))
|
|
output_format = data.get('format', 'json')
|
|
|
|
filename = secure_filename(file.filename)
|
|
task_id = str(uuid.uuid4())[:8]
|
|
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
|
|
file.save(audio_path)
|
|
|
|
try:
|
|
global ASR_MODEL_LOADED
|
|
service = get_asr_service()
|
|
if not ASR_MODEL_LOADED:
|
|
service._load_model()
|
|
ASR_MODEL_LOADED = True
|
|
|
|
sentences = service.recognize(
|
|
audio_path,
|
|
use_3d_speaker=use_3d_speaker,
|
|
embedding_model=embedding_model,
|
|
cluster_threshold=cluster_threshold,
|
|
min_cluster_size=min_cluster_size
|
|
)
|
|
|
|
result = {
|
|
'file': filename,
|
|
'total_sentences': len(sentences),
|
|
'sentences': [s.to_dict() for s in sentences]
|
|
}
|
|
|
|
if output_format == 'json':
|
|
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.json")
|
|
with open(result_path, 'w', encoding='utf-8') as f:
|
|
import json
|
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
return jsonify(result)
|
|
else:
|
|
srt_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.srt")
|
|
service.export_to_srt(sentences, srt_path)
|
|
return send_file(srt_path, as_attachment=True, download_name=f"{task_id}_result.srt")
|
|
|
|
finally:
|
|
if os.path.exists(audio_path):
|
|
os.remove(audio_path)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
traceback.print_exc()
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
@app.route('/api/recognize/batch', methods=['POST'])
|
|
def recognize_batch():
|
|
"""批量推理"""
|
|
try:
|
|
if 'files' not in request.files:
|
|
return jsonify({'error': '请上传音频文件'}), 400
|
|
|
|
files = request.files.getlist('files')
|
|
if not files or len(files) == 0:
|
|
return jsonify({'error': '文件列表为空'}), 400
|
|
|
|
data = request.form or {}
|
|
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
|
|
embedding_model = data.get('embedding_model', 'eres2netv2')
|
|
cluster_threshold = float(data.get('cluster_threshold', 0.5))
|
|
min_cluster_size = int(data.get('min_cluster_size', 10))
|
|
|
|
task_id = str(uuid.uuid4())[:8]
|
|
audio_paths = []
|
|
|
|
for f in files:
|
|
if f.filename:
|
|
filename = secure_filename(f.filename)
|
|
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
|
|
f.save(audio_path)
|
|
audio_paths.append(audio_path)
|
|
|
|
try:
|
|
global ASR_MODEL_LOADED
|
|
service = get_asr_service()
|
|
if not ASR_MODEL_LOADED:
|
|
service._load_model()
|
|
ASR_MODEL_LOADED = True
|
|
|
|
results = []
|
|
for audio_path in audio_paths:
|
|
try:
|
|
sentences = service.recognize(
|
|
audio_path,
|
|
use_3d_speaker=use_3d_speaker,
|
|
embedding_model=embedding_model,
|
|
cluster_threshold=cluster_threshold,
|
|
min_cluster_size=min_cluster_size
|
|
)
|
|
results.append({
|
|
'file': os.path.basename(audio_path),
|
|
'total_sentences': len(sentences),
|
|
'sentences': [s.to_dict() for s in sentences]
|
|
})
|
|
except Exception as e:
|
|
results.append({
|
|
'file': os.path.basename(audio_path),
|
|
'error': str(e)
|
|
})
|
|
|
|
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_batch_result.json")
|
|
import json
|
|
with open(result_path, 'w', encoding='utf-8') as f:
|
|
json.dump({
|
|
'task_id': task_id,
|
|
'total_files': len(results),
|
|
'results': results
|
|
}, f, ensure_ascii=False, indent=2)
|
|
|
|
return jsonify({
|
|
'task_id': task_id,
|
|
'total_files': len(results),
|
|
'results': results,
|
|
'result_file': result_path
|
|
})
|
|
|
|
finally:
|
|
for audio_path in audio_paths:
|
|
if os.path.exists(audio_path):
|
|
os.remove(audio_path)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
traceback.print_exc()
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
@app.route('/api/status', methods=['GET'])
|
|
def get_status():
|
|
"""获取模型状态"""
|
|
return jsonify({
|
|
'asr_loaded': ASR_MODEL_LOADED,
|
|
'diar_loaded': DIAR_MODEL_LOADED,
|
|
'asr_model': 'paraformer-zh',
|
|
'diar_model': '3D-Speaker'
|
|
})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
print("=" * 60)
|
|
print(" ASR & Speaker Diarization API Server")
|
|
print("=" * 60)
|
|
print("\nAPI 接口:")
|
|
print(" GET /health - 健康检查")
|
|
print(" GET /api/status - 获取模型状态")
|
|
print(" POST /api/asr/load - 加载 ASR 模型")
|
|
print(" POST /api/diar/load - 加载 3D-Speaker 模型")
|
|
print(" POST /api/asr/unload - 卸载 ASR 模型")
|
|
print(" POST /api/diar/unload - 卸载 3D-Speaker 模型")
|
|
print(" POST /api/recognize/single - 单文件推理")
|
|
print(" POST /api/recognize/batch - 批量推理")
|
|
print("\n" + "=" * 60)
|
|
print("启动服务: http://localhost:5000")
|
|
print("=" * 60)
|
|
app.run(host='0.0.0.0', port=5000, debug=False)
|