bert_cate/bert_train.py

212 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import (
BertTokenizer,
BertForSequenceClassification,
Trainer,
TrainingArguments,
EarlyStoppingCallback
)
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import warnings
import re # 用于正则表达式清洗
from sklearn.preprocessing import LabelEncoder
import joblib
# 1. 参数配置
class Config:
MODEL_NAME = "bert-base-chinese"
MAX_LENGTH = 64
BATCH_SIZE = 32
NUM_EPOCHS = 5
LEARNING_RATE = 2e-5
WARMUP_STEPS = 500
WEIGHT_DECAY = 0.01
FP16 = torch.cuda.is_available()
OUTPUT_DIR = "./results/single_level"
LOG_DIR = "./logs"
SAVE_DIR = "./saved_model/single_level"
DEVICE = "cuda" if FP16 else "cpu"
# 2. 数据加载与预处理
def load_data(file_path):
try:
df = pd.read_csv(file_path)
assert {'sentence', 'label'}.issubset(df.columns), "数据必须包含'sentence''label'"
print(f"✅ 数据加载成功 | 样本量: {len(df)} | 分类数: {df['label'].nunique()}")
return df
except Exception as e:
warnings.warn(f"❌ 数据加载失败: {str(e)}")
raise
# 新增:数据清洗函数 - 只保留中文字符
def clean_chinese_text(text):
"""
清洗文本,只保留中文字符
"""
if not isinstance(text, str):
return ""
cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text)
return cleaned_text.strip()
# 3. Dataset
class TextDataset(Dataset):
def __init__(self, dataframe, tokenizer, text_col="sentence", label_col="label"):
self.data = dataframe
self.tokenizer = tokenizer
self.text_col = text_col
self.label_col = label_col
# 预计算编码(空间换时间)
self.encodings = tokenizer(
dataframe[text_col].tolist(),
max_length=Config.MAX_LENGTH,
padding="max_length",
truncation=True,
return_tensors="pt"
)
self.labels = torch.tensor(dataframe[label_col].values, dtype=torch.long)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return {
"input_ids": self.encodings["input_ids"][idx],
"attention_mask": self.encodings["attention_mask"][idx],
"labels": self.labels[idx]
}
# 4. 模型初始化
def init_model(num_labels):
tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME)
model = BertForSequenceClassification.from_pretrained(
Config.MODEL_NAME,
num_labels=num_labels,
ignore_mismatched_sizes=True # 忽略不匹配warning
).to(Config.DEVICE)
return tokenizer, model
# 5. 训练配置
def get_training_args():
return TrainingArguments(
output_dir=Config.OUTPUT_DIR, #输出目录
num_train_epochs=Config.NUM_EPOCHS, #训练轮数,适度训练会增加精度,训练过多可能会因为训练数据中的噪声(错误数据)导致精度下降,解决方案:正则,早停,数据增强;梯度爆炸
per_device_train_batch_size=Config.BATCH_SIZE, #前向传播forward pass处理的样本数比如若 per_device_train_batch_size=32且使用 2 块 GPU则每块 GPU 会独立处理 32 个样本。总批量大小total_batch_size由以下公式决定total_batch_size=per_device_train_batch_size×GPU 数量×gradient_accumulation_steps
per_device_eval_batch_size=Config.BATCH_SIZE * 2,
learning_rate=Config.LEARNING_RATE,
warmup_steps=Config.WARMUP_STEPS,
weight_decay=Config.WEIGHT_DECAY,
logging_dir=Config.LOG_DIR,
logging_steps=10,
eval_steps=100,
save_strategy="steps",
save_steps=200,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=Config.FP16,
gradient_accumulation_steps=2,
report_to="none", # 禁用wandb等报告
seed=42
)
# 6.推理完测试函数
@torch.no_grad()
def batch_predict(texts, model, tokenizer, label_map, top_k=1, batch_size=16):
model.eval()
all_results = []
for i in tqdm(range(0, len(texts), batch_size), desc="预测中"):
batch = texts[i:i + batch_size]
inputs = tokenizer(
batch,
return_tensors="pt",
truncation=True,
padding=True,
max_length=Config.MAX_LENGTH
).to(Config.DEVICE)
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).cpu()
for prob in probs:
top_probs, top_indices = torch.topk(prob, k=top_k)
all_results.extend([
{
"category": label_map[idx.item()],
"confidence": prob.item()
}
for prob, idx in zip(top_probs, top_indices)
])
return all_results[:len(texts)] # 处理非整除情况
# 主流程
if __name__ == "__main__":
# 1. 加载数据
df = load_data("goods_cate.csv")
# 2. 数据清洗 - 只保留中文
print("🧼 开始清洗文本数据...")
df['sentence'] = df['sentence'].apply(clean_chinese_text)
df = df[df['sentence'].str.len() > 0].reset_index(drop=True)
print(f"✅ 数据清洗完成 | 剩余样本量: {len(df)}")
# 3. 处理中文标签映射为数值ID并保存映射关系
print("🏷️ 处理中文标签...")
label_encoder = LabelEncoder()
df['label_id'] = label_encoder.fit_transform(df['label']) # 中文标签 → 数值ID
label_map = {i: label for i, label in enumerate(label_encoder.classes_)} # 数值ID → 中文标签
print(f"标签映射示例: {label_map}")
# 保存标签映射器(供推理时使用)
joblib.dump(label_encoder, "cate/label_encoder.pkl")
print(f"✅ 标签映射完成 | 类别数: {len(label_map)}")
# 4. 划分数据集
train_df, test_df = train_test_split(
df, test_size=0.2, random_state=42, stratify=df["label_id"]
)
# 5. 初始化模型
num_labels = len(label_map)
tokenizer, model = init_model(num_labels)
# 6. 准备数据集(使用 label_id 列)
train_dataset = TextDataset(train_df, tokenizer, label_col="label_id")
test_dataset = TextDataset(test_df, tokenizer, label_col="label_id")
# 7. 训练配置
training_args = get_training_args()
# 8. 训练器
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
# 9. 训练和保存
trainer.train()
model.save_pretrained(Config.SAVE_DIR)
tokenizer.save_pretrained(Config.SAVE_DIR)
# 12. 测试推理
test_samples = ["不二家棒棒糖", "iPhone 15", "无线鼠标"]
# 先清洗测试样本
cleaned_samples = [clean_chinese_text(s) for s in test_samples]
predictions = batch_predict(cleaned_samples, model, tokenizer, label_map)
for sample, pred in zip(test_samples, predictions):
print(
f"输入: {sample}\n清洗后: {clean_chinese_text(sample)}\n预测: {pred['category']} (置信度: {pred['confidence']:.2f})\n")