Authors:
(1) P Aditya Sreekar, Amazon and these authors contributed equally to this work {[email protected]};
(2) Sahil Verm, Amazon and these authors contributed equally to this work {[email protected];}
(3) Varun Madhavan, Indian Institute of Technology, Kharagpur. Work done during internship at Amazon {[email protected]};
(4) Abhishek Persad, Amazon {[email protected]}.
Tree-based algorithms are widely used in machine learning for tabular data. Decision trees recursively split the data into multiple parts based on axis-aligned hyper-planes (Hastie et al., 2009). Random Forests (RF) (Breiman, 2001) and Gradient Boosted Decision Trees (GBDT) (Friedman, 2001) are the most commonly used tree based ensembles. RF fits multiple decision trees on random subsets of the data and averages/polls the predictions to alleviate the overfitting characteristic of decision trees. GBDT, XGBoost (Chen and Guestrin, 2016), and CatBoost (Prokhorenkova et al., 2018) are boosted ensemble models that sequentially build decision trees to correct errors made by previous trees, leading to improved performance on complex datasets with non-linear relations.
Recently, there has been a lot of interest in deep learning models for tabular data. Some methods introduce differentiable approximations of decision functions used in decision trees to make them differentiable (Hazimeh et al., 2020; Popov et al., 2019). These methods outperform pure tree based problem for some problem statements, however, they are not consistently better (Gorishniy et al., 2021). Other methods have used attention mechanisms to adapt DL methods to tabular data (Arik et al., 2019; Huang et al., 2020; Gorishniy et al., 2021; Somepalli et al., 2021; Chen et al., 2022). TabNet (Arik et al., 2019) proposes a sparse attention mechanism that is stacked in multiple layers to mimic the recursive splitting of decision trees. Inspired from the success of self-attention transformers (Vaswani et al., 2017) in many domains (Devlin et al., 2019; Dosovitskiy et al., 2021; Gong et al., 2021) methods like TabTransformer (Huang et al., 2020), FT-Transformer (Gorishniy et al., 2021) and SAINT (Somepalli et al., 2021) were proposed. TabTransformer embeds all categorical variables into a unified embedding space, and a sentence of categorical embeddings is passed through self-attention transformer layers. FT-Transformer further extends this by attending to numerical features as well, by using continuous embedding. SAINT builds on FT-Transformer by proposing a new kind of attention which captures interactions between samples of a batch. However, SAINT does not provide any advantage over FT-Transformer for our problem statement, because intersample attention is only effective when the number of dimensions is higher in comparision to the number of samples, thus we do not compare RCT against SAINT (Somepalli et al., 2021).
This paper is available on arxiv under CC BY-NC-ND 4.0 DEED license.