10 lightning
This is an EARLY DRAFT.
In the previous chapter, we explored the theoretical foundations of advanced neural network architectures. Now, we turn our attention to the practical aspects of implementing and training these models efficiently. This chapter introduces PyTorch Lightning, a high-level framework that streamlines neural network implementation while maintaining PyTorch’s flexibility.
Lightning addresses a common challenge in deep learning research and applications: the separation of scientific code (model architecture, loss functions) from engineering code (training loops, device management, logging). By providing a structured organization for neural network development, Lightning enables more reproducible experiments, cleaner code, and faster iteration cycles—all critical factors for successful machine learning projects in economics and beyond.
Throughout this chapter, we will work with our familiar NYC taxi fare prediction task, progressively enhancing our implementation while exploring Lightning’s key features: logging, visualization with TensorBoard, model checkpointing, early stopping, and hyperparameter search. These tools represent the modern practice of deep learning, enabling more efficient model development, improved performance, and increased reproducibility.
10.1 Why Use Lightning?
Traditional PyTorch code, while flexible, often contains repetitive boilerplate that can obscure the core model logic. Consider a typical training loop in vanilla PyTorch:
# Vanilla PyTorch training loop
= MyModel()
model = model.to(device)
model = optim.Adam(model.parameters(), lr=1e-3)
optimizer = nn.MSELoss()
criterion
for epoch in range(num_epochs):
model.train()for batch_idx, (x, y) in enumerate(train_loader):
= x.to(device), y.to(device)
x, y
optimizer.zero_grad()= model(x)
output = criterion(output, y)
loss
loss.backward()
optimizer.step()
eval()
model.= 0
val_loss with torch.no_grad():
for x, y in val_loader:
= x.to(device), y.to(device)
x, y = model(x)
output += criterion(output, y).item()
val_loss /= len(val_loader)
val_loss
print(f"Epoch {epoch}: Train Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}")
This code has several issues:
- Repetitive Patterns: Device management, gradient zeroing, and evaluation code are duplicated across projects.
- Mixed Concerns: Scientific logic (model, loss) is intertwined with engineering details (loops, device management).
- Limited Extensibility: Adding features like learning rate scheduling or early stopping requires significant code changes.
- Poor Reproducibility: Subtle differences in implementation can lead to different results across experiments.
PyTorch Lightning addresses these issues by providing a structured organization for neural network code while preserving PyTorch’s flexibility. The core abstraction in Lightning is the LightningModule
class, which encapsulates:
- The model architecture (
__init__
method) - The forward pass (
forward
method) - Training, validation, and test logic (
training_step
,validation_step
,test_step
methods) - Optimization configuration (
configure_optimizers
method)
This organization makes the code more readable and maintainable while reducing potential sources of error.
Lightning offers several advantages for economics research and applications:
Code Organization: Lightning enforces a clean separation between research code (model architecture, loss functions) and engineering code (training loops, GPU handling).
Reduced Boilerplate: Common operations like moving tensors to the correct device, gradient calculation, and parameter updates are handled automatically.
Built-in Features: Lightning provides out-of-the-box support for logging, checkpointing, early stopping, and other training utilities.
Scalability: The same code can easily scale from a single CPU to multiple GPUs or even multiple machines with minimal changes.
Reproducibility: Lightning makes it easier to ensure consistent results by standardizing the training process.
Readability and Communication: The structured format makes it easier to share code and results with colleagues and reviewers—a crucial consideration for academic research.
For economics applications in particular, reproducibility and transparency are paramount. By separating the scientific components (model definition, hyperparameters) from the engineering details, Lightning makes it easier to communicate and replicate research findings.
10.2 Loading and Preparing Data
Before diving into Lightning-specific code, let’s load and prepare our NYC taxi dataset. As in previous chapters, we’ll download the data, perform initial cleaning operations, and create appropriate data loaders.
from pathlib import Path
import requests
import pandas as pd
import numpy as np
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
import torch
from torch import nn, optim, utils
import lightning as L
# Set random seed for reproducibility
110, workers=True)
L.pytorch.seed_everything(
# Download data if needed
= Path('data/fhvhv_tripdata_2024-01.parquet')
local_path = 'https://d37ci6vzurychx.cloudfront.net/trip-data/fhvhv_tripdata_2024-01.parquet'
url
if not local_path.exists():
=True)
local_path.parent.mkdir(exist_ok
local_path.write_bytes(requests.get(url).content)
# Load and clean data
= pd.read_parquet('data/fhvhv_tripdata_2024-01.parquet',
df = ['hvfhs_license_num','request_datetime',
columns 'trip_miles','trip_time','base_passenger_fare',
'driver_pay','PULocationID','DOLocationID']).sample(1_000_000)
# Clean the data by filtering outliers
= df[(df['trip_miles']>=1)
df & (df['trip_miles']<=20)
& (df['base_passenger_fare']<200)]
# Feature engineering
'request_day_of_week'] = df['request_datetime'].dt.dayofweek
df['request_hour_of_day'] = df['request_datetime'].dt.hour
df['fare_per_mile'] = df['base_passenger_fare']/df['trip_miles'] df[
Note that we’re sampling 1 million records from the dataset to make training more manageable. While this is still a substantial amount of data, it allows for faster experimentation without overly compromising model quality.
Next, we define our feature sets and split the data:
= ['hvfhs_license_num', 'request_day_of_week', 'request_hour_of_day']
categorical_features = ['trip_miles', 'trip_time']
numerical_features
= df[categorical_features+numerical_features+['PULocationID','DOLocationID']]
X = df['fare_per_mile']
y
# Create training, validation, and test sets
= train_test_split(X, y,
X_train_val, X_test, y_train_val, y_test =0.1, random_state=100)
test_size= train_test_split(X_train_val, y_train_val,
X_train, X_val, y_train, y_val =0.1, random_state=100)
test_size# Free memory
del X_train_val
del y_train_val
This three-way split is standard practice in deep learning:
- Training set (81% of data): Used to update model parameters
- Validation set (9% of data): Used to tune hyperparameters and monitor for overfitting
- Test set (10% of data): Used for final evaluation only
The validation set helps us monitor the model’s generalization ability during training and make informed decisions about hyperparameters, while the test set provides an unbiased final evaluation.
For preprocessing our features, we’ll use scikit-learn’s ColumnTransformer
:
= ColumnTransformer([
ct 'num', StandardScaler(), numerical_features),
('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), categorical_features)
(
])
ct.fit(X_train)
# Get min and max values for location IDs to determine embedding sizes
= df['PULocationID'].max(), df['PULocationID'].min()
PUmax, PUmin = PUmax - PUmin + 1
PUvals = df['DOLocationID'].max(), df['DOLocationID'].min()
DOmax, DOmin = DOmax - DOmin + 1 DOvals
Now, let’s create a function to process our data and turn it into PyTorch’s TensorDataset
:
def mk_dataset(X, y):
# Transform features with ColumnTransformer
= ct.transform(X).astype(np.float32)
X_trans return utils.data.TensorDataset(
torch.from_numpy(X_trans),'PULocationID'].values - PUmin),
torch.from_numpy(X['DOLocationID'].values - DOmin),
torch.from_numpy(X[
torch.from_numpy(y.values.astype(np.float32))
)
# Create dataloaders
= utils.data.DataLoader(mk_dataset(X_train, y_train),
train_dl =True, batch_size=1024, num_workers=4)
shuffle= utils.data.DataLoader(mk_dataset(X_val, y_val),
val_dl =1024, num_workers=4)
batch_size= utils.data.DataLoader(mk_dataset(X_test, y_test),
test_dl =1024, num_workers=4) batch_size
Note these important DataLoader parameters: - shuffle=True
for the training set ensures that each epoch sees a different order of samples - batch_size=1024
defines how many samples are processed in each iteration - num_workers=4
enables parallel data loading, which can significantly speed up training
With our data prepared, we’re ready to implement models using Lightning.
10.3 Creating a Basic Lightning Module
Let’s create a base LightningModule
class that implements the common functionality needed by all our models:
class BasicModel(L.LightningModule):
def __init__(self, lr=1e-4):
super().__init__()
self.lr = lr
def training_step(self, batch, batch_idx):
= self.common_step(batch, batch_idx)
y, y_hat = nn.functional.mse_loss(y, y_hat)
loss self.log("training_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
= self.common_step(batch, batch_idx)
y, y_hat = nn.functional.mse_loss(y, y_hat)
loss self.log("val_loss", loss)
return loss
def test_step(self, batch, batch_idx):
= self.common_step(batch, batch_idx)
y, y_hat = nn.functional.mse_loss(y, y_hat)
loss self.log("test_loss", loss)
return loss
def configure_optimizers(self):
= optim.Adam(self.parameters(), self.lr)
optimizer return optimizer
This BasicModel
class:
- Inherits from
L.LightningModule
- Accepts a learning rate parameter
- Defines methods for training, validation, and testing steps
- Uses mean squared error (MSE) as the loss function
- Configures the Adam optimizer
- Logs the loss for each phase (training, validation, test)
The common_step
method is not implemented in the base class—it will be provided by each specific model subclass. This follows the template method pattern from software design: the base class defines the general algorithm structure, while subclasses implement the specific details.
Now, let’s implement a simple model that uses this base class:
class TrivialModel(BasicModel):
def __init__(self, other_dim, lr=1e-4):
# other_dim is the dimension of features other than PU and DO
super().__init__(lr)
self.model = nn.Sequential(
+ 2, 512), # +2 for the two location IDs
nn.Linear(other_dim
nn.ReLU(),512, 1)
nn.Linear(
)
def common_step(self, batch, batch_idx):
= batch
x, pu, do, y # Concatenate all features
= torch.hstack((x, torch.unsqueeze(pu, 1), torch.unsqueeze(do, 1)))
X = self.model(X)
y_hat = y.view(-1, 1) # Reshape target to match prediction shape
y return y, y_hat
Before training, we need to determine the input dimension:
= train_dl.dataset[0][0].shape[0] other_dim
10.4 Training with Lightning Trainer
With our model defined, we can now train and evaluate it using Lightning’s Trainer
class:
= L.Trainer(deterministic=True, max_epochs=10)
trivial_trainer = TrivialModel(other_dim)
trivial_model
trivial_trainer.fit(trivial_model, train_dl, val_dl) trivial_trainer.test(trivial_model, test_dl)
The L.Trainer
class automates the entire training process. We specify: - deterministic=True
to ensure reproducible results - max_epochs=10
to limit the number of training epochs
The fit
method handles the training and validation process, while the test
method evaluates the model on the test set. Under the hood, Lightning’s Trainer is implementing all the boilerplate code we would otherwise have to write:
- Iterating through epochs
- Processing batches from data loaders
- Moving tensors to the appropriate device (CPU or GPU)
- Computing gradients and updating parameters
- Tracking metrics and logging progress
- Validating after each epoch
- Managing computational resources efficiently
This automation drastically reduces the amount of code we need to write while ensuring best practices are followed. Lightning also handles many edge cases and potential bugs that might otherwise creep into custom training loops.
10.5 Logging with Lightning
One of Lightning’s powerful features is its comprehensive logging system. In our basic model, we used self.log()
to track loss values:
def training_step(self, batch, batch_idx):
= self.common_step(batch, batch_idx)
y, y_hat = nn.functional.mse_loss(y, y_hat)
loss self.log("training_loss", loss)
return loss
This simple method automatically: - Records the metric at each step - Associates it with the current epoch - Makes it available for monitoring - Syncs it across distributed processes (if using multi-GPU)
By default, Lightning aggregates logged metrics by computing their mean over each epoch. This behavior can be customized through additional parameters:
self.log("metric_name", value, on_step=True, on_epoch=True, prog_bar=True, logger=True)
Where: - on_step=True
: Log the value at each step - on_epoch=True
: Log the epoch-averaged value - prog_bar=True
: Display the metric in the progress bar - logger=True
: Send the metric to the configured logger(s)
For logging multiple metrics at once, Lightning provides the log_dict()
method:
def validation_step(self, batch, batch_idx):
= self.common_step(batch, batch_idx)
y, y_hat = nn.functional.mse_loss(y, y_hat)
mse_loss = torch.abs(y - y_hat).mean()
mae_loss
self.log_dict({
"val_mse": mse_loss,
"val_mae": mae_loss
})return mse_loss
10.5.1 Using TensorBoard for Visualization
While Lightning’s default logger provides basic metric tracking, TensorBoard offers much richer visualization capabilities. TensorBoard is a web-based tool that provides graphical representations of model metrics, allowing you to:
- Track metrics over time
- Compare multiple runs
- Visualize model architectures
- Inspect weight distributions
- Explore embedding spaces
To use TensorBoard with Lightning, we simply add a TensorBoard logger to our Trainer:
from lightning.pytorch.loggers import TensorBoardLogger
= TensorBoardLogger("tb_logs", name="taxi_fare_model")
logger = L.Trainer(logger=logger, max_epochs=15) trainer
This configuration will: 1. Create a directory tb_logs/taxi_fare_model/
2. Store all logged metrics in TensorBoard format 3. Create a unique subdirectory for each run (version)
To view the TensorBoard dashboard during or after training, run:
%load_ext tensorboard
%tensorboard --logdir tb_logs
This launches the TensorBoard interface in your browser or notebook, providing real-time visualization of your model’s performance.
10.5.2 Advanced Logging Capabilities
Beyond simple metrics, TensorBoard can visualize various aspects of your model:
- Histograms: Track parameter distributions over time:
for name, param in self.named_parameters():
self.logger.experiment.add_histogram(name, param, self.current_epoch)
- Images: Visualize input data, activations, or generated outputs:
self.logger.experiment.add_image('sample_image', img_tensor, self.current_epoch)
- Text: Log textual information such as predictions or attention weights:
self.logger.experiment.add_text('prediction', text_string, self.current_epoch)
- Embedding Projections: Visualize high-dimensional embeddings in 2D or 3D space:
self.logger.experiment.add_embedding(
=embeddings,
mat=labels,
metadata=self.current_epoch
global_step )
These advanced logging capabilities provide deeper insights into your model’s behavior and learning process.
10.6 Model Checkpointing
In deep learning, training sessions can be time-consuming and susceptible to interruptions. Model checkpointing addresses this by periodically saving model weights, enabling you to:
- Resume training after interruptions
- Keep the best model according to validation metrics
- Perform post-training analysis
- Deploy trained models to production
Lightning provides a ModelCheckpoint
callback that handles these tasks automatically:
from lightning.pytorch.callbacks import ModelCheckpoint
= ModelCheckpoint(
checkpoint_callback ='val_loss',
monitor='checkpoints/',
dirpath='taxi-fare-{epoch:02d}-{val_loss:.4f}',
filename=3,
save_top_k='min',
mode
)
= L.Trainer(
trainer =logger,
logger=[checkpoint_callback],
callbacks=15
max_epochs )
This configuration: - Monitors the validation loss - Saves checkpoints in the ‘checkpoints/’ directory - Names files with epoch number and validation loss - Keeps the top 3 models with lowest validation loss - Deletes older checkpoints that don’t make the cut
10.6.1 Loading from Checkpoints
Lightning makes it easy to resume training from a checkpoint:
# Resume from a specific checkpoint
= L.Trainer(resume_from_checkpoint='checkpoints/taxi-fare-05-0.3456.ckpt')
trainer
trainer.fit(model, train_dl, val_dl)
# Or load a model from checkpoint without resuming training
= TrivialModel.load_from_checkpoint('checkpoints/taxi-fare-10-0.1234.ckpt', other_dim=other_dim) model
When loading a model from a checkpoint, you need to provide any positional arguments required by the model’s constructor (like other_dim
in our example). Lightning automatically restores the model’s state, including parameters, buffers, and optimizer state if resuming training.
10.6.2 Saving and Loading Best Practices
Here are some best practices for effective checkpoint management:
Monitor the Right Metric: Choose a validation metric that truly reflects model performance for your problem.
Use Meaningful Filenames: Include relevant metrics and epoch numbers in filenames for easier identification.
Save Hyperparameters: Use the
self.save_hyperparameters()
method in your__init__
to automatically save constructor arguments:
def __init__(self, other_dim, hidden_dim=512, lr=1e-4):
super().__init__(lr)
self.save_hyperparameters()
# ... rest of init
- Limit Checkpoint Frequency: For large models or fast training, save checkpoints less frequently to reduce storage and I/O overhead:
= ModelCheckpoint(
checkpoint_callback =5, # Save every 5 epochs
every_n_epochs# ... other params
)
- Checkpoint Compression: For large models, enable compression to save disk space:
= ModelCheckpoint(
checkpoint_callback =True, # Save only weights, not optimizer state
save_weights_only# ... other params
)
These practices ensure efficient and reliable model checkpointing.
10.7 Early Stopping
Training neural networks to convergence can be time-consuming, and continuing training after performance plateaus often leads to overfitting. Early stopping addresses this by halting training when performance on the validation set no longer improves.
Lightning provides an EarlyStopping
callback:
from lightning.pytorch.callbacks import EarlyStopping
= EarlyStopping(
early_stop_callback ='val_loss',
monitor=0.00,
min_delta=5,
patience=True,
verbose='min'
mode
)
= L.Trainer(
trainer =logger,
logger=[checkpoint_callback, early_stop_callback],
callbacks=50 # Set a high number, early stopping will prevent reaching this
max_epochs )
This configuration: - Monitors validation loss - Stops training if the loss doesn’t improve by at least min_delta
for patience
epochs - Displays a message when stopping - Uses ‘min’ mode because we want to minimize loss (use ‘max’ for metrics like accuracy)
10.7.1 Benefits of Early Stopping
Early stopping provides several benefits:
- Reduced Overfitting: Prevents the model from memorizing training data
- Time Efficiency: Saves computational resources
- Automatic Regularization: Acts as a form of regularization
- Best Model Selection: Combined with checkpointing, ensures you keep the best model
10.7.2 Early Stopping Considerations
When implementing early stopping, consider:
Appropriate Patience: Set patience based on the expected learning curve - too low might stop prematurely, too high wastes computation.
Monitor the Right Metric: Choose a metric that reflects generalization performance.
Validation Set Quality: Ensure your validation set is representative and large enough.
Combined with Restoring Best: Enable the
restore_best_weights
parameter to automatically load the best model:
= EarlyStopping(
early_stop_callback ='val_loss',
monitor=5,
patience=True,
verbose='min',
mode=True # Restore best model at the end
restore_best_weights )
This parameter ensures that even if training continues for several epochs after the best performance, the final model will have the weights from the best epoch.
10.8 A More Sophisticated Model
Now that we’ve covered the basics of Lightning, let’s implement a more sophisticated model that incorporates embedding layers for location IDs:
class EmbeddingModel(BasicModel):
def __init__(self, other_dim, embed_dim=16, lr=1e-4):
super().__init__(lr)
# Save hyperparameters for checkpointing
self.save_hyperparameters()
self.model = nn.Sequential(
+ 2*embed_dim, 1024),
nn.Linear(other_dim
nn.ReLU(),1024, 1)
nn.Linear(
)self.pu_embed = torch.nn.Embedding(PUvals, embed_dim)
self.do_embed = torch.nn.Embedding(DOvals, embed_dim)
def common_step(self, batch, batch_idx):
= batch
x, pu, do, y = self.pu_embed(pu)
pu_vec = self.do_embed(do)
do_vec = torch.hstack((x, pu_vec, do_vec))
X = self.model(X)
y_hat = y.view(-1, 1)
y return y, y_hat
def validation_step(self, batch, batch_idx):
= self.common_step(batch, batch_idx)
y, y_hat = nn.functional.mse_loss(y, y_hat)
mse_loss = torch.abs(y - y_hat).mean()
mae_loss
# Log multiple metrics
self.log_dict({
"val_loss": mse_loss, # Primary metric for monitoring
"val_mae": mae_loss # Additional metric for analysis
})return mse_loss
This model: - Uses embedding layers for pickup and dropoff locations - Saves hyperparameters for better checkpointing - Logs multiple metrics during validation - Has a larger network capacity for better performance
Let’s train this model with our complete set of Lightning features:
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
# Configure logger
= TensorBoardLogger("tb_logs", name="embedding_model")
logger
# Configure checkpointing
= ModelCheckpoint(
checkpoint_callback ='val_loss',
monitor='checkpoints/',
dirpath='embed-{epoch:02d}-{val_loss:.4f}',
filename=3,
save_top_k='min'
mode
)
# Configure early stopping
= EarlyStopping(
early_stop_callback ='val_loss',
monitor=0.001,
min_delta=5,
patience=True,
verbose='min',
mode=True
restore_best_weights
)
# Create and train model
= EmbeddingModel(other_dim, embed_dim=32)
embed_model = L.Trainer(
trainer =logger,
logger=[checkpoint_callback, early_stop_callback],
callbacks=25,
max_epochs=True
deterministic
)
trainer.fit(embed_model, train_dl, val_dl) trainer.test(embed_model, test_dl)
This training setup includes: - Comprehensive logging with TensorBoard - Model checkpointing to save the best models - Early stopping to prevent overfitting - A longer maximum training duration (25 epochs)
With this configuration, we can train a more powerful model while maintaining efficient use of computational resources and automatically selecting the best-performing version.
10.9 Hyperparameter Search
Selecting optimal hyperparameters is crucial for neural network performance. While manual tuning is common, automatic hyperparameter search can identify better configurations more efficiently. Lightning integrates with libraries like Optuna and Ray Tune to facilitate hyperparameter optimization.
Let’s implement a hyperparameter search using Lightning and Optuna:
import optuna
from optuna.integration import PyTorchLightningPruningCallback
def objective(trial):
# Define hyperparameter search space
= trial.suggest_int('embed_dim', 8, 64)
embed_dim = trial.suggest_int('hidden_dim', 256, 1024)
hidden_dim = trial.suggest_float('lr', 1e-5, 1e-3, log=True)
lr = trial.suggest_float('dropout', 0.1, 0.5)
dropout
# Define model with trial hyperparameters
class TrialModel(BasicModel):
def __init__(self):
super().__init__(lr)
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.dropout_rate = dropout
self.model = nn.Sequential(
+ 2*embed_dim, hidden_dim),
nn.Linear(other_dim
nn.ReLU(),
nn.Dropout(dropout),// 2),
nn.Linear(hidden_dim, hidden_dim
nn.ReLU(),
nn.Dropout(dropout),// 2, 1)
nn.Linear(hidden_dim
)self.pu_embed = torch.nn.Embedding(PUvals, embed_dim)
self.do_embed = torch.nn.Embedding(DOvals, embed_dim)
def common_step(self, batch, batch_idx):
= batch
x, pu, do, y = self.pu_embed(pu)
pu_vec = self.do_embed(do)
do_vec = torch.hstack((x, pu_vec, do_vec))
X = self.model(X)
y_hat = y.view(-1, 1)
y return y, y_hat
# Create model
= TrialModel()
model
# Configure callbacks
= PyTorchLightningPruningCallback(trial, monitor="val_loss")
pruning_callback = EarlyStopping(
early_stop_callback ='val_loss',
monitor=3,
patience=False,
verbose='min'
mode
)
# Configure trainer
= L.Trainer(
trainer =False, # Disable logging to avoid clutter
logger=[pruning_callback, early_stop_callback],
callbacks=10,
max_epochs=False # Disable progress bar for cleaner output
enable_progress_bar
)
# Train model
trainer.fit(model, train_dl, val_dl)
# Return best validation loss
return trainer.callback_metrics["val_loss"].item()
# Create study
= optuna.create_study(direction="minimize")
study =20)
study.optimize(objective, n_trials
# Get best parameters
= study.best_params
best_params print(f"Best parameters: {best_params}")
print(f"Best validation loss: {study.best_value}")
This hyperparameter search: - Defines a search space for embedding dimension, hidden dimension, learning rate, and dropout rate - Creates a trial model for each hyperparameter configuration - Uses Optuna’s pruning callback to terminate unpromising trials early - Trains each model for up to 10 epochs with early stopping - Returns the best validation loss for each trial - Collects results across 20 trials to find the optimal configuration
10.9.1 Advanced Hyperparameter Search
For more sophisticated hyperparameter optimization, consider:
- Conditional Parameters: Define parameters that depend on other parameters:
# Only include batch normalization if a specific architecture is chosen
if trial.suggest_categorical('architecture', ['simple', 'complex']) == 'complex':
= trial.suggest_categorical('use_batchnorm', [True, False])
use_batchnorm else:
= False use_batchnorm
- Pruning Inefficient Trials: Terminate poorly performing trials early to save resources:
= optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=5)
pruner = optuna.create_study(direction="minimize", pruner=pruner) study
- Parallel Optimization: Run multiple trials concurrently:
=20, n_jobs=4) # Run 4 trials in parallel study.optimize(objective, n_trials
- Visualizing Results: Analyze the search results:
# Plot optimization history
optuna.visualization.plot_optimization_history(study)
# Plot parameter importance
optuna.visualization.plot_param_importances(study)
# Plot parameter relationships
optuna.visualization.plot_contour(study)
These visualizations provide insights into the hyperparameter landscape, highlighting which parameters have the strongest influence on model performance.
10.9.2 Final Model with Optimal Hyperparameters
After identifying the optimal hyperparameters, we can train our final model:
# Create model with best hyperparameters
= EmbeddingModel(
final_model =other_dim,
other_dim=best_params['embed_dim'],
embed_dim=best_params['lr']
lr
)
# Configure comprehensive logging and checkpointing
= TensorBoardLogger("tb_logs", name="final_model")
logger = ModelCheckpoint(
checkpoint_callback ='val_loss',
monitor='checkpoints/',
dirpath='final-{epoch:02d}-{val_loss:.4f}',
filename=1,
save_top_k='min'
mode
)
# Train with optimal configuration
= L.Trainer(
final_trainer =logger,
logger=[checkpoint_callback, early_stop_callback],
callbacks=30,
max_epochs=True
deterministic
)
final_trainer.fit(final_model, train_dl, val_dl) final_trainer.test(final_model, test_dl)
This final training run incorporates all the best practices we’ve discussed: - Using the optimal hyperparameters from our search - Comprehensive logging with TensorBoard - Model checkpointing to save the best model - Early stopping to prevent overfitting - A longer maximum training duration (30 epochs)
The result is a model that achieves the best possible performance for our taxi fare prediction task.
10.10 Advanced Lightning Techniques
Beyond the features we’ve covered, Lightning offers several advanced capabilities that can further enhance your deep learning workflow.
10.10.1 Learning Rate Scheduling
Learning rate scheduling can significantly improve training dynamics. Lightning makes it easy to implement:
def configure_optimizers(self):
= optim.Adam(self.parameters(), lr=self.lr)
optimizer = optim.lr_scheduler.ReduceLROnPlateau(
scheduler ='min', factor=0.5, patience=3, verbose=True
optimizer, mode
)return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"frequency": 1
} }
This configuration reduces the learning rate by half whenever the validation loss plateaus for 3 epochs.
10.10.2 Gradient Accumulation
For training with large batch sizes on limited memory, gradient accumulation is helpful:
= L.Trainer(
trainer =4, # Accumulate gradients over 4 batches
accumulate_grad_batches# ... other parameters
)
This effectively simulates a batch size 4 times larger than what fits in memory.
10.10.3 Mixed Precision Training
Mixed precision training uses lower-precision arithmetic (e.g., float16) to speed up training:
= L.Trainer(
trainer ="16-mixed", # Use mixed precision
precision# ... other parameters
)
This can provide significant speedups, especially on modern GPUs with tensor cores.
10.10.4 Multi-GPU Training
Lightning makes it trivial to scale to multiple GPUs:
= L.Trainer(
trainer ="gpu",
accelerator=4, # Use 4 GPUs
devices="ddp" # Distributed Data Parallel
strategy )
The same model code works seamlessly across different hardware configurations.
10.10.5 Profiling
For performance optimization, Lightning includes built-in profiling:
= L.Trainer(
trainer ="simple", # or "advanced" for more detailed profiling
profiler# ... other parameters
)
This helps identify bottlenecks in your training pipeline.
10.11 A Production-Ready ResNet Model
To demonstrate the full power of Lightning, let’s implement a production-ready model that incorporates all the advanced techniques we’ve discussed. This model uses a residual architecture with batch normalization and dropout:
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_rate=0.2):
super().__init__()
self.block = nn.Sequential(
nn.Linear(dim, dim),
nn.BatchNorm1d(dim),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(dim, dim),
nn.BatchNorm1d(dim)
)self.relu = nn.ReLU()
def forward(self, x):
= x
residual = self.block(x)
out += residual # Residual connection
out = self.relu(out)
out return out
class ResNetModel(BasicModel):
def __init__(self, other_dim, embed_dim=32, hidden_dim=512, num_blocks=3, dropout_rate=0.2, lr=1e-4):
super().__init__(lr)
self.save_hyperparameters()
# Input projection
self.input_proj = nn.Sequential(
+ 2*embed_dim, hidden_dim),
nn.Linear(other_dim
nn.BatchNorm1d(hidden_dim),
nn.ReLU()
)
# Residual blocks
self.res_blocks = nn.ModuleList([
for _ in range(num_blocks)
ResidualBlock(hidden_dim, dropout_rate)
])
# Output layer
self.output_layer = nn.Linear(hidden_dim, 1)
# Embeddings
self.pu_embed = nn.Embedding(PUvals, embed_dim)
self.do_embed = nn.Embedding(DOvals, embed_dim)
def common_step(self, batch, batch_idx):
= batch
x, pu, do, y
# Get embeddings
= self.pu_embed(pu)
pu_vec = self.do_embed(do)
do_vec
# Combine features
= torch.hstack((x, pu_vec, do_vec))
X
# Forward pass
= self.input_proj(X)
out
# Apply residual blocks
for block in self.res_blocks:
= block(out)
out
# Final prediction
= self.output_layer(out)
y_hat = y.view(-1, 1)
y
return y, y_hat
def validation_step(self, batch, batch_idx):
= self.common_step(batch, batch_idx)
y, y_hat = nn.functional.mse_loss(y, y_hat)
mse_loss = torch.abs(y - y_hat).mean()
mae_loss
self.log_dict({
"val_loss": mse_loss,
"val_mae": mae_loss
})return mse_loss
def configure_optimizers(self):
= optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-4)
optimizer = optim.lr_scheduler.ReduceLROnPlateau(
scheduler ='min', factor=0.5, patience=3, verbose=True
optimizer, mode
)return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"frequency": 1
} }
This model incorporates: - Embedding layers for categorical features - Residual connections for better gradient flow - Batch normalization for training stability - Dropout for regularization - AdamW optimizer with weight decay - Learning rate scheduling
Let’s train this model with all the Lightning features we’ve discussed:
# Configure logging
= TensorBoardLogger("tb_logs", name="resnet_model")
logger
# Configure callbacks
= ModelCheckpoint(
checkpoint_callback ='val_loss',
monitor='checkpoints/',
dirpath='resnet-{epoch:02d}-{val_loss:.4f}',
filename=1,
save_top_k='min'
mode
)
= EarlyStopping(
early_stop_callback ='val_loss',
monitor=0.001,
min_delta=5,
patience=True,
verbose='min',
mode=True
restore_best_weights
)
# Create model
= ResNetModel(
resnet_model =other_dim,
other_dim=32,
embed_dim=512,
hidden_dim=3,
num_blocks=0.2,
dropout_rate=1e-4
lr
)
# Configure trainer with advanced features
= L.Trainer(
trainer =logger,
logger=[checkpoint_callback, early_stop_callback],
callbacks=30,
max_epochs="16-mixed", # Use mixed precision for speed
precision=2, # Larger effective batch size
accumulate_grad_batches=True
deterministic
)
# Train model
trainer.fit(resnet_model, train_dl, val_dl) trainer.test(resnet_model, test_dl)
This production-ready training pipeline incorporates all best practices and advanced techniques, resulting in a highly optimized model for our taxi fare prediction task.
10.12 Conclusion
In this chapter, we’ve explored PyTorch Lightning, a powerful framework that streamlines deep learning implementation while maintaining flexibility. We’ve covered:
- Lightning’s Core Architecture: Understanding the
LightningModule
andTrainer
classes - Logging and Visualization: Using TensorBoard for comprehensive monitoring
- Checkpointing: Saving and loading models effectively
- Early Stopping: Preventing overfitting through automated training termination
- Hyperparameter Search: Finding optimal configurations systematically
- Advanced Techniques: Learning rate scheduling, mixed precision, and more
By adopting Lightning, we’ve transformed our PyTorch code in several important ways:
- From Imperative to Declarative: Instead of explicitly coding how to train models, we declare what our models are and what metrics to track.
- From Low-Level to High-Level: We focus on scientific code (model architecture, loss functions) rather than engineering details.
- From Custom to Standardized: We leverage battle-tested implementations of common patterns rather than reinventing them.
- From Brittle to Robust: Our code is less prone to errors and works across different hardware configurations.
These transformations lead to significant benefits for economics research and applications:
- Improved Reproducibility: Standardized training procedures ensure consistent results.
- Increased Productivity: Less boilerplate means more focus on model design and analysis.
- Enhanced Collaboration: Structured code is easier to share and understand.
- Better Resource Utilization: Built-in optimizations maximize computational efficiency.
The combination of advanced neural network architectures (as covered in the previous chapter) with efficient implementation through Lightning provides a powerful toolkit for tackling complex economic problems with deep learning.
As you continue your deep learning journey, we encourage you to explore Lightning’s extensive documentation and community resources. The framework continues to evolve with new features and optimizations, making it an invaluable tool for both research and production applications in economics and beyond.