教師データを,訓練データと検証データに分ける関数sklearn.model_selection.train_test_split()
が便利そうだったので使ってみました.
sklearn.model_selection.train_test_split — scikit-learn 1.1.1 documentation
この関数の引数にshuffle
がありますが,デフォルトでTrue
なので注意しなければなりません.
環境
tensorflow 2.9.1
教師データの準備
例えば,2次元空間の-1から1までの範囲のランダムな点に対して,XOR分類する機械学習用データを以下のように作成するのは誤りです.
import numpy as np import tensorflow as tf from sklearn.model_selection import train_test_split tf.random.set_seed(1) np.random.seed(1) x = np.random.uniform(low=-1, high=1, size=(1024, 2)).astype("float32") y = np.array([1 if a * b > 0 else 0 for a, b in x]).astype("int8") x_train, x_test = train_test_split(x, test_size=0.2) y_train, y_test = train_test_split(y, test_size=0.2)
train_test_split()
を実行するとシャッフルされるので,訓練データと検証データの対応付けが壊れます.
x1, x2, y = 0.8185864 -0.19825253 1 -0.9665688 -0.20268188 1 0.5745269 -0.7995307 1 0.5891567 0.014159846 1 0.27671063 -0.34486112 0 0.49400964 -0.3192414 0 ...
shuffle
オプションをFalse
にします.
x_train, x_test = train_test_split(x, test_size=0.2, shuffle=False) y_train, y_test = train_test_split(y, test_size=0.2, shuffle=False)
これで対応したままになります.
x1, x2, y = 0.00601039 0.28671584 1 -0.7775087 0.5161233 0 0.7038915 0.68849975 1 0.78101367 -0.2363152 0 -0.1123896 -0.79683703 1 -0.54838616 0.94264096 0 ...
学習結果
以下のモデルで学習させました.
model = tf.keras.Sequential() model.add(tf.keras.Input(shape=(2,))) model.add(tf.keras.layers.Dense(units=4, activation="relu")) model.add(tf.keras.layers.Dense(units=1, activation="sigmoid")) model.summary() model.compile( optimizer="sgd", loss="binary_crossentropy", metrics=["binary_accuracy"], ) history = model.fit( x_train, y_train, batch_size=2, epochs=32, validation_data=(x_test, y_test), )
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 4) 12 dense_1 (Dense) (None, 1) 5 ================================================================= Total params: 17 Trainable params: 17 Non-trainable params: 0 _________________________________________________________________
シャッフルして対応付けが壊れた教師データ(shuffle
)は全く学習できていませんが,正しく準備した教師データ(noshuffle
)は学習できています.