123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- package transport
- import (
- "context"
- "encoding/json"
- "errors"
- "github.com/go-kit/kit/log"
- "github.com/go-kit/kit/transport"
- kithttp "github.com/go-kit/kit/transport/http"
- "github.com/gorilla/mux"
- "github.com/longjoy/micro-go-course/section31/endpoint"
- "github.com/longjoy/micro-go-course/section31/service"
- "github.com/prometheus/client_golang/prometheus/promhttp"
- "net/http"
- )
- var (
- ErrorBadRequest = errors.New("invalid request parameter")
- ErrorGrantTypeRequest = errors.New("invalid request grant type")
- ErrorTokenRequest = errors.New("invalid request token")
- ErrInvalidClientRequest = errors.New("invalid client message")
- )
- // MakeHttpHandler make http handler use mux
- func MakeHttpHandler(ctx context.Context, endpoints endpoint.OAuth2Endpoints, tokenService service.TokenService, clientService service.ClientDetailsService, logger log.Logger) http.Handler {
- r := mux.NewRouter()
- options := []kithttp.ServerOption{
- kithttp.ServerErrorHandler(transport.NewLogErrorHandler(logger)),
- kithttp.ServerErrorEncoder(encodeError),
- }
- r.Path("/metrics").Handler(promhttp.Handler())
- clientAuthorizationOptions := []kithttp.ServerOption{
- kithttp.ServerBefore(makeClientAuthorizationContext(clientService, logger)),
- kithttp.ServerErrorHandler(transport.NewLogErrorHandler(logger)),
- kithttp.ServerErrorEncoder(encodeError),
- }
- r.Methods("POST").Path("/oauth/token").Handler(kithttp.NewServer(
- endpoints.TokenEndpoint,
- decodeTokenRequest,
- encodeJsonResponse,
- clientAuthorizationOptions...,
- ))
- r.Methods("POST").Path("/oauth/check_token").Handler(kithttp.NewServer(
- endpoints.CheckTokenEndpoint,
- decodeCheckTokenRequest,
- encodeJsonResponse,
- clientAuthorizationOptions...,
- ))
- // create health check handler
- r.Methods("GET").Path("/health").Handler(kithttp.NewServer(
- endpoints.HealthCheckEndpoint,
- decodeHealthCheckRequest,
- encodeJsonResponse,
- options...,
- ))
- return r
- }
- func makeOAuth2AuthorizationContext(tokenService service.TokenService, logger log.Logger) kithttp.RequestFunc {
- return func(ctx context.Context, r *http.Request) context.Context {
- // 获取访问令牌
- accessTokenValue := r.Header.Get("Authorization")
- var err error
- if accessTokenValue != ""{
- // 获取令牌对应的用户信息和客户端信息
- oauth2Details, err := tokenService.GetOAuth2DetailsByAccessToken(accessTokenValue)
- if err == nil {
- return context.WithValue(ctx, endpoint.OAuth2DetailsKey, oauth2Details)
- }
- }else {
- err = ErrorTokenRequest
- }
- return context.WithValue(ctx, endpoint.OAuth2ErrorKey, err)
- }
- }
- func makeClientAuthorizationContext(clientDetailsService service.ClientDetailsService, logger log.Logger) kithttp.RequestFunc {
- return func(ctx context.Context, r *http.Request) context.Context {
- if clientId, clientSecret, ok := r.BasicAuth(); ok {
- clientDetails, err := clientDetailsService.GetClientDetailsByClientId(ctx, clientId, clientSecret)
- if err != nil{
- return context.WithValue(ctx, endpoint.OAuth2ErrorKey, ErrInvalidClientRequest)
- }
- return context.WithValue(ctx, endpoint.OAuth2ClientDetailsKey, clientDetails)
- }
- return context.WithValue(ctx, endpoint.OAuth2ErrorKey, ErrInvalidClientRequest)
- }
- }
- func decodeTokenRequest(ctx context.Context, r *http.Request) (interface{}, error) {
- grantType := r.URL.Query().Get("grant_type")
- if grantType == ""{
- return nil, ErrorGrantTypeRequest
- }
- return &endpoint.TokenRequest{
- GrantType:grantType,
- Reader:r,
- }, nil
- }
- func decodeCheckTokenRequest(ctx context.Context, r *http.Request) (interface{}, error) {
- tokenValue := r.URL.Query().Get("token")
- if tokenValue == ""{
- return nil, ErrorTokenRequest
- }
- return &endpoint.CheckTokenRequest{
- Token:tokenValue,
- }, nil
- }
- func encodeJsonResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error {
- w.Header().Set("Content-Type", "application/json;charset=utf-8")
- return json.NewEncoder(w).Encode(response)
- }
- // decodeHealthCheckRequest decode request
- func decodeHealthCheckRequest(ctx context.Context, r *http.Request) (interface{}, error) {
- return endpoint.HealthRequest{}, nil
- }
- // encode errors from business-logic
- func encodeError(_ context.Context, err error, w http.ResponseWriter) {
- w.Header().Set("Content-Type", "application/json; charset=utf-8")
- switch err {
- default:
- w.WriteHeader(http.StatusInternalServerError)
- }
- json.NewEncoder(w).Encode(map[string]interface{}{
- "error": err.Error(),
- })
- }
|