217 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			217 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
package impl
 | 
						||
 | 
						||
import (
 | 
						||
	"ai_scheduler/internal/data/model"
 | 
						||
	"errors"
 | 
						||
	"fmt"
 | 
						||
	"gorm.io/gorm"
 | 
						||
)
 | 
						||
 | 
						||
var (
 | 
						||
	ErrNoConditions = errors.New("不允许不带条件的操作")
 | 
						||
	ErrInvalidPage  = errors.New("page和pageSize必须为正整数")
 | 
						||
)
 | 
						||
 | 
						||
/*
 | 
						||
GORM 基础模型
 | 
						||
BaseModel 是一个泛型结构体,用于封装GORM数据库通用操作。
 | 
						||
支持的PO类型需实现具体的数据库模型(如ZxCreditOrder、ZxCreditLog等)
 | 
						||
*/
 | 
						||
 | 
						||
// 定义受支持的PO类型集合(可根据需要扩展), 只有包含表结构才能使用BaseModel,避免使用出现问题
 | 
						||
type PO interface {
 | 
						||
	model.AiChatHi |
 | 
						||
		model.AiSy | model.AiSession | model.AiTask
 | 
						||
}
 | 
						||
 | 
						||
type BaseModel[P PO] struct {
 | 
						||
	Db *gorm.DB // 数据库连接
 | 
						||
}
 | 
						||
 | 
						||
func NewBaseModel[P PO](db *gorm.DB) BaseRepository[P] {
 | 
						||
	return &BaseModel[P]{
 | 
						||
		Db: db,
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// 查询条件函数类型,支持链式条件组装, 由外部调用者自定义组装
 | 
						||
type CondFunc = func(db *gorm.DB) *gorm.DB
 | 
						||
 | 
						||
// 定义通用数据库操作接口
 | 
						||
type BaseRepository[P PO] interface {
 | 
						||
	FindAll(conditions ...CondFunc) ([]P, error)                                       // 查询所有
 | 
						||
	Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) // 分页查询
 | 
						||
	FindOne(conditions ...CondFunc) (P, bool, error)                                   // 查询单条记录,若未找到则返回 has=false, err=nil
 | 
						||
	Create(m *P) error                                                                 // 创建
 | 
						||
	BatchCreate(m *[]P) (err error)                                                    // 批量创建
 | 
						||
	Update(m *P, conditions ...CondFunc) (err error)                                   // 更新
 | 
						||
	Delete(conditions ...CondFunc) (err error)                                         // 删除
 | 
						||
	Count(conditions ...CondFunc) (count int64, err error)                             // 查询条数
 | 
						||
	PaginateScope(page, pageSize int) CondFunc                                         // 分页
 | 
						||
	OrderByDesc(field string) CondFunc                                                 // 倒序排序
 | 
						||
	WithId(id interface{}) CondFunc                                                    // 查询id
 | 
						||
	WithStatus(status int) CondFunc                                                    // 查询status
 | 
						||
	GetDb() *gorm.DB                                                                   // 获取数据库连接
 | 
						||
	WithLimit(limit int) CondFunc                                                      // 限制返回条数
 | 
						||
}
 | 
						||
 | 
						||
// PaginationResult 分页查询结果
 | 
						||
type PaginationResult[P any] struct {
 | 
						||
	List     []P   `json:"list"`      // 数据列表
 | 
						||
	Total    int64 `json:"total"`     // 总记录数
 | 
						||
	Page     int   `json:"page"`      // 当前页码(从1开始)
 | 
						||
	PageSize int   `json:"page_size"` // 每页数量
 | 
						||
}
 | 
						||
 | 
						||
// 分页查询实现
 | 
						||
func (this *BaseModel[P]) Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) {
 | 
						||
	if page < 1 || pageSize < 1 {
 | 
						||
		return nil, ErrInvalidPage
 | 
						||
	}
 | 
						||
 | 
						||
	db := this.Db.Model(new(P)).Scopes(conditions...) // 自动绑定泛型类型对应的表
 | 
						||
 | 
						||
	var (
 | 
						||
		total int64
 | 
						||
		list  []P
 | 
						||
	)
 | 
						||
 | 
						||
	// 先统计总数
 | 
						||
	if err := db.Count(&total).Error; err != nil {
 | 
						||
		return nil, fmt.Errorf("统计总数失败: %w", err)
 | 
						||
	}
 | 
						||
 | 
						||
	// 再查询分页数据
 | 
						||
	if err := db.Scopes(this.PaginateScope(page, pageSize)).Find(&list).Error; err != nil {
 | 
						||
		return nil, fmt.Errorf("分页查询失败: %w", err)
 | 
						||
	}
 | 
						||
 | 
						||
	return &PaginationResult[P]{
 | 
						||
		List:     list,
 | 
						||
		Total:    total,
 | 
						||
		Page:     page,
 | 
						||
		PageSize: pageSize,
 | 
						||
	}, nil
 | 
						||
}
 | 
						||
 | 
						||
// 查询所有
 | 
						||
