Airflow : Run a task when some upstream is skipped by shortcircuit

后端 未结 5 422
难免孤独
难免孤独 2021-01-17 20:11

I have a task that I\'ll call final that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator this tas

5条回答
  •  隐瞒了意图╮
    2021-01-17 20:59

    This question is still legit with airflow 1.10

    ShortCircuitOperator will skip all downstream TASK whatever the trigger_rule set


    The solution of @michael-spector will only work with simple case and not this case :

    with @michael-spector the task L will not be skipped ( only E , F , G , H tasks will be skipped )

    A solution is this (based on @michael-spector proposition) :

    class ShortCircuitOperatorOnlyDirectDownStream(PythonOperator, SkipMixin):
    """
    Work like a ShortCircuitOperator but it will only skip the task that have in their upstream this task
    
    So if a task have this task in his upstream AND another task it will not be skipped
    
            -> B -> C -> D ------\
          /                       \
    A -> K                         -> Y
     \                            /
       -> F -> G - P -----------/
    
    
    If K is a normal ShortCircuitOperator and condition is False then B , C , D and Y will be skip
    
    if K is ShortCircuitOperatorOnlyDirectDownStream and condition is False then B , C , D will be skip , but not Y
    
    
    found_tasks_name contains the names of the previous skipped task
    found_tasks contains the airflow_task_id of the previous skipped task
    
    :return found_tasks
    """
    
    def find_tasks_to_skip(self, task, found_tasks_to_skip=None, found_tasks_to_skip_names=None):
        if not found_tasks_to_skip:  # list of task_id to skip
            found_tasks_to_skip = []
    
        # necessary because found_tasks do not keep a copy of names but airflow task_id
        if not found_tasks_to_skip_names:
            found_tasks_to_skip_names = set()
    
        direct_relatives = task.get_direct_relatives(upstream=False)
        for t in direct_relatives:
            self.log.info("UPSTREAM : " + str(t.upstream_task_ids))
            self.log.info(
                " Does all skipped task " +
                str(found_tasks_to_skip_names) +
                " contain the upstream tasks" +
                str(t.upstream_task_ids)
            )
    
            # if len == 1 then the task is only precede by a skipped task
            # otherwise check if ALL upstream task are skipped
            if len(t.upstream_task_ids) == 1 or all(elem in found_tasks_to_skip_names for elem in t.upstream_task_ids):
                found_tasks_to_skip.append(t)
                found_tasks_to_skip_names.add(t.task_id)
                self.find_tasks_to_skip(t, found_tasks_to_skip, found_tasks_to_skip_names)
    
        return found_tasks_to_skip
    
    def execute(self, context):
        condition = super(ShortCircuitOperatorOnlyDirectDownStream, self).execute(context)
        self.log.info("Condition result is %s", condition)
    
        if condition:
            self.log.info('Proceeding with downstream tasks...')
            return
    
        self.log.info(
            'Skipping downstream tasks that only rely on this path...')
    
        tasks_to_skip = self.find_tasks_to_skip(context['task'])
        self.log.debug("Tasks to skip: %s", tasks_to_skip)
    
        if tasks_to_skip:
            self.skip(context['dag_run'], context['ti'].execution_date,
                      tasks_to_skip)
    
        self.log.info("Done.")
    

提交回复
热议问题