Skip to content

Commit

Permalink
Fix another S3 multipart upload issue with marc exporter (PP-1693) (#…
Browse files Browse the repository at this point in the history
…2053)

* Fix another issue with multipart s3 upload and add additional test.
* Code review feedback
* Another fix
  • Loading branch information
jonathangreen authored Sep 11, 2024
1 parent 95277ec commit aa8f01a
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 32 deletions.
50 changes: 27 additions & 23 deletions src/palace/manager/marc/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Self

from palace.manager.service.redis.models.marc import MarcFileUploadSession
from palace.manager.service.storage.s3 import S3Service
from palace.manager.service.storage.s3 import MultipartS3UploadPart, S3Service
from palace.manager.sqlalchemy.model.resource import Representation
from palace.manager.util.log import LoggerMixin

Expand Down Expand Up @@ -40,6 +40,14 @@ def update_number(self) -> int:
def add_record(self, key: str, record: bytes) -> None:
self._buffers[key] += record.decode()

def _s3_upload_part(self, key: str, upload_id: str) -> MultipartS3UploadPart:
part_number, data = self.upload_session.get_part_num_and_buffer(key)
upload_part = self.storage_service.multipart_upload(
key, upload_id, part_number, data.encode()
)
self.upload_session.add_part_and_clear_buffer(key, upload_part)
return upload_part

def _s3_sync(self, needs_upload: Sequence[str]) -> None:
upload_ids = self.upload_session.get_upload_ids(needs_upload)
for key in needs_upload:
Expand All @@ -50,18 +58,18 @@ def _s3_sync(self, needs_upload: Sequence[str]) -> None:
self.upload_session.set_upload_id(key, upload_id)
upload_ids[key] = upload_id

part_number, data = self.upload_session.get_part_num_and_buffer(key)
upload_part = self.storage_service.multipart_upload(
key, upload_ids[key], part_number, data.encode()
)
self.upload_session.add_part_and_clear_buffer(key, upload_part)
self._s3_upload_part(key, upload_ids[key])

def sync(self) -> None:
# First sync our buffers to redis
def _sync_buffers_to_redis(self) -> dict[str, int]:
buffer_lengths = self.upload_session.append_buffers(self._buffers)
self._buffers.clear()
return buffer_lengths

# Then, if any of our redis buffers are large enough, upload them to S3
def sync(self) -> None:
# First sync our buffers to redis
buffer_lengths = self._sync_buffers_to_redis()

# Then, if any of our redis buffers are large enough sync them to S3.
needs_upload = [
key
for key, length in buffer_lengths.items()
Expand Down Expand Up @@ -93,36 +101,32 @@ def _abort(self) -> None:
self.remove_session()

def complete(self) -> set[str]:
# Make sure any local data we have is synced
self.sync()
# Ensure any local data is synced to Redis.
self._sync_buffers_to_redis()

in_progress = self.upload_session.get()
for key, upload in in_progress.items():
if upload.upload_id is None:
# We haven't started the upload. At this point there is no reason to start a
# multipart upload, just upload the file directly and continue.
# The multipart upload hasn't started. Perform a regular S3 upload since all data is in the buffer.
self.storage_service.store(
key, upload.buffer, Representation.MARC_MEDIA_TYPE
)
else:
if upload.buffer != "":
# Upload the last chunk if the buffer is not empty, the final part has no
# minimum size requirement.
upload_part = self.storage_service.multipart_upload(
key, upload.upload_id, len(upload.parts), upload.buffer.encode()
)
upload.parts.append(upload_part)

# Complete the multipart upload
# Upload the last chunk if the buffer is not empty. The final part has no minimum size requirement.
last_part = self._s3_upload_part(key, upload.upload_id)
upload.parts.append(last_part)

# Complete the multipart upload.
self.storage_service.multipart_complete(
key, upload.upload_id, upload.parts
)

# Delete our in-progress uploads data from redis
# Delete the in-progress uploads data from Redis.
if in_progress:
self.upload_session.clear_uploads()

# Return the keys that were uploaded
# Return the keys that were uploaded.
return set(in_progress.keys())

def remove_session(self) -> None:
Expand Down
5 changes: 5 additions & 0 deletions src/palace/manager/service/redis/models/marc.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ def get_part_num_and_buffer(self, key: str) -> tuple[int, str]:
)

buffer_data: str = self._parse_value_or_raise(results[0])
# AWS S3 requires part numbers to start at 1, so we need to increment by 1.
#
# NOTE: This is not true in MinIO (our local development environment). MinIO
# allows both 0 and 1 as the first part number. Therefore, tests will pass if this is
# changed, but it will fail when running in an actual AWS environment.
part_number: int = self._parse_value_or_raise(results[1]) + 1

return part_number, buffer_data
Expand Down
16 changes: 10 additions & 6 deletions tests/fixtures/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class MockMultipartUploadPart:
class MockMultipartUpload:
key: str
upload_id: str
parts: list[MockMultipartUploadPart] = field(default_factory=list)
parts: dict[int, MockMultipartUploadPart] = field(default_factory=dict)
content_type: str | None = None


