feat:会话历史和前端回传数据
This commit is contained in:
parent
328d837250
commit
7ac61893f5
5
go.mod
5
go.mod
|
|
@ -14,6 +14,9 @@ require (
|
|||
github.com/faabiosr/cachego v0.26.0
|
||||
github.com/fastwego/dingding v1.0.0-beta.4
|
||||
github.com/go-kratos/kratos/v2 v2.9.1
|
||||
github.com/go-playground/locales v0.14.1
|
||||
github.com/go-playground/universal-translator v0.18.1
|
||||
github.com/go-playground/validator/v10 v10.20.0
|
||||
github.com/gofiber/fiber/v2 v2.52.9
|
||||
github.com/gofiber/websocket/v2 v2.2.1
|
||||
github.com/google/uuid v1.6.0
|
||||
|
|
@ -53,6 +56,7 @@ require (
|
|||
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/goph/emperror v0.17.2 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
|
|
@ -61,6 +65,7 @@ require (
|
|||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.17.9 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
|
|
|
|||
12
go.sum
12
go.sum
|
|
@ -174,6 +174,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo
|
|||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
|
||||
github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
|
||||
|
|
@ -183,6 +185,14 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2
|
|||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-kratos/kratos/v2 v2.9.1 h1:EGif6/S/aK/RCR5clIbyhioTNyoSrii3FC118jG40Z0=
|
||||
github.com/go-kratos/kratos/v2 v2.9.1/go.mod h1:a1MQLjMhIh7R0kcJS9SzJYR43BRI7EPzzN0J1Ksu2bA=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw=
|
||||
|
|
@ -292,6 +302,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
|||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
|
|
|
|||
|
|
@ -1,53 +1,105 @@
|
|||
package biz
|
||||
|
||||
import (
|
||||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/util"
|
||||
"context"
|
||||
|
||||
"encoding/json"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type ChatHistoryBiz struct {
|
||||
chatHiRepo *impl.ChatHisImpl
|
||||
taskRepo *impl.TaskImpl
|
||||
}
|
||||
|
||||
func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl) *ChatHistoryBiz {
|
||||
func NewChatHistoryBiz(chatHiRepo *impl.ChatHisImpl, taskRepo *impl.TaskImpl) *ChatHistoryBiz {
|
||||
s := &ChatHistoryBiz{
|
||||
chatHiRepo: chatHiRepo,
|
||||
taskRepo: taskRepo,
|
||||
}
|
||||
//go s.AsyncProcess(context.Background())
|
||||
return s
|
||||
}
|
||||
|
||||
// 查询会话历史
|
||||
func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]model.AiChatHi, error) {
|
||||
func (s *ChatHistoryBiz) List(ctx context.Context, query *entitys.ChatHistQuery) ([]entitys.ChatHisQueryResponse, error) {
|
||||
chats, err := s.chatHiRepo.FindAll(
|
||||
s.chatHiRepo.WithSessionId(query.SessionID),
|
||||
s.chatHiRepo.PaginateScope(query.Page, query.PageSize),
|
||||
s.chatHiRepo.OrderByDesc("his_id"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return chats, nil
|
||||
|
||||
taskIds := make([]int32, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
// 去重任务ID
|
||||
if !util.Contains(taskIds, chat.TaskID) {
|
||||
taskIds = append(taskIds, chat.TaskID)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加会话历史
|
||||
func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
|
||||
return s.chatHiRepo.Create(&model.AiChatHi{
|
||||
SessionID: chat.SessionID,
|
||||
Ques: chat.Role.String(),
|
||||
Ans: chat.Content,
|
||||
})
|
||||
// 查询任务名称
|
||||
tasks, err := s.taskRepo.FindAll(s.taskRepo.In("task_id", taskIds))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
taskMap := make(map[int32]model.AiTask)
|
||||
for _, task := range tasks {
|
||||
taskMap[task.TaskID] = task
|
||||
}
|
||||
|
||||
// 更新会话历史内容
|
||||
func (s *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UseFulRequest) error {
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"his_id": chat.HisId})
|
||||
// 构建结果
|
||||
result := make([]entitys.ChatHisQueryResponse, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
item := entitys.ChatHisQueryResponse{}
|
||||
item.FromModel(chat, taskMap[chat.TaskID])
|
||||
result = append(result, item)
|
||||
}
|
||||
|
||||
return s.chatHiRepo.UpdateByCond(&cond, &model.AiChatHi{HisID: chat.HisId, Useful: chat.Useful})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
//// 添加会话历史
|
||||
//func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
|
||||
// return s.chatHiRepo.Create(&model.AiChatHi{
|
||||
// SessionID: chat.SessionID,
|
||||
// Ques: chat.Role.String(),
|
||||
// Ans: chat.Content,
|
||||
// })
|
||||
//}
|
||||
|
||||
// 更新会话历史内容, 追加内容, 不覆盖原有内容, content 使用json格式存储
|
||||
func (c *ChatHistoryBiz) UpdateContent(ctx context.Context, chat *entitys.UpdateContentRequest) error {
|
||||
var contents []string
|
||||
chatHi, has, err := c.chatHiRepo.FindOne(c.chatHiRepo.WithHisId(chat.HisID))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !has {
|
||||
return errors.NewBusinessErr(errors.InvalidParamCode, "chat history not found")
|
||||
}
|
||||
|
||||
if "" != chatHi.Content {
|
||||
// 解析历史内容
|
||||
err = json.Unmarshal([]byte(chatHi.Content), &contents)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
contents = append(contents, chat.Content)
|
||||
|
||||
b, err := json.Marshal(contents)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chatHi.Content = string(b)
|
||||
return c.chatHiRepo.Update(&chatHi,
|
||||
c.chatHiRepo.Select("content"),
|
||||
c.chatHiRepo.WithHisId(chatHi.HisID))
|
||||
}
|
||||
|
||||
// 异步添加会话历史
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ type BaseRepository[P PO] interface {
|
|||
WithStatus(status int) CondFunc // 查询status
|
||||
GetDb() *gorm.DB // 获取数据库连接
|
||||
WithLimit(limit int) CondFunc // 限制返回条数
|
||||
In(field string, values interface{}) CondFunc // 查询字段是否在列表中
|
||||
Select(fields ...string) CondFunc // 选择字段
|
||||
}
|
||||
|
||||
// PaginationResult 分页查询结果
|
||||
|
|
@ -215,3 +217,17 @@ func (this *BaseModel[P]) WithLimit(limit int) CondFunc {
|
|||
return db.Limit(limit)
|
||||
}
|
||||
}
|
||||
|
||||
// 查询字段是否在列表中
|
||||
func (this *BaseModel[P]) In(field string, values interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where(fmt.Sprintf("%s IN ?", field), values)
|
||||
}
|
||||
}
|
||||
|
||||
// select 字段
|
||||
func (this *BaseModel[P]) Select(fields ...string) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Select(fields)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,3 +55,10 @@ func (impl *ChatHisImpl) AsyncProcess(ctx context.Context) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// his_id 条件:历史ID
|
||||
func (impl *ChatHisImpl) WithHisId(hisId interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("his_id = ?", hisId)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,8 +8,12 @@ import (
|
|||
|
||||
type TaskImpl struct {
|
||||
dataTemp.DataTemp
|
||||
BaseRepository[model.AiTask]
|
||||
}
|
||||
|
||||
func NewTaskImpl(db *utils.Db) *TaskImpl {
|
||||
return &TaskImpl{*dataTemp.NewDataTemp(db, new(model.AiTask))}
|
||||
return &TaskImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiTask)),
|
||||
BaseRepository: NewBaseModel[model.AiTask](db.Client),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ type AiChatHi struct {
|
|||
CreateAt time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID
|
||||
Content string `gorm:"column:content" json:"content"` // 前端回传数据
|
||||
}
|
||||
|
||||
// TableName AiChatHi's table name
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ package entitys
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/data/constants"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ChatHistory struct {
|
||||
|
|
@ -20,3 +22,43 @@ type ChatHistQuery struct {
|
|||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
type ChatHisQueryResponse struct {
|
||||
HisID int64 `gorm:"column:his_id;primaryKey;autoIncrement:true" json:"his_id"`
|
||||
SessionID string `gorm:"column:session_id;not null" json:"session_id"`
|
||||
Ques string `gorm:"column:ques;not null" json:"ques"`
|
||||
Ans string `gorm:"column:ans;not null" json:"ans"`
|
||||
Files string `gorm:"column:files;not null" json:"files"`
|
||||
Useful int32 `gorm:"column:useful;not null;comment:0不评价,1有用,其他为无用" json:"useful"` // 0不评价,1有用,其他为无用
|
||||
CreateAt string `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
|
||||
TaskID int32 `gorm:"column:task_id;not null" json:"task_id"` // 任务ID
|
||||
TaskName string `gorm:"column:task_name;not null" json:"task_name"` // 任务名称
|
||||
Content []string `gorm:"column:content" json:"content"` // 前端回传数据
|
||||
}
|
||||
|
||||
func (c *ChatHisQueryResponse) FromModel(chat model.AiChatHi, task model.AiTask) {
|
||||
c.HisID = chat.HisID
|
||||
c.SessionID = chat.SessionID
|
||||
c.Ques = chat.Ques
|
||||
c.Ans = chat.Ans
|
||||
c.Files = chat.Files
|
||||
c.Useful = chat.Useful
|
||||
c.CreateAt = chat.CreateAt.Format("2006-01-02 15:04:05")
|
||||
c.TaskID = chat.TaskID
|
||||
c.TaskName = task.Name
|
||||
c.Content = make([]string, 0)
|
||||
|
||||
// 解析Content
|
||||
if "" != chat.Content {
|
||||
var contents []string
|
||||
if err := json.Unmarshal([]byte(chat.Content), &contents); err != nil {
|
||||
c.Content = contents
|
||||
}
|
||||
c.Content = append(c.Content, chat.Content)
|
||||
}
|
||||
}
|
||||
|
||||
type UpdateContentRequest struct {
|
||||
HisID int64 `json:"his_id" validate:"required"`
|
||||
Content string `json:"content" validate:"required"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,3 +32,13 @@ func StringToFloat64(s string) float64 {
|
|||
i, _ := strconv.ParseFloat(s, 64)
|
||||
return i
|
||||
}
|
||||
|
||||
// 是否包含在数组中
|
||||
func Contains[T comparable](strings []T, str T) bool {
|
||||
for _, s := range strings {
|
||||
if s == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
package validate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-playground/locales/zh"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
"github.com/go-playground/validator/v10"
|
||||
zh_translations "github.com/go-playground/validator/v10/translations/zh"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func Struct(s interface{}) (errMsg []string, err error) {
|
||||
// 创建验证器实例
|
||||
validate := validator.New()
|
||||
|
||||
// 创建中文翻译器
|
||||
zh_ch := zh.New()
|
||||
uni := ut.New(zh_ch, zh_ch)
|
||||
trans, _ := uni.GetTranslator("zh")
|
||||
|
||||
//注册一个函数,获取struct tag里自定义的label作为字段名
|
||||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
name := fld.Tag.Get("label")
|
||||
return name
|
||||
})
|
||||
|
||||
// 注册中文翻译器到验证器
|
||||
_ = zh_translations.RegisterDefaultTranslations(validate, trans)
|
||||
|
||||
// 验证结构体
|
||||
err = validate.Struct(s)
|
||||
if err != nil {
|
||||
// 处理验证错误
|
||||
if _, ok := err.(*validator.InvalidValidationError); ok {
|
||||
fmt.Println("处理验证错误error:", err)
|
||||
errMsg = append(errMsg, err.Error())
|
||||
} else {
|
||||
for _, v := range err.(validator.ValidationErrors) {
|
||||
errMsg = append(errMsg, v.Translate(trans))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -82,7 +82,8 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
|||
})
|
||||
|
||||
// 会话历史
|
||||
r.Post("/chat/history", chatHist.GetHistory)
|
||||
r.Post("/chat/history/list", chatHist.List)
|
||||
r.Post("/chat/history/update/content", chatHist.UpdateContent)
|
||||
}
|
||||
|
||||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,10 @@ import (
|
|||
"ai_scheduler/internal/biz"
|
||||
errors "ai_scheduler/internal/data/error"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg/validate"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type HistoryService struct {
|
||||
|
|
@ -18,7 +21,7 @@ func NewHistoryService(chatRepo *biz.ChatHistoryBiz) *HistoryService {
|
|||
}
|
||||
|
||||
// GetHistoryService 获取会话历史
|
||||
func (h *HistoryService) GetHistory(c *fiber.Ctx) error {
|
||||
func (h *HistoryService) List(c *fiber.Ctx) error {
|
||||
var query entitys.ChatHistQuery
|
||||
if err := c.BodyParser(&query); err != nil {
|
||||
return err
|
||||
|
|
@ -42,3 +45,19 @@ func (h *HistoryService) GetHistory(c *fiber.Ctx) error {
|
|||
|
||||
return c.JSON(history)
|
||||
}
|
||||
|
||||
func (h *HistoryService) UpdateContent(c *fiber.Ctx) error {
|
||||
var req entitys.UpdateContentRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
// 校验参数
|
||||
msg, err := validate.Struct(req)
|
||||
if err != nil {
|
||||
log.Error(c.UserContext(), "参数错误 error: ", err)
|
||||
return errors.NewBusinessErr(errors.InvalidParamCode, strings.Join(msg, ";"))
|
||||
}
|
||||
|
||||
// 更新历史
|
||||
return h.chatRepo.UpdateContent(c.Context(), &req)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue