diff --git a/examples/bitswap/bitswap.py b/examples/bitswap/bitswap.py index 1a9c31cac..cbd222004 100755 --- a/examples/bitswap/bitswap.py +++ b/examples/bitswap/bitswap.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse +import hashlib import logging from pathlib import Path import sys @@ -16,6 +17,7 @@ from libp2p.bitswap import BitswapClient from libp2p.bitswap.cid import cid_to_bytes, format_cid_for_display from libp2p.bitswap.dag import MerkleDag +from libp2p.crypto.ed25519 import create_new_key_pair from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.utils.address_validation import ( find_free_port, @@ -35,6 +37,23 @@ logger = logging.getLogger(__name__) +DEFAULT_LISTEN_PORT = 4013 + + +def select_preferred_listen_addr(addrs: list[Multiaddr], port: int) -> Multiaddr: + """Pick a stable, local-friendly address for copy/paste commands.""" + preferred_v4 = f"/ip4/127.0.0.1/tcp/{port}" + for addr in addrs: + if str(addr) == preferred_v4: + return addr + + preferred_v6 = f"/ip6/::1/tcp/{port}" + for addr in addrs: + if str(addr) == preferred_v6: + return addr + + return addrs[0] + def format_size(size_bytes: int) -> str: """Format size in human-readable form.""" @@ -46,13 +65,14 @@ def format_size(size_bytes: int) -> str: return f"{size:.1f} TB" -async def run_provider(file_path: str, port: int = 0): +async def run_provider(file_path: str, port: int = 0, seed: str | None = None): """ Run the provider node to share a file. Args: file_path: Path to the file to share port: TCP port to listen on (0 for auto) + seed: Optional seed string for deterministic peer ID generation """ file_path_obj = Path(file_path) @@ -73,12 +93,19 @@ async def run_provider(file_path: str, port: int = 0): if port <= 0: port = find_free_port() listen_addrs = get_available_interfaces(port) - # Create host - host = new_host() - async with host.run(listen_addrs=listen_addrs): - peer_id = host.get_id() - logger.info(f"Peer ID: {peer_id}") + # Create host with optional seed for deterministic peer ID + key_pair = None + if seed: + # Convert seed string to bytes (must be 32 bytes for Ed25519) + seed_bytes = hashlib.sha256(seed.encode()).digest() + key_pair = create_new_key_pair(seed=seed_bytes) + logger.info("Using deterministic peer ID from seed") + + host = new_host(key_pair=key_pair) + + async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: + logger.info(f"Peer ID: {host.get_id()}") # Get actual listening addresses addrs = host.get_addrs() @@ -91,7 +118,8 @@ async def run_provider(file_path: str, port: int = 0): await bitswap.start() logger.info("✓ Bitswap started") - # Create Merkle DAG + # Set nursery so bitswap can spawn background tasks + bitswap.set_nursery(nursery) dag = MerkleDag(bitswap) logger.info("") @@ -109,7 +137,7 @@ def progress_callback(current: int, total: int, status: str): # Add file with directory wrapper for filename preservation # Always uses Merkle DAG regardless of file size root_cid = await dag.add_file( - file_path, progress_callback=progress_callback, wrap_with_directory=True + file_path, progress_callback=progress_callback, wrap_with_directory=False ) # Get all blocks that were stored @@ -131,8 +159,10 @@ def progress_callback(current: int, total: int, status: str): logger.info("FILE READY TO SHARE!") logger.info("=" * 70) - # Get the first address (clean multiaddr without duplicate /p2p/) - provider_addr = host.get_addrs()[0] + # Prefer a deterministic local address for copy/paste commands. + transport_addrs = host.get_transport_addrs() + provider_addr = select_preferred_listen_addr(transport_addrs, port) + provider_addr = provider_addr.encapsulate(Multiaddr(f"/p2p/{host.get_id()}")) root_cid_text = format_cid_for_display(root_cid) logger.info(f"Root CID: {root_cid_text}") logger.info("") @@ -161,6 +191,7 @@ async def run_client( root_cid_input: str, output_dir: str = "/tmp", port: int = 0, + seed: str | None = None, ): """ Run the client node to fetch a file. @@ -170,6 +201,7 @@ async def run_client( root_cid_input: Root CID (canonical text, /ipfs/... path, or hex string) output_dir: Directory to save the file port: TCP port to listen on (0 for auto) + seed: Optional seed string for deterministic peer ID generation """ output_path = Path(output_dir) @@ -195,16 +227,24 @@ async def run_client( port = find_free_port() listen_addrs = get_available_interfaces(port) - # Create host - host = new_host() + # Create host with optional seed for deterministic peer ID + key_pair = None + if seed: + # Convert seed string to bytes (must be 32 bytes for Ed25519) + seed_bytes = hashlib.sha256(seed.encode()).digest() + key_pair = create_new_key_pair(seed=seed_bytes) + logger.info("Using deterministic peer ID from seed") - async with host.run(listen_addrs=listen_addrs): + host = new_host(key_pair=key_pair) + + async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: logger.info(f"Client Peer ID: {host.get_id()}") # Start Bitswap bitswap = BitswapClient(host) await bitswap.start() logger.info("✓ Bitswap started") + bitswap.set_nursery(nursery) try: # Connect to provider @@ -214,7 +254,6 @@ async def run_client( await host.connect(peer_info) logger.info("✓ Connected") - # Create Merkle DAG dag = MerkleDag(bitswap) logger.info("") @@ -232,7 +271,7 @@ def progress_callback(current: int, total: int, status: str): # Fetch file with automatic filename extraction try: file_data, filename = await dag.fetch_file( - root_cid, progress_callback=progress_callback + root_cid, progress_callback=progress_callback, timeout=120.0 ) # Show fetch statistics @@ -284,18 +323,18 @@ def progress_callback(current: int, total: int, status: str): logger.info("=" * 70) logger.info(f"Size: {format_size(len(file_data))}") - # Determine output filename + # Determine output filename (priority: metadata > generated) if filename: - output_filename = filename - logger.info(f"Filename: {filename} (from metadata)") + final_filename = filename + logger.info(f"Filename: {final_filename} (from metadata)") else: - output_filename = ( + final_filename = ( f"file_{format_cid_for_display(root_cid, max_len=16)}.bin" ) - logger.info(f"Filename: {output_filename} (no metadata)") + logger.info(f"Filename: {final_filename} (generated from CID)") # Handle filename conflicts - output_file = output_path / output_filename + output_file = output_path / final_filename if output_file.exists(): stem = output_file.stem suffix = output_file.suffix @@ -315,7 +354,9 @@ def progress_callback(current: int, total: int, status: str): except Exception as e: logger.error(f"Failed: {e}") logger.exception("Full traceback:") + raise finally: + pass # Nursery will cleanup background tasks await bitswap.stop() @@ -333,8 +374,8 @@ def parse_args(): parser.add_argument( "--port", type=int, - default=0, - help="Port to listen on (0 for random, provider mode only)", + default=DEFAULT_LISTEN_PORT, + help=("Port to listen on (default: 4012). Use 0 to auto-select a random port."), ) parser.add_argument( "--file", @@ -365,6 +406,14 @@ def parse_args(): action="store_true", help="Enable verbose logging", ) + parser.add_argument( + "--seed", + type=str, + help=( + "Seed string for deterministic peer ID generation " + "(same seed = same peer ID)" + ), + ) args = parser.parse_args() @@ -395,9 +444,11 @@ def main(): ) if args.mode == "provider": - trio.run(run_provider, args.file, args.port) + trio.run(run_provider, args.file, args.port, args.seed) elif args.mode == "client": - trio.run(run_client, args.provider, args.cid, args.output, args.port) + trio.run( + run_client, args.provider, args.cid, args.output, args.port, args.seed + ) except Exception as e: logger.critical(f"Script failed: {e}", exc_info=True) sys.exit(1) diff --git a/libp2p/bitswap/__init__.py b/libp2p/bitswap/__init__.py index 756ad5793..e057362bc 100644 --- a/libp2p/bitswap/__init__.py +++ b/libp2p/bitswap/__init__.py @@ -31,7 +31,12 @@ New code should prefer the object-returning variants above. """ -from .block_store import BlockStore, MemoryBlockStore +from .block_service import BlockService +from .block_store import BlockStore, FilesystemBlockStore, MemoryBlockStore +from .gated_decision_engine import PaymentGatedDecisionEngine +from .payment_ledger import PaymentLedger +from .pricing_engine import BlockPricingEngine +from .payment_client_1_3 import BitswapPaymentClient_1_3 from .cid import ( CID_V0, CID_V1, @@ -65,12 +70,33 @@ MessageTooLargeError, TimeoutError, ) +from .wantlist import ( + BitswapMessage, + BlockPresence, + BlockPresenceType, + Wantlist, + WantlistEntry, + WantType, +) __all__ = [ # Core "BitswapClient", + "BitswapPaymentClient_1_3", + "PaymentGatedDecisionEngine", + "PaymentLedger", + "BlockPricingEngine", + "BlockService", "BlockStore", "MemoryBlockStore", + "FilesystemBlockStore", + # Messages + "BitswapMessage", + "BlockPresence", + "BlockPresenceType", + "Wantlist", + "WantlistEntry", + "WantType", # CID types "CIDInput", "CIDObject", diff --git a/libp2p/bitswap/block_service.py b/libp2p/bitswap/block_service.py new file mode 100644 index 000000000..c4e452d9a --- /dev/null +++ b/libp2p/bitswap/block_service.py @@ -0,0 +1,196 @@ +""" +BlockService: transparent local→network fallback for block retrieval. + +Sits between MerkleDag and BitswapClient, providing: + - Local-first lookup (no network cost if block is already stored) + - Automatic caching of network-fetched blocks into the local store + - Peer announcement when new blocks are stored locally + - A clean abstraction so MerkleDag is not hardwired to BitswapClient +""" + +from __future__ import annotations + +from collections.abc import Sequence +import logging +from typing import TYPE_CHECKING + +from .block_store import BlockStore +from .cid import CIDInput, cid_to_bytes, format_cid_for_display, parse_cid + +if TYPE_CHECKING: + from libp2p.peer.id import ID as PeerID + + from .client import BitswapClient + +logger = logging.getLogger(__name__) + + +class BlockService: + """ + Combines a local BlockStore with a BitswapClient into one unified interface. + + get_block() flow: + 1. Check local BlockStore → return immediately if found (no network) + 2. Fetch via BitswapClient → goes to the network + 3. Auto-cache the result → store locally so next call is free + + put_block() flow: + 1. Write to local BlockStore + 2. Call bitswap.add_block() so peers who have this CID in their + wantlist are notified and can receive it + + This is a drop-in wrapper: MerkleDag can use BlockService instead of + calling bitswap directly, and the behaviour is identical but with the + caching and announcement benefits added transparently. + + Example: + >>> store = FilesystemBlockStore("./blocks") + >>> service = BlockService(store, bitswap) + >>> dag = MerkleDag(bitswap, block_service=service) + + """ + + def __init__(self, store: BlockStore, bitswap: BitswapClient) -> None: + self.store = store + self.bitswap = bitswap + + async def get_block( + self, + cid: CIDInput, + peer_id: PeerID | None = None, + timeout: float = 30.0, + ) -> bytes | None: + """ + Get a block. Checks local store first, then fetches from network. + Any block fetched from the network is automatically cached locally. + + Args: + cid: The CID of the block to retrieve + peer_id: Optional specific peer to fetch from (passed to bitswap) + timeout: Network timeout in seconds + + Returns: + Block data bytes, or None if not found anywhere + + """ + cid_bytes = cid_to_bytes(cid) + cid_obj = parse_cid(cid_bytes) + + # 1. Local lookup — instant, no network cost + data = await self.store.get_block(cid_obj) + if data is not None: + logger.debug( + f"BlockService: local hit {format_cid_for_display(cid_obj, max_len=12)}" + ) + return data + + # 2. Network fetch via Bitswap + logger.debug( + f"BlockService: local miss, fetching from network " + f"{format_cid_for_display(cid_obj, max_len=12)}" + ) + try: + data = await self.bitswap.get_block(cid_bytes, peer_id, timeout) + except Exception as e: + logger.warning(f"BlockService: network fetch failed: {e}") + return None + + if data is not None: + # 3. Auto-cache locally — future requests for this block are free + await self.store.put_block(cid_obj, data) + logger.debug( + f"BlockService: cached fetched block " + f"{format_cid_for_display(cid_obj, max_len=12)}" + ) + + return data + + async def put_block(self, cid: CIDInput, data: bytes) -> None: + """ + Store a block locally and announce it to waiting peers. + + Calling bitswap.add_block() both writes to bitswap's own store AND + notifies any peers who have this CID in their pending wantlist. + We also write to our own store so get_block() local-hits on it. + + Args: + cid: The CID of the block + data: The block data bytes + + """ + cid_obj = parse_cid(cid_to_bytes(cid)) + + # Write to our local store + await self.store.put_block(cid_obj, data) + + # add_block() writes to bitswap's internal store AND calls + # _notify_peers_about_block() for any peers waiting on this CID + await self.bitswap.add_block(cid_obj, data) + + logger.debug( + f"BlockService: stored and announced " + f"{format_cid_for_display(cid_obj, max_len=12)}" + ) + + async def get_blocks_batch( + self, + cids: Sequence[CIDInput], + peer_id: PeerID | None = None, + timeout: float = 30.0, + batch_size: int = 32, + ) -> dict[bytes, bytes]: + """ + Batch-fetch multiple blocks. Local hits are returned immediately; + only missing blocks go to the network. All network-fetched blocks + are auto-cached locally. + + Args: + cids: List of CIDs to fetch + peer_id: Optional specific peer to fetch from + timeout: Network timeout in seconds + batch_size: Wantlist batch size passed to bitswap + + Returns: + Dict mapping cid_bytes -> block_data for all found blocks + + """ + results: dict[bytes, bytes] = {} + missing_cids: list[CIDInput] = [] + + # Local pass first + for cid in cids: + cid_bytes = cid_to_bytes(cid) + cid_obj = parse_cid(cid_bytes) + data = await self.store.get_block(cid_obj) + if data is not None: + results[cid_bytes] = data + else: + missing_cids.append(cid) + + if not missing_cids: + logger.debug(f"BlockService.get_blocks_batch: all {len(cids)} blocks local") + return results + + local_hits = len(cids) - len(missing_cids) + logger.debug( + f"BlockService.get_blocks_batch: {local_hits} local hits, " + f"{len(missing_cids)} fetching from network" + ) + + # Network pass for missing blocks + network_results = await self.bitswap.get_blocks_batch( + missing_cids, peer_id=peer_id, timeout=timeout, batch_size=batch_size + ) + + # Auto-cache all network-fetched blocks + for cid_bytes, data in network_results.items(): + cid_obj = parse_cid(cid_bytes) + await self.store.put_block(cid_obj, data) + results[cid_bytes] = data + + return results + + @property + def block_store(self) -> BlockStore: + """Expose the underlying BlockStore (used by MerkleDag internals).""" + return self.store diff --git a/libp2p/bitswap/block_store.py b/libp2p/bitswap/block_store.py index 12eee5aab..bc36269ce 100644 --- a/libp2p/bitswap/block_store.py +++ b/libp2p/bitswap/block_store.py @@ -3,6 +3,9 @@ """ from abc import ABC, abstractmethod +from pathlib import Path + +import trio from .cid import CIDInput, CIDObject, parse_cid @@ -118,3 +121,99 @@ def get_all_cids(self) -> list[bytes]: def size(self) -> int: """Get the number of blocks in the store.""" return len(self._blocks) + + +class FilesystemBlockStore(BlockStore): + """ + Filesystem-based block store. Persists blocks to disk as files. + + Each block is stored as a file at: + // + + This two-level directory structure avoids having too many files in a + single directory and matches the layout used by py-ipfs-lite. + + Args: + base_path: Root directory for block storage. Created if it doesn't exist. + + Example: + >>> store = FilesystemBlockStore("/var/lib/myapp/blocks") + >>> bitswap = BitswapClient(host, store) + >>> # Blocks now survive process restarts! + + >>> # Drop-in replacement for MemoryBlockStore: + >>> # store = MemoryBlockStore() # before + >>> store = FilesystemBlockStore("./blocks") # after — persistent + + """ + + def __init__(self, base_path: str | Path) -> None: + """Initialize the filesystem block store.""" + self._path = Path(base_path) + self._path.mkdir(parents=True, exist_ok=True) + + def _cid_to_path(self, cid: CIDInput) -> Path: + """Convert a CID to a filesystem path using 2-char prefix directories.""" + cid_str = str(_normalize_cid(cid)) + # e.g. bafybeiabc... → /ba/fybeiabc... + return self._path / cid_str[:2] / cid_str[2:] + + async def get_block(self, cid: CIDInput) -> bytes | None: + """Get a block by CID. Returns None if not found on disk.""" + path = self._cid_to_path(cid) + if not path.exists(): + return None + return await trio.to_thread.run_sync(path.read_bytes) + + async def put_block(self, cid: CIDInput, data: bytes) -> None: + """Write a block to disk.""" + path = self._cid_to_path(cid) + await trio.to_thread.run_sync( + lambda: path.parent.mkdir(parents=True, exist_ok=True) + ) + await trio.to_thread.run_sync(path.write_bytes, data) + + async def has_block(self, cid: CIDInput) -> bool: + """Check if a block file exists on disk.""" + return self._cid_to_path(cid).exists() + + async def delete_block(self, cid: CIDInput) -> None: + """Delete a block file from disk.""" + path = self._cid_to_path(cid) + if path.exists(): + await trio.to_thread.run_sync(path.unlink) + + def get_all_cids(self) -> list[bytes]: + """Return all stored CIDs as bytes by scanning the directory tree.""" + cids: list[bytes] = [] + if not self._path.exists(): + return cids + for subdir in self._path.iterdir(): + if not subdir.is_dir(): + continue + for entry in subdir.iterdir(): + if not entry.is_file(): + continue + cid_str = subdir.name + entry.name + try: + cid_obj = _normalize_cid(cid_str) + cids.append(cid_obj.buffer) + except Exception: + pass # skip files that aren't valid CIDs + return cids + + def size(self) -> int: + """Return the number of stored blocks.""" + if not self._path.exists(): + return 0 + return sum( + 1 + for d in self._path.iterdir() + if d.is_dir() + for f in d.iterdir() + if f.is_file() + ) + + def base_path(self) -> Path: + """Return the root directory where blocks are stored.""" + return self._path diff --git a/libp2p/bitswap/chunker.py b/libp2p/bitswap/chunker.py index 10cb869b0..ba3fe9822 100644 --- a/libp2p/bitswap/chunker.py +++ b/libp2p/bitswap/chunker.py @@ -7,10 +7,12 @@ """ from collections.abc import Callable, Iterator +import io from pathlib import Path -# Default chunk size: 63 KB (py-libp2p accepts less than 64 KB) -DEFAULT_CHUNK_SIZE = 63 * 1024 +# Default chunk size: 256 KiB — matches Kubo's default chunker (size-262144). +# Raw leaves are stored directly without dag-pb wrapping, so no overhead needed. +DEFAULT_CHUNK_SIZE = 256 * 1024 def chunk_bytes(data: bytes, chunk_size: int = DEFAULT_CHUNK_SIZE) -> list[bytes]: @@ -82,6 +84,49 @@ def chunk_file(file_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> Iterator yield chunk +def chunk_stream( + stream: io.IOBase, chunk_size: int = DEFAULT_CHUNK_SIZE +) -> Iterator[bytes]: + """ + Stream chunks from any readable io.IOBase object. + + Memory efficient — reads one chunk at a time without loading the + entire content into memory. Works with any Python stream: + open() file handles, BytesIO, GzipFile, BZ2File, network sockets, + or any object that implements io.IOBase.read(). + + Args: + stream: Any readable io.IOBase (open(), BytesIO, GzipFile, etc.) + chunk_size: Size of each chunk in bytes + + Yields: + Chunks of up to chunk_size bytes. The final chunk may be smaller. + + Example: + >>> import io + >>> data = b"hello world " * 100000 + >>> chunks = list(chunk_stream(io.BytesIO(data), chunk_size=256*1024)) + >>> print(f"Split into {len(chunks)} chunks") + + >>> # From a real file handle + >>> with open("movie.mp4", "rb") as f: + ... for chunk in chunk_stream(f): + ... process(chunk) + + >>> # From a gzip stream (decompress on-the-fly) + >>> import gzip + >>> with gzip.open("archive.gz", "rb") as f: + ... for chunk in chunk_stream(f): + ... process(chunk) + + """ + while True: + chunk = stream.read(chunk_size) + if not chunk: + break + yield chunk + + def estimate_chunk_count(file_size: int, chunk_size: int = DEFAULT_CHUNK_SIZE) -> int: """ Estimate number of chunks for a given file size. diff --git a/libp2p/bitswap/cid.py b/libp2p/bitswap/cid.py index 9f21d90de..b1c89821a 100644 --- a/libp2p/bitswap/cid.py +++ b/libp2p/bitswap/cid.py @@ -209,7 +209,16 @@ def parse_cid(value: CIDInput) -> CIDv0 | CIDv1: return value if isinstance(value, bytes): - return make_cid(value) + try: + return make_cid(value) + except ValueError: + # make_cid(bytes) fails for raw CIDv0 buffers (multihash bytes). + # CIDv0 is simply a bare multihash, so try constructing directly. + try: + return CIDv0(value) + except Exception: + pass + raise if isinstance(value, str): cid_str = value.strip() @@ -234,8 +243,16 @@ def cid_to_bytes(value: CIDInput) -> bytes: def cid_to_text(value: CIDInput) -> str: - """Convert CID input to canonical CID string form.""" - return str(parse_cid(value)) + """ + Convert CID input to canonical CID string form + (base32 for CIDv1, base58btc for CIDv0). + """ + cid_obj = parse_cid(value) + # Use base32 for CIDv1 (matches Kubo's default output) + if cid_obj.version == 1: + return cid_obj.encode("base32").decode() + # Use base58btc for CIDv0 (legacy format) + return str(cid_obj) def format_cid_for_display(cid: CIDInput, max_len: int | None = None) -> str: diff --git a/libp2p/bitswap/client.py b/libp2p/bitswap/client.py index 96913567f..6675ad64d 100644 --- a/libp2p/bitswap/client.py +++ b/libp2p/bitswap/client.py @@ -1,12 +1,12 @@ """ Bitswap client implementation for block exchange. -Supports v1.0.0, v1.1.0, and v1.2.0 protocols. +Supports v1.0.0, v1.1.0, v1.2.0, and v1.3.0 protocols. """ from collections.abc import Sequence import hashlib import logging -from typing import Any +from typing import TYPE_CHECKING, Any import trio import varint @@ -15,7 +15,10 @@ from libp2p.custom_types import TProtocol from libp2p.network.stream.exceptions import StreamEOF from libp2p.peer.id import ID as PeerID +from libp2p.peer.peerinfo import PeerInfo # noqa: F401 +if TYPE_CHECKING: + from .extension import IBitswapExtension from .block_store import BlockStore, MemoryBlockStore from .cid import ( CIDInput, @@ -43,6 +46,7 @@ ) from .messages import create_message, create_wantlist_entry from .pb.bitswap_pb2 import Message +from .provider_query import ProviderQueryManager logger = logging.getLogger(__name__) @@ -51,8 +55,10 @@ class BitswapClient: """ Bitswap client for exchanging blocks with other peers. - Supports Bitswap protocol versions 1.0.0, 1.1.0, and 1.2.0 for content - discovery and file sharing in a peer-to-peer network. + Supports Bitswap protocol versions 1.0.0, 1.1.0, 1.2.0, and 1.3.0 for + content discovery and file sharing in a peer-to-peer network. + + For 1.3.0 payment support, register a PaymentExtension. """ def __init__( @@ -60,6 +66,7 @@ def __init__( host: IHost, block_store: BlockStore | None = None, protocol_version: str = BITSWAP_PROTOCOL_V120, + provider_query_manager: ProviderQueryManager | None = None, ): """ Initialize Bitswap client. @@ -68,11 +75,22 @@ def __init__( host: The libp2p host block_store: Block storage backend (defaults to in-memory) protocol_version: Preferred protocol version (defaults to v1.2.0) + provider_query_manager: Optional ProviderQueryManager for automatic + DHT-based provider discovery. When supplied, + ``get_block()`` will query the DHT for providers before + broadcasting to all connected peers. """ self.host = host self.block_store = block_store or MemoryBlockStore() self.protocol_version = protocol_version + self.provider_query_manager: ProviderQueryManager | None = ( + provider_query_manager + ) + + self.protocol_handlers: dict[str, "IBitswapExtension"] = {} + self.supported_protocols: list[str] = list(BITSWAP_PROTOCOLS) + self._wantlist: dict[ CIDObject, dict[str, Any] ] = {} # CID -> {priority, want_type, send_dont_have} @@ -89,15 +107,22 @@ def __init__( self._nursery: trio.Nursery | None = None self._started = False + def register_extension(self, protocol: str, extension: "IBitswapExtension") -> None: + """Register an extension for a specific protocol.""" + extension.set_client(self) + self.protocol_handlers[protocol] = extension + if protocol not in self.supported_protocols: + self.supported_protocols.insert(0, protocol) + async def start(self) -> None: """Start the Bitswap client.""" if self._started: return # Set stream handler for all supported Bitswap protocols - for protocol in BITSWAP_PROTOCOLS: + for protocol in self.supported_protocols: self.host.set_stream_handler( - protocol, + TProtocol(protocol), self._handle_stream, ) @@ -111,8 +136,8 @@ async def stop(self) -> None: self._started = False # Unregister stream handlers for all supported Bitswap protocols - for protocol in BITSWAP_PROTOCOLS: - self.host.remove_stream_handler(protocol) + for protocol in self.supported_protocols: + self.host.remove_stream_handler(TProtocol(protocol)) # Clear wantlists and pending requests self._wantlist.clear() self._peer_wantlists.clear() @@ -153,6 +178,103 @@ async def add_block(self, cid: CIDInput, data: bytes) -> None: # Notify peers who wanted this block await self._notify_peers_about_block(cid_obj, data) + async def get_blocks_batch( + self, + cids: list[CIDInput], + peer_id: PeerID | None = None, + timeout: float = DEFAULT_TIMEOUT, + batch_size: int = 32, + ) -> dict[bytes, bytes]: + """ + Fetch multiple blocks in batches using a single wantlist per batch. + + Sends all CIDs in one wantlist message, waits for all responses on the + same stream. This avoids opening hundreds of individual streams which + causes Kubo to send GO_AWAY. + + Args: + cids: List of CIDs to fetch + peer_id: Optional specific peer to request from + timeout: Timeout per batch in seconds + batch_size: How many CIDs to request per wantlist message + + Returns: + Dict mapping cid_bytes -> block_data for all successfully fetched blocks + + """ + results: dict[bytes, bytes] = {} + cid_objs = [parse_cid(c) for c in cids] + + # Check local store first + remaining: list[CIDObject] = [] + for cid_obj in cid_objs: + data = await self.block_store.get_block(cid_obj) + if data is not None: + results[cid_obj.buffer] = data + else: + remaining.append(cid_obj) + + if not remaining: + return results + + # Process in batches to avoid overwhelming the peer + for batch_start in range(0, len(remaining), batch_size): + batch = remaining[batch_start : batch_start + batch_size] + + # Register pending events for all CIDs in batch + for cid_obj in batch: + if cid_obj not in self._pending_requests: + self._pending_requests[cid_obj] = trio.Event() + await self.want_block(cid_obj, send_dont_have=True) + + # Send all CIDs in a single wantlist to the peer + if peer_id: + await self._send_wantlist_to_peer(peer_id, batch) + else: + await self._broadcast_wantlist(batch) + + # Wait for all blocks in this batch + try: + with trio.fail_after(timeout): + for cid_obj in batch: + if cid_obj in self._pending_requests: + await self._pending_requests[cid_obj].wait() + except trio.TooSlowError: + msg = f"Batch timeout: {len(batch)} blocks, got partial results" + logger.warning(msg) + + # Collect results and clean up + for cid_obj in batch: + data = await self.block_store.get_block(cid_obj) + if data is not None: + results[cid_obj.buffer] = data + else: + # Block may have arrived late (e.g. after payment round-trip). + # Check if the pending event was set after the timeout fired. + event = self._pending_requests.get(cid_obj) + if event and event.is_set(): + data = await self.block_store.get_block(cid_obj) + if data is not None: + results[cid_obj.buffer] = data + logger.info( + f"Late block received (post-timeout): " + f"{format_cid_for_display(cid_obj)}" + ) + else: + cid_str = format_cid_for_display(cid_obj) + logger.warning(f"Block not received: {cid_str}") + else: + cid_str = format_cid_for_display(cid_obj) + logger.warning(f"Block not received: {cid_str}") + + # Cleanup + if cid_obj in self._pending_requests: + del self._pending_requests[cid_obj] + if cid_obj in self._wantlist: + del self._wantlist[cid_obj] + + return results + async def get_block( self, cid: CIDInput, @@ -162,9 +284,15 @@ async def get_block( """ Get a block, fetching from peers if not available locally. + If a ``ProviderQueryManager`` was supplied at construction time and no + explicit ``peer_id`` is given, the manager is consulted first to + discover which peers have the block via the DHT. The first discovered + provider is used; if none is found the request falls back to + broadcasting to all connected peers. + Args: cid: The CID of the block to fetch - peer_id: Optional specific peer to request from + peer_id: Optional peer to request from; DHT discovery is skipped when set. timeout: Timeout in seconds Returns: @@ -177,12 +305,31 @@ async def get_block( """ cid_obj = parse_cid(cid) - # Check local store first + # 1. Check local store first data = await self.block_store.get_block(cid_obj) if data is not None: return data - # Request from network + # 2. If no explicit peer given, try DHT provider discovery + if peer_id is None and self.provider_query_manager is not None: + try: + providers = await self.provider_query_manager.find_providers_single( + cid, timeout=min(5.0, timeout / 2) + ) + if providers: + peer_id = providers[0] + logger.debug( + "DHT discovered provider %s for %s", + peer_id, + format_cid_for_display(cid_obj, max_len=12), + ) + except Exception as exc: + logger.debug( + "Provider query failed, falling back to broadcast: %s", + exc, + ) + + # 3. Request from network (specific peer or broadcast) return await self._request_block(cid_obj, peer_id, timeout) async def want_block( @@ -286,10 +433,8 @@ async def _request_block( # Send wantlist to peers if peer_id: - logger.info(f" → Sending wantlist to peer {peer_id}") await self._send_wantlist_to_peer(peer_id, [cid]) else: - logger.info(" → Broadcasting wantlist") await self._broadcast_wantlist([cid]) # Wait for block to arrive @@ -379,7 +524,7 @@ async def _send_wantlist_to_peer( if peer_id in self._peer_protocols: protocols = [TProtocol(self._peer_protocols[peer_id])] else: - protocols = list(BITSWAP_PROTOCOLS) # Try all + protocols = [TProtocol(p) for p in self.supported_protocols] # Try all # Open stream and send message stream = await self.host.new_stream( @@ -553,14 +698,29 @@ async def _handle_stream(self, stream: INetStream) -> None: peer_id = stream.muxed_conn.peer_id logger.debug(f"Handling Bitswap stream from peer {peer_id}") + # Detect negotiated protocol and store it immediately so that + # _process_message can use the correct protocol for responses. + protocol = stream.get_protocol() + if protocol: + self._peer_protocols[peer_id] = str(protocol) + try: + # Read the first message from this stream + msg = await self._read_message(stream) + if msg is None: + return + + # If the peer sent a WANT_HAVE and we have blocks, reply with + # a proactive HAVE so Kubo's session scores us highly and sends + # WANT_BLOCK immediately on the same stream. + await self._process_message(msg, peer_id, stream) + + # Continue reading further messages on the same stream + # (Kubo sends WANT_BLOCK as a follow-up after receiving HAVE) while True: - # Read message msg = await self._read_message(stream) if msg is None: break - - # Process message await self._process_message(msg, peer_id, stream) except Exception as e: @@ -572,24 +732,60 @@ async def _process_message( self, msg: Message, peer_id: PeerID, stream: INetStream ) -> None: """Process a received Bitswap message.""" + peer_id_str = str(peer_id)[:16] + if msg.HasField("wantlist"): + logger.warning("=" * 70) + logger.warning(f"📥 RECEIVED WANTLIST from peer {peer_id_str}") + logger.warning(f" Entries: {len(msg.wantlist.entries)}") + logger.warning(f" Full: {msg.wantlist.full}") + for _i, _e in enumerate(msg.wantlist.entries): + _cid_hex = bytes(_e.block).hex()[:20] if _e.block else "N/A" + _wt = "WANT_HAVE" if _e.wantType == 1 else "WANT_BLOCK" + logger.warning( + f" [{_i + 1}] cid={_cid_hex}... type={_wt} cancel={_e.cancel}" + ) + logger.warning("=" * 70) + print( + f"\n📥 RECEIVED WANTLIST from peer {peer_id_str} with " + f"{len(msg.wantlist.entries)} entries", + flush=True, + ) + # Detect peer protocol version from stream protocol = stream.get_protocol() if protocol: self._peer_protocols[peer_id] = str(protocol) - # Process wantlist + peer_protocol = str(protocol) if protocol else BITSWAP_PROTOCOL_V100 + logger.info( + f"[FLOW] Negotiated protocol for peer {str(peer_id)[:20]}...: " + f"{peer_protocol}" + ) + + # ── Protocol Extension Handling ───────────────────────────────────── + if peer_protocol in self.protocol_handlers: + handled = await self.protocol_handlers[peer_protocol].process_message( + peer_id, msg.SerializeToString(), stream + ) + if handled: + return + + # ── Standard 1.0.0–1.2.0 message handling (always runs) ───────── if msg.HasField("wantlist"): - await self._process_wantlist(msg.wantlist, peer_id, stream) + handled = False + if peer_protocol in self.protocol_handlers: + handled = await self.protocol_handlers[peer_protocol].process_wantlist( + msg.wantlist, peer_id, stream + ) + if not handled: + await self._process_wantlist(msg.wantlist, peer_id, stream) - # Process blocks (v1.0.0 format) if msg.blocks: await self._process_blocks_v100(list(msg.blocks), peer_id) - # Process payload (v1.1.0+ format) if msg.payload: await self._process_blocks_v110(msg.payload) - # Process block presences (v1.2.0 format) if msg.blockPresences: await self._process_block_presences(msg.blockPresences, peer_id) @@ -602,7 +798,6 @@ async def _process_wantlist( self._peer_wantlists[peer_id] = {} peer_wantlist = self._peer_wantlists[peer_id] - # Update based on full or incremental wantlist if wantlist.full: peer_wantlist.clear() @@ -610,13 +805,28 @@ async def _process_wantlist( # Get peer protocol for response format peer_protocol = self._peer_protocols.get(peer_id, BITSWAP_PROTOCOL_V100) + logger.warning("=" * 70) + logger.warning( + f"[STEP 1] SERVER PROCESSING WANTLIST from {str(peer_id)[:20]}..." + ) + logger.warning(f" entries={len(wantlist.entries)} protocol={peer_protocol}") + logger.warning("=" * 70) + + # ── Standard 1.0.0–1.2.0 wantlist handling ──────────────────────── # Process entries blocks_to_send_v100 = [] # For v1.0.0 blocks_to_send_v110 = [] # For v1.1.0+ presences_to_send = [] # For v1.2.0 for entry in wantlist.entries: - entry_cid = parse_cid(entry.block) + try: + logger.warning(f" -> Processing entry: {bytes(entry.block).hex()}") + entry_cid = parse_cid(entry.block) + logger.warning(f" -> Parsed CID: {entry_cid}") + except Exception as e: + logger.warning(f" -> EXCEPTION in parse_cid: {e}") + continue + if entry.cancel: # Remove from peer's wantlist if entry_cid in peer_wantlist: @@ -630,44 +840,229 @@ async def _process_wantlist( } # Check if we have this block - has_block = await self.block_store.has_block(entry_cid) + logger.warning(f" -> Checking if we have block {entry_cid}") + try: + has_block = await self.block_store.has_block(entry_cid) + logger.warning(f" -> has_block result: {has_block}") + except Exception as e: + logger.warning(f" -> EXCEPTION in has_block: {e}") + has_block = False + + logger.warning( + f"[WANTLIST ENTRY] " + f"cid={format_cid_for_display(entry_cid, max_len=16)} " + f"wantType={entry.wantType} cancel={entry.cancel} " + f"has_block={has_block}" + ) # Handle based on want type (v1.2.0) - if entry.wantType == 1: # Have request - # Send presence information - if has_block or entry.sendDontHave: - presences_to_send.append((entry_cid, has_block)) - else: # Block request + if entry.wantType == 1: # Have request (WANT_HAVE) + if has_block: + # Send the block directly — do NOT send a separate HAVE + # presence. Sending HAVE causes Go's bitswap session to + # open a NEW outbound WANT_BLOCK stream to Python. That + # stream fails due to Python TLS limitations, so Go never + # receives the block. Sending the block directly (implicit + # HAVE) is the correct interop approach. + data = await self.block_store.get_block(entry_cid) + if data: + print( + f"\n[WANT_HAVE] Sending block directly " + f"({len(data)} bytes) for " + f"{format_cid_for_display(entry_cid, max_len=16)}", + flush=True, + ) + logger.warning( + f"[WANT_HAVE] Sending block directly " + f"({len(data)} bytes) for " + f"{format_cid_for_display(entry_cid, max_len=16)} " + f"(skipping HAVE presence to avoid Go re-request)" + ) + if peer_protocol == BITSWAP_PROTOCOL_V100: + blocks_to_send_v100.append(data) + else: + prefix = get_cid_prefix(entry_cid) + blocks_to_send_v110.append((prefix, data)) + else: + # Don't have the block — send DontHave so requester + # knows to look elsewhere. + print( + f"\n[WANT_HAVE] DontHave for " + f"{format_cid_for_display(entry_cid, max_len=16)}", + flush=True, + ) + logger.warning( + f"[WANT_HAVE] Sending DontHave for " + f"{format_cid_for_display(entry_cid, max_len=16)}" + ) + presences_to_send.append((entry_cid, False)) + else: # Block request (WANT_BLOCK) if has_block: data = await self.block_store.get_block(entry_cid) if data: + print( + f"\n[WANT_BLOCK] Sending block directly " + f"({len(data)} bytes) for " + f"{format_cid_for_display(entry_cid, max_len=16)}", + flush=True, + ) + logger.warning( + f"[WANT_BLOCK] Sending block for " + f"{format_cid_for_display(entry_cid, max_len=16)}" + ) if peer_protocol == BITSWAP_PROTOCOL_V100: blocks_to_send_v100.append(data) else: prefix = get_cid_prefix(entry_cid) blocks_to_send_v110.append((prefix, data)) - elif entry.sendDontHave: - # Send DontHave (v1.2.0) + else: + # Always send DontHave when we don't have the block, + # regardless of sendDontHave flag. This prevents the + # requester from stalling waiting for a response. presences_to_send.append((entry_cid, False)) - # Send responses + # Send responses in batches to stay under MAX_MESSAGE_SIZE + # and Noise protocol limit (65535 bytes) if blocks_to_send_v100 or blocks_to_send_v110 or presences_to_send: - response_msg = create_message( - blocks_v100=blocks_to_send_v100 if blocks_to_send_v100 else None, - blocks_v110=blocks_to_send_v110 if blocks_to_send_v110 else None, - block_presences=presences_to_send if presences_to_send else None, - ) - logger.debug(f"Sending response message to {peer_id} on stream {stream}") - await self._write_message(stream, response_msg) - logger.debug(f"Response message sent to {peer_id}") - - if blocks_to_send_v100 or blocks_to_send_v110: - count = len(blocks_to_send_v100) + len(blocks_to_send_v110) - logger.debug(f"Sent {count} blocks to peer {peer_id}") - if presences_to_send: - logger.debug( - f"Sent {len(presences_to_send)} block presences to peer {peer_id}" + if self._nursery is not None: + self._nursery.start_soon( + self._send_wantlist_responses_bg, # type: ignore + peer_id, + str(peer_protocol), + blocks_to_send_v100, + blocks_to_send_v110, + presences_to_send, ) + else: + # Fallback to writing to the inbound stream if nursery is not available. + # This works for Python-to-Python tests, but may fail for + # Go-libp2p interop. + await self._send_wantlist_responses_inline( + stream, + peer_id, + blocks_to_send_v100, + blocks_to_send_v110, + presences_to_send, + ) + + async def _send_wantlist_responses_bg( + self, + peer_id: PeerID, + peer_protocol: str, + blocks_to_send_v100: list[bytes], + blocks_to_send_v110: list[tuple[bytes, bytes]], + presences_to_send: list[tuple[CIDObject, bool]], + ) -> None: + """Background task to send responses over a new outbound stream.""" + # We MUST open a new stream to the client to send the blocks. + # Writing to the inbound stream that the client opened for their WANTLIST + # is often ignored by the client (Kubo), as it expects dial back. + try: + outbound_stream = await self.host.new_stream( + peer_id, [TProtocol(peer_protocol)] + ) + except Exception as e: + logger.error(f"Failed to open outbound stream to send response: {e}") + return + + try: + await self._send_wantlist_responses_inline( + outbound_stream, + peer_id, + blocks_to_send_v100, + blocks_to_send_v110, + presences_to_send, + ) + finally: + await outbound_stream.close() + + async def _send_wantlist_responses_inline( + self, + stream: INetStream, + peer_id: PeerID, + blocks_to_send_v100: list[bytes], + blocks_to_send_v110: list[tuple[bytes, bytes]], + presences_to_send: list[tuple[CIDObject, bool]], + ) -> None: + """Helper to send blocks on a specific stream.""" + # Send blocks in batches + if blocks_to_send_v100: + await self._send_blocks_in_batches_v100( + blocks_to_send_v100, peer_id, stream + ) + if blocks_to_send_v110: + await self._send_blocks_in_batches_v110( + blocks_to_send_v110, peer_id, stream + ) + # Send presences (usually small, can send all at once) + if presences_to_send: + presence_msg = create_message(block_presences=presences_to_send) + await self._write_message(stream, presence_msg) + + async def _send_blocks_in_batches_v100( + self, blocks: list[bytes], peer_id: PeerID, stream: INetStream + ) -> None: + """Send blocks in batches to stay under message size limit.""" + # Noise protocol limit is 65535 bytes per message + # Reserve some space for protobuf overhead + MAX_BATCH_SIZE = 60000 # ~60KB per message for safety + + batch: list[bytes] = [] + batch_size = 0 + + for block_data in blocks: + block_size = len(block_data) + + # If adding this block would exceed limit, send current batch first + if batch and (batch_size + block_size > MAX_BATCH_SIZE): + msg = create_message(blocks_v100=batch) + await self._write_message(stream, msg) + logger.debug(f"Sent batch of {len(batch)} blocks to peer {peer_id}") + batch = [] + batch_size = 0 + + batch.append(block_data) + batch_size += block_size + + # Send remaining blocks + if batch: + msg = create_message(blocks_v100=batch) + await self._write_message(stream, msg) + logger.debug(f"Sent final batch of {len(batch)} blocks to peer {peer_id}") + + async def _send_blocks_in_batches_v110( + self, + blocks: list[tuple[bytes, bytes]], + peer_id: PeerID, + stream: INetStream, + ) -> None: + """Send blocks (v1.1.0+ format) in batches to stay under message size limit.""" + # Noise protocol limit is 65535 bytes per message + # Reserve some space for protobuf overhead + MAX_BATCH_SIZE = 60000 # ~60KB per message for safety + + batch: list[tuple[bytes, bytes]] = [] + batch_size = 0 + + for prefix, block_data in blocks: + block_size = len(prefix) + len(block_data) + + # If adding this block would exceed limit, send current batch first + if batch and (batch_size + block_size > MAX_BATCH_SIZE): + msg = create_message(blocks_v110=batch) + await self._write_message(stream, msg) + logger.debug(f"Sent batch of {len(batch)} blocks to peer {peer_id}") + batch = [] + batch_size = 0 + + batch.append((prefix, block_data)) + batch_size += block_size + + # Send remaining blocks + if batch: + msg = create_message(blocks_v110=batch) + await self._write_message(stream, msg) + logger.debug(f"Sent final batch of {len(batch)} blocks to peer {peer_id}") async def _process_blocks_v100(self, blocks: list[bytes], peer_id: PeerID) -> None: """ @@ -905,3 +1300,55 @@ async def _write_message(self, stream: INetStream, msg: Message) -> None: # Write length prefix and message length_prefix = varint.encode(len(msg_bytes)) await stream.write(length_prefix + msg_bytes) + + async def _write_message_bytes(self, stream: INetStream, msg_bytes: bytes) -> None: + """ + Write pre-serialized message bytes (for 1.3.0 Message_1_3 objects). + """ + if len(msg_bytes) > MAX_MESSAGE_SIZE: + raise MessageTooLargeError( + f"Message size {len(msg_bytes)} exceeds maximum {MAX_MESSAGE_SIZE}" + ) + length_prefix = varint.encode(len(msg_bytes)) + await stream.write(length_prefix + msg_bytes) + + async def _process_block_presences_1_3( + self, presences: Any, peer_id: PeerID + ) -> None: + """ + Process block presences from a 1.3.0 message. + Handles PaymentRequired (type=2) in addition to Have/DontHave. + """ + for presence in presences: + cid_bytes = bytes(presence.cid) + try: + cid = parse_cid(cid_bytes) + except Exception: + continue + + presence_type = presence.type + + if presence_type == 0: # Have + if peer_id not in self._expected_blocks: + self._expected_blocks[peer_id] = set() + self._expected_blocks[peer_id].add(cid) + logger.debug( + f"[1.3.0] Peer {peer_id} has block " + f"{format_cid_for_display(cid, max_len=16)}" + ) + elif presence_type == 1: # DontHave + if cid not in self._dont_have_responses: + self._dont_have_responses[cid] = set() + self._dont_have_responses[cid].add(peer_id) + logger.info( + f"[1.3.0] Peer {peer_id} doesn't have block " + f"{format_cid_for_display(cid, max_len=16)}" + ) + elif presence_type == 2: # PaymentRequired + logger.info( + f"[1.3.0] Peer {peer_id} requires payment for block " + f"{format_cid_for_display(cid, max_len=16)} " + f"(PaymentTerms will follow in same message)" + ) + # The payment_client will handle PaymentTerms + # in process_incoming_message diff --git a/libp2p/bitswap/config.py b/libp2p/bitswap/config.py index 87ba26e0e..028103100 100644 --- a/libp2p/bitswap/config.py +++ b/libp2p/bitswap/config.py @@ -8,6 +8,7 @@ BITSWAP_PROTOCOL_V100 = TProtocol("/ipfs/bitswap/1.0.0") BITSWAP_PROTOCOL_V110 = TProtocol("/ipfs/bitswap/1.1.0") BITSWAP_PROTOCOL_V120 = TProtocol("/ipfs/bitswap/1.2.0") +BITSWAP_PROTOCOL_V130 = TProtocol("/ipfs/bitswap/1.3.0") # All supported protocols (ordered from newest to oldest for negotiation) BITSWAP_PROTOCOLS = [ @@ -22,12 +23,13 @@ # Maximum message size (4MiB as per spec) MAX_MESSAGE_SIZE = 4 * 1024 * 1024 -# Maximum block size (63 KB - matches DEFAULT_CHUNK_SIZE in chunker.py) +# Maximum block size (63 KB - after DAG-PB/UnixFS encoding) # py-libp2p stream limit is ~64 KB, so we use 63 KB to be safe -MAX_BLOCK_SIZE = 63 * 1024 +# Note: Raw chunk data should be smaller to account for DAG-PB overhead (~14 bytes) +MAX_BLOCK_SIZE = 512 * 1024 # Default timeout for operations (in seconds) -DEFAULT_TIMEOUT = 30 +DEFAULT_TIMEOUT = 90 # Maximum number of concurrent block requests MAX_CONCURRENT_REQUESTS = 100 diff --git a/libp2p/bitswap/dag.py b/libp2p/bitswap/dag.py index 98ce469db..e06df18ce 100644 --- a/libp2p/bitswap/dag.py +++ b/libp2p/bitswap/dag.py @@ -1,631 +1,930 @@ -""" -Merkle DAG manager for file operations. - -This module provides a high-level API for adding and fetching files -using the Bitswap protocol with automatic chunking, linking, and -multi-block resolution. - -""" - -from collections.abc import Awaitable, Callable -import inspect -import logging -from typing import Union - -from libp2p.peer.id import ID as PeerID - -from .block_store import BlockStore -from .chunker import ( - DEFAULT_CHUNK_SIZE, - chunk_bytes, - chunk_file, - estimate_chunk_count, - get_file_size, -) -from .cid import ( - CODEC_DAG_PB, - CODEC_RAW, - CIDInput, - cid_to_bytes, - compute_cid_v1, - format_cid_for_display, - verify_cid, -) -from .client import BitswapClient -from .dag_pb import ( - create_file_node, - decode_dag_pb, - is_directory_node, - is_file_node, -) - -logger = logging.getLogger(__name__) - - -# Type alias for progress callbacks (sync or async) -ProgressCallback = Union[ - Callable[[int, int, str], None], - Callable[[int, int, str], Awaitable[None]], -] - - -async def _call_progress_callback( - callback: ProgressCallback | None, - current: int, - total: int, - status: str, -) -> None: - """Call a progress callback, handling both sync and async callbacks.""" - if callback is None: - return - - if inspect.iscoroutinefunction(callback): - await callback(current, total, status) - else: - callback(current, total, status) - - -class MerkleDag: - """ - Merkle DAG manager for file operations. - - Provides high-level API for adding and fetching files with automatic - chunking, link creation, and recursive block fetching. - - Example: - >>> from libp2p import new_host - >>> from libp2p.bitswap import BitswapClient, MemoryBlockStore, MerkleDag - >>> import trio - >>> - >>> async def main(): - ... host = new_host() - ... async with host.run(["/ip4/0.0.0.0/tcp/0"]): - ... store = MemoryBlockStore() - ... bitswap = BitswapClient(host, store) - ... await bitswap.start() - ... - ... dag = MerkleDag(bitswap) - ... - ... # Add a large file (auto-chunked) - ... root_cid = await dag.add_file('movie.mp4') - ... print(f"Share: {cid_to_text(root_cid)}") - ... - ... # Fetch file (auto-resolves all chunks) - ... data = await dag.fetch_file(root_cid) - ... open('downloaded.mp4', 'wb').write(data) - ... - >>> trio.run(main) - - """ - - def __init__(self, bitswap: BitswapClient, block_store: BlockStore | None = None): - """ - Initialize Merkle DAG manager. - - Args: - bitswap: Bitswap client for block exchange - block_store: Optional block store (uses bitswap's store if None) - - """ - self.bitswap = bitswap - self.block_store = block_store or bitswap.block_store - - async def add_file( - self, - file_path: str, - chunk_size: int | None = None, - progress_callback: Callable[[int, int, str], None] | None = None, - wrap_with_directory: bool = True, - ) -> bytes: - """ - Add a file to the DAG. - - Automatically chunks large files and creates link structure. - Small files are stored as single blocks. - - Args: - file_path: Path to file - chunk_size: Optional chunk size (auto-selected if None) - progress_callback: Optional callback(current, total, status) - wrap_with_directory: If True, wraps file in a directory node with filename - (IPFS-standard way, enables filename preservation) - - Returns: - Root CID of the file (or wrapping directory if wrap_with_directory=True) - - Raises: - FileNotFoundError: If file doesn't exist - BlockTooLargeError: If a single chunk exceeds MAX_BLOCK_SIZE - - Example: - >>> async def progress(current, total, status): - ... print(f"{status}: {current}/{total}") - >>> root_cid = await dag.add_file('movie.mp4', progress_callback=progress) - >>> print(f"Share this: {cid_to_text(root_cid)}") - - """ - # Get file size - file_size = get_file_size(file_path) - logger.info(f"Adding file: {file_path} ({file_size} bytes)") - - # Determine chunk size - if chunk_size is None: - chunk_size = DEFAULT_CHUNK_SIZE - - logger.debug(f"Using chunk size: {chunk_size} bytes") - - # If file is small enough, store as single RAW block - if file_size <= chunk_size: - logger.debug("File fits in single block") - - with open(file_path, "rb") as f: - data = f.read() - - cid = compute_cid_v1(data, codec=CODEC_RAW) - - await self.bitswap.add_block(cid, data) - - if progress_callback: - await _call_progress_callback( - progress_callback, file_size, file_size, "completed" - ) - - logger.info( - f"Added file as single block: {format_cid_for_display(cid, max_len=16)}" - ) - - # Wrap in directory if requested - if wrap_with_directory: - import os - - from .dag_pb import create_directory_node - - filename = os.path.basename(file_path) - logger.info( - f"Wrapping single-block file in directory with name: {filename}" - ) - - dir_data = create_directory_node([(filename, cid, file_size)]) - dir_cid = compute_cid_v1(dir_data, codec=CODEC_DAG_PB) - await self.bitswap.add_block(dir_cid, dir_data) - - logger.info( - f"Created directory wrapper. Directory CID: " - f"{format_cid_for_display(dir_cid, max_len=16)}" - ) - return dir_cid - - return cid - - # Chunk the file - estimated_chunks = estimate_chunk_count(file_size, chunk_size) - logger.debug(f"Chunking file into ~{estimated_chunks} chunks") - logger.info("=== Starting file chunking process ===") - - chunks_data: list[tuple[bytes, int]] = [] - bytes_processed = 0 - - # Process file in chunks (memory efficient) - for i, chunk_data in enumerate(chunk_file(file_path, chunk_size)): - # Compute CID for chunk - chunk_cid = compute_cid_v1(chunk_data, codec=CODEC_RAW) - - # Store chunk - await self.bitswap.add_block(chunk_cid, chunk_data) - - # Track chunk info - chunks_data.append((chunk_cid, len(chunk_data))) - bytes_processed += len(chunk_data) - - # Progress callback - if progress_callback: - await _call_progress_callback( - progress_callback, - bytes_processed, - file_size, - f"chunking ({i + 1} chunks)", - ) - - # Enhanced logging with full CID - logger.info( - f"Chunk {i + 1}: CID={format_cid_for_display(chunk_cid)}, " - f"Size={len(chunk_data)} bytes, " - f"Progress={bytes_processed}/{file_size}" - ) - logger.debug( - f"Stored chunk {i}: {format_cid_for_display(chunk_cid, max_len=16)} " - f"({len(chunk_data)} bytes)" - ) - - # Create root node with links to all chunks - if progress_callback: - await _call_progress_callback( - progress_callback, file_size, file_size, "creating root node" - ) - - root_data = create_file_node(chunks_data) - root_cid = compute_cid_v1(root_data, codec=CODEC_DAG_PB) - await self.bitswap.add_block(root_cid, root_data) - - # Enhanced logging for root CID - logger.info("=== File chunking completed ===") - logger.info( - f"Root CID: {format_cid_for_display(root_cid)} " - f"(Links to {len(chunks_data)} chunks)" - ) - logger.info(f"Total file size: {file_size} bytes") - logger.info("=== Chunk CIDs ===") - for i, (chunk_cid, chunk_size) in enumerate(chunks_data): - logger.info( - f" Chunk {i}: {format_cid_for_display(chunk_cid)} ({chunk_size} bytes)" - ) - logger.info("=" * 50) - - logger.info( - f"Added file with {len(chunks_data)} chunks. " - f"Root CID: {format_cid_for_display(root_cid, max_len=16)}" - ) - - if progress_callback: - await _call_progress_callback( - progress_callback, file_size, file_size, "completed" - ) - - # Wrap in directory if requested (IPFS-standard way for filename preservation) - if wrap_with_directory: - import os - - from .dag_pb import create_directory_node - - filename = os.path.basename(file_path) - logger.info(f"Wrapping file in directory with name: {filename}") - - # Create directory node with single entry pointing to the file - dir_data = create_directory_node([(filename, root_cid, file_size)]) - dir_cid = compute_cid_v1(dir_data, codec=CODEC_DAG_PB) - await self.bitswap.add_block(dir_cid, dir_data) - - logger.info( - "Created directory wrapper. Directory CID: " - f"{format_cid_for_display(dir_cid, max_len=16)}" - ) - return dir_cid - - return root_cid - - async def add_bytes( - self, - data: bytes, - chunk_size: int | None = None, - progress_callback: Callable[[int, int, str], None] | None = None, - ) -> bytes: - """ - Add bytes to the DAG (similar to add_file but for in-memory data). - - Args: - data: Data to add - chunk_size: Optional chunk size (auto-selected if None) - progress_callback: Optional callback(current, total, status) - - Returns: - Root CID - - Example: - >>> data = b"x" * (10 * 1024 * 1024) # 10 MB - >>> root_cid = await dag.add_bytes(data) - - """ - file_size = len(data) - logger.info(f"Adding {file_size} bytes") - - # Determine chunk size - if chunk_size is None: - chunk_size = DEFAULT_CHUNK_SIZE - - # If data is small, store as single block - if file_size <= chunk_size: - cid = compute_cid_v1(data, codec=CODEC_RAW) - await self.bitswap.add_block(cid, data) - - if progress_callback: - await _call_progress_callback( - progress_callback, file_size, file_size, "completed" - ) - - return cid - - # Chunk the data - chunks = chunk_bytes(data, chunk_size) - chunks_data: list[tuple[bytes, int]] = [] - - for i, chunk_data in enumerate(chunks): - chunk_cid = compute_cid_v1(chunk_data, codec=CODEC_RAW) - await self.bitswap.add_block(chunk_cid, chunk_data) - chunks_data.append((chunk_cid, len(chunk_data))) - - if progress_callback: - bytes_processed = sum(size for _, size in chunks_data) - await _call_progress_callback( - progress_callback, - bytes_processed, - file_size, - f"chunking ({i + 1}/{len(chunks)})", - ) - - # Create root node - root_data = create_file_node(chunks_data) - root_cid = compute_cid_v1(root_data, codec=CODEC_DAG_PB) - await self.bitswap.add_block(root_cid, root_data) - - if progress_callback: - await _call_progress_callback( - progress_callback, file_size, file_size, "completed" - ) - - return root_cid - - async def fetch_file( - self, - root_cid: CIDInput, - peer_id: PeerID | None = None, - timeout: float = 30.0, - progress_callback: Callable[[int, int, str], None] | None = None, - ) -> tuple[bytes, str | None]: - """ - Fetch a file from the DAG. - - Automatically resolves links and fetches all chunks. Works with both - single-block files and multi-chunk files. Everything is handled - automatically - just provide the root CID! - - The method automatically: - - Detects directory wrappers and extracts filename - - Fetches and decodes the root block - - Determines file size and number of chunks - - Fetches all chunks in sequence - - Verifies integrity of all blocks - - Reconstructs the complete file - - Args: - root_cid: Root CID of the file (or directory wrapper) - peer_id: Optional specific peer to fetch from - timeout: Timeout per block in seconds - progress_callback: Optional callback(current, total, status) - Receives metadata automatically in first call - - Returns: - Tuple of (file_data, filename) where filename is None if not - wrapped in directory - - Raises: - BlockNotFoundError: If any block cannot be found - ValueError: If CID verification fails - - Example: - >>> # Simple usage - just provide root CID - >>> data, filename = await dag.fetch_file(root_cid) - >>> save_path = filename or 'downloaded_file' - >>> open(save_path, 'wb').write(data) - - >>> # With progress tracking - >>> def progress(current, total, status): - ... percent = (current / total) * 100 if total > 0 else 0 - ... print(f"{status}: {percent:.1f}%") - >>> data, filename = await dag.fetch_file( - ... root_cid, progress_callback=progress - ... ) - - """ - root_cid_bytes = cid_to_bytes(root_cid) - logger.info( - f"Fetching file: {format_cid_for_display(root_cid_bytes, max_len=16)}" - ) - logger.info( - "=== Starting file fetch for CID: " - f"{format_cid_for_display(root_cid_bytes)} ===" - ) - - # Get root block - root_data = await self.bitswap.get_block(root_cid_bytes, peer_id, timeout) - - # Verify root block - if not verify_cid(root_cid_bytes, root_data): - raise ValueError( - "Root block verification failed: " - f"{format_cid_for_display(root_cid_bytes)}" - ) - - # Check if it's a directory wrapper (IPFS-standard way for filename) - filename = None - actual_file_cid = root_cid_bytes - actual_file_data = root_data - - if is_directory_node(root_data): - logger.info("Root is a directory node, extracting file entry...") - links, _ = decode_dag_pb(root_data) - - if links: - # Get the first (and typically only) file entry - first_link = links[0] - filename = first_link.name if first_link.name else None - actual_file_cid = first_link.cid - - logger.info(f"Extracted filename: {filename}") - logger.info( - f"Actual file CID: " - f"{format_cid_for_display(actual_file_cid, max_len=16)}" - ) - - # Fetch the actual file block - actual_file_data = await self.bitswap.get_block( - actual_file_cid, peer_id, timeout - ) - - if not verify_cid(actual_file_cid, actual_file_data): - raise ValueError( - "File block verification failed: " - f"{format_cid_for_display(actual_file_cid)}" - ) - - # Now process the actual file data - # Check if it's a DAG-PB file node - if is_file_node(actual_file_data): - logger.debug("Root is a DAG-PB file node, resolving chunks...") - - # Decode to get links and metadata - links, unixfs_data = decode_dag_pb(actual_file_data) - - if not links: - # File with inline data (small file) - logger.debug("File has inline data") - file_data = ( - unixfs_data.data if unixfs_data and unixfs_data.data else b"" - ) - - # Notify progress callback with metadata - if progress_callback: - await _call_progress_callback( - progress_callback, - len(file_data), - len(file_data), - f"metadata: size={len(file_data)}, chunks=0", - ) - - return file_data, filename - - # File with multiple chunks - total_size = unixfs_data.filesize if unixfs_data else 0 - logger.debug(f"File has {len(links)} chunks, total size: {total_size}") - logger.info( - f"Fetching multi-chunk file: {len(links)} chunks, {total_size} bytes" - ) - logger.info("=== Chunk CIDs to fetch ===") - for i, link in enumerate(links): - logger.info( - f" Chunk {i}: {format_cid_for_display(link.cid)} " - f"({link.size} bytes)" - ) - logger.info("=" * 50) - - # Notify progress callback with file metadata at the start - if progress_callback: - await _call_progress_callback( - progress_callback, - 0, - total_size, - f"metadata: size={total_size}, chunks={len(links)}", - ) - - file_data = b"" - bytes_fetched = 0 - - # Fetch each chunk - for i, link in enumerate(links): - if progress_callback: - await _call_progress_callback( - progress_callback, - bytes_fetched, - total_size, - f"fetching chunk {i + 1}/{len(links)}", - ) - - logger.info( - f"Fetching chunk {i + 1}/{len(links)}: " - f"CID={format_cid_for_display(link.cid)}" - ) - - # Fetch chunk - chunk_data = await self.bitswap.get_block(link.cid, peer_id, timeout) - - # Verify chunk - if not verify_cid(link.cid, chunk_data): - raise ValueError( - f"Chunk verification failed: {format_cid_for_display(link.cid)}" - ) - - file_data += chunk_data - bytes_fetched += len(chunk_data) - - logger.info( - f"✓ Chunk {i + 1} fetched and verified: " - f"{len(chunk_data)} bytes (total: {bytes_fetched}/{total_size})" - ) - logger.debug( - f"Fetched chunk {i + 1}/{len(links)}: " - f"{format_cid_for_display(link.cid, max_len=16)} " - f"({len(chunk_data)} bytes)" - ) - - if progress_callback: - await _call_progress_callback( - progress_callback, total_size, total_size, "completed" - ) - - logger.info("=== File fetch completed ===") - logger.info(f"Total bytes fetched: {len(file_data)}") - logger.info(f"All {len(links)} chunks verified successfully") - logger.info("=" * 50) - logger.info(f"Fetched file: {len(file_data)} bytes") - return file_data, filename - - # Not a DAG-PB file node - return as raw data - logger.debug("Root is a raw block, returning directly") - return actual_file_data, filename - - async def get_file_info( - self, root_cid: CIDInput, peer_id: PeerID | None = None, timeout: float = 30.0 - ) -> dict[str, int | list[int]]: - """ - Get information about a file without downloading it. - - Args: - root_cid: Root CID of the file - peer_id: Optional specific peer to fetch from - timeout: Timeout in seconds (default: 30.0) - - Returns: - Dictionary with file information: - - size: Total file size in bytes - - chunks: Number of chunks - - chunk_sizes: List of chunk sizes - - Example: - >>> info = await dag.get_file_info(root_cid) - >>> print(f"File size: {info['size']} bytes") - >>> print(f"Chunks: {info['chunks']}") - - """ - # Get root block - root_cid_bytes = cid_to_bytes(root_cid) - root_data = await self.bitswap.get_block(root_cid_bytes, peer_id, timeout) - - # Check if it's a DAG-PB file node - if is_file_node(root_data): - links, unixfs_data = decode_dag_pb(root_data) - - if not links: - # Small file with inline data - data_size = ( - len(unixfs_data.data) if unixfs_data and unixfs_data.data else 0 - ) - return {"size": data_size, "chunks": 0, "chunk_sizes": []} - - # Multi-chunk file - total_size = ( - unixfs_data.filesize - if unixfs_data - else sum(link.size for link in links) - ) - chunk_sizes = [link.size for link in links] - - return { - "size": total_size, - "chunks": len(links), - "chunk_sizes": chunk_sizes, - } - - # Single raw block - return {"size": len(root_data), "chunks": 1, "chunk_sizes": [len(root_data)]} - - -__all__ = ["MerkleDag"] +""" +Merkle DAG manager for file operations. + +This module provides a high-level API for adding and fetching files +using the Bitswap protocol with automatic chunking, linking, and +multi-block resolution. + +""" + +from collections.abc import Awaitable, Callable +import inspect +import io +import logging +from typing import Union + +from libp2p.peer.id import ID as PeerID + +from .block_service import BlockService +from .block_store import BlockStore +from .chunker import ( + DEFAULT_CHUNK_SIZE, + chunk_bytes, + chunk_file, + chunk_stream, + estimate_chunk_count, + get_file_size, +) +from .cid import ( + CODEC_DAG_PB, + CODEC_RAW, + CIDInput, + cid_to_bytes, + compute_cid_v1, + format_cid_for_display, + verify_cid, +) +from .client import BitswapClient +from .dag_pb import ( + balanced_layout, + decode_dag_pb, + is_directory_node, + is_file_node, +) +from .errors import BlockNotFoundError + +logger = logging.getLogger(__name__) + + +# Type alias for progress callbacks (sync or async) +ProgressCallback = Union[ + Callable[[int, int, str], None], + Callable[[int, int, str], Awaitable[None]], +] + + +async def _call_progress_callback( + callback: ProgressCallback | None, + current: int, + total: int, + status: str, +) -> None: + """Call a progress callback, handling both sync and async callbacks.""" + if callback is None: + return + + if inspect.iscoroutinefunction(callback): + await callback(current, total, status) + else: + callback(current, total, status) + + +class MerkleDag: + """ + Merkle DAG manager for file operations. + + Provides high-level API for adding and fetching files with automatic + chunking, link creation, and recursive block fetching. + + Example: + >>> from libp2p import new_host + >>> from libp2p.bitswap import BitswapClient, MemoryBlockStore, MerkleDag + >>> import trio + >>> + >>> async def main(): + ... host = new_host() + ... async with host.run(["/ip4/0.0.0.0/tcp/0"]): + ... store = MemoryBlockStore() + ... bitswap = BitswapClient(host, store) + ... await bitswap.start() + ... + ... dag = MerkleDag(bitswap) + ... + ... # Add a large file (auto-chunked) + ... root_cid = await dag.add_file('movie.mp4') + ... print(f"Share: {cid_to_text(root_cid)}") + ... + ... # Fetch file (auto-resolves all chunks) + ... data = await dag.fetch_file(root_cid) + ... open('downloaded.mp4', 'wb').write(data) + ... + >>> trio.run(main) + + """ + + def __init__( + self, + bitswap: BitswapClient, + block_store: BlockStore | None = None, + block_service: BlockService | None = None, + ): + """ + Initialize Merkle DAG manager. + + Args: + bitswap: Bitswap client for block exchange + block_store: Optional block store (uses bitswap's store if None) + block_service: Optional BlockService for transparent local→network + fallback with auto-caching. When provided, all block + reads/writes go through it instead of bitswap directly. + Construct with: BlockService(your_store, bitswap) + + """ + self.bitswap = bitswap + self.block_store = block_store or bitswap.block_store + # If a BlockService is provided use it; otherwise fall back to + # calling bitswap directly (existing behaviour, no regression). + self._service: BlockService | None = block_service + + # ── private routing helpers ─────────────────────────────────────────────── + + async def _put_block(self, cid: CIDInput, data: bytes) -> None: + """Store a block. Routes through BlockService when available.""" + if self._service is not None: + await self._service.put_block(cid, data) + else: + await self.bitswap.add_block(cid, data) + + async def _get_block( + self, + cid: CIDInput, + peer_id: PeerID | None = None, + timeout: float = 30.0, + ) -> bytes: + """Fetch a block. Routes through BlockService when available.""" + if self._service is not None: + data = await self._service.get_block(cid, peer_id=peer_id, timeout=timeout) + if data is None: + from .cid import cid_to_bytes, format_cid_for_display + + raise BlockNotFoundError( + f"Block not found: {format_cid_for_display(cid_to_bytes(cid))}" + ) + return data + return await self.bitswap.get_block(cid, peer_id, timeout) + + async def _get_blocks_batch( + self, + cids: list[CIDInput], + peer_id: PeerID | None = None, + timeout: float = 30.0, + batch_size: int = 32, + ) -> dict[bytes, bytes]: + """Batch-fetch blocks. Routes through BlockService when available.""" + if self._service is not None: + return await self._service.get_blocks_batch( + cids, peer_id=peer_id, timeout=timeout, batch_size=batch_size + ) + # Check if the client supports native batch fetching + get_blocks_batch: Callable[..., Awaitable[dict[bytes, bytes]]] | None = getattr( + self.bitswap, "get_blocks_batch", None + ) + if get_blocks_batch is not None and callable(get_blocks_batch): + try: + result = await get_blocks_batch( + cids, peer_id=peer_id, timeout=timeout, batch_size=batch_size + ) + # Ensure the result is a plain dict (not a coroutine from a mock) + if isinstance(result, dict): + return result + except Exception: + pass + # Fall back to individual _get_block calls + results: dict[bytes, bytes] = {} + for cid in cids: + from .cid import cid_to_bytes + + cid_bytes = cid_to_bytes(cid) + try: + data = await self._get_block( + cid_bytes, peer_id=peer_id, timeout=timeout + ) + results[cid_bytes] = data + except Exception: + pass + return results + + async def add_file( + self, + file_path: str, + chunk_size: int | None = None, + progress_callback: Callable[[int, int, str], None] | None = None, + wrap_with_directory: bool = False, + ) -> bytes: + """ + Add a file to the DAG. + + Automatically chunks large files and creates link structure. + Small files are stored as single blocks. + + Args: + file_path: Path to file + chunk_size: Optional chunk size (auto-selected if None) + progress_callback: Optional callback(current, total, status) + wrap_with_directory: If True, wraps file in a directory node with filename + (IPFS-standard way, enables filename preservation) + + Returns: + Root CID of the file (or wrapping directory if wrap_with_directory=True) + + Raises: + FileNotFoundError: If file doesn't exist + BlockTooLargeError: If a single chunk exceeds MAX_BLOCK_SIZE + + Example: + >>> async def progress(current, total, status): + ... print(f"{status}: {current}/{total}") + >>> root_cid = await dag.add_file('movie.mp4', progress_callback=progress) + >>> print(f"Share this: {cid_to_text(root_cid)}") + + """ + # Get file size + file_size = get_file_size(file_path) + logger.info(f"Adding file: {file_path} ({file_size} bytes)") + + # Determine chunk size + if chunk_size is None: + chunk_size = DEFAULT_CHUNK_SIZE + + logger.debug(f"Using chunk size: {chunk_size} bytes") + + # If file is small enough, store as single raw leaf block + # (Kubo default: RawLeaves=true) + if file_size <= chunk_size: + logger.debug("File fits in single block") + + with open(file_path, "rb") as f: + data = f.read() + + # Raw leaf: store file bytes directly with raw codec CID + cid = compute_cid_v1(data, codec=CODEC_RAW) + + await self._put_block(cid, data) + + if progress_callback: + await _call_progress_callback( + progress_callback, file_size, file_size, "completed" + ) + + logger.info( + f"Added file as single raw block: " + f"{format_cid_for_display(cid, max_len=16)}" + ) + + # Wrap in directory if requested + if wrap_with_directory: + import os + + from .dag_pb import create_directory_node + + filename = os.path.basename(file_path) + logger.info( + f"Wrapping single-block file in directory with name: {filename}" + ) + + # Tsize for raw leaf = raw file size (no block overhead) + dir_data = create_directory_node([(filename, cid, file_size)]) + dir_cid = compute_cid_v1(dir_data, codec=CODEC_DAG_PB) + await self._put_block(dir_cid, dir_data) + + logger.info( + f"Created directory wrapper. Directory CID: " + f"{format_cid_for_display(dir_cid, max_len=16)}" + ) + return dir_cid + + return cid + + # Chunk the file + estimated_chunks = estimate_chunk_count(file_size, chunk_size) + logger.debug(f"Chunking file into ~{estimated_chunks} chunks") + logger.info("=== Starting file chunking process ===") + + # leaf_triples: (cid_bytes, leaf_block_bytes, raw_data_size) + # For raw leaves (Kubo default): leaf_block = raw chunk bytes, + # CID uses CODEC_RAW. This matches Kubo's RawLeaves=true behavior + # for multi-chunk files, producing identical CIDs. + leaf_triples: list[tuple[bytes, bytes, int]] = [] + bytes_processed = 0 + + # Process file in chunks (memory efficient) + for i, chunk_data in enumerate(chunk_file(file_path, chunk_size)): + # Raw leaf: store chunk bytes directly with raw codec CID (Kubo default) + chunk_cid = compute_cid_v1(chunk_data, codec=CODEC_RAW) + + await self._put_block(chunk_cid, chunk_data) + leaf_triples.append((chunk_cid, chunk_data, len(chunk_data))) + bytes_processed += len(chunk_data) + + # Progress callback + if progress_callback: + await _call_progress_callback( + progress_callback, + bytes_processed, + file_size, + f"chunking ({i + 1} chunks)", + ) + + logger.info( + f"Chunk {i + 1}: CID={format_cid_for_display(chunk_cid)}, " + f"Size={len(chunk_data)} bytes, " + f"Progress={bytes_processed}/{file_size}" + ) + logger.debug( + f"Stored leaf {i}: {format_cid_for_display(chunk_cid, max_len=16)} " + f"({len(chunk_data)} bytes)" + ) + + # Build balanced DAG tree (max 174 links/node, matches Kubo) + if progress_callback: + await _call_progress_callback( + progress_callback, file_size, file_size, "creating root node" + ) + + # Create a sync wrapper for the async _put_block method + # We'll collect (cid, data) pairs and store them after + internal_nodes: list[tuple[bytes, bytes]] = [] + + def store_internal_node(cid: bytes, data: bytes) -> None: + """Callback to collect internal nodes for storage.""" + internal_nodes.append((cid, data)) + + root_cid, root_data, root_tsize = balanced_layout( + leaf_triples, put_block_callback=store_internal_node + ) + + # Store all internal nodes + logger.info(f"Storing {len(internal_nodes)} internal DAG nodes...") + for cid, data in internal_nodes: + await self._put_block(cid, data) + + # Store the root node + await self._put_block(root_cid, root_data) + + # Enhanced logging for root CID + logger.info("=== File chunking completed ===") + logger.info( + f"Root CID: {format_cid_for_display(root_cid)} " + f"(Balanced DAG over {len(leaf_triples)} leaves)" + ) + logger.info(f"Total file size: {file_size} bytes") + logger.info("=" * 50) + + logger.info( + f"Added file with {len(leaf_triples)} leaves. " + f"Root CID: {format_cid_for_display(root_cid, max_len=16)}" + ) + + if progress_callback: + await _call_progress_callback( + progress_callback, file_size, file_size, "completed" + ) + + # Wrap in directory if requested (IPFS-standard way for filename preservation) + if wrap_with_directory: + import os + + from .dag_pb import create_directory_node + + filename = os.path.basename(file_path) + logger.info(f"Wrapping file in directory with name: {filename}") + + # Tsize = cumulative block size (root block + all descendant blocks), + # matching Kubo's behavior for directory link Tsize. + dir_data = create_directory_node([(filename, root_cid, root_tsize)]) + dir_cid = compute_cid_v1(dir_data, codec=CODEC_DAG_PB) + await self._put_block(dir_cid, dir_data) + + logger.info( + "Created directory wrapper. Directory CID: " + f"{format_cid_for_display(dir_cid, max_len=16)}" + ) + return dir_cid + + return root_cid + + async def add_bytes( + self, + data: bytes, + chunk_size: int | None = None, + progress_callback: Callable[[int, int, str], None] | None = None, + ) -> bytes: + """ + Add bytes to the DAG (similar to add_file but for in-memory data). + + Args: + data: Data to add + chunk_size: Optional chunk size (auto-selected if None) + progress_callback: Optional callback(current, total, status) + + Returns: + Root CID + + Example: + >>> data = b"x" * (10 * 1024 * 1024) # 10 MB + >>> root_cid = await dag.add_bytes(data) + + """ + file_size = len(data) + logger.info(f"Adding {file_size} bytes") + + # Determine chunk size + if chunk_size is None: + chunk_size = DEFAULT_CHUNK_SIZE + + # If data is small, store as single raw leaf block + # (Kubo default: RawLeaves=true) + if file_size <= chunk_size: + cid = compute_cid_v1(data, codec=CODEC_RAW) + await self._put_block(cid, data) + + if progress_callback: + await _call_progress_callback( + progress_callback, file_size, file_size, "completed" + ) + + return cid + + # Chunk the data using raw leaves (Kubo default: RawLeaves=true) + chunks = chunk_bytes(data, chunk_size) + leaf_triples: list[tuple[bytes, bytes, int]] = [] + + for i, chunk_data in enumerate(chunks): + chunk_cid = compute_cid_v1(chunk_data, codec=CODEC_RAW) + await self._put_block(chunk_cid, chunk_data) + leaf_triples.append((chunk_cid, chunk_data, len(chunk_data))) + + if progress_callback: + bytes_processed = sum(s for _, _, s in leaf_triples) + await _call_progress_callback( + progress_callback, + bytes_processed, + file_size, + f"chunking ({i + 1}/{len(chunks)})", + ) + + # Build balanced DAG tree + root_cid, root_data, _tsize = balanced_layout(leaf_triples) + await self._put_block(root_cid, root_data) + + if progress_callback: + await _call_progress_callback( + progress_callback, file_size, file_size, "completed" + ) + + return root_cid + + async def add_stream( + self, + stream: io.IOBase, + chunk_size: int | None = None, + progress_callback: ProgressCallback | None = None, + ) -> bytes: + """ + Add data from any io.IOBase stream to the DAG. + + More flexible than add_file() (accepts any stream, not just file paths) + and more memory efficient than add_bytes() (reads one chunk at a time, + so total memory usage is O(chunk_size) regardless of file size). + + Args: + stream: Any readable io.IOBase — open() handles, BytesIO, + GzipFile, BZ2File, network streams, pipes, etc. + chunk_size: Optional chunk size in bytes (auto-selected if None) + progress_callback: Optional callback(current, total, status). + Note: total is unknown for streams, so current + is reported as bytes processed so far. + + Returns: + Root CID bytes of the stored DAG + + Example: + >>> import io + >>> root_cid = await dag.add_stream(io.BytesIO(b"hello world")) + + >>> # Memory-efficient large file (no full read into RAM) + >>> with open("movie.mp4", "rb") as f: + ... root_cid = await dag.add_stream(f) + + >>> # Decompress and add in one pass + >>> import gzip + >>> with gzip.open("archive.gz", "rb") as f: + ... root_cid = await dag.add_stream(f) + + >>> # With BlockService for persistent caching + >>> service = BlockService(FilesystemBlockStore("./blocks"), bitswap) + >>> dag = MerkleDag(bitswap, block_service=service) + >>> with open("large.bin", "rb") as f: + ... root_cid = await dag.add_stream(f) # cached to disk + + """ + if chunk_size is None: + chunk_size = DEFAULT_CHUNK_SIZE + + leaf_triples: list[tuple[bytes, bytes, int]] = [] + bytes_processed = 0 + + for i, chunk_data in enumerate(chunk_stream(stream, chunk_size)): + # Raw leaf: store chunk bytes directly (Kubo default: RawLeaves=true) + chunk_cid = compute_cid_v1(chunk_data, codec=CODEC_RAW) + await self._put_block(chunk_cid, chunk_data) + leaf_triples.append((chunk_cid, chunk_data, len(chunk_data))) + bytes_processed += len(chunk_data) + + if progress_callback: + # total is unknown for streams — report bytes processed so far + await _call_progress_callback( + progress_callback, + bytes_processed, + bytes_processed, + f"chunking ({i + 1} chunks, {bytes_processed} bytes)", + ) + + # Empty stream — store a single empty raw block + if not leaf_triples: + cid = compute_cid_v1(b"", codec=CODEC_RAW) + await self._put_block(cid, b"") + return cid + + # Single chunk — return the leaf CID directly (no root node needed) + if len(leaf_triples) == 1: + return leaf_triples[0][0] + + # Multiple chunks — build balanced DAG tree + root_cid, root_data, _tsize = balanced_layout(leaf_triples) + await self._put_block(root_cid, root_data) + + if progress_callback: + await _call_progress_callback( + progress_callback, bytes_processed, bytes_processed, "completed" + ) + + return root_cid + + async def fetch_file( + self, + root_cid: CIDInput, + peer_id: PeerID | None = None, + timeout: float = 30.0, + progress_callback: Callable[[int, int, str], None] | None = None, + ) -> tuple[bytes, str | None]: + """ + Fetch a file from the DAG. + + Automatically resolves links and fetches all chunks. Works with both + single-block files and multi-chunk files. Everything is handled + automatically - just provide the root CID! + + The method automatically: + - Detects directory wrappers and extracts filename + - Fetches and decodes the root block + - Determines file size and number of chunks + - Fetches all chunks in sequence + - Verifies integrity of all blocks + - Reconstructs the complete file + + Args: + root_cid: Root CID of the file (or directory wrapper) + peer_id: Optional specific peer to fetch from + timeout: Timeout per block in seconds + progress_callback: Optional callback(current, total, status) + Receives metadata automatically in first call + + Returns: + Tuple of (file_data, filename) where filename is None if not + wrapped in directory + + Raises: + BlockNotFoundError: If any block cannot be found + ValueError: If CID verification fails + + Example: + >>> # Simple usage - just provide root CID + >>> data, filename = await dag.fetch_file(root_cid) + >>> save_path = filename or 'downloaded_file' + >>> open(save_path, 'wb').write(data) + + >>> # With progress tracking + >>> def progress(current, total, status): + ... percent = (current / total) * 100 if total > 0 else 0 + ... print(f"{status}: {percent:.1f}%") + >>> data, filename = await dag.fetch_file( + ... root_cid, progress_callback=progress + ... ) + + """ + root_cid_bytes = cid_to_bytes(root_cid) + logger.info(f"Fetching file: {format_cid_for_display(root_cid_bytes)}") + + # Step 1: Fetch the root block + root_data = await self._get_block(root_cid_bytes, peer_id, timeout) + if not verify_cid(root_cid_bytes, root_data): + root_cid_str = format_cid_for_display(root_cid_bytes) + raise ValueError(f"Root block CID verification failed: {root_cid_str}") + + # Step 2: Handle directory wrapper + # (produced by `ipfs add --wrap-with-directory`) + filename = None + actual_file_cid = root_cid_bytes + actual_file_data = root_data + + if is_directory_node(root_data): + logger.info("Root is a directory node — extracting filename and file CID") + dir_links, _ = decode_dag_pb(root_data) + if dir_links: + first_link = dir_links[0] + filename = first_link.name or None + # Links now store the full CID bytes (CIDv1 buffer or CIDv0 multihash) + actual_file_cid = first_link.cid + logger.info(f"Filename from directory: {filename!r}") + actual_file_data = await self._get_block( + actual_file_cid, peer_id, timeout + ) + if not verify_cid(actual_file_cid, actual_file_data): + f_cid_str = format_cid_for_display(actual_file_cid) + err_msg = f"File block CID verification failed: {f_cid_str}" + raise ValueError(err_msg) + + # Step 3: Handle raw block (not a DAG-PB node at all) + if not is_file_node(actual_file_data): + logger.info(f"Root is a raw block: {len(actual_file_data)} bytes") + return actual_file_data, filename + + # Step 4: Parse the file node + top_links, top_unixfs = decode_dag_pb(actual_file_data) + filesize = top_unixfs.filesize if top_unixfs else 0 + total_size = filesize or sum(lnk.size for lnk in top_links) + msg = f"File node: {len(top_links)} top-level links, total size={total_size}" + logger.info(f"{msg} bytes") + + # Step 5: Small file with inline data (no links) + if not top_links: + file_data = top_unixfs.data if top_unixfs and top_unixfs.data else b"" + logger.info(f"Inline file data: {len(file_data)} bytes") + if progress_callback: + data_len = len(file_data) + await _call_progress_callback( + progress_callback, data_len, data_len, "completed" + ) + return file_data, filename + + # Step 6: Collect all leaf CIDs without opening streams + # Strategy: Recursively batch-fetch all DAG nodes + # then traverse locally to collect leaves + + top_len = len(top_links) + msg1 = f"[DAG] Recursively batch-fetching DAG tree ({top_len} top links)..." + logger.info(msg1) + msg2 = f"[FETCH] Recursively batch-fetching DAG tree ({top_len} top links)..." + print(msg2, flush=True) + + # Map to store ALL fetched blocks (both intermediate and leaves) + all_blocks_map: dict[bytes, bytes] = {} + + async def _batch_fetch_tree(cid_list: list[bytes], depth: int) -> None: + """Recursively batch-fetch a level of DAG nodes and queue their children.""" + if not cid_list: + return + + c_count = len(cid_list) + msg1 = f"[DAG] Depth {depth}: batch-fetching {c_count} blocks..." + logger.info(msg1) + msg2 = f"[FETCH] Depth {depth}: batch-fetching {c_count} blocks..." + print(msg2, flush=True) + + # Batch-fetch this level's blocks + level_blocks = await self._get_blocks_batch( + list(cid_list), peer_id=peer_id, timeout=timeout, batch_size=32 + ) + logger.info(f"[DAG] Depth {depth}: ✓ received {len(level_blocks)} blocks") + all_blocks_map.update(level_blocks) + + # Collect child CIDs for recursion + child_cids: list[bytes] = [] + for cid_bytes in cid_list: + block_data = level_blocks.get(cid_bytes) + if block_data is None: + c_str = format_cid_for_display(cid_bytes) + msg = f"[DAG] Depth {depth}: block {c_str} missing after" + logger.warning(f"{msg} fetch") + continue + + if is_file_node(block_data): + node_links, _ = decode_dag_pb(block_data) + cid_str = format_cid_for_display(cid_bytes) + msg = f"[DAG] Depth {depth}: {cid_str} has {len(node_links)}" + logger.debug(f"{msg} children") + for link in node_links: + # Links now store full CID bytes directly + child_cids.append(link.cid) + + # Recursively fetch next level if there are children + if child_cids: + ch_count = len(child_cids) + msg = f"[DAG] Depth {depth}: found {ch_count} child CIDs" + logger.info(f"{msg}, fetching next level...") + await _batch_fetch_tree(child_cids, depth + 1) + + # Starting from the top-level links (full CID bytes stored in links) + top_cids = [top_link.cid for top_link in top_links] + await _batch_fetch_tree(top_cids, depth=1) + blocks_count = len(all_blocks_map) + logger.info(f"[DAG] ✓ Tree fetch complete: {blocks_count} total blocks") + print(f"[FETCH] ✓ Tree fetch complete: {blocks_count} total blocks", flush=True) + + # Now traverse locally to collect leaf CIDs in order + ordered_leaf_cids: list[bytes] = [] + + def _collect_leaves_local(cid_bytes: bytes, depth: int = 1) -> None: + """Traverse locally-fetched blocks to collect leaf CIDs.""" + block_data = all_blocks_map.get(cid_bytes) + if block_data is None: + cid_str = format_cid_for_display(cid_bytes) + logger.warning(f"[DAG] Depth {depth}: block {cid_str} not in map") + return + + if not is_file_node(block_data): + # Raw block - it's a leaf + logger.debug(f"[DAG] Depth {depth}: raw block (leaf)") + ordered_leaf_cids.append(cid_bytes) + return + + node_links, _ = decode_dag_pb(block_data) + logger.debug(f"[DAG] Depth {depth}: {len(node_links)} links") + + if not node_links: + # Leaf node (no children, data is inline in UnixFS) + logger.debug(f"[DAG] Depth {depth}: file node with inline data (leaf)") + ordered_leaf_cids.append(cid_bytes) + return + + # Intermediate node - recursively process children + for j, child_link in enumerate(node_links): + c_idx = j + 1 + c_tot = len(node_links) + msg = f"[DAG] Depth {depth}: processing child {c_idx}/{c_tot}" + logger.debug(msg) + # Links store full CID bytes directly + child_cid = child_link.cid + _collect_leaves_local(child_cid, depth + 1) + + # Traverse each top-level block + for i, top_link in enumerate(top_links): + logger.info(f"[DAG] Traversing top-level {i + 1}/{len(top_links)}...") + # Links store full CID bytes directly + top_cid = top_link.cid + _collect_leaves_local(top_cid, depth=1) + + logger.info(f"[DAG] ✓ Collected {len(ordered_leaf_cids)} leaf blocks") + + # Step 7: Batch-fetch all leaf blocks + # (single wantlist per batch → avoids GO_AWAY) + if progress_callback: + await _call_progress_callback( + progress_callback, + 0, + total_size, + f"fetching {len(ordered_leaf_cids)} leaf blocks in batches", + ) + + l_count = len(ordered_leaf_cids) + msg1 = f"[DAG] Starting batch fetch of {l_count} leaves with batch_size=32" + logger.info(f"{msg1}, timeout={timeout}s") + msg2 = ( + f"[FETCH] Batch fetching {l_count} leaves " + f"(batch_size=32, timeout={timeout}s)" + ) + print(msg2, flush=True) + + # First try to get blocks from the already-fetched tree + block_map: dict[bytes, bytes] = {} + missing_cids: list[CIDInput] = [] + for leaf_cid in ordered_leaf_cids: + leaf_data = all_blocks_map.get(leaf_cid) + if leaf_data is not None: + block_map[leaf_cid] = leaf_data + else: + missing_cids.append(leaf_cid) + + # If some leaves weren't in the tree fetch, fetch them now + if missing_cids: + logger.info(f"[DAG] Fetching {len(missing_cids)} missing leaves") + fetched_blocks = await self._get_blocks_batch( + missing_cids, peer_id=peer_id, timeout=timeout, batch_size=32 + ) + block_map.update(fetched_blocks) + + logger.info(f"[DAG] ✓ Batch fetch complete: {len(block_map)} blocks received") + print(f"[FETCH] ✓ Batch fetch complete: {len(block_map)} blocks", flush=True) + + # Step 8: Reassemble data in order + # extracting UnixFS inline data from leaf nodes + file_data = b"" + bytes_fetched = 0 + missing_blocks: list[bytes] = [] + for idx, leaf_cid in enumerate(ordered_leaf_cids): + leaf_raw = block_map.get(bytes(leaf_cid)) + if leaf_raw is None: + l_idx = idx + 1 + t_leaves = len(ordered_leaf_cids) + c_str = format_cid_for_display(leaf_cid) + msg = f"[DAG] Leaf block {l_idx}/{t_leaves} MISSING: {c_str}" + logger.error(msg) + print(f"[FETCH] ✗ Leaf {l_idx}/{t_leaves} MISSING", flush=True) + missing_blocks.append(leaf_cid) + continue + + # Extract data: leaf blocks are UnixFS file nodes with inline data + if is_file_node(leaf_raw): + _, leaf_unixfs = decode_dag_pb(leaf_raw) + if leaf_unixfs is not None and leaf_unixfs.data: + chunk = leaf_unixfs.data + else: + chunk = b"" + chunk_len = len(chunk) + msg = f"[DAG] Leaf {idx + 1}: extracted {chunk_len} bytes" + logger.debug(f"{msg} from file node") + else: + chunk = leaf_raw + logger.debug(f"[DAG] Leaf {idx + 1}: raw block {len(chunk)} bytes") + + file_data += chunk + bytes_fetched += len(chunk) + + if (idx + 1) % 10 == 0 or idx == len(ordered_leaf_cids) - 1: + i_p = idx + 1 + t_l = len(ordered_leaf_cids) + p_str = f"{bytes_fetched}/{total_size} bytes" + logger.info(f"[DAG] Reassembled {i_p}/{t_l} leaves: {p_str}") + print(f"[FETCH] Reassembled {i_p}/{t_l} leaves: {p_str}", flush=True) + + if progress_callback: + await _call_progress_callback( + progress_callback, bytes_fetched, total_size, "downloading" + ) + + if missing_blocks: + missing_count = len(missing_blocks) + logger.error(f"[DAG] ✗ {missing_count} blocks missing after batch fetch!") + missing_list = [format_cid_for_display(cid) for cid in missing_blocks[:5]] + msg = f"{missing_count} leaf blocks missing: {missing_list}..." + raise BlockNotFoundError(msg) + + if progress_callback: + await _call_progress_callback( + progress_callback, total_size, total_size, "completed" + ) + + file_len = len(file_data) + msg = f"[DAG] ✓ File fetch complete: {file_len} bytes, filename={filename!r}" + logger.info(msg) + print(f"[FETCH] ✓ DOWNLOAD COMPLETE: {file_len} bytes", flush=True) + return file_data, filename + + async def get_file_info( + self, root_cid: CIDInput, peer_id: PeerID | None = None, timeout: float = 30.0 + ) -> dict[str, int | list[int]]: + """ + Get information about a file without downloading it. + + Args: + root_cid: Root CID of the file + peer_id: Optional specific peer to fetch from + timeout: Timeout in seconds (default: 30.0) + + Returns: + Dictionary with file information: + - size: Total file size in bytes + - chunks: Number of chunks + - chunk_sizes: List of chunk sizes + + Example: + >>> info = await dag.get_file_info(root_cid) + >>> print(f"File size: {info['size']} bytes") + >>> print(f"Chunks: {info['chunks']}") + + """ + # Get root block + root_cid_bytes = cid_to_bytes(root_cid) + root_data = await self._get_block(root_cid_bytes, peer_id, timeout) + + # Check if it's a DAG-PB file node + if is_file_node(root_data): + links, unixfs_data = decode_dag_pb(root_data) + + if not links: + # Small file with inline data + data_size = ( + len(unixfs_data.data) if unixfs_data and unixfs_data.data else 0 + ) + return {"size": data_size, "chunks": 0, "chunk_sizes": []} + + # Multi-chunk file + total_size = ( + unixfs_data.filesize + if unixfs_data + else sum(link.size for link in links) + ) + chunk_sizes = [link.size for link in links] + + return { + "size": total_size, + "chunks": len(links), + "chunk_sizes": chunk_sizes, + } + + # Single raw block + return {"size": len(root_data), "chunks": 1, "chunk_sizes": [len(root_data)]} + + +__all__ = ["MerkleDag"] diff --git a/libp2p/bitswap/dag_pb.py b/libp2p/bitswap/dag_pb.py index 74bbcddc2..1d4e6ec37 100644 --- a/libp2p/bitswap/dag_pb.py +++ b/libp2p/bitswap/dag_pb.py @@ -5,22 +5,45 @@ which is used by IPFS to represent files and directories as Merkle DAGs. """ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass, field import logging -from .cid import CIDInput, cid_to_bytes -from .pb.dag_pb_pb2 import PBNode +from .cid import CODEC_DAG_PB, CIDInput, compute_cid_v1 +from .pb.dag_pb_pb2 import PBLink, PBNode from .pb.unixfs_pb2 import Data as PBUnixFSData +# Maximum links per internal DAG-PB node — matches Go's balanced.Layout default +MAX_LINKS_PER_NODE = 174 + logger = logging.getLogger(__name__) +def _encode_varint(value: int) -> bytes: + """Encode an unsigned integer as a protobuf varint.""" + buf = [] + while value > 0x7F: + buf.append((value & 0x7F) | 0x80) + value >>= 7 + buf.append(value & 0x7F) + return bytes(buf) + + def _normalize_link_cid(cid: CIDInput) -> bytes: - """Normalize CID input for DAG links while preserving raw-bytes compatibility.""" - if isinstance(cid, bytes): - return cid - return cid_to_bytes(cid) + """ + Normalize CID input for DAG links. + + DAG-PB links store the full CID bytes in the Hash field. + For CIDv0 (legacy), this is the 34-byte multihash. + For CIDv1 (e.g. raw-leaf blocks), this is the full CIDv1 buffer + (version varint + codec varint + multihash), matching Kubo's behavior. + """ + from .cid import parse_cid + + cid_obj = parse_cid(cid) + # CIDv0: buffer IS the multihash — no change in behavior. + # CIDv1: buffer includes version + codec + multihash — store the full CID. + return cid_obj.buffer @dataclass(init=False) @@ -103,38 +126,42 @@ def encode_dag_pb(links: list[Link], unixfs_data: UnixFSData | None = None) -> b >>> encoded = encode_dag_pb(links, data) """ - # Create PBNode - pb_node = PBNode() + # DAG-PB canonical format requires Links (field 2) BEFORE Data (field 1). + # Standard protobuf SerializeToString() emits fields in field-number order + # (Data=1 first, Links=2 second), producing different bytes and a different + # CID than Kubo for the same logical content. + # We manually construct the wire format to enforce the correct ordering. + + result = b"" - # Add links + # 1. Serialize each Link first — field 2, wire type 2 (length-delimited) = tag 0x12 for link in links: - pb_link = pb_node.Links.add() + pb_link = PBLink() pb_link.Hash = link.cid pb_link.Name = link.name pb_link.Tsize = link.size + link_bytes = pb_link.SerializeToString() + result += b"\x12" + _encode_varint(len(link_bytes)) + link_bytes - # Add UnixFS data if provided - if unixfs_data: - # Create UnixFS data structure + # 2. Serialize Data after Links — field 1, wire type 2 = tag 0x0a + if unixfs_data is not None: pb_unixfs = PBUnixFSData() pb_unixfs.Type = UnixFSData.TYPE_MAP[unixfs_data.type] # type: ignore[assignment] - pb_unixfs.Data = unixfs_data.data - pb_unixfs.filesize = unixfs_data.filesize - - # Add blocksizes + # Only set fields with non-default values to match Kubo's encoding + if unixfs_data.data: + pb_unixfs.Data = unixfs_data.data + if unixfs_data.filesize: + pb_unixfs.filesize = unixfs_data.filesize for blocksize in unixfs_data.blocksizes: pb_unixfs.blocksizes.append(blocksize) - if unixfs_data.hash_type: pb_unixfs.hashType = unixfs_data.hash_type if unixfs_data.fanout: pb_unixfs.fanout = unixfs_data.fanout + data_bytes = pb_unixfs.SerializeToString() + result += b"\x0a" + _encode_varint(len(data_bytes)) + data_bytes - # Serialize UnixFS data and add to PBNode - pb_node.Data = pb_unixfs.SerializeToString() - - # Serialize PBNode - return pb_node.SerializeToString() + return result def decode_dag_pb(data: bytes) -> tuple[list[Link], UnixFSData | None]: @@ -213,7 +240,7 @@ def create_file_node(chunks: Sequence[tuple[CIDInput, int]]) -> bytes: blocksizes = [] for i, (cid, size) in enumerate(chunks): - links.append(Link(cid=cid, name=f"chunk{i}", size=size)) + links.append(Link(cid=cid, name="", size=size)) blocksizes.append(size) total_size += size @@ -282,3 +309,104 @@ def get_file_size(data: bytes) -> int: if unixfs_data and unixfs_data.type == "file": return unixfs_data.filesize return 0 + + +def create_leaf_node(data: bytes) -> bytes: + """ + Create a DAG-PB leaf node for a single file chunk. + + Wraps raw bytes in UnixFS Data(type=File, data=chunk, filesize=len(chunk)) + inside a PBNode with no links. This matches Kubo's default behaviour + (RawLeaves=false), ensuring leaf CIDs are byte-identical to those + produced by `ipfs add`. + + Args: + data: Raw chunk bytes (may be empty for an empty file) + + Returns: + Encoded DAG-PB bytes, suitable for storage as a dag-pb block + + """ + unixfs_data = UnixFSData(type="file", data=data, filesize=len(data)) + return encode_dag_pb([], unixfs_data) + + +def balanced_layout( + leaves: list[tuple[bytes, bytes, int]], + max_links: int = MAX_LINKS_PER_NODE, + put_block_callback: Callable[[bytes, bytes], None] | None = None, +) -> tuple[bytes, bytes, int]: + """ + Build a balanced Merkle DAG from a flat list of leaf blocks. + + Groups leaves into batches of `max_links` (default 174), creates an + internal DAG-PB node for each batch, then repeats level by level until + a single root remains. Matches Go's balanced.Layout exactly. + + Args: + leaves: List of (cid_bytes, block_bytes, file_data_size) tuples where + - cid_bytes: CID of the leaf block as raw bytes + - block_bytes: The encoded dag-pb leaf block bytes + - file_data_size: Size of the raw file data inside this leaf + (i.e. len(original chunk), NOT len(block)) + max_links: Max links per internal node (default 174, matches Kubo) + put_block_callback: Optional async callback to store each internal node + Signature: callback(cid_bytes, block_bytes) + + Returns: + (root_cid_bytes, root_block_bytes, cumulative_tsize) + where cumulative_tsize = len(root_block) + sum of all descendant block sizes. + This matches the Tsize value Kubo stores in directory links pointing to + the root of a multi-block file. + + Raises: + ValueError: If leaves is empty + + """ + if not leaves: + raise ValueError("Cannot build balanced layout from empty leaf list") + + if len(leaves) == 1: + return leaves[0][0], leaves[0][1], len(leaves[0][1]) + + # Each level entry: (cid_bytes, block_bytes, file_data_size, cumulative_block_size) + # cumulative_block_size = len(this block) + sum(children's cumulative sizes) + level: list[tuple[bytes, bytes, int, int]] = [ + (cid, blk, fsize, len(blk)) for cid, blk, fsize in leaves + ] + + while len(level) > 1: + next_level: list[tuple[bytes, bytes, int, int]] = [] + for i in range(0, len(level), max_links): + batch = level[i : i + max_links] + if len(batch) == 1: + next_level.append(batch[0]) + continue + + # Build internal node: links to each child, UnixFS blocksizes + internal_links: list[Link] = [] + blocksizes: list[int] = [] + total_filesize = 0 + total_cum = 0 + for cid_b, _, fsize, cum in batch: + # Tsize = cumulative block size of the subtree rooted at this child + internal_links.append(Link(cid=cid_b, name="", size=cum)) + blocksizes.append(fsize) + total_filesize += fsize + total_cum += cum + + unixfs_data = UnixFSData( + type="file", filesize=total_filesize, blocksizes=blocksizes + ) + internal_block = encode_dag_pb(internal_links, unixfs_data) + internal_cid = compute_cid_v1(internal_block, codec=CODEC_DAG_PB) + + # Store internal node if callback provided + if put_block_callback is not None: + put_block_callback(internal_cid, internal_block) + # cumulative size = own block + sum of children's cumulative sizes + cum_size = len(internal_block) + total_cum + next_level.append((internal_cid, internal_block, total_filesize, cum_size)) + level = next_level + + return level[0][0], level[0][1], level[0][3] diff --git a/libp2p/bitswap/extension.py b/libp2p/bitswap/extension.py new file mode 100644 index 000000000..0dc682c5b --- /dev/null +++ b/libp2p/bitswap/extension.py @@ -0,0 +1,63 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from libp2p.abc import INetStream +from libp2p.bitswap.client import BitswapClient +from libp2p.peer.id import ID as PeerID + +if not TYPE_CHECKING: + from libp2p.bitswap.client import BitswapClient + + +class IBitswapExtension(ABC): + """ + Abstract base class for protocol-bound Bitswap extensions. + Extensions are registered for specific protocol versions to handle messages. + """ + + client: "BitswapClient" + + def set_client(self, client: BitswapClient) -> None: + """ + Set the parent BitswapClient instance. + """ + self.client = client + + @abstractmethod + async def process_message( + self, peer_id: PeerID, msg_bytes: bytes, stream: INetStream + ) -> bool: + """ + Process an incoming message. + + Args: + peer_id: The ID of the peer sending the message + msg_bytes: The raw bytes of the incoming message + stream: The network stream to communicate back + + Returns: + True if the extension fully handled the message and no further + processing is required. + False if normal processing should continue. + + """ + pass + + @abstractmethod + async def process_wantlist( + self, wantlist: Any, peer_id: PeerID, stream: INetStream + ) -> bool: + """ + Process a wantlist specifically. + + Args: + wantlist: The Wantlist protobuf object + peer_id: The ID of the peer + stream: The network stream + + Returns: + True if the extension handled the wantlist fully. + False if BitswapClient should process it normally. + + """ + pass diff --git a/libp2p/bitswap/gated_decision_engine.py b/libp2p/bitswap/gated_decision_engine.py new file mode 100644 index 000000000..e71668c72 --- /dev/null +++ b/libp2p/bitswap/gated_decision_engine.py @@ -0,0 +1,526 @@ +""" +Payment-Gated Decision Engine for Bitswap 1.3.0. + +Extends the standard Bitswap block serving logic with payment gating: +- If a block is free, serve it directly. +- If a block requires payment and the peer has NOT paid, respond with + PaymentRequired (type=2) + PaymentTerms in-band (1.3.0 path). +- If the peer sends a TxReceipt (on-chain payment proof), verify it + and serve the block. + +Proto alignment: + PaymentTerms → fields: cid, asset, pay_to, amount, network, block_size, description + TxReceipt → fields: cid, tx_hash, from_address, to_address, amount, asset, network + PaymentReceipt → fields: cid, tx_hash, expires + PaymentRejection → fields: cid, reason + +This module lives in py-libp2p so it's importable as libp2p.bitswap. +""" + +import logging +import time +from typing import Any + +from libp2p.bitswap.block_store import BlockStore +from libp2p.bitswap.cid import parse_cid +from libp2p.bitswap.pb.bitswap_1_3_0_pb2 import Message as Message_1_3 +from libp2p.bitswap.pb.bitswap_pb2 import Message as Message_1_2 + +logger = logging.getLogger(__name__) + +BITSWAP_PROTOCOL_V120 = "/ipfs/bitswap/1.2.0" +BITSWAP_PROTOCOL_V130 = "/ipfs/bitswap/1.3.0" + + +class PaymentGatedDecisionEngine: + """ + Decides whether to serve a block or gate it behind payment. + + Integrates with: + - payments.ledger.PaymentLedger — tracks paid (peer, cid) pairs + - payments.pricing.BlockPricingEngine — computes prices + - payments.tx_verifier.TxVerifier — verifies on-chain TxReceipts + + Payment flow (1.3.0): + 1. Client sends WANT_BLOCK + 2. Server → PaymentRequired + PaymentTerms (price offer) + 3. Client pays on-chain, sends TxReceipt with tx_hash + 4. Server verifies tx on-chain → PaymentReceipt + block data + """ + + def __init__( + self, + blockstore: BlockStore, + ledger: Any, # payments.ledger.PaymentLedger + pricing: Any, # payments.pricing.BlockPricingEngine + tx_verifier: Any, # payments.tx_verifier.TxVerifier (or None) + server_wallet: str = "", + network: str = "sepolia", + asset: str = "ETH", + ): + self.blockstore = blockstore + self.ledger = ledger + self.pricing = pricing + self.tx_verifier = tx_verifier + self.server_wallet = server_wallet + self.network = network + self.asset = asset + + # Track pending payment offers: cid_hex → (peer_id, terms) + self._pending_offers: dict[str, tuple[str, Any]] = {} + + # Callbacks for sending messages back to peers + self.send_message_callback = None + + # Root CID tracking: cid_hex → {root_cid, total_size, child_count} + # Used to compute total file size for pricing + self._dag_info: dict[str, dict[str, Any]] = {} + + async def register_dag( + self, + root_cid: str | bytes, + child_cids: list[str | bytes], + total_size: int, + ) -> None: + """ + Register a DAG structure for root CID payment tracking. + + Call this after chunking a file to register the relationship between + the root CID and its child blocks, along with the total file size. + + Args: + root_cid: The root CID of the DAG + child_cids: List of child/chunk CIDs + total_size: Total size of all blocks combined (bytes) + + Example: + >>> # After adding a large file to Bitswap + >>> await engine.register_dag( + ... root_cid=root_cid, + ... child_cids=[chunk1, chunk2, ...], + ... total_size=5_000_000, # 5 MB + ... ) + + """ + root_hex = _cid_to_str(root_cid) + + # Store DAG metadata + self._dag_info[root_hex] = { + "root_cid": root_hex, + "total_size": total_size, + "child_count": len(child_cids), + } + + # Register in ledger so child blocks inherit root payment status + await self.ledger.register_dag(root_cid, child_cids) + + logger.info( + f"📋 Registered DAG: root={root_hex[:20]}... " + f"size={total_size}B children={len(child_cids)}" + ) + + def mark_free(self, cid: str | bytes) -> None: + """ + Mark a CID as free (no payment required). + + Args: + cid: The CID to mark as free (root or child) + + """ + self.ledger.mark_free(cid) + self.pricing.set_free(cid) + logger.info(f"Marked as FREE: {_cid_to_str(cid)[:20]}...") + + async def handle_want( + self, + peer_id: str, + cid: str | bytes, + want_type: int, # 0 = WANT_BLOCK, 1 = WANT_HAVE + send_dont_have: bool, + peer_protocol: str = BITSWAP_PROTOCOL_V120, + ) -> Message_1_3 | Message_1_2 | None: + """ + Process a WANT request from a peer. + Returns a Message to send back, or None if nothing should be sent. + """ + cid_str = _cid_to_str(cid) + cid_bytes = _cid_to_bytes(cid) + cid_obj = parse_cid(cid_bytes) + + logger.info( + f"🔍 handle_want: peer={peer_id[:20]}... cid={cid_str[:20]}... " + f"want_type={want_type} protocol={peer_protocol}" + ) + + # Check blockstore + logger.info( + "All CIDs in blockstore: " + + ", ".join([c.hex() for c in self.blockstore.get_all_cids()]) + ) + block_data = await self.blockstore.get_block(cid_obj) + + if block_data is None: + logger.warning(f"❌ Block not in store: {cid_str[:20]}...") + if send_dont_have: + return self._make_dont_have(cid_bytes, peer_protocol) + return None + + block_size = len(block_data) + logger.info(f"✅ Block found: {cid_str[:20]}... size={block_size}") + + # Get pricing size (use total DAG size if this is part of a DAG) + pricing_size = self._get_pricing_size(cid_str, block_size) + + # Compute price (at root CID level, not per-block) + price = self.pricing.compute_price(cid_str, pricing_size) + logger.info( + f"💰 Price: {price} units for {cid_str[:20]}... " + f"(block={block_size}B, pricing_size={pricing_size}B)" + ) + + # Check if free or already paid (ledger resolves child → root automatically) + is_paid = self.ledger.is_paid(peer_id, cid_str) + + if price == 0: + # Free block — serve it + logger.info(f"✅ Serving block (FREE): {cid_str[:20]}...") + if want_type == 1: # WANT_HAVE + return self._make_have(cid_bytes, peer_protocol) + else: # WANT_BLOCK + return self._make_block_response(cid_bytes, block_data, peer_protocol) + elif is_paid: + # Already paid with sufficient amount — serve it + logger.info( + f"✅ Serving block (ALREADY PAID): {cid_str[:20]}... " + f"price={price} units" + ) + if want_type == 1: # WANT_HAVE + return self._make_have(cid_bytes, peer_protocol) + else: # WANT_BLOCK + return self._make_block_response(cid_bytes, block_data, peer_protocol) + else: + # Payment required + if peer_protocol == BITSWAP_PROTOCOL_V130: + logger.info(f"💳 Payment required: {price} units for {cid_str[:20]}...") + return self._make_payment_required_1_3( + peer_id, cid_bytes, pricing_size, price + ) + else: + logger.warning( + f"⚠️ Payment required but peer on {peer_protocol}, " + f"sending DONT_HAVE" + ) + if send_dont_have: + return self._make_dont_have(cid_bytes, peer_protocol) + return None + + async def handle_payment_authorization( + self, + peer_id: str, + auth: Any, # Message_1_3.PaymentAuthorization + ) -> Message_1_3: + """ + Process a PaymentAuthorization from a client (EIP-3009 signed payment). + Verifies the signature and serves the block if valid. + """ + cid_bytes = bytes(auth.cid) + cid_str = cid_bytes.hex() + from_address = auth.from_address + + logger.warning("=" * 70) + logger.warning( + f"[STEP 6b] SERVER handle_payment_authorization: peer={peer_id[:20]}... " + f"cid={cid_str[:20]}... from={from_address[:12]}... value={auth.value}" + ) + logger.warning("=" * 70) + + # Check if already paid (ledger hit — no need to re-verify) + cid_obj = parse_cid(cid_bytes) + block_data = await self.blockstore.get_block(cid_obj) + + if block_data is None: + return self._make_payment_rejection(cid_bytes, "BLOCK_NOT_FOUND") + + block_size = len(block_data) + pricing_size = self._get_pricing_size(cid_str, block_size) + expected_price = self.pricing.compute_price(cid_str, pricing_size) + + # Check if already paid (ledger resolves child → root automatically) + if self.ledger.is_paid(peer_id, cid_str): + # Already in ledger with sufficient payment — serve immediately + logger.info( + f"✅ Already paid (ledger hit): {cid_str[:20]}... " + f"block_size={block_size}B expected_price={expected_price}" + ) + return self._make_receipt_and_block(cid_bytes, "", block_data) + + # Validate payment amount matches expected price + if auth.value < expected_price: + error_msg = ( + f"INSUFFICIENT_PAYMENT: paid={auth.value}, " + f"expected={expected_price} for {block_size}B block" + ) + logger.warning(f"❌ {error_msg}") + return self._make_payment_rejection(cid_bytes, error_msg) + + # Verify EIP-3009 signature + logger.warning("=" * 70) + logger.warning("[STEP 7] SERVER VERIFYING EIP-3009 SIGNATURE") + logger.warning(f" from={from_address[:20]}...") + logger.warning(f" to={auth.to_address[:20]}...") + logger.warning(f" value={auth.value} expected={expected_price}") + verifier_status = ( + "configured" + if self.tx_verifier is not None + else "NOT CONFIGURED (optimistic mode)" + ) + logger.warning(f" verifier={verifier_status}") + logger.warning("=" * 70) + if self.tx_verifier is not None: + try: + # The tx_verifier is actually a FacilitatorClient for EIP-3009 + result = await self.tx_verifier.verify( + from_address=from_address, + to_address=auth.to_address, + value=auth.value, + valid_after=auth.valid_after, + valid_before=auth.valid_before, + nonce=bytes(auth.nonce), + v=auth.v, + r=bytes(auth.r), + s=bytes(auth.s), + ) + valid = result.valid + error = result.error + except Exception as e: + logger.error(f"[STEP 7] VERIFICATION EXCEPTION: {e}", exc_info=True) + valid, error = False, str(e) + + if not valid: + logger.warning("=" * 70) + logger.warning(f"[STEP 7] ❌ EIP-3009 VERIFICATION FAILED: {error}") + logger.warning("=" * 70) + return self._make_payment_rejection( + cid_bytes, error or "INVALID_SIGNATURE" + ) + else: + logger.warning("[STEP 7] ✅ EIP-3009 VERIFICATION PASSED") + else: + # No verifier configured — optimistic mode: trust the authorization + logger.warning( + "[STEP 7] ⚠️ No payment verifier configured — accepting " + "PaymentAuthorization optimistically" + ) + + # Record payment in ledger + try: + await self.ledger.record_payment( + peer_id=peer_id, + cid=cid_bytes, + tx_hash="", # No on-chain tx for EIP-3009 + amount=auth.value, + nonce=bytes(auth.nonce), + ) + except ValueError as e: + # Duplicate nonce — already recorded + logger.info(f"Payment already recorded: {e}") + + logger.warning("=" * 70) + logger.warning("[STEP 8b] ✅ SERVER PAYMENT ACCEPTED — SENDING BLOCK TO CLIENT") + logger.warning( + f" cid={cid_str[:20]}... value={auth.value} expected={expected_price} " + f"block_size={block_size}B (EIP-3009)" + ) + logger.warning("=" * 70) + return self._make_receipt_and_block(cid_bytes, "", block_data) + + async def process_incoming_1_3_message( + self, peer_id: str, msg: Message_1_3 + ) -> Message_1_3 | None: + """ + Process an incoming 1.3.0 message that may contain PaymentAuthorizations. + Returns a response message or None. + """ + if msg.payment_authorizations: # type: ignore[attr-defined] + for auth in msg.payment_authorizations: # type: ignore[attr-defined] + return await self.handle_payment_authorization(peer_id, auth) + return None + + # ── Internal helpers ────────────────────────────────────────────────── + + def _get_pricing_size(self, cid_str: str, block_size: int) -> int: + """ + Get the size to use for pricing calculation. + + NEW PAYMENT MODEL: For root CIDs, use total DAG size. + For child CIDs, pricing is N/A (they inherit root payment). + + Args: + cid_str: The CID (hex string) + block_size: The actual block size + + Returns: + Size in bytes to use for pricing + + """ + # Check if this is a registered DAG root + dag_info = self._dag_info.get(cid_str) + if dag_info: + # This is a root CID - use total DAG size for pricing + total_size = dag_info["total_size"] + logger.info( + f"💡 CID {cid_str[:20]}... is DAG root: " + f"block_size={block_size}B, total_size={total_size}B" + ) + return total_size + + # Not a registered root CID - use block size (backward compatibility) + # This handles: old files, single-block files, or child blocks + logger.debug( + f"CID {cid_str[:20]}... not a registered DAG root, " + f"using block_size={block_size}B for pricing" + ) + return block_size + + def _make_payment_required_1_3( + self, + peer_id: str, + cid_bytes: bytes, + block_size: int, + amount: int, + ) -> Message_1_3: + """Build a 1.3.0 PaymentRequired message with embedded PaymentTerms.""" + import secrets + import time + + msg = Message_1_3() + + # BlockPresence with type=2 (PaymentRequired) + presence = msg.blockPresences.add() + presence.cid = cid_bytes + presence.type = Message_1_3.BlockPresenceType.PaymentRequired # = 2 + + # PaymentTerms — all fields including nonce, valid_before, scheme + terms = msg.payment_terms.add() + terms.cid = cid_bytes + terms.asset = self.asset + terms.pay_to = self.server_wallet + terms.amount = amount + terms.network = self.network + terms.nonce = secrets.token_bytes(32) # type: ignore[attr-defined] + terms.valid_before = int(time.time()) + 3600 # type: ignore[attr-defined] + terms.block_size = block_size + terms.description = ( + f"Block {cid_bytes.hex()[:20]}... ({block_size // 1024}KB) — " + f"pay {amount} wei to {self.server_wallet[:10]}..." + ) + terms.scheme = "EIP3009" # type: ignore[attr-defined] + + logger.info( + f"📤 PaymentRequired → {peer_id[:20]}... " + f"cid={cid_bytes.hex()[:20]}... amount={amount} asset={self.asset}" + ) + return msg + + def _make_receipt_and_block( + self, cid_bytes: bytes, tx_hash: str, block_data: bytes + ) -> Message_1_3: + """Build a PaymentReceipt + block payload message.""" + msg = Message_1_3() + + receipt = msg.payment_receipts.add() + receipt.cid = cid_bytes + receipt.tx_hash = tx_hash or "" + receipt.expires = int(time.time()) + 86400 * 7 # 7 days + + block_entry = msg.payload.add() + block_entry.prefix = cid_bytes[:4] + block_entry.data = block_data + + return msg + + def _make_payment_rejection(self, cid_bytes: bytes, reason: str) -> Message_1_3: + msg = Message_1_3() + rejection = msg.payment_rejections.add() + rejection.cid = cid_bytes + rejection.reason = reason + return msg + + def _make_have(self, cid_bytes: bytes, protocol: str) -> Message_1_3 | Message_1_2: + MsgClass = Message_1_3 if protocol == BITSWAP_PROTOCOL_V130 else Message_1_2 + msg = MsgClass() + presence = msg.blockPresences.add() + presence.cid = cid_bytes + if protocol == BITSWAP_PROTOCOL_V130: + presence.type = Message_1_3.BlockPresenceType.Have # type: ignore + else: + presence.type = Message_1_2.BlockPresenceType.Have # type: ignore + return msg + + def _make_dont_have( + self, cid_bytes: bytes, protocol: str + ) -> Message_1_3 | Message_1_2: + MsgClass = Message_1_3 if protocol == BITSWAP_PROTOCOL_V130 else Message_1_2 + msg = MsgClass() + presence = msg.blockPresences.add() + presence.cid = cid_bytes + if protocol == BITSWAP_PROTOCOL_V130: + presence.type = Message_1_3.BlockPresenceType.DontHave # type: ignore + else: + presence.type = Message_1_2.BlockPresenceType.DontHave # type: ignore + return msg + + def _make_block_response( + self, cid_bytes: bytes, block_data: bytes, protocol: str + ) -> Message_1_3 | Message_1_2: + MsgClass = Message_1_3 if protocol == BITSWAP_PROTOCOL_V130 else Message_1_2 + msg = MsgClass() + block = msg.payload.add() + block.prefix = cid_bytes[:4] + block.data = block_data + return msg + + def _get_pricing_size_fallback(self, cid_str: str, block_size: int) -> int: + """ + Get the size to use for pricing calculations. + + If this CID is part of a registered DAG, return the total DAG size. + Otherwise, return the individual block size. + + Args: + cid_str: The CID being priced + block_size: The individual block size + + Returns: + Size in bytes to use for pricing + + """ + # Check if this is a registered root CID + if cid_str in self._dag_info: + total_size = self._dag_info[cid_str]["total_size"] + logger.debug( + f"Using DAG total size for pricing: {cid_str[:20]}... " + f"total={total_size}B (not block={block_size}B)" + ) + return total_size + + # Not a registered DAG, use individual block size + return block_size + + +# ── CID helpers ─────────────────────────────────────────────────────────────── + + +def _cid_to_str(cid: str | bytes) -> str: + if isinstance(cid, bytes): + return cid.hex() + return cid + + +def _cid_to_bytes(cid: str | bytes) -> bytes: + if isinstance(cid, str): + try: + return bytes.fromhex(cid) + except ValueError: + return cid.encode() + return cid diff --git a/libp2p/bitswap/messages.py b/libp2p/bitswap/messages.py index 8eea6535d..0c4264bce 100644 --- a/libp2p/bitswap/messages.py +++ b/libp2p/bitswap/messages.py @@ -4,16 +4,20 @@ """ from collections.abc import Sequence +from typing import TYPE_CHECKING, Union from .cid import CIDInput, cid_to_bytes from .pb.bitswap_pb2 import Message +if TYPE_CHECKING: + from .wantlist import WantType + def create_wantlist_entry( block_cid: CIDInput, priority: int = 1, cancel: bool = False, - want_type: int = 0, # 0 = Block, 1 = Have (v1.2.0) + want_type: Union[int, "WantType"] = 0, # 0 = Block, 1 = Have (v1.2.0) send_dont_have: bool = False, # v1.2.0 ) -> Message.Wantlist.Entry: """ @@ -36,8 +40,12 @@ def create_wantlist_entry( entry.block = cid_to_bytes(block_cid) entry.priority = priority entry.cancel = cancel - # Type checkers don't like int assignment to enum, but protobuf accepts it - entry.wantType = want_type # type: ignore[assignment] # v1.2.0 field + # Handle both int and WantType enum + if isinstance(want_type, int): + entry.wantType = want_type # type: ignore[assignment] + else: + # Extract .value from WantType enum + entry.wantType = want_type.value # type: ignore[assignment] entry.sendDontHave = send_dont_have # v1.2.0 field return entry diff --git a/libp2p/bitswap/payment_client_1_3.py b/libp2p/bitswap/payment_client_1_3.py new file mode 100644 index 000000000..8e4c052b8 --- /dev/null +++ b/libp2p/bitswap/payment_client_1_3.py @@ -0,0 +1,362 @@ +""" +Bitswap 1.3.0 Payment Client. + +Client-side handler for in-band payment messages. When the server sends +a PAYMENT_REQUIRED response with PaymentTerms, this client: +1. Validates the price is acceptable +2. Signs an EIP-3009 authorization +3. Sends back a PaymentAuthorization in the same Bitswap stream +4. On receipt of PaymentReceipt, triggers a WANT_BLOCK retry + +This module lives in py-libp2p so it's importable as libp2p.bitswap. +""" + +from collections.abc import Callable +import logging +import time +from typing import Any + +from libp2p.bitswap.pb.bitswap_1_3_0_pb2 import Message as Message_1_3 + +logger = logging.getLogger(__name__) + +# Default maximum auto-pay threshold: $0.001 USDC = 1000 micro-units +DEFAULT_MAX_AUTO_PAY_UNITS = 1000000 + + +class BitswapPaymentClient_1_3: + """ + Client-side handler for Bitswap 1.3.0 payment messages. + + Processes PaymentTerms from incoming messages and auto-pays if the + amount is within the configured threshold. + + Args: + signer: An EIP3009Signer instance (gooseswarm.payments.eip3009_signer) + want_manager: Object with retry_want_block(peer_id, cid) async method + max_auto_pay_usdc: Maximum amount to auto-pay in USDC (default $1.00) + send_callback: Async function(peer_id, msg_bytes) to send responses + + """ + + def __init__( + self, + signer: Any, # gooseswarm.payments.eip3009_signer.EIP3009Signer + want_manager: Any, # has retry_want_block(peer_id, cid) method + max_auto_pay_usdc: float = 1.0, + send_callback: Callable[..., Any] | None = None, + ledger: Any = None, # gooseswarm.payments.ledger.PaymentLedger (optional) + ): + self.signer = signer + self.want_manager = want_manager + self.max_auto_pay_units = int(max_auto_pay_usdc * 1000000) + self.send_callback = send_callback + self.ledger = ledger + + # Pending payments: nonce_hex → {peer_id, cid, amount} + self._pending_payments: dict[str, dict[str, Any]] = {} + + # Server pricing config: peer_id → {units_per_kb, last_updated} + # This is learned from PaymentTerms messages + self._server_pricing: dict[str, dict[str, Any]] = {} + + async def process_incoming_message( + self, peer_id: str, msg: Message_1_3 + ) -> Message_1_3 | None: + """ + Called by the Bitswap dispatcher for every incoming 1.3.0 message. + + Handles: + - PaymentTerms → sign and send PaymentAuthorization + - PaymentReceipts → retry WANT_BLOCK + - PaymentRejections → log and surface to application + + Returns a response Message to send back, or None. + """ + # Handle payment terms (server telling us what a block costs) + if msg.payment_terms: + for terms in msg.payment_terms: + response = await self._handle_payment_terms(peer_id, terms) + if response: + return response + + # Handle receipts (server confirming our payment) + for receipt in msg.payment_receipts: + await self._handle_payment_receipt(peer_id, receipt) + + # Handle rejections + for rejection in msg.payment_rejections: + self._handle_payment_rejection(peer_id, rejection) + + return None + + async def build_payment_auth_msg( + self, + terms: Any, # Message_1_3.PaymentTerms + ) -> Message_1_3: + """ + Build a PaymentAuthorization message for the given PaymentTerms. + Used by tests and demo scripts. + """ + v, r, s = self.signer.sign_transfer_authorization( + to=terms.pay_to, + value=terms.amount, + nonce=bytes(terms.nonce), + valid_before=terms.valid_before, + ) + + msg = Message_1_3() + auth = msg.payment_authorizations.add() # type: ignore[attr-defined] + auth.cid = bytes(terms.cid) + auth.from_address = self.signer.address + auth.to_address = terms.pay_to + auth.value = terms.amount + auth.valid_after = 0 + auth.valid_before = terms.valid_before + auth.nonce = bytes(terms.nonce) + auth.v = v + auth.r = r + auth.s = s + auth.scheme = terms.scheme + return msg + + # ── Internal handlers ───────────────────────────────────────────────── + + async def _handle_payment_terms( + self, peer_id: str, terms: Any + ) -> Message_1_3 | None: + """ + Server sent us PaymentTerms alongside a PaymentRequired BlockPresence. + Decide whether to pay and send back a PaymentAuthorization. + """ + amount = terms.amount + block_size = terms.block_size + + logger.warning("=" * 70) + logger.warning( + f"[STEP 3b] CLIENT EVALUATING PAYMENT TERMS from {peer_id[:20]}..." + ) + logger.warning( + f" amount={amount} units max_auto_pay={self.max_auto_pay_units} units" + ) + logger.warning( + f" block_size={block_size}B asset={terms.asset} scheme={terms.scheme}" + ) + logger.warning("=" * 70) + + # Learn server's pricing from the PaymentTerms + # The server includes its units_per_kb in the pricing calculation + self._update_server_pricing(peer_id, amount, block_size) + + # Reject if too expensive + if amount > self.max_auto_pay_units: + logger.warning( + f"[STEP 3b] ❌ PAYMENT REJECTED (too expensive): {amount} units > " + f"max {self.max_auto_pay_units} units. " + f"Skipping — will seek block elsewhere." + ) + return None + + if not self._validate_pricing(peer_id, amount, block_size): + logger.warning( + f"[STEP 3b] ❌ PAYMENT REJECTED (pricing validation failed) for " + f"{block_size}B block from {peer_id[:20]}... " + f"Server asked {amount} units. Skipping payment." + ) + return None + + logger.warning( + "[STEP 3b] ✅ Payment terms accepted — proceeding to sign EIP-3009" + ) + + # Sign EIP-3009 authorization + logger.warning("=" * 70) + logger.warning("[STEP 4] CLIENT SIGNING EIP-3009 AUTHORIZATION") + logger.warning(f" to={terms.pay_to[:20]}...") + logger.warning(f" value={amount} units") + logger.warning(f" nonce={bytes(terms.nonce).hex()[:20]}...") + logger.warning(f" valid_before={terms.valid_before}") + logger.warning(f" signer_address={getattr(self.signer, 'address', 'N/A')}") + logger.warning("=" * 70) + try: + v, r, s = self.signer.sign_transfer_authorization( + to=terms.pay_to, + value=amount, + nonce=bytes(terms.nonce), + valid_before=terms.valid_before, + ) + logger.warning( + f"[STEP 4] EIP-3009 SIGNATURE CREATED: v={v} r_len={len(r)} " + f"s_len={len(s)}" + ) + except Exception as e: + logger.error( + f"[STEP 4] FAILED TO SIGN EIP-3009 AUTHORIZATION: {e}", exc_info=True + ) + return None + + # Build PaymentAuthorization message + response = Message_1_3() + auth = response.payment_authorizations.add() # type: ignore[attr-defined] + auth.cid = bytes(terms.cid) + auth.from_address = self.signer.address + auth.to_address = terms.pay_to + auth.value = amount + auth.valid_after = 0 + auth.valid_before = terms.valid_before + auth.nonce = bytes(terms.nonce) + auth.v = v + auth.r = r + auth.s = s + auth.scheme = terms.scheme + + # Track pending payment + nonce_hex = bytes(terms.nonce).hex() + self._pending_payments[nonce_hex] = { + "peer_id": peer_id, + "cid": bytes(terms.cid).hex(), + "amount": amount, + } + + # Persist spent payment to ledger + if self.ledger is not None: + try: + self.ledger.record_spent_payment( + peer_id=peer_id, + cid=bytes(terms.cid), + amount=amount, + nonce=bytes(terms.nonce), + ) + except Exception as _e: + logger.warning(f"Failed to persist spent payment: {_e}") + + logger.info( + f"Sending PaymentAuthorization to {peer_id[:20]}... " + f"cid={bytes(terms.cid).hex()[:20]}... " + f"amount={amount} units (${amount / 1_000_000:.6f} USDC) " + f"for {terms.block_size}B block" + ) + return response + + async def _handle_payment_receipt(self, peer_id: str, receipt: Any) -> None: + """Server confirmed payment. Retry the WANT_BLOCK immediately.""" + cid_hex = ( + bytes(receipt.cid).hex() if isinstance(receipt.cid, bytes) else receipt.cid + ) + logger.info( + f"Payment receipt received from {peer_id[:20]}... " + f"cid={cid_hex[:20]}... " + f"tx={receipt.tx_hash[:20] if receipt.tx_hash else 'optimistic'}..." + ) + # Trigger want manager to retry + if self.want_manager: + try: + await self.want_manager.retry_want_block(peer_id, cid_hex) + except Exception as e: + logger.error(f"Failed to retry want block: {e}") + + def _handle_payment_rejection(self, peer_id: str, rejection: Any) -> None: + """Log and surface payment rejection.""" + cid_hex = ( + bytes(rejection.cid).hex() + if isinstance(rejection.cid, bytes) + else rejection.cid + ) + logger.warning( + f"Payment rejected by {peer_id[:20]}... " + f"cid={cid_hex[:20]}... reason={rejection.reason}" + ) + + def _update_server_pricing( + self, peer_id: str, amount: int, block_size: int + ) -> None: + """ + Learn the server's pricing configuration from PaymentTerms. + + The server calculates: price = max(1, int(block_size_kb * units_per_kb)) + We can reverse-engineer units_per_kb from the amount and block_size. + """ + if amount == 0 or block_size == 0: + return # Free block, no pricing info to learn + + # Calculate implied units_per_kb from this payment request + kb = block_size / 1024 + if kb > 0: + implied_units_per_kb = amount / kb + + # Store or update the pricing config for this peer + if peer_id not in self._server_pricing: + self._server_pricing[peer_id] = { + "units_per_kb": implied_units_per_kb, + "last_updated": time.time(), + "sample_count": 1, + } + logger.info( + f"Learned pricing from {peer_id[:20]}...: " + f"{implied_units_per_kb:.2f} units/KB" + ) + else: + # Average with existing samples for stability + config = self._server_pricing[peer_id] + old_rate = config["units_per_kb"] + sample_count = config["sample_count"] + new_rate = (old_rate * sample_count + implied_units_per_kb) / ( + sample_count + 1 + ) + config["units_per_kb"] = new_rate + config["sample_count"] = sample_count + 1 + config["last_updated"] = time.time() + + # Warn if pricing changed significantly (>20%) + if abs(new_rate - old_rate) / old_rate > 0.2: + logger.warning( + f"Server {peer_id[:20]}... pricing changed: " + f"{old_rate:.2f} → {new_rate:.2f} units/KB" + ) + + def _validate_pricing(self, peer_id: str, amount: int, block_size: int) -> bool: + """ + Validate that the server's price request is consistent with its learned pricing. + + Returns True if pricing is acceptable, False if suspicious. + """ + if amount == 0: + return True # Free blocks are always acceptable + + # If we haven't learned pricing yet, accept this first payment + if peer_id not in self._server_pricing: + return True + + config = self._server_pricing[peer_id] + units_per_kb = config["units_per_kb"] + + # Calculate expected price using learned pricing + kb = block_size / 1024 + expected = max(1, int(kb * units_per_kb)) + + # Allow 20% tolerance for rounding and small variations + tolerance = 0.2 + min_acceptable = expected * (1 - tolerance) + max_acceptable = expected * (1 + tolerance) + + if amount < min_acceptable or amount > max_acceptable: + logger.warning( + f"Pricing inconsistency detected: " + f"expected {expected} units (±{tolerance * 100}%), got {amount} units " + f"for {block_size}B block ({kb:.3f} KB) " + f"using learned rate {units_per_kb:.2f} units/KB" + ) + return False + + return True + + def get_server_pricing(self, peer_id: str) -> dict[str, Any] | None: + """ + Get the learned pricing configuration for a peer. + + Returns: + Dict with units_per_kb, last_updated, sample_count, + or None if not learned yet. + + """ + return self._server_pricing.get(peer_id) diff --git a/libp2p/bitswap/payment_extension.py b/libp2p/bitswap/payment_extension.py new file mode 100644 index 000000000..1b68d1561 --- /dev/null +++ b/libp2p/bitswap/payment_extension.py @@ -0,0 +1,247 @@ +import logging +from typing import Any + +from libp2p.abc import INetStream +from libp2p.peer.id import ID as PeerID + +from .cid import parse_cid +from .extension import IBitswapExtension +from .pb.bitswap_1_3_0_pb2 import Message as Message_1_3 + +logger = logging.getLogger(__name__) + + +class PaymentExtension(IBitswapExtension): + """ + Bitswap 1.3.0 Payment Extension. + Intercepts and processes payment-related protobuf fields and wantlists. + """ + + def __init__(self, payment_client: Any = None, payment_engine: Any = None): + self.payment_client = payment_client + self.payment_engine = payment_engine + + async def process_message( + self, peer_id: PeerID, msg_bytes: bytes, stream: INetStream + ) -> bool: + """ + Process the 1.3.0 specific fields: payment terms, receipts, auths. + Returns False so that standard wantlist and block processing can + continue if needed. + """ + msg_1_3: Message_1_3 | None = None + try: + _tmp = Message_1_3() + _tmp.ParseFromString(msg_bytes) + msg_1_3 = _tmp + except Exception: + return False + + if msg_1_3 is None: + return False + + # Client-side: handle PaymentTerms / PaymentReceipts / PaymentRejections + if self.payment_client and ( + msg_1_3.payment_terms + or msg_1_3.payment_receipts + or msg_1_3.payment_rejections + ): + if msg_1_3.payment_terms: + logger.warning("=" * 70) + logger.warning( + f"[STEP 3] CLIENT RECEIVED PAYMENT TERMS from " + f"{str(peer_id)[:20]}..." + ) + for _t in msg_1_3.payment_terms: + logger.warning(f" cid={bytes(_t.cid).hex()[:20]}...") + logger.warning(f" amount={_t.amount} units") + logger.warning(f" asset={_t.asset} scheme={_t.scheme}") # type: ignore[attr-defined] + logger.warning(f" pay_to={_t.pay_to[:20]}...") + logger.warning(f" block_size={_t.block_size}B") + logger.warning(f" valid_before={_t.valid_before}") # type: ignore[attr-defined] + logger.warning("=" * 70) + if msg_1_3.payment_receipts: + logger.warning("=" * 70) + logger.warning( + f"[STEP 8a] CLIENT RECEIVED PAYMENT RECEIPT from " + f"{str(peer_id)[:20]}..." + ) + for _r in msg_1_3.payment_receipts: + logger.warning(f" cid={bytes(_r.cid).hex()[:20]}...") + logger.warning( + f" tx_hash={_r.tx_hash[:20] if _r.tx_hash else 'optimistic'}" + ) + logger.warning(f" expires={_r.expires}") + logger.warning("=" * 70) + if msg_1_3.payment_rejections: + logger.warning("=" * 70) + logger.warning( + f"[STEP 8a] CLIENT RECEIVED PAYMENT REJECTION from " + f"{str(peer_id)[:20]}..." + ) + for _rj in msg_1_3.payment_rejections: + logger.warning(f" cid={bytes(_rj.cid).hex()[:20]}...") + logger.warning(f" reason={_rj.reason}") + logger.warning("=" * 70) + + response = await self.payment_client.process_incoming_message( + str(peer_id), msg_1_3 + ) + if response is not None: + logger.warning("=" * 70) + logger.warning( + f"[STEP 5] CLIENT SENDING PAYMENT AUTHORIZATION to " + f"{str(peer_id)[:20]}..." + ) + if response.payment_authorizations: + for _a in response.payment_authorizations: + logger.warning(f" cid={bytes(_a.cid).hex()[:20]}...") + logger.warning(f" from={_a.from_address[:20]}...") + logger.warning(f" to={_a.to_address[:20]}...") + logger.warning(f" value={_a.value}") + logger.warning(f" scheme={_a.scheme}") + logger.warning( + f" v={_a.v} r_len={len(bytes(_a.r))} " + f"s_len={len(bytes(_a.s))}" + ) + logger.warning("=" * 70) + await self.client._write_message_bytes( + stream, response.SerializeToString() + ) + + # Server-side: handle PaymentAuthorizations (EIP-3009 signed payments) + if self.payment_engine and msg_1_3.payment_authorizations: # type: ignore[attr-defined] + try: + logger.warning("=" * 70) + logger.warning( + f"[STEP 6] SERVER RECEIVED PAYMENT AUTHORIZATION from " + f"{str(peer_id)[:20]}..." + ) + for _a in msg_1_3.payment_authorizations: # type: ignore[attr-defined] + logger.warning(f" cid={bytes(_a.cid).hex()[:20]}...") + logger.warning(f" from={_a.from_address[:20]}...") + logger.warning(f" to={_a.to_address[:20]}...") + logger.warning(f" value={_a.value}") + logger.warning(f" scheme={_a.scheme}") + logger.warning( + f" v={_a.v} r_len={len(bytes(_a.r))} s_len={len(bytes(_a.s))}" + ) + logger.warning("=" * 70) + + response = await self.payment_engine.process_incoming_1_3_message( + str(peer_id), msg_1_3 + ) + if response is not None: + _has_receipt = bool(response.payment_receipts) + _has_rejection = bool(response.payment_rejections) + _has_blocks = bool(response.payload) or bool(response.blocks) + logger.warning("=" * 70) + logger.warning( + "[STEP 8] SERVER SENDING RESPONSE after PaymentAuthorization:" + ) + logger.warning( + f" has_receipt={_has_receipt} " + f"has_rejection={_has_rejection} has_blocks={_has_blocks}" + ) + if _has_rejection: + for _rj in response.payment_rejections: + logger.warning(f" ❌ REJECTION reason={_rj.reason}") + if _has_blocks: + _nb = len(response.payload) + len(response.blocks) + logger.warning( + f" ✅ SENDING {_nb} block(s) to client " # type: ignore + f"— FILE TRANSFER STARTING" + ) + logger.warning("=" * 70) + await self.client._write_message_bytes( + stream, response.SerializeToString() + ) + + # Payment authorization handled — we intercept this completely. + return True + except Exception as e: + logger.error(f"Error handling PaymentAuthorization: {e}", exc_info=True) + + # Handle PaymentRequired block presences (1.3.0 type=2) + if msg_1_3.blockPresences: + await self.client._process_block_presences_1_3( + msg_1_3.blockPresences, peer_id + ) + + return False + + async def process_wantlist( + self, wantlist: Any, peer_id: PeerID, stream: INetStream + ) -> bool: + """ + Gated wantlist processing. + If we have a payment_engine, we MUST gate block sharing behind payment terms. + """ + if not self.payment_engine: + return False + + if peer_id not in self.client._peer_wantlists: + self.client._peer_wantlists[peer_id] = {} + peer_wantlist = self.client._peer_wantlists[peer_id] + + if wantlist.full: + peer_wantlist.clear() + + for entry in wantlist.entries: + entry_cid = parse_cid(entry.block) + if entry.cancel: + if entry_cid in peer_wantlist: + del peer_wantlist[entry_cid] + continue + + peer_wantlist[entry_cid] = { + "priority": entry.priority, + "want_type": entry.wantType, + "send_dont_have": entry.sendDontHave, + } + + peer_protocol = self.client._peer_protocols.get(peer_id, "") + response_msg = await self.payment_engine.handle_want( + peer_id=str(peer_id), + cid=entry.block, + want_type=entry.wantType, + send_dont_have=entry.sendDontHave, + peer_protocol=str(peer_protocol), + ) + + if response_msg is not None: + _has_pr = bool(getattr(response_msg, "blockPresences", [])) + _has_terms = bool(getattr(response_msg, "payment_terms", [])) + _has_blocks = bool(getattr(response_msg, "payload", [])) or bool( + getattr(response_msg, "blocks", []) + ) + logger.warning("=" * 70) + logger.warning( + f"[STEP 2] SERVER SENDING RESPONSE for cid=" + f"{bytes(entry.block).hex()[:20]}..." + ) + logger.warning( + f" payment_required={_has_pr} payment_terms={_has_terms} " + f"has_blocks={_has_blocks}" + ) + if _has_pr: + for _bp in response_msg.blockPresences: + logger.warning( + f" BlockPresence type={_bp.type} (2=PaymentRequired)" + ) + if _has_terms: + for _t in response_msg.payment_terms: + logger.warning( + f" PaymentTerms: amount={_t.amount} asset={_t.asset} " + f"pay_to={_t.pay_to[:20]}... scheme={_t.scheme}" + ) + if _has_blocks: + logger.warning( + " ✅ Sending block(s) directly (free/already paid)" + ) + logger.warning("=" * 70) + await self.client._write_message_bytes( + stream, response_msg.SerializeToString() + ) + + return True diff --git a/libp2p/bitswap/payment_ledger.py b/libp2p/bitswap/payment_ledger.py new file mode 100644 index 000000000..97ccaec1c --- /dev/null +++ b/libp2p/bitswap/payment_ledger.py @@ -0,0 +1,289 @@ +""" +Payment Ledger for Bitswap 1.3.0 - Root CID Payment Tracking. + +Tracks payments at the root CID level, not per-block. When a peer pays for +a root CID, all child blocks (chunks) in the DAG are automatically accessible. + +Design: +- Payment records: (peer_id, root_cid) → {amount, nonce, timestamp, tx_hash} +- Root CID mapping: (child_cid) → root_cid (for chunk → root resolution) +- Nonce deduplication: Prevents replay attacks +""" + +import logging +import time +from typing import Any + +logger = logging.getLogger(__name__) + + +class PaymentLedger: + r""" + Tracks root CID payments for Bitswap 1.3.0. + + When a peer pays for a root CID, they gain access to all blocks in that DAG. + This prevents charging separately for each chunk of a multi-block file. + + Example: + >>> ledger = PaymentLedger() + >>> + >>> # Register a DAG structure (root → children mapping) + >>> await ledger.register_dag( + ... root_cid="bafyroot123...", + ... child_cids=["bafychild1...", "bafychild2...", ...] + ... ) + >>> + >>> # Record payment for root CID + >>> await ledger.record_payment( + ... peer_id="12D3Koo...", + ... cid=b"\\x01\\x55...", # Can be root or child CID + ... amount=1000000, # 1 USDC in micro-units + ... nonce=b"\\x12\\x34...", + ... ) + >>> + >>> # Check if peer has paid (works for root OR child CIDs) + >>> ledger.is_paid("12D3Koo...", "bafychild1...") # True (child of paid root) + >>> ledger.is_paid("12D3Koo...", "bafyroot123...") # True (root itself) + + """ + + def __init__(self) -> None: + # Payment records: (peer_id, root_cid_hex) → payment_info + self._payments: dict[tuple[str, str], dict[str, Any]] = {} + + # Child → Root mapping: child_cid_hex → root_cid_hex + # Used to resolve chunk CIDs to their root CID + self._cid_to_root: dict[str, str] = {} + + # Nonce registry: nonce_hex → (peer_id, cid_hex, timestamp) + # Prevents replay attacks (same nonce can't be used twice) + self._used_nonces: dict[str, tuple[str, str, float]] = {} + + # Free CIDs: Set of CID hashes that are always free (no payment required) + self._free_cids: set[str] = set() + + async def register_dag( + self, + root_cid: str | bytes, + child_cids: list[str | bytes], + ) -> None: + """ + Register a DAG structure so child blocks inherit root payment status. + + Args: + root_cid: The root CID of the DAG (hex string or bytes) + child_cids: List of child/chunk CIDs in the DAG + + Example: + >>> # After chunking a file into blocks + >>> await ledger.register_dag( + ... root_cid=root_cid, + ... child_cids=[chunk1_cid, chunk2_cid, ...] + ... ) + + """ + root_hex = _cid_to_hex(root_cid) + + for child_cid in child_cids: + child_hex = _cid_to_hex(child_cid) + self._cid_to_root[child_hex] = root_hex + logger.debug( + f"Registered child {child_hex[:20]}... → root {root_hex[:20]}..." + ) + + logger.info( + f"Registered DAG: root={root_hex[:20]}... with {len(child_cids)} children" + ) + + def mark_free(self, cid: str | bytes) -> None: + """ + Mark a CID as free (no payment required). + + Args: + cid: The CID to mark as free (hex string or bytes) + + """ + cid_hex = _cid_to_hex(cid) + self._free_cids.add(cid_hex) + logger.info(f"Marked CID as FREE: {cid_hex[:20]}...") + + def is_free(self, cid: str | bytes) -> bool: + """ + Check if a CID is marked as free. + + Args: + cid: The CID to check (hex string or bytes) + + Returns: + True if the CID is free, False otherwise + + """ + cid_hex = _cid_to_hex(cid) + root_hex = self._cid_to_root.get(cid_hex, cid_hex) + return cid_hex in self._free_cids or root_hex in self._free_cids + + def is_paid( + self, + peer_id: str, + cid: str | bytes, + block_size: int = 0, # Ignored (kept for backward compatibility) + ) -> bool: + """ + Check if a peer has paid for a CID (root or child). + + Resolves child CIDs to their root CID automatically. + + Args: + peer_id: The peer ID to check + cid: The CID to check (can be root or child CID) + block_size: Ignored (kept for backward compatibility with old API) + + Returns: + True if the peer has paid for this CID (or its root), False otherwise + + """ + cid_hex = _cid_to_hex(cid) + + # Check if it's a free CID + if self.is_free(cid_hex): + return True + + # Resolve to root CID if this is a child + root_hex = self._cid_to_root.get(cid_hex, cid_hex) + + # Check if payment exists for (peer, root) + key = (peer_id, root_hex) + paid = key in self._payments + + if paid: + payment = self._payments[key] + logger.debug( + f"✅ Payment found: peer={peer_id[:20]}... " + f"cid={cid_hex[:20]}... root={root_hex[:20]}... " + f"amount={payment['amount']}" + ) + else: + logger.debug( + f"❌ No payment: peer={peer_id[:20]}... " + f"cid={cid_hex[:20]}... root={root_hex[:20]}..." + ) + + return paid + + async def record_payment( + self, + peer_id: str, + cid: str | bytes, + amount: int, + nonce: bytes, + tx_hash: str = "", + ) -> None: + """ + Record a payment for a root CID. + + Args: + peer_id: The peer who paid + cid: The CID being paid for (root or child - will resolve to root) + amount: Payment amount in micro-units (e.g., USDC micro-units) + nonce: Unique nonce for this payment (prevents replay attacks) + tx_hash: Optional transaction hash (empty for EIP-3009) + + Raises: + ValueError: If the nonce has already been used + + """ + cid_hex = _cid_to_hex(cid) + nonce_hex = nonce.hex() + + # Check for nonce reuse (replay attack prevention) + if nonce_hex in self._used_nonces: + existing = self._used_nonces[nonce_hex] + raise ValueError( + f"Nonce already used: {nonce_hex[:20]}... " + f"by peer={existing[0][:20]}... for cid={existing[1][:20]}..." + ) + + # Resolve to root CID + root_hex = self._cid_to_root.get(cid_hex, cid_hex) + + # Record payment + key = (peer_id, root_hex) + self._payments[key] = { + "amount": amount, + "nonce": nonce_hex, + "tx_hash": tx_hash, + "timestamp": time.time(), + } + + # Mark nonce as used + self._used_nonces[nonce_hex] = (peer_id, root_hex, time.time()) + + logger.info( + f"💰 Payment recorded: peer={peer_id[:20]}... " + f"root={root_hex[:20]}... amount={amount} " + f"nonce={nonce_hex[:16]}..." + ) + + def get_payment( + self, + peer_id: str, + cid: str | bytes, + ) -> dict[str, Any] | None: + """ + Get payment details for a peer and CID. + + Args: + peer_id: The peer ID + cid: The CID (root or child) + + Returns: + Payment info dict with keys: amount, nonce, tx_hash, timestamp + or None if no payment found + + """ + cid_hex = _cid_to_hex(cid) + root_hex = self._cid_to_root.get(cid_hex, cid_hex) + key = (peer_id, root_hex) + return self._payments.get(key) + + def clear_old_nonces(self, max_age_seconds: float = 86400) -> int: + """ + Clear nonces older than max_age_seconds (default: 24 hours). + + Returns: + Number of nonces cleared + + """ + now = time.time() + old_nonces = [ + nonce_hex + for nonce_hex, (_, _, timestamp) in self._used_nonces.items() + if now - timestamp > max_age_seconds + ] + + for nonce_hex in old_nonces: + del self._used_nonces[nonce_hex] + + if old_nonces: + logger.info(f"Cleared {len(old_nonces)} old nonces (>{max_age_seconds}s)") + + return len(old_nonces) + + +# ── Helper functions ────────────────────────────────────────────────────────── + + +def _cid_to_hex(cid: str | bytes) -> str: + """Convert CID to hex string for consistent storage.""" + if isinstance(cid, bytes): + return cid.hex() + elif isinstance(cid, str): + # If already hex, return as-is; otherwise try to decode + try: + bytes.fromhex(cid) + return cid + except ValueError: + # Assume it's a base58/base32 encoded CID string + return cid.encode().hex() + else: + raise TypeError(f"CID must be str or bytes, got {type(cid)}") diff --git a/libp2p/bitswap/pb/bitswap_1_3_0.proto b/libp2p/bitswap/pb/bitswap_1_3_0.proto new file mode 100644 index 000000000..bd3196efb --- /dev/null +++ b/libp2p/bitswap/pb/bitswap_1_3_0.proto @@ -0,0 +1,104 @@ +// bitswap_1_3_0.proto +// Bitswap 1.3.0 — adds PAYMENT_REQUIRED block presence and in-band payment flow +// Backward compatible with 1.2.0: new fields use field numbers 6, 7, 8, 9 +// New enum value PaymentRequired = 2 (proto3 open enums — safe for old parsers) + +syntax = "proto3"; + +package bitswap.pb.v130; + +message Message { + + // ─── EXISTING: Wantlist (unchanged from 1.2.0) ───────────────────────── + message Wantlist { + enum WantType { + Block = 0; // default: client wants the full block + Have = 1; // client only wants to know if server has it + } + message Entry { + bytes block = 1; // CID bytes (CIDv1 binary) + int32 priority = 2; // higher = serve first; default 1 + bool cancel = 3; // true = remove from wantlist + WantType wantType = 4; // Block or Have + bool sendDontHave = 5; // server MUST respond DONT_HAVE if missing + } + repeated Entry entries = 1; + bool full = 2; // true = authoritative wantlist replacement + } + + // ─── EXISTING: Block payload (unchanged from 1.1.0) ───────────────────── + message Block { + bytes prefix = 1; // CID prefix: version + codec varint + bytes data = 2; // raw block bytes + } + + // ─── EXTENDED: BlockPresenceType — NEW value PaymentRequired = 2 ──────── + enum BlockPresenceType { + Have = 0; // server has the block + DontHave = 1; // server genuinely does not have the block + PaymentRequired = 2; // [NEW 1.3.0] server has the block but requires payment + // Old parsers: see integer 2, no matching case → skip entry + } + + // ─── EXISTING: BlockPresence (unchanged structure, extended enum) ──────── + message BlockPresence { + bytes cid = 1; + BlockPresenceType type = 2; // Now can be 0, 1, or 2 + } + + // ─── NEW 1.3.0: PaymentTerms — embedded in Message when type=PaymentRequired + message PaymentTerms { + bytes cid = 1; // CID of the gated block + string asset = 2; // Token contract address + string pay_to = 3; // Server's wallet address + uint64 amount = 4; // Amount in token's smallest unit + string network = 5; // Chain identifier: "base-sepolia" | "base-mainnet" + bytes nonce = 6; // Random 32 bytes — per-offer, prevents replay attacks + uint64 valid_before = 7; // Unix timestamp: offer expires after this + uint64 block_size = 8; // Actual block size in bytes + string description = 9; // Human-readable description + string scheme = 10; // Payment scheme: "exact" (EIP-3009) + } + + // ─── NEW 1.3.0: PaymentAuthorization — client's signed proof of payment + message PaymentAuthorization { + bytes cid = 1; // CID being paid for + string from_address = 2; // Client's wallet address + string to_address = 3; // Must match PaymentTerms.pay_to + uint64 value = 4; // Must be >= PaymentTerms.amount + uint64 valid_after = 5; // EIP-3009 validAfter (typically 0) + uint64 valid_before = 6; // EIP-3009 validBefore + bytes nonce = 7; // Must match PaymentTerms.nonce exactly + uint32 v = 8; // ECDSA signature recovery id (27 or 28) + bytes r = 9; // ECDSA r component (32 bytes) + bytes s = 10; // ECDSA s component (32 bytes) + string scheme = 11; // Must match PaymentTerms.scheme + } + + // ─── NEW 1.3.0: PaymentReceipt — server confirms payment accepted + message PaymentReceipt { + bytes cid = 1; // CID now authorized to be served + string tx_hash = 2; // On-chain tx hash (empty in OPTIMISTIC mode) + uint64 expires = 3; // Unix ts: this authorization is valid until this time + } + + // ─── NEW 1.3.0: PaymentRejection — server rejects a PaymentAuthorization + message PaymentRejection { + bytes cid = 1; + string reason = 2; // "INVALID_SIGNATURE" | "WRONG_AMOUNT" | "NONCE_USED" | "EXPIRED" + } + + // ─── TOP-LEVEL MESSAGE FIELDS ────────────────────────────────────────── + // Fields 1-5: identical to Bitswap 1.2.0 (never modified) + Wantlist wantlist = 1; + repeated bytes blocks = 2; // deprecated since 1.1.0 + repeated Block payload = 3; + repeated BlockPresence blockPresences = 4; // type=2 means PAYMENT_REQUIRED + int32 pendingBytes = 5; + + // Fields 6-9: NEW in 1.3.0 (safe unknown fields for old parsers) + repeated PaymentTerms payment_terms = 6; // server → client + repeated PaymentAuthorization payment_authorizations = 7; // client → server + repeated PaymentReceipt payment_receipts = 8; // server → client + repeated PaymentRejection payment_rejections = 9; // server → client +} diff --git a/libp2p/bitswap/pb/bitswap_1_3_0_pb2.py b/libp2p/bitswap/pb/bitswap_1_3_0_pb2.py new file mode 100644 index 000000000..24b07bf75 --- /dev/null +++ b/libp2p/bitswap/pb/bitswap_1_3_0_pb2.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: bitswap_1_3_0.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13\x62itswap_1_3_0.proto\x12\x0f\x62itswap.pb.v130\"\xaa\x0b\n\x07Message\x12\x33\n\x08wantlist\x18\x01 \x01(\x0b\x32!.bitswap.pb.v130.Message.Wantlist\x12\x0e\n\x06\x62locks\x18\x02 \x03(\x0c\x12/\n\x07payload\x18\x03 \x03(\x0b\x32\x1e.bitswap.pb.v130.Message.Block\x12>\n\x0e\x62lockPresences\x18\x04 \x03(\x0b\x32&.bitswap.pb.v130.Message.BlockPresence\x12\x14\n\x0cpendingBytes\x18\x05 \x01(\x05\x12<\n\rpayment_terms\x18\x06 \x03(\x0b\x32%.bitswap.pb.v130.Message.PaymentTerms\x12M\n\x16payment_authorizations\x18\x07 \x03(\x0b\x32-.bitswap.pb.v130.Message.PaymentAuthorization\x12\x41\n\x10payment_receipts\x18\x08 \x03(\x0b\x32\'.bitswap.pb.v130.Message.PaymentReceipt\x12\x45\n\x12payment_rejections\x18\t \x03(\x0b\x32).bitswap.pb.v130.Message.PaymentRejection\x1a\x82\x02\n\x08Wantlist\x12\x38\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\'.bitswap.pb.v130.Message.Wantlist.Entry\x12\x0c\n\x04\x66ull\x18\x02 \x01(\x08\x1a\x8c\x01\n\x05\x45ntry\x12\r\n\x05\x62lock\x18\x01 \x01(\x0c\x12\x10\n\x08priority\x18\x02 \x01(\x05\x12\x0e\n\x06\x63\x61ncel\x18\x03 \x01(\x08\x12<\n\x08wantType\x18\x04 \x01(\x0e\x32*.bitswap.pb.v130.Message.Wantlist.WantType\x12\x14\n\x0csendDontHave\x18\x05 \x01(\x08\"\x1f\n\x08WantType\x12\t\n\x05\x42lock\x10\x00\x12\x08\n\x04Have\x10\x01\x1a%\n\x05\x42lock\x12\x0e\n\x06prefix\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x1aV\n\rBlockPresence\x12\x0b\n\x03\x63id\x18\x01 \x01(\x0c\x12\x38\n\x04type\x18\x02 \x01(\x0e\x32*.bitswap.pb.v130.Message.BlockPresenceType\x1a\xb9\x01\n\x0cPaymentTerms\x12\x0b\n\x03\x63id\x18\x01 \x01(\x0c\x12\r\n\x05\x61sset\x18\x02 \x01(\t\x12\x0e\n\x06pay_to\x18\x03 \x01(\t\x12\x0e\n\x06\x61mount\x18\x04 \x01(\x04\x12\x0f\n\x07network\x18\x05 \x01(\t\x12\r\n\x05nonce\x18\x06 \x01(\x0c\x12\x14\n\x0cvalid_before\x18\x07 \x01(\x04\x12\x12\n\nblock_size\x18\x08 \x01(\x04\x12\x13\n\x0b\x64\x65scription\x18\t \x01(\t\x12\x0e\n\x06scheme\x18\n \x01(\t\x1a\xc7\x01\n\x14PaymentAuthorization\x12\x0b\n\x03\x63id\x18\x01 \x01(\x0c\x12\x14\n\x0c\x66rom_address\x18\x02 \x01(\t\x12\x12\n\nto_address\x18\x03 \x01(\t\x12\r\n\x05value\x18\x04 \x01(\x04\x12\x13\n\x0bvalid_after\x18\x05 \x01(\x04\x12\x14\n\x0cvalid_before\x18\x06 \x01(\x04\x12\r\n\x05nonce\x18\x07 \x01(\x0c\x12\t\n\x01v\x18\x08 \x01(\r\x12\t\n\x01r\x18\t \x01(\x0c\x12\t\n\x01s\x18\n \x01(\x0c\x12\x0e\n\x06scheme\x18\x0b \x01(\t\x1a?\n\x0ePaymentReceipt\x12\x0b\n\x03\x63id\x18\x01 \x01(\x0c\x12\x0f\n\x07tx_hash\x18\x02 \x01(\t\x12\x0f\n\x07\x65xpires\x18\x03 \x01(\x04\x1a/\n\x10PaymentRejection\x12\x0b\n\x03\x63id\x18\x01 \x01(\x0c\x12\x0e\n\x06reason\x18\x02 \x01(\t\"@\n\x11\x42lockPresenceType\x12\x08\n\x04Have\x10\x00\x12\x0c\n\x08\x44ontHave\x10\x01\x12\x13\n\x0fPaymentRequired\x10\x02\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'bitswap_1_3_0_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_MESSAGE']._serialized_start=41 + _globals['_MESSAGE']._serialized_end=1491 + _globals['_MESSAGE_WANTLIST']._serialized_start=536 + _globals['_MESSAGE_WANTLIST']._serialized_end=794 + _globals['_MESSAGE_WANTLIST_ENTRY']._serialized_start=621 + _globals['_MESSAGE_WANTLIST_ENTRY']._serialized_end=761 + _globals['_MESSAGE_WANTLIST_WANTTYPE']._serialized_start=763 + _globals['_MESSAGE_WANTLIST_WANTTYPE']._serialized_end=794 + _globals['_MESSAGE_BLOCK']._serialized_start=796 + _globals['_MESSAGE_BLOCK']._serialized_end=833 + _globals['_MESSAGE_BLOCKPRESENCE']._serialized_start=835 + _globals['_MESSAGE_BLOCKPRESENCE']._serialized_end=921 + _globals['_MESSAGE_PAYMENTTERMS']._serialized_start=924 + _globals['_MESSAGE_PAYMENTTERMS']._serialized_end=1109 + _globals['_MESSAGE_PAYMENTAUTHORIZATION']._serialized_start=1112 + _globals['_MESSAGE_PAYMENTAUTHORIZATION']._serialized_end=1311 + _globals['_MESSAGE_PAYMENTRECEIPT']._serialized_start=1313 + _globals['_MESSAGE_PAYMENTRECEIPT']._serialized_end=1376 + _globals['_MESSAGE_PAYMENTREJECTION']._serialized_start=1378 + _globals['_MESSAGE_PAYMENTREJECTION']._serialized_end=1425 + _globals['_MESSAGE_BLOCKPRESENCETYPE']._serialized_start=1427 + _globals['_MESSAGE_BLOCKPRESENCETYPE']._serialized_end=1491 +# @@protoc_insertion_point(module_scope) diff --git a/libp2p/bitswap/pb/bitswap_1_3_0_pb2.pyi b/libp2p/bitswap/pb/bitswap_1_3_0_pb2.pyi new file mode 100644 index 000000000..75bcd0b01 --- /dev/null +++ b/libp2p/bitswap/pb/bitswap_1_3_0_pb2.pyi @@ -0,0 +1,128 @@ +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from collections.abc import Iterable as _Iterable, Mapping as _Mapping +from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union, Any as _Any + +DESCRIPTOR: _descriptor.FileDescriptor + +class Message(_message.Message): + __slots__ = ("wantlist", "blocks", "payload", "blockPresences", "pendingBytes", "payment_terms", "tx_receipts", "payment_receipts", "payment_rejections") + class BlockPresenceType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + Have: _ClassVar[Message.BlockPresenceType] + DontHave: _ClassVar[Message.BlockPresenceType] + PaymentRequired: _ClassVar[Message.BlockPresenceType] + Have: Message.BlockPresenceType + DontHave: Message.BlockPresenceType + PaymentRequired: Message.BlockPresenceType + class Wantlist(_message.Message): + __slots__ = ("entries", "full") + class WantType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + Block: _ClassVar[Message.Wantlist.WantType] + Have: _ClassVar[Message.Wantlist.WantType] + Block: Message.Wantlist.WantType + Have: Message.Wantlist.WantType + class Entry(_message.Message): + __slots__ = ("block", "priority", "cancel", "wantType", "sendDontHave") + BLOCK_FIELD_NUMBER: _ClassVar[int] + PRIORITY_FIELD_NUMBER: _ClassVar[int] + CANCEL_FIELD_NUMBER: _ClassVar[int] + WANTTYPE_FIELD_NUMBER: _ClassVar[int] + SENDDONTHAVE_FIELD_NUMBER: _ClassVar[int] + block: bytes + priority: int + cancel: bool + wantType: Message.Wantlist.WantType + sendDontHave: bool + def __init__(self, block: _Optional[bytes] = ..., priority: _Optional[int] = ..., cancel: bool = ..., wantType: _Optional[_Union[Message.Wantlist.WantType, str]] = ..., sendDontHave: bool = ...) -> None: ... + ENTRIES_FIELD_NUMBER: _ClassVar[int] + FULL_FIELD_NUMBER: _ClassVar[int] + entries: _containers.RepeatedCompositeFieldContainer[Message.Wantlist.Entry] + full: bool + def __init__(self, entries: _Optional[_Iterable[_Union[Message.Wantlist.Entry, _Mapping[str, _Any]]]] = ..., full: bool = ...) -> None: ... + class Block(_message.Message): + __slots__ = ("prefix", "data") + PREFIX_FIELD_NUMBER: _ClassVar[int] + DATA_FIELD_NUMBER: _ClassVar[int] + prefix: bytes + data: bytes + def __init__(self, prefix: _Optional[bytes] = ..., data: _Optional[bytes] = ...) -> None: ... + class BlockPresence(_message.Message): + __slots__ = ("cid", "type") + CID_FIELD_NUMBER: _ClassVar[int] + TYPE_FIELD_NUMBER: _ClassVar[int] + cid: bytes + type: Message.BlockPresenceType + def __init__(self, cid: _Optional[bytes] = ..., type: _Optional[_Union[Message.BlockPresenceType, str]] = ...) -> None: ... + class PaymentTerms(_message.Message): + __slots__ = ("cid", "asset", "pay_to", "amount", "network", "block_size", "description") + CID_FIELD_NUMBER: _ClassVar[int] + ASSET_FIELD_NUMBER: _ClassVar[int] + PAY_TO_FIELD_NUMBER: _ClassVar[int] + AMOUNT_FIELD_NUMBER: _ClassVar[int] + NETWORK_FIELD_NUMBER: _ClassVar[int] + BLOCK_SIZE_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + cid: bytes + asset: str + pay_to: str + amount: int + network: str + block_size: int + description: str + def __init__(self, cid: _Optional[bytes] = ..., asset: _Optional[str] = ..., pay_to: _Optional[str] = ..., amount: _Optional[int] = ..., network: _Optional[str] = ..., block_size: _Optional[int] = ..., description: _Optional[str] = ...) -> None: ... + class TxReceipt(_message.Message): + __slots__ = ("cid", "tx_hash", "from_address", "to_address", "amount", "asset", "network") + CID_FIELD_NUMBER: _ClassVar[int] + TX_HASH_FIELD_NUMBER: _ClassVar[int] + FROM_ADDRESS_FIELD_NUMBER: _ClassVar[int] + TO_ADDRESS_FIELD_NUMBER: _ClassVar[int] + AMOUNT_FIELD_NUMBER: _ClassVar[int] + ASSET_FIELD_NUMBER: _ClassVar[int] + NETWORK_FIELD_NUMBER: _ClassVar[int] + cid: bytes + tx_hash: str + from_address: str + to_address: str + amount: int + asset: str + network: str + def __init__(self, cid: _Optional[bytes] = ..., tx_hash: _Optional[str] = ..., from_address: _Optional[str] = ..., to_address: _Optional[str] = ..., amount: _Optional[int] = ..., asset: _Optional[str] = ..., network: _Optional[str] = ...) -> None: ... + class PaymentReceipt(_message.Message): + __slots__ = ("cid", "tx_hash", "expires") + CID_FIELD_NUMBER: _ClassVar[int] + TX_HASH_FIELD_NUMBER: _ClassVar[int] + EXPIRES_FIELD_NUMBER: _ClassVar[int] + cid: bytes + tx_hash: str + expires: int + def __init__(self, cid: _Optional[bytes] = ..., tx_hash: _Optional[str] = ..., expires: _Optional[int] = ...) -> None: ... + class PaymentRejection(_message.Message): + __slots__ = ("cid", "reason") + CID_FIELD_NUMBER: _ClassVar[int] + REASON_FIELD_NUMBER: _ClassVar[int] + cid: bytes + reason: str + def __init__(self, cid: _Optional[bytes] = ..., reason: _Optional[str] = ...) -> None: ... + WANTLIST_FIELD_NUMBER: _ClassVar[int] + BLOCKS_FIELD_NUMBER: _ClassVar[int] + PAYLOAD_FIELD_NUMBER: _ClassVar[int] + BLOCKPRESENCES_FIELD_NUMBER: _ClassVar[int] + PENDINGBYTES_FIELD_NUMBER: _ClassVar[int] + PAYMENT_TERMS_FIELD_NUMBER: _ClassVar[int] + TX_RECEIPTS_FIELD_NUMBER: _ClassVar[int] + PAYMENT_RECEIPTS_FIELD_NUMBER: _ClassVar[int] + PAYMENT_REJECTIONS_FIELD_NUMBER: _ClassVar[int] + wantlist: Message.Wantlist + blocks: _containers.RepeatedScalarFieldContainer[bytes] + payload: _containers.RepeatedCompositeFieldContainer[Message.Block] + blockPresences: _containers.RepeatedCompositeFieldContainer[Message.BlockPresence] + pendingBytes: int + payment_terms: _containers.RepeatedCompositeFieldContainer[Message.PaymentTerms] + tx_receipts: _containers.RepeatedCompositeFieldContainer[Message.TxReceipt] + payment_receipts: _containers.RepeatedCompositeFieldContainer[Message.PaymentReceipt] + payment_rejections: _containers.RepeatedCompositeFieldContainer[Message.PaymentRejection] + def __init__(self, wantlist: _Optional[_Union[Message.Wantlist, _Mapping[str, _Any]]] = ..., blocks: _Optional[_Iterable[bytes]] = ..., payload: _Optional[_Iterable[_Union[Message.Block, _Mapping[str, _Any]]]] = ..., blockPresences: _Optional[_Iterable[_Union[Message.BlockPresence, _Mapping[str, _Any]]]] = ..., pendingBytes: _Optional[int] = ..., payment_terms: _Optional[_Iterable[_Union[Message.PaymentTerms, _Mapping[str, _Any]]]] = ..., tx_receipts: _Optional[_Iterable[_Union[Message.TxReceipt, _Mapping[str, _Any]]]] = ..., payment_receipts: _Optional[_Iterable[_Union[Message.PaymentReceipt, _Mapping[str, _Any]]]] = ..., payment_rejections: _Optional[_Iterable[_Union[Message.PaymentRejection, _Mapping[str, _Any]]]] = ...) -> None: ... diff --git a/libp2p/bitswap/pricing_engine.py b/libp2p/bitswap/pricing_engine.py new file mode 100644 index 000000000..6dc443cbe --- /dev/null +++ b/libp2p/bitswap/pricing_engine.py @@ -0,0 +1,185 @@ +""" +Block Pricing Engine for Bitswap 1.3.0 - Root CID Pricing. + +Computes prices for files/DAGs based on total size, not individual blocks. +Supports configurable pricing strategies: +- Free: All blocks are free (price = 0) +- Fixed: Fixed price per file regardless of size +- Size-based: Price scales with total file size (units per KB) +- Custom: User-defined pricing function +""" + +from collections.abc import Callable +import logging + +logger = logging.getLogger(__name__) + + +class BlockPricingEngine: + """ + Computes prices for Bitswap blocks based on configurable strategies. + + Pricing is typically done at the root CID level (total file size), + not per-block, to avoid charging for each chunk separately. + + Example: + >>> # Size-based pricing: 100 micro-USDC per KB + >>> pricing = BlockPricingEngine( + ... strategy="size_based", + ... units_per_kb=100, + ... ) + >>> + >>> # 5 MB file = 5000 KB × 100 = 500,000 micro-units = $0.50 + >>> price = pricing.compute_price("bafyroot...", block_size=5_000_000) + >>> print(f"${price / 1_000_000:.2f}") # $0.50 + >>> + >>> # Mark specific CIDs as free + >>> pricing.set_free("bafyfree123...") + >>> pricing.compute_price("bafyfree123...", 1_000_000) # 0 (free) + + """ + + def __init__( + self, + strategy: str = "size_based", + units_per_kb: float = 100.0, + fixed_price: int = 0, + custom_pricing_fn: Callable[[str, int], int] | None = None, + default_free: bool = False, + ): + """ + Initialize pricing engine. + + Args: + strategy: Pricing strategy - "free", "fixed", "size_based", or "custom" + units_per_kb: Price per KB for size_based strategy (micro-units) + fixed_price: Fixed price for "fixed" strategy (micro-units) + custom_pricing_fn: Custom function(cid_str, size) → price + for "custom" strategy + default_free: If True, all CIDs are free by default + + Strategies: + - "free": Always return 0 (all blocks free) + - "fixed": Return fixed_price for all blocks + - "size_based": price = max(1, int(size_kb * units_per_kb)) + - "custom": Use custom_pricing_fn(cid_str, block_size) + + """ + self.strategy = strategy + self.units_per_kb = units_per_kb + self.fixed_price = fixed_price + self.custom_pricing_fn = custom_pricing_fn + self.default_free = default_free + + # Per-CID overrides: cid_hex → price (0 = free, >0 = specific price) + self._cid_prices: dict[str, int] = {} + + logger.info( + f"Pricing engine initialized: strategy={strategy} " + f"units_per_kb={units_per_kb} default_free={default_free}" + ) + + def set_price(self, cid: str | bytes, price: int) -> None: + """ + Set a specific price for a CID (overrides strategy). + + Args: + cid: The CID (hex string or bytes) + price: Price in micro-units (0 = free) + + """ + cid_hex = _cid_to_hex(cid) + self._cid_prices[cid_hex] = price + logger.info(f"Set price for {cid_hex[:20]}... = {price} units") + + def set_free(self, cid: str | bytes) -> None: + """ + Mark a CID as free (price = 0). + + Args: + cid: The CID to mark as free + + """ + self.set_price(cid, 0) + + def compute_price(self, cid_str: str, block_size: int) -> int: + """ + Compute the price for a block/file. + + Args: + cid_str: The CID as a hex string + block_size: Size in bytes (for root CID, this is total file size) + + Returns: + Price in micro-units (0 = free, >0 = paid) + + Note: + For multi-block files, call this ONCE with the root CID and total size, + not for each individual chunk. + + """ + # Check for per-CID override + if cid_str in self._cid_prices: + price = self._cid_prices[cid_str] + logger.debug(f"Using override price for {cid_str[:20]}... = {price}") + return price + + # Apply default free policy + if self.default_free: + return 0 + + # Apply strategy + if self.strategy == "free": + return 0 + + elif self.strategy == "fixed": + return self.fixed_price + + elif self.strategy == "size_based": + # Price = units_per_kb × size_in_kb (minimum 1 unit) + kb = block_size / 1024 + price = max(1, int(kb * self.units_per_kb)) + logger.debug( + f"Size-based pricing: {block_size}B = {kb:.2f}KB × " + f"{self.units_per_kb} = {price} units" + ) + return price + + elif self.strategy == "custom": + if self.custom_pricing_fn is None: + raise ValueError("Custom strategy requires custom_pricing_fn") + return self.custom_pricing_fn(cid_str, block_size) + + else: + raise ValueError(f"Unknown pricing strategy: {self.strategy}") + + def get_units_per_kb(self) -> float: + """ + Get the current units_per_kb rate (for size_based strategy). + + Returns: + Units per KB, or 0.0 if not using size_based strategy + + """ + if self.strategy == "size_based": + return self.units_per_kb + return 0.0 + + +# ── Helper functions ────────────────────────────────────────────────────────── + + +def _cid_to_hex(cid: str | bytes) -> str: + """Convert CID to hex string for consistent storage.""" + if isinstance(cid, bytes): + return cid.hex() + elif isinstance(cid, str): + # If already hex, return as-is + try: + bytes.fromhex(cid) + return cid + except ValueError: + # Assume it's a base58/base32 encoded CID string + return cid.encode().hex() + else: + raise TypeError(f"CID must be str or bytes, got {type(cid)}") diff --git a/libp2p/bitswap/provider_query.py b/libp2p/bitswap/provider_query.py new file mode 100644 index 000000000..47fcf98ad --- /dev/null +++ b/libp2p/bitswap/provider_query.py @@ -0,0 +1,457 @@ +""" +Provider Query Manager for Bitswap. + +This module provides DHT integration for automatic provider discovery with +caching, parallelization, and error handling. It's a critical component for +enabling automatic peer discovery in Bitswap without manual peer specification. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass, field +import logging +import time +from typing import TYPE_CHECKING + +import trio + +from libp2p.peer.id import ID as PeerID + +from .cid import CIDInput, cid_to_bytes, format_cid_for_display + +if TYPE_CHECKING: + from libp2p.kad_dht.kad_dht import KadDHT + +logger = logging.getLogger(__name__) + + +@dataclass +class ProviderCacheEntry: + """ + Cached provider information for a CID. + + Attributes: + providers: List of peer IDs that provide this content + timestamp: When this entry was cached + ttl: Time-to-live in seconds (how long the cache is valid) + + """ + + providers: list[PeerID] + timestamp: float = field(default_factory=time.time) + ttl: float = 300 # 5 minutes default + + def is_expired(self) -> bool: + """Check if this cache entry has expired.""" + return (time.time() - self.timestamp) > self.ttl + + def age(self) -> float: + """Get the age of this cache entry in seconds.""" + return time.time() - self.timestamp + + +class ProviderCache: + """ + LRU cache for provider records with TTL support. + + Caches DHT provider query results to reduce network load and improve + performance for repeated queries. + """ + + def __init__(self, max_size: int = 1000, default_ttl: float = 300): + """ + Initialize provider cache. + + Args: + max_size: Maximum number of entries to cache + default_ttl: Default time-to-live in seconds + + """ + self.max_size = max_size + self.default_ttl: float = default_ttl + self._cache: dict[bytes, ProviderCacheEntry] = {} + self._access_order: list[bytes] = [] # For LRU tracking + + def get(self, cid_bytes: bytes) -> list[PeerID] | None: + """ + Get cached providers for a CID. + + Args: + cid_bytes: CID as bytes + + Returns: + List of provider peer IDs if cached and not expired, None otherwise + + """ + if cid_bytes not in self._cache: + return None + + entry = self._cache[cid_bytes] + + # Check if expired + if entry.is_expired(): + self._remove(cid_bytes) + return None + + # Update access order (LRU) + self._mark_accessed(cid_bytes) + + return entry.providers + + def put( + self, + cid_bytes: bytes, + providers: list[PeerID], + ttl: float | None = None, + ) -> None: + """ + Cache providers for a CID. + + Args: + cid_bytes: CID as bytes + providers: List of provider peer IDs + ttl: Optional custom TTL (uses default if not specified) + + """ + # Evict oldest entry if cache is full + if len(self._cache) >= self.max_size and cid_bytes not in self._cache: + self._evict_oldest() + + # Store entry + entry = ProviderCacheEntry( + providers=providers, + timestamp=time.time(), + ttl=ttl or self.default_ttl, + ) + self._cache[cid_bytes] = entry + self._mark_accessed(cid_bytes) + + def _mark_accessed(self, cid_bytes: bytes) -> None: + """Mark a cache entry as recently accessed (for LRU).""" + # Remove from current position if exists + if cid_bytes in self._access_order: + self._access_order.remove(cid_bytes) + # Add to end (most recently used) + self._access_order.append(cid_bytes) + + def _evict_oldest(self) -> None: + """Evict the least recently used cache entry.""" + if not self._access_order: + return + oldest = self._access_order.pop(0) + self._remove(oldest) + + def _remove(self, cid_bytes: bytes) -> None: + """Remove an entry from the cache.""" + if cid_bytes in self._cache: + del self._cache[cid_bytes] + if cid_bytes in self._access_order: + self._access_order.remove(cid_bytes) + + def clear(self) -> None: + """Clear all cache entries.""" + self._cache.clear() + self._access_order.clear() + + def cleanup_expired(self) -> int: + """ + Remove all expired entries from the cache. + + Returns: + Number of entries removed + + """ + expired = [ + cid_bytes for cid_bytes, entry in self._cache.items() if entry.is_expired() + ] + + for cid_bytes in expired: + self._remove(cid_bytes) + + return len(expired) + + def size(self) -> int: + """Get current cache size.""" + return len(self._cache) + + def stats(self) -> dict[str, int]: + """ + Get cache statistics. + + Returns: + Dictionary with cache statistics + + """ + return { + "size": len(self._cache), + "max_size": self.max_size, + "expired": sum(1 for e in self._cache.values() if e.is_expired()), + } + + +class ProviderQueryManager: + """ + Manages DHT provider queries with caching and parallelization. + + This component integrates Bitswap with the Kademlia DHT to automatically + discover which peers have specific content. It provides: + + - Automatic provider discovery via DHT + - Parallel queries for multiple CIDs + - Provider caching to reduce DHT load + - Configurable limits and timeouts + - Error handling and retry logic + + Example: + >>> dht = KadDHT(host) + >>> manager = ProviderQueryManager(dht) + >>> providers = await manager.find_providers([cid1, cid2]) + >>> print(f"Found {len(providers)} provider mappings") + + """ + + def __init__( + self, + dht: KadDHT, + max_providers: int = 10, + cache_ttl: float = 300, # 5 minutes + cache_size: int = 1000, + max_concurrent_queries: int = 20, + ): + """ + Initialize Provider Query Manager. + + Args: + dht: Kademlia DHT instance for provider queries + max_providers: Maximum number of providers to return per CID + cache_ttl: Cache time-to-live in seconds + cache_size: Maximum number of CIDs to cache + max_concurrent_queries: Maximum parallel DHT queries + + """ + self.dht = dht + self.max_providers = max_providers + self.cache = ProviderCache(max_size=cache_size, default_ttl=cache_ttl) + self.query_semaphore = trio.Semaphore(max_concurrent_queries) + + # Statistics + self._stats = { + "queries": 0, + "cache_hits": 0, + "cache_misses": 0, + "errors": 0, + "providers_found": 0, + } + + async def find_providers( + self, + cids: Sequence[CIDInput], + timeout: float = 5.0, + use_cache: bool = True, + ) -> dict[bytes, list[PeerID]]: + """ + Find providers for multiple CIDs in parallel. + + This is the main entry point for provider discovery. It: + 1. Checks cache for each CID + 2. Queries DHT in parallel for cache misses + 3. Updates cache with results + 4. Returns combined results + + Args: + cids: List of CIDs to find providers for + timeout: Timeout per DHT query in seconds + use_cache: Whether to use cached results + + Returns: + Dictionary mapping CID bytes to list of provider peer IDs + + Example: + >>> cids = [cid1, cid2, cid3] + >>> results = await manager.find_providers(cids) + >>> for cid_bytes, providers in results.items(): + ... n = len(providers) + ... print(f"CID {cid_bytes.hex()[:8]}... has {n} providers") + + """ + results: dict[bytes, list[PeerID]] = {} + missing: list[tuple[CIDInput, bytes]] = [] + + # Phase 1: Check cache + for cid in cids: + cid_bytes = cid_to_bytes(cid) + + if use_cache: + cached = self.cache.get(cid_bytes) + if cached is not None: + results[cid_bytes] = cached + self._stats["cache_hits"] += 1 + logger.debug( + f"Cache hit for {format_cid_for_display(cid, max_len=12)}: " + f"{len(cached)} providers" + ) + continue + + # Not in cache or cache disabled + missing.append((cid, cid_bytes)) + self._stats["cache_misses"] += 1 + + if not missing: + logger.debug(f"All {len(cids)} CIDs found in cache") + return results + + logger.info( + f"Querying DHT for {len(missing)} CIDs (cache hits: {len(results)})" + ) + + # Phase 2: Query DHT in parallel for missing CIDs + async with trio.open_nursery() as nursery: + for cid, cid_bytes in missing: + nursery.start_soon( + self._query_single, + cid, + cid_bytes, + results, + timeout, + ) + + logger.info( + f"Provider discovery complete: {len(results)}/{len(cids)} CIDs resolved" + ) + + return results + + async def _query_single( + self, + cid: CIDInput, + cid_bytes: bytes, + results: dict[bytes, list[PeerID]], + timeout: float, + ) -> None: + """ + Query DHT for providers of a single CID. + + This method is called concurrently for each CID. It uses a semaphore + to limit parallelism and handles errors gracefully. + + Args: + cid: CID to query (for display) + cid_bytes: CID as bytes (for DHT query) + results: Shared results dictionary to update + timeout: Query timeout in seconds + + """ + async with self.query_semaphore: + self._stats["queries"] += 1 + + try: + with trio.fail_after(timeout): + # Perform a network DHT provider lookup (not a local-store read) + provider_infos = await self.dht.provider_store.find_providers( + cid_bytes, self.max_providers + ) + + # Extract peer IDs from PeerInfo objects + providers = [info.peer_id for info in provider_infos] + + # Limit to max_providers + if len(providers) > self.max_providers: + providers = providers[: self.max_providers] + + if providers: + # Update results + results[cid_bytes] = providers + + # Update cache with remote results + self.cache.put(cid_bytes, providers) + + # Update stats + self._stats["providers_found"] += len(providers) + + logger.debug( + f"Found {len(providers)} providers for " + f"{format_cid_for_display(cid, max_len=12)}" + ) + else: + logger.debug( + f"No providers found for " + f"{format_cid_for_display(cid, max_len=12)}" + ) + + except trio.TooSlowError: + self._stats["errors"] += 1 + logger.warning( + f"DHT query timeout for {format_cid_for_display(cid, max_len=12)}" + ) + except Exception as e: + self._stats["errors"] += 1 + cid_disp = format_cid_for_display(cid, max_len=12) + logger.error(f"DHT query error for {cid_disp}: {e}") + + async def find_providers_single( + self, + cid: CIDInput, + timeout: float = 5.0, + use_cache: bool = True, + ) -> list[PeerID]: + """ + Find providers for a single CID (convenience method). + + Args: + cid: CID to find providers for + timeout: Query timeout in seconds + use_cache: Whether to use cached results + + Returns: + List of provider peer IDs + + Example: + >>> providers = await manager.find_providers_single(cid) + >>> for peer_id in providers: + ... print(f"Provider: {peer_id}") + + """ + results = await self.find_providers([cid], timeout, use_cache) + cid_bytes = cid_to_bytes(cid) + return results.get(cid_bytes, []) + + def get_stats(self) -> dict[str, int]: + """ + Get provider query statistics. + + Returns: + Dictionary with statistics: + - queries: Total DHT queries made + - cache_hits: Number of cache hits + - cache_misses: Number of cache misses + - errors: Number of query errors + - providers_found: Total providers discovered + - cache_size: Current cache size + + Example: + >>> stats = manager.get_stats() + >>> print(f"Cache hit rate: {stats['cache_hits'] / stats['queries']:.1%}") + + """ + stats = self._stats.copy() + stats.update(self.cache.stats()) + return stats + + def clear_cache(self) -> None: + """Clear the provider cache.""" + self.cache.clear() + logger.info("Provider cache cleared") + + async def cleanup_expired_cache(self) -> int: + """ + Remove expired entries from cache. + + Returns: + Number of entries removed + + """ + removed = self.cache.cleanup_expired() + if removed > 0: + logger.debug(f"Removed {removed} expired cache entries") + return removed diff --git a/libp2p/bitswap/wantlist.py b/libp2p/bitswap/wantlist.py new file mode 100644 index 000000000..8c3f80519 --- /dev/null +++ b/libp2p/bitswap/wantlist.py @@ -0,0 +1,367 @@ +""" +Typed dataclass wrappers for Bitswap wantlist entries and messages. + +Provides a clean, self-documenting Python API over the raw protobuf +Message format. All types here are pure Python dataclasses — no +protobuf dependency. Convert to/from protobuf via messages.py helpers. + +Usage: + from libp2p.bitswap.wantlist import ( + WantType, BlockPresenceType, + WantlistEntry, Wantlist, + BlockPresence, BitswapMessage, + ) + + # Build a wantlist + wl = Wantlist() + wl.add(my_cid, want_type=WantType.Block, send_dont_have=True) + wl.add(other_cid, want_type=WantType.Have) + + # Build a full message + msg = BitswapMessage() + msg.add_want(my_cid, want_type=WantType.Block) + msg.add_block(root_cid, block_data) + msg.add_have(peer_cid) + msg.add_dont_have(missing_cid) +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + +from .cid import CIDInput, cid_to_bytes +from .pb.bitswap_pb2 import Message as PBMessage + +# ── enums ───────────────────────────────────────────────────────────────────── + + +class WantType(Enum): + """ + Type of want request (Bitswap 1.2.0 wantType field). + + Block = 0 → "Send me the full block bytes." + Have = 1 → "Just tell me if you have it (HAVE/DONT_HAVE response)." + Cheaper than Block — useful for presence checks before + committing to a full block transfer. + """ + + Block = 0 + Have = 1 + + +class BlockPresenceType(Enum): + """ + Type of block presence response (Bitswap 1.2.0 BlockPresence.type field). + + Have = 0 → Peer has the block and can send it. + DontHave = 1 → Peer does not have the block. + """ + + Have = 0 + DontHave = 1 + + +# ── wantlist dataclasses ────────────────────────────────────────────────────── + + +@dataclass +class WantlistEntry: + """ + A single entry in a Bitswap wantlist. + + Prefer constructing via WantlistEntry.from_cid() which normalises + any CIDInput form to raw bytes. + + Attributes: + cid: CID of the requested block as raw bytes. + priority: Request urgency. Higher = more urgent. Default 1. + cancel: True to cancel a previously sent want for this CID. + want_type: WantType.Block (full data) or WantType.Have (presence). + send_dont_have: If True, ask the peer to send an explicit DontHave + response when it doesn't have the block. + + """ + + cid: bytes + priority: int = 1 + cancel: bool = False + want_type: WantType = WantType.Block + send_dont_have: bool = False + + @classmethod + def from_cid( + cls, + cid: CIDInput, + priority: int = 1, + cancel: bool = False, + want_type: WantType = WantType.Block, + send_dont_have: bool = False, + ) -> WantlistEntry: + """Create a WantlistEntry from any CIDInput form.""" + return cls( + cid=cid_to_bytes(cid), + priority=priority, + cancel=cancel, + want_type=want_type, + send_dont_have=send_dont_have, + ) + + +@dataclass +class Wantlist: + """ + A collection of wantlist entries. + + Attributes: + entries: List of WantlistEntry items. + full: True = this replaces the peer's entire wantlist. + False (default) = delta update, adds/cancels entries. + + Example: + >>> wl = Wantlist() + >>> wl.add(cid1, want_type=WantType.Block, send_dont_have=True) + >>> wl.add(cid2, want_type=WantType.Have) + >>> wl.cancel(cid3) + >>> print(len(wl)) # 3 + + """ + + entries: list[WantlistEntry] = field(default_factory=list) + full: bool = False + + def add( + self, + cid: CIDInput, + priority: int = 1, + want_type: WantType = WantType.Block, + send_dont_have: bool = False, + ) -> None: + """Add a want entry for the given CID.""" + self.entries.append( + WantlistEntry.from_cid( + cid, + priority=priority, + want_type=want_type, + send_dont_have=send_dont_have, + ) + ) + + def cancel(self, cid: CIDInput) -> None: + """Add a cancel entry for a previously wanted CID.""" + self.entries.append(WantlistEntry.from_cid(cid, cancel=True)) + + def contains(self, cid: CIDInput) -> bool: + """Return True if any non-cancel entry exists for this CID.""" + cid_bytes = cid_to_bytes(cid) + return any(e.cid == cid_bytes and not e.cancel for e in self.entries) + + def __len__(self) -> int: + return len(self.entries) + + def __bool__(self) -> bool: + return bool(self.entries) + + +# ── message dataclasses ─────────────────────────────────────────────────────── + + +@dataclass +class BlockPresence: + """ + A HAVE or DONT_HAVE response for a specific CID (Bitswap 1.2.0). + + Use the class-method constructors for convenience: + BlockPresence.have(cid) + BlockPresence.dont_have(cid) + """ + + cid: bytes + type: BlockPresenceType + + @classmethod + def have(cls, cid: CIDInput) -> BlockPresence: + """Create a HAVE response.""" + return cls(cid=cid_to_bytes(cid), type=BlockPresenceType.Have) + + @classmethod + def dont_have(cls, cid: CIDInput) -> BlockPresence: + """Create a DONT_HAVE response.""" + return cls(cid=cid_to_bytes(cid), type=BlockPresenceType.DontHave) + + +@dataclass +class BitswapMessage: + """ + High-level typed representation of a Bitswap protocol message. + + Wraps the three main message components with typed fields and + convenience builder methods. Does not depend on protobuf directly — + convert to/from protobuf using to_proto() / from_proto(). + + Attributes: + wantlist: Optional wantlist (want/cancel entries). + blocks: List of (cid_bytes, block_data) block payloads. + block_presences: List of HAVE/DONT_HAVE presence responses. + pending_bytes: Bytes queued to send (v1.2.0 flow-control hint). + + Properties: + is_want True if the message contains want entries. + has_blocks True if the message contains block payloads. + has_presences True if the message contains HAVE/DONT_HAVE entries. + + Example: + >>> msg = BitswapMessage() + >>> msg.add_want(cid1, want_type=WantType.Block, send_dont_have=True) + >>> msg.add_want(cid2, want_type=WantType.Have) + >>> msg.add_block(root_cid, data) + >>> msg.add_have(cid3) + >>> msg.add_dont_have(cid4) + >>> assert msg.is_want and msg.has_blocks and msg.has_presences + + """ + + wantlist: Wantlist | None = None + blocks: list[tuple[bytes, bytes]] = field(default_factory=list) # (cid, data) + block_presences: list[BlockPresence] = field(default_factory=list) + pending_bytes: int = 0 + + # ── read-only properties ────────────────────────────────────────────────── + + @property + def is_want(self) -> bool: + """True if this message contains wantlist entries.""" + return self.wantlist is not None and bool(self.wantlist) + + @property + def has_blocks(self) -> bool: + """True if this message carries block payloads.""" + return bool(self.blocks) + + @property + def has_presences(self) -> bool: + """True if this message carries HAVE/DONT_HAVE responses.""" + return bool(self.block_presences) + + # ── builder methods ─────────────────────────────────────────────────────── + + def add_want( + self, + cid: CIDInput, + priority: int = 1, + want_type: WantType = WantType.Block, + send_dont_have: bool = False, + ) -> None: + """Add a want entry. Creates the wantlist if not yet present.""" + if self.wantlist is None: + self.wantlist = Wantlist() + self.wantlist.add( + cid, + priority=priority, + want_type=want_type, + send_dont_have=send_dont_have, + ) + + def cancel_want(self, cid: CIDInput) -> None: + """Add a cancel entry for a previously wanted CID.""" + if self.wantlist is None: + self.wantlist = Wantlist() + self.wantlist.cancel(cid) + + def add_block(self, cid: CIDInput, data: bytes) -> None: + """Add a block payload to this message.""" + self.blocks.append((cid_to_bytes(cid), data)) + + def add_have(self, cid: CIDInput) -> None: + """Add a HAVE presence response.""" + self.block_presences.append(BlockPresence.have(cid)) + + def add_dont_have(self, cid: CIDInput) -> None: + """Add a DONT_HAVE presence response.""" + self.block_presences.append(BlockPresence.dont_have(cid)) + + # ── protobuf conversion ─────────────────────────────────────────────────── + + def to_proto(self) -> PBMessage: + """ + Convert to a raw protobuf Message object (pb.bitswap_pb2.Message). + + Returns: + A populated protobuf Message ready for serialisation. + + """ + proto = PBMessage() + + if self.wantlist is not None: + for entry in self.wantlist.entries: + pb_entry = proto.wantlist.entries.add() + pb_entry.block = entry.cid + pb_entry.priority = entry.priority + pb_entry.cancel = entry.cancel + pb_entry.wantType = entry.want_type.value # type: ignore[assignment] + pb_entry.sendDontHave = entry.send_dont_have + proto.wantlist.full = self.wantlist.full + + for cid_bytes, data in self.blocks: + from .cid import get_cid_prefix + + pb_block = proto.payload.add() + pb_block.prefix = get_cid_prefix(cid_bytes) + pb_block.data = data + + for presence in self.block_presences: + pb_presence = proto.blockPresences.add() + pb_presence.cid = presence.cid + pb_presence.type = presence.type.value # type: ignore[assignment] + + if self.pending_bytes: + proto.pendingBytes = self.pending_bytes + + return proto + + @classmethod + def from_proto(cls, proto: PBMessage) -> BitswapMessage: + """ + Build a BitswapMessage from a raw protobuf Message object. + + Args: + proto: A pb.bitswap_pb2.Message instance. + + Returns: + A populated BitswapMessage dataclass. + + """ + from .cid import reconstruct_cid_from_prefix_and_data + + msg = cls() + + if proto.HasField("wantlist") and proto.wantlist.entries: + wl = Wantlist(full=proto.wantlist.full) + for e in proto.wantlist.entries: + wl.entries.append( + WantlistEntry( + cid=bytes(e.block), + priority=e.priority, + cancel=e.cancel, + want_type=WantType(e.wantType), + send_dont_have=e.sendDontHave, + ) + ) + msg.wantlist = wl + + for pb_block in proto.payload: + cid_bytes = reconstruct_cid_from_prefix_and_data( + bytes(pb_block.prefix), bytes(pb_block.data) + ) + msg.blocks.append((cid_bytes, bytes(pb_block.data))) + + for pb_presence in proto.blockPresences: + msg.block_presences.append( + BlockPresence( + cid=bytes(pb_presence.cid), + type=BlockPresenceType(pb_presence.type), + ) + ) + + msg.pending_bytes = proto.pendingBytes + return msg diff --git a/libp2p/kad_dht/__init__.py b/libp2p/kad_dht/__init__.py index 690d37bae..cf58e878f 100644 --- a/libp2p/kad_dht/__init__.py +++ b/libp2p/kad_dht/__init__.py @@ -7,6 +7,7 @@ from .kad_dht import ( KadDHT, + DHTMode, ) from .peer_routing import ( PeerRouting, @@ -23,6 +24,7 @@ __all__ = [ "KadDHT", + "DHTMode", "RoutingTable", "PeerRouting", "ValueStore", diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index 01aa23afc..bb11f1cb6 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -1058,7 +1058,7 @@ async def query_one(peer: ID) -> None: values = [rec.value for _p, rec in valid_records] best_idx = self.validator.select(key, values) logger.debug( - f"Selected best value at index {best_idx}using validator.select()" + f"Selected best value at index {best_idx} using validator.select()" ) best_peer, best_rec = valid_records[best_idx] @@ -1074,7 +1074,7 @@ async def query_one(peer: ID) -> None: if outdated_peers: logger.debug( - f"Propagating best value to {len(outdated_peers)}" + f"Propagating best value to {len(outdated_peers)} " "peers with outdated values" ) diff --git a/libp2p/kad_dht/pb/kademlia.proto b/libp2p/kad_dht/pb/kademlia.proto index 8d66cca5c..93fe526c3 100644 --- a/libp2p/kad_dht/pb/kademlia.proto +++ b/libp2p/kad_dht/pb/kademlia.proto @@ -4,6 +4,11 @@ message Record { bytes key = 1; bytes value = 2; string timeReceived = 5; + // author is the serialized public key of the record author (for unsigned records) + optional bytes author = 3; + // signature is the Ed25519/Secp256k1 signature over the record + // signing payload: "libp2p-record:" + key + value + optional bytes signature = 4; }; message Message { @@ -39,4 +44,3 @@ message Message { optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded } -` diff --git a/libp2p/kad_dht/pb/kademlia_pb2.py b/libp2p/kad_dht/pb/kademlia_pb2.py index e41bb5292..19b4c2ca2 100644 --- a/libp2p/kad_dht/pb/kademlia_pb2.py +++ b/libp2p/kad_dht/pb/kademlia_pb2.py @@ -1,22 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE # source: libp2p/kad_dht/pb/kademlia.proto -# Protobuf Python Version: 5.29.3 +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 5, - 29, - 3, - '', - 'libp2p/kad_dht/pb/kademlia.proto' -) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -24,21 +14,21 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\"\x80\x01\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\x12\x13\n\x06\x61uthor\x18\x03 \x01(\x0cH\x00\x88\x01\x01\x12\x16\n\tsignature\x18\x04 \x01(\x0cH\x01\x88\x01\x01\x42\t\n\x07_authorB\x0c\n\n_signature\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_RECORD']._serialized_start=36 - _globals['_RECORD']._serialized_end=94 - _globals['_MESSAGE']._serialized_start=97 - _globals['_MESSAGE']._serialized_end=643 - _globals['_MESSAGE_PEER']._serialized_start=308 - _globals['_MESSAGE_PEER']._serialized_end=430 - _globals['_MESSAGE_MESSAGETYPE']._serialized_start=432 - _globals['_MESSAGE_MESSAGETYPE']._serialized_end=537 - _globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539 - _globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626 +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_RECORD']._serialized_start=37 + _globals['_RECORD']._serialized_end=165 + _globals['_MESSAGE']._serialized_start=168 + _globals['_MESSAGE']._serialized_end=714 + _globals['_MESSAGE_PEER']._serialized_start=379 + _globals['_MESSAGE_PEER']._serialized_end=501 + _globals['_MESSAGE_MESSAGETYPE']._serialized_start=503 + _globals['_MESSAGE_MESSAGETYPE']._serialized_end=608 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=610 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=697 # @@protoc_insertion_point(module_scope) diff --git a/libp2p/kad_dht/pb/kademlia_pb2.pyi b/libp2p/kad_dht/pb/kademlia_pb2.pyi index 641ae66ae..ae32c2361 100644 --- a/libp2p/kad_dht/pb/kademlia_pb2.pyi +++ b/libp2p/kad_dht/pb/kademlia_pb2.pyi @@ -1,144 +1,74 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" - -import builtins -import collections.abc -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import sys -import typing - -if sys.version_info >= (3, 10): - import typing as typing_extensions -else: - import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor - -@typing.final -class Record(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - TIMERECEIVED_FIELD_NUMBER: builtins.int - key: builtins.bytes - value: builtins.bytes - timeReceived: builtins.str - def __init__( - self, - *, - key: builtins.bytes = ..., - value: builtins.bytes = ..., - timeReceived: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ... - -global___Record = Record - -@typing.final -class Message(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - class _MessageType: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - PUT_VALUE: Message._MessageType.ValueType # 0 - GET_VALUE: Message._MessageType.ValueType # 1 - ADD_PROVIDER: Message._MessageType.ValueType # 2 - GET_PROVIDERS: Message._MessageType.ValueType # 3 - FIND_NODE: Message._MessageType.ValueType # 4 - PING: Message._MessageType.ValueType # 5 - - class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ... - PUT_VALUE: Message.MessageType.ValueType # 0 - GET_VALUE: Message.MessageType.ValueType # 1 - ADD_PROVIDER: Message.MessageType.ValueType # 2 - GET_PROVIDERS: Message.MessageType.ValueType # 3 - FIND_NODE: Message.MessageType.ValueType # 4 - PING: Message.MessageType.ValueType # 5 - - class _ConnectionType: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - NOT_CONNECTED: Message._ConnectionType.ValueType # 0 - CONNECTED: Message._ConnectionType.ValueType # 1 - CAN_CONNECT: Message._ConnectionType.ValueType # 2 - CANNOT_CONNECT: Message._ConnectionType.ValueType # 3 - - class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ... - NOT_CONNECTED: Message.ConnectionType.ValueType # 0 - CONNECTED: Message.ConnectionType.ValueType # 1 - CAN_CONNECT: Message.ConnectionType.ValueType # 2 - CANNOT_CONNECT: Message.ConnectionType.ValueType # 3 - - @typing.final - class Peer(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - ID_FIELD_NUMBER: builtins.int - ADDRS_FIELD_NUMBER: builtins.int - CONNECTION_FIELD_NUMBER: builtins.int - SIGNEDRECORD_FIELD_NUMBER: builtins.int - id: builtins.bytes - connection: global___Message.ConnectionType.ValueType - signedRecord: builtins.bytes - """Envelope(PeerRecord) encoded""" - @property - def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... - def __init__( - self, - *, - id: builtins.bytes = ..., - addrs: collections.abc.Iterable[builtins.bytes] | None = ..., - connection: global___Message.ConnectionType.ValueType = ..., - signedRecord: builtins.bytes | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["_signedRecord", b"_signedRecord", "signedRecord", b"signedRecord"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_signedRecord", b"_signedRecord", "addrs", b"addrs", "connection", b"connection", "id", b"id", "signedRecord", b"signedRecord"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["_signedRecord", b"_signedRecord"]) -> typing.Literal["signedRecord"] | None: ... - - TYPE_FIELD_NUMBER: builtins.int - CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int - KEY_FIELD_NUMBER: builtins.int - RECORD_FIELD_NUMBER: builtins.int - CLOSERPEERS_FIELD_NUMBER: builtins.int - PROVIDERPEERS_FIELD_NUMBER: builtins.int - SENDERRECORD_FIELD_NUMBER: builtins.int - type: global___Message.MessageType.ValueType - clusterLevelRaw: builtins.int - key: builtins.bytes - senderRecord: builtins.bytes - """Envelope(PeerRecord) encoded""" - @property - def record(self) -> global___Record: ... - @property - def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... - @property - def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... - def __init__( - self, - *, - type: global___Message.MessageType.ValueType = ..., - clusterLevelRaw: builtins.int = ..., - key: builtins.bytes = ..., - record: global___Record | None = ..., - closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., - providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., - senderRecord: builtins.bytes | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["_senderRecord", b"_senderRecord", "record", b"record", "senderRecord", b"senderRecord"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_senderRecord", b"_senderRecord", "closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "senderRecord", b"senderRecord", "type", b"type"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["_senderRecord", b"_senderRecord"]) -> typing.Literal["senderRecord"] | None: ... - -global___Message = Message +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import Any, ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class Record(_message.Message): + __slots__ = ("key", "value", "timeReceived", "author", "signature") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + TIMERECEIVED_FIELD_NUMBER: _ClassVar[int] + AUTHOR_FIELD_NUMBER: _ClassVar[int] + SIGNATURE_FIELD_NUMBER: _ClassVar[int] + key: bytes + value: bytes + timeReceived: str + author: bytes + signature: bytes + def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ..., author: _Optional[bytes] = ..., signature: _Optional[bytes] = ...) -> None: ... + +class Message(_message.Message): + __slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord") + class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + PUT_VALUE: _ClassVar[Message.MessageType] + GET_VALUE: _ClassVar[Message.MessageType] + ADD_PROVIDER: _ClassVar[Message.MessageType] + GET_PROVIDERS: _ClassVar[Message.MessageType] + FIND_NODE: _ClassVar[Message.MessageType] + PING: _ClassVar[Message.MessageType] + PUT_VALUE: Message.MessageType + GET_VALUE: Message.MessageType + ADD_PROVIDER: Message.MessageType + GET_PROVIDERS: Message.MessageType + FIND_NODE: Message.MessageType + PING: Message.MessageType + class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + NOT_CONNECTED: _ClassVar[Message.ConnectionType] + CONNECTED: _ClassVar[Message.ConnectionType] + CAN_CONNECT: _ClassVar[Message.ConnectionType] + CANNOT_CONNECT: _ClassVar[Message.ConnectionType] + NOT_CONNECTED: Message.ConnectionType + CONNECTED: Message.ConnectionType + CAN_CONNECT: Message.ConnectionType + CANNOT_CONNECT: Message.ConnectionType + class Peer(_message.Message): + __slots__ = ("id", "addrs", "connection", "signedRecord") + ID_FIELD_NUMBER: _ClassVar[int] + ADDRS_FIELD_NUMBER: _ClassVar[int] + CONNECTION_FIELD_NUMBER: _ClassVar[int] + SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int] + id: bytes + addrs: _containers.RepeatedScalarFieldContainer[bytes] + connection: Message.ConnectionType + signedRecord: bytes + def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ... + TYPE_FIELD_NUMBER: _ClassVar[int] + CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + RECORD_FIELD_NUMBER: _ClassVar[int] + CLOSERPEERS_FIELD_NUMBER: _ClassVar[int] + PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int] + SENDERRECORD_FIELD_NUMBER: _ClassVar[int] + type: Message.MessageType + clusterLevelRaw: int + key: bytes + record: Record + closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer] + providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer] + senderRecord: bytes + def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping[str, Any]]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping[str, Any]]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping[str, Any]]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index 90cd77ae4..459e7487e 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -20,6 +20,7 @@ ID, ) from libp2p.peer.peerstore import env_to_send_in_RPC +from libp2p.records.record import make_signed_put_record from .common import ( DEFAULT_TTL, @@ -65,14 +66,17 @@ def put(self, key: bytes, value: bytes, validity: float = 0.0) -> None: None """ - from libp2p.records.record import make_put_record - if validity == 0.0: validity = time.time() + DEFAULT_TTL logger.debug( "Storing value for key %s... with validity %s", key.hex(), validity ) - record = make_put_record(key, value) + + # Create a signed record using the host's private key + private_key = self.host.get_private_key() + record = make_signed_put_record(key, value, private_key) + + # Set timeReceived when storing locally record.timeReceived = str(time.time()) self.store[key] = (record, validity) @@ -123,11 +127,20 @@ async def _store_at_peer(self, peer_id: ID, key: bytes, value: bytes) -> bool: envelope_bytes, _ = env_to_send_in_RPC(self.host) message.senderRecord = envelope_bytes - # Set message fields + # Build the outbound record from the locally-stored signed record when + # available (normal put() path), otherwise sign the record now so the + # outbound message always carries signature and author fields. + local_entry = self.store.get(key) + if local_entry is not None: + signed_record, _ = local_entry + message.record.CopyFrom(signed_record) + else: + private_key = self.host.get_private_key() + signed_record = make_signed_put_record(key, value, private_key) + message.record.CopyFrom(signed_record) message.key = key - message.record.key = key - message.record.value = value - message.record.timeReceived = str(time.time()) + # Note: timeReceived will be set by the receiving peer when storing + message.record.ClearField("timeReceived") # Serialize and send the protobuf message with length prefix proto_bytes = message.SerializeToString() @@ -320,6 +333,10 @@ async def _get_from_peer( logger.debug( f"Received value for key {key.hex()} from peer {peer_id}" ) + + # Update timeReceived to current time (when we received it locally) + response.record.timeReceived = str(time.time()) + return response.record if return_record else response.record.value # Handle case where value is not found but peer infos are returned diff --git a/libp2p/peer/envelope.py b/libp2p/peer/envelope.py index 1fcbb1c75..9a7f6466f 100644 --- a/libp2p/peer/envelope.py +++ b/libp2p/peer/envelope.py @@ -1,7 +1,7 @@ from typing import Any, cast import multiaddr -from multicodec import Code, get_codec, get_prefix +from multicodec import Code, get_prefix from multicodec.code_table import LIBP2P_PEER_RECORD from libp2p.crypto.ed25519 import Ed25519PublicKey @@ -12,6 +12,7 @@ import libp2p.peer.pb.envelope_pb2 as pb import libp2p.peer.pb.peer_record_pb2 as record_pb from libp2p.peer.peer_record import ( + PEER_RECORD_ENVELOPE_PAYLOAD_TYPE, PeerRecord, peer_record_from_protobuf, unmarshal_record, @@ -19,9 +20,10 @@ from libp2p.utils.varint import encode_uvarint ENVELOPE_DOMAIN = "libp2p-peer-record" -# Multicodec-based codec for peer records +# Multicodec Code object (for internal use / comparison only) PEER_RECORD_CODE: Code = LIBP2P_PEER_RECORD -PEER_RECORD_CODEC: bytes = get_prefix(str(PEER_RECORD_CODE)) +# Wire-format payload type bytes — matches go-libp2p: []byte{0x03, 0x01} +PEER_RECORD_CODEC: bytes = PEER_RECORD_ENVELOPE_PAYLOAD_TYPE class Envelope: @@ -40,7 +42,9 @@ class Envelope: """ public_key: PublicKey - payload_type_code: Code + # payload_type is stored as raw bytes (wire format), matching go-libp2p. + # For PeerRecord envelopes this is bytes([0x03, 0x01]), NOT varint-encoded. + _payload_type: bytes raw_payload: bytes signature: bytes @@ -56,28 +60,42 @@ def __init__( ): self.public_key = public_key - # Normalise payload_type to a Code instance + # Normalise payload_type to raw bytes if isinstance(payload_type, bytes): - try: - codec_name = get_codec(payload_type) - self.payload_type_code = Code.from_string(codec_name) - except Exception as e: - raise ValueError(f"Invalid codec: {e}") + # Already raw bytes — use as-is (this is the go-libp2p wire format) + self._payload_type = payload_type elif isinstance(payload_type, str): - try: - self.payload_type_code = Code.from_string(payload_type) - except Exception as e: - raise ValueError(f"Invalid codec: {e}") + # Treat as codec name, encode to raw prefix bytes + self._payload_type = get_prefix(payload_type) + elif isinstance(payload_type, Code): + if payload_type == PEER_RECORD_CODE: + # Use the go-libp2p compatible raw bytes, not varint + self._payload_type = PEER_RECORD_ENVELOPE_PAYLOAD_TYPE + else: + self._payload_type = get_prefix(str(payload_type)) else: - self.payload_type_code = payload_type + self._payload_type = bytes(payload_type) self.raw_payload = raw_payload self.signature = signature @property def payload_type(self) -> bytes: - """Return the multicodec-prefixed payload type.""" - return get_prefix(str(self.payload_type_code)) + """Return the raw payload type bytes (wire format).""" + return self._payload_type + + @property + def payload_type_code(self) -> Code: + """Return the multicodec Code for this payload type (best-effort).""" + return PEER_RECORD_CODE + + @payload_type_code.setter + def payload_type_code(self, value: Code) -> None: + """Update the raw payload_type bytes from a Code value.""" + if value == PEER_RECORD_CODE: + self._payload_type = PEER_RECORD_ENVELOPE_PAYLOAD_TYPE + else: + self._payload_type = get_prefix(str(value)) def marshal_envelope(self) -> bytes: """ @@ -125,10 +143,9 @@ def record(self) -> PeerRecord: return self._cached_record try: - if self.payload_type_code != PEER_RECORD_CODE: + if self._payload_type != PEER_RECORD_ENVELOPE_PAYLOAD_TYPE: raise ValueError( - f"Unsupported payload type in envelope: " - f"{self.payload_type_code.name}" + f"Unsupported payload type in envelope: {self._payload_type.hex()}" ) msg = record_pb.PeerRecord() msg.ParseFromString(self.raw_payload) @@ -154,7 +171,7 @@ def equal(self, other: Any) -> bool: if isinstance(other, Envelope): return ( self.public_key.__eq__(other.public_key) - and self.payload_type_code == other.payload_type_code + and self._payload_type == other._payload_type and self.signature == other.signature and self.raw_payload == other.raw_payload ) @@ -217,7 +234,7 @@ def seal_record(record: PeerRecord, private_key: PrivateKey) -> Envelope: return Envelope( public_key=private_key.get_public_key(), - payload_type=PEER_RECORD_CODE, + payload_type=PEER_RECORD_ENVELOPE_PAYLOAD_TYPE, raw_payload=payload, signature=signature, ) diff --git a/libp2p/peer/peer_record.py b/libp2p/peer/peer_record.py index 0fff196f0..26676f983 100644 --- a/libp2p/peer/peer_record.py +++ b/libp2p/peer/peer_record.py @@ -4,7 +4,7 @@ from typing import Any from multiaddr import Multiaddr -from multicodec import Code, get_prefix +from multicodec import Code from multicodec.code_table import LIBP2P_PEER_RECORD from libp2p.abc import IPeerRecord @@ -14,7 +14,10 @@ PEER_RECORD_ENVELOPE_DOMAIN = "libp2p-peer-record" PEER_RECORD_ENVELOPE_CODE: Code = LIBP2P_PEER_RECORD -PEER_RECORD_ENVELOPE_PAYLOAD_TYPE = get_prefix(str(PEER_RECORD_ENVELOPE_CODE)) +# go-libp2p uses raw bytes [0x03, 0x01] for the peer-record payload type +# (NOT varint-encoded). See: https://github.com/libp2p/go-libp2p/blob/master/core/peer/record.go +# PeerRecordEnvelopePayloadType = []byte{0x03, 0x01} +PEER_RECORD_ENVELOPE_PAYLOAD_TYPE = bytes([0x03, 0x01]) _last_timestamp_lock = threading.Lock() _last_timestamp: int = 0 diff --git a/libp2p/records/record.py b/libp2p/records/record.py index 8644e3c09..87dd96b1c 100644 --- a/libp2p/records/record.py +++ b/libp2p/records/record.py @@ -1,4 +1,6 @@ +from libp2p.crypto.keys import PrivateKey from libp2p.kad_dht.pb import kademlia_pb2 as record_pb2 +from libp2p.records.utils import sign_record def make_put_record(key: bytes, value: bytes) -> record_pb2.Record: @@ -17,3 +19,35 @@ def make_put_record(key: bytes, value: bytes) -> record_pb2.Record: record.key = key record.value = value return record + + +def make_signed_put_record( + key: bytes, value: bytes, private_key: PrivateKey +) -> record_pb2.Record: + """ + Create a signed Record object with the specified key, value, and signature. + + The record is signed using the libp2p record signing convention: + signature = sign("libp2p-record:" + key + value) + + This matches go-libp2p's record signing behavior for DHT PUT_VALUE. + + Args: + key (bytes): The key for the record. + value (bytes): The value to associate with the key in the record. + private_key (PrivateKey): The private key to sign the record with. + + Returns: + record_pb2.Record: A signed Record object. + + """ + record = record_pb2.Record() + record.key = key + record.value = value + + # Sign the record + signature, author_public_key = sign_record(private_key, key, value) + record.signature = signature + record.author = author_public_key + + return record diff --git a/libp2p/records/utils.py b/libp2p/records/utils.py index 82161beb3..4657510a6 100644 --- a/libp2p/records/utils.py +++ b/libp2p/records/utils.py @@ -1,7 +1,93 @@ +from libp2p.crypto.ed25519 import Ed25519PublicKey +from libp2p.crypto.keys import PrivateKey, PublicKey +from libp2p.crypto.pb import crypto_pb2 +from libp2p.crypto.rsa import RSAPublicKey +from libp2p.crypto.secp256k1 import Secp256k1PublicKey + + class InvalidRecordType(Exception): pass +def _unmarshal_public_key(data: bytes) -> PublicKey: + """ + Deserialize a ``crypto_pb2.PublicKey`` protobuf into a concrete + ``PublicKey`` instance. + + Kept private to this module to avoid the circular import that arises + when importing from ``libp2p.records.pubkey`` (which itself imports + from this module). + """ + proto_key = crypto_pb2.PublicKey.FromString(data) + key_type = proto_key.key_type + key_data = proto_key.data + + if key_type == crypto_pb2.KeyType.RSA: + return RSAPublicKey.from_bytes(key_data) + elif key_type == crypto_pb2.KeyType.Ed25519: + return Ed25519PublicKey.from_bytes(key_data) + elif key_type == crypto_pb2.KeyType.Secp256k1: + return Secp256k1PublicKey.from_bytes(key_data) + else: + raise ValueError(f"Unsupported key type: {key_type}") + + +def sign_record( + private_key: PrivateKey, key: bytes, value: bytes +) -> tuple[bytes, bytes]: + """ + Sign a DHT record using the given private key. + + The signature is computed over "libp2p-record:" + key + value. + + Args: + private_key: The private key to sign with + key: The record key + value: The record value + + Returns: + tuple[bytes, bytes]: A tuple of (signature, author_public_key_bytes) + + """ + signing_payload = b"libp2p-record:" + key + value + signature = private_key.sign(signing_payload) + public_key = private_key.get_public_key() + # Serialize as a protobuf-wrapped PublicKey so that verify_record (and + # remote peers) can reconstruct the key without knowing its type in advance. + author_bytes = public_key.serialize() + return signature, author_bytes + + +def verify_record( + signature: bytes, author_public_key: bytes, key: bytes, value: bytes +) -> bool: + """ + Verify a signed DHT record. + + Supports all key types that libp2p serialises in a protobuf PublicKey + envelope (Ed25519, RSA, Secp256k1). The author field is treated as a + serialised ``crypto_pb2.PublicKey`` message and dispatched through + ``unmarshal_public_key`` so that non-Ed25519 peers are not silently + rejected. + + Args: + signature: The record signature + author_public_key: The serialized public key of the author + key: The record key + value: The record value + + Returns: + bool: True if the signature is valid, False otherwise + + """ + try: + public_key = _unmarshal_public_key(author_public_key) + signing_payload = b"libp2p-record:" + key + value + return public_key.verify(signing_payload, signature) + except Exception: + return False + + def split_key(key: str) -> tuple[str, str]: """ Split a record key into its type and the rest. The key must start with diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index cc2a85f8c..73d3b6aec 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -1158,16 +1158,20 @@ async def handle_incoming(self) -> None: stream.closed = True stream.reset_received = True self.stream_events[stream_id].set() - - ack_header = struct.pack( - YAMUX_HEADER_FORMAT, - 0, - TYPE_WINDOW_UPDATE, - FLAG_ACK, - stream_id, - 0, - ) - new_stream_notify = stream + # Deliver the reset stream to accept_stream() so + # callers can observe the reset state, but do NOT + # send an ACK back — the stream is already dead. + new_stream_notify = stream + else: + ack_header = struct.pack( + YAMUX_HEADER_FORMAT, + 0, + TYPE_WINDOW_UPDATE, + FLAG_ACK, + stream_id, + 0, + ) + new_stream_notify = stream else: rst_header = struct.pack( YAMUX_HEADER_FORMAT, @@ -1188,6 +1192,14 @@ async def handle_incoming(self) -> None: ) if new_stream_notify is not None: await self.new_stream_send_channel.send(new_stream_notify) + elif new_stream_notify is not None: + # SYN+RST: stream is reset on arrival — deliver to + # accept_stream() without sending an ACK back. + logger.debug( + f"Delivering reset stream {stream_id} " + f"to channel (no ACK) for peer {self.peer_id}" + ) + await self.new_stream_send_channel.send(new_stream_notify) elif ( typ == TYPE_DATA or typ == TYPE_WINDOW_UPDATE ) and flags & FLAG_ACK: diff --git a/libp2p/transport/quic/transport.py b/libp2p/transport/quic/transport.py index 0572fcfb9..8c6167f3b 100644 --- a/libp2p/transport/quic/transport.py +++ b/libp2p/transport/quic/transport.py @@ -268,7 +268,7 @@ async def dial( # Get appropriate QUIC client configuration config_key = TProtocol(f"{quic_version}_client") - logger.debug("config_key", config_key, self._quic_configs.keys()) + logger.debug("config_key %s %s", config_key, self._quic_configs.keys()) config = self._quic_configs.get(config_key) if not config: raise QUICDialError(f"Unsupported QUIC version: {quic_version}") @@ -286,7 +286,7 @@ async def dial( # Debug log to verify certificate is present logger.info( - f"Dialing QUIC connection to {host}:{port} (version: {{quic_version}})" + f"Dialing QUIC connection to {host}:{port} (version: {quic_version})" ) logger.debug("Starting QUIC Connection") diff --git a/newsfragments/1347.feature.rst b/newsfragments/1347.feature.rst new file mode 100644 index 000000000..f19cc7629 --- /dev/null +++ b/newsfragments/1347.feature.rst @@ -0,0 +1 @@ +Implement comprehensive Bitswap interoperability with IPFS Kubo, including UnixFS DAG-PB encoding and balanced layout support. Introduces ``FilesystemBlockStore`` and ``BlockService`` for robust block caching, Bitswap batch fetching, and streaming inputs (``chunk_stream``). diff --git a/tests/core/bitswap/test_block_service.py b/tests/core/bitswap/test_block_service.py new file mode 100644 index 000000000..f4754dd7c --- /dev/null +++ b/tests/core/bitswap/test_block_service.py @@ -0,0 +1,236 @@ +""" +Test BlockService — transparent local→network fallback with auto-caching. + +Run with: + python test_block_service.py +""" + +from unittest.mock import AsyncMock, MagicMock + +import trio + +from libp2p.bitswap.block_service import BlockService +from libp2p.bitswap.block_store import MemoryBlockStore +from libp2p.bitswap.cid import CODEC_RAW, compute_cid_v1 +from libp2p.bitswap.client import BitswapClient + + +def make_block(content: bytes): + cid = compute_cid_v1(content, codec=CODEC_RAW) + return cid, content + + +def ok(label): + print(f" OK {label}") + + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +def make_service(network_blocks: dict | None = None): + """ + Build a BlockService with a real MemoryBlockStore and a mock BitswapClient. + network_blocks: cid_bytes -> data that the mock 'network' can return. + """ + store = MemoryBlockStore() + mock_bitswap = MagicMock(spec=BitswapClient) + mock_bitswap.block_store = store + network_blocks = network_blocks or {} + + async def fake_get_block(cid, peer_id=None, timeout=30.0): + return network_blocks.get(bytes(cid)) + + async def fake_add_block(cid, data): + pass # just accept it + + async def fake_get_blocks_batch(cids, peer_id=None, timeout=30.0, batch_size=32): + return { + bytes(c): network_blocks[bytes(c)] + for c in cids + if bytes(c) in network_blocks + } + + mock_bitswap.get_block = AsyncMock(side_effect=fake_get_block) + mock_bitswap.add_block = AsyncMock(side_effect=fake_add_block) + mock_bitswap.get_blocks_batch = AsyncMock(side_effect=fake_get_blocks_batch) + + service = BlockService(store, mock_bitswap) + return service, store, mock_bitswap + + +# ── tests ───────────────────────────────────────────────────────────────────── + + +async def test_local_hit_no_network(): + print("\n[1] Local hit — network is never called") + cid, data = make_block(b"already stored locally") + service, store, mock_bitswap = make_service() + + # Pre-populate local store + await store.put_block(cid, data) + + result = await service.get_block(cid) + assert result == data + ok("get_block returns local data") + + mock_bitswap.get_block.assert_not_called() + ok("network (bitswap.get_block) was NOT called") + + +async def test_local_miss_goes_to_network(): + print("\n[2] Local miss — fetches from network") + cid, data = make_block(b"only on the network") + service, store, mock_bitswap = make_service(network_blocks={bytes(cid): data}) + + result = await service.get_block(cid) + assert result == data + ok("get_block returns network data") + + mock_bitswap.get_block.assert_called_once() + ok("network (bitswap.get_block) was called exactly once") + + +async def test_auto_cache_after_network_fetch(): + print("\n[3] Auto-cache — network-fetched block stored locally") + cid, data = make_block(b"fetch and cache me") + service, store, mock_bitswap = make_service(network_blocks={bytes(cid): data}) + + # First call: local miss → network fetch → auto-cache + result1 = await service.get_block(cid) + assert result1 == data + + # Verify it's now in the local store + cached = await store.get_block(cid) + assert cached == data + ok("block is in local store after first network fetch") + + # Second call: must be a local hit, no second network call + result2 = await service.get_block(cid) + assert result2 == data + assert mock_bitswap.get_block.call_count == 1 # still only 1 network call + ok("second get_block is a local hit (network called only once total)") + + +async def test_put_block_stores_and_announces(): + print("\n[4] put_block — stores locally AND calls bitswap.add_block") + cid, data = make_block(b"new block to store") + service, store, mock_bitswap = make_service() + + await service.put_block(cid, data) + + # Must be in local store + cached = await store.get_block(cid) + assert cached == data + ok("block is in local store after put_block") + + # Must have called bitswap.add_block (announces to waiting peers) + mock_bitswap.add_block.assert_called_once() + ok("bitswap.add_block was called (peers notified)") + + +async def test_get_blocks_batch_local_hits_skip_network(): + print("\n[5] get_blocks_batch — local hits skip network") + blocks = [make_block(f"block {i}".encode()) for i in range(5)] + service, store, mock_bitswap = make_service() + + # Store all 5 locally + for cid, data in blocks: + await store.put_block(cid, data) + + cids: list[bytes] = [cid for cid, _ in blocks] + results = await service.get_blocks_batch(cids) + + assert len(results) == 5 + ok("all 5 blocks returned from local store") + mock_bitswap.get_blocks_batch.assert_not_called() + ok("network batch fetch was NOT called") + + +async def test_get_blocks_batch_partial_local(): + print("\n[6] get_blocks_batch — partial local, rest from network") + local_blocks = [make_block(f"local {i}".encode()) for i in range(3)] + net_blocks = [make_block(f"remote {i}".encode()) for i in range(2)] + network_dict = {bytes(cid): data for cid, data in net_blocks} + + service, store, mock_bitswap = make_service(network_blocks=network_dict) + + # Store only local blocks + for cid, data in local_blocks: + await store.put_block(cid, data) + + all_cids: list[bytes] = [cid for cid, _ in local_blocks + net_blocks] + results = await service.get_blocks_batch(all_cids) + + assert len(results) == 5 + ok("all 5 blocks returned (3 local + 2 network)") + mock_bitswap.get_blocks_batch.assert_called_once() + ok("network batch fetch called exactly once (only for 2 missing blocks)") + + # Network blocks must now be cached locally + for cid, data in net_blocks: + cached = await store.get_block(cid) + assert cached == data + ok("network-fetched blocks are now cached locally") + + +async def test_missing_block_returns_none(): + print("\n[7] get_block returns None when block not found anywhere") + cid, _ = make_block(b"this block does not exist") + service, store, mock_bitswap = make_service(network_blocks={}) # empty network + + result = await service.get_block(cid) + assert result is None + ok("get_block returns None for unknown block") + + +async def test_merkledag_uses_block_service(): + print("\n[8] MerkleDag.add_bytes routes through BlockService") + from libp2p.bitswap.dag import MerkleDag + + service, store, mock_bitswap = make_service() + dag = MerkleDag(mock_bitswap, block_service=service) + + data = b"hello block service" * 100 + root_cid = await dag.add_bytes(data) + + # All blocks must be in the local store via BlockService + cached = await store.get_block(root_cid) + assert cached is not None + ok("root block is in local store via BlockService") + + # bitswap.add_block was called (for peer announcement) + assert mock_bitswap.add_block.called + ok("bitswap.add_block was called for peer announcement") + + # MerkleDag without BlockService still works (no regression) + service2, store2, mock_bitswap2 = make_service() + dag2 = MerkleDag(mock_bitswap2) # no block_service + root_cid2 = await dag2.add_bytes(data) + assert root_cid2 is not None + ok("MerkleDag without BlockService still works (no regression)") + + +# ── main ────────────────────────────────────────────────────────────────────── + + +async def main(): + print("=" * 60) + print("BlockService — Test Suite") + print("=" * 60) + + await test_local_hit_no_network() + await test_local_miss_goes_to_network() + await test_auto_cache_after_network_fetch() + await test_put_block_stores_and_announces() + await test_get_blocks_batch_local_hits_skip_network() + await test_get_blocks_batch_partial_local() + await test_missing_block_returns_none() + await test_merkledag_uses_block_service() + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + + +if __name__ == "__main__": + trio.run(main) diff --git a/tests/core/bitswap/test_cid.py b/tests/core/bitswap/test_cid.py index 741c2d769..54d121a5a 100644 --- a/tests/core/bitswap/test_cid.py +++ b/tests/core/bitswap/test_cid.py @@ -352,7 +352,9 @@ def test_cid_to_bytes_and_text_roundtrip(): roundtrip_bytes = cid_to_bytes(cid_text) assert roundtrip_bytes == cid_bytes - assert cid_text == str(make_cid(cid_bytes)) + # String representations might differ by base (e.g. base32 vs base58btc) + # but they should parse to the same bytes. + assert cid_to_bytes(cid_text) == cid_to_bytes(str(make_cid(cid_bytes))) def test_object_wrappers_for_v0_and_v1(): diff --git a/tests/core/bitswap/test_dag.py b/tests/core/bitswap/test_dag.py index e94fb0f1a..c883cf3f2 100644 --- a/tests/core/bitswap/test_dag.py +++ b/tests/core/bitswap/test_dag.py @@ -66,13 +66,17 @@ async def test_add_small_bytes(self): # Verify assert root_cid is not None assert len(root_cid) > 0 + + # Small data is stored as a raw leaf node (RawLeaves=True default) + expected_cid = compute_cid_v1(data, codec=CODEC_RAW) + assert root_cid == expected_cid assert verify_cid(root_cid, data) - # Should be single block (RAW codec) + # Should be single block (raw codec) mock_client.add_block.assert_called_once() call_args = mock_client.add_block.call_args assert call_args[0][0] == root_cid # CID - assert call_args[0][1] == data # Data + assert call_args[0][1] == data # raw data @pytest.mark.trio async def test_add_large_bytes(self): @@ -161,9 +165,12 @@ async def test_add_small_file(self): assert root_cid is not None mock_client.add_block.assert_called_once() - # Should be single RAW block + # Small file is stored as a raw leaf node call_args = mock_client.add_block.call_args - assert verify_cid(call_args[0][0], data) + stored_cid = call_args[0][0] + stored_block = call_args[0][1] + assert stored_block == data + assert verify_cid(stored_cid, data) finally: Path(temp_path).unlink() @@ -243,9 +250,8 @@ async def test_add_file_with_custom_chunk_size(self): temp_path, chunk_size=chunk_size, wrap_with_directory=False ) - # Should have many chunks - # (3.2MB / 16KB = 200 chunks) + 1 root = 201 calls - assert mock_client.add_block.call_count == 201 + # (3.2MB / 16KB = 200 chunks) + intermediate nodes + 1 root + assert mock_client.add_block.call_count > 200 finally: Path(temp_path).unlink() @@ -285,16 +291,22 @@ async def test_fetch_small_file(self, cid_input_kind: str): @pytest.mark.trio async def test_fetch_chunked_file(self): """Test fetching multi-chunk file.""" - # Create chunks + from libp2p.bitswap.dag_pb import create_leaf_node + + # Create dag-pb leaf blocks (matching what add_bytes/add_file produces) chunk1 = b"chunk1" * 1000 chunk2 = b"chunk2" * 1000 chunk3 = b"chunk3" * 1000 - cid1 = compute_cid_v1(chunk1, codec=CODEC_RAW) - cid2 = compute_cid_v1(chunk2, codec=CODEC_RAW) - cid3 = compute_cid_v1(chunk3, codec=CODEC_RAW) + leaf1 = create_leaf_node(chunk1) + leaf2 = create_leaf_node(chunk2) + leaf3 = create_leaf_node(chunk3) + + cid1 = compute_cid_v1(leaf1, codec=CODEC_DAG_PB) + cid2 = compute_cid_v1(leaf2, codec=CODEC_DAG_PB) + cid3 = compute_cid_v1(leaf3, codec=CODEC_DAG_PB) - # Create DAG-PB root node + # Create DAG-PB root node linking to the leaves chunks_data = [ (cid1, len(chunk1)), (cid2, len(chunk2)), @@ -308,11 +320,11 @@ def get_block_side_effect(cid, peer_id, timeout): if cid == root_cid: return root_data elif cid == cid1: - return chunk1 + return leaf1 elif cid == cid2: - return chunk2 + return leaf2 elif cid == cid3: - return chunk3 + return leaf3 raise ValueError(f"Unknown CID: {cid.hex()}") mock_client = MagicMock(spec=BitswapClient) @@ -324,23 +336,30 @@ def get_block_side_effect(cid, peer_id, timeout): # Fetch fetched_data, filename = await dag.fetch_file(root_cid, timeout=30.0) - # Verify + # Verify reconstructed data expected_data = chunk1 + chunk2 + chunk3 assert fetched_data == expected_data assert filename is None # File node without directory wrapper - # Should have fetched root + 3 chunks + # root fetch (1) + tree-level batch fallback (3) = 4 + # Leaves are already fetched during tree traversal, + # no separate leaf fetch needed assert mock_client.get_block.call_count == 4 @pytest.mark.trio async def test_fetch_file_with_progress(self): """Test fetching with progress callback.""" - # Create chunked file + from libp2p.bitswap.dag_pb import create_leaf_node + + # Create dag-pb leaf blocks (matching what add_bytes/add_file produces) chunk1 = b"x" * 1000 chunk2 = b"y" * 1000 - cid1 = compute_cid_v1(chunk1, codec=CODEC_RAW) - cid2 = compute_cid_v1(chunk2, codec=CODEC_RAW) + leaf1 = create_leaf_node(chunk1) + leaf2 = create_leaf_node(chunk2) + + cid1 = compute_cid_v1(leaf1, codec=CODEC_DAG_PB) + cid2 = compute_cid_v1(leaf2, codec=CODEC_DAG_PB) root_data = create_file_node([(cid1, len(chunk1)), (cid2, len(chunk2))]) root_cid = compute_cid_v1(root_data, codec=CODEC_DAG_PB) @@ -350,9 +369,9 @@ def get_block_side_effect(cid, peer_id, timeout): if cid == root_cid: return root_data elif cid == cid1: - return chunk1 + return leaf1 elif cid == cid2: - return chunk2 + return leaf2 mock_client = MagicMock(spec=BitswapClient) mock_client.block_store = MemoryBlockStore() @@ -370,8 +389,8 @@ def progress_callback(current, total, status): # Verify progress assert len(progress_calls) > 0 - # Should report progress for each chunk - assert any("fetching chunk" in call[2] for call in progress_calls) + # Implementation emits "downloading" per leaf and "completed" at end + assert any(call[2] in ("downloading", "completed") for call in progress_calls) # Last call should be completion assert progress_calls[-1][2] == "completed" diff --git a/tests/core/bitswap/test_filesystem_blockstore.py b/tests/core/bitswap/test_filesystem_blockstore.py new file mode 100644 index 000000000..edf691170 --- /dev/null +++ b/tests/core/bitswap/test_filesystem_blockstore.py @@ -0,0 +1,209 @@ +""" +Manual test for FilesystemBlockStore. + +Tests: + 1. Basic put/get/has/delete round-trip + 2. Persistence: blocks survive store re-creation (simulates process restart) + 3. get_all_cids: scans the directory tree and returns all stored CIDs + 4. Drop-in replacement: swapping MemoryBlockStore → FilesystemBlockStore + +Run with: + python test_filesystem_blockstore.py + or + pytest test_filesystem_blockstore.py +""" + +from pathlib import Path +import shutil +import tempfile + +import pytest +import trio + +from libp2p.bitswap.block_store import FilesystemBlockStore, MemoryBlockStore +from libp2p.bitswap.cid import CODEC_RAW, cid_to_text, compute_cid_v1 + +# ── helpers ────────────────────────────────────────────────────────────────── + + +def make_block(content: bytes) -> tuple[bytes, bytes]: + """Return (cid_bytes, data) for a raw block.""" + cid = compute_cid_v1(content, codec=CODEC_RAW) + return cid, content + + +def pass_fail(label: str, ok: bool) -> None: + icon = "✅" if ok else "❌" + print(f" {icon} {label}") + if not ok: + raise AssertionError(f"FAILED: {label}") + + +# ── pytest fixtures ─────────────────────────────────────────────────────────── + + +@pytest.fixture +def store_path(tmp_path): + """Provide a fresh temporary directory path for each test.""" + return str(tmp_path) + + +# ── tests ───────────────────────────────────────────────────────────────────── + + +@pytest.mark.trio +async def test_basic_round_trip(store_path: str) -> None: + print("\n[1] Basic put / get / has / delete") + store = FilesystemBlockStore(store_path) + + cid, data = make_block(b"hello filesystem blockstore") + + # has_block → False before put + pass_fail("has_block returns False before put", not await store.has_block(cid)) + + # put_block + await store.put_block(cid, data) + pass_fail("block file exists on disk after put", store._cid_to_path(cid).exists()) + + # get_block + fetched = await store.get_block(cid) + pass_fail("get_block returns correct data", fetched == data) + + # has_block → True after put + pass_fail("has_block returns True after put", await store.has_block(cid)) + + # delete_block + await store.delete_block(cid) + pass_fail("block file gone after delete", not store._cid_to_path(cid).exists()) + pass_fail("get_block returns None after delete", await store.get_block(cid) is None) + + +@pytest.mark.trio +async def test_persistence(store_path: str) -> None: + print("\n[2] Persistence across store re-creation (simulates process restart)") + + # Write with first instance + store1 = FilesystemBlockStore(store_path) + cid1, data1 = make_block(b"block that should survive restart") + cid2, data2 = make_block(b"another persistent block") + await store1.put_block(cid1, data1) + await store1.put_block(cid2, data2) + pass_fail("2 blocks written by store1", store1.size() == 2) + + # Create a brand-new store object pointing to the same path + # (simulates a process restart) + store2 = FilesystemBlockStore(store_path) + pass_fail( + "store2 sees block1 written by store1", await store2.get_block(cid1) == data1 + ) + pass_fail( + "store2 sees block2 written by store1", await store2.get_block(cid2) == data2 + ) + pass_fail("store2.size() == 2", store2.size() == 2) + + print(f" Block directory: {store2.base_path()}") + print(f" CID1: {cid_to_text(cid1)}") + print(f" CID2: {cid_to_text(cid2)}") + + +@pytest.mark.trio +async def test_get_all_cids(store_path: str) -> None: + print("\n[3] get_all_cids scans directory tree") + store = FilesystemBlockStore(store_path) + + blocks = [make_block(f"block {i}".encode()) for i in range(5)] + for cid, data in blocks: + await store.put_block(cid, data) + + all_cids = store.get_all_cids() + pass_fail(f"get_all_cids returns {len(blocks)} CIDs", len(all_cids) == len(blocks)) + + stored_set = {bytes(c) for c in all_cids} + for cid, _ in blocks: + pass_fail( + f"CID {cid_to_text(cid)[:20]}... is in get_all_cids", + bytes(cid) in stored_set, + ) + + +@pytest.mark.trio +async def test_get_missing_returns_none(store_path: str) -> None: + print("\n[4] get_block returns None for missing CID") + store = FilesystemBlockStore(store_path) + cid, _ = make_block(b"this block was never stored") + result = await store.get_block(cid) + pass_fail("get_block returns None for unknown CID", result is None) + + +@pytest.mark.trio +async def test_drop_in_for_memory_store(store_path: str) -> None: + print("\n[5] Drop-in replacement for MemoryBlockStore") + + async def use_store(store) -> bytes: + """Same code works for both store types.""" + cid, data = make_block(b"drop-in replacement test") + await store.put_block(cid, data) + return await store.get_block(cid) + + mem_result = await use_store(MemoryBlockStore()) + fs_result = await use_store(FilesystemBlockStore(store_path)) + + pass_fail( + "MemoryBlockStore and FilesystemBlockStore return same data", + mem_result == fs_result, + ) + + +@pytest.mark.trio +async def test_directory_structure(store_path: str) -> None: + print("\n[6] 2-char prefix directory structure") + store = FilesystemBlockStore(store_path) + cid, data = make_block(b"check directory layout") + await store.put_block(cid, data) + + from cid import make_cid + + cid_str = str(make_cid(cid)) + expected_dir = Path(store_path) / cid_str[:2] + expected_file = expected_dir / cid_str[2:] + + pass_fail(f"2-char prefix dir '{cid_str[:2]}' exists", expected_dir.is_dir()) + pass_fail( + f"block file '{cid_str[2:8]}...' exists inside prefix dir", + expected_file.exists(), + ) + pass_fail("file contents match original data", expected_file.read_bytes() == data) + + print(f" Path: {expected_file}") + + +# ── main ────────────────────────────────────────────────────────────────────── + + +async def main() -> None: + print("=" * 60) + print("FilesystemBlockStore — Manual Test Suite") + print("=" * 60) + + # Each test gets its own temp directory so they don't interfere + dirs = [tempfile.mkdtemp(prefix="fs_blockstore_test_") for _ in range(6)] + + try: + await test_basic_round_trip(dirs[0]) + await test_persistence(dirs[1]) + await test_get_all_cids(dirs[2]) + await test_get_missing_returns_none(dirs[3]) + await test_drop_in_for_memory_store(dirs[4]) + await test_directory_structure(dirs[5]) + + print("\n" + "=" * 60) + print("✅ All tests passed!") + print("=" * 60) + + finally: + for d in dirs: + shutil.rmtree(d, ignore_errors=True) + + +if __name__ == "__main__": + trio.run(main) diff --git a/tests/core/bitswap/test_io_stream.py b/tests/core/bitswap/test_io_stream.py new file mode 100644 index 000000000..a8949036d --- /dev/null +++ b/tests/core/bitswap/test_io_stream.py @@ -0,0 +1,283 @@ +""" +Test io.IOBase input support — chunk_stream() and MerkleDag.add_stream(). + +Run with: + python test_io_stream.py +""" + +import gzip +import io +import os +import tempfile + +import trio + +from libp2p.bitswap.block_store import MemoryBlockStore +from libp2p.bitswap.chunker import DEFAULT_CHUNK_SIZE, chunk_stream +from libp2p.bitswap.cid import cid_to_text + + +def ok(label): + print(f" OK {label}") + + +# ── 1. chunk_stream basics ──────────────────────────────────────────────────── + + +def test_chunk_stream_bytesio(): + print("\n[1] chunk_stream — BytesIO") + data = b"x" * (DEFAULT_CHUNK_SIZE * 3 + 100) # 3 full + 1 partial chunk + chunks = list(chunk_stream(io.BytesIO(data), DEFAULT_CHUNK_SIZE)) + assert len(chunks) == 4 + assert b"".join(chunks) == data + assert len(chunks[0]) == DEFAULT_CHUNK_SIZE + assert len(chunks[-1]) == 100 + ok(f"4 chunks, sizes: {[len(c) for c in chunks]}") + + +def test_chunk_stream_empty(): + print("\n[2] chunk_stream — empty stream yields nothing") + chunks = list(chunk_stream(io.BytesIO(b""))) + assert chunks == [] + ok("empty stream yields no chunks") + + +def test_chunk_stream_file_handle(): + print("\n[3] chunk_stream — real file handle") + data = b"file handle test " * 5000 + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(data) + tmp = f.name + try: + with open(tmp, "rb") as fh: + chunks = list(chunk_stream(fh)) + assert b"".join(chunks) == data + ok(f"file handle: {len(chunks)} chunks, {len(data)} bytes total") + finally: + os.unlink(tmp) + + +def test_chunk_stream_gzip(): + print("\n[4] chunk_stream — gzip stream (decompress on-the-fly)") + original = b"compressed data " * 10000 + buf = io.BytesIO() + with gzip.GzipFile(fileobj=buf, mode="wb") as gz: + gz.write(original) + buf.seek(0) + + with gzip.GzipFile(fileobj=buf, mode="rb") as gz: + chunks = list(chunk_stream(gz)) + + assert b"".join(chunks) == original + ok(f"gzip stream: {len(chunks)} chunks, {len(original)} bytes decompressed") + + +def test_chunk_stream_matches_chunk_bytes(): + print("\n[5] chunk_stream produces same chunks as chunk_bytes") + from libp2p.bitswap.chunker import chunk_bytes + + data = os.urandom(DEFAULT_CHUNK_SIZE * 5 + 777) + stream_chunks = list(chunk_stream(io.BytesIO(data))) + bytes_chunks = chunk_bytes(data) + assert stream_chunks == bytes_chunks + ok(f"chunk_stream == chunk_bytes for {len(data)} bytes of random data") + + +# ── 2. MerkleDag.add_stream ─────────────────────────────────────────────────── + + +async def test_add_stream_bytesio(): + print("\n[6] add_stream — BytesIO produces same CID as add_bytes") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + store = MemoryBlockStore() + mock = MagicMock(spec=BitswapClient) + mock.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block(cid, data): + stored[bytes(cid)] = data + + mock.add_block = AsyncMock(side_effect=add_block) + + dag = MerkleDag(mock) + data = b"same content " * 5000 + + cid_bytes = await dag.add_bytes(data) + stored.clear() + cid_stream = await dag.add_stream(io.BytesIO(data)) + + assert bytes(cid_bytes) == bytes(cid_stream), ( + f"CIDs differ:\n add_bytes: {cid_to_text(cid_bytes)}\n" + f" add_stream: {cid_to_text(cid_stream)}" + ) + ok(f"add_stream CID == add_bytes CID: {cid_to_text(cid_stream)[:30]}...") + + +async def test_add_stream_empty(): + print("\n[7] add_stream — empty stream stores single empty leaf") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + store = MemoryBlockStore() + mock = MagicMock(spec=BitswapClient) + mock.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block(cid, data): + stored[bytes(cid)] = data + + mock.add_block = AsyncMock(side_effect=add_block) + + dag = MerkleDag(mock) + await dag.add_stream(io.BytesIO(b"")) + + assert len(stored) == 1 + block = list(stored.values())[0] + assert block == b"" + ok("empty stream → 1 empty raw leaf block stored") + + +async def test_add_stream_single_chunk(): + print("\n[8] add_stream — single chunk returns leaf CID directly (no root node)") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + store = MemoryBlockStore() + mock = MagicMock(spec=BitswapClient) + mock.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block(cid, data): + stored[bytes(cid)] = data + + mock.add_block = AsyncMock(side_effect=add_block) + + dag = MerkleDag(mock) + data = b"small enough to be one chunk" + root_cid = await dag.add_stream(io.BytesIO(data)) + + assert len(stored) == 1, f"expected 1 block, got {len(stored)}" + block = stored[bytes(root_cid)] + assert block == data + ok("single chunk: leaf CID returned directly, inline data correct") + + +async def test_add_stream_gzip(): + print("\n[9] add_stream — gzip stream decompresses and adds correctly") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + original = b"gzip content " * 20000 # ~260 KB — 2 chunks after decompress + + buf = io.BytesIO() + with gzip.GzipFile(fileobj=buf, mode="wb") as gz: + gz.write(original) + compressed_size = buf.tell() + buf.seek(0) + + store = MemoryBlockStore() + mock = MagicMock(spec=BitswapClient) + mock.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block(cid, data): + stored[bytes(cid)] = data + + mock.add_block = AsyncMock(side_effect=add_block) + + dag = MerkleDag(mock) + + with gzip.GzipFile(fileobj=buf, mode="rb") as gz: + root_cid = await dag.add_stream(gz) + + # Since it's < 256KB, it's a single raw chunk + root_block = stored[bytes(root_cid)] + assert root_block == original + ok( + f"gzip stream: {compressed_size} compressed → {len(original)} bytes added " + f"as a single chunk" + ) + + +async def test_add_stream_vs_add_file_same_cid(): + print("\n[10] add_stream(open(f)) produces same CID as add_file(path)") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + data = b"compare stream vs file " * 8000 # ~176 KB, 3 chunks + + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(data) + tmp = f.name + + try: + + def make_dag(): + store = MemoryBlockStore() + mock = MagicMock(spec=BitswapClient) + mock.block_store = store + stored = {} + + async def add_block(cid, d): + stored[bytes(cid)] = d + + mock.add_block = AsyncMock(side_effect=add_block) + return MerkleDag(mock) + + dag1 = make_dag() + cid_file = await dag1.add_file(tmp, wrap_with_directory=False) + + dag2 = make_dag() + with open(tmp, "rb") as fh: + cid_stream = await dag2.add_stream(fh) + + assert bytes(cid_file) == bytes(cid_stream), ( + f"CIDs differ:\n add_file: {cid_to_text(cid_file)}\n" + f" add_stream: {cid_to_text(cid_stream)}" + ) + ok(f"add_file == add_stream CID: {cid_to_text(cid_file)[:30]}...") + finally: + os.unlink(tmp) + + +# ── main ────────────────────────────────────────────────────────────────────── + + +async def main(): + print("=" * 60) + print("io.IOBase Input Support — Test Suite") + print("=" * 60) + + # sync tests + test_chunk_stream_bytesio() + test_chunk_stream_empty() + test_chunk_stream_file_handle() + test_chunk_stream_gzip() + test_chunk_stream_matches_chunk_bytes() + + # async tests + await test_add_stream_bytesio() + await test_add_stream_empty() + await test_add_stream_single_chunk() + await test_add_stream_gzip() + await test_add_stream_vs_add_file_same_cid() + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + + +if __name__ == "__main__": + trio.run(main) diff --git a/tests/core/bitswap/test_payment.py b/tests/core/bitswap/test_payment.py new file mode 100644 index 000000000..8e944e566 --- /dev/null +++ b/tests/core/bitswap/test_payment.py @@ -0,0 +1,110 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from libp2p.bitswap.gated_decision_engine import PaymentGatedDecisionEngine +from libp2p.bitswap.payment_ledger import PaymentLedger +from libp2p.bitswap.pricing_engine import BlockPricingEngine + + +def test_block_pricing_engine_size_based(): + pricing = BlockPricingEngine(strategy="size_based", units_per_kb=10.0) + # 500 KB = 512000 bytes + price = pricing.compute_price("cid1", 512000) + assert price == 5000 + + +def test_block_pricing_engine_fixed(): + pricing = BlockPricingEngine(strategy="fixed", fixed_price=123) + price = pricing.compute_price("cid1", 512000) + assert price == 123 + + +def test_block_pricing_engine_free(): + pricing = BlockPricingEngine(strategy="free") + price = pricing.compute_price("cid1", 512000) + assert price == 0 + + +def test_block_pricing_engine_overrides(): + pricing = BlockPricingEngine(strategy="fixed", fixed_price=100) + pricing.set_price(b"cid2", 50) + pricing.set_free(b"cid3") + + assert pricing.compute_price(b"cid1".hex(), 10) == 100 + assert pricing.compute_price(b"cid2".hex(), 10) == 50 + assert pricing.compute_price(b"cid3".hex(), 10) == 0 + + +@pytest.mark.trio +async def test_payment_ledger_registration_and_payment(): + ledger = PaymentLedger() + + root_cid = b"root" + child_cids: list[bytes | str] = [b"child1", b"child2"] + + await ledger.register_dag(root_cid, child_cids) + + assert not ledger.is_paid("peer1", b"child1") + + await ledger.record_payment("peer1", b"root", amount=1000, nonce=b"nonce1") + + # After payment for root, root and children should be considered paid + assert ledger.is_paid("peer1", b"root") + assert ledger.is_paid("peer1", b"child1") + assert ledger.is_paid("peer1", b"child2") + + # Peer 2 has not paid + assert not ledger.is_paid("peer2", b"root") + + +@pytest.mark.trio +async def test_payment_ledger_nonce_replay(): + ledger = PaymentLedger() + + await ledger.record_payment("peer1", b"root", amount=1000, nonce=b"nonce1") + + with pytest.raises(ValueError, match="Nonce already used"): + await ledger.record_payment("peer1", b"root", amount=1000, nonce=b"nonce1") + + # Different nonce should work + await ledger.record_payment("peer1", b"root", amount=1000, nonce=b"nonce2") + + +@pytest.mark.trio +async def test_gated_decision_engine_auth(): + # Setup mocks + blockstore = AsyncMock() + blockstore.get_block.return_value = b"block data" + + ledger = PaymentLedger() + pricing = BlockPricingEngine(strategy="fixed", fixed_price=100) + + engine = PaymentGatedDecisionEngine( + blockstore=blockstore, ledger=ledger, pricing=pricing, tx_verifier=None + ) + + auth = MagicMock() + auth.cid = b"cid1" + auth.value = 50 + auth.from_address = "0x..." + auth.nonce = b"nonce1" + + # Payment less than expected + response = await engine.handle_payment_authorization("peer1", auth) + assert len(response.payment_rejections) == 1 + assert "INSUFFICIENT_PAYMENT" in response.payment_rejections[0].reason + + # Payment sufficient + auth.value = 100 + response = await engine.handle_payment_authorization("peer1", auth) + assert len(response.payment_receipts) == 1 + assert len(response.payload) == 1 + assert response.payload[0].data == b"block data" + + # Nonce reused + auth.nonce = b"nonce1" + # Should still succeed because ledger is_paid will return true + # and it skips re-verification + response = await engine.handle_payment_authorization("peer1", auth) + assert len(response.payment_receipts) == 1 diff --git a/tests/core/bitswap/test_provider_query.py b/tests/core/bitswap/test_provider_query.py new file mode 100644 index 000000000..8617cc6eb --- /dev/null +++ b/tests/core/bitswap/test_provider_query.py @@ -0,0 +1,450 @@ +""" +Tests for ProviderQueryManager and its integration with BitswapClient. + +Covers: +- ProviderCacheEntry – TTL, expiry +- ProviderCache – LRU eviction, TTL, cleanup, stats +- ProviderQueryManager – single/batch queries, cache hit/miss, + max_providers cap, error handling, stats +- BitswapClient integration – provider_query_manager wired at construction, + get_block() uses DHT discovery +""" + +from __future__ import annotations + +import time +from unittest.mock import Mock + +import pytest +import trio + +from libp2p.bitswap.block_store import MemoryBlockStore +from libp2p.bitswap.cid import cid_to_bytes, compute_cid_v0, parse_cid +from libp2p.bitswap.client import BitswapClient +from libp2p.bitswap.provider_query import ( + ProviderCache, + ProviderCacheEntry, + ProviderQueryManager, +) +from libp2p.peer.id import ID as PeerID +from libp2p.peer.peerinfo import PeerInfo + +# ── helpers ─────────────────────────────────────────────────────────────────── + +PEER_A = PeerID.from_base58("QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN") +PEER_B = PeerID.from_base58("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ") +PEER_C = PeerID.from_base58("QmSoLV4Bbm51jM9C4gDYZQ9Cy3U6aXMJDAbzgu2fzaDs64") + +SAMPLE_PEERS = [PEER_A, PEER_B, PEER_C] + +CID_1 = parse_cid(compute_cid_v0(b"block-one")) +CID_2 = parse_cid(compute_cid_v0(b"block-two")) +CID_3 = parse_cid(compute_cid_v0(b"block-three")) + +SAMPLE_CIDS = [CID_1, CID_2, CID_3] + + +def _mock_dht(return_peers: list[PeerID] | None = None) -> Mock: + """ + Return a mock DHT whose provider_store.find_providers returns *return_peers*. + + find_providers is the async network lookup path; get_providers is the + local-store read that ProviderQueryManager no longer calls directly. + """ + dht = Mock() + dht.provider_store = Mock() + peer_infos = [PeerInfo(p, []) for p in (return_peers or [])] + + async def _async_find_providers(key: bytes, count: int = 20) -> list[PeerInfo]: + return peer_infos[:count] + + dht.provider_store.find_providers = Mock(side_effect=_async_find_providers) + return dht + + +# ═════════════════════════════════════════════════════════════════════════════ +# ProviderCacheEntry +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestProviderCacheEntry: + def test_fresh_entry_not_expired(self) -> None: + entry = ProviderCacheEntry(providers=SAMPLE_PEERS, ttl=300) + assert not entry.is_expired() + assert entry.age() < 1.0 + + def test_entry_with_past_timestamp_is_expired(self) -> None: + entry = ProviderCacheEntry( + providers=SAMPLE_PEERS, + timestamp=time.time() - 10, + ttl=5, + ) + assert entry.is_expired() + + def test_default_ttl_applied(self) -> None: + entry = ProviderCacheEntry(providers=[PEER_A]) + assert entry.ttl == 300 + + +# ═════════════════════════════════════════════════════════════════════════════ +# ProviderCache +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestProviderCache: + def test_put_and_get(self) -> None: + cache = ProviderCache(max_size=10, default_ttl=60) + cache.put(b"k1", SAMPLE_PEERS) + assert cache.get(b"k1") == SAMPLE_PEERS + + def test_miss_returns_none(self) -> None: + cache = ProviderCache() + assert cache.get(b"no-such-key") is None + + def test_expired_entry_returns_none(self) -> None: + cache = ProviderCache(max_size=10, default_ttl=300) + cache.put(b"k1", SAMPLE_PEERS, ttl=0.01) + time.sleep(0.05) + assert cache.get(b"k1") is None + + def test_lru_evicts_oldest(self) -> None: + cache = ProviderCache(max_size=3, default_ttl=300) + cache.put(b"a", [PEER_A]) + cache.put(b"b", [PEER_B]) + cache.put(b"c", [PEER_C]) + cache.get(b"a") # mark 'a' recently used + cache.put(b"d", [PEER_A]) # 'b' should be evicted + assert cache.get(b"b") is None + assert cache.get(b"a") is not None + assert cache.get(b"d") is not None + + def test_clear_empties_cache(self) -> None: + cache = ProviderCache(max_size=10, default_ttl=300) + cache.put(b"k1", [PEER_A]) + cache.put(b"k2", [PEER_B]) + cache.clear() + assert cache.size() == 0 + + def test_cleanup_expired_removes_stale(self) -> None: + cache = ProviderCache(max_size=10, default_ttl=300) + cache.put(b"stale", [PEER_A], ttl=0.01) + cache.put(b"fresh", [PEER_B], ttl=300) + time.sleep(0.05) + removed = cache.cleanup_expired() + assert removed == 1 + assert cache.size() == 1 + + def test_stats_keys_present(self) -> None: + cache = ProviderCache(max_size=5, default_ttl=300) + cache.put(b"k", [PEER_A]) + stats = cache.stats() + assert {"size", "max_size", "expired"} <= stats.keys() + assert stats["size"] == 1 + assert stats["max_size"] == 5 + + +# ═════════════════════════════════════════════════════════════════════════════ +# ProviderQueryManager +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestProviderQueryManager: + @pytest.mark.trio + async def test_cache_miss_queries_dht(self) -> None: + dht = _mock_dht(return_peers=[PEER_A]) + mgr = ProviderQueryManager(dht) + + providers = await mgr.find_providers_single(CID_1, timeout=5.0) + + assert providers == [PEER_A] + stats = mgr.get_stats() + assert stats["queries"] == 1 + assert stats["cache_misses"] == 1 + assert stats["cache_hits"] == 0 + assert stats["providers_found"] == 1 + # Verify the async network path was used, not the local store read + dht.provider_store.find_providers.assert_called_once() + + @pytest.mark.trio + async def test_cache_hit_skips_dht(self) -> None: + dht = _mock_dht() + mgr = ProviderQueryManager(dht) + mgr.cache.put(cid_to_bytes(CID_1), [PEER_B]) + + providers = await mgr.find_providers_single(CID_1) + + assert providers == [PEER_B] + dht.provider_store.find_providers.assert_not_called() + assert mgr.get_stats()["cache_hits"] == 1 + + @pytest.mark.trio + async def test_second_call_uses_cache(self) -> None: + dht = _mock_dht(return_peers=[PEER_A]) + mgr = ProviderQueryManager(dht) + + await mgr.find_providers_single(CID_1) # miss + await mgr.find_providers_single(CID_1) # hit + + stats = mgr.get_stats() + assert stats["queries"] == 1 # no extra DHT call + assert stats["cache_hits"] == 1 + + @pytest.mark.trio + async def test_max_providers_cap(self) -> None: + dht = _mock_dht(return_peers=SAMPLE_PEERS) + mgr = ProviderQueryManager(dht, max_providers=1) + + providers = await mgr.find_providers_single(CID_1) + assert len(providers) == 1 + + @pytest.mark.trio + async def test_no_providers_returns_empty(self) -> None: + dht = _mock_dht(return_peers=[]) + mgr = ProviderQueryManager(dht) + providers = await mgr.find_providers_single(CID_1) + assert providers == [] + + @pytest.mark.trio + async def test_dht_error_increments_errors(self) -> None: + dht = _mock_dht() + + async def _raise(*_args: object, **_kwargs: object) -> None: + raise RuntimeError("dht down") + + dht.provider_store.find_providers = Mock(side_effect=_raise) + mgr = ProviderQueryManager(dht) + + providers = await mgr.find_providers_single(CID_1, timeout=5.0) + + assert providers == [] + assert mgr.get_stats()["errors"] == 1 + + @pytest.mark.trio + async def test_batch_all_cache_hits(self) -> None: + dht = _mock_dht() + mgr = ProviderQueryManager(dht) + for cid in SAMPLE_CIDS: + mgr.cache.put(cid_to_bytes(cid), [PEER_A]) + + results = await mgr.find_providers(SAMPLE_CIDS) + + assert len(results) == 3 + dht.provider_store.find_providers.assert_not_called() + + @pytest.mark.trio + async def test_batch_partial_cache(self) -> None: + dht = _mock_dht(return_peers=[PEER_B]) + mgr = ProviderQueryManager(dht) + # Pre-cache only first CID + mgr.cache.put(cid_to_bytes(CID_1), [PEER_A]) + + results = await mgr.find_providers(SAMPLE_CIDS) + + assert len(results) == 3 + # Only 2 DHT calls (CID_2 and CID_3 are cache misses) + assert dht.provider_store.find_providers.call_count == 2 + + @pytest.mark.trio + async def test_use_cache_false_always_queries_dht(self) -> None: + dht = _mock_dht(return_peers=[PEER_A]) + mgr = ProviderQueryManager(dht) + mgr.cache.put(cid_to_bytes(CID_1), [PEER_B]) # pre-populated + + providers = await mgr.find_providers_single(CID_1, use_cache=False) + + # DHT was queried despite cache having an entry + dht.provider_store.find_providers.assert_called_once() + assert providers == [PEER_A] + + @pytest.mark.trio + async def test_clear_cache_forces_new_query(self) -> None: + dht = _mock_dht(return_peers=[PEER_A]) + mgr = ProviderQueryManager(dht) + + await mgr.find_providers_single(CID_1) # miss → cached + await mgr.find_providers_single(CID_1) # hit + mgr.clear_cache() + await mgr.find_providers_single(CID_1) # miss again + + assert mgr.get_stats()["cache_misses"] == 2 + assert dht.provider_store.find_providers.call_count == 2 + + @pytest.mark.trio + async def test_cleanup_expired_cache(self) -> None: + dht = _mock_dht() + mgr = ProviderQueryManager(dht) + mgr.cache.put(cid_to_bytes(CID_1), [PEER_A], ttl=0.01) + mgr.cache.put(cid_to_bytes(CID_2), [PEER_B], ttl=300) + await trio.sleep(0.05) + + removed = await mgr.cleanup_expired_cache() + + assert removed == 1 + assert mgr.cache.size() == 1 + + def test_get_stats_initial_values(self) -> None: + mgr = ProviderQueryManager(_mock_dht()) + stats = mgr.get_stats() + assert stats["queries"] == 0 + assert stats["cache_hits"] == 0 + assert stats["cache_misses"] == 0 + assert stats["errors"] == 0 + assert stats["providers_found"] == 0 + + @pytest.mark.trio + async def test_empty_cid_list(self) -> None: + mgr = ProviderQueryManager(_mock_dht()) + assert await mgr.find_providers([]) == {} + + +# ═════════════════════════════════════════════════════════════════════════════ +# BitswapClient integration +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestBitswapClientProviderQueryIntegration: + """Verify that BitswapClient wires ProviderQueryManager into get_block().""" + + def _make_client( + self, + mock_host: Mock, + pqm: ProviderQueryManager | None = None, + ) -> BitswapClient: + store = MemoryBlockStore() + return BitswapClient(mock_host, block_store=store, provider_query_manager=pqm) + + def test_provider_query_manager_stored_on_client(self, mock_host: Mock) -> None: + dht = _mock_dht() + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + assert client.provider_query_manager is pqm + + def test_no_pqm_by_default(self, mock_host: Mock) -> None: + client = self._make_client(mock_host) + assert client.provider_query_manager is None + + @pytest.mark.trio + async def test_get_block_returns_local_without_dht(self, mock_host: Mock) -> None: + """Local cache hit must never touch the DHT.""" + dht = _mock_dht(return_peers=[PEER_A]) + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + + block_data = b"local block" + cid = parse_cid(compute_cid_v0(block_data)) + await client.block_store.put_block(cid, block_data) + + result = await client.block_store.get_block(cid) + assert result == block_data + # DHT must not have been consulted + dht.provider_store.find_providers.assert_not_called() + + @pytest.mark.trio + async def test_get_block_uses_pqm_to_pick_peer(self, mock_host: Mock) -> None: + """ + When the block is not local, get_block() should call + provider_query_manager.find_providers_single() and use the + returned peer_id. + """ + discovered_peer = PEER_A + block_data = b"remote block" + cid = parse_cid(compute_cid_v0(block_data)) + + dht = _mock_dht(return_peers=[discovered_peer]) + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + + # Patch _request_block so we can inspect the peer_id it receives + captured: dict[str, object] = {} + + async def _fake_request(cid_obj, peer_id, timeout): # noqa: ANN001 + captured["peer_id"] = peer_id + return block_data + + client._request_block = _fake_request # type: ignore[method-assign] + + result = await client.get_block(cid) + + assert result == block_data + assert captured["peer_id"] == discovered_peer + + @pytest.mark.trio + async def test_get_block_falls_back_to_broadcast_when_no_providers( + self, mock_host: Mock + ) -> None: + """ + When the DHT returns no providers, get_block() must still call + _request_block with peer_id=None (broadcast fallback). + """ + dht = _mock_dht(return_peers=[]) + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + + block_data = b"broadcast block" + cid = parse_cid(compute_cid_v0(block_data)) + + captured: dict[str, object] = {} + + async def _fake_request(cid_obj, peer_id, timeout): # noqa: ANN001 + captured["peer_id"] = peer_id + return block_data + + client._request_block = _fake_request # type: ignore[method-assign] + + result = await client.get_block(cid) + + assert result == block_data + assert captured["peer_id"] is None # broadcast + + @pytest.mark.trio + async def test_explicit_peer_id_skips_pqm(self, mock_host: Mock) -> None: + """An explicit peer_id argument must bypass DHT discovery.""" + dht = _mock_dht(return_peers=[PEER_B]) + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + + block_data = b"explicit peer block" + cid = parse_cid(compute_cid_v0(block_data)) + + captured: dict[str, object] = {} + + async def _fake_request(cid_obj, peer_id, timeout): # noqa: ANN001 + captured["peer_id"] = peer_id + return block_data + + client._request_block = _fake_request # type: ignore[method-assign] + + await client.get_block(cid, peer_id=PEER_A) + + # DHT must NOT have been called + dht.provider_store.get_providers.assert_not_called() + # The explicit peer_id must be passed through unchanged + assert captured["peer_id"] == PEER_A + + @pytest.mark.trio + async def test_pqm_error_falls_back_gracefully(self, mock_host: Mock) -> None: + """A crashing PQM must not prevent the block fetch from proceeding.""" + dht = _mock_dht() + + async def _raise(*_args: object, **_kwargs: object) -> None: + raise RuntimeError("dht exploded") + + dht.provider_store.find_providers = Mock(side_effect=_raise) + pqm = ProviderQueryManager(dht) + client = self._make_client(mock_host, pqm) + + block_data = b"fallback block" + cid = parse_cid(compute_cid_v0(block_data)) + + captured: dict[str, object] = {} + + async def _fake_request(cid_obj, peer_id, timeout): # noqa: ANN001 + captured["peer_id"] = peer_id + return block_data + + client._request_block = _fake_request # type: ignore[method-assign] + + result = await client.get_block(cid) + + assert result == block_data + assert captured["peer_id"] is None # graceful broadcast fallback diff --git a/tests/core/bitswap/test_unixfs_encoding.py b/tests/core/bitswap/test_unixfs_encoding.py new file mode 100644 index 000000000..2bcecd0ac --- /dev/null +++ b/tests/core/bitswap/test_unixfs_encoding.py @@ -0,0 +1,266 @@ +""" +Test that add_file / add_bytes now produce dag-pb leaf blocks (UnixFS-wrapped) +and that balanced_layout builds the correct tree structure. + +Run with: + python test_unixfs_encoding.py +""" + +import os +import tempfile + +import trio + +from libp2p.bitswap.block_store import MemoryBlockStore +from libp2p.bitswap.cid import CODEC_DAG_PB, CODEC_RAW, cid_to_text, compute_cid_v1 +from libp2p.bitswap.dag_pb import ( + MAX_LINKS_PER_NODE, + balanced_layout, + create_leaf_node, + decode_dag_pb, + is_file_node, +) + + +def ok(label): + print(f" OK {label}") + + +def fail(label, detail=""): + raise AssertionError(f"FAIL {label} {detail}") + + +# ── 1. create_leaf_node wraps data in dag-pb + UnixFS ──────────────────────── +def test_create_leaf_node(): + print("\n[1] create_leaf_node") + data = b"hello leaf" + leaf = create_leaf_node(data) + + # Must be a valid dag-pb file node + assert is_file_node(leaf), "leaf must be a dag-pb file node" + ok("create_leaf_node produces a dag-pb file node") + + # Decode and check inline data + links, unixfs = decode_dag_pb(leaf) + assert links == [], "leaf must have no links" + assert unixfs is not None + assert unixfs.data == data, f"inline data mismatch: {unixfs.data!r} != {data!r}" + assert unixfs.filesize == len(data) + ok(f"leaf contains inline data ({len(data)} bytes), filesize={unixfs.filesize}") + + # CID must be dag-pb, not raw + cid = compute_cid_v1(leaf, codec=CODEC_DAG_PB) + raw_cid = compute_cid_v1(data, codec=CODEC_RAW) + assert bytes(cid) != bytes(raw_cid), "dag-pb leaf CID must differ from raw CID" + ok(f"leaf CID is dag-pb (not raw): {cid_to_text(cid)[:30]}...") + + # Empty leaf + empty_leaf = create_leaf_node(b"") + _, empty_unixfs = decode_dag_pb(empty_leaf) + assert empty_unixfs is not None + assert empty_unixfs.filesize == 0 + ok("empty leaf node is valid") + + +# ── 2. balanced_layout single leaf ─────────────────────────────────────────── +def test_balanced_layout_single(): + print("\n[2] balanced_layout — single leaf returns leaf unchanged") + data = b"only chunk" + leaf = create_leaf_node(data) + cid = compute_cid_v1(leaf, codec=CODEC_DAG_PB) + + root_cid, root_block, _ = balanced_layout([(cid, leaf, len(data))]) + assert bytes(root_cid) == bytes(cid) + assert root_block == leaf + ok("single leaf: root_cid == leaf_cid") + + +# ── 3. balanced_layout two leaves ──────────────────────────────────────────── +def test_balanced_layout_two_leaves(): + print("\n[3] balanced_layout — two leaves builds one root") + leaves = [] + for i in range(2): + data = f"chunk {i}".encode() * 100 + leaf = create_leaf_node(data) + cid = compute_cid_v1(leaf, codec=CODEC_DAG_PB) + leaves.append((cid, leaf, len(data))) + + root_cid, root_block, _ = balanced_layout(leaves) + + # Root must be a dag-pb file node with 2 links + assert is_file_node(root_block) + links, unixfs = decode_dag_pb(root_block) + assert len(links) == 2, f"expected 2 links, got {len(links)}" + assert unixfs is not None + assert unixfs.filesize == sum(s for _, _, s in leaves) + assert len(unixfs.blocksizes) == 2 + ok(f"root has 2 links, filesize={unixfs.filesize}, blocksizes={unixfs.blocksizes}") + + +# ── 4. balanced_layout 175 leaves builds 2-level tree ──────────────────────── +def test_balanced_layout_two_levels(): + print("\n[4] balanced_layout — 175 leaves builds 2-level tree (174 + 1)") + n = MAX_LINKS_PER_NODE + 1 # 175 + chunk_size = 100 + leaves = [] + for i in range(n): + data = bytes([i % 256]) * chunk_size + leaf = create_leaf_node(data) + cid = compute_cid_v1(leaf, codec=CODEC_DAG_PB) + leaves.append((cid, leaf, chunk_size)) + + root_cid, root_block, _ = balanced_layout(leaves) + links, unixfs = decode_dag_pb(root_block) + + # Root should link to 2 internal nodes (174 + 1) + assert len(links) == 2, f"expected 2 top-level links, got {len(links)}" + assert unixfs is not None + assert unixfs.filesize == n * chunk_size + ok("175 leaves → root has 2 links (174-leaf node + 1-leaf node)") + ok(f"root filesize = {unixfs.filesize} = 175 * {chunk_size}") + + +# ── 5. balanced_layout 174 leaves stays flat ───────────────────────────────── +def test_balanced_layout_flat(): + print("\n[5] balanced_layout — exactly 174 leaves stays flat (1 level)") + n = MAX_LINKS_PER_NODE # 174 + leaves = [] + for i in range(n): + data = bytes([i % 256]) * 50 + leaf = create_leaf_node(data) + cid = compute_cid_v1(leaf, codec=CODEC_DAG_PB) + leaves.append((cid, leaf, 50)) + + root_cid, root_block, _ = balanced_layout(leaves) + links, unixfs = decode_dag_pb(root_block) + + assert len(links) == 174, f"expected 174 direct links, got {len(links)}" + ok("174 leaves → flat root with 174 direct links") + + +# ── 6. add_file produces dag-pb leaves (not raw) via MerkleDag ─────────────── +async def test_add_file_produces_dag_pb_leaves(): + print("\n[6] MerkleDag.add_file produces dag-pb leaf blocks") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + store = MemoryBlockStore() + mock_client = MagicMock(spec=BitswapClient) + mock_client.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block_impl(cid, data): + stored[bytes(cid)] = data + + mock_client.add_block = AsyncMock(side_effect=add_block_impl) + + dag = MerkleDag(mock_client) + + # Write a 3-chunk file + chunk_size = 63 * 1024 + content = b"x" * (chunk_size * 3 - 7) # 3 chunks + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(content) + tmp = f.name + + try: + root_cid = await dag.add_file( + tmp, chunk_size=chunk_size, wrap_with_directory=False + ) + finally: + os.unlink(tmp) + + # Root block must be dag-pb, but leaves must be raw blocks + raw_blocks = [] + dag_pb_blocks = [] + for cid_bytes, block_data in stored.items(): + if is_file_node(block_data): + dag_pb_blocks.append(cid_bytes) + else: + raw_blocks.append(cid_bytes) + + assert len(dag_pb_blocks) == 1, f"Expected 1 root node, got {len(dag_pb_blocks)}" + assert len(raw_blocks) > 0, f"Expected raw leaves, got {len(raw_blocks)}" + ok("Root is dag-pb, and all leaves are raw blocks") + + # Root must link to 3 leaves + root_block = stored[bytes(root_cid)] + links, unixfs = decode_dag_pb(root_block) + assert len(links) == 3, f"expected 3 links on root, got {len(links)}" + assert unixfs is not None + assert unixfs.filesize == len(content) + ok(f"root has 3 links, filesize={unixfs.filesize}") + + # Each leaf must be raw data + for link in links: + leaf_block = stored[bytes(link.cid)] + assert not is_file_node(leaf_block), "leaf must be raw data" + assert len(leaf_block) > 0 + ok("each leaf contains raw data") + + +# ── 7. add_bytes produces raw leaves ────────────────────────────────────── +async def test_add_bytes_produces_dag_pb_leaves(): + print("\n[7] MerkleDag.add_bytes produces raw leaf blocks") + from unittest.mock import AsyncMock, MagicMock + + from libp2p.bitswap.client import BitswapClient + from libp2p.bitswap.dag import MerkleDag + + store = MemoryBlockStore() + mock_client = MagicMock(spec=BitswapClient) + mock_client.block_store = store + stored: dict[bytes, bytes] = {} + + async def add_block_impl(cid, data): + stored[bytes(cid)] = data + + mock_client.add_block = AsyncMock(side_effect=add_block_impl) + + dag = MerkleDag(mock_client) + content = b"y" * (256 * 1024 * 3 + 500) # > 3 default chunks + await dag.add_bytes(content) + + raw_blocks = [] + dag_pb_blocks = [] + for c, d in stored.items(): + if is_file_node(d): + dag_pb_blocks.append(c) + else: + raw_blocks.append(c) + + assert len(dag_pb_blocks) == 1 + assert len(raw_blocks) > 0 + ok("Root is dag-pb, and all leaves are raw blocks") + + root_block = stored[dag_pb_blocks[0]] + links, unixfs = decode_dag_pb(root_block) + assert len(links) == 4 + assert unixfs is not None + assert unixfs.filesize == len(content) + ok(f"root has 3 links, filesize={unixfs.filesize}") + + +# ── main ────────────────────────────────────────────────────────────────────── +async def main(): + print("=" * 60) + print("UnixFSFile / Balanced DAG — Test Suite") + print("=" * 60) + + test_create_leaf_node() + test_balanced_layout_single() + test_balanced_layout_two_leaves() + test_balanced_layout_two_levels() + test_balanced_layout_flat() + await test_add_file_produces_dag_pb_leaves() + await test_add_bytes_produces_dag_pb_leaves() + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + + +if __name__ == "__main__": + trio.run(main) diff --git a/tests/core/bitswap/test_wantlist.py b/tests/core/bitswap/test_wantlist.py new file mode 100644 index 000000000..a632fc80b --- /dev/null +++ b/tests/core/bitswap/test_wantlist.py @@ -0,0 +1,281 @@ +""" +Test Wantlist / Message dataclasses. + +Run with: + python test_wantlist.py +""" + +from libp2p.bitswap.cid import CODEC_RAW, cid_to_bytes, compute_cid_v1 +from libp2p.bitswap.messages import create_wantlist_entry +from libp2p.bitswap.wantlist import ( + BitswapMessage, + BlockPresence, + BlockPresenceType, + Wantlist, + WantlistEntry, + WantType, +) + + +def make_cid(content: bytes) -> bytes: + return cid_to_bytes(compute_cid_v1(content, codec=CODEC_RAW)) + + +def ok(label): + print(f" OK {label}") + + +# ── WantType enum ───────────────────────────────────────────────────────────── + + +def test_want_type_values(): + print("\n[1] WantType enum values match protobuf") + assert WantType.Block.value == 0 + assert WantType.Have.value == 1 + ok("WantType.Block == 0, WantType.Have == 1") + + +# ── WantlistEntry ───────────────────────────────────────────────────────────── + + +def test_wantlist_entry_from_cid(): + print("\n[2] WantlistEntry.from_cid normalises any CIDInput") + cid = compute_cid_v1(b"entry test", codec=CODEC_RAW) + cid_bytes = cid_to_bytes(cid) + + # from bytes + e1 = WantlistEntry.from_cid(cid_bytes) + assert e1.cid == cid_bytes + assert e1.want_type == WantType.Block + assert e1.priority == 1 + assert not e1.cancel + ok("from bytes — defaults correct") + + # from CIDObject + e2 = WantlistEntry.from_cid(cid, want_type=WantType.Have, send_dont_have=True) + assert e2.want_type == WantType.Have + assert e2.send_dont_have + ok("from CIDObject — WantType.Have, send_dont_have=True") + + # cancel entry + e3 = WantlistEntry.from_cid(cid_bytes, cancel=True) + assert e3.cancel + ok("cancel entry") + + +# ── Wantlist ────────────────────────────────────────────────────────────────── + + +def test_wantlist_add_cancel_contains(): + print("\n[3] Wantlist.add / cancel / contains") + cid1 = make_cid(b"block 1") + cid2 = make_cid(b"block 2") + cid3 = make_cid(b"block 3") + + wl = Wantlist() + assert len(wl) == 0 + assert not wl + + wl.add(cid1, want_type=WantType.Block, send_dont_have=True) + wl.add(cid2, want_type=WantType.Have) + wl.cancel(cid3) + + assert len(wl) == 3 + assert bool(wl) + ok("len(wl) == 3 after 2 adds + 1 cancel") + + assert wl.contains(cid1) + assert wl.contains(cid2) + assert not wl.contains(cid3) # cancel entry → not "contained" + ok("contains() returns True for non-cancel entries only") + + # Check entry fields + e1 = wl.entries[0] + assert e1.want_type == WantType.Block + assert e1.send_dont_have + e2 = wl.entries[1] + assert e2.want_type == WantType.Have + e3 = wl.entries[2] + assert e3.cancel + ok("entry fields correct (want_type, send_dont_have, cancel)") + + +def test_wantlist_full_flag(): + print("\n[4] Wantlist.full flag") + wl = Wantlist(full=True) + assert wl.full + ok("full=True preserved") + + +# ── BlockPresence ───────────────────────────────────────────────────────────── + + +def test_block_presence(): + print("\n[5] BlockPresence constructors") + cid = make_cid(b"presence test") + + have = BlockPresence.have(cid) + assert have.cid == cid + assert have.type == BlockPresenceType.Have + ok("BlockPresence.have()") + + dont = BlockPresence.dont_have(cid) + assert dont.cid == cid + assert dont.type == BlockPresenceType.DontHave + ok("BlockPresence.dont_have()") + + assert BlockPresenceType.Have.value == 0 + assert BlockPresenceType.DontHave.value == 1 + ok("BlockPresenceType values match protobuf (Have=0, DontHave=1)") + + +# ── BitswapMessage ──────────────────────────────────────────────────────────── + + +def test_bitswap_message_properties(): + print("\n[6] BitswapMessage builder + properties") + cid1 = make_cid(b"want me") + cid2 = make_cid(b"block data") + cid3 = make_cid(b"i have this") + cid4 = make_cid(b"i dont have this") + data = b"actual block content" + + msg = BitswapMessage() + assert not msg.is_want + assert not msg.has_blocks + assert not msg.has_presences + + msg.add_want(cid1, want_type=WantType.Block, send_dont_have=True) + assert msg.is_want + ok("is_want True after add_want()") + + msg.add_block(cid2, data) + assert msg.has_blocks + assert msg.blocks[0] == (cid2, data) + ok("has_blocks True after add_block()") + + msg.add_have(cid3) + msg.add_dont_have(cid4) + assert msg.has_presences + assert len(msg.block_presences) == 2 + assert msg.block_presences[0].type == BlockPresenceType.Have + assert msg.block_presences[1].type == BlockPresenceType.DontHave + ok("has_presences True, HAVE and DONT_HAVE entries correct") + + +def test_bitswap_message_cancel_want(): + print("\n[7] BitswapMessage.cancel_want()") + cid = make_cid(b"cancel me") + msg = BitswapMessage() + msg.cancel_want(cid) + assert msg.is_want + assert msg.wantlist is not None + assert msg.wantlist.entries[0].cancel + ok("cancel_want() adds cancel entry") + + +# ── to_proto / from_proto round-trip ───────────────────────────────────────── + + +def test_to_proto_from_proto_roundtrip(): + print("\n[8] BitswapMessage to_proto() / from_proto() round-trip") + cid1 = make_cid(b"want block") + cid2 = make_cid(b"block payload") + cid3 = make_cid(b"have this") + data = b"block payload data" + + original = BitswapMessage() + original.add_want(cid1, want_type=WantType.Block, send_dont_have=True) + original.add_block(cid2, data) + original.add_have(cid3) + original.add_dont_have(make_cid(b"dont have")) + + proto = original.to_proto() + restored = BitswapMessage.from_proto(proto) + + # Wantlist + assert restored.wantlist is not None + assert len(restored.wantlist.entries) == 1 + e = restored.wantlist.entries[0] + assert e.cid == cid1 + assert e.want_type == WantType.Block + assert e.send_dont_have + ok("wantlist entry round-trips correctly") + + # Block payload + assert len(restored.blocks) == 1 + restored_cid, restored_data = restored.blocks[0] + assert restored_data == data + ok("block payload round-trips correctly") + + # Block presences + assert len(restored.block_presences) == 2 + assert restored.block_presences[0].type == BlockPresenceType.Have + assert restored.block_presences[1].type == BlockPresenceType.DontHave + ok("block presences round-trip correctly") + + +# ── backward compat: create_wantlist_entry accepts int OR WantType ──────────── + + +def test_create_wantlist_entry_backward_compat(): + print("\n[9] create_wantlist_entry — backward compat (int OR WantType)") + cid = make_cid(b"compat test") + + # Old style: raw int + e_int = create_wantlist_entry(cid, want_type=0) + assert e_int.wantType == 0 + ok("want_type=0 (int) still works") + + e_int2 = create_wantlist_entry(cid, want_type=1) + assert e_int2.wantType == 1 + ok("want_type=1 (int) still works") + + # New style: WantType enum + e_enum = create_wantlist_entry(cid, want_type=WantType.Block) + assert e_enum.wantType == 0 + ok("want_type=WantType.Block works") + + e_enum2 = create_wantlist_entry(cid, want_type=WantType.Have) + assert e_enum2.wantType == 1 + ok("want_type=WantType.Have works") + + +# ── public API exports ──────────────────────────────────────────────────────── + + +def test_public_exports(): + print("\n[10] All types exported from libp2p.bitswap") + from libp2p.bitswap import ( + WantType, + ) + + assert WantType.Block.value == 0 + assert WantType.Have.value == 1 + ok( + "WantType, WantlistEntry, Wantlist, BlockPresence, BlockPresenceType, " + "BitswapMessage all importable from libp2p.bitswap" + ) + + +# ── main ────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + print("=" * 60) + print("Wantlist / Message Dataclasses — Test Suite") + print("=" * 60) + + test_want_type_values() + test_wantlist_entry_from_cid() + test_wantlist_add_cancel_contains() + test_wantlist_full_flag() + test_block_presence() + test_bitswap_message_properties() + test_bitswap_message_cancel_want() + test_to_proto_from_proto_roundtrip() + test_create_wantlist_entry_backward_compat() + test_public_exports() + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) diff --git a/tests/core/kad_dht/test_kad_dht_quorum_sliding_window.py b/tests/core/kad_dht/test_kad_dht_quorum_sliding_window.py index 87e669cc0..b1be7cdb8 100644 --- a/tests/core/kad_dht/test_kad_dht_quorum_sliding_window.py +++ b/tests/core/kad_dht/test_kad_dht_quorum_sliding_window.py @@ -41,6 +41,7 @@ def _make_dht() -> KadDHT: host = MagicMock() key_pair = create_new_key_pair() host.get_id.return_value = ID.from_pubkey(key_pair.public_key) + host.get_private_key.return_value = key_pair.private_key host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] host.get_peerstore.return_value = MagicMock() host.new_stream = AsyncMock() diff --git a/tests/core/kad_dht/test_unit_value_store.py b/tests/core/kad_dht/test_unit_value_store.py index bdaaacd9c..6a5d7d4a7 100644 --- a/tests/core/kad_dht/test_unit_value_store.py +++ b/tests/core/kad_dht/test_unit_value_store.py @@ -15,6 +15,7 @@ import pytest +from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.kad_dht.value_store import ( DEFAULT_TTL, ValueStore, @@ -24,8 +25,11 @@ ) from libp2p.records.record import make_put_record +# Create a real key pair for signing +key_pair = create_new_key_pair() mock_host = Mock() -peer_id = ID.from_base58("QmTest123") +mock_host.get_private_key.return_value = key_pair.private_key +peer_id = ID.from_pubkey(key_pair.public_key) class TestValueStore: @@ -445,6 +449,178 @@ async def test_store_at_peer_local_peer(self): assert result is True + @pytest.mark.trio + async def test_store_at_peer_propagates_signature_and_author(self): + """ + _store_at_peer must include signature and author from the locally-stored + signed record in the outbound PUT_VALUE message. + + This ensures signed-record authenticity is preserved when replicating + values to remote peers, matching go-libp2p interoperability requirements. + """ + import varint + + from libp2p.kad_dht.pb.kademlia_pb2 import Message + + # Build a host with a real key pair so put() creates a genuine signed record + kp = create_new_key_pair() + remote_peer_id = ID.from_base58("QmRemote123456789") + local_peer_id = ID.from_pubkey(kp.public_key) + + # Capture the bytes written to the mock stream + written: list[bytes] = [] + + mock_stream = Mock() + + async def _write(data: bytes) -> None: + written.append(data) + + async def _read(n: int) -> bytes: + # Simulate a minimal valid PUT_VALUE acknowledgement + resp = Message() + resp.type = Message.MessageType.PUT_VALUE + resp.key = b"test_key" + raw = resp.SerializeToString() + length = varint.encode(len(raw)) + # Return one byte at a time for the varint reader, then the body + full = length + raw + if not hasattr(_read, "_buf"): + _read._buf = iter(full) # type: ignore[attr-defined] + byte_val = next(_read._buf, b"") # type: ignore[attr-defined] + return bytes([byte_val]) if isinstance(byte_val, int) else byte_val + + mock_stream.write = Mock(side_effect=_write) + mock_stream.read = Mock(side_effect=_read) + mock_stream.close = Mock(return_value=None) + + # Patch close to be awaitable + async def _close() -> None: + pass + + mock_stream.close = _close + + h = Mock() + h.get_private_key.return_value = kp.private_key + h.get_peerstore.return_value = Mock() + + # env_to_send_in_RPC is called; return empty bytes to keep test simple + from libp2p.peer.peerstore import env_to_send_in_RPC + + original_env = env_to_send_in_RPC + + import libp2p.kad_dht.value_store as vs_module + + vs_module.env_to_send_in_RPC = Mock(return_value=(b"", None)) # type: ignore[attr-defined] + + async def _new_stream(*_args: object, **_kwargs: object) -> object: + return mock_stream + + h.new_stream = _new_stream + + try: + store = ValueStore(host=h, local_peer_id=local_peer_id) + key = b"test_key" + value = b"test_value" + + # Store locally first (creates signed record) + store.put(key, value) + + # Confirm the local record has signature and author set + local_record, _ = store.store[key] + assert local_record.signature, "put() must produce a non-empty signature" + assert local_record.author, "put() must populate the author field" + + # Now replicate to a remote peer + await store._store_at_peer(remote_peer_id, key, value) + + # Reconstruct the serialized message from what was written + # written[0] is the varint length prefix, written[1] is the proto body + assert len(written) >= 2, "Expected varint + proto body to be written" + sent_msg = Message() + sent_msg.ParseFromString(written[1]) + + assert sent_msg.HasField("record"), "Outbound message must contain a record" + assert sent_msg.record.signature == local_record.signature, ( + "Outbound record must carry the signature from the signed record" + ) + assert sent_msg.record.author == local_record.author, ( + "Outbound record must carry the author from the signed record" + ) + finally: + vs_module.env_to_send_in_RPC = original_env # type: ignore[attr-defined] + + @pytest.mark.trio + async def test_store_at_peer_signs_record_without_prior_put(self): + """ + When _store_at_peer is called without a prior put() (e.g. the get_value + propagation path), it must still produce a signed outbound record — + never a bare unsigned one. + """ + import varint + + from libp2p.kad_dht.pb.kademlia_pb2 import Message + + kp = create_new_key_pair() + remote_peer_id = ID.from_base58("QmRemote999") + local_peer_id = ID.from_pubkey(kp.public_key) + + written: list[bytes] = [] + + async def _write(data: bytes) -> None: + written.append(data) + + mock_stream = Mock() + resp = Message() + resp.type = Message.MessageType.PUT_VALUE + resp.key = b"bare_key" + raw = resp.SerializeToString() + resp_bytes = varint.encode(len(raw)) + raw + resp_iter = iter(resp_bytes) + + async def _read(n: int) -> bytes: + byte_val = next(resp_iter, b"") + return bytes([byte_val]) if isinstance(byte_val, int) else byte_val + + mock_stream.write = Mock(side_effect=_write) + mock_stream.read = Mock(side_effect=_read) + + async def _close() -> None: + pass + + mock_stream.close = _close + + h = Mock() + h.get_private_key.return_value = kp.private_key + + import libp2p.kad_dht.value_store as vs_module + + original_env = vs_module.env_to_send_in_RPC + vs_module.env_to_send_in_RPC = Mock(return_value=(b"", None)) # type: ignore[attr-defined] + + async def _new_stream(*_args: object, **_kwargs: object) -> object: + return mock_stream + + h.new_stream = _new_stream + + try: + store = ValueStore(host=h, local_peer_id=local_peer_id) + key = b"bare_key" + value = b"bare_value" + + # Do NOT call store.put() — _store_at_peer must sign the record itself + await store._store_at_peer(remote_peer_id, key, value) + + assert len(written) >= 2 + sent_msg = Message() + sent_msg.ParseFromString(written[1]) + assert sent_msg.record.key == key + assert sent_msg.record.value == value + # The record must be signed even without a prior put() + assert sent_msg.record.signature, "record must be signed inline" + assert sent_msg.record.author, "record must carry author field" + finally: + vs_module.env_to_send_in_RPC = original_env # type: ignore[attr-defined] + @pytest.mark.trio async def test_get_from_peer_local_peer(self): """Test _get_from_peer returns None when querying local peer.""" diff --git a/tests/core/records/test_validator.py b/tests/core/records/test_validator.py index 4a0efc0f7..9faf3bb6c 100644 --- a/tests/core/records/test_validator.py +++ b/tests/core/records/test_validator.py @@ -5,7 +5,12 @@ from libp2p.peer.id import ID from libp2p.records.pubkey import PublicKeyValidator, unmarshal_public_key from libp2p.records.record import make_put_record -from libp2p.records.utils import InvalidRecordType, split_key +from libp2p.records.utils import ( + InvalidRecordType, + sign_record, + split_key, + verify_record, +) from libp2p.records.validator import NamespacedValidator, Validator bad_paths = [ @@ -243,3 +248,85 @@ def select(self, key: str, values: list[bytes]) -> int: # Non-namespaced key uses custom fallback that rejects with pytest.raises(ValueError, match="Rejected by fallback"): validators.validate("plain-key", b"value") + + +# ───────────────────────────────────────────────────────────────────────────── +# verify_record — multi-key-type coverage +# ───────────────────────────────────────────────────────────────────────────── + + +class TestVerifyRecord: + """ + verify_record must accept signatures from every key type that libp2p + serialises via crypto_pb2.PublicKey (Ed25519, Secp256k1, RSA). + + Previously the implementation hard-coded Ed25519PublicKey.from_bytes, + causing it to silently return False for RSA and Secp256k1 peers and + breaking DHT interoperability with non-Ed25519 nodes. + """ + + def _round_trip(self, key_pair) -> None: # noqa: ANN001 + """Sign with *key_pair* and assert verify_record returns True.""" + key = b"/test/mykey" + value = b"hello world" + sig, author = sign_record(key_pair.private_key, key, value) + assert verify_record(sig, author, key, value), ( + f"verify_record returned False for key type " + f"{key_pair.private_key.get_type()}" + ) + + def _tampered_fails(self, key_pair) -> None: # noqa: ANN001 + """Tampered payload must make verify_record return False.""" + key = b"/test/mykey" + value = b"hello world" + sig, author = sign_record(key_pair.private_key, key, value) + assert not verify_record(sig, author, key, b"tampered"), ( + f"verify_record accepted tampered value for key type " + f"{key_pair.private_key.get_type()}" + ) + + def test_ed25519_valid_signature(self) -> None: + from libp2p.crypto.ed25519 import create_new_key_pair as ed_kp + + self._round_trip(ed_kp()) + + def test_ed25519_tampered_value_rejected(self) -> None: + from libp2p.crypto.ed25519 import create_new_key_pair as ed_kp + + self._tampered_fails(ed_kp()) + + def test_secp256k1_valid_signature(self) -> None: + from libp2p.crypto.secp256k1 import create_new_key_pair as secp_kp + + self._round_trip(secp_kp()) + + def test_secp256k1_tampered_value_rejected(self) -> None: + from libp2p.crypto.secp256k1 import create_new_key_pair as secp_kp + + self._tampered_fails(secp_kp()) + + def test_rsa_valid_signature(self) -> None: + from libp2p.crypto.rsa import create_new_key_pair as rsa_kp + + self._round_trip(rsa_kp()) + + def test_rsa_tampered_value_rejected(self) -> None: + from libp2p.crypto.rsa import create_new_key_pair as rsa_kp + + self._tampered_fails(rsa_kp()) + + def test_garbage_author_bytes_returns_false(self) -> None: + """Completely invalid author bytes must return False, not raise.""" + assert not verify_record(b"sig", b"not-a-valid-protobuf", b"key", b"value") + + def test_wrong_key_returns_false(self) -> None: + """Signature verified against a different key must return False.""" + from libp2p.crypto.ed25519 import create_new_key_pair as ed_kp + + kp1 = ed_kp() + kp2 = ed_kp() + key = b"/test/k" + value = b"v" + sig, _ = sign_record(kp1.private_key, key, value) + _, author2 = sign_record(kp2.private_key, key, value) + assert not verify_record(sig, author2, key, value)