Merge branch 'refs/heads/feature/session' into feature/dev

This commit is contained in:
wolter 2025-09-18 14:52:43 +08:00
commit 11241af84c
9 changed files with 434 additions and 6 deletions

85
internal/biz/session.go Normal file
View File

@ -0,0 +1,85 @@
package biz
import (
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"context"
"fmt"
"github.com/gofiber/fiber/v2/utils"
"time"
"ai_scheduler/internal/config"
)
type SessionBiz struct {
sessionRepo *impl.SessionImpl
sysRepo *impl.SysImpl
conf *config.Config
}
func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl) *SessionBiz {
return &SessionBiz{
sessionRepo: sessionImpl,
sysRepo: sysImpl,
conf: conf,
}
}
// InitSession 初始化会话 ,当天存在则返回会话,如果不存在则创建一个
func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRequest) (sessionId string, err error) {
// 获取系统配置
sysConfig, has, err := s.sysRepo.FindOne(s.sysRepo.WithSysId(req.SysId))
if err != nil {
return "", err
} else if !has {
return "", fmt.Errorf("sys not found")
}
// 获取 当天的session
t := time.Now().Truncate(24 * time.Hour)
session, has, err := s.sessionRepo.FindOne(
s.sessionRepo.WithUserId(req.UserId), // 条件用户ID
s.sessionRepo.WithStartTime(t), // 条件:会话开始时间
s.sysRepo.WithSysId(sysConfig.SysID), // 条件系统ID
)
if err != nil {
return "", err
} else if !has {
// 不存在,创建一个
session = model.AiSession{
SysID: sysConfig.SysID,
SessionID: utils.UUID(),
CreateAt: time.Now(),
UpdateAt: time.Now(),
}
err = s.sessionRepo.Create(&session)
if err != nil {
return "", err
}
}
return session.SessionID, nil
}
// SessionList 会话列表
func (s *SessionBiz) SessionList(ctx context.Context, req *entitys.SessionListRequest) (list []model.AiSession, err error) {
if req.Page <= 0 {
req.Page = 1
}
if req.PageSize <= 0 {
req.PageSize = 10
}
list, err = s.sessionRepo.FindAll(
s.sessionRepo.WithUserId(req.UserId), // 条件用户ID
s.sessionRepo.WithSysId(req.SysId), // 条件系统ID
s.sessionRepo.PaginateScope(req.Page, req.PageSize), // 分页
s.sessionRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
)
return
}

208
internal/data/impl/base.go Normal file
View File

