From e5dc83937d92df968ba14fb98c85aa345972b2d0 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Wed, 13 Mar 2024 15:09:03 +0000 Subject: [PATCH] feat(library): add CUDA unpack kernel --- quanto/library/ext/__init__.py | 3 + quanto/library/ext/cpp/__init__.py | 2 +- quanto/library/ext/cuda/README.md | 11 +++ quanto/library/ext/cuda/__init__.py | 30 ++++++++ quanto/library/ext/cuda/pybind_module.cpp | 13 ++++ quanto/library/ext/cuda/unpack.cu | 83 +++++++++++++++++++++++ quanto/library/ext/cuda/unpack.h | 3 + 7 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 quanto/library/ext/cuda/README.md create mode 100644 quanto/library/ext/cuda/__init__.py create mode 100644 quanto/library/ext/cuda/pybind_module.cpp create mode 100644 quanto/library/ext/cuda/unpack.cu create mode 100644 quanto/library/ext/cuda/unpack.h diff --git a/quanto/library/ext/__init__.py b/quanto/library/ext/__init__.py index 834b6fa3..611eee96 100644 --- a/quanto/library/ext/__init__.py +++ b/quanto/library/ext/__init__.py @@ -3,5 +3,8 @@ from .cpp import * +if torch.cuda.is_available(): + from .cuda import * + if torch.backends.mps.is_available(): from .mps import * diff --git a/quanto/library/ext/cpp/__init__.py b/quanto/library/ext/cpp/__init__.py index 476a28a8..717730f6 100644 --- a/quanto/library/ext/cpp/__init__.py +++ b/quanto/library/ext/cpp/__init__.py @@ -33,6 +33,6 @@ def dqmm_cpp(input: torch.Tensor, other: torch.Tensor, other_scale: torch.Tensor return ext().dqmm(input, other, other_scale) -@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"]) +@torch.library.impl("quanto_ext::unpack", ["CPU"]) def unpack_cpp(t: torch.Tensor, bits: int): return ext().unpack(t, bits) diff --git a/quanto/library/ext/cuda/README.md b/quanto/library/ext/cuda/README.md new file mode 100644 index 00000000..f3b9e8a0 --- /dev/null +++ b/quanto/library/ext/cuda/README.md @@ -0,0 +1,11 @@ +# Quanto generic CUDA extension + +Kernels in this extension can use both the C++ and CUDA syntax. + +They can use any pytorch operation defined under `aten::` or `c10::`. + +To add a new implementation for an operation defined in `library./ops.py`: + +- add the corresponding `.cpp` or `.cu` file to the list of sources in `__init__.py`, +- add a binding to `pybind_module.cpp`, +- provide an implementation calling the binding in `__init__.py`. diff --git a/quanto/library/ext/cuda/__init__.py b/quanto/library/ext/cuda/__init__.py new file mode 100644 index 00000000..35b4f714 --- /dev/null +++ b/quanto/library/ext/cuda/__init__.py @@ -0,0 +1,30 @@ +import os + +import torch +from torch.utils.cpp_extension import load + + +__all__ = [] + + +_ext = None + + +def ext(): + """Helper to load the CUDA ext only when it is required""" + global _ext + if _ext is None: + module_path = os.path.dirname(__file__) + _ext = load( + name="quanto_cuda", + sources=[ + f"{module_path}/unpack.cu", + f"{module_path}/pybind_module.cpp", + ], + ) + return _ext + + +@torch.library.impl("quanto_ext::unpack", ["CUDA"]) +def unpack_cuda(t: torch.Tensor, bits: int): + return ext().unpack(t, bits) diff --git a/quanto/library/ext/cuda/pybind_module.cpp b/quanto/library/ext/cuda/pybind_module.cpp new file mode 100644 index 00000000..013f2e6f --- /dev/null +++ b/quanto/library/ext/cuda/pybind_module.cpp @@ -0,0 +1,13 @@ +#include +#include "unpack.h" + +// !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types, +// and need to be explicitly converted using dedicated helpers before calling a C++ method. +// As a consequence, when an operation takes such an object as parameter, instead +// of creating a binding directly to the C++ method, you must create a binding to a +// lambda method that converts the unmapped types and calls the C++ method. +// See the binding of quantize_symmetric for instance. + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("unpack", &unpack, "unpack"); +} diff --git a/quanto/library/ext/cuda/unpack.cu b/quanto/library/ext/cuda/unpack.cu new file mode 100644 index 00000000..3512a60b --- /dev/null +++ b/quanto/library/ext/cuda/unpack.cu @@ -0,0 +1,83 @@ +#include +#include +#include +#include + +inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;} +#define BLOCK_SIZE 256 + +using namespace at; + + +static torch::Tensor allocate_output(const torch::Tensor& input, int bits) { + int n_packed = 8 / bits; + auto output_shape = input.sizes().vec(); + output_shape[0] = output_shape[0] * n_packed; + return torch::empty(output_shape, input.options()); +} + +__global__ void unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + if(i>=n) return; + + output[i] = (input[i] & 0x0F); + output[i + n] = (input[i] & 0xF0) >> 4; +} + +static torch::Tensor unpack_4bit(const torch::Tensor& input){ + + auto output = allocate_output(input, 4); + + const auto numel = input.numel(); + int blocks = cdiv(numel, BLOCK_SIZE); + unpack_4bit_kernel<<>>( + input.data_ptr(), + output.data_ptr(), + numel + ); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +__global__ void unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + if(i>=n) return; + + output[i] = (input[i] & 0x03); + output[i + n] = (input[i] & 0x0C) >> 2; + output[i + n*2] = (input[i] & 0x30) >> 4; + output[i + n*3] = (input[i] & 0xC0) >> 6; +} + +static torch::Tensor unpack_2bit(const torch::Tensor& input){ + + auto output = allocate_output(input, 2); + + const auto numel = input.numel(); + int blocks = cdiv(numel, BLOCK_SIZE); + unpack_2bit_kernel<<>>( + input.data_ptr(), + output.data_ptr(), + numel + ); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +torch::Tensor unpack(torch::Tensor &t, int bits) { + TORCH_CHECK(t.scalar_type() == torch::kUInt8, "Unsupported data type: ", t.scalar_type()); + TORCH_CHECK(t.device().is_cuda(), "t must be a CUDA tensor."); + TORCH_CHECK(t.is_contiguous(), "t must be contiguous."); + switch(bits) { + case 4: + return unpack_4bit(t); + case 2: + return unpack_2bit(t); + default: + throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors."); + } +} diff --git a/quanto/library/ext/cuda/unpack.h b/quanto/library/ext/cuda/unpack.h new file mode 100644 index 00000000..ddabf77a --- /dev/null +++ b/quanto/library/ext/cuda/unpack.h @@ -0,0 +1,3 @@ +#include + +torch::Tensor unpack(torch::Tensor &t, int bits);