Skip to content

Commit 900829a

Browse files
committed
netstack: swap may create a new fdbased endpoint
1 parent 994f58c commit 900829a

2 files changed

Lines changed: 47 additions & 5 deletions

File tree

intra/netstack/fdbased.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
package netstack
3131

3232
import (
33+
"errors"
3334
"fmt"
3435
"sync/atomic"
3536
"time"
@@ -59,6 +60,8 @@ const errorOnInvalidFD = false
5960
// wrapttl is the time to wait for the dispatcher to wrap up (close a previous FD).
6061
const waitttl = wrapttl
6162

63+
var errNeedsNewEndpoint = errors.New("ns: needs new endpoint")
64+
6265
type FdSwapper interface {
6366
// Swap closes existing FDs; uses new fd.
6467
Swap(fd int) error
@@ -200,7 +203,7 @@ type Options struct {
200203
// Makes fd non-blocking, but does not take ownership of fd, which must remain
201204
// open for the lifetime of the returned endpoint (until after the endpoint has
202205
// stopped being using and Wait returns).
203-
func NewFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) {
206+
func newFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) {
204207
caps := stack.LinkEndpointCapabilities(0)
205208
if opts.RXChecksumOffload {
206209
caps |= stack.CapabilityRXChecksumOffload
@@ -255,7 +258,7 @@ func NewFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) {
255258

256259
e.SetMTU(opts.MTU)
257260

258-
if err := e.Swap(opts.FDs[0]); err != nil {
261+
if err := e.swap(opts.FDs[0], true); err != nil {
259262
return nil, err
260263
}
261264

@@ -319,16 +322,25 @@ func (e *endpoint) Dispose() (err error) {
319322

320323
// Implements Swapper.
321324
func (e *endpoint) Swap(fd int) (err error) {
325+
return e.swap(fd, false)
326+
}
327+
328+
func (e *endpoint) swap(fd int, force bool) (err error) {
322329
e.Lock()
323330
defer e.Unlock()
324331

332+
prevfd := e.fds.Load()
333+
if !force && !prevfd.ok() {
334+
return errNeedsNewEndpoint
335+
} // if prevfd is nil, then we're creating a new endpoint
336+
325337
f, err := newTun(fd) // fd may be invalid (ex: -1)
326338
if err != nil || f == nil {
327339
f = invalidFds // nilaway
328340
err = log.EE("ns: tun: swap: (%d) err: %v / %v; using invalidfd", fd, err)
329341
}
330342

331-
prevfd := e.fds.Swap(f) // commence WritePackets() on fd
343+
e.fds.Store(f) // commence WritePackets() on fd
332344

333345
log.D("ns: tun: swap: fd %s => %d; err? %v", prevfd, fd, err)
334346

intra/netstack/netstack.go

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import (
1111
"fmt"
1212
"io"
1313
"net/netip"
14+
"strconv"
1415
"strings"
16+
"sync"
1517
"syscall"
1618

1719
"github.com/celzero/firestack/intra/core"
@@ -35,11 +37,39 @@ const nicfwd = false
3537
// packets will be truncated to snapLen.
3638
const SnapLen uint32 = 2048 // in bytes; some sufficient value
3739

40+
var (
41+
errNoFdSwapper = errors.New("linkFdSwap: no FdSwapper")
42+
)
43+
3844
type linkFdSwap struct {
45+
sync.Mutex
3946
stack.LinkEndpoint
4047
FdSwapper
4148
}
4249

50+
// Swap implements FdSwapper.
51+
func (l *linkFdSwap) Swap(fd int) error {
52+
l.Lock()
53+
defer l.Unlock()
54+
55+
if l.FdSwapper == nil {
56+
return errNoFdSwapper
57+
}
58+
59+
err := l.FdSwapper.Swap(fd)
60+
if errors.Is(err, errNeedsNewEndpoint) {
61+
umtu := uint32(l.MTU())
62+
opt := Options{
63+
FDs: []int{fd},
64+
MTU: umtu,
65+
}
66+
core.Go("linkFdSwap."+strconv.Itoa(fd), l.LinkEndpoint.Close)
67+
l.LinkEndpoint, err = newFdbasedInjectableEndpoint(&opt)
68+
}
69+
70+
return err
71+
}
72+
4373
// ref: github.com/google/gvisor/blob/91f58d2cc/pkg/tcpip/sample/tun_tcp_echo/main.go#L102
4474
func NewEndpoint(dev, mtu int, sink io.WriteCloser) (ep SeamlessEndpoint, err error) {
4575
defer func() {
@@ -55,7 +85,7 @@ func NewEndpoint(dev, mtu int, sink io.WriteCloser) (ep SeamlessEndpoint, err er
5585
MTU: umtu,
5686
}
5787

58-
if ep, err = NewFdbasedInjectableEndpoint(&opt); err != nil {
88+
if ep, err = newFdbasedInjectableEndpoint(&opt); err != nil {
5989
return nil, err
6090
}
6191
// ref: github.com/google/gvisor/blob/aeabb785278/pkg/tcpip/link/sniffer/sniffer.go#L111-L131
@@ -70,7 +100,7 @@ func snoop(ep SeamlessEndpoint, sink io.WriteCloser) (SeamlessEndpoint, error) {
70100
if link, err := NewSnoopyEndpoint(ep, sink); err != nil {
71101
return nil, err
72102
} else {
73-
return linkFdSwap{link, ep}, nil
103+
return &linkFdSwap{sync.Mutex{}, link, ep}, nil
74104
}
75105
}
76106

0 commit comments

Comments
 (0)