Compare commits
2 Commits
9ce3153e0a
...
5ffefe7dc6
Author | SHA1 | Date |
---|---|---|
|
5ffefe7dc6 | |
|
2ec7c81a4d |
98
ai.go
98
ai.go
|
@ -4,17 +4,54 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao"
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func GetCategoryWithSelf(ctx context.Context, goodsName string, address string) (cateId int, confidence float64, err error) {
|
||||||
|
cateId, name, confidence, err := getCategoryWithAi(ctx, goodsName, address)
|
||||||
|
fmt.Printf("商品:%s,所属分类:%s,置信度:%f\n", goodsName, name, confidence)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type CatRes struct {
|
||||||
|
Predictions []struct {
|
||||||
|
Category string `json:"category"`
|
||||||
|
Confidence float64 `json:"confidence"`
|
||||||
|
} `json:"predictions"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCategoryWithAi(ctx context.Context, goodsName string, address string) (cateId int, cateName string, confidence float64, err error) {
|
||||||
|
request := map[string]interface{}{
|
||||||
|
"product_names": goodsName,
|
||||||
|
}
|
||||||
|
req := l_request.Request{
|
||||||
|
Method: "POST",
|
||||||
|
Url: address,
|
||||||
|
Json: request,
|
||||||
|
}
|
||||||
|
res, err := req.Send()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var resMap CatRes
|
||||||
|
err = json.Unmarshal(res.Content, &resMap)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cateInfo := strings.Split(resMap.Predictions[0].Category, "_")
|
||||||
|
cateId, _ = strconv.Atoi(cateInfo[1])
|
||||||
|
return cateId, cateInfo[0], resMap.Predictions[0].Confidence, nil
|
||||||
|
}
|
||||||
|
|
||||||
type Category struct {
|
type Category struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Level int `json:"level"`
|
Level int `json:"level"`
|
||||||
Pid int `json:"pid"`
|
Pid int `json:"pid"`
|
||||||
|
Pids map[int]int `json:"pid"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CategoryDic struct {
|
type CategoryDic struct {
|
||||||
|
@ -25,58 +62,7 @@ type QuesStruct struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) {
|
type Req struct {
|
||||||
var (
|
Text string `json:"text"`
|
||||||
usage1 int
|
Cate []string `json:"cate"`
|
||||||
usage2 int
|
|
||||||
usage3 int
|
|
||||||
)
|
|
||||||
defer func() {
|
|
||||||
fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3)
|
|
||||||
}()
|
|
||||||
cate1Json, _ := json.Marshal(catePath[0][1])
|
|
||||||
cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "")
|
|
||||||
if err != nil || cate1 == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cate2Json, _ := json.Marshal(catePath[1][cate1])
|
|
||||||
cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "")
|
|
||||||
if err != nil || cate2 == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cate3Json, _ := json.Marshal(catePath[2][cate2])
|
|
||||||
cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) {
|
|
||||||
var quesList []*QuesStruct
|
|
||||||
err = json.Unmarshal(cateJson, &quesList)
|
|
||||||
quesByte, _ := json.Marshal(quesList)
|
|
||||||
modelObj := doubao.NewDouBao(chatModel, key)
|
|
||||||
text := []string{
|
|
||||||
"根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]",
|
|
||||||
"-只返回数字,不要有其他文字和字符",
|
|
||||||
"-JSON数据:" + string(quesByte),
|
|
||||||
}
|
|
||||||
if exCommand != "" {
|
|
||||||
text = append(text, exCommand)
|
|
||||||
}
|
|
||||||
|
|
||||||
text = append(text, "-如果无法匹配,则返回0")
|
|
||||||
|
|
||||||
category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
|
|
||||||
cleaned := strings.ReplaceAll(input, "\n", "")
|
|
||||||
return cleaned, nil
|
|
||||||
}, text...)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
categoryId, err = strconv.Atoi(category)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("为找到分类")
|
|
||||||
categoryId = 0
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
100
ai_test.go
100
ai_test.go
|
@ -4,25 +4,31 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm"
|
"gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dns = "developer:Lsxd@2024@tcp(lsxdpolar.rwlb.rds.aliyuncs.com:3306)/physicalgoods?parseTime=True&loc=Local"
|
||||||
|
driver = "mysql"
|
||||||
|
table = "category"
|
||||||
|
modelType = "deepseek-r1-distill-qwen-32b-250120"
|
||||||
|
key = "03320e58-6a0b-4061-a22b-902039f2190d"
|
||||||
|
Address = "http://8.147.107.207:5000/predict"
|
||||||
|
)
|
||||||
|
|
||||||
func TestCategoryGoodsSet(t *testing.T) {
|
func TestCategoryGoodsSet(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
//var wg sync.WaitGroup
|
||||||
defer clean()
|
defer clean()
|
||||||
path := getPath()
|
///path := getPath()
|
||||||
goods := goodsAll()
|
goods := goodsAll()
|
||||||
updateChan := make(chan *UpdateLog, len(goods))
|
updateChan := make(chan *UpdateLog, len(goods))
|
||||||
start := time.Now().Unix()
|
start := time.Now().Unix()
|
||||||
wg.Add(len(goods))
|
//wg.Add(len(goods))
|
||||||
for _, v := range goods {
|
for _, v := range goods {
|
||||||
go func(updateChan chan *UpdateLog) {
|
//go func(updateChan chan *UpdateLog) {
|
||||||
defer wg.Done()
|
//defer wg.Done()
|
||||||
res, err := GetCategory(context.Background(), v.Title, key, modelType, path)
|
res, _, err := GetCategoryWithSelf(context.Background(), v.Title, Address)
|
||||||
|
|
||||||
if err != nil || res == 0 {
|
if err != nil || res == 0 {
|
||||||
return
|
return
|
||||||
|
@ -32,9 +38,9 @@ func TestCategoryGoodsSet(t *testing.T) {
|
||||||
Title: v.Title,
|
Title: v.Title,
|
||||||
CateId: res,
|
CateId: res,
|
||||||
}
|
}
|
||||||
}(updateChan)
|
//}(updateChan)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
//wg.Wait()
|
||||||
close(updateChan)
|
close(updateChan)
|
||||||
end := time.Now().Unix()
|
end := time.Now().Unix()
|
||||||
fmt.Printf("耗时:%d秒", end-start)
|
fmt.Printf("耗时:%d秒", end-start)
|
||||||
|
@ -80,7 +86,7 @@ func goodsAll() (data []Goods) {
|
||||||
row Goods
|
row Goods
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
rows, _ := db.Raw("SELECT id,title FROM `goods` where brand_category_check=4").Rows()
|
rows, _ := db.Raw("SELECT id,title FROM `goods` order by id desc limit 1000 offset 3156").Rows() //where brand_category_check=4
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
err = db.ScanRows(rows, &row)
|
err = db.ScanRows(rows, &row)
|
||||||
|
@ -114,6 +120,78 @@ func getPath() []map[int][]*CategoryDic {
|
||||||
return cateLevelPath
|
return cateLevelPath
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getPath2() []Category {
|
||||||
|
cates := cateAll()
|
||||||
|
var (
|
||||||
|
level1Map = make(map[int]Category)
|
||||||
|
level2Map = make(map[int]Category)
|
||||||
|
level3Slice = make([]Category, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, v := range cates {
|
||||||
|
switch v.Level {
|
||||||
|
case 1:
|
||||||
|
level1Map[v.Id] = Category{
|
||||||
|
Id: v.Id,
|
||||||
|
Name: v.Name,
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
level2Map[v.Id] = Category{
|
||||||
|
Id: v.Id,
|
||||||
|
Name: v.Name,
|
||||||
|
Pid: v.ParentId,
|
||||||
|
}
|
||||||
|
case 3:
|
||||||
|
cateName := fmt.Sprintf("%s/%s/%s", level1Map[level2Map[v.ParentId].Pid].Name, level2Map[v.ParentId].Name, v.Name)
|
||||||
|
level3Slice = append(level3Slice, Category{
|
||||||
|
Id: v.Id,
|
||||||
|
Name: cateName,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return level3Slice
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPath3() (map[int]Category, map[string]Category) {
|
||||||
|
cates := cateAll()
|
||||||
|
var (
|
||||||
|
level1Map = make(map[int]Category)
|
||||||
|
level2Map = make(map[int]Category)
|
||||||
|
level3Map = make(map[string]Category)
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, v := range cates {
|
||||||
|
switch v.Level {
|
||||||
|
case 1:
|
||||||
|
|
||||||
|
level2Map[v.Id] = Category{
|
||||||
|
Name: v.Name,
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
cate_name := fmt.Sprintf("%s/%s", level1Map[v.ParentId].Name, v.Name)
|
||||||
|
level2Map[v.Id] = Category{
|
||||||
|
Id: v.Id,
|
||||||
|
Name: cate_name,
|
||||||
|
Pid: v.ParentId,
|
||||||
|
}
|
||||||
|
case 3:
|
||||||
|
if _, ex := level3Map[v.Name]; !ex {
|
||||||
|
level3Map[v.Name] = Category{
|
||||||
|
Id: v.Id,
|
||||||
|
Pids: map[int]int{v.Id: v.ParentId},
|
||||||
|
Name: v.Name,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
level3Map[v.Name].Pids[v.Id] = v.ParentId
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return level2Map, level3Map
|
||||||
|
}
|
||||||
|
|
||||||
func cateAll() (data []CategoryPhy) {
|
func cateAll() (data []CategoryPhy) {
|
||||||
var (
|
var (
|
||||||
row CategoryPhy
|
row CategoryPhy
|
||||||
|
|
5
go.mod
5
go.mod
|
@ -1,6 +1,8 @@
|
||||||
module gitea.cdlsxd.cn/self-tools/l_ai_category
|
module gitea.cdlsxd.cn/self-tools/l_ai_category
|
||||||
|
|
||||||
go 1.22.2
|
go 1.22.10
|
||||||
|
|
||||||
|
toolchain go1.23.7
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/volcengine/volcengine-go-sdk v1.0.187
|
github.com/volcengine/volcengine-go-sdk v1.0.187
|
||||||
|
@ -9,6 +11,7 @@ require (
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
gitea.cdlsxd.cn/self-tools/l_request v1.0.8 // indirect
|
||||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||||
github.com/google/uuid v1.3.0 // indirect
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
|
Loading…
Reference in New Issue