l_ai_knowledge/internal/application/service/graph.go

996 lines
31 KiB
Go

package service
import (
"context"
"encoding/json"
"fmt"
"math"
"slices"
"strings"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"knowlege-lsxd/internal/common"
"knowlege-lsxd/internal/config"
"knowlege-lsxd/internal/logger"
"knowlege-lsxd/internal/models/chat"
"knowlege-lsxd/internal/models/utils"
"knowlege-lsxd/internal/types"
)
const (
// DefaultLLMTemperature Use low temperature for more deterministic results
DefaultLLMTemperature = 0.1
// PMIWeight Proportion of PMI in calculating relationship weight
PMIWeight = 0.6
// StrengthWeight Proportion of relationship strength in calculating relationship weight
StrengthWeight = 0.4
// IndirectRelationWeightDecay Decay coefficient for indirect relationship weights
IndirectRelationWeightDecay = 0.5
// MaxConcurrentEntityExtractions Maximum concurrency for entity extraction
MaxConcurrentEntityExtractions = 4
// MaxConcurrentRelationExtractions Maximum concurrency for relationship extraction
MaxConcurrentRelationExtractions = 4
// DefaultRelationBatchSize Default batch size for relationship extraction
DefaultRelationBatchSize = 5
// MinEntitiesForRelation Minimum number of entities required for relationship extraction
MinEntitiesForRelation = 2
// MinWeightValue Minimum weight value to avoid division by zero
MinWeightValue = 1.0
// WeightScaleFactor Weight scaling factor to normalize weights to 1-10 range
WeightScaleFactor = 9.0
)
// ChunkRelation represents a relationship between two Chunks
type ChunkRelation struct {
// Weight relationship weight, calculated based on PMI and strength
Weight float64
// Degree total degree of related entities
Degree int
}
// graphBuilder implements knowledge graph construction functionality
type graphBuilder struct {
config *config.Config
entityMap map[string]*types.Entity // Entities indexed by ID
entityMapByTitle map[string]*types.Entity // Entities indexed by title
relationshipMap map[string]*types.Relationship // Relationship mapping
chatModel chat.Chat
chunkGraph map[string]map[string]*ChunkRelation // Document chunk relationship graph
mutex sync.RWMutex // Mutex for concurrent operations
}
// NewGraphBuilder creates a new graph builder
func NewGraphBuilder(config *config.Config, chatModel chat.Chat) types.GraphBuilder {
logger.Info(context.Background(), "Creating new graph builder")
return &graphBuilder{
config: config,
chatModel: chatModel,
entityMap: make(map[string]*types.Entity),
entityMapByTitle: make(map[string]*types.Entity),
relationshipMap: make(map[string]*types.Relationship),
chunkGraph: make(map[string]map[string]*ChunkRelation),
}
}
// extractEntities extracts entities from text chunks
// It uses LLM to analyze text content and identify relevant entities
func (b *graphBuilder) extractEntities(ctx context.Context, chunk *types.Chunk) ([]*types.Entity, error) {
log := logger.GetLogger(ctx)
log.Infof("Extracting entities from chunk: %s", chunk.ID)
if chunk.Content == "" {
log.Warn("Empty chunk content, skipping entity extraction")
return []*types.Entity{}, nil
}
// Create prompt for entity extraction
thinking := false
messages := []chat.Message{
{
Role: "system",
Content: b.config.Conversation.ExtractEntitiesPrompt,
},
{
Role: "user",
Content: chunk.Content,
},
}
// Call LLM to extract entities
log.Debug("Calling LLM to extract entities")
resp, err := b.chatModel.Chat(ctx, messages, &chat.ChatOptions{
Temperature: DefaultLLMTemperature,
Thinking: &thinking,
})
if err != nil {
log.WithError(err).Error("Failed to extract entities from chunk")
return nil, fmt.Errorf("LLM entity extraction failed: %w", err)
}
// Parse JSON response
var extractedEntities []*types.Entity
if err := common.ParseLLMJsonResponse(resp.Content, &extractedEntities); err != nil {
log.WithError(err).Errorf("Failed to parse entity extraction response, rsp content: %s", resp.Content)
return nil, fmt.Errorf("failed to parse entity extraction response: %w", err)
}
log.Infof("Extracted %d entities from chunk", len(extractedEntities))
// Print detailed entity information in a clear format
log.Info("=========== EXTRACTED ENTITIES ===========")
for i, entity := range extractedEntities {
log.Infof("[Entity %d] Title: '%s', Description: '%s'", i+1, entity.Title, entity.Description)
}
log.Info("=========================================")
var entities []*types.Entity
// Process entities and update entityMap
b.mutex.Lock()
defer b.mutex.Unlock()
for _, entity := range extractedEntities {
if entity.Title == "" || entity.Description == "" {
log.WithField("entity", entity).Warn("Invalid entity with empty title or description")
continue
}
if existEntity, exists := b.entityMapByTitle[entity.Title]; !exists {
// This is a new entity
entity.ID = uuid.New().String()
entity.ChunkIDs = []string{chunk.ID}
entity.Frequency = 1
b.entityMapByTitle[entity.Title] = entity
b.entityMap[entity.ID] = entity
entities = append(entities, entity)
log.Debugf("New entity added: %s (ID: %s)", entity.Title, entity.ID)
} else {
// Entity already exists, update its ChunkIDs
if !slices.Contains(existEntity.ChunkIDs, chunk.ID) {
existEntity.ChunkIDs = append(existEntity.ChunkIDs, chunk.ID)
log.Debugf("Updated existing entity: %s with chunk: %s", entity.Title, chunk.ID)
}
existEntity.Frequency++
entities = append(entities, existEntity)
}
}
log.Infof("Completed entity extraction for chunk %s: %d entities", chunk.ID, len(entities))
return entities, nil
}
// extractRelationships extracts relationships between entities
// It analyzes semantic connections between multiple entities and establishes relationships
func (b *graphBuilder) extractRelationships(ctx context.Context,
chunks []*types.Chunk, entities []*types.Entity) error {
log := logger.GetLogger(ctx)
log.Infof("Extracting relationships from %d entities across %d chunks", len(entities), len(chunks))
if len(entities) < MinEntitiesForRelation {
log.Info("Not enough entities to form relationships (minimum 2)")
return nil
}
// Serialize entities to build prompt
entitiesJSON, err := json.Marshal(entities)
if err != nil {
log.WithError(err).Error("Failed to serialize entities to JSON")
return fmt.Errorf("failed to serialize entities: %w", err)
}
// Merge chunk contents
content := b.mergeChunkContents(chunks)
if content == "" {
log.Warn("No content to extract relationships from")
return nil
}
// Create relationship extraction prompt
thinking := false
messages := []chat.Message{
{
Role: "system",
Content: b.config.Conversation.ExtractRelationshipsPrompt,
},
{
Role: "user",
Content: fmt.Sprintf("Entities: %s\n\nText: %s", string(entitiesJSON), content),
},
}
// Call LLM to extract relationships
log.Debug("Calling LLM to extract relationships")
resp, err := b.chatModel.Chat(ctx, messages, &chat.ChatOptions{
Temperature: DefaultLLMTemperature,
Thinking: &thinking,
})
if err != nil {
log.WithError(err).Error("Failed to extract relationships")
return fmt.Errorf("LLM relationship extraction failed: %w", err)
}
// Parse JSON response
var extractedRelationships []*types.Relationship
if err := common.ParseLLMJsonResponse(resp.Content, &extractedRelationships); err != nil {
log.WithError(err).Error("Failed to parse relationship extraction response")
return fmt.Errorf("failed to parse relationship extraction response: %w", err)
}
log.Infof("Extracted %d relationships", len(extractedRelationships))
// Print detailed relationship information in a clear format
log.Info("========= EXTRACTED RELATIONSHIPS =========")
for i, rel := range extractedRelationships {
log.Infof("[Relation %d] Source: '%s', Target: '%s', Description: '%s', Strength: %d",
i+1, rel.Source, rel.Target, rel.Description, rel.Strength)
}
log.Info("===========================================")
// Process relationships and update relationshipMap
b.mutex.Lock()
defer b.mutex.Unlock()
relationshipsAdded := 0
relationshipsUpdated := 0
for _, relationship := range extractedRelationships {
key := fmt.Sprintf("%s#%s", relationship.Source, relationship.Target)
relationChunkIDs := b.findRelationChunkIDs(relationship.Source, relationship.Target, entities)
if len(relationChunkIDs) == 0 {
log.Debugf("Skipping relationship %s -> %s: no common chunks", relationship.Source, relationship.Target)
continue
}
if existingRel, exists := b.relationshipMap[key]; !exists {
// This is a new relationship
relationship.ID = uuid.New().String()
relationship.ChunkIDs = relationChunkIDs
b.relationshipMap[key] = relationship
relationshipsAdded++
log.Debugf("New relationship added: %s -> %s (ID: %s)",
relationship.Source, relationship.Target, relationship.ID)
} else {
// This relationship already exists, update its properties
chunkIDsAdded := 0
for _, chunkID := range relationChunkIDs {
if !slices.Contains(existingRel.ChunkIDs, chunkID) {
existingRel.ChunkIDs = append(existingRel.ChunkIDs, chunkID)
chunkIDsAdded++
}
}
// Update strength, considering weighted average of existing strength and new relationship strength
if len(existingRel.ChunkIDs) > 0 {
existingRel.Strength = (existingRel.Strength*len(existingRel.ChunkIDs) + relationship.Strength) /
(len(existingRel.ChunkIDs) + 1)
}
if chunkIDsAdded > 0 {
relationshipsUpdated++
log.Debugf("Updated relationship: %s -> %s with %d new chunks",
relationship.Source, relationship.Target, chunkIDsAdded)
}
}
}
log.Infof("Relationship extraction completed: added %d, updated %d relationships",
relationshipsAdded, relationshipsUpdated)
return nil
}
// findRelationChunkIDs finds common document chunk IDs between two entities
func (b *graphBuilder) findRelationChunkIDs(source, target string, entities []*types.Entity) []string {
relationChunkIDs := make(map[string]struct{})
// Collect all document chunk IDs for source and target entities
for _, entity := range entities {
if entity.Title == source || entity.Title == target {
for _, chunkID := range entity.ChunkIDs {
relationChunkIDs[chunkID] = struct{}{}
}
}
}
if len(relationChunkIDs) == 0 {
return []string{}
}
// Convert map keys to slice
result := make([]string, 0, len(relationChunkIDs))
for chunkID := range relationChunkIDs {
result = append(result, chunkID)
}
return result
}
// mergeChunkContents merges content from multiple document chunks
// It accounts for overlapping portions between chunks to ensure coherent content
func (b *graphBuilder) mergeChunkContents(chunks []*types.Chunk) string {
if len(chunks) == 0 {
return ""
}
var chunkContents = chunks[0].Content
preChunk := chunks[0]
for i := 1; i < len(chunks); i++ {
// Only add non-overlapping content parts
if preChunk.EndAt > chunks[i].StartAt {
// Calculate overlap starting position
startPos := preChunk.EndAt - chunks[i].StartAt
if startPos >= 0 && startPos < len([]rune(chunks[i].Content)) {
chunkContents = chunkContents + string([]rune(chunks[i].Content)[startPos:])
}
} else {
// If there's no overlap between chunks, add all content
chunkContents = chunkContents + chunks[i].Content
}
preChunk = chunks[i]
}
return chunkContents
}
// BuildGraph constructs the knowledge graph
// It serves as the main entry point for the graph building process, coordinating all components
func (b *graphBuilder) BuildGraph(ctx context.Context, chunks []*types.Chunk) error {
log := logger.GetLogger(ctx)
log.Infof("Building knowledge graph from %d chunks", len(chunks))
startTime := time.Now()
// Concurrently extract entities from each document chunk
var chunkEntities = make([][]*types.Entity, len(chunks))
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(MaxConcurrentEntityExtractions) // Limit concurrency
for i, chunk := range chunks {
i, chunk := i, chunk // Create local variables to avoid closure issues
g.Go(func() error {
log.Debugf("Processing chunk %d/%d (ID: %s)", i+1, len(chunks), chunk.ID)
entities, err := b.extractEntities(gctx, chunk)
if err != nil {
log.WithError(err).Errorf("Failed to extract entities from chunk %s", chunk.ID)
return fmt.Errorf("entity extraction failed for chunk %s: %w", chunk.ID, err)
}
chunkEntities[i] = entities
return nil
})
}
// Wait for all entity extractions to complete
if err := g.Wait(); err != nil {
log.WithError(err).Error("Entity extraction failed")
return fmt.Errorf("entity extraction process failed: %w", err)
}
// Count total extracted entities
totalEntityCount := 0
for _, entities := range chunkEntities {
totalEntityCount += len(entities)
}
log.Infof("Successfully extracted %d total entities across %d chunks",
totalEntityCount, len(chunks))
// Process relationships in batches concurrently
relationChunkSize := DefaultRelationBatchSize
log.Infof("Processing relationships concurrently in batches of %d chunks", relationChunkSize)
// prepare relationship extraction batches
var relationBatches []struct {
batchChunks []*types.Chunk
relationUseEntities []*types.Entity
batchIndex int
}
for i, batchChunks := range utils.ChunkSlice(chunks, relationChunkSize) {
start := i * relationChunkSize
end := start + relationChunkSize
if end > len(chunkEntities) {
end = len(chunkEntities)
}
// Merge all entities in this batch
relationUseEntities := make([]*types.Entity, 0)
for j := start; j < end; j++ {
if j < len(chunkEntities) {
relationUseEntities = append(relationUseEntities, chunkEntities[j]...)
}
}
if len(relationUseEntities) < MinEntitiesForRelation {
log.Debugf("Skipping batch %d: not enough entities (%d)", i+1, len(relationUseEntities))
continue
}
relationBatches = append(relationBatches, struct {
batchChunks []*types.Chunk
relationUseEntities []*types.Entity
batchIndex int
}{
batchChunks: batchChunks,
relationUseEntities: relationUseEntities,
batchIndex: i,
})
}
// extract relationships concurrently
relG, relGctx := errgroup.WithContext(ctx)
relG.SetLimit(MaxConcurrentRelationExtractions) // use dedicated relationship extraction concurrency limit
for _, batch := range relationBatches {
relG.Go(func() error {
log.Debugf("Processing relationship batch %d (chunks %d)", batch.batchIndex+1, len(batch.batchChunks))
err := b.extractRelationships(relGctx, batch.batchChunks, batch.relationUseEntities)
if err != nil {
log.WithError(err).Errorf("Failed to extract relationships for batch %d", batch.batchIndex+1)
}
return nil // continue to process other batches even if the current batch fails
})
}
// wait for all relationship extractions to complete
if err := relG.Wait(); err != nil {
log.WithError(err).Error("Some relationship extraction tasks failed")
// but we continue to process the next steps because some relationship extractions are still useful
}
// Calculate relationship weights
log.Info("Calculating weights for relationships")
b.calculateWeights(ctx)
// Calculate entity degrees
log.Info("Calculating degrees for entities")
b.calculateDegrees(ctx)
// Build Chunk graph
log.Info("Building chunk relationship graph")
b.buildChunkGraph(ctx)
log.Infof("Graph building completed in %.2f seconds: %d entities, %d relationships",
time.Since(startTime).Seconds(), len(b.entityMap), len(b.relationshipMap))
// generate knowledge graph visualization diagram
mermaidDiagram := b.generateKnowledgeGraphDiagram(ctx)
log.Info("Knowledge graph visualization diagram:")
log.Info(mermaidDiagram)
return nil
}
// calculateWeights calculates relationship weights
// It uses Point Mutual Information (PMI) and strength values to calculate relationship weights
func (b *graphBuilder) calculateWeights(ctx context.Context) {
log := logger.GetLogger(ctx)
log.Info("Calculating relationship weights using PMI and strength")
// Calculate total entity occurrences
totalEntityOccurrences := 0
entityFrequency := make(map[string]int)
for _, entity := range b.entityMap {
frequency := len(entity.ChunkIDs)
entityFrequency[entity.Title] = frequency
totalEntityOccurrences += frequency
}
// Calculate total relationship occurrences
totalRelOccurrences := 0
for _, rel := range b.relationshipMap {
totalRelOccurrences += len(rel.ChunkIDs)
}
// Skip calculation if insufficient data
if totalEntityOccurrences == 0 || totalRelOccurrences == 0 {
log.Warn("Insufficient data for weight calculation")
return
}
// Track maximum PMI and Strength values for normalization
maxPMI := 0.0
maxStrength := MinWeightValue // Avoid division by zero
// First calculate PMI and find maximum values
pmiValues := make(map[string]float64)
for _, rel := range b.relationshipMap {
sourceFreq := entityFrequency[rel.Source]
targetFreq := entityFrequency[rel.Target]
relFreq := len(rel.ChunkIDs)
if sourceFreq > 0 && targetFreq > 0 && relFreq > 0 {
sourceProbability := float64(sourceFreq) / float64(totalEntityOccurrences)
targetProbability := float64(targetFreq) / float64(totalEntityOccurrences)
relProbability := float64(relFreq) / float64(totalRelOccurrences)
// PMI calculation: log(P(x,y) / (P(x) * P(y)))
pmi := math.Max(math.Log2(relProbability/(sourceProbability*targetProbability)), 0)
pmiValues[rel.ID] = pmi
if pmi > maxPMI {
maxPMI = pmi
}
}
// Record maximum Strength value
if float64(rel.Strength) > maxStrength {
maxStrength = float64(rel.Strength)
}
}
// Combine PMI and Strength to calculate final weights
for _, rel := range b.relationshipMap {
pmi := pmiValues[rel.ID]
// Normalize PMI and Strength (0-1 range)
normalizedPMI := 0.0
if maxPMI > 0 {
normalizedPMI = pmi / maxPMI
}
normalizedStrength := float64(rel.Strength) / maxStrength
// Combine PMI and Strength using configured weights
combinedWeight := normalizedPMI*PMIWeight + normalizedStrength*StrengthWeight
// Scale weight to 1-10 range
scaledWeight := 1.0 + WeightScaleFactor*combinedWeight
rel.Weight = scaledWeight
}
log.Infof("Weight calculation completed for %d relationships", len(b.relationshipMap))
}
// calculateDegrees calculates entity degrees
// Degree represents the number of connections an entity has with other entities, a key metric in graph structures
func (b *graphBuilder) calculateDegrees(ctx context.Context) {
log := logger.GetLogger(ctx)
log.Info("Calculating entity degrees")
// Calculate in-degree and out-degree for each entity
inDegree := make(map[string]int)
outDegree := make(map[string]int)
for _, rel := range b.relationshipMap {
outDegree[rel.Source]++
inDegree[rel.Target]++
}
// Set degree for each entity
for _, entity := range b.entityMap {
entity.Degree = inDegree[entity.Title] + outDegree[entity.Title]
}
// Set combined degree for relationships
for _, rel := range b.relationshipMap {
sourceEntity := b.getEntityByTitle(rel.Source)
targetEntity := b.getEntityByTitle(rel.Target)
if sourceEntity != nil && targetEntity != nil {
rel.CombinedDegree = sourceEntity.Degree + targetEntity.Degree
}
}
log.Info("Entity degree calculation completed")
}
// buildChunkGraph builds relationship graph between Chunks
// It creates a network of relationships between document chunks based on entity relationships
func (b *graphBuilder) buildChunkGraph(ctx context.Context) {
log := logger.GetLogger(ctx)
log.Info("Building chunk relationship graph")
// Create document chunk relationship graph based on entity relationships
for _, rel := range b.relationshipMap {
// Ensure source and target entities exist for the relationship
sourceEntity := b.entityMapByTitle[rel.Source]
targetEntity := b.entityMapByTitle[rel.Target]
if sourceEntity == nil || targetEntity == nil {
log.Warnf("Missing entity for relationship %s -> %s", rel.Source, rel.Target)
continue
}
// Build Chunk graph - connect all related document chunks
for _, sourceChunkID := range sourceEntity.ChunkIDs {
if _, exists := b.chunkGraph[sourceChunkID]; !exists {
b.chunkGraph[sourceChunkID] = make(map[string]*ChunkRelation)
}
for _, targetChunkID := range targetEntity.ChunkIDs {
if _, exists := b.chunkGraph[targetChunkID]; !exists {
b.chunkGraph[targetChunkID] = make(map[string]*ChunkRelation)
}
relation := &ChunkRelation{
Weight: rel.Weight,
Degree: rel.CombinedDegree,
}
b.chunkGraph[sourceChunkID][targetChunkID] = relation
b.chunkGraph[targetChunkID][sourceChunkID] = relation
}
}
}
log.Infof("Chunk graph built with %d nodes", len(b.chunkGraph))
}
// GetAllEntities returns all entities
func (b *graphBuilder) GetAllEntities() []*types.Entity {
b.mutex.RLock()
defer b.mutex.RUnlock()
entities := make([]*types.Entity, 0, len(b.entityMap))
for _, entity := range b.entityMap {
entities = append(entities, entity)
}
return entities
}
// GetAllRelationships returns all relationships
func (b *graphBuilder) GetAllRelationships() []*types.Relationship {
b.mutex.RLock()
defer b.mutex.RUnlock()
relationships := make([]*types.Relationship, 0, len(b.relationshipMap))
for _, relationship := range b.relationshipMap {
relationships = append(relationships, relationship)
}
return relationships
}
// GetRelationChunks retrieves document chunks directly related to the given chunkID
// It returns a list of related document chunk IDs sorted by weight and degree
func (b *graphBuilder) GetRelationChunks(chunkID string, topK int) []string {
b.mutex.RLock()
defer b.mutex.RUnlock()
log := logger.GetLogger(context.Background())
log.Debugf("Getting related chunks for %s (topK=%d)", chunkID, topK)
// Create weighted chunk structure for sorting
type weightedChunk struct {
id string
weight float64
degree int
}
// Collect related chunks with their weights and degrees
weightedChunks := make([]weightedChunk, 0)
for relationChunkID, relation := range b.chunkGraph[chunkID] {
weightedChunks = append(weightedChunks, weightedChunk{
id: relationChunkID,
weight: relation.Weight,
degree: relation.Degree,
})
}
// Sort by weight and degree in descending order
slices.SortFunc(weightedChunks, func(a, b weightedChunk) int {
// Sort by weight first
if a.weight > b.weight {
return -1 // Descending order
} else if a.weight < b.weight {
return 1
}
// If weights are equal, sort by degree
if a.degree > b.degree {
return -1 // Descending order
} else if a.degree < b.degree {
return 1
}
return 0
})
// Take top K results
resultCount := len(weightedChunks)
if topK > 0 && topK < resultCount {
resultCount = topK
}
// Extract chunk IDs
chunks := make([]string, 0, resultCount)
for i := 0; i < resultCount; i++ {
chunks = append(chunks, weightedChunks[i].id)
}
log.Debugf("Found %d related chunks for %s (limited to %d)",
len(weightedChunks), chunkID, resultCount)
return chunks
}
// GetIndirectRelationChunks retrieves document chunks indirectly related to the given chunkID
// It returns document chunk IDs found through second-degree connections
func (b *graphBuilder) GetIndirectRelationChunks(chunkID string, topK int) []string {
b.mutex.RLock()
defer b.mutex.RUnlock()
log := logger.GetLogger(context.Background())
log.Debugf("Getting indirectly related chunks for %s (topK=%d)", chunkID, topK)
// Create weighted chunk structure for sorting
type weightedChunk struct {
id string
weight float64
degree int
}
// Get directly related chunks (first-degree connections)
directChunks := make(map[string]struct{})
directChunks[chunkID] = struct{}{} // Add original chunkID
for directChunkID := range b.chunkGraph[chunkID] {
directChunks[directChunkID] = struct{}{}
}
log.Debugf("Found %d directly related chunks to exclude", len(directChunks))
// Use map to deduplicate and store second-degree connections
indirectChunkMap := make(map[string]*ChunkRelation)
// Get first-degree connections
for directChunkID, directRelation := range b.chunkGraph[chunkID] {
// Get second-degree connections
for indirectChunkID, indirectRelation := range b.chunkGraph[directChunkID] {
// Skip self and all direct connections
if _, isDirect := directChunks[indirectChunkID]; isDirect {
continue
}
// Weight decay: second-degree relationship weight is the product of two direct relationship weights
// multiplied by decay coefficient
combinedWeight := directRelation.Weight * indirectRelation.Weight * IndirectRelationWeightDecay
// Degree calculation: take the maximum degree from the two path segments
combinedDegree := max(directRelation.Degree, indirectRelation.Degree)
// If already exists, take the higher weight
if existingRel, exists := indirectChunkMap[indirectChunkID]; !exists ||
combinedWeight > existingRel.Weight {
indirectChunkMap[indirectChunkID] = &ChunkRelation{
Weight: combinedWeight,
Degree: combinedDegree,
}
}
}
}
// Convert to sortable slice
weightedChunks := make([]weightedChunk, 0, len(indirectChunkMap))
for id, relation := range indirectChunkMap {
weightedChunks = append(weightedChunks, weightedChunk{
id: id,
weight: relation.Weight,
degree: relation.Degree,
})
}
// Sort by weight and degree in descending order
slices.SortFunc(weightedChunks, func(a, b weightedChunk) int {
// Sort by weight first
if a.weight > b.weight {
return -1 // Descending order
} else if a.weight < b.weight {
return 1
}
// If weights are equal, sort by degree
if a.degree > b.degree {
return -1 // Descending order
} else if a.degree < b.degree {
return 1
}
return 0
})
// Take top K results
resultCount := len(weightedChunks)
if topK > 0 && topK < resultCount {
resultCount = topK
}
// Extract chunk IDs
chunks := make([]string, 0, resultCount)
for i := 0; i < resultCount; i++ {
chunks = append(chunks, weightedChunks[i].id)
}
log.Debugf("Found %d indirect related chunks for %s (limited to %d)",
len(weightedChunks), chunkID, resultCount)
return chunks
}
// getEntityByTitle retrieves an entity by its title
func (b *graphBuilder) getEntityByTitle(title string) *types.Entity {
return b.entityMapByTitle[title]
}
// dfs depth-first search to find connected components
func dfs(entityTitle string,
adjacencyList map[string]map[string]*types.Relationship,
visited map[string]bool, component *[]string) {
visited[entityTitle] = true
*component = append(*component, entityTitle)
// traverse all relationships of the current entity
for targetEntity := range adjacencyList[entityTitle] {
if !visited[targetEntity] {
dfs(targetEntity, adjacencyList, visited, component)
}
}
// check reverse relationships (check if other entities point to the current entity)
for source, targets := range adjacencyList {
for target := range targets {
if target == entityTitle && !visited[source] {
dfs(source, adjacencyList, visited, component)
}
}
}
}
// generateKnowledgeGraphDiagram generate Mermaid diagram for knowledge graph
func (b *graphBuilder) generateKnowledgeGraphDiagram(ctx context.Context) string {
log := logger.GetLogger(ctx)
log.Info("Generating knowledge graph visualization diagram...")
var sb strings.Builder
// Mermaid diagram header
sb.WriteString("```mermaid\ngraph TD\n")
sb.WriteString(" %% entity style definition\n")
sb.WriteString(" classDef entity fill:#f9f,stroke:#333,stroke-width:1px;\n")
sb.WriteString(" classDef highFreq fill:#bbf,stroke:#333,stroke-width:2px;\n\n")
// get all entities and sort by frequency
entities := b.GetAllEntities()
slices.SortFunc(entities, func(a, b *types.Entity) int {
if a.Frequency > b.Frequency {
return -1
} else if a.Frequency < b.Frequency {
return 1
}
return 0
})
// get relationships and sort by weight
relationships := b.GetAllRelationships()
slices.SortFunc(relationships, func(a, b *types.Relationship) int {
if a.Weight > b.Weight {
return -1
} else if a.Weight < b.Weight {
return 1
}
return 0
})
// create entity ID mapping
entityMap := make(map[string]string) // store entity title to node ID mapping
for i, entity := range entities {
nodeID := fmt.Sprintf("E%d", i)
entityMap[entity.Title] = nodeID
}
// create adjacency list to represent graph structure
adjacencyList := make(map[string]map[string]*types.Relationship)
for _, entity := range entities {
adjacencyList[entity.Title] = make(map[string]*types.Relationship)
}
// fill adjacency list
for _, rel := range relationships {
if _, sourceExists := entityMap[rel.Source]; sourceExists {
if _, targetExists := entityMap[rel.Target]; targetExists {
adjacencyList[rel.Source][rel.Target] = rel
}
}
}
// use DFS to find connected components (subgraphs)
visited := make(map[string]bool)
subgraphs := make([][]string, 0) // store entity titles in each subgraph
for _, entity := range entities {
if !visited[entity.Title] {
component := make([]string, 0)
dfs(entity.Title, adjacencyList, visited, &component)
if len(component) > 0 {
subgraphs = append(subgraphs, component)
}
}
}
// generate Mermaid subgraphs
subgraphCount := 0
for _, component := range subgraphs {
// check if this component has relationships
hasRelations := false
nodeCount := len(component)
// if there is only 1 node, check if it has relationships
if nodeCount == 1 {
entityTitle := component[0]
// check if this entity appears as source or target in any relationship
for _, rel := range relationships {
if rel.Source == entityTitle || rel.Target == entityTitle {
hasRelations = true
break
}
}
// if there is only 1 node and no relationships, skip this subgraph
if !hasRelations {
continue
}
} else if nodeCount > 1 {
// a subgraph with more than 1 node must have relationships
hasRelations = true
}
// only draw if there are multiple entities or at least one relationship in the subgraph
if hasRelations {
subgraphCount++
sb.WriteString(fmt.Sprintf("\n subgraph 子图%d\n", subgraphCount))
// add all entities in this subgraph
entitiesInComponent := make(map[string]bool)
for _, entityTitle := range component {
nodeID := entityMap[entityTitle]
entitiesInComponent[entityTitle] = true
// add node definition for each entity
entity := b.entityMapByTitle[entityTitle]
if entity != nil {
sb.WriteString(fmt.Sprintf(" %s[\"%s\"]\n", nodeID, entityTitle))
}
}
// add relationships in this subgraph
for _, rel := range relationships {
if entitiesInComponent[rel.Source] && entitiesInComponent[rel.Target] {
sourceID := entityMap[rel.Source]
targetID := entityMap[rel.Target]
linkStyle := "-->"
// adjust link style based on relationship strength
if rel.Strength > 7 {
linkStyle = "==>"
}
sb.WriteString(fmt.Sprintf(" %s %s|%s| %s\n",
sourceID, linkStyle, rel.Description, targetID))
}
}
// subgraph ends
sb.WriteString(" end\n")
// apply style class
for _, entityTitle := range component {
nodeID := entityMap[entityTitle]
entity := b.entityMapByTitle[entityTitle]
if entity != nil {
if entity.Frequency > 5 {
sb.WriteString(fmt.Sprintf(" class %s highFreq;\n", nodeID))
} else {
sb.WriteString(fmt.Sprintf(" class %s entity;\n", nodeID))
}
}
}
}
}
// close Mermaid diagram
sb.WriteString("```\n")
log.Infof("Knowledge graph visualization diagram generated with %d subgraphs", subgraphCount)
return sb.String()
}