10  lightning

Note

This is an EARLY DRAFT.

In the previous chapter, we explored the theoretical foundations of advanced neural network architectures. Now, we turn our attention to the practical aspects of implementing and training these models efficiently. This chapter introduces PyTorch Lightning, a high-level framework that streamlines neural network implementation while maintaining PyTorch’s flexibility.

Lightning addresses a common challenge in deep learning research and applications: the separation of scientific code (model architecture, loss functions) from engineering code (training loops, device management, logging). By providing a structured organization for neural network development, Lightning enables more reproducible experiments, cleaner code, and faster iteration cycles—all critical factors for successful machine learning projects in economics and beyond.

Throughout this chapter, we will work with our familiar NYC taxi fare prediction task, progressively enhancing our implementation while exploring Lightning’s key features: logging, visualization with TensorBoard, model checkpointing, early stopping, and hyperparameter search. These tools represent the modern practice of deep learning, enabling more efficient model development, improved performance, and increased reproducibility.

10.1 Why Use Lightning?

Traditional PyTorch code, while flexible, often contains repetitive boilerplate that can obscure the core model logic. Consider a typical training loop in vanilla PyTorch:

# Vanilla PyTorch training loop
model = MyModel()
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

for epoch in range(num_epochs):
    model.train()
    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            val_loss += criterion(output, y).item()
    val_loss /= len(val_loader)
    
    print(f"Epoch {epoch}: Train Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}")

This code has several issues:

  1. Repetitive Patterns: Device management, gradient zeroing, and evaluation code are duplicated across projects.
  2. Mixed Concerns: Scientific logic (model, loss) is intertwined with engineering details (loops, device management).
  3. Limited Extensibility: Adding features like learning rate scheduling or early stopping requires significant code changes.
  4. Poor Reproducibility: Subtle differences in implementation can lead to different results across experiments.

PyTorch Lightning addresses these issues by providing a structured organization for neural network code while preserving PyTorch’s flexibility. The core abstraction in Lightning is the LightningModule class, which encapsulates:

  • The model architecture (__init__ method)
  • The forward pass (forward method)
  • Training, validation, and test logic (training_step, validation_step, test_step methods)
  • Optimization configuration (configure_optimizers method)

This organization makes the code more readable and maintainable while reducing potential sources of error.

Lightning offers several advantages for economics research and applications:

  1. Code Organization: Lightning enforces a clean separation between research code (model architecture, loss functions) and engineering code (training loops, GPU handling).

  2. Reduced Boilerplate: Common operations like moving tensors to the correct device, gradient calculation, and parameter updates are handled automatically.

  3. Built-in Features: Lightning provides out-of-the-box support for logging, checkpointing, early stopping, and other training utilities.

  4. Scalability: The same code can easily scale from a single CPU to multiple GPUs or even multiple machines with minimal changes.

  5. Reproducibility: Lightning makes it easier to ensure consistent results by standardizing the training process.

  6. Readability and Communication: The structured format makes it easier to share code and results with colleagues and reviewers—a crucial consideration for academic research.

For economics applications in particular, reproducibility and transparency are paramount. By separating the scientific components (model definition, hyperparameters) from the engineering details, Lightning makes it easier to communicate and replicate research findings.

10.2 Loading and Preparing Data

Before diving into Lightning-specific code, let’s load and prepare our NYC taxi dataset. As in previous chapters, we’ll download the data, perform initial cleaning operations, and create appropriate data loaders.

from pathlib import Path
import requests
import pandas as pd
import numpy as np
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
import torch
from torch import nn, optim, utils
import lightning as L

# Set random seed for reproducibility
L.pytorch.seed_everything(110, workers=True)

# Download data if needed
local_path = Path('data/fhvhv_tripdata_2024-01.parquet')
url = 'https://d37ci6vzurychx.cloudfront.net/trip-data/fhvhv_tripdata_2024-01.parquet'

if not local_path.exists():
    local_path.parent.mkdir(exist_ok=True)
    local_path.write_bytes(requests.get(url).content)

