Skip to content

Commit

Permalink
Merge pull request #15 from sklam/features/to_from_pandas
Browse files Browse the repository at this point in the history
Support to_pandas and from_pandas
  • Loading branch information
seibert authored Jun 12, 2017
2 parents 415972c + 5dc25bf commit ac0ff8d
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 8 deletions.
1 change: 1 addition & 0 deletions conda_environments/testing_py35.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ dependencies:
- cudatoolkit=8.0
- llvmlite>=0.18
- numpy=1.12.1
- pandas
- numba>=0.33
- libgdf_cffi>=0.1.0a1.dev
106 changes: 98 additions & 8 deletions pygdf/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@ class DataFrame(object):
3 3 0.887777513028
4 4 0.55838311246
[5 more rows]
Convert from a Pandas DataFrame.
>>> import pandas as pd
>>> from pygdf.dataframe import DataFrame
>>> pdf = pd.DataFrame({'a': [0, 1, 2, 3],
... 'b': [0.1, 0.2, None, 0.3]})
>>> pdf
a b
0 0 0.1
1 1 0.2
2 2 NaN
3 3 0.3
>>> df = DataFrame.from_pandas(pdf)
>>> df
a b
0 0 0.1
1 1 0.2
2 2 nan
3 3 0.3
"""

def __init__(self, name_series=None):
Expand Down Expand Up @@ -97,6 +117,7 @@ def to_string(self, nrows=5, ncols=8):
return ''
if nrows is None:
nrows = len(self)
nrows = min(nrows, len(self)) # cap row count
if ncols is None:
ncols = len(self.columns)
lastcol = None
Expand Down Expand Up @@ -284,7 +305,8 @@ def as_gpu_matrix(self, columns=None):

matrix = cuda.device_array(shape=(nrow, ncol), dtype=dtype, order="F")
for colidx, inpcol in enumerate(cols):
matrix[:, colidx].copy_to_device(inpcol.to_gpu_array())
dense = inpcol.to_gpu_array(fillna='pandas')
matrix[:, colidx].copy_to_device(dense)

return matrix

Expand Down Expand Up @@ -331,6 +353,33 @@ def one_hot_encoding(self, column, prefix, cats, prefix_sep='_',
outdf.add_column(name, col)
return outdf

def to_pandas(self):
"""Convert to a Pandas DataFrame.
"""
import pandas as pd

dct = {k: c.to_array(fillna='pandas') for k, c in self._cols.items()}
return pd.DataFrame.from_dict(dct)

@classmethod
def from_pandas(cls, dataframe):
"""Convert from a Pandas DataFrame.
Raises
------
TypeError for invalid input type.
"""
import pandas as pd

if not isinstance(dataframe, pd.DataFrame):
raise TypeError('not a pandas.DataFrame')

df = cls()

for colk in dataframe.columns:
df[colk] = dataframe[colk].values
return df


class Loc(object):
"""
Expand Down Expand Up @@ -674,16 +723,33 @@ def fillna(self, value):
value=value)
return self.from_array(out)

def to_dense_buffer(self):
def to_dense_buffer(self, fillna=None):
"""Get dense (no null values) ``Buffer`` of the data.
Parameters
----------
fillna : str or None
See *fillna* in ``.to_array``.
Notes
-----
Null values are skipped. Therefore, the output size could be smaller.
if ``fillna`` is ``None``, null values are skipped. Therefore, the
output size could be smaller.
"""
if fillna not in {None, 'pandas'}:
raise ValueError('invalid for fillna')

if self.has_null_mask:
return self._copy_to_dense_buffer()
if fillna == 'pandas':
# cast non-float types to float64
col = (self.astype(np.float64)
if self.dtype.kind != 'f'
else self)
# fill nan
return col.fillna(np.nan)
else:
return self._copy_to_dense_buffer()
else:
return self._data

Expand All @@ -693,15 +759,39 @@ def _copy_to_dense_buffer(self):
nnz, mem = cudautils.copy_to_dense(data=data, mask=mask)
return Buffer(mem, size=nnz, capacity=mem.size)

def to_array(self):
def to_array(self, fillna=None):
"""Get a dense numpy array for the data.
Parameters
----------
fillna : str or None
Defaults to None, which will skip null values.
If it equals "pandas", null values are filled with NaNs.
Non integral dtype is promoted to np.float64.
Notes
-----
if ``fillna`` is ``None``, null values are skipped. Therefore, the
output size could be smaller.
"""
return self.to_dense_buffer().to_array()
return self.to_dense_buffer(fillna=fillna).to_array()

def to_gpu_array(self):
def to_gpu_array(self, fillna=None):
"""Get a dense numba device array for the data.
Parameters
----------
fillna : str or None
See *fillna* in ``.to_array``.
Notes
-----
if ``fillna`` is ``None``, null values are skipped. Therefore, the
output size could be smaller.
"""
return self.to_dense_buffer().to_gpu_array()
return self.to_dense_buffer(fillna=fillna).to_gpu_array()

@property
def data(self):
Expand Down
52 changes: 52 additions & 0 deletions pygdf/tests/test_pandas_interop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np
import pandas as pd

from pygdf.dataframe import DataFrame


def test_to_pandas():
df = DataFrame()
df['a'] = np.arange(10, dtype=np.int32)
df['b'] = np.arange(10, 20, dtype=np.float64)

pdf = df.to_pandas()

assert tuple(df.columns) == tuple(pdf.columns)

assert df['a'].dtype == pdf['a'].dtype
assert df['b'].dtype == pdf['b'].dtype

assert len(df['a']) == len(pdf['a'])
assert len(df['b']) == len(pdf['b'])


def test_from_pandas():
pdf = pd.DataFrame()
pdf['a'] = np.arange(10, dtype=np.int32)
pdf['b'] = np.arange(10, 20, dtype=np.float64)

df = DataFrame.from_pandas(pdf)

assert tuple(df.columns) == tuple(pdf.columns)

assert df['a'].dtype == pdf['a'].dtype
assert df['b'].dtype == pdf['b'].dtype

assert len(df['a']) == len(pdf['a'])
assert len(df['b']) == len(pdf['b'])


def test_from_pandas_ex1():
pdf = pd.DataFrame({'a': [0, 1, 2, 3],
'b': [0.1, 0.2, None, 0.3]})
print(pdf)
df = DataFrame.from_pandas(pdf)
print(df)

assert tuple(df.columns) == tuple(pdf.columns)
assert np.all(df['a'].to_array() == pdf['a'])
matches = df['b'].to_array() == pdf['b']
# the 3d element is False due to (nan == nan) == False
assert np.all(matches == [True, True, False, True])
assert np.isnan(df['b'].to_array()[2])
assert np.isnan(pdf['b'][2])
13 changes: 13 additions & 0 deletions pygdf/tests/test_sparse_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ def test_fillna():
assert not dense.has_null_mask


def test_to_dense_array():
data = np.random.random(8)
mask = np.asarray([0b11010110], dtype=np.byte)

sr = Series.from_masked_array(data=data, mask=mask, null_count=3)
assert sr.null_count > 0
assert sr.null_count != len(sr)
filled = sr.to_array(fillna='pandas')
dense = sr.to_array()
assert dense.size < filled.size
assert filled.size == len(sr)


def test_reading_arrow_sparse_data():
darr = read_data()
gar = GpuArrowReader(darr)
Expand Down

0 comments on commit ac0ff8d

Please sign in to comment.