本文整理汇总了Python中model.D_NET256属性的典型用法代码示例。如果您正苦于以下问题:Python model.D_NET256属性的具体用法?Python model.D_NET256怎么用?Python model.D_NET256使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类model
的用法示例。
在下文中一共展示了model.D_NET256属性的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: load_network
# 需要导入模块: import model [as 别名]
# 或者: from model import D_NET256 [as 别名]
def load_network(gpus):
netG = G_NET()
netG.apply(weights_init)
netG = torch.nn.DataParallel(netG, device_ids=gpus)
print(netG)
netsD = []
if cfg.TREE.BRANCH_NUM > 0:
netsD.append(D_NET64())
if cfg.TREE.BRANCH_NUM > 1:
netsD.append(D_NET128())
if cfg.TREE.BRANCH_NUM > 2:
netsD.append(D_NET256())
if cfg.TREE.BRANCH_NUM > 3:
netsD.append(D_NET512())
if cfg.TREE.BRANCH_NUM > 4:
netsD.append(D_NET1024())
# TODO: if cfg.TREE.BRANCH_NUM > 5:
for i in range(len(netsD)):
netsD[i].apply(weights_init)
netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
# print(netsD[i])
print('# of netsD', len(netsD))
count = 0
if cfg.TRAIN.NET_G != '':
state_dict = torch.load(cfg.TRAIN.NET_G)
netG.load_state_dict(state_dict)
print('Load ', cfg.TRAIN.NET_G)
istart = cfg.TRAIN.NET_G.rfind('_') + 1
iend = cfg.TRAIN.NET_G.rfind('.')
count = cfg.TRAIN.NET_G[istart:iend]
count = int(count) + 1
if cfg.TRAIN.NET_D != '':
for i in range(len(netsD)):
print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i))
netsD[i].load_state_dict(state_dict)
inception_model = INCEPTION_V3()
if cfg.CUDA:
netG.cuda()
for i in range(len(netsD)):
netsD[i].cuda()
inception_model = inception_model.cuda()
inception_model.eval()
return netG, netsD, len(netsD), inception_model, count
示例2: load_network
# 需要导入模块: import model [as 别名]
# 或者: from model import D_NET256 [as 别名]
def load_network(gpus):
netG = G_NET()
netG.apply(weights_init)
netG = torch.nn.DataParallel(netG, device_ids=gpus)
print(netG)
netsD = []
if cfg.TREE.BRANCH_NUM > 0:
netsD.append(D_NET64())
if cfg.TREE.BRANCH_NUM > 1:
netsD.append(D_NET128())
if cfg.TREE.BRANCH_NUM > 2:
netsD.append(D_NET256())
if cfg.TREE.BRANCH_NUM > 3:
netsD.append(D_NET512())
if cfg.TREE.BRANCH_NUM > 4:
netsD.append(D_NET1024())
# TODO: if cfg.TREE.BRANCH_NUM > 5:
for i in range(len(netsD)):
netsD[i].apply(weights_init)
netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
# print(netsD[i])
print('# of netsD', len(netsD))
count = 0
if cfg.TRAIN.NET_G != '':
state_dict = torch.load(cfg.TRAIN.NET_G)
netG.load_state_dict(state_dict)
print('Load ', cfg.TRAIN.NET_G)
try:
istart = cfg.TRAIN.NET_G.rfind('_') + 1
iend = cfg.TRAIN.NET_G.rfind('.')
count = cfg.TRAIN.NET_G[istart:iend]
count = int(count)
except:
last_run_dir = cfg.DATA_DIR + '/' + cfg.LAST_RUN_DIR + '/Model'
with open(last_run_dir + '/count.txt', 'r') as f:
count = int(f.read())
count = int(count) + 1
if cfg.TRAIN.NET_D != '':
for i in range(len(netsD)):
print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i))
netsD[i].load_state_dict(state_dict)
inception_model = INCEPTION_V3()
if cfg.CUDA:
netG.cuda()
for i in range(len(netsD)):
netsD[i].cuda()
inception_model = inception_model.cuda()
inception_model.eval()
return netG, netsD, len(netsD), inception_model, count