diff --git a/cache/.gitignore b/cache/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/cache/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/output/.gitignore b/output/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/output/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/pg_nearest_city/__init__.py b/pg_nearest_city/__init__.py index ad1c9f2..c70ef0e 100644 --- a/pg_nearest_city/__init__.py +++ b/pg_nearest_city/__init__.py @@ -2,6 +2,6 @@ from ._async.nearest_city import AsyncNearestCity from ._sync.nearest_city import NearestCity -from .base_nearest_city import DbConfig, Location +from .base_nearest_city import DbConfig, Location, geo_test_cases -__all__ = ["NearestCity", "AsyncNearestCity", "DbConfig", "Location"] +__all__ = ["NearestCity", "AsyncNearestCity", "DbConfig", "Location", "geo_test_cases"] diff --git a/pg_nearest_city/base_nearest_city.py b/pg_nearest_city/base_nearest_city.py index 7e9d93b..71ace02 100644 --- a/pg_nearest_city/base_nearest_city.py +++ b/pg_nearest_city/base_nearest_city.py @@ -125,49 +125,107 @@ def validate_coordinates(lon: float, lat: float) -> Optional[Location]: @staticmethod def _get_tableexistence_query() -> sql.SQL: """Check if a table exists via SQL.""" - return sql.SQL(""" + return sql.SQL( + """ SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_name = 'pg_nearest_city_geocoding' ); - """) + """ + ) @staticmethod def _get_table_structure_query() -> sql.SQL: """Get the fields from the pg_nearest_city_geocoding table.""" - return sql.SQL(""" + return sql.SQL( + """ SELECT column_name, data_type FROM information_schema.columns WHERE table_name = 'pg_nearest_city_geocoding' - """) + """ + ) @staticmethod def _get_data_completeness_query() -> sql.SQL: """Check data was loaded into correct structure.""" - return sql.SQL(""" + return sql.SQL( + """ SELECT COUNT(*) as total_cities, COUNT(*) FILTER (WHERE voronoi IS NOT NULL) as cities_with_voronoi FROM pg_nearest_city_geocoding; - """) + """ + ) @staticmethod def _get_spatial_index_check_query() -> sql.SQL: """Check index was created correctly.""" - return sql.SQL(""" + return sql.SQL( + """ SELECT EXISTS ( SELECT FROM pg_indexes WHERE tablename = 'pg_nearest_city_geocoding' AND indexname = 'geocoding_voronoi_idx' ); - """) + """ + ) @staticmethod def _get_reverse_geocoding_query(lon: float, lat: float): """The query to do the reverse geocode!""" - return sql.SQL(""" - SELECT city, country, lat, lon - FROM pg_nearest_city_geocoding - WHERE ST_Contains(voronoi, ST_SetSRID(ST_MakePoint({}, {}), 4326)) + return sql.SQL( + """ + WITH query_point AS ( + SELECT ST_SetSRID( + ST_MakePoint({}, {}), 4326) AS geom + ) + SELECT g.city, g.country, g.lon, g.lat + FROM query_point qp + JOIN country c ON ST_ContainsProperly(c.geom, qp.geom) + JOIN geocoding g ON c.alpha2 = g.country + ORDER BY g.geom <-> qp.geom LIMIT 1 - """).format(sql.Literal(lon), sql.Literal(lat)) + """ + ).format(sql.Literal(lon), sql.Literal(lat)) + + +@dataclass +class GeoTestCase: + """Class representing points with their expected values. + + The given points lie either close to country borders, + are islands, or both (St. Martin / Sint Marteen). + + All longitude / latitudes are in EPSG 4326. + + lon: longitude + lat: latitude + expected city: name of city expected, exactly as stored in the DB + expected country: ISO 3166-1 alpha2 code of the country expected + + """ + + lon: float + lat: float + expected_city: str + expected_country: str + + +geo_test_cases: list[GeoTestCase] = [ + GeoTestCase(7.397405, 43.750402, "La Turbie", "FR"), + GeoTestCase(-79.0647, 43.0896, "Niagara Falls", "US"), + GeoTestCase(-117.1221, 32.5422, "Imperial Beach", "US"), + GeoTestCase(-5.3525, 36.1658, "La Línea de la Concepción", "ES"), + GeoTestCase(12.4534, 41.9033, "Vatican City", "VA"), + GeoTestCase(-63.0822, 18.0731, "Marigot", "MF"), + GeoTestCase(-63.11852, 18.03783, "Simpson Bay Village", "SX"), + GeoTestCase(-63.0458, 18.0255, "Philipsburg", "SX"), + GeoTestCase(7.6194, 47.5948, "Weil am Rhein", "DE"), + GeoTestCase(10.2640, 47.1274, "St Anton am Arlberg", "AT"), + GeoTestCase(4.9312, 51.4478, "Baarle-Nassau", "NL"), + GeoTestCase(-6.3390, 54.1751, "Newry", "GB"), + GeoTestCase(55.478017, -21.297475, "Saint-Pierre", "RE"), + GeoTestCase(-6.271183, 55.687669, "Bowmore", "GB"), + GeoTestCase(88.136284, 26.934422, "Mirik", "IN"), + GeoTestCase(114.060691, 22.512898, "San Tin", "HK"), +] diff --git a/pg_nearest_city/db/data_cleanup.py b/pg_nearest_city/db/data_cleanup.py new file mode 100644 index 0000000..e678e06 --- /dev/null +++ b/pg_nearest_city/db/data_cleanup.py @@ -0,0 +1,338 @@ +"""Data with known issues to be cleaned.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + +from psycopg import sql + +from pg_nearest_city.db.tables import get_all_table_classes + + +class _Comment(Enum): + COORDINATES = "COORDINATES" + ERRATUM = "ERRATUM" + MISSING = "MISSING" + SPELLING = "SPELLING" + + +class _DML(Enum): + DELETE = "DELETE" + INSERT = "INSERT" # not yet implemented + UPDATE = "UPDATE" + + +class _PredicateComparison(Enum): + BETWEEN = "BETWEEN" + BETWEEN_SYM = "BETWEEN SYMMETRIC" + DISTINCT = "IS DISTINCT FROM" + EQUAL = "=" + FALSE = "FALSE" + GT = ">" + GTE = ">=" + IN = "IN" + LT = "<" + LTE = "<=" + NOT_BETWEEN = "NOT BETWEEN" + NOT_BETWEEN_SYM = "NOT BETWEEN SYMMETRIC" + NOT_DISTINCT = "IS NOT DISTINCT FROM" + NOT_EQ = "<>" + NOT_FALSE = "IS NOT FALSE" + NOT_IN = "NOT IN" + NOT_NULL = "IS NOT NULL" + NOT_TRUE = "IS NOT TRUE" + NOT_UNK = "IS NOT UNKNOWN" + NULL = "IS NULL" + TRUE = "IS TRUE" + UNK = "IS UNKNOWN" + + @property + def requires_two_values(self) -> bool: + """Return True if this comparison requires two values (e.g., BETWEEN).""" + return self in { + _PredicateComparison.BETWEEN, + _PredicateComparison.BETWEEN_SYM, + _PredicateComparison.NOT_BETWEEN, + _PredicateComparison.NOT_BETWEEN_SYM, + } + + @property + def requires_column_comparison(self) -> bool: + """Return True if this comparison can compare two columns.""" + return self in { + _PredicateComparison.DISTINCT, + _PredicateComparison.NOT_DISTINCT, + } + + @property + def requires_no_values(self) -> bool: + """Return True if this comparison requires no values (e.g., IS NULL).""" + return self in { + _PredicateComparison.FALSE, + _PredicateComparison.NOT_FALSE, + _PredicateComparison.NULL, + _PredicateComparison.NOT_NULL, + _PredicateComparison.TRUE, + _PredicateComparison.NOT_TRUE, + _PredicateComparison.UNK, + _PredicateComparison.NOT_UNK, + } + + @property + def supports_lists(self) -> bool: + """Return True if this comparison supports list values (e.g., IN).""" + return self in { + _PredicateComparison.IN, + _PredicateComparison.NOT_IN, + } + + +@dataclass +class PredicateData: + """Class representing a predicate for a query. + + col_name: name of column being used as a predicate + comparison: type of comparison for col_name + col_val: value to compare column against + col_val_2: additional value to compare column against (for BETWEEN, etc.) + col_name_2: second column name for column-to-column comparisons + + """ + + col_name: str + comparison: _PredicateComparison + col_name_2: str | None = None + col_val: float | int | str | list | None = None + col_val_2: float | int | str | None = None + + def __post_init__(self) -> None: + """Validate predicate data based on comparison type.""" + if self.comparison.requires_two_values: + if self.col_val_2 is None: + raise ValueError(f"Must provide col_val_2 for {self.comparison.value}") + + elif self.comparison.requires_no_values: + if ( + self.col_val is not None + or self.col_val_2 is not None + or self.col_name_2 is not None + ): + raise ValueError( + f"Must not provide any values for {self.comparison.value}" + ) + + elif self.comparison.supports_lists: + if self.col_val is None: + raise ValueError( + "Must provide col_val (list or single value) " + f"for {self.comparison.value}" + ) + + elif self.comparison.requires_column_comparison: + if self.col_val is None and self.col_name_2 is None: + raise ValueError( + "Must provide either col_val or col_name_2 " + f"for {self.comparison.value}" + ) + + else: + if self.col_val is None: + raise ValueError(f"Must provide col_val for {self.comparison.value}") + + +@dataclass +class RowData: + """Class representing a partial row transformation. + + comment: type of correction (purely for observation) + dml: type of _DML + tbl_name: table name containing the target column + col_name: column name to transform. + col_val: value to assign to the specified column + predicate_cols: one or more PredicateData instances + result_limit: maximum number of rows that to be affected (0 = no limit) + val_is_query: bool indicating whether col_val is a query to be parsed + """ + + comment: _Comment + dml: _DML + tbl_name: str + col_name: str | None = None + col_val: float | int | str | None = None + predicate_cols: list[PredicateData] = field(default_factory=list) + result_limit: int = 1 + val_is_query: bool = False + + def __post_init__(self): + """Performs validation of various supplied values.""" + if self.tbl_name not in [ + x.name for x in get_all_table_classes() + ] and not self.tbl_name.startswith("tmp_"): + raise ValueError(f"Table {self.tbl_name} does not exist") + + if self.dml is _DML.INSERT: + raise NotImplementedError("INSERT operations not yet implemented") + + if self.dml is not _DML.DELETE and ( + self.col_name is None or self.col_val is None + ): + raise ValueError( + f"Column name and value must be provided for {self.dml.value}" + ) + + if not self.predicate_cols: + raise ValueError("At least one predicate must be provided") + + if self.result_limit < 0: + raise ValueError("result_limit must be non-negative") + + +def format_predicate(predicate: PredicateData) -> sql.Composed: + """Format a single predicate into SQL.""" + col_name = sql.Identifier(predicate.col_name) + comparison = sql.SQL(predicate.comparison.value) + + if predicate.comparison.requires_no_values: + return sql.SQL("{col_name} {comparison}").format( + col_name=col_name, comparison=comparison + ) + + elif predicate.comparison.requires_two_values: + return sql.SQL("{col_name} {comparison} {val1} AND {val2}").format( + col_name=col_name, + comparison=comparison, + val1=sql.Literal(predicate.col_val), + val2=sql.Literal(predicate.col_val_2), + ) + + elif predicate.comparison.supports_lists and isinstance(predicate.col_val, list): + values = sql.SQL("({})").format( + sql.SQL(", ").join(sql.Literal(val) for val in predicate.col_val) + ) + return sql.SQL("{col_name} {comparison} {values}").format( + col_name=col_name, comparison=comparison, values=values + ) + + elif predicate.comparison.requires_column_comparison and predicate.col_name_2: + return sql.SQL("{col_name} {comparison} {col_name_2}").format( + col_name=col_name, + comparison=comparison, + col_name_2=sql.Identifier(predicate.col_name_2), + ) + + else: + return sql.SQL("{col_name} {comparison} {col_val}").format( + col_name=col_name, + comparison=comparison, + col_val=sql.Literal(predicate.col_val), + ) + + +PC = _PredicateComparison +PD = PredicateData +ROWS_TO_CLEAN: list[RowData] = [ + RowData( + comment=_Comment.SPELLING, + col_name="city", + col_val="Simpson Bay Village", + dml=_DML.UPDATE, + predicate_cols=[ + PD( + col_name="city", + col_val="Simson Bay Village", + comparison=PC.EQUAL, + ), + PD( + col_name="country", + col_val="SX", + comparison=PC.EQUAL, + ), + ], + tbl_name="geocoding", + result_limit=1, + val_is_query=False, + ), + RowData( + comment=_Comment.COORDINATES, + col_name="geom", + col_val=""" + ST_Difference( + geom, + (SELECT ST_Union(geom) FROM tmp_country_bounds_adm1) + )""", + dml=_DML.UPDATE, + predicate_cols=[ + PD( + col_name="alpha2", + col_val="CN", + comparison=PC.EQUAL, + ), + PD( + col_name="geom", + comparison=PC.NOT_NULL, + ), + ], + tbl_name="country", + result_limit=1, + val_is_query=True, + ), + RowData( + comment=_Comment.ERRATUM, + col_name="alpha3", + col_val="XKX", + dml=_DML.UPDATE, + predicate_cols=[ + PD( + col_name="alpha3", + col_val="XKO", + comparison=PC.EQUAL, + ), + ], + tbl_name="tmp_country_bounds_adm0", + result_limit=1, + val_is_query=False, + ), +] + +SQL_CLEAN_BASE_DEL: sql.SQL = sql.SQL("DELETE FROM {tbl_name} WHERE ") +SQL_CLEAN_BASE_UPD: sql.SQL = sql.SQL( + "UPDATE {tbl_name} SET {col_name} = {col_val} WHERE " +) +SQL_CLEAN_PREDICATES: sql.SQL = sql.SQL("{col_name} {comparison} {col_val}") + + +def make_queries(rows: list[RowData]) -> list[sql.Composed]: + """Generate SQL queries from row data specifications.""" + queries: list[sql.Composed] = [] + + for row in rows: + predicate_clauses = [ + format_predicate(predicate) for predicate in row.predicate_cols + ] + predicates = sql.SQL(" AND ").join(predicate_clauses) + + if row.dml is _DML.UPDATE: + base_query = SQL_CLEAN_BASE_UPD.format( + tbl_name=sql.Identifier(row.tbl_name), + col_name=sql.Identifier(row.col_name), + col_val=( + sql.Literal(row.col_val) + if not row.val_is_query + else sql.SQL(row.col_val) + ), + ) + elif row.dml is _DML.DELETE: + base_query = SQL_CLEAN_BASE_DEL.format( + tbl_name=sql.Identifier(row.tbl_name), + ) + elif row.dml is _DML.INSERT: + raise NotImplementedError("INSERT operations not yet implemented") + else: + raise ValueError(f"Unsupported _DML operation: {row.dml}") + + full_query = base_query + predicates + + queries.append(full_query) + + return queries diff --git a/pg_nearest_city/db/tables.py b/pg_nearest_city/db/tables.py index 5dab89f..dcddddf 100644 --- a/pg_nearest_city/db/tables.py +++ b/pg_nearest_city/db/tables.py @@ -28,11 +28,10 @@ class Country(BaseTable): CREATE TABLE country ( alpha2 CHAR(2) NOT NULL, alpha3 CHAR(3) NOT NULL, - numeric CHAR(3) NOT NULL, name TEXT NOT NULL, + geom GEOMETRY(MultiPolygon,4326) DEFAULT NULL, CONSTRAINT country_pkey PRIMARY KEY (alpha2), CONSTRAINT country_alpha3_unq UNIQUE (alpha3), - CONSTRAINT country_numeric_unq UNIQUE (numeric), CONSTRAINT country_name_len_chk CHECK ( char_length(name) <= 126 ) @@ -64,7 +63,6 @@ class Geocoding(BaseTable): geom GEOMETRY(Point,4326) GENERATED ALWAYS AS ( ST_SetSRID(ST_MakePoint(lon, lat), 4326) ) STORED, - voronoi GEOMETRY(Polygon,4326), CONSTRAINT geocoding_pkey PRIMARY KEY (id), CONSTRAINT geocoding_city_len_chk CHECK ( char_length(city) <= 126 diff --git a/pg_nearest_city/scripts/voronoi_generator.py b/pg_nearest_city/scripts/voronoi_generator.py index 71e4b05..dcf7e1e 100644 --- a/pg_nearest_city/scripts/voronoi_generator.py +++ b/pg_nearest_city/scripts/voronoi_generator.py @@ -11,19 +11,68 @@ import gzip import logging import os +import re import shutil +import subprocess import tempfile import urllib.request import zipfile -from dataclasses import dataclass +from collections import ChainMap +from dataclasses import dataclass, field +from datetime import datetime, timedelta from pathlib import Path from typing import Optional import psycopg +from pg_nearest_city.db.data_cleanup import ROWS_TO_CLEAN, make_queries from pg_nearest_city.db.tables import get_tables_in_creation_order from psycopg.rows import dict_row +@dataclass +class Header: + """Class to create a k:v pair instead of dataclasses' asdict method.""" + + key: str + value: str + + def _to_dict(self): + return {self.key: self.value} + + +@dataclass +class URLConfig: + """Class representing a data file. + + domain: the domain portion of the URL (i.e. after http[s]:// and before the first /) + path: the path portion of the URL (i.e. everything after the domain) + alpha3_column: the name of the column in the file with an ISO 3166-1 alpha-3 code + scheme: the scheme portion of the URL (i.e. http, https) + slug: the last portion of the URL (automatically generated) + zip_name: the name of the file once downloaded + headers: a list of Headers to be passed in (e.g. Referer, User-Agent) + + """ + + domain: str + path: str + alpha3_column: str = "" + scheme: str = "https" + slug: str = field(init=False) + zip_name: str = "" + headers: list[Header] = field(default_factory=list) + _headers: dict[str, str] = field(default_factory=dict) + url: str = field(init=False) + + def __post_init__(self): + """Creates necessary portions of dataclass from supplied values.""" + self._headers = dict(ChainMap(*[header._to_dict() for header in self.headers])) + self.domain = self.domain.rstrip("/") + self.path = self.path.lstrip("/") + self.slug = self.path.rsplit("/", maxsplit=1)[-1] + self.url = f"{self.scheme}://{self.domain}/{self.path}" + + @dataclass class Config: """Configuration parameters for the Voronoi generator.""" @@ -35,24 +84,67 @@ class Config: db_host: str = os.environ.get("PGNEAREST_DB_HOST", "localhost") db_port: int = int(os.environ.get("PGNEAREST_DB_PORT", "5432")) - # Data sources - geonames_url: str = "http://download.geonames.org/export/dump/cities1000.zip" - _zip_path: str = "" + # Cache configuration + cache_dir: Path = Path("./cache") + cur_dt: datetime = datetime.now() # Output configuration - output_dir: Path = Path("/data/output") # Default output directory + output_dir: Path = Path("/data/output") + cache_files: bool = True compress_output: bool = True # Processing options country_filter: Optional[str] = None # Optional filter for testing (e.g., "IT") + # Data sources + country_boundaries: URLConfig = field(init=False) + geonames: URLConfig = field(init=False) + + def __post_init__(self) -> None: + """Creates URLConfig objects for Config.""" + self.geonames: URLConfig = URLConfig( + domain="download.geonames.org", + path="export/dump/cities500.zip", + zip_name="cities500.zip", + ) + self.geonames_old: URLConfig = URLConfig( + domain="download.geonames.org", + path="export/dump/cities1000.zip", + zip_name="cities1000.zip", + ) + self.country_boundaries: URLConfig = URLConfig( + alpha3_column="GID_0", + domain="", + path="export/dump/ADM_0.zip", + zip_name="gadm_0.zip", + ) + self.country_boundaries_geo_boundaries: URLConfig = URLConfig( + alpha3_column="shapeGroup", + domain="www.github.com", + path="wmgeolab/geoBoundaries/raw/refs/tags/v6.0.0/releaseData/CGAZ/geoBoundariesCGAZ_ADM0.zip", + zip_name="geoBoundariesCGAZ_ADM0.zip", + ) + self.country_boundaries_natural_earth: URLConfig = URLConfig( + alpha3_column="ISO_A3", + domain="www.naturalearthdata.com", + # This path is not a typo, it really has http//www... after the domain + path="http//www.naturalearthdata.com/download/10m/cultural/ne_10m_admin_0_countries.zip", + zip_name="countries10m.zip", + headers=[ + Header(key="Referer", value="https://www.naturalearthdata.com/"), + Header(key="User-Agent", value="curl/8.7.1"), + ], + ) + def get_connection_string(self) -> str: """Generate PostgreSQL connection string.""" return f"postgresql://{self.db_user}:{self.db_password}@{self.db_host}:{self.db_port}/{self.db_name}" - def ensure_output_directories(self): - """Ensure all output directories exist.""" + def ensure_directories(self): + """Ensure all directories exist.""" # Ensure output directory exists + if self.cache_files: + self.cache_dir.mkdir(parents=True, exist_ok=True) self.output_dir.mkdir(parents=True, exist_ok=True) @@ -62,8 +154,9 @@ class VoronoiGenerator: def __init__(self, config: Config, logger: Optional[logging.Logger] = None): """Initialise class.""" self.config = config + self.cache_dir: Path = self.config.cache_dir self.logger = logger or logging.getLogger("voronoi_generator") - self.temp_dir = None + self.temp_dir: Path | None = None def run_pipeline(self): """Execute the full data pipeline.""" @@ -75,14 +168,23 @@ def run_pipeline(self): atexit.register(self._cleanup_temp_dir) try: - # Ensure output directories exist - self.config.ensure_output_directories() - - if not self.config._zip_path or not Path(self.config._zip_path).is_file(): - self._download_data() - self.config._zip_path = "" - self._extract_data() - self._clean_data() + # Ensure directories exist + self.config.ensure_directories() + + for url_config in (self.config.geonames, self.config.country_boundaries): + if ( + not self.config.cache_files + or not self.config.cache_dir / url_config.zip_name + or not Path(self.config.cache_dir / url_config.zip_name).is_file() + ): + self._download_data(url_config) + else: + self._check_cached_file_mtime(url_config) + self._copy_file_from_cache(url_config) + if self.config.cache_files: + self._copy_file_to_cache(url_config) + self._extract_data(url_config) + self._clean_geonames() # Connect to database with psycopg.connect( @@ -91,19 +193,27 @@ def run_pipeline(self): # Run each stage with the same connection self._setup_database(conn) self._setup_country_table(conn) - self._import_data(conn) + self._import_geonames(conn) + self._cleanup_geonames_db(conn) + self._import_country_boundaries(conn) self._create_country_index(conn) - self._create_spatial_index(conn) + self._create_spatial_indices(conn) self._compute_voronoi(conn) self._export_wkb(conn) + # Can't perform VACUUM inside of a transaction + with psycopg.connect( + self.config.get_connection_string(), autocommit=True + ) as conn: + self._vacuum_full_and_analyze_db(conn) + # Verify output files self._verify_output_files() self.logger.info("Pipeline completed successfully.") except Exception as e: - self.logger.error(f"Pipeline failed: {str(e)}") + self.logger.error(f"Pipeline failed: {e}") raise finally: # Cleanup is also handled by atexit, but we do it here as well @@ -125,6 +235,7 @@ def _setup_database(self, conn): with conn.cursor() as cur: try: cur.execute("CREATE EXTENSION IF NOT EXISTS postgis") + cur.execute("CREATE EXTENSION IF NOT EXISTS btree_gist") for table in get_tables_in_creation_order(): if table.drop_first: cur.execute( @@ -162,11 +273,11 @@ def _setup_country_table(self, conn): [ "CREATE TEMP TABLE country_tmp", "ON COMMIT DROP", - "AS SELECT *", + "AS SELECT alpha2, alpha3, numeric, name", "FROM country WITH NO DATA", ], [ - "INSERT INTO country", + "INSERT INTO country (alpha2, alpha3, numeric, name)", "SELECT *", "FROM country_tmp", "ORDER BY alpha2", @@ -190,45 +301,96 @@ def _setup_country_table(self, conn): cur.execute(" ".join(prep_stmt[1])) conn.commit() - def _download_data(self): - """Download GeoNames data.""" - self.logger.info(f"Downloading data from {self.config.geonames_url}") + def _download_data(self, url_config: URLConfig): + """Download data from a given URL.""" + self.logger.info(f"Downloading data from {url_config.domain}") - zip_path = self.temp_dir / "cities1000.zip" + assert isinstance(self.temp_dir, Path) + zip_name = self.temp_dir / url_config.zip_name try: - urllib.request.urlretrieve(self.config.geonames_url, zip_path) - except urllib.error.URLError as e: + _request = urllib.request.Request( + url_config.url, headers=url_config._headers + ) + with urllib.request.urlopen(_request) as resp: + with open(zip_name, "wb") as f: + shutil.copyfileobj(resp, f) + except (urllib.error.URLError, urllib.error.HTTPError) as e: self.logger.error(f"Failed to download data: {e}") raise + except (OSError, PermissionError) as e: + self.logger.error(f"Failed to save data: {e}") + raise - def _extract_data(self): - """Extract GeoNames data.""" - zip_path = self.temp_dir / "cities1000.zip" + def _check_cached_file_mtime(self, url_config: URLConfig): + """Check modification time of cached file to determine freshness.""" + assert isinstance(self.cache_dir, Path) + zip_name = self.cache_dir / url_config.zip_name + assert zip_name.is_file() + if datetime.fromtimestamp( + zip_name.stat().st_mtime + ) < self.config.cur_dt - timedelta(weeks=1): + self.logger.warning(f"{url_config.zip_name} is more than one week old") + if ( + _download_file := input(f"Download {url_config.zip_name} again (y/n)? ") + ).lower() == "y": + self._download_data(url_config) + return + self.logger.info(f"User declined to re-download {url_config.zip_name}") + + def _copy_file_to_cache(self, url_config: URLConfig): + """Copy a file from temp directory to local cache directory.""" + assert isinstance(self.cache_dir, Path) + assert isinstance(self.temp_dir, Path) + zip_name = self.temp_dir / url_config.zip_name + try: + shutil.copy2(zip_name, self.cache_dir / url_config.zip_name) + except (FileNotFoundError, PermissionError) as e: + self.logger.error(f"Failed to copy zip file: {e}") + raise - if self.config._zip_path: - shutil.copy2(self.config._zip_path, zip_path) + def _copy_file_from_cache(self, url_config: URLConfig): + """Copy a file from local cache directory to temp directory.""" + assert isinstance(self.cache_dir, Path) + assert isinstance(self.temp_dir, Path) + zip_name = self.temp_dir / url_config.zip_name try: - with zipfile.ZipFile(zip_path, "r") as zip_ref: + shutil.copy2(self.cache_dir / url_config.zip_name, zip_name) + except (FileNotFoundError, PermissionError) as e: + self.logger.error(f"Failed to copy zip file: {e}") + raise + + def _extract_data(self, url_config: URLConfig): + """Extract data from a given zip file.""" + assert isinstance(self.temp_dir, Path) + zip_name = self.temp_dir / url_config.zip_name + try: + with zipfile.ZipFile(zip_name, "r") as zip_ref: zip_ref.extractall(self.temp_dir) except (FileNotFoundError, PermissionError, zipfile.BadZipFile) as e: self.logger.error(f"Failed to extract zip file: {e}") raise - def _clean_data(self): - """Clean GeoNames data to the simplified format.""" - self.logger.info("Cleaning data to simplified format") + def _clean_geonames(self): + """Clean GeoNames data to simplified format.""" + self.logger.info("Cleaning GeoNames data to simplified format") - raw_file = self.temp_dir / "cities1000.txt" + raw_file = self.temp_dir / Path(self.config.geonames.zip_name).with_suffix( + ".txt" + ) clean_file = self.temp_dir / "cities_clean.txt" # This is the file format expected by the package - simplified_file = self.temp_dir / "cities_1000_simple.txt" - simplified_gz = self.temp_dir / "cities_1000_simple.txt.gz" + _simplified_file = Path( + "_".join(re.split(r"(\d+)", Path(raw_file).stem)) + "simple" + ) + simplified_file = self.temp_dir / _simplified_file.with_suffix(".txt") + simplified_gz = self.temp_dir / _simplified_file.with_suffix(".txt.gz") # Output path for the package - output_cities_gz = self.config.output_dir / "cities_1000_simple.txt.gz" - + output_cities_gz = self.config.output_dir / _simplified_file.with_suffix( + ".txt.gz" + ) try: with open(raw_file, "r", newline="") as f: tsv_raw = [x for x in csv.reader(f, delimiter="\t", escapechar="\\")] @@ -282,10 +444,66 @@ def _clean_data(self): self.logger.info(f"Data cleaned and saved to {clean_file}") return clean_file - def _import_data(self, conn): - """Import the cleaned data into PostgreSQL.""" + def _import_country_boundaries(self, conn) -> None: + """Import the country boundaries into PostgreSQL.""" + self.logger.info("Importing country boundaries data") + + assert isinstance(self.temp_dir, Path) + if not shutil.which("ogr2ogr"): + raise RuntimeError("Couldn't find ogr2ogr - please install it") + _shpfile_path = Path(self.config.country_boundaries.slug) + ogr_cmd: list[str] = [ + "ogr2ogr", + "-nln", + "tmp_country_bounds", + "-nlt", + "PROMOTE_TO_MULTI", + "-lco", + "GEOMETRY_NAME=geom", + "-lco", + "PRECISION=NO", + "--config", + "PG_USE_COPY=YES", + "-f", + "PostgreSQL", + "-sql", + f"SELECT {self.config.country_boundaries.alpha3_column} \ + AS alpha3 FROM {_shpfile_path.stem}", + f"PG:{self.config.get_connection_string()}", + f"{self.temp_dir / _shpfile_path.with_suffix('.shp')}", + ] + + update_sql: str = """ + UPDATE country + SET geom = t.geom + FROM tmp_country_bounds t + WHERE country.alpha3 = t.alpha3 + """ + drop_sql: str = "DROP TABLE tmp_country_bounds" + + try: + subprocess.run(ogr_cmd, check=True) + except subprocess.CalledProcessError: + self.logger.error( + "Failed to extract country boundaries from " + f"{_shpfile_path.with_suffix('.shp')}" + ) + raise + + try: + with conn.cursor() as cur: + cur.execute(update_sql) + cur.execute(drop_sql) + conn.commit() + except Exception as e: + conn.rollback() + self.logger.error(f"Failed to update country with geom: {e}") + raise + + def _import_geonames(self, conn): + """Import the cleaned GeoNames data into PostgreSQL.""" clean_file = self.temp_dir / "cities_clean.txt" - self.logger.info(f"Importing data from {clean_file}") + self.logger.info(f"Importing GeoNames data from {clean_file}") if not clean_file.exists(): self.logger.error(f"Clean data file not found: {clean_file}") @@ -330,19 +548,54 @@ def _import_data(self, conn): self.logger.error(f"Failed to import data: {e}") raise - def _create_spatial_index(self, conn): - """Create spatial index for efficient processing.""" - self.logger.info("Creating spatial index on geometry") + def _cleanup_geonames_db(self, conn): + """Manually fix known issues with dataset.""" + self.logger.info("Cleaning up geonames") + query_data = zip(ROWS_TO_CLEAN, make_queries(ROWS_TO_CLEAN), strict=False) + + with conn.cursor() as cur: + for query_info, query in query_data: + try: + cur.execute(query) + if cur.rowcount == query_info.result_limit: + conn.commit() + continue + elif cur.rowcount > query_info.result_limit: + self.logger.error( + f"Expected {query_info.result_limit} affected rows, " + f"got {cur.rowcount} affected rows - " + "tighten predicates and try again" + ) + conn.rollback() + return + elif not cur.rowcount: + self.logger.warning( + f"Expected {query_info.result_limit} affected rows, " + "got 0 affected rows" + ) + except Exception as e: + conn.rollback() + self.logger.error(f"Failed to update data: {e}") + raise + + def _create_spatial_indices(self, conn): + """Create spatial indices for efficient processing.""" + self.logger.info("Creating spatial indices on geometry columns") with conn.cursor() as cur: try: cur.execute( - "CREATE INDEX geocoding_geom_idx ON geocoding USING GIST(geom)" + "CREATE INDEX IF NOT EXISTS geocoding_country_geom_gist_idx " + "ON geocoding USING GIST (country, geom)" + ) + cur.execute( + "CREATE INDEX IF NOT EXISTS country_geom_idx " + "ON country USING GIST(geom)" ) conn.commit() - self.logger.info("Spatial index created") + self.logger.info("Spatial indices created") except Exception as e: conn.rollback() - self.logger.error(f"Failed to create spatial index: {e}") + self.logger.error(f"Failed to create spatial indices: {e}") raise def _create_country_index(self, conn): @@ -350,9 +603,12 @@ def _create_country_index(self, conn): self.logger.info("Creating B+tree index on country") with conn.cursor() as cur: try: - cur.execute("CREATE INDEX geocoding_country_idx ON geocoding (country)") + cur.execute( + "CREATE INDEX IF NOT EXISTS geocoding_country_idx " + "ON geocoding (country)" + ) conn.commit() - self.logger.info("country index created") + self.logger.info("B+tree index created on country") except Exception as e: conn.rollback() self.logger.error(f"Failed to create index on country: {e}") @@ -468,7 +724,8 @@ def _export_wkb(self, conn): def _verify_output_files(self): """Verify that all required output files exist.""" - cities_file = self.config.output_dir / "cities_1000_simple.txt.gz" + # cities_file = self.config.output_dir / "cities_1000_simple.txt.gz" + cities_file = self.config.output_dir / "cities_500_simple.txt.gz" voronoi_file = ( self.config.output_dir / "voronois.wkb.gz" if self.config.compress_output @@ -496,6 +753,18 @@ def _verify_output_files(self): self.logger.info(f" - {cities_file}") self.logger.info(f" - {voronoi_file}") + def _vacuum_full_and_analyze_db(self, conn): + """Perform VACUUM (ANALYZE, FULL) on tables to cleanup dead tuples.""" + self.logger.info("Performing VACUUM (ANALYZE, FULL) on geocoding tables") + with conn.cursor() as cur: + try: + cur.execute("VACUUM (ANALYZE, FULL) geocoding") + cur.execute("VACUUM (ANALYZE, FULL) country") + self.logger.info("Tables vacuumed") + except Exception as e: + self.logger.error(f"Failed to vacuum tables: {e}") + raise + def setup_logging(): """Configure logging for the script.""" @@ -519,15 +788,18 @@ def parse_args(): group_db.add_argument("--db-user", help="Database username") group_db.add_argument("--db-password", help="Database password") + parser.add_argument( + "--cache-dir", default="./cache", help="Directory to cache downloaded files" + ) parser.add_argument("--country", help="Filter to specific country code (e.g. IT)") parser.add_argument( - "--no-compress", action="store_true", help="Don't compress output" + "--no-cache", action="store_true", help="Don't cache downloaded files" ) parser.add_argument( - "--output-dir", default="/data/output", help="Directory for output files" + "--no-compress", action="store_true", help="Don't compress output" ) parser.add_argument( - "--zip-path", help="Path to existing cities1000.zip (avoids re-downloading)" + "--output-dir", default="/data/output", help="Directory for output files" ) return parser.parse_args() @@ -539,7 +811,9 @@ def parse_args(): # Create config with consistent Path objects config = Config( + cache_dir=Path(args.cache_dir), output_dir=Path(args.output_dir), + cache_files=not args.no_cache, compress_output=not args.no_compress, country_filter=args.country, ) @@ -550,17 +824,27 @@ def parse_args(): config.db_name = args.db_name or config.db_name config.db_user = args.db_user or config.db_user config.db_password = args.db_password or config.db_password - config._zip_path = args.zip_path or config._zip_path generator = VoronoiGenerator(config, logger) - + geonames_output_match = re.match( + r"([a-z]+)([0-9]+)", config.geonames.zip_name, re.I + ) + if geonames_output_match: + geonames_output = f"{'_'.join(geonames_output_match.groups())}_simple.txt.gz" + else: + logger.warning( + "Failed to match filename for geonames simple output - " + "check output directory for file like " + f"'{Path(config.geonames.zip_name).stem}'" + ) + geonames_output = "?" try: generator.run_pipeline() logger.info("Generation complete!") # Print summary info logger.info("\nOutput files created:") - logger.info(f" - {config.output_dir}/cities_1000_simple.txt.gz") + logger.info(f" - {config.output_dir}/{geonames_output}") logger.info( f" - {config.output_dir}/voronois.wkb" f"{'.gz' if config.compress_output else ''}" @@ -568,4 +852,4 @@ def parse_args(): logger.info("\nThese files are ready for use with the pg-nearest-city package.") except Exception as e: logger.error(f"Pipeline failed: {e}") - raise SystemExit(1) from e + raise Exception from e diff --git a/tests/_async/test_nearest_city.py b/tests/_async/test_nearest_city.py index 185f629..4cee88a 100644 --- a/tests/_async/test_nearest_city.py +++ b/tests/_async/test_nearest_city.py @@ -2,11 +2,10 @@ import os +import psycopg import pytest import pytest_asyncio -import psycopg - -from pg_nearest_city import AsyncNearestCity, Location, DbConfig +from pg_nearest_city import AsyncNearestCity, DbConfig, Location, geo_test_cases # NOTE we define the fixture here and not in conftest.py to allow @@ -86,12 +85,14 @@ async def test_check_initialization_incomplete_table(test_db): geocoder = AsyncNearestCity(test_db) async with test_db.cursor() as cur: - await cur.execute(""" + await cur.execute( + """ CREATE TABLE pg_nearest_city_geocoding ( city varchar, country varchar ); - """) + """ + ) await test_db.commit() status = await geocoder._check_initialization_status(cur) @@ -188,3 +189,13 @@ async def test_invalid_coordinates(test_db): with pytest.raises(ValueError): await geocoder.query(0, 181) # Invalid longitude + + +@pytest.mark.parametrize("case", geo_test_cases) +async def test_cities_close_country_boundaries(case): + async with AsyncNearestCity() as geocoder: + location = await geocoder.query(lon=case.lon, lat=case.lat) + assert location is not None + assert isinstance(location, Location) + assert location.city == case.expected_city + assert location.country == case.expected_country diff --git a/tests/_sync/test_cleanup.py b/tests/_sync/test_cleanup.py new file mode 100644 index 0000000..d765f07 --- /dev/null +++ b/tests/_sync/test_cleanup.py @@ -0,0 +1,145 @@ +from typing import TYPE_CHECKING +from unittest.mock import Mock, patch + +import psycopg +import pytest +from pg_nearest_city.db.data_cleanup import ( + PredicateData, + RowData, + make_queries, +) + +if TYPE_CHECKING: + from pg_nearest_city.db.data_cleanup import _DML, _Comment, _PredicateComparison + + +@pytest.fixture() +def test_db(test_db_conn_string): + """Provide a clean database connection for each test.""" + conn = psycopg.Connection.connect(test_db_conn_string) + + yield conn + + conn.close() + + +class TestMakeQueries: + """Test cases for make_queries function.""" + + @patch("pg_nearest_city.db.tables.get_all_table_classes") + def test_make_queries_update_single_predicate(self, mock_get_tables, test_db): + """Test make_queries with single UPDATE operation.""" + mock_table = Mock() + mock_table.name = "geocoding" + mock_get_tables.return_value = [mock_table] + + predicate = PredicateData( + col_name="city", comparison=_PredicateComparison.EQUAL, col_val="Old City" + ) + + row_data = RowData( + col_name="city", + col_val="New City", + comment=_Comment.SPELLING, + dml=_DML.UPDATE, + predicate_cols=[predicate], + tbl_name="geocoding", + ) + + queries = make_queries([row_data]) + + assert len(queries) == 1 + assert isinstance(queries[0], psycopg.sql.Composed) + + query_str = queries[0].as_string(test_db) + assert query_str == ( + """UPDATE "geocoding" SET "city" = 'New City' WHERE """ + """"city" = 'Old City'""" + ) + + @patch("pg_nearest_city.db.tables.get_all_table_classes") + def test_make_queries_update_multiple_predicates(self, mock_get_tables, test_db): + """Test make_queries with single UPDATE operation.""" + mock_table = Mock() + mock_table.name = "geocoding" + mock_get_tables.return_value = [mock_table] + + predicates = [ + PredicateData( + col_name="city", + comparison=_PredicateComparison.EQUAL, + col_val="Old City", + ), + PredicateData( + col_name="created_at", + comparison=_PredicateComparison.BETWEEN, + col_val="1995-05-23", + col_val_2="2038-01-01", + ), + ] + + row_data = RowData( + col_name="city", + col_val="New City", + comment=_Comment.SPELLING, + dml=_DML.UPDATE, + predicate_cols=predicates, + tbl_name="geocoding", + ) + + queries = make_queries([row_data]) + + assert len(queries) == 1 + assert isinstance(queries[0], psycopg.sql.Composed) + + query_str = queries[0].as_string(test_db) + assert query_str == ( + """UPDATE "geocoding" SET "city" = 'New City' WHERE """ + """"city" = 'Old City' AND "created_at" BETWEEN """ + """'1995-05-23' AND '2038-01-01'""" + ) + + @patch("pg_nearest_city.db.tables.get_all_table_classes") + def test_make_queries_delete_operation(self, mock_get_tables, test_db): + """Test make_queries with DELETE operation.""" + mock_table = Mock() + mock_table.name = "geocoding" + mock_get_tables.return_value = [mock_table] + + predicate = PredicateData( + col_name="city", + comparison=_PredicateComparison.EQUAL, + col_val="Dallas", + ) + + row_data = RowData( + comment=_Comment.ERRATUM, + dml=_DML.DELETE, + predicate_cols=[predicate], + tbl_name="geocoding", + ) + + queries = make_queries([row_data]) + + assert len(queries) == 1 + assert isinstance(queries[0], psycopg.sql.Composed) + + query_str = queries[0].as_string(test_db) + assert query_str == """DELETE FROM "geocoding" WHERE "city" = 'Dallas'""" + + @patch("pg_nearest_city.db.tables.get_all_table_classes") + def test_make_queries_insert_raises_not_implemented(self, mock_get_tables): + """Test make_queries raises NotImplemented for INSERT operations.""" + mock_table = Mock() + mock_table.name = "geocoding" + mock_get_tables.return_value = [mock_table] + + with pytest.raises(NotImplementedError): + row_data = RowData( + col_name="city", + col_val="New City", + comment=_Comment.MISSING, + dml=_DML.INSERT, + tbl_name="geocoding", + ) + make_queries([row_data]) diff --git a/tests/_sync/test_nearest_city.py b/tests/_sync/test_nearest_city.py index 2ab6f05..5f127cc 100644 --- a/tests/_sync/test_nearest_city.py +++ b/tests/_sync/test_nearest_city.py @@ -2,11 +2,10 @@ import os -import pytest - import psycopg +import pytest -from pg_nearest_city import NearestCity, Location, DbConfig +from pg_nearest_city import NearestCity, DbConfig, Location, geo_test_cases # NOTE we define the fixture here and not in conftest.py to allow @@ -86,12 +85,14 @@ def test_check_initialization_incomplete_table(test_db): geocoder = NearestCity(test_db) with test_db.cursor() as cur: - cur.execute(""" + cur.execute( + """ CREATE TABLE pg_nearest_city_geocoding ( city varchar, country varchar ); - """) + """ + ) test_db.commit() status = geocoder._check_initialization_status(cur) @@ -188,3 +189,13 @@ def test_invalid_coordinates(test_db): with pytest.raises(ValueError): geocoder.query(0, 181) # Invalid longitude + + +@pytest.mark.parametrize("case", geo_test_cases) +def test_cities_close_country_boundaries(case): + with NearestCity() as geocoder: + location = geocoder.query(lon=case.lon, lat=case.lat) + assert location is not None + assert isinstance(location, Location) + assert location.city == case.expected_city + assert location.country == case.expected_country