nfold.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. package rfc3961
  2. // Implementation of the n-fold algorithm as defined in RFC 3961.
  3. /* Credits
  4. This golang implementation of nfold used the following project for help with implementation detail.
  5. Although their source is in java it was helpful as a reference implementation of the RFC.
  6. You can find the source code of their open source project along with license information below.
  7. We acknowledge and are grateful to these developers for their contributions to open source
  8. Project: Apache Directory (http://http://directory.apache.org/)
  9. https://svn.apache.org/repos/asf/directory/apacheds/tags/1.5.1/kerberos-shared/src/main/java/org/apache/directory/server/kerberos/shared/crypto/encryption/NFold.java
  10. License: http://www.apache.org/licenses/LICENSE-2.0
  11. */
  12. // Nfold expands the key to ensure it is not smaller than one cipher block.
  13. // Defined in RFC 3961.
  14. //
  15. // m input bytes that will be "stretched" to the least common multiple of n bits and the bit length of m.
  16. func Nfold(m []byte, n int) []byte {
  17. k := len(m) * 8
  18. //Get the lowest common multiple of the two bit sizes
  19. lcm := lcm(n, k)
  20. relicate := lcm / k
  21. var sumBytes []byte
  22. for i := 0; i < relicate; i++ {
  23. rotation := 13 * i
  24. sumBytes = append(sumBytes, rotateRight(m, rotation)...)
  25. }
  26. nfold := make([]byte, n/8)
  27. sum := make([]byte, n/8)
  28. for i := 0; i < lcm/n; i++ {
  29. for j := 0; j < n/8; j++ {
  30. sum[j] = sumBytes[j+(i*len(sum))]
  31. }
  32. nfold = onesComplementAddition(nfold, sum)
  33. }
  34. return nfold
  35. }
  36. func onesComplementAddition(n1, n2 []byte) []byte {
  37. numBits := len(n1) * 8
  38. out := make([]byte, numBits/8)
  39. carry := 0
  40. for i := numBits - 1; i > -1; i-- {
  41. n1b := getBit(&n1, i)
  42. n2b := getBit(&n2, i)
  43. s := n1b + n2b + carry
  44. if s == 0 || s == 1 {
  45. setBit(&out, i, s)
  46. carry = 0
  47. } else if s == 2 {
  48. carry = 1
  49. } else if s == 3 {
  50. setBit(&out, i, 1)
  51. carry = 1
  52. }
  53. }
  54. if carry == 1 {
  55. carryArray := make([]byte, len(n1))
  56. carryArray[len(carryArray)-1] = 1
  57. out = onesComplementAddition(out, carryArray)
  58. }
  59. return out
  60. }
  61. func rotateRight(b []byte, step int) []byte {
  62. out := make([]byte, len(b))
  63. bitLen := len(b) * 8
  64. for i := 0; i < bitLen; i++ {
  65. v := getBit(&b, i)
  66. setBit(&out, (i+step)%bitLen, v)
  67. }
  68. return out
  69. }
  70. func lcm(x, y int) int {
  71. return (x * y) / gcd(x, y)
  72. }
  73. func gcd(x, y int) int {
  74. for y != 0 {
  75. x, y = y, x%y
  76. }
  77. return x
  78. }
  79. func getBit(b *[]byte, p int) int {
  80. pByte := p / 8
  81. pBit := uint(p % 8)
  82. vByte := (*b)[pByte]
  83. vInt := int(vByte >> (8 - (pBit + 1)) & 0x0001)
  84. return vInt
  85. }
  86. func setBit(b *[]byte, p, v int) {
  87. pByte := p / 8
  88. pBit := uint(p % 8)
  89. oldByte := (*b)[pByte]
  90. var newByte byte
  91. newByte = byte(v<<(8-(pBit+1))) | oldByte
  92. (*b)[pByte] = newByte
  93. }