# Load and clean data
df = pd.read_parquet('data/fhvhv_tripdata_2024-01.parquet',
                     columns = ['hvfhs_license_num','request_datetime',
                                'trip_miles','trip_time','base_passenger_fare',
                                'driver_pay','PULocationID','DOLocationID']).sample(1_000_000)

# Clean the data by filtering outliers
df = df[(df['trip_miles']>=1) 
        & (df['trip_miles']<=20) 
        & (df['base_passenger_fare']<200)]

# Feature engineering
df['request_day_of_week'] = df['request_datetime'].dt.dayofweek
df['request_hour_of_day'] = df['request_datetime'].dt.hour
df['fare_per_mile'] = df['base_passenger_fare']/df['trip_miles']

Note that we’re sampling 1 million records from the dataset to make training more manageable. While this is still a substantial amount of data, it allows for faster experimentation without overly compromising model quality.

Next, we define our feature sets and split the data:

categorical_features = ['hvfhs_license_num', 'request_day_of_week', 'request_hour_of_day']
numerical_features = ['trip_miles', 'trip_time']

X = df[categorical_features+numerical_features+['PULocationID','DOLocationID']]
y = df['fare_per_mile']

# Create training, validation, and test sets
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y,
                                    test_size=0.1, random_state=100)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val,
                                    test_size=0.1, random_state=100)
# Free memory
del X_train_val
del y_train_val

This three-way split is standard practice in deep learning:

  • Training set (81% of data): Used to update model parameters
  • Validation set (9% of data): Used to tune hyperparameters and monitor for overfitting
  • Test set (10% of data): Used for final evaluation only

The validation set helps us monitor the model’s generalization ability during training and make informed decisions about hyperparameters, while the test set provides an unbiased final evaluation.

For preprocessing our features, we’ll use scikit-learn’s ColumnTransformer:

ct = ColumnTransformer([
    ('num', StandardScaler(), numerical_features),
    ('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), categorical_features)
])

ct.fit(X_train)

# Get min and max values for location IDs to determine embedding sizes
PUmax, PUmin = df['PULocationID'].max(), df['PULocationID'].min()
PUvals = PUmax - PUmin + 1
DOmax, DOmin = df['DOLocationID'].max(), df['DOLocationID'].min()
DOvals = DOmax - DOmin + 1

Now, let’s create a function to process our data and turn it into PyTorch’s TensorDataset:

def mk_dataset(X, y):
    # Transform features with ColumnTransformer
    X_trans = ct.transform(X).astype(np.float32)
    return utils.data.TensorDataset(
        torch.from_numpy(X_trans),
        torch.from_numpy(X['PULocationID'].values - PUmin),
        torch.from_numpy(X['DOLocationID'].values - DOmin),
        torch.from_numpy(y.values.astype(np.float32))
    )

# Create dataloaders
train_dl = utils.data.DataLoader(mk_dataset(X_train, y_train),
                               shuffle=True, batch_size=1024, num_workers=4)
val_dl = utils.data.DataLoader(mk_dataset(X_val, y_val),
                             batch_size=1024, num_workers=4)
test_dl = utils.data.DataLoader(mk_dataset(X_test, y_test),
                              batch_size=1024, num_workers=4)

Note these important DataLoader parameters: - shuffle=True for the training set ensures that each epoch sees a different order of samples - batch_size=1024 defines how many samples are processed in each iteration - num_workers=4 enables parallel data loading, which can significantly speed up training

With our data prepared, we’re ready to implement models using Lightning.

10.3 Creating a Basic Lightning Module

Let’s create a base LightningModule class that implements the common functionality needed by all our models:

class BasicModel(L.LightningModule):
    def __init__(self, lr=1e-4):
        super().__init__()
        self.lr = lr

    def training_step(self, batch, batch_idx):
        y, y_hat = self.common_step(batch, batch_idx)
        loss = nn.functional.mse_loss(y, y_hat)
        self.log("training_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        y, y_hat = self.common_step(batch, batch_idx)
        loss = nn.functional.mse_loss(y, y_hat)
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        y, y_hat = self.common_step(batch, batch_idx)
        loss = nn.functional.mse_loss(y, y_hat)
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), self.lr)
        return optimizer

