本文整理汇总了Python中torch.nn.functional.gumbel_softmax方法的典型用法代码示例。如果您正苦于以下问题:Python functional.gumbel_softmax方法的具体用法?Python functional.gumbel_softmax怎么用?Python functional.gumbel_softmax使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.nn.functional
的用法示例。
在下文中一共展示了functional.gumbel_softmax方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _apply_activate
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def _apply_activate(self, data):
data_t = []
st = 0
for item in self.transformer.output_info:
if item[1] == 'tanh':
ed = st + item[0]
data_t.append(torch.tanh(data[:, st:ed]))
st = ed
elif item[1] == 'softmax':
ed = st + item[0]
data_t.append(functional.gumbel_softmax(data[:, st:ed], tau=0.2))
st = ed
else:
assert 0
return torch.cat(data_t, dim=1)
示例2: assign
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def assign(self, points, distance='euclid', greedy=False):
# points = points.data
centroids = self.centroids
if distance == 'cosine':
# nearest neigbor in the centroids (cosine distance):
points = F.normalize(points, dim=-1)
centroids = F.normalize(centroids, dim=-1)
distances = (torch.sum(points**2, dim=1, keepdim=True) +
torch.sum(centroids**2, dim=1, keepdim=True).t() -
2 * torch.matmul(points, centroids.t())) # T*B, e
print('Distances:', distances[:3])
if not greedy:
logits = - distances
resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
else:
# Greedy non-differentiable responsabilities:
indices = torch.argmin(distances, dim=-1) # T*B
resp = torch.zeros(points.size(0), self.ne).type_as(points)
resp.scatter_(1, indices.unsqueeze(1), 1)
return resp
示例3: assign
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def assign(self, points, distance='euclid', greedy=False):
points = points.data
centroids = self.centroids
if distance == 'cosine':
# nearest neigbor in the centroids (cosine distance):
points = F.normalize(points, dim=-1)
centroids = F.normalize(centroids, dim=-1)
distances = (torch.sum(points**2, dim=1, keepdim=True) +
torch.sum(centroids**2, dim=1, keepdim=True).t() -
2 * torch.matmul(points, centroids.t())) # T*B, e
if not greedy:
logits = - distances
resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
# batch_counts = resp.sum(dim=0).view(-1).data
else:
# Greedy non-differentiable responsabilities:
indices = torch.argmin(distances, dim=-1) # T*B
resp = torch.zeros(points.size(0), self.ne).type_as(points)
resp.scatter_(1, indices.unsqueeze(1), 1)
return resp
示例4: forward
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def forward(self, x, attention_mask, gumbel_softmax=True, tau=None):
extended_attention_mask = self.convert_mask(attention_mask)
h = self.bert_layer(x, extended_attention_mask)
h = self.linear_layer(h)
log_probs = self.log_sigmoid(h).squeeze(dim=2)
if gumbel_softmax:
tau = self.tau if tau is None else tau
return F.gumbel_softmax(log_probs, tau=tau)
else:
return log_probs
示例5: forward
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def forward(self, input, discrete=False, normalize=False):
# NASBench only has one input to each cell
s0 = self.stem(input)
for i, cell in enumerate(self.cells):
if i in [self._layers // 3, 2 * self._layers // 3]:
# Perform down-sampling by factor 1/2
# Equivalent to https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L68
s0 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(s0)
# If using discrete architecture from random_ws search with weight sharing then pass through architecture
# weights directly.
# For GDAS use gumbel softmax hard, therefore per mixed block only a single operation is evaluated
preprocess_op_mixed_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=True, dim=-1)
# Don't use hard for the rest, because it very quickly gave exploding gradients
preprocess_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=False, dim=-1)
# Normalize mixed_op weights for the choice blocks in the graph
mixed_op_weights = preprocess_op_mixed_op(self._arch_parameters[0])
# Normalize the output weights
output_weights = preprocess_op(self._arch_parameters[1]) if self._output_weights else None
# Normalize the input weights for the nodes in the cell
input_weights = [preprocess_op(alpha) for alpha in self._arch_parameters[2:]]
s0 = cell(s0, mixed_op_weights, output_weights, input_weights)
# Include one more preprocessing step here
s0 = self.postprocess(s0) # [N, C_max * (steps + 1), w, h] -> [N, C_max, w, h]
# Global Average Pooling by averaging over last two remaining spatial dimensions
# Like in nasbench: https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L92
out = s0.view(*s0.shape[:2], -1).mean(-1)
logits = self.classifier(out.view(out.size(0), -1))
return logits
示例6: apply_activate
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def apply_activate(data, output_info):
data_t = []
st = 0
for item in output_info:
if item[1] == 'tanh':
ed = st + item[0]
data_t.append(torch.tanh(data[:, st:ed]))
st = ed
elif item[1] == 'softmax':
ed = st + item[0]
data_t.append(F.gumbel_softmax(data[:, st:ed], tau=0.2))
st = ed
else:
assert 0
return torch.cat(data_t, dim=1)
示例7: forward
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def forward(self, x, key):
T, B, C = x.size()
loss = torch.zeros(1).type_as(x).to(x.device)
if key is not None:
Tr = 1
else:
key = x
Tr = T
if self.tau:
resp = F.gumbel_softmax(
self.assign(key.contiguous().view(Tr*B, self.key_dim)),
tau=self.tau,
hard=self.hard
) # T*B, ne
else:
resp = torch.softmax(
self.assign(key.contiguous().view(Tr*B, self.key_dim)),
dim=-1
) # T*B, ne
importance = resp.sum(dim=0)
loss = self.loss_scale * torch.std(importance) / torch.mean(importance)
print('importance', importance.data.round())
# w = torch.matmul(resp, self.pw_w1) # T*B, C_out * C_in
# w = w.view(T, B, self.output_dim, self.input_dim)
# x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1)
# if self.pw_bias is not None:
# x = x + self.pw_bias(x0)
# First evaluate each expert output
resp = resp.view(Tr, B, self.ne, 1)
residual = x
x = torch.matmul(self.pw_w1, x.unsqueeze(2).unsqueeze(-1)).squeeze(-1) # T, B, ne, out
x = F.relu(x)
x = torch.sum(resp * x, dim=2)
if self.pw_bias is not None:
x = x + self.pw_bias(key)
return x + residual, loss
示例8: step
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def step(self, x, n, cumul=None, total_computes=None):
"""
n is the index of the upcoming block,
Given the current activation decide whether to go in or skip/exit.
returns the binary decision and the log-(p, 1-p)
"""
T, B, C = x.size()
if self.detach_before_classifier:
x = x.detach()
x = self.halting_predictors[n if self.separate_halting_predictors else 0](x)
halt_logits = F.logsigmoid(x) # the log-p of halting
# Apply the gumbel trick
halt = halt_logits.view(-1, 2)
halt = F.gumbel_softmax(halt, tau=self.gumbel_tau).view(T, B, 2)
return halt
示例9: forward
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def forward(self, q, k, v):
b_q, t_q, dim_q = list(q.size())
b_k, t_k, dim_k = list(k.size())
b_v, t_v, dim_v = list(v.size())
assert(b_q == b_k and b_k == b_v) # batch size should be equal
assert(dim_q == dim_k) # dims should be equal
assert(t_k == t_v) # times should be equal
b = b_q
qk = torch.bmm(q, k.transpose(1, 2)) # b x t_q x t_k
qk = qk / (dim_k ** 0.5)
mask = None
with torch.no_grad():
if self.causal and t_q > 1:
causal_mask = q.data.new(t_q, t_k).byte().fill_(1).triu_(1)
mask = causal_mask.unsqueeze(0).expand(b, t_q, t_k)
if self.mask_k is not None:
mask_k = self.mask_k.unsqueeze(1).expand(b, t_q, t_k)
mask = mask_k if mask is None else mask | mask_k
if self.mask_q is not None:
mask_q = self.mask_q.unsqueeze(2).expand(b, t_q, t_k)
mask = mask_q if mask is None else mask | mask_q
if mask is not None:
qk.masked_fill_(mask, -1e12)
if self.gumbel:
sm_qk = F.gumbel_softmax(qk, dim=2, hard=True)
else:
sm_qk = F.softmax(qk, dim=2,
dtype=torch.float32 if qk.dtype == torch.float16 else qk.dtype)
sm_qk = self.dropout(sm_qk)
return torch.bmm(sm_qk, v), sm_qk # b x t_q x dim_v
示例10: forward
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def forward(self, logits, training=True, temperature=None):
# gumbel-softmax (training and evaluation)
if temperature is not None:
return F.gumbel_softmax(logits, hard=not training, tau=temperature)
# softmax training
elif training:
return F.softmax(logits, dim=1)
# softmax evaluation
else:
return OneHotCategorical(logits=logits).sample()
示例11: test_gumbel_softmax
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def test_gumbel_softmax(self):
inp = torch.randn(16, 1024, device='cuda', dtype=self.dtype)
output = F.gumbel_softmax(inp, tau=1, hard=False, eps=1e-10, dim=-1)
示例12: parse_gumbel
# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def parse_gumbel(alpha, beta, k):
"""
parse continuous alpha to discrete gene.
alpha is ParameterList:
ParameterList [
Parameter(n_edges1, n_ops),
Parameter(n_edges2, n_ops),
...
]
beta is ParameterList:
ParameterList [
Parameter(n_edges1),
Parameter(n_edges2),
...
]
gene is list:
[
[('node1_ops_1', node_idx), ..., ('node1_ops_k', node_idx)],
[('node2_ops_1', node_idx), ..., ('node2_ops_k', node_idx)],
...
]
each node has two edges (k=2) in CNN.
"""
gene = []
assert PRIMITIVES[-1] == 'none' # assume last PRIMITIVE is 'none'
# 1) Convert the mixed op to discrete edge (single op) by choosing top-1 weight edge
# 2) Choose top-k edges per node by edge score (top-1 weight in edge)
# output the connect idx[(node_idx, connect_idx, op_idx).... () ()]
connect_idx = []
for edges, w in zip(alpha, beta):
# edges: Tensor(n_edges, n_ops)
discrete_a = F.gumbel_softmax(edges[:, :-1].reshape(-1), tau=1, hard=True)
for i in range(k-1):
discrete_a = discrete_a + F.gumbel_softmax(edges[:, :-1].reshape(-1), tau=1, hard=True)
discrete_a = discrete_a.reshape(-1, len(PRIMITIVES)-1)
reserved_edge = (discrete_a > 0).nonzero()
node_gene = []
node_idx = []
for i in range(reserved_edge.shape[0]):
edge_idx = reserved_edge[i][0].item()
prim_idx = reserved_edge[i][1].item()
prim = PRIMITIVES[prim_idx]
node_gene.append((prim, edge_idx))
node_idx.append((edge_idx, prim_idx))
gene.append(node_gene)
connect_idx.append(node_idx)
return gene, connect_idx