Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
Signed-off-by: Levko Kravets <levko.ne@gmail.com>
  • Loading branch information
kravets-levko committed Jul 2, 2024
1 parent cdebaf4 commit 9e204f7
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 262 deletions.
27 changes: 8 additions & 19 deletions tests/unit/test_cloud_fetch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,25 @@ def test_initializer_adds_links(self, mock_create_next_table):
result_links = self.create_result_links(10)
queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10)

assert len(queue.download_manager.download_handlers) == 10
assert len(queue.download_manager._pending_links) == 10
assert len(queue.download_manager._download_tasks) == 0
mock_create_next_table.assert_called()

def test_initializer_no_links_to_add(self):
schema_bytes = MagicMock()
result_links = []
queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10)

assert len(queue.download_manager.download_handlers) == 0
assert len(queue.download_manager._pending_links) == 0
assert len(queue.download_manager._download_tasks) == 0
assert queue.table is None

@patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None)
def test_create_next_table_no_download(self, mock_get_next_downloaded_file):
queue = utils.CloudFetchQueue(MagicMock(), result_links=[], max_download_threads=10)

assert queue._create_next_table() is None
assert mock_get_next_downloaded_file.called_with(0)
mock_get_next_downloaded_file.assert_called_with(0)

@patch("databricks.sql.utils.create_arrow_table_from_arrow_file")
@patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file",
Expand All @@ -76,8 +78,8 @@ def test_initializer_create_next_table_success(self, mock_get_next_downloaded_fi
queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10)
expected_result = self.make_arrow_table()

assert mock_create_arrow_table.called_with(b"1234567890", True, schema_bytes, description)
assert mock_get_next_downloaded_file.called_with(0)
mock_get_next_downloaded_file.assert_called_with(0)
mock_create_arrow_table.assert_called_with(b"1234567890", description)
assert queue.table == expected_result
assert queue.table.num_rows == 4
assert queue.table_row_index == 0
Expand Down Expand Up @@ -130,20 +132,6 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table):
assert queue.table_row_index == 3
assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7]

@patch("databricks.sql.utils.CloudFetchQueue._create_next_table")
def test_next_n_rows_more_than_one_table(self, mock_create_next_table):
mock_create_next_table.return_value = self.make_arrow_table()
schema_bytes, description = MagicMock(), MagicMock()
queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10)
assert queue.table == self.make_arrow_table()
assert queue.table.num_rows == 4
assert queue.table_row_index == 0

result = queue.next_n_rows(7)
assert result.num_rows == 7
assert queue.table_row_index == 3
assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7]

@patch("databricks.sql.utils.CloudFetchQueue._create_next_table")
def test_next_n_rows_only_one_table_returned(self, mock_create_next_table):
mock_create_next_table.side_effect = [self.make_arrow_table(), None]
Expand All @@ -165,6 +153,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table):
assert queue.table is None

result = queue.next_n_rows(100)
mock_create_next_table.assert_called()
assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all()

@patch("databricks.sql.utils.CloudFetchQueue._create_next_table")
Expand Down
178 changes: 14 additions & 164 deletions tests/unit/test_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from unittest.mock import patch, MagicMock

import databricks.sql.cloudfetch.download_manager as download_manager
import databricks.sql.cloudfetch.downloader as downloader
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink


Expand All @@ -11,10 +10,8 @@ class DownloadManagerTests(unittest.TestCase):
Unit tests for checking download manager logic.
"""

def create_download_manager(self):
max_download_threads = 10
lz4_compressed = True
return download_manager.ResultFileDownloadManager(max_download_threads, lz4_compressed)
def create_download_manager(self, links, max_download_threads=10, lz4_compressed=True):
return download_manager.ResultFileDownloadManager(links, max_download_threads, lz4_compressed)

def create_result_link(
self,
Expand All @@ -36,172 +33,25 @@ def create_result_links(self, num_files: int, start_row_offset: int = 0):

def test_add_file_links_zero_row_count(self):
links = [self.create_result_link(row_count=0, bytes_num=0)]
manager = self.create_download_manager()
manager.add_file_links(links)
manager = self.create_download_manager(links)

assert not manager.download_handlers
assert len(manager._pending_links) == 0 # the only link supplied contains no data, so should be skipped
assert len(manager._download_tasks) == 0

def test_add_file_links_success(self):
links = self.create_result_links(num_files=10)
manager = self.create_download_manager()
manager.add_file_links(links)
manager = self.create_download_manager(links)

assert len(manager.download_handlers) == 10

def test_remove_past_handlers_one(self):
links = self.create_result_links(num_files=10)
manager = self.create_download_manager()
manager.add_file_links(links)

manager._remove_past_handlers(8000)
assert len(manager.download_handlers) == 9

def test_remove_past_handlers_all(self):
links = self.create_result_links(num_files=10)
manager = self.create_download_manager()
manager.add_file_links(links)

manager._remove_past_handlers(8000*10)
assert len(manager.download_handlers) == 0

@patch("concurrent.futures.ThreadPoolExecutor.submit")
def test_schedule_downloads_partial_already_scheduled(self, mock_submit):
links = self.create_result_links(num_files=10)
manager = self.create_download_manager()
manager.add_file_links(links)

for i in range(5):
manager.download_handlers[i].is_download_scheduled = True

manager._schedule_downloads()
assert mock_submit.call_count == 5
assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 10

@patch("concurrent.futures.ThreadPoolExecutor.submit")
def test_schedule_downloads_will_not_schedule_twice(self, mock_submit):
links = self.create_result_links(num_files=10)
manager = self.create_download_manager()
manager.add_file_links(links)

for i in range(5):
manager.download_handlers[i].is_download_scheduled = True

manager._schedule_downloads()
assert mock_submit.call_count == 5
assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 10

manager._schedule_downloads()
assert mock_submit.call_count == 5

@patch("concurrent.futures.ThreadPoolExecutor.submit", side_effect=[True, KeyError("foo")])
def test_schedule_downloads_submit_fails(self, mock_submit):
links = self.create_result_links(num_files=10)
manager = self.create_download_manager()
manager.add_file_links(links)

manager._schedule_downloads()
assert mock_submit.call_count == 2
assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 1

@patch("concurrent.futures.ThreadPoolExecutor.submit")
def test_find_next_file_index_all_scheduled_next_row_0(self, mock_submit):
links = self.create_result_links(num_files=10)
manager = self.create_download_manager()
manager.add_file_links(links)
manager._schedule_downloads()

assert manager._find_next_file_index(0) == 0
assert len(manager._pending_links) == len(links)
assert len(manager._download_tasks) == 0

@patch("concurrent.futures.ThreadPoolExecutor.submit")
def test_find_next_file_index_all_scheduled_next_row_7999(self, mock_submit):
def test_schedule_downloads(self, mock_submit):
max_download_threads = 4
links = self.create_result_links(num_files=10)
manager = self.create_download_manager()
manager.add_file_links(links)
manager._schedule_downloads()

assert manager._find_next_file_index(7999) is None

@patch("concurrent.futures.ThreadPoolExecutor.submit")
def test_find_next_file_index_all_scheduled_next_row_8000(self, mock_submit):
links = self.create_result_links(num_files=10)
manager = self.create_download_manager()
manager.add_file_links(links)
manager._schedule_downloads()

assert manager._find_next_file_index(8000) == 1

@patch("concurrent.futures.ThreadPoolExecutor.submit", side_effect=[True, KeyError("foo")])
def test_find_next_file_index_one_scheduled_next_row_8000(self, mock_submit):
links = self.create_result_links(num_files=10)
manager = self.create_download_manager()
manager.add_file_links(links)
manager._schedule_downloads()
manager = self.create_download_manager(links, max_download_threads=max_download_threads)

assert manager._find_next_file_index(8000) is None

@patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful",
return_value=True)
@patch("concurrent.futures.ThreadPoolExecutor.submit")
def test_check_if_download_successful_happy(self, mock_submit, mock_is_file_download_successful):
links = self.create_result_links(num_files=10)
manager = self.create_download_manager()
manager.add_file_links(links)
manager._schedule_downloads()

status = manager._check_if_download_successful(manager.download_handlers[0])
assert status
assert manager.num_consecutive_result_file_download_retries == 0

@patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful",
return_value=False)
def test_check_if_download_successful_link_expired(self, mock_is_file_download_successful):
manager = self.create_download_manager()
handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link())
handler.is_link_expired = True

status = manager._check_if_download_successful(handler)
mock_is_file_download_successful.assert_called()
assert not status
assert manager.fetch_need_retry

@patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful",
return_value=False)
def test_check_if_download_successful_download_timed_out_no_retries(self, mock_is_file_download_successful):
manager = self.create_download_manager()
handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link())
handler.is_download_timedout = True

status = manager._check_if_download_successful(handler)
mock_is_file_download_successful.assert_called()
assert not status
assert manager.fetch_need_retry

@patch("concurrent.futures.ThreadPoolExecutor.submit")
@patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful",
return_value=False)
def test_check_if_download_successful_download_timed_out_1_retry(self, mock_is_file_download_successful, mock_submit):
manager = self.create_download_manager()
manager.downloadable_result_settings = download_manager.DownloadableResultSettings(
is_lz4_compressed=True,
download_timeout=0,
max_consecutive_file_download_retries=1,
)
handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link())
handler.is_download_timedout = True

status = manager._check_if_download_successful(handler)
assert mock_is_file_download_successful.call_count == 2
assert mock_submit.call_count == 1
assert not status
assert manager.fetch_need_retry

@patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful",
return_value=False)
def test_check_if_download_successful_other_reason(self, mock_is_file_download_successful):
manager = self.create_download_manager()
handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link())

status = manager._check_if_download_successful(handler)
mock_is_file_download_successful.assert_called()
assert not status
assert manager.fetch_need_retry
assert mock_submit.call_count == max_download_threads
assert len(manager._pending_links) == len(links) - max_download_threads
assert len(manager._download_tasks) == max_download_threads
Loading

0 comments on commit 9e204f7

Please sign in to comment.