This BasicModel class:

  • Inherits from L.LightningModule
  • Accepts a learning rate parameter
  • Defines methods for training, validation, and testing steps
  • Uses mean squared error (MSE) as the loss function
  • Configures the Adam optimizer
  • Logs the loss for each phase (training, validation, test)

The common_step method is not implemented in the base class—it will be provided by each specific model subclass. This follows the template method pattern from software design: the base class defines the general algorithm structure, while subclasses implement the specific details.

Now, let’s implement a simple model that uses this base class:

class TrivialModel(BasicModel):
    def __init__(self, other_dim, lr=1e-4):
        # other_dim is the dimension of features other than PU and DO
        super().__init__(lr)
        self.model = nn.Sequential(
            nn.Linear(other_dim + 2, 512),  # +2 for the two location IDs
            nn.ReLU(),
            nn.Linear(512, 1)
        )
        
    def common_step(self, batch, batch_idx):
        x, pu, do, y = batch
        # Concatenate all features
        X = torch.hstack((x, torch.unsqueeze(pu, 1), torch.unsqueeze(do, 1)))
        y_hat = self.model(X)
        y = y.view(-1, 1)  # Reshape target to match prediction shape
        return y, y_hat

Before training, we need to determine the input dimension:

other_dim = train_dl.dataset[0][0].shape[0]

10.4 Training with Lightning Trainer

With our model defined, we can now train and evaluate it using Lightning’s Trainer class:

trivial_trainer = L.Trainer(deterministic=True, max_epochs=10)
trivial_model = TrivialModel(other_dim)
trivial_trainer.fit(trivial_model, train_dl, val_dl)
trivial_trainer.test(trivial_model, test_dl)

The L.Trainer class automates the entire training process. We specify: - deterministic=True to ensure reproducible results - max_epochs=10 to limit the number of training epochs

The fit method handles the training and validation process, while the test method evaluates the model on the test set. Under the hood, Lightning’s Trainer is implementing all the boilerplate code we would otherwise have to write:

  1. Iterating through epochs
  2. Processing batches from data loaders
  3. Moving tensors to the appropriate device (CPU or GPU)
  4. Computing gradients and updating parameters
  5. Tracking metrics and logging progress
  6. Validating after each epoch
  7. Managing computational resources efficiently

This automation drastically reduces the amount of code we need to write while ensuring best practices are followed. Lightning also handles many edge cases and potential bugs that might otherwise creep into custom training loops.

10.5 Logging with Lightning

One of Lightning’s powerful features is its comprehensive logging system. In our basic model, we used self.log() to track loss values:

def training_step(self, batch, batch_idx):
    y, y_hat = self.common_step(batch, batch_idx)
    loss = nn.functional.mse_loss(y, y_hat)
    self.log("training_loss", loss)
    return loss

This simple method automatically: - Records the metric at each step - Associates it with the current epoch - Makes it available for monitoring - Syncs it across distributed processes (if using multi-GPU)

By default, Lightning aggregates logged metrics by computing their mean over each epoch. This behavior can be customized through additional parameters:

self.log("metric_name", value, on_step=True, on_epoch=True, prog_bar=True, logger=True)

Where: - on_step=True: Log the value at each step - on_epoch=True: Log the epoch-averaged value - prog_bar=True: Display the metric in the progress bar - logger=True: Send the metric to the configured logger(s)

For logging multiple metrics at once, Lightning provides the log_dict() method:

def validation_step(self, batch, batch_idx):
    y, y_hat = self.common_step(batch, batch_idx)
    mse_loss = nn.functional.mse_loss(y, y_hat)
    mae_loss = torch.abs(y - y_hat).mean()
    
    self.log_dict({
        "val_mse": mse_loss,
        "val_mae": mae_loss
    })
    return mse_loss

