Skip to content

Commit 744f849

Browse files
committed
ipn/proxy: run reachability checks concurrently
1 parent 2a786f2 commit 744f849

1 file changed

Lines changed: 90 additions & 66 deletions

File tree

intra/ipn/proxy.go

Lines changed: 90 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -256,28 +256,25 @@ func Reaches(p Proxy, urlOrHostPortOrIPPortCsv string, protos ...string) bool {
256256
return true
257257
}
258258

259-
if urls := httpURLs(urlOrHostPortOrIPPortCsv); len(urls) > 0 {
260-
// For URLs, only test HTTPS connectivity
259+
pid := idstr(p)
260+
hostportOrIPPort := strings.Split(urlOrHostPortOrIPPortCsv, ",")
261+
if urls, oth := extractHttpURLs(urlOrHostPortOrIPPortCsv); len(urls) > 0 {
262+
log.V("proxy: %s reaches: testing for %v", idstr(p), urls)
263+
264+
hostportOrIPPort = oth
261265
tests := make([]core.WorkCtx[bool], 0)
262266
for _, u := range urls {
263267
tests = append(tests, httpsReachesWorkCtx(p, u))
264268
}
265269

266-
pid := idstr(p)
267-
if len(tests) <= 0 {
268-
log.W("proxy: %s reaches: %v; no HTTPS tests", pid, urlOrHostPortOrIPPortCsv)
269-
return false
270-
}
271-
272-
okays, errs := core.All("reach."+pid, 5*time.Second, tests...)
273-
274-
overall := core.IsAll(errs, func(err error) bool { return err == nil }) &&
275-
core.IsAll(okays, func(ok bool) bool { return ok })
270+
ok, who, errs := core.Race("reach.http."+pid, 5*time.Second, tests...)
276271

277-
logeif(overall)("proxy: %s reaches: %v verdict (https): reachable? %t [oks? %v; errs? %v]",
278-
pid, urlOrHostPortOrIPPortCsv, overall, okays, errs)
272+
logeif(!ok)("proxy: %s #%d reaches: %v verdict (https): reachable? %t; errs? %v",
273+
pid, who, urlOrHostPortOrIPPortCsv, ok, errs)
279274

280-
return overall
275+
if !ok || len(oth) <= 0 {
276+
return ok
277+
}
281278
}
282279

283280
// Original logic for host:port or ip:port
@@ -296,7 +293,7 @@ func Reaches(p Proxy, urlOrHostPortOrIPPortCsv string, protos ...string) bool {
296293
// upstream = pdns
297294
// }
298295
ipps := make([]netip.AddrPort, 0)
299-
for x := range strings.SplitSeq(urlOrHostPortOrIPPortCsv, ",") {
296+
for _, x := range hostportOrIPPort {
300297
host, port, err := net.SplitHostPort(x)
301298
if err != nil {
302299
port = "80"
@@ -315,38 +312,53 @@ func Reaches(p Proxy, urlOrHostPortOrIPPortCsv string, protos ...string) bool {
315312
}
316313
}
317314
}
318-
log.V("proxy: %s reaches: testing for %s", p.ID(), ipps)
319-
tests := make([]core.WorkCtx[bool], 0)
315+
316+
n := 0
317+
log.V("proxy: %s reaches: testing for %s", pid, ipps)
318+
tests := make([][]core.WorkCtx[bool], 0)
320319
for _, ipp := range ipps {
320+
fns := make([]core.WorkCtx[bool], 0)
321321
ippstr := ipp.String()
322322
if hastcp {
323-
tests = append(tests, tcpReachesWorkCtx(p, ippstr))
323+
fns = append(fns, tcpReachesWorkCtx(p, ippstr))
324324
}
325325
if hasudp {
326-
tests = append(tests, udpReachesWorkCtx(p, ippstr))
326+
fns = append(fns, udpReachesWorkCtx(p, ippstr))
327327
}
328328
if hasicmp {
329-
tests = append(tests, icmpReachesWorkCtx(p, ipp))
329+
fns = append(fns, icmpReachesWorkCtx(p, ipp))
330330
}
331+
tests = append(tests, fns)
332+
n += len(fns)
331333
}
332334

333-
pid := idstr(p)
334-
if len(tests) <= 0 {
335+
if n <= 0 {
335336
log.W("proxy: %s reaches: %v / %v; no tests for %s",
336337
pid, urlOrHostPortOrIPPortCsv, ipps, protos)
337338
return false
338339
}
339340

340-
okays, errs := core.All("reach."+pid, getproxytimeout, tests...)
341+
ok, who, errs := core.Race("reach"+"."+pid, getproxytimeout, every(pid, tests)...)
341342

342-
// overall is false if any okays is false, or if all errs are not nil
343-
overall := core.IsAll(errs, func(err error) bool { return err == nil }) &&
344-
core.IsAll(okays, func(ok bool) bool { return ok })
343+
logeif(!ok)("proxy: %s #%d reaches: %v => %v verdict (%s): reachable? %t; errs? %v",
344+
pid, who, urlOrHostPortOrIPPortCsv, ipps, protos, ok, errs)
345345

346-
logeif(overall)("proxy: %s reaches: %v => %v verdict (%s): reachable? %t [oks? %v; errs? %v]",
347-
pid, urlOrHostPortOrIPPortCsv, ipps, protos, overall, okays, errs)
346+
return ok
347+
}
348348

349-
return overall
349+
func every(who string, tests [][]core.WorkCtx[bool]) []core.WorkCtx[bool] {
350+
all := make([]core.WorkCtx[bool], 0, len(tests))
351+
for _, t := range tests {
352+
t := t
353+
all = append(all, func(ctx context.Context) (bool, error) {
354+
okays, errs := core.All("reach.all."+who, getproxytimeout, t...)
355+
// overall is false if any okays is false, or if all errs are not nil
356+
overall := core.IsAll(errs, func(err error) bool { return err == nil }) &&
357+
core.IsAll(okays, func(ok bool) bool { return ok })
358+
return overall, core.JoinErr(errs...)
359+
})
360+
}
361+
return all
350362
}
351363

352364
func AnyAddrForUDP(ipp netip.AddrPort) (proto, anyaddr string) {
@@ -588,64 +600,76 @@ func addElem[T comparable](s []T, add T) []T {
588600
return core.WithElem(s, add)
589601
}
590602

591-
// httpURLs extracts valid URLs from comma-separated input
592-
func httpURLs(input string) (urls []*url.URL) {
593-
for x := range strings.SplitSeq(input, ",") {
603+
// extractHttpURLs extracts valid URLs from comma-separated input
604+
func extractHttpURLs(csv string) (urls []*url.URL, oth []string) {
605+
for x := range strings.SplitSeq(csv, ",") {
594606
x = strings.TrimSpace(x)
595607
if len(x) == 0 {
596608
continue
597609
}
598610
// Check if it's a URL (contains scheme)
599611
if u, err := url.Parse(x); err == nil && strings.Contains(u.Scheme, "http") {
600612
urls = append(urls, u)
613+
} else {
614+
oth = append(oth, x)
601615
}
602616
}
603-
return urls
617+
return
604618
}
605619

606620
func httpsReachesWorkCtx(p Proxy, url *url.URL) core.WorkCtx[bool] {
607621
return func(ctx context.Context) (bool, error) {
608-
return httpsReaches(p, url)
622+
requestednetwork := url.Fragment
623+
switch requestednetwork {
624+
case "tcp", "tcp4", "tcp6":
625+
case "udp", "udp4", "udp6":
626+
case "v4", "ipv4":
627+
requestednetwork = "tcp4" // default to tcp4 for v4
628+
case "v6", "ipv6":
629+
requestednetwork = "tcp6" // default to tcp6 for v6
630+
default:
631+
requestednetwork = "tcp" // default to tcp for any other case
632+
}
633+
// Lightweight transport for one-time use
634+
client := &http.Client{
635+
Timeout: 5 * time.Second,
636+
Transport: &http.Transport{
637+
Dial: func(network, addr string) (net.Conn, error) {
638+
if _, err := netip.ParseAddrPort(addr); err != nil {
639+
// addr is likely host:port
640+
network = requestednetwork
641+
}
642+
log.VV("proxy: %s reaches: dial(%s, %s) for %s",
643+
idstr(p), network, addr, url)
644+
return p.Dial(network, addr)
645+
},
646+
// Disable connection pooling for one-time use
647+
DisableKeepAlives: true,
648+
MaxIdleConns: -1,
649+
MaxIdleConnsPerHost: -1,
650+
// Disable compression to reduce overhead
651+
DisableCompression: true,
652+
// Short timeouts for quick failure detection
653+
ResponseHeaderTimeout: 3 * time.Second,
654+
// Prefer h1 to simplify conn handling
655+
ForceAttemptHTTP2: false,
656+
TLSHandshakeTimeout: 3 * time.Second,
657+
},
658+
}
659+
return httpsReaches(idstr(p), client, url)
609660
}
610661
}
611662

612-
func httpsReaches(p Proxy, url *url.URL) (bool, error) {
663+
func httpsReaches(who string, c *http.Client, url *url.URL) (bool, error) {
613664
start := time.Now()
614665

615-
requestednetwork := url.Fragment
616-
switch requestednetwork {
617-
case "tcp", "tcp4", "tcp6":
618-
case "udp", "udp4", "udp6":
619-
case "v4", "ipv4":
620-
requestednetwork = "tcp4" // default to tcp4 for v4
621-
case "v6", "ipv6":
622-
requestednetwork = "tcp6" // default to tcp6 for v6
623-
default:
624-
requestednetwork = "tcp" // default to tcp for any other case
625-
}
626-
627-
// TODO: share http.Transport across checks
628-
client := &http.Client{
629-
Timeout: 5 * time.Second,
630-
Transport: &http.Transport{
631-
Dial: func(network, addr string) (net.Conn, error) {
632-
if _, err := netip.ParseAddrPort(addr); err != nil {
633-
// addr is likely host:port
634-
network = requestednetwork
635-
}
636-
log.VV("proxy: %s reaches: dial(%s, %s) for %s", idstr(p), network, addr, url)
637-
return p.Dial(network, addr)
638-
},
639-
},
640-
}
641-
642666
req, err := http.NewRequest("HEAD", url.String(), nil)
643667
if err != nil {
644668
return false, fmt.Errorf("proxy: reaches: err creating req: %w", err)
645669
}
646670
req.Header.Set("User-Agent", "intra")
647671

648-
resp, err := client.Do(req)
672+
resp, err := c.Do(req)
649673
if resp != nil {
650674
defer core.Close(resp.Body)
651675
}
@@ -658,8 +682,8 @@ func httpsReaches(p Proxy, url *url.URL) (bool, error) {
658682

659683
ok := err == nil && statuscode > 0 && statuscode < 500
660684

661-
logeif(!ok)("proxy: %s reaches: %v (%s); ok? %t, status: %d, rtt: %s; err: %v",
662-
idstr(p), url, requestednetwork, ok, statuscode, core.FmtPeriod(rtt), err)
685+
logeif(!ok)("proxy: %s reaches: %v; ok? %t, status: %d, rtt: %s; err: %v",
686+
who, url, ok, statuscode, core.FmtPeriod(rtt), err)
663687

664688
if ok {
665689
err = nil // wipe out err as it makes core.Race discard "ok"

0 commit comments

Comments
 (0)