89 lines
2.3 KiB
Python
89 lines
2.3 KiB
Python
from flask import Flask, request, jsonify
|
|
from transformers import pipeline
|
|
import logging
|
|
|
|
# 配置日志
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 初始化模型
|
|
try:
|
|
classifier = pipeline(
|
|
"zero-shot-classification",
|
|
model="./nlp_structbert_zero-shot-classification_chinese-large",
|
|
device="cpu",
|
|
max_length=512,
|
|
ignore_mismatched_sizes=True
|
|
)
|
|
logger.info("模型加载成功")
|
|
except Exception as e:
|
|
logger.error(f"模型加载失败: {str(e)}")
|
|
raise
|
|
|
|
# 定义类别标签
|
|
level1 = [
|
|
"食品", "电器", "洗护", "女装", "手机","健康", "男装", "美妆", "电脑", "运动","内衣", "母婴", "数码", "百货", "鞋包","办公", "家装", "饰品", "车品", "图书","生鲜", "家纺", "宠物", "奢品", "其它", "药品"
|
|
]
|
|
|
|
# 创建Flask应用
|
|
app = Flask(__name__)
|
|
app.config['JSON_AS_ASCII'] = False # 确保中文能正常显示
|
|
|
|
|
|
@app.route('/best', methods=['POST'])
|
|
def classify_text():
|
|
# 获取请求数据
|
|
data = request.get_json()
|
|
if not data or 'text' not in data:
|
|
return jsonify({"error": "缺少text参数"}), 400
|
|
if not data or 'cate' not in data:
|
|
return jsonify({"error": "缺少cate参数"}), 400
|
|
text = data['text']
|
|
cate = data['cate']
|
|
|
|
try:
|
|
# 执行分类
|
|
result = classifier(
|
|
text,
|
|
candidate_labels=cate,
|
|
truncation=True
|
|
)
|
|
|
|
labels = result['labels']
|
|
|
|
|
|
# 构建响应
|
|
response = {
|
|
"cate": labels[0],
|
|
}
|
|
return jsonify(response)
|
|
|
|
except Exception as e:
|
|
logger.error(f"分类过程中出错: {str(e)}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
@app.route('/health')
|
|
def health_check():
|
|
"""健康检查端点"""
|
|
try:
|
|
# 简单测试模型是否可用
|
|
test_result = classifier(
|
|
"测试文本",
|
|
candidate_labels=["食品", "电器"],
|
|
truncation=True
|
|
)
|
|
return jsonify({
|
|
"status": "healthy",
|
|
"model_loaded": True
|
|
})
|
|
except:
|
|
return jsonify({
|
|
"status": "unhealthy",
|
|
"model_loaded": False
|
|
}), 500
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# 运行应用
|
|
app.run(host='0.0.0.0', port=5000, debug=True) |