本文整理匯總了Python中utils.count_parameters_in_MB方法的典型用法代碼示例。如果您正苦於以下問題:Python utils.count_parameters_in_MB方法的具體用法?Python utils.count_parameters_in_MB怎麽用?Python utils.count_parameters_in_MB使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類utils
的用法示例。
在下文中一共展示了utils.count_parameters_in_MB方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: build_cifar10
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def build_cifar10(model_state_dict=None, optimizer_state_dict=None, **kwargs):
epoch = kwargs.pop('epoch')
ratio = kwargs.pop('ratio')
train_transform, valid_transform = utils._data_transforms_cifar10(args.child_cutout_size)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=valid_transform)
num_train = len(train_data)
assert num_train == len(valid_data)
indices = list(range(num_train))
split = int(np.floor(ratio * num_train))
np.random.shuffle(indices)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.child_batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True, num_workers=16)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.child_eval_batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True, num_workers=16)
model = NASWSNetworkCIFAR(args, 10, args.child_layers, args.child_nodes, args.child_channels, args.child_keep_prob, args.child_drop_path_keep_prob,
args.child_use_aux_head, args.steps)
model = model.cuda()
train_criterion = nn.CrossEntropyLoss().cuda()
eval_criterion = nn.CrossEntropyLoss().cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
optimizer = torch.optim.SGD(
model.parameters(),
args.child_lr_max,
momentum=0.9,
weight_decay=args.child_l2_reg,
)
if model_state_dict is not None:
model.load_state_dict(model_state_dict)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.child_epochs, args.child_lr_min, epoch)
return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
示例2: build_cifar100
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def build_cifar100(model_state_dict=None, optimizer_state_dict=None, **kwargs):
epoch = kwargs.pop('epoch')
ratio = kwargs.pop('ratio')
train_transform, valid_transform = utils._data_transforms_cifar10(args.cutout_size)
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=valid_transform)
num_train = len(train_data)
assert num_train == len(valid_data)
indices = list(range(num_train))
split = int(np.floor(ratio * num_train))
np.random.shuffle(indices)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.child_batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True, num_workers=16)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.child_eval_batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True, num_workers=16)
model = NASWSNetworkCIFAR(args, 100, args.child_layers, args.child_nodes, args.child_channels, args.child_keep_prob, args.child_drop_path_keep_prob,
args.child_use_aux_head, args.steps)
model = model.cuda()
train_criterion = nn.CrossEntropyLoss().cuda()
eval_criterion = nn.CrossEntropyLoss().cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
optimizer = torch.optim.SGD(
model.parameters(),
args.child_lr_max,
momentum=0.9,
weight_decay=args.child_l2_reg,
)
if model_state_dict is not None:
model.load_state_dict(model_state_dict)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.child_epochs, args.child_lr_min, epoch)
return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
示例3: build_cifar10
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def build_cifar10(model_state_dict, optimizer_state_dict, **kwargs):
epoch = kwargs.pop('epoch')
train_transform, valid_transform = utils._data_transforms_cifar10(args.cutout_size)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=16)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.eval_batch_size, shuffle=False, pin_memory=True, num_workers=16)
model = NASNetworkCIFAR(args, 10, args.layers, args.nodes, args.channels, args.keep_prob, args.drop_path_keep_prob,
args.use_aux_head, args.steps, args.arch)
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
logging.info("multi adds = %fM", model.multi_adds / 1000000)
if model_state_dict is not None:
model.load_state_dict(model_state_dict)
if torch.cuda.device_count() > 1:
logging.info("Use %d %s", torch.cuda.device_count(), "GPUs !")
model = nn.DataParallel(model)
model = model.cuda()
train_criterion = nn.CrossEntropyLoss().cuda()
eval_criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.lr_max,
momentum=0.9,
weight_decay=args.l2_reg,
)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), args.lr_min, epoch)
return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
示例4: build_cifar100
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def build_cifar100(model_state_dict, optimizer_state_dict, **kwargs):
epoch = kwargs.pop('epoch')
train_transform, valid_transform = utils._data_transforms_cifar10(args.cutout_size)
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=16)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.eval_batch_size, shuffle=False, pin_memory=True, num_workers=16)
model = NASNetworkCIFAR(args, 100, args.layers, args.nodes, args.channels, args.keep_prob, args.drop_path_keep_prob,
args.use_aux_head, args.steps, args.arch)
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
logging.info("multi adds = %fM", model.multi_adds / 1000000)
if model_state_dict is not None:
model.load_state_dict(model_state_dict)
if torch.cuda.device_count() > 1:
logging.info("Use %d %s", torch.cuda.device_count(), "GPUs !")
model = nn.DataParallel(model)
model = model.cuda()
train_criterion = nn.CrossEntropyLoss().cuda()
eval_criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.lr_max,
momentum=0.9,
weight_decay=args.l2_reg,
)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), args.lr_min, epoch)
return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
示例5: build_cifar10
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def build_cifar10(model_state_dict, optimizer_state_dict, **kwargs):
epoch = kwargs.pop('epoch')
train_transform, valid_transform = utils._data_transforms_cifar10(args.cutout_size, args.autoaugment)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=16)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.eval_batch_size, shuffle=False, pin_memory=True, num_workers=16)
model = NASNetworkCIFAR(args, 10, args.layers, args.nodes, args.channels, args.keep_prob, args.drop_path_keep_prob,
args.use_aux_head, args.steps, args.arch)
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
if model_state_dict is not None:
model.load_state_dict(model_state_dict)
if torch.cuda.device_count() > 1:
logging.info("Use %d %s", torch.cuda.device_count(), "GPUs !")
model = nn.DataParallel(model)
model = model.cuda()
train_criterion = nn.CrossEntropyLoss().cuda()
eval_criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.lr_max,
momentum=0.9,
weight_decay=args.l2_reg,
)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), args.lr_min, epoch)
return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
示例6: build_cifar10
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def build_cifar10(model_state_dict=None, optimizer_state_dict=None, **kwargs):
epoch = kwargs.pop('epoch')
ratio = kwargs.pop('ratio')
train_transform, valid_transform = utils._data_transforms_cifar10(args.child_cutout_size)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=valid_transform)
num_train = len(train_data)
assert num_train == len(valid_data)
indices = list(range(num_train))
split = int(np.floor(ratio * num_train))
np.random.shuffle(indices)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.child_batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True, num_workers=16)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.child_eval_batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True, num_workers=16)
model = NASWSNetworkCIFAR(10, args.child_layers, args.child_nodes, args.child_channels, args.child_keep_prob, args.child_drop_path_keep_prob,
args.child_use_aux_head, args.steps)
model = model.cuda()
train_criterion = nn.CrossEntropyLoss().cuda()
eval_criterion = nn.CrossEntropyLoss().cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
optimizer = torch.optim.SGD(
model.parameters(),
args.child_lr_max,
momentum=0.9,
weight_decay=args.child_l2_reg,
)
if model_state_dict is not None:
model.load_state_dict(model_state_dict)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.child_epochs, args.child_lr_min, epoch)
return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
示例7: build_cifar100
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def build_cifar100(model_state_dict=None, optimizer_state_dict=None, **kwargs):
epoch = kwargs.pop('epoch')
ratio = kwargs.pop('ratio')
train_transform, valid_transform = utils._data_transforms_cifar10(args.cutout_size)
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=valid_transform)
num_train = len(train_data)
assert num_train == len(valid_data)
indices = list(range(num_train))
split = int(np.floor(ratio * num_train))
np.random.shuffle(indices)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.child_batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True, num_workers=16)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.child_eval_batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True, num_workers=16)
model = NASWSNetworkCIFAR(100, args.child_layers, args.child_nodes, args.child_channels, args.child_keep_prob, args.child_drop_path_keep_prob,
args.child_use_aux_head, args.steps)
model = model.cuda()
train_criterion = nn.CrossEntropyLoss().cuda()
eval_criterion = nn.CrossEntropyLoss().cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
optimizer = torch.optim.SGD(
model.parameters(),
args.child_lr_max,
momentum=0.9,
weight_decay=args.child_l2_reg,
)
if model_state_dict is not None:
model.load_state_dict(model_state_dict)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.child_epochs, args.child_lr_min, epoch)
return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
示例8: build_cifar10
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def build_cifar10(model_state_dict, optimizer_state_dict, **kwargs):
epoch = kwargs.pop('epoch')
train_transform, valid_transform = utils._data_transforms_cifar10(args.cutout_size)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=16)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.eval_batch_size, shuffle=False, pin_memory=True, num_workers=16)
model = NASNetworkCIFAR(args, 10, args.layers, args.nodes, args.channels, args.keep_prob, args.drop_path_keep_prob,
args.use_aux_head, args.steps, args.arch)
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
if model_state_dict is not None:
model.load_state_dict(model_state_dict)
if torch.cuda.device_count() > 1:
logging.info("Use %d %s", torch.cuda.device_count(), "GPUs !")
model = nn.DataParallel(model)
model = model.cuda()
train_criterion = nn.CrossEntropyLoss().cuda()
eval_criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.lr_max,
momentum=0.9,
weight_decay=args.l2_reg,
)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), args.lr_min, epoch)
return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
示例9: num_parameters
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def num_parameters(self):
params = count_parameters_in_MB(self.netG_A)
params+= count_parameters_in_MB(self.netG_B)
params+= count_parameters_in_MB(self.netD_B)
params+= count_parameters_in_MB(self.netD_B)
return params
示例10: main
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def main():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
torch.cuda.set_device(args.gpu)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled=True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
model = model.cuda()
utils.load(model, args.model_path)
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
_, test_transform = utils._data_transforms_cifar10(args)
test_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=test_transform)
test_queue = torch.utils.data.DataLoader(
test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)
model.drop_path_prob = args.drop_path_prob
test_acc, test_obj = infer(test_queue, model, criterion)
logging.info('test_acc %f', test_acc)
示例11: initialize_model
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def initialize_model(self):
"""
Initialize model, may change across different model.
:return:
"""
args = self.args
model = self.model_fn(args)
if args.gpus > 0:
if self.args.gpus == 1:
model = model.cuda()
self.parallel_model = model
else:
self.model = model
self.parallel_model = nn.DataParallel(self.model).cuda()
# IPython.embed(header='checking replicas and others.')
else:
self.parallel_model = model
# rewrite the pointer
model = self.parallel_model
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
optimizer = torch.optim.SGD(
model.parameters(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay)
# scheduler as Cosine.
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(args.epochs), eta_min=args.learning_rate_min)
return model, optimizer, scheduler
示例12: main
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def main():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
torch.cuda.set_device(args.gpu)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled=True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
model = model.cuda()
utils.load(model, args.model_path)
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
_, test_transform = utils.data_transforms_cifar10(args)
test_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=test_transform)
test_queue = torch.utils.data.DataLoader(
test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)
model.drop_path_prob = args.drop_path_prob
test_acc, test_obj = infer(test_queue, model, criterion)
logging.info('test_acc %f', test_acc)
示例13: main
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def main():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
torch.cuda.set_device(args.gpu)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled=True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
model = model.cuda()
utils.load(model, args.model_path)
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
_, test_transform = utils._data_transforms_cifar10(args)
test_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=test_transform)
test_queue = torch.utils.data.DataLoader(
test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)
model.drop_path_prob = args.drop_path_prob
with torch.no_grad():
test_acc, test_obj = infer(test_queue, model, criterion)
logging.info('test_acc %f', test_acc)
示例14: main
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def main():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
cudnn.enabled=True
logging.info("args = %s", args)
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
model = nn.DataParallel(model)
model = model.cuda()
model.load_state_dict(torch.load(args.model_path)['state_dict'])
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
validdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
valid_data = dset.ImageFolder(
validdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=False, num_workers=4)
model.module.drop_path_prob = 0.0
valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion)
logging.info('Valid_acc_top1 %f', valid_acc_top1)
logging.info('Valid_acc_top5 %f', valid_acc_top5)
示例15: initialize_model
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import count_parameters_in_MB [as 別名]
def initialize_model(self):
"""
Initialize model, may change across different model.
:return:
"""
args = self.args
# over ride the model_fn
self.train_fn = None
self.eval_fn = nao_model_validation_nasbench
if self.args.search_space == 'nasbench':
self.model_fn = NasBenchNetSearchENAS
self.fixmodel_fn = NasBenchNet
model = self.model_fn(args)
utils = enas_nasbench_utils
enas = MicroControllerNasbench(args=args)
else:
utils = enas_utils
self.model_fn = ENASWSCNN
self.fixmodel_fn = None
model = self.model_fn(args)
enas = MicroController(args)
enas = enas.cuda()
logging.info("ENAS RNN sampler param size = %fMB", project_utils.count_parameters_in_MB(enas))
self.controller = enas
if args.gpus > 0:
if self.args.gpus == 1:
model = model.cuda()
self.parallel_model = model
else:
self.model = model
self.parallel_model = nn.DataParallel(self.model).cuda()
# IPython.embed(header='checking replicas and others.')
else:
self.parallel_model = model
# rewrite the pointer
model = self.parallel_model
logging.info("param size = %fMB", project_utils.count_parameters_in_MB(model))
optimizer = torch.optim.SGD(
model.parameters(),
args.child_lr_max,
momentum=0.9,
weight_decay=args.child_l2_reg,
)
controller_optimizer = torch.optim.Adam(
enas.parameters(),
args.controller_lr,
betas=(0.1, 0.999),
eps=1e-3,
)
# scheduler as Cosine.
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, args.child_lr_min)
return model, optimizer, scheduler, enas, controller_optimizer