This commit is contained in:
		
							parent
							
								
									3296b47211
								
							
						
					
					
						commit
						0c5ccfa73c
					
				
							
								
								
									
										77
									
								
								ai.go
								
								
								
								
							
							
						
						
									
										77
									
								
								ai.go
								
								
								
								
							|  | @ -7,6 +7,7 @@ import ( | |||
| 	"gitea.cdlsxd.cn/self-tools/l_ai_category/doubao" | ||||
| 	"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| type Category struct { | ||||
|  | @ -25,37 +26,49 @@ type QuesStruct struct { | |||
| } | ||||
| 
 | ||||
| func GetCategory(ctx context.Context, goodsName, key, chatModel string, catePath []map[int][]*CategoryDic) (cate3 int, err error) { | ||||
| 	var ( | ||||
| 		usage1 int | ||||
| 		usage2 int | ||||
| 		usage3 int | ||||
| 	) | ||||
| 	defer func() { | ||||
| 		fmt.Printf("第一次用量:%d,第二次用量:%d,第三次用量:%d,总共花费:%d \n", usage1, usage2, usage3, usage1+usage2+usage3) | ||||
| 	}() | ||||
| 	cate1Json, _ := json.Marshal(catePath[0][1]) | ||||
| 	cate1, err := ques(ctx, goodsName, key, chatModel, cate1Json) | ||||
| 	if err != nil { | ||||
| 	cate1, usage1, err := ques(ctx, goodsName, key, chatModel, cate1Json, "") | ||||
| 	if err != nil || cate1 == 0 { | ||||
| 		return | ||||
| 	} | ||||
| 	cate2Json, _ := json.Marshal(catePath[1][cate1]) | ||||
| 	cate2, err := ques(ctx, goodsName, key, chatModel, cate2Json) | ||||
| 	if err != nil { | ||||
| 	cate2, usage2, err := ques(ctx, goodsName, key, chatModel, cate2Json, "") | ||||
| 	if err != nil || cate2 == 0 { | ||||
| 		return | ||||
| 	} | ||||
| 	cate3Json, _ := json.Marshal(catePath[2][cate2]) | ||||
| 	cate3, err = ques(ctx, goodsName, key, chatModel, cate3Json) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	cate3, usage3, err = ques(ctx, goodsName, key, chatModel, cate3Json, "匹配一个最相近的,或者商品名称里面包含的分类") | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte) (categoryId int, err error) { | ||||
| func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte, exCommand string) (categoryId int, usage int, err error) { | ||||
| 	var quesList []*QuesStruct | ||||
| 	err = json.Unmarshal(cateJson, &quesList) | ||||
| 	quesByte, _ := json.Marshal(quesList) | ||||
| 	modelObj := doubao.NewDouBao(chatModel, key) | ||||
| 	text := []string{ | ||||
| 		"根据商品名称,从json中找到该商品对应的分类[goodsName]" + goodsName + "[/goodsName]", | ||||
| 		"-只返回分类名称对应的id,返回其他文字", | ||||
| 		"-如果无法匹配,则返回数字0", | ||||
| 		"根据商品名称,从json中找到该商品对应的分类的数字id[goodsName]" + goodsName + "[/goodsName]", | ||||
| 		"-只返回数字,不要有其他文字和字符", | ||||
| 		"-JSON数据:" + string(quesByte), | ||||
| 	} | ||||
| 	category, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) { | ||||
| 		return input, nil | ||||
| 	if exCommand != "" { | ||||
| 		text = append(text, exCommand) | ||||
| 	} | ||||
| 
 | ||||
| 	text = append(text, "-如果无法匹配,则返回0") | ||||
| 
 | ||||
| 	category, usage, err := modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) { | ||||
| 		cleaned := strings.ReplaceAll(input, "\n", "") | ||||
| 		return cleaned, nil | ||||
| 	}, text...) | ||||
| 	if err != nil { | ||||
| 		return | ||||
|  | @ -67,39 +80,3 @@ func ques(ctx context.Context, goodsName, key, chatModel string, cateJson []byte | |||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func Test(ctx context.Context, key, chatModel string) (category string, err error) { | ||||
| 
 | ||||
| 	modelObj := doubao.NewDouBao(chatModel, key) | ||||
| 	text := []string{ | ||||
| 		"给我base_json里休闲食品对应的id", | ||||
| 	} | ||||
| 	category, err = modelObj.GetData(ctx, key, model.ChatMessageRoleUser, func(input string) (string, error) { | ||||
| 		return input, nil | ||||
| 	}, text...) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func FlushCate(ctx context.Context, key, chatModel string, cateJson string) (res string, err error) { | ||||
| 
 | ||||
| 	modelObj := doubao.NewDouBao(chatModel, key) | ||||
| 	text := []string{ | ||||
| 		"你是一个之智能非关系型数据库", | ||||
| 		"你的任务是将提供的 JSON 作为base_json进行存放,之后所有提问提到的base_json都是指这个JSON。", | ||||
| 		"以下是base_json:", | ||||
| 		fmt.Sprintf("<base_json>%s</base_json>", cateJson), | ||||
| 		"请将上述 JSON 作为base_json 进行存储。之后有提问提到base_json 时,要明确知晓是指此JSON。", | ||||
| 	} | ||||
| 	res, err = modelObj.GetData(ctx, key, model.ChatMessageRoleAssistant, func(input string) (string, error) { | ||||
| 		return input, nil | ||||
| 	}, text...) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
|  |  | |||
							
								
								
									
										180
									
								
								ai_test.go
								
								
								
								
							
							
						
						
									
										180
									
								
								ai_test.go
								
								
								
								
							|  | @ -4,84 +4,95 @@ import ( | |||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	dns       = "root:lansexiongdi6,@tcp(47.97.27.195:3306)/report?parseTime=True&loc=Local" | ||||
| 	dns       = "developer:Lsxd@2024@tcp(lsxdpolar.rwlb.rds.aliyuncs.com:3306)/physicalgoods?parseTime=True&loc=Local" | ||||
| 	driver    = "mysql" | ||||
| 	table     = "goods_category" | ||||
| 	modelType = "deepseek-v3-250324" | ||||
| 	table     = "category" | ||||
| 	modelType = "deepseek-r1-distill-qwen-32b-250120" | ||||
| 	key       = "03320e58-6a0b-4061-a22b-902039f2190d" | ||||
| ) | ||||
| 
 | ||||
| func TestCategory(t *testing.T) { | ||||
| 	path := getPath() | ||||
| 	res, err := GetCategory(context.Background(), "MM 便携装湿厕纸 40包*10片@默认", "03320e58-6a0b-4061-a22b-902039f2190d", modelType, path) | ||||
| 	res, err := GetCategory(context.Background(), "半亩花田倍润身体乳(失重玫瑰)250ml@默认", "03320e58-6a0b-4061-a22b-902039f2190d", modelType, path) | ||||
| 	t.Log(res, err) | ||||
| } | ||||
| 
 | ||||
| func TestRem(t *testing.T) { | ||||
| 
 | ||||
| 	res, err := Test(context.Background(), "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType) | ||||
| 	t.Log(res, err) | ||||
| } | ||||
| 
 | ||||
| //func TestFlush(t *testing.T) {
 | ||||
| //	catString := string(CategoryToJsonByte())
 | ||||
| //	res, err := FlushCate(context.Background(), "914ccf1d-c002-4fad-a431-f291f5e0d2ad", modelType, catString)
 | ||||
| //	t.Log(res, err)
 | ||||
| //}
 | ||||
| //
 | ||||
| //func TestGetCateGoryJson(t *testing.T) {
 | ||||
| //	cateJsonByte := CategoryToJsonByte()
 | ||||
| //
 | ||||
| //	t.Log(string(cateJsonByte))
 | ||||
| //}
 | ||||
| 
 | ||||
| //func CategoryToJsonByte() (cateJsonByte []byte) {
 | ||||
| //
 | ||||
| //	cateLevelPath := getPath()
 | ||||
| //	cate := cateIteration(cateLevelPath[0], cateIteration(cateLevelPath[1], cateLevelPath[2]))[0]
 | ||||
| //	cateJsonByte, _ = json.Marshal(cate)
 | ||||
| //	return
 | ||||
| //}
 | ||||
| 
 | ||||
| func getPath() []map[int][]*CategoryDic { | ||||
| 	cates := cateAll() | ||||
| 	var ( | ||||
| 		cateLevelPath = make([]map[int][]*CategoryDic, 3) | ||||
| 	) | ||||
| 	for _, v := range cates { | ||||
| 		if cateLevelPath[v.Level-1] == nil { | ||||
| 			cateLevelPath[v.Level-1] = make(map[int][]*CategoryDic) | ||||
| 		} | ||||
| 		if _, ex := cateLevelPath[v.Level-1][v.Pid]; !ex { | ||||
| 			cateLevelPath[v.Level-1][v.Pid] = make([]*CategoryDic, 0) | ||||
| 		} | ||||
| 		cateLevelPath[v.Level-1][v.Pid] = append(cateLevelPath[v.Level-1][v.Pid], &CategoryDic{Category: v}) | ||||
| 	} | ||||
| 	return cateLevelPath | ||||
| } | ||||
| 
 | ||||
| //func cateIteration(parent map[int][]*CategoryDic, child map[int][]*CategoryDic) map[int][]*CategoryDic {
 | ||||
| //	for _, v := range parent {
 | ||||
| //		for _, vv := range v {
 | ||||
| //			vv.Children = child[vv.Id]
 | ||||
| //		}
 | ||||
| //	}
 | ||||
| //	return parent
 | ||||
| //}
 | ||||
| 
 | ||||
| func cateAll() (data []Category) { | ||||
| 	var ( | ||||
| 		row Category | ||||
| 	) | ||||
| 	db, err, clean := utils_gorm.DB(driver, dns) | ||||
| func TestCategoryGoodsSet(t *testing.T) { | ||||
| 	var wg sync.WaitGroup | ||||
| 	defer clean() | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	path := getPath() | ||||
| 	goods := goodsAll() | ||||
| 	updateChan := make(chan *UpdateLog, len(goods)) | ||||
| 	start := time.Now().Unix() | ||||
| 	wg.Add(len(goods)) | ||||
| 	for _, v := range goods { | ||||
| 		go func(updateChan chan *UpdateLog) { | ||||
| 			defer wg.Done() | ||||
| 			res, err := GetCategory(context.Background(), v.Title, key, modelType, path) | ||||
| 
 | ||||
| 			if err != nil || res == 0 { | ||||
| 				return | ||||
| 			} | ||||
| 			updateChan <- &UpdateLog{ | ||||
| 				GoodsId: v.Id, | ||||
| 				Title:   v.Title, | ||||
| 				CateId:  res, | ||||
| 			} | ||||
| 		}(updateChan) | ||||
| 	} | ||||
| 	rows, _ := db.Raw(fmt.Sprintf("SELECT id,name,level,pid  FROM `%s` where `type`=2", table)).Rows() | ||||
| 	wg.Wait() | ||||
| 	close(updateChan) | ||||
| 	end := time.Now().Unix() | ||||
| 	fmt.Printf("耗时:%d秒", end-start) | ||||
| 	updateSlice := make([]*UpdateLog, len(updateChan)) | ||||
| 	for v := range updateChan { | ||||
| 		updateSlice = append(updateSlice, v) | ||||
| 	} | ||||
| 	for _, v := range updateSlice { | ||||
| 		if v == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		updateGoodsCate(v) | ||||
| 	} | ||||
| 
 | ||||
| 	t.Log(updateSlice) | ||||
| } | ||||
| 
 | ||||
| func updateGoodsCate(data *UpdateLog) { | ||||
| 	db.Table("goods_category_relation").Where("goods_id = ?", data.GoodsId).Updates(map[string]interface{}{"category_id": data.CateId}) | ||||
| 	db.Table("goods").Where("id = ?", data.GoodsId).Updates(map[string]interface{}{"brand_category_check": 3}) | ||||
| } | ||||
| 
 | ||||
| type UpdateLog struct { | ||||
| 	GoodsId int    `json:"goods_id"` | ||||
| 	Title   string `json:"title"` | ||||
| 	CateId  int    `json:"cate_id"` | ||||
| } | ||||
| type CategoryPhy struct { | ||||
| 	Id       int    `json:"id"` | ||||
| 	Name     string `json:"name"` | ||||
| 	Level    int    `json:"level"` | ||||
| 	ParentId int    `json:"parent_id"` | ||||
| } | ||||
| 
 | ||||
| type Goods struct { | ||||
| 	Id                 int    `json:"id"` | ||||
| 	Title              string `json:"title"` | ||||
| 	BrandCateGoryCheck int    `json:"brand_category_check"` | ||||
| } | ||||
| 
 | ||||
| func goodsAll() (data []Goods) { | ||||
| 	var ( | ||||
| 		row Goods | ||||
| 		err error | ||||
| 	) | ||||
| 	rows, _ := db.Raw("SELECT id,title  FROM `goods` where brand_category_check=4").Rows() | ||||
| 	defer rows.Close() | ||||
| 	for rows.Next() { | ||||
| 		err = db.ScanRows(rows, &row) | ||||
|  | @ -92,3 +103,44 @@ func cateAll() (data []Category) { | |||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func getPath() []map[int][]*CategoryDic { | ||||
| 	cates := cateAll() | ||||
| 	var ( | ||||
| 		cateLevelPath = make([]map[int][]*CategoryDic, 3) | ||||
| 	) | ||||
| 	for _, v := range cates { | ||||
| 		if cateLevelPath[v.Level-1] == nil { | ||||
| 			cateLevelPath[v.Level-1] = make(map[int][]*CategoryDic) | ||||
| 		} | ||||
| 		if _, ex := cateLevelPath[v.Level-1][v.ParentId]; !ex { | ||||
| 			cateLevelPath[v.Level-1][v.ParentId] = make([]*CategoryDic, 0) | ||||
| 		} | ||||
| 		cateLevelPath[v.Level-1][v.ParentId] = append(cateLevelPath[v.Level-1][v.ParentId], &CategoryDic{ | ||||
| 			Category{ | ||||
| 				Id:    v.Id, | ||||
| 				Name:  v.Name, | ||||
| 				Level: v.Level, | ||||
| 			}}) | ||||
| 	} | ||||
| 	return cateLevelPath | ||||
| } | ||||
| 
 | ||||
| func cateAll() (data []CategoryPhy) { | ||||
| 	var ( | ||||
| 		row CategoryPhy | ||||
| 	) | ||||
| 
 | ||||
| 	rows, _ := db.Raw(fmt.Sprintf("SELECT id,name,level,parent_id  FROM `%s`", table)).Rows() | ||||
| 	defer rows.Close() | ||||
| 	for rows.Next() { | ||||
| 		err := db.ScanRows(rows, &row) | ||||
| 		if err != nil { | ||||
| 			panic(err) | ||||
| 		} | ||||
| 		data = append(data, row) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| var db, _, clean = utils_gorm.DB(driver, dns) | ||||
|  |  | |||
|  | @ -2,7 +2,6 @@ package doubao | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/volcengine/volcengine-go-sdk/service/arkruntime" | ||||
| 	"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" | ||||
| 	"github.com/volcengine/volcengine-go-sdk/volcengine" | ||||
|  | @ -21,7 +20,7 @@ func NewDouBao(modelName string, key string) *DouBao { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandle func(input string) (string, error), text ...string) (string, error) { | ||||
| func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandle func(input string) (string, error), text ...string) (string, int, error) { | ||||
| 	var Message = make([]*model.ChatCompletionMessage, len(text)) | ||||
| 
 | ||||
| 	client := arkruntime.NewClientWithApiKey( | ||||
|  | @ -46,10 +45,10 @@ func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandl | |||
| 
 | ||||
| 	resp, err := client.CreateChatCompletion(ctx, req) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return "", 0, err | ||||
| 	} | ||||
| 	fmt.Printf("用量:%d", resp.Usage.TotalTokens) | ||||
| 
 | ||||
| 	result, err := respHandle(*resp.Choices[0].Message.Content.StringValue) | ||||
| 
 | ||||
| 	return result, err | ||||
| 	return result, resp.Usage.TotalTokens, err | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue