After his original studies in medicine, an MSc in bio-electronics, and MD with the Swiss Tropical Health Institute, Andreas has been working at Google since 2015. His main focus there has been on machine learning using Tensorflow and data mining, development of internal tools for data analysis.
Modern accelerators (graphics processing units and tensor processing units) allow for high performance computing at massive scale. JAX traces computation in Python programs through the familiar numpy API, and uses XLA to compile programs that run efficiently on these accelerators. A set of composable function transformations allows for expressing versatile scientific computing with an elegant syntax.
Flax provides abstractions on top of JAX that make it easy to handle weights and other states that is required for solving problems using neural networks.
This talk first presents the basic JAX API that allows for computing gradients, compiling functions, or vectorizing computation. It then proceeds to cover other parts of the JAX ecosystem commonly used for neural network programming, such as basic building blocks and optimizers.