Compare commits
No commits in common. "5ffefe7dc6273dcefc2abb84cf5c9388e0d36f75" and "9ce3153e0a3e2f4506d37e2e67063a7902753e96" have entirely different histories.
5ffefe7dc6
...
9ce3153e0a
106
ai.go
106
ai.go
|
@ -4,54 +4,17 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao"
|
||||||
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetCategoryWithSelf(ctx context.Context, goodsName string, address string) (cateId int, confidence float64, err error) {
|
|
||||||
cateId, name, confidence, err := getCategoryWithAi(ctx, goodsName, address)
|
|
||||||
fmt.Printf("商品:%s,所属分类:%s,置信度:%f\n", goodsName, name, confidence)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type CatRes struct {
|
|
||||||
Predictions []struct {
|
|
||||||
Category string `json:"category"`
|
|
||||||
Confidence float64 `json:"confidence"`
|
|
||||||
} `json:"predictions"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCategoryWithAi(ctx context.Context, goodsName string, address string) (cateId int, cateName string, confidence float64, err error) {
|
|
||||||
request := map[string]interface{}{
|
|
||||||
"product_names": goodsName,
|
|
||||||
}
|
|
||||||
req := l_request.Request{
|
|
||||||
Method: "POST",
|
|
||||||
Url: address,
|
|
||||||
Json: request,
|
|
||||||
}
|
|
||||||
res, err := req.Send()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var resMap CatRes
|
|
||||||
err = json.Unmarshal(res.Content, &resMap)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cateInfo := strings.Split(resMap.Predictions[0].Category, "_")
|
|
||||||
cateId, _ = strconv.Atoi(cateInfo[1])
|
|
||||||
return cateId, cateInfo[0], resMap.Predictions[0].Confidence, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Category struct {
|
type Category struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Level int `json:"level"`
|
Level int `json:"level"`
|
||||||
Pid int `json:"pid"`
|
Pid int `json:"pid"`
|
||||||
Pids map[int]int `json:"pid"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CategoryDic struct {
|
type CategoryDic struct {
|
||||||
|
@ -62,7 +25,58 @@ type QuesStruct struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Req struct {
|
func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) {
|
||||||
Text string `json:"text"`
|
var (
|
||||||
Cate []string `json:"cate"`
|
usage1 int
|
||||||
|
usage2 int
|
||||||
|
usage3 int
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3)
|
||||||
|
}()
|
||||||
|
cate1Json, _ := json.Marshal(catePath[0][1])
|
||||||
|
cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "")
|
||||||
|
if err != nil || cate1 == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cate2Json, _ := json.Marshal(catePath[1][cate1])
|
||||||
|
cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "")
|
||||||
|
if err != nil || cate2 == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cate3Json, _ := json.Marshal(catePath[2][cate2])
|
||||||
|
cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) {
|
||||||
|
var quesList []*QuesStruct
|
||||||
|
err = json.Unmarshal(cateJson, &quesList)
|
||||||
|
quesByte, _ := json.Marshal(quesList)
|
||||||
|
modelObj := doubao.NewDouBao(chatModel, key)
|
||||||
|
text := []string{
|
||||||
|
"根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]",
|
||||||
|
"-只返回数字,不要有其他文字和字符",
|
||||||
|
"-JSON数据:" + string(quesByte),
|
||||||
|
}
|
||||||
|
if exCommand != "" {
|
||||||
|
text = append(text, exCommand)
|
||||||
|
}
|
||||||
|
|
||||||
|
text = append(text, "-如果无法匹配,则返回0")
|
||||||
|
|
||||||
|
category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) {
|
||||||
|
cleaned := strings.ReplaceAll(input, "\n", "")
|
||||||
|
return cleaned, nil
|
||||||
|
}, text...)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
categoryId, err = strconv.Atoi(category)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("为找到分类")
|
||||||
|
categoryId = 0
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
116
ai_test.go
116
ai_test.go
|
@ -4,43 +4,37 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm"
|
"gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
dns = "developer:Lsxd@2024@tcp(lsxdpolar.rwlb.rds.aliyuncs.com:3306)/physicalgoods?parseTime=True&loc=Local"
|
|
||||||
driver = "mysql"
|
|
||||||
table = "category"
|
|
||||||
modelType = "deepseek-r1-distill-qwen-32b-250120"
|
|
||||||
key = "03320e58-6a0b-4061-a22b-902039f2190d"
|
|
||||||
Address = "http://8.147.107.207:5000/predict"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCategoryGoodsSet(t *testing.T) {
|
func TestCategoryGoodsSet(t *testing.T) {
|
||||||
//var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
defer clean()
|
defer clean()
|
||||||
///path := getPath()
|
path := getPath()
|
||||||
goods := goodsAll()
|
goods := goodsAll()
|
||||||
updateChan := make(chan *UpdateLog, len(goods))
|
updateChan := make(chan *UpdateLog, len(goods))
|
||||||
start := time.Now().Unix()
|
start := time.Now().Unix()
|
||||||
//wg.Add(len(goods))
|
wg.Add(len(goods))
|
||||||
for _, v := range goods {
|
for _, v := range goods {
|
||||||
//go func(updateChan chan *UpdateLog) {
|
go func(updateChan chan *UpdateLog) {
|
||||||
//defer wg.Done()
|
defer wg.Done()
|
||||||
res, _, err := GetCategoryWithSelf(context.Background(), v.Title, Address)
|
res, err := GetCategory(context.Background(), v.Title, key, modelType, path)
|
||||||
|
|
||||||
if err != nil || res == 0 {
|
if err != nil || res == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
updateChan <- &UpdateLog{
|
updateChan <- &UpdateLog{
|
||||||
GoodsId: v.Id,
|
GoodsId: v.Id,
|
||||||
Title: v.Title,
|
Title: v.Title,
|
||||||
CateId: res,
|
CateId: res,
|
||||||
}
|
}
|
||||||
//}(updateChan)
|
}(updateChan)
|
||||||
}
|
}
|
||||||
//wg.Wait()
|
wg.Wait()
|
||||||
close(updateChan)
|
close(updateChan)
|
||||||
end := time.Now().Unix()
|
end := time.Now().Unix()
|
||||||
fmt.Printf("耗时:%d秒", end-start)
|
fmt.Printf("耗时:%d秒", end-start)
|
||||||
|
@ -86,7 +80,7 @@ func goodsAll() (data []Goods) {
|
||||||
row Goods
|
row Goods
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
rows, _ := db.Raw("SELECT id,title FROM `goods` order by id desc limit 1000 offset 3156").Rows() //where brand_category_check=4
|
rows, _ := db.Raw("SELECT id,title FROM `goods` where brand_category_check=4").Rows()
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
err = db.ScanRows(rows, &row)
|
err = db.ScanRows(rows, &row)
|
||||||
|
@ -120,78 +114,6 @@ func getPath() []map[int][]*CategoryDic {
|
||||||
return cateLevelPath
|
return cateLevelPath
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPath2() []Category {
|
|
||||||
cates := cateAll()
|
|
||||||
var (
|
|
||||||
level1Map = make(map[int]Category)
|
|
||||||
level2Map = make(map[int]Category)
|
|
||||||
level3Slice = make([]Category, 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, v := range cates {
|
|
||||||
switch v.Level {
|
|
||||||
case 1:
|
|
||||||
level1Map[v.Id] = Category{
|
|
||||||
Id: v.Id,
|
|
||||||
Name: v.Name,
|
|
||||||
}
|
|
||||||
case 2:
|
|
||||||
level2Map[v.Id] = Category{
|
|
||||||
Id: v.Id,
|
|
||||||
Name: v.Name,
|
|
||||||
Pid: v.ParentId,
|
|
||||||
}
|
|
||||||
case 3:
|
|
||||||
cateName := fmt.Sprintf("%s/%s/%s", level1Map[level2Map[v.ParentId].Pid].Name, level2Map[v.ParentId].Name, v.Name)
|
|
||||||
level3Slice = append(level3Slice, Category{
|
|
||||||
Id: v.Id,
|
|
||||||
Name: cateName,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return level3Slice
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPath3() (map[int]Category, map[string]Category) {
|
|
||||||
cates := cateAll()
|
|
||||||
var (
|
|
||||||
level1Map = make(map[int]Category)
|
|
||||||
level2Map = make(map[int]Category)
|
|
||||||
level3Map = make(map[string]Category)
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, v := range cates {
|
|
||||||
switch v.Level {
|
|
||||||
case 1:
|
|
||||||
|
|
||||||
level2Map[v.Id] = Category{
|
|
||||||
Name: v.Name,
|
|
||||||
}
|
|
||||||
case 2:
|
|
||||||
cate_name := fmt.Sprintf("%s/%s", level1Map[v.ParentId].Name, v.Name)
|
|
||||||
level2Map[v.Id] = Category{
|
|
||||||
Id: v.Id,
|
|
||||||
Name: cate_name,
|
|
||||||
Pid: v.ParentId,
|
|
||||||
}
|
|
||||||
case 3:
|
|
||||||
if _, ex := level3Map[v.Name]; !ex {
|
|
||||||
level3Map[v.Name] = Category{
|
|
||||||
Id: v.Id,
|
|
||||||
Pids: map[int]int{v.Id: v.ParentId},
|
|
||||||
Name: v.Name,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
level3Map[v.Name].Pids[v.Id] = v.ParentId
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return level2Map, level3Map
|
|
||||||
}
|
|
||||||
|
|
||||||
func cateAll() (data []CategoryPhy) {
|
func cateAll() (data []CategoryPhy) {
|
||||||
var (
|
var (
|
||||||
row CategoryPhy
|
row CategoryPhy
|
||||||
|
|
5
go.mod
5
go.mod
|
@ -1,8 +1,6 @@
|
||||||
module gitea.cdlsxd.cn/self-tools/l_ai_category
|
module gitea.cdlsxd.cn/self-tools/l_ai_category
|
||||||
|
|
||||||
go 1.22.10
|
go 1.22.2
|
||||||
|
|
||||||
toolchain go1.23.7
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/volcengine/volcengine-go-sdk v1.0.187
|
github.com/volcengine/volcengine-go-sdk v1.0.187
|
||||||
|
@ -11,7 +9,6 @@ require (
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
gitea.cdlsxd.cn/self-tools/l_request v1.0.8 // indirect
|
|
||||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||||
github.com/google/uuid v1.3.0 // indirect
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
|
Loading…
Reference in New Issue