From 8a6645c33ce90622a966fbcc797063627e488e96 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Thu, 27 Mar 2025 22:43:56 +0800 Subject: [PATCH] 1 --- ai.go | 92 +++++++++++++++++++++++++++++++++++++++++++----- ai_test.go | 70 +++++++++++++++++++----------------- doubao/doubao.go | 6 ++-- 3 files changed, 125 insertions(+), 43 deletions(-) diff --git a/ai.go b/ai.go index abb7d50..6df5bab 100644 --- a/ai.go +++ b/ai.go @@ -2,21 +2,59 @@ package l_ai_category import ( "context" + "encoding/json" "fmt" "gitea.cdlsxd.cn/self-tools/l_ai_category/doubao" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "strconv" ) -func GetCategory(ctx context.Context, goodsName, key, model string, cateJson string) (categoryId int, err error) { +type Category struct { + Id int `json:"id"` + Name string `json:"name"` + Level int `json:"level"` + Pid int `json:"pid"` +} - modelObj := doubao.NewDouBao(model, key) - text := []string{ - "根据商品名称,从json中找到该商品对应的第三级分类[QUESTION]" + goodsName + "[/QUESTION]", - "-只需要返回类型的名称对应的id,不用返回其他任何文字", - "-如果无法匹配,则返回数字0", - "-以下是 JSON 数据:" + cateJson, +type CategoryDic struct { + Category +} +type QuesStruct struct { + Id int `json:"id"` + Name string `json:"name"` +} + +func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) { + cate1Json, _ := json.Marshal(catePath[0][0]) + cate1, err := ques(ctx, goodsName, key, chatModel, cate1Json) + if err != nil { + return } - category, err := modelObj.GetData(ctx, key, func(input string) (string, error) { + cate2Json, _ := json.Marshal(catePath[1][cate1]) + cate2, err := ques(ctx, goodsName, key, chatModel, cate2Json) + if err != nil { + return + } + cate3Json, _ := json.Marshal(catePath[2][cate2]) + cate3, err = ques(ctx, goodsName, key, chatModel, cate3Json) + if err != nil { + return + } + return +} + +func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte) (categoryId int, err error) { + var quesList []*QuesStruct + err = json.Unmarshal(cateJson, &quesList) + quesByte, _ := json.Marshal(quesList) + modelObj := doubao.NewDouBao(chatModel, key) + text := []string{ + "根据商品名称,从json中找到该商品对应的分类[goodsName]" + goodsName + "[/goodsName]", + "-只返回分类名称对应的id,返回其他文字", + "-如果无法匹配,则返回数字0", + "-JSON数据:" + string(quesByte), + } + category, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) { return input, nil }, text...) if err != nil { @@ -24,8 +62,44 @@ func GetCategory(ctx context.Context, goodsName, key, model string, cateJson str } categoryId, err = strconv.Atoi(category) if err != nil { - err = fmt.Errorf("未找到商品分类") + err = fmt.Errorf("为找到分类") categoryId = 0 } return } + +func Test(ctx context.Context, key, chatModel string) (category string, err error) { + + modelObj := doubao.NewDouBao(chatModel, key) + text := []string{ + "给我base_json里休闲食品对应的id", + } + category, err = modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) { + return input, nil + }, text...) + if err != nil { + return + } + + return +} + +func FlushCate(ctx context.Context, key, chatModel string, cateJson string) (res string, err error) { + + modelObj := doubao.NewDouBao(chatModel, key) + text := []string{ + "你是一个之智能非关系型数据库", + "你的任务是将提供的 JSON 作为base_json进行存放,之后所有提问提到的base_json都是指这个JSON。", + "以下是base_json:", + fmt.Sprintf("%s", cateJson), + "请将上述 JSON 作为base_json 进行存储。之后有提问提到base_json 时,要明确知晓是指此JSON。", + } + res, err = modelObj.GetData(ctx, key, model.ChatMessageRoleAssistant, func(input string) (string, error) { + return input, nil + }, text...) + if err != nil { + return + } + + return +} diff --git a/ai_test.go b/ai_test.go index 72f29f3..6985068 100644 --- a/ai_test.go +++ b/ai_test.go @@ -2,43 +2,51 @@ package l_ai_category import ( "context" - "encoding/json" "fmt" "gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm" "testing" ) const ( - dns = "" - driver = "mysql" - table = "goods_category" + dns = "root:lansexiongdi6,@tcp(47.97.27.195:3306)/report?parseTime=True&loc=Local" + driver = "mysql" + table = "goods_category" + modelType = "deepseek-v3-250324" ) func TestCategory(t *testing.T) { - catString := string(CategoryToJsonByte()) - res, err := GetCategory(context.Background(), "思嘉思达 雅格电火锅 3.5L电火锅 SKD-G0020", "03320e58-6a0b-4061-a22b-902039f2190d", "deepseek-v3-250324", catString) + path := getPath() + res, err := GetCategory(context.Background(), "倍轻松头部按摩器\niDream 3(尊享版", "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType, path) t.Log(res, err) } -type Category struct { - Id int `json:"id"` - Name string `json:"name"` - Level int `json:"level"` - Pid int `json:"pid"` +func TestRem(t *testing.T) { + + res, err := Test(context.Background(), "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType) + t.Log(res, err) } -type CategoryDic struct { - Category - Children []*CategoryDic `json:"children"` -} +//func TestFlush(t *testing.T) { +// catString := string(CategoryToJsonByte()) +// res, err := FlushCate(context.Background(), "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType, catString) +// t.Log(res, err) +//} +// +//func TestGetCateGoryJson(t *testing.T) { +// cateJsonByte := CategoryToJsonByte() +// +// t.Log(string(cateJsonByte)) +//} -func TestGetCateGoryJson(t *testing.T) { - cateJsonByte := CategoryToJsonByte() +//func CategoryToJsonByte() (cateJsonByte []byte) { +// +// cateLevelPath := getPath() +// cate := cateIteration(cateLevelPath[0], cateIteration(cateLevelPath[1], cateLevelPath[2]))[0] +// cateJsonByte, _ = json.Marshal(cate) +// return +//} - t.Log(string(cateJsonByte)) -} - -func CategoryToJsonByte() (cateJsonByte []byte) { +func getPath() []map[int][]*CategoryDic { cates := cateAll() var ( cateLevelPath = make([]map[int][]*CategoryDic, 3) @@ -52,19 +60,17 @@ func CategoryToJsonByte() (cateJsonByte []byte) { } cateLevelPath[v.Level-1][v.Pid] = append(cateLevelPath[v.Level-1][v.Pid], &CategoryDic{Category: v}) } - cate := cateIteration(cateLevelPath[0], cateIteration(cateLevelPath[1], cateLevelPath[2]))[0] - cateJsonByte, _ = json.Marshal(cate) - return + return cateLevelPath } -func cateIteration(parent map[int][]*CategoryDic, child map[int][]*CategoryDic) map[int][]*CategoryDic { - for _, v := range parent { - for _, vv := range v { - vv.Children = child[vv.Id] - } - } - return parent -} +//func cateIteration(parent map[int][]*CategoryDic, child map[int][]*CategoryDic) map[int][]*CategoryDic { +// for _, v := range parent { +// for _, vv := range v { +// vv.Children = child[vv.Id] +// } +// } +// return parent +//} func cateAll() (data []Category) { var ( diff --git a/doubao/doubao.go b/doubao/doubao.go index fb82f25..12002e9 100644 --- a/doubao/doubao.go +++ b/doubao/doubao.go @@ -2,6 +2,7 @@ package doubao import ( "context" + "fmt" "github.com/volcengine/volcengine-go-sdk/service/arkruntime" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "github.com/volcengine/volcengine-go-sdk/volcengine" @@ -20,7 +21,7 @@ func NewDouBao(modelName string, key string) *DouBao { } } -func (o *DouBao) GetData(ctx context.Context, key string, respHandle func(input string) (string, error), text ...string) (string, error) { +func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandle func(input string) (string, error), text ...string) (string, error) { var Message = make([]*model.ChatCompletionMessage, len(text)) client := arkruntime.NewClientWithApiKey( @@ -32,7 +33,7 @@ func (o *DouBao) GetData(ctx context.Context, key string, respHandle func(input ) for k, v := range text { Message[k] = &model.ChatCompletionMessage{ - Role: model.ChatMessageRoleSystem, + Role: role, Content: &model.ChatCompletionMessageContent{ StringValue: volcengine.String(v), }, @@ -47,6 +48,7 @@ func (o *DouBao) GetData(ctx context.Context, key string, respHandle func(input if err != nil { return "", err } + fmt.Printf("用量:%d", resp.Usage.TotalTokens) result, err := respHandle(*resp.Choices[0].Message.Content.StringValue) return result, err