通过实时数据增强生成批量张量图像数据。
用法
tf.keras.preprocessing.image.ImageDataGenerator(
featurewise_center=False, samplewise_center=False,
featurewise_std_normalization=False, samplewise_std_normalization=False,
zca_whitening=False, zca_epsilon=1e-06, rotation_range=0, width_shift_range=0.0,
height_shift_range=0.0, brightness_range=None, shear_range=0.0, zoom_range=0.0,
channel_shift_range=0.0, fill_mode='nearest', cval=0.0,
horizontal_flip=False, vertical_flip=False, rescale=None,
preprocessing_function=None, data_format=None, validation_split=0.0, dtype=None
)
参数
-
featurewise_center
布尔值。在数据集feature-wise 上将输入均值设置为 0。 -
samplewise_center
布尔值。将每个样本均值设置为 0。 -
featurewise_std_normalization
布尔值。将输入除以数据集的 std,feature-wise。 -
samplewise_std_normalization
布尔值。将每个输入除以其标准。 -
zca_epsilon
用于 ZCA 美白的 epsilon。默认值为 1e-6。 -
zca_whitening
布尔值。应用 ZCA 美白。 -
rotation_range
Int. 随机旋转的度数范围。 -
width_shift_range
浮点数,一维 array-like 或 int- 浮点数:总宽度的分数,如果 < 1,或像素,如果 >= 1。
- 一维array-like:数组中的随机元素。
- int:间隔
(-width_shift_range, +width_shift_range)
的整数像素数 width_shift_range=2
可能的值是整数[-1, 0, +1]
,与width_shift_range=[-1, 0, +1]
相同,而width_shift_range=1.0
可能的值是区间 [-1.0, +1.0) 中的浮点数。
-
height_shift_range
浮点数,一维 array-like 或 int- 浮点数:总高度的分数,如果 < 1,或像素,如果 >= 1。
- 一维array-like:数组中的随机元素。
- int:间隔
(-height_shift_range, +height_shift_range)
的整数像素数 height_shift_range=2
可能的值是整数[-1, 0, +1]
,与height_shift_range=[-1, 0, +1]
相同,而height_shift_range=1.0
可能的值是区间 [-1.0, +1.0) 中的浮点数。
-
brightness_range
元组或两个浮点数的列表。从中选择亮度偏移值的范围。 -
shear_range
浮点数。剪切强度(逆时针方向的剪切角,以度为单位) -
zoom_range
浮点数或[下,上]。随机缩放范围。如果是浮点数,[lower, upper] = [1-zoom_range, 1+zoom_range]
。 -
channel_shift_range
浮点数。随机通道移位的范围。 -
fill_mode
{"constant"、"nearest"、"reflect" 或 "wrap"} 之一。默认为'nearest'。根据给定的模式填充输入边界之外的点:- 'constant':kkkkkkkk|abcd|kkkkkkkk (cval=k)
- 'nearest': aaaaaaaa|abcd|dddddddd
- 'reflect': abcddcba|abcd|dcbaabcd
- 'wrap':abcdabcd|abcd|abcdabcd
-
cval
浮点数或 Int.fill_mode = "constant"
时用于边界外点的值。 -
horizontal_flip
布尔值。水平随机翻转输入。 -
vertical_flip
布尔值。垂直随机翻转输入。 -
rescale
重新缩放因子。默认为无。如果 None 或 0,则不应用重新缩放,否则我们将数据乘以提供的值(在应用所有其他转换之后)。 -
preprocessing_function
将应用于每个输入的函数。该函数将在图像调整大小和增强后运行。该函数应采用一个参数:一张图像(等级为 3 的 Numpy 张量),并且应输出具有相同形状的 Numpy 张量。 -
data_format
图像数据格式,"channels_first" 或 "channels_last"。 "channels_last" 模式意味着图像应该具有形状(samples, height, width, channels)
,"channels_first" 模式意味着图像应该具有形状(samples, channels, height, width)
。它默认为您的 Keras 配置文件中的image_data_format
值~/.keras/keras.json
。如果您从未设置它,那么它将是"channels_last"。 -
validation_split
浮点数。保留用于验证的图像分数(严格在 0 和 1 之间)。 -
dtype
用于生成的数组的 Dtype。
抛出
-
ValueError
如果参数的值data_format
不是"channels_last"
或"channels_first"
。 -
ValueError
如果参数的值,validation_split
> 1 或validation_split
数据将被循环(分批)。
例子:
使用 .flow(x, y)
的示例:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = utils.to_categorical(y_train, num_classes)
y_test = utils.to_categorical(y_test, num_classes)
datagen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
validation_split=0.2)
# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(x_train)
# fits the model on batches with real-time data augmentation:
model.fit(datagen.flow(x_train, y_train, batch_size=32,
subset='training'),
validation_data=datagen.flow(x_train, y_train,
batch_size=8, subset='validation'),
steps_per_epoch=len(x_train) / 32, epochs=epochs)
# here's a more "manual" example
for e in range(epochs):
print('Epoch', e)
batches = 0
for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
model.fit(x_batch, y_batch)
batches += 1
if batches >= len(x_train) / 32:
# we need to break the loop by hand because
# the generator loops indefinitely
break
使用 .flow_from_directory(directory)
的示例:
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
'data/validation',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
model.fit(
train_generator,
steps_per_epoch=2000,
epochs=50,
validation_data=validation_generator,
validation_steps=800)
一起转换图像和蒙版的示例。
# we create two instances with the same arguments
data_gen_args = dict(featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=90,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.2)
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)
# Provide the same seed and keyword arguments to the fit and flow methods
seed = 1
image_datagen.fit(images, augment=True, seed=seed)
mask_datagen.fit(masks, augment=True, seed=seed)
image_generator = image_datagen.flow_from_directory(
'data/images',
class_mode=None,
seed=seed)
mask_generator = mask_datagen.flow_from_directory(
'data/masks',
class_mode=None,
seed=seed)
# combine generators into one which yields image and masks
train_generator = zip(image_generator, mask_generator)
model.fit(
train_generator,
steps_per_epoch=2000,
epochs=50)
相关用法
- Python tf.keras.preprocessing.image.smart_resize用法及代码示例
- Python tf.keras.preprocessing.image.DirectoryIterator.__len__用法及代码示例
- Python tf.keras.preprocessing.sequence.TimeseriesGenerator用法及代码示例
- Python tf.keras.preprocessing.sequence.pad_sequences用法及代码示例
- Python tf.keras.preprocessing.sequence.make_sampling_table用法及代码示例
- Python tf.keras.preprocessing.text.text_to_word_sequence用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
- Python tf.keras.layers.InputLayer用法及代码示例
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代码示例
- Python tf.keras.layers.serialize用法及代码示例
- Python tf.keras.metrics.Hinge用法及代码示例
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代码示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代码示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代码示例
- Python tf.keras.applications.resnet50.preprocess_input用法及代码示例
- Python tf.keras.metrics.SparseCategoricalCrossentropy.merge_state用法及代码示例
- Python tf.keras.metrics.sparse_categorical_accuracy用法及代码示例
- Python tf.keras.layers.Dropout用法及代码示例
- Python tf.keras.activations.softplus用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.preprocessing.image.ImageDataGenerator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。