Skip to content

Commit

Permalink
gh-89967: make WeakKeyDictionary and WeakValueDictionary thread safe (#…
Browse files Browse the repository at this point in the history
…125325)

Make `WeakKeyDictionary` and `WeakValueDictionary` thread safe by copying the underlying the dict before iterating over it.
  • Loading branch information
kumaraditya303 authored Oct 13, 2024
1 parent 0848932 commit cd0f9d1
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 174 deletions.
25 changes: 0 additions & 25 deletions Lib/_weakrefset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,6 @@
__all__ = ['WeakSet']


class _IterationGuard:
# This context manager registers itself in the current iterators of the
# weak container, such as to delay all removals until the context manager
# exits.
# This technique should be relatively thread-safe (since sets are).

def __init__(self, weakcontainer):
# Don't create cycles
self.weakcontainer = ref(weakcontainer)

def __enter__(self):
w = self.weakcontainer()
if w is not None:
w._iterating.add(self)
return self

def __exit__(self, e, t, b):
w = self.weakcontainer()
if w is not None:
s = w._iterating
s.remove(self)
if not s:
w._commit_removals()


class WeakSet:
def __init__(self, data=None):
self.data = set()
Expand Down
198 changes: 49 additions & 149 deletions Lib/weakref.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ReferenceType,
_remove_dead_weakref)

from _weakrefset import WeakSet, _IterationGuard
from _weakrefset import WeakSet

import _collections_abc # Import after _weakref to avoid circular import.
import sys
Expand Down Expand Up @@ -105,53 +105,27 @@ def __init__(self, other=(), /, **kw):
def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
self = selfref()
if self is not None:
if self._iterating:
self._pending_removals.append(wr.key)
else:
# Atomic removal is necessary since this function
# can be called asynchronously by the GC
_atomic_removal(self.data, wr.key)
# Atomic removal is necessary since this function
# can be called asynchronously by the GC
_atomic_removal(self.data, wr.key)
self._remove = remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
self.data = {}
self.update(other, **kw)

def _commit_removals(self, _atomic_removal=_remove_dead_weakref):
pop = self._pending_removals.pop
d = self.data
# We shouldn't encounter any KeyError, because this method should
# always be called *before* mutating the dict.
while True:
try:
key = pop()
except IndexError:
return
_atomic_removal(d, key)

def __getitem__(self, key):
if self._pending_removals:
self._commit_removals()
o = self.data[key]()
if o is None:
raise KeyError(key)
else:
return o

def __delitem__(self, key):
if self._pending_removals:
self._commit_removals()
del self.data[key]

def __len__(self):
if self._pending_removals:
self._commit_removals()
return len(self.data)

def __contains__(self, key):
if self._pending_removals:
self._commit_removals()
try:
o = self.data[key]()
except KeyError:
Expand All @@ -162,38 +136,28 @@ def __repr__(self):
return "<%s at %#x>" % (self.__class__.__name__, id(self))

def __setitem__(self, key, value):
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(value, self._remove, key)

def copy(self):
if self._pending_removals:
self._commit_removals()
new = WeakValueDictionary()
with _IterationGuard(self):
for key, wr in self.data.items():
o = wr()
if o is not None:
new[key] = o
for key, wr in self.data.copy().items():
o = wr()
if o is not None:
new[key] = o
return new

__copy__ = copy

def __deepcopy__(self, memo):
from copy import deepcopy
if self._pending_removals:
self._commit_removals()
new = self.__class__()
with _IterationGuard(self):
for key, wr in self.data.items():
o = wr()
if o is not None:
new[deepcopy(key, memo)] = o
for key, wr in self.data.copy().items():
o = wr()
if o is not None:
new[deepcopy(key, memo)] = o
return new

def get(self, key, default=None):
if self._pending_removals:
self._commit_removals()
try:
wr = self.data[key]
except KeyError:
Expand All @@ -207,21 +171,15 @@ def get(self, key, default=None):
return o

def items(self):
if self._pending_removals:
self._commit_removals()
with _IterationGuard(self):
for k, wr in self.data.items():
v = wr()
if v is not None:
yield k, v
for k, wr in self.data.copy().items():
v = wr()
if v is not None:
yield k, v

def keys(self):
if self._pending_removals:
self._commit_removals()
with _IterationGuard(self):
for k, wr in self.data.items():
if wr() is not None:
yield k
for k, wr in self.data.copy().items():
if wr() is not None:
yield k

__iter__ = keys

