当前位置: 首页>>代码示例>>Python>>正文


Python tqdm.write方法代码示例

本文整理汇总了Python中tqdm.tqdm.write方法的典型用法代码示例。如果您正苦于以下问题:Python tqdm.write方法的具体用法?Python tqdm.write怎么用?Python tqdm.write使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tqdm.tqdm的用法示例。


在下文中一共展示了tqdm.write方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def test(model, data_tar, e):
    total_loss_test = 0
    correct = 0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for batch_id, (data, target) in enumerate(data_tar):
            data, target = data.view(-1,28 * 28).to(DEVICE),target.to(DEVICE)
            model.eval()
            ypred, _, _ = model(data, data)
            loss = criterion(ypred, target)
            pred = ypred.data.max(1)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
            total_loss_test += loss.data
        accuracy = correct * 100. / len(data_tar.dataset)
        res = 'Test: total loss: {:.6f}, correct: [{}/{}], testing accuracy: {:.4f}%'.format(
            total_loss_test, correct, len(data_tar.dataset), accuracy
        )
    tqdm.write(res)
    RESULT_TEST.append([e, total_loss_test, accuracy])
    log_test.write(res + '\n') 
开发者ID:jindongwang,项目名称:transferlearning,代码行数:22,代码来源:main.py

示例2: test

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def test(model, loader, criterion, device, dtype, child):
    model.eval()
    test_loss = 0
    correct1, correct5 = 0, 0

    enum_load = enumerate(loader) if child else enumerate(tqdm(loader))

    with torch.no_grad():
        for batch_idx, (data, target) in enum_load:
            data, target = data.to(device=device, dtype=dtype), target.to(device=device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            corr = correct(output, target, topk=(1, 5))
            correct1 += corr[0]
            correct5 += corr[1]

    test_loss /= len(loader)
    if not child:
        tqdm.write(
            '\nTest set: Average loss: {:.4f}, Top1: {}/{} ({:.2f}%), '
            'Top5: {}/{} ({:.2f}%)'.format(test_loss, int(correct1), len(loader.sampler),
                                           100. * correct1 / len(loader.sampler), int(correct5),
                                           len(loader.sampler), 100. * correct5 / len(loader.sampler)))
    return test_loss, correct1 / len(loader.sampler), correct5 / len(loader.sampler) 
开发者ID:Randl,项目名称:MobileNetV3-pytorch,代码行数:26,代码来源:run.py

示例3: test

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def test(capsule_net, test_loader, epoch):
    capsule_net.eval()
    test_loss = 0
    correct = 0
    for batch_id, (data, target) in enumerate(test_loader):

        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)

        if USE_CUDA:
            data, target = data.cuda(), target.cuda()

        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)

        test_loss += loss.data[0]
        correct += sum(np.argmax(masked.data.cpu().numpy(), 1) ==
                       np.argmax(target.data.cpu().numpy(), 1))

    tqdm.write(
        "Epoch: [{}/{}], test accuracy: {:.6f}, loss: {:.6f}".format(epoch, N_EPOCHS, correct / len(test_loader.dataset),
                                                                  test_loss / len(test_loader))) 
开发者ID:jindongwang,项目名称:Pytorch-CapsuleNet,代码行数:24,代码来源:test_capsnet.py

示例4: __call__

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def __call__(self, transformer, callback_data, phase, data, idx):
        if phase == CallbackPhase.train_pre_:
            self.total_iterations = callback_data['config'].attrs['total_iterations']
            num_intervals = self.total_iterations // self.frequency
            for loss_name in self.interval_loss_comp.output_keys:
                callback_data.create_dataset("cost/{}".format(loss_name), (num_intervals,))
            callback_data.create_dataset("time/loss", (num_intervals,))
        elif phase == CallbackPhase.train_post:
            losses = loop_eval(self.dataset, self.interval_loss_comp)
            tqdm.write("Training complete.  Avg losses: {}".format(losses))
        elif phase == CallbackPhase.minibatch_post and ((idx + 1) % self.frequency == 0):
            start_loss = default_timer()
            interval_idx = idx // self.frequency

            losses = loop_eval(self.dataset, self.interval_loss_comp)

            for loss_name, loss in losses.items():
                callback_data["cost/{}".format(loss_name)][interval_idx] = loss

            callback_data["time/loss"][interval_idx] = (default_timer() - start_loss)
            tqdm.write("Interval {} Iteration {} complete.  Avg losses: {}".format(
                interval_idx + 1, idx + 1, losses)) 
开发者ID:NervanaSystems,项目名称:ngraph-python,代码行数:24,代码来源:callbacks.py

