Skip to content

Commit

Permalink
Convert object store connectors to SourcePage
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Nov 5, 2024
1 parent 51ad8aa commit 0bd2fb5
Show file tree
Hide file tree
Showing 78 changed files with 1,801 additions and 2,692 deletions.
11 changes: 8 additions & 3 deletions lib/trino-orc/src/main/java/io/trino/orc/OrcReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import io.trino.orc.metadata.PostScript.HiveWriterVersion;
import io.trino.orc.stream.OrcChunkLoader;
import io.trino.orc.stream.OrcInputStream;
import io.trino.spi.Page;
import io.trino.spi.connector.SourcePage;
import io.trino.spi.type.Type;
import org.joda.time.DateTimeZone;

Expand Down Expand Up @@ -252,6 +252,7 @@ public CompressionKind getCompressionKind()
public OrcRecordReader createRecordReader(
List<OrcColumn> readColumns,
List<Type> readTypes,
boolean appendRowNumberColumn,
OrcPredicate predicate,
DateTimeZone legacyFileTimeZone,
AggregatedMemoryContext memoryUsage,
Expand All @@ -263,6 +264,7 @@ public OrcRecordReader createRecordReader(
readColumns,
readTypes,
Collections.nCopies(readColumns.size(), fullyProjectedLayout()),
appendRowNumberColumn,
predicate,
0,
orcDataSource.getEstimatedSize(),
Expand All @@ -277,6 +279,7 @@ public OrcRecordReader createRecordReader(
List<OrcColumn> readColumns,
List<Type> readTypes,
List<ProjectedLayout> readLayouts,
boolean appendRowNumberColumn,
OrcPredicate predicate,
long offset,
long length,
Expand All @@ -291,6 +294,7 @@ public OrcRecordReader createRecordReader(
requireNonNull(readColumns, "readColumns is null"),
requireNonNull(readTypes, "readTypes is null"),
requireNonNull(readLayouts, "readLayouts is null"),
appendRowNumberColumn,
requireNonNull(predicate, "predicate is null"),
footer.getNumberOfRows(),
footer.getStripes(),
Expand Down Expand Up @@ -416,6 +420,7 @@ static void validateFile(
try (OrcRecordReader orcRecordReader = orcReader.createRecordReader(
orcReader.getRootColumn().getNestedColumns(),
readTypes,
false,
OrcPredicate.TRUE,
UTC,
newSimpleAggregatedMemoryContext(),
Expand All @@ -424,9 +429,9 @@ static void validateFile(
throwIfUnchecked(exception);
return new RuntimeException(exception);
})) {
for (Page page = orcRecordReader.nextPage(); page != null; page = orcRecordReader.nextPage()) {
for (SourcePage page = orcRecordReader.nextPage(); page != null; page = orcRecordReader.nextPage()) {
// fully load the page
page.getLoadedPage();
page.getPage();
}
}
}
Expand Down
180 changes: 165 additions & 15 deletions lib/trino-orc/src/main/java/io/trino/orc/OrcRecordReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.io.Closer;
import com.google.errorprone.annotations.CheckReturnValue;
import com.google.errorprone.annotations.FormatMethod;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
Expand All @@ -40,7 +41,10 @@
import io.trino.orc.stream.InputStreamSources;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.LongArrayBlock;
import io.trino.spi.connector.SourcePage;
import io.trino.spi.type.Type;
import jakarta.annotation.Nullable;
import org.joda.time.DateTimeZone;

import java.io.Closeable;
Expand All @@ -54,6 +58,7 @@
import java.util.Optional;
import java.util.OptionalInt;
import java.util.function.Function;
import java.util.function.ObjLongConsumer;
import java.util.function.Predicate;
import java.util.stream.Collectors;

Expand All @@ -70,14 +75,17 @@
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
import static java.util.Comparator.comparingLong;
import static java.util.Objects.checkIndex;
import static java.util.Objects.requireNonNull;

public class OrcRecordReader
implements Closeable
{
private static final int INSTANCE_SIZE = instanceSize(OrcRecordReader.class);

private final List<OrcColumn> columns;
private final OrcDataSource orcDataSource;
private final boolean appendRowNumberColumn;

private final ColumnReader[] columnReaders;
private final long[] currentBytesPerCell;
Expand Down Expand Up @@ -129,6 +137,7 @@ public OrcRecordReader(
List<OrcColumn> readColumns,
List<Type> readTypes,
List<OrcReader.ProjectedLayout> readLayouts,
boolean appendRowNumberColumn,
OrcPredicate predicate,
long numberOfRows,
List<StripeInformation> fileStripes,
Expand All @@ -152,7 +161,7 @@ public OrcRecordReader(
FieldMapperFactory fieldMapperFactory)
throws OrcCorruptionException
{
requireNonNull(readColumns, "readColumns is null");
this.columns = requireNonNull(readColumns, "readColumns is null");
checkArgument(readColumns.stream().distinct().count() == readColumns.size(), "readColumns contains duplicate entries");
requireNonNull(readTypes, "readTypes is null");
checkArgument(readColumns.size() == readTypes.size(), "readColumns and readTypes must have the same size");
Expand All @@ -168,6 +177,7 @@ public OrcRecordReader(
requireNonNull(userMetadata, "userMetadata is null");
requireNonNull(memoryUsage, "memoryUsage is null");
requireNonNull(exceptionTransform, "exceptionTransform is null");
this.appendRowNumberColumn = appendRowNumberColumn;

this.writeValidation = requireNonNull(writeValidation, "writeValidation is null");
this.writeChecksumBuilder = writeValidation.map(validation -> createWriteChecksumBuilder(orcTypes, readTypes));
Expand Down Expand Up @@ -304,6 +314,11 @@ static OrcDataSource wrapWithCacheIfTinyStripes(OrcDataSource dataSource, List<S
return new CachingOrcDataSource(dataSource, createTinyStripesRangeFinder(stripes, maxMergeDistance, tinyStripeThreshold));
}

public List<OrcColumn> getColumns()
{
return columns;
}

/**
* Return the row position relative to the start of the file.
*/
Expand Down Expand Up @@ -406,7 +421,7 @@ public void close()
}
}

public Page nextPage()
public SourcePage nextPage()
throws IOException
{
// update position for current row group (advancing resets them)
Expand Down Expand Up @@ -447,21 +462,156 @@ public Page nextPage()
// create a lazy page
blockFactory.nextPage();
Arrays.fill(currentBytesPerCell, 0);
Block[] blocks = new Block[columnReaders.length];
for (int i = 0; i < columnReaders.length; i++) {
int columnIndex = i;
blocks[columnIndex] = blockFactory.createBlock(
currentBatchSize,
columnReaders[columnIndex]::readBlock,
false);
listenForLoads(blocks[columnIndex], block -> blockLoaded(columnIndex, block));
}

Page page = new Page(currentBatchSize, blocks);
SourcePage page = new OrcSourcePage(currentBatchSize);
validateWritePageChecksum(page);
return page;
}

private class OrcSourcePage
implements SourcePage
{
private final Block[] blocks = new Block[columnReaders.length + (appendRowNumberColumn ? 1 : 0)];
private final int rowNumberColumnIndex = appendRowNumberColumn ? columnReaders.length : -1;
private SelectedPositions selectedPositions;

public OrcSourcePage(int positionCount)
{
selectedPositions = new SelectedPositions(positionCount, null);
}

@Override
public int getPositionCount()
{
return selectedPositions.positionCount();
}

@Override
public long getSizeInBytes()
{
long sizeInBytes = 0;
for (Block block : blocks) {
if (block != null) {
sizeInBytes += block.getSizeInBytes();
}
}
return sizeInBytes;
}

@Override
public long getRetainedSizeInBytes()
{
long retainedSizeInBytes = 0;
for (Block block : blocks) {
if (block != null) {
retainedSizeInBytes += block.getRetainedSizeInBytes();
}
}
return retainedSizeInBytes;
}

@Override
public void retainedBytesForEachPart(ObjLongConsumer<Object> consumer)
{
for (Block block : blocks) {
if (block != null) {
block.retainedBytesForEachPart(consumer);
}
}
}

@Override
public int getChannelCount()
{
return blocks.length;
}

@Override
public Block getBlock(int channel)
{
checkIndex(channel, blocks.length);

Block block = blocks[channel];
if (block == null) {
if (channel == rowNumberColumnIndex) {
block = selectedPositions.createRowNumberBlock(filePosition);
}
else {
// todo use selected positions to improve read performance
block = blockFactory.createBlock(
currentBatchSize,
columnReaders[channel]::readBlock,
false);
listenForLoads(block, nestedBlock -> blockLoaded(channel, nestedBlock));
block = selectedPositions.apply(block);
}
blocks[channel] = block;
}
return block;
}

@Override
public Page getPage()
{
// ensure all blocks are loaded
for (int i = 0; i < blocks.length; i++) {
getBlock(i);
}
return new Page(selectedPositions.positionCount(), blocks);
}

@Override
public void selectPositions(int[] positions, int offset, int size)
{
selectedPositions = selectedPositions.selectPositions(positions, offset, size);
for (int i = 0; i < blocks.length; i++) {
Block block = blocks[i];
if (block != null) {
block = selectedPositions.apply(block);
blocks[i] = block;
}
}
}
}

private record SelectedPositions(int positionCount, @Nullable int[] positions)
{
@CheckReturnValue
public Block apply(Block block)
{
if (positions == null) {
return block;
}
return block.getPositions(positions, 0, positionCount);
}

public Block createRowNumberBlock(long filePosition)
{
long[] rowNumbers = new long[positionCount];
for (int i = 0; i < positionCount; i++) {
int position = positions == null ? i : positions[i];
rowNumbers[i] = filePosition + position;
}
return new LongArrayBlock(positionCount, Optional.empty(), rowNumbers);
}

@CheckReturnValue
public SelectedPositions selectPositions(int[] positions, int offset, int size)
{
if (this.positions == null) {
for (int i = 0; i < size; i++) {
checkIndex(offset + i, positionCount);
}
return new SelectedPositions(size, Arrays.copyOfRange(positions, offset, offset + size));
}

int[] newPositions = new int[size];
for (int i = 0; i < size; i++) {
newPositions[i] = this.positions[positions[offset + i]];
}
return new SelectedPositions(size, newPositions);
}
}

private void blockLoaded(int columnIndex, Block block)
{
if (block.getPositionCount() <= 0) {
Expand Down Expand Up @@ -586,10 +736,10 @@ private void validateWriteStripe(int rowCount)
writeChecksumBuilder.ifPresent(builder -> builder.addStripe(rowCount));
}

private void validateWritePageChecksum(Page page)
private void validateWritePageChecksum(SourcePage sourcePage)
{
if (writeChecksumBuilder.isPresent()) {
page = page.getLoadedPage();
Page page = sourcePage.getPage();
writeChecksumBuilder.get().addPage(page);
rowGroupStatisticsValidation.get().addPage(page);
stripeStatisticsValidation.get().addPage(page);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.trino.plugin.tpch.DecimalTypeMapping;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.connector.SourcePage;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.SqlDecimal;
import io.trino.spi.type.SqlTimestamp;
Expand Down Expand Up @@ -328,8 +329,8 @@ public Object readLineitem(LineitemBenchmarkData data)
{
List<Page> pages = new ArrayList<>();
try (OrcRecordReader recordReader = data.createRecordReader()) {
for (Page page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) {
pages.add(page.getLoadedPage());
for (SourcePage page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) {
pages.add(page.getPage());
}
}
return pages;
Expand Down Expand Up @@ -375,7 +376,7 @@ private Object readFirstColumn(OrcRecordReader recordReader)
throws IOException
{
List<Block> blocks = new ArrayList<>();
for (Page page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) {
for (SourcePage page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) {
blocks.add(page.getBlock(0).getLoadedBlock());
}
return blocks;
Expand Down Expand Up @@ -429,6 +430,7 @@ OrcRecordReader createRecordReader()
return orcReader.createRecordReader(
orcReader.getRootColumn().getNestedColumns(),
types,
false,
OrcPredicate.TRUE,
UTC, // arbitrary
newSimpleAggregatedMemoryContext(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
package io.trino.orc;

import com.google.common.collect.ImmutableList;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.connector.SourcePage;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.SqlDecimal;
import org.joda.time.DateTimeZone;
Expand Down Expand Up @@ -71,8 +71,8 @@ public Object readDecimal(BenchmarkData data)
{
OrcRecordReader recordReader = data.createRecordReader();
List<Block> blocks = new ArrayList<>();
for (Page page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) {
blocks.add(page.getBlock(0).getLoadedBlock());
for (SourcePage page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) {
blocks.add(page.getBlock(0));
}
return blocks;
}
Expand Down Expand Up @@ -118,6 +118,7 @@ private OrcRecordReader createRecordReader()
return orcReader.createRecordReader(
orcReader.getRootColumn().getNestedColumns(),
ImmutableList.of(DECIMAL_TYPE),
false,
OrcPredicate.TRUE,
DateTimeZone.UTC, // arbitrary
newSimpleAggregatedMemoryContext(),
Expand Down
Loading

0 comments on commit 0bd2fb5

Please sign in to comment.