PyTorch Patterns#

Model Definition#

import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        return self.net(x)

Training Loop#

model = MLP(784, 256, 10).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

for epoch in range(100):
    model.train()
    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()
        optimizer.zero_grad()
        loss = F.cross_entropy(model(x), y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
    scheduler.step()

    model.eval()
    with torch.no_grad():
        # evaluate on val_loader
        pass

Key Modules#

Module Use
nn.Linear(in, out) fully connected layer
nn.Conv2d(in, out, k) 2D convolution
nn.MultiheadAttention(d, h) multi-head attention
nn.LayerNorm(d) layer normalization
nn.Embedding(V, d) token embeddings
nn.TransformerEncoder stack of encoder layers

Common Gotchas#

  • zero_grad() before backward() — gradients accumulate by default
  • model.eval() / model.train() — affects dropout and batch norm behavior
  • torch.no_grad() for inference — saves memory, 2× faster
  • .detach() to stop gradient flow through a tensor
  • contiguous() before .view() if tensor is non-contiguous

Device Management#

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
x = x.to(device)

Saving & Loading#

# Save
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'ckpt.pt')

# Load
ckpt = torch.load('ckpt.pt', map_location=device)
model.load_state_dict(ckpt['model'])

DataLoader#

from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx], self.labels[idx]

loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)