Airflow : Run a task when some upstream is skipped by shortcircuit

后端 未结 5 423
难免孤独
难免孤独 2021-01-17 20:11

I have a task that I\'ll call final that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator this tas

相关标签:
5条回答
  • 2021-01-17 20:36

    I've made it work by making final task to check for statuses of upstream instances. Not beautiful as only way to access their state I've found was by querying Airflow DB.

    # # additional imports to ones in question code
    # from airflow import AirflowException
    # from airflow.models import TaskInstance
    # from airflow.operators.python_operator import PythonOperator
    # from airflow.settings import Session
    # from airflow.utils.state import State
    # from airflow.utils.trigger_rule import TriggerRule
    
    def all_upstreams_either_succeeded_or_skipped(dag, task, task_instance, **context):
        """
        find directly upstream task instances and count how many are not in prefered statuses.
        return True if we got no instances with non-preferred statuses.
        """
        upstream_task_ids = [t.task_id for t in task.get_direct_relatives(upstream=True)]
        session = Session()
        query = (session
            .query(TaskInstance)
            .filter(
                TaskInstance.dag_id == dag.dag_id,
                TaskInstance.execution_date.in_([task_instance.execution_date]),
                TaskInstance.task_id.in_(upstream_task_ids)
            )
        )
        upstream_task_instances = query.all()
        unhappy_task_instances = [ti for ti in upstream_task_instances if ti.state not in [State.SUCCESS, State.SKIPPED]]
        print(unhappy_task_instances)
        return len(unhappy_task_instances) == 0
    
    def final_fn(**context):
        """
        fail if upstream task instances have unwanted statuses
        """
        if not all_upstreams_either_succeeded_or_skipped(**context):
            raise AirflowException("Not all upstream tasks succeeded.")
        # Do things
    
    # will run when upstream task instances are done, including failed
    final = PythonOperator(
        dag=dag,
        task_id="final",
        trigger_rule=TriggerRule.ALL_DONE,
        python_callable=final_fn,
        provide_context=True)
    
    0 讨论(0)
  • 2021-01-17 20:37

    This may have been added after you asked your initial question, but Airflow now conveniently has a trigger_rule value of none_failed. If you set this on your final task, it should complete whether upstream tasks are skipped or succeeded, just not when they fail.

    More info: https://airflow.apache.org/concepts.html#trigger-rules

    0 讨论(0)
  • 2021-01-17 20:45

    I'm posting another possible workaround for this since this is a method that does not require a custom operator implementation.

    I was influenced by the solution in this blog using a PythonOperator which raises an AirflowSkipException which skips the task itself and then downstream tasks individually.

    https://godatadriven.com/blog/the-zen-of-python-and-apache-airflow/

    This then respects the trigger_rule of the final downstream task, which in my case I set to trigger_rule='none_failed'.

    Modfied example as per the blog to include a final task:

    def fn_short_circuit(**context):
        if <<<some condition>>>:
            raise AirflowSkipException("Skip this task and individual downstream tasks while respecting trigger rules.")
    
    check_date = PythonOperator(
        task_id="check_if_min_date",
        python_callable=_check_date,
        provide_context=True,
        dag=dag,
    )
    
    task1 = DummyOperator(task_id="task1", dag=dag)
    task2 = DummyOperator(task_id="task2", dag=dag)
    work = DummyOperator(dag=dag, task_id='work')
    short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=fn_short_circuit
    final_task = DummyOperator(task_id="final_task",
        trigger_rule='none_failed',
        dag=dag)
    
    
    task_1 >> short >> work >> final_task
    task_1 >> task_2 >> final_task
    
    0 讨论(0)
  • 2021-01-17 20:52

    I've ended up with developing custom ShortCircuitOperator based on the original one:

    class ShortCircuitOperator(PythonOperator, SkipMixin):
        """
        Allows a workflow to continue only if a condition is met. Otherwise, the
        workflow "short-circuits" and downstream tasks that only rely on this operator
        are skipped.
    
        The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
        condition and short-circuits the workflow if the condition is False. Any
        downstream tasks that only rely on this operator are marked with a state of "skipped".
        If the condition is True, downstream tasks proceed as normal.
    
        The condition is determined by the result of `python_callable`.
        """
    
        def find_tasks_to_skip(self, task, found_tasks=None):
            if not found_tasks:
                found_tasks = []
            direct_relatives = task.get_direct_relatives(upstream=False)
            for t in direct_relatives:
                if len(t.upstream_task_ids) == 1:
                    found_tasks.append(t)
                    self.find_tasks_to_skip(t, found_tasks)
            return found_tasks
    
        def execute(self, context):
            condition = super(ShortCircuitOperator, self).execute(context)
            self.log.info("Condition result is %s", condition)
    
            if condition:
                self.log.info('Proceeding with downstream tasks...')
                return
    
            self.log.info(
                'Skipping downstream tasks that only rely on this path...')
    
            tasks_to_skip = self.find_tasks_to_skip(context['task'])
            self.log.debug("Tasks to skip: %s", tasks_to_skip)
    
            if tasks_to_skip:
                self.skip(context['dag_run'], context['ti'].execution_date,
                          tasks_to_skip)
    
            self.log.info("Done.")
    

    This operator makes sure no downstream task that rely on multiple paths are getting skipped because of one skipped task.

    0 讨论(0)
  • 2021-01-17 20:59

    This question is still legit with airflow 1.10

    ShortCircuitOperator will skip all downstream TASK whatever the trigger_rule set


    The solution of @michael-spector will only work with simple case and not this case :

    with @michael-spector the task L will not be skipped ( only E , F , G , H tasks will be skipped )

    A solution is this (based on @michael-spector proposition) :

    class ShortCircuitOperatorOnlyDirectDownStream(PythonOperator, SkipMixin):
    """
    Work like a ShortCircuitOperator but it will only skip the task that have in their upstream this task
    
    So if a task have this task in his upstream AND another task it will not be skipped
    
            -> B -> C -> D ------\
          /                       \
    A -> K                         -> Y
     \                            /
       -> F -> G - P -----------/
    
    
    If K is a normal ShortCircuitOperator and condition is False then B , C , D and Y will be skip
    
    if K is ShortCircuitOperatorOnlyDirectDownStream and condition is False then B , C , D will be skip , but not Y
    
    
    found_tasks_name contains the names of the previous skipped task
    found_tasks contains the airflow_task_id of the previous skipped task
    
    :return found_tasks
    """
    
    def find_tasks_to_skip(self, task, found_tasks_to_skip=None, found_tasks_to_skip_names=None):
        if not found_tasks_to_skip:  # list of task_id to skip
            found_tasks_to_skip = []
    
        # necessary because found_tasks do not keep a copy of names but airflow task_id
        if not found_tasks_to_skip_names:
            found_tasks_to_skip_names = set()
    
        direct_relatives = task.get_direct_relatives(upstream=False)
        for t in direct_relatives:
            self.log.info("UPSTREAM : " + str(t.upstream_task_ids))
            self.log.info(
                " Does all skipped task " +
                str(found_tasks_to_skip_names) +
                " contain the upstream tasks" +
                str(t.upstream_task_ids)
            )
    
            # if len == 1 then the task is only precede by a skipped task
            # otherwise check if ALL upstream task are skipped
            if len(t.upstream_task_ids) == 1 or all(elem in found_tasks_to_skip_names for elem in t.upstream_task_ids):
                found_tasks_to_skip.append(t)
                found_tasks_to_skip_names.add(t.task_id)
                self.find_tasks_to_skip(t, found_tasks_to_skip, found_tasks_to_skip_names)
    
        return found_tasks_to_skip
    
    def execute(self, context):
        condition = super(ShortCircuitOperatorOnlyDirectDownStream, self).execute(context)
        self.log.info("Condition result is %s", condition)
    
        if condition:
            self.log.info('Proceeding with downstream tasks...')
            return
    
        self.log.info(
            'Skipping downstream tasks that only rely on this path...')
    
        tasks_to_skip = self.find_tasks_to_skip(context['task'])
        self.log.debug("Tasks to skip: %s", tasks_to_skip)
    
        if tasks_to_skip:
            self.skip(context['dag_run'], context['ti'].execution_date,
                      tasks_to_skip)
    
        self.log.info("Done.")
    
    0 讨论(0)
提交回复
热议问题