Skip to content

Commit

Permalink
♻️ Standardize types and improve consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
odiseo0 committed Jun 19, 2022
1 parent e0ef5b1 commit dd71d42
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 26 deletions.
26 changes: 13 additions & 13 deletions realtime/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

import asyncio
import json
from collections import namedtuple
from typing import List, TYPE_CHECKING
from typing import Any, List, Dict, TYPE_CHECKING, NamedTuple

from realtime.types import Callback

if TYPE_CHECKING:
from realtime.connection import Socket

"""
Callback Listener is a tuple with `event` and `callback`
"""
CallbackListener = namedtuple("CallbackListener", "event callback")

class CallbackListener(NamedTuple):
"""A tuple with `event` and `callback` """
event: str
callback: Callback


class Channel:
Expand All @@ -22,17 +24,17 @@ class Channel:
Topic-Channel has a 1-many relationship.
"""

def __init__(self, socket: Socket, topic: str, params: dict = {}) -> None:
def __init__(self, socket: Socket, topic: str, params: Dict[str, Any] = {}) -> None:
"""
:param socket: Socket object
:param topic: Topic that it subscribes to on the realtime server
:param params:
"""
self.socket = socket
self.topic: str = topic
self.params: dict = params
self.params = params
self.topic = topic
self.listeners: List[CallbackListener] = []
self.joined: bool = False
self.joined = False

def join(self) -> Channel:
"""
Expand All @@ -54,18 +56,16 @@ async def _join(self) -> None:

try:
await self.socket.ws_connection.send(json.dumps(join_req))

except Exception as e:
print(str(e)) # TODO: better error propagation
return

def on(self, event: str, callback) -> Channel:
def on(self, event: str, callback: Callback) -> Channel:
"""
:param event: A specific event will have a specific callback
:param callback: Callback that takes msg payload as its first argument
:return: Channel
"""

cl = CallbackListener(event=event, callback=callback)
self.listeners.append(cl)
return self
Expand Down
23 changes: 13 additions & 10 deletions realtime/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
import logging
from collections import defaultdict
from functools import wraps
from typing import Any, Callable
from typing import Any, Callable, List, Dict, cast

import websockets

from realtime.channel import Channel
from realtime.exceptions import NotConnectedError
from realtime.message import HEARTBEAT_PAYLOAD, PHOENIX_CHANNEL, ChannelEvents, Message
from realtime.types import T_ParamSpec, T_Retval

logging.basicConfig(
format="%(asctime)s:%(levelname)s - %(message)s", level=logging.INFO)


def ensure_connection(func: Callable):
def ensure_connection(func: Callable[T_ParamSpec, T_Retval]):
@wraps(func)
def wrapper(*args: Any, **kwargs: Any):
def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
if not args[0].connected:
raise NotConnectedError(func.__name__)

Expand All @@ -27,7 +28,7 @@ def wrapper(*args: Any, **kwargs: Any):


class Socket:
def __init__(self, url: str, params: dict = {}, hb_interval: int = 5) -> None:
def __init__(self, url: str, params: Dict[str, Any] = {}, hb_interval: int = 5) -> None:
"""
`Socket` is the abstraction for an actual socket connection that receives and 'reroutes' `Message` according to its `topic` and `event`.
Socket-Channel has a 1-many relationship.
Expand All @@ -39,10 +40,12 @@ def __init__(self, url: str, params: dict = {}, hb_interval: int = 5) -> None:
self.url = url
self.channels = defaultdict(list)
self.connected = False
self.params: dict = params
self.hb_interval: int = hb_interval
self.params = params
self.hb_interval = hb_interval
self.ws_connection: websockets.client.WebSocketClientProtocol
self.kept_alive: bool = False
self.kept_alive = False

self.channels = cast(defaultdict[str, List[Channel]], self.channels)

@ensure_connection
def listen(self) -> None:
Expand All @@ -64,13 +67,14 @@ async def _listen(self) -> None:
try:
msg = await self.ws_connection.recv()
msg = Message(**json.loads(msg))

if msg.event == ChannelEvents.reply:
continue

for channel in self.channels.get(msg.topic, []):
for cl in channel.listeners:
if cl.event == msg.event:
cl.callback(msg.payload)

except websockets.exceptions.ConnectionClosed:
logging.exception("Connection closed")
break
Expand All @@ -84,8 +88,8 @@ def connect(self) -> None:
self.connected = True

async def _connect(self) -> None:

ws_connection = await websockets.connect(self.url)

if ws_connection.open:
logging.info("Connection was successful")
self.ws_connection = ws_connection
Expand Down Expand Up @@ -119,7 +123,6 @@ def set_channel(self, topic: str) -> Channel:
:param topic: Initializes a channel and creates a two-way association with the socket
:return: Channel
"""

chan = Channel(self, topic, self.params)
self.channels[topic].append(chan)

Expand Down
5 changes: 2 additions & 3 deletions realtime/message.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any
from typing import Any, Dict, Any


@dataclass
class Message:
"""
Dataclass abstraction for message
"""

event: str
payload: dict
payload: Dict[str, Any]
ref: Any
topic: str

Expand Down
16 changes: 16 additions & 0 deletions realtime/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import sys
from typing import Callable, TypeVar

if sys.version_info > (3, 9):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


# Generic types
T = TypeVar("T")
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")

# Custom types
Callback = Callable[T_ParamSpec, T_Retval]

0 comments on commit dd71d42

Please sign in to comment.