完善服务端

This commit is contained in:
yueliuli 2026-05-06 15:57:34 +08:00
parent bdfce46d7c
commit 1507093701
4 changed files with 220 additions and 265 deletions

View File

@ -26,6 +26,12 @@ D:\ProgramData\miniconda3\envs\audio\python.exe
# FunASR 语音识别服务 # FunASR 语音识别服务
基于阿里达摩院 [FunASR](https://github.com/alibaba-damo-academy/FunASR) 的本地语音识别解决方案。 基于阿里达摩院 [FunASR](https://github.com/alibaba-damo-academy/FunASR) 的本地语音识别解决方案。

46
main.py
View File

@ -12,6 +12,7 @@
import gc import gc
import json import json
import os
import shutil import shutil
import subprocess import subprocess
import time import time
@ -23,8 +24,7 @@ BASE_DIR = Path(__file__).parent.absolute()
TEMP_DIR = BASE_DIR / "temp" TEMP_DIR = BASE_DIR / "temp"
OUTPUT_DIR = BASE_DIR / "output" OUTPUT_DIR = BASE_DIR / "output"
# 视频文件夹路径(全局变量)
VIDEO_FOLDER = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司"
# 支持的视频格式 # 支持的视频格式
SUPPORTED_VIDEO_FORMATS = ["*.mp4", "*.avi", "*.mkv", "*.mov", "*.flv", "*.wmv", "*.m4v"] SUPPORTED_VIDEO_FORMATS = ["*.mp4", "*.avi", "*.mkv", "*.mov", "*.flv", "*.wmv", "*.m4v"]
@ -271,7 +271,12 @@ def process_batch_asr(video_paths, diar_results, max_workers=1):
print("加载 ASR 模型...") print("加载 ASR 模型...")
from asr_service import ASRService from asr_service import ASRService
asr_service = ASRService(model_name="paraformer-zh", device="auto") asr_service = ASRService(
model_name="paraformer-zh",
device="auto",
min_segment_duration=0.5, # 增加最小片段时长
merge_gap=0.3 # 减少合并间隔
)
asr_service._load_model() asr_service._load_model()
print("✓ ASR 模型已加载") print("✓ ASR 模型已加载")
print() print()
@ -489,34 +494,54 @@ def process_batch_asr(video_paths, diar_results, max_workers=1):
return results return results
def main(): def main(path: str):
"""主函数""" """主函数"""
import torch import torch
# 拼接根目录
path = os.path.join(os.path.dirname(__file__), "input", path)
print(f"开始处理路径:{path}")
print("\n" + "=" * 60) print("\n" + "=" * 60)
print(" 并发批量语音识别处理系统") print(" 并发批量语音识别处理系统")
print("=" * 60) print("=" * 60)
print() print()
# 0. 检查输入路径是否存在
if not Path(path).exists():
print(f"✗ 错误:输入路径不存在 - {path}")
return "输入路径不存在"
# 文件夹为空
if Path(path).is_dir() and not list(Path(path).iterdir()):
print("✗ 错误:文件夹为空")
return "文件夹为空"
# 1. 清空 temp 目录 # 1. 清空 temp 目录
clear_temp_dir() clear_temp_dir()
# 2. 确保 output 目录存在 # 2. 确保 output 目录存在
ensure_output_dir() ensure_output_dir()
# 3. 准备视频列表(从 VIDEO_FOLDER 自动获取) # 3. 准备视频列表(从 path 自动获取)
video_folder = Path(VIDEO_FOLDER) # 如果是文件夹,递归获取所有视频文件
# 如果是文件,直接添加到列表
video_paths = []
if Path(path).is_dir():
video_folder = Path(path)
if not video_folder.exists(): if not video_folder.exists():
print(f"✗ 错误:视频文件夹不存在 - {video_folder}") print(f"✗ 错误:视频文件夹不存在 - {video_folder}")
return return
video_paths = get_video_list(video_folder) video_paths = get_video_list(video_folder)
if not video_paths: if not video_paths:
print("✗ 错误:未找到任何视频文件") print("✗ 错误:未找到任何视频文件")
print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}") print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}")
print(f"请检查文件夹:{video_folder}") print(f"请检查文件夹:{video_folder}")
return return
else:
video_paths.append(Path(path))
video_paths = list(set(video_paths)) # 去重
print(f"找到 {len(video_paths)} 个视频文件") print(f"找到 {len(video_paths)} 个视频文件")
for vp in video_paths: for vp in video_paths:
@ -569,4 +594,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() # 视频文件夹路径(全局变量)
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司"
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\temp"
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\temp\VID_20251104_085655_024.AVI"
PATH = r"VID_20251104_085655_024.AVI"
main(PATH)

