本文整理汇总了Python中nets.nasnet.nasnet_utils.get_channel_dim方法的典型用法代码示例。如果您正苦于以下问题:Python nasnet_utils.get_channel_dim方法的具体用法?Python nasnet_utils.get_channel_dim怎么用?Python nasnet_utils.get_channel_dim使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类nets.nasnet.nasnet_utils
的用法示例。
在下文中一共展示了nasnet_utils.get_channel_dim方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testGetChannelDim
# 需要导入模块: from nets.nasnet import nasnet_utils [as 别名]
# 或者: from nets.nasnet.nasnet_utils import get_channel_dim [as 别名]
def testGetChannelDim(self):
data_formats = ['NHWC', 'NCHW']
shape = [10, 20, 30, 40]
for data_format in data_formats:
dim = nasnet_utils.get_channel_dim(shape, data_format)
correct_dim = shape[3] if data_format == 'NHWC' else shape[1]
self.assertEqual(dim, correct_dim)
示例2: build_nasnet_cifar
# 需要导入模块: from nets.nasnet import nasnet_utils [as 别名]
# 或者: from nets.nasnet.nasnet_utils import get_channel_dim [as 别名]
def build_nasnet_cifar(
images, num_classes, is_training=True):
"""Build NASNet model for the Cifar Dataset."""
hparams = _cifar_config(is_training=is_training)
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
tf.logging.info('A GPU is available on the machine, consider using NCHW '
'data format for increased speed on GPU.')
if hparams.data_format == 'NCHW':
images = tf.transpose(images, [0, 3, 1, 2])
# Calculate the total number of cells in the network
# Add 2 for the reduction cells
total_num_cells = hparams.num_cells + 2
normal_cell = nasnet_utils.NasNetANormalCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
with arg_scope([slim.avg_pool2d,
slim.max_pool2d,
slim.conv2d,
slim.batch_norm,
slim.separable_conv2d,
nasnet_utils.factorized_reduction,
nasnet_utils.global_avg_pool,
nasnet_utils.get_channel_index,
nasnet_utils.get_channel_dim],
data_format=hparams.data_format):
return _build_nasnet_base(images,
normal_cell=normal_cell,
reduction_cell=reduction_cell,
num_classes=num_classes,
hparams=hparams,
is_training=is_training,
stem_type='cifar')
示例3: build_pnasnet_large
# 需要导入模块: from nets.nasnet import nasnet_utils [as 别名]
# 或者: from nets.nasnet.nasnet_utils import get_channel_dim [as 别名]
def build_pnasnet_large(images,
num_classes,
is_training=True,
final_endpoint=None,
config=None):
"""Build PNASNet Large model for the ImageNet Dataset."""
hparams = copy.deepcopy(config) if config else large_imagenet_config()
# pylint: disable=protected-access
nasnet._update_hparams(hparams, is_training)
# pylint: enable=protected-access
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
tf.logging.info('A GPU is available on the machine, consider using NCHW '
'data format for increased speed on GPU.')
if hparams.data_format == 'NCHW':
images = tf.transpose(images, [0, 3, 1, 2])
# Calculate the total number of cells in the network.
# There is no distinction between reduction and normal cells in PNAS so the
# total number of cells is equal to the number normal cells plus the number
# of stem cells (two by default).
total_num_cells = hparams.num_cells + 2
normal_cell = PNasNetNormalCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob, total_num_cells,
hparams.total_training_steps)
with arg_scope(
[slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d,
slim.batch_norm, slim.separable_conv2d,
nasnet_utils.factorized_reduction,
nasnet_utils.global_avg_pool,
nasnet_utils.get_channel_index,
nasnet_utils.get_channel_dim],
data_format=hparams.data_format):
return _build_pnasnet_base(
images,
normal_cell=normal_cell,
num_classes=num_classes,
hparams=hparams,
is_training=is_training,
final_endpoint=final_endpoint)
示例4: build_pnasnet_mobile
# 需要导入模块: from nets.nasnet import nasnet_utils [as 别名]
# 或者: from nets.nasnet.nasnet_utils import get_channel_dim [as 别名]
def build_pnasnet_mobile(images,
num_classes,
is_training=True,
final_endpoint=None,
config=None):
"""Build PNASNet Mobile model for the ImageNet Dataset."""
hparams = copy.deepcopy(config) if config else mobile_imagenet_config()
# pylint: disable=protected-access
nasnet._update_hparams(hparams, is_training)
# pylint: enable=protected-access
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
tf.logging.info('A GPU is available on the machine, consider using NCHW '
'data format for increased speed on GPU.')
if hparams.data_format == 'NCHW':
images = tf.transpose(images, [0, 3, 1, 2])
# Calculate the total number of cells in the network.
# There is no distinction between reduction and normal cells in PNAS so the
# total number of cells is equal to the number normal cells plus the number
# of stem cells (two by default).
total_num_cells = hparams.num_cells + 2
normal_cell = PNasNetNormalCell(hparams.num_conv_filters,
hparams.drop_path_keep_prob, total_num_cells,
hparams.total_training_steps)
with arg_scope(
[slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
with arg_scope(
[
slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm,
slim.separable_conv2d, nasnet_utils.factorized_reduction,
nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index,
nasnet_utils.get_channel_dim
],
data_format=hparams.data_format):
return _build_pnasnet_base(
images,
normal_cell=normal_cell,
num_classes=num_classes,
hparams=hparams,
is_training=is_training,
final_endpoint=final_endpoint)
示例5: build_nasnet_cifar
# 需要导入模块: from nets.nasnet import nasnet_utils [as 别名]
# 或者: from nets.nasnet.nasnet_utils import get_channel_dim [as 别名]
def build_nasnet_cifar(images, num_classes,
is_training=True,
config=None,
current_step=None):
"""Build NASNet model for the Cifar Dataset."""
hparams = cifar_config() if config is None else copy.deepcopy(config)
_update_hparams(hparams, is_training)
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
tf.logging.info('A GPU is available on the machine, consider using NCHW '
'data format for increased speed on GPU.')
if hparams.data_format == 'NCHW':
images = tf.transpose(images, [0, 3, 1, 2])
# Calculate the total number of cells in the network
# Add 2 for the reduction cells
total_num_cells = hparams.num_cells + 2
normal_cell = nasnet_utils.NasNetANormalCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
with arg_scope([slim.avg_pool2d,
slim.max_pool2d,
slim.conv2d,
slim.batch_norm,
slim.separable_conv2d,
nasnet_utils.factorized_reduction,
nasnet_utils.global_avg_pool,
nasnet_utils.get_channel_index,
nasnet_utils.get_channel_dim],
data_format=hparams.data_format):
return _build_nasnet_base(images,
normal_cell=normal_cell,
reduction_cell=reduction_cell,
num_classes=num_classes,
hparams=hparams,
is_training=is_training,
stem_type='cifar',
current_step=current_step)
示例6: build_nasnet_mobile
# 需要导入模块: from nets.nasnet import nasnet_utils [as 别名]
# 或者: from nets.nasnet.nasnet_utils import get_channel_dim [as 别名]
def build_nasnet_mobile(images, num_classes,
is_training=True,
final_endpoint=None,
config=None,
current_step=None):
"""Build NASNet Mobile model for the ImageNet Dataset."""
hparams = (mobile_imagenet_config() if config is None
else copy.deepcopy(config))
_update_hparams(hparams, is_training)
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
tf.logging.info('A GPU is available on the machine, consider using NCHW '
'data format for increased speed on GPU.')
if hparams.data_format == 'NCHW':
images = tf.transpose(images, [0, 3, 1, 2])
# Calculate the total number of cells in the network
# Add 2 for the reduction cells
total_num_cells = hparams.num_cells + 2
# If ImageNet, then add an additional two for the stem cells
total_num_cells += 2
normal_cell = nasnet_utils.NasNetANormalCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
with arg_scope([slim.avg_pool2d,
slim.max_pool2d,
slim.conv2d,
slim.batch_norm,
slim.separable_conv2d,
nasnet_utils.factorized_reduction,
nasnet_utils.global_avg_pool,
nasnet_utils.get_channel_index,
nasnet_utils.get_channel_dim],
data_format=hparams.data_format):
return _build_nasnet_base(images,
normal_cell=normal_cell,
reduction_cell=reduction_cell,
num_classes=num_classes,
hparams=hparams,
is_training=is_training,
stem_type='imagenet',
final_endpoint=final_endpoint,
current_step=current_step)
示例7: build_nasnet_mobile
# 需要导入模块: from nets.nasnet import nasnet_utils [as 别名]
# 或者: from nets.nasnet.nasnet_utils import get_channel_dim [as 别名]
def build_nasnet_mobile(images, num_classes,
is_training=True,
final_endpoint=None):
"""Build NASNet Mobile model for the ImageNet Dataset."""
hparams = _mobile_imagenet_config()
if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
tf.logging.info('A GPU is available on the machine, consider using NCHW '
'data format for increased speed on GPU.')
if hparams.data_format == 'NCHW':
images = tf.transpose(images, [0, 3, 1, 2])
# Calculate the total number of cells in the network
# Add 2 for the reduction cells
total_num_cells = hparams.num_cells + 2
# If ImageNet, then add an additional two for the stem cells
total_num_cells += 2
normal_cell = nasnet_utils.NasNetANormalCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
reduction_cell = nasnet_utils.NasNetAReductionCell(
hparams.num_conv_filters, hparams.drop_path_keep_prob,
total_num_cells, hparams.total_training_steps)
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
is_training=is_training):
with arg_scope([slim.avg_pool2d,
slim.max_pool2d,
slim.conv2d,
slim.batch_norm,
slim.separable_conv2d,
nasnet_utils.factorized_reduction,
nasnet_utils.global_avg_pool,
nasnet_utils.get_channel_index,
nasnet_utils.get_channel_dim],
data_format=hparams.data_format):
return _build_nasnet_base(images,
normal_cell=normal_cell,
reduction_cell=reduction_cell,
num_classes=num_classes,
hparams=hparams,
is_training=is_training,
stem_type='imagenet',
final_endpoint=final_endpoint)