124 lines
3.3 KiB
Go
124 lines
3.3 KiB
Go
package biz
|
|
|
|
import (
|
|
errors "ai_scheduler/internal/data/error"
|
|
"ai_scheduler/internal/data/impl"
|
|
"ai_scheduler/internal/data/model"
|
|
"ai_scheduler/internal/entitys"
|
|
"ai_scheduler/internal/pkg/util"
|
|
"context"
|
|
"encoding/json"
|
|
"xorm.io/builder"
|
|
)
|
|
|
|
type ChatHistoryBiz struct {
|
|
chatHiRepo *impl.ChatHisImpl
|
|
taskRepo *impl.TaskImpl
|
|
}
|
|
|
|
func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl, taskRepo *impl.TaskImpl) *ChatHistoryBiz {
|
|
s := &ChatHistoryBiz{
|
|
chatHiRepo: chatHiRepo,
|
|
taskRepo: taskRepo,
|
|
}
|
|
return s
|
|
}
|
|
|
|
// 查询会话历史
|
|
func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]entitys.ChatHisQueryResponse, error) {
|
|
chats, err := s.chatHiRepo.FindAll(
|
|
s.chatHiRepo.WithSessionId(query.SessionID),
|
|
s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
|
|
s.chatHiRepo.OrderByDesc("his_id"),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
taskIds := make([]int32, 0, len(chats))
|
|
for _, chat := range chats {
|
|
// 去重任务ID
|
|
if !util.Contains(taskIds, chat.TaskID) {
|
|
taskIds = append(taskIds, chat.TaskID)
|
|
}
|
|
}
|
|
|
|
// 查询任务名称
|
|
tasks, err := s.taskRepo.FindAll(s.taskRepo.In("task_id", taskIds))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
taskMap := make(map[int32]model.AiTask)
|
|
for _, task := range tasks {
|
|
taskMap[task.TaskID] = task
|
|
}
|
|
|
|
// 构建结果
|
|
result := make([]entitys.ChatHisQueryResponse, 0, len(chats))
|
|
for _, chat := range chats {
|
|
item := entitys.ChatHisQueryResponse{}
|
|
item.FromModel(chat, taskMap[chat.TaskID])
|
|
result = append(result, item)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
//// 添加会话历史
|
|
//func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
|
|
// return s.chatHiRepo.Create(&model.AiChatHi{
|
|
// SessionID: chat.SessionID,
|
|
// Ques: chat.Role.String(),
|
|
// Ans: chat.Content,
|
|
// })
|
|
//}
|
|
|
|
// 更新会话历史内容, 追加内容, 不覆盖原有内容, content 使用json格式存储
|
|
func (c *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UpdateContentRequest) error {
|
|
var contents []string
|
|
chatHi, has, err := c.chatHiRepo.FindOne(c.chatHiRepo.WithHisId(chat.HisID))
|
|
if err != nil {
|
|
return err
|
|
} else if !has {
|
|
return errors.NewBusinessErr(errors.InvalidParamCode, "chat history not found")
|
|
}
|
|
|
|
if "" != chatHi.Content {
|
|
// 解析历史内容
|
|
err = json.Unmarshal([]byte(chatHi.Content), &contents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
contents = append(contents, chat.Content)
|
|
|
|
b, err := json.Marshal(contents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
chatHi.Content = string(b)
|
|
return c.chatHiRepo.Update(&chatHi,
|
|
c.chatHiRepo.Select("content"),
|
|
c.chatHiRepo.WithHisId(chatHi.HisID))
|
|
}
|
|
|
|
// 异步添加会话历史
|
|
//func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) {
|
|
// s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
|
|
// SessionID: chat.SessionID,
|
|
// Role: chat.Role.String(),
|
|
// Content: chat.Content,
|
|
// })
|
|
//}
|
|
|
|
// 异步处理会话历史
|
|
//func (s *ChatHistoryBiz) AsyncProcess(ctx context.Context) {
|
|
// s.chatRepo.AsyncProcess(ctx)
|
|
//}
|
|
|
|
func (s *ChatHistoryBiz) Update(ctx context.Context, chat *entitys.UseFulRequest) error {
|
|
cond := builder.NewCond()
|
|
cond = cond.And(builder.Eq{"his_id": chat.HisId})
|
|
return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
|
|
}
|