feat: 会话历史
This commit is contained in:
parent
0f9ac06f73
commit
057cf707d3
|
|
@ -10,31 +10,45 @@ import (
|
|||
)
|
||||
|
||||
type ChatHistoryBiz struct {
|
||||
chatRepo *impl.ChatImpl
|
||||
chatHiRepo *impl.ChatHisImpl
|
||||
}
|
||||
|
||||
func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz {
|
||||
func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl) *ChatHistoryBiz {
|
||||
s := &ChatHistoryBiz{
|
||||
chatRepo: chatRepo,
|
||||
chatHiRepo: chatHiRepo,
|
||||
}
|
||||
//go s.AsyncProcess(context.Background())
|
||||
return s
|
||||
}
|
||||
|
||||
//func (s *ChatHistoryBiz) create(ctx context.Context, sessionID, role, content string) error {
|
||||
// chat := model.AiChatHi{
|
||||
// SessionID: sessionID,
|
||||
// Role: role,
|
||||
// Content: content,
|
||||
// }
|
||||
//
|
||||
// return s.chatRepo.Create(&chat)
|
||||
//}
|
||||
//
|
||||
//// 添加会话历史
|
||||
//func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
|
||||
// return s.create(ctx, chat.SessionID, chat.Role.String(), chat.Content)
|
||||
//}
|
||||
// 查询会话历史
|
||||
func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]model.AiChatHi, error) {
|
||||
chats, err := s.chatHiRepo.FindAll(
|
||||
s.chatHiRepo.WithSessionId(query.SessionID),
|
||||
s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return chats, nil
|
||||
}
|
||||
|
||||
// 添加会话历史
|
||||
func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
|
||||
return s.chatHiRepo.Create(&model.AiChatHi{
|
||||
SessionID: chat.SessionID,
|
||||
Ques: chat.Role.String(),
|
||||
Ans: chat.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// 更新会话历史内容
|
||||
func (s *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UseFulRequest) error {
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"his_id": chat.HisId})
|
||||
|
||||
return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
|
||||
}
|
||||
|
||||
// 异步添加会话历史
|
||||
//func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) {
|
||||
|
|
@ -53,5 +67,5 @@ func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz {
|
|||
func (s *ChatHistoryBiz) Update(ctx context.Context, chat *entitys.UseFulRequest) error {
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"his_id": chat.HisId})
|
||||
return s.chatRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
|
||||
return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,14 +27,14 @@ type Do struct {
|
|||
sessionImpl *impl.SessionImpl
|
||||
sysImpl *impl.SysImpl
|
||||
taskImpl *impl.TaskImpl
|
||||
hisImpl *impl.ChatImpl
|
||||
hisImpl *impl.ChatHisImpl
|
||||
conf *config.Config
|
||||
}
|
||||
|
||||
func NewDo(
|
||||
sysImpl *impl.SysImpl,
|
||||
taskImpl *impl.TaskImpl,
|
||||
hisImpl *impl.ChatImpl,
|
||||
hisImpl *impl.ChatHisImpl,
|
||||
conf *config.Config,
|
||||
) *Do {
|
||||
return &Do{
|
||||
|
|
@ -252,6 +252,7 @@ func (d *Do) startMessageHandler(
|
|||
Ques: requireData.Req.Text,
|
||||
Ans: strings.Join(chat, ""),
|
||||
Files: requireData.Req.Img,
|
||||
TaskID: requireData.Task.TaskID,
|
||||
}
|
||||
d.hisImpl.AddWithData(AiRes)
|
||||
hisLog.HisId = AiRes.HisID
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requir
|
|||
for _, task := range requireData.Tasks {
|
||||
if task.Index == requireData.Match.Index {
|
||||
pointTask = &task
|
||||
requireData.Task = task
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,13 +20,13 @@ import (
|
|||
type OllamaService struct {
|
||||
client *utils_ollama.Client
|
||||
config *config.Config
|
||||
chatHis *impl.ChatImpl
|
||||
chatHis *impl.ChatHisImpl
|
||||
}
|
||||
|
||||
func NewOllamaGenerate(
|
||||
client *utils_ollama.Client,
|
||||
config *config.Config,
|
||||
chatHis *impl.ChatImpl,
|
||||
chatHis *impl.ChatHisImpl,
|
||||
) *OllamaService {
|
||||
return &OllamaService{
|
||||
client: client,
|
||||
|
|
|
|||
|
|
@ -17,16 +17,16 @@ import (
|
|||
type SessionBiz struct {
|
||||
sessionRepo *impl.SessionImpl
|
||||
sysRepo *impl.SysImpl
|
||||
chatRepo *impl.ChatImpl
|
||||
chatHisRepo *impl.ChatHisImpl
|
||||
|
||||
conf *config.Config
|
||||
}
|
||||
|
||||
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatImpl) *SessionBiz {
|
||||
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatHisImpl) *SessionBiz {
|
||||
return &SessionBiz{
|
||||
sessionRepo: sessionImpl,
|
||||
sysRepo: sysImpl,
|
||||
chatRepo: chatImpl,
|
||||
chatHisRepo: chatImpl,
|
||||
conf: conf,
|
||||
}
|
||||
}
|
||||
|
|
@ -91,10 +91,10 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe
|
|||
result.Prologue = sysConfig.Prologue
|
||||
// 存在,返回会话历史
|
||||
var chatList []model.AiChatHi
|
||||
chatList, err = s.chatRepo.FindAll(
|
||||
s.chatRepo.WithSessionId(session.SessionID), // 条件:会话ID
|
||||
s.chatRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
|
||||
s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
|
||||
chatList, err = s.chatHisRepo.FindAll(
|
||||
s.chatHisRepo.WithSessionId(session.SessionID), // 条件:会话ID
|
||||
s.chatHisRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
|
||||
s.chatHisRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
|||
|
|
@ -11,14 +11,14 @@ import (
|
|||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ChatImpl struct {
|
||||
type ChatHisImpl struct {
|
||||
dataTemp.DataTemp
|
||||
BaseRepository[model.AiChatHi]
|
||||
chatChannel chan model.AiChatHi
|
||||
}
|
||||
|
||||
func NewChatImpl(db *utils.Db) *ChatImpl {
|
||||
return &ChatImpl{
|
||||
func NewChatHisImpl(db *utils.Db) *ChatHisImpl {
|
||||
return &ChatHisImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)),
|
||||
BaseRepository: NewBaseModel[model.AiChatHi](db.Client),
|
||||
chatChannel: make(chan model.AiChatHi, 100),
|
||||
|
|
@ -26,19 +26,19 @@ func NewChatImpl(db *utils.Db) *ChatImpl {
|
|||
}
|
||||
|
||||
// WithSessionId 条件:会话ID
|
||||
func (impl *ChatImpl) WithSessionId(sessionId interface{}) CondFunc {
|
||||
func (impl *ChatHisImpl) WithSessionId(sessionId interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("session_id = ?", sessionId)
|
||||
}
|
||||
}
|
||||
|
||||
// 异步添加会话历史
|
||||
func (impl *ChatImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) {
|
||||
func (impl *ChatHisImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) {
|
||||
impl.chatChannel <- chat
|
||||
}
|
||||
|
||||
// 异步处理会话历史
|
||||
func (impl *ChatImpl) AsyncProcess(ctx context.Context) {
|
||||
func (impl *ChatHisImpl) AsyncProcess(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case chat := <-impl.chatChannel:
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@ import (
|
|||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatImpl)
|
||||
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatHisImpl)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ type AiChatHi struct {
|
|||
Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用
|
||||
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID
|
||||
}
|
||||
|
||||
// TableName AiChatHi's table name
|
||||
|
|
|
|||
|
|
@ -14,3 +14,9 @@ type ChatHistory struct {
|
|||
type ChatHisLog struct {
|
||||
HisId int64 `json:"his_id"`
|
||||
}
|
||||
|
||||
type ChatHistQuery struct {
|
||||
SessionID string `json:"session_id"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -150,6 +150,7 @@ type RequireData struct {
|
|||
Histories []model.AiChatHi
|
||||
SessionInfo model.AiSession
|
||||
Tasks []model.AiTask
|
||||
Task model.AiTask
|
||||
Match *Match
|
||||
Req *ChatSockRequest
|
||||
Auth string
|
||||
|
|
|
|||
|
|
@ -11,24 +11,25 @@ import (
|
|||
)
|
||||
|
||||
type HTTPServer struct {
|
||||
app *fiber.App
|
||||
service *services.ChatService
|
||||
session *services.SessionService
|
||||
gateway *gateway.Gateway
|
||||
callback *services.CallbackService
|
||||
app *fiber.App
|
||||
service *services.ChatService
|
||||
session *services.SessionService
|
||||
gateway *gateway.Gateway
|
||||
callback *services.CallbackService
|
||||
}
|
||||
|
||||
func NewHTTPServer(
|
||||
service *services.ChatService,
|
||||
session *services.SessionService,
|
||||
task *services.TaskService,
|
||||
gateway *gateway.Gateway,
|
||||
callback *services.CallbackService,
|
||||
service *services.ChatService,
|
||||
session *services.SessionService,
|
||||
task *services.TaskService,
|
||||
gateway *gateway.Gateway,
|
||||
callback *services.CallbackService,
|
||||
chatHis *services.HistoryService,
|
||||
) *fiber.App {
|
||||
//构建 server
|
||||
app := initRoute()
|
||||
router.SetupRoutes(app, service, session, task, gateway, callback)
|
||||
return app
|
||||
//构建 server
|
||||
app := initRoute()
|
||||
router.SetupRoutes(app, service, session, task, gateway, callback, chatHis)
|
||||
return app
|
||||
}
|
||||
|
||||
func initRoute() *fiber.App {
|
||||
|
|
|
|||
|
|
@ -15,14 +15,17 @@ import (
|
|||
)
|
||||
|
||||
type RouterServer struct {
|
||||
app *fiber.App
|
||||
service *services.ChatService
|
||||
session *services.SessionService
|
||||
gateway *gateway.Gateway
|
||||
app *fiber.App
|
||||
service *services.ChatService
|
||||
session *services.SessionService
|
||||
gateway *gateway.Gateway
|
||||
chatHist *services.HistoryService
|
||||
}
|
||||
|
||||
// SetupRoutes 设置路由
|
||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway, callbackService *services.CallbackService) {
|
||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService,
|
||||
gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService,
|
||||
) {
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
// 设置 CORS 头
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
|
|
@ -77,6 +80,9 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
|||
return ctx.Status(400).SendString("unknown action")
|
||||
}
|
||||
})
|
||||
|
||||
// 会话历史
|
||||
r.Post("/chat/history", chatHist.GetHistory)
|
||||
}
|
||||
|
||||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||||
|
|
|
|||
|
|
@ -73,9 +73,9 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
|||
var wg sync.WaitGroup
|
||||
// 确保在函数返回时移除客户端并关闭连接
|
||||
defer func() {
|
||||
h.Gw.Cleanup(client.GetID())
|
||||
close(semaphore) // 关闭信号量通道
|
||||
wg.Wait() // 等待所有消息处理goroutine完成
|
||||
close(semaphore) // 关闭信号量通道
|
||||
h.Gw.Cleanup(client.GetID())
|
||||
}()
|
||||
|
||||
// 绑定会话ID, sessionId 为空时, 则不绑定
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz"
|
||||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type HistoryService struct {
|
||||
chatRepo *biz.ChatHistoryBiz
|
||||
}
|
||||
|
||||
func NewHistoryService(chatRepo *biz.ChatHistoryBiz) *HistoryService {
|
||||
return &HistoryService{
|
||||
chatRepo: chatRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetHistoryService 获取会话历史
|
||||
func (h *HistoryService) GetHistory(c *fiber.Ctx) error {
|
||||
var query entitys.ChatHistQuery
|
||||
if err := c.BodyParser(&query); err != nil {
|
||||
return err
|
||||
}
|
||||
// 校验参数
|
||||
if query.SessionID == "" {
|
||||
return errors.SessionNotFound
|
||||
}
|
||||
if query.Page <= 0 {
|
||||
query.Page = 1
|
||||
}
|
||||
if query.PageSize <= 0 {
|
||||
query.PageSize = 10
|
||||
}
|
||||
|
||||
// 查询历史
|
||||
history, err := h.chatRepo.List(c.Context(), &query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(history)
|
||||
}
|
||||
|
|
@ -6,4 +6,4 @@ import (
|
|||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService)
|
||||
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService, NewHistoryService)
|
||||
|
|
|
|||
Loading…
Reference in New Issue