Understanding Flash Attention (Forward) with CUDA
Flash Attention is an efficient mechanism used in transformers to enhance attention computation, making the operation faster and less memory-intensive. This post will break down the forward pass of Flash Attention as implemented in CUDA.
Overview of the Algorithm
In the forward kernel, we use matrices Q (query), K (key), and V (value) split into blocks for efficient computations. The essential steps are:
- Divide Q, K, and V into blocks.
- Load the query, key, and value matrices into shared memory (SRAM).
- Compute the attention scores using softmax.
- Calculate the output.
Memory Allocation and Initialization
Before performing computations, we allocate and initialize device memory for the matrices:
template <typename T>
T* allocateAndInitializeDeviceMemory(size_t size, bool initializeToZero = false, bool initializeToNegativeInfinity = false) {
T* device_ptr;
cudaMalloc(&device_ptr, size); // Allocate memory on the device
if (initializeToZero) {
cudaMemset(device_ptr, 0, size); // Initialize to zero
} else if (initializeToNegativeInfinity) {
float negative_infinity_host = -INFINITY;
cudaMemset(device_ptr, *reinterpret_cast<int*>(&negative_infinity_host), size); // Initialize to negative infinity
} else {
// Generate random numbers if no initialization is specified
}
return device_ptr;
}
CUDA Kernel: forward_kernel
The main CUDA kernel forward_kernel
handles the computations.
Thread and Block Indexing
Each thread in a block computes a portion of the output for specific query and key matrices.
SRAM for Q, K, V
We declare shared memory for tiles of Q, K, and V:
extern __shared__ float shared_memory[];
float* query_matrix_tile = shared_memory;
float* key_matrix_tile = &shared_memory[tile_size];
float* value_matrix_tile = &shared_memory[tile_size * 2];
Loading Matrices into SRAM
We load tiles of K and V from global memory into shared memory:
for (int embedding_index = 0; embedding_index < embedding_dimension; embedding_index++) {
key_matrix_tile[(thread_index_x * embedding_dimension) + embedding_index] =
key_matrix_device_pointer[qkv_offset + ...];
value_matrix_tile[(thread_index_x * embedding_dimension) + embedding_index] =
value_matrix_device_pointer[qkv_offset + ...];
}
__syncthreads(); // Ensure all threads have completed loading
Computing Attention Scores
For each block, compute scores based on queries and keys:
for (int column_index_inner = 0; column_index_inner < block_size_columns; column_index_inner++) {
sum += query_matrix_tile[(thread_index_x * embedding_dimension) + embedding_index] *
key_matrix_tile[(column_index_inner * embedding_dimension) + embedding_index];
}
score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner] = sum * softmax_scale;
Softmax Calculation
Scale the scores and apply the softmax function:
float row_sum = 0;
for (int column_index_inner = 0; column_index_inner < block_size_columns; column_index_inner++) {
score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner] =
__expf(score - row_max);
row_sum += score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner];
}
Calculating Outputs
Compute the final output matrix by combining scores with the value matrix:
for (int embedding_index = 0; embedding_index < embedding_dimension; embedding_index++) {
float probability_times_value = 0;
for (int column_index_inner = 0; column_index_inner < block_size_columns; column_index_inner++) {
probability_times_value += score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner] *
value_matrix_tile[(column_index_inner * embedding_dimension) + embedding_index];
}
output_matrix_device_pointer[qkv_offset + ...] = ...
}
Synchronization and Final Steps
After all computations, we perform synchronization to ensure that all threads have completed their tasks:
__syncthreads(); // Ensure all computations are complete before proceeding
Implementing the main
Function for the CUDA Kernel
The main
function is the heart of the CUDA program, where we set up the problem, allocate memory, and launch the kernel. Below, we walk through the steps involved in this function.
1. Problem Setup
We start by defining the key parameters for our task:
batch_size
,num_heads
,sequence_length
, andembedding_dimension
to define the input dimensions.block_size_columns
andblock_size_rows
to control how threads are grouped within blocks.softmax_scale
to scale the scores during softmax computation.
int batch_size = 2;
int num_heads = 8;
int sequence_length = 64;
int embedding_dimension = 32;
int block_size_columns = 8;
int block_size_rows = 8;
float softmax_scale = 1.0f / std::sqrt(embedding_dimension);
2. Memory Allocation and Initialization
We allocate and initialize device memory for the query
, key
, value
matrices, as well as the sum
and max
matrices used during the softmax computation.
size_t matrix_size = batch_size * num_heads * sequence_length * embedding_dimension * sizeof(float);
float* query_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
float* key_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
float* value_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
float* sum_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(
batch_size * num_heads * sequence_length * sizeof(float), true);
float* max_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(
batch_size * num_heads * sequence_length * sizeof(float), false, true);
float* output_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size, true);
3. Grid and Block Configuration
CUDA requires defining the grid and block configuration. Here, we use 2D grids and 1D blocks.
dim3 block(block_size_rows);
dim3 grid(sequence_length / block_size_rows, num_heads, batch_size);
size_t shared_memory_size = 4 * block_size_columns * embedding_dimension * sizeof(float);
4. Launching the Kernel
The forward_kernel
is launched with the grid and block dimensions specified. We pass the required pointers and configuration parameters.
forward_kernel<<<grid, block, shared_memory_size>>>(
query_matrix_device_pointer, key_matrix_device_pointer, value_matrix_device_pointer,
sequence_length, embedding_dimension,
sequence_length / block_size_columns, sequence_length / block_size_rows,
block_size_columns, block_size_rows, softmax_scale,
sum_matrix_device_pointer, max_matrix_device_pointer, output_matrix_device_pointer);
cudaDeviceSynchronize();
5. Validating and Writing Results
Writing Matrix to File: writeMatrixToFile
Function
The writeMatrixToFile
function is an essential utility for validating and analyzing the outputs of CUDA kernels. It enables saving device matrices to a file, such as a CSV, for debugging or further evaluation. This section explains its implementation and role in the workflow.
Purpose
- Debugging: Provides a way to inspect intermediate or final results of computations.
- Analysis: Enables data visualization or comparison with expected outputs in external tools like Python which I actually implemented, check the repo.
Implementation
The writeMatrixToFile
function copies matrix data from device memory to host memory and writes it to a file.
#include <fstream>
#include <iostream>
void writeMatrixToFile(float* device_pointer, const std::string& filename, int batch_size, int num_heads, int sequence_length, int embedding_dimension) {
size_t size = batch_size * num_heads * sequence_length * embedding_dimension;
float* host_pointer = new float[size];
// Copy data from device to host
cudaMemcpy(host_pointer, device_pointer, size * sizeof(float), cudaMemcpyDeviceToHost);
// Open file for writing
std::ofstream file(filename);
if (!file.is_open()) {
std::cerr << "Failed to open file: " << filename << std::endl;
delete[] host_pointer;
return;
}
// Write data to the file
for (int batch = 0; batch < batch_size; ++batch) {
for (int head = 0; head < num_heads; ++head) {
for (int seq = 0; seq < sequence_length; ++seq) {
for (int embed = 0; embed < embedding_dimension; ++embed) {
size_t index = ((batch * num_heads + head) * sequence_length + seq) * embedding_dimension + embed;
file << host_pointer[index];
if (embed < embedding_dimension - 1) {
file << ",";
}
}
file << "\n";
}
}
}
// Clean up
file.close();
delete[] host_pointer;
std::cout << "Matrix written to file: " << filename << std::endl;
}
Key Steps
- Copy Data: Transfers matrix data from the device to host memory using
cudaMemcpy
. - File Handling: Opens a file in write mode and verifies successful opening.
- Data Writing: Iterates through matrix dimensions (
batch_size
,num_heads
,sequence_length
,embedding_dimension
) and writes values in CSV format. - Cleanup: Closes the file and releases dynamically allocated host memory.
Usage in Main Function
The writeMatrixToFile
function is called after kernel execution to save the computed output matrix:
writeMatrixToFile(output_matrix_device_pointer, "output_matrix.csv", batch_size, num_heads, sequence_length, embedding_dimension);
Advantages
- Versatility: Supports matrices of arbitrary dimensions.
- Ease of Use: Automates the process of exporting CUDA output for offline analysis.
- Readability: Organizes data in a structured format suitable for debugging and visualization.
6. Cleaning Up
Finally, we free the allocated device memory to avoid memory leaks.
cudaFree(query_matrix_device_pointer);
cudaFree(key_matrix_device_pointer);
cudaFree(value_matrix_device_pointer);
cudaFree(sum_matrix_device_pointer);
cudaFree(max_matrix_device_pointer);
cudaFree(output_matrix_device_pointer);
Full Main Function
Here's the complete main
function:
int main() {
int batch_size = 2, num_heads = 8, sequence_length = 64, embedding_dimension = 32;
int block_size_columns = 8, block_size_rows = 8;
float softmax_scale = 1.0f / std::sqrt(embedding_dimension);
size_t matrix_size = batch_size * num_heads * sequence_length * embedding_dimension * sizeof(float);
float* query_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
float* key_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
float* value_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
float* sum_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(
batch_size * num_heads * sequence_length * sizeof(float), true);
float* max_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(
batch_size * num_heads * sequence_length * sizeof(float), false, true);
float* output_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size, true);
dim3 block(block_size_rows);
dim3 grid(sequence_length / block_size_rows, num_heads, batch_size);
size_t shared_memory_size = 4 * block_size_columns * embedding_dimension * sizeof(float);
forward_kernel<<<grid, block, shared_memory_size>>>(
query_matrix_device_pointer, key_matrix_device_pointer, value_matrix_device_pointer,
sequence_length, embedding_dimension,
sequence_length / block_size_columns, sequence_length / block_size_rows,
block_size_columns, block_size_rows, softmax_scale,
sum_matrix_device_pointer, max_matrix_device_pointer, output_matrix_device_pointer);
cudaDeviceSynchronize();
writeMatrixToFile(output_matrix_device_pointer, "output_matrix.csv", batch_size, num_heads, sequence_length, embedding_dimension);
cudaFree(query_matrix_device_pointer);
cudaFree(key_matrix_device_pointer);
cudaFree(value_matrix_device_pointer);
cudaFree(sum_matrix_device_pointer);
cudaFree(max_matrix_device_pointer);
cudaFree(output_matrix_device_pointer);
return 0;
}
Takeaways
- Memory Management: Properly allocate and deallocate device memory to ensure efficiency and avoid memory leaks.
- CUDA Configuration: Understanding grid and block configuration is crucial for maximizing performance.
- Debugging: Use
cudaDeviceSynchronize()
to catch kernel errors early during development. - Post-Processing: Save outputs to files for verification and further analysis.
This setup ensures a structured and efficient implementation of the CUDA kernel in a real-world scenario.
Conclusion
The forward pass of Flash Attention effectively calculates attention scores in parallel, leveraging shared memory for performance optimization. By structuring the computation through blocks and efficiently utilizing GPU resources, Flash Attention drastically improves the speed of transformers.