package impl import ( "ai_scheduler/internal/data/model" "errors" "fmt" "gorm.io/gorm" ) var ( ErrNoConditions = errors.New("不允许不带条件的操作") ErrInvalidPage = errors.New("page和pageSize必须为正整数") ) /* GORM 基础模型 BaseModel 是一个泛型结构体,用于封装GORM数据库通用操作。 支持的PO类型需实现具体的数据库模型(如ZxCreditOrder、ZxCreditLog等) */ // 定义受支持的PO类型集合(可根据需要扩展), 只有包含表结构才能使用BaseModel,避免使用出现问题 type PO interface { model.AiSy | model.AiSession | model.AiTask } type BaseModel[P PO] struct { Db *gorm.DB // 数据库连接 } func NewBaseModel[P PO](db *gorm.DB) BaseRepository[P] { return &BaseModel[P]{ Db: db, } } // 查询条件函数类型,支持链式条件组装, 由外部调用者自定义组装 type CondFunc = func(db *gorm.DB) *gorm.DB // 定义通用数据库操作接口 type BaseRepository[P PO] interface { FindAll(conditions ...CondFunc) ([]P, error) // 查询所有 Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) // 分页查询 FindOne(conditions ...CondFunc) (P, bool, error) // 查询单条记录,若未找到则返回 has=false, err=nil Create(m *P) error // 创建 BatchCreate(m *[]P) (err error) // 批量创建 Update(m *P, conditions ...CondFunc) (err error) // 更新 Delete(conditions ...CondFunc) (err error) // 删除 Count(conditions ...CondFunc) (count int64, err error) // 查询条数 PaginateScope(page, pageSize int) CondFunc // 分页 OrderByDesc(field string) CondFunc // 倒序排序 WithId(id interface{}) CondFunc // 查询id WithStatus(status int) CondFunc // 查询status GetDb() *gorm.DB // 获取数据库连接 } // PaginationResult 分页查询结果 type PaginationResult[P any] struct { List []P `json:"list"` // 数据列表 Total int64 `json:"total"` // 总记录数 Page int `json:"page"` // 当前页码(从1开始) PageSize int `json:"page_size"` // 每页数量 } // 分页查询实现 func (this *BaseModel[P]) Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) { if page < 1 || pageSize < 1 { return nil, ErrInvalidPage } db := this.Db.Model(new(P)).Scopes(conditions...) // 自动绑定泛型类型对应的表 var ( total int64 list []P ) // 先统计总数 if err := db.Count(&total).Error; err != nil { return nil, fmt.Errorf("统计总数失败: %w", err) } // 再查询分页数据 if err := db.Scopes(this.PaginateScope(page, pageSize)).Find(&list).Error; err != nil { return nil, fmt.Errorf("分页查询失败: %w", err) } return &PaginationResult[P]{ List: list, Total: total, Page: page, PageSize: pageSize, }, nil } // 查询所有 func (this *BaseModel[P]) FindAll(conditions ...CondFunc) ([]P, error) { var list []P err := this.Db.Model(new(P)). Scopes(conditions...). Find(&list).Error return list, err } // 查询单条记录(优先返回第一条) func (this *BaseModel[P]) FindOne(conditions ...CondFunc) (P, bool, error) { var ( result P ) err := this.Db.Model(new(P)). Scopes(conditions...). First(&result). Error if errors.Is(err, gorm.ErrRecordNotFound) { return result, false, nil // 未找到记录 } if err != nil { return result, false, fmt.Errorf("查询单条记录失败: %w", err) } return result, true, err } // 创建 func (this *BaseModel[P]) Create(m *P) error { err := this.Db.Create(m).Error return err } // 批量创建 func (this *BaseModel[P]) BatchCreate(m *[]P) (err error) { //使用 CreateInBatches 方法分批插入用户,批次大小为 100 err = this.Db.CreateInBatches(m, 100).Error return err } // 按条件更新记录(必须带条件) func (this *BaseModel[P]) Update(m *P, conditions ...CondFunc) (err error) { // 没有更新条件 if len(conditions) == 0 { return ErrNoConditions } return this.Db.Model(new(P)). Scopes(conditions...). Updates(m). Error } // 按条件删除记录(必须带条件) func (this *BaseModel[P]) Delete(conditions ...CondFunc) (err error) { // 没有更新条件 if len(conditions) == 0 { return ErrNoConditions } return this.Db.Model(new(P)). Scopes(conditions...). Delete(new(P)). Error } // 统计符合条件的记录数 func (this *BaseModel[P]) Count(conditions ...CondFunc) (count int64, err error) { err = this.Db.Model(new(P)). Scopes(conditions...). Count(&count). Error return count, err } // 分页条件生成器 func (this *BaseModel[P]) PaginateScope(page, pageSize int) CondFunc { return func(db *gorm.DB) *gorm.DB { if page < 1 || pageSize < 1 { return db // 由上层方法校验参数有效性 } offset := (page - 1) * pageSize return db.Offset(offset).Limit(pageSize) } } // 倒序排序条件生成器 func (this *BaseModel[P]) OrderByDesc(field string) CondFunc { return func(db *gorm.DB) *gorm.DB { return db.Order(fmt.Sprintf("%s DESC", field)) } } // ID查询条件生成器 func (this *BaseModel[P]) WithId(id interface{}) CondFunc { return func(db *gorm.DB) *gorm.DB { return db.Where("id = ?", id) } } // 状态查询条件生成器 func (this *BaseModel[P]) WithStatus(status int) CondFunc { return func(db *gorm.DB) *gorm.DB { return db.Where("status =?", status) } } // 获取数据链接 func (this *BaseModel[P]) GetDb() *gorm.DB { return this.Db }