Skip to content

Commit

Permalink
Merge pull request #11002 from rouault/ogr2ogr_GetArrowStream_once
Browse files Browse the repository at this point in the history
ogr2ogr: optim: call GetArrowStream() only once on source layer when using Arrow interface
  • Loading branch information
rouault authored Oct 15, 2024
2 parents 9350c8a + 9d058a9 commit 2e32ace
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 63 deletions.
110 changes: 60 additions & 50 deletions apps/ogr2ogr_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ struct TargetLayerInfo
const char *m_pszGeomField = nullptr;
std::vector<int> m_anDateTimeFieldIdx{};
bool m_bSupportCurves = false;
OGRArrowArrayStream m_sArrowArrayStream{};
};

struct AssociatedLayers
Expand All @@ -507,7 +508,8 @@ class SetupTargetLayer
bool CanUseWriteArrowBatch(OGRLayer *poSrcLayer, OGRLayer *poDstLayer,
bool bJustCreatedLayer,
const GDALVectorTranslateOptions *psOptions,
bool &bError);
bool bPreserveFID, bool &bError,
OGRArrowArrayStream &streamSrc);

public:
GDALDataset *m_poSrcDS = nullptr;
Expand Down Expand Up @@ -3990,13 +3992,46 @@ static int GetArrowGeomFieldIndex(const struct ArrowSchema *psLayerSchema,
return -1;
}

/************************************************************************/
/* BuildGetArrowStreamOptions() */
/************************************************************************/

static CPLStringList
BuildGetArrowStreamOptions(const GDALVectorTranslateOptions *psOptions,
bool bPreserveFID)
{
CPLStringList aosOptionsGetArrowStream;
aosOptionsGetArrowStream.SetNameValue("SILENCE_GET_SCHEMA_ERROR", "YES");
aosOptionsGetArrowStream.SetNameValue("GEOMETRY_ENCODING", "WKB");
if (!bPreserveFID)
aosOptionsGetArrowStream.SetNameValue("INCLUDE_FID", "NO");
if (psOptions->nLimit >= 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf(CPL_FRMT_GIB,
std::min<GIntBig>(psOptions->nLimit,
(psOptions->nGroupTransactions > 0
? psOptions->nGroupTransactions
: 65536))));
}
else if (psOptions->nGroupTransactions > 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf("%d", psOptions->nGroupTransactions));
}
return aosOptionsGetArrowStream;
}

/************************************************************************/
/* SetupTargetLayer::CanUseWriteArrowBatch() */
/************************************************************************/

bool SetupTargetLayer::CanUseWriteArrowBatch(
OGRLayer *poSrcLayer, OGRLayer *poDstLayer, bool bJustCreatedLayer,
const GDALVectorTranslateOptions *psOptions, bool &bError)
const GDALVectorTranslateOptions *psOptions, bool bPreserveFID,
bool &bError, OGRArrowArrayStream &streamSrc)
{
bError = false;

Expand Down Expand Up @@ -4050,20 +4085,20 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
}
}

struct ArrowArrayStream streamSrc;
const char *const apszOptions[] = {"SILENCE_GET_SCHEMA_ERROR=YES",
nullptr};
if (poSrcLayer->GetArrowStream(&streamSrc, apszOptions))
const CPLStringList aosGetArrowStreamOptions(
BuildGetArrowStreamOptions(psOptions, bPreserveFID));
if (poSrcLayer->GetArrowStream(streamSrc.get(),
aosGetArrowStreamOptions.List()))
{
struct ArrowSchema schemaSrc;
if (streamSrc.get_schema(&streamSrc, &schemaSrc) == 0)
if (streamSrc.get_schema(&schemaSrc) == 0)
{
if (psOptions->bTransform &&
GetArrowGeomFieldIndex(&schemaSrc,
poSrcLayer->GetGeometryColumn()) < 0)
{
schemaSrc.release(&schemaSrc);
streamSrc.release(&streamSrc);
streamSrc.clear();
return false;
}

Expand Down Expand Up @@ -4145,7 +4180,7 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
"Cannot create field %s",
pszFieldName);
schemaSrc.release(&schemaSrc);
streamSrc.release(&streamSrc);
streamSrc.clear();
return false;
}
}
Expand All @@ -4157,7 +4192,8 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
// check that it looks to be the same as the source
// one
struct ArrowArrayStream streamDst;
if (poDstLayer->GetArrowStream(&streamDst, nullptr))
if (poDstLayer->GetArrowStream(
&streamDst, aosGetArrowStreamOptions.List()))
{
struct ArrowSchema schemaDst;
if (streamDst.get_schema(&streamDst, &schemaDst) ==
Expand Down Expand Up @@ -4188,7 +4224,8 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
}
schemaSrc.release(&schemaSrc);
}
streamSrc.release(&streamSrc);
if (!bUseWriteArrowBatch)
streamSrc.clear();
}
}
return bUseWriteArrowBatch;
Expand Down Expand Up @@ -4915,8 +4952,10 @@ SetupTargetLayer::Setup(OGRLayer *poSrcLayer, const char *pszNewLayerName,
}

