Heterogeneous Treatment Effect Estimation: Function Approximation
Table of Contents
Introduction
I’ve come to think of Causal Inference as the study of complex, only partially observed systems. If we only observe a part of a system, how does that impact our ability to learn about it? It might seem like if we can’t completely observe a system, we can’t learn much of anything about it, because how do we know the parts we don’t observe don’t impact the conclusions we draw? But it turns out, under certain circumstances, we really can learn at least certain aspects of the system, and that’s what Causal Inference is all about.
In this sense, Causal Inference is different than other types of Machine Learning. In Deep Learning, we observe the entire system, it’s just that the system is complex and so we need a lot of data to learn about it. Causal Inference is kind of like Deep Learning, but without having all the data Deep Learning would need to work well.
In a not-so-recent post, I wrote about interpreting supervised learning as an approach to function approximation. In this post I want to discuss what a “causal effect” really is, and how we go about estimating it. Suppose we want to learn about some unknown function $f(t, x, z),$ where $t$ and $x$ are observed but $z$ is unobserved. Specifically, we want to figure out how $f$ depends on $t,$ which I’ll call the treatment. Perhaps $f$ is the GDP of a country. Clearly this is too complex to ever understand completely, but we’d like to understand how it is influenced by a specific input, say the corporate tax rate.
We might wish to estimate $\partial f / \partial t$ for example. Of course, $f$ might not be differentiable with respect to $t$, or perhaps $t$ only takes on discrete values, such as 0 and 1. In that case we would want to compare $f(0, x, z)$ and $f(1, x, z)$. Slightly more generally, if there are two values of $t$ that are of interest, say $t_1$ and $t_0$ we may wish to estimate $\xi(x, z) = f(t_1, x, z) - f(t_0, x, z).$ This last formulation can stand in for either of the previous quantities, since the partial derivative is simply a suitable limit of differences. So in this post I will focus on estimating $\xi(x, z),$ which I will call the treatment effect, or the causal effect.
Is $\xi$ really causal? We have it really beat into our heads that correlation does not imply causation, and we can only draw causal conclusions under certain circumstances, such as random assignment. But physics draws causal conclusions all the time: if I drop a ball and it falls to the ground, it’s because of gravity. But I can’t exactly A/B test gravity, so can I really know that gravity caused the ball to fall? Of course we can, but why is physics different? It’s because in physics, we actually know what $f$ is. We actually know the equation governing gravity, whereas we don’t know the equation governing economic productivity. Physical systems are simple and self-contained, and the economy is complex and interacts with all of human existence. That’s the difference. The simple formula we have provided for the causal effect is the correct one, it’s just that when we don’t know $f$, and when the system is complex and only partially observed, we struggle to learn anything about it.
Notably, there is nothing random about $f$ or $\xi$, in keeping with my belief that randomness does not exist except in quantum mechanics. There’s nothing random about the economy, it’s just really complicated. But as I discussed in that previous post, we introduce randomness (technically pseudo-randomness, but that’s close enough for statistics) in order to approximate a partially observed system. That is, random sampling is an artificial device we humans can use to learn about a system. There are several non-random problems in mathematics where introducing randomness can help solve the problem, such as evaluating complicated integrals with Monte Carlo methods, and Machine Learning is in the same vein. But probability is so central to solving these problems we often forget probability is not necessarily part of the problem itself, just the way that we solve it.
Approximation to the Treatment Effect
Now, there are a few situations where it’s actually really easy to learn about the relationship between $f$ and $t.$ When we observe all the relevant factors and have a suitably large (and representative) dataset, we can just fit a deep neural network and base any conclusions off of that. (I think Deep Learning folks sometimes consider Causal Inference to be trivial, because they don’t appreciate the impact of not observing all the relevant information, or that we can just impute the missing data or something.) Or when we have the ability to randomly assign the treatment, as in a typical A/B test, and $\xi(x, z)$ is a constant, then we can estimate it as the difference in response between the two groups.
But when we observe all the relevant factors, or can at least impute them, that’s not really a partially observed system. And if the treatment effect is constant, that’s not an especially complicated system. What we’re interested in here is the case of heterogeneous treatment effects with unobserved factors. And in general, if we don’t observe $z$, then we can’t estimate $\xi(x, z).$ In an A/B test, we end up estimating the average treatment effect, but we can do better. We can try to estimate a particular approximation to $\xi(x, z)$ that is a function of $x$ alone. Define $$ \hat{\xi}^\ast(x) = \underset{\zeta \in \mathcal{H}}{\operatorname{arg\,min}} \int_{\mathcal{X}} \int_{\mathcal{Z}} D(\zeta(x), \xi(x, z)) \cdot w(x, z) \,dz \,dx, \quad (1) $$ where $\mathcal{X} \times \mathcal{Z}$ is the domain of $\xi,$ $D$ is some distance or loss function, $w$ is some non-negative weighting function satisfying $\int_{\mathcal{X}} \int_{\mathcal{Z}} w(x, z) \, dz \, dx = 1$ and $\mathcal{H}$ is some family of functions with support on $\mathcal{X}.$
Now, $\hat{\xi}^\ast(x)$ is by definition the best approximation to $\xi(x, z)$ in the family $\mathcal{H},$ where best is interpreted relative to the loss function $D$ and the weighting function $w.$ We will show that under certain circumstances we can calculate $\hat{\xi}^\ast(x)$ in terms of approximations to $f.$ Of course, if it turns out that the treatment effect does not actually depend on $z$, then $\hat{\xi}^\ast(x)$ is the treatment effect, not just an approximation.
The Calculus of Variations
We will need some results from the Calculus of Variations. Let $$ \hat{f}^\ast(t, x) = \underset{g \in \mathcal{G}}{\operatorname{arg\,min}} \int_{\mathcal{T}} \int_{\mathcal{X}} \int_{\mathcal{Z}} D(g(t, x), f(t, x, z)) \cdot w(t, x, z) \,dz \,dx \, dt. $$ We first suppose that $\mathcal{G}$ includes all functions with support on $\mathcal{T} \times \mathcal{X}.$ In this case, the calculus of variations allows us to calculate $\hat{f}^\ast.$ Define $$ L(t, x, g) = \int_{\mathcal{Z}} D(g, f(t, x, z)) \cdot w(t, x, z) \, dz. $$ Then to calculate $\hat{f}^\ast(t, x) = \operatorname{arg\,min}_g \int_{\mathcal{T}} \int_{\mathcal{X}} L(t, x, g) \, dx \, dt,$ we differentiate $L$ with respect to $g$ as if it were a variable and set equal to zero, just as minimizing any other function: $$ \frac{\partial L}{\partial g} = \int_{\mathcal{Z}} \frac{\partial D}{\partial g} \cdot w(t, x, z) \, dz = 0. $$ For example, when $D(g, f) = (g - f)^2,$ then $$ \begin{align} 0 &= \int_{\mathcal{Z}} 2 \cdot (\hat{f}^\ast(t, x) - f(t, x, z)) \cdot w(t, x, z) \, dz \\\ &= \hat{f}^\ast(t, x) \int_{\mathcal{Z}} w(t, x, z) \, dz - \int_{\mathcal{Z}} f(t, x, z) \cdot w(t, x, z) \, dz \\\ \Rightarrow \hat{f}^\ast(t, x) &= \frac{\int_{\mathcal{Z}} f(t, x, z) \cdot w(t, x, z) \, dz}{\int_{\mathcal{Z}} w(t, x, z) \, dz}, \quad (2) \end{align} $$ that is, $\hat{f}^\ast$ is a weighted average of $f.$ It can be shown that when $f$ takes on binary values as in a classification problem, and $D(g, f) = f \cdot \log(g) + (1 - f) \log(1 - g)$ we actually wind up with the same formula for $\hat{f}^\ast,$ but I’ll leave this exercise to the reader. I’ll call this loss the logistic loss since it is the loss function in logistic regression. Recall that $\hat{f}^\ast$ is the best approximation to $f$ not depending on $z,$ so it makes sense that it would be the average value, integrating over $z.$
Next we derive an important property of the residual function, $\epsilon(t, x, z) := f(t, x, z) - \hat{f}^\ast(t, x):$ $$ \begin{align} \int_{\mathcal{Z}} \epsilon(t, x, z) \cdot w(t, x, z) \, dz &= \int_{\mathcal{Z}} (f(t, x, z) - \hat{f}^\ast(t, x)) \cdot w(t, x, z) \, dz \\\ &= \int_{\mathcal{Z}} f(t, x, z) \cdot w(t, x, z) \, dz \\\ &\phantom{=} \hspace{20pt} - \hat{f}^\ast(t, x) \cdot \int_{\mathcal{Z}} w(t, x, z) \, dz \\\ &= 0 \textrm{ for all } t, x \end{align} $$ under Equation (2), applicable for squared loss as well as the logistic loss. In words, the residual function integrates to zero. This is the key result we need for the next section.
We have placed no restrictions on the class $\mathcal{G},$ but it is straightforward to show that when Equation (2) holds, if $f$ is linear in $t$ and $x$ then so is $\hat{f}^\ast.$ In this case, we can restrict $\mathcal{G}$ to be the space of linear functions, but arrive at the same formula for $\hat{f}^\ast,$ and the residual function still integrates to zero. Whereas, if $f$ is nonlinear, the residual function will not necessarily integrate to zero.
A Simple Relationship
Now return to $\hat{\xi}^\ast,$ as defined in Equation (1). Notably, the weight function in the definition of $\hat{\xi}^\ast$ depends on $x$ and $z,$ but in the previous section it also depended on $t.$ We require a consistency condition, $w^\prime(t, x, z) := \psi(t, x) \cdot w(x, z),$ where $\psi(x, t) > 0,$ and with $w^\prime$ standing in for $w$ as used in the last section. This requirement may seem odd, but we will have more to say about its interpretation below.
Suppose $D(\zeta, \xi) = (\zeta - \xi)^2.$ Then $$ \begin{align} D(\zeta(x), \xi(x, z)) &= (\zeta(x) - \xi(x, z))^2 \\\ &= (\zeta(x) - (f(t_1, x, z) - f(t_0, x, z)))^2 \\\ &= (\zeta(x) - (\hat{f}^\ast(t_1, x) + \epsilon(t_1, x, z) - \hat{f}^\ast(t_0, x) - \epsilon(t_0, x, z)))^2 \\\ &= (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x)) - (\epsilon(t_1, x, z) - \epsilon(t_0, x, z)))^2 \\\ &= (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x)))^2 \\\ &\phantom{=} - 2 \cdot (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x))) \cdot (\epsilon(t_1, x, z) - \epsilon(t_0, x, z))\\\ &\phantom{=} + (\epsilon(t_1, x, z) - \epsilon(t_0, x, z))^2. \end{align} $$
Thus, $$ \begin{align} \hat{\xi}^\ast(x) &= \underset{\zeta}{\operatorname{arg\,min}} \left\{ \int_{\mathcal{X}} \int_{\mathcal{Z}} D(\zeta(x), \xi(x, z)) \cdot w(x, z) \,dz \,dx \right\} \\\ &= \underset{\zeta}{\operatorname{arg\,min}} \left\{ \int_{\mathcal{X}} (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x)))^2 \, \int_{\mathcal{Z}} w(x, z) \, dz \, dx \right. \\\ &\phantom{= \operatorname{arg \,min} \{} - 2 \int_{\mathcal{X}} (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x))) \\\ &\phantom{= \operatorname{arg \,min} \{}\hspace{30pt} \times \int_{\mathcal{Z}} (\epsilon(t_1, x, z) - \epsilon(t_0, x, z)) \cdot w(x, z) \, dz \, dx \\\ &\phantom{= \operatorname{arg \,min} \{} + \left. \int_{\mathcal{X}} \int_{\mathcal{Z}} (\epsilon(t_1, x, z) - \epsilon(t_0, x, z))^2 \cdot w(x, z) \, dz \, dx \right\}. \quad (3) \end{align} $$ The third term in this expression does not depend on $\zeta$ and thus does not affect the solution. And the second term in this expression is zero since $$ \begin{align} \int_{\mathcal{Z}} (\epsilon(t_1, x, z) - \epsilon(t_0, x, z)) \cdot w(x, z) \, dz &= \frac{1}{\psi(t_1, x)} \int_{\mathcal{Z}} \epsilon(t_1, x, z) \cdot w^\prime(t_1, x, z) \, dz \\\ &\phantom{=} - \frac{1}{\psi(t_0, x)} \int_{\mathcal{Z}} \epsilon(t_0, x, z) \cdot w^\prime(t_0, x, z) \, dz \\\ &= 0, \end{align} $$ by the consistency condition on the weights, and the result from the last section that the residuals integrate to zero. That leaves $$ \hat{\xi}^\ast(x) = \underset{\zeta}{\operatorname{arg\,min}} \left\{ \int_{\mathcal{X}} (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x)))^2 \, \int_{\mathcal{Z}} w(x, z) \, dz \, dx \right\}. $$ The objective is an integral of a non-negative function but is zero when the first term is zero. Thus, the solution is attained when $$ \hat{\xi}^\ast(x) = \hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x). $$ In words, the approximation to the treatment effect is equal to the difference in approximations of the function.
Commentary
This result is so simple it seems like we shouldn’t have had to derive it. We defined $\xi(x, z) = f(t_1, x, z) - f(t_0, x, z)$, and we demonstrated that $\hat{\xi}^\ast(x) = \hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x)$. This seems obvious, but it really isn’t! And it’s only true under certain circumstances.
Here’s why it matters: we wish to calculate an approximation to the treatment effect, $\xi(x, z),$ that depends on $x$ alone. However, in practice, we never actually observe the treatment effect directly; we only observe the function $f.$ What our result demonstrates is that if we have the ability to approximate $f,$ we can simply calculate $\hat{\xi}^\ast$ in terms of this approximation. In other words, there is a simple and intuitive relationship between the approximation to the treatment effect and the approximation to the function, $f$.
This only works under special circumstances:
- The loss function on the treatment effect approximation is squared error.
- The loss function on the function $f$ is squared error or the logistic loss.
- The family of functions used to approximate $f$ either satisfies a universal function representation property, as deep neural networks do, or, $f$ itself is linear.
- The weight function $w^\prime(t, x, z) = \psi(t, x) \cdot w(x, z),$ with $\psi(t, x) > 0.$
Of these, the last seems most needing discussion. I’ll rewrite the consistency condition as: $w^\prime(t, x, z) / w(x, z) = \psi(t, x)$ (when $w(x, z) > 0$). Recall from the last post that $w^\prime$ has the interpretation of a probability distribution that we sample from when generating the dataset used to fit a model to $f.$ We can think of $w(x, z)$ as being the weight function reflecting where we want $\hat{\xi}^\ast(x)$ to best approximate $\xi(x, z).$ And in order to estimate $\hat{\xi}^\ast(x)$ we sample from the domain of $f,$ with a sampling distribution given by $w^\prime(t, x, z).$ But in order to get a valid estimate, we need $w^\prime$ to be related to $w$ in a specific way: the ratio $w^\prime / w$ could be a constant, or it could depend on the treatment, or it could depend on both the treatment and the observed covariates $x$, but it cannot depend on the unobserved covariates $z$. (I think this requirement is analogous to the unconfoundedness assumption in Causal Inference.) Only special sampling strategies satisfy this property!
In the last paragraph we decided on a $w$ that we cared about, and then designed a corresponding $w^\prime$. But in an observational study it’s the opposite: nature hands us a $w^\prime$, and we can only hope that it can be factored appropriately. Otherwise, the approximation to the treatment effect is not necessarily equal to the difference in the approximations to $f.$
But especially in an experimental setting, this provides a recipe for estimating heterogeneous treatment effects. Fit models $\hat{f}_1(x) = \hat{f}^\ast(t_1, x)$ and $\hat{f}_0(x) = \hat{f}^\ast(t_0, x)$, then calculate $\hat{\xi}^\ast(x) = \hat{f}_1(x) - \hat{f}_0(x)$. In the literature on heterogenous treatment effect estimation (also known as uplift modeling), this is called the two model approach. It turns out there are better approaches, but the two model approach is simple and intuitive.
As a parting thought, suppose we wanted to draw some insights about $\xi(x, z)$. That is, suppose we are interested in how the treatment effect depends on one of the observed covariates, say $x^{(1)}$. I’ll rewrite the treatment effect as $\xi(x^{(1)}, x^\prime, z),$ where $x^\prime$ just denotes all the other observed covariates. Then we might want to estimate $\partial \xi / \partial x^{(1)}$ or $\xi(x_1^{(1)}, x^\prime, z) - \xi(x_0^{(1)}, x^\prime, z),$ where $x_1^{(1)}$ and $x_0^{(1)}$ denote two values of interest. Then we’re actually in exactly the same position as we were before: we want to learn about some function that depends on unobserved factors, and we can only do this under the same conditions outlined above. Perhaps most importantly, the sampling strategy has to be unconfounded with respect to $x^{(1)}$. And just because the sampling strategy is unconfounded with respect to $t$, doesn’t mean it is unconfounded with respect to $x^{(1)}$.
So when estimating heterogeneous treatment effects, we need to be careful about interpreting the resulting model. While we can certainly predict the treatment effect for any set of covariates $x$, we cannot say that differing treatment effects are because of particular covariate values, except under special circumstances.
Summary
In this post, I described Causal Inference as the study of complex, only partially observed systems. I defined the treatment effect as a comparison between two values of a treatment, keeping all other factors constant.
The treatment effect itself may depend on unobserved factors, but under certain circumstances we can calculate an approximation to the treatment effect in terms of approximations to the system or function itself. The most important requirement is a factorization property of the sampling distribution used to approximate the system that is related to the unconfoundedness assumption in the Causal Inference literature. In general, this can only be guaranteed in the context of a controlled experiment.
Finally, I provided some commentary on the challenges of drawing causal conclusions about the treatment effects themselves. While we can predict the treatment effect whenever the treatment assignment is unconfounded, we also need a covariate of interest to be unconfounded in order to draw causal conclusions about its effect on the treatment effect.
Further Reading
The Calculus of Variations plays a central role in Classical Mechanics, and that’s where I learned about it. V.I. Arnold’s book, Mathematical Methods of Classical Mechanics, is a gem. Calculus of Variations by I. M. Gelfand and S. V. Fomin is a more general-purpose reference.
One of my colleagues, Huigang Chen, is a primary contributor to CausalML, a python package for uplift modeling. The algorithms implemented are much fancier than the two-model approach I describe above! Their Github page contains a list of references for folks that want to learn more about Uplift Modeling.