当前位置: 首页>>代码示例>>Python>>正文


Python tensorboardX.SummaryWriter方法代码示例

本文整理汇总了Python中tensorboardX.SummaryWriter方法的典型用法代码示例。如果您正苦于以下问题:Python tensorboardX.SummaryWriter方法的具体用法?Python tensorboardX.SummaryWriter怎么用?Python tensorboardX.SummaryWriter使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorboardX的用法示例。


在下文中一共展示了tensorboardX.SummaryWriter方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, fsm: FileStructManager, is_continue: bool, network_name: str = None):
        super().__init__()
        self.__writer = None
        self.__txt_log_file = None

        fsm.register_dir(self)
        dir = fsm.get_path(self)
        if dir is None:
            return

        dir = os.path.join(dir, network_name) if network_name is not None else dir

        if not (fsm.in_continue_mode() or is_continue) and os.path.exists(dir) and os.path.isdir(dir):
            idx = 0
            tmp_dir = dir + "_v{}".format(idx)
            while os.path.exists(tmp_dir) and os.path.isdir(tmp_dir):
                idx += 1
                tmp_dir = dir + "_v{}".format(idx)
            dir = tmp_dir

        os.makedirs(dir, exist_ok=True)
        self.__writer = SummaryWriter(dir)
        self.__txt_log_file = open(os.path.join(dir, "log.txt"), 'a' if is_continue else 'w') 
开发者ID:toodef,项目名称:neural-pipeline,代码行数:25,代码来源:tensorboard.py

示例2: run

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def run(self):
        self.writer = SummaryWriter(os.path.join(self.env.model_dir, self.env.args.run))
        try:
            height, width = tuple(map(int, self.config.get('image', 'size').split()))
            tensor = torch.randn(1, 3, height, width)
            step, epoch, dnn = self.env.load()
            self.writer.add_graph(dnn, (torch.autograd.Variable(tensor),))
        except:
            traceback.print_exc()
        while True:
            name, kwargs = self.queue.get()
            if name is None:
                break
            func = getattr(self, 'summary_' + name)
            try:
                func(**kwargs)
            except:
                traceback.print_exc() 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:20,代码来源:train.py

示例3: __init__

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = Critic(state_dim, action_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)
        self.replay_buffer = Replay_buffer()
        self.writer = SummaryWriter(directory)

        self.num_critic_update_iteration = 0
        self.num_actor_update_iteration = 0
        self.num_training = 0 
开发者ID:sweetice,项目名称:Deep-reinforcement-learning-with-pytorch,代码行数:18,代码来源:DDPG.py

示例4: __init__

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, state_dim, action_dim, max_action):

        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.critic_1 = Critic(state_dim, action_dim).to(device)
        self.critic_1_target = Critic(state_dim, action_dim).to(device)
        self.critic_2 = Critic(state_dim, action_dim).to(device)
        self.critic_2_target = Critic(state_dim, action_dim).to(device)

        self.actor_optimizer = optim.Adam(self.actor.parameters())
        self.critic_1_optimizer = optim.Adam(self.critic_1.parameters())
        self.critic_2_optimizer = optim.Adam(self.critic_2.parameters())

        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_1_target.load_state_dict(self.critic_1.state_dict())
        self.critic_2_target.load_state_dict(self.critic_2.state_dict())

        self.max_action = max_action
        self.memory = Replay_buffer(args.capacity)
        self.writer = SummaryWriter(directory)
        self.num_critic_update_iteration = 0
        self.num_actor_update_iteration = 0
        self.num_training = 0 
开发者ID:sweetice,项目名称:Deep-reinforcement-learning-with-pytorch,代码行数:25,代码来源:TD3_BipedalWalker-v2.py

示例5: __init__

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self,
                 controller,
                 log_dir="logs/results",
                 v_action=0,
                 v_observation=0,
                 v_reward=0,
                 windows=[10, 100, 200]):
        self.controller = controller

        self.step_cntr = 0
        self.step_global = 0
        self.step_reset = 0

        self.score = 0
        self.score_history = []

        self.v_action = v_action
        self.v_observation = v_observation
        self.v_reward = v_reward
        self.windows = windows

        self.file_writer = SummaryWriter(log_dir, flush_secs=30) 
开发者ID:aidudezzz,项目名称:deepbots,代码行数:24,代码来源:tensorboard_wrapper.py

示例6: __init__

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, model, trainloader, valloader, args):
        self.model = model
        self.trainloader = trainloader
        self.valloader = valloader
        self.args = args
        self.start_epoch = 0
        self.best_top1 = 0.0

        # Loss function and Optimizer
        self.loss = None
        self.optimizer = None
        self.create_optimization()

        # Model Loading
        self.load_pretrained_model()
        self.load_checkpoint(self.args.resume_from)

        # Tensorboard Writer
        self.summary_writer = SummaryWriter(log_dir=args.summary_dir) 
