ai_scheduler/internal/tools/konwledge_base.go

265 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/l_request"
"bufio"
"encoding/json"
"fmt"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2"
"net/http"
"strings"
)
// 知识库工具
type KnowledgeBaseTool struct {
config config.ToolConfig
}
// NewKnowledgeBaseTool 创建知识库工具
func NewKnowledgeBaseTool(config config.ToolConfig) *KnowledgeBaseTool {
return &KnowledgeBaseTool{config: config}
}
func (k *KnowledgeBaseTool) GetConfig() config.ToolConfig {
return k.config
}
// Name 返回工具名称
func (k *KnowledgeBaseTool) Name() string {
return "knowledgeBase"
}
// Description 返回工具描述
func (k *KnowledgeBaseTool) Description() string {
return "请求知识库"
}
// Definition 返回工具定义
func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition {
return entitys.ToolDefinition{
Type: "function",
Function: entitys.FunctionDef{
Name: k.Name(),
Description: k.Description(),
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "知识库查询条件",
},
},
"required": []string{"query"},
},
},
}
}
// Execute 执行知识库查询
func (k *KnowledgeBaseTool) Execute(channel chan entitys.ResponseData, c *websocket.Conn, args json.RawMessage) error {
var params KnowledgeBaseRequest
if err := json.Unmarshal(args, &params); err != nil {
return fmt.Errorf("unmarshal args failed: %w", err)
}
log.Info("开始执行知识库 KnowledgeBaseTool Execute, params: %v", params)
return k.chat(channel, c, params)
}
type KnowledgeBaseRequest struct {
Session string // 知识库会话id
ApiKey string // 知识库apiKey
Query string // 用户输入
}
// Message 表示解析后的 SSE 消息
type Message struct {
Event string // 事件类型(默认 "message"
Data string // 消息内容(可能多行)
ID string // 消息 ID可选
}
type MsgContent struct {
Id string `json:"id"`
ResponseType string `json:"response_type"`
Content string `json:"content"`
Done bool `json:"done"`
KnowledgeReferences interface{} `json:"knowledge_references"`
}
// 解析知识库响应内容,并把通过channel结果返回
func msgContentParse(input string, channel chan entitys.ResponseData) (msgContent MsgContent, err error) {
err = json.Unmarshal([]byte(input), &msgContent)
if err != nil {
err = fmt.Errorf("unmarshal input failed: %w", err)
}
channel <- entitys.ResponseData{
Done: msgContent.Done,
Content: msgContent.Content,
Type: entitys.ResponseStream,
}
return
}
// 请求知识库聊天
func (this *KnowledgeBaseTool) chat(channel chan entitys.ResponseData, c *websocket.Conn, param KnowledgeBaseRequest) (err error) {
req := l_request.Request{
Method: "post",
Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + param.Session,
Params: nil,
Headers: map[string]string{
"Content-Type": "application/json",
"X-API-Key": param.ApiKey,
},
Cookies: nil,
Data: nil,
Json: map[string]interface{}{
"query": param.Query,
},
Files: nil,
Raw: "",
JsonByte: nil,
Xml: nil,
}
rsp, err := req.SendNoParseResponse()
if err != nil {
return
}
defer rsp.Body.Close()
err = connectAndReadSSE(rsp, channel)
if err != nil {
return
}
return
}
// 连接 SSE 并读取数据
func connectAndReadSSE(resp *http.Response, channel chan entitys.ResponseData) error {
// 验证响应状态和格式
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("非 200 状态码: %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(contentType, "text/event-stream") {
return fmt.Errorf("不支持的 Content-Type: %s", contentType)
}
// 逐行读取响应流
scanner := bufio.NewScanner(resp.Body)
var currentMsg Message // 当前正在组装的消息
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// 空行表示一条消息结束,处理当前消息
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
_, err := msgContentParse(currentMsg.Data, channel)
if err != nil {
return fmt.Errorf("msgContentParse failed: %w", err)
}
currentMsg = Message{} // 重置消息
}
continue
}
// 解析字段(格式:"field: value"
parts := strings.SplitN(line, ":", 2)
if len(parts) < 2 {
continue // 无效行(无冒号),跳过
}
field := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
switch field {
case "event":
currentMsg.Event = value
case "data":
// data 可能多行,用换行符拼接(最后一条消息可能无结尾空行)
currentMsg.Data += value + ""
default:
// 忽略未知字段
}
}
// 检查扫描错误(如连接断开)
if err := scanner.Err(); err != nil {
return fmt.Errorf("读取流失败: %w", err)
}
// 处理最后一条未结束的消息(无结尾空行)
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
_, err := msgContentParse(currentMsg.Data, channel)
if err != nil {
return fmt.Errorf("msgContentParse failed: %w", err)
}
}
return nil
}
// 获取知识库 session
func GetKnowledgeBaseSession(host, baseId, apiKey string) (string, error) {
req := l_request.Request{
Method: "post",
Url: host + "/api/v1/sessions",
Params: nil,
Headers: map[string]string{
"Content-Type": "application/json",
"X-API-Key": apiKey,
},
Cookies: nil,
Data: nil,
Json: map[string]interface{}{
"knowledge_base_id": baseId,
},
Files: nil,
Raw: "",
JsonByte: nil,
Xml: nil,
}
rsp, err := req.Send()
if err != nil {
return "", err
}
var result sessionRsp
err = json.Unmarshal(rsp.Content, &result)
return result.Data.Id, err
}
type sessionRsp struct {
Data struct {
Id string `json:"id"`
Title string `json:"title"`
Description string `json:"description"`
TenantId int `json:"tenant_id"`
KnowledgeBaseId string `json:"knowledge_base_id"`
MaxRounds int `json:"max_rounds"`
EnableRewrite bool `json:"enable_rewrite"`
FallbackStrategy string `json:"fallback_strategy"`
FallbackResponse string `json:"fallback_response"`
EmbeddingTopK int `json:"embedding_top_k"`
KeywordThreshold float64 `json:"keyword_threshold"`
VectorThreshold float64 `json:"vector_threshold"`
RerankModelId string `json:"rerank_model_id"`
RerankTopK int `json:"rerank_top_k"`
RerankThreshold float64 `json:"rerank_threshold"`
SummaryModelId string `json:"summary_model_id"`
} `json:"data"`
Success bool `json:"success"`
}