From 057cf707d30f062a010111866d39dadfce71b33d Mon Sep 17 00:00:00 2001 From: wolter <11@gmail> Date: Fri, 5 Dec 2025 17:59:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=9A=E8=AF=9D=E5=8E=86=E5=8F=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/biz/chat_history.go | 50 ++++++++++++++++---------- internal/biz/do/ctx.go | 5 +-- internal/biz/do/handle.go | 1 + internal/biz/llm_service/ollama.go | 4 +-- internal/biz/session.go | 14 ++++---- internal/data/impl/chat_history.go | 12 +++---- internal/data/impl/provider_set.go | 2 +- internal/data/model/ai_chat_his.gen.go | 1 + internal/entitys/chat_history.go | 6 ++++ internal/entitys/types.go | 1 + internal/server/http.go | 29 +++++++-------- internal/server/router/router.go | 16 ++++++--- internal/services/chat.go | 4 +-- internal/services/chat_history.go | 44 +++++++++++++++++++++++ internal/services/provider_set.go | 2 +- 15 files changed, 133 insertions(+), 58 deletions(-) create mode 100644 internal/services/chat_history.go diff --git a/internal/biz/chat_history.go b/internal/biz/chat_history.go index 48dcd0d..eeb74e5 100644 --- a/internal/biz/chat_history.go +++ b/internal/biz/chat_history.go @@ -10,31 +10,45 @@ import ( ) type ChatHistoryBiz struct { - chatRepo *impl.ChatImpl + chatHiRepo *impl.ChatHisImpl } -func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz { +func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl) *ChatHistoryBiz { s := &ChatHistoryBiz{ - chatRepo: chatRepo, + chatHiRepo: chatHiRepo, } //go s.AsyncProcess(context.Background()) return s } -//func (s *ChatHistoryBiz) create(ctx context.Context, sessionID, role, content string) error { -// chat := model.AiChatHi{ -// SessionID: sessionID, -// Role: role, -// Content: content, -// } -// -// return s.chatRepo.Create(&chat) -//} -// -//// 添加会话历史 -//func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error { -// return s.create(ctx, chat.SessionID, chat.Role.String(), chat.Content) -//} +// 查询会话历史 +func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]model.AiChatHi, error) { + chats, err := s.chatHiRepo.FindAll( + s.chatHiRepo.WithSessionId(query.SessionID), + s.chatHiRepo.PaginateScope(query.Page, query.PageSize), + ) + if err != nil { + return nil, err + } + return chats, nil +} + +// 添加会话历史 +func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error { + return s.chatHiRepo.Create(&model.AiChatHi{ + SessionID: chat.SessionID, + Ques: chat.Role.String(), + Ans: chat.Content, + }) +} + +// 更新会话历史内容 +func (s *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UseFulRequest) error { + cond := builder.NewCond() + cond = cond.And(builder.Eq{"his_id": chat.HisId}) + + return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful}) +} // 异步添加会话历史 //func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) { @@ -53,5 +67,5 @@ func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz { func (s *ChatHistoryBiz) Update(ctx context.Context, chat *entitys.UseFulRequest) error { cond := builder.NewCond() cond = cond.And(builder.Eq{"his_id": chat.HisId}) - return s.chatRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful}) + return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful}) } diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index eb2c5a5..43f1e1a 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -27,14 +27,14 @@ type Do struct { sessionImpl *impl.SessionImpl sysImpl *impl.SysImpl taskImpl *impl.TaskImpl - hisImpl *impl.ChatImpl + hisImpl *impl.ChatHisImpl conf *config.Config } func NewDo( sysImpl *impl.SysImpl, taskImpl *impl.TaskImpl, - hisImpl *impl.ChatImpl, + hisImpl *impl.ChatHisImpl, conf *config.Config, ) *Do { return &Do{ @@ -252,6 +252,7 @@ func (d *Do) startMessageHandler( Ques: requireData.Req.Text, Ans: strings.Join(chat, ""), Files: requireData.Req.Img, + TaskID: requireData.Task.TaskID, } d.hisImpl.AddWithData(AiRes) hisLog.HisId = AiRes.HisID diff --git a/internal/biz/do/handle.go b/internal/biz/do/handle.go index e014d25..14695e6 100644 --- a/internal/biz/do/handle.go +++ b/internal/biz/do/handle.go @@ -85,6 +85,7 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requir for _, task := range requireData.Tasks { if task.Index == requireData.Match.Index { pointTask = &task + requireData.Task = task break } } diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 6527a3b..56887a2 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -20,13 +20,13 @@ import ( type OllamaService struct { client *utils_ollama.Client config *config.Config - chatHis *impl.ChatImpl + chatHis *impl.ChatHisImpl } func NewOllamaGenerate( client *utils_ollama.Client, config *config.Config, - chatHis *impl.ChatImpl, + chatHis *impl.ChatHisImpl, ) *OllamaService { return &OllamaService{ client: client, diff --git a/internal/biz/session.go b/internal/biz/session.go index d014148..a42e9c0 100644 --- a/internal/biz/session.go +++ b/internal/biz/session.go @@ -17,16 +17,16 @@ import ( type SessionBiz struct { sessionRepo *impl.SessionImpl sysRepo *impl.SysImpl - chatRepo *impl.ChatImpl + chatHisRepo *impl.ChatHisImpl conf *config.Config } -func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatImpl) *SessionBiz { +func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatHisImpl) *SessionBiz { return &SessionBiz{ sessionRepo: sessionImpl, sysRepo: sysImpl, - chatRepo: chatImpl, + chatHisRepo: chatImpl, conf: conf, } } @@ -91,10 +91,10 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe result.Prologue = sysConfig.Prologue // 存在,返回会话历史 var chatList []model.AiChatHi - chatList, err = s.chatRepo.FindAll( - s.chatRepo.WithSessionId(session.SessionID), // 条件:会话ID - s.chatRepo.OrderByDesc("create_at"), // 排序:按创建时间降序 - s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数 + chatList, err = s.chatHisRepo.FindAll( + s.chatHisRepo.WithSessionId(session.SessionID), // 条件:会话ID + s.chatHisRepo.OrderByDesc("create_at"), // 排序:按创建时间降序 + s.chatHisRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数 ) if err != nil { return diff --git a/internal/data/impl/chat_history.go b/internal/data/impl/chat_history.go index 6f6027d..f1c0b42 100644 --- a/internal/data/impl/chat_history.go +++ b/internal/data/impl/chat_history.go @@ -11,14 +11,14 @@ import ( "gorm.io/gorm" ) -type ChatImpl struct { +type ChatHisImpl struct { dataTemp.DataTemp BaseRepository[model.AiChatHi] chatChannel chan model.AiChatHi } -func NewChatImpl(db *utils.Db) *ChatImpl { - return &ChatImpl{ +func NewChatHisImpl(db *utils.Db) *ChatHisImpl { + return &ChatHisImpl{ DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)), BaseRepository: NewBaseModel[model.AiChatHi](db.Client), chatChannel: make(chan model.AiChatHi, 100), @@ -26,19 +26,19 @@ func NewChatImpl(db *utils.Db) *ChatImpl { } // WithSessionId 条件:会话ID -func (impl *ChatImpl) WithSessionId(sessionId interface{}) CondFunc { +func (impl *ChatHisImpl) WithSessionId(sessionId interface{}) CondFunc { return func(db *gorm.DB) *gorm.DB { return db.Where("session_id = ?", sessionId) } } // 异步添加会话历史 -func (impl *ChatImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) { +func (impl *ChatHisImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) { impl.chatChannel <- chat } // 异步处理会话历史 -func (impl *ChatImpl) AsyncProcess(ctx context.Context) { +func (impl *ChatHisImpl) AsyncProcess(ctx context.Context) { for { select { case chat := <-impl.chatChannel: diff --git a/internal/data/impl/provider_set.go b/internal/data/impl/provider_set.go index f970e84..1284f11 100644 --- a/internal/data/impl/provider_set.go +++ b/internal/data/impl/provider_set.go @@ -4,4 +4,4 @@ import ( "github.com/google/wire" ) -var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatImpl) +var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatHisImpl) diff --git a/internal/data/model/ai_chat_his.gen.go b/internal/data/model/ai_chat_his.gen.go index d143972..e3b5b58 100644 --- a/internal/data/model/ai_chat_his.gen.go +++ b/internal/data/model/ai_chat_his.gen.go @@ -20,6 +20,7 @@ type AiChatHi struct { Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用 CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` + TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID } // TableName AiChatHi's table name diff --git a/internal/entitys/chat_history.go b/internal/entitys/chat_history.go index a26fa46..34e6c5e 100644 --- a/internal/entitys/chat_history.go +++ b/internal/entitys/chat_history.go @@ -14,3 +14,9 @@ type ChatHistory struct { type ChatHisLog struct { HisId int64 `json:"his_id"` } + +type ChatHistQuery struct { + SessionID string `json:"session_id"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} diff --git a/internal/entitys/types.go b/internal/entitys/types.go index ece5ceb..a9dc72e 100644 --- a/internal/entitys/types.go +++ b/internal/entitys/types.go @@ -150,6 +150,7 @@ type RequireData struct { Histories []model.AiChatHi SessionInfo model.AiSession Tasks []model.AiTask + Task model.AiTask Match *Match Req *ChatSockRequest Auth string diff --git a/internal/server/http.go b/internal/server/http.go index 4cdc393..fd7e49e 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -11,24 +11,25 @@ import ( ) type HTTPServer struct { - app *fiber.App - service *services.ChatService - session *services.SessionService - gateway *gateway.Gateway - callback *services.CallbackService + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway + callback *services.CallbackService } func NewHTTPServer( - service *services.ChatService, - session *services.SessionService, - task *services.TaskService, - gateway *gateway.Gateway, - callback *services.CallbackService, + service *services.ChatService, + session *services.SessionService, + task *services.TaskService, + gateway *gateway.Gateway, + callback *services.CallbackService, + chatHis *services.HistoryService, ) *fiber.App { - //构建 server - app := initRoute() - router.SetupRoutes(app, service, session, task, gateway, callback) - return app + //构建 server + app := initRoute() + router.SetupRoutes(app, service, session, task, gateway, callback, chatHis) + return app } func initRoute() *fiber.App { diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 1a6fd8b..4f5579c 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -15,14 +15,17 @@ import ( ) type RouterServer struct { - app *fiber.App - service *services.ChatService - session *services.SessionService - gateway *gateway.Gateway + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway + chatHist *services.HistoryService } // SetupRoutes 设置路由 -func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway, callbackService *services.CallbackService) { +func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, + gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService, +) { app.Use(func(c *fiber.Ctx) error { // 设置 CORS 头 c.Set("Access-Control-Allow-Origin", "*") @@ -77,6 +80,9 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi return ctx.Status(400).SendString("unknown action") } }) + + // 会话历史 + r.Post("/chat/history", chatHist.GetHistory) } func routerSocket(app *fiber.App, chatService *services.ChatService) { diff --git a/internal/services/chat.go b/internal/services/chat.go index 11dda38..eb47809 100644 --- a/internal/services/chat.go +++ b/internal/services/chat.go @@ -73,9 +73,9 @@ func (h *ChatService) Chat(c *websocket.Conn) { var wg sync.WaitGroup // 确保在函数返回时移除客户端并关闭连接 defer func() { - h.Gw.Cleanup(client.GetID()) - close(semaphore) // 关闭信号量通道 wg.Wait() // 等待所有消息处理goroutine完成 + close(semaphore) // 关闭信号量通道 + h.Gw.Cleanup(client.GetID()) }() // 绑定会话ID, sessionId 为空时, 则不绑定 diff --git a/internal/services/chat_history.go b/internal/services/chat_history.go new file mode 100644 index 0000000..f49d238 --- /dev/null +++ b/internal/services/chat_history.go @@ -0,0 +1,44 @@ +package services + +import ( + "ai_scheduler/internal/biz" + errors "ai_scheduler/internal/data/error" + "ai_scheduler/internal/entitys" + "github.com/gofiber/fiber/v2" +) + +type HistoryService struct { + chatRepo *biz.ChatHistoryBiz +} + +func NewHistoryService(chatRepo *biz.ChatHistoryBiz) *HistoryService { + return &HistoryService{ + chatRepo: chatRepo, + } +} + +// GetHistoryService 获取会话历史 +func (h *HistoryService) GetHistory(c *fiber.Ctx) error { + var query entitys.ChatHistQuery + if err := c.BodyParser(&query); err != nil { + return err + } + // 校验参数 + if query.SessionID == "" { + return errors.SessionNotFound + } + if query.Page <= 0 { + query.Page = 1 + } + if query.PageSize <= 0 { + query.PageSize = 10 + } + + // 查询历史 + history, err := h.chatRepo.List(c.Context(), &query) + if err != nil { + return err + } + + return c.JSON(history) +} diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index 0e1284b..668c7fe 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -6,4 +6,4 @@ import ( "github.com/google/wire" ) -var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService) +var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService, NewHistoryService)