Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions src/modalities/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@ def __len__(self) -> int:
return self.samples[key].shape[self.batch_dim]


def _apply_to(val, device):
if isinstance(val, torch.Tensor):
return val.to(device)
elif isinstance(val, dict):
return {k: _apply_to(v, device) for k, v in val.items()}
elif isinstance(val, list):
return [_apply_to(v, device) for v in val]
return val


def _apply_detach(val):
if isinstance(val, torch.Tensor):
return val.detach()
elif isinstance(val, dict):
return {k: _apply_detach(v) for k, v in val.items()}
elif isinstance(val, list):
return [_apply_detach(v) for v in val]
return val


@dataclass
class InferenceResultBatch(Batch, TorchDeviceMixin):
"""Stores targets and predictions of an entire batch."""
Expand All @@ -71,12 +91,12 @@ def device(self) -> torch.device:
return self.targets[key].device

def to(self, device: torch.device):
self.predictions = {k: v.to(device) for k, v in self.predictions.items()}
self.targets = {k: v.to(device) for k, v in self.targets.items()}
self.predictions = {k: _apply_to(v, device) for k, v in self.predictions.items()}
self.targets = {k: _apply_to(v, device) for k, v in self.targets.items()}

def detach(self):
self.targets = {k: v.detach() for k, v in self.targets.items()}
self.predictions = {k: v.detach() for k, v in self.predictions.items()}
self.targets = {k: _apply_detach(v) for k, v in self.targets.items()}
self.predictions = {k: _apply_detach(v) for k, v in self.predictions.items()}

def get_predictions(self, key: str) -> torch.Tensor:
if key not in self.predictions:
Expand All @@ -89,8 +109,13 @@ def get_targets(self, key: str) -> torch.Tensor:
return self.targets[key]

def __len__(self) -> int:
key = list(self.predictions.keys())[0]
return self.predictions[key].shape[self.batch_dim]
for v in self.predictions.values():
if isinstance(v, torch.Tensor):
return v.shape[self.batch_dim]
for v in self.targets.values():
if isinstance(v, torch.Tensor):
return v.shape[self.batch_dim]
raise ValueError("No tensor found in predictions or targets to determine batch length")


@dataclass
Expand Down
74 changes: 50 additions & 24 deletions src/modalities/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,60 +31,86 @@ def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEnt
self.prediction_key = prediction_key
# Mean over the tokens in the local-batch (batch per rank)
self.loss_fun = CrossEntropyLoss(reduction="mean")
self._last_ce_loss = torch.tensor(0.0)
self._last_metrics = None

@overload
def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor:
...

