Skip to content

Commit 6b6374c

Browse files
committed
Train on Chickenpox Dataset
Add data download Load from training config
1 parent 0006d1b commit 6b6374c

14 files changed

+820
-42
lines changed

config/Ant.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"embedding_dim": 128,
3+
"idx_reference_snapshot": 7,
4+
"interpolation": 0.2,
5+
"model_name": "GConvGRU",
6+
"num_nearest_neighbors": [20],
7+
"perplexity": 2,
8+
"projected_nodes": "projected_nodes.json",
9+
"reference_nodes": "reference_nodes.json"
10+
}

config/Chickenpox.json

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,18 @@
1010
"perplexity": 2,
1111
"reference_nodes": ["BACS", "BARANYA", "BEKES", "BORSOD", "BUDAPEST", "CSONGRAD",
1212
"FEJER", "GYOR", "HAJDU", "HEVES", "JASZ", "KOMAROM", "NOGRAD",
13-
"PEST", "SOMOGY", "SZABOLCS", "TOLNA", "VAS", "VESZPREM", "ZALA"]
13+
"PEST", "SOMOGY", "SZABOLCS", "TOLNA", "VAS", "VESZPREM", "ZALA"],
14+
15+
16+
"do_node_classification": false,
17+
"do_node_regression": true,
18+
"do_edge_classification": false,
19+
"do_edge_regression": false,
20+
"num_classes_nodes": 0,
21+
"num_classes_edges": 0,
22+
"do_link_prediction": true,
23+
24+
"tasks": ["node_regression","link_pred"]
25+
26+
1427
}

dygetviz/arguments.py

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,30 @@
11
import argparse
22
import os
33
import os.path as osp
4-
import uuid
5-
from pprint import pprint
64

75
import const
86
from const import *
97

10-
if platform.system() == "Windows":
11-
DEVICE = "cuda:0"
8+
if platform.system() in ["Windows", "Linux"]:
9+
import torch
1210

13-
elif platform.system() == "Linux":
14-
DEVICE = "cuda:0"
11+
if torch.cuda.is_available():
12+
DEFAULT_DEVICE = "cuda:0"
13+
else:
14+
DEFAULT_DEVICE = "cpu"
1515

1616

1717
elif platform.system() == "Darwin":
18-
DEVICE = "mps:0"
18+
DEFAULT_DEVICE = "mps:0"
1919

2020

2121
else:
2222
raise NotImplementedError("Unknown System")
2323

24-
print(f"Your system: {platform.system()}. Default device: {DEVICE}")
24+
print(f"Your system: {platform.system()}. Default device: {DEFAULT_DEVICE}")
2525

26-
27-
28-
29-
parser = argparse.ArgumentParser(description="Dynamic Graph Embedding Trajectory.")
26+
parser = argparse.ArgumentParser(
27+
description="Dynamic Graph Embedding Trajectory.")
3028
# Parameters for Analysis
3129
parser.add_argument('--do_visual', action='store_true',
3230
help="Whether to do visualization")
@@ -43,22 +41,25 @@
4341
help="Comment for each run. Useful for identifying each run on Tensorboard")
4442
parser.add_argument('--data_dir', type=str, default="data",
4543
help="Location to store all the data.")
46-
parser.add_argument('--dataset_name', type=str, default='Chickenpox', help="Name of dataset.")
47-
parser.add_argument('--device', type=str, default=DEVICE, help="Device to use. When using multi-gpu, this is the 'master' device where all operations are performed.")
48-
parser.add_argument('--device2', type=str, default='cpu',
49-
help="For Multi-GPU training")
44+
parser.add_argument('--dataset_name', type=str, default='Chickenpox',
45+
help="Name of dataset.")
46+
parser.add_argument('--device', type=str, default=DEFAULT_DEVICE,
47+
help="Device to use. When using multi-gpu, this is the 'master' device where all operations are performed.")
5048

