This commit is contained in:
parent
dd251baec0
commit
3abfe72e19
4
app.py
4
app.py
|
@ -11,7 +11,7 @@ import threading
|
||||||
# 配置参数
|
# 配置参数
|
||||||
MAX_LENGTH = 512
|
MAX_LENGTH = 512
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
SAVE_DIR = "./cate1" # 模型路径
|
SAVE_DIR = "cate1" # 模型路径
|
||||||
LABEL_ENCODER_PATH = "cate1/label_encoder.pkl" # 标签映射器路径
|
LABEL_ENCODER_PATH = "cate1/label_encoder.pkl" # 标签映射器路径
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
@ -124,7 +124,7 @@ def health_check():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(host='0.0.0.0', port=5000, threaded=True)
|
app.run(host='0.0.0.0', port=5002, threaded=True)
|
||||||
else:
|
else:
|
||||||
application = app # 兼容 WSGI 标准(如 Gunicorn)
|
application = app # 兼容 WSGI 标准(如 Gunicorn)
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ import joblib
|
||||||
# 1. 参数配置(集中管理)
|
# 1. 参数配置(集中管理)
|
||||||
class Config:
|
class Config:
|
||||||
MODEL_NAME = "bert-base-chinese"
|
MODEL_NAME = "bert-base-chinese"
|
||||||
|
Train_CSV = "order_address.csv"
|
||||||
MAX_LENGTH = 64
|
MAX_LENGTH = 64
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
NUM_EPOCHS = 5
|
NUM_EPOCHS = 5
|
||||||
|
@ -157,7 +158,7 @@ def batch_predict(texts, model, tokenizer, label_map, top_k=1, batch_size=16):
|
||||||
# 主流程
|
# 主流程
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 1. 加载数据
|
# 1. 加载数据
|
||||||
df = load_data("goods_cate.csv")
|
df = load_data(Config.Train_CSV)
|
||||||
|
|
||||||
# 2. 数据清洗 - 只保留中文
|
# 2. 数据清洗 - 只保留中文
|
||||||
print("🧼 开始清洗文本数据...")
|
print("🧼 开始清洗文本数据...")
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue