當前位置: 首頁>>代碼示例>>Python>>正文


Python lax.scan方法代碼示例

本文整理匯總了Python中jax.lax.scan方法的典型用法代碼示例。如果您正苦於以下問題:Python lax.scan方法的具體用法?Python lax.scan怎麽用?Python lax.scan使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在jax.lax的用法示例。


在下文中一共展示了lax.scan方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: test_scan_parametrized_cell_without_params

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def test_scan_parametrized_cell_without_params():
    @parametrized
    def cell(carry, x):
        return jnp.array([2]) * carry * x, jnp.array([2]) * carry * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
        return outs

    inputs = jnp.zeros((3,))

    params = rnn.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((),), params)

    outs = rnn.apply(params, inputs)

    assert (3, 2) == outs.shape 
開發者ID:JuliusKunze,項目名稱:jaxnet,代碼行數:20,代碼來源:test_core.py

示例2: test_scan_parametrized_cell

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def test_scan_parametrized_cell():
    @parametrized
    def cell(carry, x):
        scale = parameter((2,), zeros)
        return scale * jnp.array([2]) * carry * x, scale * jnp.array([2]) * carry * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
        return outs

    inputs = jnp.zeros((3,))

    rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == rnn_params.cell.parameter.shape
    outs = rnn.apply(rnn_params, inputs)

    assert (3, 2) == outs.shape 
開發者ID:JuliusKunze,項目名稱:jaxnet,代碼行數:20,代碼來源:test_core.py

示例3: test_scan_parametrized_nonflat_cell

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def test_scan_parametrized_nonflat_cell():
    @parametrized
    def cell(carry, x):
        scale = parameter((2,), zeros)
        return {'a': scale * jnp.array([2]) * carry['a'] * x}, scale * jnp.array([2]) * carry[
            'a'] * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, {'a': jnp.zeros((2,))}, inputs)
        return outs

    inputs = jnp.zeros((3,))

    rnn_params = rnn.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == rnn_params.cell.parameter.shape
    outs = rnn.apply(rnn_params, inputs)

    assert (3, 2) == outs.shape 
開發者ID:JuliusKunze,項目名稱:jaxnet,代碼行數:21,代碼來源:test_core.py

示例4: _scan

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def _scan(f, init, sequences, non_sequences=None, length=None, reverse=False):
    # get the fully jaxed function
    truef = symjax_to_jax_fn(f)

    # now create a dummy function that only takes as input the sequences
    if non_sequences is None:

        def finalf(a, args):
            return truef(a, *args)

    else:

        def finalf(a, args):
            return truef(a, *args, *non_sequences)

    return jla.scan(finalf, init, sequences, length=length, reverse=reverse) 
開發者ID:SymJAX,項目名稱:SymJAX,代碼行數:18,代碼來源:control_flow.py

示例5: forward_log_prob

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def forward_log_prob(init_log_prob, words, transition_log_prob, emission_log_prob, unroll_loop=False):
    # Note: The following naive implementation will make it very slow to compile
    # and do inference. So we use lax.scan instead.
    #
    # >>> log_prob = init_log_prob
    # >>> for word in words:
    # ...     log_prob = forward_one_step(log_prob, word, transition_log_prob, emission_log_prob)
    def scan_fn(log_prob, word):
        return forward_one_step(log_prob, word, transition_log_prob, emission_log_prob), jnp.zeros((0,))

    if unroll_loop:
        log_prob = init_log_prob
        for word in words:
            log_prob = forward_one_step(log_prob, word, transition_log_prob, emission_log_prob)
    else:
        log_prob, _ = lax.scan(scan_fn, init_log_prob, words)
    return log_prob 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:19,代碼來源:hmm.py

示例6: scan_wrapper

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[]):

    def body_fn(wrapped_carry, x):
        i, rng_key, carry = wrapped_carry
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        with handlers.block():
            seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map, i, length)
                if subs_type == 'condition':
                    seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
                elif subs_type == 'substitute':
                    seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

            with handlers.trace() as trace:
                carry, y = seeded_fn(carry, x)

        return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

    if length is None:
        length = tree_flatten(xs)[0][0].shape[0]
    return lax.scan(body_fn, (jnp.array(0), rng_key, init), xs, length=length, reverse=reverse) 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:25,代碼來源:scan.py