示例5: train_with_dataset

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def train_with_dataset(self,dataset,batch_size,include_action=False,iter=10000,l2_reg=0.01,debug=False):
        sess = tf.get_default_session()

        for it in tqdm(range(iter),dynamic_ncols=True):
            b_x,b_y,x_split,y_split,b_l = dataset.batch(batch_size=batch_size,include_action=include_action)
            loss,l2_loss,acc,_ = sess.run([self.loss,self.l2_loss,self.acc,self.update_op],feed_dict={
                self.x:b_x,
                self.y:b_y,
                self.x_split:x_split,
                self.y_split:y_split,
                self.l:b_l,
                self.l2_reg:l2_reg,
            })

            if debug:
                if it % 100 == 0 or it < 10:
                    tqdm.write(('loss: %f (l2_loss: %f), acc: %f'%(loss,l2_loss,acc))) 
开发者ID:hiwonjoon,项目名称:ICML2019-TREX,代码行数:19,代码来源:preference_learning.py

示例6: loadFromFile

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def loadFromFile(cls, file_path):
        """
        Loads a Merkle-tree from the provided file, the latter being the result
        of an export (cf. the *MerkleTree.export()* method)

        :param file_path: relative path of the file to load from with
                respect to the current working directory
        :type file_path: str
        :returns: The tree loaded from the provided file
        :rtype: MerkleTree

        :raises WrongJSONFormat: if the JSON object loaded from within the
                    provided file is not a Merkle-tree export
        """
        with open(file_path, 'r') as __file:
            loaded_object = json.load(__file)
        try:
            header = loaded_object['header']
            tree = cls(
                hash_type=header['hash_type'],
                encoding=header['encoding'],
                raw_bytes=header['raw_bytes'],
                security=header['security'])
        except KeyError:
            raise WrongJSONFormat

        tqdm.write('\nFile has been loaded')
        update = tree.update
        for hash in tqdm(loaded_object['hashes'], desc='Retrieving tree...'):
            update(digest=hash)
        tqdm.write('Tree has been retrieved')
        return tree


    # Comparison 
开发者ID:fmerg,项目名称:pymerkle,代码行数:37,代码来源:tree.py

示例7: test

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def test(model, loader, criterion, device, dtype):
    model.eval()
    test_loss = 0
    correct1, correct5 = 0, 0

    for batch_idx, (data, target) in enumerate(tqdm(loader)):
        data, target = data.to(device=device, dtype=dtype), target.to(device=device)
        with torch.no_grad():
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            corr = correct(output, target, topk=(1, 5))
        correct1 += corr[0]
        correct5 += corr[1]

    test_loss /= len(loader)

    tqdm.write(
        '\nTest set: Average loss: {:.4f}, Top1: {}/{} ({:.2f}%), '
        'Top5: {}/{} ({:.2f}%)'.format(test_loss, int(correct1), len(loader.dataset),
                                       100. * correct1 / len(loader.dataset), int(correct5),
                                       len(loader.dataset), 100. * correct5 / len(loader.dataset)))
    return test_loss, correct1 / len(loader.dataset), correct5 / len(loader.dataset) 
开发者ID:Randl,项目名称:MobileNetV2-pytorch,代码行数:24,代码来源:run.py

示例8: loadConversations

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def loadConversations(self, dirName):
        """
        Args:
            dirName (str): folder to load
        Return:
            array(question, answer): the extracted QA pairs
        """
        conversations = []
        dirList = self.filesInDir(dirName)
        for filepath in tqdm(dirList, "OpenSubtitles data files"):
            if filepath.endswith('gz'):
                try:
                    doc = self.getXML(filepath)
                    conversations.extend(self.genList(doc))
                except ValueError:
                    tqdm.write("Skipping file %s with errors." % filepath)
                except:
                    print("Unexpected error:", sys.exc_info()[0])
                    raise
        return conversations 
开发者ID:PENGZhaoqing,项目名称:deepQA,代码行数:22,代码来源:opensubsdata.py

示例9: calculate_stats

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def calculate_stats(stats, opts):
    model_desc = Default_MargiPose_Desc
    model = create_model(model_desc)
    skeleton = CanonicalSkeletonDesc
    loader = create_train_dataloader(
        [opts.dataset], model.data_specs, opts.batch_size, opts.examples_per_epoch, False)
    loader.dataset.without_image = not opts.with_image
    for epoch in range(opts.epochs):
        for batch in tqdm(loader, total=len(loader), leave=False, ascii=True):
            joints_3d = np.asarray(batch['target'])
            stats['root_x'].add_samples(joints_3d[:, skeleton.root_joint_id, 0])
            stats['root_y'].add_samples(joints_3d[:, skeleton.root_joint_id, 1])
            stats['root_z'].add_samples(joints_3d[:, skeleton.root_joint_id, 2])
            stats['lankle_x'].add_samples(joints_3d[:, skeleton.joint_names.index('left_ankle'), 0])
            stats['lankle_y'].add_samples(joints_3d[:, skeleton.joint_names.index('left_ankle'), 1])
            stats['lankle_z'].add_samples(joints_3d[:, skeleton.joint_names.index('left_ankle'), 2])
            if opts.with_image:
                image = np.asarray(batch['input'])
                stats['red'].add_samples(image[:, 0].ravel())
                stats['green'].add_samples(image[:, 1].ravel())
                stats['blue'].add_samples(image[:, 2].ravel())
            stats['index'].add_samples(np.asarray(batch['index'], dtype=np.float32) / (len(loader.dataset) - 1))
        tqdm.write(f'Epoch {epoch + 1:3d}')
        tqdm.write(repr(stats))
    tqdm.write('Done.') 
