当前位置: 首页>>代码示例>>Python>>正文


Python Adam.state_dict方法代码示例

本文整理汇总了Python中torch.optim.Adam.state_dict方法的典型用法代码示例。如果您正苦于以下问题:Python Adam.state_dict方法的具体用法?Python Adam.state_dict怎么用?Python Adam.state_dict使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.optim.Adam的用法示例。


在下文中一共展示了Adam.state_dict方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _run

# 需要导入模块: from torch.optim import Adam [as 别名]
# 或者: from torch.optim.Adam import state_dict [as 别名]

#.........这里部分代码省略.........
                transfer batch to Torch: 1 ms, (1508424955295, 1508424955296)
                Generating one example time: 2~5 ms, (1508458881118, 1508458881122)
                Generating one document time: 50~400 ms, (1508458881118, 1508458881122)
                Generating one batch time: 650~700 ms, (1508458880690, 1508458881122)
            After changing to torch.sampler
                Generating one example time: 4~7 ms
                Generating one batch time: 900~1200 ms

    '''

    model = DistributedMemory(
        vec_dim,
        num_docs=len(dataset),
        num_words=vocabulary_size)

    cost_func = NegativeSampling()
    optimizer = Adam(params=model.parameters(), lr=lr)
    logger = logging.getLogger('root')

    if torch.cuda.is_available():
        model.cuda()
        logger.info("Running on GPU - CUDA")
    else:
        logger.info("Running on CPU")

    logger.info("Dataset comprised of {:d} documents.".format(len(dataset)))
    logger.info("Vocabulary size is {:d}.\n".format(vocabulary_size))
    logger.info("Training started.")

    best_loss = float_info.max
    prev_model_file_path = ""

    progbar = Progbar(num_batches, batch_size=batch_size, total_examples = number_examples)

    for epoch_i in range(num_epochs):
        epoch_start_time = time.time()
        loss = []

        for batch_i in range(num_batches):
            start_time = current_milli_time()
            batch = next(data_generator)
            current_time = current_milli_time()
            print('data-prepare time: %d ms' % (round(current_time - start_time)))

            start_time = current_milli_time()
            x = model.forward(
                batch.context_ids,
                batch.doc_ids,
                batch.target_noise_ids)
            x = cost_func.forward(x)
            loss.append(x.data[0])
            print('forward time: %d ms' % round(current_milli_time() - start_time))

            start_time = current_milli_time()
            model.zero_grad()
            x.backward()
            optimizer.step()
            print('backward time: %d ms' % round(current_milli_time() - start_time))

            progbar.update(epoch_i, batch_i, )
            # _print_progress(epoch_i, batch_i, num_batches)

        # end of epoch
        loss = torch.mean(torch.FloatTensor(loss))
        is_best_loss = loss < best_loss
        best_loss = min(loss, best_loss)
        progbar.update(epoch_i, batch_i, [('loss', loss), ('best_loss', best_loss)])

        model_file_name = MODEL_NAME.format(
            data_file_name[:-4],
            model_ver,
            vec_combine_method,
            context_size,
            num_noise_words,
            vec_dim,
            batch_size,
            lr,
            epoch_i + 1,
            loss)
        model_file_path = join(MODELS_DIR, model_file_name)
        if not os.path.exists(MODELS_DIR):
            os.makedirs(MODELS_DIR)
        state = {
            'epoch': epoch_i + 1,
            'model_state_dict': model.state_dict(),
            'best_loss': best_loss,
            'optimizer_state_dict': optimizer.state_dict()
        }
        if save_all:
            torch.save(state, model_file_path)
        elif is_best_loss:
            try:
                remove(prev_model_file_path)
            except FileNotFoundError:
                pass
            torch.save(state, model_file_path)
            prev_model_file_path = model_file_path

        epoch_total_time = round(time.time() - epoch_start_time)
        logger.info(" ({:d}s) - loss: {:.4f}".format(epoch_total_time, loss))
开发者ID:memray,项目名称:paragraph-vectors,代码行数:104,代码来源:train.py


注:本文中的torch.optim.Adam.state_dict方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。