@overload
def __call__(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
def __call__(self, outputs: torch.Tensor | dict, targets: torch.Tensor) -> torch.Tensor:
...

def __call__(self, *args, **kwargs) -> torch.Tensor:
labels, lm_logits = self._parse_arguments(args, kwargs)
labels, outputs = self._parse_arguments(args, kwargs)

if isinstance(outputs, dict):
if "logits" in outputs:
lm_logits = outputs["logits"]
else:
lm_logits = outputs[self.prediction_key]
metrics = outputs.get("metrics", None)
else:
lm_logits = outputs
metrics = None

# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
shift_logits = lm_logits.contiguous()
shift_labels = labels.contiguous().long()
# Flatten the tokens. We compute here, the loss per token.
loss = self.loss_fun(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

self._last_ce_loss = loss.detach()
self._last_metrics = _detach_metrics(metrics)
return loss

def get_metrics(self) -> dict:
return {
"ce_loss": self._last_ce_loss,
"metrics": self._last_metrics,
}

def _parse_arguments(
self,
args: list[torch.Tensor] | list[InferenceResultBatch],
kwargs: dict[str, torch.Tensor] | dict[str, InferenceResultBatch],
) -> tuple[torch.Tensor, torch.Tensor]:
args: list[torch.Tensor | dict] | list[InferenceResultBatch],
kwargs: dict[str, torch.Tensor | dict] | dict[str, InferenceResultBatch],
) -> tuple[torch.Tensor, torch.Tensor | dict]:
if len(args) == 1 and isinstance(args[0], InferenceResultBatch):
forward_batch = args[0]
labels = forward_batch.get_targets(self.target_key)
lm_logits = forward_batch.get_predictions(self.prediction_key)
outputs = forward_batch.predictions
elif "forward_batch" in kwargs and isinstance(kwargs["forward_batch"], InferenceResultBatch):
forward_batch = kwargs["forward_batch"]
labels = forward_batch.get_targets(self.target_key)
lm_logits = forward_batch.get_predictions(self.prediction_key)
elif len(args) == 2 and all(isinstance(arg, torch.Tensor) for arg in args):
lm_logits, labels = args
elif (
"outputs" in kwargs
and "targets" in kwargs
and isinstance(kwargs["outputs"], torch.Tensor)
and isinstance(kwargs["targets"], torch.Tensor)
):
lm_logits = kwargs["outputs"]
outputs = forward_batch.predictions
elif len(args) == 2:
outputs, labels = args
elif "outputs" in kwargs and "targets" in kwargs:
outputs = kwargs["outputs"]
labels = kwargs["targets"]
elif (
len(args) == 1
and "targets" in kwargs
and isinstance(args[0], torch.Tensor)
and isinstance(kwargs["targets"], torch.Tensor)
):
lm_logits = args[0]
elif len(args) == 1 and "targets" in kwargs:
outputs = args[0]
labels = kwargs["targets"]
else:
raise TypeError("Invalid arguments for CLMCrossEntropyLoss.__call__")
return labels, lm_logits
return labels, outputs


def _detach_metrics(metrics: dict | None) -> dict | None:
"""Recursively detach all tensors in a nested dict."""
if metrics is None:
return None
result = {}
for key, value in metrics.items():
if isinstance(value, torch.Tensor):
result[key] = value.detach()
elif isinstance(value, dict):
result[key] = _detach_metrics(value)
else:
result[key] = value
return result


def nce_loss(
Expand Down
47 changes: 36 additions & 11 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def __init__(
) # https://paperswithcode.com/method/weight-tying

@overload
def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor | dict]:
"""
Forward pass of the GPT2LLM module.

Expand All @@ -948,8 +948,8 @@ def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
- sample_key (str): Key for the input tensor containing token ids.

Returns:
dict[str, torch.Tensor]: A dictionary containing output tensors.
- prediction_key (str): Key for the output tensor containing logits.
dict[str, torch.Tensor | dict]: A dictionary containing output tensors and metrics.
- prediction_key (str): Key for the output containing logits and metrics dict.
"""
...

Expand All @@ -966,30 +966,35 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
...

def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor:
def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor | dict] | torch.Tensor:
"""
Forward pass of the GPT2LLM module.

Args:
inputs (dict[str, torch.Tensor] | torch.Tensor): Input data.

Returns:
dict[str, torch.Tensor] | torch.Tensor: Model output.
dict[str, torch.Tensor | dict] | torch.Tensor: Model output.
"""
if isinstance(inputs, dict):
return {self.prediction_key: self.forward_impl(inputs[self.sample_key])}
logits, metrics = self.forward_impl(inputs[self.sample_key])
return {
self.prediction_key: logits,
"metrics": metrics,
}
else:
return self.forward_impl(inputs)
logits, _ = self.forward_impl(inputs)
return logits

def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor:
def forward_impl(self, inputs: torch.Tensor) -> tuple[torch.Tensor, dict]:
"""
Forward pass implementation of the GPT2LLM module.

Args:
inputs (torch.Tensor): A tensor containing input token ids.

Returns:
torch.Tensor: A tensor containing output logits.
tuple[torch.Tensor, dict]: A tuple containing output logits and custom metrics.
"""
device = inputs.device
seq_len = inputs.size(1)
Expand All @@ -1009,11 +1014,31 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor:
# TODO: use drop out also without absolute position embedding?
h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h

layer_norms = []
for layer_idx in self.transformer.h:
h = self.transformer.h[layer_idx](h)
layer_norms.append(h.detach().norm(dim=-1).mean())

