首页 Go 实现一个IntSet
文章
取消

Go 实现一个IntSet

Go 中没有 set 这种数据结构,我们常常需要一种 set 数据结构来实现集合的包含、交集、并集、差集、对称差运算。一种思路是使用内置的 map,它能实现的元素类型较为丰富,可以是 string、int、float 等等可以被比较的类型;另一种思路是使用 bit 数组,它对于 int 类型的元素集合可以极大节约内存使用,一个 int64,即一个64位 bit 数组,可以最多存储64位整数,使用二进制的运算特性可以很方便的得出集合运算结果,接下来我们使用后者来实现一个 IntSet。

定义

结构定义

刚刚我们提到使用 int64 来实现 bit 数组,为了存储更多的整数,需要用 int64 数组来扩充存储。我们定义 word 为一个 int64,word 的数量为size:

1
2
3
4
type IntSet struct {
	words []uint64
	size  uint64
}

这样一个简单的 IntSet 就定义好了。

包含

检查是否包含一个整数,只要检查相应的 bit 位 是否为 1 即可,为 1 的话就是包含,否则不包含。

1
2
3
4
5
// Contains returns whether num exists in IntSet s
func (s *IntSet) Contains(num uint64) bool {
	pos, bit := num/64, num%64
	return pos < s.size && s.words[pos]&(1<<bit) != 0
}

添加元素

添加元素,即找到相应的 bit 位,置 1 即可,

1
2
3
4
5
6
7
8
9
10
11
12
13
// Add add nums into IntSet s
func (s *IntSet) Add(nums ...uint64) {
	for _, num := range nums {
		pos, bit := num/64, num%64
		for pos >= s.size {
			// 扩充 s.words,先用0占位
			s.words = append(s.words, 0)
			s.size++
		}
		// 将 pos 位置的 word 的 bit 位置为 1
		s.words[pos] |= 1 << bit
	}
}

删除元素

删除元素,即找到相应的 bit 位,置 0 即可,

1
2
3
4
5
6
7
8
9
10
// Remove removes nums from IntSet s
func (s *IntSet) Remove(nums ...uint64) {
	for _, num := range nums {
		pos, bit := num/64, num%64
		if pos < s.size && s.words[pos]&(1<<bit) != 0 {
			// 如果存在 num 的话,将 pos 位置的 word 的 bit 为置为 0
			s.words[pos] &= ^(1 << bit)
		}
	}
}

复制

要复制 words 不能直接赋值给一个新变量,会共用底层数组,可以对新 slice 一次次 append,也可以创建同大小的 slice,再使用 copy 把元素复制过去,后者的效率更高些:

1
2
3
4
5
6
7
8
9
// Copy returns a copy of IntSet s
func (s *IntSet) Copy() *IntSet {
	r := IntSet{
		words: make([]uint64, s.size),
		size:  s.size,
	}
	copy(r.words, s.words)
	return &r
}

交集

交集,是指同时存在于两个集合的元素,如 {1,2,3} 与 {2,3,4} 的交集为 {2,3}。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// Intersection returns the intersection of IntSet s and t
func (s *IntSet) Intersection(t *IntSet) *IntSet {
	r := s.Copy()

	if r.size > t.size {
		// 截断 r.words,使其与 t.words 等长
		r.words = r.words[:t.size]
		r.size = t.size
	}
	var i uint64
	for ; i < r.size; i++ {
		// 逐 word 与运算
		r.words[i] &= t.words[i]
	}
	return r
}

并集

并集,是指包含两个集合的所有不重复元素,如 {1,2,3} 与 {2,3,4} 的并集为 {1,2,3,4}。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// Union returns the union of IntSet s and t
func (s *IntSet) Union(t *IntSet) *IntSet {
	r := s.Copy()
	for idx, tw := range t.words {
		if uint64(idx) < r.size {
			r.words[idx] |= tw
		} else {
			// t.words 比 s.words 长
			r.words = append(r.words, tw)
			r.size++
		}
	}
	return r
}

差集

对于集合 a 与 b,差集 a-b,意味着存在于 a,但不存在于 b 中的元素,如 {1,2,3}-{2,3,4}={1}。

1
2
3
4
5
6
7
8
9
10
11
12
13
// Difference returns difference of IntSet s and t
func (s *IntSet) Difference(t *IntSet) *IntSet {
	r := s.Copy()
	for idx := range r.words {
		for bit := 0; bit < 64; bit++ {
			num := uint64(idx*64 + bit)
			if t.Contains(num) {
				r.Remove(num)
			}
		}
	}
	return r
}

对称差

对称差,是只存在于一方中的元素,如 {1,2,3} 与 {2,3,4} 的对称差为 {1,4}。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// SymmetricDifference returns the symmetric difference of s and t
func (s *IntSet) SymmetricDifference(t *IntSet) *IntSet {
	var r IntSet
	if s.size >= t.size {
		for idx := range s.words {
			if uint64(idx) < t.size {
				r.words = append(r.words, s.words[idx]^t.words[idx])
			} else {
				r.words = append(r.words, s.words[idx]^0)
			}
			r.size++
		}
	} else {
		for idx := range t.words {
			if uint64(idx) < s.size {
				r.words = append(r.words, s.words[idx]^t.words[idx])
			} else {
				r.words = append(r.words, t.words[idx]^0)
			}
			r.size++
		}
	}
	return &r
}

是否是子集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// IsSubset returns whether Intset s is subset of t
func (s *IntSet) IsSubset(t *IntSet) bool {
	if s.size > t.size || s.size == 0 {
		return false
	}
	var idx uint64
	for ; idx < s.size; idx++ {
		if s.words[idx]&t.words[idx] != s.words[idx] {
            // 如果 s.words[idx] 是 t.words[idx] 子集的话,s.words[idx]&t.words[idx] 一定等于 s.words[idx]
			return false
		}
	}
	return true
}

长度

这里的长度是存储的整数的数量,即存储 bit 数组中有多少个 1:

1
2
3
4
5
6
7
8
9
10
11
// Len returns the count of numbers that exist in IntSet
func (s *IntSet) Len() int {
	var length int
	for _, word := range s.words {
		for word > 0 {
			word = word & (word - 1)
			length++
		}
	}
	return length
}

清空

清空最容易,s.words 置空,s.size 置 0 即可。

1
2
3
4
5
// Clear empties itself
func (s *IntSet) Clear() {
	s.words = nil
	s.size = 0
}

输出 slice

需要遍历找出所有值为 1 的bit 位,

1
2
3
4
5
6
7
8
9
10
11
12
// Slice returns the numbers in IntSet
func (s *IntSet) Slice() []uint64 {
	var result []uint64
	for idx, word := range s.words {
		for bit := 0; bit < 64; bit++ {
			if word&(1<<bit) != 0 {
				result = append(result, uint64(64*idx+bit))
			}
		}
	}
	return result
}

输出 string

同上,只是以 string 形式输出,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// String returns string represention
func (s *IntSet) String() string {
	var buf bytes.Buffer
	buf.WriteByte('{')
	for idx, word := range s.words {
		if word == 0 {
			continue
		}
		for bit := 0; bit < 64; bit++ {
			if word&(1<<bit) != 0 {
				buf.WriteByte(' ')
				fmt.Fprintf(&buf, "%d", idx*64+bit)
			}
		}
	}
	buf.WriteByte(' ')
	buf.WriteByte('}')
	return buf.String()
}

小结

我们将上面代码整理下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
package ds

import (
	"bytes"
	"fmt"
)

type IntSet struct {
	words []uint64
	size  uint64
}

// Contains returns whether num exists in IntSet s
func (s *IntSet) Contains(num uint64) bool {
	pos, bit := num/64, num%64
	return pos < s.size && s.words[pos]&(1<<bit) != 0
}

// Copy returns a copy of IntSet s
func (s *IntSet) Copy() *IntSet {
	r := IntSet{
		words: make([]uint64, s.size),
		size:  s.size,
	}
	copy(r.words, s.words)
	return &r
}

// Add add nums into IntSet s
func (s *IntSet) Add(nums ...uint64) {
	for _, num := range nums {
		pos, bit := num/64, num%64
		for pos >= s.size {
			// 扩充 s.words,先用0占位
			s.words = append(s.words, 0)
			s.size++
		}
		// 将 pos 位置的 word 的 bit 位置为 1
		s.words[pos] |= 1 << bit
	}
}

// Remove removes nums from IntSet s
func (s *IntSet) Remove(nums ...uint64) {
	for _, num := range nums {
		pos, bit := num/64, num%64
		if pos < s.size && s.words[pos]&(1<<bit) != 0 {
			// 如果存在 num 的话,将 pos 位置的 word 的 bit 为置为 0
			s.words[pos] &= ^(1 << bit)
		}
	}
}

// Intersection returns the intersection of IntSet s and t
func (s *IntSet) Intersection(t *IntSet) *IntSet {
	r := s.Copy()

	if r.size > t.size {
		// 截断 r.words,使其与 t.words 等长
		r.words = r.words[:t.size]
		r.size = t.size
	}
	var i uint64
	for ; i < r.size; i++ {
		// 逐 word 与运算
		r.words[i] &= t.words[i]
	}
	return r
}

// Union returns the union of IntSet s and t
func (s *IntSet) Union(t *IntSet) *IntSet {
	r := s.Copy()
	for idx, tw := range t.words {
		if uint64(idx) < r.size {
			r.words[idx] |= tw
		} else {
			// t.words 比 s.words 长
			r.words = append(r.words, tw)
			r.size++
		}
	}
	return r
}

// Difference returns difference of IntSet s and t
func (s *IntSet) Difference(t *IntSet) *IntSet {
	r := s.Copy()
	for idx := range r.words {
		for bit := 0; bit < 64; bit++ {
			num := uint64(idx*64 + bit)
			if t.Contains(num) {
				r.Remove(num)
			}
		}
	}
	return r
}

