Compare commits
3 Commits
4e301d17cd
...
c7b35f8dcd
| Author | SHA1 | Date |
|---|---|---|
|
|
c7b35f8dcd | |
|
|
e981115cc1 | |
|
|
8d046df04e |
|
|
@ -13,6 +13,7 @@ func main() {
|
||||||
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
|
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
|
||||||
onBot := flag.String("bot", "", "bot start")
|
onBot := flag.String("bot", "", "bot start")
|
||||||
cron := flag.String("cron", "", "close")
|
cron := flag.String("cron", "", "close")
|
||||||
|
runJob := flag.String("runJob", "", "run single job and exit")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
bc, err := config.LoadConfig(*configPath)
|
bc, err := config.LoadConfig(*configPath)
|
||||||
|
|
@ -33,6 +34,11 @@ func main() {
|
||||||
if *cron == "start" {
|
if *cron == "start" {
|
||||||
app.Cron.Run(ctx)
|
app.Cron.Run(ctx)
|
||||||
}
|
}
|
||||||
|
// 运行指定任务并退出
|
||||||
|
if *runJob != "" {
|
||||||
|
app.Cron.RunOnce(ctx, *runJob)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
log.Fatal(app.HttpServer.Listen(fmt.Sprintf(":%d", bc.Server.Port)))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package biz
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/biz/do"
|
"ai_scheduler/internal/biz/do"
|
||||||
"ai_scheduler/internal/biz/llm_service"
|
"ai_scheduler/internal/biz/llm_service"
|
||||||
|
"ai_scheduler/internal/biz/support"
|
||||||
|
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
)
|
)
|
||||||
|
|
@ -22,4 +23,5 @@ var ProviderSetBiz = wire.NewSet(
|
||||||
NewQywxAppBiz,
|
NewQywxAppBiz,
|
||||||
NewGroupConfigBiz,
|
NewGroupConfigBiz,
|
||||||
do.NewMacro,
|
do.NewMacro,
|
||||||
|
support.NewHytAddressIngester,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
package support
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/data/constants"
|
||||||
|
errorcode "ai_scheduler/internal/data/error"
|
||||||
|
"ai_scheduler/internal/pkg/util"
|
||||||
|
"ai_scheduler/internal/pkg/utils_vllm"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HytAddressIngester 货易通地址提取实现
|
||||||
|
type HytAddressIngester struct {
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHytAddressIngester(cfg *config.Config) *HytAddressIngester {
|
||||||
|
return &HytAddressIngester{cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth 鉴权逻辑
|
||||||
|
func (s *HytAddressIngester) Auth(c *fiber.Ctx) error {
|
||||||
|
// 读取头
|
||||||
|
token := strings.TrimSpace(c.Get("X-Source-Key"))
|
||||||
|
ts := strings.TrimSpace(c.Get("X-Timestamp"))
|
||||||
|
|
||||||
|
// 时间窗口校验
|
||||||
|
if ts != "" && !util.IsInTimeWindow(ts, 5*time.Minute) {
|
||||||
|
return errorcode.AuthNotFound
|
||||||
|
}
|
||||||
|
// token校验
|
||||||
|
if token == "" || token != constants.TokenAddressIngestHyt {
|
||||||
|
return errorcode.KeyNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ingest 执行提取逻辑
|
||||||
|
func (s *HytAddressIngester) Ingest(ctx context.Context, text string) (*AddressIngestResp, error) {
|
||||||
|
// 模型调用
|
||||||
|
client, cleanup, err := utils_vllm.NewClient(s.cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
res, err := client.Chat(ctx, []*schema.Message{
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: constants.SystemPromptAddressIngestHyt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: text,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析模型返回结果
|
||||||
|
var addr AddressIngestResp
|
||||||
|
// 尝试直接解析
|
||||||
|
if err := json.Unmarshal([]byte(res.Content), &addr); err != nil {
|
||||||
|
// 修复json字符串
|
||||||
|
repairedContent, err := util.JSONRepair(res.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errorcode.ParamErrf("invalid response body: %v", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(repairedContent), &addr); err != nil {
|
||||||
|
return nil, errorcode.ParamErrf("invalid response body: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &addr, nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
package support
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddressIngestResp 通用地址提取响应
|
||||||
|
type AddressIngestResp struct {
|
||||||
|
Recipient string `json:"recipient"` // 收货人
|
||||||
|
Phone string `json:"phone"` // 联系电话
|
||||||
|
Address string `json:"address"` // 收货地址
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddressIngester 地址提取接口
|
||||||
|
type AddressIngester interface {
|
||||||
|
// Auth 鉴权逻辑
|
||||||
|
Auth(c *fiber.Ctx) error
|
||||||
|
// Ingest 执行提取逻辑
|
||||||
|
Ingest(ctx context.Context, text string) (*AddressIngestResp, error)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
package constants
|
||||||
|
|
||||||
|
// Token
|
||||||
|
const (
|
||||||
|
TokenAddressIngestHyt = "E632C7D3E60771B03264F2337CCFA014" // md5("address_ingest_hyt")
|
||||||
|
)
|
||||||
|
|
||||||
|
// 系统提示词
|
||||||
|
const (
|
||||||
|
SystemPromptAddressIngestHyt = `# 你是一个地址信息结构化解析器。
|
||||||
|
你的任务是从用户提供的非结构化文本中,准确抽取并区分以下字段:
|
||||||
|
|
||||||
|
1. 收货人 recipient (真实姓名或带掩码姓名,如“张三”)
|
||||||
|
2. 联系电话 phone (中国大陆手机号,11位数字)
|
||||||
|
3. 收货地址 address
|
||||||
|
|
||||||
|
解析规则:
|
||||||
|
- 电话号码只提取最可能的一个
|
||||||
|
- 不要编造不存在的信息
|
||||||
|
|
||||||
|
输出示例:
|
||||||
|
{\"recipient\": \"张三\",\"phone\": \"13458968095\",\"address\": \"四川省成都市武侯区天府三街88号\"}
|
||||||
|
|
||||||
|
输出格式必须为严格 JSON,不要输出任何解释性文字。`
|
||||||
|
)
|
||||||
|
|
@ -2,6 +2,8 @@ package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
@ -42,3 +44,74 @@ func Contains[T comparable](strings []T, str T) bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// json LLM专用字符串修复
|
||||||
|
func JSONRepair(input string) (string, error) {
|
||||||
|
s := strings.TrimSpace(input)
|
||||||
|
|
||||||
|
s = trimToJSONObject(s)
|
||||||
|
s = normalizeQuotes(s)
|
||||||
|
s = removeTrailingCommas(s)
|
||||||
|
s = quoteObjectKeys(s)
|
||||||
|
s = balanceBrackets(s)
|
||||||
|
|
||||||
|
// 最终校验
|
||||||
|
var js any
|
||||||
|
if err := json.Unmarshal([]byte(s), &js); err != nil {
|
||||||
|
return "", fmt.Errorf("json repair failed: %w", err)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 裁剪前后垃圾文本
|
||||||
|
func trimToJSONObject(s string) string {
|
||||||
|
start := strings.IndexAny(s, "{[")
|
||||||
|
end := strings.LastIndexAny(s, "}]")
|
||||||
|
if start == -1 || end == -1 || start >= end {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[start : end+1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 引号统一
|
||||||
|
func normalizeQuotes(s string) string {
|
||||||
|
// 只替换“看起来像字符串的单引号”
|
||||||
|
re := regexp.MustCompile(`'([^']*)'`)
|
||||||
|
return re.ReplaceAllString(s, `"$1"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除尾随逗号
|
||||||
|
func removeTrailingCommas(s string) string {
|
||||||
|
re := regexp.MustCompile(`,(\s*[}\]])`)
|
||||||
|
return re.ReplaceAllString(s, `$1`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 给 object key 自动补双引号
|
||||||
|
func quoteObjectKeys(s string) string {
|
||||||
|
re := regexp.MustCompile(`([{,]\s*)([a-zA-Z0-9_]+)\s*:`)
|
||||||
|
return re.ReplaceAllString(s, `$1"$2":`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 括号补齐
|
||||||
|
func balanceBrackets(s string) string {
|
||||||
|
var stack []rune
|
||||||
|
for _, r := range s {
|
||||||
|
switch r {
|
||||||
|
case '{', '[':
|
||||||
|
stack = append(stack, r)
|
||||||
|
case '}', ']':
|
||||||
|
if len(stack) > 0 {
|
||||||
|
stack = stack[:len(stack)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := len(stack) - 1; i >= 0; i-- {
|
||||||
|
switch stack[i] {
|
||||||
|
case '{':
|
||||||
|
s += "}"
|
||||||
|
case '[':
|
||||||
|
s += "]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package server
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/services"
|
"ai_scheduler/internal/services"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/log"
|
"github.com/gofiber/fiber/v2/log"
|
||||||
"github.com/robfig/cron/v3"
|
"github.com/robfig/cron/v3"
|
||||||
|
|
@ -20,6 +21,7 @@ type cronJob struct {
|
||||||
EntryId int32
|
EntryId int32
|
||||||
Func func(context.Context) error
|
Func func(context.Context) error
|
||||||
Name string
|
Name string
|
||||||
|
Key string
|
||||||
Schedule string
|
Schedule string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -42,11 +44,13 @@ func (c *CronServer) InitJobs(ctx context.Context) {
|
||||||
{
|
{
|
||||||
Func: c.cronService.CronReportSendDingTalk,
|
Func: c.cronService.CronReportSendDingTalk,
|
||||||
Name: "直连天下报表推送(钉钉)",
|
Name: "直连天下报表推送(钉钉)",
|
||||||
|
Key: "ding_report_dingtalk",
|
||||||
Schedule: "20 12,18,23 * * *",
|
Schedule: "20 12,18,23 * * *",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Func: c.cronService.CronReportSendQywx,
|
Func: c.cronService.CronReportSendQywx,
|
||||||
Name: "直连天下报表推送(微信)",
|
Name: "直连天下报表推送(微信)",
|
||||||
|
Key: "ding_report_qywx",
|
||||||
Schedule: "20 12,18,23 * * *",
|
Schedule: "20 12,18,23 * * *",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -96,3 +100,39 @@ func (c *CronServer) Stop() {
|
||||||
c.log.Info("Cron调度器已停止")
|
c.log.Info("Cron调度器已停止")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *CronServer) RunOnce(ctx context.Context, key string) error {
|
||||||
|
|
||||||
|
if c.jobs == nil {
|
||||||
|
c.InitJobs(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取key对应的任务
|
||||||
|
var job *cronJob
|
||||||
|
for _, j := range c.jobs {
|
||||||
|
if j.Key == key {
|
||||||
|
job = j
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if job == nil {
|
||||||
|
return fmt.Errorf("unknown job key: %s\n", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
fmt.Printf("任务[once]:%s执行时发生panic: %v\n", job.Name, r)
|
||||||
|
}
|
||||||
|
fmt.Printf("任务[once]:%s执行结束\n", job.Name)
|
||||||
|
}()
|
||||||
|
|
||||||
|
fmt.Printf("任务[once]:%s开始执行\n", job.Name)
|
||||||
|
|
||||||
|
err := job.Func(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("任务[once]:%s执行失败: %s\n", job.Name, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("任务[once]:%s执行成功\n", job.Name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ type HTTPServer struct {
|
||||||
callback *services.CallbackService
|
callback *services.CallbackService
|
||||||
chatHis *services.HistoryService
|
chatHis *services.HistoryService
|
||||||
capabilityService *services.CapabilityService
|
capabilityService *services.CapabilityService
|
||||||
|
supportService *services.SupportService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPServer(
|
func NewHTTPServer(
|
||||||
|
|
@ -28,10 +29,11 @@ func NewHTTPServer(
|
||||||
callback *services.CallbackService,
|
callback *services.CallbackService,
|
||||||
chatHis *services.HistoryService,
|
chatHis *services.HistoryService,
|
||||||
capabilityService *services.CapabilityService,
|
capabilityService *services.CapabilityService,
|
||||||
|
supportService *services.SupportService,
|
||||||
) *fiber.App {
|
) *fiber.App {
|
||||||
//构建 server
|
//构建 server
|
||||||
app := initRoute()
|
app := initRoute()
|
||||||
router.SetupRoutes(app, service, session, task, gateway, callback, chatHis, capabilityService)
|
router.SetupRoutes(app, service, session, task, gateway, callback, chatHis, capabilityService, supportService)
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ type RouterServer struct {
|
||||||
// SetupRoutes 设置路由
|
// SetupRoutes 设置路由
|
||||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService,
|
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService,
|
||||||
gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService,
|
gateway *gateway.Gateway, callbackService *services.CallbackService, chatHist *services.HistoryService,
|
||||||
capabilityService *services.CapabilityService,
|
capabilityService *services.CapabilityService, supportService *services.SupportService,
|
||||||
) {
|
) {
|
||||||
app.Use(func(c *fiber.Ctx) error {
|
app.Use(func(c *fiber.Ctx) error {
|
||||||
// 设置 CORS 头
|
// 设置 CORS 头
|
||||||
|
|
@ -98,6 +98,9 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
||||||
// 能力
|
// 能力
|
||||||
r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取
|
r.Post("/capability/product/ingest", capabilityService.ProductIngest) // 商品数据提取
|
||||||
r.Post("/capability/product/ingest/:thread_id/confirm", capabilityService.ProductIngestConfirm) // 商品数据提取确认
|
r.Post("/capability/product/ingest/:thread_id/confirm", capabilityService.ProductIngestConfirm) // 商品数据提取确认
|
||||||
|
|
||||||
|
// 外部系统支持
|
||||||
|
r.Post("/support/address/ingest/:platform", supportService.AddressIngest) // 通用收获地址提取
|
||||||
}
|
}
|
||||||
|
|
||||||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||||||
|
|
|
||||||
|
|
@ -15,4 +15,5 @@ var ProviderSetServices = wire.NewSet(
|
||||||
NewHistoryService,
|
NewHistoryService,
|
||||||
NewCapabilityService,
|
NewCapabilityService,
|
||||||
NewCronService,
|
NewCronService,
|
||||||
|
NewSupportService,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/biz/support"
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
errorcode "ai_scheduler/internal/data/error"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SupportService struct {
|
||||||
|
cfg *config.Config
|
||||||
|
addressIngester map[string]support.AddressIngester
|
||||||
|
addressIngestHyt *support.HytAddressIngester
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSupportService(cfg *config.Config, addressIngestHyt *support.HytAddressIngester) *SupportService {
|
||||||
|
s := &SupportService{
|
||||||
|
cfg: cfg,
|
||||||
|
addressIngester: map[string]support.AddressIngester{
|
||||||
|
"hyt": addressIngestHyt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
type AddressIngestReq struct {
|
||||||
|
Text string `json:"text"` // 待提取文本
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddressIngest 收获地址提取
|
||||||
|
func (s *SupportService) AddressIngest(c *fiber.Ctx) error {
|
||||||
|
platform := c.Params("platform")
|
||||||
|
ingester, ok := s.addressIngester[platform]
|
||||||
|
if !ok {
|
||||||
|
return errorcode.ParamErrf("unsupported platform: %s", platform)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 鉴权
|
||||||
|
if err := ingester.Auth(c); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析请求参数 body
|
||||||
|
req := AddressIngestReq{}
|
||||||
|
if err := c.BodyParser(&req); err != nil {
|
||||||
|
return errorcode.ParamErrf("invalid request body: %v", err)
|
||||||
|
}
|
||||||
|
// 必要参数校验
|
||||||
|
if req.Text == "" {
|
||||||
|
return errorcode.ParamErrf("missing required fields")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行提取
|
||||||
|
res, err := ingester.Ingest(c.Context(), req.Text)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(res)
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue