当前位置: 首页>>代码示例>>Python>>正文


Python wandb.log方法代码示例

本文整理汇总了Python中wandb.log方法的典型用法代码示例。如果您正苦于以下问题:Python wandb.log方法的具体用法?Python wandb.log怎么用?Python wandb.log使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在wandb的用法示例。


在下文中一共展示了wandb.log方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: remove_duplicates

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def remove_duplicates(tr_eps, val_eps, test_eps, test_labels):
    """
    Remove any items in test_eps (&test_labels) which are present in tr/val_eps
    """
    flat_tr = list(chain.from_iterable(tr_eps))
    flat_val = list(chain.from_iterable(val_eps))
    tr_val_set = set([x.numpy().tostring() for x in flat_tr] + [x.numpy().tostring() for x in flat_val])
    flat_test = list(chain.from_iterable(test_eps))

    for i, episode in enumerate(test_eps[:]):
        test_labels[i] = [label for obs, label in zip(test_eps[i], test_labels[i]) if obs.numpy().tostring() not in tr_val_set]
        test_eps[i] = [obs for obs in episode if obs.numpy().tostring() not in tr_val_set]
    test_len = len(list(chain.from_iterable(test_eps)))
    dups = len(flat_test) - test_len
    print('Duplicates: {}, Test Len: {}'.format(dups, test_len))
    #wandb.log({'Duplicates': dups, 'Test Len': test_len})
    return test_eps, test_labels 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:19,代码来源:label_preprocess.py

示例2: write_log

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def write_log(
        self, log_value: tuple,
    ):
        i_episode, n_step, score, actor_loss, critic_loss, total_loss = log_value
        print(
            "[INFO] episode %d\tepisode steps: %d\ttotal score: %d\n"
            "total loss: %f\tActor loss: %f\tCritic loss: %f\n"
            % (i_episode, n_step, score, total_loss, actor_loss, critic_loss)
        )

        if self.args.log:
            wandb.log(
                {
                    "total loss": total_loss,
                    "actor loss": actor_loss,
                    "critic loss": critic_loss,
                    "score": score,
                }
            ) 
开发者ID:medipixel,项目名称:rl_algorithms,代码行数:21,代码来源:agent.py

示例3: write_log

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def write_log(self, log_value: tuple):
        i, score, policy_loss, value_loss = log_value
        total_loss = policy_loss + value_loss

        print(
            "[INFO] episode %d\tepisode step: %d\ttotal score: %d\n"
            "total loss: %.4f\tpolicy loss: %.4f\tvalue loss: %.4f\n"
            % (i, self.episode_step, score, total_loss, policy_loss, value_loss)
        )

        if self.args.log:
            wandb.log(
                {
                    "total loss": total_loss,
                    "policy loss": policy_loss,
                    "value loss": value_loss,
                    "score": score,
                }
            ) 
开发者ID:medipixel,项目名称:rl_algorithms,代码行数:21,代码来源:agent.py

示例4: run

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def run(self):
        """Run main logging loop; continuously receive data and log."""
        if self.args.log:
            self.set_wandb()

        while self.update_step < self.args.max_update_step:
            self.recv_log_info()
            if self.log_info_queue:  # if non-empty
                log_info_id = self.log_info_queue.pop()
                log_info = pa.deserialize(log_info_id)
                state_dict = log_info["state_dict"]
                log_value = log_info["log_value"]
                self.update_step = log_value["update_step"]

                self.synchronize(state_dict)
                avg_score = self.test(self.update_step)
                log_value["avg_score"] = avg_score
                self.write_log(log_value) 
开发者ID:medipixel,项目名称:rl_algorithms,代码行数:20,代码来源:distributed_logger.py

示例5: __init__

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def __init__(
        self, train_tensors=[], wandb_name=None, wandb_project=None, args=None, update_freq=25,
    ):
        """
        Args:
            train_tensors: list of tensors to evaluate and log based on training batches
            wandb_name: wandb experiment name
            wandb_project: wandb project name
            args: argparse flags - will be logged as hyperparameters
            update_freq: frequency with which to log updates
        """
        super().__init__()

        if not _WANDB_AVAILABLE:
            logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")

        self._update_freq = update_freq
        self._train_tensors = train_tensors
        self._name = wandb_name
        self._project = wandb_project
        self._args = args 
