diff --git a/internal/tools/konwledge_base.go b/internal/tools/konwledge_base.go index bf82b28..b614c11 100644 --- a/internal/tools/konwledge_base.go +++ b/internal/tools/konwledge_base.go @@ -7,6 +7,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" "net/http" "strings" @@ -64,6 +65,7 @@ func (k *KnowledgeBaseTool) Execute(channel chan entitys.ResponseData, c *websoc if err := json.Unmarshal(args, ¶ms); err != nil { return fmt.Errorf("unmarshal args failed: %w", err) } + log.Info("开始执行知识库 KnowledgeBaseTool Execute, params: %v", params) return k.chat(channel, c, params) @@ -82,7 +84,7 @@ type Message struct { ID string // 消息 ID(可选) } -type MegContent struct { +type MsgContent struct { Id string `json:"id"` ResponseType string `json:"response_type"` Content string `json:"content"` @@ -90,6 +92,22 @@ type MegContent struct { KnowledgeReferences interface{} `json:"knowledge_references"` } +// 解析知识库响应内容,并把通过channel结果返回 +func msgContentParse(input string, channel chan entitys.ResponseData) (msgContent MsgContent, err error) { + err = json.Unmarshal([]byte(input), &msgContent) + if err != nil { + err = fmt.Errorf("unmarshal input failed: %w", err) + } + + channel <- entitys.ResponseData{ + Done: msgContent.Done, + Content: msgContent.Content, + Type: entitys.ResponseStream, + } + + return +} + // 请求知识库聊天 func (this *KnowledgeBaseTool) chat(channel chan entitys.ResponseData, c *websocket.Conn, param KnowledgeBaseRequest) (err error) { @@ -147,10 +165,9 @@ func connectAndReadSSE(resp *http.Response, channel chan entitys.ResponseData) e if line == "" { // 空行表示一条消息结束,处理当前消息 if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { - channel <- entitys.ResponseData{ - Done: false, - Content: currentMsg.Data, - Type: entitys.ResponseJson, + _, err := msgContentParse(currentMsg.Data, channel) + if err != nil { + return fmt.Errorf("msgContentParse failed: %w", err) } currentMsg = Message{} // 重置消息 } @@ -183,10 +200,9 @@ func connectAndReadSSE(resp *http.Response, channel chan entitys.ResponseData) e // 处理最后一条未结束的消息(无结尾空行) if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" { - channel <- entitys.ResponseData{ - Done: false, - Content: currentMsg.Data, - Type: entitys.ResponseJson, + _, err := msgContentParse(currentMsg.Data, channel) + if err != nil { + return fmt.Errorf("msgContentParse failed: %w", err) } }