+9

GEMM: Hiểu thêm về viên gạch tạo nên các mô hình Deep Learning

1. GEMM không phải là nhân ma trận thôi à?

Đúng vậy, GEMM chính là phép nhân ma trận. GEMM là viết tắt của GEneral Matrix Multiplication, là một phần của đặc tả bậc 3 của BLAS (Basic Linear Algebra Subprogram). GEMM được đặc tả với công thức tổng quát sau:

C=αAB+βCC = \alpha A B + \beta C

Trong đó AABB là các ma trận đầu với size lần lượt là RM×N\mathbb{R}^{M \times N}RN×P\mathbb{R}^{N \times P}, α\alphaβ\beta là đại lượng vô hướng, CC là ma trận đầu ra được khởi tạo từ trước với size RM×P\mathbb{R}^{M \times P} và kết quả sẽ được viết đè lên ma trận này. Có thể dễ dàng thấy với α=1\alpha = 1β=0\beta = 0, ta sẽ có được phép nhân ma trận quen thuộc mà ta được dạy từ giải tích cấp 3.

2. Tại sao GEMM lại quan trọng?

Gần như mọi layer quan trọng trong các mạng Deep Learning đều sử dụng đến phép tính trên. Nổi bật nhất có lẽ là các lớp Fully Connected, phép toán Attention trong kiến trúc Transformer, hay đến cả phép Convolution cũng sử dụng GEMM. Có thể dễ dàng thấy phần lớn thời gian tính toán của các Deep Learning nằm ở phép toán này. Chính vì vậy, việc tối ưu phép toán trên cực kì quan trọng trong việc giảm tài nguyên và thời gian dành ra để huấn luyện mô hình cũng như cải thiện hiệu năng của mô hình trong quá trình sử dụng.

3. Triton, CUDA nhưng lại là python?

Một trong những điểm đáng chú ý của GEMM là phép tính này có thể dễ dàng tính toán song song, hay nói cách khác chúng ta có thể dùng GPU để tăng đáng kể hiệu năng so với CPU. Đây cũng chính là lý do GPU được sử dụng triệt để trong Deep Learning. Và khi nói đến các ông kẹ trong mảng này thì không ai có thể soán ngôi nổi NVIDIA với dòng GPU của mình cùng CUDA. Tuy nhiên, không phải ai cũng có thể lập trình ngôn ngữ bậc thấp như C/C++ thành thạo, đó là chưa kể chúng ta phải xử lý multithread trên chính ngôn ngữ này nữa. Việc này lại càng khó nhằn hơn đối với những bạn thuần Data Science và chưa động vào gì khác trừ python. Chính vì lý do đó mà đội phát triển Kernel của OpenAI đã phát triển một ngôn ngữ mới, có syntax giống python nhưng compile ra PTX (assembly nhưng dành chu GPU). Triton tự động tối ưu công việc quản lý bộ nhớ chung cũng như scheduling trong SM, từ đó cho phép lập trình viên chú trọng hơn vào logic của thuật toán, giảm thời gian vật lộn với code, giảm số dòng code phải viết cũng như đảm bảo hiệu năng thực thi không thua kém code CUDA thuần. Và đặc biệt, cài đặt triton rất đơn giản: pip install triton

3.1 Element wise

Cách đơn giản nhất đó là tách phép toán ra thành các phép toán đơn giản hơn. Ta có thể dễ dàng thấy phép toán trên là tổ hợp của phép nhân vô hướng, phép cộng 1 số với 1 ma trận và phép nhân ma trận. Với phép nhân vô hướng và phép cộng, ta có thể dễ dàng định nghĩa nó thành phép toán theo từng cặp phần tử (element wise)

Bắt đầu bằng việc import các thư viện cần thiết

import torch
import triton
import triton.language as tl

Khi được gọi, việc thực thi kernel sẽ được đảm nhận bởi các program trên 1 đoạn dữ liệu. Mỗi program sẽ spawn các threads và thực thi song song kernel trên các threads đó.

Đầu tiên, chúng ta khai báo với triton rằng nó cần compile hàm sau thành kernel để chạy trên GPU thông qua decorator

@triton.jit

