diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..cc2c23a --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @kolkov diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d40a34d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,207 @@ +name: CI + +# CI Strategy: +# - Tests run on Linux, macOS, and Windows (cross-platform IPC library) +# - Go 1.25+ required (matches go.mod requirement) +# - CGO_ENABLED=0: Pure Go library, no C compiler required +# +# Branch Strategy (GitHub Flow): +# - main branch: Production-ready code +# - Feature branches: Tested via pull_request trigger +# - Pull requests: Must pass all checks before merge + +env: + CGO_ENABLED: "0" + +on: + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read + id-token: write # Required for Codecov OIDC token + +jobs: + # Build verification - Cross-platform + build: + name: Build - ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + go-version: ['1.25'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + cache: true + + - name: Download dependencies + run: go mod download + + - name: Verify dependencies + run: go mod verify + + - name: Build all packages + run: go build ./... + + # Unit tests - Cross-platform + test: + name: Test - ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + go-version: ['1.25'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + cache: true + + - name: Download dependencies + run: go mod download + + - name: Run go vet + if: matrix.os == 'ubuntu-latest' + run: go vet ./... + + - name: Run tests + run: go test -v ./... + + - name: Run tests with coverage + if: matrix.os == 'ubuntu-latest' + run: go test -coverprofile=coverage.out -covermode=atomic ./... + + - name: Upload coverage to Codecov + if: matrix.os == 'ubuntu-latest' + uses: codecov/codecov-action@v5 + with: + files: coverage.out + use_oidc: true + + # Linting + lint: + name: Lint + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.25' + cache: true + + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v8 + with: + version: latest + args: --timeout=5m + + # Code formatting check + formatting: + name: Formatting + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.25' + cache: true + + - name: Check formatting + run: | + if [ -n "$(gofmt -l .)" ]; then + echo "ERROR: The following files are not formatted:" + gofmt -l . + echo "" + echo "Run 'go fmt ./...' to fix formatting issues." + exit 1 + fi + echo "All files are properly formatted" + + # Dependency freshness check + # Uses go-mod-outdated (https://github.com/psampaz/go-mod-outdated) + # Non-blocking: reports outdated deps as warnings, does not fail CI + deps: + name: Dependencies + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.25' + cache: true + + - name: Install go-mod-outdated + run: go install github.com/psampaz/go-mod-outdated@latest + + - name: Check all dependencies + run: | + echo "## All Dependencies Status" + go list -u -m -json all 2>/dev/null | go-mod-outdated -style markdown || true + + - name: Check direct dependencies for updates + run: | + echo "## Direct Dependencies with Available Updates" + OUTDATED=$(go list -u -m -json all 2>/dev/null | go-mod-outdated -update -direct || true) + if [ -n "$OUTDATED" ]; then + echo "$OUTDATED" + echo "" + echo "::warning::Some direct dependencies have updates available" + else + echo "All direct dependencies are up to date!" + fi + + - name: Check ecosystem dependencies + env: + GH_TOKEN: ${{ github.token }} + run: | + echo "## Ecosystem Dependencies" + WARNINGS=0 + + check_ecosystem_dep() { + local DEP=$1 REPO=$2 + LOCAL=$(grep "$DEP" go.mod 2>/dev/null | grep -v "^module" | awk '{print $2}') + [ -z "$LOCAL" ] && return 0 + + LATEST=$(gh release view --repo "$REPO" --json tagName -q '.tagName' 2>/dev/null || echo "") + [ -z "$LATEST" ] && { echo " $DEP: $LOCAL (cannot verify)"; return 0; } + + if [ "$LOCAL" = "$LATEST" ]; then + echo " $DEP: $LOCAL" + else + echo " $DEP: $LOCAL -> $LATEST available" + WARNINGS=$((WARNINGS + 1)) + fi + } + + check_ecosystem_dep "github.com/pierrec/lz4/v4" "pierrec/lz4" + + [ $WARNINGS -gt 0 ] && echo "::warning::$WARNINGS ecosystem dep(s) outdated. Run: go get @latest" + exit 0 # Non-blocking diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..11b9746 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,165 @@ +# GolangCI-Lint v2 Configuration for gogpu/compose +# Documentation: https://golangci-lint.run/docs/configuration/ + +version: "2" + +run: + timeout: 5m + tests: true + +linters: + enable: + # Code quality and complexity + - gocyclo + - gocognit + - funlen + - maintidx + - cyclop + - nestif + + # Bug detection + - govet + - staticcheck + - errcheck + - errorlint + - gosec + - nilnil + - nilerr + - ineffassign + + # Code style and consistency + - misspell + - whitespace + - unconvert + - unparam + + # Naming conventions + - errname + - revive + + # Performance + - prealloc + - makezero + + # Code practices + - goconst + - gocritic + - goprintffuncname + - nolintlint + - nakedret + + # Additional quality checkers + - dupl + - dogsled + - durationcheck + + settings: + govet: + enable: + - copylocks + disable: + - fieldalignment + + gocyclo: + min-complexity: 20 + + cyclop: + max-complexity: 20 + + funlen: + lines: 120 + statements: 60 + + gocognit: + min-complexity: 30 + + misspell: + locale: US + + nestif: + min-complexity: 4 + + revive: + rules: + - name: var-naming + - name: error-return + - name: error-naming + - name: if-return + - name: increment-decrement + - name: var-declaration + - name: range + - name: receiver-naming + - name: time-naming + - name: unexported-return + - name: indent-error-flow + - name: errorf + - name: empty-block + - name: superfluous-else + - name: unreachable-code + - name: redefines-builtin-id + + gocritic: + enabled-tags: + - diagnostic + - style + - performance + + disabled-checks: + - commentFormatting + - whyNoLint + - unnamedResult + - commentedOutCode + - octalLiteral + - paramTypeCombine + + settings: + hugeParam: + sizeThreshold: 256 + + exclusions: + rules: + # Enum String() methods: the string literal IS the constant name, not a + # candidate for extraction. Standard Go stringer pattern (false positive). + - linters: [goconst] + source: 'return "' + + # Test files - allow more flexibility + - path: _test\.go + linters: + - gocyclo + - cyclop + - funlen + - maintidx + - errcheck + - gosec + - goconst + - dogsled + - dupl + - gocognit + + # Example code + - path: examples?/.*\.go + linters: + - errcheck + - errorlint + - funlen + - gocyclo + - cyclop + - gocognit + - revive + - gosec + - gocritic + +formatters: + enable: + - gofmt + - goimports + +issues: + max-issues-per-linter: 0 + max-same-issues: 0 + new: false + +output: + sort-order: + - file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..680d009 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,18 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [0.1.0] — 2026-05-17 + +### Added + +- **Wire protocol v1** (`internal/protocol/`) — 64-byte fixed header, 128-byte handshake messages, message types, encode/decode with zero allocations (100% coverage) +- **Codec package** (`internal/codec/`) — Raw pass-through + LZ4 block compression via `pierrec/lz4/v4` (97% coverage, 2.9 GB/s encode, 99.6% compression on GUI pixels) +- **Connection manager** (`internal/conn/`) — module registry with monotonic ID allocation, lifecycle state machine, hot-plug callbacks (98.9% coverage) +- **Flow controller** (`internal/flow/`) — pull-based frame pacing (Wayland frame callback pattern), adaptive rate reduction after missed frames (100% coverage) +- **Unix socket transport** (`internal/transport/socket/`) — framed Conn, Listener, Dialer for Unix domain sockets (95.1% coverage, 4.3 GB/s, 45μs latency) +- **Public API** — `compose.Listen()`, `compose.Dial()`, `Frame` type, functional options (`WithMaxModules`, `WithCompression`, `WithName`, `WithFrameSize`, `WithFPS`) +- **CI/CD** — GitHub Actions (build/test/lint on Ubuntu/macOS/Windows), Codecov, golangci-lint v2 diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..037490d --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,34 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a welcoming and respectful experience for everyone. + +## Our Standards + +Examples of behavior that contributes to a positive environment: + +- Using welcoming and inclusive language +- Being respectful of differing viewpoints and experiences +- Gracefully accepting constructive criticism +- Focusing on what is best for the community +- Showing empathy towards other community members + +Examples of unacceptable behavior: + +- The use of sexualized language or imagery and unwelcome sexual attention +- Trolling, insulting/derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information without explicit permission +- Other conduct which could reasonably be considered inappropriate + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the project maintainers at a.kolkov@gmail.com. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), +version 2.1. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..c776959 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,81 @@ +# Contributing to compose + +Thank you for your interest in contributing to gogpu/compose! This document covers how to build, test, and submit changes. + +## Prerequisites + +- **Go 1.25+** ([download](https://go.dev/dl/)) +- **golangci-lint** (`go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest`) + +## Building + +```bash +go build ./... +``` + +## Running Tests + +```bash +go test ./... +``` + +## Running Tests with Coverage + +```bash +go test -coverprofile=tmp/coverage.out ./... +go tool cover -html=tmp/coverage.out +``` + +## Running Linter + +```bash +golangci-lint run --timeout=5m +``` + +## Code Standards + +- **Pure Go** — zero CGO, zero platform-specific code in public API +- **Enterprise quality** — 90%+ test coverage, zero-alloc hot paths +- **Functional options** — use `With*` pattern for configuration +- **Internal packages** — implementation details live in `internal/` + +## Pull Request Process + +1. Fork the repository +2. Create a feature branch (`git checkout -b feat/my-feature`) +3. Make changes with tests +4. Run `go fmt ./... && golangci-lint run --timeout=5m && go test ./...` +5. Commit with conventional format (`feat:`, `fix:`, `docs:`, etc.) +6. Open a pull request against `main` + +## Commit Messages + +``` +feat: add shared memory transport +fix: handle connection timeout on Windows +docs: update wire protocol documentation +test: add benchmark for frame compression +chore: update lz4 dependency +``` + +## Architecture + +The library follows a strict internal/ pattern: + +``` +compose/ # Public API (< 15 exported symbols) +├── internal/ +│ ├── protocol/ # Wire format (64B header, handshake) +│ ├── codec/ # Compression (Raw, LZ4) +│ ├── conn/ # Module lifecycle +│ ├── flow/ # Pull-based pacing +│ └── transport/ +│ ├── socket/ # Unix domain socket (Phase 1) +│ └── shm/ # Shared memory (Phase 2) +``` + +Users import only `"github.com/gogpu/compose"`. All implementation is hidden. + +## License + +By contributing, you agree that your contributions will be licensed under the MIT License. diff --git a/README.md b/README.md index 2071e2d..245ab5c 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ - GoGPU Logo + GoGPU Logo

@@ -16,15 +16,22 @@

+ CI + Coverage Go Reference Go Report Card License Zero CGO - Status

