l_ai_category/ai.go

106 lines
2.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package l_ai_category
import (
"context"
"encoding/json"
"fmt"
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"strconv"
)
type Category struct {
Id int `json:"id"`
Name string `json:"name"`
Level int `json:"level"`
Pid int `json:"pid"`
}
type CategoryDic struct {
Category
}
type QuesStruct struct {
Id int `json:"id"`
Name string `json:"name"`
}
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) {
cate1Json, _ := json.Marshal(catePath[0][1])
cate1, err := ques(ctx, goodsName, key, chatModel, cate1Json)
if err != nil {
return
}
cate2Json, _ := json.Marshal(catePath[1][cate1])
cate2, err := ques(ctx, goodsName, key, chatModel, cate2Json)
if err != nil {
return
}
cate3Json, _ := json.Marshal(catePath[2][cate2])
cate3, err = ques(ctx, goodsName, key, chatModel, cate3Json)
if err != nil {
return
}
return
}
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte) (categoryId int, err error) {
var quesList []*QuesStruct
err = json.Unmarshal(cateJson, &quesList)
quesByte, _ := json.Marshal(quesList)
modelObj := doubao.NewDouBao(chatModel, key)
text := []string{
"根据商品名称,从json中找到该商品对应的分类[goodsName]" + goodsName + "[/goodsName]",
"-只返回分类名称对应的id,返回其他文字",
"-如果无法匹配则返回数字0",
"-JSON数据" + string(quesByte),
}
category, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
return input, nil
}, text...)
if err != nil {
return
}
categoryId, err = strconv.Atoi(category)
if err != nil {
err = fmt.Errorf("为找到分类")
categoryId = 0
}
return
}
func Test(ctx context.Context, key, chatModel string) (category string, err error) {
modelObj := doubao.NewDouBao(chatModel, key)
text := []string{
"给我base_json里休闲食品对应的id",
}
category, err = modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
return input, nil
}, text...)
if err != nil {
return
}
return
}
func FlushCate(ctx context.Context, key, chatModel string, cateJson string) (res string, err error) {
modelObj := doubao.NewDouBao(chatModel, key)
text := []string{
"你是一个之智能非关系型数据库",
"你的任务是将提供的 JSON 作为base_json进行存放之后所有提问提到的base_json都是指这个JSON。",
"以下是base_json",
fmt.Sprintf("<base_json>%s</base_json>", cateJson),
"请将上述 JSON 作为base_json 进行存储。之后有提问提到base_json 时要明确知晓是指此JSON。",
}
res, err = modelObj.GetData(ctx, key, model.ChatMessageRoleAssistant, func(input string) (string, error) {
return input, nil
}, text...)
if err != nil {
return
}
return
}