Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues caused by differences between redis and elasticache (PP-1693) #2045

Merged
merged 5 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions src/palace/manager/celery/tasks/marc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,32 +133,37 @@ def marc_export_collection(
# Sync the upload_manager to ensure that all the data is written to storage.
upload_manager.sync()

if len(works) == batch_size:
# This task is complete, but there are more works waiting to be exported. So we requeue ourselves
# to process the next batch.
raise task.replace(
marc_export_collection.s(
collection_id=collection_id,
start_time=start_time,
libraries=[l.dict() for l in libraries_info],
batch_size=batch_size,
last_work_id=works[-1].id,
update_number=upload_manager.update_number,
if len(works) != batch_size:
# We have finished generating MARC records. Cleanup and exit.
with task.transaction() as session:
collection = MarcExporter.collection(session, collection_id)
collection_name = collection.name if collection else "unknown"
completed_uploads = upload_manager.complete()
MarcExporter.create_marc_upload_records(
session,
start_time,
collection_id,
libraries_info,
completed_uploads,
)
upload_manager.remove_session()
task.log.info(
f"Finished generating MARC records for collection '{collection_name}' ({collection_id})."
)
return

# If we got here, we have finished generating MARC records. Cleanup and exit.
with task.transaction() as session:
collection = MarcExporter.collection(session, collection_id)
collection_name = collection.name if collection else "unknown"
completed_uploads = upload_manager.complete()
MarcExporter.create_marc_upload_records(
session, start_time, collection_id, libraries_info, completed_uploads
)
upload_manager.remove_session()
task.log.info(
f"Finished generating MARC records for collection '{collection_name}' ({collection_id})."
# This task is complete, but there are more works waiting to be exported. So we requeue ourselves
# to process the next batch.
raise task.replace(
marc_export_collection.s(
collection_id=collection_id,
start_time=start_time,
libraries=[l.dict() for l in libraries_info],
batch_size=batch_size,
last_work_id=works[-1].id,
update_number=upload_manager.update_number,
)
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in src/palace/manager/celery/tasks/marc.py could come in as a separate PR if desired. They fix a concurrency issue I found while doing other testing. Occasionally, since we were re-queuing while holding the lock, another worker would pick up the new task before this one had released the lock. So this updates the order of operations, so we release the lock before calling task.replace.



@shared_task(queue=QueueNames.default, bind=True)
Expand Down
81 changes: 81 additions & 0 deletions src/palace/manager/service/redis/escape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

import json
from functools import cached_property

from palace.manager.core.exceptions import PalaceValueError


class JsonPathEscapeMixin:
r"""
Mixin to provide methods for escaping and unescaping JsonPaths for use in Redis / ElastiCache.

This is necessary because some characters in object keys are not handled well by AWS ElastiCache,
and other characters seem problematic in Redis.

This mixin provides methods to escape and unescape these characters, so that they can be used in
object keys, and the keys can be queried via JSONPath without issue.

In ElastiCache when ~ is used in a key, the key is never updated, despite returning a success. And
when a / is used in a key, the key is interpreted as a nested path, nesting a new key for every
slash in the path. This is not the behavior we want, so we need to escape these characters.

In Redis, the \ character is used as an escape character, and the " character is used to denote
the end of a string for the JSONPath. This means that these characters need to be escaped as well.

Characters are escaped by prefixing them with a backtick character, followed by a single character
from _MAPPING that represents the escaped character. The backtick character itself is escaped by
prefixing it with another backtick character.
"""

_ESCAPE_CHAR = "`"

_MAPPING = {
"/": "s",
"\\": "b",
'"': "'",
"~": "t",
}

@cached_property
def _FORWARD_MAPPING(self) -> dict[str, str]:
mapping = {k: "".join((self._ESCAPE_CHAR, v)) for k, v in self._MAPPING.items()}
mapping[self._ESCAPE_CHAR] = "".join((self._ESCAPE_CHAR, self._ESCAPE_CHAR))
return mapping

@cached_property
def _REVERSE_MAPPING(self) -> dict[str, str]:
mapping = {v: k for k, v in self._MAPPING.items()}
mapping[self._ESCAPE_CHAR] = self._ESCAPE_CHAR
return mapping

def _escape_path(self, path: str, elasticache: bool = False) -> str:
escaped = "".join([self._FORWARD_MAPPING.get(c, c) for c in path])
if elasticache:
# As well as the simple escaping we have defined here, for ElastiCache we need to fully
# escape the path as if it were a JSON string. So we call json.dumps to do this. We
# strip the leading and trailing quotes from the result, as we only want the escaped
# string, not the quotes.
escaped = json.dumps(escaped)[1:-1]
return escaped

def _unescape_path(self, path: str) -> str:
in_escape = False
unescaped = []
for char in path:
if in_escape:
if char not in self._REVERSE_MAPPING:
raise PalaceValueError(
f"Invalid escape sequence '{self._ESCAPE_CHAR}{char}'"
)
unescaped.append(self._REVERSE_MAPPING[char])
in_escape = False
elif char == self._ESCAPE_CHAR:
in_escape = True
else:
unescaped.append(char)

if in_escape:
raise PalaceValueError("Unterminated escape sequence.")

return "".join(unescaped)
16 changes: 15 additions & 1 deletion src/palace/manager/service/redis/models/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from contextlib import contextmanager
from datetime import timedelta
from functools import cached_property
from typing import TypeVar, cast
from typing import Any, TypeVar, cast
from uuid import uuid4

from redis.exceptions import ResponseError

from palace.manager.celery.task import Task
from palace.manager.core.exceptions import BasePalaceException
from palace.manager.service.redis.redis import Redis
Expand Down Expand Up @@ -374,6 +376,18 @@ def _parse_value_or_raise(cls, value: Sequence[T] | None) -> T:
raise LockError(f"Could not parse value ({json.dumps(value)})")
return parsed_value

@staticmethod
def _validate_pipeline_results(results: list[Any]) -> bool:
"""
This function validates that all the results of the pipeline are successful,
and not a ResponseError.

NOTE: The AWS ElastiCache implementation returns slightly different results than Redis.
In Redis, unsuccessful results when a key is not found are `None`, but in AWS they are
returned as a `ResponseError`, so we are careful to check for both in this function.
"""
return all(r and not isinstance(r, ResponseError) for r in results)

def acquire(self) -> bool:
return (
self._acquire_script(
Expand Down
36 changes: 19 additions & 17 deletions src/palace/manager/service/redis/models/marc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from typing import Any

from pydantic import BaseModel
from redis import ResponseError, WatchError
from redis import WatchError

from palace.manager.service.redis.escape import JsonPathEscapeMixin
from palace.manager.service.redis.models.lock import LockError, RedisJsonLock
from palace.manager.service.redis.redis import Pipeline, Redis
from palace.manager.service.storage.s3 import MultipartS3UploadPart
Expand Down Expand Up @@ -40,7 +41,7 @@ class MarcFileUploadState(StrEnum):
UPLOADING = auto()


class MarcFileUploadSession(RedisJsonLock, LoggerMixin):
class MarcFileUploadSession(RedisJsonLock, JsonPathEscapeMixin, LoggerMixin):
"""
This class is used as a lock for the Celery MARC export task, to ensure that only one
task can upload MARC files for a given collection at a time. It increments an update
Expand Down Expand Up @@ -106,7 +107,8 @@ def _upload_initial_value(buffer_data: str) -> dict[str, Any]:
return MarcFileUpload(buffer=buffer_data).dict(exclude_none=True)

def _upload_path(self, upload_key: str) -> str:
return f"{self._uploads_json_key}['{upload_key}']"
upload_key = self._escape_path(upload_key, self._redis_client.elasticache)
return f'{self._uploads_json_key}["{upload_key}"]'

def _buffer_path(self, upload_key: str) -> str:
upload_path = self._upload_path(upload_key)
Expand Down Expand Up @@ -166,7 +168,7 @@ def _execute_pipeline(
pipe.json().numincrby(self.key, self._update_number_json_key, updates)
pipe.pexpire(self.key, self._lock_timeout_ms)
try:
pipe_results = pipe.execute()
pipe_results = pipe.execute(raise_on_error=False)
except WatchError as e:
raise MarcFileUploadSessionError(
"Failed to update buffers. Another process is modifying the buffers."
Expand All @@ -184,6 +186,7 @@ def append_buffers(self, data: Mapping[str, str]) -> dict[str, int]:
existing_uploads: list[str] = self._parse_value_or_raise(
pipe.json().objkeys(self.key, self._uploads_json_key)
)
existing_uploads = [self._unescape_path(p) for p in existing_uploads]
pipe.multi()
for key, value in data.items():
if value == "":
Expand All @@ -195,14 +198,14 @@ def append_buffers(self, data: Mapping[str, str]) -> dict[str, int]:
else:
pipe.json().set(
self.key,
path=self._upload_path(key),
path=(self._upload_path(key)),
obj=self._upload_initial_value(value),
)
set_results[key] = len(value)

pipe_results = self._execute_pipeline(pipe, len(data))

if not all(pipe_results):
if not self._validate_pipeline_results(pipe_results):
raise MarcFileUploadSessionError("Failed to append buffers.")

return {
Expand All @@ -224,7 +227,7 @@ def add_part_and_clear_buffer(self, key: str, part: MultipartS3UploadPart) -> No
)
pipe_results = self._execute_pipeline(pipe, 1)

if not all(pipe_results):
if not self._validate_pipeline_results(pipe_results):
raise MarcFileUploadSessionError("Failed to add part and clear buffer.")

def set_upload_id(self, key: str, upload_id: str) -> None:
Expand All @@ -237,15 +240,15 @@ def set_upload_id(self, key: str, upload_id: str) -> None:
)
pipe_results = self._execute_pipeline(pipe, 1)

if not all(pipe_results):
if not self._validate_pipeline_results(pipe_results):
raise MarcFileUploadSessionError("Failed to set upload ID.")

def clear_uploads(self) -> None:
with self._pipeline() as pipe:
pipe.json().clear(self.key, self._uploads_json_key)
pipe_results = self._execute_pipeline(pipe, 1)

if not all(pipe_results):
if not self._validate_pipeline_results(pipe_results):
raise MarcFileUploadSessionError("Failed to clear uploads.")

def _get_specific(
Expand All @@ -269,7 +272,7 @@ def _get_all(self, key: str) -> dict[str, Any]:
if results is None:
return {}

return results
return {self._unescape_path(k): v for k, v in results.items()}

def get(self, keys: str | Sequence[str] | None = None) -> dict[str, MarcFileUpload]:
if keys is None:
Expand All @@ -285,15 +288,14 @@ def get_upload_ids(self, keys: str | Sequence[str]) -> dict[str, str]:
return self._get_specific(keys, self._upload_id_path)

def get_part_num_and_buffer(self, key: str) -> tuple[int, str]:
try:
with self._redis_client.pipeline() as pipe:
pipe.json().get(self.key, self._buffer_path(key))
pipe.json().arrlen(self.key, self._parts_path(key))
results = pipe.execute()
except ResponseError as e:
with self._redis_client.pipeline() as pipe:
pipe.json().get(self.key, self._buffer_path(key))
pipe.json().arrlen(self.key, self._parts_path(key))
results = pipe.execute(raise_on_error=False)
if not self._validate_pipeline_results(results):
raise MarcFileUploadSessionError(
"Failed to get part number and buffer data."
) from e
)

buffer_data: str = self._parse_value_or_raise(results[0])
part_number: int = self._parse_value_or_raise(results[1])
Expand Down
13 changes: 13 additions & 0 deletions src/palace/manager/service/redis/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import ABC, abstractmethod
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any

import redis
Expand Down Expand Up @@ -76,6 +77,7 @@ def key_args(self, args: list[Any]) -> Sequence[str]:
cmd.name: cmd
for cmd in [
RedisCommandNoArgs("SCRIPT LOAD"),
RedisCommandNoArgs("INFO"),
RedisCommandArgs("KEYS"),
RedisCommandArgs("GET"),
RedisCommandArgs("EXPIRE"),
Expand Down Expand Up @@ -156,6 +158,17 @@ def pipeline(self, transaction: bool = True, shard_hint: Any = None) -> Pipeline
key_generator=self.get_key,
)

@cached_property
def elasticache(self) -> bool:
"""
Check if this Redis instances is actually connected to AWS ElastiCache rather than Redis.

AWS ElastiCache is supposed to be API compatible with Redis, but there are some differences
that can cause issues. This property can be used to detect if we are connected to ElastiCache
and handle those differences.
"""
return self.info().get("os") == "Amazon ElastiCache"


class Pipeline(RedisPipeline, RedisPrefixCheckMixin):
"""
Expand Down
11 changes: 7 additions & 4 deletions tests/manager/service/redis/models/test_marc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import string

import pytest

from palace.manager.service.redis.models.marc import (
Expand All @@ -21,9 +23,10 @@ def __init__(self, redis_fixture: RedisFixture):
self._redis_fixture.client, self.mock_collection_id
)

self.mock_upload_key_1 = "test1"
self.mock_upload_key_2 = "test2"
self.mock_upload_key_3 = "test3"
# Some keys with special characters to make sure they are handled correctly.
self.mock_upload_key_1 = "test/test1/?$xyz.abc"
self.mock_upload_key_2 = "t'est💣/tëst2.\"ext`"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the 💣 here! 🤣

self.mock_upload_key_3 = string.printable

self.mock_unset_upload_key = "test4"

Expand All @@ -49,7 +52,7 @@ def load_test_data(self) -> dict[str, int]:

return return_value

def test_data_records(self, *keys: str):
def test_data_records(self, *keys: str) -> dict[str, MarcFileUpload]:
return {key: MarcFileUpload(buffer=self.test_data[key]) for key in keys}


Expand Down
Loading