Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a workaround for CloudWatch GetLogEvents empty results #1652

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions src/dstack/_internal/server/services/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,10 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi
stream = self._get_stream_name(
project.name, request.run_name, request.job_submission_id, log_producer
)
parameters = {
"logGroupName": self._group,
"logStreamName": stream,
"limit": request.limit,
"startFromHead": (not request.descending),
}
if request.start_time:
# XXX: Since callers use start_time/end_time for pagination, one millisecond is added
# to avoid an infinite loop because startTime boundary is inclusive.
parameters["startTime"] = _datetime_to_unix_time_ms(request.start_time) + 1
if request.end_time:
# No need to substract one millisecond in this case, though, seems that endTime is
# exclusive, that is, time interval boundaries are [startTime, entTime)
parameters["endTime"] = _datetime_to_unix_time_ms(request.end_time)
cw_events: List[_CloudWatchLogEvent]
with self._wrap_boto_errors():
try:
response = self._client.get_log_events(**parameters)
cw_events = response["events"]
cw_events = self._get_log_events(stream, request)
except botocore.exceptions.ClientError as e:
if not self._is_resource_not_found_exception(e):
raise
Expand All @@ -122,6 +107,42 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi
]
return JobSubmissionLogs(logs=logs)

def _get_log_events(self, stream: str, request: PollLogsRequest) -> List[_CloudWatchLogEvent]:
parameters = {
"logGroupName": self._group,
"logStreamName": stream,
"limit": request.limit,
}
start_from_head = not request.descending
parameters["startFromHead"] = start_from_head
if request.start_time:
# XXX: Since callers use start_time/end_time for pagination, one millisecond is added
# to avoid an infinite loop because startTime boundary is inclusive.
parameters["startTime"] = _datetime_to_unix_time_ms(request.start_time) + 1
if request.end_time:
# No need to substract one millisecond in this case, though, seems that endTime is
# exclusive, that is, time interval boundaries are [startTime, entTime)
parameters["endTime"] = _datetime_to_unix_time_ms(request.end_time)
response = self._client.get_log_events(**parameters)
events: List[_CloudWatchLogEvent] = response["events"]
if start_from_head or events:
return events
# Workaround for https://github.com/boto/boto3/issues/3718
# Required only when startFromHead = false (the default value).
next_token: str = response["nextBackwardToken"]
# Limit max tries to avoid a possible infinite loop if the API is misbehaving
tries_left = 10
while tries_left:
parameters["nextToken"] = next_token
response = self._client.get_log_events(**parameters)
events = response["events"]
if events or response["nextBackwardToken"] == next_token:
return events
next_token = response["nextBackwardToken"]
tries_left -= 1
logger.warning("too many empty responses from stream %s, returning dummy response", stream)
return []

