Skip to content

Commit

Permalink
Implement exclusive/atomic write (#917)
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant authored Dec 5, 2024
1 parent b6a0095 commit ca949ab
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 113 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ jobs:
fail-fast: false
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
Expand Down Expand Up @@ -39,9 +38,10 @@ jobs:
shell: bash -l {0}
run: |
pip install git+https://github.com/fsspec/filesystem_spec
pip install --upgrade "aiobotocore${{ matrix.aiobotocore-version }}" boto3 # boto3 to ensure compatibility
pip install --upgrade "aiobotocore${{ matrix.aiobotocore-version }}"
pip install --upgrade "botocore" --no-deps
pip install . --no-deps
pip show aiobotocore boto3 botocore
pip list
- name: Run Tests
shell: bash -l {0}
Expand Down
2 changes: 1 addition & 1 deletion ci/env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ dependencies:
- black
- httpretty
- aiobotocore
- "moto>=4,<5"
- moto
- flask
- fsspec
1 change: 1 addition & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Changelog
- invalidate cache in one-shot pipe file (#904)
- make pipe() concurrent (#901)
- add py3.13 (#898)
- suppoert R2 multi-part uploads (#888)

2024.9.0
--------
Expand Down
4 changes: 0 additions & 4 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
[pytest]
testpaths = s3fs
env =
BOTO_PATH=/dev/null
AWS_ACCESS_KEY_ID=dummy_key
AWS_SECRET_ACCESS_KEY=dummy_secret
165 changes: 104 additions & 61 deletions s3fs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,10 @@ def _open(
mode: string
One of 'r', 'w', 'a', 'rb', 'wb', or 'ab'. These have the same meaning
as they do for the built-in `open` function.
"x" mode, exclusive write, is only known to work on AWS S3, and
requires botocore>1.35.20. If the file is multi-part (i.e., has more
than one block), the condition is only checked on commit; if this fails,
the MPU is aborted.
block_size: int
Size of data-node blocks if reading
fill_cache: bool
Expand Down Expand Up @@ -1135,15 +1139,30 @@ async def _call_and_read():
return await _error_wrapper(_call_and_read, retries=self.retries)

async def _pipe_file(
self, path, data, chunksize=50 * 2**20, max_concurrency=None, **kwargs
self,
path,
data,
chunksize=50 * 2**20,
max_concurrency=None,
mode="overwrite",
**kwargs,
):
"""
mode=="create", exclusive write, is only known to work on AWS S3, and
requires botocore>1.35.20
"""
bucket, key, _ = self.split_path(path)
concurrency = max_concurrency or self.max_concurrency
size = len(data)
if mode == "create":
match = {"IfNoneMatch": "*"}
else:
match = {}

# 5 GB is the limit for an S3 PUT
if size < min(5 * 2**30, 2 * chunksize):
out = await self._call_s3(
"put_object", Bucket=bucket, Key=key, Body=data, **kwargs
"put_object", Bucket=bucket, Key=key, Body=data, **kwargs, **match
)
self.invalidate_cache(path)
return out
Expand All @@ -1155,32 +1174,37 @@ async def _pipe_file(
ranges = list(range(0, len(data), chunksize))
inds = list(range(0, len(ranges), concurrency)) + [len(ranges)]
parts = []
for start, stop in zip(inds[:-1], inds[1:]):
out = await asyncio.gather(
*[
self._call_s3(
"upload_part",
Bucket=bucket,
PartNumber=i + 1,
UploadId=mpu["UploadId"],
Body=data[ranges[i] : ranges[i] + chunksize],
Key=key,
)
for i in range(start, stop)
]
)
parts.extend(
{"PartNumber": i + 1, "ETag": o["ETag"]}
for i, o in zip(range(start, stop), out)
try:
for start, stop in zip(inds[:-1], inds[1:]):
out = await asyncio.gather(
*[
self._call_s3(
"upload_part",
Bucket=bucket,
PartNumber=i + 1,
UploadId=mpu["UploadId"],
Body=data[ranges[i] : ranges[i] + chunksize],
Key=key,
)
for i in range(start, stop)
]
)
parts.extend(
{"PartNumber": i + 1, "ETag": o["ETag"]}
for i, o in zip(range(start, stop), out)
)
await self._call_s3(
"complete_multipart_upload",
Bucket=bucket,
Key=key,
UploadId=mpu["UploadId"],
MultipartUpload={"Parts": parts},
**match,
)
await self._call_s3(
"complete_multipart_upload",
Bucket=bucket,
Key=key,
UploadId=mpu["UploadId"],
MultipartUpload={"Parts": parts},
)
self.invalidate_cache(path)
self.invalidate_cache(path)
except Exception:
await self._abort_mpu(bucket, key, mpu["UploadId"])
raise

async def _put_file(
self,
Expand All @@ -1189,8 +1213,13 @@ async def _put_file(
callback=_DEFAULT_CALLBACK,
chunksize=50 * 2**20,
max_concurrency=None,
mode="overwrite",
**kwargs,
):
"""
mode=="create", exclusive write, is only known to work on AWS S3, and
requires botocore>1.35.20
"""
bucket, key, _ = self.split_path(rpath)
if os.path.isdir(lpath):
if key:
Expand All @@ -1200,6 +1229,10 @@ async def _put_file(
await self._mkdir(lpath)
size = os.path.getsize(lpath)
callback.set_size(size)
if mode == "create":
match = {"IfNoneMatch": "*"}
else:
match = {}

if "ContentType" not in kwargs:
content_type, _ = mimetypes.guess_type(lpath)
Expand All @@ -1210,33 +1243,39 @@ async def _put_file(
if size < min(5 * 2**30, 2 * chunksize):
chunk = f0.read()
await self._call_s3(
"put_object", Bucket=bucket, Key=key, Body=chunk, **kwargs
"put_object", Bucket=bucket, Key=key, Body=chunk, **kwargs, **match
)
callback.relative_update(size)
else:

mpu = await self._call_s3(
"create_multipart_upload", Bucket=bucket, Key=key, **kwargs
)
out = await self._upload_file_part_concurrent(
bucket,
key,
mpu,
f0,
callback=callback,
chunksize=chunksize,
max_concurrency=max_concurrency,
)
parts = [
{"PartNumber": i + 1, "ETag": o["ETag"]} for i, o in enumerate(out)
]
await self._call_s3(
"complete_multipart_upload",
Bucket=bucket,
Key=key,
UploadId=mpu["UploadId"],
MultipartUpload={"Parts": parts},
)
try:
out = await self._upload_file_part_concurrent(
bucket,
key,
mpu,
f0,
callback=callback,
chunksize=chunksize,
max_concurrency=max_concurrency,
)
parts = [
{"PartNumber": i + 1, "ETag": o["ETag"]}
for i, o in enumerate(out)
]
await self._call_s3(
"complete_multipart_upload",
Bucket=bucket,
Key=key,
UploadId=mpu["UploadId"],
MultipartUpload={"Parts": parts},
**match,
)
except Exception:
await self._abort_mpu(bucket, key, mpu["UploadId"])
raise
while rpath:
self.invalidate_cache(rpath)
rpath = self._parent(rpath)
Expand Down Expand Up @@ -1939,18 +1978,22 @@ async def _list_multipart_uploads(self, bucket):

list_multipart_uploads = sync_wrapper(_list_multipart_uploads)

async def _abort_mpu(self, bucket, key, mpu):
await self._call_s3(
"abort_multipart_upload",
Bucket=bucket,
Key=key,
UploadId=mpu,
)

abort_mpu = sync_wrapper(_abort_mpu)

async def _clear_multipart_uploads(self, bucket):
"""Remove any partial uploads in the bucket"""
out = await self._list_multipart_uploads(bucket)
await asyncio.gather(
*[
self._call_s3(
"abort_multipart_upload",
Bucket=bucket,
Key=upload["Key"],
UploadId=upload["UploadId"],
)
for upload in out
self._abort_mpu(bucket, upload["Key"], upload["UploadId"])
for upload in await self._list_multipart_uploads(bucket)
]
)

Expand Down Expand Up @@ -2412,13 +2455,18 @@ def commit(self):
raise RuntimeError
else:
logger.debug("Complete multi-part upload for %s " % self)
if "x" in self.mode:
match = {"IfNoneMatch": "*"}
else:
match = {}
part_info = {"Parts": self.parts}
write_result = self._call_s3(
"complete_multipart_upload",
Bucket=self.bucket,
Key=self.key,
UploadId=self.mpu["UploadId"],
MultipartUpload=part_info,
**match,
)

if self.fs.version_aware:
Expand All @@ -2441,12 +2489,7 @@ def discard(self):

def _abort_mpu(self):
if self.mpu:
self._call_s3(
"abort_multipart_upload",
Bucket=self.bucket,
Key=self.key,
UploadId=self.mpu["UploadId"],
)
self.fs.abort_mpu(self.bucket, self.key, self.mpu["UploadId"])
self.mpu = None


Expand Down
9 changes: 8 additions & 1 deletion s3fs/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,18 @@ def translate_boto_error(error, message=None, set_cause=True, *args, **kwargs):
recognized, an IOError with the original error message is returned.
"""
error_response = getattr(error, "response", None)

if error_response is None:
# non-http error, or response is None:
return error
code = error_response["Error"].get("Code")
constructor = ERROR_CODE_TO_EXCEPTION.get(code)
if (
code == "PreconditionFailed"
and error_response["Error"].get("Condition", "") == "If-None-Match"
):
constructor = FileExistsError
else:
constructor = ERROR_CODE_TO_EXCEPTION.get(code)
if constructor:
if not message:
message = error_response["Error"].get("Message", str(error))
Expand Down
38 changes: 5 additions & 33 deletions s3fs/tests/derived/s3fs_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,45 +75,17 @@ def _get_boto3_client(self):

@pytest.fixture(scope="class")
def _s3_base(self):
# writable local S3 system
import shlex
import subprocess
# copy of s3_base in test_s3fs
from moto.moto_server.threaded_moto_server import ThreadedMotoServer

try:
# should fail since we didn't start server yet
r = requests.get(endpoint_uri)
except:
pass
else:
if r.ok:
raise RuntimeError("moto server already up")
server = ThreadedMotoServer(ip_address="127.0.0.1", port=port)
server.start()
if "AWS_SECRET_ACCESS_KEY" not in os.environ:
os.environ["AWS_SECRET_ACCESS_KEY"] = "foo"
if "AWS_ACCESS_KEY_ID" not in os.environ:
os.environ["AWS_ACCESS_KEY_ID"] = "foo"
proc = subprocess.Popen(
shlex.split("moto_server s3 -p %s" % port),
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
stdin=subprocess.DEVNULL,
)

timeout = 5
while timeout > 0:
try:
print("polling for moto server")
r = requests.get(endpoint_uri)
if r.ok:
break
except:
pass
timeout -= 0.1
time.sleep(0.1)
if proc.poll() is not None:
proc.terminate()
raise RuntimeError("Starting moto server failed")
print("server up")
yield
print("moto done")
proc.terminate()
proc.wait()
server.stop()
13 changes: 13 additions & 0 deletions s3fs/tests/derived/s3fs_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

import fsspec.tests.abstract as abstract
from s3fs.tests.derived.s3fs_fixtures import S3fsFixtures

Expand All @@ -12,3 +14,14 @@ class TestS3fsGet(abstract.AbstractGetTests, S3fsFixtures):

class TestS3fsPut(abstract.AbstractPutTests, S3fsFixtures):
pass


class TestS3fsPipe(abstract.AbstractPipeTests, S3fsFixtures):
pass


class TestS3fsOpen(abstract.AbstractOpenTests, S3fsFixtures):

test_open_exclusive = pytest.mark.xfail(
reason="complete_multipart_upload doesn't implement condition in moto"
)(abstract.AbstractOpenTests.test_open_exclusive)
Loading

0 comments on commit ca949ab

Please sign in to comment.