前回の記事「NVIDIA DIGITSでファインチューニングしてみた」で作成したalexnet_finetune.prototxtを例に、NVIDIA DIGITS用のprototxtについて説明します。
NVIDIA DIGITS用のprototxt
Caffeでは、トレーニング用のsolver.prototxtとtrain_val.prototxt、推論用のdeploy.prototxtの3種類のprototxtを使用します。
これに対して、NVIDIA DIGITSでは、トレーニング用のtrain_val.prototxtと推論用のdeploy.prototxtを一つにまとめたprototxtを用意しておき、DIGIT上の「New Image Classification Model」ページ等での設定に対応したsolver.prototxtを生成すると共に、一つにまとめたprototxtから、トレーニング用のtrain_val.prototxtと推論用のdeploy.prototxtを生成します。
また、DIGITSで生成されるtrain_val.prototxtとdeploy.prototxtは、Datasetのクラス数(又は、カテゴリー数)に合わせて自動的に設定されます。
alexnet_finetune.prototxt
alexnet_finetune.prototxtは、DIGITSのスタンダードネットワーク(https://github.com/NVIDIA/DIGITS/tree/digits-4.0/digits/standard-networks/caffe)のalexnet.prototxtをファインチューニング用に変更したもので、差分は以下のようになります。
$ diff -u alexnet.prototxt alexnet_finetune.prototxt --- alexnet.prototxt 2017-02-12 13:32:48.254066528 +0900 +++ alexnet_finetune.prototxt 2017-02-12 13:45:48.018104266 +0900 @@ -1,5 +1,5 @@ -# AlexNet -name: "AlexNet" +# AlexNet Fine Tuning +name: "AlexNet-Finetune" layer { name: "train-data" type: "Data" @@ -331,16 +331,16 @@ } } layer { - name: "fc8" + name: "fc8_ft" type: "InnerProduct" bottom: "fc7" - top: "fc8" + top: "fc8_ft" param { - lr_mult: 1 + lr_mult: 10 decay_mult: 1 } param { - lr_mult: 2 + lr_mult: 20 decay_mult: 0 } inner_product_param { @@ -361,15 +361,16 @@ layer { name: "accuracy" type: "Accuracy" - bottom: "fc8" + bottom: "fc8_ft" bottom: "label" top: "accuracy" + include { stage: "train" } include { stage: "val" } } layer { name: "loss" type: "SoftmaxWithLoss" - bottom: "fc8" + bottom: "fc8_ft" bottom: "label" top: "loss" exclude { stage: "deploy" } @@ -377,7 +378,7 @@ layer { name: "softmax" type: "Softmax" - bottom: "fc8" + bottom: "fc8_ft" top: "softmax" include { stage: "deploy" } }
'fc8 -> fc8_ft'
は、最終レイヤを新たに学習させるための設定、'lr_mult: 1 -> 10'
と'lr_mult: 2 -> 20'
は、最終レイヤを他のレイヤよりも早く学習させるための設定です。
また、accuracyレイヤへの'include { stage: "train" }'
の追加は、Validationだけではなく、TrainのAccuracyも表示させるための設定です。
まとめ
前回の記事「NVIDIA DIGITSでファインチューニングしてみた」で作成したalexnet_finetune.prototxtを例に、NVIDIA DIGITS用のプロトテキストについて説明しました。