// SymmetricDifference returns the symmetric difference of s and t
func (s *IntSet) SymmetricDifference(t *IntSet) *IntSet {
	var r IntSet
	if s.size >= t.size {
		for idx := range s.words {
			if uint64(idx) < t.size {
				r.words = append(r.words, s.words[idx]^t.words[idx])
			} else {
				r.words = append(r.words, s.words[idx]^0)
			}
			r.size++
		}
	} else {
		for idx := range t.words {
			if uint64(idx) < s.size {
				r.words = append(r.words, s.words[idx]^t.words[idx])
			} else {
				r.words = append(r.words, t.words[idx]^0)
			}
			r.size++
		}
	}
	return &r
}

// IsSubset returns whether Intset s is subset of t
func (s *IntSet) IsSubset(t *IntSet) bool {
	if s.size > t.size || s.size == 0 {
		return false
	}
	var idx uint64
	for ; idx < s.size; idx++ {
		if s.words[idx]&t.words[idx] != s.words[idx] {
			// 如果 s.words[idx] 是 t.words[idx] 子集的话,s.words[idx]&t.words[idx] 一定等于 s.words[idx]
			return false
		}
	}
	return true
}

// Len returns the count of numbers that exist in IntSet
func (s *IntSet) Len() int {
	var length int
	for _, word := range s.words {
		for word > 0 {
			word = word & (word - 1)
			length++
		}
	}
	return length
}

// Clear empties itself
func (s *IntSet) Clear() {
	s.words = nil
	s.size = 0
}

// Slice returns the numbers in IntSet
func (s *IntSet) Slice() []uint64 {
	var result []uint64
	for idx, word := range s.words {
		for bit := 0; bit < 64; bit++ {
			if word&(1<<bit) != 0 {
				result = append(result, uint64(64*idx+bit))
			}
		}
	}
	return result
}

// String returns string represention
func (s *IntSet) String() string {
	var buf bytes.Buffer
	buf.WriteByte('{')
	for idx, word := range s.words {
		if word == 0 {
			continue
		}
		for bit := 0; bit < 64; bit++ {
			if word&(1<<bit) != 0 {
				buf.WriteByte(' ')
				fmt.Fprintf(&buf, "%d", idx*64+bit)
			}
		}
	}
	buf.WriteByte(' ')
	buf.WriteByte('}')
	return buf.String()
}

我们来测试下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func main() {
	var a, b, c ds.IntSet

	a.Add(1, 5, 10, 3, 69, 300)
	b.Add(0, 5, 10, 11)
	c.Add(5, 3, 10)

	fmt.Printf("b是a子集?%t \n", b.IsSubset(&a))
	fmt.Printf("c是a子集?%t \n", c.IsSubset(&a))

	fmt.Printf("a长度:%d \n", a.Len())
	fmt.Printf("a转slice:%v \n", a.Slice())
	fmt.Printf("b转slice:%v \n", b.Slice())

	fmt.Printf("a与b并集:%s \n", a.Union(&b).String())
	fmt.Printf("a与b交集:%s \n", a.Intersection(&b).String())
	fmt.Printf("a与b对称差:%s \n", a.SymmetricDifference(&b).String())
	fmt.Printf("a对b差集:%s \n", a.Difference(&b).String())

	fmt.Printf("a包含50? %t \n", a.Contains(50))
	fmt.Printf("a包含10? %t \n", a.Contains(10))

	a.Remove(5, 6)
	fmt.Printf("a移除5,6后:%s \n", a.String())

	a.Remove(69)
	fmt.Printf("a移除69后:%s \n", a.String())

	a.Clear()
	fmt.Printf("a清空后:%s \n", a.String())
}

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
b是a子集?false
c是a子集?true
a长度:6 
a转slice:[1 3 5 10 69 300] 
b转slice:[0 5 10 11] 
a与b并集:{ 0 1 3 5 10 11 69 300 } 
a与b交集:{ 5 10 } 
a与b对称差:{ 0 1 3 11 69 300 }
a对b差集:{ 1 3 69 300 }
a包含50? false 
a包含10? true 
a移除5,6后:{ 1 3 10 69 300 } 
a移除69后:{ 1 3 10 300 } 
a清空后:{ } 

不过该 IntSet 有其局限性:

  1. 不能存储负数,一个非负整数存储在 bit 数组中只需要1bit,但如果要支持负数存储,在 bit 数组中就至少需要3bit,一个表示数值的绝对值是否存在,一个表示是否是正数,一个表示是否是负数。
  2. 32位系统中使用 uint64 效率较差。
  3. 存储非负整数数量有限,slice 增长时如果容量过大会抛出异常 panic: runtime error: growslice: cap out of range,不过对于一般情况下够用了。
  4. 非线程安全
本文由作者按照 CC BY 4.0 进行授权