diff --git a/internal/biz/provider_set.go b/internal/biz/provider_set.go index c0e1d21..cc3de0a 100644 --- a/internal/biz/provider_set.go +++ b/internal/biz/provider_set.go @@ -17,4 +17,5 @@ var ProviderSetBiz = wire.NewSet( handle.NewHandle, do.NewDo, do.NewHandle, + NewTaskBiz, ) diff --git a/internal/biz/task.go b/internal/biz/task.go new file mode 100644 index 0000000..2c67591 --- /dev/null +++ b/internal/biz/task.go @@ -0,0 +1,35 @@ +package biz + +import ( + "ai_scheduler/internal/data/impl" + "ai_scheduler/internal/data/model" + "ai_scheduler/internal/entitys" + "context" + + "xorm.io/builder" + + "ai_scheduler/internal/config" +) + +type TaskBiz struct { + taskRepo *impl.TaskImpl + conf *config.Config +} + +func NewTaskBiz(conf *config.Config, taskRepo *impl.TaskImpl) *TaskBiz { + return &TaskBiz{ + taskRepo: taskRepo, + conf: conf, + } +} + +// taskList 功能列表 +func (t *TaskBiz) TaskList(ctx context.Context, req *entitys.TaskRequest) (list []model.AiTask, err error) { + cond := builder.NewCond() + cond = cond.And(builder.Eq{"status": 1}) + + cond = cond.And(builder.Eq{"sys_id": req.SysId}) + err = t.taskRepo.GetRangeToMapStruct(&cond, &list) + + return +} diff --git a/internal/entitys/session.go b/internal/entitys/session.go index 1b31ac1..74a517a 100644 --- a/internal/entitys/session.go +++ b/internal/entitys/session.go @@ -18,3 +18,7 @@ type SessionListRequest struct { Page int `json:"page"` PageSize int `json:"page_size"` } + +type TaskRequest struct { + SysId int32 `json:"sys_id"` +} diff --git a/internal/server/http.go b/internal/server/http.go index d874dfa..d4df3a1 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -10,14 +10,22 @@ import ( "github.com/gofiber/fiber/v2/middleware/recover" ) +type HTTPServer struct { + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway +} + func NewHTTPServer( service *services.ChatService, session *services.SessionService, + task *services.TaskService, gateway *gateway.Gateway, ) *fiber.App { //构建 server app := initRoute() - router.SetupRoutes(app, service, session, gateway) + router.SetupRoutes(app, service, session, task, gateway) return app } diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 30a0a6a..04fbc7c 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -14,8 +14,15 @@ import ( "github.com/gofiber/websocket/v2" ) +type RouterServer struct { + app *fiber.App + service *services.ChatService + session *services.SessionService + gateway *gateway.Gateway +} + // SetupRoutes 设置路由 -func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, gateway *gateway.Gateway) { +func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway) { app.Use(func(c *fiber.Ctx) error { // 设置 CORS 头 c.Set("Access-Control-Allow-Origin", "*") @@ -30,9 +37,40 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi // 继续处理后续中间件或路由 return c.Next() }) - routerHttp(app, sessionService, gateway) + //socket routerSocket(app, ChatService) + r := app.Group("api/v1/") + registerResponse(r) + // 注册 CORS 中间件 + r.Get("/health", func(c *fiber.Ctx) error { + c.Response().SetBody([]byte("1")) + return nil + }) + + r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史 + r.Post("/session/list", sessionService.SessionList) + r.Post("/sys/tasks", task.Tasks) + //广播 + r.Get("/broadcast", func(ctx *fiber.Ctx) error { + action := ctx.Query("action") + uid := ctx.Query("uid") + msg := ctx.Query("msg") + + switch action { + case "sendToAll": + gateway.SendToAll([]byte(msg)) + return ctx.SendString("sent to all") + case "sendToUid": + if uid == "" { + return ctx.Status(400).SendString("missing uid") + } + gateway.SendToUid(uid, []byte(msg)) + return ctx.SendString(fmt.Sprintf("sent to uid %s", uid)) + default: + return ctx.Status(400).SendString("unknown action") + } + }) } func routerSocket(app *fiber.App, chatService *services.ChatService) { @@ -58,39 +96,6 @@ var bufferPool = &sync.Pool{ }, } -func routerHttp(app *fiber.App, sessionService *services.SessionService, gateway *gateway.Gateway) { - r := app.Group("api/v1/") - registerResponse(r) - // 注册 CORS 中间件 - r.Get("/health", func(c *fiber.Ctx) error { - c.Response().SetBody([]byte("1")) - return nil - }) - - r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史 - r.Post("/session/list", sessionService.SessionList) - //广播 - r.Get("/broadcast", func(ctx *fiber.Ctx) error { - action := ctx.Query("action") - uid := ctx.Query("uid") - msg := ctx.Query("msg") - - switch action { - case "sendToAll": - gateway.SendToAll([]byte(msg)) - return ctx.SendString("sent to all") - case "sendToUid": - if uid == "" { - return ctx.Status(400).SendString("missing uid") - } - gateway.SendToUid(uid, []byte(msg)) - return ctx.SendString(fmt.Sprintf("sent to uid %s", uid)) - default: - return ctx.Status(400).SendString("unknown action") - } - }) -} - func registerResponse(router fiber.Router) { // 自定义返回 router.Use(func(c *fiber.Ctx) error { diff --git a/internal/services/provider_set.go b/internal/services/provider_set.go index f5b03d4..70987fa 100644 --- a/internal/services/provider_set.go +++ b/internal/services/provider_set.go @@ -2,7 +2,8 @@ package services import ( "ai_scheduler/internal/gateway" + "github.com/google/wire" ) -var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway) +var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService) diff --git a/internal/services/session.go b/internal/services/session.go index ea4f464..6820fe5 100644 --- a/internal/services/session.go +++ b/internal/services/session.go @@ -3,6 +3,7 @@ package services import ( "ai_scheduler/internal/biz" "ai_scheduler/internal/entitys" + "github.com/gofiber/fiber/v2" ) diff --git a/internal/services/task.go b/internal/services/task.go new file mode 100644 index 0000000..380f027 --- /dev/null +++ b/internal/services/task.go @@ -0,0 +1,35 @@ +package services + +import ( + "ai_scheduler/internal/biz" + "ai_scheduler/internal/entitys" + + "github.com/gofiber/fiber/v2" +) + +type TaskService struct { + taskBiz *biz.TaskBiz + chatBiz *biz.ChatHistoryBiz +} + +func NewTaskService(sessionBiz *biz.SessionBiz, taskBiz *biz.TaskBiz) *TaskService { + return &TaskService{ + taskBiz: taskBiz, + } +} + +// Tasks 任务列表 +func (s *TaskService) Tasks(c *fiber.Ctx) error { + req := &entitys.TaskRequest{} + if err := c.BodyParser(req); err != nil { + return err + } + + result, err := s.taskBiz.TaskList(c.Context(), req) + + if err != nil { + return err + } + + return c.JSON(result) +} diff --git a/tmpl/dataTemp/queryTempl.go b/tmpl/dataTemp/queryTempl.go index e3d98e3..2415755 100644 --- a/tmpl/dataTemp/queryTempl.go +++ b/tmpl/dataTemp/queryTempl.go @@ -89,6 +89,15 @@ func (k DataTemp) GetRange(cond *builder.Cond) (list []map[string]interface{}, e return list, err } +func (k DataTemp) GetRangeToMapStruct(cond *builder.Cond, data interface{}) (err error) { + var ( + query, _ = builder.ToBoundSQL(*cond) + model = k.Db.Model(k.Model).Where(query) + ) + err = model.Find(data).Error + return err +} + func (k DataTemp) GetOneBySearch(cond *builder.Cond) (data map[string]interface{}, err error) { query, _ := builder.ToBoundSQL(*cond) err = k.Db.Model(k.Model).Where(query).Limit(1).Find(&data).Error