golang系列之-sync.Pool

sync.Pool-临时对象池,是golang一个很关键的数据结构,通过复用历史对象,缓解因频繁创建、删除对象而导致的内存分配压力、GC压力,在社区中被广泛使用,有如go-gin、kubernetes等

当前go版本:1.24

快速上手

下面展示一个简单的使用示例,用于帮助用户快速上手

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
package main

import (
"fmt"
"sync"
)

type JobState int

const (
JobStateFresh JobState = iota
JobStateRunning
JobStateRecycled
)

type Job struct {
state JobState
}

func (j *Job) Run() {
switch j.state {
case JobStateRecycled:
fmt.Println("this job came from the pool")
case JobStateFresh:
fmt.Println("this job just got allocated")
}

j.state = JobStateRunning
}

func main() {
// 创建一个对象池
pool := &sync.Pool{
New: func() any {
return &Job{state: JobStateFresh}
},
}

// 获取一个对象,可以是新建的或者是历史使用过的
job := pool.Get().(*Job)

// 执行业务代码
job.Run()

// reset状态并放回池子里,方便下次使用
job.state = JobStateRecycled
pool.Put(job)
}

数据结构

todo:文章图片待补充

sync.Pool的源代码注释被我删除了,建议自行查看源代码,简单总结如下

  1. sync.Pool用于临时对象服务存储/获取(临时对象可能随时被清理掉)
  2. sync.Pool是线程安全的
  3. sync.Pool使用后不应该也不能被复制
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
// Pool -> local -> poolLocal_0 -> poolChainElt_0(head) -> poolDequeue(2^(3+N)
// ↕
// -> poolChainElt_1 -> poolDequeue(2^(3+N-1)
// -> ...
// -> poolChainElt_N(tail) -> poolDequeue(2^3=8个数据)
// -> ...
// -> poolLocal_P
//
// -> victim
//

// src/sync/pool.go
type Pool struct {
noCopy noCopy
local unsafe.Pointer // 临时对象数组指针,真实结构是[P]poolLocal,每个P一个poolLocal链表
localSize uintptr // local数组的大小,一般情况下与P的数量相同
victim unsafe.Pointer // 前local数组,被GC搬过来的
victimSize uintptr // 前local数组的大小
New func() any // New 用于创建临时对象,如果池子内没有数据的话
}

type poolLocal struct {
poolLocalInternal // 实际数据存储位置-链表

// Prevents false sharing on widespread platforms with
// 128 mod (cache line size) = 0 .
// pad在amd64平台下是96个字节大小
pad [128 - unsafe.Sizeof(poolLocalInternal{})%128]byte
}

type poolLocalInternal struct {
private any // 单个数据,优化读写,该字段只被当前P访问
shared poolChain // 链表,当前P读写都在头部,其他P没数据时从末尾偷
}

// src/sync/poolqueue.go
type poolChain struct {
// 新的poolChainElt在head,容量是上一个poolChainElt的双倍
head *poolChainElt // 头部 数据读写 => 1 write
tail atomic.Pointer[poolChainElt] // 尾部 数据读取 => N read
}

type poolChainElt struct {
poolDequeue // 数据部份-环形数组

// prev指向旧poolChainElt
// next指向新poolChainElt
next, prev atomic.Pointer[poolChainElt]
}

type poolDequeue struct {
// ptr -> | tail | | | | head | |
// idx -> | 0 | 1 | ... | 98 | 99 | ... |
// val -> | 1 | 2 | ... | 98 | nil | nil |
headTail atomic.Uint64 // head(高32bit) + tail(低32bit)
vals []eface // 环形数组 初始容量=8 => 2的乘方,最大不能超过2^30 => 1GB
}

注意:这里有几个全局变量用于纪录所有创建的池子

1
2
3
4
5
6
7
8
var (
// Pool创建或者local扩容时使用
allPoolsMu Mutex
// 所有创建的Pool都纪录到这里
allPools []*Pool
// allPools的上一个历史版本,每次GC都会将allPools移动到oldPools
oldPools []*Pool
)

