Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion cmd/nvidia-ctk-installer/container/runtime/nri/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"fmt"
"os"
"strings"
"sync/atomic"
"time"

"github.com/containerd/nri/pkg/api"
"github.com/containerd/nri/pkg/plugin"
Expand All @@ -37,6 +39,9 @@ var (
const (
// nriCDIAnnotationDomain is the domain name used for CDI device annotations
nriCDIAnnotationDomain = "nvidia.cdi.k8s.io"

// nriReconnectBackoff is the backoff time between retries when attempting to connect the NRI Plugin to the ttrpc server
nriReconnectBackoff = 2 * time.Second
)

type Plugin struct {
Expand All @@ -45,6 +50,11 @@ type Plugin struct {

namespace string
stub stub.Stub

// stopped is set before Stop() so OnClose does not reconnect during shutdown.
stopped atomic.Bool
// reconnectInProgress ensures that only one NRI plugin reconnect operation runs at any given time.
reconnectInProgress atomic.Bool
}

// NewPlugin creates a new NRI plugin for injecting CDI devices
Expand Down Expand Up @@ -119,7 +129,8 @@ func (p *Plugin) Start(ctx context.Context, nriSocketPath, nriPluginIdx string)
stub.WithPluginIdx(nriPluginIdx),
stub.WithLogger(toNriLogger{p.logger}),
stub.WithOnClose(func() {
p.logger.Infof("NRI ttrpc connection to %s is down. NRI plugin stopped.", nriSocketPath)
p.logger.Infof("NRI ttrpc connection to %s is down; attempting to reconnect...", nriSocketPath)
p.scheduleReconnect(nriSocketPath)
}),
}
if len(nriSocketPath) > 0 {
Expand All @@ -141,10 +152,43 @@ func (p *Plugin) Start(ctx context.Context, nriSocketPath, nriPluginIdx string)
return nil
}

// scheduleReconnect runs stub.Start in a loop until success, shutdown, or context cancellation.
func (p *Plugin) scheduleReconnect(nriSocketPath string) {
if !p.reconnectInProgress.CompareAndSwap(false, true) {
return
}
go func() {
defer p.reconnectInProgress.Store(false)
for i := 1; ; i++ {
if p.stopped.Load() {
p.logger.Infof("NRI plugin stopped. Stopping all reconnect attempts...")
return
}
select {
case <-p.ctx.Done():
return
case <-time.After(nriReconnectBackoff):
}
p.logger.Infof("NRI plugin reconnecting to %s (attempt %d)...", nriSocketPath, i)
if err := p.stub.Start(p.ctx); err != nil {
p.logger.Warningf("NRI plugin reconnect failed: %v", err)
if p.stopped.Load() {
p.logger.Infof("NRI plugin stopped. Stopping all reconnect attempts...")
return
}
continue
}
p.logger.Infof("NRI plugin reconnected to %s", nriSocketPath)
return
}
}()
}

// Stop stops the NRI plugin
func (p *Plugin) Stop() {
if p == nil || p.stub == nil {
return
}
p.stopped.Store(true)
p.stub.Stop()
}
Loading