本文介紹Keras一些常見的驗證和調參技巧,快速地驗證模型和調節超參。
小技巧:
- CSV數據文件加載
- Dense初始化警告
驗證與調參:
- 模型驗證(Validation)
- K重交叉驗證(K-fold Cross-Validation)
- 網格搜索驗證(Grid Search Cross-Validation)
Keras
CSV數據文件加載
使用NumPy的 loadtxt() 方法加載CSV數據文件
- delimiter:數據單元的分割符;
- skiprows:略過首行標題;
Dense初始化警告
Dense初始化參數的警告:
將init參數替換為 kernel_initializer 參數即可。
模型驗證
在 fit() 中 自動 劃分驗證集:
通過設置參數 validation_split 的值(0~1)確定驗證集的比例。
實現:
在
fit() 中 手動 劃分驗證集:train_test_split 來源sklearn.model_selection:
- test_size :驗證集的比例;
- random_state :隨機數的種子;
通過參數 validation_data 添加驗證數據,格式是 數據+標籤 的元組。
實現:
交叉驗證
K重交叉驗證(K-fold Cross-Validation)是常見的模型評估統計。
人工模式
交叉驗證函數 StratifiedKFold() 來源於sklearn.model_selection:
- n_splits :交叉的重數,即N重交叉驗證;
- shuffle :數據和標籤是否隨機洗牌;
- random_state :隨機數種子;
- skf.split(X, y) :劃分數據和標籤的索引。
cvscores用於統計K重交叉驗證的結果,計算均值和方差。
實現:
輸出:
Wrapper模式
通過 cross_val_score() 函數集成模型和交叉驗證邏輯。
- 將模型封裝成wrapper,注意使用 內置函數 ,而 非 調用,沒有括號 () 。
- epochs 即輪次, batch_size 即批次數;
- StratifiedKFold是K重交叉驗證的邏輯;
cross_val_score 的輸入是模型wrapper、數據X、標籤Y、交叉驗證cv;輸出是每次驗證的結果,再計算均值和方差。
實現:
輸出:
網格搜索驗證
網格搜索驗證(Grid Search Cross-Validation)用於選擇模型的最優超參值。
交叉驗證函數 GridSearchCV() 來源於sklearn.model_selection:
- 設置超參列表,如optimizers、 init_modes 、epochs、batches;
- 創建參數字典,key值是模型的參數,或者wrapper的參數;
- estimator是模型, param_grid 是網格參數字典, n_jobs 是進程數;
- 輸出最優結果和其他排列組合結果。
實現:
輸出:
閱讀更多 sandag 的文章
關鍵字: 驗證 調參 Validation