This commit is contained in:
		
						commit
						e772392594
					
				|  | @ -0,0 +1,8 @@ | ||||||
|  | # Default ignored files | ||||||
|  | /shelf/ | ||||||
|  | /workspace.xml | ||||||
|  | # Editor-based HTTP Client requests | ||||||
|  | /httpRequests/ | ||||||
|  | # Datasource local storage ignored files | ||||||
|  | /dataSources/ | ||||||
|  | /dataSources.local.xml | ||||||
|  | @ -0,0 +1,8 @@ | ||||||
|  | <?xml version="1.0" encoding="UTF-8"?> | ||||||
|  | <module type="PYTHON_MODULE" version="4"> | ||||||
|  |   <component name="NewModuleRootManager"> | ||||||
|  |     <content url="file://$MODULE_DIR$" /> | ||||||
|  |     <orderEntry type="jdk" jdkName="311 (2)" jdkType="Python SDK" /> | ||||||
|  |     <orderEntry type="sourceFolder" forTests="false" /> | ||||||
|  |   </component> | ||||||
|  | </module> | ||||||
|  | @ -0,0 +1,14 @@ | ||||||
|  | <component name="InspectionProjectProfileManager"> | ||||||
|  |   <profile version="1.0"> | ||||||
|  |     <option name="myName" value="Project Default" /> | ||||||
|  |     <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true"> | ||||||
|  |       <option name="ignoredPackages"> | ||||||
|  |         <value> | ||||||
|  |           <list size="1"> | ||||||
|  |             <item index="0" class="java.lang.String" itemvalue="betterproto" /> | ||||||
|  |           </list> | ||||||
|  |         </value> | ||||||
|  |       </option> | ||||||
|  |     </inspection_tool> | ||||||
|  |   </profile> | ||||||
|  | </component> | ||||||
|  | @ -0,0 +1,6 @@ | ||||||
|  | <component name="InspectionProjectProfileManager"> | ||||||
|  |   <settings> | ||||||
|  |     <option name="USE_PROJECT_PROFILE" value="false" /> | ||||||
|  |     <version value="1.0" /> | ||||||
|  |   </settings> | ||||||
|  | </component> | ||||||
|  | @ -0,0 +1,7 @@ | ||||||
|  | <?xml version="1.0" encoding="UTF-8"?> | ||||||
|  | <project version="4"> | ||||||
|  |   <component name="Black"> | ||||||
|  |     <option name="sdkName" value="/usr/bin/python3.8 (6)" /> | ||||||
|  |   </component> | ||||||
|  |   <component name="ProjectRootManager" version="2" project-jdk-name="311 (2)" project-jdk-type="Python SDK" /> | ||||||
|  | </project> | ||||||
|  | @ -0,0 +1,8 @@ | ||||||
|  | <?xml version="1.0" encoding="UTF-8"?> | ||||||
|  | <project version="4"> | ||||||
|  |   <component name="ProjectModuleManager"> | ||||||
|  |     <modules> | ||||||
|  |       <module fileurl="file://$PROJECT_DIR$/.idea/bert_order_address.iml" filepath="$PROJECT_DIR$/.idea/bert_order_address.iml" /> | ||||||
|  |     </modules> | ||||||
|  |   </component> | ||||||
|  | </project> | ||||||
|  | @ -0,0 +1,130 @@ | ||||||
|  | from flask import Flask, request, jsonify | ||||||
|  | import torch | ||||||
|  | from transformers import BertTokenizer, BertForSequenceClassification | ||||||
|  | import joblib | ||||||
|  | import re | ||||||
|  | from functools import lru_cache | ||||||
|  | from concurrent.futures import ThreadPoolExecutor | ||||||
|  | from typing import List, Dict | ||||||
|  | import threading | ||||||
|  | 
 | ||||||
|  | # 配置参数 | ||||||
|  | MAX_LENGTH = 512 | ||||||
|  | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | ||||||
|  | SAVE_DIR = "address"  # 模型路径 | ||||||
|  | LABEL_ENCODER_PATH = "address/label_encoder.pkl"  # 标签映射器路径 | ||||||
|  | 
 | ||||||
