123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471 |
- package service
- import (
- "context"
- "errors"
- "github.com/dgrijalva/jwt-go"
- "github.com/longjoy/micro-go-course/section31/model"
- uuid "github.com/satori/go.uuid"
- "net/http"
- "strconv"
- "time"
- )
- var (
- ErrNotSupportGrantType = errors.New("grant type is not supported")
- ErrNotSupportOperation = errors.New("no support operation")
- ErrInvalidUsernameAndPasswordRequest = errors.New("invalid username, password")
- ErrInvalidTokenRequest = errors.New("invalid token")
- ErrExpiredToken = errors.New("token is expired")
- )
- type TokenGranter interface {
- Grant(ctx context.Context, grantType string, client model.ClientDetails, reader *http.Request) (*model.OAuth2Token, error)
- }
- type ComposeTokenGranter struct {
- TokenGrantDict map[string] TokenGranter
- }
- func NewComposeTokenGranter(tokenGrantDict map[string] TokenGranter) TokenGranter {
- return &ComposeTokenGranter{
- TokenGrantDict:tokenGrantDict,
- }
- }
- func (tokenGranter *ComposeTokenGranter) Grant(ctx context.Context, grantType string, client model.ClientDetails, reader *http.Request) (*model.OAuth2Token, error) {
- // 检查客户端是否允许该种授权类型
- var isSupport bool
- if len(client.AuthorizedGrantTypes) > 0 {
- for _, v := range client.AuthorizedGrantTypes {
- if v == grantType {
- isSupport = true
- break
- }
- }
- }
- if !isSupport{
- return nil, ErrNotSupportOperation
- }
- // 查找具体的授权类型实现节点
- dispatchGranter,ok := tokenGranter.TokenGrantDict[grantType]; if ok{
- return dispatchGranter.Grant(ctx, grantType, client, reader)
- } else{
- return nil, ErrNotSupportGrantType
- }
- }
- type UsernamePasswordTokenGranter struct {
- supportGrantType string
- userDetailsService UserDetailsService
- tokenService TokenService
- }
- func NewUsernamePasswordTokenGranter(grantType string, userDetailsService UserDetailsService, tokenService TokenService) TokenGranter {
- return &UsernamePasswordTokenGranter{
- supportGrantType:grantType,
- userDetailsService:userDetailsService,
- tokenService:tokenService,
- }
- }
- func (tokenGranter *UsernamePasswordTokenGranter) Grant(ctx context.Context,
- grantType string, client model.ClientDetails, reader *http.Request) (*model.OAuth2Token, error) {
- if grantType != tokenGranter.supportGrantType{
- return nil, ErrNotSupportGrantType
- }
- // 从请求体中获取用户名密码
- username := reader.FormValue("username")
- password := reader.FormValue("password")
- if username == "" || password == ""{
- return nil, ErrInvalidUsernameAndPasswordRequest
- }
- // 验证用户名密码是否正确
- userDetails, err := tokenGranter.userDetailsService.GetUserDetailByUsername(ctx, username, password)
- if err != nil{
- return nil, ErrInvalidUsernameAndPasswordRequest
- }
- // 根据用户信息和客户端信息生成访问令牌
- return tokenGranter.tokenService.CreateAccessToken(&model.OAuth2Details{
- Client:client,
- User:userDetails,
- })
- }
- type RefreshTokenGranter struct {
- supportGrantType string
- tokenService TokenService
- }
- func NewRefreshGranter(grantType string, userDetailsService UserDetailsService, tokenService TokenService) TokenGranter {
- return &RefreshTokenGranter{
- supportGrantType:grantType,
- tokenService:tokenService,
- }
- }
- func (tokenGranter *RefreshTokenGranter) Grant(ctx context.Context, grantType string, client model.ClientDetails, reader *http.Request) (*model.OAuth2Token, error) {
- if grantType != tokenGranter.supportGrantType{
- return nil, ErrNotSupportGrantType
- }
- // 从请求中获取刷新令牌
- refreshTokenValue := reader.URL.Query().Get("refresh_token")
- if refreshTokenValue == ""{
- return nil, ErrInvalidTokenRequest
- }
- return tokenGranter.tokenService.RefreshAccessToken(refreshTokenValue)
- }
- type TokenService interface {
- // 根据访问令牌获取对应的用户信息和客户端信息
- GetOAuth2DetailsByAccessToken(tokenValue string) (*model.OAuth2Details, error)
- // 根据用户信息和客户端信息生成访问令牌
- CreateAccessToken(oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error)
- // 根据刷新令牌获取访问令牌
- RefreshAccessToken(refreshTokenValue string) (*model.OAuth2Token, error)
- // 根据用户信息和客户端信息获取已生成访问令牌
- GetAccessToken(details *model.OAuth2Details) (*model.OAuth2Token, error)
- // 根据访问令牌值获取访问令牌结构体
- ReadAccessToken(tokenValue string) (*model.OAuth2Token, error)
- }
- type DefaultTokenService struct {
- tokenStore TokenStore
- tokenEnhancer TokenEnhancer
- }
- func NewTokenService(tokenStore TokenStore, tokenEnhancer TokenEnhancer) TokenService {
- return &DefaultTokenService{
- tokenStore:tokenStore,
- tokenEnhancer:tokenEnhancer,
- }
- }
- func (tokenService *DefaultTokenService) CreateAccessToken(oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error) {
- existToken, err := tokenService.tokenStore.GetAccessToken(oauth2Details)
- if err != nil{
- return nil, err
- }
- var refreshToken *model.OAuth2Token
- // 存在未失效访问令牌,直接返回
- if existToken != nil {
- if !existToken.IsExpired() {
- err = tokenService.tokenStore.StoreAccessToken(existToken, oauth2Details)
- return existToken, err
- }
- // 访问令牌已失效,移除
- err = tokenService.tokenStore.RemoveAccessToken(existToken.TokenValue)
- if err != nil {
- return nil, err
- }
- if existToken.RefreshToken != nil {
- refreshToken = existToken.RefreshToken
- err = tokenService.tokenStore.RemoveRefreshToken(refreshToken.TokenType)
- if err != nil {
- return nil, err
- }
- }
- }
- if refreshToken == nil || refreshToken.IsExpired() {
- refreshToken, err = tokenService.createRefreshToken(oauth2Details)
- if err != nil {
- return nil, err
- }
- }
- // 生成新的访问令牌
- accessToken, err := tokenService.createAccessToken(refreshToken, oauth2Details)
- if err != nil{
- return nil, err
- }
- // 保存新生成令牌
- err = tokenService.tokenStore.StoreAccessToken(accessToken, oauth2Details)
- if err != nil{
- return nil, err
- }
- err = tokenService.tokenStore.StoreRefreshToken(refreshToken, oauth2Details)
- if err != nil{
- return nil, err
- }
- return accessToken, err
- }
- func (tokenService *DefaultTokenService) createAccessToken(refreshToken *model.OAuth2Token, oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error) {
- validitySeconds := oauth2Details.Client.AccessTokenValiditySeconds
- s, _ := time.ParseDuration(strconv.Itoa(validitySeconds) + "s")
- expiredTime := time.Now().Add(s)
- accessToken := &model.OAuth2Token{
- RefreshToken:refreshToken,
- ExpiresTime:&expiredTime,
- TokenValue:uuid.NewV4().String(),
- }
- if tokenService.tokenEnhancer != nil{
- return tokenService.tokenEnhancer.Enhance(accessToken, oauth2Details)
- }
- return accessToken, nil
- }
- func (tokenService *DefaultTokenService) createRefreshToken(oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error) {
- validitySeconds := oauth2Details.Client.RefreshTokenValiditySeconds
- s, _ := time.ParseDuration(strconv.Itoa(validitySeconds) + "s")
- expiredTime := time.Now().Add(s)
- refreshToken := &model.OAuth2Token{
- ExpiresTime:&expiredTime,
- TokenValue:uuid.NewV4().String(),
- }
- if tokenService.tokenEnhancer != nil{
- return tokenService.tokenEnhancer.Enhance(refreshToken, oauth2Details)
- }
- return refreshToken, nil
- }
- func (tokenService *DefaultTokenService) RefreshAccessToken(refreshTokenValue string) (*model.OAuth2Token, error){
- refreshToken, err := tokenService.tokenStore.ReadRefreshToken(refreshTokenValue)
- if err == nil{
- if refreshToken.IsExpired(){
- return nil, ErrExpiredToken
- }
- oauth2Details, err := tokenService.tokenStore.ReadOAuth2DetailsForRefreshToken(refreshTokenValue)
- if err == nil{
- oauth2Token, err := tokenService.tokenStore.GetAccessToken(oauth2Details)
- // 移除原有的访问令牌
- if err == nil{
- tokenService.tokenStore.RemoveAccessToken(oauth2Token.TokenValue)
- }
- // 移除已使用的刷新令牌
- tokenService.tokenStore.RemoveRefreshToken(refreshTokenValue)
- refreshToken, err = tokenService.createRefreshToken(oauth2Details)
- if err == nil{
- accessToken, err := tokenService.createAccessToken(refreshToken, oauth2Details)
- if err == nil{
- tokenService.tokenStore.StoreAccessToken(accessToken, oauth2Details)
- tokenService.tokenStore.StoreRefreshToken(refreshToken, oauth2Details)
- }
- return accessToken, err;
- }
- }
- }
- return nil, err
- }
- func (tokenService *DefaultTokenService) GetAccessToken(details *model.OAuth2Details) (*model.OAuth2Token, error) {
- return tokenService.tokenStore.GetAccessToken(details)
- }
- func (tokenService *DefaultTokenService) ReadAccessToken(tokenValue string) (*model.OAuth2Token, error){
- return tokenService.tokenStore.ReadAccessToken(tokenValue)
- }
- func (tokenService *DefaultTokenService) GetOAuth2DetailsByAccessToken(tokenValue string) (*model.OAuth2Details, error) {
- accessToken, err := tokenService.tokenStore.ReadAccessToken(tokenValue)
- if err != nil{
- return nil, err
- }
- if accessToken.IsExpired(){
- return nil, ErrExpiredToken
- }
- return tokenService.tokenStore.ReadOAuth2Details(tokenValue)
- }
- type TokenStore interface {
- // 存储访问令牌
- StoreAccessToken(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) error
- // 根据令牌值获取访问令牌结构体
- ReadAccessToken(tokenValue string) (*model.OAuth2Token, error)
- // 根据令牌值获取令牌对应的客户端和用户信息
- ReadOAuth2Details(tokenValue string)(*model.OAuth2Details, error)
- // 根据客户端信息和用户信息获取访问令牌
- GetAccessToken(oauth2Details *model.OAuth2Details)(*model.OAuth2Token, error);
- // 移除存储的访问令牌
- RemoveAccessToken(tokenValue string) error
- // 存储刷新令牌
- StoreRefreshToken(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) error
- // 移除存储的刷新令牌
- RemoveRefreshToken(oauth2Token string) error
- // 根据令牌值获取刷新令牌
- ReadRefreshToken(tokenValue string)(*model.OAuth2Token, error)
- // 根据令牌值获取刷新令牌对应的客户端和用户信息
- ReadOAuth2DetailsForRefreshToken(tokenValue string)(*model.OAuth2Details, error)
- }
- func NewJwtTokenStore(jwtTokenEnhancer *JwtTokenEnhancer) TokenStore {
- return &JwtTokenStore{
- jwtTokenEnhancer:jwtTokenEnhancer,
- }
- }
- type JwtTokenStore struct {
- jwtTokenEnhancer *JwtTokenEnhancer
- }
- func (tokenStore *JwtTokenStore) StoreAccessToken(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) error{
- return nil
- }
- func (tokenStore *JwtTokenStore)ReadAccessToken(tokenValue string) (*model.OAuth2Token, error){
- oauth2Token, _, err := tokenStore.jwtTokenEnhancer.Extract(tokenValue)
- return oauth2Token, err
- }
- // 根据令牌值获取令牌对应的客户端和用户信息
- func (tokenStore *JwtTokenStore) ReadOAuth2Details(tokenValue string)(*model.OAuth2Details, error){
- _, oauth2Details, err := tokenStore.jwtTokenEnhancer.Extract(tokenValue)
- return oauth2Details, err
- }
- // 根据客户端信息和用户信息获取访问令牌
- func (tokenStore *JwtTokenStore) GetAccessToken(oauth2Details *model.OAuth2Details)(*model.OAuth2Token, error){
- return nil, nil
- }
- // 移除存储的访问令牌
- func (tokenStore *JwtTokenStore) RemoveAccessToken(tokenValue string) error {
- return nil
- }
- // 存储刷新令牌
- func (tokenStore *JwtTokenStore) StoreRefreshToken(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) error {
- return nil
- }
- // 移除存储的刷新令牌
- func (tokenStore *JwtTokenStore)RemoveRefreshToken(oauth2Token string) error {
- return nil
- }
- // 根据令牌值获取刷新令牌
- func (tokenStore *JwtTokenStore) ReadRefreshToken(tokenValue string)(*model.OAuth2Token, error){
- oauth2Token, _, err := tokenStore.jwtTokenEnhancer.Extract(tokenValue)
- return oauth2Token, err
- }
- // 根据令牌值获取刷新令牌对应的客户端和用户信息
- func (tokenStore *JwtTokenStore)ReadOAuth2DetailsForRefreshToken(tokenValue string)(*model.OAuth2Details, error){
- _, oauth2Details, err := tokenStore.jwtTokenEnhancer.Extract(tokenValue)
- return oauth2Details, err
- }
- type TokenEnhancer interface {
- // 组装 Token 信息
- Enhance(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error)
- // 从 Token 中还原信息
- Extract(tokenValue string) (*model.OAuth2Token, *model.OAuth2Details, error)
- }
- type OAuth2TokenCustomClaims struct {
- UserDetails model.UserDetails
- ClientDetails model.ClientDetails
- RefreshToken model.OAuth2Token
- jwt.StandardClaims
- }
- type JwtTokenEnhancer struct {
- secretKey []byte
- }
- func NewJwtTokenEnhancer(secretKey string) TokenEnhancer {
- return &JwtTokenEnhancer{
- secretKey:[]byte(secretKey),
- }
- }
- func (enhancer *JwtTokenEnhancer) Enhance(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error) {
- return enhancer.sign(oauth2Token, oauth2Details)
- }
- func (enhancer *JwtTokenEnhancer) sign(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error) {
- expireTime := oauth2Token.ExpiresTime
- clientDetails := oauth2Details.Client
- userDetails := oauth2Details.User
- clientDetails.ClientSecret = ""
- userDetails.Password = ""
- claims := OAuth2TokenCustomClaims{
- UserDetails:userDetails,
- ClientDetails:clientDetails,
- StandardClaims:jwt.StandardClaims{
- ExpiresAt:expireTime.Unix(),
- Issuer:"System",
- },
- }
- if oauth2Token.RefreshToken != nil{
- claims.RefreshToken = *oauth2Token.RefreshToken
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenValue, err := token.SignedString(enhancer.secretKey)
- if err != nil{
- return nil, err
- }
- oauth2Token.TokenValue = tokenValue
- oauth2Token.TokenType = "jwt"
- return oauth2Token, nil;
- }
- func (enhancer *JwtTokenEnhancer) Extract(tokenValue string) (*model.OAuth2Token, *model.OAuth2Details, error) {
- token, err := jwt.ParseWithClaims(tokenValue, &OAuth2TokenCustomClaims{}, func(token *jwt.Token) (i interface{}, e error) {
- return enhancer.secretKey, nil
- })
- if err != nil{
- return nil, nil, err
- }
- claims := token.Claims.(*OAuth2TokenCustomClaims)
- expiresTime := time.Unix(claims.ExpiresAt, 0)
- return &model.OAuth2Token{
- RefreshToken:&claims.RefreshToken,
- TokenValue:tokenValue,
- ExpiresTime: &expiresTime,
- }, &model.OAuth2Details{
- User:claims.UserDetails,
- Client:claims.ClientDetails,
- }, nil
- }
|