diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
new file mode 100644
index 0000000..cc2c23a
--- /dev/null
+++ b/.github/CODEOWNERS
@@ -0,0 +1 @@
+* @kolkov
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..d40a34d
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,207 @@
+name: CI
+
+# CI Strategy:
+# - Tests run on Linux, macOS, and Windows (cross-platform IPC library)
+# - Go 1.25+ required (matches go.mod requirement)
+# - CGO_ENABLED=0: Pure Go library, no C compiler required
+#
+# Branch Strategy (GitHub Flow):
+# - main branch: Production-ready code
+# - Feature branches: Tested via pull_request trigger
+# - Pull requests: Must pass all checks before merge
+
+env:
+ CGO_ENABLED: "0"
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+permissions:
+ contents: read
+ id-token: write # Required for Codecov OIDC token
+
+jobs:
+ # Build verification - Cross-platform
+ build:
+ name: Build - ${{ matrix.os }}
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ go-version: ['1.25']
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ matrix.go-version }}
+ cache: true
+
+ - name: Download dependencies
+ run: go mod download
+
+ - name: Verify dependencies
+ run: go mod verify
+
+ - name: Build all packages
+ run: go build ./...
+
+ # Unit tests - Cross-platform
+ test:
+ name: Test - ${{ matrix.os }}
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ go-version: ['1.25']
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ matrix.go-version }}
+ cache: true
+
+ - name: Download dependencies
+ run: go mod download
+
+ - name: Run go vet
+ if: matrix.os == 'ubuntu-latest'
+ run: go vet ./...
+
+ - name: Run tests
+ run: go test -v ./...
+
+ - name: Run tests with coverage
+ if: matrix.os == 'ubuntu-latest'
+ run: go test -coverprofile=coverage.out -covermode=atomic ./...
+
+ - name: Upload coverage to Codecov
+ if: matrix.os == 'ubuntu-latest'
+ uses: codecov/codecov-action@v5
+ with:
+ files: coverage.out
+ use_oidc: true
+
+ # Linting
+ lint:
+ name: Lint
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: '1.25'
+ cache: true
+
+ - name: Run golangci-lint
+ uses: golangci/golangci-lint-action@v8
+ with:
+ version: latest
+ args: --timeout=5m
+
+ # Code formatting check
+ formatting:
+ name: Formatting
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: '1.25'
+ cache: true
+
+ - name: Check formatting
+ run: |
+ if [ -n "$(gofmt -l .)" ]; then
+ echo "ERROR: The following files are not formatted:"
+ gofmt -l .
+ echo ""
+ echo "Run 'go fmt ./...' to fix formatting issues."
+ exit 1
+ fi
+ echo "All files are properly formatted"
+
+ # Dependency freshness check
+ # Uses go-mod-outdated (https://github.com/psampaz/go-mod-outdated)
+ # Non-blocking: reports outdated deps as warnings, does not fail CI
+ deps:
+ name: Dependencies
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up Go
+ uses: actions/setup-go@v5
+ with:
+ go-version: '1.25'
+ cache: true
+
+ - name: Install go-mod-outdated
+ run: go install github.com/psampaz/go-mod-outdated@latest
+
+ - name: Check all dependencies
+ run: |
+ echo "## All Dependencies Status"
+ go list -u -m -json all 2>/dev/null | go-mod-outdated -style markdown || true
+
+ - name: Check direct dependencies for updates
+ run: |
+ echo "## Direct Dependencies with Available Updates"
+ OUTDATED=$(go list -u -m -json all 2>/dev/null | go-mod-outdated -update -direct || true)
+ if [ -n "$OUTDATED" ]; then
+ echo "$OUTDATED"
+ echo ""
+ echo "::warning::Some direct dependencies have updates available"
+ else
+ echo "All direct dependencies are up to date!"
+ fi
+
+ - name: Check ecosystem dependencies
+ env:
+ GH_TOKEN: ${{ github.token }}
+ run: |
+ echo "## Ecosystem Dependencies"
+ WARNINGS=0
+
+ check_ecosystem_dep() {
+ local DEP=$1 REPO=$2
+ LOCAL=$(grep "$DEP" go.mod 2>/dev/null | grep -v "^module" | awk '{print $2}')
+ [ -z "$LOCAL" ] && return 0
+
+ LATEST=$(gh release view --repo "$REPO" --json tagName -q '.tagName' 2>/dev/null || echo "")
+ [ -z "$LATEST" ] && { echo " $DEP: $LOCAL (cannot verify)"; return 0; }
+
+ if [ "$LOCAL" = "$LATEST" ]; then
+ echo " $DEP: $LOCAL"
+ else
+ echo " $DEP: $LOCAL -> $LATEST available"
+ WARNINGS=$((WARNINGS + 1))
+ fi
+ }
+
+ check_ecosystem_dep "github.com/pierrec/lz4/v4" "pierrec/lz4"
+
+ [ $WARNINGS -gt 0 ] && echo "::warning::$WARNINGS ecosystem dep(s) outdated. Run: go get @latest"
+ exit 0 # Non-blocking
diff --git a/.golangci.yml b/.golangci.yml
new file mode 100644
index 0000000..11b9746
--- /dev/null
+++ b/.golangci.yml
@@ -0,0 +1,165 @@
+# GolangCI-Lint v2 Configuration for gogpu/compose
+# Documentation: https://golangci-lint.run/docs/configuration/
+
+version: "2"
+
+run:
+ timeout: 5m
+ tests: true
+
+linters:
+ enable:
+ # Code quality and complexity
+ - gocyclo
+ - gocognit
+ - funlen
+ - maintidx
+ - cyclop
+ - nestif
+
+ # Bug detection
+ - govet
+ - staticcheck
+ - errcheck
+ - errorlint
+ - gosec
+ - nilnil
+ - nilerr
+ - ineffassign
+
+ # Code style and consistency
+ - misspell
+ - whitespace
+ - unconvert
+ - unparam
+
+ # Naming conventions
+ - errname
+ - revive
+
+ # Performance
+ - prealloc
+ - makezero
+
+ # Code practices
+ - goconst
+ - gocritic
+ - goprintffuncname
+ - nolintlint
+ - nakedret
+
+ # Additional quality checkers
+ - dupl
+ - dogsled
+ - durationcheck
+
+ settings:
+ govet:
+ enable:
+ - copylocks
+ disable:
+ - fieldalignment
+
+ gocyclo:
+ min-complexity: 20
+
+ cyclop:
+ max-complexity: 20
+
+ funlen:
+ lines: 120
+ statements: 60
+
+ gocognit:
+ min-complexity: 30
+
+ misspell:
+ locale: US
+
+ nestif:
+ min-complexity: 4
+
+ revive:
+ rules:
+ - name: var-naming
+ - name: error-return
+ - name: error-naming
+ - name: if-return
+ - name: increment-decrement
+ - name: var-declaration
+ - name: range
+ - name: receiver-naming
+ - name: time-naming
+ - name: unexported-return
+ - name: indent-error-flow
+ - name: errorf
+ - name: empty-block
+ - name: superfluous-else
+ - name: unreachable-code
+ - name: redefines-builtin-id
+
+ gocritic:
+ enabled-tags:
+ - diagnostic
+ - style
+ - performance
+
+ disabled-checks:
+ - commentFormatting
+ - whyNoLint
+ - unnamedResult
+ - commentedOutCode
+ - octalLiteral
+ - paramTypeCombine
+
+ settings:
+ hugeParam:
+ sizeThreshold: 256
+
+ exclusions:
+ rules:
+ # Enum String() methods: the string literal IS the constant name, not a
+ # candidate for extraction. Standard Go stringer pattern (false positive).
+ - linters: [goconst]
+ source: 'return "'
+
+ # Test files - allow more flexibility
+ - path: _test\.go
+ linters:
+ - gocyclo
+ - cyclop
+ - funlen
+ - maintidx
+ - errcheck
+ - gosec
+ - goconst
+ - dogsled
+ - dupl
+ - gocognit
+
+ # Example code
+ - path: examples?/.*\.go
+ linters:
+ - errcheck
+ - errorlint
+ - funlen
+ - gocyclo
+ - cyclop
+ - gocognit
+ - revive
+ - gosec
+ - gocritic
+
+formatters:
+ enable:
+ - gofmt
+ - goimports
+
+issues:
+ max-issues-per-linter: 0
+ max-same-issues: 0
+ new: false
+
+output:
+ sort-order:
+ - file
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..680d009
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,18 @@
+# Changelog
+
+All notable changes to this project will be documented in this file.
+
+The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
+and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+
+## [0.1.0] — 2026-05-17
+
+### Added
+
+- **Wire protocol v1** (`internal/protocol/`) — 64-byte fixed header, 128-byte handshake messages, message types, encode/decode with zero allocations (100% coverage)
+- **Codec package** (`internal/codec/`) — Raw pass-through + LZ4 block compression via `pierrec/lz4/v4` (97% coverage, 2.9 GB/s encode, 99.6% compression on GUI pixels)
+- **Connection manager** (`internal/conn/`) — module registry with monotonic ID allocation, lifecycle state machine, hot-plug callbacks (98.9% coverage)
+- **Flow controller** (`internal/flow/`) — pull-based frame pacing (Wayland frame callback pattern), adaptive rate reduction after missed frames (100% coverage)
+- **Unix socket transport** (`internal/transport/socket/`) — framed Conn, Listener, Dialer for Unix domain sockets (95.1% coverage, 4.3 GB/s, 45μs latency)
+- **Public API** — `compose.Listen()`, `compose.Dial()`, `Frame` type, functional options (`WithMaxModules`, `WithCompression`, `WithName`, `WithFrameSize`, `WithFPS`)
+- **CI/CD** — GitHub Actions (build/test/lint on Ubuntu/macOS/Windows), Codecov, golangci-lint v2
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000..037490d
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,34 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a welcoming and respectful experience for everyone.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment:
+
+- Using welcoming and inclusive language
+- Being respectful of differing viewpoints and experiences
+- Gracefully accepting constructive criticism
+- Focusing on what is best for the community
+- Showing empathy towards other community members
+
+Examples of unacceptable behavior:
+
+- The use of sexualized language or imagery and unwelcome sexual attention
+- Trolling, insulting/derogatory comments, and personal or political attacks
+- Public or private harassment
+- Publishing others' private information without explicit permission
+- Other conduct which could reasonably be considered inappropriate
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the project maintainers at a.kolkov@gmail.com.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/),
+version 2.1.
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..c776959
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,81 @@
+# Contributing to compose
+
+Thank you for your interest in contributing to gogpu/compose! This document covers how to build, test, and submit changes.
+
+## Prerequisites
+
+- **Go 1.25+** ([download](https://go.dev/dl/))
+- **golangci-lint** (`go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest`)
+
+## Building
+
+```bash
+go build ./...
+```
+
+## Running Tests
+
+```bash
+go test ./...
+```
+
+## Running Tests with Coverage
+
+```bash
+go test -coverprofile=tmp/coverage.out ./...
+go tool cover -html=tmp/coverage.out
+```
+
+## Running Linter
+
+```bash
+golangci-lint run --timeout=5m
+```
+
+## Code Standards
+
+- **Pure Go** — zero CGO, zero platform-specific code in public API
+- **Enterprise quality** — 90%+ test coverage, zero-alloc hot paths
+- **Functional options** — use `With*` pattern for configuration
+- **Internal packages** — implementation details live in `internal/`
+
+## Pull Request Process
+
+1. Fork the repository
+2. Create a feature branch (`git checkout -b feat/my-feature`)
+3. Make changes with tests
+4. Run `go fmt ./... && golangci-lint run --timeout=5m && go test ./...`
+5. Commit with conventional format (`feat:`, `fix:`, `docs:`, etc.)
+6. Open a pull request against `main`
+
+## Commit Messages
+
+```
+feat: add shared memory transport
+fix: handle connection timeout on Windows
+docs: update wire protocol documentation
+test: add benchmark for frame compression
+chore: update lz4 dependency
+```
+
+## Architecture
+
+The library follows a strict internal/ pattern:
+
+```
+compose/ # Public API (< 15 exported symbols)
+├── internal/
+│ ├── protocol/ # Wire format (64B header, handshake)
+│ ├── codec/ # Compression (Raw, LZ4)
+│ ├── conn/ # Module lifecycle
+│ ├── flow/ # Pull-based pacing
+│ └── transport/
+│ ├── socket/ # Unix domain socket (Phase 1)
+│ └── shm/ # Shared memory (Phase 2)
+```
+
+Users import only `"github.com/gogpu/compose"`. All implementation is hidden.
+
+## License
+
+By contributing, you agree that your contributions will be licensed under the MIT License.
diff --git a/README.md b/README.md
index 2071e2d..245ab5c 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
-
+
@@ -16,15 +16,22 @@
+
+
-
---
+## Installation
+
+```bash
+go get github.com/gogpu/compose
+```
+
## What is compose?
The `compose` library lets Go applications combine content from several independent processes onto a single display.
@@ -75,7 +82,7 @@ srv.OnFrame(func(f compose.Frame) {
})
```
-> **Status: design phase.** APIs above are aspirational — they show the intended shape of the library. See the [Roadmap](#roadmap) section for current progress.
+> Unix socket transport, wire protocol v1, LZ4 compression, pull-based flow control. See the [Roadmap](#roadmap) for current progress.
## Why a separate library?
@@ -184,14 +191,12 @@ The protocol is the stable compatibility surface of compose. Versioning is indep
| Phase | Features | Status |
|-------|----------|--------|
-| **Phase 0** | Design, ADRs, reference example sketch | Design phase |
-| **Phase 1** | Reference example: compositor + clock module + notification module | Planned |
-| **Phase 2** | Wire protocol v1, framing, header, dirty rects | Planned |
-| **Phase 3** | Unix socket transport, hot-plug, connection management | Planned |
-| **Phase 4** | Shared memory ring buffer transport | Planned |
-| **Phase 5** | Extract stable APIs from the reference example into this library | Planned |
-| **Phase 6** | Multi-screen layout, layered z-order, fade transitions | Future |
-| **Phase 7** | Cross-language module SDK (C header, Rust crate, Python) | Future |
+| **Phase 1** | Wire protocol v1, Unix socket transport, LZ4 compression, public API | Complete |
+| **Phase 2** | Reference examples: compositor + clock + notification (multi-process) | Next |
+| **Phase 3** | Shared memory ring buffer transport (zero-copy) | Planned |
+| **Phase 4** | Delta frames, compression negotiation | Planned |
+| **Phase 5** | Multi-screen layout, layered z-order, fade transitions | Future |
+| **Phase 6** | Cross-language module SDK (C header, Rust crate, Python) | Future |
## Design principles
@@ -215,11 +220,11 @@ The design of `compose` is informed by:
## Status and contributing
-The `compose` library is in the **design phase**. There is no shippable code yet. This repository exists to host the design discussion, the architecture decision records, and (once they exist) the reference example and the extracted library.
+The `compose` library ships a working Unix socket transport with wire protocol v1, LZ4 compression, pull-based flow control, and hot-plug module lifecycle.
-The first user is a [Go rewrite of MagicMirror²](https://github.com/gogpu/ui/issues/75) targeting Raspberry Pi and (eventually) Redox OS.
+Known users: [KiGo](https://github.com/AgentNemo00/kigo) (modular Go application using offscreen rendering + multi-process composition).
-If you have a use case for multi-process composition in Go and want to influence the design before APIs freeze, please join the [compose RFC discussion](https://github.com/orgs/gogpu/discussions/177) or open an issue describing your scenario. For the related (but distinct) question of in-process multi-window support in `gogpu` itself, see the [multi-window RFC discussion](https://github.com/orgs/gogpu/discussions/167). Real use cases drive both — we deliberately avoid designing against hypotheticals.
+If you have a use case for multi-process composition in Go, please join the [compose RFC discussion](https://github.com/orgs/gogpu/discussions/177) or open an issue describing your scenario.
## Part of the GoGPU Ecosystem
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 0000000..38b9beb
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,49 @@
+# Security Policy
+
+## Supported Versions
+
+| Version | Supported |
+| ------- | ------------------ |
+| 0.1.x | :white_check_mark: |
+
+## Reporting a Vulnerability
+
+**DO NOT** open a public GitHub issue for security vulnerabilities.
+
+Instead, please report security issues via:
+
+1. **Private Security Advisory** (preferred):
+ https://github.com/gogpu/compose/security/advisories/new
+
+2. **GitHub Discussions** (for less critical issues):
+ https://github.com/gogpu/gogpu/discussions
+
+### What to Include
+
+- Description of the vulnerability
+- Steps to reproduce
+- Affected versions
+- Potential impact
+
+### Response Timeline
+
+- **Initial Response**: Within 72 hours
+- **Fix & Disclosure**: Coordinated with reporter
+
+## Security Considerations
+
+compose uses IPC mechanisms for inter-process communication. Users should be aware of:
+
+1. **Unix Domain Sockets** — socket files are created with default permissions. Use appropriate file permissions in production.
+2. **Shared Memory** (Phase 2) — memory-mapped regions are shared between processes. Ensure only trusted modules connect.
+3. **Wire Protocol** — frame data is not encrypted. For untrusted networks, use TLS or a secure transport.
+4. **Module Identity** — module names are self-declared. The compositor should validate module identity for security-critical deployments.
+
+## Security Contact
+
+- **GitHub Security Advisory**: https://github.com/gogpu/compose/security/advisories/new
+- **Public Issues**: https://github.com/gogpu/compose/issues
+
+---
+
+**Thank you for helping keep gogpu/compose secure!**
diff --git a/client.go b/client.go
new file mode 100644
index 0000000..3879ee0
--- /dev/null
+++ b/client.go
@@ -0,0 +1,320 @@
+package compose
+
+import (
+ "fmt"
+ "image"
+ "math"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/gogpu/compose/internal/codec"
+ "github.com/gogpu/compose/internal/protocol"
+ "github.com/gogpu/compose/internal/transport/socket"
+)
+
+// saturateUint16 converts v to uint16, clamping to math.MaxUint16 on overflow.
+func saturateUint16(v int) uint16 {
+ return uint16(min(max(v, 0), math.MaxUint16)) //nolint:gosec // clamped to [0, MaxUint16]
+}
+
+// saturateUint16from32 converts v to uint16, clamping to math.MaxUint16 on overflow.
+func saturateUint16from32(v uint32) uint16 {
+ return uint16(min(v, math.MaxUint16)) //nolint:gosec // clamped to [0, MaxUint16]
+}
+
+// saturateUint32 converts v to uint32, clamping to math.MaxUint32 on overflow.
+func saturateUint32(v int) uint32 {
+ return uint32(min(max(v, 0), math.MaxUint32)) //nolint:gosec // clamped to [0, MaxUint32]
+}
+
+// Client is the module-side endpoint that connects to a compositor and
+// publishes frames. All methods are safe for concurrent use.
+type Client struct {
+ conn *socket.Conn
+ moduleID uint64
+ name string
+ codec codec.Codec
+
+ mu sync.RWMutex
+ onFrameRequest func()
+
+ closed atomic.Bool
+ done chan struct{}
+ wg sync.WaitGroup
+ seq atomic.Uint64
+}
+
+// Dial creates a Client that connects to a compositor at the given Unix
+// domain socket address. Dial performs the handshake immediately: it sends
+// a HelloMsg and reads the WelcomeMsg. If the compositor rejects the
+// connection, ErrNotAccepted is returned.
+//
+// Use ClientOption functions to configure the module:
+//
+// client, err := compose.Dial("/tmp/compose.sock",
+// compose.WithName("clock"),
+// compose.WithFrameSize(400, 120),
+// compose.WithFPS(1),
+// )
+func Dial(addr string, opts ...ClientOption) (*Client, error) {
+ cfg := defaultClientConfig()
+ for _, o := range opts {
+ o(&cfg)
+ }
+
+ dialer := socket.NewDialer(addr)
+ sc, err := dialer.Dial()
+ if err != nil {
+ return nil, fmt.Errorf("compose: dial: %w", err)
+ }
+
+ // Build and send HelloMsg.
+ hello := &protocol.HelloMsg{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ Width: saturateUint16from32(cfg.width),
+ Height: saturateUint16from32(cfg.height),
+ PreferredFPS: cfg.fps,
+ Transport: protocol.TransportSocket,
+ }
+ protocol.SetName(hello, cfg.name)
+
+ if err := sc.WriteHandshakeHello(hello); err != nil {
+ _ = sc.Close()
+ return nil, fmt.Errorf("compose: send hello: %w", err)
+ }
+
+ // Read WelcomeMsg.
+ welcome, err := sc.ReadHandshakeWelcome()
+ if err != nil {
+ _ = sc.Close()
+ return nil, fmt.Errorf("compose: read welcome: %w", err)
+ }
+
+ if welcome.Accepted == 0 {
+ _ = sc.Close()
+ return nil, ErrNotAccepted
+ }
+
+ c := &Client{
+ conn: sc,
+ moduleID: welcome.ModuleID,
+ name: cfg.name,
+ codec: codec.Raw(), // client always sends raw; server handles decompression
+ done: make(chan struct{}),
+ }
+
+ // Start reader goroutine for FrameRequest messages.
+ c.wg.Add(1)
+ go c.readLoop()
+
+ return c, nil
+}
+
+// PublishFrame sends a frame to the compositor.
+// The frame's ModuleID is automatically set to this client's assigned ID.
+//
+// Returns ErrClosed if the client has been shut down.
+func (c *Client) PublishFrame(f Frame) error {
+ if c.closed.Load() {
+ return ErrClosed
+ }
+
+ seq := c.seq.Add(1)
+ pixels := f.Pixels
+ uncompressedSize := saturateUint32(len(pixels))
+
+ // Compress if codec is not raw.
+ var flags protocol.Flag
+ compressionID := protocol.CompressionNone
+
+ if c.codec.ID() != codec.IDRaw {
+ maxSize := c.codec.MaxEncodedSize(len(pixels))
+ dst := make([]byte, maxSize)
+ compressed, err := c.codec.Encode(dst, pixels)
+ if err != nil {
+ return fmt.Errorf("compose: compress frame: %w", err)
+ }
+ pixels = compressed
+ flags = flags.Set(protocol.FlagCompressed)
+ compressionID = protocol.Compression(c.codec.ID())
+ }
+
+ // Set dirty rect flags.
+ if !f.DirtyRect.Empty() {
+ flags = flags.Set(protocol.FlagDirtyValid)
+ } else {
+ flags = flags.Set(protocol.FlagKeyframe)
+ }
+
+ hdr := &protocol.Header{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ MsgType: protocol.MsgFrame,
+ Flags: flags,
+ ModuleID: c.moduleID,
+ Sequence: seq,
+ TimestampNs: f.Timestamp,
+ Width: saturateUint16from32(f.Width),
+ Height: saturateUint16from32(f.Height),
+ Stride: f.Width * 4,
+ PixelFormat: protocol.PixelRGBA8,
+ Compression: compressionID,
+ PayloadSize: saturateUint32(len(pixels)),
+ UncompressedSize: uncompressedSize,
+ }
+
+ // Set dirty rect fields if valid.
+ if flags.Has(protocol.FlagDirtyValid) {
+ hdr.DirtyX = saturateUint16(f.DirtyRect.Min.X)
+ hdr.DirtyY = saturateUint16(f.DirtyRect.Min.Y)
+ hdr.DirtyW = saturateUint16(f.DirtyRect.Dx())
+ hdr.DirtyH = saturateUint16(f.DirtyRect.Dy())
+ }
+
+ // Use monotonic timestamp if caller did not set one.
+ if hdr.TimestampNs == 0 {
+ hdr.TimestampNs = time.Now().UnixNano()
+ }
+
+ return c.conn.WriteFrame(hdr, pixels)
+}
+
+// OnFrameRequest registers a callback invoked when the compositor requests
+// a frame. This enables pull-based rendering: the module renders only when
+// asked. The callback is called on an internal goroutine; it must not block
+// for extended periods. Only one callback can be active; subsequent calls
+// replace the previous one. Pass nil to remove.
+func (c *Client) OnFrameRequest(fn func()) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.onFrameRequest = fn
+}
+
+// Close disconnects from the compositor and releases resources.
+// After Close returns, no more callbacks will be invoked.
+func (c *Client) Close() error {
+ if !c.closed.CompareAndSwap(false, true) {
+ return ErrClosed
+ }
+
+ close(c.done)
+
+ // Send graceful disconnect message (best-effort).
+ hdr := &protocol.Header{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ MsgType: protocol.MsgDisconnect,
+ ModuleID: c.moduleID,
+ }
+ _ = c.conn.WriteFrame(hdr, nil)
+
+ err := c.conn.Close()
+
+ c.wg.Wait()
+
+ if err != nil {
+ return fmt.Errorf("compose: close client: %w", err)
+ }
+ return nil
+}
+
+// ModuleID returns the compositor-assigned module identifier.
+// This is valid after Dial returns successfully.
+func (c *Client) ModuleID() uint64 {
+ return c.moduleID
+}
+
+// readLoop runs in a goroutine, reading control messages from the
+// compositor. Currently handles FrameRequest messages.
+func (c *Client) readLoop() {
+ defer c.wg.Done()
+
+ for {
+ select {
+ case <-c.done:
+ return
+ default:
+ }
+
+ hdr, _, err := c.conn.ReadFrame()
+ if err != nil {
+ // Connection closed or error — stop reading.
+ return
+ }
+
+ switch hdr.MsgType {
+ case protocol.MsgFrameRequest:
+ c.mu.RLock()
+ cb := c.onFrameRequest
+ c.mu.RUnlock()
+ if cb != nil {
+ cb()
+ }
+
+ case protocol.MsgDisconnect:
+ return
+
+ default:
+ // Ignore unknown message types for forward compatibility.
+ }
+ }
+}
+
+// SetCompression sets the codec used for frame payload compression.
+// This allows the compositor to negotiate compression during or after
+// handshake. Supported values: "lz4". Any other value uses raw.
+func (c *Client) SetCompression(algo string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.codec = resolveCodec(algo)
+}
+
+// frameToHeader converts a public Frame to a protocol.Header.
+// Used in PublishFrame; extracted here for testing.
+func frameToHeader(f Frame, moduleID uint64, seq uint64, c codec.Codec) protocol.Header {
+ var flags protocol.Flag
+ if !f.DirtyRect.Empty() {
+ flags = flags.Set(protocol.FlagDirtyValid)
+ } else {
+ flags = flags.Set(protocol.FlagKeyframe)
+ }
+
+ if c.ID() != codec.IDRaw {
+ flags = flags.Set(protocol.FlagCompressed)
+ }
+
+ hdr := protocol.Header{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ MsgType: protocol.MsgFrame,
+ Flags: flags,
+ ModuleID: moduleID,
+ Sequence: seq,
+ TimestampNs: f.Timestamp,
+ Width: saturateUint16from32(f.Width),
+ Height: saturateUint16from32(f.Height),
+ Stride: f.Width * 4,
+ PixelFormat: protocol.PixelRGBA8,
+ Compression: protocol.Compression(c.ID()),
+ PayloadSize: saturateUint32(len(f.Pixels)),
+ UncompressedSize: saturateUint32(len(f.Pixels)),
+ }
+
+ if flags.Has(protocol.FlagDirtyValid) {
+ dr := f.DirtyRect.Canon()
+ hdr.DirtyX = saturateUint16(dr.Min.X)
+ hdr.DirtyY = saturateUint16(dr.Min.Y)
+ hdr.DirtyW = saturateUint16(dr.Dx())
+ hdr.DirtyH = saturateUint16(dr.Dy())
+ }
+
+ return hdr
+}
+
+// isDirtyRectValid reports whether the dirty rect represents a valid
+// sub-region (non-empty).
+func isDirtyRectValid(r image.Rectangle) bool {
+ return !r.Empty()
+}
diff --git a/codecov.yml b/codecov.yml
new file mode 100644
index 0000000..8fe38b8
--- /dev/null
+++ b/codecov.yml
@@ -0,0 +1,29 @@
+# Codecov configuration
+# https://docs.codecov.com/docs/codecovyml-reference
+
+coverage:
+ precision: 2
+ round: down
+ range: "80...100"
+ status:
+ project:
+ default:
+ target: 90%
+ threshold: 5%
+ patch:
+ default:
+ target: 80%
+ threshold: 10%
+
+ignore:
+ - "examples/**/*"
+ - "tmp/**/*"
+
+parsers:
+ go:
+ partials_as_hits: false
+
+comment:
+ layout: "header, diff, flags, components"
+ behavior: default
+ require_changes: false
diff --git a/compose.go b/compose.go
new file mode 100644
index 0000000..ae8e88f
--- /dev/null
+++ b/compose.go
@@ -0,0 +1,35 @@
+package compose
+
+import "image"
+
+// Frame is the fundamental data unit: a rectangular pixel buffer from a module.
+// Users construct it to publish (module side) or receive it in callbacks
+// (compositor side).
+//
+// On the module side, ModuleID is ignored in PublishFrame — the server assigns
+// the module's ID automatically. On the compositor side, OnFrame delivers a
+// Frame with ModuleID populated.
+type Frame struct {
+ // ModuleID is the compositor-assigned identifier.
+ // Ignored on publish (server assigns).
+ ModuleID uint64
+
+ // Name is a human-readable module name (e.g., "clock", "weather").
+ // Set during handshake, included in received frames for convenience.
+ Name string
+
+ // Pixels is the RGBA premultiplied pixel buffer.
+ // Stride is always Width * 4.
+ Pixels []byte
+
+ // Width and Height of the frame in pixels.
+ Width uint32
+ Height uint32
+
+ // DirtyRect is the sub-region that changed since the last frame.
+ // Zero value means the entire frame is dirty (keyframe).
+ DirtyRect image.Rectangle
+
+ // Timestamp is a monotonic nanosecond timestamp from the module's clock.
+ Timestamp int64
+}
diff --git a/compose_test.go b/compose_test.go
new file mode 100644
index 0000000..052e59f
--- /dev/null
+++ b/compose_test.go
@@ -0,0 +1,811 @@
+package compose
+
+import (
+ "errors"
+ "image"
+ "os"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/gogpu/compose/internal/protocol"
+)
+
+// tempSocket returns a short temporary Unix socket path.
+// macOS limits Unix socket paths to 104 bytes. t.TempDir() on macOS CI
+// produces paths like /var/folders/.../TestName.../compose.sock which
+// easily exceeds this limit. We use os.CreateTemp with a short prefix
+// to guarantee a short path under /tmp (or %TEMP% on Windows).
+func tempSocket(t *testing.T) string {
+ t.Helper()
+ f, err := os.CreateTemp("", "cs-*.sock")
+ if err != nil {
+ t.Fatal(err)
+ }
+ path := f.Name()
+ f.Close()
+ os.Remove(path) // Remove the file; we need the path for the socket.
+ t.Cleanup(func() { os.Remove(path) })
+ return path
+}
+
+// makePixels creates a simple RGBA pixel buffer of the given dimensions
+// filled with the specified byte value.
+func makePixels(w, h uint32, fill byte) []byte {
+ pixels := make([]byte, w*h*4)
+ for i := range pixels {
+ pixels[i] = fill
+ }
+ return pixels
+}
+
+// waitFor polls a condition function until it returns true or 5 seconds elapse.
+// Returns true if the condition was met. The generous timeout accommodates slow
+// CI runners (macOS shared, Ubuntu containers) where goroutine scheduling may
+// introduce significant delays.
+func waitFor(t *testing.T, condition func() bool) bool {
+ t.Helper()
+ deadline := time.Now().Add(5 * time.Second)
+ for time.Now().Before(deadline) {
+ if condition() {
+ return true
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ return false
+}
+
+func TestListenAndDial(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ client, err := Dial(addr, WithName("test-module"))
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ t.Cleanup(func() { _ = client.Close() })
+
+ if client.ModuleID() == 0 {
+ t.Error("ModuleID should be non-zero after successful Dial")
+ }
+}
+
+func TestFrameRoundTrip(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ var received atomic.Value
+
+ srv.OnFrame(func(f Frame) {
+ received.Store(f)
+ })
+
+ client, err := Dial(addr, WithName("painter"), WithFrameSize(10, 10))
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ t.Cleanup(func() { _ = client.Close() })
+
+ // Publish a frame.
+ pixels := makePixels(10, 10, 0xAA)
+ ts := time.Now().UnixNano()
+
+ err = client.PublishFrame(Frame{
+ Pixels: pixels,
+ Width: 10,
+ Height: 10,
+ Timestamp: ts,
+ })
+ if err != nil {
+ t.Fatalf("PublishFrame: %v", err)
+ }
+
+ // Wait for frame to arrive.
+ ok := waitFor(t, func() bool {
+ return received.Load() != nil
+ })
+ if !ok {
+ t.Fatal("timed out waiting for frame")
+ }
+
+ f := received.Load().(Frame)
+
+ if f.ModuleID == 0 {
+ t.Error("received frame ModuleID should be non-zero")
+ }
+ if f.Name != "painter" {
+ t.Errorf("received frame Name = %q, want %q", f.Name, "painter")
+ }
+ if f.Width != 10 {
+ t.Errorf("received frame Width = %d, want %d", f.Width, 10)
+ }
+ if f.Height != 10 {
+ t.Errorf("received frame Height = %d, want %d", f.Height, 10)
+ }
+ if len(f.Pixels) != 400 {
+ t.Errorf("received frame Pixels len = %d, want %d", len(f.Pixels), 400)
+ }
+ if f.Timestamp != ts {
+ t.Errorf("received frame Timestamp = %d, want %d", f.Timestamp, ts)
+ }
+
+ // Verify pixel content.
+ for i, b := range f.Pixels {
+ if b != 0xAA {
+ t.Errorf("pixel[%d] = 0x%02X, want 0xAA", i, b)
+ break
+ }
+ }
+}
+
+func TestFrameWithDirtyRect(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ var received atomic.Value
+
+ srv.OnFrame(func(f Frame) {
+ received.Store(f)
+ })
+
+ client, err := Dial(addr, WithName("dirty"), WithFrameSize(100, 100))
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ t.Cleanup(func() { _ = client.Close() })
+
+ dirty := image.Rect(10, 20, 50, 80)
+
+ err = client.PublishFrame(Frame{
+ Pixels: makePixels(100, 100, 0xFF),
+ Width: 100,
+ Height: 100,
+ DirtyRect: dirty,
+ })
+ if err != nil {
+ t.Fatalf("PublishFrame: %v", err)
+ }
+
+ ok := waitFor(t, func() bool {
+ return received.Load() != nil
+ })
+ if !ok {
+ t.Fatal("timed out waiting for frame")
+ }
+
+ f := received.Load().(Frame)
+
+ if f.DirtyRect != dirty {
+ t.Errorf("DirtyRect = %v, want %v", f.DirtyRect, dirty)
+ }
+}
+
+func TestMultipleClients(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr, WithMaxModules(4))
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ var mu sync.Mutex
+ framesByModule := make(map[uint64]Frame)
+
+ srv.OnFrame(func(f Frame) {
+ mu.Lock()
+ framesByModule[f.ModuleID] = f
+ mu.Unlock()
+ })
+
+ // Connect three clients.
+ clients := make([]*Client, 3)
+ names := []string{"alpha", "beta", "gamma"}
+ for i, name := range names {
+ c, dialErr := Dial(addr, WithName(name), WithFrameSize(8, 8))
+ if dialErr != nil {
+ t.Fatalf("Dial(%s): %v", name, dialErr)
+ }
+ clients[i] = c
+ t.Cleanup(func() { _ = c.Close() })
+ }
+
+ // Each client publishes a frame with a unique fill byte.
+ for i, c := range clients {
+ fill := byte(i + 1)
+ pubErr := c.PublishFrame(Frame{
+ Pixels: makePixels(8, 8, fill),
+ Width: 8,
+ Height: 8,
+ })
+ if pubErr != nil {
+ t.Fatalf("PublishFrame(%s): %v", names[i], pubErr)
+ }
+ }
+
+ // Wait for all three frames.
+ ok := waitFor(t, func() bool {
+ mu.Lock()
+ n := len(framesByModule)
+ mu.Unlock()
+ return n >= 3
+ })
+ if !ok {
+ mu.Lock()
+ n := len(framesByModule)
+ mu.Unlock()
+ t.Fatalf("timed out: received %d/3 frames", n)
+ }
+
+ // Verify each client got a unique module ID.
+ ids := make(map[uint64]bool)
+ mu.Lock()
+ for id := range framesByModule {
+ ids[id] = true
+ }
+ mu.Unlock()
+ if len(ids) != 3 {
+ t.Errorf("expected 3 unique module IDs, got %d", len(ids))
+ }
+}
+
+func TestOnConnectCallback(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ var connectedName atomic.Value
+ var connectedID atomic.Uint64
+
+ srv.OnConnect(func(id uint64, name string) {
+ connectedID.Store(id)
+ connectedName.Store(name)
+ })
+
+ client, err := Dial(addr, WithName("connector"))
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ t.Cleanup(func() { _ = client.Close() })
+
+ ok := waitFor(t, func() bool {
+ return connectedName.Load() != nil
+ })
+ if !ok {
+ t.Fatal("timed out waiting for OnConnect")
+ }
+
+ if name := connectedName.Load().(string); name != "connector" {
+ t.Errorf("OnConnect name = %q, want %q", name, "connector")
+ }
+ if id := connectedID.Load(); id == 0 {
+ t.Error("OnConnect ID should be non-zero")
+ }
+}
+
+func TestOnDisconnectCallback(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ var disconnectedName atomic.Value
+ var disconnectedID atomic.Uint64
+
+ srv.OnDisconnect(func(id uint64, name string) {
+ disconnectedID.Store(id)
+ disconnectedName.Store(name)
+ })
+
+ client, err := Dial(addr, WithName("leaver"))
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+
+ // Wait for the server to fully process the connection. On CI runners,
+ // goroutine scheduling can be slow, so 200ms gives ample headroom.
+ time.Sleep(200 * time.Millisecond)
+
+ // Close the client — this triggers disconnect.
+ if err := client.Close(); err != nil {
+ t.Fatalf("client.Close: %v", err)
+ }
+
+ ok := waitFor(t, func() bool {
+ return disconnectedName.Load() != nil
+ })
+ if !ok {
+ t.Fatal("timed out waiting for OnDisconnect")
+ }
+
+ if name := disconnectedName.Load().(string); name != "leaver" {
+ t.Errorf("OnDisconnect name = %q, want %q", name, "leaver")
+ }
+ if id := disconnectedID.Load(); id == 0 {
+ t.Error("OnDisconnect ID should be non-zero")
+ }
+}
+
+func TestServerClose(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+
+ client, err := Dial(addr, WithName("stranded"))
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ t.Cleanup(func() { _ = client.Close() })
+
+ // Close server — should disconnect all clients.
+ if err := srv.Close(); err != nil {
+ t.Fatalf("srv.Close: %v", err)
+ }
+
+ // Double close should return ErrClosed.
+ if err := srv.Close(); !errors.Is(err, ErrClosed) {
+ t.Errorf("second Close = %v, want ErrClosed", err)
+ }
+
+ // RequestFrame after close should return ErrClosed.
+ if err := srv.RequestFrame(1); !errors.Is(err, ErrClosed) {
+ t.Errorf("RequestFrame after close = %v, want ErrClosed", err)
+ }
+
+ // Socket file should be removed.
+ if _, err := os.Stat(addr); !os.IsNotExist(err) {
+ t.Errorf("socket file still exists after Close")
+ }
+}
+
+func TestClientDoubleClose(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ client, err := Dial(addr, WithName("closer"))
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+
+ if err := client.Close(); err != nil {
+ t.Fatalf("first Close: %v", err)
+ }
+
+ if err := client.Close(); !errors.Is(err, ErrClosed) {
+ t.Errorf("second Close = %v, want ErrClosed", err)
+ }
+}
+
+func TestPublishAfterClose(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ client, err := Dial(addr, WithName("early-close"))
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+
+ if err := client.Close(); err != nil {
+ t.Fatalf("Close: %v", err)
+ }
+
+ err = client.PublishFrame(Frame{
+ Pixels: makePixels(1, 1, 0),
+ Width: 1,
+ Height: 1,
+ })
+ if !errors.Is(err, ErrClosed) {
+ t.Errorf("PublishFrame after close = %v, want ErrClosed", err)
+ }
+}
+
+func TestRequestFrameTrigger(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ var requestCount atomic.Int32
+
+ // Track connected module ID.
+ var moduleID atomic.Uint64
+ srv.OnConnect(func(id uint64, _ string) {
+ moduleID.Store(id)
+ })
+
+ client, err := Dial(addr, WithName("puller"))
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ t.Cleanup(func() { _ = client.Close() })
+
+ client.OnFrameRequest(func() {
+ requestCount.Add(1)
+ })
+
+ // Wait for connection to be fully established.
+ ok := waitFor(t, func() bool {
+ return moduleID.Load() != 0
+ })
+ if !ok {
+ t.Fatal("timed out waiting for connection")
+ }
+
+ id := moduleID.Load()
+
+ // Server requests a frame.
+ if err := srv.RequestFrame(id); err != nil {
+ t.Fatalf("RequestFrame: %v", err)
+ }
+
+ ok = waitFor(t, func() bool {
+ return requestCount.Load() >= 1
+ })
+ if !ok {
+ t.Fatal("timed out waiting for OnFrameRequest callback")
+ }
+
+ if n := requestCount.Load(); n != 1 {
+ t.Errorf("request count = %d, want 1", n)
+ }
+}
+
+func TestRequestFrameUnknownModule(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ if err := srv.RequestFrame(999); !errors.Is(err, ErrModuleNotFound) {
+ t.Errorf("RequestFrame(999) = %v, want ErrModuleNotFound", err)
+ }
+}
+
+func TestDialNonExistentServer(t *testing.T) {
+ _, err := Dial("/tmp/compose-nonexistent-test.sock")
+ if err == nil {
+ t.Fatal("Dial to non-existent server should fail")
+ }
+}
+
+func TestWithMaxModulesLimit(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr, WithMaxModules(1))
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ // First client should succeed.
+ c1, err := Dial(addr, WithName("first"))
+ if err != nil {
+ t.Fatalf("Dial first: %v", err)
+ }
+ t.Cleanup(func() { _ = c1.Close() })
+
+ // Wait for the first client to be fully registered on the server.
+ // On CI runners, goroutine scheduling can be slow.
+ time.Sleep(200 * time.Millisecond)
+
+ // Second client should be rejected.
+ _, err = Dial(addr, WithName("second"))
+ if err == nil {
+ t.Fatal("Dial second should fail with max modules = 1")
+ }
+}
+
+func TestFrameSequenceNumbers(t *testing.T) {
+ addr := tempSocket(t)
+
+ srv, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ t.Cleanup(func() { _ = srv.Close() })
+
+ client, err := Dial(addr, WithName("sequencer"))
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ t.Cleanup(func() { _ = client.Close() })
+
+ // Publish three frames and verify the internal sequence counter increments.
+ for i := 0; i < 3; i++ {
+ pubErr := client.PublishFrame(Frame{
+ Pixels: makePixels(2, 2, byte(i)),
+ Width: 2,
+ Height: 2,
+ })
+ if pubErr != nil {
+ t.Fatalf("PublishFrame %d: %v", i, pubErr)
+ }
+ }
+
+ // The sequence counter should be 3 after three frames.
+ if seq := client.seq.Load(); seq != 3 {
+ t.Errorf("seq = %d, want 3", seq)
+ }
+}
+
+func TestFunctionalOptionsDefaults(t *testing.T) {
+ // Verify server config defaults.
+ sCfg := defaultServerConfig()
+ if sCfg.maxModules != 16 {
+ t.Errorf("default maxModules = %d, want 16", sCfg.maxModules)
+ }
+ if sCfg.compression != "" {
+ t.Errorf("default compression = %q, want empty", sCfg.compression)
+ }
+
+ // Verify client config defaults.
+ cCfg := defaultClientConfig()
+ if cCfg.name != "module" {
+ t.Errorf("default name = %q, want %q", cCfg.name, "module")
+ }
+ if cCfg.width != 400 {
+ t.Errorf("default width = %d, want 400", cCfg.width)
+ }
+ if cCfg.height != 300 {
+ t.Errorf("default height = %d, want 300", cCfg.height)
+ }
+ if cCfg.fps != 1 {
+ t.Errorf("default fps = %d, want 1", cCfg.fps)
+ }
+}
+
+func TestFunctionalOptionsApply(t *testing.T) {
+ sCfg := defaultServerConfig()
+ WithMaxModules(32)(&sCfg)
+ WithCompression("lz4")(&sCfg)
+
+ if sCfg.maxModules != 32 {
+ t.Errorf("maxModules = %d, want 32", sCfg.maxModules)
+ }
+ if sCfg.compression != "lz4" {
+ t.Errorf("compression = %q, want %q", sCfg.compression, "lz4")
+ }
+
+ cCfg := defaultClientConfig()
+ WithName("clock")(&cCfg)
+ WithFrameSize(800, 600)(&cCfg)
+ WithFPS(60)(&cCfg)
+
+ if cCfg.name != "clock" {
+ t.Errorf("name = %q, want %q", cCfg.name, "clock")
+ }
+ if cCfg.width != 800 {
+ t.Errorf("width = %d, want 800", cCfg.width)
+ }
+ if cCfg.height != 600 {
+ t.Errorf("height = %d, want 600", cCfg.height)
+ }
+ if cCfg.fps != 60 {
+ t.Errorf("fps = %d, want 60", cCfg.fps)
+ }
+}
+
+func TestWithMaxModulesClamping(t *testing.T) {
+ cfg := defaultServerConfig()
+ WithMaxModules(0)(&cfg)
+ if cfg.maxModules != 1 {
+ t.Errorf("maxModules after clamp(0) = %d, want 1", cfg.maxModules)
+ }
+
+ WithMaxModules(-5)(&cfg)
+ if cfg.maxModules != 1 {
+ t.Errorf("maxModules after clamp(-5) = %d, want 1", cfg.maxModules)
+ }
+}
+
+func TestResolveCodec(t *testing.T) {
+ raw := resolveCodec("")
+ if raw.ID() != 0x00 {
+ t.Errorf("resolveCodec(\"\") ID = 0x%02X, want 0x00", raw.ID())
+ }
+
+ lz4 := resolveCodec("lz4")
+ if lz4.ID() != 0x01 {
+ t.Errorf("resolveCodec(\"lz4\") ID = 0x%02X, want 0x01", lz4.ID())
+ }
+
+ unknown := resolveCodec("zstd")
+ if unknown.ID() != 0x00 {
+ t.Errorf("resolveCodec(\"zstd\") ID = 0x%02X, want 0x00 (raw fallback)", unknown.ID())
+ }
+}
+
+func TestIsDirtyRectValid(t *testing.T) {
+ tests := []struct {
+ name string
+ rect image.Rectangle
+ valid bool
+ }{
+ {"zero", image.Rectangle{}, false},
+ {"valid", image.Rect(0, 0, 10, 10), true},
+ {"empty", image.Rect(5, 5, 5, 5), false},
+ // image.Rect canonicalizes, so Rect(10,10,5,5) becomes Rect(5,5,10,10)
+ // which is valid. Use raw struct to create a truly empty rect.
+ {"point", image.Rect(5, 5, 5, 5), false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := isDirtyRectValid(tt.rect)
+ if got != tt.valid {
+ t.Errorf("isDirtyRectValid(%v) = %v, want %v", tt.rect, got, tt.valid)
+ }
+ })
+ }
+}
+
+func TestFrameToHeader(t *testing.T) {
+ f := Frame{
+ Pixels: makePixels(10, 10, 0xFF),
+ Width: 10,
+ Height: 10,
+ DirtyRect: image.Rect(2, 3, 8, 9),
+ Timestamp: 123456789,
+ }
+
+ hdr := frameToHeader(f, 42, 7, resolveCodec(""))
+
+ if hdr.ModuleID != 42 {
+ t.Errorf("ModuleID = %d, want 42", hdr.ModuleID)
+ }
+ if hdr.Sequence != 7 {
+ t.Errorf("Sequence = %d, want 7", hdr.Sequence)
+ }
+ if hdr.Width != 10 {
+ t.Errorf("Width = %d, want 10", hdr.Width)
+ }
+ if hdr.Height != 10 {
+ t.Errorf("Height = %d, want 10", hdr.Height)
+ }
+ if hdr.Stride != 40 {
+ t.Errorf("Stride = %d, want 40", hdr.Stride)
+ }
+ if hdr.TimestampNs != 123456789 {
+ t.Errorf("TimestampNs = %d, want 123456789", hdr.TimestampNs)
+ }
+ if hdr.DirtyX != 2 || hdr.DirtyY != 3 || hdr.DirtyW != 6 || hdr.DirtyH != 6 {
+ t.Errorf("DirtyRect = (%d,%d,%d,%d), want (2,3,6,6)",
+ hdr.DirtyX, hdr.DirtyY, hdr.DirtyW, hdr.DirtyH)
+ }
+}
+
+func TestSentinelErrors(t *testing.T) {
+ // Verify sentinel errors are distinct and not nil.
+ sentinels := []error{ErrClosed, ErrNotAccepted, ErrModuleNotFound, ErrMaxModules, ErrNameTaken}
+ for i, e := range sentinels {
+ if e == nil {
+ t.Errorf("sentinel error %d is nil", i)
+ }
+ }
+
+ // ErrMaxModules and ErrNameTaken should be the same objects as conn package exports.
+ if ErrMaxModules.Error() != "compose: maximum module count reached" {
+ t.Errorf("ErrMaxModules message = %q", ErrMaxModules.Error())
+ }
+ if ErrNameTaken.Error() != "compose: module name already registered" {
+ t.Errorf("ErrNameTaken message = %q", ErrNameTaken.Error())
+ }
+}
+
+func TestHeaderToFrame(t *testing.T) {
+ pixels := makePixels(8, 8, 0xBB)
+
+ tests := []struct {
+ name string
+ flags uint8
+ dirtyX uint16
+ dirtyY uint16
+ dirtyW uint16
+ dirtyH uint16
+ wantDirty image.Rectangle
+ }{
+ {
+ name: "no dirty rect (keyframe)",
+ flags: 0,
+ wantDirty: image.Rectangle{},
+ },
+ {
+ name: "with dirty rect",
+ flags: 0x01, // FlagDirtyValid
+ dirtyX: 2,
+ dirtyY: 3,
+ dirtyW: 4,
+ dirtyH: 5,
+ wantDirty: image.Rect(2, 3, 6, 8),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ hdr := buildTestHeader(tt.flags, tt.dirtyX, tt.dirtyY, tt.dirtyW, tt.dirtyH)
+
+ f := headerToFrame(hdr, pixels, "test")
+
+ if f.ModuleID != 1 {
+ t.Errorf("ModuleID = %d, want 1", f.ModuleID)
+ }
+ if f.Name != "test" {
+ t.Errorf("Name = %q, want %q", f.Name, "test")
+ }
+ if f.Width != 8 {
+ t.Errorf("Width = %d, want 8", f.Width)
+ }
+ if f.Height != 8 {
+ t.Errorf("Height = %d, want 8", f.Height)
+ }
+ if f.DirtyRect != tt.wantDirty {
+ t.Errorf("DirtyRect = %v, want %v", f.DirtyRect, tt.wantDirty)
+ }
+ })
+ }
+}
+
+// buildTestHeader creates a protocol.Header with commonly used test values.
+func buildTestHeader(flags uint8, dirtyX, dirtyY, dirtyW, dirtyH uint16) protocol.Header {
+ return protocol.Header{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ MsgType: protocol.MsgFrame,
+ Flags: protocol.Flag(flags),
+ ModuleID: 1,
+ Sequence: 1,
+ TimestampNs: time.Now().UnixNano(),
+ Width: 8,
+ Height: 8,
+ Stride: 32,
+ DirtyX: dirtyX,
+ DirtyY: dirtyY,
+ DirtyW: dirtyW,
+ DirtyH: dirtyH,
+ PixelFormat: protocol.PixelRGBA8,
+ Compression: protocol.CompressionNone,
+ PayloadSize: 256,
+ }
+}
diff --git a/doc.go b/doc.go
index d844209..3e34d43 100644
--- a/doc.go
+++ b/doc.go
@@ -8,14 +8,53 @@
// possibility of cross-language modules — anything that can write RGBA to a
// Unix socket or shared memory segment can participate.
//
+// # Quick Start
+//
+// Compositor (server) side:
+//
+// srv, err := compose.Listen("/tmp/compose.sock",
+// compose.WithMaxModules(8),
+// )
+// if err != nil {
+// log.Fatal(err)
+// }
+// defer srv.Close()
+//
+// srv.OnFrame(func(f compose.Frame) {
+// // Blit f.Pixels onto the compositor window.
+// })
+//
+// Module (client) side:
+//
+// client, err := compose.Dial("/tmp/compose.sock",
+// compose.WithName("clock"),
+// compose.WithFrameSize(400, 120),
+// )
+// if err != nil {
+// log.Fatal(err)
+// }
+// defer client.Close()
+//
+// client.OnFrameRequest(func() {
+// // Render and publish a frame.
+// _ = client.PublishFrame(compose.Frame{
+// Pixels: renderClock(),
+// Width: 400,
+// Height: 120,
+// })
+// })
+//
+// # Architecture
+//
// Multi-window in gogpu (ADR-010) and the compose model are different concepts.
// Multi-window shares one GPU Device across N native windows inside a single
// process. The compose model is the opposite: N independent processes, each
// with its own GPU Device, cooperating over IPC to produce a single composed
// display.
//
-// Status: design phase. See the repository README for the roadmap and the
-// linked ui#75 discussion for context.
+// The public API consists of five types (Frame, Server, Client, ServerOption,
+// ClientOption) and three constructor functions (Listen, Dial, and With*
+// options). All implementation details live behind internal/ boundaries.
//
// Part of the GoGPU ecosystem: https://github.com/gogpu
package compose
diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md
new file mode 100644
index 0000000..ba87c77
--- /dev/null
+++ b/docs/ARCHITECTURE.md
@@ -0,0 +1,156 @@
+# Architecture
+
+> **Module:** `github.com/gogpu/compose`
+> **Pattern:** Two-tier transport (Unix socket + shared memory)
+> **Inspiration:** Wayland wl_shm, Android SurfaceFlinger, PipeWire SPA
+
+## Overview
+
+```
+┌─────────────────────────────────────────┐
+│ Display (one physical screen) │
+└───────────────────┬─────────────────────┘
+ │
+┌───────────────────┴─────────────────────┐
+│ Compositor process │
+│ compose.Listen("/tmp/compose.sock") │
+│ │
+│ • accepts module connections │
+│ • assigns module IDs │
+│ • pull-based frame requests │
+│ • composites frames onto display │
+└────┬────────────┬────────────┬──────────┘
+ │ │ │
+ socket socket socket
+ │ │ │
+┌────┴─────┐ ┌───┴──────┐ ┌───┴───────────┐
+│ Clock │ │ Weather │ │ Notification │
+│ module │ │ module │ │ module │
+│ 1 Hz │ │ 0.1 Hz │ │ 60 Hz anim │
+│ own PID │ │ own PID │ │ own PID │
+└──────────┘ └──────────┘ └───────────────┘
+```
+
+## Package Structure
+
+```
+compose/ # Public API (13 exported symbols)
+├── compose.go # Frame type
+├── server.go # Server (compositor side)
+├── client.go # Client (module side)
+├── option.go # Functional options
+├── error.go # Sentinel errors
+│
+└── internal/ # Implementation (hidden)
+ ├── protocol/ # Wire format (64B header, handshake)
+ ├── codec/ # Compression (Raw, LZ4)
+ ├── conn/ # Module lifecycle (registry, hot-plug)
+ ├── flow/ # Pull-based pacing (Wayland pattern)
+ └── transport/
+ └── socket/ # Unix domain socket transport
+```
+
+## Public API
+
+Users import ONE package: `"github.com/gogpu/compose"`.
+
+```go
+// Compositor side
+srv, _ := compose.Listen("/tmp/compose.sock",
+ compose.WithMaxModules(8),
+ compose.WithCompression("lz4"),
+)
+srv.OnFrame(func(f compose.Frame) { /* composite */ })
+srv.OnConnect(func(id uint64, name string) { /* module joined */ })
+srv.OnDisconnect(func(id uint64, name string) { /* module left */ })
+
+// Module side
+client, _ := compose.Dial("/tmp/compose.sock",
+ compose.WithName("clock"),
+ compose.WithFrameSize(400, 120),
+ compose.WithFPS(1),
+)
+client.OnFrameRequest(func() { /* render and publish */ })
+client.PublishFrame(compose.Frame{ Pixels: rgba, Width: 400, Height: 120 })
+```
+
+## Wire Protocol v1
+
+64-byte fixed header (cache-line aligned, little-endian):
+
+| Offset | Field | Size | Description |
+|--------|-------|------|-------------|
+| 0 | Magic | 4B | `0x434F4D50` ("COMP") |
+| 4 | Version | 2B | Protocol version |
+| 6 | MsgType | 1B | Frame, Handshake, Ack, FrameRequest, Resize, Disconnect |
+| 7 | Flags | 1B | DirtyValid, Compressed, Keyframe |
+| 8 | ModuleID | 8B | Compositor-assigned |
+| 16 | Sequence | 8B | Monotonic frame counter |
+| 24 | Timestamp | 8B | Monotonic nanoseconds |
+| 32 | Width | 2B | Frame width |
+| 34 | Height | 2B | Frame height |
+| 36 | Stride | 4B | Bytes per row |
+| 40 | DirtyRect | 8B | x, y, w, h (2B each) |
+| 48 | PixelFormat | 1B | RGBA8, BGRA8 |
+| 49 | Compression | 1B | None, LZ4, Zstd |
+| 56 | PayloadSize | 4B | Compressed payload bytes |
+| 60 | UncompressedSize | 4B | Original pixel bytes |
+
+## Flow Control
+
+Pull-based (Wayland frame callback pattern):
+
+1. Compositor → Module: `FrameRequest`
+2. Module renders → sends `Frame`
+3. Compositor processes → sends next `FrameRequest`
+4. Adaptive rate: 3 consecutive misses → halve request rate
+
+## Connection Lifecycle
+
+```
+Module connects → Handshake (name, size, fps)
+ → Compositor assigns ID, fires OnConnect
+ → Frame loop (pull-based)
+ → Module crashes → EOF → OnDisconnect
+ → Module reconnects → new handshake → same slot
+```
+
+## Design Principles
+
+1. **Minimal public surface** — < 15 exported symbols, one import
+2. **Enterprise internal/** — all implementation hidden behind `internal/`
+3. **Zero allocations on hot path** — pre-allocated buffers, reused headers
+4. **Cross-platform** — Linux, macOS, Windows (AF_UNIX), FreeBSD
+5. **Zero CGO** — Pure Go transports, Pure Go compression
+6. **Independent releases** — protocol versioned separately from API
+
+## Performance
+
+| Metric | Value |
+|--------|-------|
+| Header encode/decode | 6–24 ns, 0 allocs |
+| LZ4 compression | 2.9 GB/s encode |
+| GUI pixel compression ratio | 99.6% (flat color) |
+| Socket throughput | 4.3 GB/s |
+| Frame latency (192KB) | 45 μs |
+| Flow control overhead | 37 ns/decision |
+
+## Dependency Graph
+
+```
+compose (root) ──→ internal/protocol (leaf, no deps)
+ ├──→ internal/transport/socket ──→ internal/protocol
+ ├──→ internal/codec (standalone)
+ ├──→ internal/flow (standalone)
+ └──→ internal/conn (standalone)
+```
+
+No circular dependencies. `internal/protocol` is the leaf — imported by others, imports nothing internal.
+
+## Part of GoGPU Ecosystem
+
+```
+naga (shaders) → wgpu (WebGPU) → gogpu (windowing) → gg (2D) → ui (widgets)
+ ↓
+ compose (IPC)
+```
diff --git a/error.go b/error.go
new file mode 100644
index 0000000..f7380b4
--- /dev/null
+++ b/error.go
@@ -0,0 +1,30 @@
+package compose
+
+import (
+ "errors"
+
+ "github.com/gogpu/compose/internal/conn"
+)
+
+// Sentinel errors returned by Server and Client.
+var (
+ // ErrClosed is returned when an operation is attempted on a closed
+ // Server or Client.
+ ErrClosed = errors.New("compose: server/client closed")
+
+ // ErrNotAccepted is returned by Dial when the compositor rejects the
+ // connection (e.g., due to capacity limits or policy).
+ ErrNotAccepted = errors.New("compose: connection not accepted by compositor")
+
+ // ErrModuleNotFound is returned by RequestFrame when the specified
+ // module ID does not exist in the server's module table.
+ ErrModuleNotFound = errors.New("compose: module not found")
+
+ // ErrMaxModules is returned when the server cannot accept a new module
+ // because the maximum module count has been reached.
+ ErrMaxModules = conn.ErrMaxModules
+
+ // ErrNameTaken is returned when a module with the same name is already
+ // connected to the server.
+ ErrNameTaken = conn.ErrNameTaken
+)
diff --git a/go.mod b/go.mod
index ae69c8d..561b7a7 100644
--- a/go.mod
+++ b/go.mod
@@ -1,3 +1,5 @@
module github.com/gogpu/compose
go 1.25
+
+require github.com/pierrec/lz4/v4 v4.1.26
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..dafceee
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,2 @@
+github.com/pierrec/lz4/v4 v4.1.26 h1:GrpZw1gZttORinvzBdXPUXATeqlJjqUG/D87TKMnhjY=
+github.com/pierrec/lz4/v4 v4.1.26/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4=
diff --git a/internal/codec/codec.go b/internal/codec/codec.go
new file mode 100644
index 0000000..b9c7b32
--- /dev/null
+++ b/internal/codec/codec.go
@@ -0,0 +1,68 @@
+package codec
+
+import "sync"
+
+// Protocol compression identifiers.
+const (
+ IDRaw byte = 0x00
+ IDLZ4 byte = 0x01
+)
+
+// Codec compresses and decompresses frame pixel data.
+// Implementations must be safe for concurrent use.
+// Encode/Decode must not allocate on the hot path when the caller provides
+// a sufficiently sized destination buffer.
+type Codec interface {
+ // Encode compresses src into dst. Returns the compressed slice (sub-slice of dst).
+ // dst must be large enough (use MaxEncodedSize to determine required capacity).
+ // If dst is nil or too small, a new buffer is allocated (slow path).
+ Encode(dst, src []byte) ([]byte, error)
+
+ // Decode decompresses src into dst. Returns the decompressed slice.
+ // dst must be large enough to hold the uncompressed data.
+ // If dst is nil or too small, a new buffer is allocated (slow path).
+ Decode(dst, src []byte) ([]byte, error)
+
+ // ID returns the protocol compression identifier.
+ ID() byte
+
+ // MaxEncodedSize returns the maximum possible compressed size for input of
+ // the given length. Use this to pre-allocate destination buffers.
+ MaxEncodedSize(srcLen int) int
+}
+
+var (
+ registryMu sync.RWMutex
+ registry = make(map[byte]Codec)
+)
+
+// Register adds a codec to the global registry. It is called during init()
+// by each codec implementation. Panics if a codec with the same ID is already
+// registered.
+func Register(c Codec) {
+ registryMu.Lock()
+ defer registryMu.Unlock()
+
+ id := c.ID()
+ if _, exists := registry[id]; exists {
+ panic("codec: duplicate registration for ID " + string(rune('0'+id)))
+ }
+ registry[id] = c
+}
+
+// Get returns the codec for the given protocol ID. Returns nil if no codec
+// is registered for that ID.
+func Get(id byte) Codec {
+ registryMu.RLock()
+ defer registryMu.RUnlock()
+
+ return registry[id]
+}
+
+// resetRegistry clears the global registry. Used only in tests.
+func resetRegistry() {
+ registryMu.Lock()
+ defer registryMu.Unlock()
+
+ registry = make(map[byte]Codec)
+}
diff --git a/internal/codec/codec_test.go b/internal/codec/codec_test.go
new file mode 100644
index 0000000..c8f0df6
--- /dev/null
+++ b/internal/codec/codec_test.go
@@ -0,0 +1,83 @@
+package codec
+
+import "testing"
+
+func TestRegisterAndGet(t *testing.T) {
+ // Reset registry for isolated test.
+ resetRegistry()
+ defer func() {
+ // Re-register defaults after test.
+ resetRegistry()
+ Register(Raw())
+ Register(LZ4())
+ }()
+
+ raw := Raw()
+ Register(raw)
+
+ got := Get(IDRaw)
+ if got == nil {
+ t.Fatal("Get(IDRaw) returned nil after Register")
+ }
+ if got.ID() != IDRaw {
+ t.Errorf("Get(IDRaw).ID() = %d, want %d", got.ID(), IDRaw)
+ }
+}
+
+func TestGetUnknownID(t *testing.T) {
+ got := Get(0xFF)
+ if got != nil {
+ t.Errorf("Get(0xFF) = %v, want nil", got)
+ }
+}
+
+func TestGetRegisteredCodecs(t *testing.T) {
+ tests := []struct {
+ name string
+ id byte
+ }{
+ {"Raw", IDRaw},
+ {"LZ4", IDLZ4},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ c := Get(tt.id)
+ if c == nil {
+ t.Fatalf("Get(0x%02x) returned nil", tt.id)
+ }
+ if c.ID() != tt.id {
+ t.Errorf("ID() = 0x%02x, want 0x%02x", c.ID(), tt.id)
+ }
+ })
+ }
+}
+
+func TestRegisterDuplicatePanics(t *testing.T) {
+ resetRegistry()
+ defer func() {
+ resetRegistry()
+ Register(Raw())
+ Register(LZ4())
+ }()
+
+ Register(Raw())
+
+ defer func() {
+ r := recover()
+ if r == nil {
+ t.Fatal("expected panic on duplicate registration, got none")
+ }
+ }()
+
+ // Second registration with same ID should panic.
+ Register(Raw())
+}
+
+func TestCodecConstants(t *testing.T) {
+ if IDRaw != 0x00 {
+ t.Errorf("IDRaw = 0x%02x, want 0x00", IDRaw)
+ }
+ if IDLZ4 != 0x01 {
+ t.Errorf("IDLZ4 = 0x%02x, want 0x01", IDLZ4)
+ }
+}
diff --git a/internal/codec/doc.go b/internal/codec/doc.go
new file mode 100644
index 0000000..4d47bc2
--- /dev/null
+++ b/internal/codec/doc.go
@@ -0,0 +1,15 @@
+// Package codec provides frame payload compression and decompression for the
+// compose protocol. It defines a Codec interface with pluggable implementations
+// and a global registry for protocol-level codec negotiation.
+//
+// Implementations must be safe for concurrent use. The Encode and Decode methods
+// accept caller-provided destination buffers to avoid allocations on the hot path.
+// When the caller provides a sufficiently sized buffer, zero allocations occur.
+//
+// Available codecs:
+// - Raw (ID 0x00): Pass-through copy, no compression.
+// - LZ4 (ID 0x01): LZ4 block compression via github.com/pierrec/lz4/v4.
+//
+// Registration happens automatically via init() in each codec's source file.
+// Use Get(id) to retrieve a codec by its protocol identifier.
+package codec
diff --git a/internal/codec/lz4.go b/internal/codec/lz4.go
new file mode 100644
index 0000000..5125732
--- /dev/null
+++ b/internal/codec/lz4.go
@@ -0,0 +1,116 @@
+package codec
+
+import (
+ "fmt"
+ "sync"
+
+ "github.com/pierrec/lz4/v4"
+)
+
+func init() {
+ Register(LZ4())
+}
+
+// LZ4 returns a codec using LZ4 block compression. It uses a pooled
+// lz4.Compressor to avoid allocations on the hot path. The compressor
+// hash table is reused across calls via sync.Pool.
+//
+// LZ4 provides fast compression with moderate ratios, making it ideal for
+// frame pixel data that contains large flat color regions (typical in GUIs).
+func LZ4() Codec {
+ return &lz4Codec{
+ pool: sync.Pool{
+ New: func() any {
+ var c lz4.Compressor
+ return &c
+ },
+ },
+ }
+}
+
+type lz4Codec struct {
+ pool sync.Pool
+}
+
+// Encode compresses src using LZ4 block compression. Returns a sub-slice of
+// dst containing the compressed data. If dst is nil or too small, allocates a
+// new buffer.
+//
+// If src is empty, returns an empty slice without error.
+func (c *lz4Codec) Encode(dst, src []byte) ([]byte, error) {
+ if len(src) == 0 {
+ if dst == nil {
+ return nil, nil
+ }
+ return dst[:0], nil
+ }
+
+ maxSize := lz4.CompressBlockBound(len(src))
+ if cap(dst) < maxSize {
+ dst = make([]byte, maxSize)
+ } else {
+ dst = dst[:maxSize]
+ }
+
+ compressor := c.pool.Get().(*lz4.Compressor)
+ n, err := compressor.CompressBlock(src, dst)
+ c.pool.Put(compressor)
+
+ if err != nil {
+ return nil, fmt.Errorf("codec: lz4 encode: %w", err)
+ }
+
+ return dst[:n], nil
+}
+
+// Decode decompresses an LZ4 block-compressed payload. Returns a sub-slice of
+// dst containing the decompressed data. If dst is nil or too small, allocates
+// a new buffer.
+//
+// The caller should provide a dst buffer sized to the expected uncompressed
+// length (e.g., Width * Height * 4 for RGBA frames). When dst is nil or too
+// small, an exponential growth strategy is used as a fallback.
+func (c *lz4Codec) Decode(dst, src []byte) ([]byte, error) {
+ if len(src) == 0 {
+ if dst == nil {
+ return nil, nil
+ }
+ return dst[:0], nil
+ }
+
+ if cap(dst) == 0 {
+ // Caller didn't provide a buffer. Start with 10x compressed size
+ // as initial guess. LZ4 GUI data often compresses 100:1 or better,
+ // but 10x covers most cases in one attempt.
+ dst = make([]byte, len(src)*10)
+ } else {
+ dst = dst[:cap(dst)]
+ }
+
+ // maxDecodeBuf caps the growth strategy to prevent runaway allocation
+ // on corrupt or adversarial input (64 MB covers 4K RGBA frames).
+ const maxDecodeBuf = 64 * 1024 * 1024
+
+ for {
+ n, err := lz4.UncompressBlock(src, dst)
+ if err != nil {
+ // If buffer might be too small, double and retry.
+ if len(dst) < maxDecodeBuf {
+ dst = make([]byte, len(dst)*2)
+ continue
+ }
+ return nil, fmt.Errorf("codec: lz4 decode: %w", err)
+ }
+ return dst[:n], nil
+ }
+}
+
+// ID returns the LZ4 codec protocol identifier (0x01).
+func (c *lz4Codec) ID() byte {
+ return IDLZ4
+}
+
+// MaxEncodedSize returns the worst-case encoded size for LZ4 block compression.
+func (c *lz4Codec) MaxEncodedSize(srcLen int) int {
+ return lz4.CompressBlockBound(srcLen)
+}
diff --git a/internal/codec/lz4_test.go b/internal/codec/lz4_test.go
new file mode 100644
index 0000000..417891f
--- /dev/null
+++ b/internal/codec/lz4_test.go
@@ -0,0 +1,341 @@
+package codec
+
+import (
+ "bytes"
+ "crypto/rand"
+ "testing"
+)
+
+func TestLZ4RoundTrip(t *testing.T) {
+ tests := []struct {
+ name string
+ data []byte
+ }{
+ {"single byte", []byte{0x42}},
+ {"small text", []byte("hello world hello world hello world")},
+ {"zeros 1KB", make([]byte, 1024)},
+ {"gui frame 400x120", makeGUIPixels(400, 120)},
+ {"repeated pattern", bytes.Repeat([]byte{0xAA, 0xBB, 0xCC, 0xDD}, 4096)},
+ }
+
+ c := LZ4()
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ src := tt.data
+
+ // Encode with pre-allocated buffer.
+ encBuf := make([]byte, c.MaxEncodedSize(len(src)))
+ encoded, err := c.Encode(encBuf, src)
+ if err != nil {
+ t.Fatalf("Encode: %v", err)
+ }
+
+ // Decode with pre-allocated buffer.
+ decBuf := make([]byte, len(src))
+ decoded, err := c.Decode(decBuf, encoded)
+ if err != nil {
+ t.Fatalf("Decode: %v", err)
+ }
+
+ if !bytes.Equal(decoded, src) {
+ t.Errorf("round-trip mismatch: decoded %d bytes, want %d", len(decoded), len(src))
+ }
+ })
+ }
+}
+
+func TestLZ4EmptyInput(t *testing.T) {
+ c := LZ4()
+
+ // Encode nil input.
+ encoded, err := c.Encode(nil, nil)
+ if err != nil {
+ t.Fatalf("Encode(nil, nil): %v", err)
+ }
+ if encoded != nil {
+ t.Errorf("Encode(nil, nil) = %v, want nil", encoded)
+ }
+
+ // Encode empty slice with dst.
+ dst := make([]byte, 16)
+ encoded, err = c.Encode(dst, []byte{})
+ if err != nil {
+ t.Fatalf("Encode(dst, []byte{}): %v", err)
+ }
+ if len(encoded) != 0 {
+ t.Errorf("Encode empty: len = %d, want 0", len(encoded))
+ }
+
+ // Decode nil input.
+ decoded, err := c.Decode(nil, nil)
+ if err != nil {
+ t.Fatalf("Decode(nil, nil): %v", err)
+ }
+ if decoded != nil {
+ t.Errorf("Decode(nil, nil) = %v, want nil", decoded)
+ }
+
+ // Decode empty slice with dst.
+ decoded, err = c.Decode(dst, []byte{})
+ if err != nil {
+ t.Fatalf("Decode(dst, []byte{}): %v", err)
+ }
+ if len(decoded) != 0 {
+ t.Errorf("Decode empty: len = %d, want 0", len(decoded))
+ }
+}
+
+func TestLZ4CompressionRatio(t *testing.T) {
+ c := LZ4()
+
+ // GUI pixel data should compress well (large flat color regions).
+ src := makeGUIPixels(400, 120)
+ encBuf := make([]byte, c.MaxEncodedSize(len(src)))
+ encoded, err := c.Encode(encBuf, src)
+ if err != nil {
+ t.Fatalf("Encode: %v", err)
+ }
+
+ ratio := float64(len(encoded)) / float64(len(src))
+ t.Logf("GUI pixels: %d -> %d bytes (ratio: %.3f, savings: %.1f%%)",
+ len(src), len(encoded), ratio, (1-ratio)*100)
+
+ // GUI pixel data with large flat regions should compress to at least 50%.
+ if ratio > 0.50 {
+ t.Errorf("compression ratio %.3f is worse than expected 0.50 for GUI data", ratio)
+ }
+}
+
+func TestLZ4RandomData(t *testing.T) {
+ c := LZ4()
+
+ // Random data compresses poorly but must still round-trip correctly.
+ src := make([]byte, 4096)
+ if _, err := rand.Read(src); err != nil {
+ t.Fatalf("rand.Read: %v", err)
+ }
+
+ encBuf := make([]byte, c.MaxEncodedSize(len(src)))
+ encoded, err := c.Encode(encBuf, src)
+ if err != nil {
+ t.Fatalf("Encode random: %v", err)
+ }
+
+ decBuf := make([]byte, len(src))
+ decoded, err := c.Decode(decBuf, encoded)
+ if err != nil {
+ t.Fatalf("Decode random: %v", err)
+ }
+ if !bytes.Equal(decoded, src) {
+ t.Error("round-trip mismatch for random data")
+ }
+}
+
+func TestLZ4NilDst(t *testing.T) {
+ c := LZ4()
+ src := makeGUIPixels(400, 120)
+
+ // nil dst on Encode forces allocation (slow path).
+ encoded, err := c.Encode(nil, src)
+ if err != nil {
+ t.Fatalf("Encode(nil, src): %v", err)
+ }
+
+ // nil dst on Decode forces allocation with growth strategy.
+ decoded, err := c.Decode(nil, encoded)
+ if err != nil {
+ t.Fatalf("Decode(nil, encoded): %v", err)
+ }
+ if !bytes.Equal(decoded, src) {
+ t.Error("round-trip with nil dst: mismatch")
+ }
+}
+
+func TestLZ4SmallDstEncode(t *testing.T) {
+ c := LZ4()
+ src := makeGUIPixels(400, 120)
+
+ // dst too small for MaxEncodedSize -- forces reallocation.
+ dst := make([]byte, 4)
+ encoded, err := c.Encode(dst, src)
+ if err != nil {
+ t.Fatalf("Encode small dst: %v", err)
+ }
+
+ // Verify round-trip.
+ decBuf := make([]byte, len(src))
+ decoded, err := c.Decode(decBuf, encoded)
+ if err != nil {
+ t.Fatalf("Decode: %v", err)
+ }
+ if !bytes.Equal(decoded, src) {
+ t.Error("round-trip mismatch with small encode dst")
+ }
+}
+
+func TestLZ4SmallDstDecode(t *testing.T) {
+ c := LZ4()
+ src := makeGUIPixels(400, 120)
+
+ // Encode normally.
+ encBuf := make([]byte, c.MaxEncodedSize(len(src)))
+ encoded, err := c.Encode(encBuf, src)
+ if err != nil {
+ t.Fatalf("Encode: %v", err)
+ }
+
+ // Decode with dst too small -- triggers retry growth.
+ smallDst := make([]byte, 64)
+ decoded, err := c.Decode(smallDst, encoded)
+ if err != nil {
+ t.Fatalf("Decode small dst: %v", err)
+ }
+ if !bytes.Equal(decoded, src) {
+ t.Error("round-trip mismatch with small decode dst")
+ }
+}
+
+func TestLZ4ID(t *testing.T) {
+ c := LZ4()
+ if c.ID() != IDLZ4 {
+ t.Errorf("ID() = 0x%02x, want 0x%02x", c.ID(), IDLZ4)
+ }
+}
+
+func TestLZ4MaxEncodedSize(t *testing.T) {
+ c := LZ4()
+
+ tests := []int{0, 1, 100, 1024, 192000}
+ for _, srcLen := range tests {
+ maxSize := c.MaxEncodedSize(srcLen)
+ if maxSize < srcLen {
+ t.Errorf("MaxEncodedSize(%d) = %d, want >= %d", srcLen, maxSize, srcLen)
+ }
+ }
+}
+
+func TestLZ4RoundTripVaryingSizes(t *testing.T) {
+ c := LZ4()
+
+ sizes := []int{1, 2, 3, 4, 8, 12, 15, 16, 17, 64, 128, 256, 1024, 8192}
+ for _, size := range sizes {
+ src := make([]byte, size)
+ for i := range src {
+ src[i] = byte(i*17 + 31)
+ }
+
+ encBuf := make([]byte, c.MaxEncodedSize(len(src)))
+ encoded, err := c.Encode(encBuf, src)
+ if err != nil {
+ t.Fatalf("Encode size=%d: %v", size, err)
+ }
+
+ decBuf := make([]byte, len(src))
+ decoded, err := c.Decode(decBuf, encoded)
+ if err != nil {
+ t.Fatalf("Decode size=%d: %v", size, err)
+ }
+ if !bytes.Equal(decoded, src) {
+ t.Errorf("round-trip mismatch for size=%d", size)
+ }
+ }
+}
+
+func TestLZ4ConcurrentSafety(t *testing.T) {
+ c := LZ4()
+ src := makeGUIPixels(400, 120)
+ done := make(chan struct{})
+
+ for i := 0; i < 8; i++ {
+ go func() {
+ defer func() { done <- struct{}{} }()
+ encBuf := make([]byte, c.MaxEncodedSize(len(src)))
+ decBuf := make([]byte, len(src))
+ for j := 0; j < 50; j++ {
+ encoded, err := c.Encode(encBuf, src)
+ if err != nil {
+ t.Errorf("concurrent Encode: %v", err)
+ return
+ }
+ decoded, err := c.Decode(decBuf, encoded)
+ if err != nil {
+ t.Errorf("concurrent Decode: %v", err)
+ return
+ }
+ if !bytes.Equal(decoded, src) {
+ t.Errorf("concurrent round-trip mismatch")
+ return
+ }
+ }
+ }()
+ }
+ for i := 0; i < 8; i++ {
+ <-done
+ }
+}
+
+// BenchmarkLZ4Encode benchmarks LZ4 encoding with a realistic GUI frame (192KB).
+func BenchmarkLZ4Encode(b *testing.B) {
+ c := LZ4()
+ src := makeGUIPixels(400, 120) // 192,000 bytes
+ dst := make([]byte, c.MaxEncodedSize(len(src)))
+
+ b.SetBytes(int64(len(src)))
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, _ = c.Encode(dst, src)
+ }
+}
+
+// BenchmarkLZ4Decode benchmarks LZ4 decoding with a realistic GUI frame.
+func BenchmarkLZ4Decode(b *testing.B) {
+ c := LZ4()
+ src := makeGUIPixels(400, 120)
+ encBuf := make([]byte, c.MaxEncodedSize(len(src)))
+ encoded, err := c.Encode(encBuf, src)
+ if err != nil {
+ b.Fatalf("setup Encode: %v", err)
+ }
+
+ dst := make([]byte, len(src))
+ b.SetBytes(int64(len(src)))
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, _ = c.Decode(dst, encoded)
+ }
+}
+
+// BenchmarkLZ4EncodeFullHD benchmarks LZ4 encoding with a 1920x1080 frame.
+func BenchmarkLZ4EncodeFullHD(b *testing.B) {
+ c := LZ4()
+ src := makeGUIPixels(1920, 1080) // ~8.3 MB
+ dst := make([]byte, c.MaxEncodedSize(len(src)))
+
+ b.SetBytes(int64(len(src)))
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, _ = c.Encode(dst, src)
+ }
+}
+
+// BenchmarkLZ4DecodeFullHD benchmarks LZ4 decoding with a 1920x1080 frame.
+func BenchmarkLZ4DecodeFullHD(b *testing.B) {
+ c := LZ4()
+ src := makeGUIPixels(1920, 1080)
+ encBuf := make([]byte, c.MaxEncodedSize(len(src)))
+ encoded, err := c.Encode(encBuf, src)
+ if err != nil {
+ b.Fatalf("setup Encode: %v", err)
+ }
+
+ dst := make([]byte, len(src))
+ b.SetBytes(int64(len(src)))
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, _ = c.Decode(dst, encoded)
+ }
+}
diff --git a/internal/codec/raw.go b/internal/codec/raw.go
new file mode 100644
index 0000000..bcdbb69
--- /dev/null
+++ b/internal/codec/raw.go
@@ -0,0 +1,59 @@
+package codec
+
+func init() {
+ Register(Raw())
+}
+
+// Raw returns a codec that performs no compression (pass-through copy).
+// Encode and Decode simply copy src into dst. This is the fastest codec
+// and is used when compression overhead is not justified (small frames,
+// already-compressed data, or LAN with abundant bandwidth).
+func Raw() Codec {
+ return rawCodec{}
+}
+
+type rawCodec struct{}
+
+// Encode copies src into dst unchanged. Returns a sub-slice of dst containing
+// the copied data. If dst is nil or too small, allocates a new buffer.
+func (rawCodec) Encode(dst, src []byte) ([]byte, error) {
+ if len(src) == 0 {
+ return dst[:0], nil
+ }
+
+ if cap(dst) < len(src) {
+ dst = make([]byte, len(src))
+ } else {
+ dst = dst[:len(src)]
+ }
+
+ copy(dst, src)
+ return dst, nil
+}
+
+// Decode copies src into dst unchanged. Returns a sub-slice of dst containing
+// the copied data. If dst is nil or too small, allocates a new buffer.
+func (rawCodec) Decode(dst, src []byte) ([]byte, error) {
+ if len(src) == 0 {
+ return dst[:0], nil
+ }
+
+ if cap(dst) < len(src) {
+ dst = make([]byte, len(src))
+ } else {
+ dst = dst[:len(src)]
+ }
+
+ copy(dst, src)
+ return dst, nil
+}
+
+// ID returns the raw codec protocol identifier (0x00).
+func (rawCodec) ID() byte {
+ return IDRaw
+}
+
+// MaxEncodedSize returns srcLen since raw encoding has no overhead.
+func (rawCodec) MaxEncodedSize(srcLen int) int {
+ return srcLen
+}
diff --git a/internal/codec/raw_test.go b/internal/codec/raw_test.go
new file mode 100644
index 0000000..8d6928e
--- /dev/null
+++ b/internal/codec/raw_test.go
@@ -0,0 +1,258 @@
+package codec
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestRawRoundTrip(t *testing.T) {
+ tests := []struct {
+ name string
+ data []byte
+ }{
+ {"empty", nil},
+ {"single byte", []byte{0x42}},
+ {"small", []byte("hello world")},
+ {"zeros", make([]byte, 1024)},
+ {"frame 400x120x4", makeGUIPixels(400, 120)},
+ }
+
+ c := Raw()
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ src := tt.data
+
+ // Encode with pre-allocated buffer.
+ encBuf := make([]byte, c.MaxEncodedSize(len(src)))
+ encoded, err := c.Encode(encBuf, src)
+ if err != nil {
+ t.Fatalf("Encode: %v", err)
+ }
+
+ // Raw codec output should equal input.
+ if !bytes.Equal(encoded, src) {
+ t.Errorf("Encode output differs from input: got %d bytes, want %d", len(encoded), len(src))
+ }
+
+ // Decode with pre-allocated buffer.
+ decBuf := make([]byte, len(src))
+ decoded, err := c.Decode(decBuf, encoded)
+ if err != nil {
+ t.Fatalf("Decode: %v", err)
+ }
+
+ if !bytes.Equal(decoded, src) {
+ t.Errorf("round-trip mismatch: got %d bytes, want %d", len(decoded), len(src))
+ }
+ })
+ }
+}
+
+func TestRawEncodeNilDst(t *testing.T) {
+ c := Raw()
+ src := []byte("test data")
+
+ // nil dst forces allocation (slow path).
+ encoded, err := c.Encode(nil, src)
+ if err != nil {
+ t.Fatalf("Encode(nil, src): %v", err)
+ }
+ if !bytes.Equal(encoded, src) {
+ t.Error("Encode with nil dst: output differs from input")
+ }
+}
+
+func TestRawDecodeNilDst(t *testing.T) {
+ c := Raw()
+ src := []byte("test data")
+
+ decoded, err := c.Decode(nil, src)
+ if err != nil {
+ t.Fatalf("Decode(nil, src): %v", err)
+ }
+ if !bytes.Equal(decoded, src) {
+ t.Error("Decode with nil dst: output differs from input")
+ }
+}
+
+func TestRawEncodeEmptyInput(t *testing.T) {
+ c := Raw()
+ dst := make([]byte, 16)
+
+ encoded, err := c.Encode(dst, nil)
+ if err != nil {
+ t.Fatalf("Encode(dst, nil): %v", err)
+ }
+ if len(encoded) != 0 {
+ t.Errorf("Encode(dst, nil) len = %d, want 0", len(encoded))
+ }
+}
+
+func TestRawDecodeEmptyInput(t *testing.T) {
+ c := Raw()
+ dst := make([]byte, 16)
+
+ decoded, err := c.Decode(dst, nil)
+ if err != nil {
+ t.Fatalf("Decode(dst, nil): %v", err)
+ }
+ if len(decoded) != 0 {
+ t.Errorf("Decode(dst, nil) len = %d, want 0", len(decoded))
+ }
+}
+
+func TestRawEncodeSmallDst(t *testing.T) {
+ c := Raw()
+ src := []byte("longer test data that exceeds small buffer")
+ dst := make([]byte, 4) // too small
+
+ encoded, err := c.Encode(dst, src)
+ if err != nil {
+ t.Fatalf("Encode with small dst: %v", err)
+ }
+ if !bytes.Equal(encoded, src) {
+ t.Error("Encode with small dst: output differs from input")
+ }
+}
+
+func TestRawID(t *testing.T) {
+ c := Raw()
+ if c.ID() != IDRaw {
+ t.Errorf("ID() = 0x%02x, want 0x%02x", c.ID(), IDRaw)
+ }
+}
+
+func TestRawMaxEncodedSize(t *testing.T) {
+ c := Raw()
+ tests := []struct {
+ srcLen int
+ want int
+ }{
+ {0, 0},
+ {1, 1},
+ {1024, 1024},
+ {192000, 192000}, // 400*120*4
+ }
+ for _, tt := range tests {
+ got := c.MaxEncodedSize(tt.srcLen)
+ if got != tt.want {
+ t.Errorf("MaxEncodedSize(%d) = %d, want %d", tt.srcLen, got, tt.want)
+ }
+ }
+}
+
+func TestRawEncodeLargeFrame(t *testing.T) {
+ c := Raw()
+ // Simulate a 1920x1080 RGBA frame (8.3 MB).
+ src := make([]byte, 1920*1080*4)
+ for i := range src {
+ src[i] = byte(i % 256)
+ }
+
+ dst := make([]byte, c.MaxEncodedSize(len(src)))
+ encoded, err := c.Encode(dst, src)
+ if err != nil {
+ t.Fatalf("Encode large frame: %v", err)
+ }
+ if !bytes.Equal(encoded, src) {
+ t.Error("large frame encode: output differs from input")
+ }
+}
+
+func TestRawConcurrentSafety(t *testing.T) {
+ c := Raw()
+ src := makeGUIPixels(400, 120)
+ done := make(chan struct{})
+
+ for i := 0; i < 8; i++ {
+ go func() {
+ defer func() { done <- struct{}{} }()
+ dst := make([]byte, c.MaxEncodedSize(len(src)))
+ for j := 0; j < 100; j++ {
+ encoded, err := c.Encode(dst, src)
+ if err != nil {
+ t.Errorf("concurrent Encode: %v", err)
+ return
+ }
+ if !bytes.Equal(encoded, src) {
+ t.Errorf("concurrent Encode: mismatch")
+ return
+ }
+ }
+ }()
+ }
+ for i := 0; i < 8; i++ {
+ <-done
+ }
+}
+
+// BenchmarkRawEncode benchmarks raw codec encoding with a realistic GUI frame.
+func BenchmarkRawEncode(b *testing.B) {
+ c := Raw()
+ src := makeGUIPixels(400, 120)
+ dst := make([]byte, c.MaxEncodedSize(len(src)))
+
+ b.SetBytes(int64(len(src)))
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, _ = c.Encode(dst, src)
+ }
+}
+
+// BenchmarkRawDecode benchmarks raw codec decoding with a realistic GUI frame.
+func BenchmarkRawDecode(b *testing.B) {
+ c := Raw()
+ src := makeGUIPixels(400, 120)
+ dst := make([]byte, len(src))
+
+ b.SetBytes(int64(len(src)))
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, _ = c.Decode(dst, src)
+ }
+}
+
+// makeGUIPixels generates synthetic GUI-like pixel data with large flat color
+// regions, simulating a typical desktop widget frame (good for compression).
+func makeGUIPixels(width, height int) []byte {
+ size := width * height * 4
+ pixels := make([]byte, size)
+
+ // Background: solid light gray (70% of frame).
+ bgEnd := int(float64(height) * 0.7)
+ for y := 0; y < bgEnd; y++ {
+ for x := 0; x < width; x++ {
+ off := (y*width + x) * 4
+ pixels[off+0] = 0xF0 // R
+ pixels[off+1] = 0xF0 // G
+ pixels[off+2] = 0xF0 // B
+ pixels[off+3] = 0xFF // A
+ }
+ }
+
+ // Button region: solid blue.
+ for y := bgEnd; y < bgEnd+30 && y < height; y++ {
+ for x := 20; x < 120 && x < width; x++ {
+ off := (y*width + x) * 4
+ pixels[off+0] = 0x21 // R
+ pixels[off+1] = 0x96 // G
+ pixels[off+2] = 0xF3 // B
+ pixels[off+3] = 0xFF // A
+ }
+ }
+
+ // Footer: solid dark gray.
+ for y := bgEnd + 30; y < height; y++ {
+ for x := 0; x < width; x++ {
+ off := (y*width + x) * 4
+ pixels[off+0] = 0x30 // R
+ pixels[off+1] = 0x30 // G
+ pixels[off+2] = 0x30 // B
+ pixels[off+3] = 0xFF // A
+ }
+ }
+
+ return pixels
+}
diff --git a/internal/conn/doc.go b/internal/conn/doc.go
new file mode 100644
index 0000000..b904b11
--- /dev/null
+++ b/internal/conn/doc.go
@@ -0,0 +1,16 @@
+// Package conn provides connection lifecycle management for the compose library.
+//
+// It handles module ID allocation, name-to-ID mapping, hot-plug detection,
+// graceful disconnect, and reconnection matching. The package is standalone
+// with no internal dependencies.
+//
+// The core types are:
+//
+// - [Registry] manages module ID allocation and lookup.
+// - [Manager] orchestrates the full module lifecycle (connect, handshake,
+// active, disconnect) and fires event callbacks.
+// - [Module] holds metadata about a connected module.
+// - [State] represents the lifecycle state of a module connection.
+//
+// All exported methods are safe for concurrent use from multiple goroutines.
+package conn
diff --git a/internal/conn/errors.go b/internal/conn/errors.go
new file mode 100644
index 0000000..671736c
--- /dev/null
+++ b/internal/conn/errors.go
@@ -0,0 +1,18 @@
+package conn
+
+import "errors"
+
+// Sentinel errors for the conn package.
+var (
+ // ErrMaxModules is returned when attempting to register a module beyond
+ // the configured maximum capacity.
+ ErrMaxModules = errors.New("compose: maximum module count reached")
+
+ // ErrNameTaken is returned when attempting to register a module with a
+ // name that is already in use by an active module.
+ ErrNameTaken = errors.New("compose: module name already registered")
+
+ // ErrNotFound is returned when a lookup or operation references a module
+ // ID that does not exist in the registry.
+ ErrNotFound = errors.New("compose: module not found")
+)
diff --git a/internal/conn/manager.go b/internal/conn/manager.go
new file mode 100644
index 0000000..7b84d84
--- /dev/null
+++ b/internal/conn/manager.go
@@ -0,0 +1,103 @@
+package conn
+
+import "sync"
+
+// Manager orchestrates module lifecycle: connect, handshake, active, disconnect.
+// It wraps a Registry and adds event callbacks plus reconnection matching.
+// All methods are safe for concurrent access from multiple goroutines.
+type Manager struct {
+ registry *Registry
+
+ mu sync.RWMutex
+ onConnect func(id uint64, name string)
+ onDisconnect func(id uint64, name string)
+}
+
+// NewManager creates a Manager with the given max module count.
+// The maxModules parameter must be positive; values less than 1 are clamped to 1.
+func NewManager(maxModules int) *Manager {
+ return &Manager{
+ registry: NewRegistry(maxModules),
+ }
+}
+
+// HandleConnect processes a new module connection.
+// It allocates an ID, transitions the module to Active state, and fires
+// the OnConnect callback if registered.
+//
+// If a module with the same name was previously disconnected and unregistered,
+// it is treated as a reconnection — a new ID is allocated but the name slot
+// is reused seamlessly.
+//
+// Returns ErrMaxModules if the registry is at capacity.
+// Returns ErrNameTaken if a module with the same name is currently active.
+func (m *Manager) HandleConnect(name string, width, height, fps uint16) (uint64, error) {
+ id, err := m.registry.Register(name, width, height, fps)
+ if err != nil {
+ return 0, err
+ }
+
+ // Transition directly to Active (handshake completed at this point).
+ m.registry.SetState(id, StateActive)
+
+ // Fire callback outside the registry lock to avoid deadlocks in user code.
+ m.mu.RLock()
+ cb := m.onConnect
+ m.mu.RUnlock()
+
+ if cb != nil {
+ cb(id, name)
+ }
+
+ return id, nil
+}
+
+// HandleDisconnect processes a module disconnection (graceful or crash).
+// It transitions the module to Disconnected state, fires the OnDisconnect
+// callback, and removes the module from the registry so its name can be reused.
+//
+// If the module ID does not exist, this is a no-op.
+func (m *Manager) HandleDisconnect(id uint64) {
+ // Look up the module before unregistering so we have the name for the callback.
+ mod, exists := m.registry.Lookup(id)
+ if !exists {
+ return
+ }
+
+ m.registry.SetState(id, StateDisconnected)
+ m.registry.Unregister(id)
+
+ // Fire callback after state transition.
+ m.mu.RLock()
+ cb := m.onDisconnect
+ m.mu.RUnlock()
+
+ if cb != nil {
+ cb(id, mod.Name)
+ }
+}
+
+// OnConnect sets the callback fired when a module becomes active.
+// Only one callback can be set; subsequent calls replace the previous one.
+// Passing nil removes the callback.
+func (m *Manager) OnConnect(fn func(id uint64, name string)) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ m.onConnect = fn
+}
+
+// OnDisconnect sets the callback fired when a module disconnects.
+// Only one callback can be set; subsequent calls replace the previous one.
+// Passing nil removes the callback.
+func (m *Manager) OnDisconnect(fn func(id uint64, name string)) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ m.onDisconnect = fn
+}
+
+// Registry returns the underlying registry for lookups.
+func (m *Manager) Registry() *Registry {
+ return m.registry
+}
diff --git a/internal/conn/manager_test.go b/internal/conn/manager_test.go
new file mode 100644
index 0000000..d3c4327
--- /dev/null
+++ b/internal/conn/manager_test.go
@@ -0,0 +1,362 @@
+package conn
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "testing"
+)
+
+func TestNewManager(t *testing.T) {
+ m := NewManager(16)
+ if m.Registry() == nil {
+ t.Fatal("Registry() returned nil")
+ }
+ if m.Registry().Count() != 0 {
+ t.Errorf("initial count = %d, want 0", m.Registry().Count())
+ }
+}
+
+func TestManager_HandleConnect(t *testing.T) {
+ t.Run("basic connect", func(t *testing.T) {
+ m := NewManager(16)
+ id, err := m.HandleConnect("clock", 400, 120, 1)
+ if err != nil {
+ t.Fatalf("HandleConnect failed: %v", err)
+ }
+ if id == 0 {
+ t.Error("HandleConnect returned zero ID")
+ }
+
+ mod, ok := m.Registry().Lookup(id)
+ if !ok {
+ t.Fatal("module not found after connect")
+ }
+ if mod.State != StateActive {
+ t.Errorf("state = %v, want StateActive", mod.State)
+ }
+ if mod.Name != "clock" {
+ t.Errorf("name = %q, want %q", mod.Name, "clock")
+ }
+ if mod.Width != 400 || mod.Height != 120 {
+ t.Errorf("dimensions = %dx%d, want 400x120", mod.Width, mod.Height)
+ }
+ if mod.FPS != 1 {
+ t.Errorf("fps = %d, want 1", mod.FPS)
+ }
+ })
+
+ t.Run("fires OnConnect callback", func(t *testing.T) {
+ m := NewManager(16)
+
+ var callbackID uint64
+ var callbackName string
+ m.OnConnect(func(id uint64, name string) {
+ callbackID = id
+ callbackName = name
+ })
+
+ id, _ := m.HandleConnect("weather", 320, 240, 30)
+ if callbackID != id {
+ t.Errorf("callback ID = %d, want %d", callbackID, id)
+ }
+ if callbackName != "weather" {
+ t.Errorf("callback name = %q, want %q", callbackName, "weather")
+ }
+ })
+
+ t.Run("no callback when nil", func(t *testing.T) {
+ m := NewManager(16)
+ // No callback set — should not panic.
+ _, err := m.HandleConnect("mod", 100, 100, 30)
+ if err != nil {
+ t.Fatalf("HandleConnect failed: %v", err)
+ }
+ })
+
+ t.Run("max modules error", func(t *testing.T) {
+ m := NewManager(2)
+ _, _ = m.HandleConnect("a", 100, 100, 30)
+ _, _ = m.HandleConnect("b", 100, 100, 30)
+
+ _, err := m.HandleConnect("c", 100, 100, 30)
+ if !errors.Is(err, ErrMaxModules) {
+ t.Errorf("err = %v, want ErrMaxModules", err)
+ }
+ })
+
+ t.Run("name taken error", func(t *testing.T) {
+ m := NewManager(16)
+ _, _ = m.HandleConnect("clock", 400, 120, 1)
+
+ _, err := m.HandleConnect("clock", 400, 120, 1)
+ if !errors.Is(err, ErrNameTaken) {
+ t.Errorf("err = %v, want ErrNameTaken", err)
+ }
+ })
+
+ t.Run("callback not fired on error", func(t *testing.T) {
+ m := NewManager(1)
+ _, _ = m.HandleConnect("a", 100, 100, 30)
+
+ callbackFired := false
+ m.OnConnect(func(_ uint64, _ string) {
+ callbackFired = true
+ })
+
+ _, _ = m.HandleConnect("b", 100, 100, 30) // capacity exceeded
+ if callbackFired {
+ t.Error("OnConnect callback fired on error, should not")
+ }
+ })
+}
+
+func TestManager_HandleDisconnect(t *testing.T) {
+ t.Run("basic disconnect", func(t *testing.T) {
+ m := NewManager(16)
+ id, _ := m.HandleConnect("clock", 400, 120, 1)
+ m.HandleDisconnect(id)
+
+ _, ok := m.Registry().Lookup(id)
+ if ok {
+ t.Error("module still found after disconnect")
+ }
+ if m.Registry().Count() != 0 {
+ t.Errorf("count = %d, want 0", m.Registry().Count())
+ }
+ })
+
+ t.Run("fires OnDisconnect callback", func(t *testing.T) {
+ m := NewManager(16)
+ id, _ := m.HandleConnect("weather", 320, 240, 30)
+
+ var callbackID uint64
+ var callbackName string
+ m.OnDisconnect(func(id uint64, name string) {
+ callbackID = id
+ callbackName = name
+ })
+
+ m.HandleDisconnect(id)
+ if callbackID != id {
+ t.Errorf("callback ID = %d, want %d", callbackID, id)
+ }
+ if callbackName != "weather" {
+ t.Errorf("callback name = %q, want %q", callbackName, "weather")
+ }
+ })
+
+ t.Run("nonexistent ID is no-op", func(t *testing.T) {
+ m := NewManager(16)
+ callbackFired := false
+ m.OnDisconnect(func(_ uint64, _ string) {
+ callbackFired = true
+ })
+
+ m.HandleDisconnect(999) // should not panic
+ if callbackFired {
+ t.Error("OnDisconnect fired for nonexistent ID")
+ }
+ })
+
+ t.Run("double disconnect is safe", func(t *testing.T) {
+ m := NewManager(16)
+ id, _ := m.HandleConnect("mod", 100, 100, 30)
+
+ var count int
+ m.OnDisconnect(func(_ uint64, _ string) {
+ count++
+ })
+
+ m.HandleDisconnect(id)
+ m.HandleDisconnect(id) // second call should be no-op
+
+ if count != 1 {
+ t.Errorf("OnDisconnect fired %d times, want 1", count)
+ }
+ })
+}
+
+func TestManager_Reconnection(t *testing.T) {
+ t.Run("name reusable after disconnect", func(t *testing.T) {
+ m := NewManager(16)
+ id1, _ := m.HandleConnect("clock", 400, 120, 1)
+ m.HandleDisconnect(id1)
+
+ id2, err := m.HandleConnect("clock", 400, 120, 1)
+ if err != nil {
+ t.Fatalf("reconnect failed: %v", err)
+ }
+ if id2 <= id1 {
+ t.Errorf("reconnect ID not monotonic: old=%d, new=%d", id1, id2)
+ }
+ })
+
+ t.Run("reconnection fires callbacks", func(t *testing.T) {
+ m := NewManager(16)
+
+ var connectCount, disconnectCount int
+ m.OnConnect(func(_ uint64, _ string) { connectCount++ })
+ m.OnDisconnect(func(_ uint64, _ string) { disconnectCount++ })
+
+ id, _ := m.HandleConnect("clock", 400, 120, 1)
+ m.HandleDisconnect(id)
+ _, _ = m.HandleConnect("clock", 400, 120, 1)
+
+ if connectCount != 2 {
+ t.Errorf("connect count = %d, want 2", connectCount)
+ }
+ if disconnectCount != 1 {
+ t.Errorf("disconnect count = %d, want 1", disconnectCount)
+ }
+ })
+
+ t.Run("slot freed for capacity", func(t *testing.T) {
+ m := NewManager(2)
+ id1, _ := m.HandleConnect("a", 100, 100, 30)
+ _, _ = m.HandleConnect("b", 100, 100, 30)
+
+ // At capacity.
+ _, err := m.HandleConnect("c", 100, 100, 30)
+ if !errors.Is(err, ErrMaxModules) {
+ t.Fatalf("expected ErrMaxModules, got %v", err)
+ }
+
+ // Disconnect one — slot freed.
+ m.HandleDisconnect(id1)
+ _, err = m.HandleConnect("c", 100, 100, 30)
+ if err != nil {
+ t.Errorf("connect after disconnect failed: %v", err)
+ }
+ })
+}
+
+func TestManager_CallbackReplacement(t *testing.T) {
+ m := NewManager(16)
+
+ var firstCalled, secondCalled bool
+
+ m.OnConnect(func(_ uint64, _ string) { firstCalled = true })
+ m.OnConnect(func(_ uint64, _ string) { secondCalled = true })
+
+ m.HandleConnect("mod", 100, 100, 30)
+
+ if firstCalled {
+ t.Error("first callback called after replacement")
+ }
+ if !secondCalled {
+ t.Error("second callback not called")
+ }
+}
+
+func TestManager_NilCallback(t *testing.T) {
+ m := NewManager(16)
+
+ m.OnConnect(func(_ uint64, _ string) {})
+ m.OnConnect(nil) // remove callback
+
+ // Should not panic.
+ _, _ = m.HandleConnect("mod", 100, 100, 30)
+}
+
+func TestManager_ConcurrentAccess(t *testing.T) {
+ m := NewManager(500)
+
+ var connectCount atomic.Int64
+ var disconnectCount atomic.Int64
+
+ m.OnConnect(func(_ uint64, _ string) {
+ connectCount.Add(1)
+ })
+ m.OnDisconnect(func(_ uint64, _ string) {
+ disconnectCount.Add(1)
+ })
+
+ var wg sync.WaitGroup
+
+ // Concurrent connects.
+ ids := make([]uint64, 200)
+ var idMu sync.Mutex
+ for i := range 200 {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ name := fmt.Sprintf("module-%d", idx)
+ id, err := m.HandleConnect(name, 100, 100, 30)
+ if err == nil {
+ idMu.Lock()
+ ids[idx] = id
+ idMu.Unlock()
+ }
+ }(i)
+ }
+ wg.Wait()
+
+ // Concurrent disconnects.
+ for i := range 200 {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ idMu.Lock()
+ id := ids[idx]
+ idMu.Unlock()
+ if id != 0 {
+ m.HandleDisconnect(id)
+ }
+ }(i)
+ }
+ wg.Wait()
+
+ // After all disconnects, registry should be empty.
+ if m.Registry().Count() != 0 {
+ t.Errorf("count after all disconnects = %d, want 0", m.Registry().Count())
+ }
+
+ // Verify callback counts match.
+ connected := connectCount.Load()
+ disconnected := disconnectCount.Load()
+ if connected != disconnected {
+ t.Errorf("connect count (%d) != disconnect count (%d)", connected, disconnected)
+ }
+}
+
+func TestManager_Registry(t *testing.T) {
+ m := NewManager(16)
+ r := m.Registry()
+ if r == nil {
+ t.Fatal("Registry() returned nil")
+ }
+
+ // Verify it is the same instance.
+ id, _ := m.HandleConnect("mod", 100, 100, 30)
+ mod, ok := r.Lookup(id)
+ if !ok {
+ t.Fatal("registry lookup failed for module connected via manager")
+ }
+ if mod.Name != "mod" {
+ t.Errorf("name = %q, want %q", mod.Name, "mod")
+ }
+}
+
+func BenchmarkHandleConnect(b *testing.B) {
+ m := NewManager(b.N + 1)
+ b.ResetTimer()
+ for i := range b.N {
+ name := fmt.Sprintf("module-%d", i)
+ _, _ = m.HandleConnect(name, 100, 100, 30)
+ }
+}
+
+func BenchmarkHandleDisconnect(b *testing.B) {
+ m := NewManager(b.N + 1)
+ ids := make([]uint64, b.N)
+ for i := range b.N {
+ ids[i], _ = m.HandleConnect(fmt.Sprintf("module-%d", i), 100, 100, 30)
+ }
+
+ b.ResetTimer()
+ for i := range b.N {
+ m.HandleDisconnect(ids[i])
+ }
+}
diff --git a/internal/conn/module.go b/internal/conn/module.go
new file mode 100644
index 0000000..3956145
--- /dev/null
+++ b/internal/conn/module.go
@@ -0,0 +1,65 @@
+package conn
+
+import "time"
+
+// State represents the lifecycle state of a module connection.
+type State uint8
+
+const (
+ // StateConnecting indicates a connection has been initiated but the
+ // handshake has not yet started.
+ StateConnecting State = iota
+
+ // StateHandshaking indicates the handshake is in progress.
+ StateHandshaking
+
+ // StateActive indicates the module is fully connected and sending frames.
+ StateActive
+
+ // StateDisconnected indicates the connection was lost or gracefully closed.
+ StateDisconnected
+)
+
+// String returns a human-readable name for the state.
+func (s State) String() string {
+ switch s {
+ case StateConnecting:
+ return "connecting"
+ case StateHandshaking:
+ return "handshaking"
+ case StateActive:
+ return "active"
+ case StateDisconnected:
+ return "disconnected"
+ default:
+ return "unknown"
+ }
+}
+
+// Module holds metadata about a connected module.
+type Module struct {
+ // ID is the compositor-assigned unique identifier for this module.
+ // IDs are monotonically increasing and never reused within a process lifetime.
+ ID uint64
+
+ // Name is the human-readable module name (e.g., "clock", "weather").
+ Name string
+
+ // State is the current lifecycle state of the module connection.
+ State State
+
+ // Width is the frame width in pixels.
+ Width uint16
+
+ // Height is the frame height in pixels.
+ Height uint16
+
+ // FPS is the requested frame rate.
+ FPS uint16
+
+ // ConnectedAt is the time the module first connected.
+ ConnectedAt time.Time
+
+ // LastFrameAt is the time the most recent frame was received.
+ LastFrameAt time.Time
+}
diff --git a/internal/conn/registry.go b/internal/conn/registry.go
new file mode 100644
index 0000000..eb645ef
--- /dev/null
+++ b/internal/conn/registry.go
@@ -0,0 +1,159 @@
+package conn
+
+import (
+ "sync"
+ "time"
+)
+
+// Registry manages module ID allocation and name-to-ID mapping.
+// All methods are safe for concurrent access from multiple goroutines.
+type Registry struct {
+ mu sync.RWMutex
+ modules map[uint64]*Module
+ nameIndex map[string]uint64
+ nextID uint64
+ maxModules int
+}
+
+// NewRegistry creates a registry with the given maximum capacity.
+// The maxModules parameter must be positive; values less than 1 are
+// clamped to 1.
+func NewRegistry(maxModules int) *Registry {
+ if maxModules < 1 {
+ maxModules = 1
+ }
+ return &Registry{
+ modules: make(map[uint64]*Module),
+ nameIndex: make(map[string]uint64),
+ nextID: 1, // IDs start at 1, never 0
+ maxModules: maxModules,
+ }
+}
+
+// Register allocates a new module ID and stores the module.
+// Returns ErrMaxModules if at capacity.
+// Returns ErrNameTaken if a module with the same name is already active.
+func (r *Registry) Register(name string, width, height, fps uint16) (uint64, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if len(r.modules) >= r.maxModules {
+ return 0, ErrMaxModules
+ }
+
+ if _, exists := r.nameIndex[name]; exists {
+ return 0, ErrNameTaken
+ }
+
+ id := r.nextID
+ r.nextID++
+
+ mod := &Module{
+ ID: id,
+ Name: name,
+ State: StateConnecting,
+ Width: width,
+ Height: height,
+ FPS: fps,
+ ConnectedAt: time.Now(),
+ }
+
+ r.modules[id] = mod
+ r.nameIndex[name] = id
+
+ return id, nil
+}
+
+// Unregister removes a module from the registry, freeing its slot.
+// The module's name becomes available for reuse. If the module ID does
+// not exist, this is a no-op.
+func (r *Registry) Unregister(id uint64) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ mod, exists := r.modules[id]
+ if !exists {
+ return
+ }
+
+ delete(r.nameIndex, mod.Name)
+ delete(r.modules, id)
+}
+
+// Lookup returns the module by ID. Returns (nil, false) if not found.
+func (r *Registry) Lookup(id uint64) (*Module, bool) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ mod, exists := r.modules[id]
+ if !exists {
+ return nil, false
+ }
+
+ // Return a copy to prevent data races on the caller's side.
+ cp := *mod
+ return &cp, true
+}
+
+// LookupByName returns the module by name. Returns (nil, false) if not found.
+func (r *Registry) LookupByName(name string) (*Module, bool) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ id, exists := r.nameIndex[name]
+ if !exists {
+ return nil, false
+ }
+
+ mod, exists := r.modules[id]
+ if !exists {
+ return nil, false
+ }
+
+ cp := *mod
+ return &cp, true
+}
+
+// SetState updates a module's state. If the module ID does not exist,
+// this is a no-op.
+func (r *Registry) SetState(id uint64, state State) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if mod, exists := r.modules[id]; exists {
+ mod.State = state
+ }
+}
+
+// UpdateLastFrame updates the last frame timestamp for a module.
+// If the module ID does not exist, this is a no-op.
+func (r *Registry) UpdateLastFrame(id uint64, t time.Time) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if mod, exists := r.modules[id]; exists {
+ mod.LastFrameAt = t
+ }
+}
+
+// Count returns the number of currently registered modules.
+func (r *Registry) Count() int {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ return len(r.modules)
+}
+
+// All returns a snapshot of all registered modules. The returned slice
+// contains copies; modifying them does not affect the registry.
+func (r *Registry) All() []*Module {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ result := make([]*Module, 0, len(r.modules))
+ for _, mod := range r.modules {
+ cp := *mod
+ result = append(result, &cp)
+ }
+ return result
+}
diff --git a/internal/conn/registry_test.go b/internal/conn/registry_test.go
new file mode 100644
index 0000000..a7150c0
--- /dev/null
+++ b/internal/conn/registry_test.go
@@ -0,0 +1,475 @@
+package conn
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+)
+
+func TestNewRegistry(t *testing.T) {
+ t.Run("positive capacity", func(t *testing.T) {
+ r := NewRegistry(16)
+ if r.Count() != 0 {
+ t.Errorf("new registry count = %d, want 0", r.Count())
+ }
+ })
+
+ t.Run("zero capacity clamped to 1", func(t *testing.T) {
+ r := NewRegistry(0)
+ _, err := r.Register("a", 100, 100, 30)
+ if err != nil {
+ t.Fatalf("first register failed: %v", err)
+ }
+ _, err = r.Register("b", 100, 100, 30)
+ if !errors.Is(err, ErrMaxModules) {
+ t.Errorf("second register err = %v, want ErrMaxModules", err)
+ }
+ })
+
+ t.Run("negative capacity clamped to 1", func(t *testing.T) {
+ r := NewRegistry(-5)
+ _, err := r.Register("a", 100, 100, 30)
+ if err != nil {
+ t.Fatalf("first register failed: %v", err)
+ }
+ _, err = r.Register("b", 100, 100, 30)
+ if !errors.Is(err, ErrMaxModules) {
+ t.Errorf("second register err = %v, want ErrMaxModules", err)
+ }
+ })
+}
+
+func TestRegistry_Register(t *testing.T) {
+ t.Run("basic registration", func(t *testing.T) {
+ r := NewRegistry(16)
+ id, err := r.Register("clock", 400, 120, 1)
+ if err != nil {
+ t.Fatalf("register failed: %v", err)
+ }
+ if id != 1 {
+ t.Errorf("first ID = %d, want 1", id)
+ }
+ if r.Count() != 1 {
+ t.Errorf("count = %d, want 1", r.Count())
+ }
+ })
+
+ t.Run("monotonic IDs", func(t *testing.T) {
+ r := NewRegistry(16)
+ id1, _ := r.Register("a", 100, 100, 30)
+ id2, _ := r.Register("b", 100, 100, 30)
+ id3, _ := r.Register("c", 100, 100, 30)
+
+ if id1 >= id2 || id2 >= id3 {
+ t.Errorf("IDs not monotonic: %d, %d, %d", id1, id2, id3)
+ }
+ })
+
+ t.Run("IDs never reused after unregister", func(t *testing.T) {
+ r := NewRegistry(16)
+ id1, _ := r.Register("a", 100, 100, 30)
+ r.Unregister(id1)
+ id2, _ := r.Register("b", 100, 100, 30)
+
+ if id2 <= id1 {
+ t.Errorf("ID reused: id1=%d, id2=%d", id1, id2)
+ }
+ })
+
+ t.Run("max capacity", func(t *testing.T) {
+ r := NewRegistry(3)
+ for i := range 3 {
+ _, err := r.Register(fmt.Sprintf("mod%d", i), 100, 100, 30)
+ if err != nil {
+ t.Fatalf("register %d failed: %v", i, err)
+ }
+ }
+
+ _, err := r.Register("overflow", 100, 100, 30)
+ if !errors.Is(err, ErrMaxModules) {
+ t.Errorf("overflow err = %v, want ErrMaxModules", err)
+ }
+ })
+
+ t.Run("name collision", func(t *testing.T) {
+ r := NewRegistry(16)
+ _, err := r.Register("clock", 400, 120, 1)
+ if err != nil {
+ t.Fatalf("first register failed: %v", err)
+ }
+
+ _, err = r.Register("clock", 400, 120, 1)
+ if !errors.Is(err, ErrNameTaken) {
+ t.Errorf("duplicate name err = %v, want ErrNameTaken", err)
+ }
+ })
+
+ t.Run("name reusable after unregister", func(t *testing.T) {
+ r := NewRegistry(16)
+ id, _ := r.Register("clock", 400, 120, 1)
+ r.Unregister(id)
+
+ id2, err := r.Register("clock", 400, 120, 1)
+ if err != nil {
+ t.Fatalf("re-register failed: %v", err)
+ }
+ if id2 <= id {
+ t.Errorf("reused ID: old=%d, new=%d", id, id2)
+ }
+ })
+
+ t.Run("initial state is connecting", func(t *testing.T) {
+ r := NewRegistry(16)
+ id, _ := r.Register("mod", 100, 100, 30)
+ mod, ok := r.Lookup(id)
+ if !ok {
+ t.Fatal("lookup failed")
+ }
+ if mod.State != StateConnecting {
+ t.Errorf("initial state = %v, want StateConnecting", mod.State)
+ }
+ })
+
+ t.Run("metadata stored correctly", func(t *testing.T) {
+ r := NewRegistry(16)
+ before := time.Now()
+ id, _ := r.Register("weather", 320, 240, 60)
+ after := time.Now()
+
+ mod, ok := r.Lookup(id)
+ if !ok {
+ t.Fatal("lookup failed")
+ }
+ if mod.Name != "weather" {
+ t.Errorf("name = %q, want %q", mod.Name, "weather")
+ }
+ if mod.Width != 320 {
+ t.Errorf("width = %d, want 320", mod.Width)
+ }
+ if mod.Height != 240 {
+ t.Errorf("height = %d, want 240", mod.Height)
+ }
+ if mod.FPS != 60 {
+ t.Errorf("fps = %d, want 60", mod.FPS)
+ }
+ if mod.ConnectedAt.Before(before) || mod.ConnectedAt.After(after) {
+ t.Errorf("ConnectedAt = %v, want between %v and %v", mod.ConnectedAt, before, after)
+ }
+ })
+}
+
+func TestRegistry_Unregister(t *testing.T) {
+ t.Run("removes module", func(t *testing.T) {
+ r := NewRegistry(16)
+ id, _ := r.Register("mod", 100, 100, 30)
+ r.Unregister(id)
+
+ if r.Count() != 0 {
+ t.Errorf("count after unregister = %d, want 0", r.Count())
+ }
+ _, ok := r.Lookup(id)
+ if ok {
+ t.Error("lookup succeeded after unregister, want not found")
+ }
+ })
+
+ t.Run("nonexistent ID is no-op", func(t *testing.T) {
+ r := NewRegistry(16)
+ r.Unregister(999) // should not panic
+ })
+
+ t.Run("frees capacity slot", func(t *testing.T) {
+ r := NewRegistry(2)
+ id1, _ := r.Register("a", 100, 100, 30)
+ _, _ = r.Register("b", 100, 100, 30)
+
+ // At capacity.
+ _, err := r.Register("c", 100, 100, 30)
+ if !errors.Is(err, ErrMaxModules) {
+ t.Fatalf("expected ErrMaxModules, got %v", err)
+ }
+
+ // Unregister one — should free a slot.
+ r.Unregister(id1)
+ _, err = r.Register("c", 100, 100, 30)
+ if err != nil {
+ t.Errorf("register after unregister failed: %v", err)
+ }
+ })
+}
+
+func TestRegistry_Lookup(t *testing.T) {
+ t.Run("existing module", func(t *testing.T) {
+ r := NewRegistry(16)
+ id, _ := r.Register("mod", 100, 200, 30)
+ mod, ok := r.Lookup(id)
+ if !ok {
+ t.Fatal("lookup failed")
+ }
+ if mod.ID != id || mod.Name != "mod" {
+ t.Errorf("lookup returned wrong module: %+v", mod)
+ }
+ })
+
+ t.Run("nonexistent module", func(t *testing.T) {
+ r := NewRegistry(16)
+ _, ok := r.Lookup(42)
+ if ok {
+ t.Error("lookup succeeded for nonexistent ID")
+ }
+ })
+
+ t.Run("returns copy", func(t *testing.T) {
+ r := NewRegistry(16)
+ id, _ := r.Register("mod", 100, 100, 30)
+ mod1, _ := r.Lookup(id)
+ mod1.Name = "mutated"
+
+ mod2, _ := r.Lookup(id)
+ if mod2.Name == "mutated" {
+ t.Error("Lookup returned a reference instead of a copy")
+ }
+ })
+}
+
+func TestRegistry_LookupByName(t *testing.T) {
+ t.Run("existing module", func(t *testing.T) {
+ r := NewRegistry(16)
+ id, _ := r.Register("clock", 400, 120, 1)
+ mod, ok := r.LookupByName("clock")
+ if !ok {
+ t.Fatal("lookup by name failed")
+ }
+ if mod.ID != id {
+ t.Errorf("ID = %d, want %d", mod.ID, id)
+ }
+ })
+
+ t.Run("nonexistent name", func(t *testing.T) {
+ r := NewRegistry(16)
+ _, ok := r.LookupByName("nonexistent")
+ if ok {
+ t.Error("lookup by name succeeded for nonexistent name")
+ }
+ })
+}
+
+func TestRegistry_SetState(t *testing.T) {
+ t.Run("transitions state", func(t *testing.T) {
+ r := NewRegistry(16)
+ id, _ := r.Register("mod", 100, 100, 30)
+
+ r.SetState(id, StateHandshaking)
+ mod, _ := r.Lookup(id)
+ if mod.State != StateHandshaking {
+ t.Errorf("state = %v, want StateHandshaking", mod.State)
+ }
+
+ r.SetState(id, StateActive)
+ mod, _ = r.Lookup(id)
+ if mod.State != StateActive {
+ t.Errorf("state = %v, want StateActive", mod.State)
+ }
+
+ r.SetState(id, StateDisconnected)
+ mod, _ = r.Lookup(id)
+ if mod.State != StateDisconnected {
+ t.Errorf("state = %v, want StateDisconnected", mod.State)
+ }
+ })
+
+ t.Run("nonexistent ID is no-op", func(t *testing.T) {
+ r := NewRegistry(16)
+ r.SetState(999, StateActive) // should not panic
+ })
+}
+
+func TestRegistry_UpdateLastFrame(t *testing.T) {
+ t.Run("updates timestamp", func(t *testing.T) {
+ r := NewRegistry(16)
+ id, _ := r.Register("mod", 100, 100, 30)
+
+ ts := time.Now().Add(5 * time.Second)
+ r.UpdateLastFrame(id, ts)
+
+ mod, _ := r.Lookup(id)
+ if !mod.LastFrameAt.Equal(ts) {
+ t.Errorf("LastFrameAt = %v, want %v", mod.LastFrameAt, ts)
+ }
+ })
+
+ t.Run("nonexistent ID is no-op", func(t *testing.T) {
+ r := NewRegistry(16)
+ r.UpdateLastFrame(999, time.Now()) // should not panic
+ })
+}
+
+func TestRegistry_Count(t *testing.T) {
+ r := NewRegistry(16)
+ if r.Count() != 0 {
+ t.Fatalf("initial count = %d, want 0", r.Count())
+ }
+
+ r.Register("a", 100, 100, 30)
+ r.Register("b", 100, 100, 30)
+ if r.Count() != 2 {
+ t.Errorf("count = %d, want 2", r.Count())
+ }
+
+ id, _ := r.Register("c", 100, 100, 30)
+ r.Unregister(id)
+ if r.Count() != 2 {
+ t.Errorf("count after unregister = %d, want 2", r.Count())
+ }
+}
+
+func TestRegistry_All(t *testing.T) {
+ t.Run("empty registry", func(t *testing.T) {
+ r := NewRegistry(16)
+ all := r.All()
+ if len(all) != 0 {
+ t.Errorf("All() len = %d, want 0", len(all))
+ }
+ })
+
+ t.Run("returns all modules", func(t *testing.T) {
+ r := NewRegistry(16)
+ r.Register("a", 100, 100, 30)
+ r.Register("b", 200, 200, 60)
+ r.Register("c", 300, 300, 1)
+
+ all := r.All()
+ if len(all) != 3 {
+ t.Fatalf("All() len = %d, want 3", len(all))
+ }
+
+ names := make(map[string]bool)
+ for _, m := range all {
+ names[m.Name] = true
+ }
+ for _, want := range []string{"a", "b", "c"} {
+ if !names[want] {
+ t.Errorf("missing module %q in All() result", want)
+ }
+ }
+ })
+
+ t.Run("returns copies", func(t *testing.T) {
+ r := NewRegistry(16)
+ r.Register("mod", 100, 100, 30)
+
+ all := r.All()
+ all[0].Name = "mutated"
+
+ mod, _ := r.LookupByName("mod")
+ if mod == nil {
+ t.Error("original module name was mutated via All() result")
+ }
+ })
+}
+
+func TestRegistry_ConcurrentAccess(t *testing.T) {
+ r := NewRegistry(1000)
+ var wg sync.WaitGroup
+
+ // Concurrent writers (register).
+ for i := range 100 {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ name := fmt.Sprintf("module-%d", idx)
+ _, _ = r.Register(name, 100, 100, 30)
+ }(i)
+ }
+
+ // Concurrent readers (lookup, count, all).
+ for range 50 {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ _ = r.Count()
+ _ = r.All()
+ _, _ = r.Lookup(1)
+ _, _ = r.LookupByName("module-0")
+ }()
+ }
+
+ // Concurrent state updates.
+ for i := range 50 {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ r.SetState(uint64(idx+1), StateActive)
+ r.UpdateLastFrame(uint64(idx+1), time.Now())
+ }(i)
+ }
+
+ wg.Wait()
+
+ // Verify consistency: count matches actual registered modules.
+ count := r.Count()
+ all := r.All()
+ if count != len(all) {
+ t.Errorf("count=%d != len(All())=%d", count, len(all))
+ }
+}
+
+func TestState_String(t *testing.T) {
+ tests := []struct {
+ state State
+ want string
+ }{
+ {StateConnecting, "connecting"},
+ {StateHandshaking, "handshaking"},
+ {StateActive, "active"},
+ {StateDisconnected, "disconnected"},
+ {State(99), "unknown"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.want, func(t *testing.T) {
+ got := tt.state.String()
+ if got != tt.want {
+ t.Errorf("State(%d).String() = %q, want %q", tt.state, got, tt.want)
+ }
+ })
+ }
+}
+
+func BenchmarkRegister(b *testing.B) {
+ r := NewRegistry(b.N + 1)
+ b.ResetTimer()
+ for i := range b.N {
+ name := fmt.Sprintf("module-%d", i)
+ _, _ = r.Register(name, 100, 100, 30)
+ }
+}
+
+func BenchmarkLookup(b *testing.B) {
+ r := NewRegistry(1000)
+ for i := range 1000 {
+ r.Register(fmt.Sprintf("module-%d", i), 100, 100, 30)
+ }
+
+ b.ResetTimer()
+ for i := range b.N {
+ id := uint64(i%1000) + 1
+ _, _ = r.Lookup(id)
+ }
+}
+
+func BenchmarkLookupByName(b *testing.B) {
+ r := NewRegistry(1000)
+ names := make([]string, 1000)
+ for i := range 1000 {
+ names[i] = fmt.Sprintf("module-%d", i)
+ r.Register(names[i], 100, 100, 30)
+ }
+
+ b.ResetTimer()
+ for i := range b.N {
+ _, _ = r.LookupByName(names[i%1000])
+ }
+}
diff --git a/internal/flow/controller.go b/internal/flow/controller.go
new file mode 100644
index 0000000..c78a2c5
--- /dev/null
+++ b/internal/flow/controller.go
@@ -0,0 +1,198 @@
+package flow
+
+import (
+ "sync"
+ "time"
+)
+
+// defaultFPS is used when a module specifies 0 FPS (static content).
+const defaultFPS = 1
+
+// missedThreshold is the number of consecutive missed frames before the
+// controller halves the effective request rate for a module.
+const missedThreshold = 3
+
+// moduleState holds per-module frame pacing state.
+type moduleState struct {
+ targetFPS uint16 // module's preferred FPS (from handshake)
+ interval time.Duration // base interval: 1/targetFPS
+ lastRequest time.Time // when we last sent FrameRequest
+ lastDelivery time.Time // when we last received a frame
+ pending bool // true if FrameRequest sent but no frame received yet
+ missedCount int // consecutive missed requests (for adaptive rate)
+}
+
+// effectiveInterval returns the current interval accounting for adaptive rate
+// reduction. After missedThreshold consecutive misses, the interval doubles.
+func (ms *moduleState) effectiveInterval() time.Duration {
+ if ms.missedCount >= missedThreshold {
+ return ms.interval * 2
+ }
+ return ms.interval
+}
+
+// Option configures a Controller.
+type Option func(*Controller)
+
+// WithClock sets a custom time source for the controller. Intended for testing
+// to provide deterministic time progression without time.Sleep.
+func WithClock(fn func() time.Time) Option {
+ return func(c *Controller) {
+ c.now = fn
+ }
+}
+
+// Controller manages pull-based frame pacing for connected modules.
+// The compositor uses it to decide when to request frames from each module.
+// All methods are safe for concurrent access from multiple goroutines.
+type Controller struct {
+ mu sync.Mutex
+ modules map[uint64]*moduleState
+ now func() time.Time
+}
+
+// New creates a flow controller. Use WithClock to inject a custom time source
+// for testing.
+func New(opts ...Option) *Controller {
+ c := &Controller{
+ modules: make(map[uint64]*moduleState),
+ now: time.Now,
+ }
+ for _, opt := range opts {
+ opt(c)
+ }
+ return c
+}
+
+// AddModule registers a module with its preferred FPS.
+// If fps is 0, it defaults to 1 FPS (suitable for static content like a clock).
+func (c *Controller) AddModule(moduleID uint64, fps uint16) {
+ if fps == 0 {
+ fps = defaultFPS
+ }
+
+ interval := time.Second / time.Duration(fps)
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.modules[moduleID] = &moduleState{
+ targetFPS: fps,
+ interval: interval,
+ }
+}
+
+// RemoveModule unregisters a module. If the module ID does not exist, this is
+// a no-op.
+func (c *Controller) RemoveModule(moduleID uint64) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ delete(c.modules, moduleID)
+}
+
+// ShouldRequest returns true if it is time to request a frame from the
+// specified module. The compositor should call this on each render tick.
+//
+// Returns false if:
+// - The module ID is not registered
+// - A request is already pending (module has not responded yet)
+// - Not enough time has elapsed since the last request (respecting target FPS
+// and adaptive rate reduction)
+func (c *Controller) ShouldRequest(moduleID uint64) bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ ms, ok := c.modules[moduleID]
+ if !ok {
+ return false
+ }
+
+ if ms.pending {
+ return false
+ }
+
+ now := c.now()
+
+ // First request: always allow if no request has been made yet.
+ if ms.lastRequest.IsZero() {
+ return true
+ }
+
+ elapsed := now.Sub(ms.lastRequest)
+ return elapsed >= ms.effectiveInterval()
+}
+
+// FrameRequested marks that a FrameRequest was sent to the specified module.
+// Records the request time and sets the pending flag. If the module ID does not
+// exist, this is a no-op.
+func (c *Controller) FrameRequested(moduleID uint64) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ ms, ok := c.modules[moduleID]
+ if !ok {
+ return
+ }
+
+ ms.lastRequest = c.now()
+ ms.pending = true
+}
+
+// FrameDelivered marks that a frame was received from the specified module.
+// Resets the pending flag, records the delivery time, and clears the missed
+// count. If the module ID does not exist, this is a no-op.
+func (c *Controller) FrameDelivered(moduleID uint64) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ ms, ok := c.modules[moduleID]
+ if !ok {
+ return
+ }
+
+ ms.lastDelivery = c.now()
+ ms.pending = false
+ ms.missedCount = 0
+}
+
+// FrameMissed marks that the specified module failed to deliver a frame in
+// time. After missedThreshold (3) consecutive misses, the controller halves
+// the effective request rate by doubling the interval. This prevents wasting
+// bandwidth on slow or stuck modules. If the module ID does not exist, this
+// is a no-op.
+func (c *Controller) FrameMissed(moduleID uint64) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ ms, ok := c.modules[moduleID]
+ if !ok {
+ return
+ }
+
+ ms.missedCount++
+ ms.pending = false
+}
+
+// PendingModules returns the count of modules with pending (unanswered)
+// frame requests.
+func (c *Controller) PendingModules() int {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ count := 0
+ for _, ms := range c.modules {
+ if ms.pending {
+ count++
+ }
+ }
+ return count
+}
+
+// ModuleCount returns the total number of registered modules.
+func (c *Controller) ModuleCount() int {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ return len(c.modules)
+}
diff --git a/internal/flow/controller_test.go b/internal/flow/controller_test.go
new file mode 100644
index 0000000..fecc16f
--- /dev/null
+++ b/internal/flow/controller_test.go
@@ -0,0 +1,739 @@
+package flow
+
+import (
+ "sync"
+ "testing"
+ "time"
+)
+
+// testClock returns a clock function and an advance function.
+// The clock starts at a fixed epoch. Advance moves the clock forward
+// by the given duration. All time progression is deterministic.
+func testClock() (now func() time.Time, advance func(d time.Duration)) {
+ t := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
+ return func() time.Time {
+ return t
+ }, func(d time.Duration) {
+ t = t.Add(d)
+ }
+}
+
+func TestNewController(t *testing.T) {
+ c := New()
+ if c == nil {
+ t.Fatal("New() returned nil")
+ }
+ if c.ModuleCount() != 0 {
+ t.Errorf("ModuleCount() = %d, want 0", c.ModuleCount())
+ }
+ if c.PendingModules() != 0 {
+ t.Errorf("PendingModules() = %d, want 0", c.PendingModules())
+ }
+}
+
+func TestNewControllerWithClock(t *testing.T) {
+ clock, _ := testClock()
+ c := New(WithClock(clock))
+ if c == nil {
+ t.Fatal("New(WithClock) returned nil")
+ }
+ if c.now == nil {
+ t.Fatal("clock function not set")
+ }
+}
+
+func TestAddModule(t *testing.T) {
+ c := New()
+
+ c.AddModule(1, 60)
+ if c.ModuleCount() != 1 {
+ t.Errorf("ModuleCount() = %d, want 1", c.ModuleCount())
+ }
+
+ c.AddModule(2, 30)
+ if c.ModuleCount() != 2 {
+ t.Errorf("ModuleCount() = %d, want 2", c.ModuleCount())
+ }
+}
+
+func TestAddModuleZeroFPS(t *testing.T) {
+ clock, _ := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 0)
+
+ c.mu.Lock()
+ ms := c.modules[1]
+ c.mu.Unlock()
+
+ if ms.targetFPS != 1 {
+ t.Errorf("targetFPS = %d, want 1 (default for 0 FPS)", ms.targetFPS)
+ }
+ if ms.interval != time.Second {
+ t.Errorf("interval = %v, want %v", ms.interval, time.Second)
+ }
+}
+
+func TestAddModuleOverwrite(t *testing.T) {
+ c := New()
+
+ c.AddModule(1, 60)
+ c.AddModule(1, 30)
+
+ if c.ModuleCount() != 1 {
+ t.Errorf("ModuleCount() = %d, want 1 (overwrite, not duplicate)", c.ModuleCount())
+ }
+
+ c.mu.Lock()
+ ms := c.modules[1]
+ c.mu.Unlock()
+
+ if ms.targetFPS != 30 {
+ t.Errorf("targetFPS = %d, want 30 (overwritten)", ms.targetFPS)
+ }
+}
+
+func TestRemoveModule(t *testing.T) {
+ c := New()
+
+ c.AddModule(1, 60)
+ c.AddModule(2, 30)
+
+ c.RemoveModule(1)
+ if c.ModuleCount() != 1 {
+ t.Errorf("ModuleCount() = %d, want 1 after remove", c.ModuleCount())
+ }
+
+ c.RemoveModule(2)
+ if c.ModuleCount() != 0 {
+ t.Errorf("ModuleCount() = %d, want 0 after remove all", c.ModuleCount())
+ }
+}
+
+func TestRemoveUnknownModule(t *testing.T) {
+ c := New()
+
+ // Should not panic.
+ c.RemoveModule(999)
+
+ if c.ModuleCount() != 0 {
+ t.Errorf("ModuleCount() = %d, want 0", c.ModuleCount())
+ }
+}
+
+func TestShouldRequestFirstTime(t *testing.T) {
+ clock, _ := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 60)
+
+ if !c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = false for first request, want true")
+ }
+}
+
+func TestShouldRequestUnknownModule(t *testing.T) {
+ c := New()
+
+ if c.ShouldRequest(999) {
+ t.Error("ShouldRequest(999) = true for unknown module, want false")
+ }
+}
+
+func TestShouldRequestRespectsInterval(t *testing.T) {
+ clock, advance := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 10) // 10 FPS = 100ms interval
+
+ // First request: allowed.
+ if !c.ShouldRequest(1) {
+ t.Fatal("ShouldRequest() = false for first request")
+ }
+
+ c.FrameRequested(1)
+ c.FrameDelivered(1)
+
+ // Immediately after delivery: not enough time has passed.
+ if c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = true immediately after request, want false")
+ }
+
+ // Advance 50ms (half the interval): still too early.
+ advance(50 * time.Millisecond)
+ if c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = true at 50ms (half interval), want false")
+ }
+
+ // Advance another 50ms (total 100ms = interval): should be allowed.
+ advance(50 * time.Millisecond)
+ if !c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = false at 100ms (full interval), want true")
+ }
+}
+
+func TestShouldRequestBlocksWhilePending(t *testing.T) {
+ clock, advance := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 1) // 1 FPS = 1s interval
+
+ // First request.
+ if !c.ShouldRequest(1) {
+ t.Fatal("ShouldRequest() = false for first request")
+ }
+ c.FrameRequested(1)
+
+ // Even after the interval, pending blocks the request.
+ advance(2 * time.Second)
+ if c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = true while pending, want false")
+ }
+
+ // After delivery, pending is cleared.
+ c.FrameDelivered(1)
+
+ // Now enough time has passed.
+ if !c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = false after delivery + elapsed interval, want true")
+ }
+}
+
+func TestFrameRequestedSetsTime(t *testing.T) {
+ clock, advance := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 10) // 100ms interval
+
+ advance(500 * time.Millisecond)
+ c.FrameRequested(1)
+
+ c.mu.Lock()
+ ms := c.modules[1]
+ lr := ms.lastRequest
+ c.mu.Unlock()
+
+ expected := time.Date(2026, 1, 1, 0, 0, 0, 500_000_000, time.UTC)
+ if !lr.Equal(expected) {
+ t.Errorf("lastRequest = %v, want %v", lr, expected)
+ }
+}
+
+func TestFrameRequestedUnknownModule(t *testing.T) {
+ c := New()
+
+ // Should not panic.
+ c.FrameRequested(999)
+}
+
+func TestFrameDeliveredClearsPending(t *testing.T) {
+ clock, _ := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 60)
+
+ c.FrameRequested(1)
+ if c.PendingModules() != 1 {
+ t.Errorf("PendingModules() = %d after request, want 1", c.PendingModules())
+ }
+
+ c.FrameDelivered(1)
+ if c.PendingModules() != 0 {
+ t.Errorf("PendingModules() = %d after delivery, want 0", c.PendingModules())
+ }
+}
+
+func TestFrameDeliveredClearsMissedCount(t *testing.T) {
+ clock, _ := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 60)
+
+ // Accumulate misses.
+ c.FrameMissed(1)
+ c.FrameMissed(1)
+
+ c.mu.Lock()
+ countBefore := c.modules[1].missedCount
+ c.mu.Unlock()
+
+ if countBefore != 2 {
+ t.Errorf("missedCount = %d before delivery, want 2", countBefore)
+ }
+
+ // Delivery resets missed count.
+ c.FrameDelivered(1)
+
+ c.mu.Lock()
+ countAfter := c.modules[1].missedCount
+ c.mu.Unlock()
+
+ if countAfter != 0 {
+ t.Errorf("missedCount = %d after delivery, want 0", countAfter)
+ }
+}
+
+func TestFrameDeliveredUnknownModule(t *testing.T) {
+ c := New()
+
+ // Should not panic.
+ c.FrameDelivered(999)
+}
+
+func TestFrameMissedIncrements(t *testing.T) {
+ c := New()
+
+ c.AddModule(1, 60)
+
+ for i := 1; i <= 5; i++ {
+ c.FrameMissed(1)
+
+ c.mu.Lock()
+ count := c.modules[1].missedCount
+ c.mu.Unlock()
+
+ if count != i {
+ t.Errorf("missedCount after %d misses = %d, want %d", i, count, i)
+ }
+ }
+}
+
+func TestFrameMissedClearsPending(t *testing.T) {
+ clock, _ := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 60)
+
+ c.FrameRequested(1)
+ if c.PendingModules() != 1 {
+ t.Fatal("PendingModules() != 1 after request")
+ }
+
+ c.FrameMissed(1)
+ if c.PendingModules() != 0 {
+ t.Errorf("PendingModules() = %d after miss, want 0", c.PendingModules())
+ }
+}
+
+func TestFrameMissedUnknownModule(t *testing.T) {
+ c := New()
+
+ // Should not panic.
+ c.FrameMissed(999)
+}
+
+func TestAdaptiveRateReduction(t *testing.T) {
+ clock, advance := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 10) // 10 FPS = 100ms base interval
+
+ // Request and deliver first frame to establish lastRequest.
+ c.FrameRequested(1)
+ c.FrameDelivered(1)
+
+ // Accumulate 3 misses (threshold).
+ c.FrameRequested(1)
+ c.FrameMissed(1)
+ c.FrameRequested(1)
+ c.FrameMissed(1)
+ c.FrameRequested(1)
+ c.FrameMissed(1)
+
+ // After 3 misses, effective interval should be 200ms (doubled).
+ // Advance 100ms from last request: should NOT be ready.
+ advance(100 * time.Millisecond)
+ if c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = true at 100ms (base interval) after 3 misses, want false (doubled to 200ms)")
+ }
+
+ // Advance another 100ms (total 200ms): now should be ready.
+ advance(100 * time.Millisecond)
+ if !c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = false at 200ms (doubled interval) after 3 misses, want true")
+ }
+}
+
+func TestAdaptiveRateResetOnDelivery(t *testing.T) {
+ clock, advance := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 10) // 10 FPS = 100ms base interval
+
+ c.FrameRequested(1)
+ c.FrameDelivered(1)
+
+ // Accumulate 3 misses.
+ c.FrameRequested(1)
+ c.FrameMissed(1)
+ c.FrameRequested(1)
+ c.FrameMissed(1)
+ c.FrameRequested(1)
+ c.FrameMissed(1)
+
+ // Now deliver a successful frame. This resets missedCount.
+ c.FrameRequested(1)
+ c.FrameDelivered(1)
+
+ // After delivery, effective interval should be back to 100ms.
+ advance(100 * time.Millisecond)
+ if !c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = false at 100ms after miss reset, want true (back to base interval)")
+ }
+}
+
+func TestAdaptiveRateBelowThreshold(t *testing.T) {
+ clock, advance := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 10) // 10 FPS = 100ms interval
+
+ c.FrameRequested(1)
+ c.FrameDelivered(1)
+
+ // Only 2 misses (below threshold of 3).
+ c.FrameRequested(1)
+ c.FrameMissed(1)
+ c.FrameRequested(1)
+ c.FrameMissed(1)
+
+ // Effective interval should still be 100ms (not doubled).
+ advance(100 * time.Millisecond)
+ if !c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = false at 100ms with only 2 misses, want true (below threshold)")
+ }
+}
+
+func TestPendingModules(t *testing.T) {
+ clock, _ := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 60)
+ c.AddModule(2, 30)
+ c.AddModule(3, 10)
+
+ if c.PendingModules() != 0 {
+ t.Errorf("PendingModules() = %d initially, want 0", c.PendingModules())
+ }
+
+ c.FrameRequested(1)
+ c.FrameRequested(2)
+ if c.PendingModules() != 2 {
+ t.Errorf("PendingModules() = %d after 2 requests, want 2", c.PendingModules())
+ }
+
+ c.FrameDelivered(1)
+ if c.PendingModules() != 1 {
+ t.Errorf("PendingModules() = %d after 1 delivery, want 1", c.PendingModules())
+ }
+
+ c.FrameMissed(2)
+ if c.PendingModules() != 0 {
+ t.Errorf("PendingModules() = %d after miss, want 0", c.PendingModules())
+ }
+}
+
+func TestModuleCount(t *testing.T) {
+ c := New()
+
+ tests := []struct {
+ action string
+ fn func()
+ want int
+ }{
+ {"initial", func() {}, 0},
+ {"add 1", func() { c.AddModule(1, 60) }, 1},
+ {"add 2", func() { c.AddModule(2, 30) }, 2},
+ {"add 3", func() { c.AddModule(3, 10) }, 3},
+ {"remove 2", func() { c.RemoveModule(2) }, 2},
+ {"remove unknown", func() { c.RemoveModule(999) }, 2},
+ {"remove 1", func() { c.RemoveModule(1) }, 1},
+ {"remove 3", func() { c.RemoveModule(3) }, 0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.action, func(t *testing.T) {
+ tt.fn()
+ if got := c.ModuleCount(); got != tt.want {
+ t.Errorf("ModuleCount() = %d, want %d", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestHighFPSInterval(t *testing.T) {
+ clock, advance := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 120) // 120 FPS ~ 8.33ms interval
+
+ c.FrameRequested(1)
+ c.FrameDelivered(1)
+
+ // At 8ms: should not be ready yet.
+ advance(8 * time.Millisecond)
+ if c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = true at 8ms for 120 FPS, want false")
+ }
+
+ // At 9ms: should be ready (8.33ms interval).
+ advance(1 * time.Millisecond)
+ if !c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = false at 9ms for 120 FPS, want true")
+ }
+}
+
+func TestMultipleModulesIndependent(t *testing.T) {
+ clock, advance := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 10) // 100ms interval
+ c.AddModule(2, 5) // 200ms interval
+
+ // Both should allow first request.
+ if !c.ShouldRequest(1) {
+ t.Error("module 1: ShouldRequest() = false for first request")
+ }
+ if !c.ShouldRequest(2) {
+ t.Error("module 2: ShouldRequest() = false for first request")
+ }
+
+ c.FrameRequested(1)
+ c.FrameRequested(2)
+ c.FrameDelivered(1)
+ c.FrameDelivered(2)
+
+ // At 100ms: module 1 ready, module 2 not.
+ advance(100 * time.Millisecond)
+ if !c.ShouldRequest(1) {
+ t.Error("module 1: ShouldRequest() = false at 100ms, want true")
+ }
+ if c.ShouldRequest(2) {
+ t.Error("module 2: ShouldRequest() = true at 100ms (needs 200ms), want false")
+ }
+
+ // At 200ms: both ready.
+ advance(100 * time.Millisecond)
+ if !c.ShouldRequest(1) {
+ t.Error("module 1: ShouldRequest() = false at 200ms, want true")
+ }
+ if !c.ShouldRequest(2) {
+ t.Error("module 2: ShouldRequest() = false at 200ms, want true")
+ }
+}
+
+func TestEffectiveInterval(t *testing.T) {
+ tests := []struct {
+ name string
+ missedCount int
+ baseMs int
+ wantMs int
+ }{
+ {"no misses", 0, 100, 100},
+ {"1 miss", 1, 100, 100},
+ {"2 misses", 2, 100, 100},
+ {"3 misses (threshold)", 3, 100, 200},
+ {"4 misses", 4, 100, 200},
+ {"10 misses", 10, 100, 200},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ms := &moduleState{
+ interval: time.Duration(tt.baseMs) * time.Millisecond,
+ missedCount: tt.missedCount,
+ }
+ got := ms.effectiveInterval()
+ want := time.Duration(tt.wantMs) * time.Millisecond
+ if got != want {
+ t.Errorf("effectiveInterval() = %v, want %v", got, want)
+ }
+ })
+ }
+}
+
+func TestFullLifecycle(t *testing.T) {
+ clock, advance := testClock()
+ c := New(WithClock(clock))
+
+ // Register.
+ c.AddModule(1, 10) // 100ms interval
+ if c.ModuleCount() != 1 {
+ t.Fatal("ModuleCount() != 1 after add")
+ }
+
+ // First request cycle.
+ if !c.ShouldRequest(1) {
+ t.Fatal("ShouldRequest() = false for first request")
+ }
+ c.FrameRequested(1)
+ if c.PendingModules() != 1 {
+ t.Fatal("PendingModules() != 1 after request")
+ }
+ c.FrameDelivered(1)
+ if c.PendingModules() != 0 {
+ t.Fatal("PendingModules() != 0 after delivery")
+ }
+
+ // Second request after interval.
+ advance(100 * time.Millisecond)
+ if !c.ShouldRequest(1) {
+ t.Fatal("ShouldRequest() = false after interval")
+ }
+ c.FrameRequested(1)
+ c.FrameDelivered(1)
+
+ // Miss some frames.
+ advance(100 * time.Millisecond)
+ c.FrameRequested(1)
+ c.FrameMissed(1)
+ advance(100 * time.Millisecond)
+ c.FrameRequested(1)
+ c.FrameMissed(1)
+ advance(100 * time.Millisecond)
+ c.FrameRequested(1)
+ c.FrameMissed(1)
+
+ // Now at doubled interval. 100ms should not suffice.
+ advance(100 * time.Millisecond)
+ if c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = true at base interval after 3 misses, want false")
+ }
+
+ // 200ms should work.
+ advance(100 * time.Millisecond)
+ if !c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = false at doubled interval, want true")
+ }
+
+ // Deliver to reset.
+ c.FrameRequested(1)
+ c.FrameDelivered(1)
+
+ // Back to normal interval.
+ advance(100 * time.Millisecond)
+ if !c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = false after reset, want true")
+ }
+
+ // Unregister.
+ c.RemoveModule(1)
+ if c.ModuleCount() != 0 {
+ t.Fatal("ModuleCount() != 0 after remove")
+ }
+ if c.ShouldRequest(1) {
+ t.Error("ShouldRequest() = true for removed module, want false")
+ }
+}
+
+func TestConcurrentAccess(t *testing.T) {
+ c := New()
+
+ const numModules = 100
+ const iterations = 1000
+
+ // Add modules.
+ for i := uint64(0); i < numModules; i++ {
+ c.AddModule(i, uint16(10+i%50))
+ }
+
+ var wg sync.WaitGroup
+
+ // Concurrent ShouldRequest calls.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for range iterations {
+ for i := uint64(0); i < numModules; i++ {
+ c.ShouldRequest(i)
+ }
+ }
+ }()
+
+ // Concurrent FrameRequested/Delivered calls.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for range iterations {
+ for i := uint64(0); i < numModules; i++ {
+ c.FrameRequested(i)
+ c.FrameDelivered(i)
+ }
+ }
+ }()
+
+ // Concurrent FrameMissed calls.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for range iterations {
+ for i := uint64(0); i < numModules; i++ {
+ c.FrameMissed(i)
+ }
+ }
+ }()
+
+ // Concurrent PendingModules/ModuleCount calls.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for range iterations {
+ c.PendingModules()
+ c.ModuleCount()
+ }
+ }()
+
+ // Concurrent Add/Remove.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for range iterations {
+ c.AddModule(numModules+1, 60)
+ c.RemoveModule(numModules + 1)
+ }
+ }()
+
+ wg.Wait()
+}
+
+func BenchmarkShouldRequest(b *testing.B) {
+ clock, _ := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 60)
+
+ b.ResetTimer()
+ for range b.N {
+ c.ShouldRequest(1)
+ }
+}
+
+func BenchmarkShouldRequestMultipleModules(b *testing.B) {
+ clock, _ := testClock()
+ c := New(WithClock(clock))
+
+ const numModules = 16
+ for i := uint64(1); i <= numModules; i++ {
+ c.AddModule(i, uint16(10*i))
+ }
+
+ b.ResetTimer()
+ for range b.N {
+ for i := uint64(1); i <= numModules; i++ {
+ c.ShouldRequest(i)
+ }
+ }
+}
+
+func BenchmarkFrameRequestDeliverCycle(b *testing.B) {
+ clock, _ := testClock()
+ c := New(WithClock(clock))
+
+ c.AddModule(1, 60)
+
+ b.ResetTimer()
+ for range b.N {
+ c.FrameRequested(1)
+ c.FrameDelivered(1)
+ }
+}
diff --git a/internal/flow/doc.go b/internal/flow/doc.go
new file mode 100644
index 0000000..1f1687f
--- /dev/null
+++ b/internal/flow/doc.go
@@ -0,0 +1,17 @@
+// Package flow provides pull-based frame pacing for the compose library.
+//
+// The compositor uses a [Controller] to decide when to request frames from each
+// connected module. This implements the Wayland frame callback pattern: modules
+// render only when the compositor asks, preventing them from flooding the
+// compositor with unsolicited frames.
+//
+// The controller is passive (no goroutines, no timers). The compositor's render
+// loop polls [Controller.ShouldRequest] on each tick and calls
+// [Controller.FrameRequested] after sending a request to a module. When a frame
+// arrives, the compositor calls [Controller.FrameDelivered]. If a module fails
+// to respond in time, [Controller.FrameMissed] adaptively reduces the effective
+// request rate.
+//
+// All exported methods are safe for concurrent use from multiple goroutines.
+// The package is standalone with no internal dependencies.
+package flow
diff --git a/internal/protocol/doc.go b/internal/protocol/doc.go
new file mode 100644
index 0000000..f645ec2
--- /dev/null
+++ b/internal/protocol/doc.go
@@ -0,0 +1,15 @@
+// Package protocol defines the wire format for gogpu/compose inter-process
+// communication. It is the leaf package in the internal dependency graph —
+// imported by transport implementations but importing nothing internal itself.
+//
+// The protocol is built around a fixed 64-byte binary frame header that
+// precedes every message on the wire. Headers use little-endian byte order
+// and are designed to be cache-line aligned on modern hardware.
+//
+// All encode/decode functions operate on caller-provided byte buffers to
+// achieve zero allocations on the hot path. This is critical for sustained
+// 60 FPS frame delivery.
+//
+// Wire format version: 1
+// Magic bytes: 0x43 0x4F 0x4D 0x50 ("COMP")
+package protocol
diff --git a/internal/protocol/handshake.go b/internal/protocol/handshake.go
new file mode 100644
index 0000000..ce2af8d
--- /dev/null
+++ b/internal/protocol/handshake.go
@@ -0,0 +1,326 @@
+package protocol
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+)
+
+// HandshakeSize is the fixed size in bytes of both HelloMsg and WelcomeMsg.
+const HandshakeSize = 128
+
+// TransportType identifies the preferred or granted transport mechanism.
+type TransportType uint8
+
+const (
+ // TransportSocket indicates Unix domain socket transport.
+ TransportSocket TransportType = 0
+
+ // TransportShm indicates shared memory transport.
+ TransportShm TransportType = 1
+)
+
+// String returns the human-readable name of the transport type.
+func (t TransportType) String() string {
+ switch t {
+ case TransportSocket:
+ return "Socket"
+ case TransportShm:
+ return "SharedMemory"
+ default:
+ return fmt.Sprintf("TransportType(%d)", uint8(t))
+ }
+}
+
+// HelloMsg is sent by the module to the compositor during the handshake phase.
+// Fixed 128 bytes.
+//
+// Layout: Magic(4) + Version(2) + Name(64) + Width(2) + Height(2) +
+// PreferredFPS(2) + Transport(1) + Reserved(51) = 128
+type HelloMsg struct {
+ // Magic must be "COMP" (0x43, 0x4F, 0x4D, 0x50).
+ Magic [4]byte
+
+ // Version is the protocol version the module supports.
+ Version uint16
+
+ // Name is the null-terminated human-readable module name (max 63 chars + NUL).
+ Name [64]byte
+
+ // Width is the initial frame width in pixels.
+ Width uint16
+
+ // Height is the initial frame height in pixels.
+ Height uint16
+
+ // PreferredFPS is the module's preferred frame rate (e.g., 1 for clock, 60 for animation).
+ PreferredFPS uint16
+
+ // Transport is the module's preferred transport mechanism.
+ Transport TransportType
+
+ // Reserved is padding for future use. Must be zero.
+ // Size: 128 - 4 - 2 - 64 - 2 - 2 - 2 - 1 = 51
+ Reserved [51]byte
+}
+
+// HelloMsg field offsets.
+const (
+ helloOffMagic = 0
+ helloOffVersion = 4
+ helloOffName = 6
+ helloOffWidth = 70
+ helloOffHeight = 72
+ helloOffPreferredFPS = 74
+ helloOffTransport = 76
+ helloOffReserved = 77
+)
+
+// WelcomeMsg is sent by the compositor to the module after accepting the handshake.
+// Fixed 128 bytes.
+//
+// Layout: Magic(4) + Version(2) + ModuleID(8) + Accepted(1) + Transport(1) +
+// MinVersion(2) + MaxVersion(2) + Reserved(108) = 128
+type WelcomeMsg struct {
+ // Magic must be "COMP" (0x43, 0x4F, 0x4D, 0x50).
+ Magic [4]byte
+
+ // Version is the protocol version the compositor selected.
+ Version uint16
+
+ // ModuleID is the compositor-assigned unique identifier for this module.
+ ModuleID uint64
+
+ // Accepted is 1 if the connection was accepted, 0 if rejected.
+ Accepted uint8
+
+ // Transport is the transport mechanism the compositor granted.
+ Transport TransportType
+
+ // MinVersion is the minimum protocol version the compositor supports.
+ MinVersion uint16
+
+ // MaxVersion is the maximum protocol version the compositor supports.
+ MaxVersion uint16
+
+ // Reserved is padding for future use. Must be zero.
+ // Size: 128 - 4 - 2 - 8 - 1 - 1 - 2 - 2 = 108
+ Reserved [108]byte
+}
+
+// WelcomeMsg field offsets.
+const (
+ welcomeOffMagic = 0
+ welcomeOffVersion = 4
+ welcomeOffModuleID = 6
+ welcomeOffAccepted = 14
+ welcomeOffTransport = 15
+ welcomeOffMinVersion = 16
+ welcomeOffMaxVersion = 18
+ welcomeOffReserved = 20
+)
+
+// Handshake errors.
+var (
+ // ErrHandshakeBufTooSmall is returned when the buffer is smaller than HandshakeSize.
+ ErrHandshakeBufTooSmall = errors.New("protocol: handshake buffer too small (need 128 bytes)")
+
+ // ErrHandshakeInvalidMagic is returned when handshake magic bytes are wrong.
+ ErrHandshakeInvalidMagic = errors.New("protocol: handshake invalid magic (expected 0x434F4D50)")
+
+ // ErrRejected is returned when the compositor rejected the connection.
+ ErrRejected = errors.New("protocol: connection rejected by compositor")
+)
+
+// EncodeHello writes a HelloMsg into buf using little-endian byte order.
+// buf must be at least HandshakeSize (128) bytes. Never allocates.
+func EncodeHello(msg *HelloMsg, buf []byte) error {
+ if len(buf) < HandshakeSize {
+ return ErrHandshakeBufTooSmall
+ }
+
+ // Zero the buffer to ensure reserved bytes are clean.
+ clear(buf[:HandshakeSize])
+
+ // Magic
+ buf[helloOffMagic] = msg.Magic[0]
+ buf[helloOffMagic+1] = msg.Magic[1]
+ buf[helloOffMagic+2] = msg.Magic[2]
+ buf[helloOffMagic+3] = msg.Magic[3]
+
+ // Version
+ binary.LittleEndian.PutUint16(buf[helloOffVersion:], msg.Version)
+
+ // Name (64 bytes, null-terminated)
+ copy(buf[helloOffName:helloOffName+64], msg.Name[:])
+
+ // Width
+ binary.LittleEndian.PutUint16(buf[helloOffWidth:], msg.Width)
+
+ // Height
+ binary.LittleEndian.PutUint16(buf[helloOffHeight:], msg.Height)
+
+ // PreferredFPS
+ binary.LittleEndian.PutUint16(buf[helloOffPreferredFPS:], msg.PreferredFPS)
+
+ // Transport
+ buf[helloOffTransport] = uint8(msg.Transport)
+
+ // Reserved already zeroed by clear()
+
+ return nil
+}
+
+// DecodeHello reads a HelloMsg from buf and validates the magic bytes.
+// buf must be at least HandshakeSize (128) bytes. Never allocates.
+func DecodeHello(buf []byte) (HelloMsg, error) {
+ if len(buf) < HandshakeSize {
+ return HelloMsg{}, ErrHandshakeBufTooSmall
+ }
+
+ var msg HelloMsg
+
+ // Magic
+ msg.Magic[0] = buf[helloOffMagic]
+ msg.Magic[1] = buf[helloOffMagic+1]
+ msg.Magic[2] = buf[helloOffMagic+2]
+ msg.Magic[3] = buf[helloOffMagic+3]
+
+ if msg.Magic != Magic {
+ return HelloMsg{}, fmt.Errorf("%w: got [0x%02X 0x%02X 0x%02X 0x%02X]",
+ ErrHandshakeInvalidMagic, msg.Magic[0], msg.Magic[1], msg.Magic[2], msg.Magic[3])
+ }
+
+ // Version
+ msg.Version = binary.LittleEndian.Uint16(buf[helloOffVersion:])
+
+ // Name
+ copy(msg.Name[:], buf[helloOffName:helloOffName+64])
+
+ // Width
+ msg.Width = binary.LittleEndian.Uint16(buf[helloOffWidth:])
+
+ // Height
+ msg.Height = binary.LittleEndian.Uint16(buf[helloOffHeight:])
+
+ // PreferredFPS
+ msg.PreferredFPS = binary.LittleEndian.Uint16(buf[helloOffPreferredFPS:])
+
+ // Transport
+ msg.Transport = TransportType(buf[helloOffTransport])
+
+ // Reserved
+ copy(msg.Reserved[:], buf[helloOffReserved:helloOffReserved+51])
+
+ return msg, nil
+}
+
+// EncodeWelcome writes a WelcomeMsg into buf using little-endian byte order.
+// buf must be at least HandshakeSize (128) bytes. Never allocates.
+func EncodeWelcome(msg *WelcomeMsg, buf []byte) error {
+ if len(buf) < HandshakeSize {
+ return ErrHandshakeBufTooSmall
+ }
+
+ // Zero the buffer to ensure reserved bytes are clean.
+ clear(buf[:HandshakeSize])
+
+ // Magic
+ buf[welcomeOffMagic] = msg.Magic[0]
+ buf[welcomeOffMagic+1] = msg.Magic[1]
+ buf[welcomeOffMagic+2] = msg.Magic[2]
+ buf[welcomeOffMagic+3] = msg.Magic[3]
+
+ // Version
+ binary.LittleEndian.PutUint16(buf[welcomeOffVersion:], msg.Version)
+
+ // ModuleID
+ binary.LittleEndian.PutUint64(buf[welcomeOffModuleID:], msg.ModuleID)
+
+ // Accepted
+ buf[welcomeOffAccepted] = msg.Accepted
+
+ // Transport
+ buf[welcomeOffTransport] = uint8(msg.Transport)
+
+ // MinVersion
+ binary.LittleEndian.PutUint16(buf[welcomeOffMinVersion:], msg.MinVersion)
+
+ // MaxVersion
+ binary.LittleEndian.PutUint16(buf[welcomeOffMaxVersion:], msg.MaxVersion)
+
+ // Reserved already zeroed by clear()
+
+ return nil
+}
+
+// DecodeWelcome reads a WelcomeMsg from buf and validates the magic bytes.
+// buf must be at least HandshakeSize (128) bytes. Never allocates.
+func DecodeWelcome(buf []byte) (WelcomeMsg, error) {
+ if len(buf) < HandshakeSize {
+ return WelcomeMsg{}, ErrHandshakeBufTooSmall
+ }
+
+ var msg WelcomeMsg
+
+ // Magic
+ msg.Magic[0] = buf[welcomeOffMagic]
+ msg.Magic[1] = buf[welcomeOffMagic+1]
+ msg.Magic[2] = buf[welcomeOffMagic+2]
+ msg.Magic[3] = buf[welcomeOffMagic+3]
+
+ if msg.Magic != Magic {
+ return WelcomeMsg{}, fmt.Errorf("%w: got [0x%02X 0x%02X 0x%02X 0x%02X]",
+ ErrHandshakeInvalidMagic, msg.Magic[0], msg.Magic[1], msg.Magic[2], msg.Magic[3])
+ }
+
+ // Version
+ msg.Version = binary.LittleEndian.Uint16(buf[welcomeOffVersion:])
+
+ // ModuleID
+ msg.ModuleID = binary.LittleEndian.Uint64(buf[welcomeOffModuleID:])
+
+ // Accepted
+ msg.Accepted = buf[welcomeOffAccepted]
+
+ // Transport
+ msg.Transport = TransportType(buf[welcomeOffTransport])
+
+ // MinVersion
+ msg.MinVersion = binary.LittleEndian.Uint16(buf[welcomeOffMinVersion:])
+
+ // MaxVersion
+ msg.MaxVersion = binary.LittleEndian.Uint16(buf[welcomeOffMaxVersion:])
+
+ // Reserved
+ copy(msg.Reserved[:], buf[welcomeOffReserved:welcomeOffReserved+108])
+
+ return msg, nil
+}
+
+// SetName copies a string name into the HelloMsg's fixed-size Name field.
+// The name is null-terminated. If name is longer than 63 bytes, it is truncated.
+func SetName(msg *HelloMsg, name string) {
+ // Clear the name field first.
+ clear(msg.Name[:])
+
+ // Copy up to 63 bytes (leaving room for null terminator).
+ maxLen := len(msg.Name) - 1
+ n := len(name)
+ if n > maxLen {
+ n = maxLen
+ }
+ copy(msg.Name[:n], name)
+ // The field is already zero-filled, so null terminator is implicit.
+}
+
+// GetName reads the null-terminated string from the HelloMsg's Name field.
+func GetName(msg *HelloMsg) string {
+ for i, b := range msg.Name {
+ if b == 0 {
+ return string(msg.Name[:i])
+ }
+ }
+ // No null terminator found — return all 64 bytes as string.
+ return string(msg.Name[:])
+}
diff --git a/internal/protocol/handshake_test.go b/internal/protocol/handshake_test.go
new file mode 100644
index 0000000..46a08f7
--- /dev/null
+++ b/internal/protocol/handshake_test.go
@@ -0,0 +1,566 @@
+package protocol
+
+import (
+ "errors"
+ "math"
+ "strings"
+ "testing"
+)
+
+func TestHandshakeSize(t *testing.T) {
+ if HandshakeSize != 128 {
+ t.Fatalf("HandshakeSize = %d, want 128", HandshakeSize)
+ }
+}
+
+func TestHelloMsg_RoundTrip(t *testing.T) {
+ tests := []struct {
+ name string
+ msg HelloMsg
+ }{
+ {
+ name: "Typical",
+ msg: func() HelloMsg {
+ var m HelloMsg
+ m.Magic = Magic
+ m.Version = ProtocolVersion
+ SetName(&m, "clock")
+ m.Width = 400
+ m.Height = 120
+ m.PreferredFPS = 1
+ m.Transport = TransportSocket
+ return m
+ }(),
+ },
+ {
+ name: "AnimatedModule",
+ msg: func() HelloMsg {
+ var m HelloMsg
+ m.Magic = Magic
+ m.Version = ProtocolVersion
+ SetName(&m, "notification-popup")
+ m.Width = 1920
+ m.Height = 1080
+ m.PreferredFPS = 60
+ m.Transport = TransportShm
+ return m
+ }(),
+ },
+ {
+ name: "MaxValues",
+ msg: func() HelloMsg {
+ var m HelloMsg
+ m.Magic = Magic
+ m.Version = math.MaxUint16
+ SetName(&m, strings.Repeat("x", 63))
+ m.Width = math.MaxUint16
+ m.Height = math.MaxUint16
+ m.PreferredFPS = math.MaxUint16
+ m.Transport = TransportType(0xFF)
+ return m
+ }(),
+ },
+ {
+ name: "EmptyName",
+ msg: func() HelloMsg {
+ var m HelloMsg
+ m.Magic = Magic
+ m.Version = ProtocolVersion
+ m.Width = 100
+ m.Height = 100
+ m.PreferredFPS = 30
+ return m
+ }(),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, HandshakeSize)
+ if err := EncodeHello(&tt.msg, buf); err != nil {
+ t.Fatalf("EncodeHello() error: %v", err)
+ }
+
+ got, err := DecodeHello(buf)
+ if err != nil {
+ t.Fatalf("DecodeHello() error: %v", err)
+ }
+
+ if got != tt.msg {
+ t.Errorf("round-trip mismatch:\n got: %+v\n want: %+v", got, tt.msg)
+ }
+ })
+ }
+}
+
+func TestWelcomeMsg_RoundTrip(t *testing.T) {
+ tests := []struct {
+ name string
+ msg WelcomeMsg
+ }{
+ {
+ name: "Accepted",
+ msg: WelcomeMsg{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ ModuleID: 42,
+ Accepted: 1,
+ Transport: TransportSocket,
+ MinVersion: 1,
+ MaxVersion: 1,
+ },
+ },
+ {
+ name: "Rejected",
+ msg: WelcomeMsg{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ ModuleID: 0,
+ Accepted: 0,
+ Transport: TransportSocket,
+ MinVersion: 1,
+ MaxVersion: 3,
+ },
+ },
+ {
+ name: "ShmGranted",
+ msg: WelcomeMsg{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ ModuleID: 99,
+ Accepted: 1,
+ Transport: TransportShm,
+ MinVersion: 1,
+ MaxVersion: 2,
+ },
+ },
+ {
+ name: "MaxValues",
+ msg: WelcomeMsg{
+ Magic: Magic,
+ Version: math.MaxUint16,
+ ModuleID: math.MaxUint64,
+ Accepted: 0xFF,
+ Transport: TransportType(0xFF),
+ MinVersion: math.MaxUint16,
+ MaxVersion: math.MaxUint16,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, HandshakeSize)
+ if err := EncodeWelcome(&tt.msg, buf); err != nil {
+ t.Fatalf("EncodeWelcome() error: %v", err)
+ }
+
+ got, err := DecodeWelcome(buf)
+ if err != nil {
+ t.Fatalf("DecodeWelcome() error: %v", err)
+ }
+
+ if got != tt.msg {
+ t.Errorf("round-trip mismatch:\n got: %+v\n want: %+v", got, tt.msg)
+ }
+ })
+ }
+}
+
+func TestEncodeHello_BufferTooSmall(t *testing.T) {
+ msg := HelloMsg{Magic: Magic}
+ tests := []struct {
+ name string
+ size int
+ }{
+ {"Zero", 0},
+ {"Half", 64},
+ {"AlmostFull", 127},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, tt.size)
+ err := EncodeHello(&msg, buf)
+ if !errors.Is(err, ErrHandshakeBufTooSmall) {
+ t.Errorf("EncodeHello() with %d-byte buf: got error %v, want ErrHandshakeBufTooSmall", tt.size, err)
+ }
+ })
+ }
+}
+
+func TestDecodeHello_BufferTooSmall(t *testing.T) {
+ tests := []struct {
+ name string
+ size int
+ }{
+ {"Zero", 0},
+ {"Half", 64},
+ {"AlmostFull", 127},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, tt.size)
+ _, err := DecodeHello(buf)
+ if !errors.Is(err, ErrHandshakeBufTooSmall) {
+ t.Errorf("DecodeHello() with %d-byte buf: got error %v, want ErrHandshakeBufTooSmall", tt.size, err)
+ }
+ })
+ }
+}
+
+func TestEncodeWelcome_BufferTooSmall(t *testing.T) {
+ msg := WelcomeMsg{Magic: Magic}
+ tests := []struct {
+ name string
+ size int
+ }{
+ {"Zero", 0},
+ {"Half", 64},
+ {"AlmostFull", 127},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, tt.size)
+ err := EncodeWelcome(&msg, buf)
+ if !errors.Is(err, ErrHandshakeBufTooSmall) {
+ t.Errorf("EncodeWelcome() with %d-byte buf: got error %v, want ErrHandshakeBufTooSmall", tt.size, err)
+ }
+ })
+ }
+}
+
+func TestDecodeWelcome_BufferTooSmall(t *testing.T) {
+ tests := []struct {
+ name string
+ size int
+ }{
+ {"Zero", 0},
+ {"Half", 64},
+ {"AlmostFull", 127},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, tt.size)
+ _, err := DecodeWelcome(buf)
+ if !errors.Is(err, ErrHandshakeBufTooSmall) {
+ t.Errorf("DecodeWelcome() with %d-byte buf: got error %v, want ErrHandshakeBufTooSmall", tt.size, err)
+ }
+ })
+ }
+}
+
+func TestDecodeHello_InvalidMagic(t *testing.T) {
+ tests := []struct {
+ name string
+ magic [4]byte
+ }{
+ {"AllZeros", [4]byte{0, 0, 0, 0}},
+ {"WrongFirst", [4]byte{0x00, 0x4F, 0x4D, 0x50}},
+ {"Reversed", [4]byte{0x50, 0x4D, 0x4F, 0x43}},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, HandshakeSize)
+ buf[0] = tt.magic[0]
+ buf[1] = tt.magic[1]
+ buf[2] = tt.magic[2]
+ buf[3] = tt.magic[3]
+
+ _, err := DecodeHello(buf)
+ if err == nil {
+ t.Error("DecodeHello() with invalid magic: expected error, got nil")
+ }
+ })
+ }
+}
+
+func TestDecodeWelcome_InvalidMagic(t *testing.T) {
+ tests := []struct {
+ name string
+ magic [4]byte
+ }{
+ {"AllZeros", [4]byte{0, 0, 0, 0}},
+ {"WrongLast", [4]byte{0x43, 0x4F, 0x4D, 0x00}},
+ {"AllOnes", [4]byte{0xFF, 0xFF, 0xFF, 0xFF}},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, HandshakeSize)
+ buf[0] = tt.magic[0]
+ buf[1] = tt.magic[1]
+ buf[2] = tt.magic[2]
+ buf[3] = tt.magic[3]
+
+ _, err := DecodeWelcome(buf)
+ if err == nil {
+ t.Error("DecodeWelcome() with invalid magic: expected error, got nil")
+ }
+ })
+ }
+}
+
+func TestSetName_GetName(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {"Short", "clock", "clock"},
+ {"Empty", "", ""},
+ {"ExactMax", strings.Repeat("a", 63), strings.Repeat("a", 63)},
+ {"Overflow", strings.Repeat("b", 100), strings.Repeat("b", 63)},
+ {"Unicode", "часы", "часы"},
+ {"WithSpaces", "my module name", "my module name"},
+ {"SingleChar", "x", "x"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var msg HelloMsg
+ SetName(&msg, tt.input)
+ got := GetName(&msg)
+ if got != tt.want {
+ t.Errorf("SetName/GetName(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestSetName_NullTerminated(t *testing.T) {
+ var msg HelloMsg
+ SetName(&msg, "test")
+
+ // Verify null terminator exists at position 4.
+ if msg.Name[4] != 0 {
+ t.Errorf("Name[4] = 0x%02X, want 0x00 (null terminator)", msg.Name[4])
+ }
+ // Verify rest is zeroed.
+ for i := 5; i < len(msg.Name); i++ {
+ if msg.Name[i] != 0 {
+ t.Errorf("Name[%d] = 0x%02X, want 0x00", i, msg.Name[i])
+ }
+ }
+}
+
+func TestSetName_ClearsOldName(t *testing.T) {
+ var msg HelloMsg
+ SetName(&msg, strings.Repeat("x", 63))
+ SetName(&msg, "ab")
+
+ got := GetName(&msg)
+ if got != "ab" {
+ t.Errorf("after overwrite, GetName() = %q, want %q", got, "ab")
+ }
+ // Verify bytes after "ab" are zero.
+ if msg.Name[2] != 0 {
+ t.Errorf("Name[2] = 0x%02X, want 0x00 after overwrite", msg.Name[2])
+ }
+}
+
+func TestGetName_NoNullTerminator(t *testing.T) {
+ // Edge case: all 64 bytes are non-zero (no null terminator).
+ var msg HelloMsg
+ for i := range msg.Name {
+ msg.Name[i] = 'z'
+ }
+
+ got := GetName(&msg)
+ if len(got) != 64 {
+ t.Errorf("GetName() with no NUL: len = %d, want 64", len(got))
+ }
+}
+
+func TestTransportType_String(t *testing.T) {
+ tests := []struct {
+ name string
+ tr TransportType
+ want string
+ }{
+ {"Socket", TransportSocket, "Socket"},
+ {"Shm", TransportShm, "SharedMemory"},
+ {"Unknown", TransportType(99), "TransportType(99)"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.tr.String()
+ if got != tt.want {
+ t.Errorf("TransportType(%d).String() = %q, want %q", tt.tr, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestEncodeHello_ReservedZeroed(t *testing.T) {
+ msg := HelloMsg{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ }
+ SetName(&msg, "test")
+
+ buf := make([]byte, HandshakeSize)
+ // Fill with non-zero.
+ for i := range buf {
+ buf[i] = 0xFF
+ }
+
+ if err := EncodeHello(&msg, buf); err != nil {
+ t.Fatalf("EncodeHello() error: %v", err)
+ }
+
+ // Verify reserved bytes are zeroed.
+ for i := helloOffReserved; i < helloOffReserved+51; i++ {
+ if buf[i] != 0 {
+ t.Errorf("reserved byte buf[%d] = 0x%02X, want 0x00", i, buf[i])
+ }
+ }
+}
+
+func TestEncodeWelcome_ReservedZeroed(t *testing.T) {
+ msg := WelcomeMsg{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ ModuleID: 1,
+ Accepted: 1,
+ }
+
+ buf := make([]byte, HandshakeSize)
+ // Fill with non-zero.
+ for i := range buf {
+ buf[i] = 0xFF
+ }
+
+ if err := EncodeWelcome(&msg, buf); err != nil {
+ t.Fatalf("EncodeWelcome() error: %v", err)
+ }
+
+ // Verify reserved bytes are zeroed.
+ for i := welcomeOffReserved; i < welcomeOffReserved+108; i++ {
+ if buf[i] != 0 {
+ t.Errorf("reserved byte buf[%d] = 0x%02X, want 0x00", i, buf[i])
+ }
+ }
+}
+
+func TestHelloMsg_LargerBuffer(t *testing.T) {
+ msg := HelloMsg{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ }
+ SetName(&msg, "test")
+
+ buf := make([]byte, 256)
+ for i := range buf {
+ buf[i] = 0xAA
+ }
+
+ if err := EncodeHello(&msg, buf); err != nil {
+ t.Fatalf("EncodeHello() error: %v", err)
+ }
+
+ // Bytes beyond HandshakeSize should be untouched.
+ for i := HandshakeSize; i < len(buf); i++ {
+ if buf[i] != 0xAA {
+ t.Errorf("buf[%d] = 0x%02X, want 0xAA (sentinel should be untouched)", i, buf[i])
+ }
+ }
+}
+
+func TestWelcomeMsg_LargerBuffer(t *testing.T) {
+ msg := WelcomeMsg{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ ModuleID: 7,
+ Accepted: 1,
+ }
+
+ buf := make([]byte, 256)
+ for i := range buf {
+ buf[i] = 0xBB
+ }
+
+ if err := EncodeWelcome(&msg, buf); err != nil {
+ t.Fatalf("EncodeWelcome() error: %v", err)
+ }
+
+ // Bytes beyond HandshakeSize should be untouched.
+ for i := HandshakeSize; i < len(buf); i++ {
+ if buf[i] != 0xBB {
+ t.Errorf("buf[%d] = 0x%02X, want 0xBB (sentinel should be untouched)", i, buf[i])
+ }
+ }
+}
+
+func BenchmarkEncodeHello(b *testing.B) {
+ msg := HelloMsg{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ Width: 400,
+ Height: 120,
+ PreferredFPS: 60,
+ Transport: TransportSocket,
+ }
+ SetName(&msg, "benchmark-module")
+ buf := make([]byte, HandshakeSize)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for b.Loop() {
+ _ = EncodeHello(&msg, buf)
+ }
+}
+
+func BenchmarkDecodeHello(b *testing.B) {
+ msg := HelloMsg{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ Width: 400,
+ Height: 120,
+ PreferredFPS: 60,
+ Transport: TransportSocket,
+ }
+ SetName(&msg, "benchmark-module")
+ buf := make([]byte, HandshakeSize)
+ _ = EncodeHello(&msg, buf)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for b.Loop() {
+ _, _ = DecodeHello(buf)
+ }
+}
+
+func BenchmarkEncodeWelcome(b *testing.B) {
+ msg := WelcomeMsg{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ ModuleID: 42,
+ Accepted: 1,
+ Transport: TransportSocket,
+ MinVersion: 1,
+ MaxVersion: 1,
+ }
+ buf := make([]byte, HandshakeSize)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for b.Loop() {
+ _ = EncodeWelcome(&msg, buf)
+ }
+}
+
+func BenchmarkDecodeWelcome(b *testing.B) {
+ msg := WelcomeMsg{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ ModuleID: 42,
+ Accepted: 1,
+ Transport: TransportSocket,
+ MinVersion: 1,
+ MaxVersion: 1,
+ }
+ buf := make([]byte, HandshakeSize)
+ _ = EncodeWelcome(&msg, buf)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for b.Loop() {
+ _, _ = DecodeWelcome(buf)
+ }
+}
diff --git a/internal/protocol/header.go b/internal/protocol/header.go
new file mode 100644
index 0000000..3d9f9f9
--- /dev/null
+++ b/internal/protocol/header.go
@@ -0,0 +1,270 @@
+package protocol
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+)
+
+// HeaderSize is the fixed size in bytes of the wire frame header.
+// It is designed to be cache-line aligned (64 bytes on modern CPUs).
+const HeaderSize = 64
+
+// Magic is the 4-byte protocol identifier at the start of every header.
+// ASCII: "COMP" (0x43, 0x4F, 0x4D, 0x50).
+var Magic = [4]byte{0x43, 0x4F, 0x4D, 0x50}
+
+// ProtocolVersion is the current wire protocol version.
+const ProtocolVersion uint16 = 1
+
+// Header is the 64-byte frame header that precedes every message on the wire.
+// All multi-byte fields are little-endian encoded.
+type Header struct {
+ // Magic is the protocol identifier (must be "COMP").
+ Magic [4]byte
+
+ // Version is the protocol version number.
+ Version uint16
+
+ // MsgType identifies the message kind (Frame, Handshake, Ack, etc.).
+ MsgType MsgType
+
+ // Flags is a bitfield (DirtyValid, Compressed, Keyframe).
+ Flags Flag
+
+ // ModuleID is the compositor-assigned module identifier.
+ ModuleID uint64
+
+ // Sequence is the monotonically increasing frame counter per module.
+ Sequence uint64
+
+ // TimestampNs is the monotonic clock timestamp in nanoseconds.
+ TimestampNs int64
+
+ // Width is the frame width in pixels.
+ Width uint16
+
+ // Height is the frame height in pixels.
+ Height uint16
+
+ // Stride is the number of bytes per row (typically Width * 4).
+ Stride uint32
+
+ // DirtyX is the X offset of the dirty rectangle.
+ DirtyX uint16
+
+ // DirtyY is the Y offset of the dirty rectangle.
+ DirtyY uint16
+
+ // DirtyW is the width of the dirty rectangle.
+ DirtyW uint16
+
+ // DirtyH is the height of the dirty rectangle.
+ DirtyH uint16
+
+ // PixelFormat identifies the pixel encoding (RGBA8, BGRA8).
+ PixelFormat PixelFormat
+
+ // Compression identifies the payload compression algorithm.
+ Compression Compression
+
+ // Reserved is padding for future use. Must be zero.
+ Reserved [6]byte
+
+ // PayloadSize is the number of payload bytes following this header.
+ PayloadSize uint32
+
+ // UncompressedSize is the original payload size before compression.
+ // When Compression is None, this equals PayloadSize.
+ UncompressedSize uint32
+}
+
+// Field byte offsets within the 64-byte header.
+const (
+ offMagic = 0
+ offVersion = 4
+ offMsgType = 6
+ offFlags = 7
+ offModuleID = 8
+ offSequence = 16
+ offTimestampNs = 24
+ offWidth = 32
+ offHeight = 34
+ offStride = 36
+ offDirtyX = 40
+ offDirtyY = 42
+ offDirtyW = 44
+ offDirtyH = 46
+ offPixelFormat = 48
+ offCompression = 49
+ offReserved = 50
+ offPayloadSize = 56
+ offUncompressedSize = 60
+)
+
+// Errors returned by Encode and Decode.
+var (
+ // ErrBufferTooSmall is returned when the provided buffer is smaller than HeaderSize.
+ ErrBufferTooSmall = errors.New("protocol: buffer too small (need 64 bytes)")
+
+ // ErrInvalidMagic is returned when the decoded magic bytes do not match "COMP".
+ ErrInvalidMagic = errors.New("protocol: invalid magic (expected 0x434F4D50)")
+
+ // ErrUnknownMsgType is returned when the decoded message type is not recognized.
+ ErrUnknownMsgType = errors.New("protocol: unknown message type")
+)
+
+// Encode writes the header h into buf using little-endian byte order.
+// buf must be at least HeaderSize (64) bytes. Encode never allocates.
+func Encode(h *Header, buf []byte) error {
+ if len(buf) < HeaderSize {
+ return ErrBufferTooSmall
+ }
+
+ // Magic (4 bytes)
+ buf[offMagic] = h.Magic[0]
+ buf[offMagic+1] = h.Magic[1]
+ buf[offMagic+2] = h.Magic[2]
+ buf[offMagic+3] = h.Magic[3]
+
+ // Version (2 bytes)
+ binary.LittleEndian.PutUint16(buf[offVersion:], h.Version)
+
+ // MsgType (1 byte)
+ buf[offMsgType] = uint8(h.MsgType)
+
+ // Flags (1 byte)
+ buf[offFlags] = uint8(h.Flags)
+
+ // ModuleID (8 bytes)
+ binary.LittleEndian.PutUint64(buf[offModuleID:], h.ModuleID)
+
+ // Sequence (8 bytes)
+ binary.LittleEndian.PutUint64(buf[offSequence:], h.Sequence)
+
+ // TimestampNs (8 bytes, signed as uint64 bit pattern)
+ binary.LittleEndian.PutUint64(buf[offTimestampNs:], uint64(h.TimestampNs)) //nolint:gosec // bit-cast int64→uint64, no overflow
+
+ // Width (2 bytes)
+ binary.LittleEndian.PutUint16(buf[offWidth:], h.Width)
+
+ // Height (2 bytes)
+ binary.LittleEndian.PutUint16(buf[offHeight:], h.Height)
+
+ // Stride (4 bytes)
+ binary.LittleEndian.PutUint32(buf[offStride:], h.Stride)
+
+ // DirtyX (2 bytes)
+ binary.LittleEndian.PutUint16(buf[offDirtyX:], h.DirtyX)
+
+ // DirtyY (2 bytes)
+ binary.LittleEndian.PutUint16(buf[offDirtyY:], h.DirtyY)
+
+ // DirtyW (2 bytes)
+ binary.LittleEndian.PutUint16(buf[offDirtyW:], h.DirtyW)
+
+ // DirtyH (2 bytes)
+ binary.LittleEndian.PutUint16(buf[offDirtyH:], h.DirtyH)
+
+ // PixelFormat (1 byte)
+ buf[offPixelFormat] = uint8(h.PixelFormat)
+
+ // Compression (1 byte)
+ buf[offCompression] = uint8(h.Compression)
+
+ // Reserved (6 bytes, zero)
+ buf[offReserved] = 0
+ buf[offReserved+1] = 0
+ buf[offReserved+2] = 0
+ buf[offReserved+3] = 0
+ buf[offReserved+4] = 0
+ buf[offReserved+5] = 0
+
+ // PayloadSize (4 bytes)
+ binary.LittleEndian.PutUint32(buf[offPayloadSize:], h.PayloadSize)
+
+ // UncompressedSize (4 bytes)
+ binary.LittleEndian.PutUint32(buf[offUncompressedSize:], h.UncompressedSize)
+
+ return nil
+}
+
+// Decode reads a header from buf and validates the magic bytes and message type.
+// buf must be at least HeaderSize (64) bytes. Decode never allocates.
+func Decode(buf []byte) (Header, error) {
+ if len(buf) < HeaderSize {
+ return Header{}, ErrBufferTooSmall
+ }
+
+ var h Header
+
+ // Magic (4 bytes)
+ h.Magic[0] = buf[offMagic]
+ h.Magic[1] = buf[offMagic+1]
+ h.Magic[2] = buf[offMagic+2]
+ h.Magic[3] = buf[offMagic+3]
+
+ if h.Magic != Magic {
+ return Header{}, fmt.Errorf("%w: got [0x%02X 0x%02X 0x%02X 0x%02X]",
+ ErrInvalidMagic, h.Magic[0], h.Magic[1], h.Magic[2], h.Magic[3])
+ }
+
+ // Version (2 bytes)
+ h.Version = binary.LittleEndian.Uint16(buf[offVersion:])
+
+ // MsgType (1 byte)
+ h.MsgType = MsgType(buf[offMsgType])
+ if !h.MsgType.Valid() {
+ return Header{}, fmt.Errorf("%w: 0x%02X", ErrUnknownMsgType, uint8(h.MsgType))
+ }
+
+ // Flags (1 byte)
+ h.Flags = Flag(buf[offFlags])
+
+ // ModuleID (8 bytes)
+ h.ModuleID = binary.LittleEndian.Uint64(buf[offModuleID:])
+
+ // Sequence (8 bytes)
+ h.Sequence = binary.LittleEndian.Uint64(buf[offSequence:])
+
+ // TimestampNs (8 bytes)
+ h.TimestampNs = int64(binary.LittleEndian.Uint64(buf[offTimestampNs:])) //nolint:gosec // bit-cast uint64→int64, no overflow
+
+ // Width (2 bytes)
+ h.Width = binary.LittleEndian.Uint16(buf[offWidth:])
+
+ // Height (2 bytes)
+ h.Height = binary.LittleEndian.Uint16(buf[offHeight:])
+
+ // Stride (4 bytes)
+ h.Stride = binary.LittleEndian.Uint32(buf[offStride:])
+
+ // DirtyX (2 bytes)
+ h.DirtyX = binary.LittleEndian.Uint16(buf[offDirtyX:])
+
+ // DirtyY (2 bytes)
+ h.DirtyY = binary.LittleEndian.Uint16(buf[offDirtyY:])
+
+ // DirtyW (2 bytes)
+ h.DirtyW = binary.LittleEndian.Uint16(buf[offDirtyW:])
+
+ // DirtyH (2 bytes)
+ h.DirtyH = binary.LittleEndian.Uint16(buf[offDirtyH:])
+
+ // PixelFormat (1 byte)
+ h.PixelFormat = PixelFormat(buf[offPixelFormat])
+
+ // Compression (1 byte)
+ h.Compression = Compression(buf[offCompression])
+
+ // Reserved (6 bytes)
+ copy(h.Reserved[:], buf[offReserved:offReserved+6])
+
+ // PayloadSize (4 bytes)
+ h.PayloadSize = binary.LittleEndian.Uint32(buf[offPayloadSize:])
+
+ // UncompressedSize (4 bytes)
+ h.UncompressedSize = binary.LittleEndian.Uint32(buf[offUncompressedSize:])
+
+ return h, nil
+}
diff --git a/internal/protocol/header_test.go b/internal/protocol/header_test.go
new file mode 100644
index 0000000..5a029e7
--- /dev/null
+++ b/internal/protocol/header_test.go
@@ -0,0 +1,399 @@
+package protocol
+
+import (
+ "errors"
+ "math"
+ "testing"
+)
+
+func TestHeaderSize(t *testing.T) {
+ if HeaderSize != 64 {
+ t.Fatalf("HeaderSize = %d, want 64", HeaderSize)
+ }
+}
+
+func TestHeader_RoundTrip(t *testing.T) {
+ tests := []struct {
+ name string
+ h Header
+ }{
+ {
+ name: "TypicalFrame",
+ h: Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgFrame,
+ Flags: FlagDirtyValid | FlagKeyframe,
+ ModuleID: 42,
+ Sequence: 1001,
+ TimestampNs: 1_000_000_000,
+ Width: 1920,
+ Height: 1080,
+ Stride: 1920 * 4,
+ DirtyX: 100,
+ DirtyY: 200,
+ DirtyW: 300,
+ DirtyH: 400,
+ PixelFormat: PixelRGBA8,
+ Compression: CompressionNone,
+ PayloadSize: 1920 * 1080 * 4,
+ UncompressedSize: 1920 * 1080 * 4,
+ },
+ },
+ {
+ name: "CompressedFrame",
+ h: Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgFrame,
+ Flags: FlagCompressed,
+ ModuleID: 7,
+ Sequence: 55,
+ TimestampNs: -12345678, // negative timestamp is valid
+ Width: 800,
+ Height: 600,
+ Stride: 800 * 4,
+ PixelFormat: PixelBGRA8,
+ Compression: CompressionLZ4,
+ PayloadSize: 100000,
+ UncompressedSize: 800 * 600 * 4,
+ },
+ },
+ {
+ name: "Ack",
+ h: Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgAck,
+ },
+ },
+ {
+ name: "FrameRequest",
+ h: Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgFrameRequest,
+ ModuleID: 99,
+ Sequence: 12345,
+ },
+ },
+ {
+ name: "Disconnect",
+ h: Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgDisconnect,
+ ModuleID: 3,
+ },
+ },
+ {
+ name: "Resize",
+ h: Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgResize,
+ Width: 2560,
+ Height: 1440,
+ Stride: 2560 * 4,
+ },
+ },
+ {
+ name: "MaxValues",
+ h: Header{
+ Magic: Magic,
+ Version: math.MaxUint16,
+ MsgType: MsgFrame,
+ Flags: Flag(0xFF),
+ ModuleID: math.MaxUint64,
+ Sequence: math.MaxUint64,
+ TimestampNs: math.MaxInt64,
+ Width: math.MaxUint16,
+ Height: math.MaxUint16,
+ Stride: math.MaxUint32,
+ DirtyX: math.MaxUint16,
+ DirtyY: math.MaxUint16,
+ DirtyW: math.MaxUint16,
+ DirtyH: math.MaxUint16,
+ PixelFormat: PixelFormat(0xFF),
+ Compression: Compression(0xFF),
+ PayloadSize: math.MaxUint32,
+ UncompressedSize: math.MaxUint32,
+ },
+ },
+ {
+ name: "ZeroValue",
+ h: Header{
+ Magic: Magic,
+ MsgType: MsgFrame,
+ },
+ },
+ {
+ name: "NegativeTimestamp",
+ h: Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgHandshake,
+ TimestampNs: math.MinInt64,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, HeaderSize)
+ if err := Encode(&tt.h, buf); err != nil {
+ t.Fatalf("Encode() error: %v", err)
+ }
+
+ got, err := Decode(buf)
+ if err != nil {
+ t.Fatalf("Decode() error: %v", err)
+ }
+
+ if got != tt.h {
+ t.Errorf("round-trip mismatch:\n got: %+v\n want: %+v", got, tt.h)
+ }
+ })
+ }
+}
+
+func TestEncode_BufferTooSmall(t *testing.T) {
+ h := Header{Magic: Magic, MsgType: MsgFrame}
+ tests := []struct {
+ name string
+ size int
+ }{
+ {"Zero", 0},
+ {"One", 1},
+ {"Half", 32},
+ {"AlmostFull", 63},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, tt.size)
+ err := Encode(&h, buf)
+ if !errors.Is(err, ErrBufferTooSmall) {
+ t.Errorf("Encode() with %d-byte buf: got error %v, want ErrBufferTooSmall", tt.size, err)
+ }
+ })
+ }
+}
+
+func TestDecode_BufferTooSmall(t *testing.T) {
+ tests := []struct {
+ name string
+ size int
+ }{
+ {"Zero", 0},
+ {"One", 1},
+ {"Half", 32},
+ {"AlmostFull", 63},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, tt.size)
+ _, err := Decode(buf)
+ if !errors.Is(err, ErrBufferTooSmall) {
+ t.Errorf("Decode() with %d-byte buf: got error %v, want ErrBufferTooSmall", tt.size, err)
+ }
+ })
+ }
+}
+
+func TestDecode_InvalidMagic(t *testing.T) {
+ tests := []struct {
+ name string
+ magic [4]byte
+ }{
+ {"AllZeros", [4]byte{0, 0, 0, 0}},
+ {"WrongFirst", [4]byte{0x00, 0x4F, 0x4D, 0x50}},
+ {"WrongLast", [4]byte{0x43, 0x4F, 0x4D, 0x00}},
+ {"Reversed", [4]byte{0x50, 0x4D, 0x4F, 0x43}},
+ {"AllOnes", [4]byte{0xFF, 0xFF, 0xFF, 0xFF}},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, HeaderSize)
+ // Manually write magic to bypass Encode validation.
+ buf[0] = tt.magic[0]
+ buf[1] = tt.magic[1]
+ buf[2] = tt.magic[2]
+ buf[3] = tt.magic[3]
+ buf[offMsgType] = uint8(MsgFrame)
+
+ _, err := Decode(buf)
+ if err == nil {
+ t.Errorf("Decode() with magic %v: expected error, got nil", tt.magic)
+ }
+ })
+ }
+}
+
+func TestDecode_UnknownMsgType(t *testing.T) {
+ tests := []struct {
+ name string
+ msgType uint8
+ }{
+ {"Zero", 0x00},
+ {"Seven", 0x07},
+ {"Max", 0xFF},
+ {"JustAbove", 0x08},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := make([]byte, HeaderSize)
+ // Valid magic.
+ buf[0] = Magic[0]
+ buf[1] = Magic[1]
+ buf[2] = Magic[2]
+ buf[3] = Magic[3]
+ buf[offMsgType] = tt.msgType
+
+ _, err := Decode(buf)
+ if err == nil {
+ t.Errorf("Decode() with MsgType 0x%02X: expected error, got nil", tt.msgType)
+ }
+ })
+ }
+}
+
+func TestEncode_LargerBuffer(t *testing.T) {
+ // Encode into a buffer larger than 64 bytes should work and not corrupt beyond.
+ h := Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgFrame,
+ }
+ buf := make([]byte, 128)
+ for i := range buf {
+ buf[i] = 0xAA // sentinel
+ }
+
+ if err := Encode(&h, buf); err != nil {
+ t.Fatalf("Encode() error: %v", err)
+ }
+
+ // Bytes beyond 64 should be untouched.
+ for i := HeaderSize; i < len(buf); i++ {
+ if buf[i] != 0xAA {
+ t.Errorf("buf[%d] = 0x%02X, want 0xAA (sentinel should be untouched)", i, buf[i])
+ }
+ }
+}
+
+func TestDecode_ExactBuffer(t *testing.T) {
+ // Decode from a buffer that is exactly 64 bytes.
+ h := Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgFrame,
+ ModuleID: 42,
+ PayloadSize: 1024,
+ }
+ buf := make([]byte, HeaderSize)
+ if err := Encode(&h, buf); err != nil {
+ t.Fatalf("Encode() error: %v", err)
+ }
+
+ got, err := Decode(buf)
+ if err != nil {
+ t.Fatalf("Decode() error: %v", err)
+ }
+ if got.ModuleID != 42 {
+ t.Errorf("ModuleID = %d, want 42", got.ModuleID)
+ }
+ if got.PayloadSize != 1024 {
+ t.Errorf("PayloadSize = %d, want 1024", got.PayloadSize)
+ }
+}
+
+func TestHeader_ReservedZeroed(t *testing.T) {
+ // After encoding, reserved bytes should be zero.
+ h := Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgFrame,
+ }
+ buf := make([]byte, HeaderSize)
+ // Fill with non-zero to prove Encode zeros them.
+ for i := range buf {
+ buf[i] = 0xFF
+ }
+
+ if err := Encode(&h, buf); err != nil {
+ t.Fatalf("Encode() error: %v", err)
+ }
+
+ for i := offReserved; i < offReserved+6; i++ {
+ if buf[i] != 0 {
+ t.Errorf("reserved byte buf[%d] = 0x%02X, want 0x00", i, buf[i])
+ }
+ }
+}
+
+func TestMagicBytes(t *testing.T) {
+ // Verify magic is "COMP" in ASCII.
+ if Magic != [4]byte{'C', 'O', 'M', 'P'} {
+ t.Errorf("Magic = %v, want {0x43, 0x4F, 0x4D, 0x50} ('COMP')", Magic)
+ }
+}
+
+func BenchmarkHeaderEncode(b *testing.B) {
+ h := Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgFrame,
+ Flags: FlagDirtyValid | FlagCompressed,
+ ModuleID: 42,
+ Sequence: 1001,
+ TimestampNs: 1_000_000_000,
+ Width: 1920,
+ Height: 1080,
+ Stride: 1920 * 4,
+ DirtyX: 100,
+ DirtyY: 200,
+ DirtyW: 300,
+ DirtyH: 400,
+ PixelFormat: PixelRGBA8,
+ Compression: CompressionLZ4,
+ PayloadSize: 100000,
+ UncompressedSize: 1920 * 1080 * 4,
+ }
+ buf := make([]byte, HeaderSize)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for b.Loop() {
+ _ = Encode(&h, buf)
+ }
+}
+
+func BenchmarkHeaderDecode(b *testing.B) {
+ h := Header{
+ Magic: Magic,
+ Version: ProtocolVersion,
+ MsgType: MsgFrame,
+ Flags: FlagDirtyValid | FlagCompressed,
+ ModuleID: 42,
+ Sequence: 1001,
+ TimestampNs: 1_000_000_000,
+ Width: 1920,
+ Height: 1080,
+ Stride: 1920 * 4,
+ DirtyX: 100,
+ DirtyY: 200,
+ DirtyW: 300,
+ DirtyH: 400,
+ PixelFormat: PixelRGBA8,
+ Compression: CompressionLZ4,
+ PayloadSize: 100000,
+ UncompressedSize: 1920 * 1080 * 4,
+ }
+ buf := make([]byte, HeaderSize)
+ _ = Encode(&h, buf)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for b.Loop() {
+ _, _ = Decode(buf)
+ }
+}
diff --git a/internal/protocol/message.go b/internal/protocol/message.go
new file mode 100644
index 0000000..13368a7
--- /dev/null
+++ b/internal/protocol/message.go
@@ -0,0 +1,174 @@
+package protocol
+
+import "fmt"
+
+// MsgType identifies the kind of message carried by a frame header.
+type MsgType uint8
+
+const (
+ // MsgFrame carries a rendered pixel buffer from module to compositor.
+ MsgFrame MsgType = 0x01
+
+ // MsgHandshake is the initial connection negotiation message.
+ MsgHandshake MsgType = 0x02
+
+ // MsgAck acknowledges receipt of a frame (compositor to module).
+ MsgAck MsgType = 0x03
+
+ // MsgFrameRequest is sent by the compositor to request the next frame
+ // from a module (pull-based flow control, Wayland pattern).
+ MsgFrameRequest MsgType = 0x04
+
+ // MsgResize notifies a module that its frame dimensions have changed.
+ MsgResize MsgType = 0x05
+
+ // MsgDisconnect signals a graceful disconnection.
+ MsgDisconnect MsgType = 0x06
+)
+
+// String returns the human-readable name of the message type.
+func (m MsgType) String() string {
+ switch m {
+ case MsgFrame:
+ return "Frame"
+ case MsgHandshake:
+ return "Handshake"
+ case MsgAck:
+ return "Ack"
+ case MsgFrameRequest:
+ return "FrameRequest"
+ case MsgResize:
+ return "Resize"
+ case MsgDisconnect:
+ return "Disconnect"
+ default:
+ return fmt.Sprintf("MsgType(0x%02X)", uint8(m))
+ }
+}
+
+// Valid reports whether m is a known message type.
+func (m MsgType) Valid() bool {
+ switch m {
+ case MsgFrame, MsgHandshake, MsgAck, MsgFrameRequest, MsgResize, MsgDisconnect:
+ return true
+ default:
+ return false
+ }
+}
+
+// Flag is a bitfield carried in the header's Flags byte.
+type Flag uint8
+
+const (
+ // FlagDirtyValid indicates that the DirtyRect fields contain valid data.
+ // When not set, the entire frame is considered dirty (keyframe).
+ FlagDirtyValid Flag = 1 << iota
+
+ // FlagCompressed indicates that the payload is compressed.
+ // The Compression field specifies the algorithm.
+ FlagCompressed
+
+ // FlagKeyframe marks the frame as a full keyframe (no delta dependency).
+ FlagKeyframe
+)
+
+// Has reports whether the flag f is set in the bitmask flags.
+func (f Flag) Has(flag Flag) bool {
+ return f&flag != 0
+}
+
+// Set returns the bitmask with flag set.
+func (f Flag) Set(flag Flag) Flag {
+ return f | flag
+}
+
+// Clear returns the bitmask with flag cleared.
+func (f Flag) Clear(flag Flag) Flag {
+ return f &^ flag
+}
+
+// String returns the human-readable name of the flag bit.
+func (f Flag) String() string {
+ switch f {
+ case FlagDirtyValid:
+ return "DirtyValid"
+ case FlagCompressed:
+ return "Compressed"
+ case FlagKeyframe:
+ return "Keyframe"
+ default:
+ return fmt.Sprintf("Flag(0x%02X)", uint8(f))
+ }
+}
+
+// PixelFormat identifies the pixel encoding of frame payload data.
+type PixelFormat uint8
+
+const (
+ // PixelRGBA8 is 8-bit RGBA, 4 bytes per pixel, premultiplied alpha.
+ PixelRGBA8 PixelFormat = 0x00
+
+ // PixelBGRA8 is 8-bit BGRA, 4 bytes per pixel, premultiplied alpha.
+ // Common on Windows/DX12 surfaces.
+ PixelBGRA8 PixelFormat = 0x01
+)
+
+// String returns the human-readable name of the pixel format.
+func (p PixelFormat) String() string {
+ switch p {
+ case PixelRGBA8:
+ return "RGBA8"
+ case PixelBGRA8:
+ return "BGRA8"
+ default:
+ return fmt.Sprintf("PixelFormat(0x%02X)", uint8(p))
+ }
+}
+
+// Valid reports whether p is a known pixel format.
+func (p PixelFormat) Valid() bool {
+ switch p {
+ case PixelRGBA8, PixelBGRA8:
+ return true
+ default:
+ return false
+ }
+}
+
+// Compression identifies the payload compression algorithm.
+type Compression uint8
+
+const (
+ // CompressionNone means the payload is uncompressed raw pixels.
+ CompressionNone Compression = 0x00
+
+ // CompressionLZ4 uses the LZ4 block compression algorithm.
+ CompressionLZ4 Compression = 0x01
+
+ // CompressionZstd uses the Zstandard compression algorithm.
+ CompressionZstd Compression = 0x02
+)
+
+// String returns the human-readable name of the compression algorithm.
+func (c Compression) String() string {
+ switch c {
+ case CompressionNone:
+ return "None"
+ case CompressionLZ4:
+ return "LZ4"
+ case CompressionZstd:
+ return "Zstd"
+ default:
+ return fmt.Sprintf("Compression(0x%02X)", uint8(c))
+ }
+}
+
+// Valid reports whether c is a known compression algorithm.
+func (c Compression) Valid() bool {
+ switch c {
+ case CompressionNone, CompressionLZ4, CompressionZstd:
+ return true
+ default:
+ return false
+ }
+}
diff --git a/internal/protocol/message_test.go b/internal/protocol/message_test.go
new file mode 100644
index 0000000..a3c3bab
--- /dev/null
+++ b/internal/protocol/message_test.go
@@ -0,0 +1,272 @@
+package protocol
+
+import (
+ "testing"
+)
+
+func TestMsgType_String(t *testing.T) {
+ tests := []struct {
+ name string
+ m MsgType
+ want string
+ }{
+ {"Frame", MsgFrame, "Frame"},
+ {"Handshake", MsgHandshake, "Handshake"},
+ {"Ack", MsgAck, "Ack"},
+ {"FrameRequest", MsgFrameRequest, "FrameRequest"},
+ {"Resize", MsgResize, "Resize"},
+ {"Disconnect", MsgDisconnect, "Disconnect"},
+ {"Unknown_0xFF", MsgType(0xFF), "MsgType(0xFF)"},
+ {"Unknown_0x00", MsgType(0x00), "MsgType(0x00)"},
+ {"Unknown_0x07", MsgType(0x07), "MsgType(0x07)"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.m.String()
+ if got != tt.want {
+ t.Errorf("MsgType(%d).String() = %q, want %q", tt.m, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestMsgType_Valid(t *testing.T) {
+ tests := []struct {
+ name string
+ m MsgType
+ want bool
+ }{
+ {"Frame", MsgFrame, true},
+ {"Handshake", MsgHandshake, true},
+ {"Ack", MsgAck, true},
+ {"FrameRequest", MsgFrameRequest, true},
+ {"Resize", MsgResize, true},
+ {"Disconnect", MsgDisconnect, true},
+ {"Zero", MsgType(0x00), false},
+ {"Unknown_0x07", MsgType(0x07), false},
+ {"Unknown_0xFF", MsgType(0xFF), false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.m.Valid()
+ if got != tt.want {
+ t.Errorf("MsgType(%d).Valid() = %v, want %v", tt.m, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFlag_Operations(t *testing.T) {
+ t.Run("Has", func(t *testing.T) {
+ flags := FlagDirtyValid | FlagKeyframe
+ if !flags.Has(FlagDirtyValid) {
+ t.Error("expected FlagDirtyValid to be set")
+ }
+ if !flags.Has(FlagKeyframe) {
+ t.Error("expected FlagKeyframe to be set")
+ }
+ if flags.Has(FlagCompressed) {
+ t.Error("expected FlagCompressed to not be set")
+ }
+ })
+
+ t.Run("Set", func(t *testing.T) {
+ var f Flag
+ f = f.Set(FlagCompressed)
+ if !f.Has(FlagCompressed) {
+ t.Error("expected FlagCompressed after Set")
+ }
+ f = f.Set(FlagDirtyValid)
+ if !f.Has(FlagDirtyValid) {
+ t.Error("expected FlagDirtyValid after Set")
+ }
+ // Setting again is idempotent.
+ f = f.Set(FlagCompressed)
+ if f != FlagCompressed|FlagDirtyValid {
+ t.Errorf("expected 0x03, got 0x%02X", f)
+ }
+ })
+
+ t.Run("Clear", func(t *testing.T) {
+ f := FlagDirtyValid | FlagCompressed | FlagKeyframe
+ f = f.Clear(FlagCompressed)
+ if f.Has(FlagCompressed) {
+ t.Error("expected FlagCompressed to be cleared")
+ }
+ if !f.Has(FlagDirtyValid) {
+ t.Error("expected FlagDirtyValid to still be set")
+ }
+ if !f.Has(FlagKeyframe) {
+ t.Error("expected FlagKeyframe to still be set")
+ }
+ })
+
+ t.Run("ZeroHasNothing", func(t *testing.T) {
+ var f Flag
+ if f.Has(FlagDirtyValid) || f.Has(FlagCompressed) || f.Has(FlagKeyframe) {
+ t.Error("zero Flag should have no flags set")
+ }
+ })
+}
+
+func TestFlag_String(t *testing.T) {
+ tests := []struct {
+ name string
+ f Flag
+ want string
+ }{
+ {"DirtyValid", FlagDirtyValid, "DirtyValid"},
+ {"Compressed", FlagCompressed, "Compressed"},
+ {"Keyframe", FlagKeyframe, "Keyframe"},
+ {"Unknown", Flag(0x10), "Flag(0x10)"},
+ {"Combined", FlagDirtyValid | FlagCompressed, "Flag(0x03)"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.f.String()
+ if got != tt.want {
+ t.Errorf("Flag(0x%02X).String() = %q, want %q", tt.f, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestPixelFormat_String(t *testing.T) {
+ tests := []struct {
+ name string
+ p PixelFormat
+ want string
+ }{
+ {"RGBA8", PixelRGBA8, "RGBA8"},
+ {"BGRA8", PixelBGRA8, "BGRA8"},
+ {"Unknown", PixelFormat(0x99), "PixelFormat(0x99)"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.p.String()
+ if got != tt.want {
+ t.Errorf("PixelFormat(0x%02X).String() = %q, want %q", tt.p, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestPixelFormat_Valid(t *testing.T) {
+ tests := []struct {
+ name string
+ p PixelFormat
+ want bool
+ }{
+ {"RGBA8", PixelRGBA8, true},
+ {"BGRA8", PixelBGRA8, true},
+ {"Unknown", PixelFormat(0x02), false},
+ {"Max", PixelFormat(0xFF), false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.p.Valid()
+ if got != tt.want {
+ t.Errorf("PixelFormat(0x%02X).Valid() = %v, want %v", tt.p, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestCompression_String(t *testing.T) {
+ tests := []struct {
+ name string
+ c Compression
+ want string
+ }{
+ {"None", CompressionNone, "None"},
+ {"LZ4", CompressionLZ4, "LZ4"},
+ {"Zstd", CompressionZstd, "Zstd"},
+ {"Unknown", Compression(0xAB), "Compression(0xAB)"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.c.String()
+ if got != tt.want {
+ t.Errorf("Compression(0x%02X).String() = %q, want %q", tt.c, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestCompression_Valid(t *testing.T) {
+ tests := []struct {
+ name string
+ c Compression
+ want bool
+ }{
+ {"None", CompressionNone, true},
+ {"LZ4", CompressionLZ4, true},
+ {"Zstd", CompressionZstd, true},
+ {"Unknown_0x03", Compression(0x03), false},
+ {"Unknown_0xFF", Compression(0xFF), false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.c.Valid()
+ if got != tt.want {
+ t.Errorf("Compression(0x%02X).Valid() = %v, want %v", tt.c, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFlagValues(t *testing.T) {
+ // Verify flag bit positions match the spec.
+ if FlagDirtyValid != 0x01 {
+ t.Errorf("FlagDirtyValid = 0x%02X, want 0x01", FlagDirtyValid)
+ }
+ if FlagCompressed != 0x02 {
+ t.Errorf("FlagCompressed = 0x%02X, want 0x02", FlagCompressed)
+ }
+ if FlagKeyframe != 0x04 {
+ t.Errorf("FlagKeyframe = 0x%02X, want 0x04", FlagKeyframe)
+ }
+}
+
+func TestMsgTypeValues(t *testing.T) {
+ // Verify message type values match the spec.
+ if MsgFrame != 0x01 {
+ t.Errorf("MsgFrame = 0x%02X, want 0x01", MsgFrame)
+ }
+ if MsgHandshake != 0x02 {
+ t.Errorf("MsgHandshake = 0x%02X, want 0x02", MsgHandshake)
+ }
+ if MsgAck != 0x03 {
+ t.Errorf("MsgAck = 0x%02X, want 0x03", MsgAck)
+ }
+ if MsgFrameRequest != 0x04 {
+ t.Errorf("MsgFrameRequest = 0x%02X, want 0x04", MsgFrameRequest)
+ }
+ if MsgResize != 0x05 {
+ t.Errorf("MsgResize = 0x%02X, want 0x05", MsgResize)
+ }
+ if MsgDisconnect != 0x06 {
+ t.Errorf("MsgDisconnect = 0x%02X, want 0x06", MsgDisconnect)
+ }
+}
+
+func TestPixelFormatValues(t *testing.T) {
+ if PixelRGBA8 != 0x00 {
+ t.Errorf("PixelRGBA8 = 0x%02X, want 0x00", PixelRGBA8)
+ }
+ if PixelBGRA8 != 0x01 {
+ t.Errorf("PixelBGRA8 = 0x%02X, want 0x01", PixelBGRA8)
+ }
+}
+
+func TestCompressionValues(t *testing.T) {
+ if CompressionNone != 0x00 {
+ t.Errorf("CompressionNone = 0x%02X, want 0x00", CompressionNone)
+ }
+ if CompressionLZ4 != 0x01 {
+ t.Errorf("CompressionLZ4 = 0x%02X, want 0x01", CompressionLZ4)
+ }
+ if CompressionZstd != 0x02 {
+ t.Errorf("CompressionZstd = 0x%02X, want 0x02", CompressionZstd)
+ }
+}
diff --git a/internal/transport/socket/conn.go b/internal/transport/socket/conn.go
new file mode 100644
index 0000000..f0445e8
--- /dev/null
+++ b/internal/transport/socket/conn.go
@@ -0,0 +1,183 @@
+package socket
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+
+ "github.com/gogpu/compose/internal/protocol"
+)
+
+// bufferSize is the bufio.Reader buffer size.
+// 256 KB handles most frames (including 192 KB payloads) in one syscall.
+const bufferSize = 256 * 1024
+
+// Conn wraps a net.Conn with framed header+payload read/write.
+// Concurrent reads and writes are safe (separate locks).
+// Concurrent reads from multiple goroutines are serialized by readMu.
+// Concurrent writes from multiple goroutines are serialized by writeMu.
+type Conn struct {
+ raw net.Conn
+ reader *bufio.Reader
+ writeMu sync.Mutex
+ readMu sync.Mutex
+ hdrBuf [protocol.HeaderSize]byte // reused scratch buffer for header encode/decode
+ hsBuf [protocol.HandshakeSize]byte
+}
+
+// NewConn wraps an existing network connection with framed I/O.
+func NewConn(c net.Conn) *Conn {
+ return &Conn{
+ raw: c,
+ reader: bufio.NewReaderSize(c, bufferSize),
+ }
+}
+
+// WriteFrame sends a header + payload atomically.
+// The header is encoded into an internal buffer and written together with
+// payload in a single locked section to prevent interleaving.
+func (c *Conn) WriteFrame(hdr *protocol.Header, payload []byte) error {
+ c.writeMu.Lock()
+ defer c.writeMu.Unlock()
+
+ if err := protocol.Encode(hdr, c.hdrBuf[:]); err != nil {
+ return fmt.Errorf("socket: encode header: %w", err)
+ }
+
+ // Write header bytes.
+ if _, err := c.raw.Write(c.hdrBuf[:]); err != nil {
+ return fmt.Errorf("socket: write header: %w", err)
+ }
+
+ // Write payload if present.
+ if len(payload) > 0 {
+ if _, err := c.raw.Write(payload); err != nil {
+ return fmt.Errorf("socket: write payload: %w", err)
+ }
+ }
+
+ return nil
+}
+
+// ReadFrame reads a header + payload from the connection.
+// The returned payload slice is freshly allocated if PayloadSize > 0.
+func (c *Conn) ReadFrame() (protocol.Header, []byte, error) {
+ return c.ReadFrameInto(nil)
+}
+
+// ReadFrameInto reads a header + payload into a caller-provided buffer.
+// If buf is nil or too small for the header's PayloadSize, a new buffer
+// is allocated. The returned slice may be a sub-slice of buf or a new
+// allocation.
+func (c *Conn) ReadFrameInto(buf []byte) (protocol.Header, []byte, error) {
+ c.readMu.Lock()
+ defer c.readMu.Unlock()
+
+ // Read exactly HeaderSize bytes into scratch buffer.
+ if _, err := io.ReadFull(c.reader, c.hdrBuf[:]); err != nil {
+ return protocol.Header{}, nil, fmt.Errorf("socket: read header: %w", err)
+ }
+
+ hdr, err := protocol.Decode(c.hdrBuf[:])
+ if err != nil {
+ return protocol.Header{}, nil, fmt.Errorf("socket: decode header: %w", err)
+ }
+
+ // Read payload.
+ size := int(hdr.PayloadSize)
+ if size == 0 {
+ return hdr, nil, nil
+ }
+
+ // Reuse or allocate payload buffer.
+ var payload []byte
+ if len(buf) >= size {
+ payload = buf[:size]
+ } else {
+ payload = make([]byte, size)
+ }
+
+ if _, err := io.ReadFull(c.reader, payload); err != nil {
+ return protocol.Header{}, nil, fmt.Errorf("socket: read payload (%d bytes): %w", size, err)
+ }
+
+ return hdr, payload, nil
+}
+
+// WriteHandshakeHello sends a HelloMsg on the connection.
+func (c *Conn) WriteHandshakeHello(msg *protocol.HelloMsg) error {
+ c.writeMu.Lock()
+ defer c.writeMu.Unlock()
+
+ if err := protocol.EncodeHello(msg, c.hsBuf[:]); err != nil {
+ return fmt.Errorf("socket: encode hello: %w", err)
+ }
+
+ if _, err := c.raw.Write(c.hsBuf[:]); err != nil {
+ return fmt.Errorf("socket: write hello: %w", err)
+ }
+
+ return nil
+}
+
+// ReadHandshakeHello reads a HelloMsg from the connection.
+func (c *Conn) ReadHandshakeHello() (protocol.HelloMsg, error) {
+ c.readMu.Lock()
+ defer c.readMu.Unlock()
+
+ if _, err := io.ReadFull(c.reader, c.hsBuf[:]); err != nil {
+ return protocol.HelloMsg{}, fmt.Errorf("socket: read hello: %w", err)
+ }
+
+ msg, err := protocol.DecodeHello(c.hsBuf[:])
+ if err != nil {
+ return protocol.HelloMsg{}, fmt.Errorf("socket: decode hello: %w", err)
+ }
+
+ return msg, nil
+}
+
+// WriteHandshakeWelcome sends a WelcomeMsg on the connection.
+func (c *Conn) WriteHandshakeWelcome(msg *protocol.WelcomeMsg) error {
+ c.writeMu.Lock()
+ defer c.writeMu.Unlock()
+
+ if err := protocol.EncodeWelcome(msg, c.hsBuf[:]); err != nil {
+ return fmt.Errorf("socket: encode welcome: %w", err)
+ }
+
+ if _, err := c.raw.Write(c.hsBuf[:]); err != nil {
+ return fmt.Errorf("socket: write welcome: %w", err)
+ }
+
+ return nil
+}
+
+// ReadHandshakeWelcome reads a WelcomeMsg from the connection.
+func (c *Conn) ReadHandshakeWelcome() (protocol.WelcomeMsg, error) {
+ c.readMu.Lock()
+ defer c.readMu.Unlock()
+
+ if _, err := io.ReadFull(c.reader, c.hsBuf[:]); err != nil {
+ return protocol.WelcomeMsg{}, fmt.Errorf("socket: read welcome: %w", err)
+ }
+
+ msg, err := protocol.DecodeWelcome(c.hsBuf[:])
+ if err != nil {
+ return protocol.WelcomeMsg{}, fmt.Errorf("socket: decode welcome: %w", err)
+ }
+
+ return msg, nil
+}
+
+// Close closes the underlying network connection.
+func (c *Conn) Close() error {
+ return c.raw.Close()
+}
+
+// RemoteAddr returns the remote address of the underlying connection.
+func (c *Conn) RemoteAddr() net.Addr {
+ return c.raw.RemoteAddr()
+}
diff --git a/internal/transport/socket/conn_test.go b/internal/transport/socket/conn_test.go
new file mode 100644
index 0000000..0843b7a
--- /dev/null
+++ b/internal/transport/socket/conn_test.go
@@ -0,0 +1,738 @@
+package socket
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "net"
+ "sync"
+ "testing"
+
+ "github.com/gogpu/compose/internal/protocol"
+)
+
+// newPipeConns creates a pair of Conn values connected by net.Pipe.
+func newPipeConns(t *testing.T) (client, server *Conn) {
+ t.Helper()
+ c1, c2 := net.Pipe()
+ t.Cleanup(func() {
+ c1.Close()
+ c2.Close()
+ })
+ return NewConn(c1), NewConn(c2)
+}
+
+// makeHeader returns a minimal valid header for testing.
+func makeHeader(msgType protocol.MsgType, payloadSize uint32) protocol.Header {
+ return protocol.Header{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ MsgType: msgType,
+ PayloadSize: payloadSize,
+ }
+}
+
+func TestWriteReadFrame_RoundTrip(t *testing.T) {
+ tests := []struct {
+ name string
+ msgType protocol.MsgType
+ payload []byte
+ }{
+ {
+ name: "frame with payload",
+ msgType: protocol.MsgFrame,
+ payload: []byte("hello compose"),
+ },
+ {
+ name: "ack empty payload",
+ msgType: protocol.MsgAck,
+ payload: nil,
+ },
+ {
+ name: "frame request empty",
+ msgType: protocol.MsgFrameRequest,
+ payload: []byte{},
+ },
+ {
+ name: "single byte payload",
+ msgType: protocol.MsgFrame,
+ payload: []byte{0xFF},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ client, server := newPipeConns(t)
+
+ payloadLen := uint32(len(tt.payload))
+ hdr := makeHeader(tt.msgType, payloadLen)
+
+ // Write in background.
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- client.WriteFrame(&hdr, tt.payload)
+ }()
+
+ gotHdr, gotPayload, err := server.ReadFrame()
+ if err != nil {
+ t.Fatalf("ReadFrame: %v", err)
+ }
+
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteFrame: %v", werr)
+ }
+
+ if gotHdr.MsgType != tt.msgType {
+ t.Errorf("MsgType = %v, want %v", gotHdr.MsgType, tt.msgType)
+ }
+ if gotHdr.PayloadSize != payloadLen {
+ t.Errorf("PayloadSize = %d, want %d", gotHdr.PayloadSize, payloadLen)
+ }
+
+ if payloadLen == 0 {
+ if gotPayload != nil {
+ t.Errorf("payload = %v, want nil", gotPayload)
+ }
+ } else if !bytes.Equal(gotPayload, tt.payload) {
+ t.Errorf("payload mismatch: got %d bytes, want %d bytes", len(gotPayload), len(tt.payload))
+ }
+ })
+ }
+}
+
+func TestWriteReadFrame_LargePayload(t *testing.T) {
+ client, server := newPipeConns(t)
+
+ // 1 MB payload.
+ payload := make([]byte, 1<<20)
+ for i := range payload {
+ payload[i] = byte(i % 256)
+ }
+
+ hdr := makeHeader(protocol.MsgFrame, uint32(len(payload)))
+
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- client.WriteFrame(&hdr, payload)
+ }()
+
+ gotHdr, gotPayload, err := server.ReadFrame()
+ if err != nil {
+ t.Fatalf("ReadFrame: %v", err)
+ }
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteFrame: %v", werr)
+ }
+
+ if gotHdr.PayloadSize != uint32(len(payload)) {
+ t.Errorf("PayloadSize = %d, want %d", gotHdr.PayloadSize, len(payload))
+ }
+ if !bytes.Equal(gotPayload, payload) {
+ t.Error("large payload content mismatch")
+ }
+}
+
+func TestReadFrameInto_ReuseBuffer(t *testing.T) {
+ client, server := newPipeConns(t)
+
+ payload := []byte("reuse this buffer please")
+ hdr := makeHeader(protocol.MsgFrame, uint32(len(payload)))
+
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- client.WriteFrame(&hdr, payload)
+ }()
+
+ // Provide a sufficiently large buffer.
+ buf := make([]byte, 1024)
+ _, gotPayload, err := server.ReadFrameInto(buf)
+ if err != nil {
+ t.Fatalf("ReadFrameInto: %v", err)
+ }
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteFrame: %v", werr)
+ }
+
+ if !bytes.Equal(gotPayload, payload) {
+ t.Error("payload mismatch")
+ }
+
+ // Verify the returned slice shares backing array with buf.
+ if &gotPayload[0] != &buf[0] {
+ t.Error("ReadFrameInto did not reuse provided buffer")
+ }
+}
+
+func TestReadFrameInto_SmallBuffer_Allocates(t *testing.T) {
+ client, server := newPipeConns(t)
+
+ payload := []byte("too big for tiny buffer")
+ hdr := makeHeader(protocol.MsgFrame, uint32(len(payload)))
+
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- client.WriteFrame(&hdr, payload)
+ }()
+
+ // Buffer smaller than payload.
+ buf := make([]byte, 4)
+ _, gotPayload, err := server.ReadFrameInto(buf)
+ if err != nil {
+ t.Fatalf("ReadFrameInto: %v", err)
+ }
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteFrame: %v", werr)
+ }
+
+ if !bytes.Equal(gotPayload, payload) {
+ t.Error("payload mismatch")
+ }
+}
+
+func TestConcurrentWriteRead(t *testing.T) {
+ client, server := newPipeConns(t)
+
+ const numFrames = 100
+ payload := []byte("concurrent test payload")
+ hdr := makeHeader(protocol.MsgFrame, uint32(len(payload)))
+
+ // Writer goroutine: send numFrames frames.
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := range numFrames {
+ h := hdr
+ h.Sequence = uint64(i)
+ if err := client.WriteFrame(&h, payload); err != nil {
+ t.Errorf("WriteFrame[%d]: %v", i, err)
+ return
+ }
+ }
+ }()
+
+ // Reader: receive numFrames frames.
+ for i := range numFrames {
+ gotHdr, gotPayload, err := server.ReadFrame()
+ if err != nil {
+ t.Fatalf("ReadFrame[%d]: %v", i, err)
+ }
+ if gotHdr.Sequence != uint64(i) {
+ t.Errorf("frame %d: Sequence = %d, want %d", i, gotHdr.Sequence, i)
+ }
+ if !bytes.Equal(gotPayload, payload) {
+ t.Errorf("frame %d: payload mismatch", i)
+ }
+ }
+
+ wg.Wait()
+}
+
+func TestConcurrentWriters(t *testing.T) {
+ client, server := newPipeConns(t)
+
+ const numWriters = 4
+ const framesPerWriter = 25
+ payload := []byte("multi-writer")
+ hdr := makeHeader(protocol.MsgFrame, uint32(len(payload)))
+
+ var wg sync.WaitGroup
+ for w := range numWriters {
+ wg.Add(1)
+ go func(writerID int) {
+ defer wg.Done()
+ for i := range framesPerWriter {
+ h := hdr
+ h.ModuleID = uint64(writerID)
+ h.Sequence = uint64(i)
+ if err := client.WriteFrame(&h, payload); err != nil {
+ t.Errorf("writer %d frame %d: %v", writerID, i, err)
+ return
+ }
+ }
+ }(w)
+ }
+
+ // Read all frames.
+ total := numWriters * framesPerWriter
+ for i := range total {
+ _, gotPayload, err := server.ReadFrame()
+ if err != nil {
+ t.Fatalf("ReadFrame[%d]: %v", i, err)
+ }
+ if !bytes.Equal(gotPayload, payload) {
+ t.Errorf("frame %d: payload mismatch", i)
+ }
+ }
+
+ wg.Wait()
+}
+
+func TestHandshake_RoundTrip(t *testing.T) {
+ client, server := newPipeConns(t)
+
+ // Module sends Hello.
+ hello := protocol.HelloMsg{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ Width: 400,
+ Height: 120,
+ PreferredFPS: 60,
+ Transport: protocol.TransportSocket,
+ }
+ protocol.SetName(&hello, "test-module")
+
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- client.WriteHandshakeHello(&hello)
+ }()
+
+ gotHello, err := server.ReadHandshakeHello()
+ if err != nil {
+ t.Fatalf("ReadHandshakeHello: %v", err)
+ }
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteHandshakeHello: %v", werr)
+ }
+
+ if gotHello.Magic != protocol.Magic {
+ t.Error("hello: magic mismatch")
+ }
+ if protocol.GetName(&gotHello) != "test-module" {
+ t.Errorf("hello: name = %q, want %q", protocol.GetName(&gotHello), "test-module")
+ }
+ if gotHello.Width != 400 {
+ t.Errorf("hello: Width = %d, want 400", gotHello.Width)
+ }
+ if gotHello.Height != 120 {
+ t.Errorf("hello: Height = %d, want 120", gotHello.Height)
+ }
+ if gotHello.PreferredFPS != 60 {
+ t.Errorf("hello: PreferredFPS = %d, want 60", gotHello.PreferredFPS)
+ }
+
+ // Compositor sends Welcome.
+ welcome := protocol.WelcomeMsg{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ ModuleID: 42,
+ Accepted: 1,
+ Transport: protocol.TransportSocket,
+ MinVersion: 1,
+ MaxVersion: 1,
+ }
+
+ go func() {
+ errCh <- server.WriteHandshakeWelcome(&welcome)
+ }()
+
+ gotWelcome, err := client.ReadHandshakeWelcome()
+ if err != nil {
+ t.Fatalf("ReadHandshakeWelcome: %v", err)
+ }
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteHandshakeWelcome: %v", werr)
+ }
+
+ if gotWelcome.ModuleID != 42 {
+ t.Errorf("welcome: ModuleID = %d, want 42", gotWelcome.ModuleID)
+ }
+ if gotWelcome.Accepted != 1 {
+ t.Errorf("welcome: Accepted = %d, want 1", gotWelcome.Accepted)
+ }
+}
+
+func TestReadFrame_ConnectionClosed(t *testing.T) {
+ client, server := newPipeConns(t)
+
+ // Close the writing side immediately.
+ client.Close()
+
+ _, _, err := server.ReadFrame()
+ if err == nil {
+ t.Fatal("expected error on closed connection, got nil")
+ }
+
+ // Should wrap io.EOF or io.ErrUnexpectedEOF.
+ if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
+ // net.Pipe close produces io.EOF via io.ReadFull -> io.ErrUnexpectedEOF,
+ // but the exact wrapping depends on how many bytes were read.
+ // Accept any error as valid — the point is we don't hang.
+ t.Logf("got error (acceptable): %v", err)
+ }
+}
+
+func TestReadHandshakeHello_ConnectionClosed(t *testing.T) {
+ client, server := newPipeConns(t)
+ client.Close()
+
+ _, err := server.ReadHandshakeHello()
+ if err == nil {
+ t.Fatal("expected error on closed connection, got nil")
+ }
+}
+
+func TestReadHandshakeWelcome_ConnectionClosed(t *testing.T) {
+ client, server := newPipeConns(t)
+ client.Close()
+
+ _, err := server.ReadHandshakeWelcome()
+ if err == nil {
+ t.Fatal("expected error on closed connection, got nil")
+ }
+}
+
+func TestConn_RemoteAddr(t *testing.T) {
+ client, _ := newPipeConns(t)
+
+ addr := client.RemoteAddr()
+ if addr == nil {
+ t.Fatal("RemoteAddr returned nil")
+ }
+}
+
+func TestConn_Close(t *testing.T) {
+ c1, c2 := net.Pipe()
+ conn := NewConn(c1)
+
+ // Close should not error on first call.
+ if err := conn.Close(); err != nil {
+ t.Fatalf("Close: %v", err)
+ }
+
+ // The other end should see EOF.
+ buf := make([]byte, 1)
+ _, err := c2.Read(buf)
+ if err == nil {
+ t.Fatal("expected error after close, got nil")
+ }
+
+ c2.Close()
+}
+
+func TestWriteReadFrame_AllHeaderFields(t *testing.T) {
+ client, server := newPipeConns(t)
+
+ payload := []byte{0xDE, 0xAD, 0xBE, 0xEF}
+ hdr := protocol.Header{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ MsgType: protocol.MsgFrame,
+ Flags: protocol.FlagDirtyValid | protocol.FlagKeyframe,
+ ModuleID: 12345,
+ Sequence: 99,
+ TimestampNs: 1_000_000_000,
+ Width: 1920,
+ Height: 1080,
+ Stride: 1920 * 4,
+ DirtyX: 10,
+ DirtyY: 20,
+ DirtyW: 100,
+ DirtyH: 200,
+ PixelFormat: protocol.PixelRGBA8,
+ Compression: protocol.CompressionNone,
+ PayloadSize: uint32(len(payload)),
+ UncompressedSize: uint32(len(payload)),
+ }
+
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- client.WriteFrame(&hdr, payload)
+ }()
+
+ got, gotPayload, err := server.ReadFrame()
+ if err != nil {
+ t.Fatalf("ReadFrame: %v", err)
+ }
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteFrame: %v", werr)
+ }
+
+ // Verify all fields survived the round trip.
+ if got.ModuleID != 12345 {
+ t.Errorf("ModuleID = %d, want 12345", got.ModuleID)
+ }
+ if got.Sequence != 99 {
+ t.Errorf("Sequence = %d, want 99", got.Sequence)
+ }
+ if got.TimestampNs != 1_000_000_000 {
+ t.Errorf("TimestampNs = %d, want 1000000000", got.TimestampNs)
+ }
+ if got.Width != 1920 {
+ t.Errorf("Width = %d, want 1920", got.Width)
+ }
+ if got.Height != 1080 {
+ t.Errorf("Height = %d, want 1080", got.Height)
+ }
+ if got.Stride != 1920*4 {
+ t.Errorf("Stride = %d, want %d", got.Stride, 1920*4)
+ }
+ if got.DirtyX != 10 || got.DirtyY != 20 || got.DirtyW != 100 || got.DirtyH != 200 {
+ t.Errorf("DirtyRect = (%d,%d,%d,%d), want (10,20,100,200)",
+ got.DirtyX, got.DirtyY, got.DirtyW, got.DirtyH)
+ }
+ if got.Flags != (protocol.FlagDirtyValid | protocol.FlagKeyframe) {
+ t.Errorf("Flags = %d, want %d", got.Flags, protocol.FlagDirtyValid|protocol.FlagKeyframe)
+ }
+ if !bytes.Equal(gotPayload, payload) {
+ t.Error("payload mismatch")
+ }
+}
+
+func TestWriteReadFrame_MultipleSequential(t *testing.T) {
+ client, server := newPipeConns(t)
+
+ const count = 10
+ errCh := make(chan error, 1)
+ go func() {
+ for i := range count {
+ p := []byte{byte(i)}
+ h := makeHeader(protocol.MsgFrame, 1)
+ h.Sequence = uint64(i)
+ if err := client.WriteFrame(&h, p); err != nil {
+ errCh <- err
+ return
+ }
+ }
+ errCh <- nil
+ }()
+
+ for i := range count {
+ hdr, payload, err := server.ReadFrame()
+ if err != nil {
+ t.Fatalf("ReadFrame[%d]: %v", i, err)
+ }
+ if hdr.Sequence != uint64(i) {
+ t.Errorf("frame %d: Sequence = %d", i, hdr.Sequence)
+ }
+ if len(payload) != 1 || payload[0] != byte(i) {
+ t.Errorf("frame %d: payload mismatch", i)
+ }
+ }
+
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteFrame: %v", werr)
+ }
+}
+
+func TestWriteFrame_ClosedConn_HeaderWriteError(t *testing.T) {
+ c1, c2 := net.Pipe()
+ conn := NewConn(c1)
+ c2.Close()
+
+ // Close the underlying conn so the header write fails.
+ c1.Close()
+
+ hdr := makeHeader(protocol.MsgFrame, 5)
+ err := conn.WriteFrame(&hdr, []byte("hello"))
+ if err == nil {
+ t.Fatal("expected error writing to closed conn, got nil")
+ }
+}
+
+func TestWriteFrame_ClosedConn_PayloadWriteError(t *testing.T) {
+ // Use a real socket pair so we can close mid-flight.
+ addr := testSocketPath(t)
+ ln, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ defer ln.Close()
+
+ acceptCh := make(chan *Conn, 1)
+ go func() {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ acceptCh <- c
+ }()
+
+ dialer := NewDialer(addr)
+ clientConn, err := dialer.Dial()
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+
+ serverConn := <-acceptCh
+ // Close the server side so writes from client eventually fail.
+ serverConn.Close()
+
+ // Write a large payload — the header write may succeed but the
+ // large payload write should fail because the peer is closed.
+ bigPayload := make([]byte, 1<<20) // 1 MB
+ hdr := makeHeader(protocol.MsgFrame, uint32(len(bigPayload)))
+
+ // Retry a few times — the first write might succeed due to kernel buffering.
+ var lastErr error
+ for range 100 {
+ lastErr = clientConn.WriteFrame(&hdr, bigPayload)
+ if lastErr != nil {
+ break
+ }
+ }
+ if lastErr == nil {
+ t.Log("write never failed (kernel buffered everything) — skipping payload error test")
+ }
+
+ clientConn.Close()
+}
+
+func TestWriteHandshakeHello_ClosedConn(t *testing.T) {
+ c1, c2 := net.Pipe()
+ conn := NewConn(c1)
+ c2.Close()
+ c1.Close()
+
+ hello := protocol.HelloMsg{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ }
+ err := conn.WriteHandshakeHello(&hello)
+ if err == nil {
+ t.Fatal("expected error writing hello to closed conn, got nil")
+ }
+}
+
+func TestWriteHandshakeWelcome_ClosedConn(t *testing.T) {
+ c1, c2 := net.Pipe()
+ conn := NewConn(c1)
+ c2.Close()
+ c1.Close()
+
+ welcome := protocol.WelcomeMsg{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ }
+ err := conn.WriteHandshakeWelcome(&welcome)
+ if err == nil {
+ t.Fatal("expected error writing welcome to closed conn, got nil")
+ }
+}
+
+func TestReadFrame_InvalidMagic(t *testing.T) {
+ c1, c2 := net.Pipe()
+ defer c1.Close()
+ defer c2.Close()
+
+ // Write 64 bytes of garbage (invalid magic).
+ go func() {
+ garbage := make([]byte, protocol.HeaderSize)
+ garbage[0] = 0xFF // wrong magic
+ garbage[1] = 0xFF
+ garbage[2] = 0xFF
+ garbage[3] = 0xFF
+ // Set a valid MsgType at offset 6 to avoid ErrUnknownMsgType.
+ garbage[6] = uint8(protocol.MsgFrame)
+ c1.Write(garbage)
+ }()
+
+ conn := NewConn(c2)
+ _, _, err := conn.ReadFrame()
+ if err == nil {
+ t.Fatal("expected error for invalid magic, got nil")
+ }
+}
+
+func TestReadFrame_PayloadReadError(t *testing.T) {
+ c1, c2 := net.Pipe()
+ defer c2.Close()
+
+ conn := NewConn(c2)
+
+ // Write a valid header claiming 1000 bytes of payload, then close.
+ go func() {
+ hdr := makeHeader(protocol.MsgFrame, 1000)
+ var buf [protocol.HeaderSize]byte
+ protocol.Encode(&hdr, buf[:])
+ c1.Write(buf[:])
+ // Write only 10 bytes of the claimed 1000, then close.
+ c1.Write([]byte("short data"))
+ c1.Close()
+ }()
+
+ _, _, err := conn.ReadFrame()
+ if err == nil {
+ t.Fatal("expected error for truncated payload, got nil")
+ }
+}
+
+func TestReadHandshakeHello_InvalidMagic(t *testing.T) {
+ c1, c2 := net.Pipe()
+ defer c1.Close()
+ defer c2.Close()
+
+ go func() {
+ garbage := make([]byte, protocol.HandshakeSize)
+ garbage[0] = 0xBA
+ garbage[1] = 0xAD
+ c1.Write(garbage)
+ }()
+
+ conn := NewConn(c2)
+ _, err := conn.ReadHandshakeHello()
+ if err == nil {
+ t.Fatal("expected error for invalid hello magic, got nil")
+ }
+}
+
+func TestReadHandshakeWelcome_InvalidMagic(t *testing.T) {
+ c1, c2 := net.Pipe()
+ defer c1.Close()
+ defer c2.Close()
+
+ go func() {
+ garbage := make([]byte, protocol.HandshakeSize)
+ garbage[0] = 0xDE
+ garbage[1] = 0xAD
+ c1.Write(garbage)
+ }()
+
+ conn := NewConn(c2)
+ _, err := conn.ReadHandshakeWelcome()
+ if err == nil {
+ t.Fatal("expected error for invalid welcome magic, got nil")
+ }
+}
+
+func BenchmarkWriteReadFrame(b *testing.B) {
+ c1, c2 := net.Pipe()
+ defer c1.Close()
+ defer c2.Close()
+
+ client := NewConn(c1)
+ server := NewConn(c2)
+
+ // 192 KB payload (typical frame: 320x150 RGBA).
+ payload := make([]byte, 192*1024)
+ for i := range payload {
+ payload[i] = byte(i)
+ }
+ hdr := makeHeader(protocol.MsgFrame, uint32(len(payload)))
+ buf := make([]byte, len(payload))
+
+ // Writer goroutine.
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ for range b.N {
+ if err := client.WriteFrame(&hdr, payload); err != nil {
+ b.Errorf("WriteFrame: %v", err)
+ return
+ }
+ }
+ }()
+
+ b.SetBytes(int64(protocol.HeaderSize + len(payload)))
+ b.ResetTimer()
+
+ for range b.N {
+ _, _, err := server.ReadFrameInto(buf)
+ if err != nil {
+ b.Fatalf("ReadFrameInto: %v", err)
+ }
+ }
+
+ b.StopTimer()
+ <-done
+}
diff --git a/internal/transport/socket/dialer.go b/internal/transport/socket/dialer.go
new file mode 100644
index 0000000..ea7d9d7
--- /dev/null
+++ b/internal/transport/socket/dialer.go
@@ -0,0 +1,44 @@
+package socket
+
+import (
+ "fmt"
+ "net"
+ "time"
+)
+
+// defaultDialTimeout is the default connection timeout.
+const defaultDialTimeout = 5 * time.Second
+
+// Dialer connects to a compositor's Unix domain socket.
+// Reconnection with backoff belongs in higher layers (internal/conn.Manager);
+// the dialer is a simple one-shot connect helper.
+type Dialer struct {
+ addr string
+ timeout time.Duration
+}
+
+// NewDialer creates a dialer for the given Unix domain socket address.
+// The default timeout is 5 seconds.
+func NewDialer(addr string) *Dialer {
+ return &Dialer{
+ addr: addr,
+ timeout: defaultDialTimeout,
+ }
+}
+
+// Dial connects to the compositor using the default timeout.
+// The returned [Conn] is ready for handshake (WriteHandshakeHello /
+// ReadHandshakeWelcome).
+func (d *Dialer) Dial() (*Conn, error) {
+ return d.DialWithTimeout(d.timeout)
+}
+
+// DialWithTimeout connects to the compositor with a custom timeout.
+// A zero or negative timeout means no deadline.
+func (d *Dialer) DialWithTimeout(timeout time.Duration) (*Conn, error) {
+ c, err := net.DialTimeout("unix", d.addr, timeout)
+ if err != nil {
+ return nil, fmt.Errorf("socket: dial %s: %w", d.addr, err)
+ }
+ return NewConn(c), nil
+}
diff --git a/internal/transport/socket/dialer_test.go b/internal/transport/socket/dialer_test.go
new file mode 100644
index 0000000..33a51ed
--- /dev/null
+++ b/internal/transport/socket/dialer_test.go
@@ -0,0 +1,248 @@
+package socket
+
+import (
+ "testing"
+ "time"
+
+ "github.com/gogpu/compose/internal/protocol"
+)
+
+func TestNewDialer(t *testing.T) {
+ d := NewDialer("/tmp/test.sock")
+ if d.addr != "/tmp/test.sock" {
+ t.Errorf("addr = %q, want %q", d.addr, "/tmp/test.sock")
+ }
+ if d.timeout != defaultDialTimeout {
+ t.Errorf("timeout = %v, want %v", d.timeout, defaultDialTimeout)
+ }
+}
+
+func TestDial_Success(t *testing.T) {
+ addr := testSocketPath(t)
+
+ ln, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ defer ln.Close()
+
+ // Accept in background.
+ type result struct {
+ conn *Conn
+ err error
+ }
+ ch := make(chan result, 1)
+ go func() {
+ c, err := ln.Accept()
+ ch <- result{c, err}
+ }()
+
+ dialer := NewDialer(addr)
+ conn, err := dialer.Dial()
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
+
+ r := <-ch
+ if r.err != nil {
+ t.Fatalf("Accept: %v", r.err)
+ }
+ t.Cleanup(func() { _ = r.conn.Close() })
+}
+
+func TestDial_Timeout(t *testing.T) {
+ // Use a non-existent socket path.
+ addr := testSocketPath(t)
+
+ dialer := NewDialer(addr)
+ _, err := dialer.DialWithTimeout(100 * time.Millisecond)
+ if err == nil {
+ t.Fatal("expected error dialing non-existent socket, got nil")
+ }
+}
+
+func TestDialWithTimeout_CustomTimeout(t *testing.T) {
+ addr := testSocketPath(t)
+
+ ln, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ defer ln.Close()
+
+ // Accept in background.
+ ch := make(chan error, 1)
+ go func() {
+ c, err := ln.Accept()
+ if err == nil {
+ c.Close()
+ }
+ ch <- err
+ }()
+
+ dialer := NewDialer(addr)
+ conn, err := dialer.DialWithTimeout(2 * time.Second)
+ if err != nil {
+ t.Fatalf("DialWithTimeout: %v", err)
+ }
+ conn.Close()
+
+ if err := <-ch; err != nil {
+ t.Fatalf("Accept: %v", err)
+ }
+}
+
+func TestDial_FullHandshake(t *testing.T) {
+ addr := testSocketPath(t)
+
+ ln, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ defer ln.Close()
+
+ // Server Accept goroutine.
+ type result struct {
+ conn *Conn
+ err error
+ }
+ acceptCh := make(chan result, 1)
+ go func() {
+ c, err := ln.Accept()
+ acceptCh <- result{c, err}
+ }()
+
+ // Client connects.
+ dialer := NewDialer(addr)
+ clientConn, err := dialer.Dial()
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer clientConn.Close()
+
+ r := <-acceptCh
+ if r.err != nil {
+ t.Fatalf("Accept: %v", r.err)
+ }
+ serverConn := r.conn
+ defer serverConn.Close()
+
+ // Full handshake: Hello then Welcome.
+ hello := protocol.HelloMsg{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ Width: 320,
+ Height: 240,
+ PreferredFPS: 1,
+ Transport: protocol.TransportSocket,
+ }
+ protocol.SetName(&hello, "clock")
+
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- clientConn.WriteHandshakeHello(&hello)
+ }()
+
+ gotHello, err := serverConn.ReadHandshakeHello()
+ if err != nil {
+ t.Fatalf("ReadHandshakeHello: %v", err)
+ }
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteHandshakeHello: %v", werr)
+ }
+
+ if protocol.GetName(&gotHello) != "clock" {
+ t.Errorf("name = %q, want %q", protocol.GetName(&gotHello), "clock")
+ }
+
+ welcome := protocol.WelcomeMsg{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ ModuleID: 7,
+ Accepted: 1,
+ Transport: protocol.TransportSocket,
+ MinVersion: 1,
+ MaxVersion: 1,
+ }
+
+ go func() {
+ errCh <- serverConn.WriteHandshakeWelcome(&welcome)
+ }()
+
+ gotWelcome, err := clientConn.ReadHandshakeWelcome()
+ if err != nil {
+ t.Fatalf("ReadHandshakeWelcome: %v", err)
+ }
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteHandshakeWelcome: %v", werr)
+ }
+
+ if gotWelcome.ModuleID != 7 {
+ t.Errorf("ModuleID = %d, want 7", gotWelcome.ModuleID)
+ }
+ if gotWelcome.Accepted != 1 {
+ t.Errorf("Accepted = %d, want 1", gotWelcome.Accepted)
+ }
+
+ // Post-handshake: send a frame.
+ payload := []byte("pixel data here")
+ hdr := protocol.Header{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ MsgType: protocol.MsgFrame,
+ ModuleID: 7,
+ Sequence: 1,
+ PayloadSize: uint32(len(payload)),
+ }
+
+ go func() {
+ errCh <- clientConn.WriteFrame(&hdr, payload)
+ }()
+
+ gotHdr, gotPayload, err := serverConn.ReadFrame()
+ if err != nil {
+ t.Fatalf("ReadFrame: %v", err)
+ }
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteFrame: %v", werr)
+ }
+
+ if gotHdr.ModuleID != 7 {
+ t.Errorf("frame ModuleID = %d, want 7", gotHdr.ModuleID)
+ }
+ if string(gotPayload) != "pixel data here" {
+ t.Errorf("frame payload = %q, want %q", gotPayload, "pixel data here")
+ }
+}
+
+func TestDial_ZeroTimeout(t *testing.T) {
+ addr := testSocketPath(t)
+
+ ln, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ defer ln.Close()
+
+ ch := make(chan error, 1)
+ go func() {
+ c, err := ln.Accept()
+ if err == nil {
+ c.Close()
+ }
+ ch <- err
+ }()
+
+ dialer := NewDialer(addr)
+ // Zero timeout = no deadline.
+ conn, err := dialer.DialWithTimeout(0)
+ if err != nil {
+ t.Fatalf("DialWithTimeout(0): %v", err)
+ }
+ conn.Close()
+
+ if err := <-ch; err != nil {
+ t.Fatalf("Accept: %v", err)
+ }
+}
diff --git a/internal/transport/socket/doc.go b/internal/transport/socket/doc.go
new file mode 100644
index 0000000..1b563a6
--- /dev/null
+++ b/internal/transport/socket/doc.go
@@ -0,0 +1,30 @@
+// Package socket provides a Unix domain socket transport for the compose
+// protocol. It implements framed read/write of [protocol.Header] + payload
+// messages over a single [net.Conn], along with server-side listening and
+// client-side dialing.
+//
+// # Conn
+//
+// [Conn] wraps a [net.Conn] with framed I/O. Each frame on the wire is a
+// 64-byte header (see [protocol.HeaderSize]) followed by PayloadSize bytes
+// of payload. Reads and writes are independently locked, so concurrent
+// producers and consumers are safe on the same connection.
+//
+// # Listener
+//
+// [Listener] binds a Unix domain socket (AF_UNIX) and accepts incoming
+// connections. On startup it removes any stale socket file left by a
+// previous crash, preventing "address already in use" errors.
+//
+// # Dialer
+//
+// [Dialer] connects to a compositor's Unix domain socket. It is intentionally
+// simple — reconnection with backoff belongs in higher layers
+// (internal/conn.Manager), keeping the transport layer focused on
+// establishing a single connection.
+//
+// # Platform support
+//
+// Unix domain sockets are supported on Linux, macOS, and Windows 10 1803+
+// (AF_UNIX). No CGO required on any platform.
+package socket
diff --git a/internal/transport/socket/listener.go b/internal/transport/socket/listener.go
new file mode 100644
index 0000000..c9ce23d
--- /dev/null
+++ b/internal/transport/socket/listener.go
@@ -0,0 +1,64 @@
+package socket
+
+import (
+ "fmt"
+ "net"
+ "os"
+)
+
+// Listener accepts module connections on a Unix domain socket (AF_UNIX).
+// On startup it removes any stale socket file left by a previous crash.
+type Listener struct {
+ ln net.Listener
+ addr string
+}
+
+// Listen creates a Unix domain socket listener at the given path.
+//
+// If a socket file already exists at addr, it is removed first. This is the
+// standard Unix pattern to avoid "address already in use" after an unclean
+// shutdown. On Windows, AF_UNIX is supported since Windows 10 version 1803.
+func Listen(addr string) (*Listener, error) {
+ // Remove stale socket file. Ignore errors — the file may not exist,
+ // or the path may be invalid (net.Listen will catch that).
+ _ = os.Remove(addr)
+
+ ln, err := net.Listen("unix", addr)
+ if err != nil {
+ return nil, fmt.Errorf("socket: listen %s: %w", addr, err)
+ }
+
+ return &Listener{
+ ln: ln,
+ addr: addr,
+ }, nil
+}
+
+// Accept waits for and returns the next module connection.
+// The returned [Conn] is ready for handshake (WriteHandshakeWelcome /
+// ReadHandshakeHello).
+func (l *Listener) Accept() (*Conn, error) {
+ c, err := l.ln.Accept()
+ if err != nil {
+ return nil, fmt.Errorf("socket: accept: %w", err)
+ }
+ return NewConn(c), nil
+}
+
+// Close stops accepting connections and removes the socket file.
+func (l *Listener) Close() error {
+ err := l.ln.Close()
+
+ // Best-effort cleanup of socket file.
+ _ = os.Remove(l.addr)
+
+ if err != nil {
+ return fmt.Errorf("socket: close listener: %w", err)
+ }
+ return nil
+}
+
+// Addr returns the listener's network address.
+func (l *Listener) Addr() net.Addr {
+ return l.ln.Addr()
+}
diff --git a/internal/transport/socket/listener_test.go b/internal/transport/socket/listener_test.go
new file mode 100644
index 0000000..6de600f
--- /dev/null
+++ b/internal/transport/socket/listener_test.go
@@ -0,0 +1,327 @@
+package socket
+
+import (
+ "net"
+ "os"
+ "path/filepath"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/gogpu/compose/internal/protocol"
+)
+
+// testSocketPath returns a short temporary Unix socket path for testing.
+// macOS limits Unix socket paths to 104 bytes. t.TempDir() on macOS CI
+// produces paths like /var/folders/.../TestName.../test.sock which easily
+// exceeds this limit. We use os.CreateTemp with a short prefix under /tmp
+// (or %TEMP% on Windows) to guarantee a short path.
+func testSocketPath(t *testing.T) string {
+ t.Helper()
+ f, err := os.CreateTemp("", "cs-*.sock")
+ if err != nil {
+ t.Fatal(err)
+ }
+ path := f.Name()
+ f.Close()
+ os.Remove(path) // Remove the file; we need the path for the socket.
+ t.Cleanup(func() { os.Remove(path) })
+ return path
+}
+
+func TestListen_Accept_Close(t *testing.T) {
+ addr := testSocketPath(t)
+
+ ln, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+
+ // Accept in background.
+ type result struct {
+ conn *Conn
+ err error
+ }
+ ch := make(chan result, 1)
+ go func() {
+ c, err := ln.Accept()
+ ch <- result{c, err}
+ }()
+
+ // Connect from client side.
+ raw, err := net.Dial("unix", addr)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer raw.Close()
+
+ r := <-ch
+ if r.err != nil {
+ t.Fatalf("Accept: %v", r.err)
+ }
+ defer r.conn.Close()
+
+ // Verify the accepted connection is usable.
+ if r.conn.RemoteAddr() == nil {
+ t.Error("accepted conn RemoteAddr is nil")
+ }
+
+ // Close listener.
+ if err := ln.Close(); err != nil {
+ t.Fatalf("Close: %v", err)
+ }
+
+ // Socket file should be removed.
+ if _, err := os.Stat(addr); !os.IsNotExist(err) {
+ t.Errorf("socket file still exists after Close: %v", err)
+ }
+}
+
+func TestListen_StaleSocketCleanup(t *testing.T) {
+ addr := testSocketPath(t)
+
+ // Create a stale socket file.
+ if err := os.WriteFile(addr, []byte("stale"), 0o600); err != nil {
+ t.Fatalf("create stale file: %v", err)
+ }
+
+ // Listen should remove the stale file and succeed.
+ ln, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen with stale file: %v", err)
+ }
+ defer ln.Close()
+
+ // Verify the listener is functional.
+ ch := make(chan error, 1)
+ go func() {
+ c, err := ln.Accept()
+ if err == nil {
+ c.Close()
+ }
+ ch <- err
+ }()
+
+ raw, err := net.Dial("unix", addr)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ raw.Close()
+
+ if err := <-ch; err != nil {
+ t.Fatalf("Accept: %v", err)
+ }
+}
+
+func TestListen_MultipleConnections(t *testing.T) {
+ addr := testSocketPath(t)
+
+ ln, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ defer ln.Close()
+
+ const numClients = 5
+
+ // Accept goroutine.
+ accepted := make([]*Conn, 0, numClients)
+ var mu sync.Mutex
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for range numClients {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Errorf("Accept: %v", err)
+ return
+ }
+ mu.Lock()
+ accepted = append(accepted, c)
+ mu.Unlock()
+ }
+ }()
+
+ // Connect numClients clients.
+ clients := make([]net.Conn, numClients)
+ for i := range numClients {
+ c, err := net.Dial("unix", addr)
+ if err != nil {
+ t.Fatalf("Dial[%d]: %v", i, err)
+ }
+ clients[i] = c
+ }
+
+ wg.Wait()
+
+ mu.Lock()
+ if len(accepted) != numClients {
+ t.Errorf("accepted %d connections, want %d", len(accepted), numClients)
+ }
+ mu.Unlock()
+
+ // Cleanup.
+ for _, c := range clients {
+ c.Close()
+ }
+ mu.Lock()
+ for _, c := range accepted {
+ c.Close()
+ }
+ mu.Unlock()
+}
+
+func TestListen_CloseWhileAcceptBlocking(t *testing.T) {
+ addr := testSocketPath(t)
+
+ ln, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+
+ // Start blocking Accept.
+ ch := make(chan error, 1)
+ go func() {
+ _, err := ln.Accept()
+ ch <- err
+ }()
+
+ // Give Accept time to block. On CI runners, goroutine scheduling
+ // can be slow, so 200ms gives ample headroom.
+ time.Sleep(200 * time.Millisecond)
+
+ // Close should unblock Accept.
+ if err := ln.Close(); err != nil {
+ t.Fatalf("Close: %v", err)
+ }
+
+ select {
+ case err := <-ch:
+ if err == nil {
+ t.Error("Accept should return error after Close")
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("Accept did not unblock after Close (timeout)")
+ }
+}
+
+func TestListen_Addr(t *testing.T) {
+ addr := testSocketPath(t)
+
+ ln, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ defer ln.Close()
+
+ got := ln.Addr()
+ if got == nil {
+ t.Fatal("Addr returned nil")
+ }
+ if got.Network() != "unix" {
+ t.Errorf("Network = %q, want %q", got.Network(), "unix")
+ }
+}
+
+func TestListen_AcceptHandshake(t *testing.T) {
+ addr := testSocketPath(t)
+
+ ln, err := Listen(addr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ defer ln.Close()
+
+ // Accept in background.
+ type result struct {
+ conn *Conn
+ err error
+ }
+ ch := make(chan result, 1)
+ go func() {
+ c, err := ln.Accept()
+ ch <- result{c, err}
+ }()
+
+ // Connect and send Hello.
+ dialer := NewDialer(addr)
+ clientConn, err := dialer.Dial()
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer clientConn.Close()
+
+ r := <-ch
+ if r.err != nil {
+ t.Fatalf("Accept: %v", r.err)
+ }
+ defer r.conn.Close()
+
+ // Handshake: client Hello -> server Welcome.
+ hello := protocol.HelloMsg{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ Width: 800,
+ Height: 600,
+ PreferredFPS: 30,
+ Transport: protocol.TransportSocket,
+ }
+ protocol.SetName(&hello, "gallery")
+
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- clientConn.WriteHandshakeHello(&hello)
+ }()
+
+ gotHello, err := r.conn.ReadHandshakeHello()
+ if err != nil {
+ t.Fatalf("ReadHandshakeHello: %v", err)
+ }
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteHandshakeHello: %v", werr)
+ }
+
+ if protocol.GetName(&gotHello) != "gallery" {
+ t.Errorf("name = %q, want %q", protocol.GetName(&gotHello), "gallery")
+ }
+
+ // Server responds with Welcome.
+ welcome := protocol.WelcomeMsg{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ ModuleID: 1,
+ Accepted: 1,
+ Transport: protocol.TransportSocket,
+ MinVersion: 1,
+ MaxVersion: 1,
+ }
+
+ go func() {
+ errCh <- r.conn.WriteHandshakeWelcome(&welcome)
+ }()
+
+ gotWelcome, err := clientConn.ReadHandshakeWelcome()
+ if err != nil {
+ t.Fatalf("ReadHandshakeWelcome: %v", err)
+ }
+ if werr := <-errCh; werr != nil {
+ t.Fatalf("WriteHandshakeWelcome: %v", werr)
+ }
+
+ if gotWelcome.ModuleID != 1 {
+ t.Errorf("ModuleID = %d, want 1", gotWelcome.ModuleID)
+ }
+ if gotWelcome.Accepted != 1 {
+ t.Errorf("Accepted = %d, want 1", gotWelcome.Accepted)
+ }
+}
+
+func TestListen_InvalidPath(t *testing.T) {
+ // Path that is too long or invalid on most systems.
+ addr := filepath.Join(t.TempDir(), string(make([]byte, 300)))
+ _, err := Listen(addr)
+ if err == nil {
+ t.Fatal("expected error for invalid socket path, got nil")
+ }
+}
diff --git a/option.go b/option.go
new file mode 100644
index 0000000..16b13a7
--- /dev/null
+++ b/option.go
@@ -0,0 +1,94 @@
+package compose
+
+// defaultMaxModules is the default maximum number of concurrent module
+// connections a Server will accept.
+const defaultMaxModules = 16
+
+// defaultFPS is the default frames-per-second for a Client that does not
+// specify WithFPS.
+const defaultFPS = 1
+
+// ServerOption configures a Server. Use With* functions to create options.
+type ServerOption func(*serverConfig)
+
+// ClientOption configures a Client. Use With* functions to create options.
+type ClientOption func(*clientConfig)
+
+type serverConfig struct {
+ maxModules int
+ compression string // "", "raw", "lz4"
+}
+
+type clientConfig struct {
+ name string
+ width uint32
+ height uint32
+ fps uint16
+}
+
+func defaultServerConfig() serverConfig {
+ return serverConfig{
+ maxModules: defaultMaxModules,
+ compression: "",
+ }
+}
+
+func defaultClientConfig() clientConfig {
+ return clientConfig{
+ name: "module",
+ width: 400,
+ height: 300,
+ fps: defaultFPS,
+ }
+}
+
+// WithMaxModules sets the maximum number of concurrent module connections
+// the Server will accept. Values less than 1 are clamped to 1.
+// Default: 16.
+func WithMaxModules(n int) ServerOption {
+ return func(c *serverConfig) {
+ if n < 1 {
+ n = 1
+ }
+ c.maxModules = n
+ }
+}
+
+// WithCompression enables frame payload compression on the Server.
+// Supported values: "lz4". Any other value (including "") uses raw
+// pass-through (no compression).
+// Default: no compression.
+func WithCompression(algo string) ServerOption {
+ return func(c *serverConfig) {
+ c.compression = algo
+ }
+}
+
+// WithName sets the human-readable module name sent during the handshake.
+// The name is used for logging, slot assignment, and module identification.
+// Names longer than 63 bytes are silently truncated.
+// Default: "module".
+func WithName(name string) ClientOption {
+ return func(c *clientConfig) {
+ c.name = name
+ }
+}
+
+// WithFrameSize sets the initial frame dimensions in pixels.
+// The compositor may acknowledge different dimensions during handshake.
+// Default: 400x300.
+func WithFrameSize(width, height uint32) ClientOption {
+ return func(c *clientConfig) {
+ c.width = width
+ c.height = height
+ }
+}
+
+// WithFPS sets the module's preferred frame rate.
+// A value of 0 defaults to 1 FPS (suitable for static content).
+// Default: 1.
+func WithFPS(fps uint16) ClientOption {
+ return func(c *clientConfig) {
+ c.fps = fps
+ }
+}
diff --git a/server.go b/server.go
new file mode 100644
index 0000000..464f09d
--- /dev/null
+++ b/server.go
@@ -0,0 +1,418 @@
+package compose
+
+import (
+ "fmt"
+ "image"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/gogpu/compose/internal/codec"
+ "github.com/gogpu/compose/internal/conn"
+ "github.com/gogpu/compose/internal/flow"
+ "github.com/gogpu/compose/internal/protocol"
+ "github.com/gogpu/compose/internal/transport/socket"
+)
+
+// moduleConn tracks a per-module connection on the server side.
+type moduleConn struct {
+ conn *socket.Conn
+ moduleID uint64
+ name string
+}
+
+// Server is the compositor-side endpoint that accepts module connections
+// and delivers frames. All methods are safe for concurrent use.
+type Server struct {
+ listener *socket.Listener
+ manager *conn.Manager
+ flow *flow.Controller
+ codec codec.Codec
+
+ mu sync.RWMutex
+ onFrame func(Frame)
+ onConnect func(id uint64, name string)
+ onDisconnect func(id uint64, name string)
+
+ modulesMu sync.RWMutex
+ modules map[uint64]*moduleConn
+
+ closed atomic.Bool
+ done chan struct{}
+ wg sync.WaitGroup
+}
+
+// Listen creates a Server that accepts module connections on the given
+// Unix domain socket address. The server immediately begins accepting
+// connections in a background goroutine.
+//
+// Use ServerOption functions to configure behavior:
+//
+// srv, err := compose.Listen("/tmp/compose.sock",
+// compose.WithMaxModules(8),
+// compose.WithCompression("lz4"),
+// )
+func Listen(addr string, opts ...ServerOption) (*Server, error) {
+ cfg := defaultServerConfig()
+ for _, o := range opts {
+ o(&cfg)
+ }
+
+ ln, err := socket.Listen(addr)
+ if err != nil {
+ return nil, fmt.Errorf("compose: listen: %w", err)
+ }
+
+ c := resolveCodec(cfg.compression)
+
+ s := &Server{
+ listener: ln,
+ manager: conn.NewManager(cfg.maxModules),
+ flow: flow.New(),
+ codec: c,
+ modules: make(map[uint64]*moduleConn),
+ done: make(chan struct{}),
+ }
+
+ // Wire manager callbacks to server callbacks.
+ s.manager.OnConnect(func(id uint64, name string) {
+ s.mu.RLock()
+ cb := s.onConnect
+ s.mu.RUnlock()
+ if cb != nil {
+ cb(id, name)
+ }
+ })
+
+ s.manager.OnDisconnect(func(id uint64, name string) {
+ s.mu.RLock()
+ cb := s.onDisconnect
+ s.mu.RUnlock()
+ if cb != nil {
+ cb(id, name)
+ }
+ })
+
+ // Start accept loop.
+ s.wg.Add(1)
+ go s.acceptLoop()
+
+ return s, nil
+}
+
+// OnFrame registers a callback invoked when a module delivers a frame.
+// The callback is called on an internal goroutine; it must not block for
+// extended periods. Only one callback can be active; subsequent calls
+// replace the previous one. Pass nil to remove.
+func (s *Server) OnFrame(fn func(Frame)) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.onFrame = fn
+}
+
+// OnConnect registers a callback invoked when a module completes the
+// handshake and becomes active. Only one callback can be active;
+// subsequent calls replace the previous one. Pass nil to remove.
+func (s *Server) OnConnect(fn func(id uint64, name string)) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.onConnect = fn
+
+ // Also update the manager callback so it fires the new function.
+ s.manager.OnConnect(func(id uint64, name string) {
+ s.mu.RLock()
+ cb := s.onConnect
+ s.mu.RUnlock()
+ if cb != nil {
+ cb(id, name)
+ }
+ })
+}
+
+// OnDisconnect registers a callback invoked when a module disconnects
+// or crashes. Only one callback can be active; subsequent calls replace
+// the previous one. Pass nil to remove.
+func (s *Server) OnDisconnect(fn func(id uint64, name string)) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.onDisconnect = fn
+
+ // Also update the manager callback so it fires the new function.
+ s.manager.OnDisconnect(func(id uint64, name string) {
+ s.mu.RLock()
+ cb := s.onDisconnect
+ s.mu.RUnlock()
+ if cb != nil {
+ cb(id, name)
+ }
+ })
+}
+
+// RequestFrame sends a frame-request signal to the specified module
+// (pull model). The module should respond with its next frame via
+// PublishFrame.
+//
+// Returns ErrModuleNotFound if the module ID is not connected.
+// Returns ErrClosed if the server has been shut down.
+func (s *Server) RequestFrame(moduleID uint64) error {
+ if s.closed.Load() {
+ return ErrClosed
+ }
+
+ s.modulesMu.RLock()
+ mc, ok := s.modules[moduleID]
+ s.modulesMu.RUnlock()
+
+ if !ok {
+ return ErrModuleNotFound
+ }
+
+ hdr := &protocol.Header{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ MsgType: protocol.MsgFrameRequest,
+ }
+ hdr.ModuleID = moduleID
+
+ if err := mc.conn.WriteFrame(hdr, nil); err != nil {
+ return fmt.Errorf("compose: request frame: %w", err)
+ }
+
+ s.flow.FrameRequested(moduleID)
+ return nil
+}
+
+// Close shuts down the server, disconnects all modules, and releases
+// resources. After Close returns, no more callbacks will be invoked.
+func (s *Server) Close() error {
+ if !s.closed.CompareAndSwap(false, true) {
+ return ErrClosed
+ }
+
+ close(s.done)
+
+ // Close the listener to unblock Accept().
+ err := s.listener.Close()
+
+ // Close all module connections.
+ s.modulesMu.RLock()
+ for _, mc := range s.modules {
+ _ = mc.conn.Close() // best-effort
+ }
+ s.modulesMu.RUnlock()
+
+ // Wait for all goroutines to finish.
+ s.wg.Wait()
+
+ if err != nil {
+ return fmt.Errorf("compose: close server: %w", err)
+ }
+ return nil
+}
+
+// acceptLoop runs in a goroutine, accepting new module connections.
+func (s *Server) acceptLoop() {
+ defer s.wg.Done()
+
+ for {
+ c, err := s.listener.Accept()
+ if err != nil {
+ // If server is closing, Accept error is expected.
+ if s.closed.Load() {
+ return
+ }
+ // Transient error — continue accepting.
+ continue
+ }
+
+ s.wg.Add(1)
+ go s.handleModule(c)
+ }
+}
+
+// handleModule runs in a goroutine for each connected module.
+// It performs the handshake, then enters a frame read loop.
+func (s *Server) handleModule(c *socket.Conn) {
+ defer s.wg.Done()
+
+ // Read HelloMsg from module.
+ hello, err := c.ReadHandshakeHello()
+ if err != nil {
+ _ = c.Close()
+ return
+ }
+
+ name := protocol.GetName(&hello)
+
+ // Register module via Manager (allocates ID, fires OnConnect callback).
+ moduleID, err := s.manager.HandleConnect(name, hello.Width, hello.Height, hello.PreferredFPS)
+ if err != nil {
+ // Send rejection WelcomeMsg.
+ welcome := &protocol.WelcomeMsg{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ Accepted: 0,
+ }
+ _ = c.WriteHandshakeWelcome(welcome) // best-effort
+ _ = c.Close()
+ return
+ }
+
+ // Register in flow controller.
+ s.flow.AddModule(moduleID, hello.PreferredFPS)
+
+ // Track the module connection.
+ mc := &moduleConn{
+ conn: c,
+ moduleID: moduleID,
+ name: name,
+ }
+
+ s.modulesMu.Lock()
+ s.modules[moduleID] = mc
+ s.modulesMu.Unlock()
+
+ // Send acceptance WelcomeMsg.
+ welcome := &protocol.WelcomeMsg{
+ Magic: protocol.Magic,
+ Version: protocol.ProtocolVersion,
+ ModuleID: moduleID,
+ Accepted: 1,
+ Transport: protocol.TransportSocket,
+ MinVersion: protocol.ProtocolVersion,
+ MaxVersion: protocol.ProtocolVersion,
+ }
+
+ if err := c.WriteHandshakeWelcome(welcome); err != nil {
+ s.cleanupModule(moduleID)
+ _ = c.Close()
+ return
+ }
+
+ // Enter frame read loop.
+ s.readFrameLoop(mc)
+}
+
+// readFrameLoop reads frames from a module connection until EOF or error.
+func (s *Server) readFrameLoop(mc *moduleConn) {
+ defer func() {
+ s.cleanupModule(mc.moduleID)
+ _ = mc.conn.Close()
+ }()
+
+ for {
+ // Check if server is shutting down.
+ select {
+ case <-s.done:
+ return
+ default:
+ }
+
+ hdr, payload, err := mc.conn.ReadFrame()
+ if err != nil {
+ // EOF or connection error — module disconnected.
+ return
+ }
+
+ if hdr.MsgType == protocol.MsgDisconnect {
+ return
+ }
+
+ if hdr.MsgType != protocol.MsgFrame {
+ // Ignore non-frame messages in the frame read loop.
+ continue
+ }
+
+ // Decompress payload if needed.
+ pixels, err := s.decodePayload(hdr, payload)
+ if err != nil {
+ // Corrupted frame — skip it but keep the connection.
+ continue
+ }
+
+ // Build Frame from header fields.
+ frame := headerToFrame(hdr, pixels, mc.name)
+
+ // Notify flow controller.
+ s.flow.FrameDelivered(mc.moduleID)
+
+ // Update registry last frame time.
+ s.manager.Registry().UpdateLastFrame(mc.moduleID, time.Now())
+
+ // Fire OnFrame callback.
+ s.mu.RLock()
+ cb := s.onFrame
+ s.mu.RUnlock()
+
+ if cb != nil {
+ cb(frame)
+ }
+ }
+}
+
+// cleanupModule removes a module from all tracking structures and fires
+// the OnDisconnect callback via Manager.
+func (s *Server) cleanupModule(moduleID uint64) {
+ s.modulesMu.Lock()
+ delete(s.modules, moduleID)
+ s.modulesMu.Unlock()
+
+ s.flow.RemoveModule(moduleID)
+ s.manager.HandleDisconnect(moduleID)
+}
+
+// decodePayload decompresses the frame payload based on the header's
+// compression flag. Returns the raw RGBA pixel data.
+func (s *Server) decodePayload(hdr protocol.Header, payload []byte) ([]byte, error) {
+ if !hdr.Flags.Has(protocol.FlagCompressed) || hdr.Compression == protocol.CompressionNone {
+ return payload, nil
+ }
+
+ // Look up the codec by compression ID.
+ c := codec.Get(byte(hdr.Compression))
+ if c == nil {
+ return nil, fmt.Errorf("compose: unknown compression 0x%02X", hdr.Compression)
+ }
+
+ // Allocate destination buffer sized to uncompressed size.
+ dst := make([]byte, hdr.UncompressedSize)
+ decoded, err := c.Decode(dst, payload)
+ if err != nil {
+ return nil, fmt.Errorf("compose: decompress: %w", err)
+ }
+
+ return decoded, nil
+}
+
+// headerToFrame converts a protocol.Header and pixel payload into a Frame.
+func headerToFrame(hdr protocol.Header, pixels []byte, name string) Frame {
+ f := Frame{
+ ModuleID: hdr.ModuleID,
+ Name: name,
+ Pixels: pixels,
+ Width: uint32(hdr.Width),
+ Height: uint32(hdr.Height),
+ Timestamp: hdr.TimestampNs,
+ }
+
+ if hdr.Flags.Has(protocol.FlagDirtyValid) {
+ f.DirtyRect = image.Rect(
+ int(hdr.DirtyX),
+ int(hdr.DirtyY),
+ int(hdr.DirtyX)+int(hdr.DirtyW),
+ int(hdr.DirtyY)+int(hdr.DirtyH),
+ )
+ }
+
+ return f
+}
+
+// resolveCodec returns the appropriate codec for the given compression name.
+func resolveCodec(name string) codec.Codec {
+ switch name {
+ case "lz4":
+ return codec.LZ4()
+ default:
+ return codec.Raw()
+ }
+}