87 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			87 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
from flask import Flask, request, jsonify
 | 
						|
from transformers import pipeline
 | 
						|
import logging
 | 
						|
 | 
						|
# 配置日志
 | 
						|
logging.basicConfig(level=logging.INFO)
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
# 初始化模型
 | 
						|
try:
 | 
						|
    classifier = pipeline(
 | 
						|
        "zero-shot-classification",
 | 
						|
        model="./nlp_structbert_zero-shot-classification_chinese-base",
 | 
						|
        device="cpu",
 | 
						|
        max_length=512,
 | 
						|
        ignore_mismatched_sizes=True
 | 
						|
    )
 | 
						|
    logger.info("模型加载成功")
 | 
						|
except Exception as e:
 | 
						|
    logger.error(f"模型加载失败: {str(e)}")
 | 
						|
    raise
 | 
						|
 | 
						|
 | 
						|
# 创建Flask应用
 | 
						|
app = Flask(__name__)
 | 
						|
app.config['JSON_AS_ASCII'] = False  # 确保中文能正常显示
 | 
						|
 | 
						|
 | 
						|
@app.route('/best', methods=['POST'])
 | 
						|
def classify_text():
 | 
						|
    # 获取请求数据
 | 
						|
    data = request.get_json()
 | 
						|
    if not data or 'text' not in data:
 | 
						|
        return jsonify({"error": "缺少text参数"}), 400
 | 
						|
    if not data or 'cate' not in data:
 | 
						|
        return jsonify({"error": "缺少cate参数"}), 400
 | 
						|
    text = data['text']
 | 
						|
    cate = data['cate']
 | 
						|
 | 
						|
    try:
 | 
						|
        # 执行分类
 | 
						|
        result = classifier(
 | 
						|
            text,
 | 
						|
            candidate_labels=cate,
 | 
						|
            truncation=True
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
        print(result)
 | 
						|
        print(result['scores'][-1])
 | 
						|
 | 
						|
        # 构建响应
 | 
						|
        response = {
 | 
						|
            "cate": result['labels'][-1],
 | 
						|
            "score": result['scores'][-1],
 | 
						|
        }
 | 
						|
        return jsonify(response)
 | 
						|
 | 
						|
    except Exception as e:
 | 
						|
        logger.error(f"分类过程中出错: {str(e)}")
 | 
						|
        return jsonify({"error": str(e)}), 500
 | 
						|
 | 
						|
 | 
						|
@app.route('/health')
 | 
						|
def health_check():
 | 
						|
    """健康检查端点"""
 | 
						|
    try:
 | 
						|
        # 简单测试模型是否可用
 | 
						|
        test_result = classifier(
 | 
						|
            "测试文本",
 | 
						|
            candidate_labels=["食品", "电器"],
 | 
						|
            truncation=True
 | 
						|
        )
 | 
						|
        return jsonify({
 | 
						|
            "status": "healthy",
 | 
						|
            "model_loaded": True
 | 
						|
        })
 | 
						|
    except:
 | 
						|
        return jsonify({
 | 
						|
            "status": "unhealthy",
 | 
						|
            "model_loaded": False
 | 
						|
        }), 500
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    # 运行应用
 | 
						|
    app.run(host='0.0.0.0', port=5000, debug=True) |