diff --git a/.bazelrc b/.bazelrc index 99aacce..448a201 100644 --- a/.bazelrc +++ b/.bazelrc @@ -13,7 +13,7 @@ # limitations under the License. common --experimental_repo_remote_exec -common --experimental_remote_download_regex='.*\/dataproxy_sdk$|.*\/arrow$' +common --modify_execution_info=CppLink=+no-remote-cache build --incompatible_new_actions_api=false diff --git a/.clang-format b/.clang-format index 16b3e5e..2ed65ab 100644 --- a/.clang-format +++ b/.clang-format @@ -9,7 +9,7 @@ IncludeCategories: Priority: 2 - Regex: '.*\.pb\.h"$' Priority: 5 - - Regex: '^"secretflow_serving.*' + - Regex: '^"dataproxy_sdk.*' Priority: 4 - Regex: '^".*' Priority: 3 diff --git a/build/Dockerfiles/dataproxy.Dockerfile b/build/Dockerfiles/dataproxy.Dockerfile index 31f7193..e165256 100644 --- a/build/Dockerfiles/dataproxy.Dockerfile +++ b/build/Dockerfiles/dataproxy.Dockerfile @@ -1,11 +1,11 @@ # Copyright 2024 Ant Group Co., Ltd. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/config/application.yaml b/config/application.yaml index 086ee20..614120e 100644 --- a/config/application.yaml +++ b/config/application.yaml @@ -1,11 +1,11 @@ # Copyright 2024 Ant Group Co., Ltd. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/model/datasource/conn/OdpsConnConfig.java b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/model/datasource/conn/OdpsConnConfig.java index 9fefa43..8f875b0 100644 --- a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/model/datasource/conn/OdpsConnConfig.java +++ b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/model/datasource/conn/OdpsConnConfig.java @@ -1,12 +1,12 @@ /* * Copyright 2024 Ant Group Co., Ltd. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/model/datasource/location/OdpsTableInfo.java b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/model/datasource/location/OdpsTableInfo.java index a8883be..9b813b4 100644 --- a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/model/datasource/location/OdpsTableInfo.java +++ b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/model/datasource/location/OdpsTableInfo.java @@ -1,12 +1,12 @@ /* * Copyright 2024 Ant Group Co., Ltd. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/serializer/SensitiveDataSerializer.java b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/serializer/SensitiveDataSerializer.java index aec3cb4..cca960c 100644 --- a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/serializer/SensitiveDataSerializer.java +++ b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/serializer/SensitiveDataSerializer.java @@ -1,12 +1,12 @@ /* * Copyright 2024 Ant Group Co., Ltd. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/serializer/package-info.java b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/serializer/package-info.java index 4501b30..82786d4 100644 --- a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/serializer/package-info.java +++ b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/serializer/package-info.java @@ -1,12 +1,12 @@ /* * Copyright 2024 Ant Group Co., Ltd. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java index 0446c84..0f10919 100644 --- a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java +++ b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java @@ -67,6 +67,8 @@ public class OdpsDataWriter implements DataWriter { private final boolean overwrite = true; + private boolean isTemporarilyCreatedTable = false; + private TableTunnel.UploadSession uploadSession = null; private RecordWriter recordWriter = null; @@ -147,9 +149,9 @@ private void initOdps() throws TunnelException, IOException { Odps odps = initOdpsClient(this.connConfig); // Pre-processing preProcessing(odps, connConfig.getProjectName(), tableInfo.tableName()); - // init download session + // init upload session TableTunnel tunnel = new TableTunnel(odps); - if (tableInfo.partitionSpec() != null && !tableInfo.partitionSpec().isEmpty()) { + if (tableInfo.partitionSpec() != null && !tableInfo.partitionSpec().isEmpty() && !isTemporarilyCreatedTable) { PartitionSpec partitionSpec = new PartitionSpec(tableInfo.partitionSpec()); uploadSession = tunnel.createUploadSession(connConfig.getProjectName(), tableInfo.tableName(), partitionSpec, overwrite); } else { @@ -244,6 +246,7 @@ private void preProcessing(Odps odps, String projectName, String tableName) { if (!odpsTable) { throw DataproxyException.of(DataproxyErrorCode.ODPS_CREATE_TABLE_FAILED); } + isTemporarilyCreatedTable = true; } log.info("odps table is exists or create table successful, project: {}, table name: {}", projectName, tableName); } diff --git a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsSplitArrowReader.java b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsSplitArrowReader.java index 29655b5..3773757 100644 --- a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsSplitArrowReader.java +++ b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsSplitArrowReader.java @@ -22,7 +22,6 @@ import com.aliyun.odps.data.Record; import com.aliyun.odps.data.ResultSet; import com.aliyun.odps.task.SQLTask; -import com.aliyun.odps.utils.StringUtils; import lombok.extern.slf4j.Slf4j; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.BigIntVector; @@ -44,12 +43,7 @@ import org.secretflow.dataproxy.manager.SplitReader; import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.regex.Matcher; import java.util.regex.Pattern; /** @@ -155,147 +149,16 @@ public ArrowReader startRead() { return this; } - private String buildSql(String tableName, List fields, String partition) { + private String buildSql(String tableName, List fields, String whereClause) { if (!columnOrValuePattern.matcher(tableName).matches()) { throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid tableName:" + tableName); } - String transformedPartition = buildWhereClause(partition); - return "select " + String.join(",", fields) + " from " + tableName + (transformedPartition.isEmpty() ? "" : " where " + transformedPartition) + ";"; + return "select " + String.join(",", fields) + " from " + tableName + (whereClause.isEmpty() ? "" : " where " + whereClause) + ";"; } - /** - * 过时方法,后续删除 - * - * @param partition 分区字段 - * @return boolean - */ - @Deprecated - private String transformPartition(String partition) { - - Map> fieldValuesMap = new HashMap<>(); - - if (partition != null) { - String[] split = StringUtils.split(partition, ';'); - for (String s : split) { - String[] kv = StringUtils.split(s, '='); - if (kv.length != 2 || kv[0].isEmpty() || kv[1].isEmpty()) { - throw DataproxyException.of(DataproxyErrorCode.INVALID_PARTITION_SPEC); - } - if (fieldValuesMap.containsKey(kv[0])) { - fieldValuesMap.get(kv[0]).add(kv[1]); - } else { - fieldValuesMap.put(kv[0], new ArrayList<>(List.of(kv[1]))); - } - } - } - - return buildEqualClause(fieldValuesMap).toString(); - } - - /** - * 构造转换等于号多值条件至 "in" 条件,单值保留为 "=" 条件
- * - * @param fieldValuesMap 字段值 - * @return where clause string - */ - private StringBuilder buildEqualClause(Map> fieldValuesMap) { - StringBuilder sb = new StringBuilder(); - if (!fieldValuesMap.isEmpty()) { - - boolean first = true; - for (Map.Entry> entry : fieldValuesMap.entrySet()) { - if (!first) { - sb.append(" and "); - } - first = false; - sb.append(entry.getKey()); - List values = entry.getValue(); - if (values.size() > 1) { - sb.append(" in ("); - for (String value : values) { - sb.append("'").append(value).append("'").append(", "); - } - sb.setLength(sb.length() - 2); - sb.append(")"); - } else { - sb.append(" = ").append("'").append(values.get(0)).append("'"); - } - } - } - - return sb; - } - - /** - * TODO: 对于通过 JDBC 操作的方式,可以把这块逻辑抽出来 - * - * @param conditionString 条件字段 - * @return where clause - */ - private String buildWhereClause(String conditionString) { - - if (conditionString == null || conditionString.isEmpty()) { - return ""; - } - - String[] conditions = conditionString.split(";"); - - StringBuilder whereClause = new StringBuilder(); - Pattern pattern = Pattern.compile("^(\\w+)(>=|<=|<>|!=|=|>|<| LIKE | like )(.*)$"); - - - Map> equalFieldValuesMap = new HashMap<>(); - - for (String condition : conditions) { - Matcher matcher = pattern.matcher(condition.trim()); - - if (!matcher.matches() || matcher.groupCount() != 3) { - throw new DataproxyException(DataproxyErrorCode.INVALID_PARTITION_SPEC, "Invalid condition format: " + condition); - } - - String column = matcher.group(1).trim(); - String operator = matcher.group(2); - String value = matcher.group(3).trim(); - - if (!columnOrValuePattern.matcher(column).matches()) { - throw new DataproxyException(DataproxyErrorCode.INVALID_PARTITION_SPEC, "Invalid condition format: " + column); - } - - if (!columnOrValuePattern.matcher(value).matches()) { - throw new DataproxyException(DataproxyErrorCode.INVALID_PARTITION_SPEC, "Invalid condition format: " + column); - } - - // 安全处理用户输入的值,可以根据具体需要进行处理 - value = value.replace("'", "''"); // 简单处理单引号转义 - - if ("=".equals(operator)) { - if (equalFieldValuesMap.containsKey(column)) { - equalFieldValuesMap.get(column).add(value); - } else { - equalFieldValuesMap.put(column, new ArrayList<>(List.of(value))); - } - } else { - if (!whereClause.isEmpty()) { - whereClause.append(" and "); - } - whereClause.append(column).append(' ').append(operator).append(" '").append(value).append("'"); - } - } - StringBuilder equalFieldClause = buildEqualClause(equalFieldValuesMap); - - if (whereClause.isEmpty()) { - return equalFieldClause.toString(); - } - - if (!equalFieldClause.isEmpty()) { - whereClause.append(" and ").append(equalFieldClause); - } - return whereClause.toString(); - } - - private void toArrowVector(Record record, VectorSchemaRoot root, int rowIndex) throws IOException { + private void toArrowVector(Record record, VectorSchemaRoot root, int rowIndex) { FieldVector vector; String columnName; for (Field field : schema.getFields()) { @@ -327,7 +190,8 @@ private void setValue(ArrowType type, FieldVector vector, int rowIndex, Record r } case Utf8 -> { if (vector instanceof VarCharVector varcharVector) { - varcharVector.setSafe(rowIndex, record.getString(columnName).getBytes(StandardCharsets.UTF_8)); + // record#getBytes default is UTF-8 + varcharVector.setSafe(rowIndex, record.getBytes(columnName)); } else { log.warn("Unsupported type: {}", type); } diff --git a/dataproxy-server/src/main/resources/application.yaml b/dataproxy-server/src/main/resources/application.yaml index 086ee20..614120e 100644 --- a/dataproxy-server/src/main/resources/application.yaml +++ b/dataproxy-server/src/main/resources/application.yaml @@ -1,11 +1,11 @@ # Copyright 2024 Ant Group Co., Ltd. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/dataproxy_sdk/bazel/repositories.bzl b/dataproxy_sdk/bazel/repositories.bzl index 425af23..39c9e3b 100644 --- a/dataproxy_sdk/bazel/repositories.bzl +++ b/dataproxy_sdk/bazel/repositories.bzl @@ -55,10 +55,10 @@ def _kuscia(): http_archive, name = "kuscia", urls = [ - "https://github.com/secretflow/kuscia/archive/refs/tags/v0.9.0b0.tar.gz", + "https://github.com/secretflow/kuscia/archive/refs/tags/v0.11.0b0.tar.gz", ], - strip_prefix = "kuscia-0.9.0b0", - sha256 = "851455f4a3ba70850c8a751a78ebfbbb9fd6d78ec902d0cbf32c2c565d1c8410", + strip_prefix = "kuscia-0.11.0b0", + sha256 = "c8de425a5f442ba3fa30a9b5943f9fd056efd9ab610ddc2168d5ffcf71224974", ) def _bazel_rules_pkg(): diff --git a/dataproxy_sdk/cc/BUILD.bazel b/dataproxy_sdk/cc/BUILD.bazel index 66e7111..a848258 100644 --- a/dataproxy_sdk/cc/BUILD.bazel +++ b/dataproxy_sdk/cc/BUILD.bazel @@ -68,17 +68,6 @@ dataproxy_cc_library( ], ) -dataproxy_cc_test( - name = "file_help_test", - srcs = ["file_help_test.cc"], - deps = [ - "utils", - ":exception", - ":file_help", - "@org_apache_arrow//:arrow", - ], -) - dataproxy_cc_library( name = "data_proxy_conn", srcs = ["data_proxy_conn.cc"], @@ -102,3 +91,42 @@ dataproxy_cc_library( "@org_apache_arrow//:arrow_flight", ], ) + +dataproxy_cc_test( + name = "file_help_test", + srcs = ["file_help_test.cc"], + deps = [ + ":file_help", + "//dataproxy_sdk/test:random", + "//dataproxy_sdk/test:test_utils", + ], +) + +dataproxy_cc_test( + name = "data_proxy_conn_test", + srcs = ["data_proxy_conn_test.cc"], + deps = [ + ":data_proxy_conn", + "//dataproxy_sdk/test:data_mesh_mock", + "//dataproxy_sdk/test:random", + ], +) + +dataproxy_cc_test( + name = "data_proxy_file_test", + srcs = ["data_proxy_file_test.cc"], + deps = [ + ":data_proxy_file", + "//dataproxy_sdk/test:data_mesh_mock", + "//dataproxy_sdk/test:random", + "//dataproxy_sdk/test:test_utils", + ], +) + +dataproxy_cc_test( + name = "data_proxy_pb_test", + srcs = ["data_proxy_pb_test.cc"], + deps = [ + ":proto", + ], +) diff --git a/dataproxy_sdk/cc/data_proxy_conn.cc b/dataproxy_sdk/cc/data_proxy_conn.cc index a76223d..80263e7 100644 --- a/dataproxy_sdk/cc/data_proxy_conn.cc +++ b/dataproxy_sdk/cc/data_proxy_conn.cc @@ -135,6 +135,12 @@ FlightStreamReaderWrapper::ReadRecordBatch() { return chunk.data; } +std::shared_ptr FlightStreamReaderWrapper::GetSchema() { + std::shared_ptr ret; + ASSIGN_ARROW_OR_THROW(ret, stream_reader_->GetSchema()); + return ret; +} + DataProxyConn::DataProxyConn() { impl_ = std::make_unique(); } diff --git a/dataproxy_sdk/cc/data_proxy_conn.h b/dataproxy_sdk/cc/data_proxy_conn.h index 606fbcd..5c60bbb 100644 --- a/dataproxy_sdk/cc/data_proxy_conn.h +++ b/dataproxy_sdk/cc/data_proxy_conn.h @@ -48,6 +48,7 @@ class DoPutResultWrapper { class FlightStreamReaderWrapper { public: std::shared_ptr ReadRecordBatch(); + std::shared_ptr GetSchema(); public: FlightStreamReaderWrapper( diff --git a/dataproxy_sdk/cc/data_proxy_conn_test.cc b/dataproxy_sdk/cc/data_proxy_conn_test.cc new file mode 100644 index 0000000..7260856 --- /dev/null +++ b/dataproxy_sdk/cc/data_proxy_conn_test.cc @@ -0,0 +1,90 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dataproxy_sdk/cc/data_proxy_conn.h" + +#include +#include + +#include "gtest/gtest.h" + +#include "dataproxy_sdk/cc/exception.h" +#include "dataproxy_sdk/test/data_mesh_mock.h" +#include "dataproxy_sdk/test/random.h" + +namespace dataproxy_sdk { + +static const std::string kDataMeshAddress = "127.0.0.1:23333"; +static const std::string kDataProxyAddress = "127.0.0.1:23334"; + +class TestDataProxyConn : public ::testing::Test { + public: + void SetUp() { + data_mesh_ = DataMeshMock::Make(); + CHECK_ARROW_OR_THROW(data_mesh_->StartServer(kDataMeshAddress)); + + data_ = RandomBatchGenerator::ExampleGenerate(); + } + + protected: + std::shared_ptr data_mesh_; + std::shared_ptr data_; +}; + +class TestDataProxyConnUseDP : public ::testing::Test { + public: + void SetUp() { + data_mesh_ = DataMeshMock::Make(); + CHECK_ARROW_OR_THROW(data_mesh_->StartServer(kDataProxyAddress, true)); + + data_ = RandomBatchGenerator::ExampleGenerate(); + } + + protected: + std::unique_ptr data_mesh_; + std::shared_ptr data_; +}; + +std::shared_ptr DataProxyConnPutAndGet( + const std::string& ip, const std::shared_ptr& batch) { + arrow::flight::FlightClientOptions options = + arrow::flight::FlightClientOptions::Defaults(); + auto dp_conn = DataProxyConn::Connect(ip, false, options); + auto descriptor = arrow::flight::FlightDescriptor::Command(""); + + auto put_result = dp_conn->DoPut(descriptor, batch->schema()); + put_result->WriteRecordBatch(*batch); + put_result->Close(); + + std::shared_ptr result_batch; + auto get_result = dp_conn->DoGet(descriptor); + result_batch = get_result->ReadRecordBatch(); + + dp_conn->Close(); + return result_batch; +} + +TEST_F(TestDataProxyConn, PutAndGet) { + auto result = DataProxyConnPutAndGet(kDataMeshAddress, data_); + + EXPECT_TRUE(data_->Equals(*result)); +} + +TEST_F(TestDataProxyConnUseDP, PutAndGet) { + auto result = DataProxyConnPutAndGet(kDataProxyAddress, data_); + + EXPECT_TRUE(data_->Equals(*result)); +} + +} // namespace dataproxy_sdk diff --git a/dataproxy_sdk/cc/data_proxy_file.cc b/dataproxy_sdk/cc/data_proxy_file.cc index b4270d3..785f233 100644 --- a/dataproxy_sdk/cc/data_proxy_file.cc +++ b/dataproxy_sdk/cc/data_proxy_file.cc @@ -19,11 +19,13 @@ #include "arrow/buffer.h" #include "arrow/flight/api.h" +#include "arrow/util/byte_size.h" +#include "spdlog/spdlog.h" + #include "dataproxy_sdk/cc/data_proxy_conn.h" #include "dataproxy_sdk/cc/exception.h" #include "dataproxy_sdk/cc/file_help.h" #include "dataproxy_sdk/cc/utils.h" -#include "spdlog/spdlog.h" namespace dataproxy_sdk { @@ -45,6 +47,17 @@ class DataProxyFile::Impl { config.has_tls_config(), options); } + FileHelpWrite::Options BuildWriteOptions(const proto::DownloadInfo &info) { + FileHelpWrite::Options options = FileHelpWrite::Options::Defaults(); + if (info.has_orc_info()) { + options.compression = + static_cast(info.orc_info().compression()); + options.compression_block_size = info.orc_info().compression_block_size(); + options.stripe_size = info.orc_info().stripe_size(); + } + return options; + } + void DownloadFile(const proto::DownloadInfo &info, const std::string &file_path, proto::FileFormat file_format) { @@ -57,8 +70,15 @@ class DataProxyFile::Impl { auto stream_reader = dp_conn_->DoGet(descriptor); // 4. 从读取流下载数据 + auto write_options = BuildWriteOptions(info); std::unique_ptr file_write = - FileHelpWrite::Make(file_format, file_path); + FileHelpWrite::Make(file_format, file_path, write_options); + // 当没有数据传输时,需要生成具有schema信息的文件 + std::shared_ptr empty_batch; + ASSIGN_ARROW_OR_THROW( + empty_batch, arrow::RecordBatch::MakeEmpty(stream_reader->GetSchema())); + file_write->DoWrite(empty_batch); + while (true) { auto record_batch = stream_reader->ReadRecordBatch(); if (record_batch == nullptr) { @@ -93,6 +113,12 @@ class DataProxyFile::Impl { auto put_result = dp_conn_->DoPut(descriptor, file_read->Schema()); + static const int64_t kMaxBatchSize = 64 * 1024 * 1024; + int64_t slice_size = 0; + int64_t slice_len = 0; + int64_t slice_offset = 0; + int64_t slice_left = 0; + int64_t batch_size = 0; // 5. 向写入流写入文件数据 while (true) { std::shared_ptr batch; @@ -100,7 +126,21 @@ class DataProxyFile::Impl { if (batch.get() == nullptr) { break; } - put_result->WriteRecordBatch(*batch); + + ASSIGN_DP_OR_THROW(batch_size, arrow::util::ReferencedBufferSize(*batch)); + if (batch_size > kMaxBatchSize) { + slice_size = (batch_size + kMaxBatchSize - 1) / kMaxBatchSize; + slice_left = batch->num_rows(); + slice_len = (slice_left + slice_size - 1) / slice_size; + while (slice_left > 0) { + put_result->WriteRecordBatch( + *(batch->Slice(slice_offset, std::min(slice_len, slice_left)))); + slice_offset += slice_len; + slice_left -= slice_len; + } + } else { + put_result->WriteRecordBatch(*batch); + } } put_result->Close(); @@ -160,11 +200,20 @@ class DataProxyFile::Impl { std::unique_ptr DataProxyFile::Make( const proto::DataProxyConfig &config) { + proto::DataProxyConfig dp_config; + dp_config.CopyFrom(config); + GetDPConfigValueFromEnv(&dp_config); + std::unique_ptr ret = std::make_unique(); - ret->impl_->Init(config); + ret->impl_->Init(dp_config); return ret; } +std::unique_ptr DataProxyFile::Make() { + proto::DataProxyConfig config; + return DataProxyFile::Make(config); +} + DataProxyFile::DataProxyFile() { impl_ = std::make_unique(); } diff --git a/dataproxy_sdk/cc/data_proxy_file.h b/dataproxy_sdk/cc/data_proxy_file.h index 7b1f02b..fb4403f 100644 --- a/dataproxy_sdk/cc/data_proxy_file.h +++ b/dataproxy_sdk/cc/data_proxy_file.h @@ -26,6 +26,8 @@ class DataProxyFile { static std::unique_ptr Make( const proto::DataProxyConfig& config); + static std::unique_ptr Make(); + public: DataProxyFile(); ~DataProxyFile(); diff --git a/dataproxy_sdk/cc/data_proxy_file_test.cc b/dataproxy_sdk/cc/data_proxy_file_test.cc new file mode 100644 index 0000000..66525a5 --- /dev/null +++ b/dataproxy_sdk/cc/data_proxy_file_test.cc @@ -0,0 +1,146 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dataproxy_sdk/cc/data_proxy_file.h" + +#include + +#include "gtest/gtest.h" + +#include "dataproxy_sdk/cc/exception.h" +#include "dataproxy_sdk/cc/file_help.h" +#include "dataproxy_sdk/test/data_mesh_mock.h" +#include "dataproxy_sdk/test/random.h" +#include "dataproxy_sdk/test/test_utils.h" + +namespace dataproxy_sdk { + +class TestDataProxyFile : public ::testing::Test { + public: + void SetUp() { + data_mesh_ = DataMeshMock::Make(); + CHECK_ARROW_OR_THROW(data_mesh_->StartServer(kDataMeshAddress)); + + dataproxy_sdk::proto::DataProxyConfig sdk_config; + sdk_config.set_data_proxy_addr(kDataMeshAddress); + data_proxy_file_ = DataProxyFile::Make(sdk_config); + + data_ = RandomBatchGenerator::ExampleGenerate(); + } + + protected: + std::shared_ptr data_mesh_; + std::shared_ptr data_; + std::unique_ptr data_proxy_file_; + const std::string kDataMeshAddress = "127.0.0.1:23335"; +}; + +TEST_F(TestDataProxyFile, UploadAndDownload) { + const std::string upload_file = "tmp_upload.orc"; + const std::string download_file = "tmp_download.orc"; + auto write_options = FileHelpWrite::Options::Defaults(); + auto file_writer = FileHelpWrite::Make(GetFileFormat(upload_file), + upload_file, write_options); + file_writer->DoWrite(data_); + file_writer->DoClose(); + + proto::UploadInfo upload_info; + upload_info.set_domaindata_id(""); + upload_info.set_type("table"); + for (const auto& field : data_->schema()->fields()) { + auto column = upload_info.add_columns(); + column->set_name(field->name()); + column->set_type(field->type()->name()); + } + data_proxy_file_->UploadFile(upload_info, upload_file, + GetFileFormat(upload_file)); + + proto::DownloadInfo download_info; + download_info.set_domaindata_id("test"); + data_proxy_file_->DownloadFile(download_info, download_file, + GetFileFormat(download_file)); + data_proxy_file_->Close(); + + auto read_options = FileHelpRead::Options::Defaults(); + auto file_reader = FileHelpRead::Make(GetFileFormat(download_file), + download_file, read_options); + std::shared_ptr result_batch; + file_reader->DoRead(&result_batch); + file_reader->DoClose(); + + std::cout << data_->ToString() << std::endl; + std::cout << result_batch->ToString() << std::endl; + + EXPECT_TRUE(data_->Equals(*result_batch)); +} + +class TestDataProxyFileEmpty : public ::testing::Test { + public: + void SetUp() { + data_mesh_ = DataMeshMock::Make(); + CHECK_ARROW_OR_THROW(data_mesh_->StartServer(kDataMeshAddress)); + + dataproxy_sdk::proto::DataProxyConfig sdk_config; + sdk_config.set_data_proxy_addr(kDataMeshAddress); + data_proxy_file_ = DataProxyFile::Make(sdk_config); + + data_ = RandomBatchGenerator::ExampleGenerate(0); + } + + protected: + std::shared_ptr data_mesh_; + std::shared_ptr data_; + std::unique_ptr data_proxy_file_; + const std::string kDataMeshAddress = "127.0.0.1:23336"; +}; + +TEST_F(TestDataProxyFileEmpty, UploadAndDownload) { + const std::string upload_file = "empty_upload.orc"; + const std::string download_file = "empty_download.orc"; + auto write_options = FileHelpWrite::Options::Defaults(); + auto file_writer = FileHelpWrite::Make(GetFileFormat(upload_file), + upload_file, write_options); + file_writer->DoWrite(data_); + file_writer->DoClose(); + + proto::UploadInfo upload_info; + upload_info.set_domaindata_id(""); + upload_info.set_type("table"); + for (const auto& field : data_->schema()->fields()) { + auto column = upload_info.add_columns(); + column->set_name(field->name()); + column->set_type(field->type()->name()); + } + data_proxy_file_->UploadFile(upload_info, upload_file, + GetFileFormat(upload_file)); + + proto::DownloadInfo download_info; + download_info.set_domaindata_id("test"); + data_proxy_file_->DownloadFile(download_info, download_file, + GetFileFormat(download_file)); + data_proxy_file_->Close(); + + auto read_options = FileHelpRead::Options::Defaults(); + auto file_reader = FileHelpRead::Make(GetFileFormat(download_file), + download_file, read_options); + std::shared_ptr result_batch; + file_reader->DoRead(&result_batch); + + EXPECT_TRUE(file_reader->Schema()->Equals(data_->schema())); + file_reader->DoClose(); + + EXPECT_TRUE(result_batch == nullptr); +} + +} // namespace dataproxy_sdk diff --git a/dataproxy_sdk/cc/data_proxy_pb.cc b/dataproxy_sdk/cc/data_proxy_pb.cc index 48bfe0b..5d2a786 100644 --- a/dataproxy_sdk/cc/data_proxy_pb.cc +++ b/dataproxy_sdk/cc/data_proxy_pb.cc @@ -24,7 +24,7 @@ inline proto::ContentType FormatToContentType(proto::FileFormat format) { return proto::ContentType::RAW; case proto::FileFormat::CSV: case proto::FileFormat::ORC: - return proto::ContentType::CSV; + return proto::ContentType::Table; default: DATAPROXY_THROW("do not support this type of format:{}", proto::FileFormat_Name(format)); @@ -50,8 +50,7 @@ google::protobuf::Any BuildDownloadAny(const proto::DownloadInfo& info, google::protobuf::Any any; proto::CommandDomainDataQuery msg; msg.set_domaindata_id(info.domaindata_id()); - // 需要更新kuscia版本 - // msg.set_partition_spec(info.partition_spec()); + msg.set_partition_spec(info.partition_spec()); msg.set_content_type(FormatToContentType(file_format)); any.PackFrom(msg); @@ -64,6 +63,11 @@ google::protobuf::Any BuildUploadAny(const proto::UploadInfo& info, proto::CommandDomainDataUpdate msg; msg.set_domaindata_id(info.domaindata_id()); msg.set_content_type(FormatToContentType(file_format)); + if (file_format != proto::FileFormat::BINARY) { + msg.mutable_file_write_options() + ->mutable_csv_options() + ->set_field_delimiter(","); + } any.PackFrom(msg); return any; @@ -104,7 +108,7 @@ proto::CreateDomainDataResponse GetActionCreateDomainDataResponse( void CheckUploadInfo(const proto::UploadInfo& info) { // Enum: table,model,rule,report,unknown if (info.type() != "table" && info.type() != "model" && - info.type() != "rule" && info.type() != "report") { + info.type() != "rule" && info.type() != "serving_model") { DATAPROXY_THROW("type[{}] not support in UploadInfo!", info.type()); } @@ -114,4 +118,30 @@ void CheckUploadInfo(const proto::UploadInfo& info) { } } -} // namespace dataproxy_sdk \ No newline at end of file +static inline char* GetEnvValue(std::string_view key) { + if (char* env_p = std::getenv(key.data())) { + if (strlen(env_p) != 0) { + return env_p; + } + } + return nullptr; +} + +void GetDPConfigValueFromEnv(proto::DataProxyConfig* config) { + if (config == nullptr) return; + + if (char* env_value = GetEnvValue("CLIENT_CERT_FILE")) { + config->mutable_tls_config()->set_certificate_path(env_value); + } + if (char* env_value = GetEnvValue("CLIENT_PRIVATE_KEY_FILE")) { + config->mutable_tls_config()->set_private_key_path(env_value); + } + if (char* env_value = GetEnvValue("TRUSTED_CA_FILE")) { + config->mutable_tls_config()->set_ca_file_path(env_value); + } + if (char* env_value = GetEnvValue("KUSCIA_DATA_MESH_ADDR")) { + config->set_data_proxy_addr(env_value); + } +} + +} // namespace dataproxy_sdk diff --git a/dataproxy_sdk/cc/data_proxy_pb.h b/dataproxy_sdk/cc/data_proxy_pb.h index 1d13f93..df23ab9 100644 --- a/dataproxy_sdk/cc/data_proxy_pb.h +++ b/dataproxy_sdk/cc/data_proxy_pb.h @@ -44,4 +44,6 @@ proto::CreateDomainDataResponse GetActionCreateDomainDataResponse( void CheckUploadInfo(const proto::UploadInfo& info); -} // namespace dataproxy_sdk \ No newline at end of file +void GetDPConfigValueFromEnv(proto::DataProxyConfig* config); + +} // namespace dataproxy_sdk diff --git a/dataproxy_sdk/cc/data_proxy_pb_test.cc b/dataproxy_sdk/cc/data_proxy_pb_test.cc new file mode 100644 index 0000000..25e4c43 --- /dev/null +++ b/dataproxy_sdk/cc/data_proxy_pb_test.cc @@ -0,0 +1,144 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dataproxy_sdk/cc/data_proxy_pb.h" + +#include + +#include "gtest/gtest.h" + +namespace dataproxy_sdk { + +TEST(DataProxyPD, GetConfigFromEnv) { + std::unordered_map env_values = { + {"CLIENT_CERT_FILE", "aaa"}, + {"CLIENT_PRIVATE_KEY_FILE", "bbb"}, + {"TRUSTED_CA_FILE", "ccc"}, + {"KUSCIA_DATA_MESH_ADDR", "ddd"}}; + for (const auto& it : env_values) { + setenv(it.first.c_str(), it.second.c_str(), 0); + } + + proto::DataProxyConfig config; + GetDPConfigValueFromEnv(&config); + + for (const auto& it : env_values) { + unsetenv(it.first.c_str()); + } + + std::cout << config.DebugString() << std::endl; + + EXPECT_EQ(config.tls_config().certificate_path(), + env_values["CLIENT_CERT_FILE"]); + EXPECT_EQ(config.tls_config().private_key_path(), + env_values["CLIENT_PRIVATE_KEY_FILE"]); + EXPECT_EQ(config.tls_config().ca_file_path(), env_values["TRUSTED_CA_FILE"]); + EXPECT_EQ(config.data_proxy_addr(), env_values["KUSCIA_DATA_MESH_ADDR"]); +} + +TEST(DataProxyPD, GetConfig) { + static const std::string kCertificate = "eee"; + static const std::string kPrivateKey = "fff"; + static const std::string kCa = "ggg"; + static const std::string kAddress = "hhh"; + + proto::DataProxyConfig config; + config.mutable_tls_config()->set_certificate_path(kCertificate); + config.mutable_tls_config()->set_private_key_path(kPrivateKey); + config.mutable_tls_config()->set_ca_file_path(kCa); + config.set_data_proxy_addr(kAddress); + + GetDPConfigValueFromEnv(&config); + + std::cout << config.DebugString() << std::endl; + + EXPECT_EQ(config.tls_config().certificate_path(), kCertificate); + EXPECT_EQ(config.tls_config().private_key_path(), kPrivateKey); + EXPECT_EQ(config.tls_config().ca_file_path(), kCa); + EXPECT_EQ(config.data_proxy_addr(), kAddress); +} + +TEST(DataProxyPD, GetConfigWithNullEnv) { + static const std::string kCertificate = "iii"; + static const std::string kPrivateKey = "jjj"; + static const std::string kCa = "kkk"; + static const std::string kAddress = "lll"; + + std::unordered_map env_values = { + {"CLIENT_CERT_FILE", ""}, + {"CLIENT_PRIVATE_KEY_FILE", ""}, + {"TRUSTED_CA_FILE", ""}, + {"KUSCIA_DATA_MESH_ADDR", ""}}; + for (const auto& it : env_values) { + setenv(it.first.c_str(), it.second.c_str(), 0); + } + + proto::DataProxyConfig config; + config.mutable_tls_config()->set_certificate_path(kCertificate); + config.mutable_tls_config()->set_private_key_path(kPrivateKey); + config.mutable_tls_config()->set_ca_file_path(kCa); + config.set_data_proxy_addr(kAddress); + + GetDPConfigValueFromEnv(&config); + + for (const auto& it : env_values) { + unsetenv(it.first.c_str()); + } + + std::cout << config.DebugString() << std::endl; + + EXPECT_EQ(config.tls_config().certificate_path(), kCertificate); + EXPECT_EQ(config.tls_config().private_key_path(), kPrivateKey); + EXPECT_EQ(config.tls_config().ca_file_path(), kCa); + EXPECT_EQ(config.data_proxy_addr(), kAddress); +} + +TEST(DataProxyPD, GetConfigWithEnv) { + static const std::string kCertificate = "mmm"; + static const std::string kPrivateKey = "nnn"; + static const std::string kCa = "ooo"; + static const std::string kAddress = "ppp"; + + std::unordered_map env_values = { + {"CLIENT_CERT_FILE", "qqq"}, + {"CLIENT_PRIVATE_KEY_FILE", "rrr"}, + {"TRUSTED_CA_FILE", "sss"}, + {"KUSCIA_DATA_MESH_ADDR", "ttt"}}; + for (const auto& it : env_values) { + setenv(it.first.c_str(), it.second.c_str(), 0); + } + + proto::DataProxyConfig config; + config.mutable_tls_config()->set_certificate_path(kCertificate); + config.mutable_tls_config()->set_private_key_path(kPrivateKey); + config.mutable_tls_config()->set_ca_file_path(kCa); + config.set_data_proxy_addr(kAddress); + + GetDPConfigValueFromEnv(&config); + + for (const auto& it : env_values) { + unsetenv(it.first.c_str()); + } + + std::cout << config.DebugString() << std::endl; + + EXPECT_EQ(config.tls_config().certificate_path(), + env_values["CLIENT_CERT_FILE"]); + EXPECT_EQ(config.tls_config().private_key_path(), + env_values["CLIENT_PRIVATE_KEY_FILE"]); + EXPECT_EQ(config.tls_config().ca_file_path(), env_values["TRUSTED_CA_FILE"]); + EXPECT_EQ(config.data_proxy_addr(), env_values["KUSCIA_DATA_MESH_ADDR"]); +} + +} // namespace dataproxy_sdk diff --git a/dataproxy_sdk/cc/exception.h b/dataproxy_sdk/cc/exception.h index fefa1de..ac20290 100644 --- a/dataproxy_sdk/cc/exception.h +++ b/dataproxy_sdk/cc/exception.h @@ -50,4 +50,13 @@ namespace dataproxy_sdk { lhs = std::move(__s__).ValueOrDie(); \ } while (false) +#define ASSIGN_DP_OR_THROW(lhs, rexpr) \ + auto&& _error_or_value = (rexpr); \ + do { \ + if ((__builtin_expect(!!(!(_error_or_value).ok()), 0))) { \ + DATAPROXY_THROW((_error_or_value).status().message()); \ + } \ + } while (0); \ + lhs = std::move(_error_or_value).ValueUnsafe(); + } // namespace dataproxy_sdk \ No newline at end of file diff --git a/dataproxy_sdk/cc/file_help.cc b/dataproxy_sdk/cc/file_help.cc index 10d30a3..2028be0 100644 --- a/dataproxy_sdk/cc/file_help.cc +++ b/dataproxy_sdk/cc/file_help.cc @@ -20,14 +20,19 @@ #include "arrow/builder.h" #include "arrow/csv/api.h" #include "arrow/io/api.h" +#include "arrow/ipc/writer.h" + #include "dataproxy_sdk/cc/exception.h" -#include "file_help.h" namespace dataproxy_sdk { class BinaryFileWrite : public FileHelpWrite { public: void DoWrite(std::shared_ptr& record_batch) { + if (record_batch->num_rows() == 0) { + return; + } + DATAPROXY_ENFORCE_EQ(record_batch->num_columns(), 1); auto binary_array = @@ -38,7 +43,8 @@ class BinaryFileWrite : public FileHelpWrite { void DoClose() { CHECK_ARROW_OR_THROW(out_stream_->Close()); } protected: - void DoOpen(const std::string& file_name) { + void DoOpen(const std::string& file_name, + const FileHelpWrite::Options& options) { ASSIGN_ARROW_OR_THROW(out_stream_, arrow::io::FileOutputStream::Open(file_name)); } @@ -50,20 +56,34 @@ class BinaryFileWrite : public FileHelpWrite { class CSVFileWrite : public FileHelpWrite { public: void DoWrite(std::shared_ptr& record_batch) { - CHECK_ARROW_OR_THROW(arrow::csv::WriteCSV( - *record_batch, arrow::csv::WriteOptions::Defaults(), - out_stream_.get())); + // 由于每次调用WriteCSV都会在文件中生成列信息,所以只在第一次写入时调用MakeCSVWriter + if (!writer_) { + ASSIGN_ARROW_OR_THROW( + writer_, arrow::csv::MakeCSVWriter(out_stream_, + record_batch->schema(), options_)); + } + CHECK_ARROW_OR_THROW(writer_->WriteRecordBatch(*record_batch)); + } + void DoClose() { + if (writer_) { + CHECK_ARROW_OR_THROW(writer_->Close()); + } + CHECK_ARROW_OR_THROW(out_stream_->Close()); } - void DoClose() { CHECK_ARROW_OR_THROW(out_stream_->Close()); } protected: - void DoOpen(const std::string& file_name) { + void DoOpen(const std::string& file_name, + const FileHelpWrite::Options& options) { + options_ = arrow::csv::WriteOptions::Defaults(); + options_.quoting_style = arrow::csv::QuotingStyle::None; ASSIGN_ARROW_OR_THROW(out_stream_, arrow::io::FileOutputStream::Open(file_name)); } private: std::shared_ptr out_stream_; + std::shared_ptr writer_; + arrow::csv::WriteOptions options_; }; class ORCFileWrite : public FileHelpWrite { @@ -71,18 +91,25 @@ class ORCFileWrite : public FileHelpWrite { void DoWrite(std::shared_ptr& record_batch) { CHECK_ARROW_OR_THROW(orc_writer_->Write(*record_batch)); } + void DoClose() { CHECK_ARROW_OR_THROW(orc_writer_->Close()); CHECK_ARROW_OR_THROW(out_stream_->Close()); }; protected: - void DoOpen(const std::string& file_name) { + void DoOpen(const std::string& file_name, + const FileHelpWrite::Options& options) { ASSIGN_ARROW_OR_THROW(out_stream_, arrow::io::FileOutputStream::Open(file_name)); - ASSIGN_ARROW_OR_THROW( - orc_writer_, - arrow::adapters::orc::ORCFileWriter::Open(out_stream_.get())); + + arrow::adapters::orc::WriteOptions write_opts; + write_opts.compression = options.compression; + write_opts.compression_block_size = options.compression_block_size; + write_opts.stripe_size = options.stripe_size; + ASSIGN_ARROW_OR_THROW(orc_writer_, + arrow::adapters::orc::ORCFileWriter::Open( + out_stream_.get(), write_opts)); } private: @@ -91,7 +118,8 @@ class ORCFileWrite : public FileHelpWrite { }; std::unique_ptr FileHelpWrite::Make( - proto::FileFormat file_format, const std::string& file_name) { + proto::FileFormat file_format, const std::string& file_name, + const FileHelpWrite::Options& options) { std::unique_ptr ret; switch (file_format) { case proto::FileFormat::CSV: @@ -108,18 +136,18 @@ std::unique_ptr FileHelpWrite::Make( proto::FileFormat_Name(file_format)); break; } - ret->DoOpen(file_name); + ret->DoOpen(file_name, options); return ret; } -class BinaryFileRead : public FileHelpRead { - public: - BinaryFileRead(FileHelpRead::Options options) : FileHelpRead(options) {} - ~BinaryFileRead() = default; +FileHelpWrite::Options FileHelpWrite::Options::Defaults() { + return FileHelpWrite::Options(); +} +class BinaryFileRead : public FileHelpRead { private: - const int64_t kReadBytesLen = 128 * 1024; - const int64_t kchunksNum = 8; + static const int64_t kReadBytesLen = 128 * 1024; + static const int64_t kChunkNum = 8; public: static std::shared_ptr kBinaryFileSchema; @@ -127,7 +155,7 @@ class BinaryFileRead : public FileHelpRead { public: void DoRead(std::shared_ptr* record_batch) { arrow::BinaryBuilder binary_build; - for (int i = 0; i < kchunksNum; ++i) { + for (int i = 0; i < kChunkNum; ++i) { std::shared_ptr buffer; ASSIGN_ARROW_OR_THROW(buffer, read_stream_->Read(kReadBytesLen)); CHECK_ARROW_OR_THROW(binary_build.Append(buffer->data(), buffer->size())); @@ -145,7 +173,8 @@ class BinaryFileRead : public FileHelpRead { std::shared_ptr Schema() { return kBinaryFileSchema; } protected: - void DoOpen(const std::string& file_name) { + void DoOpen(const std::string& file_name, + const FileHelpRead::Options& options) { std::shared_ptr file_stream; ASSIGN_ARROW_OR_THROW(file_stream, arrow::io::ReadableFile::Open(file_name)); @@ -163,17 +192,6 @@ std::shared_ptr BinaryFileRead::kBinaryFileSchema = arrow::schema({arrow::field("binary_data", arrow::binary())}); class CSVFileRead : public FileHelpRead { - public: - CSVFileRead(FileHelpRead::Options options) - : FileHelpRead(options), - convert_options_(arrow::csv::ConvertOptions::Defaults()) { - for (auto& pair : options.column_types) { - convert_options_.column_types.emplace(pair.first, pair.second); - convert_options_.include_columns.push_back(pair.first); - } - } - ~CSVFileRead() = default; - public: void DoRead(std::shared_ptr* record_batch) { CHECK_ARROW_OR_THROW(file_reader_->ReadNext(record_batch)); @@ -182,33 +200,31 @@ class CSVFileRead : public FileHelpRead { std::shared_ptr Schema() { return file_reader_->schema(); } protected: - void DoOpen(const std::string& file_name) { + void DoOpen(const std::string& file_name, + const FileHelpRead::Options& options) { std::shared_ptr file_stream; ASSIGN_ARROW_OR_THROW(file_stream, arrow::io::ReadableFile::Open(file_name)); + + arrow::csv::ConvertOptions convert_options = + arrow::csv::ConvertOptions::Defaults(); + for (auto& pair : options.column_types) { + convert_options.column_types.emplace(pair.first, pair.second); + convert_options.include_columns.push_back(pair.first); + } ASSIGN_ARROW_OR_THROW( file_reader_, arrow::csv::StreamingReader::Make( arrow::io::default_io_context(), file_stream, arrow::csv::ReadOptions::Defaults(), - arrow::csv::ParseOptions::Defaults(), convert_options_)); + arrow::csv::ParseOptions::Defaults(), convert_options)); } private: std::shared_ptr file_reader_; - arrow::csv::ConvertOptions convert_options_; }; class ORCFileRead : public FileHelpRead { - public: - ORCFileRead(FileHelpRead::Options options) - : FileHelpRead(options), current_stripe_(0) { - for (auto& pair : options.column_types) { - include_names_.push_back(pair.first); - } - } - ~ORCFileRead() = default; - public: void DoRead(std::shared_ptr* record_batch) { if (current_stripe_ >= orc_reader_->NumberOfStripes()) return; @@ -230,7 +246,12 @@ class ORCFileRead : public FileHelpRead { } protected: - void DoOpen(const std::string& file_name) { + void DoOpen(const std::string& file_name, + const FileHelpRead::Options& options) { + for (auto& pair : options.column_types) { + include_names_.push_back(pair.first); + } + ASSIGN_ARROW_OR_THROW(file_stream_, arrow::io::ReadableFile::Open(file_name)); ASSIGN_ARROW_OR_THROW(orc_reader_, @@ -239,7 +260,7 @@ class ORCFileRead : public FileHelpRead { } private: - int64_t current_stripe_; + int64_t current_stripe_ = 0; std::unique_ptr orc_reader_; std::shared_ptr file_stream_; std::vector include_names_; @@ -251,20 +272,20 @@ std::unique_ptr FileHelpRead::Make( std::unique_ptr ret; switch (file_format) { case proto::FileFormat::CSV: - ret = std::make_unique(options); + ret = std::make_unique(); break; case proto::FileFormat::BINARY: - ret = std::make_unique(options); + ret = std::make_unique(); break; case proto::FileFormat::ORC: - ret = std::make_unique(options); + ret = std::make_unique(); break; default: DATAPROXY_THROW("format[{}] not support.", proto::FileFormat_Name(file_format)); break; } - ret->DoOpen(file_name); + ret->DoOpen(file_name, options); return ret; } diff --git a/dataproxy_sdk/cc/file_help.h b/dataproxy_sdk/cc/file_help.h index 63e0d85..91e9c0e 100644 --- a/dataproxy_sdk/cc/file_help.h +++ b/dataproxy_sdk/cc/file_help.h @@ -18,36 +18,41 @@ #include #include "arrow/type.h" +#include "arrow/util/type_fwd.h" + #include "dataproxy_sdk/cc/data_proxy_pb.h" namespace dataproxy_sdk { -class FileHelpBase { - public: - FileHelpBase() = default; - virtual ~FileHelpBase() = default; - +class FileHelpWrite { public: - virtual void DoClose() = 0; + struct Options { + // only orc use by sf + arrow::Compression::type compression = arrow::Compression::ZSTD; + // only orc use by sf + int64_t compression_block_size = 256 * 1024; + // only orc use by sf + int64_t stripe_size = 64 * 1024 * 1024; - protected: - virtual void DoOpen(const std::string& file_name) = 0; -}; + static Options Defaults(); + }; -class FileHelpWrite : public FileHelpBase { public: static std::unique_ptr Make(proto::FileFormat file_format, - const std::string& file_name); + const std::string& file_name, + const Options& options); public: FileHelpWrite() = default; virtual ~FileHelpWrite() = default; public: + virtual void DoOpen(const std::string& file_name, const Options& options) = 0; + virtual void DoClose() = 0; virtual void DoWrite(std::shared_ptr& record_batch) = 0; }; -class FileHelpRead : public FileHelpBase { +class FileHelpRead { public: struct Options { std::unordered_map> @@ -60,18 +65,16 @@ class FileHelpRead : public FileHelpBase { static std::unique_ptr Make(proto::FileFormat file_format, const std::string& file_name, const Options& options); - static std::unique_ptr Make(proto::FileFormat file_format, - const std::string& file_name) { - return Make(file_format, file_name, Options::Defaults()); - } public: - explicit FileHelpRead(const Options& options){}; + FileHelpRead() = default; virtual ~FileHelpRead() = default; public: - virtual std::shared_ptr Schema() = 0; + virtual void DoOpen(const std::string& file_name, const Options& options) = 0; + virtual void DoClose() = 0; virtual void DoRead(std::shared_ptr* record_batch) = 0; + virtual std::shared_ptr Schema() = 0; }; } // namespace dataproxy_sdk \ No newline at end of file diff --git a/dataproxy_sdk/cc/file_help_test.cc b/dataproxy_sdk/cc/file_help_test.cc index 144014c..be2b4d3 100644 --- a/dataproxy_sdk/cc/file_help_test.cc +++ b/dataproxy_sdk/cc/file_help_test.cc @@ -15,132 +15,39 @@ #include "dataproxy_sdk/cc/file_help.h" #include -#include -#include -#include -#include -#include "arrow/adapters/orc/adapter.h" #include "arrow/builder.h" -#include "arrow/csv/api.h" -#include "arrow/io/api.h" -#include "dataproxy_sdk/cc/exception.h" #include "gtest/gtest.h" -namespace dataproxy_sdk { - -class RandomBatchGenerator { - public: - std::shared_ptr schema; - RandomBatchGenerator(std::shared_ptr schema) - : schema(schema){}; - - static std::shared_ptr Generate( - std::shared_ptr schema, int32_t num_rows) { - RandomBatchGenerator generator(schema); - - std::shared_ptr batch; - ASSIGN_ARROW_OR_THROW(batch, generator.Generate(num_rows)); - return batch; - } - - arrow::Result> Generate( - int32_t num_rows) { - num_rows_ = num_rows; - for (std::shared_ptr field : schema->fields()) { - ARROW_RETURN_NOT_OK(arrow::VisitTypeInline(*field->type(), this)); - } - return arrow::RecordBatch::Make(schema, num_rows, arrays_); - } - - // Default implementation - arrow::Status Visit(const arrow::DataType &type) { - return arrow::Status::NotImplemented("Generating data for", - type.ToString()); - } - - arrow::Status Visit(const arrow::BinaryType &) { - auto builder = arrow::BinaryBuilder(); - // std::normal_distribution<> d{ - // /*mean=*/0x05, - // }; // 正态分布 - for (int32_t i = 0; i < num_rows_; ++i) { - ARROW_RETURN_NOT_OK(builder.Append("03", 2)); - } - - ARROW_ASSIGN_OR_RAISE(auto array, builder.Finish()); - arrays_.push_back(array); - return arrow::Status::OK(); - } - - arrow::Status Visit(const arrow::DoubleType &) { - auto builder = arrow::DoubleBuilder(); - std::normal_distribution<> d{/*mean=*/5.0, /*stddev=*/2.0}; // 正态分布 - for (int32_t i = 0; i < num_rows_; ++i) { - ARROW_RETURN_NOT_OK(builder.Append(d(gen_))); - } - - ARROW_ASSIGN_OR_RAISE(auto array, builder.Finish()); - arrays_.push_back(array); - return arrow::Status::OK(); - } - - arrow::Status Visit(const arrow::Int64Type &) { - // Generate offsets first, which determines number of values in sub-array - std::poisson_distribution<> d{ - /*mean=*/4}; // 产生随机非负整数值i,按离散概率函数分布 - auto builder = arrow::Int64Builder(); - for (int32_t i = 0; i < num_rows_; ++i) { - ARROW_RETURN_NOT_OK(builder.Append(d(gen_))); - } - - ARROW_ASSIGN_OR_RAISE(auto array, builder.Finish()); - arrays_.push_back(array); - return arrow::Status::OK(); - } - - protected: - std::random_device rd_{}; - std::mt19937 gen_{rd_()}; // 随机种子 - std::vector> arrays_; - int32_t num_rows_; - -}; // RandomBatchGenerator - -proto::FileFormat GetFormat(const std::string &file) { - if (file.find(".csv") != std::string::npos) - return proto::FileFormat::CSV; - else if (file.find(".orc") != std::string::npos) - return proto::FileFormat::ORC; - - return proto::FileFormat::BINARY; -} - -static std::shared_ptr GetRecordBatch(int data_num = 2) { - static std::shared_ptr gSchema = arrow::schema( - {arrow::field("x", arrow::int64()), arrow::field("y", arrow::int64()), - arrow::field("z", arrow::int64())}); +#include "dataproxy_sdk/cc/exception.h" +#include "dataproxy_sdk/test/random.h" +#include "dataproxy_sdk/test/test_utils.h" - return RandomBatchGenerator::Generate(gSchema, data_num); -} +namespace dataproxy_sdk { const std::string kCSVFilePath = "test.csv"; const std::string kORCFilePath = "test.orc"; const std::string kBianryFilePath = "test.txt"; +template +std::unique_ptr GetDefaultFileHelp(const std::string &file_path) { + auto options = T::Options::Defaults(); + auto ret = T::Make(GetFileFormat(file_path), file_path, options); + return ret; +} + TEST(FileHelpTest, Binary) { std::shared_ptr schema = arrow::schema({arrow::field("binary_data", arrow::binary())}); std::shared_ptr batch = RandomBatchGenerator::Generate(schema, 1); - auto writer = - FileHelpWrite::Make(GetFormat(kBianryFilePath), kBianryFilePath); + auto writer = GetDefaultFileHelp(kBianryFilePath); writer->DoWrite(batch); writer->DoClose(); std::shared_ptr read_batch; - auto reader = FileHelpRead::Make(GetFormat(kBianryFilePath), kBianryFilePath); + auto reader = GetDefaultFileHelp(kBianryFilePath); reader->DoRead(&read_batch); reader->DoClose(); @@ -163,13 +70,12 @@ TEST(FileHelpTest, ZeroBinary) { std::shared_ptr batch = arrow::RecordBatch::Make(schema, arrays.size(), arrays); - auto writer = - FileHelpWrite::Make(GetFormat(kBianryFilePath), kBianryFilePath); + auto writer = GetDefaultFileHelp(kBianryFilePath); writer->DoWrite(batch); writer->DoClose(); std::shared_ptr read_batch; - auto reader = FileHelpRead::Make(GetFormat(kBianryFilePath), kBianryFilePath); + auto reader = GetDefaultFileHelp(kBianryFilePath); reader->DoRead(&read_batch); reader->DoClose(); @@ -180,14 +86,15 @@ TEST(FileHelpTest, ZeroBinary) { } TEST(FileHelpTest, CSV) { - std::shared_ptr batch = GetRecordBatch(); + std::shared_ptr batch = + RandomBatchGenerator::ExampleGenerate(); - auto writer = FileHelpWrite::Make(GetFormat(kCSVFilePath), kCSVFilePath); + auto writer = GetDefaultFileHelp(kCSVFilePath); writer->DoWrite(batch); writer->DoClose(); std::shared_ptr read_batch; - auto reader = FileHelpRead::Make(GetFormat(kCSVFilePath), kCSVFilePath); + auto reader = GetDefaultFileHelp(kCSVFilePath); reader->DoRead(&read_batch); reader->DoClose(); @@ -198,14 +105,15 @@ TEST(FileHelpTest, CSV) { } TEST(FileHelpTest, ORC) { - std::shared_ptr batch = GetRecordBatch(); + std::shared_ptr batch = + RandomBatchGenerator::ExampleGenerate(); - auto writer = FileHelpWrite::Make(GetFormat(kORCFilePath), kORCFilePath); + auto writer = GetDefaultFileHelp(kORCFilePath); writer->DoWrite(batch); writer->DoClose(); std::shared_ptr read_batch; - auto reader = FileHelpRead::Make(GetFormat(kORCFilePath), kORCFilePath); + auto reader = GetDefaultFileHelp(kORCFilePath); reader->DoRead(&read_batch); reader->DoClose(); @@ -227,14 +135,15 @@ std::vector GetSelectColumns() { } TEST(FileHelpTestWithOption, CSV) { - std::shared_ptr batch = GetRecordBatch(); + std::shared_ptr batch = + RandomBatchGenerator::ExampleGenerate(); - auto writer = FileHelpWrite::Make(GetFormat(kCSVFilePath), kCSVFilePath); + auto writer = GetDefaultFileHelp(kCSVFilePath); writer->DoWrite(batch); writer->DoClose(); std::shared_ptr read_batch; - auto reader = FileHelpRead::Make(GetFormat(kCSVFilePath), kCSVFilePath, + auto reader = FileHelpRead::Make(GetFileFormat(kCSVFilePath), kCSVFilePath, GetReadOptions()); reader->DoRead(&read_batch); reader->DoClose(); @@ -247,14 +156,15 @@ TEST(FileHelpTestWithOption, CSV) { } TEST(FileHelpTestWithOption, ORC) { - std::shared_ptr batch = GetRecordBatch(); + std::shared_ptr batch = + RandomBatchGenerator::ExampleGenerate(); - auto writer = FileHelpWrite::Make(GetFormat(kORCFilePath), kORCFilePath); + auto writer = GetDefaultFileHelp(kORCFilePath); writer->DoWrite(batch); writer->DoClose(); std::shared_ptr read_batch; - auto reader = FileHelpRead::Make(GetFormat(kORCFilePath), kORCFilePath, + auto reader = FileHelpRead::Make(GetFileFormat(kORCFilePath), kORCFilePath, GetReadOptions()); reader->DoRead(&read_batch); reader->DoClose(); @@ -273,29 +183,29 @@ FileHelpRead::Options GetErrorOptions() { } TEST(FileHelpTestWithOption, ErrorCSV) { - std::shared_ptr batch = GetRecordBatch(); + auto batch = RandomBatchGenerator::ExampleGenerate(); - auto writer = FileHelpWrite::Make(GetFormat(kCSVFilePath), kCSVFilePath); + auto writer = GetDefaultFileHelp(kCSVFilePath); writer->DoWrite(batch); writer->DoClose(); std::shared_ptr read_batch; - EXPECT_THROW(FileHelpRead::Make(GetFormat(kCSVFilePath), kCSVFilePath, + EXPECT_THROW(FileHelpRead::Make(GetFileFormat(kCSVFilePath), kCSVFilePath, GetErrorOptions()), yacl::Exception); } TEST(FileHelpTestWithOption, ErrorORC) { - std::shared_ptr batch = GetRecordBatch(); + auto batch = RandomBatchGenerator::ExampleGenerate(); - auto writer = FileHelpWrite::Make(GetFormat(kORCFilePath), kORCFilePath); + auto writer = GetDefaultFileHelp(kORCFilePath); writer->DoWrite(batch); writer->DoClose(); std::shared_ptr read_batch; - auto reader = FileHelpRead::Make(GetFormat(kORCFilePath), kORCFilePath, + auto reader = FileHelpRead::Make(GetFileFormat(kORCFilePath), kORCFilePath, GetErrorOptions()); EXPECT_THROW(reader->DoRead(&read_batch), yacl::Exception); } -} // namespace dataproxy_sdk \ No newline at end of file +} // namespace dataproxy_sdk diff --git a/dataproxy_sdk/proto/data_proxy_pb.proto b/dataproxy_sdk/proto/data_proxy_pb.proto index a43ad97..e84f437 100644 --- a/dataproxy_sdk/proto/data_proxy_pb.proto +++ b/dataproxy_sdk/proto/data_proxy_pb.proto @@ -46,10 +46,33 @@ message DataProxyConfig { TlSConfig tls_config = 2; } +message ORCFileInfo { + enum CompressionType { + UNCOMPRESSED = 0; + SNAPPY = 1; + GZIP = 2; + BROTLI = 3; + ZSTD = 4; + LZ4 = 5; + LZ4_FRAME = 6; + LZO = 7; + BZ2 = 8; + LZ4_HADOOP = 9; + } + + CompressionType compression = 1; + int64 compression_block_size = 2; + int64 stripe_size = 3; +} + message DownloadInfo { string domaindata_id = 1; // specific the partition column and value, such as "dmdt=20240520" string partition_spec = 2; + + oneof file_info { + ORCFileInfo orc_info = 10; + } } message UploadInfo { diff --git a/dataproxy_sdk/python/BUILD.bazel b/dataproxy_sdk/python/BUILD.bazel index 0621a76..8289c83 100644 --- a/dataproxy_sdk/python/BUILD.bazel +++ b/dataproxy_sdk/python/BUILD.bazel @@ -11,61 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") -load("@rules_python//python:defs.bzl", "py_library") - -package(default_visibility = ["//visibility:public"]) - -exports_files( - [ - "exported_symbols.lds", - "version_script.lds", - ], - visibility = ["//visibility:private"], -) - -pybind_extension( - name = "libdataproxy", - srcs = ["libdataproxy.cc"], - linkopts = select({ - "@bazel_tools//src/conditions:darwin": [ - "-Wl,-exported_symbols_list,$(location //dataproxy_sdk/python:exported_symbols.lds)", - ], - "//conditions:default": [ - "-Wl,--version-script,$(location //dataproxy_sdk/python:version_script.lds)", - ], - }), - deps = [ - ":exported_symbols.lds", - ":version_script.lds", - "//dataproxy_sdk/cc:dataproxy_sdk_cc", - ], -) - -py_library( - name = "data_proxy_file_py", - srcs = [ - "dp_file_adapter.py", - ], - data = [ - ":libdataproxy.so", - ], -) - -py_library( - name = "protos", - srcs = [ - "dp_pb2.py", - "//dataproxy_sdk/proto:data_proxy_proto_py", - ], -) - -py_library( - name = "init", - srcs = [ - "__init__.py", - ":data_proxy_file_py", - ":protos", - ], -) diff --git a/dataproxy_sdk/python/dataproxy/BUILD.bazel b/dataproxy_sdk/python/dataproxy/BUILD.bazel new file mode 100644 index 0000000..1d41f21 --- /dev/null +++ b/dataproxy_sdk/python/dataproxy/BUILD.bazel @@ -0,0 +1,74 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +exports_files( + [ + "exported_symbols.lds", + "version_script.lds", + ], + visibility = ["//visibility:private"], +) + +pybind_extension( + name = "libdataproxy", + srcs = ["libdataproxy.cc"], + linkopts = select({ + "@bazel_tools//src/conditions:darwin": [ + "-Wl,-exported_symbols_list,$(location :exported_symbols.lds)", + ], + "//conditions:default": [ + "-Wl,--version-script,$(location :version_script.lds)", + ], + }), + deps = [ + ":exported_symbols.lds", + ":version_script.lds", + "//dataproxy_sdk/cc:dataproxy_sdk_cc", + ], +) + +py_library( + name = "data_proxy_file_py", + srcs = [ + "dp_file_adapter.py", + ], + data = [ + ":libdataproxy.so", + ], +) + +py_library( + name = "protos", + srcs = [ + "dp_pb2.py", + "//dataproxy_sdk/proto:data_proxy_proto_py", + ], +) + +py_library( + name = "init", + srcs = [ + "__init__.py", + ":data_proxy_file_py", + ":protos", + ], + data = [ + ":libdataproxy.so", + ], +) diff --git a/dataproxy_sdk/python/dataproxy/__init__.py b/dataproxy_sdk/python/dataproxy/__init__.py new file mode 100644 index 0000000..2190371 --- /dev/null +++ b/dataproxy_sdk/python/dataproxy/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import sdk + +__all__ = [ + "sdk", +] diff --git a/dataproxy_sdk/python/dataproxy/dp_file_adapter.py b/dataproxy_sdk/python/dataproxy/dp_file_adapter.py new file mode 100644 index 0000000..4785267 --- /dev/null +++ b/dataproxy_sdk/python/dataproxy/dp_file_adapter.py @@ -0,0 +1,53 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import libdataproxy +from . import dp_pb2 as proto +import logging +import os + + +class DataProxyFileAdapter: + def __init__(self, config: proto.DataProxyConfig): + self.data_proxy_file = libdataproxy.DataProxyFile(config.SerializeToString()) + + + def close(self): + self.data_proxy_file.close() + + + def download_file( + self, info: proto.DownloadInfo, file_path: str, file_format: proto.FileFormat + ): + self.data_proxy_file.download_file( + info.SerializeToString(), file_path, file_format + ) + + size = os.path.getsize(file_path) + logging.info( + f"dataproxy sdk: download_file[{file_path}], type[{file_format}], size[{size}]" + ) + + + def upload_file( + self, info: proto.UploadInfo, file_path: str, file_format: proto.FileFormat + ): + self.data_proxy_file.upload_file( + info.SerializeToString(), file_path, file_format + ) + + size = os.path.getsize(file_path) + logging.info( + f"dataproxy sdk: upload_file[{file_path}], type[{file_format}], size[{size}]" + ) diff --git a/dataproxy_sdk/python/dataproxy/dp_pb2.py b/dataproxy_sdk/python/dataproxy/dp_pb2.py new file mode 100644 index 0000000..4e68795 --- /dev/null +++ b/dataproxy_sdk/python/dataproxy/dp_pb2.py @@ -0,0 +1,15 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataproxy_sdk.proto.data_proxy_pb_pb2 import * diff --git a/dataproxy_sdk/python/exported_symbols.lds b/dataproxy_sdk/python/dataproxy/exported_symbols.lds similarity index 100% rename from dataproxy_sdk/python/exported_symbols.lds rename to dataproxy_sdk/python/dataproxy/exported_symbols.lds diff --git a/dataproxy_sdk/python/libdataproxy.cc b/dataproxy_sdk/python/dataproxy/libdataproxy.cc similarity index 99% rename from dataproxy_sdk/python/libdataproxy.cc rename to dataproxy_sdk/python/dataproxy/libdataproxy.cc index 167a86d..dfe3b35 100644 --- a/dataproxy_sdk/python/libdataproxy.cc +++ b/dataproxy_sdk/python/dataproxy/libdataproxy.cc @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "pybind11/pybind11.h" + #include "dataproxy_sdk/cc/api.h" #include "dataproxy_sdk/cc/exception.h" -#include "pybind11/pybind11.h" namespace py = pybind11; diff --git a/dataproxy_sdk/python/dataproxy/sdk.py b/dataproxy_sdk/python/dataproxy/sdk.py new file mode 100644 index 0000000..cd4a158 --- /dev/null +++ b/dataproxy_sdk/python/dataproxy/sdk.py @@ -0,0 +1,16 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dp_file_adapter import * +from .dp_pb2 import * diff --git a/dataproxy_sdk/python/dataproxy/version.py b/dataproxy_sdk/python/dataproxy/version.py new file mode 100644 index 0000000..acdbf5c --- /dev/null +++ b/dataproxy_sdk/python/dataproxy/version.py @@ -0,0 +1,16 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +__version__ = "0.1.0.dev$$DATE$$" diff --git a/dataproxy_sdk/python/version_script.lds b/dataproxy_sdk/python/dataproxy/version_script.lds similarity index 100% rename from dataproxy_sdk/python/version_script.lds rename to dataproxy_sdk/python/dataproxy/version_script.lds diff --git a/dataproxy_sdk/python/requirements.txt b/dataproxy_sdk/python/requirements.txt new file mode 100644 index 0000000..fca26b9 --- /dev/null +++ b/dataproxy_sdk/python/requirements.txt @@ -0,0 +1,2 @@ +protobuf>=4,<5 +kuscia==0.0.3b0 \ No newline at end of file diff --git a/dataproxy_sdk/python/setup.py b/dataproxy_sdk/python/setup.py new file mode 100644 index 0000000..cb2abad --- /dev/null +++ b/dataproxy_sdk/python/setup.py @@ -0,0 +1,204 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import platform +import re +import shutil +import subprocess +import sys +from datetime import date +from pathlib import Path + +from setuptools import Extension, setup, Distribution, find_packages +from setuptools.command.build_ext import build_ext + +BAZEL_MAX_JOBS = os.getenv("BAZEL_MAX_JOBS") +ROOT_DIR = os.path.dirname(__file__) +SKIP_BAZEL_CLEAN = os.getenv("SKIP_BAZEL_CLEAN") +BAZEL_CACHE_DIR = os.getenv("BAZEL_CACHE_DIR") +BAZEL_BIN = "../../bazel-bin/" + + +# Calls Bazel in PATH +def bazel_invoke(invoker, cmdline, *args, **kwargs): + try: + result = invoker(["bazel"] + cmdline, *args, **kwargs) + return result + except IOError: + raise + + +# NOTE: The lists below must be kept in sync with dataproxy/BUILD.bazel. +ops_lib_files = [BAZEL_BIN + "dataproxy_sdk/python/dataproxy/libdataproxy.so"] + +# These are the directories where automatically generated Python protobuf +# bindings are created. +generated_python_directories = [ + BAZEL_BIN + "dataproxy_sdk/proto", +] + + +def remove_prefix(text, prefix): + return text[text.startswith(prefix) and len(prefix) :] + + +def copy_file(target_dir, filename, rootdir): + source = os.path.relpath(filename, rootdir) + if source.startswith(BAZEL_BIN + "dataproxy_sdk/python"): + destination = os.path.join( + target_dir, remove_prefix(source, BAZEL_BIN + "dataproxy_sdk/python/") + ) + else: + destination = os.path.join(target_dir, remove_prefix(source, BAZEL_BIN)) + + # Create the target directory if it doesn't already exist. + print(f"Create dir {os.path.dirname(destination)}") + os.makedirs(os.path.dirname(destination), exist_ok=True) + if not os.path.exists(destination): + print(f"Copy file from {source} to {destination}") + shutil.copy(source, destination, follow_symlinks=True) + return 1 + return 0 + + +class BinaryDistribution(Distribution): + def has_ext_modules(self): + return True + + +class BazelExtension(Extension): + def __init__(self, name: str, sourcedir: str = "") -> None: + super().__init__(name, sources=[]) + self.sourcedir = os.fspath(Path(sourcedir).resolve()) + + +class BazelBuild(build_ext): + def build_extension(self, ext: BazelExtension) -> None: + bazel_env = dict(os.environ, PYTHON3_BIN_PATH=sys.executable) + + bazel_flags = ["--verbose_failures"] + if BAZEL_MAX_JOBS: + n = int(BAZEL_MAX_JOBS) # the value must be an int + bazel_flags.append("--jobs") + bazel_flags.append(f"{n}") + if BAZEL_CACHE_DIR: + bazel_flags.append(f"--repository_cache={BAZEL_CACHE_DIR}") + + bazel_precmd_flags = [] + + bazel_targets = ["//dataproxy_sdk/python/dataproxy:init"] + + bazel_flags.extend(["-c", "opt"]) + + if platform.machine() == "x86_64": + bazel_flags.extend(["--config=avx"]) + + bazel_invoke( + subprocess.check_call, + bazel_precmd_flags + ["build"] + bazel_flags + ["--"] + bazel_targets, + env=bazel_env, + ) + + copied_files = 0 + files_to_copy = ops_lib_files + + # Copy over the autogenerated protobuf Python bindings. + for directory in generated_python_directories: + for filename in os.listdir(directory): + if filename[-3:] == ".py": + files_to_copy.append(os.path.join(directory, filename)) + + for filename in files_to_copy: + copied_files += copy_file(self.build_lib, filename, ROOT_DIR) + print("{} of files copied to {}".format(copied_files, self.build_lib)) + + +# Ensure no remaining lib files. +build_dir = os.path.join(ROOT_DIR, "build") +if os.path.isdir(build_dir): + shutil.rmtree(build_dir) + + +if not SKIP_BAZEL_CLEAN: + bazel_invoke(subprocess.check_call, ["clean"]) + + +# Default Linux platform tag +plat_name = "manylinux2014_x86_64" +if sys.platform == "darwin": + # Due to a bug in conda x64 python, platform tag has to be 10_16 for X64 wheel + if platform.machine() == "x86_64": + plat_name = "macosx_10_16_x86_64" + else: + plat_name = "macosx_12_0_arm64" +elif platform.machine() == "aarch64": + # Linux aarch64 + plat_name = "manylinux_2_28_aarch64" + + +def read_requirements(*filepath): + requirements = [] + with open(os.path.join(ROOT_DIR, *filepath)) as file: + requirements = file.read().splitlines() + return requirements + + +def complete_version_file(*filepath): + today = date.today() + dstr = today.strftime("%Y%m%d") + with open(os.path.join(".", *filepath), "r") as fp: + content = fp.read() + + content = content.replace("$$DATE$$", dstr) + + with open(os.path.join(".", *filepath), "w+") as fp: + fp.write(content) + + +def find_version(*filepath): + complete_version_file(*filepath) + # Extract version information from filepath + with open(os.path.join(".", *filepath)) as fp: + version_match = re.search( + r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M + ) + if version_match: + return version_match.group(1) + print("Unable to find version string.") + exit(-1) + + +setup( + name="secretflow-dataproxy", + version=find_version("dataproxy", "version.py"), + author="SecretFlow Team", + author_email="secretflow-contact@service.alipay.com", + description="DataProxy SDK", + long_description="", + license="Apache 2.0", + classifiers=[ + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + ], + packages=find_packages(), + ext_modules=[BazelExtension("dataproxy")], + cmdclass={"build_ext": BazelBuild}, + distclass=BinaryDistribution, + python_requires=">=3.9, <3.12", + install_requires=read_requirements("requirements.txt"), + setup_requires=["wheel"], + options={"bdist_wheel": {"plat_name": plat_name}}, +) diff --git a/dataproxy_sdk/test/BUILD.bazel b/dataproxy_sdk/test/BUILD.bazel new file mode 100644 index 0000000..7b76960 --- /dev/null +++ b/dataproxy_sdk/test/BUILD.bazel @@ -0,0 +1,45 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//dataproxy_sdk/bazel:defs.bzl", "dataproxy_cc_library") + +package(default_visibility = ["//visibility:public"]) + +dataproxy_cc_library( + name = "random", + srcs = ["random.cc"], + hdrs = ["random.h"], + deps = [ + "//dataproxy_sdk/cc:exception", + "@org_apache_arrow//:arrow", + ], +) + +dataproxy_cc_library( + name = "data_mesh_mock", + srcs = ["data_mesh_mock.cc"], + hdrs = ["data_mesh_mock.h"], + deps = [ + "@org_apache_arrow//:arrow_flight", + ], +) + +dataproxy_cc_library( + name = "test_utils", + srcs = ["test_utils.cc"], + hdrs = ["test_utils.h"], + deps = [ + "//dataproxy_sdk/cc:proto", + ], +) diff --git a/dataproxy_sdk/test/data_mesh_mock.cc b/dataproxy_sdk/test/data_mesh_mock.cc new file mode 100644 index 0000000..2db096f --- /dev/null +++ b/dataproxy_sdk/test/data_mesh_mock.cc @@ -0,0 +1,152 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "data_mesh_mock.h" + +#include +#include + +#include "arrow/flight/api.h" +#include "arrow/table.h" + +namespace dataproxy_sdk { + +class DataMeshMockServer : public arrow::flight::FlightServerBase { + public: + DataMeshMockServer(bool open_dp) : open_dp_(open_dp) {} + + public: + arrow::Status GetFlightInfo( + const arrow::flight::ServerCallContext &, + const arrow::flight::FlightDescriptor &descriptor, + std::unique_ptr *info) override { + ARROW_ASSIGN_OR_RAISE(auto flight_info, MakeFlightInfo()); + *info = std::unique_ptr( + new arrow::flight::FlightInfo(std::move(flight_info))); + + return arrow::Status::OK(); + } + + arrow::Status DoPut( + const arrow::flight::ServerCallContext &, + std::unique_ptr reader, + std::unique_ptr) override { + ARROW_ASSIGN_OR_RAISE(table_, reader->ToTable()); + + return arrow::Status::OK(); + } + + arrow::Status DoGet( + const arrow::flight::ServerCallContext &, + const arrow::flight::Ticket &request, + std::unique_ptr *stream) override { + std::vector> batches; + std::shared_ptr owning_reader; + std::shared_ptr schema; + + if (table_) { + arrow::TableBatchReader batch_reader(*table_); + ARROW_ASSIGN_OR_RAISE(batches, batch_reader.ToRecordBatches()); + schema = table_->schema(); + } + ARROW_ASSIGN_OR_RAISE(owning_reader, arrow::RecordBatchReader::Make( + std::move(batches), schema)); + *stream = std::unique_ptr( + new arrow::flight::RecordBatchStream(owning_reader)); + + return arrow::Status::OK(); + } + + arrow::Status DoAction( + const arrow::flight::ServerCallContext &, + const arrow::flight::Action &action, + std::unique_ptr *result) override { + std::vector results; + ARROW_ASSIGN_OR_RAISE(auto flight_result, + arrow::flight::Result::Deserialize("")); + results.push_back(flight_result); + + *result = std::unique_ptr( + new arrow::flight::SimpleResultStream(std::move(results))); + + return arrow::Status::OK(); + } + + private: + arrow::Result MakeFlightInfo() { + auto descriptor = arrow::flight::FlightDescriptor::Command(""); + arrow::flight::FlightEndpoint endpoint; + if (open_dp_) { + endpoint.locations.push_back(location()); + } else { + ARROW_ASSIGN_OR_RAISE( + auto location, arrow::flight::Location::Parse("kuscia://datamesh")); + endpoint.locations.push_back(location); + } + + arrow::SchemaBuilder builder; + ARROW_ASSIGN_OR_RAISE(auto schema, builder.Finish()); + + return arrow::flight::FlightInfo::Make(*schema, descriptor, {endpoint}, 0, + 0); + } + + bool open_dp_; + std::shared_ptr table_; +}; + +class DataMeshMock::Impl { + public: + arrow::Status StartServer(const std::string &dm_address, bool open_dp) { + ARROW_ASSIGN_OR_RAISE(auto options, arrow::flight::Location::Parse( + "grpc+tcp://" + dm_address)); + arrow::flight::FlightServerOptions server_location(options); + server_ = std::make_shared(open_dp); + RETURN_NOT_OK(server_->Init(server_location)); + + auto thread = std::thread(&DataMeshMockServer::Serve, server_); + thread.detach(); + + return arrow::Status::OK(); + } + arrow::Status CloseServer() { + if (server_) RETURN_NOT_OK(server_->Shutdown()); + + return arrow::Status::OK(); + } + + public: + Impl() = default; + ~Impl() { auto status = CloseServer(); } + + private: + std::shared_ptr server_; +}; + +std::unique_ptr DataMeshMock::Make() { + return std::make_unique(); +} + +DataMeshMock::DataMeshMock() { impl_ = std::make_unique(); } + +DataMeshMock::~DataMeshMock() = default; + +arrow::Status DataMeshMock::StartServer(const std::string &dm_address, + bool open_dp) { + return impl_->StartServer(dm_address, open_dp); +} + +arrow::Status DataMeshMock::CloseServer() { return impl_->CloseServer(); } + +} // namespace dataproxy_sdk diff --git a/dataproxy_sdk/test/data_mesh_mock.h b/dataproxy_sdk/test/data_mesh_mock.h new file mode 100644 index 0000000..34b0ed0 --- /dev/null +++ b/dataproxy_sdk/test/data_mesh_mock.h @@ -0,0 +1,40 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "arrow/status.h" + +namespace dataproxy_sdk { + +class DataMeshMock { + public: + arrow::Status StartServer(const std::string& dm_address, + bool open_dp = false); + arrow::Status CloseServer(); + + public: + static std::unique_ptr Make(); + DataMeshMock(); + ~DataMeshMock(); + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace dataproxy_sdk diff --git a/dataproxy_sdk/test/random.cc b/dataproxy_sdk/test/random.cc new file mode 100644 index 0000000..ceca46e --- /dev/null +++ b/dataproxy_sdk/test/random.cc @@ -0,0 +1,115 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dataproxy_sdk/test/random.h" + +#include + +#include "arrow/builder.h" +#include "arrow/record_batch.h" + +#include "dataproxy_sdk/cc/exception.h" + +namespace dataproxy_sdk { + +class RandomBatchGeneratorImpl { + public: + RandomBatchGeneratorImpl(const std::shared_ptr &schema, + int32_t num_rows) + : schema_(schema), num_rows_(num_rows) {} + + public: + std::shared_ptr Generate() { + for (std::shared_ptr field : schema_->fields()) { + CHECK_ARROW_OR_THROW(arrow::VisitTypeInline(*field->type(), this)); + } + auto ret = arrow::RecordBatch::Make(schema_, num_rows_, arrays_); + return ret; + } + + // Default implementation + arrow::Status Visit(const arrow::DataType &type) { + return arrow::Status::NotImplemented("Generating data for", + type.ToString()); + } + + arrow::Status Visit(const arrow::BinaryType &) { + auto builder = arrow::BinaryBuilder(); + uint32_t max = std::numeric_limits::max() > num_rows_ + ? num_rows_ + : std::numeric_limits::max(); + std::uniform_int_distribution d(0, max); + + uint8_t *buff = (uint8_t *)alloca(sizeof(uint8_t) * num_rows_); + for (int32_t i = 0; i < num_rows_; ++i) { + buff[i] = d(gen_); + } + + CHECK_ARROW_OR_THROW(builder.Append(buff, num_rows_)); + ASSIGN_DP_OR_THROW(auto array, builder.Finish()); + arrays_.push_back(array); + return arrow::Status::OK(); + } + + arrow::Status Visit(const arrow::DoubleType &) { + auto builder = arrow::DoubleBuilder(); + std::normal_distribution<> d{/*mean=*/5.0, /*stddev=*/2.0}; // 正态分布 + for (int32_t i = 0; i < num_rows_; ++i) { + CHECK_ARROW_OR_THROW(builder.Append(d(gen_))); + } + + ASSIGN_DP_OR_THROW(auto array, builder.Finish()); + arrays_.push_back(array); + return arrow::Status::OK(); + } + + arrow::Status Visit(const arrow::Int64Type &) { + // Generate offsets first, which determines number of values in sub-array + std::poisson_distribution<> d{ + /*mean=*/4}; // 产生随机非负整数值i,按离散概率函数分布 + auto builder = arrow::Int64Builder(); + for (int32_t i = 0; i < num_rows_; ++i) { + CHECK_ARROW_OR_THROW(builder.Append(d(gen_))); + } + + ASSIGN_DP_OR_THROW(auto array, builder.Finish()); + arrays_.push_back(array); + return arrow::Status::OK(); + } + + protected: + std::random_device rd_{}; + std::mt19937 gen_{rd_()}; // 随机种子 + std::vector> arrays_; + std::shared_ptr schema_; + int32_t num_rows_; +}; + +std::shared_ptr RandomBatchGenerator::Generate( + const std::shared_ptr &schema, int32_t num_rows) { + RandomBatchGeneratorImpl generator(schema, num_rows); + + return generator.Generate(); +} + +std::shared_ptr RandomBatchGenerator::ExampleGenerate( + int row) { + auto f0 = arrow::field("x", arrow::int64()); + auto f1 = arrow::field("y", arrow::int64()); + auto f2 = arrow::field("z", arrow::int64()); + std::shared_ptr schema = arrow::schema({f0, f1, f2}); + return RandomBatchGenerator::Generate(schema, row); +} + +} // namespace dataproxy_sdk diff --git a/dataproxy_sdk/test/random.h b/dataproxy_sdk/test/random.h new file mode 100644 index 0000000..4ad65b4 --- /dev/null +++ b/dataproxy_sdk/test/random.h @@ -0,0 +1,29 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "arrow/record_batch.h" + +namespace dataproxy_sdk { + +class RandomBatchGenerator { + public: + static std::shared_ptr Generate( + const std::shared_ptr& schema, int32_t num_rows); + + static std::shared_ptr ExampleGenerate(int row = 10); +}; + +} // namespace dataproxy_sdk diff --git a/dataproxy_sdk/test/test_utils.cc b/dataproxy_sdk/test/test_utils.cc new file mode 100644 index 0000000..c1a055b --- /dev/null +++ b/dataproxy_sdk/test/test_utils.cc @@ -0,0 +1,29 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dataproxy_sdk/test/test_utils.h" + +namespace dataproxy_sdk { + +proto::FileFormat GetFileFormat(const std::string& file) { + if (file.find(".csv") != std::string::npos) { + return proto::FileFormat::CSV; + } else if (file.find(".orc") != std::string::npos) { + return proto::FileFormat::ORC; + } + + return proto::FileFormat::BINARY; +} + +} // namespace dataproxy_sdk diff --git a/dataproxy_sdk/test/test_utils.h b/dataproxy_sdk/test/test_utils.h new file mode 100644 index 0000000..8f04b43 --- /dev/null +++ b/dataproxy_sdk/test/test_utils.h @@ -0,0 +1,23 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "dataproxy_sdk/cc/data_proxy_pb.h" + +namespace dataproxy_sdk { + +proto::FileFormat GetFileFormat(const std::string& file); + +} // namespace dataproxy_sdk