ai_scheduler/internal/biz/do/prompt.go

71 lines
1.9 KiB
Go

package do
import (
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"context"
"strings"
"github.com/ollama/ollama/api"
)
type PromptOption interface {
CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error)
}
type WithSys struct {
}
func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) {
var (
prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片
)
// 获取用户内容,如果出错则直接返回错误
content, err := f.getUserContent(ctx, rec)
if err != nil {
return nil, err
}
// 构建提示消息列表,包含系统提示、助手回复和用户内容
mes = append(prompt, api.Message{
Role: "system", // 系统角色
Content: rec.SystemPrompt, // 系统提示内容
}, api.Message{
Role: "assistant", // 助手角色
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容
}, api.Message{
Role: "user", // 用户角色
Content: content.String(), // 用户输入内容
})
return
}
func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) {
var hasFile bool
if len(rec.UserContent.FileUrl) > 0 || rec.UserContent.File != nil {
hasFile = true
}
content.WriteString(rec.UserContent.Text)
if hasFile {
content.WriteString("\n")
}
if len(rec.UserContent.Tag) > 0 {
content.WriteString("\n")
content.WriteString("### 工具必须使用:")
content.WriteString(rec.UserContent.Tag)
}
if len(rec.ChatHis.Messages) > 0 {
content.WriteString("### 引用历史聊天记录:\n")
content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis))
}
if hasFile {
content.WriteString("\n")
content.WriteString("### 文件内容:\n")
hand.WriteString(rec.UserContent.FileUrl, rec.UserContent.FileUrl)
}
return
}