252 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			252 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Go
		
	
	
	
package tools
 | 
						||
 | 
						||
import (
 | 
						||
	"ai_scheduler/internal/config"
 | 
						||
	"ai_scheduler/internal/entitys"
 | 
						||
	"ai_scheduler/internal/pkg/l_request"
 | 
						||
	"bufio"
 | 
						||
	"context"
 | 
						||
	"encoding/json"
 | 
						||
	"fmt"
 | 
						||
	"net/http"
 | 
						||
	"strings"
 | 
						||
)
 | 
						||
 | 
						||
// 知识库工具
 | 
						||
type KnowledgeBaseTool struct {
 | 
						||
	config config.ToolConfig
 | 
						||
}
 | 
						||
 | 
						||
// NewKnowledgeBaseTool 创建知识库工具
 | 
						||
func NewKnowledgeBaseTool(config config.ToolConfig) *KnowledgeBaseTool {
 | 
						||
	return &KnowledgeBaseTool{config: config}
 | 
						||
}
 | 
						||
 | 
						||
func (k *KnowledgeBaseTool) GetConfig() config.ToolConfig {
 | 
						||
	return k.config
 | 
						||
}
 | 
						||
 | 
						||
// Name 返回工具名称
 | 
						||
func (k *KnowledgeBaseTool) Name() string {
 | 
						||
	return "knowledgeBase"
 | 
						||
}
 | 
						||
 | 
						||
// Description 返回工具描述
 | 
						||
func (k *KnowledgeBaseTool) Description() string {
 | 
						||
	return "请求知识库"
 | 
						||
}
 | 
						||
 | 
						||
// Definition 返回工具定义
 | 
						||
func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition {
 | 
						||
	return entitys.ToolDefinition{
 | 
						||
		Type: "function",
 | 
						||
		Function: entitys.FunctionDef{
 | 
						||
			Name:        k.Name(),
 | 
						||
			Description: k.Description(),
 | 
						||
			Parameters: map[string]interface{}{
 | 
						||
				"type": "object",
 | 
						||
				"properties": map[string]interface{}{
 | 
						||
					"query": map[string]interface{}{
 | 
						||
						"type":        "string",
 | 
						||
						"description": "知识库查询条件",
 | 
						||
					},
 | 
						||
				},
 | 
						||
				"required": []string{"query"},
 | 
						||
			},
 | 
						||
		},
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// Execute 执行知识库查询
 | 
						||
func (k *KnowledgeBaseTool) Execute(ctx context.Context, requireData *entitys.RequireData) error {
 | 
						||
 | 
						||
	return k.chat(requireData)
 | 
						||
 | 
						||
}
 | 
						||
 | 
						||
// Message 表示解析后的 SSE 消息
 | 
						||
type Message struct {
 | 
						||
	Event string // 事件类型(默认 "message")
 | 
						||
	Data  string // 消息内容(可能多行)
 | 
						||
	ID    string // 消息 ID(可选)
 | 
						||
}
 | 
						||
 | 
						||
type MsgContent struct {
 | 
						||
	Id                  string      `json:"id"`
 | 
						||
	ResponseType        string      `json:"response_type"`
 | 
						||
	Content             string      `json:"content"`
 | 
						||
	Done                bool        `json:"done"`
 | 
						||
	KnowledgeReferences interface{} `json:"knowledge_references"`
 | 
						||
}
 | 
						||
 | 
						||
// 解析知识库响应内容,并把通过channel结果返回
 | 
						||
func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entitys.Response) (msgContent MsgContent, err error) {
 | 
						||
	err = json.Unmarshal([]byte(input), &msgContent)
 | 
						||
	if err != nil {
 | 
						||
		err = fmt.Errorf("unmarshal input failed: %w", err)
 | 
						||
	}
 | 
						||
 | 
						||
	channel <- entitys.Response{
 | 
						||
		Index:   this.Name(),
 | 
						||
		Content: msgContent.Content,
 | 
						||
		Type:    entitys.ResponseStream,
 | 
						||
	}
 | 
						||
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
// 请求知识库聊天
 | 
						||
func (this *KnowledgeBaseTool) chat(requireData *entitys.RequireData) (err error) {
 | 
						||
 | 
						||
	req := l_request.Request{
 | 
						||
		Method: "post",
 | 
						||
		Url:    this.config.BaseURL + "/api/v1/knowledge-chat/" + requireData.KnowledgeConf.Session,
 | 
						||
		Params: nil,
 | 
						||
		Headers: map[string]string{
 | 
						||
			"Content-Type": "application/json",
 | 
						||
			"X-API-Key":    requireData.KnowledgeConf.ApiKey,
 | 
						||
		},
 | 
						||
		Cookies: nil,
 | 
						||
		Data:    nil,
 | 
						||
		Json: map[string]interface{}{
 | 
						||
			"query": requireData.KnowledgeConf.Query,
 | 
						||
		},
 | 
						||
		Files:    nil,
 | 
						||
		Raw:      "",
 | 
						||
		JsonByte: nil,
 | 
						||
		Xml:      nil,
 | 
						||
	}
 | 
						||
 | 
						||
	rsp, err := req.SendNoParseResponse()
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	defer rsp.Body.Close()
 | 
						||
 | 
						||
	err = this.connectAndReadSSE(rsp, requireData.Ch)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
// 连接 SSE 并读取数据
 | 
						||
func (this *KnowledgeBaseTool) connectAndReadSSE(resp *http.Response, channel chan entitys.Response) error {
 | 
						||
 | 
						||
	// 验证响应状态和格式
 | 
						||
	if resp.StatusCode != http.StatusOK {
 | 
						||
		return fmt.Errorf("非 200 状态码: %d", resp.StatusCode)
 | 
						||
	}
 | 
						||
	contentType := resp.Header.Get("Content-Type")
 | 
						||
	if !strings.Contains(contentType, "text/event-stream") {
 | 
						||
		return fmt.Errorf("不支持的 Content-Type: %s", contentType)
 | 
						||
	}
 | 
						||
 | 
						||
	// 逐行读取响应流
 | 
						||
	scanner := bufio.NewScanner(resp.Body)
 | 
						||
	var currentMsg Message // 当前正在组装的消息
 | 
						||
 | 
						||
	for scanner.Scan() {
 | 
						||
		line := scanner.Text()
 | 
						||
		if line == "" {
 | 
						||
			// 空行表示一条消息结束,处理当前消息
 | 
						||
			if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
 | 
						||
				_, err := this.msgContentParse(currentMsg.Data, channel)
 | 
						||
				if err != nil {
 | 
						||
					return fmt.Errorf("msgContentParse failed: %w", err)
 | 
						||
				}
 | 
						||
				currentMsg = Message{} // 重置消息
 | 
						||
			}
 | 
						||
			continue
 | 
						||
		}
 | 
						||
 | 
						||
		// 解析字段(格式:"field: value")
 | 
						||
		parts := strings.SplitN(line, ":", 2)
 | 
						||
		if len(parts) < 2 {
 | 
						||
			continue // 无效行(无冒号),跳过
 | 
						||
		}
 | 
						||
		field := strings.TrimSpace(parts[0])
 | 
						||
		value := strings.TrimSpace(parts[1])
 | 
						||
 | 
						||
		switch field {
 | 
						||
		case "event":
 | 
						||
			currentMsg.Event = value
 | 
						||
		case "data":
 | 
						||
			// data 可能多行,用换行符拼接(最后一条消息可能无结尾空行)
 | 
						||
			currentMsg.Data += value + ""
 | 
						||
		default:
 | 
						||
			// 忽略未知字段
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	// 检查扫描错误(如连接断开)
 | 
						||
	if err := scanner.Err(); err != nil {
 | 
						||
		return fmt.Errorf("读取流失败: %w", err)
 | 
						||
	}
 | 
						||
 | 
						||
	// 处理最后一条未结束的消息(无结尾空行)
 | 
						||
	if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
 | 
						||
		_, err := this.msgContentParse(currentMsg.Data, channel)
 | 
						||
		if err != nil {
 | 
						||
			return fmt.Errorf("msgContentParse failed: %w", err)
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
// 获取知识库 session
 | 
						||
func GetKnowledgeBaseSession(host, baseId, apiKey string) (string, error) {
 | 
						||
	req := l_request.Request{
 | 
						||
		Method: "post",
 | 
						||
		Url:    host + "/api/v1/sessions",
 | 
						||
		Params: nil,
 | 
						||
		Headers: map[string]string{
 | 
						||
			"Content-Type": "application/json",
 | 
						||
			"X-API-Key":    apiKey,
 | 
						||
		},
 | 
						||
		Cookies: nil,
 | 
						||
		Data:    nil,
 | 
						||
		Json: map[string]interface{}{
 | 
						||
			"knowledge_base_id": baseId,
 | 
						||
		},
 | 
						||
		Files:    nil,
 | 
						||
		Raw:      "",
 | 
						||
		JsonByte: nil,
 | 
						||
		Xml:      nil,
 | 
						||
	}
 | 
						||
 | 
						||
	rsp, err := req.Send()
 | 
						||
	if err != nil {
 | 
						||
		return "", err
 | 
						||
	}
 | 
						||
 | 
						||
	var result sessionRsp
 | 
						||
	err = json.Unmarshal(rsp.Content, &result)
 | 
						||
 | 
						||
	return result.Data.Id, err
 | 
						||
}
 | 
						||
 | 
						||
type sessionRsp struct {
 | 
						||
	Data struct {
 | 
						||
		Id               string  `json:"id"`
 | 
						||
		Title            string  `json:"title"`
 | 
						||
		Description      string  `json:"description"`
 | 
						||
		TenantId         int     `json:"tenant_id"`
 | 
						||
		KnowledgeBaseId  string  `json:"knowledge_base_id"`
 | 
						||
		MaxRounds        int     `json:"max_rounds"`
 | 
						||
		EnableRewrite    bool    `json:"enable_rewrite"`
 | 
						||
		FallbackStrategy string  `json:"fallback_strategy"`
 | 
						||
		FallbackResponse string  `json:"fallback_response"`
 | 
						||
		EmbeddingTopK    int     `json:"embedding_top_k"`
 | 
						||
		KeywordThreshold float64 `json:"keyword_threshold"`
 | 
						||
		VectorThreshold  float64 `json:"vector_threshold"`
 | 
						||
		RerankModelId    string  `json:"rerank_model_id"`
 | 
						||
		RerankTopK       int     `json:"rerank_top_k"`
 | 
						||
		RerankThreshold  float64 `json:"rerank_threshold"`
 | 
						||
		SummaryModelId   string  `json:"summary_model_id"`
 | 
						||
	} `json:"data"`
 | 
						||
	Success bool `json:"success"`
 | 
						||
}
 |