Airflow branching: A task that only sometimes depends on an upstream task

41 views Asked by At

I have two tasks: task_a and task_b. There are DAG-parameters run_task_a and run_task_b that determine whether each task should be run. There is further parameter that is an input for task_a. Here's the important part:

If task_a is run, then task_b should start only after task_a has finished. However, if task_a is not run, then task_b can start whenever.

(Motivation: task_a is the main task. A new run of task_a can result in defunct artifacts, which task_b cleans up. However, one may wish to trigger task_b independently.)

This is what I have written so far:

from airflow.decorators import dag, task
from airflow.models.param import Param
from datetime import datetime

default_args = {
  'owner': 'xyz',
  'email_on_retry': False,
  'email_on_failure': False,
  'retries': 0,
  'provide_context': True,
  'depends_on_past': False
}

@dag(
  default_args=default_args,
  start_date=datetime(2024, 3, 7),
  schedule_interval=None,
  params={
    'run_task_a': Param(
      True,
      type='boolean'),
    'run_task_b': Param(
      True,
      type='boolean'),
    'param_for_task_a': Param(
      'foo',
      enum=['foo','bar'],
      type='string')
      }
)
def my_dag():

  @task
  def get_context_values(**context):

    context_values = dict()
    context_values['params'] = context['params']

    return context_values

  @task.branch
  def branching(context_values):
    tasks_to_run = []

    if context_values['params']['run_task_a']:
      tasks_to_run.append('task_a')

    if context_values['params']['run_task_b']:
      tasks_to_run.append('task_b')

    return tasks_to_run

  @task
  def task_a(context_values):

    param_for_task_a = context_values['params']['param_for_task_a']

    if param_for_task_a == 'foo':
      # Do some stuff
      pass

    if param_for_task_a == 'bar':
      # Do some different stuff
      pass

    return None

  @task
  def task_b():

    # Do some more stuff
    
    return None

  # Taskflow
  context_values = get_context_values()
  branching(context_values) >> [task_a(context_values),task_b()]

my_dag()

enter image description here

The problem is when run_task_a == True and run_task_b == True: Both tasks run, but of course task_b does not wait for task_a to finish before starting because there is no dependency. I've tried to add this dependency by making task_b a downstream task of task_a, but then task_b does not run if run_task_a == False and run_task_b == True. Trigger rules also don't seem to be the solution, since task_b should not be run if run_task_b == False.

3

There are 3 answers

0
dwolfeu On BEST ANSWER

After a lot of trial and error, we managed to get it to work using short-circuiting:

from airflow.decorators import dag, task
from airflow.models.param import Param
from airflow.utils.trigger_rule import TriggerRule
from datetime import datetime

default_args = {
  'owner': 'xyz',
  'email_on_retry': False,
  'email_on_failure': False,
  'retries': 0,
  'provide_context': True,
  'depends_on_past': False
}

@dag(
  default_args=default_args,
  start_date=datetime(2024, 3, 7),
  schedule_interval=None,
  params={
    'run_task_a': Param(
      True,
      type='boolean'),
    'run_task_b': Param(
      True,
      type='boolean'),
    'param_for_task_a': Param(
      'foo',
      enum=['foo','bar'],
      type='string')
      }
)
def my_dag():

  @task
  def get_context_values(**context):

    context_values = dict()
    context_values['params'] = context['params']

    return context_values

  @task.short_circuit
  def short_circuit(context_values,key):
    return context_values['params'][key]

  @task
  def task_a(context_values):

    param_for_task_a = context_values['params']['param_for_task_a']

    if param_for_task_a == 'foo':
      # Do some stuff
      pass

    if param_for_task_a == 'bar':
      # Do some different stuff
      pass

    return None

  @task(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
  def task_b():

    # Do some more stuff
    
    return None

  # Taskflow
  context_values = get_context_values()
  short_circuit_a = short_circuit.override(
    task_id='short_circuit_a',ignore_downstream_trigger_rules=False)(context_values,'run_task_a')
  a = task_a(context_values)
  short_circuit_b = short_circuit.override(
    task_id='short_circuit_b')(context_values,'run_task_b')
  b = task_b()
  short_circuit_a >> a
  short_circuit_b >> b
  a >> b

my_dag()

enter image description here

1
ghowkay On

You can modify the branching function to control the flow based on the parameters. You'll need to make sure task_b depends on task_a only when task_a is supposed to run.

from airflow.models.param import Param
from datetime import datetime

default_args = {
  'owner': 'xyz',
  'email_on_retry': False,
  'email_on_failure': False,
  'retries': 0,
  'provide_context': True,
  'depends_on_past': False
}

@dag(
  default_args=default_args,
  start_date=datetime(2024, 3, 7),
  schedule_interval=None,
  params={
    'run_task_a': Param(True, type='boolean'),
    'run_task_b': Param(True, type='boolean'),
    'param_for_task_a': Param('foo', enum=['foo', 'bar'], type='string')
  }
)
def my_dag():

  @task
  def get_context_values(**context):
    return context['params']

  @task
  def task_a(param_for_task_a):
    if param_for_task_a == 'foo':
      # Do some stuff
      pass
    elif param_for_task_a == 'bar':
      # Do some different stuff
      pass
    return None

  @task
  def task_b():
    # Do some more stuff
    return None

  # Taskflow
  context_values = get_context_values()

  # Branching logic
  run_task_a = context_values['run_task_a']
  run_task_b = context_values['run_task_b']

  if run_task_a and run_task_b:
      task_a_output = task_a(context_values['param_for_task_a'])
      task_a_output >> task_b()
  elif run_task_b:
      task_b()
  # Add an else clause if you want to handle cases where both are False

my_dag_instance = my_dag()
1
Collin McNulty On

A straightforward fix would be to make task_b dependent on task_a with >> while also changing task_b's trigger_rule to none_failed. In the case where your branching causes task_a to be skipped, task_b will still run even though its upstream task did not succeed.