Skip to content

Commit 9a5c3e0

Browse files
authored
Fix @task.kubernetes to receive input and send output (#28942)
* Fix @task.kubernetes to receive input and send output * Pickle input and rm unnecessary env vars * Back to env vars and make cmds easier to read * Remove check for op_args and op_kwargs on input write
1 parent 73c8e7d commit 9a5c3e0

File tree

3 files changed

+130
-20
lines changed

3 files changed

+130
-20
lines changed

airflow/providers/cncf/kubernetes/decorators/kubernetes.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import base64
1920
import inspect
2021
import os
2122
import pickle
2223
import uuid
24+
from shlex import quote
2325
from tempfile import TemporaryDirectory
2426
from textwrap import dedent
2527
from typing import TYPE_CHECKING, Callable, Sequence
2628

29+
import dill
2730
from kubernetes.client import models as k8s
2831

2932
from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
@@ -37,21 +40,20 @@
3740
from airflow.utils.context import Context
3841

3942
_PYTHON_SCRIPT_ENV = "__PYTHON_SCRIPT"
43+
_PYTHON_INPUT_ENV = "__PYTHON_INPUT"
4044

41-
_FILENAME_IN_CONTAINER = "/tmp/script.py"
4245

43-
44-
def _generate_decode_command() -> str:
46+
def _generate_decoded_command(env_var: str, file: str) -> str:
4547
return (
4648
f'python -c "import base64, os;'
47-
rf"x = os.environ[\"{_PYTHON_SCRIPT_ENV}\"];"
48-
rf'f = open(\"{_FILENAME_IN_CONTAINER}\", \"w\"); f.write(x); f.close()"'
49+
rf"x = base64.b64decode(os.environ[\"{env_var}\"]);"
50+
rf'f = open(\"{file}\", \"wb\"); f.write(x); f.close()"'
4951
)
5052

5153

52-
def _read_file_contents(filename):
53-
with open(filename) as script_file:
54-
return script_file.read()
54+
def _read_file_contents(filename: str) -> str:
55+
with open(filename, "rb") as script_file:
56+
return base64.b64encode(script_file.read()).decode("utf-8")
5557

5658

5759
class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
@@ -62,17 +64,16 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
6264
{"op_args", "op_kwargs", *KubernetesPodOperator.template_fields} - {"cmds", "arguments"}
6365
)
6466

65-
# since we won't mutate the arguments, we should just do the shallow copy
67+
# Since we won't mutate the arguments, we should just do the shallow copy
6668
# there are some cases we can't deepcopy the objects (e.g protobuf).
6769
shallow_copy_attrs: Sequence[str] = ("python_callable",)
6870

69-
def __init__(self, namespace: str = "default", **kwargs) -> None:
70-
self.pickling_library = pickle
71+
def __init__(self, namespace: str = "default", use_dill: bool = False, **kwargs) -> None:
72+
self.pickling_library = dill if use_dill else pickle
7173
super().__init__(
7274
namespace=namespace,
7375
name=kwargs.pop("name", f"k8s_airflow_pod_{uuid.uuid4().hex}"),
74-
cmds=["bash"],
75-
arguments=["-cx", f"{_generate_decode_command()} && python {_FILENAME_IN_CONTAINER}"],
76+
cmds=["dummy-command"],
7677
**kwargs,
7778
)
7879

@@ -82,11 +83,41 @@ def _get_python_source(self):
8283
res = remove_task_decorator(res, "@task.kubernetes")
8384
return res
8485

86+
def _generate_cmds(self) -> list[str]:
87+
script_filename = "/tmp/script.py"
88+
input_filename = "/tmp/script.in"
89+
output_filename = "/airflow/xcom/return.json"
90+
91+
write_local_script_file_cmd = (
92+
f"{_generate_decoded_command(quote(_PYTHON_SCRIPT_ENV), quote(script_filename))}"
93+
)
94+
write_local_input_file_cmd = (
95+
f"{_generate_decoded_command(quote(_PYTHON_INPUT_ENV), quote(input_filename))}"
96+
)
97+
make_xcom_dir_cmd = "mkdir -p /airflow/xcom"
98+
exec_python_cmd = f"python {script_filename} {input_filename} {output_filename}"
99+
return [
100+
"bash",
101+
"-cx",
102+
" && ".join(
103+
[
104+
write_local_script_file_cmd,
105+
write_local_input_file_cmd,
106+
make_xcom_dir_cmd,
107+
exec_python_cmd,
108+
]
109+
),
110+
]
111+
85112
def execute(self, context: Context):
86113
with TemporaryDirectory(prefix="venv") as tmp_dir:
87114
script_filename = os.path.join(tmp_dir, "script.py")
88-
py_source = self._get_python_source()
115+
input_filename = os.path.join(tmp_dir, "script.in")
116+
117+
with open(input_filename, "wb") as file:
118+
self.pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file)
89119

