diff --git a/causalpy/experiments/prepostnegd.py b/causalpy/experiments/prepostnegd.py
index a187b7aa..32c1ceb1 100644
--- a/causalpy/experiments/prepostnegd.py
+++ b/causalpy/experiments/prepostnegd.py
@@ -82,7 +82,7 @@ class PrePostNEGD(BaseExperiment):
         Intercept      -0.5, 94% HDI [-1, 0.2]
         C(group)[T.1]  2, 94% HDI [2, 2]
         pre            1, 94% HDI [1, 1]
-        sigma          0.5, 94% HDI [0.5, 0.6]
+        y_hat_sigma    0.5, 94% HDI [0.5, 0.6]
     """
 
     supports_ols = False
diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py
index 50bfb0cb..d3ecbb6e 100644
--- a/causalpy/pymc_models.py
+++ b/causalpy/pymc_models.py
@@ -22,6 +22,7 @@
 import pytensor.tensor as pt
 import xarray as xr
 from arviz import r2_score
+from pymc_extras.prior import Prior
 
 from causalpy.utils import round_num
 
@@ -89,7 +90,18 @@ class PyMCModel(pm.Model):
     Inference data...
     """
 
-    def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
+    @property
+    def default_priors(self):
+        return {}
+
+    def priors_from_data(self, X, y) -> Dict[str, Any]:
+        return {}
+
+    def __init__(
+        self,
+        sample_kwargs: Optional[Dict[str, Any]] = None,
+        priors: dict[str, Any] | None = None,
+    ):
         """
         :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the
             :func:`pymc.sample` function. Defaults to an empty dictionary.
@@ -98,6 +110,8 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
         self.idata = None
         self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {}
 
+        self.priors = {**self.default_priors, **(priors or {})}
+
     def build_model(self, X, y, coords) -> None:
         """Build the model, must be implemented by subclass."""
         raise NotImplementedError("This method must be implemented by a subclass")
@@ -143,6 +157,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
         # sample_posterior_predictive() if provided in sample_kwargs.
         random_seed = self.sample_kwargs.get("random_seed", None)
 
+        self.priors = {**self.priors_from_data(X, y), **self.priors}
+
         self.build_model(X, y, coords)
         with self:
             self.idata = pm.sample(**self.sample_kwargs)
@@ -238,26 +254,34 @@ def print_coefficients_for_unit(
         ) -> None:
             """Print coefficients for a single unit"""
             # Determine the width of the longest label
-            max_label_length = max(len(name) for name in labels + ["sigma"])
+            max_label_length = max(len(name) for name in labels + ["y_hat_sigma"])
 
             for name in labels:
                 coeff_samples = unit_coeffs.sel(coeffs=name)
                 print_row(max_label_length, name, coeff_samples, round_to)
 
             # Add coefficient for measurement std
-            print_row(max_label_length, "sigma", unit_sigma, round_to)
+            print_row(max_label_length, "y_hat_sigma", unit_sigma, round_to)
 
         print("Model coefficients:")
         coeffs = az.extract(self.idata.posterior, var_names="beta")
 
-        # Always has treated_units dimension - no branching needed!
+        # Check if sigma or y_hat_sigma variable exists
+        sigma_var_name = None
+        if "sigma" in self.idata.posterior:
+            sigma_var_name = "sigma"
+        elif "y_hat_sigma" in self.idata.posterior:
+            sigma_var_name = "y_hat_sigma"
+        else:
+            raise ValueError("Neither 'sigma' nor 'y_hat_sigma' found in posterior")
+
         treated_units = coeffs.coords["treated_units"].values
         for unit in treated_units:
             if len(treated_units) > 1:
                 print(f"\nTreated unit: {unit}")
 
             unit_coeffs = coeffs.sel(treated_units=unit)
-            unit_sigma = az.extract(self.idata.posterior, var_names="sigma").sel(
+            unit_sigma = az.extract(self.idata.posterior, var_names=sigma_var_name).sel(
                 treated_units=unit
             )
             print_coefficients_for_unit(unit_coeffs, unit_sigma, labels, round_to or 2)
@@ -300,6 +324,15 @@ class LinearRegression(PyMCModel):
     Inference data...
     """  # noqa: W605
 
+    default_priors = {
+        "beta": Prior("Normal", mu=0, sigma=50, dims=["treated_units", "coeffs"]),
+        "y_hat": Prior(
+            "Normal",
+            sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]),
+            dims=["obs_ind", "treated_units"],
+        ),
+    }
+
     def build_model(self, X, y, coords):
         """
         Defines the PyMC model
@@ -313,12 +346,13 @@ def build_model(self, X, y, coords):
             self.add_coords(coords)
             X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
             y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
-            beta = pm.Normal("beta", 0, 50, dims=["treated_units", "coeffs"])
-            sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
+            # beta = pm.Normal("beta", 0, 50, dims=["treated_units", "coeffs"])
+            beta = self.priors["beta"].create_variable("beta")
             mu = pm.Deterministic(
                 "mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
             )
-            pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
+            # pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
+            self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
 
 
 class WeightedSumFitter(PyMCModel):
@@ -361,23 +395,35 @@ class WeightedSumFitter(PyMCModel):
     Inference data...
     """  # noqa: W605
 
+    default_priors = {
+        "y_hat": Prior(
+            "Normal",
+            sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]),
+            dims=["obs_ind", "treated_units"],
+        ),
+    }
+
+    def priors_from_data(self, X, y) -> Dict[str, Any]:
+        n_predictors = X.shape[1]
+        return {
+            "beta": Prior(
+                "Dirichlet", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
+            ),
+        }
+
     def build_model(self, X, y, coords):
         """
         Defines the PyMC model
         """
         with self:
             self.add_coords(coords)
-            n_predictors = X.sizes["coeffs"]
             X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
             y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
-            beta = pm.Dirichlet(
-                "beta", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
-            )
-            sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
+            beta = self.priors["beta"].create_variable("beta")
             mu = pm.Deterministic(
                 "mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
             )
-            pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
+            self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
 
 
 class InstrumentalVariableRegression(PyMCModel):
