diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..cb0f56f --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,14 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..9b6e139 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..63455ab --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/select.iml b/.idea/select.iml new file mode 100644 index 0000000..244c072 --- /dev/null +++ b/.idea/select.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/bert_train.py b/bert_train.py new file mode 100644 index 0000000..c5ecf6b --- /dev/null +++ b/bert_train.py @@ -0,0 +1,216 @@ +import pandas as pd +from sklearn.model_selection import train_test_split +from transformers import ( + BertTokenizer, + BertForSequenceClassification, + Trainer, + TrainingArguments, + EarlyStoppingCallback +) +import torch +from torch.utils.data import Dataset +from tqdm import tqdm +import os +import warnings +import re # 用于正则表达式清洗 +from sklearn.preprocessing import LabelEncoder +import joblib + +# 1. 参数配置(集中管理) +class Config: + MODEL_NAME = "bert-base-chinese" + MAX_LENGTH = 64 + BATCH_SIZE = 32 + NUM_EPOCHS = 5 + LEARNING_RATE = 2e-5 + WARMUP_STEPS = 500 + WEIGHT_DECAY = 0.01 + FP16 = torch.cuda.is_available() + OUTPUT_DIR = "./results/single_level" + LOG_DIR = "./logs" + SAVE_DIR = "./saved_model/single_level" + DEVICE = "cuda" if FP16 else "cpu" + + +# 2. 数据加载与预处理(添加异常处理和日志) +def load_data(file_path): + try: + df = pd.read_csv(file_path) + assert {'sentence', 'label'}.issubset(df.columns), "数据必须包含'sentence'和'label'列" + print(f"✅ 数据加载成功 | 样本量: {len(df)} | 分类数: {df['label'].nunique()}") + return df + except Exception as e: + warnings.warn(f"❌ 数据加载失败: {str(e)}") + raise + + +# 新增:数据清洗函数 - 只保留中文字符 +def clean_chinese_text(text): + """ + 清洗文本,只保留中文字符 + """ + if not isinstance(text, str): + return "" + # 使用正则表达式匹配所有中文字符(包括中文标点符号)[^\u4e00-\u9fa5\u3000-\u303f\uff00-\uffef] + # 如果需要更严格的只保留汉字,可以使用:[\u4e00-\u9fa5] + cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text) + return cleaned_text.strip() + + +# 3. 优化Dataset(添加内存缓存和批处理支持) +class TextDataset(Dataset): + def __init__(self, dataframe, tokenizer, text_col="sentence", label_col="label"): + self.data = dataframe + self.tokenizer = tokenizer + self.text_col = text_col + self.label_col = label_col + + # 预计算编码(空间换时间) + self.encodings = tokenizer( + dataframe[text_col].tolist(), + max_length=Config.MAX_LENGTH, + padding="max_length", + truncation=True, + return_tensors="pt" + ) + self.labels = torch.tensor(dataframe[label_col].values, dtype=torch.long) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return { + "input_ids": self.encodings["input_ids"][idx], + "attention_mask": self.encodings["attention_mask"][idx], + "labels": self.labels[idx] + } + + +# 4. 模型初始化(添加设备移动) +def init_model(num_labels): + tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME) + model = BertForSequenceClassification.from_pretrained( + Config.MODEL_NAME, + num_labels=num_labels, + ignore_mismatched_sizes=True # 可选 + ).to(Config.DEVICE) + return tokenizer, model + + +# 5. 训练配置(添加早停和梯度累积) +def get_training_args(): + return TrainingArguments( + output_dir=Config.OUTPUT_DIR, + num_train_epochs=Config.NUM_EPOCHS, + per_device_train_batch_size=Config.BATCH_SIZE, + per_device_eval_batch_size=Config.BATCH_SIZE * 2, # 评估时可用更大batch + learning_rate=Config.LEARNING_RATE, + warmup_steps=Config.WARMUP_STEPS, + weight_decay=Config.WEIGHT_DECAY, + logging_dir=Config.LOG_DIR, + logging_steps=10, + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=200, + load_best_model_at_end=True, + metric_for_best_model="eval_loss", + greater_is_better=False, + fp16=Config.FP16, + gradient_accumulation_steps=2, # 模拟更大batch + report_to="none", # 禁用wandb等报告 + seed=42 + ) + + +# 6. 优化推理函数(添加批处理支持) +@torch.no_grad() +def batch_predict(texts, model, tokenizer, label_map, top_k=1, batch_size=16): + model.eval() + all_results = [] + + for i in tqdm(range(0, len(texts), batch_size), desc="预测中"): + batch = texts[i:i + batch_size] + inputs = tokenizer( + batch, + return_tensors="pt", + truncation=True, + padding=True, + max_length=Config.MAX_LENGTH + ).to(Config.DEVICE) + + outputs = model(**inputs) + probs = torch.softmax(outputs.logits, dim=1).cpu() + + for prob in probs: + top_probs, top_indices = torch.topk(prob, k=top_k) + all_results.extend([ + { + "category": label_map[idx.item()], + "confidence": prob.item() + } + for prob, idx in zip(top_probs, top_indices) + ]) + + return all_results[:len(texts)] # 处理非整除情况 + + +# 主流程 +if __name__ == "__main__": + # 1. 加载数据 + df = load_data("goods_cate.csv") + + # 2. 数据清洗 - 只保留中文 + print("🧼 开始清洗文本数据...") + df['sentence'] = df['sentence'].apply(clean_chinese_text) + df = df[df['sentence'].str.len() > 0].reset_index(drop=True) + print(f"✅ 数据清洗完成 | 剩余样本量: {len(df)}") + + # 3. 处理中文标签:映射为数值ID,并保存映射关系 + print("🏷️ 处理中文标签...") + label_encoder = LabelEncoder() + df['label_id'] = label_encoder.fit_transform(df['label']) # 中文标签 → 数值ID + label_map = {i: label for i, label in enumerate(label_encoder.classes_)} # 数值ID → 中文标签 + print(f"标签映射示例: {label_map}") + + # 保存标签映射器(供推理时使用) + joblib.dump(label_encoder, "cate/label_encoder.pkl") + print(f"✅ 标签映射完成 | 类别数: {len(label_map)}") + + # 4. 划分数据集(使用 label_id 列) + train_df, test_df = train_test_split( + df, test_size=0.2, random_state=42, stratify=df["label_id"] # 注意这里用 label_id + ) + + # 5. 初始化模型(使用数值标签的数量) + num_labels = len(label_map) + tokenizer, model = init_model(num_labels) + + # 6. 准备数据集(使用 label_id 列) + train_dataset = TextDataset(train_df, tokenizer, label_col="label_id") # 指定 label_col + test_dataset = TextDataset(test_df, tokenizer, label_col="label_id") + + # 7. 训练配置(保持不变) + training_args = get_training_args() + + # 8. 训练器(保持不变) + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=test_dataset, + callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] + ) + + # 9. 训练和保存(保持不变) + trainer.train() + model.save_pretrained(Config.SAVE_DIR) + tokenizer.save_pretrained(Config.SAVE_DIR) + # 12. 测试推理 + test_samples = ["不二家棒棒糖", "iPhone 15", "无线鼠标"] + # 先清洗测试样本 + cleaned_samples = [clean_chinese_text(s) for s in test_samples] + predictions = batch_predict(cleaned_samples, model, tokenizer, label_map) + for sample, pred in zip(test_samples, predictions): + print( + f"输入: {sample}\n清洗后: {clean_chinese_text(sample)}\n预测: {pred['category']} (置信度: {pred['confidence']:.2f})\n") \ No newline at end of file diff --git a/app.py b/other/app.py similarity index 82% rename from app.py rename to other/app.py index 022b248..e5ebff3 100644 --- a/app.py +++ b/other/app.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) try: classifier = pipeline( "zero-shot-classification", - model="./nlp_structbert_zero-shot-classification_chinese-large", + model="./nlp_structbert_zero-shot-classification_chinese-base", device="cpu", max_length=512, ignore_mismatched_sizes=True @@ -20,10 +20,6 @@ except Exception as e: logger.error(f"模型加载失败: {str(e)}") raise -# 定义类别标签 -level1 = [ - "食品", "电器", "洗护", "女装", "手机","健康", "男装", "美妆", "电脑", "运动","内衣", "母婴", "数码", "百货", "鞋包","办公", "家装", "饰品", "车品", "图书","生鲜", "家纺", "宠物", "奢品", "其它", "药品" -] # 创建Flask应用 app = Flask(__name__) @@ -49,12 +45,14 @@ def classify_text(): truncation=True ) - labels = result['labels'] + print(result) + print(result['scores'][-1]) # 构建响应 response = { - "cate": labels[0], + "cate": result['labels'][-1], + "score": result['scores'][-1], } return jsonify(response) diff --git a/test_my_mode.py b/test_my_mode.py new file mode 100644 index 0000000..44bc30d --- /dev/null +++ b/test_my_mode.py @@ -0,0 +1,129 @@ +from flask import Flask, request, jsonify +import torch +from transformers import BertTokenizer, BertForSequenceClassification +import joblib +import re +from functools import lru_cache +from concurrent.futures import ThreadPoolExecutor +from typing import List, Dict +import threading + +# 配置参数 +MAX_LENGTH = 512 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +SAVE_DIR = "./cate1" # 模型路径 +LABEL_ENCODER_PATH = "cate1/label_encoder.pkl" # 标签映射器路径 + +app = Flask(__name__) + +# 全局变量锁 +model_lock = threading.Lock() + + +class Predictor: + """预测器类,用于管理模型和分词器的生命周期""" + _instance = None + + def __new__(cls): + if cls._instance is None: + with model_lock: + if cls._instance is None: # 双重检查锁定 + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + # 加载模型和分词器 + self.tokenizer = BertTokenizer.from_pretrained(SAVE_DIR) + self.model = BertForSequenceClassification.from_pretrained(SAVE_DIR).to(DEVICE) + self.model.eval() + + # 加载标签映射器 + self.label_encoder = joblib.load(LABEL_ENCODER_PATH) + self.label_map = {i: label for i, label in enumerate(self.label_encoder.classes_)} + + # 创建线程池 + self.executor = ThreadPoolExecutor(max_workers=4) + self._initialized = True + + @lru_cache(maxsize=1000) + def clean_text(self, text: str) -> str: + """文本清洗函数,带缓存""" + if not isinstance(text, str): + return "" + cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text) + return cleaned_text.strip() + + def predict_single(self, text: str) -> Dict: + """单个文本预测""" + with torch.no_grad(): + cleaned_text = self.clean_text(text) + if not cleaned_text: + return {"category": "", "confidence": 0.0} + + inputs = self.tokenizer( + cleaned_text, + return_tensors="pt", + truncation=True, + padding=True, + max_length=MAX_LENGTH + ).to(DEVICE) + + outputs = self.model(**inputs) + probs = torch.softmax(outputs.logits, dim=1).cpu() + top_prob, top_idx = torch.topk(probs, k=1) + + #去掉置信度,速度提升10%-30% + # top_idx = torch.argmax(outputs.logits, dim=1) + return { + "category": self.label_map[top_idx.item()], + "confidence": top_prob.item() + } + + def batch_predict(self, texts: List[str]) -> List[Dict]: + """批量预测(并发处理)""" + if not texts: + return [] + + # 使用线程池并发处理 + futures = [self.executor.submit(self.predict_single, text) for text in texts] + return [future.result() for future in futures] + + +# 初始化预测器 +predictor = Predictor() + + +@app.route('/predict', methods=['POST']) +def predict(): + """预测接口""" + data = request.get_json() + if not data or 'product_names' not in data: + return jsonify({"error": "Invalid request, 'product_names' is required"}), 400 + + product_names = data['product_names'] + if not isinstance(product_names, list): + product_names = [product_names] + + # 批量预测 + results = predictor.batch_predict(product_names) + + return jsonify({ + "status": "success", + "predictions": results + }) + + +@app.route('/health', methods=['GET']) +def health_check(): + """健康检查接口""" + return jsonify({"status": "healthy"}) + + +if __name__ == '__main__': + # 生产环境建议使用: + # gunicorn -w 4 -b 0.0.0.0:5000 app:app --timeout 120 + app.run(host='0.0.0.0', port=5002, threaded=True) \ No newline at end of file