本文整理汇总了Python中gin.configurable方法的典型用法代码示例。如果您正苦于以下问题:Python gin.configurable方法的具体用法?Python gin.configurable怎么用?Python gin.configurable使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类gin
的用法示例。
在下文中一共展示了gin.configurable方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_random_number_generator_and_set_seed
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def get_random_number_generator_and_set_seed(seed=None):
"""Get a JAX random number generator and set random seed everywhere."""
random.seed(seed)
# While python random accepts None as seed and uses time/os seed then,
# some other functions expect integers so we create one here.
if seed is None:
seed = random.randint(0, 2**31 - 1)
tf.set_random_seed(seed)
numpy.random.seed(seed)
return jax_random.get_prng(seed)
# TODO(trax):
# * Make configurable:
# * loss
# * metrics
# * Training loop callbacks/hooks/...
# * Save/restore: pickle unsafe. Use np.array.savez + MessagePack?
# * Move metrics to metrics.py
# * Setup namedtuples for interfaces (e.g. lr fun constructors can take a
# LearningRateInit, metric funs, etc.).
# * Allow disabling eval
示例2: _step_impl
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def _step_impl(self, state, action):
"""Run one timestep of the environment's dynamics.
At each timestep, x is flipped from zero to one or one to zero.
Args:
state: A `State` object containing the current state.
action: An action in `action_space`.
Returns:
A `State` object containing the updated state.
"""
del action # Unused.
state.x = [1 - x for x in state.x]
return state
# TODO(): There isn't actually anything to configure in DummyMetric,
# but we mark it as configurable so that we can refer to it on the
# right-hand-side of expressions in gin configurations. Find out whether
# there's a better way of indicating that than gin.configurable.
示例3: initialize
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def initialize(self):
"""Initialize the teacher model from the checkpoint.
This function will be called after the graph has been constructed.
"""
if self.fraction_soft == 0.0:
# Do nothing if we do not need the teacher.
return
vars_to_restore = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope="teacher")
tf.train.init_from_checkpoint(
self.teacher_checkpoint,
{v.name[len("teacher/"):].split(":")[0]: v for v in vars_to_restore})
# gin-configurable constructors
示例4: _output_dir_or_default
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def _output_dir_or_default():
"""Returns a path to the output directory."""
if FLAGS.output_dir:
output_dir = FLAGS.output_dir
trainer_lib.log('Using --output_dir {}'.format(output_dir))
return os.path.expanduser(output_dir)
# Else, generate a default output dir (under the user's home directory).
try:
dataset_name = gin.query_parameter('data_streams.dataset_name')
except ValueError:
dataset_name = 'random'
output_name = '{model_name}_{dataset_name}_{timestamp}'.format(
model_name=gin.query_parameter('train.model').configurable.name,
dataset_name=dataset_name,
timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'),
)
output_dir = os.path.join('~', 'trax', output_name)
output_dir = os.path.expanduser(output_dir)
print()
trainer_lib.log('No --output_dir specified')
trainer_lib.log('Using default output_dir: {}'.format(output_dir))
return output_dir
# TODO(afrozm): Share between trainer.py and rl_trainer.py
示例5: create_train_op
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def create_train_op(self,
loss,
optimizer,
update_ops=None,
train_outputs=None):
"""Create meta-training op.
MAMLModel has a configurable var_scope used to select which variables to
train on. Note that MAMLInnerLoopGradientDescent also has such a parameter
to decide which variables to update in the *inner* loop. If you don't want
to update a set of variables in both the inner and outer loop, you'll need
to configure var_scope for both MAMLModel *and*
MAMLInnerLoopGradientDescent.
Args:
loss: The loss we compute within model_train_fn.
optimizer: An instance of `tf.train.Optimizer`.
update_ops: List of update ops to execute alongside the training op.
train_outputs: (Optional) A dict with additional tensors the training
model generates.
Returns:
train_op: Op for the training step.
"""
vars_to_train = tf.trainable_variables()
if self._var_scope is not None:
vars_to_train = [
v for v in vars_to_train if v.op.name.startswith(self._var_scope)]
summarize_gradients = self._summarize_gradients
if self.is_device_tpu:
# TPUs don't support summaries up until now. Hence, we overwrite the user
# provided summarize_gradients option to False.
if self._summarize_gradients:
logging.info('We cannot use summarize_gradients on TPUs.')
summarize_gradients = False
return contrib_training.create_train_op(
loss,
optimizer,
variables_to_train=vars_to_train,
summarize_gradients=summarize_gradients,
update_ops=update_ops)
示例6: loss_fn
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def loss_fn(self, labels, inference_outputs, mode, params=None):
"""This implements outer loss and configurable inner losses."""
if params and params.get('is_outer_loss', False):
pass
if self._num_mixture_components > 1:
gm = mdn.get_mixture_distribution(
inference_outputs['dist_params'], self._num_mixture_components,
self._action_size,
self._output_mean if self._normalize_outputs else None)
return -tf.reduce_mean(gm.log_prob(labels.action))
else:
return self._outer_loss_multiplier * tf.losses.mean_squared_error(
labels=labels.action,
predictions=inference_outputs['inference_output'])
示例7: parse_gin_defaults_and_flags
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def parse_gin_defaults_and_flags():
"""Parses all default gin files and those provided via flags."""
# Register .gin file search paths with gin
for gin_file_path in FLAGS.gin_location_prefix:
gin.add_config_file_search_path(gin_file_path)
# Set up the default values for the configurable parameters. These values will
# be overridden by any user provided gin files/parameters.
gin.parse_config_file(
pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE))
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
# TODO(noam): maybe add gin-config to mtf.get_variable so we can delete
# this stupid VariableDtype class and stop passing it all over creation.
示例8: separate_vocabularies
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def separate_vocabularies(inputs=gin.REQUIRED, targets=gin.REQUIRED):
"""Gin-configurable helper function to generate a tuple of vocabularies."""
return (inputs, targets)
# TODO(katherinelee): Update layout_rules string when noam updates the
# definition in run
示例9: attention_internal
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def attention_internal(self, context, q, m, memory_length, bias):
logits = mtf.layers.us_einsum(
[q, m], reduced_dims=[context.model.model_dim])
if bias is not None:
logits += bias
weights = mtf.softmax(logits, memory_length)
# TODO(noam): make dropout_broadcast_dims configurable
dropout_broadcast_dims = [context.length_dim]
weights = mtf.dropout(
weights, rate=self.dropout_rate if context.train else 0.0,
noise_shape=weights.shape - dropout_broadcast_dims)
u = mtf.einsum([weights, m], reduced_dims=[memory_length])
return self.compute_y(context, u)
示例10: load
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def load(
environment_name: Text,
discount: types.Float = 1.0,
max_episode_steps: Optional[types.Int] = None,
gym_env_wrappers: Sequence[types.GymEnvWrapper] = (),
env_wrappers: Sequence[types.PyEnvWrapper] = (),
spec_dtype_map: Optional[Dict[gym.Space, np.dtype]] = None
) -> py_environment.PyEnvironment:
"""Loads the selected environment and wraps it with the specified wrappers.
Note that by default a TimeLimit wrapper is used to limit episode lengths
to the default benchmarks defined by the registered environments.
Args:
environment_name: Name for the environment to load.
discount: Discount to use for the environment.
max_episode_steps: If None the max_episode_steps will be set to the default
step limit defined in the environment's spec. No limit is applied if set
to 0 or if there is no timestep_limit set in the environment's spec.
gym_env_wrappers: Iterable with references to wrapper classes to use
directly on the gym environment.
env_wrappers: Iterable with references to wrapper classes to use on the
gym_wrapped environment.
spec_dtype_map: A dict that maps gym specs to tf dtypes to use as the
default dtype for the tensors. An easy way how to configure a custom
mapping through Gin is to define a gin-configurable function that returns
desired mapping and call it in your Gin config file, for example:
`suite_gym.load.spec_dtype_map = @get_custom_mapping()`.
Returns:
A PyEnvironmentBase instance.
"""
return suite_gym.load(environment_name, discount, max_episode_steps,
gym_env_wrappers, env_wrappers, spec_dtype_map)
示例11: compute_optimal_action_with_classification_environment
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def compute_optimal_action_with_classification_environment(
observation, environment):
"""Helper function for gin configurable SuboptimalArms metric."""
del observation
return environment.compute_optimal_action()
示例12: compute_optimal_reward_with_classification_environment
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def compute_optimal_reward_with_classification_environment(
observation, environment):
"""Helper function for gin configurable Regret metric."""
del observation
return environment.compute_optimal_reward()
示例13: rate_unsupervised
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def rate_unsupervised(task, value=1e6):
"""Gin-configurable mixing rate for the unsupervised co-training task."""
del task
return value
示例14: default_input_fn_tmpl
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def default_input_fn_tmpl(
file_patterns,
batch_size,
feature_spec,
label_spec,
num_parallel_calls = 4,
is_training = False,
preprocess_fn=None,
shuffle_buffer_size = 500,
prefetch_buffer_size = (tf.data.experimental.AUTOTUNE),
parallel_shards = 10):
"""Generic gin-configurable tf.data input pipeline."""
if isinstance(file_patterns, dict):
file_patterns_map = file_patterns
else:
file_patterns_map = {'': file_patterns}
datasets = {}
# Read Each Dataset
for dataset_key, file_patterns in file_patterns_map.items():
data_format, filenames = get_data_format_and_filenames(file_patterns)
filenames_dataset = tf.data.Dataset.list_files(
filenames, shuffle=is_training)
if is_training:
cycle_length = min(parallel_shards, len(filenames))
else:
cycle_length = 1
dataset = filenames_dataset.apply(
tf.data.experimental.parallel_interleave(
DATA_FORMAT[data_format],
cycle_length=cycle_length,
sloppy=is_training))
if is_training:
dataset = dataset.shuffle(buffer_size=shuffle_buffer_size).repeat()
else:
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
datasets[dataset_key] = dataset
# Merge dict of datasets of batched serialized examples into a single dataset
# of dicts of batched serialized examples.
dataset = tf.data.Dataset.zip(datasets)
# Parse all datasets together.
dataset = serialized_to_parsed(
dataset, feature_spec, label_spec, num_parallel_calls=num_parallel_calls)
if preprocess_fn is not None:
# TODO(psanketi): Consider adding num_parallel calls here.
dataset = dataset.map(preprocess_fn, num_parallel_calls=parallel_shards)
if prefetch_buffer_size is not None:
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
示例15: make_bitransformer
# 需要导入模块: import gin [as 别名]
# 或者: from gin import configurable [as 别名]
def make_bitransformer(
input_vocab_size=gin.REQUIRED,
output_vocab_size=gin.REQUIRED,
layout=None,
mesh_shape=None,
encoder_name="encoder",
decoder_name="decoder"):
"""Gin-configurable bitransformer constructor.
In your config file you need to set the encoder and decoder layers like this:
encoder/make_layer_stack.layers = [
@transformer_layers.SelfAttention,
@transformer_layers.DenseReluDense,
]
decoder/make_layer_stack.layers = [
@transformer_layers.SelfAttention,
@transformer_layers.EncDecAttention,
@transformer_layers.DenseReluDense,
]
Args:
input_vocab_size: a integer
output_vocab_size: an integer
layout: optional - an input to mtf.convert_to_layout_rules
Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
mesh_shape: optional - an input to mtf.convert_to_shape
Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
encoder_name: optional - a string giving the Unitransformer encoder name.
decoder_name: optional - a string giving the Unitransformer decoder name.
Returns:
a Bitransformer
"""
with gin.config_scope("encoder"):
encoder = Unitransformer(
layer_stack=make_layer_stack(),
input_vocab_size=input_vocab_size,
output_vocab_size=None,
autoregressive=False,
name=encoder_name,
layout=layout,
mesh_shape=mesh_shape)
with gin.config_scope("decoder"):
decoder = Unitransformer(
layer_stack=make_layer_stack(),
input_vocab_size=output_vocab_size,
output_vocab_size=output_vocab_size,
autoregressive=True,
name=decoder_name,
layout=layout,
mesh_shape=mesh_shape)
return Bitransformer(encoder, decoder)