Papers
Topics
Authors
Recent
Search
2000 character limit reached

Mirage: A Multi-Level Superoptimizer for Tensor Programs

Published 9 May 2024 in cs.LG, cs.AI, and cs.PL | (2405.05751v3)

Abstract: We introduce Mirage, the first multi-level superoptimizer for tensor programs. A key idea in Mirage is $\mu$Graphs, a uniform representation of tensor programs at the kernel, thread block, and thread levels of the GPU compute hierarchy. $\mu$Graphs enable Mirage to discover novel optimizations that combine algebraic transformations, schedule transformations, and generation of new custom kernels. To navigate the large search space, Mirage introduces a pruning technique based on abstraction that significantly reduces the search space and provides a certain optimality guarantee. To ensure that the optimized $\mu$Graph is equivalent to the input program, Mirage introduces a probabilistic equivalence verification procedure with strong theoretical guarantees. Our evaluation shows that Mirage outperforms existing approaches by up to 3.3$\times$ even for DNNs that are widely used and heavily optimized. Mirage is publicly available at https://github.com/mirage-project/mirage.

Summary

  • The paper introduces Mirage, a novel superoptimizer that unifies GPU compute levels with a multi-level intermediate representation for tensor programs.
  • It leverages algebraic and scheduling transformations along with custom kernel generation to optimize complex tensor operations.
  • Using abstraction-based pruning and probabilistic equivalence verification, Mirage achieves up to 3.5x speedups over existing optimization frameworks.

The paper "A Multi-Level Superoptimizer for Tensor Programs" (2405.05751) introduces Mirage, a superoptimizer designed to optimize tensor programs across multiple levels of the GPU compute hierarchy. Its central innovation lies in a unified intermediate representation (IR) and search strategy that enables the discovery of complex, cross-level optimizations previously inaccessible to single-level optimizers or rule-based compilers.

MuGraphs: A Unified Multi-Level Intermediate Representation

Mirage employs a novel IR called μ\muGraphs (μ\muGs) to represent tensor programs. Unlike traditional approaches that often use separate IRs for different optimization stages (e.g., high-level computation graphs, loop nests, low-level assembly), μ\muGs provide a single, unified representation capturing computation details at the kernel, thread block (CTA), and thread levels.

A μ\muG is essentially a dataflow graph where nodes represent operations (e.g., tensor contractions, elementwise operations, memory loads/stores, synchronization primitives) and edges represent data dependencies. Crucially, μ\muGs incorporate hierarchy directly: nodes can represent computations performed by threads, CTAs, or entire kernels. This hierarchical structure allows Mirage to explicitly model and manipulate interactions between different levels of parallelism and memory access patterns within the GPU architecture.

For instance, a μ\muG can represent a matrix multiplication kernel, detailing the high-level tiling strategy (kernel level), the distribution of tiles to CTAs and associated shared memory usage (CTA level), and the specific micro-kernel instructions executed by individual threads, including register allocation and instruction scheduling (thread level), all within the same graph structure. This unification is key to enabling optimizations that span these levels.

Cross-Level Optimization Capabilities

The unified μ\muG representation empowers Mirage to explore a significantly richer optimization space than conventional compilers or single-level superoptimizers. Mirage applies transformations directly to the μ\muG, encompassing:

  1. Algebraic Transformations: Exploiting properties like associativity, distributivity, and commutativity to restructure computations (e.g., A(BC)→(AB)CA(BC) \rightarrow (AB)C). In the μ\muG context, these can be applied at any granularity represented by the graph nodes.
  2. Schedule Transformations: Modifying the execution order, parallelism mapping, and memory access patterns. This includes traditional loop transformations (tiling, fusion, reordering) represented implicitly through changes in the μ\muG structure, as well as explicit mapping of computation to threads/CTAs and data to memory hierarchies (registers, shared memory, global memory). Mirage can discover complex tiling strategies or alter thread-level instruction sequences based on CTA-level data layout.
  3. Custom Kernel Generation: Mirage is not limited to predefined kernel implementations. By manipulating the μ\muG, it can effectively synthesize novel, specialized kernel structures tailored to the specific tensor program and target hardware, potentially combining parts of different standard algorithms or inventing entirely new dataflows.

The ability to consider these transformations concurrently within the unified μ\muG framework allows Mirage to find optimizations that require coordinated changes across compute levels, such as modifying thread-level computations to enable better CTA-level shared memory reuse or applying an algebraic identity that simplifies the computation into a form more amenable to efficient mapping onto the GPU architecture.

Search Strategy and Abstraction-Based Pruning

Superoptimization inherently involves navigating vast search spaces. Mirage employs a search algorithm (likely a form of enumerative search, potentially beam search or stochastic search, though the paper details are key here) over the space of possible μ\muG transformations. To manage the combinatorial explosion, a critical contribution is an abstraction-based pruning technique.

This technique works by creating an abstracted version of the μ\muG, α(μG)\alpha(\mu G), which omits certain low-level details (e.g., precise thread scheduling, exact register allocation nuances) while preserving essential structural and computational properties. Mirage computes a cost lower bound on the abstracted graph, CostLB(α(μG))Cost_{LB}(\alpha(\mu G)). If this lower bound already exceeds the cost of the best solution found so far, the entire subtree of concrete μ\muG refinements corresponding to that abstraction can be pruned. This pruning method provides a formal guarantee: if the abstraction and cost lower bound are correctly defined, Mirage will not prune away the truly optimal solution within its searchable space, relative to the abstraction. This offers a balance between search tractability and optimality guarantees, significantly reducing the search effort compared to a brute-force enumeration.

The abstraction function α\alpha and the lower bound cost function CostLBCost_{LB} are crucial design elements. The paper likely details specific abstractions used, perhaps focusing on abstracting memory access patterns or synchronization details while retaining computational dependencies and parallelism structures.

Probabilistic Equivalence Verification

Ensuring that optimized programs are functionally equivalent to the original is paramount. Superoptimizers often rely on formal verification or extensive testing. Mirage introduces a probabilistic equivalence verification procedure specifically for μ\muGraphs.

Given an original μGorig\mu G_{orig} and an optimized μGopt\mu G_{opt}, the verifier checks if P(μGorig(x)=μGopt(x))=1P(\mu G_{orig}(x) = \mu G_{opt}(x)) = 1 for all valid inputs xx. The core idea is likely based on random input testing, but with stronger theoretical underpinnings than simple fuzzing. It leverages principles possibly related to Schwartz-Zippel lemma or polynomial identity testing, applied to the computational paths within the μ\muGraphs. By evaluating both graphs on randomly selected inputs drawn from a sufficiently large domain, the procedure provides a high probability (often 1−ϵ1 - \epsilon, where ϵ\epsilon can be made arbitrarily small) that the two μ\muGraphs are equivalent if they agree on the random inputs. The paper provides theoretical guarantees on the probability of detecting non-equivalence, linking it to the properties of the computations represented in the μ\muG and the size of the input domain used for testing. This offers a practical alternative to formal methods, which can be computationally intractable for complex tensor programs and GPU schedules.

Evaluation and Performance Results

Mirage was evaluated against state-of-the-art tensor program optimization frameworks, potentially including TVM/Ansor/Meta-Schedule, Halide, TASO, or vendor libraries like cuDNN/cuBLAS. The evaluation focused on optimizing various DNN models and constituent tensor operations.

The standout result reported is that Mirage achieves speedups of up to 3.5x compared to existing, heavily optimized baselines, even for widely used DNNs. This substantial improvement suggests that current approaches might be trapped in local optima or are unable to explore the complex, cross-level optimizations that Mirage unlocks via μ\muGraphs. The performance gains are attributed to Mirage's ability to discover novel combinations of algebraic transformations, scheduling choices (like tiling and fusion), and generating custom kernels that outperform standard library implementations or the outputs of auto-schedulers operating at a single level or with limited transformation rules. The paper likely presents detailed breakdowns for specific operators (e.g., convolutions, matrix multiplications, attention layers) on specific GPU architectures, demonstrating the breadth of applicability.

Implementation

Mirage is implemented and publicly available (github.com/mirage-project/mirage). The implementation likely consists of the μ\muGraph representation, the transformation engine applying rules to the μ\muG, the search algorithm incorporating abstraction-based pruning, and the probabilistic equivalence verifier. Key implementation considerations include the engineering effort required to build and maintain the multi-level μ\muG representation, the computational cost of the search process (which, despite pruning, can still be significant for complex programs), and the tuning of the probabilistic verifier (e.g., number of random inputs needed for a desired confidence level). The practicality hinges on the trade-off between the potentially substantial offline optimization time required by Mirage and the runtime performance gains achieved.

Conclusion

Mirage represents a significant advancement in tensor program optimization by introducing a multi-level superoptimization approach centered around the unified μ\muGraph representation. Its ability to explore and combine algebraic, scheduling, and kernel generation optimizations across GPU compute levels allows it to discover highly efficient implementations missed by previous methods, yielding substantial performance improvements (up to 3.5x) on important workloads. The abstraction-based pruning and probabilistic verification techniques address the key challenges of search space complexity and correctness assurance inherent in superoptimization.

Paper to Video (Beta)

No one has generated a video about this paper yet.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Continue Learning

We haven't generated follow-up questions for this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 13 tweets with 589 likes about this paper.

HackerNews