SageMakerで学習したPyTorchのモデルをElastic Inferenceを有効にしてデプロイする

pythonpytorchmachinelearningaws

学習させたモデルをSageMakerのホスティングサービスにデプロイする。

SageMakerでPyTorchのモデルを学習させる - sambaiz-net

推論時に呼ばれる関数

推論時には次の関数が呼ばれる

  • model_fn(model_dir): モデルをロードする
  • input_fn(request_body, request_content_type): リクエストボディのデシリアライズ
  • predict_fn(input_data, model): モデルで推論する
  • output_fn(prediction, content_type): Content-Typeに応じたシリアライズ

input_fn()output_fn() はJSON, CSV, NPYに対応した実装が、predict_fn() はモデルを呼び出す実装がデフォルトとして用意されていて、 model_fn() も後述するElastic Inferenceを使う場合model.ptというファイルをロードするデフォルト実装が使われる。 ただしその場合モデルがtorch.jit.save()TorchScriptとして保存してある必要がある

今回は predict_fn() のみ実装した。

$ cat inference.py
import torch

def predict_fn(input_data, model):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    input_data = input_data.to(device)
    model.eval()
    with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
        output = model(input_data)
        return output.max(1)[1]

デプロイ

Training jobがモデルを保存したS3のパスを取ってきてPyTorchModelを作る。

from sagemaker.pytorch.model import PyTorchModel

training_job_name = 'pytorch-training-2020-07-25-08-41-45-674'
training_job = sess.client('sagemaker').describe_training_job(TrainingJobName=training_job_name)
model = PyTorchModel(model_data=training_job['ModelArtifacts']['S3ModelArtifacts'], 
                     role=sagemaker.get_execution_role(),
                     framework_version='1.3.1',
                     py_version='py3',
                     source_dir='/root/sagemaker-pytorch-mnist',
                     entry_point='inference.py')

deploy()するとModelsとEndpoint configurations、Endpointsが作成される。結構時間がかかる。

Endpoints

accelerator_typeに指定しているElastic Inferenceというのは、適した量のGPUリソースを各CPUインスタンスにアタッチしてくれるもので、GPUインスタンスのリソースを十分使えていない場合コストを下げることができる。ただEI用のイメージのPyTorchのバージョンが古くバージョンを下げる必要があった。

model.deploy(instance_type='ml.c4.xlarge', initial_instance_count=1, endpoint_name='pytorch-mnist-test', accelerator_type='ml.eia2.medium')

呼び出してみる。

from torchvision import datasets, transforms

dataset = datasets.MNIST('mnist', train=False, transform=transforms.ToTensor(), download=False)
result = predictor.predict(dataset[0][0].view(-1, 1, 28, 28))
print(result) # [7]
print(dataset[0][1]) # 7

shapeの不一致などでエラーを起こすとタイムアウトし、ログにも EI Error Description: Internal error と表示されるだけでトラブルシューティングが難しかった。 例外をcatchしてログを出したり入力のバリデーションを行うと良いと思う。

後片付け

デプロイしたリソースを削除する。

predictor.delete_endpoint()
predictor.delete_model()