问题
I'm working on a machine learning process to classify images. My problem is that my dataset is imbalanced, and in my 5 categories of images, I have about 400 images in of one class, and about 20 images of each of the other classes.
I would like to balance my train set by applying data augmentation only to certain classes of my train set.
Here's the code I'm using for creating the train an validation sets:
# Import data
data_dir = pathlib.Path(r"C:\Train set")
# Define train and validation sets (80% - 20%)
batch_size = 32
img_height = 240
img_width = 240
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
And here's how I apply data augmentation, although this would be for the entire train set:
# Apply data augmentation
data_augmentation = keras.Sequential(
[
layers.experimental.preprocessing.RandomFlip("horizontal",
input_shape=(img_height,
img_width,
3)),
layers.experimental.preprocessing.RandomRotation(0.1),
layers.experimental.preprocessing.RandomZoom(0.1),
]
)
Is there any way to go into my train set, extract those categories that have fewer images, and apply data augmentation only to them?
Thanks in advance!
回答1:
I suggest not using ImageDataGenerator
but a customized tf.data.Dataset
. In a mapping operation, you can treat categories differently, e.g.:
def preprocess(filepath):
category = tf.strings.split(filepath, os.sep)[0]
read_file = tf.io.read_file(filepath)
decode = tf.image.decode_jpeg(read_file, channels=3)
resize = tf.image.resize(decode, (200, 200))
image = tf.expand_dims(resize, 0)
if tf.equal(category, 'tf_astronauts'):
image = tf.image.flip_up_down(image)
image = tf.image.flip_left_right(image)
# image = tf.image.convert_image_dtype(image, tf.float32)
# category = tf.cast(tf.equal(category, 'tf_astronauts'), tf.int32)
return image, category
Let me demonstrate it. Let's make you a folder with training images:
import tensorflow as tf
import matplotlib.pyplot as plt
import cv2
from skimage import data
from glob2 import glob
import os
cat = data.chelsea()
astronaut = data.astronaut()
for category, picture in zip(['tf_cats', 'tf_astronauts'], [cat, astronaut]):
os.makedirs(category, exist_ok=True)
for i in range(5):
cv2.imwrite(os.path.join(category, category + f'_{i}.jpg'),
cv2.cvtColor(picture, cv2.COLOR_RGB2BGR))
files = glob('tf_*\\*.jpg')
Now you have these files:
['tf_astronauts\\tf_astronauts_0.jpg',
'tf_astronauts\\tf_astronauts_1.jpg',
'tf_astronauts\\tf_astronauts_2.jpg',
'tf_astronauts\\tf_astronauts_3.jpg',
'tf_astronauts\\tf_astronauts_4.jpg',
'tf_cats\\tf_cats_0.jpg',
'tf_cats\\tf_cats_1.jpg',
'tf_cats\\tf_cats_2.jpg',
'tf_cats\\tf_cats_3.jpg',
'tf_cats\\tf_cats_4.jpg']
Let's apply tranformations only to the astronaut category. Let's use the tf.image
transformations.
def preprocess(filepath):
category = tf.strings.split(filepath, os.sep)[0]
read_file = tf.io.read_file(filepath)
decode = tf.image.decode_jpeg(read_file, channels=3)
resize = tf.image.resize(decode, (200, 200))
image = tf.expand_dims(resize, 0)
if tf.equal(category, 'tf_astronauts'):
image = tf.image.flip_up_down(image)
image = tf.image.flip_left_right(image)
# image = tf.image.convert_image_dtype(image, tf.float32)
# category = tf.cast(tf.equal(category, 'tf_astronauts'), tf.int32)
return image, category
Then, we make the tf.data.Dataset
:
train = tf.data.Dataset.from_tensor_slices(files).\
shuffle(10).take(4).map(preprocess).batch(4)
And when you iterate the dataset, you'll see that only the astronaut is flipped:
fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images, labels = next(iter(train))
for index, (image, label) in enumerate(zip(images, labels)):
ax = plt.subplot(2, 2, index + 1)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(label.numpy().decode())
ax.imshow(image[0].numpy().astype(int))
plt.show()
Please note, for training you will need to uncomment the two lines in preprocess
so it returns an array of floats and an integer.
来源:https://stackoverflow.com/questions/64374691/apply-different-data-augmentation-to-part-of-the-train-set-based-on-the-category