5149
parser.add_argument('--do_test', action='store_true')
5250
parser.add_argument('--do_val', action='store_true')
53-
parser.add_argument('--do_weighted', action='store_true', help="Construct weighted graph instead of multigraph for each graph snapshot")
51+
parser.add_argument('--do_weighted', action='store_true',
52+
help="Construct weighted graph instead of multigraph for each graph snapshot")
5453

5554
parser.add_argument('--dropout', type=float, default=0.1,
5655
help="Dropout rate (1 - keep probability).")
5756

58-
59-
parser.add_argument('--embedding_dim', type=int, default=64, help="the embedding size of model")
60-
parser.add_argument('--embedding_dim_user', type=int, default=32, help="The embedding size for the users")
61-
parser.add_argument('--embedding_dim_resource', type=int, default=32, help="The embedding size for the resource (e.g. video)")
57+
parser.add_argument('--embedding_dim', type=int, default=64,
58+
help="the embedding size of model")
59+
parser.add_argument('--embedding_dim_user', type=int, default=32,
60+
help="The embedding size for the users")
61+
parser.add_argument('--embedding_dim_resource', type=int, default=32,
62+
help="The embedding size for the resource (e.g. video)")
6263

6364
parser.add_argument('--epochs', type=int, default=50,
6465
help="Number of epochs to train.")
@@ -85,24 +86,30 @@
8586

8687
parser.add_argument('--i_end', type=int, default=None,
8788
help="Index of the end dataset.")
88-
89+
parser.add_argument('--in_channels', type=int, default=None,
90+
help="Index of the end dataset.")
8991

9092
parser.add_argument('--lr', type=float, default=1e-3, help="Learning rate")
9193
parser.add_argument('--max_seq_length', type=int, default=128,
9294
help="Maximum sequence length")
9395

94-
parser.add_argument('--model', type=str, default=None, help="Model Name")
96+
parser.add_argument('--model', type=str, default=None, help="Model name")
9597

