SageMakerのBatch Transformのパラメータの挙動をentrypointの関数の呼び出しと引数から確認する

awsmachinelearning

SageMaker の Batch Transform は単発のバッチ推論ジョブを実行する機能。 その際推論エンドポイントなどの場合と同じく呼ばれる Model の entrypoint の関数とその引数から、ジョブのパラメータがどのようにはたらくかを確認する。

SageMakerで学習したPyTorchのモデルをElastic Inferenceを有効にしてデプロイする - sambaiz-net

from sagemaker.transformer import Transformer

transformer = Transformer(
  model_name=model_name,
  instance_type='ml.m5.xlarge',
  instance_count=1,
  output_path=f's3://{os.getenv("S3_BUCKET")}/output/',
  # strategy='SingleRecord',
  # max_payload=1,
  # assemble_with='Line',
  # accept='application/json',
)
transformer.transform(
  data=f's3://{os.getenv("S3_BUCKET")}/batch_input/',
  content_type='application/json',
  # compression_type='Gzip',
  # split_type='Line',
  # output_filter="$['SageMakerOutput','value']",
  # join_source='Input',
)

SageMaker Inference Toolkit の Transformer クラスを見ると transform_fn() から次の関数が呼ばれていたのでそれらの引数を print してみる。

import json

class Model:
  warmup = False
  def predict(self, data):
    if not self.warmup:
      raise Exception("Model not warmed up")
    return ["🍣", "🍵"][data % 2]

def pre_model_fn(model_dir, context=None):
  print("pre_model_fn")
  print(model_dir, context)

def model_fn(model_dir, context=None):
  print("model_fn")
  print(model_dir, context)
  return Model()

def model_warmup_fn(model_dir, model, context=None):
  print("model_warmup_fn")
  print(model_dir, model, context)
  model.warmup = True

def input_fn(input_data, content_type, context=None):
  print("input_fn")
  print(input_data, content_type, context)
  return json.loads(input_data)

def predict_fn(input_data, model, context=None):
  print("predict_fn")
  print(input_data, model, context)
  return model.predict(input_data['value'])

def output_fn(prediction, accept, context=None):
  print("output_fn")
  print(prediction, accept, context)
  return json.dumps({"output": f"I want to {prediction}"}, ensure_ascii=False)

次のようなjsonが 100000 行含まれる2ファイル (各 1.2 MB)を入力データとする。

$ cat testdata1.jsonl
{"value":1}
{"value":3}
{"value":4}
{"value":15}
{"value":13}
...
{"value":3000}

strategy=‘SingleRecord’, max_payload=1 (MB) を渡したとき

ファイルの内容がそのまま渡された結果 max_payload にひっかかり Too much data for max payload size エラーになった。

2023-08-13T10:31:04,142 [INFO ] W-9003-model_1.0-stdout MODEL_LOG - input_fn
2023-08-13T10:40:59,040 [INFO ] W-9002-model_1.0-stdout MODEL_LOG - {"value":1}
2023-08-13T10:40:59,040 [INFO ] W-9002-model_1.0-stdout MODEL_LOG - {"value":3}
2023-08-13T10:40:59,040 [INFO ] W-9002-model_1.0-stdout MODEL_LOG - {"value":4}
2023-08-13T10:40:59,040 [INFO ] W-9002-model_1.0-stdout MODEL_LOG - {"value":15}
...
2023-08-13T10:40:59,061 [INFO ] W-9002-model_1.0-stdout MODEL_LOG - {"value":3} application/json <ts.context.Context object at 0x7f4cde8cef10>
...
on <ts.context.Context object at 0x7f4cde8cef10>
2023-08-13T10:40:59,061 [INFO ] W-9002-model_1.0-stdout MODEL_METRICS - PredictionTime.Milliseconds:0.51|#ModelName:model,Level:Model|#hostname:26f3ed57bfc4,requestID:987d258e-61f6-4dd1-9ab9-7cea700d688d,timestamp:1691923258

2023-08-13T10:40:57.646:[sagemaker logs]: MaxConcurrentTransforms=1, MaxPayloadInMB=6, BatchStrategy=SINGLE_RECORD
2023-08-13T10:40:58.380:[sagemaker logs]: ****/batch_input/testdata2.jsonl: Bad HTTP status received from algorithm: 500
2023-08-13T10:40:58.381:[sagemaker logs]: ****/batch_input/testdata2.jsonl: 
2023-08-13T10:40:58.381:[sagemaker logs]: ****/batch_input/testdata2.jsonl: Message:
2023-08-13T10:40:58.385:[sagemaker logs]: ****/batch_input/testdata2.jsonl: Extra data: line 2 column 1 (char 12)
...
2023-08-13T10:40:59.069:[sagemaker logs]: ****/batch_input/testdata1.jsonl:     return json.loads(input_data)