|  | app = Flask(__name__) | ||||||
|  | app.config['JSON_AS_ASCII'] = False | ||||||
|  | # 全局变量锁 | ||||||
|  | model_lock = threading.Lock() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Predictor: | ||||||
|  |     """预测器类,用于管理模型和分词器的生命周期""" | ||||||
|  |     _instance = None | ||||||
|  | 
 | ||||||
|  |     def __new__(cls): | ||||||
|  |         if cls._instance is None: | ||||||
|  |             with model_lock: | ||||||
|  |                 if cls._instance is None:  # 双重检查锁定 | ||||||
|  |                     cls._instance = super().__new__(cls) | ||||||
|  |                     cls._instance._initialized = False | ||||||
|  |         return cls._instance | ||||||
|  | 
 | ||||||
|  |     def __init__(self): | ||||||
|  |         if self._initialized: | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         # 加载模型和分词器 | ||||||
|  |         self.tokenizer = BertTokenizer.from_pretrained(SAVE_DIR) | ||||||
|  |         self.model = BertForSequenceClassification.from_pretrained(SAVE_DIR).to(DEVICE) | ||||||
|  |         self.model.eval() | ||||||
|  | 
 | ||||||
|  |         # 加载标签映射器 | ||||||
|  |         self.label_encoder = joblib.load(LABEL_ENCODER_PATH) | ||||||
|  |         self.label_map = {i: label for i, label in enumerate(self.label_encoder.classes_)} | ||||||
|  | 
 | ||||||
|  |         # 创建线程池 | ||||||
|  |         self.executor = ThreadPoolExecutor(max_workers=4) | ||||||
|  |         self._initialized = True | ||||||
|  | 
 | ||||||
|  |     @lru_cache(maxsize=1000) | ||||||
|  |     def clean_text(self, text: str) -> str: | ||||||
|  |         """文本清洗函数,带缓存""" | ||||||
|  |         if not isinstance(text, str): | ||||||
|  |             return "" | ||||||
|  |         cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text) | ||||||
|  |         return cleaned_text.strip() | ||||||
|  | 
 | ||||||
|  |     def predict_single(self, text: str) -> Dict: | ||||||
|  |         """单个文本预测""" | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             cleaned_text = self.clean_text(text) | ||||||
|  |             if not cleaned_text: | ||||||
|  |                 return {"address": "", "confidence": 0.0} | ||||||
|  | 
 | ||||||
|  |             inputs = self.tokenizer( | ||||||
|  |                 cleaned_text, | ||||||
|  |                 return_tensors="pt", | ||||||
|  |                 truncation=True, | ||||||
|  |                 padding=True, | ||||||
|  |                 max_length=MAX_LENGTH | ||||||
|  |             ).to(DEVICE) | ||||||
|  | 
 | ||||||
|  |             outputs = self.model(**inputs) | ||||||
|  |             probs = torch.softmax(outputs.logits, dim=1).cpu() | ||||||
|  |             top_prob, top_idx = torch.topk(probs, k=1) | ||||||
|  | 
 | ||||||
