diff --git a/daemon/listen.go b/daemon/listen.go new file mode 100644 index 0000000..981dd23 --- /dev/null +++ b/daemon/listen.go @@ -0,0 +1,165 @@ +package daemon + +import ( + "context" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "runtime" + "time" +) + +const defaultStaleSocketProbeTimeout = 500 * time.Millisecond + +type listenOptions struct { + // Store provides the daemon runtime listen lock used to serialize Unix + // socket stale cleanup and bind. Ignored when LockPath is set. + store RuntimeStore + + // LockPath overrides the lock file used for Unix socket startup. + lockPath string + + // StaleSocketProbeTimeout bounds the local dial used to prove that an + // existing Unix socket path is stale before removing it. + staleSocketProbeTimeout time.Duration +} + +// ListenOption configures Listen. +type ListenOption func(*listenOptions) + +// WithRuntimeStore uses store's daemon listen lock to serialize Unix socket +// stale cleanup and bind. Ignored when WithListenLockPath is also supplied. +func WithRuntimeStore(store RuntimeStore) ListenOption { + return func(opts *listenOptions) { + opts.store = store + } +} + +// WithListenLockPath overrides the lock file used for Unix socket startup. +func WithListenLockPath(lockPath string) ListenOption { + return func(opts *listenOptions) { + opts.lockPath = lockPath + } +} + +// WithStaleSocketProbeTimeout bounds the local dial used to prove that an +// existing Unix socket path is stale before removing it. +func WithStaleSocketProbeTimeout(timeout time.Duration) ListenOption { + return func(opts *listenOptions) { + opts.staleSocketProbeTimeout = timeout + } +} + +// Listen binds ep for daemon serving. +// +// For Unix sockets, Listen serializes stale socket probing/removal and the +// subsequent bind under an inter-process lock. Existing live sockets and +// non-socket paths are rejected. TCP endpoints and Windows retain Endpoint's +// normal Listen behavior. +func Listen(ctx context.Context, ep Endpoint, options ...ListenOption) (net.Listener, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + opts := listenOptions{} + for _, option := range options { + option(&opts) + } + if !ep.IsUnix() || runtime.GOOS == "windows" { + return ep.Listen() + } + if err := prepareUnixListenEndpoint(ep); err != nil { + return nil, err + } + lockPath, err := opts.listenLockPath(ep) + if err != nil { + return nil, err + } + unlock, err := acquireDaemonLock(ctx, lockPath, "acquire daemon listen lock") + if err != nil { + return nil, err + } + defer unlock() + if err := removeStaleUnixSocket(ctx, ep, opts); err != nil { + return nil, err + } + return ep.Listen() +} + +func prepareUnixListenEndpoint(ep Endpoint) error { + if ep.Address == "" { + return fmt.Errorf("empty daemon endpoint address") + } + if !filepath.IsAbs(ep.Address) { + return fmt.Errorf("unix socket path %q must be absolute", ep.Address) + } + if err := validatePrivateRuntimeDir(filepath.Dir(ep.Address)); err != nil { + return fmt.Errorf("validate unix socket dir: %w", err) + } + return nil +} + +func (opts listenOptions) listenLockPath(ep Endpoint) (string, error) { + if opts.lockPath != "" { + return opts.lockPath, nil + } + if opts.store.Dir != "" { + return opts.store.ListenLockPath() + } + if ep.Address == "" { + return "", fmt.Errorf("empty daemon endpoint address") + } + return ep.Address + ".lock", nil +} + +func (opts listenOptions) staleProbeTimeout() time.Duration { + if opts.staleSocketProbeTimeout > 0 { + return opts.staleSocketProbeTimeout + } + return defaultStaleSocketProbeTimeout +} + +func removeStaleUnixSocket(ctx context.Context, ep Endpoint, opts listenOptions) error { + if ep.Address == "" { + return fmt.Errorf("empty daemon endpoint address") + } + info, err := os.Lstat(ep.Address) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return fmt.Errorf("inspect unix socket %s: %w", ep.Address, err) + } + if info.Mode()&os.ModeSocket == 0 { + return fmt.Errorf("refusing to remove non-socket path %s", ep.Address) + } + stale, err := unixSocketStale(ctx, ep.Address, opts.staleProbeTimeout()) + if err != nil { + return err + } + if !stale { + return fmt.Errorf("daemon already listening on unix socket %s", ep.Address) + } + if err := os.Remove(ep.Address); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("remove stale unix socket %s: %w", ep.Address, err) + } + return nil +} + +func unixSocketStale(ctx context.Context, path string, timeout time.Duration) (bool, error) { + probeCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + conn, err := (&net.Dialer{}).DialContext(probeCtx, NetworkUnix, path) + if err == nil { + _ = conn.Close() + return false, nil + } + if ctxErr := probeCtx.Err(); ctxErr != nil { + return false, fmt.Errorf("probe unix socket %s: %w", path, ctxErr) + } + if isStaleUnixSocketDialError(err) { + return true, nil + } + return false, fmt.Errorf("probe unix socket %s: %w", path, err) +} diff --git a/daemon/listen_internal_unix_test.go b/daemon/listen_internal_unix_test.go new file mode 100644 index 0000000..e406883 --- /dev/null +++ b/daemon/listen_internal_unix_test.go @@ -0,0 +1,25 @@ +//go:build !windows + +package daemon + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnixSocketStaleTreatsMissingSocketAsStale(t *testing.T) { + dir, err := os.MkdirTemp("/tmp", "kitd-probe") + require.NoError(t, err) + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + + stale, err := unixSocketStale(context.Background(), filepath.Join(dir, "missing.sock"), 50*time.Millisecond) + + require.NoError(t, err) + assert.True(t, stale) +} diff --git a/daemon/listen_unix.go b/daemon/listen_unix.go new file mode 100644 index 0000000..ddd6f1a --- /dev/null +++ b/daemon/listen_unix.go @@ -0,0 +1,12 @@ +//go:build !windows + +package daemon + +import ( + "errors" + "syscall" +) + +func isStaleUnixSocketDialError(err error) bool { + return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT) +} diff --git a/daemon/listen_unix_test.go b/daemon/listen_unix_test.go new file mode 100644 index 0000000..d05d8f6 --- /dev/null +++ b/daemon/listen_unix_test.go @@ -0,0 +1,227 @@ +//go:build !windows + +package daemon_test + +import ( + "context" + "net" + "os" + "path/filepath" + "strings" + "syscall" + "testing" + "time" + + "github.com/gofrs/flock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.kenn.io/kit/daemon" +) + +func TestListenUnixRemovesStaleSocketAndBinds(t *testing.T) { + socketPath := staleUnixSocket(t) + ep := daemon.Endpoint{Network: daemon.NetworkUnix, Address: socketPath} + + listener, err := daemon.Listen(context.Background(), ep) + require.NoError(t, err) + t.Cleanup(func() { _ = listener.Close() }) + + conn, err := net.DialTimeout(daemon.NetworkUnix, socketPath, time.Second) + require.NoError(t, err) + _ = conn.Close() +} + +func TestListenUnixRejectsNonSocketPath(t *testing.T) { + socketPath := unixSocketPath(t) + require.NoError(t, os.WriteFile(socketPath, []byte("not a socket"), 0o600)) + ep := daemon.Endpoint{Network: daemon.NetworkUnix, Address: socketPath} + + listener, err := daemon.Listen(context.Background(), ep) + require.Error(t, err) + assert.Nil(t, listener) + assert.Contains(t, err.Error(), "refusing to remove non-socket path") + + body, readErr := os.ReadFile(socketPath) + require.NoError(t, readErr) + assert.Equal(t, "not a socket", string(body)) +} + +func TestListenUnixRejectsLiveSocket(t *testing.T) { + socketPath := unixSocketPath(t) + live, err := net.Listen(daemon.NetworkUnix, socketPath) + require.NoError(t, err) + t.Cleanup(func() { _ = live.Close() }) + ep := daemon.Endpoint{Network: daemon.NetworkUnix, Address: socketPath} + + listener, err := daemon.Listen(context.Background(), ep) + require.Error(t, err) + assert.Nil(t, listener) + assert.Contains(t, err.Error(), "daemon already listening") + + conn, err := net.DialTimeout(daemon.NetworkUnix, socketPath, time.Second) + require.NoError(t, err) + _ = conn.Close() +} + +func TestListenUnixSerializesConcurrentStaleSocketStartup(t *testing.T) { + socketPath := staleUnixSocket(t) + lockPath := filepath.Join(filepath.Dir(socketPath), "daemon.lock") + ep := daemon.Endpoint{Network: daemon.NetworkUnix, Address: socketPath} + opt := daemon.WithListenLockPath(lockPath) + + const starters = 16 + start := make(chan struct{}) + results := make(chan listenResult, starters) + for range starters { + go func() { + <-start + listener, err := daemon.Listen(context.Background(), ep, opt) + results <- listenResult{listener: listener, err: err} + }() + } + close(start) + + var winner net.Listener + var errors []error + for range starters { + result := <-results + if result.err == nil { + require.Nil(t, winner, "only one daemon start should bind the socket") + winner = result.listener + continue + } + errors = append(errors, result.err) + } + require.NotNil(t, winner) + t.Cleanup(func() { _ = winner.Close() }) + require.Len(t, errors, starters-1) + for _, err := range errors { + assert.True(t, + strings.Contains(err.Error(), "daemon already listening") || + strings.Contains(err.Error(), "bind: address already in use"), + "unexpected listen error: %v", err) + } + + conn, err := net.DialTimeout(daemon.NetworkUnix, socketPath, time.Second) + require.NoError(t, err) + _ = conn.Close() +} + +func TestListenUnixProbesAfterAcquiringLock(t *testing.T) { + socketPath := staleUnixSocket(t) + lockPath := filepath.Join(filepath.Dir(socketPath), "daemon.lock") + heldLock := flock.New(lockPath) + require.NoError(t, heldLock.Lock()) + locked := true + t.Cleanup(func() { + if locked { + _ = heldLock.Unlock() + } + }) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + ep := daemon.Endpoint{Network: daemon.NetworkUnix, Address: socketPath} + resultCh := make(chan listenResult, 1) + go func() { + listener, err := daemon.Listen(ctx, ep, daemon.WithListenLockPath(lockPath)) + resultCh <- listenResult{listener: listener, err: err} + }() + + require.NoError(t, os.Remove(socketPath)) + live, err := net.Listen(daemon.NetworkUnix, socketPath) + require.NoError(t, err) + t.Cleanup(func() { _ = live.Close() }) + + require.NoError(t, heldLock.Unlock()) + locked = false + result := <-resultCh + require.Error(t, result.err) + assert.Nil(t, result.listener) + assert.Contains(t, result.err.Error(), "daemon already listening") + + conn, err := net.DialTimeout(daemon.NetworkUnix, socketPath, time.Second) + require.NoError(t, err) + _ = conn.Close() +} + +func TestListenUnixRejectsUnsafeLockDirectory(t *testing.T) { + socketPath := staleUnixSocket(t) + base, err := os.MkdirTemp("/tmp", "kitd-lock") + require.NoError(t, err) + t.Cleanup(func() { _ = os.RemoveAll(base) }) + target := filepath.Join(base, "target") + link := filepath.Join(base, "link") + require.NoError(t, os.MkdirAll(target, 0o700)) + require.NoError(t, os.Symlink(target, link)) + + ep := daemon.Endpoint{Network: daemon.NetworkUnix, Address: socketPath} + listener, err := daemon.Listen(context.Background(), ep, daemon.WithListenLockPath(filepath.Join(link, "daemon.lock"))) + + require.Error(t, err) + assert.Nil(t, listener) + assert.Contains(t, err.Error(), "prepare daemon lock dir") + assert.Contains(t, err.Error(), "symlink") + _, statErr := os.Lstat(socketPath) + require.NoError(t, statErr, "stale socket should not be touched when lock dir is unsafe") +} + +func TestListenUnixRejectsRelativeLockPath(t *testing.T) { + ep := daemon.Endpoint{Network: daemon.NetworkUnix, Address: unixSocketPath(t)} + listener, err := daemon.Listen(context.Background(), ep, daemon.WithListenLockPath("daemon.lock")) + + require.Error(t, err) + assert.Nil(t, listener) + assert.Contains(t, err.Error(), "daemon lock path") + assert.Contains(t, err.Error(), "must be absolute") +} + +func TestListenUnixRejectsRelativeSocketPath(t *testing.T) { + lockPath := filepath.Join(t.TempDir(), "daemon.lock") + ep := daemon.Endpoint{Network: daemon.NetworkUnix, Address: "daemon.sock"} + listener, err := daemon.Listen(context.Background(), ep, daemon.WithListenLockPath(lockPath)) + + require.Error(t, err) + assert.Nil(t, listener) + assert.Contains(t, err.Error(), "unix socket path") + assert.Contains(t, err.Error(), "must be absolute") +} + +func TestListenUnixRejectsSharedSocketDirectoryEvenWithStoreLock(t *testing.T) { + socketPath := filepath.Join("/tmp", "kitd-shared-socket.sock") + t.Cleanup(func() { _ = os.Remove(socketPath) }) + ep := daemon.Endpoint{Network: daemon.NetworkUnix, Address: socketPath} + listener, err := daemon.Listen(context.Background(), ep, daemon.WithRuntimeStore(daemon.RuntimeStore{Dir: t.TempDir()})) + + require.Error(t, err) + assert.Nil(t, listener) + assert.Contains(t, err.Error(), "validate unix socket dir") + _, statErr := os.Lstat(socketPath) + assert.True(t, os.IsNotExist(statErr), "socket in shared dir should not be created: %v", statErr) +} + +type listenResult struct { + listener net.Listener + err error +} + +func staleUnixSocket(t *testing.T) string { + t.Helper() + socketPath := unixSocketPath(t) + fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + require.NoError(t, err) + defer func() { _ = syscall.Close(fd) }() + require.NoError(t, syscall.Bind(fd, &syscall.SockaddrUnix{Name: socketPath})) + if _, err := os.Lstat(socketPath); err != nil { + t.Fatalf("bound unix socket did not leave a socket path: %v", err) + } + return socketPath +} + +func unixSocketPath(t *testing.T) string { + t.Helper() + dir, err := os.MkdirTemp("/tmp", "kitd") + require.NoError(t, err) + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + return filepath.Join(dir, "d.sock") +} diff --git a/daemon/listen_windows.go b/daemon/listen_windows.go new file mode 100644 index 0000000..883ae10 --- /dev/null +++ b/daemon/listen_windows.go @@ -0,0 +1,7 @@ +//go:build windows + +package daemon + +func isStaleUnixSocketDialError(error) bool { + return false +} diff --git a/daemon/lock.go b/daemon/lock.go new file mode 100644 index 0000000..da3cc88 --- /dev/null +++ b/daemon/lock.go @@ -0,0 +1,59 @@ +package daemon + +import ( + "context" + "errors" + "fmt" + "path/filepath" + "sync" + "time" + + "github.com/gofrs/flock" + "golang.org/x/sync/semaphore" +) + +var daemonLocks sync.Map + +// daemonLockRetryDelay is the poll interval; the caller context bounds total wait. +const daemonLockRetryDelay = 50 * time.Millisecond + +type daemonLock struct { + local *semaphore.Weighted + file *flock.Flock +} + +func acquireDaemonLock(ctx context.Context, lockPath, action string) (func(), error) { + if lockPath == "" { + return nil, fmt.Errorf("%s: empty daemon lock path", action) + } + if !filepath.IsAbs(lockPath) { + return nil, fmt.Errorf("%s: daemon lock path %q must be absolute", action, lockPath) + } + if err := ensurePrivateRuntimeDir(filepath.Dir(lockPath)); err != nil { + return nil, fmt.Errorf("prepare daemon lock dir: %w", err) + } + value, _ := daemonLocks.LoadOrStore(lockPath, &daemonLock{ + local: semaphore.NewWeighted(1), + file: flock.New(lockPath), + }) + lock := value.(*daemonLock) + if err := lock.local.Acquire(ctx, 1); err != nil { + return nil, fmt.Errorf("%s: %w", action, err) + } + locked, err := lock.file.TryLockContext(ctx, daemonLockRetryDelay) + if err != nil { + lock.local.Release(1) + return nil, fmt.Errorf("%s: %w", action, err) + } + if !locked { + lock.local.Release(1) + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("%s: %w", action, err) + } + return nil, errors.New(action + ": lock not acquired") + } + return func() { + _ = lock.file.Unlock() + lock.local.Release(1) + }, nil +} diff --git a/daemon/manager.go b/daemon/manager.go index f92f1f7..3a8c165 100644 --- a/daemon/manager.go +++ b/daemon/manager.go @@ -4,25 +4,9 @@ import ( "context" "errors" "fmt" - "os" - "path/filepath" - "sync" "time" - - "github.com/gofrs/flock" - "golang.org/x/sync/semaphore" ) -var startLocks sync.Map - -// startLockRetryDelay is the poll interval; the caller context bounds total wait. -const startLockRetryDelay = 50 * time.Millisecond - -type startLock struct { - local *semaphore.Weighted - file *flock.Flock -} - // CompatibleFunc returns true when a discovered daemon can serve this client. type CompatibleFunc func(RuntimeRecord, PingInfo) bool @@ -102,31 +86,5 @@ func (m Manager) lockStart(ctx context.Context) (func(), error) { if err != nil { return nil, err } - if err := os.MkdirAll(filepath.Dir(lockPath), 0o700); err != nil { - return nil, fmt.Errorf("mkdir daemon lock dir: %w", err) - } - value, _ := startLocks.LoadOrStore(lockPath, &startLock{ - local: semaphore.NewWeighted(1), - file: flock.New(lockPath), - }) - lock := value.(*startLock) - if err := lock.local.Acquire(ctx, 1); err != nil { - return nil, fmt.Errorf("acquire daemon start lock: %w", err) - } - locked, err := lock.file.TryLockContext(ctx, startLockRetryDelay) - if err != nil { - lock.local.Release(1) - return nil, fmt.Errorf("acquire daemon start lock: %w", err) - } - if !locked { - lock.local.Release(1) - if err := ctx.Err(); err != nil { - return nil, fmt.Errorf("acquire daemon start lock: %w", err) - } - return nil, errors.New("daemon start lock not acquired") - } - return func() { - _ = lock.file.Unlock() - lock.local.Release(1) - }, nil + return acquireDaemonLock(ctx, lockPath, "acquire daemon start lock") } diff --git a/daemon/runtime.go b/daemon/runtime.go index da19c5b..62ed1c2 100644 --- a/daemon/runtime.go +++ b/daemon/runtime.go @@ -79,6 +79,9 @@ func (s RuntimeStore) prepareDir() error { if s.Dir == "" { return fmt.Errorf("runtime dir is empty") } + if !filepath.IsAbs(s.Dir) { + return fmt.Errorf("runtime dir %q must be absolute", s.Dir) + } if err := ensurePrivateRuntimeDir(s.Dir); err != nil { return fmt.Errorf("prepare runtime dir: %w", err) } @@ -109,6 +112,19 @@ func (s RuntimeStore) LockPath() (string, error) { return filepath.Join(s.Dir, prefix+".lock"), nil } +// ListenLockPath returns the path used to serialize daemon server listen setup +// for the store. +func (s RuntimeStore) ListenLockPath() (string, error) { + prefix, err := s.validatePrefix() + if err != nil { + return "", err + } + if err := s.prepareDir(); err != nil { + return "", err + } + return filepath.Join(s.Dir, prefix+".listen.lock"), nil +} + // Write saves rec atomically and returns the final path. func (s RuntimeStore) Write(rec RuntimeRecord) (string, error) { if err := s.prepareDir(); err != nil { diff --git a/daemon/runtime_test.go b/daemon/runtime_test.go index e8ef707..90b7d49 100644 --- a/daemon/runtime_test.go +++ b/daemon/runtime_test.go @@ -82,3 +82,26 @@ func TestRuntimeStoreRejectsPrefixTraversal(t *testing.T) { _, err = store.CleanupDead() require.Error(t, err) } + +func TestRuntimeStoreRejectsRelativeDirBeforePreparing(t *testing.T) { + store := daemon.RuntimeStore{Dir: "relative-runtime"} + + _, err := store.LockPath() + require.Error(t, err) + assert.Contains(t, err.Error(), "must be absolute") + + _, statErr := os.Stat("relative-runtime") + assert.True(t, os.IsNotExist(statErr), "relative runtime dir should not be created: %v", statErr) +} + +func TestRuntimeStoreListenLockPathIsSeparateFromStartLock(t *testing.T) { + store := daemon.RuntimeStore{Dir: t.TempDir(), Prefix: "kata"} + + startLock, err := store.LockPath() + require.NoError(t, err) + listenLock, err := store.ListenLockPath() + require.NoError(t, err) + + assert.Equal(t, filepath.Join(store.Dir, "kata.lock"), startLock) + assert.Equal(t, filepath.Join(store.Dir, "kata.listen.lock"), listenLock) +} diff --git a/daemon/safefileio.go b/daemon/safefileio.go index b02674d..0bd7717 100644 --- a/daemon/safefileio.go +++ b/daemon/safefileio.go @@ -10,6 +10,10 @@ func ensurePrivateRuntimeDir(path string) error { return safefileio.EnsurePrivateDir(path) } +func validatePrivateRuntimeDir(path string) error { + return safefileio.ValidatePrivateDir(path) +} + func openRuntimeFile(path string) (*os.File, error) { return safefileio.OpenCurrentUserFile(path) } diff --git a/safefileio/private_dir_unix.go b/safefileio/private_dir_unix.go index f5a577f..f59bc96 100644 --- a/safefileio/private_dir_unix.go +++ b/safefileio/private_dir_unix.go @@ -70,3 +70,32 @@ func EnsurePrivateDir(path string) error { } return nil } + +// ValidatePrivateDir verifies path is a non-symlink directory owned by the +// current user with mode 0700. It never creates or chmods the directory. +func ValidatePrivateDir(path string) error { + if path == "" { + return fmt.Errorf("path is empty") + } + info, err := os.Lstat(path) + if err != nil { + return err + } + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("%s is a symlink", path) + } + if !info.IsDir() { + return fmt.Errorf("%s is not a directory", path) + } + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok { + return fmt.Errorf("stat %s: missing owner information", path) + } + if stat.Uid != uint32(os.Getuid()) { + return fmt.Errorf("%s is not owned by current user", path) + } + if info.Mode().Perm() != 0o700 { + return fmt.Errorf("%s is not mode 0700", path) + } + return nil +} diff --git a/safefileio/private_dir_unix_test.go b/safefileio/private_dir_unix_test.go index a69615d..ff9760a 100644 --- a/safefileio/private_dir_unix_test.go +++ b/safefileio/private_dir_unix_test.go @@ -27,6 +27,29 @@ func TestEnsurePrivateDirRepairsPublicDir(t *testing.T) { require.Equal(t, os.FileMode(0o700), info.Mode().Perm()) } +func TestValidatePrivateDirRejectsWithoutRepairingPublicDir(t *testing.T) { + dir := filepath.Join("/tmp", fmt.Sprintf("kit-safefileio-validate-public-%d", os.Getpid())) + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + require.NoError(t, os.RemoveAll(dir)) + require.NoError(t, os.MkdirAll(dir, 0o700)) + require.NoError(t, os.Chmod(dir, 0o777)) + + require.Error(t, safefileio.ValidatePrivateDir(dir)) + + info, err := os.Stat(dir) + require.NoError(t, err) + require.Equal(t, os.FileMode(0o777), info.Mode().Perm()) +} + +func TestValidatePrivateDirAcceptsPrivateDir(t *testing.T) { + dir := filepath.Join("/tmp", fmt.Sprintf("kit-safefileio-validate-private-%d", os.Getpid())) + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + require.NoError(t, os.RemoveAll(dir)) + require.NoError(t, os.MkdirAll(dir, 0o700)) + + require.NoError(t, safefileio.ValidatePrivateDir(dir)) +} + func TestEnsurePrivateDirRejectsEmptyPath(t *testing.T) { require.Error(t, safefileio.EnsurePrivateDir("")) } diff --git a/safefileio/private_dir_windows.go b/safefileio/private_dir_windows.go index a614855..60b559f 100644 --- a/safefileio/private_dir_windows.go +++ b/safefileio/private_dir_windows.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "fmt" "os" + "unsafe" "golang.org/x/sys/windows" ) @@ -53,6 +54,41 @@ func EnsurePrivateDir(path string) error { return restrictWindowsDir(handle, userSID) } +// ValidatePrivateDir verifies path is a non-reparse directory owned by the +// current token user or token owner. It never creates or changes the directory. +func ValidatePrivateDir(path string) error { + if path == "" { + return fmt.Errorf("path is empty") + } + if err := rejectWindowsReparsePoint(path); err != nil { + return err + } + info, err := os.Lstat(path) + if err != nil { + return err + } + if !info.IsDir() { + return fmt.Errorf("%s is not a directory", path) + } + handle, err := openWindowsDir(path) + if err != nil { + return err + } + defer func() { _ = windows.CloseHandle(handle) }() + userSID, err := currentWindowsUserSID() + if err != nil { + return err + } + ownerSID, err := currentWindowsOwnerSID() + if err != nil { + return err + } + if err := verifyWindowsDirHandle(path, handle, userSID, ownerSID); err != nil { + return err + } + return verifyWindowsDirDACL(path, handle, userSID, ownerSID) +} + // CurrentUserID returns a stable filesystem-safe identifier for the current // Windows account. func CurrentUserID() (string, error) { @@ -137,6 +173,63 @@ func verifyWindowsDirHandle(path string, handle windows.Handle, userSID, ownerSI return nil } +func verifyWindowsDirDACL(path string, handle windows.Handle, userSID, ownerSID *windows.SID) error { + descriptor, err := windows.GetSecurityInfo( + handle, + windows.SE_FILE_OBJECT, + windows.DACL_SECURITY_INFORMATION, + ) + if err != nil { + return err + } + control, _, err := descriptor.Control() + if err != nil { + return err + } + if control&windows.SE_DACL_PROTECTED == 0 { + return fmt.Errorf("%s DACL is not protected", path) + } + dacl, _, err := descriptor.DACL() + if err != nil { + return err + } + if dacl == nil { + return fmt.Errorf("%s DACL is empty", path) + } + system, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid) + if err != nil { + return err + } + admins, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) + if err != nil { + return err + } + allowed := []*windows.SID{userSID, ownerSID, system, admins} + for i := uint16(0); i < dacl.AceCount; i++ { + var ace *windows.ACCESS_ALLOWED_ACE + if err := windows.GetAce(dacl, uint32(i), &ace); err != nil { + return err + } + if ace.Header.AceType != windows.ACCESS_ALLOWED_ACE_TYPE { + return fmt.Errorf("%s DACL contains non-allow ACE", path) + } + sid := (*windows.SID)(unsafe.Pointer(&ace.SidStart)) + if !windowsAnyOwnerMatches(sid, allowed) { + return fmt.Errorf("%s DACL grants access to unexpected principal", path) + } + } + return nil +} + +func windowsAnyOwnerMatches(owner *windows.SID, allowed []*windows.SID) bool { + for _, sid := range allowed { + if sid != nil && owner != nil && owner.Equals(sid) { + return true + } + } + return false +} + func restrictWindowsDir(handle windows.Handle, userSID *windows.SID) error { system, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid) if err != nil { diff --git a/safefileio/private_dir_windows_test.go b/safefileio/private_dir_windows_test.go index 3eec675..4f1d0d4 100644 --- a/safefileio/private_dir_windows_test.go +++ b/safefileio/private_dir_windows_test.go @@ -30,6 +30,41 @@ func TestEnsurePrivateDirCreatesOwnedDirectory(t *testing.T) { require.True(t, owner.Equals(ownerSID)) } +func TestValidatePrivateDirAcceptsPrivateDir(t *testing.T) { + dir := filepath.Join(t.TempDir(), "runtime") + require.NoError(t, EnsurePrivateDir(dir)) + + require.NoError(t, ValidatePrivateDir(dir)) +} + +func TestValidatePrivateDirRejectsBroadDACL(t *testing.T) { + dir := filepath.Join(t.TempDir(), "runtime") + require.NoError(t, EnsurePrivateDir(dir)) + handle, err := openWindowsDir(dir) + require.NoError(t, err) + defer func() { _ = windows.CloseHandle(handle) }() + userSID, err := currentWindowsUserSID() + require.NoError(t, err) + world, err := windows.CreateWellKnownSid(windows.WinWorldSid) + require.NoError(t, err) + acl, err := windows.ACLFromEntries([]windows.EXPLICIT_ACCESS{ + allowFullControl(userSID, windows.TRUSTEE_IS_USER), + allowFullControl(world, windows.TRUSTEE_IS_WELL_KNOWN_GROUP), + }, nil) + require.NoError(t, err) + require.NoError(t, windows.SetSecurityInfo( + handle, + windows.SE_FILE_OBJECT, + windows.DACL_SECURITY_INFORMATION|windows.PROTECTED_DACL_SECURITY_INFORMATION, + nil, + nil, + acl, + nil, + )) + + require.Error(t, ValidatePrivateDir(dir)) +} + func TestEnsurePrivateDirRejectsEmptyPath(t *testing.T) { require.Error(t, EnsurePrivateDir("")) }