diff --git a/ai.go b/ai.go index 1c82019..1b6646b 100644 --- a/ai.go +++ b/ai.go @@ -4,17 +4,54 @@ import ( "context" "encoding/json" "fmt" - "gitea.cdlsxd.cn/self-tools/l_ai_category/doubao" - "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" + "gitea.cdlsxd.cn/self-tools/l_request" "strconv" "strings" ) +func GetCategoryWithSelf(ctx context.Context, goodsName string, address string) (cateId int, confidence float64, err error) { + cateId, name, confidence, err := getCategoryWithAi(ctx, goodsName, address) + fmt.Printf("商品:%s,所属分类:%s,置信度:%f\n", goodsName, name, confidence) + return +} + +type CatRes struct { + Predictions []struct { + Category string `json:"category"` + Confidence float64 `json:"confidence"` + } `json:"predictions"` + Status string `json:"status"` +} + +func getCategoryWithAi(ctx context.Context, goodsName string, address string) (cateId int, cateName string, confidence float64, err error) { + request := map[string]interface{}{ + "product_names": goodsName, + } + req := l_request.Request{ + Method: "POST", + Url: address, + Json: request, + } + res, err := req.Send() + if err != nil { + return + } + var resMap CatRes + err = json.Unmarshal(res.Content, &resMap) + if err != nil { + return + } + cateInfo := strings.Split(resMap.Predictions[0].Category, "_") + cateId, _ = strconv.Atoi(cateInfo[1]) + return cateId, cateInfo[0], resMap.Predictions[0].Confidence, nil +} + type Category struct { - Id int `json:"id"` - Name string `json:"name"` - Level int `json:"level"` - Pid int `json:"pid"` + Id int `json:"id"` + Name string `json:"name"` + Level int `json:"level"` + Pid int `json:"pid"` + Pids map[int]int `json:"pid"` } type CategoryDic struct { @@ -25,58 +62,7 @@ type QuesStruct struct { Name string `json:"name"` } -func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) { - var ( - usage1 int - usage2 int - usage3 int - ) - defer func() { - fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3) - }() - cate1Json, _ := json.Marshal(catePath[0][1]) - cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "") - if err != nil || cate1 == 0 { - return - } - cate2Json, _ := json.Marshal(catePath[1][cate1]) - cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "") - if err != nil || cate2 == 0 { - return - } - cate3Json, _ := json.Marshal(catePath[2][cate2]) - cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类") - - return -} - -func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) { - var quesList []*QuesStruct - err = json.Unmarshal(cateJson, &quesList) - quesByte, _ := json.Marshal(quesList) - modelObj := doubao.NewDouBao(chatModel, key) - text := []string{ - "根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]", - "-只返回数字,不要有其他文字和字符", - "-JSON数据:" + string(quesByte), - } - if exCommand != "" { - text = append(text, exCommand) - } - - text = append(text, "-如果无法匹配,则返回0") - - category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) { - cleaned := strings.ReplaceAll(input, "\n", "") - return cleaned, nil - }, text...) - if err != nil { - return - } - categoryId, err = strconv.Atoi(category) - if err != nil { - err = fmt.Errorf("为找到分类") - categoryId = 0 - } - return +type Req struct { + Text string `json:"text"` + Cate []string `json:"cate"` } diff --git a/ai_test.go b/ai_test.go index 25a438f..9605056 100644 --- a/ai_test.go +++ b/ai_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm" - "sync" "testing" "time" ) @@ -15,38 +14,33 @@ const ( table = "category" modelType = "deepseek-r1-distill-qwen-32b-250120" key = "03320e58-6a0b-4061-a22b-902039f2190d" + Address = "http://8.147.107.207:5000/predict" ) -func TestCategory(t *testing.T) { - path := getPath() - res, err := GetCategory(context.Background(), "半亩花田倍润身体乳(失重玫瑰)250ml@默认", "03320e58-6a0b-4061-a22b-902039f2190d", modelType, path) - t.Log(res, err) -} - func TestCategoryGoodsSet(t *testing.T) { - var wg sync.WaitGroup + //var wg sync.WaitGroup defer clean() - path := getPath() + ///path := getPath() goods := goodsAll() updateChan := make(chan *UpdateLog, len(goods)) start := time.Now().Unix() - wg.Add(len(goods)) + //wg.Add(len(goods)) for _, v := range goods { - go func(updateChan chan *UpdateLog) { - defer wg.Done() - res, err := GetCategory(context.Background(), v.Title, key, modelType, path) + //go func(updateChan chan *UpdateLog) { + //defer wg.Done() + res, _, err := GetCategoryWithSelf(context.Background(), v.Title, Address) - if err != nil || res == 0 { - return - } - updateChan <- &UpdateLog{ - GoodsId: v.Id, - Title: v.Title, - CateId: res, - } - }(updateChan) + if err != nil || res == 0 { + return + } + updateChan <- &UpdateLog{ + GoodsId: v.Id, + Title: v.Title, + CateId: res, + } + //}(updateChan) } - wg.Wait() + //wg.Wait() close(updateChan) end := time.Now().Unix() fmt.Printf("耗时:%d秒", end-start) @@ -92,7 +86,7 @@ func goodsAll() (data []Goods) { row Goods err error ) - rows, _ := db.Raw("SELECT id,title FROM `goods` where brand_category_check=4").Rows() + rows, _ := db.Raw("SELECT id,title FROM `goods` order by id desc limit 1000 offset 3156").Rows() //where brand_category_check=4 defer rows.Close() for rows.Next() { err = db.ScanRows(rows, &row) @@ -126,6 +120,78 @@ func getPath() []map[int][]*CategoryDic { return cateLevelPath } +func getPath2() []Category { + cates := cateAll() + var ( + level1Map = make(map[int]Category) + level2Map = make(map[int]Category) + level3Slice = make([]Category, 0) + ) + + for _, v := range cates { + switch v.Level { + case 1: + level1Map[v.Id] = Category{ + Id: v.Id, + Name: v.Name, + } + case 2: + level2Map[v.Id] = Category{ + Id: v.Id, + Name: v.Name, + Pid: v.ParentId, + } + case 3: + cateName := fmt.Sprintf("%s/%s/%s", level1Map[level2Map[v.ParentId].Pid].Name, level2Map[v.ParentId].Name, v.Name) + level3Slice = append(level3Slice, Category{ + Id: v.Id, + Name: cateName, + }) + } + + } + return level3Slice +} + +func getPath3() (map[int]Category, map[string]Category) { + cates := cateAll() + var ( + level1Map = make(map[int]Category) + level2Map = make(map[int]Category) + level3Map = make(map[string]Category) + ) + + for _, v := range cates { + switch v.Level { + case 1: + + level2Map[v.Id] = Category{ + Name: v.Name, + } + case 2: + cate_name := fmt.Sprintf("%s/%s", level1Map[v.ParentId].Name, v.Name) + level2Map[v.Id] = Category{ + Id: v.Id, + Name: cate_name, + Pid: v.ParentId, + } + case 3: + if _, ex := level3Map[v.Name]; !ex { + level3Map[v.Name] = Category{ + Id: v.Id, + Pids: map[int]int{v.Id: v.ParentId}, + Name: v.Name, + } + } else { + level3Map[v.Name].Pids[v.Id] = v.ParentId + } + + } + + } + return level2Map, level3Map +} + func cateAll() (data []CategoryPhy) { var ( row CategoryPhy diff --git a/go.mod b/go.mod index e18d524..86bab5f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module gitea.cdlsxd.cn/self-tools/l_ai_category -go 1.22.2 +go 1.22.10 + +toolchain go1.23.7 require ( github.com/volcengine/volcengine-go-sdk v1.0.187 @@ -9,6 +11,7 @@ require ( ) require ( + gitea.cdlsxd.cn/self-tools/l_request v1.0.8 // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/google/uuid v1.3.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect