Skip to content

Commit 6778ecf

Browse files
committed
Use sum type for WorkflowRunType
1 parent 02f7806 commit 6778ecf

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

Diff for: src/ci/github-actions/calculate-job-matrix.py

+26-14
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
and filters them based on the event that happened on CI.
99
"""
1010
import dataclasses
11-
import enum
1211
import json
1312
import logging
1413
import os
14+
import typing
1515
from pathlib import Path
1616
from typing import List, Dict, Any, Optional
1717

@@ -44,10 +44,22 @@ def add_base_env(jobs: List[Job], environment: Dict[str, str]) -> List[Job]:
4444
return jobs
4545

4646

47-
class WorkflowRunType(enum.Enum):
48-
PR = enum.auto()
49-
Try = enum.auto()
50-
Auto = enum.auto()
47+
@dataclasses.dataclass
48+
class PRRunType:
49+
pass
50+
51+
52+
@dataclasses.dataclass
53+
class TryRunType:
54+
custom_jobs: List[str]
55+
56+
57+
@dataclasses.dataclass
58+
class AutoRunType:
59+
pass
60+
61+
62+
WorkflowRunType = typing.Union[PRRunType, TryRunType, AutoRunType]
5163

5264

5365
@dataclasses.dataclass
@@ -59,7 +71,7 @@ class GitHubCtx:
5971

6072
def find_run_type(ctx: GitHubCtx) -> Optional[WorkflowRunType]:
6173
if ctx.event_name == "pull_request":
62-
return WorkflowRunType.PR
74+
return PRRunType()
6375
elif ctx.event_name == "push":
6476
old_bors_try_build = (
6577
ctx.ref in ("refs/heads/try", "refs/heads/try-perf") and
@@ -72,20 +84,20 @@ def find_run_type(ctx: GitHubCtx) -> Optional[WorkflowRunType]:
7284
try_build = old_bors_try_build or new_bors_try_build
7385

7486
if try_build:
75-
return WorkflowRunType.Try
87+
return TryRunType()
7688

7789
if ctx.ref == "refs/heads/auto" and ctx.repository == "rust-lang-ci/rust":
78-
return WorkflowRunType.Auto
90+
return AutoRunType()
7991

8092
return None
8193

8294

8395
def calculate_jobs(run_type: WorkflowRunType, job_data: Dict[str, Any]) -> List[Job]:
84-
if run_type == WorkflowRunType.PR:
96+
if isinstance(run_type, PRRunType):
8597
return add_base_env(name_jobs(job_data["pr"], "PR"), job_data["envs"]["pr"])
86-
elif run_type == WorkflowRunType.Try:
98+
elif isinstance(run_type, TryRunType):
8799
return add_base_env(name_jobs(job_data["try"], "try"), job_data["envs"]["try"])
88-
elif run_type == WorkflowRunType.Auto:
100+
elif isinstance(run_type, AutoRunType):
89101
return add_base_env(name_jobs(job_data["auto"], "auto"), job_data["envs"]["auto"])
90102

91103
return []
@@ -107,11 +119,11 @@ def get_github_ctx() -> GitHubCtx:
107119

108120

109121
def format_run_type(run_type: WorkflowRunType) -> str:
110-
if run_type == WorkflowRunType.PR:
122+
if isinstance(run_type, PRRunType):
111123
return "pr"
112-
elif run_type == WorkflowRunType.Auto:
124+
elif isinstance(run_type, AutoRunType):
113125
return "auto"
114-
elif run_type == WorkflowRunType.Try:
126+
elif isinstance(run_type, TryRunType):
115127
return "try"
116128
else:
117129
raise AssertionError()

0 commit comments

Comments
 (0)