diff --git a/internal/biz/session.go b/internal/biz/session.go new file mode 100644 index 0000000..247c4d9 --- /dev/null +++ b/internal/biz/session.go @@ -0,0 +1,85 @@ +package biz + +import ( + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "context" + "fmt" + "github.com/gofiber/fiber/v2/utils" + "time" + + "ai_scheduler/internal/config" +) + +type SessionBiz struct { + sessionRepo *impl.SessionImpl + sysRepo *impl.SysImpl + + conf *config.Config +} + +func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl) *SessionBiz { + return &SessionBiz{ + sessionRepo: sessionImpl, + sysRepo: sysImpl, + conf: conf, + } +} + +// InitSession 初始化会话 ,当天存在则返回会话,如果不存在则创建一个 +func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRequest) (sessionId string, err error) { + + // 获取系统配置 + sysConfig, has, err := s.sysRepo.FindOne(s.sysRepo.WithSysId(req.SysId)) + if err != nil { + return "", err + } else if !has { + return "", fmt.Errorf("sys not found") + } + + // 获取 当天的session + t := time.Now().Truncate(24 * time.Hour) + session, has, err := s.sessionRepo.FindOne( + s.sessionRepo.WithUserId(req.UserId), // 条件:用户ID + s.sessionRepo.WithStartTime(t), // 条件:会话开始时间 + s.sysRepo.WithSysId(sysConfig.SysID), // 条件:系统ID + ) + if err != nil { + return "", err + } else if !has { + // 不存在,创建一个 + session = model.AiSession{ + SysID: sysConfig.SysID, + SessionID: utils.UUID(), + CreateAt: time.Now(), + UpdateAt: time.Now(), + } + err = s.sessionRepo.Create(&session) + if err != nil { + return "", err + } + } + + return session.SessionID, nil +} + +// SessionList 会话列表 +func (s *SessionBiz) SessionList(ctx context.Context, req *entitys.SessionListRequest) (list []model.AiSession, err error) { + + if req.Page <= 0 { + req.Page = 1 + } + if req.PageSize <= 0 { + req.PageSize = 10 + } + + list, err = s.sessionRepo.FindAll( + s.sessionRepo.WithUserId(req.UserId), // 条件:用户ID + s.sessionRepo.WithSysId(req.SysId), // 条件:系统ID + s.sessionRepo.PaginateScope(req.Page, req.PageSize), // 分页 + s.sessionRepo.OrderByDesc("create_at"), // 排序:按创建时间降序 + ) + + return +} diff --git a/internal/data/impl/base.go b/internal/data/impl/base.go new file mode 100644 index 0000000..bed5281 --- /dev/null +++ b/internal/data/impl/base.go @@ -0,0 +1,208 @@ +package impl + +import ( + "ai_scheduler/internal/data/model" + "errors" + "fmt" + "gorm.io/gorm" +) + +var ( + ErrNoConditions = errors.New("不允许不带条件的操作") + ErrInvalidPage = errors.New("page和pageSize必须为正整数") +) + +/* +GORM 基础模型 +BaseModel 是一个泛型结构体,用于封装GORM数据库通用操作。 +支持的PO类型需实现具体的数据库模型(如ZxCreditOrder、ZxCreditLog等) +*/ + +// 定义受支持的PO类型集合(可根据需要扩展), 只有包含表结构才能使用BaseModel,避免使用出现问题 +type PO interface { + model.AiSy | model.AiSession | model.AiTask +} + +type BaseModel[P PO] struct { + Db *gorm.DB // 数据库连接 +} + +func NewBaseModel[P PO](db *gorm.DB) BaseRepository[P] { + return &BaseModel[P]{ + Db: db, + } +} + +// 查询条件函数类型,支持链式条件组装, 由外部调用者自定义组装 +type CondFunc = func(db *gorm.DB) *gorm.DB + +// 定义通用数据库操作接口 +type BaseRepository[P PO] interface { + FindAll(conditions ...CondFunc) ([]P, error) // 查询所有 + Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) // 分页查询 + FindOne(conditions ...CondFunc) (P, bool, error) // 查询单条记录,若未找到则返回 has=false, err=nil + Create(m *P) error // 创建 + BatchCreate(m *[]P) (err error) // 批量创建 + Update(m *P, conditions ...CondFunc) (err error) // 更新 + Delete(conditions ...CondFunc) (err error) // 删除 + Count(conditions ...CondFunc) (count int64, err error) // 查询条数 + PaginateScope(page, pageSize int) CondFunc // 分页 + OrderByDesc(field string) CondFunc // 倒序排序 + WithId(id interface{}) CondFunc // 查询id + WithStatus(status int) CondFunc // 查询status + GetDb() *gorm.DB // 获取数据库连接 +} + +// PaginationResult 分页查询结果 +type PaginationResult[P any] struct { + List []P `json:"list"` // 数据列表 + Total int64 `json:"total"` // 总记录数 + Page int `json:"page"` // 当前页码(从1开始) + PageSize int `json:"page_size"` // 每页数量 +} + +// 分页查询实现 +func (this *BaseModel[P]) Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) { + if page < 1 || pageSize < 1 { + return nil, ErrInvalidPage + } + + db := this.Db.Model(new(P)).Scopes(conditions...) // 自动绑定泛型类型对应的表 + + var ( + total int64 + list []P + ) + + // 先统计总数 + if err := db.Count(&total).Error; err != nil { + return nil, fmt.Errorf("统计总数失败: %w", err) + } + + // 再查询分页数据 + if err := db.Scopes(this.PaginateScope(page, pageSize)).Find(&list).Error; err != nil { + return nil, fmt.Errorf("分页查询失败: %w", err) + } + + return &PaginationResult[P]{ + List: list, + Total: total, + Page: page, + PageSize: pageSize, + }, nil +} + +// 查询所有 +func (this *BaseModel[P]) FindAll(conditions ...CondFunc) ([]P, error) { + var list []P + err := this.Db.Model(new(P)). + Scopes(conditions...). + Find(&list).Error + return list, err +} + +// 查询单条记录(优先返回第一条) +func (this *BaseModel[P]) FindOne(conditions ...CondFunc) (P, bool, error) { + var ( + result P + ) + + err := this.Db.Model(new(P)). + Scopes(conditions...). + First(&result). + Error + + if errors.Is(err, gorm.ErrRecordNotFound) { + return result, false, nil // 未找到记录 + } + if err != nil { + return result, false, fmt.Errorf("查询单条记录失败: %w", err) + } + + return result, true, err +} + +// 创建 +func (this *BaseModel[P]) Create(m *P) error { + err := this.Db.Create(m).Error + return err +} + +// 批量创建 +func (this *BaseModel[P]) BatchCreate(m *[]P) (err error) { + //使用 CreateInBatches 方法分批插入用户,批次大小为 100 + err = this.Db.CreateInBatches(m, 100).Error + return err +} + +// 按条件更新记录(必须带条件) +func (this *BaseModel[P]) Update(m *P, conditions ...CondFunc) (err error) { + // 没有更新条件 + if len(conditions) == 0 { + return ErrNoConditions + } + + return this.Db.Model(new(P)). + Scopes(conditions...). + Updates(m). + Error +} + +// 按条件删除记录(必须带条件) +func (this *BaseModel[P]) Delete(conditions ...CondFunc) (err error) { + // 没有更新条件 + if len(conditions) == 0 { + return ErrNoConditions + } + + return this.Db.Model(new(P)). + Scopes(conditions...). + Delete(new(P)). + Error +} + +// 统计符合条件的记录数 +func (this *BaseModel[P]) Count(conditions ...CondFunc) (count int64, err error) { + err = this.Db.Model(new(P)). + Scopes(conditions...). + Count(&count). + Error + return count, err +} + +// 分页条件生成器 +func (this *BaseModel[P]) PaginateScope(page, pageSize int) CondFunc { + return func(db *gorm.DB) *gorm.DB { + if page < 1 || pageSize < 1 { + return db // 由上层方法校验参数有效性 + } + offset := (page - 1) * pageSize + return db.Offset(offset).Limit(pageSize) + } +} + +// 倒序排序条件生成器 +func (this *BaseModel[P]) OrderByDesc(field string) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Order(fmt.Sprintf("%s DESC", field)) + } +} + +// ID查询条件生成器 +func (this *BaseModel[P]) WithId(id interface{}) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where("id = ?", id) + } +} + +// 状态查询条件生成器 +func (this *BaseModel[P]) WithStatus(status int) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where("status =?", status) + } +} + +// 获取数据链接 +func (this *BaseModel[P]) GetDb() *gorm.DB { + return this.Db +} diff --git a/internal/data/impl/session_impl.go b/internal/data/impl/session_impl.go index f0ff5e2..4b79c7b 100644 --- a/internal/data/impl/session_impl.go +++ b/internal/data/impl/session_impl.go @@ -4,12 +4,39 @@ import ( "ai_scheduler/internal/data/model" "ai_scheduler/tmpl/dataTemp" "ai_scheduler/utils" + "gorm.io/gorm" + "time" ) type SessionImpl struct { dataTemp.DataTemp + BaseModel[model.AiSession] } func NewSessionImpl(db *utils.Db) *SessionImpl { - return &SessionImpl{*dataTemp.NewDataTemp(db, new(model.AiSession))} + return &SessionImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSession)), + BaseModel: BaseModel[model.AiSession]{}, + } +} + +// WithUserId 条件:用户ID +func (impl *SessionImpl) WithUserId(userId interface{}) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where("user_id = ?", userId) + } +} + +// WithStartTime 条件:会话开始时间 +func (impl *SessionImpl) WithStartTime(startTime time.Time) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where("create_at >= ?", startTime) + } +} + +// WithSysId 系统id +func (s *SessionImpl) WithSysId(sysId interface{}) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where("sys_id = ?", sysId) + } } diff --git a/internal/data/impl/sys_impl.go b/internal/data/impl/sys_impl.go index 4370f33..7252657 100644 --- a/internal/data/impl/sys_impl.go +++ b/internal/data/impl/sys_impl.go @@ -4,12 +4,24 @@ import ( "ai_scheduler/internal/data/model" "ai_scheduler/tmpl/dataTemp" "ai_scheduler/utils" + "gorm.io/gorm" ) type SysImpl struct { dataTemp.DataTemp + BaseModel[model.AiSy] } func NewSysImpl(db *utils.Db) *SysImpl { - return &SysImpl{*dataTemp.NewDataTemp(db, new(model.AiSy))} + return &SysImpl{ + DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSy)), + BaseModel: BaseModel[model.AiSy]{}, + } +} + +// WithSysId 系统id +func (s *SysImpl) WithSysId(sysId interface{}) CondFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where("sys_id = ?", sysId) + } } diff --git a/internal/entitys/session.go b/internal/entitys/session.go new file mode 100644 index 0000000..1ceb514 --- /dev/null +++ b/internal/entitys/session.go @@ -0,0 +1,13 @@ +package entitys + +type SessionInitRequest struct { + SysId string `json:"sys_id"` + UserId string `json:"user_id"` +} + +type SessionListRequest struct { + SysId string `json:"sys_id"` + UserId string `json:"user_id"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} diff --git a/internal/server/http.go b/internal/server/http.go index 65e73a1..19eb931 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -11,10 +11,11 @@ import ( func NewHTTPServer( service *services.ChatService, + session *services.SessionService, ) *fiber.App { //构建 server app := initRoute() - router.SetupRoutes(app, service) + router.SetupRoutes(app, service, session) return app } diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 13de910..0b3d10a 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -13,7 +13,7 @@ import ( ) // SetupRoutes 设置路由 -func SetupRoutes(app *fiber.App, ChatService *services.ChatService) { +func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService) { app.Use(func(c *fiber.Ctx) error { // 设置 CORS 头 c.Set("Access-Control-Allow-Origin", "*") @@ -28,7 +28,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService) { // 继续处理后续中间件或路由 return c.Next() }) - routerHttp(app) + routerHttp(app, sessionService) routerSocket(app, ChatService) } @@ -56,7 +56,7 @@ var bufferPool = &sync.Pool{ }, } -func routerHttp(app *fiber.App) { +func routerHttp(app *fiber.App, sessionService *services.SessionService) { r := app.Group("api/v1/") registerResponse(r) // 注册 CORS 中间件 @@ -64,6 +64,10 @@ func routerHttp(app *fiber.App) { c.Response().SetBody([]byte("1")) return nil }) + + r.Post("/session/init", sessionService.SessionInit) + r.Post("/session/list", sessionService.SessionList) + } func registerResponse(router fiber.Router) { diff --git a/internal/services/base.go b/internal/services/base.go new file mode 100644 index 0000000..804d255 --- /dev/null +++ b/internal/services/base.go @@ -0,0 +1,33 @@ +package services + +import ( + errorcode "ai_scheduler/internal/data/error" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/log" +) + +// 响应数据 +func handRes(c *fiber.Ctx, _err error, rsp interface{}) error { + var ( + err *errorcode.BusinessErr + ) + if _err == nil { + err = errorcode.Success + } else { + if e, ok := _err.(*errorcode.BusinessErr); ok { + err = e + } else { + log.Error(c.UserContext(), "系统错误 error: ", _err) + err = errorcode.NewBusinessErr("500", _err.Error()) + } + } + + body := fiber.Map{ + "code": err.Code, + "msg": err.Error(), + "data": rsp, + } + + log.Info(c.UserContext(), c.Path(), "请求参数=", c.BodyRaw(), "响应=", body) + return c.JSON(body) +} diff --git a/internal/services/session.go b/internal/services/session.go new file mode 100644 index 0000000..44139e0 --- /dev/null +++ b/internal/services/session.go @@ -0,0 +1,45 @@ +package services + +import ( + "ai_scheduler/internal/biz" + "ai_scheduler/internal/entitys" + "github.com/gofiber/fiber/v2" +) + +type SessionService struct { + sessionBiz *biz.SessionBiz +} + +func NewSession(sessionBiz *biz.SessionBiz) *SessionService { + return &SessionService{ + sessionBiz: sessionBiz, + } +} + +// SessionInit 初始化会话 +func (s *SessionService) SessionInit(c *fiber.Ctx) error { + req := &entitys.SessionInitRequest{} + if err := c.BodyParser(req); err != nil { + return err + } + + sessionId, err := s.sessionBiz.SessionInit(c.Context(), req) + + return handRes(c, err, fiber.Map{ + "session_id": sessionId, + }) +} + +// SessionList 获取会话列表 +func (s *SessionService) SessionList(c *fiber.Ctx) error { + req := &entitys.SessionListRequest{} + if err := c.BodyParser(req); err != nil { + return err + } + + sessionList, err := s.sessionBiz.SessionList(c.Context(), req) + + return handRes(c, err, fiber.Map{ + "session_list": sessionList, + }) +}