I have a task that I\'ll call final
that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator
this tas
I've made it work by making final
task to check for statuses of upstream instances. Not beautiful as only way to access their state I've found was by querying Airflow DB.
# # additional imports to ones in question code
# from airflow import AirflowException
# from airflow.models import TaskInstance
# from airflow.operators.python_operator import PythonOperator
# from airflow.settings import Session
# from airflow.utils.state import State
# from airflow.utils.trigger_rule import TriggerRule
def all_upstreams_either_succeeded_or_skipped(dag, task, task_instance, **context):
"""
find directly upstream task instances and count how many are not in prefered statuses.
return True if we got no instances with non-preferred statuses.
"""
upstream_task_ids = [t.task_id for t in task.get_direct_relatives(upstream=True)]
session = Session()
query = (session
.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date.in_([task_instance.execution_date]),
TaskInstance.task_id.in_(upstream_task_ids)
)
)
upstream_task_instances = query.all()
unhappy_task_instances = [ti for ti in upstream_task_instances if ti.state not in [State.SUCCESS, State.SKIPPED]]
print(unhappy_task_instances)
return len(unhappy_task_instances) == 0
def final_fn(**context):
"""
fail if upstream task instances have unwanted statuses
"""
if not all_upstreams_either_succeeded_or_skipped(**context):
raise AirflowException("Not all upstream tasks succeeded.")
# Do things
# will run when upstream task instances are done, including failed
final = PythonOperator(
dag=dag,
task_id="final",
trigger_rule=TriggerRule.ALL_DONE,
python_callable=final_fn,
provide_context=True)
This may have been added after you asked your initial question, but Airflow now conveniently has a trigger_rule value of none_failed
. If you set this on your final task, it should complete whether upstream tasks are skipped or succeeded, just not when they fail.
More info: https://airflow.apache.org/concepts.html#trigger-rules
I'm posting another possible workaround for this since this is a method that does not require a custom operator implementation.
I was influenced by the solution in this blog using a PythonOperator which raises an AirflowSkipException which skips the task itself and then downstream tasks individually.
https://godatadriven.com/blog/the-zen-of-python-and-apache-airflow/
This then respects the trigger_rule of the final downstream task, which in my case I set to trigger_rule='none_failed'
.
Modfied example as per the blog to include a final task:
def fn_short_circuit(**context):
if <<<some condition>>>:
raise AirflowSkipException("Skip this task and individual downstream tasks while respecting trigger rules.")
check_date = PythonOperator(
task_id="check_if_min_date",
python_callable=_check_date,
provide_context=True,
dag=dag,
)
task1 = DummyOperator(task_id="task1", dag=dag)
task2 = DummyOperator(task_id="task2", dag=dag)
work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=fn_short_circuit
final_task = DummyOperator(task_id="final_task",
trigger_rule='none_failed',
dag=dag)
task_1 >> short >> work >> final_task
task_1 >> task_2 >> final_task
I've ended up with developing custom ShortCircuitOperator based on the original one:
class ShortCircuitOperator(PythonOperator, SkipMixin):
"""
Allows a workflow to continue only if a condition is met. Otherwise, the
workflow "short-circuits" and downstream tasks that only rely on this operator
are skipped.
The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
condition and short-circuits the workflow if the condition is False. Any
downstream tasks that only rely on this operator are marked with a state of "skipped".
If the condition is True, downstream tasks proceed as normal.
The condition is determined by the result of `python_callable`.
"""
def find_tasks_to_skip(self, task, found_tasks=None):
if not found_tasks:
found_tasks = []
direct_relatives = task.get_direct_relatives(upstream=False)
for t in direct_relatives:
if len(t.upstream_task_ids) == 1:
found_tasks.append(t)
self.find_tasks_to_skip(t, found_tasks)
return found_tasks
def execute(self, context):
condition = super(ShortCircuitOperator, self).execute(context)
self.log.info("Condition result is %s", condition)
if condition:
self.log.info('Proceeding with downstream tasks...')
return
self.log.info(
'Skipping downstream tasks that only rely on this path...')
tasks_to_skip = self.find_tasks_to_skip(context['task'])
self.log.debug("Tasks to skip: %s", tasks_to_skip)
if tasks_to_skip:
self.skip(context['dag_run'], context['ti'].execution_date,
tasks_to_skip)
self.log.info("Done.")
This operator makes sure no downstream task that rely on multiple paths are getting skipped because of one skipped task.
This question is still legit with airflow 1.10
ShortCircuitOperator will skip all downstream TASK whatever the trigger_rule set
The solution of @michael-spector will only work with simple case and not this case :
with @michael-spector the task L will not be skipped ( only E , F , G , H tasks will be skipped )
A solution is this (based on @michael-spector proposition) :
class ShortCircuitOperatorOnlyDirectDownStream(PythonOperator, SkipMixin):
"""
Work like a ShortCircuitOperator but it will only skip the task that have in their upstream this task
So if a task have this task in his upstream AND another task it will not be skipped
-> B -> C -> D ------\
/ \
A -> K -> Y
\ /
-> F -> G - P -----------/
If K is a normal ShortCircuitOperator and condition is False then B , C , D and Y will be skip
if K is ShortCircuitOperatorOnlyDirectDownStream and condition is False then B , C , D will be skip , but not Y
found_tasks_name contains the names of the previous skipped task
found_tasks contains the airflow_task_id of the previous skipped task
:return found_tasks
"""
def find_tasks_to_skip(self, task, found_tasks_to_skip=None, found_tasks_to_skip_names=None):
if not found_tasks_to_skip: # list of task_id to skip
found_tasks_to_skip = []
# necessary because found_tasks do not keep a copy of names but airflow task_id
if not found_tasks_to_skip_names:
found_tasks_to_skip_names = set()
direct_relatives = task.get_direct_relatives(upstream=False)
for t in direct_relatives:
self.log.info("UPSTREAM : " + str(t.upstream_task_ids))
self.log.info(
" Does all skipped task " +
str(found_tasks_to_skip_names) +
" contain the upstream tasks" +
str(t.upstream_task_ids)
)
# if len == 1 then the task is only precede by a skipped task
# otherwise check if ALL upstream task are skipped
if len(t.upstream_task_ids) == 1 or all(elem in found_tasks_to_skip_names for elem in t.upstream_task_ids):
found_tasks_to_skip.append(t)
found_tasks_to_skip_names.add(t.task_id)
self.find_tasks_to_skip(t, found_tasks_to_skip, found_tasks_to_skip_names)
return found_tasks_to_skip
def execute(self, context):
condition = super(ShortCircuitOperatorOnlyDirectDownStream, self).execute(context)
self.log.info("Condition result is %s", condition)
if condition:
self.log.info('Proceeding with downstream tasks...')
return
self.log.info(
'Skipping downstream tasks that only rely on this path...')
tasks_to_skip = self.find_tasks_to_skip(context['task'])
self.log.debug("Tasks to skip: %s", tasks_to_skip)
if tasks_to_skip:
self.skip(context['dag_run'], context['ti'].execution_date,
tasks_to_skip)
self.log.info("Done.")