Each month, we break down a key ML topic with clarity and engaging visuals. Subscribe for in-depth insights and stay at the forefront of Machine Learning and MLOps.
Share
⛽ How GPU Memory Hierarchy Fuels the idea Behind FlashAttention
Published 8 months ago • 6 min read
FlashAttention Part Three: Understanding the GPU memory hierarchy and how each component can be used to optimize performance.
Dimitris Poulopoulos
May 8th
GPU Memory Hierarchy Faster than your company’s organization chart!
In our last chapter, we delved into the attention mechanism—today's superstar in the world of Deep Learning. We now have a basic understanding of how attention works, so, before we explore the various types of attention mechanisms, let's circle back to this month's topic: FlashAttention.
However, in this story we won’t be talking about FlashAttention directly. Instead, we take a detour to understand a bit about the GPU architecture and, specifically, the memory hierarchy of the device. Solidifying this theoretical framework is crucial for getting to grips with FlashAttention.
When you hear “hierarchy”, you might think of the army's rigid structure or the complex layers of a multinational company. However, when it comes to GPUs, we're dealing with a different kind of hierarchy altogether. Like most system devices, GPUs have a memory hierarchy determined by the distance between the memory chip and the computation unit—what we call cores in modern CPUs, or, in the GPU case, the SM (Streaming Multiprocessor).
Typically, the further you are from the computation unit, the slower the access speeds. On the other hand, the memory capacity increases and the cost is lower. In this chapter, we'll dissect the GPU memory hierarchy to show why understanding this is crucial for mastering FlashAttention. I’m certain that some hints in the end will get you to wonder: why has no one ever thought of FlashAttention before?
Let’s begin!
As we touched on in the prologue, GPUs—much like other systems—feature a memory hierarchy. We’ll start our exploration with the type of storage that sits closest to the execution unit. First stop on this silicon road-trip: the registers!
Local Storage
Each execution thread operates within its own logically isolated local storage space, where the most critical resource is the registers. These registers, embedded within the GPU, are high-speed storage units that act as both the starting point and the destination for nearly all low-level machine instructions. Essentially, when the GPU crunches numbers, it pulls its input from these registers and sends the output right back to them.
Registers are the fastest memory resource, but you, as a programmer, typically have no say on how they are used during execution; they are largely managed by the compiler. So, there’s not much optimization we can do at this level. Let’s move to something we can influence, and should take advantage of.
Shared Memory (SRAM) / L1 Cache
The shared memory and the L1 cache are both integral parts of the GPU die and do not exist off-chip. Let’s start with the things you can control!
Shared memory is a local memory array, directly accessible and explicitly allocable by programmers for their temporary data stash needs. Think of it as a box that can hold up to 48KB per thread block (maybe a bit more in modern architectures), and you can use it to deposit small items that you can’t hold in your hands, but you also don’t want to move to the attic, because you will need it in a minute or two.
Shared memory is distinct in that it is a per SM (Streaming Multiprocessor) resource and possesses a dual nature:
It facilitates inter-thread communication within a block: Multiple workers (threads) may be working in a factory (SM). Everyone is tasked with doing the exact same job, using different resources (input data). However, each worker can share their work (output data) by placing the output of their effort in a common chest (shared memory), where other workers can pick it up.
It aims to reduce redundant global memory accesses and optimize corresponding data access patterns: Imagine you store every small item, even your forks and spoons, in the attic. Now, every time you’d like to eat, you’d have to open the attic door, climb up the ladder, open the box labeled “cattlery”, pick up a fork, go down, eat your meal, wash your fork, and follow the same process to store it back in the attic. It’s like a workout you didn’t sign up for. Since you will need the fork again in a few hours, it’s better to put it in the kitchen drawer. That’s what programmers do with shared memory. They use it to store data they want to use in the next operation.
On the other hand, life gets a bit easier for programmers when dealing with the L1 cache, and that’s because the L1 cache is on auto-pilot! Unlike the hands-on approach needed for shared memory, the L1 cache operates behind the scenes. It automatically keeps a hold of recently used data, in the hope that it will be reused soon. It’s like having a smart assistant who remembers to keep your files on the desk because you’ll probably ask for them in a minute.
L2 Cache
The L2 cache is like the L1 cache’s bigger sibling with a broader responsibility. While the L1 cache sticks close to its home base in each SM (Streaming Multiprocessor), the L2 cache has the run of the entire device.
It's the central hub of data traffic—think of it as the main bus station in a city. All data must pass through this station before heading back to the vast countryside of global memory. This device-wide role ensures that if any piece of data needs to travel, it first checks in with the L2 cache, making sure nothing is needlessly repeated across the system.
Global Memory (HBM)
Global memory is like the distant warehouse of the GPU world—it's where everything starts when you're gearing up for computation. It is the first stop for any CUDA programmer needing to transfer data from the system memory to the GPU's processing powerhouse. Since it's accessible by all threads and the host, it's the communal pool everyone dips into.
However, it's not without its drawbacks. Global memory has relatively high latency, meaning it’s a bit slow on the draw. Requesting data from this memory is like ordering a package online; you click 'order' and then wait patiently (or not so patiently) for it to arrive ready for use. This can be a snag for memory-bound algorithms that are as impatient as a kid on Christmas morning. However, compared to the system's memory, it’s really fast. But here's the silver lining: compared to the system's memory, global memory is still pretty speedy. It’s like getting a package from within your own country, as opposed to system memory's international shipping scenario, where your order has to clear customs.
Additionally, while global memory boasts higher throughput than your average system memory, it still moves like a snail compared to other memory types within the GPU. It's like having a fast car in a world where everyone else is driving supercars.
Why Should You Care?
Armed with this knowledge, you might be wondering, "Why should I care about all this if I just want to understand how FlashAttention works?" Great question! All the concepts we've discussed form the very core of FlashAttention. The distinctions between shared memory and global memory aren't just technical details; they're fundamental to grasping why FlashAttention is so much faster than the vanilla attention algorithm. If you've been following along, you're probably starting to see the bigger picture.
If you still don’t, I’m sure the following picture of how the vanilla attention algorithm flows will make everything clear:
Notice how each operation sequence starts with loading data from HBM and ends with the phrase “write back to HBM”? This loop is a significant bottleneck. As we discussed before, loading data from HBM is similar to ordering items online. Similarly, writing back to HBM can be compared to sending items back using the post office. Just as you might dread the delays and potential hassles of shipping items back and forth, the GPU faces efficiency challenges with these repetitive trips to and from the HBM.
So, what can we do? This will be the focus of our next—and final—story on FlashAttention. But if you’ve been paying close attention, you might already have a hint of what’s coming!
$5.00
Learning Rate
Thank you for your support! It keeps the doors open and the wheels rolling.
Photo by Rock'n Roll Monkey on Unsplash
Each month, we break down a key ML topic with clarity and engaging visuals. Subscribe for in-depth insights and stay at the forefront of Machine Learning and MLOps.
FlashAttention Part Two: An intuitive introduction to the attention mechanism, with real-world analogies, simple visuals, and plain narrative. Dimitris Poulopoulos April 15th Attention, Please! (Part I) Attention is all you need, but the span is limited In the previous chapter, I introduced the FlashAttention mechanism from a high-level perspective, following an "Explain Like I'm 5" (ELI5) approach. This method resonates with me the most; I always strive to connect challenging concepts to...
Part One: An ELI5 introduction to "FlashAttention", a fast, IO-aware, memory-efficient mechanism, promising faster training times on longer sequences. Dimitris Poulopoulos April 8th Unraveling FlashAttention A Leap Forward in Language Modeling As I pondered the topic for my newsletter, the idea of explaining how the attention mechanism works immediately stood out. Indeed, when launching a new series, starting with the fundamentals is a wise strategy, and Large Language Models (LLMs) are the...
Learning Rate Dear Subscribers It’s been a long time since we last connected. I hope this message finds you well and continuously curious about the vast world of Machine Learning and MLOps. Today, I’m thrilled to announce a transformative evolution in how Learning Rate will bring you the insights and knowledge you value so much. 🧮 A New Shape to Learning Starting next month, Learning Rate is taking a deep dive approach. Each month, we will focus on one key topic, breaking it down over four...