JuliaCon 2022 (Times are UTC)

Training Spiking Neural Networks in pure Julia
07-28, 17:30–17:40 (UTC), Blue

Training artificial neural networks to recapitulate the dynamics of biological neuronal recordings has become a prominent tool to understand computations in the brain. We present an implementation of a recursive-least squares algorithm to train units in a recurrent spiking network. Our code can reproduce the activity of 50,000 neurons of a mouse performing a decision-making task in less than an hour of training time. It can scale to a million neurons on a GPU with 80 GB of memory.

Spiking networks that operate in a fluctuation driven regime are a common way to model brain activity. Individual neurons within a spiking artificial neural network are trained to reproduce the spiking activity of individual neurons in the brain. In doing so, they capture the structured activity patterns of the recorded neurons, as well as the spiking irregularities and the trial-to-trial variability. Such trained networks can be analyzed in silico to gain insights into the dynamics and connectivity of cortical circuits underlying the recorded neural activity that would be otherwise difficult to obtain in vivo.

The number of simultaneously recorded neurons in behaving animals has been increasing in the last few years at an exponential rate. It is now possible to simultaneously record from about 1000 neurons using electrophysiology in behaving animals, and up to 100,000 using calcium imaging. When combining several sessions of recordings, the amount of data becomes huge and could grow to millions of recorded neurons in the next few years. There is a need then for fast algorithms and code bases to train networks of spiking neurons on ever larger data sets.

Here we use a recursive least-squares training algorithm (RLS; also known as FORCE; Sussillo and Abbott, 2009), adapted for spiking networks (Kim and Chow 2018, 2021), which uses an on-line estimation of the inverse covariance matrix between connected neurons to update the strength of plastic synapses. We make the code more performant through a combination of data parallelism, leveraging of BLAS, use of symmetric packed arrays, reduction in storage precision, and refactoring for GPUs.

Our goal is to train the synaptic current input to each neuron such that the resulting spikes follow the target activity pattern over an interval in time. We use a leaky integrate-and-fire neuron model with current-based synapses. The peri-stimulus time histograms of the spike trains are converted to the equivalent target synaptic currents using the transfer function of the neuron model. We treat every neuron’s synaptic current as a read-out, which makes our task equivalent to training a recurrently connected read-out for each neuron. Since a neuron's synaptic current can be expressed as a weighted sum of the spiking activities of its presynaptic neurons, we adjust the strength of the incoming synaptic connections by the RLS algorithm in order to generate the target activity.

This training scheme allows us to set up independent objective functions for each neuron and to update them in parallel. A CPU version of the algorithm partitions the neurons onto threads and uses the standard BLAS libraries to perform the matrix operations. As the vast majority of memory is consumed by the inverse covariance matrix, larger models can be accommodated by reducing precision for all state variables and using a packed symmetric matrix for the covariance (see SymmetricFormats.jl for the SymmetricPacked type definition). These memory-use optimizations also have the benefit of being faster too. For the GPU version, custom batched BLAS kernels were written for packed symmetric matrices (see BatchedBLAS.jl for the batched_spmv! and batched_spr! functions).

We benchmarked on synthetic targets consisting of sinusoids with identical frequencies and random phases. For a model with one million neurons, 512 static connections per neuron, and 45 plastic connections per neuron, the CPU code took 1260 seconds per training iteration on a 48-core Intel machine and the GPU code took 48 seconds on an Nvidia A100. For this connectivity pattern, one million is the largest number of neurons (within a factor of two) that could fit in the 80 GB GPU. The CPU cores, with 768 GB of RAM, accommodated four million neurons with this connectivity.

We also tested our algorithm's ability to learn real target functions using 50,000 neurons recorded in five different brain regions from a mouse performing a decision-making task. The recording intervals were 3 sec long and spikes rates averaged 7 Hz. Replacing the static connections with random Gaussian noise and using 256 plastic connections, the model achieved a correlation of 0.8 between the desired and learned currents in 30 minutes of training time.

Our work enables one to train spiking recurrent networks to reproduce the spiking activity of huge data sets of recorded neurons in a reasonable amount of time. By doing so, it facilitates analyzing the relations between connectivity patterns, network dynamics and brain functions in networks of networks in the brain. We also introduce two new Julia packages to better support packed symmetric matrices.