Honey, I broke the PyTorch model >.< - Debugging custom PyTorch models in a structured manner
2023-04-17 , B05-B06

When building PyTorch models for custom applications from scratch there's usually one problem: The model does not learn anything. In a complex project, it can be tricky to identify the cause: Is it the data? A bug in the model? Choosing the wrong loss function at 3 am after an 8-hour coding session?

In this talk, we will build a toolbox to find the culprits in a structured manner. We will focus on simple ways to ensure a training loop is correct, generate synthetic training data to determine whether we have a model bug or problematic real-world data, and leverage pytest to safely refactor PyTorch models.

After this talk, visitors will be well equipped to take the right steps when a model is not learning, quickly identify the underlying reasons, and prevent bugs in the future.


PyTorch models for off-the-shelf applications are easy to build and debug. But in real-world ML applications, debugging can become quite tricky - especially when model complexity is high and only noisy real-world data is available.

When our DNN is not learning many factors can be at fault:
- Is there a bug in the model structure - for example mixed-up channels or timesteps?
- Is our dataset not large or homogeneous enough to learn something? Have we mixed up labels in the preprocessing?
- Have we chosen incorrect losses, accidentally skipped layers, or chosen inappropriate activation functions?

The plethora of potential reasons can be overwhelming to engineers. This talk will introduce a structured approach and valuable tools for efficiently debugging PyTorch models.
We'll start with techniques to check for correct training loops, such as ensuring our model overfits with a single training example. In the second step, we'll investigate how to generate simple, synthetic data for arbitrary input and output formats to validate our model. At last, we'll look at how to avoid model bugs altogether, by setting up universal tests that can be used during development and refactoring to prevent breaking PyTorch models.


Abstract as a tweet

Honey, I broke the Pytorch model >.< No problem! In this talk, we'll build a toolbox to debug our models and prevent this from happening again -all by leveraging DL logic, synthetic data and pytest. Let's make our models unbreakable <3

Expected audience expertise: Domain

Advanced

Expected audience expertise: Python

Intermediate

I'm a former ML Engineer in the geospatial domain and currently a Ph.D. student for trustworthy ML and Data Science at the RC Trust Ruhr. My main field of interest is Computer Vision and my guilty pleasure is assigning probability densities to all relevant variables in CV models. Application-wise I focus on Remote Sensing (Synthetic Aperture Radar) and Neuroscience (modeling trajectories of disease severity from MRI scans). I also collected a splash of experience in autonomous driving.
I'm classically trained in Bayesian Statistics and interested in combining Bayesian approaches with self-supervised learning and deterministic DNNs.