120+
py_source = self._get_python_source()
90121
jinja_context = {
91122
"op_args": self.op_args,
92123
"op_kwargs": self.op_kwargs,
@@ -100,7 +131,10 @@ def execute(self, context: Context):
100131
self.env_vars = [
101132
*self.env_vars,
102133
k8s.V1EnvVar(name=_PYTHON_SCRIPT_ENV, value=_read_file_contents(script_filename)),
134+
k8s.V1EnvVar(name=_PYTHON_INPUT_ENV, value=_read_file_contents(input_filename)),
103135
]
136+
137+
self.cmds = self._generate_cmds()
104138
return super().execute(context)
105139

106140

airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
under the License.
1818
-#}
1919

20+
import json
2021
import {{ pickling_library }}
2122
import sys
2223

@@ -42,3 +43,8 @@ arg_dict = {"args": [], "kwargs": {}}
4243
# Script
4344
{{ python_callable_source }}
4445
res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])
46+
47+
# Write output
48+
with open(sys.argv[2], "w") as file:
49+
if res is not None:
50+
file.write(json.dumps(res))

tests/providers/cncf/kubernetes/decorators/test_kubernetes.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import base64
20+
import pickle
1921
from unittest import mock
2022

2123
import pytest
@@ -29,6 +31,8 @@
2931
POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
3032
HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook"
3133

34+
XCOM_IMAGE = "XCOM_IMAGE"
35+
3236

3337
@pytest.fixture(autouse=True)
3438
def mock_create_pod() -> mock.Mock:
@@ -40,6 +44,18 @@ def mock_await_pod_start() -> mock.Mock:
4044
return mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start").start()
4145

4246

47+
@pytest.fixture(autouse=True)
48+
def await_xcom_sidecar_container_start() -> mock.Mock:
49+
return mock.patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start").start()
50+
51+
52+
@pytest.fixture(autouse=True)
53+
def extract_xcom() -> mock.Mock:
54+
f = mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom").start()
55+
f.return_value = '{"key1": "value1", "key2": "value2"}'
56+
return f
57+
58+
4359
@pytest.fixture(autouse=True)
4460
def mock_await_pod_completion() -> mock.Mock:
4561
f = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion").start()
@@ -81,11 +97,65 @@ def f():
8197

8298
containers = mock_create_pod.call_args[1]["pod"].spec.containers
8399
assert len(containers) == 1
84-
assert containers[0].command == ["bash"]
100+
assert containers[0].command[0] == "bash"
101+
assert len(containers[0].args) == 0
102+
assert containers[0].env[0].name == "__PYTHON_SCRIPT"
103+
assert containers[0].env[0].value
104+
assert containers[0].env[1].name == "__PYTHON_INPUT"
105+
106+
# Ensure we pass input through a b64 encoded env var
107+
decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
108+
assert decoded_input == {"args": [], "kwargs": {}}
109+
110+
111+
def test_kubernetes_with_input_output(
112+
dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
113+
) -> None:
114+
with dag_maker(session=session) as dag:
115+
116+
@task.kubernetes(
117+
image="python:3.10-slim-buster",
118+
in_cluster=False,
119+
cluster_context="default",
120+
config_file="/tmp/fake_file",
121+
)
122+
def f(arg1, arg2, kwarg1=None, kwarg2=None):
123+
return {"key1": "value1", "key2": "value2"}
124+
125+
f.override(task_id="my_task_id", do_xcom_push=True)("arg1", "arg2", kwarg1="kwarg1")
126+
127+
dr = dag_maker.create_dagrun()
128+
(ti,) = dr.task_instances
129+
130+
mock_hook.return_value.get_xcom_sidecar_container_image.return_value = XCOM_IMAGE
131+
132+
dag.get_task("my_task_id").execute(context=ti.get_template_context(session=session))
133+
134+
mock_hook.assert_called_once_with(
135+
conn_id=None,
136+
in_cluster=False,
137+
cluster_context="default",
138+
config_file="/tmp/fake_file",
139+
)
140+
assert mock_create_pod.call_count == 1
141+
assert mock_hook.return_value.get_xcom_sidecar_container_image.call_count == 1
142+
143+
containers = mock_create_pod.call_args[1]["pod"].spec.containers
144+
145+
# First container is Python script
146+
assert len(containers) == 2
147+
assert containers[0].command[0] == "bash"
148+
assert len(containers[0].args) == 0
149+
150+
assert containers[0].env[0].name == "__PYTHON_SCRIPT"
151+
assert containers[0].env[0].value
152+
assert containers[0].env[1].name == "__PYTHON_INPUT"
153+
assert containers[0].env[1].value
85154

86-
assert len(containers[0].args) == 2
87-
assert containers[0].args[0] == "-cx"
88-
assert containers[0].args[1].endswith("/tmp/script.py")
155+
# Ensure we pass input through a b64 encoded env var
156+
decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
157+
assert decoded_input == {"args": ("arg1", "arg2"), "kwargs": {"kwarg1": "kwarg1"}}
89158

90-
assert containers[0].env[-1].name == "__PYTHON_SCRIPT"
91-
assert containers[0].env[-1].value
159+
# Second container is xcom image
160+
assert containers[1].image == XCOM_IMAGE
161+
assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"

0 commit comments

Comments
 (0)