你聽說過checkpointing嗎?
checkpointing透過在前向傳播的時候,丟棄一些計算好的結果,而在反向傳播需要的時候重新計算這些結果,從而節省中間的記憶體消耗。
這類別的工作有很多,經典的比如萬惡之源Treeverse,在DL界更多人知道的Sublinear memory cost,首個自動對一切程式進行checkpointing的Checkpointing for Arbitrary Programs,rnn特化版的Memory-Efficient Backpropagation Through Time。
這裏面,有一個不是問題的‘問題’ - checkpointing是一項‘利用重計算節省記憶體’的手段,不過這好像。。跟自動求導沒啥關系啊?
用一個比較軟件工程的說法:自動求導是一種尋找導數的演算法,重計算是一種節省記憶體的方法,為什麽我們在checkpointing裏面,把這兩個耦合起來呢?換句話說,我們能不能找一個 跟自動求導無關 的X,使得X + 自動求導 = checkpointing?
(P.S. Checkmate的確符合X,但是Checkmate僅僅是把問題粗暴的丟進一個ILP solver,不能帶來任何insight,也沒辦法scale,我們看看還有沒有其他方法?)
那我們試著倒推一下 - 假設我們實作了一套自動重計算的系統,這套系統的API,該如何設計呢?
首先,既然重計算不關心裏面存著什麽數據,是如何算出的(只需要忠實地重新執行計算操作就好),那我們知道,這是一個可以裝各種類別數據的泛型容器。
data
Recomputable
a
=
???
在這上面,最基礎的操作自然是獲得裏面的數據 - 或者直接獲得,或者重新計算下儲存下來,然後獲得之。
get
::
Recomputable
a
->
a
這過程跟haskell的thunk不謀而合。同時,重計算的過程中可能會呼叫其他Recomputable的get,從而導致遞迴的重計算。這是可組合性下自然誕生的產物,是充話費送的,不需要刻意支持。
那,我們如何構造一個Recomputable呢?我們可以從一個a構造,作為一個不可重計算的輸入:
pure
::
a
->
Recomputable
a
也可以由多個Recomputable數據組合而來 - 之所以是多個而不是一個,是因為計算可能會分叉,然後合並,於是會形成一個圖而不是鏈或者樹。
(
<*>
)
::
Recomputable
(
a
->
b
)
->
Recomputable
a
->
Recomputable
b
這剛好形成一個Applicative,性質也很好驗證 - Recomputable在外部看來,應該跟
newtype
Identity
a
=
Identity
{
runIdentity
::
a
}
是一樣的 - 所有的區別,都只是我們會在有需要的時候,把數據去掉 - 但是get的時候會透過重計算原樣返回。這只會產生效能上的差異,並不會更改語意。
OK,那我們還需要考慮如何去掉(drop)一個Recomputable的數據。
回想起Recomputable跟Lazy很像,drop這個操作就是Lazy的force的逆 - force強制計算一個lazy數據,並且儲存結果,那drop就是去掉儲存結果,然後只要我們還有計算用的thunk,就大功告成。於是,我們只需要選擇一個‘force後保留原thunk’的lazy實作則可。
module
LazyOption
:
Lazy
with
type
'
a
t
=
'
a
option
ref
*
(
unit
->
'
a
)
=
LazyImpl
(
struct
type
'
a
t
=
'
a
option
ref
*
(
unit
->
'
a
)
let
mk
f
=
(
ref
None
,
f
)
let
get
(
o
,
f
)
=
match
!
o
with
None
->
let
ret
=
f
()
in
o
:=
Some
ret
;
ret
|
Some
x
->
x
end
);;
那,就到了最後也是最重要的問題了 - 選那個數據drop?什麽時候drop?
什麽時候drop有一個很簡單的做法 - 當我們記憶體不夠(OS瘋狂swap/runtime瘋狂gc/allocator直接罷工)的時候drop就可以了,其他時候不存在。這樣的話,當記憶體足夠的時候,我們只會在一旁看戲,只有記憶體告急,才會像安全網一樣空間換時間。
那 - 我們drop什麽數據呢?我們思考一下,CS有啥地方需要‘考慮存什麽數據不存什麽數據’的 - 答案是caching: 硬件的memory hierarchy, os的page table, 軟件層的software cache...
那我們可以像cache這樣,給所有的recomputable算‘儲存偏好’,在記憶體不夠的時候去掉認為最不該存下的數據。那,我們用最經典的LRU cache可不可行?
橋豆麻袋!我們思考一下一般的caching跟我們的caching有沒有不同的地方。在一個page table內,所有數據的大小都是一樣的(page size),而且所有數據的計算開銷都差不多(讀寫磁盤)。
但這,在深度學習中不成立。一個relu會比conv2d快數量級以上,pooling以前的tensor也比pooling以後的tensor大得多。所以,需要在LRU的基礎上,再算上計算時間跟記憶體開銷。
把這三個數都一起思考,我們可以得出compute / (memory * staleness)這個代表‘重計算單位收益開銷’的數。我們可以搜尋這個值最低的recomputable,並且drop - 哦,不對,既然是cache,就該叫evict了 - 之。
這時候,就麻雀雖小五臟俱全了。我們再檢查一下這整套系統效能上有沒有明顯問題 - 有沒有東東會成為瓶頸,有沒有兩個東東之間會帶來不良互動。
可以發現很明顯的三個問題 -
0:每次malloc,都可能觸發evict,而evict需要全域搜尋,這導致O(n)的程式會變成O(n2)的。幸運的是,深度學習中tensor數量不多,於是n2也可以頂。但是,依然可以做一些trick最佳化降低搜尋開銷 - 比如每次只搜尋sqrt(n)這麽多的tensor,或者比如有些tensor(太小的,或者計算量太多的)不搜尋。
1:cache policy跟recursive thunk不相容 - 如果重計算X需要觸發一系列的重計算,我們的compute只會算上X。我們並不能簡單的去遞迴計算compute來修復這一問題(因為這樣大O又多了一個n)。我們需要上incremental algorithm,計算下遞迴的cost存下來,而且在數據(因為evict或者get)更新的時候更新。一個剛好挺合適的數據結構是union find - 我們把各個聯通的evict的節點當成union find內同一集合的元素,就可以了。這是近似的,跟暴力算出來的結果不一樣,但是沒關系,cache的特點就是很健壯,policy不‘對’不影響結果。
2:整個系統跟傳統的memory management(GC/Region Based Memory Management/Ref Counting)沖突 - 由於變成了一個圖,加入了backpointer,那怕本來數據用完了,可以free,為了重計算,所有的過往數據都會被儲存下來。這其實挺好解決 - 進行GC/etc時我們假裝所有recomputable都沒有內部的指標,但是當我們發現recomputable可以被回收的時候,不進行回收而僅僅是evict - 換句話說,我們把所有‘本應發生的GC’改成evict則可。
ok,這時候我們整個系統就設計完畢了。這時候,我們來思考一下eval - 我們最初的設計初衷是‘decouple recompute from ad’。那我們思考一下這新設計有什麽更現實的價值。我們也繼續以前的操作,也就是從給定性質為基礎,進行推導。
0:我們不是ad - 這代表,我們可以在比如純inference條件下使用,或者在不是深度學習框架/自動求導工具外使用
1:我們是caching - 而這帶來了多個好處
1.0:這導致我們的演算法是adaptive的 - 記憶體越多,重計算越小,跑得越快,甚至我們可以在執行時動態調整,從而支持multi tenancy場景。
1.1:也可以靈活設計cache policy,在裏面簡單的融入memory/compute等常數項 - 而treeverse跟graph based algorithm則往往忽略了這些常數項,轉去分析程式的structure,從而承受忽略數量級常數項帶來的效能損失 - 這可以看checkmate的evaluation,傳統方法往往比checkmate差一截,但我們的方法卻跟checkmate這最優演算法效能接近。
1.2:cache是一種local又很快的東東。這代表我們可以很容易支持動態圖框架,跟動態神經網絡。同時,cache的速度又保證把這移動到執行時並不會構成瓶頸。
到這時候,我們就從最原始的,一些理論上細微的不協調,推出了一個有不錯現實價值的系統。
congrats, you just invented DTR。