--- +## Installation + +```bash +go get github.com/gogpu/compose +``` + ## What is compose? The `compose` library lets Go applications combine content from several independent processes onto a single display. @@ -75,7 +82,7 @@ srv.OnFrame(func(f compose.Frame) { }) ``` -> **Status: design phase.** APIs above are aspirational — they show the intended shape of the library. See the [Roadmap](#roadmap) section for current progress. +> Unix socket transport, wire protocol v1, LZ4 compression, pull-based flow control. See the [Roadmap](#roadmap) for current progress. ## Why a separate library? @@ -184,14 +191,12 @@ The protocol is the stable compatibility surface of compose. Versioning is indep | Phase | Features | Status | |-------|----------|--------| -| **Phase 0** | Design, ADRs, reference example sketch | Design phase | -| **Phase 1** | Reference example: compositor + clock module + notification module | Planned | -| **Phase 2** | Wire protocol v1, framing, header, dirty rects | Planned | -| **Phase 3** | Unix socket transport, hot-plug, connection management | Planned | -| **Phase 4** | Shared memory ring buffer transport | Planned | -| **Phase 5** | Extract stable APIs from the reference example into this library | Planned | -| **Phase 6** | Multi-screen layout, layered z-order, fade transitions | Future | -| **Phase 7** | Cross-language module SDK (C header, Rust crate, Python) | Future | +| **Phase 1** | Wire protocol v1, Unix socket transport, LZ4 compression, public API | Complete | +| **Phase 2** | Reference examples: compositor + clock + notification (multi-process) | Next | +| **Phase 3** | Shared memory ring buffer transport (zero-copy) | Planned | +| **Phase 4** | Delta frames, compression negotiation | Planned | +| **Phase 5** | Multi-screen layout, layered z-order, fade transitions | Future | +| **Phase 6** | Cross-language module SDK (C header, Rust crate, Python) | Future | ## Design principles @@ -215,11 +220,11 @@ The design of `compose` is informed by: ## Status and contributing -The `compose` library is in the **design phase**. There is no shippable code yet. This repository exists to host the design discussion, the architecture decision records, and (once they exist) the reference example and the extracted library. +The `compose` library ships a working Unix socket transport with wire protocol v1, LZ4 compression, pull-based flow control, and hot-plug module lifecycle. -The first user is a [Go rewrite of MagicMirror²](https://github.com/gogpu/ui/issues/75) targeting Raspberry Pi and (eventually) Redox OS. +Known users: [KiGo](https://github.com/AgentNemo00/kigo) (modular Go application using offscreen rendering + multi-process composition). -If you have a use case for multi-process composition in Go and want to influence the design before APIs freeze, please join the [compose RFC discussion](https://github.com/orgs/gogpu/discussions/177) or open an issue describing your scenario. For the related (but distinct) question of in-process multi-window support in `gogpu` itself, see the [multi-window RFC discussion](https://github.com/orgs/gogpu/discussions/167). Real use cases drive both — we deliberately avoid designing against hypotheticals. +If you have a use case for multi-process composition in Go, please join the [compose RFC discussion](https://github.com/orgs/gogpu/discussions/177) or open an issue describing your scenario. ## Part of the GoGPU Ecosystem diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..38b9beb --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,49 @@ +# Security Policy + +## Supported Versions + +| Version | Supported | +| ------- | ------------------ | +| 0.1.x | :white_check_mark: | + +## Reporting a Vulnerability + +**DO NOT** open a public GitHub issue for security vulnerabilities. + +Instead, please report security issues via: + +1. **Private Security Advisory** (preferred): + https://github.com/gogpu/compose/security/advisories/new + +2. **GitHub Discussions** (for less critical issues): + https://github.com/gogpu/gogpu/discussions + +### What to Include + +- Description of the vulnerability +- Steps to reproduce +- Affected versions +- Potential impact + +### Response Timeline + +- **Initial Response**: Within 72 hours +- **Fix & Disclosure**: Coordinated with reporter + +## Security Considerations + +compose uses IPC mechanisms for inter-process communication. Users should be aware of: + +1. **Unix Domain Sockets** — socket files are created with default permissions. Use appropriate file permissions in production. +2. **Shared Memory** (Phase 2) — memory-mapped regions are shared between processes. Ensure only trusted modules connect. +3. **Wire Protocol** — frame data is not encrypted. For untrusted networks, use TLS or a secure transport. +4. **Module Identity** — module names are self-declared. The compositor should validate module identity for security-critical deployments. + +## Security Contact + +- **GitHub Security Advisory**: https://github.com/gogpu/compose/security/advisories/new +- **Public Issues**: https://github.com/gogpu/compose/issues + +--- + +**Thank you for helping keep gogpu/compose secure!** diff --git a/client.go b/client.go new file mode 100644 index 0000000..3879ee0 --- /dev/null +++ b/client.go @@ -0,0 +1,320 @@ +package compose + +import ( + "fmt" + "image" + "math" + "sync" + "sync/atomic" + "time" + + "github.com/gogpu/compose/internal/codec" + "github.com/gogpu/compose/internal/protocol" + "github.com/gogpu/compose/internal/transport/socket" +) + +// saturateUint16 converts v to uint16, clamping to math.MaxUint16 on overflow. +func saturateUint16(v int) uint16 { + return uint16(min(max(v, 0), math.MaxUint16)) //nolint:gosec // clamped to [0, MaxUint16] +} + +// saturateUint16from32 converts v to uint16, clamping to math.MaxUint16 on overflow. +func saturateUint16from32(v uint32) uint16 { + return uint16(min(v, math.MaxUint16)) //nolint:gosec // clamped to [0, MaxUint16] +} + +// saturateUint32 converts v to uint32, clamping to math.MaxUint32 on overflow. +func saturateUint32(v int) uint32 { + return uint32(min(max(v, 0), math.MaxUint32)) //nolint:gosec // clamped to [0, MaxUint32] +} + +// Client is the module-side endpoint that connects to a compositor and +// publishes frames. All methods are safe for concurrent use. +type Client struct { + conn *socket.Conn + moduleID uint64 + name string + codec codec.Codec + + mu sync.RWMutex + onFrameRequest func() + + closed atomic.Bool + done chan struct{} + wg sync.WaitGroup + seq atomic.Uint64 +} + +// Dial creates a Client that connects to a compositor at the given Unix +// domain socket address. Dial performs the handshake immediately: it sends +// a HelloMsg and reads the WelcomeMsg. If the compositor rejects the +// connection, ErrNotAccepted is returned. +// +// Use ClientOption functions to configure the module: +// +// client, err := compose.Dial("/tmp/compose.sock", +// compose.WithName("clock"), +// compose.WithFrameSize(400, 120), +// compose.WithFPS(1), +// ) +func Dial(addr string, opts ...ClientOption) (*Client, error) { + cfg := defaultClientConfig() + for _, o := range opts { + o(&cfg) + } + + dialer := socket.NewDialer(addr) + sc, err := dialer.Dial() + if err != nil { + return nil, fmt.Errorf("compose: dial: %w", err) + } + + // Build and send HelloMsg. + hello := &protocol.HelloMsg{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + Width: saturateUint16from32(cfg.width), + Height: saturateUint16from32(cfg.height), + PreferredFPS: cfg.fps, + Transport: protocol.TransportSocket, + } + protocol.SetName(hello, cfg.name) + + if err := sc.WriteHandshakeHello(hello); err != nil { + _ = sc.Close() + return nil, fmt.Errorf("compose: send hello: %w", err) + } + + // Read WelcomeMsg. + welcome, err := sc.ReadHandshakeWelcome() + if err != nil { + _ = sc.Close() + return nil, fmt.Errorf("compose: read welcome: %w", err) + } + + if welcome.Accepted == 0 { + _ = sc.Close() + return nil, ErrNotAccepted + } + + c := &Client{ + conn: sc, + moduleID: welcome.ModuleID, + name: cfg.name, + codec: codec.Raw(), // client always sends raw; server handles decompression + done: make(chan struct{}), + } + + // Start reader goroutine for FrameRequest messages. + c.wg.Add(1) + go c.readLoop() + + return c, nil +} + +// PublishFrame sends a frame to the compositor. +// The frame's ModuleID is automatically set to this client's assigned ID. +// +// Returns ErrClosed if the client has been shut down. +func (c *Client) PublishFrame(f Frame) error { + if c.closed.Load() { + return ErrClosed + } + + seq := c.seq.Add(1) + pixels := f.Pixels + uncompressedSize := saturateUint32(len(pixels)) + + // Compress if codec is not raw. + var flags protocol.Flag + compressionID := protocol.CompressionNone + + if c.codec.ID() != codec.IDRaw { + maxSize := c.codec.MaxEncodedSize(len(pixels)) + dst := make([]byte, maxSize) + compressed, err := c.codec.Encode(dst, pixels) + if err != nil { + return fmt.Errorf("compose: compress frame: %w", err) + } + pixels = compressed + flags = flags.Set(protocol.FlagCompressed) + compressionID = protocol.Compression(c.codec.ID()) + } + + // Set dirty rect flags. + if !f.DirtyRect.Empty() { + flags = flags.Set(protocol.FlagDirtyValid) + } else { + flags = flags.Set(protocol.FlagKeyframe) + } + + hdr := &protocol.Header{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + MsgType: protocol.MsgFrame, + Flags: flags, + ModuleID: c.moduleID, + Sequence: seq, + TimestampNs: f.Timestamp, + Width: saturateUint16from32(f.Width), + Height: saturateUint16from32(f.Height), + Stride: f.Width * 4, + PixelFormat: protocol.PixelRGBA8, + Compression: compressionID, + PayloadSize: saturateUint32(len(pixels)), + UncompressedSize: uncompressedSize, + } + + // Set dirty rect fields if valid. + if flags.Has(protocol.FlagDirtyValid) { + hdr.DirtyX = saturateUint16(f.DirtyRect.Min.X) + hdr.DirtyY = saturateUint16(f.DirtyRect.Min.Y) + hdr.DirtyW = saturateUint16(f.DirtyRect.Dx()) + hdr.DirtyH = saturateUint16(f.DirtyRect.Dy()) + } + + // Use monotonic timestamp if caller did not set one. + if hdr.TimestampNs == 0 { + hdr.TimestampNs = time.Now().UnixNano() + } + + return c.conn.WriteFrame(hdr, pixels) +} + +// OnFrameRequest registers a callback invoked when the compositor requests +// a frame. This enables pull-based rendering: the module renders only when +// asked. The callback is called on an internal goroutine; it must not block +// for extended periods. Only one callback can be active; subsequent calls +// replace the previous one. Pass nil to remove. +func (c *Client) OnFrameRequest(fn func()) { + c.mu.Lock() + defer c.mu.Unlock() + c.onFrameRequest = fn +} + +// Close disconnects from the compositor and releases resources. +// After Close returns, no more callbacks will be invoked. +func (c *Client) Close() error { + if !c.closed.CompareAndSwap(false, true) { + return ErrClosed + } + + close(c.done) + + // Send graceful disconnect message (best-effort). + hdr := &protocol.Header{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + MsgType: protocol.MsgDisconnect, + ModuleID: c.moduleID, + } + _ = c.conn.WriteFrame(hdr, nil) + + err := c.conn.Close() + + c.wg.Wait() + + if err != nil { + return fmt.Errorf("compose: close client: %w", err) + } + return nil +} + +// ModuleID returns the compositor-assigned module identifier. +// This is valid after Dial returns successfully. +func (c *Client) ModuleID() uint64 { + return c.moduleID +} + +// readLoop runs in a goroutine, reading control messages from the +// compositor. Currently handles FrameRequest messages. +func (c *Client) readLoop() { + defer c.wg.Done() + + for { + select { + case <-c.done: + return + default: + } + + hdr, _, err := c.conn.ReadFrame() + if err != nil { + // Connection closed or error — stop reading. + return + } + + switch hdr.MsgType { + case protocol.MsgFrameRequest: + c.mu.RLock() + cb := c.onFrameRequest + c.mu.RUnlock() + if cb != nil { + cb() + } + + case protocol.MsgDisconnect: + return + + default: + // Ignore unknown message types for forward compatibility. + } + } +} + +// SetCompression sets the codec used for frame payload compression. +// This allows the compositor to negotiate compression during or after +// handshake. Supported values: "lz4". Any other value uses raw. +func (c *Client) SetCompression(algo string) { + c.mu.Lock() + defer c.mu.Unlock() + c.codec = resolveCodec(algo) +} + +// frameToHeader converts a public Frame to a protocol.Header. +// Used in PublishFrame; extracted here for testing. +func frameToHeader(f Frame, moduleID uint64, seq uint64, c codec.Codec) protocol.Header { + var flags protocol.Flag + if !f.DirtyRect.Empty() { + flags = flags.Set(protocol.FlagDirtyValid) + } else { + flags = flags.Set(protocol.FlagKeyframe) + } + + if c.ID() != codec.IDRaw { + flags = flags.Set(protocol.FlagCompressed) + } + + hdr := protocol.Header{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + MsgType: protocol.MsgFrame, + Flags: flags, + ModuleID: moduleID, + Sequence: seq, + TimestampNs: f.Timestamp, + Width: saturateUint16from32(f.Width), + Height: saturateUint16from32(f.Height), + Stride: f.Width * 4, + PixelFormat: protocol.PixelRGBA8, + Compression: protocol.Compression(c.ID()), + PayloadSize: saturateUint32(len(f.Pixels)), + UncompressedSize: saturateUint32(len(f.Pixels)), + } + + if flags.Has(protocol.FlagDirtyValid) { + dr := f.DirtyRect.Canon() + hdr.DirtyX = saturateUint16(dr.Min.X) + hdr.DirtyY = saturateUint16(dr.Min.Y) + hdr.DirtyW = saturateUint16(dr.Dx()) + hdr.DirtyH = saturateUint16(dr.Dy()) + } + + return hdr +} + +// isDirtyRectValid reports whether the dirty rect represents a valid +// sub-region (non-empty). +func isDirtyRectValid(r image.Rectangle) bool { + return !r.Empty() +} diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..8fe38b8 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,29 @@ +# Codecov configuration +# https://docs.codecov.com/docs/codecovyml-reference + +coverage: + precision: 2 + round: down + range: "80...100" + status: + project: + default: + target: 90% + threshold: 5% + patch: + default: + target: 80% + threshold: 10% + +ignore: + - "examples/**/*" + - "tmp/**/*" + +parsers: + go: + partials_as_hits: false + +comment: + layout: "header, diff, flags, components" + behavior: default + require_changes: false diff --git a/compose.go b/compose.go new file mode 100644 index 0000000..ae8e88f --- /dev/null +++ b/compose.go @@ -0,0 +1,35 @@ +package compose + +import "image" + +// Frame is the fundamental data unit: a rectangular pixel buffer from a module. +// Users construct it to publish (module side) or receive it in callbacks +// (compositor side). +// +// On the module side, ModuleID is ignored in PublishFrame — the server assigns +// the module's ID automatically. On the compositor side, OnFrame delivers a +// Frame with ModuleID populated. +type Frame struct { + // ModuleID is the compositor-assigned identifier. + // Ignored on publish (server assigns). + ModuleID uint64 + + // Name is a human-readable module name (e.g., "clock", "weather"). + // Set during handshake, included in received frames for convenience. + Name string + + // Pixels is the RGBA premultiplied pixel buffer. + // Stride is always Width * 4. + Pixels []byte + + // Width and Height of the frame in pixels. + Width uint32 + Height uint32 + + // DirtyRect is the sub-region that changed since the last frame. + // Zero value means the entire frame is dirty (keyframe). + DirtyRect image.Rectangle + + // Timestamp is a monotonic nanosecond timestamp from the module's clock. + Timestamp int64 +} diff --git a/compose_test.go b/compose_test.go new file mode 100644 index 0000000..052e59f --- /dev/null +++ b/compose_test.go @@ -0,0 +1,811 @@ +package compose + +import ( + "errors" + "image" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/gogpu/compose/internal/protocol" +) + +// tempSocket returns a short temporary Unix socket path. +// macOS limits Unix socket paths to 104 bytes. t.TempDir() on macOS CI +// produces paths like /var/folders/.../TestName.../compose.sock which +// easily exceeds this limit. We use os.CreateTemp with a short prefix +// to guarantee a short path under /tmp (or %TEMP% on Windows). +func tempSocket(t *testing.T) string { + t.Helper() + f, err := os.CreateTemp("", "cs-*.sock") + if err != nil { + t.Fatal(err) + } + path := f.Name() + f.Close() + os.Remove(path) // Remove the file; we need the path for the socket. + t.Cleanup(func() { os.Remove(path) }) + return path +} + +// makePixels creates a simple RGBA pixel buffer of the given dimensions +// filled with the specified byte value. +func makePixels(w, h uint32, fill byte) []byte { + pixels := make([]byte, w*h*4) + for i := range pixels { + pixels[i] = fill + } + return pixels +} + +// waitFor polls a condition function until it returns true or 5 seconds elapse. +// Returns true if the condition was met. The generous timeout accommodates slow +// CI runners (macOS shared, Ubuntu containers) where goroutine scheduling may +// introduce significant delays. +func waitFor(t *testing.T, condition func() bool) bool { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if condition() { + return true + } + time.Sleep(10 * time.Millisecond) + } + return false +} + +func TestListenAndDial(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + client, err := Dial(addr, WithName("test-module")) + if err != nil { + t.Fatalf("Dial: %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + if client.ModuleID() == 0 { + t.Error("ModuleID should be non-zero after successful Dial") + } +} + +func TestFrameRoundTrip(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + var received atomic.Value + + srv.OnFrame(func(f Frame) { + received.Store(f) + }) + + client, err := Dial(addr, WithName("painter"), WithFrameSize(10, 10)) + if err != nil { + t.Fatalf("Dial: %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + // Publish a frame. + pixels := makePixels(10, 10, 0xAA) + ts := time.Now().UnixNano() + + err = client.PublishFrame(Frame{ + Pixels: pixels, + Width: 10, + Height: 10, + Timestamp: ts, + }) + if err != nil { + t.Fatalf("PublishFrame: %v", err) + } + + // Wait for frame to arrive. + ok := waitFor(t, func() bool { + return received.Load() != nil + }) + if !ok { + t.Fatal("timed out waiting for frame") + } + + f := received.Load().(Frame) + + if f.ModuleID == 0 { + t.Error("received frame ModuleID should be non-zero") + } + if f.Name != "painter" { + t.Errorf("received frame Name = %q, want %q", f.Name, "painter") + } + if f.Width != 10 { + t.Errorf("received frame Width = %d, want %d", f.Width, 10) + } + if f.Height != 10 { + t.Errorf("received frame Height = %d, want %d", f.Height, 10) + } + if len(f.Pixels) != 400 { + t.Errorf("received frame Pixels len = %d, want %d", len(f.Pixels), 400) + } + if f.Timestamp != ts { + t.Errorf("received frame Timestamp = %d, want %d", f.Timestamp, ts) + } + + // Verify pixel content. + for i, b := range f.Pixels { + if b != 0xAA { + t.Errorf("pixel[%d] = 0x%02X, want 0xAA", i, b) + break + } + } +} + +func TestFrameWithDirtyRect(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + var received atomic.Value + + srv.OnFrame(func(f Frame) { + received.Store(f) + }) + + client, err := Dial(addr, WithName("dirty"), WithFrameSize(100, 100)) + if err != nil { + t.Fatalf("Dial: %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + dirty := image.Rect(10, 20, 50, 80) + + err = client.PublishFrame(Frame{ + Pixels: makePixels(100, 100, 0xFF), + Width: 100, + Height: 100, + DirtyRect: dirty, + }) + if err != nil { + t.Fatalf("PublishFrame: %v", err) + } + + ok := waitFor(t, func() bool { + return received.Load() != nil + }) + if !ok { + t.Fatal("timed out waiting for frame") + } + + f := received.Load().(Frame) + + if f.DirtyRect != dirty { + t.Errorf("DirtyRect = %v, want %v", f.DirtyRect, dirty) + } +} + +func TestMultipleClients(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr, WithMaxModules(4)) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + var mu sync.Mutex + framesByModule := make(map[uint64]Frame) + + srv.OnFrame(func(f Frame) { + mu.Lock() + framesByModule[f.ModuleID] = f + mu.Unlock() + }) + + // Connect three clients. + clients := make([]*Client, 3) + names := []string{"alpha", "beta", "gamma"} + for i, name := range names { + c, dialErr := Dial(addr, WithName(name), WithFrameSize(8, 8)) + if dialErr != nil { + t.Fatalf("Dial(%s): %v", name, dialErr) + } + clients[i] = c + t.Cleanup(func() { _ = c.Close() }) + } + + // Each client publishes a frame with a unique fill byte. + for i, c := range clients { + fill := byte(i + 1) + pubErr := c.PublishFrame(Frame{ + Pixels: makePixels(8, 8, fill), + Width: 8, + Height: 8, + }) + if pubErr != nil { + t.Fatalf("PublishFrame(%s): %v", names[i], pubErr) + } + } + + // Wait for all three frames. + ok := waitFor(t, func() bool { + mu.Lock() + n := len(framesByModule) + mu.Unlock() + return n >= 3 + }) + if !ok { + mu.Lock() + n := len(framesByModule) + mu.Unlock() + t.Fatalf("timed out: received %d/3 frames", n) + } + + // Verify each client got a unique module ID. + ids := make(map[uint64]bool) + mu.Lock() + for id := range framesByModule { + ids[id] = true + } + mu.Unlock() + if len(ids) != 3 { + t.Errorf("expected 3 unique module IDs, got %d", len(ids)) + } +} + +func TestOnConnectCallback(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + var connectedName atomic.Value + var connectedID atomic.Uint64 + + srv.OnConnect(func(id uint64, name string) { + connectedID.Store(id) + connectedName.Store(name) + }) + + client, err := Dial(addr, WithName("connector")) + if err != nil { + t.Fatalf("Dial: %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + ok := waitFor(t, func() bool { + return connectedName.Load() != nil + }) + if !ok { + t.Fatal("timed out waiting for OnConnect") + } + + if name := connectedName.Load().(string); name != "connector" { + t.Errorf("OnConnect name = %q, want %q", name, "connector") + } + if id := connectedID.Load(); id == 0 { + t.Error("OnConnect ID should be non-zero") + } +} + +func TestOnDisconnectCallback(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + var disconnectedName atomic.Value + var disconnectedID atomic.Uint64 + + srv.OnDisconnect(func(id uint64, name string) { + disconnectedID.Store(id) + disconnectedName.Store(name) + }) + + client, err := Dial(addr, WithName("leaver")) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + // Wait for the server to fully process the connection. On CI runners, + // goroutine scheduling can be slow, so 200ms gives ample headroom. + time.Sleep(200 * time.Millisecond) + + // Close the client — this triggers disconnect. + if err := client.Close(); err != nil { + t.Fatalf("client.Close: %v", err) + } + + ok := waitFor(t, func() bool { + return disconnectedName.Load() != nil + }) + if !ok { + t.Fatal("timed out waiting for OnDisconnect") + } + + if name := disconnectedName.Load().(string); name != "leaver" { + t.Errorf("OnDisconnect name = %q, want %q", name, "leaver") + } + if id := disconnectedID.Load(); id == 0 { + t.Error("OnDisconnect ID should be non-zero") + } +} + +func TestServerClose(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + + client, err := Dial(addr, WithName("stranded")) + if err != nil { + t.Fatalf("Dial: %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + // Close server — should disconnect all clients. + if err := srv.Close(); err != nil { + t.Fatalf("srv.Close: %v", err) + } + + // Double close should return ErrClosed. + if err := srv.Close(); !errors.Is(err, ErrClosed) { + t.Errorf("second Close = %v, want ErrClosed", err) + } + + // RequestFrame after close should return ErrClosed. + if err := srv.RequestFrame(1); !errors.Is(err, ErrClosed) { + t.Errorf("RequestFrame after close = %v, want ErrClosed", err) + } + + // Socket file should be removed. + if _, err := os.Stat(addr); !os.IsNotExist(err) { + t.Errorf("socket file still exists after Close") + } +} + +func TestClientDoubleClose(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + client, err := Dial(addr, WithName("closer")) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + if err := client.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + + if err := client.Close(); !errors.Is(err, ErrClosed) { + t.Errorf("second Close = %v, want ErrClosed", err) + } +} + +func TestPublishAfterClose(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + client, err := Dial(addr, WithName("early-close")) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + if err := client.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + err = client.PublishFrame(Frame{ + Pixels: makePixels(1, 1, 0), + Width: 1, + Height: 1, + }) + if !errors.Is(err, ErrClosed) { + t.Errorf("PublishFrame after close = %v, want ErrClosed", err) + } +} + +func TestRequestFrameTrigger(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + var requestCount atomic.Int32 + + // Track connected module ID. + var moduleID atomic.Uint64 + srv.OnConnect(func(id uint64, _ string) { + moduleID.Store(id) + }) + + client, err := Dial(addr, WithName("puller")) + if err != nil { + t.Fatalf("Dial: %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + client.OnFrameRequest(func() { + requestCount.Add(1) + }) + + // Wait for connection to be fully established. + ok := waitFor(t, func() bool { + return moduleID.Load() != 0 + }) + if !ok { + t.Fatal("timed out waiting for connection") + } + + id := moduleID.Load() + + // Server requests a frame. + if err := srv.RequestFrame(id); err != nil { + t.Fatalf("RequestFrame: %v", err) + } + + ok = waitFor(t, func() bool { + return requestCount.Load() >= 1 + }) + if !ok { + t.Fatal("timed out waiting for OnFrameRequest callback") + } + + if n := requestCount.Load(); n != 1 { + t.Errorf("request count = %d, want 1", n) + } +} + +func TestRequestFrameUnknownModule(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + if err := srv.RequestFrame(999); !errors.Is(err, ErrModuleNotFound) { + t.Errorf("RequestFrame(999) = %v, want ErrModuleNotFound", err) + } +} + +func TestDialNonExistentServer(t *testing.T) { + _, err := Dial("/tmp/compose-nonexistent-test.sock") + if err == nil { + t.Fatal("Dial to non-existent server should fail") + } +} + +func TestWithMaxModulesLimit(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr, WithMaxModules(1)) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + // First client should succeed. + c1, err := Dial(addr, WithName("first")) + if err != nil { + t.Fatalf("Dial first: %v", err) + } + t.Cleanup(func() { _ = c1.Close() }) + + // Wait for the first client to be fully registered on the server. + // On CI runners, goroutine scheduling can be slow. + time.Sleep(200 * time.Millisecond) + + // Second client should be rejected. + _, err = Dial(addr, WithName("second")) + if err == nil { + t.Fatal("Dial second should fail with max modules = 1") + } +} + +func TestFrameSequenceNumbers(t *testing.T) { + addr := tempSocket(t) + + srv, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + client, err := Dial(addr, WithName("sequencer")) + if err != nil { + t.Fatalf("Dial: %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + // Publish three frames and verify the internal sequence counter increments. + for i := 0; i < 3; i++ { + pubErr := client.PublishFrame(Frame{ + Pixels: makePixels(2, 2, byte(i)), + Width: 2, + Height: 2, + }) + if pubErr != nil { + t.Fatalf("PublishFrame %d: %v", i, pubErr) + } + } + + // The sequence counter should be 3 after three frames. + if seq := client.seq.Load(); seq != 3 { + t.Errorf("seq = %d, want 3", seq) + } +} + +func TestFunctionalOptionsDefaults(t *testing.T) { + // Verify server config defaults. + sCfg := defaultServerConfig() + if sCfg.maxModules != 16 { + t.Errorf("default maxModules = %d, want 16", sCfg.maxModules) + } + if sCfg.compression != "" { + t.Errorf("default compression = %q, want empty", sCfg.compression) + } + + // Verify client config defaults. + cCfg := defaultClientConfig() + if cCfg.name != "module" { + t.Errorf("default name = %q, want %q", cCfg.name, "module") + } + if cCfg.width != 400 { + t.Errorf("default width = %d, want 400", cCfg.width) + } + if cCfg.height != 300 { + t.Errorf("default height = %d, want 300", cCfg.height) + } + if cCfg.fps != 1 { + t.Errorf("default fps = %d, want 1", cCfg.fps) + } +} + +func TestFunctionalOptionsApply(t *testing.T) { + sCfg := defaultServerConfig() + WithMaxModules(32)(&sCfg) + WithCompression("lz4")(&sCfg) + + if sCfg.maxModules != 32 { + t.Errorf("maxModules = %d, want 32", sCfg.maxModules) + } + if sCfg.compression != "lz4" { + t.Errorf("compression = %q, want %q", sCfg.compression, "lz4") + } + + cCfg := defaultClientConfig() + WithName("clock")(&cCfg) + WithFrameSize(800, 600)(&cCfg) + WithFPS(60)(&cCfg) + + if cCfg.name != "clock" { + t.Errorf("name = %q, want %q", cCfg.name, "clock") + } + if cCfg.width != 800 { + t.Errorf("width = %d, want 800", cCfg.width) + } + if cCfg.height != 600 { + t.Errorf("height = %d, want 600", cCfg.height) + } + if cCfg.fps != 60 { + t.Errorf("fps = %d, want 60", cCfg.fps) + } +} + +func TestWithMaxModulesClamping(t *testing.T) { + cfg := defaultServerConfig() + WithMaxModules(0)(&cfg) + if cfg.maxModules != 1 { + t.Errorf("maxModules after clamp(0) = %d, want 1", cfg.maxModules) + } + + WithMaxModules(-5)(&cfg) + if cfg.maxModules != 1 { + t.Errorf("maxModules after clamp(-5) = %d, want 1", cfg.maxModules) + } +} + +func TestResolveCodec(t *testing.T) { + raw := resolveCodec("") + if raw.ID() != 0x00 { + t.Errorf("resolveCodec(\"\") ID = 0x%02X, want 0x00", raw.ID()) + } + + lz4 := resolveCodec("lz4") + if lz4.ID() != 0x01 { + t.Errorf("resolveCodec(\"lz4\") ID = 0x%02X, want 0x01", lz4.ID()) + } + + unknown := resolveCodec("zstd") + if unknown.ID() != 0x00 { + t.Errorf("resolveCodec(\"zstd\") ID = 0x%02X, want 0x00 (raw fallback)", unknown.ID()) + } +} + +func TestIsDirtyRectValid(t *testing.T) { + tests := []struct { + name string + rect image.Rectangle + valid bool + }{ + {"zero", image.Rectangle{}, false}, + {"valid", image.Rect(0, 0, 10, 10), true}, + {"empty", image.Rect(5, 5, 5, 5), false}, + // image.Rect canonicalizes, so Rect(10,10,5,5) becomes Rect(5,5,10,10) + // which is valid. Use raw struct to create a truly empty rect. + {"point", image.Rect(5, 5, 5, 5), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isDirtyRectValid(tt.rect) + if got != tt.valid { + t.Errorf("isDirtyRectValid(%v) = %v, want %v", tt.rect, got, tt.valid) + } + }) + } +} + +func TestFrameToHeader(t *testing.T) { + f := Frame{ + Pixels: makePixels(10, 10, 0xFF), + Width: 10, + Height: 10, + DirtyRect: image.Rect(2, 3, 8, 9), + Timestamp: 123456789, + } + + hdr := frameToHeader(f, 42, 7, resolveCodec("")) + + if hdr.ModuleID != 42 { + t.Errorf("ModuleID = %d, want 42", hdr.ModuleID) + } + if hdr.Sequence != 7 { + t.Errorf("Sequence = %d, want 7", hdr.Sequence) + } + if hdr.Width != 10 { + t.Errorf("Width = %d, want 10", hdr.Width) + } + if hdr.Height != 10 { + t.Errorf("Height = %d, want 10", hdr.Height) + } + if hdr.Stride != 40 { + t.Errorf("Stride = %d, want 40", hdr.Stride) + } + if hdr.TimestampNs != 123456789 { + t.Errorf("TimestampNs = %d, want 123456789", hdr.TimestampNs) + } + if hdr.DirtyX != 2 || hdr.DirtyY != 3 || hdr.DirtyW != 6 || hdr.DirtyH != 6 { + t.Errorf("DirtyRect = (%d,%d,%d,%d), want (2,3,6,6)", + hdr.DirtyX, hdr.DirtyY, hdr.DirtyW, hdr.DirtyH) + } +} + +func TestSentinelErrors(t *testing.T) { + // Verify sentinel errors are distinct and not nil. + sentinels := []error{ErrClosed, ErrNotAccepted, ErrModuleNotFound, ErrMaxModules, ErrNameTaken} + for i, e := range sentinels { + if e == nil { + t.Errorf("sentinel error %d is nil", i) + } + } + + // ErrMaxModules and ErrNameTaken should be the same objects as conn package exports. + if ErrMaxModules.Error() != "compose: maximum module count reached" { + t.Errorf("ErrMaxModules message = %q", ErrMaxModules.Error()) + } + if ErrNameTaken.Error() != "compose: module name already registered" { + t.Errorf("ErrNameTaken message = %q", ErrNameTaken.Error()) + } +} + +func TestHeaderToFrame(t *testing.T) { + pixels := makePixels(8, 8, 0xBB) + + tests := []struct { + name string + flags uint8 + dirtyX uint16 + dirtyY uint16 + dirtyW uint16 + dirtyH uint16 + wantDirty image.Rectangle + }{ + { + name: "no dirty rect (keyframe)", + flags: 0, + wantDirty: image.Rectangle{}, + }, + { + name: "with dirty rect", + flags: 0x01, // FlagDirtyValid + dirtyX: 2, + dirtyY: 3, + dirtyW: 4, + dirtyH: 5, + wantDirty: image.Rect(2, 3, 6, 8), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hdr := buildTestHeader(tt.flags, tt.dirtyX, tt.dirtyY, tt.dirtyW, tt.dirtyH) + + f := headerToFrame(hdr, pixels, "test") + + if f.ModuleID != 1 { + t.Errorf("ModuleID = %d, want 1", f.ModuleID) + } + if f.Name != "test" { + t.Errorf("Name = %q, want %q", f.Name, "test") + } + if f.Width != 8 { + t.Errorf("Width = %d, want 8", f.Width) + } + if f.Height != 8 { + t.Errorf("Height = %d, want 8", f.Height) + } + if f.DirtyRect != tt.wantDirty { + t.Errorf("DirtyRect = %v, want %v", f.DirtyRect, tt.wantDirty) + } + }) + } +} + +// buildTestHeader creates a protocol.Header with commonly used test values. +func buildTestHeader(flags uint8, dirtyX, dirtyY, dirtyW, dirtyH uint16) protocol.Header { + return protocol.Header{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + MsgType: protocol.MsgFrame, + Flags: protocol.Flag(flags), + ModuleID: 1, + Sequence: 1, + TimestampNs: time.Now().UnixNano(), + Width: 8, + Height: 8, + Stride: 32, + DirtyX: dirtyX, + DirtyY: dirtyY, + DirtyW: dirtyW, + DirtyH: dirtyH, + PixelFormat: protocol.PixelRGBA8, + Compression: protocol.CompressionNone, + PayloadSize: 256, + } +} diff --git a/doc.go b/doc.go index d844209..3e34d43 100644 --- a/doc.go +++ b/doc.go @@ -8,14 +8,53 @@ // possibility of cross-language modules — anything that can write RGBA to a // Unix socket or shared memory segment can participate. // +// # Quick Start +// +// Compositor (server) side: +// +// srv, err := compose.Listen("/tmp/compose.sock", +// compose.WithMaxModules(8), +// ) +// if err != nil { +// log.Fatal(err) +// } +// defer srv.Close() +// +// srv.OnFrame(func(f compose.Frame) { +// // Blit f.Pixels onto the compositor window. +// }) +// +// Module (client) side: +// +// client, err := compose.Dial("/tmp/compose.sock", +// compose.WithName("clock"), +// compose.WithFrameSize(400, 120), +// ) +// if err != nil { +// log.Fatal(err) +// } +// defer client.Close() +// +// client.OnFrameRequest(func() { +// // Render and publish a frame. +// _ = client.PublishFrame(compose.Frame{ +// Pixels: renderClock(), +// Width: 400, +// Height: 120, +// }) +// }) +// +// # Architecture +// // Multi-window in gogpu (ADR-010) and the compose model are different concepts. // Multi-window shares one GPU Device across N native windows inside a single // process. The compose model is the opposite: N independent processes, each // with its own GPU Device, cooperating over IPC to produce a single composed // display. // -// Status: design phase. See the repository README for the roadmap and the -// linked ui#75 discussion for context. +// The public API consists of five types (Frame, Server, Client, ServerOption, +// ClientOption) and three constructor functions (Listen, Dial, and With* +// options). All implementation details live behind internal/ boundaries. // // Part of the GoGPU ecosystem: https://github.com/gogpu package compose diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 0000000..ba87c77 --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,156 @@ +# Architecture + +> **Module:** `github.com/gogpu/compose` +> **Pattern:** Two-tier transport (Unix socket + shared memory) +> **Inspiration:** Wayland wl_shm, Android SurfaceFlinger, PipeWire SPA + +## Overview + +``` +┌─────────────────────────────────────────┐ +│ Display (one physical screen) │ +└───────────────────┬─────────────────────┘ + │ +┌───────────────────┴─────────────────────┐ +│ Compositor process │ +│ compose.Listen("/tmp/compose.sock") │ +│ │ +│ • accepts module connections │ +│ • assigns module IDs │ +│ • pull-based frame requests │ +│ • composites frames onto display │ +└────┬────────────┬────────────┬──────────┘ + │ │ │ + socket socket socket + │ │ │ +┌────┴─────┐ ┌───┴──────┐ ┌───┴───────────┐ +│ Clock │ │ Weather │ │ Notification │ +│ module │ │ module │ │ module │ +│ 1 Hz │ │ 0.1 Hz │ │ 60 Hz anim │ +│ own PID │ │ own PID │ │ own PID │ +└──────────┘ └──────────┘ └───────────────┘ +``` + +## Package Structure + +``` +compose/ # Public API (13 exported symbols) +├── compose.go # Frame type +├── server.go # Server (compositor side) +├── client.go # Client (module side) +├── option.go # Functional options +├── error.go # Sentinel errors +│ +└── internal/ # Implementation (hidden) + ├── protocol/ # Wire format (64B header, handshake) + ├── codec/ # Compression (Raw, LZ4) + ├── conn/ # Module lifecycle (registry, hot-plug) + ├── flow/ # Pull-based pacing (Wayland pattern) + └── transport/ + └── socket/ # Unix domain socket transport +``` + +## Public API + +Users import ONE package: `"github.com/gogpu/compose"`. + +```go +// Compositor side +srv, _ := compose.Listen("/tmp/compose.sock", + compose.WithMaxModules(8), + compose.WithCompression("lz4"), +) +srv.OnFrame(func(f compose.Frame) { /* composite */ }) +srv.OnConnect(func(id uint64, name string) { /* module joined */ }) +srv.OnDisconnect(func(id uint64, name string) { /* module left */ }) + +// Module side +client, _ := compose.Dial("/tmp/compose.sock", + compose.WithName("clock"), + compose.WithFrameSize(400, 120), + compose.WithFPS(1), +) +client.OnFrameRequest(func() { /* render and publish */ }) +client.PublishFrame(compose.Frame{ Pixels: rgba, Width: 400, Height: 120 }) +``` + +## Wire Protocol v1 + +64-byte fixed header (cache-line aligned, little-endian): + +| Offset | Field | Size | Description | +|--------|-------|------|-------------| +| 0 | Magic | 4B | `0x434F4D50` ("COMP") | +| 4 | Version | 2B | Protocol version | +| 6 | MsgType | 1B | Frame, Handshake, Ack, FrameRequest, Resize, Disconnect | +| 7 | Flags | 1B | DirtyValid, Compressed, Keyframe | +| 8 | ModuleID | 8B | Compositor-assigned | +| 16 | Sequence | 8B | Monotonic frame counter | +| 24 | Timestamp | 8B | Monotonic nanoseconds | +| 32 | Width | 2B | Frame width | +| 34 | Height | 2B | Frame height | +| 36 | Stride | 4B | Bytes per row | +| 40 | DirtyRect | 8B | x, y, w, h (2B each) | +| 48 | PixelFormat | 1B | RGBA8, BGRA8 | +| 49 | Compression | 1B | None, LZ4, Zstd | +| 56 | PayloadSize | 4B | Compressed payload bytes | +| 60 | UncompressedSize | 4B | Original pixel bytes | + +## Flow Control + +Pull-based (Wayland frame callback pattern): + +1. Compositor → Module: `FrameRequest` +2. Module renders → sends `Frame` +3. Compositor processes → sends next `FrameRequest` +4. Adaptive rate: 3 consecutive misses → halve request rate + +## Connection Lifecycle + +``` +Module connects → Handshake (name, size, fps) + → Compositor assigns ID, fires OnConnect + → Frame loop (pull-based) + → Module crashes → EOF → OnDisconnect + → Module reconnects → new handshake → same slot +``` + +## Design Principles + +1. **Minimal public surface** — < 15 exported symbols, one import +2. **Enterprise internal/** — all implementation hidden behind `internal/` +3. **Zero allocations on hot path** — pre-allocated buffers, reused headers +4. **Cross-platform** — Linux, macOS, Windows (AF_UNIX), FreeBSD +5. **Zero CGO** — Pure Go transports, Pure Go compression +6. **Independent releases** — protocol versioned separately from API + +## Performance + +| Metric | Value | +|--------|-------| +| Header encode/decode | 6–24 ns, 0 allocs | +| LZ4 compression | 2.9 GB/s encode | +| GUI pixel compression ratio | 99.6% (flat color) | +| Socket throughput | 4.3 GB/s | +| Frame latency (192KB) | 45 μs | +| Flow control overhead | 37 ns/decision | + +## Dependency Graph + +``` +compose (root) ──→ internal/protocol (leaf, no deps) + ├──→ internal/transport/socket ──→ internal/protocol + ├──→ internal/codec (standalone) + ├──→ internal/flow (standalone) + └──→ internal/conn (standalone) +``` + +No circular dependencies. `internal/protocol` is the leaf — imported by others, imports nothing internal. + +## Part of GoGPU Ecosystem + +``` +naga (shaders) → wgpu (WebGPU) → gogpu (windowing) → gg (2D) → ui (widgets) + ↓ + compose (IPC) +``` diff --git a/error.go b/error.go new file mode 100644 index 0000000..f7380b4 --- /dev/null +++ b/error.go @@ -0,0 +1,30 @@ +package compose + +import ( + "errors" + + "github.com/gogpu/compose/internal/conn" +) + +// Sentinel errors returned by Server and Client. +var ( + // ErrClosed is returned when an operation is attempted on a closed + // Server or Client. + ErrClosed = errors.New("compose: server/client closed") + + // ErrNotAccepted is returned by Dial when the compositor rejects the + // connection (e.g., due to capacity limits or policy). + ErrNotAccepted = errors.New("compose: connection not accepted by compositor") + + // ErrModuleNotFound is returned by RequestFrame when the specified + // module ID does not exist in the server's module table. + ErrModuleNotFound = errors.New("compose: module not found") + + // ErrMaxModules is returned when the server cannot accept a new module + // because the maximum module count has been reached. + ErrMaxModules = conn.ErrMaxModules + + // ErrNameTaken is returned when a module with the same name is already + // connected to the server. + ErrNameTaken = conn.ErrNameTaken +) diff --git a/go.mod b/go.mod index ae69c8d..561b7a7 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/gogpu/compose go 1.25 + +require github.com/pierrec/lz4/v4 v4.1.26 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..dafceee --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/pierrec/lz4/v4 v4.1.26 h1:GrpZw1gZttORinvzBdXPUXATeqlJjqUG/D87TKMnhjY= +github.com/pierrec/lz4/v4 v4.1.26/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4= diff --git a/internal/codec/codec.go b/internal/codec/codec.go new file mode 100644 index 0000000..b9c7b32 --- /dev/null +++ b/internal/codec/codec.go @@ -0,0 +1,68 @@ +package codec + +import "sync" + +// Protocol compression identifiers. +const ( + IDRaw byte = 0x00 + IDLZ4 byte = 0x01 +) + +// Codec compresses and decompresses frame pixel data. +// Implementations must be safe for concurrent use. +// Encode/Decode must not allocate on the hot path when the caller provides +// a sufficiently sized destination buffer. +type Codec interface { + // Encode compresses src into dst. Returns the compressed slice (sub-slice of dst). + // dst must be large enough (use MaxEncodedSize to determine required capacity). + // If dst is nil or too small, a new buffer is allocated (slow path). + Encode(dst, src []byte) ([]byte, error) + + // Decode decompresses src into dst. Returns the decompressed slice. + // dst must be large enough to hold the uncompressed data. + // If dst is nil or too small, a new buffer is allocated (slow path). + Decode(dst, src []byte) ([]byte, error) + + // ID returns the protocol compression identifier. + ID() byte + + // MaxEncodedSize returns the maximum possible compressed size for input of + // the given length. Use this to pre-allocate destination buffers. + MaxEncodedSize(srcLen int) int +} + +var ( + registryMu sync.RWMutex + registry = make(map[byte]Codec) +) + +// Register adds a codec to the global registry. It is called during init() +// by each codec implementation. Panics if a codec with the same ID is already +// registered. +func Register(c Codec) { + registryMu.Lock() + defer registryMu.Unlock() + + id := c.ID() + if _, exists := registry[id]; exists { + panic("codec: duplicate registration for ID " + string(rune('0'+id))) + } + registry[id] = c +} + +// Get returns the codec for the given protocol ID. Returns nil if no codec +// is registered for that ID. +func Get(id byte) Codec { + registryMu.RLock() + defer registryMu.RUnlock() + + return registry[id] +} + +// resetRegistry clears the global registry. Used only in tests. +func resetRegistry() { + registryMu.Lock() + defer registryMu.Unlock() + + registry = make(map[byte]Codec) +} diff --git a/internal/codec/codec_test.go b/internal/codec/codec_test.go new file mode 100644 index 0000000..c8f0df6 --- /dev/null +++ b/internal/codec/codec_test.go @@ -0,0 +1,83 @@ +package codec + +import "testing" + +func TestRegisterAndGet(t *testing.T) { + // Reset registry for isolated test. + resetRegistry() + defer func() { + // Re-register defaults after test. + resetRegistry() + Register(Raw()) + Register(LZ4()) + }() + + raw := Raw() + Register(raw) + + got := Get(IDRaw) + if got == nil { + t.Fatal("Get(IDRaw) returned nil after Register") + } + if got.ID() != IDRaw { + t.Errorf("Get(IDRaw).ID() = %d, want %d", got.ID(), IDRaw) + } +} + +func TestGetUnknownID(t *testing.T) { + got := Get(0xFF) + if got != nil { + t.Errorf("Get(0xFF) = %v, want nil", got) + } +} + +func TestGetRegisteredCodecs(t *testing.T) { + tests := []struct { + name string + id byte + }{ + {"Raw", IDRaw}, + {"LZ4", IDLZ4}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := Get(tt.id) + if c == nil { + t.Fatalf("Get(0x%02x) returned nil", tt.id) + } + if c.ID() != tt.id { + t.Errorf("ID() = 0x%02x, want 0x%02x", c.ID(), tt.id) + } + }) + } +} + +func TestRegisterDuplicatePanics(t *testing.T) { + resetRegistry() + defer func() { + resetRegistry() + Register(Raw()) + Register(LZ4()) + }() + + Register(Raw()) + + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic on duplicate registration, got none") + } + }() + + // Second registration with same ID should panic. + Register(Raw()) +} + +func TestCodecConstants(t *testing.T) { + if IDRaw != 0x00 { + t.Errorf("IDRaw = 0x%02x, want 0x00", IDRaw) + } + if IDLZ4 != 0x01 { + t.Errorf("IDLZ4 = 0x%02x, want 0x01", IDLZ4) + } +} diff --git a/internal/codec/doc.go b/internal/codec/doc.go new file mode 100644 index 0000000..4d47bc2 --- /dev/null +++ b/internal/codec/doc.go @@ -0,0 +1,15 @@ +// Package codec provides frame payload compression and decompression for the +// compose protocol. It defines a Codec interface with pluggable implementations +// and a global registry for protocol-level codec negotiation. +// +// Implementations must be safe for concurrent use. The Encode and Decode methods +// accept caller-provided destination buffers to avoid allocations on the hot path. +// When the caller provides a sufficiently sized buffer, zero allocations occur. +// +// Available codecs: +// - Raw (ID 0x00): Pass-through copy, no compression. +// - LZ4 (ID 0x01): LZ4 block compression via github.com/pierrec/lz4/v4. +// +// Registration happens automatically via init() in each codec's source file. +// Use Get(id) to retrieve a codec by its protocol identifier. +package codec diff --git a/internal/codec/lz4.go b/internal/codec/lz4.go new file mode 100644 index 0000000..5125732 --- /dev/null +++ b/internal/codec/lz4.go @@ -0,0 +1,116 @@ +package codec + +import ( + "fmt" + "sync" + + "github.com/pierrec/lz4/v4" +) + +func init() { + Register(LZ4()) +} + +// LZ4 returns a codec using LZ4 block compression. It uses a pooled +// lz4.Compressor to avoid allocations on the hot path. The compressor +// hash table is reused across calls via sync.Pool. +// +// LZ4 provides fast compression with moderate ratios, making it ideal for +// frame pixel data that contains large flat color regions (typical in GUIs). +func LZ4() Codec { + return &lz4Codec{ + pool: sync.Pool{ + New: func() any { + var c lz4.Compressor + return &c + }, + }, + } +} + +type lz4Codec struct { + pool sync.Pool +} + +// Encode compresses src using LZ4 block compression. Returns a sub-slice of +// dst containing the compressed data. If dst is nil or too small, allocates a +// new buffer. +// +// If src is empty, returns an empty slice without error. +func (c *lz4Codec) Encode(dst, src []byte) ([]byte, error) { + if len(src) == 0 { + if dst == nil { + return nil, nil + } + return dst[:0], nil + } + + maxSize := lz4.CompressBlockBound(len(src)) + if cap(dst) < maxSize { + dst = make([]byte, maxSize) + } else { + dst = dst[:maxSize] + } + + compressor := c.pool.Get().(*lz4.Compressor) + n, err := compressor.CompressBlock(src, dst) + c.pool.Put(compressor) + + if err != nil { + return nil, fmt.Errorf("codec: lz4 encode: %w", err) + } + + return dst[:n], nil +} + +// Decode decompresses an LZ4 block-compressed payload. Returns a sub-slice of +// dst containing the decompressed data. If dst is nil or too small, allocates +// a new buffer. +// +// The caller should provide a dst buffer sized to the expected uncompressed +// length (e.g., Width * Height * 4 for RGBA frames). When dst is nil or too +// small, an exponential growth strategy is used as a fallback. +func (c *lz4Codec) Decode(dst, src []byte) ([]byte, error) { + if len(src) == 0 { + if dst == nil { + return nil, nil + } + return dst[:0], nil + } + + if cap(dst) == 0 { + // Caller didn't provide a buffer. Start with 10x compressed size + // as initial guess. LZ4 GUI data often compresses 100:1 or better, + // but 10x covers most cases in one attempt. + dst = make([]byte, len(src)*10) + } else { + dst = dst[:cap(dst)] + } + + // maxDecodeBuf caps the growth strategy to prevent runaway allocation + // on corrupt or adversarial input (64 MB covers 4K RGBA frames). + const maxDecodeBuf = 64 * 1024 * 1024 + + for { + n, err := lz4.UncompressBlock(src, dst) + if err != nil { + // If buffer might be too small, double and retry. + if len(dst) < maxDecodeBuf { + dst = make([]byte, len(dst)*2) + continue + } + return nil, fmt.Errorf("codec: lz4 decode: %w", err) + } + return dst[:n], nil + } +} + +// ID returns the LZ4 codec protocol identifier (0x01). +func (c *lz4Codec) ID() byte { + return IDLZ4 +} + +// MaxEncodedSize returns the worst-case encoded size for LZ4 block compression. +func (c *lz4Codec) MaxEncodedSize(srcLen int) int { + return lz4.CompressBlockBound(srcLen) +} diff --git a/internal/codec/lz4_test.go b/internal/codec/lz4_test.go new file mode 100644 index 0000000..417891f --- /dev/null +++ b/internal/codec/lz4_test.go @@ -0,0 +1,341 @@ +package codec + +import ( + "bytes" + "crypto/rand" + "testing" +) + +func TestLZ4RoundTrip(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + {"single byte", []byte{0x42}}, + {"small text", []byte("hello world hello world hello world")}, + {"zeros 1KB", make([]byte, 1024)}, + {"gui frame 400x120", makeGUIPixels(400, 120)}, + {"repeated pattern", bytes.Repeat([]byte{0xAA, 0xBB, 0xCC, 0xDD}, 4096)}, + } + + c := LZ4() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + src := tt.data + + // Encode with pre-allocated buffer. + encBuf := make([]byte, c.MaxEncodedSize(len(src))) + encoded, err := c.Encode(encBuf, src) + if err != nil { + t.Fatalf("Encode: %v", err) + } + + // Decode with pre-allocated buffer. + decBuf := make([]byte, len(src)) + decoded, err := c.Decode(decBuf, encoded) + if err != nil { + t.Fatalf("Decode: %v", err) + } + + if !bytes.Equal(decoded, src) { + t.Errorf("round-trip mismatch: decoded %d bytes, want %d", len(decoded), len(src)) + } + }) + } +} + +func TestLZ4EmptyInput(t *testing.T) { + c := LZ4() + + // Encode nil input. + encoded, err := c.Encode(nil, nil) + if err != nil { + t.Fatalf("Encode(nil, nil): %v", err) + } + if encoded != nil { + t.Errorf("Encode(nil, nil) = %v, want nil", encoded) + } + + // Encode empty slice with dst. + dst := make([]byte, 16) + encoded, err = c.Encode(dst, []byte{}) + if err != nil { + t.Fatalf("Encode(dst, []byte{}): %v", err) + } + if len(encoded) != 0 { + t.Errorf("Encode empty: len = %d, want 0", len(encoded)) + } + + // Decode nil input. + decoded, err := c.Decode(nil, nil) + if err != nil { + t.Fatalf("Decode(nil, nil): %v", err) + } + if decoded != nil { + t.Errorf("Decode(nil, nil) = %v, want nil", decoded) + } + + // Decode empty slice with dst. + decoded, err = c.Decode(dst, []byte{}) + if err != nil { + t.Fatalf("Decode(dst, []byte{}): %v", err) + } + if len(decoded) != 0 { + t.Errorf("Decode empty: len = %d, want 0", len(decoded)) + } +} + +func TestLZ4CompressionRatio(t *testing.T) { + c := LZ4() + + // GUI pixel data should compress well (large flat color regions). + src := makeGUIPixels(400, 120) + encBuf := make([]byte, c.MaxEncodedSize(len(src))) + encoded, err := c.Encode(encBuf, src) + if err != nil { + t.Fatalf("Encode: %v", err) + } + + ratio := float64(len(encoded)) / float64(len(src)) + t.Logf("GUI pixels: %d -> %d bytes (ratio: %.3f, savings: %.1f%%)", + len(src), len(encoded), ratio, (1-ratio)*100) + + // GUI pixel data with large flat regions should compress to at least 50%. + if ratio > 0.50 { + t.Errorf("compression ratio %.3f is worse than expected 0.50 for GUI data", ratio) + } +} + +func TestLZ4RandomData(t *testing.T) { + c := LZ4() + + // Random data compresses poorly but must still round-trip correctly. + src := make([]byte, 4096) + if _, err := rand.Read(src); err != nil { + t.Fatalf("rand.Read: %v", err) + } + + encBuf := make([]byte, c.MaxEncodedSize(len(src))) + encoded, err := c.Encode(encBuf, src) + if err != nil { + t.Fatalf("Encode random: %v", err) + } + + decBuf := make([]byte, len(src)) + decoded, err := c.Decode(decBuf, encoded) + if err != nil { + t.Fatalf("Decode random: %v", err) + } + if !bytes.Equal(decoded, src) { + t.Error("round-trip mismatch for random data") + } +} + +func TestLZ4NilDst(t *testing.T) { + c := LZ4() + src := makeGUIPixels(400, 120) + + // nil dst on Encode forces allocation (slow path). + encoded, err := c.Encode(nil, src) + if err != nil { + t.Fatalf("Encode(nil, src): %v", err) + } + + // nil dst on Decode forces allocation with growth strategy. + decoded, err := c.Decode(nil, encoded) + if err != nil { + t.Fatalf("Decode(nil, encoded): %v", err) + } + if !bytes.Equal(decoded, src) { + t.Error("round-trip with nil dst: mismatch") + } +} + +func TestLZ4SmallDstEncode(t *testing.T) { + c := LZ4() + src := makeGUIPixels(400, 120) + + // dst too small for MaxEncodedSize -- forces reallocation. + dst := make([]byte, 4) + encoded, err := c.Encode(dst, src) + if err != nil { + t.Fatalf("Encode small dst: %v", err) + } + + // Verify round-trip. + decBuf := make([]byte, len(src)) + decoded, err := c.Decode(decBuf, encoded) + if err != nil { + t.Fatalf("Decode: %v", err) + } + if !bytes.Equal(decoded, src) { + t.Error("round-trip mismatch with small encode dst") + } +} + +func TestLZ4SmallDstDecode(t *testing.T) { + c := LZ4() + src := makeGUIPixels(400, 120) + + // Encode normally. + encBuf := make([]byte, c.MaxEncodedSize(len(src))) + encoded, err := c.Encode(encBuf, src) + if err != nil { + t.Fatalf("Encode: %v", err) + } + + // Decode with dst too small -- triggers retry growth. + smallDst := make([]byte, 64) + decoded, err := c.Decode(smallDst, encoded) + if err != nil { + t.Fatalf("Decode small dst: %v", err) + } + if !bytes.Equal(decoded, src) { + t.Error("round-trip mismatch with small decode dst") + } +} + +func TestLZ4ID(t *testing.T) { + c := LZ4() + if c.ID() != IDLZ4 { + t.Errorf("ID() = 0x%02x, want 0x%02x", c.ID(), IDLZ4) + } +} + +func TestLZ4MaxEncodedSize(t *testing.T) { + c := LZ4() + + tests := []int{0, 1, 100, 1024, 192000} + for _, srcLen := range tests { + maxSize := c.MaxEncodedSize(srcLen) + if maxSize < srcLen { + t.Errorf("MaxEncodedSize(%d) = %d, want >= %d", srcLen, maxSize, srcLen) + } + } +} + +func TestLZ4RoundTripVaryingSizes(t *testing.T) { + c := LZ4() + + sizes := []int{1, 2, 3, 4, 8, 12, 15, 16, 17, 64, 128, 256, 1024, 8192} + for _, size := range sizes { + src := make([]byte, size) + for i := range src { + src[i] = byte(i*17 + 31) + } + + encBuf := make([]byte, c.MaxEncodedSize(len(src))) + encoded, err := c.Encode(encBuf, src) + if err != nil { + t.Fatalf("Encode size=%d: %v", size, err) + } + + decBuf := make([]byte, len(src)) + decoded, err := c.Decode(decBuf, encoded) + if err != nil { + t.Fatalf("Decode size=%d: %v", size, err) + } + if !bytes.Equal(decoded, src) { + t.Errorf("round-trip mismatch for size=%d", size) + } + } +} + +func TestLZ4ConcurrentSafety(t *testing.T) { + c := LZ4() + src := makeGUIPixels(400, 120) + done := make(chan struct{}) + + for i := 0; i < 8; i++ { + go func() { + defer func() { done <- struct{}{} }() + encBuf := make([]byte, c.MaxEncodedSize(len(src))) + decBuf := make([]byte, len(src)) + for j := 0; j < 50; j++ { + encoded, err := c.Encode(encBuf, src) + if err != nil { + t.Errorf("concurrent Encode: %v", err) + return + } + decoded, err := c.Decode(decBuf, encoded) + if err != nil { + t.Errorf("concurrent Decode: %v", err) + return + } + if !bytes.Equal(decoded, src) { + t.Errorf("concurrent round-trip mismatch") + return + } + } + }() + } + for i := 0; i < 8; i++ { + <-done + } +} + +// BenchmarkLZ4Encode benchmarks LZ4 encoding with a realistic GUI frame (192KB). +func BenchmarkLZ4Encode(b *testing.B) { + c := LZ4() + src := makeGUIPixels(400, 120) // 192,000 bytes + dst := make([]byte, c.MaxEncodedSize(len(src))) + + b.SetBytes(int64(len(src))) + b.ResetTimer() + + for b.Loop() { + _, _ = c.Encode(dst, src) + } +} + +// BenchmarkLZ4Decode benchmarks LZ4 decoding with a realistic GUI frame. +func BenchmarkLZ4Decode(b *testing.B) { + c := LZ4() + src := makeGUIPixels(400, 120) + encBuf := make([]byte, c.MaxEncodedSize(len(src))) + encoded, err := c.Encode(encBuf, src) + if err != nil { + b.Fatalf("setup Encode: %v", err) + } + + dst := make([]byte, len(src)) + b.SetBytes(int64(len(src))) + b.ResetTimer() + + for b.Loop() { + _, _ = c.Decode(dst, encoded) + } +} + +// BenchmarkLZ4EncodeFullHD benchmarks LZ4 encoding with a 1920x1080 frame. +func BenchmarkLZ4EncodeFullHD(b *testing.B) { + c := LZ4() + src := makeGUIPixels(1920, 1080) // ~8.3 MB + dst := make([]byte, c.MaxEncodedSize(len(src))) + + b.SetBytes(int64(len(src))) + b.ResetTimer() + + for b.Loop() { + _, _ = c.Encode(dst, src) + } +} + +// BenchmarkLZ4DecodeFullHD benchmarks LZ4 decoding with a 1920x1080 frame. +func BenchmarkLZ4DecodeFullHD(b *testing.B) { + c := LZ4() + src := makeGUIPixels(1920, 1080) + encBuf := make([]byte, c.MaxEncodedSize(len(src))) + encoded, err := c.Encode(encBuf, src) + if err != nil { + b.Fatalf("setup Encode: %v", err) + } + + dst := make([]byte, len(src)) + b.SetBytes(int64(len(src))) + b.ResetTimer() + + for b.Loop() { + _, _ = c.Decode(dst, encoded) + } +} diff --git a/internal/codec/raw.go b/internal/codec/raw.go new file mode 100644 index 0000000..bcdbb69 --- /dev/null +++ b/internal/codec/raw.go @@ -0,0 +1,59 @@ +package codec + +func init() { + Register(Raw()) +} + +// Raw returns a codec that performs no compression (pass-through copy). +// Encode and Decode simply copy src into dst. This is the fastest codec +// and is used when compression overhead is not justified (small frames, +// already-compressed data, or LAN with abundant bandwidth). +func Raw() Codec { + return rawCodec{} +} + +type rawCodec struct{} + +// Encode copies src into dst unchanged. Returns a sub-slice of dst containing +// the copied data. If dst is nil or too small, allocates a new buffer. +func (rawCodec) Encode(dst, src []byte) ([]byte, error) { + if len(src) == 0 { + return dst[:0], nil + } + + if cap(dst) < len(src) { + dst = make([]byte, len(src)) + } else { + dst = dst[:len(src)] + } + + copy(dst, src) + return dst, nil +} + +// Decode copies src into dst unchanged. Returns a sub-slice of dst containing +// the copied data. If dst is nil or too small, allocates a new buffer. +func (rawCodec) Decode(dst, src []byte) ([]byte, error) { + if len(src) == 0 { + return dst[:0], nil + } + + if cap(dst) < len(src) { + dst = make([]byte, len(src)) + } else { + dst = dst[:len(src)] + } + + copy(dst, src) + return dst, nil +} + +// ID returns the raw codec protocol identifier (0x00). +func (rawCodec) ID() byte { + return IDRaw +} + +// MaxEncodedSize returns srcLen since raw encoding has no overhead. +func (rawCodec) MaxEncodedSize(srcLen int) int { + return srcLen +} diff --git a/internal/codec/raw_test.go b/internal/codec/raw_test.go new file mode 100644 index 0000000..8d6928e --- /dev/null +++ b/internal/codec/raw_test.go @@ -0,0 +1,258 @@ +package codec + +import ( + "bytes" + "testing" +) + +func TestRawRoundTrip(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + {"empty", nil}, + {"single byte", []byte{0x42}}, + {"small", []byte("hello world")}, + {"zeros", make([]byte, 1024)}, + {"frame 400x120x4", makeGUIPixels(400, 120)}, + } + + c := Raw() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + src := tt.data + + // Encode with pre-allocated buffer. + encBuf := make([]byte, c.MaxEncodedSize(len(src))) + encoded, err := c.Encode(encBuf, src) + if err != nil { + t.Fatalf("Encode: %v", err) + } + + // Raw codec output should equal input. + if !bytes.Equal(encoded, src) { + t.Errorf("Encode output differs from input: got %d bytes, want %d", len(encoded), len(src)) + } + + // Decode with pre-allocated buffer. + decBuf := make([]byte, len(src)) + decoded, err := c.Decode(decBuf, encoded) + if err != nil { + t.Fatalf("Decode: %v", err) + } + + if !bytes.Equal(decoded, src) { + t.Errorf("round-trip mismatch: got %d bytes, want %d", len(decoded), len(src)) + } + }) + } +} + +func TestRawEncodeNilDst(t *testing.T) { + c := Raw() + src := []byte("test data") + + // nil dst forces allocation (slow path). + encoded, err := c.Encode(nil, src) + if err != nil { + t.Fatalf("Encode(nil, src): %v", err) + } + if !bytes.Equal(encoded, src) { + t.Error("Encode with nil dst: output differs from input") + } +} + +func TestRawDecodeNilDst(t *testing.T) { + c := Raw() + src := []byte("test data") + + decoded, err := c.Decode(nil, src) + if err != nil { + t.Fatalf("Decode(nil, src): %v", err) + } + if !bytes.Equal(decoded, src) { + t.Error("Decode with nil dst: output differs from input") + } +} + +func TestRawEncodeEmptyInput(t *testing.T) { + c := Raw() + dst := make([]byte, 16) + + encoded, err := c.Encode(dst, nil) + if err != nil { + t.Fatalf("Encode(dst, nil): %v", err) + } + if len(encoded) != 0 { + t.Errorf("Encode(dst, nil) len = %d, want 0", len(encoded)) + } +} + +func TestRawDecodeEmptyInput(t *testing.T) { + c := Raw() + dst := make([]byte, 16) + + decoded, err := c.Decode(dst, nil) + if err != nil { + t.Fatalf("Decode(dst, nil): %v", err) + } + if len(decoded) != 0 { + t.Errorf("Decode(dst, nil) len = %d, want 0", len(decoded)) + } +} + +func TestRawEncodeSmallDst(t *testing.T) { + c := Raw() + src := []byte("longer test data that exceeds small buffer") + dst := make([]byte, 4) // too small + + encoded, err := c.Encode(dst, src) + if err != nil { + t.Fatalf("Encode with small dst: %v", err) + } + if !bytes.Equal(encoded, src) { + t.Error("Encode with small dst: output differs from input") + } +} + +func TestRawID(t *testing.T) { + c := Raw() + if c.ID() != IDRaw { + t.Errorf("ID() = 0x%02x, want 0x%02x", c.ID(), IDRaw) + } +} + +func TestRawMaxEncodedSize(t *testing.T) { + c := Raw() + tests := []struct { + srcLen int + want int + }{ + {0, 0}, + {1, 1}, + {1024, 1024}, + {192000, 192000}, // 400*120*4 + } + for _, tt := range tests { + got := c.MaxEncodedSize(tt.srcLen) + if got != tt.want { + t.Errorf("MaxEncodedSize(%d) = %d, want %d", tt.srcLen, got, tt.want) + } + } +} + +func TestRawEncodeLargeFrame(t *testing.T) { + c := Raw() + // Simulate a 1920x1080 RGBA frame (8.3 MB). + src := make([]byte, 1920*1080*4) + for i := range src { + src[i] = byte(i % 256) + } + + dst := make([]byte, c.MaxEncodedSize(len(src))) + encoded, err := c.Encode(dst, src) + if err != nil { + t.Fatalf("Encode large frame: %v", err) + } + if !bytes.Equal(encoded, src) { + t.Error("large frame encode: output differs from input") + } +} + +func TestRawConcurrentSafety(t *testing.T) { + c := Raw() + src := makeGUIPixels(400, 120) + done := make(chan struct{}) + + for i := 0; i < 8; i++ { + go func() { + defer func() { done <- struct{}{} }() + dst := make([]byte, c.MaxEncodedSize(len(src))) + for j := 0; j < 100; j++ { + encoded, err := c.Encode(dst, src) + if err != nil { + t.Errorf("concurrent Encode: %v", err) + return + } + if !bytes.Equal(encoded, src) { + t.Errorf("concurrent Encode: mismatch") + return + } + } + }() + } + for i := 0; i < 8; i++ { + <-done + } +} + +// BenchmarkRawEncode benchmarks raw codec encoding with a realistic GUI frame. +func BenchmarkRawEncode(b *testing.B) { + c := Raw() + src := makeGUIPixels(400, 120) + dst := make([]byte, c.MaxEncodedSize(len(src))) + + b.SetBytes(int64(len(src))) + b.ResetTimer() + + for b.Loop() { + _, _ = c.Encode(dst, src) + } +} + +// BenchmarkRawDecode benchmarks raw codec decoding with a realistic GUI frame. +func BenchmarkRawDecode(b *testing.B) { + c := Raw() + src := makeGUIPixels(400, 120) + dst := make([]byte, len(src)) + + b.SetBytes(int64(len(src))) + b.ResetTimer() + + for b.Loop() { + _, _ = c.Decode(dst, src) + } +} + +// makeGUIPixels generates synthetic GUI-like pixel data with large flat color +// regions, simulating a typical desktop widget frame (good for compression). +func makeGUIPixels(width, height int) []byte { + size := width * height * 4 + pixels := make([]byte, size) + + // Background: solid light gray (70% of frame). + bgEnd := int(float64(height) * 0.7) + for y := 0; y < bgEnd; y++ { + for x := 0; x < width; x++ { + off := (y*width + x) * 4 + pixels[off+0] = 0xF0 // R + pixels[off+1] = 0xF0 // G + pixels[off+2] = 0xF0 // B + pixels[off+3] = 0xFF // A + } + } + + // Button region: solid blue. + for y := bgEnd; y < bgEnd+30 && y < height; y++ { + for x := 20; x < 120 && x < width; x++ { + off := (y*width + x) * 4 + pixels[off+0] = 0x21 // R + pixels[off+1] = 0x96 // G + pixels[off+2] = 0xF3 // B + pixels[off+3] = 0xFF // A + } + } + + // Footer: solid dark gray. + for y := bgEnd + 30; y < height; y++ { + for x := 0; x < width; x++ { + off := (y*width + x) * 4 + pixels[off+0] = 0x30 // R + pixels[off+1] = 0x30 // G + pixels[off+2] = 0x30 // B + pixels[off+3] = 0xFF // A + } + } + + return pixels +} diff --git a/internal/conn/doc.go b/internal/conn/doc.go new file mode 100644 index 0000000..b904b11 --- /dev/null +++ b/internal/conn/doc.go @@ -0,0 +1,16 @@ +// Package conn provides connection lifecycle management for the compose library. +// +// It handles module ID allocation, name-to-ID mapping, hot-plug detection, +// graceful disconnect, and reconnection matching. The package is standalone +// with no internal dependencies. +// +// The core types are: +// +// - [Registry] manages module ID allocation and lookup. +// - [Manager] orchestrates the full module lifecycle (connect, handshake, +// active, disconnect) and fires event callbacks. +// - [Module] holds metadata about a connected module. +// - [State] represents the lifecycle state of a module connection. +// +// All exported methods are safe for concurrent use from multiple goroutines. +package conn diff --git a/internal/conn/errors.go b/internal/conn/errors.go new file mode 100644 index 0000000..671736c --- /dev/null +++ b/internal/conn/errors.go @@ -0,0 +1,18 @@ +package conn + +import "errors" + +// Sentinel errors for the conn package. +var ( + // ErrMaxModules is returned when attempting to register a module beyond + // the configured maximum capacity. + ErrMaxModules = errors.New("compose: maximum module count reached") + + // ErrNameTaken is returned when attempting to register a module with a + // name that is already in use by an active module. + ErrNameTaken = errors.New("compose: module name already registered") + + // ErrNotFound is returned when a lookup or operation references a module + // ID that does not exist in the registry. + ErrNotFound = errors.New("compose: module not found") +) diff --git a/internal/conn/manager.go b/internal/conn/manager.go new file mode 100644 index 0000000..7b84d84 --- /dev/null +++ b/internal/conn/manager.go @@ -0,0 +1,103 @@ +package conn + +import "sync" + +// Manager orchestrates module lifecycle: connect, handshake, active, disconnect. +// It wraps a Registry and adds event callbacks plus reconnection matching. +// All methods are safe for concurrent access from multiple goroutines. +type Manager struct { + registry *Registry + + mu sync.RWMutex + onConnect func(id uint64, name string) + onDisconnect func(id uint64, name string) +} + +// NewManager creates a Manager with the given max module count. +// The maxModules parameter must be positive; values less than 1 are clamped to 1. +func NewManager(maxModules int) *Manager { + return &Manager{ + registry: NewRegistry(maxModules), + } +} + +// HandleConnect processes a new module connection. +// It allocates an ID, transitions the module to Active state, and fires +// the OnConnect callback if registered. +// +// If a module with the same name was previously disconnected and unregistered, +// it is treated as a reconnection — a new ID is allocated but the name slot +// is reused seamlessly. +// +// Returns ErrMaxModules if the registry is at capacity. +// Returns ErrNameTaken if a module with the same name is currently active. +func (m *Manager) HandleConnect(name string, width, height, fps uint16) (uint64, error) { + id, err := m.registry.Register(name, width, height, fps) + if err != nil { + return 0, err + } + + // Transition directly to Active (handshake completed at this point). + m.registry.SetState(id, StateActive) + + // Fire callback outside the registry lock to avoid deadlocks in user code. + m.mu.RLock() + cb := m.onConnect + m.mu.RUnlock() + + if cb != nil { + cb(id, name) + } + + return id, nil +} + +// HandleDisconnect processes a module disconnection (graceful or crash). +// It transitions the module to Disconnected state, fires the OnDisconnect +// callback, and removes the module from the registry so its name can be reused. +// +// If the module ID does not exist, this is a no-op. +func (m *Manager) HandleDisconnect(id uint64) { + // Look up the module before unregistering so we have the name for the callback. + mod, exists := m.registry.Lookup(id) + if !exists { + return + } + + m.registry.SetState(id, StateDisconnected) + m.registry.Unregister(id) + + // Fire callback after state transition. + m.mu.RLock() + cb := m.onDisconnect + m.mu.RUnlock() + + if cb != nil { + cb(id, mod.Name) + } +} + +// OnConnect sets the callback fired when a module becomes active. +// Only one callback can be set; subsequent calls replace the previous one. +// Passing nil removes the callback. +func (m *Manager) OnConnect(fn func(id uint64, name string)) { + m.mu.Lock() + defer m.mu.Unlock() + + m.onConnect = fn +} + +// OnDisconnect sets the callback fired when a module disconnects. +// Only one callback can be set; subsequent calls replace the previous one. +// Passing nil removes the callback. +func (m *Manager) OnDisconnect(fn func(id uint64, name string)) { + m.mu.Lock() + defer m.mu.Unlock() + + m.onDisconnect = fn +} + +// Registry returns the underlying registry for lookups. +func (m *Manager) Registry() *Registry { + return m.registry +} diff --git a/internal/conn/manager_test.go b/internal/conn/manager_test.go new file mode 100644 index 0000000..d3c4327 --- /dev/null +++ b/internal/conn/manager_test.go @@ -0,0 +1,362 @@ +package conn + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" +) + +func TestNewManager(t *testing.T) { + m := NewManager(16) + if m.Registry() == nil { + t.Fatal("Registry() returned nil") + } + if m.Registry().Count() != 0 { + t.Errorf("initial count = %d, want 0", m.Registry().Count()) + } +} + +func TestManager_HandleConnect(t *testing.T) { + t.Run("basic connect", func(t *testing.T) { + m := NewManager(16) + id, err := m.HandleConnect("clock", 400, 120, 1) + if err != nil { + t.Fatalf("HandleConnect failed: %v", err) + } + if id == 0 { + t.Error("HandleConnect returned zero ID") + } + + mod, ok := m.Registry().Lookup(id) + if !ok { + t.Fatal("module not found after connect") + } + if mod.State != StateActive { + t.Errorf("state = %v, want StateActive", mod.State) + } + if mod.Name != "clock" { + t.Errorf("name = %q, want %q", mod.Name, "clock") + } + if mod.Width != 400 || mod.Height != 120 { + t.Errorf("dimensions = %dx%d, want 400x120", mod.Width, mod.Height) + } + if mod.FPS != 1 { + t.Errorf("fps = %d, want 1", mod.FPS) + } + }) + + t.Run("fires OnConnect callback", func(t *testing.T) { + m := NewManager(16) + + var callbackID uint64 + var callbackName string + m.OnConnect(func(id uint64, name string) { + callbackID = id + callbackName = name + }) + + id, _ := m.HandleConnect("weather", 320, 240, 30) + if callbackID != id { + t.Errorf("callback ID = %d, want %d", callbackID, id) + } + if callbackName != "weather" { + t.Errorf("callback name = %q, want %q", callbackName, "weather") + } + }) + + t.Run("no callback when nil", func(t *testing.T) { + m := NewManager(16) + // No callback set — should not panic. + _, err := m.HandleConnect("mod", 100, 100, 30) + if err != nil { + t.Fatalf("HandleConnect failed: %v", err) + } + }) + + t.Run("max modules error", func(t *testing.T) { + m := NewManager(2) + _, _ = m.HandleConnect("a", 100, 100, 30) + _, _ = m.HandleConnect("b", 100, 100, 30) + + _, err := m.HandleConnect("c", 100, 100, 30) + if !errors.Is(err, ErrMaxModules) { + t.Errorf("err = %v, want ErrMaxModules", err) + } + }) + + t.Run("name taken error", func(t *testing.T) { + m := NewManager(16) + _, _ = m.HandleConnect("clock", 400, 120, 1) + + _, err := m.HandleConnect("clock", 400, 120, 1) + if !errors.Is(err, ErrNameTaken) { + t.Errorf("err = %v, want ErrNameTaken", err) + } + }) + + t.Run("callback not fired on error", func(t *testing.T) { + m := NewManager(1) + _, _ = m.HandleConnect("a", 100, 100, 30) + + callbackFired := false + m.OnConnect(func(_ uint64, _ string) { + callbackFired = true + }) + + _, _ = m.HandleConnect("b", 100, 100, 30) // capacity exceeded + if callbackFired { + t.Error("OnConnect callback fired on error, should not") + } + }) +} + +func TestManager_HandleDisconnect(t *testing.T) { + t.Run("basic disconnect", func(t *testing.T) { + m := NewManager(16) + id, _ := m.HandleConnect("clock", 400, 120, 1) + m.HandleDisconnect(id) + + _, ok := m.Registry().Lookup(id) + if ok { + t.Error("module still found after disconnect") + } + if m.Registry().Count() != 0 { + t.Errorf("count = %d, want 0", m.Registry().Count()) + } + }) + + t.Run("fires OnDisconnect callback", func(t *testing.T) { + m := NewManager(16) + id, _ := m.HandleConnect("weather", 320, 240, 30) + + var callbackID uint64 + var callbackName string + m.OnDisconnect(func(id uint64, name string) { + callbackID = id + callbackName = name + }) + + m.HandleDisconnect(id) + if callbackID != id { + t.Errorf("callback ID = %d, want %d", callbackID, id) + } + if callbackName != "weather" { + t.Errorf("callback name = %q, want %q", callbackName, "weather") + } + }) + + t.Run("nonexistent ID is no-op", func(t *testing.T) { + m := NewManager(16) + callbackFired := false + m.OnDisconnect(func(_ uint64, _ string) { + callbackFired = true + }) + + m.HandleDisconnect(999) // should not panic + if callbackFired { + t.Error("OnDisconnect fired for nonexistent ID") + } + }) + + t.Run("double disconnect is safe", func(t *testing.T) { + m := NewManager(16) + id, _ := m.HandleConnect("mod", 100, 100, 30) + + var count int + m.OnDisconnect(func(_ uint64, _ string) { + count++ + }) + + m.HandleDisconnect(id) + m.HandleDisconnect(id) // second call should be no-op + + if count != 1 { + t.Errorf("OnDisconnect fired %d times, want 1", count) + } + }) +} + +func TestManager_Reconnection(t *testing.T) { + t.Run("name reusable after disconnect", func(t *testing.T) { + m := NewManager(16) + id1, _ := m.HandleConnect("clock", 400, 120, 1) + m.HandleDisconnect(id1) + + id2, err := m.HandleConnect("clock", 400, 120, 1) + if err != nil { + t.Fatalf("reconnect failed: %v", err) + } + if id2 <= id1 { + t.Errorf("reconnect ID not monotonic: old=%d, new=%d", id1, id2) + } + }) + + t.Run("reconnection fires callbacks", func(t *testing.T) { + m := NewManager(16) + + var connectCount, disconnectCount int + m.OnConnect(func(_ uint64, _ string) { connectCount++ }) + m.OnDisconnect(func(_ uint64, _ string) { disconnectCount++ }) + + id, _ := m.HandleConnect("clock", 400, 120, 1) + m.HandleDisconnect(id) + _, _ = m.HandleConnect("clock", 400, 120, 1) + + if connectCount != 2 { + t.Errorf("connect count = %d, want 2", connectCount) + } + if disconnectCount != 1 { + t.Errorf("disconnect count = %d, want 1", disconnectCount) + } + }) + + t.Run("slot freed for capacity", func(t *testing.T) { + m := NewManager(2) + id1, _ := m.HandleConnect("a", 100, 100, 30) + _, _ = m.HandleConnect("b", 100, 100, 30) + + // At capacity. + _, err := m.HandleConnect("c", 100, 100, 30) + if !errors.Is(err, ErrMaxModules) { + t.Fatalf("expected ErrMaxModules, got %v", err) + } + + // Disconnect one — slot freed. + m.HandleDisconnect(id1) + _, err = m.HandleConnect("c", 100, 100, 30) + if err != nil { + t.Errorf("connect after disconnect failed: %v", err) + } + }) +} + +func TestManager_CallbackReplacement(t *testing.T) { + m := NewManager(16) + + var firstCalled, secondCalled bool + + m.OnConnect(func(_ uint64, _ string) { firstCalled = true }) + m.OnConnect(func(_ uint64, _ string) { secondCalled = true }) + + m.HandleConnect("mod", 100, 100, 30) + + if firstCalled { + t.Error("first callback called after replacement") + } + if !secondCalled { + t.Error("second callback not called") + } +} + +func TestManager_NilCallback(t *testing.T) { + m := NewManager(16) + + m.OnConnect(func(_ uint64, _ string) {}) + m.OnConnect(nil) // remove callback + + // Should not panic. + _, _ = m.HandleConnect("mod", 100, 100, 30) +} + +func TestManager_ConcurrentAccess(t *testing.T) { + m := NewManager(500) + + var connectCount atomic.Int64 + var disconnectCount atomic.Int64 + + m.OnConnect(func(_ uint64, _ string) { + connectCount.Add(1) + }) + m.OnDisconnect(func(_ uint64, _ string) { + disconnectCount.Add(1) + }) + + var wg sync.WaitGroup + + // Concurrent connects. + ids := make([]uint64, 200) + var idMu sync.Mutex + for i := range 200 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + name := fmt.Sprintf("module-%d", idx) + id, err := m.HandleConnect(name, 100, 100, 30) + if err == nil { + idMu.Lock() + ids[idx] = id + idMu.Unlock() + } + }(i) + } + wg.Wait() + + // Concurrent disconnects. + for i := range 200 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + idMu.Lock() + id := ids[idx] + idMu.Unlock() + if id != 0 { + m.HandleDisconnect(id) + } + }(i) + } + wg.Wait() + + // After all disconnects, registry should be empty. + if m.Registry().Count() != 0 { + t.Errorf("count after all disconnects = %d, want 0", m.Registry().Count()) + } + + // Verify callback counts match. + connected := connectCount.Load() + disconnected := disconnectCount.Load() + if connected != disconnected { + t.Errorf("connect count (%d) != disconnect count (%d)", connected, disconnected) + } +} + +func TestManager_Registry(t *testing.T) { + m := NewManager(16) + r := m.Registry() + if r == nil { + t.Fatal("Registry() returned nil") + } + + // Verify it is the same instance. + id, _ := m.HandleConnect("mod", 100, 100, 30) + mod, ok := r.Lookup(id) + if !ok { + t.Fatal("registry lookup failed for module connected via manager") + } + if mod.Name != "mod" { + t.Errorf("name = %q, want %q", mod.Name, "mod") + } +} + +func BenchmarkHandleConnect(b *testing.B) { + m := NewManager(b.N + 1) + b.ResetTimer() + for i := range b.N { + name := fmt.Sprintf("module-%d", i) + _, _ = m.HandleConnect(name, 100, 100, 30) + } +} + +func BenchmarkHandleDisconnect(b *testing.B) { + m := NewManager(b.N + 1) + ids := make([]uint64, b.N) + for i := range b.N { + ids[i], _ = m.HandleConnect(fmt.Sprintf("module-%d", i), 100, 100, 30) + } + + b.ResetTimer() + for i := range b.N { + m.HandleDisconnect(ids[i]) + } +} diff --git a/internal/conn/module.go b/internal/conn/module.go new file mode 100644 index 0000000..3956145 --- /dev/null +++ b/internal/conn/module.go @@ -0,0 +1,65 @@ +package conn + +import "time" + +// State represents the lifecycle state of a module connection. +type State uint8 + +const ( + // StateConnecting indicates a connection has been initiated but the + // handshake has not yet started. + StateConnecting State = iota + + // StateHandshaking indicates the handshake is in progress. + StateHandshaking + + // StateActive indicates the module is fully connected and sending frames. + StateActive + + // StateDisconnected indicates the connection was lost or gracefully closed. + StateDisconnected +) + +// String returns a human-readable name for the state. +func (s State) String() string { + switch s { + case StateConnecting: + return "connecting" + case StateHandshaking: + return "handshaking" + case StateActive: + return "active" + case StateDisconnected: + return "disconnected" + default: + return "unknown" + } +} + +// Module holds metadata about a connected module. +type Module struct { + // ID is the compositor-assigned unique identifier for this module. + // IDs are monotonically increasing and never reused within a process lifetime. + ID uint64 + + // Name is the human-readable module name (e.g., "clock", "weather"). + Name string + + // State is the current lifecycle state of the module connection. + State State + + // Width is the frame width in pixels. + Width uint16 + + // Height is the frame height in pixels. + Height uint16 + + // FPS is the requested frame rate. + FPS uint16 + + // ConnectedAt is the time the module first connected. + ConnectedAt time.Time + + // LastFrameAt is the time the most recent frame was received. + LastFrameAt time.Time +} diff --git a/internal/conn/registry.go b/internal/conn/registry.go new file mode 100644 index 0000000..eb645ef --- /dev/null +++ b/internal/conn/registry.go @@ -0,0 +1,159 @@ +package conn + +import ( + "sync" + "time" +) + +// Registry manages module ID allocation and name-to-ID mapping. +// All methods are safe for concurrent access from multiple goroutines. +type Registry struct { + mu sync.RWMutex + modules map[uint64]*Module + nameIndex map[string]uint64 + nextID uint64 + maxModules int +} + +// NewRegistry creates a registry with the given maximum capacity. +// The maxModules parameter must be positive; values less than 1 are +// clamped to 1. +func NewRegistry(maxModules int) *Registry { + if maxModules < 1 { + maxModules = 1 + } + return &Registry{ + modules: make(map[uint64]*Module), + nameIndex: make(map[string]uint64), + nextID: 1, // IDs start at 1, never 0 + maxModules: maxModules, + } +} + +// Register allocates a new module ID and stores the module. +// Returns ErrMaxModules if at capacity. +// Returns ErrNameTaken if a module with the same name is already active. +func (r *Registry) Register(name string, width, height, fps uint16) (uint64, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if len(r.modules) >= r.maxModules { + return 0, ErrMaxModules + } + + if _, exists := r.nameIndex[name]; exists { + return 0, ErrNameTaken + } + + id := r.nextID + r.nextID++ + + mod := &Module{ + ID: id, + Name: name, + State: StateConnecting, + Width: width, + Height: height, + FPS: fps, + ConnectedAt: time.Now(), + } + + r.modules[id] = mod + r.nameIndex[name] = id + + return id, nil +} + +// Unregister removes a module from the registry, freeing its slot. +// The module's name becomes available for reuse. If the module ID does +// not exist, this is a no-op. +func (r *Registry) Unregister(id uint64) { + r.mu.Lock() + defer r.mu.Unlock() + + mod, exists := r.modules[id] + if !exists { + return + } + + delete(r.nameIndex, mod.Name) + delete(r.modules, id) +} + +// Lookup returns the module by ID. Returns (nil, false) if not found. +func (r *Registry) Lookup(id uint64) (*Module, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + mod, exists := r.modules[id] + if !exists { + return nil, false + } + + // Return a copy to prevent data races on the caller's side. + cp := *mod + return &cp, true +} + +// LookupByName returns the module by name. Returns (nil, false) if not found. +func (r *Registry) LookupByName(name string) (*Module, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + id, exists := r.nameIndex[name] + if !exists { + return nil, false + } + + mod, exists := r.modules[id] + if !exists { + return nil, false + } + + cp := *mod + return &cp, true +} + +// SetState updates a module's state. If the module ID does not exist, +// this is a no-op. +func (r *Registry) SetState(id uint64, state State) { + r.mu.Lock() + defer r.mu.Unlock() + + if mod, exists := r.modules[id]; exists { + mod.State = state + } +} + +// UpdateLastFrame updates the last frame timestamp for a module. +// If the module ID does not exist, this is a no-op. +func (r *Registry) UpdateLastFrame(id uint64, t time.Time) { + r.mu.Lock() + defer r.mu.Unlock() + + if mod, exists := r.modules[id]; exists { + mod.LastFrameAt = t + } +} + +// Count returns the number of currently registered modules. +func (r *Registry) Count() int { + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.modules) +} + +// All returns a snapshot of all registered modules. The returned slice +// contains copies; modifying them does not affect the registry. +func (r *Registry) All() []*Module { + r.mu.RLock() + defer r.mu.RUnlock() + + result := make([]*Module, 0, len(r.modules)) + for _, mod := range r.modules { + cp := *mod + result = append(result, &cp) + } + return result +} diff --git a/internal/conn/registry_test.go b/internal/conn/registry_test.go new file mode 100644 index 0000000..a7150c0 --- /dev/null +++ b/internal/conn/registry_test.go @@ -0,0 +1,475 @@ +package conn + +import ( + "errors" + "fmt" + "sync" + "testing" + "time" +) + +func TestNewRegistry(t *testing.T) { + t.Run("positive capacity", func(t *testing.T) { + r := NewRegistry(16) + if r.Count() != 0 { + t.Errorf("new registry count = %d, want 0", r.Count()) + } + }) + + t.Run("zero capacity clamped to 1", func(t *testing.T) { + r := NewRegistry(0) + _, err := r.Register("a", 100, 100, 30) + if err != nil { + t.Fatalf("first register failed: %v", err) + } + _, err = r.Register("b", 100, 100, 30) + if !errors.Is(err, ErrMaxModules) { + t.Errorf("second register err = %v, want ErrMaxModules", err) + } + }) + + t.Run("negative capacity clamped to 1", func(t *testing.T) { + r := NewRegistry(-5) + _, err := r.Register("a", 100, 100, 30) + if err != nil { + t.Fatalf("first register failed: %v", err) + } + _, err = r.Register("b", 100, 100, 30) + if !errors.Is(err, ErrMaxModules) { + t.Errorf("second register err = %v, want ErrMaxModules", err) + } + }) +} + +func TestRegistry_Register(t *testing.T) { + t.Run("basic registration", func(t *testing.T) { + r := NewRegistry(16) + id, err := r.Register("clock", 400, 120, 1) + if err != nil { + t.Fatalf("register failed: %v", err) + } + if id != 1 { + t.Errorf("first ID = %d, want 1", id) + } + if r.Count() != 1 { + t.Errorf("count = %d, want 1", r.Count()) + } + }) + + t.Run("monotonic IDs", func(t *testing.T) { + r := NewRegistry(16) + id1, _ := r.Register("a", 100, 100, 30) + id2, _ := r.Register("b", 100, 100, 30) + id3, _ := r.Register("c", 100, 100, 30) + + if id1 >= id2 || id2 >= id3 { + t.Errorf("IDs not monotonic: %d, %d, %d", id1, id2, id3) + } + }) + + t.Run("IDs never reused after unregister", func(t *testing.T) { + r := NewRegistry(16) + id1, _ := r.Register("a", 100, 100, 30) + r.Unregister(id1) + id2, _ := r.Register("b", 100, 100, 30) + + if id2 <= id1 { + t.Errorf("ID reused: id1=%d, id2=%d", id1, id2) + } + }) + + t.Run("max capacity", func(t *testing.T) { + r := NewRegistry(3) + for i := range 3 { + _, err := r.Register(fmt.Sprintf("mod%d", i), 100, 100, 30) + if err != nil { + t.Fatalf("register %d failed: %v", i, err) + } + } + + _, err := r.Register("overflow", 100, 100, 30) + if !errors.Is(err, ErrMaxModules) { + t.Errorf("overflow err = %v, want ErrMaxModules", err) + } + }) + + t.Run("name collision", func(t *testing.T) { + r := NewRegistry(16) + _, err := r.Register("clock", 400, 120, 1) + if err != nil { + t.Fatalf("first register failed: %v", err) + } + + _, err = r.Register("clock", 400, 120, 1) + if !errors.Is(err, ErrNameTaken) { + t.Errorf("duplicate name err = %v, want ErrNameTaken", err) + } + }) + + t.Run("name reusable after unregister", func(t *testing.T) { + r := NewRegistry(16) + id, _ := r.Register("clock", 400, 120, 1) + r.Unregister(id) + + id2, err := r.Register("clock", 400, 120, 1) + if err != nil { + t.Fatalf("re-register failed: %v", err) + } + if id2 <= id { + t.Errorf("reused ID: old=%d, new=%d", id, id2) + } + }) + + t.Run("initial state is connecting", func(t *testing.T) { + r := NewRegistry(16) + id, _ := r.Register("mod", 100, 100, 30) + mod, ok := r.Lookup(id) + if !ok { + t.Fatal("lookup failed") + } + if mod.State != StateConnecting { + t.Errorf("initial state = %v, want StateConnecting", mod.State) + } + }) + + t.Run("metadata stored correctly", func(t *testing.T) { + r := NewRegistry(16) + before := time.Now() + id, _ := r.Register("weather", 320, 240, 60) + after := time.Now() + + mod, ok := r.Lookup(id) + if !ok { + t.Fatal("lookup failed") + } + if mod.Name != "weather" { + t.Errorf("name = %q, want %q", mod.Name, "weather") + } + if mod.Width != 320 { + t.Errorf("width = %d, want 320", mod.Width) + } + if mod.Height != 240 { + t.Errorf("height = %d, want 240", mod.Height) + } + if mod.FPS != 60 { + t.Errorf("fps = %d, want 60", mod.FPS) + } + if mod.ConnectedAt.Before(before) || mod.ConnectedAt.After(after) { + t.Errorf("ConnectedAt = %v, want between %v and %v", mod.ConnectedAt, before, after) + } + }) +} + +func TestRegistry_Unregister(t *testing.T) { + t.Run("removes module", func(t *testing.T) { + r := NewRegistry(16) + id, _ := r.Register("mod", 100, 100, 30) + r.Unregister(id) + + if r.Count() != 0 { + t.Errorf("count after unregister = %d, want 0", r.Count()) + } + _, ok := r.Lookup(id) + if ok { + t.Error("lookup succeeded after unregister, want not found") + } + }) + + t.Run("nonexistent ID is no-op", func(t *testing.T) { + r := NewRegistry(16) + r.Unregister(999) // should not panic + }) + + t.Run("frees capacity slot", func(t *testing.T) { + r := NewRegistry(2) + id1, _ := r.Register("a", 100, 100, 30) + _, _ = r.Register("b", 100, 100, 30) + + // At capacity. + _, err := r.Register("c", 100, 100, 30) + if !errors.Is(err, ErrMaxModules) { + t.Fatalf("expected ErrMaxModules, got %v", err) + } + + // Unregister one — should free a slot. + r.Unregister(id1) + _, err = r.Register("c", 100, 100, 30) + if err != nil { + t.Errorf("register after unregister failed: %v", err) + } + }) +} + +func TestRegistry_Lookup(t *testing.T) { + t.Run("existing module", func(t *testing.T) { + r := NewRegistry(16) + id, _ := r.Register("mod", 100, 200, 30) + mod, ok := r.Lookup(id) + if !ok { + t.Fatal("lookup failed") + } + if mod.ID != id || mod.Name != "mod" { + t.Errorf("lookup returned wrong module: %+v", mod) + } + }) + + t.Run("nonexistent module", func(t *testing.T) { + r := NewRegistry(16) + _, ok := r.Lookup(42) + if ok { + t.Error("lookup succeeded for nonexistent ID") + } + }) + + t.Run("returns copy", func(t *testing.T) { + r := NewRegistry(16) + id, _ := r.Register("mod", 100, 100, 30) + mod1, _ := r.Lookup(id) + mod1.Name = "mutated" + + mod2, _ := r.Lookup(id) + if mod2.Name == "mutated" { + t.Error("Lookup returned a reference instead of a copy") + } + }) +} + +func TestRegistry_LookupByName(t *testing.T) { + t.Run("existing module", func(t *testing.T) { + r := NewRegistry(16) + id, _ := r.Register("clock", 400, 120, 1) + mod, ok := r.LookupByName("clock") + if !ok { + t.Fatal("lookup by name failed") + } + if mod.ID != id { + t.Errorf("ID = %d, want %d", mod.ID, id) + } + }) + + t.Run("nonexistent name", func(t *testing.T) { + r := NewRegistry(16) + _, ok := r.LookupByName("nonexistent") + if ok { + t.Error("lookup by name succeeded for nonexistent name") + } + }) +} + +func TestRegistry_SetState(t *testing.T) { + t.Run("transitions state", func(t *testing.T) { + r := NewRegistry(16) + id, _ := r.Register("mod", 100, 100, 30) + + r.SetState(id, StateHandshaking) + mod, _ := r.Lookup(id) + if mod.State != StateHandshaking { + t.Errorf("state = %v, want StateHandshaking", mod.State) + } + + r.SetState(id, StateActive) + mod, _ = r.Lookup(id) + if mod.State != StateActive { + t.Errorf("state = %v, want StateActive", mod.State) + } + + r.SetState(id, StateDisconnected) + mod, _ = r.Lookup(id) + if mod.State != StateDisconnected { + t.Errorf("state = %v, want StateDisconnected", mod.State) + } + }) + + t.Run("nonexistent ID is no-op", func(t *testing.T) { + r := NewRegistry(16) + r.SetState(999, StateActive) // should not panic + }) +} + +func TestRegistry_UpdateLastFrame(t *testing.T) { + t.Run("updates timestamp", func(t *testing.T) { + r := NewRegistry(16) + id, _ := r.Register("mod", 100, 100, 30) + + ts := time.Now().Add(5 * time.Second) + r.UpdateLastFrame(id, ts) + + mod, _ := r.Lookup(id) + if !mod.LastFrameAt.Equal(ts) { + t.Errorf("LastFrameAt = %v, want %v", mod.LastFrameAt, ts) + } + }) + + t.Run("nonexistent ID is no-op", func(t *testing.T) { + r := NewRegistry(16) + r.UpdateLastFrame(999, time.Now()) // should not panic + }) +} + +func TestRegistry_Count(t *testing.T) { + r := NewRegistry(16) + if r.Count() != 0 { + t.Fatalf("initial count = %d, want 0", r.Count()) + } + + r.Register("a", 100, 100, 30) + r.Register("b", 100, 100, 30) + if r.Count() != 2 { + t.Errorf("count = %d, want 2", r.Count()) + } + + id, _ := r.Register("c", 100, 100, 30) + r.Unregister(id) + if r.Count() != 2 { + t.Errorf("count after unregister = %d, want 2", r.Count()) + } +} + +func TestRegistry_All(t *testing.T) { + t.Run("empty registry", func(t *testing.T) { + r := NewRegistry(16) + all := r.All() + if len(all) != 0 { + t.Errorf("All() len = %d, want 0", len(all)) + } + }) + + t.Run("returns all modules", func(t *testing.T) { + r := NewRegistry(16) + r.Register("a", 100, 100, 30) + r.Register("b", 200, 200, 60) + r.Register("c", 300, 300, 1) + + all := r.All() + if len(all) != 3 { + t.Fatalf("All() len = %d, want 3", len(all)) + } + + names := make(map[string]bool) + for _, m := range all { + names[m.Name] = true + } + for _, want := range []string{"a", "b", "c"} { + if !names[want] { + t.Errorf("missing module %q in All() result", want) + } + } + }) + + t.Run("returns copies", func(t *testing.T) { + r := NewRegistry(16) + r.Register("mod", 100, 100, 30) + + all := r.All() + all[0].Name = "mutated" + + mod, _ := r.LookupByName("mod") + if mod == nil { + t.Error("original module name was mutated via All() result") + } + }) +} + +func TestRegistry_ConcurrentAccess(t *testing.T) { + r := NewRegistry(1000) + var wg sync.WaitGroup + + // Concurrent writers (register). + for i := range 100 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + name := fmt.Sprintf("module-%d", idx) + _, _ = r.Register(name, 100, 100, 30) + }(i) + } + + // Concurrent readers (lookup, count, all). + for range 50 { + wg.Add(1) + go func() { + defer wg.Done() + _ = r.Count() + _ = r.All() + _, _ = r.Lookup(1) + _, _ = r.LookupByName("module-0") + }() + } + + // Concurrent state updates. + for i := range 50 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + r.SetState(uint64(idx+1), StateActive) + r.UpdateLastFrame(uint64(idx+1), time.Now()) + }(i) + } + + wg.Wait() + + // Verify consistency: count matches actual registered modules. + count := r.Count() + all := r.All() + if count != len(all) { + t.Errorf("count=%d != len(All())=%d", count, len(all)) + } +} + +func TestState_String(t *testing.T) { + tests := []struct { + state State + want string + }{ + {StateConnecting, "connecting"}, + {StateHandshaking, "handshaking"}, + {StateActive, "active"}, + {StateDisconnected, "disconnected"}, + {State(99), "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := tt.state.String() + if got != tt.want { + t.Errorf("State(%d).String() = %q, want %q", tt.state, got, tt.want) + } + }) + } +} + +func BenchmarkRegister(b *testing.B) { + r := NewRegistry(b.N + 1) + b.ResetTimer() + for i := range b.N { + name := fmt.Sprintf("module-%d", i) + _, _ = r.Register(name, 100, 100, 30) + } +} + +func BenchmarkLookup(b *testing.B) { + r := NewRegistry(1000) + for i := range 1000 { + r.Register(fmt.Sprintf("module-%d", i), 100, 100, 30) + } + + b.ResetTimer() + for i := range b.N { + id := uint64(i%1000) + 1 + _, _ = r.Lookup(id) + } +} + +func BenchmarkLookupByName(b *testing.B) { + r := NewRegistry(1000) + names := make([]string, 1000) + for i := range 1000 { + names[i] = fmt.Sprintf("module-%d", i) + r.Register(names[i], 100, 100, 30) + } + + b.ResetTimer() + for i := range b.N { + _, _ = r.LookupByName(names[i%1000]) + } +} diff --git a/internal/flow/controller.go b/internal/flow/controller.go new file mode 100644 index 0000000..c78a2c5 --- /dev/null +++ b/internal/flow/controller.go @@ -0,0 +1,198 @@ +package flow + +import ( + "sync" + "time" +) + +// defaultFPS is used when a module specifies 0 FPS (static content). +const defaultFPS = 1 + +// missedThreshold is the number of consecutive missed frames before the +// controller halves the effective request rate for a module. +const missedThreshold = 3 + +// moduleState holds per-module frame pacing state. +type moduleState struct { + targetFPS uint16 // module's preferred FPS (from handshake) + interval time.Duration // base interval: 1/targetFPS + lastRequest time.Time // when we last sent FrameRequest + lastDelivery time.Time // when we last received a frame + pending bool // true if FrameRequest sent but no frame received yet + missedCount int // consecutive missed requests (for adaptive rate) +} + +// effectiveInterval returns the current interval accounting for adaptive rate +// reduction. After missedThreshold consecutive misses, the interval doubles. +func (ms *moduleState) effectiveInterval() time.Duration { + if ms.missedCount >= missedThreshold { + return ms.interval * 2 + } + return ms.interval +} + +// Option configures a Controller. +type Option func(*Controller) + +// WithClock sets a custom time source for the controller. Intended for testing +// to provide deterministic time progression without time.Sleep. +func WithClock(fn func() time.Time) Option { + return func(c *Controller) { + c.now = fn + } +} + +// Controller manages pull-based frame pacing for connected modules. +// The compositor uses it to decide when to request frames from each module. +// All methods are safe for concurrent access from multiple goroutines. +type Controller struct { + mu sync.Mutex + modules map[uint64]*moduleState + now func() time.Time +} + +// New creates a flow controller. Use WithClock to inject a custom time source +// for testing. +func New(opts ...Option) *Controller { + c := &Controller{ + modules: make(map[uint64]*moduleState), + now: time.Now, + } + for _, opt := range opts { + opt(c) + } + return c +} + +// AddModule registers a module with its preferred FPS. +// If fps is 0, it defaults to 1 FPS (suitable for static content like a clock). +func (c *Controller) AddModule(moduleID uint64, fps uint16) { + if fps == 0 { + fps = defaultFPS + } + + interval := time.Second / time.Duration(fps) + + c.mu.Lock() + defer c.mu.Unlock() + + c.modules[moduleID] = &moduleState{ + targetFPS: fps, + interval: interval, + } +} + +// RemoveModule unregisters a module. If the module ID does not exist, this is +// a no-op. +func (c *Controller) RemoveModule(moduleID uint64) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.modules, moduleID) +} + +// ShouldRequest returns true if it is time to request a frame from the +// specified module. The compositor should call this on each render tick. +// +// Returns false if: +// - The module ID is not registered +// - A request is already pending (module has not responded yet) +// - Not enough time has elapsed since the last request (respecting target FPS +// and adaptive rate reduction) +func (c *Controller) ShouldRequest(moduleID uint64) bool { + c.mu.Lock() + defer c.mu.Unlock() + + ms, ok := c.modules[moduleID] + if !ok { + return false + } + + if ms.pending { + return false + } + + now := c.now() + + // First request: always allow if no request has been made yet. + if ms.lastRequest.IsZero() { + return true + } + + elapsed := now.Sub(ms.lastRequest) + return elapsed >= ms.effectiveInterval() +} + +// FrameRequested marks that a FrameRequest was sent to the specified module. +// Records the request time and sets the pending flag. If the module ID does not +// exist, this is a no-op. +func (c *Controller) FrameRequested(moduleID uint64) { + c.mu.Lock() + defer c.mu.Unlock() + + ms, ok := c.modules[moduleID] + if !ok { + return + } + + ms.lastRequest = c.now() + ms.pending = true +} + +// FrameDelivered marks that a frame was received from the specified module. +// Resets the pending flag, records the delivery time, and clears the missed +// count. If the module ID does not exist, this is a no-op. +func (c *Controller) FrameDelivered(moduleID uint64) { + c.mu.Lock() + defer c.mu.Unlock() + + ms, ok := c.modules[moduleID] + if !ok { + return + } + + ms.lastDelivery = c.now() + ms.pending = false + ms.missedCount = 0 +} + +// FrameMissed marks that the specified module failed to deliver a frame in +// time. After missedThreshold (3) consecutive misses, the controller halves +// the effective request rate by doubling the interval. This prevents wasting +// bandwidth on slow or stuck modules. If the module ID does not exist, this +// is a no-op. +func (c *Controller) FrameMissed(moduleID uint64) { + c.mu.Lock() + defer c.mu.Unlock() + + ms, ok := c.modules[moduleID] + if !ok { + return + } + + ms.missedCount++ + ms.pending = false +} + +// PendingModules returns the count of modules with pending (unanswered) +// frame requests. +func (c *Controller) PendingModules() int { + c.mu.Lock() + defer c.mu.Unlock() + + count := 0 + for _, ms := range c.modules { + if ms.pending { + count++ + } + } + return count +} + +// ModuleCount returns the total number of registered modules. +func (c *Controller) ModuleCount() int { + c.mu.Lock() + defer c.mu.Unlock() + + return len(c.modules) +} diff --git a/internal/flow/controller_test.go b/internal/flow/controller_test.go new file mode 100644 index 0000000..fecc16f --- /dev/null +++ b/internal/flow/controller_test.go @@ -0,0 +1,739 @@ +package flow + +import ( + "sync" + "testing" + "time" +) + +// testClock returns a clock function and an advance function. +// The clock starts at a fixed epoch. Advance moves the clock forward +// by the given duration. All time progression is deterministic. +func testClock() (now func() time.Time, advance func(d time.Duration)) { + t := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + return func() time.Time { + return t + }, func(d time.Duration) { + t = t.Add(d) + } +} + +func TestNewController(t *testing.T) { + c := New() + if c == nil { + t.Fatal("New() returned nil") + } + if c.ModuleCount() != 0 { + t.Errorf("ModuleCount() = %d, want 0", c.ModuleCount()) + } + if c.PendingModules() != 0 { + t.Errorf("PendingModules() = %d, want 0", c.PendingModules()) + } +} + +func TestNewControllerWithClock(t *testing.T) { + clock, _ := testClock() + c := New(WithClock(clock)) + if c == nil { + t.Fatal("New(WithClock) returned nil") + } + if c.now == nil { + t.Fatal("clock function not set") + } +} + +func TestAddModule(t *testing.T) { + c := New() + + c.AddModule(1, 60) + if c.ModuleCount() != 1 { + t.Errorf("ModuleCount() = %d, want 1", c.ModuleCount()) + } + + c.AddModule(2, 30) + if c.ModuleCount() != 2 { + t.Errorf("ModuleCount() = %d, want 2", c.ModuleCount()) + } +} + +func TestAddModuleZeroFPS(t *testing.T) { + clock, _ := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 0) + + c.mu.Lock() + ms := c.modules[1] + c.mu.Unlock() + + if ms.targetFPS != 1 { + t.Errorf("targetFPS = %d, want 1 (default for 0 FPS)", ms.targetFPS) + } + if ms.interval != time.Second { + t.Errorf("interval = %v, want %v", ms.interval, time.Second) + } +} + +func TestAddModuleOverwrite(t *testing.T) { + c := New() + + c.AddModule(1, 60) + c.AddModule(1, 30) + + if c.ModuleCount() != 1 { + t.Errorf("ModuleCount() = %d, want 1 (overwrite, not duplicate)", c.ModuleCount()) + } + + c.mu.Lock() + ms := c.modules[1] + c.mu.Unlock() + + if ms.targetFPS != 30 { + t.Errorf("targetFPS = %d, want 30 (overwritten)", ms.targetFPS) + } +} + +func TestRemoveModule(t *testing.T) { + c := New() + + c.AddModule(1, 60) + c.AddModule(2, 30) + + c.RemoveModule(1) + if c.ModuleCount() != 1 { + t.Errorf("ModuleCount() = %d, want 1 after remove", c.ModuleCount()) + } + + c.RemoveModule(2) + if c.ModuleCount() != 0 { + t.Errorf("ModuleCount() = %d, want 0 after remove all", c.ModuleCount()) + } +} + +func TestRemoveUnknownModule(t *testing.T) { + c := New() + + // Should not panic. + c.RemoveModule(999) + + if c.ModuleCount() != 0 { + t.Errorf("ModuleCount() = %d, want 0", c.ModuleCount()) + } +} + +func TestShouldRequestFirstTime(t *testing.T) { + clock, _ := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 60) + + if !c.ShouldRequest(1) { + t.Error("ShouldRequest() = false for first request, want true") + } +} + +func TestShouldRequestUnknownModule(t *testing.T) { + c := New() + + if c.ShouldRequest(999) { + t.Error("ShouldRequest(999) = true for unknown module, want false") + } +} + +func TestShouldRequestRespectsInterval(t *testing.T) { + clock, advance := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 10) // 10 FPS = 100ms interval + + // First request: allowed. + if !c.ShouldRequest(1) { + t.Fatal("ShouldRequest() = false for first request") + } + + c.FrameRequested(1) + c.FrameDelivered(1) + + // Immediately after delivery: not enough time has passed. + if c.ShouldRequest(1) { + t.Error("ShouldRequest() = true immediately after request, want false") + } + + // Advance 50ms (half the interval): still too early. + advance(50 * time.Millisecond) + if c.ShouldRequest(1) { + t.Error("ShouldRequest() = true at 50ms (half interval), want false") + } + + // Advance another 50ms (total 100ms = interval): should be allowed. + advance(50 * time.Millisecond) + if !c.ShouldRequest(1) { + t.Error("ShouldRequest() = false at 100ms (full interval), want true") + } +} + +func TestShouldRequestBlocksWhilePending(t *testing.T) { + clock, advance := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 1) // 1 FPS = 1s interval + + // First request. + if !c.ShouldRequest(1) { + t.Fatal("ShouldRequest() = false for first request") + } + c.FrameRequested(1) + + // Even after the interval, pending blocks the request. + advance(2 * time.Second) + if c.ShouldRequest(1) { + t.Error("ShouldRequest() = true while pending, want false") + } + + // After delivery, pending is cleared. + c.FrameDelivered(1) + + // Now enough time has passed. + if !c.ShouldRequest(1) { + t.Error("ShouldRequest() = false after delivery + elapsed interval, want true") + } +} + +func TestFrameRequestedSetsTime(t *testing.T) { + clock, advance := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 10) // 100ms interval + + advance(500 * time.Millisecond) + c.FrameRequested(1) + + c.mu.Lock() + ms := c.modules[1] + lr := ms.lastRequest + c.mu.Unlock() + + expected := time.Date(2026, 1, 1, 0, 0, 0, 500_000_000, time.UTC) + if !lr.Equal(expected) { + t.Errorf("lastRequest = %v, want %v", lr, expected) + } +} + +func TestFrameRequestedUnknownModule(t *testing.T) { + c := New() + + // Should not panic. + c.FrameRequested(999) +} + +func TestFrameDeliveredClearsPending(t *testing.T) { + clock, _ := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 60) + + c.FrameRequested(1) + if c.PendingModules() != 1 { + t.Errorf("PendingModules() = %d after request, want 1", c.PendingModules()) + } + + c.FrameDelivered(1) + if c.PendingModules() != 0 { + t.Errorf("PendingModules() = %d after delivery, want 0", c.PendingModules()) + } +} + +func TestFrameDeliveredClearsMissedCount(t *testing.T) { + clock, _ := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 60) + + // Accumulate misses. + c.FrameMissed(1) + c.FrameMissed(1) + + c.mu.Lock() + countBefore := c.modules[1].missedCount + c.mu.Unlock() + + if countBefore != 2 { + t.Errorf("missedCount = %d before delivery, want 2", countBefore) + } + + // Delivery resets missed count. + c.FrameDelivered(1) + + c.mu.Lock() + countAfter := c.modules[1].missedCount + c.mu.Unlock() + + if countAfter != 0 { + t.Errorf("missedCount = %d after delivery, want 0", countAfter) + } +} + +func TestFrameDeliveredUnknownModule(t *testing.T) { + c := New() + + // Should not panic. + c.FrameDelivered(999) +} + +func TestFrameMissedIncrements(t *testing.T) { + c := New() + + c.AddModule(1, 60) + + for i := 1; i <= 5; i++ { + c.FrameMissed(1) + + c.mu.Lock() + count := c.modules[1].missedCount + c.mu.Unlock() + + if count != i { + t.Errorf("missedCount after %d misses = %d, want %d", i, count, i) + } + } +} + +func TestFrameMissedClearsPending(t *testing.T) { + clock, _ := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 60) + + c.FrameRequested(1) + if c.PendingModules() != 1 { + t.Fatal("PendingModules() != 1 after request") + } + + c.FrameMissed(1) + if c.PendingModules() != 0 { + t.Errorf("PendingModules() = %d after miss, want 0", c.PendingModules()) + } +} + +func TestFrameMissedUnknownModule(t *testing.T) { + c := New() + + // Should not panic. + c.FrameMissed(999) +} + +func TestAdaptiveRateReduction(t *testing.T) { + clock, advance := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 10) // 10 FPS = 100ms base interval + + // Request and deliver first frame to establish lastRequest. + c.FrameRequested(1) + c.FrameDelivered(1) + + // Accumulate 3 misses (threshold). + c.FrameRequested(1) + c.FrameMissed(1) + c.FrameRequested(1) + c.FrameMissed(1) + c.FrameRequested(1) + c.FrameMissed(1) + + // After 3 misses, effective interval should be 200ms (doubled). + // Advance 100ms from last request: should NOT be ready. + advance(100 * time.Millisecond) + if c.ShouldRequest(1) { + t.Error("ShouldRequest() = true at 100ms (base interval) after 3 misses, want false (doubled to 200ms)") + } + + // Advance another 100ms (total 200ms): now should be ready. + advance(100 * time.Millisecond) + if !c.ShouldRequest(1) { + t.Error("ShouldRequest() = false at 200ms (doubled interval) after 3 misses, want true") + } +} + +func TestAdaptiveRateResetOnDelivery(t *testing.T) { + clock, advance := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 10) // 10 FPS = 100ms base interval + + c.FrameRequested(1) + c.FrameDelivered(1) + + // Accumulate 3 misses. + c.FrameRequested(1) + c.FrameMissed(1) + c.FrameRequested(1) + c.FrameMissed(1) + c.FrameRequested(1) + c.FrameMissed(1) + + // Now deliver a successful frame. This resets missedCount. + c.FrameRequested(1) + c.FrameDelivered(1) + + // After delivery, effective interval should be back to 100ms. + advance(100 * time.Millisecond) + if !c.ShouldRequest(1) { + t.Error("ShouldRequest() = false at 100ms after miss reset, want true (back to base interval)") + } +} + +func TestAdaptiveRateBelowThreshold(t *testing.T) { + clock, advance := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 10) // 10 FPS = 100ms interval + + c.FrameRequested(1) + c.FrameDelivered(1) + + // Only 2 misses (below threshold of 3). + c.FrameRequested(1) + c.FrameMissed(1) + c.FrameRequested(1) + c.FrameMissed(1) + + // Effective interval should still be 100ms (not doubled). + advance(100 * time.Millisecond) + if !c.ShouldRequest(1) { + t.Error("ShouldRequest() = false at 100ms with only 2 misses, want true (below threshold)") + } +} + +func TestPendingModules(t *testing.T) { + clock, _ := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 60) + c.AddModule(2, 30) + c.AddModule(3, 10) + + if c.PendingModules() != 0 { + t.Errorf("PendingModules() = %d initially, want 0", c.PendingModules()) + } + + c.FrameRequested(1) + c.FrameRequested(2) + if c.PendingModules() != 2 { + t.Errorf("PendingModules() = %d after 2 requests, want 2", c.PendingModules()) + } + + c.FrameDelivered(1) + if c.PendingModules() != 1 { + t.Errorf("PendingModules() = %d after 1 delivery, want 1", c.PendingModules()) + } + + c.FrameMissed(2) + if c.PendingModules() != 0 { + t.Errorf("PendingModules() = %d after miss, want 0", c.PendingModules()) + } +} + +func TestModuleCount(t *testing.T) { + c := New() + + tests := []struct { + action string + fn func() + want int + }{ + {"initial", func() {}, 0}, + {"add 1", func() { c.AddModule(1, 60) }, 1}, + {"add 2", func() { c.AddModule(2, 30) }, 2}, + {"add 3", func() { c.AddModule(3, 10) }, 3}, + {"remove 2", func() { c.RemoveModule(2) }, 2}, + {"remove unknown", func() { c.RemoveModule(999) }, 2}, + {"remove 1", func() { c.RemoveModule(1) }, 1}, + {"remove 3", func() { c.RemoveModule(3) }, 0}, + } + + for _, tt := range tests { + t.Run(tt.action, func(t *testing.T) { + tt.fn() + if got := c.ModuleCount(); got != tt.want { + t.Errorf("ModuleCount() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestHighFPSInterval(t *testing.T) { + clock, advance := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 120) // 120 FPS ~ 8.33ms interval + + c.FrameRequested(1) + c.FrameDelivered(1) + + // At 8ms: should not be ready yet. + advance(8 * time.Millisecond) + if c.ShouldRequest(1) { + t.Error("ShouldRequest() = true at 8ms for 120 FPS, want false") + } + + // At 9ms: should be ready (8.33ms interval). + advance(1 * time.Millisecond) + if !c.ShouldRequest(1) { + t.Error("ShouldRequest() = false at 9ms for 120 FPS, want true") + } +} + +func TestMultipleModulesIndependent(t *testing.T) { + clock, advance := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 10) // 100ms interval + c.AddModule(2, 5) // 200ms interval + + // Both should allow first request. + if !c.ShouldRequest(1) { + t.Error("module 1: ShouldRequest() = false for first request") + } + if !c.ShouldRequest(2) { + t.Error("module 2: ShouldRequest() = false for first request") + } + + c.FrameRequested(1) + c.FrameRequested(2) + c.FrameDelivered(1) + c.FrameDelivered(2) + + // At 100ms: module 1 ready, module 2 not. + advance(100 * time.Millisecond) + if !c.ShouldRequest(1) { + t.Error("module 1: ShouldRequest() = false at 100ms, want true") + } + if c.ShouldRequest(2) { + t.Error("module 2: ShouldRequest() = true at 100ms (needs 200ms), want false") + } + + // At 200ms: both ready. + advance(100 * time.Millisecond) + if !c.ShouldRequest(1) { + t.Error("module 1: ShouldRequest() = false at 200ms, want true") + } + if !c.ShouldRequest(2) { + t.Error("module 2: ShouldRequest() = false at 200ms, want true") + } +} + +func TestEffectiveInterval(t *testing.T) { + tests := []struct { + name string + missedCount int + baseMs int + wantMs int + }{ + {"no misses", 0, 100, 100}, + {"1 miss", 1, 100, 100}, + {"2 misses", 2, 100, 100}, + {"3 misses (threshold)", 3, 100, 200}, + {"4 misses", 4, 100, 200}, + {"10 misses", 10, 100, 200}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ms := &moduleState{ + interval: time.Duration(tt.baseMs) * time.Millisecond, + missedCount: tt.missedCount, + } + got := ms.effectiveInterval() + want := time.Duration(tt.wantMs) * time.Millisecond + if got != want { + t.Errorf("effectiveInterval() = %v, want %v", got, want) + } + }) + } +} + +func TestFullLifecycle(t *testing.T) { + clock, advance := testClock() + c := New(WithClock(clock)) + + // Register. + c.AddModule(1, 10) // 100ms interval + if c.ModuleCount() != 1 { + t.Fatal("ModuleCount() != 1 after add") + } + + // First request cycle. + if !c.ShouldRequest(1) { + t.Fatal("ShouldRequest() = false for first request") + } + c.FrameRequested(1) + if c.PendingModules() != 1 { + t.Fatal("PendingModules() != 1 after request") + } + c.FrameDelivered(1) + if c.PendingModules() != 0 { + t.Fatal("PendingModules() != 0 after delivery") + } + + // Second request after interval. + advance(100 * time.Millisecond) + if !c.ShouldRequest(1) { + t.Fatal("ShouldRequest() = false after interval") + } + c.FrameRequested(1) + c.FrameDelivered(1) + + // Miss some frames. + advance(100 * time.Millisecond) + c.FrameRequested(1) + c.FrameMissed(1) + advance(100 * time.Millisecond) + c.FrameRequested(1) + c.FrameMissed(1) + advance(100 * time.Millisecond) + c.FrameRequested(1) + c.FrameMissed(1) + + // Now at doubled interval. 100ms should not suffice. + advance(100 * time.Millisecond) + if c.ShouldRequest(1) { + t.Error("ShouldRequest() = true at base interval after 3 misses, want false") + } + + // 200ms should work. + advance(100 * time.Millisecond) + if !c.ShouldRequest(1) { + t.Error("ShouldRequest() = false at doubled interval, want true") + } + + // Deliver to reset. + c.FrameRequested(1) + c.FrameDelivered(1) + + // Back to normal interval. + advance(100 * time.Millisecond) + if !c.ShouldRequest(1) { + t.Error("ShouldRequest() = false after reset, want true") + } + + // Unregister. + c.RemoveModule(1) + if c.ModuleCount() != 0 { + t.Fatal("ModuleCount() != 0 after remove") + } + if c.ShouldRequest(1) { + t.Error("ShouldRequest() = true for removed module, want false") + } +} + +func TestConcurrentAccess(t *testing.T) { + c := New() + + const numModules = 100 + const iterations = 1000 + + // Add modules. + for i := uint64(0); i < numModules; i++ { + c.AddModule(i, uint16(10+i%50)) + } + + var wg sync.WaitGroup + + // Concurrent ShouldRequest calls. + wg.Add(1) + go func() { + defer wg.Done() + for range iterations { + for i := uint64(0); i < numModules; i++ { + c.ShouldRequest(i) + } + } + }() + + // Concurrent FrameRequested/Delivered calls. + wg.Add(1) + go func() { + defer wg.Done() + for range iterations { + for i := uint64(0); i < numModules; i++ { + c.FrameRequested(i) + c.FrameDelivered(i) + } + } + }() + + // Concurrent FrameMissed calls. + wg.Add(1) + go func() { + defer wg.Done() + for range iterations { + for i := uint64(0); i < numModules; i++ { + c.FrameMissed(i) + } + } + }() + + // Concurrent PendingModules/ModuleCount calls. + wg.Add(1) + go func() { + defer wg.Done() + for range iterations { + c.PendingModules() + c.ModuleCount() + } + }() + + // Concurrent Add/Remove. + wg.Add(1) + go func() { + defer wg.Done() + for range iterations { + c.AddModule(numModules+1, 60) + c.RemoveModule(numModules + 1) + } + }() + + wg.Wait() +} + +func BenchmarkShouldRequest(b *testing.B) { + clock, _ := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 60) + + b.ResetTimer() + for range b.N { + c.ShouldRequest(1) + } +} + +func BenchmarkShouldRequestMultipleModules(b *testing.B) { + clock, _ := testClock() + c := New(WithClock(clock)) + + const numModules = 16 + for i := uint64(1); i <= numModules; i++ { + c.AddModule(i, uint16(10*i)) + } + + b.ResetTimer() + for range b.N { + for i := uint64(1); i <= numModules; i++ { + c.ShouldRequest(i) + } + } +} + +func BenchmarkFrameRequestDeliverCycle(b *testing.B) { + clock, _ := testClock() + c := New(WithClock(clock)) + + c.AddModule(1, 60) + + b.ResetTimer() + for range b.N { + c.FrameRequested(1) + c.FrameDelivered(1) + } +} diff --git a/internal/flow/doc.go b/internal/flow/doc.go new file mode 100644 index 0000000..1f1687f --- /dev/null +++ b/internal/flow/doc.go @@ -0,0 +1,17 @@ +// Package flow provides pull-based frame pacing for the compose library. +// +// The compositor uses a [Controller] to decide when to request frames from each +// connected module. This implements the Wayland frame callback pattern: modules +// render only when the compositor asks, preventing them from flooding the +// compositor with unsolicited frames. +// +// The controller is passive (no goroutines, no timers). The compositor's render +// loop polls [Controller.ShouldRequest] on each tick and calls +// [Controller.FrameRequested] after sending a request to a module. When a frame +// arrives, the compositor calls [Controller.FrameDelivered]. If a module fails +// to respond in time, [Controller.FrameMissed] adaptively reduces the effective +// request rate. +// +// All exported methods are safe for concurrent use from multiple goroutines. +// The package is standalone with no internal dependencies. +package flow diff --git a/internal/protocol/doc.go b/internal/protocol/doc.go new file mode 100644 index 0000000..f645ec2 --- /dev/null +++ b/internal/protocol/doc.go @@ -0,0 +1,15 @@ +// Package protocol defines the wire format for gogpu/compose inter-process +// communication. It is the leaf package in the internal dependency graph — +// imported by transport implementations but importing nothing internal itself. +// +// The protocol is built around a fixed 64-byte binary frame header that +// precedes every message on the wire. Headers use little-endian byte order +// and are designed to be cache-line aligned on modern hardware. +// +// All encode/decode functions operate on caller-provided byte buffers to +// achieve zero allocations on the hot path. This is critical for sustained +// 60 FPS frame delivery. +// +// Wire format version: 1 +// Magic bytes: 0x43 0x4F 0x4D 0x50 ("COMP") +package protocol diff --git a/internal/protocol/handshake.go b/internal/protocol/handshake.go new file mode 100644 index 0000000..ce2af8d --- /dev/null +++ b/internal/protocol/handshake.go @@ -0,0 +1,326 @@ +package protocol + +import ( + "encoding/binary" + "errors" + "fmt" +) + +// HandshakeSize is the fixed size in bytes of both HelloMsg and WelcomeMsg. +const HandshakeSize = 128 + +// TransportType identifies the preferred or granted transport mechanism. +type TransportType uint8 + +const ( + // TransportSocket indicates Unix domain socket transport. + TransportSocket TransportType = 0 + + // TransportShm indicates shared memory transport. + TransportShm TransportType = 1 +) + +// String returns the human-readable name of the transport type. +func (t TransportType) String() string { + switch t { + case TransportSocket: + return "Socket" + case TransportShm: + return "SharedMemory" + default: + return fmt.Sprintf("TransportType(%d)", uint8(t)) + } +} + +// HelloMsg is sent by the module to the compositor during the handshake phase. +// Fixed 128 bytes. +// +// Layout: Magic(4) + Version(2) + Name(64) + Width(2) + Height(2) + +// PreferredFPS(2) + Transport(1) + Reserved(51) = 128 +type HelloMsg struct { + // Magic must be "COMP" (0x43, 0x4F, 0x4D, 0x50). + Magic [4]byte + + // Version is the protocol version the module supports. + Version uint16 + + // Name is the null-terminated human-readable module name (max 63 chars + NUL). + Name [64]byte + + // Width is the initial frame width in pixels. + Width uint16 + + // Height is the initial frame height in pixels. + Height uint16 + + // PreferredFPS is the module's preferred frame rate (e.g., 1 for clock, 60 for animation). + PreferredFPS uint16 + + // Transport is the module's preferred transport mechanism. + Transport TransportType + + // Reserved is padding for future use. Must be zero. + // Size: 128 - 4 - 2 - 64 - 2 - 2 - 2 - 1 = 51 + Reserved [51]byte +} + +// HelloMsg field offsets. +const ( + helloOffMagic = 0 + helloOffVersion = 4 + helloOffName = 6 + helloOffWidth = 70 + helloOffHeight = 72 + helloOffPreferredFPS = 74 + helloOffTransport = 76 + helloOffReserved = 77 +) + +// WelcomeMsg is sent by the compositor to the module after accepting the handshake. +// Fixed 128 bytes. +// +// Layout: Magic(4) + Version(2) + ModuleID(8) + Accepted(1) + Transport(1) + +// MinVersion(2) + MaxVersion(2) + Reserved(108) = 128 +type WelcomeMsg struct { + // Magic must be "COMP" (0x43, 0x4F, 0x4D, 0x50). + Magic [4]byte + + // Version is the protocol version the compositor selected. + Version uint16 + + // ModuleID is the compositor-assigned unique identifier for this module. + ModuleID uint64 + + // Accepted is 1 if the connection was accepted, 0 if rejected. + Accepted uint8 + + // Transport is the transport mechanism the compositor granted. + Transport TransportType + + // MinVersion is the minimum protocol version the compositor supports. + MinVersion uint16 + + // MaxVersion is the maximum protocol version the compositor supports. + MaxVersion uint16 + + // Reserved is padding for future use. Must be zero. + // Size: 128 - 4 - 2 - 8 - 1 - 1 - 2 - 2 = 108 + Reserved [108]byte +} + +// WelcomeMsg field offsets. +const ( + welcomeOffMagic = 0 + welcomeOffVersion = 4 + welcomeOffModuleID = 6 + welcomeOffAccepted = 14 + welcomeOffTransport = 15 + welcomeOffMinVersion = 16 + welcomeOffMaxVersion = 18 + welcomeOffReserved = 20 +) + +// Handshake errors. +var ( + // ErrHandshakeBufTooSmall is returned when the buffer is smaller than HandshakeSize. + ErrHandshakeBufTooSmall = errors.New("protocol: handshake buffer too small (need 128 bytes)") + + // ErrHandshakeInvalidMagic is returned when handshake magic bytes are wrong. + ErrHandshakeInvalidMagic = errors.New("protocol: handshake invalid magic (expected 0x434F4D50)") + + // ErrRejected is returned when the compositor rejected the connection. + ErrRejected = errors.New("protocol: connection rejected by compositor") +) + +// EncodeHello writes a HelloMsg into buf using little-endian byte order. +// buf must be at least HandshakeSize (128) bytes. Never allocates. +func EncodeHello(msg *HelloMsg, buf []byte) error { + if len(buf) < HandshakeSize { + return ErrHandshakeBufTooSmall + } + + // Zero the buffer to ensure reserved bytes are clean. + clear(buf[:HandshakeSize]) + + // Magic + buf[helloOffMagic] = msg.Magic[0] + buf[helloOffMagic+1] = msg.Magic[1] + buf[helloOffMagic+2] = msg.Magic[2] + buf[helloOffMagic+3] = msg.Magic[3] + + // Version + binary.LittleEndian.PutUint16(buf[helloOffVersion:], msg.Version) + + // Name (64 bytes, null-terminated) + copy(buf[helloOffName:helloOffName+64], msg.Name[:]) + + // Width + binary.LittleEndian.PutUint16(buf[helloOffWidth:], msg.Width) + + // Height + binary.LittleEndian.PutUint16(buf[helloOffHeight:], msg.Height) + + // PreferredFPS + binary.LittleEndian.PutUint16(buf[helloOffPreferredFPS:], msg.PreferredFPS) + + // Transport + buf[helloOffTransport] = uint8(msg.Transport) + + // Reserved already zeroed by clear() + + return nil +} + +// DecodeHello reads a HelloMsg from buf and validates the magic bytes. +// buf must be at least HandshakeSize (128) bytes. Never allocates. +func DecodeHello(buf []byte) (HelloMsg, error) { + if len(buf) < HandshakeSize { + return HelloMsg{}, ErrHandshakeBufTooSmall + } + + var msg HelloMsg + + // Magic + msg.Magic[0] = buf[helloOffMagic] + msg.Magic[1] = buf[helloOffMagic+1] + msg.Magic[2] = buf[helloOffMagic+2] + msg.Magic[3] = buf[helloOffMagic+3] + + if msg.Magic != Magic { + return HelloMsg{}, fmt.Errorf("%w: got [0x%02X 0x%02X 0x%02X 0x%02X]", + ErrHandshakeInvalidMagic, msg.Magic[0], msg.Magic[1], msg.Magic[2], msg.Magic[3]) + } + + // Version + msg.Version = binary.LittleEndian.Uint16(buf[helloOffVersion:]) + + // Name + copy(msg.Name[:], buf[helloOffName:helloOffName+64]) + + // Width + msg.Width = binary.LittleEndian.Uint16(buf[helloOffWidth:]) + + // Height + msg.Height = binary.LittleEndian.Uint16(buf[helloOffHeight:]) + + // PreferredFPS + msg.PreferredFPS = binary.LittleEndian.Uint16(buf[helloOffPreferredFPS:]) + + // Transport + msg.Transport = TransportType(buf[helloOffTransport]) + + // Reserved + copy(msg.Reserved[:], buf[helloOffReserved:helloOffReserved+51]) + + return msg, nil +} + +// EncodeWelcome writes a WelcomeMsg into buf using little-endian byte order. +// buf must be at least HandshakeSize (128) bytes. Never allocates. +func EncodeWelcome(msg *WelcomeMsg, buf []byte) error { + if len(buf) < HandshakeSize { + return ErrHandshakeBufTooSmall + } + + // Zero the buffer to ensure reserved bytes are clean. + clear(buf[:HandshakeSize]) + + // Magic + buf[welcomeOffMagic] = msg.Magic[0] + buf[welcomeOffMagic+1] = msg.Magic[1] + buf[welcomeOffMagic+2] = msg.Magic[2] + buf[welcomeOffMagic+3] = msg.Magic[3] + + // Version + binary.LittleEndian.PutUint16(buf[welcomeOffVersion:], msg.Version) + + // ModuleID + binary.LittleEndian.PutUint64(buf[welcomeOffModuleID:], msg.ModuleID) + + // Accepted + buf[welcomeOffAccepted] = msg.Accepted + + // Transport + buf[welcomeOffTransport] = uint8(msg.Transport) + + // MinVersion + binary.LittleEndian.PutUint16(buf[welcomeOffMinVersion:], msg.MinVersion) + + // MaxVersion + binary.LittleEndian.PutUint16(buf[welcomeOffMaxVersion:], msg.MaxVersion) + + // Reserved already zeroed by clear() + + return nil +} + +// DecodeWelcome reads a WelcomeMsg from buf and validates the magic bytes. +// buf must be at least HandshakeSize (128) bytes. Never allocates. +func DecodeWelcome(buf []byte) (WelcomeMsg, error) { + if len(buf) < HandshakeSize { + return WelcomeMsg{}, ErrHandshakeBufTooSmall + } + + var msg WelcomeMsg + + // Magic + msg.Magic[0] = buf[welcomeOffMagic] + msg.Magic[1] = buf[welcomeOffMagic+1] + msg.Magic[2] = buf[welcomeOffMagic+2] + msg.Magic[3] = buf[welcomeOffMagic+3] + + if msg.Magic != Magic { + return WelcomeMsg{}, fmt.Errorf("%w: got [0x%02X 0x%02X 0x%02X 0x%02X]", + ErrHandshakeInvalidMagic, msg.Magic[0], msg.Magic[1], msg.Magic[2], msg.Magic[3]) + } + + // Version + msg.Version = binary.LittleEndian.Uint16(buf[welcomeOffVersion:]) + + // ModuleID + msg.ModuleID = binary.LittleEndian.Uint64(buf[welcomeOffModuleID:]) + + // Accepted + msg.Accepted = buf[welcomeOffAccepted] + + // Transport + msg.Transport = TransportType(buf[welcomeOffTransport]) + + // MinVersion + msg.MinVersion = binary.LittleEndian.Uint16(buf[welcomeOffMinVersion:]) + + // MaxVersion + msg.MaxVersion = binary.LittleEndian.Uint16(buf[welcomeOffMaxVersion:]) + + // Reserved + copy(msg.Reserved[:], buf[welcomeOffReserved:welcomeOffReserved+108]) + + return msg, nil +} + +// SetName copies a string name into the HelloMsg's fixed-size Name field. +// The name is null-terminated. If name is longer than 63 bytes, it is truncated. +func SetName(msg *HelloMsg, name string) { + // Clear the name field first. + clear(msg.Name[:]) + + // Copy up to 63 bytes (leaving room for null terminator). + maxLen := len(msg.Name) - 1 + n := len(name) + if n > maxLen { + n = maxLen + } + copy(msg.Name[:n], name) + // The field is already zero-filled, so null terminator is implicit. +} + +// GetName reads the null-terminated string from the HelloMsg's Name field. +func GetName(msg *HelloMsg) string { + for i, b := range msg.Name { + if b == 0 { + return string(msg.Name[:i]) + } + } + // No null terminator found — return all 64 bytes as string. + return string(msg.Name[:]) +} diff --git a/internal/protocol/handshake_test.go b/internal/protocol/handshake_test.go new file mode 100644 index 0000000..46a08f7 --- /dev/null +++ b/internal/protocol/handshake_test.go @@ -0,0 +1,566 @@ +package protocol + +import ( + "errors" + "math" + "strings" + "testing" +) + +func TestHandshakeSize(t *testing.T) { + if HandshakeSize != 128 { + t.Fatalf("HandshakeSize = %d, want 128", HandshakeSize) + } +} + +func TestHelloMsg_RoundTrip(t *testing.T) { + tests := []struct { + name string + msg HelloMsg + }{ + { + name: "Typical", + msg: func() HelloMsg { + var m HelloMsg + m.Magic = Magic + m.Version = ProtocolVersion + SetName(&m, "clock") + m.Width = 400 + m.Height = 120 + m.PreferredFPS = 1 + m.Transport = TransportSocket + return m + }(), + }, + { + name: "AnimatedModule", + msg: func() HelloMsg { + var m HelloMsg + m.Magic = Magic + m.Version = ProtocolVersion + SetName(&m, "notification-popup") + m.Width = 1920 + m.Height = 1080 + m.PreferredFPS = 60 + m.Transport = TransportShm + return m + }(), + }, + { + name: "MaxValues", + msg: func() HelloMsg { + var m HelloMsg + m.Magic = Magic + m.Version = math.MaxUint16 + SetName(&m, strings.Repeat("x", 63)) + m.Width = math.MaxUint16 + m.Height = math.MaxUint16 + m.PreferredFPS = math.MaxUint16 + m.Transport = TransportType(0xFF) + return m + }(), + }, + { + name: "EmptyName", + msg: func() HelloMsg { + var m HelloMsg + m.Magic = Magic + m.Version = ProtocolVersion + m.Width = 100 + m.Height = 100 + m.PreferredFPS = 30 + return m + }(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, HandshakeSize) + if err := EncodeHello(&tt.msg, buf); err != nil { + t.Fatalf("EncodeHello() error: %v", err) + } + + got, err := DecodeHello(buf) + if err != nil { + t.Fatalf("DecodeHello() error: %v", err) + } + + if got != tt.msg { + t.Errorf("round-trip mismatch:\n got: %+v\n want: %+v", got, tt.msg) + } + }) + } +} + +func TestWelcomeMsg_RoundTrip(t *testing.T) { + tests := []struct { + name string + msg WelcomeMsg + }{ + { + name: "Accepted", + msg: WelcomeMsg{ + Magic: Magic, + Version: ProtocolVersion, + ModuleID: 42, + Accepted: 1, + Transport: TransportSocket, + MinVersion: 1, + MaxVersion: 1, + }, + }, + { + name: "Rejected", + msg: WelcomeMsg{ + Magic: Magic, + Version: ProtocolVersion, + ModuleID: 0, + Accepted: 0, + Transport: TransportSocket, + MinVersion: 1, + MaxVersion: 3, + }, + }, + { + name: "ShmGranted", + msg: WelcomeMsg{ + Magic: Magic, + Version: ProtocolVersion, + ModuleID: 99, + Accepted: 1, + Transport: TransportShm, + MinVersion: 1, + MaxVersion: 2, + }, + }, + { + name: "MaxValues", + msg: WelcomeMsg{ + Magic: Magic, + Version: math.MaxUint16, + ModuleID: math.MaxUint64, + Accepted: 0xFF, + Transport: TransportType(0xFF), + MinVersion: math.MaxUint16, + MaxVersion: math.MaxUint16, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, HandshakeSize) + if err := EncodeWelcome(&tt.msg, buf); err != nil { + t.Fatalf("EncodeWelcome() error: %v", err) + } + + got, err := DecodeWelcome(buf) + if err != nil { + t.Fatalf("DecodeWelcome() error: %v", err) + } + + if got != tt.msg { + t.Errorf("round-trip mismatch:\n got: %+v\n want: %+v", got, tt.msg) + } + }) + } +} + +func TestEncodeHello_BufferTooSmall(t *testing.T) { + msg := HelloMsg{Magic: Magic} + tests := []struct { + name string + size int + }{ + {"Zero", 0}, + {"Half", 64}, + {"AlmostFull", 127}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, tt.size) + err := EncodeHello(&msg, buf) + if !errors.Is(err, ErrHandshakeBufTooSmall) { + t.Errorf("EncodeHello() with %d-byte buf: got error %v, want ErrHandshakeBufTooSmall", tt.size, err) + } + }) + } +} + +func TestDecodeHello_BufferTooSmall(t *testing.T) { + tests := []struct { + name string + size int + }{ + {"Zero", 0}, + {"Half", 64}, + {"AlmostFull", 127}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, tt.size) + _, err := DecodeHello(buf) + if !errors.Is(err, ErrHandshakeBufTooSmall) { + t.Errorf("DecodeHello() with %d-byte buf: got error %v, want ErrHandshakeBufTooSmall", tt.size, err) + } + }) + } +} + +func TestEncodeWelcome_BufferTooSmall(t *testing.T) { + msg := WelcomeMsg{Magic: Magic} + tests := []struct { + name string + size int + }{ + {"Zero", 0}, + {"Half", 64}, + {"AlmostFull", 127}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, tt.size) + err := EncodeWelcome(&msg, buf) + if !errors.Is(err, ErrHandshakeBufTooSmall) { + t.Errorf("EncodeWelcome() with %d-byte buf: got error %v, want ErrHandshakeBufTooSmall", tt.size, err) + } + }) + } +} + +func TestDecodeWelcome_BufferTooSmall(t *testing.T) { + tests := []struct { + name string + size int + }{ + {"Zero", 0}, + {"Half", 64}, + {"AlmostFull", 127}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, tt.size) + _, err := DecodeWelcome(buf) + if !errors.Is(err, ErrHandshakeBufTooSmall) { + t.Errorf("DecodeWelcome() with %d-byte buf: got error %v, want ErrHandshakeBufTooSmall", tt.size, err) + } + }) + } +} + +func TestDecodeHello_InvalidMagic(t *testing.T) { + tests := []struct { + name string + magic [4]byte + }{ + {"AllZeros", [4]byte{0, 0, 0, 0}}, + {"WrongFirst", [4]byte{0x00, 0x4F, 0x4D, 0x50}}, + {"Reversed", [4]byte{0x50, 0x4D, 0x4F, 0x43}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, HandshakeSize) + buf[0] = tt.magic[0] + buf[1] = tt.magic[1] + buf[2] = tt.magic[2] + buf[3] = tt.magic[3] + + _, err := DecodeHello(buf) + if err == nil { + t.Error("DecodeHello() with invalid magic: expected error, got nil") + } + }) + } +} + +func TestDecodeWelcome_InvalidMagic(t *testing.T) { + tests := []struct { + name string + magic [4]byte + }{ + {"AllZeros", [4]byte{0, 0, 0, 0}}, + {"WrongLast", [4]byte{0x43, 0x4F, 0x4D, 0x00}}, + {"AllOnes", [4]byte{0xFF, 0xFF, 0xFF, 0xFF}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, HandshakeSize) + buf[0] = tt.magic[0] + buf[1] = tt.magic[1] + buf[2] = tt.magic[2] + buf[3] = tt.magic[3] + + _, err := DecodeWelcome(buf) + if err == nil { + t.Error("DecodeWelcome() with invalid magic: expected error, got nil") + } + }) + } +} + +func TestSetName_GetName(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"Short", "clock", "clock"}, + {"Empty", "", ""}, + {"ExactMax", strings.Repeat("a", 63), strings.Repeat("a", 63)}, + {"Overflow", strings.Repeat("b", 100), strings.Repeat("b", 63)}, + {"Unicode", "часы", "часы"}, + {"WithSpaces", "my module name", "my module name"}, + {"SingleChar", "x", "x"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var msg HelloMsg + SetName(&msg, tt.input) + got := GetName(&msg) + if got != tt.want { + t.Errorf("SetName/GetName(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestSetName_NullTerminated(t *testing.T) { + var msg HelloMsg + SetName(&msg, "test") + + // Verify null terminator exists at position 4. + if msg.Name[4] != 0 { + t.Errorf("Name[4] = 0x%02X, want 0x00 (null terminator)", msg.Name[4]) + } + // Verify rest is zeroed. + for i := 5; i < len(msg.Name); i++ { + if msg.Name[i] != 0 { + t.Errorf("Name[%d] = 0x%02X, want 0x00", i, msg.Name[i]) + } + } +} + +func TestSetName_ClearsOldName(t *testing.T) { + var msg HelloMsg + SetName(&msg, strings.Repeat("x", 63)) + SetName(&msg, "ab") + + got := GetName(&msg) + if got != "ab" { + t.Errorf("after overwrite, GetName() = %q, want %q", got, "ab") + } + // Verify bytes after "ab" are zero. + if msg.Name[2] != 0 { + t.Errorf("Name[2] = 0x%02X, want 0x00 after overwrite", msg.Name[2]) + } +} + +func TestGetName_NoNullTerminator(t *testing.T) { + // Edge case: all 64 bytes are non-zero (no null terminator). + var msg HelloMsg + for i := range msg.Name { + msg.Name[i] = 'z' + } + + got := GetName(&msg) + if len(got) != 64 { + t.Errorf("GetName() with no NUL: len = %d, want 64", len(got)) + } +} + +func TestTransportType_String(t *testing.T) { + tests := []struct { + name string + tr TransportType + want string + }{ + {"Socket", TransportSocket, "Socket"}, + {"Shm", TransportShm, "SharedMemory"}, + {"Unknown", TransportType(99), "TransportType(99)"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.tr.String() + if got != tt.want { + t.Errorf("TransportType(%d).String() = %q, want %q", tt.tr, got, tt.want) + } + }) + } +} + +func TestEncodeHello_ReservedZeroed(t *testing.T) { + msg := HelloMsg{ + Magic: Magic, + Version: ProtocolVersion, + } + SetName(&msg, "test") + + buf := make([]byte, HandshakeSize) + // Fill with non-zero. + for i := range buf { + buf[i] = 0xFF + } + + if err := EncodeHello(&msg, buf); err != nil { + t.Fatalf("EncodeHello() error: %v", err) + } + + // Verify reserved bytes are zeroed. + for i := helloOffReserved; i < helloOffReserved+51; i++ { + if buf[i] != 0 { + t.Errorf("reserved byte buf[%d] = 0x%02X, want 0x00", i, buf[i]) + } + } +} + +func TestEncodeWelcome_ReservedZeroed(t *testing.T) { + msg := WelcomeMsg{ + Magic: Magic, + Version: ProtocolVersion, + ModuleID: 1, + Accepted: 1, + } + + buf := make([]byte, HandshakeSize) + // Fill with non-zero. + for i := range buf { + buf[i] = 0xFF + } + + if err := EncodeWelcome(&msg, buf); err != nil { + t.Fatalf("EncodeWelcome() error: %v", err) + } + + // Verify reserved bytes are zeroed. + for i := welcomeOffReserved; i < welcomeOffReserved+108; i++ { + if buf[i] != 0 { + t.Errorf("reserved byte buf[%d] = 0x%02X, want 0x00", i, buf[i]) + } + } +} + +func TestHelloMsg_LargerBuffer(t *testing.T) { + msg := HelloMsg{ + Magic: Magic, + Version: ProtocolVersion, + } + SetName(&msg, "test") + + buf := make([]byte, 256) + for i := range buf { + buf[i] = 0xAA + } + + if err := EncodeHello(&msg, buf); err != nil { + t.Fatalf("EncodeHello() error: %v", err) + } + + // Bytes beyond HandshakeSize should be untouched. + for i := HandshakeSize; i < len(buf); i++ { + if buf[i] != 0xAA { + t.Errorf("buf[%d] = 0x%02X, want 0xAA (sentinel should be untouched)", i, buf[i]) + } + } +} + +func TestWelcomeMsg_LargerBuffer(t *testing.T) { + msg := WelcomeMsg{ + Magic: Magic, + Version: ProtocolVersion, + ModuleID: 7, + Accepted: 1, + } + + buf := make([]byte, 256) + for i := range buf { + buf[i] = 0xBB + } + + if err := EncodeWelcome(&msg, buf); err != nil { + t.Fatalf("EncodeWelcome() error: %v", err) + } + + // Bytes beyond HandshakeSize should be untouched. + for i := HandshakeSize; i < len(buf); i++ { + if buf[i] != 0xBB { + t.Errorf("buf[%d] = 0x%02X, want 0xBB (sentinel should be untouched)", i, buf[i]) + } + } +} + +func BenchmarkEncodeHello(b *testing.B) { + msg := HelloMsg{ + Magic: Magic, + Version: ProtocolVersion, + Width: 400, + Height: 120, + PreferredFPS: 60, + Transport: TransportSocket, + } + SetName(&msg, "benchmark-module") + buf := make([]byte, HandshakeSize) + + b.ReportAllocs() + b.ResetTimer() + for b.Loop() { + _ = EncodeHello(&msg, buf) + } +} + +func BenchmarkDecodeHello(b *testing.B) { + msg := HelloMsg{ + Magic: Magic, + Version: ProtocolVersion, + Width: 400, + Height: 120, + PreferredFPS: 60, + Transport: TransportSocket, + } + SetName(&msg, "benchmark-module") + buf := make([]byte, HandshakeSize) + _ = EncodeHello(&msg, buf) + + b.ReportAllocs() + b.ResetTimer() + for b.Loop() { + _, _ = DecodeHello(buf) + } +} + +func BenchmarkEncodeWelcome(b *testing.B) { + msg := WelcomeMsg{ + Magic: Magic, + Version: ProtocolVersion, + ModuleID: 42, + Accepted: 1, + Transport: TransportSocket, + MinVersion: 1, + MaxVersion: 1, + } + buf := make([]byte, HandshakeSize) + + b.ReportAllocs() + b.ResetTimer() + for b.Loop() { + _ = EncodeWelcome(&msg, buf) + } +} + +func BenchmarkDecodeWelcome(b *testing.B) { + msg := WelcomeMsg{ + Magic: Magic, + Version: ProtocolVersion, + ModuleID: 42, + Accepted: 1, + Transport: TransportSocket, + MinVersion: 1, + MaxVersion: 1, + } + buf := make([]byte, HandshakeSize) + _ = EncodeWelcome(&msg, buf) + + b.ReportAllocs() + b.ResetTimer() + for b.Loop() { + _, _ = DecodeWelcome(buf) + } +} diff --git a/internal/protocol/header.go b/internal/protocol/header.go new file mode 100644 index 0000000..3d9f9f9 --- /dev/null +++ b/internal/protocol/header.go @@ -0,0 +1,270 @@ +package protocol + +import ( + "encoding/binary" + "errors" + "fmt" +) + +// HeaderSize is the fixed size in bytes of the wire frame header. +// It is designed to be cache-line aligned (64 bytes on modern CPUs). +const HeaderSize = 64 + +// Magic is the 4-byte protocol identifier at the start of every header. +// ASCII: "COMP" (0x43, 0x4F, 0x4D, 0x50). +var Magic = [4]byte{0x43, 0x4F, 0x4D, 0x50} + +// ProtocolVersion is the current wire protocol version. +const ProtocolVersion uint16 = 1 + +// Header is the 64-byte frame header that precedes every message on the wire. +// All multi-byte fields are little-endian encoded. +type Header struct { + // Magic is the protocol identifier (must be "COMP"). + Magic [4]byte + + // Version is the protocol version number. + Version uint16 + + // MsgType identifies the message kind (Frame, Handshake, Ack, etc.). + MsgType MsgType + + // Flags is a bitfield (DirtyValid, Compressed, Keyframe). + Flags Flag + + // ModuleID is the compositor-assigned module identifier. + ModuleID uint64 + + // Sequence is the monotonically increasing frame counter per module. + Sequence uint64 + + // TimestampNs is the monotonic clock timestamp in nanoseconds. + TimestampNs int64 + + // Width is the frame width in pixels. + Width uint16 + + // Height is the frame height in pixels. + Height uint16 + + // Stride is the number of bytes per row (typically Width * 4). + Stride uint32 + + // DirtyX is the X offset of the dirty rectangle. + DirtyX uint16 + + // DirtyY is the Y offset of the dirty rectangle. + DirtyY uint16 + + // DirtyW is the width of the dirty rectangle. + DirtyW uint16 + + // DirtyH is the height of the dirty rectangle. + DirtyH uint16 + + // PixelFormat identifies the pixel encoding (RGBA8, BGRA8). + PixelFormat PixelFormat + + // Compression identifies the payload compression algorithm. + Compression Compression + + // Reserved is padding for future use. Must be zero. + Reserved [6]byte + + // PayloadSize is the number of payload bytes following this header. + PayloadSize uint32 + + // UncompressedSize is the original payload size before compression. + // When Compression is None, this equals PayloadSize. + UncompressedSize uint32 +} + +// Field byte offsets within the 64-byte header. +const ( + offMagic = 0 + offVersion = 4 + offMsgType = 6 + offFlags = 7 + offModuleID = 8 + offSequence = 16 + offTimestampNs = 24 + offWidth = 32 + offHeight = 34 + offStride = 36 + offDirtyX = 40 + offDirtyY = 42 + offDirtyW = 44 + offDirtyH = 46 + offPixelFormat = 48 + offCompression = 49 + offReserved = 50 + offPayloadSize = 56 + offUncompressedSize = 60 +) + +// Errors returned by Encode and Decode. +var ( + // ErrBufferTooSmall is returned when the provided buffer is smaller than HeaderSize. + ErrBufferTooSmall = errors.New("protocol: buffer too small (need 64 bytes)") + + // ErrInvalidMagic is returned when the decoded magic bytes do not match "COMP". + ErrInvalidMagic = errors.New("protocol: invalid magic (expected 0x434F4D50)") + + // ErrUnknownMsgType is returned when the decoded message type is not recognized. + ErrUnknownMsgType = errors.New("protocol: unknown message type") +) + +// Encode writes the header h into buf using little-endian byte order. +// buf must be at least HeaderSize (64) bytes. Encode never allocates. +func Encode(h *Header, buf []byte) error { + if len(buf) < HeaderSize { + return ErrBufferTooSmall + } + + // Magic (4 bytes) + buf[offMagic] = h.Magic[0] + buf[offMagic+1] = h.Magic[1] + buf[offMagic+2] = h.Magic[2] + buf[offMagic+3] = h.Magic[3] + + // Version (2 bytes) + binary.LittleEndian.PutUint16(buf[offVersion:], h.Version) + + // MsgType (1 byte) + buf[offMsgType] = uint8(h.MsgType) + + // Flags (1 byte) + buf[offFlags] = uint8(h.Flags) + + // ModuleID (8 bytes) + binary.LittleEndian.PutUint64(buf[offModuleID:], h.ModuleID) + + // Sequence (8 bytes) + binary.LittleEndian.PutUint64(buf[offSequence:], h.Sequence) + + // TimestampNs (8 bytes, signed as uint64 bit pattern) + binary.LittleEndian.PutUint64(buf[offTimestampNs:], uint64(h.TimestampNs)) //nolint:gosec // bit-cast int64→uint64, no overflow + + // Width (2 bytes) + binary.LittleEndian.PutUint16(buf[offWidth:], h.Width) + + // Height (2 bytes) + binary.LittleEndian.PutUint16(buf[offHeight:], h.Height) + + // Stride (4 bytes) + binary.LittleEndian.PutUint32(buf[offStride:], h.Stride) + + // DirtyX (2 bytes) + binary.LittleEndian.PutUint16(buf[offDirtyX:], h.DirtyX) + + // DirtyY (2 bytes) + binary.LittleEndian.PutUint16(buf[offDirtyY:], h.DirtyY) + + // DirtyW (2 bytes) + binary.LittleEndian.PutUint16(buf[offDirtyW:], h.DirtyW) + + // DirtyH (2 bytes) + binary.LittleEndian.PutUint16(buf[offDirtyH:], h.DirtyH) + + // PixelFormat (1 byte) + buf[offPixelFormat] = uint8(h.PixelFormat) + + // Compression (1 byte) + buf[offCompression] = uint8(h.Compression) + + // Reserved (6 bytes, zero) + buf[offReserved] = 0 + buf[offReserved+1] = 0 + buf[offReserved+2] = 0 + buf[offReserved+3] = 0 + buf[offReserved+4] = 0 + buf[offReserved+5] = 0 + + // PayloadSize (4 bytes) + binary.LittleEndian.PutUint32(buf[offPayloadSize:], h.PayloadSize) + + // UncompressedSize (4 bytes) + binary.LittleEndian.PutUint32(buf[offUncompressedSize:], h.UncompressedSize) + + return nil +} + +// Decode reads a header from buf and validates the magic bytes and message type. +// buf must be at least HeaderSize (64) bytes. Decode never allocates. +func Decode(buf []byte) (Header, error) { + if len(buf) < HeaderSize { + return Header{}, ErrBufferTooSmall + } + + var h Header + + // Magic (4 bytes) + h.Magic[0] = buf[offMagic] + h.Magic[1] = buf[offMagic+1] + h.Magic[2] = buf[offMagic+2] + h.Magic[3] = buf[offMagic+3] + + if h.Magic != Magic { + return Header{}, fmt.Errorf("%w: got [0x%02X 0x%02X 0x%02X 0x%02X]", + ErrInvalidMagic, h.Magic[0], h.Magic[1], h.Magic[2], h.Magic[3]) + } + + // Version (2 bytes) + h.Version = binary.LittleEndian.Uint16(buf[offVersion:]) + + // MsgType (1 byte) + h.MsgType = MsgType(buf[offMsgType]) + if !h.MsgType.Valid() { + return Header{}, fmt.Errorf("%w: 0x%02X", ErrUnknownMsgType, uint8(h.MsgType)) + } + + // Flags (1 byte) + h.Flags = Flag(buf[offFlags]) + + // ModuleID (8 bytes) + h.ModuleID = binary.LittleEndian.Uint64(buf[offModuleID:]) + + // Sequence (8 bytes) + h.Sequence = binary.LittleEndian.Uint64(buf[offSequence:]) + + // TimestampNs (8 bytes) + h.TimestampNs = int64(binary.LittleEndian.Uint64(buf[offTimestampNs:])) //nolint:gosec // bit-cast uint64→int64, no overflow + + // Width (2 bytes) + h.Width = binary.LittleEndian.Uint16(buf[offWidth:]) + + // Height (2 bytes) + h.Height = binary.LittleEndian.Uint16(buf[offHeight:]) + + // Stride (4 bytes) + h.Stride = binary.LittleEndian.Uint32(buf[offStride:]) + + // DirtyX (2 bytes) + h.DirtyX = binary.LittleEndian.Uint16(buf[offDirtyX:]) + + // DirtyY (2 bytes) + h.DirtyY = binary.LittleEndian.Uint16(buf[offDirtyY:]) + + // DirtyW (2 bytes) + h.DirtyW = binary.LittleEndian.Uint16(buf[offDirtyW:]) + + // DirtyH (2 bytes) + h.DirtyH = binary.LittleEndian.Uint16(buf[offDirtyH:]) + + // PixelFormat (1 byte) + h.PixelFormat = PixelFormat(buf[offPixelFormat]) + + // Compression (1 byte) + h.Compression = Compression(buf[offCompression]) + + // Reserved (6 bytes) + copy(h.Reserved[:], buf[offReserved:offReserved+6]) + + // PayloadSize (4 bytes) + h.PayloadSize = binary.LittleEndian.Uint32(buf[offPayloadSize:]) + + // UncompressedSize (4 bytes) + h.UncompressedSize = binary.LittleEndian.Uint32(buf[offUncompressedSize:]) + + return h, nil +} diff --git a/internal/protocol/header_test.go b/internal/protocol/header_test.go new file mode 100644 index 0000000..5a029e7 --- /dev/null +++ b/internal/protocol/header_test.go @@ -0,0 +1,399 @@ +package protocol + +import ( + "errors" + "math" + "testing" +) + +func TestHeaderSize(t *testing.T) { + if HeaderSize != 64 { + t.Fatalf("HeaderSize = %d, want 64", HeaderSize) + } +} + +func TestHeader_RoundTrip(t *testing.T) { + tests := []struct { + name string + h Header + }{ + { + name: "TypicalFrame", + h: Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgFrame, + Flags: FlagDirtyValid | FlagKeyframe, + ModuleID: 42, + Sequence: 1001, + TimestampNs: 1_000_000_000, + Width: 1920, + Height: 1080, + Stride: 1920 * 4, + DirtyX: 100, + DirtyY: 200, + DirtyW: 300, + DirtyH: 400, + PixelFormat: PixelRGBA8, + Compression: CompressionNone, + PayloadSize: 1920 * 1080 * 4, + UncompressedSize: 1920 * 1080 * 4, + }, + }, + { + name: "CompressedFrame", + h: Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgFrame, + Flags: FlagCompressed, + ModuleID: 7, + Sequence: 55, + TimestampNs: -12345678, // negative timestamp is valid + Width: 800, + Height: 600, + Stride: 800 * 4, + PixelFormat: PixelBGRA8, + Compression: CompressionLZ4, + PayloadSize: 100000, + UncompressedSize: 800 * 600 * 4, + }, + }, + { + name: "Ack", + h: Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgAck, + }, + }, + { + name: "FrameRequest", + h: Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgFrameRequest, + ModuleID: 99, + Sequence: 12345, + }, + }, + { + name: "Disconnect", + h: Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgDisconnect, + ModuleID: 3, + }, + }, + { + name: "Resize", + h: Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgResize, + Width: 2560, + Height: 1440, + Stride: 2560 * 4, + }, + }, + { + name: "MaxValues", + h: Header{ + Magic: Magic, + Version: math.MaxUint16, + MsgType: MsgFrame, + Flags: Flag(0xFF), + ModuleID: math.MaxUint64, + Sequence: math.MaxUint64, + TimestampNs: math.MaxInt64, + Width: math.MaxUint16, + Height: math.MaxUint16, + Stride: math.MaxUint32, + DirtyX: math.MaxUint16, + DirtyY: math.MaxUint16, + DirtyW: math.MaxUint16, + DirtyH: math.MaxUint16, + PixelFormat: PixelFormat(0xFF), + Compression: Compression(0xFF), + PayloadSize: math.MaxUint32, + UncompressedSize: math.MaxUint32, + }, + }, + { + name: "ZeroValue", + h: Header{ + Magic: Magic, + MsgType: MsgFrame, + }, + }, + { + name: "NegativeTimestamp", + h: Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgHandshake, + TimestampNs: math.MinInt64, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, HeaderSize) + if err := Encode(&tt.h, buf); err != nil { + t.Fatalf("Encode() error: %v", err) + } + + got, err := Decode(buf) + if err != nil { + t.Fatalf("Decode() error: %v", err) + } + + if got != tt.h { + t.Errorf("round-trip mismatch:\n got: %+v\n want: %+v", got, tt.h) + } + }) + } +} + +func TestEncode_BufferTooSmall(t *testing.T) { + h := Header{Magic: Magic, MsgType: MsgFrame} + tests := []struct { + name string + size int + }{ + {"Zero", 0}, + {"One", 1}, + {"Half", 32}, + {"AlmostFull", 63}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, tt.size) + err := Encode(&h, buf) + if !errors.Is(err, ErrBufferTooSmall) { + t.Errorf("Encode() with %d-byte buf: got error %v, want ErrBufferTooSmall", tt.size, err) + } + }) + } +} + +func TestDecode_BufferTooSmall(t *testing.T) { + tests := []struct { + name string + size int + }{ + {"Zero", 0}, + {"One", 1}, + {"Half", 32}, + {"AlmostFull", 63}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, tt.size) + _, err := Decode(buf) + if !errors.Is(err, ErrBufferTooSmall) { + t.Errorf("Decode() with %d-byte buf: got error %v, want ErrBufferTooSmall", tt.size, err) + } + }) + } +} + +func TestDecode_InvalidMagic(t *testing.T) { + tests := []struct { + name string + magic [4]byte + }{ + {"AllZeros", [4]byte{0, 0, 0, 0}}, + {"WrongFirst", [4]byte{0x00, 0x4F, 0x4D, 0x50}}, + {"WrongLast", [4]byte{0x43, 0x4F, 0x4D, 0x00}}, + {"Reversed", [4]byte{0x50, 0x4D, 0x4F, 0x43}}, + {"AllOnes", [4]byte{0xFF, 0xFF, 0xFF, 0xFF}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, HeaderSize) + // Manually write magic to bypass Encode validation. + buf[0] = tt.magic[0] + buf[1] = tt.magic[1] + buf[2] = tt.magic[2] + buf[3] = tt.magic[3] + buf[offMsgType] = uint8(MsgFrame) + + _, err := Decode(buf) + if err == nil { + t.Errorf("Decode() with magic %v: expected error, got nil", tt.magic) + } + }) + } +} + +func TestDecode_UnknownMsgType(t *testing.T) { + tests := []struct { + name string + msgType uint8 + }{ + {"Zero", 0x00}, + {"Seven", 0x07}, + {"Max", 0xFF}, + {"JustAbove", 0x08}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, HeaderSize) + // Valid magic. + buf[0] = Magic[0] + buf[1] = Magic[1] + buf[2] = Magic[2] + buf[3] = Magic[3] + buf[offMsgType] = tt.msgType + + _, err := Decode(buf) + if err == nil { + t.Errorf("Decode() with MsgType 0x%02X: expected error, got nil", tt.msgType) + } + }) + } +} + +func TestEncode_LargerBuffer(t *testing.T) { + // Encode into a buffer larger than 64 bytes should work and not corrupt beyond. + h := Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgFrame, + } + buf := make([]byte, 128) + for i := range buf { + buf[i] = 0xAA // sentinel + } + + if err := Encode(&h, buf); err != nil { + t.Fatalf("Encode() error: %v", err) + } + + // Bytes beyond 64 should be untouched. + for i := HeaderSize; i < len(buf); i++ { + if buf[i] != 0xAA { + t.Errorf("buf[%d] = 0x%02X, want 0xAA (sentinel should be untouched)", i, buf[i]) + } + } +} + +func TestDecode_ExactBuffer(t *testing.T) { + // Decode from a buffer that is exactly 64 bytes. + h := Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgFrame, + ModuleID: 42, + PayloadSize: 1024, + } + buf := make([]byte, HeaderSize) + if err := Encode(&h, buf); err != nil { + t.Fatalf("Encode() error: %v", err) + } + + got, err := Decode(buf) + if err != nil { + t.Fatalf("Decode() error: %v", err) + } + if got.ModuleID != 42 { + t.Errorf("ModuleID = %d, want 42", got.ModuleID) + } + if got.PayloadSize != 1024 { + t.Errorf("PayloadSize = %d, want 1024", got.PayloadSize) + } +} + +func TestHeader_ReservedZeroed(t *testing.T) { + // After encoding, reserved bytes should be zero. + h := Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgFrame, + } + buf := make([]byte, HeaderSize) + // Fill with non-zero to prove Encode zeros them. + for i := range buf { + buf[i] = 0xFF + } + + if err := Encode(&h, buf); err != nil { + t.Fatalf("Encode() error: %v", err) + } + + for i := offReserved; i < offReserved+6; i++ { + if buf[i] != 0 { + t.Errorf("reserved byte buf[%d] = 0x%02X, want 0x00", i, buf[i]) + } + } +} + +func TestMagicBytes(t *testing.T) { + // Verify magic is "COMP" in ASCII. + if Magic != [4]byte{'C', 'O', 'M', 'P'} { + t.Errorf("Magic = %v, want {0x43, 0x4F, 0x4D, 0x50} ('COMP')", Magic) + } +} + +func BenchmarkHeaderEncode(b *testing.B) { + h := Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgFrame, + Flags: FlagDirtyValid | FlagCompressed, + ModuleID: 42, + Sequence: 1001, + TimestampNs: 1_000_000_000, + Width: 1920, + Height: 1080, + Stride: 1920 * 4, + DirtyX: 100, + DirtyY: 200, + DirtyW: 300, + DirtyH: 400, + PixelFormat: PixelRGBA8, + Compression: CompressionLZ4, + PayloadSize: 100000, + UncompressedSize: 1920 * 1080 * 4, + } + buf := make([]byte, HeaderSize) + + b.ReportAllocs() + b.ResetTimer() + for b.Loop() { + _ = Encode(&h, buf) + } +} + +func BenchmarkHeaderDecode(b *testing.B) { + h := Header{ + Magic: Magic, + Version: ProtocolVersion, + MsgType: MsgFrame, + Flags: FlagDirtyValid | FlagCompressed, + ModuleID: 42, + Sequence: 1001, + TimestampNs: 1_000_000_000, + Width: 1920, + Height: 1080, + Stride: 1920 * 4, + DirtyX: 100, + DirtyY: 200, + DirtyW: 300, + DirtyH: 400, + PixelFormat: PixelRGBA8, + Compression: CompressionLZ4, + PayloadSize: 100000, + UncompressedSize: 1920 * 1080 * 4, + } + buf := make([]byte, HeaderSize) + _ = Encode(&h, buf) + + b.ReportAllocs() + b.ResetTimer() + for b.Loop() { + _, _ = Decode(buf) + } +} diff --git a/internal/protocol/message.go b/internal/protocol/message.go new file mode 100644 index 0000000..13368a7 --- /dev/null +++ b/internal/protocol/message.go @@ -0,0 +1,174 @@ +package protocol + +import "fmt" + +// MsgType identifies the kind of message carried by a frame header. +type MsgType uint8 + +const ( + // MsgFrame carries a rendered pixel buffer from module to compositor. + MsgFrame MsgType = 0x01 + + // MsgHandshake is the initial connection negotiation message. + MsgHandshake MsgType = 0x02 + + // MsgAck acknowledges receipt of a frame (compositor to module). + MsgAck MsgType = 0x03 + + // MsgFrameRequest is sent by the compositor to request the next frame + // from a module (pull-based flow control, Wayland pattern). + MsgFrameRequest MsgType = 0x04 + + // MsgResize notifies a module that its frame dimensions have changed. + MsgResize MsgType = 0x05 + + // MsgDisconnect signals a graceful disconnection. + MsgDisconnect MsgType = 0x06 +) + +// String returns the human-readable name of the message type. +func (m MsgType) String() string { + switch m { + case MsgFrame: + return "Frame" + case MsgHandshake: + return "Handshake" + case MsgAck: + return "Ack" + case MsgFrameRequest: + return "FrameRequest" + case MsgResize: + return "Resize" + case MsgDisconnect: + return "Disconnect" + default: + return fmt.Sprintf("MsgType(0x%02X)", uint8(m)) + } +} + +// Valid reports whether m is a known message type. +func (m MsgType) Valid() bool { + switch m { + case MsgFrame, MsgHandshake, MsgAck, MsgFrameRequest, MsgResize, MsgDisconnect: + return true + default: + return false + } +} + +// Flag is a bitfield carried in the header's Flags byte. +type Flag uint8 + +const ( + // FlagDirtyValid indicates that the DirtyRect fields contain valid data. + // When not set, the entire frame is considered dirty (keyframe). + FlagDirtyValid Flag = 1 << iota + + // FlagCompressed indicates that the payload is compressed. + // The Compression field specifies the algorithm. + FlagCompressed + + // FlagKeyframe marks the frame as a full keyframe (no delta dependency). + FlagKeyframe +) + +// Has reports whether the flag f is set in the bitmask flags. +func (f Flag) Has(flag Flag) bool { + return f&flag != 0 +} + +// Set returns the bitmask with flag set. +func (f Flag) Set(flag Flag) Flag { + return f | flag +} + +// Clear returns the bitmask with flag cleared. +func (f Flag) Clear(flag Flag) Flag { + return f &^ flag +} + +// String returns the human-readable name of the flag bit. +func (f Flag) String() string { + switch f { + case FlagDirtyValid: + return "DirtyValid" + case FlagCompressed: + return "Compressed" + case FlagKeyframe: + return "Keyframe" + default: + return fmt.Sprintf("Flag(0x%02X)", uint8(f)) + } +} + +// PixelFormat identifies the pixel encoding of frame payload data. +type PixelFormat uint8 + +const ( + // PixelRGBA8 is 8-bit RGBA, 4 bytes per pixel, premultiplied alpha. + PixelRGBA8 PixelFormat = 0x00 + + // PixelBGRA8 is 8-bit BGRA, 4 bytes per pixel, premultiplied alpha. + // Common on Windows/DX12 surfaces. + PixelBGRA8 PixelFormat = 0x01 +) + +// String returns the human-readable name of the pixel format. +func (p PixelFormat) String() string { + switch p { + case PixelRGBA8: + return "RGBA8" + case PixelBGRA8: + return "BGRA8" + default: + return fmt.Sprintf("PixelFormat(0x%02X)", uint8(p)) + } +} + +// Valid reports whether p is a known pixel format. +func (p PixelFormat) Valid() bool { + switch p { + case PixelRGBA8, PixelBGRA8: + return true + default: + return false + } +} + +// Compression identifies the payload compression algorithm. +type Compression uint8 + +const ( + // CompressionNone means the payload is uncompressed raw pixels. + CompressionNone Compression = 0x00 + + // CompressionLZ4 uses the LZ4 block compression algorithm. + CompressionLZ4 Compression = 0x01 + + // CompressionZstd uses the Zstandard compression algorithm. + CompressionZstd Compression = 0x02 +) + +// String returns the human-readable name of the compression algorithm. +func (c Compression) String() string { + switch c { + case CompressionNone: + return "None" + case CompressionLZ4: + return "LZ4" + case CompressionZstd: + return "Zstd" + default: + return fmt.Sprintf("Compression(0x%02X)", uint8(c)) + } +} + +// Valid reports whether c is a known compression algorithm. +func (c Compression) Valid() bool { + switch c { + case CompressionNone, CompressionLZ4, CompressionZstd: + return true + default: + return false + } +} diff --git a/internal/protocol/message_test.go b/internal/protocol/message_test.go new file mode 100644 index 0000000..a3c3bab --- /dev/null +++ b/internal/protocol/message_test.go @@ -0,0 +1,272 @@ +package protocol + +import ( + "testing" +) + +func TestMsgType_String(t *testing.T) { + tests := []struct { + name string + m MsgType + want string + }{ + {"Frame", MsgFrame, "Frame"}, + {"Handshake", MsgHandshake, "Handshake"}, + {"Ack", MsgAck, "Ack"}, + {"FrameRequest", MsgFrameRequest, "FrameRequest"}, + {"Resize", MsgResize, "Resize"}, + {"Disconnect", MsgDisconnect, "Disconnect"}, + {"Unknown_0xFF", MsgType(0xFF), "MsgType(0xFF)"}, + {"Unknown_0x00", MsgType(0x00), "MsgType(0x00)"}, + {"Unknown_0x07", MsgType(0x07), "MsgType(0x07)"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.m.String() + if got != tt.want { + t.Errorf("MsgType(%d).String() = %q, want %q", tt.m, got, tt.want) + } + }) + } +} + +func TestMsgType_Valid(t *testing.T) { + tests := []struct { + name string + m MsgType + want bool + }{ + {"Frame", MsgFrame, true}, + {"Handshake", MsgHandshake, true}, + {"Ack", MsgAck, true}, + {"FrameRequest", MsgFrameRequest, true}, + {"Resize", MsgResize, true}, + {"Disconnect", MsgDisconnect, true}, + {"Zero", MsgType(0x00), false}, + {"Unknown_0x07", MsgType(0x07), false}, + {"Unknown_0xFF", MsgType(0xFF), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.m.Valid() + if got != tt.want { + t.Errorf("MsgType(%d).Valid() = %v, want %v", tt.m, got, tt.want) + } + }) + } +} + +func TestFlag_Operations(t *testing.T) { + t.Run("Has", func(t *testing.T) { + flags := FlagDirtyValid | FlagKeyframe + if !flags.Has(FlagDirtyValid) { + t.Error("expected FlagDirtyValid to be set") + } + if !flags.Has(FlagKeyframe) { + t.Error("expected FlagKeyframe to be set") + } + if flags.Has(FlagCompressed) { + t.Error("expected FlagCompressed to not be set") + } + }) + + t.Run("Set", func(t *testing.T) { + var f Flag + f = f.Set(FlagCompressed) + if !f.Has(FlagCompressed) { + t.Error("expected FlagCompressed after Set") + } + f = f.Set(FlagDirtyValid) + if !f.Has(FlagDirtyValid) { + t.Error("expected FlagDirtyValid after Set") + } + // Setting again is idempotent. + f = f.Set(FlagCompressed) + if f != FlagCompressed|FlagDirtyValid { + t.Errorf("expected 0x03, got 0x%02X", f) + } + }) + + t.Run("Clear", func(t *testing.T) { + f := FlagDirtyValid | FlagCompressed | FlagKeyframe + f = f.Clear(FlagCompressed) + if f.Has(FlagCompressed) { + t.Error("expected FlagCompressed to be cleared") + } + if !f.Has(FlagDirtyValid) { + t.Error("expected FlagDirtyValid to still be set") + } + if !f.Has(FlagKeyframe) { + t.Error("expected FlagKeyframe to still be set") + } + }) + + t.Run("ZeroHasNothing", func(t *testing.T) { + var f Flag + if f.Has(FlagDirtyValid) || f.Has(FlagCompressed) || f.Has(FlagKeyframe) { + t.Error("zero Flag should have no flags set") + } + }) +} + +func TestFlag_String(t *testing.T) { + tests := []struct { + name string + f Flag + want string + }{ + {"DirtyValid", FlagDirtyValid, "DirtyValid"}, + {"Compressed", FlagCompressed, "Compressed"}, + {"Keyframe", FlagKeyframe, "Keyframe"}, + {"Unknown", Flag(0x10), "Flag(0x10)"}, + {"Combined", FlagDirtyValid | FlagCompressed, "Flag(0x03)"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.f.String() + if got != tt.want { + t.Errorf("Flag(0x%02X).String() = %q, want %q", tt.f, got, tt.want) + } + }) + } +} + +func TestPixelFormat_String(t *testing.T) { + tests := []struct { + name string + p PixelFormat + want string + }{ + {"RGBA8", PixelRGBA8, "RGBA8"}, + {"BGRA8", PixelBGRA8, "BGRA8"}, + {"Unknown", PixelFormat(0x99), "PixelFormat(0x99)"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.p.String() + if got != tt.want { + t.Errorf("PixelFormat(0x%02X).String() = %q, want %q", tt.p, got, tt.want) + } + }) + } +} + +func TestPixelFormat_Valid(t *testing.T) { + tests := []struct { + name string + p PixelFormat + want bool + }{ + {"RGBA8", PixelRGBA8, true}, + {"BGRA8", PixelBGRA8, true}, + {"Unknown", PixelFormat(0x02), false}, + {"Max", PixelFormat(0xFF), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.p.Valid() + if got != tt.want { + t.Errorf("PixelFormat(0x%02X).Valid() = %v, want %v", tt.p, got, tt.want) + } + }) + } +} + +func TestCompression_String(t *testing.T) { + tests := []struct { + name string + c Compression + want string + }{ + {"None", CompressionNone, "None"}, + {"LZ4", CompressionLZ4, "LZ4"}, + {"Zstd", CompressionZstd, "Zstd"}, + {"Unknown", Compression(0xAB), "Compression(0xAB)"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.c.String() + if got != tt.want { + t.Errorf("Compression(0x%02X).String() = %q, want %q", tt.c, got, tt.want) + } + }) + } +} + +func TestCompression_Valid(t *testing.T) { + tests := []struct { + name string + c Compression + want bool + }{ + {"None", CompressionNone, true}, + {"LZ4", CompressionLZ4, true}, + {"Zstd", CompressionZstd, true}, + {"Unknown_0x03", Compression(0x03), false}, + {"Unknown_0xFF", Compression(0xFF), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.c.Valid() + if got != tt.want { + t.Errorf("Compression(0x%02X).Valid() = %v, want %v", tt.c, got, tt.want) + } + }) + } +} + +func TestFlagValues(t *testing.T) { + // Verify flag bit positions match the spec. + if FlagDirtyValid != 0x01 { + t.Errorf("FlagDirtyValid = 0x%02X, want 0x01", FlagDirtyValid) + } + if FlagCompressed != 0x02 { + t.Errorf("FlagCompressed = 0x%02X, want 0x02", FlagCompressed) + } + if FlagKeyframe != 0x04 { + t.Errorf("FlagKeyframe = 0x%02X, want 0x04", FlagKeyframe) + } +} + +func TestMsgTypeValues(t *testing.T) { + // Verify message type values match the spec. + if MsgFrame != 0x01 { + t.Errorf("MsgFrame = 0x%02X, want 0x01", MsgFrame) + } + if MsgHandshake != 0x02 { + t.Errorf("MsgHandshake = 0x%02X, want 0x02", MsgHandshake) + } + if MsgAck != 0x03 { + t.Errorf("MsgAck = 0x%02X, want 0x03", MsgAck) + } + if MsgFrameRequest != 0x04 { + t.Errorf("MsgFrameRequest = 0x%02X, want 0x04", MsgFrameRequest) + } + if MsgResize != 0x05 { + t.Errorf("MsgResize = 0x%02X, want 0x05", MsgResize) + } + if MsgDisconnect != 0x06 { + t.Errorf("MsgDisconnect = 0x%02X, want 0x06", MsgDisconnect) + } +} + +func TestPixelFormatValues(t *testing.T) { + if PixelRGBA8 != 0x00 { + t.Errorf("PixelRGBA8 = 0x%02X, want 0x00", PixelRGBA8) + } + if PixelBGRA8 != 0x01 { + t.Errorf("PixelBGRA8 = 0x%02X, want 0x01", PixelBGRA8) + } +} + +func TestCompressionValues(t *testing.T) { + if CompressionNone != 0x00 { + t.Errorf("CompressionNone = 0x%02X, want 0x00", CompressionNone) + } + if CompressionLZ4 != 0x01 { + t.Errorf("CompressionLZ4 = 0x%02X, want 0x01", CompressionLZ4) + } + if CompressionZstd != 0x02 { + t.Errorf("CompressionZstd = 0x%02X, want 0x02", CompressionZstd) + } +} diff --git a/internal/transport/socket/conn.go b/internal/transport/socket/conn.go new file mode 100644 index 0000000..f0445e8 --- /dev/null +++ b/internal/transport/socket/conn.go @@ -0,0 +1,183 @@ +package socket + +import ( + "bufio" + "fmt" + "io" + "net" + "sync" + + "github.com/gogpu/compose/internal/protocol" +) + +// bufferSize is the bufio.Reader buffer size. +// 256 KB handles most frames (including 192 KB payloads) in one syscall. +const bufferSize = 256 * 1024 + +// Conn wraps a net.Conn with framed header+payload read/write. +// Concurrent reads and writes are safe (separate locks). +// Concurrent reads from multiple goroutines are serialized by readMu. +// Concurrent writes from multiple goroutines are serialized by writeMu. +type Conn struct { + raw net.Conn + reader *bufio.Reader + writeMu sync.Mutex + readMu sync.Mutex + hdrBuf [protocol.HeaderSize]byte // reused scratch buffer for header encode/decode + hsBuf [protocol.HandshakeSize]byte +} + +// NewConn wraps an existing network connection with framed I/O. +func NewConn(c net.Conn) *Conn { + return &Conn{ + raw: c, + reader: bufio.NewReaderSize(c, bufferSize), + } +} + +// WriteFrame sends a header + payload atomically. +// The header is encoded into an internal buffer and written together with +// payload in a single locked section to prevent interleaving. +func (c *Conn) WriteFrame(hdr *protocol.Header, payload []byte) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + if err := protocol.Encode(hdr, c.hdrBuf[:]); err != nil { + return fmt.Errorf("socket: encode header: %w", err) + } + + // Write header bytes. + if _, err := c.raw.Write(c.hdrBuf[:]); err != nil { + return fmt.Errorf("socket: write header: %w", err) + } + + // Write payload if present. + if len(payload) > 0 { + if _, err := c.raw.Write(payload); err != nil { + return fmt.Errorf("socket: write payload: %w", err) + } + } + + return nil +} + +// ReadFrame reads a header + payload from the connection. +// The returned payload slice is freshly allocated if PayloadSize > 0. +func (c *Conn) ReadFrame() (protocol.Header, []byte, error) { + return c.ReadFrameInto(nil) +} + +// ReadFrameInto reads a header + payload into a caller-provided buffer. +// If buf is nil or too small for the header's PayloadSize, a new buffer +// is allocated. The returned slice may be a sub-slice of buf or a new +// allocation. +func (c *Conn) ReadFrameInto(buf []byte) (protocol.Header, []byte, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + // Read exactly HeaderSize bytes into scratch buffer. + if _, err := io.ReadFull(c.reader, c.hdrBuf[:]); err != nil { + return protocol.Header{}, nil, fmt.Errorf("socket: read header: %w", err) + } + + hdr, err := protocol.Decode(c.hdrBuf[:]) + if err != nil { + return protocol.Header{}, nil, fmt.Errorf("socket: decode header: %w", err) + } + + // Read payload. + size := int(hdr.PayloadSize) + if size == 0 { + return hdr, nil, nil + } + + // Reuse or allocate payload buffer. + var payload []byte + if len(buf) >= size { + payload = buf[:size] + } else { + payload = make([]byte, size) + } + + if _, err := io.ReadFull(c.reader, payload); err != nil { + return protocol.Header{}, nil, fmt.Errorf("socket: read payload (%d bytes): %w", size, err) + } + + return hdr, payload, nil +} + +// WriteHandshakeHello sends a HelloMsg on the connection. +func (c *Conn) WriteHandshakeHello(msg *protocol.HelloMsg) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + if err := protocol.EncodeHello(msg, c.hsBuf[:]); err != nil { + return fmt.Errorf("socket: encode hello: %w", err) + } + + if _, err := c.raw.Write(c.hsBuf[:]); err != nil { + return fmt.Errorf("socket: write hello: %w", err) + } + + return nil +} + +// ReadHandshakeHello reads a HelloMsg from the connection. +func (c *Conn) ReadHandshakeHello() (protocol.HelloMsg, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + if _, err := io.ReadFull(c.reader, c.hsBuf[:]); err != nil { + return protocol.HelloMsg{}, fmt.Errorf("socket: read hello: %w", err) + } + + msg, err := protocol.DecodeHello(c.hsBuf[:]) + if err != nil { + return protocol.HelloMsg{}, fmt.Errorf("socket: decode hello: %w", err) + } + + return msg, nil +} + +// WriteHandshakeWelcome sends a WelcomeMsg on the connection. +func (c *Conn) WriteHandshakeWelcome(msg *protocol.WelcomeMsg) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + if err := protocol.EncodeWelcome(msg, c.hsBuf[:]); err != nil { + return fmt.Errorf("socket: encode welcome: %w", err) + } + + if _, err := c.raw.Write(c.hsBuf[:]); err != nil { + return fmt.Errorf("socket: write welcome: %w", err) + } + + return nil +} + +// ReadHandshakeWelcome reads a WelcomeMsg from the connection. +func (c *Conn) ReadHandshakeWelcome() (protocol.WelcomeMsg, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + if _, err := io.ReadFull(c.reader, c.hsBuf[:]); err != nil { + return protocol.WelcomeMsg{}, fmt.Errorf("socket: read welcome: %w", err) + } + + msg, err := protocol.DecodeWelcome(c.hsBuf[:]) + if err != nil { + return protocol.WelcomeMsg{}, fmt.Errorf("socket: decode welcome: %w", err) + } + + return msg, nil +} + +// Close closes the underlying network connection. +func (c *Conn) Close() error { + return c.raw.Close() +} + +// RemoteAddr returns the remote address of the underlying connection. +func (c *Conn) RemoteAddr() net.Addr { + return c.raw.RemoteAddr() +} diff --git a/internal/transport/socket/conn_test.go b/internal/transport/socket/conn_test.go new file mode 100644 index 0000000..0843b7a --- /dev/null +++ b/internal/transport/socket/conn_test.go @@ -0,0 +1,738 @@ +package socket + +import ( + "bytes" + "errors" + "io" + "net" + "sync" + "testing" + + "github.com/gogpu/compose/internal/protocol" +) + +// newPipeConns creates a pair of Conn values connected by net.Pipe. +func newPipeConns(t *testing.T) (client, server *Conn) { + t.Helper() + c1, c2 := net.Pipe() + t.Cleanup(func() { + c1.Close() + c2.Close() + }) + return NewConn(c1), NewConn(c2) +} + +// makeHeader returns a minimal valid header for testing. +func makeHeader(msgType protocol.MsgType, payloadSize uint32) protocol.Header { + return protocol.Header{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + MsgType: msgType, + PayloadSize: payloadSize, + } +} + +func TestWriteReadFrame_RoundTrip(t *testing.T) { + tests := []struct { + name string + msgType protocol.MsgType + payload []byte + }{ + { + name: "frame with payload", + msgType: protocol.MsgFrame, + payload: []byte("hello compose"), + }, + { + name: "ack empty payload", + msgType: protocol.MsgAck, + payload: nil, + }, + { + name: "frame request empty", + msgType: protocol.MsgFrameRequest, + payload: []byte{}, + }, + { + name: "single byte payload", + msgType: protocol.MsgFrame, + payload: []byte{0xFF}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, server := newPipeConns(t) + + payloadLen := uint32(len(tt.payload)) + hdr := makeHeader(tt.msgType, payloadLen) + + // Write in background. + errCh := make(chan error, 1) + go func() { + errCh <- client.WriteFrame(&hdr, tt.payload) + }() + + gotHdr, gotPayload, err := server.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame: %v", err) + } + + if werr := <-errCh; werr != nil { + t.Fatalf("WriteFrame: %v", werr) + } + + if gotHdr.MsgType != tt.msgType { + t.Errorf("MsgType = %v, want %v", gotHdr.MsgType, tt.msgType) + } + if gotHdr.PayloadSize != payloadLen { + t.Errorf("PayloadSize = %d, want %d", gotHdr.PayloadSize, payloadLen) + } + + if payloadLen == 0 { + if gotPayload != nil { + t.Errorf("payload = %v, want nil", gotPayload) + } + } else if !bytes.Equal(gotPayload, tt.payload) { + t.Errorf("payload mismatch: got %d bytes, want %d bytes", len(gotPayload), len(tt.payload)) + } + }) + } +} + +func TestWriteReadFrame_LargePayload(t *testing.T) { + client, server := newPipeConns(t) + + // 1 MB payload. + payload := make([]byte, 1<<20) + for i := range payload { + payload[i] = byte(i % 256) + } + + hdr := makeHeader(protocol.MsgFrame, uint32(len(payload))) + + errCh := make(chan error, 1) + go func() { + errCh <- client.WriteFrame(&hdr, payload) + }() + + gotHdr, gotPayload, err := server.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame: %v", err) + } + if werr := <-errCh; werr != nil { + t.Fatalf("WriteFrame: %v", werr) + } + + if gotHdr.PayloadSize != uint32(len(payload)) { + t.Errorf("PayloadSize = %d, want %d", gotHdr.PayloadSize, len(payload)) + } + if !bytes.Equal(gotPayload, payload) { + t.Error("large payload content mismatch") + } +} + +func TestReadFrameInto_ReuseBuffer(t *testing.T) { + client, server := newPipeConns(t) + + payload := []byte("reuse this buffer please") + hdr := makeHeader(protocol.MsgFrame, uint32(len(payload))) + + errCh := make(chan error, 1) + go func() { + errCh <- client.WriteFrame(&hdr, payload) + }() + + // Provide a sufficiently large buffer. + buf := make([]byte, 1024) + _, gotPayload, err := server.ReadFrameInto(buf) + if err != nil { + t.Fatalf("ReadFrameInto: %v", err) + } + if werr := <-errCh; werr != nil { + t.Fatalf("WriteFrame: %v", werr) + } + + if !bytes.Equal(gotPayload, payload) { + t.Error("payload mismatch") + } + + // Verify the returned slice shares backing array with buf. + if &gotPayload[0] != &buf[0] { + t.Error("ReadFrameInto did not reuse provided buffer") + } +} + +func TestReadFrameInto_SmallBuffer_Allocates(t *testing.T) { + client, server := newPipeConns(t) + + payload := []byte("too big for tiny buffer") + hdr := makeHeader(protocol.MsgFrame, uint32(len(payload))) + + errCh := make(chan error, 1) + go func() { + errCh <- client.WriteFrame(&hdr, payload) + }() + + // Buffer smaller than payload. + buf := make([]byte, 4) + _, gotPayload, err := server.ReadFrameInto(buf) + if err != nil { + t.Fatalf("ReadFrameInto: %v", err) + } + if werr := <-errCh; werr != nil { + t.Fatalf("WriteFrame: %v", werr) + } + + if !bytes.Equal(gotPayload, payload) { + t.Error("payload mismatch") + } +} + +func TestConcurrentWriteRead(t *testing.T) { + client, server := newPipeConns(t) + + const numFrames = 100 + payload := []byte("concurrent test payload") + hdr := makeHeader(protocol.MsgFrame, uint32(len(payload))) + + // Writer goroutine: send numFrames frames. + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for i := range numFrames { + h := hdr + h.Sequence = uint64(i) + if err := client.WriteFrame(&h, payload); err != nil { + t.Errorf("WriteFrame[%d]: %v", i, err) + return + } + } + }() + + // Reader: receive numFrames frames. + for i := range numFrames { + gotHdr, gotPayload, err := server.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame[%d]: %v", i, err) + } + if gotHdr.Sequence != uint64(i) { + t.Errorf("frame %d: Sequence = %d, want %d", i, gotHdr.Sequence, i) + } + if !bytes.Equal(gotPayload, payload) { + t.Errorf("frame %d: payload mismatch", i) + } + } + + wg.Wait() +} + +func TestConcurrentWriters(t *testing.T) { + client, server := newPipeConns(t) + + const numWriters = 4 + const framesPerWriter = 25 + payload := []byte("multi-writer") + hdr := makeHeader(protocol.MsgFrame, uint32(len(payload))) + + var wg sync.WaitGroup + for w := range numWriters { + wg.Add(1) + go func(writerID int) { + defer wg.Done() + for i := range framesPerWriter { + h := hdr + h.ModuleID = uint64(writerID) + h.Sequence = uint64(i) + if err := client.WriteFrame(&h, payload); err != nil { + t.Errorf("writer %d frame %d: %v", writerID, i, err) + return + } + } + }(w) + } + + // Read all frames. + total := numWriters * framesPerWriter + for i := range total { + _, gotPayload, err := server.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame[%d]: %v", i, err) + } + if !bytes.Equal(gotPayload, payload) { + t.Errorf("frame %d: payload mismatch", i) + } + } + + wg.Wait() +} + +func TestHandshake_RoundTrip(t *testing.T) { + client, server := newPipeConns(t) + + // Module sends Hello. + hello := protocol.HelloMsg{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + Width: 400, + Height: 120, + PreferredFPS: 60, + Transport: protocol.TransportSocket, + } + protocol.SetName(&hello, "test-module") + + errCh := make(chan error, 1) + go func() { + errCh <- client.WriteHandshakeHello(&hello) + }() + + gotHello, err := server.ReadHandshakeHello() + if err != nil { + t.Fatalf("ReadHandshakeHello: %v", err) + } + if werr := <-errCh; werr != nil { + t.Fatalf("WriteHandshakeHello: %v", werr) + } + + if gotHello.Magic != protocol.Magic { + t.Error("hello: magic mismatch") + } + if protocol.GetName(&gotHello) != "test-module" { + t.Errorf("hello: name = %q, want %q", protocol.GetName(&gotHello), "test-module") + } + if gotHello.Width != 400 { + t.Errorf("hello: Width = %d, want 400", gotHello.Width) + } + if gotHello.Height != 120 { + t.Errorf("hello: Height = %d, want 120", gotHello.Height) + } + if gotHello.PreferredFPS != 60 { + t.Errorf("hello: PreferredFPS = %d, want 60", gotHello.PreferredFPS) + } + + // Compositor sends Welcome. + welcome := protocol.WelcomeMsg{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + ModuleID: 42, + Accepted: 1, + Transport: protocol.TransportSocket, + MinVersion: 1, + MaxVersion: 1, + } + + go func() { + errCh <- server.WriteHandshakeWelcome(&welcome) + }() + + gotWelcome, err := client.ReadHandshakeWelcome() + if err != nil { + t.Fatalf("ReadHandshakeWelcome: %v", err) + } + if werr := <-errCh; werr != nil { + t.Fatalf("WriteHandshakeWelcome: %v", werr) + } + + if gotWelcome.ModuleID != 42 { + t.Errorf("welcome: ModuleID = %d, want 42", gotWelcome.ModuleID) + } + if gotWelcome.Accepted != 1 { + t.Errorf("welcome: Accepted = %d, want 1", gotWelcome.Accepted) + } +} + +func TestReadFrame_ConnectionClosed(t *testing.T) { + client, server := newPipeConns(t) + + // Close the writing side immediately. + client.Close() + + _, _, err := server.ReadFrame() + if err == nil { + t.Fatal("expected error on closed connection, got nil") + } + + // Should wrap io.EOF or io.ErrUnexpectedEOF. + if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { + // net.Pipe close produces io.EOF via io.ReadFull -> io.ErrUnexpectedEOF, + // but the exact wrapping depends on how many bytes were read. + // Accept any error as valid — the point is we don't hang. + t.Logf("got error (acceptable): %v", err) + } +} + +func TestReadHandshakeHello_ConnectionClosed(t *testing.T) { + client, server := newPipeConns(t) + client.Close() + + _, err := server.ReadHandshakeHello() + if err == nil { + t.Fatal("expected error on closed connection, got nil") + } +} + +func TestReadHandshakeWelcome_ConnectionClosed(t *testing.T) { + client, server := newPipeConns(t) + client.Close() + + _, err := server.ReadHandshakeWelcome() + if err == nil { + t.Fatal("expected error on closed connection, got nil") + } +} + +func TestConn_RemoteAddr(t *testing.T) { + client, _ := newPipeConns(t) + + addr := client.RemoteAddr() + if addr == nil { + t.Fatal("RemoteAddr returned nil") + } +} + +func TestConn_Close(t *testing.T) { + c1, c2 := net.Pipe() + conn := NewConn(c1) + + // Close should not error on first call. + if err := conn.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // The other end should see EOF. + buf := make([]byte, 1) + _, err := c2.Read(buf) + if err == nil { + t.Fatal("expected error after close, got nil") + } + + c2.Close() +} + +func TestWriteReadFrame_AllHeaderFields(t *testing.T) { + client, server := newPipeConns(t) + + payload := []byte{0xDE, 0xAD, 0xBE, 0xEF} + hdr := protocol.Header{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + MsgType: protocol.MsgFrame, + Flags: protocol.FlagDirtyValid | protocol.FlagKeyframe, + ModuleID: 12345, + Sequence: 99, + TimestampNs: 1_000_000_000, + Width: 1920, + Height: 1080, + Stride: 1920 * 4, + DirtyX: 10, + DirtyY: 20, + DirtyW: 100, + DirtyH: 200, + PixelFormat: protocol.PixelRGBA8, + Compression: protocol.CompressionNone, + PayloadSize: uint32(len(payload)), + UncompressedSize: uint32(len(payload)), + } + + errCh := make(chan error, 1) + go func() { + errCh <- client.WriteFrame(&hdr, payload) + }() + + got, gotPayload, err := server.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame: %v", err) + } + if werr := <-errCh; werr != nil { + t.Fatalf("WriteFrame: %v", werr) + } + + // Verify all fields survived the round trip. + if got.ModuleID != 12345 { + t.Errorf("ModuleID = %d, want 12345", got.ModuleID) + } + if got.Sequence != 99 { + t.Errorf("Sequence = %d, want 99", got.Sequence) + } + if got.TimestampNs != 1_000_000_000 { + t.Errorf("TimestampNs = %d, want 1000000000", got.TimestampNs) + } + if got.Width != 1920 { + t.Errorf("Width = %d, want 1920", got.Width) + } + if got.Height != 1080 { + t.Errorf("Height = %d, want 1080", got.Height) + } + if got.Stride != 1920*4 { + t.Errorf("Stride = %d, want %d", got.Stride, 1920*4) + } + if got.DirtyX != 10 || got.DirtyY != 20 || got.DirtyW != 100 || got.DirtyH != 200 { + t.Errorf("DirtyRect = (%d,%d,%d,%d), want (10,20,100,200)", + got.DirtyX, got.DirtyY, got.DirtyW, got.DirtyH) + } + if got.Flags != (protocol.FlagDirtyValid | protocol.FlagKeyframe) { + t.Errorf("Flags = %d, want %d", got.Flags, protocol.FlagDirtyValid|protocol.FlagKeyframe) + } + if !bytes.Equal(gotPayload, payload) { + t.Error("payload mismatch") + } +} + +func TestWriteReadFrame_MultipleSequential(t *testing.T) { + client, server := newPipeConns(t) + + const count = 10 + errCh := make(chan error, 1) + go func() { + for i := range count { + p := []byte{byte(i)} + h := makeHeader(protocol.MsgFrame, 1) + h.Sequence = uint64(i) + if err := client.WriteFrame(&h, p); err != nil { + errCh <- err + return + } + } + errCh <- nil + }() + + for i := range count { + hdr, payload, err := server.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame[%d]: %v", i, err) + } + if hdr.Sequence != uint64(i) { + t.Errorf("frame %d: Sequence = %d", i, hdr.Sequence) + } + if len(payload) != 1 || payload[0] != byte(i) { + t.Errorf("frame %d: payload mismatch", i) + } + } + + if werr := <-errCh; werr != nil { + t.Fatalf("WriteFrame: %v", werr) + } +} + +func TestWriteFrame_ClosedConn_HeaderWriteError(t *testing.T) { + c1, c2 := net.Pipe() + conn := NewConn(c1) + c2.Close() + + // Close the underlying conn so the header write fails. + c1.Close() + + hdr := makeHeader(protocol.MsgFrame, 5) + err := conn.WriteFrame(&hdr, []byte("hello")) + if err == nil { + t.Fatal("expected error writing to closed conn, got nil") + } +} + +func TestWriteFrame_ClosedConn_PayloadWriteError(t *testing.T) { + // Use a real socket pair so we can close mid-flight. + addr := testSocketPath(t) + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + acceptCh := make(chan *Conn, 1) + go func() { + c, err := ln.Accept() + if err != nil { + return + } + acceptCh <- c + }() + + dialer := NewDialer(addr) + clientConn, err := dialer.Dial() + if err != nil { + t.Fatalf("Dial: %v", err) + } + + serverConn := <-acceptCh + // Close the server side so writes from client eventually fail. + serverConn.Close() + + // Write a large payload — the header write may succeed but the + // large payload write should fail because the peer is closed. + bigPayload := make([]byte, 1<<20) // 1 MB + hdr := makeHeader(protocol.MsgFrame, uint32(len(bigPayload))) + + // Retry a few times — the first write might succeed due to kernel buffering. + var lastErr error + for range 100 { + lastErr = clientConn.WriteFrame(&hdr, bigPayload) + if lastErr != nil { + break + } + } + if lastErr == nil { + t.Log("write never failed (kernel buffered everything) — skipping payload error test") + } + + clientConn.Close() +} + +func TestWriteHandshakeHello_ClosedConn(t *testing.T) { + c1, c2 := net.Pipe() + conn := NewConn(c1) + c2.Close() + c1.Close() + + hello := protocol.HelloMsg{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + } + err := conn.WriteHandshakeHello(&hello) + if err == nil { + t.Fatal("expected error writing hello to closed conn, got nil") + } +} + +func TestWriteHandshakeWelcome_ClosedConn(t *testing.T) { + c1, c2 := net.Pipe() + conn := NewConn(c1) + c2.Close() + c1.Close() + + welcome := protocol.WelcomeMsg{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + } + err := conn.WriteHandshakeWelcome(&welcome) + if err == nil { + t.Fatal("expected error writing welcome to closed conn, got nil") + } +} + +func TestReadFrame_InvalidMagic(t *testing.T) { + c1, c2 := net.Pipe() + defer c1.Close() + defer c2.Close() + + // Write 64 bytes of garbage (invalid magic). + go func() { + garbage := make([]byte, protocol.HeaderSize) + garbage[0] = 0xFF // wrong magic + garbage[1] = 0xFF + garbage[2] = 0xFF + garbage[3] = 0xFF + // Set a valid MsgType at offset 6 to avoid ErrUnknownMsgType. + garbage[6] = uint8(protocol.MsgFrame) + c1.Write(garbage) + }() + + conn := NewConn(c2) + _, _, err := conn.ReadFrame() + if err == nil { + t.Fatal("expected error for invalid magic, got nil") + } +} + +func TestReadFrame_PayloadReadError(t *testing.T) { + c1, c2 := net.Pipe() + defer c2.Close() + + conn := NewConn(c2) + + // Write a valid header claiming 1000 bytes of payload, then close. + go func() { + hdr := makeHeader(protocol.MsgFrame, 1000) + var buf [protocol.HeaderSize]byte + protocol.Encode(&hdr, buf[:]) + c1.Write(buf[:]) + // Write only 10 bytes of the claimed 1000, then close. + c1.Write([]byte("short data")) + c1.Close() + }() + + _, _, err := conn.ReadFrame() + if err == nil { + t.Fatal("expected error for truncated payload, got nil") + } +} + +func TestReadHandshakeHello_InvalidMagic(t *testing.T) { + c1, c2 := net.Pipe() + defer c1.Close() + defer c2.Close() + + go func() { + garbage := make([]byte, protocol.HandshakeSize) + garbage[0] = 0xBA + garbage[1] = 0xAD + c1.Write(garbage) + }() + + conn := NewConn(c2) + _, err := conn.ReadHandshakeHello() + if err == nil { + t.Fatal("expected error for invalid hello magic, got nil") + } +} + +func TestReadHandshakeWelcome_InvalidMagic(t *testing.T) { + c1, c2 := net.Pipe() + defer c1.Close() + defer c2.Close() + + go func() { + garbage := make([]byte, protocol.HandshakeSize) + garbage[0] = 0xDE + garbage[1] = 0xAD + c1.Write(garbage) + }() + + conn := NewConn(c2) + _, err := conn.ReadHandshakeWelcome() + if err == nil { + t.Fatal("expected error for invalid welcome magic, got nil") + } +} + +func BenchmarkWriteReadFrame(b *testing.B) { + c1, c2 := net.Pipe() + defer c1.Close() + defer c2.Close() + + client := NewConn(c1) + server := NewConn(c2) + + // 192 KB payload (typical frame: 320x150 RGBA). + payload := make([]byte, 192*1024) + for i := range payload { + payload[i] = byte(i) + } + hdr := makeHeader(protocol.MsgFrame, uint32(len(payload))) + buf := make([]byte, len(payload)) + + // Writer goroutine. + done := make(chan struct{}) + go func() { + defer close(done) + for range b.N { + if err := client.WriteFrame(&hdr, payload); err != nil { + b.Errorf("WriteFrame: %v", err) + return + } + } + }() + + b.SetBytes(int64(protocol.HeaderSize + len(payload))) + b.ResetTimer() + + for range b.N { + _, _, err := server.ReadFrameInto(buf) + if err != nil { + b.Fatalf("ReadFrameInto: %v", err) + } + } + + b.StopTimer() + <-done +} diff --git a/internal/transport/socket/dialer.go b/internal/transport/socket/dialer.go new file mode 100644 index 0000000..ea7d9d7 --- /dev/null +++ b/internal/transport/socket/dialer.go @@ -0,0 +1,44 @@ +package socket + +import ( + "fmt" + "net" + "time" +) + +// defaultDialTimeout is the default connection timeout. +const defaultDialTimeout = 5 * time.Second + +// Dialer connects to a compositor's Unix domain socket. +// Reconnection with backoff belongs in higher layers (internal/conn.Manager); +// the dialer is a simple one-shot connect helper. +type Dialer struct { + addr string + timeout time.Duration +} + +// NewDialer creates a dialer for the given Unix domain socket address. +// The default timeout is 5 seconds. +func NewDialer(addr string) *Dialer { + return &Dialer{ + addr: addr, + timeout: defaultDialTimeout, + } +} + +// Dial connects to the compositor using the default timeout. +// The returned [Conn] is ready for handshake (WriteHandshakeHello / +// ReadHandshakeWelcome). +func (d *Dialer) Dial() (*Conn, error) { + return d.DialWithTimeout(d.timeout) +} + +// DialWithTimeout connects to the compositor with a custom timeout. +// A zero or negative timeout means no deadline. +func (d *Dialer) DialWithTimeout(timeout time.Duration) (*Conn, error) { + c, err := net.DialTimeout("unix", d.addr, timeout) + if err != nil { + return nil, fmt.Errorf("socket: dial %s: %w", d.addr, err) + } + return NewConn(c), nil +} diff --git a/internal/transport/socket/dialer_test.go b/internal/transport/socket/dialer_test.go new file mode 100644 index 0000000..33a51ed --- /dev/null +++ b/internal/transport/socket/dialer_test.go @@ -0,0 +1,248 @@ +package socket + +import ( + "testing" + "time" + + "github.com/gogpu/compose/internal/protocol" +) + +func TestNewDialer(t *testing.T) { + d := NewDialer("/tmp/test.sock") + if d.addr != "/tmp/test.sock" { + t.Errorf("addr = %q, want %q", d.addr, "/tmp/test.sock") + } + if d.timeout != defaultDialTimeout { + t.Errorf("timeout = %v, want %v", d.timeout, defaultDialTimeout) + } +} + +func TestDial_Success(t *testing.T) { + addr := testSocketPath(t) + + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + // Accept in background. + type result struct { + conn *Conn + err error + } + ch := make(chan result, 1) + go func() { + c, err := ln.Accept() + ch <- result{c, err} + }() + + dialer := NewDialer(addr) + conn, err := dialer.Dial() + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + r := <-ch + if r.err != nil { + t.Fatalf("Accept: %v", r.err) + } + t.Cleanup(func() { _ = r.conn.Close() }) +} + +func TestDial_Timeout(t *testing.T) { + // Use a non-existent socket path. + addr := testSocketPath(t) + + dialer := NewDialer(addr) + _, err := dialer.DialWithTimeout(100 * time.Millisecond) + if err == nil { + t.Fatal("expected error dialing non-existent socket, got nil") + } +} + +func TestDialWithTimeout_CustomTimeout(t *testing.T) { + addr := testSocketPath(t) + + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + // Accept in background. + ch := make(chan error, 1) + go func() { + c, err := ln.Accept() + if err == nil { + c.Close() + } + ch <- err + }() + + dialer := NewDialer(addr) + conn, err := dialer.DialWithTimeout(2 * time.Second) + if err != nil { + t.Fatalf("DialWithTimeout: %v", err) + } + conn.Close() + + if err := <-ch; err != nil { + t.Fatalf("Accept: %v", err) + } +} + +func TestDial_FullHandshake(t *testing.T) { + addr := testSocketPath(t) + + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + // Server Accept goroutine. + type result struct { + conn *Conn + err error + } + acceptCh := make(chan result, 1) + go func() { + c, err := ln.Accept() + acceptCh <- result{c, err} + }() + + // Client connects. + dialer := NewDialer(addr) + clientConn, err := dialer.Dial() + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer clientConn.Close() + + r := <-acceptCh + if r.err != nil { + t.Fatalf("Accept: %v", r.err) + } + serverConn := r.conn + defer serverConn.Close() + + // Full handshake: Hello then Welcome. + hello := protocol.HelloMsg{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + Width: 320, + Height: 240, + PreferredFPS: 1, + Transport: protocol.TransportSocket, + } + protocol.SetName(&hello, "clock") + + errCh := make(chan error, 1) + go func() { + errCh <- clientConn.WriteHandshakeHello(&hello) + }() + + gotHello, err := serverConn.ReadHandshakeHello() + if err != nil { + t.Fatalf("ReadHandshakeHello: %v", err) + } + if werr := <-errCh; werr != nil { + t.Fatalf("WriteHandshakeHello: %v", werr) + } + + if protocol.GetName(&gotHello) != "clock" { + t.Errorf("name = %q, want %q", protocol.GetName(&gotHello), "clock") + } + + welcome := protocol.WelcomeMsg{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + ModuleID: 7, + Accepted: 1, + Transport: protocol.TransportSocket, + MinVersion: 1, + MaxVersion: 1, + } + + go func() { + errCh <- serverConn.WriteHandshakeWelcome(&welcome) + }() + + gotWelcome, err := clientConn.ReadHandshakeWelcome() + if err != nil { + t.Fatalf("ReadHandshakeWelcome: %v", err) + } + if werr := <-errCh; werr != nil { + t.Fatalf("WriteHandshakeWelcome: %v", werr) + } + + if gotWelcome.ModuleID != 7 { + t.Errorf("ModuleID = %d, want 7", gotWelcome.ModuleID) + } + if gotWelcome.Accepted != 1 { + t.Errorf("Accepted = %d, want 1", gotWelcome.Accepted) + } + + // Post-handshake: send a frame. + payload := []byte("pixel data here") + hdr := protocol.Header{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + MsgType: protocol.MsgFrame, + ModuleID: 7, + Sequence: 1, + PayloadSize: uint32(len(payload)), + } + + go func() { + errCh <- clientConn.WriteFrame(&hdr, payload) + }() + + gotHdr, gotPayload, err := serverConn.ReadFrame() + if err != nil { + t.Fatalf("ReadFrame: %v", err) + } + if werr := <-errCh; werr != nil { + t.Fatalf("WriteFrame: %v", werr) + } + + if gotHdr.ModuleID != 7 { + t.Errorf("frame ModuleID = %d, want 7", gotHdr.ModuleID) + } + if string(gotPayload) != "pixel data here" { + t.Errorf("frame payload = %q, want %q", gotPayload, "pixel data here") + } +} + +func TestDial_ZeroTimeout(t *testing.T) { + addr := testSocketPath(t) + + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + ch := make(chan error, 1) + go func() { + c, err := ln.Accept() + if err == nil { + c.Close() + } + ch <- err + }() + + dialer := NewDialer(addr) + // Zero timeout = no deadline. + conn, err := dialer.DialWithTimeout(0) + if err != nil { + t.Fatalf("DialWithTimeout(0): %v", err) + } + conn.Close() + + if err := <-ch; err != nil { + t.Fatalf("Accept: %v", err) + } +} diff --git a/internal/transport/socket/doc.go b/internal/transport/socket/doc.go new file mode 100644 index 0000000..1b563a6 --- /dev/null +++ b/internal/transport/socket/doc.go @@ -0,0 +1,30 @@ +// Package socket provides a Unix domain socket transport for the compose +// protocol. It implements framed read/write of [protocol.Header] + payload +// messages over a single [net.Conn], along with server-side listening and +// client-side dialing. +// +// # Conn +// +// [Conn] wraps a [net.Conn] with framed I/O. Each frame on the wire is a +// 64-byte header (see [protocol.HeaderSize]) followed by PayloadSize bytes +// of payload. Reads and writes are independently locked, so concurrent +// producers and consumers are safe on the same connection. +// +// # Listener +// +// [Listener] binds a Unix domain socket (AF_UNIX) and accepts incoming +// connections. On startup it removes any stale socket file left by a +// previous crash, preventing "address already in use" errors. +// +// # Dialer +// +// [Dialer] connects to a compositor's Unix domain socket. It is intentionally +// simple — reconnection with backoff belongs in higher layers +// (internal/conn.Manager), keeping the transport layer focused on +// establishing a single connection. +// +// # Platform support +// +// Unix domain sockets are supported on Linux, macOS, and Windows 10 1803+ +// (AF_UNIX). No CGO required on any platform. +package socket diff --git a/internal/transport/socket/listener.go b/internal/transport/socket/listener.go new file mode 100644 index 0000000..c9ce23d --- /dev/null +++ b/internal/transport/socket/listener.go @@ -0,0 +1,64 @@ +package socket + +import ( + "fmt" + "net" + "os" +) + +// Listener accepts module connections on a Unix domain socket (AF_UNIX). +// On startup it removes any stale socket file left by a previous crash. +type Listener struct { + ln net.Listener + addr string +} + +// Listen creates a Unix domain socket listener at the given path. +// +// If a socket file already exists at addr, it is removed first. This is the +// standard Unix pattern to avoid "address already in use" after an unclean +// shutdown. On Windows, AF_UNIX is supported since Windows 10 version 1803. +func Listen(addr string) (*Listener, error) { + // Remove stale socket file. Ignore errors — the file may not exist, + // or the path may be invalid (net.Listen will catch that). + _ = os.Remove(addr) + + ln, err := net.Listen("unix", addr) + if err != nil { + return nil, fmt.Errorf("socket: listen %s: %w", addr, err) + } + + return &Listener{ + ln: ln, + addr: addr, + }, nil +} + +// Accept waits for and returns the next module connection. +// The returned [Conn] is ready for handshake (WriteHandshakeWelcome / +// ReadHandshakeHello). +func (l *Listener) Accept() (*Conn, error) { + c, err := l.ln.Accept() + if err != nil { + return nil, fmt.Errorf("socket: accept: %w", err) + } + return NewConn(c), nil +} + +// Close stops accepting connections and removes the socket file. +func (l *Listener) Close() error { + err := l.ln.Close() + + // Best-effort cleanup of socket file. + _ = os.Remove(l.addr) + + if err != nil { + return fmt.Errorf("socket: close listener: %w", err) + } + return nil +} + +// Addr returns the listener's network address. +func (l *Listener) Addr() net.Addr { + return l.ln.Addr() +} diff --git a/internal/transport/socket/listener_test.go b/internal/transport/socket/listener_test.go new file mode 100644 index 0000000..6de600f --- /dev/null +++ b/internal/transport/socket/listener_test.go @@ -0,0 +1,327 @@ +package socket + +import ( + "net" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/gogpu/compose/internal/protocol" +) + +// testSocketPath returns a short temporary Unix socket path for testing. +// macOS limits Unix socket paths to 104 bytes. t.TempDir() on macOS CI +// produces paths like /var/folders/.../TestName.../test.sock which easily +// exceeds this limit. We use os.CreateTemp with a short prefix under /tmp +// (or %TEMP% on Windows) to guarantee a short path. +func testSocketPath(t *testing.T) string { + t.Helper() + f, err := os.CreateTemp("", "cs-*.sock") + if err != nil { + t.Fatal(err) + } + path := f.Name() + f.Close() + os.Remove(path) // Remove the file; we need the path for the socket. + t.Cleanup(func() { os.Remove(path) }) + return path +} + +func TestListen_Accept_Close(t *testing.T) { + addr := testSocketPath(t) + + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + + // Accept in background. + type result struct { + conn *Conn + err error + } + ch := make(chan result, 1) + go func() { + c, err := ln.Accept() + ch <- result{c, err} + }() + + // Connect from client side. + raw, err := net.Dial("unix", addr) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer raw.Close() + + r := <-ch + if r.err != nil { + t.Fatalf("Accept: %v", r.err) + } + defer r.conn.Close() + + // Verify the accepted connection is usable. + if r.conn.RemoteAddr() == nil { + t.Error("accepted conn RemoteAddr is nil") + } + + // Close listener. + if err := ln.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // Socket file should be removed. + if _, err := os.Stat(addr); !os.IsNotExist(err) { + t.Errorf("socket file still exists after Close: %v", err) + } +} + +func TestListen_StaleSocketCleanup(t *testing.T) { + addr := testSocketPath(t) + + // Create a stale socket file. + if err := os.WriteFile(addr, []byte("stale"), 0o600); err != nil { + t.Fatalf("create stale file: %v", err) + } + + // Listen should remove the stale file and succeed. + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen with stale file: %v", err) + } + defer ln.Close() + + // Verify the listener is functional. + ch := make(chan error, 1) + go func() { + c, err := ln.Accept() + if err == nil { + c.Close() + } + ch <- err + }() + + raw, err := net.Dial("unix", addr) + if err != nil { + t.Fatalf("Dial: %v", err) + } + raw.Close() + + if err := <-ch; err != nil { + t.Fatalf("Accept: %v", err) + } +} + +func TestListen_MultipleConnections(t *testing.T) { + addr := testSocketPath(t) + + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + const numClients = 5 + + // Accept goroutine. + accepted := make([]*Conn, 0, numClients) + var mu sync.Mutex + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for range numClients { + c, err := ln.Accept() + if err != nil { + t.Errorf("Accept: %v", err) + return + } + mu.Lock() + accepted = append(accepted, c) + mu.Unlock() + } + }() + + // Connect numClients clients. + clients := make([]net.Conn, numClients) + for i := range numClients { + c, err := net.Dial("unix", addr) + if err != nil { + t.Fatalf("Dial[%d]: %v", i, err) + } + clients[i] = c + } + + wg.Wait() + + mu.Lock() + if len(accepted) != numClients { + t.Errorf("accepted %d connections, want %d", len(accepted), numClients) + } + mu.Unlock() + + // Cleanup. + for _, c := range clients { + c.Close() + } + mu.Lock() + for _, c := range accepted { + c.Close() + } + mu.Unlock() +} + +func TestListen_CloseWhileAcceptBlocking(t *testing.T) { + addr := testSocketPath(t) + + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + + // Start blocking Accept. + ch := make(chan error, 1) + go func() { + _, err := ln.Accept() + ch <- err + }() + + // Give Accept time to block. On CI runners, goroutine scheduling + // can be slow, so 200ms gives ample headroom. + time.Sleep(200 * time.Millisecond) + + // Close should unblock Accept. + if err := ln.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + select { + case err := <-ch: + if err == nil { + t.Error("Accept should return error after Close") + } + case <-time.After(5 * time.Second): + t.Fatal("Accept did not unblock after Close (timeout)") + } +} + +func TestListen_Addr(t *testing.T) { + addr := testSocketPath(t) + + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + got := ln.Addr() + if got == nil { + t.Fatal("Addr returned nil") + } + if got.Network() != "unix" { + t.Errorf("Network = %q, want %q", got.Network(), "unix") + } +} + +func TestListen_AcceptHandshake(t *testing.T) { + addr := testSocketPath(t) + + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + // Accept in background. + type result struct { + conn *Conn + err error + } + ch := make(chan result, 1) + go func() { + c, err := ln.Accept() + ch <- result{c, err} + }() + + // Connect and send Hello. + dialer := NewDialer(addr) + clientConn, err := dialer.Dial() + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer clientConn.Close() + + r := <-ch + if r.err != nil { + t.Fatalf("Accept: %v", r.err) + } + defer r.conn.Close() + + // Handshake: client Hello -> server Welcome. + hello := protocol.HelloMsg{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + Width: 800, + Height: 600, + PreferredFPS: 30, + Transport: protocol.TransportSocket, + } + protocol.SetName(&hello, "gallery") + + errCh := make(chan error, 1) + go func() { + errCh <- clientConn.WriteHandshakeHello(&hello) + }() + + gotHello, err := r.conn.ReadHandshakeHello() + if err != nil { + t.Fatalf("ReadHandshakeHello: %v", err) + } + if werr := <-errCh; werr != nil { + t.Fatalf("WriteHandshakeHello: %v", werr) + } + + if protocol.GetName(&gotHello) != "gallery" { + t.Errorf("name = %q, want %q", protocol.GetName(&gotHello), "gallery") + } + + // Server responds with Welcome. + welcome := protocol.WelcomeMsg{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + ModuleID: 1, + Accepted: 1, + Transport: protocol.TransportSocket, + MinVersion: 1, + MaxVersion: 1, + } + + go func() { + errCh <- r.conn.WriteHandshakeWelcome(&welcome) + }() + + gotWelcome, err := clientConn.ReadHandshakeWelcome() + if err != nil { + t.Fatalf("ReadHandshakeWelcome: %v", err) + } + if werr := <-errCh; werr != nil { + t.Fatalf("WriteHandshakeWelcome: %v", werr) + } + + if gotWelcome.ModuleID != 1 { + t.Errorf("ModuleID = %d, want 1", gotWelcome.ModuleID) + } + if gotWelcome.Accepted != 1 { + t.Errorf("Accepted = %d, want 1", gotWelcome.Accepted) + } +} + +func TestListen_InvalidPath(t *testing.T) { + // Path that is too long or invalid on most systems. + addr := filepath.Join(t.TempDir(), string(make([]byte, 300))) + _, err := Listen(addr) + if err == nil { + t.Fatal("expected error for invalid socket path, got nil") + } +} diff --git a/option.go b/option.go new file mode 100644 index 0000000..16b13a7 --- /dev/null +++ b/option.go @@ -0,0 +1,94 @@ +package compose + +// defaultMaxModules is the default maximum number of concurrent module +// connections a Server will accept. +const defaultMaxModules = 16 + +// defaultFPS is the default frames-per-second for a Client that does not +// specify WithFPS. +const defaultFPS = 1 + +// ServerOption configures a Server. Use With* functions to create options. +type ServerOption func(*serverConfig) + +// ClientOption configures a Client. Use With* functions to create options. +type ClientOption func(*clientConfig) + +type serverConfig struct { + maxModules int + compression string // "", "raw", "lz4" +} + +type clientConfig struct { + name string + width uint32 + height uint32 + fps uint16 +} + +func defaultServerConfig() serverConfig { + return serverConfig{ + maxModules: defaultMaxModules, + compression: "", + } +} + +func defaultClientConfig() clientConfig { + return clientConfig{ + name: "module", + width: 400, + height: 300, + fps: defaultFPS, + } +} + +// WithMaxModules sets the maximum number of concurrent module connections +// the Server will accept. Values less than 1 are clamped to 1. +// Default: 16. +func WithMaxModules(n int) ServerOption { + return func(c *serverConfig) { + if n < 1 { + n = 1 + } + c.maxModules = n + } +} + +// WithCompression enables frame payload compression on the Server. +// Supported values: "lz4". Any other value (including "") uses raw +// pass-through (no compression). +// Default: no compression. +func WithCompression(algo string) ServerOption { + return func(c *serverConfig) { + c.compression = algo + } +} + +// WithName sets the human-readable module name sent during the handshake. +// The name is used for logging, slot assignment, and module identification. +// Names longer than 63 bytes are silently truncated. +// Default: "module". +func WithName(name string) ClientOption { + return func(c *clientConfig) { + c.name = name + } +} + +// WithFrameSize sets the initial frame dimensions in pixels. +// The compositor may acknowledge different dimensions during handshake. +// Default: 400x300. +func WithFrameSize(width, height uint32) ClientOption { + return func(c *clientConfig) { + c.width = width + c.height = height + } +} + +// WithFPS sets the module's preferred frame rate. +// A value of 0 defaults to 1 FPS (suitable for static content). +// Default: 1. +func WithFPS(fps uint16) ClientOption { + return func(c *clientConfig) { + c.fps = fps + } +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..464f09d --- /dev/null +++ b/server.go @@ -0,0 +1,418 @@ +package compose + +import ( + "fmt" + "image" + "sync" + "sync/atomic" + "time" + + "github.com/gogpu/compose/internal/codec" + "github.com/gogpu/compose/internal/conn" + "github.com/gogpu/compose/internal/flow" + "github.com/gogpu/compose/internal/protocol" + "github.com/gogpu/compose/internal/transport/socket" +) + +// moduleConn tracks a per-module connection on the server side. +type moduleConn struct { + conn *socket.Conn + moduleID uint64 + name string +} + +// Server is the compositor-side endpoint that accepts module connections +// and delivers frames. All methods are safe for concurrent use. +type Server struct { + listener *socket.Listener + manager *conn.Manager + flow *flow.Controller + codec codec.Codec + + mu sync.RWMutex + onFrame func(Frame) + onConnect func(id uint64, name string) + onDisconnect func(id uint64, name string) + + modulesMu sync.RWMutex + modules map[uint64]*moduleConn + + closed atomic.Bool + done chan struct{} + wg sync.WaitGroup +} + +// Listen creates a Server that accepts module connections on the given +// Unix domain socket address. The server immediately begins accepting +// connections in a background goroutine. +// +// Use ServerOption functions to configure behavior: +// +// srv, err := compose.Listen("/tmp/compose.sock", +// compose.WithMaxModules(8), +// compose.WithCompression("lz4"), +// ) +func Listen(addr string, opts ...ServerOption) (*Server, error) { + cfg := defaultServerConfig() + for _, o := range opts { + o(&cfg) + } + + ln, err := socket.Listen(addr) + if err != nil { + return nil, fmt.Errorf("compose: listen: %w", err) + } + + c := resolveCodec(cfg.compression) + + s := &Server{ + listener: ln, + manager: conn.NewManager(cfg.maxModules), + flow: flow.New(), + codec: c, + modules: make(map[uint64]*moduleConn), + done: make(chan struct{}), + } + + // Wire manager callbacks to server callbacks. + s.manager.OnConnect(func(id uint64, name string) { + s.mu.RLock() + cb := s.onConnect + s.mu.RUnlock() + if cb != nil { + cb(id, name) + } + }) + + s.manager.OnDisconnect(func(id uint64, name string) { + s.mu.RLock() + cb := s.onDisconnect + s.mu.RUnlock() + if cb != nil { + cb(id, name) + } + }) + + // Start accept loop. + s.wg.Add(1) + go s.acceptLoop() + + return s, nil +} + +// OnFrame registers a callback invoked when a module delivers a frame. +// The callback is called on an internal goroutine; it must not block for +// extended periods. Only one callback can be active; subsequent calls +// replace the previous one. Pass nil to remove. +func (s *Server) OnFrame(fn func(Frame)) { + s.mu.Lock() + defer s.mu.Unlock() + s.onFrame = fn +} + +// OnConnect registers a callback invoked when a module completes the +// handshake and becomes active. Only one callback can be active; +// subsequent calls replace the previous one. Pass nil to remove. +func (s *Server) OnConnect(fn func(id uint64, name string)) { + s.mu.Lock() + defer s.mu.Unlock() + s.onConnect = fn + + // Also update the manager callback so it fires the new function. + s.manager.OnConnect(func(id uint64, name string) { + s.mu.RLock() + cb := s.onConnect + s.mu.RUnlock() + if cb != nil { + cb(id, name) + } + }) +} + +// OnDisconnect registers a callback invoked when a module disconnects +// or crashes. Only one callback can be active; subsequent calls replace +// the previous one. Pass nil to remove. +func (s *Server) OnDisconnect(fn func(id uint64, name string)) { + s.mu.Lock() + defer s.mu.Unlock() + s.onDisconnect = fn + + // Also update the manager callback so it fires the new function. + s.manager.OnDisconnect(func(id uint64, name string) { + s.mu.RLock() + cb := s.onDisconnect + s.mu.RUnlock() + if cb != nil { + cb(id, name) + } + }) +} + +// RequestFrame sends a frame-request signal to the specified module +// (pull model). The module should respond with its next frame via +// PublishFrame. +// +// Returns ErrModuleNotFound if the module ID is not connected. +// Returns ErrClosed if the server has been shut down. +func (s *Server) RequestFrame(moduleID uint64) error { + if s.closed.Load() { + return ErrClosed + } + + s.modulesMu.RLock() + mc, ok := s.modules[moduleID] + s.modulesMu.RUnlock() + + if !ok { + return ErrModuleNotFound + } + + hdr := &protocol.Header{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + MsgType: protocol.MsgFrameRequest, + } + hdr.ModuleID = moduleID + + if err := mc.conn.WriteFrame(hdr, nil); err != nil { + return fmt.Errorf("compose: request frame: %w", err) + } + + s.flow.FrameRequested(moduleID) + return nil +} + +// Close shuts down the server, disconnects all modules, and releases +// resources. After Close returns, no more callbacks will be invoked. +func (s *Server) Close() error { + if !s.closed.CompareAndSwap(false, true) { + return ErrClosed + } + + close(s.done) + + // Close the listener to unblock Accept(). + err := s.listener.Close() + + // Close all module connections. + s.modulesMu.RLock() + for _, mc := range s.modules { + _ = mc.conn.Close() // best-effort + } + s.modulesMu.RUnlock() + + // Wait for all goroutines to finish. + s.wg.Wait() + + if err != nil { + return fmt.Errorf("compose: close server: %w", err) + } + return nil +} + +// acceptLoop runs in a goroutine, accepting new module connections. +func (s *Server) acceptLoop() { + defer s.wg.Done() + + for { + c, err := s.listener.Accept() + if err != nil { + // If server is closing, Accept error is expected. + if s.closed.Load() { + return + } + // Transient error — continue accepting. + continue + } + + s.wg.Add(1) + go s.handleModule(c) + } +} + +// handleModule runs in a goroutine for each connected module. +// It performs the handshake, then enters a frame read loop. +func (s *Server) handleModule(c *socket.Conn) { + defer s.wg.Done() + + // Read HelloMsg from module. + hello, err := c.ReadHandshakeHello() + if err != nil { + _ = c.Close() + return + } + + name := protocol.GetName(&hello) + + // Register module via Manager (allocates ID, fires OnConnect callback). + moduleID, err := s.manager.HandleConnect(name, hello.Width, hello.Height, hello.PreferredFPS) + if err != nil { + // Send rejection WelcomeMsg. + welcome := &protocol.WelcomeMsg{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + Accepted: 0, + } + _ = c.WriteHandshakeWelcome(welcome) // best-effort + _ = c.Close() + return + } + + // Register in flow controller. + s.flow.AddModule(moduleID, hello.PreferredFPS) + + // Track the module connection. + mc := &moduleConn{ + conn: c, + moduleID: moduleID, + name: name, + } + + s.modulesMu.Lock() + s.modules[moduleID] = mc + s.modulesMu.Unlock() + + // Send acceptance WelcomeMsg. + welcome := &protocol.WelcomeMsg{ + Magic: protocol.Magic, + Version: protocol.ProtocolVersion, + ModuleID: moduleID, + Accepted: 1, + Transport: protocol.TransportSocket, + MinVersion: protocol.ProtocolVersion, + MaxVersion: protocol.ProtocolVersion, + } + + if err := c.WriteHandshakeWelcome(welcome); err != nil { + s.cleanupModule(moduleID) + _ = c.Close() + return + } + + // Enter frame read loop. + s.readFrameLoop(mc) +} + +// readFrameLoop reads frames from a module connection until EOF or error. +func (s *Server) readFrameLoop(mc *moduleConn) { + defer func() { + s.cleanupModule(mc.moduleID) + _ = mc.conn.Close() + }() + + for { + // Check if server is shutting down. + select { + case <-s.done: + return + default: + } + + hdr, payload, err := mc.conn.ReadFrame() + if err != nil { + // EOF or connection error — module disconnected. + return + } + + if hdr.MsgType == protocol.MsgDisconnect { + return + } + + if hdr.MsgType != protocol.MsgFrame { + // Ignore non-frame messages in the frame read loop. + continue + } + + // Decompress payload if needed. + pixels, err := s.decodePayload(hdr, payload) + if err != nil { + // Corrupted frame — skip it but keep the connection. + continue + } + + // Build Frame from header fields. + frame := headerToFrame(hdr, pixels, mc.name) + + // Notify flow controller. + s.flow.FrameDelivered(mc.moduleID) + + // Update registry last frame time. + s.manager.Registry().UpdateLastFrame(mc.moduleID, time.Now()) + + // Fire OnFrame callback. + s.mu.RLock() + cb := s.onFrame + s.mu.RUnlock() + + if cb != nil { + cb(frame) + } + } +} + +// cleanupModule removes a module from all tracking structures and fires +// the OnDisconnect callback via Manager. +func (s *Server) cleanupModule(moduleID uint64) { + s.modulesMu.Lock() + delete(s.modules, moduleID) + s.modulesMu.Unlock() + + s.flow.RemoveModule(moduleID) + s.manager.HandleDisconnect(moduleID) +} + +// decodePayload decompresses the frame payload based on the header's +// compression flag. Returns the raw RGBA pixel data. +func (s *Server) decodePayload(hdr protocol.Header, payload []byte) ([]byte, error) { + if !hdr.Flags.Has(protocol.FlagCompressed) || hdr.Compression == protocol.CompressionNone { + return payload, nil + } + + // Look up the codec by compression ID. + c := codec.Get(byte(hdr.Compression)) + if c == nil { + return nil, fmt.Errorf("compose: unknown compression 0x%02X", hdr.Compression) + } + + // Allocate destination buffer sized to uncompressed size. + dst := make([]byte, hdr.UncompressedSize) + decoded, err := c.Decode(dst, payload) + if err != nil { + return nil, fmt.Errorf("compose: decompress: %w", err) + } + + return decoded, nil +} + +// headerToFrame converts a protocol.Header and pixel payload into a Frame. +func headerToFrame(hdr protocol.Header, pixels []byte, name string) Frame { + f := Frame{ + ModuleID: hdr.ModuleID, + Name: name, + Pixels: pixels, + Width: uint32(hdr.Width), + Height: uint32(hdr.Height), + Timestamp: hdr.TimestampNs, + } + + if hdr.Flags.Has(protocol.FlagDirtyValid) { + f.DirtyRect = image.Rect( + int(hdr.DirtyX), + int(hdr.DirtyY), + int(hdr.DirtyX)+int(hdr.DirtyW), + int(hdr.DirtyY)+int(hdr.DirtyH), + ) + } + + return f +} + +// resolveCodec returns the appropriate codec for the given compression name. +func resolveCodec(name string) codec.Codec { + switch name { + case "lz4": + return codec.LZ4() + default: + return codec.Raw() + } +}