83 lines
2.3 KiB
Go
83 lines
2.3 KiB
Go
package l_ai_category
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao"
|
||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||
"strconv"
|
||
"strings"
|
||
)
|
||
|
||
type Category struct {
|
||
Id int `json:"id"`
|
||
Name string `json:"name"`
|
||
Level int `json:"level"`
|
||
Pid int `json:"pid"`
|
||
}
|
||
|
||
type CategoryDic struct {
|
||
Category
|
||
}
|
||
type QuesStruct struct {
|
||
Id int `json:"id"`
|
||
Name string `json:"name"`
|
||
}
|
||
|
||
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) {
|
||
var (
|
||
usage1 int
|
||
usage2 int
|
||
usage3 int
|
||
)
|
||
defer func() {
|
||
fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3)
|
||
}()
|
||
cate1Json, _ := json.Marshal(catePath[0][1])
|
||
cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "")
|
||
if err != nil || cate1 == 0 {
|
||
return
|
||
}
|
||
cate2Json, _ := json.Marshal(catePath[1][cate1])
|
||
cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "")
|
||
if err != nil || cate2 == 0 {
|
||
return
|
||
}
|
||
cate3Json, _ := json.Marshal(catePath[2][cate2])
|
||
cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类")
|
||
|
||
return
|
||
}
|
||
|
||
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) {
|
||
var quesList []*QuesStruct
|
||
err = json.Unmarshal(cateJson, &quesList)
|
||
quesByte, _ := json.Marshal(quesList)
|
||
modelObj := doubao.NewDouBao(chatModel, key)
|
||
text := []string{
|
||
"根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]",
|
||
"-只返回数字,不要有其他文字和字符",
|
||
"-JSON数据:" + string(quesByte),
|
||
}
|
||
if exCommand != "" {
|
||
text = append(text, exCommand)
|
||
}
|
||
|
||
text = append(text, "-如果无法匹配,则返回0")
|
||
|
||
category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
|
||
cleaned := strings.ReplaceAll(input, "\n", "")
|
||
return cleaned, nil
|
||
}, text...)
|
||
if err != nil {
|
||
return
|
||
}
|
||
categoryId, err = strconv.Atoi(category)
|
||
if err != nil {
|
||
err = fmt.Errorf("为找到分类")
|
||
categoryId = 0
|
||
}
|
||
return
|
||
}
|