From f66dc37dbc83ba8ce74e79ae6f8e145c7effb4d4 Mon Sep 17 00:00:00 2001 From: CYHSM Date: Tue, 2 Jun 2026 09:28:37 +0200 Subject: [PATCH] extracted clean logging code from hpc branch --- src/modalities/batch.py | 37 +++- src/modalities/loss_functions.py | 74 ++++--- src/modalities/models/gpt2/gpt2_model.py | 47 ++++- src/modalities/trainer.py | 48 ++++- src/modalities/training/logging.py | 198 ++++++++++++++++++ .../subscriber_impl/test_logging_rich.py | 120 +++++++++++ 6 files changed, 481 insertions(+), 43 deletions(-) create mode 100644 src/modalities/training/logging.py create mode 100644 tests/logging_broker/subscriber_impl/test_logging_rich.py diff --git a/src/modalities/batch.py b/src/modalities/batch.py index bb55f245a..834ad6b1e 100644 --- a/src/modalities/batch.py +++ b/src/modalities/batch.py @@ -54,6 +54,26 @@ def __len__(self) -> int: return self.samples[key].shape[self.batch_dim] +def _apply_to(val, device): + if isinstance(val, torch.Tensor): + return val.to(device) + elif isinstance(val, dict): + return {k: _apply_to(v, device) for k, v in val.items()} + elif isinstance(val, list): + return [_apply_to(v, device) for v in val] + return val + + +def _apply_detach(val): + if isinstance(val, torch.Tensor): + return val.detach() + elif isinstance(val, dict): + return {k: _apply_detach(v) for k, v in val.items()} + elif isinstance(val, list): + return [_apply_detach(v) for v in val] + return val + + @dataclass class InferenceResultBatch(Batch, TorchDeviceMixin): """Stores targets and predictions of an entire batch.""" @@ -71,12 +91,12 @@ def device(self) -> torch.device: return self.targets[key].device def to(self, device: torch.device): - self.predictions = {k: v.to(device) for k, v in self.predictions.items()} - self.targets = {k: v.to(device) for k, v in self.targets.items()} + self.predictions = {k: _apply_to(v, device) for k, v in self.predictions.items()} + self.targets = {k: _apply_to(v, device) for k, v in self.targets.items()} def detach(self): - self.targets = {k: v.detach() for k, v in self.targets.items()} - self.predictions = {k: v.detach() for k, v in self.predictions.items()} + self.targets = {k: _apply_detach(v) for k, v in self.targets.items()} + self.predictions = {k: _apply_detach(v) for k, v in self.predictions.items()} def get_predictions(self, key: str) -> torch.Tensor: if key not in self.predictions: @@ -89,8 +109,13 @@ def get_targets(self, key: str) -> torch.Tensor: return self.targets[key] def __len__(self) -> int: - key = list(self.predictions.keys())[0] - return self.predictions[key].shape[self.batch_dim] + for v in self.predictions.values(): + if isinstance(v, torch.Tensor): + return v.shape[self.batch_dim] + for v in self.targets.values(): + if isinstance(v, torch.Tensor): + return v.shape[self.batch_dim] + raise ValueError("No tensor found in predictions or targets to determine batch length") @dataclass diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index e3be6100d..b4cd1367f 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -31,17 +31,29 @@ def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEnt self.prediction_key = prediction_key # Mean over the tokens in the local-batch (batch per rank) self.loss_fun = CrossEntropyLoss(reduction="mean") + self._last_ce_loss = torch.tensor(0.0) + self._last_metrics = None @overload def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: ... @overload - def __call__(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + def __call__(self, outputs: torch.Tensor | dict, targets: torch.Tensor) -> torch.Tensor: ... def __call__(self, *args, **kwargs) -> torch.Tensor: - labels, lm_logits = self._parse_arguments(args, kwargs) + labels, outputs = self._parse_arguments(args, kwargs) + + if isinstance(outputs, dict): + if "logits" in outputs: + lm_logits = outputs["logits"] + else: + lm_logits = outputs[self.prediction_key] + metrics = outputs.get("metrics", None) + else: + lm_logits = outputs + metrics = None # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) @@ -49,42 +61,56 @@ def __call__(self, *args, **kwargs) -> torch.Tensor: shift_labels = labels.contiguous().long() # Flatten the tokens. We compute here, the loss per token. loss = self.loss_fun(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + self._last_ce_loss = loss.detach() + self._last_metrics = _detach_metrics(metrics) return loss + def get_metrics(self) -> dict: + return { + "ce_loss": self._last_ce_loss, + "metrics": self._last_metrics, + } + def _parse_arguments( self, - args: list[torch.Tensor] | list[InferenceResultBatch], - kwargs: dict[str, torch.Tensor] | dict[str, InferenceResultBatch], - ) -> tuple[torch.Tensor, torch.Tensor]: + args: list[torch.Tensor | dict] | list[InferenceResultBatch], + kwargs: dict[str, torch.Tensor | dict] | dict[str, InferenceResultBatch], + ) -> tuple[torch.Tensor, torch.Tensor | dict]: if len(args) == 1 and isinstance(args[0], InferenceResultBatch): forward_batch = args[0] labels = forward_batch.get_targets(self.target_key) - lm_logits = forward_batch.get_predictions(self.prediction_key) + outputs = forward_batch.predictions elif "forward_batch" in kwargs and isinstance(kwargs["forward_batch"], InferenceResultBatch): forward_batch = kwargs["forward_batch"] labels = forward_batch.get_targets(self.target_key) - lm_logits = forward_batch.get_predictions(self.prediction_key) - elif len(args) == 2 and all(isinstance(arg, torch.Tensor) for arg in args): - lm_logits, labels = args - elif ( - "outputs" in kwargs - and "targets" in kwargs - and isinstance(kwargs["outputs"], torch.Tensor) - and isinstance(kwargs["targets"], torch.Tensor) - ): - lm_logits = kwargs["outputs"] + outputs = forward_batch.predictions + elif len(args) == 2: + outputs, labels = args + elif "outputs" in kwargs and "targets" in kwargs: + outputs = kwargs["outputs"] labels = kwargs["targets"] - elif ( - len(args) == 1 - and "targets" in kwargs - and isinstance(args[0], torch.Tensor) - and isinstance(kwargs["targets"], torch.Tensor) - ): - lm_logits = args[0] + elif len(args) == 1 and "targets" in kwargs: + outputs = args[0] labels = kwargs["targets"] else: raise TypeError("Invalid arguments for CLMCrossEntropyLoss.__call__") - return labels, lm_logits + return labels, outputs + + +def _detach_metrics(metrics: dict | None) -> dict | None: + """Recursively detach all tensors in a nested dict.""" + if metrics is None: + return None + result = {} + for key, value in metrics.items(): + if isinstance(value, torch.Tensor): + result[key] = value.detach() + elif isinstance(value, dict): + result[key] = _detach_metrics(value) + else: + result[key] = value + return result def nce_loss( diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 2da4979c0..f73d56f16 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -939,7 +939,7 @@ def __init__( ) # https://paperswithcode.com/method/weight-tying @overload - def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor | dict]: """ Forward pass of the GPT2LLM module. @@ -948,8 +948,8 @@ def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - sample_key (str): Key for the input tensor containing token ids. Returns: - dict[str, torch.Tensor]: A dictionary containing output tensors. - - prediction_key (str): Key for the output tensor containing logits. + dict[str, torch.Tensor | dict]: A dictionary containing output tensors and metrics. + - prediction_key (str): Key for the output containing logits and metrics dict. """ ... @@ -966,7 +966,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ ... - def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: + def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor | dict] | torch.Tensor: """ Forward pass of the GPT2LLM module. @@ -974,14 +974,19 @@ def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, t inputs (dict[str, torch.Tensor] | torch.Tensor): Input data. Returns: - dict[str, torch.Tensor] | torch.Tensor: Model output. + dict[str, torch.Tensor | dict] | torch.Tensor: Model output. """ if isinstance(inputs, dict): - return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} + logits, metrics = self.forward_impl(inputs[self.sample_key]) + return { + self.prediction_key: logits, + "metrics": metrics, + } else: - return self.forward_impl(inputs) + logits, _ = self.forward_impl(inputs) + return logits - def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: + def forward_impl(self, inputs: torch.Tensor) -> tuple[torch.Tensor, dict]: """ Forward pass implementation of the GPT2LLM module. @@ -989,7 +994,7 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: inputs (torch.Tensor): A tensor containing input token ids. Returns: - torch.Tensor: A tensor containing output logits. + tuple[torch.Tensor, dict]: A tuple containing output logits and custom metrics. """ device = inputs.device seq_len = inputs.size(1) @@ -1009,11 +1014,31 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: # TODO: use drop out also without absolute position embedding? h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h + layer_norms = [] for layer_idx in self.transformer.h: h = self.transformer.h[layer_idx](h) + layer_norms.append(h.detach().norm(dim=-1).mean()) + h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h - h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h - return h + logits = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h + + with torch.no_grad(): + log_p = torch.log_softmax(logits, dim=-1) + p = torch.exp(log_p) + entropy = -torch.sum(p * log_p, dim=-1).mean() + layer_norms_tensor = torch.stack(layer_norms) + + metrics = { + "scalars": { + "logits_entropy": entropy + }, + "per_layer_scalars": { + "layer_activation_norm": layer_norms_tensor + } + } + + return logits, metrics + def manual_scaled_dot_product_attention( diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index c715a01fa..1811037cc 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -51,6 +51,10 @@ class ThroughputAggregationKeys(Enum): FORWARD_BACKWARD_TIME = "FORWARD_BACKWARD_TIME" + +from modalities.training.logging import MetricsAccumulator, format_metrics + + class Trainer: def __init__( self, @@ -150,7 +154,7 @@ def _train_batch( operate the model. Defaults to None. Returns: - tuple[bool, int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]: A tuple containing the following: - step_performed (bool): Indicates whether a training step was performed. - num_train_steps_done (int): The number of training steps done. @@ -234,6 +238,7 @@ def train( local_num_seen_samples = 0 cumulated_losses = torch.zeros(3).cuda() + metrics_accum = MetricsAccumulator() # throughput device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -291,6 +296,9 @@ def train( # it has less samples than the batch size cumulated_losses[-1] += 1 # number of local batches + if hasattr(loss_fun, "get_metrics"): + metrics_accum.accumulate(loss_fun.get_metrics()) + # gradient norm is already synced across all ranks if gradient_norm_score is not None: gradient_norm_scores.append(gradient_norm_score.item()) @@ -336,17 +344,52 @@ def train( reduced_losses[0], reduced_losses[1], ) + + adaptive_losses = {} + adaptive_metrics = {} + if metrics_accum.count > 0: + ( + sync_tensor, scalar_names, per_layer_names, per_layer_sizes, + hist_names, hist_shapes, + ) = metrics_accum.build_sync_tensor(device) + + reduce_scale = dist.get_world_size() / self.pp_degree + synced_tensor = Reducer.reduce( + tensor=sync_tensor, + operation=dist.ReduceOp.SUM, + post_processing_fun=lambda t: t / reduce_scale, + ) + + ( + synced_loss, synced_scalars, synced_per_layer, synced_hists, + ) = MetricsAccumulator.unpack_synced_tensor( + synced_tensor, scalar_names, per_layer_names, per_layer_sizes, + hist_names, hist_shapes, + ) + + adaptive_losses, adaptive_metrics = format_metrics( + loss=synced_loss, + scalars=synced_scalars, + per_layer_scalars=synced_per_layer, + per_layer_vectors=metrics_accum.last_per_layer_vectors, + per_layer_histograms=synced_hists, + ) + losses = { "train loss avg": ResultItem(train_loss_avg, decimal_places=2), "train loss last": ResultItem(train_loss_last_batch, decimal_places=2), + **adaptive_losses, } metrics = { "consumed tokens": ResultItem(torch.tensor(training_progress.num_seen_tokens_total), 0), "grad norm avg": ResultItem(torch.mean(torch.Tensor(gradient_norm_scores)), 2), "grad norm last": ResultItem(torch.tensor(gradient_norm_scores[-1]), 2), + **adaptive_metrics, } + gradient_norm_scores = [] + mfu_score = torch.tensor(-1.0) if self.mfu_calculator is not None: mfu_score = self.mfu_calculator.compute(num_samples_per_second=global_num_samples_per_second) @@ -384,11 +427,12 @@ def train( ) cumulated_losses.zero_() + metrics_accum.reset() if step_performed: self.gc.run(step_count=training_progress.num_seen_steps_total) evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) checkpointing_callback(training_progress=training_progress) - + profiler_cm.step() @staticmethod diff --git a/src/modalities/training/logging.py b/src/modalities/training/logging.py new file mode 100644 index 000000000..18b632af9 --- /dev/null +++ b/src/modalities/training/logging.py @@ -0,0 +1,198 @@ +from typing import Optional +import torch +from modalities.batch import ResultItem + + +# ============================================================================= +# Generic Metrics Accumulator +# ============================================================================= + +class MetricsAccumulator: + """Accumulates metrics across batches and produces a single flat tensor + for cross-rank reduction. + """ + + def __init__(self): + self.reset() + + def reset(self): + self.loss_sum: float = 0.0 + self.scalar_sums: dict[str, float] = {} + self.per_layer_scalar_sums: dict[str, torch.Tensor] = {} + self.last_per_layer_vectors: dict[str, torch.Tensor] = {} + self.per_layer_hist_sums: dict[str, torch.Tensor] = {} + self.count: int = 0 + + def accumulate(self, loss_metrics: dict): + if "ce_loss" in loss_metrics: + self.loss_sum += loss_metrics["ce_loss"].item() + elif "loss" in loss_metrics: + self.loss_sum += loss_metrics["loss"].item() + self.count += 1 + + bag = loss_metrics.get("metrics") + if bag is None: + return + + for name, tensor in bag.get("scalars", {}).items(): + self.scalar_sums[name] = self.scalar_sums.get(name, 0.0) + tensor.item() + + for name, tensor in bag.get("per_layer_scalars", {}).items(): + if name not in self.per_layer_scalar_sums: + self.per_layer_scalar_sums[name] = torch.zeros_like(tensor, dtype=torch.float32) + self.per_layer_scalar_sums[name] += tensor.float() + + for name, tensor in bag.get("per_layer_vectors", {}).items(): + self.last_per_layer_vectors[name] = tensor + + for name, tensor in bag.get("per_layer_histograms", {}).items(): + if name not in self.per_layer_hist_sums: + self.per_layer_hist_sums[name] = torch.zeros_like(tensor, dtype=torch.float32) + self.per_layer_hist_sums[name] += tensor.float() + + def build_sync_tensor( + self, device: torch.device + ) -> tuple[ + torch.Tensor, + list[str], + list[str], + dict[str, int], + list[str], + dict[str, tuple], + ]: + if self.count == 0: + return torch.zeros(1, device=device), [], [], {}, [], {} + + n = self.count + values = [self.loss_sum / n] + + scalar_names = sorted(self.scalar_sums.keys()) + for name in scalar_names: + values.append(self.scalar_sums[name] / n) + + per_layer_names = sorted(self.per_layer_scalar_sums.keys()) + per_layer_sizes = {} + layer_tensors = [] + for name in per_layer_names: + t = self.per_layer_scalar_sums[name] / n + layer_tensors.append(t.to(device)) + per_layer_sizes[name] = t.numel() + + hist_names = sorted(self.per_layer_hist_sums.keys()) + hist_shapes: dict[str, tuple] = {} + hist_tensors = [] + for name in hist_names: + t = self.per_layer_hist_sums[name] / n + hist_shapes[name] = tuple(t.shape) + hist_tensors.append(t.to(device).flatten()) + + combined = torch.tensor(values, device=device, dtype=torch.float32) + if layer_tensors: + combined = torch.cat([combined, torch.cat(layer_tensors)]) + if hist_tensors: + combined = torch.cat([combined, torch.cat(hist_tensors)]) + + return combined, scalar_names, per_layer_names, per_layer_sizes, hist_names, hist_shapes + + @staticmethod + def unpack_synced_tensor( + synced: torch.Tensor, + scalar_names: list[str], + per_layer_names: list[str], + per_layer_sizes: dict[str, int], + hist_names: list[str] = None, + hist_shapes: dict[str, tuple] = None, + ) -> tuple[ + torch.Tensor, + dict[str, torch.Tensor], + dict[str, torch.Tensor], + dict[str, torch.Tensor], + ]: + hist_names = hist_names or [] + hist_shapes = hist_shapes or {} + + idx = 0 + loss = synced[idx]; idx += 1 + + scalars = {} + for name in scalar_names: + scalars[name] = synced[idx]; idx += 1 + + per_layer_scalars = {} + for name in per_layer_names: + size = per_layer_sizes[name] + per_layer_scalars[name] = synced[idx : idx + size]; idx += size + + per_layer_histograms = {} + for name in hist_names: + shape = hist_shapes[name] + size = 1 + for dim in shape: + size *= dim + per_layer_histograms[name] = synced[idx : idx + size].reshape(shape) + idx += size + + return loss, scalars, per_layer_scalars, per_layer_histograms + + +# ============================================================================= +# Metrics Formatter +# ============================================================================= + +def format_metrics( + loss: torch.Tensor, + scalars: dict[str, torch.Tensor], + per_layer_scalars: dict[str, torch.Tensor], + per_layer_vectors: dict[str, torch.Tensor], + summary_only: bool = False, + per_layer_histograms: Optional[dict[str, torch.Tensor]] = None, +) -> tuple[dict[str, ResultItem], dict[str, ResultItem]]: + per_layer_histograms = per_layer_histograms or {} + + losses = { + "loss/ce_avg": ResultItem(loss, decimal_places=2), + } + + metrics: dict[str, ResultItem] = {} + + for name, val in scalars.items(): + metrics[f"adaptive/{name}"] = ResultItem(val, 4) + + for name, vals in per_layer_scalars.items(): + metrics[f"summary/{name}"] = ResultItem(vals.mean(), 4) + if not summary_only: + for i, v in enumerate(vals): + metrics[f"layer_{i}/{name}"] = ResultItem(v, 4) + + for name, tensor in per_layer_vectors.items(): + if tensor.numel() == 0: + continue + t = tensor.float().cpu() + n_layers, n_loops = t.shape + + metrics[f"summary/{name}"] = ResultItem(t.mean(), 4) + + if not summary_only: + for i in range(n_layers): + metrics[f"layer_{i}/avg_{name}"] = ResultItem(t[i].mean(), 4) + for j in range(n_loops): + metrics[f"layer_{i}/{name}_{j}"] = ResultItem(t[i, j], 4) + + for j in range(n_loops): + metrics[f"loop_{j}/{name}"] = ResultItem(t[:, j].mean(), 4) + + for name, tensor in per_layer_histograms.items(): + if tensor.numel() == 0: + continue + t = tensor.float().cpu() + n_layers, n_bins = t.shape + + for b in range(n_bins): + metrics[f"hist/{name}/bin_{b}"] = ResultItem(t[:, b].mean(), 4) + + if not summary_only: + for i in range(n_layers): + for b in range(n_bins): + metrics[f"hist/{name}/layer_{i}/bin_{b}"] = ResultItem(t[i, b], 4) + + return losses, metrics diff --git a/tests/logging_broker/subscriber_impl/test_logging_rich.py b/tests/logging_broker/subscriber_impl/test_logging_rich.py new file mode 100644 index 000000000..89710fc2b --- /dev/null +++ b/tests/logging_broker/subscriber_impl/test_logging_rich.py @@ -0,0 +1,120 @@ +import torch +import pytest +from modalities.training.logging import MetricsAccumulator, format_metrics +from modalities.batch import ResultItem + + +def test_metrics_accumulator_accumulation_and_sync(): + device = torch.device("cpu") + accum = MetricsAccumulator() + + # Step 1: Accumulate first batch + metrics_1 = { + "scalars": { + "p_weight": torch.tensor(0.5) + }, + "per_layer_scalars": { + "cost": torch.tensor([1.0, 2.0]) + }, + "per_layer_vectors": { + "vec": torch.tensor([[0.1, 0.2], [0.3, 0.4]]) + }, + "per_layer_histograms": { + "hist": torch.tensor([[0.5, 0.5], [0.8, 0.2]]) + } + } + accum.accumulate({"ce_loss": torch.tensor(2.0), "metrics": metrics_1}) + + # Step 2: Accumulate second batch + metrics_2 = { + "scalars": { + "p_weight": torch.tensor(1.5) + }, + "per_layer_scalars": { + "cost": torch.tensor([3.0, 4.0]) + }, + "per_layer_vectors": { + # Vectors are last-batch-only in trainer design + "vec": torch.tensor([[1.1, 1.2], [1.3, 1.4]]) + }, + "per_layer_histograms": { + "hist": torch.tensor([[0.3, 0.7], [0.6, 0.4]]) + } + } + accum.accumulate({"ce_loss": torch.tensor(4.0), "metrics": metrics_2}) + + assert accum.count == 2 + + # Step 3: Build sync tensor + sync_tensor, scalar_names, pl_names, pl_sizes, hist_names, hist_shapes = accum.build_sync_tensor(device) + + # Expected averages: + # ce_loss = (2 + 4) / 2 = 3.0 + # scalars: p_weight = (0.5 + 1.5) / 2 = 1.0 + # per_layer_scalars: cost = [(1+3)/2, (2+4)/2] = [2.0, 3.0] + # per_layer_histograms: hist = [[(0.5+0.3)/2, (0.5+0.7)/2], [(0.8+0.6)/2, (0.2+0.4)/2]] = [[0.4, 0.6], [0.7, 0.3]] + + # Step 4: Unpack + loss, scalars, pl_scalars, pl_hist = MetricsAccumulator.unpack_synced_tensor( + sync_tensor, scalar_names, pl_names, pl_sizes, hist_names, hist_shapes + ) + + assert torch.allclose(loss, torch.tensor(3.0)) + assert torch.allclose(scalars["p_weight"], torch.tensor(1.0)) + assert torch.allclose(pl_scalars["cost"], torch.tensor([2.0, 3.0])) + assert torch.allclose(pl_hist["hist"], torch.tensor([[0.4, 0.6], [0.7, 0.3]])) + assert torch.allclose(accum.last_per_layer_vectors["vec"], torch.tensor([[1.1, 1.2], [1.3, 1.4]])) + + +def test_format_metrics(): + loss = torch.tensor(3.0) + scalars = {"p_weight": torch.tensor(1.0)} + pl_scalars = {"cost": torch.tensor([2.0, 3.0])} + pl_vectors = {"vec": torch.tensor([[1.1, 1.2], [1.3, 1.4]])} + pl_hists = {"hist": torch.tensor([[0.4, 0.6], [0.7, 0.3]])} + + # Test summary_only = False + losses, metrics = format_metrics( + loss=loss, + scalars=scalars, + per_layer_scalars=pl_scalars, + per_layer_vectors=pl_vectors, + summary_only=False, + per_layer_histograms=pl_hists + ) + + assert losses["loss/ce_avg"].value.item() == pytest.approx(3.0) + assert metrics["adaptive/p_weight"].value.item() == pytest.approx(1.0) + assert metrics["summary/cost"].value.item() == pytest.approx(2.5) + assert metrics["layer_0/cost"].value.item() == pytest.approx(2.0) + assert metrics["layer_1/cost"].value.item() == pytest.approx(3.0) + + # Vectors + assert metrics["summary/vec"].value.item() == pytest.approx(1.25) + assert metrics["layer_0/vec_0"].value.item() == pytest.approx(1.1) + assert metrics["layer_0/vec_1"].value.item() == pytest.approx(1.2) + assert metrics["layer_1/vec_0"].value.item() == pytest.approx(1.3) + assert metrics["layer_1/vec_1"].value.item() == pytest.approx(1.4) + + # Histograms + assert metrics["hist/hist/bin_0"].value.item() == pytest.approx(0.55) + assert metrics["hist/hist/bin_1"].value.item() == pytest.approx(0.45) + assert metrics["hist/hist/layer_0/bin_0"].value.item() == pytest.approx(0.4) + assert metrics["hist/hist/layer_1/bin_1"].value.item() == pytest.approx(0.3) + + # Test summary_only = True + losses, metrics = format_metrics( + loss=loss, + scalars=scalars, + per_layer_scalars=pl_scalars, + per_layer_vectors=pl_vectors, + summary_only=True, + per_layer_histograms=pl_hists + ) + + assert "layer_0/cost" not in metrics + assert "layer_0/vec_0" not in metrics + assert "hist/hist/layer_0/bin_0" not in metrics + assert metrics["summary/cost"].value.item() == pytest.approx(2.5) + assert metrics["summary/vec"].value.item() == pytest.approx(1.25) + assert metrics["hist/hist/bin_0"].value.item() == pytest.approx(0.55)