Merge branch 'refs/heads/feature/session' into feature/dev
This commit is contained in:
commit
11241af84c
|
@ -0,0 +1,85 @@
|
|||
package biz
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
"time"
|
||||
|
||||
"ai_scheduler/internal/config"
|
||||
)
|
||||
|
||||
type SessionBiz struct {
|
||||
sessionRepo *impl.SessionImpl
|
||||
sysRepo *impl.SysImpl
|
||||
|
||||
conf *config.Config
|
||||
}
|
||||
|
||||
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl) *SessionBiz {
|
||||
return &SessionBiz{
|
||||
sessionRepo: sessionImpl,
|
||||
sysRepo: sysImpl,
|
||||
conf: conf,
|
||||
}
|
||||
}
|
||||
|
||||
// InitSession 初始化会话 ,当天存在则返回会话,如果不存在则创建一个
|
||||
func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRequest) (sessionId string, err error) {
|
||||
|
||||
// 获取系统配置
|
||||
sysConfig, has, err := s.sysRepo.FindOne(s.sysRepo.WithSysId(req.SysId))
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if !has {
|
||||
return "", fmt.Errorf("sys not found")
|
||||
}
|
||||
|
||||
// 获取 当天的session
|
||||
t := time.Now().Truncate(24 * time.Hour)
|
||||
session, has, err := s.sessionRepo.FindOne(
|
||||
s.sessionRepo.WithUserId(req.UserId), // 条件:用户ID
|
||||
s.sessionRepo.WithStartTime(t), // 条件:会话开始时间
|
||||
s.sysRepo.WithSysId(sysConfig.SysID), // 条件:系统ID
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if !has {
|
||||
// 不存在,创建一个
|
||||
session = model.AiSession{
|
||||
SysID: sysConfig.SysID,
|
||||
SessionID: utils.UUID(),
|
||||
CreateAt: time.Now(),
|
||||
UpdateAt: time.Now(),
|
||||
}
|
||||
err = s.sessionRepo.Create(&session)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return session.SessionID, nil
|
||||
}
|
||||
|
||||
// SessionList 会话列表
|
||||
func (s *SessionBiz) SessionList(ctx context.Context, req *entitys.SessionListRequest) (list []model.AiSession, err error) {
|
||||
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 10
|
||||
}
|
||||
|
||||
list, err = s.sessionRepo.FindAll(
|
||||
s.sessionRepo.WithUserId(req.UserId), // 条件:用户ID
|
||||
s.sessionRepo.WithSysId(req.SysId), // 条件:系统ID
|
||||
s.sessionRepo.PaginateScope(req.Page, req.PageSize), // 分页
|
||||
s.sessionRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
|
||||
)
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,208 @@
|
|||
package impl
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/model"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoConditions = errors.New("不允许不带条件的操作")
|
||||
ErrInvalidPage = errors.New("page和pageSize必须为正整数")
|
||||
)
|
||||
|
||||
/*
|
||||
GORM 基础模型
|
||||
BaseModel 是一个泛型结构体,用于封装GORM数据库通用操作。
|
||||
支持的PO类型需实现具体的数据库模型(如ZxCreditOrder、ZxCreditLog等)
|
||||
*/
|
||||
|
||||
// 定义受支持的PO类型集合(可根据需要扩展), 只有包含表结构才能使用BaseModel,避免使用出现问题
|
||||
type PO interface {
|
||||
model.AiSy | model.AiSession | model.AiTask
|
||||
}
|
||||
|
||||
type BaseModel[P PO] struct {
|
||||
Db *gorm.DB // 数据库连接
|
||||
}
|
||||
|
||||
func NewBaseModel[P PO](db *gorm.DB) BaseRepository[P] {
|
||||
return &BaseModel[P]{
|
||||
Db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// 查询条件函数类型,支持链式条件组装, 由外部调用者自定义组装
|
||||
type CondFunc = func(db *gorm.DB) *gorm.DB
|
||||
|
||||
// 定义通用数据库操作接口
|
||||
type BaseRepository[P PO] interface {
|
||||
FindAll(conditions ...CondFunc) ([]P, error) // 查询所有
|
||||
Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) // 分页查询
|
||||
FindOne(conditions ...CondFunc) (P, bool, error) // 查询单条记录,若未找到则返回 has=false, err=nil
|
||||
Create(m *P) error // 创建
|
||||
BatchCreate(m *[]P) (err error) // 批量创建
|
||||
Update(m *P, conditions ...CondFunc) (err error) // 更新
|
||||
Delete(conditions ...CondFunc) (err error) // 删除
|
||||
Count(conditions ...CondFunc) (count int64, err error) // 查询条数
|
||||
PaginateScope(page, pageSize int) CondFunc // 分页
|
||||
OrderByDesc(field string) CondFunc // 倒序排序
|
||||
WithId(id interface{}) CondFunc // 查询id
|
||||
WithStatus(status int) CondFunc // 查询status
|
||||
GetDb() *gorm.DB // 获取数据库连接
|
||||
}
|
||||
|
||||
// PaginationResult 分页查询结果
|
||||
type PaginationResult[P any] struct {
|
||||
List []P `json:"list"` // 数据列表
|
||||
Total int64 `json:"total"` // 总记录数
|
||||
Page int `json:"page"` // 当前页码(从1开始)
|
||||
PageSize int `json:"page_size"` // 每页数量
|
||||
}
|
||||
|
||||
// 分页查询实现
|
||||
func (this *BaseModel[P]) Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) {
|
||||
if page < 1 || pageSize < 1 {
|
||||
return nil, ErrInvalidPage
|
||||
}
|
||||
|
||||
db := this.Db.Model(new(P)).Scopes(conditions...) // 自动绑定泛型类型对应的表
|
||||
|
||||
var (
|
||||
total int64
|
||||
list []P
|
||||
)
|
||||
|
||||
// 先统计总数
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, fmt.Errorf("统计总数失败: %w", err)
|
||||
}
|
||||
|
||||
// 再查询分页数据
|
||||
if err := db.Scopes(this.PaginateScope(page, pageSize)).Find(&list).Error; err != nil {
|
||||
return nil, fmt.Errorf("分页查询失败: %w", err)
|
||||
}
|
||||
|
||||
return &PaginationResult[P]{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 查询所有
|
||||
func (this *BaseModel[P]) FindAll(conditions ...CondFunc) ([]P, error) {
|
||||
var list []P
|
||||
err := this.Db.Model(new(P)).
|
||||
Scopes(conditions...).
|
||||
Find(&list).Error
|
||||
return list, err
|
||||
}
|
||||
|
||||
// 查询单条记录(优先返回第一条)
|
||||
func (this *BaseModel[P]) FindOne(conditions ...CondFunc) (P, bool, error) {
|
||||
var (
|
||||
result P
|
||||
)
|
||||
|
||||
err := this.Db.Model(new(P)).
|
||||
Scopes(conditions...).
|
||||
First(&result).
|
||||
Error
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return result, false, nil // 未找到记录
|
||||
}
|
||||
if err != nil {
|
||||
return result, false, fmt.Errorf("查询单条记录失败: %w", err)
|
||||
}
|
||||
|
||||
return result, true, err
|
||||
}
|
||||
|
||||
// 创建
|
||||
func (this *BaseModel[P]) Create(m *P) error {
|
||||
err := this.Db.Create(m).Error
|
||||
return err
|
||||
}
|
||||
|
||||
// 批量创建
|
||||
func (this *BaseModel[P]) BatchCreate(m *[]P) (err error) {
|
||||
//使用 CreateInBatches 方法分批插入用户,批次大小为 100
|
||||
err = this.Db.CreateInBatches(m, 100).Error
|
||||
return err
|
||||
}
|
||||
|
||||
// 按条件更新记录(必须带条件)
|
||||
func (this *BaseModel[P]) Update(m *P, conditions ...CondFunc) (err error) {
|
||||
// 没有更新条件
|
||||
if len(conditions) == 0 {
|
||||
return ErrNoConditions
|
||||
}
|
||||
|
||||
return this.Db.Model(new(P)).
|
||||
Scopes(conditions...).
|
||||
Updates(m).
|
||||
Error
|
||||
}
|
||||
|
||||
// 按条件删除记录(必须带条件)
|
||||
func (this *BaseModel[P]) Delete(conditions ...CondFunc) (err error) {
|
||||
// 没有更新条件
|
||||
if len(conditions) == 0 {
|
||||
return ErrNoConditions
|
||||
}
|
||||
|
||||
return this.Db.Model(new(P)).
|
||||
Scopes(conditions...).
|
||||
Delete(new(P)).
|
||||
Error
|
||||
}
|
||||
|
||||
// 统计符合条件的记录数
|
||||
func (this *BaseModel[P]) Count(conditions ...CondFunc) (count int64, err error) {
|
||||
err = this.Db.Model(new(P)).
|
||||
Scopes(conditions...).
|
||||
Count(&count).
|
||||
Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// 分页条件生成器
|
||||
func (this *BaseModel[P]) PaginateScope(page, pageSize int) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
if page < 1 || pageSize < 1 {
|
||||
return db // 由上层方法校验参数有效性
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
return db.Offset(offset).Limit(pageSize)
|
||||
}
|
||||
}
|
||||
|
||||
// 倒序排序条件生成器
|
||||
func (this *BaseModel[P]) OrderByDesc(field string) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Order(fmt.Sprintf("%s DESC", field))
|
||||
}
|
||||
}
|
||||
|
||||
// ID查询条件生成器
|
||||
func (this *BaseModel[P]) WithId(id interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("id = ?", id)
|
||||
}
|
||||
}
|
||||
|
||||
// 状态查询条件生成器
|
||||
func (this *BaseModel[P]) WithStatus(status int) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("status =?", status)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取数据链接
|
||||
func (this *BaseModel[P]) GetDb() *gorm.DB {
|
||||
return this.Db
|
||||
}
|
|
@ -4,12 +4,39 @@ import (
|
|||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/tmpl/dataTemp"
|
||||
"ai_scheduler/utils"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SessionImpl struct {
|
||||
dataTemp.DataTemp
|
||||
BaseModel[model.AiSession]
|
||||
}
|
||||
|
||||
func NewSessionImpl(db *utils.Db) *SessionImpl {
|
||||
return &SessionImpl{*dataTemp.NewDataTemp(db, new(model.AiSession))}
|
||||
return &SessionImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSession)),
|
||||
BaseModel: BaseModel[model.AiSession]{},
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserId 条件:用户ID
|
||||
func (impl *SessionImpl) WithUserId(userId interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("user_id = ?", userId)
|
||||
}
|
||||
}
|
||||
|
||||
// WithStartTime 条件:会话开始时间
|
||||
func (impl *SessionImpl) WithStartTime(startTime time.Time) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("create_at >= ?", startTime)
|
||||
}
|
||||
}
|
||||
|
||||
// WithSysId 系统id
|
||||
func (s *SessionImpl) WithSysId(sysId interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("sys_id = ?", sysId)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,12 +4,24 @@ import (
|
|||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/tmpl/dataTemp"
|
||||
"ai_scheduler/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type SysImpl struct {
|
||||
dataTemp.DataTemp
|
||||
BaseModel[model.AiSy]
|
||||
}
|
||||
|
||||
func NewSysImpl(db *utils.Db) *SysImpl {
|
||||
return &SysImpl{*dataTemp.NewDataTemp(db, new(model.AiSy))}
|
||||
return &SysImpl{
|
||||
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSy)),
|
||||
BaseModel: BaseModel[model.AiSy]{},
|
||||
}
|
||||
}
|
||||
|
||||
// WithSysId 系统id
|
||||
func (s *SysImpl) WithSysId(sysId interface{}) CondFunc {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("sys_id = ?", sysId)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
package entitys
|
||||
|
||||
type SessionInitRequest struct {
|
||||
SysId string `json:"sys_id"`
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type SessionListRequest struct {
|
||||
SysId string `json:"sys_id"`
|
||||
UserId string `json:"user_id"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
|
@ -11,10 +11,11 @@ import (
|
|||
|
||||
func NewHTTPServer(
|
||||
service *services.ChatService,
|
||||
session *services.SessionService,
|
||||
) *fiber.App {
|
||||
//构建 server
|
||||
app := initRoute()
|
||||
router.SetupRoutes(app, service)
|
||||
router.SetupRoutes(app, service, session)
|
||||
return app
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
// SetupRoutes 设置路由
|
||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService) {
|
||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService) {
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
// 设置 CORS 头
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
|
@ -28,7 +28,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService) {
|
|||
// 继续处理后续中间件或路由
|
||||
return c.Next()
|
||||
})
|
||||
routerHttp(app)
|
||||
routerHttp(app, sessionService)
|
||||
routerSocket(app, ChatService)
|
||||
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ var bufferPool = &sync.Pool{
|
|||
},
|
||||
}
|
||||
|
||||
func routerHttp(app *fiber.App) {
|
||||
func routerHttp(app *fiber.App, sessionService *services.SessionService) {
|
||||
r := app.Group("api/v1/")
|
||||
registerResponse(r)
|
||||
// 注册 CORS 中间件
|
||||
|
@ -64,6 +64,10 @@ func routerHttp(app *fiber.App) {
|
|||
c.Response().SetBody([]byte("1"))
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Post("/session/init", sessionService.SessionInit)
|
||||
r.Post("/session/list", sessionService.SessionList)
|
||||
|
||||
}
|
||||
|
||||
func registerResponse(router fiber.Router) {
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
errorcode "ai_scheduler/internal/data/error"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// 响应数据
|
||||
func handRes(c *fiber.Ctx, _err error, rsp interface{}) error {
|
||||
var (
|
||||
err *errorcode.BusinessErr
|
||||
)
|
||||
if _err == nil {
|
||||
err = errorcode.Success
|
||||
} else {
|
||||
if e, ok := _err.(*errorcode.BusinessErr); ok {
|
||||
err = e
|
||||
} else {
|
||||
log.Error(c.UserContext(), "系统错误 error: ", _err)
|
||||
err = errorcode.NewBusinessErr("500", _err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
body := fiber.Map{
|
||||
"code": err.Code,
|
||||
"msg": err.Error(),
|
||||
"data": rsp,
|
||||
}
|
||||
|
||||
log.Info(c.UserContext(), c.Path(), "请求参数=", c.BodyRaw(), "响应=", body)
|
||||
return c.JSON(body)
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type SessionService struct {
|
||||
sessionBiz *biz.SessionBiz
|
||||
}
|
||||
|
||||
func NewSession(sessionBiz *biz.SessionBiz) *SessionService {
|
||||
return &SessionService{
|
||||
sessionBiz: sessionBiz,
|
||||
}
|
||||
}
|
||||
|
||||
// SessionInit 初始化会话
|
||||
func (s *SessionService) SessionInit(c *fiber.Ctx) error {
|
||||
req := &entitys.SessionInitRequest{}
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessionId, err := s.sessionBiz.SessionInit(c.Context(), req)
|
||||
|
||||
return handRes(c, err, fiber.Map{
|
||||
"session_id": sessionId,
|
||||
})
|
||||
}
|
||||
|
||||
// SessionList 获取会话列表
|
||||
func (s *SessionService) SessionList(c *fiber.Ctx) error {
|
||||
req := &entitys.SessionListRequest{}
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessionList, err := s.sessionBiz.SessionList(c.Context(), req)
|
||||
|
||||
return handRes(c, err, fiber.Map{
|
||||
"session_list": sessionList,
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue