package tools import ( "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg/l_request" "bufio" "encoding/json" "fmt" "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" "net/http" "strings" ) // 知识库工具 type KnowledgeBaseTool struct { config config.ToolConfig } // NewKnowledgeBaseTool 创建知识库工具 func NewKnowledgeBaseTool(config config.ToolConfig) *KnowledgeBaseTool { return &KnowledgeBaseTool{config: config} } func (k *KnowledgeBaseTool) GetConfig() config.ToolConfig { return k.config } // Name 返回工具名称 func (k *KnowledgeBaseTool) Name() string { return "knowledgeBase" } // Description 返回工具描述 func (k *KnowledgeBaseTool) Description() string { return "请求知识库" } // Definition 返回工具定义 func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition { return entitys.ToolDefinition{ Type: "function", Function: entitys.FunctionDef{ Name: k.Name(), Description: k.Description(), Parameters: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "query": map[string]interface{}{ "type": "string", "description": "知识库查询条件", }, }, "required": []string{"query"}, }, }, } } // Execute 执行知识库查询 func (k *KnowledgeBaseTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error { var params KnowledgeBaseRequest if err := json.Unmarshal(args, ¶ms); err != nil { return fmt.Errorf("unmarshal args failed: %w", err) } log.Info("开始执行知识库 KnowledgeBaseTool Execute, params: %v", params) return k.chat(channel, c, params) } type KnowledgeBaseRequest struct { Session string // 知识库会话id ApiKey string // 知识库apiKey Query string // 用户输入 } // Message 表示解析后的 SSE 消息 type Message struct { Event string // 事件类型(默认 "message") Data string // 消息内容(可能多行) ID string // 消息 ID(可选) } type MsgContent struct { Id string `json:"id"` ResponseType string `json:"response_type"` Content string `json:"content"` Done bool `json:"done"` KnowledgeReferences interface{} `json:"knowledge_references"` } // 解析知识库响应内容,并把通过channel结果返回 func msgContentParse(input string, channel chan entitys.ResponseData) (msgContent MsgContent, err error) { err = json.Unmarshal([]byte(input), &msgContent) if err != nil { err = fmt.Errorf("unmarshal input failed: %w", err) } channel <- entitys.ResponseData{ Done: msgContent.Done, Content: msgContent.Content, Type: entitys.ResponseStream, } return } // 请求知识库聊天 func (this *KnowledgeBaseTool) chat(channel chan entitys.ResponseData, c *websocket.Conn, param KnowledgeBaseRequest) (err error) { req := l_request.Request{ Method: "post", Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + param.Session, Params: nil, Headers: map[string]string{ "Content-Type": "application/json", "X-API-Key": param.ApiKey, }, Cookies: nil, Data: nil, Json: map[string]interface{}{ "query": param.Query, }, Files: nil, Raw: "", JsonByte: nil, Xml: nil, } rsp, err := req.SendNoParseResponse() if err != nil { return } defer rsp.Body.Close() err = connectAndReadSSE(rsp, channel) if err != nil { return } return } // 连接 SSE 并读取数据 func connectAndReadSSE(resp *http.Response, channel chan entitys.ResponseData) error { // 验证响应状态和格式 if resp.StatusCode != http.StatusOK { return fmt.Errorf("非 200 状态码: %d", resp.StatusCode) } contentType := resp.Header.Get("Content-Type") if !strings.Contains(contentType, "text/event-stream") { return fmt.Errorf("不支持的 Content-Type: %s", contentType) } // 逐行读取响应流 scanner := bufio.NewScanner(resp.Body) var currentMsg Message // 当前正在组装的消息 for scanner.Scan() { line := scanner.Text() if line == "" { // 空行表示一条消息结束,处理当前消息 if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { _, err := msgContentParse(currentMsg.Data, channel) if err != nil { return fmt.Errorf("msgContentParse failed: %w", err) } currentMsg = Message{} // 重置消息 } continue } // 解析字段(格式:"field: value") parts := strings.SplitN(line, ":", 2) if len(parts) < 2 { continue // 无效行(无冒号),跳过 } field := strings.TrimSpace(parts[0]) value := strings.TrimSpace(parts[1]) switch field { case "event": currentMsg.Event = value case "data": // data 可能多行,用换行符拼接(最后一条消息可能无结尾空行) currentMsg.Data += value + "" default: // 忽略未知字段 } } // 检查扫描错误(如连接断开) if err := scanner.Err(); err != nil { return fmt.Errorf("读取流失败: %w", err) } // 处理最后一条未结束的消息(无结尾空行) if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { _, err := msgContentParse(currentMsg.Data, channel) if err != nil { return fmt.Errorf("msgContentParse failed: %w", err) } } return nil } // 获取知识库 session func GetKnowledgeBaseSession(host, baseId, apiKey string) (string, error) { req := l_request.Request{ Method: "post", Url: host + "/api/v1/sessions", Params: nil, Headers: map[string]string{ "Content-Type": "application/json", "X-API-Key": apiKey, }, Cookies: nil, Data: nil, Json: map[string]interface{}{ "knowledge_base_id": baseId, }, Files: nil, Raw: "", JsonByte: nil, Xml: nil, } rsp, err := req.Send() if err != nil { return "", err } var result sessionRsp err = json.Unmarshal(rsp.Content, &result) return result.Data.Id, err } type sessionRsp struct { Data struct { Id string `json:"id"` Title string `json:"title"` Description string `json:"description"` TenantId int `json:"tenant_id"` KnowledgeBaseId string `json:"knowledge_base_id"` MaxRounds int `json:"max_rounds"` EnableRewrite bool `json:"enable_rewrite"` FallbackStrategy string `json:"fallback_strategy"` FallbackResponse string `json:"fallback_response"` EmbeddingTopK int `json:"embedding_top_k"` KeywordThreshold float64 `json:"keyword_threshold"` VectorThreshold float64 `json:"vector_threshold"` RerankModelId string `json:"rerank_model_id"` RerankTopK int `json:"rerank_top_k"` RerankThreshold float64 `json:"rerank_threshold"` SummaryModelId string `json:"summary_model_id"` } `json:"data"` Success bool `json:"success"` }