330 lines
10 KiB
Go
330 lines
10 KiB
Go
package ai
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
contextpkg "eino-project/internal/context"
|
||
"eino-project/internal/monitor"
|
||
"eino-project/internal/vector"
|
||
|
||
"github.com/cloudwego/eino/components/model"
|
||
"github.com/cloudwego/eino/schema"
|
||
"github.com/go-kratos/kratos/v2/log"
|
||
)
|
||
|
||
// AIService AI服务接口
|
||
type AIService interface {
|
||
ProcessChat(ctx context.Context, message string, sessionID string) (string, error)
|
||
StreamChat(ctx context.Context, message string, sessionID string) (<-chan string, error)
|
||
// AnalyzeIntent 使用意图模型(qwen3:8b)做意图识别,仅返回固定标签之一
|
||
AnalyzeIntent(ctx context.Context, message string) (string, error)
|
||
}
|
||
|
||
// OllamaRequest Ollama API 请求结构
|
||
type OllamaRequest struct {
|
||
Model string `json:"model"`
|
||
Prompt string `json:"prompt"`
|
||
Stream bool `json:"stream"`
|
||
}
|
||
|
||
// OllamaResponse Ollama API 响应结构
|
||
type OllamaResponse struct {
|
||
Model string `json:"model"`
|
||
Response string `json:"response"`
|
||
Done bool `json:"done"`
|
||
CreatedAt string `json:"created_at"`
|
||
}
|
||
|
||
// aiService AI服务实现
|
||
type aiService struct {
|
||
logger log.Logger
|
||
chatModel model.BaseChatModel
|
||
intentModel model.BaseChatModel
|
||
knowledgeSearcher vector.KnowledgeSearcher
|
||
contextManager contextpkg.ContextManager
|
||
monitor interface {
|
||
RecordLLMUsage(ctx context.Context, usage *monitor.LLMUsage) error
|
||
}
|
||
chatModelName string
|
||
}
|
||
|
||
// NewAIService 创建AI服务实例
|
||
func NewAIService(logger log.Logger, chatModel, intentModel model.BaseChatModel, knowledgeSearcher vector.KnowledgeSearcher, contextManager contextpkg.ContextManager) AIService {
|
||
return &aiService{
|
||
logger: logger,
|
||
chatModel: chatModel,
|
||
intentModel: intentModel,
|
||
knowledgeSearcher: knowledgeSearcher,
|
||
contextManager: contextManager,
|
||
}
|
||
}
|
||
|
||
// ProcessChat 处理聊天消息
|
||
func (s *aiService) ProcessChat(ctx context.Context, message string, sessionID string) (string, error) {
|
||
log.Context(ctx).Infof("Processing chat message: %s for session: %s", message, sessionID)
|
||
|
||
// 1. 添加消息到上下文管理器
|
||
if s.contextManager != nil {
|
||
msg := contextpkg.Message{
|
||
Role: "user",
|
||
Content: message,
|
||
Timestamp: time.Now(),
|
||
}
|
||
s.contextManager.AddMessage(ctx, sessionID, msg)
|
||
}
|
||
|
||
// 2. 搜索相关知识库内容
|
||
var knowledgeContext string
|
||
if s.knowledgeSearcher != nil {
|
||
knowledgeResults, err := s.knowledgeSearcher.SearchKnowledge(ctx, message, 3)
|
||
if err == nil && len(knowledgeResults) > 0 {
|
||
var contextParts []string
|
||
for _, result := range knowledgeResults {
|
||
contextParts = append(contextParts, fmt.Sprintf("相关知识: %s", result.Document.Content))
|
||
}
|
||
knowledgeContext = strings.Join(contextParts, "\n")
|
||
}
|
||
}
|
||
|
||
// 3. 构建增强的聊天消息
|
||
enhancedMessage := message
|
||
if knowledgeContext != "" {
|
||
enhancedMessage = fmt.Sprintf("基于以下知识库内容回答用户问题:\n%s\n\n用户问题: %s", knowledgeContext, message)
|
||
}
|
||
|
||
messages := []*schema.Message{
|
||
{
|
||
Role: schema.User,
|
||
Content: enhancedMessage,
|
||
},
|
||
}
|
||
|
||
// 4. 调用 Eino 聊天模型
|
||
start := time.Now()
|
||
response, err := s.chatModel.Generate(ctx, messages)
|
||
if err != nil {
|
||
log.Context(ctx).Warnf("Eino chat model call failed: %v, falling back to mock response", err)
|
||
// 如果 Eino 调用失败,返回模拟响应
|
||
return s.generateMockResponse(message), nil
|
||
}
|
||
|
||
if response == nil || response.Content == "" {
|
||
log.Context(ctx).Warn("Empty response from Eino chat model, falling back to mock response")
|
||
return s.generateMockResponse(message), nil
|
||
}
|
||
|
||
// 轻量化上报LLM使用:模型、token估算、延迟、知识命中等
|
||
if s.monitor != nil {
|
||
// eino 不直接暴露 tokens,这里用近似估算:中文字数/2,英文按空格词数*1
|
||
promptTokens := estimateTokens(enhancedMessage)
|
||
completionTokens := estimateTokens(response.Content)
|
||
// 使用专用意图模型识别意图(不使用自定义规则)
|
||
detectedIntent := "general_inquiry"
|
||
if intent, err := s.AnalyzeIntent(ctx, message); err == nil && intent != "" {
|
||
detectedIntent = intent
|
||
}
|
||
usage := &monitor.LLMUsage{
|
||
Model: s.chatModelName,
|
||
SessionID: sessionID,
|
||
UserID: "default_user",
|
||
PromptPreview: preview(enhancedMessage, 200),
|
||
PromptTokens: promptTokens,
|
||
CompletionTokens: completionTokens,
|
||
TotalTokens: promptTokens + completionTokens,
|
||
LatencyMS: time.Since(start).Milliseconds(),
|
||
AgentThought: "intent=" + detectedIntent,
|
||
KnowledgeHits: countLines(knowledgeContext),
|
||
Metadata: map[string]string{"source": "ai.ProcessChat"},
|
||
Timestamp: time.Now(),
|
||
}
|
||
_ = s.monitor.RecordLLMUsage(ctx, usage)
|
||
}
|
||
|
||
return response.Content, nil
|
||
}
|
||
|
||
// StreamChat 流式处理聊天消息
|
||
func (s *aiService) StreamChat(ctx context.Context, message string, sessionID string) (<-chan string, error) {
|
||
log.Context(ctx).Infof("Processing stream chat message: %s for session: %s", message, sessionID)
|
||
|
||
// 构建聊天消息
|
||
messages := []*schema.Message{
|
||
{
|
||
Role: schema.User,
|
||
Content: message,
|
||
},
|
||
}
|
||
|
||
// 调用 Eino 流式聊天模型
|
||
start := time.Now()
|
||
streamReader, err := s.chatModel.Stream(ctx, messages)
|
||
if err != nil {
|
||
log.Context(ctx).Warnf("Eino stream chat model call failed: %v, falling back to mock stream response", err)
|
||
// 如果 Eino 流式调用失败,返回模拟流式响应
|
||
return s.generateMockStreamResponse(ctx, message), nil
|
||
}
|
||
|
||
// 创建响应通道
|
||
responseChan := make(chan string, 10)
|
||
|
||
// 启动 goroutine 处理流式响应
|
||
go func() {
|
||
defer close(responseChan)
|
||
defer streamReader.Close()
|
||
|
||
for {
|
||
chunk, err := streamReader.Recv()
|
||
if err != nil {
|
||
// 流结束或出错
|
||
return
|
||
}
|
||
if chunk != nil && chunk.Content != "" {
|
||
select {
|
||
case responseChan <- chunk.Content:
|
||
case <-ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// 汇总轻量化上报:以流结束后估算总tokens与延迟
|
||
if s.monitor != nil {
|
||
// 注意:为了避免阻塞,这里不去累计所有chunk;仅使用输入估算 + 总延迟
|
||
detectedIntent := "general_inquiry"
|
||
if intent, err := s.AnalyzeIntent(ctx, message); err == nil && intent != "" {
|
||
detectedIntent = intent
|
||
}
|
||
usage := &monitor.LLMUsage{
|
||
Model: s.chatModelName,
|
||
SessionID: sessionID,
|
||
UserID: "default_user",
|
||
PromptPreview: preview(message, 200),
|
||
PromptTokens: estimateTokens(message),
|
||
// completionTokens 无法准确统计,留空或基于时间估算可选
|
||
TotalTokens: estimateTokens(message),
|
||
LatencyMS: time.Since(start).Milliseconds(),
|
||
AgentThought: "intent=" + detectedIntent,
|
||
Metadata: map[string]string{"source": "ai.StreamChat"},
|
||
Timestamp: time.Now(),
|
||
}
|
||
_ = s.monitor.RecordLLMUsage(ctx, usage)
|
||
}
|
||
}()
|
||
|
||
return responseChan, nil
|
||
}
|
||
|
||
// estimateTokens 简易估算 tokens 数(不准确,但用于监控趋势足够)
|
||
func estimateTokens(text string) int {
|
||
if text == "" {
|
||
return 0
|
||
}
|
||
// 粗略策略:英文按空格分词;中文按 rune 数/2
|
||
ascii := true
|
||
for _, r := range text {
|
||
if r > 127 {
|
||
ascii = false
|
||
break
|
||
}
|
||
}
|
||
if ascii {
|
||
// 英文:词数近似为token数
|
||
return len(strings.Fields(text))
|
||
}
|
||
// 中文:每2个汉字约1个token
|
||
return len([]rune(text)) / 2
|
||
}
|
||
|
||
func preview(text string, max int) string {
|
||
if len(text) <= max {
|
||
return text
|
||
}
|
||
return text[:max]
|
||
}
|
||
|
||
func countLines(text string) int {
|
||
if text == "" {
|
||
return 0
|
||
}
|
||
return len(strings.Split(text, "\n"))
|
||
}
|
||
|
||
// AnalyzeIntent 使用意图模型(qwen3:8b)进行意图识别,输出以下标签之一:
|
||
// order_inquiry、product_inquiry、technical_support、general_inquiry
|
||
func (s *aiService) AnalyzeIntent(ctx context.Context, message string) (string, error) {
|
||
if s.intentModel == nil {
|
||
return "general_inquiry", fmt.Errorf("intent model not configured")
|
||
}
|
||
instruction := "你是一个意图分类器。仅从以下标签中选择并输出一个,且只输出标签本身:order_inquiry、product_inquiry、technical_support、general_inquiry。用户消息:" + message
|
||
messages := []*schema.Message{
|
||
{Role: schema.User, Content: instruction},
|
||
}
|
||
resp, err := s.intentModel.Generate(ctx, messages)
|
||
if err != nil || resp == nil || resp.Content == "" {
|
||
return "general_inquiry", err
|
||
}
|
||
intent := strings.TrimSpace(strings.ToLower(resp.Content))
|
||
switch intent {
|
||
case "order_inquiry", "product_inquiry", "technical_support", "general_inquiry":
|
||
return intent, nil
|
||
}
|
||
// 输出不在预期集合,回退为general
|
||
return "general_inquiry", nil
|
||
}
|
||
|
||
// generateMockStreamResponse 生成模拟流式响应
|
||
func (s *aiService) generateMockStreamResponse(ctx context.Context, message string) <-chan string {
|
||
responseChan := make(chan string, 10)
|
||
|
||
go func() {
|
||
defer close(responseChan)
|
||
|
||
// 获取模拟响应
|
||
mockResponse := s.generateMockResponse(message)
|
||
|
||
// 将响应分割成小块,模拟流式输出
|
||
words := strings.Fields(mockResponse)
|
||
for i, word := range words {
|
||
select {
|
||
case responseChan <- word:
|
||
if i < len(words)-1 {
|
||
// 在单词之间添加空格,除了最后一个单词
|
||
select {
|
||
case responseChan <- " ":
|
||
case <-ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
// 模拟网络延迟
|
||
time.Sleep(50 * time.Millisecond)
|
||
case <-ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
|
||
return responseChan
|
||
}
|
||
|
||
// generateMockResponse 生成模拟响应
|
||
func (s *aiService) generateMockResponse(message string) string {
|
||
// 简单的模拟响应逻辑
|
||
message = strings.ToLower(strings.TrimSpace(message))
|
||
|
||
switch {
|
||
case strings.Contains(message, "你好") || strings.Contains(message, "hello"):
|
||
return "你好!有什么我可以帮你的吗?😊"
|
||
case strings.Contains(message, "天气"):
|
||
return "抱歉,我无法获取实时天气信息。建议您查看天气预报应用。"
|
||
case strings.Contains(message, "时间"):
|
||
return fmt.Sprintf("当前时间是 %s", time.Now().Format("2006-01-02 15:04:05"))
|
||
case strings.Contains(message, "介绍"):
|
||
return "我是一个AI助手,可以帮助您回答问题、提供信息和进行对话。有什么我可以为您做的吗?"
|
||
default:
|
||
return "感谢您的消息!我正在学习中,暂时无法完全理解您的问题。请尝试用不同的方式表达,或者问一些简单的问题。"
|
||
}
|
||
}
|