GoでAmazon Forecastに時系列データをimportしPredictorを作成して予測結果をS3にexportする

awsgolangmachinelearning

以前コンソール上で実行したAmazon Forecastでの時系列データの学習、予測をGoで行う。全体のコードはGitHubにある。

Amazon Forecastで時系列データの予測を行う - sambaiz-net

Datasetの作成

以前と同じく電力消費量のデータセットを、予測対象の時系列データ(DatasetTypeTargetTimeSeries)として登録する。データの頻度は1時間でドメインはCustom。

func (f Forecast) CreateDataset(ctx context.Context, name string) (*string, error) {
	return f.skipIfAlreadyExists("dataset", name, func() (*string, error) {
		dataset, err := f.svc.CreateDataset(ctx, &forecast.CreateDatasetInput{
			DatasetName:   aws.String(name),
			DatasetType:   types.DatasetTypeTargetTimeSeries,
			DataFrequency: aws.String("H"),
			Domain:        types.DomainCustom,
			Schema: &types.Schema{
				Attributes: []types.SchemaAttribute{
					{
						AttributeName: aws.String("timestamp"),
						AttributeType: types.AttributeTypeTimestamp,
					},
					{
						AttributeName: aws.String("target_value"),
						AttributeType: types.AttributeTypeFloat,
					},
					{
						AttributeName: aws.String("item_id"),
						AttributeType: types.AttributeTypeString,
					},
				},
			},
		})
		if err != nil {
			return nil, err
		}
		return dataset.DatasetArn, nil
	})
}

DatasetImport/ForecastExport JobPredictorForecastはリソースを作成してから使えるようになるまで時間がかかるので、DescribeしてStatusがActiveになるのを待っている。

func (f Forecast) waitForActive(ctx context.Context, name string, h func() (*string, *int64, error)) error {
	ticker := time.NewTicker(time.Minute * 1)
	defer ticker.Stop()
	for {
		select {
		case <-ticker.C:
			status, remainingMin, err := h()
			if err != nil {
				return err
			}
			if *status == "ACTIVE" {
				return nil
			} else if !strings.HasPrefix(*status, "CREATE") {
				return fmt.Errorf("%s is not creating but %s", name, *status)
			} else if *status == "CREATE_FAILED" {
				return errors.New("creating is failed")
			}
			if remainingMin != nil {
				log.Printf("%s's status is %s. remaining %d mins", name, *status, *remainingMin)
			}
		case <-ctx.Done():
			return nil
		}
	}
}

func (f Forecast) CreateDatasetImportJob(ctx context.Context, name, datasetName, datasetArn string, src *types.S3Config) (*string, error) {
	arn, err := f.skipIfAlreadyExists(fmt.Sprintf("dataset-import-job/%s", datasetName), name, func() (*string, error) {
		job, err := f.svc.CreateDatasetImportJob(ctx, &forecast.CreateDatasetImportJobInput{
			DatasetImportJobName: aws.String(name),
			DatasetArn:           &datasetArn,
			TimeZone:             aws.String("America/Los_Angeles"),
			DataSource: &types.DataSource{
				S3Config: src,
			},
		})
		if err != nil {
			return nil, err
		}
		return job.DatasetImportJobArn, nil
	})
	if err != nil {
		return nil, err
	}

	if err := f.waitForActive(ctx, "dataset-import-job", func() (*string, *int64, error) {
		desc, err := f.svc.DescribeDatasetImportJob(ctx, &forecast.DescribeDatasetImportJobInput{
			DatasetImportJobArn: arn,
		})
		if err != nil {
			return nil, nil, err
		}
		return desc.Status, desc.EstimatedTimeRemainingInMinutes, nil
	}); err != nil {
		return nil, err
	}
	return arn, nil
}

DatasetGroupの作成

登録したDatasetDatasetGroupを作成する。

func (f Forecast) CreateDatasetGroup(ctx context.Context, name string, datasetArns []string) (*string, error) {
	return f.skipIfAlreadyExists("dataset-group", name, func() (*string, error) {
		datasetGroup, err := f.svc.CreateDatasetGroup(ctx, &forecast.CreateDatasetGroupInput{
			DatasetGroupName: aws.String(name),
			DatasetArns:      datasetArns,
			Domain:           types.DomainCustom,
		})
		if err != nil {
			return nil, err
		}
		return datasetGroup.DatasetGroupArn, nil
	})
}

Predictorの学習

AutoMLで、USの祝日を有効にして学習を始める。予測期間は72時間。Activeになるまで180分ほどかかった。

func (f Forecast) CreatePredictor(ctx context.Context, name, datasetGroupArn string) (*string, error) {
	arn, err := f.skipIfAlreadyExists("predictor", name, func() (*string, error) {
		predictor, err := f.svc.CreatePredictor(ctx, &forecast.CreatePredictorInput{
			PredictorName:   aws.String(name),
			ForecastHorizon: aws.Int32(72), // 3 days 2015-01-01T00:00:00 - 2015-01-04T00:00:00
			FeaturizationConfig: &types.FeaturizationConfig{
				ForecastFrequency: aws.String("H"),
			},
			PerformAutoML: aws.Bool(true),
			InputDataConfig: &types.InputDataConfig{
				DatasetGroupArn: aws.String(datasetGroupArn),
				SupplementaryFeatures: []types.SupplementaryFeature{
					{
						Name:  aws.String("holiday"),
						Value: aws.String("US"),
					},
				},
			},
		})
		if err != nil {
			return nil, err
		}
		return predictor.PredictorArn, nil
	})
	if err != nil {
		return nil, err
	}

	if err := f.waitForActive(ctx, "predictor", func() (*string, *int64, error) {
		desc, err := f.svc.DescribePredictor(ctx, &forecast.DescribePredictorInput{
			PredictorArn: arn,
		})
		if err != nil {
			return nil, nil, err
		}
		return desc.Status, desc.EstimatedTimeRemainingInMinutes, nil
	}); err != nil {
		return nil, err
	}
	return arn, err
}

