Multiply Robust Estimation for Local Distribution Shifts with Multiple Domains (2402.14145v2)
Abstract: Distribution shifts are ubiquitous in real-world machine learning applications, posing a challenge to the generalization of models trained on one data distribution to another. We focus on scenarios where data distributions vary across multiple segments of the entire population and only make local assumptions about the differences between training and test (deployment) distributions within each segment. We propose a two-stage multiply robust estimation method to improve model performance on each individual segment for tabular data analysis. The method involves fitting a linear combination of the based models, learned using clusters of training data from multiple segments, followed by a refinement step for each segment. Our method is designed to be implemented with commonly used off-the-shelf machine learning models. We establish theoretical guarantees on the generalization bound of the method on the test risk. With extensive experiments on synthetic and real datasets, we demonstrate that the proposed method substantially improves over existing alternatives in prediction accuracy and robustness on both regression and classification tasks. We also assess its effectiveness on a user city prediction dataset from Meta.
- Rademacher and gaussian complexities: Risk bounds and structural results. Journal of Machine Learning Research, 3(Nov):463–482, 2002.
- Bethlehem, J. Selection bias in web surveys. International statistical review, 78(2):161–188, 2010.
- Discriminative learning under covariate shift. Journal of Machine Learning Research, 10(9), 2009.
- Brier, G. W. Verification of forecasts expressed in terms of probability. Monthly weather review, 78(1):1–3, 1950.
- Chan, K. C. G. A simple multiply robust estimator for missing response problem. Stat, 2(1):143–149, 2013.
- Xgboost: A scalable tree boosting system. In Proceedings of the 22nd acm sigkdd international conference on knowledge discovery and data mining, pp. 785–794, 2016.
- Binary classifier evaluation on unlabeled segments using inverse distance weighting with distance learning. In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, pp. 3877–3888, 2023.
- Uci machine learning repository, 2017. URL http://archive.ics.uci.edu/ml.
- Assessing the bias in samples of large online networks. Social Networks, 38:16–27, 2014.
- Covariate shift by kernel mean matching. Dataset shift in machine learning, 3(4):5, 2009.
- Estimation with missing data: beyond double robustness. Biometrika, 100(2):417–430, 2013.
- Detecting and correcting for label shift with black box predictors. In International conference on machine learning, pp. 3122–3130. PMLR, 2018.
- Domain adaptation with multiple sources. Advances in neural information processing systems, 21, 2008.
- Ad click prediction: a view from the trenches. In Proceedings of the 19th ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 1222–1230, 2013.
- Wasserstein barycenter for multi-source domain adaptation. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 16785–16793, 2021.
- Multi-source domain adaptation through dataset dictionary learning in wasserstein space. arXiv preprint arXiv:2307.14953, 2023.
- Estimating divergence functionals and the likelihood ratio by convex risk minimization. IEEE Transactions on Information Theory, 56(11):5847–5861, 2010.
- Finite-sample analysis of m-estimators using self-concordance. Electronic Journal of Statistics, 15:326–391, 2021.
- Unsupervised domain adaptation for medical imaging segmentation with self-ensembling. NeuroImage, 194:1–11, 2019.
- Catboost: unbiased boosting with categorical features. Advances in neural information processing systems, 31, 2018.
- Doubly robust covariate shift correction. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 29, 2015.
- Adjusting the outputs of a classifier to new a priori probabilities: a simple procedure. Neural computation, 14(1):21–41, 2002.
- Shimodaira, H. Improving predictive inference under covariate shift by weighting the log-likelihood function. Journal of statistical planning and inference, 90(2):227–244, 2000.
- Tabular data: Deep learning is not all you need. Information Fusion, 81:84–90, 2022.
- Maximum mean discrepancy. In 13th international conference, ICONIP, pp. 3–6, 2006.
- Storkey, A. When training and test sets are different: characterizing learning transfer. Dataset shift in machine learning, 30(3-28):6, 2009.
- Direct importance estimation for covariate shift adaptation. Annals of the Institute of Statistical Mathematics, 60:699–746, 2008.
- Vetrivel, P. Kaggle: Customer segmentation, 2021. URL https://www.kaggle.com/datasets/vetrirah/customer/.
- Bermuda: a novel deep transfer learning method for single-cell rna sequencing batch correction reveals hidden high-resolution cellular subtypes. Genome biology, 20(1):1–15, 2019.
- Deep cocktail network: Multi-source unsupervised domain adaptation with category shift. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 3964–3973, 2018.
- Adversarial multiple source domain adaptation. Advances in neural information processing systems, 31, 2018.
- Aligning domain-specific distribution and classifier for cross-domain classification from multiple sources. In Proceedings of the AAAI conference on artificial intelligence, volume 33, pp. 5989–5996, 2019.