示例7: _jax_scan

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def _jax_scan(f, xs, init_value, axis=0, remat=False):
  """Scans the f over the given axis of xs.

  In pseudo-python, the scan function would look as follows:

  def scan(f, xs, init_value, axis):
    xs  = [xs[..., i, ...] for i in range(xs.shape[axis])]
    cur_value = init_value
    ys = []
    for x in xs:
      y, cur_value = f(x, cur_value)
      ys.append(y)
    return np.stack(ys, axis), cur_value

  Args:
    f: function (x, carry) -> (y, new_carry)
    xs: tensor, x will be xs slices on axis
    init_value: tensor, initial value of the carry-over
    axis: int, the axis on which to slice xs
    remat: whether to re-materialize f

  Returns:
    A pair (ys, last_value) as described above.
  """
  def swapaxes(x):
    transposed_axes = list(range(len(x.shape)))
    transposed_axes[axis] = 0
    transposed_axes[0] = axis
    return jnp.transpose(x, axes=transposed_axes)
  if axis != 0:
    xs = nested_map(swapaxes, xs)
  def transposed_f(c, x):
    y, d = f(x, c)
    return d, y
  if remat:
    last_value, ys = lax.scan(jax.remat(transposed_f), init_value, xs)
  else:
    last_value, ys = lax.scan(transposed_f, init_value, xs)
  if axis != 0:
    ys = nested_map(swapaxes, ys)
  return ys, last_value 
開發者ID:google,項目名稱:trax,代碼行數:43,代碼來源:jax.py

示例8: Rnn

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def Rnn(cell, carry_init):
    """Layer construction function for recurrent neural nets.
    Expecting input shape (batch, sequence, channels).
    TODO allow returning last carry."""

    @parametrized
    def rnn(xs):
        xs = np.swapaxes(xs, 0, 1)
        _, ys = lax.scan(cell, carry_init(xs.shape[1]), xs)
        return np.swapaxes(ys, 0, 1)

    return rnn 
開發者ID:JuliusKunze,項目名稱:jaxnet,代碼行數:14,代碼來源:modules.py

示例9: test_scan_unparametrized_cell

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def test_scan_unparametrized_cell():
    def cell(carry, x):
        return jnp.array([2]) * carry * x, jnp.array([2]) * carry * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, jnp.zeros((2,)), inputs)
        return outs

    inputs = jnp.zeros((3,))

    params = rnn.init_parameters(inputs, key=PRNGKey(0))
    outs = rnn.apply(params, inputs)

    assert (3, 2) == outs.shape 
開發者ID:JuliusKunze,項目名稱:jaxnet,代碼行數:17,代碼來源:test_core.py

示例10: rnn

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def rnn(self, step_function, input, initial_states, **kwargs):
        input = np.swapaxes(input, 0, 1)
        def step(state, input_):
            _, state = step_function(input_, state, **kwargs)
            return state, state[0]
        state, output = self.scan(step, input, initial_states)
        return np.swapaxes(output, 0, 1) 
開發者ID:sharadmv,項目名稱:deepx,代碼行數:9,代碼來源:jax.py

示例11: scan

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def scan(self, fn, elems, initializer=None):
        return lax.scan(fn, initializer, elems) 
開發者ID:sharadmv,項目名稱:deepx,代碼行數:4,代碼來源:jax.py

示例12: cholesky_update

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def cholesky_update(L, x, coef=1):
    """
    Finds cholesky of L @ L.T + coef * x @ x.T.

    **References;**

        1. A more efficient rank-one covariance matrix update for evolution strategies,
           Oswin Krause and Christian Igel
    """
    batch_shape = lax.broadcast_shapes(L.shape[:-2], x.shape[:-1])
    L = jnp.broadcast_to(L, batch_shape + L.shape[-2:])
    x = jnp.broadcast_to(x, batch_shape + x.shape[-1:])
    diag = jnp.diagonal(L, axis1=-2, axis2=-1)
    # convert to unit diagonal triangular matrix: L @ D @ T.t
    L = L / diag[..., None, :]
    D = jnp.square(diag)

    def scan_fn(carry, val):
        b, w = carry
        j, Dj, L_j = val
        wj = w[..., j]
        gamma = b * Dj + coef * jnp.square(wj)
        Dj_new = gamma / b
        b = gamma / Dj_new

        # update vectors w and L_j
        w = w - wj[..., None] * L_j
        L_j = L_j + (coef * wj / gamma)[..., None] * w
        return (b, w), (Dj_new, L_j)

    D, L = jnp.moveaxis(D, -1, 0), jnp.moveaxis(L, -1, 0)  # move scan dim to front
    _, (D, L) = lax.scan(scan_fn, (jnp.ones(batch_shape), x), (jnp.arange(D.shape[0]), D, L))
    D, L = jnp.moveaxis(D, 0, -1), jnp.moveaxis(L, 0, -1)  # move scan dim back
    return L * jnp.sqrt(D)[..., None, :] 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:36,代碼來源:util.py

