本文整理汇总了Python中tensorboardX.SummaryWriter方法的典型用法代码示例。如果您正苦于以下问题:Python tensorboardX.SummaryWriter方法的具体用法?Python tensorboardX.SummaryWriter怎么用?Python tensorboardX.SummaryWriter使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorboardX
的用法示例。
在下文中一共展示了tensorboardX.SummaryWriter方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, fsm: FileStructManager, is_continue: bool, network_name: str = None):
super().__init__()
self.__writer = None
self.__txt_log_file = None
fsm.register_dir(self)
dir = fsm.get_path(self)
if dir is None:
return
dir = os.path.join(dir, network_name) if network_name is not None else dir
if not (fsm.in_continue_mode() or is_continue) and os.path.exists(dir) and os.path.isdir(dir):
idx = 0
tmp_dir = dir + "_v{}".format(idx)
while os.path.exists(tmp_dir) and os.path.isdir(tmp_dir):
idx += 1
tmp_dir = dir + "_v{}".format(idx)
dir = tmp_dir
os.makedirs(dir, exist_ok=True)
self.__writer = SummaryWriter(dir)
self.__txt_log_file = open(os.path.join(dir, "log.txt"), 'a' if is_continue else 'w')
示例2: run
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def run(self):
self.writer = SummaryWriter(os.path.join(self.env.model_dir, self.env.args.run))
try:
height, width = tuple(map(int, self.config.get('image', 'size').split()))
tensor = torch.randn(1, 3, height, width)
step, epoch, dnn = self.env.load()
self.writer.add_graph(dnn, (torch.autograd.Variable(tensor),))
except:
traceback.print_exc()
while True:
name, kwargs = self.queue.get()
if name is None:
break
func = getattr(self, 'summary_' + name)
try:
func(**kwargs)
except:
traceback.print_exc()
示例3: __init__
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, state_dim, action_dim, max_action):
self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = Critic(state_dim, action_dim).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)
self.replay_buffer = Replay_buffer()
self.writer = SummaryWriter(directory)
self.num_critic_update_iteration = 0
self.num_actor_update_iteration = 0
self.num_training = 0
示例4: __init__
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, state_dim, action_dim, max_action):
self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
self.critic_1 = Critic(state_dim, action_dim).to(device)
self.critic_1_target = Critic(state_dim, action_dim).to(device)
self.critic_2 = Critic(state_dim, action_dim).to(device)
self.critic_2_target = Critic(state_dim, action_dim).to(device)
self.actor_optimizer = optim.Adam(self.actor.parameters())
self.critic_1_optimizer = optim.Adam(self.critic_1.parameters())
self.critic_2_optimizer = optim.Adam(self.critic_2.parameters())
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic_1_target.load_state_dict(self.critic_1.state_dict())
self.critic_2_target.load_state_dict(self.critic_2.state_dict())
self.max_action = max_action
self.memory = Replay_buffer(args.capacity)
self.writer = SummaryWriter(directory)
self.num_critic_update_iteration = 0
self.num_actor_update_iteration = 0
self.num_training = 0
示例5: __init__
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self,
controller,
log_dir="logs/results",
v_action=0,
v_observation=0,
v_reward=0,
windows=[10, 100, 200]):
self.controller = controller
self.step_cntr = 0
self.step_global = 0
self.step_reset = 0
self.score = 0
self.score_history = []
self.v_action = v_action
self.v_observation = v_observation
self.v_reward = v_reward
self.windows = windows
self.file_writer = SummaryWriter(log_dir, flush_secs=30)
示例6: __init__
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, model, trainloader, valloader, args):
self.model = model
self.trainloader = trainloader
self.valloader = valloader
self.args = args
self.start_epoch = 0
self.best_top1 = 0.0
# Loss function and Optimizer
self.loss = None
self.optimizer = None
self.create_optimization()
# Model Loading
self.load_pretrained_model()
self.load_checkpoint(self.args.resume_from)
# Tensorboard Writer
self.summary_writer = SummaryWriter(log_dir=args.summary_dir)
示例7: visualize
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def visualize(args):
saved_path = constant.EXP_ROOT
model = models.Model(args, constant.ANSWER_NUM_DICT[args.goal])
model.cuda()
model.eval()
model.load_state_dict(torch.load(saved_path + '/' + args.model_id + '_best.pt')["state_dict"])
label2id = constant.ANS2ID_DICT["open"]
visualize = SummaryWriter("../visualize/" + args.model_id)
# label_list = ["person", "leader", "president", "politician", "organization", "company", "athlete","adult", "male", "man", "television_program", "event"]
label_list = list(label2id.keys())
ids = [label2id[_] for _ in label_list]
if args.gcn:
# connection_matrix = model.decoder.label_matrix + model.decoder.weight * model.decoder.affinity
connection_matrix = model.decoder.label_matrix + model.decoder.weight * model.decoder.affinity
label_vectors = model.decoder.transform(connection_matrix.mm(model.decoder.linear.weight) / connection_matrix.sum(1, keepdim=True))
else:
label_vectors = model.decoder.linear.weight.data
interested_vectors = torch.index_select(label_vectors, 0, torch.tensor(ids).to(torch.device("cuda")))
visualize.add_embedding(interested_vectors, metadata=label_list, label_img=None)
示例8: __init__
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, config, args):
super(Trainer, self).__init__(config, args)
# Best validation error, initialize it with a large number
self.best_val_err = 1e10
# Logger Settings
name = Path(args.checkpoint_dir).stem
self.log_dir = str(Path(args.log_dir, name))
self.log_writer = SummaryWriter(self.log_dir)
self.checkpoint_path = args.checkpoint_path
self.checkpoint_dir = args.checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
self.config = config
self.audio_processor = AudioProcessor(**config['audio'])
# Training detail
self.step = 0
self.max_step = config['solver']['total_steps']
示例9: main
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def main():
args = molecule_arg_parser().parse_args()
print(args)
args.name_full = args.env + '_' + args.dataset + '_' + args.name
args.name_full_load = args.env + '_' + args.dataset_load + '_' + args.name_load + '_' + str(args.load_step)
# check and clean
if not os.path.exists('molecule_gen'):
os.makedirs('molecule_gen')
if not os.path.exists('ckpt'):
os.makedirs('ckpt')
# only keep first worker result in tensorboard
if MPI.COMM_WORLD.Get_rank() == 0:
writer = SummaryWriter(comment='_'+args.dataset+'_'+args.name)
else:
writer = None
train(args,seed=args.seed,writer=writer)
示例10: __init__
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, mazes_path, tensorboard_dir, vector_length, mp, n_stack, action_frame_repeat=4,
scaled_resolution=(42, 42)):
self.tensorboard_dir = tensorboard_dir
mazes_folders = [(x, os.path.join(mazes_path, x)) for x in os.listdir(mazes_path)]
get_cfg = lambda x: os.path.join(x, [cfg for cfg in sorted(os.listdir(x)) if cfg.endswith('.cfg')][0])
self.eval_cfgs = [(x[0], get_cfg(x[1])) for x in mazes_folders]
if not len(self.eval_cfgs):
raise FileNotFoundError("No eval cfgs found")
number_maps = 1 # number of maps inside each eval map path
self.eval_envs = [(name, load_stable_baselines_env(cfg_path, vector_length, mp, n_stack, number_maps,
action_frame_repeat, scaled_resolution))
for name, cfg_path in self.eval_cfgs]
self.vector_length = vector_length
self.mp = mp
self.n_stack = n_stack
self.eval_summary_writer = SummaryWriter(tensorboard_dir)
示例11: init_logdir
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def init_logdir(self):
if not self._has_init and self.logdir:
if self.testing_log:
self._testing_log = StreamingJSONWriter(
os.path.join(self.logdir, self.testing_log))
if self.training_log:
self._training_log = StreamingJSONWriter(
os.path.join(self.logdir, self.training_log))
if self.summary_writer is None:
try:
from tensorboardX import SummaryWriter
self.summary_writer = SummaryWriter(self.logdir)
except ImportError:
logger.error(
"Could not import tensorboardX. "
"SafeLifeLogger will not write data to tensorboard.")
self._has_init = True
示例12: build_report_manager
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def build_report_manager(opt):
if opt.tensorboard:
from tensorboardX import SummaryWriter
tensorboard_log_dir = opt.tensorboard_log_dir
if not opt.train_from:
tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S")
writer = SummaryWriter(tensorboard_log_dir,
comment="Unmt")
else:
writer = None
report_mgr = ReportMgr(opt.report_every, start_time=-1,
tensorboard_writer=writer)
return report_mgr
示例13: __init__
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def __init__(self, log_dir):
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
try:
from tensorboardX import SummaryWriter
except ImportError:
raise RuntimeError("This contrib module requires tensorboardX to be installed. "
"Please install it with command: \n pip install tensorboardX")
try:
self.writer = SummaryWriter(log_dir)
except TypeError as err:
if "type object got multiple values for keyword argument 'logdir'" == str(err):
self.writer = SummaryWriter(log_dir=log_dir)
warnings.warn('tensorboardX version < 1.7 will not be supported '
'after ignite 0.3.0; please upgrade',
DeprecationWarning)
else:
raise err
示例14: set_save_name_log_nvdm
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def set_save_name_log_nvdm(args):
args.save_name = os.path.join(args.root_path, args.exp_path,
'Data{}_Dist{}_Model{}_Emb{}_Hid{}_lat{}_lr{}_drop{}_kappa{}_auxw{}_normf{}'
.format(
args.data_name, str(args.dist), args.model,
args.emsize,
args.nhid, args.lat_dim, args.lr,
args.dropout, args.kappa, args.aux_weight, str(args.norm_func)))
writer = SummaryWriter(log_dir=args.save_name)
log_name = args.save_name + '.log'
logging.basicConfig(filename=log_name, level=logging.INFO)
# set up logging to console
console = logging.StreamHandler()
console.setLevel(logging.INFO)
# set a format which is simpler for console use
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
console.setFormatter(formatter)
# add the handler to the root logger
logging.getLogger('').addHandler(console)
return args, writer
示例15: set_save_name_log_nvrnn
# 需要导入模块: import tensorboardX [as 别名]
# 或者: from tensorboardX import SummaryWriter [as 别名]
def set_save_name_log_nvrnn(args):
args.save_name = os.path.join(
args.exp_path, 'Data{}_' \
'Dist{}_Model{}_Enc{}Bi{}_Emb{}_Hid{}_lat{}_lr{}_drop{}_kappa{}_auxw{}_normf{}_nlay{}_mixunk{}_inpz{}_cdbit{}_cdbow{}_ann{}'
.format(
args.data_name, str(args.dist), args.model, args.enc_type, args.bi,
args.emsize,
args.nhid, args.lat_dim, args.lr,
args.dropout, args.kappa, args.aux_weight, str(args.norm_func), args.nlayers, args.mix_unk, args.input_z,
args.cd_bit, args.cd_bow, args.anneal))
writer = SummaryWriter(log_dir=args.save_name)
log_name = args.save_name + '.log'
logging.basicConfig(filename=log_name, level=logging.INFO)
# set up logging to console
console = logging.StreamHandler()
console.setLevel(logging.INFO)
# set a format which is simpler for console use
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
console.setFormatter(formatter)
# add the handler to the root logger
logging.getLogger('').addHandler(console)
return args, writer