Airflow : Run a task when some upstream is skipped by shortcircuit

后端 未结 5 436
难免孤独
难免孤独 2021-01-17 20:11

I have a task that I\'ll call final that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator this tas

5条回答
  •  臣服心动
    2021-01-17 20:52

    I've ended up with developing custom ShortCircuitOperator based on the original one:

    class ShortCircuitOperator(PythonOperator, SkipMixin):
        """
        Allows a workflow to continue only if a condition is met. Otherwise, the
        workflow "short-circuits" and downstream tasks that only rely on this operator
        are skipped.
    
        The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
        condition and short-circuits the workflow if the condition is False. Any
        downstream tasks that only rely on this operator are marked with a state of "skipped".
        If the condition is True, downstream tasks proceed as normal.
    
        The condition is determined by the result of `python_callable`.
        """
    
        def find_tasks_to_skip(self, task, found_tasks=None):
            if not found_tasks:
                found_tasks = []
            direct_relatives = task.get_direct_relatives(upstream=False)
            for t in direct_relatives:
                if len(t.upstream_task_ids) == 1:
                    found_tasks.append(t)
                    self.find_tasks_to_skip(t, found_tasks)
            return found_tasks
    
        def execute(self, context):
            condition = super(ShortCircuitOperator, self).execute(context)
            self.log.info("Condition result is %s", condition)
    
            if condition:
                self.log.info('Proceeding with downstream tasks...')
                return
    
            self.log.info(
                'Skipping downstream tasks that only rely on this path...')
    
            tasks_to_skip = self.find_tasks_to_skip(context['task'])
            self.log.debug("Tasks to skip: %s", tasks_to_skip)
    
            if tasks_to_skip:
                self.skip(context['dag_run'], context['ti'].execution_date,
                          tasks_to_skip)
    
            self.log.info("Done.")
    

    This operator makes sure no downstream task that rely on multiple paths are getting skipped because of one skipped task.

提交回复
热议问题