本文整理汇总了Python中jax.numpy.concatenate方法的典型用法代码示例。如果您正苦于以下问题:Python numpy.concatenate方法的具体用法?Python numpy.concatenate怎么用?Python numpy.concatenate使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.numpy
的用法示例。
在下文中一共展示了numpy.concatenate方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: GRUCell
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def GRUCell(carry_size, param_init):
@parametrized
def gru_cell(carry, x):
def param(name):
return parameter((x.shape[1] + carry_size, carry_size), param_init, name)
both = np.concatenate((x, carry), axis=1)
update = sigmoid(np.dot(both, param('update_kernel')))
reset = sigmoid(np.dot(both, param('reset_kernel')))
both_reset_carry = np.concatenate((x, reset * carry), axis=1)
compute = np.tanh(np.dot(both_reset_carry, param('compute_kernel')))
out = update * compute + (1 - update) * carry
return out, out
def carry_init(batch_size):
return np.zeros((batch_size, carry_size))
return gru_cell, carry_init
示例2: test_submodule_order
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def test_submodule_order():
@parametrized
def net():
p = Parameter(lambda key: jnp.zeros((1,)))
a = p()
b = parameter((2,), zeros)
c = parameter((3,), zeros)
d = parameter((4,), zeros)
e = parameter((5,), zeros)
f = parameter((6,), zeros)
# must not mess up order (decided by first submodule call):
k = p()
return jnp.concatenate([a, f]) + jnp.concatenate([b, e]) + jnp.concatenate([c, d]) + k
params = net.init_parameters(key=PRNGKey(0))
assert jnp.zeros((1,)) == params.parameter0
out = net.apply(params)
assert (7,) == out.shape
示例3: make_dataset
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Make simulated dataset where potential customers who get a
sales calls have ~2% higher chance of making another purchase.
"""
key1, key2, key3 = random.split(rng_key, 3)
num_calls = 51342
num_no_calls = 48658
made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,))
made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,))
made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])
is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,))
got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])
design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)),
got_called.reshape(-1, 1),
is_female.reshape(-1, 1)])
return design_matrix, made_purchase
示例4: _flatten
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def _flatten(self, matrix_tups):
"""Flatten everything and concatenate it together."""
out_vecs = [v.flatten() for t in matrix_tups for v in t]
return jnp.concatenate(out_vecs)
示例5: _flatten_batch
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def _flatten_batch(self, matrix_tups):
"""Flatten all except leading dim & concatenate results together in channel dim.
(Channel dim is whatever the dim after the leading dim is)."""
out_vecs = []
for t in matrix_tups:
for v in t:
new_shape = (v.shape[0],)
if len(v.shape) > 1:
new_shape = new_shape + (np.prod(v.shape[1:]),)
out_vecs.append(v.reshape(new_shape))
return jnp.concatenate(out_vecs, axis=1)
示例6: _cat
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def _cat(dim, *x):
if len(x) == 1:
return x[0]
return np.concatenate(x, axis=dim)
示例7: concat_elu
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def concat_elu(x, axis=-1):
return elu(jnp.concatenate((x, -x), axis))
示例8: concatenate
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def concatenate(self, tensors, axis=-1):
values = [self.coerce(v, dtype=self.floatx()) for v in tensors]
return np.concatenate(values, axis=int(axis))
示例9: concat
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def concat(self, values, axis=-1):
return self.concatenate(values, axis=axis)
示例10: concatenate
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def concatenate(self, sequence, axis=0):
"""
Join a sequence of arrays along an existing axis.
Args:
sequence: sequence of tensors
axis: dimension along which to concatenate
Returns:
output: the concatenated tensor
"""
return np.concatenate(sequence, axis=axis)
示例11: _ravel_list
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def _ravel_list(*leaves):
leaves_metadata = tree_map(lambda l: pytree_metadata(
jnp.ravel(l), jnp.shape(l), jnp.size(l), canonicalize_dtype(lax.dtype(l))), leaves)
leaves_idx = jnp.cumsum(jnp.array((0,) + tuple(d.size for d in leaves_metadata)))
def unravel_list(arr):
return [jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size),
m.shape).astype(m.dtype)
for i, m in enumerate(leaves_metadata)]
flat = jnp.concatenate([m.flat for m in leaves_metadata]) if leaves_metadata else jnp.array([])
return flat, unravel_list
示例12: inv
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def inv(self, y):
z = matrix_to_tril_vec(y, diagonal=-1)
return jnp.concatenate([z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1)
示例13: __call__
# 需要导入模块: from jax import numpy [as 别名]
# 或者: from jax.numpy import concatenate [as 别名]
def __call__(self, x):
z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1)
return jnp.cumsum(z, axis=-1)