本文整理匯總了Python中torch.distributions.categorical.Categorical._new方法的典型用法代碼示例。如果您正苦於以下問題:Python Categorical._new方法的具體用法?Python Categorical._new怎麽用?Python Categorical._new使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.distributions.categorical.Categorical
的用法示例。
在下文中一共展示了Categorical._new方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: ExpRelaxedCategorical
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import _new [as 別名]
class ExpRelaxedCategorical(Distribution):
r"""
Creates a ExpRelaxedCategorical parameterized by `probs` and `temperature`.
Returns the log of a point in the simplex. Based on the interface to OneHotCategorical.
Implementation based on [1].
See also: :func:`torch.distributions.OneHotCategorical`
Args:
temperature (Tensor): relaxation temperature
probs (Tensor): event probabilities
logits (Tensor): the log probability of each event.
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
(Maddison et al, 2017)
[2] Categorical Reparametrization with Gumbel-Softmax
(Jang et al, 2017)
"""
arg_constraints = {'probs': constraints.simplex}
support = constraints.real
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
self._categorical = Categorical(probs, logits)
self.temperature = temperature
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super(ExpRelaxedCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
@property
def param_shape(self):
return self._categorical.param_shape
@property
def logits(self):
return self._categorical.logits
@property
def probs(self):
return self._categorical.probs
def rsample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
uniforms = clamp_probs(self.logits.new(self._extended_shape(sample_shape)).uniform_())
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
return scores - _log_sum_exp(scores)
def log_prob(self, value):
K = self._categorical._num_events
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
log_scale = (self.temperature.new(self.temperature.shape).fill_(K).lgamma() -
self.temperature.log().mul(-(K - 1)))
score = logits - value.mul(self.temperature)
score = (score - _log_sum_exp(score)).sum(-1)
return score + log_scale
示例2: OneHotCategorical
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import _new [as 別名]
class OneHotCategorical(Distribution):
r"""
Creates a one-hot categorical distribution parameterized by :attr:`probs` or
:attr:`logits`.
Samples are one-hot coded vectors of size ``probs.size(-1)``.
.. note:: :attr:`probs` will be normalized to be summing to 1.
See also: :func:`torch.distributions.Categorical` for specifications of
:attr:`probs` and :attr:`logits`.
Example::
>>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample() # equal probability of 0, 1, 2, 3
tensor([ 0., 0., 0., 1.])
Args:
probs (Tensor): event probabilities
logits (Tensor): event log probabilities
"""
arg_constraints = {'probs': constraints.simplex}
support = constraints.simplex
has_enumerate_support = True
def __init__(self, probs=None, logits=None, validate_args=None):
self._categorical = Categorical(probs, logits)
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
@property
def probs(self):
return self._categorical.probs
@property
def logits(self):
return self._categorical.logits
@property
def mean(self):
return self._categorical.probs
@property
def variance(self):
return self._categorical.probs * (1 - self._categorical.probs)
@property
def param_shape(self):
return self._categorical.param_shape
def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
probs = self._categorical.probs
one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
indices = self._categorical.sample(sample_shape)
if indices.dim() < one_hot.dim():
indices = indices.unsqueeze(-1)
return one_hot.scatter_(-1, indices, 1)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
indices = value.max(-1)[1]
return self._categorical.log_prob(indices)
def entropy(self):
return self._categorical.entropy()
def enumerate_support(self):
n = self.event_shape[0]
values = self._new((n, n))
torch.eye(n, out=values)
values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
return values.expand((n,) + self.batch_shape + (n,))
示例3: OneHotCategorical
# 需要導入模塊: from torch.distributions.categorical import Categorical [as 別名]
# 或者: from torch.distributions.categorical.Categorical import _new [as 別名]
class OneHotCategorical(Distribution):
r"""
Creates a one-hot categorical distribution parameterized by `probs`.
Samples are one-hot coded vectors of size probs.size(-1).
See also: :func:`torch.distributions.Categorical`
Example::
>>> m = OneHotCategorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample() # equal probability of 0, 1, 2, 3
0
0
1
0
[torch.FloatTensor of size 4]
Args:
probs (Tensor or Variable): event probabilities
"""
params = {'probs': constraints.simplex}
support = constraints.simplex
has_enumerate_support = True
def __init__(self, probs=None, logits=None):
self._categorical = Categorical(probs, logits)
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super(OneHotCategorical, self).__init__(batch_shape, event_shape)
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
@property
def probs(self):
return self._categorical.probs
@property
def logits(self):
return self._categorical.logits
@property
def mean(self):
return self._categorical.probs
@property
def variance(self):
return self._categorical.probs * (1 - self._categorical.probs)
@property
def param_shape(self):
return self._categorical.param_shape
def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
probs = self._categorical.probs
one_hot = probs.new(self._extended_shape(sample_shape)).zero_()
indices = self._categorical.sample(sample_shape)
if indices.dim() < one_hot.dim():
indices = indices.unsqueeze(-1)
return one_hot.scatter_(-1, indices, 1)
def log_prob(self, value):
indices = value.max(-1)[1]
return self._categorical.log_prob(indices)
def entropy(self):
return self._categorical.entropy()
def enumerate_support(self):
n = self.event_shape[0]
values = self._new((n, n))
torch.eye(n, out=values.data if isinstance(values, Variable) else values)
values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
return values.expand((n,) + self.batch_shape + (n,))