Athenaのデータソースコネクタとユーザー定義関数(UDF)を実装する

awsjavaetl

AthenaにはLambdaをコネクタとしてS3以外のデータソースにアクセスできるFederate Queryという機能があって、公式のリポジトリでBigQueryやSnowflakeなど様々なデータソースのコネクタが提供されているが自作することもできる。 今回はExample Connectorを参考にしながら最低限のコネクタを実装しその動作を確認する。全体のコードはGitHubにある。

AthenaのFederated QueryでTPC-DS Connectorを用いてデータを生成する - sambaiz-net

CompositeHandler

テーブルのスキーマなどを返すMetadataHandlerとデータを返すExampleRecordHandlerをまとめるためのHandlerで、 必要ならUDFを実行するUserDefinedFuncHandlerもこれに含めることができる。

package net.sambaiz.athena_connector_udf_example;

import com.amazonaws.athena.connector.lambda.handlers.CompositeHandler;

public class App extends CompositeHandler {
    public App() {
        super(new ExampleMetadataHandler(), new ExampleRecordHandler(), new ExampleUserDefinedFuncHandler());
    }
}

MetadataHandler

テーブルのスキーマやパーティションといったデータソースのメタデータを返すHandler。

describe `lambda:athena-connector-udf-example`.sample_db.sample_table
/*
foo struct<bar:int>
# Partition Information
# col_name  data_type   comment
year    int
*/

データベースやテーブルの一覧を返す doListSchemaNames()doListTables()

@Override
public ListSchemasResponse doListSchemaNames(BlockAllocator allocator, ListSchemasRequest request) {
    logger.info("MetadataHandler.doListSchemaNames() with requestType: " + request.getRequestType());

    Set<String> schemas = new HashSet<>();
    schemas.add("sample_db");
    return new ListSchemasResponse(request.getCatalogName(), schemas);
}

@Override
public ListTablesResponse doListTables(BlockAllocator allocator, ListTablesRequest request) {
    logger.info("MetadataHandler.doListTables() with requestType: " + request.getRequestType());

    List<TableName> tables = new ArrayList<>();
    tables.add(new TableName(request.getSchemaName(), "sample_table"));
    String nextToken = null;
    return new ListTablesResponse(request.getCatalogName(), tables, nextToken);
}

テーブルのスキーマを返す doGetTable()

@Override
public GetTableResponse doGetTable(BlockAllocator allocator, GetTableRequest request) {
    logger.info("MetadataHandler.doGetTable() with tableName: " + request.getTableName());

    Set<String> partitionColNames = new HashSet<>();
    partitionColNames.add("year");

    SchemaBuilder tableSchemaBuilder = SchemaBuilder.newBuilder();
    tableSchemaBuilder
        .addIntField("year")
        .addStructField("foo")
        .addChildField("foo", "bar", Types.MinorType.INT.getType());

    return new GetTableResponse(request.getCatalogName(),
            request.getTableName(),
            tableSchemaBuilder.build(),
            partitionColNames);
}

パーティションを返す getPartitions() と並列実行の単位であるSplitを返す doGetSplits() からなる。

@Override
public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker) throws Exception {
    logger.info("MetadataHandler.getPartitions() with schema: " + request.getSchema().toJson());
    for (int year = 2000; year <= 2022; year++) {
        final int yearVal = year;
        blockWriter.writeRows((Block block, int row) -> {
            boolean matched = true;
            matched &= block.setValue("year", row, yearVal);
            return matched ? 1 : 0;
        });
    }
}

@Override
public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest request)
{
    String partitionFieldNames = request.getPartitions().getFields().stream()
            .map((a) -> a.getName())
            .collect(Collectors.joining(","));
    logger.info("MetadataHandler.doGetSplits() with requestType: partitionFieldNames: " + partitionFieldNames);

    Set<Split> splits = new HashSet<>();
    Block partitions = request.getPartitions();
    FieldReader year = partitions.getFieldReader("year");
    for (int i = 0; i < partitions.getRowCount(); i++) {
        year.setPosition(i);

        // Splits are parallelizable units of work.
        // For each partition in the request, create 1 or more splits.
        Split split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey())
                .add("year", String.valueOf(year.readInteger()))
                .build();
        splits.add(split);
    }

    return new GetSplitsResponse(request.getCatalogName(), splits);
}

RecordHandler

MetadataHandlerによって返されたテーブルのデータを返すHandler。

select * from "lambda:athena-connector-udf-example".sample_db.sample_table where year = 2022
/*
#	year	foo
1	2022	{bar=0}
2	2022	{bar=1}
3	2022	{bar=2}
4	2022	{bar=3}
5	2022	{bar=4}

BlockSpiller.writeRows() でデータを書き込んだ行数のデータが返っている。

@Override
protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) throws IOException {
    String splitProperties = Joiner.on(",").withKeyValueSeparator("=").join(recordsRequest.getSplit().getProperties());
    logger.info("RecordHandler.readWithConstraint() with splitProperties: " + splitProperties);

    GeneratedRowWriter.RowWriterBuilder builder = GeneratedRowWriter.newBuilder(recordsRequest.getConstraints());
    builder.withExtractor("year", (IntExtractor) (Object context, NullableIntHolder value) -> {
        value.isSet = 1;
        value.value = Integer.parseInt(((String[]) context)[0]);
    });
    builder.withFieldWriterFactory("foo",
        (FieldVector vector, Extractor extractor, ConstraintProjector constraint) ->
        (Object context, int rowNum) -> {
            Map<String, Object> eventMap = new HashMap<>();
            eventMap.put("bar", Integer.parseInt(((String[])context)[1]));
            BlockUtils.setComplexValue(vector, rowNum, FieldResolver.DEFAULT, eventMap);
            return true;
        });
    GeneratedRowWriter rowWriter = builder.build();

    int splitYear = recordsRequest.getSplit().getPropertyAsInt("year");
    for (int i = 0; i < 5; i++) {
        String[] data = {
                String.valueOf(splitYear), /* year */
                String.valueOf(i) /* foo.bar */
        };
        spiller.writeRows((Block block, int rowNum) -> rowWriter.writeRow(block, rowNum, data) ? 1 : 0);
    }
}

UserDefinedFunctionHandler

ユーザー定義関数(UDF)を処理するHandler。

USING 
    EXTERNAL FUNCTION plus_one(value INT) RETURNS INT LAMBDA 'athena-connector-udf-example'
SELECT plus_one(100)
/* 
101 
*/

任意の関数名でpublicなメソッドを実装するとUDFとして扱われる。

package net.sambaiz.athena_connector_udf_example;

import com.amazonaws.athena.connector.lambda.handlers.UserDefinedFunctionHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ExampleUserDefinedFuncHandler extends UserDefinedFunctionHandler {
    private static final Logger logger = LoggerFactory.getLogger(ExampleUserDefinedFuncHandler.class);

    private static final String SOURCE_TYPE = "custom";

    public ExampleUserDefinedFuncHandler()
    {
        super(SOURCE_TYPE);
    }

    public Integer plus_one(Integer n) {
        logger.info("UserDefinedFunctionHandler.plus_one() with " + n);
        return n + 1;
    }
}

パフォーマンス最適化のためUDFは複数行でバッチ実行される。

(追記: 2023-02-21) Re:Invent 2022 で発表された Lambda SnapStart を有効にすることで起動にかかる時間を削減しパフォーマンスを向上させることができる。