ai_scheduler/internal/biz/chat_history.go

132 lines
3.4 KiB
Go

package biz
import (
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/util"
"context"
"encoding/json"
"xorm.io/builder"
)
type ChatHistoryBiz struct {
chatHiRepo *impl.ChatHisImpl
taskRepo *impl.TaskImpl
}
func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl, taskRepo *impl.TaskImpl) *ChatHistoryBiz {
s := &ChatHistoryBiz{
chatHiRepo: chatHiRepo,
taskRepo: taskRepo,
}
return s
}
// 查询会话历史
func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]entitys.ChatHisQueryResponse, error) {
con := []impl.CondFunc{
s.chatHiRepo.WithSessionId(query.SessionID),
s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
s.chatHiRepo.OrderByDesc("his_id"),
}
if query.HisID > 0 {
con = append(con, s.chatHiRepo.WithHisId(query.HisID))
}
chats, err := s.chatHiRepo.FindAll(
con...,
)
if err != nil {
return nil, err
}
taskIds := make([]int32, 0, len(chats))
for _, chat := range chats {
// 去重任务ID
if !util.Contains(taskIds, chat.TaskID) {
taskIds = append(taskIds, chat.TaskID)
}
}
// 查询任务名称
tasks, err := s.taskRepo.FindAll(s.taskRepo.In("task_id", taskIds))
if err != nil {
return nil, err
}
taskMap := make(map[int32]model.AiTask)
for _, task := range tasks {
taskMap[task.TaskID] = task
}
// 构建结果
result := make([]entitys.ChatHisQueryResponse, 0, len(chats))
for _, chat := range chats {
item := entitys.ChatHisQueryResponse{}
item.FromModel(chat, taskMap[chat.TaskID])
result = append(result, item)
}
return result, nil
}
//// 添加会话历史
//func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
// return s.chatHiRepo.Create(&model.AiChatHi{
// SessionID: chat.SessionID,
// Ques: chat.Role.String(),
// Ans: chat.Content,
// })
//}
// 更新会话历史内容, 追加内容, 不覆盖原有内容, content 使用json格式存储
func (c *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UpdateContentRequest) error {
var contents []string
chatHi, has, err := c.chatHiRepo.FindOne(c.chatHiRepo.WithHisId(chat.HisID))
if err != nil {
return err
} else if !has {
return errors.NewBusinessErr(errors.InvalidParamCode, "chat history not found")
}
if "" != chatHi.Content {
// 解析历史内容
err = json.Unmarshal([]byte(chatHi.Content), &contents)
if err != nil {
return err
}
}
contents = append(contents, chat.Content)
b, err := json.Marshal(contents)
if err != nil {
return err
}
chatHi.Content = string(b)
return c.chatHiRepo.Update(&chatHi,
c.chatHiRepo.Select("content"),
c.chatHiRepo.WithHisId(chatHi.HisID))
}
// 异步添加会话历史
//func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) {
// s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
// SessionID: chat.SessionID,
// Role: chat.Role.String(),
// Content: chat.Content,
// })
//}
// 异步处理会话历史
//func (s *ChatHistoryBiz) AsyncProcess(ctx context.Context) {
// s.chatRepo.AsyncProcess(ctx)
//}
func (s *ChatHistoryBiz) Update(ctx context.Context, chat *entitys.UseFulRequest) error {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"his_id": chat.HisId})
return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
}