PyTorch Patterns#
Model Definition#
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, out_dim),
)
def forward(self, x):
return self.net(x)
Training Loop#
model = MLP(784, 256, 10).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(100):
model.train()
for x, y in train_loader:
x, y = x.cuda(), y.cuda()
optimizer.zero_grad()
loss = F.cross_entropy(model(x), y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
model.eval()
with torch.no_grad():
# evaluate on val_loader
pass
Key Modules#
| Module |
Use |
nn.Linear(in, out) |
fully connected layer |
nn.Conv2d(in, out, k) |
2D convolution |
nn.MultiheadAttention(d, h) |
multi-head attention |
nn.LayerNorm(d) |
layer normalization |
nn.Embedding(V, d) |
token embeddings |
nn.TransformerEncoder |
stack of encoder layers |
Common Gotchas#
zero_grad() before backward() — gradients accumulate by default
model.eval() / model.train() — affects dropout and batch norm behavior
torch.no_grad() for inference — saves memory, 2× faster
.detach() to stop gradient flow through a tensor
contiguous() before .view() if tensor is non-contiguous
Device Management#
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
x = x.to(device)
Saving & Loading#
# Save
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'ckpt.pt')
# Load
ckpt = torch.load('ckpt.pt', map_location=device)
model.load_state_dict(ckpt['model'])
DataLoader#
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __len__(self): return len(self.data)
def __getitem__(self, idx): return self.data[idx], self.labels[idx]
loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)