Airflow : Run a task when some upstream is skipped by shortcircuit

后端 未结 5 437
难免孤独
难免孤独 2021-01-17 20:11

I have a task that I\'ll call final that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator this tas

5条回答
  •  不思量自难忘°
    2021-01-17 20:36

    I've made it work by making final task to check for statuses of upstream instances. Not beautiful as only way to access their state I've found was by querying Airflow DB.

    # # additional imports to ones in question code
    # from airflow import AirflowException
    # from airflow.models import TaskInstance
    # from airflow.operators.python_operator import PythonOperator
    # from airflow.settings import Session
    # from airflow.utils.state import State
    # from airflow.utils.trigger_rule import TriggerRule
    
    def all_upstreams_either_succeeded_or_skipped(dag, task, task_instance, **context):
        """
        find directly upstream task instances and count how many are not in prefered statuses.
        return True if we got no instances with non-preferred statuses.
        """
        upstream_task_ids = [t.task_id for t in task.get_direct_relatives(upstream=True)]
        session = Session()
        query = (session
            .query(TaskInstance)
            .filter(
                TaskInstance.dag_id == dag.dag_id,
                TaskInstance.execution_date.in_([task_instance.execution_date]),
                TaskInstance.task_id.in_(upstream_task_ids)
            )
        )
        upstream_task_instances = query.all()
        unhappy_task_instances = [ti for ti in upstream_task_instances if ti.state not in [State.SUCCESS, State.SKIPPED]]
        print(unhappy_task_instances)
        return len(unhappy_task_instances) == 0
    
    def final_fn(**context):
        """
        fail if upstream task instances have unwanted statuses
        """
        if not all_upstreams_either_succeeded_or_skipped(**context):
            raise AirflowException("Not all upstream tasks succeeded.")
        # Do things
    
    # will run when upstream task instances are done, including failed
    final = PythonOperator(
        dag=dag,
        task_id="final",
        trigger_rule=TriggerRule.ALL_DONE,
        python_callable=final_fn,
        provide_context=True)
    

提交回复
热议问题