package repository import ( "context" "errors" "gorm.io/gorm" "gorm.io/gorm/clause" "knowlege-lsxd/internal/logger" "knowlege-lsxd/internal/types" "knowlege-lsxd/internal/types/interfaces" ) var ( ErrTenantNotFound = errors.New("tenant not found") ErrTenantHasKnowledgeBase = errors.New("tenant has associated knowledge bases") ) // tenantRepository implements tenant repository interface type tenantRepository struct { db *gorm.DB } // NewTenantRepository creates a new tenant repository func NewTenantRepository(db *gorm.DB) interfaces.TenantRepository { return &tenantRepository{db: db} } // CreateTenant creates tenant func (r *tenantRepository) CreateTenant(ctx context.Context, tenant *types.Tenant) error { return r.db.WithContext(ctx).Create(tenant).Error } // GetTenantByID gets tenant by ID func (r *tenantRepository) GetTenantByID(ctx context.Context, id uint) (*types.Tenant, error) { var tenant types.Tenant if err := r.db.WithContext(ctx).Where("id = ?", id).First(&tenant).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrTenantNotFound } return nil, err } return &tenant, nil } // ListTenants lists all tenants func (r *tenantRepository) ListTenants(ctx context.Context) ([]*types.Tenant, error) { var tenants []*types.Tenant if err := r.db.WithContext(ctx).Order("created_at DESC").Find(&tenants).Error; err != nil { return nil, err } return tenants, nil } // UpdateTenant updates tenant func (r *tenantRepository) UpdateTenant(ctx context.Context, tenant *types.Tenant) error { return r.db.WithContext(ctx).Model(&types.Tenant{}).Where("id = ?", tenant.ID).Updates(tenant).Error } // DeleteTenant deletes tenant func (r *tenantRepository) DeleteTenant(ctx context.Context, id uint) error { return r.db.WithContext(ctx).Where("id = ?", id).Delete(&types.Tenant{}).Error } func (r *tenantRepository) AdjustStorageUsed(ctx context.Context, tenantID uint, delta int64) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var tenant types.Tenant // 使用悲观锁确保并发安全 if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).First(&tenant, tenantID).Error; err != nil { return err } tenant.StorageUsed += delta // 保存更新并验证业务规则 if tenant.StorageUsed < 0 { logger.Error(ctx, "tenant storage used is negative %s: %d", tenant.ID, tenant.StorageUsed) tenant.StorageUsed = 0 } return tx.Save(&tenant).Error }) }