12
12
warnings .simplefilter (action = 'ignore' , category = NumbaDeprecationWarning )
13
13
14
14
15
- def load_data () -> dict :
16
-
17
-
15
+ def load_data (dataset_name = args .dataset_name ) -> dict :
18
16
"""
19
17
:return: dict that contains the following fields
20
18
z: np.ndarray of shape (num_nodes, num_timesteps, num_dims): node embeddings
@@ -24,16 +22,16 @@ def load_data() -> dict:
24
22
"""
25
23
26
24
config = json .load (
27
- open (osp .join ("config" , f"{ args . dataset_name } .json" ), 'r' ,
25
+ open (osp .join ("config" , f"{ dataset_name } .json" ), 'r' ,
28
26
encoding = 'utf-8' ))
29
27
30
28
z = np .load (
31
- osp .join ("data" , args . dataset_name , f"embeds_{ args . dataset_name } .npy" ))
29
+ osp .join ("data" , dataset_name , f"embeds_{ dataset_name } .npy" ))
32
30
node2idx = json .load (
33
- open (osp .join ("data" , args . dataset_name , "node2idx.json" ), 'r' ,
31
+ open (osp .join ("data" , dataset_name , "node2idx.json" ), 'r' ,
34
32
encoding = 'utf-8' ))
35
33
perplexity = config ["perplexity" ]
36
-
34
+ model_name = config [ "model_name" ]
37
35
idx_reference_snapshot = config ["idx_reference_snapshot" ]
38
36
39
37
# Optional argument
@@ -50,7 +48,7 @@ def load_data() -> dict:
50
48
51
49
try :
52
50
node2label = json .load (
53
- open (osp .join ("data" , args . dataset_name , "node2label.json" ), 'r' ,
51
+ open (osp .join ("data" , dataset_name , "node2label.json" ), 'r' ,
54
52
encoding = 'utf-8' ))
55
53
56
54
@@ -61,7 +59,7 @@ def load_data() -> dict:
61
59
if isinstance (config ['reference_nodes' ], str ) and config [
62
60
'reference_nodes' ].endswith ("json" ):
63
61
reference_nodes = json .load (
64
- open (osp .join ("data" , args . dataset_name , config ['reference_nodes' ]),
62
+ open (osp .join ("data" , dataset_name , config ['reference_nodes' ]),
65
63
'r' , encoding = 'utf-8' ))
66
64
67
65
elif isinstance (config ['reference_nodes' ], list ):
@@ -73,7 +71,7 @@ def load_data() -> dict:
73
71
if isinstance (config ['projected_nodes' ], str ) and config [
74
72
'projected_nodes' ].endswith ("json" ):
75
73
projected_nodes = json .load (
76
- open (osp .join ("data" , args . dataset_name , config ['projected_nodes' ]),
74
+ open (osp .join ("data" , dataset_name , config ['projected_nodes' ]),
77
75
'r' , encoding = 'utf-8' ))
78
76
79
77
elif isinstance (config ['projected_nodes' ], list ):
@@ -92,9 +90,7 @@ def load_data() -> dict:
92
90
93
91
highlighted_nodes = []
94
92
95
-
96
-
97
- if args .dataset_name == "Chickenpox" :
93
+ if dataset_name == "Chickenpox" :
98
94
highlighted_nodes = np .array (
99
95
["BUDAPEST" , "PEST" , "BORSOD" , "ZALA" , "NOGRAD" , "TOLNA" , "VAS" ])
100
96
@@ -109,7 +105,7 @@ def load_data() -> dict:
109
105
snapshot_names = snapshot_names [0 :100 ]
110
106
111
107
weekly_cases = pd .read_csv (
112
- osp .join ("data" , args . dataset_name , "hungary_chickenpox.csv" ))
108
+ osp .join ("data" , dataset_name , "hungary_chickenpox.csv" ))
113
109
114
110
ys = weekly_cases [reference_nodes ].values
115
111
@@ -119,26 +115,32 @@ def load_data() -> dict:
119
115
metadata_df ["Country" ] = "Hungary"
120
116
121
117
122
- elif args . dataset_name == "DGraphFin" :
118
+ elif dataset_name == "DGraphFin" :
123
119
plot_anomaly_labels = True
124
120
# Eliminate background nodes
125
121
node2label = {n : l for n , l in node2label .items () if l in [0 , 1 ]}
126
122
127
123
128
- elif args .dataset_name == "BMCBioinformatics2021" :
124
+
125
+ # elif dataset_name == "Reddit":
126
+ # # 2018-01, ..., 2022-12
127
+ # snapshot_names = const.dataset_names_60_months
128
+
129
+
130
+ elif dataset_name == "BMCBioinformatics2021" :
131
+
129
132
130
133
plot_anomaly_labels = True
131
134
132
135
metadata_df = pd .read_excel (
133
- osp .join ("data" , args . dataset_name , "metadata.xlsx" ))
136
+ osp .join ("data" , dataset_name , "metadata.xlsx" ))
134
137
metadata_df = metadata_df .rename (columns = {"entrez" : "node" })
135
138
metadata_df = metadata_df .astype ({"node" : str })
136
139
node2idx = {str (k ): v for k , v in node2idx .items ()}
137
140
138
141
metadata_df = metadata_df .drop (
139
142
columns = ["summary" , "lineage" , "gene_type" ])
140
143
141
-
142
144
elif args .dataset_name == "HistWords-CN-GNN" :
143
145
144
146
metadata_df = pd .read_csv (
@@ -161,14 +163,16 @@ def load_data() -> dict:
161
163
if node_presence is None :
162
164
try :
163
165
node_presence = np .load (
164
- osp .join ("data" , args . dataset_name , "node_presence.npy" ))
166
+ osp .join ("data" , dataset_name , "node_presence.npy" ))
165
167
166
168
except FileNotFoundError :
167
169
print (
168
170
"node_presence.npy not found. Assuming all nodes are present at all timesteps." )
169
171
node_presence = np .ones ((z .shape [0 ], z .shape [1 ]), dtype = bool )
170
172
171
173
return {
174
+ "dataset_name" : dataset_name ,
175
+ "model_name" : model_name ,
172
176
"display_node_type" : display_node_type ,
173
177
"highlighted_nodes" : highlighted_nodes ,
174
178
"idx_reference_snapshot" : idx_reference_snapshot ,
0 commit comments