当前位置: 首页>>代码示例>>Python>>正文


Python torch.rot90方法代码示例

本文整理汇总了Python中torch.rot90方法的典型用法代码示例。如果您正苦于以下问题:Python torch.rot90方法的具体用法?Python torch.rot90怎么用?Python torch.rot90使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.rot90方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: map_pool

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def map_pool(output_size, dtype=np.float32):
    def _thunk(obs_space):
        def runner(x):
            with torch.no_grad():
                # input: h x w x c
                # output: c x h x w  and normalized, ready to pass into net.forward
                assert x.shape[0] == x.shape[1], 'we are only using square data, data format: N,H,W,C'
                if isinstance(x, torch.Tensor):  # for training
                    x = torch.cuda.FloatTensor(x.cuda())
                else:  # for testing
                    x = torch.cuda.FloatTensor(x.copy()).cuda()

                x.unsqueeze_(0)

                x = x.permute(0, 3, 1, 2) / 255.0 #.view(1, 3, 256, 256)
                x = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
                x = torch.rot90(x, k=2, dims=(2,3)) # Face north (this could be computed using a different world2agent transform)
                x = 2.0 * x - 1.0

                x.squeeze_(0)
#                 print(x.shape)
                return x.cpu()

        return runner, spaces.Box(-1, 1, output_size, dtype)
    return _thunk 
开发者ID:alexsax,项目名称:midlevel-reps,代码行数:27,代码来源:transforms.py

示例2: map_pool_collated

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def map_pool_collated(output_size, dtype=np.float32):
    def _thunk(obs_space):
        def runner(x):
            with torch.no_grad():
                # input: n x h x w x c
                # output: n x c x h x w  and normalized, ready to pass into net.forward
                assert x.shape[2] == x.shape[1], 'we are only using square data, data format: N,H,W,C'
                if isinstance(x, torch.Tensor):  # for training
                    x = torch.cuda.FloatTensor(x.cuda())
                else:  # for testing
                    x = torch.cuda.FloatTensor(x.copy()).cuda()

                x = x.permute(0, 3, 1, 2) / 255.0 #.view(1, 3, 256, 256)
                x = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
                x = torch.rot90(x, k=2, dims=(2,3)) # Face north (this could be computed using a different world2agent transform)
                x = 2.0 * x - 1.0
                return x

        return runner, spaces.Box(-1, 1, output_size, dtype)
    return _thunk 
开发者ID:alexsax,项目名称:midlevel-reps,代码行数:22,代码来源:transforms.py

示例3: apply_tta

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def apply_tta(input):
    inputs = []
    inputs.append(input)
    inputs.append(torch.flip(input, dims=[2]))
    inputs.append(torch.flip(input, dims=[3]))
    inputs.append(torch.rot90(input, k=1, dims=[2, 3]))
    inputs.append(torch.rot90(input, k=2, dims=[2, 3]))
    inputs.append(torch.rot90(input, k=3, dims=[2, 3]))
    inputs.append(torch.rot90(torch.flip(input, dims=[2]), k=1, dims=[2, 3]))
    inputs.append(torch.rot90(torch.flip(input, dims=[2]), k=3, dims=[2, 3]))
    return inputs 
开发者ID:4uiiurz1,项目名称:kaggle-aptos2019-blindness-detection,代码行数:13,代码来源:test.py

示例4: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def forward(self, x, y):
        if self.rotations:
            k_rot = random.choice([-1, 0, 1])
            x = torch.rot90(x, k_rot, [2, 3])
            y = torch.rot90(y, k_rot, [2, 3])
        if self.flips:
            if random.choice([True, False]):
                x = torch.flip(x, (2,))
                y = torch.flip(y, (2,))
            if random.choice([True, False]):
                x = torch.flip(x, (3,))
                y = torch.flip(y, (3,))
        return self.loss(x, y) 
开发者ID:ManuelFritsche,项目名称:real-world-sr,代码行数:15,代码来源:loss.py

示例5: rot90

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def rot90(x, k=1):
    """rotate batch of images by 90 degrees k times"""
    return torch.rot90(x, k, (2, 3)) 
开发者ID:qubvel,项目名称:ttach,代码行数:5,代码来源:functional.py

