Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
194 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning (2404.00776v2)

Published 31 Mar 2024 in cs.LG, cs.DB, and stat.ML

Abstract: We present PyTorch Frame, a PyTorch-based framework for deep learning over multi-modal tabular data. PyTorch Frame makes tabular deep learning easy by providing a PyTorch-based data structure to handle complex tabular data, introducing a model abstraction to enable modular implementation of tabular models, and allowing external foundation models to be incorporated to handle complex columns (e.g., LLMs for text columns). We demonstrate the usefulness of PyTorch Frame by implementing diverse tabular models in a modular way, successfully applying these models to complex multi-modal tabular data, and integrating our framework with PyTorch Geometric, a PyTorch library for Graph Neural Networks (GNNs), to perform end-to-end learning over relational databases.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (38)
  1. Dnf-net: A neural architecture for tabular data. In International Conference on Learning Representations (ICLR), 2021.
  2. Optuna: A next-generation hyperparameter optimization framework. In ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD), pp.  2623–2631, 2019.
  3. TabNet: Attentive interpretable tabular learning. In AAAI Conference on Artificial Intelligence, 2021.
  4. Blake, C. L. Uci repository of machine learning databases. http://www. ics. uci. edu/~ mlearn/MLRepository. html, 1998.
  5. Language models are few-shot learners. volume 33, pp.  1877–1901, 2020.
  6. Excelformer: A neural network surpassing gbdts on tabular data. arXiv preprint arXiv:2301.02819, 2023a.
  7. Learning to simulate complex physics with graph networks. In International Conference on Machine Learning (ICML), 2023b.
  8. XGBoost: A scalable tree boosting system. In ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD), pp.  785–794, 2016.
  9. Electra: Pre-training text encoders as discriminators rather than generators. arXiv preprint arXiv:2003.10555, 2020.
  10. Codd, E. F. A relational model of data for large shared data banks. Communications of the ACM, 13(6):377–387, 1970.
  11. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  12. Fast graph representation learning with PyTorch Geometric. arXiv preprint arXiv:1903.02428, 2019.
  13. Relational deep learning: Graph representation learning on relational databases. arXiv preprint arXiv:2312.04615, 2023.
  14. Neural message passing for quantum chemistry. In International Conference on Machine Learning (ICML), pp. 1273–1272, 2017.
  15. Revisiting deep learning models for tabular data. In Advances in Neural Information Processing Systems (NeurIPS), 2021.
  16. On embeddings for numerical features in tabular deep learning. In Advances in Neural Information Processing Systems (NeurIPS), 2022.
  17. Tabr: Tabular deep learning meets nearest neighbors. In International Conference on Learning Representations (ICLR), 2024.
  18. Why do tree-based models still outperform deep learning on typical tabular data? volume 35, pp.  507–520, 2022.
  19. Array programming with numpy. Nature, 585(7825):357–362, 2020.
  20. Deep residual learning for image recognition. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp.  770–778, 2016.
  21. Open graph benchmark: Datasets for machine learning on graphs. In Advances in Neural Information Processing Systems (NeurIPS), 2020.
  22. Learning backward compatible embeddings. In ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD), pp.  3018–3028, 2022.
  23. TabTransformer: Tabular data modeling using contextual embeddings. arXiv preprint arXiv:2012.06678, 2020.
  24. Boost then convolve: Gradient boosting meets graph neural networks. arXiv preprint arXiv:2101.08543, 2021.
  25. Lightgbm: A highly efficient gradient boosting decision tree. volume 30, 2017.
  26. Semi-supervised classification with graph convolutional networks. In International Conference on Learning Representations (ICLR), 2017.
  27. Roberta: A robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692, 2019.
  28. Text and code embeddings by contrastive pre-training. arXiv preprint arXiv:2201.10005, 2022.
  29. WaveNet: A generative model for raw audio. arXiv preprint arXiv:1609.03499, 2016.
  30. Neural oblivious decision ensembles for deep learning on tabular data. In International Conference on Learning Representations (ICLR), 2020.
  31. Catboost: unbiased boosting with categorical features. volume 31, 2018.
  32. Sentence-bert: Sentence embeddings using siamese bert-networks. 11 2019. URL https://arxiv.org/abs/1908.10084.
  33. Benchmarking multimodal automl for tabular data with text fields. arXiv preprint arXiv:2111.02705, 2021.
  34. Masked label prediction: Unified message passing model for semi-supervised classification. arXiv preprint arXiv:2009.03509, 2020.
  35. Tabular data: Deep learning is not all you need. Information Fusion, 81:84–90, 2022.
  36. Saint: Improved neural networks for tabular data via row attention and contrastive pre-training. arXiv preprint arXiv:2106.01342, 2021.
  37. Attention is all you need. arXiv preprint arXiv:1706.03762, 2017.
  38. Xtab: Cross-table pretraining for tabular transformers. arXiv preprint arXiv:2305.06090, 2023.
Citations (5)

Summary

  • The paper introduces PyTorch Frame, a modular framework that uses the novel Tensor Frame data structure to streamline multi-modal tabular learning.
  • It employs an encoding and column-wise interaction mechanism to transform complex tabular data into unified embedded representations.
  • Integration with foundation models and PyTorch Geometric demonstrates its ability to outperform conventional models on diverse, relational datasets.

PyTorch Frame: A Comprehensive Framework for Multi-Modal Tabular Learning

Introduction to PyTorch Frame

The recently introduced PyTorch Frame offers an innovative solution for tabular deep learning, addressing the requirements for handling complex, multi-modal tabular data efficiently in deep learning applications. This PyTorch-based framework facilitates easy interaction with tabular data through a newly proposed data structure, Tensor Frame, alongside a modular implementation of diverse tabular models and seamless integration with external foundation models for complex column data processing.

Core Components of PyTorch Frame

Data Materialization

PyTorch Frame introduces Tensor Frame, a PyTorch-friendly data structure capable of effectively managing arbitrary complex columns by grouping column data based on semantic types. This transformation simplifies the handling of different data modalities, including numerical, categorical, multicategorical, timestamp, textual, and embedded types, enabling efficient data processing suitable for machine learning models.

Encoding Process

The encoding stage of PyTorch Frame transforms the materialized data into an embedded representation where each column is independently embedded into a uniform dimensional space. This process includes feature normalization and column-specific embedding techniques, catering to the unique characteristics of each semantic type.

Column-wise Interaction

Following the encoding, PyTorch Frame enacts a column-wise interaction mechanism that iteratively updates the embedding of each column by considering the information from other columns. This procedure enables the capture of intricate inter-column relationships within the tabular data, enriching the representational capacity of the encoded embeddings.

Decoding for Prediction

The final stage entails decoding the enriched column embeddings to generate row-wise embeddings that can be utilized directly for prediction tasks or as input to subsequent deep learning models. This decoding step summarizes the comprehensive interactions among columns, rendering a consolidated representation for each row in the table.

Advantages and Integrations

Integration with Foundational Models

A prominent feature of PyTorch Frame is its capacity to incorporate external foundational models, particularly for complex columns such as texts and images. By leveraging pre-trained models or enabling end-to-end fine-tuning, PyTorch Frame significantly enhances the handling and predictive modeling of multi-modal data.

Compatibility with PyTorch Geometric

PyTorch Frame seamlessly integrates with PyTorch Geometric (PyG) for learning over relational databases. This integration combines the strengths of tabular deep learning and GNNs, enabling end-to-end learning that exploits both tabular and relational data characteristics for improved prediction accuracy.

Empirical Validation

PyTorch Frame has been empirically tested across various datasets to demonstrate its efficacy in multi-modal tabular learning. The framework shows promising results in handling traditional datasets with numerical and categorical features, along with modern datasets containing complex columns and relational structures. Notably, the integration of PyTorch Frame with foundation models and PyG outperforms conventional models like LightGBM, especially in datasets enriched with textual information and relational data.

Conclusion

PyTorch Frame represents a significant advancement in tabular deep learning, offering a comprehensive, efficient, and flexible framework for handling complex multi-modal tabular data. By encapsulating the entire process from data materialization to prediction and enabling the integration with external models and PyG, PyTorch Frame paves the way for innovative applications in fields requiring sophisticated tabular data analysis.

Github Logo Streamline Icon: https://streamlinehq.com