本文整理汇总了Python中model.RNN_ENCODER属性的典型用法代码示例。如果您正苦于以下问题:Python model.RNN_ENCODER属性的具体用法?Python model.RNN_ENCODER怎么用?Python model.RNN_ENCODER使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类model
的用法示例。
在下文中一共展示了model.RNN_ENCODER属性的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: models
# 需要导入模块: import model [as 别名]
# 或者: from model import RNN_ENCODER [as 别名]
def models(word_len):
#print(word_len)
text_encoder = cache.get('text_encoder')
if text_encoder is None:
#print("text_encoder not cached")
text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM)
state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
text_encoder.load_state_dict(state_dict)
if cfg.CUDA:
text_encoder.cuda()
text_encoder.eval()
cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24)
netG = cache.get('netG')
if netG is None:
#print("netG not cached")
netG = G_NET()
state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)
if cfg.CUDA:
netG.cuda()
netG.eval()
cache.set('netG', netG, timeout=60 * 60 * 24)
return text_encoder, netG
示例2: build_models
# 需要导入模块: import model [as 别名]
# 或者: from model import RNN_ENCODER [as 别名]
def build_models():
# build model ############################################################
text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
labels = Variable(torch.LongTensor(range(batch_size)))
start_epoch = 0
if cfg.TRAIN.NET_E != '':
state_dict = torch.load(cfg.TRAIN.NET_E)
text_encoder.load_state_dict(state_dict)
print('Load ', cfg.TRAIN.NET_E)
#
name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
state_dict = torch.load(name)
image_encoder.load_state_dict(state_dict)
print('Load ', name)
istart = cfg.TRAIN.NET_E.rfind('_') + 8
iend = cfg.TRAIN.NET_E.rfind('.')
start_epoch = cfg.TRAIN.NET_E[istart:iend]
start_epoch = int(start_epoch) + 1
print('start_epoch', start_epoch)
if cfg.CUDA:
text_encoder = text_encoder.cuda()
image_encoder = image_encoder.cuda()
labels = labels.cuda()
return text_encoder, image_encoder, labels, start_epoch