Skip to content

Commit

Permalink
Merge pull request blaze#1496 from kwmsmith/feature-cat
Browse files Browse the repository at this point in the history
str_cat with Pandas `.str.cat()` interface
  • Loading branch information
kwmsmith committed May 2, 2016
2 parents 62a2911 + 26d2d9a commit 0094aa4
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 25 deletions.
7 changes: 7 additions & 0 deletions blaze/compute/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
summary,
symbol,
var,
StrCat,
)

__all__ = []
Expand Down Expand Up @@ -296,6 +297,12 @@ def compute_up(expr, data, **kwargs):
return getattr(data.str, string_func_names.get(name, name))()


@dispatch(StrCat, Series, Series)
def compute_up(expr, lhs_data, rhs_data, **kwargs):
res = lhs_data.str.cat(rhs_data, sep=expr.sep)
return res


def unpack(seq):
""" Unpack sequence of length one
Expand Down
56 changes: 46 additions & 10 deletions blaze/compute/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from toolz.curried import map

from .core import compute_up, compute, base
from ..compatibility import reduce
from ..compatibility import reduce, basestring
from ..dispatch import dispatch
from ..expr import (
BinOp,
Expand Down Expand Up @@ -90,6 +90,7 @@
std,
str_len,
strlen,
StrCat,
var,
)
from ..expr.broadcast import broadcast_collect
Expand Down Expand Up @@ -271,10 +272,12 @@ def compute_up(t, data, **kwargs):
else:
return f(t, t.lhs, column)

@compute_up.register(
type_, (Select, ColumnElement, base), (Select, ColumnElement),
)
@compute_up.register(type_, (Select, ColumnElement), base)
@compute_up.register(type_,
(Select, ColumnElement, base),
(Select, ColumnElement))
@compute_up.register(type_,
(Select, ColumnElement),
base)
def binop_sql(t, lhs, rhs, **kwargs):
if isinstance(lhs, Select):
assert len(lhs.c) == 1, (
Expand Down Expand Up @@ -1182,11 +1185,9 @@ def compute_up(t, s, **kwargs):
return s.like(t.pattern.replace('*', '%').replace('?', '_'))


string_func_names = {
# <blaze function name>: <SQL function name>
'str_upper': 'upper',
'str_lower': 'lower',
}
string_func_names = {# <blaze function name>: <SQL function name>
'str_upper': 'upper',
'str_lower': 'lower'}


# TODO: remove if the alternative fix goes into PyHive
Expand All @@ -1205,6 +1206,41 @@ def compute_up(expr, data, **kwargs):
return sa.sql.functions.char_length(data).label(expr._name)


@compute_up.register(StrCat, Select, basestring)
@compute_up.register(StrCat, basestring, Select)
def str_cat_sql(expr, lhs, rhs, **kwargs):
if isinstance(lhs, Select):
orig = lhs
lhs = first(lhs.inner_columns)
else:
orig = rhs
rhs = first(rhs.inner_columns)
if expr.sep:
result = (lhs + expr.sep + rhs).label(expr.lhs._name)
else:
result = (lhs + rhs).label(expr.lhs._name)
return reconstruct_select([result], orig)


@compute_up.register(StrCat, Select, Select)
def str_cat_sql(expr, lhs, rhs, **kwargs):
left, right = first(lhs.inner_columns), first(rhs.inner_columns)
if expr.sep:
result = (left + expr.sep + right).label(expr.lhs._name)
else:
result = (left + right).label(expr.lhs._name)
return reconstruct_select([result], lhs)


@compute_up.register(StrCat, (ColumnElement, basestring), ColumnElement)
@compute_up.register(StrCat, ColumnElement, basestring)
def str_cat_sql(expr, lhs, rhs, **kwargs):
if expr.sep:
return (lhs + expr.sep + rhs).label(expr.lhs._name)
else:
return (lhs + rhs).label(expr.lhs._name)


@dispatch(UnaryStringFunction, ColumnElement)
def compute_up(expr, data, **kwargs):
func_name = type(expr).__name__
Expand Down
57 changes: 57 additions & 0 deletions blaze/compute/tests/test_pandas_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,36 @@
columns=['name', 'sex', 'amount', 'id'])


# for now jsut copy this, but will open a PR to see if we can remove some of
# the repetitive copying
tbgr = symbol('tbgr',
""" var * {name: string,
sex: string[1],
amount: int,
id: int,
comment: ?string}
""")

dfbgr = DataFrame([['Alice', 'F', 100, 1, 'Alice comment'],
['Alice', 'F', 100, 3, None],
['Drew', 'F', 100, 4, 'Drew comment'],
['Drew', 'M', 100, 5, 'Drew comment 2'],
['Drew', 'M', 200, 5, None]],
columns=['name', 'sex', 'amount', 'id', 'comment'])


@pytest.fixture(scope='module')
def df_add_null():
rows = [(None, 'M', 300, 6),
('first', None, 300, 6),
(None, None, 300, 6)]
df_add_null = dfbig.append(DataFrame(rows,
columns=dfbig.columns
), ignore_index=True)

return df_add_null


def test_series_broadcast():
s = Series([1, 2, 3], name='a')
t = symbol('t', 'var * {a: int64}')
Expand Down Expand Up @@ -680,6 +710,33 @@ def test_str_lower():
assert_series_equal(expected, result)


def test_str_cat():
res = compute(tbig.name.str_cat(tbig.sex), dfbig)
assert all(dfbig.name.str.cat(dfbig.sex) == res)


def test_str_cat_sep():
res = compute(tbig.name.str_cat(tbig.sex, sep=' -- '), dfbig)
assert all(dfbig.name.str.cat(dfbig.sex, sep=' -- ') == res)


def test_str_cat_null_row(df_add_null):
res = compute(tbig.name.str_cat(tbig.sex, sep=' -- '), df_add_null)
exp_res = df_add_null.name.str.cat(df_add_null.sex, sep=' -- ')

assert all(exp_res.isnull() == res.isnull())
assert all(exp_res[~exp_res.isnull()] == res[~res.isnull()])


def test_str_cat_chain_operation():
expr = tbgr.name.str_cat(tbgr.comment.str_cat(tbgr.sex, sep=' --- '),
sep=' +++ ')
res = compute(expr, dfbgr)
exp_res = dfbgr.name.str.cat(dfbgr.comment.str.cat(dfbgr.sex, sep=' --- '),
sep=' +++ ')
assert all(exp_res.isnull() == res.isnull())
assert all(exp_res[~exp_res.isnull()] == res[~res.isnull()])


def test_rowwise_by():
f = lambda _, id, name: id + len(name)
Expand Down
95 changes: 94 additions & 1 deletion blaze/compute/tests/test_postgresql_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from datashape import dshape
from odo import odo, drop, discover
from odo.utils import tmpfile
from blaze import (
data,
atan2,
Expand Down Expand Up @@ -65,6 +64,36 @@ def sql(url):
drop(t)


@pytest.yield_fixture
def sql_with_null(url):
ds = dshape(""" var * {name: ?string,
sex: ?string,
amount: int,
id: int,
comment: ?string}
""")
rows = [('Alice', 'F', 100, 1, 'Alice comment'),
(None, 'M', 300, 2, None),
('Drew', 'F', 100, 4, 'Drew comment'),
('Bob', 'M', 100, 5, 'Bob comment 2'),
('Drew', 'M', 200, 5, None),
('first', None, 300, 4, 'Missing info'),
(None, None, 300, 6, None)]
try:
x = url % next(names)
t = data(x, dshape=ds)
print(x)
except sa.exc.OperationalError as e:
pytest.skip(str(e))
else:
assert t.dshape == ds
t = data(odo(rows, t))
try:
yield t
finally:
drop(t)


@pytest.yield_fixture(scope='module')
def nyc(pg_ip):
# odoing csv -> pandas -> postgres is more robust, as it doesn't require
Expand Down Expand Up @@ -667,6 +696,70 @@ def test_sample(big_sql):
assert len(result) == len(result2)


@pytest.mark.parametrize("sep", [None, " -- "])
def test_str_cat_with_null(sql_with_null, sep):
t = symbol('t', discover(sql_with_null))
res = compute(t.name.str_cat(t.sex, sep=sep), sql_with_null,
return_type=list)
res = [r[0] for r in res]
cols = compute(t[['name', 'sex']], sql_with_null, return_type=list)

for r, (n, s) in zip(res, cols):
if n is None or s is None:
assert r is None
else:
assert (r == n + s if sep is None else r == n + sep + s)


def test_chain_str_cat_with_null(sql_with_null):
t = symbol('t', discover(sql_with_null))
expr = (t.name
.str_cat(t.comment, sep=' ++ ')
.str_cat(t.sex, sep=' -- '))
res = compute(expr, sql_with_null, return_type=list)
res = [r[0] for r in res]
cols = compute(t[['name', 'comment', 'sex']], sql_with_null,
return_type=list)

for r, (n, c, s) in zip(res, cols):
if n is None or c is None or s is None:
assert r is None
else:
assert (r == n + ' ++ ' + c + ' -- ' + s)


def test_str_cat_bcast(sql_with_null):
t = symbol('t', discover(sql_with_null))
lit_sym = symbol('s', 'string')
s = t[t.amount <= 200]
result = compute(s.comment.str_cat(lit_sym, sep=' '),
{t: sql_with_null, lit_sym: '!!'},
return_type=pd.Series)
df = compute(s, sql_with_null,
return_type=pd.DataFrame)
expected = df.comment.str.cat(['!!']*len(df.comment), sep=' ')

assert all(expected[~expected.isnull()] == result[~result.isnull()])
assert all(expected[expected.isnull()].index == result[result.isnull()].index)



def test_str_cat_where_clause(sql_with_null):
"""
Invokes the (Select, Select) path for compute_up
"""
t = symbol('t', discover(sql_with_null))
s = t[t.amount <= 200]
c1 = s.comment.str_cat(s.sex, sep=' -- ')

bres = compute(c1, sql_with_null, return_type=pd.Series)
df_s = compute(s, sql_with_null, return_type=pd.DataFrame)
exp = df_s.comment.str.cat(df_s.sex, ' -- ')

assert all(exp[~exp.isnull()] == bres[~bres.isnull()])
assert all(exp[exp.isnull()].index == bres[bres.isnull()].index)


def test_core_compute(nyc):
t = symbol('t', discover(nyc))
assert isinstance(compute(t, nyc, return_type='core'), pd.DataFrame)
Expand Down
71 changes: 64 additions & 7 deletions blaze/compute/tests/test_sql_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def city_data():


t = symbol('t', 'var * {name: string, amount: int, id: int}')
t_str_cat = symbol('t',
"""var * {name: string,
amount: int,
id: int,
comment: string,
product: string}""")

nt = symbol('t', 'var * {name: ?string, amount: float64, id: int}')

metadata = sa.MetaData()
Expand All @@ -87,6 +94,13 @@ def city_data():
sa.Column('amount', sa.Integer),
sa.Column('id', sa.Integer, primary_key=True))

s_str_cat = sa.Table('accounts2', metadata,
sa.Column('name', sa.String),
sa.Column('amount', sa.Integer),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('comment', sa.String),
sa.Column('product', sa.String))

tdate = symbol('t',
"""var * {
name: string,
Expand Down Expand Up @@ -809,22 +823,65 @@ def test_str_upper():
expected = "SELECT upper(accounts.name) as name FROM accounts"
assert normalize(result) == normalize(expected)


def test_str_lower():
expr = t.name.str_lower()
result = str(compute(expr, s, return_type='native'))
expected = "SELECT lower(accounts.name) as name FROM accounts"
assert normalize(result) == normalize(expected)


@pytest.mark.parametrize("sep", [None, " sep "])
def test_str_cat(sep):
"""
Need at least two string columns to test str_cat
"""

if sep is None:
expr = t_str_cat.name.str_cat(t_str_cat.comment)
expected = """
SELECT accounts2.name || accounts2.comment
AS anon_1 FROM accounts2
"""
else:
expr = t_str_cat.name.str_cat(t_str_cat.comment, sep=sep)
expected = """
SELECT accounts2.name || :name_1 || accounts2.comment
AS anon_1 FROM accounts2
"""

result = str(compute(expr, s_str_cat, return_type='native'))
assert normalize(result) == normalize(expected)


def test_str_cat_chain():
expr = (t_str_cat.name
.str_cat(t_str_cat.comment, sep=' -- ')
.str_cat(t_str_cat.product, sep=' ++ '))
result = str(compute(expr, {t_str_cat: s_str_cat}, return_type='native'))
expected = """
SELECT accounts2.name || :name_1 || accounts2.comment ||
:param_1 || accounts2.product AS anon_1 FROM accounts2
"""
assert normalize(result) == normalize(expected)


def test_str_cat_no_runtime_exception():
"""
No exception raised if resource is the same
"""
expr = t_str_cat.comment.str_cat(t.name)
compute(expr, {t: s_str_cat, t_str_cat: s_str_cat}, return_type='native')


def test_columnwise_on_complex_selection():
result = str(select(compute(t[t.amount > 0].amount + 1, s, return_type='native')))
assert normalize(result) == \
normalize("""
SELECT accounts.amount + :amount_1 AS amount
FROM accounts
WHERE accounts.amount > :amount_2
""")
expected = """
SELECT accounts.amount + :amount_1 AS amount
FROM accounts
WHERE accounts.amount > :amount_2
"""
assert normalize(result) == normalize(expected)


def test_reductions_on_complex_selections():
Expand Down
Loading

0 comments on commit 0094aa4

Please sign in to comment.