Skip to content

Commit

Permalink
Merge pull request #3519 from lonvia/api-error-handling
Browse files Browse the repository at this point in the history
Improve error handling around CLI api commands
  • Loading branch information
lonvia authored Aug 19, 2024
2 parents 8b41b80 + adce726 commit 968f1cd
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 114 deletions.
25 changes: 23 additions & 2 deletions src/nominatim_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
This class shares most of the functions with its synchronous
version. There are some additional functions or parameters,
which are documented below.
This class should usually be used as a context manager in 'with' context.
"""
def __init__(self, project_dir: Path,
environ: Optional[Mapping[str, str]] = None,
Expand Down Expand Up @@ -166,6 +168,14 @@ async def close(self) -> None:
await self._engine.dispose()


async def __aenter__(self) -> 'NominatimAPIAsync':
return self


async def __aexit__(self, *_: Any) -> None:
await self.close()


@contextlib.asynccontextmanager
async def begin(self) -> AsyncIterator[SearchConnection]:
""" Create a new connection with automatic transaction handling.
Expand Down Expand Up @@ -351,6 +361,8 @@ class NominatimAPI:
""" This class provides a thin synchronous wrapper around the asynchronous
Nominatim functions. It creates its own event loop and runs each
synchronous function call to completion using that loop.
This class should usually be used as a context manager in 'with' context.
"""

def __init__(self, project_dir: Path,
Expand All @@ -376,8 +388,17 @@ def close(self) -> None:
This function also closes the asynchronous worker loop making
the NominatimAPI object unusable.
"""
self._loop.run_until_complete(self._async_api.close())
self._loop.close()
if not self._loop.is_closed():
self._loop.run_until_complete(self._async_api.close())
self._loop.close()


def __enter__(self) -> 'NominatimAPI':
return self


def __exit__(self, *_: Any) -> None:
self.close()


@property
Expand Down
116 changes: 67 additions & 49 deletions src/nominatim_db/clicmd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,29 +180,32 @@ def run(self, args: NominatimArgs) -> int:
raise UsageError(f"Unsupported format '{args.format}'. "
'Use --list-formats to see supported formats.')

api = napi.NominatimAPI(args.project_dir)
params: Dict[str, Any] = {'max_results': args.limit + min(args.limit, 10),
'address_details': True, # needed for display name
'geometry_output': _get_geometry_output(args),
'geometry_simplification': args.polygon_threshold,
'countries': args.countrycodes,
'excluded': args.exclude_place_ids,
'viewbox': args.viewbox,
'bounded_viewbox': args.bounded,
'locales': _get_locales(args, api.config.DEFAULT_LANGUAGE)
}

if args.query:
results = api.search(args.query, **params)
else:
results = api.search_address(amenity=args.amenity,
street=args.street,
city=args.city,
county=args.county,
state=args.state,
postalcode=args.postalcode,
country=args.country,
**params)
try:
with napi.NominatimAPI(args.project_dir) as api:
params: Dict[str, Any] = {'max_results': args.limit + min(args.limit, 10),
'address_details': True, # needed for display name
'geometry_output': _get_geometry_output(args),
'geometry_simplification': args.polygon_threshold,
'countries': args.countrycodes,
'excluded': args.exclude_place_ids,
'viewbox': args.viewbox,
'bounded_viewbox': args.bounded,
'locales': _get_locales(args, api.config.DEFAULT_LANGUAGE)
}

if args.query:
results = api.search(args.query, **params)
else:
results = api.search_address(amenity=args.amenity,
street=args.street,
city=args.city,
county=args.county,
state=args.state,
postalcode=args.postalcode,
country=args.country,
**params)
except napi.UsageError as ex:
raise UsageError(ex) from ex

if args.dedupe and len(results) > 1:
results = deduplicate_results(results, args.limit)
Expand Down Expand Up @@ -260,14 +263,19 @@ def run(self, args: NominatimArgs) -> int:
if args.lat is None or args.lon is None:
raise UsageError("lat' and 'lon' parameters are required.")

api = napi.NominatimAPI(args.project_dir)
result = api.reverse(napi.Point(args.lon, args.lat),
max_rank=zoom_to_rank(args.zoom or 18),
layers=_get_layers(args, napi.DataLayer.ADDRESS | napi.DataLayer.POI),
address_details=True, # needed for display name
geometry_output=_get_geometry_output(args),
geometry_simplification=args.polygon_threshold,
locales=_get_locales(args, api.config.DEFAULT_LANGUAGE))
layers = _get_layers(args, napi.DataLayer.ADDRESS | napi.DataLayer.POI)

try:
with napi.NominatimAPI(args.project_dir) as api:
result = api.reverse(napi.Point(args.lon, args.lat),
max_rank=zoom_to_rank(args.zoom or 18),
layers=layers,
address_details=True, # needed for display name
geometry_output=_get_geometry_output(args),
geometry_simplification=args.polygon_threshold,
locales=_get_locales(args, api.config.DEFAULT_LANGUAGE))
except napi.UsageError as ex:
raise UsageError(ex) from ex

if args.format == 'debug':
print(loglib.get_and_disable())
Expand Down Expand Up @@ -323,12 +331,15 @@ def run(self, args: NominatimArgs) -> int:

places = [napi.OsmID(o[0], int(o[1:])) for o in args.ids]

api = napi.NominatimAPI(args.project_dir)
results = api.lookup(places,
address_details=True, # needed for display name
geometry_output=_get_geometry_output(args),
geometry_simplification=args.polygon_threshold or 0.0,
locales=_get_locales(args, api.config.DEFAULT_LANGUAGE))
try:
with napi.NominatimAPI(args.project_dir) as api:
results = api.lookup(places,
address_details=True, # needed for display name
geometry_output=_get_geometry_output(args),
geometry_simplification=args.polygon_threshold or 0.0,
locales=_get_locales(args, api.config.DEFAULT_LANGUAGE))
except napi.UsageError as ex:
raise UsageError(ex) from ex

if args.format == 'debug':
print(loglib.get_and_disable())
Expand Down Expand Up @@ -410,17 +421,20 @@ def run(self, args: NominatimArgs) -> int:
raise UsageError('One of the arguments --node/-n --way/-w '
'--relation/-r --place_id/-p is required/')

api = napi.NominatimAPI(args.project_dir)
locales = _get_locales(args, api.config.DEFAULT_LANGUAGE)
result = api.details(place,
address_details=args.addressdetails,
linked_places=args.linkedplaces,
parented_places=args.hierarchy,
keywords=args.keywords,
geometry_output=napi.GeometryFormat.GEOJSON
if args.polygon_geojson
else napi.GeometryFormat.NONE,
locales=locales)
try:
with napi.NominatimAPI(args.project_dir) as api:
locales = _get_locales(args, api.config.DEFAULT_LANGUAGE)
result = api.details(place,
address_details=args.addressdetails,
linked_places=args.linkedplaces,
parented_places=args.hierarchy,
keywords=args.keywords,
geometry_output=napi.GeometryFormat.GEOJSON
if args.polygon_geojson
else napi.GeometryFormat.NONE,
locales=locales)
except napi.UsageError as ex:
raise UsageError(ex) from ex

if args.format == 'debug':
print(loglib.get_and_disable())
Expand Down Expand Up @@ -465,7 +479,11 @@ def run(self, args: NominatimArgs) -> int:
raise UsageError(f"Unsupported format '{args.format}'. "
'Use --list-formats to see supported formats.')

status = napi.NominatimAPI(args.project_dir).status()
try:
with napi.NominatimAPI(args.project_dir) as api:
status = api.status()
except napi.UsageError as ex:
raise UsageError(ex) from ex

if args.format == 'debug':
print(loglib.get_and_disable())
Expand Down
7 changes: 7 additions & 0 deletions test/python/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
from pathlib import Path
import pytest
import pytest_asyncio
import time
import datetime as dt

Expand Down Expand Up @@ -244,3 +245,9 @@ def mkapi(apiobj, options=None):

for api in testapis:
api.close()


@pytest_asyncio.fixture
async def api(temp_db):
async with napi.NominatimAPIAsync(Path('/invalid')) as api:
yield api
7 changes: 3 additions & 4 deletions test/python/api/search/test_icu_query_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@ async def conn(table_factory):
table_factory('word',
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')

api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
yield conn
await api.close()
async with NominatimAPIAsync(Path('/invalid'), {}) as api:
async with api.begin() as conn:
yield conn


@pytest.mark.asyncio
Expand Down
7 changes: 3 additions & 4 deletions test/python/api/search/test_legacy_query_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ class TEXT, type TEXT, country_code TEXT,
temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT)
RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""")

api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
yield conn
await api.close()
async with NominatimAPIAsync(Path('/invalid'), {}) as api:
async with api.begin() as conn:
yield conn


@pytest.mark.asyncio
Expand Down
14 changes: 3 additions & 11 deletions test/python/api/search/test_query_analyzer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,39 @@

import pytest

from nominatim_api import NominatimAPIAsync
from nominatim_api.search.query_analyzer_factory import make_query_analyzer
from nominatim_api.search.icu_tokenizer import ICUQueryAnalyzer

@pytest.mark.asyncio
async def test_import_icu_tokenizer(table_factory):
async def test_import_icu_tokenizer(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('tokenizer', 'icu'),
('tokenizer_import_normalisation', ':: lower();'),
('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '")))

api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
ana = await make_query_analyzer(conn)

assert isinstance(ana, ICUQueryAnalyzer)
await api.close()


@pytest.mark.asyncio
async def test_import_missing_property(table_factory):
api = NominatimAPIAsync(Path('/invalid'), {})
async def test_import_missing_property(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT')

async with api.begin() as conn:
with pytest.raises(ValueError, match='Property.*not found'):
await make_query_analyzer(conn)
await api.close()


@pytest.mark.asyncio
async def test_import_missing_module(table_factory):
api = NominatimAPIAsync(Path('/invalid'), {})
async def test_import_missing_module(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('tokenizer', 'missing'),))

async with api.begin() as conn:
with pytest.raises(RuntimeError, match='Tokenizer not found'):
await make_query_analyzer(conn)
await api.close()

39 changes: 14 additions & 25 deletions test/python/api/test_api_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,34 @@
"""
from pathlib import Path
import pytest
import pytest_asyncio

import sqlalchemy as sa

from nominatim_api import NominatimAPIAsync

@pytest_asyncio.fixture
async def apiobj(temp_db):
""" Create an asynchronous SQLAlchemy engine for the test DB.
"""
api = NominatimAPIAsync(Path('/invalid'), {})
yield api
await api.close()


@pytest.mark.asyncio
async def test_run_scalar(apiobj, table_factory):
async def test_run_scalar(api, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),))

async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a'


@pytest.mark.asyncio
async def test_run_execute(apiobj, table_factory):
async def test_run_execute(api, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),))

async with apiobj.begin() as conn:
async with api.begin() as conn:
result = await conn.execute(sa.text('SELECT * FROM foo'))
assert result.fetchone()[0] == 'a'


@pytest.mark.asyncio
async def test_get_property_existing_cached(apiobj, table_factory):
async def test_get_property_existing_cached(api, table_factory):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('dbv', '96723'), ))

async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.get_property('dbv') == '96723'

await conn.execute(sa.text('TRUNCATE nominatim_properties'))
Expand All @@ -56,12 +45,12 @@ async def test_get_property_existing_cached(apiobj, table_factory):


@pytest.mark.asyncio
async def test_get_property_existing_uncached(apiobj, table_factory):
async def test_get_property_existing_uncached(api, table_factory):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('dbv', '96723'), ))

async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.get_property('dbv') == '96723'

await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'"))
Expand All @@ -71,23 +60,23 @@ async def test_get_property_existing_uncached(apiobj, table_factory):

@pytest.mark.asyncio
@pytest.mark.parametrize('param', ['foo', 'DB:server_version'])
async def test_get_property_missing(apiobj, table_factory, param):
async def test_get_property_missing(api, table_factory, param):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT')

async with apiobj.begin() as conn:
async with api.begin() as conn:
with pytest.raises(ValueError):
await conn.get_property(param)


@pytest.mark.asyncio
async def test_get_db_property_existing(apiobj):
async with apiobj.begin() as conn:
async def test_get_db_property_existing(api):
async with api.begin() as conn:
assert await conn.get_db_property('server_version') > 0


@pytest.mark.asyncio
async def test_get_db_property_existing(apiobj):
async with apiobj.begin() as conn:
async def test_get_db_property_existing(api):
async with api.begin() as conn:
with pytest.raises(ValueError):
await conn.get_db_property('dfkgjd.rijg')
Loading

0 comments on commit 968f1cd

Please sign in to comment.