Override TaskGroup in Airflow

115 views Asked by At

The problem is that I have a loop with the generation of several tasks (let's say 2: "task_a" and "task_b"), "task_a" returns some value that I get from XCom, I pass this value as a parameter to "task_b" ". A group is created inside “task_b”, which cannot be created externally (it can, but this must be done for auto-generation). And the fact is that if you create “task_a”, for example, through PythonOperator without task_group, then the group of this task will be root by default. I need to override the group in "task_a" which will be the same as that of "task_b". You can get the TaskGroup itself for the last task, but how i can override it for "task_a"?

import pendulum
from airflow.decorators import dag
from airflow.utils.task_group import TaskGroup
from airflow.operators.empty import EmptyOperator


class Macros:
    task_group: TaskGroup = None

    @staticmethod
    def get_task(parent_group: TaskGroup, load_name: str = ""):
        Macros.task_group = TaskGroup(group_id=f"subgroup_{load_name}", parent_group=parent_group)
        task = EmptyOperator(
            task_id=f'some_action_{load_name}',
            task_group=Macros.task_group,
        )
        return task


@dag(
    dag_id="test_dag",
    schedule_interval=None,
    start_date=pendulum.datetime(2024, 1, 1, tz="UTC"),
    catchup=False,
    render_template_as_native_obj=True
)
def load_test():
    start = EmptyOperator(task_id="start_load")
    main_group = TaskGroup(group_id='main_group')
    my_list_loads = ["a", "b", "c"]
    for load_name in my_list_loads:
    
        task_a = EmptyOperator(
            task_id=f'load_{load_name}',
        )
        task_b = Macros.get_task(load_name=load_name, parent_group=main_group)
        task_group = Macros.task_group # get task group
        # task_group.add(task_a) # Error  
        task_a >> task_b    
    end = EmptyOperator(task_id="end_load")
    start >> main_group >> end

load_test_dag = load_test()

Expected graph:

graph

Important! subgroup_a/b/c does not exist before the "Macros.get_task" function is called. Consider that inside the Macros class there is a task_group attribute, which, after calling the static method "get_task", takes the value subgroup_a/b/c.

1

There are 1 answers

2
Danila Ganchar On

Updated


import pendulum
from airflow import DAG
from airflow.operators.empty import EmptyOperator
from airflow.utils.task_group import TaskGroup

dag = DAG(
    dag_id='test',
    start_date=pendulum.datetime(2024, 1, 1),
    schedule_interval=None,
    max_active_runs=1,
)


class Macros:
    @staticmethod
    def get_task(parent_group: TaskGroup, load_name: str, _dag: DAG):
        group = TaskGroup(group_id=f'subgroup_{load_name}', parent_group=parent_group, dag=_dag)
        return EmptyOperator(
            task_id=f'some_action_{load_name}',
            task_group=group,
            dag=_dag,
        )


main_group = TaskGroup(dag=dag, group_id='main_group')
for name in ['a', 'b', 'c']:
    action = Macros.get_task(main_group, name, dag)
    load = EmptyOperator(task_id=f'load_{name}')
    action.task_group.add(load)

    load >> action

JFYI: DummyOperator is deprecated