10.5.1 Using TensorBoard for Visualization

While Lightning’s default logger provides basic metric tracking, TensorBoard offers much richer visualization capabilities. TensorBoard is a web-based tool that provides graphical representations of model metrics, allowing you to:

  • Track metrics over time
  • Compare multiple runs
  • Visualize model architectures
  • Inspect weight distributions
  • Explore embedding spaces

To use TensorBoard with Lightning, we simply add a TensorBoard logger to our Trainer:

from lightning.pytorch.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="taxi_fare_model")
trainer = L.Trainer(logger=logger, max_epochs=15)

This configuration will: 1. Create a directory tb_logs/taxi_fare_model/ 2. Store all logged metrics in TensorBoard format 3. Create a unique subdirectory for each run (version)

To view the TensorBoard dashboard during or after training, run:

%load_ext tensorboard
%tensorboard --logdir tb_logs

This launches the TensorBoard interface in your browser or notebook, providing real-time visualization of your model’s performance.

10.5.2 Advanced Logging Capabilities

Beyond simple metrics, TensorBoard can visualize various aspects of your model:

  1. Histograms: Track parameter distributions over time:
for name, param in self.named_parameters():
    self.logger.experiment.add_histogram(name, param, self.current_epoch)
  1. Images: Visualize input data, activations, or generated outputs:
self.logger.experiment.add_image('sample_image', img_tensor, self.current_epoch)
  1. Text: Log textual information such as predictions or attention weights:
self.logger.experiment.add_text('prediction', text_string, self.current_epoch)
  1. Embedding Projections: Visualize high-dimensional embeddings in 2D or 3D space:
self.logger.experiment.add_embedding(
    mat=embeddings,
    metadata=labels,
    global_step=self.current_epoch
)

These advanced logging capabilities provide deeper insights into your model’s behavior and learning process.

10.6 Model Checkpointing

In deep learning, training sessions can be time-consuming and susceptible to interruptions. Model checkpointing addresses this by periodically saving model weights, enabling you to:

  1. Resume training after interruptions
  2. Keep the best model according to validation metrics
  3. Perform post-training analysis
  4. Deploy trained models to production

Lightning provides a ModelCheckpoint callback that handles these tasks automatically:

from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints/',
    filename='taxi-fare-{epoch:02d}-{val_loss:.4f}',
    save_top_k=3,
    mode='min',
)

trainer = L.Trainer(
    logger=logger,
    callbacks=[checkpoint_callback],
    max_epochs=15
)

This configuration: - Monitors the validation loss - Saves checkpoints in the ‘checkpoints/’ directory - Names files with epoch number and validation loss - Keeps the top 3 models with lowest validation loss - Deletes older checkpoints that don’t make the cut

10.6.1 Loading from Checkpoints

Lightning makes it easy to resume training from a checkpoint:

# Resume from a specific checkpoint
trainer = L.Trainer(resume_from_checkpoint='checkpoints/taxi-fare-05-0.3456.ckpt')
trainer.fit(model, train_dl, val_dl)

# Or load a model from checkpoint without resuming training
model = TrivialModel.load_from_checkpoint('checkpoints/taxi-fare-10-0.1234.ckpt', other_dim=other_dim)

When loading a model from a checkpoint, you need to provide any positional arguments required by the model’s constructor (like other_dim in our example). Lightning automatically restores the model’s state, including parameters, buffers, and optimizer state if resuming training.

10.6.2 Saving and Loading Best Practices

Here are some best practices for effective checkpoint management:

  1. Monitor the Right Metric: Choose a validation metric that truly reflects model performance for your problem.

  2. Use Meaningful Filenames: Include relevant metrics and epoch numbers in filenames for easier identification.

  3. Save Hyperparameters: Use the self.save_hyperparameters() method in your __init__ to automatically save constructor arguments:

