A plethora of science

Understanding Flash Attention (Forward) with CUDA

Flash Attention is an efficient mechanism used in transformers to enhance attention computation, making the operation faster and less memory-intensive. This post will break down the forward pass of Flash Attention as implemented in CUDA.

Paper link


Overview of the Algorithm

In the forward kernel, we use matrices Q (query), K (key), and V (value) split into blocks for efficient computations. The essential steps are:

  1. Divide Q, K, and V into blocks.
  2. Load the query, key, and value matrices into shared memory (SRAM).
  3. Compute the attention scores using softmax.
  4. Calculate the output.

Memory Allocation and Initialization

Before performing computations, we allocate and initialize device memory for the matrices:

template <typename T>
T* allocateAndInitializeDeviceMemory(size_t size, bool initializeToZero = false, bool initializeToNegativeInfinity = false) {
    T* device_ptr;
    cudaMalloc(&device_ptr, size); // Allocate memory on the device

    if (initializeToZero) {
        cudaMemset(device_ptr, 0, size); // Initialize to zero
    } else if (initializeToNegativeInfinity) {
        float negative_infinity_host = -INFINITY;
        cudaMemset(device_ptr, *reinterpret_cast<int*>(&negative_infinity_host), size); // Initialize to negative infinity
    } else {
        // Generate random numbers if no initialization is specified
    }

    return device_ptr;
}

CUDA Kernel: forward_kernel

The main CUDA kernel forward_kernel handles the computations.

Thread and Block Indexing

Each thread in a block computes a portion of the output for specific query and key matrices.

SRAM for Q, K, V

We declare shared memory for tiles of Q, K, and V:

extern __shared__ float shared_memory[];
float* query_matrix_tile = shared_memory;
float* key_matrix_tile = &shared_memory[tile_size];
float* value_matrix_tile = &shared_memory[tile_size * 2];

Loading Matrices into SRAM

We load tiles of K and V from global memory into shared memory:

for (int embedding_index = 0; embedding_index < embedding_dimension; embedding_index++) {
    key_matrix_tile[(thread_index_x * embedding_dimension) + embedding_index] = 
        key_matrix_device_pointer[qkv_offset + ...];
    value_matrix_tile[(thread_index_x * embedding_dimension) + embedding_index] = 
        value_matrix_device_pointer[qkv_offset + ...];
}
__syncthreads(); // Ensure all threads have completed loading

Computing Attention Scores

For each block, compute scores based on queries and keys:

for (int column_index_inner = 0; column_index_inner < block_size_columns; column_index_inner++) {
    sum += query_matrix_tile[(thread_index_x * embedding_dimension) + embedding_index] * 
           key_matrix_tile[(column_index_inner * embedding_dimension) + embedding_index];
}
score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner] = sum * softmax_scale;

Softmax Calculation

Scale the scores and apply the softmax function:

float row_sum = 0;
for (int column_index_inner = 0; column_index_inner < block_size_columns; column_index_inner++) {
    score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner] = 
        __expf(score - row_max);
    row_sum += score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner];
}

Calculating Outputs

Compute the final output matrix by combining scores with the value matrix:

for (int embedding_index = 0; embedding_index < embedding_dimension; embedding_index++) {
    float probability_times_value = 0;
    for (int column_index_inner = 0; column_index_inner < block_size_columns; column_index_inner++) {
        probability_times_value += score_matrix_tile[(block_size_columns * thread_index_x) + column_index_inner] * 
                                   value_matrix_tile[(column_index_inner * embedding_dimension) + embedding_index];
    }
    output_matrix_device_pointer[qkv_offset + ...] = ...
}

Synchronization and Final Steps

After all computations, we perform synchronization to ensure that all threads have completed their tasks:

__syncthreads(); // Ensure all computations are complete before proceeding

Implementing the main Function for the CUDA Kernel

The main function is the heart of the CUDA program, where we set up the problem, allocate memory, and launch the kernel. Below, we walk through the steps involved in this function.

1. Problem Setup

We start by defining the key parameters for our task:

int batch_size = 2;
int num_heads = 8;
int sequence_length = 64;
int embedding_dimension = 32;
int block_size_columns = 8;
int block_size_rows = 8;
float softmax_scale = 1.0f / std::sqrt(embedding_dimension);

2. Memory Allocation and Initialization

We allocate and initialize device memory for the query, key, value matrices, as well as the sum and max matrices used during the softmax computation.

size_t matrix_size = batch_size * num_heads * sequence_length * embedding_dimension * sizeof(float);
float* query_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
float* key_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
float* value_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
float* sum_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(
    batch_size * num_heads * sequence_length * sizeof(float), true);
float* max_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(
    batch_size * num_heads * sequence_length * sizeof(float), false, true);
float* output_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size, true);

3. Grid and Block Configuration

CUDA requires defining the grid and block configuration. Here, we use 2D grids and 1D blocks.

dim3 block(block_size_rows);
dim3 grid(sequence_length / block_size_rows, num_heads, batch_size);
size_t shared_memory_size = 4 * block_size_columns * embedding_dimension * sizeof(float);

4. Launching the Kernel

The forward_kernel is launched with the grid and block dimensions specified. We pass the required pointers and configuration parameters.

forward_kernel<<<grid, block, shared_memory_size>>>(
    query_matrix_device_pointer, key_matrix_device_pointer, value_matrix_device_pointer,
    sequence_length, embedding_dimension,
    sequence_length / block_size_columns, sequence_length / block_size_rows,
    block_size_columns, block_size_rows, softmax_scale,
    sum_matrix_device_pointer, max_matrix_device_pointer, output_matrix_device_pointer);
