641a0fdbcddd4f19c2c3ad05fe9bbf4a6d88f8fa.svn-base 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package transport
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "github.com/go-kit/kit/log"
  7. "github.com/go-kit/kit/transport"
  8. kithttp "github.com/go-kit/kit/transport/http"
  9. "github.com/gorilla/mux"
  10. "github.com/longjoy/micro-go-course/section31/endpoint"
  11. "github.com/longjoy/micro-go-course/section31/service"
  12. "github.com/prometheus/client_golang/prometheus/promhttp"
  13. "net/http"
  14. )
  15. var (
  16. ErrorBadRequest = errors.New("invalid request parameter")
  17. ErrorGrantTypeRequest = errors.New("invalid request grant type")
  18. ErrorTokenRequest = errors.New("invalid request token")
  19. ErrInvalidClientRequest = errors.New("invalid client message")
  20. )
  21. // MakeHttpHandler make http handler use mux
  22. func MakeHttpHandler(ctx context.Context, endpoints endpoint.OAuth2Endpoints, tokenService service.TokenService, clientService service.ClientDetailsService, logger log.Logger) http.Handler {
  23. r := mux.NewRouter()
  24. options := []kithttp.ServerOption{
  25. kithttp.ServerErrorHandler(transport.NewLogErrorHandler(logger)),
  26. kithttp.ServerErrorEncoder(encodeError),
  27. }
  28. r.Path("/metrics").Handler(promhttp.Handler())
  29. clientAuthorizationOptions := []kithttp.ServerOption{
  30. kithttp.ServerBefore(makeClientAuthorizationContext(clientService, logger)),
  31. kithttp.ServerErrorHandler(transport.NewLogErrorHandler(logger)),
  32. kithttp.ServerErrorEncoder(encodeError),
  33. }
  34. r.Methods("POST").Path("/oauth/token").Handler(kithttp.NewServer(
  35. endpoints.TokenEndpoint,
  36. decodeTokenRequest,
  37. encodeJsonResponse,
  38. clientAuthorizationOptions...,
  39. ))
  40. r.Methods("POST").Path("/oauth/check_token").Handler(kithttp.NewServer(
  41. endpoints.CheckTokenEndpoint,
  42. decodeCheckTokenRequest,
  43. encodeJsonResponse,
  44. clientAuthorizationOptions...,
  45. ))
  46. // create health check handler
  47. r.Methods("GET").Path("/health").Handler(kithttp.NewServer(
  48. endpoints.HealthCheckEndpoint,
  49. decodeHealthCheckRequest,
  50. encodeJsonResponse,
  51. options...,
  52. ))
  53. return r
  54. }
  55. func makeOAuth2AuthorizationContext(tokenService service.TokenService, logger log.Logger) kithttp.RequestFunc {
  56. return func(ctx context.Context, r *http.Request) context.Context {
  57. // 获取访问令牌
  58. accessTokenValue := r.Header.Get("Authorization")
  59. var err error
  60. if accessTokenValue != ""{
  61. // 获取令牌对应的用户信息和客户端信息
  62. oauth2Details, err := tokenService.GetOAuth2DetailsByAccessToken(accessTokenValue)
  63. if err == nil {
  64. return context.WithValue(ctx, endpoint.OAuth2DetailsKey, oauth2Details)
  65. }
  66. }else {
  67. err = ErrorTokenRequest
  68. }
  69. return context.WithValue(ctx, endpoint.OAuth2ErrorKey, err)
  70. }
  71. }
  72. func makeClientAuthorizationContext(clientDetailsService service.ClientDetailsService, logger log.Logger) kithttp.RequestFunc {
  73. return func(ctx context.Context, r *http.Request) context.Context {
  74. if clientId, clientSecret, ok := r.BasicAuth(); ok {
  75. clientDetails, err := clientDetailsService.GetClientDetailsByClientId(ctx, clientId, clientSecret)
  76. if err != nil{
  77. return context.WithValue(ctx, endpoint.OAuth2ErrorKey, ErrInvalidClientRequest)
  78. }
  79. return context.WithValue(ctx, endpoint.OAuth2ClientDetailsKey, clientDetails)
  80. }
  81. return context.WithValue(ctx, endpoint.OAuth2ErrorKey, ErrInvalidClientRequest)
  82. }
  83. }
  84. func decodeTokenRequest(ctx context.Context, r *http.Request) (interface{}, error) {
  85. grantType := r.URL.Query().Get("grant_type")
  86. if grantType == ""{
  87. return nil, ErrorGrantTypeRequest
  88. }
  89. return &endpoint.TokenRequest{
  90. GrantType:grantType,
  91. Reader:r,
  92. }, nil
  93. }
  94. func decodeCheckTokenRequest(ctx context.Context, r *http.Request) (interface{}, error) {
  95. tokenValue := r.URL.Query().Get("token")
  96. if tokenValue == ""{
  97. return nil, ErrorTokenRequest
  98. }
  99. return &endpoint.CheckTokenRequest{
  100. Token:tokenValue,
  101. }, nil
  102. }
  103. func encodeJsonResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error {
  104. w.Header().Set("Content-Type", "application/json;charset=utf-8")
  105. return json.NewEncoder(w).Encode(response)
  106. }
  107. // decodeHealthCheckRequest decode request
  108. func decodeHealthCheckRequest(ctx context.Context, r *http.Request) (interface{}, error) {
  109. return endpoint.HealthRequest{}, nil
  110. }
  111. // encode errors from business-logic
  112. func encodeError(_ context.Context, err error, w http.ResponseWriter) {
  113. w.Header().Set("Content-Type", "application/json; charset=utf-8")
  114. switch err {
  115. default:
  116. w.WriteHeader(http.StatusInternalServerError)
  117. }
  118. json.NewEncoder(w).Encode(map[string]interface{}{
  119. "error": err.Error(),
  120. })
  121. }