Expand Down Expand Up @@ -143,8 +143,8 @@ def multipart_upload(
part = MultipartS3UploadPart(etag=etag, part_number=part_number)
assert key in self.upload_in_progress
assert self.upload_in_progress[key].upload_id == upload_id
self.upload_in_progress[key].parts.append(
MockMultipartUploadPart(part, content)
self.upload_in_progress[key].parts[part_number] = MockMultipartUploadPart(
part, content
)
return part

Expand All @@ -154,11 +154,15 @@ def multipart_complete(
assert key in self.upload_in_progress
assert self.upload_in_progress[key].upload_id == upload_id
complete_upload = self.upload_in_progress.pop(key)
for part_stored, part_passed_in in zip(complete_upload.parts, parts):
assert part_stored.part_data == part_passed_in
assert len(complete_upload.parts) == len(parts)
expected_parts = [x.part_data for x in complete_upload.parts.values()]
expected_parts.sort(key=lambda x: x.part_number)
assert parts == expected_parts
self.uploads[key] = MockS3ServiceUpload(
key,
b"".join(part_stored.content for part_stored in complete_upload.parts),
b"".join(
part_stored.content for part_stored in complete_upload.parts.values()
),
complete_upload.content_type,
)

Expand Down
61 changes: 59 additions & 2 deletions tests/manager/marc/test_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from palace.manager.sqlalchemy.model.resource import Representation
from tests.fixtures.redis import RedisFixture
from tests.fixtures.s3 import S3ServiceFixture
from tests.fixtures.s3 import S3ServiceFixture, S3ServiceIntegrationFixture


class MarcUploadManagerFixture:
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_sync(self, marc_upload_manager_fixture: MarcUploadManagerFixture):
]
assert upload.upload_id is not None
assert upload.content_type is Representation.MARC_MEDIA_TYPE
[part] = upload.parts
[part] = upload.parts.values()
assert part.content == marc_upload_manager_fixture.test_record1 * 5

# And the s3 part data and upload_id is synced to redis
Expand Down Expand Up @@ -332,3 +332,60 @@ def test__abort(

# The redis record should have been deleted
mock_delete.assert_called_once()

def test_real_storage_service(
self,
redis_fixture: RedisFixture,
s3_service_integration_fixture: S3ServiceIntegrationFixture,
):
"""
Full end-to-end test of the MarcUploadManager using the real S3Service
"""
s3_service = s3_service_integration_fixture.public
uploads = MarcFileUploadSession(redis_fixture.client, 99)
uploader = MarcUploadManager(s3_service, uploads)
batch_size = s3_service.MINIMUM_MULTIPART_UPLOAD_SIZE + 1

with uploader.begin() as locked:
assert locked

# Test three buffer size cases for the complete() method.
#
# 1. A small record that isn't in S3 at the time `complete` is called (test1).
# 2. A large record that needs to be uploaded in parts. On the first `sync`
# call, its buffer is large enough to trigger an upload. When `complete` is
# called, the buffer has data waiting for upload (test2).
# 3. A large record that needs to be uploaded in parts. On the first `sync`
# call, its buffer is large enough to trigger the upload. When `complete`
# is called, the buffer is empty (test3).

uploader.add_record("test1", b"test_record")
uploader.add_record("test2", b"a" * batch_size)
uploader.add_record("test3", b"b" * batch_size)

# Start the sync. This will begin the multipart upload for test2 and test3.
uploader.sync()

# Add some more data
uploader.add_record("test1", b"test_record")
uploader.add_record("test2", b"a" * batch_size)

# Complete the uploads
completed = uploader.complete()

assert completed == {"test1", "test2", "test3"}
assert uploads.get() == {}
assert set(s3_service_integration_fixture.list_objects("public")) == completed

assert (
s3_service_integration_fixture.get_object("public", "test1")
== b"test_record" * 2
)
assert (
s3_service_integration_fixture.get_object("public", "test2")
== b"a" * batch_size * 2
)
assert (
s3_service_integration_fixture.get_object("public", "test3")
== b"b" * batch_size
)
6 changes: 5 additions & 1 deletion tests/manager/service/redis/models/test_marc.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ def test_get_part_num_and_buffer(

marc_file_upload_session_fixture.load_test_data()

# If the buffer has been set, but no parts have been added
# If the buffer has been set, but no parts have been added. The first part number
# should be 1. The buffer should be the same as the original data.
assert uploads.get_part_num_and_buffer(
marc_file_upload_session_fixture.mock_upload_key_1
) == (
Expand All @@ -426,10 +427,12 @@ def test_get_part_num_and_buffer(

with uploads.lock() as locked:
assert locked
# Add part 1
uploads.add_part_and_clear_buffer(
marc_file_upload_session_fixture.mock_upload_key_1,
marc_file_upload_session_fixture.part_1,
)
# Add part 2
uploads.add_part_and_clear_buffer(
marc_file_upload_session_fixture.mock_upload_key_1,
marc_file_upload_session_fixture.part_2,
Expand All @@ -440,6 +443,7 @@ def test_get_part_num_and_buffer(
}
)

# The next part number should be 3, and the buffer should be the new data
assert uploads.get_part_num_and_buffer(
marc_file_upload_session_fixture.mock_upload_key_1
) == (3, "1234567")
Expand Down

0 comments on commit aa8f01a

Please sign in to comment.