This commit is contained in:
parent
dd251baec0
commit
8ce8c75911
|
@ -12,4 +12,4 @@ COPY . .
|
||||||
EXPOSE 5000
|
EXPOSE 5000
|
||||||
|
|
||||||
# 确保模块名和 Flask 实例名正确(默认是 app:app)
|
# 确保模块名和 Flask 实例名正确(默认是 app:app)
|
||||||
CMD ["gunicorn", "-w", "2", "-k", "gthread", "--threads", "4", "-b", "0.0.0.0:5000", "app:app"]
|
CMD ["gunicorn", "-w", "2", "-k", "gthread", "--threads", "4", "-b", "0.0.0.0:5001", "app:app"]
|
||||||
|
|
2
app.py
2
app.py
|
@ -124,7 +124,7 @@ def health_check():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(host='0.0.0.0', port=5000, threaded=True)
|
app.run(host='0.0.0.0', port=5001, threaded=True)
|
||||||
else:
|
else:
|
||||||
application = app # 兼容 WSGI 标准(如 Gunicorn)
|
application = app # 兼容 WSGI 标准(如 Gunicorn)
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ import re # 用于正则表达式清洗
|
||||||
from sklearn.preprocessing import LabelEncoder
|
from sklearn.preprocessing import LabelEncoder
|
||||||
import joblib
|
import joblib
|
||||||
|
|
||||||
# 1. 参数配置(集中管理)
|
# 1. 参数配置
|
||||||
class Config:
|
class Config:
|
||||||
MODEL_NAME = "bert-base-chinese"
|
MODEL_NAME = "bert-base-chinese"
|
||||||
MAX_LENGTH = 64
|
MAX_LENGTH = 64
|
||||||
|
@ -31,7 +31,7 @@ class Config:
|
||||||
DEVICE = "cuda" if FP16 else "cpu"
|
DEVICE = "cuda" if FP16 else "cpu"
|
||||||
|
|
||||||
|
|
||||||
# 2. 数据加载与预处理(添加异常处理和日志)
|
# 2. 数据加载与预处理
|
||||||
def load_data(file_path):
|
def load_data(file_path):
|
||||||
try:
|
try:
|
||||||
df = pd.read_csv(file_path)
|
df = pd.read_csv(file_path)
|
||||||
|
@ -50,13 +50,11 @@ def clean_chinese_text(text):
|
||||||
"""
|
"""
|
||||||
if not isinstance(text, str):
|
if not isinstance(text, str):
|
||||||
return ""
|
return ""
|
||||||
# 使用正则表达式匹配所有中文字符(包括中文标点符号)[^\u4e00-\u9fa5\u3000-\u303f\uff00-\uffef]
|
|
||||||
# 如果需要更严格的只保留汉字,可以使用:[\u4e00-\u9fa5]
|
|
||||||
cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text)
|
cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text)
|
||||||
return cleaned_text.strip()
|
return cleaned_text.strip()
|
||||||
|
|
||||||
|
|
||||||
# 3. 优化Dataset(添加内存缓存和批处理支持)
|
# 3. Dataset
|
||||||
class TextDataset(Dataset):
|
class TextDataset(Dataset):
|
||||||
def __init__(self, dataframe, tokenizer, text_col="sentence", label_col="label"):
|
def __init__(self, dataframe, tokenizer, text_col="sentence", label_col="label"):
|
||||||
self.data = dataframe
|
self.data = dataframe
|
||||||
|
@ -85,30 +83,29 @@ class TextDataset(Dataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 4. 模型初始化(添加设备移动)
|
# 4. 模型初始化
|
||||||
def init_model(num_labels):
|
def init_model(num_labels):
|
||||||
tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME)
|
tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME)
|
||||||
model = BertForSequenceClassification.from_pretrained(
|
model = BertForSequenceClassification.from_pretrained(
|
||||||
Config.MODEL_NAME,
|
Config.MODEL_NAME,
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
ignore_mismatched_sizes=True # 可选
|
ignore_mismatched_sizes=True # 忽略不匹配warning
|
||||||
).to(Config.DEVICE)
|
).to(Config.DEVICE)
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
# 5. 训练配置(添加早停和梯度累积)
|
# 5. 训练配置
|
||||||
def get_training_args():
|
def get_training_args():
|
||||||
return TrainingArguments(
|
return TrainingArguments(
|
||||||
output_dir=Config.OUTPUT_DIR,
|
output_dir=Config.OUTPUT_DIR, #输出目录
|
||||||
num_train_epochs=Config.NUM_EPOCHS,
|
num_train_epochs=Config.NUM_EPOCHS, #训练轮数,适度训练会增加精度,训练过多可能会因为训练数据中的噪声(错误数据)导致精度下降,解决方案:正则,早停,数据增强;梯度爆炸
|
||||||
per_device_train_batch_size=Config.BATCH_SIZE,
|
per_device_train_batch_size=Config.BATCH_SIZE, #前向传播(forward pass)处理的样本数,比如:若 per_device_train_batch_size=32,且使用 2 块 GPU,则每块 GPU 会独立处理 32 个样本。总批量大小(total_batch_size)由以下公式决定:total_batch_size=per_device_train_batch_size×GPU 数量×gradient_accumulation_steps
|
||||||
per_device_eval_batch_size=Config.BATCH_SIZE * 2, # 评估时可用更大batch
|
per_device_eval_batch_size=Config.BATCH_SIZE * 2,
|
||||||
learning_rate=Config.LEARNING_RATE,
|
learning_rate=Config.LEARNING_RATE,
|
||||||
warmup_steps=Config.WARMUP_STEPS,
|
warmup_steps=Config.WARMUP_STEPS,
|
||||||
weight_decay=Config.WEIGHT_DECAY,
|
weight_decay=Config.WEIGHT_DECAY,
|
||||||
logging_dir=Config.LOG_DIR,
|
logging_dir=Config.LOG_DIR,
|
||||||
logging_steps=10,
|
logging_steps=10,
|
||||||
eval_strategy="steps",
|
|
||||||
eval_steps=100,
|
eval_steps=100,
|
||||||
save_strategy="steps",
|
save_strategy="steps",
|
||||||
save_steps=200,
|
save_steps=200,
|
||||||
|
@ -116,13 +113,13 @@ def get_training_args():
|
||||||
metric_for_best_model="eval_loss",
|
metric_for_best_model="eval_loss",
|
||||||
greater_is_better=False,
|
greater_is_better=False,
|
||||||
fp16=Config.FP16,
|
fp16=Config.FP16,
|
||||||
gradient_accumulation_steps=2, # 模拟更大batch
|
gradient_accumulation_steps=2,
|
||||||
report_to="none", # 禁用wandb等报告
|
report_to="none", # 禁用wandb等报告
|
||||||
seed=42
|
seed=42
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 6. 优化推理函数(添加批处理支持)
|
# 6.推理完测试函数
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def batch_predict(texts, model, tokenizer, label_map, top_k=1, batch_size=16):
|
def batch_predict(texts, model, tokenizer, label_map, top_k=1, batch_size=16):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -176,23 +173,23 @@ if __name__ == "__main__":
|
||||||
joblib.dump(label_encoder, "cate/label_encoder.pkl")
|
joblib.dump(label_encoder, "cate/label_encoder.pkl")
|
||||||
print(f"✅ 标签映射完成 | 类别数: {len(label_map)}")
|
print(f"✅ 标签映射完成 | 类别数: {len(label_map)}")
|
||||||
|
|
||||||
# 4. 划分数据集(使用 label_id 列)
|
# 4. 划分数据集
|
||||||
train_df, test_df = train_test_split(
|
train_df, test_df = train_test_split(
|
||||||
df, test_size=0.2, random_state=42, stratify=df["label_id"] # 注意这里用 label_id
|
df, test_size=0.2, random_state=42, stratify=df["label_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. 初始化模型(使用数值标签的数量)
|
# 5. 初始化模型
|
||||||
num_labels = len(label_map)
|
num_labels = len(label_map)
|
||||||
tokenizer, model = init_model(num_labels)
|
tokenizer, model = init_model(num_labels)
|
||||||
|
|
||||||
# 6. 准备数据集(使用 label_id 列)
|
# 6. 准备数据集(使用 label_id 列)
|
||||||
train_dataset = TextDataset(train_df, tokenizer, label_col="label_id") # 指定 label_col
|
train_dataset = TextDataset(train_df, tokenizer, label_col="label_id")
|
||||||
test_dataset = TextDataset(test_df, tokenizer, label_col="label_id")
|
test_dataset = TextDataset(test_df, tokenizer, label_col="label_id")
|
||||||
|
|
||||||
# 7. 训练配置(保持不变)
|
# 7. 训练配置
|
||||||
training_args = get_training_args()
|
training_args = get_training_args()
|
||||||
|
|
||||||
# 8. 训练器(保持不变)
|
# 8. 训练器
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
@ -201,7 +198,7 @@ if __name__ == "__main__":
|
||||||
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 9. 训练和保存(保持不变)
|
# 9. 训练和保存
|
||||||
trainer.train()
|
trainer.train()
|
||||||
model.save_pretrained(Config.SAVE_DIR)
|
model.save_pretrained(Config.SAVE_DIR)
|
||||||
tokenizer.save_pretrained(Config.SAVE_DIR)
|
tokenizer.save_pretrained(Config.SAVE_DIR)
|
||||||
|
|
Loading…
Reference in New Issue