135 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			135 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Go
		
	
	
	
package biz
 | 
						||
 | 
						||
import (
 | 
						||
	"ai_scheduler/internal/data/constants"
 | 
						||
	errorcode "ai_scheduler/internal/data/error"
 | 
						||
	"ai_scheduler/internal/data/impl"
 | 
						||
	"ai_scheduler/internal/data/model"
 | 
						||
	"ai_scheduler/internal/entitys"
 | 
						||
	"context"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"github.com/gofiber/fiber/v2/utils"
 | 
						||
 | 
						||
	"ai_scheduler/internal/config"
 | 
						||
)
 | 
						||
 | 
						||
type SessionBiz struct {
 | 
						||
	sessionRepo *impl.SessionImpl
 | 
						||
	sysRepo     *impl.SysImpl
 | 
						||
	chatRepo    *impl.ChatImpl
 | 
						||
 | 
						||
	conf *config.Config
 | 
						||
}
 | 
						||
 | 
						||
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, chatImpl *impl.ChatImpl) *SessionBiz {
 | 
						||
	return &SessionBiz{
 | 
						||
		sessionRepo: sessionImpl,
 | 
						||
		sysRepo:     sysImpl,
 | 
						||
		chatRepo:    chatImpl,
 | 
						||
		conf:        conf,
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// InitSession 初始化会话 ,当天存在则返回会话,如果不存在则创建一个
 | 
						||
func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRequest) (result *entitys.SessionInitResponse, err error) {
 | 
						||
 | 
						||
	// 获取系统配置
 | 
						||
	sysConfig, has, err := s.sysRepo.FindOne(s.sysRepo.WithSysId(req.SysId))
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	} else if !has {
 | 
						||
		err = errorcode.SysNotFound
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	result = &entitys.SessionInitResponse{
 | 
						||
		Chat: make([]entitys.ChatHistory, 0),
 | 
						||
	}
 | 
						||
 | 
						||
	// 获取 当天的session
 | 
						||
	t := time.Now().Truncate(24 * time.Hour)
 | 
						||
	session, has, err := s.sessionRepo.FindOne(
 | 
						||
		s.sessionRepo.WithUserId(req.UserId), // 条件:用户ID
 | 
						||
		s.sessionRepo.WithStartTime(t),       // 条件:会话开始时间
 | 
						||
		s.sysRepo.WithSysId(sysConfig.SysID), // 条件:系统ID
 | 
						||
	)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	} else if !has {
 | 
						||
		// 不存在,创建一个
 | 
						||
		session = model.AiSession{
 | 
						||
			SysID:     sysConfig.SysID,
 | 
						||
			SessionID: utils.UUID(),
 | 
						||
			UserID:    req.UserId,
 | 
						||
			UserName:  req.UserName,
 | 
						||
		}
 | 
						||
		err = s.sessionRepo.Create(&session)
 | 
						||
		if err != nil {
 | 
						||
			return
 | 
						||
		}
 | 
						||
 | 
						||
		chat := entitys.ChatHistory{
 | 
						||
			SessionID: session.SessionID,
 | 
						||
			Role:      constants.RoleSystem,
 | 
						||
			Content:   sysConfig.Prologue,
 | 
						||
		}
 | 
						||
		result.Chat = append(result.Chat, chat)
 | 
						||
		result.SessionId = session.SessionID
 | 
						||
		result.Prologue = sysConfig.Prologue
 | 
						||
 | 
						||
		// 开场白写入会话历史
 | 
						||
		s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
 | 
						||
			SessionID: chat.SessionID,
 | 
						||
			Role:      chat.Role.String(),
 | 
						||
			Content:   chat.Content,
 | 
						||
		})
 | 
						||
 | 
						||
	} else {
 | 
						||
		result.SessionId = session.SessionID
 | 
						||
		result.Prologue = sysConfig.Prologue
 | 
						||
		// 存在,返回会话历史
 | 
						||
		var chatList []model.AiChatHi
 | 
						||
		chatList, err = s.chatRepo.FindAll(
 | 
						||
			s.chatRepo.WithSessionId(session.SessionID),      // 条件:会话ID
 | 
						||
			s.chatRepo.OrderByDesc("create_at"),              // 排序:按创建时间降序
 | 
						||
			s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
 | 
						||
		)
 | 
						||
		if err != nil {
 | 
						||
			return
 | 
						||
		}
 | 
						||
 | 
						||
		// 转换为 entitys.ChatHistory 类型
 | 
						||
		for _, chat := range chatList {
 | 
						||
			result.Chat = append(result.Chat, entitys.ChatHistory{
 | 
						||
				SessionID: chat.SessionID,
 | 
						||
				Role:      constants.Caller(chat.Role),
 | 
						||
				Content:   chat.Content,
 | 
						||
				Prologue:  sysConfig.Prologue,
 | 
						||
			})
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
// SessionList 会话列表
 | 
						||
func (s *SessionBiz) SessionList(ctx context.Context, req *entitys.SessionListRequest) (list []model.AiSession, err error) {
 | 
						||
 | 
						||
	if req.Page <= 0 {
 | 
						||
		req.Page = 1
 | 
						||
	}
 | 
						||
	if req.PageSize <= 0 {
 | 
						||
		req.PageSize = 10
 | 
						||
	}
 | 
						||
 | 
						||
	list, err = s.sessionRepo.FindAll(
 | 
						||
		s.sessionRepo.WithUserId(req.UserId),                // 条件:用户ID
 | 
						||
		s.sessionRepo.WithSysId(req.SysId),                  // 条件:系统ID
 | 
						||
		s.sessionRepo.PaginateScope(req.Page, req.PageSize), // 分页
 | 
						||
		s.sessionRepo.OrderByDesc("create_at"),              // 排序:按创建时间降序
 | 
						||
	)
 | 
						||
 | 
						||
	return
 | 
						||
}
 |