本文整理汇总了Python中tensorflow.core.util.event_pb2.SessionLog.CHECKPOINT属性的典型用法代码示例。如果您正苦于以下问题:Python SessionLog.CHECKPOINT属性的具体用法?Python SessionLog.CHECKPOINT怎么用?Python SessionLog.CHECKPOINT使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类tensorflow.core.util.event_pb2.SessionLog
的用法示例。
在下文中一共展示了SessionLog.CHECKPOINT属性的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testSessionLogSummaries
# 需要导入模块: from tensorflow.core.util.event_pb2 import SessionLog [as 别名]
# 或者: from tensorflow.core.util.event_pb2.SessionLog import CHECKPOINT [as 别名]
def testSessionLogSummaries(self):
data = [
{'session_log': SessionLog(status=SessionLog.START), 'step': 0},
{'session_log': SessionLog(status=SessionLog.CHECKPOINT), 'step': 1},
{'session_log': SessionLog(status=SessionLog.CHECKPOINT), 'step': 2},
{'session_log': SessionLog(status=SessionLog.CHECKPOINT), 'step': 3},
{'session_log': SessionLog(status=SessionLog.STOP), 'step': 4},
{'session_log': SessionLog(status=SessionLog.START), 'step': 5},
{'session_log': SessionLog(status=SessionLog.STOP), 'step': 6},
]
self._WriteScalarSummaries(data)
units = efi.get_inspection_units(self.logdir)
self.assertEqual(1, len(units))
printable = efi.get_dict_to_print(units[0].field_to_obs)
self.assertEqual(printable['sessionlog:start']['steps'], [0, 5])
self.assertEqual(printable['sessionlog:stop']['steps'], [4, 6])
self.assertEqual(printable['sessionlog:checkpoint']['num_steps'], 3)
示例2: run_loop
# 需要导入模块: from tensorflow.core.util.event_pb2 import SessionLog [as 别名]
# 或者: from tensorflow.core.util.event_pb2.SessionLog import CHECKPOINT [as 别名]
def run_loop(self):
logging.info("Saving checkpoint to path %s", self._sv.save_path)
self._sv.saver.save(self._sess, self._sv.save_path,
global_step=self._sv.global_step)
if self._sv.summary_writer and self._sv.global_step is not None:
current_step = training_util.global_step(self._sess, self._sv.global_step)
self._sv.summary_writer.add_session_log(
SessionLog(status=SessionLog.CHECKPOINT,
checkpoint_path=self._sv.save_path),
current_step)
# TODO(sherrym): All non-PEP8 compliant names will be deprecated shortly.
示例3: _save
# 需要导入模块: from tensorflow.core.util.event_pb2 import SessionLog [as 别名]
# 或者: from tensorflow.core.util.event_pb2.SessionLog import CHECKPOINT [as 别名]
def _save(self, step, session):
"""Saves the latest checkpoint."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
for l in self._listeners:
l.before_save(session, step)
self._get_saver().save(session, self._save_path, global_step=step)
self._summary_writer.add_session_log(
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step)
for l in self._listeners:
l.after_save(session, step)
示例4: get_field_to_observations_map
# 需要导入模块: from tensorflow.core.util.event_pb2 import SessionLog [as 别名]
# 或者: from tensorflow.core.util.event_pb2.SessionLog import CHECKPOINT [as 别名]
def get_field_to_observations_map(generator, query_for_tag=''):
"""Return a field to `Observations` dict for the event generator.
Args:
generator: A generator over event protos.
query_for_tag: A string that if specified, only create observations for
events with this tag name.
Returns:
A dict mapping keys in `TRACKED_FIELDS` to an `Observation` list.
"""
def increment(stat, event, tag=''):
assert stat in TRACKED_FIELDS
field_to_obs[stat].append(Observation(step=event.step,
wall_time=event.wall_time,
tag=tag)._asdict())
field_to_obs = dict([(t, []) for t in TRACKED_FIELDS])
for event in generator:
## Process the event
if event.HasField('graph_def') and (not query_for_tag):
increment('graph', event)
if event.HasField('session_log') and (not query_for_tag):
status = event.session_log.status
if status == SessionLog.START:
increment('sessionlog:start', event)
elif status == SessionLog.STOP:
increment('sessionlog:stop', event)
elif status == SessionLog.CHECKPOINT:
increment('sessionlog:checkpoint', event)
elif event.HasField('summary'):
for value in event.summary.value:
if query_for_tag and value.tag != query_for_tag:
continue
for proto_name, display_name in SUMMARY_TYPE_TO_FIELD.items():
if value.HasField(proto_name):
increment(display_name, event, value.tag)
return field_to_obs
示例5: _save
# 需要导入模块: from tensorflow.core.util.event_pb2 import SessionLog [as 别名]
# 或者: from tensorflow.core.util.event_pb2.SessionLog import CHECKPOINT [as 别名]
def _save(self, step, session):
"""Saves the latest checkpoint."""
if step == self._last_saved_step:
return
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
self._last_saved_time = time.time()
self._last_saved_step = step
if self._saver is None:
self._scaffold.saver.save(session, self._save_path, global_step=step)
else:
self._saver.save(session, self._save_path, global_step=step)
self._summary_writer.add_session_log(
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step)
示例6: run_loop
# 需要导入模块: from tensorflow.core.util.event_pb2 import SessionLog [as 别名]
# 或者: from tensorflow.core.util.event_pb2.SessionLog import CHECKPOINT [as 别名]
def run_loop(self):
self._sv.saver.save(self._sess, self._sv.save_path,
global_step=self._sv.global_step)
if self._sv.summary_writer and self._sv.global_step is not None:
current_step = training_util.global_step(self._sess, self._sv.global_step)
self._sv.summary_writer.add_session_log(
SessionLog(status=SessionLog.CHECKPOINT,
checkpoint_path=self._sv.save_path),
current_step)
# TODO(sherrym): All non-PEP8 compliant names will be deprecated shortly.
示例7: _save
# 需要导入模块: from tensorflow.core.util.event_pb2 import SessionLog [as 别名]
# 或者: from tensorflow.core.util.event_pb2.SessionLog import CHECKPOINT [as 别名]
def _save(self, session, step, asynchronous=True):
"""Saves the latest checkpoint, returns should_stop."""
def _save_fn():
"""Run the saver process."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
start_time = time.time()
for l in self._listeners:
l.before_save(session, step)
self._get_saver().save(session, self._save_path, global_step=step)
self._summary_writer.add_session_log(
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step)
end_time = time.time()
logging.info("Checkpoint actual writing time: (%.3f sec)",
end_time - start_time)
logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
for l in self._listeners:
l.before_save(session, step)
if not asynchronous:
_save_fn()
return
if self._save_thread is not None:
self._save_thread.join(timeout=0.1)
if self._save_thread.is_alive():
logging.info("Saver thread still in progress, skipping checkpoint.")
return
self._save_thread = threading.Thread(target=_save_fn)
self._save_thread.start()
示例8: _save
# 需要导入模块: from tensorflow.core.util.event_pb2 import SessionLog [as 别名]
# 或者: from tensorflow.core.util.event_pb2.SessionLog import CHECKPOINT [as 别名]
def _save(self, step, session):
"""Saves the latest checkpoint."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
if self._saver is not None:
self._saver.save(session, self._save_path, global_step=step)
elif self._scaffold is not None:
self._scaffold.saver.save(session, self._save_path, global_step=step)
self._summary_writer.add_session_log(
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step)
示例9: _save
# 需要导入模块: from tensorflow.core.util.event_pb2 import SessionLog [as 别名]
# 或者: from tensorflow.core.util.event_pb2.SessionLog import CHECKPOINT [as 别名]
def _save(self, step, session):
""" Saves checkpoints.
Args:
step: A python integer, running step.
session: A TensorFlow Session.
"""
"""Saves the latest checkpoint."""
self._saver.save(session, self._save_path, global_step=step)
tf.logging.info("Saving checkpoints for {} into {}".format(step, self._save_path))
if self._summary_writer is not None:
self._summary_writer.add_session_log(
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step)
示例10: _save
# 需要导入模块: from tensorflow.core.util.event_pb2 import SessionLog [as 别名]
# 或者: from tensorflow.core.util.event_pb2.SessionLog import CHECKPOINT [as 别名]
def _save(self, session, step):
"""Saves the latest checkpoint."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
for l in self._listeners:
l.before_save(session, step)
self._get_saver().save(session, self._save_path, global_step=step)
self._summary_writer.add_session_log(
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step)
for l in self._listeners:
l.after_save(session, step)
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:17,代码来源:basic_session_run_hooks.py