diff --git a/pytorch/decoders.py b/pytorch/decoders.py index 4bbc1b1..72e6595 100644 --- a/pytorch/decoders.py +++ b/pytorch/decoders.py @@ -136,7 +136,8 @@ def forward( x_flat = x.flatten(2).transpose(1, 2) readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1) x_cat = torch.cat([x_flat, readout], dim=-1) - x_proj = F.gelu(self.readout_projects[i](x_cat)) + # JAX GELU uses tanh approximation by default. + x_proj = F.gelu(self.readout_projects[i](x_cat), approximate='tanh') x = x_proj.transpose(1, 2).reshape(b, d, h, w) x = self.out_projections[i](x) x = self.resize_layers[i](x) @@ -164,8 +165,10 @@ def __init__( channels: int = 256, post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), readout_type: str = "project", + output_activation: bool = False, ) -> None: super().__init__() + self.output_activation = output_activation self.reassemble = ReassembleBlocks( input_embed_dim=input_embed_dim, out_channels=post_process_channels, @@ -196,6 +199,10 @@ def forward( out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) out = self.project(out) + # NOTE: By default, the reference implementation does not apply output activation, + # so NO ReLU is applied after the project layer by default. + if self.output_activation: + out = F.relu(out) return out @@ -217,6 +224,7 @@ def __init__( channels: int = 256, post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), readout_type: str = "project", + output_activation: bool = False, ) -> None: super().__init__() self.channels = channels @@ -226,6 +234,7 @@ def __init__( channels=channels, post_process_channels=post_process_channels, readout_type=readout_type, + output_activation=output_activation, ) # Common head for all dense prediction tasks self.head = nn.Linear(self.channels, self.out_channels) @@ -264,36 +273,50 @@ def __init__(self, num_classes: int = 150, **kwargs) -> None: class DepthDecoder(Decoder): - """Decoder for monocular depth prediction using classification bins.""" + """Decoder for monocular depth prediction using classification bins. - def __init__(self, min_depth: float = 0.001, max_depth: float = 10.0, **kwargs) -> None: - # Decoder requires out_channels, we pass 256 as we use channels as bins, - # although we bypass the head in forward(). - super().__init__(out_channels=256, **kwargs) + Predicts depth by classifying each pixel into uniformly-spaced depth bins + and computing the expected depth value. + """ + + def __init__( + self, + num_depth_bins: int = 256, + min_depth: float = 0.001, + max_depth: float = 10.0, + **kwargs, + ) -> None: + super().__init__(out_channels=num_depth_bins, **kwargs) self.min_depth = min_depth self.max_depth = max_depth - + self.num_depth_bins = num_depth_bins + self.register_buffer( + "bin_centers", + torch.linspace(min_depth, max_depth, num_depth_bins), + ) def forward( self, intermediate_features: List[Tuple[torch.Tensor, torch.Tensor]], image_size: Optional[Tuple[int, int]] = None, ) -> torch.Tensor: - # Bypass super().forward() to avoid the linear head applied there, - # and use raw DPT features as logits. - logits = self.dpt(intermediate_features) # (B, C, H', W') - # Apply ReLU and shift + # 1. Get DPT features + task head (nn.Linear) via parent class. + # Output shape: (B, num_depth_bins, H', W') + logits = super().forward(intermediate_features) + + # 2. Classification-based depth prediction: + # relu + shift -> linear normalisation -> expectation over bins. logits = torch.relu(logits) + self.min_depth - # Normalize to probabilities along the channel dimension probs = logits / torch.sum(logits, dim=1, keepdim=True) - # Compute expectation: sum(prob * bin_center) - bin_centers = torch.linspace( - self.min_depth, self.max_depth, 256, device=logits.device, dtype=logits.dtype - ) - depth_map = torch.einsum("bchw,c->bhw", probs, bin_centers) + depth_map = torch.einsum("bchw,c->bhw", probs, self.bin_centers.to(logits.device)) + + # 3. Upsample to target resolution. if image_size is not None: depth_map = F.interpolate( - depth_map.unsqueeze(1), size=image_size, mode="bilinear", align_corners=False + depth_map.unsqueeze(1), + size=image_size, + mode="bilinear", + align_corners=False, ).squeeze(1) return depth_map.unsqueeze(1) @@ -314,37 +337,226 @@ def __init__(self, **kwargs) -> None: "convs.": "dpt.convs.", "fusion_blocks.": "dpt.fusion_blocks.", "project.": "dpt.project.", + # Task-specific head keys (Scenic Dense -> PyTorch head.*) "segmentation_head.": "head.", + "pixel_segmentation.": "head.", + "pixel_depth_classif.": "head.", + "pixel_depth_regress.": "head.", + "pixel_normals.": "head.", +} + +# Scenic/Flax head param names -> PyTorch head.* +_SCENIC_HEAD_NAMES = { + "pixel_segmentation", + "pixel_depth_classif", + "pixel_depth_regress", + "pixel_normals", + "segmentation_head", +} + +# ConvTranspose keys that need spatial flipping. +_CONV_TRANSPOSE_KEYS = { + 'dpt.reassemble.resize_layers.0.weight', + 'dpt.reassemble.resize_layers.1.weight', } +def _is_scenic_format(keys): + """Check if checkpoint keys use Scenic/Flax naming (``/`` separators).""" + return any('/' in k for k in keys) + + +def _convert_scenic_checkpoint(weights): + """Convert Flax parameter tree checkpoint to PyTorch state_dict. + + These checkpoints use Flax parameter tree naming: + decoder/dpt/reassemble_blocks/out_projection_0/kernel + which maps to PyTorch: + dpt.reassemble.out_projections.0.weight + + Weight conversions: + - Conv kernels: (H, W, Cin, Cout) -> (Cout, Cin, H, W) + - ConvTranspose kernels: same + 180-degree spatial flip + - Dense/Linear kernels: (in, out) -> (out, in) + - Biases: direct copy + """ + sd = {} + + # Build a nested dict from flat Scenic keys + tree = {} + for key, value in weights.items(): + # Strip "decoder/" prefix if present. + k = key[len("decoder/"):] if key.startswith("decoder/") else key + parts = k.split("/") + d = tree + for p in parts[:-1]: + d = d.setdefault(p, {}) + d[parts[-1]] = np.array(value) + + dpt_params = tree.get("dpt", tree) + + # --- ReassembleBlocks --- + rb = dpt_params.get("reassemble_blocks", {}) + for i in range(4): + # out_projections (Conv2d 1x1) + op = rb.get(f"out_projection_{i}", {}) + if "kernel" in op: + sd[f"dpt.reassemble.out_projections.{i}.weight"] = torch.from_numpy( + op["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in op: + sd[f"dpt.reassemble.out_projections.{i}.bias"] = torch.from_numpy( + op["bias"].copy() + ) + # readout_projects (Linear) + rp = rb.get(f"readout_projects_{i}", {}) + if "kernel" in rp: + sd[f"dpt.reassemble.readout_projects.{i}.weight"] = torch.from_numpy( + rp["kernel"].T.copy() + ) + if "bias" in rp: + sd[f"dpt.reassemble.readout_projects.{i}.bias"] = torch.from_numpy( + rp["bias"].copy() + ) + + # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv + for idx in [0, 1]: + rl = rb.get(f"resize_layers_{idx}", {}) + if "kernel" in rl: + w = rl["kernel"][::-1, ::-1, :, :].copy() # 180-degree spatial flip + sd[f"dpt.reassemble.resize_layers.{idx}.weight"] = torch.from_numpy( + w.transpose(2, 3, 0, 1).copy() + ) + if "bias" in rl: + sd[f"dpt.reassemble.resize_layers.{idx}.bias"] = torch.from_numpy( + rl["bias"].copy() + ) + # resize_layers_2 = Identity (no weights) + rl3 = rb.get("resize_layers_3", {}) + if "kernel" in rl3: + sd["dpt.reassemble.resize_layers.3.weight"] = torch.from_numpy( + rl3["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in rl3: + sd["dpt.reassemble.resize_layers.3.bias"] = torch.from_numpy( + rl3["bias"].copy() + ) + + # --- Convs (3x3, no bias) --- + for i in range(4): + c = dpt_params.get(f"convs_{i}", {}) + if "kernel" in c: + sd[f"dpt.convs.{i}.weight"] = torch.from_numpy( + c["kernel"].transpose(3, 2, 0, 1).copy() + ) + + # --- Fusion blocks --- + for i in range(4): + fb = dpt_params.get(f"fusion_blocks_{i}", {}) + if i == 0: + # No residual unit, only 1 PreActResidualConvUnit -> main_unit + pacu = fb.get("PreActResidualConvUnit_0", {}) + for cname in ["conv1", "conv2"]: + if cname in pacu and "kernel" in pacu[cname]: + sd[f"dpt.fusion_blocks.{i}.main_unit.{cname}.weight"] = ( + torch.from_numpy( + pacu[cname]["kernel"].transpose(3, 2, 0, 1).copy() + ) + ) + else: + # Residual unit (index 0) + main unit (index 1) + pacu0 = fb.get("PreActResidualConvUnit_0", {}) + pacu1 = fb.get("PreActResidualConvUnit_1", {}) + for cname in ["conv1", "conv2"]: + if cname in pacu0 and "kernel" in pacu0[cname]: + sd[f"dpt.fusion_blocks.{i}.residual_unit.{cname}.weight"] = ( + torch.from_numpy( + pacu0[cname]["kernel"].transpose(3, 2, 0, 1).copy() + ) + ) + if cname in pacu1 and "kernel" in pacu1[cname]: + sd[f"dpt.fusion_blocks.{i}.main_unit.{cname}.weight"] = ( + torch.from_numpy( + pacu1[cname]["kernel"].transpose(3, 2, 0, 1).copy() + ) + ) + # out_conv (Conv2d 1x1) -- Scenic names it Conv_0 + oc = fb.get("Conv_0", fb.get("out_conv", {})) + if "kernel" in oc: + sd[f"dpt.fusion_blocks.{i}.out_conv.weight"] = torch.from_numpy( + oc["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in oc: + sd[f"dpt.fusion_blocks.{i}.out_conv.bias"] = torch.from_numpy( + oc["bias"].copy() + ) + + # --- Project --- + proj = dpt_params.get("project", {}) + if "kernel" in proj: + sd["dpt.project.weight"] = torch.from_numpy( + proj["kernel"].transpose(3, 2, 0, 1).copy() + ) + if "bias" in proj: + sd["dpt.project.bias"] = torch.from_numpy(proj["bias"].copy()) + + # --- Task head (Dense/Linear) --- + for head_name in _SCENIC_HEAD_NAMES: + if head_name in tree: + h = tree[head_name] + if "kernel" in h: + sd["head.weight"] = torch.from_numpy(h["kernel"].T.copy()) + if "bias" in h: + sd["head.bias"] = torch.from_numpy(h["bias"].copy()) + break + + return sd + + def load_decoder_weights( model: Decoder, checkpoint_path: str, ) -> Decoder: - """Load pre-converted PyTorch weights into a Decoder. + """Load weights into a Decoder from a checkpoint file. - Supports both the legacy flat key format (e.g. ``reassemble.…``) and the - new hierarchical format (e.g. ``dpt.reassemble.…``). + Supports three checkpoint formats: + 1. Flax format: keys with ``/`` separators and ``kernel``/``bias`` + naming (e.g. ``decoder/dpt/reassemble_blocks/out_projection_0/kernel``). + Weights are automatically transposed from Flax layout to PyTorch layout. + 2. Legacy flat format: keys like ``reassemble.…`` that get remapped to + ``dpt.reassemble.…``. + 3. PyTorch hierarchical format: keys already match the model state_dict. Args: model: A Decoder instance (SegmentationDecoder, DepthDecoder, etc.). - checkpoint_path: Path to a checkpoint file. + checkpoint_path: Path to a ``.npz`` checkpoint file. Returns: The model with loaded weights. """ weights = dict(np.load(checkpoint_path, allow_pickle=False)) - sd = {} - for key, value in weights.items(): - new_key = key - # Remap legacy flat keys to hierarchical names. - for old_prefix, new_prefix in _LEGACY_KEY_PREFIXES.items(): - if key.startswith(old_prefix): - new_key = new_prefix + key[len(old_prefix):] - break - sd[new_key] = torch.from_numpy(value) + if _is_scenic_format(weights): + sd = _convert_scenic_checkpoint(weights) + else: + # Legacy flat or PyTorch hierarchical format. + sd = {} + for key, value in weights.items(): + new_key = key + for old_prefix, new_prefix in _LEGACY_KEY_PREFIXES.items(): + if key.startswith(old_prefix): + new_key = new_prefix + key[len(old_prefix):] + break + t = torch.from_numpy(value) + # Flip ConvTranspose kernels spatially for Flax->PyTorch parity. + if new_key in _CONV_TRANSPOSE_KEYS and t.ndim == 4: + t = t.flip([2, 3]) + sd[new_key] = t + + # Add registered buffers not in checkpoint (e.g. bin_centers). + for name, buf in model.named_buffers(): + if name not in sd: + sd[name] = buf model.load_state_dict(sd, strict=True) print(f"Loaded decoder weights from {checkpoint_path} ({len(sd)} tensors)")