Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C2DEVEL-16211: Index optimizations #4

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
locks: hide sensitive actions behind locks
  • Loading branch information
Yurzs committed Sep 26, 2024
commit 195bf758330f11d619a230c7457ab22f5dd14e1a
27 changes: 23 additions & 4 deletions ihashmap/cache.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import functools
import logging
import threading
from typing import Any, Callable, Generator, List, Mapping, Optional, Union

from typing_extensions import Protocol, Self, final

from ihashmap.action import Action
from ihashmap.helpers import match_query
from ihashmap.helpers import locked, match_query

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -33,6 +34,9 @@ def delete(self, name: str, key: str) -> None:
def keys(self, name: str) -> List[str]:
...

def pop(self, name: str, key: str, default: Optional[Any] = None) -> Optional[str]:
...


@final
class PipelineContext:
Expand Down Expand Up @@ -175,6 +179,9 @@ class Cache:
__INSTANCE__ = None
"""Singleton instance."""

LOCK = threading.RLock()
"""Lock for thread-safe operations."""

def __new__(cls, protocol: CacheProtocol) -> Self:
"""Singleton instance creation."""

Expand All @@ -196,6 +203,7 @@ def instance(cls):
return cls.__INSTANCE__

@PIPELINE.set
@locked
def set(self, name: str, key: str, value: Mapping[str, Any]) -> None:
"""Wrapper for pipeline execution.

Expand Down Expand Up @@ -228,6 +236,7 @@ def get(self, name: str, key: str, default: Optional[Any] = None):
return self.protocol.get(name, key, default)

@PIPELINE.update
@locked
def update(self, name: str, key: str, value: Mapping[str, Any]) -> None:
"""Wrapper for pipeline execution.

Expand All @@ -236,14 +245,18 @@ def update(self, name: str, key: str, value: Mapping[str, Any]) -> None:
:param dict value: stored value.
"""

if value.get(self.PRIMARY_KEY) is not None and value.get(self.PRIMARY_KEY) != key:
if (
value.get(self.PRIMARY_KEY) is not None
and value.get(self.PRIMARY_KEY) != key
):
raise ValueError(
f"Primary key mismatch: {value[self.PRIMARY_KEY]} != {key}"
)

return self.protocol.update(name, key, value)

@PIPELINE.delete
@locked
def delete(self, name: str, key: str) -> None:
"""Wrapper for pipeline execution.

Expand All @@ -253,6 +266,7 @@ def delete(self, name: str, key: str) -> None:

return self.protocol.delete(name, key)

@locked
def all(self, name: str) -> Generator[Mapping[str, Any], None, None]:
"""Finds all values in cache.

Expand All @@ -264,6 +278,7 @@ def all(self, name: str) -> Generator[Mapping[str, Any], None, None]:
if value is not None:
yield value

@locked
def search(
self,
name: str,
Expand Down Expand Up @@ -310,18 +325,22 @@ def search(

if not hit_indexes:
LOG.warning(
"Complete index miss for query: %s. Query will be slow.", search_query
"Complete index miss for %s query: %s. Query will be slow.",
name,
search_query,
)

matched = [{self.PRIMARY_KEY: key} for key in self.protocol.keys(name)]
rest_query = search_query

result = []
for value in matched:
entity = self.protocol.get(name, value[self.PRIMARY_KEY])
result += match_query(entity, rest_query if hit_indexes else search_query)
result += match_query(entity, rest_query)

return result

@locked
def find_all(self, name: str) -> Generator[Union[Mapping, List[str]], None, None]:
"""Internal method to get all values from cache."""

Expand Down
12 changes: 12 additions & 0 deletions ihashmap/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from types import FunctionType
from typing import Any, List, Mapping

Expand Down Expand Up @@ -27,3 +28,14 @@ def match_query(
matched.append(value)

return matched


def locked(f):
"""Decorator for thread-safe methods."""

@functools.wraps(f)
def wrapper(self, *args, **kwargs):
with self.LOCK:
return f(self, *args, **kwargs)

return wrapper
107 changes: 51 additions & 56 deletions ihashmap/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import threading
from functools import partial
from types import FunctionType
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union
from typing import (Any, Dict, List, Mapping, Optional, Set, Tuple, TypeVar,
Union)

from ihashmap.cache import Cache, PipelineContext
from ihashmap.helpers import match_query
from ihashmap.helpers import locked, match_query

T = TypeVar("T")

Expand Down Expand Up @@ -60,7 +61,9 @@ def get_keys(cls) -> List[str]:

result = []
for key in cls.keys:
result.append(key if key != cls.PK_KEY_PLACEHOLDER else cls.cache().PRIMARY_KEY)
result.append(
key if key != cls.PK_KEY_PLACEHOLDER else cls.cache().PRIMARY_KEY
)

return result

Expand Down Expand Up @@ -114,7 +117,7 @@ def before_create(cls, ctx: PipelineContext):
"""Stores original value for after_create usage."""

key, value = ctx.args
ctx.local_data["original_value"] = value
ctx.local_data["value"] = value

if cls.unique:
index_key = cls.get_key(value)
Expand All @@ -130,17 +133,7 @@ def after_create(cls, ctx: PipelineContext):
:return:
"""

value: Mapping = ctx.local_data["original_value"]

with cls.LOCK:
pk = value[cls.cache().PRIMARY_KEY]
key = cls.get_key(value)

index_value: Set[str] = set(cls.get(ctx.name, key, default=[]))
index_value.add(pk)

cls.set(ctx.name, cls.get_key(value), list(index_value))
cls.set(ctx.name, pk, key, reverse=True)
cls.append(ctx.name, ctx.local_data["value"])

