Airflow の Callback で複数の Task からなる処理のリトライを行う

airflowetl

EMR クラスタで処理を行う際、EmrAddStepsOperator で EMR クラスタに Step を追加した後、EmrStepSensor でその実行が終わるのを待つが、 Step の処理が失敗しても Failed するのは Sensor の方なので、リトライしても Step が再実行されないという問題がある。

CDK で Amazon Managed Workflow for Apache Airflow (MWAA) の環境を作成しワークフローを実行する - sambaiz-net

複数の Task からなる処理を TaskGroup に入れると UI 上はまとめられるが、Task のようなリトライの設定は現状できない。 また、SubDAG は パフォーマンスなどの問題があり deprecated となっている。 そこで Task の失敗時などに呼ばれる Callback を用いてリトライされるようにする。

Callback には次のようなオブジェクトが渡される。

# def on_failure(context):
#   print(context)
{
  'conf': <***.configuration.AirflowConfigParser object at 0xffffa02942d0>, 
  'dag': <DAG: test_callback>, 
  'dag_run': <DagRun test_callback @ 2022-12-18 03:11:29.683523+00:00: manual__2022-12-18T03:11:29.683523+00:00, state:running, queued_at: 2022-12-18 03:11:29.759660+00:00. externally triggered: True>, 
  'data_interval_end': DateTime(2022, 12, 18, 3, 0, 0, tzinfo=Timezone('UTC')), 
  'data_interval_start': DateTime(2022, 12, 18, 2, 0, 0, tzinfo=Timezone('UTC')), 
  'ds': '2022-12-18', 
  'ds_nodash': '20221218', 
  'execution_date': DateTime(2022, 12, 18, 3, 11, 29, 683523, tzinfo=Timezone('UTC')), 
  'inlets': [], 
  'logical_date': DateTime(2022, 12, 18, 3, 11, 29, 683523, tzinfo=Timezone('UTC')), 
  'macros': <module '***.macros' from '/home/***/.local/lib/python3.7/site-packages/***/macros/__init__.py'>, 
  'next_ds': '2022-12-18', 
  'next_ds_nodash': '20221218', 
  'next_execution_date': DateTime(2022, 12, 18, 3, 11, 29, 683523, tzinfo=Timezone('UTC')), 
  'outlets': [], 
  'params': {}, 
  'prev_data_interval_start_success': DateTime(2022, 12, 18, 2, 0, 0, tzinfo=Timezone('UTC')), 
  'prev_data_interval_end_success': DateTime(2022, 12, 18, 3, 0, 0, tzinfo=Timezone('UTC')), 
  'prev_ds': '2022-12-18', 
  'prev_ds_nodash': '20221218', 
  'prev_execution_date': DateTime(2022, 12, 18, 3, 11, 29, 683523, tzinfo=Timezone('UTC')), 
  'prev_execution_date_success': DateTime(2022, 12, 18, 3, 0, 9, 505185, tzinfo=Timezone('UTC')), 
  'prev_start_date_success': DateTime(2022, 12, 18, 3, 0, 9, 577917, tzinfo=Timezone('UTC')), 
  'run_id': 'manual__2022-12-18T03:11:29.683523+00:00', 
  'task': <Task(BashOperator): task2>, 
  'task_instance': <TaskInstance: test_callback.task2 manual__2022-12-18T03:11:29.683523+00:00 [failed]>, 
  'task_instance_key_str': 'test_callback__task2__20221218', 
  'test_mode': False, 
  'ti': <TaskInstance: test_callback.task2 manual__2022-12-18T03:11:29.683523+00:00 [failed]>, 
  'tomorrow_ds': '2022-12-19', 
  'tomorrow_ds_nodash': '20221219', 
  'triggering_dataset_events': <Proxy at 0xffff8b9e2a50 with factory <function TaskInstance.get_template_context.<locals>.get_triggering_events at 0xffff8ba4b950>>, 
  'ts': '2022-12-18T03:11:29.683523+00:00', 
  'ts_nodash': '20221218T031129', 
  'ts_nodash_with_tz': '20221218T031129.683523+0000', 
  'var': {'json': None, 'value': None}, 
  'conn': None, 
  'yesterday_ds': '2022-12-17', 
  'yesterday_ds_nodash': '20221217', 
  'exception': AirflowException('Bash command failed. The command returned a non-zero exit code 1.')}

次の DAG は EMR の例と同様に、Failed するのは 2 つ目の Task で、それのみをリトライしても結果は変わらないが、 そのリトライ/失敗時に upstream の Task の state を併せて更新することで 1 つの Task のようにリトライ/失敗するようにしている。

from datetime import timedelta
from airflow import DAG
from airflow.utils.state import State
from airflow.operators.bash import BashOperator
from airflow.utils.dates import days_ago
from datetime import timedelta

def retry_upstream(context):
  tasks = context["dag_run"].get_task_instances()
  for task in tasks:
    if task.task_id in context["task"].upstream_task_ids:
      task.set_state(State.UP_FOR_RETRY)

def fail_upstream(context):
  tasks = context["dag_run"].get_task_instances()
  for task in tasks:
    if task.task_id in context["task"].upstream_task_ids:
      task.set_state(State.FAILED)

with DAG(
  'test_callback',
  start_date=days_ago(1),
  default_args={
    'retries': 1,
    'retry_delay': timedelta(seconds=10),
    'retry_exponential_backoff': True
  }
) as dag:

  t1 = BashOperator(
    task_id='task1',
    bash_command='echo $(expr $RANDOM % 3)',
  )
  t2 = BashOperator(
    task_id='task2',
    bash_command='exit {{ti.xcom_pull(task_ids="task1")}}',
    on_retry_callback=retry_upstream,
    on_failure_callback=fail_upstream,
  )

  t1 >> t2

EMR の場合 DescribeStep が Rate exceeded で失敗することもあるので、エラー内容によって 1 つ目の Task を再実行するか判断すると良い。

should_readd_task_to_emr = 'ThrottlingException' in str(context['exception'])

参考

Bet you didn’t know this about Airflow! | by Jyoti Dhiman | Towards Data Science