diff --git a/doc/API-reference.rst b/doc/API-reference.rst index 7b6b411..1d3ada9 100644 --- a/doc/API-reference.rst +++ b/doc/API-reference.rst @@ -198,6 +198,11 @@ Spline regression .. autofunction:: cc .. autofunction:: te +Polynomial +---------- + +.. autofunction:: poly + Working with formulas programmatically -------------------------------------- diff --git a/patsy/__init__.py b/patsy/__init__.py index 29722d0..a4ec751 100644 --- a/patsy/__init__.py +++ b/patsy/__init__.py @@ -113,5 +113,8 @@ def _reexport(mod): import patsy.mgcv_cubic_splines _reexport(patsy.mgcv_cubic_splines) +import patsy.polynomials +_reexport(patsy.polynomials) + # XX FIXME: we aren't exporting any of the explicit parsing interface # yet. Need to figure out how to do that. diff --git a/patsy/contrasts.py b/patsy/contrasts.py index 3f1cf54..010afb3 100644 --- a/patsy/contrasts.py +++ b/patsy/contrasts.py @@ -17,6 +17,7 @@ from patsy.util import (repr_pretty_delegate, repr_pretty_impl, safe_issubdtype, no_pickling, assert_no_pickling) +from patsy.polynomials import Poly as Polynomial class ContrastMatrix(object): """A simple container for a matrix used for coding categorical factors. @@ -263,11 +264,9 @@ def _code_either(self, intercept, levels): # quadratic, etc., functions of the raw scores, and then use 'qr' to # orthogonalize each column against those to its left. scores -= scores.mean() - raw_poly = scores.reshape((-1, 1)) ** np.arange(n).reshape((1, -1)) - q, r = np.linalg.qr(raw_poly) - q *= np.sign(np.diag(r)) - q /= np.sqrt(np.sum(q ** 2, axis=1)) - # The constant term is always all 1's -- we don't normalize it. + raw_poly = Polynomial.vander(scores, n - 1) + alpha, norm, beta = Polynomial.gen_qr(raw_poly, n - 1) + q = Polynomial.apply_qr(raw_poly, n - 1, alpha, norm, beta) q[:, 0] = 1 names = [".Constant", ".Linear", ".Quadratic", ".Cubic"] names += ["^%s" % (i,) for i in range(4, n)] diff --git a/patsy/polynomials.py b/patsy/polynomials.py new file mode 100644 index 0000000..f8d54c2 --- /dev/null +++ b/patsy/polynomials.py @@ -0,0 +1,204 @@ +# This file is part of Patsy +# Copyright (C) 2012-2013 Nathaniel Smith +# See file LICENSE.txt for license information. + +# R-compatible poly function + +# These are made available in the patsy.* namespace +import numpy as np + +from patsy.util import have_pandas, no_pickling, assert_no_pickling +from patsy.state import stateful_transform + +__all__ = ["poly"] + +if have_pandas: + import pandas + + +class Poly(object): + """poly(x, degree=3, raw=False) + + Generates an orthogonal polynomial transformation of x of degree. + Generic usage is something along the lines of:: + + y ~ 1 + poly(x, 4) + + to fit ``y`` as a function of ``x``, with a 4th degree polynomial. + + :arg degree: The number of degrees for the polynomial expansion. + :arg raw: When raw is False (the default), will return orthogonal + polynomials. + + .. versionadded:: 0.4.1 + """ + + def __init__(self): + self._tmp = {} + + def memorize_chunk(self, x, degree=3, raw=False): + args = {"degree": degree, + "raw": raw + } + self._tmp["args"] = args + # XX: check whether we need x values before saving them + x = np.atleast_1d(x) + if x.ndim == 2 and x.shape[1] == 1: + x = x[:, 0] + if x.ndim > 1: + raise ValueError("input to 'poly' must be 1-d, " + "or a 2-d column vector") + # There's no better way to compute exact quantiles than memorizing + # all data. + x = np.array(x, dtype=float) + self._tmp.setdefault("xs", []).append(x) + + def memorize_finish(self): + tmp = self._tmp + args = tmp["args"] + del self._tmp + + if args["degree"] < 1: + raise ValueError("degree must be greater than 0 (not %r)" + % (args["degree"],)) + if int(args["degree"]) != args["degree"]: + raise ValueError("degree must be an integer (not %r)" + % (args['degree'],)) + + # These are guaranteed to all be 1d vectors by the code above + scores = np.concatenate(tmp["xs"]) + + n = args['degree'] + self.degree = n + self.raw = args['raw'] + + if not self.raw: + raw_poly = self.vander(scores, n) + self.alpha, self.norm, self.beta = self.gen_qr(raw_poly, n) + + def transform(self, x, degree=3, raw=False): + if have_pandas: + if isinstance(x, (pandas.Series, pandas.DataFrame)): + to_pandas = True + idx = x.index + else: + to_pandas = False + else: + to_pandas = False + x = np.array(x, ndmin=1).flatten() + + n = self.degree + p = self.vander(x, n) + + if not self.raw: + p = self.apply_qr(p, n, self.alpha, self.norm, self.beta) + + p = p[:, 1:] + if to_pandas: + p = pandas.DataFrame(p) + p.index = idx + return p + + @staticmethod + def vander(x, n): + raw_poly = np.polynomial.polynomial.polyvander(x, n) + return raw_poly + + @staticmethod + def gen_qr(raw_poly, n): + x = raw_poly[:, 1] + q, r = np.linalg.qr(raw_poly) + # Q is now orthognoal of degree n. To match what R is doing, we + # need to use the three-term recurrence technique to calculate + # new alpha, beta, and norm. + alpha = (np.sum(x.reshape((-1, 1)) * q[:, :n] ** 2, axis=0) + / np.sum(q[:, :n] ** 2, axis=0)) + + # For reasons I don't understand, the norms R uses are based off + # of the diagonal of the r upper triangular matrix. + + norm = np.linalg.norm(q * np.diag(r), axis=0) + beta = (norm[1:] / norm[:n]) ** 2 + return alpha, norm, beta + + @staticmethod + def apply_qr(x, n, alpha, norm, beta): + # This is where the three-term recurrence is unwound for the QR + # decomposition. + if np.ndim(x) == 2: + x = x[:, 1] + p = np.empty((x.shape[0], n + 1)) + p[:, 0] = 1 + + for i in np.arange(n): + p[:, i + 1] = (x - alpha[i]) * p[:, i] + if i > 0: + p[:, i + 1] = (p[:, i + 1] - (beta[i - 1] * p[:, i - 1])) + p /= norm + return p + __getstate__ = no_pickling + + +poly = stateful_transform(Poly) + + +def test_poly_compat(): + from patsy.test_state import check_stateful + from patsy.test_poly_data import (R_poly_test_x, + R_poly_test_data, + R_poly_num_tests) + from numpy.testing import assert_allclose + + lines = R_poly_test_data.split("\n") + tests_ran = 0 + start_idx = lines.index("--BEGIN TEST CASE--") + while True: + if not lines[start_idx] == "--BEGIN TEST CASE--": + break + start_idx += 1 + stop_idx = lines.index("--END TEST CASE--", start_idx) + block = lines[start_idx:stop_idx] + test_data = {} + for line in block: + key, value = line.split("=", 1) + test_data[key] = value + # Translate the R output into Python calling conventions + kwargs = { + # integer + "degree": int(test_data["degree"]), + # boolen + "raw": (test_data["raw"] == 'TRUE') + } + # Special case: in R, setting intercept=TRUE increases the effective + # dof by 1. Adjust our arguments to match. + # if kwargs["df"] is not None and kwargs["include_intercept"]: + # kwargs["df"] += 1 + output = np.asarray(eval(test_data["output"])) + # Do the actual test + check_stateful(Poly, False, R_poly_test_x, output, **kwargs) + raw_poly = Poly.vander(R_poly_test_x, kwargs['degree']) + if kwargs['raw']: + actual = raw_poly[:, 1:] + else: + alpha, norm, beta = Poly.gen_qr(raw_poly, kwargs['degree']) + actual = Poly.apply_qr(R_poly_test_x, kwargs['degree'], alpha, + norm, beta)[:, 1:] + assert_allclose(actual, output) + tests_ran += 1 + # Set up for the next one + start_idx = stop_idx + 1 + assert tests_ran == R_poly_num_tests + + +def test_poly_errors(): + from nose.tools import assert_raises + x = np.arange(27) + # Invalid input shape + assert_raises(ValueError, poly, x.reshape((3, 3, 3))) + assert_raises(ValueError, poly, x.reshape((3, 3, 3)), raw=True) + # Invalid degree + assert_raises(ValueError, poly, x, degree=-1) + assert_raises(ValueError, poly, x, degree=0) + assert_raises(ValueError, poly, x, degree=3.5) + + assert_no_pickling(Poly()) diff --git a/patsy/test_poly_data.py b/patsy/test_poly_data.py new file mode 100644 index 0000000..668ee28 --- /dev/null +++ b/patsy/test_poly_data.py @@ -0,0 +1,37 @@ +# This file auto-generated by tools/get-R-poly-test-vectors.R +# Using: R version 3.2.4 Revised (2016-03-16 r70336) +import numpy as np +R_poly_test_x = np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, ]) +R_poly_test_data = """ +--BEGIN TEST CASE-- +degree=1 +raw=TRUE +output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, ]).reshape((20, 1, ), order="F") +--END TEST CASE-- +--BEGIN TEST CASE-- +degree=1 +raw=FALSE +output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, ]).reshape((20, 1, ), order="F") +--END TEST CASE-- +--BEGIN TEST CASE-- +degree=3 +raw=TRUE +output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, 1, 2.25, 5.0625, 11.390625, 25.62890625, 57.6650390625, 129.746337890625, 291.92926025390625, 656.84083557128906, 1477.8918800354004, 3325.2567300796509, 7481.8276426792145, 16834.112196028233, 37876.752441063523, 85222.692992392927, 191751.05923288409, 431439.8832739892, 970739.73736647563, 2184164.4090745705, 4914369.9204177829, 1, 3.375, 11.390625, 38.443359375, 129.746337890625, 437.89389038085938, 1477.8918800354004, 4987.8850951194763, 16834.112196028233, 56815.128661595285, 191751.05923288409, 647159.82491098379, 2184164.4090745705, 7371554.8806266747, 24878997.722115029, 83966617.312138215, 283387333.4284665, 956432250.32107437, 3227958844.8336263, 10894361101.313488, ]).reshape((20, 3, ), order="F") +--END TEST CASE-- +--BEGIN TEST CASE-- +degree=3 +raw=FALSE +output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, 0.11682670564764953, 0.11622774987820758, 0.1153299112445243, 0.11398449209008393, 0.11196937564961051, 0.10895347864407183, 0.10444488285989936, 0.097716301062945765, 0.087700630095951776, 0.072850827534442664, 0.05096695744238839, 0.019020528242278005, -0.026920519697452645, -0.091380250921070119, -0.1780532062130448, -0.28552519567824058, -0.39602393206231051, -0.44767622905753701, -0.26843910749340033, 0.57802660073100254, -0.11560888340228653, -0.11436481217184656, -0.11250218782662975, -0.10971608089390825, -0.10555451667328646, -0.099351692934324679, -0.090136150925525155, -0.076511614727544461, -0.05651941299388475, -0.027522538371457093, 0.013772191900731716, 0.070864671671751547, 0.14593497036033168, 0.23591981919395397, 0.32391016867398448, 0.36336942185480259, 0.25890497941187346, -0.11572025100301592, -0.66076386903314166, 0.27159578788942196, ]).reshape((20, 3, ), order="F") +--END TEST CASE-- +--BEGIN TEST CASE-- +degree=5 +raw=TRUE +output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, 1, 2.25, 5.0625, 11.390625, 25.62890625, 57.6650390625, 129.746337890625, 291.92926025390625, 656.84083557128906, 1477.8918800354004, 3325.2567300796509, 7481.8276426792145, 16834.112196028233, 37876.752441063523, 85222.692992392927, 191751.05923288409, 431439.8832739892, 970739.73736647563, 2184164.4090745705, 4914369.9204177829, 1, 3.375, 11.390625, 38.443359375, 129.746337890625, 437.89389038085938, 1477.8918800354004, 4987.8850951194763, 16834.112196028233, 56815.128661595285, 191751.05923288409, 647159.82491098379, 2184164.4090745705, 7371554.8806266747, 24878997.722115029, 83966617.312138215, 283387333.4284665, 956432250.32107437, 3227958844.8336263, 10894361101.313488, 1, 5.0625, 25.62890625, 129.746337890625, 656.84083557128906, 3325.2567300796509, 16834.112196028233, 85222.692992392927, 431439.8832739892, 2184164.4090745705, 11057332.320940012, 55977744.87475881, 283387333.4284665, 1434648375.4816115, 7262907400.875659, 36768468716.933022, 186140372879.47342, 942335637702.33411, 4770574165868.0674, 24151031714707.086, 1, 7.59375, 57.6650390625, 437.89389038085938, 3325.2567300796509, 25251.168294042349, 191751.05923288409, 1456109.6060497134, 11057332.320940012, 83966617.31213823, 637621500.21404958, 4841938267.2504387, 36768468716.933022, 279210559319.21014, 2120255184830.252, 16100687809804.727, 122264598055704.64, 928446791485507, 7050392822843070, 53538920498464552, ]).reshape((20, 5, ), order="F") +--END TEST CASE-- +--BEGIN TEST CASE-- +degree=5 +raw=FALSE +output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, 0.11682670564764953, 0.11622774987820758, 0.1153299112445243, 0.11398449209008393, 0.11196937564961051, 0.10895347864407183, 0.10444488285989936, 0.097716301062945765, 0.087700630095951776, 0.072850827534442664, 0.05096695744238839, 0.019020528242278005, -0.026920519697452645, -0.091380250921070119, -0.1780532062130448, -0.28552519567824058, -0.39602393206231051, -0.44767622905753701, -0.26843910749340033, 0.57802660073100254, -0.11560888340228653, -0.11436481217184656, -0.11250218782662975, -0.10971608089390825, -0.10555451667328646, -0.099351692934324679, -0.090136150925525155, -0.076511614727544461, -0.05651941299388475, -0.027522538371457093, 0.013772191900731716, 0.070864671671751547, 0.14593497036033168, 0.23591981919395397, 0.32391016867398448, 0.36336942185480259, 0.25890497941187346, -0.11572025100301592, -0.66076386903314166, 0.27159578788942196, 0.11925766326375063, 0.11701962699862156, 0.11367531238125347, 0.10868744714732725, 0.10126981942884175, 0.090287103769210786, 0.074134201646975206, 0.050620044131431986, 0.016933017097416861, -0.030116712154368355, -0.093138533517390085, -0.17160263551697441, -0.25618209006285081, -0.3183631162695052, -0.29707753517866498, -0.10102478727647804, 0.30185248746535442, 0.55289166632880227, -0.46108564710186972, 0.081962667419115426, -0.12626707822019206, -0.12250155553682644, -0.11689136915447108, -0.10856147160045609, -0.096257598068575617, -0.078227654788373013, -0.052128116579684983, -0.015063001240831148, 0.035988153544508683, 0.10280803884977513, 0.18263307034840112, 0.26144732880503613, 0.30325203347309243, 0.24116709207723347, -0.00082575540196283526, -0.37830141983168153, -0.42887161757203512, 0.55207091753656046, -0.17171017635275559, 0.016240179713238136, ]).reshape((20, 5, ), order="F") +--END TEST CASE-- +""" +R_poly_num_tests = 6 diff --git a/tools/get-R-poly-test-vectors.R b/tools/get-R-poly-test-vectors.R new file mode 100644 index 0000000..68e9678 --- /dev/null +++ b/tools/get-R-poly-test-vectors.R @@ -0,0 +1,62 @@ +cat("# This file auto-generated by tools/get-R-poly-test-vectors.R\n") +cat(sprintf("# Using: %s\n", R.Version()$version.string)) +cat("import numpy as np\n") + +options(digits=20) +library(splines) +x <- (1.5)^(0:19) + +MISSING <- "MISSING" + +is.missing <- function(obj) { + length(obj) == 1 && obj == MISSING +} + +pyprint <- function(arr) { + if (is.missing(arr)) { + cat("None\n") + } else { + cat("np.array([") + for (val in arr) { + cat(val) + cat(", ") + } + cat("])") + if (!is.null(dim(arr))) { + cat(".reshape((") + for (size in dim(arr)) { + cat(sprintf("%s, ", size)) + } + cat("), order=\"F\")") + } + cat("\n") + } +} + +num.tests <- 0 +dump.poly <- function(degree, raw) { + cat("--BEGIN TEST CASE--\n") + cat(sprintf("degree=%s\n", degree)) + cat(sprintf("raw=%s\n", raw)) + + args <- list(x=x, degree=degree, raw=raw) + + result <- do.call(poly, args) + + cat("output=") + pyprint(result) + cat("--END TEST CASE--\n") + assign("num.tests", num.tests + 1, envir=.GlobalEnv) +} + +cat("R_poly_test_x = ") +pyprint(x) +cat("R_poly_test_data = \"\"\"\n") + +for (degree in c(1, 3, 5)) { + for (raw in c(TRUE, FALSE)) { + dump.poly(degree, raw) + } +} +cat("\"\"\"\n") +cat(sprintf("R_poly_num_tests = %s\n", num.tests))