Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func main() {
"Mount path for the supervisor binary volume in the agent container")
flag.StringVar(&cfg.GatewayEndpoint, "gateway-endpoint", cfg.GatewayEndpoint,
"Gateway gRPC endpoint for supervisor callback (OPENSHELL_ENDPOINT)")
flag.StringVar(&cfg.TLSCASecret, "tls-ca-secret", cfg.TLSCASecret,
"Secret name containing ca.crt for sandbox TLS verification (OPENSHELL_TLS_CA)")
flag.StringVar(&cfg.TLSClientSecret, "tls-client-secret", cfg.TLSClientSecret,
"Secret name containing tls.crt and tls.key for sandbox mTLS client auth")
flag.Parse()

if cfg.Tenant == "" {
Expand Down
2 changes: 2 additions & 0 deletions internal/driver/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ type Config struct {
DtachBinaryPath string
SupervisorMountPath string
GatewayEndpoint string
TLSCASecret string // Secret name containing ca.crt for gateway TLS verification
TLSClientSecret string // Secret name containing tls.crt and tls.key for mTLS client auth
}

func DefaultConfig() Config {
Expand Down
124 changes: 94 additions & 30 deletions internal/driver/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ var sandboxGVR = schema.GroupVersionResource{
}

const (
labelSandboxID = "openshell.ai/sandbox-id"
labelManagedBy = "openshell.ai/managed-by"
labelKagenti = "kagenti.io/type"
labelTenant = "openshell.ai/tenant"
labelKagentiTeam = "kagenti.io/team"
labelSandboxID = "openshell.ai/sandbox-id"
labelManagedBy = "openshell.ai/managed-by"
labelKagenti = "kagenti.io/type"
labelKagentiInject = "kagenti.io/inject"
labelTenant = "openshell.ai/tenant"
labelKagentiTeam = "kagenti.io/team"
)

// K8sProvisioner implements SandboxProvisioner using the Kubernetes API. It
Expand Down Expand Up @@ -74,9 +75,10 @@ func (p *K8sProvisioner) Create(ctx context.Context, sb *pb.DriverSandbox) error
tmpl := spec.GetTemplate()

labels := mergeMaps(tmpl.GetLabels(), map[string]string{
labelSandboxID: sb.GetId(),
labelManagedBy: "openshell",
labelKagenti: "agent",
labelSandboxID: sb.GetId(),
labelManagedBy: "openshell",
labelKagenti: "agent",
labelKagentiInject: "disabled",
})
if p.cfg.Tenant != "" {
labels[labelTenant] = p.cfg.Tenant
Expand Down Expand Up @@ -248,6 +250,28 @@ func (p *K8sProvisioner) buildSandboxSpec(sb *pb.DriverSandbox) map[string]inter
}

// Agent container runs the supervisor and mounts it read-only.
agentVolumeMounts := []interface{}{
map[string]interface{}{
"name": "supervisor-bin",
"mountPath": p.cfg.SupervisorMountPath,
"readOnly": true,
},
}
if p.cfg.TLSCASecret != "" {
agentVolumeMounts = append(agentVolumeMounts, map[string]interface{}{
"name": "tls-ca",
"mountPath": "/tls/ca",
"readOnly": true,
})
}
if p.cfg.TLSClientSecret != "" {
agentVolumeMounts = append(agentVolumeMounts, map[string]interface{}{
"name": "tls-client",
"mountPath": "/tls/client",
"readOnly": true,
})
}

container := map[string]interface{}{
"name": "agent",
"image": tmpl.GetImage(),
Expand All @@ -260,29 +284,47 @@ func (p *K8sProvisioner) buildSandboxSpec(sb *pb.DriverSandbox) map[string]inter
"add": []interface{}{"SYS_ADMIN", "NET_ADMIN", "SYS_PTRACE", "SYSLOG"},
},
},
"volumeMounts": []interface{}{
map[string]interface{}{
"name": "supervisor-bin",
"mountPath": p.cfg.SupervisorMountPath,
"readOnly": true,
},
},
"volumeMounts": agentVolumeMounts,
}

if res := tmpl.GetResources(); res != nil {
container["resources"] = buildResources(res, spec.GetGpu())
}