示例13: _subs_wrapper

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def _subs_wrapper(subs_map, i, length, site):
    value = None
    if isinstance(subs_map, dict) and site['name'] in subs_map:
        value = subs_map[site['name']]
    elif callable(subs_map):
        rng_key = site['kwargs'].get('rng_key')
        subs_map = handlers.seed(subs_map, rng_seed=rng_key) if rng_key is not None else subs_map
        value = subs_map(site)

    if value is not None:
        value_ndim = jnp.ndim(value)
        sample_shape = site['kwargs']['sample_shape']
        fn_ndim = len(sample_shape + site['fn'].shape())
        if value_ndim == fn_ndim:
            # this branch happens when substitute_fn is init_strategy,
            # where we apply init_strategy to each element in the scanned series
            return value
        elif value_ndim == fn_ndim + 1:
            # this branch happens when we substitute a series of values
            shape = jnp.shape(value)
            if shape[0] == length:
                return value[i]
            elif shape[0] < length:
                rng_key = site['kwargs']['rng_key']
                assert rng_key is not None
                # we use the substituted values if i < shape[0]
                # and generate a new sample otherwise
                return lax.cond(i < shape[0],
                                (value, i),
                                lambda val: val[0][val[1]],
                                rng_key,
                                lambda val: site['fn'](rng_key=val, sample_shape=sample_shape))
            else:
                raise RuntimeError(f"Substituted value for site {site['name']} "
                                   "requires length less than or equal to scan length."
                                   f" Expected length <= {length}, but got {shape[0]}.")
        else:
            raise RuntimeError(f"Something goes wrong. Expected ndim = {fn_ndim} or {fn_ndim+1},"
                               f" but got {value_ndim}. Please report the issue to us!") 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:41,代碼來源:scan.py

示例14: test_improper

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def test_improper():
    y = random.normal(random.PRNGKey(0), (100,))

    def model(y):
        lambda1 = numpyro.sample('lambda1', dist.ImproperUniform(dist.constraints.real, (), ()))
        lambda2 = numpyro.sample('lambda2', dist.ImproperUniform(dist.constraints.real, (), ()))
        sigma = numpyro.sample('sigma', dist.ImproperUniform(dist.constraints.positive, (), ()))
        mu = numpyro.deterministic('mu', lambda1 + lambda2)
        numpyro.sample('y', dist.Normal(mu, sigma), obs=y)

    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, optim.Adam(0.003), ELBO(), y=y)
    svi_state = svi.init(random.PRNGKey(2))
    lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(10000)) 
開發者ID:pyro-ppl,項目名稱:numpyro,代碼行數:16,代碼來源:test_autoguide.py

示例15: map

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import scan [as 別名]
def map(f, sequences, non_sequences=None):
    """Map a function over leading array axes.

    Like Python's builtin map, except inputs and outputs are in the form of
    stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
    need to apply a function element by element for reduced memory usage or
    heterogeneous computation with other control flow primitives.

    When ``xs`` is an array type, the semantics of ``map`` are given by this
    Python implementation::

      def map(f, xs):
        return np.stack([f(x) for x in xs])

    Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of
    the same advantages over a Python loop apply: ``xs`` may be an arbitrary
    nested pytree type, and the mapped computation is compiled only once.

    Args:
      f: a Python function to apply element-wise over the first axis or axes of
        ``sequences``.
      sequences: list of values over which to map along the leading axis.
      non_sequences: list of values passed the same at each call

    Returns:
      Mapped values.

    Example:
        example of creating a diagonal matrix:
        .. doctest::

           >>> import symjax.tensor as T
           >>> import symjax
           >>> x = T.ones(3)
           >>> y = T.zeros(3)
           >>> w = T.arange(3)
           >>> out = T.map(lambda x, i, w: T.index_update(w, i, x), sequences=[x, w], non_sequences=[y])
           >>> f = symjax.function(outputs=out)
           >>> f()
           array([[1., 0., 0.],
                  [0., 1., 0.],
                  [0., 0., 1.]], dtype=float32)
    """

    g = lambda _, *args: (1, f(*args))

    if type(non_sequences) == list:
        non_sequences = tuple(non_sequences)
    if type(sequences) == list:
        sequences = tuple(sequences)
    ys = scan(g, 0, sequences, non_sequences=non_sequences)[1]

    return ys 
開發者ID:SymJAX,項目名稱:SymJAX,代碼行數:55,代碼來源:control_flow.py


注:本文中的jax.lax.scan方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。