llm_train/config.py

55 lines
1.8 KiB
Python

import os
class Config:
# 基座模型
MODEL_NAME = "Qwen/Qwen3-0.6B"
# 数据集
DATASET_NAME = "krisfu/delicate_medical_r1_data"
# 数据集主题(子数据集)
DATASET_SUBJECT = "default"
# 数据集用途
DATASET_SPLIT = "train"
# 是否使用缓存
DATASET_USE_CACHE = True
# swanlab项目名称
SWANLAB_PROJECT = "qweb3-sft-medical-10-11-1"
# 验证用system的提示词
PROMPT = "你是一个医学专家,你需要根据用户的问题,给出带有思考的回答。"
DATA_MAX_LENGTH = 2048
class Default:
DATASET_PATH = os.getenv("DATASET_PATH", "./dataset") # 支持环境变量覆
MODEL_DATASET_PATH = os.getenv("MODEL_DATASET_PATH", "./model_dataset") # 支持环境变量覆
SAVE_DIR = "./saved_model" # 微调后模型存储位置
TRAIN_DATASET_FILE = "train.jsonl"
TEST_DATASET_FILE = "val.jsonl"
TRAIN_JSONL_NEW_FILE = "train_format.jsonl"
TEST_JSONL_NEW_FILE = "val_format.jsonl"
dataset_short_name = Config.DATASET_NAME.split("/")[-1]
model_dataset_short_name = Config.MODEL_NAME.split("/")[-1]
# 确保缓存目录存在
dataset_dir = os.path.normpath(
os.path.join(Default.DATASET_PATH, dataset_short_name, Config.DATASET_SUBJECT, Config.DATASET_SPLIT)
)
model_dataset_DIR = os.path.normpath(
os.path.join(Default.MODEL_DATASET_PATH, model_dataset_short_name, Config.DATASET_SUBJECT, Config.DATASET_SPLIT)
)
model_dir = os.path.normpath(
os.path.join(Default.SAVE_DIR,model_dataset_short_name, dataset_short_name, Config.DATASET_SUBJECT, Config.DATASET_SPLIT)
)
os.makedirs(dataset_dir, exist_ok=True)
os.makedirs(model_dataset_DIR, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)
class Dir:
DATASET_DIR = dataset_dir
MODEL_DIR = model_dir
MODEL_DATASET_DIR = model_dataset_DIR