前言

学习算法一般都会是从最排序开始,排序算法中堆排序效率相对来说也算是比较高的了。Golang中并没有可以直接使用的堆结构,JAVA中有现成的优先级队列可以使用,那么Golang中如何优雅实现一个比较通用的堆结构呢,需要满足:

  1. 类似JAVA优先级队列用法,可以指定比较器。
  2. 通用的数据结构。

使用泛型

对于通用组件接收的数据最好使用泛型来完成,好处就是不用进行数据断言和类型转换。

定义比较器

type CompFn[T any] func(a, b T) int

定义优先级队列对象

type PriorityQueue[T any] struct {
    base []T
    comp CompFn[T]
}

定义默认比较器

对于基础类型,比如int,我们想开箱即用,不用显示传递比较器就能完成大根堆的构建

func defaultComparator[T constraints.Ordered](a, b T) int {
    if a < b {
        return -1
    } else if a > b {
        return 1
    }

    return 0
}

注意,这里泛型类型使用的是constraints包下的Ordered,如果使用any类型,编译是会报错的,原因是any类型是不可以进行比较的。

所以需要导入constraints包:

import (
    "golang.org/x/exp/constraints"
)

工厂函数

func NewPQueue[T any](comp ...CompFn[T]) *PriorityQueue[T] {
    pq := &PriorityQueue[T]{
        base: make([]T, 0),
    }

    for _, f := range comp {
        if f != nil {
            pq.comp = f
            break
        }
    }

    if pq.comp == nil {
        var zero T
        switch any(zero).(type) {
        case int:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(int), any(b).(int)) }
        case int8:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(int8), any(b).(int8)) }
        case int16:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(int16), any(b).(int16)) }
        case int32:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(int32), any(b).(int32)) }
        case int64:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(int64), any(b).(int64)) }
        case uint:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(uint), any(b).(uint)) }
        case uint8:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(uint8), any(b).(uint8)) }
        case uint16:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(uint16), any(b).(uint16)) }
        case uint32:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(uint32), any(b).(uint32)) }
        case uint64:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(uint64), any(b).(uint64)) }
        case uintptr:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(uintptr), any(b).(uintptr)) }
        case float32:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(float32), any(b).(float32)) }
        case float64:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(float64), any(b).(float64)) }
        case string:
            pq.comp = func(a, b T) int { return defaultComparator(any(a).(string), any(b).(string)) }
        }
    }

    if pq.comp == nil {
        panic("key is not ordered,so must provide compare function")
    }

    return pq
}

参数中使用了变长参数,目的是为了实现默认的效果,就是说当没有给定比较器时使用默认比较器,结合上面所说,默认比较器并不是any类型,所以用了比较长的篇幅来实现类型转换,本人水平有限还没有想到更优雅的实现方式。

最后的判断语句:

if pq.comp == nil {
    panic("key is not ordered,so must provide compare function")
}

触发的条件时,如果排队的对象不是基础类型,比如自定义的结构体,并且没有指定比较器则会进行panic。

主体代码实现

func (p *PriorityQueue[T]) Size() int {
    return len(p.base)
}

func (p *PriorityQueue[T]) IsEmpty() bool {
    return p.Size() == 0
}

func (p *PriorityQueue[T]) Push(a T) {
    p.base = append(p.base, a)
    p.siftUp(len(p.base) - 1)
}

func (p *PriorityQueue[T]) Pop() (T, bool) {
    if p.IsEmpty() {
        var zero T
        return zero, false
    }

    ans := p.base[0]
    p.base[0] = p.base[len(p.base)-1]
    p.base = p.base[:len(p.base)-1]
    if !p.IsEmpty() {
        p.siftDown(0)
    }

    return ans, true
}

func (p *PriorityQueue[T]) Peek() (T, bool) {
    if p.IsEmpty() {
        var zero T
        return zero, false
    }

    return p.base[0], true
}

func (p *PriorityQueue[T]) siftUp(index int) {
    for p.comp(p.base[index], p.base[(index-1)/2]) > 0 {
        p.base[index], p.base[(index-1)/2] = p.base[(index-1)/2], p.base[index]
        index = (index - 1) / 2
    }
}

func (p *PriorityQueue[T]) siftDown(index int) {
    l := index*2 + 1
    size := len(p.base)
    for l < size {
        best := l
        if l+1 < size && p.comp(p.base[l+1], p.base[l]) > 0 {
            best = l + 1
        }

        if p.comp(p.base[index], p.base[best]) > 0 {
            best = index
        }

        if best == index {
            break
        }

        p.base[index], p.base[best] = p.base[best], p.base[index]
        index = best
        l = index*2 + 1
    }
}

在Pop方法中,对于泛型需要返回默认值需要注意,这里使用的是定义zero临时变量然后返回的方式,也可以使用具名返回值的形式。

单元测试和用法:

type User struct {
    company int
    age     int
    name    string
}

// 实现小根堆,所以要手动指定比较器
func Test_priorityQueue(t *testing.T) {
    pq := NewPQueue(func(a, b int) int {
        if a < b {
            return 1
        } else if a > b {
            return -1
        }

        return 0
    })
    pq.Push(5)
    pq.Push(3)
    pq.Push(7)

    t.Error(pq.Peek())
    t.Error(pq.Pop())
    t.Error(pq.Peek())
}

// 自定义对象使用比较器
func Test_priorityQueue1(t *testing.T) {
    pq := NewPQueue(func(a, b User) int {
        if a.company != b.company {
            return a.company - b.company
        } else if a.age != b.age {
            return a.age - b.age
        } else {
            return strings.Compare(a.name, b.name)
        }
    })
    pq.Push(User{company: 1, age: 18, name: "a"})
    pq.Push(User{company: 3, age: 19, name: "b"})
    pq.Push(User{company: 2, age: 17, name: "a"})

    t.Error(pq.Peek())
    t.Error(pq.Pop())
    t.Error(pq.Peek())
}
最后修改:2025 年 09 月 13 日
如果觉得我的文章对你有用,请随意赞赏