当前位置: 首页>>代码示例>>Python>>正文


Python logger.info函数代码示例

本文整理汇总了Python中tensorpack.utils.logger.info函数的典型用法代码示例。如果您正苦于以下问题:Python info函数的具体用法?Python info怎么用?Python info使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了info函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_config

def get_config(model, nr_tower):
    batch = TOTAL_BATCH_SIZE // nr_tower

    logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
    dataset_train = get_data('train', batch)
    dataset_val = get_data('val', batch)

    step_size = 1280000 // TOTAL_BATCH_SIZE
    max_iter = 3 * 10**5
    max_epoch = (max_iter // step_size) + 1
    callbacks = [
        ModelSaver(),
        ScheduledHyperParamSetter('learning_rate',
                                  [(0, 0.5), (max_iter, 0)],
                                  interp='linear', step_based=True),
    ]
    infs = [ClassificationError('wrong-top1', 'val-error-top1'),
            ClassificationError('wrong-top5', 'val-error-top5')]
    if nr_tower == 1:
        # single-GPU inference with queue prefetch
        callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
    else:
        # multi-GPU inference (with mandatory queue prefetch)
        callbacks.append(DataParallelInferenceRunner(
            dataset_val, infs, list(range(nr_tower))))

    return TrainConfig(
        model=model,
        dataflow=dataset_train,
        callbacks=callbacks,
        steps_per_epoch=step_size,
        max_epoch=max_epoch,
    )
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:33,代码来源:shufflenet.py

示例2: _sync

 def _sync(self):
     logger.info("Updating weights ...")
     dic = {v.name: v.eval() for v in self.vars}
     self.shared_dic['params'] = dic
     self.condvar.acquire()
     self.condvar.notify_all()
     self.condvar.release()
开发者ID:j50888,项目名称:tensorpack,代码行数:7,代码来源:simulator.py

示例3: __init__

    def __init__(self,
                 predictor_io_names,
                 player,
                 state_shape,
                 batch_size,
                 memory_size, init_memory_size,
                 exploration, end_exploration, exploration_epoch_anneal,
                 update_frequency, history_len):
        """
        Args:
            predictor_io_names (tuple of list of str): input/output names to
                predict Q value from state.
            player (RLEnvironment): the player.
            history_len (int): length of history frames to concat. Zero-filled
                initial frames.
            update_frequency (int): number of new transitions to add to memory
                after sampling a batch of transitions for training.
        """
        init_memory_size = int(init_memory_size)

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.num_actions = player.get_action_space().num_actions()
        logger.info("Number of Legal actions: {}".format(self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # TODO just use a semaphore?
        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape, history_len)
开发者ID:j50888,项目名称:tensorpack,代码行数:34,代码来源:expreplay.py

示例4: _init_memory

    def _init_memory(self):
        logger.info("Populating replay memory with epsilon={} ...".format(self.exploration))

        with get_tqdm(total=self.init_memory_size) as pbar:
            while len(self.mem) < self.init_memory_size:
                self._populate_exp()
                pbar.update()
        self._init_memory_flag.set()
开发者ID:tobyma,项目名称:tensorpack,代码行数:8,代码来源:expreplay.py

示例5: __init__

 def __init__(self, dirname, label='phoneme'):
     self.dirname = dirname
     assert os.path.isdir(dirname), dirname
     self.filelists = [k for k in fs.recursive_walk(self.dirname)
                       if k.endswith('.wav')]
     logger.info("Found {} wav files ...".format(len(self.filelists)))
     assert len(self.filelists), "Found no '.wav' files!"
     assert label in ['phoneme', 'letter'], label
     self.label = label
开发者ID:ahuirecome,项目名称:tensorpack,代码行数:9,代码来源:create-lmdb.py

示例6: eval_model_multithread

def eval_model_multithread(pred, nr_eval, get_player_fn):
    """
    Args:
        pred (OfflinePredictor): state -> Qvalue
    """
    NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
    with pred.sess.as_default():
        mean, max = eval_with_funcs([pred] * NR_PROC, nr_eval, get_player_fn)
    logger.info("Average Score: {}; Max Score: {}".format(mean, max))
开发者ID:ahuirecome,项目名称:tensorpack,代码行数:9,代码来源:common.py

示例7: update_target_param

 def update_target_param():
     vars = tf.global_variables()
     ops = []
     G = tf.get_default_graph()
     for v in vars:
         target_name = v.op.name
         if target_name.startswith('target'):
             new_name = target_name.replace('target/', '')
             logger.info("{} <- {}".format(target_name, new_name))
             ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
     return tf.group(*ops, name='update_target_network')
开发者ID:caserzer,项目名称:tensorpack,代码行数:11,代码来源:DQNModel.py

示例8: eval_with_funcs

def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
    """
    Args:
        predictors ([PredictorBase])
    """
    class Worker(StoppableThread, ShareSessionThread):
        def __init__(self, func, queue):
            super(Worker, self).__init__()
            self._func = func
            self.q = queue

        def func(self, *args, **kwargs):
            if self.stopped():
                raise RuntimeError("stopped!")
            return self._func(*args, **kwargs)

        def run(self):
            with self.default_sess():
                player = get_player_fn(train=False)
                while not self.stopped():
                    try:
                        score = play_one_episode(player, self.func)
                    except RuntimeError:
                        return
                    self.queue_put_stoppable(self.q, score)

    q = queue.Queue()
    threads = [Worker(f, q) for f in predictors]

    for k in threads:
        k.start()
        time.sleep(0.1)  # avoid simulator bugs
    stat = StatCounter()

    def fetch():
        r = q.get()
        stat.feed(r)
        if verbose:
            logger.info("Score: {}".format(r))

    for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
        fetch()
    # waiting is necessary, otherwise the estimated mean score is biased
    logger.info("Waiting for all the workers to finish the last run...")
    for k in threads:
        k.stop()
    for k in threads:
        k.join()
    while q.qsize():
        fetch()

    if stat.count > 0:
        return (stat.average, stat.max)
    return (0, 0)
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:54,代码来源:common.py

示例9: convert_param_name

def convert_param_name(param):
    resnet_param = {}
    for k, v in six.iteritems(param):
        try:
            newname = name_conversion(k)
        except Exception:
            logger.error("Exception when processing caffe layer {}".format(k))
            raise
        logger.info("Name Transform: " + k + ' --> ' + newname)
        resnet_param[newname] = v
    return resnet_param
开发者ID:caserzer,项目名称:tensorpack,代码行数:11,代码来源:load-resnet.py

示例10: compute_mean_std

def compute_mean_std(db, fname):
    ds = LMDBSerializer.load(db, shuffle=False)
    ds.reset_state()
    o = OnlineMoments()
    for dp in get_tqdm(ds):
        feat = dp[0]  # len x dim
        for f in feat:
            o.feed(f)
    logger.info("Writing to {} ...".format(fname))
    with open(fname, 'wb') as f:
        f.write(serialize.dumps([o.mean, o.std]))
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:11,代码来源:create-lmdb.py

示例11: run

 def run(self):
     self.clients = defaultdict(self.ClientState)
     try:
         while True:
             msg = loads(self.c2s_socket.recv(copy=False).bytes)
             ident, state, reward, isOver = msg
             client = self.clients[ident]
             if client.ident is None:
                 client.ident = ident
             # maybe check history and warn about dead client?
             self._process_msg(client, state, reward, isOver)
     except zmq.ContextTerminated:
         logger.info("[Simulator] Context was terminated.")
开发者ID:ahuirecome,项目名称:tensorpack,代码行数:13,代码来源:simulator.py

示例12: compute_mean_std

def compute_mean_std(db, fname):
    ds = LMDBDataPoint(db, shuffle=False)
    ds.reset_state()
    o = OnlineMoments()
    with get_tqdm(total=ds.size()) as bar:
        for dp in ds.get_data():
            feat = dp[0]  # len x dim
            for f in feat:
                o.feed(f)
            bar.update()
    logger.info("Writing to {} ...".format(fname))
    with open(fname, 'wb') as f:
        f.write(serialize.dumps([o.mean, o.std]))
开发者ID:ahuirecome,项目名称:tensorpack,代码行数:13,代码来源:create-lmdb.py

示例13: _trigger_epoch

 def _trigger_epoch(self):
     if self.exploration > self.end_exploration:
         self.exploration -= self.exploration_epoch_anneal
         logger.info("Exploration changed to {}".format(self.exploration))
     # log player statistics
     stats = self.player.stats
     for k, v in six.iteritems(stats):
         try:
             mean, max = np.mean(v), np.max(v)
             self.trainer.add_scalar_summary('expreplay/mean_' + k, mean)
             self.trainer.add_scalar_summary('expreplay/max_' + k, max)
         except:
             pass
     self.player.reset_stat()
开发者ID:j50888,项目名称:tensorpack,代码行数:14,代码来源:expreplay.py

示例14: print_class_histogram

    def print_class_histogram(self, imgs):
        nr_class = len(COCOMeta.class_names)
        hist_bins = np.arange(nr_class + 1)

        # Histogram of ground-truth objects
        gt_hist = np.zeros((nr_class,), dtype=np.int)
        for entry in imgs:
            # filter crowd?
            gt_inds = np.where(
                (entry['class'] > 0) & (entry['is_crowd'] == 0))[0]
            gt_classes = entry['class'][gt_inds]
            gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
        data = [[COCOMeta.class_names[i], v] for i, v in enumerate(gt_hist)]
        data.append(['total', sum([x[1] for x in data])])
        table = tabulate(data, headers=['class', '#box'], tablefmt='pipe')
        logger.info("Ground-Truth Boxes:\n" + colored(table, 'cyan'))
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:16,代码来源:coco.py

示例15: texture_loss

                def texture_loss(x, p=16):
                    _, h, w, c = x.get_shape().as_list()
                    x = normalize(x)
                    assert h % p == 0 and w % p == 0
                    logger.info('Create texture loss for layer {} with shape {}'.format(x.name, x.get_shape()))

                    x = tf.space_to_batch_nd(x, [p, p], [[0, 0], [0, 0]])  # [b * ?, h/p, w/p, c]
                    x = tf.reshape(x, [p, p, -1, h // p, w // p, c])       # [p, p, b, h/p, w/p, c]
                    x = tf.transpose(x, [2, 3, 4, 0, 1, 5])                # [b * ?, p, p, c]
                    patches_a, patches_b = tf.split(x, 2, axis=0)          # each is b,h/p,w/p,p,p,c

                    patches_a = tf.reshape(patches_a, [-1, p, p, c])       # [b * ?, p, p, c]
                    patches_b = tf.reshape(patches_b, [-1, p, p, c])       # [b * ?, p, p, c]
                    return tf.losses.mean_squared_error(
                        gram_matrix(patches_a),
                        gram_matrix(patches_b),
                        reduction=Reduction.MEAN
                    )
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:18,代码来源:enet-pat.py


注:本文中的tensorpack.utils.logger.info函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。