volumes := []interface{}{
map[string]interface{}{
"name": "supervisor-bin",
"emptyDir": map[string]interface{}{},
},
}
if p.cfg.TLSCASecret != "" {
volumes = append(volumes, map[string]interface{}{
"name": "tls-ca",
"secret": map[string]interface{}{
"secretName": p.cfg.TLSCASecret,
"items": []interface{}{
map[string]interface{}{
"key": "ca.crt",
"path": "ca.crt",
},
},
},
})
}
if p.cfg.TLSClientSecret != "" {
volumes = append(volumes, map[string]interface{}{
"name": "tls-client",
"secret": map[string]interface{}{
"secretName": p.cfg.TLSClientSecret,
},
})
}

podSpec := map[string]interface{}{
"initContainers": []interface{}{initContainer},
"containers": []interface{}{container},
"initContainers": []interface{}{initContainer},
"containers": []interface{}{container},
"serviceAccountName": "openshell-sandbox",
"volumes": []interface{}{
map[string]interface{}{
"name": "supervisor-bin",
"emptyDir": map[string]interface{}{},
},
},
"volumes": volumes,
}

// Apply platform_config passthrough fields.
Expand All @@ -294,9 +336,10 @@ func (p *K8sProvisioner) buildSandboxSpec(sb *pb.DriverSandbox) map[string]inter
}

podLabels := mergeMaps(tmpl.GetLabels(), map[string]string{
labelSandboxID: sb.GetId(),
labelManagedBy: "openshell",
labelKagenti: "agent",
labelSandboxID: sb.GetId(),
labelManagedBy: "openshell",
labelKagenti: "agent",
labelKagentiInject: "disabled",
})
if p.cfg.Tenant != "" {
podLabels[labelTenant] = p.cfg.Tenant
Expand All @@ -321,16 +364,26 @@ func (p *K8sProvisioner) buildFullEnvList(
envList := buildEnvList(spec.GetEnvironment(), tmpl.GetEnvironment())

gatewayEnv := map[string]string{
"OPENSHELL_SANDBOX_ID": sb.GetId(),
"OPENSHELL_SANDBOX": sb.GetName(),
"OPENSHELL_SANDBOX_COMMAND": "sleep infinity",
"OPENSHELL_SANDBOX_ID": sb.GetId(),
"OPENSHELL_SANDBOX": sb.GetName(),
"OPENSHELL_SANDBOX_COMMAND": "sleep infinity",
"OPENSHELL_SSH_SOCKET_PATH": "/tmp/openshell-ssh.sock",
}
if p.cfg.GatewayEndpoint != "" {
gatewayEnv["OPENSHELL_ENDPOINT"] = p.cfg.GatewayEndpoint
}
if p.cfg.TLSCASecret != "" {
gatewayEnv["OPENSHELL_TLS_CA"] = "/tls/ca/ca.crt"
}
if p.cfg.TLSClientSecret != "" {
gatewayEnv["OPENSHELL_TLS_CERT"] = "/tls/client/tls.crt"
gatewayEnv["OPENSHELL_TLS_KEY"] = "/tls/client/tls.key"
}

gatewayEnv["ANTHROPIC_BASE_URL"] = "https://inference.local/v1"
gatewayEnv["OPENSHELL_LOG_LEVEL"] = "debug"
gatewayEnv["ANTHROPIC_BASE_URL"] = "https://inference.local"
gatewayEnv["OPENAI_BASE_URL"] = "https://inference.local/v1"
gatewayEnv["CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS"] = "1"

for k, v := range gatewayEnv {
envList = append(envList, map[string]interface{}{
Expand All @@ -339,5 +392,16 @@ func (p *K8sProvisioner) buildFullEnvList(
})
}

// SSH handshake secret sourced from the gateway secrets
envList = append(envList, map[string]interface{}{
"name": "OPENSHELL_SSH_HANDSHAKE_SECRET",
"valueFrom": map[string]interface{}{
"secretKeyRef": map[string]interface{}{
"name": "openshell-gateway-secrets",
"key": "ssh-handshake-secret",
},
},
})

return envList
}
Loading