Equivariant Neural Functional Networks for Transformers

Published in The Thirteenth International Conference on Learning Representations (ICLR 2025), 2025

This paper systematically explores neural functional networks (NFN) for transformer architectures. NFN are specialized neural networks that treat the weights, gradients, or sparsity patterns of a deep neural network (DNN) as input data and have proven valuable for tasks like learnable optimizers, implicit data representations, and weight editing. While NFN have been extensively developed for MLP and CNN, no prior work has addressed their design for transformers, despite their importance in modern deep learning. This paper aims to address this gap by systematically studying NFN for transformers. We first determine the maximal symmetric group of the weights in a multi-head attention module and a necessary and sufficient condition under which two sets of hyperparameters of the module define the same function. We then define the weight space of transformer architectures and its associated group action, leading to design principles for NFN in transformers. Based on these, we introduce Transformer-NFN, an NFN equivariant under this group action. Additionally, we release a dataset of over 125,000 Transformers model checkpoints trained on two datasets with two tasks, providing a benchmark for evaluating Transformer-NFN and encouraging further research on transformer training and performance.

Download Paper | Download Bibtex