当前位置: 首页>>代码示例>>Python>>正文


Python model.RNN_ENCODER属性代码示例

本文整理汇总了Python中model.RNN_ENCODER属性的典型用法代码示例。如果您正苦于以下问题:Python model.RNN_ENCODER属性的具体用法?Python model.RNN_ENCODER怎么用?Python model.RNN_ENCODER使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在model的用法示例。


在下文中一共展示了model.RNN_ENCODER属性的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: models

# 需要导入模块: import model [as 别名]
# 或者: from model import RNN_ENCODER [as 别名]
def models(word_len):
    #print(word_len)
    text_encoder = cache.get('text_encoder')
    if text_encoder is None:
        #print("text_encoder not cached")
        text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        if cfg.CUDA:
            text_encoder.cuda()
        text_encoder.eval()
        cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24)

    netG = cache.get('netG')
    if netG is None:
        #print("netG not cached")
        netG = G_NET()
        state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        if cfg.CUDA:
            netG.cuda()
        netG.eval()
        cache.set('netG', netG, timeout=60 * 60 * 24)

    return text_encoder, netG 
开发者ID:taoxugit,项目名称:AttnGAN,代码行数:27,代码来源:eval.py

示例2: build_models

# 需要导入模块: import model [as 别名]
# 或者: from model import RNN_ENCODER [as 别名]
def build_models():
    # build model ############################################################
    text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
    labels = Variable(torch.LongTensor(range(batch_size)))
    start_epoch = 0
    if cfg.TRAIN.NET_E != '':
        state_dict = torch.load(cfg.TRAIN.NET_E)
        text_encoder.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_E)
        #
        name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = torch.load(name)
        image_encoder.load_state_dict(state_dict)
        print('Load ', name)

        istart = cfg.TRAIN.NET_E.rfind('_') + 8
        iend = cfg.TRAIN.NET_E.rfind('.')
        start_epoch = cfg.TRAIN.NET_E[istart:iend]
        start_epoch = int(start_epoch) + 1
        print('start_epoch', start_epoch)
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()
        labels = labels.cuda()

    return text_encoder, image_encoder, labels, start_epoch 
开发者ID:MinfengZhu,项目名称:DM-GAN,代码行数:29,代码来源:pretrain_DAMSM.py


注:本文中的model.RNN_ENCODER属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。