package dataTemp import ( "context" "database/sql" "fmt" "geo/tmpl/errcode" "geo/utils" "reflect" "github.com/go-kratos/kratos/v2/log" "gorm.io/gorm" "xorm.io/builder" ) type PrimaryKey struct { Id int `json:"id"` } type GormDb struct { Client *gorm.DB } type contextTxKey struct{} func (d *Db) DB(ctx context.Context) *gorm.DB { tx, ok := ctx.Value(contextTxKey{}).(*gorm.DB) if ok { return tx } return d.Db.Client } func (t *Db) ExecTx(ctx context.Context, f func(ctx context.Context) error) error { return t.Db.Client.WithContext(ctx).Transaction(func(tx *gorm.DB) error { ctx = context.WithValue(ctx, contextTxKey{}, tx) return f(ctx) }) } type Db struct { Db *GormDb Log *log.Helper } type DataTemp struct { Db *gorm.DB ModelType reflect.Type // 改为存储类型而不是实例 modelName string // 可选的表名缓存 } func NewDataTemp(db *utils.Db, model interface{}) *DataTemp { // 获取模型的类型 t := reflect.TypeOf(model) if t.Kind() == reflect.Ptr { t = t.Elem() } return &DataTemp{ Db: db.Client, ModelType: t, } } func (k DataTemp) modelInstance() interface{} { return reflect.New(k.ModelType).Interface() } func (k DataTemp) GetById(id int32) (data map[string]interface{}, err error) { err = k.Db.Model(k.modelInstance()).Where("id = ?", id).Find(&data).Error if data == nil { err = sql.ErrNoRows } return } func (k DataTemp) GetByStruct(ctx context.Context, search interface{}, data interface{}, orderBy string) (err error) { err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(search).Find(&data).Error if data == nil { err = sql.ErrNoRows } return } func (k DataTemp) SaveByStruct(search interface{}, data interface{}) (err error) { err = k.Db.Model(k.modelInstance()).Where(search).Save(&data).Error if data == nil { err = sql.ErrNoRows } return } func (k DataTemp) Add(ctx context.Context, data interface{}) (err error) { m := k.modelInstance() if err = k.Db.Model(m).WithContext(ctx).Create(data).Error; err != nil { return errcode.SqlErr(err) } return } func (k DataTemp) AddWithData(data interface{}) (interface{}, error) { result := k.Db.Model(k.modelInstance()).Create(data) if result.Error != nil { return data, result.Error } return data, nil } func (k DataTemp) GetList(cond *builder.Cond, pageBoIn *ReqPageBo) (list []map[string]interface{}, pageBoOut *RespPageBo, err error) { var ( query, _ = builder.ToBoundSQL(*cond) model = k.Db.Model(k.modelInstance()).Where(query) total int64 ) model.Count(&total) pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn) model.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order("updated_at desc").Find(&list) return } func (k DataTemp) GetRange(ctx context.Context, cond *builder.Cond) (list []map[string]interface{}, err error) { var ( query, _ = builder.ToBoundSQL(*cond) model = k.Db.Model(k.modelInstance()).Where(query) ) err = model.WithContext(ctx).Find(&list).Error return list, err } func (k DataTemp) GetRangeToMapStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) { var ( query, _ = builder.ToBoundSQL(*cond) model = k.Db.Model(k.modelInstance()).Where(query) ) err = model.WithContext(ctx).Find(data).Error return err } func (k DataTemp) GetOneBySearch(cond *builder.Cond) (data map[string]interface{}, err error) { query, _ := builder.ToBoundSQL(*cond) if err = k.Db.Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error; err != nil { return nil, errcode.SqlErr(err) } return } func (k DataTemp) Exist(ctx context.Context, cond *builder.Cond) (bool, error) { var data map[string]interface{} query, _ := builder.ToBoundSQL(*cond) err := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query).Limit(1).Find(&data).Error if err != nil || data != nil { return true, err } return false, nil } func (k DataTemp) GetOneBySearchStruct(ctx context.Context, cond *builder.Cond, data interface{}) (err error) { query, _ := builder.ToBoundSQL(*cond) if err = k.Db.Model(k.modelInstance()).WithContext(ctx).Where(query).Limit(1).Find(&data).Error; err != nil { return errcode.SqlErr(err) } return } func (k DataTemp) GetListToStruct(ctx context.Context, cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}, orderBy string) (pageBoOut *RespPageBo, err error) { // 参数验证 if result == nil { return nil, fmt.Errorf("result cannot be nil") } val := reflect.ValueOf(result) if val.Kind() != reflect.Ptr { return nil, fmt.Errorf("result must be a pointer") } elem := val.Elem() if elem.Kind() != reflect.Slice { return nil, fmt.Errorf("result must be a pointer to slice") } // 构建基础查询 query, _ := builder.ToBoundSQL(*cond) // 预编译 SQL 以提高性能 // 使用 Table 指定表名,避免 GORM 的反射开销 db := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(query) // 获取总数(使用单独的计数查询,避免缓存影响) var total int64 countDb := db if pageBoIn != nil { if err = countDb.Count(&total).Error; err != nil { return nil, err } } // 初始化分页响应 pageBoOut = &RespPageBo{} pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn) // 如果没有数据,直接返回空切片 if total == 0 && pageBoIn != nil { elem.Set(reflect.MakeSlice(elem.Type(), 0, 0)) return pageBoOut, nil } // 设置排序(使用索引字段提高性能) if orderBy == "" { orderBy = "updated_at desc" } // 应用分页和排序,执行查询 // 使用 Select 指定字段,避免查询所有字段(如果需要优化) baseQuery := db if pageBoIn != nil { baseQuery = db.Limit(pageBoIn.GetSize()).Offset(pageBoIn.GetOffset()).Order(orderBy) } if err = baseQuery. Order(orderBy). Find(result).Error; err != nil { return nil, err } return pageBoOut, nil } func (k DataTemp) UpdateByKey(ctx context.Context, key string, id interface{}, data interface{}) (err error) { if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), id).Updates(data).Error; err != nil { return errcode.SqlErr(err) } return } func (k DataTemp) UpdateByCond(ctx context.Context, cond *builder.Cond, data interface{}) (err error) { var ( query, _ = builder.ToBoundSQL(*cond) model = k.Db.Model(k.modelInstance()).Where(query) ) err = model.WithContext(ctx).Updates(data).Error return err } func (k DataTemp) UpdateColumnByCond(ctx context.Context, cond *builder.Cond, column string, data interface{}) (err error) { var ( query, _ = builder.ToBoundSQL(*cond) model = k.Db.Model(k.modelInstance()).Where(query) ) err = model.WithContext(ctx).Update(column, data).Error return err } func (k DataTemp) GetByKey(ctx context.Context, key string, value interface{}, data interface{}) (err error) { if err = k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value).Find(data).Error; err != nil { return errcode.SqlErr(err) } return } func (k DataTemp) DeleteByKey(ctx context.Context, key string, value int64) error { result := k.Db.WithContext(ctx).Model(k.modelInstance()).Where(fmt.Sprintf("%s = ?", key), value). Update("deleted_at", gorm.Expr("CURRENT_TIMESTAMP")) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return errcode.NotFound("不存在或已被删除") } return nil } func (k DataTemp) CountByCond(ctx context.Context, cond *builder.Cond) (int64, error) { var ( count int64 query, _ = builder.ToBoundSQL(*cond) model = k.Db.Model(k.modelInstance()).Where(query) ) err := model.WithContext(ctx).Count(&count).Error if err != nil { return 0, err } return count, err }