158 lines
4.3 KiB
Go
158 lines
4.3 KiB
Go
package server
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"github.com/go-kratos/kratos/v2/errors"
|
||
"github.com/go-kratos/kratos/v2/log"
|
||
"github.com/go-kratos/kratos/v2/middleware"
|
||
"github.com/go-kratos/kratos/v2/middleware/logging"
|
||
"github.com/go-kratos/kratos/v2/middleware/recovery"
|
||
"github.com/go-kratos/kratos/v2/middleware/validate"
|
||
"github.com/go-kratos/kratos/v2/transport/http"
|
||
"github.com/gorilla/handlers"
|
||
http2 "net/http"
|
||
"voucher/internal/conf"
|
||
log2 "voucher/internal/pkg/log"
|
||
"voucher/internal/service"
|
||
)
|
||
|
||
// NewHTTPServer new an HTTP server.
|
||
func NewHTTPServer(
|
||
c *conf.Bootstrap,
|
||
log *log.Helper,
|
||
accessLogger *log2.AccessLogger,
|
||
voucherService *service.VoucherService,
|
||
) *http.Server {
|
||
//构建 server
|
||
srv := buildHTTPServer(c, accessLogger, log)
|
||
|
||
srv.Route("/").GET("/ping", func(ctx http.Context) error {
|
||
return ctx.String(http2.StatusOK, "pong")
|
||
})
|
||
|
||
//v1 := srv.Route("/v1")
|
||
//v1.POST("/order", pluginService.Upload)
|
||
|
||
return srv
|
||
}
|
||
|
||
// buildHTTPServer 构建 HTTP Server
|
||
func buildHTTPServer(c *conf.Bootstrap, accessLogger *log2.AccessLogger, log *log.Helper) *http.Server {
|
||
var opts []http.ServerOption
|
||
|
||
if c.Server.Http.Network != "" {
|
||
opts = append(opts, http.Network(c.Server.Http.Network))
|
||
}
|
||
if c.Server.Http.Addr != "" {
|
||
opts = append(opts, http.Address(c.Server.Http.Addr))
|
||
}
|
||
if c.Server.Http.Timeout != nil {
|
||
opts = append(opts, http.Timeout(c.Server.Http.Timeout.AsDuration()))
|
||
}
|
||
|
||
// 允许跨域
|
||
filters := []http.FilterFunc{handlers.CORS(
|
||
// 域名配置
|
||
handlers.AllowedOrigins([]string{"*"}),
|
||
handlers.AllowedHeaders([]string{"Content-Type", "X-Requested-With", "Authorization"}),
|
||
handlers.AllowedMethods([]string{"GET", "POST", "PUT", "HEAD", "OPTIONS", "DELETE"}),
|
||
)}
|
||
|
||
if c.Server.Http.IsResponseReqHeaders {
|
||
//将请求的 header 头信息返回给客户端
|
||
filters = append(filters, responseReqHeadersHandler)
|
||
}
|
||
|
||
opts = append(opts, http.Filter(filters...))
|
||
|
||
//拦截恢复异常
|
||
middlewares := []middleware.Middleware{
|
||
recovery.Recovery(
|
||
recovery.WithHandler(func(ctx context.Context, req, err interface{}) error {
|
||
//做一些panic处理
|
||
newErr, isOk := err.(*errors.Error)
|
||
if isOk {
|
||
//如果是自己定义的类型,则直接抛出去,目的是支持 panic(myErr) 类型的写法,提升开发速度
|
||
return newErr
|
||
}
|
||
log.Errorf("系统错误,req=%+v, error=%+v", req, err)
|
||
return fmt.Errorf("系统错误,req=%+v, error=%+v", req, err)
|
||
}),
|
||
),
|
||
}
|
||
|
||
//加入 validator
|
||
middlewares = append(middlewares, validate.Validator())
|
||
|
||
//打印访问日志
|
||
if accessLogger != nil {
|
||
middlewares = append(middlewares, logging.Server(*accessLogger))
|
||
}
|
||
|
||
opts = append(opts, http.Middleware(middlewares...))
|
||
opts = append(opts, http.ResponseEncoder(responseEncoder))
|
||
opts = append(opts, http.ErrorEncoder(errorEncoder()))
|
||
//注册 service到 http 上
|
||
srv := http.NewServer(opts...)
|
||
|
||
//注册 openapi
|
||
if c.Server.Http.GetIsOpenSwagger() {
|
||
srv.HandlePrefix("/doc/", http2.StripPrefix("/doc/", http2.FileServer(http2.Dir("./third_party/swagger_ui"))))
|
||
}
|
||
|
||
return srv
|
||
}
|
||
|
||
// 将请求头响应到 header 中
|
||
func responseReqHeadersHandler(h http2.Handler) http2.Handler {
|
||
return http2.HandlerFunc(func(w http2.ResponseWriter, r *http2.Request) {
|
||
reqHeaders, _ := json.Marshal(r.Header)
|
||
w.Header().Add("Req-Headers", string(reqHeaders))
|
||
h.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
|
||
func responseEncoder(w http.ResponseWriter, r *http.Request, data interface{}) error {
|
||
if data == nil {
|
||
return nil
|
||
}
|
||
codec, _ := http.CodecForRequest(r, "Accept")
|
||
dataByte, err := codec.Marshal(data)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
w.Header().Set("Content-Type", fmt.Sprintf("application/%s", codec.Name()))
|
||
_, err = w.Write([]byte(fmt.Sprintf(`{"code":%d,"data":%s,"message":"%s"}`, http2.StatusOK, string(dataByte), "成功")))
|
||
|
||
return err
|
||
}
|
||
|
||
func errorEncoder() http.EncodeErrorFunc {
|
||
return func(w http2.ResponseWriter, r *http2.Request, err error) {
|
||
var message string
|
||
var code int32
|
||
|
||
se := errors.FromError(err)
|
||
message = se.Message
|
||
code = se.Code
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
switch code {
|
||
case http2.StatusUnauthorized:
|
||
w.WriteHeader(http2.StatusUnauthorized)
|
||
default:
|
||
w.WriteHeader(http2.StatusOK)
|
||
}
|
||
respBody := map[string]interface{}{
|
||
"code": code,
|
||
"data": "",
|
||
"message": message,
|
||
}
|
||
body, _ := json.Marshal(respBody)
|
||
_, _ = w.Write(body)
|
||
}
|
||
}
|