bert_cate/app.py

131 lines
3.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Flask, request, jsonify
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import joblib
import re
from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict
import threading
# 配置参数
MAX_LENGTH = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_DIR = "./cate1" # 模型路径
LABEL_ENCODER_PATH = "cate1/label_encoder.pkl" # 标签映射器路径
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False
# 全局变量锁
model_lock = threading.Lock()
class Predictor:
"""预测器类,用于管理模型和分词器的生命周期"""
_instance = None
def __new__(cls):
if cls._instance is None:
with model_lock:
if cls._instance is None: # 双重检查锁定
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# 加载模型和分词器
self.tokenizer = BertTokenizer.from_pretrained(SAVE_DIR)
self.model = BertForSequenceClassification.from_pretrained(SAVE_DIR).to(DEVICE)
self.model.eval()
# 加载标签映射器
self.label_encoder = joblib.load(LABEL_ENCODER_PATH)
self.label_map = {i: label for i, label in enumerate(self.label_encoder.classes_)}
# 创建线程池
self.executor = ThreadPoolExecutor(max_workers=4)
self._initialized = True
@lru_cache(maxsize=1000)
def clean_text(self, text: str) -> str:
"""文本清洗函数,带缓存"""
if not isinstance(text, str):
return ""
cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text)
return cleaned_text.strip()
def predict_single(self, text: str) -> Dict:
"""单个文本预测"""
with torch.no_grad():
cleaned_text = self.clean_text(text)
if not cleaned_text:
return {"category": "", "confidence": 0.0}
inputs = self.tokenizer(
cleaned_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=MAX_LENGTH
).to(DEVICE)
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).cpu()
top_prob, top_idx = torch.topk(probs, k=1)
#去掉置信度速度提升10%-30%
# top_idx = torch.argmax(outputs.logits, dim=1)
return {
"category": self.label_map[top_idx.item()],
"confidence": top_prob.item()
}
def batch_predict(self, texts: List[str]) -> List[Dict]:
"""批量预测(并发处理)"""
if not texts:
return []
# 使用线程池并发处理
futures = [self.executor.submit(self.predict_single, text) for text in texts]
return [future.result() for future in futures]
# 初始化预测器
predictor = Predictor()
@app.route('/predict', methods=['POST'])
def predict():
"""预测接口"""
data = request.get_json()
if not data or 'product_names' not in data:
return jsonify({"error": "Invalid request, 'product_names' is required"}), 400
product_names = data['product_names']
if not isinstance(product_names, list):
product_names = [product_names]
# 批量预测
results = predictor.batch_predict(product_names)
return jsonify({
"status": "success",
"predictions": results
})
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查接口"""
return jsonify({"status": "healthy"})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, threaded=True)
else:
application = app # 兼容 WSGI 标准(如 Gunicorn