JuliaCon 2022 (Times are UTC)

ImplicitDifferentiation.jl: differentiating implicit functions
2022-07-29 , Blue

We present a Julia package for differentiating through functions that are defined implicitly. It can be used to compute derivatives for a wide array of "black box" procedures, from optimization algorithms to fixed point iterations or systems of nonlinear equations.
Since it mostly relies on defining custom chain rules, our code is lightweight and integrates nicely with Julia's automatic differentiation and machine learning ecosystem.


Introduction

Differentiable programming is a core ingredient of modern machine learning, and it is one of the areas where Julia truly shines. By defining new kinds of differentiable layers, we can hope to increase the expressivity of deep learning pipelines without having to scale up the number of parameters.

For instance, in structured prediction settings, domain knowledge can be encoded into optimization problems of many flavors: linear, quadratic, conic, nonlinear or even combinatorial. In domain adaptation, differentiable distances based on optimal transport are often computed using the Sinkhorn fixed point iteration algorithm. Last but not least, in differential equation-constrained optimization and neural differential equations, one often needs to obtain derivatives for solutions of nonlinear equation systems with respect to equation parameters.

Note that these complex functions are all defined implicitly, through a condition that their output must satisfy. As a consequence, differentiating said output (e.g. the minimizer of an optimization problem) with respect to the input (e.g. the cost vector or constraint matrix) requires the automatization of the implicit function theorem.

Related works

When trying to differentiate through iterative procedures, unrolling the loop is a natural approach. However, it is computationally demanding and it only works for pure Julia code with no external "black box" calls. On the other hand, using the implicit function theorem means we can decouple the derivative from the function itself: see Efficient and Modular Implicit Differentiation for an overview of the related theory.

In the last few years, this implicit differentiation paradigm has given rise to several Python libraries such as OpenMDAO, cvxpylayers and JAXopt. In Julia, the most advanced one is DiffOpt.jl, which allows the user to differentiate through a JuMP.jl optimization model. A more generic approach was recently experimented with in NonconvexUtils.jl: our goal with ImplicitDifferentiation.jl is to make it more efficient, reliable and easily usable for everyone.

Package content

Our package provides a simple toolbox that can differentiate through any kind of user-specified function x -> y(x). The only requirement is that its output be characterized with a condition of the form F(x,y(x)) = 0.
Beyond the generic machinery of implicit differentiation, we also include several use cases as tutorials: unconstrained and constrained optimization, fixed point algorithms and nonlinear equation systems, etc.

Technical details

The central construct of our package is a wrapper of the form

struct ImplicitFunction{O,C}
    forward::O
    conditions::C
end

where forward computes the mapping x -> y(x), while conditions corresponds to (x,y) -> F(x, y). By defining custom pushforwards and pullbacks, we ensure that ImplicitFunction objects can be used with any ChainRules.jl-compatible automatic differentiation backend, in forward or reverse mode.

To attain maximum efficiency, we never actually store a full jacobian matrix: we only reason with vector-jacobian and jacobian-vector products. Thus, when solving linear systems (which is a key requirement of implicit differentiation), we exploit iterative Krylov subspace methods for their ability to handle lazy linear operators.

Scientist at PumasAI Inc.

PhD student at École des Ponts (France), working on machine learning and operations research with applications to railway planning.

This speaker also appears in: