モバイル&ワイヤレスブロードバンドでインターネットへ

gwaw.jp

tfjs で Early Stopping。スマホのテキストエディタで。

TensorFlow.js (tfjs) ブラウザと Node.js 内で機械学習モデルの訓練とデプロイを行うための JavaScript ライブラリです。

https://www.tensorflow.org/js/tutorials

今回は、この tfjs Early Stopping を試します。

機械学習では、訓練データを繰り返し学習することで誤差を小さくしていきますが、その学習回数が多すぎると、途中から検証データにおいて誤差が大きくなるという問題があります。これは、オーバーフィッティング ( Overfitting ) 、または、過学習といわれています。この問題に対するテクニックが Early Stopping (早期終了)です。

現時点で最新の tfjs バージョン 1.1.2 では、 tf.callbacks.earlyStopping() が用意されています。この API Reference の example で動作が確認できます。

次のコードは、その example を HTML ドキュメントに展開して、async の即時関数 Immediately Invoked Function Expression ( IIFE ) により実行できるようにしてみました。console.log() の出力にも対応して、動作が確認できます。



<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>サンプル</title>

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"> </script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>

<script>
var _debug;
var _dct = 0;
function _debugStart(){
  _debug = document.getElementById("DEBUG");
  _debug.textContent = "";
  _debugHTML("DEBUG START");
}
function _debugHTML(_deHTML){
  _debug.insertAdjacentHTML("afterbegin", "<div><em>" + ++_dct + " : </em>" + _deHTML + "<div>");
}
console.log = function(_dobj){
  _debugHTML(_dobj);
};
</script>
</head>
<body>

<div id="DEBUG"></div>

<script>
_debugStart();
_debugHTML("tfjs Early Stopping START");

// 即時関数 Immediately Invoked Function Expression ( IIFE )
(async () => {

  const model = tf.sequential();
  model.add(tf.layers.dense({
    units: 3,
    activation: "softmax",
    kernelInitializer: "ones",
    inputShape: [2]
  }));
  const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
  const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
  const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
  const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
  model.compile(
    {loss: "categoricalCrossentropy", optimizer: "sgd", metrics: ["acc"]});

// Without the EarlyStopping callback, the val_acc value would be:
//   0.5, 0.5, 0.5, 0.5, ...
// With val_acc being monitored, training should stop after the 2nd epoch.
  const history = await model.fit(xs, ys, {
    epochs: 10,
    validationData: [xsVal, ysVal],
    callbacks: tf.callbacks.earlyStopping({monitor: "val_acc"})
  });

// Expect to see a length-2 array.
  console.log(JSON.stringify(history.history));

})();

_debugHTML("tfjs Early Stopping E N D");
</script>
</body>
</html>

 

エポック数とバッチサイズについて、カンタンにおさらいします。

訓練データの学習回数をエポック ( epoch ) 数といいます。確率的勾配降下法 stochastic gradient descent ( SGD ) は、ランダムにシャッフルした訓練データ1つ1つで計算した勾配からモデルを更新し、それをエポック数だけ繰り返します。

ミニバッチ勾配降下法 mini-batch gradient descent は、訓練データをミニバッチサイズごとで計算した勾配からモデルを更新します。確率的勾配降下法は、このミニバッチ勾配降下法に取り込まれつつ、そう呼ばれていることもあります。

さて、今回使用している環境は、Android スマホのテキストエディタアプリ QuickEdit Text Editor Pro です。プレビュー機能で HTML ドキュメントの JavaScript コードを実行することで確認しています。 QuickEdit Text Editor Pro で TensorFlow.js の CODE SAMPLE FOR SCRIPT TAG SETUP をプレビュー で記事投稿した方法です。文法 Error の Debug はできませんが、console.log を出力するためのコードを追加しています。

次に、 tf.callbacks.earlyStopping() がまだなかったときに作成した Early Stopping のコードです。この記事の最後に紹介している書籍の解説と Python 実装コードを参考にしています。これを tf.callbacks.earlyStopping() の example に使用してみました。

さらに、 tfvis.show.history による可視化で動作を確認します。



<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>サンプル</title>

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"> </script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>

<script>
var _debug;
var _dct = 0;
function _debugStart(){
  _debug = document.getElementById("DEBUG");
  _debug.textContent = "";
  _debugHTML("DEBUG START");
}
function _debugHTML(_deHTML){
  _debug.insertAdjacentHTML("afterbegin", "<div><em>" + ++_dct + " : </em>" + _deHTML + "<div>");
}
console.log = function(_dobj){
  _debugHTML(_dobj);
};
</script>
</head>
<body>

<div id="historyChart"></div>

<div id="DEBUG"></div>

<script>
_debugStart();
_debugHTML("tfjs Early Stopping START");

class EarlyStopping{
  constructor(patience=0, verbose=0){
    this._step = 0;
    this._loss = Number.POSITIVE_INFINITY;
    this.patience = patience;
    this.verbose = verbose;
  }
  validate(loss){
    if(this._loss < loss){
      this._step += 1;
      if(this._step > this.patience){
        if(this.verbose){
          console.log("early stopping");
        }
        return true;
      }
    }else{
      this._step = 0;
      this._loss = loss;
    }
    return false;
  }
}

// 即時関数 Immediately Invoked Function Expression ( IIFE )
(async () => {

  const model = tf.sequential();
  model.add(tf.layers.dense({
    units: 3,
    activation: "softmax",
    kernelInitializer: "ones",
    inputShape: [2]
  }));
  const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
  const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
  const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
  const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
  model.compile(
    {loss: "categoricalCrossentropy", optimizer: "sgd", metrics: ["acc"]});
  const early_stopping = new EarlyStopping(0, 1);
// Without the EarlyStopping callback, the val_acc value would be:
//   0.5, 0.5, 0.5, 0.5, ...
// With val_acc being monitored, training should stop after the 2nd epoch.
  const history = await model.fit(xs, ys, {
    epochs: 10,
    validationData: [xsVal, ysVal],
    callbacks: {
      onEpochEnd: async (epoch, logs) => {
        if(early_stopping.validate(logs["val_loss"])){
          model.stopTraining = true;
        }
      }
    }
  });
// Expect to see a length-2 array.
  console.log(JSON.stringify(history.history));

  const historyContainer = document.getElementById("historyChart");
  tfvis.show.history(
    historyContainer,
    history,
    ["val_acc"],
    {
      width: 300,
      height: 300,
      xLabel: "epoch",
      yLabel: "accuracy",
      xType: "ordinal",
  });
})();

_debugHTML("tfjs Early Stopping E N D");
</script>
</body>
</html>

 

今回参照している書籍です。なお、本書での実装は Python です。

詳解 ディープラーニング ~TensorFlow・Kerasによる時系列データ処理~

Amazon.co.jp
商品詳細リンク

『tfjs で Early Stopping。スマホのテキストエディタで。』を公開しました。