K重交叉驗證 和 網格搜索驗證

本文介紹Keras一些常見的驗證和調參技巧,快速地驗證模型和調節超參。

小技巧:

CSV數據文件加載Dense初始化警告

驗證與調參:

模型驗證(Validation)K重交叉驗證(K-fold Cross-Validation)網格搜索驗證(Grid Search Cross-Validation)

Keras

CSV數據文件加載

使用NumPy的 loadtxt() 方法加載CSV數據文件

delimiter:數據單元的分割符;skiprows:略過首行標題;

Dense初始化警告

Dense初始化參數的警告:

將init參數替換為 kernel_initializer 參數即可。

模型驗證

fit()自動 劃分驗證集:

通過設置參數 validation_split 的值(0~1)確定驗證集的比例。

實現:

fit()手動 劃分驗證集:

train_test_split 來源sklearn.model_selection:

test_size :驗證集的比例;random_state :隨機數的種子;

通過參數 validation_data 添加驗證數據,格式是 數據+標籤 的元組。

實現:

交叉驗證

K重交叉驗證(K-fold Cross-Validation)是常見的模型評估統計。

人工模式

交叉驗證函數 StratifiedKFold() 來源於sklearn.model_selection:

n_splits :交叉的重數,即N重交叉驗證;shuffle :數據和標籤是否隨機洗牌;random_state :隨機數種子;skf.split(X, y) :劃分數據和標籤的索引。

cvscores用於統計K重交叉驗證的結果,計算均值和方差。

實現:

輸出:

Wrapper模式

通過 cross_val_score() 函數集成模型和交叉驗證邏輯。

將模型封裝成wrapper,注意使用 內置函數 ,而 調用,沒有括號 () 。epochs 即輪次, batch_size 即批次數;StratifiedKFold是K重交叉驗證的邏輯;

cross_val_score 的輸入是模型wrapper、數據X、標籤Y、交叉驗證cv;輸出是每次驗證的結果,再計算均值和方差。

實現:

輸出:

網格搜索驗證

網格搜索驗證(Grid Search Cross-Validation)用於選擇模型的最優超參值。

交叉驗證函數 GridSearchCV() 來源於sklearn.model_selection:

設置超參列表,如optimizers、 init_modes 、epochs、batches;創建參數字典,key值是模型的參數,或者wrapper的參數;estimator是模型, param_grid 是網格參數字典, n_jobs 是進程數;輸出最優結果和其他排列組合結果。

實現:

輸出: