当前位置: 华文星空 > 知识

Pytorch有什么节省显存的小技巧?

2020-04-28知识

你听说过checkpointing吗?

checkpointing通过在前向传播的时候,丢弃一些计算好的结果,而在反向传播需要的时候重新计算这些结果,从而节省中间的内存消耗。

这类型的工作有很多,经典的比如万恶之源Treeverse,在DL界更多人知道的Sublinear memory cost,首个自动对一切程序进行checkpointing的Checkpointing for Arbitrary Programs,rnn特化版的Memory-Efficient Backpropagation Through Time。

这里面,有一个不是问题的‘问题’ - checkpointing是一项‘利用重计算节省内存’的手段,不过这好像。。跟自动求导没啥关系啊?

用一个比较软件工程的说法:自动求导是一种寻找导数的算法,重计算是一种节省内存的方法,为什么我们在checkpointing里面,把这两个耦合起来呢?换句话说,我们能不能找一个 跟自动求导无关 的X,使得X + 自动求导 = checkpointing?

(P.S. Checkmate的确符合X,但是Checkmate仅仅是把问题粗暴的丢进一个ILP solver,不能带来任何insight,也没办法scale,我们看看还有没有其他方法?)

那我们试着倒推一下 - 假设我们实现了一套自动重计算的系统,这套系统的API,该如何设计呢?

首先,既然重计算不关心里面存着什么数据,是如何算出的(只需要忠实地重新执行计算操作就好),那我们知道,这是一个可以装各种类型数据的泛型容器。

data Recomputable a = ???

在这上面,最基础的操作自然是获得里面的数据 - 或者直接获得,或者重新计算下储存下来,然后获得之。

get :: Recomputable a -> a

这过程跟haskell的thunk不谋而合。同时,重计算的过程中可能会调用其他Recomputable的get,从而导致递归的重计算。这是可组合性下自然诞生的产物,是充话费送的,不需要刻意支持。

那,我们如何构造一个Recomputable呢?我们可以从一个a构造,作为一个不可重计算的输入:

pure :: a -> Recomputable a

也可以由多个Recomputable数据组合而来 - 之所以是多个而不是一个,是因为计算可能会分叉,然后合并,于是会形成一个图而不是链或者树。

( <*> ) :: Recomputable ( a -> b ) -> Recomputable a -> Recomputable b

这刚好形成一个Applicative,性质也很好验证 - Recomputable在外部看来,应该跟

newtype Identity a = Identity { runIdentity :: a }

是一样的 - 所有的区别,都只是我们会在有需要的时候,把数据去掉 - 但是get的时候会通过重计算原样返回。这只会产生性能上的差异,并不会更改语义。

OK,那我们还需要考虑如何去掉(drop)一个Recomputable的数据。

回想起Recomputable跟Lazy很像,drop这个操作就是Lazy的force的逆 - force强制计算一个lazy数据,并且储存结果,那drop就是去掉储存结果,然后只要我们还有计算用的thunk,就大功告成。于是,我们只需要选择一个‘force后保留原thunk’的lazy实现则可。

module LazyOption : Lazy with type ' a t = ' a option ref * ( unit -> ' a ) = LazyImpl ( struct type ' a t = ' a option ref * ( unit -> ' a ) let mk f = ( ref None , f ) let get ( o , f ) = match ! o with None -> let ret = f () in o := Some ret ; ret | Some x -> x end );;

那,就到了最后也是最重要的问题了 - 选那个数据drop?什么时候drop?

什么时候drop有一个很简单的做法 - 当我们内存不够(OS疯狂swap/runtime疯狂gc/allocator直接罢工)的时候drop就可以了,其他时候不存在。这样的话,当内存足够的时候,我们只会在一旁看戏,只有内存告急,才会像安全网一样空间换时间。

那 - 我们drop什么数据呢?我们思考一下,CS有啥地方需要‘考虑存什么数据不存什么数据’的 - 答案是caching: 硬件的memory hierarchy, os的page table, 软件层的software cache...

那我们可以像cache这样,给所有的recomputable算‘储存偏好’,在内存不够的时候去掉认为最不该存下的数据。那,我们用最经典的LRU cache可不可行?

桥豆麻袋!我们思考一下一般的caching跟我们的caching有没有不同的地方。在一个page table内,所有数据的大小都是一样的(page size),而且所有数据的计算开销都差不多(读写磁盘)。

但这,在深度学习中不成立。一个relu会比conv2d快数量级以上,pooling以前的tensor也比pooling以后的tensor大得多。所以,需要在LRU的基础上,再算上计算时间跟内存开销。

把这三个数都一起思考,我们可以得出compute / (memory * staleness)这个代表‘重计算单位收益开销’的数。我们可以搜索这个值最低的recomputable,并且drop - 哦,不对,既然是cache,就该叫evict了 - 之。

这时候,就麻雀虽小五脏俱全了。我们再检查一下这整套系统性能上有没有明显问题 - 有没有东东会成为瓶颈,有没有两个东东之间会带来不良交互。

可以发现很明显的三个问题 -

0:每次malloc,都可能触发evict,而evict需要全局搜索,这导致O(n)的程序会变成O(n2)的。幸运的是,深度学习中tensor数量不多,于是n2也可以顶。但是,依然可以做一些trick优化降低搜索开销 - 比如每次只搜索sqrt(n)这么多的tensor,或者比如有些tensor(太小的,或者计算量太多的)不搜索。

1:cache policy跟recursive thunk不兼容 - 如果重计算X需要触发一系列的重计算,我们的compute只会算上X。我们并不能简单的去递归计算compute来修复这一问题(因为这样大O又多了一个n)。我们需要上incremental algorithm,计算下递归的cost存下来,而且在数据(因为evict或者get)更新的时候更新。一个刚好挺合适的数据结构是union find - 我们把各个联通的evict的节点当成union find内同一集合的元素,就可以了。这是近似的,跟暴力算出来的结果不一样,但是没关系,cache的特点就是很健壮,policy不‘对’不影响结果。

2:整个系统跟传统的memory management(GC/Region Based Memory Management/Ref Counting)冲突 - 由于变成了一个图,加入了backpointer,那怕本来数据用完了,可以free,为了重计算,所有的过往数据都会被储存下来。这其实挺好解决 - 进行GC/etc时我们假装所有recomputable都没有内部的指针,但是当我们发现recomputable可以被回收的时候,不进行回收而仅仅是evict - 换句话说,我们把所有‘本应发生的GC’改成evict则可。

ok,这时候我们整个系统就设计完毕了。这时候,我们来思考一下eval - 我们最初的设计初衷是‘decouple recompute from ad’。那我们思考一下这新设计有什么更现实的价值。我们也继续以前的操作,也就是从给定性质为基础,进行推导。

0:我们不是ad - 这代表,我们可以在比如纯inference条件下使用,或者在不是深度学习框架/自动求导工具外使用

1:我们是caching - 而这带来了多个好处

1.0:这导致我们的算法是adaptive的 - 内存越多,重计算越小,跑得越快,甚至我们可以在运行时动态调整,从而支持multi tenancy场景。

1.1:也可以灵活设计cache policy,在里面简单的融入memory/compute等常数项 - 而treeverse跟graph based algorithm则往往忽略了这些常数项,转去分析程序的structure,从而承受忽略数量级常数项带来的性能损失 - 这可以看checkmate的evaluation,传统方法往往比checkmate差一截,但我们的方法却跟checkmate这最优算法性能接近。

1.2:cache是一种local又很快的东东。这代表我们可以很容易支持动态图框架,跟动态神经网络。同时,cache的速度又保证把这移动到运行时并不会构成瓶颈。

到这时候,我们就从最原始的,一些理论上细微的不协调,推出了一个有不错现实价值的系统。

congrats, you just invented DTR。