Deep treatment-adaptive network for causal inference

Publisher:
Springer Science and Business Media LLC
Publication Type:
Conference Proceeding
Citation:
VLDB Journal, 2022, pp. 1-16
Issue Date:
2022-01-01
Full metadata record
Causal inference is capable of estimating the treatment effect (i.e., the causal effect of treatment on the outcome) to benefit the decision making in various domains. One fundamental challenge in this research is that the treatment assignment bias in observational data. To increase the validity of observational studies on causal inference, representation-based methods as the state-of-the-art have demonstrated the superior performance of treatment effect estimation. Most representation-based methods assume all observed covariates are pre-treatment (i.e., not affected by the treatment) and learn a balanced representation from these observed covariates for estimating treatment effect. Unfortunately, this assumption is often too strict a requirement in practice, as some covariates are changed by doing an intervention on treatment (i.e., post-treatment). By contrast, the balanced representation learned from unchanged covariates thus biases the treatment effect estimation. In light of this, we propose a deep treatment-adaptive architecture (DTANet) that can address the post-treatment covariates and provide a unbiased treatment effect estimation. Generally speaking, the contributions of this work are threefold. First, our theoretical results guarantee DTANet can identify treatment effect from observations. Second, we introduce a novel regularization of orthogonality projection to ensure that the learned confounding representation is invariant and not being contaminated by the treatment, meanwhile mediate variable representation is informative and discriminative for predicting the outcome. Finally, we build on the optimal transport and learn a treatment-invariant representation for the unobserved confounders to alleviate the confounding bias.
Please use this identifier to cite or link to this item: