PyTorch已經足夠簡單易用,但是簡單易用不等於方便快捷。特別是做大量實驗的時候,很多東西都會變得複雜,代碼也會變得龐大,這時候就容易出錯。
針對這個問題,就有了PyTorch Lightning。它可以重構你的PyTorch代碼,抽出複雜重複部分,讓你專注於核心的構建,讓你的實驗更快速更便捷地開展迭代。
1. Lightning 簡約哲學
大部分的DL/ML代碼都可以分為以下這三部分:
- 研究代碼 Research code
- 工程代碼 Engineering code
- 非必要代碼 Non-essential code
1.1 研究代碼 Research code
這部分屬於模型(神經網絡)部分,一般處理模型的結構、訓練等定製化部分。
在Linghtning中,這部分代碼抽象為 LightningModule 類。
1.2 工程代碼 Engineering code
這部分代碼很重要的特點是:重複性強,比如說設置early stopping、16位精度、GPUs分佈訓練。
在Linghtning中,這部分抽象為 Trainer 類。
1.3 非必要代碼 Non-essential code
這部分代碼有利於實驗的進行,但是和實驗沒有直接關係,甚至可以不使用。比如說檢查梯度、給tensorboard輸出log。
在Linghtning中,這部分抽象為 Callbacks 類。
2. 典型的AI研究項目
在大多數研究項目中,研究代碼 通常可以歸納到以下關鍵部分:
- 模型
- 訓練/驗證/測試 數據
- 優化器
- 訓練/驗證/測試 計算
上面已經提到,研究代碼 在 Lightning 中,是抽象為 LightningModule 類;而這個類與我們平時使用的 torch.nn.Module 是一樣的(在原有代碼中直接替換 Module 而不改其他代碼也是可以的),但不同的是,Lightning 圍繞 torch.nn.Module 做了很多功能性的補充,把上面4個關鍵部分都囊括了進來。
這麼做的意義在於:我們的 研究代碼 都是圍繞 我們的神經網絡模型 來運行的,所以 Lightning 把這部分代碼都集合在一個類裡。
所以我們接下來的介紹,都是圍繞 LightningModule 類來展開。
3. 生命週期
為了讓大家先有一個總體的概念,在這裡我先讓大家清楚 LightningModule 中運行的生命流程。
以下所有的函數,都是在 LightningModule 類 裡。
這部分是訓練開始之後的執行 “一般(默認)順序”。
- 首先是準備工作,包括初始化 LightningModule,準備數據 和 配置優化器。
這部分代碼 只執行一次。
<code>1. `__init__()`(初始化 LightningModule )
2. `prepare_data()` (準備數據,包括下載數據、預處理等等)
3. `configure_optimizers()` (配置優化器)/<code>
- 測試 “驗證代碼”。
提前來做的意義在於:不需要等待漫長的訓練過程才發現驗證代碼有錯。這部分就是提前執行 “驗證代碼”,所以和下面的驗證部分是一樣的。
<code>1. `val_dataloader()`
2. `validation_step()`
3. `validation_epoch_end()`/<code>
- 開始加載dataloader,用來給訓練加載數據
<code>1. `train_dataloader()`
2. `val_dataloader()` (如果你定義了)/<code>
- 下面部分就是循環訓練了,_step() 的意思就是按batch來進行的部分;_epoch_end() 就是所有batch執行完後要進行的部分。
<code># 循環訓練與驗證
1. `training_step()`
2. `validation_step()`
3. `validation_epoch_end()`/<code>
- 最後訓練完了,就要進行測試,但測試部分需要手動調用 .test(),這是為了避免誤操作。
<code># 測試(需要手動調用)
1. `test_dataloader()`
2. `test_step()`
3. `test_epoch_end()`
/<code>
在這裡,我們很容易總結出,在訓練部分,主要是三部分:_dataloader/_step/_epoch_end。Lightning把訓練的三部分抽象成三個函數,而我們只需要“填鴨式”地補充這三部分,就可以完成模型訓練部分代碼的編寫。
為了讓大家更清晰地瞭解這三部分的具體位置,下面用 PyTorch實現方式 來展現其位置。
<code>for epoch in epochs:
for batch in train_dataloader:
# train_step
# ....
# train_step
loss.backward()
optimizer.step()
optimizer.zero_grad()
for batch in val_dataloader:
# validation_step
# ....
# validation_step
# *_step_end
# ....
# *_step_end
/<code>
4. 使用Lightning的好處
- 只需要專注於 研究代碼
不需要寫一大堆的 .cuda() 和 .to(device),Lightning會幫你自動處理。如果要新建一個tensor,可以使用type_as來使得新tensor處於相同的處理器上。
<code>def training_step(self, batch, batch_idx):
x, y = batch
# 把z放在和x一樣的處理器上
z = sample_noise()
z = z.type_as(x)/<code>
在這裡,有個地方需要注意的是,不是所有的在LightningModule 的 tensor 都會被自動處理,而是隻有從 Dataloader 裡獲取的 tensor 才會被自動處理,所以對於 transductive learning 的訓練,最好自己寫Dataloader的處理函數。
- 工程代碼參數化
平時我們寫模型訓練的時候,這部分代碼會不斷重複,但又不得不做,不如說ealy stopping,精度的調整,顯存內存之間的數據轉移。這部分代碼雖然不難,但減少這部分代碼會使得 研究代碼 更加清晰,整體也更加簡潔。
下面是簡單的展示,表示使用 LightningModule 建立好模型後,如何進行訓練。
<code>model = LightningModuleClass()
trainer = pl.Trainer(gpus="0", # 用來配置使用什麼GPU
precision=32, # 用來配置使用什麼精度,默認是32
max_epochs=200 # 迭代次數
)
trainer.fit(model) # 開始訓練
trainer.test() # 訓練完之後測試
/<code>
結語
以上就是我對於 PyTorch Lightning 的入門總結,自己在這裡也走了很多坑,也把官方文檔過了一遍,但我的目的不是仿照官方文檔翻譯一遍,而是希望有自己的實踐體會和相對於官方文檔的規範更直觀。
閱讀更多 圖網絡與機器學習 的文章