This commit is contained in:
renzhiyuan 2025-03-27 22:43:56 +08:00
parent b5bb2fad66
commit 8a6645c33c
3 changed files with 125 additions and 43 deletions

92
ai.go
View File

@ -2,21 +2,59 @@ package l_ai_category
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao" "gitea.cdlsxd.cn/self-tools/l_ai_category/doubao"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"strconv" "strconv"
) )
func GetCategory(ctx context.Context, goodsName, key, model string, cateJson string) (categoryId int, err error) { type Category struct {
Id int `json:"id"`
Name string `json:"name"`
Level int `json:"level"`
Pid int `json:"pid"`
}
modelObj := doubao.NewDouBao(model, key) type CategoryDic struct {
text := []string{ Category
"根据商品名称,从json中找到该商品对应的第三级分类[QUESTION]" + goodsName + "[/QUESTION]", }
"-只需要返回类型的名称对应的id,不用返回其他任何文字", type QuesStruct struct {
"-如果无法匹配则返回数字0", Id int `json:"id"`
"-以下是 JSON 数据:" + cateJson, Name string `json:"name"`
}
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) {
cate1Json, _ := json.Marshal(catePath[0][0])
cate1, err := ques(ctx, goodsName, key, chatModel, cate1Json)
if err != nil {
return
} }
category, err := modelObj.GetData(ctx, key, func(input string) (string, error) { cate2Json, _ := json.Marshal(catePath[1][cate1])
cate2, err := ques(ctx, goodsName, key, chatModel, cate2Json)
if err != nil {
return
}
cate3Json, _ := json.Marshal(catePath[2][cate2])
cate3, err = ques(ctx, goodsName, key, chatModel, cate3Json)
if err != nil {
return
}
return
}
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte) (categoryId int, err error) {
var quesList []*QuesStruct
err = json.Unmarshal(cateJson, &quesList)
quesByte, _ := json.Marshal(quesList)
modelObj := doubao.NewDouBao(chatModel, key)
text := []string{
"根据商品名称,从json中找到该商品对应的分类[goodsName]" + goodsName + "[/goodsName]",
"-只返回分类名称对应的id,返回其他文字",
"-如果无法匹配则返回数字0",
"-JSON数据" + string(quesByte),
}
category, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
return input, nil return input, nil
}, text...) }, text...)
if err != nil { if err != nil {
@ -24,8 +62,44 @@ func GetCategory(ctx context.Context, goodsName, key, model string, cateJson str
} }
categoryId, err = strconv.Atoi(category) categoryId, err = strconv.Atoi(category)
if err != nil { if err != nil {
err = fmt.Errorf("未找到商品分类") err = fmt.Errorf("为找到分类")
categoryId = 0 categoryId = 0
} }
return return
} }
func Test(ctx context.Context, key, chatModel string) (category string, err error) {
modelObj := doubao.NewDouBao(chatModel, key)
text := []string{
"给我base_json里休闲食品对应的id",
}
category, err = modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
return input, nil
}, text...)
if err != nil {
return
}
return
}
func FlushCate(ctx context.Context, key, chatModel string, cateJson string) (res string, err error) {
modelObj := doubao.NewDouBao(chatModel, key)
text := []string{
"你是一个之智能非关系型数据库",
"你的任务是将提供的 JSON 作为base_json进行存放之后所有提问提到的base_json都是指这个JSON。",
"以下是base_json",
fmt.Sprintf("<base_json>%s</base_json>", cateJson),
"请将上述 JSON 作为base_json 进行存储。之后有提问提到base_json 时要明确知晓是指此JSON。",
}
res, err = modelObj.GetData(ctx, key, model.ChatMessageRoleAssistant, func(input string) (string, error) {
return input, nil
}, text...)
if err != nil {
return
}
return
}

View File

