Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Aareon Sullivan committed Oct 28, 2024
1 parent 251b2a8 commit 6c0c770
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 127 deletions.
217 changes: 97 additions & 120 deletions geonames/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,50 @@ class QueryProtocol(Protocol[T_co]):
async def __call__(self, session: AsyncSession, *args: Any) -> T_co: ...


def _format_detailed_result(geoname: Geoname) -> Dict[str, Any]:
"""
Format a Geoname object into a detailed result dictionary.
Args:
geoname: The Geoname object to format
Returns:
Dictionary with all available location fields using standardized key names
"""
return {
"name": geoname.place_name, # Match _format_search_result
"postal_code": geoname.postal_code,
"country": geoname.country_code, # Match _format_search_result
"state": geoname.admin_name1,
"state_code": geoname.admin_code1,
"province": geoname.admin_name2,
"province_code": geoname.admin_code2,
"community": geoname.admin_name3,
"community_code": geoname.admin_code3,
"latitude": geoname.latitude,
"longitude": geoname.longitude,
"accuracy": geoname.accuracy,
}


def _format_search_result(geoname: Geoname) -> Dict[str, Any]:
"""
Format a Geoname object into a standardized search result dictionary.
Args:
geoname: The Geoname object to format
Returns:
Dictionary with standardized location fields
"""
return {
"name": geoname.place_name, # Standardize on "name" for place name
"country": geoname.country_code, # Standardize on "country" for country code
"latitude": geoname.latitude,
"longitude": geoname.longitude,
}


async def database_exists(engine: AsyncEngine) -> bool:
"""
Check if the database exists and is populated.
Expand Down Expand Up @@ -123,55 +167,25 @@ async def optimize_database(engine: AsyncEngine) -> None:


async def get_geolocation(
engine: AsyncEngine, country: str, zipcode: str
engine: AsyncEngine,
country: str,
zipcode: str
) -> List[Dict[str, Any]]:
"""
Get geolocation information for a given country and postal code.
Args:
engine: The database engine
country: Country code (e.g., 'US', 'GB')
zipcode: Postal code to search for
Returns:
List of dictionaries containing location information
"""

"""Get geolocation data for a postal code."""
async def query(session: AsyncSession, args: Tuple[str, str]) -> List[Geoname]:
c, z = args
logger.debug(f"Searching for country: '{c}', zipcode: '{z}'")
result = await session.execute(
select(Geoname).where(
and_(Geoname.postal_code == z, Geoname.country_code == c)
)
)
locations = list(result.scalars().all())
logger.debug(f"Found {len(locations)} locations")
return locations
return list(result.scalars().all())

try:
geonames: List[Geoname] = await execute_query(
engine, query, (country.strip().upper(), zipcode.strip())
)

results = [
{
"name": geoname.place_name,
"postal_code": geoname.postal_code,
"country_code": geoname.country_code,
"state": geoname.admin_name1,
"state_code": geoname.admin_code1,
"province": geoname.admin_name2,
"province_code": geoname.admin_code2,
"community": geoname.admin_name3,
"community_code": geoname.admin_code3,
"latitude": geoname.latitude,
"longitude": geoname.longitude,
"accuracy": geoname.accuracy,
}
for geoname in geonames
]

results = [_format_detailed_result(geoname) for geoname in geonames]
logger.debug(f"Processed {len(results)} results")
return results
except Exception as e:
Expand Down Expand Up @@ -285,46 +299,19 @@ async def setup_database(config: Config) -> AsyncEngine:


async def search_locations(
engine: AsyncEngine, query_func: Callable[..., Any], *args: Any
engine: AsyncEngine,
query_func: QueryProtocol[List[Geoname]],
*args: Any
) -> List[Dict[str, Any]]:
"""
Execute a location search query and format the results.
Args:
engine: The database engine
query_func: The query function to execute
*args: Arguments to pass to the query function
Returns:
List of dictionaries containing location information
"""
"""Search locations using the provided query function."""
try:
logger.debug(f"Executing search_locations with args: {args}")
geonames = await execute_query(engine, query_func, *args)
geonames: List[Geoname] = await execute_query(engine, query_func, *args)
logger.debug(f"Found {len(geonames)} results from search query")

results = [
{
"name": geoname.place_name,
"postal_code": geoname.postal_code,
"country_code": geoname.country_code,
"state": geoname.admin_name1,
"state_code": geoname.admin_code1,
"province": geoname.admin_name2,
"province_code": geoname.admin_code2,
"community": geoname.admin_name3,
"community_code": geoname.admin_code3,
"latitude": geoname.latitude,
"longitude": geoname.longitude,
"accuracy": geoname.accuracy,
}
for geoname in geonames
]

logger.debug(f"Processed {len(results)} results")
return results

return [_format_search_result(geoname) for geoname in geonames]
except Exception as e:
logger.error(f"Error in search_locations: {str(e)}", exc_info=True)
logger.error(f"Error in search_locations: {e}")
return []


Expand Down Expand Up @@ -418,70 +405,60 @@ async def query(session: AsyncSession, cc: str) -> List[Geoname]:


async def search_by_coordinates(
engine: AsyncEngine, lat: float, lon: float, radius: float, limit: int = 100
engine: AsyncEngine,
lat: float,
lon: float,
radius: float,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
Search for locations near the specified coordinates.
Search for locations within a radius of given coordinates.
Args:
engine: Database engine
lat: Latitude of the search point
lon: Longitude of the search point
radius: Search radius in kilometers
limit: Maximum number of results to return (default: 100)
lat: Latitude
lon: Longitude
radius: Search radius in km (must be positive)
limit: Maximum number of results to return
Returns:
List of locations ordered by distance from the search point
List of matching locations
Raises:
ValueError: If the input coordinates or radius are not valid numbers
ValueError: If radius is not positive
"""
try:
lat = float(lat)
lon = float(lon)
radius = float(radius)
limit = max(1, int(limit))
except ValueError:
raise ValueError(
"Invalid input: latitude, longitude, and radius must be numeric values"
)

