169 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			169 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Go
		
	
	
	
package util
 | 
						||
 | 
						||
import (
 | 
						||
	"ai_scheduler/internal/pkg/l_request"
 | 
						||
	"bufio"
 | 
						||
	"context"
 | 
						||
	"encoding/json"
 | 
						||
	"fmt"
 | 
						||
 | 
						||
	"net/http"
 | 
						||
	"strings"
 | 
						||
)
 | 
						||
 | 
						||
type KnowledgeBase struct {
 | 
						||
	session string
 | 
						||
	url     string
 | 
						||
	apiKey  string
 | 
						||
}
 | 
						||
 | 
						||
func NewKnowledgeBase(url, apiKey, session string) *KnowledgeBase {
 | 
						||
	return &KnowledgeBase{
 | 
						||
		session: session,
 | 
						||
		url:     url,
 | 
						||
		apiKey:  apiKey,
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// 请求知识库聊天
 | 
						||
func (this *KnowledgeBase) Chat(ctx context.Context, query string) (text string, err error) {
 | 
						||
 | 
						||
	req := l_request.Request{
 | 
						||
		Method: "post",
 | 
						||
		Url:    this.url + "/api/v1/knowledge-chat/" + this.session,
 | 
						||
		Params: nil,
 | 
						||
		Headers: map[string]string{
 | 
						||
			"Content-Type": "application/json",
 | 
						||
			"X-API-Key":    this.apiKey,
 | 
						||
		},
 | 
						||
		Cookies: nil,
 | 
						||
		Data:    nil,
 | 
						||
		Json: map[string]interface{}{
 | 
						||
			"query": query,
 | 
						||
		},
 | 
						||
		Files:    nil,
 | 
						||
		Raw:      "",
 | 
						||
		JsonByte: nil,
 | 
						||
		Xml:      nil,
 | 
						||
	}
 | 
						||
 | 
						||
	rsp, err := req.SendNoParseResponse()
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	defer rsp.Body.Close()
 | 
						||
 | 
						||
	err = connectAndReadSSE(rsp)
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
// Message 表示解析后的 SSE 消息
 | 
						||
type Message struct {
 | 
						||
	Event string // 事件类型(默认 "message")
 | 
						||
	Data  string // 消息内容(可能多行)
 | 
						||
	ID    string // 消息 ID(可选)
 | 
						||
}
 | 
						||
 | 
						||
// 连接 SSE 并读取数据
 | 
						||
func connectAndReadSSE(resp *http.Response) error {
 | 
						||
 | 
						||
	// 验证响应状态和格式
 | 
						||
	if resp.StatusCode != http.StatusOK {
 | 
						||
		return fmt.Errorf("非 200 状态码: %d", resp.StatusCode)
 | 
						||
	}
 | 
						||
	contentType := resp.Header.Get("Content-Type")
 | 
						||
	if !strings.Contains(contentType, "text/event-stream") {
 | 
						||
		return fmt.Errorf("不支持的 Content-Type: %s", contentType)
 | 
						||
	}
 | 
						||
 | 
						||
	// 逐行读取响应流
 | 
						||
	scanner := bufio.NewScanner(resp.Body)
 | 
						||
	var currentMsg Message // 当前正在组装的消息
 | 
						||
 | 
						||
	for scanner.Scan() {
 | 
						||
		line := scanner.Text()
 | 
						||
		if line == "" {
 | 
						||
			// 空行表示一条消息结束,处理当前消息
 | 
						||
			if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
 | 
						||
				printMessage(currentMsg)
 | 
						||
				currentMsg = Message{} // 重置消息
 | 
						||
			}
 | 
						||
			continue
 | 
						||
		}
 | 
						||
 | 
						||
		// 解析字段(格式:"field: value")
 | 
						||
		parts := strings.SplitN(line, ":", 2)
 | 
						||
		if len(parts) < 2 {
 | 
						||
			continue // 无效行(无冒号),跳过
 | 
						||
		}
 | 
						||
		field := strings.TrimSpace(parts[0])
 | 
						||
		value := strings.TrimSpace(parts[1])
 | 
						||
 | 
						||
		switch field {
 | 
						||
		case "event":
 | 
						||
			currentMsg.Event = value
 | 
						||
		case "data":
 | 
						||
			// data 可能多行,用换行符拼接(最后一条消息可能无结尾空行)
 | 
						||
			currentMsg.Data += value + ""
 | 
						||
			//case "id":
 | 
						||
			//	currentMsg.ID = value
 | 
						||
			// 可选:处理 "retry" 字段(服务器建议的重连时间,单位秒)
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	// 检查扫描错误(如连接断开)
 | 
						||
	if err := scanner.Err(); err != nil {
 | 
						||
		return fmt.Errorf("读取流失败: %w", err)
 | 
						||
	}
 | 
						||
 | 
						||
	// 处理最后一条未结束的消息(无结尾空行)
 | 
						||
	if currentMsg.Data != "" || currentMsg.Event != "" || currentMsg.ID != "" {
 | 
						||
		printMessage(currentMsg)
 | 
						||
	}
 | 
						||
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
type MegContent struct {
 | 
						||
	Id                  string      `json:"id"`            // 消息 ID
 | 
						||
	ResponseType        string      `json:"response_type"` // 响应类型,answer 或 references
 | 
						||
	Content             string      `json:"content"`       // 消息内容
 | 
						||
	Done                bool        `json:"done"`          // 是否完成
 | 
						||
	KnowledgeReferences interface{} `json:"knowledge_references"`
 | 
						||
}
 | 
						||
 | 
						||
// printMessage 打印解析后的 SSE 消息
 | 
						||
func printMessage(msg Message) {
 | 
						||
 | 
						||
	//fmt.Printf("--- 收到 SSE 消息 ---")
 | 
						||
	//fmt.Printf("事件类型: %s,", msg.Event)
 | 
						||
	//fmt.Printf("消息 ID: %s,", msg.ID)
 | 
						||
	//fmt.Printf("内容:%s,", strings.TrimSpace(msg.Data)) // 去除末尾多余换行
 | 
						||
 | 
						||
	var content MegContent
 | 
						||
	_ = json.Unmarshal([]byte(msg.Data), &content)
 | 
						||
	fmt.Println(msg.Data)
 | 
						||
 | 
						||
	//if content.ResponseType == "answer" {
 | 
						||
	//	//fmt.Printf("%s", content.Content)
 | 
						||
	//	fmt.Println(content)
 | 
						||
	//} else {
 | 
						||
	//	fmt.Printf("--- 收到 SSE 消息 ---")
 | 
						||
	//	fmt.Printf("事件类型: %s,", msg.Event)
 | 
						||
	//	fmt.Printf("消息 ID: %s,", msg.ID)
 | 
						||
	//	fmt.Printf("内容:%s,", strings.TrimSpace(msg.Data)) // 去除末尾多余换行
 | 
						||
	//}
 | 
						||
 | 
						||
}
 | 
						||
 | 
						||
// getRetryAfter 从响应头获取重连时间(示例,需根据实际响应头调整)
 | 
						||
func getRetryAfter(url string) int {
 | 
						||
	// 实际需重新请求并获取响应头(此处简化为固定值)
 | 
						||
	// 正确做法:在 connectAndReadSSE 中记录响应头的 Retry-After 字段
 | 
						||
	return 5 // 示例:等待 5 秒
 | 
						||
}
 |