bool bError = false;
const bool bUseWriteArrowBatch = CanUseWriteArrowBatch(
poSrcLayer, poDstLayer, bJustCreatedLayer, psOptions, bError);
OGRArrowArrayStream streamSrc;
const bool bUseWriteArrowBatch =
CanUseWriteArrowBatch(poSrcLayer, poDstLayer, bJustCreatedLayer,
psOptions, bPreserveFID, bError, streamSrc);
if (bError)
return nullptr;

Expand Down Expand Up @@ -5378,7 +5417,7 @@ SetupTargetLayer::Setup(OGRLayer *poSrcLayer, const char *pszNewLayerName,
nTotalEventsDone = 0;
}

std::unique_ptr<TargetLayerInfo> psInfo(new TargetLayerInfo);
auto psInfo = std::make_unique<TargetLayerInfo>();
psInfo->m_bUseWriteArrowBatch = bUseWriteArrowBatch;
psInfo->m_nFeaturesRead = 0;
psInfo->m_bPerFeatureCT = false;
Expand Down Expand Up @@ -5475,6 +5514,8 @@ SetupTargetLayer::Setup(OGRLayer *poSrcLayer, const char *pszNewLayerName,
psInfo->m_bSupportCurves =
CPL_TO_BOOL(poDstLayer->TestCapability(OLCCurveGeometries));

psInfo->m_sArrowArrayStream = std::move(streamSrc);

return psInfo;
}

Expand Down Expand Up @@ -5769,49 +5810,19 @@ bool LayerTranslator::TranslateArrow(
GIntBig *pnReadFeatureCount, GDALProgressFunc pfnProgress,
void *pProgressArg, const GDALVectorTranslateOptions *psOptions)
{
struct ArrowArrayStream stream;
struct ArrowSchema schema;
CPLStringList aosOptionsGetArrowStream;
CPLStringList aosOptionsWriteArrowBatch;
aosOptionsGetArrowStream.SetNameValue("GEOMETRY_ENCODING", "WKB");
if (!psInfo->m_bPreserveFID)
aosOptionsGetArrowStream.SetNameValue("INCLUDE_FID", "NO");
else
if (psInfo->m_bPreserveFID)
{
aosOptionsWriteArrowBatch.SetNameValue(
"FID", psInfo->m_poSrcLayer->GetFIDColumn());
aosOptionsWriteArrowBatch.SetNameValue("IF_FID_NOT_PRESERVED",
"WARNING");
}
if (psOptions->nLimit >= 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf(CPL_FRMT_GIB,
std::min<GIntBig>(psOptions->nLimit,
(psOptions->nGroupTransactions > 0
? psOptions->nGroupTransactions
: 65536))));
}
else if (psOptions->nGroupTransactions > 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf("%d", psOptions->nGroupTransactions));
}
if (psInfo->m_poSrcLayer->GetArrowStream(&stream,
aosOptionsGetArrowStream.List()))
{
if (stream.get_schema(&stream, &schema) != 0)
{
CPLError(CE_Failure, CPLE_AppDefined, "stream.get_schema() failed");
stream.release(&stream);
return false;
}
}
else

if (psInfo->m_sArrowArrayStream.get_schema(&schema) != 0)
{
CPLError(CE_Failure, CPLE_AppDefined, "GetArrowStream() failed");
CPLError(CE_Failure, CPLE_AppDefined, "stream.get_schema() failed");
return false;
}