示例6: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def forward(self, batch, angle):
        # rotation is couterclockwise
        k = angle // 90
        return torch.rot90(batch, k, (2, 3)) 
开发者ID:bonlime,项目名称:pytorch-tools,代码行数:6,代码来源:functional.py

示例7: transform

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def transform(self, x):
        return torch.rot90(x, k=self.k, dims=(-1, -2)) 
开发者ID:szymonmaszke,项目名称:torchlayers,代码行数:4,代码来源:preprocessing.py

示例8: backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def backward(self, inp):
        return self.forward(inp)


# TODO
# class Rot90Augment:
#     def __init__(self, k, dims):
#         self.k = k
#         self.dims = dims
#
#     def forward(self, inp):
#         return torch.rot90(inp, k=self.k, dims=self.dims)
#
#     def backward(self, inp):
#         return torch.rot90(inp, k=-self.k, dims=self.dims) 
开发者ID:ELEKTRONN,项目名称:elektronn3,代码行数:17,代码来源:inference.py

示例9: rot90

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def rot90(data: torch.Tensor, k: int, dims: Union[int, Sequence[int]]):
    """
    Rotate 90 degrees around dims

    Args:
        data: input data
        k: number of times to rotate
        dims: dimensions to mirror

    Returns:
        torch.Tensor: tensor with mirrored dimensions
    """
    dims = [int(d + 2) for d in dims]
    return torch.rot90(data, int(k), dims) 
开发者ID:PhoenixDL,项目名称:rising,代码行数:16,代码来源:spatial.py

示例10: train

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def train():
    net.train()  # enter train mode
    loss_avg = 0.0
    for bx, by in train_loader:
        curr_batch_size = bx.size(0)
        by_prime = torch.cat((torch.zeros(bx.size(0)), torch.ones(bx.size(0)),
                              2*torch.ones(bx.size(0)), 3*torch.ones(bx.size(0))), 0).long()
        bx = bx.numpy()
        # use torch.rot90 in later versions of pytorch
        bx = np.concatenate((bx, bx, np.rot90(bx, 1, axes=(2, 3)),
                             np.rot90(bx, 2, axes=(2, 3)), np.rot90(bx, 3, axes=(2, 3))), 0)
        bx = torch.FloatTensor(bx)
        bx, by, by_prime = bx.cuda(), by.cuda(), by_prime.cuda()

        adv_bx = adversary(net, bx, by, by_prime, curr_batch_size)

        # forward
        logits, pen = net(adv_bx * 2 - 1)

        # backward
        scheduler.step()
        optimizer.zero_grad()
        loss = F.cross_entropy(logits[:curr_batch_size], by)
        loss += 0.5 * F.cross_entropy(net.module.rot_pred(pen[curr_batch_size:]), by_prime)
        loss.backward()
        optimizer.step()

        # exponential moving average
        loss_avg = loss_avg * 0.9 + float(loss) * 0.1

    state['train_loss'] = loss_avg

# test function 
开发者ID:hendrycks,项目名称:ss-ood,代码行数:35,代码来源:train.py

