- tags
- Python, Machine learning
Pytorch is an autodiff library used to do machine learning in Python.
Pytorch tricks
I don’t know who originally made this list. I also don’t know how many of those have been addressed in recent versions. If some of these tricks are not valid anymore let me know:
DataLoader
has bad default settings, tunenum_workers > 0
and default topin_memory = True
.- Use
torch.backends.cudnn.benchmark = True
to autotune cudnn kernel choice - Max out the batch size for each GPU to ammortize compute.
- Do not
forget_bias=False
in weight layers before BatchNorms, it’s a noop that bloats model. - Use
for p in model.parameters (): p.grad = None
instead ofmodel.zero grad()
- Careful to disable debug APIs in prod (
detect_anomaly/profiler/emit_nvtx/gradcheck
). - Use
DistributedDataParallel
notDataParallel
, even if not running distributed. - Careful to load balance compute on all GPs if variably-sized inputs or GPUs will idle.
- Use an apex fused optimizer (default PyTorch optim for loop iterates individual parameters, yikes).
- Use checkpointing to recompute memory-intensive compute-efficient ops in bwd pass (eg activations, upsampling,…).
- Use
@torch.jit.script
, e.g. esp to fuse long sequences of pointwise ops like in GELU.