Expand Down Expand Up @@ -5865,7 +5876,7 @@ bool LayerTranslator::TranslateArrow(
{
struct ArrowArray array;
// Acquire source batch
if (stream.get_next(&stream, &array) != 0)
if (psInfo->m_sArrowArrayStream.get_next(&array) != 0)
{
CPLError(CE_Failure, CPLE_AppDefined, "stream.get_next() failed");
bRet = false;
Expand Down Expand Up @@ -6043,7 +6054,6 @@ bool LayerTranslator::TranslateArrow(

schema.release(&schema);

stream.release(&stream);
return bRet;
}

Expand Down
31 changes: 18 additions & 13 deletions ogr/ogr_recordbatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@

#include <stdint.h>

// Spec and documentation: https://arrow.apache.org/docs/format/CDataInterface.html

#ifdef __cplusplus
extern "C"
{
#endif

#ifndef ARROW_C_DATA_INTERFACE
#define ARROW_C_DATA_INTERFACE

#define ARROW_FLAG_DICTIONARY_ORDERED 1
#define ARROW_FLAG_NULLABLE 2
#define ARROW_FLAG_MAP_KEYS_SORTED 4
Expand Down Expand Up @@ -69,28 +74,27 @@ extern "C"
void *private_data;
};

// EXPERIMENTAL: C stream interface
#endif // ARROW_C_DATA_INTERFACE

#ifndef ARROW_C_STREAM_INTERFACE
#define ARROW_C_STREAM_INTERFACE

struct ArrowArrayStream
{
// Callback to get the stream type
// (will be the same for all arrays in the stream).
//
// Return value: 0 if successful, an `errno`-compatible error code
// otherwise.
// Return value: 0 if successful, an `errno`-compatible error code otherwise.
//
// If successful, the ArrowSchema must be released independently from
// the stream.
// If successful, the ArrowSchema must be released independently from the stream.
int (*get_schema)(struct ArrowArrayStream *, struct ArrowSchema *out);

// Callback to get the next array
// (if no error and the array is released, the stream has ended)
//
// Return value: 0 if successful, an `errno`-compatible error code
// otherwise.
// Return value: 0 if successful, an `errno`-compatible error code otherwise.
//
// If successful, the ArrowArray must be released independently from the
// stream.
// If successful, the ArrowArray must be released independently from the stream.
int (*get_next)(struct ArrowArrayStream *, struct ArrowArray *out);

// Callback to get optional detailed error information.
Expand All @@ -100,19 +104,20 @@ extern "C"
// Return value: pointer to a null-terminated character array describing
// the last error, or NULL if no description is available.
//
// The returned pointer is only valid until the next operation on this
// stream (including release).
// The returned pointer is only valid until the next operation on this stream
// (including release).
const char *(*get_last_error)(struct ArrowArrayStream *);

// Release callback: release the stream's own resources.
// Note that arrays returned by `get_next` must be individually
// released.
// Note that arrays returned by `get_next` must be individually released.
void (*release)(struct ArrowArrayStream *);

// Opaque producer-specific data
void *private_data;
};

#endif // ARROW_C_STREAM_INTERFACE

#ifdef __cplusplus
}
#endif
Expand Down
71 changes: 71 additions & 0 deletions ogr/ogrsf_frmts/generic/ogrlayerarrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <map>
#include <string>

#include "ogr_recordbatch.h"

constexpr const char *ARROW_EXTENSION_NAME_KEY = "ARROW:extension:name";
constexpr const char *ARROW_EXTENSION_METADATA_KEY = "ARROW:extension:metadata";
constexpr const char *EXTENSION_NAME_OGC_WKB = "ogc.wkb";
Expand All @@ -34,4 +36,73 @@ bool CPL_DLL OGRCloneArrowArray(const struct ArrowSchema *schema,
bool CPL_DLL OGRCloneArrowSchema(const struct ArrowSchema *schema,
struct ArrowSchema *out_schema);

/** C++ wrapper on top of ArrowArrayStream */
class OGRArrowArrayStream
{
public:
/** Constructor: instantiate an empty ArrowArrayStream */
inline OGRArrowArrayStream()
{
memset(&m_stream, 0, sizeof(m_stream));
}

/** Destructor: call release() on the ArrowArrayStream if not already done */
inline ~OGRArrowArrayStream()
{
clear();
}

/** Call release() on the ArrowArrayStream if not already done */
// cppcheck-suppress functionStatic
inline void clear()
{
if (m_stream.release)
{
m_stream.release(&m_stream);
m_stream.release = nullptr;
}
}

/** Return the raw ArrowArrayStream* */
inline ArrowArrayStream *get()
{
return &m_stream;
}

/** Get the schema */
// cppcheck-suppress functionStatic
inline int get_schema(struct ArrowSchema *schema)
{
return m_stream.get_schema(&m_stream, schema);
}

/** Get the next ArrowArray batch */
// cppcheck-suppress functionStatic
inline int get_next(struct ArrowArray *array)
{
return m_stream.get_next(&m_stream, array);
}

/** Move assignment operator */
inline OGRArrowArrayStream &operator=(OGRArrowArrayStream &&other)
{
if (this != &other)
{
clear();
memcpy(&m_stream, &(other.m_stream), sizeof(m_stream));
memset(&(other.m_stream), 0, sizeof(m_stream));
}
return *this;
}

private:
struct ArrowArrayStream m_stream
{
};

OGRArrowArrayStream(const OGRArrowArrayStream &) = delete;
OGRArrowArrayStream(OGRArrowArrayStream &&) = delete;
OGRArrowArrayStream &operator=(const OGRArrowArrayStream &) = delete;
};

#endif // OGRLAYERARROW_H_DEFINED

0 comments on commit 2e32ace

Please sign in to comment.