split_type=‘Line’, strategy=‘SingleRecord’, max_payload=1 (MB) を渡したとき

split_type=‘Line’ を渡すと行ごとのデータが input_fn に渡されるようになった。ただ出力が改行されていない。

2023-08-13T10:42:50,220 [INFO ] W-9001-model_1.0-stdout MODEL_LOG - input_fn
2023-08-13T10:42:50,220 [INFO ] W-9001-model_1.0-stdout MODEL_LOG - {"value":8}
2023-08-13T10:42:50,220 [INFO ] W-9001-model_1.0-stdout MODEL_LOG -  application/json <ts.context.Context object at 0x7f9caa819d90>
2023-08-13T10:42:50,220 [INFO ] W-9001-model_1.0-stdout MODEL_LOG - predict_fn
2023-08-13T10:42:50,220 [INFO ] W-9001-model_1.0-stdout MODEL_LOG - {'value': 8} <inference.Model object at 0x7f9caa827340> <ts.context.Context object at 0x7f9caa819d90>
2023-08-13T10:42:50,220 [INFO ] W-9001-model_1.0-stdout MODEL_LOG - output_fn
2023-08-13T10:42:50,220 [INFO ] W-9001-model_1.0-stdout MODEL_LOG - 🍣 application/json <ts.context.Context object at 0x7f9caa819d90>

$ cat testdata1.jsonl.out
{"output": "I want to 🍵"}{"output": "I want to 🍵"}{"output": "I want to 🍣"}{"output": "I want to 🍵"}{"output": "I want to 🍵"}{"output": "I want to 🍣"}...

$ cat testdata2.jsonl.out
{"output": "I want to 🍵"}{"output": "I want to 🍵"}{"output": "I want to 🍣"}{"output": "I want to 🍵"}{"output": "I want to 🍵"}{"output": "I want to 🍣"}...

split_type=‘Line’, strategy=‘SingleRecord’, assemble_with=‘Line’ を渡したとき

assemble_with=‘Line を渡すと改行されるようになった。

$ cat testdata1.jsonl.out
{"output": "I want to 🍵"}
{"output": "I want to 🍵"}
{"output": "I want to 🍣"}
{"output": "I want to 🍵"}
{"output": "I want to 🍵"}
...

split_type=‘Line’, strategy=‘MultiRecord’, max_payload=1 を渡したとき

strategy を MultiRecord にすると input_fn には複数行データの配列が渡ると思いきや、split_type=None のときのような改行を含むデータが渡された。 ただし split_type=None のときとは異なり max_payload には当たらないようになっているため Too much data ではなく json.loads() エラーになった。

2023-08-13T14:51:12.281:[sagemaker logs]: ****/batch_input/testdata2.jsonl: Bad HTTP status received from algorithm: 500
2023-08-13T14:51:12.281:[sagemaker logs]: ****/batch_input/testdata2.jsonl:
2023-08-13T14:51:12.281:[sagemaker logs]: ****/batch_input/testdata2.jsonl: Message:
2023-08-13T14:51:12.282:[sagemaker logs]: ****/batch_input/testdata2.jsonl: Extra data: line 2 column 1 (char 12)
...
2023-08-13T14:51:18.004:[sagemaker logs]: ****/batch_input/testdata1.jsonl:     return json.loads(input_data)

input_fn で split して output_fn で join するようにしたところ SingleRecord のときと同じ出力が得られた。

def input_fn(input_data, content_type, context=None):
  print("input_fn")
  print(input_data, content_type, context)
  lines = filter(lambda x: x != '', input_data.split('\n'))
  return map(lambda x: json.loads(x), lines)

def predict_fn(input_data, model, context=None):
  print("predict_fn")
  print(input_data, model, context)
  return map(lambda x: model.predict(x['value']), input_data)

def output_fn(prediction, accept, context=None):
  print("output_fn")
  print(prediction, accept, context)
  return '\n'.join(map(lambda x: json.dumps({"output": f"I want to {x}"}, ensure_ascii=False), prediction))

join_source=‘Input’ を渡したとき

output_fn の値が SageMakerOutput フィールドとして入力データに追加されたものが返される。 content_type と accept が同じである必要がある。

{"SageMakerOutput":{"output":"I want to 🍵"},"value":1}
{"SageMakerOutput":{"output":"I want to 🍵"},"value":3}
{"SageMakerOutput":{"output":"I want to 🍣"},"value":4}
{"SageMakerOutput":{"output":"I want to 🍵"},"value":15}
{"SageMakerOutput":{"output":"I want to 🍵"},"value":13}

参考

Amazon SageMaker Batch Transform を試してみた。