h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h
h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h
return h
logits = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h

with torch.no_grad():
log_p = torch.log_softmax(logits, dim=-1)
p = torch.exp(log_p)
entropy = -torch.sum(p * log_p, dim=-1).mean()
layer_norms_tensor = torch.stack(layer_norms)

metrics = {
"scalars": {
"logits_entropy": entropy
},
"per_layer_scalars": {
"layer_activation_norm": layer_norms_tensor
}
}

return logits, metrics



def manual_scaled_dot_product_attention(
Expand Down
48 changes: 46 additions & 2 deletions src/modalities/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class ThroughputAggregationKeys(Enum):
FORWARD_BACKWARD_TIME = "FORWARD_BACKWARD_TIME"



from modalities.training.logging import MetricsAccumulator, format_metrics


class Trainer:
def __init__(
self,
Expand Down Expand Up @@ -150,7 +154,7 @@ def _train_batch(
operate the model. Defaults to None.

Returns:
tuple[bool, int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]:
A tuple containing the following:
- step_performed (bool): Indicates whether a training step was performed.
- num_train_steps_done (int): The number of training steps done.
Expand Down Expand Up @@ -234,6 +238,7 @@ def train(

local_num_seen_samples = 0
cumulated_losses = torch.zeros(3).cuda()
metrics_accum = MetricsAccumulator()

# throughput
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -291,6 +296,9 @@ def train(
# it has less samples than the batch size
cumulated_losses[-1] += 1 # number of local batches

if hasattr(loss_fun, "get_metrics"):
metrics_accum.accumulate(loss_fun.get_metrics())

# gradient norm is already synced across all ranks
if gradient_norm_score is not None:
gradient_norm_scores.append(gradient_norm_score.item())
Expand Down Expand Up @@ -336,17 +344,52 @@ def train(
reduced_losses[0],
reduced_losses[1],
)

adaptive_losses = {}
adaptive_metrics = {}
if metrics_accum.count > 0:
(
sync_tensor, scalar_names, per_layer_names, per_layer_sizes,
hist_names, hist_shapes,
) = metrics_accum.build_sync_tensor(device)

reduce_scale = dist.get_world_size() / self.pp_degree
synced_tensor = Reducer.reduce(
tensor=sync_tensor,
operation=dist.ReduceOp.SUM,
post_processing_fun=lambda t: t / reduce_scale,
)

(
synced_loss, synced_scalars, synced_per_layer, synced_hists,
) = MetricsAccumulator.unpack_synced_tensor(
synced_tensor, scalar_names, per_layer_names, per_layer_sizes,
hist_names, hist_shapes,
)

adaptive_losses, adaptive_metrics = format_metrics(
loss=synced_loss,
scalars=synced_scalars,
per_layer_scalars=synced_per_layer,
per_layer_vectors=metrics_accum.last_per_layer_vectors,
per_layer_histograms=synced_hists,
)

losses = {
"train loss avg": ResultItem(train_loss_avg, decimal_places=2),
"train loss last": ResultItem(train_loss_last_batch, decimal_places=2),
**adaptive_losses,
}

metrics = {
"consumed tokens": ResultItem(torch.tensor(training_progress.num_seen_tokens_total), 0),
"grad norm avg": ResultItem(torch.mean(torch.Tensor(gradient_norm_scores)), 2),
"grad norm last": ResultItem(torch.tensor(gradient_norm_scores[-1]), 2),
**adaptive_metrics,
}

gradient_norm_scores = []

mfu_score = torch.tensor(-1.0)
if self.mfu_calculator is not None:
mfu_score = self.mfu_calculator.compute(num_samples_per_second=global_num_samples_per_second)
Expand Down Expand Up @@ -384,11 +427,12 @@ def train(
)

cumulated_losses.zero_()
metrics_accum.reset()
if step_performed:
self.gc.run(step_count=training_progress.num_seen_steps_total)
evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total)
checkpointing_callback(training_progress=training_progress)

profiler_cm.step()

@staticmethod
Expand Down
Loading