127 lines
3.2 KiB
Go
127 lines
3.2 KiB
Go
package biz
|
||
|
||
import (
|
||
"ai_scheduler/internal/data/constants"
|
||
"ai_scheduler/internal/data/impl"
|
||
"ai_scheduler/internal/data/model"
|
||
"ai_scheduler/internal/entitys"
|
||
"context"
|
||
"fmt"
|
||
"github.com/gofiber/fiber/v2/utils"
|
||
"time"
|
||
|
||
"ai_scheduler/internal/config"
|
||
)
|
||
|
||
type SessionBiz struct {
|
||
sessionRepo *impl.SessionImpl
|
||
sysRepo *impl.SysImpl
|
||
chatRepo *impl.ChatImpl
|
||
|
||
conf *config.Config
|
||
}
|
||
|
||
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl) *SessionBiz {
|
||
return &SessionBiz{
|
||
sessionRepo: sessionImpl,
|
||
sysRepo: sysImpl,
|
||
conf: conf,
|
||
}
|
||
}
|
||
|
||
// InitSession 初始化会话 ,当天存在则返回会话,如果不存在则创建一个
|
||
func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRequest) (result *entitys.SessionInitResponse, err error) {
|
||
|
||
// 获取系统配置
|
||
sysConfig, has, err := s.sysRepo.FindOne(s.sysRepo.WithSysId(req.SysId))
|
||
if err != nil {
|
||
return
|
||
} else if !has {
|
||
err = fmt.Errorf("sys not found")
|
||
return
|
||
}
|
||
|
||
result = &entitys.SessionInitResponse{
|
||
Chat: make([]entitys.ChatHistory, 0),
|
||
}
|
||
|
||
// 获取 当天的session
|
||
t := time.Now().Truncate(24 * time.Hour)
|
||
session, has, err := s.sessionRepo.FindOne(
|
||
s.sessionRepo.WithUserId(req.UserId), // 条件:用户ID
|
||
s.sessionRepo.WithStartTime(t), // 条件:会话开始时间
|
||
s.sysRepo.WithSysId(sysConfig.SysID), // 条件:系统ID
|
||
)
|
||
if err != nil {
|
||
return
|
||
} else if !has {
|
||
// 不存在,创建一个
|
||
session = model.AiSession{
|
||
SysID: sysConfig.SysID,
|
||
SessionID: utils.UUID(),
|
||
UserID: req.UserId,
|
||
}
|
||
err = s.sessionRepo.Create(&session)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
chat := entitys.ChatHistory{
|
||
SessionID: session.SessionID,
|
||
Role: constants.RoleSystem,
|
||
Content: sysConfig.Prologue,
|
||
}
|
||
result.Chat = append(result.Chat, chat)
|
||
|
||
// 开场白写入会话历史
|
||
s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
|
||
SessionID: chat.SessionID,
|
||
Role: chat.Role.String(),
|
||
Content: chat.Content,
|
||
})
|
||
|
||
} else {
|
||
// 存在,返回会话历史
|
||
var chatList []model.AiChatHi
|
||
chatList, err = s.chatRepo.FindAll(
|
||
s.chatRepo.WithSessionId(session.SessionID), // 条件:会话ID
|
||
s.chatRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
|
||
s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
|
||
)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 转换为 entitys.ChatHistory 类型
|
||
for _, chat := range chatList {
|
||
result.Chat = append(result.Chat, entitys.ChatHistory{
|
||
SessionID: chat.SessionID,
|
||
Role: constants.Caller(chat.Role),
|
||
Content: chat.Content,
|
||
})
|
||
}
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// SessionList 会话列表
|
||
func (s *SessionBiz) SessionList(ctx context.Context, req *entitys.SessionListRequest) (list []model.AiSession, err error) {
|
||
|
||
if req.Page <= 0 {
|
||
req.Page = 1
|
||
}
|
||
if req.PageSize <= 0 {
|
||
req.PageSize = 10
|
||
}
|
||
|
||
list, err = s.sessionRepo.FindAll(
|
||
s.sessionRepo.WithUserId(req.UserId), // 条件:用户ID
|
||
s.sessionRepo.WithSysId(req.SysId), // 条件:系统ID
|
||
s.sessionRepo.PaginateScope(req.Page, req.PageSize), // 分页
|
||
s.sessionRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
|
||
)
|
||
|
||
return
|
||
}
|