Compare commits
No commits in common. "master" and "feature/rzy/report_combine" have entirely different histories.
master
...
feature/rz
|
|
@ -10,10 +10,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
|
configPath := flag.String("config", "./config/config_test.yaml", "Path to configuration file")
|
||||||
onBot := flag.String("bot", "", "bot start")
|
onBot := flag.String("bot", "", "bot start")
|
||||||
cron := flag.String("cron", "", "close")
|
cron := flag.String("cron", "", "close")
|
||||||
runJob := flag.String("runJob", "", "run single job and exit")
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
bc, err := config.LoadConfig(*configPath)
|
bc, err := config.LoadConfig(*configPath)
|
||||||
|
|
@ -34,11 +33,6 @@ func main() {
|
||||||
if *cron == "start" {
|
if *cron == "start" {
|
||||||
app.Cron.Run(ctx)
|
app.Cron.Run(ctx)
|
||||||
}
|
}
|
||||||
// 运行指定任务并退出
|
|
||||||
if *runJob != "" {
|
|
||||||
app.Cron.RunOnce(ctx, *runJob)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,29 +4,23 @@ server:
|
||||||
host: "0.0.0.0"
|
host: "0.0.0.0"
|
||||||
|
|
||||||
ollama:
|
ollama:
|
||||||
base_url: "http://192.168.6.115:11434"
|
base_url: "http://172.17.0.1:11434"
|
||||||
model: "qwen3:8b"
|
# model: "qwen3:8b"
|
||||||
generate_model: "qwen3:8b"
|
# generate_model: "qwen3:8b"
|
||||||
mapping_model: "qwen3:8b"
|
# mapping_model: "qwen3:8b"
|
||||||
# model: "qwen3-coder:480b-cloud"
|
model: "qwen3-coder:480b-cloud"
|
||||||
# generate_model: "qwen3-coder:480b-cloud"
|
generate_model: "qwen3-coder:480b-cloud"
|
||||||
# mapping_model: "deepseek-v3.2:cloud"
|
mapping_model: "deepseek-v3.2:cloud"
|
||||||
vl_model: "qwen2.5vl:3b"
|
vl_model: "qwen2.5vl:3b"
|
||||||
timeout: "120s"
|
timeout: "120s"
|
||||||
level: "info"
|
level: "info"
|
||||||
format: "json"
|
format: "json"
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
vl_model:
|
base_url: "http://172.17.0.1:8001/v1"
|
||||||
base_url: "http://192.168.6.115:8001/v1"
|
vl_model: "qwen2.5-vl-3b-awq"
|
||||||
model: "qwen2.5-vl-3b-awq"
|
timeout: "120s"
|
||||||
timeout: "120s"
|
level: "info"
|
||||||
level: "info"
|
|
||||||
text_model:
|
|
||||||
base_url: "http://192.168.6.115:8002/v1"
|
|
||||||
model: "qwen3-8b-fp8"
|
|
||||||
timeout: "120s"
|
|
||||||
level: "info"
|
|
||||||
|
|
||||||
coze:
|
coze:
|
||||||
base_url: "https://api.coze.cn"
|
base_url: "https://api.coze.cn"
|
||||||
|
|
|
||||||
|
|
@ -14,16 +14,10 @@ ollama:
|
||||||
format: "json"
|
format: "json"
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
vl_model:
|
base_url: "http://117.175.169.61:16001/v1"
|
||||||
base_url: "http://192.168.6.115:8001/v1"
|
vl_model: "qwen2.5-vl-3b-awq"
|
||||||
model: "qwen2.5-vl-3b-awq"
|
timeout: "120s"
|
||||||
timeout: "120s"
|
level: "info"
|
||||||
level: "info"
|
|
||||||
text_model:
|
|
||||||
base_url: "http://192.168.6.115:8002/v1"
|
|
||||||
model: "qwen3-8b-fp8"
|
|
||||||
timeout: "120s"
|
|
||||||
level: "info"
|
|
||||||
|
|
||||||
coze:
|
coze:
|
||||||
base_url: "https://api.coze.cn"
|
base_url: "https://api.coze.cn"
|
||||||
|
|
|
||||||
|
|
@ -14,16 +14,10 @@ ollama:
|
||||||
format: "json"
|
format: "json"
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
vl_model:
|
base_url: "http://host.docker.internal:8001/v1"
|
||||||
base_url: "http://192.168.6.115:8001/v1"
|
vl_model: "qwen2.5-vl-3b-awq"
|
||||||
model: "qwen2.5-vl-3b-awq"
|
timeout: "120s"
|
||||||
timeout: "120s"
|
level: "info"
|
||||||
level: "info"
|
|
||||||
text_model:
|
|
||||||
base_url: "http://192.168.6.115:8002/v1"
|
|
||||||
model: "qwen3-8b-fp8"
|
|
||||||
timeout: "120s"
|
|
||||||
level: "info"
|
|
||||||
|
|
||||||
coze:
|
coze:
|
||||||
base_url: "https://api.coze.cn"
|
base_url: "https://api.coze.cn"
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,6 @@ import (
|
||||||
|
|
||||||
type Handle struct {
|
type Handle struct {
|
||||||
Ollama *llm_service.OllamaService
|
Ollama *llm_service.OllamaService
|
||||||
Vllm *llm_service.VllmService
|
|
||||||
toolManager *tools.Manager
|
toolManager *tools.Manager
|
||||||
conf *config.Config
|
conf *config.Config
|
||||||
sessionImpl *impl.SessionImpl
|
sessionImpl *impl.SessionImpl
|
||||||
|
|
@ -48,7 +47,6 @@ type Handle struct {
|
||||||
|
|
||||||
func NewHandle(
|
func NewHandle(
|
||||||
Ollama *llm_service.OllamaService,
|
Ollama *llm_service.OllamaService,
|
||||||
Vllm *llm_service.VllmService,
|
|
||||||
toolManager *tools.Manager,
|
toolManager *tools.Manager,
|
||||||
conf *config.Config,
|
conf *config.Config,
|
||||||
sessionImpl *impl.SessionImpl,
|
sessionImpl *impl.SessionImpl,
|
||||||
|
|
@ -59,7 +57,6 @@ func NewHandle(
|
||||||
) *Handle {
|
) *Handle {
|
||||||
return &Handle{
|
return &Handle{
|
||||||
Ollama: Ollama,
|
Ollama: Ollama,
|
||||||
Vllm: Vllm,
|
|
||||||
toolManager: toolManager,
|
toolManager: toolManager,
|
||||||
conf: conf,
|
conf: conf,
|
||||||
sessionImpl: sessionImpl,
|
sessionImpl: sessionImpl,
|
||||||
|
|
@ -75,8 +72,7 @@ func (r *Handle) Recognize(ctx context.Context, rec *entitys.Recognize, promptPr
|
||||||
|
|
||||||
prompt, err := promptProcessor.CreatePrompt(ctx, rec)
|
prompt, err := promptProcessor.CreatePrompt(ctx, rec)
|
||||||
//意图识别
|
//意图识别
|
||||||
// recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{
|
recognizeMsg, err := r.Ollama.IntentRecognize(ctx, &entitys.ToolSelect{
|
||||||
recognizeMsg, err := r.Vllm.IntentRecognize(ctx, &entitys.ToolSelect{
|
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Tools: rec.Tasks,
|
Tools: rec.Tasks,
|
||||||
})
|
})
|
||||||
|
|
@ -510,7 +506,7 @@ func handleCozeWorkflowEvents(ctx context.Context, resp coze.Stream[coze.Workflo
|
||||||
case coze.WorkflowEventTypeMessage:
|
case coze.WorkflowEventTypeMessage:
|
||||||
entitys.ResStream(ch, index, event.Message.Content)
|
entitys.ResStream(ch, index, event.Message.Content)
|
||||||
case coze.WorkflowEventTypeError:
|
case coze.WorkflowEventTypeError:
|
||||||
entitys.ResError(ch, index, fmt.Sprintf("工作流执行错误: %v", event.Error))
|
entitys.ResError(ch, index, fmt.Sprintf("工作流执行错误: %s", event.Error))
|
||||||
case coze.WorkflowEventTypeDone:
|
case coze.WorkflowEventTypeDone:
|
||||||
entitys.ResEnd(ch, index, "工作流执行完成")
|
entitys.ResEnd(ch, index, "工作流执行完成")
|
||||||
case coze.WorkflowEventTypeInterrupt:
|
case coze.WorkflowEventTypeInterrupt:
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ import (
|
||||||
|
|
||||||
type Macro struct {
|
type Macro struct {
|
||||||
botGroupImpl *impl.BotGroupImpl
|
botGroupImpl *impl.BotGroupImpl
|
||||||
botGroupConfigImpl *impl.BotGroupConfigImpl
|
|
||||||
reportDailyCacheImpl *impl.ReportDailyCacheImpl
|
reportDailyCacheImpl *impl.ReportDailyCacheImpl
|
||||||
config *config.Config
|
config *config.Config
|
||||||
rdb *utils.Rdb
|
rdb *utils.Rdb
|
||||||
|
|
@ -37,14 +36,12 @@ func NewMacro(
|
||||||
reportDailyCacheImpl *impl.ReportDailyCacheImpl,
|
reportDailyCacheImpl *impl.ReportDailyCacheImpl,
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
rdb *utils.Rdb,
|
rdb *utils.Rdb,
|
||||||
botGroupConfigImpl *impl.BotGroupConfigImpl,
|
|
||||||
) *Macro {
|
) *Macro {
|
||||||
return &Macro{
|
return &Macro{
|
||||||
botGroupImpl: botGroupImpl,
|
botGroupImpl: botGroupImpl,
|
||||||
reportDailyCacheImpl: reportDailyCacheImpl,
|
reportDailyCacheImpl: reportDailyCacheImpl,
|
||||||
config: config,
|
config: config,
|
||||||
rdb: rdb,
|
rdb: rdb,
|
||||||
botGroupConfigImpl: botGroupConfigImpl,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -196,7 +193,7 @@ func (m *Macro) ProductModify(ctx context.Context, content string, groupConfig *
|
||||||
groupConfig.ProductName = itemInfo
|
groupConfig.ProductName = itemInfo
|
||||||
cond := builder.NewCond()
|
cond := builder.NewCond()
|
||||||
cond = cond.And(builder.Eq{"config_id": groupConfig.ConfigID})
|
cond = cond.And(builder.Eq{"config_id": groupConfig.ConfigID})
|
||||||
err = m.botGroupConfigImpl.UpdateByCond(&cond, groupConfig)
|
err = m.botGroupImpl.UpdateByCond(&cond, groupConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("修改失败:%v", err)
|
err = fmt.Errorf("修改失败:%v", err)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -1,32 +0,0 @@
|
||||||
package do
|
|
||||||
|
|
||||||
import (
|
|
||||||
"ai_scheduler/internal/config"
|
|
||||||
"ai_scheduler/internal/data/impl"
|
|
||||||
"ai_scheduler/internal/data/model"
|
|
||||||
|
|
||||||
"ai_scheduler/utils"
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_report(t *testing.T) {
|
|
||||||
con := "[利润同比报表]商品修改:官方–优酷周卡,官方–优酷月卡,官方–优酷季卡,官方–优酷年卡,,官方–爱奇艺-月卡,官方–爱奇艺-季卡,官方–爱奇艺-年卡,官方–芒果-PC周卡,官方–芒果-PC月卡,官方–芒果-PC季卡,官方–QQ音乐-绿钻月卡,官方–饿了么超级会员月卡,官方–网易云黑胶vip月卡,官方–喜马拉雅巅峰会员月卡,剪映会员7天卡,剪映会员月卡,剪映会员年卡,剪映SVIP会员7天卡,剪映SVIP会员月卡,剪映SVIP会员年卡"
|
|
||||||
run()
|
|
||||||
|
|
||||||
chatId, err, i := ma.ProductModify(context.Background(), con, &model.AiBotGroupConfig{ConfigID: 1, ToolList: "8,9,10,11,12,13,16"})
|
|
||||||
t.Log(chatId, err, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ma *Macro
|
|
||||||
|
|
||||||
func run() {
|
|
||||||
configConfig, _ := config.LoadConfigWithTest()
|
|
||||||
db, _ := utils.NewGormDb(configConfig)
|
|
||||||
|
|
||||||
rdb := utils.NewRdb(configConfig)
|
|
||||||
reportDailyCacheImpl := impl.NewReportDailyCacheImpl(db)
|
|
||||||
botGroupImpl := impl.NewBotGroupImpl(db)
|
|
||||||
botGroupConfigImpl := impl.NewBotGroupConfigImpl(db)
|
|
||||||
ma = NewMacro(botGroupImpl, reportDailyCacheImpl, configConfig, rdb, botGroupConfigImpl)
|
|
||||||
}
|
|
||||||
|
|
@ -115,7 +115,6 @@ func (s *SendCardClient) NewCard(ctx context.Context, cardSend *CardSend) error
|
||||||
s.processContentChannel(ctx, cardSend, cardInstanceId.String(), client)
|
s.processContentChannel(ctx, cardSend, cardInstanceId.String(), client)
|
||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
log.Info("处理通道结束")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -164,7 +163,7 @@ func (s *SendCardClient) processContentChannel(ctx context.Context, cardSend *Ca
|
||||||
|
|
||||||
var (
|
var (
|
||||||
contentBuilder strings.Builder
|
contentBuilder strings.Builder
|
||||||
lastUpdate = time.Now()
|
lastUpdate time.Time
|
||||||
)
|
)
|
||||||
for {
|
for {
|
||||||
|
|
||||||
|
|
@ -174,7 +173,6 @@ func (s *SendCardClient) processContentChannel(ctx context.Context, cardSend *Ca
|
||||||
// 通道关闭,发送最终内容
|
// 通道关闭,发送最终内容
|
||||||
if contentBuilder.Len() > 0 {
|
if contentBuilder.Len() > 0 {
|
||||||
if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil {
|
if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil {
|
||||||
log.Info("contentBuilder.Len()修改失败1")
|
|
||||||
s.logger.Errorf("更新卡片失败1:%s", err.Error())
|
s.logger.Errorf("更新卡片失败1:%s", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -183,7 +181,6 @@ func (s *SendCardClient) processContentChannel(ctx context.Context, cardSend *Ca
|
||||||
contentBuilder.WriteString(content)
|
contentBuilder.WriteString(content)
|
||||||
if contentBuilder.Len() > 0 {
|
if contentBuilder.Len() > 0 {
|
||||||
if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil {
|
if err := s.updateCardContent(ctx, cardSend, cardInstanceId, contentBuilder.String(), client); err != nil {
|
||||||
log.Info("contentBuilder.Len()修改失败2")
|
|
||||||
s.logger.Errorf("更新卡片失败2:%s", err.Error())
|
s.logger.Errorf("更新卡片失败2:%s", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -191,12 +188,10 @@ func (s *SendCardClient) processContentChannel(ctx context.Context, cardSend *Ca
|
||||||
|
|
||||||
case <-heartbeatTicker.C:
|
case <-heartbeatTicker.C:
|
||||||
if time.Now().Unix()-lastUpdate.Unix() >= HeardBeatX {
|
if time.Now().Unix()-lastUpdate.Unix() >= HeardBeatX {
|
||||||
log.Infof("心跳超时,当前时间:%d,最后时间:%d", time.Now().Unix(), lastUpdate.Unix())
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Info("send_card上下文失效")
|
|
||||||
s.logger.Info("context canceled, stop channel processing")
|
s.logger.Info("context canceled, stop channel processing")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,153 +0,0 @@
|
||||||
package llm_service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"ai_scheduler/internal/config"
|
|
||||||
"ai_scheduler/internal/entitys"
|
|
||||||
"ai_scheduler/internal/pkg"
|
|
||||||
"ai_scheduler/internal/pkg/utils_vllm"
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/cloudwego/eino/schema"
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
type VllmService struct {
|
|
||||||
client *utils_vllm.Client
|
|
||||||
config *config.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewVllmService(
|
|
||||||
client *utils_vllm.Client,
|
|
||||||
config *config.Config,
|
|
||||||
) *VllmService {
|
|
||||||
return &VllmService{
|
|
||||||
client: client,
|
|
||||||
config: config,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *VllmService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) {
|
|
||||||
msgs := s.convertMessages(req.Prompt)
|
|
||||||
tools := s.convertTools(req.Tools)
|
|
||||||
|
|
||||||
resp, err := s.client.ToolSelect(ctx, msgs, tools)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Content == "" {
|
|
||||||
if len(resp.ToolCalls) > 0 {
|
|
||||||
call := resp.ToolCalls[0]
|
|
||||||
var matchFromTools = &entitys.Match{
|
|
||||||
Confidence: 1,
|
|
||||||
Index: call.Function.Name,
|
|
||||||
Parameters: call.Function.Arguments,
|
|
||||||
IsMatch: true,
|
|
||||||
}
|
|
||||||
msg = pkg.JsonStringIgonErr(matchFromTools)
|
|
||||||
} else {
|
|
||||||
err = errors.New("不太明白你想表达的意思呢,可以在仔细描述一下您所需要的内容吗,感谢感谢")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg = resp.Content
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *VllmService) convertMessages(prompts []api.Message) []*schema.Message {
|
|
||||||
msgs := make([]*schema.Message, 0, len(prompts))
|
|
||||||
for _, p := range prompts {
|
|
||||||
msg := &schema.Message{
|
|
||||||
Role: schema.RoleType(p.Role),
|
|
||||||
Content: p.Content,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 这里实际应该不会走进来
|
|
||||||
if len(p.Images) > 0 {
|
|
||||||
parts := []schema.MessageInputPart{
|
|
||||||
{Type: schema.ChatMessagePartTypeText, Text: p.Content},
|
|
||||||
}
|
|
||||||
for _, imgData := range p.Images {
|
|
||||||
b64 := base64.StdEncoding.EncodeToString(imgData)
|
|
||||||
mimeType := "image/jpeg"
|
|
||||||
parts = append(parts, schema.MessageInputPart{
|
|
||||||
Type: schema.ChatMessagePartTypeImageURL,
|
|
||||||
Image: &schema.MessageInputImage{
|
|
||||||
MessagePartCommon: schema.MessagePartCommon{
|
|
||||||
MIMEType: mimeType,
|
|
||||||
Base64Data: &b64,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
msg.UserInputMultiContent = parts
|
|
||||||
}
|
|
||||||
msgs = append(msgs, msg)
|
|
||||||
}
|
|
||||||
return msgs
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *VllmService) convertTools(tasks []entitys.RegistrationTask) []*schema.ToolInfo {
|
|
||||||
tools := make([]*schema.ToolInfo, 0, len(tasks))
|
|
||||||
for _, task := range tasks {
|
|
||||||
params := make(map[string]*schema.ParameterInfo)
|
|
||||||
for k, v := range task.TaskConfigDetail.Param.Properties {
|
|
||||||
dt := schema.String
|
|
||||||
|
|
||||||
// Handle v.Type dynamically to support both string and []string (compiler suggests []string)
|
|
||||||
// Using fmt.Sprint handles both cases safely without knowing exact type structure
|
|
||||||
typeStr := fmt.Sprintf("%v", v.Type)
|
|
||||||
typeStr = strings.Trim(typeStr, "[]") // normalize "[string]" -> "string"
|
|
||||||
|
|
||||||
switch typeStr {
|
|
||||||
case "string":
|
|
||||||
dt = schema.String
|
|
||||||
case "integer", "int":
|
|
||||||
dt = schema.Integer
|
|
||||||
case "number", "float":
|
|
||||||
dt = schema.Number
|
|
||||||
case "boolean", "bool":
|
|
||||||
dt = schema.Boolean
|
|
||||||
case "object":
|
|
||||||
dt = schema.Object
|
|
||||||
case "array":
|
|
||||||
dt = schema.Array
|
|
||||||
}
|
|
||||||
|
|
||||||
required := false
|
|
||||||
for _, r := range task.TaskConfigDetail.Param.Required {
|
|
||||||
if r == k {
|
|
||||||
required = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := v.Description
|
|
||||||
if len(v.Enum) > 0 {
|
|
||||||
var enumStrs []string
|
|
||||||
for _, e := range v.Enum {
|
|
||||||
enumStrs = append(enumStrs, fmt.Sprintf("%v", e))
|
|
||||||
}
|
|
||||||
desc += " Enum: " + strings.Join(enumStrs, ", ")
|
|
||||||
}
|
|
||||||
|
|
||||||
params[k] = &schema.ParameterInfo{
|
|
||||||
Type: dt,
|
|
||||||
Desc: desc,
|
|
||||||
Required: required,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tools = append(tools, &schema.ToolInfo{
|
|
||||||
Name: task.Name,
|
|
||||||
Desc: task.Desc,
|
|
||||||
ParamsOneOf: schema.NewParamsOneOfByParams(params),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return tools
|
|
||||||
}
|
|
||||||
|
|
@ -3,7 +3,6 @@ package biz
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/biz/do"
|
"ai_scheduler/internal/biz/do"
|
||||||
"ai_scheduler/internal/biz/llm_service"
|
"ai_scheduler/internal/biz/llm_service"
|
||||||
"ai_scheduler/internal/biz/support"
|
|
||||||
|
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
)
|
)
|
||||||
|
|
@ -14,7 +13,6 @@ var ProviderSetBiz = wire.NewSet(
|
||||||
NewChatHistoryBiz,
|
NewChatHistoryBiz,
|
||||||
//llm_service.NewLangChainGenerate,
|
//llm_service.NewLangChainGenerate,
|
||||||
llm_service.NewOllamaGenerate,
|
llm_service.NewOllamaGenerate,
|
||||||
llm_service.NewVllmService,
|
|
||||||
//handle.NewHandle,
|
//handle.NewHandle,
|
||||||
do.NewDo,
|
do.NewDo,
|
||||||
do.NewHandle,
|
do.NewHandle,
|
||||||
|
|
@ -23,5 +21,4 @@ var ProviderSetBiz = wire.NewSet(
|
||||||
NewQywxAppBiz,
|
NewQywxAppBiz,
|
||||||
NewGroupConfigBiz,
|
NewGroupConfigBiz,
|
||||||
do.NewMacro,
|
do.NewMacro,
|
||||||
support.NewHytAddressIngester,
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,81 +0,0 @@
|
||||||
package support
|
|
||||||
|
|
||||||
import (
|
|
||||||
"ai_scheduler/internal/config"
|
|
||||||
"ai_scheduler/internal/data/constants"
|
|
||||||
errorcode "ai_scheduler/internal/data/error"
|
|
||||||
"ai_scheduler/internal/pkg/util"
|
|
||||||
"ai_scheduler/internal/pkg/utils_vllm"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cloudwego/eino/schema"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
// HytAddressIngester 货易通地址提取实现
|
|
||||||
type HytAddressIngester struct {
|
|
||||||
cfg *config.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHytAddressIngester(cfg *config.Config) *HytAddressIngester {
|
|
||||||
return &HytAddressIngester{cfg: cfg}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auth 鉴权逻辑
|
|
||||||
func (s *HytAddressIngester) Auth(c *fiber.Ctx) error {
|
|
||||||
// 读取头
|
|
||||||
token := strings.TrimSpace(c.Get("X-Source-Key"))
|
|
||||||
ts := strings.TrimSpace(c.Get("X-Timestamp"))
|
|
||||||
|
|
||||||
// 时间窗口校验
|
|
||||||
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
|
|
||||||
return errorcode.AuthNotFound
|
|
||||||
}
|
|
||||||
// token校验
|
|
||||||
if token == "" || token != constants.TokenAddressIngestHyt {
|
|
||||||
return errorcode.KeyNotFound
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ingest 执行提取逻辑
|
|
||||||
func (s *HytAddressIngester) Ingest(ctx context.Context, text string) (*AddressIngestResp, error) {
|
|
||||||
// 模型调用
|
|
||||||
client, cleanup, err := utils_vllm.NewClient(s.cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
res, err := client.Chat(ctx, []*schema.Message{
|
|
||||||
{
|
|
||||||
Role: "system",
|
|
||||||
Content: constants.SystemPromptAddressIngestHyt,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: text,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析模型返回结果
|
|
||||||
var addr AddressIngestResp
|
|
||||||
// 尝试直接解析
|
|
||||||
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
|
|
||||||
// 修复json字符串
|
|
||||||
repairedContent, err := util.JSONRepair(res.Content)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errorcode.ParamErrf("invalid response body: %v", err)
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal([]byte(repairedContent), &addr); err != nil {
|
|
||||||
return nil, errorcode.ParamErrf("invalid response body: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &addr, nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
package support
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AddressIngestResp 通用地址提取响应
|
|
||||||
type AddressIngestResp struct {
|
|
||||||
Recipient string `json:"recipient"` // 收货人
|
|
||||||
Phone string `json:"phone"` // 联系电话
|
|
||||||
Address string `json:"address"` // 收货地址
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddressIngester 地址提取接口
|
|
||||||
type AddressIngester interface {
|
|
||||||
// Auth 鉴权逻辑
|
|
||||||
Auth(c *fiber.Ctx) error
|
|
||||||
// Ingest 执行提取逻辑
|
|
||||||
Ingest(ctx context.Context, text string) (*AddressIngestResp, error)
|
|
||||||
}
|
|
||||||
|
|
@ -122,13 +122,8 @@ type OllamaConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type VllmConfig struct {
|
type VllmConfig struct {
|
||||||
VLModel VllmModel `mapstructure:"vl_model"`
|
|
||||||
TextModel VllmModel `mapstructure:"text_model"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type VllmModel struct {
|
|
||||||
BaseURL string `mapstructure:"base_url"`
|
BaseURL string `mapstructure:"base_url"`
|
||||||
Model string `mapstructure:"model"`
|
VlModel string `mapstructure:"vl_model"`
|
||||||
Timeout time.Duration `mapstructure:"timeout"`
|
Timeout time.Duration `mapstructure:"timeout"`
|
||||||
Level string `mapstructure:"level"`
|
Level string `mapstructure:"level"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
package constants
|
|
||||||
|
|
||||||
// Token
|
|
||||||
const (
|
|
||||||
TokenAddressIngestHyt = "E632C7D3E60771B03264F2337CCFA014" // md5("address_ingest_hyt")
|
|
||||||
)
|
|
||||||
|
|
||||||
// 系统提示词
|
|
||||||
const (
|
|
||||||
SystemPromptAddressIngestHyt = `# 你是一个地址信息结构化解析器。
|
|
||||||
你的任务是从用户提供的非结构化文本中,准确抽取并区分以下字段:
|
|
||||||
|
|
||||||
1. 收货人 recipient (真实姓名或带掩码姓名,如“张三”)
|
|
||||||
2. 联系电话 phone (中国大陆手机号,11位数字)
|
|
||||||
3. 收货地址 address
|
|
||||||
|
|
||||||
解析规则:
|
|
||||||
- 电话号码只提取最可能的一个
|
|
||||||
- 不要编造不存在的信息
|
|
||||||
|
|
||||||
输出示例:
|
|
||||||
{\"recipient\": \"张三\",\"phone\": \"13458968095\",\"address\": \"四川省成都市武侯区天府三街88号\"}
|
|
||||||
|
|
||||||
输出格式必须为严格 JSON,不要输出任何解释性文字。`
|
|
||||||
)
|
|
||||||
|
|
@ -2,8 +2,6 @@ package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
@ -44,74 +42,3 @@ func Contains[T comparable](strings []T, str T) bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// json LLM专用字符串修复
|
|
||||||
func JSONRepair(input string) (string, error) {
|
|
||||||
s := strings.TrimSpace(input)
|
|
||||||
|
|
||||||
s = trimToJSONObject(s)
|
|
||||||
s = normalizeQuotes(s)
|
|
||||||
s = removeTrailingCommas(s)
|
|
||||||
s = quoteObjectKeys(s)
|
|
||||||
s = balanceBrackets(s)
|
|
||||||
|
|
||||||
// 最终校验
|
|
||||||
var js any
|
|
||||||
if err := json.Unmarshal([]byte(s), &js); err != nil {
|
|
||||||
return "", fmt.Errorf("json repair failed: %w", err)
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 裁剪前后垃圾文本
|
|
||||||
func trimToJSONObject(s string) string {
|
|
||||||
start := strings.IndexAny(s, "{[")
|
|
||||||
end := strings.LastIndexAny(s, "}]")
|
|
||||||
if start == -1 || end == -1 || start >= end {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return s[start : end+1]
|
|
||||||
}
|
|
||||||
|
|
||||||
// 引号统一
|
|
||||||
func normalizeQuotes(s string) string {
|
|
||||||
// 只替换“看起来像字符串的单引号”
|
|
||||||
re := regexp.MustCompile(`'([^']*)'`)
|
|
||||||
return re.ReplaceAllString(s, `"$1"`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 删除尾随逗号
|
|
||||||
func removeTrailingCommas(s string) string {
|
|
||||||
re := regexp.MustCompile(`,(\s*[}\]])`)
|
|
||||||
return re.ReplaceAllString(s, `$1`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 给 object key 自动补双引号
|
|
||||||
func quoteObjectKeys(s string) string {
|
|
||||||
re := regexp.MustCompile(`([{,]\s*)([a-zA-Z0-9_]+)\s*:`)
|
|
||||||
return re.ReplaceAllString(s, `$1"$2":`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 括号补齐
|
|
||||||
func balanceBrackets(s string) string {
|
|
||||||
var stack []rune
|
|
||||||
for _, r := range s {
|
|
||||||
switch r {
|
|
||||||
case '{', '[':
|
|
||||||
stack = append(stack, r)
|
|
||||||
case '}', ']':
|
|
||||||
if len(stack) > 0 {
|
|
||||||
stack = stack[:len(stack)-1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i := len(stack) - 1; i >= 0; i-- {
|
|
||||||
switch stack[i] {
|
|
||||||
case '{':
|
|
||||||
s += "}"
|
|
||||||
case '[':
|
|
||||||
s += "]"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -7,63 +7,33 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
|
||||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
"github.com/cloudwego/eino/components/model"
|
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
vlModel *openai.ChatModel
|
model *openai.ChatModel
|
||||||
generateModel *openai.ChatModel
|
config *config.Config
|
||||||
config *config.Config
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(config *config.Config) (*Client, func(), error) {
|
func NewClient(config *config.Config) (*Client, func(), error) {
|
||||||
// 初始化视觉模型
|
m, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
|
||||||
vl, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
|
BaseURL: config.Vllm.BaseURL,
|
||||||
BaseURL: config.Vllm.VLModel.BaseURL,
|
Model: config.Vllm.VlModel,
|
||||||
Model: config.Vllm.VLModel.Model,
|
Timeout: config.Vllm.Timeout,
|
||||||
Timeout: config.Vllm.VLModel.Timeout,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
c := &Client{model: m, config: config}
|
||||||
// 初始化生成模型
|
|
||||||
gen, err := openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
|
|
||||||
BaseURL: config.Vllm.TextModel.BaseURL,
|
|
||||||
Model: config.Vllm.TextModel.Model,
|
|
||||||
Timeout: config.Vllm.TextModel.Timeout,
|
|
||||||
ExtraFields: map[string]any{
|
|
||||||
"chat_template_kwargs": map[string]any{
|
|
||||||
"enable_thinking": false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &Client{
|
|
||||||
vlModel: vl,
|
|
||||||
generateModel: gen,
|
|
||||||
config: config,
|
|
||||||
}
|
|
||||||
cleanup := func() {}
|
cleanup := func() {}
|
||||||
return c, cleanup, nil
|
return c, cleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Chat(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) {
|
func (c *Client) Chat(ctx context.Context, msgs []*schema.Message) (*schema.Message, error) {
|
||||||
// 默认聊天使用生成模型
|
return c.model.Generate(ctx, msgs)
|
||||||
return c.generateModel.Generate(ctx, msgs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) ToolSelect(ctx context.Context, msgs []*schema.Message, tools []*schema.ToolInfo) (*schema.Message, error) {
|
|
||||||
// 工具选择使用生成模型
|
|
||||||
return c.generateModel.Generate(ctx, msgs, model.WithTools(tools))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt string, imgURLs []string) (*schema.Message, error) {
|
func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt string, imgURLs []string) (*schema.Message, error) {
|
||||||
// 图片识别使用视觉模型
|
|
||||||
in := []*schema.Message{
|
in := []*schema.Message{
|
||||||
{
|
{
|
||||||
Role: schema.System,
|
Role: schema.System,
|
||||||
|
|
@ -88,12 +58,11 @@ func (c *Client) RecognizeWithImg(ctx context.Context, systemPrompt, userPrompt
|
||||||
}
|
}
|
||||||
|
|
||||||
in[1].UserInputMultiContent = parts
|
in[1].UserInputMultiContent = parts
|
||||||
return c.vlModel.Generate(ctx, in)
|
return c.model.Generate(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 识别图片by二进制文件
|
// 识别图片by二进制文件
|
||||||
func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPrompt string, imgBytes []byte, imgType string) (*schema.Message, error) {
|
func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPrompt string, imgBytes []byte, imgType string) (*schema.Message, error) {
|
||||||
// 图片识别使用视觉模型
|
|
||||||
in := []*schema.Message{
|
in := []*schema.Message{
|
||||||
{
|
{
|
||||||
Role: schema.System,
|
Role: schema.System,
|
||||||
|
|
@ -113,10 +82,9 @@ func (c *Client) RecognizeWithImgBytes(ctx context.Context, systemPrompt, userPr
|
||||||
MIMEType: imgType,
|
MIMEType: imgType,
|
||||||
Base64Data: util.AnyToPoint(base64.StdEncoding.EncodeToString(imgBytes)),
|
Base64Data: util.AnyToPoint(base64.StdEncoding.EncodeToString(imgBytes)),
|
||||||
},
|
},
|
||||||
Detail: schema.ImageURLDetailHigh,
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
in[1].UserInputMultiContent = parts
|
in[1].UserInputMultiContent = parts
|
||||||
return c.vlModel.Generate(ctx, in)
|
return c.model.Generate(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ package server
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/services"
|
"ai_scheduler/internal/services"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/log"
|
"github.com/gofiber/fiber/v2/log"
|
||||||
"github.com/robfig/cron/v3"
|
"github.com/robfig/cron/v3"
|
||||||
|
|
@ -21,7 +20,6 @@ type cronJob struct {
|
||||||
EntryId int32
|
EntryId int32
|
||||||
Func func(context.Context) error
|
Func func(context.Context) error
|
||||||
Name string
|
Name string
|
||||||
Key string
|
|
||||||
Schedule string
|
Schedule string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -44,13 +42,11 @@ func (c *CronServer) InitJobs(ctx context.Context) {
|
||||||
{
|
{
|
||||||
Func: c.cronService.CronReportSendDingTalk,
|
Func: c.cronService.CronReportSendDingTalk,
|
||||||
Name: "直连天下报表推送(钉钉)",
|
Name: "直连天下报表推送(钉钉)",
|
||||||
Key: "ding_report_dingtalk",
|
|
||||||
Schedule: "20 12,18,23 * * *",
|
Schedule: "20 12,18,23 * * *",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Func: c.cronService.CronReportSendQywx,
|
Func: c.cronService.CronReportSendQywx,
|
||||||
Name: "直连天下报表推送(微信)",
|
Name: "直连天下报表推送(微信)",
|
||||||
Key: "ding_report_qywx",
|
|
||||||
Schedule: "20 12,18,23 * * *",
|
Schedule: "20 12,18,23 * * *",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -100,39 +96,3 @@ func (c *CronServer) Stop() {
|
||||||
c.log.Info("Cron调度器已停止")
|
c.log.Info("Cron调度器已停止")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CronServer) RunOnce(ctx context.Context, key string) error {
|
|
||||||
|
|
||||||
if c.jobs == nil {
|
|
||||||
c.InitJobs(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取key对应的任务
|
|
||||||
var job *cronJob
|
|
||||||
for _, j := range c.jobs {
|
|
||||||
if j.Key == key {
|
|
||||||
job = j
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if job == nil {
|
|
||||||
return fmt.Errorf("unknown job key: %s\n", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
fmt.Printf("任务[once]:%s执行时发生panic: %v\n", job.Name, r)
|
|
||||||
}
|
|
||||||
fmt.Printf("任务[once]:%s执行结束\n", job.Name)
|
|
||||||
}()
|
|
||||||
|
|
||||||
fmt.Printf("任务[once]:%s开始执行\n", job.Name)
|
|
||||||
|
|
||||||
err := job.Func(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("任务[once]:%s执行失败: %s\n", job.Name, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("任务[once]:%s执行成功\n", job.Name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ type HTTPServer struct {
|
||||||
callback *services.CallbackService
|
callback *services.CallbackService
|
||||||
chatHis *services.HistoryService
|
chatHis *services.HistoryService
|
||||||
capabilityService *services.CapabilityService
|
capabilityService *services.CapabilityService
|
||||||
supportService *services.SupportService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPServer(
|
func NewHTTPServer(
|
||||||
|
|
@ -29,11 +28,10 @@ func NewHTTPServer(
|
||||||
callback *services.CallbackService,
|
callback *services.CallbackService,
|
||||||
chatHis *services.HistoryService,
|
chatHis *services.HistoryService,
|
||||||
capabilityService *services.CapabilityService,
|
capabilityService *services.CapabilityService,
|
||||||
supportService *services.SupportService,
|
|
||||||
) *fiber.App {
|
) *fiber.App {
|
||||||
//构建 server
|
//构建 server
|
||||||
app := initRoute()
|
app := initRoute()
|
||||||
router.SetupRoutes(app, service, session, task, gateway, callback, chatHis, capabilityService, supportService)
|
router.SetupRoutes(app, service, session, task, gateway, callback, chatHis, capabilityService)
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,6 @@ import (
|
||||||
"ai_scheduler/internal/services"
|
"ai_scheduler/internal/services"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -28,7 +26,7 @@ type RouterServer struct {
|
||||||
// SetupRoutes 设置路由
|
// SetupRoutes 设置路由
|
||||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService,
|
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService,
|
||||||
gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService,
|
gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService,
|
||||||
capabilityService *services.CapabilityService, supportService *services.SupportService,
|
capabilityService *services.CapabilityService,
|
||||||
) {
|
) {
|
||||||
app.Use(func(c *fiber.Ctx) error {
|
app.Use(func(c *fiber.Ctx) error {
|
||||||
// 设置 CORS 头
|
// 设置 CORS 头
|
||||||
|
|
@ -100,54 +98,6 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
||||||
// 能力
|
// 能力
|
||||||
r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取
|
r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取
|
||||||
r.Post("/capability/product/ingest/:thread_id/confirm", capabilityService.ProductIngestConfirm) // 商品数据提取确认
|
r.Post("/capability/product/ingest/:thread_id/confirm", capabilityService.ProductIngestConfirm) // 商品数据提取确认
|
||||||
|
|
||||||
// 外部系统支持
|
|
||||||
r.Post("/support/address/ingest/:platform", supportService.AddressIngest) // 通用收获地址提取
|
|
||||||
|
|
||||||
// 转发Python服务
|
|
||||||
r.All("/proxy/fingerprint", forwardToPythonService)
|
|
||||||
}
|
|
||||||
|
|
||||||
func forwardToPythonService(c *fiber.Ctx) error {
|
|
||||||
targetURL := "http://192.168.6.115:10086/fingerprint"
|
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Method(), targetURL, c.Context().Request.BodyStream())
|
|
||||||
if err != nil {
|
|
||||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request().Header.VisitAll(func(key, value []byte) {
|
|
||||||
req.Header.Add(string(key), string(value))
|
|
||||||
})
|
|
||||||
|
|
||||||
if c.Context().QueryArgs().Len() > 0 {
|
|
||||||
q := req.URL.Query()
|
|
||||||
c.Context().QueryArgs().VisitAll(func(key, value []byte) {
|
|
||||||
q.Add(string(key), string(value))
|
|
||||||
})
|
|
||||||
req.URL.RawQuery = q.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return c.Status(fiber.StatusBadGateway).SendString(err.Error())
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
for key, values := range resp.Header {
|
|
||||||
for _, value := range values {
|
|
||||||
c.Set(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(resp.StatusCode)
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.Send(body)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||||||
|
|
|
||||||
|
|
@ -64,15 +64,9 @@ func (d *DingBotService) runBackgroundTasks(ctx context.Context, data *chatbot.B
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
defer func() {
|
defer func() {
|
||||||
// 确保通道最终关闭
|
// 确保通道最终关闭
|
||||||
log.Println("流式处理协程关闭")
|
|
||||||
|
|
||||||
close(resChan)
|
close(resChan)
|
||||||
}()
|
}()
|
||||||
err := d.dingTalkBotBiz.HandleStreamRes(ctx, data, resChan)
|
return d.dingTalkBotBiz.HandleStreamRes(ctx, data, resChan)
|
||||||
if err != nil {
|
|
||||||
log.Println("流式回复产生错误,错误:", err.Error())
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// 2. 业务处理协程(负责关闭requireData.Ch)
|
// 2. 业务处理协程(负责关闭requireData.Ch)
|
||||||
|
|
|
||||||
|
|
@ -15,5 +15,4 @@ var ProviderSetServices = wire.NewSet(
|
||||||
NewHistoryService,
|
NewHistoryService,
|
||||||
NewCapabilityService,
|
NewCapabilityService,
|
||||||
NewCronService,
|
NewCronService,
|
||||||
NewSupportService,
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,61 +0,0 @@
|
||||||
package services
|
|
||||||
|
|
||||||
import (
|
|
||||||
"ai_scheduler/internal/biz/support"
|
|
||||||
"ai_scheduler/internal/config"
|
|
||||||
errorcode "ai_scheduler/internal/data/error"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SupportService struct {
|
|
||||||
cfg *config.Config
|
|
||||||
addressIngester map[string]support.AddressIngester
|
|
||||||
addressIngestHyt *support.HytAddressIngester
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSupportService(cfg *config.Config, addressIngestHyt *support.HytAddressIngester) *SupportService {
|
|
||||||
s := &SupportService{
|
|
||||||
cfg: cfg,
|
|
||||||
addressIngester: map[string]support.AddressIngester{
|
|
||||||
"hyt": addressIngestHyt,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
type AddressIngestReq struct {
|
|
||||||
Text string `json:"text"` // 待提取文本
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddressIngest 收获地址提取
|
|
||||||
func (s *SupportService) AddressIngest(c *fiber.Ctx) error {
|
|
||||||
platform := c.Params("platform")
|
|
||||||
ingester, ok := s.addressIngester[platform]
|
|
||||||
if !ok {
|
|
||||||
return errorcode.ParamErrf("unsupported platform: %s", platform)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 鉴权
|
|
||||||
if err := ingester.Auth(c); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析请求参数 body
|
|
||||||
req := AddressIngestReq{}
|
|
||||||
if err := c.BodyParser(&req); err != nil {
|
|
||||||
return errorcode.ParamErrf("invalid request body: %v", err)
|
|
||||||
}
|
|
||||||
// 必要参数校验
|
|
||||||
if req.Text == "" {
|
|
||||||
return errorcode.ParamErrf("missing required fields")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 执行提取
|
|
||||||
res, err := ingester.Ingest(c.Context(), req.Text)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.JSON(res)
|
|
||||||
}
|
|
||||||
|
|
@ -368,13 +368,7 @@ func (b *BbxtTools) GetProfitRankingSum(now time.Time) (report *ReportRes, err e
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(data.List) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
maxLen := 20
|
|
||||||
if len(data.List) < maxLen {
|
|
||||||
maxLen = len(data.List)
|
|
||||||
}
|
|
||||||
//排序
|
//排序
|
||||||
sort.Slice(data.List, func(i, j int) bool {
|
sort.Slice(data.List, func(i, j int) bool {
|
||||||
return data.List[i].HistoryOneDiff > data.List[j].HistoryOneDiff
|
return data.List[i].HistoryOneDiff > data.List[j].HistoryOneDiff
|
||||||
|
|
@ -382,7 +376,7 @@ func (b *BbxtTools) GetProfitRankingSum(now time.Time) (report *ReportRes, err e
|
||||||
//取前20和后20
|
//取前20和后20
|
||||||
var (
|
var (
|
||||||
total [][]string
|
total [][]string
|
||||||
top = data.List[:maxLen]
|
top = data.List[:20]
|
||||||
bottom = data.List[len(data.List)-20:]
|
bottom = data.List[len(data.List)-20:]
|
||||||
)
|
)
|
||||||
//合并前20和后20
|
//合并前20和后20
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func run() {
|
func run() {
|
||||||
configConfig, _ = config.LoadConfigWithEnv()
|
configConfig, _ = config.LoadConfigWithTest()
|
||||||
// 初始化数据库连接
|
// 初始化数据库连接
|
||||||
db, _ := utils.NewGormDb(configConfig)
|
db, _ := utils.NewGormDb(configConfig)
|
||||||
reportDailyCacheImpl = impl.NewReportDailyCacheImpl(db)
|
reportDailyCacheImpl = impl.NewReportDailyCacheImpl(db)
|
||||||
|
|
|
||||||
Binary file not shown.
Loading…
Reference in New Issue