217 lines
6.1 KiB
Go
217 lines
6.1 KiB
Go
package impl
|
||
|
||
import (
|
||
"ai_scheduler/internal/data/model"
|
||
"errors"
|
||
"fmt"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
var (
|
||
ErrNoConditions = errors.New("不允许不带条件的操作")
|
||
ErrInvalidPage = errors.New("page和pageSize必须为正整数")
|
||
)
|
||
|
||
/*
|
||
GORM 基础模型
|
||
BaseModel 是一个泛型结构体,用于封装GORM数据库通用操作。
|
||
支持的PO类型需实现具体的数据库模型(如ZxCreditOrder、ZxCreditLog等)
|
||
*/
|
||
|
||
// 定义受支持的PO类型集合(可根据需要扩展), 只有包含表结构才能使用BaseModel,避免使用出现问题
|
||
type PO interface {
|
||
model.AiChatHi |
|
||
model.AiSy | model.AiSession | model.AiTask
|
||
}
|
||
|
||
type BaseModel[P PO] struct {
|
||
Db *gorm.DB // 数据库连接
|
||
}
|
||
|
||
func NewBaseModel[P PO](db *gorm.DB) BaseRepository[P] {
|
||
return &BaseModel[P]{
|
||
Db: db,
|
||
}
|
||
}
|
||
|
||
// 查询条件函数类型,支持链式条件组装, 由外部调用者自定义组装
|
||
type CondFunc = func(db *gorm.DB) *gorm.DB
|
||
|
||
// 定义通用数据库操作接口
|
||
type BaseRepository[P PO] interface {
|
||
FindAll(conditions ...CondFunc) ([]P, error) // 查询所有
|
||
Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) // 分页查询
|
||
FindOne(conditions ...CondFunc) (P, bool, error) // 查询单条记录,若未找到则返回 has=false, err=nil
|
||
Create(m *P) error // 创建
|
||
BatchCreate(m *[]P) (err error) // 批量创建
|
||
Update(m *P, conditions ...CondFunc) (err error) // 更新
|
||
Delete(conditions ...CondFunc) (err error) // 删除
|
||
Count(conditions ...CondFunc) (count int64, err error) // 查询条数
|
||
PaginateScope(page, pageSize int) CondFunc // 分页
|
||
OrderByDesc(field string) CondFunc // 倒序排序
|
||
WithId(id interface{}) CondFunc // 查询id
|
||
WithStatus(status int) CondFunc // 查询status
|
||
GetDb() *gorm.DB // 获取数据库连接
|
||
WithLimit(limit int) CondFunc // 限制返回条数
|
||
}
|
||
|
||
// PaginationResult 分页查询结果
|
||
type PaginationResult[P any] struct {
|
||
List []P `json:"list"` // 数据列表
|
||
Total int64 `json:"total"` // 总记录数
|
||
Page int `json:"page"` // 当前页码(从1开始)
|
||
PageSize int `json:"page_size"` // 每页数量
|
||
}
|
||
|
||
// 分页查询实现
|
||
func (this *BaseModel[P]) Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) {
|
||
if page < 1 || pageSize < 1 {
|
||
return nil, ErrInvalidPage
|
||
}
|
||
|
||
db := this.Db.Model(new(P)).Scopes(conditions...) // 自动绑定泛型类型对应的表
|
||
|
||
var (
|
||
total int64
|
||
list []P
|
||
)
|
||
|
||
// 先统计总数
|
||
if err := db.Count(&total).Error; err != nil {
|
||
return nil, fmt.Errorf("统计总数失败: %w", err)
|
||
}
|
||
|
||
// 再查询分页数据
|
||
if err := db.Scopes(this.PaginateScope(page, pageSize)).Find(&list).Error; err != nil {
|
||
return nil, fmt.Errorf("分页查询失败: %w", err)
|
||
}
|
||
|
||
return &PaginationResult[P]{
|
||
List: list,
|
||
Total: total,
|
||
Page: page,
|
||
PageSize: pageSize,
|
||
}, nil
|
||
}
|
||
|
||
// 查询所有
|
||
func (this *BaseModel[P]) FindAll(conditions ...CondFunc) ([]P, error) {
|
||
var list []P
|
||
err := this.Db.Model(new(P)).
|
||
Scopes(conditions...).
|
||
Find(&list).Error
|
||
return list, err
|
||
}
|
||
|
||
// 查询单条记录(优先返回第一条)
|
||
func (this *BaseModel[P]) FindOne(conditions ...CondFunc) (P, bool, error) {
|
||
var (
|
||
result P
|
||
)
|
||
|
||
err := this.Db.Model(new(P)).
|
||
Scopes(conditions...).
|
||
First(&result).
|
||
Error
|
||
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return result, false, nil // 未找到记录
|
||
}
|
||
if err != nil {
|
||
return result, false, fmt.Errorf("查询单条记录失败: %w", err)
|
||
}
|
||
|
||
return result, true, err
|
||
}
|
||
|
||
// 创建
|
||
func (this *BaseModel[P]) Create(m *P) error {
|
||
err := this.Db.Create(m).Error
|
||
return err
|
||
}
|
||
|
||
// 批量创建
|
||
func (this *BaseModel[P]) BatchCreate(m *[]P) (err error) {
|
||
//使用 CreateInBatches 方法分批插入用户,批次大小为 100
|
||
err = this.Db.CreateInBatches(m, 100).Error
|
||
return err
|
||
}
|
||
|
||
// 按条件更新记录(必须带条件)
|
||
func (this *BaseModel[P]) Update(m *P, conditions ...CondFunc) (err error) {
|
||
// 没有更新条件
|
||
if len(conditions) == 0 {
|
||
return ErrNoConditions
|
||
}
|
||
|
||
return this.Db.Model(new(P)).
|
||
Scopes(conditions...).
|
||
Updates(m).
|
||
Error
|
||
}
|
||
|
||
// 按条件删除记录(必须带条件)
|
||
func (this *BaseModel[P]) Delete(conditions ...CondFunc) (err error) {
|
||
// 没有更新条件
|
||
if len(conditions) == 0 {
|
||
return ErrNoConditions
|
||
}
|
||
|
||
return this.Db.Model(new(P)).
|
||
Scopes(conditions...).
|
||
Delete(new(P)).
|
||
Error
|
||
}
|
||
|
||
// 统计符合条件的记录数
|
||
func (this *BaseModel[P]) Count(conditions ...CondFunc) (count int64, err error) {
|
||
err = this.Db.Model(new(P)).
|
||
Scopes(conditions...).
|
||
Count(&count).
|
||
Error
|
||
return count, err
|
||
}
|
||
|
||
// 分页条件生成器
|
||
func (this *BaseModel[P]) PaginateScope(page, pageSize int) CondFunc {
|
||
return func(db *gorm.DB) *gorm.DB {
|
||
if page < 1 || pageSize < 1 {
|
||
return db // 由上层方法校验参数有效性
|
||
}
|
||
offset := (page - 1) * pageSize
|
||
return db.Offset(offset).Limit(pageSize)
|
||
}
|
||
}
|
||
|
||
// 倒序排序条件生成器
|
||
func (this *BaseModel[P]) OrderByDesc(field string) CondFunc {
|
||
return func(db *gorm.DB) *gorm.DB {
|
||
return db.Order(fmt.Sprintf("%s DESC", field))
|
||
}
|
||
}
|
||
|
||
// ID查询条件生成器
|
||
func (this *BaseModel[P]) WithId(id interface{}) CondFunc {
|
||
return func(db *gorm.DB) *gorm.DB {
|
||
return db.Where("id = ?", id)
|
||
}
|
||
}
|
||
|
||
// 状态查询条件生成器
|
||
func (this *BaseModel[P]) WithStatus(status int) CondFunc {
|
||
return func(db *gorm.DB) *gorm.DB {
|
||
return db.Where("status =?", status)
|
||
}
|
||
}
|
||
|
||
// 获取数据链接
|
||
func (this *BaseModel[P]) GetDb() *gorm.DB {
|
||
return this.Db
|
||
}
|
||
|
||
func (this *BaseModel[P]) WithLimit(limit int) CondFunc {
|
||
return func(db *gorm.DB) *gorm.DB {
|
||
return db.Limit(limit)
|
||
}
|
||
}
|