Jaxmg Enables Scalable Multi-GPU Linear Solves Beyond Single-GPU Memory Limits

Scientists are tackling a persistent bottleneck in modern scientific computing: efficiently solving large, dense linear systems. Roeland Wiersema from the Center for Computational Quantum Physics, Flatiron Institute, and colleagues demonstrate a novel solution with JAXMg, a multi-GPU linear solver built within the JAX framework. This research is significant because it bridges the gap between highly optimised, yet often inflexible, multi-GPU libraries and the increasingly popular, composable JIT-compiled workflows of JAX. By seamlessly integrating with cuSOLVERMg via XLA, JAXMg allows scalable linear algebra to be embedded directly into JAX programs, unlocking multi-GPU performance for end-to-end scientific applications.

Existing multi-GPU solver libraries often prove difficult to integrate into these composable Python workflows, requiring users to exit the JAX execution model and manually manage memory. JAXMg circumvents these challenges by providing a unified, JIT-compatible interface, allowing researchers to leverage the power of multiple GPUs without sacrificing the benefits of JAX’s streamlined programming environment. This approach is particularly valuable for applications demanding repeated linear system solves or eigenvalue decompositions within larger simulation loops or differentiable optimization processes.

Experiments demonstrate that JAXMg supports CUDA 12 and CUDA 13 compatible devices and offers JIT-able interfaces to core cuSOLVERMg routines, including potrs for solving symmetric positive-definite systems, potri for computing matrix inverses, and syevd for eigenvalue decomposition. The implementation utilizes a 1D block-cyclic data distribution scheme, efficiently mapping columns to GPUs in fixed-size tiles of user-configurable size, TA, to balance computational load. This distribution is achieved through deterministic, in-place rotations using peer-to-peer GPU copies and small staging buffers, minimizing data movement and maximizing performance. Furthermore, the study unveils sophisticated memory management techniques supporting both Single Program Multiple Devices (SPMD) and Multi Program Multiple Devices (MPMD) execution modes.
In SPMD mode, shared virtual address spaces facilitate straightforward pointer sharing, while MPMD mode leverages the CUDA IPC API to enable inter-process communication and GPU allocation sharing. Benchmarks conducted on a system equipped with 8 NVIDIA H200 GPUs (143 GB VRAM each) reveal that JAXMg consistently outperforms native single-GPU linear algebra routines, particularly for larger problem sizes, and scales effectively with increasing numbers of GPUs. The team reports performance gains across various data types, including float32, float64, complex64, and complex128, demonstrating the versatility of this new approach.

JAXMg multi-GPU linear algebra via cuSOLVERMg offers significant

Researchers engineered a system that integrates cuSOLVERMg routines—potrs, potri, and syevd—to solve symmetric positive-definite systems, perform matrix inversion, and compute eigenvalues and eigenvectors within the JAX ecosystem. The implementation supports JAX data types (float32, float64, complex64, complex128) and is compatible with CUDA 12 and CUDA 13 devices. Experiments employ a 1D block-cyclic data distribution, constructed in a C++ backend to ensure balanced workload distribution across GPUs by assigning matrix columns in fixed-size tiles distributed in a round-robin fashion. Efficient in-place redistribution is achieved by decomposing column-index mappings into disjoint permutation cycles, enabling peer-to-peer GPU transfers via cudaMemcpyPeerAsync with minimal staging overhead. Using jax.shard, the system exposes per-device shards and passes corresponding GPU pointers to the backend, supporting both SPMD and MPMD execution while maintaining a single controlling process capable of accessing all device memory. This approach enables scalable multi-GPU linear algebra for problems exceeding single-GPU memory limits, preserves composability within JAX pipelines, and overcomes key memory management challenges in existing multi-GPU solutions.

JAXMg delivers scalable multi-GPU linear algebra for machine

Experiments revealed that JAXMg surpasses native single-GPU linear algebra routines in performance, particularly for larger matrices. For the potrs benchmark, with b = (1, ., 1)T, the team varied the tile size (TA) and recorded wall-clock timings, demonstrating improved scaling with JAXMg. Results demonstrate that the largest solvable problem reached N=524288, utilising over 1 TB of memory, a feat previously infeasible, and delivering a 0.3 performance increase. Larger tile sizes improved performance once the problem size became sufficiently large, consistent with increased GPU utilisation, while tile size had minimal impact on syevd.

Tests prove that both syevd and potri require significantly more workspace memory than potrs, influencing the maximum achievable matrix sizes. Measurements confirm that JAXMg enables dense linear solves and eigendecompositions bottlenecked by single-GPU memory capacity, all while maintaining JAX’s composability and JIT-compiled programming model. The comparison of jaxmg. potri with jax. numpy. linalg. inv for a complex128 matrix and jaxmg. syevd with jax. numpy. linalg. eigh for a float64 matrix further validates the library’s effectiveness. Data shows a strong dependence of potri on the tile size (TA), whereas syevd exhibited negligible impact from tile size variations. The breakthrough delivers the ability to tackle matrix sizes previously impossible on single GPUs and to enhance throughput by leveraging aggregate device memory and compute, opening new avenues for complex scientific simulations and analyses. This work highlights the potential for JAXMg to accelerate research across diverse fields reliant on large-scale linear algebra.

JAXMg enables scalable, composable GPU linear algebra

This innovative tool enables Cholesky-based linear solves and symmetric eigendecompositions for matrices exceeding the memory capacity of a single GPU. The key achievement lies in maintaining composability with JAX transformations, allowing multi-GPU execution within complete scientific workflows, a feature often lacking in existing highly optimised multi-GPU solver libraries. Benchmarks demonstrate JAXMg’s ability to solve problems with matrices up to 524,288 in size, utilising over 1TB of memory, and show competitive performance against established JAX-based linear algebra routines. The authors acknowledge a dependence on the TA algorithm and note that further optimisation may be possible. Future work could explore extending JAXMg to encompass a wider range of linear algebra operations and solvers, broadening its applicability across diverse scientific domains.

👉 More information
🗞 JAXMg: A multi-GPU linear solver in JAX
🧠 ArXiv: https://arxiv.org/abs/2601.14466

Rohail T.

Rohail T.

As a quantum scientist exploring the frontiers of physics and technology. My work focuses on uncovering how quantum mechanics, computing, and emerging technologies are transforming our understanding of reality. I share research-driven insights that make complex ideas in quantum science clear, engaging, and relevant to the modern world.

Latest Posts by Rohail T.:

High-Power 2.1-Μm Lasers Achieved Using Innovative Ho3+-Doped CALGO Crystals

High-Power 2.1-Μm Lasers Achieved Using Innovative Ho3+-Doped CALGO Crystals

January 24, 2026
Iterative Refinement Achieves 41.3% Better Compositional Image Generation Results

Iterative Refinement Achieves 41.3% Better Compositional Image Generation Results

January 24, 2026
Nbse Intercalation Achieves Two-Fold Layer Spacing Expansion and Enhanced Charge-Density-Wave Order

Nbse Intercalation Achieves Two-Fold Layer Spacing Expansion and Enhanced Charge-Density-Wave Order

January 24, 2026