当前位置: 首页>>代码示例>>Python>>正文


Python numpy.concatenate方法代码示例

本文整理汇总了Python中jax.numpy.concatenate方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.concatenate方法的具体用法?Python numpy.concatenate怎么用?Python numpy.concatenate使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.numpy的用法示例。


在下文中一共展示了numpy.concatenate方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: GRUCell

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def GRUCell(carry_size, param_init):
    @parametrized
    def gru_cell(carry, x):
        def param(name):
            return parameter((x.shape[1] + carry_size, carry_size), param_init, name)

        both = np.concatenate((x, carry), axis=1)
        update = sigmoid(np.dot(both, param('update_kernel')))
        reset = sigmoid(np.dot(both, param('reset_kernel')))
        both_reset_carry = np.concatenate((x, reset * carry), axis=1)
        compute = np.tanh(np.dot(both_reset_carry, param('compute_kernel')))
        out = update * compute + (1 - update) * carry
        return out, out

    def carry_init(batch_size):
        return np.zeros((batch_size, carry_size))

    return gru_cell, carry_init 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:20,代码来源:modules.py

示例2: test_submodule_order

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def test_submodule_order():
    @parametrized
    def net():
        p = Parameter(lambda key: jnp.zeros((1,)))
        a = p()
        b = parameter((2,), zeros)
        c = parameter((3,), zeros)
        d = parameter((4,), zeros)
        e = parameter((5,), zeros)
        f = parameter((6,), zeros)

        # must not mess up order (decided by first submodule call):
        k = p()

        return jnp.concatenate([a, f]) + jnp.concatenate([b, e]) + jnp.concatenate([c, d]) + k

    params = net.init_parameters(key=PRNGKey(0))

    assert jnp.zeros((1,)) == params.parameter0
    out = net.apply(params)
    assert (7,) == out.shape 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:23,代码来源:test_core.py

示例3: make_dataset

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Make simulated dataset where potential customers who get a
    sales calls have ~2% higher chance of making another purchase.
    """
    key1, key2, key3 = random.split(rng_key, 3)

    num_calls = 51342
    num_no_calls = 48658

    made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,))
    made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,))

    made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])

    is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,))
    got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])
    design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)),
                               got_called.reshape(-1, 1),
                               is_female.reshape(-1, 1)])

    return design_matrix, made_purchase 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:24,代码来源:proportion_test.py

示例4: _flatten

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def _flatten(self, matrix_tups):
        """Flatten everything and concatenate it together."""
        out_vecs = [v.flatten() for t in matrix_tups for v in t]
        return jnp.concatenate(out_vecs) 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:6,代码来源:tabular_irl.py

示例5: _flatten_batch

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def _flatten_batch(self, matrix_tups):
        """Flatten all except leading dim & concatenate results together in channel dim.

        (Channel dim is whatever the dim after the leading dim is)."""
        out_vecs = []
        for t in matrix_tups:
            for v in t:
                new_shape = (v.shape[0],)
                if len(v.shape) > 1:
                    new_shape = new_shape + (np.prod(v.shape[1:]),)
                out_vecs.append(v.reshape(new_shape))
        return jnp.concatenate(out_vecs, axis=1) 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:14,代码来源:tabular_irl.py

示例6: _cat

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def _cat(dim, *x):
    if len(x) == 1:
        return x[0]
    return np.concatenate(x, axis=dim) 
开发者ID:pyro-ppl,项目名称:funsor,代码行数:6,代码来源:ops.py

示例7: concat_elu

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def concat_elu(x, axis=-1):
    return elu(jnp.concatenate((x, -x), axis)) 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:4,代码来源:pixelcnn.py

示例8: concatenate

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def concatenate(self, tensors, axis=-1):
        values = [self.coerce(v, dtype=self.floatx()) for v in tensors]
        return np.concatenate(values, axis=int(axis)) 
开发者ID:sharadmv,项目名称:deepx,代码行数:5,代码来源:jax.py

示例9: concat

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def concat(self, values, axis=-1):
        return self.concatenate(values, axis=axis) 
开发者ID:sharadmv,项目名称:deepx,代码行数:4,代码来源:jax.py

示例10: concatenate

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def concatenate(self, sequence, axis=0):
        """
        Join a sequence of arrays along an existing axis.

        Args:
            sequence: sequence of tensors
            axis: dimension along which to concatenate

        Returns:
            output: the concatenated tensor

        """
        return np.concatenate(sequence, axis=axis) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:15,代码来源:jax_backend.py

示例11: _ravel_list

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def _ravel_list(*leaves):
    leaves_metadata = tree_map(lambda l: pytree_metadata(
        jnp.ravel(l), jnp.shape(l), jnp.size(l), canonicalize_dtype(lax.dtype(l))), leaves)
    leaves_idx = jnp.cumsum(jnp.array((0,) + tuple(d.size for d in leaves_metadata)))

    def unravel_list(arr):
        return [jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size),
                            m.shape).astype(m.dtype)
                for i, m in enumerate(leaves_metadata)]

    flat = jnp.concatenate([m.flat for m in leaves_metadata]) if leaves_metadata else jnp.array([])
    return flat, unravel_list 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:14,代码来源:util.py

示例12: inv

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def inv(self, y):
        z = matrix_to_tril_vec(y, diagonal=-1)
        return jnp.concatenate([z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:5,代码来源:transforms.py

示例13: __call__

# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def __call__(self, x):
        z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1)
        return jnp.cumsum(z, axis=-1) 
开发者ID:pyro-ppl,项目名称:numpyro,代码行数:5,代码来源:transforms.py


注:本文中的jax.numpy.concatenate方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。