Skip to content

Commit 232cd01

Browse files
committed
automated tagging
1 parent f48e7a6 commit 232cd01

File tree

7 files changed

+156
-3
lines changed

7 files changed

+156
-3
lines changed

.pre-commit-config.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,12 @@ repos:
3535
language: python
3636
name: Check all notebooks appear in table of contents
3737
types: [jupyter]
38+
- id: add-tags
39+
entry: python scripts/add_tags.py
40+
language: python
41+
name: Add PyMC3 classes used to tags
42+
types: [jupyter]
43+
additional_dependencies:
44+
- nbqa==1.1.1
45+
- beautifulsoup4==4.9.3
46+
- myst_parser==0.13.7

examples/case_studies/multilevel_modeling.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"# A Primer on Bayesian Methods for Multilevel Modeling\n",
88
"\n",
99
":::{post} 30 Aug, 2021\n",
10-
":tags: hierarchical\n",
10+
":tags: hierarchical, pymc3.Data, pymc3.Deterministic, pymc3.Exponential, pymc3.LKJCholeskyCov, pymc3.Model, pymc3.MvNormal, pymc3.Normal\n",
1111
":category: intermediate\n",
1212
":::"
1313
]

examples/case_studies/rugby_analytics.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"# A Hierarchical model for Rugby prediction\n",
88
"\n",
99
":::{post} 30 Aug, 2021\n",
10-
":tags: hierarchical, sports\n",
10+
":tags: hierarchical, pymc3.Data, pymc3.Deterministic, pymc3.HalfNormal, pymc3.Model, pymc3.Normal, pymc3.Poisson, sports\n",
1111
":category: intermediate\n",
1212
":::"
1313
]

examples/getting_started.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"Note: This text is based on the [PeerJ CS publication on PyMC3](https://peerj.com/articles/cs-55/).\n",
1212
"\n",
1313
":::{post} 30 Aug, 2021\n",
14-
":tags: glm, mcmc, exploratory analysis\n",
14+
":tags: exploratory analysis, glm, mcmc, pymc3.Data, pymc3.Deterministic, pymc3.DiscreteUniform, pymc3.Exponential, pymc3.GaussianRandomWalk, pymc3.HalfNormal, pymc3.Model, pymc3.Normal, pymc3.Poisson, pymc3.Slice, pymc3.StudentT\n",
1515
":category: beginner\n",
1616
":::"
1717
]

scripts/__init__.py

Whitespace-only changes.

scripts/add_tags.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""
2+
Automatically add tags to notebook based on which PyMC3 classes are used.
3+
4+
E.g. if a notebook contains a section like
5+
6+
:::{post} 30 Aug, 2021
7+
:tags: glm, mcmc, exploratory analysis
8+
:category: beginner
9+
:::
10+
11+
in a markdown cell, and uses the class pymc3.Categorical, then this script
12+
will change that part of the markdown cell to:
13+
14+
:::{post} 30 Aug, 2021
15+
:tags: glm, mcmc, exploratory analysis, pymc3.Categorical
16+
:category: beginner
17+
:::
18+
"""
19+
import sys
20+
from myst_parser.main import to_tokens, MdParserConfig
21+
import subprocess
22+
import os
23+
import json
24+
import argparse
25+
26+
27+
def main(argv=None):
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("files", nargs="*")
30+
args = parser.parse_args(argv)
31+
32+
for file in args.files:
33+
# Find which PyMC3 classes are used in the code.
34+
output = subprocess.run(
35+
[
36+
"nbqa",
37+
"scripts.find_pm_classes",
38+
file,
39+
],
40+
stdout=subprocess.PIPE,
41+
text=True,
42+
)
43+
classes = {f"pymc3.{obj}" for obj in output.stdout.splitlines()}
44+
45+
# Tokenize the notebook's markdown cells.
46+
with open(file, encoding="utf-8") as fd:
47+
content = fd.read()
48+
notebook = json.loads(content)
49+
markdown_cells = "\n".join(
50+
[
51+
"\n".join(cell["source"])
52+
for cell in notebook["cells"]
53+
if cell["cell_type"] == "markdown"
54+
]
55+
)
56+
config = MdParserConfig(enable_extensions=["dollarmath", "colon_fence"])
57+
tokens = to_tokens(markdown_cells, config=config)
58+
59+
# Find a ```{post} or :::{post} code block, and look for a line
60+
# starting with tags: or :tags:.
61+
tags = None
62+
for token in tokens:
63+
if token.tag == "code" and token.info.startswith("{post}"):
64+
for line in token.content.splitlines():
65+
if line.startswith("tags: "):
66+
line_start = "tags: "
67+
original_line = line
68+
tags = {tag.strip() for tag in line[len(line_start) :].split(",")}
69+
break
70+
elif line.startswith(":tags: "):
71+
line_start = ":tags: "
72+
original_line = line
73+
tags = {tag.strip() for tag in line[len(line_start) :].split(",")}
74+
break
75+
76+
# If tags were found, then append any PyMC3 classes which might have
77+
# been missed.
78+
if tags is not None:
79+
new_tags = ", ".join(sorted(tags.union(classes)))
80+
new_line = f"{line_start}{new_tags}"
81+
content = content.replace(original_line, new_line)
82+
with open(file, "w", encoding="utf-8") as fd:
83+
fd.write(content)
84+
85+
86+
if __name__ == "__main__":
87+
exit(main())

scripts/find_pm_classes.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
Find all PyMC3 classes used in script.
3+
4+
This'll find both call of
5+
6+
pymc3.Categorical(...
7+
8+
and
9+
10+
from pymc3 import Categorical
11+
Categorical
12+
"""
13+
import ast
14+
import sys
15+
16+
17+
class ImportVisitor(ast.NodeVisitor):
18+
def __init__(self, file):
19+
self.imports = set()
20+
21+
def visit_ImportFrom(self, node: ast.ImportFrom):
22+
if node.module.split(".")[0] == "pymc3":
23+
for name in node.names:
24+
if name.name[0].isupper():
25+
self.imports.add(name.name)
26+
27+
28+
class CallVisitor(ast.NodeVisitor):
29+
def __init__(self, file, imports):
30+
self.file = file
31+
self.imports = imports
32+
self.classes_used = set()
33+
34+
def visit_Call(self, node: ast.Call):
35+
if isinstance(node.func, ast.Attribute):
36+
if isinstance(node.func.value, ast.Name):
37+
if node.func.value.id in {"pm", "pymc3"}:
38+
if node.func.attr[0].isupper():
39+
self.classes_used.add(node.func.attr)
40+
elif isinstance(node.func, ast.Name):
41+
if node.func.id in self.imports:
42+
self.classes_used.add(node.func.id)
43+
44+
45+
if __name__ == "__main__":
46+
for file in sys.argv[1:]:
47+
with open(file) as fd:
48+
content = fd.read()
49+
tree = ast.parse(content)
50+
51+
import_visitor = ImportVisitor(file)
52+
import_visitor.visit(tree)
53+
54+
visitor = CallVisitor(file, import_visitor.imports)
55+
visitor.visit(tree)
56+
for i in visitor.classes_used:
57+
print(i)

0 commit comments

Comments
 (0)