58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
from modelscope.msdatasets import MsDataset
|
|
from modelscope.utils.hub import read_config
|
|
from modelscope.trainers import build_trainer
|
|
|
|
|
|
class Config:
|
|
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
|
PRE_DEVICE_TRAIN_BATCH_SIZE = 1
|
|
GRADIENT_ACCUMULATION_STEPS = 8
|
|
LEARNING_RATE = 2e-5
|
|
NUMBER_TRAIN_EPOCH = 500
|
|
OUTPUT_DIR = "./qwen3_finetune"
|
|
SAVE_STRATEGY = "epoch"
|
|
FP16 = True
|
|
LOGGING_DIR = "./logs"
|
|
|
|
cfg = read_config(Config.MODEL_NAME)
|
|
print(cfg)
|
|
##自动下载
|
|
#train_dataset = MsDataset.load('iic/chinese-kuakua-collection', subset_name='default', split='train')
|
|
#test_dataset = MsDataset.load('iic/chinese-kuakua-collection', subset_name='default', split='test')
|
|
#本地加载
|
|
train_dataset = MsDataset.load('./chinese-kuakua-collection/train.csv')
|
|
eval_dataset = MsDataset.load('./chinese-kuakua-collection/test.csv')
|
|
|
|
|
|
def cfg_modify_fn(cfg):
|
|
# cfg.preprocessor.type='sen-sim-tokenizer'
|
|
# cfg.preprocessor.first_sequence = 'sentence1'
|
|
# cfg.preprocessor.second_sequence = 'sentence2'
|
|
# cfg.preprocessor.label = 'label'
|
|
# cfg.preprocessor.label2id = {'0': 0, '1': 1}
|
|
# cfg.model.num_labels = 2
|
|
# cfg.task = 'text-classification'
|
|
# cfg.pipeline = {'type': 'text-classification'}
|
|
cfg.train.max_epochs = 5
|
|
# cfg.train.work_dir = '/tmp'
|
|
cfg.train.dataloader.batch_size_per_gpu = 32
|
|
cfg.train.dataloader.workers_per_gpu = 0
|
|
# cfg.evaluation.dataloader.batch_size_per_gpu = 32
|
|
# cfg.train.dataloader.workers_per_gpu = 0
|
|
# cfg.evaluation.dataloader.workers_per_gpu = 0
|
|
# cfg.train.optimizer.lr = 2e-5
|
|
# cfg.train.lr_scheduler.total_iters = int(len(train_dataset) / cfg.train.dataloader.batch_size_per_gpu) * cfg.train.max_epochs
|
|
# cfg.evaluation.metrics = 'seq-cls-metric'
|
|
# # 注意这里需要返回修改后的cfg
|
|
return cfg
|
|
|
|
# 配置训练参数\
|
|
kwargs = dict(
|
|
model=Config.MODEL_NAME,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
cfg_modify_fn=cfg_modify_fn)
|
|
trainer = build_trainer(default_args=kwargs)
|
|
trainer.train()
|
|
|