feat: 会话历史

This commit is contained in:
wolter 2025-12-05 17:59:27 +08:00
parent 0f9ac06f73
commit 057cf707d3
15 changed files with 133 additions and 58 deletions

View File

@ -10,31 +10,45 @@ import (
) )
type ChatHistoryBiz struct { type ChatHistoryBiz struct {
chatRepo *impl.ChatImpl chatHiRepo *impl.ChatHisImpl
} }
func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz { func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl) *ChatHistoryBiz {
s := &ChatHistoryBiz{ s := &ChatHistoryBiz{
chatRepo: chatRepo, chatHiRepo: chatHiRepo,
} }
//go s.AsyncProcess(context.Background()) //go s.AsyncProcess(context.Background())
return s return s
} }
//func (s *ChatHistoryBiz) create(ctx context.Context, sessionID, role, content string) error { // 查询会话历史
// chat := model.AiChatHi{ func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]model.AiChatHi, error) {
// SessionID: sessionID, chats, err := s.chatHiRepo.FindAll(
// Role: role, s.chatHiRepo.WithSessionId(query.SessionID),
// Content: content, s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
// } )
// if err != nil {
// return s.chatRepo.Create(&chat) return nil, err
//} }
// return chats, nil
//// 添加会话历史 }
//func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
// return s.create(ctx, chat.SessionID, chat.Role.String(), chat.Content) // 添加会话历史
//} func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
return s.chatHiRepo.Create(&model.AiChatHi{
SessionID: chat.SessionID,
Ques: chat.Role.String(),
Ans: chat.Content,
})
}
// 更新会话历史内容
func (s *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UseFulRequest) error {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"his_id": chat.HisId})
return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
}
// 异步添加会话历史 // 异步添加会话历史
//func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) { //func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) {
@ -53,5 +67,5 @@ func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz {
func (s *ChatHistoryBiz) Update(ctx context.Context, chat *entitys.UseFulRequest) error { func (s *ChatHistoryBiz) Update(ctx context.Context, chat *entitys.UseFulRequest) error {
cond := builder.NewCond() cond := builder.NewCond()
cond = cond.And(builder.Eq{"his_id": chat.HisId}) cond = cond.And(builder.Eq{"his_id": chat.HisId})
return s.chatRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful}) return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
} }

View File

