Skip to content

Meta learners #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 62 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
2c0d551
implemented frequentist S, T and X learners
Feb 24, 2023
ab999b6
Reformatted. Added bootstrapping. Added DRLearner.
Feb 26, 2023
9791b9b
Fixed doc-string for DRLearner
Feb 26, 2023
100f8d7
renamed meta_learners.py to skl_meta_learners.py
Feb 26, 2023
5874281
imported skl_meta_learners
Feb 26, 2023
df90a52
minor code style fixes
Feb 27, 2023
b8a3dff
mostly stylistic changes
Feb 27, 2023
020a65f
fixed an import
Feb 27, 2023
667d3b4
bootstraping does not overwrite self.models anymore
matekadlicsko Feb 28, 2023
d05c156
fixed a citation in docstring
matekadlicsko Mar 1, 2023
542e129
added _fit function to reduce boilerplate code
matekadlicsko Mar 1, 2023
5f8a62f
refactored
matekadlicsko Mar 1, 2023
759b9e2
added BARTModel
matekadlicsko Mar 1, 2023
8c03319
outlined pymc meta-learners
matekadlicsko Mar 1, 2023
18baff5
minor changes helping pymc integration
matekadlicsko Mar 2, 2023
f9d9817
minor changes
matekadlicsko Mar 2, 2023
9917a83
continuing to integrate pymc models
matekadlicsko Mar 2, 2023
a8d6467
bugfix
matekadlicsko Mar 2, 2023
55b43df
more minor bugfixes
matekadlicsko Mar 2, 2023
9d5bb61
added logistic regression
matekadlicsko Mar 2, 2023
3f77e76
added bayesian DRLearner
matekadlicsko Mar 4, 2023
faf0db5
fixed some issues with X and DR learners
matekadlicsko Mar 5, 2023
c1bbf33
small bugfixes
matekadlicsko Mar 6, 2023
2f689dd
added (incomplete) notebook explaining meta-learners
matekadlicsko Mar 6, 2023
b57e31a
wrote section on X-learner
matekadlicsko Mar 7, 2023
483d55b
fixed major error in DRLearner implementation
matekadlicsko Mar 7, 2023
d62eb18
minor changes
matekadlicsko Mar 8, 2023
95e010e
implemented cross_fitting option for DR-learner
matekadlicsko Mar 9, 2023
3e1182d
wrote subsection on DR-learner
matekadlicsko Mar 9, 2023
806cd0f
added docstring + some small changes suggested by @juanitorduz
matekadlicsko Mar 10, 2023
21d0b15
fixed a dependency
matekadlicsko Mar 12, 2023
c4f124b
improvements on LogisticRegression
matekadlicsko Mar 12, 2023
90fddd7
several improvements
matekadlicsko Mar 12, 2023
917216c
BayesianDR now works
matekadlicsko Mar 15, 2023
bb588b9
BayesianXLearner now works
matekadlicsko Mar 15, 2023
f39b856
removed redundant _compute_cate function
matekadlicsko Mar 15, 2023
2ca0ebd
formatting
matekadlicsko Mar 15, 2023
48c8105
added score method
matekadlicsko Mar 16, 2023
ddaebb4
formatting
matekadlicsko Mar 16, 2023
3bb16fe
reworded introduction + included some suggestions by @juanitorduz
matekadlicsko Mar 16, 2023
0d98c53
minor changes
matekadlicsko Mar 16, 2023
02b78e1
formatting
matekadlicsko Mar 17, 2023
3e845bf
added correct docstring
matekadlicsko Mar 17, 2023
d4830cc
added aesera to list of dependencies
matekadlicsko Mar 22, 2023
02d592c
improved docstrings.
matekadlicsko Mar 27, 2023
2007685
XLearner computations were wrong
matekadlicsko Mar 27, 2023
a936306
added summary file
matekadlicsko Mar 29, 2023
e682b27
summary now returns a summary object
matekadlicsko Mar 29, 2023
4751aeb
minor fix
matekadlicsko Mar 29, 2023
aba9255
new summary objects are displayed
matekadlicsko Mar 29, 2023
5fe6c53
changed plot method
matekadlicsko Apr 2, 2023
8fd71ec
Added some docstrings
matekadlicsko Apr 2, 2023
14fac30
fixed pymc-bart import
matekadlicsko Apr 9, 2023
1cbe477
summary now performs bootstrapping only once
matekadlicsko Apr 9, 2023
46a33d2
added summary
matekadlicsko Apr 9, 2023
d88472c
imported summary
matekadlicsko Apr 9, 2023
c154979
Merge branch 'pymc-labs:main' into meta-learners
matekadlicsko Apr 13, 2023
18b6934
made notebook a bit more clear
matekadlicsko Apr 17, 2023
1beda78
Merge branch 'meta-learners' of https://github.com/matekadlicsko/Caus…
matekadlicsko Apr 17, 2023
b43752e
Merge branch 'pymc-labs:main' into meta-learners
matekadlicsko Apr 20, 2023
92b655d
Merge branch 'pymc-labs:main' into meta-learners
matekadlicsko May 10, 2023
9d26c40
Merge branch 'pymc-labs:main' into meta-learners
matekadlicsko Jun 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions causalpy/__init__.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,9 @@
import causalpy.pymc_experiments
import causalpy.pymc_models
import causalpy.skl_experiments
import causalpy.skl_meta_learners
import causalpy.skl_models
import causalpy.summary
from causalpy.version import __version__