开发者ID:MG2033,项目名称:MobileNet-V2,代码行数:21,代码来源:train.py

示例7: visualize

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def visualize(args):
  saved_path = constant.EXP_ROOT
  model = models.Model(args, constant.ANSWER_NUM_DICT[args.goal])
  model.cuda()
  model.eval()
  model.load_state_dict(torch.load(saved_path + '/' + args.model_id + '_best.pt')["state_dict"])

  label2id = constant.ANS2ID_DICT["open"] 
  visualize = SummaryWriter("../visualize/" + args.model_id)
  # label_list = ["person", "leader", "president", "politician", "organization", "company", "athlete","adult",  "male",  "man", "television_program", "event"]
  label_list = list(label2id.keys())
  ids = [label2id[_] for _ in label_list]
  if args.gcn:
    # connection_matrix = model.decoder.label_matrix + model.decoder.weight * model.decoder.affinity
    connection_matrix = model.decoder.label_matrix + model.decoder.weight * model.decoder.affinity
    label_vectors = model.decoder.transform(connection_matrix.mm(model.decoder.linear.weight) / connection_matrix.sum(1, keepdim=True))
  else:
    label_vectors = model.decoder.linear.weight.data

  interested_vectors = torch.index_select(label_vectors, 0, torch.tensor(ids).to(torch.device("cuda")))
  visualize.add_embedding(interested_vectors, metadata=label_list, label_img=None) 
开发者ID:xwhan,项目名称:Extremely-Fine-Grained-Entity-Typing,代码行数:23,代码来源:main.py

示例8: __init__

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, config, args):
        super(Trainer, self).__init__(config, args)
        # Best validation error, initialize it with a large number
        self.best_val_err = 1e10
        # Logger Settings
        name = Path(args.checkpoint_dir).stem
        self.log_dir = str(Path(args.log_dir, name))
        self.log_writer = SummaryWriter(self.log_dir)
        self.checkpoint_path = args.checkpoint_path
        self.checkpoint_dir = args.checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)

        self.config = config
        self.audio_processor = AudioProcessor(**config['audio'])
        # Training detail
        self.step = 0
        self.max_step = config['solver']['total_steps'] 
开发者ID:ttaoREtw,项目名称:Tacotron-pytorch,代码行数:19,代码来源:solver.py

示例9: main

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def main():
    args = molecule_arg_parser().parse_args()
    print(args)
    args.name_full = args.env + '_' + args.dataset + '_' + args.name
    args.name_full_load = args.env + '_' + args.dataset_load + '_' + args.name_load + '_' + str(args.load_step)
    # check and clean
    if not os.path.exists('molecule_gen'):
        os.makedirs('molecule_gen')
    if not os.path.exists('ckpt'):
        os.makedirs('ckpt')

    # only keep first worker result in tensorboard
    if MPI.COMM_WORLD.Get_rank() == 0:
        writer = SummaryWriter(comment='_'+args.dataset+'_'+args.name)
    else:
        writer = None
    train(args,seed=args.seed,writer=writer) 
开发者ID:bowenliu16,项目名称:rl_graph_generation,代码行数:19,代码来源:run_molecule.py

示例10: __init__

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, mazes_path, tensorboard_dir, vector_length, mp, n_stack, action_frame_repeat=4,
                 scaled_resolution=(42, 42)):

        self.tensorboard_dir = tensorboard_dir

        mazes_folders = [(x, os.path.join(mazes_path, x)) for x in os.listdir(mazes_path)]
        get_cfg = lambda x: os.path.join(x, [cfg for cfg in sorted(os.listdir(x)) if cfg.endswith('.cfg')][0])
        self.eval_cfgs = [(x[0], get_cfg(x[1])) for x in mazes_folders]

        if not len(self.eval_cfgs):
            raise FileNotFoundError("No eval cfgs found")

        number_maps = 1  # number of maps inside each eval map path

        self.eval_envs = [(name, load_stable_baselines_env(cfg_path, vector_length, mp, n_stack, number_maps,
                                                           action_frame_repeat, scaled_resolution))
                          for name, cfg_path in self.eval_cfgs]

        self.vector_length = vector_length
        self.mp = mp
        self.n_stack = n_stack

        self.eval_summary_writer = SummaryWriter(tensorboard_dir) 
开发者ID:microsoft,项目名称:MazeExplorer,代码行数:25,代码来源:evaluator.py

