l_ai_knowledge/client/session.go

404 lines
13 KiB
Go

// Package client provides the implementation for interacting with the WeKnora API
// The Session related interfaces are used to manage sessions for question-answering
// Sessions can be created, retrieved, updated, deleted, and queried
// They can also be used to generate titles for sessions
package client
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
)
// SessionStrategy defines session strategy
type SessionStrategy struct {
MaxRounds int `json:"max_rounds"` // Maximum number of rounds to maintain
EnableRewrite bool `json:"enable_rewrite"` // Enable query rewrite
FallbackStrategy string `json:"fallback_strategy"` // Fallback strategy
FallbackResponse string `json:"fallback_response"` // Fixed fallback response content
EmbeddingTopK int `json:"embedding_top_k"` // Top K for vector retrieval
KeywordThreshold float64 `json:"keyword_threshold"` // Keyword retrieval threshold
VectorThreshold float64 `json:"vector_threshold"` // Vector retrieval threshold
RerankModelID string `json:"rerank_model_id"` // Rerank model ID
RerankTopK int `json:"rerank_top_k"` // Top K for reranking
RerankThreshold float64 `json:"reranking_threshold"` // Reranking threshold
SummaryModelID string `json:"summary_model_id"` // Summary model ID
SummaryParameters *SummaryConfig `json:"summary_parameters"` // Summary model parameters
NoMatchPrefix string `json:"no_match_prefix"` // Fallback response prefix
}
// SummaryConfig defines summary configuration
type SummaryConfig struct {
MaxTokens int `json:"max_tokens"`
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
FrequencyPenalty float64 `json:"frequency_penalty"`
PresencePenalty float64 `json:"presence_penalty"`
RepeatPenalty float64 `json:"repeat_penalty"`
Prompt string `json:"prompt"`
ContextTemplate string `json:"context_template"`
NoMatchPrefix string `json:"no_match_prefix"`
Temperature float64 `json:"temperature"`
Seed int `json:"seed"`
MaxCompletionTokens int `json:"max_completion_tokens"`
}
// CreateSessionRequest session creation request
type CreateSessionRequest struct {
KnowledgeBaseID string `json:"knowledge_base_id"` // Associated knowledge base ID
SessionStrategy *SessionStrategy `json:"session_strategy"` // Session strategy
}
// Session session information
type Session struct {
ID string `json:"id"`
TenantID uint `json:"tenant_id"`
KnowledgeBaseID string `json:"knowledge_base_id"`
Title string `json:"title"`
MaxRounds int `json:"max_rounds"`
EnableRewrite bool `json:"enable_rewrite"`
FallbackStrategy string `json:"fallback_strategy"`
FallbackResponse string `json:"fallback_response"`
EmbeddingTopK int `json:"embedding_top_k"`
KeywordThreshold float64 `json:"keyword_threshold"`
VectorThreshold float64 `json:"vector_threshold"`
RerankModelID string `json:"rerank_model_id"`
RerankTopK int `json:"rerank_top_k"`
RerankThreshold float64 `json:"reranking_threshold"` // Reranking threshold
SummaryModelID string `json:"summary_model_id"`
SummaryParameters *SummaryConfig `json:"summary_parameters"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// SessionResponse session response
type SessionResponse struct {
Success bool `json:"success"`
Data Session `json:"data"`
}
// SessionListResponse session list response
type SessionListResponse struct {
Success bool `json:"success"`
Data []Session `json:"data"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// CreateSession creates a session
func (c *Client) CreateSession(ctx context.Context, request *CreateSessionRequest) (*Session, error) {
resp, err := c.doRequest(ctx, http.MethodPost, "/api/v1/sessions", request, nil)
if err != nil {
return nil, err
}
var response SessionResponse
if err := parseResponse(resp, &response); err != nil {
return nil, err
}
return &response.Data, nil
}
// GetSession gets a session
func (c *Client) GetSession(ctx context.Context, sessionID string) (*Session, error) {
path := fmt.Sprintf("/api/v1/sessions/%s", sessionID)
resp, err := c.doRequest(ctx, http.MethodGet, path, nil, nil)
if err != nil {
return nil, err
}
var response SessionResponse
if err := parseResponse(resp, &response); err != nil {
return nil, err
}
return &response.Data, nil
}
// GetSessionsByTenant gets all sessions for a tenant
func (c *Client) GetSessionsByTenant(ctx context.Context, page int, pageSize int) ([]Session, int, error) {
queryParams := url.Values{}
queryParams.Add("page", strconv.Itoa(page))
queryParams.Add("page_size", strconv.Itoa(pageSize))
resp, err := c.doRequest(ctx, http.MethodGet, "/api/v1/sessions", nil, queryParams)
if err != nil {
return nil, 0, err
}
var response SessionListResponse
if err := parseResponse(resp, &response); err != nil {
return nil, 0, err
}
return response.Data, response.Total, nil
}
// UpdateSession updates a session
func (c *Client) UpdateSession(ctx context.Context, sessionID string, request *CreateSessionRequest) (*Session, error) {
path := fmt.Sprintf("/api/v1/sessions/%s", sessionID)
resp, err := c.doRequest(ctx, http.MethodPut, path, request, nil)
if err != nil {
return nil, err
}
var response SessionResponse
if err := parseResponse(resp, &response); err != nil {
return nil, err
}
return &response.Data, nil
}
// DeleteSession deletes a session
func (c *Client) DeleteSession(ctx context.Context, sessionID string) error {
path := fmt.Sprintf("/api/v1/sessions/%s", sessionID)
resp, err := c.doRequest(ctx, http.MethodDelete, path, nil, nil)
if err != nil {
return err
}
var response struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
}
return parseResponse(resp, &response)
}
// GenerateTitleRequest title generation request
type GenerateTitleRequest struct {
Messages []Message `json:"messages"`
}
// GenerateTitleResponse title generation response
type GenerateTitleResponse struct {
Success bool `json:"success"`
Data string `json:"data"`
}
// GenerateTitle generates a session title
func (c *Client) GenerateTitle(ctx context.Context, sessionID string, request *GenerateTitleRequest) (string, error) {
path := fmt.Sprintf("/api/v1/sessions/%s/generate_title", sessionID)
resp, err := c.doRequest(ctx, http.MethodPost, path, request, nil)
if err != nil {
return "", err
}
var response GenerateTitleResponse
if err := parseResponse(resp, &response); err != nil {
return "", err
}
return response.Data, nil
}
// KnowledgeQARequest knowledge Q&A request
type KnowledgeQARequest struct {
Query string `json:"query"`
}
type ResponseType string
const (
ResponseTypeAnswer ResponseType = "answer"
ResponseTypeReferences ResponseType = "references"
)
// StreamResponse streaming response
type StreamResponse struct {
ID string `json:"id"` // Unique identifier
ResponseType ResponseType `json:"response_type"` // Response type
Content string `json:"content"` // Current content fragment
Done bool `json:"done"` // Whether completed
KnowledgeReferences []*SearchResult `json:"knowledge_references"` // Knowledge references
}
// KnowledgeQAStream knowledge Q&A streaming API
func (c *Client) KnowledgeQAStream(ctx context.Context, sessionID string, query string, callback func(*StreamResponse) error) error {
path := fmt.Sprintf("/api/v1/knowledge-chat/%s", sessionID)
fmt.Printf("Starting KnowledgeQAStream request, session ID: %s, query: %s\n", sessionID, query)
request := &KnowledgeQARequest{
Query: query,
}
resp, err := c.doRequest(ctx, http.MethodPost, path, request, nil)
if err != nil {
fmt.Printf("Request failed: %v\n", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
err := fmt.Errorf("HTTP error %d: %s", resp.StatusCode, string(body))
fmt.Printf("Request returned error status: %v\n", err)
return err
}
fmt.Println("Successfully established SSE connection, processing data stream")
// Use bufio to read SSE data line by line
scanner := bufio.NewScanner(resp.Body)
var dataBuffer string
var eventType string
messageCount := 0
for scanner.Scan() {
line := scanner.Text()
fmt.Printf("Received SSE line: %s\n", line)
// Empty line indicates the end of an event
if line == "" {
if dataBuffer != "" {
fmt.Printf("Processing data: %s, event type: %s\n", dataBuffer, eventType)
var streamResponse StreamResponse
if err := json.Unmarshal([]byte(dataBuffer), &streamResponse); err != nil {
fmt.Printf("Failed to parse SSE data: %v\n", err)
return fmt.Errorf("failed to parse SSE data: %w", err)
}
messageCount++
fmt.Printf("Parsed message #%d, done status: %v\n", messageCount, streamResponse.Done)
if err := callback(&streamResponse); err != nil {
fmt.Printf("Callback processing failed: %v\n", err)
return err
}
dataBuffer = ""
eventType = ""
}
continue
}
// Process lines with event: prefix
if strings.HasPrefix(line, "event:") {
eventType = line[6:] // Remove "event:" prefix
fmt.Printf("Set event type: %s\n", eventType)
}
// Process lines with data: prefix
if strings.HasPrefix(line, "data:") {
dataBuffer = line[5:] // Remove "data:" prefix
}
}
if err := scanner.Err(); err != nil {
fmt.Printf("Failed to read SSE stream: %v\n", err)
return fmt.Errorf("failed to read SSE stream: %w", err)
}
fmt.Printf("KnowledgeQAStream completed, processed %d messages\n", messageCount)
return nil
}
// ContinueStream continues to receive an active stream for a session
func (c *Client) ContinueStream(ctx context.Context, sessionID string, messageID string, callback func(*StreamResponse) error) error {
path := fmt.Sprintf("/api/v1/sessions/continue-stream/%s", sessionID)
queryParams := url.Values{}
queryParams.Add("message_id", messageID)
resp, err := c.doRequest(ctx, http.MethodGet, path, nil, queryParams)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("HTTP error %d: %s", resp.StatusCode, string(body))
}
// Use bufio to read SSE data line by line
scanner := bufio.NewScanner(resp.Body)
var dataBuffer string
var eventType string
for scanner.Scan() {
line := scanner.Text()
// Empty line indicates the end of an event
if line == "" {
if dataBuffer != "" && eventType == "message" {
var streamResponse StreamResponse
if err := json.Unmarshal([]byte(dataBuffer), &streamResponse); err != nil {
return fmt.Errorf("failed to parse SSE data: %w", err)
}
if err := callback(&streamResponse); err != nil {
return err
}
dataBuffer = ""
eventType = ""
}
continue
}
// Process lines with event: prefix
if strings.HasPrefix(line, "event:") {
eventType = line[6:] // Remove "event:" prefix
}
// Process lines with data: prefix
if strings.HasPrefix(line, "data:") {
dataBuffer = line[5:] // Remove "data:" prefix
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("failed to read SSE stream: %w", err)
}
return nil
}
// SearchKnowledgeRequest knowledge search request
type SearchKnowledgeRequest struct {
Query string `json:"query"` // Query content
KnowledgeBaseID string `json:"knowledge_base_id"` // Knowledge base ID
}
// SearchKnowledgeResponse search results response
type SearchKnowledgeResponse struct {
Success bool `json:"success"`
Data []*SearchResult `json:"data"`
}
// SearchKnowledge performs knowledge base search without LLM summarization
func (c *Client) SearchKnowledge(ctx context.Context, request *SearchKnowledgeRequest) ([]*SearchResult, error) {
fmt.Printf("Starting SearchKnowledge request, knowledge base ID: %s, query: %s\n",
request.KnowledgeBaseID, request.Query)
resp, err := c.doRequest(ctx, http.MethodPost, "/api/v1/knowledge-search", request, nil)
if err != nil {
fmt.Printf("Request failed: %v\n", err)
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
err := fmt.Errorf("HTTP error %d: %s", resp.StatusCode, string(body))
fmt.Printf("Request returned error status: %v\n", err)
return nil, err
}
var response SearchKnowledgeResponse
if err := parseResponse(resp, &response); err != nil {
fmt.Printf("Failed to parse response: %v\n", err)
return nil, err
}
fmt.Printf("SearchKnowledge completed, found %d results\n", len(response.Data))
return response.Data, nil
}