Description
Implement the structured sparsity variants described in Section 3.2 and 4.3 of the L0 paper.
Variants to Implement
1. Neuron-wise Sparsity (MLPs)
class L0NeuronLinear(nn.Module):
"""Prune entire neurons (rows in weight matrix)."""
def __init__(self, in_features, out_features, **kwargs):
# One gate per output neuron
self.neuron_gates = HardConcrete(out_features, **kwargs)
...
def forward(self, x):
gates = self.neuron_gates().unsqueeze(1) # (out, 1)
masked_weight = self.weight * gates
return F.linear(x, masked_weight, self.bias)
2. Filter-wise Sparsity (CNNs)
Already partially implemented, but need to ensure it matches paper exactly.
3. Block-wise Sparsity
class L0BlockSparse(nn.Module):
"""Prune blocks of weights (e.g., 4x4 blocks)."""
def __init__(self, in_features, out_features, block_size=4):
# One gate per block
n_blocks = (out_features // block_size) * (in_features // block_size)
self.block_gates = HardConcrete(n_blocks)
...
Validation
- Reproduce Table 2 from the paper (structured sparsity results)
- Compare speedups with unstructured sparsity
Description
Implement the structured sparsity variants described in Section 3.2 and 4.3 of the L0 paper.
Variants to Implement
1. Neuron-wise Sparsity (MLPs)
2. Filter-wise Sparsity (CNNs)
Already partially implemented, but need to ensure it matches paper exactly.
3. Block-wise Sparsity
Validation