读写操作

辅助函数

Get和Put操作都依赖的几个公共方法放在这里,可以先看后续的读写代码再回头看这部份

  1. pin - 用于绑定goroutine和P,阻止进入抢占模式,并返回pid
  2. indexLocal - 获取local数组中指定的P的poolLocal索引
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
// 返回poolLocal和id
func (p *Pool) pin() (*poolLocal, int) {
// pool不能为nil
if p == nil {
panic("nil Pool")
}

// 获取pid
pid := runtime_procPin()
// 数据量
s := runtime_LoadAcquintptr(&p.localSize) // load-acquire
// poolLocal
l := p.local // load-consume
// 正常情况或P缩小了
if uintptr(pid) < s {
return indexLocal(l, pid), pid
}
// P扩容了,池子不够大
return p.pinSlow()
}

func (p *Pool) pinSlow() (*poolLocal, int) {
//
runtime_procUnpin()
// 加锁
allPoolsMu.Lock()
defer allPoolsMu.Unlock()
// 重新pin
pid := runtime_procPin()
// 加锁后直接获取size和poolLocal
s := p.localSize
l := p.local
// 其他goroutine完成了扩容?
if uintptr(pid) < s {
return indexLocal(l, pid), pid
}
// 新的池子,注册到allPools
if p.local == nil {
allPools = append(allPools, p)
}
// 如果GOMAXPROCS有改动,生成新的poolLocal替换旧的
size := runtime.GOMAXPROCS(0)
// 一个P一个poolLocal
local := make([]poolLocal, size)
atomic.StorePointer(&p.local, unsafe.Pointer(&local[0])) // store-release
runtime_StoreReluintptr(&p.localSize, uintptr(size)) // store-release
return &local[pid], pid
}

// local数组索引,根据P的值定位poolLocal
func indexLocal(l unsafe.Pointer, i int) *poolLocal {
lp := unsafe.Pointer(uintptr(l) + uintptr(i)*unsafe.Sizeof(poolLocal{}))
return (*poolLocal)(lp)
}

获取一个对象

大概逻辑如下

  1. 根据P的值定位poolLocal链表
  2. 如果private有数值,返回该值
  3. 从shared(head)获取一个数据
  4. 从所有P的poolLocal链表找数据(从下一个P开始)
    • 从local查找 -> private -> shared(tail)
    • 从victim查找(只使用一次) -> shared(tail)
  5. 使用New方法生成一个新对象
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
func (p *Pool) Get() any {
// 获取pid和poolLocal链表
l, pid := p.pin()
// 直接读取private字段,如果有数据的话,这是一个优化,避免去查队列
x := l.private
l.private = nil
// 如果private没有数据
if x == nil {
// 从shared头部获取一个数据
x, _ = l.shared.popHead()
// 还是没有
if x == nil {
// 从别的P偷一个回来
x = p.getSlow(pid)
}
}
// unpin
runtime_procUnpin()
// 也没偷到
if x == nil && p.New != nil {
// 调用New函数
x = p.New()
}
return x
}

func (p *Pool) getSlow(pid int) any {
size := runtime_LoadAcquintptr(&p.localSize)
locals := p.local

// 1. 在local找
// P可能扩容也可能缩容
for i := 0; i < int(size); i++ {
// local数组索引,根据P的值定位poolLocal
// 遍历所有P,从下一个P开始
l := indexLocal(locals, (pid+i+1)%int(size))
// 从shared尾部获取一个数据
if x, _ := l.shared.popTail(); x != nil {
return x
}
}

// 2. 在victim找
size = atomic.LoadUintptr(&p.victimSize)
// 1. gc后P扩容了,没有当前P的数据
// 2. victim被其他P访问过了
// 3. victim为空
if uintptr(pid) >= size {
return nil
}
// 没有扩容或者缩容了
locals = p.victim
// local数组索引,根据P的值定位poolLocal
l := indexLocal(locals, pid)
// 下面同Get方法
// -> private -> shared
if x := l.private; x != nil {
l.private = nil
return x
}
for i := 0; i < int(size); i++ {
l := indexLocal(locals, (pid+i)%int(size))
if x, _ := l.shared.popTail(); x != nil {
return x
}
}

// victimSize设置为0
atomic.StoreUintptr(&p.victimSize, 0)

return nil
}

