diff --git a/.gitignore b/.gitignore index 09b97d5..91eefcd 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ vendor dist out/ cache/ +/challenger diff --git a/cmd/challenger/main.go b/cmd/challenger/main.go index 5ece36b..55922b6 100644 --- a/cmd/challenger/main.go +++ b/cmd/challenger/main.go @@ -19,10 +19,10 @@ import ( "context" _ "embed" "fmt" - "math/big" "net/http" "os" "os/signal" + "strings" "sync" challenger "github.com/chronicleprotocol/challenger/core" @@ -53,15 +53,10 @@ type options struct { FromBlock int64 ChainID uint64 TransactionType string + MetricsAddr string + LogLevel string } -var ( - maxGasLimit = uint64(0) - maxGasFee = (*big.Int)(nil) - maxGasPriorityFee = (*big.Int)(nil) - gasFeeMultiplier = float64(1) - gasPriorityFeeMultiplier = float64(1) -) // Checks and return private key based on given options func (o *options) getKey() (*wallet.PrivateKey, error) { @@ -92,7 +87,7 @@ func (o *options) getKey() (*wallet.PrivateKey, error) { if err != nil { return nil, fmt.Errorf("failed to read password file: %v", err) } - password = string(p) + password = strings.TrimRight(string(p), "\n\r") } return wallet.NewKeyFromJSON(o.Key, password) @@ -102,11 +97,14 @@ func main() { var opts options cmd := &cobra.Command{ Use: "run", - Args: cobra.ExactArgs(1), + Args: cobra.NoArgs, Aliases: []string{"agent"}, Run: func(cmd *cobra.Command, args []string) { - // TODO: update after completion - logger.SetLevel(logger.DebugLevel) + lvl, err := logger.ParseLevel(opts.LogLevel) + if err != nil { + logger.Fatalf("Invalid log level %q: %v", opts.LogLevel, err) + } + logger.SetLevel(lvl) logger.Debugf("Hello, Challenger!") @@ -164,19 +162,19 @@ func main() { switch opts.TransactionType { case "legacy": txModifiers = append(txModifiers, txmodifier.NewLegacyGasFeeEstimator(txmodifier.LegacyGasFeeEstimatorOptions{ - Multiplier: gasFeeMultiplier, + Multiplier: 1, MinGasPrice: nil, - MaxGasPrice: maxGasFee, + MaxGasPrice: nil, Replace: false, })) case "eip1559": txModifiers = append(txModifiers, txmodifier.NewEIP1559GasFeeEstimator(txmodifier.EIP1559GasFeeEstimatorOptions{ - GasPriceMultiplier: gasFeeMultiplier, - PriorityFeePerGasMultiplier: gasPriorityFeeMultiplier, + GasPriceMultiplier: 1, + PriorityFeePerGasMultiplier: 1, MinGasPrice: nil, - MaxGasPrice: maxGasFee, + MaxGasPrice: nil, MinPriorityFeePerGas: nil, - MaxPriorityFeePerGas: maxGasPriorityFee, + MaxPriorityFeePerGas: nil, Replace: false, })) case "", "none": @@ -194,7 +192,7 @@ func main() { // Set manual gas limit for flashbots, they might require more gas. //nolint:gocritic baseTxModifiers := append(txModifiers, txmodifier.NewGasLimitEstimator(txmodifier.GasLimitEstimatorOptions{ - MaxGas: maxGasLimit, + MaxGas: 0, Multiplier: defaultGasLimitMultiplier, })) @@ -255,7 +253,6 @@ func main() { challenger.ErrorsCounter.WithLabelValues( addr.String(), p.GetFrom(ctx).String(), - err.Error(), ).Inc() logger.Fatalf("Failed to run challenger: %v", err) @@ -270,11 +267,16 @@ func main() { challenger.LastScannedBlockGauge, ) http.Handle("/metrics", promhttp.Handler()) - // TODO: move `:9090` to config - logger. - WithError(http.ListenAndServe(":9090", nil)). //nolint:gosec - Error("metrics server error") - <-ctx.Done() + srv := &http.Server{Addr: opts.MetricsAddr} //nolint:gosec + go func() { + <-ctx.Done() + if err := srv.Shutdown(context.Background()); err != nil { + logger.WithError(err).Error("metrics server shutdown error") + } + }() + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.WithError(err).Error("metrics server error") + } }() wg.Wait() @@ -292,6 +294,10 @@ func main() { Int64Var(&opts.FromBlock, "from-block", 0, "Block number to start from. If not provided, binary will try to get it from given RPC") cmd.PersistentFlags().Uint64Var(&opts.ChainID, "chain-id", 0, "If no chain_id provided binary will try to get chain_id from given RPC") cmd.PersistentFlags().StringVar(&opts.TransactionType, "tx-type", "none", "Transaction type definition, possible values are: `legacy`, `eip1559` or `none`") + cmd.PersistentFlags().StringVar(&opts.MetricsAddr, "metrics-addr", ":9090", "Address for the Prometheus metrics server") + cmd.PersistentFlags().StringVar(&opts.LogLevel, "log-level", "info", "Log level: trace, debug, info, warn, error, fatal, panic") - _ = cmd.Execute() + if err := cmd.Execute(); err != nil { + os.Exit(1) + } } diff --git a/core/challenger.go b/core/challenger.go index 4864a2a..ebe4968 100644 --- a/core/challenger.go +++ b/core/challenger.go @@ -38,6 +38,8 @@ type Challenger struct { provider IScribeOptimisticProvider lastProcessedBlock *big.Int wg *sync.WaitGroup + inFlight map[uint64]struct{} + inFlightMu sync.Mutex } // NewChallenger creates a new instance of Challenger. @@ -58,6 +60,7 @@ func NewChallenger( provider: provider, lastProcessedBlock: latestBlock, wg: wg, + inFlight: make(map[uint64]struct{}), } } @@ -125,8 +128,28 @@ func (c *Challenger) isPokeChallengeable(poke *OpPokedEvent, challengePeriod uin } // SpawnChallenge spawns new goroutine and challenges the `OpPoked` event. +// It skips the challenge if one is already in-flight for the same block number. func (c *Challenger) SpawnChallenge(poke *OpPokedEvent) { + blockNum := poke.BlockNumber.Uint64() + + c.inFlightMu.Lock() + if _, ok := c.inFlight[blockNum]; ok { + c.inFlightMu.Unlock() + logger. + WithField("address", c.address). + Debugf("Skipping duplicate challenge for block %v, already in-flight", poke.BlockNumber) + return + } + c.inFlight[blockNum] = struct{}{} + c.inFlightMu.Unlock() + go func() { + defer func() { + c.inFlightMu.Lock() + delete(c.inFlight, blockNum) + c.inFlightMu.Unlock() + }() + logger. WithField("address", c.address). Warnf("Challenging OpPoked event from block %v", poke.BlockNumber) @@ -212,31 +235,33 @@ func (c *Challenger) executeTick() error { return nil } +func (c *Challenger) handleTickError(err error) { + if err == nil { + return + } + logger. + WithField("address", c.address). + Errorf("Failed to execute tick with error: %v", err) + ErrorsCounter.WithLabelValues( + c.address.String(), + c.provider.GetFrom(c.ctx).String(), + ).Inc() +} + // Run starts the challenger processing loop. -// If you provide `subscriptionURL` - it will listen for events from WS connection otherwise, it will poll for new events every 30 seconds. +// It polls for new events every 30 seconds. func (c *Challenger) Run() error { defer c.wg.Done() // Executing first tick - err := c.executeTick() - if err != nil { - logger. - WithField("address", c.address). - Errorf("Failed to execute tick with error: %v", err) - - // Add error to metrics - ErrorsCounter.WithLabelValues( - c.address.String(), - c.provider.GetFrom(c.ctx).String(), - err.Error(), - ).Inc() - } + c.handleTickError(c.executeTick()) logger. WithField("address", c.address). Infof("Started contract monitoring") ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() for { select { @@ -251,18 +276,7 @@ func (c *Challenger) Run() error { WithField("address", c.address). Debugf("Tick at: %v", t) - err := c.executeTick() - if err != nil { - logger. - WithField("address", c.address). - Errorf("Failed to execute tick with error: %v", err) - // Add error to metrics - ErrorsCounter.WithLabelValues( - c.address.String(), - c.provider.GetFrom(c.ctx).String(), - err.Error(), - ).Inc() - } + c.handleTickError(c.executeTick()) } } } @@ -278,7 +292,7 @@ func PickUnchallengedPokes(pokes []*OpPokedEvent, challenges []*OpPokeChallenged if len(pokes) == 1 { for _, challenge := range challenges { - if challenge.BlockNumber.Cmp(pokes[0].BlockNumber) == -1 { + if challenge.BlockNumber.Cmp(pokes[0].BlockNumber) >= 0 { return result } } @@ -301,7 +315,7 @@ func PickUnchallengedPokes(pokes []*OpPokedEvent, challenges []*OpPokeChallenged result = append(result, ev) continue } - if len(sortable)-1 > i+1 && sortable[i+1].Name() == "OpPokeChallengedSuccessfullyEvent" { + if i+1 < len(sortable) && sortable[i+1].Name() == "OpPokeChallengedSuccessfullyEvent" { continue } result = append(result, ev) diff --git a/core/challenger_test.go b/core/challenger_test.go index 2814ec8..db16539 100644 --- a/core/challenger_test.go +++ b/core/challenger_test.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "math/big" + "sync" "testing" "time" @@ -162,3 +163,474 @@ func TestIsPokeChallengeable(t *testing.T) { isPokeValidCall.Unset() call.Unset() } + +func TestPickUnchallengedPokes(t *testing.T) { + mkPoke := func(block int64) *OpPokedEvent { + return &OpPokedEvent{BlockNumber: big.NewInt(block)} + } + mkChallenge := func(block int64) *OpPokeChallengedSuccessfullyEvent { + return &OpPokeChallengedSuccessfullyEvent{BlockNumber: big.NewInt(block)} + } + + t.Run("no pokes returns empty", func(t *testing.T) { + result := PickUnchallengedPokes(nil, []*OpPokeChallengedSuccessfullyEvent{mkChallenge(100)}) + assert.Nil(t, result) + }) + + t.Run("no challenges returns all pokes", func(t *testing.T) { + pokes := []*OpPokedEvent{mkPoke(100), mkPoke(200)} + result := PickUnchallengedPokes(pokes, nil) + assert.Equal(t, pokes, result) + }) + + t.Run("single poke with challenge AFTER is challenged", func(t *testing.T) { + pokes := []*OpPokedEvent{mkPoke(100)} + challenges := []*OpPokeChallengedSuccessfullyEvent{mkChallenge(105)} + result := PickUnchallengedPokes(pokes, challenges) + assert.Empty(t, result, "poke at block 100 should be filtered out because challenge at block 105 is after it") + }) + + t.Run("single poke with challenge at SAME block is challenged (couldn't happen in real life)", func(t *testing.T) { + pokes := []*OpPokedEvent{mkPoke(100)} + challenges := []*OpPokeChallengedSuccessfullyEvent{mkChallenge(100)} + result := PickUnchallengedPokes(pokes, challenges) + assert.Empty(t, result, "poke at block 100 should be filtered out because challenge is at the same block") + }) + + t.Run("single poke with challenge BEFORE is unchallenged", func(t *testing.T) { + pokes := []*OpPokedEvent{mkPoke(100)} + challenges := []*OpPokeChallengedSuccessfullyEvent{mkChallenge(50)} + result := PickUnchallengedPokes(pokes, challenges) + require.Len(t, result, 1, "poke at block 100 should remain because challenge at block 50 is for a previous poke") + assert.Equal(t, big.NewInt(100), result[0].BlockNumber) + }) + + // Multi-poke cases (issue 1.2) + + t.Run("two pokes, first challenged between them", func(t *testing.T) { + // sorted: [Poke@100, Challenge@105, Poke@200] + pokes := []*OpPokedEvent{mkPoke(100), mkPoke(200)} + challenges := []*OpPokeChallengedSuccessfullyEvent{mkChallenge(105)} + result := PickUnchallengedPokes(pokes, challenges) + require.Len(t, result, 1, "only poke@200 should remain unchallenged") + assert.Equal(t, big.NewInt(200), result[0].BlockNumber) + }) + + t.Run("two pokes, second challenged after it (challenge is last element)", func(t *testing.T) { + // sorted: [Poke@100, Poke@200, Challenge@205] + // This is the off-by-one bug: Poke@200 is second-to-last, Challenge@205 is last + pokes := []*OpPokedEvent{mkPoke(100), mkPoke(200)} + challenges := []*OpPokeChallengedSuccessfullyEvent{mkChallenge(205)} + result := PickUnchallengedPokes(pokes, challenges) + require.Len(t, result, 1, "only poke@100 should remain, poke@200 was challenged at 205") + assert.Equal(t, big.NewInt(100), result[0].BlockNumber) + }) + + t.Run("two pokes, both challenged", func(t *testing.T) { + // sorted: [Poke@100, Challenge@105, Poke@200, Challenge@205] + pokes := []*OpPokedEvent{mkPoke(100), mkPoke(200)} + challenges := []*OpPokeChallengedSuccessfullyEvent{mkChallenge(105), mkChallenge(205)} + result := PickUnchallengedPokes(pokes, challenges) + assert.Empty(t, result, "both pokes were challenged") + }) + + t.Run("three pokes, middle one challenged", func(t *testing.T) { + // sorted: [Poke@100, Poke@200, Challenge@205, Poke@300] + pokes := []*OpPokedEvent{mkPoke(100), mkPoke(200), mkPoke(300)} + challenges := []*OpPokeChallengedSuccessfullyEvent{mkChallenge(205)} + result := PickUnchallengedPokes(pokes, challenges) + require.Len(t, result, 2, "poke@100 and poke@300 should remain") + assert.Equal(t, big.NewInt(100), result[0].BlockNumber) + assert.Equal(t, big.NewInt(300), result[1].BlockNumber) + }) + + t.Run("two pokes, no challenges between or after", func(t *testing.T) { + // sorted: [Challenge@50, Poke@100, Poke@200] + pokes := []*OpPokedEvent{mkPoke(100), mkPoke(200)} + challenges := []*OpPokeChallengedSuccessfullyEvent{mkChallenge(50)} + result := PickUnchallengedPokes(pokes, challenges) + require.Len(t, result, 2, "both pokes should remain, challenge@50 is for a previous poke") + assert.Equal(t, big.NewInt(100), result[0].BlockNumber) + assert.Equal(t, big.NewInt(200), result[1].BlockNumber) + }) +} + +func TestSpawnChallengeDuplicateProtection(t *testing.T) { + address := types.MustAddressFromHex("0x1F7acDa376eF37EC371235a094113dF9Cb4EfEe1") + from := types.MustAddressFromHex("0x0000000000000000000000000000000000000001") + txHash := types.MustHashFromHex("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", types.PadNone) + + t.Run("first call proceeds, second call with same block is skipped", func(t *testing.T) { + mockedProvider := new(mockScribeOptimisticProvider) + + // Use a channel to block ChallengePoke so the goroutine stays in-flight. + gate := make(chan struct{}) + mockedProvider.On("ChallengePoke", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { <-gate }). + Return(&txHash, &types.Transaction{}, nil) + mockedProvider.On("GetFrom", mock.Anything).Return(from) + + c := NewChallenger(context.TODO(), address, mockedProvider, 0, &sync.WaitGroup{}) + poke := &OpPokedEvent{BlockNumber: big.NewInt(1000)} + + // First call should proceed and mark block 1000 as in-flight. + c.SpawnChallenge(poke) + + // Give goroutine time to start and hit the gate. + time.Sleep(50 * time.Millisecond) + + // Second call with same block should be skipped (no additional goroutine). + c.SpawnChallenge(poke) + + // ChallengePoke should have been called exactly once at this point. + mockedProvider.AssertNumberOfCalls(t, "ChallengePoke", 1) + + // Unblock the goroutine. + close(gate) + + // Wait for the goroutine to finish and clean up. + time.Sleep(50 * time.Millisecond) + + // After completion, block 1000 should no longer be in-flight. + c.inFlightMu.Lock() + _, stillInFlight := c.inFlight[1000] + c.inFlightMu.Unlock() + assert.False(t, stillInFlight, "block 1000 should be removed from in-flight after goroutine completes") + }) + + t.Run("block can be re-challenged after goroutine completes", func(t *testing.T) { + mockedProvider := new(mockScribeOptimisticProvider) + mockedProvider.On("ChallengePoke", mock.Anything, mock.Anything, mock.Anything). + Return(&txHash, &types.Transaction{}, nil) + mockedProvider.On("GetFrom", mock.Anything).Return(from) + + c := NewChallenger(context.TODO(), address, mockedProvider, 0, &sync.WaitGroup{}) + poke := &OpPokedEvent{BlockNumber: big.NewInt(2000)} + + // First challenge. + c.SpawnChallenge(poke) + time.Sleep(50 * time.Millisecond) + + // After first goroutine completes, block should be removed from in-flight. + // Second call should proceed normally. + c.SpawnChallenge(poke) + time.Sleep(50 * time.Millisecond) + + // ChallengePoke should have been called twice total. + mockedProvider.AssertNumberOfCalls(t, "ChallengePoke", 2) + }) + + t.Run("different block numbers can be challenged concurrently", func(t *testing.T) { + mockedProvider := new(mockScribeOptimisticProvider) + + gate := make(chan struct{}) + mockedProvider.On("ChallengePoke", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { <-gate }). + Return(&txHash, &types.Transaction{}, nil) + mockedProvider.On("GetFrom", mock.Anything).Return(from) + + c := NewChallenger(context.TODO(), address, mockedProvider, 0, &sync.WaitGroup{}) + + poke1 := &OpPokedEvent{BlockNumber: big.NewInt(3000)} + poke2 := &OpPokedEvent{BlockNumber: big.NewInt(4000)} + + c.SpawnChallenge(poke1) + c.SpawnChallenge(poke2) + + time.Sleep(50 * time.Millisecond) + + // Both should have started since they have different block numbers. + mockedProvider.AssertNumberOfCalls(t, "ChallengePoke", 2) + + close(gate) + time.Sleep(50 * time.Millisecond) + }) +} + +func TestExecuteTick(t *testing.T) { + address := types.MustAddressFromHex("0x1F7acDa376eF37EC371235a094113dF9Cb4EfEe1") + from := types.MustAddressFromHex("0x0000000000000000000000000000000000000001") + txHash := types.MustHashFromHex("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", types.PadNone) + + t.Run("error on BlockNumber failure", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + p.On("BlockNumber", mock.Anything).Return((*big.Int)(nil), fmt.Errorf("rpc down")) + + c := NewChallenger(context.TODO(), address, p, 100, nil) + err := c.executeTick() + assert.ErrorContains(t, err, "failed to get latest block number") + p.AssertExpectations(t) + }) + + t.Run("error on GetChallengePeriod failure", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + p.On("BlockNumber", mock.Anything).Return(big.NewInt(1000), nil) + p.On("GetChallengePeriod", mock.Anything, address).Return(0, fmt.Errorf("contract error")) + + c := NewChallenger(context.TODO(), address, p, 100, nil) + err := c.executeTick() + assert.ErrorContains(t, err, "failed to get challenge period") + p.AssertExpectations(t) + }) + + t.Run("error on GetPokes failure", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + p.On("BlockNumber", mock.Anything).Return(big.NewInt(1000), nil) + p.On("GetChallengePeriod", mock.Anything, address).Return(600, nil) + p.On("GetPokes", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return(([]*OpPokedEvent)(nil), fmt.Errorf("logs error")) + + c := NewChallenger(context.TODO(), address, p, 100, nil) + err := c.executeTick() + assert.ErrorContains(t, err, "failed to get OpPoked events") + p.AssertExpectations(t) + }) + + t.Run("no pokes returns nil and updates lastProcessedBlock", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + p.On("BlockNumber", mock.Anything).Return(big.NewInt(1000), nil) + p.On("GetChallengePeriod", mock.Anything, address).Return(600, nil) + p.On("GetPokes", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return([]*OpPokedEvent{}, nil) + p.On("GetFrom", mock.Anything).Return(from) + + c := NewChallenger(context.TODO(), address, p, 100, nil) + err := c.executeTick() + assert.NoError(t, err) + assert.Equal(t, big.NewInt(1000), c.lastProcessedBlock) + p.AssertExpectations(t) + // GetSuccessfulChallenges should not be called when there are no pokes. + p.AssertNotCalled(t, "GetSuccessfulChallenges") + }) + + t.Run("error on GetSuccessfulChallenges failure", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + p.On("BlockNumber", mock.Anything).Return(big.NewInt(1000), nil) + p.On("GetChallengePeriod", mock.Anything, address).Return(600, nil) + p.On("GetPokes", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return([]*OpPokedEvent{{BlockNumber: big.NewInt(500)}}, nil) + p.On("GetFrom", mock.Anything).Return(from) + p.On("GetSuccessfulChallenges", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return(([]*OpPokeChallengedSuccessfullyEvent)(nil), fmt.Errorf("logs error")) + + c := NewChallenger(context.TODO(), address, p, 100, nil) + err := c.executeTick() + assert.ErrorContains(t, err, "failed to get OpPokeChallengedSuccessfully events") + p.AssertExpectations(t) + }) + + t.Run("non-challengeable poke is skipped", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + p.On("BlockNumber", mock.Anything).Return(big.NewInt(1000), nil) + p.On("GetChallengePeriod", mock.Anything, address).Return(600, nil) + poke := &OpPokedEvent{BlockNumber: big.NewInt(500)} + p.On("GetPokes", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return([]*OpPokedEvent{poke}, nil) + p.On("GetFrom", mock.Anything).Return(from) + p.On("GetSuccessfulChallenges", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return([]*OpPokeChallengedSuccessfullyEvent{}, nil) + // Block is older than challenge period — not challengeable. + ts := time.Now().Add(-time.Second * 700) + p.On("BlockByNumber", mock.Anything, big.NewInt(500)). + Return(&types.Block{Number: big.NewInt(500), Timestamp: ts}, nil) + + c := NewChallenger(context.TODO(), address, p, 100, nil) + err := c.executeTick() + assert.NoError(t, err) + // ChallengePoke should never be called. + p.AssertNotCalled(t, "ChallengePoke") + p.AssertExpectations(t) + }) + + t.Run("challengeable poke triggers SpawnChallenge", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + p.On("BlockNumber", mock.Anything).Return(big.NewInt(1000), nil) + p.On("GetChallengePeriod", mock.Anything, address).Return(600, nil) + poke := &OpPokedEvent{BlockNumber: big.NewInt(500)} + p.On("GetPokes", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return([]*OpPokedEvent{poke}, nil) + p.On("GetFrom", mock.Anything).Return(from) + p.On("GetSuccessfulChallenges", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return([]*OpPokeChallengedSuccessfullyEvent{}, nil) + // Block is recent — within challenge period. + p.On("BlockByNumber", mock.Anything, big.NewInt(500)). + Return(&types.Block{Number: big.NewInt(500), Timestamp: time.Now()}, nil) + // Signature is invalid — challengeable. + p.On("IsPokeSignatureValid", mock.Anything, address, poke).Return(false, nil) + p.On("ChallengePoke", mock.Anything, address, poke). + Return(&txHash, &types.Transaction{}, nil) + + c := NewChallenger(context.TODO(), address, p, 100, &sync.WaitGroup{}) + err := c.executeTick() + assert.NoError(t, err) + + // Wait for the SpawnChallenge goroutine to complete. + time.Sleep(50 * time.Millisecond) + + p.AssertCalled(t, "ChallengePoke", mock.Anything, address, poke) + p.AssertExpectations(t) + }) + + t.Run("already challenged poke is filtered out", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + p.On("BlockNumber", mock.Anything).Return(big.NewInt(1000), nil) + p.On("GetChallengePeriod", mock.Anything, address).Return(600, nil) + poke := &OpPokedEvent{BlockNumber: big.NewInt(500)} + p.On("GetPokes", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return([]*OpPokedEvent{poke}, nil) + p.On("GetFrom", mock.Anything).Return(from) + // Challenge exists after the poke — poke is filtered out. + p.On("GetSuccessfulChallenges", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return([]*OpPokeChallengedSuccessfullyEvent{{BlockNumber: big.NewInt(505)}}, nil) + + c := NewChallenger(context.TODO(), address, p, 100, nil) + err := c.executeTick() + assert.NoError(t, err) + // No pokes remain after filtering, so no block lookups or challenges. + p.AssertNotCalled(t, "BlockByNumber") + p.AssertNotCalled(t, "ChallengePoke") + p.AssertExpectations(t) + }) + + t.Run("lastProcessedBlock is used as fromBlock on second tick", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + // First tick: fromBlock=100, latestBlock=1000. + p.On("BlockNumber", mock.Anything).Return(big.NewInt(1000), nil).Once() + p.On("GetChallengePeriod", mock.Anything, address).Return(600, nil) + p.On("GetPokes", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return([]*OpPokedEvent{}, nil).Once() + p.On("GetFrom", mock.Anything).Return(from) + + c := NewChallenger(context.TODO(), address, p, 100, nil) + err := c.executeTick() + assert.NoError(t, err) + assert.Equal(t, big.NewInt(1000), c.lastProcessedBlock) + + // Second tick: fromBlock should now be 1000 (lastProcessedBlock), latestBlock=2000. + p.On("BlockNumber", mock.Anything).Return(big.NewInt(2000), nil).Once() + p.On("GetPokes", mock.Anything, address, big.NewInt(1000), big.NewInt(2000)). + Return([]*OpPokedEvent{}, nil).Once() + + err = c.executeTick() + assert.NoError(t, err) + assert.Equal(t, big.NewInt(2000), c.lastProcessedBlock) + p.AssertExpectations(t) + }) +} + +func TestRun(t *testing.T) { + address := types.MustAddressFromHex("0x1F7acDa376eF37EC371235a094113dF9Cb4EfEe1") + from := types.MustAddressFromHex("0x0000000000000000000000000000000000000001") + + t.Run("context cancellation exits cleanly and calls wg.Done", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + // executeTick will run once on startup — provide happy path with no pokes. + p.On("BlockNumber", mock.Anything).Return(big.NewInt(1000), nil) + p.On("GetChallengePeriod", mock.Anything, address).Return(600, nil) + p.On("GetPokes", mock.Anything, address, big.NewInt(100), big.NewInt(1000)). + Return([]*OpPokedEvent{}, nil) + p.On("GetFrom", mock.Anything).Return(from) + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + + c := NewChallenger(ctx, address, p, 100, &wg) + + done := make(chan struct{}) + go func() { + err := c.Run() + assert.NoError(t, err) + close(done) + }() + + // Cancel context to stop the loop. + cancel() + + // wg.Wait should return because Run calls wg.Done. + wg.Wait() + <-done + }) + + t.Run("tick error does not stop the loop", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + // First tick (startup): error. + p.On("BlockNumber", mock.Anything).Return((*big.Int)(nil), fmt.Errorf("rpc down")) + p.On("GetFrom", mock.Anything).Return(from) + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + + c := NewChallenger(ctx, address, p, 100, &wg) + + done := make(chan struct{}) + go func() { + err := c.Run() + assert.NoError(t, err) + close(done) + }() + + // Even though tick errored, Run should still be running. + // Cancel to exit cleanly. + time.Sleep(50 * time.Millisecond) + cancel() + wg.Wait() + <-done + }) +} + +func TestGetEarliestBlockNumber(t *testing.T) { + address := types.MustAddressFromHex("0x1F7acDa376eF37EC371235a094113dF9Cb4EfEe1") + c := NewChallenger(context.TODO(), address, nil, 0, nil) + + t.Run("block less than blocksPerPeriod returns zero", func(t *testing.T) { + // period=600, blocksPerPeriod = 600/12 = 50, lastBlock=30 < 50 + result := c.getEarliestBlockNumber(big.NewInt(30), 600) + assert.Equal(t, big.NewInt(0), result) + }) + + t.Run("block equal to blocksPerPeriod returns zero", func(t *testing.T) { + // period=600, blocksPerPeriod = 50, lastBlock=50 is not less than 50 + result := c.getEarliestBlockNumber(big.NewInt(50), 600) + assert.Equal(t, 0, result.Cmp(big.NewInt(0))) + }) + + t.Run("block greater than blocksPerPeriod returns difference", func(t *testing.T) { + // period=600, blocksPerPeriod = 50, lastBlock=1000 -> 1000-50 = 950 + result := c.getEarliestBlockNumber(big.NewInt(1000), 600) + assert.Equal(t, big.NewInt(950), result) + }) + + t.Run("small period", func(t *testing.T) { + // period=12, blocksPerPeriod = 12/12 = 1, lastBlock=100 -> 99 + result := c.getEarliestBlockNumber(big.NewInt(100), 12) + assert.Equal(t, big.NewInt(99), result) + }) +} + +func TestSpawnChallengeErrorPath(t *testing.T) { + address := types.MustAddressFromHex("0x1F7acDa376eF37EC371235a094113dF9Cb4EfEe1") + + t.Run("ChallengePoke error does not record metrics", func(t *testing.T) { + p := new(mockScribeOptimisticProvider) + p.On("ChallengePoke", mock.Anything, mock.Anything, mock.Anything). + Return((*types.Hash)(nil), (*types.Transaction)(nil), fmt.Errorf("tx failed")) + + c := NewChallenger(context.TODO(), address, p, 0, &sync.WaitGroup{}) + poke := &OpPokedEvent{BlockNumber: big.NewInt(5000)} + + c.SpawnChallenge(poke) + time.Sleep(50 * time.Millisecond) + + // ChallengePoke was called but GetFrom should NOT be called (metrics not recorded on error). + p.AssertCalled(t, "ChallengePoke", mock.Anything, mock.Anything, mock.Anything) + p.AssertNotCalled(t, "GetFrom") + + // In-flight entry should be cleaned up. + c.inFlightMu.Lock() + _, stillInFlight := c.inFlight[5000] + c.inFlightMu.Unlock() + assert.False(t, stillInFlight) + }) +} diff --git a/core/metrics.go b/core/metrics.go index cc62499..9619e48 100644 --- a/core/metrics.go +++ b/core/metrics.go @@ -8,7 +8,7 @@ var ErrorsCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: prometheusNamespace, Name: "errors_total", Help: "Challenger Errors Counter", -}, []string{"address", "from", "error"}) +}, []string{"address", "from"}) var ChallengeCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: prometheusNamespace, diff --git a/core/scribe_optimistic_provider.go b/core/scribe_optimistic_provider.go index dd0d6b4..947ea8e 100644 --- a/core/scribe_optimistic_provider.go +++ b/core/scribe_optimistic_provider.go @@ -20,6 +20,7 @@ import ( _ "embed" "fmt" "math/big" + "sync" "time" "github.com/defiweb/go-eth/abi" @@ -40,6 +41,8 @@ var ScribeOptimisticContractABI = abi.MustParseJSON(scribeOptimisticContractJSON type ScribeOptimisticRpcProvider struct { client RPCClient flashbotClient RPCClient + fromOnce sync.Once + fromAddr types.Address } // NewScribeOptimisticRPCProvider creates a new instance of ScribeOptimisticRpcProvider. @@ -53,16 +56,19 @@ func NewScribeOptimisticRPCProvider(client RPCClient, flashbotClient RPCClient) } func (s *ScribeOptimisticRpcProvider) GetFrom(ctx context.Context) types.Address { - accs, err := s.client.Accounts(ctx) - if err != nil { - logger.Errorf("failed to get accounts with error: %v", err) - return types.ZeroAddress - } - if len(accs) == 0 { - logger.Errorf("no accounts found") - return types.ZeroAddress - } - return accs[0] + s.fromOnce.Do(func() { + accs, err := s.client.Accounts(ctx) + if err != nil { + logger.Errorf("failed to get accounts with error: %v", err) + return + } + if len(accs) == 0 { + logger.Errorf("no accounts found") + return + } + s.fromAddr = accs[0] + }) + return s.fromAddr } func (s *ScribeOptimisticRpcProvider) BlockByNumber(ctx context.Context, blockNumber *big.Int) (*types.Block, error) { @@ -78,7 +84,7 @@ func (s *ScribeOptimisticRpcProvider) GetChallengePeriod(ctx context.Context, ad opChallengePeriod := ScribeOptimisticContractABI.Methods["opChallengePeriod"] calldata, err := opChallengePeriod.EncodeArgs() if err != nil { - panic(err) + return 0, fmt.Errorf("failed to encode opChallengePeriod args: %v", err) } b, _, err := s.client.Call(ctx, &types.Call{ To: &address, diff --git a/core/scribe_optimistic_provider_test.go b/core/scribe_optimistic_provider_test.go index 73475c1..e86dec9 100644 --- a/core/scribe_optimistic_provider_test.go +++ b/core/scribe_optimistic_provider_test.go @@ -10,6 +10,7 @@ import ( "github.com/defiweb/go-eth/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) type mockRpcClient struct { @@ -56,29 +57,42 @@ func (m *mockRpcClient) GetTransactionReceipt(ctx context.Context, hash types.Ha } func TestGetFrom(t *testing.T) { - mockRpcClient := new(mockRpcClient) - provider := NewScribeOptimisticRPCProvider(mockRpcClient, nil) - // gets zero address if no accounts - call := mockRpcClient.On("Accounts", mock.Anything).Return([]types.Address{}, nil) - addr := provider.GetFrom(context.TODO()) + mockClient1 := new(mockRpcClient) + provider1 := NewScribeOptimisticRPCProvider(mockClient1, nil) + call := mockClient1.On("Accounts", mock.Anything).Return([]types.Address{}, nil) + addr := provider1.GetFrom(context.TODO()) assert.Equal(t, types.ZeroAddress, addr) - mockRpcClient.AssertExpectations(t) + mockClient1.AssertExpectations(t) call.Unset() // zero address on error - call = mockRpcClient.On("Accounts", mock.Anything).Return([]types.Address{}, fmt.Errorf("error")) - addr = provider.GetFrom(context.TODO()) + mockClient2 := new(mockRpcClient) + provider2 := NewScribeOptimisticRPCProvider(mockClient2, nil) + call = mockClient2.On("Accounts", mock.Anything).Return([]types.Address{}, fmt.Errorf("error")) + addr = provider2.GetFrom(context.TODO()) assert.Equal(t, types.ZeroAddress, addr) - mockRpcClient.AssertExpectations(t) + mockClient2.AssertExpectations(t) call.Unset() // gets first account - call = mockRpcClient.On("Accounts", mock.Anything).Return([]types.Address{{0x1}}, nil) - addr = provider.GetFrom(context.TODO()) + mockClient3 := new(mockRpcClient) + provider3 := NewScribeOptimisticRPCProvider(mockClient3, nil) + call = mockClient3.On("Accounts", mock.Anything).Return([]types.Address{{0x1}}, nil) + addr = provider3.GetFrom(context.TODO()) assert.Equal(t, types.Address{0x1}, addr) - mockRpcClient.AssertExpectations(t) + mockClient3.AssertExpectations(t) call.Unset() + + // cached result is returned on subsequent calls + mockClient4 := new(mockRpcClient) + provider4 := NewScribeOptimisticRPCProvider(mockClient4, nil) + mockClient4.On("Accounts", mock.Anything).Return([]types.Address{{0x2}}, nil).Once() + addr = provider4.GetFrom(context.TODO()) + assert.Equal(t, types.Address{0x2}, addr) + addr = provider4.GetFrom(context.TODO()) + assert.Equal(t, types.Address{0x2}, addr) + mockClient4.AssertExpectations(t) } func TestGetChallengePeriod(t *testing.T) { @@ -108,3 +122,248 @@ func TestGetChallengePeriod(t *testing.T) { mockRpcClient.AssertExpectations(t) call.Unset() } + +func TestGetPokes(t *testing.T) { + address := types.MustAddressFromHex("0x1F7acDa376eF37EC371235a094113dF9Cb4EfEe1") + + t.Run("GetLogs error", func(t *testing.T) { + client := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, nil) + client.On("GetLogs", mock.Anything, mock.Anything). + Return([]types.Log{}, fmt.Errorf("rpc error")) + + result, err := provider.GetPokes(context.TODO(), address, big.NewInt(0), big.NewInt(100)) + assert.ErrorContains(t, err, "failed to get OpPoked events") + assert.Nil(t, result) + }) + + t.Run("empty logs", func(t *testing.T) { + client := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, nil) + client.On("GetLogs", mock.Anything, mock.Anything). + Return([]types.Log{}, nil) + + result, err := provider.GetPokes(context.TODO(), address, big.NewInt(0), big.NewInt(100)) + assert.NoError(t, err) + assert.Empty(t, result) + }) + + t.Run("decode error skips bad log", func(t *testing.T) { + client := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, nil) + // Return a log with invalid data that will fail decoding. + badLog := types.Log{ + BlockNumber: big.NewInt(50), + Topics: []types.Hash{}, + Data: []byte{0x01}, + } + client.On("GetLogs", mock.Anything, mock.Anything). + Return([]types.Log{badLog}, nil) + + result, err := provider.GetPokes(context.TODO(), address, big.NewInt(0), big.NewInt(100)) + assert.NoError(t, err) + assert.Empty(t, result) + }) + + t.Run("successful decode", func(t *testing.T) { + client := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, nil) + validLog := types.Log{ + BlockNumber: big.NewInt(50), + Topics: []types.Hash{ + types.MustHashFromHex("0xb9dc937c5e394d0c8f76e0e324500b88251b4c909ddc56232df10e2ea42b3c63", types.PadNone), + types.MustHashFromHex("0x0000000000000000000000001f7acda376ef37ec371235a094113df9cb4efee1", types.PadNone), + types.MustHashFromHex("0x0000000000000000000000006813eb9362372eef6200f3b1dbc3f819671cba69", types.PadNone), + }, + } + client.On("GetLogs", mock.Anything, mock.Anything). + Return([]types.Log{validLog}, nil) + + result, err := provider.GetPokes(context.TODO(), address, big.NewInt(0), big.NewInt(100)) + assert.NoError(t, err) + require.Len(t, result, 1) + assert.Equal(t, big.NewInt(50), result[0].BlockNumber) + }) +} + +func TestGetSuccessfulChallenges(t *testing.T) { + address := types.MustAddressFromHex("0x1F7acDa376eF37EC371235a094113dF9Cb4EfEe1") + + t.Run("GetLogs error", func(t *testing.T) { + client := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, nil) + client.On("GetLogs", mock.Anything, mock.Anything). + Return([]types.Log{}, fmt.Errorf("rpc error")) + + result, err := provider.GetSuccessfulChallenges(context.TODO(), address, big.NewInt(0), big.NewInt(100)) + assert.ErrorContains(t, err, "failed to get OpPokeChallengedSuccessfully events") + assert.Nil(t, result) + }) + + t.Run("empty logs", func(t *testing.T) { + client := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, nil) + client.On("GetLogs", mock.Anything, mock.Anything). + Return([]types.Log{}, nil) + + result, err := provider.GetSuccessfulChallenges(context.TODO(), address, big.NewInt(0), big.NewInt(100)) + assert.NoError(t, err) + assert.Empty(t, result) + }) + + t.Run("decode error skips bad log", func(t *testing.T) { + client := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, nil) + badLog := types.Log{ + BlockNumber: big.NewInt(50), + Topics: []types.Hash{}, + Data: []byte{0x01}, + } + client.On("GetLogs", mock.Anything, mock.Anything). + Return([]types.Log{badLog}, nil) + + result, err := provider.GetSuccessfulChallenges(context.TODO(), address, big.NewInt(0), big.NewInt(100)) + assert.NoError(t, err) + assert.Empty(t, result) + }) + + t.Run("successful decode", func(t *testing.T) { + client := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, nil) + validLog := types.Log{ + BlockNumber: big.NewInt(50), + Topics: []types.Hash{ + types.MustHashFromHex("0xac50cef58b3aef7f7c30349f5e4a342a29d2325a02eafc8dacfdba391e6d5db3", types.PadNone), + types.MustHashFromHex("0x0000000000000000000000001f7acda376ef37ec371235a094113df9cb4efee1", types.PadNone), + }, + Data: types.MustBytesFromHex("0x00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000004bd2a556b00000000000000000000000000000000000000000000000000000000"), + } + client.On("GetLogs", mock.Anything, mock.Anything). + Return([]types.Log{validLog}, nil) + + result, err := provider.GetSuccessfulChallenges(context.TODO(), address, big.NewInt(0), big.NewInt(100)) + assert.NoError(t, err) + require.Len(t, result, 1) + assert.Equal(t, big.NewInt(50), result[0].BlockNumber) + }) +} + +func TestIsPokeSignatureValid(t *testing.T) { + address := types.MustAddressFromHex("0x1F7acDa376eF37EC371235a094113dF9Cb4EfEe1") + poke := &OpPokedEvent{ + BlockNumber: big.NewInt(100), + PokeData: PokeData{Val: big.NewInt(1000), Age: 123}, + } + + t.Run("constructPokeMessage error", func(t *testing.T) { + client := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, nil) + client.On("Call", mock.Anything, mock.Anything, types.LatestBlockNumber). + Return([]byte{}, nil, fmt.Errorf("call error")) + + valid, err := provider.IsPokeSignatureValid(context.TODO(), address, poke) + assert.Error(t, err) + assert.False(t, valid) + }) + + t.Run("isSchnorrSignatureAcceptable error", func(t *testing.T) { + client := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, nil) + // First Call: constructPokeMessage succeeds with a 32-byte message. + msgBytes := hexutil.MustHexToBytes("0x0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000003e80000000000000000000000000000000000000000000000000000000000000000") + call1 := client.On("Call", mock.Anything, mock.Anything, types.LatestBlockNumber). + Return(msgBytes, &types.Call{}, nil).Once() + // Second Call: isSchnorrSignatureAcceptable fails. + client.On("Call", mock.Anything, mock.Anything, types.LatestBlockNumber). + Return([]byte{}, nil, fmt.Errorf("signature check error")).Once() + + valid, err := provider.IsPokeSignatureValid(context.TODO(), address, poke) + assert.Error(t, err) + assert.False(t, valid) + call1.Unset() + }) +} + +func TestChallengePoke(t *testing.T) { + address := types.MustAddressFromHex("0x1F7acDa376eF37EC371235a094113dF9Cb4EfEe1") + txHash := types.MustHashFromHex("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", types.PadNone) + blockHash := types.MustHashFromHex("0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", types.PadNone) + status := uint64(1) + poke := &OpPokedEvent{BlockNumber: big.NewInt(100)} + + receipt := &types.TransactionReceipt{ + TransactionHash: txHash, + Status: &status, + BlockHash: blockHash, + BlockNumber: big.NewInt(200), + } + + t.Run("no flashbot client uses mainnet", func(t *testing.T) { + client := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, nil) + client.On("SendTransaction", mock.Anything, mock.Anything). + Return(&txHash, &types.Transaction{}, nil) + client.On("GetTransactionReceipt", mock.Anything, txHash). + Return(receipt, nil) + + hash, tx, err := provider.ChallengePoke(context.TODO(), address, poke) + require.NoError(t, err) + assert.Equal(t, &txHash, hash) + assert.NotNil(t, tx) + client.AssertExpectations(t) + }) + + t.Run("flashbot success does not fall back", func(t *testing.T) { + client := new(mockRpcClient) + flashbot := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, flashbot) + flashbot.On("SendTransaction", mock.Anything, mock.Anything). + Return(&txHash, &types.Transaction{}, nil) + flashbot.On("GetTransactionReceipt", mock.Anything, txHash). + Return(receipt, nil) + + hash, tx, err := provider.ChallengePoke(context.TODO(), address, poke) + require.NoError(t, err) + assert.Equal(t, &txHash, hash) + assert.NotNil(t, tx) + // Mainnet client should not be called. + client.AssertNotCalled(t, "SendTransaction") + flashbot.AssertExpectations(t) + }) + + t.Run("flashbot failure falls back to mainnet", func(t *testing.T) { + client := new(mockRpcClient) + flashbot := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, flashbot) + // Flashbot send fails. + flashbot.On("SendTransaction", mock.Anything, mock.Anything). + Return((*types.Hash)(nil), (*types.Transaction)(nil), fmt.Errorf("flashbot down")) + // Mainnet succeeds. + client.On("SendTransaction", mock.Anything, mock.Anything). + Return(&txHash, &types.Transaction{}, nil) + client.On("GetTransactionReceipt", mock.Anything, txHash). + Return(receipt, nil) + + hash, tx, err := provider.ChallengePoke(context.TODO(), address, poke) + require.NoError(t, err) + assert.Equal(t, &txHash, hash) + assert.NotNil(t, tx) + flashbot.AssertExpectations(t) + client.AssertExpectations(t) + }) + + t.Run("both flashbot and mainnet fail", func(t *testing.T) { + client := new(mockRpcClient) + flashbot := new(mockRpcClient) + provider := NewScribeOptimisticRPCProvider(client, flashbot) + flashbot.On("SendTransaction", mock.Anything, mock.Anything). + Return((*types.Hash)(nil), (*types.Transaction)(nil), fmt.Errorf("flashbot down")) + client.On("SendTransaction", mock.Anything, mock.Anything). + Return((*types.Hash)(nil), (*types.Transaction)(nil), fmt.Errorf("mainnet down")) + + hash, tx, err := provider.ChallengePoke(context.TODO(), address, poke) + assert.Error(t, err) + assert.Nil(t, hash) + assert.Nil(t, tx) + }) +} diff --git a/core/utils.go b/core/utils.go index 7c232b4..4bea05f 100644 --- a/core/utils.go +++ b/core/utils.go @@ -9,6 +9,10 @@ import ( logger "github.com/sirupsen/logrus" ) +// txConfirmationPollInterval is the polling interval for checking transaction confirmations. +// Defaults to ~1 block time. Overridden in tests for fast execution. +var txConfirmationPollInterval = 12 * time.Second + // WaitForTxConfirmation waits for the transaction to be confirmed. func WaitForTxConfirmation( ctx context.Context, @@ -24,7 +28,7 @@ func WaitForTxConfirmation( } // check +- every block - ticker := time.NewTicker(12 * time.Second) + ticker := time.NewTicker(txConfirmationPollInterval) defer ticker.Stop() ctx, cancel := context.WithTimeout(ctx, timeout) diff --git a/core/utils_test.go b/core/utils_test.go new file mode 100644 index 0000000..5b241ae --- /dev/null +++ b/core/utils_test.go @@ -0,0 +1,163 @@ +package core + +import ( + "context" + "fmt" + "math/big" + "testing" + "time" + + "github.com/defiweb/go-eth/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func init() { + // Speed up polling for tests. + txConfirmationPollInterval = 10 * time.Millisecond +} + +func TestWaitForTxConfirmation(t *testing.T) { + hash := types.MustHashFromHex("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", types.PadNone) + status := uint64(1) + + t.Run("nil client returns error", func(t *testing.T) { + receipt, err := WaitForTxConfirmation(context.TODO(), nil, &hash, time.Second) + assert.Nil(t, receipt) + assert.ErrorContains(t, err, "ethereum client not set") + }) + + t.Run("nil txHash returns error", func(t *testing.T) { + client := new(mockRpcClient) + receipt, err := WaitForTxConfirmation(context.TODO(), client, nil, time.Second) + assert.Nil(t, receipt) + assert.ErrorContains(t, err, "tx hash is nil") + }) + + t.Run("timeout returns error", func(t *testing.T) { + client := new(mockRpcClient) + // Always return nil receipt to keep polling until timeout. + client.On("GetTransactionReceipt", mock.Anything, hash). + Return((*types.TransactionReceipt)(nil), nil) + + receipt, err := WaitForTxConfirmation(context.TODO(), client, &hash, 50*time.Millisecond) + assert.Nil(t, receipt) + assert.ErrorContains(t, err, "failed to wait for transaction confirmation") + }) + + t.Run("successful receipt returned", func(t *testing.T) { + client := new(mockRpcClient) + expected := &types.TransactionReceipt{ + TransactionHash: hash, + Status: &status, + BlockHash: types.MustHashFromHex("0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", types.PadNone), + BlockNumber: big.NewInt(100), + } + client.On("GetTransactionReceipt", mock.Anything, hash).Return(expected, nil) + + receipt, err := WaitForTxConfirmation(context.TODO(), client, &hash, time.Second) + require.NoError(t, err) + assert.Equal(t, expected, receipt) + }) + + t.Run("transient error keeps polling until success", func(t *testing.T) { + client := new(mockRpcClient) + expected := &types.TransactionReceipt{ + TransactionHash: hash, + Status: &status, + BlockHash: types.MustHashFromHex("0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", types.PadNone), + BlockNumber: big.NewInt(100), + } + // First call: error. Second call: success. + client.On("GetTransactionReceipt", mock.Anything, hash). + Return((*types.TransactionReceipt)(nil), fmt.Errorf("network error")).Once() + client.On("GetTransactionReceipt", mock.Anything, hash). + Return(expected, nil).Once() + + receipt, err := WaitForTxConfirmation(context.TODO(), client, &hash, time.Second) + require.NoError(t, err) + assert.Equal(t, expected, receipt) + client.AssertNumberOfCalls(t, "GetTransactionReceipt", 2) + }) + + t.Run("nil receipt keeps polling until success", func(t *testing.T) { + client := new(mockRpcClient) + expected := &types.TransactionReceipt{ + TransactionHash: hash, + Status: &status, + BlockHash: types.MustHashFromHex("0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", types.PadNone), + BlockNumber: big.NewInt(100), + } + // First call: nil receipt. Second call: success. + client.On("GetTransactionReceipt", mock.Anything, hash). + Return((*types.TransactionReceipt)(nil), nil).Once() + client.On("GetTransactionReceipt", mock.Anything, hash). + Return(expected, nil).Once() + + receipt, err := WaitForTxConfirmation(context.TODO(), client, &hash, time.Second) + require.NoError(t, err) + assert.Equal(t, expected, receipt) + client.AssertNumberOfCalls(t, "GetTransactionReceipt", 2) + }) + + t.Run("receipt with nil status keeps polling", func(t *testing.T) { + client := new(mockRpcClient) + pending := &types.TransactionReceipt{ + TransactionHash: hash, + Status: nil, + } + confirmed := &types.TransactionReceipt{ + TransactionHash: hash, + Status: &status, + BlockNumber: big.NewInt(100), + } + client.On("GetTransactionReceipt", mock.Anything, hash). + Return(pending, nil).Once() + client.On("GetTransactionReceipt", mock.Anything, hash). + Return(confirmed, nil).Once() + + receipt, err := WaitForTxConfirmation(context.TODO(), client, &hash, time.Second) + require.NoError(t, err) + assert.Equal(t, confirmed, receipt) + client.AssertNumberOfCalls(t, "GetTransactionReceipt", 2) + }) + + t.Run("receipt with zero hash keeps polling", func(t *testing.T) { + client := new(mockRpcClient) + pending := &types.TransactionReceipt{ + TransactionHash: types.Hash{}, + Status: &status, + } + confirmed := &types.TransactionReceipt{ + TransactionHash: hash, + Status: &status, + BlockNumber: big.NewInt(100), + } + client.On("GetTransactionReceipt", mock.Anything, hash). + Return(pending, nil).Once() + client.On("GetTransactionReceipt", mock.Anything, hash). + Return(confirmed, nil).Once() + + receipt, err := WaitForTxConfirmation(context.TODO(), client, &hash, time.Second) + require.NoError(t, err) + assert.Equal(t, confirmed, receipt) + client.AssertNumberOfCalls(t, "GetTransactionReceipt", 2) + }) + + t.Run("context cancellation returns error", func(t *testing.T) { + client := new(mockRpcClient) + client.On("GetTransactionReceipt", mock.Anything, hash). + Return((*types.TransactionReceipt)(nil), nil) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(30 * time.Millisecond) + cancel() + }() + + receipt, err := WaitForTxConfirmation(ctx, client, &hash, 5*time.Second) + assert.Nil(t, receipt) + assert.ErrorContains(t, err, "failed to wait for transaction confirmation") + }) +}