def write_logs(
self,
project: ProjectModel,
Expand Down
164 changes: 149 additions & 15 deletions src/tests/_internal/server/services/test_logs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import Mock, call
Expand Down Expand Up @@ -61,7 +62,11 @@ async def project(self, test_db, session: AsyncSession) -> ProjectModel:
def mock_client(self, monkeypatch: pytest.MonkeyPatch) -> Mock:
mock = Mock()
monkeypatch.setattr("boto3.Session.client", Mock(return_value=mock))
mock.get_log_events.return_value = {"events": []}
mock.get_log_events.return_value = {
"events": [],
"nextBackwardToken": "bwd",
"nextFormartToken": "fwd",
}
return mock

@pytest.fixture
Expand Down Expand Up @@ -160,19 +165,17 @@ def test_ensure_stream_exists_cached_forced(
)

@pytest.mark.asyncio
async def test_poll_logs_response(
async def test_poll_logs_non_empty_response(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
mock_client.get_log_events.return_value = {
"events": [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
{"timestamp": 1696586513235, "message": "V29ybGQ="},
]
}
mock_client.get_log_events.return_value["events"] = [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
{"timestamp": 1696586513235, "message": "V29ybGQ="},
]
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

assert job_submission_logs.logs == [
Expand All @@ -189,19 +192,33 @@ async def test_poll_logs_response(
]

@pytest.mark.asyncio
async def test_poll_logs_response_descending(
async def test_poll_logs_empty_response(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
mock_client.get_log_events.return_value = {
"events": [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
{"timestamp": 1696586513235, "message": "V29ybGQ="},
]
}
# Check that we don't use the workaround when descending=False -> startFromHead=True
# https://github.com/dstackai/dstack/issues/1647
mock_client.get_log_events.return_value["events"] = []
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

assert job_submission_logs.logs == []
mock_client.get_log_events.assert_called_once()

@pytest.mark.asyncio
async def test_poll_logs_descending_non_empty_response_on_first_call(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
mock_client.get_log_events.return_value["events"] = [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
{"timestamp": 1696586513235, "message": "V29ybGQ="},
]
poll_logs_request.descending = True
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

Expand All @@ -218,6 +235,118 @@ async def test_poll_logs_response_descending(
),
]

@pytest.mark.asyncio
async def test_poll_logs_descending_two_first_calls_return_empty_response(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
# The first two calls return empty event lists, though the token is not the same, meaning
# there are more events.
# https://github.com/dstackai/dstack/issues/1647
mock_client.get_log_events.side_effect = [
{
"events": [],
"nextBackwardToken": "bwd1",
"nextForwardToken": "fwd",
},
{
"events": [],
"nextBackwardToken": "bwd2",
"nextForwardToken": "fwd",
},
{
"events": [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
{"timestamp": 1696586513235, "message": "V29ybGQ="},
],
"nextBackwardToken": "bwd3",
"nextForwardToken": "fwd",
},
]
poll_logs_request.descending = True
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

assert job_submission_logs.logs == [
LogEvent(
timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc),
log_source=LogEventSource.STDOUT,
message="V29ybGQ=",
),
LogEvent(
timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc),
log_source=LogEventSource.STDOUT,
message="SGVsbG8=",
),
]
assert mock_client.get_log_events.call_count == 3

@pytest.mark.asyncio
async def test_poll_logs_descending_empty_response_with_same_token(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
# The first two calls return empty event lists with the same token, meaning we reached
# the end.
# https://github.com/dstackai/dstack/issues/1647
mock_client.get_log_events.side_effect = [
{
"events": [],
"nextBackwardToken": "bwd",
"nextForwardToken": "fwd",
},
{
"events": [],
"nextBackwardToken": "bwd",
"nextForwardToken": "fwd",
},
# We should not reach this response
{
"events": [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
],
"nextBackwardToken": "bwd2",
"nextForwardToken": "fwd",
},
]
poll_logs_request.descending = True
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

assert job_submission_logs.logs == []
assert mock_client.get_log_events.call_count == 2

@pytest.mark.asyncio
async def test_poll_logs_descending_empty_response_max_tries(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
# Test for a circuit breaker when the API returns empty results on each call, but the
# token is different on each call.
# https://github.com/dstackai/dstack/issues/1647
counter = itertools.count()

def _response_producer(*args, **kwargs):
return {
"events": [],
"nextBackwardToken": f"bwd{next(counter)}",
"nextForwardToken": "fwd",
}

mock_client.get_log_events.side_effect = _response_producer
poll_logs_request.descending = True
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

assert job_submission_logs.logs == []
assert mock_client.get_log_events.call_count == 11 # initial call + 10 tries

@pytest.mark.asyncio
async def test_poll_logs_request_params_asc_no_diag_no_dates(
self,
Expand Down Expand Up @@ -245,6 +374,11 @@ async def test_poll_logs_request_params_desc_diag_with_dates(
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
# Ensure the first response has events to avoid triggering a workaround for
# https://github.com/dstackai/dstack/issues/1647
mock_client.get_log_events.return_value["events"] = [
{"timestamp": 1696586513234, "message": "SGVsbG8="}
]
poll_logs_request.start_time = datetime(
2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc
)
Expand Down