本文整理汇总了Python中torch_geometric.data.InMemoryDataset方法的典型用法代码示例。如果您正苦于以下问题:Python data.InMemoryDataset方法的具体用法?Python data.InMemoryDataset怎么用?Python data.InMemoryDataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch_geometric.data
的用法示例。
在下文中一共展示了data.InMemoryDataset方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_in_memory_dataset
# 需要导入模块: from torch_geometric import data [as 别名]
# 或者: from torch_geometric.data import InMemoryDataset [as 别名]
def test_in_memory_dataset():
class TestDataset(InMemoryDataset):
def __init__(self, data_list):
super(TestDataset, self).__init__('/tmp/TestDataset')
self.data, self.slices = self.collate(data_list)
x = torch.Tensor([[1], [1], [1]])
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
face = torch.tensor([[0], [1], [2]])
i = 1
s = '1'
data1 = Data(x=x, edge_index=edge_index, face=face, test_int=i, test_str=s)
data1.num_nodes = 10
data2 = Data(x=x, edge_index=edge_index, face=face, test_int=i, test_str=s)
data2.num_nodes = 5
dataset = TestDataset([data1, data2])
assert len(dataset) == 2
assert dataset[0].num_nodes == 10
assert len(dataset[0]) == 5
assert dataset[1].num_nodes == 5
assert len(dataset[1]) == 5
示例2: __init__
# 需要导入模块: from torch_geometric import data [as 别名]
# 或者: from torch_geometric.data import InMemoryDataset [as 别名]
def __init__(self,
dataset: InMemoryDataset,
hidden: List[int] = [64],
dropout: float = 0.5):
super(GCN, self).__init__()
num_features = [dataset.data.x.shape[1]] + hidden + [dataset.num_classes]
layers = []
for in_features, out_features in zip(num_features[:-1], num_features[1:]):
layers.append(GCNConv(in_features, out_features))
self.layers = ModuleList(layers)
self.reg_params = list(layers[0].parameters())
self.non_reg_params = list([p for l in layers[1:] for p in l.parameters()])
self.dropout = Dropout(p=dropout)
self.act_fn = ReLU()
示例3: __init__
# 需要导入模块: from torch_geometric import data [as 别名]
# 或者: from torch_geometric.data import InMemoryDataset [as 别名]
def __init__(self, args):
super(UnsupervisedNodeClassification, self).__init__(args)
dataset = build_dataset(args)
self.data = dataset[0]
if issubclass(dataset.__class__.__bases__[0], InMemoryDataset):
self.num_nodes = self.data.y.shape[0]
self.num_classes = dataset.num_classes
self.label_matrix = np.zeros((self.num_nodes, self.num_classes), dtype=int)
self.label_matrix[range(self.num_nodes), self.data.y] = 1
self.data.edge_attr = self.data.edge_attr.t()
else:
self.label_matrix = self.data.y
self.num_nodes, self.num_classes = self.data.y.shape
self.model = build_model(args)
self.model_name = args.model
self.hidden_size = args.hidden_size
self.num_shuffle = args.num_shuffle
self.save_dir = args.save_dir
self.enhance = args.enhance
self.args = args
self.is_weighted = self.data.edge_attr is not None
示例4: get_dataset
# 需要导入模块: from torch_geometric import data [as 别名]
# 或者: from torch_geometric.data import InMemoryDataset [as 别名]
def get_dataset(name: str, use_lcc: bool = True) -> InMemoryDataset:
path = os.path.join(DATA_PATH, name)
if name in ['Cora', 'Citeseer', 'Pubmed']:
dataset = Planetoid(path, name)
elif name in ['Computers', 'Photo']:
dataset = Amazon(path, name)
elif name == 'CoauthorCS':
dataset = Coauthor(path, 'CS')
else:
raise Exception('Unknown dataset.')
if use_lcc:
lcc = get_largest_connected_component(dataset)
x_new = dataset.data.x[lcc]
y_new = dataset.data.y[lcc]
row, col = dataset.data.edge_index.numpy()
edges = [[i, j] for i, j in zip(row, col) if i in lcc and j in lcc]
edges = remap_edges(edges, get_node_mapper(lcc))
data = Data(
x=x_new,
edge_index=torch.LongTensor(edges),
y=y_new,
train_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
test_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
val_mask=torch.zeros(y_new.size()[0], dtype=torch.bool)
)
dataset.data = data
return dataset
示例5: get_component
# 需要导入模块: from torch_geometric import data [as 别名]
# 或者: from torch_geometric.data import InMemoryDataset [as 别名]
def get_component(dataset: InMemoryDataset, start: int = 0) -> set:
visited_nodes = set()
queued_nodes = set([start])
row, col = dataset.data.edge_index.numpy()
while queued_nodes:
current_node = queued_nodes.pop()
visited_nodes.update([current_node])
neighbors = col[np.where(row == current_node)[0]]
neighbors = [n for n in neighbors if n not in visited_nodes and n not in queued_nodes]
queued_nodes.update(neighbors)
return visited_nodes
示例6: get_largest_connected_component
# 需要导入模块: from torch_geometric import data [as 别名]
# 或者: from torch_geometric.data import InMemoryDataset [as 别名]
def get_largest_connected_component(dataset: InMemoryDataset) -> np.ndarray:
remaining_nodes = set(range(dataset.data.x.shape[0]))
comps = []
while remaining_nodes:
start = min(remaining_nodes)
comp = get_component(dataset, start)
comps.append(comp)
remaining_nodes = remaining_nodes.difference(comp)
return np.array(list(comps[np.argmax(list(map(len, comps)))]))
示例7: get_adj_matrix
# 需要导入模块: from torch_geometric import data [as 别名]
# 或者: from torch_geometric.data import InMemoryDataset [as 别名]
def get_adj_matrix(dataset: InMemoryDataset) -> np.ndarray:
num_nodes = dataset.data.x.shape[0]
adj_matrix = np.zeros(shape=(num_nodes, num_nodes))
for i, j in zip(dataset.data.edge_index[0], dataset.data.edge_index[1]):
adj_matrix[i, j] = 1.
return adj_matrix