Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DuckDB's filesystem for GDAL by default, Handle GDAL errors, Add ST_Union_Agg(), ST_Intersection_Agg(). #126

Merged
merged 2 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 1371 files
4 changes: 3 additions & 1 deletion spatial/include/spatial/core/functions/aggregate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ namespace spatial {
namespace core {

struct CoreAggregateFunctions {
static void Register(ClientContext &context);
public:
static void Register(ClientContext &context) {
}
};

} // namespace core
Expand Down
20 changes: 20 additions & 0 deletions spatial/include/spatial/gdal/file_handler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include "spatial/common.hpp"

namespace spatial {

namespace gdal {

struct GdalFileHandler {
static void Register(ClientContext &context);

// This is a workaround to allow the global file handler to access the current client context
// by storing it in a thread_local variable before executing a GDAL IO operation
static void SetLocalClientContext(ClientContext &context);
static ClientContext &GetLocalClientContext();
};

} // namespace gdal

} // namespace spatial
2 changes: 2 additions & 0 deletions spatial/include/spatial/gdal/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ struct GdalTableFunction : ArrowTableFunction {
static void RenameColumns(vector<string> &names);

static unique_ptr<GlobalTableFunctionState> InitGlobal(ClientContext &context, TableFunctionInitInput &input);
static unique_ptr<LocalTableFunctionState> InitLocal(ExecutionContext &context, TableFunctionInitInput &input,
GlobalTableFunctionState *global_state_p);

static void Scan(ClientContext &context, TableFunctionInput &input, DataChunk &output);

Expand Down
3 changes: 3 additions & 0 deletions spatial/include/spatial/geos/geos_wrappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ struct GeosContextWrapper {
string_t Serialize(Vector &result, const unique_ptr<GEOSGeometry, GeosDeleter<GEOSGeometry>> &geom);
};

GEOSGeometry *DeserializeGEOSGeometry(const string_t &blob, GEOSContextHandle_t ctx);
string_t SerializeGEOSGeometry(Vector &result, const GEOSGeometry *geom, GEOSContextHandle_t ctx);

} // namespace geos

} // namespace spatial
2 changes: 1 addition & 1 deletion spatial/src/spatial/core/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void CoreModule::Register(ClientContext &context) {
CoreScalarFunctions::Register(context);
CoreCastFunctions::Register(context);
CoreTableFunctions::Register(context);
// CoreAggregateFunctions::Register(context);
CoreAggregateFunctions::Register(context);
}

} // namespace core
Expand Down
1 change: 1 addition & 0 deletions spatial/src/spatial/gdal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ add_subdirectory(functions)
set(EXTENSION_SOURCES
${EXTENSION_SOURCES}
${CMAKE_CURRENT_SOURCE_DIR}/module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/file_handler.cpp
PARENT_SCOPE
)
200 changes: 200 additions & 0 deletions spatial/src/spatial/gdal/file_handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
#include "spatial/gdal/file_handler.hpp"

#include "cpl_vsi.h"
#include "cpl_string.h"

