diff --git a/internal/dialer/dialer.go b/internal/dialer/dialer.go index 5a85512..4982980 100644 --- a/internal/dialer/dialer.go +++ b/internal/dialer/dialer.go @@ -16,7 +16,8 @@ var ErrForbiddenRequest = errors.New("forbidden") // Dialer is a wrapper around net.Dialer that uses a dnscache.Resolver to cache DNS lookups. type Dialer struct { net.Dialer - resolver *dnscache.Resolver + resolver *dnscache.Resolver + blockedIPs []net.IP } // New creates a new Dialer. @@ -25,7 +26,8 @@ func New(resolver *dnscache.Resolver, blockedIps []net.IP) *Dialer { Dialer: net.Dialer{ Control: safeControl(blockedIps), }, - resolver: resolver, + resolver: resolver, + blockedIPs: blockedIps, } } @@ -47,6 +49,15 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (conn return nil, err } for _, ip := range ips { + // Check against the blocked list before attempting socket creation. + // safeControl fires after the socket is opened, which can fail first on + // systems that don't support a particular address family (e.g. IPv6). + parsed := net.ParseIP(ip) + for _, blocked := range d.blockedIPs { + if parsed != nil && parsed.Equal(blocked) { + return nil, ErrForbiddenRequest + } + } conn, err = d.Dialer.DialContext(ctx, network, net.JoinHostPort(ip, port)) if err == nil { break