diff --git a/configs/gpt2.yaml b/configs/gpt2.yaml index 1332694..0e28efb 100644 --- a/configs/gpt2.yaml +++ b/configs/gpt2.yaml @@ -1,12 +1,12 @@ global: tokenizer_name: 'r50k_base' - batch_size: 2 + batch_size: 24 optimizer: 'AdamW' - learning_rate: 5.0e-5 + learning_rate: 1.0e-4 weight_decay: 0.01 num_epochs: 1 save_dir: './checkpoints' - log_dir: './logs/gpt2_medium' + log_dir: './logs/gpt2_small' devices: ['cuda:0', 'cuda:1'] eval_frequency: 1000 # Evaluate every 1000 iterations num_val_batches: 50 # Use 50 validation batches for each evaluation diff --git a/docs/gpt2_logs.md b/docs/gpt2_logs.md index 80653cc..ef1c54c 100644 --- a/docs/gpt2_logs.md +++ b/docs/gpt2_logs.md @@ -1,11 +1,18 @@ # GPT-2 Medium Training Log ## Configuration -- Model: GPT-2 Medium - Training Duration: 24 hours - Batch Size: 4 - Learning Rate: 1e-4 - Gradient Clipping: 1.0 +models: + gpt2_medium: + context_length: 1024 + emb_dim: 1024 + n_heads: 16 + n_layers: 24 + drop_rate: 0.0 + qkv_bias: false ## Training Progress @@ -27,11 +34,50 @@ Fruits are rich in vitamins, fiber, and vitamins and are also rich in vitamins a Fruits are rich in vitamins and fiber as ``` -## Challenges and Solutions +# GPT-2 Large Training Log + +## Configuration +- Training Duration: 24 hours +- Batch Size: 2 +- Learning Rate: 5.0e-5 +- Gradient Clipping: 1.0 +models: + gpt2_large: + context_length: 1024 # Maximum sequence length + emb_dim: 1280 # Embedding dimension + n_heads: 20 # Number of attention heads + n_layers: 36 # Number of transformer layers + drop_rate: 0.0 # Dropout rate + qkv_bias: false # Whether to use bias in query, key, and value projections + +## Training Progress + +### Training Loss +![Training Loss](./images/gpt2_large_train_loss.png) + +### Validation Loss +![Validation Loss](./images/gpt2_large_val_loss.png) + +## Evaluation Results +Final validation loss: [Insert final validation loss here] + +## Sample Output +``` +Fruits are good for you because ____. +Fruits are a good way to enjoy your own health benefits. +Fruits are good for your health and will help you to improve your health. +Fruits and vegetables are good for your health. +1. Fruits are good for your health and can help you to maintain healthy blood sugar levels. +2. Fruits are good for children and adults all over the world. +3. Fruits are good for your health if they��re good for your health +``` +Looks like this model generated a multiple choice question! + +# Challenges and Solutions - Issue: NaN loss Solution: Implemented gradient clipping with a max norm of 1.0 - Issue: Special token '<|endoftext|>' in training data Solution: Added `allowed_special` in the tokenization to fix it: `tokens = self.tokenizer.encode(text, allowed_special={"<|endoftext|>"})` - Issue: Learning rate of `3e-4` seemed fine, but later found out that the learning rate should be `1e-4` led to faster convergence. I still have not tried even lower learning rates or learning rate schedules. - Issue: Mixed Precision did not reduce GPU memory usage - - I have not found a way to reduce GPU memory usage using mixed precision. My setup is 2 x RTX 4090 GPUs with 24GB RAM each. I have not yet tried gradient accumulation. \ No newline at end of file + - I have not found a way to reduce GPU memory usage using mixed precision. My setup is 2 x RTX 4090 GPUs with 24GB RAM each. I have not yet tried gradient accumulation. diff --git a/docs/images/gpt2_large_train_loss.png b/docs/images/gpt2_large_train_loss.png index 86118f8..af1031a 100644 Binary files a/docs/images/gpt2_large_train_loss.png and b/docs/images/gpt2_large_train_loss.png differ diff --git a/docs/images/gpt2_large_val_loss.png b/docs/images/gpt2_large_val_loss.png index 2fcbc52..0e9d61f 100644 Binary files a/docs/images/gpt2_large_val_loss.png and b/docs/images/gpt2_large_val_loss.png differ diff --git a/src/llm_forge/train.py b/src/llm_forge/train.py index 6bfa786..8dd4766 100644 --- a/src/llm_forge/train.py +++ b/src/llm_forge/train.py @@ -12,9 +12,8 @@ from llm_forge.utils.config import load_config -def train(model, train_loader, val_loader, loss_function, optimizer, num_epochs, - save_dir, log_dir, devices, eval_frequency, num_val_batches, - use_mixed_precision, model_name): +def train(model, train_loader, loss_function, optimizer, num_epochs, save_dir, + log_dir, devices, eval_frequency, use_mixed_precision, model_name): # Move the model to the primary device model = model.to(devices[0]) @@ -25,20 +24,22 @@ def train(model, train_loader, val_loader, loss_function, optimizer, num_epochs, writer = SummaryWriter(log_dir) os.makedirs(save_dir, exist_ok=True) - best_val_loss = float('inf') global_step = 0 - scaler = torch.amp.GradScaler('cuda', enabled=use_mixed_precision) + best_loss = float('inf') + last_checkpoint_path = None for epoch in range(num_epochs): model.train() progress_bar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") - for batch in train_loader: - # Move the batch to the primary device - for key in batch: - batch[key] = batch[key].to(devices[0]) + total_loss = 0 + steps_since_last_log = 0 + + for batch in train_loader: # Iterate over the DataLoader + # Move each tensor in the batch to the primary device + batch = {k: v.to(devices[0]) for k, v in batch.items()} optimizer.zero_grad(set_to_none=True) @@ -49,48 +50,67 @@ def train(model, train_loader, val_loader, loss_function, optimizer, num_epochs, continue if use_mixed_precision: - scaler.scale( - loss).backward() # Scale the loss for mixed precision - # apply gradient clipping + scaler.scale(loss).backward() scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_( - model.parameters(), - max_norm=1.0) # You can adjust max_norm as needed - scaler.step(optimizer) # Step the optimizer - scaler.update() # Update the scaler + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() else: - loss.backward() # Standard backward pass + loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - optimizer.step() # Step the optimizer - - # torch.cuda.empty_cache() + optimizer.step() + total_loss += loss.item() + steps_since_last_log += 1 global_step += 1 progress_bar.update(1) progress_bar.set_postfix({'train_loss': f"{loss.item():.4f}"}) if global_step % eval_frequency == 0: - # Use the primary device for validation - avg_val_loss = validate(model, - val_loader, - loss_function, - num_val_batches, - device=devices[0]) - best_val_loss = log_and_save(writer, - model, optimizer, global_step, - loss.item(), avg_val_loss, - save_dir, best_val_loss, - model_name) + avg_loss = total_loss / steps_since_last_log + writer.add_scalar('Loss/train', avg_loss, global_step) + print( + f"\nStep {global_step}, Average Train Loss: {avg_loss:.4f}") + + # Save the last model + last_checkpoint_path = save_model(model, + optimizer, + global_step, + avg_loss, + save_dir, + model_name, + is_best=False) + + # Save the best model + if avg_loss < best_loss: + best_loss = avg_loss + save_model(model, + optimizer, + global_step, + avg_loss, + save_dir, + model_name, + is_best=True) + + # Reset loss tracking + total_loss = 0 + steps_since_last_log = 0 progress_bar.close() + # After training is complete, rename the last checkpoint to 'last' + if last_checkpoint_path: + last_model_path = os.path.join(save_dir, f'model_{model_name}_last.pth') + os.rename(last_checkpoint_path, last_model_path) + print(f"Renamed last checkpoint to: {last_model_path}") + writer.close() def process_batch(model, batch, loss_function, device): - input_ids = batch['input_ids'].to(device) - labels = batch['labels'].to(device) - attention_mask = batch['attention_mask'].to(device) + input_ids = batch['input_ids'] + attention_mask = batch['attention_mask'] + labels = batch['labels'] logits = model(input_ids, attention_mask=attention_mask) @@ -115,60 +135,40 @@ def process_batch(model, batch, loss_function, device): return loss -def validate(model, val_loader, loss_function, num_val_batches, device): - model.eval() - total_val_loss = 0 - val_steps = 0 - val_iter = iter(val_loader) - - with torch.no_grad(): - for _ in range(num_val_batches): - try: - val_batch = next(val_iter) - except StopIteration: - val_iter = iter(val_loader) - val_batch = next(val_iter) - - val_loss = process_batch(model, val_batch, loss_function, device) - if val_loss is not None: - total_val_loss += val_loss.item() - val_steps += 1 - - avg_val_loss = total_val_loss / val_steps if val_steps > 0 else float('inf') - return avg_val_loss - +def save_model(model, optimizer, global_step, avg_loss, save_dir, model_name, + is_best): + # Get the original model if it's wrapped in DataParallel + model_to_save = model.module if isinstance(model, + nn.DataParallel) else model + + # Prepare the checkpoint + checkpoint = { + 'model_state_dict': model_to_save.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'global_step': global_step, + 'avg_loss': avg_loss, + } + + # Save the best model + if is_best: + best_checkpoint_name = f'model_{model_name}_best.pth' + torch.save(checkpoint, os.path.join(save_dir, best_checkpoint_name)) + print( + f"Best model saved as {best_checkpoint_name} with average loss: {avg_loss:.4f}" + ) + return None -def log_and_save(writer, model, optimizer, global_step, train_loss, val_loss, - save_dir, best_val_loss, model_name): - writer.add_scalar('Loss/train', train_loss, global_step) - writer.add_scalar('Loss/val', val_loss, global_step) +def load_checkpoint(model, optimizer, checkpoint_path): + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + global_step = checkpoint['global_step'] + avg_loss = checkpoint['avg_loss'] print( - f"\nStep {global_step}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}" + f"Loaded checkpoint from step {global_step} with average loss: {avg_loss:.4f}" ) - - if val_loss < best_val_loss: - best_val_loss = val_loss - - # Get the original model if it's wrapped in DataParallel - model_to_save = model.module if isinstance(model, - nn.DataParallel) else model - - # Save the model state dict with the new naming convention - checkpoint_name = f'model_{model_name}.pth' - torch.save( - { - 'model_state_dict': model_to_save.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'global_step': global_step, - 'best_val_loss': best_val_loss, - }, os.path.join(save_dir, checkpoint_name)) - - print( - f"New best model saved as {checkpoint_name} with validation loss: {best_val_loss:.4f}" - ) - - return best_val_loss + return global_step, avg_loss def main(): @@ -183,6 +183,8 @@ def main(): parser.add_argument("--ds_path", default="datasets/smol_lm_corpus/fineweb_edu", help="Path to the dataset") + parser.add_argument("--checkpoint", + help="Path to a checkpoint to resume training from") args = parser.parse_args() global_config, model_config = load_config(args.config_file, args.model_name) @@ -193,7 +195,7 @@ def main(): model = ModelFactory.create_model('gpt', model_config) - train_loader, val_loader = create_dataloaders( + train_loader, _ = create_dataloaders( ds_path=args.ds_path, batch_size=global_config['batch_size'], max_length=model_config['context_length'], @@ -220,20 +222,16 @@ def main(): print( f"INFO: using {global_config['eval_frequency']} as evaluation frequency" ) - print( - f"INFO: using {global_config['num_val_batches']} as number of validation batches" - ) print( f"INFO: using {global_config['mixed_precision']} as mixed precision training" ) print(f"INFO: model name: {args.model_name}") print(f"INFO: model config: {model_config}") - train(model, train_loader, val_loader, loss_function, optimizer, + train(model, train_loader, loss_function, optimizer, global_config['num_epochs'], global_config['save_dir'], global_config['log_dir'], devices, global_config['eval_frequency'], - global_config['num_val_batches'], use_mixed_precision, - args.model_name) + use_mixed_precision, args.model_name) if __name__ == "__main__": diff --git a/tests/test_train.py b/tests/test_train.py index 50fae02..ee39a72 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -8,14 +8,12 @@ class TestTrain(unittest.TestCase): @patch('llm_forge.train.tqdm') @patch('llm_forge.train.SummaryWriter') - @patch('llm_forge.data.dataset.create_dataloaders') @patch('llm_forge.train.process_batch') - def test_train_function(self, mock_process_batch, mock_create_dataloaders, - mock_summary_writer, mock_tqdm): + def test_train_function(self, mock_process_batch, mock_summary_writer, + mock_tqdm): # Create mock objects mock_model = MagicMock() mock_train_loader = MagicMock() - mock_val_loader = MagicMock() mock_loss_function = MagicMock() mock_optimizer = MagicMock() @@ -27,13 +25,6 @@ def test_train_function(self, mock_process_batch, mock_create_dataloaders, 'labels': torch.randint(0, 100, (2, 10)), 'attention_mask': torch.ones(2, 10) }]) - mock_val_loader.__iter__.return_value = iter([{ - 'input_ids': torch.randint(0, 100, (2, 10)), - 'labels': torch.randint(0, 100, (2, 10)), - 'attention_mask': torch.ones(2, 10) - }]) - mock_create_dataloaders.return_value = (mock_train_loader, - mock_val_loader) # Mock process_batch to return a tensor that requires gradients mock_process_batch.return_value = torch.tensor(0.5, requires_grad=True) @@ -43,27 +34,19 @@ def test_train_function(self, mock_process_batch, mock_create_dataloaders, mock_tqdm_instance.__iter__.return_value = iter(mock_train_loader) mock_tqdm.return_value = mock_tqdm_instance - # Call the train function with the new model_name parameter - train(mock_model, - mock_train_loader, - mock_val_loader, - mock_loss_function, - mock_optimizer, + # Call the train function with the updated parameters + train(model=mock_model, + train_loader=mock_train_loader, + loss_function=mock_loss_function, + optimizer=mock_optimizer, num_epochs=1, save_dir='dummy_save_dir', log_dir='dummy_log_dir', devices=['cpu'], eval_frequency=5, - num_val_batches=1, use_mixed_precision=False, model_name='gpt2_medium') - # Debugging call counts - print(f"train() called: {mock_model.train.call_count}") - print(f"zero_grad() called: {mock_optimizer.zero_grad.call_count}") - print(f"step() called: {mock_optimizer.step.call_count}") - print(f"process_batch() called: {mock_process_batch.call_count}") - # Assert that certain methods were called self.assertTrue(mock_model.train.called, "model.train() was not called") self.assertTrue(mock_optimizer.zero_grad.called,