2024-07-11 –, For Loop (3.2)
Large Language Models (LLMs) have become ubiquitous in several areas. But so far, the Julia ecosystem has lacked an efficient way to train these models, hindering the adoption of Julia by ML practitioners and users. In this talk we demonstrate parallel training of LLMs using Dagger.jl and Flux.jl. We discuss the various components required to write an efficient parallel training routine. Further, we present the scaling performance achievable with this method, and discuss the future developments.
Training large neural networks has been a hard challenge in the Julia ecosystem. It requires the infrastructure to scale training to many GPUs simultaneously. In this talk we demonstrate several pieces required to train LLMs such as Llama2 in Julia. Using Dagger.jl, we demonstrate scaling to multiple GPUs during the training of a Flux.jl model.
This approach has several applications; since the underlying implementation is agnostic of the model being trained, it can thus be reused to parallelize training across a wide variety of models to solve current and emerging problems faster than before.
In practice, LLMs aren't trained from scratch. Pre-trained models are typically loaded and then fine tuned on specialised tasks with techniques such as Low Rank Adaptation of LLMs (LoRA). This greatly reduces the training time required to fit a model to become suitable for a specific task.
We will present the various components of the pipeline - from the model and data description to setting up the parallel training job. We will present the scaling in training time achieved, and the size of models we could accommodate with this approach. We will walk through the various challenges encountered during the process, and how we envision scaling of large GPU workloads using the Julia ecosystem in the future.
I'm a Research Software Engineer at MIT's JuliaLab, working on parallel programming with Dagger.jl and AMDGPU.jl. I love working on low-level runtimes and compilers, as well as building out high-level, user-friendly parallel programming interfaces.