133 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			133 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Go
		
	
	
	
package llm_service
 | 
						|
 | 
						|
import (
 | 
						|
	"ai_scheduler/internal/config"
 | 
						|
	"ai_scheduler/internal/data/model"
 | 
						|
	"ai_scheduler/internal/entitys"
 | 
						|
	"ai_scheduler/internal/pkg"
 | 
						|
	"ai_scheduler/internal/pkg/utils_ollama"
 | 
						|
	"context"
 | 
						|
	"encoding/json"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
 | 
						|
	"github.com/gofiber/fiber/v2/log"
 | 
						|
	"github.com/ollama/ollama/api"
 | 
						|
)
 | 
						|
 | 
						|
type OllamaService struct {
 | 
						|
	client *utils_ollama.Client
 | 
						|
	config *config.Config
 | 
						|
}
 | 
						|
 | 
						|
func NewOllamaGenerate(
 | 
						|
	client *utils_ollama.Client,
 | 
						|
	config *config.Config,
 | 
						|
) *OllamaService {
 | 
						|
	return &OllamaService{
 | 
						|
		client: client,
 | 
						|
		config: config,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
 | 
						|
	prompt, err := r.getPrompt(ctx, requireData)
 | 
						|
	if err != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	toolDefinitions := r.registerToolsOllama(requireData.Tasks)
 | 
						|
	match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
 | 
						|
	if err != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	log.Info("意图识别结果: %v", pkg.JsonStringIgonErr(match))
 | 
						|
	if len(match.Message.Content) == 0 {
 | 
						|
		if match.Message.ToolCalls != nil {
 | 
						|
			var matchFromTools = &entitys.Match{
 | 
						|
				Confidence: 1,
 | 
						|
				Index:      match.Message.ToolCalls[0].Function.Name,
 | 
						|
				Parameters: pkg.JsonStringIgonErr(match.Message.ToolCalls[0].Function.Arguments),
 | 
						|
				IsMatch:    true,
 | 
						|
			}
 | 
						|
			match.Message.Content = pkg.JsonStringIgonErr(matchFromTools)
 | 
						|
		} else {
 | 
						|
			err = errors.New("不太明白你想表达的意思呢,可以在仔细描述一下您所需要的内容吗,感谢感谢")
 | 
						|
			return
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	msg = match.Message.Content
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
 | 
						|
 | 
						|
	var (
 | 
						|
		prompt = make([]api.Message, 0)
 | 
						|
	)
 | 
						|
	prompt = append(prompt, api.Message{
 | 
						|
		Role:    "system",
 | 
						|
		Content: buildSystemPrompt(requireData.Sys.SysPrompt),
 | 
						|
	}, api.Message{
 | 
						|
		Role:    "assistant",
 | 
						|
		Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(requireData.Histories))),
 | 
						|
	}, api.Message{
 | 
						|
		Role:    "user",
 | 
						|
		Content: requireData.UserInput,
 | 
						|
	})
 | 
						|
 | 
						|
	if len(requireData.ImgByte) > 0 {
 | 
						|
		_, err := r.RecognizeWithImg(ctx, requireData)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
	}
 | 
						|
	return prompt, nil
 | 
						|
}
 | 
						|
 | 
						|
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
 | 
						|
	if requireData.ImgByte == nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	requireData.Ch <- entitys.Response{
 | 
						|
		Index:   "",
 | 
						|
		Content: "图片识别中。。。",
 | 
						|
		Type:    entitys.ResponseLoading,
 | 
						|
	}
 | 
						|
	desc, err = r.client.Generation(ctx, &api.GenerateRequest{
 | 
						|
		Model:  r.config.Ollama.GenerateModel,
 | 
						|
		Stream: new(bool),
 | 
						|
		System: "识别图片内容",
 | 
						|
		Prompt: requireData.UserInput,
 | 
						|
		Images: requireData.ImgByte,
 | 
						|
	})
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
 | 
						|
	taskPrompt := make([]api.Tool, 0)
 | 
						|
	for _, task := range tasks {
 | 
						|
		var taskConfig entitys.TaskConfigDetail
 | 
						|
		err := json.Unmarshal([]byte(task.Config), &taskConfig)
 | 
						|
		if err != nil {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		taskPrompt = append(taskPrompt, api.Tool{
 | 
						|
			Type: "function",
 | 
						|
			Function: api.ToolFunction{
 | 
						|
				Name:        task.Index,
 | 
						|
				Description: task.Desc,
 | 
						|
				Parameters: api.ToolFunctionParameters{
 | 
						|
					Type:       taskConfig.Param.Type,
 | 
						|
					Required:   taskConfig.Param.Required,
 | 
						|
					Properties: taskConfig.Param.Properties,
 | 
						|
				},
 | 
						|
			},
 | 
						|
		})
 | 
						|
 | 
						|
	}
 | 
						|
	return taskPrompt
 | 
						|
}
 |