feat: 会话历史
This commit is contained in:
parent
0f9ac06f73
commit
057cf707d3
|
|
@ -10,31 +10,45 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatHistoryBiz struct {
|
type ChatHistoryBiz struct {
|
||||||
chatRepo *impl.ChatImpl
|
chatHiRepo *impl.ChatHisImpl
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz {
|
func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl) *ChatHistoryBiz {
|
||||||
s := &ChatHistoryBiz{
|
s := &ChatHistoryBiz{
|
||||||
chatRepo: chatRepo,
|
chatHiRepo: chatHiRepo,
|
||||||
}
|
}
|
||||||
//go s.AsyncProcess(context.Background())
|
//go s.AsyncProcess(context.Background())
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
//func (s *ChatHistoryBiz) create(ctx context.Context, sessionID, role, content string) error {
|
// 查询会话历史
|
||||||
// chat := model.AiChatHi{
|
func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]model.AiChatHi, error) {
|
||||||
// SessionID: sessionID,
|
chats, err := s.chatHiRepo.FindAll(
|
||||||
// Role: role,
|
s.chatHiRepo.WithSessionId(query.SessionID),
|
||||||
// Content: content,
|
s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
|
||||||
// }
|
)
|
||||||
//
|
if err != nil {
|
||||||
// return s.chatRepo.Create(&chat)
|
return nil, err
|
||||||
//}
|
}
|
||||||
//
|
return chats, nil
|
||||||
//// 添加会话历史
|
}
|
||||||
//func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
|
|
||||||
// return s.create(ctx, chat.SessionID, chat.Role.String(), chat.Content)
|
// 添加会话历史
|
||||||
//}
|
func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
|
||||||
|
return s.chatHiRepo.Create(&model.AiChatHi{
|
||||||
|
SessionID: chat.SessionID,
|
||||||
|
Ques: chat.Role.String(),
|
||||||
|
Ans: chat.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新会话历史内容
|
||||||
|
func (s *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UseFulRequest) error {
|
||||||
|
cond := builder.NewCond()
|
||||||
|
cond = cond.And(builder.Eq{"his_id": chat.HisId})
|
||||||
|
|
||||||
|
return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
|
||||||
|
}
|
||||||
|
|
||||||
// 异步添加会话历史
|
// 异步添加会话历史
|
||||||
//func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) {
|
//func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) {
|
||||||
|
|
@ -53,5 +67,5 @@ func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz {
|
||||||
func (s *ChatHistoryBiz) Update(ctx context.Context, chat *entitys.UseFulRequest) error {
|
func (s *ChatHistoryBiz) Update(ctx context.Context, chat *entitys.UseFulRequest) error {
|
||||||
cond := builder.NewCond()
|
cond := builder.NewCond()
|
||||||
cond = cond.And(builder.Eq{"his_id": chat.HisId})
|
cond = cond.And(builder.Eq{"his_id": chat.HisId})
|
||||||
return s.chatRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
|
return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,14 +27,14 @@ type Do struct {
|
||||||
sessionImpl *impl.SessionImpl
|
sessionImpl *impl.SessionImpl
|
||||||
sysImpl *impl.SysImpl
|
sysImpl *impl.SysImpl
|
||||||
taskImpl *impl.TaskImpl
|
taskImpl *impl.TaskImpl
|
||||||
hisImpl *impl.ChatImpl
|
hisImpl *impl.ChatHisImpl
|
||||||
conf *config.Config
|
conf *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDo(
|
func NewDo(
|
||||||
sysImpl *impl.SysImpl,
|
sysImpl *impl.SysImpl,
|
||||||
taskImpl *impl.TaskImpl,
|
taskImpl *impl.TaskImpl,
|
||||||
hisImpl *impl.ChatImpl,
|
hisImpl *impl.ChatHisImpl,
|
||||||
conf *config.Config,
|
conf *config.Config,
|
||||||
) *Do {
|
) *Do {
|
||||||
return &Do{
|
return &Do{
|
||||||
|
|
@ -252,6 +252,7 @@ func (d *Do) startMessageHandler(
|
||||||
Ques: requireData.Req.Text,
|
Ques: requireData.Req.Text,
|
||||||
Ans: strings.Join(chat, ""),
|
Ans: strings.Join(chat, ""),
|
||||||
Files: requireData.Req.Img,
|
Files: requireData.Req.Img,
|
||||||
|
TaskID: requireData.Task.TaskID,
|
||||||
}
|
}
|
||||||
d.hisImpl.AddWithData(AiRes)
|
d.hisImpl.AddWithData(AiRes)
|
||||||
hisLog.HisId = AiRes.HisID
|
hisLog.HisId = AiRes.HisID
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,7 @@ func (r *Handle) HandleMatch(ctx context.Context, client *gateway.Client, requir
|
||||||
for _, task := range requireData.Tasks {
|
for _, task := range requireData.Tasks {
|
||||||
if task.Index == requireData.Match.Index {
|
if task.Index == requireData.Match.Index {
|
||||||
pointTask = &task
|
pointTask = &task
|
||||||
|
requireData.Task = task
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,13 +20,13 @@ import (
|
||||||
type OllamaService struct {
|
type OllamaService struct {
|
||||||
client *utils_ollama.Client
|
client *utils_ollama.Client
|
||||||
config *config.Config
|
config *config.Config
|
||||||
chatHis *impl.ChatImpl
|
chatHis *impl.ChatHisImpl
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOllamaGenerate(
|
func NewOllamaGenerate(
|
||||||
client *utils_ollama.Client,
|
client *utils_ollama.Client,
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
chatHis *impl.ChatImpl,
|
chatHis *impl.ChatHisImpl,
|
||||||
) *OllamaService {
|
) *OllamaService {
|
||||||
return &OllamaService{
|
return &OllamaService{
|
||||||
client: client,
|
client: client,
|
||||||
|
|
|
||||||
|
|
@ -17,16 +17,16 @@ import (
|
||||||
type SessionBiz struct {
|
type SessionBiz struct {
|
||||||
sessionRepo *impl.SessionImpl
|
sessionRepo *impl.SessionImpl
|
||||||
sysRepo *impl.SysImpl
|
sysRepo *impl.SysImpl
|
||||||
chatRepo *impl.ChatImpl
|
chatHisRepo *impl.ChatHisImpl
|
||||||
|
|
||||||
conf *config.Config
|
conf *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatImpl) *SessionBiz {
|
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatHisImpl) *SessionBiz {
|
||||||
return &SessionBiz{
|
return &SessionBiz{
|
||||||
sessionRepo: sessionImpl,
|
sessionRepo: sessionImpl,
|
||||||
sysRepo: sysImpl,
|
sysRepo: sysImpl,
|
||||||
chatRepo: chatImpl,
|
chatHisRepo: chatImpl,
|
||||||
conf: conf,
|
conf: conf,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -91,10 +91,10 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe
|
||||||
result.Prologue = sysConfig.Prologue
|
result.Prologue = sysConfig.Prologue
|
||||||
// 存在,返回会话历史
|
// 存在,返回会话历史
|
||||||
var chatList []model.AiChatHi
|
var chatList []model.AiChatHi
|
||||||
chatList, err = s.chatRepo.FindAll(
|
chatList, err = s.chatHisRepo.FindAll(
|
||||||
s.chatRepo.WithSessionId(session.SessionID), // 条件:会话ID
|
s.chatHisRepo.WithSessionId(session.SessionID), // 条件:会话ID
|
||||||
s.chatRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
|
s.chatHisRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
|
||||||
s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
|
s.chatHisRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -11,14 +11,14 @@ import (
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatImpl struct {
|
type ChatHisImpl struct {
|
||||||
dataTemp.DataTemp
|
dataTemp.DataTemp
|
||||||
BaseRepository[model.AiChatHi]
|
BaseRepository[model.AiChatHi]
|
||||||
chatChannel chan model.AiChatHi
|
chatChannel chan model.AiChatHi
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatImpl(db *utils.Db) *ChatImpl {
|
func NewChatHisImpl(db *utils.Db) *ChatHisImpl {
|
||||||
return &ChatImpl{
|
return &ChatHisImpl{
|
||||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)),
|
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)),
|
||||||
BaseRepository: NewBaseModel[model.AiChatHi](db.Client),
|
BaseRepository: NewBaseModel[model.AiChatHi](db.Client),
|
||||||
chatChannel: make(chan model.AiChatHi, 100),
|
chatChannel: make(chan model.AiChatHi, 100),
|
||||||
|
|
@ -26,19 +26,19 @@ func NewChatImpl(db *utils.Db) *ChatImpl {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithSessionId 条件:会话ID
|
// WithSessionId 条件:会话ID
|
||||||
func (impl *ChatImpl) WithSessionId(sessionId interface{}) CondFunc {
|
func (impl *ChatHisImpl) WithSessionId(sessionId interface{}) CondFunc {
|
||||||
return func(db *gorm.DB) *gorm.DB {
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
return db.Where("session_id = ?", sessionId)
|
return db.Where("session_id = ?", sessionId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 异步添加会话历史
|
// 异步添加会话历史
|
||||||
func (impl *ChatImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) {
|
func (impl *ChatHisImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) {
|
||||||
impl.chatChannel <- chat
|
impl.chatChannel <- chat
|
||||||
}
|
}
|
||||||
|
|
||||||
// 异步处理会话历史
|
// 异步处理会话历史
|
||||||
func (impl *ChatImpl) AsyncProcess(ctx context.Context) {
|
func (impl *ChatHisImpl) AsyncProcess(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case chat := <-impl.chatChannel:
|
case chat := <-impl.chatChannel:
|
||||||
|
|
|
||||||
|
|
@ -4,4 +4,4 @@ import (
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatImpl)
|
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatHisImpl)
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ type AiChatHi struct {
|
||||||
Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用
|
Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用
|
||||||
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||||
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||||
|
TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName AiChatHi's table name
|
// TableName AiChatHi's table name
|
||||||
|
|
|
||||||
|
|
@ -14,3 +14,9 @@ type ChatHistory struct {
|
||||||
type ChatHisLog struct {
|
type ChatHisLog struct {
|
||||||
HisId int64 `json:"his_id"`
|
HisId int64 `json:"his_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ChatHistQuery struct {
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -150,6 +150,7 @@ type RequireData struct {
|
||||||
Histories []model.AiChatHi
|
Histories []model.AiChatHi
|
||||||
SessionInfo model.AiSession
|
SessionInfo model.AiSession
|
||||||
Tasks []model.AiTask
|
Tasks []model.AiTask
|
||||||
|
Task model.AiTask
|
||||||
Match *Match
|
Match *Match
|
||||||
Req *ChatSockRequest
|
Req *ChatSockRequest
|
||||||
Auth string
|
Auth string
|
||||||
|
|
|
||||||
|
|
@ -11,24 +11,25 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type HTTPServer struct {
|
type HTTPServer struct {
|
||||||
app *fiber.App
|
app *fiber.App
|
||||||
service *services.ChatService
|
service *services.ChatService
|
||||||
session *services.SessionService
|
session *services.SessionService
|
||||||
gateway *gateway.Gateway
|
gateway *gateway.Gateway
|
||||||
callback *services.CallbackService
|
callback *services.CallbackService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPServer(
|
func NewHTTPServer(
|
||||||
service *services.ChatService,
|
service *services.ChatService,
|
||||||
session *services.SessionService,
|
session *services.SessionService,
|
||||||
task *services.TaskService,
|
task *services.TaskService,
|
||||||
gateway *gateway.Gateway,
|
gateway *gateway.Gateway,
|
||||||
callback *services.CallbackService,
|
callback *services.CallbackService,
|
||||||
|
chatHis *services.HistoryService,
|
||||||
) *fiber.App {
|
) *fiber.App {
|
||||||
//构建 server
|
//构建 server
|
||||||
app := initRoute()
|
app := initRoute()
|
||||||
router.SetupRoutes(app, service, session, task, gateway, callback)
|
router.SetupRoutes(app, service, session, task, gateway, callback, chatHis)
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
func initRoute() *fiber.App {
|
func initRoute() *fiber.App {
|
||||||
|
|
|
||||||
|
|
@ -15,14 +15,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type RouterServer struct {
|
type RouterServer struct {
|
||||||
app *fiber.App
|
app *fiber.App
|
||||||
service *services.ChatService
|
service *services.ChatService
|
||||||
session *services.SessionService
|
session *services.SessionService
|
||||||
gateway *gateway.Gateway
|
gateway *gateway.Gateway
|
||||||
|
chatHist *services.HistoryService
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupRoutes 设置路由
|
// SetupRoutes 设置路由
|
||||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway, callbackService *services.CallbackService) {
|
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService,
|
||||||
|
gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService,
|
||||||
|
) {
|
||||||
app.Use(func(c *fiber.Ctx) error {
|
app.Use(func(c *fiber.Ctx) error {
|
||||||
// 设置 CORS 头
|
// 设置 CORS 头
|
||||||
c.Set("Access-Control-Allow-Origin", "*")
|
c.Set("Access-Control-Allow-Origin", "*")
|
||||||
|
|
@ -77,6 +80,9 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
||||||
return ctx.Status(400).SendString("unknown action")
|
return ctx.Status(400).SendString("unknown action")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 会话历史
|
||||||
|
r.Post("/chat/history", chatHist.GetHistory)
|
||||||
}
|
}
|
||||||
|
|
||||||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||||||
|
|
|
||||||
|
|
@ -73,9 +73,9 @@ func (h *ChatService) Chat(c *websocket.Conn) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
// 确保在函数返回时移除客户端并关闭连接
|
// 确保在函数返回时移除客户端并关闭连接
|
||||||
defer func() {
|
defer func() {
|
||||||
h.Gw.Cleanup(client.GetID())
|
|
||||||
close(semaphore) // 关闭信号量通道
|
|
||||||
wg.Wait() // 等待所有消息处理goroutine完成
|
wg.Wait() // 等待所有消息处理goroutine完成
|
||||||
|
close(semaphore) // 关闭信号量通道
|
||||||
|
h.Gw.Cleanup(client.GetID())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 绑定会话ID, sessionId 为空时, 则不绑定
|
// 绑定会话ID, sessionId 为空时, 则不绑定
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/biz"
|
||||||
|
errors "ai_scheduler/internal/data/error"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HistoryService struct {
|
||||||
|
chatRepo *biz.ChatHistoryBiz
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHistoryService(chatRepo *biz.ChatHistoryBiz) *HistoryService {
|
||||||
|
return &HistoryService{
|
||||||
|
chatRepo: chatRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistoryService 获取会话历史
|
||||||
|
func (h *HistoryService) GetHistory(c *fiber.Ctx) error {
|
||||||
|
var query entitys.ChatHistQuery
|
||||||
|
if err := c.BodyParser(&query); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 校验参数
|
||||||
|
if query.SessionID == "" {
|
||||||
|
return errors.SessionNotFound
|
||||||
|
}
|
||||||
|
if query.Page <= 0 {
|
||||||
|
query.Page = 1
|
||||||
|
}
|
||||||
|
if query.PageSize <= 0 {
|
||||||
|
query.PageSize = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询历史
|
||||||
|
history, err := h.chatRepo.List(c.Context(), &query)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(history)
|
||||||
|
}
|
||||||
|
|
@ -6,4 +6,4 @@ import (
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService)
|
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService, NewCallbackService, NewHistoryService)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue