Skip to content

Commit

Permalink
Use the timeout context manager in the connection path (#1087)
Browse files Browse the repository at this point in the history
Drop timeout management gymnastics from the `connect()` path and use the
`timeout` context manager instead.
  • Loading branch information
elprans authored Oct 9, 2023
1 parent 8b45beb commit 313b2b2
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 55 deletions.
6 changes: 6 additions & 0 deletions asyncpg/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,9 @@ async def wait_closed(stream):
from ._asyncio_compat import wait_for as wait_for # noqa: F401
else:
from asyncio import wait_for as wait_for # noqa: F401


if sys.version_info < (3, 11):
from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
else:
from asyncio import timeout as timeout # noqa: F401
45 changes: 12 additions & 33 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import stat
import struct
import sys
import time
import typing
import urllib.parse
import warnings
Expand Down Expand Up @@ -55,7 +54,6 @@ def parse(cls, sslmode):
'ssl',
'sslmode',
'direct_tls',
'connect_timeout',
'server_settings',
'target_session_attrs',
])
Expand Down Expand Up @@ -262,7 +260,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:

def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, connect_timeout, server_settings,
direct_tls, server_settings,
target_session_attrs):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
Expand Down Expand Up @@ -655,14 +653,14 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, direct_tls=direct_tls,
connect_timeout=connect_timeout, server_settings=server_settings,
server_settings=server_settings,
target_session_attrs=target_session_attrs)

return addrs, params


def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
database, timeout, command_timeout,
database, command_timeout,
statement_cache_size,
max_cached_statement_lifetime,
max_cacheable_statement_size,
Expand Down Expand Up @@ -695,7 +693,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
dsn=dsn, host=host, port=port, user=user,
password=password, passfile=passfile, ssl=ssl,
direct_tls=direct_tls, database=database,
connect_timeout=timeout, server_settings=server_settings,
server_settings=server_settings,
target_session_attrs=target_session_attrs)

config = _ClientConfiguration(
Expand Down Expand Up @@ -799,17 +797,13 @@ async def _connect_addr(
*,
addr,
loop,
timeout,
params,
config,
connection_class,
record_class
):
assert loop is not None

if timeout <= 0:
raise asyncio.TimeoutError

params_input = params
if callable(params.password):
password = params.password()
Expand All @@ -827,21 +821,16 @@ async def _connect_addr(
params_retry = params._replace(ssl=None)
else:
# skip retry if we don't have to
return await __connect_addr(params, timeout, False, *args)
return await __connect_addr(params, False, *args)

# first attempt
before = time.monotonic()
try:
return await __connect_addr(params, timeout, True, *args)
return await __connect_addr(params, True, *args)
except _RetryConnectSignal:
pass

# second attempt
timeout -= time.monotonic() - before
if timeout <= 0:
raise asyncio.TimeoutError
else:
return await __connect_addr(params_retry, timeout, False, *args)
return await __connect_addr(params_retry, False, *args)


class _RetryConnectSignal(Exception):
Expand All @@ -850,7 +839,6 @@ class _RetryConnectSignal(Exception):

async def __connect_addr(
params,
timeout,
retry,
addr,
loop,
Expand Down Expand Up @@ -882,15 +870,10 @@ async def __connect_addr(
else:
connector = loop.create_connection(proto_factory, *addr)

connector = asyncio.ensure_future(connector)
before = time.monotonic()
tr, pr = await compat.wait_for(connector, timeout=timeout)
timeout -= time.monotonic() - before
tr, pr = await connector

try:
if timeout <= 0:
raise asyncio.TimeoutError
await compat.wait_for(connected, timeout=timeout)
await connected
except (
exceptions.InvalidAuthorizationSpecificationError,
exceptions.ConnectionDoesNotExistError, # seen on Windows
Expand Down Expand Up @@ -993,23 +976,21 @@ async def _can_use_connection(connection, attr: SessionAttribute):
return await can_use(connection)


async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
async def _connect(*, loop, connection_class, record_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()

addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
addrs, params, config = _parse_connect_arguments(**kwargs)
target_attr = params.target_session_attrs

candidates = []
chosen_connection = None
last_error = None
for addr in addrs:
before = time.monotonic()
try:
conn = await _connect_addr(
addr=addr,
loop=loop,
timeout=timeout,
params=params,
config=config,
connection_class=connection_class,
Expand All @@ -1019,10 +1000,8 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
if await _can_use_connection(conn, target_attr):
chosen_connection = conn
break
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
except OSError as ex:
last_error = ex
finally:
timeout -= time.monotonic() - before
else:
if target_attr == SessionAttribute.prefer_standby and candidates:
chosen_connection = random.choice(candidates)
Expand Down
43 changes: 22 additions & 21 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings
import weakref

from . import compat
from . import connect_utils
from . import cursor
from . import exceptions
Expand Down Expand Up @@ -2184,27 +2185,27 @@ async def connect(dsn=None, *,
if loop is None:
loop = asyncio.get_event_loop()

return await connect_utils._connect(
loop=loop,
timeout=timeout,
connection_class=connection_class,
record_class=record_class,
dsn=dsn,
host=host,
port=port,
user=user,
password=password,
passfile=passfile,
ssl=ssl,
direct_tls=direct_tls,
database=database,
server_settings=server_settings,
command_timeout=command_timeout,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
target_session_attrs=target_session_attrs
)
async with compat.timeout(timeout):
return await connect_utils._connect(
loop=loop,
connection_class=connection_class,
record_class=record_class,
dsn=dsn,
host=host,
port=port,
user=user,
password=password,
passfile=passfile,
ssl=ssl,
direct_tls=direct_tls,
database=database,
server_settings=server_settings,
command_timeout=command_timeout,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
target_session_attrs=target_session_attrs
)


class _StatementCacheEntry:
Expand Down
17 changes: 17 additions & 0 deletions tests/test_adversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@ async def test_connection_close_timeout(self):
with self.assertRaises(asyncio.TimeoutError):
await con.close(timeout=0.5)

@tb.with_timeout(30.0)
async def test_pool_acquire_timeout(self):
pool = await self.create_pool(
database='postgres', min_size=2, max_size=2)
try:
self.proxy.trigger_connectivity_loss()
for _ in range(2):
with self.assertRaises(asyncio.TimeoutError):
async with pool.acquire(timeout=0.5):
pass
self.proxy.restore_connectivity()
async with pool.acquire(timeout=0.5):
pass
finally:
self.proxy.restore_connectivity()
pool.terminate()

@tb.with_timeout(30.0)
async def test_pool_release_timeout(self):
pool = await self.create_pool(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def run_testcase(self, testcase):
addrs, params = connect_utils._parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user, password=password,
passfile=passfile, database=database, ssl=sslmode,
direct_tls=False, connect_timeout=None,
direct_tls=False,
server_settings=server_settings,
target_session_attrs=target_session_attrs)

Expand Down

0 comments on commit 313b2b2

Please sign in to comment.