I have a task that I\'ll call final
that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator
this tas
I've ended up with developing custom ShortCircuitOperator based on the original one:
class ShortCircuitOperator(PythonOperator, SkipMixin):
"""
Allows a workflow to continue only if a condition is met. Otherwise, the
workflow "short-circuits" and downstream tasks that only rely on this operator
are skipped.
The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
condition and short-circuits the workflow if the condition is False. Any
downstream tasks that only rely on this operator are marked with a state of "skipped".
If the condition is True, downstream tasks proceed as normal.
The condition is determined by the result of `python_callable`.
"""
def find_tasks_to_skip(self, task, found_tasks=None):
if not found_tasks:
found_tasks = []
direct_relatives = task.get_direct_relatives(upstream=False)
for t in direct_relatives:
if len(t.upstream_task_ids) == 1:
found_tasks.append(t)
self.find_tasks_to_skip(t, found_tasks)
return found_tasks
def execute(self, context):
condition = super(ShortCircuitOperator, self).execute(context)
self.log.info("Condition result is %s", condition)
if condition:
self.log.info('Proceeding with downstream tasks...')
return
self.log.info(
'Skipping downstream tasks that only rely on this path...')
tasks_to_skip = self.find_tasks_to_skip(context['task'])
self.log.debug("Tasks to skip: %s", tasks_to_skip)
if tasks_to_skip:
self.skip(context['dag_run'], context['ti'].execution_date,
tasks_to_skip)
self.log.info("Done.")
This operator makes sure no downstream task that rely on multiple paths are getting skipped because of one skipped task.