forked from karlicoss/cachew
-
Notifications
You must be signed in to change notification settings - Fork 0
/
marshall.py
258 lines (208 loc) · 9.12 KB
/
marshall.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
import shutil
import sqlite3
import sys
from typing import (
Any,
List,
Literal,
Union,
)
import orjson
import pytest
import pytz
from ..marshall.common import Json
from ..marshall.cachew import CachewMarshall
from ..legacy import NTBinder
from .utils import (
gc_control,
profile,
running_on_ci,
timer,
)
Impl = Literal[
'cachew', # our custom deserialization
'cattrs',
'legacy', # our legacy deserialization
]
# don't include legacy by default, it's only here just for the sake of comparing once before switch
Impls: List[Impl] = ['cachew', 'cattrs']
def do_test(*, test_name: str, Type, factory, count: int, impl: Impl = 'cachew') -> None:
if count > 100 and running_on_ci:
pytest.skip("test too heavy for CI, only meant to run manually")
to_json: Any
from_json: Any
if impl == 'cachew':
marshall = CachewMarshall(Type_=Type)
to_json = marshall.dump
from_json = marshall.load
elif impl == 'legacy':
# NOTE: legacy binder emits a tuple which can be inserted directly into the database
# so 'json dump' and 'json load' should really be disregarded for this flavor
# if you're comparing with <other> implementation, you should compare
# legacy serializing as the sum of <other> serializing + <other> json dump
# that said, this way legacy will have a bit of an advantage since custom types (e.g. datetime)
# would normally be handled by sqlalchemy instead
binder = NTBinder.make(Type)
to_json = binder.to_row
from_json = binder.from_row
elif impl == 'cattrs':
from cattrs import Converter
from cattrs.strategies import configure_tagged_union
converter = Converter()
from typing import get_args, get_origin
from typing import Union
import types
# TODO use later
# def is_union(type_) -> bool:
# origin = get_origin(type_)
# return origin is Union or origin is types.UnionType
def union_structure_hook_factory(_):
def union_hook(data, type_):
args = get_args(type_)
if data is None: # we don't try to coerce None into anything
return None
for t in args:
try:
res = converter.structure(data, t)
print("YAY", data, t)
return res
except Exception:
continue
raise ValueError(f"Could not cast {data} to {type_}")
return union_hook
# borrowed from https://github.com/python-attrs/cattrs/issues/423
# uhh, this doesn't really work straightaway...
# likely need to combine what cattr does with configure_tagged_union
# converter.register_structure_hook_factory(is_union, union_structure_hook_factory)
# configure_tagged_union(
# union=Type,
# converter=converter,
# )
# NOTE: this seems to give a bit of speedup... maybe raise an issue or something?
# fmt: off
unstruct_func = converter._unstructure_func.dispatch(Type) # about 20% speedup
struct_func = converter._structure_func .dispatch(Type) # TODO speedup
# fmt: on
to_json = unstruct_func
# todo would be nice to use partial? but how do we bind a positional arg?
from_json = lambda x: struct_func(x, Type)
else:
assert False
print('', file=sys.stderr) # kinda annoying, pytest starts printing on the same line as test name
with profile(test_name + ':baseline'), timer(f'building {count} objects of type {Type}'):
objects = list(factory(count=count))
jsons: List[Json] = [None for _ in range(count)]
with profile(test_name + ':serialize'), timer(f'serializing {count} objects of type {Type}'):
for i in range(count):
jsons[i] = to_json(objects[i])
strs: List[bytes] = [None for _ in range(count)] # type: ignore
with profile(test_name + ':json_dump'), timer(f'json dump {count} objects of type {Type}'):
for i in range(count):
# TODO any orjson options to speed up?
strs[i] = orjson.dumps(jsons[i]) # pylint: disable=no-member
db = Path('/tmp/cachew_test/db.sqlite')
if db.parent.exists():
shutil.rmtree(db.parent)
db.parent.mkdir()
with profile(test_name + ':sqlite_dump'), timer(f'sqlite dump {count} objects of type {Type}'):
with sqlite3.connect(db) as conn:
conn.execute('CREATE TABLE data (value BLOB)')
conn.executemany('INSERT INTO data (value) VALUES (?)', [(s,) for s in strs])
conn.close()
strs2: List[bytes] = [None for _ in range(count)] # type: ignore
with profile(test_name + ':sqlite_load'), timer(f'sqlite load {count} objects of type {Type}'):
with sqlite3.connect(db) as conn:
i = 0
for (value,) in conn.execute('SELECT value FROM data'):
strs2[i] = value
i += 1
conn.close()
cache = db.parent / 'cache.jsonl'
with profile(test_name + ':jsonl_dump'), timer(f'jsonl dump {count} objects of type {Type}'):
with cache.open('wb') as fw:
for s in strs:
fw.write(s + b'\n')
strs3: List[bytes] = [None for _ in range(count)] # type: ignore
with profile(test_name + ':jsonl_load'), timer(f'jsonl load {count} objects of type {Type}'):
i = 0
with cache.open('rb') as fr:
for l in fr:
l = l.rstrip(b'\n')
strs3[i] = l
i += 1
assert strs2[:100] + strs2[-100:] == strs3[:100] + strs3[-100:] # just in case
jsons2: List[Json] = [None for _ in range(count)]
with profile(test_name + ':json_load'), timer(f'json load {count} objects of type {Type}'):
for i in range(count):
# TODO any orjson options to speed up?
jsons2[i] = orjson.loads(strs2[i]) # pylint: disable=no-member
objects2 = [None for _ in range(count)]
with profile(test_name + ':deserialize'), timer(f'deserializing {count} objects of type {Type}'):
for i in range(count):
objects2[i] = from_json(jsons2[i])
assert objects[:100] + objects[-100:] == objects2[:100] + objects2[-100:]
@dataclass
class Name:
first: str
last: str
@pytest.mark.parametrize('impl', Impls)
@pytest.mark.parametrize('count', [99, 1_000_000, 5_000_000])
@pytest.mark.parametrize('gc_on', [True, False], ids=['gc_on', 'gc_off'])
def test_union_str_dataclass(impl: Impl, count: int, gc_control, request) -> None:
# NOTE: previously was union_str_namedtuple, but adapted to work with cattrs for now
# perf difference between datacalss/namedtuple here seems negligible so old benchmark results should apply
if impl == 'cattrs':
pytest.skip('TODO need to adjust the handling of Union types..')
def factory(count: int):
objects: List[Union[str, Name]] = []
for i in range(count):
if i % 2 == 0:
objects.append(str(i))
else:
objects.append(Name(first=f'first {i}', last=f'last {i}'))
return objects
do_test(test_name=request.node.name, Type=Union[str, Name], factory=factory, count=count, impl=impl)
# OK, performance with calling this manually (not via pytest) is the same
# do_test_union_str_dataclass(count=1_000_000, test_name='adhoc')
@pytest.mark.parametrize('impl', Impls)
@pytest.mark.parametrize('count', [99, 1_000_000, 5_000_000])
@pytest.mark.parametrize('gc_on', [True, False], ids=['gc_on', 'gc_off'])
def test_datetimes(impl: Impl, count: int, gc_control, request) -> None:
if impl == 'cattrs':
pytest.skip('TODO support datetime with pytz for cattrs')
def factory(*, count: int):
tzs = [
pytz.timezone('Europe/Berlin'),
timezone.utc,
pytz.timezone('America/New_York'),
]
start = datetime.fromisoformat('1990-01-01T00:00:00')
end = datetime.fromisoformat('2030-01-01T00:00:00')
step = (end - start) / count
for i in range(count):
dt = start + step * i
tz = tzs[i % len(tzs)]
yield dt.replace(tzinfo=tz)
do_test(test_name=request.node.name, Type=datetime, factory=factory, count=count, impl=impl)
@pytest.mark.parametrize('impl', Impls)
@pytest.mark.parametrize('count', [99, 1_000_000])
@pytest.mark.parametrize('gc_on', [True, False], ids=['gc_on', 'gc_off'])
def test_nested_dataclass(impl: Impl, count: int, gc_control, request) -> None:
# NOTE: was previously named test_many_from_cachew
@dataclass
class UUU:
xx: int
yy: int
@dataclass
class TE2:
value: int
uuu: UUU
value2: int
def factory(*, count: int):
for i in range(count):
yield TE2(value=i, uuu=UUU(xx=i, yy=i), value2=i)
do_test(test_name=request.node.name, Type=TE2, factory=factory, count=count, impl=impl)
# TODO next test should probs be runtimeerror?