fix: 1. 对话历史增加taskid 2. 增加OrderByAsc方法 3.历史对话正序
This commit is contained in:
parent
b5312e73f0
commit
c13c601475
|
|
@ -31,7 +31,8 @@ func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery)
|
|||
con := []impl.CondFunc{
|
||||
s.chatHiRepo.WithSessionId(query.SessionID),
|
||||
s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
|
||||
s.chatHiRepo.OrderByDesc("his_id"),
|
||||
// s.chatHiRepo.OrderByDesc("his_id"),
|
||||
s.chatHiRepo.OrderByAsc("his_id"),
|
||||
}
|
||||
if query.HisID > 0 {
|
||||
con = append(con, s.chatHiRepo.WithHisId(query.HisID))
|
||||
|
|
|
|||
|
|
@ -267,13 +267,15 @@ func (d *Do) startMessageHandler(
|
|||
if len(chat) > 0 {
|
||||
// 合并所有回答-转json字符串
|
||||
ans, _ := json.Marshal(chat)
|
||||
// 通过 chat 获取 task_id
|
||||
taskId := d.getTaskIdByChat(chat)
|
||||
|
||||
AiRes := &model.AiChatHi{
|
||||
SessionID: requireData.Session,
|
||||
Ques: requireData.Req.Text,
|
||||
Ans: string(ans),
|
||||
Files: requireData.Req.Img,
|
||||
TaskID: requireData.Task.TaskID,
|
||||
TaskID: taskId,
|
||||
}
|
||||
d.hisImpl.AddWithData(AiRes)
|
||||
hisLog.HisId = AiRes.HisID
|
||||
|
|
@ -410,3 +412,39 @@ func (d *Do) LoadUserPermission(client *gateway.Client, requireData *entitys.Req
|
|||
|
||||
return respBody.Codes, nil
|
||||
}
|
||||
|
||||
// getTaskIdByChat 从 chat 中获取 task_id
|
||||
func (d *Do) getTaskIdByChat(chat []entitys.Response) (taskId int32) {
|
||||
if len(chat) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var irregularTaskToolIndexMap = map[string]string{
|
||||
"zltxProduct": "product_diagnosis",
|
||||
"zltxOrderDetail": "order_diagnosis",
|
||||
"knowledgeBase": "knowledge_qa",
|
||||
"zltxOrderStatistics": "account_statistics",
|
||||
"normalChat": "chat",
|
||||
"zltxOrderAfterSaleSupplier": "after_sale_supplier",
|
||||
"zltxOrderAfterSaleReseller": "after_sale_reseller",
|
||||
"zltxOrderAfterSaleResellerBatch": "after_sale_reseller_batch",
|
||||
"zltxLossRiskSearch": "loss_order_direct",
|
||||
}
|
||||
|
||||
taskIndex := chat[0].Index
|
||||
|
||||
if _, ok := irregularTaskToolIndexMap[taskIndex]; ok {
|
||||
taskIndex = irregularTaskToolIndexMap[taskIndex]
|
||||
}
|
||||
|
||||
// 通过 taskIndex 获取 taskId
|
||||
cond := builder.NewCond().And(builder.Eq{"`index`": taskIndex})
|
||||
taskMap, _ := d.taskImpl.GetOneBySearch(&cond)
|
||||
if taskMap == nil {
|
||||
return
|
||||
}
|
||||
|
||||
taskId = taskMap["task_id"].(int32)
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ type BaseRepository[P PO] interface {
|
|||
Count(conditions ...CondFunc) (count int64, err error) // 查询条数
|
||||
PaginateScope(page, pageSize int) CondFunc // 分页
|
||||
OrderByDesc(field string) CondFunc // 倒序排序
|
||||
OrderByAsc(field string) CondFunc // 正序排序
|
||||
WithId(id interface{}) CondFunc // 查询id
|
||||
WithStatus(status int) CondFunc // 查询status
|
||||
GetDb() *gorm.DB // 获取数据库连接
|
||||
|
|
@ -193,6 +194,13 @@ func (this *BaseModel[P]) OrderByDesc(field string) CondFunc {
|
|||
}
|
||||
}
|
||||
|
||||
// 正序排序条件生成器
|
||||
func (this *BaseModel[P]) OrderByAsc(field string) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Order(fmt.Sprintf("%s ASC", field))
|
||||
}
|
||||
}
|
||||
|
||||
// ID查询条件生成器
|
||||
func (this *BaseModel[P]) WithId(id interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ func run() {
|
|||
botGroupImpl := impl.NewBotGroupImpl(db)
|
||||
botUserImpl := impl.NewBotUserImpl(db)
|
||||
// 初始化Do业务对象
|
||||
doDo := do.NewDo(sysImpl, taskImpl, chatHisImpl, configConfig)
|
||||
doDo := do.NewDo(sessionImpl, sysImpl, taskImpl, chatHisImpl, configConfig)
|
||||
// 初始化Ollama客户端
|
||||
client, _, _ := utils_ollama.NewClient(configConfig)
|
||||
// 初始化vLLM客户端
|
||||
|
|
|
|||
Loading…
Reference in New Issue