diff --git a/internal/biz/chat_history.go b/internal/biz/chat_history.go index f8e18b8..488bead 100644 --- a/internal/biz/chat_history.go +++ b/internal/biz/chat_history.go @@ -31,7 +31,8 @@ func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) con := []impl.CondFunc{ s.chatHiRepo.WithSessionId(query.SessionID), s.chatHiRepo.PaginateScope(query.Page, query.PageSize), - s.chatHiRepo.OrderByDesc("his_id"), + // s.chatHiRepo.OrderByDesc("his_id"), + s.chatHiRepo.OrderByAsc("his_id"), } if query.HisID > 0 { con = append(con, s.chatHiRepo.WithHisId(query.HisID)) diff --git a/internal/biz/do/ctx.go b/internal/biz/do/ctx.go index 8d5cf74..18babd3 100644 --- a/internal/biz/do/ctx.go +++ b/internal/biz/do/ctx.go @@ -267,13 +267,15 @@ func (d *Do) startMessageHandler( if len(chat) > 0 { // 合并所有回答-转json字符串 ans, _ := json.Marshal(chat) + // 通过 chat 获取 task_id + taskId := d.getTaskIdByChat(chat) AiRes := &model.AiChatHi{ SessionID: requireData.Session, Ques: requireData.Req.Text, Ans: string(ans), Files: requireData.Req.Img, - TaskID: requireData.Task.TaskID, + TaskID: taskId, } d.hisImpl.AddWithData(AiRes) hisLog.HisId = AiRes.HisID @@ -410,3 +412,39 @@ func (d *Do) LoadUserPermission(client *gateway.Client, requireData *entitys.Req return respBody.Codes, nil } + +// getTaskIdByChat 从 chat 中获取 task_id +func (d *Do) getTaskIdByChat(chat []entitys.Response) (taskId int32) { + if len(chat) == 0 { + return + } + + var irregularTaskToolIndexMap = map[string]string{ + "zltxProduct": "product_diagnosis", + "zltxOrderDetail": "order_diagnosis", + "knowledgeBase": "knowledge_qa", + "zltxOrderStatistics": "account_statistics", + "normalChat": "chat", + "zltxOrderAfterSaleSupplier": "after_sale_supplier", + "zltxOrderAfterSaleReseller": "after_sale_reseller", + "zltxOrderAfterSaleResellerBatch": "after_sale_reseller_batch", + "zltxLossRiskSearch": "loss_order_direct", + } + + taskIndex := chat[0].Index + + if _, ok := irregularTaskToolIndexMap[taskIndex]; ok { + taskIndex = irregularTaskToolIndexMap[taskIndex] + } + + // 通过 taskIndex 获取 taskId + cond := builder.NewCond().And(builder.Eq{"`index`": taskIndex}) + taskMap, _ := d.taskImpl.GetOneBySearch(&cond) + if taskMap == nil { + return + } + + taskId = taskMap["task_id"].(int32) + + return +} diff --git a/internal/data/impl/base.go b/internal/data/impl/base.go index b7abf1a..8eb0b7d 100644 --- a/internal/data/impl/base.go +++ b/internal/data/impl/base.go @@ -50,6 +50,7 @@ type BaseRepository[P PO] interface { Count(conditions ...CondFunc) (count int64, err error) // 查询条数 PaginateScope(page, pageSize int) CondFunc // 分页 OrderByDesc(field string) CondFunc // 倒序排序 + OrderByAsc(field string) CondFunc // 正序排序 WithId(id interface{}) CondFunc // 查询id WithStatus(status int) CondFunc // 查询status GetDb() *gorm.DB // 获取数据库连接 @@ -193,6 +194,13 @@ func (this *BaseModel[P]) OrderByDesc(field string) CondFunc { } } +// 正序排序条件生成器 +func (this *BaseModel[P]) OrderByAsc(field string) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Order(fmt.Sprintf("%s ASC", field)) + } +} + // ID查询条件生成器 func (this *BaseModel[P]) WithId(id interface{}) CondFunc { return func(db *gorm.DB) *gorm.DB { diff --git a/internal/services/dtalk_bot_test.go b/internal/services/dtalk_bot_test.go index 5def7e6..cd4c245 100644 --- a/internal/services/dtalk_bot_test.go +++ b/internal/services/dtalk_bot_test.go @@ -63,7 +63,7 @@ func run() { botGroupImpl := impl.NewBotGroupImpl(db) botUserImpl := impl.NewBotUserImpl(db) // 初始化Do业务对象 - doDo := do.NewDo(sysImpl, taskImpl, chatHisImpl, configConfig) + doDo := do.NewDo(sessionImpl, sysImpl, taskImpl, chatHisImpl, configConfig) // 初始化Ollama客户端 client, _, _ := utils_ollama.NewClient(configConfig) // 初始化vLLM客户端