ai_scheduler/internal/data/impl/chat_history.go

57 lines
1.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package impl
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/tmpl/dataTemp"
"ai_scheduler/utils"
"context"
"github.com/gofiber/fiber/v2/log"
"gorm.io/gorm"
"time"
)
type ChatImpl struct {
dataTemp.DataTemp
BaseRepository[model.AiChatHi]
chatChannel chan model.AiChatHi
}
func NewChatImpl(db *utils.Db) *ChatImpl {
return &ChatImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)),
BaseRepository: NewBaseModel[model.AiChatHi](db.Client),
chatChannel: make(chan model.AiChatHi, 100),
}
}
// WithSessionId 条件会话ID
func (impl *ChatImpl) WithSessionId(sessionId interface{}) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("session_id = ?", sessionId)
}
}
// 异步添加会话历史
func (impl *ChatImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) {
impl.chatChannel <- chat
}
// 异步处理会话历史
func (impl *ChatImpl) AsyncProcess(ctx context.Context) {
for {
select {
case chat := <-impl.chatChannel:
log.Infof("ChatHistoryAsyncProcess chat: %v", chat)
if err := impl.Create(&chat); err != nil {
log.Errorf("ChatHistoryAsyncProcess err: %v", err)
}
case <-ctx.Done():
log.Infof("ChatHistoryAsyncProcess ctx done")
return
// 定时打印通道大小
case <-time.After(time.Second * 5):
log.Infof("ChatHistoryAsyncProcess channel len: %d", len(impl.chatChannel))
}
}
}