266 lines
6.9 KiB
Go
266 lines
6.9 KiB
Go
package tools
|
||
|
||
import (
|
||
"ai_scheduler/internal/config"
|
||
"ai_scheduler/internal/entitys"
|
||
"ai_scheduler/internal/pkg/l_request"
|
||
"bufio"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/gofiber/fiber/v2/log"
|
||
"github.com/gofiber/websocket/v2"
|
||
)
|
||
|
||
// 知识库工具
|
||
type KnowledgeBaseTool struct {
|
||
config config.ToolConfig
|
||
}
|
||
|
||
// NewKnowledgeBaseTool 创建知识库工具
|
||
func NewKnowledgeBaseTool(config config.ToolConfig) *KnowledgeBaseTool {
|
||
return &KnowledgeBaseTool{config: config}
|
||
}
|
||
|
||
func (k *KnowledgeBaseTool) GetConfig() config.ToolConfig {
|
||
return k.config
|
||
}
|
||
|
||
// Name 返回工具名称
|
||
func (k *KnowledgeBaseTool) Name() string {
|
||
return "knowledgeBase"
|
||
}
|
||
|
||
// Description 返回工具描述
|
||
func (k *KnowledgeBaseTool) Description() string {
|
||
return "请求知识库"
|
||
}
|
||
|
||
// Definition 返回工具定义
|
||
func (k *KnowledgeBaseTool) Definition() entitys.ToolDefinition {
|
||
return entitys.ToolDefinition{
|
||
Type: "function",
|
||
Function: entitys.FunctionDef{
|
||
Name: k.Name(),
|
||
Description: k.Description(),
|
||
Parameters: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"query": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "知识库查询条件",
|
||
},
|
||
},
|
||
"required": []string{"query"},
|
||
},
|
||
},
|
||
}
|
||
}
|
||
|
||
// Execute 执行知识库查询
|
||
func (k *KnowledgeBaseTool) Execute(channel chan entitys.Response, c *websocket.Conn, args json.RawMessage, matchJson *entitys.Match) error {
|
||
|
||
var params KnowledgeBaseRequest
|
||
if err := json.Unmarshal(args, ¶ms); err != nil {
|
||
return fmt.Errorf("unmarshal args failed: %w", err)
|
||
}
|
||
log.Info("开始执行知识库 KnowledgeBaseTool Execute, params: %v", params)
|
||
|
||
return k.chat(channel, c, params)
|
||
|
||
}
|
||
|
||
type KnowledgeBaseRequest struct {
|
||
Session string // 知识库会话id
|
||
ApiKey string // 知识库apiKey
|
||
Query string // 用户输入
|
||
}
|
||
|
||
// Message 表示解析后的 SSE 消息
|
||
type Message struct {
|
||
Event string // 事件类型(默认 "message")
|
||
Data string // 消息内容(可能多行)
|
||
ID string // 消息 ID(可选)
|
||
}
|
||
|
||
type MsgContent struct {
|
||
Id string `json:"id"`
|
||
ResponseType string `json:"response_type"`
|
||
Content string `json:"content"`
|
||
Done bool `json:"done"`
|
||
KnowledgeReferences interface{} `json:"knowledge_references"`
|
||
}
|
||
|
||
// 解析知识库响应内容,并把通过channel结果返回
|
||
func (this *KnowledgeBaseTool) msgContentParse(input string, channel chan entitys.Response) (msgContent MsgContent, err error) {
|
||
err = json.Unmarshal([]byte(input), &msgContent)
|
||
if err != nil {
|
||
err = fmt.Errorf("unmarshal input failed: %w", err)
|
||
}
|
||
|
||
channel <- entitys.Response{
|
||
Index: this.Name(),
|
||
Content: msgContent.Content,
|
||
Type: entitys.ResponseStream,
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// 请求知识库聊天
|
||
func (this *KnowledgeBaseTool) chat(channel chan entitys.Response, c *websocket.Conn, param KnowledgeBaseRequest) (err error) {
|
||
|
||
req := l_request.Request{
|
||
Method: "post",
|
||
Url: this.config.BaseURL + "/api/v1/knowledge-chat/" + param.Session,
|
||
Params: nil,
|
||
Headers: map[string]string{
|
||
"Content-Type": "application/json",
|
||
"X-API-Key": param.ApiKey,
|
||
},
|
||
Cookies: nil,
|
||
Data: nil,
|
||
Json: map[string]interface{}{
|
||
"query": param.Query,
|
||
},
|
||
Files: nil,
|
||
Raw: "",
|
||
JsonByte: nil,
|
||
Xml: nil,
|
||
}
|
||
|
||
rsp, err := req.SendNoParseResponse()
|
||
if err != nil {
|
||
return
|
||
}
|
||
defer rsp.Body.Close()
|
||
|
||
err = this.connectAndReadSSE(rsp, channel)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// 连接 SSE 并读取数据
|
||
func (this *KnowledgeBaseTool) connectAndReadSSE(resp *http.Response, channel chan entitys.Response) error {
|
||
|
||
// 验证响应状态和格式
|
||
if resp.StatusCode != http.StatusOK {
|
||
return fmt.Errorf("非 200 状态码: %d", resp.StatusCode)
|
||
}
|
||
contentType := resp.Header.Get("Content-Type")
|
||
if !strings.Contains(contentType, "text/event-stream") {
|
||
return fmt.Errorf("不支持的 Content-Type: %s", contentType)
|
||
}
|
||
|
||
// 逐行读取响应流
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
var currentMsg Message // 当前正在组装的消息
|
||
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if line == "" {
|
||
// 空行表示一条消息结束,处理当前消息
|
||
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
|
||
_, err := this.msgContentParse(currentMsg.Data, channel)
|
||
if err != nil {
|
||
return fmt.Errorf("msgContentParse failed: %w", err)
|
||
}
|
||
currentMsg = Message{} // 重置消息
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 解析字段(格式:"field: value")
|
||
parts := strings.SplitN(line, ":", 2)
|
||
if len(parts) < 2 {
|
||
continue // 无效行(无冒号),跳过
|
||
}
|
||
field := strings.TrimSpace(parts[0])
|
||
value := strings.TrimSpace(parts[1])
|
||
|
||
switch field {
|
||
case "event":
|
||
currentMsg.Event = value
|
||
case "data":
|
||
// data 可能多行,用换行符拼接(最后一条消息可能无结尾空行)
|
||
currentMsg.Data += value + ""
|
||
default:
|
||
// 忽略未知字段
|
||
}
|
||
}
|
||
|
||
// 检查扫描错误(如连接断开)
|
||
if err := scanner.Err(); err != nil {
|
||
return fmt.Errorf("读取流失败: %w", err)
|
||
}
|
||
|
||
// 处理最后一条未结束的消息(无结尾空行)
|
||
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
|
||
_, err := this.msgContentParse(currentMsg.Data, channel)
|
||
if err != nil {
|
||
return fmt.Errorf("msgContentParse failed: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 获取知识库 session
|
||
func GetKnowledgeBaseSession(host, baseId, apiKey string) (string, error) {
|
||
req := l_request.Request{
|
||
Method: "post",
|
||
Url: host + "/api/v1/sessions",
|
||
Params: nil,
|
||
Headers: map[string]string{
|
||
"Content-Type": "application/json",
|
||
"X-API-Key": apiKey,
|
||
},
|
||
Cookies: nil,
|
||
Data: nil,
|
||
Json: map[string]interface{}{
|
||
"knowledge_base_id": baseId,
|
||
},
|
||
Files: nil,
|
||
Raw: "",
|
||
JsonByte: nil,
|
||
Xml: nil,
|
||
}
|
||
|
||
rsp, err := req.Send()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
var result sessionRsp
|
||
err = json.Unmarshal(rsp.Content, &result)
|
||
|
||
return result.Data.Id, err
|
||
}
|
||
|
||
type sessionRsp struct {
|
||
Data struct {
|
||
Id string `json:"id"`
|
||
Title string `json:"title"`
|
||
Description string `json:"description"`
|
||
TenantId int `json:"tenant_id"`
|
||
KnowledgeBaseId string `json:"knowledge_base_id"`
|
||
MaxRounds int `json:"max_rounds"`
|
||
EnableRewrite bool `json:"enable_rewrite"`
|
||
FallbackStrategy string `json:"fallback_strategy"`
|
||
FallbackResponse string `json:"fallback_response"`
|
||
EmbeddingTopK int `json:"embedding_top_k"`
|
||
KeywordThreshold float64 `json:"keyword_threshold"`
|
||
VectorThreshold float64 `json:"vector_threshold"`
|
||
RerankModelId string `json:"rerank_model_id"`
|
||
RerankTopK int `json:"rerank_top_k"`
|
||
RerankThreshold float64 `json:"rerank_threshold"`
|
||
SummaryModelId string `json:"summary_model_id"`
|
||
} `json:"data"`
|
||
Success bool `json:"success"`
|
||
}
|