@ -27,14 +27,14 @@ type Do struct {
sessionImpl *impl.SessionImpl sessionImpl *impl.SessionImpl
sysImpl *impl.SysImpl sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl hisImpl *impl.ChatHisImpl
conf *config.Config conf *config.Config
} }
func NewDo( func NewDo(
sysImpl *impl.SysImpl, sysImpl *impl.SysImpl,
taskImpl *impl.TaskImpl, taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl, hisImpl *impl.ChatHisImpl,
conf *config.Config, conf *config.Config,
) *Do { ) *Do {
return &Do{ return &Do{
@ -252,6 +252,7 @@ func (d *Do) startMessageHandler(
Ques: requireData.Req.Text, Ques: requireData.Req.Text,
Ans: strings.Join(chat, ""), Ans: strings.Join(chat, ""),
Files: requireData.Req.Img, Files: requireData.Req.Img,
TaskID: requireData.Task.TaskID,
} }
d.hisImpl.AddWithData(AiRes) d.hisImpl.AddWithData(AiRes)
hisLog.HisId = AiRes.HisID hisLog.HisId = AiRes.HisID

View File

@ -85,6 +85,7 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requir
for _, task := range requireData.Tasks { for _, task := range requireData.Tasks {
if task.Index == requireData.Match.Index { if task.Index == requireData.Match.Index {
pointTask = &task pointTask = &task
requireData.Task = task
break break
} }
} }

View File

@ -20,13 +20,13 @@ import (
type OllamaService struct { type OllamaService struct {
client *utils_ollama.Client client *utils_ollama.Client
config *config.Config config *config.Config
chatHis *impl.ChatImpl chatHis *impl.ChatHisImpl
} }
func NewOllamaGenerate( func NewOllamaGenerate(
client *utils_ollama.Client, client *utils_ollama.Client,
config *config.Config, config *config.Config,
chatHis *impl.ChatImpl, chatHis *impl.ChatHisImpl,
) *OllamaService { ) *OllamaService {
return &OllamaService{ return &OllamaService{
client: client, client: client,

View File

@ -17,16 +17,16 @@ import (
type SessionBiz struct { type SessionBiz struct {
sessionRepo *impl.SessionImpl sessionRepo *impl.SessionImpl
sysRepo *impl.SysImpl sysRepo *impl.SysImpl
chatRepo *impl.ChatImpl chatHisRepo *impl.ChatHisImpl
conf *config.Config conf *config.Config
} }
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatImpl) *SessionBiz { func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatHisImpl) *SessionBiz {
return &SessionBiz{ return &SessionBiz{
sessionRepo: sessionImpl, sessionRepo: sessionImpl,
sysRepo: sysImpl, sysRepo: sysImpl,
chatRepo: chatImpl, chatHisRepo: chatImpl,
conf: conf, conf: conf,
} }
} }
@ -91,10 +91,10 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe
result.Prologue = sysConfig.Prologue result.Prologue = sysConfig.Prologue
// 存在,返回会话历史 // 存在,返回会话历史
var chatList []model.AiChatHi var chatList []model.AiChatHi
chatList, err = s.chatRepo.FindAll( chatList, err = s.chatHisRepo.FindAll(
s.chatRepo.WithSessionId(session.SessionID), // 条件会话ID s.chatHisRepo.WithSessionId(session.SessionID), // 条件会话ID
s.chatRepo.OrderByDesc("create_at"), // 排序:按创建时间降序 s.chatHisRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数 s.chatHisRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
) )
if err != nil { if err != nil {
return return

View File

@ -11,14 +11,14 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
type ChatImpl struct { type ChatHisImpl struct {
dataTemp.DataTemp dataTemp.DataTemp
BaseRepository[model.AiChatHi] BaseRepository[model.AiChatHi]
chatChannel chan model.AiChatHi chatChannel chan model.AiChatHi
} }
func NewChatImpl(db *utils.Db) *ChatImpl { func NewChatHisImpl(db *utils.Db) *ChatHisImpl {
return &ChatImpl{ return &ChatHisImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)), DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)),
BaseRepository: NewBaseModel[model.AiChatHi](db.Client), BaseRepository: NewBaseModel[model.AiChatHi](db.Client),
chatChannel: make(chan model.AiChatHi, 100), chatChannel: make(chan model.AiChatHi, 100),
@ -26,19 +26,19 @@ func NewChatImpl(db *utils.Db) *ChatImpl {
} }
// WithSessionId 条件会话ID // WithSessionId 条件会话ID
func (impl *ChatImpl) WithSessionId(sessionId interface{}) CondFunc { func (impl *ChatHisImpl) WithSessionId(sessionId interface{}) CondFunc {
return func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB {
return db.Where("session_id = ?", sessionId) return db.Where("session_id = ?", sessionId)
} }
} }
// 异步添加会话历史 // 异步添加会话历史
func (impl *ChatImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) { func (impl *ChatHisImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) {
impl.chatChannel <- chat impl.chatChannel <- chat
} }
// 异步处理会话历史 // 异步处理会话历史
func (impl *ChatImpl) AsyncProcess(ctx context.Context) { func (impl *ChatHisImpl) AsyncProcess(ctx context.Context) {
for { for {
select { select {
case chat := <-impl.chatChannel: case chat := <-impl.chatChannel:

View File

@ -4,4 +4,4 @@ import (
"github.com/google/wire" "github.com/google/wire"
) )
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatImpl) var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatHisImpl)

View File

@ -20,6 +20,7 @@ type AiChatHi struct {
Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用其他为无用" json:"useful"` // 0不评价,1有用其他为无用 Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用其他为无用" json:"useful"` // 0不评价,1有用其他为无用
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"` CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"`
TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID
} }
// TableName AiChatHi's table name // TableName AiChatHi's table name

View File

@ -14,3 +14,9 @@ type ChatHistory struct {
type ChatHisLog struct { type ChatHisLog struct {
HisId int64 `json:"his_id"` HisId int64 `json:"his_id"`
} }
type ChatHistQuery struct {
SessionID string `json:"session_id"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}

View File

@ -150,6 +150,7 @@ type RequireData struct {
Histories []model.AiChatHi Histories []model.AiChatHi
SessionInfo model.AiSession SessionInfo model.AiSession
Tasks []model.AiTask Tasks []model.AiTask
Task model.AiTask
Match *Match Match *Match
Req *ChatSockRequest Req *ChatSockRequest
Auth string Auth string

View File

@ -11,24 +11,25 @@ import (
) )
type HTTPServer struct { type HTTPServer struct {
app *fiber.App app *fiber.App
service *services.ChatService service *services.ChatService
session *services.SessionService session *services.SessionService
gateway *gateway.Gateway gateway *gateway.Gateway
callback *services.CallbackService callback *services.CallbackService
} }
func NewHTTPServer( func NewHTTPServer(
service *services.ChatService, service *services.ChatService,
session *services.SessionService, session *services.SessionService,
task *services.TaskService, task *services.TaskService,
gateway *gateway.Gateway, gateway *gateway.Gateway,
callback *services.CallbackService, callback *services.CallbackService,
chatHis *services.HistoryService,
) *fiber.App { ) *fiber.App {
//构建 server //构建 server
app := initRoute() app := initRoute()
router.SetupRoutes(app, service, session, task, gateway, callback) router.SetupRoutes(app, service, session, task, gateway, callback, chatHis)
return app return app
} }
func initRoute() *fiber.App { func initRoute() *fiber.App {

View File

@ -15,14 +15,17 @@ import (
) )
type RouterServer struct { type RouterServer struct {
app *fiber.App app *fiber.App
service *services.ChatService service *services.ChatService
session *services.SessionService session *services.SessionService
gateway *gateway.Gateway gateway *gateway.Gateway
chatHist *services.HistoryService
} }
// SetupRoutes 设置路由 // SetupRoutes 设置路由
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway, callbackService *services.CallbackService) { func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService,
gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService,
) {
app.Use(func(c *fiber.Ctx) error { app.Use(func(c *fiber.Ctx) error {
// 设置 CORS 头 // 设置 CORS 头
c.Set("Access-Control-Allow-Origin", "*") c.Set("Access-Control-Allow-Origin", "*")
@ -77,6 +80,9 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
return ctx.Status(400).SendString("unknown action") return ctx.Status(400).SendString("unknown action")
} }
}) })
// 会话历史
r.Post("/chat/history", chatHist.GetHistory)
} }
func routerSocket(app *fiber.App, chatService *services.ChatService) { func routerSocket(app *fiber.App, chatService *services.ChatService) {

View File

@ -73,9 +73,9 @@ func (h *ChatService) Chat(c *websocket.Conn) {
var wg sync.WaitGroup var wg sync.WaitGroup
// 确保在函数返回时移除客户端并关闭连接 // 确保在函数返回时移除客户端并关闭连接
defer func() { defer func() {
h.Gw.Cleanup(client.GetID())
close(semaphore) // 关闭信号量通道
wg.Wait() // 等待所有消息处理goroutine完成 wg.Wait() // 等待所有消息处理goroutine完成
close(semaphore) // 关闭信号量通道
h.Gw.Cleanup(client.GetID())
}() }()
// 绑定会话ID, sessionId 为空时, 则不绑定 // 绑定会话ID, sessionId 为空时, 则不绑定

View File

@ -0,0 +1,44 @@
package services
import (
"ai_scheduler/internal/biz"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/entitys"
"github.com/gofiber/fiber/v2"
)
type HistoryService struct {
chatRepo *biz.ChatHistoryBiz
}
func NewHistoryService(chatRepo *biz.ChatHistoryBiz) *HistoryService {
return &HistoryService{
chatRepo: chatRepo,
}
}
// GetHistoryService 获取会话历史
func (h *HistoryService) GetHistory(c *fiber.Ctx) error {
var query entitys.ChatHistQuery
if err := c.BodyParser(&query); err != nil {
return err
}
// 校验参数
if query.SessionID == "" {
return errors.SessionNotFound
}
if query.Page <= 0 {
query.Page = 1
}
if query.PageSize <= 0 {
query.PageSize = 10
}
// 查询历史
history, err := h.chatRepo.List(c.Context(), &query)
if err != nil {
return err
}
return c.JSON(history)
}

View File

@ -6,4 +6,4 @@ import (
"github.com/google/wire" "github.com/google/wire"
) )
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService) var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService, NewHistoryService)