diff --git a/ai.go b/ai.go index 084af66..1c82019 100644 --- a/ai.go +++ b/ai.go @@ -7,6 +7,7 @@ import ( "gitea.cdlsxd.cn/self-tools/l_ai_category/doubao" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "strconv" + "strings" ) type Category struct { @@ -25,37 +26,49 @@ type QuesStruct struct { } func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) { + var ( + usage1 int + usage2 int + usage3 int + ) + defer func() { + fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3) + }() cate1Json, _ := json.Marshal(catePath[0][1]) - cate1, err := ques(ctx, goodsName, key, chatModel, cate1Json) - if err != nil { + cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "") + if err != nil || cate1 == 0 { return } cate2Json, _ := json.Marshal(catePath[1][cate1]) - cate2, err := ques(ctx, goodsName, key, chatModel, cate2Json) - if err != nil { + cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "") + if err != nil || cate2 == 0 { return } cate3Json, _ := json.Marshal(catePath[2][cate2]) - cate3, err = ques(ctx, goodsName, key, chatModel, cate3Json) - if err != nil { - return - } + cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类") + return } -func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte) (categoryId int, err error) { +func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) { var quesList []*QuesStruct err = json.Unmarshal(cateJson, &quesList) quesByte, _ := json.Marshal(quesList) modelObj := doubao.NewDouBao(chatModel, key) text := []string{ - "根据商品名称,从json中找到该商品对应的分类[goodsName]" + goodsName + "[/goodsName]", - "-只返回分类名称对应的id,返回其他文字", - "-如果无法匹配,则返回数字0", + "根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]", + "-只返回数字,不要有其他文字和字符", "-JSON数据:" + string(quesByte), } - category, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) { - return input, nil + if exCommand != "" { + text = append(text, exCommand) + } + + text = append(text, "-如果无法匹配,则返回0") + + category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) { + cleaned := strings.ReplaceAll(input, "\n", "") + return cleaned, nil }, text...) if err != nil { return @@ -67,39 +80,3 @@ func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte } return } - -func Test(ctx context.Context, key, chatModel string) (category string, err error) { - - modelObj := doubao.NewDouBao(chatModel, key) - text := []string{ - "给我base_json里休闲食品对应的id", - } - category, err = modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) { - return input, nil - }, text...) - if err != nil { - return - } - - return -} - -func FlushCate(ctx context.Context, key, chatModel string, cateJson string) (res string, err error) { - - modelObj := doubao.NewDouBao(chatModel, key) - text := []string{ - "你是一个之智能非关系型数据库", - "你的任务是将提供的 JSON 作为base_json进行存放,之后所有提问提到的base_json都是指这个JSON。", - "以下是base_json:", - fmt.Sprintf("%s", cateJson), - "请将上述 JSON 作为base_json 进行存储。之后有提问提到base_json 时,要明确知晓是指此JSON。", - } - res, err = modelObj.GetData(ctx, key, model.ChatMessageRoleAssistant, func(input string) (string, error) { - return input, nil - }, text...) - if err != nil { - return - } - - return -} diff --git a/ai_test.go b/ai_test.go index 01f6b33..25a438f 100644 --- a/ai_test.go +++ b/ai_test.go @@ -4,84 +4,95 @@ import ( "context" "fmt" "gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm" + "sync" "testing" + "time" ) const ( - dns = "root:lansexiongdi6,@tcp(47.97.27.195:3306)/report?parseTime=True&loc=Local" + dns = "developer:Lsxd@2024@tcp(lsxdpolar.rwlb.rds.aliyuncs.com:3306)/physicalgoods?parseTime=True&loc=Local" driver = "mysql" - table = "goods_category" - modelType = "deepseek-v3-250324" + table = "category" + modelType = "deepseek-r1-distill-qwen-32b-250120" + key = "03320e58-6a0b-4061-a22b-902039f2190d" ) func TestCategory(t *testing.T) { path := getPath() - res, err := GetCategory(context.Background(), "MM 便携装湿厕纸 40包*10片@默认", "03320e58-6a0b-4061-a22b-902039f2190d", modelType, path) + res, err := GetCategory(context.Background(), "半亩花田倍润身体乳(失重玫瑰)250ml@默认", "03320e58-6a0b-4061-a22b-902039f2190d", modelType, path) t.Log(res, err) } -func TestRem(t *testing.T) { - - res, err := Test(context.Background(), "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType) - t.Log(res, err) -} - -//func TestFlush(t *testing.T) { -// catString := string(CategoryToJsonByte()) -// res, err := FlushCate(context.Background(), "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType, catString) -// t.Log(res, err) -//} -// -//func TestGetCateGoryJson(t *testing.T) { -// cateJsonByte := CategoryToJsonByte() -// -// t.Log(string(cateJsonByte)) -//} - -//func CategoryToJsonByte() (cateJsonByte []byte) { -// -// cateLevelPath := getPath() -// cate := cateIteration(cateLevelPath[0], cateIteration(cateLevelPath[1], cateLevelPath[2]))[0] -// cateJsonByte, _ = json.Marshal(cate) -// return -//} - -func getPath() []map[int][]*CategoryDic { - cates := cateAll() - var ( - cateLevelPath = make([]map[int][]*CategoryDic, 3) - ) - for _, v := range cates { - if cateLevelPath[v.Level-1] == nil { - cateLevelPath[v.Level-1] = make(map[int][]*CategoryDic) - } - if _, ex := cateLevelPath[v.Level-1][v.Pid]; !ex { - cateLevelPath[v.Level-1][v.Pid] = make([]*CategoryDic, 0) - } - cateLevelPath[v.Level-1][v.Pid] = append(cateLevelPath[v.Level-1][v.Pid], &CategoryDic{Category: v}) - } - return cateLevelPath -} - -//func cateIteration(parent map[int][]*CategoryDic, child map[int][]*CategoryDic) map[int][]*CategoryDic { -// for _, v := range parent { -// for _, vv := range v { -// vv.Children = child[vv.Id] -// } -// } -// return parent -//} - -func cateAll() (data []Category) { - var ( - row Category - ) - db, err, clean := utils_gorm.DB(driver, dns) +func TestCategoryGoodsSet(t *testing.T) { + var wg sync.WaitGroup defer clean() - if err != nil { - panic(err) + path := getPath() + goods := goodsAll() + updateChan := make(chan *UpdateLog, len(goods)) + start := time.Now().Unix() + wg.Add(len(goods)) + for _, v := range goods { + go func(updateChan chan *UpdateLog) { + defer wg.Done() + res, err := GetCategory(context.Background(), v.Title, key, modelType, path) + + if err != nil || res == 0 { + return + } + updateChan <- &UpdateLog{ + GoodsId: v.Id, + Title: v.Title, + CateId: res, + } + }(updateChan) } - rows, _ := db.Raw(fmt.Sprintf("SELECT id,name,level,pid FROM `%s` where `type`=2", table)).Rows() + wg.Wait() + close(updateChan) + end := time.Now().Unix() + fmt.Printf("耗时:%d秒", end-start) + updateSlice := make([]*UpdateLog, len(updateChan)) + for v := range updateChan { + updateSlice = append(updateSlice, v) + } + for _, v := range updateSlice { + if v == nil { + continue + } + updateGoodsCate(v) + } + + t.Log(updateSlice) +} + +func updateGoodsCate(data *UpdateLog) { + db.Table("goods_category_relation").Where("goods_id = ?", data.GoodsId).Updates(map[string]interface{}{"category_id": data.CateId}) + db.Table("goods").Where("id = ?", data.GoodsId).Updates(map[string]interface{}{"brand_category_check": 3}) +} + +type UpdateLog struct { + GoodsId int `json:"goods_id"` + Title string `json:"title"` + CateId int `json:"cate_id"` +} +type CategoryPhy struct { + Id int `json:"id"` + Name string `json:"name"` + Level int `json:"level"` + ParentId int `json:"parent_id"` +} + +type Goods struct { + Id int `json:"id"` + Title string `json:"title"` + BrandCateGoryCheck int `json:"brand_category_check"` +} + +func goodsAll() (data []Goods) { + var ( + row Goods + err error + ) + rows, _ := db.Raw("SELECT id,title FROM `goods` where brand_category_check=4").Rows() defer rows.Close() for rows.Next() { err = db.ScanRows(rows, &row) @@ -92,3 +103,44 @@ func cateAll() (data []Category) { } return } + +func getPath() []map[int][]*CategoryDic { + cates := cateAll() + var ( + cateLevelPath = make([]map[int][]*CategoryDic, 3) + ) + for _, v := range cates { + if cateLevelPath[v.Level-1] == nil { + cateLevelPath[v.Level-1] = make(map[int][]*CategoryDic) + } + if _, ex := cateLevelPath[v.Level-1][v.ParentId]; !ex { + cateLevelPath[v.Level-1][v.ParentId] = make([]*CategoryDic, 0) + } + cateLevelPath[v.Level-1][v.ParentId] = append(cateLevelPath[v.Level-1][v.ParentId], &CategoryDic{ + Category{ + Id: v.Id, + Name: v.Name, + Level: v.Level, + }}) + } + return cateLevelPath +} + +func cateAll() (data []CategoryPhy) { + var ( + row CategoryPhy + ) + + rows, _ := db.Raw(fmt.Sprintf("SELECT id,name,level,parent_id FROM `%s`", table)).Rows() + defer rows.Close() + for rows.Next() { + err := db.ScanRows(rows, &row) + if err != nil { + panic(err) + } + data = append(data, row) + } + return +} + +var db, _, clean = utils_gorm.DB(driver, dns) diff --git a/doubao/doubao.go b/doubao/doubao.go index 12002e9..9700037 100644 --- a/doubao/doubao.go +++ b/doubao/doubao.go @@ -2,7 +2,6 @@ package doubao import ( "context" - "fmt" "github.com/volcengine/volcengine-go-sdk/service/arkruntime" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "github.com/volcengine/volcengine-go-sdk/volcengine" @@ -21,7 +20,7 @@ func NewDouBao(modelName string, key string) *DouBao { } } -func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandle func(input string) (string, error), text ...string) (string, error) { +func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandle func(input string) (string, error), text ...string) (string, int, error) { var Message = make([]*model.ChatCompletionMessage, len(text)) client := arkruntime.NewClientWithApiKey( @@ -46,10 +45,10 @@ func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandl resp, err := client.CreateChatCompletion(ctx, req) if err != nil { - return "", err + return "", 0, err } - fmt.Printf("用量:%d", resp.Usage.TotalTokens) + result, err := respHandle(*resp.Choices[0].Message.Content.StringValue) - return result, err + return result, resp.Usage.TotalTokens, err }