@ -0,0 +1,208 @@
package impl
import (
"ai_scheduler/internal/data/model"
"errors"
"fmt"
"gorm.io/gorm"
)
var (
ErrNoConditions = errors.New("不允许不带条件的操作")
ErrInvalidPage = errors.New("page和pageSize必须为正整数")
)
/*
GORM 基础模型
BaseModel 是一个泛型结构体用于封装GORM数据库通用操作
支持的PO类型需实现具体的数据库模型如ZxCreditOrderZxCreditLog等
*/
// 定义受支持的PO类型集合可根据需要扩展, 只有包含表结构才能使用BaseModel避免使用出现问题
type PO interface {
model.AiSy | model.AiSession | model.AiTask
}
type BaseModel[P PO] struct {
Db *gorm.DB // 数据库连接
}
func NewBaseModel[P PO](db *gorm.DB) BaseRepository[P] {
return &BaseModel[P]{
Db: db,
}
}
// 查询条件函数类型,支持链式条件组装, 由外部调用者自定义组装
type CondFunc = func(db *gorm.DB) *gorm.DB
// 定义通用数据库操作接口
type BaseRepository[P PO] interface {
FindAll(conditions ...CondFunc) ([]P, error) // 查询所有
Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) // 分页查询
FindOne(conditions ...CondFunc) (P, bool, error) // 查询单条记录,若未找到则返回 has=false, err=nil
Create(m *P) error // 创建
BatchCreate(m *[]P) (err error) // 批量创建
Update(m *P, conditions ...CondFunc) (err error) // 更新
Delete(conditions ...CondFunc) (err error) // 删除
Count(conditions ...CondFunc) (count int64, err error) // 查询条数
PaginateScope(page, pageSize int) CondFunc // 分页
OrderByDesc(field string) CondFunc // 倒序排序
WithId(id interface{}) CondFunc // 查询id
WithStatus(status int) CondFunc // 查询status
GetDb() *gorm.DB // 获取数据库连接
}
// PaginationResult 分页查询结果
type PaginationResult[P any] struct {
List []P `json:"list"` // 数据列表
Total int64 `json:"total"` // 总记录数
Page int `json:"page"` // 当前页码从1开始
PageSize int `json:"page_size"` // 每页数量
}
// 分页查询实现
func (this *BaseModel[P]) Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) {
if page < 1 || pageSize < 1 {
return nil, ErrInvalidPage
}
db := this.Db.Model(new(P)).Scopes(conditions...) // 自动绑定泛型类型对应的表
var (
total int64
list []P
)
// 先统计总数
if err := db.Count(&total).Error; err != nil {
return nil, fmt.Errorf("统计总数失败: %w", err)
}
// 再查询分页数据
if err := db.Scopes(this.PaginateScope(page, pageSize)).Find(&list).Error; err != nil {
return nil, fmt.Errorf("分页查询失败: %w", err)
}
return &PaginationResult[P]{
List: list,
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
// 查询所有
func (this *BaseModel[P]) FindAll(conditions ...CondFunc) ([]P, error) {
var list []P
err := this.Db.Model(new(P)).
Scopes(conditions...).
Find(&list).Error
return list, err
}
// 查询单条记录(优先返回第一条)
func (this *BaseModel[P]) FindOne(conditions ...CondFunc) (P, bool, error) {
var (
result P
)
err := this.Db.Model(new(P)).
Scopes(conditions...).
First(&result).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return result, false, nil // 未找到记录
}
if err != nil {
return result, false, fmt.Errorf("查询单条记录失败: %w", err)
}
return result, true, err
}
// 创建
func (this *BaseModel[P]) Create(m *P) error {
err := this.Db.Create(m).Error
return err
}
// 批量创建
func (this *BaseModel[P]) BatchCreate(m *[]P) (err error) {
//使用 CreateInBatches 方法分批插入用户,批次大小为 100
err = this.Db.CreateInBatches(m, 100).Error
return err
}
// 按条件更新记录(必须带条件)
func (this *BaseModel[P]) Update(m *P, conditions ...CondFunc) (err error) {
// 没有更新条件
if len(conditions) == 0 {
return ErrNoConditions
}
return this.Db.Model(new(P)).
Scopes(conditions...).
Updates(m).
Error
}
// 按条件删除记录(必须带条件)
func (this *BaseModel[P]) Delete(conditions ...CondFunc) (err error) {
// 没有更新条件
if len(conditions) == 0 {
return ErrNoConditions
}
return this.Db.Model(new(P)).
Scopes(conditions...).
Delete(new(P)).
Error
}
// 统计符合条件的记录数
func (this *BaseModel[P]) Count(conditions ...CondFunc) (count int64, err error) {
err = this.Db.Model(new(P)).
Scopes(conditions...).
Count(&count).
Error
return count, err
}
// 分页条件生成器
func (this *BaseModel[P]) PaginateScope(page, pageSize int) CondFunc {
return func(db *gorm.DB) *gorm.DB {
if page < 1 || pageSize < 1 {
return db // 由上层方法校验参数有效性
}
offset := (page - 1) * pageSize
return db.Offset(offset).Limit(pageSize)
}
}
// 倒序排序条件生成器
func (this *BaseModel[P]) OrderByDesc(field string) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Order(fmt.Sprintf("%s DESC", field))
}
}
// ID查询条件生成器
func (this *BaseModel[P]) WithId(id interface{}) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("id = ?", id)
}
}
// 状态查询条件生成器
func (this *BaseModel[P]) WithStatus(status int) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("status =?", status)
}
}
// 获取数据链接
func (this *BaseModel[P]) GetDb() *gorm.DB {
return this.Db
}

View File

@ -4,12 +4,39 @@ import (
"ai_scheduler/internal/data/model" "ai_scheduler/internal/data/model"
"ai_scheduler/tmpl/dataTemp" "ai_scheduler/tmpl/dataTemp"
"ai_scheduler/utils" "ai_scheduler/utils"
"gorm.io/gorm"
"time"
) )
type SessionImpl struct { type SessionImpl struct {
dataTemp.DataTemp dataTemp.DataTemp
BaseModel[model.AiSession]
} }
func NewSessionImpl(db *utils.Db) *SessionImpl { func NewSessionImpl(db *utils.Db) *SessionImpl {
return &SessionImpl{*dataTemp.NewDataTemp(db, new(model.AiSession))} return &SessionImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSession)),
BaseModel: BaseModel[model.AiSession]{},
}
}
// WithUserId 条件用户ID
func (impl *SessionImpl) WithUserId(userId interface{}) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("user_id = ?", userId)
}
}
// WithStartTime 条件:会话开始时间
func (impl *SessionImpl) WithStartTime(startTime time.Time) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("create_at >= ?", startTime)
}
}
// WithSysId 系统id
func (s *SessionImpl) WithSysId(sysId interface{}) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("sys_id = ?", sysId)
}
} }

View File