|  |             #去掉置信度,速度提升10%-30% | ||||||
|  |             # top_idx = torch.argmax(outputs.logits, dim=1) | ||||||
|  |             return { | ||||||
|  |                 "address": self.label_map[top_idx.item()], | ||||||
|  |                 "confidence": top_prob.item() | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |     def batch_predict(self, texts: List[str]) -> List[Dict]: | ||||||
|  |         """批量预测(并发处理)""" | ||||||
|  |         if not texts: | ||||||
|  |             return [] | ||||||
|  | 
 | ||||||
|  |         # 使用线程池并发处理 | ||||||
|  |         futures = [self.executor.submit(self.predict_single, text) for text in texts] | ||||||
|  |         return [future.result() for future in futures] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # 初始化预测器 | ||||||
|  | predictor = Predictor() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @app.route('/predict', methods=['POST']) | ||||||
|  | def predict(): | ||||||
|  |     """预测接口""" | ||||||
|  |     data = request.get_json() | ||||||
|  |     if not data or 'address' not in data: | ||||||
|  |         return jsonify({"error": "Invalid request, 'address' is required"}), 400 | ||||||
|  | 
 | ||||||
|  |     addresss = data['address'] | ||||||
|  |     if not isinstance(addresss, list): | ||||||
|  |         addresss = [addresss] | ||||||
|  | 
 | ||||||
|  |     # 批量预测 | ||||||
|  |     results = predictor.batch_predict(addresss) | ||||||
|  | 
 | ||||||
|  |     return jsonify({ | ||||||
|  |         "status": "success", | ||||||
|  |         "predictions": results | ||||||
|  |     }) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @app.route('/health', methods=['GET']) | ||||||
|  | def health_check(): | ||||||
|  |     """健康检查接口""" | ||||||
|  |     return jsonify({"status": "healthy"}) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     app.run(host='0.0.0.0', port=5003, threaded=True) | ||||||
|  | else: | ||||||
|  |     application = app  # 兼容 WSGI 标准(如 Gunicorn) | ||||||
|  | 
 | ||||||
|  | @ -0,0 +1,215 @@ | ||||||
|  | import pandas as pd | ||||||
|  | from sklearn.model_selection import train_test_split | ||||||
|  | from transformers import ( | ||||||
|  |     BertTokenizer, | ||||||
|  |     BertForSequenceClassification, | ||||||
|  |     Trainer, | ||||||
|  |     TrainingArguments, | ||||||
|  |     EarlyStoppingCallback | ||||||
|  | ) | ||||||
|  | import torch | ||||||
|  | from torch.utils.data import Dataset | ||||||
|  | from tqdm import tqdm | ||||||
|  | import warnings | ||||||
|  | import re  # 用于正则表达式清洗 | ||||||
|  | from sklearn.preprocessing import LabelEncoder | ||||||
|  | import joblib | ||||||
|  | 
 | ||||||
|  | # 1. 参数配置(集中管理) | ||||||
|  | class Config: | ||||||
|  |     MODEL_NAME = "bert-base-chinese" | ||||||
|  |     MAX_LENGTH = 64 | ||||||
|  |     BATCH_SIZE = 32 | ||||||
|  |     NUM_EPOCHS = 5 | ||||||
|  |     LEARNING_RATE = 2e-5 | ||||||
|  |     WARMUP_STEPS = 500 | ||||||
|  |     WEIGHT_DECAY = 0.01 | ||||||
|  |     FP16 = torch.cuda.is_available() | ||||||
|  |     OUTPUT_DIR = "./results/single_level" | ||||||
|  |     LOG_DIR = "./logs" | ||||||
|  |     SAVE_DIR = "./saved_model/single_level" | ||||||
|  |     DEVICE = "cuda" if FP16 else "cpu" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # 2. 数据加载与预处理(添加异常处理和日志) | ||||||
|  | def load_data(file_path): | ||||||
|  |     try: | ||||||
|  |         df = pd.read_csv(file_path) | ||||||
|  |         assert {'sentence', 'label'}.issubset(df.columns), "数据必须包含'sentence'和'label'列" | ||||||
|  |         print(f"✅ 数据加载成功 | 样本量: {len(df)} | 分类数: {df['label'].nunique()}") | ||||||
|  |         return df | ||||||
|  |     except Exception as e: | ||||||
|  |         warnings.warn(f"❌ 数据加载失败: {str(e)}") | ||||||
|  |         raise | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # 新增:数据清洗函数 - 只保留中文字符 | ||||||
|  | def clean_chinese_text(text): | ||||||
|  |     """ | ||||||
|  |     清洗文本,只保留中文字符 | ||||||
|  |     """ | ||||||
|  |     if not isinstance(text, str): | ||||||
|  |         return "" | ||||||
|  |     # 使用正则表达式匹配所有中文字符(包括中文标点符号)[^\u4e00-\u9fa5\u3000-\u303f\uff00-\uffef] | ||||||
|  |     # 如果需要更严格的只保留汉字,可以使用:[\u4e00-\u9fa5] | ||||||
|  |     cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text) | ||||||
|  |     return cleaned_text.strip() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # 3. 优化Dataset(添加内存缓存和批处理支持) | ||||||
|  | class TextDataset(Dataset): | ||||||
|  |     def __init__(self, dataframe, tokenizer, text_col="sentence", label_col="label"): | ||||||
|  |         self.data = dataframe | ||||||
|  |         self.tokenizer = tokenizer | ||||||
|  |         self.text_col = text_col | ||||||
|  |         self.label_col = label_col | ||||||
|  | 
 | ||||||
