I am new to Machine Learning & Deep Learning. I would like to clarify my doubt related to train_test_split
before training
I have a data set of size (302, 100, 5)
, where,
(207,100,5)
belongs to class 0
(95,100,5)
belongs to class 1.
I would like to perform Classification using LSTM (since, sequence Data)
How can i split my data set for training, since the classes do not have equal distribution sets ?
Option 1 : Consider whole data [(302,100, 5) - both classes (0 & 1)]
, shuffle it, train_test_split,
proceed training.
Option 2 : Split both class data set equally
[(95,100,5) - class 0 & (95,100,5) - class 1]
, shuffle it,
train_test_split, proceed training.
What will be the better way of splitting before training, so that i can get better results in terms of loss reduction, accuracy, prediction, ?
If there are other options rather than above 2 options, kindly recommend,
Based on the comment section i include a part of my data :
X_train : shape (241 * 100 * 5)
Each row in every 100*5 corresponds to 1 Time step Finally 100 rows corresponds to 100 Time steps in milli seconds (ms)
array([[[0.98620635, 0. , 0.12752912, 0.60897341, 0.46903766],
[0.97345112, 0. , 0.12752912, 0.49205995, 0.38709902],
[0.9566397 , 0. , 0.12752912, 0.45728718, 0.42154812],
...,
[0.28669754, 0.8852459 , 0.12752912, 0.8786213 , 0.80125523],
[0.31559784, 0.8852459 , 0.20968731, 0.89087803, 0.79476987],
[0.34368841, 0.8852459 , 0.12752912, 0.89087803, 0.71066946]],
[[0.97957188, 0.14909194, 0.04159147, 0.50548561, 0.34209531],
[0.9687237 , 0.13964397, 0.04159147, 0.55926067, 0.64613533],
[0.96596236, 0.13553813, 0.04159147, 0.55903796, 0.85299319],
...,
[0.49309139, 0.72396527, 0.04159147, 0.81998825, 0.12362443],
[0.52072591, 0.70872926, 0.04159147, 0.82361951, 0.89639432],
[0.54441507, 0.71835207, 0.04159147, 0.84964602, 1. ]],
[[0.48151381, 0.875 , 0.16666667, 0.90637286, 0.62737926],
[0.53325374, 0.8625 , 0.33333333, 0.87881677, 0.5321154 ],
[0.57506452, 0.81859091, 0.16666667, 0.84915758, 0.3552661 ],
...,
[0.34456041, 0.92993213, 0.33333333, 0.92953899, 0.78782408],
[0.39496018, 0.90523485, 0.33333333, 0.9117954 , 0.54579383],
[0.44187985, 0.8625 , 0.33333333, 0.84163194, 0.25789356]],
...,
[[0.16368355, 0. , 0.15313225, 0.40101906, 0.36784741],
[0.15679684, 0. , 0.15313225, 0.4435126 , 0.67351994],
[0.15544309, 0.06132052, 0.15313225, 0.40101906, 0.36611345],
...,
[0.43936628, 0.68292683, 0.15313225, 0.82305329, 0.36784741],
[0.49751546, 0.68292683, 0.07764888, 0.84141109, 0.42828833],
[0.53288488, 0.68292683, 0.15313225, 0.85959823, 0.36784741]],
[[0.9418247 , 0.30821318, 0.03072816, 0.744977 , 0.93769733],
[0.9537216 , 0.28989357, 0.03072816, 0.74576381, 0.98468743],
[0.96455286, 0.21736423, 0.03072816, 0.74182977, 1. ],
...,
[0.36273884, 0.60113245, 0.06145633, 0.85409181, 0.32277415],
[0.38774614, 0.57789971, 0.05844559, 0.82937631, 0. ],
[0.41546859, 0.57789971, 0.03072816, 0.79315883, 0.31256578]],
[[0.97868688, 0.06451613, 0.00411829, 0.64705259, 0.69827586],
[0.97999663, 0.06451613, 0.02256676, 0.66812232, 0.75195925],
[0.97143037, 0.02476377, 0.02256676, 0.66317859, 0.78487461],
...,
[0.50336862, 0.73867709, 0.02256676, 0.84921606, 0.1226489 ],
[0.54003486, 0.72043011, 0.02256676, 0.82679269, 0.20297806],
[0.57594039, 0.70967742, 0.02256676, 0.83350205, 0. ]]])
Y_train : shape (241,)
[1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 1. 1. 0. 0. 0.
1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 1.
0. 0. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 1. 0.
0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1.
1. 0. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1.
0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0.
0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 1. 1. 1. 0. 1.
0. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
1. 0. 0. 1. 1. 1. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1.
1.]
For reference,
As you can see above, X-train data is large, I cannot include complete set of my entire X_train data. So I provide only one segment of my data here for better understanding of how my data looks like for 1 segment, (i.e X_train[0] : shape- (100*5))
. The remaining 240
will more or less looks like below
array([[9.86206354e-01, 0.00000000e+00, 1.27529123e-01, 2.29139335e-02,
6.08973407e-01, 4.69037657e-01],
[9.73451120e-01, 0.00000000e+00, 1.27529123e-01, 2.60807671e-02,
4.92059955e-01, 3.87099024e-01],
[9.56639704e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
4.57287179e-01, 4.21548117e-01],
[9.34897700e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
4.84177685e-01, 4.69037657e-01],
[9.18030989e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
4.86406180e-01, 4.08577406e-01],
[9.02168015e-01, 0.00000000e+00, 1.27529123e-01, 2.64020795e-02,
4.84920517e-01, 4.04184100e-01],
[8.82551572e-01, 0.00000000e+00, 1.27529123e-01, 2.56783096e-02,
4.51195959e-01, 3.78661088e-01],
[8.69975342e-01, 0.00000000e+00, 1.27529123e-01, 2.40477851e-02,
4.70286733e-01, 4.23640167e-01],
[8.41027241e-01, 0.00000000e+00, 1.27529123e-01, 1.75387576e-02,
5.04754123e-01, 4.34728033e-01],
[8.28189535e-01, 5.28763040e-01, 1.27529123e-01, 6.89133486e-03,
4.98662903e-01, 4.58368201e-01],
[8.21784739e-01, 8.21162444e-01, 1.27529123e-01, 1.06196483e-02,
5.87431288e-01, 5.72594142e-01],
[8.26651597e-01, 9.96721311e-01, 1.27529123e-01, 1.75044480e-02,
6.89050661e-01, 5.40376569e-01],
[8.42115326e-01, 1.00000000e+00, 1.27529123e-01, 1.71205069e-02,
8.35388501e-01, 4.69037657e-01],
[8.64071009e-01, 9.26875310e-01, 1.27529123e-01, 1.34068975e-02,
1.00000000e+00, 4.65062762e-01],
[8.79579724e-01, 7.60158967e-01, 1.27529123e-01, 4.65303975e-03,
9.61744169e-01, 3.65481172e-01],
[9.03630040e-01, 7.61549925e-01, 1.27529123e-01, 4.21518348e-03,
9.22076957e-01, 3.78033473e-01],
[9.18435858e-01, 6.72429210e-01, 1.27529123e-01, 2.70229205e-03,
9.39979201e-01, 5.03138075e-01],
[9.29983046e-01, 6.85345256e-01, 1.27529123e-01, 9.05120794e-04,
8.53736443e-01, 5.52510460e-01],
[9.48081232e-01, 5.78539493e-01, 1.27529123e-01, 6.96485550e-03,
8.84415391e-01, 3.04602510e-01],
[9.48112160e-01, 5.55091903e-01, 1.27529123e-01, 1.10493356e-02,
8.19046204e-01, 4.78661088e-01],
[9.61281634e-01, 5.08693492e-01, 1.27529123e-01, 9.36162843e-03,
8.23651761e-01, 3.21548117e-01],
[9.72179346e-01, 4.91803279e-01, 1.27529123e-01, 9.82725917e-03,
7.57391175e-01, 4.96025105e-01],
[9.84752763e-01, 4.91803279e-01, 1.27529123e-01, 7.04491131e-03,
7.59322538e-01, 3.95397490e-01],
[9.90300024e-01, 4.91803279e-01, 1.27529123e-01, 8.19346712e-03,
7.64819492e-01, 4.69037657e-01],
[9.88306609e-01, 3.77049180e-01, 1.27529123e-01, 8.62642201e-03,
7.93492795e-01, 4.16945607e-01],
[9.91084457e-01, 3.93442623e-01, 1.27529123e-01, 9.16557339e-03,
7.10741346e-01, 4.72175732e-01],
[1.00000000e+00, 3.78936910e-01, 1.27529123e-01, 1.16538387e-02,
6.93359085e-01, 4.76987448e-01],
[9.98925974e-01, 3.93442623e-01, 1.27529123e-01, 1.21309060e-02,
7.16609716e-01, 3.46025105e-01],
[9.92838888e-01, 3.32141083e-01, 1.27529123e-01, 1.19315833e-02,
7.31540633e-01, 4.16527197e-01],
[9.90637415e-01, 3.36910084e-01, 1.27529123e-01, 9.95632874e-03,
7.12524142e-01, 4.15481172e-01],
[9.90761125e-01, 3.38301043e-01, 1.27529123e-01, 6.59235091e-03,
6.86970732e-01, 4.37656904e-01],
[9.90274720e-01, 3.27868852e-01, 2.10913550e-01, 5.68396253e-03,
7.09181399e-01, 4.99372385e-01],
[9.83015202e-01, 3.27868852e-01, 1.27529123e-01, 2.14974358e-02,
7.31392067e-01, 6.41631799e-01],
[9.77392028e-01, 2.85245902e-01, 1.47762109e-01, 2.52861995e-02,
7.09478532e-01, 6.07112971e-01],
[9.75300207e-01, 2.78688525e-01, 1.27529123e-01, 2.91468501e-02,
6.70257020e-01, 6.28242678e-01],
[9.74917831e-01, 2.71733731e-01, 1.27529123e-01, 3.58780734e-02,
6.70257020e-01, 5.72594142e-01],
[9.64950755e-01, 2.62295082e-01, 1.27529123e-01, 3.92992339e-02,
6.36383895e-01, 6.67991632e-01],
[9.63159774e-01, 2.62295082e-01, 1.27529123e-01, 4.82932591e-02,
6.93581934e-01, 5.46443515e-01],
[9.54983679e-01, 2.90511674e-01, 1.27529123e-01, 4.90627752e-02,
6.59708810e-01, 7.40376569e-01],
[9.57595643e-01, 3.11475410e-01, 1.27529123e-01, 4.72492660e-02,
6.49977715e-01, 5.61297071e-01],
[9.51511369e-01, 2.95081967e-01, 1.27529123e-01, 1.82576261e-02,
6.64314366e-01, 5.22384937e-01],
[9.48528275e-01, 2.95081967e-01, 1.27529123e-01, 3.89659403e-03,
6.29846977e-01, 3.20711297e-01],
[9.47085931e-01, 2.95081967e-01, 1.27529123e-01, 6.86682798e-03,
6.48417769e-01, 4.38284519e-01],
[9.38153518e-01, 2.95081967e-01, 1.27529123e-01, 5.73951146e-03,
7.04130144e-01, 5.32635983e-01],
[9.38114156e-01, 2.95081967e-01, 1.27529123e-01, 2.05955826e-02,
6.85782202e-01, 5.47280335e-01],
[9.35597786e-01, 2.95081967e-01, 1.27529123e-01, 2.91141743e-02,
6.69142772e-01, 7.13807531e-01],
[9.29311077e-01, 2.72826627e-01, 1.27529123e-01, 2.91141743e-02,
6.81622344e-01, 5.72594142e-01],
[9.25495753e-01, 2.23646299e-01, 1.27529123e-01, 2.65507546e-02,
6.35566781e-01, 6.41004184e-01],
[9.18525829e-01, 2.08643815e-03, 1.27529123e-01, 2.37618715e-02,
6.09641955e-01, 5.02928870e-01],
[8.91801693e-01, 0.00000000e+00, 1.27529123e-01, 9.27013608e-03,
5.26073392e-01, 4.21338912e-01],
[8.77693149e-01, 0.00000000e+00, 1.27529123e-01, 8.13628440e-03,
4.22522656e-01, 3.44560669e-01],
[8.61894841e-01, 0.00000000e+00, 1.27529123e-01, 1.49639014e-02,
4.52755906e-01, 3.65481172e-01],
[8.44254943e-01, 0.00000000e+00, 1.27529123e-01, 2.29515107e-02,
4.59069975e-01, 3.76150628e-01],
[8.21183060e-01, 0.00000000e+00, 1.27529123e-01, 3.97583295e-02,
4.60852771e-01, 2.60460251e-01],
[8.04116726e-01, 0.00000000e+00, 1.27529123e-01, 5.89292454e-02,
4.26905363e-01, 1.97907950e-01],
[7.81311943e-01, 0.00000000e+00, 1.27529123e-01, 8.53656345e-02,
4.37379290e-01, 1.00836820e-01],
[7.60863270e-01, 0.00000000e+00, 1.27529123e-01, 1.03087377e-01,
4.37379290e-01, 6.98744770e-02],
[7.41227145e-01, 0.00000000e+00, 1.27529123e-01, 1.14206966e-01,
4.27128213e-01, 1.58368201e-01],
[7.26694052e-01, 0.00000000e+00, 1.27529123e-01, 1.17776801e-01,
4.37379290e-01, 0.00000000e+00],
[7.08716764e-01, 0.00000000e+00, 1.27529123e-01, 1.17288297e-01,
4.48596048e-01, 2.18619247e-01],
[6.90483621e-01, 0.00000000e+00, 1.27529123e-01, 1.08491961e-01,
4.58549993e-01, 1.26987448e-01],
[6.67451099e-01, 0.00000000e+00, 1.27529123e-01, 8.38217010e-02,
4.99628584e-01, 3.55020921e-01],
[6.51610618e-01, 0.00000000e+00, 1.27529123e-01, 4.32889541e-02,
5.10919626e-01, 4.83054393e-01],
[6.31195684e-01, 0.00000000e+00, 1.27529123e-01, 1.29200275e-02,
5.21170703e-01, 4.97907950e-01],
[6.14317726e-01, 0.00000000e+00, 2.26241570e-01, 9.32895259e-04,
4.98960036e-01, 4.69037657e-01],
[5.98165158e-01, 0.00000000e+00, 5.90435316e-01, 0.00000000e+00,
4.61892735e-01, 5.03556485e-01],
[5.68221755e-01, 0.00000000e+00, 6.33353771e-01, 1.61745413e-03,
4.25122567e-01, 4.69037657e-01],
[5.35292447e-01, 0.00000000e+00, 1.00000000e+00, 8.99402522e-03,
3.58490566e-01, 5.10041841e-01],
[5.10766973e-01, 0.00000000e+00, 3.93010423e-01, 3.39894098e-02,
3.27068786e-01, 6.15690377e-01],
[4.78939807e-01, 0.00000000e+00, 5.32188841e-01, 5.98114931e-02,
3.27068786e-01, 6.22175732e-01],
[4.47053597e-01, 0.00000000e+00, 4.31023912e-01, 8.44245703e-02,
3.24023176e-01, 6.76150628e-01],
[4.13654754e-01, 0.00000000e+00, 5.32188841e-01, 1.07209434e-01,
2.90298618e-01, 7.08577406e-01],
[3.80151882e-01, 0.00000000e+00, 7.97057020e-01, 1.21122807e-01,
1.19150201e-01, 4.95397490e-01],
[3.28235926e-01, 0.00000000e+00, 3.56223176e-01, 1.23820198e-01,
0.00000000e+00, 6.65271967e-01],
[2.83452966e-01, 0.00000000e+00, 2.28694053e-01, 1.22658572e-01,
2.65933739e-02, 5.55648536e-01],
[2.38616587e-01, 0.00000000e+00, 2.28694053e-01, 1.22990232e-01,
9.41910563e-02, 4.92887029e-01],
[1.82964031e-01, 0.00000000e+00, 5.19926426e-01, 1.30564491e-01,
8.97340663e-02, 4.94142259e-01],
[1.43835174e-01, 0.00000000e+00, 5.25444513e-01, 1.64135650e-01,
1.14618927e-01, 7.40585774e-01],
[1.04402664e-01, 0.00000000e+00, 1.55119559e-01, 2.41378071e-01,
1.98261774e-01, 6.50418410e-01],
[7.96438281e-02, 0.00000000e+00, 7.11220110e-02, 3.27145618e-01,
2.89110088e-01, 7.45188285e-01],
[6.36065353e-02, 0.00000000e+00, 0.00000000e+00, 4.11129065e-01,
4.05140395e-01, 6.88912134e-01],
[4.11672585e-02, 0.00000000e+00, 2.52605763e-01, 5.62182942e-01,
4.54315852e-01, 1.00000000e+00],
[2.87063044e-02, 0.00000000e+00, 1.27529123e-01, 6.81786323e-01,
4.59515674e-01, 9.32217573e-01],
[1.70269716e-02, 1.58966716e-03, 1.27529123e-01, 7.33474602e-01,
4.37453573e-01, 6.07322176e-01],
[3.30361486e-03, 6.37853949e-01, 1.27529123e-01, 8.06276376e-01,
4.69692468e-01, 7.54602510e-01],
[0.00000000e+00, 7.89369101e-01, 1.27529123e-01, 8.85843682e-01,
5.10919626e-01, 8.70502092e-01],
[5.13114648e-03, 8.19672131e-01, 1.27529123e-01, 9.60932765e-01,
5.99316595e-01, 8.79288703e-01],
[2.16829598e-02, 8.36065574e-01, 1.27529123e-01, 9.99121020e-01,
7.28866439e-01, 8.56903766e-01],
[4.27951674e-02, 8.36065574e-01, 1.27529123e-01, 1.00000000e+00,
8.67181697e-01, 7.88912134e-01],
[7.02334461e-02, 8.36065574e-01, 1.27529123e-01, 9.93500775e-01,
8.46308127e-01, 9.78451883e-01],
[9.73680733e-02, 8.36065574e-01, 1.27529123e-01, 9.87896869e-01,
8.66364582e-01, 8.59414226e-01],
[1.23611427e-01, 8.36065574e-01, 1.27529123e-01, 9.69613102e-01,
8.35685634e-01, 9.17991632e-01],
[1.52157471e-01, 8.68852459e-01, 1.27529123e-01, 9.22226597e-01,
7.96686971e-01, 9.65062762e-01],
[1.77979087e-01, 8.68852459e-01, 1.27529123e-01, 8.61132577e-01,
8.29594414e-01, 8.14225941e-01],
[2.03010647e-01, 8.84252360e-01, 1.27529123e-01, 8.13277174e-01,
8.29594414e-01, 9.11506276e-01],
[2.32490138e-01, 8.85245902e-01, 1.27529123e-01, 7.59549923e-01,
8.41851137e-01, 9.52301255e-01],
[2.58952796e-01, 8.85245902e-01, 1.27529123e-01, 6.97804020e-01,
8.55667806e-01, 8.68200837e-01],
[2.86697538e-01, 8.85245902e-01, 1.27529123e-01, 6.25149288e-01,
8.78621304e-01, 8.01255230e-01],
[3.15597842e-01, 8.85245902e-01, 2.09687308e-01, 5.51940700e-01,
8.90878027e-01, 7.94769874e-01],
[3.43688409e-01, 8.85245902e-01, 1.27529123e-01, 4.75801089e-01,
8.90878027e-01, 7.10669456e-01]])
TLDR: Try both!
I have been in similar situations before where my dataset was imbalanced. I used train_test_split or KFold to get through.
However, once I stumbled upon the problem of handling imbalanced datasets and came across the techniques of overbalancing and underbalancing. To do this, I would recommend using the library: imblearn
You will find various techniques there to handle the cases where one of your classes outnumbers the other one. I personally have used SMOTE a lot and have had relatively better success in such cases.
Other references:
https://www.analyticsvidhya.com/blog/2017/03/imbalanced-classification-problem/
https://towardsdatascience.com/handling-imbalanced-datasets-in-machine-learning-7a0e84220f28
You can use the stratify option in train test split, which splits each class on the mentioned test size.
x_train,y_train,x_test,y_test = train_test_split(X,y,test_size=0.2,stratify=y)
I am working on project where I am experimenting with credit dataset(imbalanced dataset containing 1% of a minority class and 99% of the majority class) for fraud detection using different sampling method and found that SMOTE gives better results with imbalanced datasets.
SMOTE (Synthetic Minority Oversampling Technique) is a powerful sampling method that goes beyond simple under or over sampling. This algorithm creates new instances of the minority class by creating convex combinations of neighbouring instances
I have used SMOTE sampling methods along with the K-Fold cross validation. Cross validation technique assures that model gets the correct patterns from the data, and it is not getting up too much noise.
In case of imbalanced dataset, accuracy score of sampling algorithm yields an accuracy of 99% which seems impressive, but minority class could be totally ignored in case of imbalanced datasets. So, I have used Matthew Coefficient Correlation Score, F1 Score measuring algorithm in addition to Accuracy for performance measurement on an Imbalanced Dataset.
Code :
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
sm = SMOTE(random_state=2) X_train_res, y_train_res = sm.fit_sample(X_train, y_train.ravel())
References :
来源:https://stackoverflow.com/questions/57142772/what-is-the-correct-procedure-to-split-the-data-sets-for-classification-problem