bert_cate/bert_train.py

215 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import (
BertTokenizer,
BertForSequenceClassification,
Trainer,
TrainingArguments,
EarlyStoppingCallback
)
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import warnings
import re # 用于正则表达式清洗
from sklearn.preprocessing import LabelEncoder
import joblib
# 1. 参数配置(集中管理)
class Config:
MODEL_NAME = "bert-base-chinese"
MAX_LENGTH = 64
BATCH_SIZE = 32
NUM_EPOCHS = 5
LEARNING_RATE = 2e-5
WARMUP_STEPS = 500
WEIGHT_DECAY = 0.01
FP16 = torch.cuda.is_available()
OUTPUT_DIR = "./results/single_level"
LOG_DIR = "./logs"
SAVE_DIR = "./saved_model/single_level"
DEVICE = "cuda" if FP16 else "cpu"
# 2. 数据加载与预处理(添加异常处理和日志)
def load_data(file_path):
try:
df = pd.read_csv(file_path)
assert {'sentence', 'label'}.issubset(df.columns), "数据必须包含'sentence''label'"
print(f"✅ 数据加载成功 | 样本量: {len(df)} | 分类数: {df['label'].nunique()}")
return df
except Exception as e:
warnings.warn(f"❌ 数据加载失败: {str(e)}")
raise
# 新增:数据清洗函数 - 只保留中文字符
def clean_chinese_text(text):
"""
清洗文本,只保留中文字符
"""
if not isinstance(text, str):
return ""
# 使用正则表达式匹配所有中文字符(包括中文标点符号)[^\u4e00-\u9fa5\u3000-\u303f\uff00-\uffef]
# 如果需要更严格的只保留汉字,可以使用:[\u4e00-\u9fa5]
cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text)
return cleaned_text.strip()
# 3. 优化Dataset添加内存缓存和批处理支持
class TextDataset(Dataset):
def __init__(self, dataframe, tokenizer, text_col="sentence", label_col="label"):
self.data = dataframe
self.tokenizer = tokenizer
self.text_col = text_col
self.label_col = label_col
# 预计算编码(空间换时间)
self.encodings = tokenizer(
dataframe[text_col].tolist(),
max_length=Config.MAX_LENGTH,
padding="max_length",
truncation=True,
return_tensors="pt"
)
self.labels = torch.tensor(dataframe[label_col].values, dtype=torch.long)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return {
"input_ids": self.encodings["input_ids"][idx],
"attention_mask": self.encodings["attention_mask"][idx],
"labels": self.labels[idx]
}
# 4. 模型初始化(添加设备移动)
def init_model(num_labels):
tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME)
model = BertForSequenceClassification.from_pretrained(
Config.MODEL_NAME,
num_labels=num_labels,
ignore_mismatched_sizes=True # 可选
).to(Config.DEVICE)
return tokenizer, model
# 5. 训练配置(添加早停和梯度累积)
def get_training_args():
return TrainingArguments(
output_dir=Config.OUTPUT_DIR,
num_train_epochs=Config.NUM_EPOCHS,
per_device_train_batch_size=Config.BATCH_SIZE,
per_device_eval_batch_size=Config.BATCH_SIZE * 2, # 评估时可用更大batch
learning_rate=Config.LEARNING_RATE,
warmup_steps=Config.WARMUP_STEPS,
weight_decay=Config.WEIGHT_DECAY,
logging_dir=Config.LOG_DIR,
logging_steps=10,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=200,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=Config.FP16,
gradient_accumulation_steps=2, # 模拟更大batch
report_to="none", # 禁用wandb等报告
seed=42
)
# 6. 优化推理函数(添加批处理支持)
@torch.no_grad()
def batch_predict(texts, model, tokenizer, label_map, top_k=1, batch_size=16):
model.eval()
all_results = []
for i in tqdm(range(0, len(texts), batch_size), desc="预测中"):
batch = texts[i:i + batch_size]
inputs = tokenizer(
batch,
return_tensors="pt",
truncation=True,
padding=True,
max_length=Config.MAX_LENGTH
).to(Config.DEVICE)
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).cpu()
for prob in probs:
top_probs, top_indices = torch.topk(prob, k=top_k)
all_results.extend([
{
"category": label_map[idx.item()],
"confidence": prob.item()
}
for prob, idx in zip(top_probs, top_indices)
])
return all_results[:len(texts)] # 处理非整除情况
# 主流程
if __name__ == "__main__":
# 1. 加载数据
df = load_data("goods_cate.csv")
# 2. 数据清洗 - 只保留中文
print("🧼 开始清洗文本数据...")
df['sentence'] = df['sentence'].apply(clean_chinese_text)
df = df[df['sentence'].str.len() > 0].reset_index(drop=True)
print(f"✅ 数据清洗完成 | 剩余样本量: {len(df)}")
# 3. 处理中文标签映射为数值ID并保存映射关系
print("🏷️ 处理中文标签...")
label_encoder = LabelEncoder()
df['label_id'] = label_encoder.fit_transform(df['label']) # 中文标签 → 数值ID
label_map = {i: label for i, label in enumerate(label_encoder.classes_)} # 数值ID → 中文标签
print(f"标签映射示例: {label_map}")
# 保存标签映射器(供推理时使用)
joblib.dump(label_encoder, "cate/label_encoder.pkl")
print(f"✅ 标签映射完成 | 类别数: {len(label_map)}")
# 4. 划分数据集(使用 label_id 列)
train_df, test_df = train_test_split(
df, test_size=0.2, random_state=42, stratify=df["label_id"] # 注意这里用 label_id
)
# 5. 初始化模型(使用数值标签的数量)
num_labels = len(label_map)
tokenizer, model = init_model(num_labels)
# 6. 准备数据集(使用 label_id 列)
train_dataset = TextDataset(train_df, tokenizer, label_col="label_id") # 指定 label_col
test_dataset = TextDataset(test_df, tokenizer, label_col="label_id")
# 7. 训练配置(保持不变)
training_args = get_training_args()
# 8. 训练器(保持不变)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
# 9. 训练和保存(保持不变)
trainer.train()
model.save_pretrained(Config.SAVE_DIR)
tokenizer.save_pretrained(Config.SAVE_DIR)
# 12. 测试推理
test_samples = ["不二家棒棒糖", "iPhone 15", "无线鼠标"]
# 先清洗测试样本
cleaned_samples = [clean_chinese_text(s) for s in test_samples]
predictions = batch_predict(cleaned_samples, model, tokenizer, label_map)
for sample, pred in zip(test_samples, predictions):
print(
f"输入: {sample}\n清洗后: {clean_chinese_text(sample)}\n预测: {pred['category']} (置信度: {pred['confidence']:.2f})\n")