|  |         # 预计算编码(空间换时间) | ||||||
|  |         self.encodings = tokenizer( | ||||||
|  |             dataframe[text_col].tolist(), | ||||||
|  |             max_length=Config.MAX_LENGTH, | ||||||
|  |             padding="max_length", | ||||||
|  |             truncation=True, | ||||||
|  |             return_tensors="pt" | ||||||
|  |         ) | ||||||
|  |         self.labels = torch.tensor(dataframe[label_col].values, dtype=torch.long) | ||||||
|  | 
 | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.data) | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, idx): | ||||||
|  |         return { | ||||||
|  |             "input_ids": self.encodings["input_ids"][idx], | ||||||
|  |             "attention_mask": self.encodings["attention_mask"][idx], | ||||||
|  |             "labels": self.labels[idx] | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # 4. 模型初始化(添加设备移动) | ||||||
|  | def init_model(num_labels): | ||||||
|  |     tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME) | ||||||
|  |     model = BertForSequenceClassification.from_pretrained( | ||||||
|  |         Config.MODEL_NAME, | ||||||
|  |         num_labels=num_labels, | ||||||
|  |         ignore_mismatched_sizes=True  # 可选 | ||||||
|  |     ).to(Config.DEVICE) | ||||||
|  |     return tokenizer, model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # 5. 训练配置(添加早停和梯度累积) | ||||||
|  | def get_training_args(): | ||||||
|  |     return TrainingArguments( | ||||||
|  |         output_dir=Config.OUTPUT_DIR, | ||||||
|  |         num_train_epochs=Config.NUM_EPOCHS, | ||||||
|  |         per_device_train_batch_size=Config.BATCH_SIZE, | ||||||
|  |         per_device_eval_batch_size=Config.BATCH_SIZE * 2,  # 评估时可用更大batch | ||||||
|  |         learning_rate=Config.LEARNING_RATE, | ||||||
|  |         warmup_steps=Config.WARMUP_STEPS, | ||||||
|  |         weight_decay=Config.WEIGHT_DECAY, | ||||||
|  |         logging_dir=Config.LOG_DIR, | ||||||
|  |         logging_steps=10, | ||||||
|  |         eval_strategy="steps", | ||||||
|  |         eval_steps=100, | ||||||
|  |         save_strategy="steps", | ||||||
|  |         save_steps=200, | ||||||
|  |         load_best_model_at_end=True, | ||||||
|  |         metric_for_best_model="eval_loss", | ||||||
|  |         greater_is_better=False, | ||||||
|  |         fp16=Config.FP16, | ||||||
|  |         gradient_accumulation_steps=2,  # 模拟更大batch | ||||||
|  |         report_to="none",  # 禁用wandb等报告 | ||||||
|  |         seed=42 | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # 6. 优化推理函数(添加批处理支持) | ||||||
|  | @torch.no_grad() | ||||||
|  | def batch_predict(texts, model, tokenizer, label_map, top_k=1, batch_size=16): | ||||||
|  |     model.eval() | ||||||
|  |     all_results = [] | ||||||
|  | 
 | ||||||
|  |     for i in tqdm(range(0, len(texts), batch_size), desc="预测中"): | ||||||
|  |         batch = texts[i:i + batch_size] | ||||||
|  |         inputs = tokenizer( | ||||||
|  |             batch, | ||||||
|  |             return_tensors="pt", | ||||||
|  |             truncation=True, | ||||||
|  |             padding=True, | ||||||
|  |             max_length=Config.MAX_LENGTH | ||||||
|  |         ).to(Config.DEVICE) | ||||||
|  | 
 | ||||||
|  |         outputs = model(**inputs) | ||||||
|  |         probs = torch.softmax(outputs.logits, dim=1).cpu() | ||||||
|  | 
 | ||||||
|  |         for prob in probs: | ||||||
|  |             top_probs, top_indices = torch.topk(prob, k=top_k) | ||||||
|  |             all_results.extend([ | ||||||
|  |                 { | ||||||
|  |                     "category": label_map[idx.item()], | ||||||
|  |                     "confidence": prob.item() | ||||||
|  |                 } | ||||||
|  |                 for prob, idx in zip(top_probs, top_indices) | ||||||
|  |             ]) | ||||||
|  | 
 | ||||||
|  |     return all_results[:len(texts)]  # 处理非整除情况 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # 主流程 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     # 1. 加载数据 | ||||||
|  |     df = load_data("order_address.csv") | ||||||
|  | 
 | ||||||
|  |     # 2. 数据清洗 - 只保留中文 | ||||||
|  |     print("🧼 开始清洗文本数据...") | ||||||
|  |     df['sentence'] = df['sentence'].apply(clean_chinese_text) | ||||||
|  |     df = df[df['sentence'].str.len() > 0].reset_index(drop=True) | ||||||
|  |     print(f"✅ 数据清洗完成 | 剩余样本量: {len(df)}") | ||||||
|  | 
 | ||||||
|  |     # 3. 处理中文标签:映射为数值ID,并保存映射关系 | ||||||
|  |     print("🏷️ 处理中文标签...") | ||||||
|  |     label_encoder = LabelEncoder() | ||||||
|  |     df['label_id'] = label_encoder.fit_transform(df['label'])  # 中文标签 → 数值ID | ||||||
|  |     label_map = {i: label for i, label in enumerate(label_encoder.classes_)}  # 数值ID → 中文标签 | ||||||
|  |     print(f"标签映射示例: {label_map}") | ||||||
|  | 
 | ||||||
|  |     # 保存标签映射器(供推理时使用) | ||||||
|  |     joblib.dump(label_encoder, "cate/label_encoder.pkl") | ||||||
|  |     print(f"✅ 标签映射完成 | 类别数: {len(label_map)}") | ||||||
|  | 
 | ||||||
|  |     # 4. 划分数据集(使用 label_id 列) | ||||||
|  |     train_df, test_df = train_test_split( | ||||||
|  |         df, test_size=0.2, random_state=42, stratify=df["label_id"]  # 注意这里用 label_id | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # 5. 初始化模型(使用数值标签的数量) | ||||||
|  |     num_labels = len(label_map) | ||||||
|  |     tokenizer, model = init_model(num_labels) | ||||||
|  | 
 | ||||||
|  |     # 6. 准备数据集(使用 label_id 列) | ||||||
|  |     train_dataset = TextDataset(train_df, tokenizer, label_col="label_id")  # 指定 label_col | ||||||
|  |     test_dataset = TextDataset(test_df, tokenizer, label_col="label_id") | ||||||
|  | 
 | ||||||
|  |     # 7. 训练配置(保持不变) | ||||||
|  |     training_args = get_training_args() | ||||||
|  | 
 | ||||||
|  |     # 8. 训练器(保持不变) | ||||||
|  |     trainer = Trainer( | ||||||
|  |         model=model, | ||||||
|  |         args=training_args, | ||||||
|  |         train_dataset=train_dataset, | ||||||
|  |         eval_dataset=test_dataset, | ||||||
|  |         callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # 9. 训练和保存(保持不变) | ||||||
|  |     trainer.train() | ||||||
|  |     model.save_pretrained(Config.SAVE_DIR) | ||||||
|  |     tokenizer.save_pretrained(Config.SAVE_DIR) | ||||||
|  |     # 12. 测试推理 | ||||||
|  |     test_samples = ["山东省济南市莱芜区碧桂园天樾422502", "广东省广州市花都区狮岭镇山前旅游大道18号机车检修段", "江苏省苏州市吴中区吴中区木渎镇枫瑞路85号诺德·长枫雅苑北区10栋-303"] | ||||||
|  |     # 先清洗测试样本 | ||||||
|  |     cleaned_samples = [clean_chinese_text(s) for s in test_samples] | ||||||
|  |     predictions = batch_predict(cleaned_samples, model, tokenizer, label_map) | ||||||
|  |     for sample, pred in zip(test_samples, predictions): | ||||||
|  |         print( | ||||||
|  |             f"输入: {sample}\n清洗后: {clean_chinese_text(sample)}\n预测: {pred['category']} (置信度: {pred['confidence']:.2f})\n") | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Loading…
	
		Reference in New Issue