ai-courseware/eino-project/internal/ai/ai.go

330 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ai
import (
"context"
"fmt"
"strings"
"time"
contextpkg "eino-project/internal/context"
"eino-project/internal/monitor"
"eino-project/internal/vector"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/go-kratos/kratos/v2/log"
)
// AIService AI服务接口
type AIService interface {
ProcessChat(ctx context.Context, message string, sessionID string) (string, error)
StreamChat(ctx context.Context, message string, sessionID string) (<-chan string, error)
// AnalyzeIntent 使用意图模型qwen3:8b做意图识别仅返回固定标签之一
AnalyzeIntent(ctx context.Context, message string) (string, error)
}
// OllamaRequest Ollama API 请求结构
type OllamaRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
}
// OllamaResponse Ollama API 响应结构
type OllamaResponse struct {
Model string `json:"model"`
Response string `json:"response"`
Done bool `json:"done"`
CreatedAt string `json:"created_at"`
}
// aiService AI服务实现
type aiService struct {
logger log.Logger
chatModel model.BaseChatModel
intentModel model.BaseChatModel
knowledgeSearcher vector.KnowledgeSearcher
contextManager contextpkg.ContextManager
monitor interface {
RecordLLMUsage(ctx context.Context, usage *monitor.LLMUsage) error
}
chatModelName string
}
// NewAIService 创建AI服务实例
func NewAIService(logger log.Logger, chatModel, intentModel model.BaseChatModel, knowledgeSearcher vector.KnowledgeSearcher, contextManager contextpkg.ContextManager) AIService {
return &aiService{
logger: logger,
chatModel: chatModel,
intentModel: intentModel,
knowledgeSearcher: knowledgeSearcher,
contextManager: contextManager,
}
}
// ProcessChat 处理聊天消息
func (s *aiService) ProcessChat(ctx context.Context, message string, sessionID string) (string, error) {
log.Context(ctx).Infof("Processing chat message: %s for session: %s", message, sessionID)
// 1. 添加消息到上下文管理器
if s.contextManager != nil {
msg := contextpkg.Message{
Role: "user",
Content: message,
Timestamp: time.Now(),
}
s.contextManager.AddMessage(ctx, sessionID, msg)
}
// 2. 搜索相关知识库内容
var knowledgeContext string
if s.knowledgeSearcher != nil {
knowledgeResults, err := s.knowledgeSearcher.SearchKnowledge(ctx, message, 3)
if err == nil && len(knowledgeResults) > 0 {
var contextParts []string
for _, result := range knowledgeResults {
contextParts = append(contextParts, fmt.Sprintf("相关知识: %s", result.Document.Content))
}
knowledgeContext = strings.Join(contextParts, "\n")
}
}
// 3. 构建增强的聊天消息
enhancedMessage := message
if knowledgeContext != "" {
enhancedMessage = fmt.Sprintf("基于以下知识库内容回答用户问题:\n%s\n\n用户问题: %s", knowledgeContext, message)
}
messages := []*schema.Message{
{
Role: schema.User,
Content: enhancedMessage,
},
}
// 4. 调用 Eino 聊天模型
start := time.Now()
response, err := s.chatModel.Generate(ctx, messages)
if err != nil {
log.Context(ctx).Warnf("Eino chat model call failed: %v, falling back to mock response", err)
// 如果 Eino 调用失败,返回模拟响应
return s.generateMockResponse(message), nil
}
if response == nil || response.Content == "" {
log.Context(ctx).Warn("Empty response from Eino chat model, falling back to mock response")
return s.generateMockResponse(message), nil
}
// 轻量化上报LLM使用模型、token估算、延迟、知识命中等
if s.monitor != nil {
// eino 不直接暴露 tokens这里用近似估算中文字数/2英文按空格词数*1
promptTokens := estimateTokens(enhancedMessage)
completionTokens := estimateTokens(response.Content)
// 使用专用意图模型识别意图(不使用自定义规则)
detectedIntent := "general_inquiry"
if intent, err := s.AnalyzeIntent(ctx, message); err == nil && intent != "" {
detectedIntent = intent
}
usage := &monitor.LLMUsage{
Model: s.chatModelName,
SessionID: sessionID,
UserID: "default_user",
PromptPreview: preview(enhancedMessage, 200),
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
LatencyMS: time.Since(start).Milliseconds(),
AgentThought: "intent=" + detectedIntent,
KnowledgeHits: countLines(knowledgeContext),
Metadata: map[string]string{"source": "ai.ProcessChat"},
Timestamp: time.Now(),
}
_ = s.monitor.RecordLLMUsage(ctx, usage)
}
return response.Content, nil
}
// StreamChat 流式处理聊天消息
func (s *aiService) StreamChat(ctx context.Context, message string, sessionID string) (<-chan string, error) {
log.Context(ctx).Infof("Processing stream chat message: %s for session: %s", message, sessionID)
// 构建聊天消息
messages := []*schema.Message{
{
Role: schema.User,
Content: message,
},
}
// 调用 Eino 流式聊天模型
start := time.Now()
streamReader, err := s.chatModel.Stream(ctx, messages)
if err != nil {
log.Context(ctx).Warnf("Eino stream chat model call failed: %v, falling back to mock stream response", err)
// 如果 Eino 流式调用失败,返回模拟流式响应
return s.generateMockStreamResponse(ctx, message), nil
}
// 创建响应通道
responseChan := make(chan string, 10)
// 启动 goroutine 处理流式响应
go func() {
defer close(responseChan)
defer streamReader.Close()
for {
chunk, err := streamReader.Recv()
if err != nil {
// 流结束或出错
return
}
if chunk != nil && chunk.Content != "" {
select {
case responseChan <- chunk.Content:
case <-ctx.Done():
return
}
}
}
// 汇总轻量化上报以流结束后估算总tokens与延迟
if s.monitor != nil {
// 注意为了避免阻塞这里不去累计所有chunk仅使用输入估算 + 总延迟
detectedIntent := "general_inquiry"
if intent, err := s.AnalyzeIntent(ctx, message); err == nil && intent != "" {
detectedIntent = intent
}
usage := &monitor.LLMUsage{
Model: s.chatModelName,
SessionID: sessionID,
UserID: "default_user",
PromptPreview: preview(message, 200),
PromptTokens: estimateTokens(message),
// completionTokens 无法准确统计,留空或基于时间估算可选
TotalTokens: estimateTokens(message),
LatencyMS: time.Since(start).Milliseconds(),
AgentThought: "intent=" + detectedIntent,
Metadata: map[string]string{"source": "ai.StreamChat"},
Timestamp: time.Now(),
}
_ = s.monitor.RecordLLMUsage(ctx, usage)
}
}()
return responseChan, nil
}
// estimateTokens 简易估算 tokens 数(不准确,但用于监控趋势足够)
func estimateTokens(text string) int {
if text == "" {
return 0
}
// 粗略策略:英文按空格分词;中文按 rune 数/2
ascii := true
for _, r := range text {
if r > 127 {
ascii = false
break
}
}
if ascii {
// 英文词数近似为token数
return len(strings.Fields(text))
}
// 中文每2个汉字约1个token
return len([]rune(text)) / 2
}
func preview(text string, max int) string {
if len(text) <= max {
return text
}
return text[:max]
}
func countLines(text string) int {
if text == "" {
return 0
}
return len(strings.Split(text, "\n"))
}
// AnalyzeIntent 使用意图模型qwen3:8b进行意图识别输出以下标签之一
// order_inquiry、product_inquiry、technical_support、general_inquiry
func (s *aiService) AnalyzeIntent(ctx context.Context, message string) (string, error) {
if s.intentModel == nil {
return "general_inquiry", fmt.Errorf("intent model not configured")
}
instruction := "你是一个意图分类器。仅从以下标签中选择并输出一个且只输出标签本身order_inquiry、product_inquiry、technical_support、general_inquiry。用户消息" + message
messages := []*schema.Message{
{Role: schema.User, Content: instruction},
}
resp, err := s.intentModel.Generate(ctx, messages)
if err != nil || resp == nil || resp.Content == "" {
return "general_inquiry", err
}
intent := strings.TrimSpace(strings.ToLower(resp.Content))
switch intent {
case "order_inquiry", "product_inquiry", "technical_support", "general_inquiry":
return intent, nil
}
// 输出不在预期集合回退为general
return "general_inquiry", nil
}
// generateMockStreamResponse 生成模拟流式响应
func (s *aiService) generateMockStreamResponse(ctx context.Context, message string) <-chan string {
responseChan := make(chan string, 10)
go func() {
defer close(responseChan)
// 获取模拟响应
mockResponse := s.generateMockResponse(message)
// 将响应分割成小块,模拟流式输出
words := strings.Fields(mockResponse)
for i, word := range words {
select {
case responseChan <- word:
if i < len(words)-1 {
// 在单词之间添加空格,除了最后一个单词
select {
case responseChan <- " ":
case <-ctx.Done():
return
}
}
// 模拟网络延迟
time.Sleep(50 * time.Millisecond)
case <-ctx.Done():
return
}
}
}()
return responseChan
}
// generateMockResponse 生成模拟响应
func (s *aiService) generateMockResponse(message string) string {
// 简单的模拟响应逻辑
message = strings.ToLower(strings.TrimSpace(message))
switch {
case strings.Contains(message, "你好") || strings.Contains(message, "hello"):
return "你好!有什么我可以帮你的吗?😊"
case strings.Contains(message, "天气"):
return "抱歉,我无法获取实时天气信息。建议您查看天气预报应用。"
case strings.Contains(message, "时间"):
return fmt.Sprintf("当前时间是 %s", time.Now().Format("2006-01-02 15:04:05"))
case strings.Contains(message, "介绍"):
return "我是一个AI助手可以帮助您回答问题、提供信息和进行对话。有什么我可以为您做的吗"
default:
return "感谢您的消息!我正在学习中,暂时无法完全理解您的问题。请尝试用不同的方式表达,或者问一些简单的问题。"
}
}