106 lines
2.9 KiB
Go
106 lines
2.9 KiB
Go
package l_ai_category
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao"
|
||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||
"strconv"
|
||
)
|
||
|
||
type Category struct {
|
||
Id int `json:"id"`
|
||
Name string `json:"name"`
|
||
Level int `json:"level"`
|
||
Pid int `json:"pid"`
|
||
}
|
||
|
||
type CategoryDic struct {
|
||
Category
|
||
}
|
||
type QuesStruct struct {
|
||
Id int `json:"id"`
|
||
Name string `json:"name"`
|
||
}
|
||
|
||
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) {
|
||
cate1Json, _ := json.Marshal(catePath[0][1])
|
||
cate1, err := ques(ctx, goodsName, key, chatModel, cate1Json)
|
||
if err != nil {
|
||
return
|
||
}
|
||
cate2Json, _ := json.Marshal(catePath[1][cate1])
|
||
cate2, err := ques(ctx, goodsName, key, chatModel, cate2Json)
|
||
if err != nil {
|
||
return
|
||
}
|
||
cate3Json, _ := json.Marshal(catePath[2][cate2])
|
||
cate3, err = ques(ctx, goodsName, key, chatModel, cate3Json)
|
||
if err != nil {
|
||
return
|
||
}
|
||
return
|
||
}
|
||
|
||
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte) (categoryId int, err error) {
|
||
var quesList []*QuesStruct
|
||
err = json.Unmarshal(cateJson, &quesList)
|
||
quesByte, _ := json.Marshal(quesList)
|
||
modelObj := doubao.NewDouBao(chatModel, key)
|
||
text := []string{
|
||
"根据商品名称,从json中找到该商品对应的分类[goodsName]" + goodsName + "[/goodsName]",
|
||
"-只返回分类名称对应的id,返回其他文字",
|
||
"-如果无法匹配,则返回数字0",
|
||
"-JSON数据:" + string(quesByte),
|
||
}
|
||
category, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
|
||
return input, nil
|
||
}, text...)
|
||
if err != nil {
|
||
return
|
||
}
|
||
categoryId, err = strconv.Atoi(category)
|
||
if err != nil {
|
||
err = fmt.Errorf("为找到分类")
|
||
categoryId = 0
|
||
}
|
||
return
|
||
}
|
||
|
||
func Test(ctx context.Context, key, chatModel string) (category string, err error) {
|
||
|
||
modelObj := doubao.NewDouBao(chatModel, key)
|
||
text := []string{
|
||
"给我base_json里休闲食品对应的id",
|
||
}
|
||
category, err = modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
|
||
return input, nil
|
||
}, text...)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func FlushCate(ctx context.Context, key, chatModel string, cateJson string) (res string, err error) {
|
||
|
||
modelObj := doubao.NewDouBao(chatModel, key)
|
||
text := []string{
|
||
"你是一个之智能非关系型数据库",
|
||
"你的任务是将提供的 JSON 作为base_json进行存放,之后所有提问提到的base_json都是指这个JSON。",
|
||
"以下是base_json:",
|
||
fmt.Sprintf("<base_json>%s</base_json>", cateJson),
|
||
"请将上述 JSON 作为base_json 进行存储。之后有提问提到base_json 时,要明确知晓是指此JSON。",
|
||
}
|
||
res, err = modelObj.GetData(ctx, key, model.ChatMessageRoleAssistant, func(input string) (string, error) {
|
||
return input, nil
|
||
}, text...)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|