Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/ideogram4/pipeline_ideogram4.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _load_fp8_text_encoder(
)
model = AutoModel.from_config(config, trust_remote_code=True)
state_dict = _load_subfolder_state_dict(repo_id, text_encoder_subfolder, "model")
swap_linears_to_fp8(model, state_dict, compute_dtype=dtype)
swap_linears_to_fp8(model, state_dict, compute_dtype=dtype, device=device)
# assign=True so unquantized params take the loaded dtype and the computed
# rotary buffers (absent from the checkpoint) survive; tied weights, if any,
# surface as benign missing keys.
Expand Down Expand Up @@ -168,7 +168,7 @@ def _build_transformer(
# Weight-only FP8: cast the unquantized params to the compute dtype first,
# then swap in Fp8Linear layers (which keep their weights as float8).
model.to(dtype)
swap_linears_to_fp8(model, state_dict, compute_dtype=dtype)
swap_linears_to_fp8(model, state_dict, compute_dtype=dtype, device=device)
load_fp8_state_dict(model, state_dict, device=device, dtype=dtype)
else:
model.load_state_dict(state_dict)
Expand Down Expand Up @@ -584,9 +584,16 @@ def __call__(
device=self.device,
)

# step_intervals is fixed for the whole generation, so warp it through the
# schedule once and read the values out as a Python list, rather than calling
# the schedule per element (twice per step) inside the loop. Pass a CPU tensor
# so the warp (which runs on CPU anyway) doesn't bounce the result back to the
# device just for tolist() to pull it off again.
schedule_values = schedule(step_intervals.cpu()).tolist()

for i in range(num_steps - 1, -1, -1):
t_val = float(schedule(step_intervals[i + 1].unsqueeze(0)).item())
s_val = float(schedule(step_intervals[i].unsqueeze(0)).item())
t_val = schedule_values[i + 1]
s_val = schedule_values[i]
t = torch.full((batch_size,), t_val, dtype=torch.float32, device=self.device)

pos_z = torch.cat([text_z_padding, z], dim=1)
Expand Down
96 changes: 85 additions & 11 deletions src/ideogram4/quantized_loading.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import warnings

import bitsandbytes as bnb
Expand Down Expand Up @@ -161,16 +162,44 @@ def is_fp8_state_dict(state_dict: dict[str, torch.Tensor]) -> bool:
)


@functools.lru_cache(maxsize=None)
def _probe_fp8_support(device_str: str) -> bool:
try:
torch.zeros(1, dtype=FP8_WEIGHT_DTYPE, device=device_str).to(torch.float32)
except Exception:
return False
return True
Comment thread
sammcj marked this conversation as resolved.


def device_supports_fp8(device: torch.device) -> bool:
"""Whether ``device`` can store and cast ``float8_e4m3fn`` tensors.

Probed at runtime (cached per concrete device, e.g. ``cuda:0``, so a
heterogeneous multi-device setup is checked individually) rather than
hard-coded, since float8 support can vary by PyTorch build and device. In
practice PyTorch's MPS (Apple Silicon) backend has no float8 support - it can
neither hold the dtype nor convert it - and fails the probe, so there the FP8
weights must be dequantized to the compute dtype at load time; CUDA and CPU
pass on current builds. The live probe is authoritative.
"""
Comment thread
sammcj marked this conversation as resolved.
return _probe_fp8_support(str(torch.device(device)))


class Fp8Linear(nn.Module):
"""Linear layer holding an e4m3 float8 weight + per-row float32 scale.

The weight and scale are registered as buffers (not parameters) so they load
via ``load_state_dict`` and are excluded from optimizer/grad machinery. The
dequantized matmul runs in ``compute_dtype``.

When ``store_fp8`` is False the layer holds an already-dequantized weight in
``compute_dtype`` and no scale buffer; this is the path for devices that can't
store float8 (MPS). The half-size checkpoint download is still used; only the
in-memory weight is expanded.
"""

weight: torch.Tensor
weight_scale: torch.Tensor
weight_scale: torch.Tensor | None
bias: torch.Tensor | None

def __init__(
Expand All @@ -179,23 +208,39 @@ def __init__(
out_features: int,
bias: bool,
compute_dtype: torch.dtype,
*,
store_fp8: bool = True,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.compute_dtype = compute_dtype
self.register_buffer(
"weight",
torch.empty(out_features, in_features, dtype=FP8_WEIGHT_DTYPE),
)
self.register_buffer("weight_scale", torch.empty(out_features, dtype=torch.float32))
self.store_fp8 = store_fp8
if store_fp8:
self.register_buffer(
"weight",
torch.empty(out_features, in_features, dtype=FP8_WEIGHT_DTYPE),
)
self.register_buffer(
"weight_scale", torch.empty(out_features, dtype=torch.float32)
)
else:
self.register_buffer(
"weight",
torch.empty(out_features, in_features, dtype=compute_dtype),
)
self.weight_scale = None
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=compute_dtype))
else:
self.bias = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight.to(x.dtype) * self.weight_scale.to(x.dtype).unsqueeze(1)
if self.store_fp8:
assert self.weight_scale is not None
w = self.weight.to(x.dtype) * self.weight_scale.to(x.dtype).unsqueeze(1)
else:
w = self.weight.to(x.dtype)
bias = self.bias.to(x.dtype) if self.bias is not None else None
return F.linear(x, w, bias)

