Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions configs/gpt2.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
global:
tokenizer_name: 'r50k_base'
batch_size: 2
batch_size: 24
optimizer: 'AdamW'
learning_rate: 5.0e-5
learning_rate: 1.0e-4
weight_decay: 0.01
num_epochs: 1
save_dir: './checkpoints'
log_dir: './logs/gpt2_medium'
log_dir: './logs/gpt2_small'
devices: ['cuda:0', 'cuda:1']
eval_frequency: 1000 # Evaluate every 1000 iterations
num_val_batches: 50 # Use 50 validation batches for each evaluation
Expand Down
52 changes: 49 additions & 3 deletions docs/gpt2_logs.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
# GPT-2 Medium Training Log

## Configuration
- Model: GPT-2 Medium
- Training Duration: 24 hours
- Batch Size: 4
- Learning Rate: 1e-4
- Gradient Clipping: 1.0
models:
gpt2_medium:
context_length: 1024
emb_dim: 1024
n_heads: 16
n_layers: 24
drop_rate: 0.0
qkv_bias: false

## Training Progress

Expand All @@ -27,11 +34,50 @@ Fruits are rich in vitamins, fiber, and vitamins and are also rich in vitamins a
Fruits are rich in vitamins and fiber as
```

## Challenges and Solutions
# GPT-2 Large Training Log

## Configuration
- Training Duration: 24 hours
- Batch Size: 2
- Learning Rate: 5.0e-5
- Gradient Clipping: 1.0
models:
gpt2_large:
context_length: 1024 # Maximum sequence length
emb_dim: 1280 # Embedding dimension
n_heads: 20 # Number of attention heads
n_layers: 36 # Number of transformer layers
drop_rate: 0.0 # Dropout rate
qkv_bias: false # Whether to use bias in query, key, and value projections

## Training Progress

### Training Loss
![Training Loss](./images/gpt2_large_train_loss.png)

### Validation Loss
![Validation Loss](./images/gpt2_large_val_loss.png)

## Evaluation Results
Final validation loss: [Insert final validation loss here]

## Sample Output
```
Fruits are good for you because ____.
Fruits are a good way to enjoy your own health benefits.
Fruits are good for your health and will help you to improve your health.
Fruits and vegetables are good for your health.
1. Fruits are good for your health and can help you to maintain healthy blood sugar levels.
2. Fruits are good for children and adults all over the world.
3. Fruits are good for your health if they��re good for your health
```
Looks like this model generated a multiple choice question!

# Challenges and Solutions
- Issue: NaN loss
Solution: Implemented gradient clipping with a max norm of 1.0
- Issue: Special token '<|endoftext|>' in training data
Solution: Added `allowed_special` in the tokenization to fix it: `tokens = self.tokenizer.encode(text, allowed_special={"<|endoftext|>"})`
- Issue: Learning rate of `3e-4` seemed fine, but later found out that the learning rate should be `1e-4` led to faster convergence. I still have not tried even lower learning rates or learning rate schedules.
- Issue: Mixed Precision did not reduce GPU memory usage
- I have not found a way to reduce GPU memory usage using mixed precision. My setup is 2 x RTX 4090 GPUs with 24GB RAM each. I have not yet tried gradient accumulation.
- I have not found a way to reduce GPU memory usage using mixed precision. My setup is 2 x RTX 4090 GPUs with 24GB RAM each. I have not yet tried gradient accumulation.
Binary file modified docs/images/gpt2_large_train_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/gpt2_large_val_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
182 changes: 90 additions & 92 deletions src/llm_forge/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from llm_forge.utils.config import load_config


def train(model, train_loader, val_loader, loss_function, optimizer, num_epochs,
save_dir, log_dir, devices, eval_frequency, num_val_batches,
use_mixed_precision, model_name):
def train(model, train_loader, loss_function, optimizer, num_epochs, save_dir,
log_dir, devices, eval_frequency, use_mixed_precision, model_name):
# Move the model to the primary device
model = model.to(devices[0])

Expand All @@ -25,20 +24,22 @@ def train(model, train_loader, val_loader, loss_function, optimizer, num_epochs,
writer = SummaryWriter(log_dir)
os.makedirs(save_dir, exist_ok=True)

best_val_loss = float('inf')
global_step = 0

scaler = torch.amp.GradScaler('cuda', enabled=use_mixed_precision)
best_loss = float('inf')
last_checkpoint_path = None

for epoch in range(num_epochs):
model.train()
progress_bar = tqdm(total=len(train_loader),
desc=f"Epoch {epoch+1}/{num_epochs}")

for batch in train_loader:
# Move the batch to the primary device
for key in batch:
batch[key] = batch[key].to(devices[0])
total_loss = 0
steps_since_last_log = 0

for batch in train_loader: # Iterate over the DataLoader
# Move each tensor in the batch to the primary device
batch = {k: v.to(devices[0]) for k, v in batch.items()}

optimizer.zero_grad(set_to_none=True)

Expand All @@ -49,48 +50,67 @@ def train(model, train_loader, val_loader, loss_function, optimizer, num_epochs,
continue

if use_mixed_precision:
scaler.scale(
loss).backward() # Scale the loss for mixed precision
# apply gradient clipping
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=1.0) # You can adjust max_norm as needed
scaler.step(optimizer) # Step the optimizer
scaler.update() # Update the scaler
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
else:
loss.backward() # Standard backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step() # Step the optimizer

# torch.cuda.empty_cache()
optimizer.step()

total_loss += loss.item()
steps_since_last_log += 1
global_step += 1
progress_bar.update(1)
progress_bar.set_postfix({'train_loss': f"{loss.item():.4f}"})

if global_step % eval_frequency == 0:
# Use the primary device for validation
avg_val_loss = validate(model,
val_loader,
loss_function,
num_val_batches,
device=devices[0])
best_val_loss = log_and_save(writer,
model, optimizer, global_step,
loss.item(), avg_val_loss,
save_dir, best_val_loss,
model_name)
avg_loss = total_loss / steps_since_last_log
writer.add_scalar('Loss/train', avg_loss, global_step)
print(
f"\nStep {global_step}, Average Train Loss: {avg_loss:.4f}")

# Save the last model
last_checkpoint_path = save_model(model,
optimizer,
global_step,
avg_loss,
save_dir,
model_name,
is_best=False)

# Save the best model
if avg_loss < best_loss:
best_loss = avg_loss
save_model(model,
optimizer,
global_step,
avg_loss,
save_dir,
model_name,
is_best=True)

# Reset loss tracking
total_loss = 0
steps_since_last_log = 0

progress_bar.close()

# After training is complete, rename the last checkpoint to 'last'
if last_checkpoint_path:
last_model_path = os.path.join(save_dir, f'model_{model_name}_last.pth')
os.rename(last_checkpoint_path, last_model_path)
print(f"Renamed last checkpoint to: {last_model_path}")

writer.close()


def process_batch(model, batch, loss_function, device):
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
attention_mask = batch['attention_mask'].to(device)
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']

logits = model(input_ids, attention_mask=attention_mask)

Expand All @@ -115,60 +135,40 @@ def process_batch(model, batch, loss_function, device):
return loss


def validate(model, val_loader, loss_function, num_val_batches, device):
model.eval()
total_val_loss = 0
val_steps = 0
val_iter = iter(val_loader)

with torch.no_grad():
for _ in range(num_val_batches):
try:
val_batch = next(val_iter)
except StopIteration:
val_iter = iter(val_loader)
val_batch = next(val_iter)

val_loss = process_batch(model, val_batch, loss_function, device)
if val_loss is not None:
total_val_loss += val_loss.item()
val_steps += 1

avg_val_loss = total_val_loss / val_steps if val_steps > 0 else float('inf')
return avg_val_loss

def save_model(model, optimizer, global_step, avg_loss, save_dir, model_name,
is_best):
# Get the original model if it's wrapped in DataParallel
model_to_save = model.module if isinstance(model,
nn.DataParallel) else model

# Prepare the checkpoint
checkpoint = {
'model_state_dict': model_to_save.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step,
'avg_loss': avg_loss,
}

# Save the best model
if is_best:
best_checkpoint_name = f'model_{model_name}_best.pth'
torch.save(checkpoint, os.path.join(save_dir, best_checkpoint_name))
print(
f"Best model saved as {best_checkpoint_name} with average loss: {avg_loss:.4f}"
)
return None

def log_and_save(writer, model, optimizer, global_step, train_loss, val_loss,
save_dir, best_val_loss, model_name):
writer.add_scalar('Loss/train', train_loss, global_step)
writer.add_scalar('Loss/val', val_loss, global_step)

def load_checkpoint(model, optimizer, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
global_step = checkpoint['global_step']
avg_loss = checkpoint['avg_loss']
print(
f"\nStep {global_step}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}"
f"Loaded checkpoint from step {global_step} with average loss: {avg_loss:.4f}"
)

if val_loss < best_val_loss:
best_val_loss = val_loss

# Get the original model if it's wrapped in DataParallel
model_to_save = model.module if isinstance(model,
nn.DataParallel) else model

# Save the model state dict with the new naming convention
checkpoint_name = f'model_{model_name}.pth'
torch.save(
{
'model_state_dict': model_to_save.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step,
'best_val_loss': best_val_loss,
}, os.path.join(save_dir, checkpoint_name))

print(
f"New best model saved as {checkpoint_name} with validation loss: {best_val_loss:.4f}"
)

return best_val_loss
return global_step, avg_loss


def main():
Expand All @@ -183,6 +183,8 @@ def main():
parser.add_argument("--ds_path",
default="datasets/smol_lm_corpus/fineweb_edu",
help="Path to the dataset")
parser.add_argument("--checkpoint",
help="Path to a checkpoint to resume training from")
args = parser.parse_args()

global_config, model_config = load_config(args.config_file, args.model_name)
Expand All @@ -193,7 +195,7 @@ def main():

model = ModelFactory.create_model('gpt', model_config)

train_loader, val_loader = create_dataloaders(
train_loader, _ = create_dataloaders(
ds_path=args.ds_path,
batch_size=global_config['batch_size'],
max_length=model_config['context_length'],
Expand All @@ -220,20 +222,16 @@ def main():
print(
f"INFO: using {global_config['eval_frequency']} as evaluation frequency"
)
print(
f"INFO: using {global_config['num_val_batches']} as number of validation batches"
)
print(
f"INFO: using {global_config['mixed_precision']} as mixed precision training"
)
print(f"INFO: model name: {args.model_name}")
print(f"INFO: model config: {model_config}")

train(model, train_loader, val_loader, loss_function, optimizer,
train(model, train_loader, loss_function, optimizer,
global_config['num_epochs'], global_config['save_dir'],
global_config['log_dir'], devices, global_config['eval_frequency'],
global_config['num_val_batches'], use_mixed_precision,
args.model_name)
use_mixed_precision, args.model_name)


if __name__ == "__main__":
Expand Down
Loading