This commit is contained in:
parent
dd251baec0
commit
3abfe72e19
4
app.py
4
app.py
|
@ -11,7 +11,7 @@ import threading
|
|||
# 配置参数
|
||||
MAX_LENGTH = 512
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
SAVE_DIR = "./cate1" # 模型路径
|
||||
SAVE_DIR = "cate1" # 模型路径
|
||||
LABEL_ENCODER_PATH = "cate1/label_encoder.pkl" # 标签映射器路径
|
||||
|
||||
app = Flask(__name__)
|
||||
|
@ -124,7 +124,7 @@ def health_check():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0', port=5000, threaded=True)
|
||||
app.run(host='0.0.0.0', port=5002, threaded=True)
|
||||
else:
|
||||
application = app # 兼容 WSGI 标准(如 Gunicorn)
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ import joblib
|
|||
# 1. 参数配置(集中管理)
|
||||
class Config:
|
||||
MODEL_NAME = "bert-base-chinese"
|
||||
Train_CSV = "order_address.csv"
|
||||
MAX_LENGTH = 64
|
||||
BATCH_SIZE = 32
|
||||
NUM_EPOCHS = 5
|
||||
|
@ -157,7 +158,7 @@ def batch_predict(texts, model, tokenizer, label_map, top_k=1, batch_size=16):
|
|||
# 主流程
|
||||
if __name__ == "__main__":
|
||||
# 1. 加载数据
|
||||
df = load_data("goods_cate.csv")
|
||||
df = load_data(Config.Train_CSV)
|
||||
|
||||
# 2. 数据清洗 - 只保留中文
|
||||
print("🧼 开始清洗文本数据...")
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue