bert_cate/other/app.py

87 lines
2.1 KiB
Python

from flask import Flask, request, jsonify
from transformers import pipeline
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 初始化模型
try:
classifier = pipeline(
"zero-shot-classification",
model="./nlp_structbert_zero-shot-classification_chinese-base",
device="cpu",
max_length=512,
ignore_mismatched_sizes=True
)
logger.info("模型加载成功")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
raise
# 创建Flask应用
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False # 确保中文能正常显示
@app.route('/best', methods=['POST'])
def classify_text():
# 获取请求数据
data = request.get_json()
if not data or 'text' not in data:
return jsonify({"error": "缺少text参数"}), 400
if not data or 'cate' not in data:
return jsonify({"error": "缺少cate参数"}), 400
text = data['text']
cate = data['cate']
try:
# 执行分类
result = classifier(
text,
candidate_labels=cate,
truncation=True
)
print(result)
print(result['scores'][-1])
# 构建响应
response = {
"cate": result['labels'][-1],
"score": result['scores'][-1],
}
return jsonify(response)
except Exception as e:
logger.error(f"分类过程中出错: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/health')
def health_check():
"""健康检查端点"""
try:
# 简单测试模型是否可用
test_result = classifier(
"测试文本",
candidate_labels=["食品", "电器"],
truncation=True
)
return jsonify({
"status": "healthy",
"model_loaded": True
})
except:
return jsonify({
"status": "unhealthy",
"model_loaded": False
}), 500
if __name__ == '__main__':
# 运行应用
app.run(host='0.0.0.0', port=5000, debug=True)