Skip to content

Commit 5d30ae3

Browse files
authored
Support repeatedly calling WhenAny on a progressively smaller list of already scheduled tasks (#446)
1 parent 19fba1a commit 5d30ae3

File tree

4 files changed

+69
-5
lines changed

4 files changed

+69
-5
lines changed

azure/durable_functions/models/DurableOrchestrationContext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ def _get_function_name(self, name: FunctionBuilder,
683683
name = name._function._name
684684
return name
685685
else:
686-
if(trigger_type == OrchestrationTrigger):
686+
if (trigger_type == OrchestrationTrigger):
687687
trigger_type = "OrchestrationTrigger"
688688
else:
689689
trigger_type = "ActivityTrigger"

azure/durable_functions/models/Task.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ def __init__(self, id_: Union[int, str], actions: Union[List[Action], Action]):
5656
self.result: Any = None
5757
self.action_repr: Union[List[Action], Action] = actions
5858
self.is_played = False
59+
self._is_scheduled_flag = False
60+
61+
@property
62+
def _is_scheduled(self) -> bool:
63+
return self._is_scheduled_flag
64+
65+
def _set_is_scheduled(self, is_scheduled: bool):
66+
self._is_scheduled_flag = is_scheduled
5967

6068
@property
6169
def is_completed(self) -> bool:
@@ -158,7 +166,8 @@ def __init__(self, tasks: List[TaskBase], compound_action_constructor=None):
158166
if isinstance(action_repr, list):
159167
child_actions.extend(action_repr)
160168
else:
161-
child_actions.append(action_repr)
169+
if not task._is_scheduled:
170+
child_actions.append(action_repr)
162171
if compound_action_constructor is None:
163172
self.action_repr = child_actions
164173
else: # replay_schema is ReplaySchema.V2
@@ -176,6 +185,10 @@ def __init__(self, tasks: List[TaskBase], compound_action_constructor=None):
176185
if not (child.state is TaskState.RUNNING):
177186
self.handle_completion(child)
178187

188+
@property
189+
def _is_scheduled(self) -> bool:
190+
return all([child._is_scheduled for child in self.children])
191+
179192
def handle_completion(self, child: TaskBase):
180193
"""Manage sub-task completion events.
181194

azure/durable_functions/models/TaskOrchestrationExecutor.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from azure.durable_functions.models.Task import TaskBase, TaskState, AtomicTask
1+
from azure.durable_functions.models.Task import TaskBase, TaskState, AtomicTask, CompoundTask
22
from azure.durable_functions.models.OrchestratorState import OrchestratorState
33
from azure.durable_functions.models.DurableOrchestrationContext import DurableOrchestrationContext
44
from typing import Any, List, Optional, Union
@@ -229,7 +229,8 @@ def resume_user_code(self):
229229
task_succeeded = current_task.state is TaskState.SUCCEEDED
230230
new_task = self.generator.send(
231231
task_value) if task_succeeded else self.generator.throw(task_value)
232-
self.context._add_to_open_tasks(new_task)
232+
if isinstance(new_task, TaskBase) and not (new_task._is_scheduled):
233+
self.context._add_to_open_tasks(new_task)
233234
except StopIteration as stop_exception:
234235
# the orchestration returned,
235236
# flag it as such and capture its output
@@ -245,9 +246,17 @@ def resume_user_code(self):
245246
# user yielded the same task multiple times, continue executing code
246247
# until a new/not-previously-yielded task is encountered
247248
self.resume_user_code()
248-
else:
249+
elif not (self.current_task._is_scheduled):
249250
# new task is received. it needs to be resolved to a value
250251
self.context._add_to_actions(self.current_task.action_repr)
252+
self._mark_as_scheduled(self.current_task)
253+
254+
def _mark_as_scheduled(self, task: TaskBase):
255+
if isinstance(task, CompoundTask):
256+
for task in task.children:
257+
self._mark_as_scheduled(task)
258+
else:
259+
task._set_is_scheduled(True)
251260

252261
def get_orchestrator_state_str(self) -> str:
253262
"""Obtain a JSON-formatted string representing the orchestration's state.

tests/orchestrator/test_sequential_orchestrator.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,24 @@ def generator_function_duplicate_yield(context):
7777

7878
return ""
7979

80+
def generator_function_reducing_when_all(context):
81+
task1 = context.call_activity("Hello", "Tokyo")
82+
task2 = context.call_activity("Hello", "Seattle")
83+
pending_tasks = [task1, task2]
84+
85+
# Yield until first task is completed
86+
finished_task1 = yield context.task_any(pending_tasks)
87+
88+
# Remove completed task from pending tasks
89+
pending_tasks.remove(finished_task1)
90+
91+
# Yield remaining task
92+
yield context.task_any(pending_tasks)
93+
94+
# Ensure we can still schedule new tasks
95+
yield context.call_activity("Hello", "London")
96+
return ""
97+
8098
def generator_function_compound_tasks(context):
8199
yield context.call_activity("Hello", "Tokyo")
82100

@@ -689,6 +707,30 @@ def test_duplicate_yields_do_not_add_duplicate_actions():
689707
assert_valid_schema(result)
690708
assert_orchestration_state_equals(expected, result)
691709

710+
def test_reducing_when_any_pattern():
711+
"""Tests that a user can call when_any on a progressively smaller list of already scheduled tasks"""
712+
context_builder = ContextBuilder('test_reducing_when_any', replay_schema=ReplaySchema.V2)
713+
add_hello_completed_events(context_builder, 0, "\"Hello Tokyo!\"")
714+
add_hello_completed_events(context_builder, 1, "\"Hello Seattle!\"")
715+
add_hello_completed_events(context_builder, 2, "\"Hello London!\"")
716+
717+
result = get_orchestration_state_result(
718+
context_builder, generator_function_reducing_when_all)
719+
720+
# this scenario is only supported for V2 replay
721+
expected_state = base_expected_state("",replay_schema=ReplaySchema.V2)
722+
expected_state._actions = [
723+
[WhenAnyAction(
724+
[CallActivityAction("Hello", "Seattle"), CallActivityAction("Hello", "Tokyo")]),
725+
CallActivityAction("Hello", "London")
726+
]
727+
]
728+
729+
expected_state._is_done = True
730+
expected = expected_state.to_json()
731+
732+
assert_orchestration_state_equals(expected, result)
733+
692734
def test_compound_tasks_return_single_action_in_V2():
693735
"""Tests that compound tasks, in the v2 replay schema, are represented as a single "deep" action"""
694736
context_builder = ContextBuilder('test_v2_replay_schema', replay_schema=ReplaySchema.V2)

0 commit comments

Comments
 (0)