This commit is contained in:
renzhiyuan 2025-10-11 11:11:45 +08:00
parent dd251baec0
commit 3abfe72e19
3 changed files with 187654 additions and 3 deletions

4
app.py
View File

@ -11,7 +11,7 @@ import threading
# 配置参数 # 配置参数
MAX_LENGTH = 512 MAX_LENGTH = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_DIR = "./cate1" # 模型路径 SAVE_DIR = "cate1" # 模型路径
LABEL_ENCODER_PATH = "cate1/label_encoder.pkl" # 标签映射器路径 LABEL_ENCODER_PATH = "cate1/label_encoder.pkl" # 标签映射器路径
app = Flask(__name__) app = Flask(__name__)
@ -124,7 +124,7 @@ def health_check():
if __name__ == '__main__': if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, threaded=True) app.run(host='0.0.0.0', port=5002, threaded=True)
else: else:
application = app # 兼容 WSGI 标准(如 Gunicorn application = app # 兼容 WSGI 标准(如 Gunicorn

View File

@ -18,6 +18,7 @@ import joblib
# 1. 参数配置(集中管理) # 1. 参数配置(集中管理)
class Config: class Config:
MODEL_NAME = "bert-base-chinese" MODEL_NAME = "bert-base-chinese"
Train_CSV = "order_address.csv"
MAX_LENGTH = 64 MAX_LENGTH = 64
BATCH_SIZE = 32 BATCH_SIZE = 32
NUM_EPOCHS = 5 NUM_EPOCHS = 5
@ -157,7 +158,7 @@ def batch_predict(texts, model, tokenizer, label_map, top_k=1, batch_size=16):
# 主流程 # 主流程
if __name__ == "__main__": if __name__ == "__main__":
# 1. 加载数据 # 1. 加载数据
df = load_data("goods_cate.csv") df = load_data(Config.Train_CSV)
# 2. 数据清洗 - 只保留中文 # 2. 数据清洗 - 只保留中文
print("🧼 开始清洗文本数据...") print("🧼 开始清洗文本数据...")

187650
order_address.csv Normal file

File diff suppressed because it is too large Load Diff