Skip to content

Commit 85aa2b1

Browse files
authored
Merge pull request #343 from DoubleML/o-rdd
Add possibility of user defined bandwidth in rdrobust
2 parents 9df0711 + c01fd2e commit 85aa2b1

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

doubleml/rdd/rdd.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,14 @@ def __init__(
143143

144144
self._check_effect_sign()
145145

146-
# TODO: Add further input checks
146+
if found_keys := {"h", "b"} & kwargs.keys():
147+
warnings.warn(
148+
(
149+
f"Key-worded arguments contain: {found_keys}.\n"
150+
"Iterative bandwidth selection will be overwritten by provided values."
151+
)
152+
)
153+
147154
self.kwargs = kwargs
148155

149156
self._smpls = DoubleMLResampling(
@@ -453,10 +460,16 @@ def _update_weights(self):
453460
def _fit_rdd(self, h=None, b=None):
454461
if self.fuzzy:
455462
rdd_res = rdrobust.rdrobust(
456-
y=self._M_Y[:, self._i_rep], x=self._score, fuzzy=self._M_D[:, self._i_rep], h=h, b=b, **self.kwargs
463+
y=self._M_Y[:, self._i_rep],
464+
x=self._score,
465+
fuzzy=self._M_D[:, self._i_rep],
466+
c=0,
467+
**({"h": h, "b": b} | self.kwargs),
457468
)
458469
else:
459-
rdd_res = rdrobust.rdrobust(y=self._M_Y[:, self._i_rep], x=self._score, h=h, b=b, **self.kwargs)
470+
rdd_res = rdrobust.rdrobust(
471+
y=self._M_Y[:, self._i_rep], x=self._score, fuzzy=None, c=0, **({"h": h, "b": b} | self.kwargs)
472+
)
460473
return rdd_res
461474

462475
def _set_coefs(self, rdd_res, h):

doubleml/rdd/tests/test_rdd_exceptions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,22 @@ def test_rdd_exception_fit():
237237
msg = "The number of iterations for the iterative bandwidth fitting has to be positive. 0 was passed."
238238
with pytest.raises(ValueError, match=msg):
239239
rdd_model.fit(n_iterations=0)
240+
241+
242+
@pytest.mark.ci_rdd
243+
def test_rdd_warning_kwargs():
244+
msg = r"Key-worded arguments contain: {'h'}.\n" "Iterative bandwidth selection will be overwritten by provided values."
245+
with pytest.warns(UserWarning, match=msg):
246+
_ = RDFlex(dml_data, ml_g, h=0.1)
247+
248+
msg = r"Key-worded arguments contain: {'b'}.\n" "Iterative bandwidth selection will be overwritten by provided values."
249+
with pytest.warns(UserWarning, match=msg):
250+
_ = RDFlex(dml_data, ml_g, b=0.1)
251+
252+
# The order in the set is not guaranteed
253+
msg = (
254+
r"Key-worded arguments contain: {'[hb]', '[hb]'}.\n"
255+
"Iterative bandwidth selection will be overwritten by provided values."
256+
)
257+
with pytest.warns(UserWarning, match=msg):
258+
_ = RDFlex(dml_data, ml_g, h=0.1, b=0.1)

0 commit comments

Comments
 (0)