l_ai_knowledge/internal/application/repository/knowledge.go

188 lines
5.7 KiB
Go

package repository
import (
"context"
"errors"
"gorm.io/gorm"
"knowlege-lsxd/internal/types"
"knowlege-lsxd/internal/types/interfaces"
)
var ErrKnowledgeNotFound = errors.New("knowledge not found")
// knowledgeRepository implements knowledge base and knowledge repository interface
type knowledgeRepository struct {
db *gorm.DB
}
// NewKnowledgeRepository creates a new knowledge repository
func NewKnowledgeRepository(db *gorm.DB) interfaces.KnowledgeRepository {
return &knowledgeRepository{db: db}
}
// CreateKnowledge creates knowledge
func (r *knowledgeRepository) CreateKnowledge(ctx context.Context, knowledge *types.Knowledge) error {
err := r.db.WithContext(ctx).Create(knowledge).Error
return err
}
// GetKnowledgeByID gets knowledge
func (r *knowledgeRepository) GetKnowledgeByID(ctx context.Context, tenantID uint, id string) (*types.Knowledge, error) {
var knowledge types.Knowledge
if err := r.db.WithContext(ctx).Where("tenant_id = ? AND id = ?", tenantID, id).First(&knowledge).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrKnowledgeNotFound
}
return nil, err
}
return &knowledge, nil
}
// ListKnowledgeByKnowledgeBaseID lists all knowledge in a knowledge base
func (r *knowledgeRepository) ListKnowledgeByKnowledgeBaseID(
ctx context.Context, tenantID uint, kbID string,
) ([]*types.Knowledge, error) {
var knowledges []*types.Knowledge
if err := r.db.WithContext(ctx).Where("tenant_id = ? AND knowledge_base_id = ?", tenantID, kbID).
Order("created_at DESC").Find(&knowledges).Error; err != nil {
return nil, err
}
return knowledges, nil
}
// ListPagedKnowledgeByKnowledgeBaseID lists all knowledge in a knowledge base with pagination
func (r *knowledgeRepository) ListPagedKnowledgeByKnowledgeBaseID(
ctx context.Context,
tenantID uint,
kbID string,
page *types.Pagination,
) ([]*types.Knowledge, int64, error) {
var knowledges []*types.Knowledge
var total int64
// Query total count first
if err := r.db.WithContext(ctx).Model(&types.Knowledge{}).
Where("tenant_id = ? AND knowledge_base_id = ?", tenantID, kbID).
Count(&total).Error; err != nil {
return nil, 0, err
}
// Then query paginated data
if err := r.db.WithContext(ctx).
Where("tenant_id = ? AND knowledge_base_id = ?", tenantID, kbID).
Order("created_at DESC").
Offset(page.Offset()).
Limit(page.Limit()).
Find(&knowledges).Error; err != nil {
return nil, 0, err
}
return knowledges, total, nil
}
// UpdateKnowledge updates knowledge
func (r *knowledgeRepository) UpdateKnowledge(ctx context.Context, knowledge *types.Knowledge) error {
err := r.db.WithContext(ctx).Save(knowledge).Error
return err
}
// DeleteKnowledge deletes knowledge
func (r *knowledgeRepository) DeleteKnowledge(ctx context.Context, tenantID uint, id string) error {
return r.db.WithContext(ctx).Where("tenant_id = ? AND id = ?", tenantID, id).Delete(&types.Knowledge{}).Error
}
// DeleteKnowledge deletes knowledge
func (r *knowledgeRepository) DeleteKnowledgeList(ctx context.Context, tenantID uint, ids []string) error {
return r.db.WithContext(ctx).Where("tenant_id = ? AND id in ?", tenantID, ids).Delete(&types.Knowledge{}).Error
}
// GetKnowledgeBatch gets knowledge in batch
func (r *knowledgeRepository) GetKnowledgeBatch(
ctx context.Context, tenantID uint, ids []string,
) ([]*types.Knowledge, error) {
var knowledge []*types.Knowledge
if err := r.db.WithContext(ctx).Debug().
Where("tenant_id = ? AND id IN ?", tenantID, ids).
Find(&knowledge).Error; err != nil {
return nil, err
}
return knowledge, nil
}
// CheckKnowledgeExists checks if knowledge already exists
func (r *knowledgeRepository) CheckKnowledgeExists(
ctx context.Context,
tenantID uint,
kbID string,
params *types.KnowledgeCheckParams,
) (bool, *types.Knowledge, error) {
query := r.db.WithContext(ctx).Model(&types.Knowledge{}).
Where("tenant_id = ? AND knowledge_base_id = ? AND enable_status = ?", tenantID, kbID, "enabled")
if params.Type == "file" {
// If file hash exists, prioritize exact match using hash
if params.FileHash != "" {
var knowledge types.Knowledge
err := query.Where("file_hash = ?", params.FileHash).First(&knowledge).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil, nil
}
return false, nil, err
}
return true, &knowledge, nil
}
// If no hash or hash doesn't match, use filename and size
if params.FileName != "" && params.FileSize > 0 {
var knowledge types.Knowledge
err := query.Where(
"file_name = ? AND file_size = ?",
params.FileName, params.FileSize,
).First(&knowledge).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil, nil
}
return false, nil, err
}
return true, &knowledge, nil
}
} else if params.Type == "url" {
if params.URL != "" {
var knowledge types.Knowledge
err := query.Where("type = 'url' AND source = ?", params.URL).First(&knowledge).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil, nil
}
return false, nil, err
}
return true, &knowledge, nil
}
}
// No valid parameters, default to not existing
return false, nil, nil
}
func (r *knowledgeRepository) AminusB(
ctx context.Context,
Atenant uint, A string,
Btenant uint, B string,
) ([]string, error) {
knowledgeIDs := []string{}
subQuery := r.db.Model(&types.Knowledge{}).
Where("tenant_id = ? AND knowledge_base_id = ?", Btenant, B).Select("file_hash")
err := r.db.Model(&types.Knowledge{}).
Where("tenant_id = ? AND knowledge_base_id = ?", Atenant, A).
Where("file_hash NOT IN (?)", subQuery).
Pluck("id", &knowledgeIDs).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return knowledgeIDs, nil
}
return knowledgeIDs, err
}