@ -4,12 +4,24 @@ import (
"ai_scheduler/internal/data/model" "ai_scheduler/internal/data/model"
"ai_scheduler/tmpl/dataTemp" "ai_scheduler/tmpl/dataTemp"
"ai_scheduler/utils" "ai_scheduler/utils"
"gorm.io/gorm"
) )
type SysImpl struct { type SysImpl struct {
dataTemp.DataTemp dataTemp.DataTemp
BaseModel[model.AiSy]
} }
func NewSysImpl(db *utils.Db) *SysImpl { func NewSysImpl(db *utils.Db) *SysImpl {
return &SysImpl{*dataTemp.NewDataTemp(db, new(model.AiSy))} return &SysImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSy)),
BaseModel: BaseModel[model.AiSy]{},
}
}
// WithSysId 系统id
func (s *SysImpl) WithSysId(sysId interface{}) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("sys_id = ?", sysId)
}
} }

View File

@ -0,0 +1,13 @@
package entitys
type SessionInitRequest struct {
SysId string `json:"sys_id"`
UserId string `json:"user_id"`
}
type SessionListRequest struct {
SysId string `json:"sys_id"`
UserId string `json:"user_id"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}

View File

@ -11,10 +11,11 @@ import (
func NewHTTPServer( func NewHTTPServer(
service *services.ChatService, service *services.ChatService,
session *services.SessionService,
) *fiber.App { ) *fiber.App {
//构建 server //构建 server
app := initRoute() app := initRoute()
router.SetupRoutes(app, service) router.SetupRoutes(app, service, session)
return app return app
} }

View File

@ -13,7 +13,7 @@ import (
) )
// SetupRoutes 设置路由 // SetupRoutes 设置路由
func SetupRoutes(app *fiber.App, ChatService *services.ChatService) { func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService) {
app.Use(func(c *fiber.Ctx) error { app.Use(func(c *fiber.Ctx) error {
// 设置 CORS 头 // 设置 CORS 头
c.Set("Access-Control-Allow-Origin", "*") c.Set("Access-Control-Allow-Origin", "*")
@ -28,7 +28,7 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService) {
// 继续处理后续中间件或路由 // 继续处理后续中间件或路由
return c.Next() return c.Next()
}) })
routerHttp(app) routerHttp(app, sessionService)
routerSocket(app, ChatService) routerSocket(app, ChatService)
} }
@ -56,7 +56,7 @@ var bufferPool = &sync.Pool{
}, },
} }
func routerHttp(app *fiber.App) { func routerHttp(app *fiber.App, sessionService *services.SessionService) {
r := app.Group("api/v1/") r := app.Group("api/v1/")
registerResponse(r) registerResponse(r)
// 注册 CORS 中间件 // 注册 CORS 中间件
@ -64,6 +64,10 @@ func routerHttp(app *fiber.App) {
c.Response().SetBody([]byte("1")) c.Response().SetBody([]byte("1"))
return nil return nil
}) })
r.Post("/session/init", sessionService.SessionInit)
r.Post("/session/list", sessionService.SessionList)
} }
func registerResponse(router fiber.Router) { func registerResponse(router fiber.Router) {

33
internal/services/base.go Normal file
View File

@ -0,0 +1,33 @@
package services
import (
errorcode "ai_scheduler/internal/data/error"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
)
// 响应数据
func handRes(c *fiber.Ctx, _err error, rsp interface{}) error {
var (
err *errorcode.BusinessErr
)
if _err == nil {
err = errorcode.Success
} else {
if e, ok := _err.(*errorcode.BusinessErr); ok {
err = e
} else {
log.Error(c.UserContext(), "系统错误 error: ", _err)
err = errorcode.NewBusinessErr("500", _err.Error())
}
}
body := fiber.Map{
"code": err.Code,
"msg": err.Error(),
"data": rsp,
}
log.Info(c.UserContext(), c.Path(), "请求参数=", c.BodyRaw(), "响应=", body)
return c.JSON(body)
}

View File

@ -0,0 +1,45 @@
package services
import (
"ai_scheduler/internal/biz"
"ai_scheduler/internal/entitys"
"github.com/gofiber/fiber/v2"
)
type SessionService struct {
sessionBiz *biz.SessionBiz
}
func NewSession(sessionBiz *biz.SessionBiz) *SessionService {
return &SessionService{
sessionBiz: sessionBiz,
}
}
// SessionInit 初始化会话
func (s *SessionService) SessionInit(c *fiber.Ctx) error {
req := &entitys.SessionInitRequest{}
if err := c.BodyParser(req); err != nil {
return err
}
sessionId, err := s.sessionBiz.SessionInit(c.Context(), req)
return handRes(c, err, fiber.Map{
"session_id": sessionId,
})
}
// SessionList 获取会话列表
func (s *SessionService) SessionList(c *fiber.Ctx) error {
req := &entitys.SessionListRequest{}
if err := c.BodyParser(req); err != nil {
return err
}
sessionList, err := s.sessionBiz.SessionList(c.Context(), req)
return handRes(c, err, fiber.Map{
"session_list": sessionList,
})
}