169 lines
4.3 KiB
Go
169 lines
4.3 KiB
Go
package util
|
||
|
||
import (
|
||
"ai_scheduler/internal/pkg/l_request"
|
||
"bufio"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
|
||
"net/http"
|
||
"strings"
|
||
)
|
||
|
||
type KnowledgeBase struct {
|
||
session string
|
||
url string
|
||
apiKey string
|
||
}
|
||
|
||
func NewKnowledgeBase(url, apiKey, session string) *KnowledgeBase {
|
||
return &KnowledgeBase{
|
||
session: session,
|
||
url: url,
|
||
apiKey: apiKey,
|
||
}
|
||
}
|
||
|
||
// 请求知识库聊天
|
||
func (this *KnowledgeBase) Chat(ctx context.Context, query string) (text string, err error) {
|
||
|
||
req := l_request.Request{
|
||
Method: "post",
|
||
Url: this.url + "/api/v1/knowledge-chat/" + this.session,
|
||
Params: nil,
|
||
Headers: map[string]string{
|
||
"Content-Type": "application/json",
|
||
"X-API-Key": this.apiKey,
|
||
},
|
||
Cookies: nil,
|
||
Data: nil,
|
||
Json: map[string]interface{}{
|
||
"query": query,
|
||
},
|
||
Files: nil,
|
||
Raw: "",
|
||
JsonByte: nil,
|
||
Xml: nil,
|
||
}
|
||
|
||
rsp, err := req.SendNoParseResponse()
|
||
if err != nil {
|
||
return
|
||
}
|
||
defer rsp.Body.Close()
|
||
|
||
err = connectAndReadSSE(rsp)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// Message 表示解析后的 SSE 消息
|
||
type Message struct {
|
||
Event string // 事件类型(默认 "message")
|
||
Data string // 消息内容(可能多行)
|
||
ID string // 消息 ID(可选)
|
||
}
|
||
|
||
// 连接 SSE 并读取数据
|
||
func connectAndReadSSE(resp *http.Response) error {
|
||
|
||
// 验证响应状态和格式
|
||
if resp.StatusCode != http.StatusOK {
|
||
return fmt.Errorf("非 200 状态码: %d", resp.StatusCode)
|
||
}
|
||
contentType := resp.Header.Get("Content-Type")
|
||
if !strings.Contains(contentType, "text/event-stream") {
|
||
return fmt.Errorf("不支持的 Content-Type: %s", contentType)
|
||
}
|
||
|
||
// 逐行读取响应流
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
var currentMsg Message // 当前正在组装的消息
|
||
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if line == "" {
|
||
// 空行表示一条消息结束,处理当前消息
|
||
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
|
||
printMessage(currentMsg)
|
||
currentMsg = Message{} // 重置消息
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 解析字段(格式:"field: value")
|
||
parts := strings.SplitN(line, ":", 2)
|
||
if len(parts) < 2 {
|
||
continue // 无效行(无冒号),跳过
|
||
}
|
||
field := strings.TrimSpace(parts[0])
|
||
value := strings.TrimSpace(parts[1])
|
||
|
||
switch field {
|
||
case "event":
|
||
currentMsg.Event = value
|
||
case "data":
|
||
// data 可能多行,用换行符拼接(最后一条消息可能无结尾空行)
|
||
currentMsg.Data += value + ""
|
||
//case "id":
|
||
// currentMsg.ID = value
|
||
// 可选:处理 "retry" 字段(服务器建议的重连时间,单位秒)
|
||
}
|
||
}
|
||
|
||
// 检查扫描错误(如连接断开)
|
||
if err := scanner.Err(); err != nil {
|
||
return fmt.Errorf("读取流失败: %w", err)
|
||
}
|
||
|
||
// 处理最后一条未结束的消息(无结尾空行)
|
||
if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
|
||
printMessage(currentMsg)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
type MegContent struct {
|
||
Id string `json:"id"` // 消息 ID
|
||
ResponseType string `json:"response_type"` // 响应类型,answer 或 references
|
||
Content string `json:"content"` // 消息内容
|
||
Done bool `json:"done"` // 是否完成
|
||
KnowledgeReferences interface{} `json:"knowledge_references"`
|
||
}
|
||
|
||
// printMessage 打印解析后的 SSE 消息
|
||
func printMessage(msg Message) {
|
||
|
||
//fmt.Printf("--- 收到 SSE 消息 ---")
|
||
//fmt.Printf("事件类型: %s,", msg.Event)
|
||
//fmt.Printf("消息 ID: %s,", msg.ID)
|
||
//fmt.Printf("内容:%s,", strings.TrimSpace(msg.Data)) // 去除末尾多余换行
|
||
|
||
var content MegContent
|
||
_ = json.Unmarshal([]byte(msg.Data), &content)
|
||
fmt.Println(msg.Data)
|
||
|
||
//if content.ResponseType == "answer" {
|
||
// //fmt.Printf("%s", content.Content)
|
||
// fmt.Println(content)
|
||
//} else {
|
||
// fmt.Printf("--- 收到 SSE 消息 ---")
|
||
// fmt.Printf("事件类型: %s,", msg.Event)
|
||
// fmt.Printf("消息 ID: %s,", msg.ID)
|
||
// fmt.Printf("内容:%s,", strings.TrimSpace(msg.Data)) // 去除末尾多余换行
|
||
//}
|
||
|
||
}
|
||
|
||
// getRetryAfter 从响应头获取重连时间(示例,需根据实际响应头调整)
|
||
func getRetryAfter(url string) int {
|
||
// 实际需重新请求并获取响应头(此处简化为固定值)
|
||
// 正确做法:在 connectAndReadSSE 中记录响应头的 Retry-After 字段
|
||
return 5 // 示例:等待 5 秒
|
||
}
|