cache.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package client
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "sort"
  6. "sync"
  7. "time"
  8. "github.com/jcmturner/gokrb5/v8/messages"
  9. "github.com/jcmturner/gokrb5/v8/types"
  10. )
  11. // Cache for service tickets held by the client.
  12. type Cache struct {
  13. Entries map[string]CacheEntry
  14. mux sync.RWMutex
  15. }
  16. // CacheEntry holds details for a cache entry.
  17. type CacheEntry struct {
  18. SPN string
  19. Ticket messages.Ticket `json:"-"`
  20. AuthTime time.Time
  21. StartTime time.Time
  22. EndTime time.Time
  23. RenewTill time.Time
  24. SessionKey types.EncryptionKey `json:"-"`
  25. }
  26. // NewCache creates a new client ticket cache instance.
  27. func NewCache() *Cache {
  28. return &Cache{
  29. Entries: map[string]CacheEntry{},
  30. }
  31. }
  32. // getEntry returns a cache entry that matches the SPN.
  33. func (c *Cache) getEntry(spn string) (CacheEntry, bool) {
  34. c.mux.RLock()
  35. defer c.mux.RUnlock()
  36. e, ok := (*c).Entries[spn]
  37. return e, ok
  38. }
  39. // JSON returns information about the cached service tickets in a JSON format.
  40. func (c *Cache) JSON() (string, error) {
  41. c.mux.RLock()
  42. defer c.mux.RUnlock()
  43. var es []CacheEntry
  44. keys := make([]string, 0, len(c.Entries))
  45. for k := range c.Entries {
  46. keys = append(keys, k)
  47. }
  48. sort.Strings(keys)
  49. for _, k := range keys {
  50. es = append(es, c.Entries[k])
  51. }
  52. b, err := json.MarshalIndent(&es, "", " ")
  53. if err != nil {
  54. return "", err
  55. }
  56. return string(b), nil
  57. }
  58. // addEntry adds a ticket to the cache.
  59. func (c *Cache) addEntry(tkt messages.Ticket, authTime, startTime, endTime, renewTill time.Time, sessionKey types.EncryptionKey) CacheEntry {
  60. spn := tkt.SName.PrincipalNameString()
  61. c.mux.Lock()
  62. defer c.mux.Unlock()
  63. (*c).Entries[spn] = CacheEntry{
  64. SPN: spn,
  65. Ticket: tkt,
  66. AuthTime: authTime,
  67. StartTime: startTime,
  68. EndTime: endTime,
  69. RenewTill: renewTill,
  70. SessionKey: sessionKey,
  71. }
  72. return c.Entries[spn]
  73. }
  74. // clear deletes all the cache entries
  75. func (c *Cache) clear() {
  76. c.mux.Lock()
  77. defer c.mux.Unlock()
  78. for k := range c.Entries {
  79. delete(c.Entries, k)
  80. }
  81. }
  82. // RemoveEntry removes the cache entry for the defined SPN.
  83. func (c *Cache) RemoveEntry(spn string) {
  84. c.mux.Lock()
  85. defer c.mux.Unlock()
  86. delete(c.Entries, spn)
  87. }
  88. // GetCachedTicket returns a ticket from the cache for the SPN.
  89. // Only a ticket that is currently valid will be returned.
  90. func (cl *Client) GetCachedTicket(spn string) (messages.Ticket, types.EncryptionKey, bool) {
  91. if e, ok := cl.cache.getEntry(spn); ok {
  92. //If within time window of ticket return it
  93. if time.Now().UTC().After(e.StartTime) && time.Now().UTC().Before(e.EndTime) {
  94. cl.Log("ticket received from cache for %s", spn)
  95. return e.Ticket, e.SessionKey, true
  96. } else if time.Now().UTC().Before(e.RenewTill) {
  97. e, err := cl.renewTicket(e)
  98. if err != nil {
  99. return e.Ticket, e.SessionKey, false
  100. }
  101. return e.Ticket, e.SessionKey, true
  102. }
  103. }
  104. var tkt messages.Ticket
  105. var key types.EncryptionKey
  106. return tkt, key, false
  107. }
  108. // renewTicket renews a cache entry ticket.
  109. // To renew from outside the client package use GetCachedTicket
  110. func (cl *Client) renewTicket(e CacheEntry) (CacheEntry, error) {
  111. spn := e.Ticket.SName
  112. _, _, err := cl.TGSREQGenerateAndExchange(spn, e.Ticket.Realm, e.Ticket, e.SessionKey, true)
  113. if err != nil {
  114. return e, err
  115. }
  116. e, ok := cl.cache.getEntry(e.Ticket.SName.PrincipalNameString())
  117. if !ok {
  118. return e, errors.New("ticket was not added to cache")
  119. }
  120. cl.Log("ticket renewed for %s (EndTime: %v)", spn.PrincipalNameString(), e.EndTime)
  121. return e, nil
  122. }