From 789b982f5fe4a082be7235248073d11258ccf152 Mon Sep 17 00:00:00 2001 From: Ludovic DEHON Date: Thu, 31 Mar 2022 21:11:22 +0200 Subject: [PATCH] feat(snowflake): add download & upload from stage --- plugin-jdbc-snowflake/build.gradle | 1 + .../AbstractSnowflakeConnection.java | 36 ++++++ .../plugin/jdbc/snowflake/Download.java | 111 ++++++++++++++++ .../kestra/plugin/jdbc/snowflake/Query.java | 17 +-- .../kestra/plugin/jdbc/snowflake/Upload.java | 122 ++++++++++++++++++ .../jdbc/snowflake/UploadDownloadTest.java | 89 +++++++++++++ .../kestra/plugin/jdbc/AbstractJdbcBatch.java | 14 +- .../plugin/jdbc/AbstractJdbcConnection.java | 22 +--- .../kestra/plugin/jdbc/AbstractJdbcQuery.java | 25 ++-- .../plugin/jdbc/AbstractJdbcStatement.java | 32 +++++ 10 files changed, 418 insertions(+), 51 deletions(-) create mode 100644 plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/AbstractSnowflakeConnection.java create mode 100644 plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Download.java create mode 100644 plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Upload.java create mode 100644 plugin-jdbc-snowflake/src/test/java/io/kestra/plugin/jdbc/snowflake/UploadDownloadTest.java create mode 100644 plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcStatement.java diff --git a/plugin-jdbc-snowflake/build.gradle b/plugin-jdbc-snowflake/build.gradle index 01afd61b..78f38742 100644 --- a/plugin-jdbc-snowflake/build.gradle +++ b/plugin-jdbc-snowflake/build.gradle @@ -1,6 +1,7 @@ dependencies { implementation("net.snowflake:snowflake-jdbc:3.13.14") implementation project(':plugin-jdbc') + implementation("javax.xml.bind:jaxb-api:2.3.1") testImplementation project(':plugin-jdbc').sourceSets.test.output } diff --git a/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/AbstractSnowflakeConnection.java b/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/AbstractSnowflakeConnection.java new file mode 100644 index 00000000..bbebb45c --- /dev/null +++ b/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/AbstractSnowflakeConnection.java @@ -0,0 +1,36 @@ +package io.kestra.plugin.jdbc.snowflake; + +import io.kestra.core.exceptions.IllegalVariableEvaluationException; +import io.kestra.core.runners.RunContext; +import io.kestra.plugin.jdbc.AbstractJdbcConnection; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; +import lombok.experimental.SuperBuilder; + +import java.io.IOException; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +public abstract class AbstractSnowflakeConnection extends AbstractJdbcConnection implements SnowflakeInterface { + @Override + protected void registerDriver() throws SQLException { + DriverManager.registerDriver(new net.snowflake.client.jdbc.SnowflakeDriver()); + } + + @Override + protected Properties connectionProperties(RunContext runContext) throws IllegalVariableEvaluationException, IOException { + Properties properties = super.connectionProperties(runContext); + + this.renderProperties(runContext, properties); + + return properties; + } +} diff --git a/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Download.java b/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Download.java new file mode 100644 index 00000000..1c3a7446 --- /dev/null +++ b/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Download.java @@ -0,0 +1,111 @@ +package io.kestra.plugin.jdbc.snowflake; + +import io.kestra.core.models.annotations.Example; +import io.kestra.core.models.annotations.Plugin; +import io.kestra.core.models.annotations.PluginProperty; +import io.kestra.core.models.tasks.RunnableTask; +import io.kestra.core.runners.RunContext; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.*; +import lombok.experimental.SuperBuilder; +import net.snowflake.client.jdbc.SnowflakeConnection; +import org.apache.commons.io.IOUtils; +import org.slf4j.Logger; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.InputStream; +import java.net.URI; +import java.sql.Connection; +import javax.validation.constraints.NotNull; + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +@Schema( + title = "Download data to an internal stage" +) +@Plugin( + examples = { + @Example( + code = { + "stageName: MYSTAGE", + "fileName: prefix/destFile.csv" + } + ) + } +) +public class Download extends AbstractSnowflakeConnection implements RunnableTask { + private String database; + private String warehouse; + private String schema; + private String role; + + @Schema( + title = "The stage name", + description = "~ or table name or stage name" + ) + @PluginProperty(dynamic = true) + @NotNull + private String stageName; + + @Schema( + title = "destination file name to use" + ) + @PluginProperty(dynamic = true) + @NotNull + private String fileName; + + @Schema( + title = "compress data or not before uploading stream" + ) + @PluginProperty(dynamic = false) + @NotNull + @Builder.Default + private Boolean compress = true; + + @Override + public Download.Output run(RunContext runContext) throws Exception { + Logger logger = runContext.logger(); + File tempFile = runContext.tempFile().toFile(); + + try ( + Connection conn = this.connection(runContext); + BufferedOutputStream outputStream = new BufferedOutputStream(new FileOutputStream(tempFile)) + ) { + String stageName = runContext.render(this.stageName); + String filename = runContext.render(this.fileName); + + logger.info("Starting download from stage '{}' with name '{}'", stageName, filename); + + InputStream inputStream = conn + .unwrap(SnowflakeConnection.class) + .downloadStream( + stageName, + filename, + this.compress + ); + + IOUtils.copyLarge(inputStream, outputStream); + + outputStream.flush(); + + return Output + .builder() + .uri(runContext.putTempFile(tempFile)) + .build(); + } + } + + @Builder + @Getter + public static class Output implements io.kestra.core.models.tasks.Output { + @Schema( + title = "The url of the file on kestra storage" + ) + private final URI uri; + } +} diff --git a/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Query.java b/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Query.java index 4ccf11c4..8415a3e3 100644 --- a/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Query.java +++ b/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Query.java @@ -1,18 +1,18 @@ package io.kestra.plugin.jdbc.snowflake; import io.kestra.core.exceptions.IllegalVariableEvaluationException; -import io.swagger.v3.oas.annotations.media.Schema; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.ToString; -import lombok.experimental.SuperBuilder; import io.kestra.core.models.annotations.Example; import io.kestra.core.models.annotations.Plugin; import io.kestra.core.models.tasks.RunnableTask; import io.kestra.core.runners.RunContext; import io.kestra.plugin.jdbc.AbstractCellConverter; import io.kestra.plugin.jdbc.AbstractJdbcQuery; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; +import lombok.experimental.SuperBuilder; import java.io.IOException; import java.sql.DriverManager; @@ -75,9 +75,4 @@ protected AbstractCellConverter getCellConverter(ZoneId zoneId) { protected void registerDriver() throws SQLException { DriverManager.registerDriver(new net.snowflake.client.jdbc.SnowflakeDriver()); } - - @Override - public Output run(RunContext runContext) throws Exception { - return super.run(runContext); - } } diff --git a/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Upload.java b/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Upload.java new file mode 100644 index 00000000..0d854e24 --- /dev/null +++ b/plugin-jdbc-snowflake/src/main/java/io/kestra/plugin/jdbc/snowflake/Upload.java @@ -0,0 +1,122 @@ +package io.kestra.plugin.jdbc.snowflake; + +import io.kestra.core.models.annotations.Example; +import io.kestra.core.models.annotations.Plugin; +import io.kestra.core.models.annotations.PluginProperty; +import io.kestra.core.models.tasks.RunnableTask; +import io.kestra.core.runners.RunContext; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.*; +import lombok.experimental.SuperBuilder; +import net.snowflake.client.jdbc.SnowflakeConnection; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; + +import java.io.InputStream; +import java.net.URI; +import java.sql.Connection; +import javax.validation.constraints.NotNull; + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +@Schema( + title = "Upload data to an internal stage" +) +@Plugin( + examples = { + @Example( + code = { + "stageName: MYSTAGE", + "prefix: testUploadStream", + "fileName: destFile.csv" + } + ) + } +) +public class Upload extends AbstractSnowflakeConnection implements RunnableTask { + private String database; + private String warehouse; + private String schema; + private String role; + + @Schema( + title = "The file to copy" + ) + @PluginProperty(dynamic = true) + @NotNull + private String from; + + @Schema( + title = "The stage name", + description = "~ or table name or stage name" + ) + @PluginProperty(dynamic = true) + @NotNull + private String stageName; + + @Schema( + title = "path / prefix under which the data should be uploaded on the stage" + ) + @PluginProperty(dynamic = true) + @NotNull + private String prefix; + + @Schema( + title = "destination file name to use" + ) + @PluginProperty(dynamic = true) + @NotNull + private String fileName; + + @Schema( + title = "compress data or not before uploading stream" + ) + @PluginProperty(dynamic = false) + @NotNull + @Builder.Default + private Boolean compress = true; + + @Override + public Upload.Output run(RunContext runContext) throws Exception { + Logger logger = runContext.logger(); + + URI from = new URI(runContext.render(this.from)); + try ( + Connection conn = this.connection(runContext); + InputStream inputStream = runContext.uriToInputStream(from); + ) { + String stageName = runContext.render(this.stageName); + String prefix = runContext.render(this.prefix); + String filename = runContext.render(this.fileName); + + logger.info("Starting upload to stage '{}' on '{}' with name '{}'", stageName, prefix, filename); + + conn + .unwrap(SnowflakeConnection.class) + .uploadStream( + stageName, + prefix, + inputStream, + filename, + this.compress + ); + + return Output + .builder() + .uri(URI.create(StringUtils.stripEnd(prefix, "/") + "/" + filename + (this.compress ? ".gz" : ""))) + .build(); + } + } + + @Builder + @Getter + public static class Output implements io.kestra.core.models.tasks.Output { + @Schema( + title = "The url of the staged files" + ) + private final URI uri; + } +} diff --git a/plugin-jdbc-snowflake/src/test/java/io/kestra/plugin/jdbc/snowflake/UploadDownloadTest.java b/plugin-jdbc-snowflake/src/test/java/io/kestra/plugin/jdbc/snowflake/UploadDownloadTest.java new file mode 100644 index 00000000..5513ef62 --- /dev/null +++ b/plugin-jdbc-snowflake/src/test/java/io/kestra/plugin/jdbc/snowflake/UploadDownloadTest.java @@ -0,0 +1,89 @@ +package io.kestra.plugin.jdbc.snowflake; + +import com.google.common.base.Charsets; +import com.google.common.collect.ImmutableMap; +import io.kestra.core.runners.RunContext; +import io.kestra.core.runners.RunContextFactory; +import io.kestra.core.storages.StorageInterface; +import io.kestra.core.utils.IdUtils; +import io.micronaut.context.annotation.Value; +import io.micronaut.test.extensions.junit5.annotation.MicronautTest; +import jakarta.inject.Inject; +import org.apache.commons.io.IOUtils; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.io.FileInputStream; +import java.net.URI; +import java.net.URL; +import java.util.Objects; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; + +@MicronautTest +public class UploadDownloadTest { + @Value("${snowflake.host}") + protected String host; + + @Value("${snowflake.username}") + protected String username; + + @Value("${snowflake.password}") + protected String password; + + @Inject + protected RunContextFactory runContextFactory; + + @Inject + protected StorageInterface storageInterface; + + @Test + @Disabled("no server for unit test") + void success() throws Exception { + URL resource = UploadDownloadTest.class.getClassLoader().getResource("scripts/snowflake.sql"); + + URI put = storageInterface.put( + new URI("/file/storage/snowflake.sql"), + new FileInputStream(Objects.requireNonNull(resource).getFile()) + ); + + RunContext runContext = runContextFactory.of(ImmutableMap.of()); + + Upload upload = Upload.builder() + .url("jdbc:snowflake://" + this.host + "/?loginTimeout=3") + .username(this.username) + .password(this.password) + .warehouse("COMPUTE_WH") + .database("UNITTEST") + .from(put.toString()) + .schema("public") + .stageName("UNITSTAGE") + .prefix("ut_" + IdUtils.create()) + .fileName("test.sql") + .build(); + + Upload.Output uploadRun = upload.run(runContext); + assertThat(uploadRun.getUri(), notNullValue()); + + Download download = Download.builder() + .url("jdbc:snowflake://" + this.host + "/?loginTimeout=3") + .username(this.username) + .password(this.password) + .warehouse("COMPUTE_WH") + .database("UNITTEST") + .schema("public") + .stageName("UNITSTAGE") + .fileName(uploadRun.getUri().toString()) + .build(); + + Download.Output downloadRun = download.run(runContext); + assertThat(downloadRun.getUri(), notNullValue()); + + assertThat( + IOUtils.toString(this.storageInterface.get(downloadRun.getUri()), Charsets.UTF_8), + is(IOUtils.toString(this.storageInterface.get(put), Charsets.UTF_8)) + ); + } +} diff --git a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcBatch.java b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcBatch.java index 10bcaf5b..53bc3d4c 100644 --- a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcBatch.java +++ b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcBatch.java @@ -29,7 +29,7 @@ @EqualsAndHashCode @Getter @NoArgsConstructor -public abstract class AbstractJdbcBatch extends AbstractJdbcConnection { +public abstract class AbstractJdbcBatch extends AbstractJdbcStatement { @NotNull @io.swagger.v3.oas.annotations.media.Schema( title = "Source file URI" @@ -63,11 +63,6 @@ public abstract class AbstractJdbcBatch extends AbstractJdbcConnection { @PluginProperty(dynamic = true) private List columns; - @Schema( - title = "The time zone id to use for date/time manipulation. Default value is the worker default zone id." - ) - private String timeZoneId; - protected abstract AbstractCellConverter getCellConverter(ZoneId zoneId); public Output run(RunContext runContext) throws Exception { @@ -76,12 +71,7 @@ public Output run(RunContext runContext) throws Exception { AtomicLong count = new AtomicLong(); - ZoneId zoneId = TimeZone.getDefault().toZoneId(); - if (this.timeZoneId != null) { - zoneId = ZoneId.of(timeZoneId); - } - - AbstractCellConverter cellConverter = this.getCellConverter(zoneId); + AbstractCellConverter cellConverter = this.getCellConverter(this.zoneId()); String sql = runContext.render(this.sql); logger.debug("Starting prepared statement: {}", sql); diff --git a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcConnection.java b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcConnection.java index 7e7e34a6..b4cc43f9 100644 --- a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcConnection.java +++ b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcConnection.java @@ -5,11 +5,13 @@ import io.kestra.core.models.tasks.Task; import io.kestra.core.runners.RunContext; import io.swagger.v3.oas.annotations.media.Schema; -import lombok.*; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; import lombok.experimental.SuperBuilder; import java.io.IOException; -import java.nio.file.Path; import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; @@ -39,22 +41,8 @@ public abstract class AbstractJdbcConnection extends Task { @PluginProperty(dynamic = true) protected String password; - @Schema( - title = "If autocommit is enabled", - description = "Sets this connection's auto-commit mode to the given state. If a connection is in auto-commit " + - "mode, then all its SQL statements will be executed and committed as individual transactions. Otherwise, " + - "its SQL statements are grouped into transactions that are terminated by a call to either the method commit" + - "or the method rollback. By default, new connections are in auto-commit mode except if you are using a " + - "`store` properties that will disabled autocommit whenever this properties values." - ) - @PluginProperty(dynamic = false) - protected final Boolean autoCommit = true; - - @Getter(AccessLevel.NONE) - private transient Path cleanupDirectory; - /** - * JDBC driver may be auto-registered. See https://docs.oracle.com/javase/8/docs/api/java/sql/DriverManager.html + * JDBC driver may be auto-registered. See DriverManager * * @throws SQLException registerDrivers failed */ diff --git a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQuery.java b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQuery.java index 4eaa51f6..4ed5faed 100644 --- a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQuery.java +++ b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQuery.java @@ -16,7 +16,10 @@ import java.io.FileWriter; import java.io.IOException; import java.net.URI; -import java.sql.*; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; import java.time.ZoneId; import java.util.*; import java.util.function.Consumer; @@ -26,7 +29,7 @@ @EqualsAndHashCode @Getter @NoArgsConstructor -public abstract class AbstractJdbcQuery extends AbstractJdbcConnection { +public abstract class AbstractJdbcQuery extends AbstractJdbcStatement { @Schema( title = "The sql query to run" ) @@ -59,9 +62,15 @@ public abstract class AbstractJdbcQuery extends AbstractJdbcConnection { private final Boolean fetch = false; @Schema( - title = "The time zone id to use for date/time manipulation. Default value is the worker default zone id." + title = "If autocommit is enabled", + description = "Sets this connection's auto-commit mode to the given state. If a connection is in auto-commit " + + "mode, then all its SQL statements will be executed and committed as individual transactions. Otherwise, " + + "its SQL statements are grouped into transactions that are terminated by a call to either the method commit" + + "or the method rollback. By default, new connections are in auto-commit mode except if you are using a " + + "`store` properties that will disabled autocommit whenever this properties values." ) - private String timeZoneId; + @PluginProperty(dynamic = false) + protected final Boolean autoCommit = true; @Schema( title = "Number of rows that should be fetched", @@ -84,13 +93,7 @@ public abstract class AbstractJdbcQuery extends AbstractJdbcConnection { public AbstractJdbcQuery.Output run(RunContext runContext) throws Exception { Logger logger = runContext.logger(); - - ZoneId zoneId = TimeZone.getDefault().toZoneId(); - if (this.timeZoneId != null) { - zoneId = ZoneId.of(timeZoneId); - } - - AbstractCellConverter cellConverter = getCellConverter(zoneId); + AbstractCellConverter cellConverter = getCellConverter(this.zoneId()); try ( Connection conn = this.connection(runContext); diff --git a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcStatement.java b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcStatement.java new file mode 100644 index 00000000..7c9298d9 --- /dev/null +++ b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcStatement.java @@ -0,0 +1,32 @@ +package io.kestra.plugin.jdbc; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; +import lombok.experimental.SuperBuilder; + +import java.time.ZoneId; +import java.util.TimeZone; + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +public abstract class AbstractJdbcStatement extends AbstractJdbcConnection { + @Schema( + title = "The time zone id to use for date/time manipulation. Default value is the worker default zone id." + ) + private String timeZoneId; + + protected ZoneId zoneId() { + ZoneId zoneId = TimeZone.getDefault().toZoneId(); + if (this.timeZoneId != null) { + zoneId = ZoneId.of(timeZoneId); + } + + return zoneId; + } +}