From 98638712d20667755f70daa74c73f93e7d5167d1 Mon Sep 17 00:00:00 2001 From: davidu1975 Date: Mon, 16 Dec 2024 10:44:17 +0000 Subject: [PATCH] refactor add write option --- dataservices/management/commands/helpers.py | 4 +- .../import_dbt_investment_opportunities.py | 4 +- .../management/commands/import_dbt_sectors.py | 4 +- ...import_eyb_business_cluster_information.py | 145 +++++++++--------- .../commands/import_eyb_rent_data.py | 4 +- .../commands/import_eyb_salary_data.py | 4 +- .../commands/import_postcodes_from_s3.py | 4 +- .../import_sectors_gva_value_bands.py | 4 +- 8 files changed, 89 insertions(+), 84 deletions(-) diff --git a/dataservices/management/commands/helpers.py b/dataservices/management/commands/helpers.py index 95510ce9..a3b6d42f 100644 --- a/dataservices/management/commands/helpers.py +++ b/dataservices/management/commands/helpers.py @@ -102,11 +102,11 @@ def save_import_data(self, data): def handle(self, *args, **options): if not options['write']: - data = self.load_data(save_data=False) + data = self.load_data(delete_temp_tables=True) prefix = 'Would create' else: prefix = 'Created' - data = self.load_data(save_data=True) + data = self.load_data(delete_temp_tables=False) self.save_import_data(data) if isinstance(data, list): diff --git a/dataservices/management/commands/import_dbt_investment_opportunities.py b/dataservices/management/commands/import_dbt_investment_opportunities.py index 5237a661..6f6d6574 100644 --- a/dataservices/management/commands/import_dbt_investment_opportunities.py +++ b/dataservices/management/commands/import_dbt_investment_opportunities.py @@ -66,7 +66,7 @@ class Command(BaseS3IngestionCommand, S3DownloadMixin): help = 'Import DBT investment opportunities data from s3' - def load_data(self, save_data=True, *args, **options): + def load_data(self, delete_temp_tables=True, *args, **options): data = self.do_handle(prefix=settings.INVESTMENT_OPPORTUNITIES_S3_PREFIX) return data @@ -84,3 +84,5 @@ def batches(_): yield get_investment_opportunities_batch(data, data_table) ingest_data(engine, metadata, on_before_visible, batches) + + return data diff --git a/dataservices/management/commands/import_dbt_sectors.py b/dataservices/management/commands/import_dbt_sectors.py index 7d4c229c..b10b71fc 100644 --- a/dataservices/management/commands/import_dbt_sectors.py +++ b/dataservices/management/commands/import_dbt_sectors.py @@ -55,7 +55,7 @@ class Command(BaseS3IngestionCommand, S3DownloadMixin): help = 'Import DBT Sector list data from s3' - def load_data(self, save_data=True, *args, **options): + def load_data(self, delete_temp_tables=True, *args, **options): data = self.do_handle( prefix=settings.DBT_SECTOR_S3_PREFIX, ) @@ -76,3 +76,5 @@ def batches(_): yield get_dbtsector_table_batch(data, data_table) ingest_data(engine, metadata, on_before_visible, batches) + + return data diff --git a/dataservices/management/commands/import_eyb_business_cluster_information.py b/dataservices/management/commands/import_eyb_business_cluster_information.py index d0d5ca3d..26e7ca38 100644 --- a/dataservices/management/commands/import_eyb_business_cluster_information.py +++ b/dataservices/management/commands/import_eyb_business_cluster_information.py @@ -272,88 +272,82 @@ class Command(BaseS3IngestionCommand, S3DownloadMixin): help = 'Import ONS total UK business and employee counts per region and section, 2 and 5 digit Standard Industrial Classification' # noqa:E501 - def save_tmp_table_data(self, save_data): - data = self.do_handle( - prefix=settings.NOMIS_UK_BUSINESS_EMPLOYEE_COUNTS_FROM_S3_PREFIX, - ) - save_uk_business_employee_counts_tmp_data(data) - data = self.do_handle( - prefix=settings.REF_SIC_CODES_MAPPING_FROM_S3_PREFIX, - ) - save_ref_sic_codes_mapping_data(data) - data = self.do_handle( - prefix=settings.SECTOR_REFERENCE_DATASET_FROM_S3_PREFIX, - ) - save_sector_reference_dataset_data(data) - if save_data: - return + def load_data(self, delete_temp_tables=True, *args, **options): + try: + data = self.do_handle( + prefix=settings.NOMIS_UK_BUSINESS_EMPLOYEE_COUNTS_FROM_S3_PREFIX, + ) + save_uk_business_employee_counts_tmp_data(data) + data = self.do_handle( + prefix=settings.REF_SIC_CODES_MAPPING_FROM_S3_PREFIX, + ) + save_ref_sic_codes_mapping_data(data) + data = self.do_handle( + prefix=settings.SECTOR_REFERENCE_DATASET_FROM_S3_PREFIX, + ) + save_sector_reference_dataset_data(data) - data = self.save_import_data(save_data=save_data) - return data + return self.save_import_data(delete_temp_tables=delete_temp_tables) - def load_data(self, save_data=True, *args, **options): - try: - data = self.save_tmp_table_data(save_data) - return data except Exception: logger.exception("import_eyb_business_cluster_information failed to ingest data from s3") finally: - self.delete_temp_tables(TEMP_TABLES) - - def save_import_data(self, data=[], save_data=True): - if save_data: - self.save_tmp_table_data(save_data) - - sql = """ - SELECT - nubec.geo_description, - nubec.geo_code, - nubec.sic_code, - nubec.sic_description, - nubec.total_business_count, - nubec.business_count_release_year, - nubec.total_employee_count, - nubec.employee_count_release_year, - sector_mapping.dbt_full_sector_name, - sector_mapping.dbt_sector_name - FROM public.dataservices_tmp_eybbusinessclusterinformation nubec - LEFT JOIN ( - SELECT - dataservices_tmp_sector_reference.full_sector_name as dbt_full_sector_name, - dataservices_tmp_sector_reference.field_04 as dbt_sector_name, - -- necessary because sic codes are stored as integer in source table meaning leading 0 was dropped - substring(((dataservices_tmp_ref_sic_codes_mapping.sic_code + 100000)::varchar) from 2 for 5) as five_digit_sic -- # noqa:E501 - FROM public.dataservices_tmp_ref_sic_codes_mapping - INNER JOIN public.dataservices_tmp_sector_reference ON public.dataservices_tmp_ref_sic_codes_mapping.dit_sector_list_id = public.dataservices_tmp_sector_reference.id - ) as sector_mapping - ON nubec.sic_code = sector_mapping.five_digit_sic - """ + if delete_temp_tables: + self.delete_temp_tables(TEMP_TABLES) + + def save_import_data(self, data=[], delete_temp_tables=True): engine = sa.create_engine(settings.DATABASE_URL, future=True) - data = [] - - with engine.connect() as connection: - chunks = pd.read_sql_query(sa.text(sql), connection, chunksize=5000) - - for chunk in chunks: - for _, row in chunk.iterrows(): - data.append( - { - 'geo_description': row.geo_description, - 'geo_code': row.geo_code, - 'sic_code': row.sic_code, - 'sic_description': row.sic_description, - 'total_business_count': row.total_business_count, - 'business_count_release_year': row.business_count_release_year, - 'total_employee_count': row.total_employee_count, - 'employee_count_release_year': row.employee_count_release_year, - 'dbt_full_sector_name': row.dbt_full_sector_name, - 'dbt_sector_name': row.dbt_sector_name, - } - ) - - if not save_data: + if not data: + sql = """ + SELECT + nubec.geo_description, + nubec.geo_code, + nubec.sic_code, + nubec.sic_description, + nubec.total_business_count, + nubec.business_count_release_year, + nubec.total_employee_count, + nubec.employee_count_release_year, + sector_mapping.dbt_full_sector_name, + sector_mapping.dbt_sector_name + FROM public.dataservices_tmp_eybbusinessclusterinformation nubec + LEFT JOIN ( + SELECT + dataservices_tmp_sector_reference.full_sector_name as dbt_full_sector_name, + dataservices_tmp_sector_reference.field_04 as dbt_sector_name, + -- necessary because sic codes are stored as integer in source table meaning leading 0 was dropped + substring(((dataservices_tmp_ref_sic_codes_mapping.sic_code + 100000)::varchar) from 2 for 5) as five_digit_sic -- # noqa:E501 + FROM public.dataservices_tmp_ref_sic_codes_mapping + INNER JOIN public.dataservices_tmp_sector_reference ON public.dataservices_tmp_ref_sic_codes_mapping.dit_sector_list_id = public.dataservices_tmp_sector_reference.id + ) as sector_mapping + ON nubec.sic_code = sector_mapping.five_digit_sic + """ + + data = [] + + with engine.connect() as connection: + chunks = pd.read_sql_query(sa.text(sql), connection, chunksize=5000) + + for chunk in chunks: + for _, row in chunk.iterrows(): + data.append( + { + 'geo_description': row.geo_description, + 'geo_code': row.geo_code, + 'sic_code': row.sic_code, + 'sic_description': row.sic_description, + 'total_business_count': row.total_business_count, + 'business_count_release_year': row.business_count_release_year, + 'total_employee_count': row.total_employee_count, + 'employee_count_release_year': row.employee_count_release_year, + 'dbt_full_sector_name': row.dbt_full_sector_name, + 'dbt_sector_name': row.dbt_sector_name, + } + ) + + if not delete_temp_tables: return data metadata = sa.MetaData() @@ -370,7 +364,6 @@ def batches(_): ingest_data(engine, metadata, on_before_visible, batches) - if save_data: - self.delete_temp_tables(TEMP_TABLES) + self.delete_temp_tables(TEMP_TABLES) return data diff --git a/dataservices/management/commands/import_eyb_rent_data.py b/dataservices/management/commands/import_eyb_rent_data.py index cb4806b7..35b2ddcf 100644 --- a/dataservices/management/commands/import_eyb_rent_data.py +++ b/dataservices/management/commands/import_eyb_rent_data.py @@ -70,7 +70,7 @@ class Command(BaseS3IngestionCommand, S3DownloadMixin): help = 'Import Statista commercial rent data from s3' - def load_data(self, save_data=True, *args, **options): + def load_data(self, delete_temp_tables=True, *args, **options): data = self.do_handle( prefix=settings.EYB_RENT_S3_PREFIX, ) @@ -91,3 +91,5 @@ def batches(_): yield get_eyb_rent_batch(data, data_table) ingest_data(engine, metadata, on_before_visible, batches) + + return data diff --git a/dataservices/management/commands/import_eyb_salary_data.py b/dataservices/management/commands/import_eyb_salary_data.py index 0b01524a..fe676a3a 100644 --- a/dataservices/management/commands/import_eyb_salary_data.py +++ b/dataservices/management/commands/import_eyb_salary_data.py @@ -66,7 +66,7 @@ class Command(BaseS3IngestionCommand, S3DownloadMixin): help = 'Import Statista salary data from s3' - def load_data(self, save_data=True, *args, **options): + def load_data(self, delete_temp_tables=True, *args, **options): data = self.do_handle( prefix=settings.EYB_SALARY_S3_PREFIX, ) @@ -87,3 +87,5 @@ def batches(_): yield get_eyb_salary_batch(data, data_table) ingest_data(engine, metadata, on_before_visible, batches) + + return data diff --git a/dataservices/management/commands/import_postcodes_from_s3.py b/dataservices/management/commands/import_postcodes_from_s3.py index c470dbad..b47d29d4 100644 --- a/dataservices/management/commands/import_postcodes_from_s3.py +++ b/dataservices/management/commands/import_postcodes_from_s3.py @@ -76,7 +76,7 @@ class Command(BaseS3IngestionCommand, S3DownloadMixin): help = 'Import Postcode data from s3' - def load_data(self, save_data=True, *args, **options): + def load_data(self, delete_temp_tables=True, *args, **options): data = self.do_handle( prefix=settings.POSTCODE_FROM_S3_PREFIX, ) @@ -96,3 +96,5 @@ def batches(_): yield get_postcode_table_batch(data, data_table) ingest_data(engine, metadata, on_before_visible, batches) + + return data diff --git a/dataservices/management/commands/import_sectors_gva_value_bands.py b/dataservices/management/commands/import_sectors_gva_value_bands.py index 16501422..27b5c93b 100644 --- a/dataservices/management/commands/import_sectors_gva_value_bands.py +++ b/dataservices/management/commands/import_sectors_gva_value_bands.py @@ -62,7 +62,7 @@ class Command(BaseS3IngestionCommand, S3DownloadMixin): help = 'Import sector GVA value bands data from s3' - def load_data(self, save_data=True, *args, **options): + def load_data(self, delete_temp_tables=True, *args, **options): data = self.do_handle( prefix=settings.DBT_SECTORS_GVA_VALUE_BANDS_DATA_S3_PREFIX, ) @@ -83,3 +83,5 @@ def batches(_): yield get_sectors_gva_value_bands_batch(data, data_table) ingest_data(engine, metadata, on_before_visible, batches) + + return data