diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 42a19b99a..16ce440aa 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -167,6 +167,19 @@ class AdamWOptimizerConfig(BaseModel): fused: bool | None = None +class MuonOptimizerConfig(BaseModel): + lr: float + weight_decay: float + momentum: float = 0.95 + nesterov: bool = True + ns_coefficients: tuple[float, float, float] = (3.4445, -4.775, 2.0315) + eps: float = 1e-07 + ns_steps: int = Field(default=5) + adjust_lr_fn: str | None = None + weight_decay_groups_excluded: list[str] + wrapped_model: PydanticPytorchModuleOrListType + + class DummyLRSchedulerConfig(BaseModel): optimizer: PydanticOptimizerIFType diff --git a/src/modalities/optimizers/optimizer_factory.py b/src/modalities/optimizers/optimizer_factory.py index 2eebe89e1..72a1959b3 100644 --- a/src/modalities/optimizers/optimizer_factory.py +++ b/src/modalities/optimizers/optimizer_factory.py @@ -6,7 +6,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1 from torch.distributed.tensor import DTensor -from torch.optim import Adam, AdamW, Optimizer +from torch.optim import Adam, AdamW, Muon, Optimizer from modalities.checkpointing.checkpoint_loading import FSDP1CheckpointLoadingIF from modalities.exceptions import OptimizerError @@ -49,6 +49,33 @@ def get_adam_w( optimizer = AdamW(params=optimizer_groups, lr=lr, betas=betas, eps=eps, foreach=foreach, fused=fused) return optimizer + @staticmethod + def get_muon( + lr: float, + weight_decay: float, + momentum: float, + nesterov: bool, + ns_coefficients: tuple[float, float, float], + eps: float, + ns_steps: int, + adjust_lr_fn: str | None, + weight_decay_groups_excluded: list[str], + wrapped_model: nn.Module, + ) -> Optimizer: + optimizer_groups = get_optimizer_groups(wrapped_model, weight_decay, weight_decay_groups_excluded) + optimizer = Muon( + params=optimizer_groups, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_coefficients=ns_coefficients, + eps=eps, + ns_steps=ns_steps, + adjust_lr_fn=adjust_lr_fn, + ) + return optimizer + @staticmethod def get_fsdp1_checkpointed_optimizer_( checkpoint_loading: FSDP1CheckpointLoadingIF, @@ -149,7 +176,7 @@ def _create_optimizer_groups( else: raise OptimizerError( - f"model {type(model)} is not an instance of FSDP1 or FSDP2. " "Please use the correct model type." + f"model {type(model)} is not an instance of FSDP1 or FSDP2. Please use the correct model type." ) if ( diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 26df9b432..1fe61967e 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -51,6 +51,7 @@ LinearWarmupCosineAnnealingLRSchedulerConfig, LLMDataLoaderConfig, MemMapDatasetConfig, + MuonOptimizerConfig, OneCycleLRSchedulerConfig, PackedMemMapDatasetContinuousConfig, PackedMemMapDatasetMegatronConfig, @@ -257,6 +258,9 @@ class ComponentEntity: ComponentEntity( "optimizer", "adam_w", maybe_model_list_for_optimizer(OptimizerFactory.get_adam_w), AdamWOptimizerConfig ), + ComponentEntity( + "optimizer", "muon", maybe_model_list_for_optimizer(OptimizerFactory.get_muon), MuonOptimizerConfig + ), ComponentEntity( "optimizer", "fsdp1_checkpointed", diff --git a/tests/test_optimizer_factory.py b/tests/test_optimizer_factory.py index d6c9a7b37..7d45870be 100644 --- a/tests/test_optimizer_factory.py +++ b/tests/test_optimizer_factory.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import Literal +from unittest.mock import MagicMock import pytest import torch @@ -14,7 +15,8 @@ from modalities.models.coca.coca_model import CoCa, CoCaConfig from modalities.models.gpt2.gpt2_model import GPT2LLM from modalities.models.model_factory import ModelFactory -from modalities.optimizers.optimizer_factory import get_optimizer_groups +from modalities.optimizers import optimizer_factory as optimizer_factory_module +from modalities.optimizers.optimizer_factory import OptimizerFactory, get_optimizer_groups from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry from modalities.running_env.env_utils import MixedPrecisionSettings @@ -72,6 +74,131 @@ def test_get_optimizer_groups( ) +def test_get_adam_builds_optimizer_from_groups(monkeypatch): + wrapped_model = MagicMock() + optimizer_groups = [{"params": [MagicMock()], "weight_decay": 0.1}] + optimizer = MagicMock() + get_optimizer_groups_mock = MagicMock(return_value=optimizer_groups) + adam_mock = MagicMock(return_value=optimizer) + + monkeypatch.setattr(optimizer_factory_module, "get_optimizer_groups", get_optimizer_groups_mock) + monkeypatch.setattr(optimizer_factory_module, "Adam", adam_mock) + + result = OptimizerFactory.get_adam( + lr=1e-3, + betas=(0.9, 0.95), + eps=1e-8, + weight_decay=0.1, + weight_decay_groups_excluded=["embedding"], + wrapped_model=wrapped_model, + foreach=True, + fused=False, + ) + + assert result is optimizer + get_optimizer_groups_mock.assert_called_once_with(wrapped_model, 0.1, ["embedding"]) + adam_mock.assert_called_once_with( + params=optimizer_groups, + lr=1e-3, + betas=(0.9, 0.95), + eps=1e-8, + foreach=True, + fused=False, + ) + + +def test_get_adam_w_builds_optimizer_from_groups(monkeypatch): + wrapped_model = MagicMock() + optimizer_groups = [{"params": [MagicMock()], "weight_decay": 0.2}] + optimizer = MagicMock() + get_optimizer_groups_mock = MagicMock(return_value=optimizer_groups) + adam_w_mock = MagicMock(return_value=optimizer) + + monkeypatch.setattr(optimizer_factory_module, "get_optimizer_groups", get_optimizer_groups_mock) + monkeypatch.setattr(optimizer_factory_module, "AdamW", adam_w_mock) + + result = OptimizerFactory.get_adam_w( + lr=2e-4, + betas=(0.8, 0.99), + eps=1e-6, + weight_decay=0.2, + weight_decay_groups_excluded=["layernorm"], + wrapped_model=wrapped_model, + foreach=False, + fused=True, + ) + + assert result is optimizer + get_optimizer_groups_mock.assert_called_once_with(wrapped_model, 0.2, ["layernorm"]) + adam_w_mock.assert_called_once_with( + params=optimizer_groups, + lr=2e-4, + betas=(0.8, 0.99), + eps=1e-6, + foreach=False, + fused=True, + ) + + +def test_get_muon_builds_optimizer_from_groups(monkeypatch): + wrapped_model = MagicMock() + optimizer_groups = [{"params": [MagicMock()], "weight_decay": 0.05}] + optimizer = MagicMock() + get_optimizer_groups_mock = MagicMock(return_value=optimizer_groups) + muon_mock = MagicMock(return_value=optimizer) + + monkeypatch.setattr(optimizer_factory_module, "get_optimizer_groups", get_optimizer_groups_mock) + monkeypatch.setattr(optimizer_factory_module, "Muon", muon_mock) + + result = OptimizerFactory.get_muon( + lr=3e-4, + weight_decay=0.05, + momentum=0.95, + nesterov=True, + ns_coefficients=(1.0, 0.5, 0.25), + eps=1e-9, + ns_steps=7, + adjust_lr_fn="cosine", + weight_decay_groups_excluded=["embedding"], + wrapped_model=wrapped_model, + ) + + assert result is optimizer + get_optimizer_groups_mock.assert_called_once_with(wrapped_model, 0.05, ["embedding"]) + muon_mock.assert_called_once_with( + params=optimizer_groups, + lr=3e-4, + weight_decay=0.05, + momentum=0.95, + nesterov=True, + ns_coefficients=(1.0, 0.5, 0.25), + eps=1e-9, + ns_steps=7, + adjust_lr_fn="cosine", + ) + + +def test_get_fsdp1_checkpointed_optimizer_loads_optimizer_state(): + checkpoint_loading = MagicMock() + checkpoint_path = Path("/tmp/checkpoint") + wrapped_model = MagicMock() + optimizer = MagicMock() + + result = OptimizerFactory.get_fsdp1_checkpointed_optimizer_( + checkpoint_loading=checkpoint_loading, + checkpoint_path=checkpoint_path, + wrapped_model=wrapped_model, + optimizer=optimizer, + ) + + assert result is optimizer + checkpoint_loading.load_optimizer_checkpoint_.assert_called_once_with( + file_path=checkpoint_path, + optimizer=optimizer, + model=wrapped_model, + ) + + def _run_single_optimizer_group_case( process_id: int, world_size: int,