Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,19 @@ class AdamWOptimizerConfig(BaseModel):
fused: bool | None = None


class MuonOptimizerConfig(BaseModel):
lr: float
weight_decay: float
momentum: float = 0.95
nesterov: bool = True
ns_coefficients: tuple[float, float, float] = (3.4445, -4.775, 2.0315)
eps: float = 1e-07
ns_steps: int = Field(default=5)
adjust_lr_fn: str | None = None
weight_decay_groups_excluded: list[str]
wrapped_model: PydanticPytorchModuleOrListType


class DummyLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType

Expand Down
31 changes: 29 additions & 2 deletions src/modalities/optimizers/optimizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1
from torch.distributed.tensor import DTensor
from torch.optim import Adam, AdamW, Optimizer
from torch.optim import Adam, AdamW, Muon, Optimizer

from modalities.checkpointing.checkpoint_loading import FSDP1CheckpointLoadingIF
from modalities.exceptions import OptimizerError
Expand Down Expand Up @@ -49,6 +49,33 @@ def get_adam_w(
optimizer = AdamW(params=optimizer_groups, lr=lr, betas=betas, eps=eps, foreach=foreach, fused=fused)
return optimizer

@staticmethod
def get_muon(
lr: float,
weight_decay: float,
momentum: float,
nesterov: bool,
ns_coefficients: tuple[float, float, float],
eps: float,
ns_steps: int,
adjust_lr_fn: str | None,
weight_decay_groups_excluded: list[str],
wrapped_model: nn.Module,
) -> Optimizer:
optimizer_groups = get_optimizer_groups(wrapped_model, weight_decay, weight_decay_groups_excluded)
optimizer = Muon(
params=optimizer_groups,
lr=lr,
weight_decay=weight_decay,
momentum=momentum,
nesterov=nesterov,
ns_coefficients=ns_coefficients,
eps=eps,
ns_steps=ns_steps,
adjust_lr_fn=adjust_lr_fn,
)
return optimizer

@staticmethod
def get_fsdp1_checkpointed_optimizer_(
checkpoint_loading: FSDP1CheckpointLoadingIF,
Expand Down Expand Up @@ -149,7 +176,7 @@ def _create_optimizer_groups(

else:
raise OptimizerError(
f"model {type(model)} is not an instance of FSDP1 or FSDP2. " "Please use the correct model type."
f"model {type(model)} is not an instance of FSDP1 or FSDP2. Please use the correct model type."
)

if (
Expand Down
4 changes: 4 additions & 0 deletions src/modalities/registry/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
LinearWarmupCosineAnnealingLRSchedulerConfig,
LLMDataLoaderConfig,
MemMapDatasetConfig,
MuonOptimizerConfig,
OneCycleLRSchedulerConfig,
PackedMemMapDatasetContinuousConfig,
PackedMemMapDatasetMegatronConfig,
Expand Down Expand Up @@ -257,6 +258,9 @@ class ComponentEntity:
ComponentEntity(
"optimizer", "adam_w", maybe_model_list_for_optimizer(OptimizerFactory.get_adam_w), AdamWOptimizerConfig
),
ComponentEntity(
"optimizer", "muon", maybe_model_list_for_optimizer(OptimizerFactory.get_muon), MuonOptimizerConfig
),
ComponentEntity(
"optimizer",
"fsdp1_checkpointed",
Expand Down
129 changes: 128 additions & 1 deletion tests/test_optimizer_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import Literal
from unittest.mock import MagicMock

import pytest
import torch
Expand All @@ -14,7 +15,8 @@
from modalities.models.coca.coca_model import CoCa, CoCaConfig
from modalities.models.gpt2.gpt2_model import GPT2LLM
from modalities.models.model_factory import ModelFactory
from modalities.optimizers.optimizer_factory import get_optimizer_groups
from modalities.optimizers import optimizer_factory as optimizer_factory_module
from modalities.optimizers.optimizer_factory import OptimizerFactory, get_optimizer_groups
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from modalities.running_env.env_utils import MixedPrecisionSettings
Expand Down Expand Up @@ -72,6 +74,131 @@ def test_get_optimizer_groups(
)


def test_get_adam_builds_optimizer_from_groups(monkeypatch):
wrapped_model = MagicMock()
optimizer_groups = [{"params": [MagicMock()], "weight_decay": 0.1}]
optimizer = MagicMock()
get_optimizer_groups_mock = MagicMock(return_value=optimizer_groups)
adam_mock = MagicMock(return_value=optimizer)

monkeypatch.setattr(optimizer_factory_module, "get_optimizer_groups", get_optimizer_groups_mock)
monkeypatch.setattr(optimizer_factory_module, "Adam", adam_mock)

result = OptimizerFactory.get_adam(
lr=1e-3,
betas=(0.9, 0.95),
eps=1e-8,
weight_decay=0.1,
weight_decay_groups_excluded=["embedding"],
wrapped_model=wrapped_model,
foreach=True,
fused=False,
)

assert result is optimizer
get_optimizer_groups_mock.assert_called_once_with(wrapped_model, 0.1, ["embedding"])
adam_mock.assert_called_once_with(
params=optimizer_groups,
lr=1e-3,
betas=(0.9, 0.95),
eps=1e-8,
foreach=True,
fused=False,
)


def test_get_adam_w_builds_optimizer_from_groups(monkeypatch):
wrapped_model = MagicMock()
optimizer_groups = [{"params": [MagicMock()], "weight_decay": 0.2}]
optimizer = MagicMock()
get_optimizer_groups_mock = MagicMock(return_value=optimizer_groups)
adam_w_mock = MagicMock(return_value=optimizer)

monkeypatch.setattr(optimizer_factory_module, "get_optimizer_groups", get_optimizer_groups_mock)
monkeypatch.setattr(optimizer_factory_module, "AdamW", adam_w_mock)

result = OptimizerFactory.get_adam_w(
lr=2e-4,
betas=(0.8, 0.99),
eps=1e-6,
weight_decay=0.2,
weight_decay_groups_excluded=["layernorm"],
wrapped_model=wrapped_model,
foreach=False,
fused=True,
)

assert result is optimizer
get_optimizer_groups_mock.assert_called_once_with(wrapped_model, 0.2, ["layernorm"])
adam_w_mock.assert_called_once_with(
params=optimizer_groups,
lr=2e-4,
betas=(0.8, 0.99),
eps=1e-6,
foreach=False,
fused=True,
)


def test_get_muon_builds_optimizer_from_groups(monkeypatch):
wrapped_model = MagicMock()
optimizer_groups = [{"params": [MagicMock()], "weight_decay": 0.05}]
optimizer = MagicMock()
get_optimizer_groups_mock = MagicMock(return_value=optimizer_groups)
muon_mock = MagicMock(return_value=optimizer)

monkeypatch.setattr(optimizer_factory_module, "get_optimizer_groups", get_optimizer_groups_mock)
monkeypatch.setattr(optimizer_factory_module, "Muon", muon_mock)

result = OptimizerFactory.get_muon(
lr=3e-4,
weight_decay=0.05,
momentum=0.95,
nesterov=True,
ns_coefficients=(1.0, 0.5, 0.25),
eps=1e-9,
ns_steps=7,
adjust_lr_fn="cosine",
weight_decay_groups_excluded=["embedding"],
wrapped_model=wrapped_model,
)

assert result is optimizer
get_optimizer_groups_mock.assert_called_once_with(wrapped_model, 0.05, ["embedding"])
muon_mock.assert_called_once_with(
params=optimizer_groups,
lr=3e-4,
weight_decay=0.05,
momentum=0.95,
nesterov=True,
ns_coefficients=(1.0, 0.5, 0.25),
eps=1e-9,
ns_steps=7,
adjust_lr_fn="cosine",
)


def test_get_fsdp1_checkpointed_optimizer_loads_optimizer_state():
checkpoint_loading = MagicMock()
checkpoint_path = Path("/tmp/checkpoint")
wrapped_model = MagicMock()
optimizer = MagicMock()

result = OptimizerFactory.get_fsdp1_checkpointed_optimizer_(
checkpoint_loading=checkpoint_loading,
checkpoint_path=checkpoint_path,
wrapped_model=wrapped_model,
optimizer=optimizer,
)

assert result is optimizer
checkpoint_loading.load_optimizer_checkpoint_.assert_called_once_with(
file_path=checkpoint_path,
optimizer=optimizer,
model=wrapped_model,
)


def _run_single_optimizer_group_case(
process_id: int,
world_size: int,
Expand Down
Loading