开发者ID:NVIDIA,项目名称:NeMo,代码行数:23,代码来源:deprecated_callbacks.py

示例6: visualize_recon

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def visualize_recon(self, input_image, recon_image, test=False):
        input_image = torchvision.utils.make_grid(input_image)
        recon_image = torchvision.utils.make_grid(recon_image)

        if self.white_line is None:
            self.white_line = torch.ones((3, input_image.size(1), 10)).to(self.device)

        samples = torch.cat([input_image, self.white_line, recon_image], dim=2)

        if self.file_save:
            if test:
                file_name = os.path.join(self.test_output_dir, '{}_{}.{}'.format(c.RECON, self.iter, c.JPG))
            else:
                file_name = os.path.join(self.train_output_dir, '{}.{}'.format(c.RECON, c.JPG))
            torchvision.utils.save_image(samples, file_name)

        if self.use_wandb:
            import wandb
            wandb.log({c.RECON_IMAGE: wandb.Image(samples, caption=str(self.iter))},
                      step=self.iter) 
开发者ID:amir-abdi,项目名称:disentanglement-pytorch,代码行数:22,代码来源:base_disentangler.py

示例7: visualize_figure

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def visualize_figure(self, fig):
        import wandb
        logger.info("wandb.visualize_figure() called...")
        if wandb.run:
            wandb.log({"figure": fig}) 
开发者ID:uber,项目名称:ludwig,代码行数:7,代码来源:wandb.py

示例8: remove_low_entropy_labels

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def remove_low_entropy_labels(episode_labels, entropy_threshold=0.3):
    flat_label_list = list(chain.from_iterable(episode_labels))
    counts = {}

    for label_dict in flat_label_list:
        for k in label_dict:
            counts[k] = counts.get(k, {})
            v = label_dict[k]
            counts[k][v] = counts[k].get(v, 0) + 1
    low_entropy_labels = []

    entropy_dict = {}
    for k in counts:
        entropy = torch.distributions.Categorical(
            torch.tensor([x / len(flat_label_list) for x in counts[k].values()])).entropy()
        entropy_dict['entropy_' + k] = entropy
        if entropy < entropy_threshold:
            print("Deleting {} for being too low in entropy! Sorry, dood!".format(k))
            low_entropy_labels.append(k)

    for e in episode_labels:
        for obs in e:
            for key in low_entropy_labels:
                del obs[key]
    # wandb.log(entropy_dict)
    return episode_labels, entropy_dict 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:28,代码来源:label_preprocess.py

示例9: get_pretrained_rl_representations

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def get_pretrained_rl_representations(args, steps):
    checkpoint = checkpointed_steps_full_sorted[args.checkpoint_index]
    episodes, episode_labels, mean_reward = get_ppo_representations(args, steps, checkpoint)
    wandb.log({"reward": mean_reward, "checkpoint": checkpoint})
    return episodes, episode_labels 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:7,代码来源:pretrained_agents.py

示例10: __init__

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def __init__(self, hp, logdir):
        self.hp = hp
        if hp.log.use_tensorboard:
            self.tensorboard = SummaryWriter(logdir)
        if hp.log.use_wandb:
            wandb_init_conf = hp.log.wandb_init_conf.to_dict()
            wandb_init_conf["config"] = hp.to_dict()
            wandb.init(**wandb_init_conf) 
开发者ID:ryul99,项目名称:pytorch-project-template,代码行数:10,代码来源:writer.py

示例11: logging_with_step

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def logging_with_step(self, value, step, logging_name):
        if self.hp.log.use_tensorboard:
            self.tensorboard.add_scalar(logging_name, value, step)
        if self.hp.log.use_wandb:
            wandb.log({logging_name: value}, step=step) 
开发者ID:ryul99,项目名称:pytorch-project-template,代码行数:7,代码来源:writer.py