@ -2,43 +2,51 @@ package l_ai_category
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm" "gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm"
"testing" "testing"
) )
const ( const (
dns = "" dns = "root:lansexiongdi6,@tcp(47.97.27.195:3306)/report?parseTime=True&loc=Local"
driver = "mysql" driver = "mysql"
table = "goods_category" table = "goods_category"
modelType = "deepseek-v3-250324"
) )
func TestCategory(t *testing.T) { func TestCategory(t *testing.T) {
catString := string(CategoryToJsonByte()) path := getPath()
res, err := GetCategory(context.Background(), "思嘉思达 雅格电火锅 3.5L电火锅 SKD-G0020", "03320e58-6a0b-4061-a22b-902039f2190d", "deepseek-v3-250324", catString) res, err := GetCategory(context.Background(), "倍轻松头部按摩器\niDream 3尊享版", "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType, path)
t.Log(res, err) t.Log(res, err)
} }
type Category struct { func TestRem(t *testing.T) {
Id int `json:"id"`
Name string `json:"name"` res, err := Test(context.Background(), "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType)
Level int `json:"level"` t.Log(res, err)
Pid int `json:"pid"`
} }
type CategoryDic struct { //func TestFlush(t *testing.T) {
Category // catString := string(CategoryToJsonByte())
Children []*CategoryDic `json:"children"` // res, err := FlushCate(context.Background(), "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType, catString)
} // t.Log(res, err)
//}
//
//func TestGetCateGoryJson(t *testing.T) {
// cateJsonByte := CategoryToJsonByte()
//
// t.Log(string(cateJsonByte))
//}
func TestGetCateGoryJson(t *testing.T) { //func CategoryToJsonByte() (cateJsonByte []byte) {
cateJsonByte := CategoryToJsonByte() //
// cateLevelPath := getPath()
// cate := cateIteration(cateLevelPath[0], cateIteration(cateLevelPath[1], cateLevelPath[2]))[0]
// cateJsonByte, _ = json.Marshal(cate)
// return
//}
t.Log(string(cateJsonByte)) func getPath() []map[int][]*CategoryDic {
}
func CategoryToJsonByte() (cateJsonByte []byte) {
cates := cateAll() cates := cateAll()
var ( var (
cateLevelPath = make([]map[int][]*CategoryDic, 3) cateLevelPath = make([]map[int][]*CategoryDic, 3)
@ -52,19 +60,17 @@ func CategoryToJsonByte() (cateJsonByte []byte) {
} }
cateLevelPath[v.Level-1][v.Pid] = append(cateLevelPath[v.Level-1][v.Pid], &CategoryDic{Category: v}) cateLevelPath[v.Level-1][v.Pid] = append(cateLevelPath[v.Level-1][v.Pid], &CategoryDic{Category: v})
} }
cate := cateIteration(cateLevelPath[0], cateIteration(cateLevelPath[1], cateLevelPath[2]))[0] return cateLevelPath
cateJsonByte, _ = json.Marshal(cate)
return
} }
func cateIteration(parent map[int][]*CategoryDic, child map[int][]*CategoryDic) map[int][]*CategoryDic { //func cateIteration(parent map[int][]*CategoryDic, child map[int][]*CategoryDic) map[int][]*CategoryDic {
for _, v := range parent { // for _, v := range parent {
for _, vv := range v { // for _, vv := range v {
vv.Children = child[vv.Id] // vv.Children = child[vv.Id]
} // }
} // }
return parent // return parent
} //}
func cateAll() (data []Category) { func cateAll() (data []Category) {
var ( var (

View File

@ -2,6 +2,7 @@ package doubao
import ( import (
"context" "context"
"fmt"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime" "github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/volcengine" "github.com/volcengine/volcengine-go-sdk/volcengine"
@ -20,7 +21,7 @@ func NewDouBao(modelName string, key string) *DouBao {
} }
} }
func (o *DouBao) GetData(ctx context.Context, key string, respHandle func(input string) (string, error), text ...string) (string, error) { func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandle func(input string) (string, error), text ...string) (string, error) {
var Message = make([]*model.ChatCompletionMessage, len(text)) var Message = make([]*model.ChatCompletionMessage, len(text))
client := arkruntime.NewClientWithApiKey( client := arkruntime.NewClientWithApiKey(
@ -32,7 +33,7 @@ func (o *DouBao) GetData(ctx context.Context, key string, respHandle func(input
) )
for k, v := range text { for k, v := range text {
Message[k] = &model.ChatCompletionMessage{ Message[k] = &model.ChatCompletionMessage{
Role: model.ChatMessageRoleSystem, Role: role,
Content: &model.ChatCompletionMessageContent{ Content: &model.ChatCompletionMessageContent{
StringValue: volcengine.String(v), StringValue: volcengine.String(v),
}, },
@ -47,6 +48,7 @@ func (o *DouBao) GetData(ctx context.Context, key string, respHandle func(input
if err != nil { if err != nil {
return "", err return "", err
} }
fmt.Printf("用量:%d", resp.Usage.TotalTokens)
result, err := respHandle(*resp.Choices[0].Message.Content.StringValue) result, err := respHandle(*resp.Choices[0].Message.Content.StringValue)
return result, err return result, err