diff --git a/fiftyone/service/ipc.py b/fiftyone/service/ipc.py index 5de3a92f57..049030ba1a 100644 --- a/fiftyone/service/ipc.py +++ b/fiftyone/service/ipc.py @@ -109,7 +109,7 @@ def send_request(port, message): Returns: response (any type) """ - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.connect(("localhost", port)) - pickle.dump(message, s.makefile("wb")) - return pickle.load(s.makefile("rb")) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("localhost", port)) + pickle.dump(message, s.makefile("wb")) + return pickle.load(s.makefile("rb")) diff --git a/requirements/test.txt b/requirements/test.txt index 9173566d9c..c2ee8be8fe 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,6 +1,6 @@ open3d>=0.16.0 itsdangerous==2.0.1 -werkzeug==3.0.3 +werkzeug>=2.0.3 pydicom<3 pytest==7.3.1 pytest-cov==4.0.0 diff --git a/tests/ipc_tests.py b/tests/ipc_tests.py index f94c5352e2..86f856b184 100644 --- a/tests/ipc_tests.py +++ b/tests/ipc_tests.py @@ -13,6 +13,9 @@ import sys import threading import time +import unittest +from io import BytesIO +from unittest.mock import MagicMock import psutil import pytest @@ -120,6 +123,31 @@ def test_run_in_background(): assert requests == [2, 3] +@unittest.mock.patch("socket.socket") +def test_socket_closes_on_exception(mock_socket): + mock_socket_instance = MagicMock() + mock_socket.return_value = mock_socket_instance + mock_wb = BytesIO() + mock_socket_instance.makefile.side_effect = [mock_wb] + + # Test + with unittest.mock.patch( + "pickle.dump", side_effect=Exception("Test exception") + ): + try: + send_request(12345, "test message") + except Exception as e: + assert str(e) == "Test exception" + + # Ensure that the context manager enters and exits + mock_socket_instance.__enter__.assert_called_once() + mock_socket_instance.__exit__.assert_called_once_with( + Exception, # The exception type + unittest.mock.ANY, # The exception instance + unittest.mock.ANY, # The traceback object + ) + + def test_find_processes_by_args(): assert current_process in list( find_processes_by_args(current_process.cmdline())