# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quasi-Experiment classes for OLS inference"""
import warnings
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from patsy import build_design_matrices, dmatrices
from causalpy.data_validation import (
DiDDataValidator,
PrePostFitDataValidator,
RDDataValidator,
)
from causalpy.utils import round_num
LEGEND_FONT_SIZE = 12
[docs]
class ExperimentalDesign:
"""Base class for experiment designs"""
model = None
expt_type = None
outcome_variable_name = None
[docs]
def __init__(self, model=None, **kwargs):
if model is not None:
self.model = model
if self.model is None:
raise ValueError("fitting_model not set or passed.")
[docs]
def print_coefficients(self, round_to=None) -> None:
"""
Prints the model coefficients
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
print("Model coefficients:")
# Determine the width of the longest label
max_label_length = max(len(name) for name in self.labels)
# Print each coefficient with formatted alignment
for name, val in zip(self.labels, self.model.coef_[0]):
# Left-align the name
formatted_name = f"{name:<{max_label_length}}"
# Right-align the value with width 10
formatted_val = f"{round_num(val, round_to):>10}"
print(f" {formatted_name}\t{formatted_val}")
[docs]
class PrePostFit(ExperimentalDesign, PrePostFitDataValidator):
"""
A class to analyse quasi-experiments where parameter estimation is based on just
the pre-intervention data.
:param data:
A pandas data frame
:param treatment_time:
The index or time value of when treatment begins
:param formula:
A statistical model formula
:param model:
An scikit-learn model object
Example
--------
>>> from sklearn.linear_model import LinearRegression
>>> import causalpy as cp
>>> df = cp.load_data("sc")
>>> treatment_time = 70
>>> result = cp.skl_experiments.PrePostFit(
... df,
... treatment_time,
... formula="actual ~ 0 + a + b + c + d + e + f + g",
... model = cp.skl_models.WeightedProportion()
... )
>>> result.get_coeffs()
array(...)
"""
[docs]
def __init__(
self,
data,
treatment_time,
formula,
model=None,
**kwargs,
):
super().__init__(model=model, **kwargs)
self._input_validation(data, treatment_time)
self.treatment_time = treatment_time
# set experiment type - usually done in subclasses
self.expt_type = "Pre-Post Fit"
# split data in to pre and post intervention
self.datapre = data[data.index < self.treatment_time]
self.datapost = data[data.index >= self.treatment_time]
self.formula = formula
# set things up with pre-intervention data
y, X = dmatrices(formula, self.datapre)
self.outcome_variable_name = y.design_info.column_names[0]
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
# process post-intervention data
(new_y, new_x) = build_design_matrices(
[self._y_design_info, self._x_design_info], self.datapost
)
self.post_X = np.asarray(new_x)
self.post_y = np.asarray(new_y)
# fit the model to the observed (pre-intervention) data
self.model.fit(X=self.pre_X, y=self.pre_y)
# score the goodness of fit to the pre-intervention data
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
# get the model predictions of the observed (pre-intervention) data
self.pre_pred = self.model.predict(X=self.pre_X)
# calculate the counterfactual
self.post_pred = self.model.predict(X=self.post_X)
# causal impact pre (ie the residuals of the model fit to observed)
self.pre_impact = self.pre_y - self.pre_pred
# causal impact post (ie the impact of the intervention)
self.post_impact = self.post_y - self.post_pred
# cumulative impact post
self.post_impact_cumulative = np.cumsum(self.post_impact)
[docs]
def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
"""Plot experiment results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
ax[0].plot(self.datapre.index, self.pre_y, "k.")
ax[0].plot(self.datapost.index, self.post_y, "k.")
ax[0].plot(self.datapre.index, self.pre_pred, c="k", label="model fit")
ax[0].plot(
self.datapost.index,
self.post_pred,
label=counterfactual_label,
ls=":",
c="k",
)
ax[0].set(
title=f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}"
)
ax[1].plot(self.datapre.index, self.pre_impact, "k.")
ax[1].plot(
self.datapost.index,
self.post_impact,
"k.",
label=counterfactual_label,
)
ax[1].axhline(y=0, c="k")
ax[1].set(title="Causal Impact")
ax[2].plot(self.datapost.index, self.post_impact_cumulative, c="k")
ax[2].axhline(y=0, c="k")
ax[2].set(title="Cumulative Causal Impact")
# Shaded causal effect
ax[0].fill_between(
self.datapost.index,
y1=np.squeeze(self.post_pred),
y2=np.squeeze(self.post_y),
color="C0",
alpha=0.25,
label="Causal impact",
)
ax[1].fill_between(
self.datapost.index,
y1=np.squeeze(self.post_impact),
color="C0",
alpha=0.25,
label="Causal impact",
)
# Intervention line
# TODO: make this work when self.treatment_time is a datetime
for i in [0, 1, 2]:
ax[i].axvline(
x=self.treatment_time,
ls="-",
lw=3,
color="r",
label="Treatment time",
)
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)
[docs]
def get_coeffs(self):
"""
Returns model coefficients
"""
return np.squeeze(self.model.coef_)
[docs]
def plot_coeffs(self):
"""Plots coefficient bar plot"""
df = pd.DataFrame(
{"predictor variable": self.labels, "ols_coef": self.get_coeffs()}
)
sns.barplot(
data=df,
x="ols_coef",
y="predictor variable",
palette=sns.color_palette("husl"),
)
[docs]
def summary(self, round_to=None) -> None:
"""
Print text output summarising the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
self.print_coefficients(round_to)
[docs]
class InterruptedTimeSeries(PrePostFit):
"""
Interrupted time series analysis, a wrapper around the PrePostFit class
:param data:
A pandas data frame
:param treatment_time:
The index or time value of when treatment begins
:param formula:
A statistical model formula
:param model:
An sklearn model object
Example
--------
>>> from sklearn.linear_model import LinearRegression
>>> import pandas as pd
>>> import causalpy as cp
>>> df = (
... cp.load_data("its")
... .assign(date=lambda x: pd.to_datetime(x["date"]))
... .set_index("date")
... )
>>> treatment_time = pd.to_datetime("2017-01-01")
>>> result = cp.skl_experiments.InterruptedTimeSeries(
... df,
... treatment_time,
... formula="y ~ 1 + t + C(month)",
... model = LinearRegression()
... )
"""
expt_type = "Interrupted Time Series"
[docs]
class SyntheticControl(PrePostFit):
"""
A wrapper around the PrePostFit class
:param data:
A pandas data frame
:param treatment_time:
The index or time value of when treatment begins
:param formula:
A statistical model formula
:param model:
An sklearn model object
Example
--------
>>> from sklearn.linear_model import LinearRegression
>>> import causalpy as cp
>>> df = cp.load_data("sc")
>>> treatment_time = 70
>>> result = cp.skl_experiments.SyntheticControl(
... df,
... treatment_time,
... formula="actual ~ 0 + a + b + c + d + e + f + g",
... model = cp.skl_models.WeightedProportion()
... )
"""
[docs]
def plot(self, plot_predictors=False, round_to=None, **kwargs):
"""Plot the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = super().plot(
counterfactual_label="Synthetic control", round_to=round_to, **kwargs
)
if plot_predictors:
# plot control units as well
ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)
ax[0].plot(
self.datapost.index, self.post_X, "-", c=[0.8, 0.8, 0.8], zorder=1
)
return (fig, ax)
[docs]
class DifferenceInDifferences(ExperimentalDesign, DiDDataValidator):
"""
.. note::
There is no pre/post intervention data distinction for DiD, we fit all the data
available.
:param data:
A pandas data frame
:param formula:
A statistical model formula
:param time_variable_name:
Name of the data column for the time variable
:param group_variable_name:
Name of the data column for the group variable
:param model:
An scikit-learn model for difference in differences
Example
--------
>>> import causalpy as cp
>>> from sklearn.linear_model import LinearRegression
>>> df = cp.load_data("did")
>>> result = cp.skl_experiments.DifferenceInDifferences(
... df,
... formula="y ~ 1 + group*post_treatment",
... time_variable_name="t",
... group_variable_name="group",
... treated=1,
... untreated=0,
... model=LinearRegression(),
... )
"""
[docs]
def __init__(
self,
data: pd.DataFrame,
formula: str,
time_variable_name: str,
group_variable_name: str,
treated: str,
untreated: str,
model=None,
**kwargs,
):
super().__init__(model=model, **kwargs)
self.data = data
self.expt_type = "Difference in Differences"
self.formula = formula
self.time_variable_name = time_variable_name
self.group_variable_name = group_variable_name
self._input_validation()
self.treated = treated # level of the group_variable_name that was treated
self.untreated = (
untreated # level of the group_variable_name that was untreated
)
y, X = dmatrices(formula, self.data)
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.y, self.X = np.asarray(y), np.asarray(X)
self.outcome_variable_name = y.design_info.column_names[0]
# fit the model to all the data
self.model.fit(X=self.X, y=self.y)
# predicted outcome for control group
self.x_pred_control = (
self.data
# just the untreated group
.query(f"{self.group_variable_name} == @self.untreated")
# drop the outcome variable
.drop(self.outcome_variable_name, axis=1)
# We may have multiple units per time point, we only want one time point
.groupby(self.time_variable_name)
.first()
.reset_index()
)
assert not self.x_pred_control.empty
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
self.y_pred_control = self.model.predict(np.asarray(new_x))
# predicted outcome for treatment group
self.x_pred_treatment = (
self.data
# just the treated group
.query(f"{self.group_variable_name} == @self.treated")
# drop the outcome variable
.drop(self.outcome_variable_name, axis=1)
# We may have multiple units per time point, we only want one time point
.groupby(self.time_variable_name)
.first()
.reset_index()
)
assert not self.x_pred_treatment.empty
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
self.y_pred_treatment = self.model.predict(np.asarray(new_x))
# predicted outcome for counterfactual. This is given by removing the influence
# of the interaction term between the group and the post_treatment variable
self.x_pred_counterfactual = (
self.data
# just the treated group
.query(f"{self.group_variable_name} == @self.treated")
# just the treatment period(s)
.query("post_treatment == True")
# drop the outcome variable
.drop(self.outcome_variable_name, axis=1)
# We may have multiple units per time point, we only want one time point
.groupby(self.time_variable_name)
.first()
.reset_index()
)
assert not self.x_pred_counterfactual.empty
(new_x,) = build_design_matrices(
[self._x_design_info], self.x_pred_counterfactual, return_type="dataframe"
)
# INTERVENTION: set the interaction term between the group and the
# post_treatment variable to zero. This is the counterfactual.
for i, label in enumerate(self.labels):
if "post_treatment" in label and self.group_variable_name in label:
new_x.iloc[:, i] = 0
self.y_pred_counterfactual = self.model.predict(np.asarray(new_x))
# calculate causal impact
# This is the coefficient on the interaction term
# TODO: THIS IS NOT YET CORRECT
self.causal_impact = self.y_pred_treatment[1] - self.y_pred_counterfactual[0]
[docs]
def plot(self, round_to=None):
"""Plot results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = plt.subplots()
# Plot raw data
sns.lineplot(
self.data,
x=self.time_variable_name,
y=self.outcome_variable_name,
hue="group",
units="unit",
estimator=None,
alpha=0.25,
ax=ax,
)
# Plot model fit to control group
ax.plot(
self.x_pred_control[self.time_variable_name],
self.y_pred_control,
"o",
c="C0",
markersize=10,
label="model fit (control group)",
)
# Plot model fit to treatment group
ax.plot(
self.x_pred_treatment[self.time_variable_name],
self.y_pred_treatment,
"o",
c="C1",
markersize=10,
label="model fit (treament group)",
)
# Plot counterfactual - post-test for treatment group IF no treatment
# had occurred.
ax.plot(
self.x_pred_counterfactual[self.time_variable_name],
self.y_pred_counterfactual,
"go",
markersize=10,
label="counterfactual",
)
# arrow to label the causal impact
ax.annotate(
"",
xy=(1.05, self.y_pred_counterfactual),
xycoords="data",
xytext=(1.05, self.y_pred_treatment[1]),
textcoords="data",
arrowprops={"arrowstyle": "<->", "color": "green", "lw": 3},
)
ax.annotate(
"causal\nimpact",
xy=(
1.05,
np.mean([self.y_pred_counterfactual[0], self.y_pred_treatment[1]]),
),
xycoords="data",
xytext=(5, 0),
textcoords="offset points",
color="green",
va="center",
)
# formatting
ax.set(
xlim=[-0.05, 1.1],
xticks=[0, 1],
xticklabels=["pre", "post"],
title=f"Causal impact = {round_num(self.causal_impact[0], round_to)}",
)
ax.legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)
[docs]
def summary(self, round_to=None) -> None:
"""
Print text output summarising the results.
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
print("\nResults:")
print(f"Causal impact = {round_num(self.causal_impact[0], round_to)}")
self.print_coefficients(round_to)
[docs]
class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator):
"""
A class to analyse sharp regression discontinuity experiments.
:param data:
A pandas dataframe
:param formula:
A statistical model formula
:param treatment_threshold:
A scalar threshold value at which the treatment is applied
:param model:
A sci-kit learn model object
:param running_variable_name:
The name of the predictor variable that the treatment threshold is based upon
:param epsilon:
A small scalar value which determines how far above and below the treatment
threshold to evaluate the causal impact.
:param bandwidth:
Data outside of the bandwidth (relative to the discontinuity) is not used to fit
the model.
Example
--------
>>> import causalpy as cp
>>> from sklearn.linear_model import LinearRegression
>>> data = cp.load_data("rd")
>>> result = cp.skl_experiments.RegressionDiscontinuity(
... data,
... formula="y ~ 1 + x + treated",
... model=LinearRegression(),
... treatment_threshold=0.5,
... )
"""
[docs]
def __init__(
self,
data,
formula,
treatment_threshold,
model=None,
running_variable_name="x",
epsilon: float = 0.001,
bandwidth: Optional[float] = None,
**kwargs,
):
super().__init__(model=model, **kwargs)
self.data = data
self.formula = formula
self.running_variable_name = running_variable_name
self.treatment_threshold = treatment_threshold
self.bandwidth = bandwidth
self.epsilon = epsilon
self._input_validation()
if self.bandwidth is not None:
fmin = self.treatment_threshold - self.bandwidth
fmax = self.treatment_threshold + self.bandwidth
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
if len(filtered_data) <= 10:
warnings.warn(
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
UserWarning,
)
y, X = dmatrices(formula, filtered_data)
else:
y, X = dmatrices(formula, self.data)
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.y, self.X = np.asarray(y), np.asarray(X)
self.outcome_variable_name = y.design_info.column_names[0]
# TODO: `treated` is a deterministic function of x and treatment_threshold, so
# this could be a function rather than supplied data
# fit the model to all the data
self.model.fit(X=self.X, y=self.y)
# score the goodness of fit to all data
self.score = self.model.score(X=self.X, y=self.y)
# get the model predictions of the observed data
if self.bandwidth is not None:
xi = np.linspace(fmin, fmax, 200)
else:
xi = np.linspace(
np.min(self.data[self.running_variable_name]),
np.max(self.data[self.running_variable_name]),
200,
)
self.x_pred = pd.DataFrame(
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
)
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
self.pred = self.model.predict(X=np.asarray(new_x))
# calculate discontinuity by evaluating the difference in model expectation on
# either side of the discontinuity
# NOTE: `"treated": np.array([0, 1])`` assumes treatment is applied above
# (not below) the threshold
self.x_discon = pd.DataFrame(
{
self.running_variable_name: np.array(
[
self.treatment_threshold - self.epsilon,
self.treatment_threshold + self.epsilon,
]
),
"treated": np.array([0, 1]),
}
)
(new_x,) = build_design_matrices([self._x_design_info], self.x_discon)
self.pred_discon = self.model.predict(X=np.asarray(new_x))
self.discontinuity_at_threshold = np.squeeze(self.pred_discon[1]) - np.squeeze(
self.pred_discon[0]
)
def _is_treated(self, x):
"""Returns ``True`` if ``x`` is greater than or equal to the treatment
threshold.
.. warning::
Assumes treatment is given to those ABOVE the treatment threshold.
"""
return np.greater_equal(x, self.treatment_threshold)
[docs]
def plot(self, round_to=None):
"""Plot results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = plt.subplots()
# Plot raw data
sns.scatterplot(
self.data,
x=self.running_variable_name,
y=self.outcome_variable_name,
c="k", # hue="treated",
ax=ax,
)
# Plot model fit to data
ax.plot(
self.x_pred[self.running_variable_name],
self.pred,
"k",
markersize=10,
label="model fit",
)
# create strings to compose title
r2 = f"$R^2$ on all data = {round_num(self.score, round_to)}"
discon = f"Discontinuity at threshold = {round_num(self.discontinuity_at_threshold, round_to)}"
ax.set(title=r2 + "\n" + discon)
# Intervention line
ax.axvline(
x=self.treatment_threshold,
ls="-",
lw=3,
color="r",
label="treatment threshold",
)
ax.legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)
[docs]
def summary(self, round_to=None) -> None:
"""
Print text output summarising the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
print("Difference in Differences experiment")
print(f"Formula: {self.formula}")
print(f"Running variable: {self.running_variable_name}")
print(f"Threshold on running variable: {self.treatment_threshold}")
print("\nResults:")
print(f"Discontinuity at threshold = {self.discontinuity_at_threshold:.2f}")
print("\n")
self.print_coefficients(round_to)