ai_scheduler/internal/data/impl/base.go

217 lines
6.1 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package impl
import (
"ai_scheduler/internal/data/model"
"errors"
"fmt"
"gorm.io/gorm"
)
var (
ErrNoConditions = errors.New("不允许不带条件的操作")
ErrInvalidPage = errors.New("page和pageSize必须为正整数")
)
/*
GORM 基础模型
BaseModel 是一个泛型结构体用于封装GORM数据库通用操作。
支持的PO类型需实现具体的数据库模型如ZxCreditOrder、ZxCreditLog等
*/
// 定义受支持的PO类型集合可根据需要扩展, 只有包含表结构才能使用BaseModel避免使用出现问题
type PO interface {
model.AiChatHi |
model.AiSy | model.AiSession | model.AiTask
}
type BaseModel[P PO] struct {
Db *gorm.DB // 数据库连接
}
func NewBaseModel[P PO](db *gorm.DB) BaseRepository[P] {
return &BaseModel[P]{
Db: db,
}
}
// 查询条件函数类型,支持链式条件组装, 由外部调用者自定义组装
type CondFunc = func(db *gorm.DB) *gorm.DB
// 定义通用数据库操作接口
type BaseRepository[P PO] interface {
FindAll(conditions ...CondFunc) ([]P, error) // 查询所有
Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) // 分页查询
FindOne(conditions ...CondFunc) (P, bool, error) // 查询单条记录,若未找到则返回 has=false, err=nil
Create(m *P) error // 创建
BatchCreate(m *[]P) (err error) // 批量创建
Update(m *P, conditions ...CondFunc) (err error) // 更新
Delete(conditions ...CondFunc) (err error) // 删除
Count(conditions ...CondFunc) (count int64, err error) // 查询条数
PaginateScope(page, pageSize int) CondFunc // 分页
OrderByDesc(field string) CondFunc // 倒序排序
WithId(id interface{}) CondFunc // 查询id
WithStatus(status int) CondFunc // 查询status
GetDb() *gorm.DB // 获取数据库连接
WithLimit(limit int) CondFunc // 限制返回条数
}
// PaginationResult 分页查询结果
type PaginationResult[P any] struct {
List []P `json:"list"` // 数据列表
Total int64 `json:"total"` // 总记录数
Page int `json:"page"` // 当前页码从1开始
PageSize int `json:"page_size"` // 每页数量
}
// 分页查询实现
func (this *BaseModel[P]) Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) {
if page < 1 || pageSize < 1 {
return nil, ErrInvalidPage
}
db := this.Db.Model(new(P)).Scopes(conditions...) // 自动绑定泛型类型对应的表
var (
total int64
list []P
)
// 先统计总数
if err := db.Count(&total).Error; err != nil {
return nil, fmt.Errorf("统计总数失败: %w", err)
}
// 再查询分页数据
if err := db.Scopes(this.PaginateScope(page, pageSize)).Find(&list).Error; err != nil {
return nil, fmt.Errorf("分页查询失败: %w", err)
}
return &PaginationResult[P]{
List: list,
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
// 查询所有
func (this *BaseModel[P]) FindAll(conditions ...CondFunc) ([]P, error) {
var list []P
err := this.Db.Model(new(P)).
Scopes(conditions...).
Find(&list).Error
return list, err
}
// 查询单条记录(优先返回第一条)
func (this *BaseModel[P]) FindOne(conditions ...CondFunc) (P, bool, error) {
var (
result P
)
err := this.Db.Model(new(P)).
Scopes(conditions...).
First(&result).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return result, false, nil // 未找到记录
}
if err != nil {
return result, false, fmt.Errorf("查询单条记录失败: %w", err)
}
return result, true, err
}
// 创建
func (this *BaseModel[P]) Create(m *P) error {
err := this.Db.Create(m).Error
return err
}
// 批量创建
func (this *BaseModel[P]) BatchCreate(m *[]P) (err error) {
//使用 CreateInBatches 方法分批插入用户,批次大小为 100
err = this.Db.CreateInBatches(m, 100).Error
return err
}
// 按条件更新记录(必须带条件)
func (this *BaseModel[P]) Update(m *P, conditions ...CondFunc) (err error) {
// 没有更新条件
if len(conditions) == 0 {
return ErrNoConditions
}
return this.Db.Model(new(P)).
Scopes(conditions...).
Updates(m).
Error
}
// 按条件删除记录(必须带条件)
func (this *BaseModel[P]) Delete(conditions ...CondFunc) (err error) {
// 没有更新条件
if len(conditions) == 0 {
return ErrNoConditions
}
return this.Db.Model(new(P)).
Scopes(conditions...).
Delete(new(P)).
Error
}
// 统计符合条件的记录数
func (this *BaseModel[P]) Count(conditions ...CondFunc) (count int64, err error) {
err = this.Db.Model(new(P)).
Scopes(conditions...).
Count(&count).
Error
return count, err
}
// 分页条件生成器
func (this *BaseModel[P]) PaginateScope(page, pageSize int) CondFunc {
return func(db *gorm.DB) *gorm.DB {
if page < 1 || pageSize < 1 {
return db // 由上层方法校验参数有效性
}
offset := (page - 1) * pageSize
return db.Offset(offset).Limit(pageSize)
}
}
// 倒序排序条件生成器
func (this *BaseModel[P]) OrderByDesc(field string) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Order(fmt.Sprintf("%s DESC", field))
}
}
// ID查询条件生成器
func (this *BaseModel[P]) WithId(id interface{}) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("id = ?", id)
}
}
// 状态查询条件生成器
func (this *BaseModel[P]) WithStatus(status int) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("status =?", status)
}
}
// 获取数据链接
func (this *BaseModel[P]) GetDb() *gorm.DB {
return this.Db
}
func (this *BaseModel[P]) WithLimit(limit int) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Limit(limit)
}
}