namespace spatial {

namespace gdal {

//--------------------------------------------------------------------------
// Local Client Context
//--------------------------------------------------------------------------

static thread_local ClientContext *local_context = nullptr;

void GdalFileHandler::SetLocalClientContext(ClientContext &context) {
local_context = &context;
}

ClientContext &GdalFileHandler::GetLocalClientContext() {
if (!local_context) {
throw InternalException("No local client context set");
}
return *local_context;
}

//--------------------------------------------------------------------------
// Required Callbacks
//--------------------------------------------------------------------------

static void *DuckDBOpen(void *, const char *file_name, const char *access) {
auto &context = GdalFileHandler::GetLocalClientContext();
auto &fs = context.db->GetFileSystem();

// TODO: Double check that this is correct
uint8_t flags;
auto len = strlen(access);
if (access[0] == 'r') {
flags = FileFlags::FILE_FLAGS_READ;
if (len > 1 && access[1] == '+') {
flags |= FileFlags::FILE_FLAGS_WRITE;
}
} else if (access[0] == 'w') {
flags = FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW;
if (len > 1 && access[1] == '+') {
flags |= FileFlags::FILE_FLAGS_READ;
}
} else if (access[0] == 'a') {
flags = FileFlags::FILE_FLAGS_APPEND;
if (len > 1 && access[1] == '+') {
flags |= FileFlags::FILE_FLAGS_READ;
}
} else {
throw InternalException("Unknown file access type");
}

try {
auto file = fs.OpenFile(file_name, flags);
return file.release();
} catch (std::exception &ex) {
return nullptr;
}
}

static vsi_l_offset DuckDBTell(void *file) {
auto file_handle = static_cast<FileHandle *>(file);
auto offset = file_handle->SeekPosition();
return static_cast<vsi_l_offset>(offset);
}

static int DuckDBSeek(void *file, vsi_l_offset offset, int whence) {
auto file_handle = static_cast<FileHandle *>(file);
switch (whence) {
case SEEK_SET:
file_handle->Seek(offset);
break;
case SEEK_CUR:
file_handle->Seek(file_handle->SeekPosition() + offset);
break;
case SEEK_END:
file_handle->Seek(file_handle->GetFileSize() + offset);
break;
default:
throw InternalException("Unknown seek type");
}
return 0;
}

static size_t DuckDBRead(void *pFile, void *pBuffer, size_t n_size, size_t n_count) {
auto file_handle = static_cast<FileHandle *>(pFile);
auto read_bytes = file_handle->Read(pBuffer, n_size * n_count);
// Return the number of items read
return static_cast<size_t>(read_bytes / n_size);
}

static size_t DuckDBWrite(void *file, const void *buffer, size_t n_size, size_t n_count) {
auto file_handle = static_cast<FileHandle *>(file);
auto written_bytes = file_handle->Write(const_cast<void *>(buffer), n_size * n_count);
// Return the number of items written
return static_cast<size_t>(written_bytes / n_size);
}

static int DuckDBEoF(void *file) {
// TODO: Is this correct?
auto file_handle = static_cast<FileHandle *>(file);
return file_handle->SeekPosition() == file_handle->GetFileSize() ? TRUE : FALSE;
}

static int DuckDBTruncate(void *file, vsi_l_offset size) {
auto file_handle = static_cast<FileHandle *>(file);
file_handle->Truncate(static_cast<int64_t>(size));
return 0;
}

static int DuckDBClose(void *file) {
auto file_handle = static_cast<FileHandle *>(file);
file_handle->Close();
delete file_handle;
return 0;
}

static int DuckDBFlush(void *file) {
auto file_handle = static_cast<FileHandle *>(file);
file_handle->Sync();
return 0;
}

static int DuckDBMakeDir(void *, const char *dir_name, long mode) {
auto &context = GdalFileHandler::GetLocalClientContext();
auto &fs = context.db->GetFileSystem();

fs.CreateDirectory(dir_name);
return 0;
}

static int DuckDBDeleteDir(void *, const char *dir_name) {
auto &context = GdalFileHandler::GetLocalClientContext();
auto &fs = context.db->GetFileSystem();

fs.RemoveDirectory(dir_name);
return 0;
}

static char **DuckDBReadDir(void *, const char *dir_name, int max_files) {
auto &context = GdalFileHandler::GetLocalClientContext();
auto &fs = context.db->GetFileSystem();

CPLStringList files;
auto files_count = 0;
fs.ListFiles(dir_name, [&](const string &file_name, bool is_dir) {
if (files_count >= max_files) {
return;
}
files.AddString(file_name.c_str());
files_count++;
});
return files.StealList();
}

static char **DuckDBSiblingFiles(void *, const char *dir_name) {
auto &context = GdalFileHandler::GetLocalClientContext();
auto &fs = context.db->GetFileSystem();

CPLStringList files;
auto file_vector = fs.Glob(dir_name);
for (auto &file : file_vector) {
files.AddString(file.c_str());
}
return files.StealList();
}

//--------------------------------------------------------------------------
// Register
//--------------------------------------------------------------------------
void GdalFileHandler::Register(ClientContext &context) {

auto callbacks = VSIAllocFilesystemPluginCallbacksStruct();

callbacks->nCacheSize = 16384000; // same as /vsicurl/
callbacks->open = DuckDBOpen;
callbacks->read = DuckDBRead;
callbacks->write = DuckDBWrite;
callbacks->close = DuckDBClose;
callbacks->tell = DuckDBTell;
callbacks->seek = DuckDBSeek;
callbacks->eof = DuckDBEoF;
callbacks->flush = DuckDBFlush;
callbacks->truncate = DuckDBTruncate;
callbacks->mkdir = DuckDBMakeDir;
callbacks->rmdir = DuckDBDeleteDir;
callbacks->read_dir = DuckDBReadDir;
callbacks->sibling_files = DuckDBSiblingFiles;

// Override this as the default file system
VSIInstallPluginHandler("", callbacks);
}

} // namespace gdal

} // namespace spatial
25 changes: 23 additions & 2 deletions spatial/src/spatial/gdal/functions/st_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "spatial/common.hpp"
#include "spatial/core/types.hpp"
#include "spatial/gdal/functions.hpp"
#include "spatial/gdal/file_handler.hpp"

