16
16
# under the License.
17
17
from __future__ import annotations
18
18
19
+ import base64
19
20
import inspect
20
21
import os
21
22
import pickle
22
23
import uuid
24
+ from shlex import quote
23
25
from tempfile import TemporaryDirectory
24
26
from textwrap import dedent
25
27
from typing import TYPE_CHECKING , Callable , Sequence
26
28
29
+ import dill
27
30
from kubernetes .client import models as k8s
28
31
29
32
from airflow .decorators .base import DecoratedOperator , TaskDecorator , task_decorator_factory
37
40
from airflow .utils .context import Context
38
41
39
42
_PYTHON_SCRIPT_ENV = "__PYTHON_SCRIPT"
43
+ _PYTHON_INPUT_ENV = "__PYTHON_INPUT"
40
44
41
- _FILENAME_IN_CONTAINER = "/tmp/script.py"
42
45
43
-
44
- def _generate_decode_command () -> str :
46
+ def _generate_decoded_command (env_var : str , file : str ) -> str :
45
47
return (
46
48
f'python -c "import base64, os;'
47
- rf"x = os.environ[\"{ _PYTHON_SCRIPT_ENV } \"];"
48
- rf'f = open(\"{ _FILENAME_IN_CONTAINER } \", \"w \"); f.write(x); f.close()"'
49
+ rf"x = base64.b64decode( os.environ[\"{ env_var } \"]) ;"
50
+ rf'f = open(\"{ file } \", \"wb \"); f.write(x); f.close()"'
49
51
)
50
52
51
53
52
- def _read_file_contents (filename ) :
53
- with open (filename ) as script_file :
54
- return script_file .read ()
54
+ def _read_file_contents (filename : str ) -> str :
55
+ with open (filename , "rb" ) as script_file :
56
+ return base64 . b64encode ( script_file .read ()). decode ( "utf-8" )
55
57
56
58
57
59
class _KubernetesDecoratedOperator (DecoratedOperator , KubernetesPodOperator ):
@@ -62,17 +64,16 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
62
64
{"op_args" , "op_kwargs" , * KubernetesPodOperator .template_fields } - {"cmds" , "arguments" }
63
65
)
64
66
65
- # since we won't mutate the arguments, we should just do the shallow copy
67
+ # Since we won't mutate the arguments, we should just do the shallow copy
66
68
# there are some cases we can't deepcopy the objects (e.g protobuf).
67
69
shallow_copy_attrs : Sequence [str ] = ("python_callable" ,)
68
70
69
- def __init__ (self , namespace : str = "default" , ** kwargs ) -> None :
70
- self .pickling_library = pickle
71
+ def __init__ (self , namespace : str = "default" , use_dill : bool = False , ** kwargs ) -> None :
72
+ self .pickling_library = dill if use_dill else pickle
71
73
super ().__init__ (
72
74
namespace = namespace ,
73
75
name = kwargs .pop ("name" , f"k8s_airflow_pod_{ uuid .uuid4 ().hex } " ),
74
- cmds = ["bash" ],
75
- arguments = ["-cx" , f"{ _generate_decode_command ()} && python { _FILENAME_IN_CONTAINER } " ],
76
+ cmds = ["dummy-command" ],
76
77
** kwargs ,
77
78
)
78
79
@@ -82,11 +83,41 @@ def _get_python_source(self):
82
83
res = remove_task_decorator (res , "@task.kubernetes" )
83
84
return res
84
85
86
+ def _generate_cmds (self ) -> list [str ]:
87
+ script_filename = "/tmp/script.py"
88
+ input_filename = "/tmp/script.in"
89
+ output_filename = "/airflow/xcom/return.json"
90
+
91
+ write_local_script_file_cmd = (
92
+ f"{ _generate_decoded_command (quote (_PYTHON_SCRIPT_ENV ), quote (script_filename ))} "
93
+ )
94
+ write_local_input_file_cmd = (
95
+ f"{ _generate_decoded_command (quote (_PYTHON_INPUT_ENV ), quote (input_filename ))} "
96
+ )
97
+ make_xcom_dir_cmd = "mkdir -p /airflow/xcom"
98
+ exec_python_cmd = f"python { script_filename } { input_filename } { output_filename } "
99
+ return [
100
+ "bash" ,
101
+ "-cx" ,
102
+ " && " .join (
103
+ [
104
+ write_local_script_file_cmd ,
105
+ write_local_input_file_cmd ,
106
+ make_xcom_dir_cmd ,
107
+ exec_python_cmd ,
108
+ ]
109
+ ),
110
+ ]
111
+
85
112
def execute (self , context : Context ):
86
113
with TemporaryDirectory (prefix = "venv" ) as tmp_dir :
87
114
script_filename = os .path .join (tmp_dir , "script.py" )
88
- py_source = self ._get_python_source ()
115
+ input_filename = os .path .join (tmp_dir , "script.in" )
116
+
117
+ with open (input_filename , "wb" ) as file :
118
+ self .pickling_library .dump ({"args" : self .op_args , "kwargs" : self .op_kwargs }, file )
89
119
120
+ py_source = self ._get_python_source ()
90
121
jinja_context = {
91
122
"op_args" : self .op_args ,
92
123
"op_kwargs" : self .op_kwargs ,
@@ -100,7 +131,10 @@ def execute(self, context: Context):
100
131
self .env_vars = [
101
132
* self .env_vars ,
102
133
k8s .V1EnvVar (name = _PYTHON_SCRIPT_ENV , value = _read_file_contents (script_filename )),
134
+ k8s .V1EnvVar (name = _PYTHON_INPUT_ENV , value = _read_file_contents (input_filename )),
103
135
]
136
+
137
+ self .cmds = self ._generate_cmds ()
104
138
return super ().execute (context )
105
139
106
140
0 commit comments