Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 243 additions & 31 deletions pytorch/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def forward(
x_flat = x.flatten(2).transpose(1, 2)
readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1)
x_cat = torch.cat([x_flat, readout], dim=-1)
x_proj = F.gelu(self.readout_projects[i](x_cat))
# JAX GELU uses tanh approximation by default.
x_proj = F.gelu(self.readout_projects[i](x_cat), approximate='tanh')
x = x_proj.transpose(1, 2).reshape(b, d, h, w)
x = self.out_projections[i](x)
x = self.resize_layers[i](x)
Expand Down Expand Up @@ -164,8 +165,10 @@ def __init__(
channels: int = 256,
post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024),
readout_type: str = "project",
output_activation: bool = False,
) -> None:
super().__init__()
self.output_activation = output_activation
self.reassemble = ReassembleBlocks(
input_embed_dim=input_embed_dim,
out_channels=post_process_channels,
Expand Down Expand Up @@ -196,6 +199,10 @@ def forward(
out = self.fusion_blocks[i](out, residual=x[-(i + 1)])

out = self.project(out)
# NOTE: By default, the reference implementation does not apply output activation,
# so NO ReLU is applied after the project layer by default.
if self.output_activation:
out = F.relu(out)
return out


Expand All @@ -217,6 +224,7 @@ def __init__(
channels: int = 256,
post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024),
readout_type: str = "project",
output_activation: bool = False,
) -> None:
super().__init__()
self.channels = channels
Expand All @@ -226,6 +234,7 @@ def __init__(
channels=channels,
post_process_channels=post_process_channels,
readout_type=readout_type,
output_activation=output_activation,
)
# Common head for all dense prediction tasks
self.head = nn.Linear(self.channels, self.out_channels)
Expand Down Expand Up @@ -264,36 +273,50 @@ def __init__(self, num_classes: int = 150, **kwargs) -> None:


class DepthDecoder(Decoder):
"""Decoder for monocular depth prediction using classification bins."""
"""Decoder for monocular depth prediction using classification bins.

def __init__(self, min_depth: float = 0.001, max_depth: float = 10.0, **kwargs) -> None:
# Decoder requires out_channels, we pass 256 as we use channels as bins,
# although we bypass the head in forward().
super().__init__(out_channels=256, **kwargs)
Predicts depth by classifying each pixel into uniformly-spaced depth bins
and computing the expected depth value.
"""

def __init__(
self,
num_depth_bins: int = 256,
min_depth: float = 0.001,
max_depth: float = 10.0,
**kwargs,
) -> None:
super().__init__(out_channels=num_depth_bins, **kwargs)
self.min_depth = min_depth
self.max_depth = max_depth

self.num_depth_bins = num_depth_bins
self.register_buffer(
"bin_centers",
torch.linspace(min_depth, max_depth, num_depth_bins),
)

def forward(
self,
intermediate_features: List[Tuple[torch.Tensor, torch.Tensor]],
image_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
# Bypass super().forward() to avoid the linear head applied there,
# and use raw DPT features as logits.
logits = self.dpt(intermediate_features) # (B, C, H', W')
# Apply ReLU and shift
# 1. Get DPT features + task head (nn.Linear) via parent class.
# Output shape: (B, num_depth_bins, H', W')
logits = super().forward(intermediate_features)

# 2. Classification-based depth prediction:
# relu + shift -> linear normalisation -> expectation over bins.
logits = torch.relu(logits) + self.min_depth
# Normalize to probabilities along the channel dimension
probs = logits / torch.sum(logits, dim=1, keepdim=True)
# Compute expectation: sum(prob * bin_center)
bin_centers = torch.linspace(
self.min_depth, self.max_depth, 256, device=logits.device, dtype=logits.dtype
)
depth_map = torch.einsum("bchw,c->bhw", probs, bin_centers)
depth_map = torch.einsum("bchw,c->bhw", probs, self.bin_centers.to(logits.device))

# 3. Upsample to target resolution.
if image_size is not None:
depth_map = F.interpolate(
depth_map.unsqueeze(1), size=image_size, mode="bilinear", align_corners=False
depth_map.unsqueeze(1),
size=image_size,
mode="bilinear",
align_corners=False,
).squeeze(1)
return depth_map.unsqueeze(1)

Expand All @@ -314,37 +337,226 @@ def __init__(self, **kwargs) -> None:
"convs.": "dpt.convs.",
"fusion_blocks.": "dpt.fusion_blocks.",
"project.": "dpt.project.",
# Task-specific head keys (Scenic Dense -> PyTorch head.*)
"segmentation_head.": "head.",
"pixel_segmentation.": "head.",
"pixel_depth_classif.": "head.",
"pixel_depth_regress.": "head.",
"pixel_normals.": "head.",
}

# Scenic/Flax head param names -> PyTorch head.*
_SCENIC_HEAD_NAMES = {
"pixel_segmentation",
"pixel_depth_classif",
"pixel_depth_regress",
"pixel_normals",
"segmentation_head",
}

# ConvTranspose keys that need spatial flipping.
_CONV_TRANSPOSE_KEYS = {
'dpt.reassemble.resize_layers.0.weight',
'dpt.reassemble.resize_layers.1.weight',
}


def _is_scenic_format(keys):
"""Check if checkpoint keys use Scenic/Flax naming (``/`` separators)."""
return any('/' in k for k in keys)


def _convert_scenic_checkpoint(weights):
"""Convert Flax parameter tree checkpoint to PyTorch state_dict.

These checkpoints use Flax parameter tree naming:
decoder/dpt/reassemble_blocks/out_projection_0/kernel
which maps to PyTorch:
dpt.reassemble.out_projections.0.weight

Weight conversions:
- Conv kernels: (H, W, Cin, Cout) -> (Cout, Cin, H, W)
- ConvTranspose kernels: same + 180-degree spatial flip
- Dense/Linear kernels: (in, out) -> (out, in)
- Biases: direct copy
"""
sd = {}

# Build a nested dict from flat Scenic keys
tree = {}
for key, value in weights.items():
# Strip "decoder/" prefix if present.
k = key[len("decoder/"):] if key.startswith("decoder/") else key
parts = k.split("/")
d = tree
for p in parts[:-1]:
d = d.setdefault(p, {})
d[parts[-1]] = np.array(value)

dpt_params = tree.get("dpt", tree)

# --- ReassembleBlocks ---
rb = dpt_params.get("reassemble_blocks", {})
for i in range(4):
# out_projections (Conv2d 1x1)
op = rb.get(f"out_projection_{i}", {})
if "kernel" in op:
sd[f"dpt.reassemble.out_projections.{i}.weight"] = torch.from_numpy(
op["kernel"].transpose(3, 2, 0, 1).copy()
)
if "bias" in op:
sd[f"dpt.reassemble.out_projections.{i}.bias"] = torch.from_numpy(
op["bias"].copy()
)
# readout_projects (Linear)
rp = rb.get(f"readout_projects_{i}", {})
if "kernel" in rp:
sd[f"dpt.reassemble.readout_projects.{i}.weight"] = torch.from_numpy(
rp["kernel"].T.copy()
)
if "bias" in rp:
sd[f"dpt.reassemble.readout_projects.{i}.bias"] = torch.from_numpy(
rp["bias"].copy()
)

# resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv
for idx in [0, 1]:
rl = rb.get(f"resize_layers_{idx}", {})
if "kernel" in rl:
w = rl["kernel"][::-1, ::-1, :, :].copy() # 180-degree spatial flip
sd[f"dpt.reassemble.resize_layers.{idx}.weight"] = torch.from_numpy(
w.transpose(2, 3, 0, 1).copy()
)
if "bias" in rl:
sd[f"dpt.reassemble.resize_layers.{idx}.bias"] = torch.from_numpy(
rl["bias"].copy()
)
# resize_layers_2 = Identity (no weights)
rl3 = rb.get("resize_layers_3", {})
if "kernel" in rl3:
sd["dpt.reassemble.resize_layers.3.weight"] = torch.from_numpy(
rl3["kernel"].transpose(3, 2, 0, 1).copy()
)
if "bias" in rl3:
sd["dpt.reassemble.resize_layers.3.bias"] = torch.from_numpy(
rl3["bias"].copy()
)

# --- Convs (3x3, no bias) ---
for i in range(4):
c = dpt_params.get(f"convs_{i}", {})
if "kernel" in c:
sd[f"dpt.convs.{i}.weight"] = torch.from_numpy(
c["kernel"].transpose(3, 2, 0, 1).copy()
)

# --- Fusion blocks ---
for i in range(4):
fb = dpt_params.get(f"fusion_blocks_{i}", {})
if i == 0:
# No residual unit, only 1 PreActResidualConvUnit -> main_unit
pacu = fb.get("PreActResidualConvUnit_0", {})
for cname in ["conv1", "conv2"]:
if cname in pacu and "kernel" in pacu[cname]:
sd[f"dpt.fusion_blocks.{i}.main_unit.{cname}.weight"] = (
torch.from_numpy(
pacu[cname]["kernel"].transpose(3, 2, 0, 1).copy()
)
)
else:
# Residual unit (index 0) + main unit (index 1)
pacu0 = fb.get("PreActResidualConvUnit_0", {})
pacu1 = fb.get("PreActResidualConvUnit_1", {})
for cname in ["conv1", "conv2"]:
if cname in pacu0 and "kernel" in pacu0[cname]:
sd[f"dpt.fusion_blocks.{i}.residual_unit.{cname}.weight"] = (
torch.from_numpy(
pacu0[cname]["kernel"].transpose(3, 2, 0, 1).copy()
)
)
if cname in pacu1 and "kernel" in pacu1[cname]:
sd[f"dpt.fusion_blocks.{i}.main_unit.{cname}.weight"] = (
torch.from_numpy(
pacu1[cname]["kernel"].transpose(3, 2, 0, 1).copy()
)
)
# out_conv (Conv2d 1x1) -- Scenic names it Conv_0
oc = fb.get("Conv_0", fb.get("out_conv", {}))
if "kernel" in oc:
sd[f"dpt.fusion_blocks.{i}.out_conv.weight"] = torch.from_numpy(
oc["kernel"].transpose(3, 2, 0, 1).copy()
)
if "bias" in oc:
sd[f"dpt.fusion_blocks.{i}.out_conv.bias"] = torch.from_numpy(
oc["bias"].copy()
)

# --- Project ---
proj = dpt_params.get("project", {})
if "kernel" in proj:
sd["dpt.project.weight"] = torch.from_numpy(
proj["kernel"].transpose(3, 2, 0, 1).copy()
)
if "bias" in proj:
sd["dpt.project.bias"] = torch.from_numpy(proj["bias"].copy())

# --- Task head (Dense/Linear) ---
for head_name in _SCENIC_HEAD_NAMES:
if head_name in tree:
h = tree[head_name]
if "kernel" in h:
sd["head.weight"] = torch.from_numpy(h["kernel"].T.copy())
if "bias" in h:
sd["head.bias"] = torch.from_numpy(h["bias"].copy())
break

return sd


def load_decoder_weights(
model: Decoder,
checkpoint_path: str,
) -> Decoder:
"""Load pre-converted PyTorch weights into a Decoder.
"""Load weights into a Decoder from a checkpoint file.

Supports both the legacy flat key format (e.g. ``reassemble.…``) and the
new hierarchical format (e.g. ``dpt.reassemble.…``).
Supports three checkpoint formats:
1. Flax format: keys with ``/`` separators and ``kernel``/``bias``
naming (e.g. ``decoder/dpt/reassemble_blocks/out_projection_0/kernel``).
Weights are automatically transposed from Flax layout to PyTorch layout.
2. Legacy flat format: keys like ``reassemble.…`` that get remapped to
``dpt.reassemble.…``.
3. PyTorch hierarchical format: keys already match the model state_dict.

Args:
model: A Decoder instance (SegmentationDecoder, DepthDecoder, etc.).
checkpoint_path: Path to a checkpoint file.
checkpoint_path: Path to a ``.npz`` checkpoint file.

Returns:
The model with loaded weights.
"""
weights = dict(np.load(checkpoint_path, allow_pickle=False))

sd = {}
for key, value in weights.items():
new_key = key
# Remap legacy flat keys to hierarchical names.
for old_prefix, new_prefix in _LEGACY_KEY_PREFIXES.items():
if key.startswith(old_prefix):
new_key = new_prefix + key[len(old_prefix):]
break
sd[new_key] = torch.from_numpy(value)
if _is_scenic_format(weights):
sd = _convert_scenic_checkpoint(weights)
else:
# Legacy flat or PyTorch hierarchical format.
sd = {}
for key, value in weights.items():
new_key = key
for old_prefix, new_prefix in _LEGACY_KEY_PREFIXES.items():
if key.startswith(old_prefix):
new_key = new_prefix + key[len(old_prefix):]
break
t = torch.from_numpy(value)
# Flip ConvTranspose kernels spatially for Flax->PyTorch parity.
if new_key in _CONV_TRANSPOSE_KEYS and t.ndim == 4:
t = t.flip([2, 3])
sd[new_key] = t

# Add registered buffers not in checkpoint (e.g. bin_centers).
for name, buf in model.named_buffers():
if name not in sd:
sd[name] = buf

model.load_state_dict(sd, strict=True)
print(f"Loaded decoder weights from {checkpoint_path} ({len(sd)} tensors)")
Expand Down