diff --git a/src/database.cpp b/src/database.cpp index e26275628..284760a31 100644 --- a/src/database.cpp +++ b/src/database.cpp @@ -1,4 +1,5 @@ #include "rapi.hpp" +#include "r_progress_bar_display.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" @@ -16,6 +17,51 @@ static bool CastRstringToVarchar(Vector &source, Vector &result, idx_t count, Ca return true; } +unique_ptr RProgressBarDisplay::Create() { + return make_uniq(); +} + +void RProgressBarDisplay::Initialize() { + auto progress_display = Rf_GetOption(RStrings::get().progress_display_sym, R_BaseEnv); + if (Rf_isFunction(progress_display)) { + progress_callback = progress_display; + } + D_ASSERT(progress_callback != R_NilValue); +} + +RProgressBarDisplay::RProgressBarDisplay() : ProgressBarDisplay() { + // Empty +} + +void RProgressBarDisplay::Update(double percentage) { + if (progress_callback == R_NilValue) { + Initialize(); + } + if (progress_callback != R_NilValue) { + try { + cpp11::sexp call = Rf_lang2(progress_callback, Rf_ScalarReal(percentage)); + cpp11::safe[Rf_eval](call, R_BaseEnv); + } catch (std::exception &e) { + cpp11::stop("RProgressBarDisplay: Failed to update progress bar: %s", e.what()); + } + } +} + +void RProgressBarDisplay::Finish() { + Update(100); +} + +static void SetDefaultConfigArguments(ClientContext &context) { + auto progress_display = Rf_GetOption(RStrings::get().progress_display_sym, R_BaseEnv); + if (Rf_isFunction(progress_display)) { + auto &config = ClientConfig::GetConfig(context); + config.enable_progress_bar = true; + } + + // Set the function used to create the display for the progress bar + context.config.display_create_func = RProgressBarDisplay::Create; +} + [[cpp11::register]] duckdb::db_eptr_t rapi_startup(std::string dbdir, bool readonly, cpp11::list configsexp, bool environment_scan) { const char *dbdirchar; @@ -76,6 +122,7 @@ static bool CastRstringToVarchar(Vector &source, Vector &result, idx_t count, Ca CreateTableFunctionInfo info(scan_fun); Connection conn(*wrapper->db); auto &context = *conn.context; + SetDefaultConfigArguments(context); auto &catalog = Catalog::GetSystemCatalog(context); context.transaction.BeginTransaction(); diff --git a/src/include/r_progress_bar_display.hpp b/src/include/r_progress_bar_display.hpp new file mode 100644 index 000000000..6fd867daa --- /dev/null +++ b/src/include/r_progress_bar_display.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "rapi.hpp" +#include "duckdb/common/progress_bar/progress_bar_display.hpp" +#include "duckdb/common/helper.hpp" + +namespace duckdb { + +class RProgressBarDisplay : public ProgressBarDisplay { +public: + RProgressBarDisplay(); + virtual ~RProgressBarDisplay() { + } + + static unique_ptr Create(); + +public: + void Update(double percentage) override; + void Finish() override; + +private: + void Initialize(); + +private: + SEXP progress_callback = R_NilValue; +}; + +} // namespace duckdb diff --git a/src/include/rapi.hpp b/src/include/rapi.hpp index 50fe9a86b..213990a7c 100644 --- a/src/include/rapi.hpp +++ b/src/include/rapi.hpp @@ -171,6 +171,7 @@ struct RStrings { SEXP ImportRecordBatchReader_sym; SEXP materialize_callback_sym; SEXP materialize_message_sym; + SEXP progress_display_sym; SEXP duckdb_row_names_sym; SEXP duckdb_vector_sym; diff --git a/src/utils.cpp b/src/utils.cpp index 746386566..1f0e3aabf 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -82,6 +82,7 @@ RStrings::RStrings() { Table__from_record_batches_sym = Rf_install("Table__from_record_batches"); materialize_message_sym = Rf_install("duckdb.materialize_message"); materialize_callback_sym = Rf_install("duckdb.materialize_callback"); + progress_display_sym = Rf_install("duckdb.progress_display"); duckdb_row_names_sym = Rf_install("duckdb_row_names"); duckdb_vector_sym = Rf_install("duckdb_vector"); }