|
9 | 9 | # obtain one at https://mozilla.org/MPL/2.0/.
|
10 | 10 |
|
11 | 11 | import json
|
| 12 | +import math |
12 | 13 | import statistics
|
13 | 14 | from pathlib import Path
|
14 | 15 |
|
15 |
| -import matplotlib.pyplot as plt |
16 |
| -import seaborn as sns |
17 |
| - |
18 |
| -data_path = Path(__file__).parent / "data.json" |
19 |
| -with open(data_path) as f: |
20 |
| - data = json.loads(f.read()) |
21 |
| - |
22 |
| -old_runs = data["old"] |
23 |
| -new_runs = data["new"] |
24 |
| -all_runs = old_runs + new_runs |
25 |
| - |
26 |
| -# every run should involve the same functions |
27 |
| -names = set() |
28 |
| -for run in all_runs: |
29 |
| - names.add(frozenset(run.keys())) |
30 |
| - |
31 |
| -intersection = frozenset.intersection(*names) |
32 |
| -diff = frozenset.union(*[intersection.symmetric_difference(n) for n in names]) |
33 |
| - |
34 |
| -print(f"skipping these tests which were not present in all runs: {', '.join(diff)}") |
35 |
| -names = list(intersection) |
36 |
| - |
37 |
| -# the similar invariant for number of minimal calls per run is not true: functions |
38 |
| -# may make a variable number of minimal() calls. |
39 |
| -# it would be nice to compare identically just the ones which don't vary, to get |
40 |
| -# a very fine grained comparison instead of averaging. |
41 |
| -# sizes = [] |
42 |
| -# for run in all_runs: |
43 |
| -# sizes.append(tuple(len(value) for value in run.values())) |
44 |
| -# assert len(set(sizes)) == 1 |
45 |
| - |
46 |
| -new_names = [] |
47 |
| -for name in names: |
48 |
| - if all(all(x == 0 for x in run[name]) for run in all_runs): |
49 |
| - print(f"no shrinks for {name}, skipping") |
50 |
| - continue |
51 |
| - new_names.append(name) |
52 |
| -names = new_names |
53 |
| - |
54 |
| -# either "time" or "calls" |
55 |
| -statistic = "time" |
56 |
| -# name : average calls |
57 |
| -old_values = {} |
58 |
| -new_values = {} |
59 |
| -for name in names: |
60 |
| - |
61 |
| - # mean across the different minimal() calls in a single test function, then |
62 |
| - # median across the n iterations we ran that for to reduce error |
63 |
| - old_vals = [statistics.mean(r[statistic] for r in run[name]) for run in old_runs] |
64 |
| - new_vals = [statistics.mean(r[statistic] for r in run[name]) for run in new_runs] |
65 |
| - old_values[name] = statistics.median(old_vals) |
66 |
| - new_values[name] = statistics.median(new_vals) |
67 |
| - |
68 |
| -# name : (absolute difference, times difference) |
69 |
| -diffs = {} |
70 |
| -for name in names: |
71 |
| - old = old_values[name] |
72 |
| - new = new_values[name] |
73 |
| - diff = old - new |
74 |
| - if old == 0: |
75 |
| - diff_times = 0 |
76 |
| - else: |
77 |
| - diff_times = (old - new) / old |
78 |
| - if 0 < diff_times < 1: |
79 |
| - diff_times = (1 / (1 - diff_times)) - 1 |
80 |
| - diffs[name] = (diff, diff_times) |
81 |
| - |
82 |
| - print(f"{name} {diff} ({old} -> {new}, {round(diff_times, 1)}✕)") |
83 |
| - |
84 |
| -diffs = dict(sorted(diffs.items(), key=lambda kv: kv[1][0])) |
85 |
| -diffs_value = [v[0] for v in diffs.values()] |
86 |
| -diffs_percentage = [v[1] for v in diffs.values()] |
87 |
| - |
88 |
| -print(f"mean: {statistics.mean(diffs_value)}, median: {statistics.median(diffs_value)}") |
89 |
| - |
90 |
| - |
91 |
| -# https://stackoverflow.com/a/65824524 |
92 |
| -def align_axes(ax1, ax2): |
93 |
| - ax1_ylims = ax1.axes.get_ylim() |
94 |
| - ax1_yratio = ax1_ylims[0] / ax1_ylims[1] |
95 |
| - |
96 |
| - ax2_ylims = ax2.axes.get_ylim() |
97 |
| - ax2_yratio = ax2_ylims[0] / ax2_ylims[1] |
98 |
| - |
99 |
| - if ax1_yratio < ax2_yratio: |
100 |
| - ax2.set_ylim(bottom=ax2_ylims[1] * ax1_yratio) |
101 |
| - else: |
102 |
| - ax1.set_ylim(bottom=ax1_ylims[1] * ax2_yratio) |
103 |
| - |
104 |
| - |
105 |
| -ax1 = sns.barplot(diffs_value, color="b", alpha=0.7, label="absolute change") |
106 |
| -ax2 = plt.twinx() |
107 |
| -sns.barplot(diffs_percentage, color="r", alpha=0.7, ax=ax2, label="n✕ change") |
108 |
| - |
109 |
| -ax1.set_title( |
110 |
| - "old shrinks - new shrinks (aka shrinks saved, higher is better)" |
111 |
| - if statistic == "calls" |
112 |
| - else "old time - new time in seconds (aka time saved, higher is better)" |
113 |
| -) |
114 |
| -ax1.set_xticks([]) |
115 |
| -align_axes(ax1, ax2) |
116 |
| -legend1 = ax1.legend(loc="upper left") |
117 |
| -legend1.legend_handles[0].set_color("b") |
118 |
| -legend2 = ax2.legend(loc="lower right") |
119 |
| -legend2.legend_handles[0].set_color("r") |
120 |
| - |
121 |
| -plt.show() |
| 16 | +import click |
| 17 | + |
| 18 | + |
| 19 | +def plot_vega(vega_spec, data, *, to, parameters=None): |
| 20 | + import vl_convert |
| 21 | + |
| 22 | + parameters = parameters or {} |
| 23 | + |
| 24 | + spec = json.loads(vega_spec.read_text()) |
| 25 | + spec["data"].insert(0, {"name": "source", "values": data}) |
| 26 | + if "signals" not in spec: |
| 27 | + spec["signals"] = [] |
| 28 | + |
| 29 | + for key, value in parameters.items(): |
| 30 | + spec["signals"].append({"name": key, "value": value}) |
| 31 | + |
| 32 | + with open(to, "wb") as f: |
| 33 | + # default ppi is 72, which is somewhat blurry. |
| 34 | + f.write(vl_convert.vega_to_png(spec, ppi=200)) |
| 35 | + |
| 36 | + |
| 37 | +def _mean_difference_ci(n1, n2, *, confidence): |
| 38 | + from scipy import stats |
| 39 | + |
| 40 | + var1 = statistics.variance(n1) |
| 41 | + var2 = statistics.variance(n2) |
| 42 | + df = len(n1) + len(n2) - 2 |
| 43 | + # this assumes equal variances between the populations of n1 and n2. This |
| 44 | + # is not necessarily true (new might be more consistent than old), but it's |
| 45 | + # good enough. |
| 46 | + pooled_std = math.sqrt(((len(n1) - 1) * var1 + (len(n2) - 1) * var2) / df) |
| 47 | + se = pooled_std * math.sqrt(1 / len(n1) + 1 / len(n2)) |
| 48 | + t_crit = stats.t.ppf((1 + confidence) / 2, df) |
| 49 | + return t_crit * se |
| 50 | + |
| 51 | + |
| 52 | +def _process_benchmark_data(data): |
| 53 | + assert set(data) == {"old", "new"} |
| 54 | + old_calls = data["old"]["calls"] |
| 55 | + new_calls = data["new"]["calls"] |
| 56 | + assert set(old_calls) == set(new_calls), set(old_calls).symmetric_difference( |
| 57 | + set(new_calls) |
| 58 | + ) |
| 59 | + |
| 60 | + graph_data = [] |
| 61 | + |
| 62 | + def _diff_times(old, new): |
| 63 | + if old == 0 and new == 0: |
| 64 | + return 0 |
| 65 | + if old == 0: |
| 66 | + # there aren't any great options here, but 0 is more reasonable than inf. |
| 67 | + return 0 |
| 68 | + v = (old - new) / old |
| 69 | + if 0 < v < 1: |
| 70 | + v = (1 / (1 - v)) - 1 |
| 71 | + return v |
| 72 | + |
| 73 | + sums = {"old": 0, "new": 0} |
| 74 | + for node_id in old_calls: |
| 75 | + old = old_calls[node_id] |
| 76 | + new = new_calls[node_id] |
| 77 | + if set(old) | set(new) == {0} or len(old) != len(new): |
| 78 | + print(f"skipping {node_id}") |
| 79 | + continue |
| 80 | + |
| 81 | + sums["old"] += statistics.mean(old) |
| 82 | + sums["new"] += statistics.mean(new) |
| 83 | + diffs = [n_old - n_new for n_old, n_new in zip(old, new)] |
| 84 | + diffs_times = [_diff_times(n_old, n_new) for n_old, n_new in zip(old, new)] |
| 85 | + ci_shrink = ( |
| 86 | + _mean_difference_ci(old, new, confidence=0.95) if len(old) > 1 else 0 |
| 87 | + ) |
| 88 | + |
| 89 | + graph_data.append( |
| 90 | + { |
| 91 | + "node_id": node_id, |
| 92 | + "absolute": statistics.mean(diffs), |
| 93 | + "absolute_ci_lower": ci_shrink, |
| 94 | + "absolute_ci_upper": ci_shrink, |
| 95 | + "nx": statistics.mean(diffs_times), |
| 96 | + "nx_ci_lower": 0, |
| 97 | + "nx_ci_upper": 0, |
| 98 | + } |
| 99 | + ) |
| 100 | + |
| 101 | + graph_data = sorted(graph_data, key=lambda d: d["absolute"]) |
| 102 | + return graph_data, sums |
| 103 | + |
| 104 | + |
| 105 | +@click.command() |
| 106 | +@click.argument("data", type=click.Path(exists=True, path_type=Path)) |
| 107 | +@click.argument("out", type=click.Path(path_type=Path)) |
| 108 | +def plot(data, out): |
| 109 | + data = json.loads(data.read_text()) |
| 110 | + data, sums = _process_benchmark_data(data) |
| 111 | + plot_vega( |
| 112 | + Path(__file__).parent / "spec.json", |
| 113 | + data=data, |
| 114 | + to=out, |
| 115 | + parameters={ |
| 116 | + "title": "Shrinking benchmark (calls)", |
| 117 | + "sum_old": sums["old"], |
| 118 | + "sum_new": sums["new"], |
| 119 | + "absolute_axis_title": ("shrink call change (old - new, larger is good)"), |
| 120 | + }, |
| 121 | + ) |
| 122 | + |
| 123 | + |
| 124 | +if __name__ == "__main__": |
| 125 | + plot() |
0 commit comments