Python functional.gumbel_softmax方法代码示例

本文整理汇总了Python中torch.nn.functional.gumbel_softmax方法的典型用法代码示例。如果您正苦于以下问题:Python functional.gumbel_softmax方法的具体用法?Python functional.gumbel_softmax怎么用?Python functional.gumbel_softmax使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.nn.functional的用法示例。


示例1: _apply_activate

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def _apply_activate(self, data):
        data_t = []
        st = 0
        for item in self.transformer.output_info:
            if item[1] == 'tanh':
                ed = st + item[0]
                data_t.append(torch.tanh(data[:, st:ed]))
                st = ed
            elif item[1] == 'softmax':
                ed = st + item[0]
                data_t.append(functional.gumbel_softmax(data[:, st:ed], tau=0.2))
                st = ed
                assert 0

        return torch.cat(data_t, dim=1) 

示例2: assign

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def assign(self, points, distance='euclid', greedy=False):
        # points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        print('Distances:', distances[:3])
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 

示例3: assign

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def assign(self, points, distance='euclid', greedy=False):
        points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 

示例4: forward

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def forward(self, x, attention_mask, gumbel_softmax=True, tau=None):
        extended_attention_mask = self.convert_mask(attention_mask)
        h = self.bert_layer(x, extended_attention_mask)
        h = self.linear_layer(h)
        log_probs = self.log_sigmoid(h).squeeze(dim=2)

        if gumbel_softmax:
            tau = self.tau if tau is None else tau
            return F.gumbel_softmax(log_probs, tau=tau)
            return log_probs 

示例5: forward

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def forward(self, input, discrete=False, normalize=False):
        # NASBench only has one input to each cell
        s0 = self.stem(input)
        for i, cell in enumerate(self.cells):
            if i in [self._layers // 3, 2 * self._layers // 3]:
                # Perform down-sampling by factor 1/2
                # Equivalent to https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L68
                s0 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(s0)
            # If using discrete architecture from random_ws search with weight sharing then pass through architecture
            # weights directly.
            # For GDAS use gumbel softmax hard, therefore per mixed block only a single operation is evaluated
            preprocess_op_mixed_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=True, dim=-1)
            # Don't use hard for the rest, because it very quickly gave exploding gradients
            preprocess_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=False, dim=-1)

            # Normalize mixed_op weights for the choice blocks in the graph
            mixed_op_weights = preprocess_op_mixed_op(self._arch_parameters[0])
            # Normalize the output weights
            output_weights = preprocess_op(self._arch_parameters[1]) if self._output_weights else None
            # Normalize the input weights for the nodes in the cell
            input_weights = [preprocess_op(alpha) for alpha in self._arch_parameters[2:]]
            s0 = cell(s0, mixed_op_weights, output_weights, input_weights)

        # Include one more preprocessing step here
        s0 = self.postprocess(s0)  # [N, C_max * (steps + 1), w, h] -> [N, C_max, w, h]

        # Global Average Pooling by averaging over last two remaining spatial dimensions
        # Like in nasbench: https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L92
        out = s0.view(*s0.shape[:2], -1).mean(-1)
        logits = self.classifier(out.view(out.size(0), -1))
        return logits 

示例6: apply_activate

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def apply_activate(data, output_info):
    data_t = []
    st = 0
    for item in output_info:
        if item[1] == 'tanh':
            ed = st + item[0]
            data_t.append(torch.tanh(data[:, st:ed]))
            st = ed
        elif item[1] == 'softmax':
            ed = st + item[0]
            data_t.append(F.gumbel_softmax(data[:, st:ed], tau=0.2))
            st = ed
            assert 0
    return torch.cat(data_t, dim=1) 

示例7: forward

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def forward(self, x, key):
        T, B, C = x.size()
        loss = torch.zeros(1).type_as(x).to(x.device)
        if key is not None:
            Tr = 1
            key = x
            Tr = T

        if self.tau:
            resp = F.gumbel_softmax(
                self.assign(key.contiguous().view(Tr*B, self.key_dim)),
            )  # T*B, ne
            resp = torch.softmax(
                self.assign(key.contiguous().view(Tr*B, self.key_dim)),
            )  # T*B, ne

        importance = resp.sum(dim=0)
        loss = self.loss_scale * torch.std(importance) / torch.mean(importance)
        print('importance', importance.data.round())
        # w = torch.matmul(resp, self.pw_w1)  # T*B, C_out * C_in
        # w = w.view(T, B, self.output_dim, self.input_dim)
        # x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1)
        # if self.pw_bias is not None:
            # x = x + self.pw_bias(x0)
        # First evaluate each expert output
        resp = resp.view(Tr, B, self.ne, 1)
        residual = x
        x = torch.matmul(self.pw_w1, x.unsqueeze(2).unsqueeze(-1)).squeeze(-1)  # T, B, ne, out
        x = F.relu(x)
        x = torch.sum(resp * x, dim=2)
        if self.pw_bias is not None:
            x = x + self.pw_bias(key)
        return x + residual, loss 

