99 lines
2.9 KiB
Go
99 lines
2.9 KiB
Go
package server
|
|
|
|
import (
|
|
"center-api/api/apierr"
|
|
"center-api/internal/conf"
|
|
log2 "center-api/internal/pkg/log"
|
|
"context"
|
|
"encoding/json"
|
|
"github.com/go-kratos/kratos/v2/errors"
|
|
"github.com/go-kratos/kratos/v2/log"
|
|
"github.com/go-kratos/kratos/v2/middleware"
|
|
"github.com/go-kratos/kratos/v2/middleware/logging"
|
|
"github.com/go-kratos/kratos/v2/middleware/recovery"
|
|
"github.com/go-kratos/kratos/v2/middleware/validate"
|
|
"github.com/go-kratos/kratos/v2/transport/http"
|
|
"github.com/gorilla/handlers"
|
|
http2 "net/http"
|
|
)
|
|
|
|
// NewHTTPServer new an HTTP server.
|
|
func NewHTTPServer(
|
|
c *conf.Bootstrap,
|
|
hLogger *log.Helper,
|
|
accessLogger *log2.AccessLogger,
|
|
) *http.Server {
|
|
//构建 server
|
|
srv := buildHTTPServer(c, accessLogger)
|
|
|
|
return srv
|
|
}
|
|
|
|
// buildHTTPServer 构建 HTTP Server
|
|
func buildHTTPServer(c *conf.Bootstrap, accessLogger *log2.AccessLogger) *http.Server {
|
|
var opts []http.ServerOption
|
|
|
|
if c.Server.Http.Network != "" {
|
|
opts = append(opts, http.Network(c.Server.Http.Network))
|
|
}
|
|
if c.Server.Http.Addr != "" {
|
|
opts = append(opts, http.Address(c.Server.Http.Addr))
|
|
}
|
|
if c.Server.Http.Timeout != nil {
|
|
opts = append(opts, http.Timeout(c.Server.Http.Timeout.AsDuration()))
|
|
}
|
|
|
|
// 允许跨域
|
|
filters := []http.FilterFunc{handlers.CORS(
|
|
// 域名配置
|
|
handlers.AllowedOrigins([]string{"*"}),
|
|
handlers.AllowedHeaders([]string{"Content-Type", "X-Requested-With", "Authorization"}),
|
|
handlers.AllowedMethods([]string{"GET", "POST", "PUT", "HEAD", "OPTIONS", "DELETE"}),
|
|
)}
|
|
if c.Server.Http.IsResponseReqHeaders {
|
|
//将请求的 header 头信息返回给客户端
|
|
filters = append(filters, responseReqHeadersHandler)
|
|
}
|
|
opts = append(opts, http.Filter(filters...))
|
|
|
|
//拦截恢复异常
|
|
middlewares := []middleware.Middleware{
|
|
recovery.Recovery(recovery.WithHandler(func(ctx context.Context, req, err interface{}) error {
|
|
//做一些panic处理
|
|
newErr, isOk := err.(*errors.Error)
|
|
if isOk {
|
|
//如果是自己定义的类型,则直接抛出去,目的是支持 panic(myErr) 类型的写法,提升开发速度
|
|
return newErr
|
|
}
|
|
return apierr.ErrorSystemPanic("系统错误,请重试")
|
|
})),
|
|
}
|
|
//打印访问日志
|
|
if accessLogger != nil {
|
|
middlewares = append(middlewares, logging.Server(*accessLogger))
|
|
}
|
|
//加入 validator
|
|
middlewares = append(middlewares, validate.Validator())
|
|
|
|
opts = append(opts, http.Middleware(middlewares...))
|
|
|
|
//注册 service到 http 上
|
|
srv := http.NewServer(opts...)
|
|
|
|
//注册 openapi
|
|
if c.Server.Http.GetIsOpenSwagger() {
|
|
srv.HandlePrefix("/doc/", http2.StripPrefix("/doc/", http2.FileServer(http2.Dir("./third_party/swagger_ui"))))
|
|
}
|
|
|
|
return srv
|
|
}
|
|
|
|
// 将请求头响应到 header 中
|
|
func responseReqHeadersHandler(h http2.Handler) http2.Handler {
|
|
return http2.HandlerFunc(func(w http2.ResponseWriter, r *http2.Request) {
|
|
reqHeaders, _ := json.Marshal(r.Header)
|
|
w.Header().Add("Req-Headers", string(reqHeaders))
|
|
h.ServeHTTP(w, r)
|
|
})
|
|
}
|