maxheap.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package maxheap
  2. import (
  3. "reflect"
  4. "../array"
  5. )
  6. type MaxHeap struct {
  7. data *array.Array
  8. }
  9. func New() *MaxHeap {
  10. return &MaxHeap{
  11. data: array.New(20),
  12. }
  13. }
  14. func (a *MaxHeap) GetSize() int {
  15. return a.data.GetSize()
  16. }
  17. func parent(index int) int {
  18. return (index - 1) / 2
  19. }
  20. func leftChild(index int) int {
  21. return index*2 + 1
  22. }
  23. func rightChild(index int) int {
  24. return index*2 + 2
  25. }
  26. func (h *MaxHeap) Add(e interface{}) {
  27. h.data.AddLast(e)
  28. h.siftUp(h.data.GetSize() - 1)
  29. }
  30. func (a *MaxHeap) siftUp(k int) {
  31. for k > 0 && Compare(a.data.Get(k), a.data.Get(parent(k))) > 0 {
  32. a.data.Swap(k, parent(k))
  33. k = parent(k)
  34. }
  35. }
  36. func (a *MaxHeap) FindMax() interface{} {
  37. if a.data.GetSize() == 0 {
  38. panic("cannot findMax when heap is empty.")
  39. }
  40. return a.data.Get(0)
  41. }
  42. func (h *MaxHeap) ExtractMax() interface{} {
  43. ret := h.FindMax()
  44. h.data.Swap(0, h.data.GetSize()-1)
  45. h.data.RemoveLast()
  46. h.siftDown(0)
  47. return ret
  48. }
  49. func (a *MaxHeap) siftDown(k int) {
  50. for leftChild(k) < a.data.GetSize() {
  51. j := leftChild(k)
  52. if j+1 < a.data.GetSize() && Compare(a.data.Get(j+1), a.data.Get(j)) > 0 {
  53. j++
  54. }
  55. if Compare(a.data.Get(k), a.data.Get(j)) > 0 {
  56. break
  57. }
  58. a.data.Swap(k, j)
  59. k = j
  60. }
  61. }
  62. func Compare(a interface{}, b interface{}) int {
  63. aType := reflect.TypeOf(a).String()
  64. bType := reflect.TypeOf(b).String()
  65. if aType != bType {
  66. panic("cannot compare different type params")
  67. }
  68. switch a.(type) {
  69. case int:
  70. if a.(int) > b.(int) {
  71. return 1
  72. } else if a.(int) < b.(int) {
  73. return -1
  74. } else {
  75. return 0
  76. }
  77. case string:
  78. if a.(string) > b.(string) {
  79. return 1
  80. } else if a.(string) < b.(string) {
  81. return -1
  82. } else {
  83. return 0
  84. }
  85. case float64:
  86. if a.(float64) > b.(float64) {
  87. return 1
  88. } else if a.(float64) < b.(float64) {
  89. return -1
  90. } else {
  91. return 0
  92. }
  93. default:
  94. panic("unsupported type params")
  95. }
  96. }