Files
wireguard-go/conn/bind_std.go
T

416 lines
9.0 KiB
Go
Raw Normal View History

2019-01-02 01:55:51 +01:00
/* SPDX-License-Identifier: MIT
2018-05-03 15:04:00 +02:00
*
2022-09-20 17:21:32 +02:00
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
2018-05-03 15:04:00 +02:00
*/
package conn
2017-08-25 14:53:23 +02:00
import (
"context"
2021-02-09 19:46:57 +01:00
"errors"
2017-08-25 14:53:23 +02:00
"net"
2022-03-16 16:09:48 -07:00
"net/netip"
2023-03-06 15:58:32 -08:00
"runtime"
"strconv"
"sync"
2018-06-11 19:04:38 +02:00
"syscall"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
2017-08-25 14:53:23 +02:00
)
var (
_ Bind = (*StdNetBind)(nil)
)
2023-03-06 15:58:32 -08:00
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
// (see bind_windows.go), it may fall back to StdNetBind.
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
2021-02-22 02:01:50 +01:00
type StdNetBind struct {
2023-03-06 15:58:32 -08:00
mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
udpAddrPool sync.Pool // following fields are not guarded by mu
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
2017-11-19 00:21:58 +01:00
}
2017-11-17 17:25:45 +01:00
func NewStdNetBind() Bind {
return &StdNetBind{
udpAddrPool: sync.Pool{
New: func() any {
return &net.UDPAddr{
IP: make([]byte, 16),
}
},
},
ipv4MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
ipv6MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
}
}
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if supported.
src struct {
netip.Addr
ifidx int32
}
}
2017-11-19 00:21:58 +01:00
2021-12-09 17:55:50 +01:00
var (
_ Bind = (*StdNetBind)(nil)
_ Endpoint = &StdNetEndpoint{}
2021-12-09 17:55:50 +01:00
)
2021-02-22 02:01:50 +01:00
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
2021-11-05 01:52:54 +01:00
e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return &StdNetEndpoint{
AddrPort: e,
}, nil
2017-11-19 00:21:58 +01:00
}
func (e *StdNetEndpoint) ClearSrc() {
e.src.ifidx = 0
e.src.Addr = netip.Addr{}
2017-11-19 00:21:58 +01:00
}
func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
2017-11-19 00:21:58 +01:00
}
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return e.src.Addr
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return e.src.ifidx
}
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
2022-03-17 22:23:02 -06:00
return b
2017-11-19 00:21:58 +01:00
}
func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
2017-11-19 00:21:58 +01:00
}
func (e *StdNetEndpoint) SrcToString() string {
return e.src.Addr.String()
2017-11-19 00:21:58 +01:00
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
2017-11-19 00:21:58 +01:00
if err != nil {
return nil, 0, err
}
// Retrieve port.
2017-11-19 00:21:58 +01:00
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
2017-11-19 00:21:58 +01:00
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn.(*net.UDPConn), uaddr.Port, nil
2017-11-19 00:21:58 +01:00
}
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
s.mu.Lock()
defer s.mu.Unlock()
var err error
var tries int
2017-11-17 17:25:45 +01:00
if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
2021-02-22 02:01:50 +01:00
}
// Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var v4conn, v6conn *net.UDPConn
2023-03-06 15:58:32 -08:00
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn
2017-11-17 17:25:45 +01:00
v4conn, port, err = listenNet("udp4", port)
2021-02-09 19:46:57 +01:00
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
2017-11-17 17:25:45 +01:00
}
// Listen on the same port as we're using for ipv4.
v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
v4conn.Close()
tries++
goto again
}
2021-02-09 19:46:57 +01:00
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
if v4conn != nil {
2023-03-06 15:58:32 -08:00
if runtime.GOOS == "linux" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
s.ipv4 = v4conn
2021-02-22 02:01:50 +01:00
}
if v6conn != nil {
2023-03-06 15:58:32 -08:00
if runtime.GOOS == "linux" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, uint16(port), nil
}
2023-03-06 15:58:32 -08:00
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
2023-03-13 17:55:05 +01:00
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
2023-03-06 15:58:32 -08:00
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
defer s.ipv4MsgsPool.Put(msgs)
2023-03-13 17:55:05 +01:00
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
2023-03-06 15:58:32 -08:00
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
2023-03-06 15:58:32 -08:00
eps[i] = ep
}
return numMsgs, nil
}
}
2023-03-06 15:58:32 -08:00
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
2023-03-13 17:55:05 +01:00
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
2023-03-06 15:58:32 -08:00
msgs := s.ipv4MsgsPool.Get().(*[]ipv6.Message)
defer s.ipv4MsgsPool.Put(msgs)
2023-03-13 17:55:05 +01:00
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
2023-03-06 15:58:32 -08:00
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
2023-03-06 15:58:32 -08:00
eps[i] = ep
}
return numMsgs, nil
}
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (s *StdNetBind) BatchSize() int {
if runtime.GOOS == "linux" {
return IdealBatchSize
}
return 1
}
func (s *StdNetBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
2018-06-11 19:04:38 +02:00
var err1, err2 error
if s.ipv4 != nil {
err1 = s.ipv4.Close()
s.ipv4 = nil
2023-03-06 15:58:32 -08:00
s.ipv4PC = nil
2018-06-11 19:04:38 +02:00
}
if s.ipv6 != nil {
err2 = s.ipv6.Close()
s.ipv6 = nil
2023-03-06 15:58:32 -08:00
s.ipv6PC = nil
2018-06-11 19:04:38 +02:00
}
s.blackhole4 = false
s.blackhole6 = false
if err1 != nil {
return err1
}
return err2
}
2023-03-13 17:55:05 +01:00
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
2023-03-06 15:58:32 -08:00
var (
pc4 *ipv4.PacketConn
pc6 *ipv6.PacketConn
)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
2023-03-06 15:58:32 -08:00
pc6 = s.ipv6PC
is6 = true
2023-03-06 15:58:32 -08:00
} else {
pc4 = s.ipv4PC
2018-06-11 19:04:38 +02:00
}
s.mu.Unlock()
2021-03-29 13:11:11 -07:00
if blackhole {
return nil
}
if conn == nil {
return syscall.EAFNOSUPPORT
}
if is6 {
2023-03-13 17:55:05 +01:00
return s.send6(conn, pc6, endpoint, bufs)
} else {
2023-03-13 17:55:05 +01:00
return s.send4(conn, pc4, endpoint, bufs)
}
}
2023-03-13 17:55:05 +01:00
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
2023-03-13 17:55:05 +01:00
for i, buf := range bufs {
(*msgs)[i].Buffers[0] = buf
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
2023-03-06 15:58:32 -08:00
if runtime.GOOS == "linux" {
for {
2023-03-13 17:55:05 +01:00
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
if err != nil || n == len((*msgs)[start:len(bufs)]) {
2023-03-06 15:58:32 -08:00
break
}
start += n
}
} else {
2023-03-13 17:55:05 +01:00
for i, buf := range bufs {
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
2023-03-06 15:58:32 -08:00
if err != nil {
break
}
}
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err
}
2023-03-13 17:55:05 +01:00
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
2023-03-13 17:55:05 +01:00
for i, buf := range bufs {
(*msgs)[i].Buffers[0] = buf
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
2023-03-06 15:58:32 -08:00
if runtime.GOOS == "linux" {
for {
2023-03-13 17:55:05 +01:00
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
if err != nil || n == len((*msgs)[start:len(bufs)]) {
2023-03-06 15:58:32 -08:00
break
}
start += n
}
} else {
2023-03-13 17:55:05 +01:00
for i, buf := range bufs {
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
2023-03-06 15:58:32 -08:00
if err != nil {
break
}
}
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
return err
}