本文整理汇总了Python中torch.distributions.OneHotCategorical方法的典型用法代码示例。如果您正苦于以下问题:Python distributions.OneHotCategorical方法的具体用法?Python distributions.OneHotCategorical怎么用?Python distributions.OneHotCategorical使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.distributions
的用法示例。
在下文中一共展示了distributions.OneHotCategorical方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: generate
# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def generate(self,
prior: torch.Tensor,
length=2048,
tf_board_writer: SummaryWriter = None):
decode_array = prior
result_array = prior
print(config)
print(length)
for i in Bar('generating').iter(range(length)):
if decode_array.size(1) >= config.threshold_len:
decode_array = decode_array[:, 1:]
_, _, look_ahead_mask = \
utils.get_masked_with_pad_tensor(decode_array.size(1), decode_array, decode_array, pad_token=config.pad_token)
# result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask)
# result, _ = decode_fn(decode_array, look_ahead_mask)
result, _ = self.Decoder(decode_array, None)
result = self.fc(result)
result = result.softmax(-1)
if tf_board_writer:
tf_board_writer.add_image("logits", result, global_step=i)
u = 0
if u > 1:
result = result[:, -1].argmax(-1).to(decode_array.dtype)
decode_array = torch.cat((decode_array, result.unsqueeze(-1)), -1)
else:
pdf = dist.OneHotCategorical(probs=result[:, -1])
result = pdf.sample().argmax(-1).unsqueeze(-1)
# result = torch.transpose(result, 1, 0).to(torch.int32)
decode_array = torch.cat((decode_array, result), dim=-1)
result_array = torch.cat((result_array, result), dim=-1)
del look_ahead_mask
result_array = result_array[0]
return result_array
示例2: _hard_categorical
# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def _hard_categorical(self, dist):
return dist.OneHotCategorical(logits=dist.logits)
示例3: forward
# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def forward(self, x, return_latents=False, feature_matching=False):
x = self.model(x)
if feature_matching is True:
return x
critic_score = self.disc(x)
x = self.dist_conv(x).view(-1, x.size(1))
dist_dis = distributions.OneHotCategorical(logits=self.dis_categorical(x))
dist_cont = distributions.Normal(
loc=self.cont_mean(x), scale=torch.exp(0.5 * self.cont_logvar(x))
)
return (
critic_score,
dist_dis,
dist_cont if return_latents is True else critic_score,
)
示例4: test_infogan_discriminator
# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def test_infogan_discriminator(self):
channels = [3, 4]
in_size = [32, 64]
dim_cont = [10, 20]
dim_dis = [30, 40]
step = [64, 128]
batchnorm = [True, False]
nonlinearities = [None, torch.nn.ELU(0.5)]
last_nonlinearity = [None, torch.nn.RReLU(0.25)]
for i in range(2):
x = torch.randn(10, channels[i], in_size[i], in_size[i])
D = InfoGANDiscriminator(
dim_dis[i],
dim_cont[i],
in_size[i],
channels[i],
step[i],
batchnorm[i],
nonlinearities[i],
last_nonlinearity[i],
)
y, dist_dis, dist_cont = D(x, True)
assert y.shape == (10, 1, 1, 1)
assert isinstance(dist_dis, distributions.OneHotCategorical)
assert isinstance(dist_cont, distributions.Normal)
assert dist_dis.sample().shape == (10, dim_dis[i])
assert dist_cont.sample().shape == (10, dim_cont[i])
示例5: sample
# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def sample(self, params):
pi, mean, log_std = params['pi'], params['mean'], params['log_std']
pi_onehot = OneHotCategorical(pi).sample()
ac = torch.sum((mean + torch.randn_like(mean) *
torch.exp(log_std)) * pi_onehot.unsqueeze(-1), 1)
return ac
示例6: _create_data
# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def _create_data(self, rotate=True):
# probs = (1 / self.width**2) * torch.ones(self.width**2)
#
# means = torch.Tensor([
# (x, y)
# for x in torch.linspace(-self.bound, self.bound, self.width)
# for y in torch.linspace(-self.bound, self.bound, self.width)
# ])
#
# covariance = self.std**2 * torch.eye(2)
# covariances = covariance[None, ...].repeat(self.width**2, 1, 1)
#
# mixture_distribution = distributions.OneHotCategorical(
# probs=probs
# )
# components_distribution = distributions.MultivariateNormal(
# loc=means,
# covariance_matrix=covariances
# )
#
# mask = mixture_distribution.sample((self.num_points,))[..., None].repeat(1, 1, 2)
# samples = components_distribution.sample((self.num_points,))
# self.data = torch.sum(mask * samples, dim=-2)
# if rotate:
# rotation_matrix = torch.Tensor([
# [1 / np.sqrt(2), -1 / np.sqrt(2)],
# [1 / np.sqrt(2), 1 / np.sqrt(2)]
# ])
# self.data = self.data @ rotation_matrix
means = np.array([
(x + 1e-3 * np.random.rand(), y + 1e-3 * np.random.rand())
for x in np.linspace(-self.bound, self.bound, self.width)
for y in np.linspace(-self.bound, self.bound, self.width)
])
covariance_factor = self.std * np.eye(2)
index = np.random.choice(range(self.width ** 2), size=self.num_points, replace=True)
noise = np.random.randn(self.num_points, 2)
self.data = means[index] + noise @ covariance_factor
if rotate:
rotation_matrix = np.array([
[1 / np.sqrt(2), -1 / np.sqrt(2)],
[1 / np.sqrt(2), 1 / np.sqrt(2)]
])
self.data = self.data @ rotation_matrix
self.data = self.data.astype(np.float32)
self.data = torch.Tensor(self.data)
示例7: _sample_batch_from_proposal
# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def _sample_batch_from_proposal(self, batch_size,
return_log_density_of_samples=False):
# need to do n_samples passes through autoregressive net
samples = torch.zeros(batch_size, self.autoregressive_net.input_dim)
log_density_of_samples = torch.zeros(batch_size,
self.autoregressive_net.input_dim)
for dim in range(self.autoregressive_net.input_dim):
# compute autoregressive outputs
autoregressive_outputs = self.autoregressive_net(samples).reshape(-1,
self.dim,
self.autoregressive_net.output_dim_multiplier)
# grab proposal params for dth dimensions
proposal_params = autoregressive_outputs[..., dim, self.context_dim:]
# make mixture coefficients, locs, and scales for proposal
logits = proposal_params[...,
:self.n_proposal_mixture_components] # [B, D, M]
if logits.shape[0] == 1:
logits = logits.reshape(self.dim, self.n_proposal_mixture_components)
locs = proposal_params[...,
self.n_proposal_mixture_components:(
2 * self.n_proposal_mixture_components)] # [B, D, M]
scales = self.mixture_component_min_scale + self.scale_activation(
proposal_params[...,
(2 * self.n_proposal_mixture_components):]) # [B, D, M]
# create proposal
if self.Component is not None:
mixture_distribution = distributions.OneHotCategorical(
logits=logits,
validate_args=True
)
components_distribution = self.Component(loc=locs, scale=scales)
self.proposal = distributions_.MixtureSameFamily(
mixture_distribution=mixture_distribution,
components_distribution=components_distribution
)
proposal_samples = self.proposal.sample((1,)) # [S, B, D]
else:
self.proposal = distributions.Uniform(low=-4, high=4)
proposal_samples = self.proposal.sample(
(1, batch_size, 1)
)
proposal_samples = proposal_samples.permute(1, 2, 0) # [B, D, S]
proposal_log_density = self.proposal.log_prob(proposal_samples)
log_density_of_samples[:, dim] += proposal_log_density.reshape(-1).detach()
samples[:, dim] += proposal_samples.reshape(-1).detach()
if return_log_density_of_samples:
return samples, torch.sum(log_density_of_samples, dim=-1)
else:
return samples