from .data import load_data
427 changes: 427 additions & 0 deletions causalpy/pymc_meta_learners.py

Large diffs are not rendered by default.

119 changes: 119 additions & 0 deletions causalpy/pymc_models.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,9 @@
import numpy as np
import pandas as pd
import pymc as pm
import pymc_bart as pmb
from arviz import r2_score
from pymc.distributions.distribution import DistributionMeta


class ModelBuilder(pm.Model):
@@ -113,3 +115,120 @@ def build_model(self, X, y, coords):
sigma = pm.HalfNormal("sigma", 1)
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")


class BARTRegressor(ModelBuilder):
"""
Class for building BART based regressors for meta-learners.
Parameters
----------
m : int.
Number of trees to fit.
sigma : float.
Prior standard deviation.
sample_kwargs : dict.
Keyword arguments for sampler.
"""

def __init__(
self,
m: int = 20,
sigma: float = 1.0,
sample_kwargs: Optional[dict[str, Any]] = None,
):
self.m = m
self.sigma = sigma
super().__init__(sample_kwargs)

def build_model(self, X, y, coords=None):
with self:
self.add_coords(coords)
X_ = pm.MutableData("X", X, dims=["obs_ind", "coeffs"])
mu = pmb.BART("mu", X_, y, m=self.m, dims="obs_ind")
pm.Normal("y_hat", mu=mu, sigma=self.sigma, observed=y, dims="obs_ind")


class BARTClassifier(ModelBuilder):
"""
Class for building BART based models for meta-learners.
Parameters
----------
m : int.
Number of trees to fit.
sample_kwargs : dict.
Keyword arguments for sampler.
"""

def __init__(
self,
m: int = 20,
sample_kwargs: Optional[dict[str, Any]] = None,
):
self.m = m
super().__init__(sample_kwargs)

def build_model(self, X, y, coords=None):
with self:
self.add_coords(coords)
X_ = pm.MutableData("X", X, dims=["obs_ind", "coeffs"])
mu_ = pmb.BART("mu_", X_, y, m=self.m, dims="obs_ind")
mu = pm.Deterministic("mu", pm.math.sigmoid(mu_), dims="obs_ind")
pm.Bernoulli("y_hat", mu, observed=y, dims="obs_ind")


class LogisticRegression(ModelBuilder):
"""
Custom PyMC model for logistic regression.
Parameters
----------
coeff_distribution : PyMC distribution.
Prior distribution of coefficient vector.
distribution_kwargs : dict.
Keyword arguments for prior distribution.
sample_kwargs : dict.
Keyword arguments for sampler.
Examples
--------
>>> import numpy as np
>>> import pymc as pm
>>> from causalpy.pymc_models import LogisticRegression
>>>
>>> X = np.random.rand(10, 10)
>>> y = np.random.rand(10)
>>> m = LogisticRegression(
>>> coeff_distribution=pm.Cauchy,
>>> coeff_distribution_kwargs={"alpha": 0, "beta": 1}
>>> )
>>>
>>> m.fit(X, y)
"""

def __init__(
self,
sample_kwargs: Optional[dict[str, Any]] = None,
coeff_distribution: DistributionMeta = pm.Normal,
coeff_distribution_kwargs: Optional[dict[str, Any]] = None,
):
self.coeff_distribution = coeff_distribution
if coeff_distribution_kwargs is None:
self.coeff_distribution_kwargs = {"mu": 0, "sigma": 50}
else:
self.coeff_distribution_kwargs = coeff_distribution_kwargs

super().__init__(sample_kwargs)

def build_model(self, X, y, coords) -> None:
with self:
self.add_coords(coords)
X_ = pm.MutableData("X", X, dims=["obs_ind", "coeffs"])
beta = self.coeff_distribution(
"beta", dims="coeffs", **self.coeff_distribution_kwargs
)
mu = pm.Deterministic(
"mu", pm.math.sigmoid(pm.math.dot(X_, beta)), dims="obs_ind"
)
pm.Bernoulli("y_hat", mu, observed=y, dims="obs_ind")
865 changes: 865 additions & 0 deletions causalpy/skl_meta_learners.py

Large diffs are not rendered by default.

245 changes: 245 additions & 0 deletions causalpy/summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
"Summary objects."

from dataclasses import dataclass
from typing import Optional


def html_header(columns: list[str], colspan: int = 1) -> str:
"""
Returns HTML code for table header.
Parameters
----------
columns : list[str].
List containing column names.
colspan : int.
Column span.
"""
string = (
f'<th colspan={colspan} style="text-align: center"> '
+ (f'</th> <th style="text-align: center" colspan={colspan}>'.join(columns))
+ "</th>"
)
return '<tr style="text-align: center">' + string + "</tr>"


def html_rows(index: str, values: list, colspan: int = 1) -> str:
"""
Returns HTML code for table rows.
Parameters
----------
items : dict.
The keys of items is the index-set of the rows to be added. The values are
lists containing the values to be added.
colspan : int.
Column span.
"""
string = ""
string += f'<th style="text-align: left"> {index} </th>'

if not isinstance(values, list):
values = [values]

values = map(str, values)
string += (
f'<td colspan={colspan} style="text-align: center">'
+ ('</td> <td style="text-align: center">'.join(values))
+ "</td>"
)

string = f"<tr> {string} </tr>"

return string


def str_header(columns: list[str], length: int) -> str:
"""
Returns table header string.
Parameters
----------
columns : list[str].
List containing column names.
length : int.
Length of box containing header.
"""
# If first column is empty, it's box should not be displayed
first_col_empty = columns[0] == ""

if first_col_empty:
columns = columns[1:]

n_cols = len(columns)

top = f"{length * '═'}╦"
bot = f"{length * '═'}╩"

spaces = first_col_empty * (length + 1) * " "

return f"""\
{spaces}{(n_cols - 1) * top}{length * "═"}
{spaces}{"║".join(map(lambda x: x.center(length), columns))}
{length * '━'}{(n_cols - 1) * bot}{length * "═"}\
"""


def str_row(index: str, values: str, length: int, row_type: str = "inner") -> str:
"""
Returns string table row.
Parameters
----------
index : str.
Index of the row.
values : int.
Values of the row.
length : int.
Length of boxes.
row_type : str.
One of "inner", "first", "last".
"""
n_cols = len(values)
values = map(str, values)

s = ""

if row_type == "first":
s += f""" ┏{length * '━'}┱"""
s += f"{(n_cols - 1) * (length * '─' + '┬')}"
s += f"{(length * '─' + '┐')}"
s += "\n"

s += f"""\
{index.center(length)}{"│".join(map(lambda x: x.center(length), values))}
"""
if row_type == "last":
s += f"┗{length * '━'}{(n_cols - 1) * (length * '─' + '┴')}{length * '─'}┘"
else:
s += f"┣{length * '━'}{(n_cols - 1) * (length * '─' + '┼')}{length * '─'}┤"

return s


def str_title(title: str):
"""
Returns title in a box.
Parameters
----------
title : str.
String to return in a box.
"""
length = len(title)
return f"""\
{(length + 2) * '═'}
{title}
{(length + 2) * '═'}
"""


@dataclass
class Record:
"""
Class representing either a header or a row of a table.
Parameters
----------
record_type : str.
Either 'header', 'title' or 'row'.
values : list.
List of values in record.
colspan : int.
Column span.
"""

record_type: str
values: list
colspan: int
index: Optional[str] = None

def to_html(self):
"Returns HTML code for record."
if self.record_type in ["header", "title"]:
return html_header(self.values, self.colspan)
else:
return html_rows(self.index, self.values, self.colspan)


class Summary:
"""
Base summary class.
"""

def __init__(self):
self.records = []

def add_header(self, col_names, colspan) -> None:
"Adds a header."
self.records.append(Record("header", col_names, colspan))

def add_row(self, index: str, values: list, colspan: int) -> None:
"Adds a row."
self.records.append(Record("row", values, colspan, index=index))

def add_title(self, title) -> None:
"Adds title."
self.records.append(Record("title", title, 1))

def get_longest_length(self) -> int:
"""
Returns the lenght of the longest item currently in the table.
"""
non_titles = [r for r in self.records if not r.record_type == "title"]

def length_of_items(L):
return [len(str(x)) for x in L]

vals = [length_of_items(r.values) for r in non_titles]
longest_value_length = max([max(x) for x in vals])
indx = [len(str(x.index)) for x in self.records]
longest_index_length = max(indx)
return max(longest_index_length, longest_value_length)

def to_string(self) -> str:
"""
Return string summary table in string form.
"""
type_of_next_row = "first"
length = self.get_longest_length()
s = ""

for i, x in enumerate(self.records):
if i == len(self.records) - 1:
type_of_next_row = "last"
elif self.records[i + 1].record_type == "title":
type_of_next_row = "last"

if x.record_type == "header":
s += str_header(x.values, length)
type_of_next_row = "inner"
elif x.record_type == "title":
s += str_title(x.values[0])
type_of_next_row = "first"
else:
s += str_row(x.index, x.values, length, type_of_next_row)
type_of_next_row = "inner"
s += "\n"
return s

def to_html(self) -> str:
"""
Returns HTML code for summary table.
"""
s = ""

for x in self.records:
s += x.to_html()

return "<table>" + s + "</table>"

def _repr_html_(self) -> str:
return self.to_html()

def __repr__(self) -> str:
return self.to_string()
13 changes: 13 additions & 0 deletions causalpy/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
import pandas as pd

from causalpy.pymc_models import ModelBuilder


def _fit(model, X, y, coords):
"""
Fits model to X, y, where model is either a sklearn model or a ModelBuilder
instance. In the later case it passes coords, in the first case coords is ignored.
"""
if isinstance(model, ModelBuilder):
model.fit(X, y, coords)
else:
model.fit(X, y)


def _is_variable_dummy_coded(series: pd.Series) -> bool:
"""Check if a data in the provided Series is dummy coded. It should be 0 or 1
819 changes: 819 additions & 0 deletions docs/notebooks/meta_learners_synthetic_data.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -35,10 +35,11 @@ dependencies = [
"pandas",
"patsy",
"pymc>=5.0.0",
"pymc-bart",
"scikit-learn>=1",
"scipy",
"seaborn>=0.11.2",
"xarray>=v2022.11.0",
"xarray>=v2022.11.0"
]

# List additional groups of dependencies here (e.g. development dependencies). Users