當前位置: 首頁>>代碼示例>>Python>>正文


Python Adam.state_dict方法代碼示例

本文整理匯總了Python中torch.optim.Adam.state_dict方法的典型用法代碼示例。如果您正苦於以下問題:Python Adam.state_dict方法的具體用法?Python Adam.state_dict怎麽用?Python Adam.state_dict使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.optim.Adam的用法示例。


在下文中一共展示了Adam.state_dict方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: _run

# 需要導入模塊: from torch.optim import Adam [as 別名]
# 或者: from torch.optim.Adam import state_dict [as 別名]

#.........這裏部分代碼省略.........
                transfer batch to Torch: 1 ms, (1508424955295, 1508424955296)
                Generating one example time: 2~5 ms, (1508458881118, 1508458881122)
                Generating one document time: 50~400 ms, (1508458881118, 1508458881122)
                Generating one batch time: 650~700 ms, (1508458880690, 1508458881122)
            After changing to torch.sampler
                Generating one example time: 4~7 ms
                Generating one batch time: 900~1200 ms

    '''

    model = DistributedMemory(
        vec_dim,
        num_docs=len(dataset),
        num_words=vocabulary_size)

    cost_func = NegativeSampling()
    optimizer = Adam(params=model.parameters(), lr=lr)
    logger = logging.getLogger('root')

    if torch.cuda.is_available():
        model.cuda()
        logger.info("Running on GPU - CUDA")
    else:
        logger.info("Running on CPU")

    logger.info("Dataset comprised of {:d} documents.".format(len(dataset)))
    logger.info("Vocabulary size is {:d}.\n".format(vocabulary_size))
    logger.info("Training started.")

    best_loss = float_info.max
    prev_model_file_path = ""

    progbar = Progbar(num_batches, batch_size=batch_size, total_examples = number_examples)

    for epoch_i in range(num_epochs):
        epoch_start_time = time.time()
        loss = []

        for batch_i in range(num_batches):
            start_time = current_milli_time()
            batch = next(data_generator)
            current_time = current_milli_time()
            print('data-prepare time: %d ms' % (round(current_time - start_time)))

            start_time = current_milli_time()
            x = model.forward(
                batch.context_ids,
                batch.doc_ids,
                batch.target_noise_ids)
            x = cost_func.forward(x)
            loss.append(x.data[0])
            print('forward time: %d ms' % round(current_milli_time() - start_time))

            start_time = current_milli_time()
            model.zero_grad()
            x.backward()
            optimizer.step()
            print('backward time: %d ms' % round(current_milli_time() - start_time))

            progbar.update(epoch_i, batch_i, )
            # _print_progress(epoch_i, batch_i, num_batches)

        # end of epoch
        loss = torch.mean(torch.FloatTensor(loss))
        is_best_loss = loss < best_loss
        best_loss = min(loss, best_loss)
        progbar.update(epoch_i, batch_i, [('loss', loss), ('best_loss', best_loss)])

        model_file_name = MODEL_NAME.format(
            data_file_name[:-4],
            model_ver,
            vec_combine_method,
            context_size,
            num_noise_words,
            vec_dim,
            batch_size,
            lr,
            epoch_i + 1,
            loss)
        model_file_path = join(MODELS_DIR, model_file_name)
        if not os.path.exists(MODELS_DIR):
            os.makedirs(MODELS_DIR)
        state = {
            'epoch': epoch_i + 1,
            'model_state_dict': model.state_dict(),
            'best_loss': best_loss,
            'optimizer_state_dict': optimizer.state_dict()
        }
        if save_all:
            torch.save(state, model_file_path)
        elif is_best_loss:
            try:
                remove(prev_model_file_path)
            except FileNotFoundError:
                pass
            torch.save(state, model_file_path)
            prev_model_file_path = model_file_path

        epoch_total_time = round(time.time() - epoch_start_time)
        logger.info(" ({:d}s) - loss: {:.4f}".format(epoch_total_time, loss))
開發者ID:memray,項目名稱:paragraph-vectors,代碼行數:104,代碼來源:train.py


注:本文中的torch.optim.Adam.state_dict方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。