cudaDeviceSynchronize();

5. Validating and Writing Results

Writing Matrix to File: writeMatrixToFile Function

The writeMatrixToFile function is an essential utility for validating and analyzing the outputs of CUDA kernels. It enables saving device matrices to a file, such as a CSV, for debugging or further evaluation. This section explains its implementation and role in the workflow.

Purpose

Implementation

The writeMatrixToFile function copies matrix data from device memory to host memory and writes it to a file.

#include <fstream>
#include <iostream>

void writeMatrixToFile(float* device_pointer, const std::string& filename, int batch_size, int num_heads, int sequence_length, int embedding_dimension) {
    size_t size = batch_size * num_heads * sequence_length * embedding_dimension;
    float* host_pointer = new float[size];

    // Copy data from device to host
    cudaMemcpy(host_pointer, device_pointer, size * sizeof(float), cudaMemcpyDeviceToHost);

    // Open file for writing
    std::ofstream file(filename);
    if (!file.is_open()) {
        std::cerr << "Failed to open file: " << filename << std::endl;
        delete[] host_pointer;
        return;
    }

    // Write data to the file
    for (int batch = 0; batch < batch_size; ++batch) {
        for (int head = 0; head < num_heads; ++head) {
            for (int seq = 0; seq < sequence_length; ++seq) {
                for (int embed = 0; embed < embedding_dimension; ++embed) {
                    size_t index = ((batch * num_heads + head) * sequence_length + seq) * embedding_dimension + embed;
                    file << host_pointer[index];
                    if (embed < embedding_dimension - 1) {
                        file << ",";
                    }
                }
                file << "\n";
            }
        }
    }

    // Clean up
    file.close();
    delete[] host_pointer;

    std::cout << "Matrix written to file: " << filename << std::endl;
}

Key Steps

  1. Copy Data: Transfers matrix data from the device to host memory using cudaMemcpy.
  2. File Handling: Opens a file in write mode and verifies successful opening.
  3. Data Writing: Iterates through matrix dimensions (batch_size, num_heads, sequence_length, embedding_dimension) and writes values in CSV format.
  4. Cleanup: Closes the file and releases dynamically allocated host memory.

Usage in Main Function

The writeMatrixToFile function is called after kernel execution to save the computed output matrix:

writeMatrixToFile(output_matrix_device_pointer, "output_matrix.csv", batch_size, num_heads, sequence_length, embedding_dimension);

Advantages

6. Cleaning Up

Finally, we free the allocated device memory to avoid memory leaks.

cudaFree(query_matrix_device_pointer);
cudaFree(key_matrix_device_pointer);
cudaFree(value_matrix_device_pointer);
cudaFree(sum_matrix_device_pointer);
cudaFree(max_matrix_device_pointer);
cudaFree(output_matrix_device_pointer);

Full Main Function

Here's the complete main function:

int main() {
    int batch_size = 2, num_heads = 8, sequence_length = 64, embedding_dimension = 32;
    int block_size_columns = 8, block_size_rows = 8;
    float softmax_scale = 1.0f / std::sqrt(embedding_dimension);

    size_t matrix_size = batch_size * num_heads * sequence_length * embedding_dimension * sizeof(float);
    float* query_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
    float* key_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
    float* value_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size);
    float* sum_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(
        batch_size * num_heads * sequence_length * sizeof(float), true);
    float* max_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(
        batch_size * num_heads * sequence_length * sizeof(float), false, true);
    float* output_matrix_device_pointer = allocateAndInitializeDeviceMemory<float>(matrix_size, true);

    dim3 block(block_size_rows);
    dim3 grid(sequence_length / block_size_rows, num_heads, batch_size);
    size_t shared_memory_size = 4 * block_size_columns * embedding_dimension * sizeof(float);

    forward_kernel<<<grid, block, shared_memory_size>>>(
        query_matrix_device_pointer, key_matrix_device_pointer, value_matrix_device_pointer,
        sequence_length, embedding_dimension,
        sequence_length / block_size_columns, sequence_length / block_size_rows,
        block_size_columns, block_size_rows, softmax_scale,
        sum_matrix_device_pointer, max_matrix_device_pointer, output_matrix_device_pointer);
    cudaDeviceSynchronize();

    writeMatrixToFile(output_matrix_device_pointer, "output_matrix.csv", batch_size, num_heads, sequence_length, embedding_dimension);

    cudaFree(query_matrix_device_pointer);
    cudaFree(key_matrix_device_pointer);
    cudaFree(value_matrix_device_pointer);
    cudaFree(sum_matrix_device_pointer);
    cudaFree(max_matrix_device_pointer);
    cudaFree(output_matrix_device_pointer);

    return 0;
}

Takeaways

  1. Memory Management: Properly allocate and deallocate device memory to ensure efficiency and avoid memory leaks.
  2. CUDA Configuration: Understanding grid and block configuration is crucial for maximizing performance.
  3. Debugging: Use cudaDeviceSynchronize() to catch kernel errors early during development.
  4. Post-Processing: Save outputs to files for verification and further analysis.

This setup ensures a structured and efficient implementation of the CUDA kernel in a real-world scenario.

Conclusion

The forward pass of Flash Attention effectively calculates attention scores in parallel, leveraging shared memory for performance optimization. By structuring the computation through blocks and efficiently utilizing GPU resources, Flash Attention drastically improves the speed of transformers.

The code

My LinkedIn

Umar Jamil's flash attention video

check my bro lynn's 100 days challenge github