Expand All @@ -204,15 +249,18 @@ def swap_linears_to_fp8(
module: nn.Module,
state_dict: dict[str, torch.Tensor],
compute_dtype: torch.dtype,
device: torch.device,
*,
prefix: str = "",
) -> None:
"""Replace each ``nn.Linear`` that has a saved FP8 scale with an ``Fp8Linear``.

Gating on the presence of ``<name>.weight_scale`` means only layers that were
actually quantized at save time are swapped; everything else loads normally in
the compute dtype.
the compute dtype. On devices that can't store float8 (MPS) the swapped layers
hold a dequantized ``compute_dtype`` weight instead (see ``Fp8Linear``).
"""
store_fp8 = device_supports_fp8(device)
for name, child in list(module.named_children()):
child_prefix = f"{prefix}{name}"
if (
Expand All @@ -226,10 +274,13 @@ def swap_linears_to_fp8(
child.out_features,
bias=child.bias is not None,
compute_dtype=compute_dtype,
store_fp8=store_fp8,
),
)
else:
swap_linears_to_fp8(child, state_dict, compute_dtype, prefix=f"{child_prefix}.")
swap_linears_to_fp8(
child, state_dict, compute_dtype, device, prefix=f"{child_prefix}."
)


def load_fp8_state_dict(
Expand All @@ -255,13 +306,36 @@ def load_fp8_state_dict(

``strict=False`` downgrades missing keys to a warning (e.g. tied weights that a
``transformers`` model resolves itself); unexpected keys always raise.

On devices that can't store float8 (MPS) the FP8 weights are dequantized to
``dtype`` here using their per-row scale, and the now-unused ``.weight_scale``
entries are dropped to match the dequantized ``Fp8Linear`` layout.
"""
store_fp8 = device_supports_fp8(device)
prepared: dict[str, torch.Tensor] = {}
for k, v in state_dict.items():
if v.dtype == FP8_WEIGHT_DTYPE:
prepared[k] = v.to(device=device)
if store_fp8:
prepared[k] = v.to(device=device)
else:
# MPS can't cast float8, so dequantize on CPU (where the fp8 weights are
# loaded) before moving the result across to the target device.
if not k.endswith(".weight"):
raise RuntimeError(
f"unexpected FP8 tensor key {k!r} (expected it to end with '.weight')"
)
scale_key = k[: -len(".weight")] + FP8_SCALE_SUFFIX
if scale_key not in state_dict:
raise RuntimeError(
f"FP8 weight {k!r} has no matching scale {scale_key!r} in the checkpoint"
)
Comment thread
sammcj marked this conversation as resolved.
scale = state_dict[scale_key]
w = v.cpu().to(torch.float32) * scale.cpu().to(torch.float32).unsqueeze(1)
prepared[k] = w.to(device=device, dtype=dtype)
Comment thread
sammcj marked this conversation as resolved.
elif k.endswith(FP8_SCALE_SUFFIX):
prepared[k] = v.to(device=device, dtype=torch.float32)
if store_fp8:
prepared[k] = v.to(device=device, dtype=torch.float32)
# else: folded into the dequantized weight above; drop the scale key.
elif v.is_floating_point():
prepared[k] = v.to(device=device, dtype=dtype)
else:
Expand Down
17 changes: 15 additions & 2 deletions src/ideogram4/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,27 @@ class LogitNormalSchedule:
logsnr_max: float = 18.0

def __call__(self, t: torch.Tensor) -> torch.Tensor:
t = t.to(torch.float64)
"""Map step positions to noise levels (inference-time, non-differentiable).

Called from the sampler loop under ``torch.no_grad`` and read out as Python
scalars, so the CPU detour taken on MPS never participates in autograd.
"""
device = t.device
# The float64 warp (ndtri/expit) needs precision at the tails. MPS supports
# neither float64 nor these special functions, so there the warp runs on CPU
# (a fused `.to(cpu, float64)` would still cast on the MPS side first and
# raise, hence the split). Other devices keep the original on-device path.
if device.type == "mps":
t = t.cpu().to(torch.float64)
else:
t = t.to(torch.float64)
Comment thread
sammcj marked this conversation as resolved.
Comment thread
sammcj marked this conversation as resolved.
z = torch.special.ndtri(t)
y = self.mean + self.std * z
t_ = torch.special.expit(y)
t_ = 1 - t_
t_min = 1.0 / (1 + math.exp(0.5 * self.logsnr_max))
t_max = 1.0 / (1 + math.exp(0.5 * self.logsnr_min))
return t_.clamp(t_min, t_max).to(torch.float32)
return t_.clamp(t_min, t_max).to(device=device, dtype=torch.float32)
Comment thread
sammcj marked this conversation as resolved.
Comment thread
sammcj marked this conversation as resolved.


def get_schedule_for_resolution(
Expand Down