diff --git a/pkg/node/chain.go b/pkg/node/chain.go index ef1a63c9881..6121bbb69a0 100644 --- a/pkg/node/chain.go +++ b/pkg/node/chain.go @@ -15,9 +15,11 @@ import ( "strings" "time" + "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/rpc" + "github.com/ethersphere/bee/v2/pkg/accounting" "github.com/ethersphere/bee/v2/pkg/config" "github.com/ethersphere/bee/v2/pkg/crypto" "github.com/ethersphere/bee/v2/pkg/log" @@ -33,6 +35,7 @@ import ( "github.com/ethersphere/bee/v2/pkg/transaction" "github.com/ethersphere/bee/v2/pkg/transaction/backendnoop" "github.com/ethersphere/bee/v2/pkg/transaction/wrapped" + "github.com/ethersphere/bee/v2/pkg/util/abiutil" ) const ( @@ -369,3 +372,275 @@ func (m *noOpChequebookService) LastCheque(common.Address) (*chequebook.SignedCh func (m *noOpChequebookService) LastCheques() (map[common.Address]*chequebook.SignedCheque, error) { return nil, postagecontract.ErrChainDisabled } + +// swapServiceDeps is the injection seam for setupSwapService. +type swapServiceDeps struct { + InitSwap func( + p2ps *libp2p.Service, + logger log.Logger, + stateStore storage.StateStorer, + networkID uint64, + overlayEthAddress common.Address, + chequebookService chequebook.Service, + chequeStore chequebook.ChequeStore, + cashoutService chequebook.CashoutService, + accounting settlement.Accounting, + priceOracleAddress string, + chainID int64, + transactionService transaction.Service, + ) (*swap.Service, priceoracle.Service, error) +} + +// defaultSwapServiceDeps wires the real InitSwap used by NewBee. +var defaultSwapServiceDeps = swapServiceDeps{ + InitSwap: InitSwap, +} + +// swapServiceResult collects what setupSwapService produces. SwapService is +// nil when the gate (SwapEnable && chainEnabled) is closed. PayFunc is set +// only when SwapService is set AND ChequebookEnable is true; the caller wires +// it into the accounting service. +type swapServiceResult struct { + SwapService *swap.Service + PriceOracle priceoracle.Service + PayFunc accounting.PayFunc +} + +// setupSwapService is the swap-service-init block from NewBee: gated on +// SwapEnable && chainEnabled, and inside that gate also exposes the +// ChequebookEnable-only PayFunc wire-up. +func setupSwapService( + o *Options, + chainEnabled bool, + p2ps *libp2p.Service, + logger log.Logger, + stateStore storage.StateStorer, + networkID uint64, + overlayEthAddress common.Address, + chequebookService chequebook.Service, + chequeStore chequebook.ChequeStore, + cashoutService chequebook.CashoutService, + acc settlement.Accounting, + chainID int64, + transactionService transaction.Service, + deps swapServiceDeps, +) (swapServiceResult, error) { + if !o.SwapEnable || !chainEnabled { + return swapServiceResult{}, nil + } + + swapService, priceOracle, err := deps.InitSwap( + p2ps, + logger, + stateStore, + networkID, + overlayEthAddress, + chequebookService, + chequeStore, + cashoutService, + acc, + o.PriceOracleAddress, + chainID, + transactionService, + ) + if err != nil { + return swapServiceResult{}, fmt.Errorf("init swap service: %w", err) + } + + res := swapServiceResult{ + SwapService: swapService, + PriceOracle: priceOracle, + } + if o.ChequebookEnable { + res.PayFunc = swapService.Pay + } + return res, nil +} + +// postageContractDeps is the injection seam for setupPostageContract. +// Only the chain-RPC call (LookupERC20Address) is injectable; the rest is +// pure lookup / address resolution that runs against the real chain config. +type postageContractDeps struct { + LookupERC20 func( + ctx context.Context, + ts transaction.Service, + postageStampContractAddress common.Address, + postageStampContractABI abi.ABI, + chainEnabled bool, + ) (common.Address, error) +} + +// defaultPostageContractDeps wires the real LookupERC20Address used by NewBee. +var defaultPostageContractDeps = postageContractDeps{ + LookupERC20: postagecontract.LookupERC20Address, +} + +// postageContractResult collects the resolved postage configuration so the +// caller can keep working with the same scoped names it used inline before +// the extraction. ChainConfig is propagated because callers in NewBee also +// need StakingAddress / RedistributionAddress / ABIs from it. +type postageContractResult struct { + ChainConfig config.ChainConfig + ContractAddress common.Address + ContractABI abi.ABI + SyncStartBlock uint64 + BzzTokenAddress common.Address +} + +// setupPostageContract resolves the postage stamp contract address and the +// BZZ token address. Validation of the malformed-address / missing-start-block +// / unknown-chain cases is already done by validateChainContractOptions; this +// function trusts the values it receives. +func setupPostageContract( + ctx context.Context, + o *Options, + chainID int64, + chainEnabled bool, + transactionService transaction.Service, + deps postageContractDeps, +) (postageContractResult, error) { + chainCfg, _ := config.GetByChainID(chainID) + addr := chainCfg.PostageStampAddress + syncStart := chainCfg.PostageStampStartBlock + if o.PostageContractAddress != "" { + addr = common.HexToAddress(o.PostageContractAddress) + syncStart = o.PostageContractStartBlock + } + + contractABI := abiutil.MustParseABI(chainCfg.PostageStampABI) + + bzz, err := deps.LookupERC20(ctx, transactionService, addr, contractABI, chainEnabled) + if err != nil { + return postageContractResult{}, fmt.Errorf("lookup erc20 postage address: %w", err) + } + + return postageContractResult{ + ChainConfig: chainCfg, + ContractAddress: addr, + ContractABI: contractABI, + SyncStartBlock: syncStart, + BzzTokenAddress: bzz, + }, nil +} + +// swapDeps is the injection seam for setupSwap. Production wires +// defaultSwapDeps; tests inject fakes to exercise every error path and +// configuration combination without touching a real chain backend. Fields are +// exported so the package_test test file can build a swapDeps directly. +type swapDeps struct { + InitFactory func( + logger log.Logger, + backend transaction.Backend, + chainID int64, + ts transaction.Service, + factoryAddress string, + ) (chequebook.Factory, error) + + InitChequebookService func( + ctx context.Context, + logger log.Logger, + stateStore storage.StateStorer, + signer crypto.Signer, + chainID int64, + backend transaction.Backend, + overlayEthAddress common.Address, + ts transaction.Service, + factory chequebook.Factory, + initialDeposit string, + erc20Service erc20.Service, + ) (chequebook.Service, error) + + InitChequeStoreCashout func( + stateStore storage.StateStorer, + backend transaction.Backend, + factory chequebook.Factory, + chainID int64, + overlayEthAddress common.Address, + ts transaction.Service, + ) (chequebook.ChequeStore, chequebook.CashoutService) +} + +// defaultSwapDeps wires the real chain-dependent constructors that production +// uses inside setupSwap. +var defaultSwapDeps = swapDeps{ + InitFactory: InitChequebookFactory, + InitChequebookService: InitChequebookService, + InitChequeStoreCashout: initChequeStoreCashout, +} + +// swapResult is the set of values setupSwap may produce. A nil +// ChequebookService means "leave the caller's default in place" (i.e. the +// noOpChequebookService). Other zero-value fields mean the corresponding +// subsystem was not initialized for this configuration. +type swapResult struct { + Erc20Service erc20.Service + ChequebookService chequebook.Service + ChequeStore chequebook.ChequeStore + CashoutService chequebook.CashoutService +} + +// setupSwap performs the chequebook / cheque-store / cashout wiring that is +// gated on o.SwapEnable (and, for the chequebook service, on +// o.ChequebookEnable && chainEnabled). It returns error strings exactly the +// way the original inline block in NewBee did. +func setupSwap( + ctx context.Context, + logger log.Logger, + o *Options, + chainEnabled bool, + chainBackend transaction.Backend, + chainID int64, + transactionService transaction.Service, + stateStore storage.StateStorer, + signer crypto.Signer, + overlayEthAddress common.Address, + deps swapDeps, +) (swapResult, error) { + var res swapResult + if !o.SwapEnable { + return res, nil + } + + factory, err := deps.InitFactory(logger, chainBackend, chainID, transactionService, o.SwapFactoryAddress) + if err != nil { + return res, fmt.Errorf("init chequebook factory: %w", err) + } + + erc20Address, err := factory.ERC20Address(ctx) + if err != nil { + return res, fmt.Errorf("factory fail: %w", err) + } + + res.Erc20Service = erc20.New(transactionService, erc20Address) + + if o.ChequebookEnable && chainEnabled { + svc, err := deps.InitChequebookService( + ctx, + logger, + stateStore, + signer, + chainID, + chainBackend, + overlayEthAddress, + transactionService, + factory, + o.SwapInitialDeposit, + res.Erc20Service, + ) + if err != nil { + return swapResult{}, fmt.Errorf("init chequebook service: %w", err) + } + res.ChequebookService = svc + } + + res.ChequeStore, res.CashoutService = deps.InitChequeStoreCashout( + stateStore, + chainBackend, + factory, + chainID, + overlayEthAddress, + transactionService, + ) + + return res, nil +} diff --git a/pkg/node/export_test.go b/pkg/node/export_test.go index 3a142d94f7f..0813a9d0410 100644 --- a/pkg/node/export_test.go +++ b/pkg/node/export_test.go @@ -4,4 +4,110 @@ package node -var ValidatePublicAddress = validatePublicAddress +import ( + "context" + "io" + + "github.com/ethersphere/bee/v2/pkg/log" + "github.com/ethersphere/bee/v2/pkg/util/syncutil" +) + +var ( + ValidatePublicAddress = validatePublicAddress + ValidateOptions = validateOptions + ValidateChainContractOptions = validateChainContractOptions + ParsePaymentThreshold = parsePaymentThreshold + IsChainEnabled = isChainEnabled + BatchStoreExists = batchStoreExists + CheckOverlay = checkOverlay + OverlayNonceExists = overlayNonceExists + SetOverlay = setOverlay + SetupSwap = setupSwap + DefaultSwapDeps = defaultSwapDeps + SetupPostageContract = setupPostageContract + DefaultPostageContractDeps = defaultPostageContractDeps + SetupSwapService = setupSwapService + DefaultSwapServiceDeps = defaultSwapServiceDeps +) + +type ( + SwapDeps = swapDeps + SwapResult = swapResult + PostageContractDeps = postageContractDeps + PostageContractResult = postageContractResult + SwapServiceDeps = swapServiceDeps + SwapServiceResult = swapServiceResult +) + +const ( + OverlayNonceKey = overlayNonce + NoncedOverlayKey = noncedOverlayKey +) + +// ShutdownTestClosers groups the io.Closer fields a Shutdown test wants to +// observe. Any nil entry leaves the corresponding *Bee field unset. +type ShutdownTestClosers struct { + API io.Closer + PSS io.Closer + GSOC io.Closer + Pusher io.Closer + Puller io.Closer + Accounting io.Closer + PullSync io.Closer + Hive io.Closer + Salud io.Closer + P2P io.Closer + PriceOracle io.Closer + TransactionMonitor io.Closer + Transaction io.Closer + Listener io.Closer + PostageService io.Closer + AccessControl io.Closer + Tracer io.Closer + Topology io.Closer + StorageIncentives io.Closer + Stabilization io.Closer + Localstore io.Closer + StateStore io.Closer + StamperStore io.Closer + Resolver io.Closer + EthClient func() +} + +// NewBeeForShutdownTest constructs a minimal *Bee suitable for exercising +// Shutdown without standing up any real subsystems. The returned context is +// the one Shutdown's ctxCancel cancels, so tests can assert it fires. +func NewBeeForShutdownTest(logger log.Logger, c ShutdownTestClosers) (*Bee, context.Context) { + ctx, ctxCancel := context.WithCancel(context.Background()) + b := &Bee{ + logger: logger, + ctxCancel: ctxCancel, + syncingStopped: syncutil.NewSignaler(), + apiCloser: c.API, + pssCloser: c.PSS, + gsocCloser: c.GSOC, + pusherCloser: c.Pusher, + pullerCloser: c.Puller, + accountingCloser: c.Accounting, + pullSyncCloser: c.PullSync, + hiveCloser: c.Hive, + saludCloser: c.Salud, + p2pService: c.P2P, + priceOracleCloser: c.PriceOracle, + transactionMonitorCloser: c.TransactionMonitor, + transactionCloser: c.Transaction, + listenerCloser: c.Listener, + postageServiceCloser: c.PostageService, + accesscontrolCloser: c.AccessControl, + tracerCloser: c.Tracer, + topologyCloser: c.Topology, + storageIncetivesCloser: c.StorageIncentives, + stabilizationDetector: c.Stabilization, + localstoreCloser: c.Localstore, + stateStoreCloser: c.StateStore, + stamperStoreCloser: c.StamperStore, + resolverCloser: c.Resolver, + ethClientCloser: c.EthClient, + } + return b, ctx +} diff --git a/pkg/node/init_chain_test.go b/pkg/node/init_chain_test.go new file mode 100644 index 00000000000..c7ec8f02c10 --- /dev/null +++ b/pkg/node/init_chain_test.go @@ -0,0 +1,161 @@ +// Copyright 2025 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package node_test + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + cryptomock "github.com/ethersphere/bee/v2/pkg/crypto/mock" + "github.com/ethersphere/bee/v2/pkg/log" + "github.com/ethersphere/bee/v2/pkg/node" + statestoremock "github.com/ethersphere/bee/v2/pkg/statestore/mock" +) + +// These tests cover the chainEnabled=false path through InitChain. The +// chainEnabled=true path needs a real RPC server to exercise and is covered by +// integration tests. The points pinned here are: +// +// - the no-op backend wires through and the function returns a well-formed +// tuple with the requested chainID; +// - chainID == -1 is a sentinel that skips the mismatch check, so any backend +// chain id is accepted; +// - a signer that fails to expose its Ethereum address surfaces a +// "blockchain address: …" error before the transaction service is built. + +const testChainID = int64(12345) + +func TestInitChain_ChainDisabledHappyPath(t *testing.T) { + t.Parallel() + + ctx := context.Background() + wantEth := common.HexToAddress("0x1111111111111111111111111111111111111111") + signer := cryptomock.New(cryptomock.WithEthereumAddressFunc(func() (common.Address, error) { + return wantEth, nil + })) + + backend, ethAddr, gotChainID, monitor, txService, err := node.InitChain( + ctx, + log.Noop, + statestoremock.NewStateStore(), + testChainID, + signer, + 100*time.Millisecond, + false, // chainEnabled + 0, + 0, + node.BlockchainRPCConfig{}, + 0, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + t.Cleanup(func() { + if backend != nil { + backend.Close() + } + if monitor != nil { + if err := monitor.Close(); err != nil { + t.Errorf("monitor close: %v", err) + } + } + if txService != nil { + if err := txService.Close(); err != nil { + t.Errorf("txService close: %v", err) + } + } + }) + + if ethAddr != wantEth { + t.Errorf("eth address mismatch: got %s, want %s", ethAddr, wantEth) + } + if gotChainID != testChainID { + t.Errorf("chain id mismatch: got %d, want %d", gotChainID, testChainID) + } + if backend == nil { + t.Error("backend must not be nil") + } + if monitor == nil { + t.Error("transaction monitor must not be nil") + } + if txService == nil { + t.Error("transaction service must not be nil") + } +} + +func TestInitChain_ChainIDSentinelAcceptsAnyBackendChainID(t *testing.T) { + t.Parallel() + + signer := cryptomock.New(cryptomock.WithEthereumAddressFunc(func() (common.Address, error) { + return common.Address{}, nil + })) + + // chainID == -1 tells InitChain to skip the equality check. With + // chainEnabled=false the no-op backend returns -1 too, so any value + // passes — but the documented sentinel behavior is what we pin here. + backend, _, gotChainID, monitor, txService, err := node.InitChain( + context.Background(), + log.Noop, + statestoremock.NewStateStore(), + -1, + signer, + 100*time.Millisecond, + false, + 0, 0, + node.BlockchainRPCConfig{}, + 0, + ) + if err != nil { + t.Fatalf("InitChain with chainID=-1 returned %v", err) + } + t.Cleanup(func() { + if backend != nil { + backend.Close() + } + if monitor != nil { + _ = monitor.Close() + } + if txService != nil { + _ = txService.Close() + } + }) + + if gotChainID != -1 { + t.Errorf("backend chain id: got %d, want -1 (the sentinel passed through)", gotChainID) + } +} + +func TestInitChain_SignerError_IsWrappedAsBlockchainAddress(t *testing.T) { + t.Parallel() + + wantErr := errors.New("signer offline") + signer := cryptomock.New(cryptomock.WithEthereumAddressFunc(func() (common.Address, error) { + return common.Address{}, wantErr + })) + + //nolint:dogsled // InitChain has six returns; the test only needs err. + _, _, _, _, _, err := node.InitChain( + context.Background(), + log.Noop, + statestoremock.NewStateStore(), + testChainID, + signer, + 100*time.Millisecond, + false, + 0, 0, + node.BlockchainRPCConfig{}, + 0, + ) + if !errors.Is(err, wantErr) { + t.Fatalf("expected wrapped signer error %v, got %v", wantErr, err) + } + if !strings.Contains(err.Error(), "blockchain address") { + t.Fatalf("expected error to start with %q, got %q", "blockchain address", err.Error()) + } +} diff --git a/pkg/node/init_stores_test.go b/pkg/node/init_stores_test.go new file mode 100644 index 00000000000..69c272a0f8f --- /dev/null +++ b/pkg/node/init_stores_test.go @@ -0,0 +1,134 @@ +// Copyright 2025 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package node_test + +import ( + "path/filepath" + "testing" + + "github.com/ethersphere/bee/v2/pkg/log" + "github.com/ethersphere/bee/v2/pkg/node" +) + +// These tests cover the on-disk vs in-memory branching of InitStateStore and +// InitStamperStore. They are useful regression nets because an unintended +// switch of the empty-dataDir branch to on-disk behavior would silently +// persist state in the working directory. + +// Note: a zero cacheCapacity makes cache.Wrap fail and InitStateStore leaks +// the already-opened leveldb on that error path. That is a real bug but it is +// independent of the C1 work; not adding a regression test for it here would +// require closing the leveldb on the cache.Wrap error path in +// pkg/node/statestore.go. Left for a follow-up. + +func TestInitStateStore_EmptyDirYieldsInMemoryStore(t *testing.T) { + t.Parallel() + + store, _, err := node.InitStateStore(log.Noop, "", 1) + if err != nil { + t.Fatalf("InitStateStore with empty dir returned %v", err) + } + t.Cleanup(func() { + if err := store.Close(); err != nil { + t.Errorf("close: %v", err) + } + }) + + if err := store.Put("k", []byte("v")); err != nil { + t.Fatalf("put: %v", err) + } + var got []byte + if err := store.Get("k", &got); err != nil { + t.Fatalf("get: %v", err) + } + if string(got) != "v" { + t.Fatalf("got %q, want %q", got, "v") + } +} + +func TestInitStateStore_OnDiskPersistsAcrossInstances(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + first, _, err := node.InitStateStore(log.Noop, dir, 1) + if err != nil { + t.Fatalf("first InitStateStore: %v", err) + } + if err := first.Put("hello", []byte("world")); err != nil { + t.Fatalf("put: %v", err) + } + if err := first.Close(); err != nil { + t.Fatalf("close first: %v", err) + } + + // A leveldb directory should now exist under /statestore. + if entries, err := filepath.Glob(filepath.Join(dir, "statestore", "*")); err != nil { + t.Fatalf("glob: %v", err) + } else if len(entries) == 0 { + t.Fatal("expected leveldb files under /statestore") + } + + second, _, err := node.InitStateStore(log.Noop, dir, 1) + if err != nil { + t.Fatalf("second InitStateStore: %v", err) + } + t.Cleanup(func() { + if err := second.Close(); err != nil { + t.Errorf("close second: %v", err) + } + }) + + var got []byte + if err := second.Get("hello", &got); err != nil { + t.Fatalf("get on second open: %v", err) + } + if string(got) != "world" { + t.Fatalf("got %q, want %q across reopen", got, "world") + } +} + +func TestInitStamperStore_EmptyDirIsInMemory(t *testing.T) { + t.Parallel() + + // First run with an empty directory: in-memory, dirty must be false. + store, dirty, err := node.InitStamperStore(log.Noop, "", nil) + if err != nil { + t.Fatalf("InitStamperStore: %v", err) + } + if dirty { + t.Fatal("in-memory stamper store must not report dirty") + } + if err := store.Close(); err != nil { + t.Fatalf("close: %v", err) + } +} + +func TestInitStamperStore_OnDisk_CleanThenReopens(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + logger := log.Noop + + first, dirty, err := node.InitStamperStore(logger, dir, nil) + if err != nil { + t.Fatalf("first InitStamperStore: %v", err) + } + if dirty { + t.Fatal("fresh on-disk stamper store must report clean") + } + if err := first.Close(); err != nil { + t.Fatalf("close first: %v", err) + } + + // Reopening must succeed after a clean close. + second, _, err := node.InitStamperStore(logger, dir, nil) + if err != nil { + t.Fatalf("second InitStamperStore: %v", err) + } + if err := second.Close(); err != nil { + t.Fatalf("close second: %v", err) + } +} diff --git a/pkg/node/main_test.go b/pkg/node/main_test.go index 44630fe15f9..d5444084fbd 100644 --- a/pkg/node/main_test.go +++ b/pkg/node/main_test.go @@ -23,5 +23,13 @@ func TestMain(m *testing.M) { goleak.IgnoreTopFunction("github.com/libp2p/go-cidranger/net.Network.LeastCommonBitPosition"), goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), goleak.IgnoreTopFunction("github.com/libp2p/go-cidranger.(*prefixTrie).insert"), + // goleveldb compaction goroutines settle asynchronously; the same + // ignores are used by pkg/api, pkg/postage/batchstore, pkg/pusher, + // pkg/puller and pkg/statestore/leveldb test packages. + goleak.IgnoreTopFunction("github.com/syndtr/goleveldb/leveldb.(*DB).mpoolDrain"), + goleak.IgnoreTopFunction("github.com/syndtr/goleveldb/leveldb.(*DB).compactionError"), + goleak.IgnoreTopFunction("github.com/syndtr/goleveldb/leveldb.(*DB).tCompaction"), + goleak.IgnoreTopFunction("github.com/syndtr/goleveldb/leveldb.(*DB).mCompaction"), + goleak.IgnoreTopFunction("github.com/syndtr/goleveldb/leveldb.(*session).refLoop"), ) } diff --git a/pkg/node/node.go b/pkg/node/node.go index 53d96ec5419..64309355e4c 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -59,7 +59,6 @@ import ( "github.com/ethersphere/bee/v2/pkg/settlement/swap" "github.com/ethersphere/bee/v2/pkg/settlement/swap/chequebook" "github.com/ethersphere/bee/v2/pkg/settlement/swap/erc20" - "github.com/ethersphere/bee/v2/pkg/settlement/swap/priceoracle" "github.com/ethersphere/bee/v2/pkg/stabilization" "github.com/ethersphere/bee/v2/pkg/status" "github.com/ethersphere/bee/v2/pkg/steward" @@ -225,6 +224,10 @@ func NewBee( session accesscontrol.Session, o *Options, ) (b *Bee, err error) { + if err := validateOptions(o); err != nil { + return nil, err + } + // start time for node warmup duration measurement warmupStartTime := time.Now() var pullSyncStartTime time.Time @@ -240,14 +243,6 @@ func NewBee( return nil, fmt.Errorf("tracer: %w", err) } - if err := validatePublicAddress(o.NATAddr); err != nil { - return nil, fmt.Errorf("invalid NAT address %s: %w", o.NATAddr, err) - } - - if err := validatePublicAddress(o.NATWSSAddr); err != nil { - return nil, fmt.Errorf("invalid NAT WSS address %s: %w", o.NATWSSAddr, err) - } - ctx, ctxCancel := context.WithCancel(ctx) defer func() { // if there's been an error on this function @@ -286,13 +281,6 @@ func NewBee( } }(b) - if !o.FullNodeMode && o.ReserveCapacityDoubling != 0 { - return nil, fmt.Errorf("reserve capacity doubling is only allowed for full nodes") - } - - if o.ReserveCapacityDoubling < 0 || o.ReserveCapacityDoubling > maxAllowedDoubling { - return nil, fmt.Errorf("config reserve capacity doubling has to be between default: 0 and maximum: %d", maxAllowedDoubling) - } shallowReceiptTolerance := maxAllowedDoubling - o.ReserveCapacityDoubling reserveCapacity := (1 << o.ReserveCapacityDoubling) * storer.DefaultReserveCapacity @@ -435,6 +423,10 @@ func NewBee( logger.Info("using chain with network", "chain_id", chainID, "network_id", networkID) + if err := validateChainContractOptions(o, chainID); err != nil { + return nil, err + } + b.ethClientCloser = chainBackend.Close b.transactionCloser = tracerCloser b.transactionMonitorCloser = transactionMonitor @@ -535,47 +527,16 @@ func NewBee( } } - if o.SwapEnable { - chequebookFactory, err := InitChequebookFactory(logger, chainBackend, chainID, transactionService, o.SwapFactoryAddress) - if err != nil { - return nil, fmt.Errorf("init chequebook factory: %w", err) - } - - erc20Address, err := chequebookFactory.ERC20Address(ctx) - if err != nil { - return nil, fmt.Errorf("factory fail: %w", err) - } - - erc20Service = erc20.New(transactionService, erc20Address) - - if o.ChequebookEnable && chainEnabled { - chequebookService, err = InitChequebookService( - ctx, - logger, - stateStore, - signer, - chainID, - chainBackend, - overlayEthAddress, - transactionService, - chequebookFactory, - o.SwapInitialDeposit, - erc20Service, - ) - if err != nil { - return nil, fmt.Errorf("init chequebook service: %w", err) - } - } - - chequeStore, cashoutService = initChequeStoreCashout( - stateStore, - chainBackend, - chequebookFactory, - chainID, - overlayEthAddress, - transactionService, - ) + swapRes, err := setupSwap(ctx, logger, o, chainEnabled, chainBackend, chainID, transactionService, stateStore, signer, overlayEthAddress, defaultSwapDeps) + if err != nil { + return nil, err + } + erc20Service = swapRes.Erc20Service + if swapRes.ChequebookService != nil { + chequebookService = swapRes.ChequebookService } + chequeStore = swapRes.ChequeStore + cashoutService = swapRes.CashoutService lightNodes := lightnode.NewContainer(swarmAddress) @@ -592,28 +553,10 @@ func NewBee( bootnodes = append(bootnodes, addr) } - // Perform checks related to payment threshold calculations here to not duplicate - // the checks in bootstrap process - paymentThreshold, ok := new(big.Int).SetString(o.PaymentThreshold, 10) - if !ok { - return nil, fmt.Errorf("invalid payment threshold: %s", paymentThreshold) - } - - if paymentThreshold.Cmp(big.NewInt(minPaymentThreshold)) < 0 { - return nil, fmt.Errorf("payment threshold below minimum generally accepted value, need at least %d", minPaymentThreshold) - } - - if paymentThreshold.Cmp(big.NewInt(maxPaymentThreshold)) > 0 { - return nil, fmt.Errorf("payment threshold above maximum generally accepted value, needs to be reduced to at most %d", maxPaymentThreshold) - } - - if o.PaymentTolerance < 0 { - return nil, fmt.Errorf("invalid payment tolerance: %d", o.PaymentTolerance) - } - - if o.PaymentEarly > 100 || o.PaymentEarly < 0 { - return nil, fmt.Errorf("invalid payment early: %d", o.PaymentEarly) - } + // PaymentThreshold range, PaymentTolerance and PaymentEarly bounds were + // already verified by validateOptions at NewBee entry; re-parsing here is + // just to obtain the bigint value for downstream use and cannot fail. + paymentThreshold, _ := new(big.Int).SetString(o.PaymentThreshold, 10) detector, err := stabilization.NewDetector(stabilization.Config{ PeriodDuration: 2 * time.Second, @@ -692,27 +635,15 @@ func NewBee( eventListener postage.Listener ) - chainCfg, found := config.GetByChainID(chainID) - postageStampContractAddress, postageSyncStart := chainCfg.PostageStampAddress, chainCfg.PostageStampStartBlock - if o.PostageContractAddress != "" { - if !common.IsHexAddress(o.PostageContractAddress) { - return nil, errors.New("malformed postage stamp address") - } - postageStampContractAddress = common.HexToAddress(o.PostageContractAddress) - if o.PostageContractStartBlock == 0 { - return nil, errors.New("postage contract start block option not provided") - } - postageSyncStart = o.PostageContractStartBlock - } else if !found { - return nil, errors.New("no known postage stamp addresses for this network") - } - - postageStampContractABI := abiutil.MustParseABI(chainCfg.PostageStampABI) - - bzzTokenAddress, err := postagecontract.LookupERC20Address(ctx, transactionService, postageStampContractAddress, postageStampContractABI, chainEnabled) + postageRes, err := setupPostageContract(ctx, o, chainID, chainEnabled, transactionService, defaultPostageContractDeps) if err != nil { - return nil, fmt.Errorf("lookup erc20 postage address: %w", err) + return nil, err } + chainCfg := postageRes.ChainConfig + postageStampContractAddress := postageRes.ContractAddress + postageStampContractABI := postageRes.ContractABI + postageSyncStart := postageRes.SyncStartBlock + bzzTokenAddress := postageRes.BzzTokenAddress // Compute gas limit for contract transactions: when TrxDebugMode is enabled, // gas estimation is skipped and DefaultGasLimit is used for all contract calls. @@ -997,29 +928,21 @@ func NewBee( acc.SetRefreshFunc(pseudosettleService.Pay) - if o.SwapEnable && chainEnabled { - var priceOracle priceoracle.Service - swapService, priceOracle, err = InitSwap( - p2ps, - logger, - stateStore, - networkID, - overlayEthAddress, - chequebookService, - chequeStore, - cashoutService, - acc, - o.PriceOracleAddress, - chainID, - transactionService, - ) - if err != nil { - return nil, fmt.Errorf("init swap service: %w", err) - } - b.priceOracleCloser = priceOracle - - if o.ChequebookEnable { - acc.SetPayFunc(swapService.Pay) + swapSvcRes, err := setupSwapService( + o, chainEnabled, + p2ps, logger, stateStore, networkID, overlayEthAddress, + chequebookService, chequeStore, cashoutService, acc, + chainID, transactionService, + defaultSwapServiceDeps, + ) + if err != nil { + return nil, err + } + if swapSvcRes.SwapService != nil { + swapService = swapSvcRes.SwapService + b.priceOracleCloser = swapSvcRes.PriceOracle + if swapSvcRes.PayFunc != nil { + acc.SetPayFunc(swapSvcRes.PayFunc) } } @@ -1144,11 +1067,9 @@ func NewBee( apiService.SetIsWarmingUp(false) }() + // Staking address malformed-hex was validated by validateChainContractOptions. stakingContractAddress := chainCfg.StakingAddress if o.StakingContractAddress != "" { - if !common.IsHexAddress(o.StakingContractAddress) { - return nil, errors.New("malformed staking contract address") - } stakingContractAddress = common.HexToAddress(o.StakingContractAddress) } @@ -1236,11 +1157,9 @@ func NewBee( if o.EnableStorageIncentives { + // Redistribution address malformed-hex was validated by validateChainContractOptions. redistributionContractAddress := chainCfg.RedistributionAddress if o.RedistributionContractAddress != "" { - if !common.IsHexAddress(o.RedistributionContractAddress) { - return nil, errors.New("malformed redistribution contract address") - } redistributionContractAddress = common.HexToAddress(o.RedistributionContractAddress) } @@ -1532,6 +1451,84 @@ func isChainEnabled(o *Options, swapEndpoint string, logger log.Logger) bool { return true // all other modes operate require chain enabled } +// validateOptions checks Options for invalid values that can be detected +// without performing any I/O and without needing chainID. It is the single +// place where these config-shape errors are caught at NewBee entry. +func validateOptions(o *Options) error { + if err := validatePublicAddress(o.NATAddr); err != nil { + return fmt.Errorf("invalid NAT address %s: %w", o.NATAddr, err) + } + if err := validatePublicAddress(o.NATWSSAddr); err != nil { + return fmt.Errorf("invalid NAT WSS address %s: %w", o.NATWSSAddr, err) + } + if !o.FullNodeMode && o.ReserveCapacityDoubling != 0 { + return fmt.Errorf("reserve capacity doubling is only allowed for full nodes") + } + if o.ReserveCapacityDoubling < 0 || o.ReserveCapacityDoubling > maxAllowedDoubling { + return fmt.Errorf("config reserve capacity doubling has to be between default: 0 and maximum: %d", maxAllowedDoubling) + } + if _, err := parsePaymentThreshold(o.PaymentThreshold); err != nil { + return err + } + if o.PaymentTolerance < 0 { + return fmt.Errorf("invalid payment tolerance: %d", o.PaymentTolerance) + } + if o.PaymentEarly > 100 || o.PaymentEarly < 0 { + return fmt.Errorf("invalid payment early: %d", o.PaymentEarly) + } + // The neighborhood may also be supplied at runtime by the suggester URL; + // here we only validate what the user typed into config. + if o.TargetNeighborhood != "" { + if _, err := swarm.ParseBitStrAddress(o.TargetNeighborhood); err != nil { + return fmt.Errorf("invalid neighborhood. %s", o.TargetNeighborhood) + } + } + return nil +} + +// parsePaymentThreshold parses PaymentThreshold and verifies it sits in the +// accepted [minPaymentThreshold, maxPaymentThreshold] range. The parsed value +// is returned so callers that need the bigint don't re-parse. +func parsePaymentThreshold(s string) (*big.Int, error) { + pt, ok := new(big.Int).SetString(s, 10) + if !ok { + return nil, fmt.Errorf("invalid payment threshold: %s", pt) + } + if pt.Cmp(big.NewInt(minPaymentThreshold)) < 0 { + return nil, fmt.Errorf("payment threshold below minimum generally accepted value, need at least %d", minPaymentThreshold) + } + if pt.Cmp(big.NewInt(maxPaymentThreshold)) > 0 { + return nil, fmt.Errorf("payment threshold above maximum generally accepted value, needs to be reduced to at most %d", maxPaymentThreshold) + } + return pt, nil +} + +// validateChainContractOptions runs after chainID is known. It catches +// malformed user-supplied contract addresses, the postage "no known address +// for this network" case, and the missing-start-block case for a custom +// postage contract. The redistribution check is only performed when +// EnableStorageIncentives is set, preserving prior behavior. +func validateChainContractOptions(o *Options, chainID int64) error { + _, found := config.GetByChainID(chainID) + if o.PostageContractAddress != "" { + if !common.IsHexAddress(o.PostageContractAddress) { + return errors.New("malformed postage stamp address") + } + if o.PostageContractStartBlock == 0 { + return errors.New("postage contract start block option not provided") + } + } else if !found { + return errors.New("no known postage stamp addresses for this network") + } + if o.StakingContractAddress != "" && !common.IsHexAddress(o.StakingContractAddress) { + return errors.New("malformed staking contract address") + } + if o.EnableStorageIncentives && o.RedistributionContractAddress != "" && !common.IsHexAddress(o.RedistributionContractAddress) { + return errors.New("malformed redistribution contract address") + } + return nil +} + func validatePublicAddress(addr string) error { if addr == "" { return nil diff --git a/pkg/node/node_test.go b/pkg/node/node_test.go index 5b5ca7ad967..acd29bab05e 100644 --- a/pkg/node/node_test.go +++ b/pkg/node/node_test.go @@ -5,9 +5,15 @@ package node_test import ( + "errors" + "strings" "testing" + "github.com/ethersphere/bee/v2/pkg/config" + "github.com/ethersphere/bee/v2/pkg/log" "github.com/ethersphere/bee/v2/pkg/node" + statestoremock "github.com/ethersphere/bee/v2/pkg/statestore/mock" + "github.com/ethersphere/bee/v2/pkg/storage" ) func TestValidatePublicAddress(t *testing.T) { @@ -107,3 +113,376 @@ func TestValidatePublicAddress(t *testing.T) { }) } } + +// validBaseOptions returns an Options value that passes validateOptions. Tests +// mutate one field at a time to assert exactly which check fires. +func validBaseOptions() node.Options { + return node.Options{ + FullNodeMode: true, + PaymentThreshold: "10000000", // sits inside [minPaymentThreshold, maxPaymentThreshold] + PaymentTolerance: 0, + PaymentEarly: 0, + } +} + +func TestValidateOptions(t *testing.T) { + t.Parallel() + + // maxAllowedDoubling is an unexported constant in pkg/node; mirror its value + // here so a deliberate change to the limit forces a deliberate test update. + const maxAllowedDoubling = 1 + + mutate := func(f func(*node.Options)) node.Options { + o := validBaseOptions() + f(&o) + return o + } + + testCases := []struct { + name string + opts node.Options + wantErr string + }{ + { + name: "all valid full node", + opts: validBaseOptions(), + }, + { + name: "all valid light node", + opts: mutate(func(o *node.Options) { o.FullNodeMode = false }), + }, + { + name: "invalid NAT address", + opts: mutate(func(o *node.Options) { o.NATAddr = "localhost:1635" }), + wantErr: "invalid NAT address", + }, + { + name: "invalid NAT WSS address", + opts: mutate(func(o *node.Options) { o.NATWSSAddr = "127.0.0.1:1635" }), + wantErr: "invalid NAT WSS address", + }, + { + name: "light node with non-zero doubling", + opts: mutate(func(o *node.Options) { + o.FullNodeMode = false + o.ReserveCapacityDoubling = 1 + }), + wantErr: "reserve capacity doubling is only allowed for full nodes", + }, + { + name: "doubling above max", + opts: mutate(func(o *node.Options) { o.ReserveCapacityDoubling = maxAllowedDoubling + 1 }), + wantErr: "reserve capacity doubling has to be between", + }, + { + name: "doubling negative", + opts: mutate(func(o *node.Options) { o.ReserveCapacityDoubling = -1 }), + wantErr: "reserve capacity doubling has to be between", + }, + { + name: "payment threshold empty", + opts: mutate(func(o *node.Options) { o.PaymentThreshold = "" }), + wantErr: "invalid payment threshold", + }, + { + name: "payment threshold non-numeric", + opts: mutate(func(o *node.Options) { o.PaymentThreshold = "abc" }), + wantErr: "invalid payment threshold", + }, + { + name: "payment threshold below minimum", + opts: mutate(func(o *node.Options) { o.PaymentThreshold = "1" }), + wantErr: "payment threshold below minimum", + }, + { + name: "payment threshold above maximum", + opts: mutate(func(o *node.Options) { o.PaymentThreshold = "999999999999" }), + wantErr: "payment threshold above maximum", + }, + { + name: "payment tolerance negative", + opts: mutate(func(o *node.Options) { o.PaymentTolerance = -1 }), + wantErr: "invalid payment tolerance", + }, + { + name: "payment early negative", + opts: mutate(func(o *node.Options) { o.PaymentEarly = -1 }), + wantErr: "invalid payment early", + }, + { + name: "payment early above 100", + opts: mutate(func(o *node.Options) { o.PaymentEarly = 101 }), + wantErr: "invalid payment early", + }, + { + name: "payment early at boundary 100", + opts: mutate(func(o *node.Options) { o.PaymentEarly = 100 }), + }, + { + name: "target neighborhood invalid bitstring", + opts: mutate(func(o *node.Options) { o.TargetNeighborhood = "10X01" }), + wantErr: "invalid neighborhood", + }, + { + name: "target neighborhood valid bitstring", + opts: mutate(func(o *node.Options) { o.TargetNeighborhood = "101010" }), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := node.ValidateOptions(&tc.opts) + if tc.wantErr == "" { + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error containing %q, got %q", tc.wantErr, err.Error()) + } + }) + } +} + +func TestParsePaymentThreshold(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input string + wantErr string + want string // string form of the expected bigint, only checked on success + }{ + {name: "empty parses to zero, fails below minimum", input: "", wantErr: "invalid payment threshold"}, + {name: "non-numeric", input: "nope", wantErr: "invalid payment threshold"}, + {name: "below minimum", input: "1", wantErr: "below minimum"}, + {name: "above maximum", input: "999999999999", wantErr: "above maximum"}, + {name: "at minimum", input: "9000000", want: "9000000"}, + {name: "at maximum", input: "108000000", want: "108000000"}, + {name: "inside range", input: "10000000", want: "10000000"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got, err := node.ParsePaymentThreshold(tc.input) + if tc.wantErr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error containing %q, got %q", tc.wantErr, err.Error()) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.String() != tc.want { + t.Fatalf("parsed %q -> %s, want %s", tc.input, got.String(), tc.want) + } + }) + } +} + +func TestValidateChainContractOptions(t *testing.T) { + t.Parallel() + + const ( + hexAddr = "0x1234567890123456789012345678901234567890" + bogus = "not-a-hex-address" + ) + // 99999999 is not testnet or mainnet, so config.GetByChainID returns + // found=false for it. + const unknownChain int64 = 99999999 + + testCases := []struct { + name string + opts node.Options + chainID int64 + wantErr string + }{ + { + name: "known chain, no custom postage address is OK", + opts: node.Options{}, + chainID: config.Mainnet.ChainID, + }, + { + name: "unknown chain without custom postage address", + opts: node.Options{}, + chainID: unknownChain, + wantErr: "no known postage stamp addresses for this network", + }, + { + name: "custom postage address malformed", + opts: node.Options{PostageContractAddress: bogus, PostageContractStartBlock: 1}, + chainID: unknownChain, + wantErr: "malformed postage stamp address", + }, + { + name: "custom postage address without start block", + opts: node.Options{PostageContractAddress: hexAddr, PostageContractStartBlock: 0}, + chainID: unknownChain, + wantErr: "postage contract start block option not provided", + }, + { + name: "custom postage address with start block on unknown chain is OK", + opts: node.Options{PostageContractAddress: hexAddr, PostageContractStartBlock: 1}, + chainID: unknownChain, + }, + { + name: "staking address malformed", + opts: node.Options{StakingContractAddress: bogus}, + chainID: config.Mainnet.ChainID, + wantErr: "malformed staking contract address", + }, + { + name: "staking address valid", + opts: node.Options{StakingContractAddress: hexAddr}, + chainID: config.Mainnet.ChainID, + }, + { + name: "redistribution address malformed but incentives off — silently accepted", + opts: node.Options{RedistributionContractAddress: bogus, EnableStorageIncentives: false}, + chainID: config.Mainnet.ChainID, + }, + { + name: "redistribution address malformed with incentives on", + opts: node.Options{RedistributionContractAddress: bogus, EnableStorageIncentives: true}, + chainID: config.Mainnet.ChainID, + wantErr: "malformed redistribution contract address", + }, + { + name: "redistribution address valid with incentives on", + opts: node.Options{RedistributionContractAddress: hexAddr, EnableStorageIncentives: true}, + chainID: config.Mainnet.ChainID, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := node.ValidateChainContractOptions(&tc.opts, tc.chainID) + if tc.wantErr == "" { + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error containing %q, got %q", tc.wantErr, err.Error()) + } + }) + } +} + +func TestIsChainEnabled(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + fullNode bool + swapEndpoint string + want bool + }{ + {name: "light no endpoint disables chain", fullNode: false, swapEndpoint: "", want: false}, + {name: "light with endpoint enables chain", fullNode: false, swapEndpoint: "http://rpc", want: true}, + {name: "full no endpoint enables chain", fullNode: true, swapEndpoint: "", want: true}, + {name: "full with endpoint enables chain", fullNode: true, swapEndpoint: "http://rpc", want: true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := node.IsChainEnabled(&node.Options{FullNodeMode: tc.fullNode}, tc.swapEndpoint, log.Noop) + if got != tc.want { + t.Fatalf("isChainEnabled(fullNode=%v, swapEndpoint=%q) = %v, want %v", + tc.fullNode, tc.swapEndpoint, got, tc.want) + } + }) + } +} + +func TestBatchStoreExists(t *testing.T) { + t.Parallel() + + t.Run("empty store", func(t *testing.T) { + t.Parallel() + s := statestoremock.NewStateStore() + got, err := node.BatchStoreExists(s) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got { + t.Fatal("expected false on an empty store") + } + }) + + t.Run("unrelated key only", func(t *testing.T) { + t.Parallel() + s := statestoremock.NewStateStore() + if err := s.Put("not_batchstore_key", []byte("v")); err != nil { + t.Fatalf("put: %v", err) + } + got, err := node.BatchStoreExists(s) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got { + t.Fatal("expected false when no batchstore_ prefixed key is present") + } + }) + + t.Run("batchstore key present", func(t *testing.T) { + t.Parallel() + s := statestoremock.NewStateStore() + if err := s.Put("batchstore_foo", []byte("v")); err != nil { + t.Fatalf("put: %v", err) + } + got, err := node.BatchStoreExists(s) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !got { + t.Fatal("expected true when a batchstore_ prefixed key is present") + } + }) + + t.Run("iterate error surfaces", func(t *testing.T) { + t.Parallel() + want := errors.New("boom") + got, err := node.BatchStoreExists(&iterErrStore{err: want}) + if !errors.Is(err, want) { + t.Fatalf("expected %v, got %v", want, err) + } + if got { + t.Fatal("expected false when iterate returns an error") + } + }) +} + +// iterErrStore is a StateStorer that returns a fixed error from Iterate. +// batchStoreExists only ever calls Iterate, so the other methods are stubbed. +type iterErrStore struct { + err error +} + +var _ storage.StateStorer = (*iterErrStore)(nil) + +func (s *iterErrStore) Iterate(_ string, _ storage.StateIterFunc) error { return s.err } +func (s *iterErrStore) Get(string, any) error { return nil } +func (s *iterErrStore) Put(string, any) error { return nil } +func (s *iterErrStore) Delete(string) error { return nil } +func (s *iterErrStore) Close() error { return nil } diff --git a/pkg/node/setuppostagecontract_test.go b/pkg/node/setuppostagecontract_test.go new file mode 100644 index 00000000000..6ae8cdb171e --- /dev/null +++ b/pkg/node/setuppostagecontract_test.go @@ -0,0 +1,168 @@ +// Copyright 2025 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package node_test + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + "github.com/ethersphere/bee/v2/pkg/config" + "github.com/ethersphere/bee/v2/pkg/node" + "github.com/ethersphere/bee/v2/pkg/transaction" +) + +// These tests cover setupPostageContract: it resolves the postage stamp +// contract address (default-from-chain-config vs custom override), parses the +// ABI, and calls LookupERC20Address. Validation of malformed addresses / +// missing start blocks already lives in validateChainContractOptions, so the +// happy paths and the LookupERC20 error are the interesting cases here. + +func mockPostageDeps(t *testing.T, want common.Address, wantErr error, lookupCalls *int, observed *common.Address) node.PostageContractDeps { + t.Helper() + return node.PostageContractDeps{ + LookupERC20: func(_ context.Context, _ transaction.Service, postageStampContractAddress common.Address, _ abi.ABI, _ bool) (common.Address, error) { + *lookupCalls++ + if observed != nil { + *observed = postageStampContractAddress + } + return want, wantErr + }, + } +} + +func TestSetupPostageContract_DefaultChainConfig_NoOverride(t *testing.T) { + t.Parallel() + + wantBzz := common.HexToAddress("0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + var calls int + var lookedUpAddr common.Address + + res, err := node.SetupPostageContract( + context.Background(), + &node.Options{}, // no custom postage address + config.Mainnet.ChainID, + true, + nil, // transactionService — fake never uses it + mockPostageDeps(t, wantBzz, nil, &calls, &lookedUpAddr), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 1 { + t.Fatalf("LookupERC20 should run exactly once, got %d", calls) + } + if res.BzzTokenAddress != wantBzz { + t.Fatalf("BzzTokenAddress: got %s, want %s", res.BzzTokenAddress, wantBzz) + } + // Without a custom override, the address fed to LookupERC20 must be the + // chain's default postage stamp address. + if lookedUpAddr != config.Mainnet.PostageStampAddress { + t.Fatalf("looked up addr: got %s, want %s (chain default)", lookedUpAddr, config.Mainnet.PostageStampAddress) + } + if res.ContractAddress != config.Mainnet.PostageStampAddress { + t.Fatalf("ContractAddress: got %s, want %s", res.ContractAddress, config.Mainnet.PostageStampAddress) + } + if res.SyncStartBlock != config.Mainnet.PostageStampStartBlock { + t.Fatalf("SyncStartBlock: got %d, want %d", res.SyncStartBlock, config.Mainnet.PostageStampStartBlock) + } + if res.ChainConfig.ChainID != config.Mainnet.ChainID { + t.Fatalf("ChainConfig.ChainID: got %d, want %d", res.ChainConfig.ChainID, config.Mainnet.ChainID) + } +} + +func TestSetupPostageContract_CustomOverride_UsedInLookup(t *testing.T) { + t.Parallel() + + customAddr := "0x1234567890123456789012345678901234567890" + customStart := uint64(42) + wantBzz := common.HexToAddress("0xcccccccccccccccccccccccccccccccccccccccc") + var calls int + var lookedUpAddr common.Address + + res, err := node.SetupPostageContract( + context.Background(), + &node.Options{ + PostageContractAddress: customAddr, + PostageContractStartBlock: customStart, + }, + config.Mainnet.ChainID, + true, + nil, + mockPostageDeps(t, wantBzz, nil, &calls, &lookedUpAddr), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 1 { + t.Fatalf("LookupERC20 should run exactly once, got %d", calls) + } + wantCustom := common.HexToAddress(customAddr) + if lookedUpAddr != wantCustom { + t.Fatalf("LookupERC20 must receive the custom address, got %s, want %s", lookedUpAddr, wantCustom) + } + if res.ContractAddress != wantCustom { + t.Fatalf("ContractAddress: got %s, want %s", res.ContractAddress, wantCustom) + } + if res.SyncStartBlock != customStart { + t.Fatalf("SyncStartBlock: got %d, want %d", res.SyncStartBlock, customStart) + } +} + +func TestSetupPostageContract_LookupError_IsWrapped(t *testing.T) { + t.Parallel() + + want := errors.New("rpc say no") + var calls int + res, err := node.SetupPostageContract( + context.Background(), + &node.Options{}, + config.Mainnet.ChainID, + true, + nil, + mockPostageDeps(t, common.Address{}, want, &calls, nil), + ) + if !errors.Is(err, want) { + t.Fatalf("expected wrapped %v, got %v", want, err) + } + if !strings.Contains(err.Error(), "lookup erc20 postage address") { + t.Fatalf("expected error to start with %q, got %q", "lookup erc20 postage address", err.Error()) + } + // On error the result is a zero-value; nothing should leak through. + if (res.BzzTokenAddress != common.Address{}) || (res.ContractAddress != common.Address{}) { + t.Fatalf("expected zero result on error, got %+v", res) + } +} + +func TestSetupPostageContract_ChainEnabledFlag_PropagatesToLookup(t *testing.T) { + t.Parallel() + + var lastChainEnabled bool + deps := node.PostageContractDeps{ + LookupERC20: func(_ context.Context, _ transaction.Service, _ common.Address, _ abi.ABI, chainEnabled bool) (common.Address, error) { + lastChainEnabled = chainEnabled + return common.Address{}, nil + }, + } + + for _, want := range []bool{true, false} { + if _, err := node.SetupPostageContract( + context.Background(), + &node.Options{}, + config.Mainnet.ChainID, + want, + nil, + deps, + ); err != nil { + t.Fatalf("chainEnabled=%v: unexpected error: %v", want, err) + } + if lastChainEnabled != want { + t.Fatalf("chainEnabled flag must be forwarded to LookupERC20: got %v, want %v", lastChainEnabled, want) + } + } +} diff --git a/pkg/node/setupswap_test.go b/pkg/node/setupswap_test.go new file mode 100644 index 00000000000..b9cba5e6c33 --- /dev/null +++ b/pkg/node/setupswap_test.go @@ -0,0 +1,354 @@ +// Copyright 2025 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package node_test + +import ( + "context" + "errors" + "math/big" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethersphere/bee/v2/pkg/crypto" + "github.com/ethersphere/bee/v2/pkg/log" + "github.com/ethersphere/bee/v2/pkg/node" + "github.com/ethersphere/bee/v2/pkg/settlement/swap/chequebook" + "github.com/ethersphere/bee/v2/pkg/settlement/swap/erc20" + "github.com/ethersphere/bee/v2/pkg/storage" + "github.com/ethersphere/bee/v2/pkg/transaction" +) + +// These tests pin down behavior of the SwapEnable / ChequebookEnable / +// chainEnabled block that used to be inlined in NewBee. They give future +// refactors of that block — e.g. issue #5233, which changes the SwapEnable=false +// path — a regression net for every behavior that is *not* supposed to change. + +// recordedSwapDeps wraps a node.SwapDeps so each test controls success / failure +// of every injected call and observes which ones ran. +type recordedSwapDeps struct { + factory chequebook.Factory + factoryErr error + factoryCalls int + + svc chequebook.Service + svcErr error + svcCalls int + + store chequebook.ChequeStore + cashout chequebook.CashoutService + cashoutCalls int +} + +func (r *recordedSwapDeps) deps() node.SwapDeps { + return node.SwapDeps{ + InitFactory: func(_ log.Logger, _ transaction.Backend, _ int64, _ transaction.Service, _ string) (chequebook.Factory, error) { + r.factoryCalls++ + if r.factoryErr != nil { + return nil, r.factoryErr + } + return r.factory, nil + }, + InitChequebookService: func(_ context.Context, _ log.Logger, _ storage.StateStorer, _ crypto.Signer, _ int64, _ transaction.Backend, _ common.Address, _ transaction.Service, _ chequebook.Factory, _ string, _ erc20.Service) (chequebook.Service, error) { + r.svcCalls++ + if r.svcErr != nil { + return nil, r.svcErr + } + return r.svc, nil + }, + InitChequeStoreCashout: func(_ storage.StateStorer, _ transaction.Backend, _ chequebook.Factory, _ int64, _ common.Address, _ transaction.Service) (chequebook.ChequeStore, chequebook.CashoutService) { + r.cashoutCalls++ + return r.store, r.cashout + }, + } +} + +// callSetupSwap is a convenience wrapper that injects a real-looking transaction +// service (nil here, since the fakes never touch it) and runs setupSwap. +func callSetupSwap(t *testing.T, o *node.Options, chainEnabled bool, deps node.SwapDeps) (node.SwapResult, error) { + t.Helper() + return node.SetupSwap( + context.Background(), + log.Noop, + o, + chainEnabled, + nil, // chainBackend — fakes never invoke it + 0, // chainID + nil, // transactionService — fakes never invoke it + nil, // stateStore + nil, // signer + common.Address{}, + deps, + ) +} + +func TestSetupSwap_DisabledMakesNoCallsAndReturnsZero(t *testing.T) { + t.Parallel() + + r := &recordedSwapDeps{factory: &swapFactoryStub{}} + res, err := callSetupSwap(t, + &node.Options{SwapEnable: false}, + true, // chainEnabled is irrelevant when SwapEnable=false + r.deps(), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.Erc20Service != nil || res.ChequebookService != nil || res.ChequeStore != nil || res.CashoutService != nil { + t.Fatalf("expected zero swapResult, got %+v", res) + } + if r.factoryCalls+r.svcCalls+r.cashoutCalls != 0 { + t.Fatalf("expected no deps to be invoked, got factory=%d svc=%d cashout=%d", + r.factoryCalls, r.svcCalls, r.cashoutCalls) + } +} + +func TestSetupSwap_FactoryErrorReturnsWrapped(t *testing.T) { + t.Parallel() + + want := errors.New("boom") + r := &recordedSwapDeps{factoryErr: want} + res, err := callSetupSwap(t, + &node.Options{SwapEnable: true, ChequebookEnable: true}, + true, + r.deps(), + ) + if !errors.Is(err, want) { + t.Fatalf("expected wrapped %v, got %v", want, err) + } + if !strings.Contains(err.Error(), "init chequebook factory") { + t.Fatalf("expected error to start with %q, got %q", "init chequebook factory", err.Error()) + } + if res.Erc20Service != nil || res.ChequeStore != nil || res.CashoutService != nil { + t.Fatal("no downstream output should be set when factory init fails") + } + if r.svcCalls != 0 || r.cashoutCalls != 0 { + t.Fatal("downstream deps must not be invoked after factory error") + } +} + +func TestSetupSwap_ERC20AddressErrorReturnsWrapped(t *testing.T) { + t.Parallel() + + want := errors.New("rpc dead") + r := &recordedSwapDeps{factory: &swapFactoryStub{erc20Err: want}} + res, err := callSetupSwap(t, + &node.Options{SwapEnable: true, ChequebookEnable: true}, + true, + r.deps(), + ) + if !errors.Is(err, want) { + t.Fatalf("expected wrapped %v, got %v", want, err) + } + if !strings.Contains(err.Error(), "factory fail") { + t.Fatalf("expected error to start with %q, got %q", "factory fail", err.Error()) + } + if r.svcCalls != 0 || r.cashoutCalls != 0 { + t.Fatal("downstream deps must not be invoked after ERC20 error") + } + if res.Erc20Service != nil { + t.Fatal("erc20Service must not be set when ERC20Address fails") + } +} + +func TestSetupSwap_ChequebookDisabled_StillSetsErc20AndCashout(t *testing.T) { + t.Parallel() + + wantStore := &chequeStoreStub{} + wantCashout := &cashoutStub{} + r := &recordedSwapDeps{factory: &swapFactoryStub{}, store: wantStore, cashout: wantCashout} + res, err := callSetupSwap(t, + &node.Options{SwapEnable: true, ChequebookEnable: false}, + true, + r.deps(), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if r.svcCalls != 0 { + t.Fatalf("chequebook service must not be initialized when ChequebookEnable=false, got %d calls", r.svcCalls) + } + if res.ChequebookService != nil { + t.Fatal("chequebookService must remain nil when ChequebookEnable=false") + } + if res.Erc20Service == nil { + t.Fatal("erc20Service must be set when SwapEnable=true") + } + if res.ChequeStore != wantStore || res.CashoutService != wantCashout { + t.Fatal("chequeStore / cashoutService must be set when SwapEnable=true") + } +} + +func TestSetupSwap_ChainDisabled_SkipsChequebookServiceOnly(t *testing.T) { + t.Parallel() + + wantStore := &chequeStoreStub{} + wantCashout := &cashoutStub{} + r := &recordedSwapDeps{factory: &swapFactoryStub{}, store: wantStore, cashout: wantCashout} + res, err := callSetupSwap(t, + &node.Options{SwapEnable: true, ChequebookEnable: true}, + false, // chainEnabled=false gates the chequebook service even when ChequebookEnable=true + r.deps(), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if r.svcCalls != 0 { + t.Fatalf("chequebook service must not be initialized when chainEnabled=false, got %d calls", r.svcCalls) + } + if res.ChequebookService != nil { + t.Fatal("chequebookService must remain nil when chainEnabled=false") + } + if res.Erc20Service == nil || res.ChequeStore != wantStore || res.CashoutService != wantCashout { + t.Fatal("erc20Service / chequeStore / cashoutService must still be set") + } +} + +func TestSetupSwap_ChequebookServiceError_ReturnsWrapped(t *testing.T) { + t.Parallel() + + want := errors.New("svc no") + r := &recordedSwapDeps{factory: &swapFactoryStub{}, svcErr: want} + _, err := callSetupSwap(t, + &node.Options{SwapEnable: true, ChequebookEnable: true}, + true, + r.deps(), + ) + if !errors.Is(err, want) { + t.Fatalf("expected wrapped %v, got %v", want, err) + } + if !strings.Contains(err.Error(), "init chequebook service") { + t.Fatalf("expected error to contain %q, got %q", "init chequebook service", err.Error()) + } + // The original block returned before initChequeStoreCashout when chequebook + // service init failed; verify that contract is preserved. + if r.cashoutCalls != 0 { + t.Fatal("initChequeStoreCashout must not run after chequebook service error") + } +} + +func TestSetupSwap_AllEnabledSuccess_SetsEveryOutput(t *testing.T) { + t.Parallel() + + wantSvc := &chequebookSvcStub{} + wantStore := &chequeStoreStub{} + wantCashout := &cashoutStub{} + r := &recordedSwapDeps{ + factory: &swapFactoryStub{}, + svc: wantSvc, + store: wantStore, + cashout: wantCashout, + } + res, err := callSetupSwap(t, + &node.Options{SwapEnable: true, ChequebookEnable: true}, + true, + r.deps(), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.Erc20Service == nil { + t.Fatal("erc20Service not set") + } + if res.ChequebookService != wantSvc { + t.Fatal("chequebookService not set to the value returned by deps.InitChequebookService") + } + if res.ChequeStore != wantStore || res.CashoutService != wantCashout { + t.Fatal("chequeStore / cashoutService not set to the values returned by deps.InitChequeStoreCashout") + } + if r.factoryCalls != 1 || r.svcCalls != 1 || r.cashoutCalls != 1 { + t.Fatalf("each dep should be called exactly once, got factory=%d svc=%d cashout=%d", + r.factoryCalls, r.svcCalls, r.cashoutCalls) + } +} + +// ---- minimal stubs for chequebook interfaces used in these tests ---- + +type swapFactoryStub struct{ erc20Err error } + +func (f *swapFactoryStub) ERC20Address(_ context.Context) (common.Address, error) { + if f.erc20Err != nil { + return common.Address{}, f.erc20Err + } + return common.Address{}, nil +} + +func (f *swapFactoryStub) Deploy(_ context.Context, _ common.Address, _ *big.Int, _ common.Hash) (common.Hash, error) { + panic("Deploy must not be called by setupSwap") +} + +func (f *swapFactoryStub) WaitDeployed(_ context.Context, _ common.Hash) (common.Address, error) { + panic("WaitDeployed must not be called by setupSwap") +} + +func (f *swapFactoryStub) VerifyChequebook(_ context.Context, _ common.Address) error { + panic("VerifyChequebook must not be called by setupSwap") +} + +func (f *swapFactoryStub) VerifyBytecode(_ context.Context) error { + panic("VerifyBytecode must not be called by setupSwap") +} + +type chequebookSvcStub struct{} + +func (chequebookSvcStub) Deposit(context.Context, *big.Int) (common.Hash, error) { + panic("Deposit must not be called") +} + +func (chequebookSvcStub) Withdraw(context.Context, *big.Int) (common.Hash, error) { + panic("Withdraw must not be called") +} + +func (chequebookSvcStub) WaitForDeposit(context.Context, common.Hash) error { + panic("WaitForDeposit must not be called") +} + +func (chequebookSvcStub) Balance(context.Context) (*big.Int, error) { + panic("Balance must not be called") +} + +func (chequebookSvcStub) AvailableBalance(context.Context) (*big.Int, error) { + panic("AvailableBalance must not be called") +} + +func (chequebookSvcStub) Address() common.Address { + panic("Address must not be called") +} + +func (chequebookSvcStub) Issue(context.Context, common.Address, *big.Int, chequebook.SendChequeFunc) (*big.Int, error) { + panic("Issue must not be called") +} + +func (chequebookSvcStub) LastCheque(common.Address) (*chequebook.SignedCheque, error) { + panic("LastCheque must not be called") +} + +func (chequebookSvcStub) LastCheques() (map[common.Address]*chequebook.SignedCheque, error) { + panic("LastCheques must not be called") +} + +type chequeStoreStub struct{} + +func (chequeStoreStub) ReceiveCheque(context.Context, *chequebook.SignedCheque, *big.Int, *big.Int) (*big.Int, error) { + panic("ReceiveCheque must not be called") +} + +func (chequeStoreStub) LastCheque(common.Address) (*chequebook.SignedCheque, error) { + panic("LastCheque must not be called") +} + +func (chequeStoreStub) LastCheques() (map[common.Address]*chequebook.SignedCheque, error) { + panic("LastCheques must not be called") +} + +type cashoutStub struct{} + +func (cashoutStub) CashCheque(context.Context, common.Address, common.Address) (common.Hash, error) { + panic("CashCheque must not be called") +} + +func (cashoutStub) CashoutStatus(context.Context, common.Address) (*chequebook.CashoutStatus, error) { + panic("CashoutStatus must not be called") +} diff --git a/pkg/node/setupswapservice_test.go b/pkg/node/setupswapservice_test.go new file mode 100644 index 00000000000..73a1f2effde --- /dev/null +++ b/pkg/node/setupswapservice_test.go @@ -0,0 +1,170 @@ +// Copyright 2025 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package node_test + +import ( + "context" + "errors" + "math/big" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethersphere/bee/v2/pkg/log" + "github.com/ethersphere/bee/v2/pkg/node" + "github.com/ethersphere/bee/v2/pkg/p2p/libp2p" + "github.com/ethersphere/bee/v2/pkg/settlement" + "github.com/ethersphere/bee/v2/pkg/settlement/swap" + "github.com/ethersphere/bee/v2/pkg/settlement/swap/chequebook" + "github.com/ethersphere/bee/v2/pkg/settlement/swap/priceoracle" + "github.com/ethersphere/bee/v2/pkg/storage" + "github.com/ethersphere/bee/v2/pkg/transaction" +) + +// These tests cover setupSwapService — the second swap block in NewBee, which +// is gated on (SwapEnable && chainEnabled) and inside that gate wires up the +// accounting PayFunc only when ChequebookEnable is true. + +type recordedSwapServiceDeps struct { + calls int + resultSwap *swap.Service + resultPrice priceoracle.Service + resultErr error +} + +func (r *recordedSwapServiceDeps) deps() node.SwapServiceDeps { + return node.SwapServiceDeps{ + InitSwap: func(_ *libp2p.Service, _ log.Logger, _ storage.StateStorer, _ uint64, _ common.Address, _ chequebook.Service, _ chequebook.ChequeStore, _ chequebook.CashoutService, _ settlement.Accounting, _ string, _ int64, _ transaction.Service) (*swap.Service, priceoracle.Service, error) { + r.calls++ + if r.resultErr != nil { + return nil, nil, r.resultErr + } + return r.resultSwap, r.resultPrice, nil + }, + } +} + +// realSwapService builds an actual *swap.Service so that swapService.Pay +// resolves to a callable method value. Most fields are nil — the tests never +// invoke Pay, they just check that PayFunc was assigned when expected. +func realSwapService(t *testing.T) *swap.Service { + t.Helper() + return swap.New(nil, log.Noop, nil, nil, nil, nil, 0, nil, nil, common.Address{}) +} + +func callSetupSwapService(t *testing.T, o *node.Options, chainEnabled bool, deps node.SwapServiceDeps) (node.SwapServiceResult, error) { + t.Helper() + return node.SetupSwapService( + o, chainEnabled, + nil, log.Noop, nil, 0, common.Address{}, + nil, nil, nil, nil, + 0, nil, + deps, + ) +} + +func TestSetupSwapService_SwapDisabled_MakesNoCalls(t *testing.T) { + t.Parallel() + + r := &recordedSwapServiceDeps{} + res, err := callSetupSwapService(t, &node.Options{SwapEnable: false, ChequebookEnable: true}, true, r.deps()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if r.calls != 0 { + t.Fatalf("InitSwap must not run when SwapEnable=false, got %d calls", r.calls) + } + if res.SwapService != nil || res.PriceOracle != nil || res.PayFunc != nil { + t.Fatalf("expected zero result, got %+v", res) + } +} + +func TestSetupSwapService_ChainDisabled_MakesNoCalls(t *testing.T) { + t.Parallel() + + r := &recordedSwapServiceDeps{} + res, err := callSetupSwapService(t, &node.Options{SwapEnable: true, ChequebookEnable: true}, false, r.deps()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if r.calls != 0 { + t.Fatalf("InitSwap must not run when chainEnabled=false, got %d calls", r.calls) + } + if res.SwapService != nil || res.PriceOracle != nil || res.PayFunc != nil { + t.Fatalf("expected zero result, got %+v", res) + } +} + +func TestSetupSwapService_InitSwapError_Wraps(t *testing.T) { + t.Parallel() + + want := errors.New("init swap exploded") + r := &recordedSwapServiceDeps{resultErr: want} + res, err := callSetupSwapService(t, &node.Options{SwapEnable: true, ChequebookEnable: true}, true, r.deps()) + if !errors.Is(err, want) { + t.Fatalf("expected wrapped %v, got %v", want, err) + } + if !strings.Contains(err.Error(), "init swap service") { + t.Fatalf("expected error to start with %q, got %q", "init swap service", err.Error()) + } + if res.SwapService != nil || res.PriceOracle != nil || res.PayFunc != nil { + t.Fatalf("expected zero result on error, got %+v", res) + } +} + +func TestSetupSwapService_Enabled_ChequebookDisabled_NoPayFunc(t *testing.T) { + t.Parallel() + + wantSwap := realSwapService(t) + wantPrice := &priceOracleStub{} + r := &recordedSwapServiceDeps{resultSwap: wantSwap, resultPrice: wantPrice} + + res, err := callSetupSwapService(t, &node.Options{SwapEnable: true, ChequebookEnable: false}, true, r.deps()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if r.calls != 1 { + t.Fatalf("InitSwap should run exactly once, got %d", r.calls) + } + if res.SwapService != wantSwap { + t.Fatal("SwapService not propagated") + } + if res.PriceOracle != wantPrice { + t.Fatal("PriceOracle not propagated") + } + if res.PayFunc != nil { + t.Fatal("PayFunc must be nil when ChequebookEnable=false") + } +} + +func TestSetupSwapService_Enabled_ChequebookEnabled_SetsPayFunc(t *testing.T) { + t.Parallel() + + wantSwap := realSwapService(t) + wantPrice := &priceOracleStub{} + r := &recordedSwapServiceDeps{resultSwap: wantSwap, resultPrice: wantPrice} + + res, err := callSetupSwapService(t, &node.Options{SwapEnable: true, ChequebookEnable: true}, true, r.deps()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.SwapService != wantSwap || res.PriceOracle != wantPrice { + t.Fatal("Swap/PriceOracle not propagated") + } + if res.PayFunc == nil { + t.Fatal("PayFunc must be set when SwapEnable && chainEnabled && ChequebookEnable") + } +} + +// priceOracleStub is a no-op priceoracle.Service so we have an identity-comparable +// value to thread through the result struct. +type priceOracleStub struct{} + +func (priceOracleStub) Start() {} +func (priceOracleStub) Close() error { return nil } +func (priceOracleStub) GetPrice(context.Context) (*big.Int, *big.Int, error) { + return nil, nil, nil +} +func (priceOracleStub) CurrentRates() (*big.Int, *big.Int, error) { return nil, nil, nil } diff --git a/pkg/node/shutdown_test.go b/pkg/node/shutdown_test.go new file mode 100644 index 00000000000..9ba84e91437 --- /dev/null +++ b/pkg/node/shutdown_test.go @@ -0,0 +1,195 @@ +// Copyright 2025 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package node_test + +import ( + "errors" + "strings" + "sync" + "sync/atomic" + "testing" + + "github.com/ethersphere/bee/v2/pkg/log" + "github.com/ethersphere/bee/v2/pkg/node" +) + +// trackingCloser records whether Close was called and, optionally, returns a +// fixed error. It is safe for concurrent use because Shutdown closes several +// components in parallel. +type trackingCloser struct { + closed atomic.Bool + err error +} + +func (t *trackingCloser) Close() error { + t.closed.Store(true) + return t.err +} + +func TestShutdown_NilCloserFieldsAreSafe(t *testing.T) { + t.Parallel() + + b, ctx := node.NewBeeForShutdownTest(log.Noop, node.ShutdownTestClosers{}) + + if err := b.Shutdown(); err != nil { + t.Fatalf("Shutdown on a Bee with no closers returned %v, want nil", err) + } + select { + case <-ctx.Done(): + default: + t.Fatal("ctxCancel was not invoked by Shutdown") + } +} + +func TestShutdown_CallsEveryNonNilCloser(t *testing.T) { + t.Parallel() + + // Cover representative closers from each of Shutdown's phases: + // parallel first phase, sequential phase, parallel chain phase, and + // the final sequential tail. + closers := map[string]*trackingCloser{ + "api": {}, + "pss": {}, + "gsoc": {}, + "pusher": {}, + "puller": {}, + "accounting": {}, + "pullSync": {}, + "hive": {}, + "salud": {}, + "p2p": {}, + "priceOracle": {}, + "transactionMonitor": {}, + "transaction": {}, + "listener": {}, + "postageService": {}, + "accesscontrol": {}, + "tracer": {}, + "topology": {}, + "storageIncentives": {}, + "stabilization": {}, + "localstore": {}, + "stateStore": {}, + "stamperStore": {}, + "resolver": {}, + } + + var ethClientCalled atomic.Bool + b, _ := node.NewBeeForShutdownTest(log.Noop, node.ShutdownTestClosers{ + API: closers["api"], + PSS: closers["pss"], + GSOC: closers["gsoc"], + Pusher: closers["pusher"], + Puller: closers["puller"], + Accounting: closers["accounting"], + PullSync: closers["pullSync"], + Hive: closers["hive"], + Salud: closers["salud"], + P2P: closers["p2p"], + PriceOracle: closers["priceOracle"], + TransactionMonitor: closers["transactionMonitor"], + Transaction: closers["transaction"], + Listener: closers["listener"], + PostageService: closers["postageService"], + AccessControl: closers["accesscontrol"], + Tracer: closers["tracer"], + Topology: closers["topology"], + StorageIncentives: closers["storageIncentives"], + Stabilization: closers["stabilization"], + Localstore: closers["localstore"], + StateStore: closers["stateStore"], + StamperStore: closers["stamperStore"], + Resolver: closers["resolver"], + EthClient: func() { ethClientCalled.Store(true) }, + }) + + if err := b.Shutdown(); err != nil { + t.Fatalf("Shutdown returned unexpected error: %v", err) + } + + for name, c := range closers { + if !c.closed.Load() { + t.Errorf("closer %q was not invoked", name) + } + } + if !ethClientCalled.Load() { + t.Error("ethClientCloser was not invoked") + } +} + +func TestShutdown_AggregatesErrors(t *testing.T) { + t.Parallel() + + apiErr := errors.New("api boom") + stateErr := errors.New("statestore boom") + tracerErr := errors.New("tracer boom") + + b, _ := node.NewBeeForShutdownTest(log.Noop, node.ShutdownTestClosers{ + API: &trackingCloser{err: apiErr}, + StateStore: &trackingCloser{err: stateErr}, + Tracer: &trackingCloser{err: tracerErr}, + }) + + err := b.Shutdown() + if err == nil { + t.Fatal("expected aggregated error, got nil") + } + msg := err.Error() + for _, want := range []string{"api", "statestore", "tracer"} { + if !strings.Contains(msg, want) { + t.Errorf("expected aggregated error to mention %q, got %q", want, msg) + } + } +} + +func TestShutdown_SecondCallReturnsErrShutdownInProgress(t *testing.T) { + t.Parallel() + + b, _ := node.NewBeeForShutdownTest(log.Noop, node.ShutdownTestClosers{}) + + if err := b.Shutdown(); err != nil { + t.Fatalf("first Shutdown returned %v, want nil", err) + } + if err := b.Shutdown(); !errors.Is(err, node.ErrShutdownInProgress) { + t.Fatalf("second Shutdown returned %v, want ErrShutdownInProgress", err) + } +} + +func TestShutdown_ConcurrentCallsExactlyOneRuns(t *testing.T) { + t.Parallel() + + // All shutdown work is no-op (no closers set), so the only way to tell + // who got the "first" slot is by the returned error: exactly one caller + // should see nil and the rest should see ErrShutdownInProgress. + b, _ := node.NewBeeForShutdownTest(log.Noop, node.ShutdownTestClosers{}) + + const callers = 8 + var ( + wg sync.WaitGroup + nilCount atomic.Int32 + errCount atomic.Int32 + ) + wg.Add(callers) + for i := 0; i < callers; i++ { + go func() { + defer wg.Done() + if err := b.Shutdown(); err == nil { + nilCount.Add(1) + } else if errors.Is(err, node.ErrShutdownInProgress) { + errCount.Add(1) + } else { + t.Errorf("unexpected error: %v", err) + } + }() + } + wg.Wait() + + if got := nilCount.Load(); got != 1 { + t.Fatalf("expected exactly 1 caller to run shutdown, got %d", got) + } + if got := errCount.Load(); got != callers-1 { + t.Fatalf("expected %d callers to get ErrShutdownInProgress, got %d", callers-1, got) + } +} diff --git a/pkg/node/statestore_helpers_test.go b/pkg/node/statestore_helpers_test.go new file mode 100644 index 00000000000..8140836d008 --- /dev/null +++ b/pkg/node/statestore_helpers_test.go @@ -0,0 +1,145 @@ +// Copyright 2025 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package node_test + +import ( + "strings" + "testing" + + "github.com/ethersphere/bee/v2/pkg/node" + statestoremock "github.com/ethersphere/bee/v2/pkg/statestore/mock" + "github.com/ethersphere/bee/v2/pkg/swarm" +) + +// These tests cover the small persistence helpers in pkg/node/statestore.go. +// They are pure logic over a StateStorer and therefore safe to unit-test with +// the mock store; the actual leveldb backing is exercised by InitStateStore / +// InitStamperStore tests in their own files. + +func TestOverlayNonceExists(t *testing.T) { + t.Parallel() + + t.Run("not present returns false and a zero-filled nonce", func(t *testing.T) { + t.Parallel() + s := statestoremock.NewStateStore() + nonce, exists, err := node.OverlayNonceExists(s) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exists { + t.Fatal("expected exists=false on a fresh store") + } + if len(nonce) != 32 { + t.Fatalf("expected a 32-byte nonce buffer, got %d", len(nonce)) + } + }) + + t.Run("present returns the stored bytes and true", func(t *testing.T) { + t.Parallel() + s := statestoremock.NewStateStore() + want := make([]byte, 32) + for i := range want { + want[i] = byte(i + 1) + } + if err := s.Put(node.OverlayNonceKey, want); err != nil { + t.Fatalf("put nonce: %v", err) + } + nonce, exists, err := node.OverlayNonceExists(s) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !exists { + t.Fatal("expected exists=true when nonce is stored") + } + if string(nonce) != string(want) { + t.Fatalf("nonce mismatch: got % x, want % x", nonce, want) + } + }) +} + +func TestSetOverlay_RoundtripsBothKeys(t *testing.T) { + t.Parallel() + + s := statestoremock.NewStateStore() + overlay := swarm.RandAddress(t) + nonce := make([]byte, 32) + for i := range nonce { + nonce[i] = byte(i) + } + + if err := node.SetOverlay(s, overlay, nonce); err != nil { + t.Fatalf("SetOverlay: %v", err) + } + + var gotNonce []byte + if err := s.Get(node.OverlayNonceKey, &gotNonce); err != nil { + t.Fatalf("Get nonce: %v", err) + } + if string(gotNonce) != string(nonce) { + t.Fatalf("stored nonce mismatch: got % x, want % x", gotNonce, nonce) + } + + var gotOverlay swarm.Address + if err := s.Get(node.NoncedOverlayKey, &gotOverlay); err != nil { + t.Fatalf("Get overlay: %v", err) + } + if !gotOverlay.Equal(overlay) { + t.Fatalf("stored overlay mismatch: got %s, want %s", gotOverlay, overlay) + } +} + +func TestCheckOverlay(t *testing.T) { + t.Parallel() + + t.Run("empty store writes the overlay and returns nil", func(t *testing.T) { + t.Parallel() + s := statestoremock.NewStateStore() + overlay := swarm.RandAddress(t) + + if err := node.CheckOverlay(s, overlay); err != nil { + t.Fatalf("first call expected nil, got %v", err) + } + + var stored swarm.Address + if err := s.Get(node.NoncedOverlayKey, &stored); err != nil { + t.Fatalf("expected overlay to have been written: %v", err) + } + if !stored.Equal(overlay) { + t.Fatalf("stored overlay mismatch: got %s, want %s", stored, overlay) + } + }) + + t.Run("matching stored overlay returns nil", func(t *testing.T) { + t.Parallel() + s := statestoremock.NewStateStore() + overlay := swarm.RandAddress(t) + if err := s.Put(node.NoncedOverlayKey, overlay); err != nil { + t.Fatalf("put overlay: %v", err) + } + if err := node.CheckOverlay(s, overlay); err != nil { + t.Fatalf("expected nil when stored overlay matches, got %v", err) + } + }) + + t.Run("differing stored overlay returns error", func(t *testing.T) { + t.Parallel() + s := statestoremock.NewStateStore() + stored := swarm.RandAddress(t) + incoming := swarm.RandAddress(t) + if stored.Equal(incoming) { + t.Skip("random addresses collided; rerun") + } + if err := s.Put(node.NoncedOverlayKey, stored); err != nil { + t.Fatalf("put overlay: %v", err) + } + err := node.CheckOverlay(s, incoming) + if err == nil { + t.Fatal("expected error when stored overlay differs") + } + if !strings.Contains(err.Error(), "overlay address changed") { + t.Fatalf("expected error to mention overlay change, got %q", err.Error()) + } + }) +}