Skip to content

Commit 082377a

Browse files
authored
GH-47713: [C++][FlightRPC] ODBC return number of affected rows (#48037)
### Rationale for this change Make ODBC return number of affected rows as -1 which to BI tools means number of affected rows is unknown. This is because ODBC only supports `select` statement and doesn't support queries that affect rows. ### What changes are included in this PR? - SQLRowCount & tests ### Are these changes tested? Tested locally on MSVC ### Are there any user-facing changes? N/A * GitHub Issue: #47713 Authored-by: Alina (Xi) Li <[email protected]> Signed-off-by: David Li <[email protected]>
1 parent 01bc1bd commit 082377a

File tree

4 files changed

+67
-2
lines changed

4 files changed

+67
-2
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,8 +1156,13 @@ SQLRETURN SQLNumResultCols(SQLHSTMT stmt, SQLSMALLINT* column_count_ptr) {
11561156
SQLRETURN SQLRowCount(SQLHSTMT stmt, SQLLEN* row_count_ptr) {
11571157
ARROW_LOG(DEBUG) << "SQLRowCount called with stmt: " << stmt
11581158
<< ", column_count_ptr: " << static_cast<const void*>(row_count_ptr);
1159-
// GH-47713 TODO: Implement SQLRowCount
1160-
return SQL_INVALID_HANDLE;
1159+
1160+
using ODBC::ODBCStatement;
1161+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
1162+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
1163+
statement->GetRowCount(row_count_ptr);
1164+
return SQL_SUCCESS;
1165+
});
11611166
}
11621167

11631168
SQLRETURN SQLTables(SQLHSTMT stmt, SQLWCHAR* catalog_name,

cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,17 @@ SQLRETURN ODBCStatement::GetData(SQLSMALLINT record_number, SQLSMALLINT c_type,
761761
data_ptr, buffer_length, indicator_ptr);
762762
}
763763

764+
void ODBCStatement::GetRowCount(SQLLEN* row_count_ptr) {
765+
if (!row_count_ptr) {
766+
// row count pointer is not valid, do nothing as ODBC spec does not mention this as an
767+
// error
768+
return;
769+
}
770+
// Will always be -1 (meaning number of rows unknown) since only SELECT is supported by
771+
// driver
772+
*row_count_ptr = -1;
773+
}
774+
764775
void ODBCStatement::ReleaseStatement() {
765776
CloseCursor(true);
766777
connection_.DropStatement(this);

cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ class ODBCStatement : public ODBCHandle<ODBCStatement> {
7878
SQLRETURN GetData(SQLSMALLINT record_number, SQLSMALLINT c_type, SQLPOINTER data_ptr,
7979
SQLLEN buffer_length, SQLLEN* indicator_ptr);
8080

81+
/// \brief Return number of rows affected by an UPDATE, INSERT, or DELETE statement\
82+
///
83+
/// -1 is returned as driver only supports SELECT statement
84+
void GetRowCount(SQLLEN* row_count_ptr);
85+
8186
/// \brief Closes the cursor. This does _not_ un-prepare the statement or change
8287
/// bindings.
8388
void CloseCursor(bool suppress_errors);

cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,50 @@ TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsErrorOnBadInputs) {
12811281
VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY090);
12821282
}
12831283

1284+
TYPED_TEST(StatementTest, SQLRowCountReturnsNegativeOneOnSelect) {
1285+
SQLLEN row_count = 0;
1286+
SQLLEN expected_value = -1;
1287+
SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3";
1288+
SQLINTEGER query_length = static_cast<SQLINTEGER>(wcslen(sql_query));
1289+
1290+
ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length));
1291+
1292+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
1293+
1294+
CheckIntColumn(this->stmt, 1, 1);
1295+
CheckStringColumnW(this->stmt, 2, L"One");
1296+
CheckIntColumn(this->stmt, 3, 3);
1297+
1298+
ASSERT_EQ(SQL_SUCCESS, SQLRowCount(this->stmt, &row_count));
1299+
1300+
EXPECT_EQ(expected_value, row_count);
1301+
}
1302+
1303+
TYPED_TEST(StatementTest, SQLRowCountReturnsSuccessOnNullptr) {
1304+
SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3";
1305+
SQLINTEGER query_length = static_cast<SQLINTEGER>(wcslen(sql_query));
1306+
1307+
ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length));
1308+
1309+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
1310+
1311+
CheckIntColumn(this->stmt, 1, 1);
1312+
CheckStringColumnW(this->stmt, 2, L"One");
1313+
CheckIntColumn(this->stmt, 3, 3);
1314+
1315+
ASSERT_EQ(SQL_SUCCESS, SQLRowCount(this->stmt, nullptr));
1316+
}
1317+
1318+
TYPED_TEST(StatementTest, SQLRowCountFunctionSequenceErrorOnNoQuery) {
1319+
SQLLEN row_count = 0;
1320+
SQLLEN expected_value = 0;
1321+
1322+
ASSERT_EQ(SQL_ERROR, SQLRowCount(this->stmt, &row_count));
1323+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010);
1324+
1325+
EXPECT_EQ(expected_value, row_count);
1326+
}
1327+
12841328
TYPED_TEST(StatementTest, TestSQLCloseCursor) {
12851329
std::wstring wsql = L"SELECT 1;";
12861330
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());

0 commit comments

Comments
 (0)