55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
import os
|
|
|
|
|
|
class Config:
|
|
# 基座模型
|
|
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
|
# 数据集
|
|
DATASET_NAME = "krisfu/delicate_medical_r1_data"
|
|
# 数据集主题(子数据集)
|
|
DATASET_SUBJECT = "default"
|
|
# 数据集用途
|
|
DATASET_SPLIT = "train"
|
|
# 是否使用缓存
|
|
DATASET_USE_CACHE = True
|
|
# swanlab项目名称
|
|
SWANLAB_PROJECT = "qweb3-sft-medical-10-11-1"
|
|
# 验证用system的提示词
|
|
PROMPT = "你是一个医学专家,你需要根据用户的问题,给出带有思考的回答。"
|
|
DATA_MAX_LENGTH = 2048
|
|
|
|
|
|
class Default:
|
|
DATASET_PATH = os.getenv("DATASET_PATH", "./dataset") # 支持环境变量覆
|
|
MODEL_DATASET_PATH = os.getenv("MODEL_DATASET_PATH", "./model_dataset") # 支持环境变量覆
|
|
SAVE_DIR = "./saved_model" # 微调后模型存储位置
|
|
TRAIN_DATASET_FILE = "train.jsonl"
|
|
TEST_DATASET_FILE = "val.jsonl"
|
|
TRAIN_JSONL_NEW_FILE = "train_format.jsonl"
|
|
TEST_JSONL_NEW_FILE = "val_format.jsonl"
|
|
|
|
|
|
dataset_short_name = Config.DATASET_NAME.split("/")[-1]
|
|
model_dataset_short_name = Config.MODEL_NAME.split("/")[-1]
|
|
# 确保缓存目录存在
|
|
dataset_dir = os.path.normpath(
|
|
os.path.join(Default.DATASET_PATH, dataset_short_name, Config.DATASET_SUBJECT, Config.DATASET_SPLIT)
|
|
)
|
|
model_dataset_DIR = os.path.normpath(
|
|
os.path.join(Default.MODEL_DATASET_PATH, model_dataset_short_name, Config.DATASET_SUBJECT, Config.DATASET_SPLIT)
|
|
)
|
|
model_dir = os.path.normpath(
|
|
os.path.join(Default.SAVE_DIR,model_dataset_short_name, dataset_short_name, Config.DATASET_SUBJECT, Config.DATASET_SPLIT)
|
|
)
|
|
os.makedirs(dataset_dir, exist_ok=True)
|
|
os.makedirs(model_dataset_DIR, exist_ok=True)
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
|
class Dir:
|
|
DATASET_DIR = dataset_dir
|
|
MODEL_DIR = model_dir
|
|
MODEL_DATASET_DIR = model_dataset_DIR
|
|
|
|
|