Dynamic Tensor Rematerialization(DTR)

Marisa Kirisame, Steven Lyubomirsky, Altan Haan, Jennifer Brennan, Mike He, Jared Roesch, Tianqi Chen, Zachary Tatlock
[paper] [slides] [video]



Model Maximum non-remat batch size Remat batch size Throughput relative to non-remat
BERT-base-mlm 16 32 94.3%
BERT-large-mlm 8 16 96.5%
GPT2 8 24 93.2%



Technical Dive

  1. Gradient Checkpointing is a technique to save memory for Deep Neural Network training.
  2. Or more generally, for reverse-mode automatic differentiation.
  3. However, memory planning is np-complete.
  4. Checkpointing also has to deal with program with arbitary control flow.
  5. To combat this, previous work made different restriction which sacrifice performance or usability.
  6. Some works models the program as a stack machine with no heap...
  7. And suffers performance degradation when the assumption is broken!
  8. (For example, NN with highway connection/branching).
  9. Other works use an ILP solver, which consume lots of time to find the optimal memory planning.
  10. And can only be used for program/framework without control flow, posing problem for real world adoption.
  11. Additionally, gradient checkpointing couple derivative calculation, with memory saving by recomputing.
  12. This add complexity and limit applications range.
  13. DTR tackles the problems above by planning the memory greedily at runtime, instead of as a compiler pass.
  14. Essentially, DTR always pair each tensor with a thunk(think lazy evaluation), so all tensors can be recomputed after evicting.
  15. This solves the control flow and stack machine issue, as we do not model the program in anyway!
  16. However, with a novel cache eviction policy, we are still able to achieve great performance.
  17. Suprisingly DTR 'accidentally' discovered the O(n) time O(n^0.5) space, and the O(n log n) time O(log n) space scheme.
  18. These two are both well-known result.
  19. Furthermore DTR switch smoothly between normal execution without recomputation, O(n^0.5) space scheme, and O(log n) space scheme
  20. .