diff --git a/src/slurm_plugin/clustermgtd.py b/src/slurm_plugin/clustermgtd.py index c612dea1..8f053c96 100644 --- a/src/slurm_plugin/clustermgtd.py +++ b/src/slurm_plugin/clustermgtd.py @@ -151,6 +151,7 @@ class ClustermgtdConfig: "terminate_down_nodes": True, "orphaned_instance_timeout": 300, "ec2_instance_missing_max_count": 0, + "compute_node_missing_ip_max_count": 10, # Health check configs "disable_ec2_health_check": False, "disable_scheduled_event_health_check": False, @@ -310,6 +311,11 @@ def _get_terminate_config(self, config): "ec2_instance_missing_max_count", fallback=self.DEFAULTS.get("ec2_instance_missing_max_count"), ) + self.compute_node_missing_ip_max_count = config.getint( + "clustermgtd", + "compute_node_missing_ip_max_count", + fallback=self.DEFAULTS.get("compute_node_missing_ip_max_count"), + ) self.disable_nodes_on_insufficient_capacity = self.insufficient_capacity_timeout > 0 def _get_dns_config(self, config): @@ -391,6 +397,7 @@ def __init__(self, config): self._static_nodes_in_replacement = set() self._partitions_protected_failure_count_map = {} self._nodes_without_backing_instance_count_map = {} + self._fallback_match_count_map = {} self._compute_fleet_status = ComputeFleetStatus.RUNNING self._current_time = None self._config = None @@ -551,7 +558,13 @@ def manage_cluster(self): return log.debug("Current cluster instances in EC2: %s", cluster_instances) partitions = list(partitions_name_map.values()) - self._update_slurm_nodes_with_ec2_info(nodes, cluster_instances) + self._update_slurm_nodes_with_ec2_info( + nodes, + cluster_instances, + self._instance_manager, + self._fallback_match_count_map, + self._config.compute_node_missing_ip_max_count, + ) self._event_publisher.publish_compute_node_events(nodes, cluster_instances) # Handle inactive partition and terminate backing instances self._clean_up_inactive_partition(partitions) @@ -1144,16 +1157,90 @@ def _parse_scheduler_nodes_data(nodes): raise @staticmethod - def _update_slurm_nodes_with_ec2_info(nodes, cluster_instances): - if cluster_instances: - ip_to_slurm_node_map = {node.nodeaddr: node for node in nodes} - for instance in cluster_instances: - for private_ip in instance.all_private_ips: - if private_ip in ip_to_slurm_node_map: - slurm_node = ip_to_slurm_node_map.get(private_ip) - slurm_node.instance = instance - instance.slurm_node = slurm_node - break + def _update_slurm_nodes_with_ec2_info( # noqa: C901 + nodes, + cluster_instances, + instance_manager=None, + fallback_match_count_map=None, + compute_node_missing_ip_max_count=10, + ): + """ + Associate EC2 instances with Slurm nodes. + + Primary matching is by IP address. If instances have empty IPs (due to EC2 eventual consistency), + a fallback matching by instance ID is attempted using the DynamoDB node-name-to-instance-id mapping. + + Args: + fallback_match_count_map: dict tracking consecutive fallback match counts per node name. + If a node is matched via fallback for more than compute_node_missing_ip_max_count + consecutive iterations, the association is removed so that existing health check + mechanisms can handle it. + compute_node_missing_ip_max_count: max consecutive iterations to tolerate missing IP + before dissociating the instance. Configurable via clustermgtd conf file. + """ + if not cluster_instances: + return + + # First pass: match by IP address (primary method) + ip_to_slurm_node_map = {node.nodeaddr: node for node in nodes} + unmatched_instances = [] + for instance in cluster_instances: + matched = False + for private_ip in instance.all_private_ips: + if private_ip in ip_to_slurm_node_map: + slurm_node = ip_to_slurm_node_map.get(private_ip) + slurm_node.instance = instance + instance.slurm_node = slurm_node + matched = True + break + if not matched and not instance.all_private_ips: + unmatched_instances.append(instance) + + # Second pass: fallback matching by instance ID for instances with missing IPs + if unmatched_instances and instance_manager: + log.info( + "Found %d instances with missing IPs, attempting fallback matching by instance ID", + len(unmatched_instances), + ) + try: + instance_id_to_node_name = instance_manager.get_instance_id_to_node_name_mapping() + except Exception as e: + log.warning("Failed to read instance-ID-to-node-name mapping for fallback matching: %s", e) + instance_id_to_node_name = {} + + if instance_id_to_node_name: + name_to_slurm_node_map = {node.name: node for node in nodes} + for instance in unmatched_instances: + node_name = instance_id_to_node_name.get(instance.id) + if node_name and node_name in name_to_slurm_node_map: + slurm_node = name_to_slurm_node_map[node_name] + if not slurm_node.instance: + # Track consecutive fallback match count + if fallback_match_count_map is not None: + count = fallback_match_count_map.get(node_name, 0) + 1 + fallback_match_count_map[node_name] = count + if count > compute_node_missing_ip_max_count: + log.warning( + "Instance %s matched to node %s via fallback for %d consecutive iterations " + "(IP never appeared). Dissociating to let health checks handle it.", + instance.id, + node_name, + count, + ) + continue + log.info( + "Matched instance %s to node %s via instance ID fallback (IP not yet available)", + instance.id, + node_name, + ) + slurm_node.instance = instance + instance.slurm_node = slurm_node + + # Clean up fallback counts for nodes no longer needing fallback matching + # (either IP appeared and they matched via primary path, or the instance no longer exists) + if fallback_match_count_map is not None and not unmatched_instances: + # No unmatched instances this iteration — all previously tracked nodes have resolved + fallback_match_count_map.clear() @staticmethod def get_instance_id_to_active_node_map(partitions: List[SlurmPartition]) -> Dict: diff --git a/src/slurm_plugin/instance_manager.py b/src/slurm_plugin/instance_manager.py index bd60ec57..746483e2 100644 --- a/src/slurm_plugin/instance_manager.py +++ b/src/slurm_plugin/instance_manager.py @@ -262,7 +262,8 @@ def get_cluster_instances(self, include_head_node=False, alive_states_only=True) """ Get instances that are associated with the cluster. - Instances without all the info set are ignored and not returned + Instances with missing EC2 info (e.g., PrivateIpAddress due to EC2 eventual consistency) are included + with empty IP fields to allow instance-ID-based fallback matching in clustermgtd. """ ec2_client = boto3.client("ec2", region_name=self._region, config=self._boto3_config) paginator = ec2_client.get_paginator("describe_instances") @@ -290,11 +291,25 @@ def get_cluster_instances(self, include_head_node=False, alive_states_only=True) ) ) except Exception as e: + required_fields = {"PrivateIpAddress", "PrivateDnsName", "NetworkInterfaces"} + missing_fields = required_fields - set(instance_info.keys()) logger.warning( - "Ignoring instance %s because not all EC2 info are available, exception: %s, message: %s", + "Instance %s missing some EC2 info, exception: %s, message: %s. " + "Missing top-level fields: %s. " + "Adding with instance ID only to allow fallback matching.", instance_info["InstanceId"], type(e).__name__, e, + missing_fields if missing_fields else "none", + ) + instances.append( + EC2Instance( + instance_info["InstanceId"], + "", + "", + set(), + instance_info.get("LaunchTime"), + ) ) return instances @@ -311,6 +326,39 @@ def terminate_all_compute_nodes(self, terminate_batch_size): logger.error("Failed when terminating compute fleet with error %s", e) return False + def get_instance_id_to_node_name_mapping(self): + """Read instance-ID-to-node-name mapping from DynamoDB. + + Returns a dict mapping instance_id -> node_name. + Used by clustermgtd for fallback matching when PrivateIpAddress is missing. + """ + if not self._table: + logger.warning("DynamoDB table not configured, cannot read instance-ID-to-node-name mapping") + return {} + try: + mapping = {} + response = self._table.scan(ProjectionExpression="Id, InstanceId") + for item in response.get("Items", []): + instance_id = item.get("InstanceId") + node_name = item.get("Id") + if instance_id and node_name: + mapping[instance_id] = node_name + # Handle pagination + while "LastEvaluatedKey" in response: + response = self._table.scan( + ProjectionExpression="Id, InstanceId", + ExclusiveStartKey=response["LastEvaluatedKey"], + ) + for item in response.get("Items", []): + instance_id = item.get("InstanceId") + node_name = item.get("Id") + if instance_id and node_name: + mapping[instance_id] = node_name + return mapping + except Exception as e: + logger.warning("Failed to read instance-ID-to-node-name mapping from DynamoDB: %s", e) + return {} + def _update_failed_nodes(self, nodeset, error_code="Exception", override=True): """Update failed nodes dict with error code as key and nodeset value.""" if not override: diff --git a/tests/slurm_plugin/test_instance_manager.py b/tests/slurm_plugin/test_instance_manager.py index 28f809b0..12ace2ed 100644 --- a/tests/slurm_plugin/test_instance_manager.py +++ b/tests/slurm_plugin/test_instance_manager.py @@ -907,6 +907,7 @@ def get_unhealthy_cluster_instance_status( generate_error=False, ), [ + EC2Instance("i-1", "", "", set(), datetime(2020, 1, 1, tzinfo=timezone.utc)), EC2Instance("i-2", "ip-2", "hostname", {"ip-2"}, datetime(2020, 1, 1, tzinfo=timezone.utc)), ], False,