示例12: test_epoch_end

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def test_epoch_end(self, epoch_num: int):
        """ Performs house-keeping at the end of the test epoch

        It reports the metric that is being traced at the end
        of the test epoch

        Parameters
        ----------
        epoch_num : int
            Epoch num after which the test dataset is run

        """
        metric_report = self.test_metric_calc.report_metrics()
        for label_namespace, table in metric_report.items():
            self.msg_printer.divider(text=f"Test Metrics for {label_namespace.upper()}")
            print(table)

        precision_recall_fmeasure = self.test_metric_calc.get_metric()
        self.msg_printer.divider(f"Test @ Epoch {epoch_num+1}")
        self.test_logger.info(
            f"Test Metrics @ Epoch {epoch_num+1} - {precision_recall_fmeasure}"
        )
        if self.use_wandb:
            wandb.log({"test_metrics": str(precision_recall_fmeasure)})

        self.summaryWriter.close() 
开发者ID:abhinavkashyap,项目名称:sciwing,代码行数:28,代码来源:engine.py

示例13: write_log

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def write_log(self, log_value: tuple):
        """Write log about loss and score"""
        i, loss, score, avg_time_cost = log_value
        print(
            "[INFO] episode %d, episode step: %d, total step: %d, total score: %f\n"
            "epsilon: %f, loss: %f, avg q-value: %f (spent %.6f sec/step)\n"
            % (
                i,
                self.episode_step,
                self.total_step,
                score,
                self.epsilon,
                loss[0],
                loss[1],
                avg_time_cost,
            )
        )

        if self.args.log:
            wandb.log(
                {
                    "score": score,
                    "epsilon": self.epsilon,
                    "dqn loss": loss[0],
                    "avg q values": loss[1],
                    "time per each step": avg_time_cost,
                    "total_step": self.total_step,
                }
            )

    # pylint: disable=no-self-use, unnecessary-pass 
开发者ID:medipixel,项目名称:rl_algorithms,代码行数:33,代码来源:agent.py

示例14: write_log

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def write_log(self, log_value: tuple):
        """Write log about loss and score"""
        i, loss, score, policy_update_freq, avg_time_cost = log_value
        total_loss = loss.sum()

        print(
            "[INFO] episode %d, episode_step %d, total step %d, total score: %d\n"
            "total loss: %.3f actor_loss: %.3f qf_1_loss: %.3f qf_2_loss: %.3f "
            "vf_loss: %.3f alpha_loss: %.3f n_qf_mask: %d (spent %.6f sec/step)\n"
            % (
                i,
                self.episode_step,
                self.total_step,
                score,
                total_loss,
                loss[0] * policy_update_freq,  # actor loss
                loss[1],  # qf_1 loss
                loss[2],  # qf_2 loss
                loss[3],  # vf loss
                loss[4],  # alpha loss
                loss[5],  # n_qf_mask
                avg_time_cost,
            )
        )

        if self.args.log:
            wandb.log(
                {
                    "score": score,
                    "total loss": total_loss,
                    "actor loss": loss[0] * policy_update_freq,
                    "qf_1 loss": loss[1],
                    "qf_2 loss": loss[2],
                    "vf loss": loss[3],
                    "alpha loss": loss[4],
                    "time per each step": avg_time_cost,
                }
            ) 
开发者ID:medipixel,项目名称:rl_algorithms,代码行数:40,代码来源:sac_agent.py

示例15: write_log

# 需要导入模块: import wandb [as 别名]
# 或者: from wandb import log [as 别名]
def write_log(self, log_value: tuple):
        """Write log about loss and score"""
        i, loss, score, avg_time_cost = log_value
        total_loss = loss.sum()

        print(
            "[INFO] episode %d, episode step: %d, total step: %d, total score: %d\n"
            "total loss: %f actor_loss: %.3f critic_loss: %.3f, n_qf_mask: %d "
            "(spent %.6f sec/step)\n"
            % (
                i,
                self.episode_step,
                self.total_step,
                score,
                total_loss,
                loss[0],
                loss[1],
                loss[2],
                avg_time_cost,
            )  # actor loss  # critic loss
        )

        if self.args.log:
            wandb.log(
                {
                    "score": score,
                    "total loss": total_loss,
                    "actor loss": loss[0],
                    "critic loss": loss[1],
                    "time per each step": avg_time_cost,
                }
            ) 
开发者ID:medipixel,项目名称:rl_algorithms,代码行数:34,代码来源:ddpg_agent.py


注:本文中的wandb.log方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。