forked from PrefectHQ/prefect
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_states.py
346 lines (280 loc) · 11.7 KB
/
test_states.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import uuid
import pytest
from prefect import flow
from prefect.exceptions import CancelledRun, CrashedRun, FailedRun
from prefect.results import (
LiteralResult,
PersistedResult,
ResultFactory,
UnpersistedResult,
)
from prefect.states import (
Cancelled,
Completed,
Crashed,
Failed,
Pending,
Running,
State,
StateGroup,
is_state,
is_state_iterable,
raise_state_exception,
return_value_to_state,
)
from prefect.utilities.annotations import quote
def test_is_state():
assert is_state(Completed())
def test_is_not_state():
assert not is_state(None)
assert not is_state("test")
def test_is_state_requires_instance():
assert not is_state(Completed)
@pytest.mark.parametrize("iterable_type", [set, list, tuple])
def test_is_state_iterable(iterable_type):
assert is_state_iterable(iterable_type([Completed(), Completed()]))
def test_is_not_state_iterable_if_unsupported_iterable_type():
assert not is_state_iterable({Completed(): i for i in range(3)})
@pytest.mark.parametrize("iterable_type", [set, list, tuple])
def test_is_not_state_iterable_if_empty(iterable_type):
assert not is_state_iterable(iterable_type())
@pytest.mark.parametrize("state_cls", [Failed, Crashed, Cancelled])
class TestRaiseStateException:
def test_works_in_sync_context(self, state_cls):
with pytest.raises(ValueError, match="Test"):
raise_state_exception(state_cls(data=ValueError("Test")))
async def test_raises_state_exception(self, state_cls):
with pytest.raises(ValueError, match="Test"):
await raise_state_exception(state_cls(data=ValueError("Test")))
async def test_returns_without_error_for_completed_states(self, state_cls):
assert await raise_state_exception(Completed()) is None
async def test_raises_nested_state_exception(self, state_cls):
with pytest.raises(ValueError, match="Test"):
await raise_state_exception(state_cls(data=Failed(data=ValueError("Test"))))
async def test_raises_value_error_if_nested_state_is_not_failed(self, state_cls):
with pytest.raises(
ValueError, match="Expected failed or crashed state got Completed"
):
await raise_state_exception(state_cls(data=Completed(data="test")))
async def test_raises_first_nested_multistate_exception(self, state_cls):
# TODO: We may actually want to raise a "multi-error" here where we have several
# exceptions displayed at once
inner_states = [
Completed(data="test"),
Failed(data=ValueError("Test")),
Failed(data=ValueError("Should not be raised")),
]
with pytest.raises(ValueError, match="Test"):
await raise_state_exception(state_cls(data=inner_states))
async def test_value_error_if_all_multistates_are_not_failed(self, state_cls):
inner_states = [
Completed(),
Completed(),
Completed(data=ValueError("Should not be raised")),
]
with pytest.raises(
ValueError,
match="Failed state result was an iterable of states but none were failed",
):
await raise_state_exception(state_cls(data=inner_states))
@pytest.mark.parametrize("value", ["foo", LiteralResult(value="foo")])
async def test_raises_wrapper_with_message_if_result_is_string(
self, state_cls, value
):
state_to_exception = {
Failed: FailedRun,
Crashed: CrashedRun,
Cancelled: CancelledRun,
}
with pytest.raises(state_to_exception[state_cls]):
await raise_state_exception(state_cls(data=value))
async def test_raises_base_exception(self, state_cls):
with pytest.raises(BaseException):
await raise_state_exception(state_cls(data=BaseException("foo")))
async def test_raises_wrapper_with_state_message_if_result_is_null(self, state_cls):
state_to_exception = {
Failed: FailedRun,
Crashed: CrashedRun,
Cancelled: CancelledRun,
}
with pytest.raises(state_to_exception[state_cls]):
await raise_state_exception(state_cls(data=None, message="foo"))
async def test_raises_error_if_failed_state_does_not_contain_exception(
self, state_cls
):
with pytest.raises(TypeError, match="int cannot be resolved into an exception"):
await raise_state_exception(state_cls(data=2))
async def test_quoted_state_does_not_raise_state_exception(self, state_cls):
@flow
def test_flow():
return quote(state_cls())
actual = test_flow()
assert isinstance(actual, quote)
assert isinstance(actual.unquote(), State)
class TestReturnValueToState:
@pytest.fixture
async def factory(self, prefect_client):
return await ResultFactory.default_factory(client=prefect_client)
async def test_returns_single_state_unaltered(self, factory):
state = Completed(data="hello!")
assert await return_value_to_state(state, factory) is state
async def test_returns_single_state_with_null_data(self, factory):
state = Completed(data=None)
result_state = await return_value_to_state(state, factory)
assert result_state is state
assert isinstance(result_state.data, UnpersistedResult)
assert await result_state.result() is None
async def test_returns_single_state_with_data_to_persist(self, factory):
factory.persist_result = True
state = Completed(data=1)
result_state = await return_value_to_state(state, factory)
assert result_state is state
assert isinstance(result_state.data, PersistedResult)
assert await result_state.result() == 1
async def test_returns_single_state_unaltered_with_user_created_reference(
self, factory
):
result = await factory.create_result("test")
state = Completed(data=result)
result_state = await return_value_to_state(state, factory)
assert result_state is state
# Pydantic makes a copy of the result type during state so we cannot assert that
# it is the original `result` object but we can assert there is not a copy in
# `return_value_to_state`
assert result_state.data is state.data
assert result_state.data == result
assert await result_state.result() == "test"
async def test_all_completed_states(self, factory):
states = [Completed(message="hi"), Completed(message="bye")]
result_state = await return_value_to_state(states, factory)
# States have been stored as data
assert await result_state.result() == states
# Message explains aggregate
assert result_state.message == "All states completed."
# Aggregate type is completed
assert result_state.is_completed()
async def test_some_failed_states(self, factory):
states = [
Completed(message="hi"),
Failed(message="bye"),
Failed(message="err"),
]
result_state = await return_value_to_state(states, factory)
# States have been stored as data
assert await result_state.result(raise_on_failure=False) == states
# Message explains aggregate
assert result_state.message == "2/3 states failed."
# Aggregate type is failed
assert result_state.is_failed()
async def test_some_unfinal_states(self, factory):
states = [
Completed(message="hi"),
Running(message="bye"),
Pending(message="err"),
]
result_state = await return_value_to_state(states, factory)
# States have been stored as data
assert await result_state.result(raise_on_failure=False) == states
# Message explains aggregate
assert result_state.message == "2/3 states are not final."
# Aggregate type is failed
assert result_state.is_failed()
@pytest.mark.parametrize("run_identifier", ["task_run_id", "flow_run_id"])
async def test_single_state_in_future_is_processed(self, run_identifier, factory):
state = Completed(data="test", state_details={run_identifier: uuid.uuid4()})
# The engine is responsible for resolving the futures
result_state = await return_value_to_state(state, factory)
assert await result_state.result() == state
assert result_state.is_completed()
assert result_state.message == "All states completed."
async def test_non_prefect_types_return_completed_state(self, factory):
result_state = await return_value_to_state("foo", factory)
assert result_state.is_completed()
assert await result_state.result() == "foo"
class TestStateGroup:
def test_fail_count(self):
states = [
Failed(data=ValueError("1")),
Failed(data=ValueError("2")),
Failed(data=ValueError("3")),
Crashed(data=ValueError("4")),
Crashed(data=ValueError("5")),
]
assert StateGroup(states).fail_count == 5
def test_all_completed(self):
states = [
Completed(data="test"),
Completed(data="test"),
Completed(data="test"),
]
assert StateGroup(states).all_completed()
states = [
Completed(data="test"),
Failed(data=ValueError("1")),
]
assert not StateGroup(states).all_completed()
def test_any_cancelled(self):
states = [
Cancelled(),
Failed(data=ValueError("1")),
]
assert StateGroup(states).any_cancelled()
states = [
Completed(data="test"),
Failed(data=ValueError("1")),
]
assert not StateGroup(states).any_cancelled()
def test_any_failed(self):
states = [
Completed(data="test"),
Failed(data=ValueError("1")),
]
assert StateGroup(states).any_failed()
states = [
Completed(data="test"),
Completed(data="test"),
Completed(data="test"),
]
assert not StateGroup(states).any_failed()
def test_all_final(self):
states = [
Failed(data=ValueError("failed")),
Crashed(data=ValueError("crashed")),
Completed(data="complete"),
Cancelled(data="cancelled"),
]
assert StateGroup(states).all_final()
states = [
Failed(data=ValueError("failed")),
Crashed(data=ValueError("crashed")),
Completed(data="complete"),
Cancelled(data="cancelled"),
Running(),
]
assert not StateGroup(states).all_final()
def test_counts_message_all_final(self):
states = [
Failed(data=ValueError("failed")),
Crashed(data=ValueError("crashed")),
Completed(data="complete"),
Cancelled(data="cancelled"),
]
counts_message = StateGroup(states).counts_message()
assert "total=4" in counts_message
assert "'FAILED'=1" in counts_message
assert "'CRASHED'=1" in counts_message
assert "'COMPLETED'=1" in counts_message
assert "'CANCELLED'=1" in counts_message
def test_counts_message_some_non_final(self):
states = [
Failed(data=ValueError("failed")),
Running(),
Crashed(data=ValueError("crashed")),
Running(),
]
counts_message = StateGroup(states).counts_message()
assert "total=4" in counts_message
assert "not_final=2" in counts_message
assert "'FAILED'=1" in counts_message
assert "'CRASHED'=1" in counts_message
assert "'RUNNING'=2" in counts_message