Forecastの生成

学習したPredictorで予測する。

func (f Forecast) CreateForecast(ctx context.Context, name, predictorArn string) (*string, error) {
	arn, err := f.skipIfAlreadyExists("forecast", name, func() (*string, error) {
		forecast, err := f.svc.CreateForecast(ctx, &forecast.CreateForecastInput{
			ForecastName: aws.String(name),
			PredictorArn: aws.String(predictorArn),
		})
		if err != nil {
			return nil, err
		}
		return forecast.ForecastArn, nil
	})
	if err != nil {
		return nil, err
	}

	if err := f.waitForActive(ctx, "forecast", func() (*string, *int64, error) {
		desc, err := f.svc.DescribeForecast(ctx, &forecast.DescribeForecastInput{
			ForecastArn: arn,
		})
		if err != nil {
			return nil, nil, err
		}
		return desc.Status, desc.EstimatedTimeRemainingInMinutes, nil
	}); err != nil {
		return nil, err
	}
	return arn, nil
}

予測結果をS3にExportする。

func (f Forecast) CreateForecastExportJob(ctx context.Context, name, forecastName, forecastArn string, dest *types.S3Config) (*string, error) {
	arn, err := f.skipIfAlreadyExists(fmt.Sprintf("forecast-export-job/%s", forecastName), name, func() (*string, error) {
		job, err := f.svc.CreateForecastExportJob(ctx, &forecast.CreateForecastExportJobInput{
			ForecastExportJobName: aws.String(name),
			ForecastArn:           aws.String(forecastArn),
			Destination: &types.DataDestination{
				S3Config: dest,
			},
		})
		if err != nil {
			return nil, err
		}
		return job.ForecastExportJobArn, nil
	})
	if err != nil {
		return nil, err
	}

	if f.waitForActive(ctx, "forecast-export-job", func() (*string, *int64, error) {
		desc, err := f.svc.DescribeForecastExportJob(ctx, &forecast.DescribeForecastExportJobInput{
			ForecastExportJobArn: arn,
		})
		if err != nil {
			return nil, nil, err
		}
		return desc.Status, aws.Int64(0), nil
	}); err != nil {
		return nil, err
	}
	return arn, nil
}

完了すると指定したパスに次のようなcsvが保存される。

S3のオブジェクト

item_id,date,p10,p50,p90
client_355,2015-01-01T01:00:00Z,12.408531189,14.5442657471,16.5890808105
client_355,2015-01-01T02:00:00Z,11.4778079987,13.4856367111,15.8068180084
client_355,2015-01-01T03:00:00Z,12.2634305954,14.1838378906,17.9517936707
client_355,2015-01-01T04:00:00Z,12.0927619934,14.6298713684,18.1633453369
client_355,2015-01-01T05:00:00Z,12.13489151,15.9390125275,19.2308177948
client_355,2015-01-01T06:00:00Z,12.7612571716,16.1918830872,19.3734836578
client_355,2015-01-01T07:00:00Z,13.0059919357,16.4345035553,19.6970005035
client_355,2015-01-01T08:00:00Z,13.1267080307,17.3538475037,20.8151321411
client_355,2015-01-01T09:00:00Z,18.6357440948,23.1925506592,28.9863166809
client_355,2015-01-01T10:00:00Z,23.7633323669,31.5546112061,41.632358551
client_355,2015-01-01T11:00:00Z,21.1502532959,30.0297222137,42.0502319336
client_355,2015-01-01T12:00:00Z,19.8031330109,28.0423240662,39.4894638062
client_355,2015-01-01T13:00:00Z,21.9451351166,29.55493927,44.4517326355
client_355,2015-01-01T14:00:00Z,21.8753509521,31.9131622314,48.4647483826
client_355,2015-01-01T15:00:00Z,20.5444030762,31.8455886841,49.697052002

Clean up

依存しているリソースが残っていると削除に失敗するので作成した逆順で削除していく。

if err := f.DeleteForecastExportJob(ctx, *forecastExportJobArn); err != nil {
  log.Fatal(err)
}
if err := f.DeleteForecast(ctx, *forecastArn); err != nil {
  log.Fatal(err)
}
if err := f.DeletePredictor(ctx, *predictorArn); err != nil {
  log.Fatal(err)
}
if err := f.DeleteDatasetGroup(ctx, *datasetGroupArn); err != nil {
  log.Fatal(err)
}
if err := f.DeleteDatasetImportJob(ctx, *datasetImportJobArn); err != nil {
  log.Fatal(err)
}
if err := f.DeleteDataset(ctx, *datasetArn); err != nil {
  log.Fatal(err)
}

リソースを削除する際も若干時間がかかるので同様にStatusを見ている。

func (f Forecast) waitForDeleted(ctx context.Context, name string, h func() (*string, error)) error {
	ticker := time.NewTicker(time.Minute * 1)
	defer ticker.Stop()
	for {
		select {
		case <-ticker.C:
			status, err := h()
			if err != nil {
				var notFound *types.ResourceNotFoundException
				if errors.As(err, &notFound) {
					return nil
				}
				return err
			}
			if !strings.HasPrefix(*status, "DELETE") {
				return fmt.Errorf("%s is not deleting but %s", name, *status)
			} else if *status == "DELTE_FAILED" {
				return errors.New("creating is failed")
			}
			log.Printf("%s's status is %s", name, *status)
		case <-ctx.Done():
			return nil
		}
	}
}