This commit is contained in:
parent
9e2e2e02a5
commit
49e6e2d29e
|
@ -0,0 +1,8 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
|
@ -0,0 +1,14 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="1">
|
||||
<item index="0" class="java.lang.String" itemvalue="betterproto" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
|
@ -0,0 +1,6 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="/root/miniconda3" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="311" project-jdk-type="Python SDK" />
|
||||
</project>
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/select.iml" filepath="$PROJECT_DIR$/.idea/select.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="311" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,216 @@
|
|||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
from transformers import (
|
||||
BertTokenizer,
|
||||
BertForSequenceClassification,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
EarlyStoppingCallback
|
||||
)
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import warnings
|
||||
import re # 用于正则表达式清洗
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
import joblib
|
||||
|
||||
# 1. 参数配置(集中管理)
|
||||
class Config:
|
||||
MODEL_NAME = "bert-base-chinese"
|
||||
MAX_LENGTH = 64
|
||||
BATCH_SIZE = 32
|
||||
NUM_EPOCHS = 5
|
||||
LEARNING_RATE = 2e-5
|
||||
WARMUP_STEPS = 500
|
||||
WEIGHT_DECAY = 0.01
|
||||
FP16 = torch.cuda.is_available()
|
||||
OUTPUT_DIR = "./results/single_level"
|
||||
LOG_DIR = "./logs"
|
||||
SAVE_DIR = "./saved_model/single_level"
|
||||
DEVICE = "cuda" if FP16 else "cpu"
|
||||
|
||||
|
||||
# 2. 数据加载与预处理(添加异常处理和日志)
|
||||
def load_data(file_path):
|
||||
try:
|
||||
df = pd.read_csv(file_path)
|
||||
assert {'sentence', 'label'}.issubset(df.columns), "数据必须包含'sentence'和'label'列"
|
||||
print(f"✅ 数据加载成功 | 样本量: {len(df)} | 分类数: {df['label'].nunique()}")
|
||||
return df
|
||||
except Exception as e:
|
||||
warnings.warn(f"❌ 数据加载失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# 新增:数据清洗函数 - 只保留中文字符
|
||||
def clean_chinese_text(text):
|
||||
"""
|
||||
清洗文本,只保留中文字符
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
return ""
|
||||
# 使用正则表达式匹配所有中文字符(包括中文标点符号)[^\u4e00-\u9fa5\u3000-\u303f\uff00-\uffef]
|
||||
# 如果需要更严格的只保留汉字,可以使用:[\u4e00-\u9fa5]
|
||||
cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text)
|
||||
return cleaned_text.strip()
|
||||
|
||||
|
||||
# 3. 优化Dataset(添加内存缓存和批处理支持)
|
||||
class TextDataset(Dataset):
|
||||
def __init__(self, dataframe, tokenizer, text_col="sentence", label_col="label"):
|
||||
self.data = dataframe
|
||||
self.tokenizer = tokenizer
|
||||
self.text_col = text_col
|
||||
self.label_col = label_col
|
||||
|
||||
# 预计算编码(空间换时间)
|
||||
self.encodings = tokenizer(
|
||||
dataframe[text_col].tolist(),
|
||||
max_length=Config.MAX_LENGTH,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
self.labels = torch.tensor(dataframe[label_col].values, dtype=torch.long)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.encodings["input_ids"][idx],
|
||||
"attention_mask": self.encodings["attention_mask"][idx],
|
||||
"labels": self.labels[idx]
|
||||
}
|
||||
|
||||
|
||||
# 4. 模型初始化(添加设备移动)
|
||||
def init_model(num_labels):
|
||||
tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME)
|
||||
model = BertForSequenceClassification.from_pretrained(
|
||||
Config.MODEL_NAME,
|
||||
num_labels=num_labels,
|
||||
ignore_mismatched_sizes=True # 可选
|
||||
).to(Config.DEVICE)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
# 5. 训练配置(添加早停和梯度累积)
|
||||
def get_training_args():
|
||||
return TrainingArguments(
|
||||
output_dir=Config.OUTPUT_DIR,
|
||||
num_train_epochs=Config.NUM_EPOCHS,
|
||||
per_device_train_batch_size=Config.BATCH_SIZE,
|
||||
per_device_eval_batch_size=Config.BATCH_SIZE * 2, # 评估时可用更大batch
|
||||
learning_rate=Config.LEARNING_RATE,
|
||||
warmup_steps=Config.WARMUP_STEPS,
|
||||
weight_decay=Config.WEIGHT_DECAY,
|
||||
logging_dir=Config.LOG_DIR,
|
||||
logging_steps=10,
|
||||
eval_strategy="steps",
|
||||
eval_steps=100,
|
||||
save_strategy="steps",
|
||||
save_steps=200,
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="eval_loss",
|
||||
greater_is_better=False,
|
||||
fp16=Config.FP16,
|
||||
gradient_accumulation_steps=2, # 模拟更大batch
|
||||
report_to="none", # 禁用wandb等报告
|
||||
seed=42
|
||||
)
|
||||
|
||||
|
||||
# 6. 优化推理函数(添加批处理支持)
|
||||
@torch.no_grad()
|
||||
def batch_predict(texts, model, tokenizer, label_map, top_k=1, batch_size=16):
|
||||
model.eval()
|
||||
all_results = []
|
||||
|
||||
for i in tqdm(range(0, len(texts), batch_size), desc="预测中"):
|
||||
batch = texts[i:i + batch_size]
|
||||
inputs = tokenizer(
|
||||
batch,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
padding=True,
|
||||
max_length=Config.MAX_LENGTH
|
||||
).to(Config.DEVICE)
|
||||
|
||||
outputs = model(**inputs)
|
||||
probs = torch.softmax(outputs.logits, dim=1).cpu()
|
||||
|
||||
for prob in probs:
|
||||
top_probs, top_indices = torch.topk(prob, k=top_k)
|
||||
all_results.extend([
|
||||
{
|
||||
"category": label_map[idx.item()],
|
||||
"confidence": prob.item()
|
||||
}
|
||||
for prob, idx in zip(top_probs, top_indices)
|
||||
])
|
||||
|
||||
return all_results[:len(texts)] # 处理非整除情况
|
||||
|
||||
|
||||
# 主流程
|
||||
if __name__ == "__main__":
|
||||
# 1. 加载数据
|
||||
df = load_data("goods_cate.csv")
|
||||
|
||||
# 2. 数据清洗 - 只保留中文
|
||||
print("🧼 开始清洗文本数据...")
|
||||
df['sentence'] = df['sentence'].apply(clean_chinese_text)
|
||||
df = df[df['sentence'].str.len() > 0].reset_index(drop=True)
|
||||
print(f"✅ 数据清洗完成 | 剩余样本量: {len(df)}")
|
||||
|
||||
# 3. 处理中文标签:映射为数值ID,并保存映射关系
|
||||
print("🏷️ 处理中文标签...")
|
||||
label_encoder = LabelEncoder()
|
||||
df['label_id'] = label_encoder.fit_transform(df['label']) # 中文标签 → 数值ID
|
||||
label_map = {i: label for i, label in enumerate(label_encoder.classes_)} # 数值ID → 中文标签
|
||||
print(f"标签映射示例: {label_map}")
|
||||
|
||||
# 保存标签映射器(供推理时使用)
|
||||
joblib.dump(label_encoder, "cate/label_encoder.pkl")
|
||||
print(f"✅ 标签映射完成 | 类别数: {len(label_map)}")
|
||||
|
||||
# 4. 划分数据集(使用 label_id 列)
|
||||
train_df, test_df = train_test_split(
|
||||
df, test_size=0.2, random_state=42, stratify=df["label_id"] # 注意这里用 label_id
|
||||
)
|
||||
|
||||
# 5. 初始化模型(使用数值标签的数量)
|
||||
num_labels = len(label_map)
|
||||
tokenizer, model = init_model(num_labels)
|
||||
|
||||
# 6. 准备数据集(使用 label_id 列)
|
||||
train_dataset = TextDataset(train_df, tokenizer, label_col="label_id") # 指定 label_col
|
||||
test_dataset = TextDataset(test_df, tokenizer, label_col="label_id")
|
||||
|
||||
# 7. 训练配置(保持不变)
|
||||
training_args = get_training_args()
|
||||
|
||||
# 8. 训练器(保持不变)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=test_dataset,
|
||||
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
|
||||
)
|
||||
|
||||
# 9. 训练和保存(保持不变)
|
||||
trainer.train()
|
||||
model.save_pretrained(Config.SAVE_DIR)
|
||||
tokenizer.save_pretrained(Config.SAVE_DIR)
|
||||
# 12. 测试推理
|
||||
test_samples = ["不二家棒棒糖", "iPhone 15", "无线鼠标"]
|
||||
# 先清洗测试样本
|
||||
cleaned_samples = [clean_chinese_text(s) for s in test_samples]
|
||||
predictions = batch_predict(cleaned_samples, model, tokenizer, label_map)
|
||||
for sample, pred in zip(test_samples, predictions):
|
||||
print(
|
||||
f"输入: {sample}\n清洗后: {clean_chinese_text(sample)}\n预测: {pred['category']} (置信度: {pred['confidence']:.2f})\n")
|
|
@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
|
|||
try:
|
||||
classifier = pipeline(
|
||||
"zero-shot-classification",
|
||||
model="./nlp_structbert_zero-shot-classification_chinese-large",
|
||||
model="./nlp_structbert_zero-shot-classification_chinese-base",
|
||||
device="cpu",
|
||||
max_length=512,
|
||||
ignore_mismatched_sizes=True
|
||||
|
@ -20,10 +20,6 @@ except Exception as e:
|
|||
logger.error(f"模型加载失败: {str(e)}")
|
||||
raise
|
||||
|
||||
# 定义类别标签
|
||||
level1 = [
|
||||
"食品", "电器", "洗护", "女装", "手机","健康", "男装", "美妆", "电脑", "运动","内衣", "母婴", "数码", "百货", "鞋包","办公", "家装", "饰品", "车品", "图书","生鲜", "家纺", "宠物", "奢品", "其它", "药品"
|
||||
]
|
||||
|
||||
# 创建Flask应用
|
||||
app = Flask(__name__)
|
||||
|
@ -49,12 +45,14 @@ def classify_text():
|
|||
truncation=True
|
||||
)
|
||||
|
||||
labels = result['labels']
|
||||
|
||||
print(result)
|
||||
print(result['scores'][-1])
|
||||
|
||||
# 构建响应
|
||||
response = {
|
||||
"cate": labels[0],
|
||||
"cate": result['labels'][-1],
|
||||
"score": result['scores'][-1],
|
||||
}
|
||||
return jsonify(response)
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
from flask import Flask, request, jsonify
|
||||
import torch
|
||||
from transformers import BertTokenizer, BertForSequenceClassification
|
||||
import joblib
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Dict
|
||||
import threading
|
||||
|
||||
# 配置参数
|
||||
MAX_LENGTH = 512
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
SAVE_DIR = "./cate1" # 模型路径
|
||||
LABEL_ENCODER_PATH = "cate1/label_encoder.pkl" # 标签映射器路径
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# 全局变量锁
|
||||
model_lock = threading.Lock()
|
||||
|
||||
|
||||
class Predictor:
|
||||
"""预测器类,用于管理模型和分词器的生命周期"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with model_lock:
|
||||
if cls._instance is None: # 双重检查锁定
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 加载模型和分词器
|
||||
self.tokenizer = BertTokenizer.from_pretrained(SAVE_DIR)
|
||||
self.model = BertForSequenceClassification.from_pretrained(SAVE_DIR).to(DEVICE)
|
||||
self.model.eval()
|
||||
|
||||
# 加载标签映射器
|
||||
self.label_encoder = joblib.load(LABEL_ENCODER_PATH)
|
||||
self.label_map = {i: label for i, label in enumerate(self.label_encoder.classes_)}
|
||||
|
||||
# 创建线程池
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
self._initialized = True
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def clean_text(self, text: str) -> str:
|
||||
"""文本清洗函数,带缓存"""
|
||||
if not isinstance(text, str):
|
||||
return ""
|
||||
cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text)
|
||||
return cleaned_text.strip()
|
||||
|
||||
def predict_single(self, text: str) -> Dict:
|
||||
"""单个文本预测"""
|
||||
with torch.no_grad():
|
||||
cleaned_text = self.clean_text(text)
|
||||
if not cleaned_text:
|
||||
return {"category": "", "confidence": 0.0}
|
||||
|
||||
inputs = self.tokenizer(
|
||||
cleaned_text,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
padding=True,
|
||||
max_length=MAX_LENGTH
|
||||
).to(DEVICE)
|
||||
|
||||
outputs = self.model(**inputs)
|
||||
probs = torch.softmax(outputs.logits, dim=1).cpu()
|
||||
top_prob, top_idx = torch.topk(probs, k=1)
|
||||
|
||||
#去掉置信度,速度提升10%-30%
|
||||
# top_idx = torch.argmax(outputs.logits, dim=1)
|
||||
return {
|
||||
"category": self.label_map[top_idx.item()],
|
||||
"confidence": top_prob.item()
|
||||
}
|
||||
|
||||
def batch_predict(self, texts: List[str]) -> List[Dict]:
|
||||
"""批量预测(并发处理)"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# 使用线程池并发处理
|
||||
futures = [self.executor.submit(self.predict_single, text) for text in texts]
|
||||
return [future.result() for future in futures]
|
||||
|
||||
|
||||
# 初始化预测器
|
||||
predictor = Predictor()
|
||||
|
||||
|
||||
@app.route('/predict', methods=['POST'])
|
||||
def predict():
|
||||
"""预测接口"""
|
||||
data = request.get_json()
|
||||
if not data or 'product_names' not in data:
|
||||
return jsonify({"error": "Invalid request, 'product_names' is required"}), 400
|
||||
|
||||
product_names = data['product_names']
|
||||
if not isinstance(product_names, list):
|
||||
product_names = [product_names]
|
||||
|
||||
# 批量预测
|
||||
results = predictor.batch_predict(product_names)
|
||||
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"predictions": results
|
||||
})
|
||||
|
||||
|
||||
@app.route('/health', methods=['GET'])
|
||||
def health_check():
|
||||
"""健康检查接口"""
|
||||
return jsonify({"status": "healthy"})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 生产环境建议使用:
|
||||
# gunicorn -w 4 -b 0.0.0.0:5000 app:app --timeout 120
|
||||
app.run(host='0.0.0.0', port=5002, threaded=True)
|
Loading…
Reference in New Issue