def __init__(self, other_dim, hidden_dim=512, lr=1e-4):
    super().__init__(lr)
    self.save_hyperparameters()
    # ... rest of init
  1. Limit Checkpoint Frequency: For large models or fast training, save checkpoints less frequently to reduce storage and I/O overhead:
checkpoint_callback = ModelCheckpoint(
    every_n_epochs=5,  # Save every 5 epochs
    # ... other params
)
  1. Checkpoint Compression: For large models, enable compression to save disk space:
checkpoint_callback = ModelCheckpoint(
    save_weights_only=True,  # Save only weights, not optimizer state
    # ... other params
)

These practices ensure efficient and reliable model checkpointing.

10.7 Early Stopping

Training neural networks to convergence can be time-consuming, and continuing training after performance plateaus often leads to overfitting. Early stopping addresses this by halting training when performance on the validation set no longer improves.

Lightning provides an EarlyStopping callback:

from lightning.pytorch.callbacks import EarlyStopping

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=5,
    verbose=True,
    mode='min'
)

trainer = L.Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback],
    max_epochs=50  # Set a high number, early stopping will prevent reaching this
)

This configuration: - Monitors validation loss - Stops training if the loss doesn’t improve by at least min_delta for patience epochs - Displays a message when stopping - Uses ‘min’ mode because we want to minimize loss (use ‘max’ for metrics like accuracy)

10.7.1 Benefits of Early Stopping

Early stopping provides several benefits:

  1. Reduced Overfitting: Prevents the model from memorizing training data
  2. Time Efficiency: Saves computational resources
  3. Automatic Regularization: Acts as a form of regularization
  4. Best Model Selection: Combined with checkpointing, ensures you keep the best model

10.7.2 Early Stopping Considerations

When implementing early stopping, consider:

  1. Appropriate Patience: Set patience based on the expected learning curve - too low might stop prematurely, too high wastes computation.

  2. Monitor the Right Metric: Choose a metric that reflects generalization performance.

  3. Validation Set Quality: Ensure your validation set is representative and large enough.

  4. Combined with Restoring Best: Enable the restore_best_weights parameter to automatically load the best model:

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=5,
    verbose=True,
    mode='min',
    restore_best_weights=True  # Restore best model at the end
)

This parameter ensures that even if training continues for several epochs after the best performance, the final model will have the weights from the best epoch.

10.8 A More Sophisticated Model

Now that we’ve covered the basics of Lightning, let’s implement a more sophisticated model that incorporates embedding layers for location IDs:

class EmbeddingModel(BasicModel):
    def __init__(self, other_dim, embed_dim=16, lr=1e-4):
        super().__init__(lr)
        # Save hyperparameters for checkpointing
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(other_dim + 2*embed_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1)
        )
        self.pu_embed = torch.nn.Embedding(PUvals, embed_dim)
        self.do_embed = torch.nn.Embedding(DOvals, embed_dim)
        
    def common_step(self, batch, batch_idx):
        x, pu, do, y = batch
        pu_vec = self.pu_embed(pu)
        do_vec = self.do_embed(do)
        X = torch.hstack((x, pu_vec, do_vec))
        y_hat = self.model(X)
        y = y.view(-1, 1)
        return y, y_hat
        
    def validation_step(self, batch, batch_idx):
        y, y_hat = self.common_step(batch, batch_idx)
        mse_loss = nn.functional.mse_loss(y, y_hat)
        mae_loss = torch.abs(y - y_hat).mean()
        
        # Log multiple metrics
        self.log_dict({
            "val_loss": mse_loss,  # Primary metric for monitoring
            "val_mae": mae_loss     # Additional metric for analysis
        })
        return mse_loss

This model: - Uses embedding layers for pickup and dropoff locations - Saves hyperparameters for better checkpointing - Logs multiple metrics during validation - Has a larger network capacity for better performance

Let’s train this model with our complete set of Lightning features:

from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

# Configure logger
logger = TensorBoardLogger("tb_logs", name="embedding_model")