开发者ID:anibali,项目名称:margipose,代码行数:27,代码来源:calc_dataloader_stats.py

示例10: gt_roidb

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            os.remove(cache_file)

        gt_roidb = [self._load_pascal_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
        tqdm.write('wrote gt roidb to {}'.format(cache_file))

        return gt_roidb 
开发者ID:jinyu121,项目名称:CIOD,代码行数:19,代码来源:pascal_voc.py

示例11: stream_handler

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def stream_handler(loglevel, is_gui):
    """ Add a logging cli handler """
    # Don't set stdout to lower than verbose
    loglevel = max(loglevel, 15)
    log_format = FaceswapFormatter("%(asctime)s %(levelname)-8s %(message)s",
                                   datefmt="%m/%d/%Y %H:%M:%S")

    if is_gui:
        # tqdm.write inserts extra lines in the GUI, so use standard output as
        # it is not needed there.
        log_console = logging.StreamHandler(sys.stdout)
    else:
        log_console = TqdmHandler(sys.stdout)
    log_console.setFormatter(log_format)
    log_console.setLevel(loglevel)
    return log_console 
开发者ID:deepfakes,项目名称:faceswap,代码行数:18,代码来源:logger.py

示例12: crash_log

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def crash_log():
    """ Write debug_buffer to a crash log on crash """
    original_traceback = traceback.format_exc()
    path = os.path.dirname(os.path.realpath(sys.argv[0]))
    filename = os.path.join(path, datetime.now().strftime("crash_report.%Y.%m.%d.%H%M%S%f.log"))
    freeze_log = list(debug_buffer)
    try:
        from lib.sysinfo import sysinfo  # pylint:disable=import-outside-toplevel
    except Exception:  # pylint:disable=broad-except
        sysinfo = ("\n\nThere was an error importing System Information from lib.sysinfo. This is "
                   "probably a bug which should be fixed:\n{}".format(traceback.format_exc()))
    with open(filename, "w") as outfile:
        outfile.writelines(freeze_log)
        outfile.write(original_traceback)
        outfile.write(sysinfo)
    return filename 
开发者ID:deepfakes,项目名称:faceswap,代码行数:18,代码来源:logger.py

示例13: _check_alignments

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def _check_alignments(self, frame_name):
        """ Ensure that we have alignments for the current frame.

        If we have no alignments for this image, skip it and output a message.

        Parameters
        ----------
        frame_name: str
            The name of the frame to check that we have alignments for

        Returns
        -------
        bool
            ``True`` if we have alignments for this face, otherwise ``False``
        """
        have_alignments = self._alignments.frame_exists(frame_name)
        if not have_alignments:
            tqdm.write("No alignment found for {}, "
                       "skipping".format(frame_name))
        return have_alignments 
开发者ID:deepfakes,项目名称:faceswap,代码行数:22,代码来源:convert.py

示例14: download_objects_of_interest

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def download_objects_of_interest(download_list):
    def fetch_url(url):
        try:
            urllib.request.urlretrieve(url, os.path.join(OUTPUT_DIR, url.split("/")[-1]))
            return url, None
        except Exception as e:
            return None, e

    start = timer()
    results = ThreadPool(20).imap_unordered(fetch_url, download_list)

    df_pbar = tqdm(total=len(download_list), position=1, desc="Download %: ")

    for url, error in results:
        df_pbar.update(1)
        if error is None:
            pass  # TODO: find a way to do tqdm.write() with a refresh
            # print("{} fetched in {}s".format(url, timer() - start), end='\r')
        else:
            pass  # TODO: find a way to do tqdm.write() with a refresh
            # print("error fetching {}: {}".format(url, error), end='\r') 
开发者ID:harshilpatel312,项目名称:open-images-downloader,代码行数:23,代码来源:download.py

示例15: colorize_exceptions

# 需要导入模块: from tqdm import tqdm [as 别名]
# 或者: from tqdm.tqdm import write [as 别名]
def colorize_exceptions() -> None:
    """Colorizes the system stderr ouput using pygments if installed"""
    try:
        import traceback
        from pygments import highlight
        from pygments.lexers import get_lexer_by_name
        from pygments.formatters import TerminalFormatter

        def colorized_excepthook(type_: Type[BaseException],
                                 value: BaseException,
                                 tb: TracebackType) -> None:
            tbtext = ''.join(traceback.format_exception(type_, value, tb))
            lexer = get_lexer_by_name("pytb", stripall=True)
            formatter = TerminalFormatter()
            sys.stderr.write(highlight(tbtext, lexer, formatter))

        sys.excepthook = colorized_excepthook  # type: ignore

    except ModuleNotFoundError:
        pass 
开发者ID:asappresearch,项目名称:flambe,代码行数:22,代码来源:logging.py


注:本文中的tqdm.tqdm.write方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。