Merge remote-tracking branch 'origin/master'

This commit is contained in:
wuchao 2025-10-09 16:25:39 +08:00
commit 7ee5904752
12 changed files with 140 additions and 19 deletions

View File

@ -6,6 +6,7 @@ server:
ollama:
base_url: "http://127.0.0.1:11434"
model: "qwen3-coder:480b-cloud"
generate_model: "qwen3-coder:480b-cloud"
timeout: "120s"
level: "info"
format: "json"

View File

@ -1,6 +1,7 @@
package llm_service
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
@ -16,18 +17,24 @@ import (
type OllamaService struct {
client *utils_ollama.Client
config *config.Config
}
func NewOllamaGenerate(
client *utils_ollama.Client,
config *config.Config,
) *OllamaService {
return &OllamaService{
client: client,
config: config,
}
}
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
prompt := r.getPrompt(requireData.Sys, requireData.Histories, requireData.UserInput, requireData.Tasks)
prompt, err := r.getPrompt(ctx, requireData)
if err != nil {
return
}
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
if err != nil {
@ -53,21 +60,45 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity
return
}
func (r *OllamaService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
var (
prompt = make([]api.Message, 0)
)
prompt = append(prompt, api.Message{
Role: "system",
Content: buildSystemPrompt(sysInfo.SysPrompt),
Content: buildSystemPrompt(requireData.Sys.SysPrompt),
}, api.Message{
Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(history))),
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(requireData.Histories))),
}, api.Message{
Role: "user",
Content: reqInput,
Content: requireData.UserInput,
})
return prompt
if len(requireData.ImgByte) > 0 {
_, err := r.RecognizeWithImg(ctx, requireData)
if err != nil {
return nil, err
}
}
return prompt, nil
}
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
requireData.Ch <- entitys.Response{
Index: "",
Content: "图片识别中。。。",
Type: entitys.ResponseLoading,
}
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
Model: r.config.Ollama.GenerateModel,
Stream: new(bool),
System: "识别图片内容",
Prompt: requireData.UserInput,
})
return
}
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {

View File

@ -80,6 +80,13 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe
cancel()
}()
//获取图片信息
err = r.getImgData(req.Img, &requireData)
if err != nil {
log.Errorf("GetImgData error: %v", err)
return
}
//获取系统信息
err = r.getRequireData(req.Text, &requireData)
if err != nil {
@ -173,6 +180,32 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
}
}
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
if len(imgUrl) == 0 {
return
}
if err = pkg.ValidateImageURL(imgUrl); err != nil {
return err
}
req := l_request.Request{
Method: "GET",
Url: imgUrl,
}
res, err := req.Send()
if err != nil {
return
}
if _, ex := res.Headers["Content-Type"]; !ex {
return errors.ParamErr("图片格式错误:Content-Type未获取")
}
if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
return errors.ParamErr("expected image content, got %s", res.Headers["Content-Type"])
}
requireData.ImgByte = append(requireData.ImgByte, res.Content)
return
}
func (r *AiRouterBiz) recognize(ctx context.Context, requireData *entitys.RequireData) (err error) {
requireData.Ch <- entitys.Response{
Index: "",

View File

@ -39,9 +39,10 @@ type ServerConfig struct {
// OllamaConfig Ollama配置
type OllamaConfig struct {
BaseURL string `mapstructure:"base_url"`
Model string `mapstructure:"model"`
Timeout time.Duration `mapstructure:"timeout"`
BaseURL string `mapstructure:"base_url"`
Model string `mapstructure:"model"`
GenerateModel string `mapstructure:"generate_model"`
Timeout time.Duration `mapstructure:"timeout"`
}
type Redis struct {

View File

@ -46,6 +46,10 @@ func SysErr(message string, arg ...any) *BusinessErr {
return &BusinessErr{code: SystemError.code, message: fmt.Sprintf(message, arg)}
}
func ParamErr(message string, arg ...any) *BusinessErr {
return &BusinessErr{code: ParamError.code, message: fmt.Sprintf(message, arg)}
}
func (e *BusinessErr) Wrap(err error) *BusinessErr {
return NewBusinessErr(e.code, err.Error())
}

View File

@ -151,6 +151,7 @@ type RequireData struct {
Auth string
Ch chan Response
KnowledgeConf KnowledgeBaseRequest
ImgByte []api.ImageData
}
type KnowledgeBaseRequest struct {

View File

@ -3,6 +3,10 @@ package pkg
import (
"ai_scheduler/internal/entitys"
"encoding/json"
"errors"
"fmt"
"net/url"
"strings"
)
func JsonStringIgonErr(data interface{}) string {
@ -25,3 +29,29 @@ func IsChannelClosed(ch chan entitys.ResponseData) bool {
return false // channel 未关闭(但可能有数据未读取)
}
}
// ValidateImageURL 验证图片 URL 是否有效
func ValidateImageURL(rawURL string) error {
// 1. 基础格式验证
parsed, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid URL format: %v", err)
}
// 2. 检查协议是否为 http/https
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return errors.New("URL must use http or https protocol")
}
// 3. 检查是否有空的主机名
if parsed.Host == "" {
return errors.New("URL missing host")
}
// 4. 检查路径是否为空(可选)
if strings.TrimSpace(parsed.Path) == "" {
return errors.New("URL path is empty")
}
return nil
}

View File

@ -59,10 +59,13 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [
return
}
func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string) (err error) {
func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string, model string) (err error) {
if len(model) == 0 {
model = c.config.Model
}
// 构建聊天请求
req := &api.ChatRequest{
Model: c.config.Model,
Model: model,
Messages: messages,
Stream: nil,
Think: &api.ThinkValue{Value: true},
@ -90,6 +93,14 @@ func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messa
return
}
func (c *Client) Generation(ctx context.Context, generateRequest *api.GenerateRequest) (result api.GenerateResponse, err error) {
err = c.client.Generate(ctx, generateRequest, func(resp api.GenerateResponse) error {
result = resp
return nil
})
return
}
// convertResponse 转换响应格式
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
//result := &entitys.ChatResponse{

View File

@ -71,7 +71,7 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
}
// 普通对话
chat := NewNormalChatTool(m.llm)
chat := NewNormalChatTool(m.llm, config)
m.tools[chat.Name()] = chat
return m

View File

@ -1,6 +1,7 @@
package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
@ -13,12 +14,13 @@ import (
// NormalChatTool 普通对话
type NormalChatTool struct {
llm *utils_ollama.Client
llm *utils_ollama.Client
config *config.Config
}
// NewNormalChatTool 实例普通对话
func NewNormalChatTool(llm *utils_ollama.Client) *NormalChatTool {
return &NormalChatTool{llm: llm}
func NewNormalChatTool(llm *utils_ollama.Client, config *config.Config) *NormalChatTool {
return &NormalChatTool{llm: llm, config: config}
}
// Name 返回工具名称
@ -56,6 +58,12 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) {
//requireData.Ch <- entitys.Response{
// Index: w.Name(),
// Content: "<think></think>",
// Type: entitys.ResponseStream,
//}
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
{
Role: "system",
@ -67,9 +75,9 @@ func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat
},
{
Role: "user",
Content: requireData.UserInput,
Content: chat.ChatContent,
},
}, w.Name())
}, w.Name(), w.config.Ollama.GenerateModel)
if err != nil {
return fmt.Errorf("%s", err)
}

View File

@ -173,7 +173,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
Role: "user",
Content: requireData.UserInput,
},
}, w.Name())
}, w.Name(), "")
if err != nil {
return fmt.Errorf("订单日志解析失败:%s", err)
}

View File

@ -3,6 +3,7 @@ package tools
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"context"
"encoding/json"
"fmt"
@ -167,7 +168,7 @@ func (z ZltxProductTool) getZltxProduct(body *ZltxProductRequest, requireData *e
return fmt.Errorf("解析商品数据失败:%w", err)
}
if resp.Code != 200 {
return fmt.Errorf("商品查询失败:%s", resp.Error)
return fmt.Errorf("商品查询失败:%s", pkg.JsonStringIgonErr(resp))
}
if resp.Data.List == nil || len(resp.Data.List) == 0 {
var respData ZltxProductDataById