diff --git a/pkg/execution/mimicry/config.go b/pkg/execution/mimicry/config.go new file mode 100644 index 0000000..23a6e17 --- /dev/null +++ b/pkg/execution/mimicry/config.go @@ -0,0 +1,84 @@ +package mimicry + +import ( + "bytes" + "crypto/ecdsa" + "time" +) + +// StatusProvider is a callback function that returns a Status response +// for an incoming connection. It receives the remote Hello for context. +type StatusProvider func(remoteHello *Hello) (Status, error) + +// Network defines filter criteria for allowed networks. +// All non-nil/non-zero fields must match for a peer to be accepted. +// If all fields are nil/zero, matches any network. +type Network struct { + NetworkID *uint64 // nil = don't filter on network ID + ForkIDHash []byte // nil = don't filter on fork ID hash + ForkIDNext *uint64 // nil = don't filter on fork ID next + Genesis []byte // nil = don't filter on genesis hash +} + +// Matches returns true if the given Status matches this Network filter. +func (n *Network) Matches(status Status) bool { + if n.NetworkID != nil && *n.NetworkID != status.GetNetworkID() { + return false + } + + if n.ForkIDHash != nil && !bytes.Equal(n.ForkIDHash, status.GetForkIDHash()) { + return false + } + + if n.ForkIDNext != nil && *n.ForkIDNext != status.GetForkIDNext() { + return false + } + + if n.Genesis != nil && !bytes.Equal(n.Genesis, status.GetGenesis()) { + return false + } + + return true +} + +// ServerConfig configures the Server for accepting incoming connections. +type ServerConfig struct { + // Name is the name to advertise in Hello messages. + Name string + + // PrivateKey is the ECDSA private key for the server. + // If nil, a new key will be generated. + PrivateKey *ecdsa.PrivateKey + + // ListenAddr is the address to listen on (e.g., ":30303"). + ListenAddr string + + // MaxPeers is the maximum number of concurrent connections. + // 0 means unlimited. + MaxPeers int + + // HandshakeTimeout is the timeout for RLPx handshake. + // Default: 5s + HandshakeTimeout time.Duration + + // ReadTimeout is the timeout for reading messages. + // Default: 30s + ReadTimeout time.Duration + + // StatusProvider is a required callback that returns the Status + // to send in response to incoming peers. + StatusProvider StatusProvider + + // AllowedNetworks filters incoming peers by network. + // If empty/nil, all networks are allowed. + // If set, peer must match at least ONE of the networks. + AllowedNetworks []Network +} + +// DefaultServerConfig returns a ServerConfig with sensible defaults. +func DefaultServerConfig() *ServerConfig { + return &ServerConfig{ + HandshakeTimeout: 5 * time.Second, + ReadTimeout: 30 * time.Second, + } +} diff --git a/pkg/execution/mimicry/config_test.go b/pkg/execution/mimicry/config_test.go new file mode 100644 index 0000000..a509deb --- /dev/null +++ b/pkg/execution/mimicry/config_test.go @@ -0,0 +1,193 @@ +package mimicry + +import ( + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockStatus implements the Status interface for testing. +type mockStatus struct { + networkID uint64 + forkIDHash []byte + forkIDNext uint64 + genesis []byte + head []byte +} + +func (m *mockStatus) Code() int { return StatusCode } +func (m *mockStatus) ReqID() uint64 { return 0 } +func (m *mockStatus) GetNetworkID() uint64 { return m.networkID } +func (m *mockStatus) GetForkIDHash() []byte { return m.forkIDHash } +func (m *mockStatus) GetForkIDNext() uint64 { return m.forkIDNext } +func (m *mockStatus) GetGenesis() []byte { return m.genesis } +func (m *mockStatus) GetHead() []byte { return m.head } + +func ptr[T any](v T) *T { + return &v +} + +func TestNetworkMatches(t *testing.T) { + mainnetGenesis := common.HexToHash("0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3") + sepoliaGenesis := common.HexToHash("0x25a5cc106eea7138acab33231d7160d69cb777ee0c2c553fcddf5138993e6dd9") + dencunForkHash := []byte{0x9f, 0x3d, 0x22, 0x54} + + tests := []struct { + name string + network Network + status Status + want bool + }{ + { + name: "empty filter matches any network", + network: Network{}, + status: &mockStatus{ + networkID: 1, + genesis: mainnetGenesis[:], + forkIDHash: dencunForkHash, + forkIDNext: 0, + }, + want: true, + }, + { + name: "network ID filter - match", + network: Network{ + NetworkID: ptr(uint64(1)), + }, + status: &mockStatus{ + networkID: 1, + genesis: mainnetGenesis[:], + }, + want: true, + }, + { + name: "network ID filter - no match", + network: Network{ + NetworkID: ptr(uint64(1)), + }, + status: &mockStatus{ + networkID: 11155111, // sepolia + genesis: sepoliaGenesis[:], + }, + want: false, + }, + { + name: "genesis filter - match", + network: Network{ + Genesis: mainnetGenesis[:], + }, + status: &mockStatus{ + networkID: 1, + genesis: mainnetGenesis[:], + }, + want: true, + }, + { + name: "genesis filter - no match", + network: Network{ + Genesis: mainnetGenesis[:], + }, + status: &mockStatus{ + networkID: 11155111, + genesis: sepoliaGenesis[:], + }, + want: false, + }, + { + name: "fork ID hash filter - match", + network: Network{ + ForkIDHash: dencunForkHash, + }, + status: &mockStatus{ + networkID: 1, + forkIDHash: dencunForkHash, + }, + want: true, + }, + { + name: "fork ID hash filter - no match", + network: Network{ + ForkIDHash: dencunForkHash, + }, + status: &mockStatus{ + networkID: 1, + forkIDHash: []byte{0x00, 0x00, 0x00, 0x00}, + }, + want: false, + }, + { + name: "fork ID next filter - match", + network: Network{ + ForkIDNext: ptr(uint64(1000)), + }, + status: &mockStatus{ + networkID: 1, + forkIDNext: 1000, + }, + want: true, + }, + { + name: "fork ID next filter - no match", + network: Network{ + ForkIDNext: ptr(uint64(1000)), + }, + status: &mockStatus{ + networkID: 1, + forkIDNext: 2000, + }, + want: false, + }, + { + name: "multiple filters - all match", + network: Network{ + NetworkID: ptr(uint64(1)), + Genesis: mainnetGenesis[:], + ForkIDHash: dencunForkHash, + }, + status: &mockStatus{ + networkID: 1, + genesis: mainnetGenesis[:], + forkIDHash: dencunForkHash, + }, + want: true, + }, + { + name: "multiple filters - one doesn't match", + network: Network{ + NetworkID: ptr(uint64(1)), + Genesis: mainnetGenesis[:], + ForkIDHash: dencunForkHash, + }, + status: &mockStatus{ + networkID: 1, + genesis: sepoliaGenesis[:], // doesn't match + forkIDHash: dencunForkHash, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.network.Matches(tt.status) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestDefaultServerConfig(t *testing.T) { + config := DefaultServerConfig() + + require.NotNil(t, config) + assert.Equal(t, 5*time.Second, config.HandshakeTimeout) + assert.Equal(t, 30*time.Second, config.ReadTimeout) + assert.Empty(t, config.ListenAddr) + assert.Empty(t, config.Name) + assert.Nil(t, config.PrivateKey) + assert.Nil(t, config.StatusProvider) + assert.Empty(t, config.AllowedNetworks) + assert.Equal(t, 0, config.MaxPeers) +} diff --git a/pkg/execution/mimicry/message_disconnect.go b/pkg/execution/mimicry/message_disconnect.go index 4877a43..6de7da5 100644 --- a/pkg/execution/mimicry/message_disconnect.go +++ b/pkg/execution/mimicry/message_disconnect.go @@ -20,6 +20,25 @@ func (h *Disconnect) Code() int { return DisconnectCode } func (h *Disconnect) ReqID() uint64 { return 0 } +// decodeDisconnect decodes a Disconnect message from RLP-encoded data. +func decodeDisconnect(data []byte) (*Disconnect, error) { + d := new(p2p.DiscReason) + + if len(data) > 0 { + reason := data[0:1] + // besu sends 2 byte disconnect message + if len(data) > 1 { + reason = data[1:2] + } + + if err := rlp.DecodeBytes(reason, &d); err != nil { + return nil, err + } + } + + return &Disconnect{Reason: *d}, nil +} + func (c *Client) receiveDisconnect(ctx context.Context, data []byte) *Disconnect { d := new(p2p.DiscReason) diff --git a/pkg/execution/mimicry/message_hello.go b/pkg/execution/mimicry/message_hello.go index 5f94d47..29f3f2e 100644 --- a/pkg/execution/mimicry/message_hello.go +++ b/pkg/execution/mimicry/message_hello.go @@ -3,6 +3,7 @@ package mimicry import ( "context" + "crypto/ecdsa" "fmt" "github.com/ethereum/go-ethereum/crypto" @@ -92,7 +93,7 @@ func (h *Hello) ETHProtocolVersion() uint { } func SupportedEthCaps() []p2p.Cap { - caps := []p2p.Cap{} + caps := make([]p2p.Cap, 0, maxETHProtocolVersion-minETHProtocolVersion+1) for i := minETHProtocolVersion; i <= maxETHProtocolVersion; i++ { caps = append(caps, p2p.Cap{ Name: ETHCapName, @@ -103,12 +104,34 @@ func SupportedEthCaps() []p2p.Cap { return caps } -func (c *Client) receiveHello(ctx context.Context, data []byte) (*Hello, error) { +// decodeHello decodes a Hello message from RLP-encoded data. +func decodeHello(data []byte) (*Hello, error) { h := new(Hello) if err := rlp.DecodeBytes(data, &h); err != nil { return nil, fmt.Errorf("error decoding hello: %w", err) } + return h, nil +} + +// encodeHello encodes a Hello message to RLP bytes. +func encodeHello(privateKey *ecdsa.PrivateKey, caps []p2p.Cap) ([]byte, error) { + pub0 := crypto.FromECDSAPub(&privateKey.PublicKey)[1:] + hello := &Hello{ + Version: P2PProtocolVersion, + Caps: caps, + ID: pub0, + } + + return rlp.EncodeToBytes(hello) +} + +func (c *Client) receiveHello(ctx context.Context, data []byte) (*Hello, error) { + h, err := decodeHello(data) + if err != nil { + return nil, err + } + c.log.WithFields(logrus.Fields{ "version": h.Version, "caps": h.Caps, @@ -127,14 +150,7 @@ func (c *Client) sendHello(ctx context.Context) error { "code": HelloCode, }).Debug("sending Hello") - pub0 := crypto.FromECDSAPub(&c.privateKey.PublicKey)[1:] - hello := &Hello{ - Version: P2PProtocolVersion, - Caps: SupportedEthCaps(), - ID: pub0, - } - - encodedData, err := rlp.EncodeToBytes(hello) + encodedData, err := encodeHello(c.privateKey, SupportedEthCaps()) if err != nil { return fmt.Errorf("error encoding hello: %w", err) } diff --git a/pkg/execution/mimicry/message_new_pooled_transaction_hashes.go b/pkg/execution/mimicry/message_new_pooled_transaction_hashes.go index 130aa83..55e14fb 100644 --- a/pkg/execution/mimicry/message_new_pooled_transaction_hashes.go +++ b/pkg/execution/mimicry/message_new_pooled_transaction_hashes.go @@ -19,6 +19,16 @@ func (msg *NewPooledTransactionHashes) Code() int { return NewPooledTransactionH func (msg *NewPooledTransactionHashes) ReqID() uint64 { return 0 } +// decodeNewPooledTransactionHashes decodes a NewPooledTransactionHashes message from RLP-encoded data. +func decodeNewPooledTransactionHashes(data []byte) (*NewPooledTransactionHashes, error) { + s := new(NewPooledTransactionHashes) + if err := rlp.DecodeBytes(data, &s); err != nil { + return nil, fmt.Errorf("error decoding new pooled transaction hashes: %w", err) + } + + return s, nil +} + func (c *Client) receiveNewPooledTransactionHashes(ctx context.Context, data []byte) (*NewPooledTransactionHashes, error) { s := new(NewPooledTransactionHashes) if err := rlp.DecodeBytes(data, &s); err != nil { diff --git a/pkg/execution/mimicry/message_status.go b/pkg/execution/mimicry/message_status.go index 2879e47..d2e614c 100644 --- a/pkg/execution/mimicry/message_status.go +++ b/pkg/execution/mimicry/message_status.go @@ -61,8 +61,9 @@ func (msg *Status69) GetForkIDHash() []byte { return msg.ForkID.Hash[:] } func (msg *Status69) GetForkIDNext() uint64 { return msg.ForkID.Next } -func (c *Client) receiveStatus(ctx context.Context, data []byte) (Status, error) { - if c.ethCapVersion == 68 { +// decodeStatus decodes a Status message from RLP-encoded data. +func decodeStatus(data []byte, ethCapVersion uint) (Status, error) { + if ethCapVersion == 68 { s := new(Status68) if err := rlp.DecodeBytes(data, &s.StatusPacket68); err != nil { return nil, fmt.Errorf("error decoding status68: %w", err) @@ -80,6 +81,22 @@ func (c *Client) receiveStatus(ctx context.Context, data []byte) (Status, error) return s, nil } +// encodeStatus encodes a Status message to RLP bytes. +func encodeStatus(status Status) ([]byte, error) { + switch s := status.(type) { + case *Status68: + return rlp.EncodeToBytes(&s.StatusPacket68) + case *Status69: + return rlp.EncodeToBytes(&s.StatusPacket69) + default: + return nil, fmt.Errorf("unsupported status type: %T", status) + } +} + +func (c *Client) receiveStatus(ctx context.Context, data []byte) (Status, error) { + return decodeStatus(data, c.ethCapVersion) +} + func (c *Client) sendStatus(ctx context.Context, status Status) error { c.log.WithFields(logrus.Fields{ "code": StatusCode, @@ -87,19 +104,7 @@ func (c *Client) sendStatus(ctx context.Context, status Status) error { "ethCapVersion": c.ethCapVersion, }).Debug("sending Status") - var encodedData []byte - - var err error - - switch s := status.(type) { - case *Status68: - encodedData, err = rlp.EncodeToBytes(&s.StatusPacket68) - case *Status69: - encodedData, err = rlp.EncodeToBytes(&s.StatusPacket69) - default: - return fmt.Errorf("unsupported status type: %T", status) - } - + encodedData, err := encodeStatus(status) if err != nil { return fmt.Errorf("error encoding status: %w", err) } diff --git a/pkg/execution/mimicry/message_transactions.go b/pkg/execution/mimicry/message_transactions.go index 14e3204..b49779e 100644 --- a/pkg/execution/mimicry/message_transactions.go +++ b/pkg/execution/mimicry/message_transactions.go @@ -20,6 +20,16 @@ func (msg *Transactions) Code() int { return TransactionsCode } func (msg *Transactions) ReqID() uint64 { return 0 } +// decodeTransactions decodes a Transactions message from RLP-encoded data. +func decodeTransactions(data []byte) (*Transactions, error) { + s := new(Transactions) + if err := rlp.DecodeBytes(data, &s); err != nil { + return nil, fmt.Errorf("error decoding transactions: %w", err) + } + + return s, nil +} + func (c *Client) receiveTransactions(ctx context.Context, data []byte) (*Transactions, error) { s := new(Transactions) if err := rlp.DecodeBytes(data, &s); err != nil { diff --git a/pkg/execution/mimicry/publish.go b/pkg/execution/mimicry/publish.go index 4a029f0..cc60f65 100644 --- a/pkg/execution/mimicry/publish.go +++ b/pkg/execution/mimicry/publish.go @@ -4,6 +4,7 @@ import ( "context" ) +// Client event topics. const ( topicDisconnect = "disconnect" topicHello = "hello" @@ -12,6 +13,17 @@ const ( topicNewPooledTransactionHashes = "new_pooled_transaction_hashes" ) +// Server event topics. +const ( + topicServerPeerConnected = "server:peer:connected" + topicServerPeerDisconnected = "server:peer:disconnected" + topicServerDisconnect = "server:disconnect" + topicServerHello = "server:hello" + topicServerStatus = "server:status" + topicServerTransactions = "server:transactions" + topicServerNewPooledTransactionHashes = "server:new_pooled_transaction_hashes" +) + func (c *Client) publishDisconnect(ctx context.Context, reason *Disconnect) { c.broker.Emit(topicDisconnect, reason) } @@ -67,3 +79,86 @@ func (c *Client) OnNewPooledTransactionHashes(ctx context.Context, handler func( c.handleSubscriberError(handler(ctx, hashes), topicNewPooledTransactionHashes) }) } + +// Server publish methods + +func (s *Server) publishPeerConnected(peer *ServerPeer) { + s.broker.Emit(topicServerPeerConnected, peer) +} + +func (s *Server) publishPeerDisconnected(peer *ServerPeer) { + s.broker.Emit(topicServerPeerDisconnected, peer) +} + +func (s *Server) publishDisconnect(ctx context.Context, peer *ServerPeer, reason *Disconnect) { + s.broker.Emit(topicServerDisconnect, peer, reason) +} + +func (s *Server) publishHello(ctx context.Context, peer *ServerPeer, hello *Hello) { + s.broker.Emit(topicServerHello, peer, hello) +} + +func (s *Server) publishStatus(ctx context.Context, peer *ServerPeer, status Status) { + s.broker.Emit(topicServerStatus, peer, status) +} + +func (s *Server) publishTransactions(ctx context.Context, peer *ServerPeer, transactions *Transactions) { + s.broker.Emit(topicServerTransactions, peer, transactions) +} + +func (s *Server) publishNewPooledTransactionHashes(ctx context.Context, peer *ServerPeer, hashes *NewPooledTransactionHashes) { + s.broker.Emit(topicServerNewPooledTransactionHashes, peer, hashes) +} + +func (s *Server) handleSubscriberError(err error, topic string) { + if err != nil { + s.log.WithError(err).WithField("topic", topic).Error("Subscriber error") + } +} + +// Server subscription methods + +// OnPeerConnected registers a handler for when a peer connects. +func (s *Server) OnPeerConnected(handler func(peer *ServerPeer)) { + s.broker.On(topicServerPeerConnected, handler) +} + +// OnPeerDisconnected registers a handler for when a peer disconnects. +func (s *Server) OnPeerDisconnected(handler func(peer *ServerPeer)) { + s.broker.On(topicServerPeerDisconnected, handler) +} + +// OnDisconnect registers a handler for disconnect messages from peers. +func (s *Server) OnDisconnect(ctx context.Context, handler func(ctx context.Context, peer *ServerPeer, reason *Disconnect) error) { + s.broker.On(topicServerDisconnect, func(peer *ServerPeer, reason *Disconnect) { + s.handleSubscriberError(handler(ctx, peer, reason), topicServerDisconnect) + }) +} + +// OnHello registers a handler for hello messages from peers. +func (s *Server) OnHello(ctx context.Context, handler func(ctx context.Context, peer *ServerPeer, hello *Hello) error) { + s.broker.On(topicServerHello, func(peer *ServerPeer, hello *Hello) { + s.handleSubscriberError(handler(ctx, peer, hello), topicServerHello) + }) +} + +// OnStatus registers a handler for status messages from peers. +func (s *Server) OnStatus(ctx context.Context, handler func(ctx context.Context, peer *ServerPeer, status Status) error) { + s.broker.On(topicServerStatus, func(peer *ServerPeer, status Status) { + s.handleSubscriberError(handler(ctx, peer, status), topicServerStatus) + }) +} + +// OnTransactions registers a handler for transaction messages from peers. +func (s *Server) OnTransactions(ctx context.Context, handler func(ctx context.Context, peer *ServerPeer, transactions *Transactions) error) { + s.broker.On(topicServerTransactions, func(peer *ServerPeer, transactions *Transactions) { + s.handleSubscriberError(handler(ctx, peer, transactions), topicServerTransactions) + }) +} + +// OnNewPooledTransactionHashes registers a handler for new pooled transaction hash messages from peers. +func (s *Server) OnNewPooledTransactionHashes(ctx context.Context, handler func(ctx context.Context, peer *ServerPeer, hashes *NewPooledTransactionHashes) error) { + s.broker.On(topicServerNewPooledTransactionHashes, func(peer *ServerPeer, hashes *NewPooledTransactionHashes) { + s.handleSubscriberError(handler(ctx, peer, hashes), topicServerNewPooledTransactionHashes) + }) +} diff --git a/pkg/execution/mimicry/server.go b/pkg/execution/mimicry/server.go new file mode 100644 index 0000000..62e9ad6 --- /dev/null +++ b/pkg/execution/mimicry/server.go @@ -0,0 +1,623 @@ +package mimicry + +import ( + "context" + "crypto/ecdsa" + "encoding/hex" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/chuckpreslar/emission" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/rlpx" + "github.com/sirupsen/logrus" +) + +// ServerPeer represents a single inbound RLPx connection. +type ServerPeer struct { + log logrus.FieldLogger + server *Server + + remoteID enode.ID + remotePubkey *ecdsa.PublicKey + + conn net.Conn + rlpxConn *rlpx.Conn + + ethCapVersion uint + remoteHello *Hello + + pooledTransactionsMap map[uint64]chan *PooledTransactions + + done chan struct{} + closed atomic.Bool +} + +// Server accepts incoming RLPx connections. +type Server struct { + log logrus.FieldLogger + config *ServerConfig + broker *emission.Emitter + + privateKey *ecdsa.PrivateKey + + listener net.Listener + + peers map[enode.ID]*ServerPeer + peersMu sync.RWMutex + + wg sync.WaitGroup + closed atomic.Bool + done chan struct{} +} + +// NewServer creates a new Server with the given configuration. +func NewServer(log logrus.FieldLogger, config *ServerConfig) (*Server, error) { + if config.StatusProvider == nil { + return nil, errors.New("StatusProvider is required") + } + + if config.ListenAddr == "" { + return nil, errors.New("ListenAddr is required") + } + + // Apply defaults + if config.HandshakeTimeout == 0 { + config.HandshakeTimeout = 5 * time.Second + } + + if config.ReadTimeout == 0 { + config.ReadTimeout = 30 * time.Second + } + + privateKey := config.PrivateKey + if privateKey == nil { + var err error + + privateKey, err = crypto.GenerateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %w", err) + } + } + + return &Server{ + log: log.WithField("component", "execution_mimicry_server"), + config: config, + broker: emission.NewEmitter(), + privateKey: privateKey, + peers: make(map[enode.ID]*ServerPeer, 64), + done: make(chan struct{}), + }, nil +} + +// Start begins listening for incoming connections. +func (s *Server) Start(ctx context.Context) error { + s.log.WithField("addr", s.config.ListenAddr).Info("starting execution mimicry server") + + var err error + + s.listener, err = net.Listen("tcp", s.config.ListenAddr) + if err != nil { + return fmt.Errorf("failed to listen: %w", err) + } + + s.log.WithField("addr", s.listener.Addr().String()).Info("listening for connections") + + s.wg.Add(1) + + go s.acceptLoop(ctx) + + return nil +} + +// Stop stops the server and disconnects all peers. +func (s *Server) Stop(ctx context.Context) error { + if s.closed.Swap(true) { + return nil // Already closed + } + + s.log.Info("stopping execution mimicry server") + + close(s.done) + + if s.listener != nil { + if err := s.listener.Close(); err != nil { + s.log.WithError(err).Warn("error closing listener") + } + } + + // Close all peers + s.peersMu.Lock() + + for _, peer := range s.peers { + peer.Close() + } + + s.peersMu.Unlock() + + // Wait for goroutines with timeout + done := make(chan struct{}) + + go func() { + s.wg.Wait() + close(done) + }() + + select { + case <-done: + s.log.Debug("all goroutines stopped") + case <-ctx.Done(): + return ctx.Err() + case <-time.After(5 * time.Second): + s.log.Warn("shutdown timeout, some goroutines may still be running") + } + + return nil +} + +// acceptLoop accepts incoming connections. +func (s *Server) acceptLoop(ctx context.Context) { + defer s.wg.Done() + + for { + select { + case <-ctx.Done(): + return + case <-s.done: + return + default: + } + + // Check max peers + if s.config.MaxPeers > 0 { + s.peersMu.RLock() + peerCount := len(s.peers) + s.peersMu.RUnlock() + + if peerCount >= s.config.MaxPeers { + time.Sleep(100 * time.Millisecond) + + continue + } + } + + conn, err := s.listener.Accept() + if err != nil { + if s.closed.Load() { + return + } + + s.log.WithError(err).Warn("accept error") + + continue + } + + s.wg.Add(1) + + go s.handleConnection(ctx, conn) + } +} + +// handleConnection handles a single incoming connection. +func (s *Server) handleConnection(ctx context.Context, conn net.Conn) { + defer s.wg.Done() + + remoteAddr := conn.RemoteAddr().String() + log := s.log.WithField("remote_addr", remoteAddr) + + log.Debug("new incoming connection") + + peer := &ServerPeer{ + log: log, + server: s, + conn: conn, + pooledTransactionsMap: make(map[uint64]chan *PooledTransactions, 16), + done: make(chan struct{}), + } + + // Perform RLPx handshake (responder mode) + if err := peer.handshake(ctx, s.config.HandshakeTimeout); err != nil { + log.WithError(err).Debug("handshake failed") + conn.Close() + + return + } + + // Check if peer already exists + s.peersMu.Lock() + + if existing, ok := s.peers[peer.remoteID]; ok { + s.peersMu.Unlock() + log.WithField("peer_id", peer.remoteID.String()).Debug("peer already connected, closing new connection") + existing.Close() + peer.Close() + + return + } + + s.peers[peer.remoteID] = peer + s.peersMu.Unlock() + + // Emit peer connected event + s.publishPeerConnected(peer) + + // Run session + if err := peer.startSession(ctx); err != nil { + log.WithError(err).Debug("peer session ended") + } + + // Unregister peer + s.peersMu.Lock() + delete(s.peers, peer.remoteID) + s.peersMu.Unlock() + + // Emit peer disconnected event + s.publishPeerDisconnected(peer) +} + +// isNetworkAllowed checks if Status matches allowed networks. +func (s *Server) isNetworkAllowed(status Status) bool { + // No filter = allow all + if len(s.config.AllowedNetworks) == 0 { + return true + } + + // Must match at least one allowed network + for _, network := range s.config.AllowedNetworks { + if network.Matches(status) { + return true + } + } + + return false +} + +// Peers returns all connected peers. +func (s *Server) Peers() []*ServerPeer { + s.peersMu.RLock() + defer s.peersMu.RUnlock() + + peers := make([]*ServerPeer, 0, len(s.peers)) + for _, p := range s.peers { + peers = append(peers, p) + } + + return peers +} + +// PeerCount returns the number of connected peers. +func (s *Server) PeerCount() int { + s.peersMu.RLock() + defer s.peersMu.RUnlock() + + return len(s.peers) +} + +// Broker returns the event broker. +func (s *Server) Broker() *emission.Emitter { + return s.broker +} + +// ServerPeer methods + +// handshake performs the RLPx handshake in responder mode. +func (sp *ServerPeer) handshake(ctx context.Context, timeout time.Duration) error { + if err := sp.conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return fmt.Errorf("error setting handshake deadline: %w", err) + } + + // nil pubkey = responder mode + sp.rlpxConn = rlpx.NewConn(sp.conn, nil) + + remotePubkey, err := sp.rlpxConn.Handshake(sp.server.privateKey) + if err != nil { + return fmt.Errorf("rlpx handshake failed: %w", err) + } + + sp.remotePubkey = remotePubkey + sp.remoteID = enode.PubkeyToIDV4(remotePubkey) + + // Clear deadline + if err := sp.conn.SetDeadline(time.Time{}); err != nil { + return fmt.Errorf("error clearing deadline: %w", err) + } + + sp.log = sp.log.WithField("peer_id", sp.remoteID.String()) + sp.log.Debug("rlpx handshake complete") + + return nil +} + +// startSession runs the peer session (Hello exchange + message loop). +func (sp *ServerPeer) startSession(ctx context.Context) error { + defer sp.Close() + + // Send Hello first + if err := sp.sendHello(ctx); err != nil { + return fmt.Errorf("error sending hello: %w", err) + } + + // Enter message loop + return sp.messageLoop(ctx) +} + +// messageLoop handles incoming messages. +func (sp *ServerPeer) messageLoop(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-sp.done: + return nil + default: + } + + // Set read deadline + if err := sp.conn.SetReadDeadline(time.Now().Add(sp.server.config.ReadTimeout)); err != nil { + return fmt.Errorf("error setting read deadline: %w", err) + } + + code, data, _, err := sp.rlpxConn.Read() + if err != nil { + return fmt.Errorf("error reading: %w", err) + } + + if err := sp.handleMessage(ctx, code, data); err != nil { + return err + } + } +} + +// handleMessage routes messages to appropriate handlers. +func (sp *ServerPeer) handleMessage(ctx context.Context, code uint64, data []byte) error { + //nolint:gosec // not an overflow issue here. + switch int(code) { + case HelloCode: + return sp.handleHello(ctx, code, data) + case DisconnectCode: + sp.handleDisconnect(ctx, code, data) + + return nil + case PingCode: + return sp.handlePing(ctx, code, data) + case StatusCode: + return sp.handleStatus(ctx, code, data) + case TransactionsCode: + return sp.handleTransactions(ctx, code, data) + case GetBlockHeadersCode: + return sp.handleGetBlockHeaders(ctx, code, data) + case BlockHeadersCode: + return sp.handleBlockHeaders(ctx, code, data) + case GetBlockBodiesCode: + return sp.handleGetBlockBodies(ctx, code, data) + case NewPooledTransactionHashesCode: + return sp.handleNewPooledTransactionHashes(ctx, code, data) + case PooledTransactionsCode: + return sp.handlePooledTransactions(ctx, code, data) + case GetReceiptsCode: + return sp.handleGetReceipts(ctx, code, data) + case BlockRangeUpdateCode: + return sp.handleBlockRangeUpdate(ctx, code, data) + default: + sp.log.WithField("code", code).Debug("received unhandled message code") + } + + return nil +} + +// Close terminates the peer connection. +func (sp *ServerPeer) Close() error { + if sp.closed.Swap(true) { + return nil // Already closed + } + + close(sp.done) + + if sp.rlpxConn != nil { + if err := sp.rlpxConn.Close(); err != nil { + sp.log.WithError(err).Debug("error closing rlpx connection") + } + } + + return nil +} + +// RemoteID returns the remote node ID. +func (sp *ServerPeer) RemoteID() enode.ID { + return sp.remoteID +} + +// RemotePubkey returns the remote public key. +func (sp *ServerPeer) RemotePubkey() *ecdsa.PublicKey { + return sp.remotePubkey +} + +// Message handlers for ServerPeer + +func (sp *ServerPeer) sendHello(ctx context.Context) error { + sp.log.WithField("code", HelloCode).Debug("sending Hello") + + encodedData, err := encodeHello(sp.server.privateKey, SupportedEthCaps()) + if err != nil { + return fmt.Errorf("error encoding hello: %w", err) + } + + if _, err := sp.rlpxConn.Write(HelloCode, encodedData); err != nil { + return fmt.Errorf("error sending hello: %w", err) + } + + return nil +} + +func (sp *ServerPeer) handleHello(ctx context.Context, code uint64, data []byte) error { + sp.log.WithField("code", code).Debug("received Hello") + + hello, err := decodeHello(data) + if err != nil { + return err + } + + sp.log.WithFields(logrus.Fields{ + "version": hello.Version, + "caps": hello.Caps, + "name": hello.Name, + }).Debug("received hello message") + + if err := hello.Validate(); err != nil { + return err + } + + sp.ethCapVersion = hello.ETHProtocolVersion() + sp.remoteHello = hello + + sp.server.publishHello(ctx, sp, hello) + + // Enable snappy compression + sp.rlpxConn.SetSnappy(true) + + return nil +} + +func (sp *ServerPeer) handleStatus(ctx context.Context, code uint64, data []byte) error { + sp.log.WithFields(logrus.Fields{ + "code": code, + "ethCapVersion": sp.ethCapVersion, + }).Debug("received Status") + + status, err := decodeStatus(data, sp.ethCapVersion) + if err != nil { + return err + } + + // Check network filter BEFORE publishing/responding + if !sp.server.isNetworkAllowed(status) { + sp.log.WithFields(logrus.Fields{ + "network_id": status.GetNetworkID(), + "genesis": hex.EncodeToString(status.GetGenesis()[:8]), + }).Debug("peer rejected: network not allowed") + + return fmt.Errorf("network not allowed") + } + + sp.server.publishStatus(ctx, sp, status) + + // Use StatusProvider callback for response + responseStatus, err := sp.server.config.StatusProvider(sp.remoteHello) + if err != nil { + return fmt.Errorf("status provider error: %w", err) + } + + return sp.sendStatus(ctx, responseStatus) +} + +func (sp *ServerPeer) sendStatus(ctx context.Context, status Status) error { + sp.log.WithFields(logrus.Fields{ + "code": StatusCode, + "ethCapVersion": sp.ethCapVersion, + }).Debug("sending Status") + + encodedData, err := encodeStatus(status) + if err != nil { + return fmt.Errorf("error encoding status: %w", err) + } + + if _, err := sp.rlpxConn.Write(StatusCode, encodedData); err != nil { + return fmt.Errorf("error sending status: %w", err) + } + + return nil +} + +func (sp *ServerPeer) handleDisconnect(ctx context.Context, code uint64, data []byte) { + sp.log.WithField("code", code).Debug("received Disconnect") + + disconnect, err := decodeDisconnect(data) + if err != nil { + sp.log.WithError(err).Debug("error decoding disconnect") + + return + } + + sp.server.publishDisconnect(ctx, sp, disconnect) +} + +func (sp *ServerPeer) handlePing(ctx context.Context, code uint64, data []byte) error { + sp.log.WithField("code", code).Debug("received Ping") + + // Respond with Pong + if _, err := sp.rlpxConn.Write(PongCode, []byte{}); err != nil { + return fmt.Errorf("error sending pong: %w", err) + } + + return nil +} + +func (sp *ServerPeer) handleTransactions(ctx context.Context, code uint64, data []byte) error { + sp.log.WithField("code", code).Debug("received Transactions") + + transactions, err := decodeTransactions(data) + if err != nil { + return err + } + + sp.server.publishTransactions(ctx, sp, transactions) + + return nil +} + +func (sp *ServerPeer) handleGetBlockHeaders(ctx context.Context, code uint64, data []byte) error { + sp.log.WithField("code", code).Debug("received GetBlockHeaders") + // Server doesn't respond to block header requests + + return nil +} + +func (sp *ServerPeer) handleBlockHeaders(ctx context.Context, code uint64, data []byte) error { + sp.log.WithField("code", code).Debug("received BlockHeaders") + + return nil +} + +func (sp *ServerPeer) handleGetBlockBodies(ctx context.Context, code uint64, data []byte) error { + sp.log.WithField("code", code).Debug("received GetBlockBodies") + // Server doesn't respond to block body requests + + return nil +} + +func (sp *ServerPeer) handleNewPooledTransactionHashes(ctx context.Context, code uint64, data []byte) error { + sp.log.WithField("code", code).Debug("received NewPooledTransactionHashes") + + hashes, err := decodeNewPooledTransactionHashes(data) + if err != nil { + return err + } + + sp.server.publishNewPooledTransactionHashes(ctx, sp, hashes) + + return nil +} + +func (sp *ServerPeer) handlePooledTransactions(ctx context.Context, code uint64, data []byte) error { + sp.log.WithField("code", code).Debug("received PooledTransactions") + + return nil +} + +func (sp *ServerPeer) handleGetReceipts(ctx context.Context, code uint64, data []byte) error { + sp.log.WithField("code", code).Debug("received GetReceipts") + // Server doesn't respond to receipt requests + + return nil +} + +func (sp *ServerPeer) handleBlockRangeUpdate(ctx context.Context, code uint64, data []byte) error { + sp.log.WithField("code", code).Debug("received BlockRangeUpdate") + + return nil +} diff --git a/pkg/execution/mimicry/server_test.go b/pkg/execution/mimicry/server_test.go new file mode 100644 index 0000000..7d974df --- /dev/null +++ b/pkg/execution/mimicry/server_test.go @@ -0,0 +1,359 @@ +package mimicry + +import ( + "context" + "crypto/ecdsa" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testLogger() *logrus.Logger { + log := logrus.New() + log.SetLevel(logrus.ErrorLevel) // Suppress logs during tests + + return log +} + +func testStatusProvider(hello *Hello) (Status, error) { + return &Status69{ + StatusPacket69: eth.StatusPacket69{ + NetworkID: 1, + Genesis: common.HexToHash("0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3"), + LatestBlockHash: common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"), + }, + }, nil +} + +func TestNewServer(t *testing.T) { + log := testLogger() + + tests := []struct { + name string + config *ServerConfig + wantErr bool + errMsg string + checks func(t *testing.T, server *Server) + }{ + { + name: "valid config creates server", + config: &ServerConfig{ + Name: "test-server", + ListenAddr: ":0", // Use any available port + StatusProvider: testStatusProvider, + }, + wantErr: false, + checks: func(t *testing.T, server *Server) { + t.Helper() + assert.NotNil(t, server.broker) + assert.NotNil(t, server.privateKey) + assert.NotNil(t, server.peers) + assert.NotNil(t, server.done) + assert.Equal(t, 5*time.Second, server.config.HandshakeTimeout) + assert.Equal(t, 30*time.Second, server.config.ReadTimeout) + }, + }, + { + name: "valid config with custom private key", + config: func() *ServerConfig { + key, _ := crypto.GenerateKey() + + return &ServerConfig{ + Name: "test-server", + ListenAddr: ":0", + StatusProvider: testStatusProvider, + PrivateKey: key, + } + }(), + wantErr: false, + checks: func(t *testing.T, server *Server) { + t.Helper() + assert.NotNil(t, server.privateKey) + }, + }, + { + name: "valid config with custom timeouts", + config: &ServerConfig{ + Name: "test-server", + ListenAddr: ":0", + StatusProvider: testStatusProvider, + HandshakeTimeout: 10 * time.Second, + ReadTimeout: 60 * time.Second, + }, + wantErr: false, + checks: func(t *testing.T, server *Server) { + t.Helper() + assert.Equal(t, 10*time.Second, server.config.HandshakeTimeout) + assert.Equal(t, 60*time.Second, server.config.ReadTimeout) + }, + }, + { + name: "valid config with max peers", + config: &ServerConfig{ + Name: "test-server", + ListenAddr: ":0", + StatusProvider: testStatusProvider, + MaxPeers: 100, + }, + wantErr: false, + checks: func(t *testing.T, server *Server) { + t.Helper() + assert.Equal(t, 100, server.config.MaxPeers) + }, + }, + { + name: "valid config with allowed networks", + config: &ServerConfig{ + Name: "test-server", + ListenAddr: ":0", + StatusProvider: testStatusProvider, + AllowedNetworks: []Network{ + {NetworkID: ptr(uint64(1))}, + {NetworkID: ptr(uint64(11155111))}, + }, + }, + wantErr: false, + checks: func(t *testing.T, server *Server) { + t.Helper() + assert.Len(t, server.config.AllowedNetworks, 2) + }, + }, + { + name: "missing StatusProvider returns error", + config: &ServerConfig{ + Name: "test-server", + ListenAddr: ":0", + }, + wantErr: true, + errMsg: "StatusProvider is required", + }, + { + name: "missing ListenAddr returns error", + config: &ServerConfig{ + Name: "test-server", + StatusProvider: testStatusProvider, + }, + wantErr: true, + errMsg: "ListenAddr is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server, err := NewServer(log, tt.config) + + if tt.wantErr { + require.Error(t, err) + assert.Nil(t, server) + + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + + return + } + + require.NoError(t, err) + require.NotNil(t, server) + + if tt.checks != nil { + tt.checks(t, server) + } + }) + } +} + +func TestServerStartStop(t *testing.T) { + log := testLogger() + + config := &ServerConfig{ + Name: "test-server", + ListenAddr: "127.0.0.1:0", // Use any available port on localhost + StatusProvider: testStatusProvider, + } + + server, err := NewServer(log, config) + require.NoError(t, err) + + ctx := context.Background() + + // Start server + err = server.Start(ctx) + require.NoError(t, err) + + // Verify server is listening + assert.NotNil(t, server.listener) + assert.False(t, server.closed.Load()) + + // Get the actual address + addr := server.listener.Addr() + assert.NotNil(t, addr) + + // Stop server + err = server.Stop(ctx) + require.NoError(t, err) + + // Verify server is stopped + assert.True(t, server.closed.Load()) +} + +func TestServerDoubleStop(t *testing.T) { + log := testLogger() + + config := &ServerConfig{ + Name: "test-server", + ListenAddr: "127.0.0.1:0", + StatusProvider: testStatusProvider, + } + + server, err := NewServer(log, config) + require.NoError(t, err) + + ctx := context.Background() + + err = server.Start(ctx) + require.NoError(t, err) + + // Stop twice - should not error + err = server.Stop(ctx) + require.NoError(t, err) + + err = server.Stop(ctx) + require.NoError(t, err) // Second stop should be idempotent +} + +func TestServerPeersAndCount(t *testing.T) { + log := testLogger() + + config := &ServerConfig{ + Name: "test-server", + ListenAddr: "127.0.0.1:0", + StatusProvider: testStatusProvider, + } + + server, err := NewServer(log, config) + require.NoError(t, err) + + ctx := context.Background() + + err = server.Start(ctx) + require.NoError(t, err) + + defer func() { + _ = server.Stop(ctx) + }() + + // Initially no peers + assert.Equal(t, 0, server.PeerCount()) + assert.Empty(t, server.Peers()) + + // Broker should be accessible + assert.NotNil(t, server.Broker()) +} + +func TestServerIsNetworkAllowed(t *testing.T) { + log := testLogger() + mainnetGenesis := common.HexToHash("0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3") + sepoliaGenesis := common.HexToHash("0x25a5cc106eea7138acab33231d7160d69cb777ee0c2c553fcddf5138993e6dd9") + + tests := []struct { + name string + allowedNetworks []Network + status Status + want bool + }{ + { + name: "no filter allows all networks", + allowedNetworks: nil, + status: &mockStatus{ + networkID: 1, + genesis: mainnetGenesis[:], + }, + want: true, + }, + { + name: "filter allows matching network", + allowedNetworks: []Network{ + {NetworkID: ptr(uint64(1))}, + }, + status: &mockStatus{ + networkID: 1, + genesis: mainnetGenesis[:], + }, + want: true, + }, + { + name: "filter rejects non-matching network", + allowedNetworks: []Network{ + {NetworkID: ptr(uint64(1))}, + }, + status: &mockStatus{ + networkID: 11155111, + genesis: sepoliaGenesis[:], + }, + want: false, + }, + { + name: "multiple filters - matches one", + allowedNetworks: []Network{ + {NetworkID: ptr(uint64(1))}, + {NetworkID: ptr(uint64(11155111))}, + }, + status: &mockStatus{ + networkID: 11155111, + genesis: sepoliaGenesis[:], + }, + want: true, + }, + { + name: "multiple filters - matches none", + allowedNetworks: []Network{ + {NetworkID: ptr(uint64(1))}, + {NetworkID: ptr(uint64(11155111))}, + }, + status: &mockStatus{ + networkID: 5, // goerli + genesis: common.Hash{}.Bytes(), + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &ServerConfig{ + Name: "test-server", + ListenAddr: ":0", + StatusProvider: testStatusProvider, + AllowedNetworks: tt.allowedNetworks, + } + + server, err := NewServer(log, config) + require.NoError(t, err) + + got := server.isNetworkAllowed(tt.status) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestServerPeerMethods(t *testing.T) { + // Test ServerPeer accessor methods + peer := &ServerPeer{ + remoteID: [32]byte{1, 2, 3, 4}, + remotePubkey: func() *ecdsa.PublicKey { + key, _ := crypto.GenerateKey() + + return &key.PublicKey + }(), + } + + assert.Equal(t, [32]byte{1, 2, 3, 4}, [32]byte(peer.RemoteID())) + assert.NotNil(t, peer.RemotePubkey()) +}