JAX and Flax: Function Transformations and Neural Networks
2022-08-31 , Aula

Modern accelerators (graphics processing units and tensor processing units) allow for high performance computing at massive scale. JAX traces computation in Python programs through the familiar numpy API, and uses XLA to compile programs that run efficiently on these accelerators. A set of composable function transformations allows for expressing versatile scientific computing with an elegant syntax.

Flax provides abstractions on top of JAX that make it easy to handle weights and other states that is required for solving problems using neural networks.

This talk first presents the basic JAX API that allows for computing gradients, compiling functions, or vectorizing computation. It then proceeds to cover other parts of the JAX ecosystem commonly used for neural network programming, such as basic building blocks and optimizers.


Modern accelerators (graphics processing units and tensor processing units) allow for high performance computing at massive scale. JAX traces computation in Python programs through the familiar numpy API, and uses XLA to compile programs that run efficiently on these accelerators. A set of composable function transformations allows for expressing versatile scientific computing with an elegant syntax.

Flax provides abstractions on top of JAX that make it easy to handle weights and other states that is required for solving problems using neural networks.

This talk first presents the basic JAX API that allows for computing gradients, compiling functions, or vectorizing computation. It then proceeds to cover other parts of the JAX ecosystem commonly used for neural network programming, such as basic building blocks and optimizers.


Expected audience expertise: Python

some

Domains

Machine Learning, Open Source Library, Parallel Computing / HPC

Expected audience expertise: Domain

some

Project Homepage / Git

https://jax.readthedocs.io/

Abstract as a tweet

JAX and Flax: Function Transformations and Neural Networks

After his original studies in medicine, an MSc in bio-electronics, and MD with the Swiss Tropical Health Institute, Andreas has been working at Google since 2015. His main focus there has been on machine learning using Tensorflow and data mining, development of internal tools for data analysis.