154 lines
3.7 KiB
Go
154 lines
3.7 KiB
Go
package llm_service
|
|
|
|
import (
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/entitys"
|
|
"ai_scheduler/internal/pkg"
|
|
"ai_scheduler/internal/pkg/utils_vllm"
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
type VllmService struct {
|
|
client *utils_vllm.Client
|
|
config *config.Config
|
|
}
|
|
|
|
func NewVllmService(
|
|
client *utils_vllm.Client,
|
|
config *config.Config,
|
|
) *VllmService {
|
|
return &VllmService{
|
|
client: client,
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
func (s *VllmService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) {
|
|
msgs := s.convertMessages(req.Prompt)
|
|
tools := s.convertTools(req.Tools)
|
|
|
|
resp, err := s.client.ToolSelect(ctx, msgs, tools)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if resp.Content == "" {
|
|
if len(resp.ToolCalls) > 0 {
|
|
call := resp.ToolCalls[0]
|
|
var matchFromTools = &entitys.Match{
|
|
Confidence: 1,
|
|
Index: call.Function.Name,
|
|
Parameters: call.Function.Arguments,
|
|
IsMatch: true,
|
|
}
|
|
msg = pkg.JsonStringIgonErr(matchFromTools)
|
|
} else {
|
|
err = errors.New("不太明白你想表达的意思呢,可以在仔细描述一下您所需要的内容吗,感谢感谢")
|
|
return
|
|
}
|
|
} else {
|
|
msg = resp.Content
|
|
}
|
|
return
|
|
}
|
|
|
|
func (s *VllmService) convertMessages(prompts []api.Message) []*schema.Message {
|
|
msgs := make([]*schema.Message, 0, len(prompts))
|
|
for _, p := range prompts {
|
|
msg := &schema.Message{
|
|
Role: schema.RoleType(p.Role),
|
|
Content: p.Content,
|
|
}
|
|
|
|
// 这里实际应该不会走进来
|
|
if len(p.Images) > 0 {
|
|
parts := []schema.MessageInputPart{
|
|
{Type: schema.ChatMessagePartTypeText, Text: p.Content},
|
|
}
|
|
for _, imgData := range p.Images {
|
|
b64 := base64.StdEncoding.EncodeToString(imgData)
|
|
mimeType := "image/jpeg"
|
|
parts = append(parts, schema.MessageInputPart{
|
|
Type: schema.ChatMessagePartTypeImageURL,
|
|
Image: &schema.MessageInputImage{
|
|
MessagePartCommon: schema.MessagePartCommon{
|
|
MIMEType: mimeType,
|
|
Base64Data: &b64,
|
|
},
|
|
},
|
|
})
|
|
}
|
|
msg.UserInputMultiContent = parts
|
|
}
|
|
msgs = append(msgs, msg)
|
|
}
|
|
return msgs
|
|
}
|
|
|
|
func (s *VllmService) convertTools(tasks []entitys.RegistrationTask) []*schema.ToolInfo {
|
|
tools := make([]*schema.ToolInfo, 0, len(tasks))
|
|
for _, task := range tasks {
|
|
params := make(map[string]*schema.ParameterInfo)
|
|
for k, v := range task.TaskConfigDetail.Param.Properties {
|
|
dt := schema.String
|
|
|
|
// Handle v.Type dynamically to support both string and []string (compiler suggests []string)
|
|
// Using fmt.Sprint handles both cases safely without knowing exact type structure
|
|
typeStr := fmt.Sprintf("%v", v.Type)
|
|
typeStr = strings.Trim(typeStr, "[]") // normalize "[string]" -> "string"
|
|
|
|
switch typeStr {
|
|
case "string":
|
|
dt = schema.String
|
|
case "integer", "int":
|
|
dt = schema.Integer
|
|
case "number", "float":
|
|
dt = schema.Number
|
|
case "boolean", "bool":
|
|
dt = schema.Boolean
|
|
case "object":
|
|
dt = schema.Object
|
|
case "array":
|
|
dt = schema.Array
|
|
}
|
|
|
|
required := false
|
|
for _, r := range task.TaskConfigDetail.Param.Required {
|
|
if r == k {
|
|
required = true
|
|
break
|
|
}
|
|
}
|
|
|
|
desc := v.Description
|
|
if len(v.Enum) > 0 {
|
|
var enumStrs []string
|
|
for _, e := range v.Enum {
|
|
enumStrs = append(enumStrs, fmt.Sprintf("%v", e))
|
|
}
|
|
desc += " Enum: " + strings.Join(enumStrs, ", ")
|
|
}
|
|
|
|
params[k] = &schema.ParameterInfo{
|
|
Type: dt,
|
|
Desc: desc,
|
|
Required: required,
|
|
}
|
|
}
|
|
|
|
tools = append(tools, &schema.ToolInfo{
|
|
Name: task.Name,
|
|
Desc: task.Desc,
|
|
ParamsOneOf: schema.NewParamsOneOfByParams(params),
|
|
})
|
|
}
|
|
return tools
|
|
}
|