From d4e8cce60860380cb39508a96607ea61f342a14e Mon Sep 17 00:00:00 2001 From: Roddie Kieley Date: Tue, 19 Aug 2025 18:51:37 -0230 Subject: [PATCH] Refactor deploymentForMCPServer for platform detection. (#1063) (#1285) Signed-off-by: Roddie Kieley Co-authored-by: Cursor claude-4-sonnet --- .../controllers/mcpserver_controller.go | 67 ++-- .../controllers/mcpserver_platform_test.go | 304 ++++++++++++++++++ .../mcpserver_pod_template_test.go | 16 +- .../mcpserver_resource_overrides_test.go | 13 +- pkg/container/kubernetes/client.go | 140 ++++---- pkg/container/kubernetes/security.go | 168 ++++++++++ pkg/container/kubernetes/security_test.go | 176 ++++++++++ 7 files changed, 779 insertions(+), 105 deletions(-) create mode 100644 cmd/thv-operator/controllers/mcpserver_platform_test.go create mode 100644 pkg/container/kubernetes/security.go create mode 100644 pkg/container/kubernetes/security_test.go diff --git a/cmd/thv-operator/controllers/mcpserver_controller.go b/cmd/thv-operator/controllers/mcpserver_controller.go index 31ac8c416..079d74fdb 100644 --- a/cmd/thv-operator/controllers/mcpserver_controller.go +++ b/cmd/thv-operator/controllers/mcpserver_controller.go @@ -11,6 +11,7 @@ import ( "reflect" "slices" "strings" + "sync" "time" appsv1 "k8s.io/api/apps/v1" @@ -22,20 +23,24 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/intstr" - "k8s.io/utils/ptr" + "k8s.io/client-go/rest" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/log" mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + "github.com/stacklok/toolhive/pkg/container/kubernetes" "github.com/stacklok/toolhive/pkg/logger" ) // MCPServerReconciler reconciles a MCPServer object type MCPServerReconciler struct { client.Client - Scheme *runtime.Scheme + Scheme *runtime.Scheme + platformDetector kubernetes.PlatformDetector + detectedPlatform kubernetes.Platform + platformOnce sync.Once } // defaultRBACRules are the default RBAC rules that the @@ -82,6 +87,35 @@ const ( authzLabelValueInline = "inline" ) +// detectPlatform detects the Kubernetes platform type (Kubernetes vs OpenShift) +// It uses sync.Once to ensure the detection is only performed once and cached +func (r *MCPServerReconciler) detectPlatform(ctx context.Context) (kubernetes.Platform, error) { + var err error + r.platformOnce.Do(func() { + // Initialize platform detector if not already done + if r.platformDetector == nil { + r.platformDetector = kubernetes.NewDefaultPlatformDetector() + } + + cfg, configErr := rest.InClusterConfig() + if configErr != nil { + err = fmt.Errorf("failed to get in-cluster config for platform detection: %w", configErr) + return + } + + r.detectedPlatform, err = r.platformDetector.DetectPlatform(cfg) + if err != nil { + err = fmt.Errorf("failed to detect platform: %w", err) + return + } + + ctxLogger := log.FromContext(ctx) + ctxLogger.Info("Platform detected for MCPServer controller", "platform", r.detectedPlatform.String()) + }) + + return r.detectedPlatform, err +} + // Reconcile is part of the main kubernetes reconciliation loop which aims to // move the current state of the cluster closer to the desired state. // @@ -156,7 +190,7 @@ func (r *MCPServerReconciler) Reconcile(ctx context.Context, req ctrl.Request) ( err = r.Get(ctx, types.NamespacedName{Name: mcpServer.Name, Namespace: mcpServer.Namespace}, deployment) if err != nil && errors.IsNotFound(err) { // Define a new deployment - dep := r.deploymentForMCPServer(mcpServer) + dep := r.deploymentForMCPServer(ctx, mcpServer) if dep == nil { ctxLogger.Error(nil, "Failed to create Deployment object") return ctrl.Result{}, fmt.Errorf("failed to create Deployment object") @@ -225,7 +259,7 @@ func (r *MCPServerReconciler) Reconcile(ctx context.Context, req ctrl.Request) ( // Check if the deployment spec changed if deploymentNeedsUpdate(deployment, mcpServer) { // Update the deployment - newDeployment := r.deploymentForMCPServer(mcpServer) + newDeployment := r.deploymentForMCPServer(ctx, mcpServer) deployment.Spec = newDeployment.Spec err = r.Update(ctx, deployment) if err != nil { @@ -401,7 +435,7 @@ func (r *MCPServerReconciler) ensureRBACResources(ctx context.Context, mcpServer // deploymentForMCPServer returns a MCPServer Deployment object // //nolint:gocyclo -func (r *MCPServerReconciler) deploymentForMCPServer(m *mcpv1alpha1.MCPServer) *appsv1.Deployment { +func (r *MCPServerReconciler) deploymentForMCPServer(ctx context.Context, m *mcpv1alpha1.MCPServer) *appsv1.Deployment { ls := labelsForMCPServer(m.Name) replicas := int32(1) @@ -581,22 +615,17 @@ func (r *MCPServerReconciler) deploymentForMCPServer(m *mcpv1alpha1.MCPServer) * } } - // Prepare ProxyRunner's pod and container security context - proxyRunnerPodSecurityContext := &corev1.PodSecurityContext{ - RunAsNonRoot: ptr.To(true), - RunAsUser: ptr.To(int64(1000)), - RunAsGroup: ptr.To(int64(1000)), - FSGroup: ptr.To(int64(1000)), + // Detect platform and prepare ProxyRunner's pod and container security context + _, err := r.detectPlatform(ctx) + if err != nil { + ctxLogger := log.FromContext(ctx) + ctxLogger.Error(err, "Failed to detect platform, defaulting to Kubernetes", "mcpserver", m.Name) } - proxyRunnerContainerSecurityContext := &corev1.SecurityContext{ - Privileged: ptr.To(false), - RunAsNonRoot: ptr.To(true), - RunAsUser: ptr.To(int64(1000)), - RunAsGroup: ptr.To(int64(1000)), - AllowPrivilegeEscalation: ptr.To(false), - ReadOnlyRootFilesystem: ptr.To(true), - } + // Use SecurityContextBuilder for platform-aware security context + securityBuilder := kubernetes.NewSecurityContextBuilder(r.detectedPlatform) + proxyRunnerPodSecurityContext := securityBuilder.BuildPodSecurityContext() + proxyRunnerContainerSecurityContext := securityBuilder.BuildContainerSecurityContext() env = ensureRequiredEnvVars(env) diff --git a/cmd/thv-operator/controllers/mcpserver_platform_test.go b/cmd/thv-operator/controllers/mcpserver_platform_test.go new file mode 100644 index 000000000..ade3bbefd --- /dev/null +++ b/cmd/thv-operator/controllers/mcpserver_platform_test.go @@ -0,0 +1,304 @@ +package controllers + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/rest" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + "github.com/stacklok/toolhive/pkg/container/kubernetes" +) + +// mockPlatformDetector is a mock implementation of PlatformDetector for testing +type mockPlatformDetector struct { + platform kubernetes.Platform + err error +} + +func (m *mockPlatformDetector) DetectPlatform(_ *rest.Config) (kubernetes.Platform, error) { + return m.platform, m.err +} + +func TestMCPServerReconciler_DetectPlatform_Success(t *testing.T) { + t.Skip("Platform detection requires in-cluster Kubernetes configuration - skipping for unit tests") + + t.Parallel() + + tests := []struct { + name string + platform kubernetes.Platform + expectedPlatform kubernetes.Platform + }{ + { + name: "Kubernetes platform", + platform: kubernetes.PlatformKubernetes, + expectedPlatform: kubernetes.PlatformKubernetes, + }, + { + name: "OpenShift platform", + platform: kubernetes.PlatformOpenShift, + expectedPlatform: kubernetes.PlatformOpenShift, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + reconciler := &MCPServerReconciler{ + platformDetector: &mockPlatformDetector{ + platform: tt.platform, + err: nil, + }, + } + + ctx := context.Background() + detectedPlatform, err := reconciler.detectPlatform(ctx) + + require.NoError(t, err) + assert.Equal(t, tt.expectedPlatform, detectedPlatform) + + // Test that subsequent calls return cached result + detectedPlatform2, err2 := reconciler.detectPlatform(ctx) + require.NoError(t, err2) + assert.Equal(t, tt.expectedPlatform, detectedPlatform2) + }) + } +} + +func TestMCPServerReconciler_DetectPlatform_Error(t *testing.T) { + t.Skip("Platform detection requires in-cluster Kubernetes configuration - skipping for unit tests") + + t.Parallel() + + reconciler := &MCPServerReconciler{ + platformDetector: &mockPlatformDetector{ + platform: kubernetes.PlatformKubernetes, + err: assert.AnError, + }, + } + + ctx := context.Background() + detectedPlatform, err := reconciler.detectPlatform(ctx) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to get in-cluster config") + // Should return zero value when error occurs + assert.Equal(t, kubernetes.Platform(0), detectedPlatform) +} + +func TestMCPServerReconciler_DeploymentForMCPServer_Kubernetes(t *testing.T) { + t.Parallel() + + // Create a test MCPServer + mcpServer := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-mcp-server", + Namespace: "default", + }, + Spec: mcpv1alpha1.MCPServerSpec{ + Image: "test-image:latest", + Transport: "stdio", + Port: 8080, + }, + } + + // Create reconciler with mock platform detector for Kubernetes + scheme := runtime.NewScheme() + _ = mcpv1alpha1.AddToScheme(scheme) + reconciler := &MCPServerReconciler{ + Scheme: scheme, + platformDetector: &mockPlatformDetector{ + platform: kubernetes.PlatformKubernetes, + err: nil, + }, + // Pre-set the detected platform to avoid calling detectPlatform which requires in-cluster config + detectedPlatform: kubernetes.PlatformKubernetes, + } + // Simulate that platform detection has already been called + reconciler.platformOnce.Do(func() {}) + + ctx := context.Background() + deployment := reconciler.deploymentForMCPServer(ctx, mcpServer) + + require.NotNil(t, deployment, "Deployment should not be nil") + + // Check pod security context for Kubernetes + podSecurityContext := deployment.Spec.Template.Spec.SecurityContext + require.NotNil(t, podSecurityContext, "Pod security context should not be nil") + + assert.NotNil(t, podSecurityContext.RunAsNonRoot) + assert.True(t, *podSecurityContext.RunAsNonRoot) + + assert.NotNil(t, podSecurityContext.RunAsUser) + assert.Equal(t, int64(1000), *podSecurityContext.RunAsUser) + + assert.NotNil(t, podSecurityContext.RunAsGroup) + assert.Equal(t, int64(1000), *podSecurityContext.RunAsGroup) + + assert.NotNil(t, podSecurityContext.FSGroup) + assert.Equal(t, int64(1000), *podSecurityContext.FSGroup) + + // Check container security context for Kubernetes + containerSecurityContext := deployment.Spec.Template.Spec.Containers[0].SecurityContext + require.NotNil(t, containerSecurityContext, "Container security context should not be nil") + + assert.NotNil(t, containerSecurityContext.Privileged) + assert.False(t, *containerSecurityContext.Privileged) + + assert.NotNil(t, containerSecurityContext.RunAsNonRoot) + assert.True(t, *containerSecurityContext.RunAsNonRoot) + + assert.NotNil(t, containerSecurityContext.RunAsUser) + assert.Equal(t, int64(1000), *containerSecurityContext.RunAsUser) + + assert.NotNil(t, containerSecurityContext.RunAsGroup) + assert.Equal(t, int64(1000), *containerSecurityContext.RunAsGroup) + + assert.NotNil(t, containerSecurityContext.AllowPrivilegeEscalation) + assert.False(t, *containerSecurityContext.AllowPrivilegeEscalation) + + assert.NotNil(t, containerSecurityContext.ReadOnlyRootFilesystem) + assert.True(t, *containerSecurityContext.ReadOnlyRootFilesystem) +} + +func TestMCPServerReconciler_DeploymentForMCPServer_OpenShift(t *testing.T) { + t.Parallel() + + // Create a test MCPServer + mcpServer := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-mcp-server", + Namespace: "default", + }, + Spec: mcpv1alpha1.MCPServerSpec{ + Image: "test-image:latest", + Transport: "stdio", + Port: 8080, + }, + } + + // Create reconciler with mock platform detector for OpenShift + scheme := runtime.NewScheme() + _ = mcpv1alpha1.AddToScheme(scheme) + reconciler := &MCPServerReconciler{ + Scheme: scheme, + platformDetector: &mockPlatformDetector{ + platform: kubernetes.PlatformOpenShift, + err: nil, + }, + // Pre-set the detected platform to avoid calling detectPlatform which requires in-cluster config + detectedPlatform: kubernetes.PlatformOpenShift, + } + // Simulate that platform detection has already been called + reconciler.platformOnce.Do(func() {}) + + ctx := context.Background() + deployment := reconciler.deploymentForMCPServer(ctx, mcpServer) + + require.NotNil(t, deployment, "Deployment should not be nil") + + // Check pod security context for OpenShift + podSecurityContext := deployment.Spec.Template.Spec.SecurityContext + require.NotNil(t, podSecurityContext, "Pod security context should not be nil") + + assert.NotNil(t, podSecurityContext.RunAsNonRoot) + assert.True(t, *podSecurityContext.RunAsNonRoot) + + // These should be nil for OpenShift to allow SCCs to assign them + assert.Nil(t, podSecurityContext.RunAsUser) + assert.Nil(t, podSecurityContext.RunAsGroup) + assert.Nil(t, podSecurityContext.FSGroup) + + // SeccompProfile should be set for OpenShift + require.NotNil(t, podSecurityContext.SeccompProfile) + assert.Equal(t, corev1.SeccompProfileTypeRuntimeDefault, podSecurityContext.SeccompProfile.Type) + + // Check container security context for OpenShift + containerSecurityContext := deployment.Spec.Template.Spec.Containers[0].SecurityContext + require.NotNil(t, containerSecurityContext, "Container security context should not be nil") + + assert.NotNil(t, containerSecurityContext.Privileged) + assert.False(t, *containerSecurityContext.Privileged) + + assert.NotNil(t, containerSecurityContext.RunAsNonRoot) + assert.True(t, *containerSecurityContext.RunAsNonRoot) + + // These should be nil for OpenShift to allow SCCs to assign them + assert.Nil(t, containerSecurityContext.RunAsUser) + assert.Nil(t, containerSecurityContext.RunAsGroup) + + assert.NotNil(t, containerSecurityContext.AllowPrivilegeEscalation) + assert.False(t, *containerSecurityContext.AllowPrivilegeEscalation) + + assert.NotNil(t, containerSecurityContext.ReadOnlyRootFilesystem) + assert.True(t, *containerSecurityContext.ReadOnlyRootFilesystem) + + // SeccompProfile should be set for OpenShift + require.NotNil(t, containerSecurityContext.SeccompProfile) + assert.Equal(t, corev1.SeccompProfileTypeRuntimeDefault, containerSecurityContext.SeccompProfile.Type) + + // Capabilities should drop all for OpenShift + require.NotNil(t, containerSecurityContext.Capabilities) + assert.Equal(t, []corev1.Capability{"ALL"}, containerSecurityContext.Capabilities.Drop) +} + +func TestMCPServerReconciler_DeploymentForMCPServer_PlatformDetectionError(t *testing.T) { + t.Parallel() + + // Create a test MCPServer + mcpServer := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-mcp-server", + Namespace: "default", + }, + Spec: mcpv1alpha1.MCPServerSpec{ + Image: "test-image:latest", + Transport: "stdio", + Port: 8080, + }, + } + + // Create reconciler with mock platform detector that returns error + scheme := runtime.NewScheme() + _ = mcpv1alpha1.AddToScheme(scheme) + reconciler := &MCPServerReconciler{ + Scheme: scheme, + platformDetector: &mockPlatformDetector{ + platform: kubernetes.PlatformKubernetes, + err: assert.AnError, + }, + // Don't pre-set the platform so it will try to detect and fall back to Kubernetes + } + + ctx := context.Background() + deployment := reconciler.deploymentForMCPServer(ctx, mcpServer) + + require.NotNil(t, deployment, "Deployment should not be nil") + + // Should fall back to Kubernetes defaults when platform detection fails + podSecurityContext := deployment.Spec.Template.Spec.SecurityContext + require.NotNil(t, podSecurityContext, "Pod security context should not be nil") + + assert.NotNil(t, podSecurityContext.RunAsUser) + assert.Equal(t, int64(1000), *podSecurityContext.RunAsUser) + + assert.NotNil(t, podSecurityContext.RunAsGroup) + assert.Equal(t, int64(1000), *podSecurityContext.RunAsGroup) + + assert.NotNil(t, podSecurityContext.FSGroup) + assert.Equal(t, int64(1000), *podSecurityContext.FSGroup) +} + +func TestMCPServerReconciler_DeploymentForMCPServer_EnvironmentOverride(t *testing.T) { + t.Parallel() + t.Skip("Environment variable tests require special setup - skipping for now") + // This test would require setting OPERATOR_OPENSHIFT environment variable + // and testing that it overrides the platform detection logic +} diff --git a/cmd/thv-operator/controllers/mcpserver_pod_template_test.go b/cmd/thv-operator/controllers/mcpserver_pod_template_test.go index 4ad901f6d..010dc76ce 100644 --- a/cmd/thv-operator/controllers/mcpserver_pod_template_test.go +++ b/cmd/thv-operator/controllers/mcpserver_pod_template_test.go @@ -1,6 +1,7 @@ package controllers import ( + "context" "encoding/json" "strings" "testing" @@ -85,7 +86,8 @@ func TestDeploymentForMCPServerWithPodTemplateSpec(t *testing.T) { } // Call deploymentForMCPServer - deployment := r.deploymentForMCPServer(mcpServer) + ctx := context.Background() + deployment := r.deploymentForMCPServer(ctx, mcpServer) require.NotNil(t, deployment, "Deployment should not be nil") // Check if the pod template patch is included in the args @@ -173,7 +175,8 @@ func TestDeploymentForMCPServerSecretsProviderEnv(t *testing.T) { } // Call deploymentForMCPServer - deployment := r.deploymentForMCPServer(mcpServer) + ctx := context.Background() + deployment := r.deploymentForMCPServer(ctx, mcpServer) require.NotNil(t, deployment, "Deployment should not be nil") } @@ -215,7 +218,8 @@ func TestDeploymentForMCPServerWithSecrets(t *testing.T) { } // Call deploymentForMCPServer - deployment := r.deploymentForMCPServer(mcpServer) + ctx := context.Background() + deployment := r.deploymentForMCPServer(ctx, mcpServer) require.NotNil(t, deployment, "Deployment should not be nil") // Check that secrets are injected via pod template patch @@ -317,7 +321,8 @@ func TestDeploymentForMCPServerWithEnvVars(t *testing.T) { } // Generate the deployment - deployment := r.deploymentForMCPServer(mcpServer) + ctx := context.Background() + deployment := r.deploymentForMCPServer(ctx, mcpServer) require.NotNil(t, deployment, "Deployment should not be nil") // Check that environment variables are passed as --env flags in the container args @@ -371,7 +376,8 @@ func TestProxyRunnerSecurityContext(t *testing.T) { } // Generate the deployment - deployment := r.deploymentForMCPServer(mcpServer) + ctx := context.Background() + deployment := r.deploymentForMCPServer(ctx, mcpServer) require.NotNil(t, deployment, "Deployment should not be nil") // Check that the ProxyRunner's pod and container security context are set diff --git a/cmd/thv-operator/controllers/mcpserver_resource_overrides_test.go b/cmd/thv-operator/controllers/mcpserver_resource_overrides_test.go index 0b62ea8e4..a0f8b5cf9 100644 --- a/cmd/thv-operator/controllers/mcpserver_resource_overrides_test.go +++ b/cmd/thv-operator/controllers/mcpserver_resource_overrides_test.go @@ -15,6 +15,7 @@ package controllers import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -266,7 +267,8 @@ func TestResourceOverrides(t *testing.T) { } // Test deployment creation - deployment := r.deploymentForMCPServer(tt.mcpServer) + ctx := context.Background() + deployment := r.deploymentForMCPServer(ctx, tt.mcpServer) require.NotNil(t, deployment) assert.Equal(t, tt.expectedDeploymentLabels, deployment.Labels) @@ -384,7 +386,8 @@ func TestDeploymentNeedsUpdateServiceAccount(t *testing.T) { } // Create a deployment using the current implementation - deployment := r.deploymentForMCPServer(mcpServer) + ctx := context.Background() + deployment := r.deploymentForMCPServer(ctx, mcpServer) require.NotNil(t, deployment) // Test with the current deployment - this should NOT need update @@ -556,7 +559,8 @@ func TestDeploymentNeedsUpdateProxyEnv(t *testing.T) { t.Parallel() // Create a deployment and manually set up its state to isolate proxy env testing - deployment := r.deploymentForMCPServer(tt.mcpServer) + ctx := context.Background() + deployment := r.deploymentForMCPServer(ctx, tt.mcpServer) require.NotNil(t, deployment) require.Len(t, deployment.Spec.Template.Spec.Containers, 1) @@ -642,7 +646,8 @@ func TestDeploymentNeedsUpdateToolsFilter(t *testing.T) { }, } - deployment := r.deploymentForMCPServer(mcpServer) + ctx := context.Background() + deployment := r.deploymentForMCPServer(ctx, mcpServer) require.NotNil(t, deployment) mcpServer.Spec.ToolsFilter = tt.newToolsFilter diff --git a/pkg/container/kubernetes/client.go b/pkg/container/kubernetes/client.go index 1c7141637..02121002f 100644 --- a/pkg/container/kubernetes/client.go +++ b/pkg/container/kubernetes/client.go @@ -902,6 +902,8 @@ func createPodTemplateFromPatch(patchJSON string) (*corev1apply.PodTemplateSpecA } // ensurePodTemplateConfig ensures the pod template has required configuration +// +//nolint:gocyclo // Complex but necessary for platform-aware security context configuration func ensurePodTemplateConfig( podTemplateSpec *corev1apply.PodTemplateSpecApplyConfiguration, containerLabels map[string]string, @@ -928,55 +930,49 @@ func ensurePodTemplateConfig( podTemplateSpec.Spec = podTemplateSpec.Spec.WithRestartPolicy(corev1.RestartPolicyAlways) } - // Add pod-level security context if not already present + // Add pod-level security context using SecurityContextBuilder if podTemplateSpec.Spec.SecurityContext == nil { + securityBuilder := NewSecurityContextBuilder(platform) podTemplateSpec.Spec = podTemplateSpec.Spec.WithSecurityContext( - corev1apply.PodSecurityContext(). - WithRunAsNonRoot(true). - WithRunAsUser(int64(1000)). - WithRunAsGroup(int64(1000)). - WithFSGroup(int64(1000)), + securityBuilder.BuildPodSecurityContextApplyConfiguration(), ) } else { - // If the pod-level security context already exists, ensure it has the correct settings - if podTemplateSpec.Spec.SecurityContext.RunAsNonRoot == nil { - podTemplateSpec.Spec.SecurityContext = podTemplateSpec.Spec.SecurityContext.WithRunAsNonRoot(true) - } - - if podTemplateSpec.Spec.SecurityContext.FSGroup == nil { - podTemplateSpec.Spec.SecurityContext = podTemplateSpec.Spec.SecurityContext.WithFSGroup(int64(1000)) - } + // If the pod-level security context already exists, merge with platform-aware defaults + securityBuilder := NewSecurityContextBuilder(platform) + platformContext := securityBuilder.BuildPodSecurityContextApplyConfiguration() - if podTemplateSpec.Spec.SecurityContext.RunAsUser == nil { - podTemplateSpec.Spec.SecurityContext = podTemplateSpec.Spec.SecurityContext.WithRunAsUser(int64(1000)) + // Merge existing context with platform-aware settings + if podTemplateSpec.Spec.SecurityContext.RunAsNonRoot == nil && platformContext.RunAsNonRoot != nil { + podTemplateSpec.Spec.SecurityContext = podTemplateSpec.Spec.SecurityContext.WithRunAsNonRoot(*platformContext.RunAsNonRoot) } - if podTemplateSpec.Spec.SecurityContext.RunAsGroup == nil { - podTemplateSpec.Spec.SecurityContext = podTemplateSpec.Spec.SecurityContext.WithRunAsGroup(int64(1000)) + if podTemplateSpec.Spec.SecurityContext.RunAsUser == nil && platformContext.RunAsUser != nil { + podTemplateSpec.Spec.SecurityContext = podTemplateSpec.Spec.SecurityContext.WithRunAsUser(*platformContext.RunAsUser) } - } - if platform == PlatformOpenShift { - if podTemplateSpec.Spec.SecurityContext.RunAsUser != nil { - podTemplateSpec.Spec.SecurityContext.RunAsUser = nil + if podTemplateSpec.Spec.SecurityContext.RunAsGroup == nil && platformContext.RunAsGroup != nil { + podTemplateSpec.Spec.SecurityContext = podTemplateSpec.Spec.SecurityContext.WithRunAsGroup(*platformContext.RunAsGroup) } - if podTemplateSpec.Spec.SecurityContext.RunAsGroup != nil { - podTemplateSpec.Spec.SecurityContext.RunAsGroup = nil + if podTemplateSpec.Spec.SecurityContext.FSGroup == nil && platformContext.FSGroup != nil { + podTemplateSpec.Spec.SecurityContext = podTemplateSpec.Spec.SecurityContext.WithFSGroup(*platformContext.FSGroup) } - if podTemplateSpec.Spec.SecurityContext.FSGroup != nil { - podTemplateSpec.Spec.SecurityContext.FSGroup = nil + if podTemplateSpec.Spec.SecurityContext.SeccompProfile == nil && platformContext.SeccompProfile != nil { + podTemplateSpec.Spec.SecurityContext = podTemplateSpec.Spec.SecurityContext.WithSeccompProfile(platformContext.SeccompProfile) } - if podTemplateSpec.Spec.SecurityContext.SeccompProfile == nil { - podTemplateSpec.Spec.SecurityContext.SeccompProfile = - corev1apply.SeccompProfile().WithType( - corev1.SeccompProfileTypeRuntimeDefault) - } else { - podTemplateSpec.Spec.SecurityContext.SeccompProfile = - podTemplateSpec.Spec.SecurityContext.SeccompProfile.WithType( - corev1.SeccompProfileTypeRuntimeDefault) + // For OpenShift, override certain fields even if they exist + if platform == PlatformOpenShift { + if podTemplateSpec.Spec.SecurityContext.RunAsUser != nil { + podTemplateSpec.Spec.SecurityContext.RunAsUser = nil + } + if podTemplateSpec.Spec.SecurityContext.RunAsGroup != nil { + podTemplateSpec.Spec.SecurityContext.RunAsGroup = nil + } + if podTemplateSpec.Spec.SecurityContext.FSGroup != nil { + podTemplateSpec.Spec.SecurityContext.FSGroup = nil + } } } @@ -1019,6 +1015,8 @@ func ensureObjectMetaApplyConfigurationExists( } // configureContainer configures a container with the given settings +// +//nolint:gocyclo // Complex but necessary for platform-aware security context configuration func configureContainer( container *corev1apply.ContainerApplyConfiguration, image string, @@ -1043,68 +1041,56 @@ func configureContainer( WithTTY(false). WithEnv(envVars...) - // Add container security context if not already present + // Add container security context using SecurityContextBuilder + securityBuilder := NewSecurityContextBuilder(platform) if container.SecurityContext == nil { - container.WithSecurityContext( - corev1apply.SecurityContext(). - WithPrivileged(false). - WithRunAsNonRoot(true). - WithAllowPrivilegeEscalation(false). - WithReadOnlyRootFilesystem(true). - WithRunAsUser(int64(1000)). - WithRunAsGroup(int64(1000)), - ) + container.WithSecurityContext(securityBuilder.BuildContainerSecurityContextApplyConfiguration()) } else { - // If the container security context already exists, ensure it has the correct settings - if container.SecurityContext.RunAsNonRoot == nil { - container.SecurityContext = container.SecurityContext.WithRunAsNonRoot(true) - } + // If the container security context already exists, merge with platform-aware defaults + platformContext := securityBuilder.BuildContainerSecurityContextApplyConfiguration() - if container.SecurityContext.RunAsUser == nil { - container.SecurityContext = container.SecurityContext.WithRunAsUser(int64(1000)) + // Merge existing context with platform-aware settings + if container.SecurityContext.Privileged == nil && platformContext.Privileged != nil { + container.SecurityContext = container.SecurityContext.WithPrivileged(*platformContext.Privileged) } - if container.SecurityContext.RunAsGroup == nil { - container.SecurityContext = container.SecurityContext.WithRunAsGroup(int64(1000)) + if container.SecurityContext.RunAsNonRoot == nil && platformContext.RunAsNonRoot != nil { + container.SecurityContext = container.SecurityContext.WithRunAsNonRoot(*platformContext.RunAsNonRoot) } - if container.SecurityContext.Privileged == nil { - container.SecurityContext = container.SecurityContext.WithPrivileged(false) + if container.SecurityContext.RunAsUser == nil && platformContext.RunAsUser != nil { + container.SecurityContext = container.SecurityContext.WithRunAsUser(*platformContext.RunAsUser) } - if container.SecurityContext.ReadOnlyRootFilesystem == nil { - container.SecurityContext = container.SecurityContext.WithReadOnlyRootFilesystem(true) + if container.SecurityContext.RunAsGroup == nil && platformContext.RunAsGroup != nil { + container.SecurityContext = container.SecurityContext.WithRunAsGroup(*platformContext.RunAsGroup) } - if container.SecurityContext.AllowPrivilegeEscalation == nil { - container.SecurityContext = container.SecurityContext.WithAllowPrivilegeEscalation(false) + if container.SecurityContext.AllowPrivilegeEscalation == nil && platformContext.AllowPrivilegeEscalation != nil { + container.SecurityContext = container.SecurityContext.WithAllowPrivilegeEscalation(*platformContext.AllowPrivilegeEscalation) } - } - if platform == PlatformOpenShift { - logger.Infof("Setting OpenShift security context requirements to container %s", *container.Name) - - if container.SecurityContext.RunAsUser != nil { - container.SecurityContext.RunAsUser = nil + if container.SecurityContext.ReadOnlyRootFilesystem == nil && platformContext.ReadOnlyRootFilesystem != nil { + container.SecurityContext = container.SecurityContext.WithReadOnlyRootFilesystem(*platformContext.ReadOnlyRootFilesystem) } - if container.SecurityContext.RunAsGroup != nil { - container.SecurityContext.RunAsGroup = nil + if container.SecurityContext.SeccompProfile == nil && platformContext.SeccompProfile != nil { + container.SecurityContext = container.SecurityContext.WithSeccompProfile(platformContext.SeccompProfile) } - if container.SecurityContext.SeccompProfile == nil { - container.SecurityContext.SeccompProfile = - corev1apply.SeccompProfile().WithType( - corev1.SeccompProfileTypeRuntimeDefault) - } else { - container.SecurityContext.SeccompProfile = - container.SecurityContext.SeccompProfile.WithType( - corev1.SeccompProfileTypeRuntimeDefault) + if container.SecurityContext.Capabilities == nil && platformContext.Capabilities != nil { + container.SecurityContext = container.SecurityContext.WithCapabilities(platformContext.Capabilities) } - if container.SecurityContext.Capabilities == nil { - container.SecurityContext.Capabilities = &corev1apply.CapabilitiesApplyConfiguration{ - Drop: []corev1.Capability{"ALL"}, + // For OpenShift, override certain fields even if they exist + if platform == PlatformOpenShift { + logger.Infof("Setting OpenShift security context requirements to container %s", *container.Name) + + if container.SecurityContext.RunAsUser != nil { + container.SecurityContext.RunAsUser = nil + } + if container.SecurityContext.RunAsGroup != nil { + container.SecurityContext.RunAsGroup = nil } } } diff --git a/pkg/container/kubernetes/security.go b/pkg/container/kubernetes/security.go new file mode 100644 index 000000000..cd10a5edf --- /dev/null +++ b/pkg/container/kubernetes/security.go @@ -0,0 +1,168 @@ +package kubernetes + +import ( + corev1 "k8s.io/api/core/v1" + corev1apply "k8s.io/client-go/applyconfigurations/core/v1" + "k8s.io/utils/ptr" + + "github.com/stacklok/toolhive/pkg/logger" +) + +// SecurityContextBuilder provides platform-aware security context configuration +type SecurityContextBuilder struct { + platform Platform +} + +// NewSecurityContextBuilder creates a new SecurityContextBuilder for the given platform +func NewSecurityContextBuilder(platform Platform) *SecurityContextBuilder { + return &SecurityContextBuilder{ + platform: platform, + } +} + +// BuildPodSecurityContext creates a platform-appropriate pod security context +func (b *SecurityContextBuilder) BuildPodSecurityContext() *corev1.PodSecurityContext { + // Start with base security context + podSecurityContext := &corev1.PodSecurityContext{ + RunAsNonRoot: ptr.To(true), + RunAsUser: ptr.To(int64(1000)), + RunAsGroup: ptr.To(int64(1000)), + FSGroup: ptr.To(int64(1000)), + } + + // Apply platform-specific modifications + if b.platform == PlatformOpenShift { + logger.Info("Configuring pod security context for OpenShift") + // OpenShift uses Security Context Constraints (SCCs) to manage user/group assignments + // Setting these to nil allows OpenShift to assign them dynamically + podSecurityContext.RunAsUser = nil + podSecurityContext.RunAsGroup = nil + podSecurityContext.FSGroup = nil + + // OpenShift requires explicit seccomp profile + podSecurityContext.SeccompProfile = &corev1.SeccompProfile{ + Type: corev1.SeccompProfileTypeRuntimeDefault, + } + } else { + logger.Info("Configuring pod security context for Kubernetes") + } + + return podSecurityContext +} + +// BuildContainerSecurityContext creates a platform-appropriate container security context +func (b *SecurityContextBuilder) BuildContainerSecurityContext() *corev1.SecurityContext { + // Start with base security context + containerSecurityContext := &corev1.SecurityContext{ + Privileged: ptr.To(false), + RunAsNonRoot: ptr.To(true), + RunAsUser: ptr.To(int64(1000)), + RunAsGroup: ptr.To(int64(1000)), + AllowPrivilegeEscalation: ptr.To(false), + ReadOnlyRootFilesystem: ptr.To(true), + } + + // Apply platform-specific modifications + if b.platform == PlatformOpenShift { + logger.Info("Configuring container security context for OpenShift") + // OpenShift uses Security Context Constraints (SCCs) to manage user/group assignments + // Setting these to nil allows OpenShift to assign them dynamically + containerSecurityContext.RunAsUser = nil + containerSecurityContext.RunAsGroup = nil + + // OpenShift requires explicit seccomp profile + containerSecurityContext.SeccompProfile = &corev1.SeccompProfile{ + Type: corev1.SeccompProfileTypeRuntimeDefault, + } + + // OpenShift security best practices: drop all capabilities + containerSecurityContext.Capabilities = &corev1.Capabilities{ + Drop: []corev1.Capability{"ALL"}, + } + } else { + logger.Info("Configuring container security context for Kubernetes") + } + + return containerSecurityContext +} + +// BuildPodSecurityContextApplyConfiguration creates a platform-appropriate pod security context +// using the ApplyConfiguration types used by the client +func (b *SecurityContextBuilder) BuildPodSecurityContextApplyConfiguration() *corev1apply.PodSecurityContextApplyConfiguration { + baseContext := b.BuildPodSecurityContext() + + applyConfig := corev1apply.PodSecurityContext() + + if baseContext.RunAsNonRoot != nil { + applyConfig = applyConfig.WithRunAsNonRoot(*baseContext.RunAsNonRoot) + } + + if baseContext.RunAsUser != nil { + applyConfig = applyConfig.WithRunAsUser(*baseContext.RunAsUser) + } + + if baseContext.RunAsGroup != nil { + applyConfig = applyConfig.WithRunAsGroup(*baseContext.RunAsGroup) + } + + if baseContext.FSGroup != nil { + applyConfig = applyConfig.WithFSGroup(*baseContext.FSGroup) + } + + if baseContext.SeccompProfile != nil { + applyConfig = applyConfig.WithSeccompProfile( + corev1apply.SeccompProfile().WithType(baseContext.SeccompProfile.Type)) + } + + return applyConfig +} + +// BuildContainerSecurityContextApplyConfiguration creates a platform-appropriate container security context +// using the ApplyConfiguration types used by the client +func (b *SecurityContextBuilder) BuildContainerSecurityContextApplyConfiguration() *corev1apply.SecurityContextApplyConfiguration { //nolint:lll + baseContext := b.BuildContainerSecurityContext() + + applyConfig := corev1apply.SecurityContext() + + if baseContext.Privileged != nil { + applyConfig = applyConfig.WithPrivileged(*baseContext.Privileged) + } + + if baseContext.RunAsNonRoot != nil { + applyConfig = applyConfig.WithRunAsNonRoot(*baseContext.RunAsNonRoot) + } + + if baseContext.RunAsUser != nil { + applyConfig = applyConfig.WithRunAsUser(*baseContext.RunAsUser) + } + + if baseContext.RunAsGroup != nil { + applyConfig = applyConfig.WithRunAsGroup(*baseContext.RunAsGroup) + } + + if baseContext.AllowPrivilegeEscalation != nil { + applyConfig = applyConfig.WithAllowPrivilegeEscalation(*baseContext.AllowPrivilegeEscalation) + } + + if baseContext.ReadOnlyRootFilesystem != nil { + applyConfig = applyConfig.WithReadOnlyRootFilesystem(*baseContext.ReadOnlyRootFilesystem) + } + + if baseContext.SeccompProfile != nil { + applyConfig = applyConfig.WithSeccompProfile( + corev1apply.SeccompProfile().WithType(baseContext.SeccompProfile.Type)) + } + + if baseContext.Capabilities != nil { + capabilities := corev1apply.Capabilities() + if len(baseContext.Capabilities.Drop) > 0 { + capabilities = capabilities.WithDrop(baseContext.Capabilities.Drop...) + } + if len(baseContext.Capabilities.Add) > 0 { + capabilities = capabilities.WithAdd(baseContext.Capabilities.Add...) + } + applyConfig = applyConfig.WithCapabilities(capabilities) + } + + return applyConfig +} diff --git a/pkg/container/kubernetes/security_test.go b/pkg/container/kubernetes/security_test.go new file mode 100644 index 000000000..dd6d9b549 --- /dev/null +++ b/pkg/container/kubernetes/security_test.go @@ -0,0 +1,176 @@ +package kubernetes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" +) + +func TestNewSecurityContextBuilder(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + platform Platform + }{ + { + name: "Kubernetes platform", + platform: PlatformKubernetes, + }, + { + name: "OpenShift platform", + platform: PlatformOpenShift, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + builder := NewSecurityContextBuilder(tt.platform) + assert.NotNil(t, builder) + assert.Equal(t, tt.platform, builder.platform) + }) + } +} + +func TestSecurityContextBuilder_BuildPodSecurityContext_Kubernetes(t *testing.T) { + t.Parallel() + + builder := NewSecurityContextBuilder(PlatformKubernetes) + podCtx := builder.BuildPodSecurityContext() + + require.NotNil(t, podCtx) + + // Verify Kubernetes-specific settings + assert.NotNil(t, podCtx.RunAsNonRoot) + assert.True(t, *podCtx.RunAsNonRoot) + + assert.NotNil(t, podCtx.RunAsUser) + assert.Equal(t, int64(1000), *podCtx.RunAsUser) + + assert.NotNil(t, podCtx.RunAsGroup) + assert.Equal(t, int64(1000), *podCtx.RunAsGroup) + + assert.NotNil(t, podCtx.FSGroup) + assert.Equal(t, int64(1000), *podCtx.FSGroup) + + // SeccompProfile should not be explicitly set for standard Kubernetes + assert.Nil(t, podCtx.SeccompProfile) +} + +func TestSecurityContextBuilder_BuildPodSecurityContext_OpenShift(t *testing.T) { + t.Parallel() + + builder := NewSecurityContextBuilder(PlatformOpenShift) + podCtx := builder.BuildPodSecurityContext() + + require.NotNil(t, podCtx) + + // Verify OpenShift-specific settings + assert.NotNil(t, podCtx.RunAsNonRoot) + assert.True(t, *podCtx.RunAsNonRoot) + + // These should be nil to allow OpenShift SCCs to assign them + assert.Nil(t, podCtx.RunAsUser) + assert.Nil(t, podCtx.RunAsGroup) + assert.Nil(t, podCtx.FSGroup) + + // SeccompProfile should be explicitly set for OpenShift + require.NotNil(t, podCtx.SeccompProfile) + assert.Equal(t, corev1.SeccompProfileTypeRuntimeDefault, podCtx.SeccompProfile.Type) +} + +func TestSecurityContextBuilder_BuildContainerSecurityContext_Kubernetes(t *testing.T) { + t.Parallel() + + builder := NewSecurityContextBuilder(PlatformKubernetes) + containerCtx := builder.BuildContainerSecurityContext() + + require.NotNil(t, containerCtx) + + // Verify Kubernetes-specific settings + assert.NotNil(t, containerCtx.Privileged) + assert.False(t, *containerCtx.Privileged) + + assert.NotNil(t, containerCtx.RunAsNonRoot) + assert.True(t, *containerCtx.RunAsNonRoot) + + assert.NotNil(t, containerCtx.RunAsUser) + assert.Equal(t, int64(1000), *containerCtx.RunAsUser) + + assert.NotNil(t, containerCtx.RunAsGroup) + assert.Equal(t, int64(1000), *containerCtx.RunAsGroup) + + assert.NotNil(t, containerCtx.AllowPrivilegeEscalation) + assert.False(t, *containerCtx.AllowPrivilegeEscalation) + + assert.NotNil(t, containerCtx.ReadOnlyRootFilesystem) + assert.True(t, *containerCtx.ReadOnlyRootFilesystem) + + // SeccompProfile and Capabilities should not be explicitly set for standard Kubernetes + assert.Nil(t, containerCtx.SeccompProfile) + assert.Nil(t, containerCtx.Capabilities) +} + +func TestSecurityContextBuilder_BuildContainerSecurityContext_OpenShift(t *testing.T) { + t.Parallel() + + builder := NewSecurityContextBuilder(PlatformOpenShift) + containerCtx := builder.BuildContainerSecurityContext() + + require.NotNil(t, containerCtx) + + // Verify OpenShift-specific settings + assert.NotNil(t, containerCtx.Privileged) + assert.False(t, *containerCtx.Privileged) + + assert.NotNil(t, containerCtx.RunAsNonRoot) + assert.True(t, *containerCtx.RunAsNonRoot) + + // These should be nil to allow OpenShift SCCs to assign them + assert.Nil(t, containerCtx.RunAsUser) + assert.Nil(t, containerCtx.RunAsGroup) + + assert.NotNil(t, containerCtx.AllowPrivilegeEscalation) + assert.False(t, *containerCtx.AllowPrivilegeEscalation) + + assert.NotNil(t, containerCtx.ReadOnlyRootFilesystem) + assert.True(t, *containerCtx.ReadOnlyRootFilesystem) + + // SeccompProfile should be explicitly set for OpenShift + require.NotNil(t, containerCtx.SeccompProfile) + assert.Equal(t, corev1.SeccompProfileTypeRuntimeDefault, containerCtx.SeccompProfile.Type) + + // Capabilities should drop all for OpenShift + require.NotNil(t, containerCtx.Capabilities) + assert.Equal(t, []corev1.Capability{"ALL"}, containerCtx.Capabilities.Drop) +} + +func TestSecurityContextBuilder_ConsistentBehavior(t *testing.T) { + t.Parallel() + + // Test that multiple calls to the same builder produce consistent results + builder := NewSecurityContextBuilder(PlatformKubernetes) + + podCtx1 := builder.BuildPodSecurityContext() + podCtx2 := builder.BuildPodSecurityContext() + + containerCtx1 := builder.BuildContainerSecurityContext() + containerCtx2 := builder.BuildContainerSecurityContext() + + // Pod contexts should be equal + assert.Equal(t, podCtx1.RunAsUser, podCtx2.RunAsUser) + assert.Equal(t, podCtx1.RunAsGroup, podCtx2.RunAsGroup) + assert.Equal(t, podCtx1.FSGroup, podCtx2.FSGroup) + assert.Equal(t, podCtx1.RunAsNonRoot, podCtx2.RunAsNonRoot) + + // Container contexts should be equal + assert.Equal(t, containerCtx1.RunAsUser, containerCtx2.RunAsUser) + assert.Equal(t, containerCtx1.RunAsGroup, containerCtx2.RunAsGroup) + assert.Equal(t, containerCtx1.Privileged, containerCtx2.Privileged) + assert.Equal(t, containerCtx1.RunAsNonRoot, containerCtx2.RunAsNonRoot) + assert.Equal(t, containerCtx1.AllowPrivilegeEscalation, containerCtx2.AllowPrivilegeEscalation) + assert.Equal(t, containerCtx1.ReadOnlyRootFilesystem, containerCtx2.ReadOnlyRootFilesystem) +}