l_ai_category/ai.go

83 lines
2.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package l_ai_category
import (
"context"
"encoding/json"
"fmt"
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"strconv"
"strings"
)
type Category struct {
Id int `json:"id"`
Name string `json:"name"`
Level int `json:"level"`
Pid int `json:"pid"`
}
type CategoryDic struct {
Category
}
type QuesStruct struct {
Id int `json:"id"`
Name string `json:"name"`
}
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) {
var (
usage1 int
usage2 int
usage3 int
)
defer func() {
fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3)
}()
cate1Json, _ := json.Marshal(catePath[0][1])
cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "")
if err != nil || cate1 == 0 {
return
}
cate2Json, _ := json.Marshal(catePath[1][cate1])
cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "")
if err != nil || cate2 == 0 {
return
}
cate3Json, _ := json.Marshal(catePath[2][cate2])
cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类")
return
}
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) {
var quesList []*QuesStruct
err = json.Unmarshal(cateJson, &quesList)
quesByte, _ := json.Marshal(quesList)
modelObj := doubao.NewDouBao(chatModel, key)
text := []string{
"根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]",
"-只返回数字,不要有其他文字和字符",
"-JSON数据" + string(quesByte),
}
if exCommand != "" {
text = append(text, exCommand)
}
text = append(text, "-如果无法匹配则返回0")
category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
cleaned := strings.ReplaceAll(input, "\n", "")
return cleaned, nil
}, text...)
if err != nil {
return
}
categoryId, err = strconv.Atoi(category)
if err != nil {
err = fmt.Errorf("为找到分类")
categoryId = 0
}
return
}