#include "ogrsf_frmts.h"

Expand Down Expand Up @@ -134,6 +135,9 @@ struct ScopedOption {
}
};

//------------------------------------------------------------------------------
// Bind
//------------------------------------------------------------------------------
unique_ptr<FunctionData> GdalTableFunction::Bind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {

Expand All @@ -142,6 +146,9 @@ unique_ptr<FunctionData> GdalTableFunction::Bind(ClientContext &context, TableFu
throw PermissionException("Scanning GDAL files is disabled through configuration");
}

// Set the local client context so that we can access it from the filesystem handler
GdalFileHandler::SetLocalClientContext(context);

// First scan for "options" parameter
auto gdal_open_options = vector<char const *>();
auto options_param = input.named_parameters.find("open_options");
Expand Down Expand Up @@ -453,7 +460,9 @@ OGRLayer *open_layer(const GdalScanFunctionData &data) {
return layer;
}

// init global
//-----------------------------------------------------------------------------
// Init global
//-----------------------------------------------------------------------------
unique_ptr<GlobalTableFunctionState> GdalTableFunction::InitGlobal(ClientContext &context,
TableFunctionInitInput &input) {
auto &data = input.bind_data->Cast<GdalScanFunctionData>();
Expand Down Expand Up @@ -500,7 +509,19 @@ unique_ptr<GlobalTableFunctionState> GdalTableFunction::InitGlobal(ClientContext
return std::move(global_state);
}

//-----------------------------------------------------------------------------
// Init Local
//-----------------------------------------------------------------------------
unique_ptr<LocalTableFunctionState> GdalTableFunction::InitLocal(ExecutionContext &context,
TableFunctionInitInput &input,
GlobalTableFunctionState *global_state_p) {
GdalFileHandler::SetLocalClientContext(context.client);
return ArrowTableFunction::ArrowScanInitLocal(context, input, global_state_p);
}

//-----------------------------------------------------------------------------
// Scan
//-----------------------------------------------------------------------------
void GdalTableFunction::Scan(ClientContext &context, TableFunctionInput &input, DataChunk &output) {
if (!input.local_state) {
return;
Expand Down Expand Up @@ -578,7 +599,7 @@ void GdalTableFunction::Register(ClientContext &context) {

TableFunctionSet set("st_read");
TableFunction scan({LogicalType::VARCHAR}, GdalTableFunction::Scan, GdalTableFunction::Bind,
GdalTableFunction::InitGlobal, ArrowTableFunction::ArrowScanInitLocal);
GdalTableFunction::InitGlobal, GdalTableFunction::InitLocal);

scan.cardinality = GdalTableFunction::Cardinality;
scan.get_batch_index = ArrowTableFunction::ArrowGetBatchIndex;
Expand Down
9 changes: 5 additions & 4 deletions spatial/src/spatial/gdal/functions/st_write.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include "duckdb/parser/parsed_data/create_table_function_info.hpp"
#include "spatial/core/types.hpp"
#include "spatial/core/geometry/geometry_factory.hpp"
#include "spatial/core/geometry/wkb_writer.hpp"
#include "spatial/gdal/functions.hpp"
#include "spatial/gdal/file_handler.hpp"

#include "ogrsf_frmts.h"

Expand Down Expand Up @@ -115,6 +115,7 @@ static unique_ptr<FunctionData> Bind(ClientContext &context, CopyInfo &info, vec
// Init Local
//===--------------------------------------------------------------------===//
static unique_ptr<LocalFunctionData> InitLocal(ExecutionContext &context, FunctionData &bind_data) {
GdalFileHandler::SetLocalClientContext(context.client);
auto local_data = make_uniq<LocalState>(context.client);
return std::move(local_data);
}
Expand Down Expand Up @@ -208,9 +209,9 @@ static unique_ptr<OGRFieldDefn> OGRFieldTypeFromLogicalType(const string &name,
}
static unique_ptr<GlobalFunctionData> InitGlobal(ClientContext &context, FunctionData &bind_data,
const string &file_path) {
// auto gdal_data = (BindData&)bind_data;
// auto global_data = make_uniq<GlobalState>(file_path, "FlatGeobuf");
// return std::move(global_data);

// Set the local client context so that we can access it from the filesystem handler
GdalFileHandler::SetLocalClientContext(context);

auto &gdal_data = (BindData &)bind_data;
GDALDriver *driver = GetGDALDriverManager()->GetDriverByName(gdal_data.driver_name.c_str());
Expand Down
Loading