diff --git a/go.mod b/go.mod index 49b7751..b6e99c9 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,9 @@ require ( github.com/faabiosr/cachego v0.26.0 github.com/fastwego/dingding v1.0.0-beta.4 github.com/go-kratos/kratos/v2 v2.9.1 + github.com/go-playground/locales v0.14.1 + github.com/go-playground/universal-translator v0.18.1 + github.com/go-playground/validator/v10 v10.20.0 github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/websocket/v2 v2.2.1 github.com/google/uuid v1.6.0 @@ -53,6 +56,7 @@ require ( github.com/evanphx/json-patch v0.5.2 // indirect github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/goph/emperror v0.17.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect @@ -61,6 +65,7 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index 5dc58f0..a6689ca 100644 --- a/go.sum +++ b/go.sum @@ -174,6 +174,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= @@ -183,6 +185,14 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-kratos/kratos/v2 v2.9.1 h1:EGif6/S/aK/RCR5clIbyhioTNyoSrii3FC118jG40Z0= github.com/go-kratos/kratos/v2 v2.9.1/go.mod h1:a1MQLjMhIh7R0kcJS9SzJYR43BRI7EPzzN0J1Ksu2bA= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw= @@ -292,6 +302,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= diff --git a/internal/biz/chat_history.go b/internal/biz/chat_history.go index eeb74e5..58266e8 100644 --- a/internal/biz/chat_history.go +++ b/internal/biz/chat_history.go @@ -1,53 +1,105 @@ package biz import ( + errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/util" "context" - + "encoding/json" "xorm.io/builder" ) type ChatHistoryBiz struct { chatHiRepo *impl.ChatHisImpl + taskRepo *impl.TaskImpl } -func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl) *ChatHistoryBiz { +func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl, taskRepo *impl.TaskImpl) *ChatHistoryBiz { s := &ChatHistoryBiz{ chatHiRepo: chatHiRepo, + taskRepo: taskRepo, } - //go s.AsyncProcess(context.Background()) return s } // 查询会话历史 -func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]model.AiChatHi, error) { +func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]entitys.ChatHisQueryResponse, error) { chats, err := s.chatHiRepo.FindAll( s.chatHiRepo.WithSessionId(query.SessionID), s.chatHiRepo.PaginateScope(query.Page, query.PageSize), + s.chatHiRepo.OrderByDesc("his_id"), ) if err != nil { return nil, err } - return chats, nil + + taskIds := make([]int32, 0, len(chats)) + for _, chat := range chats { + // 去重任务ID + if !util.Contains(taskIds, chat.TaskID) { + taskIds = append(taskIds, chat.TaskID) + } + } + + // 查询任务名称 + tasks, err := s.taskRepo.FindAll(s.taskRepo.In("task_id", taskIds)) + if err != nil { + return nil, err + } + taskMap := make(map[int32]model.AiTask) + for _, task := range tasks { + taskMap[task.TaskID] = task + } + + // 构建结果 + result := make([]entitys.ChatHisQueryResponse, 0, len(chats)) + for _, chat := range chats { + item := entitys.ChatHisQueryResponse{} + item.FromModel(chat, taskMap[chat.TaskID]) + result = append(result, item) + } + + return result, nil } -// 添加会话历史 -func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error { - return s.chatHiRepo.Create(&model.AiChatHi{ - SessionID: chat.SessionID, - Ques: chat.Role.String(), - Ans: chat.Content, - }) -} +//// 添加会话历史 +//func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error { +// return s.chatHiRepo.Create(&model.AiChatHi{ +// SessionID: chat.SessionID, +// Ques: chat.Role.String(), +// Ans: chat.Content, +// }) +//} -// 更新会话历史内容 -func (s *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UseFulRequest) error { - cond := builder.NewCond() - cond = cond.And(builder.Eq{"his_id": chat.HisId}) +// 更新会话历史内容, 追加内容, 不覆盖原有内容, content 使用json格式存储 +func (c *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UpdateContentRequest) error { + var contents []string + chatHi, has, err := c.chatHiRepo.FindOne(c.chatHiRepo.WithHisId(chat.HisID)) + if err != nil { + return err + } else if !has { + return errors.NewBusinessErr(errors.InvalidParamCode, "chat history not found") + } - return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful}) + if "" != chatHi.Content { + // 解析历史内容 + err = json.Unmarshal([]byte(chatHi.Content), &contents) + if err != nil { + return err + } + } + contents = append(contents, chat.Content) + + b, err := json.Marshal(contents) + if err != nil { + return err + } + chatHi.Content = string(b) + return c.chatHiRepo.Update(&chatHi, + c.chatHiRepo.Select("content"), + c.chatHiRepo.WithHisId(chatHi.HisID)) } // 异步添加会话历史 diff --git a/internal/data/impl/base.go b/internal/data/impl/base.go index 5ab2afa..7aee0b7 100644 --- a/internal/data/impl/base.go +++ b/internal/data/impl/base.go @@ -54,6 +54,8 @@ type BaseRepository[P PO] interface { WithStatus(status int) CondFunc // 查询status GetDb() *gorm.DB // 获取数据库连接 WithLimit(limit int) CondFunc // 限制返回条数 + In(field string, values interface{}) CondFunc // 查询字段是否在列表中 + Select(fields ...string) CondFunc // 选择字段 } // PaginationResult 分页查询结果 @@ -215,3 +217,17 @@ func (this *BaseModel[P]) WithLimit(limit int) CondFunc { return db.Limit(limit) } } + +// 查询字段是否在列表中 +func (this *BaseModel[P]) In(field string, values interface{}) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s IN ?", field), values) + } +} + +// select 字段 +func (this *BaseModel[P]) Select(fields ...string) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Select(fields) + } +} diff --git a/internal/data/impl/chat_history.go b/internal/data/impl/chat_history.go index f1c0b42..08a7fa8 100644 --- a/internal/data/impl/chat_history.go +++ b/internal/data/impl/chat_history.go @@ -55,3 +55,10 @@ func (impl *ChatHisImpl) AsyncProcess(ctx context.Context) { } } } + +// his_id 条件:历史ID +func (impl *ChatHisImpl) WithHisId(hisId interface{}) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where("his_id = ?", hisId) + } +} diff --git a/internal/data/impl/task_impl.go b/internal/data/impl/task_impl.go index 8b3f246..3c76680 100644 --- a/internal/data/impl/task_impl.go +++ b/internal/data/impl/task_impl.go @@ -8,8 +8,12 @@ import ( type TaskImpl struct { dataTemp.DataTemp + BaseRepository[model.AiTask] } func NewTaskImpl(db *utils.Db) *TaskImpl { - return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiTask))} + return &TaskImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiTask)), + BaseRepository: NewBaseModel[model.AiTask](db.Client), + } } diff --git a/internal/data/model/ai_chat_his.gen.go b/internal/data/model/ai_chat_his.gen.go index e3b5b58..595b4c4 100644 --- a/internal/data/model/ai_chat_his.gen.go +++ b/internal/data/model/ai_chat_his.gen.go @@ -21,6 +21,7 @@ type AiChatHi struct { CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID + Content string `gorm:"column:content" json:"content"` // 前端回传数据 } // TableName AiChatHi's table name diff --git a/internal/entitys/chat_history.go b/internal/entitys/chat_history.go index 34e6c5e..6991a53 100644 --- a/internal/entitys/chat_history.go +++ b/internal/entitys/chat_history.go @@ -2,6 +2,8 @@ package entitys import ( "ai_scheduler/internal/data/constants" + "ai_scheduler/internal/data/model" + "encoding/json" ) type ChatHistory struct { @@ -20,3 +22,43 @@ type ChatHistQuery struct { Page int `json:"page"` PageSize int `json:"page_size"` } + +type ChatHisQueryResponse struct { + HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"` + SessionID string `gorm:"column:session_id;not null" json:"session_id"` + Ques string `gorm:"column:ques;not null" json:"ques"` + Ans string `gorm:"column:ans;not null" json:"ans"` + Files string `gorm:"column:files;not null" json:"files"` + Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用 + CreateAt string `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` + TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID + TaskName string `gorm:"column:task_name;not null" json:"task_name"` // 任务名称 + Content []string `gorm:"column:content" json:"content"` // 前端回传数据 +} + +func (c *ChatHisQueryResponse) FromModel(chat model.AiChatHi, task model.AiTask) { + c.HisID = chat.HisID + c.SessionID = chat.SessionID + c.Ques = chat.Ques + c.Ans = chat.Ans + c.Files = chat.Files + c.Useful = chat.Useful + c.CreateAt = chat.CreateAt.Format("2006-01-02 15:04:05") + c.TaskID = chat.TaskID + c.TaskName = task.Name + c.Content = make([]string, 0) + + // 解析Content + if "" != chat.Content { + var contents []string + if err := json.Unmarshal([]byte(chat.Content), &contents); err != nil { + c.Content = contents + } + c.Content = append(c.Content, chat.Content) + } +} + +type UpdateContentRequest struct { + HisID int64 `json:"his_id" validate:"required"` + Content string `json:"content" validate:"required"` +} diff --git a/internal/pkg/util/string.go b/internal/pkg/util/string.go index 1fc898a..9dd4056 100644 --- a/internal/pkg/util/string.go +++ b/internal/pkg/util/string.go @@ -32,3 +32,13 @@ func StringToFloat64(s string) float64 { i, _ := strconv.ParseFloat(s, 64) return i } + +// 是否包含在数组中 +func Contains[T comparable](strings []T, str T) bool { + for _, s := range strings { + if s == str { + return true + } + } + return false +} diff --git a/internal/pkg/validate/validate.go b/internal/pkg/validate/validate.go new file mode 100644 index 0000000..de58134 --- /dev/null +++ b/internal/pkg/validate/validate.go @@ -0,0 +1,45 @@ +package validate + +import ( + "fmt" + "github.com/go-playground/locales/zh" + ut "github.com/go-playground/universal-translator" + "github.com/go-playground/validator/v10" + zh_translations "github.com/go-playground/validator/v10/translations/zh" + "reflect" +) + +func Struct(s interface{}) (errMsg []string, err error) { + // 创建验证器实例 + validate := validator.New() + + // 创建中文翻译器 + zh_ch := zh.New() + uni := ut.New(zh_ch, zh_ch) + trans, _ := uni.GetTranslator("zh") + + //注册一个函数,获取struct tag里自定义的label作为字段名 + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := fld.Tag.Get("label") + return name + }) + + // 注册中文翻译器到验证器 + _ = zh_translations.RegisterDefaultTranslations(validate, trans) + + // 验证结构体 + err = validate.Struct(s) + if err != nil { + // 处理验证错误 + if _, ok := err.(*validator.InvalidValidationError); ok { + fmt.Println("处理验证错误error:", err) + errMsg = append(errMsg, err.Error()) + } else { + for _, v := range err.(validator.ValidationErrors) { + errMsg = append(errMsg, v.Translate(trans)) + } + } + } + + return +} diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 4f5579c..e2645bb 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -82,7 +82,8 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi }) // 会话历史 - r.Post("/chat/history", chatHist.GetHistory) + r.Post("/chat/history/list", chatHist.List) + r.Post("/chat/history/update/content", chatHist.UpdateContent) } func routerSocket(app *fiber.App, chatService *services.ChatService) { diff --git a/internal/services/chat_history.go b/internal/services/chat_history.go index f49d238..7bd4d75 100644 --- a/internal/services/chat_history.go +++ b/internal/services/chat_history.go @@ -4,7 +4,10 @@ import ( "ai_scheduler/internal/biz" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/entitys" + "ai_scheduler/internal/pkg/validate" "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/log" + "strings" ) type HistoryService struct { @@ -18,7 +21,7 @@ func NewHistoryService(chatRepo *biz.ChatHistoryBiz) *HistoryService { } // GetHistoryService 获取会话历史 -func (h *HistoryService) GetHistory(c *fiber.Ctx) error { +func (h *HistoryService) List(c *fiber.Ctx) error { var query entitys.ChatHistQuery if err := c.BodyParser(&query); err != nil { return err @@ -42,3 +45,19 @@ func (h *HistoryService) GetHistory(c *fiber.Ctx) error { return c.JSON(history) } + +func (h *HistoryService) UpdateContent(c *fiber.Ctx) error { + var req entitys.UpdateContentRequest + if err := c.BodyParser(&req); err != nil { + return err + } + // 校验参数 + msg, err := validate.Struct(req) + if err != nil { + log.Error(c.UserContext(), "参数错误 error: ", err) + return errors.NewBusinessErr(errors.InvalidParamCode, strings.Join(msg, ";")) + } + + // 更新历史 + return h.chatRepo.UpdateContent(c.Context(), &req) +}