session.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. package client
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "sort"
  6. "strings"
  7. "sync"
  8. "time"
  9. "github.com/jcmturner/gokrb5/v8/iana/nametype"
  10. "github.com/jcmturner/gokrb5/v8/krberror"
  11. "github.com/jcmturner/gokrb5/v8/messages"
  12. "github.com/jcmturner/gokrb5/v8/types"
  13. )
  14. // sessions hold TGTs and are keyed on the realm name
  15. type sessions struct {
  16. Entries map[string]*session
  17. mux sync.RWMutex
  18. }
  19. // destroy erases all sessions
  20. func (s *sessions) destroy() {
  21. s.mux.Lock()
  22. defer s.mux.Unlock()
  23. for k, e := range s.Entries {
  24. e.destroy()
  25. delete(s.Entries, k)
  26. }
  27. }
  28. // update replaces a session with the one provided or adds it as a new one
  29. func (s *sessions) update(sess *session) {
  30. s.mux.Lock()
  31. defer s.mux.Unlock()
  32. // if a session already exists for this, cancel its auto renew.
  33. if i, ok := s.Entries[sess.realm]; ok {
  34. if i != sess {
  35. // Session in the sessions cache is not the same as one provided.
  36. // Cancel the one in the cache and add this one.
  37. i.mux.Lock()
  38. defer i.mux.Unlock()
  39. i.cancel <- true
  40. s.Entries[sess.realm] = sess
  41. return
  42. }
  43. }
  44. // No session for this realm was found so just add it
  45. s.Entries[sess.realm] = sess
  46. }
  47. // get returns the session for the realm specified
  48. func (s *sessions) get(realm string) (*session, bool) {
  49. s.mux.RLock()
  50. defer s.mux.RUnlock()
  51. sess, ok := s.Entries[realm]
  52. return sess, ok
  53. }
  54. // session holds the TGT details for a realm
  55. type session struct {
  56. realm string
  57. authTime time.Time
  58. endTime time.Time
  59. renewTill time.Time
  60. tgt messages.Ticket
  61. sessionKey types.EncryptionKey
  62. sessionKeyExpiration time.Time
  63. cancel chan bool
  64. mux sync.RWMutex
  65. }
  66. // jsonSession is used to enable marshaling some information of a session in a JSON format
  67. type jsonSession struct {
  68. Realm string
  69. AuthTime time.Time
  70. EndTime time.Time
  71. RenewTill time.Time
  72. SessionKeyExpiration time.Time
  73. }
  74. // AddSession adds a session for a realm with a TGT to the client's session cache.
  75. // A goroutine is started to automatically renew the TGT before expiry.
  76. func (cl *Client) addSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
  77. if strings.ToLower(tgt.SName.NameString[0]) != "krbtgt" {
  78. // Not a TGT
  79. return
  80. }
  81. realm := tgt.SName.NameString[len(tgt.SName.NameString)-1]
  82. s := &session{
  83. realm: realm,
  84. authTime: dep.AuthTime,
  85. endTime: dep.EndTime,
  86. renewTill: dep.RenewTill,
  87. tgt: tgt,
  88. sessionKey: dep.Key,
  89. sessionKeyExpiration: dep.KeyExpiration,
  90. }
  91. cl.sessions.update(s)
  92. cl.enableAutoSessionRenewal(s)
  93. cl.Log("TGT session added for %s (EndTime: %v)", realm, dep.EndTime)
  94. }
  95. // update overwrites the session details with those from the TGT and decrypted encPart
  96. func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
  97. s.mux.Lock()
  98. defer s.mux.Unlock()
  99. s.authTime = dep.AuthTime
  100. s.endTime = dep.EndTime
  101. s.renewTill = dep.RenewTill
  102. s.tgt = tgt
  103. s.sessionKey = dep.Key
  104. s.sessionKeyExpiration = dep.KeyExpiration
  105. }
  106. // destroy will cancel any auto renewal of the session and set the expiration times to the current time
  107. func (s *session) destroy() {
  108. s.mux.Lock()
  109. defer s.mux.Unlock()
  110. if s.cancel != nil {
  111. s.cancel <- true
  112. }
  113. s.endTime = time.Now().UTC()
  114. s.renewTill = s.endTime
  115. s.sessionKeyExpiration = s.endTime
  116. }
  117. // valid informs if the TGT is still within the valid time window
  118. func (s *session) valid() bool {
  119. s.mux.RLock()
  120. defer s.mux.RUnlock()
  121. t := time.Now().UTC()
  122. if t.Before(s.endTime) && s.authTime.Before(t) {
  123. return true
  124. }
  125. return false
  126. }
  127. // tgtDetails is a thread safe way to get the session's realm, TGT and session key values
  128. func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
  129. s.mux.RLock()
  130. defer s.mux.RUnlock()
  131. return s.realm, s.tgt, s.sessionKey
  132. }
  133. // timeDetails is a thread safe way to get the session's validity time values
  134. func (s *session) timeDetails() (string, time.Time, time.Time, time.Time, time.Time) {
  135. s.mux.RLock()
  136. defer s.mux.RUnlock()
  137. return s.realm, s.authTime, s.endTime, s.renewTill, s.sessionKeyExpiration
  138. }
  139. // JSON return information about the held sessions in a JSON format.
  140. func (s *sessions) JSON() (string, error) {
  141. s.mux.RLock()
  142. defer s.mux.RUnlock()
  143. var js []jsonSession
  144. keys := make([]string, 0, len(s.Entries))
  145. for k := range s.Entries {
  146. keys = append(keys, k)
  147. }
  148. sort.Strings(keys)
  149. for _, k := range keys {
  150. r, at, et, rt, kt := s.Entries[k].timeDetails()
  151. j := jsonSession{
  152. Realm: r,
  153. AuthTime: at,
  154. EndTime: et,
  155. RenewTill: rt,
  156. SessionKeyExpiration: kt,
  157. }
  158. js = append(js, j)
  159. }
  160. b, err := json.MarshalIndent(js, "", " ")
  161. if err != nil {
  162. return "", err
  163. }
  164. return string(b), nil
  165. }
  166. // enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
  167. func (cl *Client) enableAutoSessionRenewal(s *session) {
  168. var timer *time.Timer
  169. s.mux.Lock()
  170. s.cancel = make(chan bool, 1)
  171. s.mux.Unlock()
  172. go func(s *session) {
  173. for {
  174. s.mux.RLock()
  175. w := (s.endTime.Sub(time.Now().UTC()) * 5) / 6
  176. s.mux.RUnlock()
  177. if w < 0 {
  178. return
  179. }
  180. timer = time.NewTimer(w)
  181. select {
  182. case <-timer.C:
  183. renewal, err := cl.refreshSession(s)
  184. if err != nil {
  185. cl.Log("error refreshing session: %v", err)
  186. }
  187. if !renewal && err == nil {
  188. // end this goroutine as there will have been a new login and new auto renewal goroutine created.
  189. return
  190. }
  191. case <-s.cancel:
  192. // cancel has been called. Stop the timer and exit.
  193. timer.Stop()
  194. return
  195. }
  196. }
  197. }(s)
  198. }
  199. // renewTGT renews the client's TGT session.
  200. func (cl *Client) renewTGT(s *session) error {
  201. realm, tgt, skey := s.tgtDetails()
  202. spn := types.PrincipalName{
  203. NameType: nametype.KRB_NT_SRV_INST,
  204. NameString: []string{"krbtgt", realm},
  205. }
  206. _, tgsRep, err := cl.TGSREQGenerateAndExchange(spn, cl.Credentials.Domain(), tgt, skey, true)
  207. if err != nil {
  208. return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT for %s", realm)
  209. }
  210. s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
  211. cl.sessions.update(s)
  212. cl.Log("TGT session renewed for %s (EndTime: %v)", realm, tgsRep.DecryptedEncPart.EndTime)
  213. return nil
  214. }
  215. // refreshSession updates either through renewal or creating a new login.
  216. // The boolean indicates if the update was a renewal.
  217. func (cl *Client) refreshSession(s *session) (bool, error) {
  218. s.mux.RLock()
  219. realm := s.realm
  220. renewTill := s.renewTill
  221. s.mux.RUnlock()
  222. cl.Log("refreshing TGT session for %s", realm)
  223. if time.Now().UTC().Before(renewTill) {
  224. err := cl.renewTGT(s)
  225. return true, err
  226. }
  227. err := cl.realmLogin(realm)
  228. return false, err
  229. }
  230. // ensureValidSession makes sure there is a valid session for the realm
  231. func (cl *Client) ensureValidSession(realm string) error {
  232. s, ok := cl.sessions.get(realm)
  233. if ok {
  234. s.mux.RLock()
  235. d := s.endTime.Sub(s.authTime) / 6
  236. if s.endTime.Sub(time.Now().UTC()) > d {
  237. s.mux.RUnlock()
  238. return nil
  239. }
  240. s.mux.RUnlock()
  241. _, err := cl.refreshSession(s)
  242. return err
  243. }
  244. return cl.realmLogin(realm)
  245. }
  246. // sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
  247. func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
  248. err = cl.ensureValidSession(realm)
  249. if err != nil {
  250. return
  251. }
  252. s, ok := cl.sessions.get(realm)
  253. if !ok {
  254. err = fmt.Errorf("could not find TGT session for %s", realm)
  255. return
  256. }
  257. _, tgt, sessionKey = s.tgtDetails()
  258. return
  259. }
  260. // sessionTimes provides the timing information with regards to a session for the realm specified.
  261. func (cl *Client) sessionTimes(realm string) (authTime, endTime, renewTime, sessionExp time.Time, err error) {
  262. s, ok := cl.sessions.get(realm)
  263. if !ok {
  264. err = fmt.Errorf("could not find TGT session for %s", realm)
  265. return
  266. }
  267. _, authTime, endTime, renewTime, sessionExp = s.timeDetails()
  268. return
  269. }
  270. // spnRealm resolves the realm name of a service principal name
  271. func (cl *Client) spnRealm(spn types.PrincipalName) string {
  272. return cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
  273. }