123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- package maxheap
- import (
- "reflect"
- "../array"
- )
- type MaxHeap struct {
- data *array.Array
- }
- func New() *MaxHeap {
- return &MaxHeap{
- data: array.New(20),
- }
- }
- func (a *MaxHeap) GetSize() int {
- return a.data.GetSize()
- }
- func parent(index int) int {
- return (index - 1) / 2
- }
- func leftChild(index int) int {
- return index*2 + 1
- }
- func rightChild(index int) int {
- return index*2 + 2
- }
- func (h *MaxHeap) Add(e interface{}) {
- h.data.AddLast(e)
- h.siftUp(h.data.GetSize() - 1)
- }
- func (a *MaxHeap) siftUp(k int) {
- for k > 0 && Compare(a.data.Get(k), a.data.Get(parent(k))) > 0 {
- a.data.Swap(k, parent(k))
- k = parent(k)
- }
- }
- func (a *MaxHeap) FindMax() interface{} {
- if a.data.GetSize() == 0 {
- panic("cannot findMax when heap is empty.")
- }
- return a.data.Get(0)
- }
- func (h *MaxHeap) ExtractMax() interface{} {
- ret := h.FindMax()
- h.data.Swap(0, h.data.GetSize()-1)
- h.data.RemoveLast()
- h.siftDown(0)
- return ret
- }
- func (a *MaxHeap) siftDown(k int) {
- for leftChild(k) < a.data.GetSize() {
- j := leftChild(k)
- if j+1 < a.data.GetSize() && Compare(a.data.Get(j+1), a.data.Get(j)) > 0 {
- j++
- }
- if Compare(a.data.Get(k), a.data.Get(j)) > 0 {
- break
- }
- a.data.Swap(k, j)
- k = j
- }
- }
- func Compare(a interface{}, b interface{}) int {
- aType := reflect.TypeOf(a).String()
- bType := reflect.TypeOf(b).String()
- if aType != bType {
- panic("cannot compare different type params")
- }
- switch a.(type) {
- case int:
- if a.(int) > b.(int) {
- return 1
- } else if a.(int) < b.(int) {
- return -1
- } else {
- return 0
- }
- case string:
- if a.(string) > b.(string) {
- return 1
- } else if a.(string) < b.(string) {
- return -1
- } else {
- return 0
- }
- case float64:
- if a.(float64) > b.(float64) {
- return 1
- } else if a.(float64) < b.(float64) {
- return -1
- } else {
- return 0
- }
- default:
- panic("unsupported type params")
- }
- }
|