From 944706014e901cc09e45bc970d4c9b1d18485fa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=B9=BD=E9=9B=85?= Date: Thu, 9 Apr 2026 23:28:45 +0200 Subject: [PATCH 1/2] exp2_jepa --- JEPA/.DS_Store | Bin 0 -> 10244 bytes JEPA/configs/.DS_Store | Bin 0 -> 6148 bytes JEPA/configs/base.yaml | 118 +++ JEPA/configs/exp2_lora_encoder_predictor.yaml | 31 + JEPA/scripts/.DS_Store | Bin 0 -> 6148 bytes JEPA/scripts/_bootstrap.py | 12 + JEPA/scripts/decoder/test_lora.py | 14 + JEPA/scripts/decoder/test_projector.py | 14 + JEPA/scripts/decoder/train_lora.py | 14 + JEPA/scripts/decoder/train_projector.py | 14 + JEPA/scripts/encoder/embed.py | 14 + JEPA/scripts/encoder/train.py | 14 + JEPA/src/.DS_Store | Bin 0 -> 6148 bytes JEPA/src/__init__.py | 1 + JEPA/src/losses.py | 123 +++ JEPA/src/models.py | 316 +++++++ JEPA/src/tasks/.DS_Store | Bin 0 -> 6148 bytes JEPA/src/tasks/__init__.py | 1 + JEPA/src/tasks/decoder/__init__.py | 1 + JEPA/src/tasks/decoder/test_lora.py | 436 +++++++++ JEPA/src/tasks/decoder/test_projector.py | 381 ++++++++ JEPA/src/tasks/decoder/train_lora.py | 874 ++++++++++++++++++ JEPA/src/tasks/decoder/train_projector.py | 550 +++++++++++ JEPA/src/tasks/encoder/.DS_Store | Bin 0 -> 6148 bytes JEPA/src/tasks/encoder/__init__.py | 1 + JEPA/src/tasks/encoder/embed.py | 425 +++++++++ JEPA/src/tasks/encoder/train.py | 805 ++++++++++++++++ JEPA/src/utils.py | 318 +++++++ 28 files changed, 4477 insertions(+) create mode 100644 JEPA/.DS_Store create mode 100644 JEPA/configs/.DS_Store create mode 100644 JEPA/configs/base.yaml create mode 100644 JEPA/configs/exp2_lora_encoder_predictor.yaml create mode 100644 JEPA/scripts/.DS_Store create mode 100644 JEPA/scripts/_bootstrap.py create mode 100644 JEPA/scripts/decoder/test_lora.py create mode 100644 JEPA/scripts/decoder/test_projector.py create mode 100644 JEPA/scripts/decoder/train_lora.py create mode 100644 JEPA/scripts/decoder/train_projector.py create mode 100644 JEPA/scripts/encoder/embed.py create mode 100644 JEPA/scripts/encoder/train.py create mode 100644 JEPA/src/.DS_Store create mode 100644 JEPA/src/__init__.py create mode 100644 JEPA/src/losses.py create mode 100644 JEPA/src/models.py create mode 100644 JEPA/src/tasks/.DS_Store create mode 100644 JEPA/src/tasks/__init__.py create mode 100644 JEPA/src/tasks/decoder/__init__.py create mode 100644 JEPA/src/tasks/decoder/test_lora.py create mode 100644 JEPA/src/tasks/decoder/test_projector.py create mode 100644 JEPA/src/tasks/decoder/train_lora.py create mode 100644 JEPA/src/tasks/decoder/train_projector.py create mode 100644 JEPA/src/tasks/encoder/.DS_Store create mode 100644 JEPA/src/tasks/encoder/__init__.py create mode 100644 JEPA/src/tasks/encoder/embed.py create mode 100644 JEPA/src/tasks/encoder/train.py create mode 100644 JEPA/src/utils.py diff --git a/JEPA/.DS_Store b/JEPA/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a7de149a2e13354c45bfd17a2f2a974126e571c9 GIT binary patch literal 10244 zcmeHMTWl3Y7@lui=$z%Yg%-NV!Np>2p+Id*0YN-%DHlPDoYIz3IGoEuH*9y0XZM_f zgji8uyhWo?6N8HJ&4Bu1ln3IQ7!w~1A)+x}9(?h|geMbk|LpAHa*7W=plEiJoqy+_ z`Df?<=KFUt|1!qVQ_S7Xn86t1bPK7?q~ZpbXczaCO9?fi5oFJp#SSqqGvVbk_LSa{ zB0?ZSAVMHQAVMHQ;8s9@_H1$S8B#`Ngg}HqguoR9#P`9VTgXHxC#4KN9aIEI0Fu?D z4la78b3kBYgfbDzNht$U8dIJg5Sk)9VnCWxyCJ%hOoVb$N@>m@%^AWsBRruXcstn{ zL3f6Plu;QW5Fs!X0TDJeEX!Qxu!Gm0-`xzge3;4Gmgl#%U4&9qJ#Bgoui-QJzT}uY zk@VAk(d$VU_KMno>)M&}b2>9 zc4{<7(|W+oi2-v#i>@1Ya=usSdkZF&QmyJO%g7)-;E`wHELX0gDR={iX^o|++R@2l zT5h+U&X{(nTDy)^+lNXL-cWh(Mjoq0>u}Z`A18M_sn(4i&i9Kq??`)of8KPi%xUu$ z&b9L+-dMU|_D>MWLW|AYIS-bM^unkwXZdbH0%GX>obM6Cb?0(4lD42>{zJQToz}&2 zN)}O?)>|@L`+{@cm{=y4CjhlE(I#s#+N^Z8lD4L;Y?PTS&yKQZ*h%&}JIy{|=h%7n zIlIKZVqddw*bnR{cA5QxDpX@SYM?>KLfnQpmZ1ZkSdDJnkL}oj2a&`m_F+E`z=Q)A zc@%I2kKqI!$CG##FXI)wir4TqPU8&T#W{R}^Y|2>;Ud1pclaKc@hg7A?+RCHm8D9P za+lJqv?%S$8YLm8u9QP&MAplt(oIhl9|8&)O5}-l0UqB&k4L;QKr=h z4s$^^{dRRpT$@JRm(`_eQ(T))yqDFM`0}`>P`JtJ%6OZqRZ%d-0^5%Ga#aJDl{M-* zRja13lGP3BMpa`J=(5@^1|q8deXhRG&a#h*s~3o?KN3rSW`BTVHX5)PO=w07TCoCM zSc`R7j}7R<7Hkc;I)Fh8VGl--C8i$47%XC{k0MHVG~nwK#MkHWJYK+ycnK%*2HwOe zyoGo09zLpI;!s73I%hIzDQ6==K4Gryd0{vT{u*LhXpa~1hq zU|E@5|IhmxdQ;(}BtTR~2t)`(2t)`(2t)|nL02Xy^!m2!Vea z0aR~FZt9^~lyd`>6H-WF(0zbzafyDDQU)eeh%Q1Oq~oa~@`}x&F3g$)?TJt~DP>?z gc7Zubna#iHKLcbID?0y2=l>vieXAs_|HH1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T08udRk&ztJBg)kg0KYx_z?{={zh2VrYi&oAz%xZeolR!=8{K?_2T55qW6 z!>Ss@L85a#(;#vpSI-wlqs{V0QI;#EaZ!R@F3QUK#(10)OUtWUd-e0E6RS&{Hhhop zL26mBIDsoL-mNM+Tl1Q#)18D&M|4!JX4fqqZm-joyZh~rqJnQkb=zN!cH#BnNX2I` z!ekY%rbZ?)KnxHAKWD%mE@J-YT~QCj05R~-7{K#Ef+D&WQ-k{GfI>e30JGrM0ye%S zFh^Q+Ev5#c1%#VaK$FVt6N8&{@Jk!#T1*X^bjI!DgWHk0eW7qYI`l7fIODECDv1GN z;4K3)rkcn5fA{nE|Jx+05d*})zhZ!AYhJAiOR{(C)Z*~2m7u4fC>WO-d`VAb3FNM?lj+1u^ie3_Jl{p>yK^ literal 0 HcmV?d00001 diff --git a/JEPA/scripts/_bootstrap.py b/JEPA/scripts/_bootstrap.py new file mode 100644 index 0000000..f44d248 --- /dev/null +++ b/JEPA/scripts/_bootstrap.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import sys +from pathlib import Path + + +def bootstrap() -> None: + repo_root = Path(__file__).resolve().parents[1] + src_dir = repo_root / "src" + src_str = str(src_dir) + if src_str not in sys.path: + sys.path.insert(0, src_str) diff --git a/JEPA/scripts/decoder/test_lora.py b/JEPA/scripts/decoder/test_lora.py new file mode 100644 index 0000000..e4a3cbd --- /dev/null +++ b/JEPA/scripts/decoder/test_lora.py @@ -0,0 +1,14 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from _bootstrap import bootstrap + +bootstrap() + +from jepa.tasks.decoder.test_lora import main + + +if __name__ == "__main__": + main() diff --git a/JEPA/scripts/decoder/test_projector.py b/JEPA/scripts/decoder/test_projector.py new file mode 100644 index 0000000..00b5af5 --- /dev/null +++ b/JEPA/scripts/decoder/test_projector.py @@ -0,0 +1,14 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from _bootstrap import bootstrap + +bootstrap() + +from jepa.tasks.decoder.test_projector import main + + +if __name__ == "__main__": + main() diff --git a/JEPA/scripts/decoder/train_lora.py b/JEPA/scripts/decoder/train_lora.py new file mode 100644 index 0000000..7a7a5b9 --- /dev/null +++ b/JEPA/scripts/decoder/train_lora.py @@ -0,0 +1,14 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from _bootstrap import bootstrap + +bootstrap() + +from jepa.tasks.decoder.train_lora import parse_args, train + + +if __name__ == "__main__": + train(parse_args()) diff --git a/JEPA/scripts/decoder/train_projector.py b/JEPA/scripts/decoder/train_projector.py new file mode 100644 index 0000000..8a8cc5d --- /dev/null +++ b/JEPA/scripts/decoder/train_projector.py @@ -0,0 +1,14 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from _bootstrap import bootstrap + +bootstrap() + +from jepa.tasks.decoder.train_projector import parse_args, train + + +if __name__ == "__main__": + train(parse_args()) diff --git a/JEPA/scripts/encoder/embed.py b/JEPA/scripts/encoder/embed.py new file mode 100644 index 0000000..1a4260b --- /dev/null +++ b/JEPA/scripts/encoder/embed.py @@ -0,0 +1,14 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from _bootstrap import bootstrap + +bootstrap() + +from jepa.tasks.encoder.embed import main + + +if __name__ == "__main__": + main() diff --git a/JEPA/scripts/encoder/train.py b/JEPA/scripts/encoder/train.py new file mode 100644 index 0000000..e1aa07e --- /dev/null +++ b/JEPA/scripts/encoder/train.py @@ -0,0 +1,14 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from _bootstrap import bootstrap + +bootstrap() + +from jepa.tasks.encoder.train import main + + +if __name__ == "__main__": + main() diff --git a/JEPA/src/.DS_Store b/JEPA/src/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..0909df4bbf6c38dd9103c7d0fbc144d9f0f9de85 GIT binary patch literal 6148 zcmeHKy-or_5Z(opCB%e66U$93Y)AwoSlFD1FW?z1)Zm>*qbxwal)IhJb{ z)^BchDzzf7R7;~G2fI?_)$Q77lw)ffTf2wNPS6d-ttOJf7pG*!;27?}nAlNpRkklh zw>KW19_@-o!+fvo{^99$^Z8}(bAM%c7;e=Wmu|)=ufoTTPx?7O<|H1e5|8x@e zkO5@iUok*44ZBf?Et$J@W^?kc<)9a!Qe<4B@goHcbrnObyo%RBm4Kg11JE* z_tgt +- emb_std_mean: mean std across embedding dimensions (collapse indicator) +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# ------------------------- +# Core losses +# ------------------------- +class CosineAlignLoss(nn.Module): + """1 - cosine similarity (mean over batch).""" + + def __init__(self, eps: float = 1e-8): + super().__init__() + self.eps = float(eps) + + def forward(self, z_pred: torch.Tensor, z_tgt: torch.Tensor) -> torch.Tensor: + z_pred = F.normalize(z_pred, dim=-1, eps=self.eps) + z_tgt = F.normalize(z_tgt, dim=-1, eps=self.eps) + return 1.0 - (z_pred * z_tgt).sum(dim=-1).mean() + + +class VarianceLoss(nn.Module): + """ + VICReg variance term: + std = sqrt(var + eps) + loss = mean( relu(target_std - std) ) + """ + + def __init__(self, target_std: float = 1.0, eps: float = 1e-4): + super().__init__() + self.target_std = float(target_std) + self.eps = float(eps) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # z: [B, D] + std = torch.sqrt(z.var(dim=0, unbiased=False) + self.eps) + return F.relu(self.target_std - std).mean() + + +class EmaPredictiveLoss(nn.Module): + """ + Minimal JEPA-style loss for exp2 (EMA target + predictor): + + L = w_align * align(z_pred, z_tgt) + w_var * (var(z_ctx) + var(z_tgt)) + """ + def __init__( + self, + w_align: float = 1.0, + w_var: float = 1.0, + align_eps: float = 1e-8, + var_target_std: float = 1.0, + var_eps: float = 1e-4, + ): + super().__init__() + self.w_align = float(w_align) + self.w_var = float(w_var) + self.align = CosineAlignLoss(eps=align_eps) + self.var = VarianceLoss(target_std=var_target_std, eps=var_eps) + + def forward(self, z_ctx, z_pred, z_tgt): + l_align = self.align(z_pred, z_tgt) + l_var = self.var(z_ctx) + self.var(z_tgt) + loss = self.w_align * l_align + self.w_var * l_var + return {"loss": loss, "align": l_align, "var": l_var} + + +def build_loss(cfg: Dict[str, Any]) -> EmaPredictiveLoss: + lcfg = cfg.get("loss", {}) + return EmaPredictiveLoss( + w_align=float(lcfg.get("w_align", 1.0)), + w_var=float(lcfg.get("w_var", 1.0)), + align_eps=float(lcfg.get("align_eps", 1e-8)), + var_target_std=float(lcfg.get("var_target_std", 1.0)), + var_eps=float(lcfg.get("var_eps", 1e-4)), + ) + + +# ------------------------- +# Metrics (no grad) +# ------------------------- +@torch.no_grad() +def retrieval_top1_acc(z_pred: torch.Tensor, z_tgt: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + """ + Batch retrieval accuracy: + For each i, find argmax_j cos(z_pred[i], z_tgt[j]) + count how many i match j=i + AMP-safe: cast to float32 for stable metric computation. + """ + z_pred = z_pred.float() + z_tgt = z_tgt.float() + + z_pred = F.normalize(z_pred, dim=-1, eps=eps) + z_tgt = F.normalize(z_tgt, dim=-1, eps=eps) + + sim = z_pred @ z_tgt.t() # [B, B] + pred_idx = sim.argmax(dim=1) + gt_idx = torch.arange(sim.size(0), device=sim.device) + return (pred_idx == gt_idx).float().mean() + + +@torch.no_grad() +def emb_std_mean(z: torch.Tensor) -> torch.Tensor: + z = z.float() + return z.std(dim=0, unbiased=False).mean() \ No newline at end of file diff --git a/JEPA/src/models.py b/JEPA/src/models.py new file mode 100644 index 0000000..41bbc6c --- /dev/null +++ b/JEPA/src/models.py @@ -0,0 +1,316 @@ +# models.py +# -*- coding: utf-8 -*- +""" +Models module for JEPA-style training on code pairs. + +Goals: +- Flexible encoder: ModernBERT (or any HF AutoModel) with modes: + - frozen: no trainable params + - full: full fine-tuning + - lora: LoRA fine-tuning via PEFT (optional dependency) +- Flexible predictor: + - vit1d: your current "ViT on 1D embedding" predictor + - mlp: simple MLP predictor (swap-in baseline) +- Provide small "factory" functions: + - build_encoder(cfg, device) + - build_predictor(cfg, emb_dim, device) + - infer_emb_dim(model_name) +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers import AutoModel + + +# ------------------------- +# Helpers +# ------------------------- +def infer_emb_dim(model_name: str) -> int: + """Infer hidden size from HF config (no weights needed).""" + m = AutoModel.from_pretrained(model_name) + return int(getattr(m.config, "hidden_size")) + + +def _freeze_module(m: nn.Module) -> None: + for p in m.parameters(): + p.requires_grad = False + + +def _unfreeze_module(m: nn.Module) -> None: + for p in m.parameters(): + p.requires_grad = True + + +# ------------------------- +# Encoder +# ------------------------- +class HFMeanPoolEncoder(nn.Module): + """ + Encoder wrapper: + - backbone: HF AutoModel + - output: [B, D] mean-pooled embedding (mask-aware) + """ + + def __init__(self, model_name: str): + super().__init__() + self.backbone = AutoModel.from_pretrained(model_name) + + @property + def emb_dim(self) -> int: + return int(getattr(self.backbone.config, "hidden_size")) + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + out = self.backbone(input_ids=input_ids, attention_mask=attention_mask) + h = out.last_hidden_state # [B, L, D] + mask = attention_mask.unsqueeze(-1).to(h.dtype) # [B, L, 1] + pooled = (h * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0) + return pooled # [B, D] + + +def _apply_lora(backbone: nn.Module, lora_cfg: Dict[str, Any]) -> nn.Module: + """ + Apply PEFT LoRA to a HF model. + lora_cfg expected keys (examples): + enabled: bool + r: int + alpha: int + dropout: float + target_modules: list[str] + bias: str (optional, default "none") + task_type: str (optional; for encoders we can omit) + """ + try: + from peft import LoraConfig, get_peft_model + except Exception as e: # pragma: no cover + raise RuntimeError( + "PEFT (peft) is required for LoRA mode but is not available in this environment. " + "Install it via `pip install peft`." + ) from e + + if not lora_cfg.get("target_modules"): + raise ValueError( + "LoRA requires `encoder.lora.target_modules` to be a non-empty list. " + "Please set it in your config." + ) + + bias = lora_cfg.get("bias", "none") + # For encoder-style models, PEFT task_type is not strictly required for plain LoRA. + # If you want, you can pass task_type=lora_cfg.get("task_type", None) + lcfg = LoraConfig( + r=int(lora_cfg.get("r", 16)), + lora_alpha=int(lora_cfg.get("alpha", 32)), + lora_dropout=float(lora_cfg.get("dropout", 0.05)), + target_modules=list(lora_cfg["target_modules"]), + bias=bias, + ) + return get_peft_model(backbone, lcfg) + + +def build_encoder(cfg: Dict[str, Any], device: torch.device) -> Tuple[Optional[nn.Module], int]: + """ + Build encoder according to cfg["encoder"]. + + Expected cfg structure (example): + cfg["encoder"] = { + "name": "answerdotai/ModernBERT-large", + "train_mode": "frozen" | "full" | "lora", + "lora": {"enabled": true, "r":16, "alpha":32, "dropout":0.05, "target_modules":[...]} + } + + Returns: + (encoder_module_or_None, emb_dim) + + Note: + If you later want "cached embeddings" mode, you can set cfg["encoder"]["name"]=None + and handle it upstream; here we return (None, emb_dim) only if emb_dim is provided. + """ + enc_cfg = cfg.get("encoder", {}) + model_name = enc_cfg.get("name", None) + + # If no encoder name, assume upstream will provide embeddings. + if not model_name: + emb_dim = enc_cfg.get("emb_dim") + if emb_dim is None: + raise ValueError("encoder.name is None but encoder.emb_dim is not set.") + return None, int(emb_dim) + + encoder = HFMeanPoolEncoder(model_name=model_name).to(device) + emb_dim = encoder.emb_dim + + train_mode = (enc_cfg.get("train_mode") or "frozen").lower() + if train_mode not in ("frozen", "full", "lora"): + raise ValueError(f"Unknown encoder.train_mode: {train_mode}") + + if train_mode == "frozen": + _freeze_module(encoder) + encoder.eval() # usually frozen encoder in eval mode + elif train_mode == "full": + _unfreeze_module(encoder) + encoder.train() + else: # lora + lora_cfg = enc_cfg.get("lora", {}) + if not lora_cfg.get("enabled", True): + raise ValueError("encoder.train_mode is 'lora' but encoder.lora.enabled is false.") + # Apply LoRA to the backbone only (keeps wrapper intact) + encoder.backbone = _apply_lora(encoder.backbone, lora_cfg) + _unfreeze_module(encoder) # PEFT will set only LoRA params trainable; others frozen. + encoder.train() + + return encoder, emb_dim + + +# ------------------------- +# Predictors +# ------------------------- +class ViTPredictor1D(nn.Module): + """ + 1D "ViT-like" predictor operating on pooled embeddings [B, D]. + + It reshapes embedding into tokens: [B, T, patch], projects to d_model, + runs TransformerEncoder, projects back to patch and flattens. + """ + + def __init__( + self, + dim: int, + patch: int = 32, + layers: int = 4, + heads: int = 8, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + ): + super().__init__() + if dim % patch != 0: + raise ValueError(f"emb_dim={dim} must be divisible by patch={patch}") + self.dim = dim + self.patch = patch + self.num_tokens = dim // patch + + self.token_proj = nn.Linear(patch, dim) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=dim, + nhead=heads, + dim_feedforward=int(dim * mlp_ratio), + dropout=float(dropout), + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=layers) + self.out = nn.Linear(dim, patch) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # z: [B, D] + b, d = z.shape + x = z.view(b, self.num_tokens, self.patch) # [B, T, patch] + x = self.token_proj(x) # [B, T, D] + x = self.encoder(x) # [B, T, D] + x = self.out(x) # [B, T, patch] + return x.reshape(b, d) # [B, D] + + +class MLPredictor(nn.Module): + """ + MLP predictor: [B, D] -> [B, D] + + Two styles supported: + 1) hidden_sizes list (your reference): e.g. [4096, 2048, 1024] + => Linear(D->4096)->Act->Linear(4096->2048)->Act->Linear(2048->1024)->Act->Linear(1024->D) + 2) hidden + layers (compact): layers>=1 + => (layers-1) blocks of Linear->Act(+LN/Dropout) then final Linear->D + """ + + def __init__( + self, + dim: int, + hidden_sizes: Optional[list[int]] = None, + hidden: int = 2048, + layers: int = 3, + activation: str = "relu", # "relu" | "gelu" + dropout: float = 0.0, + use_layernorm: bool = False, + residual: bool = False, + out_layernorm: bool = False, + ): + super().__init__() + + act = nn.ReLU() if activation.lower() == "relu" else nn.GELU() + + sizes = None + if hidden_sizes is not None and len(hidden_sizes) > 0: + sizes = [int(x) for x in hidden_sizes] + else: + if layers < 1: + raise ValueError("MLPredictor layers must be >= 1") + sizes = [hidden] * max(0, layers - 1) + + mods = [] + in_f = dim + for h in sizes: + mods.append(nn.Linear(in_f, h)) + if use_layernorm: + mods.append(nn.LayerNorm(h)) + mods.append(act) + if dropout and dropout > 0: + mods.append(nn.Dropout(float(dropout))) + in_f = h + + mods.append(nn.Linear(in_f, dim)) + self.net = nn.Sequential(*mods) + + self.residual = bool(residual) + self.out_ln = nn.LayerNorm(dim) if out_layernorm else nn.Identity() + + def forward(self, z: torch.Tensor) -> torch.Tensor: + y = self.net(z) + if self.residual: + y = y + z + return self.out_ln(y) + + +def build_predictor(cfg: Dict[str, Any], emb_dim: int, device: torch.device) -> nn.Module: + p_cfg = cfg.get("predictor", {}) + name = (p_cfg.get("name") or "vit1d").lower() + + if name == "vit1d": + v = p_cfg.get("vit", {}) + m = ViTPredictor1D( + dim=emb_dim, + patch=int(v.get("patch", 32)), + layers=int(v.get("layers", 4)), + heads=int(v.get("heads", 8)), + mlp_ratio=float(v.get("mlp_ratio", 4.0)), + dropout=float(v.get("dropout", 0.0)), + ).to(device) + return m + + if name == "mlp": + mcfg = p_cfg.get("mlp", {}) + # preferred: hidden_sizes list (matches your reference implementation) + hidden_sizes = mcfg.get("hidden_sizes", None) + if hidden_sizes is not None: + # YAML might load it as list already; if string, try parse later upstream + hidden_sizes = list(hidden_sizes) + + m = MLPredictor( + dim=emb_dim, + hidden_sizes=hidden_sizes, + hidden=int(mcfg.get("hidden", 2048)), + layers=int(mcfg.get("layers", 3)), + activation=str(mcfg.get("activation", "relu")), + dropout=float(mcfg.get("dropout", 0.0)), + use_layernorm=bool(mcfg.get("use_layernorm", False)), + residual=bool(mcfg.get("residual", False)), + out_layernorm=bool(mcfg.get("out_layernorm", False)), + ).to(device) + return m + + raise ValueError(f"Unknown predictor.name: {name}") \ No newline at end of file diff --git a/JEPA/src/tasks/.DS_Store b/JEPA/src/tasks/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..68cbf7d3678849d11ad1123429f5754f43fac9af GIT binary patch literal 6148 zcmeHKF=_)r43rXs4QX7a+%Mz@i*a6%4}|#KnH(gfzbfy_(=sE8fVp!zH)aId+0|-x z*(pvZGxOco@MN|%vkjbR-yG(~efrGqDnj%)W3Z!r44;S7+c?W!4v@Px@(wlz7JodE zu*Syw^49Wsib??~AO)m=6p#Y{D}b6!+dL#{lmb#f3j8U+??ZzVd*PHApAHPs0svPC zhhZMQ1h6px?1fVzA}~)XFsWWGh9@2IR(ZW}N=&+W+>CSTX0Hy#<95Vbq?`9djZ#1g zoGWmj%Ps5wGyFpTe@@a$3P^#QQovXHm;DY;s@ghx9BXZZzrvaG1E*mg6bw<0fl-dJ fU^#w_q|9raW4{+pi9ts^=s^7pP#2jLxV8enMS2^m literal 0 HcmV?d00001 diff --git a/JEPA/src/tasks/__init__.py b/JEPA/src/tasks/__init__.py new file mode 100644 index 0000000..cc8ce29 --- /dev/null +++ b/JEPA/src/tasks/__init__.py @@ -0,0 +1 @@ +"""Task entrypoints for JEPA experiments.""" diff --git a/JEPA/src/tasks/decoder/__init__.py b/JEPA/src/tasks/decoder/__init__.py new file mode 100644 index 0000000..b24e1d2 --- /dev/null +++ b/JEPA/src/tasks/decoder/__init__.py @@ -0,0 +1 @@ +"""Decoder downstream tasks.""" diff --git a/JEPA/src/tasks/decoder/test_lora.py b/JEPA/src/tasks/decoder/test_lora.py new file mode 100644 index 0000000..583eeed --- /dev/null +++ b/JEPA/src/tasks/decoder/test_lora.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +test_exp3_decoder_lora.py + +Decode repaired code from embedding pt using: +- StarCoder2-3B +- trained prompt projector +- trained ln +- optional trained LoRA weights + +Expected emb_pt format: +{ + "z_pred": Tensor[N, D], + "global_indices": LongTensor[N] +} + +Supported checkpoint formats: +1) tunable-only checkpoint: +{ + "projector": ..., + "ln": ..., + "lora": ... +} + +2) lightweight resume checkpoint: +{ + "epoch": ..., + "global_step": ..., + "best_val": ..., + "projector": ..., + "ln": ..., + "lora": ..., + "optimizer": ..., + "scaler": ... +} + +Output JSONL format: +{ + "global_index": ..., + "problem_id": ..., + "buggy_submission_id": ..., + "fixed_submission_id": ..., + "language": ..., + "preds": [...], + "gt_fixed_code": ... +} +""" + +import os +import json +import argparse +from typing import Any, List, Tuple, Dict + +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from datasets import load_dataset +from peft import LoraConfig, TaskType, get_peft_model +from jepa.utils import ( + ddp_barrier, + ddp_cleanup, + ddp_enabled, + ddp_local_rank, + ddp_rank, + ddp_setup, + ddp_world, + is_main, + load_embeddings_pt, +) + + +# ===================== +# Model +# ===================== +class SoftPromptStarCoderDecoder(nn.Module): + def __init__(self, cond_dim: int, decoder_model, tokenizer, prompt_len: int = 128): + super().__init__() + self.decoder = decoder_model + self.tokenizer = tokenizer + self.prompt_len = int(prompt_len) + self.hidden_dim = int(decoder_model.config.hidden_size) + + inter_dim = cond_dim * 4 + self.prompt_proj = nn.Sequential( + nn.Linear(cond_dim, inter_dim), + nn.LayerNorm(inter_dim), + nn.SiLU(), + nn.Dropout(0.1), + nn.Linear(inter_dim, self.prompt_len * self.hidden_dim), + ) + self.ln = nn.LayerNorm(self.hidden_dim) + + self.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 0 + self.eos_token_id = tokenizer.eos_token_id + + @torch.no_grad() + def generate_fast(self, cond_emb: torch.Tensor, max_new_tokens: int = 512) -> List[str]: + self.eval() + B = cond_emb.shape[0] + + prompt = self.prompt_proj(cond_emb.to(self.prompt_proj[0].weight.dtype)) + prompt = prompt.view(B, self.prompt_len, self.hidden_dim) + prompt = self.ln(prompt).to(self.decoder.dtype) + + p_mask = torch.ones(B, self.prompt_len, device=cond_emb.device, dtype=torch.long) + + generated_ids = self.decoder.generate( + inputs_embeds=prompt, + attention_mask=p_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_token_id, + ) + return self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + +# ===================== +# Checkpoint loader +# ===================== +def load_inference_weights(model: nn.Module, path: str, use_lora: bool) -> Dict[str, Any]: + """ + Load projector / ln / optional lora weights from either: + - tunable-only checkpoint + - lightweight resume checkpoint + + Returns checkpoint metadata if present. + """ + ckpt = torch.load(path, map_location="cpu", weights_only=False) + + if "projector" not in ckpt or "ln" not in ckpt: + raise KeyError("Checkpoint must contain at least 'projector' and 'ln'") + + metadata = { + "epoch": ckpt.get("epoch", None), + "global_step": ckpt.get("global_step", None), + "best_val": ckpt.get("best_val", None), + } + + model.prompt_proj.load_state_dict(ckpt["projector"], strict=True) + model.ln.load_state_dict(ckpt["ln"], strict=True) + + print(f"[Load] projector loaded") + print(f"[Load] ln loaded") + + if use_lora: + if "lora" not in ckpt: + raise KeyError("Checkpoint does not contain 'lora', but --use_lora was specified.") + + current_lora_state = { + name: p + for name, p in model.decoder.named_parameters() + if p.requires_grad + } + + missing = [] + loaded = 0 + for name, tensor in ckpt["lora"].items(): + if name in current_lora_state: + current_lora_state[name].data.copy_(tensor) + loaded += 1 + else: + missing.append(name) + + print(f"[Load] lora params loaded: {loaded}") + if missing: + print(f"[Warn] missing current lora keys for ckpt items: {missing[:10]}") + else: + if "lora" in ckpt and is_main(): + print("[Info] checkpoint contains 'lora', but --use_lora was not specified. LoRA weights are ignored.") + + return metadata + + +# ===================== +# Dataset +# ===================== +class DecodeDataset(Dataset): + def __init__( + self, + emb_cpu: torch.Tensor, + global_indices: List[int], + problem_ids: List[Any], + buggy_sids: List[Any], + fixed_sids: List[Any], + languages: List[Any], + gt_fixed_codes: List[str], + ): + self.emb = emb_cpu + self.gidx = global_indices + self.pid = problem_ids + self.bsid = buggy_sids + self.fsid = fixed_sids + self.lang = languages + self.gt = gt_fixed_codes + + def __len__(self): + return int(self.emb.shape[0]) + + def __getitem__(self, i: int): + return ( + self.emb[i], + int(self.gidx[i]), + str(self.pid[i]), + int(self.bsid[i]), + int(self.fsid[i]), + str(self.lang[i]), + str(self.gt[i]), + ) + + +# ===================== +# Args +# ===================== +def parse_args(): + ap = argparse.ArgumentParser() + + ap.add_argument("--emb_pt", type=str, required=True) + ap.add_argument("--emb_key", type=str, default="z_pred") + ap.add_argument("--ckpt", type=str, required=True) + ap.add_argument("--out_jsonl", type=str, required=True) + + ap.add_argument("--decoder_model_id", type=str, default="bigcode/starcoder2-3b") + ap.add_argument("--prompt_len", type=int, default=128) + ap.add_argument("--batch_size", type=int, default=4) + ap.add_argument("--max_new_tokens", type=int, default=512) + ap.add_argument("--num_workers", type=int, default=4) + ap.add_argument("--max_items", type=int, default=-1, help="Per-rank max items. -1 means no limit.") + ap.add_argument("--use_bf16", action="store_true") + + ap.add_argument("--hf_dataset_id", type=str, default="ASSERT-KTH/RunBugRun-Final") + ap.add_argument("--hf_split", type=str, default="train") + ap.add_argument("--hf_fixed_field", type=str, default="fixed_code") + + # LoRA options (must match training if used) + ap.add_argument("--use_lora", action="store_true", help="Enable LoRA structure before loading checkpoint") + ap.add_argument("--lora_r", type=int, default=32) + ap.add_argument("--lora_alpha", type=int, default=64) + ap.add_argument("--lora_dropout", type=float, default=0.05) + ap.add_argument( + "--lora_target_modules", + type=str, + nargs="+", + default=["q_proj", "k_proj", "v_proj", "o_proj", "c_proj"], + ) + + return ap.parse_args() + + +@torch.no_grad() +def main(): + args = parse_args() + + if ddp_enabled(): + ddp_setup("nccl" if torch.cuda.is_available() else "gloo") + + device = torch.device(f"cuda:{ddp_local_rank()}") if torch.cuda.is_available() else torch.device("cpu") + r, w = ddp_rank(), ddp_world() + + if is_main(): + print(f"[DDP] enabled={ddp_enabled()} world_size={w}") + print(f"[Args] ckpt={args.ckpt}") + print(f"[Args] emb_pt={args.emb_pt}") + print(f"[Args] out_jsonl={args.out_jsonl}") + print(f"[Args] use_lora={args.use_lora}") + print(f"[Rank {r}] device={device}") + + base_out = args.out_jsonl + out_rank = base_out.replace(".jsonl", f".rank{r}.jsonl") + out_dir = os.path.dirname(out_rank) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + + # load embeddings + emb_all, gidx_all = load_embeddings_pt(args.emb_pt, key=args.emb_key) + if is_main(): + print(f"[Data] emb_all shape={tuple(emb_all.shape)} from {args.emb_pt}") + + # shard by rank + emb_shard = emb_all[r::w].contiguous() + gidx_shard = gidx_all[r::w].tolist() + + # optional truncation (per-rank) + if args.max_items > 0: + emb_shard = emb_shard[:args.max_items].contiguous() + gidx_shard = gidx_shard[:args.max_items] + + print(f"[Rank {r}] emb_shard shape={tuple(emb_shard.shape)} num_items={len(gidx_shard)}") + + # load HF metadata + ds = load_dataset(args.hf_dataset_id, split=args.hf_split) + subset = ds.select([int(x) for x in gidx_shard]) + + gt_fixed_codes = [str(x) if x is not None else "" for x in subset[args.hf_fixed_field]] + languages = subset["language"] + problem_ids = subset["problem_id"] + buggy_sids = subset["buggy_submission_id"] + fixed_sids = subset["fixed_submission_id"] + + # tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.decoder_model_id) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # base decoder + decoder = AutoModelForCausalLM.from_pretrained( + args.decoder_model_id, + torch_dtype=torch.bfloat16 if args.use_bf16 else torch.float16, + low_cpu_mem_usage=True, + ).to(device) + + # optional LoRA structure + if args.use_lora: + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=True, + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=args.lora_target_modules, + ) + decoder = get_peft_model(decoder, lora_config) + if is_main(): + print("[Model] LoRA structure rebuilt for inference.") + else: + if is_main(): + print("[Model] LoRA disabled for inference.") + + decoder.eval() + + cond_dim = int(emb_shard.shape[1]) + model = SoftPromptStarCoderDecoder(cond_dim, decoder, tokenizer, prompt_len=args.prompt_len).to(device) + + if device.type == "cuda" and args.use_bf16: + try: + model.prompt_proj.to(torch.bfloat16) + model.ln.to(torch.bfloat16) + except Exception as e: + print(f"[Warn] bf16 cast failed: {e}") + + if not os.path.exists(args.ckpt): + raise FileNotFoundError(f"Checkpoint not found: {args.ckpt}") + + ckpt_meta = load_inference_weights(model, args.ckpt, use_lora=args.use_lora) + + if is_main(): + if ckpt_meta["epoch"] is not None: + print(f"[CKPT] epoch = {ckpt_meta['epoch']}") + if ckpt_meta["global_step"] is not None: + print(f"[CKPT] global_step = {ckpt_meta['global_step']}") + if ckpt_meta["best_val"] is not None: + print(f"[CKPT] best_val = {ckpt_meta['best_val']}") + + ds_decode = DecodeDataset( + emb_cpu=emb_shard, + global_indices=gidx_shard, + problem_ids=problem_ids, + buggy_sids=buggy_sids, + fixed_sids=fixed_sids, + languages=languages, + gt_fixed_codes=gt_fixed_codes, + ) + + dl = DataLoader( + ds_decode, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + ) + + with open(out_rank, "w", encoding="utf-8") as f: + pbar = tqdm(dl, desc=f"Decoding(rank {r})", leave=True) + for b_emb, b_gidx, b_pid, b_bsid, b_fsid, b_lang, b_gt in pbar: + b_emb = b_emb.to(device, non_blocking=True) + preds = model.generate_fast(b_emb, max_new_tokens=args.max_new_tokens) + + if torch.is_tensor(b_gidx): + b_gidx = b_gidx.tolist() + if torch.is_tensor(b_bsid): + b_bsid = b_bsid.tolist() + if torch.is_tensor(b_fsid): + b_fsid = b_fsid.tolist() + + for i in range(len(preds)): + rec = { + "global_index": int(b_gidx[i]), + "problem_id": str(b_pid[i]), + "buggy_submission_id": int(b_bsid[i]), + "fixed_submission_id": int(b_fsid[i]), + "language": str(b_lang[i]), + "preds": [preds[i]], + "gt_fixed_code": str(b_gt[i]), + } + f.write(json.dumps(rec, ensure_ascii=False) + "\n") + + print(f"[Rank {r}] wrote -> {out_rank}") + + ddp_barrier() + + # merge outputs + if ddp_enabled(): + if is_main(): + with open(base_out, "w", encoding="utf-8") as fout: + total_lines = 0 + for rr in range(w): + part = base_out.replace(".jsonl", f".rank{rr}.jsonl") + if not os.path.exists(part): + print(f"[Merge] missing shard: {part}") + continue + with open(part, "r", encoding="utf-8") as fin: + for line in fin: + fout.write(line) + total_lines += 1 + print(f"[Merge] merged -> {base_out} (total_lines={total_lines})") + else: + if out_rank != base_out: + with open(out_rank, "r", encoding="utf-8") as fin, open(base_out, "w", encoding="utf-8") as fout: + total_lines = 0 + for line in fin: + fout.write(line) + total_lines += 1 + if is_main(): + print(f"[Merge] single-rank copy -> {base_out} (total_lines={total_lines})") + + if ddp_enabled(): + ddp_cleanup() + + +if __name__ == "__main__": + main() diff --git a/JEPA/src/tasks/decoder/test_projector.py b/JEPA/src/tasks/decoder/test_projector.py new file mode 100644 index 0000000..d1cb6bf --- /dev/null +++ b/JEPA/src/tasks/decoder/test_projector.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +decoder_only_ddp.py (JSONL format aligned with your ViT pipeline) + +Goal: +- Load precomputed embeddings from a .pt file (pred_fixed_emb OR true fixed_emb) +- Decode with StarCoder2-3B + your tunable params (prompt_proj + ln) +- DDP parallel: each rank decodes a shard of embeddings (rows) +- Write JSONL records in the SAME format as your previous ViT pipeline: + { + "global_index": ..., + "problem_id": ..., + "buggy_submission_id": ..., + "fixed_submission_id": ..., + "language": ..., + "preds": [...], + "gt_fixed_code": ... + } +- Save JSONL per-rank, and optionally merge on rank0 + +Run (single GPU): + python3 decoder_only_ddp.py --emb_pt pred_fixed_emb_test.pt + +Run (multi GPU): + torchrun --standalone --nproc_per_node=4 decoder_only_ddp.py --emb_pt pred_fixed_emb_test.pt +""" + +import os +import json +import argparse +from typing import Any, Dict, List, Tuple, Optional + +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from datasets import load_dataset +from jepa.utils import ddp_barrier, ddp_cleanup, ddp_enabled, ddp_local_rank, ddp_rank, ddp_setup, ddp_world, is_main + + +# ===================== +# Keep your existing constants (DO NOT CHANGE) +# ===================== +DECODER_TUNABLE_CKPT = "/mimer/NOBACKUP/groups/naiss2025-5-243/youya/CodeRepair_JEPA/e2_decoder_runs_r1/r1_decA_zpred/checkpoints/ckpt_epoch6.pt" + +DECODER_MODEL_ID = "bigcode/starcoder2-3b" +PROMPT_LEN = 128 + +BATCH_DECODER = 20 +MAX_NEW_TOKENS = 512 + +SAVE_PRED_CODE_DIR = "results/pred_fixed_code_pred" +SAVE_PRED_CODE_JSONL = os.path.join(SAVE_PRED_CODE_DIR, "pred_fixed_code_test.jsonl") + +# HF dataset fields (same as your ViT pipeline) +HF_DATASET_ID = "ASSERT-KTH/RunBugRun-Final" +HF_SPLIT = "train" +HF_FIXED_FIELD = "fixed_code" + + +# ===================== +# Decoder wrapper(与你训练逻辑一致) +# ===================== +class SoftPromptStarCoderDecoder(nn.Module): + def __init__(self, cond_dim: int, decoder_model, tokenizer, prompt_len: int = 32): + super().__init__() + self.decoder = decoder_model + self.tokenizer = tokenizer + self.prompt_len = prompt_len + self.hidden_dim = decoder_model.config.hidden_size + + inter_dim = cond_dim * 4 + self.prompt_proj = nn.Sequential( + nn.Linear(cond_dim, inter_dim), + nn.LayerNorm(inter_dim), + nn.SiLU(), + nn.Dropout(0.1), + nn.Linear(inter_dim, prompt_len * self.hidden_dim) + ) + self.ln = nn.LayerNorm(self.hidden_dim) + + for p in self.decoder.parameters(): + p.requires_grad = False + self.decoder.eval() + + self.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 0 + self.eos_token_id = tokenizer.eos_token_id + + @torch.no_grad() + def generate_fast(self, cond_emb: torch.Tensor, max_new_tokens: int = 128) -> List[str]: + self.decoder.eval() + B = cond_emb.shape[0] + + prompt = self.prompt_proj(cond_emb.to(self.prompt_proj[0].weight.dtype)) + prompt = prompt.view(B, self.prompt_len, self.hidden_dim) + prompt = self.ln(prompt).to(self.decoder.dtype) + + p_mask = torch.ones(B, self.prompt_len, device=cond_emb.device, dtype=torch.long) + + generated_ids = self.decoder.generate( + inputs_embeds=prompt, + attention_mask=p_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_token_id + ) + return self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + +def load_tunable_parameters(model, path: str): + state = torch.load(path, map_location="cpu", weights_only=False) + missing, unexpected = model.load_state_dict(state, strict=False) + print(f" >>> [Loaded decoder tunable params] <- {path}") + if len(unexpected) > 0: + print(f" [Warn] unexpected keys: {unexpected[:5]} ...") + if len(missing) > 0: + print(f" [Info] missing keys (expected): {missing[:5]} ...") + + +# ===================== +# Embedding loader +# ===================== +def load_embeddings_pt(path: str, emb_key: Optional[str] = None) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: + obj = torch.load(path, map_location="cpu", weights_only=False) + + if torch.is_tensor(obj): + return obj, None + + if not isinstance(obj, dict): + raise ValueError(f"Unsupported emb_pt format: {type(obj)}. Expect Tensor or dict.") + + if emb_key: + if emb_key not in obj: + raise KeyError(f"--emb_key='{emb_key}' not found in keys={list(obj.keys())[:30]}") + emb = obj[emb_key] + if not torch.is_tensor(emb): + raise ValueError(f"obj['{emb_key}'] is not a tensor.") + meta = {k: v for k, v in obj.items() if k != emb_key} + return emb, meta + + candidates = ["pred_fixed_emb_test", "fixed_emb_test", "fixed_emb", "emb", "embeddings", "z_pred", "z_tgt"] + for k in candidates: + if k in obj and torch.is_tensor(obj[k]): + meta = {kk: vv for kk, vv in obj.items() if kk != k} + return obj[k], meta + + for k, v in obj.items(): + if torch.is_tensor(v): + meta = {kk: vv for kk, vv in obj.items() if kk != k} + print(f"[Warn] auto-picked tensor key='{k}' from emb_pt dict.") + return v, meta + + raise ValueError(f"No tensor found in emb_pt dict keys={list(obj.keys())[:30]}") + + +def _extract_global_indices(meta: Optional[Dict[str, Any]]) -> Optional[List[int]]: + if not isinstance(meta, dict): + return None + for k in ["global_test_indices", "sample_idx", "global_indices"]: + if k in meta: + v = meta[k] + if torch.is_tensor(v): + return v.cpu().long().tolist() + if isinstance(v, (list, tuple)): + return [int(x) for x in v] + return None + + +# ===================== +# Dataset for decoding (with HF metadata) +# ===================== +class DecodeDataset(Dataset): + """ + Each item includes: + emb, global_index, problem_id, buggy_submission_id, fixed_submission_id, language, gt_fixed_code + """ + def __init__( + self, + emb_cpu: torch.Tensor, + global_indices: List[int], + problem_ids: List[Any], + buggy_sids: List[Any], + fixed_sids: List[Any], + languages: List[Any], + gt_fixed_codes: List[str], + ): + self.emb = emb_cpu + self.gidx = global_indices + self.pid = problem_ids + self.bsid = buggy_sids + self.fsid = fixed_sids + self.lang = languages + self.gt = gt_fixed_codes + + def __len__(self): + return int(self.emb.shape[0]) + + def __getitem__(self, i: int): + return ( + self.emb[i], + int(self.gidx[i]), + str(self.pid[i]), + int(self.bsid[i]), + int(self.fsid[i]), + str(self.lang[i]), + str(self.gt[i]), + ) + + +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--emb_pt", type=str, required=True, help="Path to pred_fixed_emb.pt OR fixed_emb.pt") + ap.add_argument("--emb_key", type=str, default="", help="If emb_pt is a dict, pick this key as embedding tensor") + ap.add_argument("--out_jsonl", type=str, default="", help="Override output jsonl path (default uses SAVE_PRED_CODE_JSONL)") + ap.add_argument("--max_items", type=int, default=-1, help="Decode at most N embeddings per rank (debug).") + return ap.parse_args() + + +@torch.no_grad() +def main(): + args = parse_args() + + # ---- DDP setup ---- + if ddp_enabled(): + ddp_setup("nccl" if torch.cuda.is_available() else "gloo") + + # ---- device ---- + if torch.cuda.is_available(): + device = torch.device(f"cuda:{ddp_local_rank()}") + else: + device = torch.device("cpu") + + r = ddp_rank() + w = ddp_world() + if ddp_enabled(): + if is_main(): + print(f"[DDP] enabled: world_size={w}") + print(f"[DDP] rank={r} local_rank={ddp_local_rank()} device={device}") + + # ---- output paths ---- + base_out = args.out_jsonl.strip() or SAVE_PRED_CODE_JSONL + out_rank = base_out.replace(".jsonl", f".rank{r}.jsonl") + os.makedirs(os.path.dirname(out_rank), exist_ok=True) + + # ---- load embeddings + global indices ---- + emb_key = args.emb_key.strip() or None + emb_all, meta = load_embeddings_pt(args.emb_pt, emb_key=emb_key) + + if emb_all.dim() != 2: + raise ValueError(f"Embedding tensor must be [N, D], got shape={tuple(emb_all.shape)}") + + global_indices_all = _extract_global_indices(meta) + if global_indices_all is None: + raise ValueError( + "Cannot find global indices in emb_pt meta. " + "Please ensure your embedding pt contains one of: " + "global_test_indices / sample_idx / global_indices." + ) + + if len(global_indices_all) != emb_all.shape[0]: + raise ValueError( + f"Length mismatch: len(global_indices)={len(global_indices_all)} " + f"but emb rows={emb_all.shape[0]}. They must align 1-to-1." + ) + + # ---- shard by rank (same as before) ---- + emb_shard = emb_all[r::w].contiguous() + gidx_shard = global_indices_all[r::w] + + if args.max_items > 0: + emb_shard = emb_shard[: args.max_items].contiguous() + gidx_shard = gidx_shard[: args.max_items] + + if is_main(): + print(f"[Info] emb_all shape={tuple(emb_all.shape)} from {args.emb_pt}") + print(f"[Info] rank={r} emb_shard shape={tuple(emb_shard.shape)} gidx_shard={len(gidx_shard)}") + + # ---- load HF dataset metadata for THIS shard only ---- + # This matches your ViT pipeline Step A format fields + ds = load_dataset(HF_DATASET_ID, split=HF_SPLIT) + subset = ds.select([int(x) for x in gidx_shard]) + + gt_fixed_codes = [str(x) if x is not None else "" for x in subset[HF_FIXED_FIELD]] + languages = subset["language"] + problem_ids = subset["problem_id"] + buggy_sids = subset["buggy_submission_id"] + fixed_sids = subset["fixed_submission_id"] + + # ---- load decoder tokenizer + model ---- + tokenizer = AutoTokenizer.from_pretrained(DECODER_MODEL_ID) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + print(f"[Rank {r}] Loading StarCoder decoder (bfloat16)...") + decoder = AutoModelForCausalLM.from_pretrained( + DECODER_MODEL_ID, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True + ).to(device) + decoder.eval() + for p in decoder.parameters(): + p.requires_grad = False + + cond_dim = int(emb_shard.shape[1]) + model = SoftPromptStarCoderDecoder(cond_dim, decoder, tokenizer, prompt_len=PROMPT_LEN).to(device) + + try: + model.prompt_proj.to(torch.bfloat16) + model.ln.to(torch.bfloat16) + except Exception as e: + print(f"[Warn] Failed to cast prompt modules to bfloat16: {e}") + + if not os.path.exists(DECODER_TUNABLE_CKPT): + raise FileNotFoundError(f"Not found: {DECODER_TUNABLE_CKPT}") + load_tunable_parameters(model, DECODER_TUNABLE_CKPT) + + # ---- DataLoader ---- + ds_decode = DecodeDataset( + emb_cpu=emb_shard, + global_indices=gidx_shard, + problem_ids=problem_ids, + buggy_sids=buggy_sids, + fixed_sids=fixed_sids, + languages=languages, + gt_fixed_codes=gt_fixed_codes, + ) + dl = DataLoader(ds_decode, batch_size=BATCH_DECODER, shuffle=False, num_workers=4, pin_memory=True) + + # ---- decode + write JSONL (ViT pipeline format) ---- + with open(out_rank, "w", encoding="utf-8") as f: + pbar = tqdm(dl, desc=f"Decoding(rank {r})", leave=True) + for b_emb, b_gidx, b_pid, b_bsid, b_fsid, b_lang, b_gt in pbar: + b_emb = b_emb.to(device, non_blocking=True) + preds = model.generate_fast(b_emb, max_new_tokens=MAX_NEW_TOKENS) + + # ensure python types + if torch.is_tensor(b_gidx): b_gidx = b_gidx.tolist() + if torch.is_tensor(b_bsid): b_bsid = b_bsid.tolist() + if torch.is_tensor(b_fsid): b_fsid = b_fsid.tolist() + + # b_pid/b_lang/b_gt are lists (strings) after default collate + for i in range(len(preds)): + record = { + "global_index": int(b_gidx[i]), + "problem_id": str(b_pid[i]), + "buggy_submission_id": int(b_bsid[i]), + "fixed_submission_id": int(b_fsid[i]), + "language": str(b_lang[i]), + "preds": [preds[i]], + "gt_fixed_code": str(b_gt[i]), + } + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + print(f"[Rank {r}] wrote -> {out_rank}") + + # ---- merge on rank0 ---- + ddp_barrier() + if ddp_enabled() and is_main(): + merged_path = base_out + with open(merged_path, "w", encoding="utf-8") as fout: + for rr in range(ddp_world()): + part = base_out.replace(".jsonl", f".rank{rr}.jsonl") + if not os.path.exists(part): + print(f"[Merge] missing shard: {part}") + continue + with open(part, "r", encoding="utf-8") as fin: + for line in fin: + fout.write(line) + print(f"[Merge] merged jsonl -> {merged_path}") + + if ddp_enabled(): + ddp_cleanup() + + +if __name__ == "__main__": + main() diff --git a/JEPA/src/tasks/decoder/train_lora.py b/JEPA/src/tasks/decoder/train_lora.py new file mode 100644 index 0000000..e9d4e72 --- /dev/null +++ b/JEPA/src/tasks/decoder/train_lora.py @@ -0,0 +1,874 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +train_decoder_softprompt.py + +Train decoder soft prompt (prompt_proj) with StarCoder2-3B frozen or LoRA-tuned. +Condition embeddings: z_pred from JEPA (train + val). + +Input embedding pt format (from embed_jepa.py): + train_emb_pt: {"z_pred": Tensor[N,D], "global_indices": LongTensor[N]} + val_emb_pt: {"z_pred": Tensor[M,D], "global_indices": LongTensor[M]} + +Supervision: + fixed_code from HF dataset (ASSERT-KTH/RunBugRun-Final, split=train) + tokenize with StarCoder2 tokenizer + teacher forcing loss (CrossEntropy) + +DDP: + torchrun --standalone --nproc_per_node=4 train_decoder_softprompt.py ... + +Saves: + out_dir/ + checkpoints/ + ckpt_step{global_step}.pt # lightweight but true-resume checkpoint (trainable params + optimizer/scaler/meta) + ckpt_best_val.pt # tunable params only + ckpt_epoch{e}.pt # tunable params only + train_log.jsonl + resolved_args.json +""" + +from __future__ import annotations + +import os +import time +import argparse +from typing import Any, Dict, List, Tuple, Optional + +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.profiler import profile, ProfilerActivity + +from peft import LoraConfig, TaskType, get_peft_model +from transformers import AutoTokenizer, AutoModelForCausalLM +from datasets import load_dataset +from tqdm import tqdm +from jepa.utils import ( + JSONLLogger, + ddp_all_reduce_sum, + ddp_barrier, + ddp_cleanup, + ddp_enabled, + ddp_local_rank, + ddp_rank, + ddp_setup, + ddp_world, + is_main, + load_embeddings_pt, + unwrap_ddp, +) + +try: + import wandb # type: ignore +except Exception: + wandb = None + + +def save_tunable_parameters(model: nn.Module, path: str) -> None: + """ + Save only tunable parameters for inference / lightweight eval use. + """ + base_model = unwrap_ddp(model) + + saved = { + "projector": { + name: p.detach().to("cpu") + for name, p in base_model.prompt_proj.named_parameters() + }, + "ln": { + name: p.detach().to("cpu") + for name, p in base_model.ln.named_parameters() + }, + "lora": { + name: p.detach().to("cpu") + for name, p in base_model.decoder.named_parameters() + if p.requires_grad + } + } + torch.save(saved, path) + + +def save_resume_checkpoint( + model: nn.Module, + optimizer: torch.optim.Optimizer, + scaler: Optional[torch.cuda.amp.GradScaler], + path: str, + epoch: int, + global_step: int, + best_val: float, + args, +) -> None: + """ + Lightweight but true-resume checkpoint: + - trainable params only + - optimizer state + - scaler state + - training metadata + Does NOT save frozen decoder backbone. + """ + base_model = unwrap_ddp(model) + + ckpt = { + "epoch": int(epoch), + "global_step": int(global_step), + "best_val": float(best_val), + "args": vars(args), + + "projector": { + name: p.detach().to("cpu") + for name, p in base_model.prompt_proj.named_parameters() + }, + "ln": { + name: p.detach().to("cpu") + for name, p in base_model.ln.named_parameters() + }, + "lora": { + name: p.detach().to("cpu") + for name, p in base_model.decoder.named_parameters() + if p.requires_grad + }, + + "optimizer": optimizer.state_dict(), + "scaler": scaler.state_dict() if scaler is not None else None, + } + torch.save(ckpt, path) + +def load_resume_checkpoint( + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer], + scaler: Optional[torch.cuda.amp.GradScaler], + path: str, +) -> Tuple[int, int, float]: + """ + Restore trainable params + optimizer/scaler + metadata. + Returns: + epoch, global_step, best_val + """ + ckpt = torch.load(path, map_location="cpu", weights_only=False) + base_model = unwrap_ddp(model) + + if "projector" in ckpt: + base_model.prompt_proj.load_state_dict(ckpt["projector"], strict=True) + + if "ln" in ckpt: + base_model.ln.load_state_dict(ckpt["ln"], strict=True) + + if "lora" in ckpt: + current_lora_state = { + name: p + for name, p in base_model.decoder.named_parameters() + if p.requires_grad + } + + missing = [] + for name, tensor in ckpt["lora"].items(): + if name in current_lora_state: + current_lora_state[name].data.copy_(tensor) + else: + missing.append(name) + + if missing and is_main(): + print(f"[Warn] Some LoRA keys from checkpoint were not found: {missing[:10]}") + + if optimizer is not None and ckpt.get("optimizer") is not None: + optimizer.load_state_dict(ckpt["optimizer"]) + + if scaler is not None and ckpt.get("scaler") is not None: + scaler.load_state_dict(ckpt["scaler"]) + + epoch = int(ckpt.get("epoch", 0)) + global_step = int(ckpt.get("global_step", 0)) + best_val = float(ckpt.get("best_val", float("inf"))) + + return epoch, global_step, best_val + + +def get_total_flops_from_prof(prof) -> float: + """ + Robust FLOPs aggregation over profiler events. + """ + total_flops = 0.0 + try: + for evt in prof.key_averages(): + fl = getattr(evt, "flops", None) + if fl is not None: + total_flops += float(fl) + except Exception: + return 0.0 + return float(total_flops) + + +# ------------------------- +# Model: SoftPrompt Decoder +# ------------------------- +class SoftPromptStarCoderDecoder(nn.Module): + def __init__(self, cond_dim: int, decoder_model, tokenizer, prompt_len: int = 128): + super().__init__() + self.decoder = decoder_model + self.tokenizer = tokenizer + self.prompt_len = int(prompt_len) + self.hidden_dim = int(decoder_model.config.hidden_size) + + inter_dim = cond_dim * 4 + self.prompt_proj = nn.Sequential( + nn.Linear(cond_dim, inter_dim), + nn.LayerNorm(inter_dim), + nn.SiLU(), + nn.Dropout(0.1), + nn.Linear(inter_dim, self.prompt_len * self.hidden_dim), + ) + self.ln = nn.LayerNorm(self.hidden_dim) + + self.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 0 + self.eos_token_id = tokenizer.eos_token_id + + def forward(self, cond_emb: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """ + Teacher-forcing loss on fixed_code tokens. + """ + B, T = input_ids.shape + + prompt = self.prompt_proj(cond_emb.to(self.prompt_proj[0].weight.dtype)) + prompt = prompt.view(B, self.prompt_len, self.hidden_dim) + prompt = self.ln(prompt) + + tok_emb = self.decoder.get_input_embeddings()(input_ids) + full_emb = torch.cat([prompt.to(self.decoder.dtype), tok_emb.to(self.decoder.dtype)], dim=1) + + prompt_mask = torch.ones(B, self.prompt_len, device=input_ids.device, dtype=attention_mask.dtype) + full_mask = torch.cat([prompt_mask, attention_mask], dim=1) + + full_labels = torch.full((B, self.prompt_len + T), -100, device=input_ids.device, dtype=torch.long) + code_labels = input_ids.clone() + code_labels[attention_mask == 0] = -100 + full_labels[:, self.prompt_len:] = code_labels + + out = self.decoder(inputs_embeds=full_emb, attention_mask=full_mask, labels=full_labels) + return out.loss + + +# ------------------------- +# Dataset +# ------------------------- +class CondCodeDataset(Dataset): + def __init__( + self, + emb_cpu: torch.Tensor, # [N,D] CPU float + gidx_cpu: torch.Tensor, # [N] CPU long + fixed_codes: List[str], + tokenizer, + max_len: int, + ): + self.emb = emb_cpu + self.gidx = gidx_cpu + self.fixed_codes = fixed_codes + self.tokenizer = tokenizer + self.max_len = int(max_len) + + def __len__(self) -> int: + return int(self.emb.shape[0]) + + def __getitem__(self, i: int): + e = self.emb[i] + gi = int(self.gidx[i]) + code = self.fixed_codes[i] + enc = self.tokenizer( + code, + truncation=True, + max_length=self.max_len, + padding=False, + return_tensors=None, + ) + return e, enc["input_ids"], enc["attention_mask"], gi + + +def collate_pad(batch, pad_token_id: int): + """ + batch: list of (emb, input_ids(list[int]), attention_mask(list[int]), global_idx) + Pads to max length in batch. + """ + embs, ids_list, mask_list, gidx = zip(*batch) + + embs = torch.stack([x.float() for x in embs], dim=0) + gidx = torch.tensor(gidx, dtype=torch.long) + + max_t = max(len(x) for x in ids_list) + B = len(ids_list) + input_ids = torch.full((B, max_t), pad_token_id, dtype=torch.long) + attn = torch.zeros((B, max_t), dtype=torch.long) + + for i in range(B): + ids = torch.tensor(ids_list[i], dtype=torch.long) + m = torch.tensor(mask_list[i], dtype=torch.long) + input_ids[i, : ids.numel()] = ids + attn[i, : m.numel()] = m + + return embs, input_ids, attn, gidx + + +# ------------------------- +# Train / Eval +# ------------------------- +@torch.no_grad() +def evaluate(model: nn.Module, dl: DataLoader, device: torch.device, use_bf16: bool) -> float: + model.eval() + tot_loss = 0.0 + tot_n = 0 + + for b_emb, b_ids, b_mask, _ in dl: + bsz = b_emb.size(0) + b_emb = b_emb.to(device, non_blocking=True) + b_ids = b_ids.to(device, non_blocking=True) + b_mask = b_mask.to(device, non_blocking=True) + + autocast_enabled = (device.type == "cuda") + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=autocast_enabled and use_bf16): + loss = model(b_emb, b_ids, b_mask) + + if not torch.isnan(loss): + tot_loss += float(loss.item()) * bsz + tot_n += bsz + + t = torch.tensor([tot_loss, float(tot_n)], device=device, dtype=torch.float64) + ddp_all_reduce_sum(t) + return (t[0] / t[1]).item() if t[1].item() > 0 else 0.0 + + +def train(args) -> None: + if ddp_enabled(): + ddp_setup("nccl" if torch.cuda.is_available() else "gloo") + + device = torch.device(f"cuda:{ddp_local_rank()}") if torch.cuda.is_available() else torch.device("cpu") + r, w = ddp_rank(), ddp_world() + + os.makedirs(args.out_dir, exist_ok=True) + if is_main(): + os.makedirs(os.path.join(args.out_dir, "checkpoints"), exist_ok=True) + with open(os.path.join(args.out_dir, "resolved_args.json"), "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=2, ensure_ascii=False) + + logger = JSONLLogger(os.path.join(args.out_dir, "train_log.jsonl")) if is_main() else None + + # ------------------------- + # W&B init (rank0 only) + # ------------------------- + if args.wandb: + if wandb is None: + raise RuntimeError("wandb enabled but wandb is not installed in this environment.") + if is_main(): + wandb.init( + entity=args.wandb_entity or None, + project=args.wandb_project or None, + group=args.wandb_group or None, + name=args.wandb_run_name or None, + id=args.wandb_id or None, + resume="never", + config=vars(args), + settings=wandb.Settings(init_timeout=180), + ) + + # Load embeddings + train_emb, train_gidx = load_embeddings_pt(args.train_emb_pt, key=args.emb_key) + val_emb, val_gidx = load_embeddings_pt(args.val_emb_pt, key=args.emb_key) + + if is_main(): + print(f"[Data] train_emb={tuple(train_emb.shape)} val_emb={tuple(val_emb.shape)} emb_key={args.emb_key}") + + # Load HF dataset fixed_code by global indices + ds = load_dataset(args.hf_dataset_id, split=args.hf_split) + + train_fixed = [str(x) if x is not None else "" for x in ds.select(train_gidx.tolist())[args.hf_fixed_field]] + val_fixed = [str(x) if x is not None else "" for x in ds.select(val_gidx.tolist())[args.hf_fixed_field]] + + # Tokenizer/model (StarCoder2) + tokenizer = AutoTokenizer.from_pretrained(args.decoder_model_id) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + decoder = AutoModelForCausalLM.from_pretrained( + args.decoder_model_id, + torch_dtype=torch.bfloat16 if args.use_bf16 else torch.float16, + low_cpu_mem_usage=True, + ).to(device) + + # Add LoRA on decoder + if args.use_lora: + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=args.lora_target_modules, + ) + decoder = get_peft_model(decoder, lora_config) + if is_main(): + decoder.print_trainable_parameters() + else: + for p in decoder.parameters(): + p.requires_grad = False + if is_main(): + print("[LoRA] disabled, decoder frozen") + + # Build soft prompt model + cond_dim = int(train_emb.shape[1]) + model = SoftPromptStarCoderDecoder(cond_dim, decoder, tokenizer, prompt_len=args.prompt_len).to(device) + + # Cast tunable modules to bf16 for speed + if device.type == "cuda" and args.use_bf16: + try: + model.prompt_proj.to(torch.bfloat16) + model.ln.to(torch.bfloat16) + except Exception as e: + if is_main(): + print(f"[Warn] Failed to cast prompt modules to bf16: {e}") + + # Wrap DDP + if ddp_enabled(): + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[ddp_local_rank()], + find_unused_parameters=False + ) + + # Datasets / loaders + train_ds = CondCodeDataset(train_emb, train_gidx, train_fixed, tokenizer, max_len=args.max_len) + val_ds = CondCodeDataset(val_emb, val_gidx, val_fixed, tokenizer, max_len=args.max_len) + + train_sampler = DistributedSampler(train_ds, shuffle=True) if ddp_enabled() else None + val_sampler = DistributedSampler(val_ds, shuffle=False) if ddp_enabled() else None + + collate = lambda b: collate_pad(b, pad_token_id=tokenizer.pad_token_id) + + train_dl = DataLoader( + train_ds, + batch_size=args.batch_size, + sampler=train_sampler, + shuffle=(train_sampler is None), + num_workers=args.num_workers, + pin_memory=True, + collate_fn=collate, + drop_last=True, + ) + val_dl = DataLoader( + val_ds, + batch_size=args.eval_batch_size or args.batch_size, + sampler=val_sampler, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + collate_fn=collate, + drop_last=False, + ) + + base_model = model.module if hasattr(model, "module") else model + + projector_params = [p for p in base_model.prompt_proj.parameters() if p.requires_grad] + ln_params = [p for p in base_model.ln.parameters() if p.requires_grad] + lora_params = [p for p in base_model.decoder.parameters() if p.requires_grad] + tunable_params = projector_params + ln_params + lora_params + + optimizer = torch.optim.AdamW(tunable_params, lr=args.lr, weight_decay=args.weight_decay) + + use_amp = (device.type == "cuda") and args.use_amp + scaler = torch.cuda.amp.GradScaler(enabled=use_amp) + + # ------------------------- + # Resume + # ------------------------- + start_epoch = 1 + global_step = 0 + best_val = float("inf") + + if args.resume_ckpt: + resumed_epoch, resumed_global_step, resumed_best_val = load_resume_checkpoint( + model=model, + optimizer=optimizer, + scaler=scaler, + path=args.resume_ckpt, + ) + start_epoch = resumed_epoch + 1 + global_step = resumed_global_step + best_val = resumed_best_val + + if is_main(): + print(f"[Resume] loaded from {args.resume_ckpt}") + print(f"[Resume] start_epoch={start_epoch} global_step={global_step} best_val={best_val:.4f}") + + # ------------------------- + # FLOPs profiler state + # ------------------------- + prof = None + profiling_active = bool(args.profile_flops and args.profile_steps > 0) + profiled_update_steps = 0 + estimated_flops_per_update = None # global total FLOPs/update across all ranks + + if profiling_active: + activities = [ProfilerActivity.CPU] + if device.type == "cuda": + activities.append(ProfilerActivity.CUDA) + + prof = profile( + activities=activities, + record_shapes=False, + profile_memory=False, + with_stack=False, + with_flops=True, + ) + prof.__enter__() + + if is_main(): + print(f"[FLOPs] profiling first {args.profile_steps} optimizer-update steps ...") + + if is_main(): + n_proj = sum(p.numel() for p in projector_params if p.requires_grad) + n_ln = sum(p.numel() for p in ln_params if p.requires_grad) + n_lora = sum(p.numel() for p in lora_params if p.requires_grad) + print(f"[Trainable] projector={n_proj:,} ln={n_ln:,} lora={n_lora:,} total={n_proj + n_ln + n_lora:,}") + print(f"[Train] world={w} batch={args.batch_size} grad_accum={args.grad_accum} lr={args.lr}") + print(f"[Train] max_len={args.max_len} prompt_len={args.prompt_len} use_bf16={args.use_bf16} use_amp={use_amp}") + + # ------------------------- + # Training + # ------------------------- + for epoch in range(start_epoch, args.epochs + 1): + if ddp_enabled() and train_sampler is not None: + train_sampler.set_epoch(epoch) + + (model.module if hasattr(model, "module") else model).train() + optimizer.zero_grad(set_to_none=True) + + running_loss = 0.0 + running_n = 0 + t0 = time.time() + + pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{args.epochs}", dynamic_ncols=True) if is_main() else train_dl + + for step, (b_emb, b_ids, b_mask, _) in enumerate(pbar): + b_emb = b_emb.to(device, non_blocking=True) + b_ids = b_ids.to(device, non_blocking=True) + b_mask = b_mask.to(device, non_blocking=True) + + autocast_enabled = (device.type == "cuda") + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=autocast_enabled and args.use_bf16): + loss = (model(b_emb, b_ids, b_mask) / args.grad_accum) + + # backward + if use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + # optimizer update + if (step + 1) % args.grad_accum == 0: + if args.clip_grad > 0: + if use_amp: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(tunable_params, args.clip_grad) + + if use_amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + global_step += 1 + + # profiler step + if profiling_active and prof is not None and estimated_flops_per_update is None: + prof.step() + profiled_update_steps += 1 + + if profiled_update_steps >= args.profile_steps: + local_total_flops = get_total_flops_from_prof(prof) + + if is_main() and args.profile_print_tables: + try: + print("\n[FLOPs Debug] profiler table sorted by self_cuda_time_total") + print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20)) + except Exception as e: + print(f"[Warn] Failed to print CUDA-time profiler table: {e}") + + try: + print("\n[FLOPs Debug] profiler table sorted by flops") + print(prof.key_averages().table(sort_by="flops", row_limit=20)) + except Exception as e: + print(f"[Warn] Failed to print FLOPs profiler table: {e}") + + prof.__exit__(None, None, None) + prof = None + profiling_active = False + + local_avg_flops = local_total_flops / float(profiled_update_steps) + + t_flops = torch.tensor([local_avg_flops], device=device, dtype=torch.float64) + ddp_all_reduce_sum(t_flops) + estimated_flops_per_update = t_flops[0].item() + + ddp_barrier() + if is_main(): + if estimated_flops_per_update > 0: + print(f"[FLOPs] estimated_global_flops_per_update = {estimated_flops_per_update:.3e}") + else: + print("[Warn] profiler finished but FLOPs are still 0. " + "This usually means PyTorch profiler did not attribute FLOPs " + "for the fused/transformer kernels in this run.") + + cumulative_flops = None + if estimated_flops_per_update is not None and estimated_flops_per_update > 0: + cumulative_flops = global_step * estimated_flops_per_update + + # track train loss (unscaled) + loss_item = float(loss.item()) * args.grad_accum + running_loss += loss_item * b_emb.size(0) + running_n += b_emb.size(0) + + if is_main() and (global_step % args.log_every == 0): + avg = running_loss / max(1, running_n) + pbar.set_postfix(loss=f"{avg:.4f}") + + if args.wandb: + log_dict = { + "train/loss": float(avg), + "train/epoch": epoch, + "train/step": global_step, + "lr": float(optimizer.param_groups[0]["lr"]), + } + if cumulative_flops is not None: + log_dict["train/cumulative_flops"] = float(cumulative_flops) + log_dict["train/flops_per_update"] = float(estimated_flops_per_update) + wandb.log(log_dict, step=global_step) + + # step-based resume checkpoint + if args.save_every_steps > 0 and (global_step % args.save_every_steps == 0): + if is_main(): + step_ckpt_path = os.path.join( + args.out_dir, "checkpoints", f"ckpt_step{global_step}.pt" + ) + save_resume_checkpoint( + model=model, + optimizer=optimizer, + scaler=scaler, + path=step_ckpt_path, + epoch=epoch, + global_step=global_step, + best_val=best_val, + args=args, + ) + print(f"[Save] step resume ckpt -> {step_ckpt_path}") + + # periodic eval + if args.eval_every_steps > 0 and (global_step % args.eval_every_steps == 0): + if ddp_enabled() and val_sampler is not None: + val_sampler.set_epoch(epoch) + + val_loss = evaluate(model, val_dl, device=device, use_bf16=args.use_bf16) + + if is_main(): + rec = { + "type": "eval", + "epoch": epoch, + "step": global_step, + "val_loss": float(val_loss), + "time": time.time(), + } + if estimated_flops_per_update is not None and estimated_flops_per_update > 0: + rec["flops_per_update"] = float(estimated_flops_per_update) + rec["cumulative_flops"] = float(global_step * estimated_flops_per_update) + + logger.log(rec) + print(f"[Eval] epoch={epoch} step={global_step} val_loss={val_loss:.4f}") + + if args.wandb: + val_log = { + "val/loss": float(val_loss), + "val/epoch": epoch, + "val/step": global_step, + } + if estimated_flops_per_update is not None and estimated_flops_per_update > 0: + val_log["val/cumulative_flops"] = float(global_step * estimated_flops_per_update) + wandb.log(val_log, step=global_step) + + # save best tunable-only checkpoint + if val_loss < best_val: + best_val = val_loss + best_path = os.path.join(args.out_dir, "checkpoints", "ckpt_best_val.pt") + save_tunable_parameters(model, best_path) + print(f"[Save] new best -> {best_path} (val_loss={best_val:.4f})") + + (model.module if hasattr(model, "module") else model).train() + + # epoch end train-loss reduce + t = torch.tensor([running_loss, float(running_n)], device=device, dtype=torch.float64) + ddp_all_reduce_sum(t) + tr_loss = (t[0] / t[1]).item() if t[1].item() > 0 else 0.0 + + # epoch end eval + if ddp_enabled() and val_sampler is not None: + val_sampler.set_epoch(epoch) + val_loss = evaluate(model, val_dl, device=device, use_bf16=args.use_bf16) + + if is_main(): + dt = time.time() - t0 + rec = { + "type": "epoch", + "epoch": epoch, + "step": global_step, + "train_loss": float(tr_loss), + "val_loss": float(val_loss), + "seconds": float(dt), + "time": time.time(), + } + if estimated_flops_per_update is not None and estimated_flops_per_update > 0: + rec["flops_per_update"] = float(estimated_flops_per_update) + rec["cumulative_flops"] = float(global_step * estimated_flops_per_update) + logger.log(rec) + + print(f"[Epoch {epoch}] train_loss={tr_loss:.4f} val_loss={val_loss:.4f} time={dt/60:.1f} min") + + if args.wandb: + epoch_log = { + "epoch/train_loss": float(tr_loss), + "epoch/val_loss": float(val_loss), + "epoch": epoch, + } + if estimated_flops_per_update is not None and estimated_flops_per_update > 0: + epoch_log["epoch/cumulative_flops"] = float(global_step * estimated_flops_per_update) + wandb.log(epoch_log, step=global_step) + + # save epoch tunable-only checkpoint + ep_path = os.path.join(args.out_dir, "checkpoints", f"ckpt_epoch{epoch}.pt") + save_tunable_parameters(model, ep_path) + print(f"[Save] epoch tunable ckpt -> {ep_path}") + + if val_loss < best_val: + best_val = val_loss + best_path = os.path.join(args.out_dir, "checkpoints", "ckpt_best_val.pt") + save_tunable_parameters(model, best_path) + print(f"[Save] new best -> {best_path} (val_loss={best_val:.4f})") + + ddp_barrier() + + # Safety: close profiler if training ended early + if prof is not None: + local_total_flops = get_total_flops_from_prof(prof) + + if is_main() and args.profile_print_tables: + try: + print("\n[FLOPs Debug] profiler table sorted by self_cuda_time_total") + print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20)) + except Exception as e: + print(f"[Warn] Failed to print CUDA-time profiler table: {e}") + + try: + print("\n[FLOPs Debug] profiler table sorted by flops") + print(prof.key_averages().table(sort_by="flops", row_limit=20)) + except Exception as e: + print(f"[Warn] Failed to print FLOPs profiler table: {e}") + + prof.__exit__(None, None, None) + prof = None + + if profiled_update_steps > 0 and estimated_flops_per_update is None: + local_avg_flops = local_total_flops / float(profiled_update_steps) + t_flops = torch.tensor([local_avg_flops], device=device, dtype=torch.float64) + ddp_all_reduce_sum(t_flops) + estimated_flops_per_update = t_flops[0].item() + + ddp_barrier() + if is_main(): + if estimated_flops_per_update > 0: + print(f"[FLOPs] fallback estimated_global_flops_per_update = {estimated_flops_per_update:.3e}") + else: + print("[Warn] profiler fallback also produced FLOPs=0. " + "Likely cause: FLOPs attribution is unavailable for the actual fused kernels used here.") + + if is_main(): + print("Training done.") + + if args.wandb and is_main(): + wandb.finish() + + if ddp_enabled(): + ddp_cleanup() + + +# ------------------------- +# CLI +# ------------------------- +def parse_args(): + ap = argparse.ArgumentParser() + + # Embeddings + ap.add_argument("--train_emb_pt", type=str, required=True, help="train z_pred.pt (dict with z_pred + global_indices)") + ap.add_argument("--val_emb_pt", type=str, required=True, help="val z_pred.pt (dict with z_pred + global_indices)") + ap.add_argument("--emb_key", type=str, default="z_pred", help="key name of embedding tensor inside pt (default: z_pred)") + + # HF dataset + ap.add_argument("--hf_dataset_id", type=str, default="ASSERT-KTH/RunBugRun-Final") + ap.add_argument("--hf_split", type=str, default="train") + ap.add_argument("--hf_fixed_field", type=str, default="fixed_code") + + # LoRA + ap.add_argument("--use_lora", action="store_true", help="Enable LoRA on decoder") + ap.add_argument("--lora_r", type=int, default=8) + ap.add_argument("--lora_alpha", type=int, default=16) + ap.add_argument("--lora_dropout", type=float, default=0.05) + ap.add_argument( + "--lora_target_modules", + type=str, + nargs="+", + default=["q_proj", "k_proj", "v_proj", "o_proj"], + help="Target modules for LoRA" + ) + + # Decoder + ap.add_argument("--decoder_model_id", type=str, default="bigcode/starcoder2-3b") + ap.add_argument("--prompt_len", type=int, default=128) + ap.add_argument("--max_len", type=int, default=512) + + # Wandb + ap.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging (rank0 only)") + ap.add_argument("--wandb_entity", type=str, default="", help="W&B entity (optional)") + ap.add_argument("--wandb_project", type=str, default="CodeRepair_JEPA", help="W&B project") + ap.add_argument("--wandb_group", type=str, default="decoder_zpred", help="W&B group") + ap.add_argument("--wandb_run_name", type=str, default="", help="W&B run name (optional)") + ap.add_argument("--wandb_id", type=str, default="", help="W&B run id") + ap.add_argument("--wandb_resume", type=str, default="auto", help="W&B resume mode: auto|must|never") + + # Train hyperparams + ap.add_argument("--out_dir", type=str, required=True) + ap.add_argument("--epochs", type=int, default=3) + ap.add_argument("--batch_size", type=int, default=20) + ap.add_argument("--eval_batch_size", type=int, default=0, help="0 means use batch_size") + ap.add_argument("--lr", type=float, default=5e-5) + ap.add_argument("--weight_decay", type=float, default=0.01) + ap.add_argument("--grad_accum", type=int, default=1) + ap.add_argument("--clip_grad", type=float, default=1.0) + + # Resume / save + ap.add_argument("--resume_ckpt", type=str, default="", help="Path to step resume checkpoint") + ap.add_argument("--save_every_steps", type=int, default=0, help="Save true-resume checkpoint every N optimizer steps; 0 disables") + + # Runtime + ap.add_argument("--num_workers", type=int, default=4) + ap.add_argument("--log_every", type=int, default=50) + ap.add_argument("--eval_every_steps", type=int, default=0, help="0 disables step eval; epoch eval always runs") + ap.add_argument("--use_bf16", action="store_true", help="use bf16 autocast on CUDA") + ap.add_argument("--use_amp", action="store_true", help="use fp16 GradScaler (if you want fp16)") + + # Flops + ap.add_argument("--profile_flops", action="store_true", help="Estimate FLOPs with torch.profiler") + ap.add_argument("--profile_steps", type=int, default=5, help="Number of optimizer update steps used for FLOPs profiling") + ap.add_argument("--profile_print_tables", action="store_true", help="Print profiler tables for FLOPs debugging") + + return ap.parse_args() + + +if __name__ == "__main__": + args = parse_args() + train(args) diff --git a/JEPA/src/tasks/decoder/train_projector.py b/JEPA/src/tasks/decoder/train_projector.py new file mode 100644 index 0000000..a456b9d --- /dev/null +++ b/JEPA/src/tasks/decoder/train_projector.py @@ -0,0 +1,550 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +train_decoder_softprompt.py + +Train decoder soft prompt (prompt_proj) with StarCoder2-3B frozen. +Condition embeddings: z_pred from JEPA (train + val). + +Input embedding pt format (from embed_jepa.py): + train_emb_pt: {"z_pred": Tensor[N,D], "global_indices": LongTensor[N]} + val_emb_pt: {"z_pred": Tensor[M,D], "global_indices": LongTensor[M]} + +Supervision: + fixed_code from HF dataset (ASSERT-KTH/RunBugRun-Final, split=train) + tokenize with StarCoder2 tokenizer + teacher forcing loss (CrossEntropy) + +DDP: + torchrun --standalone --nproc_per_node=4 train_decoder_softprompt.py ... + +Saves: + out_dir/ + ckpt_epoch{e}.pt # tunable params only (prompt_proj) + ckpt_best_val.pt # best by val_loss + train_log.jsonl # rank0 metrics log + resolved_args.json +""" + +from __future__ import annotations + +import os +import time +import math +import argparse +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Optional + +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler + +from transformers import AutoTokenizer, AutoModelForCausalLM +from datasets import load_dataset +from tqdm import tqdm +from jepa.utils import ( + JSONLLogger, + ddp_all_reduce_sum, + ddp_barrier, + ddp_cleanup, + ddp_enabled, + ddp_local_rank, + ddp_rank, + ddp_setup, + ddp_world, + is_main, + load_embeddings_pt, +) + +try: + import wandb # type: ignore +except Exception: + wandb = None + + +def save_tunable_parameters(model: nn.Module, path: str) -> None: + saved = {name: p.detach().to("cpu") for name, p in model.named_parameters() if p.requires_grad} + torch.save(saved, path) + + +# ------------------------- +# Model: SoftPrompt Decoder (same logic as your exp1) +# ------------------------- +class SoftPromptStarCoderDecoder(nn.Module): + def __init__(self, cond_dim: int, decoder_model, tokenizer, prompt_len: int = 128): + super().__init__() + self.decoder = decoder_model + self.tokenizer = tokenizer + self.prompt_len = int(prompt_len) + self.hidden_dim = int(decoder_model.config.hidden_size) + + inter_dim = cond_dim * 4 + self.prompt_proj = nn.Sequential( + nn.Linear(cond_dim, inter_dim), + nn.LayerNorm(inter_dim), + nn.SiLU(), + nn.Dropout(0.1), + nn.Linear(inter_dim, self.prompt_len * self.hidden_dim), + ) + self.ln = nn.LayerNorm(self.hidden_dim) + + # freeze decoder + for p in self.decoder.parameters(): + p.requires_grad = False + self.decoder.eval() + + self.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 0 + self.eos_token_id = tokenizer.eos_token_id + + def forward(self, cond_emb: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """ + Teacher-forcing loss on fixed_code tokens. + """ + B, T = input_ids.shape + + prompt = self.prompt_proj(cond_emb.to(self.prompt_proj[0].weight.dtype)) + prompt = prompt.view(B, self.prompt_len, self.hidden_dim) + prompt = self.ln(prompt) + + tok_emb = self.decoder.get_input_embeddings()(input_ids) + full_emb = torch.cat([prompt.to(self.decoder.dtype), tok_emb.to(self.decoder.dtype)], dim=1) + + prompt_mask = torch.ones(B, self.prompt_len, device=input_ids.device, dtype=attention_mask.dtype) + full_mask = torch.cat([prompt_mask, attention_mask], dim=1) + + full_labels = torch.full((B, self.prompt_len + T), -100, device=input_ids.device, dtype=torch.long) + code_labels = input_ids.clone() + code_labels[attention_mask == 0] = -100 + full_labels[:, self.prompt_len:] = code_labels + + out = self.decoder(inputs_embeds=full_emb, attention_mask=full_mask, labels=full_labels) + return out.loss + + +# ------------------------- +# Dataset: (z_pred, fixed_code tokens) +# ------------------------- +class CondCodeDataset(Dataset): + def __init__( + self, + emb_cpu: torch.Tensor, # [N,D] CPU float + gidx_cpu: torch.Tensor, # [N] CPU long (global indices in ds_full) + fixed_codes: List[str], + tokenizer, + max_len: int, + ): + self.emb = emb_cpu + self.gidx = gidx_cpu + self.fixed_codes = fixed_codes + self.tokenizer = tokenizer + self.max_len = int(max_len) + + def __len__(self) -> int: + return int(self.emb.shape[0]) + + def __getitem__(self, i: int): + e = self.emb[i] + gi = int(self.gidx[i]) + code = self.fixed_codes[i] + # tokenize per-sample (simple & robust; you can optimize later) + enc = self.tokenizer( + code, + truncation=True, + max_length=self.max_len, + padding=False, + return_tensors=None, + ) + return e, enc["input_ids"], enc["attention_mask"], gi + + +def collate_pad(batch, pad_token_id: int): + """ + batch: list of (emb, input_ids(list[int]), attention_mask(list[int]), global_idx) + Pads to max length in batch. + """ + embs, ids_list, mask_list, gidx = zip(*batch) + + embs = torch.stack([x.float() for x in embs], dim=0) # [B,D] float32 on CPU + gidx = torch.tensor(gidx, dtype=torch.long) + + max_t = max(len(x) for x in ids_list) + B = len(ids_list) + input_ids = torch.full((B, max_t), pad_token_id, dtype=torch.long) + attn = torch.zeros((B, max_t), dtype=torch.long) + + for i in range(B): + ids = torch.tensor(ids_list[i], dtype=torch.long) + m = torch.tensor(mask_list[i], dtype=torch.long) + input_ids[i, : ids.numel()] = ids + attn[i, : m.numel()] = m + + return embs, input_ids, attn, gidx + + +# ------------------------- +# Train / Eval +# ------------------------- +@torch.no_grad() +def evaluate(model: nn.Module, dl: DataLoader, device: torch.device, use_bf16: bool) -> float: + model.eval() + tot_loss = 0.0 + tot_n = 0 + + for b_emb, b_ids, b_mask, _ in dl: + bsz = b_emb.size(0) + b_emb = b_emb.to(device, non_blocking=True) + b_ids = b_ids.to(device, non_blocking=True) + b_mask = b_mask.to(device, non_blocking=True) + + # Use autocast for speed on GPU (bf16 preferred) + autocast_enabled = (device.type == "cuda") + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=autocast_enabled and use_bf16): + loss = model(b_emb, b_ids, b_mask) + + if not torch.isnan(loss): + tot_loss += float(loss.item()) * bsz + tot_n += bsz + + # DDP reduce + t = torch.tensor([tot_loss, float(tot_n)], device=device, dtype=torch.float64) + ddp_all_reduce_sum(t) + return (t[0] / t[1]).item() if t[1].item() > 0 else 0.0 + + +def train(args) -> None: + if ddp_enabled(): + ddp_setup("nccl" if torch.cuda.is_available() else "gloo") + + device = torch.device(f"cuda:{ddp_local_rank()}") if torch.cuda.is_available() else torch.device("cpu") + r, w = ddp_rank(), ddp_world() + + os.makedirs(args.out_dir, exist_ok=True) + if is_main(): + os.makedirs(os.path.join(args.out_dir, "checkpoints"), exist_ok=True) + with open(os.path.join(args.out_dir, "resolved_args.json"), "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=2, ensure_ascii=False) + + logger = JSONLLogger(os.path.join(args.out_dir, "train_log.jsonl")) if is_main() else None + + # ------------------------- + # W&B init (rank0 only) + # ------------------------- + if args.wandb: + if wandb is None: + raise RuntimeError("wandb enabled but wandb is not installed in this environment.") + if is_main(): + wandb.init( + entity=args.wandb_entity or None, + project=args.wandb_project or None, + group=args.wandb_group or None, + name=args.wandb_run_name or None, + id=args.wandb_id or None, + resume=args.wandb_resume or "auto", + config=vars(args), + dir=args.out_dir, + ) + + # Load embeddings + train_emb, train_gidx = load_embeddings_pt(args.train_emb_pt, key=args.emb_key) + val_emb, val_gidx = load_embeddings_pt(args.val_emb_pt, key=args.emb_key) + + if is_main(): + print(f"[Data] train_emb={tuple(train_emb.shape)} val_emb={tuple(val_emb.shape)} emb_key={args.emb_key}") + + # Load HF dataset fixed_code by global indices (vectorized select) + # IMPORTANT: fixed_code should align with embeddings order (same indices list order saved). + ds = load_dataset(args.hf_dataset_id, split=args.hf_split) + + train_fixed = [str(x) if x is not None else "" for x in ds.select(train_gidx.tolist())[args.hf_fixed_field]] + val_fixed = [str(x) if x is not None else "" for x in ds.select(val_gidx.tolist())[args.hf_fixed_field]] + + # Tokenizer/model (StarCoder2) + tokenizer = AutoTokenizer.from_pretrained(args.decoder_model_id) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + decoder = AutoModelForCausalLM.from_pretrained( + args.decoder_model_id, + torch_dtype=torch.bfloat16 if args.use_bf16 else torch.float16, + low_cpu_mem_usage=True, + ).to(device) + + + + # Build soft prompt model + cond_dim = int(train_emb.shape[1]) + model = SoftPromptStarCoderDecoder(cond_dim, decoder, tokenizer, prompt_len=args.prompt_len).to(device) + + # Cast tunable modules to bf16 for speed (keeps behavior close to your exp1) + if device.type == "cuda" and args.use_bf16: + try: + model.prompt_proj.to(torch.bfloat16) + model.ln.to(torch.bfloat16) + except Exception as e: + if is_main(): + print(f"[Warn] Failed to cast prompt modules to bf16: {e}") + + + # Keep exp2 consistent with exp1: train projector only, freeze ln + for p in model.ln.parameters(): + p.requires_grad = False + + # Wrap DDP + if ddp_enabled(): + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[ddp_local_rank()], find_unused_parameters=False) + + # Datasets / loaders + train_ds = CondCodeDataset(train_emb, train_gidx, train_fixed, tokenizer, max_len=args.max_len) + val_ds = CondCodeDataset(val_emb, val_gidx, val_fixed, tokenizer, max_len=args.max_len) + + train_sampler = DistributedSampler(train_ds, shuffle=True) if ddp_enabled() else None + val_sampler = DistributedSampler(val_ds, shuffle=False) if ddp_enabled() else None + + collate = lambda b: collate_pad(b, pad_token_id=tokenizer.pad_token_id) + + train_dl = DataLoader( + train_ds, + batch_size=args.batch_size, + sampler=train_sampler, + shuffle=(train_sampler is None), + num_workers=args.num_workers, + pin_memory=True, + collate_fn=collate, + drop_last=True, + ) + val_dl = DataLoader( + val_ds, + batch_size=args.eval_batch_size or args.batch_size, + sampler=val_sampler, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + collate_fn=collate, + drop_last=False, + ) + + + # Optimizer (keep exp2 consistent with exp1: projector only) + base_model = model.module if hasattr(model, "module") else model + tunable_params = list(base_model.prompt_proj.parameters()) + + optimizer = torch.optim.AdamW(tunable_params, lr=args.lr, weight_decay=args.weight_decay) + + # AMP scaler (not needed for bf16 autocast, but keep optional for fp16) + use_amp = (device.type == "cuda") and args.use_amp + scaler = torch.cuda.amp.GradScaler(enabled=use_amp) + + # Training + best_val = float("inf") + global_step = 0 + + if is_main(): + print(f"[Train] world={w} batch={args.batch_size} grad_accum={args.grad_accum} lr={args.lr}") + print(f"[Train] max_len={args.max_len} prompt_len={args.prompt_len} use_bf16={args.use_bf16} use_amp={use_amp}") + + for epoch in range(1, args.epochs + 1): + if ddp_enabled() and train_sampler is not None: + train_sampler.set_epoch(epoch) + + (model.module if hasattr(model, "module") else model).train() + optimizer.zero_grad(set_to_none=True) + + running_loss = 0.0 + running_n = 0 + t0 = time.time() + + pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{args.epochs}", dynamic_ncols=True) if is_main() else train_dl + + for step, (b_emb, b_ids, b_mask, _) in enumerate(pbar): + b_emb = b_emb.to(device, non_blocking=True) + b_ids = b_ids.to(device, non_blocking=True) + b_mask = b_mask.to(device, non_blocking=True) + + autocast_enabled = (device.type == "cuda") + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=autocast_enabled and args.use_bf16): + loss = (model(b_emb, b_ids, b_mask) / args.grad_accum) + + # backward + if use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + # step + if (step + 1) % args.grad_accum == 0: + if args.clip_grad > 0: + if use_amp: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(tunable_params, args.clip_grad) + + if use_amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + global_step += 1 + + # track train loss (unscaled) + loss_item = float(loss.item()) * args.grad_accum + running_loss += loss_item * b_emb.size(0) + running_n += b_emb.size(0) + + if is_main() and (global_step % args.log_every == 0): + avg = running_loss / max(1, running_n) + pbar.set_postfix(loss=f"{avg:.4f}") + if args.wandb and is_main(): + wandb.log( + { + "train/loss": float(avg), + "train/epoch": epoch, + "train/step": global_step, + "lr": float(optimizer.param_groups[0]["lr"]), + }, + step=global_step, + ) + + + + # periodic eval + if args.eval_every_steps > 0 and (global_step % args.eval_every_steps == 0): + if ddp_enabled() and val_sampler is not None: + val_sampler.set_epoch(epoch) + + val_loss = evaluate(model, val_dl, device=device, use_bf16=args.use_bf16) + + if is_main(): + rec = { + "type": "eval", + "epoch": epoch, + "step": global_step, + "val_loss": float(val_loss), + "time": time.time(), + } + logger.log(rec) + print(f"[Eval] epoch={epoch} step={global_step} val_loss={val_loss:.4f}") + + if args.wandb and is_main(): + wandb.log( + {"val/loss": float(val_loss), "val/epoch": epoch, "val/step": global_step}, + step=global_step, + ) + + # save best + if val_loss < best_val: + best_val = val_loss + best_path = os.path.join(args.out_dir, "checkpoints", "ckpt_best_val.pt") + save_tunable_parameters(model.module if hasattr(model, "module") else model, best_path) + print(f"[Save] new best -> {best_path} (val_loss={best_val:.4f})") + + # IMPORTANT: switch back to train mode after eval + (model.module if hasattr(model, "module") else model).train() + + # epoch end: compute epoch train loss (reduce across ranks) + t = torch.tensor([running_loss, float(running_n)], device=device, dtype=torch.float64) + ddp_all_reduce_sum(t) + tr_loss = (t[0] / t[1]).item() if t[1].item() > 0 else 0.0 + + # epoch end eval + if ddp_enabled() and val_sampler is not None: + val_sampler.set_epoch(epoch) + val_loss = evaluate(model, val_dl, device=device, use_bf16=args.use_bf16) + + if is_main(): + dt = time.time() - t0 + rec = { + "type": "epoch", + "epoch": epoch, + "step": global_step, + "train_loss": float(tr_loss), + "val_loss": float(val_loss), + "seconds": float(dt), + "time": time.time(), + } + logger.log(rec) + print(f"[Epoch {epoch}] train_loss={tr_loss:.4f} val_loss={val_loss:.4f} time={dt/60:.1f} min") + + if args.wandb and is_main(): + wandb.log( + {"epoch/train_loss": float(tr_loss), "epoch/val_loss": float(val_loss), "epoch": epoch}, + step=global_step, + ) + + # save epoch ckpt + ep_path = os.path.join(args.out_dir, "checkpoints", f"ckpt_epoch{epoch}.pt") + save_tunable_parameters(model.module if hasattr(model, "module") else model, ep_path) + print(f"[Save] epoch ckpt -> {ep_path}") + + if val_loss < best_val: + best_val = val_loss + best_path = os.path.join(args.out_dir, "checkpoints", "ckpt_best_val.pt") + save_tunable_parameters(model.module if hasattr(model, "module") else model, best_path) + print(f"[Save] new best -> {best_path} (val_loss={best_val:.4f})") + + ddp_barrier() + + if is_main(): + print("Training done.") + + if args.wandb and is_main(): + wandb.finish() + + if ddp_enabled(): + ddp_cleanup() + + +# ------------------------- +# CLI +# ------------------------- +def parse_args(): + ap = argparse.ArgumentParser() + + # Embeddings + ap.add_argument("--train_emb_pt", type=str, required=True, help="train z_pred.pt (dict with z_pred + global_indices)") + ap.add_argument("--val_emb_pt", type=str, required=True, help="val z_pred.pt (dict with z_pred + global_indices)") + ap.add_argument("--emb_key", type=str, default="z_pred", help="key name of embedding tensor inside pt (default: z_pred)") + + # HF dataset + ap.add_argument("--hf_dataset_id", type=str, default="ASSERT-KTH/RunBugRun-Final") + ap.add_argument("--hf_split", type=str, default="train") + ap.add_argument("--hf_fixed_field", type=str, default="fixed_code") + + # Decoder + ap.add_argument("--decoder_model_id", type=str, default="bigcode/starcoder2-3b") + ap.add_argument("--prompt_len", type=int, default=128) + ap.add_argument("--max_len", type=int, default=512) + + # Wandb + ap.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging (rank0 only)") + ap.add_argument("--wandb_entity", type=str, default="", help="W&B entity (optional)") + ap.add_argument("--wandb_project", type=str, default="CodeRepair_JEPA", help="W&B project") + ap.add_argument("--wandb_group", type=str, default="decoder_zpred", help="W&B group") + ap.add_argument("--wandb_run_name", type=str, default="", help="W&B run name (optional)") + ap.add_argument("--wandb_id", type=str, default="", help="W&B run id (for resume)") + ap.add_argument("--wandb_resume", type=str, default="auto", help="W&B resume mode: auto|must|never") + + # Train hyperparams + ap.add_argument("--out_dir", type=str, required=True) + ap.add_argument("--epochs", type=int, default=3) + ap.add_argument("--batch_size", type=int, default=20) + ap.add_argument("--eval_batch_size", type=int, default=0, help="0 means use batch_size") + ap.add_argument("--lr", type=float, default=5e-5) + ap.add_argument("--weight_decay", type=float, default=0.01) + ap.add_argument("--grad_accum", type=int, default=1) + ap.add_argument("--clip_grad", type=float, default=1.0) + + # Runtime + ap.add_argument("--num_workers", type=int, default=4) + ap.add_argument("--log_every", type=int, default=50) + ap.add_argument("--eval_every_steps", type=int, default=0, help="0 disables step eval; epoch eval always runs") + ap.add_argument("--use_bf16", action="store_true", help="use bf16 autocast on CUDA") + ap.add_argument("--use_amp", action="store_true", help="use fp16 GradScaler (if you want fp16)") + return ap.parse_args() + + +if __name__ == "__main__": + args = parse_args() + train(args) diff --git a/JEPA/src/tasks/encoder/.DS_Store b/JEPA/src/tasks/encoder/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 Dict[str, Any]: + base_path = os.path.join(os.path.dirname(exp_config_path), "base.yaml") + base_cfg = load_yaml(base_path) + exp_cfg = load_yaml(exp_config_path) + cfg = deep_update(base_cfg, exp_cfg) + cfg = apply_overrides(cfg, overrides) + return cfg + + +# ------------------------- +# Data helpers +# ------------------------- +def load_indices(path: str) -> np.ndarray: + return np.load(path) + +def to_device(batch_tok: Dict[str, torch.Tensor], device: torch.device) -> Dict[str, torch.Tensor]: + return {k: v.to(device, non_blocking=True) for k, v in batch_tok.items()} + + +def try_merge_on_rank0(save_path: str, world: int) -> None: + if not is_main(): + return + + def merge_dict_shards(prefix: str, out_name: str, emb_key: str) -> None: + parts: List[Dict[str, Any]] = [] + for r in range(world): + p = os.path.join(save_path, f"{prefix}.rank{r}.pt") + if not os.path.exists(p): + print(f"[Merge] missing shard: {p}") + return + d = load_pt(p) + if not isinstance(d, dict) or ("global_indices" not in d) or (emb_key not in d): + raise ValueError( + f"[Merge] bad shard format: {p} keys={list(d.keys()) if isinstance(d, dict) else type(d)}" + ) + parts.append(d) + + gidx = torch.cat([d["global_indices"] for d in parts], dim=0) + emb = torch.cat([d[emb_key] for d in parts], dim=0) + + out_file = os.path.join(save_path, out_name) + torch.save({emb_key: emb, "global_indices": gidx}, out_file) + print(f"[Merge] wrote: {out_file}") + + merge_dict_shards("z_ctx", "z_ctx.pt", "z_ctx") + # merge_dict_shards("z_pred", "z_pred.pt", "z_pred") + # merge_dict_shards("z_tgt", "z_tgt.pt", "z_tgt") + + # merge metrics + m_parts = [] + for r in range(world): + mp = os.path.join(save_path, f"metrics.rank{r}.json") + if os.path.exists(mp): + m_parts.append(load_json(mp)) + if m_parts: + merged_m = {"per_rank": m_parts} + keys = set.intersection(*[set(m.keys()) for m in m_parts]) if len(m_parts) > 1 else set(m_parts[0].keys()) + avg = {} + for k in keys: + if isinstance(m_parts[0][k], (int, float)): + avg[k] = float(sum(m[k] for m in m_parts)) / max(1, len(m_parts)) + if avg: + merged_m["avg"] = avg + save_json(merged_m, os.path.join(save_path, "metrics.json")) + print(f"[Merge] wrote: {os.path.join(save_path, 'metrics.json')}") + + +# ------------------------- +# Core: choose indices file +# ------------------------- +def resolve_indices_filename(idx_cfg: Dict[str, Any], embed_cfg: Dict[str, Any]) -> str: + """ + Priority: + 1) embed.indices_file (explicit filename) + 2) embed.split in {"train","val","test"} mapping to idx_cfg keys or defaults + """ + if embed_cfg.get("indices_file"): + return str(embed_cfg["indices_file"]) + + split = str(embed_cfg.get("split", "test")).lower() + # Prefer config-provided filenames if present + if split == "train": + return str(idx_cfg.get("train", "train_idx.npy")) + if split == "val": + return str(idx_cfg.get("val", "val_idx.npy")) + if split == "test": + return str(idx_cfg.get("test", "test_idx.npy")) + + raise ValueError(f"Unknown embed.split='{split}'. Use train/val/test or set embed.indices_file.") + + +# ------------------------- +# Embed +# ------------------------- +@torch.no_grad() +def run_embed(cfg: Dict[str, Any]) -> None: + ddp_enabled = bool(cfg.get("ddp", {}).get("enabled", False)) and is_ddp_env() + if ddp_enabled: + ddp_setup(cfg.get("ddp", {}).get("backend", "nccl")) + + device = torch.device("cuda", local_rank()) if torch.cuda.is_available() else torch.device("cpu") + world = int(os.environ.get("WORLD_SIZE", "1")) + + embed_cfg = cfg.get("embed", {}) + ckpt_path = str(embed_cfg.get("ckpt_path", "")) + if not ckpt_path: + raise ValueError("Missing cfg.embed.ckpt_path (path to ckpt_last.pt / ckpt_best.pt / ckpt_epoch*.pt).") + + save_path = str(embed_cfg.get("save_path", "./embed_outputs")) + save_dtype = str(embed_cfg.get("save_dtype", "float16")).lower() + max_items = int(embed_cfg.get("max_items", -1)) + + os.makedirs(save_path, exist_ok=True) + + # data + assert cfg["data"]["source"] == "hf", "This script expects data.source=hf." + hf_cfg = cfg["data"]["hf"] + idx_cfg = cfg["data"]["indices"] + + dataset_id = hf_cfg["dataset_id"] + split_name = hf_cfg.get("split", "train") + buggy_key = hf_cfg["fields"]["buggy"] + fixed_key = hf_cfg["fields"]["fixed"] + + indices_dir = idx_cfg["dir"] + global_target = load_indices(os.path.join(indices_dir, idx_cfg["global_target"])) + + idx_filename = resolve_indices_filename(idx_cfg, embed_cfg) + subset_idx = load_indices(os.path.join(indices_dir, idx_filename)) + + # ds_full global indices aligned with ds_subset_selected order + global_indices_all = global_target[subset_idx] # np.ndarray [N_subset] + + # load HF dataset + ds_full = load_dataset(dataset_id, split=split_name) + ds_subset = ds_full.select(global_target.tolist()) + ds_sel = ds_subset.select(subset_idx.tolist()) + + # tokenizer (encoder tokenizer) + enc_name = cfg["encoder"]["name"] + tokenizer = AutoTokenizer.from_pretrained(enc_name, use_fast=True) + + def collate_fn(batch): + buggy = [str(x.get(buggy_key, "")) if x.get(buggy_key, None) is not None else "" for x in batch] + fixed = [str(x.get(fixed_key, "")) if x.get(fixed_key, None) is not None else "" for x in batch] + tok_buggy = tokenizer( + buggy, + padding=True, + truncation=True, + max_length=int(cfg["encoder"]["max_len"]), + return_tensors="pt", + ) + tok_fixed = tokenizer( + fixed, + padding=True, + truncation=True, + max_length=int(cfg["encoder"]["max_len"]), + return_tensors="pt", + ) + return tok_buggy, tok_fixed + + sampler = DistributedSampler(ds_sel, shuffle=False) if ddp_enabled else None + dl = DataLoader( + ds_sel, + batch_size=int(embed_cfg.get("batch_size", cfg["train"]["batch_size"])), + shuffle=False, + sampler=sampler, + num_workers=int(cfg["data"].get("num_workers", 4)), + pin_memory=True, + collate_fn=collate_fn, + drop_last=False, + ) + + # build models + enc_ctx, emb_dim = build_encoder(cfg, device=device) + if enc_ctx is None: + raise ValueError("End2end embed expects encoder != None.") + enc_tgt, _ = build_encoder(cfg, device=device) + predictor = build_predictor(cfg, emb_dim=emb_dim, device=device) + + if ddp_enabled: + find_unused = bool(cfg.get("ddp", {}).get("find_unused_parameters", False)) + enc_ctx = DDP(enc_ctx, device_ids=[local_rank()], find_unused_parameters=find_unused) + predictor = DDP(predictor, device_ids=[local_rank()], find_unused_parameters=find_unused) + + # load checkpoint + if is_main(): + print(f"[Embed] Loading ckpt: {ckpt_path}") + print(f"[Embed] indices_file: {idx_filename} (N={len(ds_sel)})") + print(f"[Embed] save_path: {save_path}") + + if ddp_enabled and torch.distributed.is_initialized(): + torch.distributed.barrier() + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + if ddp_enabled and torch.distributed.is_initialized(): + torch.distributed.barrier() + + unwrap_ddp(enc_ctx).load_state_dict(ckpt["enc_ctx"], strict=True) + enc_tgt.load_state_dict(ckpt["enc_tgt"], strict=True) + unwrap_ddp(predictor).load_state_dict(ckpt["predictor"], strict=True) + + enc_ctx.eval() + predictor.eval() + enc_tgt.eval() + for p in enc_tgt.parameters(): + p.requires_grad = False + + if save_dtype == "float16": + out_cast = torch.float16 + elif save_dtype == "float32": + out_cast = torch.float32 + else: + raise ValueError("cfg.embed.save_dtype must be 'float16' or 'float32'") + + # storage + out_ctx: List[torch.Tensor] = [] + # out_pred: List[torch.Tensor] = [] + # out_tgt: List[torch.Tensor] = [] + out_gidx: List[int] = [] + + + # metrics meters (rank-local) + m_top1_sum = 0.0 + m_cos_sum = 0.0 + m_std_pred_sum = 0.0 + m_std_tgt_sum = 0.0 + m_n = 0 + + def cosine_mean(a: torch.Tensor, b: torch.Tensor) -> float: + a = a.float() + b = b.float() + a = torch.nn.functional.normalize(a, dim=-1) + b = torch.nn.functional.normalize(b, dim=-1) + return (a * b).sum(dim=-1).mean().item() + + if ddp_enabled and sampler is not None: + sampler.set_epoch(0) + + it_dl = tqdm(dl, desc="Embed", dynamic_ncols=True) if is_main() else dl + + # IMPORTANT: track position within THIS rank's sampler order (rank-local) + pos_in_rank = 0 + + for (tok_buggy, tok_fixed) in it_dl: + if max_items > 0 and m_n >= max_items: + break + + tok_buggy = to_device(tok_buggy, device) + tok_fixed = to_device(tok_fixed, device) + + z_ctx = enc_ctx(tok_buggy["input_ids"], tok_buggy["attention_mask"]) + # z_pred = predictor(z_ctx) + # z_tgt = enc_tgt(tok_fixed["input_ids"], tok_fixed["attention_mask"]) + + # metrics (batch-level) + # top1 = retrieval_top1_acc(z_pred, z_tgt).item() + # cos = cosine_mean(z_pred, z_tgt) + # std_pred = emb_std_mean(z_pred).item() + # std_tgt = emb_std_mean(z_tgt).item() + + bsz = int(z_ctx.size(0)) + # m_top1_sum += top1 * bsz + # m_cos_sum += cos * bsz + # m_std_pred_sum += std_pred * bsz + # m_std_tgt_sum += std_tgt * bsz + m_n += bsz + + # save tensors + out_ctx.append(z_ctx.detach().to("cpu", dtype=out_cast)) + # out_pred.append(z_pred.detach().to("cpu", dtype=out_cast)) + # out_tgt.append(z_tgt.detach().to("cpu", dtype=out_cast)) + + # save ds_full GLOBAL indices aligned with embeddings + ds_pos = np.arange(pos_in_rank, pos_in_rank + bsz, dtype=np.int64) + pos_in_rank += bsz + gidx = global_indices_all[ds_pos] + out_gidx.extend(gidx.tolist()) + + # concat + z_ctx_all = torch.cat(out_ctx, dim=0) if out_ctx else torch.empty((0, emb_dim), dtype=out_cast) + # z_pred_all = torch.cat(out_pred, dim=0) if out_pred else torch.empty((0, emb_dim), dtype=out_cast) + # z_tgt_all = torch.cat(out_tgt, dim=0) if out_tgt else torch.empty((0, emb_dim), dtype=out_cast) + gidx_all = torch.tensor(out_gidx, dtype=torch.long) + + # save per-rank shards as dicts (format aligned with your request) + ctx_path = os.path.join(save_path, f"z_ctx.rank{rank()}.pt") + # pred_path = os.path.join(save_path, f"z_pred.rank{rank()}.pt") + # tgt_path = os.path.join(save_path, f"z_tgt.rank{rank()}.pt") + + save_pt({"z_ctx": z_ctx_all, "global_indices": gidx_all}, ctx_path) + # save_pt({"z_pred": z_pred_all, "global_indices": gidx_all}, pred_path) + # save_pt({"z_tgt": z_tgt_all, "global_indices": gidx_all}, tgt_path) + + if is_main(): + # print(f"[Embed] wrote shards:\n {ctx_path}\n {pred_path}\n {tgt_path}") + print(f"[Embed] wrote shards:\n {ctx_path}") + + # reduce metrics across ranks + # t = torch.tensor( + # [m_top1_sum, m_cos_sum, m_std_pred_sum, m_std_tgt_sum, float(m_n)], + # device=device, + # dtype=torch.float64, + # ) + # ddp_all_reduce_sum(t) + # top1_avg = (t[0] / t[4]).item() if t[4].item() > 0 else 0.0 + # cos_avg = (t[1] / t[4]).item() if t[4].item() > 0 else 0.0 + # std_pred_avg = (t[2] / t[4]).item() if t[4].item() > 0 else 0.0 + # std_tgt_avg = (t[3] / t[4]).item() if t[4].item() > 0 else 0.0 + + # metrics = { + # "top1": float(top1_avg), + # "cos": float(cos_avg), + # "std_pred": float(std_pred_avg), + # "std_tgt": float(std_tgt_avg), + # "n": int(t[4].item()), + # "indices_file": idx_filename, + # } + + t = torch.tensor([float(m_n)], device=device, dtype=torch.float64) + ddp_all_reduce_sum(t) + + metrics = { + "n": int(t[0].item()), + "indices_file": idx_filename, + } + + save_json(metrics, os.path.join(save_path, f"metrics.rank{rank()}.json")) + if is_main(): + print(f"[Embed] metrics: {metrics}") + + # merge on rank0 + if ddp_enabled and torch.distributed.is_initialized(): + torch.distributed.barrier() + try_merge_on_rank0(save_path=save_path, world=world) + + if ddp_enabled: + ddp_cleanup() + + +# ------------------------- +# CLI +# ------------------------- +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--config", type=str, required=True, help="Path to exp config YAML") + ap.add_argument("--set", type=str, action="append", default=[], help="Override config, e.g. --set embed.split=train") + return ap.parse_args() + +def main(): + args = parse_args() + cfg = resolve_config(args.config, args.set) + run_embed(cfg) + +if __name__ == "__main__": + main() diff --git a/JEPA/src/tasks/encoder/train.py b/JEPA/src/tasks/encoder/train.py new file mode 100644 index 0000000..13d3459 --- /dev/null +++ b/JEPA/src/tasks/encoder/train.py @@ -0,0 +1,805 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +train_jepa.py (refactored) + +End-to-end JEPA-style training (exp2): +- Student/context encoder: ModernBERT (full / LoRA) +- Teacher/target encoder: EMA copy (stop-grad) +- Predictor: vit1d or mlp +- Loss: exp2 minimal (cosine align + variance regularizer) + +Features: +- Config merge: configs/base.yaml + configs/exp*.yaml + CLI --set overrides +- Split by saved_indices/*.npy +- DDP via torchrun +- W&B logging (project/group/run_name placeholders allowed) + +Expected files: + models.py, losses.py, utils.py +""" + +from __future__ import annotations + +import os +import json +import time +import argparse +import random +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.nn.functional as F + +from transformers import AutoTokenizer +from pathlib import Path +from tqdm import tqdm + +from jepa.models import build_encoder, build_predictor +from jepa.losses import build_loss, retrieval_top1_acc, emb_std_mean +from jepa.utils import ( + load_yaml, + deep_update, + apply_overrides, + save_resolved_config, + is_ddp_env, + is_main, + rank, + local_rank, + ddp_setup, + ddp_cleanup, + ddp_all_reduce_sum, + seed_everything, + ema_update, + unwrap_ddp, + AverageMeter, +) + +from datasets import load_dataset + + +# W&B is optional; only used if cfg["wandb"]["enabled"] = True +try: + import wandb # type: ignore +except Exception: + wandb = None + + +# ------------------------- +# Config resolve +# ------------------------- +def resolve_config(exp_config_path: str, overrides: List[str]) -> Dict[str, Any]: + base_path = os.path.join(os.path.dirname(exp_config_path), "base.yaml") + base_cfg = load_yaml(base_path) + exp_cfg = load_yaml(exp_config_path) + cfg = deep_update(base_cfg, exp_cfg) + cfg = apply_overrides(cfg, overrides) + return cfg + + +# ------------------------- +# Data +# ------------------------- +def load_indices(indices_dir: str, filename: str) -> np.ndarray: + return np.load(os.path.join(indices_dir, filename)) + + +def to_device(batch_tok: Dict[str, torch.Tensor], device: torch.device) -> Dict[str, torch.Tensor]: + return {k: v.to(device, non_blocking=True) for k, v in batch_tok.items()} + + +# ------------------------- +# RNG helpers (for more faithful resume) +# ------------------------- +def get_rng_state() -> Dict[str, Any]: + state: Dict[str, Any] = { + "python_random_state": random.getstate(), + "numpy_random_state": np.random.get_state(), + "torch_rng_state": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + # per-device RNG states + state["cuda_rng_state_all"] = torch.cuda.get_rng_state_all() + else: + state["cuda_rng_state_all"] = None + return state + + +def set_rng_state(state: Dict[str, Any]) -> None: + try: + if state.get("python_random_state", None) is not None: + random.setstate(state["python_random_state"]) + if state.get("numpy_random_state", None) is not None: + np.random.set_state(state["numpy_random_state"]) + if state.get("torch_rng_state", None) is not None: + torch.random.set_rng_state(state["torch_rng_state"]) + if torch.cuda.is_available() and state.get("cuda_rng_state_all", None) is not None: + torch.cuda.set_rng_state_all(state["cuda_rng_state_all"]) + except Exception as e: + # Best-effort: RNG restore is optional + if is_main(): + print(f"[Resume] Warning: failed to restore RNG state: {e}") + + +# ------------------------- +# Checkpoint +# ------------------------- +def save_checkpoint( + path: str, + enc_ctx: nn.Module, + enc_tgt: nn.Module, + predictor: nn.Module, + optimizer: torch.optim.Optimizer, + scaler: Optional[torch.cuda.amp.GradScaler], + step: int, + epoch: int, + it: int, + cfg: Dict[str, Any], +) -> None: + ckpt = { + "step": int(step), + "epoch": int(epoch), + "it": int(it), # dataloader batch index inside the epoch (best-effort) + "cfg": cfg, + "enc_ctx": unwrap_ddp(enc_ctx).state_dict(), + "enc_tgt": unwrap_ddp(enc_tgt).state_dict() if hasattr(enc_tgt, "state_dict") else enc_tgt.state_dict(), + "predictor": unwrap_ddp(predictor).state_dict(), + "optimizer": optimizer.state_dict(), + "scaler": scaler.state_dict() if scaler is not None else None, + "rng_state": get_rng_state(), + } + torch.save(ckpt, path) + + +class JSONLLogger: + def __init__(self, path: str): + self.path = path + Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True) + + def log(self, record: Dict[str, Any]): + with open(self.path, "a", encoding="utf-8") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + +# ------------------------- +# W&B +# ------------------------- +def wandb_init_if_needed(cfg: Dict[str, Any], out_dir: str) -> None: + wcfg = cfg.get("wandb", {}) + if not bool(wcfg.get("enabled", False)): + return + if wandb is None: + raise RuntimeError("wandb enabled but wandb not installed.") + if not is_main(): + return + + project = wcfg.get("project") or None + group = wcfg.get("group") or None + name = wcfg.get("run_name") or None + entity = wcfg.get("entity") or None + + run_id = wcfg.get("id") or None + resume_mode = wcfg.get("resume", "auto") # "auto" or "must" + + wandb.init( + entity=entity, + project=project, + group=group, + name=name, + id=run_id, + resume=resume_mode, + config=cfg, + dir=out_dir, + ) + + +def wandb_log_if_needed(cfg: Dict[str, Any], metrics: Dict[str, Any], step: int) -> None: + if not bool(cfg.get("wandb", {}).get("enabled", False)): + return + if wandb is None or not is_main(): + return + wandb.log(metrics, step=step) + + +def wandb_finish_if_needed(cfg: Dict[str, Any]) -> None: + if not bool(cfg.get("wandb", {}).get("enabled", False)): + return + if wandb is None or not is_main(): + return + wandb.finish() + + +# ------------------------- +# Train +# ------------------------- +def train(cfg: Dict[str, Any]) -> None: + # DDP init + ddp_enabled = bool(cfg.get("ddp", {}).get("enabled", False)) and is_ddp_env() + if ddp_enabled: + ddp_setup(cfg.get("ddp", {}).get("backend", "nccl")) + + device = torch.device("cuda", local_rank()) if torch.cuda.is_available() else torch.device("cpu") + + # Seed (rank-shift) + seed = int(cfg.get("seed", 42)) + rank() + seed_everything(seed) + + # Output dir per run (good for sbatch arrays) + run_name = cfg.get("run", {}).get("run_name", "run") + save_root = cfg.get("run", {}).get("save_dir", "./checkpoints_jepa") + job_id = os.environ.get("SLURM_JOB_ID", "") + ts = time.strftime("%Y%m%d_%H%M%S") + run_folder = f"{run_name}_{ts}" + (f"_job{job_id}" if job_id else "") + out_dir = os.path.join(save_root, run_folder) + ckpt_dir = os.path.join(out_dir, "checkpoints") + metrics_logger = None + if is_main(): + metrics_logger = JSONLLogger(os.path.join(out_dir, "metrics.jsonl")) + + best_val_top1 = -1.0 + best_path = os.path.join(ckpt_dir, "ckpt_best.pt") + last_path = os.path.join(ckpt_dir, "ckpt_last.pt") + + if is_main(): + os.makedirs(ckpt_dir, exist_ok=True) + save_resolved_config(cfg, out_dir) + + # W&B + wandb_init_if_needed(cfg, out_dir) + + # Tokenizer + enc_name = cfg["encoder"]["name"] + tokenizer = AutoTokenizer.from_pretrained(enc_name, use_fast=True) + + # Data (HF + two-level indices) + assert cfg["data"]["source"] == "hf", "This train script currently expects data.source=hf." + + hf_cfg = cfg["data"]["hf"] + idx_cfg = cfg["data"]["indices"] + + dataset_id = hf_cfg["dataset_id"] + split_name = hf_cfg.get("split", "train") + + indices_dir = idx_cfg["dir"] + global_target_idx = load_indices(indices_dir, idx_cfg["global_target"]) + train_idx = load_indices(indices_dir, idx_cfg["train"]) + val_idx = load_indices(indices_dir, idx_cfg["val"]) + + ds_full = load_dataset(dataset_id, split=split_name) + ds_subset = ds_full.select(global_target_idx.tolist()) + ds_train = ds_subset.select(train_idx.tolist()) + ds_val = ds_subset.select(val_idx.tolist()) + + buggy_key = hf_cfg["fields"]["buggy"] + fixed_key = hf_cfg["fields"]["fixed"] + + def collate_fn(batch): + buggy = [str(x.get(buggy_key, "")) if x.get(buggy_key, None) is not None else "" for x in batch] + fixed = [str(x.get(fixed_key, "")) if x.get(fixed_key, None) is not None else "" for x in batch] + tok_buggy = tokenizer( + buggy, + padding=True, + truncation=True, + max_length=int(cfg["encoder"]["max_len"]), + return_tensors="pt", + ) + tok_fixed = tokenizer( + fixed, + padding=True, + truncation=True, + max_length=int(cfg["encoder"]["max_len"]), + return_tensors="pt", + ) + return tok_buggy, tok_fixed + + train_sampler = DistributedSampler(ds_train, shuffle=True) if ddp_enabled else None + val_sampler = DistributedSampler(ds_val, shuffle=False) if ddp_enabled else None + + dl_train = DataLoader( + ds_train, + batch_size=int(cfg["train"]["batch_size"]), + shuffle=(train_sampler is None), + sampler=train_sampler, + num_workers=int(cfg["data"].get("num_workers", 4)), + pin_memory=True, + collate_fn=collate_fn, + drop_last=True, + ) + dl_val = DataLoader( + ds_val, + batch_size=int(cfg["train"]["batch_size"]), + shuffle=False, + sampler=val_sampler, + num_workers=int(cfg["data"].get("num_workers", 4)), + pin_memory=True, + collate_fn=collate_fn, + drop_last=False, + ) + + # Models + enc_ctx, emb_dim = build_encoder(cfg, device=device) + if enc_ctx is None: + raise ValueError("This train_jepa.py is end2end; encoder must not be None.") + + # Teacher encoder (EMA target): SAME architecture as student (including LoRA), but frozen + enc_tgt, _ = build_encoder(cfg, device=device) + enc_tgt.load_state_dict(unwrap_ddp(enc_ctx).state_dict(), strict=True) + for p in enc_tgt.parameters(): + p.requires_grad = False + enc_tgt.eval() + + predictor = build_predictor(cfg, emb_dim=emb_dim, device=device) + + # Wrap student + predictor with DDP + if ddp_enabled: + find_unused = bool(cfg.get("ddp", {}).get("find_unused_parameters", False)) + enc_ctx = DDP(enc_ctx, device_ids=[local_rank()], find_unused_parameters=find_unused) + predictor = DDP(predictor, device_ids=[local_rank()], find_unused_parameters=find_unused) + + # ------------------------- + # Resume (part 1): load weights + global_step (+ epoch/it if present) + # ------------------------- + resume_path = cfg["train"].get("resume_from", "") or "" + resume_strict = bool(cfg["train"].get("resume_strict", True)) + restore_rng = bool(cfg["train"].get("resume_restore_rng", False)) # optional + ckpt = None + + if resume_path: + if is_main(): + print(f"[Resume] Loading checkpoint: {resume_path}") + # ckpt = torch.load(resume_path, map_location="cpu") + + # (optional) DDP barrier for more stable shared-fs IO + if ddp_enabled and torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + ckpt = torch.load(resume_path, map_location="cpu", weights_only=False) + + if ddp_enabled and torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + unwrap_ddp(enc_ctx).load_state_dict(ckpt["enc_ctx"], strict=resume_strict) + enc_tgt.load_state_dict(ckpt["enc_tgt"], strict=resume_strict) + unwrap_ddp(predictor).load_state_dict(ckpt["predictor"], strict=resume_strict) + + global_step = int(ckpt.get("step", 0)) + resume_epoch = int(ckpt.get("epoch", 0)) + resume_it = int(ckpt.get("it", 0)) + + if restore_rng and ckpt.get("rng_state", None) is not None: + set_rng_state(ckpt["rng_state"]) + + if is_main(): + print(f"[Resume] global_step restored to {global_step}") + print(f"[Resume] epoch/it from ckpt: epoch={resume_epoch}, it={resume_it}") + else: + global_step = 0 + resume_epoch = 0 + resume_it = 0 + + # Loss + loss_fn = build_loss(cfg).to(device) + + # Optimizer (separate LR for encoder vs predictor) + lr_encoder = float(cfg["train"].get("lr_encoder", cfg["train"].get("lr", 2e-5))) + lr_predictor = float(cfg["train"].get("lr_predictor", cfg["train"].get("lr", 2e-5))) + wd = float(cfg["train"].get("weight_decay", 0.01)) + betas = cfg.get("optim", {}).get("betas", [0.9, 0.999]) + eps = float(cfg.get("optim", {}).get("eps", 1e-8)) + + enc_params = [p for p in unwrap_ddp(enc_ctx).parameters() if p.requires_grad] + pred_params = [p for p in unwrap_ddp(predictor).parameters() if p.requires_grad] + + optimizer = torch.optim.AdamW( + [ + {"params": enc_params, "lr": lr_encoder}, + {"params": pred_params, "lr": lr_predictor}, + ], + weight_decay=wd, + betas=tuple(betas), + eps=eps, + ) + + # AMP + use_fp16 = bool(cfg["train"].get("fp16", True)) and device.type == "cuda" + scaler = torch.cuda.amp.GradScaler(enabled=use_fp16) + + # ------------------------- + # Resume (part 2): load optimizer/scaler (must be after creation) + # ------------------------- + if ckpt is not None: + if ckpt.get("optimizer", None) is not None: + optimizer.load_state_dict(ckpt["optimizer"]) + if scaler is not None and ckpt.get("scaler", None) is not None: + scaler.load_state_dict(ckpt["scaler"]) + if is_main(): + print("[Resume] optimizer/scaler restored.") + + + + # Train settings + epochs = int(cfg["train"]["epochs"]) + grad_accum = int(cfg["train"].get("grad_accum", 1)) + log_every = int(cfg["train"].get("log_every", 50)) + tau = float(cfg.get("ema", {}).get("tau", 0.996)) + save_ckpt = bool(cfg["train"].get("save_ckpt", True)) + + init_path = os.path.join(ckpt_dir, "ckpt_init.pt") + + if save_ckpt and is_main() and ckpt is None: + save_checkpoint( + init_path, + enc_ctx, + enc_tgt, + predictor, + optimizer, + scaler, + step=0, + epoch=0, + it=0, + cfg=cfg, + ) + print(f"[Init] Saved initial checkpoint to {init_path}") + + save_every_epoch = bool(cfg["train"].get("save_every_epoch", True)) + eval_every_steps = int(cfg["train"].get("eval_every_steps", 0)) # 0 means disabled + save_every_steps = int(cfg["train"].get("save_every_steps", 0)) # 0 means disabled + + # --------------------------------------------------------- + # IMPORTANT FIX: + # Do NOT reset global_step here. Previously you had: global_step = 0 + # That overwrote the resumed step and made it look like training restarted. + # --------------------------------------------------------- + + def cosine_mean(a: torch.Tensor, b: torch.Tensor) -> float: + return F.cosine_similarity(a.float(), b.float(), dim=-1).mean().item() + + def run_validation(epoch: int, global_step: int) -> Dict[str, float]: + enc_ctx.eval() + predictor.eval() + enc_tgt.eval() + if ddp_enabled and val_sampler is not None: + val_sampler.set_epoch(epoch) + + v_loss = AverageMeter() + v_align = AverageMeter() + v_var = AverageMeter() + v_top1 = AverageMeter() + v_std_ctx = AverageMeter() + v_std_tgt = AverageMeter() + v_cos = AverageMeter() + + with torch.no_grad(): + for tok_buggy, tok_fixed in dl_val: + tok_buggy = to_device(tok_buggy, device) + tok_fixed = to_device(tok_fixed, device) + + z_ctx = enc_ctx(tok_buggy["input_ids"], tok_buggy["attention_mask"]) + z_tgt = enc_tgt(tok_fixed["input_ids"], tok_fixed["attention_mask"]) + z_pred = predictor(z_ctx) + + out = loss_fn(z_ctx=z_ctx, z_pred=z_pred, z_tgt=z_tgt) + + bsz = int(z_ctx.size(0)) + top1 = retrieval_top1_acc(z_pred, z_tgt).item() + std_ctx = emb_std_mean(z_ctx).item() + std_tgt = emb_std_mean(z_tgt).item() + cos = cosine_mean(z_pred, z_tgt) + + v_loss.update(out["loss"].item(), bsz) + v_align.update(out["align"].item(), bsz) + v_var.update(out["var"].item(), bsz) + v_top1.update(top1, bsz) + v_std_ctx.update(std_ctx, bsz) + v_std_tgt.update(std_tgt, bsz) + v_cos.update(cos, bsz) + + # DDP reduce + t = torch.tensor( + [ + v_loss.sum, v_loss.count, + v_align.sum, v_align.count, + v_var.sum, v_var.count, + v_top1.sum, v_top1.count, + v_std_ctx.sum, v_std_ctx.count, + v_std_tgt.sum, v_std_tgt.count, + v_cos.sum, v_cos.count, + ], + device=device, + dtype=torch.float64, + ) + ddp_all_reduce_sum(t) + + vals = t.tolist() + + def avg(s, c): + return s / max(1.0, c) + + metrics = { + "val_loss": avg(vals[0], vals[1]), + "val_align": avg(vals[2], vals[3]), + "val_var": avg(vals[4], vals[5]), + "val_top1": avg(vals[6], vals[7]), + "val_std_ctx": avg(vals[8], vals[9]), + "val_std_tgt": avg(vals[10], vals[11]), + "val_cos": avg(vals[12], vals[13]), + } + + if is_main(): + print( + f"[ep {epoch} step {global_step}] VAL " + f"loss={metrics['val_loss']:.4f} align={metrics['val_align']:.4f} var={metrics['val_var']:.4f} " + f"top1={metrics['val_top1']:.3f} std_ctx={metrics['val_std_ctx']:.3f} std_tgt={metrics['val_std_tgt']:.3f} " + f"cos={metrics['val_cos']:.3f}" + ) + + wandb_log_if_needed( + cfg, + { + "val/loss": metrics["val_loss"], + "val/align": metrics["val_align"], + "val/var": metrics["val_var"], + "val/top1": metrics["val_top1"], + "val/std_ctx": metrics["val_std_ctx"], + "val/std_tgt": metrics["val_std_tgt"], + "val/cos": metrics["val_cos"], + "epoch": epoch, + }, + step=global_step, + ) + + if metrics_logger is not None: + metrics_logger.log( + { + "split": "val", + "step": global_step, + "epoch": epoch, + "loss": metrics["val_loss"], + "align": metrics["val_align"], + "var": metrics["val_var"], + "top1": metrics["val_top1"], + "std_ctx": metrics["val_std_ctx"], + "std_tgt": metrics["val_std_tgt"], + "cos": metrics["val_cos"], + "time": time.time(), + } + ) + + return metrics + + # --------------------------------------------------------- + # Resume-aware epoch + dataloader skipping + # --------------------------------------------------------- + # global_step counts optimizer steps (after grad_accum). + steps_per_epoch = max(1, (len(dl_train) // max(1, grad_accum))) + start_epoch = global_step // steps_per_epoch + start_step_in_epoch = global_step % steps_per_epoch + + if is_main(): + print(f"[Resume] steps_per_epoch={steps_per_epoch} (len(dl_train)={len(dl_train)}, grad_accum={grad_accum})") + print(f"[Resume] start_epoch={start_epoch}, start_step_in_epoch={start_step_in_epoch}") + + for epoch in range(start_epoch, epochs): + if ddp_enabled: + train_sampler.set_epoch(epoch) + + enc_ctx.train() + predictor.train() + + meter_loss = AverageMeter() + meter_align = AverageMeter() + meter_var = AverageMeter() + meter_top1 = AverageMeter() + meter_std_ctx = AverageMeter() + meter_std_tgt = AverageMeter() + meter_cos = AverageMeter() + + optimizer.zero_grad(set_to_none=True) + + # Only skip for the very first resumed epoch. + skip_batches = 0 + if epoch == start_epoch and start_step_in_epoch > 0: + skip_batches = start_step_in_epoch * grad_accum + if is_main(): + print(f"[Resume] Skipping {skip_batches} batches to align with global_step={global_step}.") + + if is_main(): + iter_dl = tqdm(dl_train, desc=f"Epoch {epoch}", dynamic_ncols=True) + else: + iter_dl = dl_train + + for it, (tok_buggy, tok_fixed) in enumerate(iter_dl): + if skip_batches > 0 and it < skip_batches: + continue + + tok_buggy = to_device(tok_buggy, device) + tok_fixed = to_device(tok_fixed, device) + + with torch.cuda.amp.autocast(enabled=use_fp16): + z_ctx = enc_ctx(tok_buggy["input_ids"], tok_buggy["attention_mask"]) + with torch.no_grad(): + z_tgt = enc_tgt(tok_fixed["input_ids"], tok_fixed["attention_mask"]) + z_pred = predictor(z_ctx) + + out = loss_fn(z_ctx=z_ctx, z_pred=z_pred, z_tgt=z_tgt) + loss = out["loss"] / grad_accum + + scaler.scale(loss).backward() + + # Step + if (it + 1) % grad_accum == 0: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + # EMA update teacher + ema_update(enc_tgt, unwrap_ddp(enc_ctx), tau=tau) + + global_step += 1 + + # ---- step-based checkpoint saving ---- + if save_ckpt and save_every_steps > 0 and (global_step % save_every_steps == 0) and is_main(): + save_checkpoint( + os.path.join(ckpt_dir, f"ckpt_step{global_step}.pt"), + enc_ctx, enc_tgt, predictor, optimizer, scaler, + global_step, epoch, it, cfg + ) + + # ---- step-based validation + best checkpoint ---- + if eval_every_steps > 0 and (global_step % eval_every_steps == 0): + metrics = run_validation(epoch=epoch, global_step=global_step) + + if save_ckpt and is_main() and metrics["val_top1"] > best_val_top1: + best_val_top1 = metrics["val_top1"] + save_checkpoint( + best_path, + enc_ctx, enc_tgt, predictor, optimizer, scaler, + global_step, epoch, it, cfg + ) + print(f"[ep {epoch} step {global_step}] New BEST ckpt: val_top1={best_val_top1:.3f}") + wandb_log_if_needed(cfg, {"best/val_top1": best_val_top1, "best/epoch": epoch}, step=global_step) + + # resume train mode + enc_ctx.train() + predictor.train() + + # Metrics (local) + with torch.no_grad(): + top1 = retrieval_top1_acc(z_pred, z_tgt).item() + std_ctx = emb_std_mean(z_ctx).item() + std_tgt = emb_std_mean(z_tgt).item() + cos = cosine_mean(z_pred, z_tgt) + + bsz = int(z_ctx.size(0)) + meter_loss.update(out["loss"].item(), bsz) + meter_align.update(out["align"].item(), bsz) + meter_var.update(out["var"].item(), bsz) + meter_top1.update(top1, bsz) + meter_std_ctx.update(std_ctx, bsz) + meter_std_tgt.update(std_tgt, bsz) + meter_cos.update(cos, bsz) + + if global_step % log_every == 0: + # DDP reduce meters (sum,count) using a tensor + t = torch.tensor( + [ + meter_loss.sum, meter_loss.count, + meter_align.sum, meter_align.count, + meter_var.sum, meter_var.count, + meter_top1.sum, meter_top1.count, + meter_std_ctx.sum, meter_std_ctx.count, + meter_std_tgt.sum, meter_std_tgt.count, + meter_cos.sum, meter_cos.count, + ], + device=device, + dtype=torch.float64, + ) + ddp_all_reduce_sum(t) + + if is_main(): + vals = t.tolist() + + def avg(s, c): + return s / max(1.0, c) + + tr_loss = avg(vals[0], vals[1]) + tr_align = avg(vals[2], vals[3]) + tr_var = avg(vals[4], vals[5]) + tr_top1 = avg(vals[6], vals[7]) + tr_std_ctx = avg(vals[8], vals[9]) + tr_std_tgt = avg(vals[10], vals[11]) + tr_cos = avg(vals[12], vals[13]) + + print( + f"[ep {epoch} step {global_step}] " + f"train loss={tr_loss:.4f} align={tr_align:.4f} var={tr_var:.4f} " + f"top1={tr_top1:.3f} std_ctx={tr_std_ctx:.3f} std_tgt={tr_std_tgt:.3f} " + f"cos={tr_cos:.3f}" + ) + + wandb_log_if_needed( + cfg, + { + "train/loss": tr_loss, + "train/align": tr_align, + "train/var": tr_var, + "train/top1": tr_top1, + "train/std_ctx": tr_std_ctx, + "train/std_tgt": tr_std_tgt, + "train/cos": tr_cos, + "epoch": epoch, + "lr/encoder": optimizer.param_groups[0]["lr"], + "lr/predictor": optimizer.param_groups[1]["lr"], + }, + step=global_step, + ) + + if metrics_logger is not None: + metrics_logger.log( + { + "split": "train", + "step": global_step, + "epoch": epoch, + "loss": tr_loss, + "align": tr_align, + "var": tr_var, + "top1": tr_top1, + "std_ctx": tr_std_ctx, + "std_tgt": tr_std_tgt, + "cos": tr_cos, + "lr_encoder": optimizer.param_groups[0]["lr"], + "lr_predictor": optimizer.param_groups[1]["lr"], + "time": time.time(), + } + ) + + # after finishing this epoch, no more skipping + start_step_in_epoch = 0 + + # Optional: save last checkpoint each epoch + if save_ckpt and is_main() and save_every_epoch: + save_checkpoint( + os.path.join(ckpt_dir, f"ckpt_epoch{epoch}.pt"), + enc_ctx, enc_tgt, predictor, optimizer, scaler, + global_step, epoch, it, cfg + ) + # also update a "last" symlink-ish file + save_checkpoint( + last_path, + enc_ctx, enc_tgt, predictor, optimizer, scaler, + global_step, epoch, it, cfg + ) + if is_main(): + print(f"[ep {epoch}] Saved epoch ckpt + last ckpt at step={global_step}") + + if is_main(): + print("Training done.") + wandb_finish_if_needed(cfg) + + if ddp_enabled: + ddp_cleanup() + + +# ------------------------- +# CLI +# ------------------------- +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--config", type=str, required=True, help="Path to exp config YAML (e.g. configs/exp2_lora_vit.yaml)") + ap.add_argument("--set", type=str, action="append", default=[], help="Override config, e.g. --set train.lr=1e-5") + return ap.parse_args() + + +def main(): + args = parse_args() + cfg = resolve_config(args.config, args.set) + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/JEPA/src/utils.py b/JEPA/src/utils.py new file mode 100644 index 0000000..7ca7c95 --- /dev/null +++ b/JEPA/src/utils.py @@ -0,0 +1,318 @@ +# utils.py +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import os +import json +import random +from typing import Any, Dict, List, Tuple, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.distributed as dist +import yaml +from torch.profiler import profile, ProfilerActivity + + +# ------------------------- +# Config +# ------------------------- +def load_yaml(path: str) -> Dict[str, Any]: + with open(path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + +def deep_update(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + for k, v in override.items(): + if isinstance(v, dict) and isinstance(base.get(k), dict): + deep_update(base[k], v) + else: + base[k] = v + return base + +def _cast_value(v: str): + s = v.strip() + if s.lower() in ("true", "false"): + return s.lower() == "true" + if s.lower() in ("none", "null"): + return None + if s.startswith("[") and s.endswith("]"): + inner = s[1:-1].strip() + if not inner: + return [] + return [_cast_value(x.strip()) for x in inner.split(",")] + try: + if "." in s or "e" in s.lower(): + return float(s) + return int(s) + except ValueError: + return s + +def set_by_dotted_key(cfg: Dict[str, Any], dotted: str, value: Any) -> None: + keys = dotted.split(".") + cur = cfg + for k in keys[:-1]: + if k not in cur or not isinstance(cur[k], dict): + cur[k] = {} + cur = cur[k] + cur[keys[-1]] = value + +def apply_overrides(cfg: Dict[str, Any], overrides: List[str]) -> Dict[str, Any]: + for ov in overrides: + if "=" not in ov: + raise ValueError(f"Bad override '{ov}', expected key=value") + k, v = ov.split("=", 1) + set_by_dotted_key(cfg, k.strip(), _cast_value(v.strip())) + return cfg + +def save_resolved_config(cfg: Dict[str, Any], out_dir: str, filename: str = "resolved_config.json") -> None: + os.makedirs(out_dir, exist_ok=True) + with open(os.path.join(out_dir, filename), "w", encoding="utf-8") as f: + json.dump(cfg, f, indent=2, ensure_ascii=False) + + +# ------------------------- +# File / JSON helpers +# ------------------------- +def ensure_parent_dir(path: str) -> None: + parent = os.path.dirname(path) + if parent: + os.makedirs(parent, exist_ok=True) + + +def save_json(obj: Any, path: str, indent: int = 2) -> None: + ensure_parent_dir(path) + with open(path, "w", encoding="utf-8") as f: + json.dump(obj, f, ensure_ascii=False, indent=indent) + + +def load_json(path: str) -> Any: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +def save_pt(obj: Any, path: str) -> None: + ensure_parent_dir(path) + torch.save(obj, path) + + +def load_pt(path: str) -> Any: + return torch.load(path, map_location="cpu", weights_only=False) + + +def write_jsonl(records: List[Dict[str, Any]], path: str) -> None: + ensure_parent_dir(path) + with open(path, "w", encoding="utf-8") as f: + for record in records: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + +class JSONLLogger: + def __init__(self, path: str): + self.path = path + ensure_parent_dir(path) + + def log(self, record: Dict[str, Any]) -> None: + with open(self.path, "a", encoding="utf-8") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + +# ------------------------- +# DDP helpers +# ------------------------- +def is_ddp_env() -> bool: + return "RANK" in os.environ and "WORLD_SIZE" in os.environ + + +def ddp_enabled() -> bool: + return is_ddp_env() + +def rank() -> int: + return int(os.environ.get("RANK", "0")) + + +def ddp_rank() -> int: + return rank() + +def local_rank() -> int: + return int(os.environ.get("LOCAL_RANK", "0")) + + +def ddp_local_rank() -> int: + return local_rank() + +def world_size() -> int: + return int(os.environ.get("WORLD_SIZE", "1")) + + +def ddp_world() -> int: + return world_size() + +def is_main() -> bool: + return rank() == 0 + +def ddp_setup(backend: str = "nccl") -> None: + if not ddp_enabled(): + return + if not dist.is_initialized(): + dist.init_process_group(backend=backend) + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank()) + +def ddp_cleanup() -> None: + if dist.is_initialized(): + dist.destroy_process_group() + + +def ddp_barrier() -> None: + if dist.is_initialized(): + dist.barrier() + +def ddp_all_reduce_sum(t: torch.Tensor) -> torch.Tensor: + if dist.is_initialized(): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + return t + + +# ------------------------- +# Reproducibility +# ------------------------- +def seed_everything(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def setup_tf32(allow_tf32: bool) -> None: + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = bool(allow_tf32) + torch.backends.cudnn.allow_tf32 = bool(allow_tf32) + + +# ------------------------- +# EMA / model helpers +# ------------------------- +@torch.no_grad() +def ema_update(teacher: nn.Module, student: nn.Module, tau: float) -> None: + for p_t, p_s in zip(teacher.parameters(), student.parameters()): + p_t.data.mul_(tau).add_(p_s.data, alpha=1.0 - tau) + +def unwrap_ddp(m: nn.Module) -> nn.Module: + return m.module if hasattr(m, "module") else m + +def set_requires_grad(m: nn.Module, requires_grad: bool) -> None: + for p in m.parameters(): + p.requires_grad = requires_grad + + +def extract_named_trainable_parameters(module: nn.Module) -> Dict[str, torch.Tensor]: + base = unwrap_ddp(module) + return { + name: p.detach().to("cpu") + for name, p in base.named_parameters() + if p.requires_grad + } + + +def load_named_parameters(module: nn.Module, saved: Dict[str, torch.Tensor], strict: bool = False) -> List[str]: + base = unwrap_ddp(module) + current = dict(base.named_parameters()) + missing: List[str] = [] + for name, tensor in saved.items(): + if name in current: + current[name].data.copy_(tensor) + elif strict: + missing.append(name) + return missing + + +def extract_from_dict(data: Any, prefer_keys: List[str]) -> Any: + if isinstance(data, dict): + for key in prefer_keys: + if key in data: + return data[key] + for value in data.values(): + if isinstance(value, (list, np.ndarray)): + return value + raise KeyError(f"Cannot find array in dict keys={list(data.keys())[:50]}") + return data + + +# ------------------------- +# Simple meters +# ------------------------- +class AverageMeter: + def __init__(self): + self.reset() + + def reset(self): + self.sum = 0.0 + self.count = 0 + + def update(self, val: float, n: int = 1): + self.sum += float(val) * n + self.count += int(n) + + @property + def avg(self) -> float: + return self.sum / max(1, self.count) + + +# ------------------------- +# FLOPs / Profiler helpers +# ------------------------- +def move_batch_to_device( + batch, + device: torch.device, +): + """ + Move a batch of (b_emb, b_ids, b_mask, gidx/anything) to device. + Keeps the 4th item untouched. + """ + b_emb, b_ids, b_mask, b_extra = batch + return ( + b_emb.to(device, non_blocking=True), + b_ids.to(device, non_blocking=True), + b_mask.to(device, non_blocking=True), + b_extra, + ) + + +def get_total_flops_from_prof(prof) -> float: + """ + Sum FLOPs over all profiler events. + Note: + PyTorch profiler FLOPs are formula-based estimates and do not cover every op. + Still useful for relative comparison across runs. + """ + total_flops = 0.0 + for evt in prof.key_averages(): + flops = getattr(evt, "flops", None) + if flops is not None: + total_flops += float(flops) + return total_flops + + +# ------------------------- +# Embedding helpers +# ------------------------- +def load_embeddings_pt(path: str, key: str = "z_pred", index_key: str = "global_indices") -> Tuple[torch.Tensor, torch.Tensor]: + obj = load_pt(path) + if not isinstance(obj, dict): + raise ValueError(f"emb_pt must be a dict, got {type(obj)}") + if key not in obj: + raise KeyError(f"Missing key '{key}' in {path}. keys={list(obj.keys())[:20]}") + if index_key not in obj: + raise KeyError(f"Missing key '{index_key}' in {path}. keys={list(obj.keys())[:20]}") + + emb = obj[key] + gidx = obj[index_key] + if not torch.is_tensor(emb) or not torch.is_tensor(gidx): + raise ValueError("emb and indices must be torch tensors.") + if emb.dim() != 2: + raise ValueError(f"emb must be [N,D], got {tuple(emb.shape)}") + if gidx.dim() != 1 or gidx.numel() != emb.shape[0]: + raise ValueError(f"indices must be [N], got {tuple(gidx.shape)} vs N={emb.shape[0]}") + return emb.contiguous(), gidx.long().contiguous() From d3690ea19eb362ba49834bbc495440092c1e2135 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=B9=BD=E9=9B=85?= Date: Thu, 9 Apr 2026 23:54:10 +0200 Subject: [PATCH 2/2] add notes --- JEPA/.DS_Store | Bin 10244 -> 6148 bytes JEPA/configs/base.yaml | 8 ++++---- JEPA/configs/exp2_lora_encoder_predictor.yaml | 16 ++++++---------- JEPA/src/.DS_Store | Bin 6148 -> 6148 bytes JEPA/src/tasks/.DS_Store | Bin 6148 -> 6148 bytes JEPA/src/tasks/decoder/.DS_Store | Bin 0 -> 6148 bytes JEPA/src/tasks/decoder/test_projector.py | 2 +- 7 files changed, 11 insertions(+), 15 deletions(-) create mode 100644 JEPA/src/tasks/decoder/.DS_Store diff --git a/JEPA/.DS_Store b/JEPA/.DS_Store index a7de149a2e13354c45bfd17a2f2a974126e571c9..6b3ae3c0f6e2e553ea44ab7355fbec3ce8319357 100644 GIT binary patch delta 174 zcmZn(XfcprU|?W$DortDU=RQ@Ie-{Mvv5r;6q~50D9R3!2a9Dgq%#z!6es5-xp z&MXa*VP{BY$Y;o7NMk^f&Ci*f=00IGS gAmIwKa%16l=E?jjo*)M>FhLvxav#Iyc%C`T0Cz1U{Qv*} literal 10244 zcmeHMTWl3Y7@lui=$z%Yg%-NV!Np>2p+Id*0YN-%DHlPDoYIz3IGoEuH*9y0XZM_f zgji8uyhWo?6N8HJ&4Bu1ln3IQ7!w~1A)+x}9(?h|geMbk|LpAHa*7W=plEiJoqy+_ z`Df?<=KFUt|1!qVQ_S7Xn86t1bPK7?q~ZpbXczaCO9?fi5oFJp#SSqqGvVbk_LSa{ zB0?ZSAVMHQAVMHQ;8s9@_H1$S8B#`Ngg}HqguoR9#P`9VTgXHxC#4KN9aIEI0Fu?D z4la78b3kBYgfbDzNht$U8dIJg5Sk)9VnCWxyCJ%hOoVb$N@>m@%^AWsBRruXcstn{ zL3f6Plu;QW5Fs!X0TDJeEX!Qxu!Gm0-`xzge3;4Gmgl#%U4&9qJ#Bgoui-QJzT}uY zk@VAk(d$VU_KMno>)M&}b2>9 zc4{<7(|W+oi2-v#i>@1Ya=usSdkZF&QmyJO%g7)-;E`wHELX0gDR={iX^o|++R@2l zT5h+U&X{(nTDy)^+lNXL-cWh(Mjoq0>u}Z`A18M_sn(4i&i9Kq??`)of8KPi%xUu$ z&b9L+-dMU|_D>MWLW|AYIS-bM^unkwXZdbH0%GX>obM6Cb?0(4lD42>{zJQToz}&2 zN)}O?)>|@L`+{@cm{=y4CjhlE(I#s#+N^Z8lD4L;Y?PTS&yKQZ*h%&}JIy{|=h%7n zIlIKZVqddw*bnR{cA5QxDpX@SYM?>KLfnQpmZ1ZkSdDJnkL}oj2a&`m_F+E`z=Q)A zc@%I2kKqI!$CG##FXI)wir4TqPU8&T#W{R}^Y|2>;Ud1pclaKc@hg7A?+RCHm8D9P za+lJqv?%S$8YLm8u9QP&MAplt(oIhl9|8&)O5}-l0UqB&k4L;QKr=h z4s$^^{dRRpT$@JRm(`_eQ(T))yqDFM`0}`>P`JtJ%6OZqRZ%d-0^5%Ga#aJDl{M-* zRja13lGP3BMpa`J=(5@^1|q8deXhRG&a#h*s~3o?KN3rSW`BTVHX5)PO=w07TCoCM zSc`R7j}7R<7Hkc;I)Fh8VGl--C8i$47%XC{k0MHVG~nwK#MkHWJYK+ycnK%*2HwOe zyoGo09zLpI;!s73I%hIzDQ6==K4Gryd0{vT{u*LhXpa~1hq zU|E@5|IhmxdQ;(}BtTR~2t)`(2t)`(2t)|nL02Xy^!m2!Vea z0aR~FZt9^~lyd`>6H-WF(0zbzafyDDQU)eeh%Q1Oq~oa~@`}x&F3g$)?TJt~DP>?z gc7Zubna#iHKLcbID?0y2=l>vieXAs_|Hp>I9m51| delta 12 TcmZoMXfc?ujFDmE@>p>I9liu? diff --git a/JEPA/src/tasks/decoder/.DS_Store b/JEPA/src/tasks/decoder/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0