88 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			88 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Go
		
	
	
	
package llm_service
 | 
						|
 | 
						|
import (
 | 
						|
	"ai_scheduler/internal/data/model"
 | 
						|
	"ai_scheduler/internal/entitys"
 | 
						|
	"ai_scheduler/internal/pkg"
 | 
						|
	"ai_scheduler/internal/pkg/utils_langchain"
 | 
						|
	"context"
 | 
						|
	"encoding/json"
 | 
						|
 | 
						|
	"github.com/tmc/langchaingo/llms"
 | 
						|
)
 | 
						|
 | 
						|
type LangChainService struct {
 | 
						|
	client *utils_langchain.UtilLangChain
 | 
						|
}
 | 
						|
 | 
						|
func NewLangChainGenerate(
 | 
						|
	client *utils_langchain.UtilLangChain,
 | 
						|
) *LangChainService {
 | 
						|
 | 
						|
	return &LangChainService{
 | 
						|
		client: client,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
 | 
						|
	prompt := r.getPrompt(sysInfo, history, userInput, tasks)
 | 
						|
	AgentClient := r.client.Get()
 | 
						|
	defer r.client.Put(AgentClient)
 | 
						|
	match, err := AgentClient.Llm.GenerateContent(
 | 
						|
		ctx, // 使用可取消的上下文
 | 
						|
		prompt,
 | 
						|
		llms.WithJSONMode(),
 | 
						|
	)
 | 
						|
	msg = match.Choices[0].Content
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
 | 
						|
	var (
 | 
						|
		prompt = make([]llms.MessageContent, 0)
 | 
						|
	)
 | 
						|
	prompt = append(prompt, llms.MessageContent{
 | 
						|
		Role: llms.ChatMessageTypeSystem,
 | 
						|
		Parts: []llms.ContentPart{
 | 
						|
			llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
 | 
						|
		},
 | 
						|
	}, llms.MessageContent{
 | 
						|
		Role: llms.ChatMessageTypeTool,
 | 
						|
		Parts: []llms.ContentPart{
 | 
						|
			llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
 | 
						|
		},
 | 
						|
	}, llms.MessageContent{
 | 
						|
		Role: llms.ChatMessageTypeTool,
 | 
						|
		Parts: []llms.ContentPart{
 | 
						|
			llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
 | 
						|
		},
 | 
						|
	}, llms.MessageContent{
 | 
						|
		Role: llms.ChatMessageTypeHuman,
 | 
						|
		Parts: []llms.ContentPart{
 | 
						|
			llms.TextPart(reqInput),
 | 
						|
		},
 | 
						|
	})
 | 
						|
	return prompt
 | 
						|
}
 | 
						|
 | 
						|
func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
 | 
						|
	taskPrompt := make([]llms.Tool, 0)
 | 
						|
	for _, task := range tasks {
 | 
						|
		var taskConfig entitys.TaskConfig
 | 
						|
		err := json.Unmarshal([]byte(task.Config), &taskConfig)
 | 
						|
		if err != nil {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		taskPrompt = append(taskPrompt, llms.Tool{
 | 
						|
			Type: "function",
 | 
						|
			Function: &llms.FunctionDefinition{
 | 
						|
				Name:        task.Index,
 | 
						|
				Description: task.Desc,
 | 
						|
				Parameters:  taskConfig.Param,
 | 
						|
			},
 | 
						|
		})
 | 
						|
 | 
						|
	}
 | 
						|
	return taskPrompt
 | 
						|
}
 |