# Configure checkpointing
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints/',
    filename='embed-{epoch:02d}-{val_loss:.4f}',
    save_top_k=3,
    mode='min'
)

# Configure early stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.001,
    patience=5,
    verbose=True,
    mode='min',
    restore_best_weights=True
)

# Create and train model
embed_model = EmbeddingModel(other_dim, embed_dim=32)
trainer = L.Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback],
    max_epochs=25,
    deterministic=True
)

trainer.fit(embed_model, train_dl, val_dl)
trainer.test(embed_model, test_dl)

This training setup includes: - Comprehensive logging with TensorBoard - Model checkpointing to save the best models - Early stopping to prevent overfitting - A longer maximum training duration (25 epochs)

With this configuration, we can train a more powerful model while maintaining efficient use of computational resources and automatically selecting the best-performing version.

10.10 Advanced Lightning Techniques

Beyond the features we’ve covered, Lightning offers several advanced capabilities that can further enhance your deep learning workflow.

10.10.1 Learning Rate Scheduling

Learning rate scheduling can significantly improve training dynamics. Lightning makes it easy to implement:

def configure_optimizers(self):
    optimizer = optim.Adam(self.parameters(), lr=self.lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
            "monitor": "val_loss",
            "frequency": 1
        }
    }

This configuration reduces the learning rate by half whenever the validation loss plateaus for 3 epochs.

10.10.2 Gradient Accumulation

For training with large batch sizes on limited memory, gradient accumulation is helpful:

trainer = L.Trainer(
    accumulate_grad_batches=4,  # Accumulate gradients over 4 batches
    # ... other parameters
)

This effectively simulates a batch size 4 times larger than what fits in memory.

10.10.3 Mixed Precision Training

Mixed precision training uses lower-precision arithmetic (e.g., float16) to speed up training:

trainer = L.Trainer(
    precision="16-mixed",  # Use mixed precision
    # ... other parameters
)

This can provide significant speedups, especially on modern GPUs with tensor cores.

10.10.4 Multi-GPU Training

Lightning makes it trivial to scale to multiple GPUs:

trainer = L.Trainer(
    accelerator="gpu",
    devices=4,  # Use 4 GPUs
    strategy="ddp"  # Distributed Data Parallel
)

The same model code works seamlessly across different hardware configurations.

10.10.5 Profiling

For performance optimization, Lightning includes built-in profiling:

trainer = L.Trainer(
    profiler="simple",  # or "advanced" for more detailed profiling
    # ... other parameters
)

This helps identify bottlenecks in your training pipeline.

10.11 A Production-Ready ResNet Model

To demonstrate the full power of Lightning, let’s implement a production-ready model that incorporates all the advanced techniques we’ve discussed. This model uses a residual architecture with batch normalization and dropout:

class ResidualBlock(nn.Module):
    def __init__(self, dim, dropout_rate=0.2):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.BatchNorm1d(dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(dim, dim),
            nn.BatchNorm1d(dim)
        )
        self.relu = nn.ReLU()
        
    def forward(self, x):
        residual = x
        out = self.block(x)
        out += residual  # Residual connection
        out = self.relu(out)
        return out

