-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdataloader.py
More file actions
166 lines (145 loc) · 6.05 KB
/
dataloader.py
File metadata and controls
166 lines (145 loc) · 6.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import pickle
import json
import os
from torch_geometric.utils import subgraph
from torch_geometric.transforms import RandomNodeSplit
from torch import Tensor
import torch_geometric
from torch_geometric.data import Data
from typing import Union, List, Dict, Tuple, Callable, Optional
from torch_geometric.typing import NodeType, EdgeType
from torch.utils.data import DataLoader
import torch
from torch import Tensor
from tqdm import tqdm
from torch_geometric.data import Data, HeteroData
from torch_geometric.sampler.utils import to_csc, to_hetero_csc
from torch_geometric.loader.utils import filter_data
import torch_geometric
import numpy as np
import random
dir='/data/zhihao/Bitcoin'#change to your own data dir
data_dir='/data/zhihao/Bitcoin'#change to your own data dir
def load_pickle(fname):
with open(os.path.join(data_dir,fname), 'rb') as f:
return pickle.load(f)
def save_obj(obj, name ):
with open( os.path.join(data_dir,name)+'.pkl', 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def save_json(data, name):
with open( os.path.join(data_dir,name)+'.json', 'w') as f:
json.dump(data,f)
def load_json(fname):
with open( os.path.join(data_dir,fname), 'r') as f:
return json.load(f)
def subdata(data: torch_geometric.data.data.Data, subset, subedges=None, relabel_nodes=True):
device = data.edge_index.device
num_nodes = data.num_nodes
num_edges = data.edge_index.shape[1]
if isinstance(subset, (list, tuple)):
subset = torch.tensor(subset, dtype=torch.long, device=device)
if subset.dtype == torch.bool or subset.dtype == torch.uint8:
node_mask = subset
num_nodes = node_mask.size(0)
if relabel_nodes:
node_idx = torch.zeros(node_mask.size(0), dtype=torch.long,
device=device)
node_idx[subset] = torch.arange(subset.sum().item(), device=device)
else:
node_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)
node_mask[subset] = 1
if relabel_nodes:
node_idx = torch.zeros(num_nodes, dtype=torch.long, device=device)
node_idx[subset] = torch.arange(subset.size(0), device=device)
sub_data = Data()
if subedges is not None:
if subedges.dtype == torch.bool:
assert subedges.shape[0] == num_edges
else:
assert subedges.max() < num_edges
# Get subgraph nodes and edges feature
for key, item in data:
if key in ['num_nodes', 'edge_index']:
continue
if isinstance(item, Tensor) and item.size(0) == num_nodes:
sub_data[key] = item[subset]
elif isinstance(item, Tensor) and item.size(0) == num_edges:
if subedges is None:
edge_index, sub_data[key] = subgraph(subset, data.edge_index, data[key], relabel_nodes=relabel_nodes)
else:
sub_data[key] = item[subedges]
else:
sub_data[key] = item
if subedges is None:
sub_data.edge_index, _ = subgraph(subset, data.edge_index, relabel_nodes=relabel_nodes)
else:
edge_index = data.edge_index[:, subedges]
if relabel_nodes:
edge_index = node_idx[edge_index]
sub_data.edge_index = edge_index
return sub_data
def load_data(data_path='./data/2015', use_unlabeled = 'SEMI', scale='minmax', graph_type = 'MultiDi', feature_type ='edge', train_rate=0.5, anomaly_rate=None, random_state=5211):
# fix random seeds
if random_state is not None:
random.seed(random_state)
np.random.seed(random_state)
torch.manual_seed(random_state)
torch.cuda.manual_seed_all(random_state)
data = torch.load(data_path)
if anomaly_rate:
n_neg = (data.y == 0).sum().item()
pos_ids = (data.y == 1).nonzero().view(-1).numpy()
np.random.shuffle(pos_ids)
drop_pos_ids = pos_ids[int(n_neg*anomaly_rate/(1-anomaly_rate)):]
data.y[drop_pos_ids] = -1
# X = data.X
labels = data.y# label is here
n_nodes = len(labels)
all_id = np.arange(n_nodes)
# label_mask is used in semi-supervised setting to identify which nodes are labeled ones (attend loss
# calculation) while others are unlabeled ones
if use_unlabeled == 'ALL': # regard unlabeled as normal users
labels = np.where(labels == -1, 0, labels)
label_mask = torch.ones(len(labels)).bool()
elif use_unlabeled == 'NONE':
labels += 1
nodes_id = labels.nonzero().reshape(-1)
labels = labels[nodes_id]
labels -= 1
X = X[nodes_id]
data = subdata(data, nodes_id,relabel_nodes=True)
label_mask = torch.ones(len(labels)).bool()
elif use_unlabeled == 'SEMI':
labels += 1
label_id = labels.nonzero().reshape(-1)
label_mask = torch.zeros(n_nodes )
label_mask[label_id ] = 1
label_mask = label_mask.bool()
labels -= 1
n_nodes = len(labels) # refresh n_nodes
all_id = np.arange(n_nodes)
data['labels'] = torch.tensor(labels, dtype = int)
# Split data into train/val/test
# data.edge_index, data.edge_attr = add_self_loops(edge_index=data.edge_index, edge_attr =data.edge_attr, fill_value='mean')
# # Split data into train/val/test
np.random.shuffle(all_id)
train_id = all_id[:int(n_nodes*train_rate)]
val_id = all_id[int(n_nodes*0.5): int(n_nodes*(1+0.5)/2)]
test_id = all_id[int(n_nodes*(1+0.5)/2): -1]
train_mask = torch.zeros(n_nodes )
train_mask [train_id] = 1
train_mask = train_mask.bool()
val_mask = torch.zeros(n_nodes )
val_mask [val_id] = 1
val_mask = val_mask.bool()
test_mask = torch.zeros(n_nodes )
test_mask [test_id] = 1
test_mask = test_mask.bool()
data.train_mask = train_mask
data.test_mask = test_mask
data.val_mask = val_mask
# Mask nodes which are labeled in SEMI
data['train_label'] = data.train_mask&label_mask
data['val_label'] = data.val_mask&label_mask
data['test_label'] = data.test_mask&label_mask
return data, 2