I have a task that I\'ll call final
that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator
this tas
I've made it work by making final
task to check for statuses of upstream instances. Not beautiful as only way to access their state I've found was by querying Airflow DB.
# # additional imports to ones in question code
# from airflow import AirflowException
# from airflow.models import TaskInstance
# from airflow.operators.python_operator import PythonOperator
# from airflow.settings import Session
# from airflow.utils.state import State
# from airflow.utils.trigger_rule import TriggerRule
def all_upstreams_either_succeeded_or_skipped(dag, task, task_instance, **context):
"""
find directly upstream task instances and count how many are not in prefered statuses.
return True if we got no instances with non-preferred statuses.
"""
upstream_task_ids = [t.task_id for t in task.get_direct_relatives(upstream=True)]
session = Session()
query = (session
.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date.in_([task_instance.execution_date]),
TaskInstance.task_id.in_(upstream_task_ids)
)
)
upstream_task_instances = query.all()
unhappy_task_instances = [ti for ti in upstream_task_instances if ti.state not in [State.SUCCESS, State.SKIPPED]]
print(unhappy_task_instances)
return len(unhappy_task_instances) == 0
def final_fn(**context):
"""
fail if upstream task instances have unwanted statuses
"""
if not all_upstreams_either_succeeded_or_skipped(**context):
raise AirflowException("Not all upstream tasks succeeded.")
# Do things
# will run when upstream task instances are done, including failed
final = PythonOperator(
dag=dag,
task_id="final",
trigger_rule=TriggerRule.ALL_DONE,
python_callable=final_fn,
provide_context=True)