class ResNetModel(BasicModel):
    def __init__(self, other_dim, embed_dim=32, hidden_dim=512, num_blocks=3, dropout_rate=0.2, lr=1e-4):
        super().__init__(lr)
        self.save_hyperparameters()
        
        # Input projection
        self.input_proj = nn.Sequential(
            nn.Linear(other_dim + 2*embed_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU()
        )
        
        # Residual blocks
        self.res_blocks = nn.ModuleList([
            ResidualBlock(hidden_dim, dropout_rate) for _ in range(num_blocks)
        ])
        
        # Output layer
        self.output_layer = nn.Linear(hidden_dim, 1)
        
        # Embeddings
        self.pu_embed = nn.Embedding(PUvals, embed_dim)
        self.do_embed = nn.Embedding(DOvals, embed_dim)
        
    def common_step(self, batch, batch_idx):
        x, pu, do, y = batch
        
        # Get embeddings
        pu_vec = self.pu_embed(pu)
        do_vec = self.do_embed(do)
        
        # Combine features
        X = torch.hstack((x, pu_vec, do_vec))
        
        # Forward pass
        out = self.input_proj(X)
        
        # Apply residual blocks
        for block in self.res_blocks:
            out = block(out)
        
        # Final prediction
        y_hat = self.output_layer(out)
        y = y.view(-1, 1)
        
        return y, y_hat
    
    def validation_step(self, batch, batch_idx):
        y, y_hat = self.common_step(batch, batch_idx)
        mse_loss = nn.functional.mse_loss(y, y_hat)
        mae_loss = torch.abs(y - y_hat).mean()
        
        self.log_dict({
            "val_loss": mse_loss,
            "val_mae": mae_loss
        })
        return mse_loss
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3, verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
                "frequency": 1
            }
        }

This model incorporates: - Embedding layers for categorical features - Residual connections for better gradient flow - Batch normalization for training stability - Dropout for regularization - AdamW optimizer with weight decay - Learning rate scheduling

Let’s train this model with all the Lightning features we’ve discussed:

# Configure logging
logger = TensorBoardLogger("tb_logs", name="resnet_model")

# Configure callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints/',
    filename='resnet-{epoch:02d}-{val_loss:.4f}',
    save_top_k=1,
    mode='min'
)

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.001,
    patience=5,
    verbose=True,
    mode='min',
    restore_best_weights=True
)

# Create model
resnet_model = ResNetModel(
    other_dim=other_dim,
    embed_dim=32,
    hidden_dim=512,
    num_blocks=3,
    dropout_rate=0.2,
    lr=1e-4
)

# Configure trainer with advanced features
trainer = L.Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback],
    max_epochs=30,
    precision="16-mixed",  # Use mixed precision for speed
    accumulate_grad_batches=2,  # Larger effective batch size
    deterministic=True
)

# Train model
trainer.fit(resnet_model, train_dl, val_dl)
trainer.test(resnet_model, test_dl)

This production-ready training pipeline incorporates all best practices and advanced techniques, resulting in a highly optimized model for our taxi fare prediction task.

10.12 Conclusion

In this chapter, we’ve explored PyTorch Lightning, a powerful framework that streamlines deep learning implementation while maintaining flexibility. We’ve covered:

  1. Lightning’s Core Architecture: Understanding the LightningModule and Trainer classes
  2. Logging and Visualization: Using TensorBoard for comprehensive monitoring
  3. Checkpointing: Saving and loading models effectively
  4. Early Stopping: Preventing overfitting through automated training termination
  5. Hyperparameter Search: Finding optimal configurations systematically
  6. Advanced Techniques: Learning rate scheduling, mixed precision, and more

By adopting Lightning, we’ve transformed our PyTorch code in several important ways:

  1. From Imperative to Declarative: Instead of explicitly coding how to train models, we declare what our models are and what metrics to track.
  2. From Low-Level to High-Level: We focus on scientific code (model architecture, loss functions) rather than engineering details.
  3. From Custom to Standardized: We leverage battle-tested implementations of common patterns rather than reinventing them.
  4. From Brittle to Robust: Our code is less prone to errors and works across different hardware configurations.

These transformations lead to significant benefits for economics research and applications:

  1. Improved Reproducibility: Standardized training procedures ensure consistent results.
  2. Increased Productivity: Less boilerplate means more focus on model design and analysis.
  3. Enhanced Collaboration: Structured code is easier to share and understand.
  4. Better Resource Utilization: Built-in optimizations maximize computational efficiency.

The combination of advanced neural network architectures (as covered in the previous chapter) with efficient implementation through Lightning provides a powerful toolkit for tackling complex economic problems with deep learning.

As you continue your deep learning journey, we encourage you to explore Lightning’s extensive documentation and community resources. The framework continues to evolve with new features and optimizations, making it an invaluable tool for both research and production applications in economics and beyond.