diff --git a/internal/biz/session.go b/internal/biz/session.go index d8b45d8..2d3f419 100644 --- a/internal/biz/session.go +++ b/internal/biz/session.go @@ -45,7 +45,7 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe // 获取 当天的最新session t := time.Now().Truncate(24 * time.Hour) - session, has, err := s.sessionRepo.FindOne( + session, has, err := s.sessionRepo.Take( s.sessionRepo.WithUserId(req.UserId), // 条件:用户ID s.sessionRepo.WithStartTime(t), // 条件:会话开始时间 s.sysRepo.WithSysId(sysConfig.SysID), diff --git a/internal/data/impl/base.go b/internal/data/impl/base.go index 8eb0b7d..501f710 100644 --- a/internal/data/impl/base.go +++ b/internal/data/impl/base.go @@ -43,20 +43,21 @@ type BaseRepository[P PO] interface { FindAll(conditions ...CondFunc) ([]P, error) // 查询所有 Paginate(page, pageSize int, conditions ...CondFunc) (*PaginationResult[P], error) // 分页查询 FindOne(conditions ...CondFunc) (P, bool, error) // 查询单条记录,若未找到则返回 has=false, err=nil - Create(m *P) error // 创建 - BatchCreate(m *[]P) (err error) // 批量创建 - Update(m *P, conditions ...CondFunc) (err error) // 更新 - Delete(conditions ...CondFunc) (err error) // 删除 - Count(conditions ...CondFunc) (count int64, err error) // 查询条数 - PaginateScope(page, pageSize int) CondFunc // 分页 - OrderByDesc(field string) CondFunc // 倒序排序 - OrderByAsc(field string) CondFunc // 正序排序 - WithId(id interface{}) CondFunc // 查询id - WithStatus(status int) CondFunc // 查询status - GetDb() *gorm.DB // 获取数据库连接 - WithLimit(limit int) CondFunc // 限制返回条数 - In(field string, values interface{}) CondFunc // 查询字段是否在列表中 - Select(fields ...string) CondFunc // 选择字段 + Take(conditions ...CondFunc) (P, bool, error) + Create(m *P) error // 创建 + BatchCreate(m *[]P) (err error) // 批量创建 + Update(m *P, conditions ...CondFunc) (err error) // 更新 + Delete(conditions ...CondFunc) (err error) // 删除 + Count(conditions ...CondFunc) (count int64, err error) // 查询条数 + PaginateScope(page, pageSize int) CondFunc // 分页 + OrderByDesc(field string) CondFunc // 倒序排序 + OrderByAsc(field string) CondFunc // 正序排序 + WithId(id interface{}) CondFunc // 查询id + WithStatus(status int) CondFunc // 查询status + GetDb() *gorm.DB // 获取数据库连接 + WithLimit(limit int) CondFunc // 限制返回条数 + In(field string, values interface{}) CondFunc // 查询字段是否在列表中 + Select(fields ...string) CondFunc // 选择字段 } // PaginationResult 分页查询结果 @@ -128,6 +129,26 @@ func (this *BaseModel[P]) FindOne(conditions ...CondFunc) (P, bool, error) { return result, true, err } +func (this *BaseModel[P]) Take(conditions ...CondFunc) (P, bool, error) { + var ( + result P + ) + + err := this.Db.Model(new(P)). + Scopes(conditions...). + Take(&result). + Error + + if errors.Is(err, gorm.ErrRecordNotFound) { + return result, false, nil // 未找到记录 + } + if err != nil { + return result, false, fmt.Errorf("查询单条记录失败: %w", err) + } + + return result, true, err +} + // 创建 func (this *BaseModel[P]) Create(m *P) error { err := this.Db.Create(m).Error