Get Started with Variational Inference using Python
2019-09-05, 10:30–11:00, Track 3 (Oteiza)

The objective is to help the audience understand the mechanics of Variational Inference, by implementing it in python. We will build inference algorithms including Coordinate Ascent VI, Black Box VI and Automatic Differentiation VI using PyTorch.


The ability to estimate uncertainty and incorporate it in decision making is crucial to sensitive applications like Self-driving Cars, Autonomous or Semi-autonomous Surgery, Industrial Automation, Personalized (Precision) Medicine, etc, where the consequences of a wrong decision is potentially catastrophic. A Probabilistic Program is the natural way to express such probabilistic models. A Probabilistic Programming Language (PPL) combines programming and probabilistic modelling to enable stochastic computation.

Typically a PPL consists of 2 primary modules - Model and Inference. The former allows the user to define a probabilistic model. The latter performs posterior inference on unknown variables conditioned on data. But the details of the inference procedure are abstracted away from the user. This is by design.

There are two major families of algorithms for probabilistic inference. Markov Chain Monte Carlo (MCMC) is regarded as the golden standard approach for probabilistic inference. MCMC is the workhorse of inference in mature PPLs like pymc. MCMC, although quite slow, promises an accurate posterior over unknowns. A recent trend in probabilistic inference - Variational Inference, is a faster and less accurate replacement for MCMC. Modern Probabilistic Programming Languages (PPLs) like Pyro (torch backend) and Edward (tensorflow backend), rely heavily on approximate inference a.k.a Variational Inference.

This talks addresses the question of "How to get started with Variational Inference?" based on a reddit thread in /r/MachineLearning. Following the footsteps of Allen B. Downey, the author of Think Bayes, we will understand the mechanics of Variational Inference by implementing it in python. The talk is organised as follows:

  • Gaussian Mixture Model (GMM)
    • Theory
    • Implementation in PyTorch
    • Expectation Maximisation for clustering Synthetic Data
  • Variational Inference - Theory
    • Introduction
    • Mean-Field VI
  • Coordinate Ascent VI
    • ELBO derivation
    • Derivation of GMM Parameter Updates
    • Implementation in PyTorch
    • Infer Latent variables in GMM
  • Black Box VI
    • Theory
    • Implementation in Python
  • Automatic Differentiation VI
    • Theory
    • Implementation in Python
    • BBVI vs ADVI
  • Recent Advances / Limitations of Mean Field Exponential Families
    • Variance Reduction
    • Expressivity
    • Scalability
    • Generalizability

Gaussian Mixture Model

We consider Gaussian Mixture Model (GMM) as our candidate for studying VI. GMM models data as a mixture of gaussian components. The latent variables associated with this model are the component membership probabilities (which component does a data point belong to), mean and standard deviation of each gaussian component. The component membership is sampled from a uniform distribution while the mean and standard deviation of each component, are sampled from gaussian distributions. The objective is to estimate the parameters of these distributions.


Variational Inference

In VI, we consider a family of distributions Q (with "nice" properties) and search for a member of that family that minimises the KL-divergence to the exact posterior over unknowns. VI trades off accuracy for computation time.

On what basis, do we choose a family of distributions for our variational distribution? It would be nice to have distributions with the following properties:

  • Exponential family (easier to calculate mean and variance)
  • Low dimensional (factorises into individual groups)

We choose Mean-Field Variational family which assumes that the latent variables are mutually independent. The obvious consequence of this decision is, we lose all the information about co-variance between variables as we assume that they don't interact with each other.

We optimise variational parameters (parameters of latent factors) by minimising the KL-divergence between the exact posterior p and our approximation q. From the expression for KL-divergence, we arrive at the Evidence Lower Bound (ELBO). Increasing ELBO is equivalent to reducing the KL-divergence. ELBO connects the variational density to data and model.

In this talk, we will explore 3 algorithms that operate on the basis of the assumptions above, to estimate the variational parameters. We will exploit PyTorch's tensor manipulation and automatic differentiation capabilities in the process.

Coordinate Ascent Variational Inference (CAVI)

CAVI estimates variational parameters by iteratively optimising (to maximise ELBO) each variational parameter while keeping others fixed. The expression for ELBO is highly model-specific which means every model demands a manual derivation of ELBO. Apart from that, parameter updates (gradients of ELBO w.r.t. each parameter) must also be manually derived for each parameter.

Black Box Variational Inference (BBVI)

The ELBO expression and custom updates in CAVI, are painstakingly hard to derive. Black Box Variational Inference (BBVI) uses sampling to approximate the gradient of ELBO, for stochastic optimisation. The heart of BBVI is that it made it possible to write the gradient of the ELBO as an expectation. Instead of evaluating the closed-form expression of the ELBO, we take the average of Monte Carlo samples of the ELBO gradients, to get a noisy estimate.

Automatic Differentiation Variational Inference (ADVI)

Automatic Differentiation Variational Inference (ADVI) automatically derives an efficient variational inference algorithm. ADVI relies on automatic differentiation libraries like PyTorch, autograd, etc, to compute the derivatives of ELBO with respect to the parameters.


We briefly explore the limitations of approximating the posterior with Mean Field Exponential Family of distributions. We address the fundamental issues in the emerging field of Bayesian Deep Learning - Scalability (Amortised VI), Expressivity (Normalising Flow), Variance Reduction and Generalizability.

The code is available at suriyadeepan/BayesianML/gmm.

References


Abstract as a tweet – Learn the mechanics of Variational Inference by implementing it in python (CAVI, BBVI, ADVI). Domains – Machine Learning, Statistics Domain Expertise – some Python Skill Level – professional Project Homepage / Git – https://github.com/suriyadeepan/BayesianML/tree/master/gmm