f79a0adbe94a344e1b3e3941a8c57a61653e6860.svn-base 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "github.com/dgrijalva/jwt-go"
  6. "github.com/longjoy/micro-go-course/section31/model"
  7. uuid "github.com/satori/go.uuid"
  8. "net/http"
  9. "strconv"
  10. "time"
  11. )
  12. var (
  13. ErrNotSupportGrantType = errors.New("grant type is not supported")
  14. ErrNotSupportOperation = errors.New("no support operation")
  15. ErrInvalidUsernameAndPasswordRequest = errors.New("invalid username, password")
  16. ErrInvalidTokenRequest = errors.New("invalid token")
  17. ErrExpiredToken = errors.New("token is expired")
  18. )
  19. type TokenGranter interface {
  20. Grant(ctx context.Context, grantType string, client model.ClientDetails, reader *http.Request) (*model.OAuth2Token, error)
  21. }
  22. type ComposeTokenGranter struct {
  23. TokenGrantDict map[string] TokenGranter
  24. }
  25. func NewComposeTokenGranter(tokenGrantDict map[string] TokenGranter) TokenGranter {
  26. return &ComposeTokenGranter{
  27. TokenGrantDict:tokenGrantDict,
  28. }
  29. }
  30. func (tokenGranter *ComposeTokenGranter) Grant(ctx context.Context, grantType string, client model.ClientDetails, reader *http.Request) (*model.OAuth2Token, error) {
  31. // 检查客户端是否允许该种授权类型
  32. var isSupport bool
  33. if len(client.AuthorizedGrantTypes) > 0 {
  34. for _, v := range client.AuthorizedGrantTypes {
  35. if v == grantType {
  36. isSupport = true
  37. break
  38. }
  39. }
  40. }
  41. if !isSupport{
  42. return nil, ErrNotSupportOperation
  43. }
  44. // 查找具体的授权类型实现节点
  45. dispatchGranter,ok := tokenGranter.TokenGrantDict[grantType]; if ok{
  46. return dispatchGranter.Grant(ctx, grantType, client, reader)
  47. } else{
  48. return nil, ErrNotSupportGrantType
  49. }
  50. }
  51. type UsernamePasswordTokenGranter struct {
  52. supportGrantType string
  53. userDetailsService UserDetailsService
  54. tokenService TokenService
  55. }
  56. func NewUsernamePasswordTokenGranter(grantType string, userDetailsService UserDetailsService, tokenService TokenService) TokenGranter {
  57. return &UsernamePasswordTokenGranter{
  58. supportGrantType:grantType,
  59. userDetailsService:userDetailsService,
  60. tokenService:tokenService,
  61. }
  62. }
  63. func (tokenGranter *UsernamePasswordTokenGranter) Grant(ctx context.Context,
  64. grantType string, client model.ClientDetails, reader *http.Request) (*model.OAuth2Token, error) {
  65. if grantType != tokenGranter.supportGrantType{
  66. return nil, ErrNotSupportGrantType
  67. }
  68. // 从请求体中获取用户名密码
  69. username := reader.FormValue("username")
  70. password := reader.FormValue("password")
  71. if username == "" || password == ""{
  72. return nil, ErrInvalidUsernameAndPasswordRequest
  73. }
  74. // 验证用户名密码是否正确
  75. userDetails, err := tokenGranter.userDetailsService.GetUserDetailByUsername(ctx, username, password)
  76. if err != nil{
  77. return nil, ErrInvalidUsernameAndPasswordRequest
  78. }
  79. // 根据用户信息和客户端信息生成访问令牌
  80. return tokenGranter.tokenService.CreateAccessToken(&model.OAuth2Details{
  81. Client:client,
  82. User:userDetails,
  83. })
  84. }
  85. type RefreshTokenGranter struct {
  86. supportGrantType string
  87. tokenService TokenService
  88. }
  89. func NewRefreshGranter(grantType string, userDetailsService UserDetailsService, tokenService TokenService) TokenGranter {
  90. return &RefreshTokenGranter{
  91. supportGrantType:grantType,
  92. tokenService:tokenService,
  93. }
  94. }
  95. func (tokenGranter *RefreshTokenGranter) Grant(ctx context.Context, grantType string, client model.ClientDetails, reader *http.Request) (*model.OAuth2Token, error) {
  96. if grantType != tokenGranter.supportGrantType{
  97. return nil, ErrNotSupportGrantType
  98. }
  99. // 从请求中获取刷新令牌
  100. refreshTokenValue := reader.URL.Query().Get("refresh_token")
  101. if refreshTokenValue == ""{
  102. return nil, ErrInvalidTokenRequest
  103. }
  104. return tokenGranter.tokenService.RefreshAccessToken(refreshTokenValue)
  105. }
  106. type TokenService interface {
  107. // 根据访问令牌获取对应的用户信息和客户端信息
  108. GetOAuth2DetailsByAccessToken(tokenValue string) (*model.OAuth2Details, error)
  109. // 根据用户信息和客户端信息生成访问令牌
  110. CreateAccessToken(oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error)
  111. // 根据刷新令牌获取访问令牌
  112. RefreshAccessToken(refreshTokenValue string) (*model.OAuth2Token, error)
  113. // 根据用户信息和客户端信息获取已生成访问令牌
  114. GetAccessToken(details *model.OAuth2Details) (*model.OAuth2Token, error)
  115. // 根据访问令牌值获取访问令牌结构体
  116. ReadAccessToken(tokenValue string) (*model.OAuth2Token, error)
  117. }
  118. type DefaultTokenService struct {
  119. tokenStore TokenStore
  120. tokenEnhancer TokenEnhancer
  121. }
  122. func NewTokenService(tokenStore TokenStore, tokenEnhancer TokenEnhancer) TokenService {
  123. return &DefaultTokenService{
  124. tokenStore:tokenStore,
  125. tokenEnhancer:tokenEnhancer,
  126. }
  127. }
  128. func (tokenService *DefaultTokenService) CreateAccessToken(oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error) {
  129. existToken, err := tokenService.tokenStore.GetAccessToken(oauth2Details)
  130. if err != nil{
  131. return nil, err
  132. }
  133. var refreshToken *model.OAuth2Token
  134. // 存在未失效访问令牌,直接返回
  135. if existToken != nil {
  136. if !existToken.IsExpired() {
  137. err = tokenService.tokenStore.StoreAccessToken(existToken, oauth2Details)
  138. return existToken, err
  139. }
  140. // 访问令牌已失效,移除
  141. err = tokenService.tokenStore.RemoveAccessToken(existToken.TokenValue)
  142. if err != nil {
  143. return nil, err
  144. }
  145. if existToken.RefreshToken != nil {
  146. refreshToken = existToken.RefreshToken
  147. err = tokenService.tokenStore.RemoveRefreshToken(refreshToken.TokenType)
  148. if err != nil {
  149. return nil, err
  150. }
  151. }
  152. }
  153. if refreshToken == nil || refreshToken.IsExpired() {
  154. refreshToken, err = tokenService.createRefreshToken(oauth2Details)
  155. if err != nil {
  156. return nil, err
  157. }
  158. }
  159. // 生成新的访问令牌
  160. accessToken, err := tokenService.createAccessToken(refreshToken, oauth2Details)
  161. if err != nil{
  162. return nil, err
  163. }
  164. // 保存新生成令牌
  165. err = tokenService.tokenStore.StoreAccessToken(accessToken, oauth2Details)
  166. if err != nil{
  167. return nil, err
  168. }
  169. err = tokenService.tokenStore.StoreRefreshToken(refreshToken, oauth2Details)
  170. if err != nil{
  171. return nil, err
  172. }
  173. return accessToken, err
  174. }
  175. func (tokenService *DefaultTokenService) createAccessToken(refreshToken *model.OAuth2Token, oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error) {
  176. validitySeconds := oauth2Details.Client.AccessTokenValiditySeconds
  177. s, _ := time.ParseDuration(strconv.Itoa(validitySeconds) + "s")
  178. expiredTime := time.Now().Add(s)
  179. accessToken := &model.OAuth2Token{
  180. RefreshToken:refreshToken,
  181. ExpiresTime:&expiredTime,
  182. TokenValue:uuid.NewV4().String(),
  183. }
  184. if tokenService.tokenEnhancer != nil{
  185. return tokenService.tokenEnhancer.Enhance(accessToken, oauth2Details)
  186. }
  187. return accessToken, nil
  188. }
  189. func (tokenService *DefaultTokenService) createRefreshToken(oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error) {
  190. validitySeconds := oauth2Details.Client.RefreshTokenValiditySeconds
  191. s, _ := time.ParseDuration(strconv.Itoa(validitySeconds) + "s")
  192. expiredTime := time.Now().Add(s)
  193. refreshToken := &model.OAuth2Token{
  194. ExpiresTime:&expiredTime,
  195. TokenValue:uuid.NewV4().String(),
  196. }
  197. if tokenService.tokenEnhancer != nil{
  198. return tokenService.tokenEnhancer.Enhance(refreshToken, oauth2Details)
  199. }
  200. return refreshToken, nil
  201. }
  202. func (tokenService *DefaultTokenService) RefreshAccessToken(refreshTokenValue string) (*model.OAuth2Token, error){
  203. refreshToken, err := tokenService.tokenStore.ReadRefreshToken(refreshTokenValue)
  204. if err == nil{
  205. if refreshToken.IsExpired(){
  206. return nil, ErrExpiredToken
  207. }
  208. oauth2Details, err := tokenService.tokenStore.ReadOAuth2DetailsForRefreshToken(refreshTokenValue)
  209. if err == nil{
  210. oauth2Token, err := tokenService.tokenStore.GetAccessToken(oauth2Details)
  211. // 移除原有的访问令牌
  212. if err == nil{
  213. tokenService.tokenStore.RemoveAccessToken(oauth2Token.TokenValue)
  214. }
  215. // 移除已使用的刷新令牌
  216. tokenService.tokenStore.RemoveRefreshToken(refreshTokenValue)
  217. refreshToken, err = tokenService.createRefreshToken(oauth2Details)
  218. if err == nil{
  219. accessToken, err := tokenService.createAccessToken(refreshToken, oauth2Details)
  220. if err == nil{
  221. tokenService.tokenStore.StoreAccessToken(accessToken, oauth2Details)
  222. tokenService.tokenStore.StoreRefreshToken(refreshToken, oauth2Details)
  223. }
  224. return accessToken, err;
  225. }
  226. }
  227. }
  228. return nil, err
  229. }
  230. func (tokenService *DefaultTokenService) GetAccessToken(details *model.OAuth2Details) (*model.OAuth2Token, error) {
  231. return tokenService.tokenStore.GetAccessToken(details)
  232. }
  233. func (tokenService *DefaultTokenService) ReadAccessToken(tokenValue string) (*model.OAuth2Token, error){
  234. return tokenService.tokenStore.ReadAccessToken(tokenValue)
  235. }
  236. func (tokenService *DefaultTokenService) GetOAuth2DetailsByAccessToken(tokenValue string) (*model.OAuth2Details, error) {
  237. accessToken, err := tokenService.tokenStore.ReadAccessToken(tokenValue)
  238. if err != nil{
  239. return nil, err
  240. }
  241. if accessToken.IsExpired(){
  242. return nil, ErrExpiredToken
  243. }
  244. return tokenService.tokenStore.ReadOAuth2Details(tokenValue)
  245. }
  246. type TokenStore interface {
  247. // 存储访问令牌
  248. StoreAccessToken(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) error
  249. // 根据令牌值获取访问令牌结构体
  250. ReadAccessToken(tokenValue string) (*model.OAuth2Token, error)
  251. // 根据令牌值获取令牌对应的客户端和用户信息
  252. ReadOAuth2Details(tokenValue string)(*model.OAuth2Details, error)
  253. // 根据客户端信息和用户信息获取访问令牌
  254. GetAccessToken(oauth2Details *model.OAuth2Details)(*model.OAuth2Token, error);
  255. // 移除存储的访问令牌
  256. RemoveAccessToken(tokenValue string) error
  257. // 存储刷新令牌
  258. StoreRefreshToken(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) error
  259. // 移除存储的刷新令牌
  260. RemoveRefreshToken(oauth2Token string) error
  261. // 根据令牌值获取刷新令牌
  262. ReadRefreshToken(tokenValue string)(*model.OAuth2Token, error)
  263. // 根据令牌值获取刷新令牌对应的客户端和用户信息
  264. ReadOAuth2DetailsForRefreshToken(tokenValue string)(*model.OAuth2Details, error)
  265. }
  266. func NewJwtTokenStore(jwtTokenEnhancer *JwtTokenEnhancer) TokenStore {
  267. return &JwtTokenStore{
  268. jwtTokenEnhancer:jwtTokenEnhancer,
  269. }
  270. }
  271. type JwtTokenStore struct {
  272. jwtTokenEnhancer *JwtTokenEnhancer
  273. }
  274. func (tokenStore *JwtTokenStore) StoreAccessToken(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) error{
  275. return nil
  276. }
  277. func (tokenStore *JwtTokenStore)ReadAccessToken(tokenValue string) (*model.OAuth2Token, error){
  278. oauth2Token, _, err := tokenStore.jwtTokenEnhancer.Extract(tokenValue)
  279. return oauth2Token, err
  280. }
  281. // 根据令牌值获取令牌对应的客户端和用户信息
  282. func (tokenStore *JwtTokenStore) ReadOAuth2Details(tokenValue string)(*model.OAuth2Details, error){
  283. _, oauth2Details, err := tokenStore.jwtTokenEnhancer.Extract(tokenValue)
  284. return oauth2Details, err
  285. }
  286. // 根据客户端信息和用户信息获取访问令牌
  287. func (tokenStore *JwtTokenStore) GetAccessToken(oauth2Details *model.OAuth2Details)(*model.OAuth2Token, error){
  288. return nil, nil
  289. }
  290. // 移除存储的访问令牌
  291. func (tokenStore *JwtTokenStore) RemoveAccessToken(tokenValue string) error {
  292. return nil
  293. }
  294. // 存储刷新令牌
  295. func (tokenStore *JwtTokenStore) StoreRefreshToken(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) error {
  296. return nil
  297. }
  298. // 移除存储的刷新令牌
  299. func (tokenStore *JwtTokenStore)RemoveRefreshToken(oauth2Token string) error {
  300. return nil
  301. }
  302. // 根据令牌值获取刷新令牌
  303. func (tokenStore *JwtTokenStore) ReadRefreshToken(tokenValue string)(*model.OAuth2Token, error){
  304. oauth2Token, _, err := tokenStore.jwtTokenEnhancer.Extract(tokenValue)
  305. return oauth2Token, err
  306. }
  307. // 根据令牌值获取刷新令牌对应的客户端和用户信息
  308. func (tokenStore *JwtTokenStore)ReadOAuth2DetailsForRefreshToken(tokenValue string)(*model.OAuth2Details, error){
  309. _, oauth2Details, err := tokenStore.jwtTokenEnhancer.Extract(tokenValue)
  310. return oauth2Details, err
  311. }
  312. type TokenEnhancer interface {
  313. // 组装 Token 信息
  314. Enhance(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error)
  315. // 从 Token 中还原信息
  316. Extract(tokenValue string) (*model.OAuth2Token, *model.OAuth2Details, error)
  317. }
  318. type OAuth2TokenCustomClaims struct {
  319. UserDetails model.UserDetails
  320. ClientDetails model.ClientDetails
  321. RefreshToken model.OAuth2Token
  322. jwt.StandardClaims
  323. }
  324. type JwtTokenEnhancer struct {
  325. secretKey []byte
  326. }
  327. func NewJwtTokenEnhancer(secretKey string) TokenEnhancer {
  328. return &JwtTokenEnhancer{
  329. secretKey:[]byte(secretKey),
  330. }
  331. }
  332. func (enhancer *JwtTokenEnhancer) Enhance(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error) {
  333. return enhancer.sign(oauth2Token, oauth2Details)
  334. }
  335. func (enhancer *JwtTokenEnhancer) sign(oauth2Token *model.OAuth2Token, oauth2Details *model.OAuth2Details) (*model.OAuth2Token, error) {
  336. expireTime := oauth2Token.ExpiresTime
  337. clientDetails := oauth2Details.Client
  338. userDetails := oauth2Details.User
  339. clientDetails.ClientSecret = ""
  340. userDetails.Password = ""
  341. claims := OAuth2TokenCustomClaims{
  342. UserDetails:userDetails,
  343. ClientDetails:clientDetails,
  344. StandardClaims:jwt.StandardClaims{
  345. ExpiresAt:expireTime.Unix(),
  346. Issuer:"System",
  347. },
  348. }
  349. if oauth2Token.RefreshToken != nil{
  350. claims.RefreshToken = *oauth2Token.RefreshToken
  351. }
  352. token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  353. tokenValue, err := token.SignedString(enhancer.secretKey)
  354. if err != nil{
  355. return nil, err
  356. }
  357. oauth2Token.TokenValue = tokenValue
  358. oauth2Token.TokenType = "jwt"
  359. return oauth2Token, nil;
  360. }
  361. func (enhancer *JwtTokenEnhancer) Extract(tokenValue string) (*model.OAuth2Token, *model.OAuth2Details, error) {
  362. token, err := jwt.ParseWithClaims(tokenValue, &OAuth2TokenCustomClaims{}, func(token *jwt.Token) (i interface{}, e error) {
  363. return enhancer.secretKey, nil
  364. })
  365. if err != nil{
  366. return nil, nil, err
  367. }
  368. claims := token.Claims.(*OAuth2TokenCustomClaims)
  369. expiresTime := time.Unix(claims.ExpiresAt, 0)
  370. return &model.OAuth2Token{
  371. RefreshToken:&claims.RefreshToken,
  372. TokenValue:tokenValue,
  373. ExpiresTime: &expiresTime,
  374. }, &model.OAuth2Details{
  375. User:claims.UserDetails,
  376. Client:claims.ClientDetails,
  377. }, nil
  378. }