Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support machines with multiple NICs #576

Merged
merged 26 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion awscrt/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ class S3Client(NativeResource):
client can use for buffering data for requests.
Default values scale with target throughput and are currently
between 2GiB and 8GiB (may change in future)

network_interface_names: (Optional[Sequence(str)]) A sequence of network interface names. The client will distribute the
connections across network interfaces. If any interface name is invalid, goes down,
or has any issues like network access, you will see connection failures.
This option is only supported on Linux, MacOS, and platforms that have either SO_BINDTODEVICE or IP_BOUND_IF. It
is not supported on Windows. `AWS_ERROR_PLATFORM_NOT_SUPPORTED` will be raised on unsupported platforms. On
Linux, SO_BINDTODEVICE is used and requires kernel version >= 5.7 or root privileges.
"""

__slots__ = ('shutdown_event', '_region')
Expand All @@ -222,7 +229,8 @@ def __init__(
multipart_upload_threshold=None,
throughput_target_gbps=None,
enable_s3express=False,
memory_limit=None):
memory_limit=None,
network_interface_names=None):
assert isinstance(bootstrap, ClientBootstrap) or bootstrap is None
assert isinstance(region, str)
assert isinstance(signing_config, AwsSigningConfig) or signing_config is None
Expand Down Expand Up @@ -284,6 +292,7 @@ def on_shutdown():
throughput_target_gbps,
enable_s3express,
memory_limit,
network_interface_names,
s3_client_core)

def make_request(
Expand Down
90 changes: 64 additions & 26 deletions source/s3_client.c
Original file line number Diff line number Diff line change
Expand Up @@ -245,22 +245,24 @@ PyObject *aws_py_s3_client_new(PyObject *self, PyObject *args) {

struct aws_allocator *allocator = aws_py_get_allocator();

PyObject *bootstrap_py; /* O */
PyObject *signing_config_py; /* O */
PyObject *credential_provider_py; /* O */
PyObject *tls_options_py; /* O */
PyObject *on_shutdown_py; /* O */
struct aws_byte_cursor region; /* s# */
int tls_mode; /* i */
uint64_t part_size; /* K */
uint64_t multipart_upload_threshold; /* K */
double throughput_target_gbps; /* d */
int enable_s3express; /* p */
uint64_t mem_limit; /* K */
PyObject *py_core; /* O */
PyObject *bootstrap_py; /* O */
PyObject *signing_config_py; /* O */
PyObject *credential_provider_py; /* O */
PyObject *tls_options_py; /* O */
PyObject *on_shutdown_py; /* O */
struct aws_byte_cursor region; /* s# */
int tls_mode; /* i */
uint64_t part_size; /* K */
uint64_t multipart_upload_threshold; /* K */
double throughput_target_gbps; /* d */
int enable_s3express; /* p */
uint64_t mem_limit; /* K */
PyObject *network_interface_names_py; /* O */
PyObject *py_core; /* O */

if (!PyArg_ParseTuple(
args,
"OOOOOs#iKKdpKO",
"OOOOOs#iKKdpKOO",
&bootstrap_py,
&signing_config_py,
&credential_provider_py,
Expand All @@ -274,6 +276,7 @@ PyObject *aws_py_s3_client_new(PyObject *self, PyObject *args) {
&throughput_target_gbps,
&enable_s3express,
&mem_limit,
&network_interface_names_py,
&py_core)) {
return NULL;
}
Expand Down Expand Up @@ -304,10 +307,16 @@ PyObject *aws_py_s3_client_new(PyObject *self, PyObject *args) {

struct aws_signing_config_aws *signing_config = NULL;
struct aws_credentials *anonymous_credentials = NULL;
struct aws_byte_cursor *network_interface_names = NULL;
size_t num_network_interface_names = 0;
PyObject *capsule = NULL;
/* From hereon, we need to clean up if errors occur */
bool success = false;

if (signing_config_py != Py_None) {
signing_config = aws_py_get_signing_config(signing_config_py);
if (!signing_config) {
return NULL;
goto cleanup;
}
} else if (credential_provider) {
aws_s3_init_default_signing_config(&default_signing_config, region, credential_provider);
Expand All @@ -321,13 +330,10 @@ PyObject *aws_py_s3_client_new(PyObject *self, PyObject *args) {

struct s3_client_binding *s3_client = aws_mem_calloc(allocator, 1, sizeof(struct s3_client_binding));

/* From hereon, we need to clean up if errors occur */

PyObject *capsule = PyCapsule_New(s3_client, s_capsule_name_s3_client, s_s3_client_capsule_destructor);
capsule = PyCapsule_New(s3_client, s_capsule_name_s3_client, s_s3_client_capsule_destructor);
if (!capsule) {
graebm marked this conversation as resolved.
Show resolved Hide resolved
aws_credentials_release(anonymous_credentials);
aws_mem_release(allocator, s3_client);
return NULL;
goto cleanup;
}

s3_client->on_shutdown = on_shutdown_py;
Expand All @@ -336,6 +342,33 @@ PyObject *aws_py_s3_client_new(PyObject *self, PyObject *args) {
s3_client->py_core = py_core;
Py_INCREF(s3_client->py_core);

if (network_interface_names_py != Py_None) {
if (!PySequence_Check(network_interface_names_py)) {
PyErr_SetString(PyExc_TypeError, "Expected network_interface_names to be a sequence.");
goto cleanup;
}
Py_ssize_t list_size = PySequence_Size(network_interface_names_py);
if (list_size < 0) {
goto cleanup;
}
num_network_interface_names = (size_t)list_size;
network_interface_names =
aws_mem_calloc(allocator, num_network_interface_names, sizeof(struct aws_byte_cursor));
for (size_t i = 0; i < num_network_interface_names; ++i) {
PyObject *str_obj = PySequence_GetItem(network_interface_names_py, i); /* New reference */
if (!str_obj) {
PyErr_SetString(PyExc_TypeError, "Expected network_interface_names elements to be non-null.");
waahm7 marked this conversation as resolved.
Show resolved Hide resolved
goto cleanup;
}
network_interface_names[i] = aws_byte_cursor_from_pyunicode(str_obj);
Py_DECREF(str_obj);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😬 now that I'm thinking about it, this isn't 1000% safe either

if (in theory) the sequence wasn't a tuple or list but something with a custom __getitem__() that returned a new object, then when we decref it here, it would get cleaned up, and our char * would be pointing at discarded memory

I looked at how PySequence_Fast_GET_ITEM() could get away with a borrowed reference, and it does it by ensuring a list or tuple was passed it, and if it was something else, it creates a list instead so it can be sure the items live at least as long as whatever's returned (see code)

anyway blugh ugh ugh ugh ugh I realize this bug exists elsewhere in our C-bindings

maybe keep these items alive until cleanup, similar to the array of cursors. Like:

  • declare them before any gotos
    struct aws_byte_cursor *network_interface_name_cursors = NULL;
    PyObject **network_interface_name_pyobjects = NULL;
    
  • aws_mem_calloc() them at the same time
  • network_interface_name_pyobjects[i] = PySequence_GetItem(network_interface_names_py, i); /* New reference */
  • then in cleanup, loop through and Py_XDECREF() them all before aws_mem_release of the tmp array

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not TRYING to make your life miserable. I'm just like "oh hey PySequence exists, let's use it" ... "uh oh"

Another option is doing like PySequence_Fast .. and turning it into a list (if necessary) before processing it. In python we could just be like:

if network_interface_names is not None:
    # ensure this is a list, so it's simpler to process in C
    if not isinstance(network_interface_names, list):
        network_interface_names = list(network_interface_names)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have updated it to the second option.

if (network_interface_names[i].ptr == NULL) {
PyErr_SetString(PyExc_TypeError, "Expected all network_interface_names elements to be strings.");
goto cleanup;
}
}
}

struct aws_s3_client_config s3_config = {
.region = region,
.client_bootstrap = bootstrap,
Expand All @@ -349,18 +382,23 @@ PyObject *aws_py_s3_client_new(PyObject *self, PyObject *args) {
.shutdown_callback = s_s3_client_shutdown,
.shutdown_callback_user_data = s3_client,
.enable_s3express = enable_s3express,
.network_interface_names_array = network_interface_names,
.num_network_interface_names = num_network_interface_names,
};

s3_client->native = aws_s3_client_new(allocator, &s3_config);
if (s3_client->native == NULL) {
PyErr_SetAwsLastError();
goto error;
goto cleanup;
}
aws_credentials_release(anonymous_credentials);
return capsule;
success = true;

error:
cleanup:
aws_credentials_release(anonymous_credentials);
Py_DECREF(capsule);
return NULL;
aws_mem_release(allocator, network_interface_names);
if (!success) {
Py_XDECREF(capsule);
return NULL;
}
return capsule;
}
10 changes: 8 additions & 2 deletions test/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def s3_client_new(
part_size=0,
is_cancel_test=False,
enable_s3express=False,
mem_limit=None):
mem_limit=None,
network_interface_names=None):

if is_cancel_test:
# for cancellation tests, make things slow, so it's less likely that
Expand Down Expand Up @@ -189,7 +190,8 @@ def s3_client_new(
part_size=part_size,
throughput_target_gbps=throughput_target_gbps,
enable_s3express=enable_s3express,
memory_limit=mem_limit)
memory_limit=mem_limit,
network_interface_names=network_interface_names)

return s3_client

Expand Down Expand Up @@ -221,6 +223,10 @@ def test_sanity_secure(self):
s3_client = s3_client_new(True, self.region)
self.assertIsNotNone(s3_client)

def test_sanity_network_interface_names(self):
s3_client = s3_client_new(True, self.region, network_interface_names=["eth0", "eth1"])
self.assertIsNotNone(s3_client)

def test_wait_shutdown(self):
s3_client = s3_client_new(False, self.region)
self.assertIsNotNone(s3_client)
Expand Down
Loading