Python, Machine learning

Pytorch is an autodiff library used to do machine learning in Python.

Pytorch tricks

I don’t know who originally made this list. I also don’t know how many of those have been addressed in recent versions. If some of these tricks are not valid anymore let me know:

  • DataLoader has bad default settings, tune num_workers > 0 and default to pin_memory = True.
  • Use torch.backends.cudnn.benchmark = True to autotune cudnn kernel choice
  • Max out the batch size for each GPU to ammortize compute.
  • Do not forget_bias=False in weight layers before BatchNorms, it’s a noop that bloats model.
  • Use for p in model.parameters (): p.grad = None instead of grad()
  • Careful to disable debug APIs in prod (detect_anomaly/profiler/emit_nvtx/gradcheck).
  • Use DistributedDataParallel not DataParallel, even if not running distributed.
  • Careful to load balance compute on all GPs if variably-sized inputs or GPUs will idle.
  • Use an apex fused optimizer (default PyTorch optim for loop iterates individual parameters, yikes).
  • Use checkpointing to recompute memory-intensive compute-efficient ops in bwd pass (eg activations, upsampling,…).
  • Use @torch.jit.script, e.g. esp to fuse long sequences of pointwise ops like in GELU.

Links to this note

Last changed | authored by


← Back to Notes