@@ -566,13 +612,17 @@ class PropensityScore(PyMCModel):
     Inference...
     """  # noqa: W605
 
+    default_priors = {
+        "b": Prior("Normal", mu=0, sigma=1, dims="coeffs"),
+    }
+
     def build_model(self, X, t, coords):
         "Defines the PyMC propensity model"
         with self:
             self.add_coords(coords)
             X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"])
             t_data = pm.Data("t", t.flatten(), dims="obs_ind")
-            b = pm.Normal("b", mu=0, sigma=1, dims="coeffs")
+            b = self.priors["b"].create_variable("b")
             mu = pt.dot(X_data, b)
             p = pm.Deterministic("p", pm.math.invlogit(mu))
             pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind")
diff --git a/causalpy/tests/test_pymc_models.py b/causalpy/tests/test_pymc_models.py
index 22f3a045..e5fc9582 100644
--- a/causalpy/tests/test_pymc_models.py
+++ b/causalpy/tests/test_pymc_models.py
@@ -45,7 +45,7 @@ def build_model(self, X, y, coords):
             X_ = pm.Data(name="X", value=X, dims=["obs_ind", "coeffs"])
             y_ = pm.Data(name="y", value=y, dims=["obs_ind", "treated_units"])
             beta = pm.Normal("beta", mu=0, sigma=1, dims=["treated_units", "coeffs"])
-            sigma = pm.HalfNormal("sigma", sigma=1, dims="treated_units")
+            sigma = pm.HalfNormal("y_hat_sigma", sigma=1, dims="treated_units")
             mu = pm.Deterministic(
                 "mu", pm.math.dot(X_, beta.T), dims=["obs_ind", "treated_units"]
             )
@@ -159,7 +159,7 @@ def test_fit_predict(self, coords, rng, mock_pymc_sample) -> None:
             2,
             2 * 2,
         )  # (treated_units, coeffs, sample)
-        assert az.extract(data=model.idata, var_names=["sigma"]).shape == (
+        assert az.extract(data=model.idata, var_names=["y_hat_sigma"]).shape == (
             1,
             2 * 2,
         )  # (treated_units, sample)
@@ -402,7 +402,7 @@ def test_multi_unit_coefficients(self, synthetic_control_data):
 
         # Extract coefficients
         beta = az.extract(wsf.idata.posterior, var_names="beta")
-        sigma = az.extract(wsf.idata.posterior, var_names="sigma")
+        sigma = az.extract(wsf.idata.posterior, var_names="y_hat_sigma")
 
         # Check beta dimensions: should be (sample, treated_units, coeffs)
         assert "treated_units" in beta.dims
@@ -461,7 +461,7 @@ def test_print_coefficients_multi_unit(self, synthetic_control_data, capsys):
             assert control in output
 
         # Check that sigma is printed for each unit
-        assert output.count("sigma") == len(treated_units)
+        assert output.count("y_hat_sigma") == len(treated_units)
 
     def test_scoring_multi_unit(self, synthetic_control_data):
         """Test that scoring works with multiple treated units."""
diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg
index d2d886ad..5b70fde2 100644
--- a/docs/source/_static/interrogate_badge.svg
+++ b/docs/source/_static/interrogate_badge.svg
@@ -1,10 +1,10 @@
 <svg width="140" height="20" viewBox="0 0 140 20" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" xmlns:serif="http://www.serif.com/" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2;">
-    <title>interrogate: 95.4%</title>
+    <title>interrogate: 94.2%</title>
     <g transform="matrix(1,0,0,1,22,0)">
         <g id="backgrounds" transform="matrix(1.32789,0,0,1,-22.3892,0)">
             <rect x="0" y="0" width="71" height="20" style="fill:rgb(85,85,85);"/>
         </g>
-        <rect x="71" y="0" width="47" height="20" data-interrogate="color" style="fill:#4c1"/>
+        <rect x="71" y="0" width="47" height="20" data-interrogate="color" style="fill:#97CA00"/>
         <g transform="matrix(1.19746,0,0,1,-22.3744,-4.85723e-16)">
             <rect x="0" y="0" width="118" height="20" style="fill:url(#_Linear1);"/>
         </g>
@@ -12,8 +12,8 @@
     <g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="110">
         <text x="590" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="610">interrogate</text>
         <text x="590" y="140" transform="scale(.1)" textLength="610">interrogate</text>
-        <text x="1160" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="370" data-interrogate="result">95.4%</text>
-        <text x="1160" y="140" transform="scale(.1)" textLength="370" data-interrogate="result">95.4%</text>
+        <text x="1160" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="370" data-interrogate="result">94.2%</text>
+        <text x="1160" y="140" transform="scale(.1)" textLength="370" data-interrogate="result">94.2%</text>
     </g>
     <g id="logo-shadow" serif:id="logo shadow" transform="matrix(0.854876,0,0,0.854876,-6.73514,1.732)">
         <g transform="matrix(0.299012,0,0,0.299012,9.70229,-6.68582)">
diff --git a/environment.yml b/environment.yml
index 02b7f920..2bc8ed20 100644
--- a/environment.yml
+++ b/environment.yml
@@ -15,3 +15,4 @@ dependencies:
   - seaborn>=0.11.2
   - statsmodels
   - xarray>=v2022.11.0
+  - pymc-extras>=0.2.7
diff --git a/pyproject.toml b/pyproject.toml
index 909f7969..88df3f87 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,6 +40,7 @@ dependencies = [
     "seaborn>=0.11.2",
     "statsmodels",
     "xarray>=v2022.11.0",
+    "pymc-extras>=0.2.7",
 ]
 
 # List additional groups of dependencies here (e.g. development dependencies). Users