Compare commits

..

2 Commits

Author SHA1 Message Date
renzhiyuan ec21193f66 Merge branch 'v2' 2025-11-13 10:02:38 +08:00
renzhiyuan 4d5ddf03f5 feat: 新增任务管理功能 2025-11-13 10:02:10 +08:00
9 changed files with 136 additions and 37 deletions

View File

@ -17,4 +17,5 @@ var ProviderSetBiz = wire.NewSet(
handle.NewHandle,
do.NewDo,
do.NewHandle,
NewTaskBiz,
)

35
internal/biz/task.go Normal file
View File

@ -0,0 +1,35 @@
package biz
import (
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"context"
"xorm.io/builder"
"ai_scheduler/internal/config"
)
type TaskBiz struct {
taskRepo *impl.TaskImpl
conf *config.Config
}
func NewTaskBiz(conf *config.Config, taskRepo *impl.TaskImpl) *TaskBiz {
return &TaskBiz{
taskRepo: taskRepo,
conf: conf,
}
}
// taskList 功能列表
func (t *TaskBiz) TaskList(ctx context.Context, req *entitys.TaskRequest) (list []model.AiTask, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"status": 1})
cond = cond.And(builder.Eq{"sys_id": req.SysId})
err = t.taskRepo.GetRangeToMapStruct(&cond, &list)
return
}

View File

@ -18,3 +18,7 @@ type SessionListRequest struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
}
type TaskRequest struct {
SysId int32 `json:"sys_id"`
}

View File

@ -10,14 +10,22 @@ import (
"github.com/gofiber/fiber/v2/middleware/recover"
)
type HTTPServer struct {
app *fiber.App
service *services.ChatService
session *services.SessionService
gateway *gateway.Gateway
}
func NewHTTPServer(
service *services.ChatService,
session *services.SessionService,
task *services.TaskService,
gateway *gateway.Gateway,
) *fiber.App {
//构建 server
app := initRoute()
router.SetupRoutes(app, service, session, gateway)
router.SetupRoutes(app, service, session, task, gateway)
return app
}

View File

@ -14,8 +14,15 @@ import (
"github.com/gofiber/websocket/v2"
)
type RouterServer struct {
app *fiber.App
service *services.ChatService
session *services.SessionService
gateway *gateway.Gateway
}
// SetupRoutes 设置路由
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, gateway *gateway.Gateway) {
func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionService *services.SessionService, task *services.TaskService, gateway *gateway.Gateway) {
app.Use(func(c *fiber.Ctx) error {
// 设置 CORS 头
c.Set("Access-Control-Allow-Origin", "*")
@ -30,9 +37,40 @@ func SetupRoutes(app *fiber.App, ChatService *services.ChatService, sessionServi
// 继续处理后续中间件或路由
return c.Next()
})
routerHttp(app, sessionService, gateway)
//socket
routerSocket(app, ChatService)
r := app.Group("api/v1/")
registerResponse(r)
// 注册 CORS 中间件
r.Get("/health", func(c *fiber.Ctx) error {
c.Response().SetBody([]byte("1"))
return nil
})
r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
r.Post("/session/list", sessionService.SessionList)
r.Post("/sys/tasks", task.Tasks)
//广播
r.Get("/broadcast", func(ctx *fiber.Ctx) error {
action := ctx.Query("action")
uid := ctx.Query("uid")
msg := ctx.Query("msg")
switch action {
case "sendToAll":
gateway.SendToAll([]byte(msg))
return ctx.SendString("sent to all")
case "sendToUid":
if uid == "" {
return ctx.Status(400).SendString("missing uid")
}
gateway.SendToUid(uid, []byte(msg))
return ctx.SendString(fmt.Sprintf("sent to uid %s", uid))
default:
return ctx.Status(400).SendString("unknown action")
}
})
}
func routerSocket(app *fiber.App, chatService *services.ChatService) {
@ -58,39 +96,6 @@ var bufferPool = &sync.Pool{
},
}
func routerHttp(app *fiber.App, sessionService *services.SessionService, gateway *gateway.Gateway) {
r := app.Group("api/v1/")
registerResponse(r)
// 注册 CORS 中间件
r.Get("/health", func(c *fiber.Ctx) error {
c.Response().SetBody([]byte("1"))
return nil
})
r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
r.Post("/session/list", sessionService.SessionList)
//广播
r.Get("/broadcast", func(ctx *fiber.Ctx) error {
action := ctx.Query("action")
uid := ctx.Query("uid")
msg := ctx.Query("msg")
switch action {
case "sendToAll":
gateway.SendToAll([]byte(msg))
return ctx.SendString("sent to all")
case "sendToUid":
if uid == "" {
return ctx.Status(400).SendString("missing uid")
}
gateway.SendToUid(uid, []byte(msg))
return ctx.SendString(fmt.Sprintf("sent to uid %s", uid))
default:
return ctx.Status(400).SendString("unknown action")
}
})
}
func registerResponse(router fiber.Router) {
// 自定义返回
router.Use(func(c *fiber.Ctx) error {

View File

@ -2,7 +2,8 @@ package services
import (
"ai_scheduler/internal/gateway"
"github.com/google/wire"
)
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway)
var ProviderSetServices = wire.NewSet(NewChatService, NewSessionService, gateway.NewGateway, NewTaskService)

View File

@ -3,6 +3,7 @@ package services
import (
"ai_scheduler/internal/biz"
"ai_scheduler/internal/entitys"
"github.com/gofiber/fiber/v2"
)

35
internal/services/task.go Normal file
View File

@ -0,0 +1,35 @@
package services
import (
"ai_scheduler/internal/biz"
"ai_scheduler/internal/entitys"
"github.com/gofiber/fiber/v2"
)
type TaskService struct {
taskBiz *biz.TaskBiz
chatBiz *biz.ChatHistoryBiz
}
func NewTaskService(sessionBiz *biz.SessionBiz, taskBiz *biz.TaskBiz) *TaskService {
return &TaskService{
taskBiz: taskBiz,
}
}
// Tasks 任务列表
func (s *TaskService) Tasks(c *fiber.Ctx) error {
req := &entitys.TaskRequest{}
if err := c.BodyParser(req); err != nil {
return err
}
result, err := s.taskBiz.TaskList(c.Context(), req)
if err != nil {
return err
}
return c.JSON(result)
}

View File

@ -89,6 +89,15 @@ func (k DataTemp) GetRange(cond *builder.Cond) (list []map[string]interface{}, e
return list, err
}
func (k DataTemp) GetRangeToMapStruct(cond *builder.Cond, data interface{}) (err error) {
var (
query, _ = builder.ToBoundSQL(*cond)
model = k.Db.Model(k.Model).Where(query)
)
err = model.Find(data).Error
return err
}
func (k DataTemp) GetOneBySearch(cond *builder.Cond) (data map[string]interface{}, err error) {
query, _ := builder.ToBoundSQL(*cond)
err = k.Db.Model(k.Model).Where(query).Limit(1).Find(&data).Error