Kernel của chúng ta sẽ nhận vào là con trỏ tới ma trận đầu vào, đơn vị vô hướng ta cần thực hiện phép tính, con trỏ tới ma trận đầu ra, tổng số phần tử và độ lớn đoạn dữ liệu sẽ được xử lý bởi program, tl.constexpr khai báo biến sẽ được khởi tạo tại compile-time. Đối với đầu vào là ma trận, Triton tự động chuyển torch.Tensor thành con trỏ trỏ tới phần tử đầu tiên trong bộ nhớ. Tiếp theo ta cần xác định program mà ta đang thực thi để xác định đoạn dữ liệu cần duyệt. Hiểu nôm na khi thực thi, mỗi program sẽ chia ma trận thành các slice có dạng sau

[program_id : program_id + BLOCK_SIZE]

và thực thi trên đó (ma trận nhiều chiều thực chất là mảng 1 chiều đã dàn phẳng bằng

torch.flatten()

triton.drawio.png

Program có thể được truy cập trên 3 trục x, y, z (do CUDA vốn được tạo ra để thực hiện phép tính trên không gian 3 chiều dùng trong đồ hoạ), với phép tính trên ta chỉ cần map trên trục x mà thôi

pid = tl.program_id(0)

Tiếp đến ta xác định slice trên mảng dữ liệu

offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

và mask lại đề phòng out of range access trên mảng

mask = offsets < numels

Tiếp theo là thực hiện load mảng đầu vào đã được xác định, tính toán và viết kết quả ra bộ nhớ

x = tl.load(x_ptr + offsets, mask=mask)
y = x * scalar
tl.store(output_ptr + offsets, y, mask=mask)

Vậy là ta đã có một kernel hoàn chỉnh để thực thi phép tính nhân 1 số với 1 ma trận rồi. Sau đây là toàn bộ code phần kernel trên:

@triton.jit
def scalar_multiply_kernel(x_ptr, scalar, output_ptr, numels, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < numels

    x = tl.load(x_ptr + offsets, mask=mask)

    output = x * scalar

    tl.store(output_ptr + offsets, output, mask=mask)

Đối với phép cộng ma trận với ma trận, ta chỉ cần thay scalar thành con trỏ tới ma trận đầu vào thứ 2 và thực hiện load giống ma trận thứ 1 và tính toán là xong

@triton.jit
def elementwise_add_kernel(x_ptr, y_ptr, output_ptr, numels, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < numels

    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    output = x + y

    tl.store(output_ptr + offsets, output, mask=mask)

Để gọi thực thi kernel, ngoài các parameters được định nghĩa, ta cần truyền thêm grid config. Grid config là nhằm chỉ đến số lượng programs và số threads trong mỗi program mà GPU để thực thi kernel. Do mỗi thread sẽ đảm nhận phép tính trên 1 phần tử nên ta sẽ có được số program là cận trên của phép tính so^ˊpha^ˋntBLOCK_SIZE\dfrac{số\,phần\,tử} {BLOCK\_SIZE}. Hàm thực thi sẽ như sau:

def scalar_multiply(x: torch.Tensor, s: float):
    output = torch.empty_like(x)

    numels = output.numel()

    BLOCK_SIZE = min(triton.next_power_of_2(numels), 1024)

    grid = lambda meta: (triton.cdiv(numels, meta["BLOCK_SIZE"]),)
    scalar_multiply_kernel[grid](x, s, output, numels, BLOCK_SIZE=BLOCK_SIZE)
    return output

def elementwise_add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)

    numels = output.numel()

    BLOCK_SIZE = min(triton.next_power_of_2(numels), 1024)

    grid = lambda meta: (triton.cdiv(numels, meta["BLOCK_SIZE"]),)
    elementwise_add_kernel[grid](x, y, output, numels, BLOCK_SIZE=BLOCK_SIZE)
    return output

3.2 Matrix Multiplication

Giờ chúng ta sẽ đến với mảnh ghép chính, phép nhân ma trận. Phép nhân giữa 2 ma trận ARM×NA \in \mathbb{R}^{M \times N}BRN×PB \in \mathbb{R}^{N \times P} được định nghĩa bởi công thức sau:

Ci,j=k=1NAi,kBk,jC_{i, j} = \sum_{k=1}^{N}A_{i,k}B_{k,j}

image.png

Do phép tính được thực hiện trên con trỏ trỏ đến dữ liệu đầu tiên trong ma trận, ta duyệt đến phần tử i, j như sau:

A[i, k] = A + i * A_stride_i  + k * stride_k

Và ta cần dữ liệu khi load dưới dạng block nên ta sẽ tính con trỏ tới hàng ma trận A và cột ma trận B như sau:

offset_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offset_p = pid_p * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P)

offset_am = offset_m % M
offset_bp = offset_p % P

offset_k = tl.arange(0, BLOCK_SIZE_N)

a_ptrs = a_ptr + offset_am[:, None] * stride_am + offset_k[None, :] * stride an
b_ptrs = b_ptr + offset_k[:, None] * stride_bn + offset_bp[None, :] * stride_bp

Đối với outer loop, ta thực hiện masking bằng cách chia lấy dư cho độ dài off_set_am = offset_m % Moff_set_bp = offset_p % P

Đối với inner loop, với mỗi một cặp hàng i ma trận A và cột j ma trận B, ta sẽ tính dot product trên BLOCK_SIZE_K phần tử của chúng, lưu trong biến acc và cập nhật con trỏ của ma trận A và B như sau:

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32)
for k in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
    a = tl.load(a_ptrs, mask=offset_k[None, :] < N - k * BLOCK_SIZE_N, other=0.0)
    b = tl.load(b_ptrs, mask=offset_k[:, None] < N - k * BLOCK_SIZE_N, other=0.0)
    acc = tl.dot(a, b, acc)
    a_ptrs += stride_an * BLOCK_SIZE_N
    b_ptrs += stride_bn * BLOCK_SIZE_N
