I have a task that I\'ll call final
that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator
this tas
This question is still legit with airflow 1.10
ShortCircuitOperator will skip all downstream TASK whatever the trigger_rule set
The solution of @michael-spector will only work with simple case and not this case :
with @michael-spector the task L will not be skipped ( only E , F , G , H tasks will be skipped )
A solution is this (based on @michael-spector proposition) :
class ShortCircuitOperatorOnlyDirectDownStream(PythonOperator, SkipMixin):
"""
Work like a ShortCircuitOperator but it will only skip the task that have in their upstream this task
So if a task have this task in his upstream AND another task it will not be skipped
-> B -> C -> D ------\
/ \
A -> K -> Y
\ /
-> F -> G - P -----------/
If K is a normal ShortCircuitOperator and condition is False then B , C , D and Y will be skip
if K is ShortCircuitOperatorOnlyDirectDownStream and condition is False then B , C , D will be skip , but not Y
found_tasks_name contains the names of the previous skipped task
found_tasks contains the airflow_task_id of the previous skipped task
:return found_tasks
"""
def find_tasks_to_skip(self, task, found_tasks_to_skip=None, found_tasks_to_skip_names=None):
if not found_tasks_to_skip: # list of task_id to skip
found_tasks_to_skip = []
# necessary because found_tasks do not keep a copy of names but airflow task_id
if not found_tasks_to_skip_names:
found_tasks_to_skip_names = set()
direct_relatives = task.get_direct_relatives(upstream=False)
for t in direct_relatives:
self.log.info("UPSTREAM : " + str(t.upstream_task_ids))
self.log.info(
" Does all skipped task " +
str(found_tasks_to_skip_names) +
" contain the upstream tasks" +
str(t.upstream_task_ids)
)
# if len == 1 then the task is only precede by a skipped task
# otherwise check if ALL upstream task are skipped
if len(t.upstream_task_ids) == 1 or all(elem in found_tasks_to_skip_names for elem in t.upstream_task_ids):
found_tasks_to_skip.append(t)
found_tasks_to_skip_names.add(t.task_id)
self.find_tasks_to_skip(t, found_tasks_to_skip, found_tasks_to_skip_names)
return found_tasks_to_skip
def execute(self, context):
condition = super(ShortCircuitOperatorOnlyDirectDownStream, self).execute(context)
self.log.info("Condition result is %s", condition)
if condition:
self.log.info('Proceeding with downstream tasks...')
return
self.log.info(
'Skipping downstream tasks that only rely on this path...')
tasks_to_skip = self.find_tasks_to_skip(context['task'])
self.log.debug("Tasks to skip: %s", tasks_to_skip)
if tasks_to_skip:
self.skip(context['dag_run'], context['ti'].execution_date,
tasks_to_skip)
self.log.info("Done.")