View File

@ -59,5 +59,6 @@ jieba>=0.42.0
# onnxruntime-gpu>=1.23.0 # onnxruntime-gpu>=1.23.0
# ---------- 可选Web API 服务 ---------- # ---------- 可选Web API 服务 ----------
# Flask>=3.1.0 Flask>=3.1.0
waitress>=3.0.0
# SQLAlchemy>=2.0.0 # SQLAlchemy>=2.0.0

438
server.py
View File

@ -6,16 +6,71 @@ Web API Server for ASR and Speaker Diarization
import os import os
import sys import sys
import gc import gc
from pathlib import Path import json
from flask import Flask, request, jsonify, send_file import shutil
from werkzeug.utils import secure_filename import signal
import threading import threading
from pathlib import Path
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
import uuid import uuid
from datetime import datetime, timezone
def make_response(status="success", data=None, errors=None, message=None, extra=None):
"""
统一 API 响应格式
Args:
status: 状态 ("success" "error")
data: 返回的数据
errors: 错误列表
message: 消息
extra: 其他额外字段
Returns:
统一格式的 JSON 响应
"""
response = {
"status": status,
"data": data if data is not None else {},
"errors": errors if errors is not None else [],
"message": message or ("操作成功" if status == "success" else "操作失败"),
"timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
}
if extra:
response.update(extra)
return response
app = Flask(__name__) app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024
app.config['UPLOAD_FOLDER'] = 'uploads' app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['RESULT_FOLDER'] = 'results' app.config['RESULT_FOLDER'] = 'results'
# 增加请求超时时间(秒),支持长时间运行的任务
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 300 # 5
OUTPUT_DIR = Path('output')
# 手动添加 CORS 响应头
@app.after_request
def after_request(response):
# 获取请求的 Origin
origin = request.headers.get('Origin', '*')
# 设置允许的来源(不使用通配符)
response.headers.add('Access-Control-Allow-Origin', origin)
response.headers.add('Access-Control-Allow-Credentials', 'true')
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS')
return response
# 处理 OPTIONS 预检请求
@app.route('/config', methods=['OPTIONS'])
def config_options():
"""处理 /config 接口的 OPTIONS 预检请求"""
return '', 200
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True) os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
@ -27,285 +82,148 @@ DIAR_MODEL_LOADED = False
ASR_MODEL_LOCK = threading.Lock() ASR_MODEL_LOCK = threading.Lock()
DIAR_MODEL_LOCK = threading.Lock() DIAR_MODEL_LOCK = threading.Lock()
# 全局变量用于控制任务执行
def get_asr_service(): task_running = {}
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED task_timeout = 600 # 10 分钟超时(秒)
if GLOBAL_ASR_SERVICE is None:
from asr_service import ASRService
GLOBAL_ASR_SERVICE = ASRService()
return GLOBAL_ASR_SERVICE
def get_diar_service(): @app.route('/api/recognize', methods=['POST'])
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED def recognize():
if GLOBAL_DIAR_SERVICE is None: """文件推理 - 调用 main.py 的 main 函数"""
from diarization_service import DiarizationService task_id = str(uuid.uuid4())
GLOBAL_DIAR_SERVICE = DiarizationService() task_running[task_id] = True
return GLOBAL_DIAR_SERVICE
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查"""
return jsonify({
'status': 'ok',
'asr_loaded': ASR_MODEL_LOADED,
'diar_loaded': DIAR_MODEL_LOADED
})
@app.route('/api/asr/load', methods=['GET'])
def load_asr_model():
"""加载 ASR 模型"""
global ASR_MODEL_LOADED
with ASR_MODEL_LOCK:
if ASR_MODEL_LOADED:
return jsonify({'message': 'ASR 模型已加载', 'loaded': True})
try: try:
# 从请求参数获取路径(只接受文件名或 input 目录下的相对路径)
data = request.json or {} data = request.json or {}
model_name = data.get('model_name', 'paraformer-zh') path = data.get('path', '')
device = data.get('device', 'auto')
print(f"正在加载 ASR 模型: {model_name}, 设备: {device}") if not path:
service = get_asr_service() return jsonify(make_response(
service._load_model() status="error",
ASR_MODEL_LOADED = True message="请提供文件路径",
errors=["缺少必要参数path"]
)), 400
print(f"\n{'='*60}")
print(f"API 收到请求path={path}, task_id={task_id}")
print(f"{'='*60}")
print(f"开始调用 main 函数...")
# 切换到 server.py 所在目录
server_dir = Path(__file__).parent.absolute()
os.chdir(server_dir)
print(f"当前工作目录:{os.getcwd()}")
# 设置超时处理
def timeout_handler(signum, frame):
task_running[task_id] = False
raise TimeoutError(f"任务执行超时 ({task_timeout}秒)")
# 注册信号处理器(仅在 Unix/Linux/Mac 有效Windows 下会被忽略)
try:
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(task_timeout)
use_alarm = True
except (AttributeError, ValueError):
# Windows 不支持 SIGALRM
use_alarm = False
print("注意Windows 系统,使用超时检测线程")
try:
# 调用 main.py 的 main 函数
from main import main
main(path)
# 取消超时
if use_alarm:
signal.alarm(0)
print(f"main 函数执行完成")
return jsonify(make_response(
status="success",
message="文件推理完成",
data={"path": path, "task_id": task_id}
)), 200
except TimeoutError as e:
print(f"任务超时:{e}")
return jsonify(make_response(
status="error",
message=str(e),
errors=[str(e)]
)), 504 # Gateway Timeout
finally:
# 清理信号处理器
if use_alarm:
signal.alarm(0)
task_running[task_id] = False
return jsonify({
'message': 'ASR 模型加载成功',
'loaded': True,
'model': model_name,
'device': device
})
except Exception as e: except Exception as e:
return jsonify({'error': str(e)}), 500 import traceback
traceback.print_exc()
task_running[task_id] = False
return jsonify(make_response(
status="error",
message=str(e),
errors=[str(e)]
)), 500
@app.route('/api/diar/load', methods=['POST']) @app.route('/api/result', methods=['GET'])
def load_diar_model(): def result():
"""加载 3D-Speaker 模型""" """获取文件推理结果"""
global DIAR_MODEL_LOADED
with DIAR_MODEL_LOCK:
if DIAR_MODEL_LOADED:
return jsonify({'message': '3D-Speaker 模型已加载', 'loaded': True})
try: try:
data = request.json or {} # 从请求参数获取路径
embedding_model = data.get('embedding_model', 'eres2netv2') path = request.args.get('path', '')
device = data.get('device', 'auto')
cluster_threshold = data.get('cluster_threshold', 0.5)
min_cluster_size = data.get('min_cluster_size', 10)
print(f"正在加载 3D-Speaker 模型: {embedding_model}, 设备: {device}") if not path:
service = get_diar_service() return jsonify(make_response(
service._load_model() status="error",
DIAR_MODEL_LOADED = True message="请提供文件路径",
errors=["缺少必要参数path"]
)), 400
return jsonify({ # 读取结果
'message': '3D-Speaker 模型加载成功', print(Path(path).stem)
'loaded': True, result_file = OUTPUT_DIR / f"{Path(path).stem}_result.json"
'embedding_model': embedding_model, if result_file.exists():
'device': device with open(result_file, 'r', encoding='utf-8') as f:
}) result_data = json.load(f)
except Exception as e: return jsonify(make_response(
return jsonify({'error': str(e)}), 500 status="success",
message="获取成功",
data=result_data
@app.route('/api/asr/unload', methods=['POST']) )), 200
def unload_asr_model():
"""卸载 ASR 模型"""
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED
try:
GLOBAL_ASR_SERVICE = None
ASR_MODEL_LOADED = False
gc.collect()
return jsonify({'message': 'ASR 模型已卸载'})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/diar/unload', methods=['POST'])
def unload_diar_model():
"""卸载 3D-Speaker 模型"""
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED
try:
GLOBAL_DIAR_SERVICE = None
DIAR_MODEL_LOADED = False
gc.collect()
return jsonify({'message': '3D-Speaker 模型已卸载'})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/recognize/single', methods=['POST'])
def recognize_single():
"""单文件推理"""
try:
if 'file' not in request.files:
return jsonify({'error': '请上传音频文件'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': '文件名不能为空'}), 400
data = request.form or {}
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
embedding_model = data.get('embedding_model', 'eres2netv2')
cluster_threshold = float(data.get('cluster_threshold', 0.5))
min_cluster_size = int(data.get('min_cluster_size', 10))
output_format = data.get('format', 'json')
filename = secure_filename(file.filename) if file.filename is not None else secure_filename("unnamed_file")
task_id = str(uuid.uuid4())[:8]
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
file.save(audio_path)
try:
global ASR_MODEL_LOADED
service = get_asr_service()
if not ASR_MODEL_LOADED:
service._load_model()
ASR_MODEL_LOADED = True
sentences = service.recognize(
audio_path
)
result = {
'file': filename,
'total_sentences': len(sentences),
'sentences': [s.to_dict() for s in sentences]
}
if output_format == 'json':
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.json")
with open(result_path, 'w', encoding='utf-8') as f:
import json
json.dump(result, f, ensure_ascii=False, indent=2)
return jsonify(result)
else: else:
srt_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.srt") return jsonify(make_response(
service.export_to_srt([s for s in sentences], srt_path) status="error",
return send_file(srt_path, as_attachment=True, download_name=f"{task_id}_result.srt") message="处理完成但未找到结果文件",
errors=["结果文件不存在"]
finally: )), 404
if os.path.exists(audio_path):
os.remove(audio_path)
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return jsonify({'error': str(e)}), 500 return jsonify(make_response(
status="error",
message=str(e),
@app.route('/api/recognize/batch', methods=['POST']) errors=[str(e)]
def recognize_batch(): )), 500
"""批量推理"""
try:
if 'files' not in request.files:
return jsonify({'error': '请上传音频文件'}), 400
files = request.files.getlist('files')
if not files or len(files) == 0:
return jsonify({'error': '文件列表为空'}), 400
data = request.form or {}
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
embedding_model = data.get('embedding_model', 'eres2netv2')
cluster_threshold = float(data.get('cluster_threshold', 0.5))
min_cluster_size = int(data.get('min_cluster_size', 10))
task_id = str(uuid.uuid4())[:8]
audio_paths = []
for f in files:
if f.filename:
filename = secure_filename(f.filename)
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
f.save(audio_path)
audio_paths.append(audio_path)
try:
global ASR_MODEL_LOADED
service = get_asr_service()
if not ASR_MODEL_LOADED:
service._load_model()
ASR_MODEL_LOADED = True
results = []
for audio_path in audio_paths:
try:
sentences = service.recognize(
audio_path,
)
results.append({
'file': os.path.basename(audio_path),
'total_sentences': len(sentences),
'sentences': [s.to_dict() for s in sentences]
})
except Exception as e:
results.append({
'file': os.path.basename(audio_path),
'error': str(e)
})
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_batch_result.json")
import json
with open(result_path, 'w', encoding='utf-8') as f:
json.dump({
'task_id': task_id,
'total_files': len(results),
'results': results
}, f, ensure_ascii=False, indent=2)
return jsonify({
'task_id': task_id,
'total_files': len(results),
'results': results,
'result_file': result_path
})
finally:
for audio_path in audio_paths:
if os.path.exists(audio_path):
os.remove(audio_path)
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'error': str(e)}), 500
@app.route('/api/status', methods=['GET'])
def get_status():
"""获取模型状态"""
return jsonify({
'asr_loaded': ASR_MODEL_LOADED,
'diar_loaded': DIAR_MODEL_LOADED,
'asr_model': 'paraformer-zh',
'diar_model': '3D-Speaker'
})
if __name__ == '__main__': if __name__ == '__main__':
print("=" * 60) print("=" * 60)
print(" ASR & Speaker Diarization API Server") print(" ASR & Speaker Diarization API Server")
print("=" * 60) print("=" * 60)
print("\nAPI 接口:") print("\nAPI 接口:")
print(" GET /health - 健康检查") print(" POST /api/recognize - 文件推理")
print(" GET /api/status - 获取模型状态") print(" POST /api/result - 获取文件推理结果")
print(" POST /api/asr/load - 加载 ASR 模型")
print(" POST /api/diar/load - 加载 3D-Speaker 模型")
print(" POST /api/asr/unload - 卸载 ASR 模型")
print(" POST /api/diar/unload - 卸载 3D-Speaker 模型")
print(" POST /api/recognize/single - 单文件推理")
print(" POST /api/recognize/batch - 批量推理")
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("启动服务: http://localhost:5000") print("启动服务http://localhost:5000")
print("使用 Waitress WSGI 服务器(无超时限制)")
print("=" * 60) print("=" * 60)
app.run(host='0.0.0.0', port=5000, debug=False)
from waitress import serve
serve(app, host='0.0.0.0', port=5000, threads=4, connection_limit=100)