From 02c92777e9d88471736fcf53943315e49f3915a3 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sun, 26 Apr 2026 15:02:22 +0000 Subject: [PATCH] device, cmd/check-lockorder: add static analysis tool for lock ordering This adds cmd/check-lockorder, a static analyzer that builds a lock-after directed graph from the device package and reports cycles (potential deadlocks) and reentrant RLocks (which deadlock with a pending writer), and hook it up to CI to fail so we don't regress. This also does some no-op (I believe) refactoring of some device code to make locking easier for both humans and cmd/check-lockorder to follow, without having to hard-code exceptions in cmd/check-lockorder or make it more complicated. This is the tool that previously found the deadlock fixed by 770e3f59265. Updates tailscale/tailscale#19513 Signed-off-by: Brad Fitzpatrick --- .github/workflows/test.yml | 12 ++ cmd/check-lockorder/analysis.go | 327 +++++++++++++++++++++++++++++ cmd/check-lockorder/graph.go | 269 ++++++++++++++++++++++++ cmd/check-lockorder/inventory.go | 130 ++++++++++++ cmd/check-lockorder/main.go | 94 +++++++++ cmd/check-lockorder/registry.go | 131 ++++++++++++ cmd/check-lockorder/resolve.go | 339 +++++++++++++++++++++++++++++++ cmd/check-lockorder/verbose.go | 69 +++++++ device/channels.go | 4 +- device/lock-ordering.md | 27 --- device/pools.go | 2 - device/receive.go | 206 +++++++++++-------- device/send.go | 122 ++++++----- go.mod | 11 +- go.sum | 12 ++ 15 files changed, 1587 insertions(+), 168 deletions(-) create mode 100644 cmd/check-lockorder/analysis.go create mode 100644 cmd/check-lockorder/graph.go create mode 100644 cmd/check-lockorder/inventory.go create mode 100644 cmd/check-lockorder/main.go create mode 100644 cmd/check-lockorder/registry.go create mode 100644 cmd/check-lockorder/resolve.go create mode 100644 cmd/check-lockorder/verbose.go delete mode 100644 device/lock-ordering.md diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a37099545..f5935aff1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,3 +66,15 @@ jobs: go-version-file: go.mod - name: test run: go test -race -v ./... + + check-lockorder: + runs-on: ubuntu-22.04 + steps: + - name: checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: setup go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 + with: + go-version-file: go.mod + - name: check-lockorder + run: go run ./cmd/check-lockorder diff --git a/cmd/check-lockorder/analysis.go b/cmd/check-lockorder/analysis.go new file mode 100644 index 000000000..7ab1f606b --- /dev/null +++ b/cmd/check-lockorder/analysis.go @@ -0,0 +1,327 @@ +package main + +import ( + "go/ast" + "go/token" + "go/types" +) + +// HeldLock is one lock in the "currently held" set. +type HeldLock struct { + Lock LockID + Kind LockKind + Pos token.Pos +} + +// CallFrame is one step in a call chain for edge attribution. +type CallFrame struct { + FuncName string + Pos token.Pos +} + +// Edge represents a lock-after relationship: Target was acquired while Source was held. +type Edge struct { + From LockID + FromKind LockKind + To LockID + ToKind LockKind + Chain []CallFrame // how we got here +} + +// LockAcq records a lock that a function (transitively) acquires. +type LockAcq struct { + Lock LockID + Kind LockKind + Chain []CallFrame +} + +// analyzer performs per-function held-set analysis and inter-procedural expansion. +type analyzer struct { + fset *token.FileSet + info *types.Info + pkg *types.Package + registry map[string]LockID + resolver *resolver + + funcs map[string]*FuncInfo // funcFullName → FuncInfo + funcDecl map[string]*ast.FuncDecl + edges []Edge + + // Transitive lock acquisitions per function, computed during expansion. + transitive map[string][]LockAcq // funcFullName → transitive locks + expanding map[string]bool // cycle detection during expansion +} + +func newAnalyzer(fset *token.FileSet, info *types.Info, pkg *types.Package, registry map[string]LockID) *analyzer { + return &analyzer{ + fset: fset, + info: info, + pkg: pkg, + registry: registry, + resolver: newResolver(fset, info, pkg, registry), + funcs: make(map[string]*FuncInfo), + funcDecl: make(map[string]*ast.FuncDecl), + transitive: make(map[string][]LockAcq), + expanding: make(map[string]bool), + } +} + +// addFile processes all function declarations in a file. +func (a *analyzer) addFile(file *ast.File) { + for _, decl := range file.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok || fn.Body == nil { + continue + } + fi := a.resolver.extractFuncInfo(fn) + if fi == nil { + continue + } + a.funcs[fi.Name] = fi + a.funcDecl[fi.Name] = fn + } +} + +// analyze runs the full analysis: held-set computation, transitive expansion, +// and edge collection. +func (a *analyzer) analyze() []Edge { + // Phase 1: Compute held-sets and direct edges for each function. + for _, fi := range a.funcs { + a.computeHeldSets(fi) + } + + // Phase 2: Transitive expansion — for each call under lock, + // expand the callee's transitive lock acquisitions and add edges. + for _, fi := range a.funcs { + a.expandCalls(fi) + } + + return a.edges +} + +// computeHeldSets walks a function's lock ops and calls in order, +// maintaining the set of currently-held locks. It records direct edges +// (lock acquired while another is held) and annotates calls with held sets. +func (a *analyzer) computeHeldSets(fi *FuncInfo) { + if fi.Decl == nil || fi.Decl.Body == nil { + return + } + + held := make(map[LockID]HeldLock) + a.walkBody(fi.Decl.Body.List, held, fi) +} + +// walkBody processes statements in order, tracking the held-lock set. +func (a *analyzer) walkBody(stmts []ast.Stmt, held map[LockID]HeldLock, fi *FuncInfo) { + for _, stmt := range stmts { + a.walkStmtForHeld(stmt, held, fi, false) + } +} + +func (a *analyzer) walkStmtForHeld(stmt ast.Stmt, held map[LockID]HeldLock, fi *FuncInfo, inDefer bool) { + switch s := stmt.(type) { + case *ast.ExprStmt: + a.processExprForHeld(s.X, held, fi, inDefer) + case *ast.AssignStmt: + for _, expr := range s.Rhs { + a.processExprForHeld(expr, held, fi, inDefer) + } + case *ast.DeferStmt: + a.processExprForHeld(s.Call, held, fi, true) + case *ast.GoStmt: + // New goroutine — don't track through + case *ast.BlockStmt: + a.walkBody(s.List, held, fi) + case *ast.IfStmt: + if s.Init != nil { + a.walkStmtForHeld(s.Init, held, fi, inDefer) + } + a.walkBody(s.Body.List, held, fi) + if s.Else != nil { + a.walkStmtForHeld(s.Else, held, fi, inDefer) + } + case *ast.ForStmt: + a.walkBody(s.Body.List, held, fi) + case *ast.RangeStmt: + a.walkBody(s.Body.List, held, fi) + case *ast.SwitchStmt: + if s.Init != nil { + a.walkStmtForHeld(s.Init, held, fi, inDefer) + } + a.walkBody(s.Body.List, held, fi) + case *ast.TypeSwitchStmt: + a.walkBody(s.Body.List, held, fi) + case *ast.CaseClause: + a.walkBody(s.Body, held, fi) + case *ast.SelectStmt: + a.walkBody(s.Body.List, held, fi) + case *ast.CommClause: + a.walkBody(s.Body, held, fi) + } +} + +func (a *analyzer) processExprForHeld(expr ast.Expr, held map[LockID]HeldLock, fi *FuncInfo, inDefer bool) { + call, ok := expr.(*ast.CallExpr) + if !ok { + return + } + + // Immediately-invoked function literal: func() { ... }() + // The closure inherits the caller's held locks but its deferred unlocks + // fire when the closure returns, not when the outer function returns. + // We analyze the closure body with a copy of the held set so that + // deferred unlocks inside the closure don't leak into the outer scope. + if funcLit, ok := call.Fun.(*ast.FuncLit); ok { + closureHeld := make(map[LockID]HeldLock, len(held)) + for k, v := range held { + closureHeld[k] = v + } + a.walkBody(funcLit.Body.List, closureHeld, fi) + return + } + + sel, isSel := call.Fun.(*ast.SelectorExpr) + + // Check for lock/unlock operations + if isSel && isLockMethod(sel.Sel.Name) { + lockID, kind := a.resolver.identifyLock(sel.X, sel.Sel.Name) + if lockID != "" { + isUnlock := sel.Sel.Name == "Unlock" || sel.Sel.Name == "RUnlock" + if isUnlock && !inDefer { + // Explicit (non-deferred) unlock: remove from held set + delete(held, lockID) + } else if !isUnlock { + // Lock acquisition: record edges from all currently held locks + for _, h := range held { + a.edges = append(a.edges, Edge{ + From: h.Lock, + FromKind: h.Kind, + To: lockID, + ToKind: kind, + Chain: []CallFrame{{ + FuncName: fi.Name, + Pos: call.Pos(), + }}, + }) + } + // Also check for reentrant acquisition + if existing, ok := held[lockID]; ok { + _ = existing // The edge above captures the self-loop + } + held[lockID] = HeldLock{ + Lock: lockID, + Kind: kind, + Pos: call.Pos(), + } + } + // Deferred unlock: lock stays in held set (released at function return) + return + } + } + + // Non-lock function call: record as a call under lock if we hold any locks. + if len(held) == 0 { + return + } + + var calleeFunc *types.Func + if isSel { + obj := a.info.ObjectOf(sel.Sel) + if fn, ok := obj.(*types.Func); ok && fn.Pkg() == a.pkg { + calleeFunc = fn + } + } else if ident, ok := call.Fun.(*ast.Ident); ok { + obj := a.info.ObjectOf(ident) + if fn, ok := obj.(*types.Func); ok && fn.Pkg() == a.pkg { + calleeFunc = fn + } + } + + if calleeFunc == nil { + return + } + + // Record edges: each held lock → each lock the callee transitively acquires + calleeName := calleeFunc.FullName() + transLocks := a.getTransitiveLocks(calleeName) + for _, h := range held { + for _, tl := range transLocks { + chain := []CallFrame{{ + FuncName: fi.Name, + Pos: call.Pos(), + }} + chain = append(chain, tl.Chain...) + a.edges = append(a.edges, Edge{ + From: h.Lock, + FromKind: h.Kind, + To: tl.Lock, + ToKind: tl.Kind, + Chain: chain, + }) + } + } +} + +// expandCalls is a no-op now — transitive expansion happens lazily in getTransitiveLocks. +func (a *analyzer) expandCalls(fi *FuncInfo) {} + +// getTransitiveLocks returns all locks that a function may acquire, directly +// or through callees. Results are memoized. +func (a *analyzer) getTransitiveLocks(funcName string) []LockAcq { + if cached, ok := a.transitive[funcName]; ok { + return cached + } + + // Cycle detection + if a.expanding[funcName] { + return nil + } + a.expanding[funcName] = true + defer func() { delete(a.expanding, funcName) }() + + fi, ok := a.funcs[funcName] + if !ok { + return nil + } + + var result []LockAcq + seen := make(map[LockID]bool) + + // Direct lock acquisitions + for _, op := range fi.LockOps { + if !op.IsUnlock && !seen[op.Lock] { + seen[op.Lock] = true + result = append(result, LockAcq{ + Lock: op.Lock, + Kind: op.Kind, + Chain: []CallFrame{{ + FuncName: funcName, + Pos: op.Pos, + }}, + }) + } + } + + // Transitive through callees + for _, c := range fi.Calls { + calleeName := c.Callee.FullName() + for _, tl := range a.getTransitiveLocks(calleeName) { + if !seen[tl.Lock] { + seen[tl.Lock] = true + chain := []CallFrame{{ + FuncName: funcName, + Pos: c.Pos, + }} + chain = append(chain, tl.Chain...) + result = append(result, LockAcq{ + Lock: tl.Lock, + Kind: tl.Kind, + Chain: chain, + }) + } + } + } + + a.transitive[funcName] = result + return result +} diff --git a/cmd/check-lockorder/graph.go b/cmd/check-lockorder/graph.go new file mode 100644 index 000000000..b2c966976 --- /dev/null +++ b/cmd/check-lockorder/graph.go @@ -0,0 +1,269 @@ +package main + +import ( + "fmt" + "go/token" + "sort" + "strings" +) + +// Violation is a problem found by the analysis. +type Violation struct { + Kind string // "cycle" or "reentrant-rlock" + Message string + Detail string // multi-line detail with file:line references +} + +// findViolations examines the collected edges for cycles and reentrant RLock. +// Cycles or reentrance involving an InstanceLocal lock are not reported: +// per-instance locks where each instance is owned by exactly one goroutine +// can't deadlock with themselves under static cycle analysis (different +// holders are different instances). +func findViolations(edges []Edge, fset *token.FileSet) []Violation { + var violations []Violation + instLocal := instanceLocalLocks() + + // Deduplicate edges into an adjacency list with attribution. + type edgeKey struct { + from, to LockID + fromKind, toKind LockKind + } + best := make(map[edgeKey]Edge) + for _, e := range edges { + k := edgeKey{e.From, e.To, e.FromKind, e.ToKind} + if _, ok := best[k]; !ok { + best[k] = e + } + } + + // 1. Detect reentrant RLock (self-loops). Skip instance-local locks + // since each instance is owned by a single goroutine. + for k, e := range best { + if k.from == k.to { + if instLocal[k.from] { + continue + } + desc := "REENTRANT " + if k.fromKind == Shared && k.toKind == Shared { + desc += "RLOCK" + } else { + desc += "LOCK" + } + violations = append(violations, Violation{ + Kind: "reentrant-rlock", + Message: fmt.Sprintf("%s on %s", desc, k.from), + Detail: formatChain(e.Chain, fset), + }) + } + } + + // 2. Detect ordering cycles. + // Build adjacency list (excluding self-loops, already reported above, + // and edges touching instance-local locks). + adj := make(map[LockID]map[LockID]bool) + edgeFor := make(map[edgeKey]Edge) + for k, e := range best { + if k.from == k.to { + continue + } + if instLocal[k.from] || instLocal[k.to] { + continue + } + if adj[k.from] == nil { + adj[k.from] = make(map[LockID]bool) + } + adj[k.from][k.to] = true + edgeFor[edgeKey{k.from, k.to, k.fromKind, k.toKind}] = e + } + + // Find all 2-cycles (A→B and B→A). + // This is the most common deadlock pattern. + reported := make(map[[2]LockID]bool) + for k, e := range best { + if k.from == k.to { + continue + } + if instLocal[k.from] || instLocal[k.to] { + continue + } + // Check if the reverse edge exists + if adj[k.to] != nil && adj[k.to][k.from] { + pair := [2]LockID{k.from, k.to} + if k.from > k.to { + pair = [2]LockID{k.to, k.from} + } + if reported[pair] { + continue + } + reported[pair] = true + + // Find the reverse edge for the detail + var reverseEdge Edge + for rk, re := range best { + if rk.from == k.to && rk.to == k.from { + reverseEdge = re + break + } + } + + violations = append(violations, Violation{ + Kind: "cycle", + Message: fmt.Sprintf("LOCK ORDERING CYCLE: %s <-> %s", + pair[0], pair[1]), + Detail: fmt.Sprintf( + " Path A (%s -> %s):\n%s\n Path B (%s -> %s):\n%s", + e.From, e.To, + formatChain(e.Chain, fset), + reverseEdge.From, reverseEdge.To, + formatChain(reverseEdge.Chain, fset), + ), + }) + } + } + + // Also find longer cycles using DFS. + longerCycles := findLongerCycles(adj) + for _, cycle := range longerCycles { + if len(cycle) == 2 { + continue // already reported as a 2-cycle above + } + // Check if this is just a combination of known 2-cycles. + // If all adjacent pairs are already reported 2-cycles, skip. + allPairsKnown := true + for i := 0; i < len(cycle); i++ { + a := cycle[i] + b := cycle[(i+1)%len(cycle)] + pair := [2]LockID{a, b} + if a > b { + pair = [2]LockID{b, a} + } + if !reported[pair] { + allPairsKnown = false + break + } + } + if allPairsKnown { + continue + } + violations = append(violations, Violation{ + Kind: "cycle", + Message: fmt.Sprintf("LOCK ORDERING CYCLE: %s", formatCycle(cycle)), + Detail: fmt.Sprintf(" Cycle: %s", formatCycle(cycle)), + }) + } + + sort.Slice(violations, func(i, j int) bool { + if violations[i].Kind != violations[j].Kind { + return violations[i].Kind < violations[j].Kind + } + return violations[i].Message < violations[j].Message + }) + + return violations +} + +// findLongerCycles finds all elementary cycles in the directed graph. +// Returns unique cycles (each as a slice of LockIDs). +func findLongerCycles(adj map[LockID]map[LockID]bool) [][](LockID) { + var cycles [][]LockID + + // Standard DFS-based cycle detection + color := make(map[LockID]int) // 0=white, 1=gray, 2=black + parent := make(map[LockID][]LockID) + + var dfs func(node LockID, path []LockID) + dfs = func(node LockID, path []LockID) { + color[node] = 1 + path = append(path, node) + + for next := range adj[node] { + if color[next] == 1 { + // Found a cycle — extract it + start := -1 + for i, n := range path { + if n == next { + start = i + break + } + } + if start >= 0 { + cycle := make([]LockID, len(path)-start) + copy(cycle, path[start:]) + if len(cycle) > 2 { + cycles = append(cycles, cycle) + } + } + } else if color[next] == 0 { + parent[next] = path + dfs(next, path) + } + } + + color[node] = 2 + } + + // Get sorted nodes for deterministic output + var nodes []LockID + for n := range adj { + nodes = append(nodes, n) + } + sort.Slice(nodes, func(i, j int) bool { + return nodes[i] < nodes[j] + }) + + for _, n := range nodes { + if color[n] == 0 { + dfs(n, nil) + } + } + + return cycles +} + +func formatChain(chain []CallFrame, fset *token.FileSet) string { + if len(chain) == 0 { + return " (no attribution)" + } + var b strings.Builder + for _, frame := range chain { + pos := fset.Position(frame.Pos) + // Strip to just filename:line + file := pos.Filename + if idx := strings.LastIndex(file, "/"); idx >= 0 { + file = file[idx+1:] + } + fmt.Fprintf(&b, " %s:%d %s\n", file, pos.Line, shortName(frame.FuncName)) + } + return strings.TrimRight(b.String(), "\n") +} + +func formatCycle(cycle []LockID) string { + parts := make([]string, len(cycle)) + for i, c := range cycle { + parts[i] = string(c) + } + return strings.Join(parts, " -> ") + " -> " + parts[0] +} + +// shortName strips the package path from a function name. +// +// "(*github.com/tailscale/wireguard-go/device.Device).SetPrivateKey" → "(*Device).SetPrivateKey" +// "(github.com/tailscale/wireguard-go/device.Peer).Foo" → "(Peer).Foo" +// "github.com/tailscale/wireguard-go/device.SomeFunc" → "SomeFunc" +func shortName(fullName string) string { + leading := "" + s := fullName + switch { + case strings.HasPrefix(s, "(*"): + leading, s = "(*", s[2:] + case strings.HasPrefix(s, "("): + leading, s = "(", s[1:] + } + if idx := strings.LastIndex(s, "/"); idx >= 0 { + s = s[idx+1:] + } + if dot := strings.Index(s, "."); dot >= 0 { + s = s[dot+1:] + } + return leading + s +} diff --git a/cmd/check-lockorder/inventory.go b/cmd/check-lockorder/inventory.go new file mode 100644 index 000000000..65b6d0d1f --- /dev/null +++ b/cmd/check-lockorder/inventory.go @@ -0,0 +1,130 @@ +package main + +import ( + "go/types" + "sort" +) + +// lockSite identifies a sync.Mutex or sync.RWMutex field declared in +// the device package. DefType is the named struct type that declares +// the field. FieldPath is the dot-separated path from a value of that +// type to the mutex, or "" if the mutex is directly embedded on the +// struct. +// +// Examples: +// +// {DefType: "Device", FieldPath: "ipcMutex"} // device.ipcMutex (named field) +// {DefType: "Device", FieldPath: "state"} // device.state.{Mutex} (anon-struct field embeds sync.Mutex) +// {DefType: "Handshake", FieldPath: "mutex"} // handshake.mutex (named field) +// {DefType: "Keypairs", FieldPath: ""} // keypairs.{RWMutex} (embedded directly on Keypairs) +type lockSite struct { + DefType string + FieldPath string +} + +func (s lockSite) String() string { + if s.FieldPath == "" { + return s.DefType + } + return s.DefType + "." + s.FieldPath +} + +// findLockSites walks every named struct type in pkg and returns all +// sync.Mutex/sync.RWMutex fields reachable without crossing into +// another named type. Anonymous (inline) struct fields are descended +// into; named struct fields are not (they are inventoried separately +// when they appear at the top level of pkg). +func findLockSites(pkg *types.Package) []lockSite { + var sites []lockSite + scope := pkg.Scope() + for _, name := range scope.Names() { + obj := scope.Lookup(name) + tn, ok := obj.(*types.TypeName) + if !ok { + continue + } + named, ok := tn.Type().(*types.Named) + if !ok { + continue + } + st, ok := named.Underlying().(*types.Struct) + if !ok { + continue + } + walkStructForLocks(name, st, "", &sites) + } + sort.Slice(sites, func(i, j int) bool { + if sites[i].DefType != sites[j].DefType { + return sites[i].DefType < sites[j].DefType + } + return sites[i].FieldPath < sites[j].FieldPath + }) + return sites +} + +func walkStructForLocks(defType string, st *types.Struct, prefix string, out *[]lockSite) { + for i := 0; i < st.NumFields(); i++ { + f := st.Field(i) + if isMutexType(f.Type()) { + path := prefix + if !f.Anonymous() { + path = joinPath(prefix, f.Name()) + } + *out = append(*out, lockSite{DefType: defType, FieldPath: path}) + continue + } + // Descend only into anonymous (inline) struct types. Named + // types — including pointers and other struct types defined + // elsewhere — are inventoried at the top level on their own + // (or are out-of-scope, e.g. *time.Timer). + if _, ok := f.Type().(*types.Named); ok { + continue + } + anon, ok := f.Type().Underlying().(*types.Struct) + if !ok { + continue + } + sub := prefix + if !f.Anonymous() { + sub = joinPath(prefix, f.Name()) + } + walkStructForLocks(defType, anon, sub, out) + } +} + +func joinPath(prefix, name string) string { + if prefix == "" { + return name + } + return prefix + "." + name +} + +func isMutexType(t types.Type) bool { + n, ok := t.(*types.Named) + if !ok { + return false + } + obj := n.Obj() + if obj.Pkg() == nil || obj.Pkg().Path() != "sync" { + return false + } + return obj.Name() == "Mutex" || obj.Name() == "RWMutex" +} + +// checkInventory returns the sync.Mutex/sync.RWMutex sites in pkg that +// are not registered in trackedLocks. An empty result means every +// mutex in the package is part of the lock-ordering analysis. +func checkInventory(pkg *types.Package) []lockSite { + sites := findLockSites(pkg) + known := make(map[lockSite]bool, len(trackedLocks)) + for _, tl := range trackedLocks { + known[lockSite{DefType: tl.DefType, FieldPath: tl.DefPath}] = true + } + var unknown []lockSite + for _, s := range sites { + if !known[s] { + unknown = append(unknown, s) + } + } + return unknown +} diff --git a/cmd/check-lockorder/main.go b/cmd/check-lockorder/main.go new file mode 100644 index 000000000..fe25aca1c --- /dev/null +++ b/cmd/check-lockorder/main.go @@ -0,0 +1,94 @@ +// check-lockorder statically analyzes the device package for lock-ordering +// deadlocks. It builds a lock-after directed graph and reports cycles +// (potential deadlocks) and reentrant RLock (deadlock with pending writer). +// +// Exit code 0 means no violations found; 1 means violations were found. +// +// Usage: +// +// go run ./cmd/check-lockorder [-v] +// +// With -v, print the full lock inventory and every observed lock-after edge. +package main + +import ( + "flag" + "fmt" + "os" + + "golang.org/x/tools/go/packages" +) + +const targetPkg = "github.com/tailscale/wireguard-go/device" + +func main() { + verbose := flag.Bool("v", false, "list the lock inventory and all observed lock-after edges") + flag.Parse() + + cfg := &packages.Config{ + Mode: packages.NeedName | + packages.NeedFiles | + packages.NeedSyntax | + packages.NeedTypes | + packages.NeedTypesInfo, + } + pkgs, err := packages.Load(cfg, targetPkg) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to load package: %v\n", err) + os.Exit(2) + } + if len(pkgs) == 0 { + fmt.Fprintf(os.Stderr, "package %s not found\n", targetPkg) + os.Exit(2) + } + pkg := pkgs[0] + if len(pkg.Errors) > 0 { + for _, e := range pkg.Errors { + fmt.Fprintf(os.Stderr, "package error: %v\n", e) + } + os.Exit(2) + } + + if unknown := checkInventory(pkg.Types); len(unknown) > 0 { + fmt.Fprintf(os.Stderr, "found %d unregistered sync.Mutex/sync.RWMutex field(s) in %s:\n", + len(unknown), targetPkg) + for _, s := range unknown { + fmt.Fprintf(os.Stderr, " %s\n", s) + } + fmt.Fprintln(os.Stderr) + fmt.Fprintln(os.Stderr, "Every mutex in the device package must be registered in trackedLocks") + fmt.Fprintln(os.Stderr, "(cmd/check-lockorder/registry.go). The analyzer will treat the new lock") + fmt.Fprintln(os.Stderr, "as part of the partial order and fail on any cycle it participates in.") + os.Exit(1) + } + + registry := buildRegistry() + a := newAnalyzer(pkg.Fset, pkg.TypesInfo, pkg.Types, registry) + + for _, file := range pkg.Syntax { + a.addFile(file) + } + + edges := a.analyze() + + fmt.Fprintf(os.Stderr, "analyzed %d functions, found %d lock-after edges\n", + len(a.funcs), len(edges)) + + if *verbose { + printVerbose(os.Stdout, edges, pkg.Fset) + fmt.Fprintln(os.Stdout) + } + + violations := findViolations(edges, pkg.Fset) + + if len(violations) == 0 { + fmt.Println("no lock-ordering violations found") + return + } + + fmt.Printf("found %d lock-ordering violation(s):\n\n", len(violations)) + for i, v := range violations { + fmt.Printf("%d. %s\n%s\n\n", i+1, v.Message, v.Detail) + } + os.Exit(1) +} diff --git a/cmd/check-lockorder/registry.go b/cmd/check-lockorder/registry.go new file mode 100644 index 000000000..37c90e86c --- /dev/null +++ b/cmd/check-lockorder/registry.go @@ -0,0 +1,131 @@ +// Package main implements a static lock-order checker for the device package. +package main + +// LockID is the canonical name of a lock class, e.g. "Device.staticIdentity". +// All instances of the same struct share the same LockID. +type LockID string + +// LockKind distinguishes exclusive (Lock) from shared (RLock). +type LockKind int + +const ( + Exclusive LockKind = iota // Lock() + Shared // RLock() +) + +func (k LockKind) String() string { + if k == Shared { + return "RLock" + } + return "Lock" +} + +// MutexKind distinguishes Mutex from RWMutex. +type MutexKind int + +const ( + PlainMutex MutexKind = iota // sync.Mutex — only Lock/Unlock + ReadWriteMutex // sync.RWMutex — Lock/Unlock/RLock/RUnlock +) + +// TrackedLock describes one lock class we want to track. +// +// (OwnerType, FieldPath) is the access-path key: it matches the selector +// chain that callers write at the lock site, e.g. `device.staticIdentity` +// or `peer.handshake.mutex`. (DefType, DefPath) identifies the struct and +// field where the mutex is actually declared, which is what the inventory +// check uses to match against the package's struct types. The two pairs +// differ when a lock lives on a sub-type accessed through a parent — for +// example, `Device.allowedips.mu` is owned via Device but defined on the +// AllowedIPs type as field `mu`. +type TrackedLock struct { + ID LockID + OwnerType string // "Device" or "Peer" — the struct containing this lock + FieldPath string // dot-separated field path from owner, e.g. "handshake.mutex" + Kind MutexKind + DefType string // struct type that declares the mutex field + DefPath string // path within DefType to the mutex; "" if directly embedded + + // InstanceLocal indicates that the lock is per-instance and that + // each instance is owned by exactly one goroutine at a time + // (typically because instances flow through a channel). Static cycle + // detection cannot distinguish instances of a lock class, so an + // instance-local lock would otherwise produce false-positive cycles + // when two different goroutines hold two different instances. Cycle + // and reentrance detection skip any cycle that contains an + // instance-local lock; the lock is still tracked for inventory and + // edge listing. + InstanceLocal bool +} + +// trackedLocks is the registry of every lock in the device package. +// Adding a new sync.Mutex or sync.RWMutex anywhere in the package +// without registering it here causes the inventory check to fail. +// +// FieldPath is matched against the selector chain from the receiver +// variable up to (but not including) the Lock/Unlock method call. For +// directly-embedded mutexes (FieldPath: "") the registry key has a +// trailing dot, e.g. "Keypairs." — see buildRegistry. +// +// Lock ordering is determined topologically from the lock-after edges +// the analyzer observes. There are no level numbers in the data: a +// lock with no outgoing edges is a leaf in the partial order; a cycle +// involving any pair of locks is a violation regardless of how +// "isolated" the participants might seem. +var trackedLocks = []TrackedLock{ + {ID: "Device.state", OwnerType: "Device", FieldPath: "state", Kind: PlainMutex, DefType: "Device", DefPath: "state"}, + {ID: "Device.ipcMutex", OwnerType: "Device", FieldPath: "ipcMutex", Kind: ReadWriteMutex, DefType: "Device", DefPath: "ipcMutex"}, + {ID: "Device.net", OwnerType: "Device", FieldPath: "net", Kind: ReadWriteMutex, DefType: "Device", DefPath: "net"}, + {ID: "Device.staticIdentity", OwnerType: "Device", FieldPath: "staticIdentity", Kind: ReadWriteMutex, DefType: "Device", DefPath: "staticIdentity"}, + {ID: "Device.peers", OwnerType: "Device", FieldPath: "peers", Kind: ReadWriteMutex, DefType: "Device", DefPath: "peers"}, + {ID: "Device.allowedips.mu", OwnerType: "Device", FieldPath: "allowedips.mu", Kind: ReadWriteMutex, DefType: "AllowedIPs", DefPath: "mu"}, + {ID: "Device.indexTable", OwnerType: "Device", FieldPath: "indexTable", Kind: ReadWriteMutex, DefType: "IndexTable", DefPath: ""}, + {ID: "Device.cookieChecker", OwnerType: "Device", FieldPath: "cookieChecker", Kind: ReadWriteMutex, DefType: "CookieChecker", DefPath: ""}, + {ID: "Peer.state", OwnerType: "Peer", FieldPath: "state", Kind: PlainMutex, DefType: "Peer", DefPath: "state"}, + {ID: "Peer.handshake.mutex", OwnerType: "Peer", FieldPath: "handshake.mutex", Kind: ReadWriteMutex, DefType: "Handshake", DefPath: "mutex"}, + {ID: "Peer.keypairs", OwnerType: "Peer", FieldPath: "keypairs", Kind: ReadWriteMutex, DefType: "Keypairs", DefPath: ""}, + {ID: "Peer.endpoint", OwnerType: "Peer", FieldPath: "endpoint", Kind: PlainMutex, DefType: "Peer", DefPath: "endpoint"}, + {ID: "Peer.cookieGenerator", OwnerType: "Peer", FieldPath: "cookieGenerator", Kind: ReadWriteMutex, DefType: "CookieGenerator", DefPath: ""}, + {ID: "Timer.modifyingLock", OwnerType: "Timer", FieldPath: "modifyingLock", Kind: ReadWriteMutex, DefType: "Timer", DefPath: "modifyingLock"}, + {ID: "Timer.runningLock", OwnerType: "Timer", FieldPath: "runningLock", Kind: PlainMutex, DefType: "Timer", DefPath: "runningLock"}, + {ID: "WaitPool.lock", OwnerType: "WaitPool", FieldPath: "lock", Kind: PlainMutex, DefType: "WaitPool", DefPath: "lock"}, +} + +// instanceLocalLocks is the set of LockIDs marked InstanceLocal in +// trackedLocks; used by cycle detection to skip false positives caused +// by treating per-instance locks as a single class. +func instanceLocalLocks() map[LockID]bool { + m := map[LockID]bool{} + for _, tl := range trackedLocks { + if tl.InstanceLocal { + m[tl.ID] = true + } + } + return m +} + +// alternateResolutions maps (TypeName.fieldPath) to the canonical LockID +// for cases where a lock is accessed through an intermediate type rather +// than the top-level owner. For example, handshake.mutex where handshake +// is *Handshake (obtained from the index table) rather than +// peer.handshake.mutex. Keys for directly-embedded mutexes have a +// trailing dot — see buildRegistry. +var alternateResolutions = map[string]LockID{ + "Handshake.mutex": "Peer.handshake.mutex", + "Keypairs.": "Peer.keypairs", // keypairs.Lock() inside *Keypairs methods + "CookieChecker.": "Device.cookieChecker", // st.Lock() inside *CookieChecker methods + "CookieGenerator.": "Peer.cookieGenerator", // st.Lock() inside *CookieGenerator methods +} + +// buildRegistry creates a lookup map from (ownerType.fieldPath) → LockID. +func buildRegistry() map[string]LockID { + m := make(map[string]LockID, len(trackedLocks)+len(alternateResolutions)) + for _, tl := range trackedLocks { + key := tl.OwnerType + "." + tl.FieldPath + m[key] = tl.ID + } + for key, id := range alternateResolutions { + m[key] = id + } + return m +} diff --git a/cmd/check-lockorder/resolve.go b/cmd/check-lockorder/resolve.go new file mode 100644 index 000000000..978ba87c4 --- /dev/null +++ b/cmd/check-lockorder/resolve.go @@ -0,0 +1,339 @@ +package main + +import ( + "go/ast" + "go/token" + "go/types" + "strings" +) + +// LockOp represents a single lock or unlock operation found in source. +type LockOp struct { + Lock LockID + Kind LockKind // Exclusive or Shared + IsUnlock bool + IsDefer bool + Pos token.Pos +} + +// CallUnderLock records a function call made while locks are held. +// The held set is filled in later during analysis. +type CallUnderLock struct { + Callee *types.Func + CalleeFn *ast.FuncDecl // nil for non-local callees + Pos token.Pos +} + +// FuncInfo holds the raw extracted information about a single function. +type FuncInfo struct { + Name string + Obj *types.Func + Decl *ast.FuncDecl + LockOps []LockOp + Calls []CallUnderLock + BodyStmts []ast.Stmt // the function body statements +} + +// resolver extracts lock operations, variable aliases, and call sites from functions. +type resolver struct { + fset *token.FileSet + info *types.Info + registry map[string]LockID // (OwnerType.fieldPath) → LockID + pkg *types.Package + + // Per-function state + aliases map[string]string // localVar → resolved receiver expression string +} + +func newResolver(fset *token.FileSet, info *types.Info, pkg *types.Package, registry map[string]LockID) *resolver { + return &resolver{ + fset: fset, + info: info, + registry: registry, + pkg: pkg, + } +} + +// extractFuncInfo analyzes a single function declaration. +func (r *resolver) extractFuncInfo(fn *ast.FuncDecl) *FuncInfo { + if fn.Body == nil { + return nil + } + + obj := r.info.ObjectOf(fn.Name) + if obj == nil { + return nil + } + funcObj, ok := obj.(*types.Func) + if !ok { + return nil + } + + fi := &FuncInfo{ + Name: funcObj.FullName(), + Obj: funcObj, + Decl: fn, + BodyStmts: fn.Body.List, + } + + r.aliases = make(map[string]string) + r.walkStmtList(fn.Body.List, fi) + return fi +} + +// walkStmtList processes a list of statements, extracting lock ops and calls. +func (r *resolver) walkStmtList(stmts []ast.Stmt, fi *FuncInfo) { + for _, stmt := range stmts { + r.walkStmt(stmt, fi, false) + } +} + +func (r *resolver) walkStmt(stmt ast.Stmt, fi *FuncInfo, inDefer bool) { + switch s := stmt.(type) { + case *ast.ExprStmt: + r.walkExpr(s.X, fi, inDefer) + case *ast.AssignStmt: + // Check for alias patterns: x := &y.field + r.checkAlias(s) + // Also check for lock ops in RHS + for _, expr := range s.Rhs { + r.walkExpr(expr, fi, inDefer) + } + case *ast.DeferStmt: + r.walkExpr(s.Call, fi, true) + case *ast.GoStmt: + // Goroutine launches start fresh — don't track through them + case *ast.BlockStmt: + r.walkStmtList(s.List, fi) + case *ast.IfStmt: + if s.Init != nil { + r.walkStmt(s.Init, fi, inDefer) + } + r.walkStmtList(s.Body.List, fi) + if s.Else != nil { + r.walkStmt(s.Else, fi, inDefer) + } + case *ast.ForStmt: + r.walkStmtList(s.Body.List, fi) + case *ast.RangeStmt: + r.walkStmtList(s.Body.List, fi) + case *ast.SwitchStmt: + if s.Init != nil { + r.walkStmt(s.Init, fi, inDefer) + } + r.walkStmtList(s.Body.List, fi) + case *ast.TypeSwitchStmt: + r.walkStmtList(s.Body.List, fi) + case *ast.CaseClause: + r.walkStmtList(s.Body, fi) + case *ast.SelectStmt: + r.walkStmtList(s.Body.List, fi) + case *ast.CommClause: + r.walkStmtList(s.Body, fi) + case *ast.ReturnStmt: + // nothing to track + } +} + +func (r *resolver) walkExpr(expr ast.Expr, fi *FuncInfo, inDefer bool) { + call, ok := expr.(*ast.CallExpr) + if !ok { + return + } + + // Check for immediately-invoked function literals: func() { ... }() + if funcLit, ok := call.Fun.(*ast.FuncLit); ok { + r.walkStmtList(funcLit.Body.List, fi) + return + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + // Non-selector call: could be a package-level function + if ident, ok := call.Fun.(*ast.Ident); ok { + obj := r.info.ObjectOf(ident) + if fn, ok := obj.(*types.Func); ok && fn.Pkg() == r.pkg { + fi.Calls = append(fi.Calls, CallUnderLock{ + Callee: fn, + Pos: call.Pos(), + }) + } + } + return + } + + methodName := sel.Sel.Name + + // Is this a lock/unlock operation? + if isLockMethod(methodName) { + if lockID, kind := r.identifyLock(sel.X, methodName); lockID != "" { + fi.LockOps = append(fi.LockOps, LockOp{ + Lock: lockID, + Kind: kind, + IsUnlock: methodName == "Unlock" || methodName == "RUnlock", + IsDefer: inDefer, + Pos: call.Pos(), + }) + return + } + } + + // Otherwise, it's a method call — track as a potential callee + selObj := r.info.ObjectOf(sel.Sel) + if fn, ok := selObj.(*types.Func); ok && fn.Pkg() == r.pkg { + fi.Calls = append(fi.Calls, CallUnderLock{ + Callee: fn, + Pos: call.Pos(), + }) + } +} + +// isLockMethod returns true for sync.Mutex/RWMutex method names. +func isLockMethod(name string) bool { + return name == "Lock" || name == "Unlock" || name == "RLock" || name == "RUnlock" +} + +// identifyLock resolves the receiver expression of a lock method call to a LockID. +// Returns ("", 0) if the lock is not in our tracked set. +func (r *resolver) identifyLock(receiver ast.Expr, methodName string) (LockID, LockKind) { + // Build the field path from the receiver expression. + // We walk the selector chain, resolving aliases, until we reach + // a variable whose type is *Device, *Peer, Device, or Peer. + ownerType, fieldPath := r.resolveReceiver(receiver) + if ownerType == "" { + return "", 0 + } + + key := ownerType + "." + fieldPath + lockID, ok := r.registry[key] + if !ok { + return "", 0 + } + + var kind LockKind + if methodName == "RLock" || methodName == "RUnlock" { + kind = Shared + } else { + kind = Exclusive + } + return lockID, kind +} + +// resolveReceiver walks a selector expression chain to extract (ownerType, fieldPath). +// For example, device.staticIdentity → ("Device", "staticIdentity") +// For peer.handshake.mutex → ("Peer", "handshake.mutex") +func (r *resolver) resolveReceiver(expr ast.Expr) (ownerType string, fieldPath string) { + var fields []string + cur := expr + + for { + switch e := cur.(type) { + case *ast.SelectorExpr: + fields = append(fields, e.Sel.Name) + cur = e.X + case *ast.Ident: + // Check if this ident has an alias + if resolved, ok := r.aliases[e.Name]; ok { + // The alias is a string like "Peer.handshake". + // Parse it: the first component is the type, the rest are fields. + parts := strings.SplitN(resolved, ".", 2) + ownerType = parts[0] + if len(parts) > 1 { + // Prepend the alias fields before our accumulated fields + aliasFields := strings.Split(parts[1], ".") + fields = append(fields, aliasFields...) + } + // Reverse fields (we collected them inner→outer, need outer→inner) + reverse(fields) + return ownerType, strings.Join(fields, ".") + } + + // Not an alias — check the variable's type + ownerType = r.typeName(e) + if ownerType == "" { + return "", "" + } + reverse(fields) + return ownerType, strings.Join(fields, ".") + case *ast.ParenExpr: + cur = e.X + case *ast.StarExpr: + cur = e.X + default: + return "", "" + } + } +} + +// typeName returns "Device" or "Peer" if the ident's type is *Device/*Peer/Device/Peer. +// Returns "" otherwise. +func (r *resolver) typeName(ident *ast.Ident) string { + obj := r.info.ObjectOf(ident) + if obj == nil { + return "" + } + t := obj.Type() + // Dereference pointer + if ptr, ok := t.(*types.Pointer); ok { + t = ptr.Elem() + } + if named, ok := t.(*types.Named); ok { + name := named.Obj().Name() + switch name { + case "Device", "Peer", "AllowedIPs", "IndexTable", + "Handshake", "Keypairs", + "CookieChecker", "CookieGenerator", + "Timer", "WaitPool": + return name + } + } + return "" +} + +// checkAlias records simple variable alias patterns: +// +// x := &y.field → aliases[x] = resolved(y.field) +// x := y → aliases[x] = resolved(y) if y is a pointer to a tracked type +func (r *resolver) checkAlias(assign *ast.AssignStmt) { + if len(assign.Lhs) != 1 || len(assign.Rhs) != 1 { + return + } + lhs, ok := assign.Lhs[0].(*ast.Ident) + if !ok { + return + } + + rhs := assign.Rhs[0] + + // Pattern: x := &y.field + if unary, ok := rhs.(*ast.UnaryExpr); ok && unary.Op.String() == "&" { + ownerType, fieldPath := r.resolveReceiver(unary.X) + if ownerType != "" { + if fieldPath != "" { + r.aliases[lhs.Name] = ownerType + "." + fieldPath + } else { + r.aliases[lhs.Name] = ownerType + } + } + return + } + + // Pattern: x := y.field (direct field access without &) + if sel, ok := rhs.(*ast.SelectorExpr); ok { + ownerType, fieldPath := r.resolveReceiver(sel.X) + if ownerType != "" { + path := fieldPath + if path != "" { + path += "." + } + path += sel.Sel.Name + r.aliases[lhs.Name] = ownerType + "." + path + } + } +} + +func reverse(s []string) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +} diff --git a/cmd/check-lockorder/verbose.go b/cmd/check-lockorder/verbose.go new file mode 100644 index 000000000..da72dee02 --- /dev/null +++ b/cmd/check-lockorder/verbose.go @@ -0,0 +1,69 @@ +package main + +import ( + "fmt" + "go/token" + "io" + "sort" +) + +// printVerbose writes the full lock inventory and every observed lock-after +// edge to w. Edges are deduplicated by (from, fromKind, to, toKind); for +// each unique pair the shortest attribution chain is shown. +func printVerbose(w io.Writer, edges []Edge, fset *token.FileSet) { + fmt.Fprintf(w, "== Lock inventory (%d) ==\n", len(trackedLocks)) + for _, tl := range trackedLocks { + kind := "Mutex" + if tl.Kind == ReadWriteMutex { + kind = "RWMutex" + } + def := tl.DefType + if tl.DefPath != "" { + def += "." + tl.DefPath + } + note := "" + if tl.InstanceLocal { + note = " [instance-local; cycles not checked]" + } + fmt.Fprintf(w, " %-32s %-7s defined at %s%s\n", tl.ID, kind, def, note) + } + fmt.Fprintln(w) + + type edgeKey struct { + from, to LockID + fromKind, toKind LockKind + } + best := map[edgeKey]Edge{} + for _, e := range edges { + k := edgeKey{e.From, e.To, e.FromKind, e.ToKind} + if cur, ok := best[k]; !ok || len(e.Chain) < len(cur.Chain) { + best[k] = e + } + } + + keys := make([]edgeKey, 0, len(best)) + for k := range best { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + if keys[i].from != keys[j].from { + return keys[i].from < keys[j].from + } + if keys[i].fromKind != keys[j].fromKind { + return keys[i].fromKind < keys[j].fromKind + } + if keys[i].to != keys[j].to { + return keys[i].to < keys[j].to + } + return keys[i].toKind < keys[j].toKind + }) + + fmt.Fprintf(w, "== Lock-after edges (%d unique pairs from %d observations) ==\n", + len(best), len(edges)) + for _, k := range keys { + e := best[k] + fmt.Fprintf(w, "\n%s.%s -> %s.%s\n", + k.from, k.fromKind, k.to, k.toKind) + fmt.Fprintln(w, formatChain(e.Chain, fset)) + } +} diff --git a/device/channels.go b/device/channels.go index e526f6bb1..bdba7dee2 100644 --- a/device/channels.go +++ b/device/channels.go @@ -91,7 +91,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { for { select { case elemsContainer := <-q.c: - elemsContainer.Lock() + elemsContainer.filling.Wait() for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) @@ -124,7 +124,7 @@ func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { for { select { case elemsContainer := <-q.c: - elemsContainer.Lock() + elemsContainer.filling.Wait() for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) diff --git a/device/lock-ordering.md b/device/lock-ordering.md deleted file mode 100644 index 55a15c0b7..000000000 --- a/device/lock-ordering.md +++ /dev/null @@ -1,27 +0,0 @@ -# Lock Ordering in wireguard-go/device - -## Lock hierarchy - -Locks must be acquired in the order listed below. A goroutine holding a -lock with a higher number must never attempt to acquire a lock with a -lower number. - -``` -Level 0 device.state.Mutex -Level 1 device.ipcMutex (sync.RWMutex) -Level 2 device.net.RWMutex -Level 3 device.staticIdentity.RWMutex -Level 4 device.peers.RWMutex -Level 5 peer.state.Mutex -Level 6 peer.handshake.mutex (sync.RWMutex) -Level 7 peer.keypairs.RWMutex -Level 8 device.allowedips.mu (sync.RWMutex) -Level 9 device.indexTable.RWMutex -Level 10 peer.endpoint.Mutex -Level 11 device.cookieChecker.RWMutex -Level 12 peer.cookieGenerator.RWMutex -Level 13 Timer.modifyingLock / Timer.runningLock -``` - -Not every pair of locks appears in practice; the ordering above is the -transitive closure of the pairs that do. diff --git a/device/pools.go b/device/pools.go index 55d2be7df..47f952b3c 100644 --- a/device/pools.go +++ b/device/pools.go @@ -68,7 +68,6 @@ func (device *Device) PopulatePools() { func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) - c.Mutex = sync.Mutex{} return c } @@ -82,7 +81,6 @@ func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContain func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) - c.Mutex = sync.Mutex{} return c } diff --git a/device/receive.go b/device/receive.go index 56cde1047..9fd2aec70 100644 --- a/device/receive.go +++ b/device/receive.go @@ -8,6 +8,7 @@ package device import ( "encoding/binary" "errors" + "fmt" "net" "net/netip" "sync" @@ -35,8 +36,13 @@ type QueueInboundElement struct { } type QueueInboundElementsContainer struct { - sync.Mutex - elems []*QueueInboundElement + // filling is a one-shot barrier signaling decryption→receive + // handoff. RoutineReceiveIncoming calls Add(1) before sending the + // container down the decryption and inbound queues; RoutineDecryption + // calls Done after decrypting; RoutineSequentialReceiver calls Wait + // before reading the decrypted packets. + filling sync.WaitGroup + elems []*QueueInboundElement } // clearPointers clears elem fields that contain pointers. @@ -178,7 +184,6 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive elemsForPeer, ok := elemsByPeer[peer] if !ok { elemsForPeer = device.GetInboundElementsContainer() - elemsForPeer.Lock() elemsByPeer[peer] = elemsForPeer } elemsForPeer.elems = append(elemsForPeer.elems, elem) @@ -222,6 +227,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive } for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { + elemsContainer.filling.Add(1) peer.queue.inbound.c <- elemsContainer device.queue.decryption.c <- elemsContainer } else { @@ -263,7 +269,7 @@ func (device *Device) RoutineDecryption(id int) { elem.packet = nil } } - elemsContainer.Unlock() + elemsContainer.filling.Done() } } @@ -441,102 +447,128 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { if elemsContainer == nil { return } - elemsContainer.Lock() - validTailPacket := -1 - dataPacketReceived := false - rxBytesLen := uint64(0) - for i, elem := range elemsContainer.elems { - if elem.packet == nil { - // decryption failed - continue - } + peer.processInboundContainer(elemsContainer, bufs[:0]) + } +} + +// processInboundContainer waits for the decryption routine to finish +// filling elemsContainer, then writes the valid packets to the TUN +// device and returns the container to the pool. +// +// scratch is a length-0 slice used to assemble the per-packet buffers +// passed to tun.device.Write; its backing array is reused across calls. +func (peer *Peer) processInboundContainer(elemsContainer *QueueInboundElementsContainer, scratch [][]byte) { + // Invariants from RoutineSequentialReceiver; all should be unreachable. + if len(scratch) != 0 || cap(scratch) == 0 { + panic(fmt.Sprintf("processInboundContainer: scratch must be empty with non-zero cap; got len=%d cap=%d", + len(scratch), cap(scratch))) + } + if cap(scratch) < len(elemsContainer.elems) { + panic(fmt.Sprintf("processInboundContainer: scratch cap %d < elems %d", + cap(scratch), len(elemsContainer.elems))) + } - if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { + device := peer.device + defer device.PutInboundElementsContainer(elemsContainer) + + // Wait for RoutineDecryption to finish filling the container. After + // Wait returns we have happens-before with that goroutine and are the + // sole owner of the container until Put hands it back to the pool. + elemsContainer.filling.Wait() + elems := elemsContainer.elems + + validTailPacket := -1 + dataPacketReceived := false + rxBytesLen := uint64(0) + for i, elem := range elems { + if elem.packet == nil { + // decryption failed + continue + } + + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { + continue + } + + validTailPacket = i + if peer.ReceivedWithKeypair(elem.keypair) { + peer.SetEndpointFromPacket(elem.endpoint) + peer.timersHandshakeComplete() + peer.SendStagedPackets() + } + if ep, ok := elem.endpoint.(conn.PeerAwareEndpoint); ok { + ep.FromPeer(peer.handshake.remoteStatic) + } + rxBytesLen += uint64(len(elem.packet) + MinMessageSize) + + if len(elem.packet) == 0 { + device.log.Verbosef("%v - Receiving keepalive packet", peer) + continue + } + dataPacketReceived = true + + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { continue } - - validTailPacket = i - if peer.ReceivedWithKeypair(elem.keypair) { - peer.SetEndpointFromPacket(elem.endpoint) - peer.timersHandshakeComplete() - peer.SendStagedPackets() + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue } - if ep, ok := elem.endpoint.(conn.PeerAwareEndpoint); ok { - ep.FromPeer(peer.handshake.remoteStatic) + elem.packet = elem.packet[:length] + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + srcAddr, _ := netip.AddrFromSlice(src) + if !peer.AllowedPeerSourceIP(srcAddr) { + device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) + continue } - rxBytesLen += uint64(len(elem.packet) + MinMessageSize) - if len(elem.packet) == 0 { - device.log.Verbosef("%v - Receiving keepalive packet", peer) + case 6: + if len(elem.packet) < ipv6.HeaderLen { continue } - dataPacketReceived = true - - switch elem.packet[0] >> 4 { - case 4: - if len(elem.packet) < ipv4.HeaderLen { - continue - } - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { - continue - } - elem.packet = elem.packet[:length] - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - srcAddr, _ := netip.AddrFromSlice(src) - if !peer.AllowedPeerSourceIP(srcAddr) { - device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) - continue - } - - case 6: - if len(elem.packet) < ipv6.HeaderLen { - continue - } - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - continue - } - elem.packet = elem.packet[:length] - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - srcAddr, _ := netip.AddrFromSlice(src) - if !peer.AllowedPeerSourceIP(srcAddr) { - device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) - continue - } - - default: - device.log.Verbosef("Packet with invalid IP version from %v", peer) + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + srcAddr, _ := netip.AddrFromSlice(src) + if !peer.AllowedPeerSourceIP(srcAddr) { + device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) continue } - bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) + default: + device.log.Verbosef("Packet with invalid IP version from %v", peer) + continue } - peer.rxBytes.Add(rxBytesLen) - if validTailPacket >= 0 { - peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - } - if dataPacketReceived { - peer.timersDataReceived() - } - if len(bufs) > 0 { - _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) - if err != nil && !device.isClosed() { - device.log.Errorf("Failed to write packets to TUN device: %v", err) - } - } - for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) - device.PutInboundElement(elem) + scratch = append(scratch, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) + } + + peer.rxBytes.Add(rxBytesLen) + if validTailPacket >= 0 { + peer.SetEndpointFromPacket(elems[validTailPacket].endpoint) + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + } + if dataPacketReceived { + peer.timersDataReceived() + } + if len(scratch) > 0 { + _, err := device.tun.device.Write(scratch, MessageTransportOffsetContent) + if err != nil && !device.isClosed() { + device.log.Errorf("Failed to write packets to TUN device: %v", err) } - bufs = bufs[:0] - device.PutInboundElementsContainer(elemsContainer) + } + for _, elem := range elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) } } diff --git a/device/send.go b/device/send.go index 89269fc07..497c39930 100644 --- a/device/send.go +++ b/device/send.go @@ -8,6 +8,7 @@ package device import ( "encoding/binary" "errors" + "fmt" "net" "net/netip" "os" @@ -58,8 +59,13 @@ type QueueOutboundElement struct { } type QueueOutboundElementsContainer struct { - sync.Mutex - elems []*QueueOutboundElement + // filling is a one-shot barrier signaling encryption→send handoff. + // SendStagedPackets calls Add(1) before sending the container down + // the encryption and outbound queues; RoutineEncryption calls Done + // after encrypting; RoutineSequentialSender calls Wait before + // reading the encrypted packets. + filling sync.WaitGroup + elems []*QueueOutboundElement } func (device *Device) NewOutboundElement() *QueueOutboundElement { @@ -378,7 +384,6 @@ top: elem.keypair = keypair } - elemsContainer.Lock() elemsContainer.elems = elemsContainer.elems[:i] if elemsContainerOOO != nil { @@ -392,6 +397,7 @@ top: // add to parallel and sequential queue if peer.isRunning.Load() { + elemsContainer.filling.Add(1) peer.queue.outbound.c <- elemsContainer peer.device.queue.encryption.c <- elemsContainer } else { @@ -483,7 +489,7 @@ func (device *Device) RoutineEncryption(id int) { // re-slice packet to include encapsulating transport space elem.packet = elem.buffer[:MessageEncapsulatingTransportSize+len(elem.packet)] } - elemsContainer.Unlock() + elemsContainer.filling.Done() } } @@ -498,58 +504,82 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) for elemsContainer := range peer.queue.outbound.c { - bufs = bufs[:0] if elemsContainer == nil { return } - if !peer.isRunning.Load() { - // peer has been stopped; return re-usable elems to the shared pool. - // This is an optimization only. It is possible for the peer to be stopped - // immediately after this check, in which case, elem will get processed. - // The timers and SendBuffers code are resilient to a few stragglers. - // TODO: rework peer shutdown order to ensure - // that we never accidentally keep timers alive longer than necessary. - elemsContainer.Lock() - for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - } - device.PutOutboundElementsContainer(elemsContainer) - continue - } - dataSent := false - elemsContainer.Lock() - for _, elem := range elemsContainer.elems { - if len(elem.packet[MessageEncapsulatingTransportSize:]) != MessageKeepaliveSize { - dataSent = true - } - bufs = append(bufs, elem.packet) - } + peer.processOutboundContainer(elemsContainer, bufs[:0]) + } +} - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketSent() +// processOutboundContainer waits for the encryption routine to finish +// filling elemsContainer, then sends the batch (or drops it, if the peer +// has been stopped) and returns the container to the pool. +// +// scratch is a length-0 slice used to assemble the per-packet buffers +// passed to SendBuffers; its backing array is reused across calls. +func (peer *Peer) processOutboundContainer(elemsContainer *QueueOutboundElementsContainer, scratch [][]byte) { + // Invariants from RoutineSequentialSender; all should be unreachable. + if len(scratch) != 0 || cap(scratch) == 0 { + panic(fmt.Sprintf("processOutboundContainer: scratch must be empty with non-zero cap; got len=%d cap=%d", + len(scratch), cap(scratch))) + } + if cap(scratch) < len(elemsContainer.elems) { + panic(fmt.Sprintf("processOutboundContainer: scratch cap %d < elems %d", + cap(scratch), len(elemsContainer.elems))) + } - err := peer.SendBuffers(bufs) - if dataSent { - peer.timersDataSent() - } + device := peer.device + defer device.PutOutboundElementsContainer(elemsContainer) + + // Wait for RoutineEncryption to finish filling the container. After + // Wait returns we have happens-before with that goroutine and are the + // sole owner of the container until Put hands it back to the pool. + elemsContainer.filling.Wait() + + if !peer.isRunning.Load() { + // peer has been stopped; return re-usable elems to the shared pool. + // This is an optimization only. It is possible for the peer to be stopped + // immediately after this check, in which case, elem will get processed. + // The timers and SendBuffers code are resilient to a few stragglers. + // TODO: rework peer shutdown order to ensure + // that we never accidentally keep timers alive longer than necessary. for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsContainer(elemsContainer) - if err != nil { - var errGSO conn.ErrUDPGSODisabled - if errors.As(err, &errGSO) { - device.log.Verbosef(err.Error()) - err = errGSO.RetryErr - } - } - if err != nil { - device.log.Errorf("%v - Failed to send data packets: %v", peer, err) - continue + return + } + + dataSent := false + for _, elem := range elemsContainer.elems { + if len(elem.packet[MessageEncapsulatingTransportSize:]) != MessageKeepaliveSize { + dataSent = true } + scratch = append(scratch, elem.packet) + } - peer.keepKeyFreshSending() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketSent() + + err := peer.SendBuffers(scratch) + if dataSent { + peer.timersDataSent() + } + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) } + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr + } + } + if err != nil { + device.log.Errorf("%v - Failed to send data packets: %v", peer, err) + return + } + + peer.keepKeyFreshSending() } diff --git a/go.mod b/go.mod index 37b0b6daf..ddb479350 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,19 @@ module github.com/tailscale/wireguard-go -go 1.25 +go 1.25.0 require ( - golang.org/x/crypto v0.13.0 - golang.org/x/net v0.15.0 - golang.org/x/sys v0.12.0 + golang.org/x/crypto v0.50.0 + golang.org/x/net v0.53.0 + golang.org/x/sys v0.43.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 ) require ( github.com/google/btree v1.0.1 // indirect + golang.org/x/mod v0.35.0 // indirect + golang.org/x/sync v0.20.0 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect + golang.org/x/tools v0.44.0 // indirect ) diff --git a/go.sum b/go.sum index 6bcecea3f..9821ebb23 100644 --- a/go.sum +++ b/go.sum @@ -2,12 +2,24 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= +golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= +golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=