From 6524b24f3d7a898e5f188f9d240f9375987f66fd Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 11 May 2026 12:49:48 +0000 Subject: [PATCH 1/9] feat: Implement context extension with yarn Co-authored-by: Copilot --- src/modalities/models/gpt2/gpt2_model.py | 167 +++++++++++++++++++++-- 1 file changed, 159 insertions(+), 8 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 2da4979c0..7a972b577 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -120,7 +120,15 @@ class RotaryTransform(QueryKeyValueTransform): XFormers implementation and removed in this implementation.# """ - def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq: int = 10000): + def __init__( + self, + n_embd: int, + n_head: int, + seq_length_dim: int = -2, + base_freq: int = 10000, + max_position_embeddings: int | None = None, + rope_scaling: dict[str, object] | None = None, + ): """ Initializes the RotaryTransform object. @@ -136,16 +144,114 @@ def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq self.dim_model = n_embd // n_head self.seq_length_dim = seq_length_dim self.base_freq = base_freq + self.max_position_embeddings = max_position_embeddings + + self.rope_scaling = rope_scaling + self.attention_scaling = 1.0 self.reset_parameters() + def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]: + if self.rope_scaling is None: + raise ValueError("YaRN requires a rope_scaling config.") + if self.max_position_embeddings is None: + raise ValueError("YaRN requires max_position_embeddings to be set.") + + original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings") + if ( + original_max_position_embeddings is None + or not isinstance(original_max_position_embeddings, int) + or original_max_position_embeddings <= 0 + ): + raise ValueError("YaRN requires original_max_position_embeddings to be a positive integer") + + factor = self.rope_scaling.get("factor") + if factor is None: + factor = self.max_position_embeddings / original_max_position_embeddings + if not isinstance(factor, (int, float)) or factor < 1.0: + raise ValueError("YaRN requires rope_scaling.factor to be a float >= 1.0") + factor_float = float(factor) + + attention_factor = self.rope_scaling.get("attention_factor") + mscale = self.rope_scaling.get("mscale") + mscale_all_dim = self.rope_scaling.get("mscale_all_dim") + beta_fast = self.rope_scaling.get("beta_fast") or 32 + beta_slow = self.rope_scaling.get("beta_slow") or 1 + truncate = self.rope_scaling.get("truncate", True) + + def get_mscale(scale: float, mscale: float = 1.0) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + if attention_factor is None: + if isinstance(mscale, (int, float)) and isinstance(mscale_all_dim, (int, float)): + attention_factor = float( + get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim)) + ) + else: + attention_factor = get_mscale(factor_float) + elif not isinstance(attention_factor, (int, float)) or attention_factor <= 0: + raise ValueError("YaRN requires rope_scaling.attention_factor to be a float > 0") + + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate): + low = find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min_value, max_value, dim): + if min_value == max_value: + max_value += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + dim = self.dim_model + base = self.base_freq + + pos_freqs = base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor_float * pos_freqs) + + low, high = find_correction_range( + beta_fast, + beta_slow, + dim, + base, + original_max_position_embeddings, + bool(truncate), + ) + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, float(attention_factor) + def reset_parameters(self): # If previously initialized on or moved to a device, reuse that device. # Otherwise, use the default device of the current environment. - device = self.inv_freq.device if hasattr(self, "inv_freq") else None - inv_freq = 1.0 / ( - self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) - ) + device = self.inv_freq.device if hasattr(self, "inv_freq") and isinstance(self.inv_freq, torch.Tensor) else None + + rope_type = "default" + if self.rope_scaling is not None: + rope_type = str(self.rope_scaling.get("rope_type", "default")) + + if rope_type == "yarn": + inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device) + else: + inv_freq = 1.0 / ( + self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) + ) + self.attention_scaling = 1.0 + self.register_buffer("inv_freq", inv_freq) self._seq_len_cached = None @@ -172,15 +278,21 @@ def _update_cos_sin_tables(self, x): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) - if seq_len != self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: + if ( + seq_len != self._seq_len_cached + or self._cos_cached is None + or self._sin_cached is None + or self._cos_cached.device != x.device + or self._cos_cached.dtype != x.dtype + ): self._seq_len_cached = seq_len t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) emb = torch.cat((freqs, freqs), dim=-1).to( x.device ) # here, we combine the two matrices (not zipping them). - self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) - self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) + self._cos_cached = (emb.cos() * self.attention_scaling)[None, None, :, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.attention_scaling)[None, None, :, :].to(x.dtype) return self._cos_cached, self._sin_cached @@ -295,6 +407,45 @@ class RotaryTransformConfig(BaseModel): n_head: Annotated[int, Field(strict=True, ge=0)] seq_length_dim: Annotated[int, Field(strict=True)] base_freq: Annotated[int, Field(strict=True, ge=10000)] + max_position_embeddings: Optional[Annotated[int, Field(strict=True, ge=1)]] = None + rope_scaling: Optional[dict[str, object]] = None + + @model_validator(mode="after") + def validate_rope_scaling(self) -> "AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig": + if self.rope_scaling is None: + return self + + if not isinstance(self.rope_scaling, dict): + raise ValueError("rope_scaling must be a dictionary") + + rope_scaling = dict(self.rope_scaling) + if "type" in rope_scaling and "rope_type" not in rope_scaling: + rope_scaling["rope_type"] = rope_scaling["type"] + + rope_type = rope_scaling.get("rope_type", "default") + if rope_type not in {"default", "yarn"}: + raise ValueError( + f"Unsupported rope_scaling.rope_type '{rope_type}'. Supported values are 'default' and 'yarn'." + ) + + if rope_type == "yarn": + if self.max_position_embeddings is None: + raise ValueError("YaRN requires max_position_embeddings to be set") + + original_max_position_embeddings = rope_scaling.get("original_max_position_embeddings") + if ( + original_max_position_embeddings is None + or not isinstance(original_max_position_embeddings, int) + or original_max_position_embeddings <= 0 + ): + raise ValueError("YaRN requires original_max_position_embeddings to be a positive integer") + + factor = rope_scaling.get("factor") + if factor is not None and (not isinstance(factor, (int, float)) or factor < 1.0): + raise ValueError("YaRN requires rope_scaling.factor to be a float >= 1.0") + + self.rope_scaling = rope_scaling + return self @validator("type_hint", pre=True, always=True) def parse_sharding_strategy_by_name(cls, name): From f87eabb923b4534190b8eea52226fe31d3580604 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 11 May 2026 14:00:34 +0000 Subject: [PATCH 2/9] test: Add test for yarn Co-authored-by: Copilot --- tests/test_rotary_qkv_transform.py | 80 ++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/tests/test_rotary_qkv_transform.py b/tests/test_rotary_qkv_transform.py index fa82715b1..9da3bd652 100644 --- a/tests/test_rotary_qkv_transform.py +++ b/tests/test_rotary_qkv_transform.py @@ -1,3 +1,4 @@ +import pytest import torch from modalities.models.gpt2.gpt2_model import RotaryTransform @@ -41,3 +42,82 @@ def test_rotary_transform(): comp_rot_h = torch.cat([-comp_h_2, comp_h_1], dim=-1) comp_rot_expected = comp * cos_m_theta + comp_rot_h * sin_m_theta assert torch.equal(comp_rot_expected, comp_rot) + + +def _apply_rotary(x: torch.Tensor, cos_cached: torch.Tensor, sin_cached: torch.Tensor) -> torch.Tensor: + cos_local = cos_cached[:, :, : x.shape[-2], :] + sin_local = sin_cached[:, :, : x.shape[-2], :] + x1, x2 = x.chunk(2, dim=-1) + x_rot = torch.cat((-x2, x1), dim=-1) + return (x * cos_local) + (x_rot * sin_local) + + +def _assert_yarn_outputs_match_reference( + rotary_transform: RotaryTransform, + q: torch.Tensor, + k: torch.Tensor, + q_rot: torch.Tensor, + k_rot: torch.Tensor, + seq_length: int, +) -> None: + t = torch.arange(seq_length, device=q.device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, rotary_transform.inv_freq.to(q.dtype)) + emb = torch.cat((freqs, freqs), dim=-1) + cos = (emb.cos() * rotary_transform.attention_scaling)[None, None, :, :].to(q.dtype) + sin = (emb.sin() * rotary_transform.attention_scaling)[None, None, :, :].to(q.dtype) + + q_expected = _apply_rotary(q, cos, sin) + k_expected = _apply_rotary(k, cos, sin) + + assert torch.allclose(q_rot, q_expected, atol=1e-5, rtol=1e-5) + assert torch.allclose(k_rot, k_expected, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize( + "rope_scaling", + [ + { + "rope_type": "yarn", + "factor": 2.0, + "beta_fast": 32, + "beta_slow": 1, + "original_max_position_embeddings": 4, + }, + { + "rope_type": "yarn", + "beta_fast": 32, + "beta_slow": 1, + "original_max_position_embeddings": 4, + }, + ], +) +def test_rotary_transform_yarn_matches_reference(rope_scaling: dict): + bs = 1 + n_heads = 2 + embedding_dim = 8 + seq_length = 8 + head_dim = embedding_dim // n_heads + + q = torch.randn(bs, n_heads, seq_length, head_dim) + k = torch.randn(bs, n_heads, seq_length, head_dim) + v = torch.randn(bs, n_heads, seq_length, head_dim) + + rotary_transform = RotaryTransform( + n_embd=embedding_dim, + n_head=n_heads, + base_freq=10000, + max_position_embeddings=seq_length, + rope_scaling=rope_scaling, + ) + + q_rot, k_rot, v_rot = rotary_transform(q=q, k=k, v=v) + assert torch.equal(v, v_rot) + + _assert_yarn_outputs_match_reference( + rotary_transform=rotary_transform, + q=q, + k=k, + q_rot=q_rot, + k_rot=k_rot, + seq_length=seq_length, + ) From 779e7c180e003979658de101a1b455147f3ace55 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 12 May 2026 07:27:27 +0000 Subject: [PATCH 3/9] docs: Add type annotations Co-authored-by: Copilot --- src/modalities/models/gpt2/gpt2_model.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 7a972b577..f103d6b0a 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -175,8 +175,10 @@ def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.T attention_factor = self.rope_scaling.get("attention_factor") mscale = self.rope_scaling.get("mscale") mscale_all_dim = self.rope_scaling.get("mscale_all_dim") - beta_fast = self.rope_scaling.get("beta_fast") or 32 - beta_slow = self.rope_scaling.get("beta_slow") or 1 + beta_fast_raw = self.rope_scaling.get("beta_fast") + beta_slow_raw = self.rope_scaling.get("beta_slow") + beta_fast = float(beta_fast_raw) if isinstance(beta_fast_raw, (int, float)) else 32.0 + beta_slow = float(beta_slow_raw) if isinstance(beta_slow_raw, (int, float)) else 1.0 truncate = self.rope_scaling.get("truncate", True) def get_mscale(scale: float, mscale: float = 1.0) -> float: @@ -194,10 +196,17 @@ def get_mscale(scale: float, mscale: float = 1.0) -> float: elif not isinstance(attention_factor, (int, float)) or attention_factor <= 0: raise ValueError("YaRN requires rope_scaling.attention_factor to be a float > 0") - def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float: return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate): + def find_correction_range( + low_rot: float, + high_rot: float, + dim: int, + base: int, + max_position_embeddings: int, + truncate: bool, + ) -> tuple[float, float]: low = find_correction_dim(low_rot, dim, base, max_position_embeddings) high = find_correction_dim(high_rot, dim, base, max_position_embeddings) if truncate: @@ -205,7 +214,7 @@ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, high = math.ceil(high) return max(low, 0), min(high, dim - 1) - def linear_ramp_factor(min_value, max_value, dim): + def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor: if min_value == max_value: max_value += 0.001 linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value) From 2126b0bc9e391a3701a5f32008ce7241769b4080 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 12 May 2026 07:35:20 +0000 Subject: [PATCH 4/9] docs: Add docstrings Co-authored-by: Copilot --- src/modalities/models/gpt2/gpt2_model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index f103d6b0a..905ad289b 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -152,6 +152,7 @@ def __init__( self.reset_parameters() def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]: + """Compute YaRN inverse frequencies and the attention scaling factor.""" if self.rope_scaling is None: raise ValueError("YaRN requires a rope_scaling config.") if self.max_position_embeddings is None: @@ -182,6 +183,7 @@ def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.T truncate = self.rope_scaling.get("truncate", True) def get_mscale(scale: float, mscale: float = 1.0) -> float: + """Return the YaRN mscale coefficient for a given scaling factor.""" if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 @@ -197,6 +199,7 @@ def get_mscale(scale: float, mscale: float = 1.0) -> float: raise ValueError("YaRN requires rope_scaling.attention_factor to be a float > 0") def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float: + """Map a target number of rotations to a rotary dimension index.""" return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) def find_correction_range( @@ -207,6 +210,7 @@ def find_correction_range( max_position_embeddings: int, truncate: bool, ) -> tuple[float, float]: + """Compute the lower and upper rotary-dimension correction bounds for YaRN.""" low = find_correction_dim(low_rot, dim, base, max_position_embeddings) high = find_correction_dim(high_rot, dim, base, max_position_embeddings) if truncate: @@ -215,6 +219,7 @@ def find_correction_range( return max(low, 0), min(high, dim - 1) def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor: + """Create a clamped linear ramp used to blend interpolation and extrapolation.""" if min_value == max_value: max_value += 0.001 linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value) @@ -421,6 +426,7 @@ class RotaryTransformConfig(BaseModel): @model_validator(mode="after") def validate_rope_scaling(self) -> "AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig": + """Validate and normalize rope_scaling, including YaRN-specific constraints.""" if self.rope_scaling is None: return self From 309d147dedf214c1cec970d3fb64ebd3228f1ceb Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 12 May 2026 12:14:22 +0000 Subject: [PATCH 5/9] fix: Write to unique filenames Co-authored-by: Copilot --- .../test_tensor_parallelism.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/fsdp2_parallelization/test_tensor_parallelism.py b/tests/fsdp2_parallelization/test_tensor_parallelism.py index d3ccd46c2..25abc686b 100644 --- a/tests/fsdp2_parallelization/test_tensor_parallelism.py +++ b/tests/fsdp2_parallelization/test_tensor_parallelism.py @@ -20,14 +20,15 @@ from tests.utility import find_free_port -def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path) -> Path: +def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path, file_tag: str = "") -> Path: """Patches the original configuration file to set a custom activation type.""" with original_config_path.open("r", encoding="utf-8") as f: config_dict = yaml.safe_load(f) config_dict["model_raw"]["config"]["activation_type"] = activation_type - tmp_file_path = tmp_dir / original_config_path.name + file_suffix = f"_{file_tag}" if file_tag else "" + tmp_file_path = tmp_dir / f"{original_config_path.stem}{file_suffix}{original_config_path.suffix}" with tmp_file_path.open("w", encoding="utf-8") as f: yaml.safe_dump(config_dict, f) @@ -103,12 +104,16 @@ def _test_tp_sharding_impl( ): # Seed before FSDP2 instantiation torch.manual_seed(42) - fsdp2_path = patch_config_file(fsdp2_config_path, activation_type, tmp_config_dir) + fsdp2_path = patch_config_file( + fsdp2_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_fsdp2" + ) fsdp2_model, fsdp2_mesh = self._get_components(fsdp2_path, tmp_path) # Seed again before TP instantiation to match torch.manual_seed(42) - tp_path = patch_config_file(tp_config_path, activation_type, tmp_config_dir) + tp_path = patch_config_file( + tp_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_tp" + ) tp_model, tp_mesh = self._get_components(tp_path, tmp_path) # Ensure models use the correct MLP From a06d6b451e42ace7df96dfb9decbdfc87f58893a Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 26 May 2026 10:17:36 +0000 Subject: [PATCH 6/9] chore: Apply black formatter --- src/modalities/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index c715a01fa..4ad54b226 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -1,6 +1,6 @@ +import gc from datetime import datetime from enum import Enum -import gc from typing import Callable, Optional import torch @@ -388,7 +388,7 @@ def train( self.gc.run(step_count=training_progress.num_seen_steps_total) evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) checkpointing_callback(training_progress=training_progress) - + profiler_cm.step() @staticmethod From 82019f18f6e579cac0edb877210403cd95893b9d Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 2 Jun 2026 09:01:17 +0000 Subject: [PATCH 7/9] fix: validate yarn rope scaling inputs --- src/modalities/models/gpt2/gpt2_model.py | 63 ++++++++- tests/test_rotary_qkv_transform.py | 165 ++++++++++++++++++++++- 2 files changed, 220 insertions(+), 8 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 905ad289b..5074443d1 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -2,6 +2,7 @@ import math from abc import abstractmethod from enum import Enum +from numbers import Real from typing import Annotated, Optional, overload import torch @@ -31,6 +32,50 @@ # GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT +def _get_optional_rope_scaling_float( + rope_scaling: dict[str, object], + key: str, + default: float, + *, + min_value: float | None = None, +) -> float: + """Return a validated float from rope_scaling or a default when the key is absent.""" + if key not in rope_scaling: + return default + + value = rope_scaling[key] + if isinstance(value, bool) or not isinstance(value, Real): + raise ValueError(f"rope_scaling.{key} must be a float") + + value_float = float(value) + if min_value is not None and value_float < min_value: + raise ValueError(f"rope_scaling.{key} must be a float >= {min_value}") + + return value_float + + +def _get_optional_rope_scaling_float_pair( + rope_scaling: dict[str, object], + first_key: str, + second_key: str, + *, + min_value: float | None = None, +) -> tuple[float, float] | None: + """Return a validated float pair when both keys are present, otherwise None if both are absent.""" + first_present = first_key in rope_scaling + second_present = second_key in rope_scaling + + if not first_present and not second_present: + return None + if first_present != second_present: + raise ValueError(f"rope_scaling.{first_key} and rope_scaling.{second_key} must be provided together") + + return ( + _get_optional_rope_scaling_float(rope_scaling, first_key, 0.0, min_value=min_value), + _get_optional_rope_scaling_float(rope_scaling, second_key, 0.0, min_value=min_value), + ) + + class LayerNorms(LookupEnum): """ Enum lookup class for LayerNorms. @@ -174,12 +219,11 @@ def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.T factor_float = float(factor) attention_factor = self.rope_scaling.get("attention_factor") - mscale = self.rope_scaling.get("mscale") - mscale_all_dim = self.rope_scaling.get("mscale_all_dim") - beta_fast_raw = self.rope_scaling.get("beta_fast") - beta_slow_raw = self.rope_scaling.get("beta_slow") - beta_fast = float(beta_fast_raw) if isinstance(beta_fast_raw, (int, float)) else 32.0 - beta_slow = float(beta_slow_raw) if isinstance(beta_slow_raw, (int, float)) else 1.0 + mscale_pair = _get_optional_rope_scaling_float_pair( + self.rope_scaling, "mscale", "mscale_all_dim", min_value=0.0 + ) + beta_fast = _get_optional_rope_scaling_float(self.rope_scaling, "beta_fast", 32.0, min_value=0.0) + beta_slow = _get_optional_rope_scaling_float(self.rope_scaling, "beta_slow", 1.0, min_value=0.0) truncate = self.rope_scaling.get("truncate", True) def get_mscale(scale: float, mscale: float = 1.0) -> float: @@ -189,7 +233,8 @@ def get_mscale(scale: float, mscale: float = 1.0) -> float: return 0.1 * mscale * math.log(scale) + 1.0 if attention_factor is None: - if isinstance(mscale, (int, float)) and isinstance(mscale_all_dim, (int, float)): + if mscale_pair is not None: + mscale, mscale_all_dim = mscale_pair attention_factor = float( get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim)) ) @@ -459,6 +504,10 @@ def validate_rope_scaling(self) -> "AttentionConfig.QueryKeyValueTransformConfig if factor is not None and (not isinstance(factor, (int, float)) or factor < 1.0): raise ValueError("YaRN requires rope_scaling.factor to be a float >= 1.0") + _get_optional_rope_scaling_float(rope_scaling, "beta_fast", 32.0, min_value=0.0) + _get_optional_rope_scaling_float(rope_scaling, "beta_slow", 1.0, min_value=0.0) + _get_optional_rope_scaling_float_pair(rope_scaling, "mscale", "mscale_all_dim", min_value=0.0) + self.rope_scaling = rope_scaling return self diff --git a/tests/test_rotary_qkv_transform.py b/tests/test_rotary_qkv_transform.py index 9da3bd652..5cebfba0a 100644 --- a/tests/test_rotary_qkv_transform.py +++ b/tests/test_rotary_qkv_transform.py @@ -1,7 +1,7 @@ import pytest import torch -from modalities.models.gpt2.gpt2_model import RotaryTransform +from modalities.models.gpt2.gpt2_model import AttentionConfig, RotaryTransform def test_rotary_transform(): @@ -121,3 +121,166 @@ def test_rotary_transform_yarn_matches_reference(rope_scaling: dict): k_rot=k_rot, seq_length=seq_length, ) + + +@pytest.mark.parametrize( + ("key", "value"), + [ + ("beta_fast", "32"), + ("beta_slow", torch.tensor(1.0)), + ("beta_fast", True), + ], +) +def test_rotary_transform_yarn_rejects_invalid_beta_values(key: str, value: object): + rope_scaling = { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + key: value, + } + + with pytest.raises(ValueError, match=rf"rope_scaling\.{key} must be a float"): + RotaryTransform( + n_embd=8, + n_head=2, + base_freq=10000, + max_position_embeddings=8, + rope_scaling=rope_scaling, + ) + + +@pytest.mark.parametrize( + ("key", "value"), + [ + ("beta_fast", "32"), + ("beta_slow", torch.tensor(1.0)), + ("beta_slow", False), + ], +) +def test_rotary_transform_config_yarn_rejects_invalid_beta_values(key: str, value: object): + rope_scaling = { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + key: value, + } + + with pytest.raises(ValueError, match=rf"rope_scaling\.{key} must be a float"): + AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig( + n_embd=8, + n_head=2, + seq_length_dim=-2, + base_freq=10000, + max_position_embeddings=8, + rope_scaling=rope_scaling, + ) + + +@pytest.mark.parametrize( + ("rope_scaling", "match"), + [ + ( + { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + "mscale": "1.0", + "mscale_all_dim": 1.0, + }, + r"rope_scaling\.mscale must be a float", + ), + ( + { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + "mscale": 1.0, + "mscale_all_dim": torch.tensor(1.0), + }, + r"rope_scaling\.mscale_all_dim must be a float", + ), + ( + { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + "mscale": True, + "mscale_all_dim": 1.0, + }, + r"rope_scaling\.mscale must be a float", + ), + ( + { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + "mscale": 1.0, + }, + r"rope_scaling\.mscale and rope_scaling\.mscale_all_dim must be provided together", + ), + ( + { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + "mscale_all_dim": 1.0, + }, + r"rope_scaling\.mscale and rope_scaling\.mscale_all_dim must be provided together", + ), + ], +) +def test_rotary_transform_yarn_rejects_invalid_mscale_values(rope_scaling: dict, match: str): + with pytest.raises(ValueError, match=match): + RotaryTransform( + n_embd=8, + n_head=2, + base_freq=10000, + max_position_embeddings=8, + rope_scaling=rope_scaling, + ) + + +@pytest.mark.parametrize( + ("rope_scaling", "match"), + [ + ( + { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + "mscale": "1.0", + "mscale_all_dim": 1.0, + }, + r"rope_scaling\.mscale must be a float", + ), + ( + { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + "mscale": 1.0, + "mscale_all_dim": torch.tensor(1.0), + }, + r"rope_scaling\.mscale_all_dim must be a float", + ), + ( + { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + "mscale": 1.0, + }, + r"rope_scaling\.mscale and rope_scaling\.mscale_all_dim must be provided together", + ), + ], +) +def test_rotary_transform_config_yarn_rejects_invalid_mscale_values(rope_scaling: dict, match: str): + with pytest.raises(ValueError, match=match): + AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig( + n_embd=8, + n_head=2, + seq_length_dim=-2, + base_freq=10000, + max_position_embeddings=8, + rope_scaling=rope_scaling, + ) From b91762af0170745c46ca83a95023af9f4519a3b3 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 2 Jun 2026 11:39:41 +0000 Subject: [PATCH 8/9] refactor: use typed rope scaling configs for rotary transform --- src/modalities/models/gpt2/gpt2_model.py | 180 ++++++++++------------- tests/test_rotary_qkv_transform.py | 105 ++----------- 2 files changed, 89 insertions(+), 196 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 5074443d1..31f25f575 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -3,11 +3,11 @@ from abc import abstractmethod from enum import Enum from numbers import Real -from typing import Annotated, Optional, overload +from typing import Annotated, Literal, Optional, overload import torch import torch.nn as nn -from pydantic import BaseModel, Field, model_validator, validator +from pydantic import BaseModel, Field, field_validator, model_validator, validator from modalities.config.lookup_enum import LookupEnum from modalities.config.utils import convert_base_model_config_to_dict @@ -32,48 +32,58 @@ # GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT -def _get_optional_rope_scaling_float( - rope_scaling: dict[str, object], - key: str, - default: float, - *, - min_value: float | None = None, -) -> float: - """Return a validated float from rope_scaling or a default when the key is absent.""" - if key not in rope_scaling: - return default - - value = rope_scaling[key] +def _validate_numeric_field(field_name: str, value: object) -> float: + """Validate that a value is a real number (excluding bool) and cast to float.""" if isinstance(value, bool) or not isinstance(value, Real): - raise ValueError(f"rope_scaling.{key} must be a float") - - value_float = float(value) - if min_value is not None and value_float < min_value: - raise ValueError(f"rope_scaling.{key} must be a float >= {min_value}") - - return value_float - - -def _get_optional_rope_scaling_float_pair( - rope_scaling: dict[str, object], - first_key: str, - second_key: str, - *, - min_value: float | None = None, -) -> tuple[float, float] | None: - """Return a validated float pair when both keys are present, otherwise None if both are absent.""" - first_present = first_key in rope_scaling - second_present = second_key in rope_scaling - - if not first_present and not second_present: - return None - if first_present != second_present: - raise ValueError(f"rope_scaling.{first_key} and rope_scaling.{second_key} must be provided together") - - return ( - _get_optional_rope_scaling_float(rope_scaling, first_key, 0.0, min_value=min_value), - _get_optional_rope_scaling_float(rope_scaling, second_key, 0.0, min_value=min_value), + raise ValueError(f"rope_scaling.{field_name} must be a float") + return float(value) + + +class DefaultRopeScalingConfig(BaseModel): + """Configuration for default RoPE behavior.""" + + rope_type: Literal["default"] = "default" + + +class YarnRopeScalingConfig(BaseModel): + """Configuration for YaRN RoPE scaling.""" + + rope_type: Literal["yarn"] = "yarn" + original_max_position_embeddings: Annotated[int, Field(strict=True, ge=1)] + factor: Optional[Annotated[float, Field(ge=1.0)]] = None + attention_factor: Optional[Annotated[float, Field(gt=0.0)]] = None + mscale: Optional[Annotated[float, Field(ge=0.0)]] = None + mscale_all_dim: Optional[Annotated[float, Field(ge=0.0)]] = None + beta_fast: Annotated[float, Field(ge=0.0)] = 32.0 + beta_slow: Annotated[float, Field(ge=0.0)] = 1.0 + truncate: bool = True + + @field_validator( + "factor", + "attention_factor", + "mscale", + "mscale_all_dim", + "beta_fast", + "beta_slow", + mode="before", ) + @classmethod + def validate_numeric_fields(cls, value: object, info): + if value is None: + return value + return _validate_numeric_field(info.field_name, value) + + @model_validator(mode="after") + def validate_mscale_pair(self) -> "YarnRopeScalingConfig": + if (self.mscale is None) != (self.mscale_all_dim is None): + raise ValueError("rope_scaling.mscale and rope_scaling.mscale_all_dim must be provided together") + return self + + +RopeScalingConfig = Annotated[ + DefaultRopeScalingConfig | YarnRopeScalingConfig, + Field(discriminator="rope_type"), +] class LayerNorms(LookupEnum): @@ -172,7 +182,7 @@ def __init__( seq_length_dim: int = -2, base_freq: int = 10000, max_position_embeddings: int | None = None, - rope_scaling: dict[str, object] | None = None, + rope_scaling: RopeScalingConfig | None = None, ): """ Initializes the RotaryTransform object. @@ -191,6 +201,11 @@ def __init__( self.base_freq = base_freq self.max_position_embeddings = max_position_embeddings + if rope_scaling is not None and not isinstance(rope_scaling, (DefaultRopeScalingConfig, YarnRopeScalingConfig)): + raise TypeError( + "rope_scaling must be an instance of DefaultRopeScalingConfig, YarnRopeScalingConfig, or None" + ) + self.rope_scaling = rope_scaling self.attention_scaling = 1.0 @@ -198,33 +213,25 @@ def __init__( def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]: """Compute YaRN inverse frequencies and the attention scaling factor.""" - if self.rope_scaling is None: + if not isinstance(self.rope_scaling, YarnRopeScalingConfig): raise ValueError("YaRN requires a rope_scaling config.") if self.max_position_embeddings is None: raise ValueError("YaRN requires max_position_embeddings to be set.") - original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings") - if ( - original_max_position_embeddings is None - or not isinstance(original_max_position_embeddings, int) - or original_max_position_embeddings <= 0 - ): - raise ValueError("YaRN requires original_max_position_embeddings to be a positive integer") - - factor = self.rope_scaling.get("factor") + original_max_position_embeddings = self.rope_scaling.original_max_position_embeddings + factor = self.rope_scaling.factor if factor is None: factor = self.max_position_embeddings / original_max_position_embeddings - if not isinstance(factor, (int, float)) or factor < 1.0: - raise ValueError("YaRN requires rope_scaling.factor to be a float >= 1.0") factor_float = float(factor) - attention_factor = self.rope_scaling.get("attention_factor") - mscale_pair = _get_optional_rope_scaling_float_pair( - self.rope_scaling, "mscale", "mscale_all_dim", min_value=0.0 - ) - beta_fast = _get_optional_rope_scaling_float(self.rope_scaling, "beta_fast", 32.0, min_value=0.0) - beta_slow = _get_optional_rope_scaling_float(self.rope_scaling, "beta_slow", 1.0, min_value=0.0) - truncate = self.rope_scaling.get("truncate", True) + attention_factor = self.rope_scaling.attention_factor + mscale_pair = None + if self.rope_scaling.mscale is not None and self.rope_scaling.mscale_all_dim is not None: + mscale_pair = (self.rope_scaling.mscale, self.rope_scaling.mscale_all_dim) + + beta_fast = self.rope_scaling.beta_fast + beta_slow = self.rope_scaling.beta_slow + truncate = self.rope_scaling.truncate def get_mscale(scale: float, mscale: float = 1.0) -> float: """Return the YaRN mscale coefficient for a given scaling factor.""" @@ -240,8 +247,6 @@ def get_mscale(scale: float, mscale: float = 1.0) -> float: ) else: attention_factor = get_mscale(factor_float) - elif not isinstance(attention_factor, (int, float)) or attention_factor <= 0: - raise ValueError("YaRN requires rope_scaling.attention_factor to be a float > 0") def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float: """Map a target number of rotations to a rotary dimension index.""" @@ -299,9 +304,7 @@ def reset_parameters(self): # Otherwise, use the default device of the current environment. device = self.inv_freq.device if hasattr(self, "inv_freq") and isinstance(self.inv_freq, torch.Tensor) else None - rope_type = "default" - if self.rope_scaling is not None: - rope_type = str(self.rope_scaling.get("rope_type", "default")) + rope_type = self.rope_scaling.rope_type if self.rope_scaling is not None else "default" if rope_type == "yarn": inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device) @@ -467,48 +470,13 @@ class RotaryTransformConfig(BaseModel): seq_length_dim: Annotated[int, Field(strict=True)] base_freq: Annotated[int, Field(strict=True, ge=10000)] max_position_embeddings: Optional[Annotated[int, Field(strict=True, ge=1)]] = None - rope_scaling: Optional[dict[str, object]] = None + rope_scaling: Optional[RopeScalingConfig] = None @model_validator(mode="after") def validate_rope_scaling(self) -> "AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig": - """Validate and normalize rope_scaling, including YaRN-specific constraints.""" - if self.rope_scaling is None: - return self - - if not isinstance(self.rope_scaling, dict): - raise ValueError("rope_scaling must be a dictionary") - - rope_scaling = dict(self.rope_scaling) - if "type" in rope_scaling and "rope_type" not in rope_scaling: - rope_scaling["rope_type"] = rope_scaling["type"] - - rope_type = rope_scaling.get("rope_type", "default") - if rope_type not in {"default", "yarn"}: - raise ValueError( - f"Unsupported rope_scaling.rope_type '{rope_type}'. Supported values are 'default' and 'yarn'." - ) - - if rope_type == "yarn": - if self.max_position_embeddings is None: - raise ValueError("YaRN requires max_position_embeddings to be set") - - original_max_position_embeddings = rope_scaling.get("original_max_position_embeddings") - if ( - original_max_position_embeddings is None - or not isinstance(original_max_position_embeddings, int) - or original_max_position_embeddings <= 0 - ): - raise ValueError("YaRN requires original_max_position_embeddings to be a positive integer") - - factor = rope_scaling.get("factor") - if factor is not None and (not isinstance(factor, (int, float)) or factor < 1.0): - raise ValueError("YaRN requires rope_scaling.factor to be a float >= 1.0") - - _get_optional_rope_scaling_float(rope_scaling, "beta_fast", 32.0, min_value=0.0) - _get_optional_rope_scaling_float(rope_scaling, "beta_slow", 1.0, min_value=0.0) - _get_optional_rope_scaling_float_pair(rope_scaling, "mscale", "mscale_all_dim", min_value=0.0) - - self.rope_scaling = rope_scaling + """Validate rope_scaling cross-field constraints.""" + if isinstance(self.rope_scaling, YarnRopeScalingConfig) and self.max_position_embeddings is None: + raise ValueError("YaRN requires max_position_embeddings to be set") return self @validator("type_hint", pre=True, always=True) diff --git a/tests/test_rotary_qkv_transform.py b/tests/test_rotary_qkv_transform.py index 5cebfba0a..b44868e4b 100644 --- a/tests/test_rotary_qkv_transform.py +++ b/tests/test_rotary_qkv_transform.py @@ -1,7 +1,7 @@ import pytest import torch -from modalities.models.gpt2.gpt2_model import AttentionConfig, RotaryTransform +from modalities.models.gpt2.gpt2_model import AttentionConfig, RotaryTransform, YarnRopeScalingConfig def test_rotary_transform(): @@ -76,22 +76,20 @@ def _assert_yarn_outputs_match_reference( @pytest.mark.parametrize( "rope_scaling", [ - { - "rope_type": "yarn", - "factor": 2.0, - "beta_fast": 32, - "beta_slow": 1, - "original_max_position_embeddings": 4, - }, - { - "rope_type": "yarn", - "beta_fast": 32, - "beta_slow": 1, - "original_max_position_embeddings": 4, - }, + YarnRopeScalingConfig( + factor=2.0, + beta_fast=32, + beta_slow=1, + original_max_position_embeddings=4, + ), + YarnRopeScalingConfig( + beta_fast=32, + beta_slow=1, + original_max_position_embeddings=4, + ), ], ) -def test_rotary_transform_yarn_matches_reference(rope_scaling: dict): +def test_rotary_transform_yarn_matches_reference(rope_scaling: YarnRopeScalingConfig): bs = 1 n_heads = 2 embedding_dim = 8 @@ -123,23 +121,14 @@ def test_rotary_transform_yarn_matches_reference(rope_scaling: dict): ) -@pytest.mark.parametrize( - ("key", "value"), - [ - ("beta_fast", "32"), - ("beta_slow", torch.tensor(1.0)), - ("beta_fast", True), - ], -) -def test_rotary_transform_yarn_rejects_invalid_beta_values(key: str, value: object): +def test_rotary_transform_rejects_dict_rope_scaling(): rope_scaling = { "rope_type": "yarn", "factor": 2.0, "original_max_position_embeddings": 4, - key: value, } - with pytest.raises(ValueError, match=rf"rope_scaling\.{key} must be a float"): + with pytest.raises(TypeError, match="rope_scaling must be an instance"): RotaryTransform( n_embd=8, n_head=2, @@ -176,70 +165,6 @@ def test_rotary_transform_config_yarn_rejects_invalid_beta_values(key: str, valu ) -@pytest.mark.parametrize( - ("rope_scaling", "match"), - [ - ( - { - "rope_type": "yarn", - "factor": 2.0, - "original_max_position_embeddings": 4, - "mscale": "1.0", - "mscale_all_dim": 1.0, - }, - r"rope_scaling\.mscale must be a float", - ), - ( - { - "rope_type": "yarn", - "factor": 2.0, - "original_max_position_embeddings": 4, - "mscale": 1.0, - "mscale_all_dim": torch.tensor(1.0), - }, - r"rope_scaling\.mscale_all_dim must be a float", - ), - ( - { - "rope_type": "yarn", - "factor": 2.0, - "original_max_position_embeddings": 4, - "mscale": True, - "mscale_all_dim": 1.0, - }, - r"rope_scaling\.mscale must be a float", - ), - ( - { - "rope_type": "yarn", - "factor": 2.0, - "original_max_position_embeddings": 4, - "mscale": 1.0, - }, - r"rope_scaling\.mscale and rope_scaling\.mscale_all_dim must be provided together", - ), - ( - { - "rope_type": "yarn", - "factor": 2.0, - "original_max_position_embeddings": 4, - "mscale_all_dim": 1.0, - }, - r"rope_scaling\.mscale and rope_scaling\.mscale_all_dim must be provided together", - ), - ], -) -def test_rotary_transform_yarn_rejects_invalid_mscale_values(rope_scaling: dict, match: str): - with pytest.raises(ValueError, match=match): - RotaryTransform( - n_embd=8, - n_head=2, - base_freq=10000, - max_position_embeddings=8, - rope_scaling=rope_scaling, - ) - - @pytest.mark.parametrize( ("rope_scaling", "match"), [ From e12db1ab46012bdee464b659c5fab03fd5177ad4 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 2 Jun 2026 11:45:04 +0000 Subject: [PATCH 9/9] chore: Place private methods below the public interface --- src/modalities/models/gpt2/gpt2_model.py | 158 +++++++++++------------ 1 file changed, 79 insertions(+), 79 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 31f25f575..f43e6e87b 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -211,6 +211,85 @@ def __init__( self.reset_parameters() + def reset_parameters(self): + # If previously initialized on or moved to a device, reuse that device. + # Otherwise, use the default device of the current environment. + device = self.inv_freq.device if hasattr(self, "inv_freq") and isinstance(self.inv_freq, torch.Tensor) else None + + rope_type = self.rope_scaling.rope_type if self.rope_scaling is not None else "default" + + if rope_type == "yarn": + inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device) + else: + inv_freq = 1.0 / ( + self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) + ) + self.attention_scaling = 1.0 + + self.register_buffer("inv_freq", inv_freq) + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def rotate_half(self, x: torch.Tensor): + """ + Rearrange tensor elements. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + + """ + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(self, x, cos, sin): + """ + Applies rotary positional embedding to the input tensor. + + Args: + x (torch.Tensor): Input tensor. + cos (torch.Tensor): Cosine values for rotary positional embedding. + sin (torch.Tensor): Sine values for rotary positional embedding. + + Returns: + torch.Tensor: Tensor after applying rotary positional embedding. + """ + # NOTE: This could probably be moved to Triton + + # Handle a possible sequence length mismatch in between q and k + cos = cos[:, :, : x.shape[self.seq_length_dim], :] + sin = sin[:, :, : x.shape[self.seq_length_dim], :] + + # the rotation is not really a rotation in higher dimensions, + # It merely swaps and negates certain dimensions to make + # the rotation below work + return (x * cos) + (self.rotate_half(x) * sin) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass of the RotaryTransform module. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + v (torch.Tensor): Value tensor. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + Tuple containing the modified query tensor, key tensor, and value tensor. + """ + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k) + q = self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached) + k = self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached) + + return q, k, v + def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]: """Compute YaRN inverse frequencies and the attention scaling factor.""" if not isinstance(self.rope_scaling, YarnRopeScalingConfig): @@ -299,41 +378,6 @@ def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Te return inv_freq, float(attention_factor) - def reset_parameters(self): - # If previously initialized on or moved to a device, reuse that device. - # Otherwise, use the default device of the current environment. - device = self.inv_freq.device if hasattr(self, "inv_freq") and isinstance(self.inv_freq, torch.Tensor) else None - - rope_type = self.rope_scaling.rope_type if self.rope_scaling is not None else "default" - - if rope_type == "yarn": - inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device) - else: - inv_freq = 1.0 / ( - self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) - ) - self.attention_scaling = 1.0 - - self.register_buffer("inv_freq", inv_freq) - - self._seq_len_cached = None - self._cos_cached = None - self._sin_cached = None - - def rotate_half(self, x: torch.Tensor): - """ - Rearrange tensor elements. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The output tensor. - - """ - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - def _update_cos_sin_tables(self, x): # Update the cosine and sine tables. seq_len = x.shape[self.seq_length_dim] @@ -358,50 +402,6 @@ def _update_cos_sin_tables(self, x): return self._cos_cached, self._sin_cached - def apply_rotary_pos_emb(self, x, cos, sin): - """ - Applies rotary positional embedding to the input tensor. - - Args: - x (torch.Tensor): Input tensor. - cos (torch.Tensor): Cosine values for rotary positional embedding. - sin (torch.Tensor): Sine values for rotary positional embedding. - - Returns: - torch.Tensor: Tensor after applying rotary positional embedding. - """ - # NOTE: This could probably be moved to Triton - - # Handle a possible sequence length mismatch in between q and k - cos = cos[:, :, : x.shape[self.seq_length_dim], :] - sin = sin[:, :, : x.shape[self.seq_length_dim], :] - - # the rotation is not really a rotation in higher dimensions, - # It merely swaps and negates certain dimensions to make - # the rotation below work - return (x * cos) + (self.rotate_half(x) * sin) - - def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Forward pass of the RotaryTransform module. - - Args: - q (torch.Tensor): Query tensor. - k (torch.Tensor): Key tensor. - v (torch.Tensor): Value tensor. - - Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - Tuple containing the modified query tensor, key tensor, and value tensor. - """ - self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k) - q = self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached) - k = self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached) - - return q, k, v - class QueryKeyValueTransformType(Enum): """