当前位置: 首页>>代码示例>>Python>>正文


Python transforms.Compose方法代码示例

本文整理汇总了Python中torch_geometric.transforms.Compose方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.Compose方法的具体用法?Python transforms.Compose怎么用?Python transforms.Compose使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch_geometric.transforms的用法示例。


在下文中一共展示了transforms.Compose方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_compose

# 需要导入模块: from torch_geometric import transforms [as 别名]
# 或者: from torch_geometric.transforms import Compose [as 别名]
def test_compose():
    transform = T.Compose([T.Center(), T.AddSelfLoops()])
    assert transform.__repr__() == ('Compose([\n'
                                    '    Center(),\n'
                                    '    AddSelfLoops(),\n'
                                    '])')

    pos = torch.Tensor([[0, 0], [2, 0], [4, 0]])
    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

    data = Data(edge_index=edge_index, pos=pos)
    data = transform(data)
    assert len(data) == 2
    assert data.pos.tolist() == [[-2, 0], [0, 0], [2, 0]]
    assert data.edge_index.tolist() == [[0, 0, 1, 1, 1, 2, 2],
                                        [0, 1, 0, 1, 2, 1, 2]] 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:18,代码来源:test_compose.py

示例2: get_planetoid_dataset

# 需要导入模块: from torch_geometric import transforms [as 别名]
# 或者: from torch_geometric.transforms import Compose [as 别名]
def get_planetoid_dataset(name, normalize_features=False, transform=None):
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)
    dataset = Planetoid(path, name)

    if transform is not None and normalize_features:
        dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
    elif normalize_features:
        dataset.transform = T.NormalizeFeatures()
    elif transform is not None:
        dataset.transform = transform

    return dataset 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:14,代码来源:datasets.py

示例3: __init__

# 需要导入模块: from torch_geometric import transforms [as 别名]
# 或者: from torch_geometric.transforms import Compose [as 别名]
def __init__(self):
        dataset = "QM9"
        path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)

        target=0
        class MyTransform(object):
            def __call__(self, data):
                # Specify target.
                data.y = data.y[:, target]
                return data

        class Complete(object):
            def __call__(self, data):
                device = data.edge_index.device
                row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
                col = torch.arange(data.num_nodes, dtype=torch.long, device=device)
                row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
                col = col.repeat(data.num_nodes)
                edge_index = torch.stack([row, col], dim=0)
                edge_attr = None
                if data.edge_attr is not None:
                    idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
                    size = list(data.edge_attr.size())
                    size[0] = data.num_nodes * data.num_nodes
                    edge_attr = data.edge_attr.new_zeros(size)
                    edge_attr[idx] = data.edge_attr
                edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
                data.edge_attr = edge_attr
                data.edge_index = edge_index
                return data

        transform = T.Compose([MyTransform(), Complete(), T.Distance(norm=False)])
        if not osp.exists(path):
            QM9(path)
        super(QM9Dataset, self).__init__(path) 
开发者ID:THUDM,项目名称:cogdl,代码行数:37,代码来源:pyg.py

示例4: get_dataset

# 需要导入模块: from torch_geometric import transforms [as 别名]
# 或者: from torch_geometric.transforms import Compose [as 别名]
def get_dataset(name, sparse=True, cleaned=False):
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)
    dataset = TUDataset(path, name, cleaned=cleaned)
    dataset.data.edge_attr = None

    if dataset.data.x is None:
        max_degree = 0
        degs = []
        for data in dataset:
            degs += [degree(data.edge_index[0], dtype=torch.long)]
            max_degree = max(max_degree, degs[-1].max().item())

        if max_degree < 1000:
            dataset.transform = T.OneHotDegree(max_degree)
        else:
            deg = torch.cat(degs, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            dataset.transform = NormalizedDegree(mean, std)

    if not sparse:
        num_nodes = max_num_nodes = 0
        for data in dataset:
            num_nodes += data.num_nodes
            max_num_nodes = max(data.num_nodes, max_num_nodes)

        # Filter out a few really large graphs in order to apply DiffPool.
        if name == 'REDDIT-BINARY':
            num_nodes = min(int(num_nodes / len(dataset) * 1.5), max_num_nodes)
        else:
            num_nodes = min(int(num_nodes / len(dataset) * 5), max_num_nodes)

        indices = []
        for i, data in enumerate(dataset):
            if data.num_nodes <= num_nodes:
                indices.append(i)
        dataset = dataset[torch.tensor(indices)]

        if dataset.transform is None:
            dataset.transform = T.ToDense(num_nodes)
        else:
            dataset.transform = T.Compose(
                [dataset.transform, T.ToDense(num_nodes)])

    return dataset 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:46,代码来源:datasets.py


注:本文中的torch_geometric.transforms.Compose方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。