c = acc.to(tl.float16)

Biến acc được khởi tạo với kiểu float32 để tăng độ chính xác khi tính toán, sau đó sẽ được convert lại về kiểu float16 Để tránh lặp lại phép tính khi duyệt trên BLOCK_SIZE_N, ta mask các phần tử đó lại thành 0.0 để tránh ảnh hưởng đến biến acc. Thực hiện phép tính dot product trên hàng i ma trận A và cột j ma trận B với độ dài là BLOCK_SIZE_N, sau đó ta cập nhật lại con trỏ đến hàng và cột tiếp theo.

Cuối cùng ta ghi kết quả ra VRAM:

c_ptrs = c_ptr + offset_m[:, None] * stride_cm + offset_p[None, :] * stride_cp

c_mask = (offset_m[:, None] < M) & (offset_p[None, :] < P)
tl.store(c_ptrs, c, mask=c_mask)

Chúng ta kernel đầy đủ như sau:

@triton.jit
def matmul_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    stride_am,
    stride_an,
    stride_bn,
    stride_bp,
    stride_cm,
    stride_cp,
    M,
    N,
    P,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_P: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_p = tl.program_id(1)

    offset_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offset_p = pid_p * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P)

    offset_am = offset_m % M
    offset_bp = offset_p % P

    offset_k = tl.arange(0, BLOCK_SIZE_N)

    a_ptrs = a_ptr + offset_am[:, None] * stride_am + offset_k[None, :] * stride_an
    b_ptrs = b_ptr + offset_k[:, None] * stride_bn + offset_bp[None, :] * stride_bp

    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32)
    for k in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
        a = tl.load(a_ptrs, mask=offset_k[None, :] < N - k * BLOCK_SIZE_N, other=0.0)
        b = tl.load(b_ptrs, mask=offset_k[:, None] < N - k * BLOCK_SIZE_N, other=0.0)
        acc = tl.dot(a, b, acc)
        a_ptrs += stride_an * BLOCK_SIZE_N
        b_ptrs += stride_bn * BLOCK_SIZE_N
    c = acc.to(tl.float16)

    c_ptrs = c_ptr + offset_m[:, None] * stride_cm + offset_p[None, :] * stride_cp

    c_mask = (offset_m[:, None] < M) & (offset_p[None, :] < P)
    tl.store(c_ptrs, c, mask=c_mask)

Hàm gọi kernel của chúng ta sẽ như sau:

def matmul(a, b):
    M, N = a.shape
    N, P = b.shape

    c = torch.empty((M, P), device="cuda", dtype=torch.float16)

    def find_block_size(n):
        return max(min(triton.next_power_of_2(n), 256), 64)

    BLOCK_SIZE_M = find_block_size(M)
    BLOCK_SIZE_P = find_block_size(N)

    grid = lambda meta: (
        triton.cdiv(M, meta["BLOCK_SIZE_M"]),
        triton.cdiv(P, meta["BLOCK_SIZE_P"]),
    )

    matmul_kernel[grid](
        a,
        b,
        c,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        M,
        N,
        P,
        BLOCK_SIZE_M=BLOCK_SIZE_M,
        BLOCK_SIZE_P=BLOCK_SIZE_P,
        BLOCK_SIZE_N=64,
    )

    return c

