前言
学习算法一般都会是从最排序开始,排序算法中堆排序效率相对来说也算是比较高的了。Golang中并没有可以直接使用的堆结构,JAVA中有现成的优先级队列可以使用,那么Golang中如何优雅实现一个比较通用的堆结构呢,需要满足:
- 类似JAVA优先级队列用法,可以指定比较器。
- 通用的数据结构。
使用泛型
对于通用组件接收的数据最好使用泛型来完成,好处就是不用进行数据断言和类型转换。
定义比较器
type CompFn[T any] func(a, b T) int
定义优先级队列对象
type PriorityQueue[T any] struct {
base []T
comp CompFn[T]
}
定义默认比较器
对于基础类型,比如int,我们想开箱即用,不用显示传递比较器就能完成大根堆的构建
func defaultComparator[T constraints.Ordered](a, b T) int {
if a < b {
return -1
} else if a > b {
return 1
}
return 0
}
注意,这里泛型类型使用的是constraints包下的Ordered,如果使用any类型,编译是会报错的,原因是any类型是不可以进行比较的。
所以需要导入constraints包:
import (
"golang.org/x/exp/constraints"
)
工厂函数
func NewPQueue[T any](comp ...CompFn[T]) *PriorityQueue[T] {
pq := &PriorityQueue[T]{
base: make([]T, 0),
}
for _, f := range comp {
if f != nil {
pq.comp = f
break
}
}
if pq.comp == nil {
var zero T
switch any(zero).(type) {
case int:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(int), any(b).(int)) }
case int8:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(int8), any(b).(int8)) }
case int16:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(int16), any(b).(int16)) }
case int32:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(int32), any(b).(int32)) }
case int64:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(int64), any(b).(int64)) }
case uint:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(uint), any(b).(uint)) }
case uint8:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(uint8), any(b).(uint8)) }
case uint16:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(uint16), any(b).(uint16)) }
case uint32:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(uint32), any(b).(uint32)) }
case uint64:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(uint64), any(b).(uint64)) }
case uintptr:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(uintptr), any(b).(uintptr)) }
case float32:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(float32), any(b).(float32)) }
case float64:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(float64), any(b).(float64)) }
case string:
pq.comp = func(a, b T) int { return defaultComparator(any(a).(string), any(b).(string)) }
}
}
if pq.comp == nil {
panic("key is not ordered,so must provide compare function")
}
return pq
}
参数中使用了变长参数,目的是为了实现默认的效果,就是说当没有给定比较器时使用默认比较器,结合上面所说,默认比较器并不是any类型,所以用了比较长的篇幅来实现类型转换,本人水平有限还没有想到更优雅的实现方式。
最后的判断语句:
if pq.comp == nil {
panic("key is not ordered,so must provide compare function")
}
触发的条件时,如果排队的对象不是基础类型,比如自定义的结构体,并且没有指定比较器则会进行panic。
主体代码实现
func (p *PriorityQueue[T]) Size() int {
return len(p.base)
}
func (p *PriorityQueue[T]) IsEmpty() bool {
return p.Size() == 0
}
func (p *PriorityQueue[T]) Push(a T) {
p.base = append(p.base, a)
p.siftUp(len(p.base) - 1)
}
func (p *PriorityQueue[T]) Pop() (T, bool) {
if p.IsEmpty() {
var zero T
return zero, false
}
ans := p.base[0]
p.base[0] = p.base[len(p.base)-1]
p.base = p.base[:len(p.base)-1]
if !p.IsEmpty() {
p.siftDown(0)
}
return ans, true
}
func (p *PriorityQueue[T]) Peek() (T, bool) {
if p.IsEmpty() {
var zero T
return zero, false
}
return p.base[0], true
}
func (p *PriorityQueue[T]) siftUp(index int) {
for p.comp(p.base[index], p.base[(index-1)/2]) > 0 {
p.base[index], p.base[(index-1)/2] = p.base[(index-1)/2], p.base[index]
index = (index - 1) / 2
}
}
func (p *PriorityQueue[T]) siftDown(index int) {
l := index*2 + 1
size := len(p.base)
for l < size {
best := l
if l+1 < size && p.comp(p.base[l+1], p.base[l]) > 0 {
best = l + 1
}
if p.comp(p.base[index], p.base[best]) > 0 {
best = index
}
if best == index {
break
}
p.base[index], p.base[best] = p.base[best], p.base[index]
index = best
l = index*2 + 1
}
}
在Pop方法中,对于泛型需要返回默认值需要注意,这里使用的是定义zero临时变量然后返回的方式,也可以使用具名返回值的形式。
单元测试和用法:
type User struct {
company int
age int
name string
}
// 实现小根堆,所以要手动指定比较器
func Test_priorityQueue(t *testing.T) {
pq := NewPQueue(func(a, b int) int {
if a < b {
return 1
} else if a > b {
return -1
}
return 0
})
pq.Push(5)
pq.Push(3)
pq.Push(7)
t.Error(pq.Peek())
t.Error(pq.Pop())
t.Error(pq.Peek())
}
// 自定义对象使用比较器
func Test_priorityQueue1(t *testing.T) {
pq := NewPQueue(func(a, b User) int {
if a.company != b.company {
return a.company - b.company
} else if a.age != b.age {
return a.age - b.age
} else {
return strings.Compare(a.name, b.name)
}
})
pq.Push(User{company: 1, age: 18, name: "a"})
pq.Push(User{company: 3, age: 19, name: "b"})
pq.Push(User{company: 2, age: 17, name: "a"})
t.Error(pq.Peek())
t.Error(pq.Pop())
t.Error(pq.Peek())
}
版权属于:redhat
本文链接:https://blog.zhangshuocauc.cn/archives/25/
本站采用“知识共享署名 - 非商业性使用 - 相同方式共享 4.0 中国大陆许可协议” 进行许可 您可以转载本站的技术类文章,转载时请以超链接形式标明文章原始出处,Thank you.