@classmethod
def before_delete(cls, ctx: PipelineContext):
Expand All @@ -149,9 +142,7 @@ def before_delete(cls, ctx: PipelineContext):
:param ctx: PipelineManager context.
"""

ctx.local_data["pk"] = cls.cache().protocol.get(
ctx.name, ctx.args[0][cls.cache().PRIMARY_KEY]
)
ctx.local_data["value"] = cls.cache().protocol.get(ctx.name, ctx.args[0])

@classmethod
def after_delete(cls, ctx: PipelineContext):
Expand All @@ -160,19 +151,15 @@ def after_delete(cls, ctx: PipelineContext):
:param dict ctx: PipelineManager context.
"""

with cls.LOCK:
key = cls.get(ctx.name, ctx.local_data["pk"], reverse=True)

if key is not None:
cls.delete(ctx.name, key)
cls.delete(ctx.name, ctx.local_data["pk"], reverse=True)
if ctx.local_data["value"]:
cls.remove(ctx.name, ctx.local_data["value"])

@classmethod
def before_update(cls, ctx: PipelineContext):
"""Creates value copy for after_update usage."""

key, value = ctx.args
ctx.local_data["original_value"] = value
key, _ = ctx.args
ctx.local_data["value"] = cls.cache().protocol.get(ctx.name, key)

@classmethod
def after_update(cls, ctx: PipelineContext):
Expand All @@ -181,19 +168,8 @@ def after_update(cls, ctx: PipelineContext):
:param dict ctx: PipelineManager context.
"""

value = ctx.local_data["original_value"]

with cls.LOCK:
key = cls.get_key(value)
pk = value[cls.cache().PRIMARY_KEY]

index_key = cls.get(ctx.name, pk, reverse=True)
if index_key is not None:
cls.delete(ctx.name, index_key)
cls.delete(ctx.name, pk, reverse=True)

cls.set(ctx.name, key, pk)
cls.set(ctx.name, pk, key, reverse=True)
cls.remove(ctx.name, ctx.local_data["value"])
cls.append(ctx.name, ctx.args[1])

@classmethod
def find_index_for_cache(cls, cache_name: str) -> List["Index"]:
Expand All @@ -213,35 +189,54 @@ def get(
cls,
cache_name: str,
key: str,
reverse: bool = False,
default: Optional[T] = None,
) -> Union[List[str], str, T]:
return cls.cache().protocol.get(
cls.get_name(cache_name, reverse=reverse),
cls.get_name(cache_name),
key,
default=default,
)

@classmethod
def keys_(cls, cache_name: str, reverse: bool = False):
return cls.cache().protocol.keys(cls.get_name(cache_name, reverse=reverse))
def keys_(cls, cache_name: str):
return cls.cache().protocol.keys(cls.get_name(cache_name))

@classmethod
@Cache.PIPELINE.index_set
def set(
cls, cache_name, key: str, value: Union[List[str], str], reverse: bool = False
) -> None:
return cls.cache().protocol.set(
cls.get_name(cache_name, reverse=reverse), key, value
@locked
def append(cls, cache_name, value: Mapping[str, Any]) -> None:
"""Appends value to index."""

index_key = cls.get_key(value)
value_pk = value[cls.cache().PRIMARY_KEY]

current_value = set(cls.get(cache_name, index_key, default=[]))
current_value.add(value_pk)

cls.cache().protocol.set(cls.get_name(cache_name), index_key, list(current_value))
cls.cache().protocol.set(
cls.get_name(cache_name, reverse=True),
value_pk,
index_key,
)

@classmethod
@Cache.PIPELINE.index_delete
def delete(cls, cache_name: str, key: str, reverse: bool = False) -> None:
return cls.cache().protocol.delete(
cls.get_name(cache_name, reverse=reverse), key
@locked
def remove(cls, cache_name: str, entity: Mapping[str, Any]) -> None:
"""Removes value from index."""

entity_pk = entity[cls.cache().PRIMARY_KEY]

index_key = cls.cache().protocol.pop(
cls.get_name(cache_name, reverse=True),
entity_pk,
default=None,
)

if index_key is not None:
cls.cache().protocol.delete(cls.get_name(cache_name), index_key)

@classmethod
def combine(
cls,
Expand All @@ -254,7 +249,7 @@ def combine(
combined_index_keys = set()
matches: List[Mapping[str, Any]] = []

cache_pk = cls.cache().PRIMARY_KEY
pk_key = cls.cache().PRIMARY_KEY

for index in indexes:
subquery = index.cut_data(query, exclude_none=True)
Expand All @@ -275,21 +270,21 @@ def combine(
index_data = filter(filter_func, index_data)

matches.extend(
{cache_pk: pk, **index.cut_data(i)}
{pk_key: pk, **index.cut_data(i)}
for i in index_data
for pk in index.get(cache_name, index.get_key(i), default=[])
)

else:
search_key = index.get_key(subquery)
matches.extend(
{cache_pk: pk, **index.cut_data(subquery)}
{pk_key: pk, **index.cut_data(subquery)}
for pk in index.get(cache_name, search_key, default=[])
)

combined_index = {}
for match in matches:
combined_index.setdefault(match[cache_pk], {}).update(match)
combined_index.setdefault(match[pk_key], {}).update(match)

result = []

Expand All @@ -302,7 +297,7 @@ def combine(
result.append(doc)

return (
sorted(result, key=lambda v: v[cache_pk]),
sorted(result, key=lambda v: v[pk_key]),
combined_index_keys,
)

Expand Down