210 lines
4.9 KiB
Go
210 lines
4.9 KiB
Go
package vector
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"math"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
// Config 向量数据库配置
|
|
type Config struct {
|
|
Endpoint string `yaml:"endpoint"`
|
|
Collection string `yaml:"collection"`
|
|
Timeout string `yaml:"timeout"`
|
|
}
|
|
|
|
// Document 文档结构
|
|
type Document struct {
|
|
ID string `json:"id"`
|
|
Title string `json:"title"`
|
|
Content string `json:"content"`
|
|
Metadata map[string]interface{} `json:"metadata"`
|
|
}
|
|
|
|
// SearchResult 搜索结果
|
|
type SearchResult struct {
|
|
Document Document `json:"document"`
|
|
Distance float64 `json:"distance"`
|
|
Score float64 `json:"score"`
|
|
}
|
|
|
|
// VectorService 向量服务接口
|
|
type VectorService interface {
|
|
AddDocument(ctx context.Context, doc *Document) error
|
|
AddDocuments(ctx context.Context, docs []*Document) error
|
|
SearchSimilar(ctx context.Context, query string, limit int) ([]*SearchResult, error)
|
|
DeleteDocument(ctx context.Context, id string) error
|
|
GetDocument(ctx context.Context, id string) (*Document, error)
|
|
HealthCheck(ctx context.Context) error
|
|
}
|
|
|
|
// vectorService 向量服务实现(内存版本,用于演示)
|
|
type vectorService struct {
|
|
documents map[string]Document
|
|
config *Config
|
|
mutex sync.RWMutex
|
|
}
|
|
|
|
// NewVectorService 创建新的向量服务
|
|
func NewVectorService(config *Config) (VectorService, error) {
|
|
vs := &vectorService{
|
|
documents: make(map[string]Document),
|
|
config: config,
|
|
}
|
|
|
|
log.Printf("Initialized in-memory vector service for collection: %s", config.Collection)
|
|
return vs, nil
|
|
}
|
|
|
|
// AddDocument 添加文档
|
|
func (vs *vectorService) AddDocument(ctx context.Context, doc *Document) error {
|
|
vs.mutex.Lock()
|
|
defer vs.mutex.Unlock()
|
|
|
|
vs.documents[doc.ID] = *doc
|
|
log.Printf("Added document: %s", doc.ID)
|
|
return nil
|
|
}
|
|
|
|
// AddDocuments 批量添加文档
|
|
func (vs *vectorService) AddDocuments(ctx context.Context, docs []*Document) error {
|
|
vs.mutex.Lock()
|
|
defer vs.mutex.Unlock()
|
|
|
|
for _, doc := range docs {
|
|
vs.documents[doc.ID] = *doc
|
|
}
|
|
|
|
log.Printf("Successfully added %d documents to collection %s", len(docs), vs.config.Collection)
|
|
return nil
|
|
}
|
|
|
|
// SearchSimilar 搜索相似文档(基于简单的文本相似度)
|
|
func (vs *vectorService) SearchSimilar(ctx context.Context, query string, limit int) ([]*SearchResult, error) {
|
|
vs.mutex.RLock()
|
|
defer vs.mutex.RUnlock()
|
|
|
|
if limit <= 0 {
|
|
limit = 10
|
|
}
|
|
|
|
var results []*SearchResult
|
|
queryLower := strings.ToLower(query)
|
|
|
|
// 计算每个文档与查询的相似度
|
|
for _, doc := range vs.documents {
|
|
contentLower := strings.ToLower(doc.Content)
|
|
|
|
// 简单的相似度计算:基于共同词汇
|
|
similarity := calculateSimilarity(queryLower, contentLower)
|
|
distance := 1.0 - similarity
|
|
|
|
result := &SearchResult{
|
|
Document: doc,
|
|
Distance: distance,
|
|
Score: similarity,
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
|
|
// 按相似度排序
|
|
sort.Slice(results, func(i, j int) bool {
|
|
return results[i].Score > results[j].Score
|
|
})
|
|
|
|
// 限制结果数量
|
|
if len(results) > limit {
|
|
results = results[:limit]
|
|
}
|
|
|
|
log.Printf("Found %d similar documents for query", len(results))
|
|
return results, nil
|
|
}
|
|
|
|
// DeleteDocument 删除文档
|
|
func (vs *vectorService) DeleteDocument(ctx context.Context, id string) error {
|
|
vs.mutex.Lock()
|
|
defer vs.mutex.Unlock()
|
|
|
|
if _, exists := vs.documents[id]; !exists {
|
|
return fmt.Errorf("document not found: %s", id)
|
|
}
|
|
|
|
delete(vs.documents, id)
|
|
log.Printf("Successfully deleted document: %s", id)
|
|
return nil
|
|
}
|
|
|
|
// GetDocument 获取文档
|
|
func (vs *vectorService) GetDocument(ctx context.Context, id string) (*Document, error) {
|
|
vs.mutex.RLock()
|
|
defer vs.mutex.RUnlock()
|
|
|
|
doc, exists := vs.documents[id]
|
|
if !exists {
|
|
return nil, fmt.Errorf("document not found: %s", id)
|
|
}
|
|
|
|
return &doc, nil
|
|
}
|
|
|
|
// HealthCheck 健康检查
|
|
func (vs *vectorService) HealthCheck(ctx context.Context) error {
|
|
vs.mutex.RLock()
|
|
defer vs.mutex.RUnlock()
|
|
|
|
log.Printf("Vector service health check passed. Documents count: %d", len(vs.documents))
|
|
return nil
|
|
}
|
|
|
|
// calculateSimilarity 计算两个文本的相似度
|
|
func calculateSimilarity(text1, text2 string) float64 {
|
|
words1 := strings.Fields(text1)
|
|
words2 := strings.Fields(text2)
|
|
|
|
if len(words1) == 0 || len(words2) == 0 {
|
|
return 0.0
|
|
}
|
|
|
|
// 创建词汇集合
|
|
wordSet1 := make(map[string]bool)
|
|
wordSet2 := make(map[string]bool)
|
|
|
|
for _, word := range words1 {
|
|
wordSet1[word] = true
|
|
}
|
|
|
|
for _, word := range words2 {
|
|
wordSet2[word] = true
|
|
}
|
|
|
|
// 计算交集
|
|
intersection := 0
|
|
for word := range wordSet1 {
|
|
if wordSet2[word] {
|
|
intersection++
|
|
}
|
|
}
|
|
|
|
// 计算并集
|
|
union := len(wordSet1) + len(wordSet2) - intersection
|
|
|
|
if union == 0 {
|
|
return 0.0
|
|
}
|
|
|
|
// Jaccard 相似度
|
|
similarity := float64(intersection) / float64(union)
|
|
|
|
// 添加一些基于包含关系的加权
|
|
if strings.Contains(text2, text1) || strings.Contains(text1, text2) {
|
|
similarity = math.Min(1.0, similarity*1.5)
|
|
}
|
|
|
|
return similarity
|
|
}
|