示例11: get_random_augmenters

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def get_random_augmenters(
            ndim: int
    ) -> Tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[torch.Tensor], torch.Tensor]]:
        """Produce a pair of functions ``augment, reverse_augment``, where
        the ``augment`` function applies a random augmentation to a torch
        tensor and the ``reverse_augment`` function performs the reverse
        aumentations if applicable (i.e. for geometrical transformations)
        so pixel-level loss calculation is still correct).

        Note that all augmentations are performed on the compute device that
        holds the input, so generally on the GPU.
        """

        # Random rotation angle (in 90 degree steps)
        k90 = torch.randint(0, 4, ()).item()
        # Get a random selection of spatial dims (ranging from [] to [2, 3, ..., example.ndim - 1]
        flip_dims_binary = torch.randint(0, 2, (ndim - 2,))
        flip_dims = (torch.nonzero(flip_dims_binary, as_tuple=False).squeeze(1) + 2).tolist()

        @torch.no_grad()
        def augment(x: torch.Tensor) -> torch.Tensor:
            x = torch.rot90(x, +k90, (-1, -2))
            if len(flip_dims) > 0:
                x = torch.flip(x, flip_dims)

            # # Uncomment to enable additional random brightness and contrast augmentations
            # contrast_std = 0.1
            # brightness_std = 0.1
            # a = torch.randn(x.shape[:2], device=x.device, dtype=x.dtype) * contrast_std + 1.0
            # b = torch.randn(x.shape[:2], device=x.device, dtype=x.dtype) * brightness_std
            # for n in range(x.shape[0]):
            #     for c in range(x.shape[1]):
            #         # Formula based on tf.image.{adjust_contrast,adjust_brightness}
            #         # See https://www.tensorflow.org/api_docs/python/tf/image
            #         m = torch.mean(x[n, c])
            #         x[n, c] = a[n, c] * (x[n, c] - m) + m + b[n, c]

            # # Uncomment to enable additional additive gaussian noise augmentations
            # agn_std = 0.1
            # x.add_(torch.randn_like(x).mul_(agn_std))

            return x

        @torch.no_grad()
        def reverse_augment(x: torch.Tensor) -> torch.Tensor:
            if len(flip_dims) > 0:  # Check is necessary only on cuda
                x = torch.flip(x, flip_dims)
            x = torch.rot90(x, -k90, (-1, -2))
            return x

        return augment, reverse_augment 
开发者ID:ELEKTRONN,项目名称:elektronn3,代码行数:53,代码来源:loss.py

示例12: construct_occupancy_map

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rot90 [as 别名]
def construct_occupancy_map(self) -> np.ndarray:
        # does not need second_last_pointgoal
        if len(self.history) == 0:
            return np.zeros((MAP_SIZE, MAP_SIZE, 3), dtype=np.uint8)

                
        cur_agent_pos_polar, cur_agent_heading = self.last_state[:2], self.last_state[2]
        cur_agent_pos_xy = convert_polar_to_xy(cur_agent_pos_polar)

        global_coords_polar = copy.deepcopy(np.array(self.history))[:,:2]  # throw away heading
        global_coords_xy = convert_polar_to_xy(global_coords_polar)


        # translate then rotate by negative angle only because we rotate everything by PI before return
        # rotation subtracts initial heading so that the initial agent always points 'north'
        agent_coords = global_coords_xy - cur_agent_pos_xy
        agent_coords = rotate(agent_coords, -1 * (cur_agent_heading - self.init_heading))
#         agent_coords = rotate(agent_coords, -1 * (np.pi + cur_agent_heading - self.init_heading))


        # calculate goal coordinates (independent of forward model)
        last_pointgoal_rotated = self.last_pointgoal #+ np.array([0, np.pi])
        goal = convert_polar_to_xy(last_pointgoal_rotated)
        goal_coords = np.array([goal])

        # quantize
        visitation_cells = pos_to_map(agent_coords + self.max_building_size / 2, cell_size=self.cell_size)
        goal_cells = pos_to_map(goal_coords + self.max_building_size / 2, cell_size=self.cell_size)


        # plot (make ambient pixels 128 so that they are 0 when pass into nn)
        omap = torch.full((3, MAP_SIZE, MAP_SIZE), fill_value=128, dtype=torch.uint8, device=None, requires_grad=False) # Avoid multiplies, stack, and copying to torch
        omap[0][visitation_cells[:, 0], visitation_cells[:, 1]] = 255 # Agent visitation
        omap[1][goal_cells[:, 0], goal_cells[:, 1]] = 255 # Goal
        # omap[2][visitation_cells[-1][0], visitation_cells[-1][1]] = 255 # Agent itself

        # omap = np.rot90(omap, k=2, axes=(0,1))
        # WARNING: with code checkpoints, we need the map to be rotated
        # omap = torch.rot90(omap, k=2, dims=(1,2)) # Face north (this could be computed using a different world2agent transform)
        
        if self.max_pool:
            omap = F.max_pool2d(omap.float(), kernel_size=3, stride=1, padding=1).byte()

        omap = omap.permute(1, 2, 0).cpu().numpy()
        assert omap.dtype == np.uint8, f'Omap needs to be uint8, currently {omap.dtype}'
        return omap 
开发者ID:alexsax,项目名称:midlevel-reps,代码行数:48,代码来源:occupancy_map.py


注:本文中的torch.rot90方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。