// src/sync/poolqueue.go
func (c *poolChain) popHead() (any, bool) {
// poolChainElt
d := c.head
// 从head扫描到tail
for d != nil {
// 找到一个数据
if val, ok := d.popHead(); ok {
return val, ok
}
// prev
d = d.prev.Load()
}
return nil, false
}

func (c *poolChain) popTail() (any, bool) {
// poolChainElt
d := c.tail.Load()
// shared链表是空的
if d == nil {
return nil, false
}

for {
// next
d2 := d.next.Load()

// 当前d找到一个数据
if val, ok := d.popTail(); ok {
return val, ok
}

// next为nil
if d2 == nil {
return nil, false
}

// 当前d没数据,next不为nil,修改tail指针以及prev指针
if c.tail.CompareAndSwap(d, d2) {
d2.prev.Store(nil)
}
d = d2
}
}

func (d *poolDequeue) popHead() (any, bool) {
var slot *eface
// vals是一个环形数组
for {
ptrs := d.headTail.Load()
// 从headTail解析
head, tail := d.unpack(ptrs)
// 数组是空的
if tail == head {
return nil, false
}

// 挪动指针
head--
ptrs2 := d.pack(head, tail)
// 回写成功
if d.headTail.CompareAndSwap(ptrs, ptrs2) {
slot = &d.vals[head&uint32(len(d.vals)-1)]
break
}
}

val := *(*any)(unsafe.Pointer(slot))
// nil特殊处理判断
if val == dequeueNil(nil) {
val = nil
}
// 默认值
*slot = eface{}
return val, true
}

func (d *poolDequeue) popTail() (any, bool) {
var slot *eface
// vals是一个环形数组
for {
ptrs := d.headTail.Load()
// 从headTail解析
head, tail := d.unpack(ptrs)
// 数组是空的
if tail == head {
return nil, false
}

//
ptrs2 := d.pack(head, tail+1)
// 回写成功
if d.headTail.CompareAndSwap(ptrs, ptrs2) {
slot = &d.vals[tail&uint32(len(d.vals)-1)]
break
}
}

// We now own slot.
val := *(*any)(unsafe.Pointer(slot))
// nil特殊处理判断
if val == dequeueNil(nil) {
val = nil
}

// 置为nil
slot.val = nil
atomic.StorePointer(&slot.typ, nil) // 读写判断该字段

return val, true
}

保存一个对象

大概逻辑如下

  1. 如果对象为nil,不处理
  2. 根据P的值定位poolLocal链表
  3. 如果private为nil,直接写入
  4. 如果private有数据,写入shared(head)
    • shared链表为空,创建链表(初始数据量为8)
    • 写入成功?返回
    • 写入失败?扩容后再次写入(数据量最大不能超过2^30=1GB)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// src/sync/pool.go
func (p *Pool) Put(x any) {
// x=nil,不操作
if x == nil {
return
}
// 获取pid和poolLocal链表
l, _ := p.pin()
// private字段是空的?直接写入
if l.private == nil {
l.private = x
} else {
// 放置在shared头部
l.shared.pushHead(x)
}
// unpin
runtime_procUnpin()
}

// src/sync/poolqueue.go
func (c *poolChain) pushHead(val any) {
d := c.head
// shared链表是空的
if d == nil {
// 创建一个poolChainElt,关联head和tail
const initSize = 8 // Must be a power of 2
d = new(poolChainElt)
d.vals = make([]eface, initSize)
c.head = d
c.tail.Store(d)
}

// 写入成功
if d.pushHead(val) {
return
}

// 写入失败
// 当前poolChainElt数据满了,双倍扩容
newSize := len(d.vals) * 2
// 最大不能超过2^30 => 1GB
if newSize >= dequeueLimit {
newSize = dequeueLimit
}

d2 := &poolChainElt{}
d2.prev.Store(d)
d2.vals = make([]eface, newSize)
c.head = d2
d.next.Store(d2)
d2.pushHead(val)
}

func (d *poolDequeue) pushHead(val any) bool {
ptrs := d.headTail.Load()
// 从headTail解析
head, tail := d.unpack(ptrs)
// (tail_idx+size)&mask == head
if (tail+uint32(len(d.vals)))&(1<<dequeueBits-1) == head {
// Queue is full.
return false
}
slot := &d.vals[head&uint32(len(d.vals)-1)]

// 目标slot还有数据
// Check if the head slot has been released by popTail.
typ := atomic.LoadPointer(&slot.typ)
if typ != nil {
return false
}

// nil转换
// The head slot is free, so we own it.
if val == nil {
val = dequeueNil(nil)
}
*(*any)(unsafe.Pointer(slot)) = val

d.headTail.Add(1 << dequeueBits)
return true
}

数据过期

每当GC进入STW状态时,清理Pool相关数据

  1. Pool内数据过期:local -> victim & victim -> nil
  2. 全局变量数据过期:allPools -> oldPools & oldPools -> nil
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func poolCleanup() {
// oldPools、allPools存储的都是指针,可以指向同一个p

// 遍历oldPools,清空victim
for _, p := range oldPools {
p.victim = nil
p.victimSize = 0
}

// 遍历allPools,把local迁移到victim
for _, p := range allPools {
p.victim = p.local
p.victimSize = p.localSize
p.local = nil
p.localSize = 0
}

// allPools的数据迁移到oldPools
oldPools, allPools = allPools, nil
}

func init() {
// 纪录poolCleanup函数,每次GC开始前执行
runtime_registerPoolCleanup(poolCleanup)
}

存储切片注意

当使用sync.Pool存储切片时,sync.Pool会如何处理呢?看下面示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
package main

import "sync"

var pool sync.Pool

func init() {
pool = sync.Pool{
New: func() any {
return make([]byte, 4<<10)
},
}
}

func main() {
b := pool.Get().([]byte)

// ...do something with b
_ = b

pool.Put(b) // this is line 21
}

从上述代码看,我们创建了一个可以重复利用的切片/缓存区。打开Escape Analysis-逃逸分析运行看看

1
2
3
4
5
6
7
8
9
10
11
go run -gcflags="-m" hello.go  

# 输出如下
# # command-line-arguments
# ./hello.go:9:8: can inline init.0.func1
# ./hello.go:7:6: can inline init.0
# ...
# ./hello.go:10:15: make([]byte, 4096) escapes to heap
# ./hello.go:10:15: make([]byte, 4096) escapes to heap
# ./hello.go:9:8: func literal escapes to heap
# ./hello.go:21:11: b escapes to heap

当创建一个slice-切片时,我们得到的是一个header,系统判断其不仅局限于New函数,使其逃逸至heap上分配,而b也是一个header,同样,系统也会使其逃逸至heap

重新调整代码,改为一个指向slice的指针,如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
package main

import "sync"

var pool sync.Pool

func init() {
pool = sync.Pool{
New: func() any {
b := make([]byte, 4<<10)
return &b
},
}
}

func main() {
bPtr := pool.Get().(*[]byte)
b := *bPtr

// ...do something with b
_ = b

pool.Put(bPtr)
}

再一次运行逃逸分析,得到结果如下

1
2
3
4
5
6
7
8
9
10
11
go run -gcflags="-m" hello.go

# 输出如下

# # command-line-arguments
# ./hello.go:9:8: can inline init.0.func1
# ./hello.go:7:6: can inline init.0
# ...
# ./hello.go:10:4: moved to heap: b
# ./hello.go:10:13: make([]byte, 4096) escapes to heap
# ./hello.go:9:8: func literal escapes to heap

这一次,把原始指针放回Pool就不会发生逃逸现象

参考文档

Let’s dive: a tour of sync.Pool internals
深度分析 Golang sync.Pool 底层原理
Go sync.Pool and the Mechanics Behind It