func (this *BaseModel[P]) FindAll(conditions ...CondFunc) ([]P, error) {
 | 
						||
	var list []P
 | 
						||
	err := this.Db.Model(new(P)).
 | 
						||
		Scopes(conditions...).
 | 
						||
		Find(&list).Error
 | 
						||
	return list, err
 | 
						||
}
 | 
						||
 | 
						||
// 查询单条记录(优先返回第一条)
 | 
						||
func (this *BaseModel[P]) FindOne(conditions ...CondFunc) (P, bool, error) {
 | 
						||
	var (
 | 
						||
		result P
 | 
						||
	)
 | 
						||
 | 
						||
	err := this.Db.Model(new(P)).
 | 
						||
		Scopes(conditions...).
 | 
						||
		First(&result).
 | 
						||
		Error
 | 
						||
 | 
						||
	if errors.Is(err, gorm.ErrRecordNotFound) {
 | 
						||
		return result, false, nil // 未找到记录
 | 
						||
	}
 | 
						||
	if err != nil {
 | 
						||
		return result, false, fmt.Errorf("查询单条记录失败: %w", err)
 | 
						||
	}
 | 
						||
 | 
						||
	return result, true, err
 | 
						||
}
 | 
						||
 | 
						||
// 创建
 | 
						||
func (this *BaseModel[P]) Create(m *P) error {
 | 
						||
	err := this.Db.Create(m).Error
 | 
						||
	return err
 | 
						||
}
 | 
						||
 | 
						||
// 批量创建
 | 
						||
func (this *BaseModel[P]) BatchCreate(m *[]P) (err error) {
 | 
						||
	//使用 CreateInBatches 方法分批插入用户,批次大小为 100
 | 
						||
	err = this.Db.CreateInBatches(m, 100).Error
 | 
						||
	return err
 | 
						||
}
 | 
						||
 | 
						||
// 按条件更新记录(必须带条件)
 | 
						||
func (this *BaseModel[P]) Update(m *P, conditions ...CondFunc) (err error) {
 | 
						||
	// 没有更新条件
 | 
						||
	if len(conditions) == 0 {
 | 
						||
		return ErrNoConditions
 | 
						||
	}
 | 
						||
 | 
						||
	return this.Db.Model(new(P)).
 | 
						||
		Scopes(conditions...).
 | 
						||
		Updates(m).
 | 
						||
		Error
 | 
						||
}
 | 
						||
 | 
						||
// 按条件删除记录(必须带条件)
 | 
						||
func (this *BaseModel[P]) Delete(conditions ...CondFunc) (err error) {
 | 
						||
	// 没有更新条件
 | 
						||
	if len(conditions) == 0 {
 | 
						||
		return ErrNoConditions
 | 
						||
	}
 | 
						||
 | 
						||
	return this.Db.Model(new(P)).
 | 
						||
		Scopes(conditions...).
 | 
						||
		Delete(new(P)).
 | 
						||
		Error
 | 
						||
}
 | 
						||
 | 
						||
// 统计符合条件的记录数
 | 
						||
func (this *BaseModel[P]) Count(conditions ...CondFunc) (count int64, err error) {
 | 
						||
	err = this.Db.Model(new(P)).
 | 
						||
		Scopes(conditions...).
 | 
						||
		Count(&count).
 | 
						||
		Error
 | 
						||
	return count, err
 | 
						||
}
 | 
						||
 | 
						||
// 分页条件生成器
 | 
						||
func (this *BaseModel[P]) PaginateScope(page, pageSize int) CondFunc {
 | 
						||
	return func(db *gorm.DB) *gorm.DB {
 | 
						||
		if page < 1 || pageSize < 1 {
 | 
						||
			return db // 由上层方法校验参数有效性
 | 
						||
		}
 | 
						||
		offset := (page - 1) * pageSize
 | 
						||
		return db.Offset(offset).Limit(pageSize)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// 倒序排序条件生成器
 | 
						||
func (this *BaseModel[P]) OrderByDesc(field string) CondFunc {
 | 
						||
	return func(db *gorm.DB) *gorm.DB {
 | 
						||
		return db.Order(fmt.Sprintf("%s DESC", field))
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// ID查询条件生成器
 | 
						||
func (this *BaseModel[P]) WithId(id interface{}) CondFunc {
 | 
						||
	return func(db *gorm.DB) *gorm.DB {
 | 
						||
		return db.Where("id = ?", id)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// 状态查询条件生成器
 | 
						||
func (this *BaseModel[P]) WithStatus(status int) CondFunc {
 | 
						||
	return func(db *gorm.DB) *gorm.DB {
 | 
						||
		return db.Where("status =?", status)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// 获取数据链接
 | 
						||
func (this *BaseModel[P]) GetDb() *gorm.DB {
 | 
						||
	return this.Db
 | 
						||
}
 | 
						||
 | 
						||
func (this *BaseModel[P]) WithLimit(limit int) CondFunc {
 | 
						||
	return func(db *gorm.DB) *gorm.DB {
 | 
						||
		return db.Limit(limit)
 | 
						||
	}
 | 
						||
}
 |