forked from quantmind/pulsar
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
174 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from asyncio import Task, get_event_loop | ||
from contextlib import contextmanager | ||
|
||
|
||
class ContextStack(list): | ||
pass | ||
|
||
|
||
class TaskContext: | ||
_previous_task_factory = None | ||
|
||
def __call__(self, loop, coro): | ||
current = Task.current_task(loop=loop) | ||
task = Task(coro, loop=loop) | ||
try: | ||
task._context = current._context.copy() | ||
except AttributeError: | ||
pass | ||
try: | ||
task._context_stack = current._context_stack.copy() | ||
except AttributeError: | ||
pass | ||
return task | ||
|
||
def setup(self): | ||
loop = get_event_loop() | ||
self._previous_task_factory = loop.get_task_factory() | ||
loop.set_task_factory(self) | ||
|
||
def remove(self): | ||
loop = get_event_loop() | ||
loop.set_task_factory(self._previous_task_factory) | ||
|
||
@contextmanager | ||
def begin(self, *args, **kwargs): | ||
for key, value in mapping(*args, **kwargs): | ||
self.stack_push(key, value) | ||
try: | ||
yield self | ||
finally: | ||
for key, _ in mapping(*args, **kwargs): | ||
self.stack_pop(key) | ||
|
||
def set(self, key, value): | ||
"""Set a value in the task context | ||
""" | ||
task = Task.current_task() | ||
try: | ||
context = task._context | ||
except AttributeError: | ||
task._context = context = {} | ||
context[key] = value | ||
|
||
def get(self, key): | ||
task = Task.current_task() | ||
try: | ||
context = task._context | ||
except AttributeError: | ||
return | ||
return context.get(key) | ||
|
||
def pop(self, key): | ||
context = Task.current_task()._context | ||
value = context.pop(key) | ||
if isinstance(value, ContextStack): | ||
stack_value = value.pop() | ||
if value: | ||
context[key] = value | ||
return stack_value | ||
return value | ||
|
||
def stack_push(self, key, value): | ||
"""Set a value in a task context stack | ||
""" | ||
task = Task.current_task() | ||
try: | ||
context = task._context_stack | ||
except AttributeError: | ||
task._context_stack = context = {} | ||
if key not in context: | ||
context[key] = [] | ||
context[key].append(value) | ||
|
||
def stack_get(self, key): | ||
"""Set a value in a task context stack | ||
""" | ||
task = Task.current_task() | ||
try: | ||
context = task._context_stack | ||
except AttributeError: | ||
task._context_stack = context = {} | ||
if key in context: | ||
return context[key][-1] | ||
|
||
def stack_pop(self, key): | ||
"""Remove a value in a task context stack | ||
""" | ||
task = Task.current_task() | ||
try: | ||
context = task._context_stack | ||
except AttributeError: | ||
task._context_stack = context = {} | ||
value = context[key] | ||
stack_value = value.pop() | ||
if not value: | ||
context.pop(key) | ||
return stack_value | ||
|
||
|
||
def mapping(*args, **kwargs): | ||
if args: | ||
if len(args) > 1: | ||
raise TypeError('expected at most 1 arguments, got %d' % len(args)) | ||
for key, value in args[0].items(): | ||
yield key, value | ||
for key, value in kwargs.items(): | ||
yield key, value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
"""Tests task context | ||
""" | ||
import unittest | ||
import asyncio | ||
|
||
from pulsar.api import context | ||
|
||
|
||
class TestContext(unittest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
context.setup() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
context.remove() | ||
|
||
def test_task_factory(self): | ||
self.assertEqual(asyncio.get_event_loop().get_task_factory(), context) | ||
|
||
async def test_set_get_pop(self): | ||
context.set('foo', 5) | ||
self.assertEqual(context.get('foo'), 5) | ||
self.assertEqual(context.pop('foo'), 5) | ||
self.assertEqual(context.get('foo'), None) | ||
|
||
async def test_set_get_pop_nested(self): | ||
context.set('foo', 5) | ||
self.assertEqual(context.get('foo'), 5) | ||
await asyncio.get_event_loop().create_task(self.nested()) | ||
self.assertEqual(context.get('foo'), 5) | ||
self.assertEqual(context.get('bla'), None) | ||
|
||
async def test_stack(self): | ||
with context.begin(text='ciao', planet='mars'): | ||
self.assertEqual(context.stack_get('text'), 'ciao') | ||
self.assertEqual(context.stack_get('planet'), 'mars') | ||
self.assertEqual(context.stack_get('text'), None) | ||
self.assertEqual(context.stack_get('planet'), None) | ||
|
||
def test_typeerror(self): | ||
with self.assertRaises(TypeError): | ||
with context.begin(1, 2): | ||
pass | ||
|
||
async def nested(self): | ||
context.set('bla', 7) | ||
self.assertEqual(context.get('bla'), 7) | ||
self.assertEqual(context.get('foo'), 5) | ||
context.set('foo', 8) | ||
self.assertEqual(context.get('foo'), 8) |