ai-courseware/eino-project/internal/domain/vector/vector.go

210 lines
4.9 KiB
Go

package vector
import (
"context"
"fmt"
"log"
"math"
"sort"
"strings"
"sync"
)
// Config 向量数据库配置
type Config struct {
Endpoint string `yaml:"endpoint"`
Collection string `yaml:"collection"`
Timeout string `yaml:"timeout"`
}
// Document 文档结构
type Document struct {
ID string `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
Metadata map[string]interface{} `json:"metadata"`
}
// SearchResult 搜索结果
type SearchResult struct {
Document Document `json:"document"`
Distance float64 `json:"distance"`
Score float64 `json:"score"`
}
// VectorService 向量服务接口
type VectorService interface {
AddDocument(ctx context.Context, doc *Document) error
AddDocuments(ctx context.Context, docs []*Document) error
SearchSimilar(ctx context.Context, query string, limit int) ([]*SearchResult, error)
DeleteDocument(ctx context.Context, id string) error
GetDocument(ctx context.Context, id string) (*Document, error)
HealthCheck(ctx context.Context) error
}
// vectorService 向量服务实现(内存版本,用于演示)
type vectorService struct {
documents map[string]Document
config *Config
mutex sync.RWMutex
}
// NewVectorService 创建新的向量服务
func NewVectorService(config *Config) (VectorService, error) {
vs := &vectorService{
documents: make(map[string]Document),
config: config,
}
log.Printf("Initialized in-memory vector service for collection: %s", config.Collection)
return vs, nil
}
// AddDocument 添加文档
func (vs *vectorService) AddDocument(ctx context.Context, doc *Document) error {
vs.mutex.Lock()
defer vs.mutex.Unlock()
vs.documents[doc.ID] = *doc
log.Printf("Added document: %s", doc.ID)
return nil
}
// AddDocuments 批量添加文档
func (vs *vectorService) AddDocuments(ctx context.Context, docs []*Document) error {
vs.mutex.Lock()
defer vs.mutex.Unlock()
for _, doc := range docs {
vs.documents[doc.ID] = *doc
}
log.Printf("Successfully added %d documents to collection %s", len(docs), vs.config.Collection)
return nil
}
// SearchSimilar 搜索相似文档(基于简单的文本相似度)
func (vs *vectorService) SearchSimilar(ctx context.Context, query string, limit int) ([]*SearchResult, error) {
vs.mutex.RLock()
defer vs.mutex.RUnlock()
if limit <= 0 {
limit = 10
}
var results []*SearchResult
queryLower := strings.ToLower(query)
// 计算每个文档与查询的相似度
for _, doc := range vs.documents {
contentLower := strings.ToLower(doc.Content)
// 简单的相似度计算:基于共同词汇
similarity := calculateSimilarity(queryLower, contentLower)
distance := 1.0 - similarity
result := &SearchResult{
Document: doc,
Distance: distance,
Score: similarity,
}
results = append(results, result)
}
// 按相似度排序
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
// 限制结果数量
if len(results) > limit {
results = results[:limit]
}
log.Printf("Found %d similar documents for query", len(results))
return results, nil
}
// DeleteDocument 删除文档
func (vs *vectorService) DeleteDocument(ctx context.Context, id string) error {
vs.mutex.Lock()
defer vs.mutex.Unlock()
if _, exists := vs.documents[id]; !exists {
return fmt.Errorf("document not found: %s", id)
}
delete(vs.documents, id)
log.Printf("Successfully deleted document: %s", id)
return nil
}
// GetDocument 获取文档
func (vs *vectorService) GetDocument(ctx context.Context, id string) (*Document, error) {
vs.mutex.RLock()
defer vs.mutex.RUnlock()
doc, exists := vs.documents[id]
if !exists {
return nil, fmt.Errorf("document not found: %s", id)
}
return &doc, nil
}
// HealthCheck 健康检查
func (vs *vectorService) HealthCheck(ctx context.Context) error {
vs.mutex.RLock()
defer vs.mutex.RUnlock()
log.Printf("Vector service health check passed. Documents count: %d", len(vs.documents))
return nil
}
// calculateSimilarity 计算两个文本的相似度
func calculateSimilarity(text1, text2 string) float64 {
words1 := strings.Fields(text1)
words2 := strings.Fields(text2)
if len(words1) == 0 || len(words2) == 0 {
return 0.0
}
// 创建词汇集合
wordSet1 := make(map[string]bool)
wordSet2 := make(map[string]bool)
for _, word := range words1 {
wordSet1[word] = true
}
for _, word := range words2 {
wordSet2[word] = true
}
// 计算交集
intersection := 0
for word := range wordSet1 {
if wordSet2[word] {
intersection++
}
}
// 计算并集
union := len(wordSet1) + len(wordSet2) - intersection
if union == 0 {
return 0.0
}
// Jaccard 相似度
similarity := float64(intersection) / float64(union)
// 添加一些基于包含关系的加权
if strings.Contains(text2, text1) || strings.Contains(text1, text2) {
similarity = math.Min(1.0, similarity*1.5)
}
return similarity
}