BLOCK_SIZE chúng ta có thể đặt là bội số của 2. Thông thường GPU có số threads tối đa với mỗi block là 1024 để ta sử dụng. Ở đây ta sẽ đặt BLOCK_SIZE_N là 64, BLOCK_SIZE_M và BLOCK_SIZE_P sẽ là bội số gần nhất đối với M và P với cận dưới là 32 và cận trên là 256.

3.3 Hiệu năng thì sao?

Vậy là chúng ta đã có đủ mảnh ghép cho phép GEMM tổng quát rồi.

def gemm(a, b, c=None, alpha=1.0, beta=0.0):
    c_hat = scalar_multiply(matmul(a, b), alpha)
    if c is None:
        c = torch.ones_like(c_hat)
    c = scalar_multiply(c, beta)
    return elementwise_add(c_hat, c)

Tuy nhiên, xét về hiệu năng thì sao? Triton có hỗ trợ benchmark API để ta dễ dàng so sánh các implementation với nhau. API hỗ trợ xuất thông số ra dưới dạng file .csv cũng như vẽ line graph để ta dễ hình dung hiệu năng của kernel ta vừa implement. Ta sẽ so sánh hiệu năng với implementation của Pytorch, một framwork về deeplearning cực kỳ nổi tiếng. Sau đấy là hàm tính GEMM đơn giản bằng pytorch:

def torch_gemm(a, b, c=None, alpha=1.0, beta=0.0):
    c_hat = alpha * torch.matmul(a, b)
    if c is None:
        c = torch.ones_like(c_hat)
    return c_hat + beta * c

Và đây là code benchmark của chúng ta:

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["size"], 
        x_vals=[
            2**i for i in range(10, 15, 1)
        ],  
        x_log=True, 
        line_arg="provider",  
        line_vals=["triton", "torch"],
        line_names=["Triton", "Torch"], 
        styles=[("blue", "-"), ("green", "-")],  
        ylabel="ms",  
        plot_name="elementwise-multiply-performance", 
        args={},  
    )
)
def benchmark(size, provider):
    m = size // 2
    n = m // 2
    p = size - m - n
    a = torch.rand((m, n), device="cuda", dtype=torch.float16)
    b = torch.rand((n, p), device="cuda", dtype=torch.float16)
    c = torch.rand((m, p), device="cuda", dtype=torch.float16)
    alpha = 2.0
    beta = 2.0
    quantiles = [0.5, 0.2, 0.8]
    if provider == "torch":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch_gemm(a, b, c, alpha, beta), quantiles=quantiles
        )
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: gemm(a, b, c, alpha, beta), quantiles=quantiles
        )
    return ms, max_ms, min_ms

benchmark.run(print_data=True, show_plots=True)

Và chúng ta có kết quả sau (chạy trên Nvidia RTX 3060):

Size Triton Torch
1024 0.601088 0.072704
2048 0.326656 0.069632
4096 3.391488 0.286720
8192 26.287104 1.770496
16384 200.496124 12.006400

image.png

Có thể dễ dàng thấy hiệu năng của chúng ta một trời một vực đối với hàng chính chủ của Pytorch. Khác biệt càng lớn hơn khi ma trận càng to lên, Khi số phần tử lên đến 16000, cách biệt đã lên đến

4. Kết luận

Vậy là chúng ta đã tập tững được những bước đầu trong con đường viết kernel của chúng ta, vẫn còn rất nhiều kĩ thuật nhằm tối ưu hoá hiệu năng của kernel như tuning, fused kernel, tiled,... Chúng ta sẽ từng bước tìm hiểu và áp dụng những kĩ thuật tối ưu trên vào GEMM kernel của chúng ta cũng như profile bằng các công cụ trong gói NVIDIA Tools Kit nhằm xác định các nút thắt hiệu năng trong những bài sau. Cuối cùng chúng ta sẽ viết lại bằng chính CUDA API nhằm vắt kiệt từng chút hiệu năng còn sót lại của chiếc GPU nhà NVIDIA.

5. References


All Rights Reserved

Viblo
Let's register a Viblo Account to get more interesting posts.