本文整理匯總了Python中jax.lax.scan方法的典型用法代碼示例。如果您正苦於以下問題:Python lax.scan方法的具體用法?Python lax.scan怎麽用?Python lax.scan使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類jax.lax
的用法示例。
在下文中一共展示了lax.scan方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: test_scan_parametrized_cell_without_params
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def test_scan_parametrized_cell_without_params():
@parametrized
def cell(carry, x):
return jnp.array([2]) * carry * x, jnp.array([2]) * carry * x
@parametrized
def rnn(inputs):
_, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
return outs
inputs = jnp.zeros((3,))
params = rnn.init_parameters(inputs, key=PRNGKey(0))
assert_parameters_equal(((),), params)
outs = rnn.apply(params, inputs)
assert (3, 2) == outs.shape
示例2: test_scan_parametrized_cell
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def test_scan_parametrized_cell():
@parametrized
def cell(carry, x):
scale = parameter((2,), zeros)
return scale * jnp.array([2]) * carry * x, scale * jnp.array([2]) * carry * x
@parametrized
def rnn(inputs):
_, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
return outs
inputs = jnp.zeros((3,))
rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
assert (2,) == rnn_params.cell.parameter.shape
outs = rnn.apply(rnn_params, inputs)
assert (3, 2) == outs.shape
示例3: test_scan_parametrized_nonflat_cell
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def test_scan_parametrized_nonflat_cell():
@parametrized
def cell(carry, x):
scale = parameter((2,), zeros)
return {'a': scale * jnp.array([2]) * carry['a'] * x}, scale * jnp.array([2]) * carry[
'a'] * x
@parametrized
def rnn(inputs):
_, outs = lax.scan(cell, {'a': jnp.zeros((2,))}, inputs)
return outs
inputs = jnp.zeros((3,))
rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
assert (2,) == rnn_params.cell.parameter.shape
outs = rnn.apply(rnn_params, inputs)
assert (3, 2) == outs.shape
示例4: _scan
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def _scan(f, init, sequences, non_sequences=None, length=None, reverse=False):
# get the fully jaxed function
truef = symjax_to_jax_fn(f)
# now create a dummy function that only takes as input the sequences
if non_sequences is None:
def finalf(a, args):
return truef(a, *args)
else:
def finalf(a, args):
return truef(a, *args, *non_sequences)
return jla.scan(finalf, init, sequences, length=length, reverse=reverse)
示例5: forward_log_prob
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def forward_log_prob(init_log_prob, words, transition_log_prob, emission_log_prob, unroll_loop=False):
# Note: The following naive implementation will make it very slow to compile
# and do inference. So we use lax.scan instead.
#
# >>> log_prob = init_log_prob
# >>> for word in words:
# ... log_prob = forward_one_step(log_prob, word, transition_log_prob, emission_log_prob)
def scan_fn(log_prob, word):
return forward_one_step(log_prob, word, transition_log_prob, emission_log_prob), jnp.zeros((0,))
if unroll_loop:
log_prob = init_log_prob
for word in words:
log_prob = forward_one_step(log_prob, word, transition_log_prob, emission_log_prob)
else:
log_prob, _ = lax.scan(scan_fn, init_log_prob, words)
return log_prob
示例6: scan_wrapper
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[]):
def body_fn(wrapped_carry, x):
i, rng_key, carry = wrapped_carry
rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)
with handlers.block():
seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
for subs_type, subs_map in substitute_stack:
subs_fn = partial(_subs_wrapper, subs_map, i, length)
if subs_type == 'condition':
seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
elif subs_type == 'substitute':
seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)
with handlers.trace() as trace:
carry, y = seeded_fn(carry, x)
return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
if length is None:
length = tree_flatten(xs)[0][0].shape[0]
return lax.scan(body_fn, (jnp.array(0), rng_key, init), xs, length=length, reverse=reverse)
示例7: _jax_scan
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def _jax_scan(f, xs, init_value, axis=0, remat=False):
"""Scans the f over the given axis of xs.
In pseudo-python, the scan function would look as follows:
def scan(f, xs, init_value, axis):
xs = [xs[..., i, ...] for i in range(xs.shape[axis])]
cur_value = init_value
ys = []
for x in xs:
y, cur_value = f(x, cur_value)
ys.append(y)
return np.stack(ys, axis), cur_value
Args:
f: function (x, carry) -> (y, new_carry)
xs: tensor, x will be xs slices on axis
init_value: tensor, initial value of the carry-over
axis: int, the axis on which to slice xs
remat: whether to re-materialize f
Returns:
A pair (ys, last_value) as described above.
"""
def swapaxes(x):
transposed_axes = list(range(len(x.shape)))
transposed_axes[axis] = 0
transposed_axes[0] = axis
return jnp.transpose(x, axes=transposed_axes)
if axis != 0:
xs = nested_map(swapaxes, xs)
def transposed_f(c, x):
y, d = f(x, c)
return d, y
if remat:
last_value, ys = lax.scan(jax.remat(transposed_f), init_value, xs)
else:
last_value, ys = lax.scan(transposed_f, init_value, xs)
if axis != 0:
ys = nested_map(swapaxes, ys)
return ys, last_value
示例8: Rnn
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def Rnn(cell, carry_init):
"""Layer construction function for recurrent neural nets.
Expecting input shape (batch, sequence, channels).
TODO allow returning last carry."""
@parametrized
def rnn(xs):
xs = np.swapaxes(xs, 0, 1)
_, ys = lax.scan(cell, carry_init(xs.shape[1]), xs)
return np.swapaxes(ys, 0, 1)
return rnn
示例9: test_scan_unparametrized_cell
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def test_scan_unparametrized_cell():
def cell(carry, x):
return jnp.array([2]) * carry * x, jnp.array([2]) * carry * x
@parametrized
def rnn(inputs):
_, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
return outs
inputs = jnp.zeros((3,))
params = rnn.init_parameters(inputs, key=PRNGKey(0))
outs = rnn.apply(params, inputs)
assert (3, 2) == outs.shape
示例10: rnn
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def rnn(self, step_function, input, initial_states, **kwargs):
input = np.swapaxes(input, 0, 1)
def step(state, input_):
_, state = step_function(input_, state, **kwargs)
return state, state[0]
state, output = self.scan(step, input, initial_states)
return np.swapaxes(output, 0, 1)
示例11: scan
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def scan(self, fn, elems, initializer=None):
return lax.scan(fn, initializer, elems)
示例12: cholesky_update
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def cholesky_update(L, x, coef=1):
"""
Finds cholesky of L @ L.T + coef * x @ x.T.
**References;**
1. A more efficient rank-one covariance matrix update for evolution strategies,
Oswin Krause and Christian Igel
"""
batch_shape = lax.broadcast_shapes(L.shape[:-2], x.shape[:-1])
L = jnp.broadcast_to(L, batch_shape + L.shape[-2:])
x = jnp.broadcast_to(x, batch_shape + x.shape[-1:])
diag = jnp.diagonal(L, axis1=-2, axis2=-1)
# convert to unit diagonal triangular matrix: L @ D @ T.t
L = L / diag[..., None, :]
D = jnp.square(diag)
def scan_fn(carry, val):
b, w = carry
j, Dj, L_j = val
wj = w[..., j]
gamma = b * Dj + coef * jnp.square(wj)
Dj_new = gamma / b
b = gamma / Dj_new
# update vectors w and L_j
w = w - wj[..., None] * L_j
L_j = L_j + (coef * wj / gamma)[..., None] * w
return (b, w), (Dj_new, L_j)
D, L = jnp.moveaxis(D, -1, 0), jnp.moveaxis(L, -1, 0) # move scan dim to front
_, (D, L) = lax.scan(scan_fn, (jnp.ones(batch_shape), x), (jnp.arange(D.shape[0]), D, L))
D, L = jnp.moveaxis(D, 0, -1), jnp.moveaxis(L, 0, -1) # move scan dim back
return L * jnp.sqrt(D)[..., None, :]
示例13: _subs_wrapper
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def _subs_wrapper(subs_map, i, length, site):
value = None
if isinstance(subs_map, dict) and site['name'] in subs_map:
value = subs_map[site['name']]
elif callable(subs_map):
rng_key = site['kwargs'].get('rng_key')
subs_map = handlers.seed(subs_map, rng_seed=rng_key) if rng_key is not None else subs_map
value = subs_map(site)
if value is not None:
value_ndim = jnp.ndim(value)
sample_shape = site['kwargs']['sample_shape']
fn_ndim = len(sample_shape + site['fn'].shape())
if value_ndim == fn_ndim:
# this branch happens when substitute_fn is init_strategy,
# where we apply init_strategy to each element in the scanned series
return value
elif value_ndim == fn_ndim + 1:
# this branch happens when we substitute a series of values
shape = jnp.shape(value)
if shape[0] == length:
return value[i]
elif shape[0] < length:
rng_key = site['kwargs']['rng_key']
assert rng_key is not None
# we use the substituted values if i < shape[0]
# and generate a new sample otherwise
return lax.cond(i < shape[0],
(value, i),
lambda val: val[0][val[1]],
rng_key,
lambda val: site['fn'](rng_key=val, sample_shape=sample_shape))
else:
raise RuntimeError(f"Substituted value for site {site['name']} "
"requires length less than or equal to scan length."
f" Expected length <= {length}, but got {shape[0]}.")
else:
raise RuntimeError(f"Something goes wrong. Expected ndim = {fn_ndim} or {fn_ndim+1},"
f" but got {value_ndim}. Please report the issue to us!")
示例14: test_improper
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def test_improper():
y = random.normal(random.PRNGKey(0), (100,))
def model(y):
lambda1 = numpyro.sample('lambda1', dist.ImproperUniform(dist.constraints.real, (), ()))
lambda2 = numpyro.sample('lambda2', dist.ImproperUniform(dist.constraints.real, (), ()))
sigma = numpyro.sample('sigma', dist.ImproperUniform(dist.constraints.positive, (), ()))
mu = numpyro.deterministic('mu', lambda1 + lambda2)
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, optim.Adam(0.003), ELBO(), y=y)
svi_state = svi.init(random.PRNGKey(2))
lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(10000))
示例15: map
# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def map(f, sequences, non_sequences=None):
"""Map a function over leading array axes.
Like Python's builtin map, except inputs and outputs are in the form of
stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
need to apply a function element by element for reduced memory usage or
heterogeneous computation with other control flow primitives.
When ``xs`` is an array type, the semantics of ``map`` are given by this
Python implementation::
def map(f, xs):
return np.stack([f(x) for x in xs])
Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of
the same advantages over a Python loop apply: ``xs`` may be an arbitrary
nested pytree type, and the mapped computation is compiled only once.
Args:
f: a Python function to apply element-wise over the first axis or axes of
``sequences``.
sequences: list of values over which to map along the leading axis.
non_sequences: list of values passed the same at each call
Returns:
Mapped values.
Example:
example of creating a diagonal matrix:
.. doctest::
>>> import symjax.tensor as T
>>> import symjax
>>> x = T.ones(3)
>>> y = T.zeros(3)
>>> w = T.arange(3)
>>> out = T.map(lambda x, i, w: T.index_update(w, i, x), sequences=[x, w], non_sequences=[y])
>>> f = symjax.function(outputs=out)
>>> f()
array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
"""
g = lambda _, *args: (1, f(*args))
if type(non_sequences) == list:
non_sequences = tuple(non_sequences)
if type(sequences) == list:
sequences = tuple(sequences)
ys = scan(g, 0, sequences, non_sequences=non_sequences)[1]
return ys