学習させたモデルをSageMakerのホスティングサービスにデプロイする。
SageMakerでPyTorchのモデルを学習させる - sambaiz-net
推論時に呼ばれる関数
推論時には次の関数が呼ばれる。
- model_fn(model_dir): モデルをロードする
- input_fn(request_body, request_content_type): リクエストボディのデシリアライズ
- predict_fn(input_data, model): モデルで推論する
- output_fn(prediction, content_type): Content-Typeに応じたシリアライズ
input_fn()
と output_fn()
はJSON, CSV, NPYに対応した実装が、predict_fn()
はモデルを呼び出す実装がデフォルトとして用意されていて、
model_fn()
も後述するElastic Inferenceを使う場合model.pt
というファイルをロードするデフォルト実装が使われる。
ただしその場合モデルがtorch.jit.save()
でTorchScriptとして保存してある必要がある。
今回は predict_fn()
のみ実装した。
$ cat inference.py
import torch
def predict_fn(input_data, model):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
input_data = input_data.to(device)
model.eval()
with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
output = model(input_data)
return output.max(1)[1]
デプロイ
Training jobがモデルを保存したS3のパスを取ってきてPyTorchModelを作る。
from sagemaker.pytorch.model import PyTorchModel
training_job_name = 'pytorch-training-2020-07-25-08-41-45-674'
training_job = sess.client('sagemaker').describe_training_job(TrainingJobName=training_job_name)
model = PyTorchModel(model_data=training_job['ModelArtifacts']['S3ModelArtifacts'],
role=sagemaker.get_execution_role(),
framework_version='1.3.1',
py_version='py3',
source_dir='/root/sagemaker-pytorch-mnist',
entry_point='inference.py')
deploy()
するとModelsとEndpoint configurations、Endpointsが作成される。結構時間がかかる。
accelerator_type
に指定しているElastic Inferenceというのは、適した量のGPUリソースを各CPUインスタンスにアタッチしてくれるもので、GPUインスタンスのリソースを十分使えていない場合コストを下げることができる。ただEI用のイメージのPyTorchのバージョンが古くバージョンを下げる必要があった。
model.deploy(instance_type='ml.c4.xlarge', initial_instance_count=1, endpoint_name='pytorch-mnist-test', accelerator_type='ml.eia2.medium')
呼び出してみる。
from torchvision import datasets, transforms
dataset = datasets.MNIST('mnist', train=False, transform=transforms.ToTensor(), download=False)
result = predictor.predict(dataset[0][0].view(-1, 1, 28, 28))
print(result) # [7]
print(dataset[0][1]) # 7
shapeの不一致などでエラーを起こすとタイムアウトし、ログにも EI Error Description: Internal error
と表示されるだけでトラブルシューティングが難しかった。
例外をcatchしてログを出したり入力のバリデーションを行うと良いと思う。
後片付け
デプロイしたリソースを削除する。
predictor.delete_endpoint()
predictor.delete_model()