This commit is contained in:
commit
faeef3050e
|
@ -0,0 +1,22 @@
|
|||
# 使用官方 Python 基础镜像
|
||||
FROM python:3.8-slim
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖(如果需要)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制项目文件
|
||||
COPY . /app
|
||||
|
||||
# 安装 Python 依赖
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 5000
|
||||
|
||||
# 启动命令(使用 Gunicorn 替代 Flask 开发服务器)
|
||||
#CMD ["gunicorn", "--bind", "0.0.0.0:5000", "--workers", "4", "--timeout", "300", "app:app"]
|
|
@ -0,0 +1,89 @@
|
|||
from flask import Flask, request, jsonify
|
||||
from transformers import pipeline
|
||||
import logging
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化模型
|
||||
try:
|
||||
classifier = pipeline(
|
||||
"zero-shot-classification",
|
||||
model="./nlp_structbert_zero-shot-classification_chinese-large",
|
||||
device="cpu",
|
||||
max_length=512,
|
||||
ignore_mismatched_sizes=True
|
||||
)
|
||||
logger.info("模型加载成功")
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {str(e)}")
|
||||
raise
|
||||
|
||||
# 定义类别标签
|
||||
level1 = [
|
||||
"食品", "电器", "洗护", "女装", "手机","健康", "男装", "美妆", "电脑", "运动","内衣", "母婴", "数码", "百货", "鞋包","办公", "家装", "饰品", "车品", "图书","生鲜", "家纺", "宠物", "奢品", "其它", "药品"
|
||||
]
|
||||
|
||||
# 创建Flask应用
|
||||
app = Flask(__name__)
|
||||
app.config['JSON_AS_ASCII'] = False # 确保中文能正常显示
|
||||
|
||||
|
||||
@app.route('/best', methods=['POST'])
|
||||
def classify_text():
|
||||
# 获取请求数据
|
||||
data = request.get_json()
|
||||
if not data or 'text' not in data:
|
||||
return jsonify({"error": "缺少text参数"}), 400
|
||||
if not data or 'cate' not in data:
|
||||
return jsonify({"error": "缺少cate参数"}), 400
|
||||
text = data['text']
|
||||
cate = data['cate']
|
||||
|
||||
try:
|
||||
# 执行分类
|
||||
result = classifier(
|
||||
text,
|
||||
candidate_labels=cate,
|
||||
truncation=True
|
||||
)
|
||||
|
||||
labels = result['labels']
|
||||
|
||||
|
||||
# 构建响应
|
||||
response = {
|
||||
"cate": labels[0],
|
||||
}
|
||||
return jsonify(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分类过程中出错: {str(e)}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@app.route('/health')
|
||||
def health_check():
|
||||
"""健康检查端点"""
|
||||
try:
|
||||
# 简单测试模型是否可用
|
||||
test_result = classifier(
|
||||
"测试文本",
|
||||
candidate_labels=["食品", "电器"],
|
||||
truncation=True
|
||||
)
|
||||
return jsonify({
|
||||
"status": "healthy",
|
||||
"model_loaded": True
|
||||
})
|
||||
except:
|
||||
return jsonify({
|
||||
"status": "unhealthy",
|
||||
"model_loaded": False
|
||||
}), 500
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 运行应用
|
||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
|
@ -0,0 +1,32 @@
|
|||
from transformers import pipeline
|
||||
|
||||
|
||||
classifier = pipeline(
|
||||
"zero-shot-classification",
|
||||
model="./nlp_structbert_zero-shot-classification_chinese-large",
|
||||
device="cpu",
|
||||
max_length=512,
|
||||
ignore_mismatched_sizes=True # 忽略维度不匹配警告
|
||||
)
|
||||
level1 = [
|
||||
"食品", "电器", "洗护", "女装", "手机",
|
||||
"健康", "男装", "美妆", "电脑", "运动",
|
||||
"内衣", "母婴", "数码", "百货", "鞋包",
|
||||
"办公", "家装", "饰品", "车品", "图书",
|
||||
"生鲜", "家纺", "宠物", "奢品", "其它","药品"
|
||||
]
|
||||
|
||||
|
||||
def theBestAndLow(goods):
|
||||
result = classifier(
|
||||
goods,
|
||||
candidate_labels=level1,
|
||||
truncation=True # 显式指定截断策略
|
||||
)
|
||||
labels = result['labels']
|
||||
scores = result['scores']
|
||||
|
||||
print("最高分标签:", labels[0], "得分:", scores[0])
|
||||
print("最低分标签:", labels[-1], "得分:", scores[-1])
|
||||
|
||||
theBestAndLow( "宇宙超萌儿童辅食有机果蔬蝴蝶面210g 数量:1盒装:")
|
|
@ -0,0 +1,27 @@
|
|||
from transformers import pipeline
|
||||
|
||||
|
||||
classifier = pipeline(
|
||||
"zero-shot-classification",
|
||||
model="./nlp_structbert_zero-shot-classification_chinese-base",
|
||||
device="cpu",
|
||||
|
||||
max_length=512,
|
||||
ignore_mismatched_sizes=True # 忽略维度不匹配警告
|
||||
)
|
||||
level1 = [
|
||||
"食品", "电器", "洗护", "女装", "手机",
|
||||
"健康", "男装", "美妆", "电脑", "运动",
|
||||
"内衣", "母婴", "数码", "百货", "鞋包",
|
||||
"办公", "家装", "饰品", "车品", "图书",
|
||||
"生鲜", "家纺", "宠物", "奢品", "其它"
|
||||
]
|
||||
result = classifier(
|
||||
"CONBA/康恩贝森兰康牌铁皮石斛西洋参颗粒 规格:3g/包*60包:",
|
||||
candidate_labels=level1,
|
||||
truncation=True # 显式指定截断策略
|
||||
)
|
||||
|
||||
|
||||
|
||||
print(result)
|
|
@ -0,0 +1,32 @@
|
|||
# 基础依赖
|
||||
Flask==3.1.0
|
||||
gunicorn==20.0.4 # 生产环境推荐
|
||||
transformers==4.26.0
|
||||
|
||||
# 数据处理
|
||||
numpy==2.0.1
|
||||
pandas==2.3.1
|
||||
|
||||
# 网络与请求
|
||||
requests==2.32.4
|
||||
urllib3==2.5.0
|
||||
|
||||
# 模型相关
|
||||
huggingface-hub==0.34.3
|
||||
tokenizers==0.13.3
|
||||
safetensors==0.5.3
|
||||
|
||||
# 工具库
|
||||
click==8.1.8
|
||||
Jinja2==3.1.6
|
||||
MarkupSafe==3.0.2
|
||||
itsdangerous==2.2.0
|
||||
Werkzeug==3.1.3
|
||||
pyyaml==6.0.2
|
||||
tqdm==4.67.1
|
||||
simplejson==3.20.1
|
||||
|
||||
|
||||
torch==2.7.1 # 如果使用GPU或需要特定版本
|
||||
datasets==3.3.2 # 如果需要数据集处理
|
||||
pillow==11.3.0 # 如果涉及图像处理
|
Loading…
Reference in New Issue