Expand All @@ -235,32 +193,22 @@ def itervaluerefs(self):
keep the values around longer than needed.
"""
if self._pending_removals:
self._commit_removals()
with _IterationGuard(self):
yield from self.data.values()
yield from self.data.copy().values()

def values(self):
if self._pending_removals:
self._commit_removals()
with _IterationGuard(self):
for wr in self.data.values():
obj = wr()
if obj is not None:
yield obj
for wr in self.data.copy().values():
obj = wr()
if obj is not None:
yield obj

def popitem(self):
if self._pending_removals:
self._commit_removals()
while True:
key, wr = self.data.popitem()
o = wr()
if o is not None:
return key, o

def pop(self, key, *args):
if self._pending_removals:
self._commit_removals()
try:
o = self.data.pop(key)()
except KeyError:
Expand All @@ -279,16 +227,12 @@ def setdefault(self, key, default=None):
except KeyError:
o = None
if o is None:
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(default, self._remove, key)
return default
else:
return o

def update(self, other=None, /, **kwargs):
if self._pending_removals:
self._commit_removals()
d = self.data
if other is not None:
if not hasattr(other, "items"):
Expand All @@ -308,9 +252,7 @@ def valuerefs(self):
keep the values around longer than needed.
"""
if self._pending_removals:
self._commit_removals()
return list(self.data.values())
return list(self.data.copy().values())

def __ior__(self, other):
self.update(other)
Expand Down Expand Up @@ -369,57 +311,22 @@ def __init__(self, dict=None):
def remove(k, selfref=ref(self)):
self = selfref()
if self is not None:
if self._iterating:
self._pending_removals.append(k)
else:
try:
del self.data[k]
except KeyError:
pass
try:
del self.data[k]
except KeyError:
pass
self._remove = remove
# A list of dead weakrefs (keys to be removed)
self._pending_removals = []
self._iterating = set()
self._dirty_len = False
if dict is not None:
self.update(dict)

def _commit_removals(self):
# NOTE: We don't need to call this method before mutating the dict,
# because a dead weakref never compares equal to a live weakref,
# even if they happened to refer to equal objects.
# However, it means keys may already have been removed.
pop = self._pending_removals.pop
d = self.data
while True:
try:
key = pop()
except IndexError:
return

try:
del d[key]
except KeyError:
pass

def _scrub_removals(self):
d = self.data
self._pending_removals = [k for k in self._pending_removals if k in d]
self._dirty_len = False

def __delitem__(self, key):
self._dirty_len = True
del self.data[ref(key)]

def __getitem__(self, key):
return self.data[ref(key)]

def __len__(self):
if self._dirty_len and self._pending_removals:
# self._pending_removals may still contain keys which were
# explicitly removed, we have to scrub them (see issue #21173).
self._scrub_removals()
return len(self.data) - len(self._pending_removals)
return len(self.data)

def __repr__(self):
return "<%s at %#x>" % (self.__class__.__name__, id(self))
Expand All @@ -429,23 +336,21 @@ def __setitem__(self, key, value):

def copy(self):
new = WeakKeyDictionary()
with _IterationGuard(self):
for key, value in self.data.items():
o = key()
if o is not None:
new[o] = value
for key, value in self.data.copy().items():
o = key()
if o is not None:
new[o] = value
return new

__copy__ = copy

def __deepcopy__(self, memo):
from copy import deepcopy
new = self.__class__()
with _IterationGuard(self):
for key, value in self.data.items():
o = key()
if o is not None:
new[o] = deepcopy(value, memo)
for key, value in self.data.copy().items():
o = key()
if o is not None:
new[o] = deepcopy(value, memo)
return new

def get(self, key, default=None):
Expand All @@ -459,26 +364,23 @@ def __contains__(self, key):
return wr in self.data

def items(self):
with _IterationGuard(self):
for wr, value in self.data.items():
key = wr()
if key is not None:
yield key, value
for wr, value in self.data.copy().items():
key = wr()
if key is not None:
yield key, value

def keys(self):
with _IterationGuard(self):
for wr in self.data:
obj = wr()
if obj is not None:
yield obj
for wr in self.data.copy():
obj = wr()
if obj is not None:
yield obj

__iter__ = keys

def values(self):
with _IterationGuard(self):
for wr, value in self.data.items():
if wr() is not None:
yield value
for wr, value in self.data.copy().items():
if wr() is not None:
yield value

def keyrefs(self):
"""Return a list of weak references to the keys.
Expand All @@ -493,15 +395,13 @@ def keyrefs(self):
return list(self.data)

def popitem(self):
self._dirty_len = True
while True:
key, value = self.data.popitem()
o = key()
if o is not None:
return o, value

def pop(self, key, *args):
self._dirty_len = True
return self.data.pop(ref(key), *args)

def setdefault(self, key, default=None):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make :class:`~weakref.WeakKeyDictionary` and :class:`~weakref.WeakValueDictionary` safe against concurrent mutations from other threads. Patch by Kumar Aditya.

0 comments on commit cd0f9d1

Please sign in to comment.