示例11: init_logdir

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def init_logdir(self):
        if not self._has_init and self.logdir:
            if self.testing_log:
                self._testing_log = StreamingJSONWriter(
                    os.path.join(self.logdir, self.testing_log))
            if self.training_log:
                self._training_log = StreamingJSONWriter(
                    os.path.join(self.logdir, self.training_log))
            if self.summary_writer is None:
                try:
                    from tensorboardX import SummaryWriter
                    self.summary_writer = SummaryWriter(self.logdir)
                except ImportError:
                    logger.error(
                        "Could not import tensorboardX. "
                        "SafeLifeLogger will not write data to tensorboard.")
        self._has_init = True 
开发者ID:PartnershipOnAI,项目名称:safelife,代码行数:19,代码来源:safelife_logger.py

示例12: build_report_manager

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def build_report_manager(opt):
    if opt.tensorboard:
        from tensorboardX import SummaryWriter
        tensorboard_log_dir = opt.tensorboard_log_dir

        if not opt.train_from:
            tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S")

        writer = SummaryWriter(tensorboard_log_dir,
                               comment="Unmt")
    else:
        writer = None

    report_mgr = ReportMgr(opt.report_every, start_time=-1,
                           tensorboard_writer=writer)
    return report_mgr 
开发者ID:lizekang,项目名称:ITDD,代码行数:18,代码来源:report_manager.py

示例13: __init__

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, log_dir):
        try:
            from torch.utils.tensorboard import SummaryWriter
        except ImportError:
            try:
                from tensorboardX import SummaryWriter
            except ImportError:
                raise RuntimeError("This contrib module requires tensorboardX to be installed. "
                                   "Please install it with command: \n pip install tensorboardX")

        try:
            self.writer = SummaryWriter(log_dir)
        except TypeError as err:
            if "type object got multiple values for keyword argument 'logdir'" == str(err):
                self.writer = SummaryWriter(log_dir=log_dir)
                warnings.warn('tensorboardX version < 1.7 will not be supported '
                              'after ignite 0.3.0; please upgrade',
                              DeprecationWarning)
            else:
                raise err 
开发者ID:budui,项目名称:Human-Pose-Transfer,代码行数:22,代码来源:tensorboard_logger.py

示例14: set_save_name_log_nvdm

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def set_save_name_log_nvdm(args):
    args.save_name = os.path.join(args.root_path, args.exp_path,
                                  'Data{}_Dist{}_Model{}_Emb{}_Hid{}_lat{}_lr{}_drop{}_kappa{}_auxw{}_normf{}'
                                  .format(
                                      args.data_name, str(args.dist), args.model,
                                      args.emsize,
                                      args.nhid, args.lat_dim, args.lr,
                                      args.dropout, args.kappa, args.aux_weight, str(args.norm_func)))
    writer = SummaryWriter(log_dir=args.save_name)
    log_name = args.save_name + '.log'
    logging.basicConfig(filename=log_name, level=logging.INFO)
    # set up logging to console
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    # set a format which is simpler for console use
    formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
    console.setFormatter(formatter)
    # add the handler to the root logger
    logging.getLogger('').addHandler(console)
    return args, writer 
开发者ID:jiacheng-xu,项目名称:vmf_vae_nlp,代码行数:22,代码来源:nvll.py

示例15: set_save_name_log_nvrnn

# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def set_save_name_log_nvrnn(args):
    args.save_name = os.path.join(
        args.exp_path, 'Data{}_' \
                       'Dist{}_Model{}_Enc{}Bi{}_Emb{}_Hid{}_lat{}_lr{}_drop{}_kappa{}_auxw{}_normf{}_nlay{}_mixunk{}_inpz{}_cdbit{}_cdbow{}_ann{}'
            .format(
            args.data_name, str(args.dist), args.model, args.enc_type, args.bi,
            args.emsize,
            args.nhid, args.lat_dim, args.lr,
            args.dropout, args.kappa, args.aux_weight, str(args.norm_func), args.nlayers, args.mix_unk, args.input_z,
            args.cd_bit, args.cd_bow, args.anneal))
    writer = SummaryWriter(log_dir=args.save_name)
    log_name = args.save_name + '.log'
    logging.basicConfig(filename=log_name, level=logging.INFO)
    # set up logging to console
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    # set a format which is simpler for console use
    formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
    console.setFormatter(formatter)
    # add the handler to the root logger
    logging.getLogger('').addHandler(console)
    return args, writer 
开发者ID:jiacheng-xu,项目名称:vmf_vae_nlp,代码行数:24,代码来源:nvll.py


注:本文中的tensorboardX.SummaryWriter方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。