This commit is contained in:
parent
3296b47211
commit
0c5ccfa73c
77
ai.go
77
ai.go
|
@ -7,6 +7,7 @@ import (
|
||||||
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao"
|
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao"
|
||||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Category struct {
|
type Category struct {
|
||||||
|
@ -25,37 +26,49 @@ type QuesStruct struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) {
|
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) {
|
||||||
|
var (
|
||||||
|
usage1 int
|
||||||
|
usage2 int
|
||||||
|
usage3 int
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3)
|
||||||
|
}()
|
||||||
cate1Json, _ := json.Marshal(catePath[0][1])
|
cate1Json, _ := json.Marshal(catePath[0][1])
|
||||||
cate1, err := ques(ctx, goodsName, key, chatModel, cate1Json)
|
cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "")
|
||||||
if err != nil {
|
if err != nil || cate1 == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cate2Json, _ := json.Marshal(catePath[1][cate1])
|
cate2Json, _ := json.Marshal(catePath[1][cate1])
|
||||||
cate2, err := ques(ctx, goodsName, key, chatModel, cate2Json)
|
cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "")
|
||||||
if err != nil {
|
if err != nil || cate2 == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cate3Json, _ := json.Marshal(catePath[2][cate2])
|
cate3Json, _ := json.Marshal(catePath[2][cate2])
|
||||||
cate3, err = ques(ctx, goodsName, key, chatModel, cate3Json)
|
cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类")
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte) (categoryId int, err error) {
|
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) {
|
||||||
var quesList []*QuesStruct
|
var quesList []*QuesStruct
|
||||||
err = json.Unmarshal(cateJson, &quesList)
|
err = json.Unmarshal(cateJson, &quesList)
|
||||||
quesByte, _ := json.Marshal(quesList)
|
quesByte, _ := json.Marshal(quesList)
|
||||||
modelObj := doubao.NewDouBao(chatModel, key)
|
modelObj := doubao.NewDouBao(chatModel, key)
|
||||||
text := []string{
|
text := []string{
|
||||||
"根据商品名称,从json中找到该商品对应的分类[goodsName]" + goodsName + "[/goodsName]",
|
"根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]",
|
||||||
"-只返回分类名称对应的id,返回其他文字",
|
"-只返回数字,不要有其他文字和字符",
|
||||||
"-如果无法匹配,则返回数字0",
|
|
||||||
"-JSON数据:" + string(quesByte),
|
"-JSON数据:" + string(quesByte),
|
||||||
}
|
}
|
||||||
category, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
|
if exCommand != "" {
|
||||||
return input, nil
|
text = append(text, exCommand)
|
||||||
|
}
|
||||||
|
|
||||||
|
text = append(text, "-如果无法匹配,则返回0")
|
||||||
|
|
||||||
|
category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
|
||||||
|
cleaned := strings.ReplaceAll(input, "\n", "")
|
||||||
|
return cleaned, nil
|
||||||
}, text...)
|
}, text...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -67,39 +80,3 @@ func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test(ctx context.Context, key, chatModel string) (category string, err error) {
|
|
||||||
|
|
||||||
modelObj := doubao.NewDouBao(chatModel, key)
|
|
||||||
text := []string{
|
|
||||||
"给我base_json里休闲食品对应的id",
|
|
||||||
}
|
|
||||||
category, err = modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
|
|
||||||
return input, nil
|
|
||||||
}, text...)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func FlushCate(ctx context.Context, key, chatModel string, cateJson string) (res string, err error) {
|
|
||||||
|
|
||||||
modelObj := doubao.NewDouBao(chatModel, key)
|
|
||||||
text := []string{
|
|
||||||
"你是一个之智能非关系型数据库",
|
|
||||||
"你的任务是将提供的 JSON 作为base_json进行存放,之后所有提问提到的base_json都是指这个JSON。",
|
|
||||||
"以下是base_json:",
|
|
||||||
fmt.Sprintf("<base_json>%s</base_json>", cateJson),
|
|
||||||
"请将上述 JSON 作为base_json 进行存储。之后有提问提到base_json 时,要明确知晓是指此JSON。",
|
|
||||||
}
|
|
||||||
res, err = modelObj.GetData(ctx, key, model.ChatMessageRoleAssistant, func(input string) (string, error) {
|
|
||||||
return input, nil
|
|
||||||
}, text...)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
180
ai_test.go
180
ai_test.go
|
@ -4,84 +4,95 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm"
|
"gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dns = "root:lansexiongdi6,@tcp(47.97.27.195:3306)/report?parseTime=True&loc=Local"
|
dns = "developer:Lsxd@2024@tcp(lsxdpolar.rwlb.rds.aliyuncs.com:3306)/physicalgoods?parseTime=True&loc=Local"
|
||||||
driver = "mysql"
|
driver = "mysql"
|
||||||
table = "goods_category"
|
table = "category"
|
||||||
modelType = "deepseek-v3-250324"
|
modelType = "deepseek-r1-distill-qwen-32b-250120"
|
||||||
|
key = "03320e58-6a0b-4061-a22b-902039f2190d"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCategory(t *testing.T) {
|
func TestCategory(t *testing.T) {
|
||||||
path := getPath()
|
path := getPath()
|
||||||
res, err := GetCategory(context.Background(), "MM 便携装湿厕纸 40包*10片@默认", "03320e58-6a0b-4061-a22b-902039f2190d", modelType, path)
|
res, err := GetCategory(context.Background(), "半亩花田倍润身体乳(失重玫瑰)250ml@默认", "03320e58-6a0b-4061-a22b-902039f2190d", modelType, path)
|
||||||
t.Log(res, err)
|
t.Log(res, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRem(t *testing.T) {
|
func TestCategoryGoodsSet(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
res, err := Test(context.Background(), "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType)
|
|
||||||
t.Log(res, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//func TestFlush(t *testing.T) {
|
|
||||||
// catString := string(CategoryToJsonByte())
|
|
||||||
// res, err := FlushCate(context.Background(), "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType, catString)
|
|
||||||
// t.Log(res, err)
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//func TestGetCateGoryJson(t *testing.T) {
|
|
||||||
// cateJsonByte := CategoryToJsonByte()
|
|
||||||
//
|
|
||||||
// t.Log(string(cateJsonByte))
|
|
||||||
//}
|
|
||||||
|
|
||||||
//func CategoryToJsonByte() (cateJsonByte []byte) {
|
|
||||||
//
|
|
||||||
// cateLevelPath := getPath()
|
|
||||||
// cate := cateIteration(cateLevelPath[0], cateIteration(cateLevelPath[1], cateLevelPath[2]))[0]
|
|
||||||
// cateJsonByte, _ = json.Marshal(cate)
|
|
||||||
// return
|
|
||||||
//}
|
|
||||||
|
|
||||||
func getPath() []map[int][]*CategoryDic {
|
|
||||||
cates := cateAll()
|
|
||||||
var (
|
|
||||||
cateLevelPath = make([]map[int][]*CategoryDic, 3)
|
|
||||||
)
|
|
||||||
for _, v := range cates {
|
|
||||||
if cateLevelPath[v.Level-1] == nil {
|
|
||||||
cateLevelPath[v.Level-1] = make(map[int][]*CategoryDic)
|
|
||||||
}
|
|
||||||
if _, ex := cateLevelPath[v.Level-1][v.Pid]; !ex {
|
|
||||||
cateLevelPath[v.Level-1][v.Pid] = make([]*CategoryDic, 0)
|
|
||||||
}
|
|
||||||
cateLevelPath[v.Level-1][v.Pid] = append(cateLevelPath[v.Level-1][v.Pid], &CategoryDic{Category: v})
|
|
||||||
}
|
|
||||||
return cateLevelPath
|
|
||||||
}
|
|
||||||
|
|
||||||
//func cateIteration(parent map[int][]*CategoryDic, child map[int][]*CategoryDic) map[int][]*CategoryDic {
|
|
||||||
// for _, v := range parent {
|
|
||||||
// for _, vv := range v {
|
|
||||||
// vv.Children = child[vv.Id]
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// return parent
|
|
||||||
//}
|
|
||||||
|
|
||||||
func cateAll() (data []Category) {
|
|
||||||
var (
|
|
||||||
row Category
|
|
||||||
)
|
|
||||||
db, err, clean := utils_gorm.DB(driver, dns)
|
|
||||||
defer clean()
|
defer clean()
|
||||||
if err != nil {
|
path := getPath()
|
||||||
panic(err)
|
goods := goodsAll()
|
||||||
|
updateChan := make(chan *UpdateLog, len(goods))
|
||||||
|
start := time.Now().Unix()
|
||||||
|
wg.Add(len(goods))
|
||||||
|
for _, v := range goods {
|
||||||
|
go func(updateChan chan *UpdateLog) {
|
||||||
|
defer wg.Done()
|
||||||
|
res, err := GetCategory(context.Background(), v.Title, key, modelType, path)
|
||||||
|
|
||||||
|
if err != nil || res == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updateChan <- &UpdateLog{
|
||||||
|
GoodsId: v.Id,
|
||||||
|
Title: v.Title,
|
||||||
|
CateId: res,
|
||||||
|
}
|
||||||
|
}(updateChan)
|
||||||
}
|
}
|
||||||
rows, _ := db.Raw(fmt.Sprintf("SELECT id,name,level,pid FROM `%s` where `type`=2", table)).Rows()
|
wg.Wait()
|
||||||
|
close(updateChan)
|
||||||
|
end := time.Now().Unix()
|
||||||
|
fmt.Printf("耗时:%d秒", end-start)
|
||||||
|
updateSlice := make([]*UpdateLog, len(updateChan))
|
||||||
|
for v := range updateChan {
|
||||||
|
updateSlice = append(updateSlice, v)
|
||||||
|
}
|
||||||
|
for _, v := range updateSlice {
|
||||||
|
if v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
updateGoodsCate(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log(updateSlice)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateGoodsCate(data *UpdateLog) {
|
||||||
|
db.Table("goods_category_relation").Where("goods_id = ?", data.GoodsId).Updates(map[string]interface{}{"category_id": data.CateId})
|
||||||
|
db.Table("goods").Where("id = ?", data.GoodsId).Updates(map[string]interface{}{"brand_category_check": 3})
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateLog struct {
|
||||||
|
GoodsId int `json:"goods_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
CateId int `json:"cate_id"`
|
||||||
|
}
|
||||||
|
type CategoryPhy struct {
|
||||||
|
Id int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Level int `json:"level"`
|
||||||
|
ParentId int `json:"parent_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Goods struct {
|
||||||
|
Id int `json:"id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
BrandCateGoryCheck int `json:"brand_category_check"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func goodsAll() (data []Goods) {
|
||||||
|
var (
|
||||||
|
row Goods
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
rows, _ := db.Raw("SELECT id,title FROM `goods` where brand_category_check=4").Rows()
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
err = db.ScanRows(rows, &row)
|
err = db.ScanRows(rows, &row)
|
||||||
|
@ -92,3 +103,44 @@ func cateAll() (data []Category) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getPath() []map[int][]*CategoryDic {
|
||||||
|
cates := cateAll()
|
||||||
|
var (
|
||||||
|
cateLevelPath = make([]map[int][]*CategoryDic, 3)
|
||||||
|
)
|
||||||
|
for _, v := range cates {
|
||||||
|
if cateLevelPath[v.Level-1] == nil {
|
||||||
|
cateLevelPath[v.Level-1] = make(map[int][]*CategoryDic)
|
||||||
|
}
|
||||||
|
if _, ex := cateLevelPath[v.Level-1][v.ParentId]; !ex {
|
||||||
|
cateLevelPath[v.Level-1][v.ParentId] = make([]*CategoryDic, 0)
|
||||||
|
}
|
||||||
|
cateLevelPath[v.Level-1][v.ParentId] = append(cateLevelPath[v.Level-1][v.ParentId], &CategoryDic{
|
||||||
|
Category{
|
||||||
|
Id: v.Id,
|
||||||
|
Name: v.Name,
|
||||||
|
Level: v.Level,
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
return cateLevelPath
|
||||||
|
}
|
||||||
|
|
||||||
|
func cateAll() (data []CategoryPhy) {
|
||||||
|
var (
|
||||||
|
row CategoryPhy
|
||||||
|
)
|
||||||
|
|
||||||
|
rows, _ := db.Raw(fmt.Sprintf("SELECT id,name,level,parent_id FROM `%s`", table)).Rows()
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
err := db.ScanRows(rows, &row)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
data = append(data, row)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var db, _, clean = utils_gorm.DB(driver, dns)
|
||||||
|
|
|
@ -2,7 +2,6 @@ package doubao
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||||
"github.com/volcengine/volcengine-go-sdk/volcengine"
|
"github.com/volcengine/volcengine-go-sdk/volcengine"
|
||||||
|
@ -21,7 +20,7 @@ func NewDouBao(modelName string, key string) *DouBao {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandle func(input string) (string, error), text ...string) (string, error) {
|
func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandle func(input string) (string, error), text ...string) (string, int, error) {
|
||||||
var Message = make([]*model.ChatCompletionMessage, len(text))
|
var Message = make([]*model.ChatCompletionMessage, len(text))
|
||||||
|
|
||||||
client := arkruntime.NewClientWithApiKey(
|
client := arkruntime.NewClientWithApiKey(
|
||||||
|
@ -46,10 +45,10 @@ func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandl
|
||||||
|
|
||||||
resp, err := client.CreateChatCompletion(ctx, req)
|
resp, err := client.CreateChatCompletion(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
fmt.Printf("用量:%d", resp.Usage.TotalTokens)
|
|
||||||
result, err := respHandle(*resp.Choices[0].Message.Content.StringValue)
|
result, err := respHandle(*resp.Choices[0].Message.Content.StringValue)
|
||||||
|
|
||||||
return result, err
|
return result, resp.Usage.TotalTokens, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue