|
30 | 30 | package netstack |
31 | 31 |
|
32 | 32 | import ( |
| 33 | + "errors" |
33 | 34 | "fmt" |
34 | 35 | "sync/atomic" |
35 | 36 | "time" |
@@ -59,6 +60,8 @@ const errorOnInvalidFD = false |
59 | 60 | // wrapttl is the time to wait for the dispatcher to wrap up (close a previous FD). |
60 | 61 | const waitttl = wrapttl |
61 | 62 |
|
| 63 | +var errNeedsNewEndpoint = errors.New("ns: needs new endpoint") |
| 64 | + |
62 | 65 | type FdSwapper interface { |
63 | 66 | // Swap closes existing FDs; uses new fd. |
64 | 67 | Swap(fd int) error |
@@ -200,7 +203,7 @@ type Options struct { |
200 | 203 | // Makes fd non-blocking, but does not take ownership of fd, which must remain |
201 | 204 | // open for the lifetime of the returned endpoint (until after the endpoint has |
202 | 205 | // stopped being using and Wait returns). |
203 | | -func NewFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) { |
| 206 | +func newFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) { |
204 | 207 | caps := stack.LinkEndpointCapabilities(0) |
205 | 208 | if opts.RXChecksumOffload { |
206 | 209 | caps |= stack.CapabilityRXChecksumOffload |
@@ -255,7 +258,7 @@ func NewFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) { |
255 | 258 |
|
256 | 259 | e.SetMTU(opts.MTU) |
257 | 260 |
|
258 | | - if err := e.Swap(opts.FDs[0]); err != nil { |
| 261 | + if err := e.swap(opts.FDs[0], true); err != nil { |
259 | 262 | return nil, err |
260 | 263 | } |
261 | 264 |
|
@@ -319,16 +322,25 @@ func (e *endpoint) Dispose() (err error) { |
319 | 322 |
|
320 | 323 | // Implements Swapper. |
321 | 324 | func (e *endpoint) Swap(fd int) (err error) { |
| 325 | + return e.swap(fd, false) |
| 326 | +} |
| 327 | + |
| 328 | +func (e *endpoint) swap(fd int, force bool) (err error) { |
322 | 329 | e.Lock() |
323 | 330 | defer e.Unlock() |
324 | 331 |
|
| 332 | + prevfd := e.fds.Load() |
| 333 | + if !force && !prevfd.ok() { |
| 334 | + return errNeedsNewEndpoint |
| 335 | + } // if prevfd is nil, then we're creating a new endpoint |
| 336 | + |
325 | 337 | f, err := newTun(fd) // fd may be invalid (ex: -1) |
326 | 338 | if err != nil || f == nil { |
327 | 339 | f = invalidFds // nilaway |
328 | 340 | err = log.EE("ns: tun: swap: (%d) err: %v / %v; using invalidfd", fd, err) |
329 | 341 | } |
330 | 342 |
|
331 | | - prevfd := e.fds.Swap(f) // commence WritePackets() on fd |
| 343 | + e.fds.Store(f) // commence WritePackets() on fd |
332 | 344 |
|
333 | 345 | log.D("ns: tun: swap: fd %s => %d; err? %v", prevfd, fd, err) |
334 | 346 |
|
|
0 commit comments