From a24edb78b8d311e6cbd31f2f321f3a8193dd2567 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Thu, 27 Mar 2025 15:33:53 +0800 Subject: [PATCH] 1 --- ai.go | 29 +++++++++++++ ai_test.go | 88 +++++++++++++++++++++++++++++++++++++++ doubao/constant.go | 17 ++++++++ doubao/doubao.go | 53 ++++++++++++++++++++++++ go.mod | 20 +++++++++ utils_gorm/gorm.go | 22 ++++++++++ utils_gorm/sql_log.go | 96 +++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 325 insertions(+) create mode 100644 ai.go create mode 100644 ai_test.go create mode 100644 doubao/constant.go create mode 100644 doubao/doubao.go create mode 100644 go.mod create mode 100644 utils_gorm/gorm.go create mode 100644 utils_gorm/sql_log.go diff --git a/ai.go b/ai.go new file mode 100644 index 0000000..2b982d7 --- /dev/null +++ b/ai.go @@ -0,0 +1,29 @@ +package l_ai_category + +import ( + "context" + "gitea.cdlsxd.cn/self-tools/l_ai_category/doubao" + "strconv" +) + +func GetCategory(ctx context.Context, goodsName, key, model string, cateJson string) (categoryId int, err error) { + + modelObj := doubao.NewDouBao(model, key) + text := []string{ + "根据商品名称,从json中找到该商品对应的第三级分类[QUESTION]" + goodsName + "[/QUESTION]", + "-只需要返回类型的名称对应的id", + "-如果无法匹配,则返回数字0", + "-以下是 JSON 数据:" + cateJson, + } + category, err := modelObj.GetData(ctx, key, func(input string) (string, error) { + return input, nil + }, text...) + if err != nil { + return + } + categoryId, err = strconv.Atoi(category) + if err != nil { + categoryId = 0 + } + return +} diff --git a/ai_test.go b/ai_test.go new file mode 100644 index 0000000..18f8ce1 --- /dev/null +++ b/ai_test.go @@ -0,0 +1,88 @@ +package l_ai_category + +import ( + "context" + "encoding/json" + "fmt" + "gitea.cdlsxd.cn/self-tools/l_ai_category/utils_gorm" + "testing" +) + +const ( + dns = "root:lansexiongdi6,@tcp(47.97.27.195:3306)/report?parseTime=True&loc=Local" + driver = "mysql" + table = "goods_category" +) + +func TestCategory(t *testing.T) { + catString := string(CategoryToJsonByte()) + res, err := GetCategory(context.Background(), "金龙鱼御品珍珠米", "03320e58-6a0b-4061-a22b-902039f2190d", "deepseek-v3-250324", catString) + t.Log(res, err) +} + +type Category struct { + Id int `json:"id"` + Name string `json:"name"` + Level int `json:"level"` + Pid int `json:"pid"` +} + +type CategoryDic struct { + Category + Children []*CategoryDic `json:"children"` +} + +func TestGetCateGoryJson(t *testing.T) { + cateJsonByte := CategoryToJsonByte() + + t.Log(string(cateJsonByte)) +} + +func CategoryToJsonByte() (cateJsonByte []byte) { + cates := cateAll() + var ( + cateLevelPath = make([]map[int][]*CategoryDic, 3) + ) + for _, v := range cates { + if cateLevelPath[v.Level-1] == nil { + cateLevelPath[v.Level-1] = make(map[int][]*CategoryDic) + } + if _, ex := cateLevelPath[v.Level-1][v.Pid]; !ex { + cateLevelPath[v.Level-1][v.Pid] = make([]*CategoryDic, 0) + } + cateLevelPath[v.Level-1][v.Pid] = append(cateLevelPath[v.Level-1][v.Pid], &CategoryDic{Category: v}) + } + cate := cateIteration(cateLevelPath[0], cateIteration(cateLevelPath[1], cateLevelPath[2]))[0] + cateJsonByte, _ = json.Marshal(cate) + return +} + +func cateIteration(parent map[int][]*CategoryDic, child map[int][]*CategoryDic) map[int][]*CategoryDic { + for _, v := range parent { + for _, vv := range v { + vv.Children = child[vv.Id] + } + } + return parent +} + +func cateAll() (data []Category) { + var ( + row Category + ) + db, err, clean := utils_gorm.DB(driver, dns) + defer clean() + if err != nil { + panic(err) + } + rows, _ := db.Raw(fmt.Sprintf("SELECT id,name,level,pid FROM `%s` where `type`=2", table)).Rows() + defer rows.Close() + for rows.Next() { + err = db.ScanRows(rows, &row) + if err != nil { + panic(err) + } + data = append(data, row) + } + return +} diff --git a/doubao/constant.go b/doubao/constant.go new file mode 100644 index 0000000..1f00ea6 --- /dev/null +++ b/doubao/constant.go @@ -0,0 +1,17 @@ +package doubao + +var UrlMap = map[UrlType]string{ + Text: "https://ark.cn-beijing.volces.com/api/v3/chat/completions", + Video: "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks", + Embedding: "https://ark.cn-beijing.volces.com/api/v3/embeddings", + Token: "https://ark.cn-beijing.volces.com/api/v3/tokenization", +} + +type UrlType string + +const ( + Text UrlType = "text" + Video UrlType = "video" + Embedding UrlType = "embedding" + Token UrlType = "token" +) diff --git a/doubao/doubao.go b/doubao/doubao.go new file mode 100644 index 0000000..fb82f25 --- /dev/null +++ b/doubao/doubao.go @@ -0,0 +1,53 @@ +package doubao + +import ( + "context" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" + "github.com/volcengine/volcengine-go-sdk/volcengine" + "time" +) + +type DouBao struct { + Model string + Key string +} + +func NewDouBao(modelName string, key string) *DouBao { + return &DouBao{ + Model: modelName, + Key: key, + } +} + +func (o *DouBao) GetData(ctx context.Context, key string, respHandle func(input string) (string, error), text ...string) (string, error) { + var Message = make([]*model.ChatCompletionMessage, len(text)) + + client := arkruntime.NewClientWithApiKey( + key, + //arkruntime.WithBaseUrl(UrlMap[url]), + arkruntime.WithRegion("cn-beijing"), + arkruntime.WithTimeout(2*time.Minute), + arkruntime.WithRetryTimes(2), + ) + for k, v := range text { + Message[k] = &model.ChatCompletionMessage{ + Role: model.ChatMessageRoleSystem, + Content: &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(v), + }, + } + } + req := model.CreateChatCompletionRequest{ + Model: o.Model, + Messages: Message, + } + + resp, err := client.CreateChatCompletion(ctx, req) + if err != nil { + return "", err + } + result, err := respHandle(*resp.Choices[0].Message.Content.StringValue) + + return result, err +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1f799ea --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module gitea.cdlsxd.cn/self-tools/l_ai_category + +go 1.23.7 + +require ( + github.com/volcengine/volcengine-go-sdk v1.0.187 + gorm.io/driver/mysql v1.5.7 + gorm.io/gorm v1.25.12 +) + +require ( + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/volcengine/volc-sdk-golang v1.0.23 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/yaml.v2 v2.2.8 // indirect +) diff --git a/utils_gorm/gorm.go b/utils_gorm/gorm.go new file mode 100644 index 0000000..9d45669 --- /dev/null +++ b/utils_gorm/gorm.go @@ -0,0 +1,22 @@ +package utils_gorm + +import ( + "database/sql" + "fmt" + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +func DB(driver string, dns string) (*gorm.DB, error, func()) { + mysqlConn, err := sql.Open(driver, dns) + gormDB, err := gorm.Open( + mysql.New(mysql.Config{Conn: mysqlConn}), + ) + return gormDB, err, func() { + if mysqlConn != nil { + if err := mysqlConn.Close(); err != nil { + fmt.Println("关闭 DB 失败:", err) + } + } + } +} diff --git a/utils_gorm/sql_log.go b/utils_gorm/sql_log.go new file mode 100644 index 0000000..b34a29d --- /dev/null +++ b/utils_gorm/sql_log.go @@ -0,0 +1,96 @@ +package utils_gorm + +import ( + "context" + "fmt" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "regexp" + "strings" + "time" +) + +type CustomLogger struct { + gormLogger logger.Interface + db *gorm.DB +} + +func NewCustomLogger(db *gorm.DB) *CustomLogger { + return &CustomLogger{ + gormLogger: logger.Default.LogMode(logger.Info), + db: db, + } +} + +func (l *CustomLogger) LogMode(level logger.LogLevel) logger.Interface { + newlogger := *l + newlogger.gormLogger = l.gormLogger.LogMode(level) + return &newlogger +} + +func (l *CustomLogger) Info(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Info(ctx, msg, data...) +} + +func (l *CustomLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Warn(ctx, msg, data...) +} + +func (l *CustomLogger) Error(ctx context.Context, msg string, data ...interface{}) { + l.gormLogger.Error(ctx, msg, data...) +} + +func (l *CustomLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + elapsed := time.Since(begin) + sql, _ := fc() + l.gormLogger.Trace(ctx, begin, fc, err) + operation := extractOperation(sql) + tableName := extractTableName(sql) + fmt.Println(tableName) + //// 将SQL语句保存到数据库 + if operation == 0 || tableName == "sql_log" { + return + } + //go l.db.Model(&SqlLog{}).Create(&SqlLog{ + // OperatorID: 1, + // OperatorName: "test", + // SqlInfo: sql, + // TableNames: tableName, + // Type: operation, + //}) + + // 如果有需要,也可以根据执行时间(elapsed)等条件过滤或处理日志记录 + if elapsed > time.Second { + //l.gormLogger.Warn(ctx, "Slow SQL (> 1s): %s", sql) + } +} + +// extractTableName extracts the table name from a SQL query, supporting quoted table names. +func extractTableName(sql string) string { + // 使用非捕获组匹配多种SQL操作关键词 + re := regexp.MustCompile(`(?i)\b(?:from|update|into|delete\s+from)\b\s+[\` + "`" + `"]?(\w+)[\` + "`" + `"]?`) + match := re.FindStringSubmatch(sql) + + // 检查是否匹配成功 + if len(match) > 1 { + return match[1] + } + + return "" +} + +// extractOperation extracts the operation type from a SQL query. +func extractOperation(sql string) int32 { + sql = strings.TrimSpace(strings.ToLower(sql)) + var operation int32 + if strings.HasPrefix(sql, "select") { + operation = 0 + } else if strings.HasPrefix(sql, "insert") { + operation = 1 + } else if strings.HasPrefix(sql, "update") { + operation = 3 + } else if strings.HasPrefix(sql, "delete") { + operation = 2 + } + return operation +}