Skip to content

Commit

Permalink
async context
Browse files Browse the repository at this point in the history
  • Loading branch information
lsbardel committed Dec 1, 2017
1 parent b0bce14 commit b466aa2
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pulsar/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Unsupported, UnprocessableEntity
)
from .utils.config import Config, Setting
from .utils.context import TaskContext
from .utils.lib import (
HAS_C_EXTENSIONS, EventHandler, Event, ProtocolConsumer, Protocol,
Producer, AbortEvent, isawaitable
Expand All @@ -27,6 +28,9 @@
from .apps.data import data_stores


context = TaskContext()


__all__ = [
#
# Protocols and Config
Expand All @@ -53,6 +57,7 @@
'async_while',
'isawaitable',
'AsyncObject',
'context',
#
# Actor Layer
'get_actor',
Expand Down
117 changes: 117 additions & 0 deletions pulsar/utils/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from asyncio import Task, get_event_loop
from contextlib import contextmanager


class ContextStack(list):
pass


class TaskContext:
_previous_task_factory = None

def __call__(self, loop, coro):
current = Task.current_task(loop=loop)
task = Task(coro, loop=loop)
try:
task._context = current._context.copy()
except AttributeError:
pass
try:
task._context_stack = current._context_stack.copy()
except AttributeError:
pass
return task

def setup(self):
loop = get_event_loop()
self._previous_task_factory = loop.get_task_factory()
loop.set_task_factory(self)

def remove(self):
loop = get_event_loop()
loop.set_task_factory(self._previous_task_factory)

@contextmanager
def begin(self, *args, **kwargs):
for key, value in mapping(*args, **kwargs):
self.stack_push(key, value)
try:
yield self
finally:
for key, _ in mapping(*args, **kwargs):
self.stack_pop(key)

def set(self, key, value):
"""Set a value in the task context
"""
task = Task.current_task()
try:
context = task._context
except AttributeError:
task._context = context = {}
context[key] = value

def get(self, key):
task = Task.current_task()
try:
context = task._context
except AttributeError:
return
return context.get(key)

def pop(self, key):
context = Task.current_task()._context
value = context.pop(key)
if isinstance(value, ContextStack):
stack_value = value.pop()
if value:
context[key] = value
return stack_value
return value

def stack_push(self, key, value):
"""Set a value in a task context stack
"""
task = Task.current_task()
try:
context = task._context_stack
except AttributeError:
task._context_stack = context = {}
if key not in context:
context[key] = []
context[key].append(value)

def stack_get(self, key):
"""Set a value in a task context stack
"""
task = Task.current_task()
try:
context = task._context_stack
except AttributeError:
task._context_stack = context = {}
if key in context:
return context[key][-1]

def stack_pop(self, key):
"""Remove a value in a task context stack
"""
task = Task.current_task()
try:
context = task._context_stack
except AttributeError:
task._context_stack = context = {}
value = context[key]
stack_value = value.pop()
if not value:
context.pop(key)
return stack_value


def mapping(*args, **kwargs):
if args:
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
for key, value in args[0].items():
yield key, value
for key, value in kwargs.items():
yield key, value
52 changes: 52 additions & 0 deletions tests/utils/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Tests task context
"""
import unittest
import asyncio

from pulsar.api import context


class TestContext(unittest.TestCase):

@classmethod
def setUpClass(cls):
context.setup()

@classmethod
def tearDownClass(cls):
context.remove()

def test_task_factory(self):
self.assertEqual(asyncio.get_event_loop().get_task_factory(), context)

async def test_set_get_pop(self):
context.set('foo', 5)
self.assertEqual(context.get('foo'), 5)
self.assertEqual(context.pop('foo'), 5)
self.assertEqual(context.get('foo'), None)

async def test_set_get_pop_nested(self):
context.set('foo', 5)
self.assertEqual(context.get('foo'), 5)
await asyncio.get_event_loop().create_task(self.nested())
self.assertEqual(context.get('foo'), 5)
self.assertEqual(context.get('bla'), None)

async def test_stack(self):
with context.begin(text='ciao', planet='mars'):
self.assertEqual(context.stack_get('text'), 'ciao')
self.assertEqual(context.stack_get('planet'), 'mars')
self.assertEqual(context.stack_get('text'), None)
self.assertEqual(context.stack_get('planet'), None)

def test_typeerror(self):
with self.assertRaises(TypeError):
with context.begin(1, 2):
pass

async def nested(self):
context.set('bla', 7)
self.assertEqual(context.get('bla'), 7)
self.assertEqual(context.get('foo'), 5)
context.set('foo', 8)
self.assertEqual(context.get('foo'), 8)

0 comments on commit b466aa2

Please sign in to comment.