Implement Athena's data source connectors and user defined functions (UDF)

awsjava

Athena has a feature called Federate Query that can access data sources other than S3 using Lambda as a connector, and the official repository provides connectors for various data sources such as BigQuery and Snowflake, but you can also implement your own. This article, implement the minimum connector while referring to Example Connector and run it. The full codes has been pushed to GitHub.

Generate data with TPC-DS Connector in Athena’s Federated Query - sambaiz-net

CompositeHandler

CompositeHandler is a handler for grouping a MetadataHandler that returns the schema of a table and an ExampleRecordHandler that returns its data. This can also include a UserDefinedFuncHandler that runs a UDF if needed.

package net.sambaiz.athena_connector_udf_example;

import com.amazonaws.athena.connector.lambda.handlers.CompositeHandler;

public class App extends CompositeHandler {
    public App() {
        super(new ExampleMetadataHandler(), new ExampleRecordHandler(), new ExampleUserDefinedFuncHandler());
    }
}

MetadataHandler

MetadataHandler is a handler that returns data source metadata such as table schemas and partitions.

describe `lambda:athena-connector-udf-example`.sample_db.sample_table
/*
foo struct<bar:int>
# Partition Information
# col_name  data_type   comment
year    int
*/

It consists on doListSchemaNames() that returns database names, doListTables() that returns table names,

@Override
public ListSchemasResponse doListSchemaNames(BlockAllocator allocator, ListSchemasRequest request) {
    logger.info("MetadataHandler.doListSchemaNames() with requestType: " + request.getRequestType());

    Set<String> schemas = new HashSet<>();
    schemas.add("sample_db");
    return new ListSchemasResponse(request.getCatalogName(), schemas);
}

@Override
public ListTablesResponse doListTables(BlockAllocator allocator, ListTablesRequest request) {
    logger.info("MetadataHandler.doListTables() with requestType: " + request.getRequestType());

    List<TableName> tables = new ArrayList<>();
    tables.add(new TableName(request.getSchemaName(), "sample_table"));
    String nextToken = null;
    return new ListTablesResponse(request.getCatalogName(), tables, nextToken);
}

doGetTable() that returns table schema,

@Override
public GetTableResponse doGetTable(BlockAllocator allocator, GetTableRequest request) {
    logger.info("MetadataHandler.doGetTable() with tableName: " + request.getTableName());

    Set<String> partitionColNames = new HashSet<>();
    partitionColNames.add("year");

    SchemaBuilder tableSchemaBuilder = SchemaBuilder.newBuilder();
    tableSchemaBuilder
        .addIntField("year")
        .addStructField("foo")
        .addChildField("foo", "bar", Types.MinorType.INT.getType());

    return new GetTableResponse(request.getCatalogName(),
            request.getTableName(),
            tableSchemaBuilder.build(),
            partitionColNames);
}

getPartitions() that returns partitions and doGetSplits() thar returns Split which is parallelizable units.

@Override
public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker) throws Exception {
    logger.info("MetadataHandler.getPartitions() with schema: " + request.getSchema().toJson());
    for (int year = 2000; year <= 2022; year++) {
        final int yearVal = year;
        blockWriter.writeRows((Block block, int row) -> {
            boolean matched = true;
            matched &= block.setValue("year", row, yearVal);
            return matched ? 1 : 0;
        });
    }
}

@Override
public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest request)
{
    String partitionFieldNames = request.getPartitions().getFields().stream()
            .map((a) -> a.getName())
            .collect(Collectors.joining(","));
    logger.info("MetadataHandler.doGetSplits() with requestType: partitionFieldNames: " + partitionFieldNames);

    Set<Split> splits = new HashSet<>();
    Block partitions = request.getPartitions();
    FieldReader year = partitions.getFieldReader("year");
    for (int i = 0; i < partitions.getRowCount(); i++) {
        year.setPosition(i);

        // Splits are parallelizable units of work.
        // For each partition in the request, create 1 or more splits.
        Split split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey())
                .add("year", String.valueOf(year.readInteger()))
                .build();
        splits.add(split);
    }

    return new GetSplitsResponse(request.getCatalogName(), splits);
}

RecordHandler

RecordHandler is a handler that returns data of the table which is returned by MetadataHandler.

select * from "lambda:athena-connector-udf-example".sample_db.sample_table where year = 2022
/*
#	year	foo
1	2022	{bar=0}
2	2022	{bar=1}
3	2022	{bar=2}
4	2022	{bar=3}
5	2022	{bar=4}

Rows is written by BlockSpiller.writeRows().

@Override
protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) throws IOException {
    String splitProperties = Joiner.on(",").withKeyValueSeparator("=").join(recordsRequest.getSplit().getProperties());
    logger.info("RecordHandler.readWithConstraint() with splitProperties: " + splitProperties);

    GeneratedRowWriter.RowWriterBuilder builder = GeneratedRowWriter.newBuilder(recordsRequest.getConstraints());
    builder.withExtractor("year", (IntExtractor) (Object context, NullableIntHolder value) -> {
        value.isSet = 1;
        value.value = Integer.parseInt(((String[]) context)[0]);
    });
    builder.withFieldWriterFactory("foo",
        (FieldVector vector, Extractor extractor, ConstraintProjector constraint) ->
        (Object context, int rowNum) -> {
            Map<String, Object> eventMap = new HashMap<>();
            eventMap.put("bar", Integer.parseInt(((String[])context)[1]));
            BlockUtils.setComplexValue(vector, rowNum, FieldResolver.DEFAULT, eventMap);
            return true;
        });
    GeneratedRowWriter rowWriter = builder.build();

    int splitYear = recordsRequest.getSplit().getPropertyAsInt("year");
    for (int i = 0; i < 5; i++) {
        String[] data = {
                String.valueOf(splitYear), /* year */
                String.valueOf(i) /* foo.bar */
        };
        spiller.writeRows((Block block, int rowNum) -> rowWriter.writeRow(block, rowNum, data) ? 1 : 0);
    }
}

UserDefinedFunctionHandler

UserDefinedFunctionHandler is a handler that processes user defined function (UDF).

USING 
    EXTERNAL FUNCTION plus_one(value INT) RETURNS INT LAMBDA 'athena-connector-udf-example'
SELECT plus_one(100)
/* 
101 
*/

Public methods are treated as UDF.

package net.sambaiz.athena_connector_udf_example;

import com.amazonaws.athena.connector.lambda.handlers.UserDefinedFunctionHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ExampleUserDefinedFuncHandler extends UserDefinedFunctionHandler {
    private static final Logger logger = LoggerFactory.getLogger(ExampleUserDefinedFuncHandler.class);

    private static final String SOURCE_TYPE = "custom";

    public ExampleUserDefinedFuncHandler()
    {
        super(SOURCE_TYPE);
    }

    public Integer plus_one(Integer n) {
        logger.info("UserDefinedFunctionHandler.plus_one() with " + n);
        return n + 1;
    }
}

To optimize performance, UDF is executed on a batch of rows.