Skip to content

Commit 0006d1b

Browse files
committed
2 parents 4b46a8d + 0f6a48d commit 0006d1b

File tree

4 files changed

+644
-21
lines changed

4 files changed

+644
-21
lines changed

dygetviz/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
help="Comment for each run. Useful for identifying each run on Tensorboard")
4444
parser.add_argument('--data_dir', type=str, default="data",
4545
help="Location to store all the data.")
46-
parser.add_argument('--dataset_name', type=str, help="")
46+
parser.add_argument('--dataset_name', type=str, default='Chickenpox', help="Name of dataset.")
4747
parser.add_argument('--device', type=str, default=DEVICE, help="Device to use. When using multi-gpu, this is the 'master' device where all operations are performed.")
4848
parser.add_argument('--device2', type=str, default='cpu',
4949
help="For Multi-GPU training")

dygetviz/data/dataloader.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
warnings.simplefilter(action='ignore', category=NumbaDeprecationWarning)
1313

1414

15-
def load_data() -> dict:
16-
17-
15+
def load_data(dataset_name=args.dataset_name) -> dict:
1816
"""
1917
:return: dict that contains the following fields
2018
z: np.ndarray of shape (num_nodes, num_timesteps, num_dims): node embeddings
@@ -24,16 +22,16 @@ def load_data() -> dict:
2422
"""
2523

2624
config = json.load(
27-
open(osp.join("config", f"{args.dataset_name}.json"), 'r',
25+
open(osp.join("config", f"{dataset_name}.json"), 'r',
2826
encoding='utf-8'))
2927

3028
z = np.load(
31-
osp.join("data", args.dataset_name, f"embeds_{args.dataset_name}.npy"))
29+
osp.join("data", dataset_name, f"embeds_{dataset_name}.npy"))
3230
node2idx = json.load(
33-
open(osp.join("data", args.dataset_name, "node2idx.json"), 'r',
31+
open(osp.join("data", dataset_name, "node2idx.json"), 'r',
3432
encoding='utf-8'))
3533
perplexity = config["perplexity"]
36-
34+
model_name = config["model_name"]
3735
idx_reference_snapshot = config["idx_reference_snapshot"]
3836

3937
# Optional argument
@@ -50,7 +48,7 @@ def load_data() -> dict:
5048

5149
try:
5250
node2label = json.load(
53-
open(osp.join("data", args.dataset_name, "node2label.json"), 'r',
51+
open(osp.join("data", dataset_name, "node2label.json"), 'r',
5452
encoding='utf-8'))
5553

5654

@@ -61,7 +59,7 @@ def load_data() -> dict:
6159
if isinstance(config['reference_nodes'], str) and config[
6260
'reference_nodes'].endswith("json"):
6361
reference_nodes = json.load(
64-
open(osp.join("data", args.dataset_name, config['reference_nodes']),
62+
open(osp.join("data", dataset_name, config['reference_nodes']),
6563
'r', encoding='utf-8'))
6664

6765
elif isinstance(config['reference_nodes'], list):
@@ -73,7 +71,7 @@ def load_data() -> dict:
7371
if isinstance(config['projected_nodes'], str) and config[
7472
'projected_nodes'].endswith("json"):
7573
projected_nodes = json.load(
76-
open(osp.join("data", args.dataset_name, config['projected_nodes']),
74+
open(osp.join("data", dataset_name, config['projected_nodes']),
7775
'r', encoding='utf-8'))
7876

7977
elif isinstance(config['projected_nodes'], list):
@@ -92,9 +90,7 @@ def load_data() -> dict:
9290

9391
highlighted_nodes = []
9492

95-
96-
97-
if args.dataset_name == "Chickenpox":
93+
if dataset_name == "Chickenpox":
9894
highlighted_nodes = np.array(
9995
["BUDAPEST", "PEST", "BORSOD", "ZALA", "NOGRAD", "TOLNA", "VAS"])
10096

@@ -109,7 +105,7 @@ def load_data() -> dict:
109105
snapshot_names = snapshot_names[0:100]
110106

111107
weekly_cases = pd.read_csv(
112-
osp.join("data", args.dataset_name, "hungary_chickenpox.csv"))
108+
osp.join("data", dataset_name, "hungary_chickenpox.csv"))
113109

114110
ys = weekly_cases[reference_nodes].values
115111

@@ -119,26 +115,32 @@ def load_data() -> dict:
119115
metadata_df["Country"] = "Hungary"
120116

121117

122-
elif args.dataset_name == "DGraphFin":
118+
elif dataset_name == "DGraphFin":
123119
plot_anomaly_labels = True
124120
# Eliminate background nodes
125121
node2label = {n: l for n, l in node2label.items() if l in [0, 1]}
126122

127123

128-
elif args.dataset_name == "BMCBioinformatics2021":
124+
125+
# elif dataset_name == "Reddit":
126+
# # 2018-01, ..., 2022-12
127+
# snapshot_names = const.dataset_names_60_months
128+
129+
130+
elif dataset_name == "BMCBioinformatics2021":
131+
129132

130133
plot_anomaly_labels = True
131134

132135
metadata_df = pd.read_excel(
133-
osp.join("data", args.dataset_name, "metadata.xlsx"))
136+
osp.join("data", dataset_name, "metadata.xlsx"))
134137
metadata_df = metadata_df.rename(columns={"entrez": "node"})
135138
metadata_df = metadata_df.astype({"node": str})
136139
node2idx = {str(k): v for k, v in node2idx.items()}
137140

138141
metadata_df = metadata_df.drop(
139142
columns=["summary", "lineage", "gene_type"])
140143

141-
142144
elif args.dataset_name == "HistWords-CN-GNN":
143145

144146
metadata_df = pd.read_csv(
@@ -161,14 +163,16 @@ def load_data() -> dict:
161163
if node_presence is None:
162164
try:
163165
node_presence = np.load(
164-
osp.join("data", args.dataset_name, "node_presence.npy"))
166+
osp.join("data", dataset_name, "node_presence.npy"))
165167

166168
except FileNotFoundError:
167169
print(
168170
"node_presence.npy not found. Assuming all nodes are present at all timesteps.")
169171
node_presence = np.ones((z.shape[0], z.shape[1]), dtype=bool)
170172

171173
return {
174+
"dataset_name": dataset_name,
175+
"model_name": model_name,
172176
"display_node_type": display_node_type,
173177
"highlighted_nodes": highlighted_nodes,
174178
"idx_reference_snapshot": idx_reference_snapshot,

0 commit comments

Comments
 (0)