290 lines
7.6 KiB
Go
290 lines
7.6 KiB
Go
package dataTemp
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"geo/tmpl/errcode"
|
|
"geo/utils"
|
|
"reflect"
|
|
|
|
"github.com/go-kratos/kratos/v2/log"
|
|
"gorm.io/gorm"
|
|
|
|
"xorm.io/builder"
|
|
)
|
|
|
|
type PrimaryKey struct {
|
|
Id int `json:"id"`
|
|
}
|
|
|
|
type GormDb struct {
|
|
Client *gorm.DB
|
|
}
|
|
type contextTxKey struct{}
|
|
|
|
func (d *Db) DB(ctx context.Context) *gorm.DB {
|
|
tx, ok := ctx.Value(contextTxKey{}).(*gorm.DB)
|
|
if ok {
|
|
return tx
|
|
}
|
|
return d.Db.Client
|
|
}
|
|
|
|
func (t *Db) ExecTx(ctx context.Context, f func(ctx context.Context) error) error {
|
|
return t.Db.Client.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
ctx = context.WithValue(ctx, contextTxKey{}, tx)
|
|
return f(ctx)
|
|
})
|
|
}
|
|
|
|
type Db struct {
|
|
Db *GormDb
|
|
Log *log.Helper
|
|
}
|
|
|
|
type DataTemp struct {
|
|
Db *gorm.DB
|
|
ModelType reflect.Type // 改为存储类型而不是实例
|
|
modelName string // 可选的表名缓存
|
|
}
|
|
|
|
func NewDataTemp(db *utils.Db, model interface{}) *DataTemp {
|
|
// 获取模型的类型
|
|
t := reflect.TypeOf(model)
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
|
|
return &DataTemp{
|
|
Db: db.Client,
|
|
ModelType: t,
|
|
}
|
|
}
|
|
|
|
func (k DataTemp) modelInstance() interface{} {
|
|
return reflect.New(k.ModelType).Interface()
|
|
}
|
|
|
|
func (k DataTemp) GetById(id int32) (data map[string]interface{}, err error) {
|
|
err = k.Db.Model(k.modelInstance()).Where("id = ?", id).Find(&data).Error
|
|
if data == nil {
|
|
err = sql.ErrNoRows
|
|
}
|
|
return
|
|
}
|
|
|
|
func (k DataTemp) GetByStruct(ctx context.Context, search interface{}, data interface{}, orderBy string) (err error) {
|
|
|
|
err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(search).Find(&data).Error
|
|
|
|
if data == nil {
|
|
err = sql.ErrNoRows
|
|
}
|
|
return
|
|
}
|
|
|
|
func (k DataTemp) SaveByStruct(search interface{}, data interface{}) (err error) {
|
|
err = k.Db.Model(k.modelInstance()).Where(search).Save(&data).Error
|
|
if data == nil {
|
|
err = sql.ErrNoRows
|
|
}
|
|
return
|
|
}
|
|
|
|
func (k DataTemp) Add(ctx context.Context, data interface{}) (err error) {
|
|
m := k.modelInstance()
|
|
if err = k.Db.Model(m).WithContext(ctx).Create(data).Error; err != nil {
|
|
return errcode.SqlErr(err)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (k DataTemp) AddWithData(data interface{}) (interface{}, error) {
|
|
result := k.Db.Model(k.modelInstance()).Create(data)
|
|
if result.Error != nil {
|
|
return data, result.Error
|
|
}
|
|
return data, nil
|
|
}
|
|
|
|
func (k DataTemp) GetList(cond *builder.Cond, pageBoIn *ReqPageBo) (list []map[string]interface{}, pageBoOut *RespPageBo, err error) {
|
|
var (
|
|
query, _ = builder.ToBoundSQL(*cond)
|
|
model = k.Db.Model(k.modelInstance()).Where(query)
|
|
total int64
|
|
)
|
|
model.Count(&total)
|
|
pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn)
|
|
model.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order("updated_at desc").Find(&list)
|
|
return
|
|
}
|
|
|
|
func (k DataTemp) GetRange(ctx context.Context, cond *builder.Cond) (list []map[string]interface{}, err error) {
|
|
var (
|
|
query, _ = builder.ToBoundSQL(*cond)
|
|
model = k.Db.Model(k.modelInstance()).Where(query)
|
|
)
|
|
err = model.WithContext(ctx).Find(&list).Error
|
|
return list, err
|
|
}
|
|
|
|
func (k DataTemp) GetRangeToMapStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) {
|
|
var (
|
|
query, _ = builder.ToBoundSQL(*cond)
|
|
model = k.Db.Model(k.modelInstance()).Where(query)
|
|
)
|
|
err = model.WithContext(ctx).Find(data).Error
|
|
return err
|
|
}
|
|
|
|
func (k DataTemp) GetOneBySearch(cond *builder.Cond) (data map[string]interface{}, err error) {
|
|
query, _ := builder.ToBoundSQL(*cond)
|
|
if err = k.Db.Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error; err != nil {
|
|
return nil, errcode.SqlErr(err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (k DataTemp) Exist(ctx context.Context, cond *builder.Cond) (bool, error) {
|
|
var data map[string]interface{}
|
|
query, _ := builder.ToBoundSQL(*cond)
|
|
err := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error
|
|
if err != nil || data != nil {
|
|
return true, err
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (k DataTemp) GetOneBySearchStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) {
|
|
query, _ := builder.ToBoundSQL(*cond)
|
|
if err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(query).Limit(1).Find(&data).Error; err != nil {
|
|
return errcode.SqlErr(err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (k DataTemp) GetListToStruct(ctx context.Context, cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}, orderBy string) (pageBoOut *RespPageBo, err error) {
|
|
// 参数验证
|
|
if result == nil {
|
|
return nil, fmt.Errorf("result cannot be nil")
|
|
}
|
|
|
|
val := reflect.ValueOf(result)
|
|
if val.Kind() != reflect.Ptr {
|
|
return nil, fmt.Errorf("result must be a pointer")
|
|
}
|
|
|
|
elem := val.Elem()
|
|
if elem.Kind() != reflect.Slice {
|
|
return nil, fmt.Errorf("result must be a pointer to slice")
|
|
}
|
|
|
|
// 构建基础查询
|
|
query, _ := builder.ToBoundSQL(*cond)
|
|
|
|
// 预编译 SQL 以提高性能
|
|
// 使用 Table 指定表名,避免 GORM 的反射开销
|
|
|
|
db := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query)
|
|
|
|
// 获取总数(使用单独的计数查询,避免缓存影响)
|
|
var total int64
|
|
countDb := db
|
|
if pageBoIn != nil {
|
|
if err = countDb.Count(&total).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// 初始化分页响应
|
|
pageBoOut = &RespPageBo{}
|
|
pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn)
|
|
|
|
// 如果没有数据,直接返回空切片
|
|
if total == 0 && pageBoIn != nil {
|
|
elem.Set(reflect.MakeSlice(elem.Type(), 0, 0))
|
|
return pageBoOut, nil
|
|
}
|
|
|
|
// 设置排序(使用索引字段提高性能)
|
|
if orderBy == "" {
|
|
orderBy = "updated_at desc"
|
|
}
|
|
|
|
// 应用分页和排序,执行查询
|
|
// 使用 Select 指定字段,避免查询所有字段(如果需要优化)
|
|
baseQuery := db
|
|
if pageBoIn != nil {
|
|
baseQuery = db.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order(orderBy)
|
|
}
|
|
if err = baseQuery.
|
|
Order(orderBy).
|
|
Find(result).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return pageBoOut, nil
|
|
}
|
|
|
|
func (k DataTemp) UpdateByKey(ctx context.Context, key string, id interface{}, data interface{}) (err error) {
|
|
if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), id).Updates(data).Error; err != nil {
|
|
return errcode.SqlErr(err)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (k DataTemp) UpdateByCond(ctx context.Context, cond *builder.Cond, data interface{}) (err error) {
|
|
var (
|
|
query, _ = builder.ToBoundSQL(*cond)
|
|
model = k.Db.Model(k.modelInstance()).Where(query)
|
|
)
|
|
err = model.WithContext(ctx).Updates(data).Error
|
|
return err
|
|
}
|
|
|
|
func (k DataTemp) UpdateColumnByCond(ctx context.Context, cond *builder.Cond, column string, data interface{}) (err error) {
|
|
var (
|
|
query, _ = builder.ToBoundSQL(*cond)
|
|
model = k.Db.Model(k.modelInstance()).Where(query)
|
|
)
|
|
err = model.WithContext(ctx).Update(column, data).Error
|
|
return err
|
|
}
|
|
|
|
func (k DataTemp) GetByKey(ctx context.Context, key string, value interface{}, data interface{}) (err error) {
|
|
if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value).Find(data).Error; err != nil {
|
|
return errcode.SqlErr(err)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (k DataTemp) DeleteByKey(ctx context.Context, key string, value int64) error {
|
|
result := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value).
|
|
Update("deleted_at", gorm.Expr("CURRENT_TIMESTAMP"))
|
|
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return errcode.NotFound("不存在或已被删除")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (k DataTemp) CountByCond(ctx context.Context, cond *builder.Cond) (int64, error) {
|
|
var (
|
|
count int64
|
|
query, _ = builder.ToBoundSQL(*cond)
|
|
model = k.Db.Model(k.modelInstance()).Where(query)
|
|
)
|
|
err := model.WithContext(ctx).Count(&count).Error
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return count, err
|
|
}
|