async def query(
session: AsyncSession, lat: float, lon: float, radius: float, limit: int
) -> List[Geoname]:
# Calculate the bounding box for initial filtering
lat_min = lat - radius / 111.0
lat_max = lat + radius / 111.0
# Adjust longitude range based on latitude to account for convergence
lon_range = radius / (111.0 * cos(radians(lat)))
lon_min = lon - lon_range
lon_max = lon + lon_range

# Create a more accurate distance calculation using the Haversine formula
distance_formula = (
"6371 * 2 * ASIN(SQRT("
"POWER(SIN(RADIANS(:lat - latitude) / 2), 2) + "
"COS(RADIANS(:lat)) * COS(RADIANS(latitude)) * "
"POWER(SIN(RADIANS(:lon - longitude) / 2), 2)"
"))"
)

result = await session.execute(
select(Geoname)
.where(
and_(
Geoname.latitude.between(lat_min, lat_max),
Geoname.longitude.between(lon_min, lon_max),

if radius <= 0:
raise ValueError("Radius must be positive")

async def query(
session: AsyncSession,
args: Tuple[float, float, float, int]
) -> List[Geoname]:
lat, lon, radius, limit = args
result = await session.execute(
select(Geoname)
.where(
and_(
Geoname.latitude.between(lat - radius / 111, lat + radius / 111),
Geoname.longitude.between(lon - radius / 111, lon + radius / 111),
)
)
.order_by(
func.abs(Geoname.latitude - lat) + func.abs(Geoname.longitude - lon)
)
.limit(limit)
)
.order_by(text(distance_formula))
.params(lat=lat, lon=lon)
.limit(limit)
)
return list(result.scalars().all())
return list(result.scalars().all())

try:
return await search_locations(engine, query, lat, lon, radius, limit)
except ValueError as e:
logger.error(f"Invalid input: {str(e)}")
raise
except Exception as e:
logger.error(f"Error in search_by_coordinates: {e}")
return []
Expand Down
15 changes: 8 additions & 7 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ async def test_get_geolocation(mock_engine):
with patch("geonames.database.execute_query", return_value=mock_result):
result = await get_geolocation(mock_engine, "US", "12345")
assert len(result) == 1
assert result[0]["city"] == "Test City"
assert result[0]["state"] == "State"
assert result[0]["name"] == "Test City" # Use standardized key
assert result[0]["country"] == "US" # Use standardized key


@pytest.mark.asyncio
Expand Down Expand Up @@ -423,7 +423,10 @@ async def test_search_locations(mock_engine):
"""Test the search_locations helper function."""
mock_result = [
Geoname(
place_name="Test Location", country_code="US", latitude=1.0, longitude=1.0
place_name="Test Location",
country_code="US",
latitude=1.0,
longitude=1.0
)
]

Expand All @@ -433,10 +436,8 @@ async def mock_query_func(session):
with patch("geonames.database.execute_query", return_value=mock_result):
result = await search_locations(mock_engine, mock_query_func)
assert len(result) == 1
assert result[0]["name"] == "Test Location"
assert result[0]["country"] == "US"
assert result[0]["latitude"] == 1.0
assert result[0]["longitude"] == 1.0
assert result[0]["name"] == "Test Location" # Use standardized key
assert result[0]["country"] == "US" # Use standardized key


@pytest.mark.asyncio
Expand Down

0 comments on commit 6c0c770

Please sign in to comment.