Compare commits
No commits in common. "5ffefe7dc6273dcefc2abb84cf5c9388e0d36f75" and "9ce3153e0a3e2f4506d37e2e67063a7902753e96" have entirely different histories.
5ffefe7dc6
...
9ce3153e0a
106
ai.go
106
ai.go
|
@ -4,54 +4,17 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetCategoryWithSelf(ctx context.Context, goodsName string, address string) (cateId int, confidence float64, err error) {
|
||||
cateId, name, confidence, err := getCategoryWithAi(ctx, goodsName, address)
|
||||
fmt.Printf("商品:%s,所属分类:%s,置信度:%f\n", goodsName, name, confidence)
|
||||
return
|
||||
}
|
||||
|
||||
type CatRes struct {
|
||||
Predictions []struct {
|
||||
Category string `json:"category"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
} `json:"predictions"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func getCategoryWithAi(ctx context.Context, goodsName string, address string) (cateId int, cateName string, confidence float64, err error) {
|
||||
request := map[string]interface{}{
|
||||
"product_names": goodsName,
|
||||
}
|
||||
req := l_request.Request{
|
||||
Method: "POST",
|
||||
Url: address,
|
||||
Json: request,
|
||||
}
|
||||
res, err := req.Send()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var resMap CatRes
|
||||
err = json.Unmarshal(res.Content, &resMap)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cateInfo := strings.Split(resMap.Predictions[0].Category, "_")
|
||||
cateId, _ = strconv.Atoi(cateInfo[1])
|
||||
return cateId, cateInfo[0], resMap.Predictions[0].Confidence, nil
|
||||
}
|
||||
|
||||
type Category struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Level int `json:"level"`
|
||||
Pid int `json:"pid"`
|
||||
Pids map[int]int `json:"pid"`
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Level int `json:"level"`
|
||||
Pid int `json:"pid"`
|
||||
}
|
||||
|
||||
type CategoryDic struct {
|
||||
|
@ -62,7 +25,58 @@ type QuesStruct struct {
|
|||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type Req struct {
|
||||
Text string `json:"text"`
|
||||
Cate []string `json:"cate"`
|
||||
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) {
|
||||
var (
|
||||
usage1 int
|
||||
usage2 int
|
||||
usage3 int
|
||||
)
|
||||
defer func() {
|
||||
fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3)
|
||||
}()
|
||||
cate1Json, _ := json.Marshal(catePath[0][1])
|
||||
cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "")
|
||||
if err != nil || cate1 == 0 {
|
||||
return
|
||||
}
|
||||
cate2Json, _ := json.Marshal(catePath[1][cate1])
|
||||
cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "")
|
||||
if err != nil || cate2 == 0 {
|
||||
return
|
||||
}
|
||||
cate3Json, _ := json.Marshal(catePath[2][cate2])
|
||||
cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) {
|
||||
var quesList []*QuesStruct
|
||||
err = json.Unmarshal(cateJson, &quesList)
|
||||
quesByte, _ := json.Marshal(quesList)
|
||||
modelObj := doubao.NewDouBao(chatModel, key)
|
||||
text := []string{
|
||||
"根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]",
|
||||
"-只返回数字,不要有其他文字和字符",
|
||||
"-JSON数据:" + string(quesByte),
|
||||
}
|
||||
if exCommand != "" {
|
||||
text = append(text, exCommand)
|
||||
}
|
||||
|
||||
text = append(text, "-如果无法匹配,则返回0")
|
||||
|
||||
category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
|
||||
cleaned := strings.ReplaceAll(input, "\n", "")
|
||||
return cleaned, nil
|
||||
}, text...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
categoryId, err = strconv.Atoi(category)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("为找到分类")
|
||||
categoryId = 0
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
116
ai_test.go
116
ai_test.go
|
@ -4,43 +4,37 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
dns = "developer:Lsxd@2024@tcp(lsxdpolar.rwlb.rds.aliyuncs.com:3306)/physicalgoods?parseTime=True&loc=Local"
|
||||
driver = "mysql"
|
||||
table = "category"
|
||||
modelType = "deepseek-r1-distill-qwen-32b-250120"
|
||||
key = "03320e58-6a0b-4061-a22b-902039f2190d"
|
||||
Address = "http://8.147.107.207:5000/predict"
|
||||
)
|
||||
|
||||
|
||||
func TestCategoryGoodsSet(t *testing.T) {
|
||||
//var wg sync.WaitGroup
|
||||
var wg sync.WaitGroup
|
||||
defer clean()
|
||||
///path := getPath()
|
||||
path := getPath()
|
||||
goods := goodsAll()
|
||||
updateChan := make(chan *UpdateLog, len(goods))
|
||||
start := time.Now().Unix()
|
||||
//wg.Add(len(goods))
|
||||
wg.Add(len(goods))
|
||||
for _, v := range goods {
|
||||
//go func(updateChan chan *UpdateLog) {
|
||||
//defer wg.Done()
|
||||
res, _, err := GetCategoryWithSelf(context.Background(), v.Title, Address)
|
||||
go func(updateChan chan *UpdateLog) {
|
||||
defer wg.Done()
|
||||
res, err := GetCategory(context.Background(), v.Title, key, modelType, path)
|
||||
|
||||
if err != nil || res == 0 {
|
||||
return
|
||||
}
|
||||
updateChan <- &UpdateLog{
|
||||
GoodsId: v.Id,
|
||||
Title: v.Title,
|
||||
CateId: res,
|
||||
}
|
||||
//}(updateChan)
|
||||
if err != nil || res == 0 {
|
||||
return
|
||||
}
|
||||
updateChan <- &UpdateLog{
|
||||
GoodsId: v.Id,
|
||||
Title: v.Title,
|
||||
CateId: res,
|
||||
}
|
||||
}(updateChan)
|
||||
}
|
||||
//wg.Wait()
|
||||
wg.Wait()
|
||||
close(updateChan)
|
||||
end := time.Now().Unix()
|
||||
fmt.Printf("耗时:%d秒", end-start)
|
||||
|
@ -86,7 +80,7 @@ func goodsAll() (data []Goods) {
|
|||
row Goods
|
||||
err error
|
||||
)
|
||||
rows, _ := db.Raw("SELECT id,title FROM `goods` order by id desc limit 1000 offset 3156").Rows() //where brand_category_check=4
|
||||
rows, _ := db.Raw("SELECT id,title FROM `goods` where brand_category_check=4").Rows()
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
err = db.ScanRows(rows, &row)
|
||||
|
@ -120,78 +114,6 @@ func getPath() []map[int][]*CategoryDic {
|
|||
return cateLevelPath
|
||||
}
|
||||
|
||||
func getPath2() []Category {
|
||||
cates := cateAll()
|
||||
var (
|
||||
level1Map = make(map[int]Category)
|
||||
level2Map = make(map[int]Category)
|
||||
level3Slice = make([]Category, 0)
|
||||
)
|
||||
|
||||
for _, v := range cates {
|
||||
switch v.Level {
|
||||
case 1:
|
||||
level1Map[v.Id] = Category{
|
||||
Id: v.Id,
|
||||
Name: v.Name,
|
||||
}
|
||||
case 2:
|
||||
level2Map[v.Id] = Category{
|
||||
Id: v.Id,
|
||||
Name: v.Name,
|
||||
Pid: v.ParentId,
|
||||
}
|
||||
case 3:
|
||||
cateName := fmt.Sprintf("%s/%s/%s", level1Map[level2Map[v.ParentId].Pid].Name, level2Map[v.ParentId].Name, v.Name)
|
||||
level3Slice = append(level3Slice, Category{
|
||||
Id: v.Id,
|
||||
Name: cateName,
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
return level3Slice
|
||||
}
|
||||
|
||||
func getPath3() (map[int]Category, map[string]Category) {
|
||||
cates := cateAll()
|
||||
var (
|
||||
level1Map = make(map[int]Category)
|
||||
level2Map = make(map[int]Category)
|
||||
level3Map = make(map[string]Category)
|
||||
)
|
||||
|
||||
for _, v := range cates {
|
||||
switch v.Level {
|
||||
case 1:
|
||||
|
||||
level2Map[v.Id] = Category{
|
||||
Name: v.Name,
|
||||
}
|
||||
case 2:
|
||||
cate_name := fmt.Sprintf("%s/%s", level1Map[v.ParentId].Name, v.Name)
|
||||
level2Map[v.Id] = Category{
|
||||
Id: v.Id,
|
||||
Name: cate_name,
|
||||
Pid: v.ParentId,
|
||||
}
|
||||
case 3:
|
||||
if _, ex := level3Map[v.Name]; !ex {
|
||||
level3Map[v.Name] = Category{
|
||||
Id: v.Id,
|
||||
Pids: map[int]int{v.Id: v.ParentId},
|
||||
Name: v.Name,
|
||||
}
|
||||
} else {
|
||||
level3Map[v.Name].Pids[v.Id] = v.ParentId
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
return level2Map, level3Map
|
||||
}
|
||||
|
||||
func cateAll() (data []CategoryPhy) {
|
||||
var (
|
||||
row CategoryPhy
|
||||
|
|
5
go.mod
5
go.mod
|
@ -1,8 +1,6 @@
|
|||
module gitea.cdlsxd.cn/self-tools/l_ai_category
|
||||
|
||||
go 1.22.10
|
||||
|
||||
toolchain go1.23.7
|
||||
go 1.22.2
|
||||
|
||||
require (
|
||||
github.com/volcengine/volcengine-go-sdk v1.0.187
|
||||
|
@ -11,7 +9,6 @@ require (
|
|||
)
|
||||
|
||||
require (
|
||||
gitea.cdlsxd.cn/self-tools/l_request v1.0.8 // indirect
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
|
|
Loading…
Reference in New Issue