Skip to content

Commit

Permalink
AMDA: Few small refac
Browse files Browse the repository at this point in the history
Signed-off-by: Alexis Jeandet <alexis.jeandet@member.fsf.org>
  • Loading branch information
jeandet committed Aug 13, 2021
1 parent 79749fc commit 2067e3c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 19 deletions.
9 changes: 6 additions & 3 deletions speasy/common/catalog.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .datetime_range import DateTimeRange
from datetime import datetime
from typing import List
from ..common import listify
from ..common import listify, all_of_type


def _all_are_events(event_list):
return all(map(lambda e: type(e) == Event, event_list))
return all_of_type(event_list, Event)


class Event(DateTimeRange):
Expand All @@ -17,7 +17,7 @@ def __init__(self, start_time: datetime, stop_time: datetime, meta=None):

def __eq__(self, other):
return (self.meta == other.meta) and super().__eq__(other)

def __repr__(self):
return f"<Event: {self.start_time.isoformat()} -> {self.stop_time.isoformat()} | {self.meta}>"

Expand Down Expand Up @@ -51,3 +51,6 @@ def __iadd__(self, other: Event or List[Event]):

def pop(self, index=-1):
return self._events.pop(index)

def __repr__(self):
return f"""<Catalog: {self.name}>"""
7 changes: 5 additions & 2 deletions speasy/common/timetable.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .datetime_range import DateTimeRange
from typing import List
from ..common import listify
from ..common import listify, all_of_type


def _all_are_datetime_ranges(dt_list):
return all(map(lambda e: type(e) == DateTimeRange, dt_list))
return all_of_type(dt_list, DateTimeRange)


class TimeTable:
Expand Down Expand Up @@ -36,3 +36,6 @@ def __iadd__(self, other: DateTimeRange or List[DateTimeRange]):

def pop(self, index=-1):
return self._storage.pop(index)

def __repr__(self):
return f"""<TimeTable: {self.name}>"""
25 changes: 11 additions & 14 deletions speasy/common/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,23 @@ class SpeasyVariable(object):
:type meta: dict
:param columns: column names
:type columns: list[str]
:param y:
:type y:
:param y:
:type y:
"""
__slots__ = ['meta', 'time', 'values', 'columns', 'y']

def __init__(self, time=np.empty(0), data=np.empty((0, 1)), meta=None, columns=None, y=None):
def __init__(self, time=np.empty(0), data=np.empty((0, 1)), meta: Optional[dict] = None,
columns: Optional[list[str]] = None, y: Optional[np.ndarray] = None):
"""Constructor
"""
if meta is None:
meta = dict()
if columns is None:
columns = []
self.meta = meta
self.meta = meta or {}
self.columns = columns or []
if len(data.shape) == 1:
self.values = data.reshape((data.shape[0], 1)) # to be consistent with pandas
else:
self.values = data
self.time = time
self.columns = columns
self.y = y

def view(self, time_range):
Expand Down Expand Up @@ -71,7 +68,7 @@ def __getitem__(self, key):
"""Item getter
:param key: key
:type key: slice
:type key: slice
:return: data slice
:rtype: speasy.common.variable.SpeasyVariable
"""
Expand Down Expand Up @@ -135,15 +132,15 @@ def from_dataframe(df: pds.DataFrame) -> 'SpeasyVariable':
time = np.array([d.timestamp() for d in df.index])
else:
time = df.index.values
return SpeasyVariable(time, df.values, {}, [c for c in df.columns])
return SpeasyVariable(time=time, data=df.values, meta={}, columns=list(df.columns))


def from_dataframe(df: pds.DataFrame) -> SpeasyVariable:
"""Convert a dataframe to SpeasyVariable.
:param df: input dataframe
:type df: pandas.DataFrame
:return: speasy variable
:return: speasy variable
:rtype: speasy.common.variable.SpeasyVariable
"""
return SpeasyVariable.from_dataframe(df)
Expand Down Expand Up @@ -199,8 +196,8 @@ def merge(variables: List[SpeasyVariable]) -> Optional[SpeasyVariable]:
data = np.zeros((dest_len, sorted_var_list[0].values.shape[1])) if len(
sorted_var_list[0].values.shape) == 2 else np.zeros(dest_len)

units = [var.values.unit for var in sorted_var_list if hasattr(var.values, 'unit')]
if len(units):
units = set([var.values.unit for var in sorted_var_list if hasattr(var.values, 'unit')])
if len(units) == 1:
data *= units[0]

pos = 0
Expand Down

0 comments on commit 2067e3c

Please sign in to comment.