Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
169 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Routing Networks: Adaptive Selection of Non-linear Functions for Multi-Task Learning (1711.01239v2)

Published 3 Nov 2017 in cs.LG, cs.CV, and cs.NE

Abstract: Multi-task learning (MTL) with neural networks leverages commonalities in tasks to improve performance, but often suffers from task interference which reduces the benefits of transfer. To address this issue we introduce the routing network paradigm, a novel neural network and training algorithm. A routing network is a kind of self-organizing neural network consisting of two components: a router and a set of one or more function blocks. A function block may be any neural network - for example a fully-connected or a convolutional layer. Given an input the router makes a routing decision, choosing a function block to apply and passing the output back to the router recursively, terminating when a fixed recursion depth is reached. In this way the routing network dynamically composes different function blocks for each input. We employ a collaborative multi-agent reinforcement learning (MARL) approach to jointly train the router and function blocks. We evaluate our model against cross-stitch networks and shared-layer baselines on multi-task settings of the MNIST, mini-imagenet, and CIFAR-100 datasets. Our experiments demonstrate a significant improvement in accuracy, with sharper convergence. In addition, routing networks have nearly constant per-task training cost while cross-stitch networks scale linearly with the number of tasks. On CIFAR-100 (20 tasks) we obtain cross-stitch performance levels with an 85% reduction in training time.

Citations (232)

Summary

  • The paper introduces a dynamic routing mechanism that adaptively selects non-linear function blocks to significantly reduce task interference.
  • It employs a collaborative multi-agent reinforcement learning strategy using a Weighted Policy Learner to optimize routing decisions.
  • Empirical evaluations on datasets like CIFAR-100 demonstrate up to 85% training time reduction, highlighting the architecture's scalability and efficiency.

Routing Networks: Adaptive Selection of Non-linear Functions for Multi-Task Learning

The paper introduces a novel approach to Multi-Task Learning (MTL) using neural networks, focusing on reducing task interference and improving performance through the introduction of Routing Networks. This innovative architecture leverages the adaptive selection of non-linear function compositions to achieve efficient multi-task learning, differentiated by its ability to dynamically adjust the network's structure based on individual inputs, thus minimizing negative transfer while optimizing positive transfer.

Architecture of Routing Networks

A Routing Network comprises two primary components: the router and a set of function blocks. The router, conditioned on both the current input and, optionally, the task label, iteratively selects a function block, which may belong to various neural network topologies, such as fully-connected layers or convolutional layers. This selection and composition process continues recursively, up to a specified recursion depth. The dynamic selection allows the architecture to customize the neural network's operation per input instance, optimizing the shared and task-specific representations on-the-fly.

The router employs a collaborative multi-agent reinforcement learning (MARL) approach for training, mitigating the non-differentiability in hard routing decisions via reinforcement learning. Specifically, a Weighted Policy Learner (WPL) algorithm facilitates the training of multiple agents, each corresponding to a task in the dataset. The network's adaptability serves to concurrently accomplish enhanced accuracy and reduced computational needs compared to equivalent networks, such as cross-stitch networks, in multi-task scenarios.

Empirical Evaluation

The paper extensively evaluates the architecture on adaptations of MNIST, mini-ImageNet, and CIFAR-100 datasets, deploying routing networks to yield performance improvements notably surpassing strong baselines, including cross-stitch networks, and a popular joint training strategy with layer sharing. The results underscore significant enhancements in accuracy and convergence speed, demonstrating routing networks' capability to maintain constant per-task training costs, in contrast to the linear growth in cost observed with cross-stitch networks.

On the CIFAR-100 dataset, a configuration of 20 tasks, the routing network matches cross-stitch network performance levels with an 85% reduction in training time. This efficiency marks a critical advancement for scalable, resource-efficient multi-task learning applications, particularly as the number of tasks increases.

Theoretical and Practical Implications

Theoretically, this work broadens the conceptual understanding of task-specific routing in neural networks, presenting a compelling case for applying reinforcement learning algorithms to configure dynamic neural network architectures. Practically, the routing networks propose a highly scalable solution for MTL problems, hinting at the potential for broader application in areas requiring adaptive computation models, ranging from automated architecture search to continual learning tasks.

The architectural flexibility inherent within routing networks—specifically the potential to expand or reduce the network's computational footprint dynamically—opens numerous avenues for further research. Investigating the scalability to deeper models and exploring the adaption of hierarchy-based reinforcement learning frameworks present intriguing directions for subsequent studies. Future work could also consider applying these principles to online learning scenarios, where task sequences are not fixed but rather evolve over time.

Conclusion

This paper represents a significant step forward in neural network architecture design, offering a robust, flexible platform for tackling the intrinsic challenges of multi-task learning. By introducing routing networks, the authors have provided the research community with a powerful toolset for achieving efficient, adaptive learning in complex, parallel task settings. As a result, this might encourage further exploration of reinforcement learning-based dynamic model configurations, particularly as the demand for scalable, intelligent systems continues to grow.