Skip to content
17 changes: 15 additions & 2 deletions cmd/gau/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"os"
"sync"
"time"

"github.com/lc/gau/v2/pkg/output"
"github.com/lc/gau/v2/runner"
Expand All @@ -14,6 +15,8 @@ import (
)

func main() {
startTime := time.Now()

cfg, err := flags.New().ReadInConfig()
if err != nil {
log.Warnf("error reading config: %v", err)
Expand Down Expand Up @@ -43,12 +46,13 @@ func main() {
}

var writeWg sync.WaitGroup
var urlCount int64
writeWg.Add(1)
go func(out io.Writer, JSON bool) {
defer writeWg.Done()
if JSON {
output.WriteURLsJSON(out, results, config.Blacklist, config.RemoveParameters)
} else if err = output.WriteURLs(out, results, config.Blacklist, config.RemoveParameters); err != nil {
output.WriteURLsJSON(out, results, config.Blacklist, config.RemoveParameters, &urlCount)
} else if err = output.WriteURLs(out, results, config.Blacklist, config.RemoveParameters, &urlCount); err != nil {
log.Fatalf("error writing results: %v\n", err)
}
}(out, config.JSON)
Expand Down Expand Up @@ -85,4 +89,13 @@ func main() {

// wait for writer to finish output
writeWg.Wait()

// Calculate duration
duration := time.Since(startTime)

// Log summary
log.Infof("=== Gau Execution Summary ===")
log.Infof("Total URLs: %d", urlCount)
log.Infof("Duration: %v", duration)
log.Infof("=============================")
}
90 changes: 87 additions & 3 deletions pkg/httpclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package httpclient

import (
"errors"
"fmt"
"math"
"math/rand"
"strings"
"time"

"github.com/valyala/fasthttp"
Expand All @@ -12,8 +15,23 @@ var (
ErrNilResponse = errors.New("unexpected nil response")
ErrNon200Response = errors.New("API responded with non-200 status code")
ErrBadRequest = errors.New("API responded with 400 status code")
ErrRateLimited = errors.New("API rate limited")
)

// StatusCodeError is an error type that carries an HTTP status code
type StatusCodeError struct {
Code int
Msg string
}

func (e *StatusCodeError) Error() string {
return fmt.Sprintf("%s (status code: %d)", e.Msg, e.Code)
}

func (e *StatusCodeError) Unwrap() error {
return errors.New(e.Msg)
}

type Header struct {
Key string
Value string
Expand All @@ -39,6 +57,28 @@ func MakeRequest(c *fasthttp.Client, url string, maxRetries uint, timeout uint,
req.Header.Set("Accept", "*/*")
req.SetRequestURI(url)
respBody, err = doReq(c, req, timeout)

// Check if we should retry based on error type
if err != nil {
// Exponential backoff: 1s, 2s, 4s, 8s, 16s... with cap at 30s
backoffDuration := time.Duration(math.Pow(2, float64(retries-i))) * time.Second
if backoffDuration > 30*time.Second {
backoffDuration = 30 * time.Second
}
if i > 0 && shouldRetry(err) {
time.Sleep(backoffDuration)
continue
}
}

// Check for rate limit (429) or bad request (400) from error
if err != nil {
statusCode := getStatusCodeFromError(err)
if statusCode == 429 || statusCode == 400 {
return nil, ErrRateLimited
}
}

if err == nil {
break
}
Expand All @@ -49,6 +89,48 @@ func MakeRequest(c *fasthttp.Client, url string, maxRetries uint, timeout uint,
return respBody, nil
}

// shouldRetry determines if an error should trigger a retry
func shouldRetry(err error) bool {
if err == nil {
return false
}
// Network errors that should trigger retry
errMsg := err.Error()
retryableErrors := []string{
"connection refused",
"connection reset",
"connection timed out",
"no such host",
"timeout",
"server closed connection",
"network is unreachable",
"i/o timeout",
}
for _, pattern := range retryableErrors {
if containsIgnoreCase(errMsg, pattern) {
return true
}
}
return false
}

// containsIgnoreCase checks if s contains substr (case-insensitive)
func containsIgnoreCase(s, substr string) bool {
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
}

// getStatusCodeFromError attempts to extract status code from error
func getStatusCodeFromError(err error) int {
if err == nil {
return 0
}
var statusErr *StatusCodeError
if errors.As(err, &statusErr) {
return statusErr.Code
}
return 0
}

// doReq handles http requests
func doReq(c *fasthttp.Client, req *fasthttp.Request, timeout uint) ([]byte, error) {
resp := fasthttp.AcquireResponse()
Expand All @@ -58,10 +140,12 @@ func doReq(c *fasthttp.Client, req *fasthttp.Request, timeout uint) ([]byte, err
return nil, err
}
if resp.StatusCode() != 200 {
if resp.StatusCode() == 400 {
return nil, ErrBadRequest
errMsg := fmt.Sprintf("API responded with status code %d", resp.StatusCode())
// Return wrapped error with status code for proper handling
return nil, &StatusCodeError{
Code: resp.StatusCode(),
Msg: errMsg,
}
return nil, ErrNon200Response
}
if resp.Body() == nil {
return nil, ErrNilResponse
Expand Down
16 changes: 14 additions & 2 deletions pkg/output/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package output
import (
"io"
"net/url"
"os"
"path"
"strings"
"sync/atomic"

mapset "github.com/deckarep/golang-set/v2"
jsoniter "github.com/json-iterator/go"
Expand All @@ -15,7 +17,7 @@ type JSONResult struct {
Url string `json:"url"`
}

func WriteURLs(writer io.Writer, results <-chan string, blacklistMap mapset.Set[string], RemoveParameters bool) error {
func WriteURLs(writer io.Writer, results <-chan string, blacklistMap mapset.Set[string], RemoveParameters bool, urlCount *int64) error {
lastURL := mapset.NewThreadUnsafeSet[string]()
for result := range results {
buf := bytebufferpool.Get()
Expand All @@ -38,12 +40,17 @@ func WriteURLs(writer io.Writer, results <-chan string, blacklistMap mapset.Set[
if err != nil {
return err
}
atomic.AddInt64(urlCount, 1)
// Real-time flush: sync stdout after each write to prevent data loss
if writer == os.Stdout {
os.Stdout.Sync()
}
bytebufferpool.Put(buf)
}
return nil
}

func WriteURLsJSON(writer io.Writer, results <-chan string, blacklistMap mapset.Set[string], RemoveParameters bool) {
func WriteURLsJSON(writer io.Writer, results <-chan string, blacklistMap mapset.Set[string], RemoveParameters bool, urlCount *int64) {
var jr JSONResult
enc := jsoniter.NewEncoder(writer)
for result := range results {
Expand All @@ -59,5 +66,10 @@ func WriteURLsJSON(writer io.Writer, results <-chan string, blacklistMap mapset.
// todo: handle this error
continue
}
atomic.AddInt64(urlCount, 1)
// Real-time flush: sync stdout after each write to prevent data loss
if writer == os.Stdout {
os.Stdout.Sync()
}
}
}
126 changes: 105 additions & 21 deletions pkg/providers/commoncrawl/commoncrawl.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"errors"
"fmt"
"sync"

jsoniter "github.com/json-iterator/go"
"github.com/lc/gau/v2/pkg/httpclient"
Expand Down Expand Up @@ -56,6 +57,11 @@ func (c *Client) Name() string {
func (c *Client) Fetch(ctx context.Context, domain string, results chan string) error {
p, err := c.getPagination(domain)
if err != nil {
logrus.WithFields(logrus.Fields{
"provider": Name,
"domain": domain,
"error": err.Error(),
}).Warn("failed to get pagination for commoncrawl")
return err
}
// 0 pages means no results
Expand All @@ -64,33 +70,111 @@ func (c *Client) Fetch(ctx context.Context, domain string, results chan string)
return nil
}

for page := uint(0); page < p.Pages; page++ {
select {
case <-ctx.Done():
return nil
default:
logrus.WithFields(logrus.Fields{"provider": Name, "page": page}).Infof("fetching %s", domain)
apiURL := c.formatURL(domain, page)
resp, err := httpclient.MakeRequest(c.config.Client, apiURL, c.config.MaxRetries, c.config.Timeout)
if err != nil {
return fmt.Errorf("failed to fetch commoncrawl(%d): %s", page, err)
}
numThreads := c.config.ProviderThreads
if numThreads == 0 {
numThreads = 3
}

sc := bufio.NewScanner(bytes.NewReader(resp))
for sc.Scan() {
var res apiResponse
if err := jsoniter.Unmarshal(sc.Bytes(), &res); err != nil {
return fmt.Errorf("failed to decode commoncrawl result: %s", err)
// Cap threads to actual page count
if numThreads > p.Pages {
numThreads = p.Pages
}

pageChan := make(chan uint, numThreads)
var wg sync.WaitGroup
var fetchErr error
var errMu sync.Mutex

// Page dispatcher: send page numbers
go func() {
defer close(pageChan)
for page := uint(0); page < p.Pages; page++ {
select {
case <-ctx.Done():
return
case pageChan <- page:
}
}
}()

// Workers: fetch pages from pageChan
for i := uint(0); i < numThreads; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for page := range pageChan {
select {
case <-ctx.Done():
return
default:
}
if res.Error != "" {
return fmt.Errorf("received an error from commoncrawl: %s", res.Error)
logrus.WithFields(logrus.Fields{"provider": Name, "page": page}).Infof("fetching %s", domain)
apiURL := c.formatURL(domain, page)
resp, err := httpclient.MakeRequest(c.config.Client, apiURL, c.config.MaxRetries, c.config.Timeout)
if err != nil {
var statusErr *httpclient.StatusCodeError
if errors.As(err, &statusErr) {
logrus.WithFields(logrus.Fields{
"provider": Name,
"domain": domain,
"page": page,
"status": statusErr.Code,
"error": statusErr.Error(),
}).Warn("CommonCrawl HTTP error")
} else {
logrus.WithFields(logrus.Fields{
"provider": Name,
"domain": domain,
"page": page,
"error": err.Error(),
}).Warn("failed to fetch commoncrawl")
}
errMu.Lock()
if fetchErr == nil {
fetchErr = fmt.Errorf("failed to fetch commoncrawl(%d): %s", page, err)
}
errMu.Unlock()
continue
}

results <- res.URL
sc := bufio.NewScanner(bytes.NewReader(resp))
for sc.Scan() {
var res apiResponse
if err := jsoniter.Unmarshal(sc.Bytes(), &res); err != nil {
errMu.Lock()
if fetchErr == nil {
fetchErr = fmt.Errorf("failed to decode commoncrawl result: %s", err)
}
errMu.Unlock()
continue
}
if res.Error != "" {
logrus.WithFields(logrus.Fields{
"provider": Name,
"domain": domain,
"page": page,
"response": res.Error,
}).Warn("CommonCrawl API error")
errMu.Lock()
if fetchErr == nil {
fetchErr = fmt.Errorf("received an error from commoncrawl: %s", res.Error)
}
errMu.Unlock()
continue
}

select {
case <-ctx.Done():
return
case results <- res.URL:
}
}
}
}
}()
}
return nil

wg.Wait()
return fetchErr
}

func (c *Client) formatURL(domain string, page uint) string {
Expand Down
Loading