Compare commits

..

No commits in common. "5ffefe7dc6273dcefc2abb84cf5c9388e0d36f75" and "9ce3153e0a3e2f4506d37e2e67063a7902753e96" have entirely different histories.

3 changed files with 80 additions and 147 deletions

106
ai.go
View File

@ -4,54 +4,17 @@ import (
"context"
"encoding/json"
"fmt"
"gitea.cdlsxd.cn/self-tools/l_request"
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"strconv"
"strings"
)
func GetCategoryWithSelf(ctx context.Context, goodsName string, address string) (cateId int, confidence float64, err error) {
cateId, name, confidence, err := getCategoryWithAi(ctx, goodsName, address)
fmt.Printf("商品:%s,所属分类:%s,置信度:%f\n", goodsName, name, confidence)
return
}
type CatRes struct {
Predictions []struct {
Category string `json:"category"`
Confidence float64 `json:"confidence"`
} `json:"predictions"`
Status string `json:"status"`
}
func getCategoryWithAi(ctx context.Context, goodsName string, address string) (cateId int, cateName string, confidence float64, err error) {
request := map[string]interface{}{
"product_names": goodsName,
}
req := l_request.Request{
Method: "POST",
Url: address,
Json: request,
}
res, err := req.Send()
if err != nil {
return
}
var resMap CatRes
err = json.Unmarshal(res.Content, &resMap)
if err != nil {
return
}
cateInfo := strings.Split(resMap.Predictions[0].Category, "_")
cateId, _ = strconv.Atoi(cateInfo[1])
return cateId, cateInfo[0], resMap.Predictions[0].Confidence, nil
}
type Category struct {
Id int `json:"id"`
Name string `json:"name"`
Level int `json:"level"`
Pid int `json:"pid"`
Pids map[int]int `json:"pid"`
Id int `json:"id"`
Name string `json:"name"`
Level int `json:"level"`
Pid int `json:"pid"`
}
type CategoryDic struct {
@ -62,7 +25,58 @@ type QuesStruct struct {
Name string `json:"name"`
}
type Req struct {
Text string `json:"text"`
Cate []string `json:"cate"`
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) {
var (
usage1 int
usage2 int
usage3 int
)
defer func() {
fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3)
}()
cate1Json, _ := json.Marshal(catePath[0][1])
cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "")
if err != nil || cate1 == 0 {
return
}
cate2Json, _ := json.Marshal(catePath[1][cate1])
cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "")
if err != nil || cate2 == 0 {
return
}
cate3Json, _ := json.Marshal(catePath[2][cate2])
cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类")
return
}
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) {
var quesList []*QuesStruct
err = json.Unmarshal(cateJson, &quesList)
quesByte, _ := json.Marshal(quesList)
modelObj := doubao.NewDouBao(chatModel, key)
text := []string{
"根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]",
"-只返回数字,不要有其他文字和字符",
"-JSON数据" + string(quesByte),
}
if exCommand != "" {
text = append(text, exCommand)
}
text = append(text, "-如果无法匹配则返回0")
category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
cleaned := strings.ReplaceAll(input, "\n", "")
return cleaned, nil
}, text...)
if err != nil {
return
}
categoryId, err = strconv.Atoi(category)
if err != nil {
err = fmt.Errorf("为找到分类")
categoryId = 0
}
return
}

View File

