Skip to content

Commit

Permalink
address more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lithomas1 committed Jun 13, 2024
1 parent 699efd3 commit e242182
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
10 changes: 8 additions & 2 deletions python/cudf/cudf/_lib/pylibcudf/io/types.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ cdef class SinkInfo:
if isinstance(sinks[0], io.StringIO):
data_sinks.reserve(len(sinks))
for s in sinks:
if not isinstance(s, io.StringIO):
raise ValueError("All sinks must be of the same type!")
self.sink_storage.push_back(
unique_ptr[data_sink](new iobase_data_sink(s))
)
Expand All @@ -219,8 +221,10 @@ cdef class SinkInfo:
}:
raise NotImplementedError(f"Unsupported encoding {s.encoding}")
sink = move(unique_ptr[data_sink](new iobase_data_sink(s.buffer)))
else:
elif isinstance(s, io.BytesIO):
sink = move(unique_ptr[data_sink](new iobase_data_sink(s)))
else:
raise ValueError("All sinks must be of the same type!")

self.sink_storage.push_back(
move(sink)
Expand All @@ -230,7 +234,9 @@ cdef class SinkInfo:
elif isinstance(sinks[0], (basestring, os.PathLike)):
paths.reserve(len(sinks))
for s in sinks:
if not isinstance(s, (basestring, os.PathLike)):
raise ValueError("All sinks must be of the same type!")
paths.push_back(<string> os.path.expanduser(s).encode())
self.c_obj = sink_info(move(paths))
else:
raise TypeError("Unrecognized input type: {}".format(type(sinks)))
raise TypeError("Unrecognized input type: {}".format(type(sinks[0])))
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,24 @@
import cudf._lib.pylibcudf as plc


@pytest.fixture(params=[plc.io.SourceInfo, plc.io.SinkInfo])
def io_class(request):
return request.param


@pytest.mark.parametrize(
"source", ["a.txt", b"hello world", io.BytesIO(b"hello world")]
)
def test_source_info_ctor(source, tmp_path):
def test_source_info_ctor(io_class, source, tmp_path):
if isinstance(source, str):
file = tmp_path / source
file.write_bytes("hello world".encode("utf-8"))
source = str(file)

plc.io.SourceInfo([source])
if io_class is plc.io.SinkInfo and isinstance(source, bytes):
pytest.skip("bytes is not a valid input for SinkInfo")

# TODO: test contents of source_info buffer is correct
# once buffers are exposed on python side
io_class([source])


@pytest.mark.parametrize(
Expand All @@ -30,18 +35,17 @@ def test_source_info_ctor(source, tmp_path):
[io.BytesIO(b"hello world"), io.BytesIO(b"hello there")],
],
)
def test_source_info_ctor_multiple(sources, tmp_path):
def test_source_info_ctor_multiple(io_class, sources, tmp_path):
for i in range(len(sources)):
source = sources[i]
if isinstance(source, str):
file = tmp_path / source
file.write_bytes("hello world".encode("utf-8"))
sources[i] = str(file)
elif io_class is plc.io.SinkInfo and isinstance(source, bytes):
pytest.skip("bytes is not a valid input for SinkInfo")

plc.io.SourceInfo(sources)

# TODO: test contents of source_info buffer is correct
# once buffers are exposed on python side
io_class(sources)


@pytest.mark.parametrize(
Expand All @@ -56,7 +60,7 @@ def test_source_info_ctor_multiple(sources, tmp_path):
],
],
)
def test_source_info_ctor_mixing_invalid(sources, tmp_path):
def test_source_info_ctor_mixing_invalid(io_class, sources, tmp_path):
# Unlike the previous test
# don't create files so that they are missing
for i in range(len(sources)):
Expand All @@ -65,5 +69,7 @@ def test_source_info_ctor_mixing_invalid(sources, tmp_path):
file = tmp_path / source
file.write_bytes("hello world".encode("utf-8"))
sources[i] = str(file)
elif io_class is plc.io.SinkInfo and isinstance(source, bytes):
pytest.skip("bytes is not a valid input for SinkInfo")
with pytest.raises(ValueError):
plc.io.SourceInfo(sources)
io_class(sources)

0 comments on commit e242182

Please sign in to comment.