示例8: step

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def step(self, x, n, cumul=None, total_computes=None):
        n is the index of the upcoming block, 
        Given the current activation decide whether to go in or skip/exit.
        returns the binary decision and the log-(p, 1-p)
        T, B, C = x.size()
        if self.detach_before_classifier:
            x = x.detach()
        x = self.halting_predictors[n if self.separate_halting_predictors else 0](x)
        halt_logits = F.logsigmoid(x)  # the log-p of halting
        # Apply the gumbel trick
        halt = halt_logits.view(-1, 2)
        halt = F.gumbel_softmax(halt, tau=self.gumbel_tau).view(T, B, 2)
        return halt 

示例9: forward

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def forward(self, q, k, v):
        b_q, t_q, dim_q = list(q.size())
        b_k, t_k, dim_k = list(k.size())
        b_v, t_v, dim_v = list(v.size())
        assert(b_q == b_k and b_k == b_v)  # batch size should be equal
        assert(dim_q == dim_k)  # dims should be equal
        assert(t_k == t_v)  # times should be equal
        b = b_q
        qk = torch.bmm(q, k.transpose(1, 2))  # b x t_q x t_k
        qk = qk / (dim_k ** 0.5)
        mask = None
        with torch.no_grad():
            if self.causal and t_q > 1:
                causal_mask = q.data.new(t_q, t_k).byte().fill_(1).triu_(1)
                mask = causal_mask.unsqueeze(0).expand(b, t_q, t_k)
            if self.mask_k is not None:
                mask_k = self.mask_k.unsqueeze(1).expand(b, t_q, t_k)
                mask = mask_k if mask is None else mask | mask_k
            if self.mask_q is not None:
                mask_q = self.mask_q.unsqueeze(2).expand(b, t_q, t_k)
                mask = mask_q if mask is None else mask | mask_q
        if mask is not None:
            qk.masked_fill_(mask, -1e12)

        if self.gumbel:
            sm_qk = F.gumbel_softmax(qk, dim=2, hard=True)
            sm_qk = F.softmax(qk, dim=2,
                              dtype=torch.float32 if qk.dtype == torch.float16 else qk.dtype)
        sm_qk = self.dropout(sm_qk)
        return torch.bmm(sm_qk, v), sm_qk  # b x t_q x dim_v 

示例10: forward

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def forward(self, logits, training=True, temperature=None):
        # gumbel-softmax (training and evaluation)
        if temperature is not None:
            return F.gumbel_softmax(logits, hard=not training, tau=temperature)
        # softmax training
        elif training:
            return F.softmax(logits, dim=1)
        # softmax evaluation
            return OneHotCategorical(logits=logits).sample() 

示例11: test_gumbel_softmax

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def test_gumbel_softmax(self):
        inp = torch.randn(16, 1024, device='cuda', dtype=self.dtype)
        output = F.gumbel_softmax(inp, tau=1, hard=False, eps=1e-10, dim=-1) 

示例12: parse_gumbel

# 需要导入模块: from torch.nn import functional [as 别名]
# 或者: from torch.nn.functional import gumbel_softmax [as 别名]
def parse_gumbel(alpha, beta, k):
    parse continuous alpha to discrete gene.
    alpha is ParameterList:
    ParameterList [
        Parameter(n_edges1, n_ops),
        Parameter(n_edges2, n_ops),

    beta is ParameterList:
    ParameterList [

    gene is list:
        [('node1_ops_1', node_idx), ..., ('node1_ops_k', node_idx)],
        [('node2_ops_1', node_idx), ..., ('node2_ops_k', node_idx)],
    each node has two edges (k=2) in CNN.

    gene = []
    assert PRIMITIVES[-1] == 'none'  # assume last PRIMITIVE is 'none'

    # 1) Convert the mixed op to discrete edge (single op) by choosing top-1 weight edge
    # 2) Choose top-k edges per node by edge score (top-1 weight in edge)
    # output the connect idx[(node_idx, connect_idx, op_idx).... () ()]
    connect_idx = []
    for edges, w in zip(alpha, beta):
        # edges: Tensor(n_edges, n_ops)
        discrete_a = F.gumbel_softmax(edges[:, :-1].reshape(-1), tau=1, hard=True)
        for i in range(k-1):
            discrete_a = discrete_a + F.gumbel_softmax(edges[:, :-1].reshape(-1), tau=1, hard=True)
        discrete_a = discrete_a.reshape(-1, len(PRIMITIVES)-1)
        reserved_edge = (discrete_a > 0).nonzero()

        node_gene = []
        node_idx = []
        for i in range(reserved_edge.shape[0]):
            edge_idx = reserved_edge[i][0].item()
            prim_idx = reserved_edge[i][1].item()
            prim = PRIMITIVES[prim_idx]
            node_gene.append((prim, edge_idx))
            node_idx.append((edge_idx, prim_idx))


    return gene, connect_idx 
