當前位置: 華文星空 > 知識

Pytorch有什麽節省視訊記憶體的小技巧?

2020-04-28知識

你聽說過checkpointing嗎?

checkpointing透過在前向傳播的時候,丟棄一些計算好的結果,而在反向傳播需要的時候重新計算這些結果,從而節省中間的記憶體消耗。

這類別的工作有很多,經典的比如萬惡之源Treeverse,在DL界更多人知道的Sublinear memory cost,首個自動對一切程式進行checkpointing的Checkpointing for Arbitrary Programs,rnn特化版的Memory-Efficient Backpropagation Through Time。

這裏面,有一個不是問題的‘問題’ - checkpointing是一項‘利用重計算節省記憶體’的手段,不過這好像。。跟自動求導沒啥關系啊?

用一個比較軟件工程的說法:自動求導是一種尋找導數的演算法,重計算是一種節省記憶體的方法,為什麽我們在checkpointing裏面,把這兩個耦合起來呢?換句話說,我們能不能找一個 跟自動求導無關 的X,使得X + 自動求導 = checkpointing?

(P.S. Checkmate的確符合X,但是Checkmate僅僅是把問題粗暴的丟進一個ILP solver,不能帶來任何insight,也沒辦法scale,我們看看還有沒有其他方法?)

那我們試著倒推一下 - 假設我們實作了一套自動重計算的系統,這套系統的API,該如何設計呢?

首先,既然重計算不關心裏面存著什麽數據,是如何算出的(只需要忠實地重新執行計算操作就好),那我們知道,這是一個可以裝各種類別數據的泛型容器。

data Recomputable a = ???

在這上面,最基礎的操作自然是獲得裏面的數據 - 或者直接獲得,或者重新計算下儲存下來,然後獲得之。

get :: Recomputable a -> a

這過程跟haskell的thunk不謀而合。同時,重計算的過程中可能會呼叫其他Recomputable的get,從而導致遞迴的重計算。這是可組合性下自然誕生的產物,是充話費送的,不需要刻意支持。

那,我們如何構造一個Recomputable呢?我們可以從一個a構造,作為一個不可重計算的輸入:

pure :: a -> Recomputable a

也可以由多個Recomputable數據組合而來 - 之所以是多個而不是一個,是因為計算可能會分叉,然後合並,於是會形成一個圖而不是鏈或者樹。

( <*> ) :: Recomputable ( a -> b ) -> Recomputable a -> Recomputable b

這剛好形成一個Applicative,性質也很好驗證 - Recomputable在外部看來,應該跟

newtype Identity a = Identity { runIdentity :: a }

是一樣的 - 所有的區別,都只是我們會在有需要的時候,把數據去掉 - 但是get的時候會透過重計算原樣返回。這只會產生效能上的差異,並不會更改語意。

OK,那我們還需要考慮如何去掉(drop)一個Recomputable的數據。

回想起Recomputable跟Lazy很像,drop這個操作就是Lazy的force的逆 - force強制計算一個lazy數據,並且儲存結果,那drop就是去掉儲存結果,然後只要我們還有計算用的thunk,就大功告成。於是,我們只需要選擇一個‘force後保留原thunk’的lazy實作則可。

module LazyOption : Lazy with type ' a t = ' a option ref * ( unit -> ' a ) = LazyImpl ( struct type ' a t = ' a option ref * ( unit -> ' a ) let mk f = ( ref None , f ) let get ( o , f ) = match ! o with None -> let ret = f () in o := Some ret ; ret | Some x -> x end );;

那,就到了最後也是最重要的問題了 - 選那個數據drop?什麽時候drop?

什麽時候drop有一個很簡單的做法 - 當我們記憶體不夠(OS瘋狂swap/runtime瘋狂gc/allocator直接罷工)的時候drop就可以了,其他時候不存在。這樣的話,當記憶體足夠的時候,我們只會在一旁看戲,只有記憶體告急,才會像安全網一樣空間換時間。

那 - 我們drop什麽數據呢?我們思考一下,CS有啥地方需要‘考慮存什麽數據不存什麽數據’的 - 答案是caching: 硬件的memory hierarchy, os的page table, 軟件層的software cache...

那我們可以像cache這樣,給所有的recomputable算‘儲存偏好’,在記憶體不夠的時候去掉認為最不該存下的數據。那,我們用最經典的LRU cache可不可行?

橋豆麻袋!我們思考一下一般的caching跟我們的caching有沒有不同的地方。在一個page table內,所有數據的大小都是一樣的(page size),而且所有數據的計算開銷都差不多(讀寫磁盤)。

但這,在深度學習中不成立。一個relu會比conv2d快數量級以上,pooling以前的tensor也比pooling以後的tensor大得多。所以,需要在LRU的基礎上,再算上計算時間跟記憶體開銷。

把這三個數都一起思考,我們可以得出compute / (memory * staleness)這個代表‘重計算單位收益開銷’的數。我們可以搜尋這個值最低的recomputable,並且drop - 哦,不對,既然是cache,就該叫evict了 - 之。

這時候,就麻雀雖小五臟俱全了。我們再檢查一下這整套系統效能上有沒有明顯問題 - 有沒有東東會成為瓶頸,有沒有兩個東東之間會帶來不良互動。

可以發現很明顯的三個問題 -

0:每次malloc,都可能觸發evict,而evict需要全域搜尋,這導致O(n)的程式會變成O(n2)的。幸運的是,深度學習中tensor數量不多,於是n2也可以頂。但是,依然可以做一些trick最佳化降低搜尋開銷 - 比如每次只搜尋sqrt(n)這麽多的tensor,或者比如有些tensor(太小的,或者計算量太多的)不搜尋。

1:cache policy跟recursive thunk不相容 - 如果重計算X需要觸發一系列的重計算,我們的compute只會算上X。我們並不能簡單的去遞迴計算compute來修復這一問題(因為這樣大O又多了一個n)。我們需要上incremental algorithm,計算下遞迴的cost存下來,而且在數據(因為evict或者get)更新的時候更新。一個剛好挺合適的數據結構是union find - 我們把各個聯通的evict的節點當成union find內同一集合的元素,就可以了。這是近似的,跟暴力算出來的結果不一樣,但是沒關系,cache的特點就是很健壯,policy不‘對’不影響結果。

2:整個系統跟傳統的memory management(GC/Region Based Memory Management/Ref Counting)沖突 - 由於變成了一個圖,加入了backpointer,那怕本來數據用完了,可以free,為了重計算,所有的過往數據都會被儲存下來。這其實挺好解決 - 進行GC/etc時我們假裝所有recomputable都沒有內部的指標,但是當我們發現recomputable可以被回收的時候,不進行回收而僅僅是evict - 換句話說,我們把所有‘本應發生的GC’改成evict則可。

ok,這時候我們整個系統就設計完畢了。這時候,我們來思考一下eval - 我們最初的設計初衷是‘decouple recompute from ad’。那我們思考一下這新設計有什麽更現實的價值。我們也繼續以前的操作,也就是從給定性質為基礎,進行推導。

0:我們不是ad - 這代表,我們可以在比如純inference條件下使用,或者在不是深度學習框架/自動求導工具外使用

1:我們是caching - 而這帶來了多個好處

1.0:這導致我們的演算法是adaptive的 - 記憶體越多,重計算越小,跑得越快,甚至我們可以在執行時動態調整,從而支持multi tenancy場景。

1.1:也可以靈活設計cache policy,在裏面簡單的融入memory/compute等常數項 - 而treeverse跟graph based algorithm則往往忽略了這些常數項,轉去分析程式的structure,從而承受忽略數量級常數項帶來的效能損失 - 這可以看checkmate的evaluation,傳統方法往往比checkmate差一截,但我們的方法卻跟checkmate這最優演算法效能接近。

1.2:cache是一種local又很快的東東。這代表我們可以很容易支持動態圖框架,跟動態神經網絡。同時,cache的速度又保證把這移動到執行時並不會構成瓶頸。

到這時候,我們就從最原始的,一些理論上細微的不協調,推出了一個有不錯現實價值的系統。

congrats, you just invented DTR。