diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..cb0f56f
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..9b6e139
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..63455ab
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/select.iml b/.idea/select.iml
new file mode 100644
index 0000000..244c072
--- /dev/null
+++ b/.idea/select.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/bert_train.py b/bert_train.py
new file mode 100644
index 0000000..c5ecf6b
--- /dev/null
+++ b/bert_train.py
@@ -0,0 +1,216 @@
+import pandas as pd
+from sklearn.model_selection import train_test_split
+from transformers import (
+ BertTokenizer,
+ BertForSequenceClassification,
+ Trainer,
+ TrainingArguments,
+ EarlyStoppingCallback
+)
+import torch
+from torch.utils.data import Dataset
+from tqdm import tqdm
+import os
+import warnings
+import re # 用于正则表达式清洗
+from sklearn.preprocessing import LabelEncoder
+import joblib
+
+# 1. 参数配置(集中管理)
+class Config:
+ MODEL_NAME = "bert-base-chinese"
+ MAX_LENGTH = 64
+ BATCH_SIZE = 32
+ NUM_EPOCHS = 5
+ LEARNING_RATE = 2e-5
+ WARMUP_STEPS = 500
+ WEIGHT_DECAY = 0.01
+ FP16 = torch.cuda.is_available()
+ OUTPUT_DIR = "./results/single_level"
+ LOG_DIR = "./logs"
+ SAVE_DIR = "./saved_model/single_level"
+ DEVICE = "cuda" if FP16 else "cpu"
+
+
+# 2. 数据加载与预处理(添加异常处理和日志)
+def load_data(file_path):
+ try:
+ df = pd.read_csv(file_path)
+ assert {'sentence', 'label'}.issubset(df.columns), "数据必须包含'sentence'和'label'列"
+ print(f"✅ 数据加载成功 | 样本量: {len(df)} | 分类数: {df['label'].nunique()}")
+ return df
+ except Exception as e:
+ warnings.warn(f"❌ 数据加载失败: {str(e)}")
+ raise
+
+
+# 新增:数据清洗函数 - 只保留中文字符
+def clean_chinese_text(text):
+ """
+ 清洗文本,只保留中文字符
+ """
+ if not isinstance(text, str):
+ return ""
+ # 使用正则表达式匹配所有中文字符(包括中文标点符号)[^\u4e00-\u9fa5\u3000-\u303f\uff00-\uffef]
+ # 如果需要更严格的只保留汉字,可以使用:[\u4e00-\u9fa5]
+ cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text)
+ return cleaned_text.strip()
+
+
+# 3. 优化Dataset(添加内存缓存和批处理支持)
+class TextDataset(Dataset):
+ def __init__(self, dataframe, tokenizer, text_col="sentence", label_col="label"):
+ self.data = dataframe
+ self.tokenizer = tokenizer
+ self.text_col = text_col
+ self.label_col = label_col
+
+ # 预计算编码(空间换时间)
+ self.encodings = tokenizer(
+ dataframe[text_col].tolist(),
+ max_length=Config.MAX_LENGTH,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt"
+ )
+ self.labels = torch.tensor(dataframe[label_col].values, dtype=torch.long)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return {
+ "input_ids": self.encodings["input_ids"][idx],
+ "attention_mask": self.encodings["attention_mask"][idx],
+ "labels": self.labels[idx]
+ }
+
+
+# 4. 模型初始化(添加设备移动)
+def init_model(num_labels):
+ tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME)
+ model = BertForSequenceClassification.from_pretrained(
+ Config.MODEL_NAME,
+ num_labels=num_labels,
+ ignore_mismatched_sizes=True # 可选
+ ).to(Config.DEVICE)
+ return tokenizer, model
+
+
+# 5. 训练配置(添加早停和梯度累积)
+def get_training_args():
+ return TrainingArguments(
+ output_dir=Config.OUTPUT_DIR,
+ num_train_epochs=Config.NUM_EPOCHS,
+ per_device_train_batch_size=Config.BATCH_SIZE,
+ per_device_eval_batch_size=Config.BATCH_SIZE * 2, # 评估时可用更大batch
+ learning_rate=Config.LEARNING_RATE,
+ warmup_steps=Config.WARMUP_STEPS,
+ weight_decay=Config.WEIGHT_DECAY,
+ logging_dir=Config.LOG_DIR,
+ logging_steps=10,
+ eval_strategy="steps",
+ eval_steps=100,
+ save_strategy="steps",
+ save_steps=200,
+ load_best_model_at_end=True,
+ metric_for_best_model="eval_loss",
+ greater_is_better=False,
+ fp16=Config.FP16,
+ gradient_accumulation_steps=2, # 模拟更大batch
+ report_to="none", # 禁用wandb等报告
+ seed=42
+ )
+
+
+# 6. 优化推理函数(添加批处理支持)
+@torch.no_grad()
+def batch_predict(texts, model, tokenizer, label_map, top_k=1, batch_size=16):
+ model.eval()
+ all_results = []
+
+ for i in tqdm(range(0, len(texts), batch_size), desc="预测中"):
+ batch = texts[i:i + batch_size]
+ inputs = tokenizer(
+ batch,
+ return_tensors="pt",
+ truncation=True,
+ padding=True,
+ max_length=Config.MAX_LENGTH
+ ).to(Config.DEVICE)
+
+ outputs = model(**inputs)
+ probs = torch.softmax(outputs.logits, dim=1).cpu()
+
+ for prob in probs:
+ top_probs, top_indices = torch.topk(prob, k=top_k)
+ all_results.extend([
+ {
+ "category": label_map[idx.item()],
+ "confidence": prob.item()
+ }
+ for prob, idx in zip(top_probs, top_indices)
+ ])
+
+ return all_results[:len(texts)] # 处理非整除情况
+
+
+# 主流程
+if __name__ == "__main__":
+ # 1. 加载数据
+ df = load_data("goods_cate.csv")
+
+ # 2. 数据清洗 - 只保留中文
+ print("🧼 开始清洗文本数据...")
+ df['sentence'] = df['sentence'].apply(clean_chinese_text)
+ df = df[df['sentence'].str.len() > 0].reset_index(drop=True)
+ print(f"✅ 数据清洗完成 | 剩余样本量: {len(df)}")
+
+ # 3. 处理中文标签:映射为数值ID,并保存映射关系
+ print("🏷️ 处理中文标签...")
+ label_encoder = LabelEncoder()
+ df['label_id'] = label_encoder.fit_transform(df['label']) # 中文标签 → 数值ID
+ label_map = {i: label for i, label in enumerate(label_encoder.classes_)} # 数值ID → 中文标签
+ print(f"标签映射示例: {label_map}")
+
+ # 保存标签映射器(供推理时使用)
+ joblib.dump(label_encoder, "cate/label_encoder.pkl")
+ print(f"✅ 标签映射完成 | 类别数: {len(label_map)}")
+
+ # 4. 划分数据集(使用 label_id 列)
+ train_df, test_df = train_test_split(
+ df, test_size=0.2, random_state=42, stratify=df["label_id"] # 注意这里用 label_id
+ )
+
+ # 5. 初始化模型(使用数值标签的数量)
+ num_labels = len(label_map)
+ tokenizer, model = init_model(num_labels)
+
+ # 6. 准备数据集(使用 label_id 列)
+ train_dataset = TextDataset(train_df, tokenizer, label_col="label_id") # 指定 label_col
+ test_dataset = TextDataset(test_df, tokenizer, label_col="label_id")
+
+ # 7. 训练配置(保持不变)
+ training_args = get_training_args()
+
+ # 8. 训练器(保持不变)
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_dataset,
+ eval_dataset=test_dataset,
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
+ )
+
+ # 9. 训练和保存(保持不变)
+ trainer.train()
+ model.save_pretrained(Config.SAVE_DIR)
+ tokenizer.save_pretrained(Config.SAVE_DIR)
+ # 12. 测试推理
+ test_samples = ["不二家棒棒糖", "iPhone 15", "无线鼠标"]
+ # 先清洗测试样本
+ cleaned_samples = [clean_chinese_text(s) for s in test_samples]
+ predictions = batch_predict(cleaned_samples, model, tokenizer, label_map)
+ for sample, pred in zip(test_samples, predictions):
+ print(
+ f"输入: {sample}\n清洗后: {clean_chinese_text(sample)}\n预测: {pred['category']} (置信度: {pred['confidence']:.2f})\n")
\ No newline at end of file
diff --git a/app.py b/other/app.py
similarity index 82%
rename from app.py
rename to other/app.py
index 022b248..e5ebff3 100644
--- a/app.py
+++ b/other/app.py
@@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
try:
classifier = pipeline(
"zero-shot-classification",
- model="./nlp_structbert_zero-shot-classification_chinese-large",
+ model="./nlp_structbert_zero-shot-classification_chinese-base",
device="cpu",
max_length=512,
ignore_mismatched_sizes=True
@@ -20,10 +20,6 @@ except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
raise
-# 定义类别标签
-level1 = [
- "食品", "电器", "洗护", "女装", "手机","健康", "男装", "美妆", "电脑", "运动","内衣", "母婴", "数码", "百货", "鞋包","办公", "家装", "饰品", "车品", "图书","生鲜", "家纺", "宠物", "奢品", "其它", "药品"
-]
# 创建Flask应用
app = Flask(__name__)
@@ -49,12 +45,14 @@ def classify_text():
truncation=True
)
- labels = result['labels']
+ print(result)
+ print(result['scores'][-1])
# 构建响应
response = {
- "cate": labels[0],
+ "cate": result['labels'][-1],
+ "score": result['scores'][-1],
}
return jsonify(response)
diff --git a/test_my_mode.py b/test_my_mode.py
new file mode 100644
index 0000000..44bc30d
--- /dev/null
+++ b/test_my_mode.py
@@ -0,0 +1,129 @@
+from flask import Flask, request, jsonify
+import torch
+from transformers import BertTokenizer, BertForSequenceClassification
+import joblib
+import re
+from functools import lru_cache
+from concurrent.futures import ThreadPoolExecutor
+from typing import List, Dict
+import threading
+
+# 配置参数
+MAX_LENGTH = 512
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+SAVE_DIR = "./cate1" # 模型路径
+LABEL_ENCODER_PATH = "cate1/label_encoder.pkl" # 标签映射器路径
+
+app = Flask(__name__)
+
+# 全局变量锁
+model_lock = threading.Lock()
+
+
+class Predictor:
+ """预测器类,用于管理模型和分词器的生命周期"""
+ _instance = None
+
+ def __new__(cls):
+ if cls._instance is None:
+ with model_lock:
+ if cls._instance is None: # 双重检查锁定
+ cls._instance = super().__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self):
+ if self._initialized:
+ return
+
+ # 加载模型和分词器
+ self.tokenizer = BertTokenizer.from_pretrained(SAVE_DIR)
+ self.model = BertForSequenceClassification.from_pretrained(SAVE_DIR).to(DEVICE)
+ self.model.eval()
+
+ # 加载标签映射器
+ self.label_encoder = joblib.load(LABEL_ENCODER_PATH)
+ self.label_map = {i: label for i, label in enumerate(self.label_encoder.classes_)}
+
+ # 创建线程池
+ self.executor = ThreadPoolExecutor(max_workers=4)
+ self._initialized = True
+
+ @lru_cache(maxsize=1000)
+ def clean_text(self, text: str) -> str:
+ """文本清洗函数,带缓存"""
+ if not isinstance(text, str):
+ return ""
+ cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text)
+ return cleaned_text.strip()
+
+ def predict_single(self, text: str) -> Dict:
+ """单个文本预测"""
+ with torch.no_grad():
+ cleaned_text = self.clean_text(text)
+ if not cleaned_text:
+ return {"category": "", "confidence": 0.0}
+
+ inputs = self.tokenizer(
+ cleaned_text,
+ return_tensors="pt",
+ truncation=True,
+ padding=True,
+ max_length=MAX_LENGTH
+ ).to(DEVICE)
+
+ outputs = self.model(**inputs)
+ probs = torch.softmax(outputs.logits, dim=1).cpu()
+ top_prob, top_idx = torch.topk(probs, k=1)
+
+ #去掉置信度,速度提升10%-30%
+ # top_idx = torch.argmax(outputs.logits, dim=1)
+ return {
+ "category": self.label_map[top_idx.item()],
+ "confidence": top_prob.item()
+ }
+
+ def batch_predict(self, texts: List[str]) -> List[Dict]:
+ """批量预测(并发处理)"""
+ if not texts:
+ return []
+
+ # 使用线程池并发处理
+ futures = [self.executor.submit(self.predict_single, text) for text in texts]
+ return [future.result() for future in futures]
+
+
+# 初始化预测器
+predictor = Predictor()
+
+
+@app.route('/predict', methods=['POST'])
+def predict():
+ """预测接口"""
+ data = request.get_json()
+ if not data or 'product_names' not in data:
+ return jsonify({"error": "Invalid request, 'product_names' is required"}), 400
+
+ product_names = data['product_names']
+ if not isinstance(product_names, list):
+ product_names = [product_names]
+
+ # 批量预测
+ results = predictor.batch_predict(product_names)
+
+ return jsonify({
+ "status": "success",
+ "predictions": results
+ })
+
+
+@app.route('/health', methods=['GET'])
+def health_check():
+ """健康检查接口"""
+ return jsonify({"status": "healthy"})
+
+
+if __name__ == '__main__':
+ # 生产环境建议使用:
+ # gunicorn -w 4 -b 0.0.0.0:5000 app:app --timeout 120
+ app.run(host='0.0.0.0', port=5002, threaded=True)
\ No newline at end of file