「Deep Learning Javaプログラミング」で気になったこと

昨日(12/17)、Java読書会BOF:『「Deep Learning Javaプログラミング 深層学習の理論と実装」を読む会』に参加しました。その本に載っているコードについて読書会中に考えていたことを自分用メモをかねてまとめました。

pp.45-46にある、Perceptronsクラスの一部
p.46のtrain()メソッド

public int train(double[] x, int t, double learningRate) {

    int classified = 0;
    double c = 0.;

    // check if the data is classified correctly
    for (int i = 0; i < nIn; i++) {
        c += w[i] * x[i] * t;
    }

    // apply gradient descent method if the data is wrongly classified
    if (c > 0) {
        classified = 1;
    } else {
        for (int i = 0; i < nIn; i++) {
            w[i] += learningRate * x[i] * t;
        }
    }

    return classified;
}

と、p.45のmain()内処理

while (true) {
    int classified_ = 0;

    for (int i=0; i < train_N; i++) {
        classified_ += classifier.train(train_X[i], train_T[i], learningRate);
    }

    if (classified_ == train_N) break;  // when all data classified correctly

    epoch++;
    if (epoch > epochs) break;
}

です。

p.46のclassifiedは、p.42下方の説明にあるように、p.43の式(2.5.4)が成り立つときは入力データが正しく分類されたという意味の「フラグ」として使用されています。
しかしp.45では、classifier.train()の戻りを計算可能な「数値」として使用し、classified_に加算しています。そして、classified_とループ回数(train_N)が同じ値であれば全て正しく分類されたと判定しています。

「フラグ」をintで表すところもですが、それを「数値」として扱うところにもやもやしてしまいます。

自分ならこんな感じにするかなということで。

Javaにはboolean型があるので「フラグ」はbooleanで表して
p.46は、

public boolean isTrainClassified(double[] x, int t, double learningRate) {
    double c = 0.;

    // check if the data is classified correctly
    for (int i = 0; i < nIn; i++) {
        c += w[i] * x[i] * t;
    }
    final boolean isClassified = c > 0;

    // apply gradient descent method if the data is wrongly classified
    if ( !isClassified ) {
        for (int i = 0; i < nIn; i++) {
            w[i] += learningRate * x[i] * t;
        }
    }
    return isClassified;
}

とし、
p.45は「ループ数と比較する」という間接的な判定ではなく、「全て正しく分類された」という直接的な判定にする。

for (int i = 0; i < = epochs; i++) {
    boolean isAllClassified = true;

    for (int j = 0; j < train_N; j++) {
        isAllClassified &= classifier.isTrainClassified(train_X[j], train_T[j], learningRate);
    }

    // when all data classified correctly
    if (isAllClassified) {
        break;
    }
}

おまけ
p.48のPerceptionsの実行結果が「こんなもんかな」と考えていた時の頭の中の絵。

dlwjp48

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

CAPTCHA