package l_ai_category import ( "context" "encoding/json" "fmt" "gitea.cdlsxd.cn/self-tools/l_ai_category/doubao" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "strconv" "strings" ) type Category struct { Id int `json:"id"` Name string `json:"name"` Level int `json:"level"` Pid int `json:"pid"` } type CategoryDic struct { Category } type QuesStruct struct { Id int `json:"id"` Name string `json:"name"` } func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) { var ( usage1 int usage2 int usage3 int ) defer func() { fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3) }() cate1Json, _ := json.Marshal(catePath[0][1]) cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "") if err != nil || cate1 == 0 { return } cate2Json, _ := json.Marshal(catePath[1][cate1]) cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "") if err != nil || cate2 == 0 { return } cate3Json, _ := json.Marshal(catePath[2][cate2]) cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类") return } func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) { var quesList []*QuesStruct err = json.Unmarshal(cateJson, &quesList) quesByte, _ := json.Marshal(quesList) modelObj := doubao.NewDouBao(chatModel, key) text := []string{ "根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]", "-只返回数字,不要有其他文字和字符", "-JSON数据:" + string(quesByte), } if exCommand != "" { text = append(text, exCommand) } text = append(text, "-如果无法匹配,则返回0") category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) { cleaned := strings.ReplaceAll(input, "\n", "") return cleaned, nil }, text...) if err != nil { return } categoryId, err = strconv.Atoi(category) if err != nil { err = fmt.Errorf("为找到分类") categoryId = 0 } return }