diff --git a/src/UnloadCopyUtility/redshift_unload_copy.py b/src/UnloadCopyUtility/redshift_unload_copy.py index 30b6e493..278cc6d6 100755 --- a/src/UnloadCopyUtility/redshift_unload_copy.py +++ b/src/UnloadCopyUtility/redshift_unload_copy.py @@ -16,7 +16,7 @@ from global_config import GlobalConfigParametersReader, config_parameters from util.s3_utils import S3Helper, S3Details from util.redshift_cluster import RedshiftCluster -from util.resources import ResourceFactory, TableResource, DBResource +from util.resources import ResourceFactory, TableResource, DBResource, SchemaResource from util.tasks import TaskManager, FailIfResourceDoesNotExistsTask, CreateIfTargetDoesNotExistTask, \ FailIfResourceClusterDoesNotExistsTask, UnloadDataToS3Task, CopyDataFromS3Task, CleanupS3StagingAreaTask, \ NoOperationTask @@ -77,7 +77,19 @@ def __init__(self, src_config = self.config_helper.config['unloadSource'] dest_config = self.config_helper.config['copyTarget'] - if(src_config['tableNames']): + + if "tableNames" in src_config or "tableName" in src_config: + self.setup_table_tasks(src_config, dest_config, global_config_values) + elif "schemaNames" in src_config or "schemaName" in src_config: + self.setup_schema_tasks(src_config, dest_config, global_config_values) + else: + logger.fatal("Invalid configuration, must configure either table or schema") + raise ValueError("Invalid configuration") + + self.task_manager.run() + + def setup_table_tasks(self, src_config, dest_config, global_config_values): + if(src_config.get('tableNames', [])): src_tables = src_config['tableNames'] dest_tables = dest_config['tableNames'] logger.info("Migrating multiple tables") @@ -97,7 +109,36 @@ def __init__(self, destination = ResourceFactory.get_target_resource_from_config_helper(self.config_helper, self.region) self.add_src_dest_tasks(source,destination,global_config_values) - self.task_manager.run() + def setup_schema_tasks(self, src_config, dest_config, global_config_values): + if src_config.get('schemaNames', []): + src_schemas = src_config['schemaNames'] + dest_schemas = dest_config['schemaNames'] + logger.info("Migrating multiple schemas") + if not dest_schemas or len(src_schemas) != len(dest_schemas): + logger.fatal( + "When migrating multiple schemas 'schemaNames' property must be configured in unloadSource and copyTarget, and be the same length" + ) + raise NotImplementedError + for idx in range(0, len(src_schemas)): + logger.info("Migrating schema: " + src_schemas[idx]) + src_config['schemaName'] = src_schemas[idx] + dest_config['schemaName'] = dest_schemas[idx] + source: SchemaResource = ResourceFactory.get_source_resource_from_config_helper( + self.config_helper + ) + destination: SchemaResource = ResourceFactory.get_target_resource_from_config_helper( + self.config_helper + ) + self.add_src_dest_tasks(source, destination, global_config_values) + else: + logger.info("Migrating a single schema") + source: SchemaResource = ResourceFactory.get_source_resource_from_config_helper( + self.config_helper + ) + destination: SchemaResource = ResourceFactory.get_target_resource_from_config_helper( + self.config_helper + ) + self.add_src_dest_tasks(source, destination, global_config_values) def add_src_dest_tasks(self,source,destination,global_config_values): # TODO: Check whether both resources are of type table if that is not the case then perform other scenario's @@ -112,6 +153,11 @@ def add_src_dest_tasks(self,source,destination,global_config_values): logger.fatal('Destination should be a database resource') raise NotImplementedError pass + elif isinstance(source, SchemaResource): + if not isinstance(destination, SchemaResource): + logger.fatal("Destination should be a schema resource") + raise NotImplementedError + self.add_schema_migration(source, destination, global_config_values) else: # TODO: add additional scenario's # For example if both resources are of type schema then create target schema and migrate all tables @@ -165,6 +211,21 @@ def add_table_migration(self, source, destination, global_config_values): s3_cleanup = CleanupS3StagingAreaTask(s3_details) self.task_manager.add_task(s3_cleanup, dependencies=copy_data) + def add_schema_migration(self, source: SchemaResource, destination: SchemaResource, global_config_values): + tables = source.list_tables() + for table_name in tables: + src_table: TableResource = TableResource( + source.get_cluster(), + source.get_schema(), + table_name + ) + dest_table: TableResource = TableResource( + destination.get_cluster(), + destination.get_schema(), + table_name + ) + self.add_table_migration(src_table, dest_table, global_config_values) + def main(args): global region diff --git a/src/UnloadCopyUtility/util/resources.py b/src/UnloadCopyUtility/util/resources.py index 735412a5..89930cf3 100644 --- a/src/UnloadCopyUtility/util/resources.py +++ b/src/UnloadCopyUtility/util/resources.py @@ -1,6 +1,7 @@ import re from abc import abstractmethod import logging +from typing import List from util.child_object import ChildObject from util.kms_helper import KMSHelper @@ -247,6 +248,18 @@ def __str__(self): def get_statement_to_retrieve_ddl_create_statement_text(self): return SchemaDDLHelper().get_schema_ddl_SQL(schema_name=self.get_schema()) + def list_tables(self) -> List[str]: + sql = f""" + SHOW TABLES FROM SCHEMA {self.get_db()}.{self.get_schema()}; + """ + sql = SQLTextHelper.get_sql_without_commands_newlines_and_whitespace(sql) + tbl_dict = self.get_cluster().get_query_full_result_as_list_of_dict( + sql, + self.get_cluster().db, + ) + tables = [tbl["table_name"] for tbl in tbl_dict if tbl["table_type"] == "TABLE"] + return tables + def clone_structure_from(self, other): ddl = other.get_create_sql(generate=True) if self.get_schema() != other.get_schema():