From 1070c3773a429addfd88f6e8ea5c245c695f8114 Mon Sep 17 00:00:00 2001 From: Parity Agent Date: Mon, 16 Mar 2026 10:57:19 +0000 Subject: [PATCH] feat(container): no symlink creation hooks for legacy paths --- cmd/amd-ctk/hook/hook.go | 36 +++++ cmd/amd-ctk/hook/symlinks/symlinks.go | 173 +++++++++++++++++++++ cmd/amd-ctk/hook/symlinks/symlinks_test.go | 171 ++++++++++++++++++++ cmd/amd-ctk/main.go | 6 +- internal/cdi/cdi.go | 17 +- 5 files changed, 399 insertions(+), 4 deletions(-) create mode 100644 cmd/amd-ctk/hook/hook.go create mode 100644 cmd/amd-ctk/hook/symlinks/symlinks.go create mode 100644 cmd/amd-ctk/hook/symlinks/symlinks_test.go diff --git a/cmd/amd-ctk/hook/hook.go b/cmd/amd-ctk/hook/hook.go new file mode 100644 index 0000000..71cae4a --- /dev/null +++ b/cmd/amd-ctk/hook/hook.go @@ -0,0 +1,36 @@ +/** +# Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package hook + +import ( + "github.com/ROCm/container-toolkit/cmd/amd-ctk/hook/symlinks" + "github.com/ROCm/container-toolkit/internal/logger" + "github.com/urfave/cli/v2" +) + +// AddNewCommand creates the hook command group +func AddNewCommand() *cli.Command { + logger.Init(false) + + return &cli.Command{ + Name: "hook", + Usage: "OCI hook operations", + Subcommands: []*cli.Command{ + symlinks.NewCommand(), + }, + } +} diff --git a/cmd/amd-ctk/hook/symlinks/symlinks.go b/cmd/amd-ctk/hook/symlinks/symlinks.go new file mode 100644 index 0000000..fc0b56c --- /dev/null +++ b/cmd/amd-ctk/hook/symlinks/symlinks.go @@ -0,0 +1,173 @@ +/** +# Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package symlinks + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/ROCm/container-toolkit/internal/logger" + "github.com/moby/sys/symlink" + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/urfave/cli/v2" +) + +type command struct{} + +type config struct { + links []string + containerSpec string +} + +// NewCommand constructs the create-symlinks hook command +func NewCommand() *cli.Command { + c := command{} + return c.build() +} + +func (m command) build() *cli.Command { + cfg := config{} + + return &cli.Command{ + Name: "create-symlinks", + Usage: "Create symlinks in the container for legacy ROCm paths", + Action: func(_ context.Context, cmd *cli.Command) error { + return m.run(cmd, &cfg) + }, + Flags: []cli.Flag{ + &cli.StringSliceFlag{ + Name: "link", + Usage: "Symlink specification: target::link. Example: /opt/rocm/lib::/opt/rocm-5.7.0/lib", + Destination: &cfg.links, + }, + &cli.StringFlag{ + Name: "container-spec", + Usage: "Path to OCI container spec (for testing)", + Destination: &cfg.containerSpec, + Hidden: true, + }, + }, + } +} + +func (m command) run(_ *cli.Command, cfg *config) error { + containerRoot, err := m.getContainerRoot(cfg.containerSpec) + if err != nil { + return fmt.Errorf("failed to determine container root: %w", err) + } + + created := make(map[string]bool) + for _, l := range cfg.links { + if created[l] { + logger.Log.Printf("Link %v already processed", l) + continue + } + parts := strings.Split(l, "::") + if len(parts) != 2 { + return fmt.Errorf("invalid symlink specification %v (expected target::link)", l) + } + + if err := m.createLink(containerRoot, parts[0], parts[1]); err != nil { + return fmt.Errorf("failed to create link %v: %w", parts, err) + } + created[l] = true + } + return nil +} + +// getContainerRoot determines the container root from the OCI spec +func (m command) getContainerRoot(specPath string) (string, error) { + if specPath == "" || specPath == "-" { + specPath = "/dev/stdin" + } + + file, err := os.Open(specPath) + if err != nil { + return "", fmt.Errorf("failed to open spec: %w", err) + } + defer file.Close() + + var spec specs.Spec + if err := json.NewDecoder(file).Decode(&spec); err != nil { + return "", fmt.Errorf("failed to decode spec: %w", err) + } + + if spec.Root == nil { + return "", fmt.Errorf("spec.Root is nil") + } + + return spec.Root.Path, nil +} + +// createLink creates a symbolic link in the container root +func (m command) createLink(containerRoot, targetPath, linkPath string) error { + fullLinkPath := filepath.Join(containerRoot, linkPath) + + // Check if link already exists with correct target + exists, err := linkExists(targetPath, fullLinkPath) + if err != nil { + return fmt.Errorf("failed to check link existence: %w", err) + } + if exists { + logger.Log.Printf("Link %s already exists with correct target", fullLinkPath) + return nil + } + + // Resolve parent directory within container root + resolvedParent, err := symlink.FollowSymlinkInScope(filepath.Dir(fullLinkPath), containerRoot) + if err != nil { + return fmt.Errorf("failed to resolve link parent: %w", err) + } + resolvedLinkPath := filepath.Join(resolvedParent, filepath.Base(fullLinkPath)) + + logger.Log.Printf("Creating symlink: %s -> %s", resolvedLinkPath, targetPath) + + // Create parent directories + if err := os.MkdirAll(filepath.Dir(resolvedLinkPath), 0755); err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + + // Remove existing file/link if present + if err := os.Remove(resolvedLinkPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove existing path: %w", err) + } + + // Create symlink + if err := os.Symlink(targetPath, resolvedLinkPath); err != nil { + return fmt.Errorf("failed to create symlink: %w", err) + } + + return nil +} + +// linkExists checks if a symlink exists and points to the expected target +func linkExists(target, link string) (bool, error) { + currentTarget, err := os.Readlink(link) + if errors.Is(err, os.ErrNotExist) { + return false, nil + } + if err != nil { + // Not a symlink or other error + return false, nil + } + return currentTarget == target, nil +} diff --git a/cmd/amd-ctk/hook/symlinks/symlinks_test.go b/cmd/amd-ctk/hook/symlinks/symlinks_test.go new file mode 100644 index 0000000..ffc83bf --- /dev/null +++ b/cmd/amd-ctk/hook/symlinks/symlinks_test.go @@ -0,0 +1,171 @@ +package symlinks + +import ( + "os" + "path/filepath" + "testing" + + "github.com/ROCm/container-toolkit/internal/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setup(t *testing.T) { + logger.Init(true) +} + +func TestLinkExists(t *testing.T) { + setup(t) + tmpDir := t.TempDir() + + // Create a symlink + target := "target-file" + link := filepath.Join(tmpDir, "test-link") + require.NoError(t, os.Symlink(target, link)) + + // Test existing link with correct target + exists, err := linkExists(target, link) + assert.NoError(t, err) + assert.True(t, exists) + + // Test existing link with wrong target + exists, err = linkExists("different-target", link) + assert.NoError(t, err) + assert.False(t, exists) + + // Test non-existent link + exists, err = linkExists("foo", filepath.Join(tmpDir, "nonexistent")) + assert.NoError(t, err) + assert.False(t, exists) +} + +func TestCreateLink(t *testing.T) { + setup(t) + cmd := command{} + + tests := []struct { + name string + target string + linkPath string + setup func(t *testing.T, root string) + wantErr bool + validateLink func(t *testing.T, root, target, linkPath string) + }{ + { + name: "simple_relative_link", + target: "librocm.so.5", + linkPath: "/opt/rocm/lib/librocm.so", + validateLink: func(t *testing.T, root, target, linkPath string) { + fullPath := filepath.Join(root, linkPath) + resolvedTarget, err := os.Readlink(fullPath) + require.NoError(t, err) + assert.Equal(t, target, resolvedTarget) + }, + }, + { + name: "absolute_target", + target: "/opt/rocm-5.7.0/lib/libhip.so", + linkPath: "/opt/rocm/lib/libhip.so", + validateLink: func(t *testing.T, root, target, linkPath string) { + fullPath := filepath.Join(root, linkPath) + resolvedTarget, err := os.Readlink(fullPath) + require.NoError(t, err) + assert.Equal(t, target, resolvedTarget) + }, + }, + { + name: "nested_directory_creation", + target: "../lib/librocm.so", + linkPath: "/opt/rocm-5.7.0/compat/lib/librocm.so", + validateLink: func(t *testing.T, root, target, linkPath string) { + fullPath := filepath.Join(root, linkPath) + _, err := os.Stat(fullPath) + require.NoError(t, err) + }, + }, + { + name: "overwrites_existing_link", + target: "new-target", + linkPath: "/test/link", + setup: func(t *testing.T, root string) { + linkPath := filepath.Join(root, "/test/link") + require.NoError(t, os.MkdirAll(filepath.Dir(linkPath), 0755)) + require.NoError(t, os.Symlink("old-target", linkPath)) + }, + validateLink: func(t *testing.T, root, target, linkPath string) { + fullPath := filepath.Join(root, linkPath) + resolvedTarget, err := os.Readlink(fullPath) + require.NoError(t, err) + assert.Equal(t, target, resolvedTarget) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + containerRoot := t.TempDir() + + if tt.setup != nil { + tt.setup(t, containerRoot) + } + + err := cmd.createLink(containerRoot, tt.target, tt.linkPath) + + if tt.wantErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + + if tt.validateLink != nil { + tt.validateLink(t, containerRoot, tt.target, tt.linkPath) + } + }) + } +} + +func TestGetContainerRoot(t *testing.T) { + setup(t) + cmd := command{} + + tmpDir := t.TempDir() + specPath := filepath.Join(tmpDir, "config.json") + + // Create a minimal OCI spec + spec := `{ + "ociVersion": "1.0.0", + "root": { + "path": "/container/root" + } + }` + + require.NoError(t, os.WriteFile(specPath, []byte(spec), 0644)) + + root, err := cmd.getContainerRoot(specPath) + require.NoError(t, err) + assert.Equal(t, "/container/root", root) +} + +func TestCreateLinkIdempotent(t *testing.T) { + setup(t) + cmd := command{} + + containerRoot := t.TempDir() + target := "librocm.so.5" + linkPath := "/opt/rocm/lib/librocm.so" + + // Create link first time + err := cmd.createLink(containerRoot, target, linkPath) + require.NoError(t, err) + + // Create same link second time (should succeed) + err = cmd.createLink(containerRoot, target, linkPath) + require.NoError(t, err) + + // Verify link still points to correct target + fullPath := filepath.Join(containerRoot, linkPath) + resolvedTarget, err := os.Readlink(fullPath) + require.NoError(t, err) + assert.Equal(t, target, resolvedTarget) +} diff --git a/cmd/amd-ctk/main.go b/cmd/amd-ctk/main.go index 690e7f8..bec4f0c 100644 --- a/cmd/amd-ctk/main.go +++ b/cmd/amd-ctk/main.go @@ -1,14 +1,14 @@ /** # Copyright (c) Advanced Micro Devices, Inc. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the \"License\"); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an \"AS IS\" BASIS, +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. @@ -22,6 +22,7 @@ import ( "github.com/ROCm/container-toolkit/cmd/amd-ctk/cdi" "github.com/ROCm/container-toolkit/cmd/amd-ctk/gpu-tracker" + "github.com/ROCm/container-toolkit/cmd/amd-ctk/hook" "github.com/ROCm/container-toolkit/cmd/amd-ctk/runtime" "github.com/ROCm/container-toolkit/internal/logger" "github.com/urfave/cli/v2" @@ -64,6 +65,7 @@ func main() { runtime.AddNewCommand(), cdi.AddNewCommand(), gpuTracker.AddNewCommand(), + hook.AddNewCommand(), } err := amdCtkCli.Run(os.Args) diff --git a/internal/cdi/cdi.go b/internal/cdi/cdi.go index 3319b7f..683f456 100644 --- a/internal/cdi/cdi.go +++ b/internal/cdi/cdi.go @@ -1,14 +1,14 @@ /** # Copyright (c) Advanced Micro Devices, Inc. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the \"License\"); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an \"AS IS\" BASIS, +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. @@ -155,6 +155,19 @@ func (cdi *cdi_t) GenerateSpec() error { Name: "all", ContainerEdits: specs.ContainerEdits{ DeviceNodes: allDNs, + Hooks: []*specs.Hook{ + { + HookName: "createContainer", + Path: "/usr/local/bin/amd-ctk", + Args: []string{ + "amd-ctk", + "hook", + "create-symlinks", + "--link", "/opt/rocm/lib::/opt/rocm-5.7.0/lib", + "--link", "/opt/rocm/lib::/opt/rocm-5.6.0/lib", + }, + }, + }, }, } cdiDevs = append(cdiDevs, allCdiDev)