Compare commits
2 Commits
b344ad77b6
...
ec21193f66
| Author | SHA1 | Date |
|---|---|---|
|
|
ec21193f66 | |
|
|
4d5ddf03f5 |
|
|
@ -17,4 +17,5 @@ var ProviderSetBiz = wire.NewSet(
|
|||
handle.NewHandle,
|
||||
do.NewDo,
|
||||
do.NewHandle,
|
||||
NewTaskBiz,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,35 @@
|
|||
package biz
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"context"
|
||||
|
||||
"xorm.io/builder"
|
||||
|
||||
"ai_scheduler/internal/config"
|
||||
)
|
||||
|
||||
type TaskBiz struct {
|
||||
taskRepo *impl.TaskImpl
|
||||
conf *config.Config
|
||||
}
|
||||
|
||||
func NewTaskBiz(conf *config.Config, taskRepo *impl.TaskImpl) *TaskBiz {
|
||||
return &TaskBiz{
|
||||
taskRepo: taskRepo,
|
||||
conf: conf,
|
||||
}
|
||||
}
|
||||
|
||||
// taskList 功能列表
|
||||
func (t *TaskBiz) TaskList(ctx context.Context, req *entitys.TaskRequest) (list []model.AiTask, err error) {
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"status": 1})
|
||||
|
||||
cond = cond.And(builder.Eq{"sys_id": req.SysId})
|
||||
err = t.taskRepo.GetRangeToMapStruct(&cond, &list)
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -18,3 +18,7 @@ type SessionListRequest struct {
|
|||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
type TaskRequest struct {
|
||||
SysId int32 `json:"sys_id"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,14 +10,22 @@ import (
|
|||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
)
|
||||
|
||||
type HTTPServer struct {
|
||||
app *fiber.App
|
||||
service *services.ChatService
|
||||
session *services.SessionService
|
||||
gateway *gateway.Gateway
|
||||
}
|
||||
|
||||
func NewHTTPServer(
|
||||
service *services.ChatService,
|
||||
session *services.SessionService,
|
||||
task *services.TaskService,
|
||||
gateway *gateway.Gateway,
|
||||
) *fiber.App {
|
||||
//构建 server
|
||||
app := initRoute()
|
||||
router.SetupRoutes(app, service, session, gateway)
|
||||
router.SetupRoutes(app, service, session, task, gateway)
|
||||
return app
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,8 +14,15 @@ import (
|
|||
"github.com/gofiber/websocket/v2"
|
||||
)
|
||||
|
||||
type RouterServer struct {
|
||||
app *fiber.App
|
||||
service *services.ChatService
|
||||
session *services.SessionService
|
||||
gateway *gateway.Gateway
|
||||
}
|
||||
|
||||
// SetupRoutes 设置路由
|
||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, gateway *gateway.Gateway) {
|
||||
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway) {
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
// 设置 CORS 头
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
|
|
@ -30,9 +37,40 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
|
|||
// 继续处理后续中间件或路由
|
||||
return c.Next()
|
||||
})
|
||||
routerHttp(app, sessionService, gateway)
|
||||
//socket
|
||||
routerSocket(app, ChatService)
|
||||
|
||||
r := app.Group("api/v1/")
|
||||
registerResponse(r)
|
||||
// 注册 CORS 中间件
|
||||
r.Get("/health", func(c *fiber.Ctx) error {
|
||||
c.Response().SetBody([]byte("1"))
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
|
||||
r.Post("/session/list", sessionService.SessionList)
|
||||
r.Post("/sys/tasks", task.Tasks)
|
||||
//广播
|
||||
r.Get("/broadcast", func(ctx *fiber.Ctx) error {
|
||||
action := ctx.Query("action")
|
||||
uid := ctx.Query("uid")
|
||||
msg := ctx.Query("msg")
|
||||
|
||||
switch action {
|
||||
case "sendToAll":
|
||||
gateway.SendToAll([]byte(msg))
|
||||
return ctx.SendString("sent to all")
|
||||
case "sendToUid":
|
||||
if uid == "" {
|
||||
return ctx.Status(400).SendString("missing uid")
|
||||
}
|
||||
gateway.SendToUid(uid, []byte(msg))
|
||||
return ctx.SendString(fmt.Sprintf("sent to uid %s", uid))
|
||||
default:
|
||||
return ctx.Status(400).SendString("unknown action")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func routerSocket(app *fiber.App, chatService *services.ChatService) {
|
||||
|
|
@ -58,39 +96,6 @@ var bufferPool = &sync.Pool{
|
|||
},
|
||||
}
|
||||
|
||||
func routerHttp(app *fiber.App, sessionService *services.SessionService, gateway *gateway.Gateway) {
|
||||
r := app.Group("api/v1/")
|
||||
registerResponse(r)
|
||||
// 注册 CORS 中间件
|
||||
r.Get("/health", func(c *fiber.Ctx) error {
|
||||
c.Response().SetBody([]byte("1"))
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
|
||||
r.Post("/session/list", sessionService.SessionList)
|
||||
//广播
|
||||
r.Get("/broadcast", func(ctx *fiber.Ctx) error {
|
||||
action := ctx.Query("action")
|
||||
uid := ctx.Query("uid")
|
||||
msg := ctx.Query("msg")
|
||||
|
||||
switch action {
|
||||
case "sendToAll":
|
||||
gateway.SendToAll([]byte(msg))
|
||||
return ctx.SendString("sent to all")
|
||||
case "sendToUid":
|
||||
if uid == "" {
|
||||
return ctx.Status(400).SendString("missing uid")
|
||||
}
|
||||
gateway.SendToUid(uid, []byte(msg))
|
||||
return ctx.SendString(fmt.Sprintf("sent to uid %s", uid))
|
||||
default:
|
||||
return ctx.Status(400).SendString("unknown action")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func registerResponse(router fiber.Router) {
|
||||
// 自定义返回
|
||||
router.Use(func(c *fiber.Ctx) error {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ package services
|
|||
|
||||
import (
|
||||
"ai_scheduler/internal/gateway"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway)
|
||||
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package services
|
|||
import (
|
||||
"ai_scheduler/internal/biz"
|
||||
"ai_scheduler/internal/entitys"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,35 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/biz"
|
||||
"ai_scheduler/internal/entitys"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type TaskService struct {
|
||||
taskBiz *biz.TaskBiz
|
||||
chatBiz *biz.ChatHistoryBiz
|
||||
}
|
||||
|
||||
func NewTaskService(sessionBiz *biz.SessionBiz, taskBiz *biz.TaskBiz) *TaskService {
|
||||
return &TaskService{
|
||||
taskBiz: taskBiz,
|
||||
}
|
||||
}
|
||||
|
||||
// Tasks 任务列表
|
||||
func (s *TaskService) Tasks(c *fiber.Ctx) error {
|
||||
req := &entitys.TaskRequest{}
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := s.taskBiz.TaskList(c.Context(), req)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
|
@ -89,6 +89,15 @@ func (k DataTemp) GetRange(cond *builder.Cond) (list []map[string]interface{}, e
|
|||
return list, err
|
||||
}
|
||||
|
||||
func (k DataTemp) GetRangeToMapStruct(cond *builder.Cond, data interface{}) (err error) {
|
||||
var (
|
||||
query, _ = builder.ToBoundSQL(*cond)
|
||||
model = k.Db.Model(k.Model).Where(query)
|
||||
)
|
||||
err = model.Find(data).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (k DataTemp) GetOneBySearch(cond *builder.Cond) (data map[string]interface{}, err error) {
|
||||
query, _ := builder.ToBoundSQL(*cond)
|
||||
err = k.Db.Model(k.Model).Where(query).Limit(1).Find(&data).Error
|
||||
|
|
|
|||
Loading…
Reference in New Issue