みーのぺーじ

みーが趣味でやっているPCやソフトウェアについて.Python, Javascript, Processing, Unityなど.

train_test_split()はデフォルトでシャッフルされる

教師データを,訓練データと検証データに分ける関数sklearn.model_selection.train_test_split()が便利そうだったので使ってみました.

sklearn.model_selection.train_test_split — scikit-learn 1.1.1 documentation

この関数の引数にshuffleがありますが,デフォルトでTrueなので注意しなければなりません.

環境

tensorflow 2.9.1

教師データの準備

例えば,2次元空間の-1から1までの範囲のランダムな点に対して,XOR分類する機械学習用データを以下のように作成するのは誤りです.

import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

tf.random.set_seed(1)
np.random.seed(1)

x = np.random.uniform(low=-1, high=1, size=(1024, 2)).astype("float32")
y = np.array([1 if a * b > 0 else 0 for a, b in x]).astype("int8")

x_train, x_test = train_test_split(x, test_size=0.2)
y_train, y_test = train_test_split(y, test_size=0.2)

train_test_split()を実行するとシャッフルされるので,訓練データと検証データの対応付けが壊れます.

x1, x2, y =
0.8185864 -0.19825253 1
-0.9665688 -0.20268188 1
0.5745269 -0.7995307 1
0.5891567 0.014159846 1
0.27671063 -0.34486112 0
0.49400964 -0.3192414 0
...

shuffleオプションをFalseにします.

x_train, x_test = train_test_split(x, test_size=0.2, shuffle=False)
y_train, y_test = train_test_split(y, test_size=0.2, shuffle=False)

これで対応したままになります.

x1, x2, y =
0.00601039 0.28671584 1
-0.7775087 0.5161233 0
0.7038915 0.68849975 1
0.78101367 -0.2363152 0
-0.1123896 -0.79683703 1
-0.54838616 0.94264096 0
...

学習結果

以下のモデルで学習させました.

model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=(2,)))
model.add(tf.keras.layers.Dense(units=4, activation="relu"))
model.add(tf.keras.layers.Dense(units=1, activation="sigmoid"))

model.summary()

model.compile(
    optimizer="sgd",
    loss="binary_crossentropy",
    metrics=["binary_accuracy"],
)

history = model.fit(
    x_train,
    y_train,
    batch_size=2,
    epochs=32,
    validation_data=(x_test, y_test),
)
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 4)                 12        
                                                                 
 dense_1 (Dense)             (None, 1)                 5         
                                                                 
=================================================================
Total params: 17
Trainable params: 17
Non-trainable params: 0
_________________________________________________________________

シャッフルして対応付けが壊れた教師データ(shuffle)は全く学習できていませんが,正しく準備した教師データ(noshuffle)は学習できています.