Skip to content

Commit

Permalink
ARROW-8314: [Python] Add a Table.select method to select a subset of …
Browse files Browse the repository at this point in the history
…columns

This is a pure python implementation. It might be we want that on the C++ side (unless it already exists?), but having it available in Python is already useful IMO.

Closes apache#7272 from jorisvandenbossche/ARROW-8314-table-select

Authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
Signed-off-by: Wes McKinney <wesm@apache.org>
  • Loading branch information
jorisvandenbossche authored and wesm committed Jul 14, 2020
1 parent e771b94 commit 1413963
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 19 deletions.
20 changes: 20 additions & 0 deletions cpp/src/arrow/table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,26 @@ Result<std::shared_ptr<Table>> Table::RenameColumns(
return Table::Make(::arrow::schema(std::move(fields)), std::move(columns), num_rows());
}

Result<std::shared_ptr<Table>> Table::SelectColumns(
const std::vector<int>& indices) const {
int n = static_cast<int>(indices.size());

std::vector<std::shared_ptr<ChunkedArray>> columns(n);
std::vector<std::shared_ptr<Field>> fields(n);
for (int i = 0; i < n; i++) {
int pos = indices[i];
if (pos < 0 || pos > num_columns() - 1) {
return Status::Invalid("Invalid column index ", pos, " to select columns.");
}
columns[i] = column(pos);
fields[i] = field(pos);
}

auto new_schema =
std::make_shared<arrow::Schema>(std::move(fields), schema()->metadata());
return Table::Make(new_schema, std::move(columns), num_rows());
}

std::string Table::ToString() const {
std::stringstream ss;
ARROW_CHECK_OK(PrettyPrint(*this, 0, &ss));
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ class ARROW_EXPORT Table {
Result<std::shared_ptr<Table>> RenameColumns(
const std::vector<std::string>& names) const;

/// \brief Return new table with specified columns
Result<std::shared_ptr<Table>> SelectColumns(const std::vector<int>& indices) const;

/// \brief Replace schema key-value metadata with new metadata (EXPERIMENTAL)
/// \since 0.5.0
///
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/arrow/table_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,22 @@ TEST_F(TestTable, RenameColumns) {
ASSERT_RAISES(Invalid, table->RenameColumns({"hello", "world"}));
}

TEST_F(TestTable, SelectColumns) {
MakeExample1(10);
auto table = Table::Make(schema_, columns_);

ASSERT_OK_AND_ASSIGN(auto subset, table->SelectColumns({0, 2}));
ASSERT_OK(subset->ValidateFull());

auto expexted_schema = ::arrow::schema({schema_->field(0), schema_->field(2)});
auto expected = Table::Make(expexted_schema, {table->column(0), table->column(2)});
ASSERT_TRUE(subset->Equals(*expected));

// Out of bounds indices
ASSERT_RAISES(Invalid, table->SelectColumns({0, 3}));
ASSERT_RAISES(Invalid, table->SelectColumns({-1}));
}

TEST_F(TestTable, RemoveColumnEmpty) {
// ARROW-1865
const int64_t length = 10;
Expand Down
5 changes: 1 addition & 4 deletions python/pyarrow/feather.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,4 @@ def read_table(source, columns=None, memory_map=True):
return table
else:
# follow exact order / selection of names
new_fields = [table.schema.field(c) for c in columns]
new_schema = schema(new_fields, metadata=table.schema.metadata)
new_columns = [table.column(c) for c in columns]
return Table.from_arrays(new_columns, schema=new_schema)
return table.select(columns)
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:

vector[c_string] ColumnNames()
CResult[shared_ptr[CTable]] RenameColumns(const vector[c_string]&)
CResult[shared_ptr[CTable]] SelectColumns(const vector[int]&)

CResult[shared_ptr[CTable]] Flatten(CMemoryPool* pool)

Expand Down
63 changes: 50 additions & 13 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,37 @@ cdef class Table(_PandasConvertible):
"""
return _pc().take(self, indices)

def select(self, object columns):
"""
Select columns of the Table.
Returns a new Table with the specified columns, and metadata
preserved.
Parameters
----------
columns : list-like
The column names or integer indices to select.
Returns
-------
Table
"""
cdef:
shared_ptr[CTable] c_table
vector[int] c_indices

for idx in columns:
idx = self._ensure_integer_index(idx)
idx = _normalize_index(idx, self.num_columns)
c_indices.push_back(<int> idx)

with nogil:
c_table = GetResultValue(self.table.SelectColumns(move(c_indices)))

return pyarrow_wrap_table(c_table)

def replace_schema_metadata(self, metadata=None):
"""
EXPERIMENTAL: Create shallow copy of table by replacing schema
Expand Down Expand Up @@ -1583,18 +1614,9 @@ cdef class Table(_PandasConvertible):
"""
return self.schema.field(i)

def column(self, i):
def _ensure_integer_index(self, i):
"""
Select a column by its column name, or numeric index.
Parameters
----------
i : int or string
The index or name of the column to retrieve.
Returns
-------
pyarrow.ChunkedArray
Ensure integer index (convert string column name to integer if needed).
"""
if isinstance(i, (bytes, str)):
field_indices = self.schema.get_all_field_indices(i)
Expand All @@ -1606,12 +1628,27 @@ cdef class Table(_PandasConvertible):
raise KeyError("Field \"{}\" exists {} times in table schema"
.format(i, len(field_indices)))
else:
return self._column(field_indices[0])
return field_indices[0]
elif isinstance(i, int):
return self._column(i)
return i
else:
raise TypeError("Index must either be string or integer")

def column(self, i):
"""
Select a column by its column name, or numeric index.
Parameters
----------
i : int or string
The index or name of the column to retrieve.
Returns
-------
pyarrow.ChunkedArray
"""
return self._column(self._ensure_integer_index(i))

def _column(self, int i):
"""
Select a column by its numeric index.
Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,8 +758,8 @@ def assert_yields_projected(fragment, row_slice,
column_names = columns if columns else table.column_names
assert actual.column_names == column_names

expected = table.slice(*row_slice).to_pandas()[[*column_names]]
assert actual.equals(pa.Table.from_pandas(expected))
expected = table.slice(*row_slice).select(column_names)
assert actual.equals(expected)

fragment = list(dataset.get_fragments())[0]
parquet_format = fragment.format
Expand Down
51 changes: 51 additions & 0 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,3 +1413,54 @@ def test_table_take_non_consecutive():
['f1', 'f2'])

assert table.take(pa.array([1, 3])).equals(result_non_consecutive)


def test_table_select():
a1 = pa.array([1, 2, 3, None, 5])
a2 = pa.array(['a', 'b', 'c', 'd', 'e'])
a3 = pa.array([[1, 2], [3, 4], [5, 6], None, [9, 10]])
table = pa.table([a1, a2, a3], ['f1', 'f2', 'f3'])

# selecting with string names
result = table.select(['f1'])
expected = pa.table([a1], ['f1'])
assert result.equals(expected)

result = table.select(['f3', 'f2'])
expected = pa.table([a3, a2], ['f3', 'f2'])
assert result.equals(expected)

# selecting with integer indices
result = table.select([0])
expected = pa.table([a1], ['f1'])
assert result.equals(expected)

result = table.select([2, 1])
expected = pa.table([a3, a2], ['f3', 'f2'])
assert result.equals(expected)

# preserve metadata
table2 = table.replace_schema_metadata({"a": "test"})
result = table2.select(["f1", "f2"])
assert b"a" in result.schema.metadata

# selecting non-existing column raises
with pytest.raises(KeyError, match='Field "f5" does not exist'):
table.select(['f5'])

with pytest.raises(IndexError, match="index out of bounds"):
table.select([5])

# duplicate selection gives duplicated names in resulting table
result = table.select(['f2', 'f2'])
expected = pa.table([a2, a2], ['f2', 'f2'])
assert result.equals(expected)

# selection duplicated column raises
table = pa.table([a1, a2, a3], ['f1', 'f2', 'f1'])
with pytest.raises(KeyError, match='Field "f1" exists 2 times'):
table.select(['f1'])

result = table.select(['f2'])
expected = pa.table([a2], ['f2'])
assert result.equals(expected)

0 comments on commit 1413963

Please sign in to comment.