96-
parser.add_argument('--node_types', type=str, choices=["v_subreddit", "author_subreddit", "author_resource"], default="v_subreddit",
98+
parser.add_argument('--node_types', type=str,
99+
choices=["v_subreddit", "author_subreddit",
100+
"author_resource"], default="v_subreddit",
97101
help="What types of node to include in the GCN bipartite graph?")
98102

99103
parser.add_argument('--num_negative_candidates', type=int, default=1000,
100104
help="How many negative examples to sample for each video during the initial sampling?")
101-
parser.add_argument('--num_neighbors', type=int, default=10, help="Number of neighboring nodes in GNN")
105+
parser.add_argument('--num_neighbors', type=int, default=10,
106+
help="Number of neighboring nodes in GNN")
102107
parser.add_argument('--num_resource_prototypes', type=int, default=-1, help="")
103108

104-
parser.add_argument('--num_workers', type=int, default=1, help="Number of workers for multiprocessing")
105-
parser.add_argument('--perplexity', type=int, default=20, help="Perplexity of the generated t-SNE plot")
109+
parser.add_argument('--num_workers', type=int, default=1,
110+
help="Number of workers for multiprocessing")
111+
parser.add_argument('--perplexity', type=int, default=20,
112+
help="Perplexity of the generated t-SNE plot")
106113
parser.add_argument('--pretrained_embeddings_epoch', type=int, default=195,
107114
help="Which epoch of the pretrained embeddings (Node2Vec, GCN ...) to use")
108115
parser.add_argument('--output_dir', type=str, default="outputs")
@@ -122,19 +129,11 @@
122129
parser.add_argument('--num_sample_author', type=int, default=-1,
123130
help="Number of resource to sample in our dataset. Set to -1 if we do not want to sample")
124131

125-
126-
127-
128132
parser.add_argument('--port', type=int, default=8050)
129133

130-
131134
parser.add_argument('--save_embed_every', type=int, default=10,
132135
help="How many epochs to save embeddings for visualization?")
133136

134-
parser.add_argument('--save_resource_embed', action='store_true',
135-
help="Whether to save the embeddings for resources (videos, URLs, Misinformative URLs)?")
136-
137-
138137
parser.add_argument('--save_model_every', type=int, default=-1,
139138
help="How many epochs to save the model weights?")
140139
parser.add_argument('--seed', type=int, default=42, help="Random seed.")
@@ -154,17 +153,33 @@
154153
parser.add_argument('--snapshot_interval', type=int, default=1,
155154
help="Time interval (in days) between each snapshot. Default: 1 month. Interactions happening within this time interval will be grouped into one snapshot.")
156155

156+
parser.add_argument('--transform_input', action='store_true',
157+
help="Whether to transform the input to a new embedding space. This field is automatically set to True if in_channels does not equal to embedding_dim")
158+
157159
parser.add_argument('--suffix', type=str, default="",
158160
help="Suffix to append to the end of the log file name")
159161

160-
parser.add_argument('--visualization_dim', type=int, choices=[2, 3], default=2, help="Dimension of the generated visualization. Can be 2- or 3-dimensional.")
162+
parser.add_argument('--tasks', type=str,
163+
default="['node_classification','link_pred']",
164+
help="Tasks to run, passed as a list of strings")
165+
166+
parser.add_argument('--visualization_dim', type=int, choices=[2, 3], default=2,
167+
help="Dimension of the generated visualization. Can be 2- or 3-dimensional.")
161168

162-
parser.add_argument('--visualization_model', type=str, choices=[const.TSNE, const.UMAP, const.PCA, const.ISOMAP, const.MDS], default=const.TSNE,
169+
parser.add_argument('--visualization_model', type=str,
170+
choices=[const.TSNE, const.UMAP, const.PCA, const.ISOMAP,
171+
const.MDS], default=const.TSNE,
163172
help="Visualization model to use")
164173

165174
args = parser.parse_args()
166175

176+
if args.in_channels is None:
177+
args.in_channels = args.embedding_dim
167178
args.num_nearest_neighbors = eval(args.num_nearest_neighbors)
168179

169180
args.visual_dir = osp.join(args.output_dir, "visual", args.dataset_name)
170-
os.makedirs(args.visual_dir, exist_ok=True)
181+
os.makedirs(args.visual_dir, exist_ok=True)
182+
183+
args.transform_input = args.in_channels != args.embedding_dim
184+
args.tasks = eval(args.tasks)
185+
print(args.tasks)

dygetviz/const.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import platform
22

33

4+
DATASET2FILEID = {
5+
"Chickenpox": "1oAO5S1ikjxbbgPzBhZJf7Xf9bodbwwCE",
6+
7+
}
8+
49
# Features in Subreddit
510
AUTHOR = 'author'
611
AUTHOR_FULLNAME = 'author_fullname'

dygetviz/data/chickenpox.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import os.path as osp
2+
import pickle
3+
import zipfile
4+
5+
import torch
6+
import pandas as pd
7+
import numpy as np
8+
from typing import List, Union
9+
from torch_geometric.data import Data
10+
from torch_geometric_temporal import DynamicGraphStaticSignal
11+
12+
from dygetviz.data.dygetviz_dataset import DyGETVizDataset
13+
from dygetviz.data.static_graph_static_signal import StaticGraphStaticSignal
14+
15+
Edge_Index = Union[np.ndarray, None]
16+
Edge_Weight = Union[np.ndarray, None]
17+
Node_Features = List[Union[np.ndarray, None]]
18+
Targets = List[Union[np.ndarray, None]]
19+
Additional_Features = List[np.ndarray]
20+
21+
class ChickenpoxDataset(StaticGraphStaticSignal, DyGETVizDataset):
22+
def __init__(self, args, **kwargs: Additional_Features):
23+
self.args = args
24+
25+
self.dataset_name = "Chickenpox"
26+
27+
DyGETVizDataset.__init__(self, self.dataset_name, **kwargs)
28+
29+
30+
if osp.exists(self.dataset_path):
31+
32+
with open(self.dataset_path, "rb") as f:
33+
d = pickle.load(f)
34+
35+
else:
36+
self.download()
37+
38+
d = self.process()
39+
40+
41+
42+
node2idx = d["node2idx"]
43+
targets = d["targets"]
44+
node_presence = d["node_presence"]
45+
edge_index = d["edge_index"]
46+
edge_weight = d["edge_weight"]
47+
48+
self.num_nodes = len(node2idx)
49+
50+
limit = np.sqrt(6 / (self.num_nodes + args.embedding_dim))
51+
features = np.random.uniform(-limit, limit, size=(
52+
self.num_nodes, args.embedding_dim))
53+
54+
StaticGraphStaticSignal.__init__(self,
55+
edge_index=edge_index,
56+
edge_weight=edge_weight,
57+
features=features,
58+
targets=targets,
59+
node_masks=node_presence,
60+
**kwargs
61+
)
62+
63+
64+
def process(self):
65+
mapping = pd.read_excel(osp.join(self.cache_dir, "raw_data", "idx2county.xlsx"))
66+
67+
self.nodes = mapping["county"].values
68+
69+
node2idx = {row["id"]: row["county"] for idx, row in
70+
mapping.iterrows()}
71+
72+
idx2node = {v: k for k, v in node2idx.items()}
73+
74+
edges = pd.read_csv(osp.join(self.cache_dir, "raw_data", "hungary_edges.csv"))
75+
76+
edge_index = [edges[["id_1", "id_2"]].values.T for i in range(522)]
77+
78+
edge_weight = [np.ones(edges.shape[0]) for i in range(522)]
79+
80+
# We use the actual #weekly cases as the ground-truth
81+
weekly_cases = pd.read_csv(
82+
osp.join(self.cache_dir, "raw_data", "hungary_weekly_chickenpox_cases.csv"))
83+
84+
85+
# We predict the log2 of the weekly cases
86+
targets = weekly_cases.loc[:, self.nodes].values
87+
# There are 522 weeks in total
88+
targets = [np.log2(targets[i] + 1) for i in range(522)]
89+
90+
node_presence = np.ones((522, len(self.nodes)))
91+
92+
93+
94+
d = {
95+
"targets": targets,
96+
"node_presence": node_presence,
97+
"node2idx": node2idx,
98+
"idx2node": idx2node,
99+
"edge_index": edge_index,
100+
"edge_weight": edge_weight
101+
}
102+
103+
with open(self.dataset_path, "wb") as f:
104+
pickle.dump(d, f)
105+
106+
return d

dygetviz/data/dataloader.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import pandas as pd
77
from numba import NumbaDeprecationWarning
88

9-
import const
109
from arguments import args
1110

11+
from dygetviz.data.chickenpox import ChickenpoxDataset
12+
1213
warnings.simplefilter(action='ignore', category=NumbaDeprecationWarning)
1314

1415

@@ -21,6 +22,8 @@ def load_data(dataset_name=args.dataset_name) -> dict:
2122
node_presence: np.ndarray of shape (num_nodes, num_timesteps): 1 if node is present at timestep, 0 otherwise
2223
"""
2324

25+
26+
2427
config = json.load(
2528
open(osp.join("config", f"{dataset_name}.json"), 'r',
2629
encoding='utf-8'))
@@ -191,3 +194,27 @@ def load_data(dataset_name=args.dataset_name) -> dict:
191194
"z": z,
192195

193196
}
197+
198+
199+
200+
def load_data_dtdg(dataset_name: str):
201+
"""
202+
Load data for embedding training on Discrete-Time Dynamic-Graph (DTDG) models.
203+
"""
204+
from torch_geometric_temporal.signal import temporal_signal_split
205+
206+
if dataset_name == "UNComtrade":
207+
208+
path = osp.join(args.cache_dir, f"full_dataset_{args.dataset_name}.pt")
209+
full_dataset = UNComtradeDataset(args)
210+
211+
elif dataset_name == "Chickenpox":
212+
full_dataset = ChickenpoxDataset(args)
213+
214+
else:
215+
raise NotImplementedError
216+
217+
train_dataset, test_dataset = temporal_signal_split(full_dataset,
218+
train_ratio=1.)
219+
220+
return train_dataset, test_dataset, full_dataset

0 commit comments

Comments
 (0)