From 6e896c64afa9baf51f54b9be9dc75c55a88ca46a Mon Sep 17 00:00:00 2001 From: Rahul Sharma Date: Wed, 1 Jul 2026 22:11:00 -0700 Subject: [PATCH] use only major version for rhel-specific driver images Signed-off-by: Rahul Sharma --- controllers/state_manager.go | 33 +---- controllers/state_manager_test.go | 105 ++++++++++++- deployments/gpu-operator/values.yaml | 2 +- internal/state/driver_test.go | 25 ++++ internal/state/nodepool.go | 35 +---- internal/state/nodepool_test.go | 212 ++++++++++++++++++++++++++- 6 files changed, 334 insertions(+), 78 deletions(-) diff --git a/controllers/state_manager.go b/controllers/state_manager.go index b046875aaa..37691059c1 100644 --- a/controllers/state_manager.go +++ b/controllers/state_manager.go @@ -20,7 +20,6 @@ import ( "context" "fmt" "path/filepath" - "strconv" "strings" "github.com/go-logr/logr" @@ -539,42 +538,16 @@ func (n *ClusterPolicyController) getGPUNodeOSInfo() (string, string, error) { if !ok { return "", "", fmt.Errorf("unable to retrieve OS version from label %s", nfdOSVersionIDLabelKey) } - osMajorVersion := strings.Split(osVersion, ".")[0] - - // If the OS is RockyLinux or RHEL 10 & above, we will omit the minor version when constructing the os image tag + // If the OS is RockyLinux or RHEL, we will omit the minor version when constructing the os image tag switch osName { - case "rocky": - osVersion = osMajorVersion - case "rhel": - osMajorNumber, err := parseOSMajorVersion(osVersion) - if err != nil { - return "", "", err - } - if osMajorNumber >= 10 { - osVersion = osMajorVersion - } + case "rocky", "rhel": + osVersion = strings.Split(osVersion, ".")[0] } osTag := fmt.Sprintf("%s%s", osName, osVersion) return osName, osTag, nil } -func parseOSMajorVersion(osVersion string) (int, error) { - osMajorVersion := strings.Split(osVersion, ".")[0] - osMajorVersion = strings.TrimSpace(osMajorVersion) - osMajorVersion = strings.TrimPrefix(strings.TrimPrefix(osMajorVersion, "v"), "V") - if osMajorVersion == "" { - return 0, fmt.Errorf("empty OS major version") - } - - osMajorNumber, err := strconv.Atoi(osMajorVersion) - if err != nil { - return 0, fmt.Errorf("error processing OS major version %s: %w", osMajorVersion, err) - } - - return osMajorNumber, nil -} - func (n *ClusterPolicyController) setPodSecurityLabelsForNamespace() error { ctx := n.ctx namespaceName := clusterPolicyCtrl.operatorNamespace diff --git a/controllers/state_manager_test.go b/controllers/state_manager_test.go index 6585de1961..b811806ad4 100644 --- a/controllers/state_manager_test.go +++ b/controllers/state_manager_test.go @@ -26,6 +26,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/utils/ptr" + ctrlclient "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" gpuv1 "github.com/NVIDIA/gpu-operator/api/nvidia/v1" @@ -46,6 +47,18 @@ func TestGetGPUNodeOSInfo(t *testing.T) { osVersion: "v1.12.6", expected: "talosv1.12.6", }, + { + name: "rhel 9 omits minor version", + osName: "rhel", + osVersion: "9.4", + expected: "rhel9", + }, + { + name: "rhel 8 omits minor version", + osName: "rhel", + osVersion: "8.10", + expected: "rhel8", + }, { name: "rhel 10 omits minor version", osName: "rhel", @@ -88,13 +101,6 @@ func TestGetGPUNodeOSInfo(t *testing.T) { osVersion: "rolling", expected: "archlinuxrolling", }, - { - name: "rhel invalid major version errors", - osName: "rhel", - osVersion: "A.10", - expectError: true, - errorContainsText: "error processing OS major version", - }, } for _, tc := range testCases { @@ -130,6 +136,91 @@ func TestGetGPUNodeOSInfo(t *testing.T) { } } +type errorListClient struct { + ctrlclient.Client + err error +} + +func (c errorListClient) List(ctx context.Context, list ctrlclient.ObjectList, opts ...ctrlclient.ListOption) error { + return c.err +} + +func TestGetGPUNodeOSInfoListError(t *testing.T) { + expectedErr := errors.New("list failed") + controller := ClusterPolicyController{ + ctx: context.Background(), + client: errorListClient{err: expectedErr}, + } + + osName, osTag, err := controller.getGPUNodeOSInfo() + require.ErrorIs(t, err, expectedErr) + require.Empty(t, osName) + require.Empty(t, osTag) + require.Contains(t, err.Error(), "unable to list nodes with GPU present") +} + +func TestGetGPUNodeOSInfoNoGPUNodes(t *testing.T) { + scheme := runtime.NewScheme() + require.NoError(t, corev1.AddToScheme(scheme)) + + client := fake.NewClientBuilder().WithScheme(scheme).Build() + controller := ClusterPolicyController{ctx: context.Background(), client: client} + + osName, osTag, err := controller.getGPUNodeOSInfo() + require.Error(t, err) + require.Empty(t, osName) + require.Empty(t, osTag) + require.Contains(t, err.Error(), "no nodes found with GPU present") +} + +func TestGetGPUNodeOSInfoMissingLabels(t *testing.T) { + testCases := []struct { + name string + labels map[string]string + errorContainsText string + }{ + { + name: "missing OS release label", + labels: map[string]string{ + commonGPULabelKey: commonGPULabelValue, + nfdOSVersionIDLabelKey: "9.4", + }, + errorContainsText: "unable to retrieve OS name", + }, + { + name: "missing OS version label", + labels: map[string]string{ + commonGPULabelKey: commonGPULabelValue, + nfdOSReleaseIDLabelKey: "rhel", + }, + errorContainsText: "unable to retrieve OS version", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + scheme := runtime.NewScheme() + require.NoError(t, corev1.AddToScheme(scheme)) + + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "gpu-node-1", + Labels: tc.labels, + }, + } + + client := fake.NewClientBuilder().WithScheme(scheme).WithObjects(node).Build() + controller := ClusterPolicyController{ctx: context.Background(), client: client} + + osName, osTag, err := controller.getGPUNodeOSInfo() + require.Error(t, err) + require.Empty(t, osName) + require.Empty(t, osTag) + require.Contains(t, err.Error(), tc.errorContainsText) + }) + } +} + func TestGetRuntimeString(t *testing.T) { testCases := []struct { description string diff --git a/deployments/gpu-operator/values.yaml b/deployments/gpu-operator/values.yaml index 90efb9fe1a..fc2b1430cf 100644 --- a/deployments/gpu-operator/values.yaml +++ b/deployments/gpu-operator/values.yaml @@ -148,7 +148,7 @@ driver: usePrecompiled: false repository: nvcr.io/nvidia image: driver - version: "595.58.03" + version: "595.71.05" imagePullPolicy: IfNotPresent imagePullSecrets: [] startupProbe: diff --git a/internal/state/driver_test.go b/internal/state/driver_test.go index 07690c04d1..9c47074b50 100644 --- a/internal/state/driver_test.go +++ b/internal/state/driver_test.go @@ -556,6 +556,31 @@ func TestDriverAdditionalConfigsSubscriptionMounts(t *testing.T) { } } +func TestDriverConfigPathHelpers(t *testing.T) { + repoConfigPath, err := getRepoConfigPath("rhel") + require.NoError(t, err) + assert.Equal(t, "/etc/yum.repos.d", repoConfigPath) + + certConfigPath, err := getCertConfigPath("rhcos") + require.NoError(t, err) + assert.Equal(t, "/etc/pki/ca-trust/extracted/pem", certConfigPath) + + subscriptionPaths, err := getSubscriptionPathsToVolumeSources("rhel") + require.NoError(t, err) + assert.Contains(t, subscriptionPaths, "/run/secrets/etc-pki-entitlement") + assert.Contains(t, subscriptionPaths, "/run/secrets/redhat.repo") + assert.Contains(t, subscriptionPaths, "/run/secrets/rhsm") + + _, err = getRepoConfigPath("unsupported") + require.ErrorContains(t, err, "distribution unsupported not supported") + + _, err = getCertConfigPath("unsupported") + require.ErrorContains(t, err, "distribution unsupported not supported") + + _, err = getSubscriptionPathsToVolumeSources("unsupported") + require.ErrorContains(t, err, "distribution unsupported not supported") +} + func TestDriverOpenshiftDriverToolkit(t *testing.T) { const ( testName = "driver-openshift-drivertoolkit" diff --git a/internal/state/nodepool.go b/internal/state/nodepool.go index 5f0dd5309b..0493927c57 100644 --- a/internal/state/nodepool.go +++ b/internal/state/nodepool.go @@ -20,7 +20,6 @@ import ( "context" "fmt" "maps" - "strconv" "strings" corev1 "k8s.io/api/core/v1" @@ -143,41 +142,13 @@ func getNodePools(ctx context.Context, k8sClient client.Client, cr *nvidiav1alph } func getOSTag(osRelease, osVersion string) (string, error) { - osMajorVersion := strings.Split(osVersion, ".")[0] - var osTagSuffix string - // If the OS is RockyLinux or RHEL 10 & above, we will omit the minor version when constructing the os image tag + // If the OS is RockyLinux or RHEL, we will omit the minor version when constructing the os image tag switch osRelease { - case "rocky": - osTagSuffix = osMajorVersion - case "rhel": - osMajorNumber, err := parseOSMajorVersion(osVersion) - if err != nil { - return "", fmt.Errorf("failed to parse os version: %w", err) - } - if osMajorNumber >= 10 { - osTagSuffix = osMajorVersion - } else { - osTagSuffix = osVersion - } + case "rocky", "rhel": + osTagSuffix = strings.Split(osVersion, ".")[0] default: osTagSuffix = osVersion } return fmt.Sprintf("%s%s", osRelease, osTagSuffix), nil } - -func parseOSMajorVersion(osVersion string) (int, error) { - osMajorVersion := strings.Split(osVersion, ".")[0] - osMajorVersion = strings.TrimSpace(osMajorVersion) - osMajorVersion = strings.TrimPrefix(strings.TrimPrefix(osMajorVersion, "v"), "V") - if osMajorVersion == "" { - return 0, fmt.Errorf("empty OS major version") - } - - osMajorNumber, err := strconv.Atoi(osMajorVersion) - if err != nil { - return 0, err - } - - return osMajorNumber, nil -} diff --git a/internal/state/nodepool_test.go b/internal/state/nodepool_test.go index 6d175d7d2a..4201c83d36 100644 --- a/internal/state/nodepool_test.go +++ b/internal/state/nodepool_test.go @@ -17,9 +17,18 @@ package state import ( + "context" "testing" "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + + nvidiav1alpha1 "github.com/NVIDIA/gpu-operator/api/nvidia/v1alpha1" + "github.com/NVIDIA/gpu-operator/internal/consts" ) func TestGetOSTag(t *testing.T) { @@ -35,7 +44,14 @@ func TestGetOSTag(t *testing.T) { description: "valid os release & version", osRelease: "rhel", osVersion: "9.4", - expected: "rhel9.4", + expected: "rhel9", + expectError: false, + }, + { + description: "valid os release & version - rhel8", + osRelease: "rhel", + osVersion: "8.10", + expected: "rhel8", expectError: false, }, { @@ -73,13 +89,6 @@ func TestGetOSTag(t *testing.T) { expected: "archlinuxrolling", expectError: false, }, - { - description: "invalid os version", - osRelease: "rhel", - osVersion: "A.10", - expectError: true, - errorMessage: "failed to parse os version: strconv.Atoi: parsing \"A\": invalid syntax", - }, } for _, test := range tests { @@ -95,3 +104,190 @@ func TestGetOSTag(t *testing.T) { }) } } + +func TestGetNodePoolsGroupsNodesByOSTag(t *testing.T) { + require.NoError(t, corev1.AddToScheme(scheme.Scheme)) + + k8sClient := fake.NewClientBuilder(). + WithScheme(scheme.Scheme). + WithObjects( + &corev1.Node{ObjectMeta: metav1.ObjectMeta{ + Name: "rhel-node", + Labels: map[string]string{ + "pool": "gold", + consts.GPUPresentLabel: "true", + consts.NVIDIADriverOwnerLabel: "driver-a", + nfdOSReleaseIDLabelKey: "rhel", + nfdOSVersionIDLabelKey: "9.4", + }, + }}, + &corev1.Node{ObjectMeta: metav1.ObjectMeta{ + Name: "ubuntu-node", + Labels: map[string]string{ + "pool": "gold", + consts.GPUPresentLabel: "true", + consts.NVIDIADriverOwnerLabel: "driver-a", + nfdOSReleaseIDLabelKey: "ubuntu", + nfdOSVersionIDLabelKey: "22.04", + }, + }}, + &corev1.Node{ObjectMeta: metav1.ObjectMeta{ + Name: "other-pool-node", + Labels: map[string]string{ + "pool": "silver", + consts.GPUPresentLabel: "true", + consts.NVIDIADriverOwnerLabel: "driver-a", + nfdOSReleaseIDLabelKey: "ubuntu", + nfdOSVersionIDLabelKey: "20.04", + }, + }}, + ). + Build() + driver := &nvidiav1alpha1.NVIDIADriver{ + ObjectMeta: metav1.ObjectMeta{Name: "driver-a"}, + Spec: nvidiav1alpha1.NVIDIADriverSpec{ + NodeSelector: map[string]string{"pool": "gold"}, + }, + } + + nodePools, err := getNodePools(context.Background(), k8sClient, driver, false) + + require.NoError(t, err) + require.Len(t, nodePools, 2) + + poolsByName := nodePoolsByName(nodePools) + require.Contains(t, poolsByName, "rhel9") + require.Equal(t, "rhel", poolsByName["rhel9"].osRelease) + require.Equal(t, "9.4", poolsByName["rhel9"].osVersion) + require.Equal(t, "gold", poolsByName["rhel9"].nodeSelector["pool"]) + require.Equal(t, "driver-a", poolsByName["rhel9"].nodeSelector[consts.NVIDIADriverOwnerLabel]) + + require.Contains(t, poolsByName, "ubuntu22.04") + require.Equal(t, "ubuntu", poolsByName["ubuntu22.04"].osRelease) + require.Equal(t, "22.04", poolsByName["ubuntu22.04"].osVersion) +} + +func TestGetNodePoolsSkipsNodesMissingNFDOSLabels(t *testing.T) { + require.NoError(t, corev1.AddToScheme(scheme.Scheme)) + + k8sClient := fake.NewClientBuilder(). + WithScheme(scheme.Scheme). + WithObjects( + &corev1.Node{ObjectMeta: metav1.ObjectMeta{ + Name: "missing-os-release", + Labels: map[string]string{ + consts.GPUPresentLabel: "true", + consts.NVIDIADriverOwnerLabel: "driver-a", + nfdOSVersionIDLabelKey: "9.4", + }, + }}, + &corev1.Node{ObjectMeta: metav1.ObjectMeta{ + Name: "missing-os-version", + Labels: map[string]string{ + consts.GPUPresentLabel: "true", + consts.NVIDIADriverOwnerLabel: "driver-a", + nfdOSReleaseIDLabelKey: "rhel", + }, + }}, + ). + Build() + driver := &nvidiav1alpha1.NVIDIADriver{ + ObjectMeta: metav1.ObjectMeta{Name: "driver-a"}, + } + + nodePools, err := getNodePools(context.Background(), k8sClient, driver, false) + + require.NoError(t, err) + require.Empty(t, nodePools) +} + +func TestGetNodePoolsPartitionsPrecompiledNodesByKernel(t *testing.T) { + require.NoError(t, corev1.AddToScheme(scheme.Scheme)) + + k8sClient := fake.NewClientBuilder(). + WithScheme(scheme.Scheme). + WithObjects( + &corev1.Node{ObjectMeta: metav1.ObjectMeta{ + Name: "kernel-node", + Labels: map[string]string{ + consts.GPUPresentLabel: "true", + consts.NVIDIADriverOwnerLabel: "driver-a", + nfdOSReleaseIDLabelKey: "ubuntu", + nfdOSVersionIDLabelKey: "22.04", + nfdKernelLabelKey: "5.15.0-70-generic_x86_64", + }, + }}, + &corev1.Node{ObjectMeta: metav1.ObjectMeta{ + Name: "missing-kernel-node", + Labels: map[string]string{ + consts.GPUPresentLabel: "true", + consts.NVIDIADriverOwnerLabel: "driver-a", + nfdOSReleaseIDLabelKey: "ubuntu", + nfdOSVersionIDLabelKey: "22.04", + }, + }}, + ). + Build() + driver := &nvidiav1alpha1.NVIDIADriver{ + ObjectMeta: metav1.ObjectMeta{Name: "driver-a"}, + Spec: nvidiav1alpha1.NVIDIADriverSpec{ + UsePrecompiled: ptr.To(true), + }, + } + + nodePools, err := getNodePools(context.Background(), k8sClient, driver, false) + + require.NoError(t, err) + require.Len(t, nodePools, 1) + require.Equal(t, "ubuntu22.04-5.15.0-70-generic", nodePools[0].name) + require.Equal(t, "5.15.0-70-generic_x86_64", nodePools[0].kernel) + require.Equal(t, "5.15.0-70-generic_x86_64", nodePools[0].nodeSelector[nfdKernelLabelKey]) +} + +func TestGetNodePoolsPartitionsOpenShiftNodesByRHCOSVersion(t *testing.T) { + require.NoError(t, corev1.AddToScheme(scheme.Scheme)) + + k8sClient := fake.NewClientBuilder(). + WithScheme(scheme.Scheme). + WithObjects( + &corev1.Node{ObjectMeta: metav1.ObjectMeta{ + Name: "rhcos-node", + Labels: map[string]string{ + consts.GPUPresentLabel: "true", + consts.NVIDIADriverOwnerLabel: "driver-a", + nfdOSReleaseIDLabelKey: "rhcos", + nfdOSVersionIDLabelKey: "4.14", + nfdOSTreeVersionLabelKey: "414.92.202309282257", + }, + }}, + &corev1.Node{ObjectMeta: metav1.ObjectMeta{ + Name: "missing-rhcos-node", + Labels: map[string]string{ + consts.GPUPresentLabel: "true", + consts.NVIDIADriverOwnerLabel: "driver-a", + nfdOSReleaseIDLabelKey: "rhcos", + nfdOSVersionIDLabelKey: "4.14", + }, + }}, + ). + Build() + driver := &nvidiav1alpha1.NVIDIADriver{ + ObjectMeta: metav1.ObjectMeta{Name: "driver-a"}, + } + + nodePools, err := getNodePools(context.Background(), k8sClient, driver, true) + + require.NoError(t, err) + require.Len(t, nodePools, 1) + require.Equal(t, "414.92.202309282257", nodePools[0].name) + require.Equal(t, "414.92.202309282257", nodePools[0].rhcosVersion) + require.Equal(t, "414.92.202309282257", nodePools[0].nodeSelector[nfdOSTreeVersionLabelKey]) +} + +func nodePoolsByName(nodePools []nodePool) map[string]nodePool { + poolsByName := make(map[string]nodePool, len(nodePools)) + for _, pool := range nodePools { + poolsByName[pool.name] = pool + } + return poolsByName +}