This commit is contained in:
renzhiyuan 2025-08-08 16:48:16 +08:00
parent 0c5ccfa73c
commit 2ec7c81a4d
3 changed files with 140 additions and 85 deletions

106
ai.go
View File

@ -4,17 +4,54 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao" "gitea.cdlsxd.cn/self-tools/l_request"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"strconv" "strconv"
"strings" "strings"
) )
func GetCategoryWithSelf(ctx context.Context, goodsName string, address string) (cateId int, confidence float64, err error) {
cateId, name, confidence, err := getCategoryWithAi(ctx, goodsName, address)
fmt.Printf("商品:%s,所属分类:%s,置信度:%f\n", goodsName, name, confidence)
return
}
type CatRes struct {
Predictions []struct {
Category string `json:"category"`
Confidence float64 `json:"confidence"`
} `json:"predictions"`
Status string `json:"status"`
}
func getCategoryWithAi(ctx context.Context, goodsName string, address string) (cateId int, cateName string, confidence float64, err error) {
request := map[string]interface{}{
"product_names": goodsName,
}
req := l_request.Request{
Method: "POST",
Url: address,
Json: request,
}
res, err := req.Send()
if err != nil {
return
}
var resMap CatRes
err = json.Unmarshal(res.Content, &resMap)
if err != nil {
return
}
cateInfo := strings.Split(resMap.Predictions[0].Category, "_")
cateId, _ = strconv.Atoi(cateInfo[1])
return cateId, cateInfo[0], resMap.Predictions[0].Confidence, nil
}
type Category struct { type Category struct {
Id int `json:"id"` Id int `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Level int `json:"level"` Level int `json:"level"`
Pid int `json:"pid"` Pid int `json:"pid"`
Pids map[int]int `json:"pid"`
} }
type CategoryDic struct { type CategoryDic struct {
@ -25,58 +62,7 @@ type QuesStruct struct {
Name string `json:"name"` Name string `json:"name"`
} }
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) { type Req struct {
var ( Text string `json:"text"`
usage1 int Cate []string `json:"cate"`
usage2 int
usage3 int
)
defer func() {
fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3)
}()
cate1Json, _ := json.Marshal(catePath[0][1])
cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "")
if err != nil || cate1 == 0 {
return
}
cate2Json, _ := json.Marshal(catePath[1][cate1])
cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "")
if err != nil || cate2 == 0 {
return
}
cate3Json, _ := json.Marshal(catePath[2][cate2])
cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类")
return
}
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) {
var quesList []*QuesStruct
err = json.Unmarshal(cateJson, &quesList)
quesByte, _ := json.Marshal(quesList)
modelObj := doubao.NewDouBao(chatModel, key)
text := []string{
"根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]",
"-只返回数字,不要有其他文字和字符",
"-JSON数据" + string(quesByte),
}
if exCommand != "" {
text = append(text, exCommand)
}
text = append(text, "-如果无法匹配则返回0")
category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
cleaned := strings.ReplaceAll(input, "\n", "")
return cleaned, nil
}, text...)
if err != nil {
return
}
categoryId, err = strconv.Atoi(category)
if err != nil {
err = fmt.Errorf("为找到分类")
categoryId = 0
}
return
} }

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm" "gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm"
"sync"
"testing" "testing"
"time" "time"
) )
@ -15,38 +14,33 @@ const (
table = "category" table = "category"
modelType = "deepseek-r1-distill-qwen-32b-250120" modelType = "deepseek-r1-distill-qwen-32b-250120"
key = "03320e58-6a0b-4061-a22b-902039f2190d" key = "03320e58-6a0b-4061-a22b-902039f2190d"
Address = "http://8.147.107.207:5000/predict"
) )
func TestCategory(t *testing.T) {
path := getPath()
res, err := GetCategory(context.Background(), "半亩花田倍润身体乳失重玫瑰250ml@默认", "03320e58-6a0b-4061-a22b-902039f2190d", modelType, path)
t.Log(res, err)
}
func TestCategoryGoodsSet(t *testing.T) { func TestCategoryGoodsSet(t *testing.T) {
var wg sync.WaitGroup //var wg sync.WaitGroup
defer clean() defer clean()
path := getPath() ///path := getPath()
goods := goodsAll() goods := goodsAll()
updateChan := make(chan *UpdateLog, len(goods)) updateChan := make(chan *UpdateLog, len(goods))
start := time.Now().Unix() start := time.Now().Unix()
wg.Add(len(goods)) //wg.Add(len(goods))
for _, v := range goods { for _, v := range goods {
go func(updateChan chan *UpdateLog) { //go func(updateChan chan *UpdateLog) {
defer wg.Done() //defer wg.Done()
res, err := GetCategory(context.Background(), v.Title, key, modelType, path) res, _, err := GetCategoryWithSelf(context.Background(), v.Title, Address)
if err != nil || res == 0 { if err != nil || res == 0 {
return return
} }
updateChan <- &UpdateLog{ updateChan <- &UpdateLog{
GoodsId: v.Id, GoodsId: v.Id,
Title: v.Title, Title: v.Title,
CateId: res, CateId: res,
} }
}(updateChan) //}(updateChan)
} }
wg.Wait() //wg.Wait()
close(updateChan) close(updateChan)
end := time.Now().Unix() end := time.Now().Unix()
fmt.Printf("耗时:%d秒", end-start) fmt.Printf("耗时:%d秒", end-start)
@ -92,7 +86,7 @@ func goodsAll() (data []Goods) {
row Goods row Goods
err error err error
) )
rows, _ := db.Raw("SELECT id,title FROM `goods` where brand_category_check=4").Rows() rows, _ := db.Raw("SELECT id,title FROM `goods` order by id desc limit 1000 offset 3156").Rows() //where brand_category_check=4
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
err = db.ScanRows(rows, &row) err = db.ScanRows(rows, &row)
@ -126,6 +120,78 @@ func getPath() []map[int][]*CategoryDic {
return cateLevelPath return cateLevelPath
} }
func getPath2() []Category {
cates := cateAll()
var (
level1Map = make(map[int]Category)
level2Map = make(map[int]Category)
level3Slice = make([]Category, 0)
)
for _, v := range cates {
switch v.Level {
case 1:
level1Map[v.Id] = Category{
Id: v.Id,
Name: v.Name,
}
case 2:
level2Map[v.Id] = Category{
Id: v.Id,
Name: v.Name,
Pid: v.ParentId,
}
case 3:
cateName := fmt.Sprintf("%s/%s/%s", level1Map[level2Map[v.ParentId].Pid].Name, level2Map[v.ParentId].Name, v.Name)
level3Slice = append(level3Slice, Category{
Id: v.Id,
Name: cateName,
})
}
}
return level3Slice
}
func getPath3() (map[int]Category, map[string]Category) {
cates := cateAll()
var (
level1Map = make(map[int]Category)
level2Map = make(map[int]Category)
level3Map = make(map[string]Category)
)
for _, v := range cates {
switch v.Level {
case 1:
level2Map[v.Id] = Category{
Name: v.Name,
}
case 2:
cate_name := fmt.Sprintf("%s/%s", level1Map[v.ParentId].Name, v.Name)
level2Map[v.Id] = Category{
Id: v.Id,
Name: cate_name,
Pid: v.ParentId,
}
case 3:
if _, ex := level3Map[v.Name]; !ex {
level3Map[v.Name] = Category{
Id: v.Id,
Pids: map[int]int{v.Id: v.ParentId},
Name: v.Name,
}
} else {
level3Map[v.Name].Pids[v.Id] = v.ParentId
}
}
}
return level2Map, level3Map
}
func cateAll() (data []CategoryPhy) { func cateAll() (data []CategoryPhy) {
var ( var (
row CategoryPhy row CategoryPhy

5
go.mod
View File

@ -1,6 +1,8 @@
module gitea.cdlsxd.cn/self-tools/l_ai_category module gitea.cdlsxd.cn/self-tools/l_ai_category
go 1.22.2 go 1.22.10
toolchain go1.23.7
require ( require (
github.com/volcengine/volcengine-go-sdk v1.0.187 github.com/volcengine/volcengine-go-sdk v1.0.187
@ -9,6 +11,7 @@ require (
) )
require ( require (
gitea.cdlsxd.cn/self-tools/l_request v1.0.8 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/google/uuid v1.3.0 // indirect github.com/google/uuid v1.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect