2025-09-21 –, Space 2
Accelerating scientific Python with JITs. We share our journey migrating a gravitational lensing likelihood calculation from Numba to JAX. Learn about performance gains, automatic differentiation benefits, and practical lessons for high-performance scientific computing in Python.
Python is widely used in scientific research, but pure Python can sometimes be too slow for computationally intensive tasks. Just-In-Time (JIT) compilers are essential tools for boosting performance, allowing Python code to run closer to native speeds. While libraries like Numba have long been popular for accelerating numerical Python functions, JAX offers a new paradigm, combining JIT compilation with powerful features like automatic differentiation (auto-diff) and execution across different hardware (CPUs, GPUs, TPUs).
This talk will take you on a journey through our experience optimizing a critical component of an astrophysics analysis pipeline: the calculation of the likelihood function for gravitational lensing models, used with data from the James Webb Space Telescope. We initially used Numba to accelerate this calculation, but the need for performance portability across hardwares, and the potential speed up from gradient computation for model fitting led us to explore JAX's unique capabilities.
This session will walk through the practical steps, challenges, and insights gained from migrating this complex scientific code from its existing Numba implementation to a JAX-based one.
You will learn:
- Why leveraging performance tools like JITs is crucial for cutting-edge scientific analysis in Python.
- The practical considerations when migrating existing numerical code from Numba to JAX, including syntax changes and managing state.
- How JAX's auto-differentiation simplifies gradient calculations essential for scientific optimization and sampling tasks.
- The significant performance improvements achieved in our specific gravitational lensing case study by using JAX's compiled functions.
- Broader lessons learned about structuring scientific Python projects to effectively use modern JIT compilers and harness capabilities like auto-diff.
We'll conclude by comparing Numba and JAX in benchmark performance, developer ergonomics, and tradeoffs between the two, providing you with practical guidance for choosing the right tool for your scientific computing needs.
This case study offers a concrete example of how evolving Python libraries are enabling researchers to perform complex, high-performance computations directly within the Python ecosystem. Join us to see how tools like JAX are empowering scientific discoveries, one optimised function at a time.
This talk is suitable for intermediate Python programmers familiar with NumPy.
Intermediate
Kolen Cheung is a Research Software Engineer with a PhD in Physics (UC Berkeley) specializing in CMB data analysis with a designated emphasis ("graduate minor") in Computational & Data Science & Engineering. Dr. Cheung has contributed to multiple CMB collaborations including POLARBEAR, LiteBIRD, CMB-S4, and served as Software Deployment Lead at Simons Observatory. His current work spans gravitational lensing, quantum biological modeling, and HPC systems, combining theoretical physics with advanced computational methods.