Files
wireguard-go/ratelimiter/ratelimiter.go
T

173 lines
3.3 KiB
Go
Raw Normal View History

2019-01-02 01:55:51 +01:00
/* SPDX-License-Identifier: MIT
2018-05-03 15:04:00 +02:00
*
2021-01-28 17:52:15 +01:00
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
2018-05-03 15:04:00 +02:00
*/
2018-02-12 22:29:11 +01:00
package ratelimiter
2017-07-11 18:48:29 +02:00
import (
"net"
"sync"
"time"
)
const (
packetsPerSecond = 20
packetsBurstable = 5
garbageCollectTime = time.Second
packetCost = 1000000000 / packetsPerSecond
maxTokens = packetCost * packetsBurstable
2017-07-11 18:48:29 +02:00
)
type RatelimiterEntry struct {
mu sync.Mutex
2017-07-11 18:48:29 +02:00
lastTime time.Time
tokens int64
}
type Ratelimiter struct {
mu sync.RWMutex
timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop
2018-02-11 22:53:39 +01:00
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
rate.mu.Lock()
defer rate.mu.Unlock()
2018-02-11 22:53:39 +01:00
if rate.stopReset != nil {
close(rate.stopReset)
2018-02-11 22:53:39 +01:00
}
2017-07-11 18:48:29 +02:00
}
func (rate *Ratelimiter) Init() {
rate.mu.Lock()
defer rate.mu.Unlock()
if rate.timeNow == nil {
rate.timeNow = time.Now
}
2018-02-11 22:53:39 +01:00
// stop any ongoing garbage collection routine
if rate.stopReset != nil {
close(rate.stopReset)
2018-02-11 22:53:39 +01:00
}
rate.stopReset = make(chan struct{})
2017-07-11 18:48:29 +02:00
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
2018-02-11 22:53:39 +01:00
stopReset := rate.stopReset // store in case Init is called again.
// Start garbage collection routine.
2018-02-11 22:53:39 +01:00
go func() {
2018-05-13 18:42:06 +02:00
ticker := time.NewTicker(time.Second)
ticker.Stop()
2018-02-11 22:53:39 +01:00
for {
select {
case _, ok := <-stopReset:
2018-05-13 18:42:06 +02:00
ticker.Stop()
if !ok {
return
}
ticker = time.NewTicker(time.Second)
2018-05-13 18:42:06 +02:00
case <-ticker.C:
if rate.cleanup() {
ticker.Stop()
}
2018-02-11 22:53:39 +01:00
}
}
}()
2017-07-11 18:48:29 +02:00
}
func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock()
defer rate.mu.Unlock()
for key, entry := range rate.tableIPv4 {
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key)
}
entry.mu.Unlock()
}
for key, entry := range rate.tableIPv6 {
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key)
}
entry.mu.Unlock()
}
return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
}
2017-07-11 18:48:29 +02:00
func (rate *Ratelimiter) Allow(ip net.IP) bool {
var entry *RatelimiterEntry
2018-05-13 18:42:06 +02:00
var keyIPv4 [net.IPv4len]byte
var keyIPv6 [net.IPv6len]byte
2017-07-11 18:48:29 +02:00
// lookup entry
IPv4 := ip.To4()
IPv6 := ip.To16()
rate.mu.RLock()
2017-07-11 18:48:29 +02:00
if IPv4 != nil {
2018-05-13 18:42:06 +02:00
copy(keyIPv4[:], IPv4)
entry = rate.tableIPv4[keyIPv4]
2017-07-11 18:48:29 +02:00
} else {
2018-05-13 18:42:06 +02:00
copy(keyIPv6[:], IPv6)
entry = rate.tableIPv6[keyIPv6]
2017-07-11 18:48:29 +02:00
}
rate.mu.RUnlock()
2017-07-11 18:48:29 +02:00
// make new entry if not found
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
entry.lastTime = rate.timeNow()
rate.mu.Lock()
2017-07-11 18:48:29 +02:00
if IPv4 != nil {
2018-05-13 18:42:06 +02:00
rate.tableIPv4[keyIPv4] = entry
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
rate.stopReset <- struct{}{}
}
2017-07-11 18:48:29 +02:00
} else {
2018-05-13 18:42:06 +02:00
rate.tableIPv6[keyIPv6] = entry
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
rate.stopReset <- struct{}{}
}
2017-07-11 18:48:29 +02:00
}
rate.mu.Unlock()
2017-07-11 18:48:29 +02:00
return true
}
// add tokens to entry
entry.mu.Lock()
now := rate.timeNow()
2017-07-11 18:48:29 +02:00
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now
if entry.tokens > maxTokens {
entry.tokens = maxTokens
2017-07-11 18:48:29 +02:00
}
// subtract cost of packet
if entry.tokens > packetCost {
entry.tokens -= packetCost
entry.mu.Unlock()
2017-07-11 18:48:29 +02:00
return true
}
entry.mu.Unlock()
2017-07-11 18:48:29 +02:00
return false
}