当前位置: 首页>>代码示例>>Python>>正文


Python distributions.OneHotCategorical方法代码示例

本文整理汇总了Python中torch.distributions.OneHotCategorical方法的典型用法代码示例。如果您正苦于以下问题:Python distributions.OneHotCategorical方法的具体用法?Python distributions.OneHotCategorical怎么用?Python distributions.OneHotCategorical使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.distributions的用法示例。


在下文中一共展示了distributions.OneHotCategorical方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: generate

# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def generate(self,
                 prior: torch.Tensor,
                 length=2048,
                 tf_board_writer: SummaryWriter = None):
        decode_array = prior
        result_array = prior
        print(config)
        print(length)
        for i in Bar('generating').iter(range(length)):
            if decode_array.size(1) >= config.threshold_len:
                decode_array = decode_array[:, 1:]
            _, _, look_ahead_mask = \
                utils.get_masked_with_pad_tensor(decode_array.size(1), decode_array, decode_array, pad_token=config.pad_token)

            # result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask)
            # result, _ = decode_fn(decode_array, look_ahead_mask)
            result, _ = self.Decoder(decode_array, None)
            result = self.fc(result)
            result = result.softmax(-1)

            if tf_board_writer:
                tf_board_writer.add_image("logits", result, global_step=i)

            u = 0
            if u > 1:
                result = result[:, -1].argmax(-1).to(decode_array.dtype)
                decode_array = torch.cat((decode_array, result.unsqueeze(-1)), -1)
            else:
                pdf = dist.OneHotCategorical(probs=result[:, -1])
                result = pdf.sample().argmax(-1).unsqueeze(-1)
                # result = torch.transpose(result, 1, 0).to(torch.int32)
                decode_array = torch.cat((decode_array, result), dim=-1)
                result_array = torch.cat((result_array, result), dim=-1)
            del look_ahead_mask
        result_array = result_array[0]
        return result_array 
开发者ID:jason9693,项目名称:MusicTransformer-pytorch,代码行数:38,代码来源:model.py

示例2: _hard_categorical

# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def _hard_categorical(self, dist):
  return dist.OneHotCategorical(logits=dist.logits) 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:4,代码来源:__init__.py

示例3: forward

# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def forward(self, x, return_latents=False, feature_matching=False):
        x = self.model(x)
        if feature_matching is True:
            return x
        critic_score = self.disc(x)
        x = self.dist_conv(x).view(-1, x.size(1))
        dist_dis = distributions.OneHotCategorical(logits=self.dis_categorical(x))
        dist_cont = distributions.Normal(
            loc=self.cont_mean(x), scale=torch.exp(0.5 * self.cont_logvar(x))
        )
        return (
            critic_score,
            dist_dis,
            dist_cont if return_latents is True else critic_score,
        ) 
开发者ID:torchgan,项目名称:torchgan,代码行数:17,代码来源:infogan.py

示例4: test_infogan_discriminator

# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def test_infogan_discriminator(self):
        channels = [3, 4]
        in_size = [32, 64]
        dim_cont = [10, 20]
        dim_dis = [30, 40]
        step = [64, 128]
        batchnorm = [True, False]
        nonlinearities = [None, torch.nn.ELU(0.5)]
        last_nonlinearity = [None, torch.nn.RReLU(0.25)]
        for i in range(2):
            x = torch.randn(10, channels[i], in_size[i], in_size[i])
            D = InfoGANDiscriminator(
                dim_dis[i],
                dim_cont[i],
                in_size[i],
                channels[i],
                step[i],
                batchnorm[i],
                nonlinearities[i],
                last_nonlinearity[i],
            )
            y, dist_dis, dist_cont = D(x, True)
            assert y.shape == (10, 1, 1, 1)
            assert isinstance(dist_dis, distributions.OneHotCategorical)
            assert isinstance(dist_cont, distributions.Normal)
            assert dist_dis.sample().shape == (10, dim_dis[i])
            assert dist_cont.sample().shape == (10, dim_cont[i]) 
开发者ID:torchgan,项目名称:torchgan,代码行数:29,代码来源:test_models.py

示例5: sample

# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def sample(self, params):
        pi, mean, log_std = params['pi'], params['mean'], params['log_std']
        pi_onehot = OneHotCategorical(pi).sample()
        ac = torch.sum((mean + torch.randn_like(mean) *
                        torch.exp(log_std)) * pi_onehot.unsqueeze(-1), 1)
        return ac 
开发者ID:DeepX-inc,项目名称:machina,代码行数:8,代码来源:mixture_gaussian_pd.py

示例6: _create_data

# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def _create_data(self, rotate=True):
        # probs = (1 / self.width**2) * torch.ones(self.width**2)
        #
        # means = torch.Tensor([
        #     (x, y)
        #     for x in torch.linspace(-self.bound, self.bound, self.width)
        #     for y in torch.linspace(-self.bound, self.bound, self.width)
        # ])
        #
        # covariance = self.std**2 * torch.eye(2)
        # covariances = covariance[None, ...].repeat(self.width**2, 1, 1)
        #
        # mixture_distribution = distributions.OneHotCategorical(
        #     probs=probs
        # )
        # components_distribution = distributions.MultivariateNormal(
        #     loc=means,
        #     covariance_matrix=covariances
        # )
        #
        # mask = mixture_distribution.sample((self.num_points,))[..., None].repeat(1, 1, 2)
        # samples = components_distribution.sample((self.num_points,))
        # self.data = torch.sum(mask * samples, dim=-2)
        # if rotate:
        #     rotation_matrix = torch.Tensor([
        #         [1 / np.sqrt(2), -1 / np.sqrt(2)],
        #         [1 / np.sqrt(2), 1 / np.sqrt(2)]
        #     ])
        #     self.data = self.data @ rotation_matrix
        means = np.array([
            (x + 1e-3 * np.random.rand(), y + 1e-3 * np.random.rand())
            for x in np.linspace(-self.bound, self.bound, self.width)
            for y in np.linspace(-self.bound, self.bound, self.width)
        ])

        covariance_factor = self.std * np.eye(2)

        index = np.random.choice(range(self.width ** 2), size=self.num_points, replace=True)
        noise = np.random.randn(self.num_points, 2)
        self.data = means[index] + noise @ covariance_factor
        if rotate:
            rotation_matrix = np.array([
                [1 / np.sqrt(2), -1 / np.sqrt(2)],
                [1 / np.sqrt(2), 1 / np.sqrt(2)]
            ])
            self.data = self.data @ rotation_matrix
        self.data = self.data.astype(np.float32)
        self.data = torch.Tensor(self.data) 
开发者ID:bayesiains,项目名称:nsf,代码行数:50,代码来源:plane.py

示例7: _sample_batch_from_proposal

# 需要导入模块: from torch import distributions [as 别名]
# 或者: from torch.distributions import OneHotCategorical [as 别名]
def _sample_batch_from_proposal(self, batch_size,
                                    return_log_density_of_samples=False):
        # need to do n_samples passes through autoregressive net
        samples = torch.zeros(batch_size, self.autoregressive_net.input_dim)
        log_density_of_samples = torch.zeros(batch_size,
                                             self.autoregressive_net.input_dim)
        for dim in range(self.autoregressive_net.input_dim):
            # compute autoregressive outputs
            autoregressive_outputs = self.autoregressive_net(samples).reshape(-1,
                                                                              self.dim,
                                                                              self.autoregressive_net.output_dim_multiplier)

            # grab proposal params for dth dimensions
            proposal_params = autoregressive_outputs[..., dim, self.context_dim:]

            # make mixture coefficients, locs, and scales for proposal
            logits = proposal_params[...,
                     :self.n_proposal_mixture_components]  # [B, D, M]
            if logits.shape[0] == 1:
                logits = logits.reshape(self.dim, self.n_proposal_mixture_components)
            locs = proposal_params[...,
                   self.n_proposal_mixture_components:(
                           2 * self.n_proposal_mixture_components)]  # [B, D, M]
            scales = self.mixture_component_min_scale + self.scale_activation(
                proposal_params[...,
                (2 * self.n_proposal_mixture_components):])  # [B, D, M]

            # create proposal
            if self.Component is not None:
                mixture_distribution = distributions.OneHotCategorical(
                    logits=logits,
                    validate_args=True
                )
                components_distribution = self.Component(loc=locs, scale=scales)
                self.proposal = distributions_.MixtureSameFamily(
                    mixture_distribution=mixture_distribution,
                    components_distribution=components_distribution
                )
                proposal_samples = self.proposal.sample((1,))  # [S, B, D]

            else:
                self.proposal = distributions.Uniform(low=-4, high=4)
                proposal_samples = self.proposal.sample(
                    (1, batch_size, 1)
                )
            proposal_samples = proposal_samples.permute(1, 2, 0)  # [B, D, S]
            proposal_log_density = self.proposal.log_prob(proposal_samples)
            log_density_of_samples[:, dim] += proposal_log_density.reshape(-1).detach()
            samples[:, dim] += proposal_samples.reshape(-1).detach()

        if return_log_density_of_samples:
            return samples, torch.sum(log_density_of_samples, dim=-1)
        else:
            return samples 
开发者ID:conormdurkan,项目名称:autoregressive-energy-machines,代码行数:56,代码来源:aem.py


注:本文中的torch.distributions.OneHotCategorical方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。