-
Notifications
You must be signed in to change notification settings - Fork 80
Refactored scikit-learn flavour of DifferenceInDifferences and allowed custom column names for post_treatment variable. #515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #515 +/- ##
==========================================
+ Coverage 95.19% 95.29% +0.09%
==========================================
Files 28 28
Lines 2457 2507 +50
==========================================
+ Hits 2339 2389 +50
Misses 118 118 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Looks like the remote checks are failing. Sometimes you need to run the pre-commit checks locally twice - the interrogate thing is a bit fiddly.
- And looks like we'll need to increase test coverage. So obvious ones would be to include tests where we use the default, or a user-provided post treatment variable name.
Overall, this is looking good. Thanks for the PR :)
Oh, remember to update from main regularly :)
causalpy/experiments/diff_in_diff.py
Outdated
) | ||
# Check if post_treatment_variable_name is in formula | ||
if self.post_treatment_variable_name not in self.formula: | ||
if self.post_treatment_variable_name == "post_treatment": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've got a minor preference to just give one generic exception message, rather than a custom one dependent on self.post_treatment_variable_name
. That will also cut down on the number of tests required to achieve high test coverage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah absolutely!! More generic ones like "Missing required variable '{self.post_treatment_variable_name}' in formula" can be used
causalpy/experiments/diff_in_diff.py
Outdated
|
||
# Check if post_treatment_variable_name is in data columns | ||
if self.post_treatment_variable_name not in self.data.columns: | ||
if self.post_treatment_variable_name == "post_treatment": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above. Just give one more generic exception message, regardless of what self.post_treatment_variable_name
is.
# Store the coefficient into dictionary {intercept:value} | ||
coef_map = dict(zip(self.labels, self.model.get_coeffs())) | ||
# Create and find the interaction term based on the values user provided | ||
interaction_term = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. We'll need more tests anyway to ensure test coverage, so when you do that can you add cases for when people specify formulas like post_treatment:a
and post_treatment*b
. It should work because we'll always get a coefficient for post_treatment:a
, but it is worth adding the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, will add some tests for a cases where a user provides post treatment variable name and check for FormulaExeption
and DataException
but @drbenvincent can you elaborate on this specific test. Are we also checking the coefficient value where two interaction terms are used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd not thought of that. I guess it's easy to find and interaction term of the post treatment variable and something else. But if there are two interaction terms, both including the post treatment variable, then that might get messy. Can we think of any situations where that be a good idea? If not, then maybe that could throw and exception and we just say we can't deal with a formula like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since our users can write any formula freely—unlike other libraries that rely on closed systems—they could specify any formula like post_treatment * group + post_treatment * group * male
which might be uncommon but it’s entirely possible in our setup.
The users can obtain estimates for exactly what they define in the formula. However, we’ve built this did object specifically for two-way Diff-in-diff with a single interaction term ?-- thus the other features might get messed up as you said.
So yeah @drbenvincent I agree that we could throw exception if we encounter any two interaction term with post_treatment
to move forward
@@ -128,6 +130,12 @@ def __init__( | |||
} | |||
self.model.fit(X=self.X, y=self.y, coords=COORDS) | |||
elif isinstance(self.model, RegressorMixin): | |||
# For scikit-learn models, automatically set fit_intercept=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
…teraction terms,more generic messages
Hi @drbenvincent, here is a draft with following changes:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about the late review - work has been rather busy!
Just a couple of requests/suggestions, and them I'm very happy to merge :)
@@ -84,6 +86,7 @@ def __init__( | |||
formula: str, | |||
time_variable_name: str, | |||
group_variable_name: str, | |||
post_treatment_variable_name: str = "post_treatment", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add post_treatment_variable_name
into the docstring to make it ultra clear what it does
@@ -236,6 +262,61 @@ def input_validation(self): | |||
coded. Consisting of 0's and 1's only.""" | |||
) | |||
|
|||
def _get_interaction_terms(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion (not requirement). This could be made a static method and you could just pass in the formula string as an argument. Or it could be a simple utility function that could go in utils.py
. It might help making testing the function marginally simpler.
Ideally we'd add some test cases for _get_interaction_terms. As in, come up with a set of example formulas and the expected outputs of the function.
PS. I pushed a small change to get the pre-commit checks to work. So remember to pull the latest. |
closes issues #390 and #514
📚 Documentation preview 📚: https://causalpy--515.org.readthedocs.build/en/515/