@ -4,43 +4,37 @@ import (
"context"
"fmt"
"gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm"
"sync"
"testing"
"time"
)
const (
dns = "developer:Lsxd@2024@tcp(lsxdpolar.rwlb.rds.aliyuncs.com:3306)/physicalgoods?parseTime=True&loc=Local"
driver = "mysql"
table = "category"
modelType = "deepseek-r1-distill-qwen-32b-250120"
key = "03320e58-6a0b-4061-a22b-902039f2190d"
Address = "http://8.147.107.207:5000/predict"
)
func TestCategoryGoodsSet(t *testing.T) {
//var wg sync.WaitGroup
var wg sync.WaitGroup
defer clean()
///path := getPath()
path := getPath()
goods := goodsAll()
updateChan := make(chan *UpdateLog, len(goods))
start := time.Now().Unix()
//wg.Add(len(goods))
wg.Add(len(goods))
for _, v := range goods {
//go func(updateChan chan *UpdateLog) {
//defer wg.Done()
res, _, err := GetCategoryWithSelf(context.Background(), v.Title, Address)
go func(updateChan chan *UpdateLog) {
defer wg.Done()
res, err := GetCategory(context.Background(), v.Title, key, modelType, path)
if err != nil || res == 0 {
return
}
updateChan <- &UpdateLog{
GoodsId: v.Id,
Title: v.Title,
CateId: res,
}
//}(updateChan)
if err != nil || res == 0 {
return
}
updateChan <- &UpdateLog{
GoodsId: v.Id,
Title: v.Title,
CateId: res,
}
}(updateChan)
}
//wg.Wait()
wg.Wait()
close(updateChan)
end := time.Now().Unix()
fmt.Printf("耗时:%d秒", end-start)
@ -86,7 +80,7 @@ func goodsAll() (data []Goods) {
row Goods
err error
)
rows, _ := db.Raw("SELECT id,title FROM `goods` order by id desc limit 1000 offset 3156").Rows() //where brand_category_check=4
rows, _ := db.Raw("SELECT id,title FROM `goods` where brand_category_check=4").Rows()
defer rows.Close()
for rows.Next() {
err = db.ScanRows(rows, &row)
@ -120,78 +114,6 @@ func getPath() []map[int][]*CategoryDic {
return cateLevelPath
}
func getPath2() []Category {
cates := cateAll()
var (
level1Map = make(map[int]Category)
level2Map = make(map[int]Category)
level3Slice = make([]Category, 0)
)
for _, v := range cates {
switch v.Level {
case 1:
level1Map[v.Id] = Category{
Id: v.Id,
Name: v.Name,
}
case 2:
level2Map[v.Id] = Category{
Id: v.Id,
Name: v.Name,
Pid: v.ParentId,
}
case 3:
cateName := fmt.Sprintf("%s/%s/%s", level1Map[level2Map[v.ParentId].Pid].Name, level2Map[v.ParentId].Name, v.Name)
level3Slice = append(level3Slice, Category{
Id: v.Id,
Name: cateName,
})
}
}
return level3Slice
}
func getPath3() (map[int]Category, map[string]Category) {
cates := cateAll()
var (
level1Map = make(map[int]Category)
level2Map = make(map[int]Category)
level3Map = make(map[string]Category)
)
for _, v := range cates {
switch v.Level {
case 1:
level2Map[v.Id] = Category{
Name: v.Name,
}
case 2:
cate_name := fmt.Sprintf("%s/%s", level1Map[v.ParentId].Name, v.Name)
level2Map[v.Id] = Category{
Id: v.Id,
Name: cate_name,
Pid: v.ParentId,
}
case 3:
if _, ex := level3Map[v.Name]; !ex {
level3Map[v.Name] = Category{
Id: v.Id,
Pids: map[int]int{v.Id: v.ParentId},
Name: v.Name,
}
} else {
level3Map[v.Name].Pids[v.Id] = v.ParentId
}
}
}
return level2Map, level3Map
}
func cateAll() (data []CategoryPhy) {
var (
row CategoryPhy

5
go.mod
View File

@ -1,8 +1,6 @@
module gitea.cdlsxd.cn/self-tools/l_ai_category
go 1.22.10
toolchain go1.23.7
go 1.22.2
require (
github.com/volcengine/volcengine-go-sdk v1.0.187
@ -11,7 +9,6 @@ require (
)
require (
gitea.cdlsxd.cn/self-tools/l_request v1.0.8 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect