fix: 1. 对话历史增加taskid 2. 增加OrderByAsc方法 3.历史对话正序
This commit is contained in:
parent
b5312e73f0
commit
c13c601475
|
|
@ -31,7 +31,8 @@ func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery)
|
||||||
con := []impl.CondFunc{
|
con := []impl.CondFunc{
|
||||||
s.chatHiRepo.WithSessionId(query.SessionID),
|
s.chatHiRepo.WithSessionId(query.SessionID),
|
||||||
s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
|
s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
|
||||||
s.chatHiRepo.OrderByDesc("his_id"),
|
// s.chatHiRepo.OrderByDesc("his_id"),
|
||||||
|
s.chatHiRepo.OrderByAsc("his_id"),
|
||||||
}
|
}
|
||||||
if query.HisID > 0 {
|
if query.HisID > 0 {
|
||||||
con = append(con, s.chatHiRepo.WithHisId(query.HisID))
|
con = append(con, s.chatHiRepo.WithHisId(query.HisID))
|
||||||
|
|
|
||||||
|
|
@ -267,13 +267,15 @@ func (d *Do) startMessageHandler(
|
||||||
if len(chat) > 0 {
|
if len(chat) > 0 {
|
||||||
// 合并所有回答-转json字符串
|
// 合并所有回答-转json字符串
|
||||||
ans, _ := json.Marshal(chat)
|
ans, _ := json.Marshal(chat)
|
||||||
|
// 通过 chat 获取 task_id
|
||||||
|
taskId := d.getTaskIdByChat(chat)
|
||||||
|
|
||||||
AiRes := &model.AiChatHi{
|
AiRes := &model.AiChatHi{
|
||||||
SessionID: requireData.Session,
|
SessionID: requireData.Session,
|
||||||
Ques: requireData.Req.Text,
|
Ques: requireData.Req.Text,
|
||||||
Ans: string(ans),
|
Ans: string(ans),
|
||||||
Files: requireData.Req.Img,
|
Files: requireData.Req.Img,
|
||||||
TaskID: requireData.Task.TaskID,
|
TaskID: taskId,
|
||||||
}
|
}
|
||||||
d.hisImpl.AddWithData(AiRes)
|
d.hisImpl.AddWithData(AiRes)
|
||||||
hisLog.HisId = AiRes.HisID
|
hisLog.HisId = AiRes.HisID
|
||||||
|
|
@ -410,3 +412,39 @@ func (d *Do) LoadUserPermission(client *gateway.Client, requireData *entitys.Req
|
||||||
|
|
||||||
return respBody.Codes, nil
|
return respBody.Codes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getTaskIdByChat 从 chat 中获取 task_id
|
||||||
|
func (d *Do) getTaskIdByChat(chat []entitys.Response) (taskId int32) {
|
||||||
|
if len(chat) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var irregularTaskToolIndexMap = map[string]string{
|
||||||
|
"zltxProduct": "product_diagnosis",
|
||||||
|
"zltxOrderDetail": "order_diagnosis",
|
||||||
|
"knowledgeBase": "knowledge_qa",
|
||||||
|
"zltxOrderStatistics": "account_statistics",
|
||||||
|
"normalChat": "chat",
|
||||||
|
"zltxOrderAfterSaleSupplier": "after_sale_supplier",
|
||||||
|
"zltxOrderAfterSaleReseller": "after_sale_reseller",
|
||||||
|
"zltxOrderAfterSaleResellerBatch": "after_sale_reseller_batch",
|
||||||
|
"zltxLossRiskSearch": "loss_order_direct",
|
||||||
|
}
|
||||||
|
|
||||||
|
taskIndex := chat[0].Index
|
||||||
|
|
||||||
|
if _, ok := irregularTaskToolIndexMap[taskIndex]; ok {
|
||||||
|
taskIndex = irregularTaskToolIndexMap[taskIndex]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 通过 taskIndex 获取 taskId
|
||||||
|
cond := builder.NewCond().And(builder.Eq{"`index`": taskIndex})
|
||||||
|
taskMap, _ := d.taskImpl.GetOneBySearch(&cond)
|
||||||
|
if taskMap == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
taskId = taskMap["task_id"].(int32)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ type BaseRepository[P PO] interface {
|
||||||
Count(conditions ...CondFunc) (count int64, err error) // 查询条数
|
Count(conditions ...CondFunc) (count int64, err error) // 查询条数
|
||||||
PaginateScope(page, pageSize int) CondFunc // 分页
|
PaginateScope(page, pageSize int) CondFunc // 分页
|
||||||
OrderByDesc(field string) CondFunc // 倒序排序
|
OrderByDesc(field string) CondFunc // 倒序排序
|
||||||
|
OrderByAsc(field string) CondFunc // 正序排序
|
||||||
WithId(id interface{}) CondFunc // 查询id
|
WithId(id interface{}) CondFunc // 查询id
|
||||||
WithStatus(status int) CondFunc // 查询status
|
WithStatus(status int) CondFunc // 查询status
|
||||||
GetDb() *gorm.DB // 获取数据库连接
|
GetDb() *gorm.DB // 获取数据库连接
|
||||||
|
|
@ -193,6 +194,13 @@ func (this *BaseModel[P]) OrderByDesc(field string) CondFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 正序排序条件生成器
|
||||||
|
func (this *BaseModel[P]) OrderByAsc(field string) CondFunc {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Order(fmt.Sprintf("%s ASC", field))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ID查询条件生成器
|
// ID查询条件生成器
|
||||||
func (this *BaseModel[P]) WithId(id interface{}) CondFunc {
|
func (this *BaseModel[P]) WithId(id interface{}) CondFunc {
|
||||||
return func(db *gorm.DB) *gorm.DB {
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ func run() {
|
||||||
botGroupImpl := impl.NewBotGroupImpl(db)
|
botGroupImpl := impl.NewBotGroupImpl(db)
|
||||||
botUserImpl := impl.NewBotUserImpl(db)
|
botUserImpl := impl.NewBotUserImpl(db)
|
||||||
// 初始化Do业务对象
|
// 初始化Do业务对象
|
||||||
doDo := do.NewDo(sysImpl, taskImpl, chatHisImpl, configConfig)
|
doDo := do.NewDo(sessionImpl, sysImpl, taskImpl, chatHisImpl, configConfig)
|
||||||
// 初始化Ollama客户端
|
// 初始化Ollama客户端
|
||||||
client, _, _ := utils_ollama.NewClient(configConfig)
|
client, _, _ := utils_ollama.NewClient(configConfig)
|
||||||
// 初始化vLLM客户端
|
// 初始化vLLM客户端
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue