From 472a45a10c2b44d0ea6b8c19de863422b1839f55 Mon Sep 17 00:00:00 2001 From: Pengfei Xuan Date: Sun, 24 Mar 2024 15:12:42 -0400 Subject: [PATCH 01/19] Add initial test for MPS backend --- CMakeLists.txt | 10 +++++++++- model.cpp | 6 +++--- opensplat.cpp | 5 ++++- rasterize_gaussians.cpp | 2 +- simple_trainer.cpp | 2 +- vendor/gsplat-cpu/gsplat_cpu.cpp | 4 ++-- 6 files changed, 20 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 806c4f2..76f19cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.21) project(opensplat) set(OPENSPLAT_BUILD_SIMPLE_TRAINER OFF CACHE BOOL "Build simple trainer applications") -set(GPU_RUNTIME "CUDA" CACHE STRING "HIP or CUDA") +set(GPU_RUNTIME "CUDA" CACHE STRING "HIP or CUDA or MPS") set(OPENCV_DIR "OPENCV_DIR-NOTFOUND" CACHE PATH "Path to the OPENCV installation directory") set(OPENSPLAT_MAX_CUDA_COMPATIBILITY OFF CACHE BOOL "Build for maximum CUDA device compatibility") @@ -81,6 +81,10 @@ elseif(GPU_RUNTIME STREQUAL "HIP") set(ROCM_ROOT "/opt/rocm" CACHE PATH "Root directory of the ROCm installation") endif() list(APPEND CMAKE_PREFIX_PATH "${ROCM_ROOT}") +elseif(GPU_RUNTIME STREQUAL "MPS") + set(USE_MPS ON CACHE BOOL "Use MPS for GPU acceleration") +else() + set(GPU_RUNTIME "CPU") endif() set(CMAKE_CXX_STANDARD 17) @@ -135,6 +139,8 @@ if(GPU_RUNTIME STREQUAL "HIP") target_compile_definitions(opensplat PRIVATE USE_HIP __HIP_PLATFORM_AMD__) elseif(GPU_RUNTIME STREQUAL "CUDA") target_compile_definitions(opensplat PRIVATE USE_CUDA) +elseif(GPU_RUNTIME STREQUAL "MPS") + target_compile_definitions(opensplat PRIVATE USE_MPS) endif() if(OPENSPLAT_BUILD_SIMPLE_TRAINER) @@ -149,6 +155,8 @@ if(OPENSPLAT_BUILD_SIMPLE_TRAINER) target_compile_definitions(simple_trainer PRIVATE USE_HIP __HIP_PLATFORM_AMD__) elseif(GPU_RUNTIME STREQUAL "CUDA") target_compile_definitions(simple_trainer PRIVATE USE_CUDA) + elseif(GPU_RUNTIME STREQUAL "MPS") + target_compile_definitions(simple_trainer PRIVATE USE_MPS) endif() endif() diff --git a/model.cpp b/model.cpp index 7b53ee4..bbfd075 100644 --- a/model.cpp +++ b/model.cpp @@ -89,7 +89,7 @@ torch::Tensor Model::forward(Camera& cam, int step){ torch::Tensor camDepths; // CPU-only torch::Tensor rgb; - if (device == torch::kCPU){ + if (device == torch::kMPS){ auto p = ProjectGaussiansCPU::apply(means, torch::exp(scales), 1, @@ -149,7 +149,7 @@ torch::Tensor Model::forward(Camera& cam, int step){ int degreesToUse = (std::min)(step / shDegreeInterval, shDegree); torch::Tensor rgbs; - if (device == torch::kCPU){ + if (device == torch::kMPS){ rgbs = SphericalHarmonicsCPU::apply(degreesToUse, viewDirs, colors); }else{ #if defined(USE_HIP) || defined(USE_CUDA) @@ -159,7 +159,7 @@ torch::Tensor Model::forward(Camera& cam, int step){ rgbs = torch::clamp_min(rgbs + 0.5f, 0.0f); - if (device == torch::kCPU){ + if (device == torch::kMPS){ rgb = RasterizeGaussiansCPU::apply( xys, radii, diff --git a/opensplat.cpp b/opensplat.cpp index 180e864..d9cd221 100644 --- a/opensplat.cpp +++ b/opensplat.cpp @@ -81,10 +81,13 @@ int main(int argc, char *argv[]){ torch::Device device = torch::kCPU; int displayStep = 1; - if (torch::cuda::is_available() && result.count("cpu") == 0) { + if (torch::hasCUDA() && result.count("cpu") == 0) { std::cout << "Using CUDA" << std::endl; device = torch::kCUDA; displayStep = 10; + } else if (torch::hasMPS() && result.count("cpu") == 0) { + std::cout << "Using MPS" << std::endl; + device = torch::kMPS; }else{ std::cout << "Using CPU" << std::endl; } diff --git a/rasterize_gaussians.cpp b/rasterize_gaussians.cpp index 154b5e9..ab76adb 100644 --- a/rasterize_gaussians.cpp +++ b/rasterize_gaussians.cpp @@ -210,7 +210,7 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr v_outImg, v_outAlpha); - delete[] px2gid; + // delete[] px2gid; torch::Tensor v_xy = std::get<0>(t); diff --git a/simple_trainer.cpp b/simple_trainer.cpp index 2da1704..e148853 100644 --- a/simple_trainer.cpp +++ b/simple_trainer.cpp @@ -139,7 +139,7 @@ int main(int argc, char **argv){ torch::Tensor outImg; for (size_t i = 0; i < iterations; i++){ - if (device == torch::kCPU){ + if (device == torch::kMPS){ auto p = ProjectGaussiansCPU::apply(means, scales, 1, quats, viewMat, viewMat, focal, focal, diff --git a/vendor/gsplat-cpu/gsplat_cpu.cpp b/vendor/gsplat-cpu/gsplat_cpu.cpp index 9385265..849b08d 100644 --- a/vendor/gsplat-cpu/gsplat_cpu.cpp +++ b/vendor/gsplat-cpu/gsplat_cpu.cpp @@ -162,7 +162,7 @@ std::tuple< torch::Tensor outImg = torch::zeros({height, width, channels}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); torch::Tensor finalTs = torch::ones({height, width}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); - torch::Tensor done = torch::zeros({height, width}, torch::TensorOptions().dtype(torch::kBool).device(device)); + torch::Tensor done = torch::zeros({height, width}, torch::TensorOptions().dtype(torch::kBool).device(device)).fill_(false); torch::Tensor sqCov2dX = 3.0f * torch::sqrt(cov2d.index({"...", 0, 0})); torch::Tensor sqCov2dY = 3.0f * torch::sqrt(cov2d.index({"...", 1, 1})); @@ -205,7 +205,7 @@ std::tuple< for (int i = minx; i < maxx; i++){ for (int j = miny; j < maxy; j++){ - size_t pixIdx = (i * width + j); + size_t pixIdx = (i * width/2 + j); if (pDone[pixIdx]) continue; float xCam = gX - j; From b8806715f0bb909b1209d440e716c3cb1cb19db7 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Mon, 8 Apr 2024 19:37:52 -0700 Subject: [PATCH 02/19] convert to CPU before using CPU rasterizer --- .gitignore | 5 ++++ rasterize_gaussians.cpp | 52 +++++++++++++++++++++-------------------- 2 files changed, 32 insertions(+), 25 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b4aa536 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +# MacOS +.DS_Store + +# build +build/ \ No newline at end of file diff --git a/rasterize_gaussians.cpp b/rasterize_gaussians.cpp index ab76adb..049a967 100644 --- a/rasterize_gaussians.cpp +++ b/rasterize_gaussians.cpp @@ -155,24 +155,25 @@ torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx, ){ int numPoints = xys.size(0); + ctx->saved_data["imgWidth"] = imgWidth; + ctx->saved_data["imgHeight"] = imgHeight; + torch::Device device = xys.device(); auto t = rasterize_forward_tensor_cpu(imgWidth, imgHeight, - xys, - conics, - colors, - opacity, - background, - cov2d, - camDepths + xys.to(torch::kCPU), + conics.to(torch::kCPU), + colors.to(torch::kCPU), + opacity.to(torch::kCPU), + background.to(torch::kCPU), + cov2d.to(torch::kCPU), + camDepths.to(torch::kCPU) ); // Final image - torch::Tensor outImg = std::get<0>(t); + torch::Tensor outImg = std::get<0>(t).to(device); - torch::Tensor finalTs = std::get<1>(t); + torch::Tensor finalTs = std::get<1>(t).to(device); std::vector *px2gid = std::get<2>(t); - ctx->saved_data["imgWidth"] = imgWidth; - ctx->saved_data["imgHeight"] = imgHeight; ctx->saved_data["px2gid"] = reinterpret_cast(px2gid); ctx->save_for_backward({ xys, conics, colors, opacity, background, cov2d, camDepths, finalTs }); @@ -196,27 +197,28 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr torch::Tensor finalTs = saved[7]; torch::Tensor v_outAlpha = torch::zeros_like(v_outImg.index({"...", 0})); + torch::Device device = xys.device(); auto t = rasterize_backward_tensor_cpu(imgHeight, imgWidth, - xys, - conics, - colors, - opacity, - background, - cov2d, - camDepths, - finalTs, + xys.to(torch::kCPU), + conics.to(torch::kCPU), + colors.to(torch::kCPU), + opacity.to(torch::kCPU), + background.to(torch::kCPU), + cov2d.to(torch::kCPU), + camDepths.to(torch::kCPU), + finalTs.to(torch::kCPU), px2gid, - v_outImg, - v_outAlpha); + v_outImg.to(torch::kCPU), + v_outAlpha.to(torch::kCPU)); // delete[] px2gid; - torch::Tensor v_xy = std::get<0>(t); - torch::Tensor v_conic = std::get<1>(t); - torch::Tensor v_colors = std::get<2>(t); - torch::Tensor v_opacity = std::get<3>(t); + torch::Tensor v_xy = std::get<0>(t).to(device); + torch::Tensor v_conic = std::get<1>(t).to(device); + torch::Tensor v_colors = std::get<2>(t).to(device); + torch::Tensor v_opacity = std::get<3>(t).to(device); torch::Tensor none; return { v_xy, From 52cfc107f801a2cdb8c91c74f6c863681868e34f Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Tue, 9 Apr 2024 14:36:12 -0700 Subject: [PATCH 03/19] scaffolding for metal kernels --- CMakeLists.txt | 35 +++ README.md | 9 +- gsplat.hpp | 4 + model.cpp | 12 +- project_gaussians.cpp | 2 +- project_gaussians.hpp | 2 +- rasterize_gaussians.cpp | 2 +- rasterize_gaussians.hpp | 2 +- simple_trainer.cpp | 7 +- spherical_harmonics.cpp | 2 +- spherical_harmonics.hpp | 2 +- vendor/gsplat-metal/bindings.h | 183 +++++++++++++ vendor/gsplat-metal/gsplat_metal.metal | 3 + vendor/gsplat-metal/gsplat_metal.mm | 347 +++++++++++++++++++++++++ 14 files changed, 597 insertions(+), 15 deletions(-) create mode 100644 vendor/gsplat-metal/bindings.h create mode 100644 vendor/gsplat-metal/gsplat_metal.metal create mode 100644 vendor/gsplat-metal/gsplat_metal.mm diff --git a/CMakeLists.txt b/CMakeLists.txt index 76f19cd..0e05891 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,6 +5,7 @@ set(OPENSPLAT_BUILD_SIMPLE_TRAINER OFF CACHE BOOL "Build simple trainer applicat set(GPU_RUNTIME "CUDA" CACHE STRING "HIP or CUDA or MPS") set(OPENCV_DIR "OPENCV_DIR-NOTFOUND" CACHE PATH "Path to the OPENCV installation directory") set(OPENSPLAT_MAX_CUDA_COMPATIBILITY OFF CACHE BOOL "Build for maximum CUDA device compatibility") +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel." FORCE) @@ -82,6 +83,12 @@ elseif(GPU_RUNTIME STREQUAL "HIP") endif() list(APPEND CMAKE_PREFIX_PATH "${ROCM_ROOT}") elseif(GPU_RUNTIME STREQUAL "MPS") + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + message(STATUS "Metal framework found") + + set(XC_FLAGS -O3) set(USE_MPS ON CACHE BOOL "Use MPS for GPU acceleration") else() set(GPU_RUNTIME "CPU") @@ -123,6 +130,34 @@ if((GPU_RUNTIME STREQUAL "CUDA") OR (GPU_RUNTIME STREQUAL "HIP")) ${TORCH_INCLUDE_DIRS} ) set_target_properties(gsplat PROPERTIES LINKER_LANGUAGE CXX) +elseif(GPU_RUNTIME STREQUAL "MPS") + add_library(gsplat vendor/gsplat-metal/gsplat_metal.mm) + list(APPEND GSPLAT_LIBS gsplat) + target_link_libraries(gsplat PRIVATE + ${FOUNDATION_LIBRARY} + ${METAL_FRAMEWORK} + ${METALKIT_FRAMEWORK} + ) + target_include_directories(gsplat PRIVATE + ${PROJECT_SOURCE_DIR}/vendor/glm + ${TORCH_INCLUDE_DIRS} + ) + # copy shader files to bin directory + configure_file(vendor/gsplat-metal/gsplat_metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/gsplat_metal.metal COPYONLY) + add_custom_command( + OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/gsplat_metal.metal -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/gsplat_metal.air + COMMAND xcrun -sdk macosx metallib ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/gsplat_metal.air -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/gsplat_metal.air + COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/gsplat_metal.metal + DEPENDS vendor/gsplat-metal/gsplat_metal.metal + COMMENT "Compiling Metal kernels" + ) + + add_custom_target( + gsplat_metal ALL + DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + ) endif() add_library(gsplat_cpu vendor/gsplat-cpu/gsplat_cpu.cpp) diff --git a/README.md b/README.md index dcd3f0b..7959f8f 100644 --- a/README.md +++ b/README.md @@ -121,16 +121,23 @@ brew install opencv brew install pytorch ``` +You will also need to install Xcode and the Xcode command line tools to compile with metal support (if you are fine with CPU-only acceleration, you can skip this step): +1. Install Xcode from the Apple App Store. +2. Install the command line tools with `xcode-select --install`. This might do nothing on your machine. +3. If `xcode-select --print-path` prints `/Library/Developer/CommandLineTools`,then run `sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer`. + Then run: ``` git clone https://github.com/pierotofy/OpenSplat OpenSplat cd OpenSplat mkdir build && cd build -cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch/ .. && make -j$(nproc) +cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch/ -DGPU_RUNTIME=MPS .. && make -j$(sysctl -n hw.logicalcpu) ./opensplat ``` +If building CPU-only, remove `-DGPU_RUNTIME=MPS`. + :warning: You will probably get a *libc10.dylib can’t be opened because Apple cannot check it for malicious software* error on first run. Open **System Settings** and go to **Privacy & Security** and find the **Allow** button. You might need to repeat this several times until all torch libraries are loaded. ## Docker Build diff --git a/gsplat.hpp b/gsplat.hpp index d427655..fa887af 100644 --- a/gsplat.hpp +++ b/gsplat.hpp @@ -7,6 +7,10 @@ #include "vendor/gsplat/bindings.h" #endif +#if defined(USE_MPS) +#include "vendor/gsplat-metal/bindings.h" +#endif + #include "vendor/gsplat-cpu/bindings.h" #endif \ No newline at end of file diff --git a/model.cpp b/model.cpp index bbfd075..10a2784 100644 --- a/model.cpp +++ b/model.cpp @@ -89,7 +89,7 @@ torch::Tensor Model::forward(Camera& cam, int step){ torch::Tensor camDepths; // CPU-only torch::Tensor rgb; - if (device == torch::kMPS){ + if (device == torch::kCPU){ auto p = ProjectGaussiansCPU::apply(means, torch::exp(scales), 1, @@ -108,7 +108,7 @@ torch::Tensor Model::forward(Camera& cam, int step){ cov2d = p[3]; camDepths = p[4]; }else{ - #if defined(USE_HIP) || defined(USE_CUDA) + #if defined(USE_HIP) || defined(USE_CUDA) || defined(USE_MPS) TileBounds tileBounds = std::make_tuple((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, @@ -149,17 +149,17 @@ torch::Tensor Model::forward(Camera& cam, int step){ int degreesToUse = (std::min)(step / shDegreeInterval, shDegree); torch::Tensor rgbs; - if (device == torch::kMPS){ + if (device == torch::kCPU){ rgbs = SphericalHarmonicsCPU::apply(degreesToUse, viewDirs, colors); }else{ - #if defined(USE_HIP) || defined(USE_CUDA) + #if defined(USE_HIP) || defined(USE_CUDA) || defined(USE_MPS) rgbs = SphericalHarmonics::apply(degreesToUse, viewDirs, colors); #endif } rgbs = torch::clamp_min(rgbs + 0.5f, 0.0f); - if (device == torch::kMPS){ + if (device == torch::kCPU){ rgb = RasterizeGaussiansCPU::apply( xys, radii, @@ -172,7 +172,7 @@ torch::Tensor Model::forward(Camera& cam, int step){ width, backgroundColor); }else{ - #if defined(USE_HIP) || defined(USE_CUDA) + #if defined(USE_HIP) || defined(USE_CUDA) || defined(USE_MPS) rgb = RasterizeGaussians::apply( xys, depths, diff --git a/project_gaussians.cpp b/project_gaussians.cpp index d57e1d6..9de74b5 100644 --- a/project_gaussians.cpp +++ b/project_gaussians.cpp @@ -1,6 +1,6 @@ #include "project_gaussians.hpp" -#if defined(USE_HIP) || defined(USE_CUDA) +#if defined(USE_HIP) || defined(USE_CUDA) || defined(USE_MPS) variable_list ProjectGaussians::forward(AutogradContext *ctx, torch::Tensor means, diff --git a/project_gaussians.hpp b/project_gaussians.hpp index 8891d22..c7a5bcb 100644 --- a/project_gaussians.hpp +++ b/project_gaussians.hpp @@ -7,7 +7,7 @@ using namespace torch::autograd; -#if defined(USE_HIP) || defined(USE_CUDA) +#if defined(USE_HIP) || defined(USE_CUDA) || defined(USE_MPS) class ProjectGaussians : public Function{ public: diff --git a/rasterize_gaussians.cpp b/rasterize_gaussians.cpp index 049a967..875134b 100644 --- a/rasterize_gaussians.cpp +++ b/rasterize_gaussians.cpp @@ -1,7 +1,7 @@ #include "rasterize_gaussians.hpp" #include "gsplat.hpp" -#if defined(USE_HIP) || defined(USE_CUDA) +#if defined(USE_HIP) || defined(USE_CUDA) || defined(USE_MPS) std::tuple{ public: diff --git a/vendor/gsplat-metal/bindings.h b/vendor/gsplat-metal/bindings.h new file mode 100644 index 0000000..bc98eb8 --- /dev/null +++ b/vendor/gsplat-metal/bindings.h @@ -0,0 +1,183 @@ +#include +#include +#include +#include +#include + +#define CHECK_MPS(x) TORCH_CHECK(x.is_mps(), #x " must be a MPS tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_MPS(x); \ + CHECK_CONTIGUOUS(x) + +std::tuple< + torch::Tensor, // output conics + torch::Tensor> // output radii +compute_cov2d_bounds_tensor(const int num_pts, torch::Tensor &A); + +torch::Tensor compute_sh_forward_tensor( + unsigned num_points, + unsigned degree, + unsigned degrees_to_use, + torch::Tensor &viewdirs, + torch::Tensor &coeffs +); + +torch::Tensor compute_sh_backward_tensor( + unsigned num_points, + unsigned degree, + unsigned degrees_to_use, + torch::Tensor &viewdirs, + torch::Tensor &v_colors +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +project_gaussians_forward_tensor( + const int num_points, + torch::Tensor &means3d, + torch::Tensor &scales, + const float glob_scale, + torch::Tensor &quats, + torch::Tensor &viewmat, + torch::Tensor &projmat, + const float fx, + const float fy, + const float cx, + const float cy, + const unsigned img_height, + const unsigned img_width, + const std::tuple tile_bounds, + const float clip_thresh +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +project_gaussians_backward_tensor( + const int num_points, + torch::Tensor &means3d, + torch::Tensor &scales, + const float glob_scale, + torch::Tensor &quats, + torch::Tensor &viewmat, + torch::Tensor &projmat, + const float fx, + const float fy, + const float cx, + const float cy, + const unsigned img_height, + const unsigned img_width, + torch::Tensor &cov3d, + torch::Tensor &radii, + torch::Tensor &conics, + torch::Tensor &v_xy, + torch::Tensor &v_depth, + torch::Tensor &v_conic +); + + +std::tuple map_gaussian_to_intersects_tensor( + const int num_points, + const int num_intersects, + const torch::Tensor &xys, + const torch::Tensor &depths, + const torch::Tensor &radii, + const torch::Tensor &cum_tiles_hit, + const std::tuple tile_bounds +); + +torch::Tensor get_tile_bin_edges_tensor( + int num_intersects, + const torch::Tensor &isect_ids_sorted +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor +> rasterize_forward_tensor( + const std::tuple tile_bounds, + const std::tuple block, + const std::tuple img_size, + const torch::Tensor &gaussian_ids_sorted, + const torch::Tensor &tile_bins, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor +> nd_rasterize_forward_tensor( + const std::tuple tile_bounds, + const std::tuple block, + const std::tuple img_size, + const torch::Tensor &gaussian_ids_sorted, + const torch::Tensor &tile_bins, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background +); + + +std:: + tuple< + torch::Tensor, // dL_dxy + torch::Tensor, // dL_dconic + torch::Tensor, // dL_dcolors + torch::Tensor // dL_dopacity + > + nd_rasterize_backward_tensor( + const unsigned img_height, + const unsigned img_width, + const torch::Tensor &gaussians_ids_sorted, + const torch::Tensor &tile_bins, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background, + const torch::Tensor &final_Ts, + const torch::Tensor &final_idx, + const torch::Tensor &v_output, // dL_dout_color + const torch::Tensor &v_output_alpha + ); + +std:: + tuple< + torch::Tensor, // dL_dxy + torch::Tensor, // dL_dconic + torch::Tensor, // dL_dcolors + torch::Tensor // dL_dopacity + > + rasterize_backward_tensor( + const unsigned img_height, + const unsigned img_width, + const torch::Tensor &gaussians_ids_sorted, + const torch::Tensor &tile_bins, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background, + const torch::Tensor &final_Ts, + const torch::Tensor &final_idx, + const torch::Tensor &v_output, // dL_dout_color + const torch::Tensor &v_output_alpha + ); \ No newline at end of file diff --git a/vendor/gsplat-metal/gsplat_metal.metal b/vendor/gsplat-metal/gsplat_metal.metal new file mode 100644 index 0000000..ccee151 --- /dev/null +++ b/vendor/gsplat-metal/gsplat_metal.metal @@ -0,0 +1,3 @@ +#include + +using namespace metal; \ No newline at end of file diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm new file mode 100644 index 0000000..f417382 --- /dev/null +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -0,0 +1,347 @@ +#import "bindings.h" + +#import + +#import + +// This function is used in both host and device code +// TODO(achan): Do I need to make this callable from the metal device? +unsigned num_sh_bases(const unsigned degree) { + if (degree == 0) + return 1; + if (degree == 1) + return 4; + if (degree == 2) + return 9; + if (degree == 3) + return 16; + return 25; +} + +std::tuple< + torch::Tensor, // output conics + torch::Tensor> // output radii +compute_cov2d_bounds_tensor(const int num_pts, torch::Tensor &covs2d) { + CHECK_INPUT(covs2d); + torch::Tensor conics = torch::zeros( + {num_pts, covs2d.size(1)}, covs2d.options().dtype(torch::kFloat32) + ); + torch::Tensor radii = + torch::zeros({num_pts, 1}, covs2d.options().dtype(torch::kFloat32)); + + return std::make_tuple(conics, radii); +} + +torch::Tensor compute_sh_forward_tensor( + unsigned num_points, + unsigned degree, + unsigned degrees_to_use, + torch::Tensor &viewdirs, + torch::Tensor &coeffs +) { + unsigned num_bases = num_sh_bases(degree); + if (coeffs.ndimension() != 3 || coeffs.size(0) != num_points || + coeffs.size(1) != num_bases || coeffs.size(2) != 3) { + AT_ERROR("coeffs must have dimensions (N, D, 3)"); + } + torch::Tensor colors = torch::empty({num_points, 3}, coeffs.options()); + return colors; +} + +torch::Tensor compute_sh_backward_tensor( + unsigned num_points, + unsigned degree, + unsigned degrees_to_use, + torch::Tensor &viewdirs, + torch::Tensor &v_colors +) { + if (viewdirs.ndimension() != 2 || viewdirs.size(0) != num_points || + viewdirs.size(1) != 3) { + AT_ERROR("viewdirs must have dimensions (N, 3)"); + } + if (v_colors.ndimension() != 2 || v_colors.size(0) != num_points || + v_colors.size(1) != 3) { + AT_ERROR("v_colors must have dimensions (N, 3)"); + } + unsigned num_bases = num_sh_bases(degree); + torch::Tensor v_coeffs = + torch::zeros({num_points, num_bases, 3}, v_colors.options()); + return v_coeffs; +} + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +project_gaussians_forward_tensor( + const int num_points, + torch::Tensor &means3d, + torch::Tensor &scales, + const float glob_scale, + torch::Tensor &quats, + torch::Tensor &viewmat, + torch::Tensor &projmat, + const float fx, + const float fy, + const float cx, + const float cy, + const unsigned img_height, + const unsigned img_width, + const std::tuple tile_bounds, + const float clip_thresh +) { + // Triangular covariance. + torch::Tensor cov3d_d = + torch::zeros({num_points, 6}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor xys_d = + torch::zeros({num_points, 2}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor depths_d = + torch::zeros({num_points}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor radii_d = + torch::zeros({num_points}, means3d.options().dtype(torch::kInt32)); + torch::Tensor conics_d = + torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor num_tiles_hit_d = + torch::zeros({num_points}, means3d.options().dtype(torch::kInt32)); + + return std::make_tuple( + cov3d_d, xys_d, depths_d, radii_d, conics_d, num_tiles_hit_d + ); +} + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +project_gaussians_backward_tensor( + const int num_points, + torch::Tensor &means3d, + torch::Tensor &scales, + const float glob_scale, + torch::Tensor &quats, + torch::Tensor &viewmat, + torch::Tensor &projmat, + const float fx, + const float fy, + const float cx, + const float cy, + const unsigned img_height, + const unsigned img_width, + torch::Tensor &cov3d, + torch::Tensor &radii, + torch::Tensor &conics, + torch::Tensor &v_xy, + torch::Tensor &v_depth, + torch::Tensor &v_conic +) { + // Triangular covariance. + torch::Tensor v_cov2d = + torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor v_cov3d = + torch::zeros({num_points, 6}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor v_mean3d = + torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor v_scale = + torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor v_quat = + torch::zeros({num_points, 4}, means3d.options().dtype(torch::kFloat32)); + + return std::make_tuple(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat); +} + + +std::tuple map_gaussian_to_intersects_tensor( + const int num_points, + const int num_intersects, + const torch::Tensor &xys, + const torch::Tensor &depths, + const torch::Tensor &radii, + const torch::Tensor &cum_tiles_hit, + const std::tuple tile_bounds +) { + CHECK_INPUT(xys); + CHECK_INPUT(depths); + CHECK_INPUT(radii); + CHECK_INPUT(cum_tiles_hit); + + torch::Tensor gaussian_ids_unsorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32)); + torch::Tensor isect_ids_unsorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64)); + + return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted); +} + +torch::Tensor get_tile_bin_edges_tensor( + int num_intersects, + const torch::Tensor &isect_ids_sorted +) { + CHECK_INPUT(isect_ids_sorted); + torch::Tensor tile_bins = torch::zeros( + {num_intersects, 2}, isect_ids_sorted.options().dtype(torch::kInt32) + ); + + return tile_bins; +} + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor +> rasterize_forward_tensor( + const std::tuple tile_bounds, + const std::tuple block, + const std::tuple img_size, + const torch::Tensor &gaussian_ids_sorted, + const torch::Tensor &tile_bins, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background +) { + CHECK_INPUT(gaussian_ids_sorted); + CHECK_INPUT(tile_bins); + CHECK_INPUT(xys); + CHECK_INPUT(conics); + CHECK_INPUT(colors); + CHECK_INPUT(opacities); + CHECK_INPUT(background); + + const int channels = colors.size(1); + const int img_width = std::get<0>(img_size); + const int img_height = std::get<1>(img_size); + + torch::Tensor out_img = torch::zeros( + {img_height, img_width, channels}, xys.options().dtype(torch::kFloat32) + ); + torch::Tensor final_Ts = torch::zeros( + {img_height, img_width}, xys.options().dtype(torch::kFloat32) + ); + torch::Tensor final_idx = torch::zeros( + {img_height, img_width}, xys.options().dtype(torch::kInt32) + ); + + return std::make_tuple(out_img, final_Ts, final_idx); +} + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor +> nd_rasterize_forward_tensor( + const std::tuple tile_bounds, + const std::tuple block, + const std::tuple img_size, + const torch::Tensor &gaussian_ids_sorted, + const torch::Tensor &tile_bins, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background +) { + CHECK_INPUT(gaussian_ids_sorted); + CHECK_INPUT(tile_bins); + CHECK_INPUT(xys); + CHECK_INPUT(conics); + CHECK_INPUT(colors); + CHECK_INPUT(opacities); + CHECK_INPUT(background); + + const int channels = colors.size(1); + const int img_width = std::get<0>(img_size); + const int img_height = std::get<1>(img_size); + + torch::Tensor out_img = torch::zeros( + {img_height, img_width, channels}, xys.options().dtype(torch::kFloat32) + ); + torch::Tensor final_Ts = torch::zeros( + {img_height, img_width}, xys.options().dtype(torch::kFloat32) + ); + torch::Tensor final_idx = torch::zeros( + {img_height, img_width}, xys.options().dtype(torch::kInt32) + ); + + return std::make_tuple(out_img, final_Ts, final_idx); +} + + +std:: + tuple< + torch::Tensor, // dL_dxy + torch::Tensor, // dL_dconic + torch::Tensor, // dL_dcolors + torch::Tensor // dL_dopacity + > + nd_rasterize_backward_tensor( + const unsigned img_height, + const unsigned img_width, + const torch::Tensor &gaussians_ids_sorted, + const torch::Tensor &tile_bins, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background, + const torch::Tensor &final_Ts, + const torch::Tensor &final_idx, + const torch::Tensor &v_output, // dL_dout_color + const torch::Tensor &v_output_alpha + ) { + CHECK_INPUT(xys); + CHECK_INPUT(colors); + + const int num_points = xys.size(0); + const int channels = colors.size(1); + + torch::Tensor v_xy = torch::zeros({num_points, 2}, xys.options()); + torch::Tensor v_conic = torch::zeros({num_points, 3}, xys.options()); + torch::Tensor v_colors = + torch::zeros({num_points, channels}, xys.options()); + torch::Tensor v_opacity = torch::zeros({num_points, 1}, xys.options()); + + return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); +} + +std:: + tuple< + torch::Tensor, // dL_dxy + torch::Tensor, // dL_dconic + torch::Tensor, // dL_dcolors + torch::Tensor // dL_dopacity + > + rasterize_backward_tensor( + const unsigned img_height, + const unsigned img_width, + const torch::Tensor &gaussians_ids_sorted, + const torch::Tensor &tile_bins, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background, + const torch::Tensor &final_Ts, + const torch::Tensor &final_idx, + const torch::Tensor &v_output, // dL_dout_color + const torch::Tensor &v_output_alpha + ) { + CHECK_INPUT(xys); + CHECK_INPUT(colors); + + const int num_points = xys.size(0); + const int channels = colors.size(1); + + torch::Tensor v_xy = torch::zeros({num_points, 2}, xys.options()); + torch::Tensor v_conic = torch::zeros({num_points, 3}, xys.options()); + torch::Tensor v_colors = + torch::zeros({num_points, channels}, xys.options()); + torch::Tensor v_opacity = torch::zeros({num_points, 1}, xys.options()); + + return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); +} \ No newline at end of file From cfcbf19bcd272c580f68143abf37eaa1ba161b56 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Wed, 10 Apr 2024 16:29:03 -0700 Subject: [PATCH 04/19] wip --- vendor/gsplat-metal/gsplat_metal.metal | 595 ++++++++++++++++++++++++- vendor/gsplat-metal/gsplat_metal.mm | 276 ++++++++++++ 2 files changed, 870 insertions(+), 1 deletion(-) diff --git a/vendor/gsplat-metal/gsplat_metal.metal b/vendor/gsplat-metal/gsplat_metal.metal index ccee151..6045947 100644 --- a/vendor/gsplat-metal/gsplat_metal.metal +++ b/vendor/gsplat-metal/gsplat_metal.metal @@ -1,3 +1,596 @@ #include -using namespace metal; \ No newline at end of file +using namespace metal; + +#define BLOCK_X 16 +#define BLOCK_Y 16 +#define BLOCK_SIZE (BLOCK_X * BLOCK_Y) +#define CHANNELS 3 + +constant float SH_C0 = 0.28209479177387814f; +constant float SH_C1 = 0.4886025119029199f; +constant float SH_C2[] = { + 1.0925484305920792f, + -1.0925484305920792f, + 0.31539156525252005f, + -1.0925484305920792f, + 0.5462742152960396f}; +constant float SH_C3[] = { + -0.5900435899266435f, + 2.890611442640554f, + -0.4570457994644658f, + 0.3731763325901154f, + -0.4570457994644658f, + 1.445305721320277f, + -0.5900435899266435f}; +constant float SH_C4[] = { + 2.5033429417967046f, + -1.7701307697799304, + 0.9461746957575601f, + -0.6690465435572892f, + 0.10578554691520431f, + -0.6690465435572892f, + 0.47308734787878004f, + -1.7701307697799304f, + 0.6258357354491761f}; + +inline uint num_sh_bases(const uint degree) { + if (degree == 0) + return 1; + if (degree == 1) + return 4; + if (degree == 2) + return 9; + if (degree == 3) + return 16; + return 25; +} + +inline float ndc2pix(const float x, const float W, const float cx) { + return 0.5f * W * x + cx - 0.5; +} + +inline void get_bbox( + const float2 center, + const float2 dims, + const int3 img_size, + thread uint2 &bb_min, + thread uint2 &bb_max +) { + // get bounding box with center and dims, within bounds + // bounding box coords returned in tile coords, inclusive min, exclusive max + // clamp between 0 and tile bounds + bb_min.x = min(max(0, (int)(center.x - dims.x)), img_size.x); + bb_max.x = min(max(0, (int)(center.x + dims.x + 1)), img_size.x); + bb_min.y = min(max(0, (int)(center.y - dims.y)), img_size.y); + bb_max.y = min(max(0, (int)(center.y + dims.y + 1)), img_size.y); +} + +inline void get_tile_bbox( + const float2 pix_center, + const float pix_radius, + const int3 tile_bounds, + thread uint2 &tile_min, + thread uint2 &tile_max +) { + // gets gaussian dimensions in tile space, i.e. the span of a gaussian in + // tile_grid (image divided into tiles) + float2 tile_center = { + pix_center.x / (float)BLOCK_X, pix_center.y / (float)BLOCK_Y + }; + float2 tile_radius = { + pix_radius / (float)BLOCK_X, pix_radius / (float)BLOCK_Y + }; + get_bbox(tile_center, tile_radius, tile_bounds, tile_min, tile_max); +} + +// helper for applying R * p + T, expect mat to be ROW MAJOR +inline float3 transform_4x3(constant float *mat, const float3 p) { + float3 out = { + mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3], + mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7], + mat[8] * p.x + mat[9] * p.y + mat[10] * p.z + mat[11], + }; + return out; +} + +// helper to apply 4x4 transform to 3d vector, return homo coords +// expects mat to be ROW MAJOR +inline float4 transform_4x4(constant float *mat, const float3 p) { + float4 out = { + mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3], + mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7], + mat[8] * p.x + mat[9] * p.y + mat[10] * p.z + mat[11], + mat[12] * p.x + mat[13] * p.y + mat[14] * p.z + mat[15], + }; + return out; +} + +inline float3x3 quat_to_rotmat(const float4 quat) { + // quat to rotation matrix + float s = rsqrt( + quat.w * quat.w + quat.x * quat.x + quat.y * quat.y + quat.z * quat.z + ); + float w = quat.x * s; + float x = quat.y * s; + float y = quat.z * s; + float z = quat.w * s; + + // metal matrices are column-major + return float3x3( + 1.f - 2.f * (y * y + z * z), + 2.f * (x * y + w * z), + 2.f * (x * z - w * y), + 2.f * (x * y - w * z), + 1.f - 2.f * (x * x + z * z), + 2.f * (y * z + w * x), + 2.f * (x * z + w * y), + 2.f * (y * z - w * x), + 1.f - 2.f * (x * x + y * y) + ); +} + +// device helper for culling near points +inline bool clip_near_plane( + const float3 p, + constant float *viewmat, + thread float3 &p_view, + float thresh +) { + p_view = transform_4x3(viewmat, p); + if (p_view.z <= thresh) { + return true; + } + return false; +} + +inline float3x3 scale_to_mat(const float3 scale, const float glob_scale) { + float3x3 S = float3x3(1.f); + S[0][0] = glob_scale * scale.x; + S[1][1] = glob_scale * scale.y; + S[2][2] = glob_scale * scale.z; + return S; +} + +// device helper to get 3D covariance from scale and quat parameters +inline void scale_rot_to_cov3d( + const float3 scale, const float glob_scale, const float4 quat, device float *cov3d +) { + // printf("quat %.2f %.2f %.2f %.2f\n", quat.x, quat.y, quat.z, quat.w); + float3x3 R = quat_to_rotmat(quat); + // printf("R %.2f %.2f %.2f\n", R[0][0], R[1][1], R[2][2]); + float3x3 S = scale_to_mat(scale, glob_scale); + // printf("S %.2f %.2f %.2f\n", S[0][0], S[1][1], S[2][2]); + + float3x3 M = R * S; + float3x3 tmp = M * transpose(M); + // printf("tmp %.2f %.2f %.2f\n", tmp[0][0], tmp[1][1], tmp[2][2]); + + // save upper right because symmetric + cov3d[0] = tmp[0][0]; + cov3d[1] = tmp[0][1]; + cov3d[2] = tmp[0][2]; + cov3d[3] = tmp[1][1]; + cov3d[4] = tmp[1][2]; + cov3d[5] = tmp[2][2]; +} + +// device helper to approximate projected 2d cov from 3d mean and cov +float3 project_cov3d_ewa( + thread float3& mean3d, + device float* cov3d, + constant float* viewmat, + const float fx, + const float fy, + const float tan_fovx, + const float tan_fovy +) { + // clip the + // we expect row major matrices as input, metal uses column major + // upper 3x3 submatrix + float3x3 W = float3x3( + viewmat[0], + viewmat[4], + viewmat[8], + viewmat[1], + viewmat[5], + viewmat[9], + viewmat[2], + viewmat[6], + viewmat[10] + ); + float3 p = float3(viewmat[3], viewmat[7], viewmat[11]); + float3 t = W * float3(mean3d.x, mean3d.y, mean3d.z) + p; + + // clip so that the covariance + float lim_x = 1.3 * tan_fovx; + float lim_y = 1.3 * tan_fovy; + t.x = t.z * min(lim_x, max(-lim_x, t.x / t.z)); + t.y = t.z * min(lim_y, max(-lim_y, t.y / t.z)); + + float rz = 1.f / t.z; + float rz2 = rz * rz; + + // column major + // we only care about the top 2x2 submatrix + float3x3 J = float3x3( + fx * rz, + 0.f, + 0.f, + 0.f, + fy * rz, + 0.f, + -fx * t.x * rz2, + -fy * t.y * rz2, + 0.f + ); + float3x3 T = J * W; + + float3x3 V = float3x3( + cov3d[0], + cov3d[1], + cov3d[2], + cov3d[1], + cov3d[3], + cov3d[4], + cov3d[2], + cov3d[4], + cov3d[5] + ); + + float3x3 cov = T * V * transpose(T); + + // add a little blur along axes and save upper triangular elements + return float3(float(cov[0][0]) + 0.3f, float(cov[0][1]), float(cov[1][1]) + 0.3f); +} + +inline bool compute_cov2d_bounds( + const float3 cov2d, + thread float3 &conic, + thread float &radius +) { + // find eigenvalues of 2d covariance matrix + // expects upper triangular values of cov matrix as float3 + // then compute the radius and conic dimensions + // the conic is the inverse cov2d matrix, represented here with upper + // triangular values. + float det = cov2d.x * cov2d.z - cov2d.y * cov2d.y; + if (det == 0.f) + return false; + float inv_det = 1.f / det; + + // inverse of 2x2 cov2d matrix + conic.x = cov2d.z * inv_det; + conic.y = -cov2d.y * inv_det; + conic.z = cov2d.x * inv_det; + + float b = 0.5f * (cov2d.x + cov2d.z); + float v1 = b + sqrt(max(0.1f, b * b - det)); + float v2 = b - sqrt(max(0.1f, b * b - det)); + // take 3 sigma of covariance + radius = ceil(3.f * sqrt(max(v1, v2))); + return true; +} + +inline float2 project_pix( + constant float *mat, const float3 p, const int3 img_size, const float2 pp +) { + // ROW MAJOR mat + float4 p_hom = transform_4x4(mat, p); + float rw = 1.f / (p_hom.w + 1e-6f); + float3 p_proj = {p_hom.x * rw, p_hom.y * rw, p_hom.z * rw}; + return { + ndc2pix(p_proj.x, img_size.x, pp.x), ndc2pix(p_proj.y, img_size.y, pp.y) + }; +} + +/* + !!!!IMPORTANT!!! + Metal does not support packed arrays of vectorized types like int2, float2, float3, etc. + and instead pads the elements of arrays of these types to fixed alignments. + Use the below functions to read and write from packed arrays of these types. +*/ + +inline int2 read_packed_int2(constant int* arr, int idx) { + return int2(arr[2*idx], arr[2*idx+1]); +} + +inline void write_packed_int2(device int* arr, int idx, int2 val) { + arr[2*idx] = val.x; + arr[2*idx+1] = val.y; +} + +inline float2 read_packed_float2(constant float* arr, int idx) { + return float2(arr[2*idx], arr[2*idx+1]); +} + +inline void write_packed_float2(device float* arr, int idx, float2 val) { + arr[2*idx] = val.x; + arr[2*idx+1] = val.y; +} + +inline int3 read_packed_int3(constant int* arr, int idx) { + return int3(arr[3*idx], arr[3*idx+1], arr[3*idx+2]); +} + +inline void write_packed_int3(device int* arr, int idx, int3 val) { + arr[3*idx] = val.x; + arr[3*idx+1] = val.y; + arr[3*idx+2] = val.z; +} + +inline float3 read_packed_float3(constant float* arr, int idx) { + return float3(arr[3*idx], arr[3*idx+1], arr[3*idx+2]); +} + +inline void write_packed_float3(device float* arr, int idx, float3 val) { + arr[3*idx] = val.x; + arr[3*idx+1] = val.y; + arr[3*idx+2] = val.z; +} + +inline float4 read_packed_float4(constant float* arr, int idx) { + return float4(arr[4*idx], arr[4*idx+1], arr[4*idx+2], arr[4*idx+3]); +} + +inline void write_packed_float4(device float* arr, int idx, float4 val) { + arr[4*idx] = val.x; + arr[4*idx+1] = val.y; + arr[4*idx+2] = val.z; + arr[4*idx+3] = val.w; +} + +// kernel function for projecting each gaussian on device +// each thread processes one gaussian +kernel void project_gaussians_forward_kernel( + constant int& num_points, + constant float* means3d, // float3 + constant float* scales, // float3 + constant float& glob_scale, + constant float* quats, // float4 + constant float* viewmat, + constant float* projmat, + constant float4& intrins, + constant int3& img_size, + constant int3& tile_bounds, + constant float& clip_thresh, + device float* covs3d, + device float* xys, // float2 + device float* depths, + device int* radii, + device float* conics, // float3 + device int32_t* num_tiles_hit, + uint idx [[thread_position_in_grid]] +) { + if (idx >= num_points) { + return; + } + radii[idx] = 0; + num_tiles_hit[idx] = 0; + + float3 p_world = means3d[idx*3]; + float3 p_view; + if (clip_near_plane(p_world, viewmat, p_view, clip_thresh)) { + return; + } + + // compute the projected covariance + float3 scale = read_packed_float3(scales, idx); + float4 quat = read_packed_float4(quats, idx); + device float *cur_cov3d = &(covs3d[6 * idx]); + scale_rot_to_cov3d(scale, glob_scale, quat, cur_cov3d); + + // project to 2d with ewa approximation + float fx = intrins.x; + float fy = intrins.y; + float cx = intrins.z; + float cy = intrins.w; + float tan_fovx = 0.5 * img_size.x / fx; + float tan_fovy = 0.5 * img_size.y / fy; + float3 cov2d = project_cov3d_ewa( + p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy + ); + + float3 conic; + float radius; + bool ok = compute_cov2d_bounds(cov2d, conic, radius); + if (!ok) { + return; // zero determinant + } + write_packed_float3(conics, idx, conic); + + // compute the projected mean + float2 center = project_pix(projmat, p_world, img_size, {cx, cy}); + uint2 tile_min, tile_max; + get_tile_bbox(center, radius, tile_bounds, tile_min, tile_max); + int32_t tile_area = (tile_max.x - tile_min.x) * (tile_max.y - tile_min.y); + if (tile_area <= 0) { + return; + } + + num_tiles_hit[idx] = tile_area; + depths[idx] = p_view.z; + radii[idx] = (int)radius; + write_packed_float2(xys, idx, center); +} + +kernel void rasterize_forward_kernel( + constant int3& tile_bounds, + constant int3& img_size, + constant uint& channels, + constant int32_t* gaussian_ids_sorted, + constant int* tile_bins, // int2 + constant float* xys, // float2 + constant float* conics, // float3 + constant float* colors, + constant float* opacities, + device float* final_Ts, + device int* final_index, + device float* out_img, + constant float* background, + constant uint2& blockDim, + uint2 blockIdx [[threadgroup_position_in_grid]], + uint2 threadIdx [[thread_position_in_threadgroup]] +) { + // current naive implementation where tile data loading is redundant + // TODO tile data should be shared between tile threads + int32_t tile_id = blockIdx.y * tile_bounds.x + blockIdx.x; + int32_t i = blockIdx.y * blockDim.y + threadIdx.y; + int32_t j = blockIdx.x * blockDim.x + threadIdx.x; + float px = (float)j; + float py = (float)i; + int32_t pix_id = i * img_size.x + j; + + // return if out of bounds + if (i >= img_size.y || j >= img_size.x) { + return; + } + + // which gaussians to look through in this tile + int2 range = read_packed_int2(tile_bins, tile_id); + float T = 1.f; + + // iterate over all gaussians and apply rendering EWA equation (e.q. 2 from + // paper) + int idx; + for (idx = range.x; idx < range.y; ++idx) { + const int32_t g = gaussian_ids_sorted[idx]; + const float3 conic = read_packed_float3(conics, g); + const float2 center = read_packed_float2(xys, g); + const float2 delta = {center.x - px, center.y - py}; + + // Mahalanobis distance (here referred to as sigma) measures how many + // standard deviations away distance delta is. sigma = -0.5(d.T * conic + // * d) + const float sigma = + 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + + conic.y * delta.x * delta.y; + if (sigma < 0.f) { + continue; + } + const float opac = opacities[g]; + + const float alpha = min(0.999f, opac * exp(-sigma)); + + // break out conditions + if (alpha < 1.f / 255.f) { + continue; + } + const float next_T = T * (1.f - alpha); + if (next_T <= 1e-4f) { + // we want to render the last gaussian that contributes and note + // that here idx > range.x so we don't underflow + idx -= 1; + break; + } + const float vis = alpha * T; + for (int c = 0; c < channels; ++c) { + out_img[channels * pix_id + c] += colors[channels * g + c] * vis; + } + T = next_T; + } + final_Ts[pix_id] = T; // transmittance at last gaussian in this pixel + final_index[pix_id] = + (idx == range.y) + ? idx - 1 + : idx; // index of in bin of last gaussian in this pixel + for (int c = 0; c < channels; ++c) { + out_img[channels * pix_id + c] += T * background[c]; + } +} + +void sh_coeffs_to_color( + const uint degree, + const float3 viewdir, + constant float *coeffs, + device float *colors +) { + // Expects v_colors to be len CHANNELS + // and v_coeffs to be num_bases * CHANNELS + for (int c = 0; c < CHANNELS; ++c) { + colors[c] = SH_C0 * coeffs[c]; + } + if (degree < 1) { + return; + } + + float norm = sqrt( + viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z + ); + float x = viewdir.x / norm; + float y = viewdir.y / norm; + float z = viewdir.z / norm; + + float xx = x * x; + float xy = x * y; + float xz = x * z; + float yy = y * y; + float yz = y * z; + float zz = z * z; + // expects CHANNELS * num_bases coefficients + // supports up to num_bases = 25 + for (int c = 0; c < CHANNELS; ++c) { + colors[c] += SH_C1 * (-y * coeffs[1 * CHANNELS + c] + + z * coeffs[2 * CHANNELS + c] - + x * coeffs[3 * CHANNELS + c]); + if (degree < 2) { + continue; + } + colors[c] += + (SH_C2[0] * xy * coeffs[4 * CHANNELS + c] + + SH_C2[1] * yz * coeffs[5 * CHANNELS + c] + + SH_C2[2] * (2.f * zz - xx - yy) * coeffs[6 * CHANNELS + c] + + SH_C2[3] * xz * coeffs[7 * CHANNELS + c] + + SH_C2[4] * (xx - yy) * coeffs[8 * CHANNELS + c]); + if (degree < 3) { + continue; + } + colors[c] += + (SH_C3[0] * y * (3.f * xx - yy) * coeffs[9 * CHANNELS + c] + + SH_C3[1] * xy * z * coeffs[10 * CHANNELS + c] + + SH_C3[2] * y * (4.f * zz - xx - yy) * coeffs[11 * CHANNELS + c] + + SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy) * + coeffs[12 * CHANNELS + c] + + SH_C3[4] * x * (4.f * zz - xx - yy) * coeffs[13 * CHANNELS + c] + + SH_C3[5] * z * (xx - yy) * coeffs[14 * CHANNELS + c] + + SH_C3[6] * x * (xx - 3.f * yy) * coeffs[15 * CHANNELS + c]); + if (degree < 4) { + continue; + } + colors[c] += + (SH_C4[0] * xy * (xx - yy) * coeffs[16 * CHANNELS + c] + + SH_C4[1] * yz * (3.f * xx - yy) * coeffs[17 * CHANNELS + c] + + SH_C4[2] * xy * (7.f * zz - 1.f) * coeffs[18 * CHANNELS + c] + + SH_C4[3] * yz * (7.f * zz - 3.f) * coeffs[19 * CHANNELS + c] + + SH_C4[4] * (zz * (35.f * zz - 30.f) + 3.f) * + coeffs[20 * CHANNELS + c] + + SH_C4[5] * xz * (7.f * zz - 3.f) * coeffs[21 * CHANNELS + c] + + SH_C4[6] * (xx - yy) * (7.f * zz - 1.f) * + coeffs[22 * CHANNELS + c] + + SH_C4[7] * xz * (xx - 3.f * yy) * coeffs[23 * CHANNELS + c] + + SH_C4[8] * (xx * (xx - 3.f * yy) - yy * (3.f * xx - yy)) * + coeffs[24 * CHANNELS + c]); + } +} + +kernel void compute_sh_forward_kernel( + constant uint& num_points, + constant uint& degree, + constant uint& degrees_to_use, + constant float* viewdirs, // float3 + constant float* coeffs, + device float* colors, + uint idx [[thread_position_in_threadgroup]] +) { + if (idx >= num_points) { + return; + } + const uint num_channels = 3; + uint num_bases = num_sh_bases(degree); + uint idx_sh = num_bases * num_channels * idx; + uint idx_col = num_channels * idx; + + sh_coeffs_to_color( + degrees_to_use, read_packed_float3(viewdirs, idx), &(coeffs[idx_sh]), &(colors[idx_col]) + ); +} \ No newline at end of file diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index f417382..5fbec07 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -4,6 +4,24 @@ #import +struct MetalContext { + id device; + id queue; + dispatch_queue_t d_queue; + + id nd_rasterize_forward_kernel_cpso; + id nd_rasterize_backward_kernel_cpso; + id rasterize_forward_kernel_cpso; + id rasterize_backward_kernel_cpso; + id project_gaussians_forward_kernel_cpso; + id project_gaussians_backward_kernel_cpso; + id compute_sh_forward_kernel_cpso; + id compute_sh_backward_kernel_cpso; + id compute_cov2d_bounds_kernel_cpso; + id map_gaussian_to_intersects_kernel_cpso; + id get_tile_bin_edges_kernel_cpso; +}; + // This function is used in both host and device code // TODO(achan): Do I need to make this callable from the metal device? unsigned num_sh_bases(const unsigned degree) { @@ -18,6 +36,130 @@ unsigned num_sh_bases(const unsigned degree) { return 25; } +// This empty class lets us query for files relative to this file's bundle path using NSBundle bundleForClass hack +@interface DummyClassForPathHack : NSObject +@end +@implementation DummyClassForPathHack +@end + +MetalContext* init_gsplat_metal_context() { + MetalContext* ctx = (MetalContext*)malloc(sizeof(MetalContext)); + // Retrieve the default Metal device + id device = MTLCreateSystemDefaultDevice(); + + // Configure context + ctx->device = device; + ctx->queue = [ctx->device newCommandQueue]; + ctx->d_queue = torch::mps::get_dispatch_queue(); + + NSError *error = nil; + + id metal_library = nil; + NSBundle * bundle = [NSBundle bundleForClass:[DummyClassForPathHack class]]; + NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; + + if (path_lib != nil) { + // pre-compiled library found + NSURL * libURL = [NSURL fileURLWithPath:path_lib]; + printf("%s: loading '%s'\n", __func__, [path_lib UTF8String]); + + metal_library = [ctx->device newLibraryWithURL:libURL error:&error]; + if (error) { + printf("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } + printf("%s: loaded '%s', functions: %s\n", __func__, [path_lib UTF8String], [[[metal_library functionNames] componentsJoinedByString:@", "] UTF8String]); + } else { + printf("%s: default.metallib not found, loading from source\n", __func__); + + NSString * source_path = [[@ __FILE__ stringByDeletingLastPathComponent] stringByAppendingPathComponent:@"ggml-metal.metal"]; + printf("%s: loading '%s'\n", __func__, [source_path UTF8String]); + + NSString * src = [NSString stringWithContentsOfFile:source_path encoding:NSUTF8StringEncoding error:&error]; + if (error) { + printf("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } + + @autoreleasepool { + // dictionary of preprocessor macros + NSMutableDictionary * prep = [NSMutableDictionary dictionary]; + + MTLCompileOptions* options = [MTLCompileOptions new]; + options.preprocessorMacros = prep; + + metal_library = [ctx->device newLibraryWithSource:src options:options error:&error]; + if (error) { + printf("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } + } + } + +#define GSPLAT_METAL_ADD_KERNEL(NAME) \ + { \ + id metal_function = [metal_library newFunctionWithName:@#NAME]; \ + printf("%s: load function %s with label: %s\n", __func__, #NAME, [[metal_function label] UTF8String]); \ + ctx->NAME ## _cpso = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ + [metal_function release]; \ + if (error) { \ + printf("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + [metal_library release]; \ + return NULL; \ + } \ + } + + // GSPLAT_METAL_ADD_KERNEL(nd_rasterize_forward_kernel); + // GSPLAT_METAL_ADD_KERNEL(nd_rasterize_backward_kernel); + GSPLAT_METAL_ADD_KERNEL(rasterize_forward_kernel); + // GSPLAT_METAL_ADD_KERNEL(rasterize_backward_kernel); + GSPLAT_METAL_ADD_KERNEL(project_gaussians_forward_kernel); + // GSPLAT_METAL_ADD_KERNEL(project_gaussians_backward_kernel); + GSPLAT_METAL_ADD_KERNEL(compute_sh_forward_kernel); + // GSPLAT_METAL_ADD_KERNEL(compute_sh_backward_kernel); + // GSPLAT_METAL_ADD_KERNEL(compute_cov2d_bounds_kernel); + // GSPLAT_METAL_ADD_KERNEL(map_gaussian_to_intersects_kernel); + // GSPLAT_METAL_ADD_KERNEL(get_tile_bin_edges_kernel); + + [metal_library release]; + + return ctx; +} + +// TODO(achan): Where do I call this? +void free_gsplat_metal_context(MetalContext* ctx) { + [ctx->nd_rasterize_forward_kernel_cpso release]; + [ctx->nd_rasterize_backward_kernel_cpso release]; + [ctx->rasterize_forward_kernel_cpso release]; + [ctx->rasterize_backward_kernel_cpso release]; + [ctx->project_gaussians_forward_kernel_cpso release]; + [ctx->project_gaussians_backward_kernel_cpso release]; + [ctx->compute_sh_forward_kernel_cpso release]; + [ctx->compute_sh_backward_kernel_cpso release]; + [ctx->compute_cov2d_bounds_kernel_cpso release]; + [ctx->map_gaussian_to_intersects_kernel_cpso release]; + [ctx->get_tile_bin_edges_kernel_cpso release]; + + [ctx->queue release]; + [ctx->device release]; + // We do not need to release `d_queue` here as that is managed by torch. + + free(ctx); +} + +MetalContext* get_global_context() { + static MetalContext* ctx = NULL; + if (ctx == NULL) { + ctx = init_gsplat_metal_context(); + } + return ctx; +} + +// Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`. +id getMTLBufferStorage(const torch::Tensor& tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + std::tuple< torch::Tensor, // output conics torch::Tensor> // output radii @@ -45,6 +187,42 @@ unsigned num_sh_bases(const unsigned degree) { AT_ERROR("coeffs must have dimensions (N, D, 3)"); } torch::Tensor colors = torch::empty({num_points, 3}, coeffs.options()); + + // Get a reference to the command buffer for the MPS stream + id command_buffer = torch::mps::get_command_buffer(); + TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + + // Dispatch the kernel + MetalContext* ctx = get_global_context(); + dispatch_sync(ctx->d_queue, ^(){ + // Start a compute pass + id encoder = [command_buffer computeCommandEncoder]; + TORCH_CHECK(encoder, "Failed to create compute command encoder"); + + // Encode the pipeline state object + id cpso = ctx->compute_sh_forward_kernel_cpso; + [encoder setComputePipelineState:cpso]; + + // Set the tensor buffers + [encoder setBytes:&num_points length:sizeof(num_points) atIndex:0]; + [encoder setBytes:°ree length:sizeof(degree) atIndex:1]; + [encoder setBuffer:getMTLBufferStorage(viewdirs) offset:viewdirs.storage_offset() * viewdirs.element_size() atIndex:2]; + [encoder setBuffer:getMTLBufferStorage(coeffs) offset:coeffs.storage_offset() * coeffs.element_size() atIndex:3]; + [encoder setBuffer:getMTLBufferStorage(colors) offset:colors.storage_offset() * colors.element_size() atIndex:4]; + + // Set the grid threadgroup sizes + MTLSize grid_size = MTLSizeMake(num_points, 1, 1); + + NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + + // Dispatch the compute command + [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; + [encoder endEncoding]; + + // Commit the work + torch::mps::synchronize(); + }); return colors; } @@ -106,6 +284,57 @@ unsigned num_sh_bases(const unsigned degree) { torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32)); torch::Tensor num_tiles_hit_d = torch::zeros({num_points}, means3d.options().dtype(torch::kInt32)); + + // Get a reference to the command buffer for the MPS stream + id command_buffer = torch::mps::get_command_buffer(); + TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + + // Dispatch the kernel + MetalContext* ctx = get_global_context(); + dispatch_sync(ctx->d_queue, ^(){ + // Start a compute pass + id encoder = [command_buffer computeCommandEncoder]; + TORCH_CHECK(encoder, "Failed to create compute command encoder"); + + float intrins[4] = {fx, fy, cx, cy}; + int32_t img_size[3] = {(int32_t)img_width, (int32_t)img_height, 1}; + int32_t tile_bounds_dim3[3] = {std::get<0>(tile_bounds), std::get<1>(tile_bounds), std::get<2>(tile_bounds)}; + + // Encode the pipeline state object + id cpso = ctx->project_gaussians_forward_kernel_cpso; + [encoder setComputePipelineState:cpso]; + + // Set the tensor buffers + [encoder setBytes:&num_points length:sizeof(num_points) atIndex:0]; + [encoder setBuffer:getMTLBufferStorage(means3d) offset:means3d.storage_offset() * means3d.element_size() atIndex:1]; + [encoder setBuffer:getMTLBufferStorage(scales) offset:scales.storage_offset() * scales.element_size() atIndex:2]; + [encoder setBytes:&glob_scale length:sizeof(glob_scale) atIndex:3]; + [encoder setBuffer:getMTLBufferStorage(quats) offset:quats.storage_offset() * quats.element_size() atIndex:4]; + [encoder setBuffer:getMTLBufferStorage(viewmat) offset:viewmat.storage_offset() * viewmat.element_size() atIndex:5]; + [encoder setBuffer:getMTLBufferStorage(projmat) offset:projmat.storage_offset() * projmat.element_size() atIndex:6]; + [encoder setBytes:intrins length:sizeof(intrins) atIndex:7]; + [encoder setBytes:img_size length:sizeof(img_size) atIndex:8]; + [encoder setBytes:tile_bounds_dim3 length:sizeof(tile_bounds_dim3) atIndex:9]; + [encoder setBytes:&clip_thresh length:sizeof(clip_thresh) atIndex:10]; + [encoder setBuffer:getMTLBufferStorage(cov3d_d) offset:cov3d_d.storage_offset() * cov3d_d.element_size() atIndex:11]; + [encoder setBuffer:getMTLBufferStorage(xys_d) offset:xys_d.storage_offset() * xys_d.element_size() atIndex:12]; + [encoder setBuffer:getMTLBufferStorage(depths_d) offset:depths_d.storage_offset() * depths_d.element_size() atIndex:13]; + [encoder setBuffer:getMTLBufferStorage(radii_d) offset:radii_d.storage_offset() * radii_d.element_size() atIndex:14]; + [encoder setBuffer:getMTLBufferStorage(conics_d) offset:conics_d.storage_offset() * conics_d.element_size() atIndex:15]; + [encoder setBuffer:getMTLBufferStorage(num_tiles_hit_d) offset:num_tiles_hit_d.storage_offset() * num_tiles_hit_d.element_size() atIndex:16]; + + // Set the grid threadgroup sizes + MTLSize grid_size = MTLSizeMake(num_points, 1, 1); + NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + + // Dispatch the compute command + [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; + [encoder endEncoding]; + + // Commit the work + torch::mps::synchronize(); + }); return std::make_tuple( cov3d_d, xys_d, depths_d, radii_d, conics_d, num_tiles_hit_d @@ -227,6 +456,53 @@ unsigned num_sh_bases(const unsigned degree) { {img_height, img_width}, xys.options().dtype(torch::kInt32) ); + // Get a reference to the command buffer for the MPS stream + id command_buffer = torch::mps::get_command_buffer(); + TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + + // Dispatch the kernel + MetalContext* ctx = get_global_context(); + dispatch_sync(ctx->d_queue, ^(){ + // Start a compute pass + id encoder = [command_buffer computeCommandEncoder]; + TORCH_CHECK(encoder, "Failed to create compute command encoder"); + + // Encode the pipeline state object + id cpso = ctx->rasterize_forward_kernel_cpso; + [encoder setComputePipelineState:cpso]; + + int32_t tile_bounds_dim3[3] = {std::get<0>(tile_bounds), std::get<1>(tile_bounds), std::get<2>(tile_bounds)}; + int32_t img_size_dim3[3] = {std::get<0>(img_size), std::get<1>(img_size), std::get<2>(img_size)}; + int32_t block_size_dim3[3] = {std::get<0>(block), std::get<1>(block), std::get<2>(block)}; + + // Set the tensor buffers + [encoder setBytes:tile_bounds_dim3 length:sizeof(tile_bounds_dim3) atIndex:0]; + [encoder setBytes:img_size_dim3 length:sizeof(img_size_dim3) atIndex:1]; + [encoder setBuffer:getMTLBufferStorage(gaussian_ids_sorted) offset:gaussian_ids_sorted.storage_offset() * gaussian_ids_sorted.element_size() atIndex:2]; + [encoder setBuffer:getMTLBufferStorage(tile_bins) offset:tile_bins.storage_offset() * tile_bins.element_size() atIndex:3]; + [encoder setBuffer:getMTLBufferStorage(xys) offset:xys.storage_offset() * xys.element_size() atIndex:4]; + [encoder setBuffer:getMTLBufferStorage(conics) offset:conics.storage_offset() * conics.element_size() atIndex:5]; + [encoder setBuffer:getMTLBufferStorage(colors) offset:colors.storage_offset() * colors.element_size() atIndex:6]; + [encoder setBuffer:getMTLBufferStorage(opacities) offset:opacities.storage_offset() * opacities.element_size() atIndex:7]; + [encoder setBuffer:getMTLBufferStorage(final_Ts) offset:final_Ts.storage_offset() * final_Ts.element_size() atIndex:8]; + [encoder setBuffer:getMTLBufferStorage(final_idx) offset:final_idx.storage_offset() * final_idx.element_size() atIndex:9]; + [encoder setBuffer:getMTLBufferStorage(out_img) offset:out_img.storage_offset() * out_img.element_size() atIndex:10]; + [encoder setBuffer:getMTLBufferStorage(background) offset:background.storage_offset() * background.element_size() atIndex:11]; + [encoder setBytes:block_size_dim3 length:2*sizeof(int32_t) atIndex:12]; + + // Set the grid threadgroup sizes + MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); + // TODO(achan): we should be able to remove the 3rd dimension of `block` as it is always set to 1 + MTLSize thread_group_size = MTLSizeMake(block_size_dim3[0], block_size_dim3[1], block_size_dim3[2]); + + // Dispatch the compute command + [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; + [encoder endEncoding]; + + // Commit the work + torch::mps::synchronize(); + }); + return std::make_tuple(out_img, final_Ts, final_idx); } From 574b66285f3cc97d73f1be604d7e2e49ac9fca82 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Wed, 10 Apr 2024 17:21:47 -0700 Subject: [PATCH 05/19] .. --- vendor/gsplat-metal/gsplat_metal.metal | 298 ++++++++++++++++++++++++- vendor/gsplat-metal/gsplat_metal.mm | 165 ++++++++++++-- 2 files changed, 434 insertions(+), 29 deletions(-) diff --git a/vendor/gsplat-metal/gsplat_metal.metal b/vendor/gsplat-metal/gsplat_metal.metal index 6045947..2c1d904 100644 --- a/vendor/gsplat-metal/gsplat_metal.metal +++ b/vendor/gsplat-metal/gsplat_metal.metal @@ -273,7 +273,7 @@ inline bool compute_cov2d_bounds( } inline float2 project_pix( - constant float *mat, const float3 p, const int3 img_size, const float2 pp + constant float *mat, const float3 p, const int2 img_size, const float2 pp ) { // ROW MAJOR mat float4 p_hom = transform_4x4(mat, p); @@ -300,6 +300,14 @@ inline void write_packed_int2(device int* arr, int idx, int2 val) { arr[2*idx+1] = val.y; } +inline void write_packed_int2x(device int* arr, int idx, int x) { + arr[2*idx] = x; +} + +inline void write_packed_int2y(device int* arr, int idx, int y) { + arr[2*idx+1] = y; +} + inline float2 read_packed_float2(constant float* arr, int idx) { return float2(arr[2*idx], arr[2*idx+1]); } @@ -351,8 +359,7 @@ kernel void project_gaussians_forward_kernel( constant float* viewmat, constant float* projmat, constant float4& intrins, - constant int3& img_size, - constant int3& tile_bounds, + constant int2& img_size, constant float& clip_thresh, device float* covs3d, device float* xys, // float2 @@ -360,8 +367,10 @@ kernel void project_gaussians_forward_kernel( device int* radii, device float* conics, // float3 device int32_t* num_tiles_hit, - uint idx [[thread_position_in_grid]] + uint3 tile_bounds [[threadgroups_per_grid]], + uint3 gp [[thread_position_in_grid]] ) { + uint idx = gp.x; if (idx >= num_points) { return; } @@ -402,7 +411,7 @@ kernel void project_gaussians_forward_kernel( // compute the projected mean float2 center = project_pix(projmat, p_world, img_size, {cx, cy}); uint2 tile_min, tile_max; - get_tile_bbox(center, radius, tile_bounds, tile_min, tile_max); + get_tile_bbox(center, radius, (int3)tile_bounds, tile_min, tile_max); int32_t tile_area = (tile_max.x - tile_min.x) * (tile_max.y - tile_min.y); if (tile_area <= 0) { return; @@ -414,8 +423,8 @@ kernel void project_gaussians_forward_kernel( write_packed_float2(xys, idx, center); } +// TODO(achan): this is actually the nd_rasterize_forward_kernel kernel void rasterize_forward_kernel( - constant int3& tile_bounds, constant int3& img_size, constant uint& channels, constant int32_t* gaussian_ids_sorted, @@ -429,6 +438,7 @@ kernel void rasterize_forward_kernel( device float* out_img, constant float* background, constant uint2& blockDim, + uint2 tile_bounds [[threadgroups_per_grid]], uint2 blockIdx [[threadgroup_position_in_grid]], uint2 threadIdx [[thread_position_in_threadgroup]] ) { @@ -593,4 +603,280 @@ kernel void compute_sh_forward_kernel( sh_coeffs_to_color( degrees_to_use, read_packed_float3(viewdirs, idx), &(coeffs[idx_sh]), &(colors[idx_col]) ); +} + +// kernel to map each intersection from tile ID and depth to a gaussian +// writes output to isect_ids and gaussian_ids +kernel void map_gaussian_to_intersects_kernel( + constant int& num_points, + constant float* xys, // float2 + constant float* depths, + constant int* radii, + constant int32_t* cum_tiles_hit, + device int64_t* isect_ids, + device int32_t* gaussian_ids, + uint3 tile_bounds [[threadgroups_per_grid]], + uint3 gp [[thread_position_in_grid]] +) { + uint idx = gp.x; + if (idx >= num_points) + return; + if (radii[idx] <= 0) + return; + // get the tile bbox for gaussian + uint2 tile_min, tile_max; + float2 center = read_packed_float2(xys, idx); + get_tile_bbox(center, radii[idx], (int3)tile_bounds, tile_min, tile_max); + // printf("point %d, %d radius, min %d %d, max %d %d\n", idx, radii[idx], + // tile_min.x, tile_min.y, tile_max.x, tile_max.y); + + // update the intersection info for all tiles this gaussian hits + int32_t cur_idx = (idx == 0) ? 0 : cum_tiles_hit[idx - 1]; + // printf("point %d starting at %d\n", idx, cur_idx); + int64_t depth_id = (int64_t) * (constant int32_t *)&(depths[idx]); + for (int i = tile_min.y; i < tile_max.y; ++i) { + for (int j = tile_min.x; j < tile_max.x; ++j) { + // isect_id is tile ID and depth as int32 + int64_t tile_id = i * tile_bounds.x + j; // tile within image + isect_ids[cur_idx] = (tile_id << 32) | depth_id; // tile | depth id + gaussian_ids[cur_idx] = idx; // 3D gaussian id + ++cur_idx; // handles gaussians that hit more than one tile + } + } + // printf("point %d ending at %d\n", idx, cur_idx); +} + +// kernel to map sorted intersection IDs to tile bins +// expect that intersection IDs are sorted by increasing tile ID +// i.e. intersections of a tile are in contiguous chunks +kernel void get_tile_bin_edges_kernel( + constant int& num_intersects, + constant int64_t* isect_ids_sorted, + device int* tile_bins, // int2 + uint idx [[thread_position_in_grid]] +) { + if (idx >= num_intersects) + return; + // save the indices where the tile_id changes + int32_t cur_tile_idx = (int32_t)(isect_ids_sorted[idx] >> 32); + if (idx == 0 || idx == num_intersects - 1) { + if (idx == 0) + write_packed_int2x(tile_bins, cur_tile_idx, 0); + if (idx == num_intersects - 1) + write_packed_int2y(tile_bins, cur_tile_idx, num_intersects); + return; + } + int32_t prev_tile_idx = (int32_t)(isect_ids_sorted[idx - 1] >> 32); + if (prev_tile_idx != cur_tile_idx) { + write_packed_int2y(tile_bins, prev_tile_idx, idx); + write_packed_int2x(tile_bins, cur_tile_idx, idx); + return; + } +} + +float block_reduce_sum(float val, uint tr, threadgroup float* shared) { + if (tr < BLOCK_SIZE) { + shared[tr] = val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = 1; s < BLOCK_SIZE; s *= 2) { + if (tr % (2 * s) == 0 && tr + s < BLOCK_SIZE) { + shared[tr] += shared[tr + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + return shared[0]; +} + +float block_reduce_max(float val, uint tr, threadgroup float* shared) { + if (tr < BLOCK_SIZE) { + shared[tr] = val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = 1; s < BLOCK_SIZE; s *= 2) { + if (tr % (2 * s) == 0 && tr + s < BLOCK_SIZE) { + shared[tr] = max(shared[tr + s], shared[tr]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + return shared[0]; +} + +kernel void rasterize_backward_kernel( + constant uint2& img_size, + constant int32_t* gaussian_ids_sorted, + constant int* tile_bins, // int2 + constant float* xys, // float2 + constant float* conics, // float3 + constant float* rgbs, // float3 + constant float* opacities, + constant float* background, // single float3 + constant float* final_Ts, + constant int* final_index, + constant float* v_output, // float3 + constant float* v_output_alpha, + device atomic_float* v_xy, // float2 + device atomic_float* v_conic, // float3 + device atomic_float* v_rgb, // float3 + device atomic_float* v_opacity, + device int32_t* debug, + uint3 tile_bounds [[threadgroups_per_grid]], + uint3 gp [[thread_position_in_grid]], + uint3 block_index [[threadgroup_position_in_grid]], + uint tr [[thread_index_in_threadgroup]] +) { + int32_t tile_id = + block_index.y * tile_bounds.x + block_index.x; + uint i = gp.y; + uint j = gp.x; + + const float px = (float)j; + const float py = (float)i; + // clamp this value to the last pixel + const int32_t pix_id = min((int32_t)(i * img_size.x + j), (int32_t)(img_size.x * img_size.y - 1)); + + // keep not rasterizing threads around for reading data + const bool inside = (i < img_size.y && j < img_size.x); + + // this is the T AFTER the last gaussian in this pixel + float T_final = final_Ts[pix_id]; + float T = T_final; + // the contribution from gaussians behind the current one + float3 buffer = {0.f, 0.f, 0.f}; + // index of last gaussian to contribute to this pixel + const int bin_final = inside? final_index[pix_id] : 0; + + // have all threads in tile process the same gaussians in batches + // first collect gaussians between range.x and range.y in batches + // which gaussians to look through in this tile + const int2 range = read_packed_int2(tile_bins, tile_id); + const int num_batches = (range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE; + + threadgroup int32_t id_batch[BLOCK_SIZE]; + threadgroup float3 xy_opacity_batch[BLOCK_SIZE]; + threadgroup float3 conic_batch[BLOCK_SIZE]; + threadgroup float3 rgbs_batch[BLOCK_SIZE]; + + // df/d_out for this pixel + const float3 v_out = read_packed_float3(v_output, pix_id); + const float v_out_alpha = v_output_alpha[pix_id]; + + // collect and process batches of gaussians + // each thread loads one gaussian at a time before rasterizing + threadgroup float shared[BLOCK_SIZE]; + // TODO(achan): convert `block_reduce_max` to use SIMD groups + const int warp_bin_final = block_reduce_max(bin_final, tr, shared); + for (int b = 0; b < num_batches; ++b) { + // resync all threads before writing next batch of shared mem + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each thread fetch 1 gaussian from back to front + // 0 index will be furthest back in batch + // index of gaussian to load + // batch end is the index of the last gaussian in the batch + const int batch_end = range.y - 1 - BLOCK_SIZE * b; + int batch_size = min(BLOCK_SIZE, batch_end + 1 - range.x); + const int idx = batch_end - tr; + if (idx >= range.x) { + int32_t g_id = gaussian_ids_sorted[idx]; + id_batch[tr] = g_id; + const float2 xy = read_packed_float2(xys, g_id); + const float opac = opacities[g_id]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + conic_batch[tr] = read_packed_float3(conics, g_id); + rgbs_batch[tr] = read_packed_float3(rgbs, g_id); + } + // wait for other threads to collect the gaussians in batch + threadgroup_barrier(mem_flags::mem_threadgroup); + + // process gaussians in the current batch for this pixel + // 0 index is the furthest back gaussian in the batch + for (int t = max(0,batch_end - warp_bin_final); t < batch_size; ++t) { + int valid = inside; + if (batch_end - t > bin_final) { + valid = 0; + } + float alpha; + float opac; + float2 delta; + float3 conic; + float vis; + if(valid){ + conic = conic_batch[t]; + float3 xy_opac = xy_opacity_batch[t]; + opac = xy_opac.z; + delta = {xy_opac.x - px, xy_opac.y - py}; + float sigma = 0.5f * (conic.x * delta.x * delta.x + + conic.z * delta.y * delta.y) + + conic.y * delta.x * delta.y; + vis = exp(-sigma); + alpha = min(0.99f, opac * vis); + if (sigma < 0.f || alpha < 1.f / 255.f) { + valid = 0; + } + } + // TODO(achan): if all threads are inactive in this warp, skip this loop iter here + + float3 v_rgb_local = {0.f, 0.f, 0.f}; + float3 v_conic_local = {0.f, 0.f, 0.f}; + float2 v_xy_local = {0.f, 0.f}; + float v_opacity_local = 0.f; + //initialize everything to 0, only set if the lane is valid + if(valid){ + // compute the current T for this gaussian + float ra = 1.f / (1.f - alpha); + T *= ra; + // update v_rgb for this gaussian + const float fac = alpha * T; + float v_alpha = 0.f; + v_rgb_local = {fac * v_out.x, fac * v_out.y, fac * v_out.z}; + + const float3 rgb = rgbs_batch[t]; + // contribution from this pixel + v_alpha += (rgb.x * T - buffer.x * ra) * v_out.x; + v_alpha += (rgb.y * T - buffer.y * ra) * v_out.y; + v_alpha += (rgb.z * T - buffer.z * ra) * v_out.z; + + v_alpha += T_final * ra * v_out_alpha; + // contribution from background pixel + v_alpha += -T_final * ra * background[0] * v_out.x; + v_alpha += -T_final * ra * background[1] * v_out.y; + v_alpha += -T_final * ra * background[2] * v_out.z; + // update the running sum + buffer.x += rgb.x * fac; + buffer.y += rgb.y * fac; + buffer.z += rgb.z * fac; + + const float v_sigma = -opac * vis * v_alpha; + v_conic_local = {0.5f * v_sigma * delta.x * delta.x, + 0.5f * v_sigma * delta.x * delta.y, + 0.5f * v_sigma * delta.y * delta.y}; + v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y), + v_sigma * (conic.y * delta.x + conic.z * delta.y)}; + v_opacity_local = vis * v_alpha; + } + + // TODO(achan): Use SIMD groups to reduce atomic contention similarly to warps + int32_t g = id_batch[t]; + + atomic_fetch_add_explicit(v_rgb + 3*g + 0, v_rgb_local.x, memory_order_relaxed); + atomic_fetch_add_explicit(v_rgb + 3*g + 1, v_rgb_local.y, memory_order_relaxed); + atomic_fetch_add_explicit(v_rgb + 3*g + 2, v_rgb_local.z, memory_order_relaxed); + + atomic_fetch_add_explicit(v_conic + 3*g + 0, v_conic_local.x, memory_order_relaxed); + atomic_fetch_add_explicit(v_conic + 3*g + 1, v_conic_local.y, memory_order_relaxed); + atomic_fetch_add_explicit(v_conic + 3*g + 2, v_conic_local.z, memory_order_relaxed); + + atomic_fetch_add_explicit(v_xy + 2*g + 0, v_xy_local.x, memory_order_relaxed); + atomic_fetch_add_explicit(v_xy + 2*g + 1, v_xy_local.y, memory_order_relaxed); + + atomic_fetch_add_explicit(v_opacity + g, v_opacity_local, memory_order_relaxed); + } + } } \ No newline at end of file diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index 5fbec07..777ccb5 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -1,4 +1,5 @@ #import "bindings.h" +#import "../gsplat/config.h" #import @@ -72,7 +73,7 @@ @implementation DummyClassForPathHack } else { printf("%s: default.metallib not found, loading from source\n", __func__); - NSString * source_path = [[@ __FILE__ stringByDeletingLastPathComponent] stringByAppendingPathComponent:@"ggml-metal.metal"]; + NSString * source_path = [[@ __FILE__ stringByDeletingLastPathComponent] stringByAppendingPathComponent:@"gsplat_metal.metal"]; printf("%s: loading '%s'\n", __func__, [source_path UTF8String]); NSString * src = [NSString stringWithContentsOfFile:source_path encoding:NSUTF8StringEncoding error:&error]; @@ -112,14 +113,14 @@ @implementation DummyClassForPathHack // GSPLAT_METAL_ADD_KERNEL(nd_rasterize_forward_kernel); // GSPLAT_METAL_ADD_KERNEL(nd_rasterize_backward_kernel); GSPLAT_METAL_ADD_KERNEL(rasterize_forward_kernel); - // GSPLAT_METAL_ADD_KERNEL(rasterize_backward_kernel); + GSPLAT_METAL_ADD_KERNEL(rasterize_backward_kernel); GSPLAT_METAL_ADD_KERNEL(project_gaussians_forward_kernel); // GSPLAT_METAL_ADD_KERNEL(project_gaussians_backward_kernel); GSPLAT_METAL_ADD_KERNEL(compute_sh_forward_kernel); // GSPLAT_METAL_ADD_KERNEL(compute_sh_backward_kernel); // GSPLAT_METAL_ADD_KERNEL(compute_cov2d_bounds_kernel); - // GSPLAT_METAL_ADD_KERNEL(map_gaussian_to_intersects_kernel); - // GSPLAT_METAL_ADD_KERNEL(get_tile_bin_edges_kernel); + GSPLAT_METAL_ADD_KERNEL(map_gaussian_to_intersects_kernel); + GSPLAT_METAL_ADD_KERNEL(get_tile_bin_edges_kernel); [metal_library release]; @@ -206,9 +207,10 @@ void free_gsplat_metal_context(MetalContext* ctx) { // Set the tensor buffers [encoder setBytes:&num_points length:sizeof(num_points) atIndex:0]; [encoder setBytes:°ree length:sizeof(degree) atIndex:1]; - [encoder setBuffer:getMTLBufferStorage(viewdirs) offset:viewdirs.storage_offset() * viewdirs.element_size() atIndex:2]; - [encoder setBuffer:getMTLBufferStorage(coeffs) offset:coeffs.storage_offset() * coeffs.element_size() atIndex:3]; - [encoder setBuffer:getMTLBufferStorage(colors) offset:colors.storage_offset() * colors.element_size() atIndex:4]; + [encoder setBytes:°rees_to_use length:sizeof(degrees_to_use) atIndex:2]; + [encoder setBuffer:getMTLBufferStorage(viewdirs) offset:viewdirs.storage_offset() * viewdirs.element_size() atIndex:3]; + [encoder setBuffer:getMTLBufferStorage(coeffs) offset:coeffs.storage_offset() * coeffs.element_size() atIndex:4]; + [encoder setBuffer:getMTLBufferStorage(colors) offset:colors.storage_offset() * colors.element_size() atIndex:5]; // Set the grid threadgroup sizes MTLSize grid_size = MTLSizeMake(num_points, 1, 1); @@ -297,8 +299,7 @@ void free_gsplat_metal_context(MetalContext* ctx) { TORCH_CHECK(encoder, "Failed to create compute command encoder"); float intrins[4] = {fx, fy, cx, cy}; - int32_t img_size[3] = {(int32_t)img_width, (int32_t)img_height, 1}; - int32_t tile_bounds_dim3[3] = {std::get<0>(tile_bounds), std::get<1>(tile_bounds), std::get<2>(tile_bounds)}; + int32_t img_size[2] = {(int32_t)img_width, (int32_t)img_height}; // Encode the pipeline state object id cpso = ctx->project_gaussians_forward_kernel_cpso; @@ -314,14 +315,13 @@ void free_gsplat_metal_context(MetalContext* ctx) { [encoder setBuffer:getMTLBufferStorage(projmat) offset:projmat.storage_offset() * projmat.element_size() atIndex:6]; [encoder setBytes:intrins length:sizeof(intrins) atIndex:7]; [encoder setBytes:img_size length:sizeof(img_size) atIndex:8]; - [encoder setBytes:tile_bounds_dim3 length:sizeof(tile_bounds_dim3) atIndex:9]; - [encoder setBytes:&clip_thresh length:sizeof(clip_thresh) atIndex:10]; - [encoder setBuffer:getMTLBufferStorage(cov3d_d) offset:cov3d_d.storage_offset() * cov3d_d.element_size() atIndex:11]; - [encoder setBuffer:getMTLBufferStorage(xys_d) offset:xys_d.storage_offset() * xys_d.element_size() atIndex:12]; - [encoder setBuffer:getMTLBufferStorage(depths_d) offset:depths_d.storage_offset() * depths_d.element_size() atIndex:13]; - [encoder setBuffer:getMTLBufferStorage(radii_d) offset:radii_d.storage_offset() * radii_d.element_size() atIndex:14]; - [encoder setBuffer:getMTLBufferStorage(conics_d) offset:conics_d.storage_offset() * conics_d.element_size() atIndex:15]; - [encoder setBuffer:getMTLBufferStorage(num_tiles_hit_d) offset:num_tiles_hit_d.storage_offset() * num_tiles_hit_d.element_size() atIndex:16]; + [encoder setBytes:&clip_thresh length:sizeof(clip_thresh) atIndex:9]; + [encoder setBuffer:getMTLBufferStorage(cov3d_d) offset:cov3d_d.storage_offset() * cov3d_d.element_size() atIndex:10]; + [encoder setBuffer:getMTLBufferStorage(xys_d) offset:xys_d.storage_offset() * xys_d.element_size() atIndex:11]; + [encoder setBuffer:getMTLBufferStorage(depths_d) offset:depths_d.storage_offset() * depths_d.element_size() atIndex:12]; + [encoder setBuffer:getMTLBufferStorage(radii_d) offset:radii_d.storage_offset() * radii_d.element_size() atIndex:13]; + [encoder setBuffer:getMTLBufferStorage(conics_d) offset:conics_d.storage_offset() * conics_d.element_size() atIndex:14]; + [encoder setBuffer:getMTLBufferStorage(num_tiles_hit_d) offset:num_tiles_hit_d.storage_offset() * num_tiles_hit_d.element_size() atIndex:15]; // Set the grid threadgroup sizes MTLSize grid_size = MTLSizeMake(num_points, 1, 1); @@ -402,6 +402,43 @@ void free_gsplat_metal_context(MetalContext* ctx) { torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32)); torch::Tensor isect_ids_unsorted = torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64)); + + // Get a reference to the command buffer for the MPS stream + id command_buffer = torch::mps::get_command_buffer(); + TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + + // Dispatch the kernel + MetalContext* ctx = get_global_context(); + dispatch_sync(ctx->d_queue, ^(){ + // Start a compute pass + id encoder = [command_buffer computeCommandEncoder]; + TORCH_CHECK(encoder, "Failed to create compute command encoder"); + + // Encode the pipeline state object + id cpso = ctx->map_gaussian_to_intersects_kernel_cpso; + [encoder setComputePipelineState:cpso]; + + // Set the tensor buffers + [encoder setBytes:&num_points length:sizeof(num_points) atIndex:0]; + [encoder setBuffer:getMTLBufferStorage(xys) offset:xys.storage_offset() * xys.element_size() atIndex:1]; + [encoder setBuffer:getMTLBufferStorage(depths) offset:depths.storage_offset() * depths.element_size() atIndex:2]; + [encoder setBuffer:getMTLBufferStorage(radii) offset:radii.storage_offset() * radii.element_size() atIndex:3]; + [encoder setBuffer:getMTLBufferStorage(cum_tiles_hit) offset:cum_tiles_hit.storage_offset() * cum_tiles_hit.element_size() atIndex:4]; + [encoder setBuffer:getMTLBufferStorage(isect_ids_unsorted) offset:isect_ids_unsorted.storage_offset() * isect_ids_unsorted.element_size() atIndex:5]; + [encoder setBuffer:getMTLBufferStorage(gaussian_ids_unsorted) offset:gaussian_ids_unsorted.storage_offset() * gaussian_ids_unsorted.element_size() atIndex:6]; + + // Set the grid threadgroup sizes + MTLSize grid_size = MTLSizeMake(num_points, 1, 1); + NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + + // Dispatch the compute command + [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; + [encoder endEncoding]; + + // Commit the work + torch::mps::synchronize(); + }); return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted); } @@ -415,6 +452,39 @@ void free_gsplat_metal_context(MetalContext* ctx) { {num_intersects, 2}, isect_ids_sorted.options().dtype(torch::kInt32) ); + // Get a reference to the command buffer for the MPS stream + id command_buffer = torch::mps::get_command_buffer(); + TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + + // Dispatch the kernel + MetalContext* ctx = get_global_context(); + dispatch_sync(ctx->d_queue, ^(){ + // Start a compute pass + id encoder = [command_buffer computeCommandEncoder]; + TORCH_CHECK(encoder, "Failed to create compute command encoder"); + + // Encode the pipeline state object + id cpso = ctx->get_tile_bin_edges_kernel_cpso; + [encoder setComputePipelineState:cpso]; + + // Set the tensor buffers + [encoder setBytes:&num_intersects length:sizeof(num_intersects) atIndex:0]; + [encoder setBuffer:getMTLBufferStorage(isect_ids_sorted) offset:isect_ids_sorted.storage_offset() * isect_ids_sorted.element_size() atIndex:1]; + [encoder setBuffer:getMTLBufferStorage(tile_bins) offset:tile_bins.storage_offset() * tile_bins.element_size() atIndex:2]; + + // Set the grid threadgroup sizes + MTLSize grid_size = MTLSizeMake(num_intersects, 1, 1); + NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_intersects); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + + // Dispatch the compute command + [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; + [encoder endEncoding]; + + // Commit the work + torch::mps::synchronize(); + }); + return tile_bins; } @@ -442,7 +512,7 @@ void free_gsplat_metal_context(MetalContext* ctx) { CHECK_INPUT(opacities); CHECK_INPUT(background); - const int channels = colors.size(1); + const uint32_t channels = colors.size(1); const int img_width = std::get<0>(img_size); const int img_height = std::get<1>(img_size); @@ -471,13 +541,12 @@ void free_gsplat_metal_context(MetalContext* ctx) { id cpso = ctx->rasterize_forward_kernel_cpso; [encoder setComputePipelineState:cpso]; - int32_t tile_bounds_dim3[3] = {std::get<0>(tile_bounds), std::get<1>(tile_bounds), std::get<2>(tile_bounds)}; - int32_t img_size_dim3[3] = {std::get<0>(img_size), std::get<1>(img_size), std::get<2>(img_size)}; - int32_t block_size_dim3[3] = {std::get<0>(block), std::get<1>(block), std::get<2>(block)}; + int32_t img_size_dim3[4] = {std::get<0>(img_size), std::get<1>(img_size), std::get<2>(img_size), 0xDEAD}; + int32_t block_size_dim3[4] = {std::get<0>(block), std::get<1>(block), std::get<2>(block), 0xDEAD}; // Set the tensor buffers - [encoder setBytes:tile_bounds_dim3 length:sizeof(tile_bounds_dim3) atIndex:0]; - [encoder setBytes:img_size_dim3 length:sizeof(img_size_dim3) atIndex:1]; + [encoder setBytes:img_size_dim3 length:sizeof(img_size_dim3) atIndex:0]; + [encoder setBytes:&channels length:sizeof(&channels) atIndex:1]; [encoder setBuffer:getMTLBufferStorage(gaussian_ids_sorted) offset:gaussian_ids_sorted.storage_offset() * gaussian_ids_sorted.element_size() atIndex:2]; [encoder setBuffer:getMTLBufferStorage(tile_bins) offset:tile_bins.storage_offset() * tile_bins.element_size() atIndex:3]; [encoder setBuffer:getMTLBufferStorage(xys) offset:xys.storage_offset() * xys.element_size() atIndex:4]; @@ -619,5 +688,55 @@ void free_gsplat_metal_context(MetalContext* ctx) { torch::zeros({num_points, channels}, xys.options()); torch::Tensor v_opacity = torch::zeros({num_points, 1}, xys.options()); + torch::Tensor debug = torch::zeros({1}, xys.options().dtype(torch::kInt32)); + + // Get a reference to the command buffer for the MPS stream + id command_buffer = torch::mps::get_command_buffer(); + TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + + // Dispatch the kernel + MetalContext* ctx = get_global_context(); + dispatch_sync(ctx->d_queue, ^(){ + // Start a compute pass + id encoder = [command_buffer computeCommandEncoder]; + TORCH_CHECK(encoder, "Failed to create compute command encoder"); + + // Encode the pipeline state object + id cpso = ctx->rasterize_backward_kernel_cpso; + [encoder setComputePipelineState:cpso]; + + uint32_t img_size[2] = {img_height, img_width}; + + // Set the tensor buffers + [encoder setBytes:img_size length:sizeof(img_size) atIndex:0]; + [encoder setBuffer:getMTLBufferStorage(gaussians_ids_sorted) offset:gaussians_ids_sorted.storage_offset() * gaussians_ids_sorted.element_size() atIndex:1]; + [encoder setBuffer:getMTLBufferStorage(tile_bins) offset:tile_bins.storage_offset() * tile_bins.element_size() atIndex:2]; + [encoder setBuffer:getMTLBufferStorage(xys) offset:xys.storage_offset() * xys.element_size() atIndex:3]; + [encoder setBuffer:getMTLBufferStorage(conics) offset:conics.storage_offset() * conics.element_size() atIndex:4]; + [encoder setBuffer:getMTLBufferStorage(colors) offset:colors.storage_offset() * colors.element_size() atIndex:5]; + [encoder setBuffer:getMTLBufferStorage(opacities) offset:opacities.storage_offset() * opacities.element_size() atIndex:6]; + [encoder setBuffer:getMTLBufferStorage(background) offset:background.storage_offset() * background.element_size() atIndex:7]; + [encoder setBuffer:getMTLBufferStorage(final_Ts) offset:final_Ts.storage_offset() * final_Ts.element_size() atIndex:8]; + [encoder setBuffer:getMTLBufferStorage(final_idx) offset:final_idx.storage_offset() * final_idx.element_size() atIndex:9]; + [encoder setBuffer:getMTLBufferStorage(v_output) offset:v_output.storage_offset() * v_output.element_size() atIndex:10]; + [encoder setBuffer:getMTLBufferStorage(v_output_alpha) offset:v_output_alpha.storage_offset() * v_output_alpha.element_size() atIndex:11]; + [encoder setBuffer:getMTLBufferStorage(v_xy) offset:v_xy.storage_offset() * v_xy.element_size() atIndex:12]; + [encoder setBuffer:getMTLBufferStorage(v_conic) offset:v_conic.storage_offset() * v_conic.element_size() atIndex:13]; + [encoder setBuffer:getMTLBufferStorage(v_colors) offset:v_colors.storage_offset() * v_colors.element_size() atIndex:14]; + [encoder setBuffer:getMTLBufferStorage(v_opacity) offset:v_opacity.storage_offset() * v_opacity.element_size() atIndex:15]; + [encoder setBuffer:getMTLBufferStorage(debug) offset:debug.storage_offset() * debug.element_size() atIndex:16]; + + // Set the grid threadgroup sizes + MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); + MTLSize thread_group_size = MTLSizeMake(BLOCK_X, BLOCK_Y, 1); + + // Dispatch the compute command + [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; + [encoder endEncoding]; + + // Commit the work + torch::mps::synchronize(); + }); + return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); } \ No newline at end of file From 90f1e02f060c97ce8a61c3dc32c494fdf54d3fa2 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Thu, 11 Apr 2024 15:53:29 -0700 Subject: [PATCH 06/19] macros --- vendor/gsplat-metal/gsplat_metal.mm | 134 ++++++++++++++-------------- 1 file changed, 69 insertions(+), 65 deletions(-) diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index 777ccb5..89a084d 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -161,6 +161,10 @@ void free_gsplat_metal_context(MetalContext* ctx) { return __builtin_bit_cast(id, tensor.storage().data()); } +#define ENC_SCALAR(encoder, x, i) [encoder setBytes:&x length:sizeof(x) atIndex:i] +#define ENC_ARRAY(encoder, x, i) [encoder setBytes:x length:sizeof(x) atIndex:i] +#define ENC_TENSOR(encoder, x, i) [encoder setBuffer:getMTLBufferStorage(x) offset:x.storage_offset() * x.element_size() atIndex:i] + std::tuple< torch::Tensor, // output conics torch::Tensor> // output radii @@ -205,12 +209,12 @@ void free_gsplat_metal_context(MetalContext* ctx) { [encoder setComputePipelineState:cpso]; // Set the tensor buffers - [encoder setBytes:&num_points length:sizeof(num_points) atIndex:0]; - [encoder setBytes:°ree length:sizeof(degree) atIndex:1]; - [encoder setBytes:°rees_to_use length:sizeof(degrees_to_use) atIndex:2]; - [encoder setBuffer:getMTLBufferStorage(viewdirs) offset:viewdirs.storage_offset() * viewdirs.element_size() atIndex:3]; - [encoder setBuffer:getMTLBufferStorage(coeffs) offset:coeffs.storage_offset() * coeffs.element_size() atIndex:4]; - [encoder setBuffer:getMTLBufferStorage(colors) offset:colors.storage_offset() * colors.element_size() atIndex:5]; + ENC_SCALAR(encoder, num_points, 0); + ENC_SCALAR(encoder, degree, 1); + ENC_SCALAR(encoder, degrees_to_use, 2); + ENC_TENSOR(encoder, viewdirs, 3); + ENC_TENSOR(encoder, coeffs, 4); + ENC_TENSOR(encoder, colors, 5); // Set the grid threadgroup sizes MTLSize grid_size = MTLSizeMake(num_points, 1, 1); @@ -306,22 +310,22 @@ void free_gsplat_metal_context(MetalContext* ctx) { [encoder setComputePipelineState:cpso]; // Set the tensor buffers - [encoder setBytes:&num_points length:sizeof(num_points) atIndex:0]; - [encoder setBuffer:getMTLBufferStorage(means3d) offset:means3d.storage_offset() * means3d.element_size() atIndex:1]; - [encoder setBuffer:getMTLBufferStorage(scales) offset:scales.storage_offset() * scales.element_size() atIndex:2]; - [encoder setBytes:&glob_scale length:sizeof(glob_scale) atIndex:3]; - [encoder setBuffer:getMTLBufferStorage(quats) offset:quats.storage_offset() * quats.element_size() atIndex:4]; - [encoder setBuffer:getMTLBufferStorage(viewmat) offset:viewmat.storage_offset() * viewmat.element_size() atIndex:5]; - [encoder setBuffer:getMTLBufferStorage(projmat) offset:projmat.storage_offset() * projmat.element_size() atIndex:6]; - [encoder setBytes:intrins length:sizeof(intrins) atIndex:7]; - [encoder setBytes:img_size length:sizeof(img_size) atIndex:8]; - [encoder setBytes:&clip_thresh length:sizeof(clip_thresh) atIndex:9]; - [encoder setBuffer:getMTLBufferStorage(cov3d_d) offset:cov3d_d.storage_offset() * cov3d_d.element_size() atIndex:10]; - [encoder setBuffer:getMTLBufferStorage(xys_d) offset:xys_d.storage_offset() * xys_d.element_size() atIndex:11]; - [encoder setBuffer:getMTLBufferStorage(depths_d) offset:depths_d.storage_offset() * depths_d.element_size() atIndex:12]; - [encoder setBuffer:getMTLBufferStorage(radii_d) offset:radii_d.storage_offset() * radii_d.element_size() atIndex:13]; - [encoder setBuffer:getMTLBufferStorage(conics_d) offset:conics_d.storage_offset() * conics_d.element_size() atIndex:14]; - [encoder setBuffer:getMTLBufferStorage(num_tiles_hit_d) offset:num_tiles_hit_d.storage_offset() * num_tiles_hit_d.element_size() atIndex:15]; + ENC_SCALAR(encoder, num_points, 0); + ENC_TENSOR(encoder, means3d, 1); + ENC_TENSOR(encoder, scales, 2); + ENC_SCALAR(encoder, glob_scale, 3); + ENC_TENSOR(encoder, quats, 4); + ENC_TENSOR(encoder, viewmat, 5); + ENC_TENSOR(encoder, projmat, 6); + ENC_ARRAY(encoder, intrins, 7); + ENC_ARRAY(encoder, img_size, 8); + ENC_SCALAR(encoder, clip_thresh, 9); + ENC_TENSOR(encoder, cov3d_d, 10); + ENC_TENSOR(encoder, xys_d, 11); + ENC_TENSOR(encoder, depths_d, 12); + ENC_TENSOR(encoder, radii_d, 13); + ENC_TENSOR(encoder, conics_d, 14); + ENC_TENSOR(encoder, num_tiles_hit_d, 15); // Set the grid threadgroup sizes MTLSize grid_size = MTLSizeMake(num_points, 1, 1); @@ -419,13 +423,13 @@ void free_gsplat_metal_context(MetalContext* ctx) { [encoder setComputePipelineState:cpso]; // Set the tensor buffers - [encoder setBytes:&num_points length:sizeof(num_points) atIndex:0]; - [encoder setBuffer:getMTLBufferStorage(xys) offset:xys.storage_offset() * xys.element_size() atIndex:1]; - [encoder setBuffer:getMTLBufferStorage(depths) offset:depths.storage_offset() * depths.element_size() atIndex:2]; - [encoder setBuffer:getMTLBufferStorage(radii) offset:radii.storage_offset() * radii.element_size() atIndex:3]; - [encoder setBuffer:getMTLBufferStorage(cum_tiles_hit) offset:cum_tiles_hit.storage_offset() * cum_tiles_hit.element_size() atIndex:4]; - [encoder setBuffer:getMTLBufferStorage(isect_ids_unsorted) offset:isect_ids_unsorted.storage_offset() * isect_ids_unsorted.element_size() atIndex:5]; - [encoder setBuffer:getMTLBufferStorage(gaussian_ids_unsorted) offset:gaussian_ids_unsorted.storage_offset() * gaussian_ids_unsorted.element_size() atIndex:6]; + ENC_SCALAR(encoder, num_points, 0); + ENC_TENSOR(encoder, xys, 1); + ENC_TENSOR(encoder, depths, 2); + ENC_TENSOR(encoder, radii, 3); + ENC_TENSOR(encoder, cum_tiles_hit, 4); + ENC_TENSOR(encoder, isect_ids_unsorted, 5); + ENC_TENSOR(encoder, gaussian_ids_unsorted, 6); // Set the grid threadgroup sizes MTLSize grid_size = MTLSizeMake(num_points, 1, 1); @@ -468,9 +472,9 @@ void free_gsplat_metal_context(MetalContext* ctx) { [encoder setComputePipelineState:cpso]; // Set the tensor buffers - [encoder setBytes:&num_intersects length:sizeof(num_intersects) atIndex:0]; - [encoder setBuffer:getMTLBufferStorage(isect_ids_sorted) offset:isect_ids_sorted.storage_offset() * isect_ids_sorted.element_size() atIndex:1]; - [encoder setBuffer:getMTLBufferStorage(tile_bins) offset:tile_bins.storage_offset() * tile_bins.element_size() atIndex:2]; + ENC_SCALAR(encoder, num_intersects, 0); + ENC_TENSOR(encoder, isect_ids_sorted, 1); + ENC_TENSOR(encoder, tile_bins, 2); // Set the grid threadgroup sizes MTLSize grid_size = MTLSizeMake(num_intersects, 1, 1); @@ -494,6 +498,7 @@ void free_gsplat_metal_context(MetalContext* ctx) { torch::Tensor > rasterize_forward_tensor( const std::tuple tile_bounds, + // TODO(achan): we should be able to remove the 3rd dimension of `block` as it is always set to 1 const std::tuple block, const std::tuple img_size, const torch::Tensor &gaussian_ids_sorted, @@ -542,27 +547,26 @@ void free_gsplat_metal_context(MetalContext* ctx) { [encoder setComputePipelineState:cpso]; int32_t img_size_dim3[4] = {std::get<0>(img_size), std::get<1>(img_size), std::get<2>(img_size), 0xDEAD}; - int32_t block_size_dim3[4] = {std::get<0>(block), std::get<1>(block), std::get<2>(block), 0xDEAD}; + int32_t block_size_dim2[2] = {std::get<0>(block), std::get<1>(block)}; // Set the tensor buffers - [encoder setBytes:img_size_dim3 length:sizeof(img_size_dim3) atIndex:0]; - [encoder setBytes:&channels length:sizeof(&channels) atIndex:1]; - [encoder setBuffer:getMTLBufferStorage(gaussian_ids_sorted) offset:gaussian_ids_sorted.storage_offset() * gaussian_ids_sorted.element_size() atIndex:2]; - [encoder setBuffer:getMTLBufferStorage(tile_bins) offset:tile_bins.storage_offset() * tile_bins.element_size() atIndex:3]; - [encoder setBuffer:getMTLBufferStorage(xys) offset:xys.storage_offset() * xys.element_size() atIndex:4]; - [encoder setBuffer:getMTLBufferStorage(conics) offset:conics.storage_offset() * conics.element_size() atIndex:5]; - [encoder setBuffer:getMTLBufferStorage(colors) offset:colors.storage_offset() * colors.element_size() atIndex:6]; - [encoder setBuffer:getMTLBufferStorage(opacities) offset:opacities.storage_offset() * opacities.element_size() atIndex:7]; - [encoder setBuffer:getMTLBufferStorage(final_Ts) offset:final_Ts.storage_offset() * final_Ts.element_size() atIndex:8]; - [encoder setBuffer:getMTLBufferStorage(final_idx) offset:final_idx.storage_offset() * final_idx.element_size() atIndex:9]; - [encoder setBuffer:getMTLBufferStorage(out_img) offset:out_img.storage_offset() * out_img.element_size() atIndex:10]; - [encoder setBuffer:getMTLBufferStorage(background) offset:background.storage_offset() * background.element_size() atIndex:11]; - [encoder setBytes:block_size_dim3 length:2*sizeof(int32_t) atIndex:12]; + ENC_ARRAY(encoder, img_size_dim3, 0); + ENC_SCALAR(encoder, channels, 1); + ENC_TENSOR(encoder, gaussian_ids_sorted, 2); + ENC_TENSOR(encoder, tile_bins, 3); + ENC_TENSOR(encoder, xys, 4); + ENC_TENSOR(encoder, conics, 5); + ENC_TENSOR(encoder, colors, 6); + ENC_TENSOR(encoder, opacities, 7); + ENC_TENSOR(encoder, final_Ts, 8); + ENC_TENSOR(encoder, final_idx, 9); + ENC_TENSOR(encoder, out_img, 10); + ENC_TENSOR(encoder, background, 11); + ENC_ARRAY(encoder, block_size_dim2, 12); // Set the grid threadgroup sizes MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); - // TODO(achan): we should be able to remove the 3rd dimension of `block` as it is always set to 1 - MTLSize thread_group_size = MTLSizeMake(block_size_dim3[0], block_size_dim3[1], block_size_dim3[2]); + MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1); // Dispatch the compute command [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; @@ -708,23 +712,23 @@ void free_gsplat_metal_context(MetalContext* ctx) { uint32_t img_size[2] = {img_height, img_width}; // Set the tensor buffers - [encoder setBytes:img_size length:sizeof(img_size) atIndex:0]; - [encoder setBuffer:getMTLBufferStorage(gaussians_ids_sorted) offset:gaussians_ids_sorted.storage_offset() * gaussians_ids_sorted.element_size() atIndex:1]; - [encoder setBuffer:getMTLBufferStorage(tile_bins) offset:tile_bins.storage_offset() * tile_bins.element_size() atIndex:2]; - [encoder setBuffer:getMTLBufferStorage(xys) offset:xys.storage_offset() * xys.element_size() atIndex:3]; - [encoder setBuffer:getMTLBufferStorage(conics) offset:conics.storage_offset() * conics.element_size() atIndex:4]; - [encoder setBuffer:getMTLBufferStorage(colors) offset:colors.storage_offset() * colors.element_size() atIndex:5]; - [encoder setBuffer:getMTLBufferStorage(opacities) offset:opacities.storage_offset() * opacities.element_size() atIndex:6]; - [encoder setBuffer:getMTLBufferStorage(background) offset:background.storage_offset() * background.element_size() atIndex:7]; - [encoder setBuffer:getMTLBufferStorage(final_Ts) offset:final_Ts.storage_offset() * final_Ts.element_size() atIndex:8]; - [encoder setBuffer:getMTLBufferStorage(final_idx) offset:final_idx.storage_offset() * final_idx.element_size() atIndex:9]; - [encoder setBuffer:getMTLBufferStorage(v_output) offset:v_output.storage_offset() * v_output.element_size() atIndex:10]; - [encoder setBuffer:getMTLBufferStorage(v_output_alpha) offset:v_output_alpha.storage_offset() * v_output_alpha.element_size() atIndex:11]; - [encoder setBuffer:getMTLBufferStorage(v_xy) offset:v_xy.storage_offset() * v_xy.element_size() atIndex:12]; - [encoder setBuffer:getMTLBufferStorage(v_conic) offset:v_conic.storage_offset() * v_conic.element_size() atIndex:13]; - [encoder setBuffer:getMTLBufferStorage(v_colors) offset:v_colors.storage_offset() * v_colors.element_size() atIndex:14]; - [encoder setBuffer:getMTLBufferStorage(v_opacity) offset:v_opacity.storage_offset() * v_opacity.element_size() atIndex:15]; - [encoder setBuffer:getMTLBufferStorage(debug) offset:debug.storage_offset() * debug.element_size() atIndex:16]; + ENC_ARRAY(encoder, img_size, 0); + ENC_TENSOR(encoder, gaussians_ids_sorted, 1); + ENC_TENSOR(encoder, tile_bins, 2); + ENC_TENSOR(encoder, xys, 3); + ENC_TENSOR(encoder, conics, 4); + ENC_TENSOR(encoder, colors, 5); + ENC_TENSOR(encoder, opacities, 6); + ENC_TENSOR(encoder, background, 7); + ENC_TENSOR(encoder, final_Ts, 8); + ENC_TENSOR(encoder, final_idx, 9); + ENC_TENSOR(encoder, v_output, 10); + ENC_TENSOR(encoder, v_output_alpha, 11); + ENC_TENSOR(encoder, v_xy, 12); + ENC_TENSOR(encoder, v_conic, 13); + ENC_TENSOR(encoder, v_colors, 14); + ENC_TENSOR(encoder, v_opacity, 15); + ENC_TENSOR(encoder, debug, 16); // Set the grid threadgroup sizes MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); From 8055f11a8b2babe0cf64e2127e8ae5e909eded9a Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Thu, 11 Apr 2024 17:01:51 -0700 Subject: [PATCH 07/19] .. --- vendor/gsplat-metal/gsplat_metal.metal | 416 ++++++++++++++++++++++++- vendor/gsplat-metal/gsplat_metal.mm | 99 +++++- 2 files changed, 504 insertions(+), 11 deletions(-) diff --git a/vendor/gsplat-metal/gsplat_metal.metal b/vendor/gsplat-metal/gsplat_metal.metal index 2c1d904..9e03935 100644 --- a/vendor/gsplat-metal/gsplat_metal.metal +++ b/vendor/gsplat-metal/gsplat_metal.metal @@ -273,14 +273,14 @@ inline bool compute_cov2d_bounds( } inline float2 project_pix( - constant float *mat, const float3 p, const int2 img_size, const float2 pp + constant float *mat, const float3 p, const uint2 img_size, const float2 pp ) { // ROW MAJOR mat float4 p_hom = transform_4x4(mat, p); float rw = 1.f / (p_hom.w + 1e-6f); float3 p_proj = {p_hom.x * rw, p_hom.y * rw, p_hom.z * rw}; return { - ndc2pix(p_proj.x, img_size.x, pp.x), ndc2pix(p_proj.y, img_size.y, pp.y) + ndc2pix(p_proj.x, (int)img_size.x, pp.x), ndc2pix(p_proj.y, (int)img_size.y, pp.y) }; } @@ -312,6 +312,10 @@ inline float2 read_packed_float2(constant float* arr, int idx) { return float2(arr[2*idx], arr[2*idx+1]); } +inline float2 read_packed_float2(device float* arr, int idx) { + return float2(arr[2*idx], arr[2*idx+1]); +} + inline void write_packed_float2(device float* arr, int idx, float2 val) { arr[2*idx] = val.x; arr[2*idx+1] = val.y; @@ -331,6 +335,10 @@ inline float3 read_packed_float3(constant float* arr, int idx) { return float3(arr[3*idx], arr[3*idx+1], arr[3*idx+2]); } +inline float3 read_packed_float3(device float* arr, int idx) { + return float3(arr[3*idx], arr[3*idx+1], arr[3*idx+2]); +} + inline void write_packed_float3(device float* arr, int idx, float3 val) { arr[3*idx] = val.x; arr[3*idx+1] = val.y; @@ -359,7 +367,7 @@ kernel void project_gaussians_forward_kernel( constant float* viewmat, constant float* projmat, constant float4& intrins, - constant int2& img_size, + constant uint2& img_size, constant float& clip_thresh, device float* covs3d, device float* xys, // float2 @@ -425,7 +433,7 @@ kernel void project_gaussians_forward_kernel( // TODO(achan): this is actually the nd_rasterize_forward_kernel kernel void rasterize_forward_kernel( - constant int3& img_size, + constant uint3& img_size, constant uint& channels, constant int32_t* gaussian_ids_sorted, constant int* tile_bins, // int2 @@ -449,10 +457,10 @@ kernel void rasterize_forward_kernel( int32_t j = blockIdx.x * blockDim.x + threadIdx.x; float px = (float)j; float py = (float)i; - int32_t pix_id = i * img_size.x + j; + int32_t pix_id = i * (int)img_size.x + j; // return if out of bounds - if (i >= img_size.y || j >= img_size.x) { + if (i >= (int)img_size.y || j >= (int)img_size.x) { return; } @@ -583,6 +591,98 @@ void sh_coeffs_to_color( } } +void sh_coeffs_to_color_vjp( + const uint degree, + const float3 viewdir, + constant float *v_colors, + device float *v_coeffs +) { + // Expects v_colors to be len CHANNELS + // and v_coeffs to be num_bases * CHANNELS + #pragma unroll + for (int c = 0; c < CHANNELS; ++c) { + v_coeffs[c] = SH_C0 * v_colors[c]; + } + if (degree < 1) { + return; + } + + float norm = sqrt( + viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z + ); + float x = viewdir.x / norm; + float y = viewdir.y / norm; + float z = viewdir.z / norm; + + float xx = x * x; + float xy = x * y; + float xz = x * z; + float yy = y * y; + float yz = y * z; + float zz = z * z; + + #pragma unroll + for (int c = 0; c < CHANNELS; ++c) { + float v1 = -SH_C1 * y; + float v2 = SH_C1 * z; + float v3 = -SH_C1 * x; + v_coeffs[1 * CHANNELS + c] = v1 * v_colors[c]; + v_coeffs[2 * CHANNELS + c] = v2 * v_colors[c]; + v_coeffs[3 * CHANNELS + c] = v3 * v_colors[c]; + if (degree < 2) { + continue; + } + float v4 = SH_C2[0] * xy; + float v5 = SH_C2[1] * yz; + float v6 = SH_C2[2] * (2.f * zz - xx - yy); + float v7 = SH_C2[3] * xz; + float v8 = SH_C2[4] * (xx - yy); + v_coeffs[4 * CHANNELS + c] = v4 * v_colors[c]; + v_coeffs[5 * CHANNELS + c] = v5 * v_colors[c]; + v_coeffs[6 * CHANNELS + c] = v6 * v_colors[c]; + v_coeffs[7 * CHANNELS + c] = v7 * v_colors[c]; + v_coeffs[8 * CHANNELS + c] = v8 * v_colors[c]; + if (degree < 3) { + continue; + } + float v9 = SH_C3[0] * y * (3.f * xx - yy); + float v10 = SH_C3[1] * xy * z; + float v11 = SH_C3[2] * y * (4.f * zz - xx - yy); + float v12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy); + float v13 = SH_C3[4] * x * (4.f * zz - xx - yy); + float v14 = SH_C3[5] * z * (xx - yy); + float v15 = SH_C3[6] * x * (xx - 3.f * yy); + v_coeffs[9 * CHANNELS + c] = v9 * v_colors[c]; + v_coeffs[10 * CHANNELS + c] = v10 * v_colors[c]; + v_coeffs[11 * CHANNELS + c] = v11 * v_colors[c]; + v_coeffs[12 * CHANNELS + c] = v12 * v_colors[c]; + v_coeffs[13 * CHANNELS + c] = v13 * v_colors[c]; + v_coeffs[14 * CHANNELS + c] = v14 * v_colors[c]; + v_coeffs[15 * CHANNELS + c] = v15 * v_colors[c]; + if (degree < 4) { + continue; + } + float v16 = SH_C4[0] * xy * (xx - yy); + float v17 = SH_C4[1] * yz * (3.f * xx - yy); + float v18 = SH_C4[2] * xy * (7.f * zz - 1.f); + float v19 = SH_C4[3] * yz * (7.f * zz - 3.f); + float v20 = SH_C4[4] * (zz * (35.f * zz - 30.f) + 3.f); + float v21 = SH_C4[5] * xz * (7.f * zz - 3.f); + float v22 = SH_C4[6] * (xx - yy) * (7.f * zz - 1.f); + float v23 = SH_C4[7] * xz * (xx - 3.f * yy); + float v24 = SH_C4[8] * (xx * (xx - 3.f * yy) - yy * (3.f * xx - yy)); + v_coeffs[16 * CHANNELS + c] = v16 * v_colors[c]; + v_coeffs[17 * CHANNELS + c] = v17 * v_colors[c]; + v_coeffs[18 * CHANNELS + c] = v18 * v_colors[c]; + v_coeffs[19 * CHANNELS + c] = v19 * v_colors[c]; + v_coeffs[20 * CHANNELS + c] = v20 * v_colors[c]; + v_coeffs[21 * CHANNELS + c] = v21 * v_colors[c]; + v_coeffs[22 * CHANNELS + c] = v22 * v_colors[c]; + v_coeffs[23 * CHANNELS + c] = v23 * v_colors[c]; + v_coeffs[24 * CHANNELS + c] = v24 * v_colors[c]; + } +} + kernel void compute_sh_forward_kernel( constant uint& num_points, constant uint& degree, @@ -590,7 +690,7 @@ kernel void compute_sh_forward_kernel( constant float* viewdirs, // float3 constant float* coeffs, device float* colors, - uint idx [[thread_position_in_threadgroup]] + uint idx [[thread_position_in_grid]] ) { if (idx >= num_points) { return; @@ -605,6 +705,28 @@ kernel void compute_sh_forward_kernel( ); } +kernel void compute_sh_backward_kernel( + constant uint& num_points, + constant uint& degree, + constant uint& degrees_to_use, + constant float* viewdirs, // float3 + constant float* v_colors, + device float* v_coeffs, + uint idx [[thread_position_in_grid]] +) { + if (idx >= num_points) { + return; + } + const uint num_channels = 3; + uint num_bases = num_sh_bases(degree); + uint idx_sh = num_bases * num_channels * idx; + uint idx_col = num_channels * idx; + + sh_coeffs_to_color_vjp( + degrees_to_use, read_packed_float3(viewdirs, idx), &(v_colors[idx_col]), &(v_coeffs[idx_sh]) + ); +} + // kernel to map each intersection from tile ID and depth to a gaussian // writes output to isect_ids and gaussian_ids kernel void map_gaussian_to_intersects_kernel( @@ -879,4 +1001,284 @@ kernel void rasterize_backward_kernel( atomic_fetch_add_explicit(v_opacity + g, v_opacity_local, memory_order_relaxed); } } +} + +// given v_xy_pix, get v_xyz +inline float3 project_pix_vjp( + constant float *mat, const float3 p, const uint2 img_size, const float2 v_xy +) { + // ROW MAJOR mat + float4 p_hom = transform_4x4(mat, p); + float rw = 1.f / (p_hom.w + 1e-6f); + + float3 v_ndc = {0.5f * img_size.x * v_xy.x, 0.5f * img_size.y * v_xy.y, 0.0f}; + float4 v_proj = { + v_ndc.x * rw, v_ndc.y * rw, 0., -(v_ndc.x + v_ndc.y) * rw * rw + }; + // df / d_world = df / d_cam * d_cam / d_world + // = v_proj * P[:3, :3] + return { + mat[0] * v_proj.x + mat[4] * v_proj.y + mat[8] * v_proj.z, + mat[1] * v_proj.x + mat[5] * v_proj.y + mat[9] * v_proj.z, + mat[2] * v_proj.x + mat[6] * v_proj.y + mat[10] * v_proj.z + }; +} + +// compute vjp from df/d_conic to df/c_cov2d +inline void cov2d_to_conic_vjp( + float3 conic, + float3 v_conic, + device float* v_cov2d // float3 +) { + // conic = inverse cov2d + // df/d_cov2d = -conic * df/d_conic * conic + float2x2 X = float2x2(conic.x, conic.y, conic.y, conic.z); + float2x2 G = float2x2(v_conic.x, v_conic.y, v_conic.y, v_conic.z); + float2x2 v_Sigma = -1. * X * G * X; + v_cov2d[0] = v_Sigma[0][0]; + v_cov2d[1] = v_Sigma[1][0] + v_Sigma[0][1]; + v_cov2d[2] = v_Sigma[1][1]; +} + +// output space: 2D covariance, input space: cov3d +void project_cov3d_ewa_vjp( + const float3 mean3d, + constant float* cov3d, + constant float* viewmat, + const float fx, + const float fy, + float3 v_cov2d, + device float* v_mean3d, // float3 + device float* v_cov3d +) { + // viewmat is row major, float3x3 is column major + // upper 3x3 submatrix + // clang-format off + float3x3 W = float3x3( + viewmat[0], viewmat[4], viewmat[8], + viewmat[1], viewmat[5], viewmat[9], + viewmat[2], viewmat[6], viewmat[10] + ); + // clang-format on + float3 p = float3(viewmat[3], viewmat[7], viewmat[11]); + float3 t = W * float3(mean3d.x, mean3d.y, mean3d.z) + p; + float rz = 1.f / t.z; + float rz2 = rz * rz; + + // column major + // we only care about the top 2x2 submatrix + // clang-format off + float3x3 J = float3x3( + fx * rz, 0.f, 0.f, + 0.f, fy * rz, 0.f, + -fx * t.x * rz2, -fy * t.y * rz2, 0.f + ); + float3x3 V = float3x3( + cov3d[0], cov3d[1], cov3d[2], + cov3d[1], cov3d[3], cov3d[4], + cov3d[2], cov3d[4], cov3d[5] + ); + // cov = T * V * Tt; G = df/dcov = v_cov + // -> d/dV = Tt * G * T + // -> df/dT = G * T * Vt + Gt * T * V + float3x3 v_cov = float3x3( + v_cov2d.x, 0.5f * v_cov2d.y, 0.f, + 0.5f * v_cov2d.y, v_cov2d.z, 0.f, + 0.f, 0.f, 0.f + ); + // clang-format on + + float3x3 T = J * W; + float3x3 Tt = transpose(T); + float3x3 Vt = transpose(V); + float3x3 v_V = Tt * v_cov * T; + float3x3 v_T = v_cov * T * Vt + transpose(v_cov) * T * V; + + // vjp of cov3d parameters + // v_cov3d_i = v_V : dV/d_cov3d_i + // where : is frobenius inner product + v_cov3d[0] = v_V[0][0]; + v_cov3d[1] = v_V[0][1] + v_V[1][0]; + v_cov3d[2] = v_V[0][2] + v_V[2][0]; + v_cov3d[3] = v_V[1][1]; + v_cov3d[4] = v_V[1][2] + v_V[2][1]; + v_cov3d[5] = v_V[2][2]; + + // compute df/d_mean3d + // T = J * W + float3x3 v_J = v_T * transpose(W); + float rz3 = rz2 * rz; + float3 v_t = float3( + -fx * rz2 * v_J[2][0], + -fy * rz2 * v_J[2][1], + -fx * rz2 * v_J[0][0] + 2.f * fx * t.x * rz3 * v_J[2][0] - + fy * rz2 * v_J[1][1] + 2.f * fy * t.y * rz3 * v_J[2][1] + ); + // printf("v_t %.2f %.2f %.2f\n", v_t[0], v_t[1], v_t[2]); + // printf("W %.2f %.2f %.2f\n", W[0][0], W[0][1], W[0][2]); + v_mean3d[0] += (float)dot(v_t, W[0]); + v_mean3d[1] += (float)dot(v_t, W[1]); + v_mean3d[2] += (float)dot(v_t, W[2]); +} + +inline float4 quat_to_rotmat_vjp(const float4 quat, const float3x3 v_R) { + float s = rsqrt( + quat.w * quat.w + quat.x * quat.x + quat.y * quat.y + quat.z * quat.z + ); + float w = quat.x * s; + float x = quat.y * s; + float y = quat.z * s; + float z = quat.w * s; + + float4 v_quat; + // v_R is COLUMN MAJOR + // w element stored in x field + v_quat.x = + 2.f * ( + // v_quat.w = 2.f * ( + x * (v_R[1][2] - v_R[2][1]) + y * (v_R[2][0] - v_R[0][2]) + + z * (v_R[0][1] - v_R[1][0]) + ); + // x element in y field + v_quat.y = + 2.f * + ( + // v_quat.x = 2.f * ( + -2.f * x * (v_R[1][1] + v_R[2][2]) + y * (v_R[0][1] + v_R[1][0]) + + z * (v_R[0][2] + v_R[2][0]) + w * (v_R[1][2] - v_R[2][1]) + ); + // y element in z field + v_quat.z = + 2.f * + ( + // v_quat.y = 2.f * ( + x * (v_R[0][1] + v_R[1][0]) - 2.f * y * (v_R[0][0] + v_R[2][2]) + + z * (v_R[1][2] + v_R[2][1]) + w * (v_R[2][0] - v_R[0][2]) + ); + // z element in w field + v_quat.w = + 2.f * + ( + // v_quat.z = 2.f * ( + x * (v_R[0][2] + v_R[2][0]) + y * (v_R[1][2] + v_R[2][1]) - + 2.f * z * (v_R[0][0] + v_R[1][1]) + w * (v_R[0][1] - v_R[1][0]) + ); + return v_quat; +} + +// given cotangent v in output space (e.g. d_L/d_cov3d) in R(6) +// compute vJp for scale and rotation +void scale_rot_to_cov3d_vjp( + const float3 scale, + const float glob_scale, + const float4 quat, + const device float* v_cov3d, + device float* v_scale, // float3 + device float* v_quat // float4 +) { + // cov3d is upper triangular elements of matrix + // off-diagonal elements count grads from both ij and ji elements, + // must halve when expanding back into symmetric matrix + float3x3 v_V = float3x3( + v_cov3d[0], + 0.5 * v_cov3d[1], + 0.5 * v_cov3d[2], + 0.5 * v_cov3d[1], + v_cov3d[3], + 0.5 * v_cov3d[4], + 0.5 * v_cov3d[2], + 0.5 * v_cov3d[4], + v_cov3d[5] + ); + float3x3 R = quat_to_rotmat(quat); + float3x3 S = scale_to_mat(scale, glob_scale); + float3x3 M = R * S; + // https://math.stackexchange.com/a/3850121 + // for D = W * X, G = df/dD + // df/dW = G * XT, df/dX = WT * G + float3x3 v_M = 2.f * v_V * M; + v_scale[0] = (float)dot(R[0], v_M[0]); + v_scale[1] = (float)dot(R[1], v_M[1]); + v_scale[2] = (float)dot(R[2], v_M[2]); + + float3x3 v_R = v_M * S; + float4 out_v_quat = quat_to_rotmat_vjp(quat, v_R); + v_quat[0] = out_v_quat.x; + v_quat[1] = out_v_quat.y; + v_quat[2] = out_v_quat.z; + v_quat[3] = out_v_quat.w; +} + +kernel void project_gaussians_backward_kernel( + constant int& num_points, + constant float* means3d, // float3 + constant float* scales, // float3 + constant float& glob_scale, + constant float* quats, // float4 + constant float* viewmat, + constant float* projmat, + constant float4& intrins, + constant uint2& img_size, + constant float* cov3d, + constant int* radii, + constant float* conics, // float3 + constant float* v_xy, // float2 + constant float* v_depth, + constant float* v_conic, // float3 + device float* v_cov2d, // float3 + device float* v_cov3d, + device float* v_mean3d, // float3 + device float* v_scale, // float3 + device float* v_quat, // float4 + uint idx [[thread_position_in_grid]] +) { + if (idx >= num_points || radii[idx] <= 0) { + return; + } + float3 p_world = read_packed_float3(means3d, idx); + float fx = intrins.x; + float fy = intrins.y; + float cx = intrins.z; + float cy = intrins.w; + // get v_mean3d from v_xy + write_packed_float3( + v_mean3d, idx, + project_pix_vjp(projmat, p_world, img_size, read_packed_float2(v_xy, idx)) + ); + + // get z gradient contribution to mean3d gradient + // z = viemwat[8] * mean3d.x + viewmat[9] * mean3d.y + viewmat[10] * + // mean3d.z + viewmat[11] + float v_z = v_depth[idx]; + write_packed_float3( + v_mean3d, idx, + read_packed_float3(v_mean3d, idx) + float3(viewmat[8], viewmat[9], viewmat[10]) * v_z + ); + + // get v_cov2d + cov2d_to_conic_vjp( + read_packed_float3(conics, idx), + read_packed_float3(v_conic, idx), + &(v_cov2d[3*idx]) + ); + // get v_cov3d (and v_mean3d contribution) + project_cov3d_ewa_vjp( + p_world, + &(cov3d[6 * idx]), + viewmat, + fx, + fy, + read_packed_float3(v_cov2d, idx), + &(v_mean3d[3*idx]), + &(v_cov3d[6 * idx]) + ); + // get v_scale and v_quat + scale_rot_to_cov3d_vjp( + read_packed_float3(scales, idx), + glob_scale, + read_packed_float4(quats, idx), + &(v_cov3d[6 * idx]), + &(v_scale[3*idx]), + &(v_quat[4*idx]) + ); } \ No newline at end of file diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index 89a084d..e1fa6d8 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -115,9 +115,9 @@ @implementation DummyClassForPathHack GSPLAT_METAL_ADD_KERNEL(rasterize_forward_kernel); GSPLAT_METAL_ADD_KERNEL(rasterize_backward_kernel); GSPLAT_METAL_ADD_KERNEL(project_gaussians_forward_kernel); - // GSPLAT_METAL_ADD_KERNEL(project_gaussians_backward_kernel); + GSPLAT_METAL_ADD_KERNEL(project_gaussians_backward_kernel); GSPLAT_METAL_ADD_KERNEL(compute_sh_forward_kernel); - // GSPLAT_METAL_ADD_KERNEL(compute_sh_backward_kernel); + GSPLAT_METAL_ADD_KERNEL(compute_sh_backward_kernel); // GSPLAT_METAL_ADD_KERNEL(compute_cov2d_bounds_kernel); GSPLAT_METAL_ADD_KERNEL(map_gaussian_to_intersects_kernel); GSPLAT_METAL_ADD_KERNEL(get_tile_bin_edges_kernel); @@ -250,6 +250,44 @@ void free_gsplat_metal_context(MetalContext* ctx) { unsigned num_bases = num_sh_bases(degree); torch::Tensor v_coeffs = torch::zeros({num_points, num_bases, 3}, v_colors.options()); + + // Get a reference to the command buffer for the MPS stream + id command_buffer = torch::mps::get_command_buffer(); + TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + + // Dispatch the kernel + MetalContext* ctx = get_global_context(); + dispatch_sync(ctx->d_queue, ^(){ + // Start a compute pass + id encoder = [command_buffer computeCommandEncoder]; + TORCH_CHECK(encoder, "Failed to create compute command encoder"); + + // Encode the pipeline state object + id cpso = ctx->compute_sh_backward_kernel_cpso; + [encoder setComputePipelineState:cpso]; + + // Set the tensor buffers + ENC_SCALAR(encoder, num_points, 0); + ENC_SCALAR(encoder, degree, 1); + ENC_SCALAR(encoder, degrees_to_use, 2); + ENC_TENSOR(encoder, viewdirs, 3); + ENC_TENSOR(encoder, v_colors, 4); + ENC_TENSOR(encoder, v_coeffs, 5); + + // Set the grid threadgroup sizes + MTLSize grid_size = MTLSizeMake(num_points, 1, 1); + + NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + + // Dispatch the compute command + [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; + [encoder endEncoding]; + + // Commit the work + torch::mps::synchronize(); + }); + return v_coeffs; } @@ -303,7 +341,7 @@ void free_gsplat_metal_context(MetalContext* ctx) { TORCH_CHECK(encoder, "Failed to create compute command encoder"); float intrins[4] = {fx, fy, cx, cy}; - int32_t img_size[2] = {(int32_t)img_width, (int32_t)img_height}; + uint32_t img_size[2] = {img_width, img_height}; // Encode the pipeline state object id cpso = ctx->project_gaussians_forward_kernel_cpso; @@ -384,6 +422,59 @@ void free_gsplat_metal_context(MetalContext* ctx) { torch::Tensor v_quat = torch::zeros({num_points, 4}, means3d.options().dtype(torch::kFloat32)); + // Get a reference to the command buffer for the MPS stream + id command_buffer = torch::mps::get_command_buffer(); + TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + + // Dispatch the kernel + MetalContext* ctx = get_global_context(); + dispatch_sync(ctx->d_queue, ^(){ + // Start a compute pass + id encoder = [command_buffer computeCommandEncoder]; + TORCH_CHECK(encoder, "Failed to create compute command encoder"); + + float intrins[4] = {fx, fy, cx, cy}; + uint32_t img_size[2] = {img_width, img_height}; + + // Encode the pipeline state object + id cpso = ctx->project_gaussians_backward_kernel_cpso; + [encoder setComputePipelineState:cpso]; + + // Set the tensor buffers + ENC_SCALAR(encoder, num_points, 0); + ENC_TENSOR(encoder, means3d, 1); + ENC_TENSOR(encoder, scales, 2); + ENC_SCALAR(encoder, glob_scale, 3); + ENC_TENSOR(encoder, quats, 4); + ENC_TENSOR(encoder, viewmat, 5); + ENC_TENSOR(encoder, projmat, 6); + ENC_ARRAY(encoder, intrins, 7); + ENC_ARRAY(encoder, img_size, 8); + ENC_TENSOR(encoder, cov3d, 9); + ENC_TENSOR(encoder, radii, 10); + ENC_TENSOR(encoder, conics, 11); + ENC_TENSOR(encoder, v_xy, 12); + ENC_TENSOR(encoder, v_depth, 13); + ENC_TENSOR(encoder, v_conic, 14); + ENC_TENSOR(encoder, v_cov2d, 15); + ENC_TENSOR(encoder, v_cov3d, 16); + ENC_TENSOR(encoder, v_mean3d, 17); + ENC_TENSOR(encoder, v_scale, 18); + ENC_TENSOR(encoder, v_quat, 19); + + // Set the grid threadgroup sizes + MTLSize grid_size = MTLSizeMake(num_points, 1, 1); + NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + + // Dispatch the compute command + [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; + [encoder endEncoding]; + + // Commit the work + torch::mps::synchronize(); + }); + return std::make_tuple(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat); } @@ -546,7 +637,7 @@ void free_gsplat_metal_context(MetalContext* ctx) { id cpso = ctx->rasterize_forward_kernel_cpso; [encoder setComputePipelineState:cpso]; - int32_t img_size_dim3[4] = {std::get<0>(img_size), std::get<1>(img_size), std::get<2>(img_size), 0xDEAD}; + uint32_t img_size_dim3[4] = {(uint32_t)std::get<0>(img_size), (uint32_t)std::get<1>(img_size), (uint32_t)std::get<2>(img_size), 0xDEAD}; int32_t block_size_dim2[2] = {std::get<0>(block), std::get<1>(block)}; // Set the tensor buffers From 38cfc99deee8d284939a920037d2daf767deb177 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Thu, 11 Apr 2024 17:51:02 -0700 Subject: [PATCH 08/19] replace macro with helper fn --- vendor/gsplat-metal/gsplat_metal.mm | 527 ++++++++++++---------------- 1 file changed, 222 insertions(+), 305 deletions(-) diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index e1fa6d8..9f75a7b 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -165,6 +165,96 @@ void free_gsplat_metal_context(MetalContext* ctx) { #define ENC_ARRAY(encoder, x, i) [encoder setBytes:x length:sizeof(x) atIndex:i] #define ENC_TENSOR(encoder, x, i) [encoder setBuffer:getMTLBufferStorage(x) offset:x.storage_offset() * x.element_size() atIndex:i] +enum struct EncodeType { + FLOAT, + INT, + UINT, + ARRAY, + TENSOR +}; + +struct EncodeArg { + static EncodeArg scalar(float x) { + return EncodeArg(EncodeType::FLOAT, x, 0, 0, nullptr, 0, nullptr); + } + static EncodeArg scalar(int32_t x) { + return EncodeArg(EncodeType::INT, 0, x, 0, nullptr, 0, nullptr); + } + static EncodeArg scalar(uint32_t x) { + return EncodeArg(EncodeType::UINT, 0, 0, x, nullptr, 0, nullptr); + } + static EncodeArg array(void* x, size_t numBytes) { + return EncodeArg(EncodeType::ARRAY, 0, 0, 0, x, numBytes, nullptr); + } + static EncodeArg tensor(const torch::Tensor& x) { + return EncodeArg(EncodeType::TENSOR, 0, 0, 0, nullptr, 0, &x); + } +private: + EncodeArg( + EncodeType type, + float fScalar, + int32_t i32Scalar, + uint32_t u32Scalar, + void* array, + size_t arrayNumBytes, + const torch::Tensor* tensor + ) : _type(type), _fScalar(fScalar), _i32Scalar(i32Scalar), _u32Scalar(u32Scalar), _array(array), _arrayNumBytes(arrayNumBytes), _tensor(tensor) {} + EncodeType _type; + float _fScalar; + int32_t _i32Scalar; + uint32_t _u32Scalar; + void* _array; + size_t _arrayNumBytes; + const torch::Tensor* _tensor; + + friend void dispatchKernel(MetalContext* ctx, id cpso, MTLSize grid_size, MTLSize thread_group_size, std::vector args); +}; + +void dispatchKernel(MetalContext* ctx, id cpso, MTLSize grid_size, MTLSize thread_group_size, std::vector args) { + // Get a reference to the command buffer for the MPS stream + id command_buffer = torch::mps::get_command_buffer(); + TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + + // Dispatch the kernel + dispatch_sync(ctx->d_queue, ^(){ + // Start a compute pass + id encoder = [command_buffer computeCommandEncoder]; + TORCH_CHECK(encoder, "Failed to create compute command encoder"); + + // Encode the pipeline state object + [encoder setComputePipelineState:cpso]; + + // Encode arguments + for (size_t i = 0; i < args.size(); ++i) { + const EncodeArg& arg = args[i]; + switch (arg._type) { + case EncodeType::FLOAT: + [encoder setBytes:&arg._fScalar length:sizeof(arg._fScalar) atIndex:i]; + break; + case EncodeType::INT: + [encoder setBytes:&arg._i32Scalar length:sizeof(arg._i32Scalar) atIndex:i]; + break; + case EncodeType::UINT: + [encoder setBytes:&arg._u32Scalar length:sizeof(arg._u32Scalar) atIndex:i]; + break; + case EncodeType::ARRAY: + [encoder setBytes:arg._array length:arg._arrayNumBytes atIndex:i]; + break; + case EncodeType::TENSOR: + [encoder setBuffer:getMTLBufferStorage(*arg._tensor) offset:arg._tensor->storage_offset() * arg._tensor->element_size() atIndex:i]; + break; + } + } + + // Dispatch the compute command + [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; + [encoder endEncoding]; + + // Commit the work + torch::mps::synchronize(); + }); +} + std::tuple< torch::Tensor, // output conics torch::Tensor> // output radii @@ -193,41 +283,19 @@ void free_gsplat_metal_context(MetalContext* ctx) { } torch::Tensor colors = torch::empty({num_points, 3}, coeffs.options()); - // Get a reference to the command buffer for the MPS stream - id command_buffer = torch::mps::get_command_buffer(); - TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); - // Dispatch the kernel MetalContext* ctx = get_global_context(); - dispatch_sync(ctx->d_queue, ^(){ - // Start a compute pass - id encoder = [command_buffer computeCommandEncoder]; - TORCH_CHECK(encoder, "Failed to create compute command encoder"); - - // Encode the pipeline state object - id cpso = ctx->compute_sh_forward_kernel_cpso; - [encoder setComputePipelineState:cpso]; - - // Set the tensor buffers - ENC_SCALAR(encoder, num_points, 0); - ENC_SCALAR(encoder, degree, 1); - ENC_SCALAR(encoder, degrees_to_use, 2); - ENC_TENSOR(encoder, viewdirs, 3); - ENC_TENSOR(encoder, coeffs, 4); - ENC_TENSOR(encoder, colors, 5); - - // Set the grid threadgroup sizes - MTLSize grid_size = MTLSizeMake(num_points, 1, 1); - - NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); - MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - - // Dispatch the compute command - [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; - [encoder endEncoding]; - - // Commit the work - torch::mps::synchronize(); + MTLSize grid_size = MTLSizeMake(num_points, 1, 1); + NSUInteger num_threads_per_group = + MIN(ctx->compute_sh_forward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + dispatchKernel(ctx, ctx->compute_sh_forward_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::scalar(num_points), + EncodeArg::scalar(degree), + EncodeArg::scalar(degrees_to_use), + EncodeArg::tensor(viewdirs), + EncodeArg::tensor(coeffs), + EncodeArg::tensor(colors) }); return colors; } @@ -251,41 +319,19 @@ void free_gsplat_metal_context(MetalContext* ctx) { torch::Tensor v_coeffs = torch::zeros({num_points, num_bases, 3}, v_colors.options()); - // Get a reference to the command buffer for the MPS stream - id command_buffer = torch::mps::get_command_buffer(); - TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); - // Dispatch the kernel MetalContext* ctx = get_global_context(); - dispatch_sync(ctx->d_queue, ^(){ - // Start a compute pass - id encoder = [command_buffer computeCommandEncoder]; - TORCH_CHECK(encoder, "Failed to create compute command encoder"); - - // Encode the pipeline state object - id cpso = ctx->compute_sh_backward_kernel_cpso; - [encoder setComputePipelineState:cpso]; - - // Set the tensor buffers - ENC_SCALAR(encoder, num_points, 0); - ENC_SCALAR(encoder, degree, 1); - ENC_SCALAR(encoder, degrees_to_use, 2); - ENC_TENSOR(encoder, viewdirs, 3); - ENC_TENSOR(encoder, v_colors, 4); - ENC_TENSOR(encoder, v_coeffs, 5); - - // Set the grid threadgroup sizes - MTLSize grid_size = MTLSizeMake(num_points, 1, 1); - - NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); - MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - - // Dispatch the compute command - [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; - [encoder endEncoding]; - - // Commit the work - torch::mps::synchronize(); + MTLSize grid_size = MTLSizeMake(num_points, 1, 1); + NSUInteger num_threads_per_group = + MIN(ctx->compute_sh_backward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + dispatchKernel(ctx, ctx->compute_sh_backward_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::scalar(num_points), + EncodeArg::scalar(degree), + EncodeArg::scalar(degrees_to_use), + EncodeArg::tensor(viewdirs), + EncodeArg::tensor(v_colors), + EncodeArg::tensor(v_coeffs) }); return v_coeffs; @@ -329,53 +375,32 @@ void free_gsplat_metal_context(MetalContext* ctx) { torch::Tensor num_tiles_hit_d = torch::zeros({num_points}, means3d.options().dtype(torch::kInt32)); - // Get a reference to the command buffer for the MPS stream - id command_buffer = torch::mps::get_command_buffer(); - TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + float intrins[4] = {fx, fy, cx, cy}; + uint32_t img_size[2] = {img_width, img_height}; // Dispatch the kernel MetalContext* ctx = get_global_context(); - dispatch_sync(ctx->d_queue, ^(){ - // Start a compute pass - id encoder = [command_buffer computeCommandEncoder]; - TORCH_CHECK(encoder, "Failed to create compute command encoder"); - - float intrins[4] = {fx, fy, cx, cy}; - uint32_t img_size[2] = {img_width, img_height}; - - // Encode the pipeline state object - id cpso = ctx->project_gaussians_forward_kernel_cpso; - [encoder setComputePipelineState:cpso]; - - // Set the tensor buffers - ENC_SCALAR(encoder, num_points, 0); - ENC_TENSOR(encoder, means3d, 1); - ENC_TENSOR(encoder, scales, 2); - ENC_SCALAR(encoder, glob_scale, 3); - ENC_TENSOR(encoder, quats, 4); - ENC_TENSOR(encoder, viewmat, 5); - ENC_TENSOR(encoder, projmat, 6); - ENC_ARRAY(encoder, intrins, 7); - ENC_ARRAY(encoder, img_size, 8); - ENC_SCALAR(encoder, clip_thresh, 9); - ENC_TENSOR(encoder, cov3d_d, 10); - ENC_TENSOR(encoder, xys_d, 11); - ENC_TENSOR(encoder, depths_d, 12); - ENC_TENSOR(encoder, radii_d, 13); - ENC_TENSOR(encoder, conics_d, 14); - ENC_TENSOR(encoder, num_tiles_hit_d, 15); - - // Set the grid threadgroup sizes - MTLSize grid_size = MTLSizeMake(num_points, 1, 1); - NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); - MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - - // Dispatch the compute command - [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; - [encoder endEncoding]; - - // Commit the work - torch::mps::synchronize(); + MTLSize grid_size = MTLSizeMake(num_points, 1, 1); + NSUInteger num_threads_per_group = + MIN(ctx->project_gaussians_forward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + dispatchKernel(ctx, ctx->project_gaussians_forward_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::scalar(num_points), + EncodeArg::tensor(means3d), + EncodeArg::tensor(scales), + EncodeArg::scalar(glob_scale), + EncodeArg::tensor(quats), + EncodeArg::tensor(viewmat), + EncodeArg::tensor(projmat), + EncodeArg::array(intrins, sizeof(intrins)), + EncodeArg::array(img_size, sizeof(img_size)), + EncodeArg::scalar(clip_thresh), + EncodeArg::tensor(cov3d_d), + EncodeArg::tensor(xys_d), + EncodeArg::tensor(depths_d), + EncodeArg::tensor(radii_d), + EncodeArg::tensor(conics_d), + EncodeArg::tensor(num_tiles_hit_d) }); return std::make_tuple( @@ -422,57 +447,35 @@ void free_gsplat_metal_context(MetalContext* ctx) { torch::Tensor v_quat = torch::zeros({num_points, 4}, means3d.options().dtype(torch::kFloat32)); - // Get a reference to the command buffer for the MPS stream - id command_buffer = torch::mps::get_command_buffer(); - TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + float intrins[4] = {fx, fy, cx, cy}; + uint32_t img_size[2] = {img_width, img_height}; - // Dispatch the kernel MetalContext* ctx = get_global_context(); - dispatch_sync(ctx->d_queue, ^(){ - // Start a compute pass - id encoder = [command_buffer computeCommandEncoder]; - TORCH_CHECK(encoder, "Failed to create compute command encoder"); - - float intrins[4] = {fx, fy, cx, cy}; - uint32_t img_size[2] = {img_width, img_height}; - - // Encode the pipeline state object - id cpso = ctx->project_gaussians_backward_kernel_cpso; - [encoder setComputePipelineState:cpso]; - - // Set the tensor buffers - ENC_SCALAR(encoder, num_points, 0); - ENC_TENSOR(encoder, means3d, 1); - ENC_TENSOR(encoder, scales, 2); - ENC_SCALAR(encoder, glob_scale, 3); - ENC_TENSOR(encoder, quats, 4); - ENC_TENSOR(encoder, viewmat, 5); - ENC_TENSOR(encoder, projmat, 6); - ENC_ARRAY(encoder, intrins, 7); - ENC_ARRAY(encoder, img_size, 8); - ENC_TENSOR(encoder, cov3d, 9); - ENC_TENSOR(encoder, radii, 10); - ENC_TENSOR(encoder, conics, 11); - ENC_TENSOR(encoder, v_xy, 12); - ENC_TENSOR(encoder, v_depth, 13); - ENC_TENSOR(encoder, v_conic, 14); - ENC_TENSOR(encoder, v_cov2d, 15); - ENC_TENSOR(encoder, v_cov3d, 16); - ENC_TENSOR(encoder, v_mean3d, 17); - ENC_TENSOR(encoder, v_scale, 18); - ENC_TENSOR(encoder, v_quat, 19); - - // Set the grid threadgroup sizes - MTLSize grid_size = MTLSizeMake(num_points, 1, 1); - NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); - MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - - // Dispatch the compute command - [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; - [encoder endEncoding]; - - // Commit the work - torch::mps::synchronize(); + MTLSize grid_size = MTLSizeMake(num_points, 1, 1); + NSUInteger num_threads_per_group = + MIN(ctx->project_gaussians_backward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + dispatchKernel(ctx, ctx->project_gaussians_backward_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::scalar(num_points), + EncodeArg::tensor(means3d), + EncodeArg::tensor(scales), + EncodeArg::scalar(glob_scale), + EncodeArg::tensor(quats), + EncodeArg::tensor(viewmat), + EncodeArg::tensor(projmat), + EncodeArg::array(intrins, sizeof(intrins)), + EncodeArg::array(img_size, sizeof(img_size)), + EncodeArg::tensor(cov3d), + EncodeArg::tensor(radii), + EncodeArg::tensor(conics), + EncodeArg::tensor(v_xy), + EncodeArg::tensor(v_depth), + EncodeArg::tensor(v_conic), + EncodeArg::tensor(v_cov2d), + EncodeArg::tensor(v_cov3d), + EncodeArg::tensor(v_mean3d), + EncodeArg::tensor(v_scale), + EncodeArg::tensor(v_quat), }); return std::make_tuple(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat); @@ -497,42 +500,20 @@ void free_gsplat_metal_context(MetalContext* ctx) { torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32)); torch::Tensor isect_ids_unsorted = torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64)); - - // Get a reference to the command buffer for the MPS stream - id command_buffer = torch::mps::get_command_buffer(); - TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); - // Dispatch the kernel MetalContext* ctx = get_global_context(); - dispatch_sync(ctx->d_queue, ^(){ - // Start a compute pass - id encoder = [command_buffer computeCommandEncoder]; - TORCH_CHECK(encoder, "Failed to create compute command encoder"); - - // Encode the pipeline state object - id cpso = ctx->map_gaussian_to_intersects_kernel_cpso; - [encoder setComputePipelineState:cpso]; - - // Set the tensor buffers - ENC_SCALAR(encoder, num_points, 0); - ENC_TENSOR(encoder, xys, 1); - ENC_TENSOR(encoder, depths, 2); - ENC_TENSOR(encoder, radii, 3); - ENC_TENSOR(encoder, cum_tiles_hit, 4); - ENC_TENSOR(encoder, isect_ids_unsorted, 5); - ENC_TENSOR(encoder, gaussian_ids_unsorted, 6); - - // Set the grid threadgroup sizes - MTLSize grid_size = MTLSizeMake(num_points, 1, 1); - NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); - MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - - // Dispatch the compute command - [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; - [encoder endEncoding]; - - // Commit the work - torch::mps::synchronize(); + MTLSize grid_size = MTLSizeMake(num_points, 1, 1); + NSUInteger num_threads_per_group = + MIN(ctx->map_gaussian_to_intersects_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + dispatchKernel(ctx, ctx->map_gaussian_to_intersects_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::scalar(num_points), + EncodeArg::tensor(xys), + EncodeArg::tensor(depths), + EncodeArg::tensor(radii), + EncodeArg::tensor(cum_tiles_hit), + EncodeArg::tensor(isect_ids_unsorted), + EncodeArg::tensor(gaussian_ids_unsorted) }); return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted); @@ -547,37 +528,15 @@ void free_gsplat_metal_context(MetalContext* ctx) { {num_intersects, 2}, isect_ids_sorted.options().dtype(torch::kInt32) ); - // Get a reference to the command buffer for the MPS stream - id command_buffer = torch::mps::get_command_buffer(); - TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); - - // Dispatch the kernel MetalContext* ctx = get_global_context(); - dispatch_sync(ctx->d_queue, ^(){ - // Start a compute pass - id encoder = [command_buffer computeCommandEncoder]; - TORCH_CHECK(encoder, "Failed to create compute command encoder"); - - // Encode the pipeline state object - id cpso = ctx->get_tile_bin_edges_kernel_cpso; - [encoder setComputePipelineState:cpso]; - - // Set the tensor buffers - ENC_SCALAR(encoder, num_intersects, 0); - ENC_TENSOR(encoder, isect_ids_sorted, 1); - ENC_TENSOR(encoder, tile_bins, 2); - - // Set the grid threadgroup sizes - MTLSize grid_size = MTLSizeMake(num_intersects, 1, 1); - NSUInteger num_threads_per_group = MIN(cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_intersects); - MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - - // Dispatch the compute command - [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; - [encoder endEncoding]; - - // Commit the work - torch::mps::synchronize(); + MTLSize grid_size = MTLSizeMake(num_intersects, 1, 1); + NSUInteger num_threads_per_group = + MIN(ctx->get_tile_bin_edges_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_intersects); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + dispatchKernel(ctx, ctx->get_tile_bin_edges_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::scalar(num_intersects), + EncodeArg::tensor(isect_ids_sorted), + EncodeArg::tensor(tile_bins) }); return tile_bins; @@ -622,49 +581,26 @@ void free_gsplat_metal_context(MetalContext* ctx) { {img_height, img_width}, xys.options().dtype(torch::kInt32) ); - // Get a reference to the command buffer for the MPS stream - id command_buffer = torch::mps::get_command_buffer(); - TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + uint32_t img_size_dim3[4] = {(uint32_t)std::get<0>(img_size), (uint32_t)std::get<1>(img_size), (uint32_t)std::get<2>(img_size), 0xDEAD}; + int32_t block_size_dim2[2] = {std::get<0>(block), std::get<1>(block)}; - // Dispatch the kernel MetalContext* ctx = get_global_context(); - dispatch_sync(ctx->d_queue, ^(){ - // Start a compute pass - id encoder = [command_buffer computeCommandEncoder]; - TORCH_CHECK(encoder, "Failed to create compute command encoder"); - - // Encode the pipeline state object - id cpso = ctx->rasterize_forward_kernel_cpso; - [encoder setComputePipelineState:cpso]; - - uint32_t img_size_dim3[4] = {(uint32_t)std::get<0>(img_size), (uint32_t)std::get<1>(img_size), (uint32_t)std::get<2>(img_size), 0xDEAD}; - int32_t block_size_dim2[2] = {std::get<0>(block), std::get<1>(block)}; - - // Set the tensor buffers - ENC_ARRAY(encoder, img_size_dim3, 0); - ENC_SCALAR(encoder, channels, 1); - ENC_TENSOR(encoder, gaussian_ids_sorted, 2); - ENC_TENSOR(encoder, tile_bins, 3); - ENC_TENSOR(encoder, xys, 4); - ENC_TENSOR(encoder, conics, 5); - ENC_TENSOR(encoder, colors, 6); - ENC_TENSOR(encoder, opacities, 7); - ENC_TENSOR(encoder, final_Ts, 8); - ENC_TENSOR(encoder, final_idx, 9); - ENC_TENSOR(encoder, out_img, 10); - ENC_TENSOR(encoder, background, 11); - ENC_ARRAY(encoder, block_size_dim2, 12); - - // Set the grid threadgroup sizes - MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); - MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1); - - // Dispatch the compute command - [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; - [encoder endEncoding]; - - // Commit the work - torch::mps::synchronize(); + MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); + MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1); + dispatchKernel(ctx, ctx->rasterize_forward_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::array(img_size_dim3, sizeof(img_size_dim3)), + EncodeArg::scalar(channels), + EncodeArg::tensor(gaussian_ids_sorted), + EncodeArg::tensor(tile_bins), + EncodeArg::tensor(xys), + EncodeArg::tensor(conics), + EncodeArg::tensor(colors), + EncodeArg::tensor(opacities), + EncodeArg::tensor(final_Ts), + EncodeArg::tensor(final_idx), + EncodeArg::tensor(out_img), + EncodeArg::tensor(background), + EncodeArg::array(block_size_dim2, sizeof(block_size_dim2)) }); return std::make_tuple(out_img, final_Ts, final_idx); @@ -789,48 +725,29 @@ void free_gsplat_metal_context(MetalContext* ctx) { id command_buffer = torch::mps::get_command_buffer(); TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); - // Dispatch the kernel - MetalContext* ctx = get_global_context(); - dispatch_sync(ctx->d_queue, ^(){ - // Start a compute pass - id encoder = [command_buffer computeCommandEncoder]; - TORCH_CHECK(encoder, "Failed to create compute command encoder"); + uint32_t img_size[2] = {img_height, img_width}; - // Encode the pipeline state object - id cpso = ctx->rasterize_backward_kernel_cpso; - [encoder setComputePipelineState:cpso]; - - uint32_t img_size[2] = {img_height, img_width}; - - // Set the tensor buffers - ENC_ARRAY(encoder, img_size, 0); - ENC_TENSOR(encoder, gaussians_ids_sorted, 1); - ENC_TENSOR(encoder, tile_bins, 2); - ENC_TENSOR(encoder, xys, 3); - ENC_TENSOR(encoder, conics, 4); - ENC_TENSOR(encoder, colors, 5); - ENC_TENSOR(encoder, opacities, 6); - ENC_TENSOR(encoder, background, 7); - ENC_TENSOR(encoder, final_Ts, 8); - ENC_TENSOR(encoder, final_idx, 9); - ENC_TENSOR(encoder, v_output, 10); - ENC_TENSOR(encoder, v_output_alpha, 11); - ENC_TENSOR(encoder, v_xy, 12); - ENC_TENSOR(encoder, v_conic, 13); - ENC_TENSOR(encoder, v_colors, 14); - ENC_TENSOR(encoder, v_opacity, 15); - ENC_TENSOR(encoder, debug, 16); - - // Set the grid threadgroup sizes - MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); - MTLSize thread_group_size = MTLSizeMake(BLOCK_X, BLOCK_Y, 1); - - // Dispatch the compute command - [encoder dispatchThreads:grid_size threadsPerThreadgroup:thread_group_size]; - [encoder endEncoding]; - - // Commit the work - torch::mps::synchronize(); + MetalContext* ctx = get_global_context(); + MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); + MTLSize thread_group_size = MTLSizeMake(BLOCK_X, BLOCK_Y, 1); + dispatchKernel(ctx, ctx->rasterize_backward_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::array(img_size, sizeof(img_size)), + EncodeArg::tensor(gaussians_ids_sorted), + EncodeArg::tensor(tile_bins), + EncodeArg::tensor(xys), + EncodeArg::tensor(conics), + EncodeArg::tensor(colors), + EncodeArg::tensor(opacities), + EncodeArg::tensor(background), + EncodeArg::tensor(final_Ts), + EncodeArg::tensor(final_idx), + EncodeArg::tensor(v_output), + EncodeArg::tensor(v_output_alpha), + EncodeArg::tensor(v_xy), + EncodeArg::tensor(v_conic), + EncodeArg::tensor(v_colors), + EncodeArg::tensor(v_opacity), + EncodeArg::tensor(debug) }); return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); From 8125246601c5e959fa6f5c3a83bedce43f3eb098 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Thu, 11 Apr 2024 20:16:23 -0700 Subject: [PATCH 09/19] use simd groups --- vendor/gsplat-metal/gsplat_metal.metal | 114 ++++++++++++++----------- vendor/gsplat-metal/gsplat_metal.mm | 16 ++++ 2 files changed, 80 insertions(+), 50 deletions(-) diff --git a/vendor/gsplat-metal/gsplat_metal.metal b/vendor/gsplat-metal/gsplat_metal.metal index 9e03935..da74cc4 100644 --- a/vendor/gsplat-metal/gsplat_metal.metal +++ b/vendor/gsplat-metal/gsplat_metal.metal @@ -796,38 +796,45 @@ kernel void get_tile_bin_edges_kernel( } } -float block_reduce_sum(float val, uint tr, threadgroup float* shared) { - if (tr < BLOCK_SIZE) { - shared[tr] = val; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (uint s = 1; s < BLOCK_SIZE; s *= 2) { - if (tr % (2 * s) == 0 && tr + s < BLOCK_SIZE) { - shared[tr] += shared[tr + s]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } +inline int warp_reduce_all_max(int val, const int warp_size) { + // This uses an xor so that all threads in a warp get the same result + for ( int mask = warp_size / 2; mask > 0; mask /= 2 ) + val = max(val, simd_shuffle_xor(val, mask)); - return shared[0]; + return val; } -float block_reduce_max(float val, uint tr, threadgroup float* shared) { - if (tr < BLOCK_SIZE) { - shared[tr] = val; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (uint s = 1; s < BLOCK_SIZE; s *= 2) { - if (tr % (2 * s) == 0 && tr + s < BLOCK_SIZE) { - shared[tr] = max(shared[tr + s], shared[tr]); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } +inline int warp_reduce_all_or(int val, const int warp_size) { + // This uses an xor so that all threads in a warp get the same result + for ( int mask = warp_size / 2; mask > 0; mask /= 2 ) + val = val | simd_shuffle_xor(val, mask); - return shared[0]; + return val; +} + +inline float warp_reduce_sum(float val, const int warp_size) { + for ( int offset = warp_size / 2; offset > 0; offset /= 2 ) + val += simd_shuffle_and_fill_down(val, 0., offset); + + return val; +} + +inline float3 warpSum3(float3 val, uint warp_size){ + val.x = warp_reduce_sum(val.x, warp_size); + val.y = warp_reduce_sum(val.y, warp_size); + val.z = warp_reduce_sum(val.z, warp_size); + return val; +} + +inline float2 warpSum2(float2 val, uint warp_size){ + val.x = warp_reduce_sum(val.x, warp_size); + val.y = warp_reduce_sum(val.y, warp_size); + return val; +} + +inline float warpSum(float val, uint warp_size){ + val = warp_reduce_sum(val, warp_size); + return val; } kernel void rasterize_backward_kernel( @@ -851,7 +858,9 @@ kernel void rasterize_backward_kernel( uint3 tile_bounds [[threadgroups_per_grid]], uint3 gp [[thread_position_in_grid]], uint3 block_index [[threadgroup_position_in_grid]], - uint tr [[thread_index_in_threadgroup]] + uint tr [[thread_index_in_threadgroup]], + uint warp_size [[threads_per_simdgroup]], + uint wr [[thread_index_in_simdgroup]] ) { int32_t tile_id = block_index.y * tile_bounds.x + block_index.x; @@ -891,9 +900,7 @@ kernel void rasterize_backward_kernel( // collect and process batches of gaussians // each thread loads one gaussian at a time before rasterizing - threadgroup float shared[BLOCK_SIZE]; - // TODO(achan): convert `block_reduce_max` to use SIMD groups - const int warp_bin_final = block_reduce_max(bin_final, tr, shared); + const int warp_bin_final = warp_reduce_all_max(bin_final, warp_size); for (int b = 0; b < num_batches; ++b) { // resync all threads before writing next batch of shared mem threadgroup_barrier(mem_flags::mem_threadgroup); @@ -943,7 +950,10 @@ kernel void rasterize_backward_kernel( valid = 0; } } - // TODO(achan): if all threads are inactive in this warp, skip this loop iter here + // if all threads are inactive in this warp, skip this loop + if (!warp_reduce_all_or(valid, warp_size)) { + continue; + } float3 v_rgb_local = {0.f, 0.f, 0.f}; float3 v_conic_local = {0.f, 0.f, 0.f}; @@ -984,21 +994,27 @@ kernel void rasterize_backward_kernel( v_opacity_local = vis * v_alpha; } - // TODO(achan): Use SIMD groups to reduce atomic contention similarly to warps - int32_t g = id_batch[t]; - - atomic_fetch_add_explicit(v_rgb + 3*g + 0, v_rgb_local.x, memory_order_relaxed); - atomic_fetch_add_explicit(v_rgb + 3*g + 1, v_rgb_local.y, memory_order_relaxed); - atomic_fetch_add_explicit(v_rgb + 3*g + 2, v_rgb_local.z, memory_order_relaxed); - - atomic_fetch_add_explicit(v_conic + 3*g + 0, v_conic_local.x, memory_order_relaxed); - atomic_fetch_add_explicit(v_conic + 3*g + 1, v_conic_local.y, memory_order_relaxed); - atomic_fetch_add_explicit(v_conic + 3*g + 2, v_conic_local.z, memory_order_relaxed); - - atomic_fetch_add_explicit(v_xy + 2*g + 0, v_xy_local.x, memory_order_relaxed); - atomic_fetch_add_explicit(v_xy + 2*g + 1, v_xy_local.y, memory_order_relaxed); - - atomic_fetch_add_explicit(v_opacity + g, v_opacity_local, memory_order_relaxed); + v_rgb_local = warpSum3(v_rgb_local, warp_size); + v_conic_local = warpSum3(v_conic_local, warp_size); + v_xy_local = warpSum2(v_xy_local, warp_size); + v_opacity_local = warpSum(v_opacity_local, warp_size); + + if (wr == 0) { + int32_t g = id_batch[t]; + + atomic_fetch_add_explicit(v_rgb + 3*g + 0, v_rgb_local.x, memory_order_relaxed); + atomic_fetch_add_explicit(v_rgb + 3*g + 1, v_rgb_local.y, memory_order_relaxed); + atomic_fetch_add_explicit(v_rgb + 3*g + 2, v_rgb_local.z, memory_order_relaxed); + + atomic_fetch_add_explicit(v_conic + 3*g + 0, v_conic_local.x, memory_order_relaxed); + atomic_fetch_add_explicit(v_conic + 3*g + 1, v_conic_local.y, memory_order_relaxed); + atomic_fetch_add_explicit(v_conic + 3*g + 2, v_conic_local.z, memory_order_relaxed); + + atomic_fetch_add_explicit(v_xy + 2*g + 0, v_xy_local.x, memory_order_relaxed); + atomic_fetch_add_explicit(v_xy + 2*g + 1, v_xy_local.y, memory_order_relaxed); + + atomic_fetch_add_explicit(v_opacity + g, v_opacity_local, memory_order_relaxed); + } } } } @@ -1238,8 +1254,6 @@ kernel void project_gaussians_backward_kernel( float3 p_world = read_packed_float3(means3d, idx); float fx = intrins.x; float fy = intrins.y; - float cx = intrins.z; - float cy = intrins.w; // get v_mean3d from v_xy write_packed_float3( v_mean3d, idx, diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index 9f75a7b..5072f1a 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -289,6 +289,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize NSUInteger num_threads_per_group = MIN(ctx->compute_sh_forward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->compute_sh_forward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_points), EncodeArg::scalar(degree), @@ -297,6 +298,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(coeffs), EncodeArg::tensor(colors) }); + printf("after dispatch for %s\n", __func__); return colors; } @@ -325,6 +327,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize NSUInteger num_threads_per_group = MIN(ctx->compute_sh_backward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->compute_sh_backward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_points), EncodeArg::scalar(degree), @@ -333,6 +336,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(v_colors), EncodeArg::tensor(v_coeffs) }); + printf("after dispatch for %s\n", __func__); return v_coeffs; } @@ -384,6 +388,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize NSUInteger num_threads_per_group = MIN(ctx->project_gaussians_forward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->project_gaussians_forward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_points), EncodeArg::tensor(means3d), @@ -402,6 +407,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(conics_d), EncodeArg::tensor(num_tiles_hit_d) }); + printf("after dispatch for %s\n", __func__); return std::make_tuple( cov3d_d, xys_d, depths_d, radii_d, conics_d, num_tiles_hit_d @@ -455,6 +461,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize NSUInteger num_threads_per_group = MIN(ctx->project_gaussians_backward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->project_gaussians_backward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_points), EncodeArg::tensor(means3d), @@ -477,6 +484,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(v_scale), EncodeArg::tensor(v_quat), }); + printf("after dispatch for %s\n", __func__); return std::make_tuple(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat); } @@ -506,6 +514,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize NSUInteger num_threads_per_group = MIN(ctx->map_gaussian_to_intersects_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->map_gaussian_to_intersects_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_points), EncodeArg::tensor(xys), @@ -515,6 +524,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(isect_ids_unsorted), EncodeArg::tensor(gaussian_ids_unsorted) }); + printf("after dispatch for %s\n", __func__); return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted); } @@ -533,11 +543,13 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize NSUInteger num_threads_per_group = MIN(ctx->get_tile_bin_edges_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_intersects); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->get_tile_bin_edges_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_intersects), EncodeArg::tensor(isect_ids_sorted), EncodeArg::tensor(tile_bins) }); + printf("after dispatch for %s\n", __func__); return tile_bins; } @@ -587,6 +599,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize MetalContext* ctx = get_global_context(); MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1); + printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->rasterize_forward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::array(img_size_dim3, sizeof(img_size_dim3)), EncodeArg::scalar(channels), @@ -602,6 +615,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(background), EncodeArg::array(block_size_dim2, sizeof(block_size_dim2)) }); + printf("after dispatch for %s\n", __func__); return std::make_tuple(out_img, final_Ts, final_idx); } @@ -730,6 +744,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize MetalContext* ctx = get_global_context(); MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); MTLSize thread_group_size = MTLSizeMake(BLOCK_X, BLOCK_Y, 1); + printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->rasterize_backward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::array(img_size, sizeof(img_size)), EncodeArg::tensor(gaussians_ids_sorted), @@ -749,6 +764,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(v_opacity), EncodeArg::tensor(debug) }); + printf("after dispatch for %s\n", __func__); return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); } \ No newline at end of file From 1f541eeeeed4bd600317f2e7e9d04802706c250e Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Fri, 12 Apr 2024 00:28:46 -0700 Subject: [PATCH 10/19] fixes --- vendor/gsplat-metal/gsplat_metal.metal | 14 ++++---- vendor/gsplat-metal/gsplat_metal.mm | 45 +++++++++++++++++--------- 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/vendor/gsplat-metal/gsplat_metal.metal b/vendor/gsplat-metal/gsplat_metal.metal index da74cc4..de2a97a 100644 --- a/vendor/gsplat-metal/gsplat_metal.metal +++ b/vendor/gsplat-metal/gsplat_metal.metal @@ -368,6 +368,7 @@ kernel void project_gaussians_forward_kernel( constant float* projmat, constant float4& intrins, constant uint2& img_size, + constant uint3& tile_bounds, constant float& clip_thresh, device float* covs3d, device float* xys, // float2 @@ -375,7 +376,6 @@ kernel void project_gaussians_forward_kernel( device int* radii, device float* conics, // float3 device int32_t* num_tiles_hit, - uint3 tile_bounds [[threadgroups_per_grid]], uint3 gp [[thread_position_in_grid]] ) { uint idx = gp.x; @@ -385,7 +385,7 @@ kernel void project_gaussians_forward_kernel( radii[idx] = 0; num_tiles_hit[idx] = 0; - float3 p_world = means3d[idx*3]; + float3 p_world = read_packed_float3(means3d, idx); float3 p_view; if (clip_near_plane(p_world, viewmat, p_view, clip_thresh)) { return; @@ -433,6 +433,7 @@ kernel void project_gaussians_forward_kernel( // TODO(achan): this is actually the nd_rasterize_forward_kernel kernel void rasterize_forward_kernel( + constant uint3& tile_bounds, constant uint3& img_size, constant uint& channels, constant int32_t* gaussian_ids_sorted, @@ -446,7 +447,6 @@ kernel void rasterize_forward_kernel( device float* out_img, constant float* background, constant uint2& blockDim, - uint2 tile_bounds [[threadgroups_per_grid]], uint2 blockIdx [[threadgroup_position_in_grid]], uint2 threadIdx [[thread_position_in_threadgroup]] ) { @@ -735,9 +735,9 @@ kernel void map_gaussian_to_intersects_kernel( constant float* depths, constant int* radii, constant int32_t* cum_tiles_hit, + constant uint3& tile_bounds, device int64_t* isect_ids, device int32_t* gaussian_ids, - uint3 tile_bounds [[threadgroups_per_grid]], uint3 gp [[thread_position_in_grid]] ) { uint idx = gp.x; @@ -838,6 +838,7 @@ inline float warpSum(float val, uint warp_size){ } kernel void rasterize_backward_kernel( + constant uint3& tile_bounds, constant uint2& img_size, constant int32_t* gaussian_ids_sorted, constant int* tile_bins, // int2 @@ -855,15 +856,14 @@ kernel void rasterize_backward_kernel( device atomic_float* v_rgb, // float3 device atomic_float* v_opacity, device int32_t* debug, - uint3 tile_bounds [[threadgroups_per_grid]], uint3 gp [[thread_position_in_grid]], - uint3 block_index [[threadgroup_position_in_grid]], + uint3 blockIdx [[threadgroup_position_in_grid]], uint tr [[thread_index_in_threadgroup]], uint warp_size [[threads_per_simdgroup]], uint wr [[thread_index_in_simdgroup]] ) { int32_t tile_id = - block_index.y * tile_bounds.x + block_index.x; + blockIdx.y * tile_bounds.x + blockIdx.x; uint i = gp.y; uint j = gp.x; diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index 5072f1a..7c449dc 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -289,7 +289,6 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize NSUInteger num_threads_per_group = MIN(ctx->compute_sh_forward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->compute_sh_forward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_points), EncodeArg::scalar(degree), @@ -298,7 +297,6 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(coeffs), EncodeArg::tensor(colors) }); - printf("after dispatch for %s\n", __func__); return colors; } @@ -327,7 +325,6 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize NSUInteger num_threads_per_group = MIN(ctx->compute_sh_backward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->compute_sh_backward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_points), EncodeArg::scalar(degree), @@ -336,7 +333,6 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(v_colors), EncodeArg::tensor(v_coeffs) }); - printf("after dispatch for %s\n", __func__); return v_coeffs; } @@ -381,6 +377,12 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize float intrins[4] = {fx, fy, cx, cy}; uint32_t img_size[2] = {img_width, img_height}; + uint32_t tile_bounds_arr[4] = { + (uint32_t)std::get<0>(tile_bounds), + (uint32_t)std::get<1>(tile_bounds), + (uint32_t)std::get<2>(tile_bounds), + 0xDEAD + }; // Dispatch the kernel MetalContext* ctx = get_global_context(); @@ -388,7 +390,6 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize NSUInteger num_threads_per_group = MIN(ctx->project_gaussians_forward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->project_gaussians_forward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_points), EncodeArg::tensor(means3d), @@ -399,6 +400,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(projmat), EncodeArg::array(intrins, sizeof(intrins)), EncodeArg::array(img_size, sizeof(img_size)), + EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)), EncodeArg::scalar(clip_thresh), EncodeArg::tensor(cov3d_d), EncodeArg::tensor(xys_d), @@ -407,7 +409,6 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(conics_d), EncodeArg::tensor(num_tiles_hit_d) }); - printf("after dispatch for %s\n", __func__); return std::make_tuple( cov3d_d, xys_d, depths_d, radii_d, conics_d, num_tiles_hit_d @@ -461,7 +462,6 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize NSUInteger num_threads_per_group = MIN(ctx->project_gaussians_backward_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->project_gaussians_backward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_points), EncodeArg::tensor(means3d), @@ -484,7 +484,6 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(v_scale), EncodeArg::tensor(v_quat), }); - printf("after dispatch for %s\n", __func__); return std::make_tuple(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat); } @@ -508,23 +507,29 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32)); torch::Tensor isect_ids_unsorted = torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64)); + + uint32_t tile_bounds_arr[4] = { + (uint32_t)std::get<0>(tile_bounds), + (uint32_t)std::get<1>(tile_bounds), + (uint32_t)std::get<2>(tile_bounds), + 0xDEAD + }; MetalContext* ctx = get_global_context(); MTLSize grid_size = MTLSizeMake(num_points, 1, 1); NSUInteger num_threads_per_group = MIN(ctx->map_gaussian_to_intersects_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_points); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->map_gaussian_to_intersects_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_points), EncodeArg::tensor(xys), EncodeArg::tensor(depths), EncodeArg::tensor(radii), EncodeArg::tensor(cum_tiles_hit), + EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)), EncodeArg::tensor(isect_ids_unsorted), EncodeArg::tensor(gaussian_ids_unsorted) }); - printf("after dispatch for %s\n", __func__); return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted); } @@ -543,13 +548,11 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize NSUInteger num_threads_per_group = MIN(ctx->get_tile_bin_edges_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_intersects); MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); - printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->get_tile_bin_edges_kernel_cpso, grid_size, thread_group_size, { EncodeArg::scalar(num_intersects), EncodeArg::tensor(isect_ids_sorted), EncodeArg::tensor(tile_bins) }); - printf("after dispatch for %s\n", __func__); return tile_bins; } @@ -594,13 +597,19 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize ); uint32_t img_size_dim3[4] = {(uint32_t)std::get<0>(img_size), (uint32_t)std::get<1>(img_size), (uint32_t)std::get<2>(img_size), 0xDEAD}; + uint32_t tile_bounds_arr[4] = { + (uint32_t)std::get<0>(tile_bounds), + (uint32_t)std::get<1>(tile_bounds), + (uint32_t)std::get<2>(tile_bounds), + 0xDEAD + }; int32_t block_size_dim2[2] = {std::get<0>(block), std::get<1>(block)}; MetalContext* ctx = get_global_context(); MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1); - printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->rasterize_forward_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)), EncodeArg::array(img_size_dim3, sizeof(img_size_dim3)), EncodeArg::scalar(channels), EncodeArg::tensor(gaussian_ids_sorted), @@ -615,7 +624,6 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(background), EncodeArg::array(block_size_dim2, sizeof(block_size_dim2)) }); - printf("after dispatch for %s\n", __func__); return std::make_tuple(out_img, final_Ts, final_idx); } @@ -740,12 +748,18 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); uint32_t img_size[2] = {img_height, img_width}; + uint32_t tile_bounds_arr[4] = { + (img_width + BLOCK_X - 1) / BLOCK_X, + (img_height + BLOCK_Y - 1) / BLOCK_Y, + 1, + 0xDEAD + }; MetalContext* ctx = get_global_context(); MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); MTLSize thread_group_size = MTLSizeMake(BLOCK_X, BLOCK_Y, 1); - printf("before dispatch for %s\n", __func__); dispatchKernel(ctx, ctx->rasterize_backward_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)), EncodeArg::array(img_size, sizeof(img_size)), EncodeArg::tensor(gaussians_ids_sorted), EncodeArg::tensor(tile_bins), @@ -764,7 +778,6 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(v_opacity), EncodeArg::tensor(debug) }); - printf("after dispatch for %s\n", __func__); return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); } \ No newline at end of file From 06d38c7e227654b4dfbb2848dc08da21738672e7 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Sat, 13 Apr 2024 18:13:56 -0700 Subject: [PATCH 11/19] fixes and cleanup --- CMakeLists.txt | 5 +---- vendor/gsplat-metal/gsplat_metal.metal | 1 - vendor/gsplat-metal/gsplat_metal.mm | 18 ++++++++++++------ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e05891..4e9f7e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -138,10 +138,7 @@ elseif(GPU_RUNTIME STREQUAL "MPS") ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK} ) - target_include_directories(gsplat PRIVATE - ${PROJECT_SOURCE_DIR}/vendor/glm - ${TORCH_INCLUDE_DIRS} - ) + target_include_directories(gsplat PRIVATE ${TORCH_INCLUDE_DIRS}) # copy shader files to bin directory configure_file(vendor/gsplat-metal/gsplat_metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/gsplat_metal.metal COPYONLY) add_custom_command( diff --git a/vendor/gsplat-metal/gsplat_metal.metal b/vendor/gsplat-metal/gsplat_metal.metal index de2a97a..a751bd6 100644 --- a/vendor/gsplat-metal/gsplat_metal.metal +++ b/vendor/gsplat-metal/gsplat_metal.metal @@ -855,7 +855,6 @@ kernel void rasterize_backward_kernel( device atomic_float* v_conic, // float3 device atomic_float* v_rgb, // float3 device atomic_float* v_opacity, - device int32_t* debug, uint3 gp [[thread_position_in_grid]], uint3 blockIdx [[threadgroup_position_in_grid]], uint tr [[thread_index_in_threadgroup]], diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index 7c449dc..5490851 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -729,8 +729,17 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize const torch::Tensor &v_output, // dL_dout_color const torch::Tensor &v_output_alpha ) { + CHECK_INPUT(gaussians_ids_sorted); + CHECK_INPUT(tile_bins); CHECK_INPUT(xys); + CHECK_INPUT(conics); CHECK_INPUT(colors); + CHECK_INPUT(opacities); + CHECK_INPUT(background); + CHECK_INPUT(final_Ts); + CHECK_INPUT(final_idx); + CHECK_INPUT(v_output); + CHECK_INPUT(v_output_alpha); const int num_points = xys.size(0); const int channels = colors.size(1); @@ -741,13 +750,11 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize torch::zeros({num_points, channels}, xys.options()); torch::Tensor v_opacity = torch::zeros({num_points, 1}, xys.options()); - torch::Tensor debug = torch::zeros({1}, xys.options().dtype(torch::kInt32)); - // Get a reference to the command buffer for the MPS stream id command_buffer = torch::mps::get_command_buffer(); TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); - uint32_t img_size[2] = {img_height, img_width}; + uint32_t img_size[2] = {img_width, img_height}; uint32_t tile_bounds_arr[4] = { (img_width + BLOCK_X - 1) / BLOCK_X, (img_height + BLOCK_Y - 1) / BLOCK_Y, @@ -756,7 +763,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize }; MetalContext* ctx = get_global_context(); - MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); + MTLSize grid_size = MTLSizeMake(img_width, img_height, 1); MTLSize thread_group_size = MTLSizeMake(BLOCK_X, BLOCK_Y, 1); dispatchKernel(ctx, ctx->rasterize_backward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)), @@ -775,8 +782,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(v_xy), EncodeArg::tensor(v_conic), EncodeArg::tensor(v_colors), - EncodeArg::tensor(v_opacity), - EncodeArg::tensor(debug) + EncodeArg::tensor(v_opacity) }); return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); From 6ce807366196ae7d8a52768eefc080632db86a60 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Sat, 13 Apr 2024 18:58:23 -0700 Subject: [PATCH 12/19] remaining functions --- vendor/gsplat-metal/gsplat_metal.metal | 137 +++++++++++++++++++++++++ vendor/gsplat-metal/gsplat_metal.mm | 104 +++++++++++++++++-- 2 files changed, 233 insertions(+), 8 deletions(-) diff --git a/vendor/gsplat-metal/gsplat_metal.metal b/vendor/gsplat-metal/gsplat_metal.metal index a751bd6..e2c3f70 100644 --- a/vendor/gsplat-metal/gsplat_metal.metal +++ b/vendor/gsplat-metal/gsplat_metal.metal @@ -6,6 +6,7 @@ using namespace metal; #define BLOCK_Y 16 #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) #define CHANNELS 3 +#define MAX_REGISTER_CHANNELS 3 constant float SH_C0 = 0.28209479177387814f; constant float SH_C1 = 0.4886025119029199f; @@ -1018,6 +1019,119 @@ kernel void rasterize_backward_kernel( } } +kernel void nd_rasterize_backward_kernel( + constant uint3& tile_bounds, + constant uint3& img_size, + constant uint& channels, + constant int32_t* gaussians_ids_sorted, + constant int* tile_bins, // int2 + constant float* xys, // float2 + constant float* conics, // float3 + constant float* rgbs, + constant float* opacities, + constant float* background, + constant float* final_Ts, + constant int* final_index, + constant float* v_output, + constant float* v_output_alpha, + device atomic_float* v_xy, // float2 + device atomic_float* v_conic, // float3 + device atomic_float* v_rgb, + device atomic_float* v_opacity, + device float* workspace, + uint3 blockIdx [[threadgroup_position_in_grid]], + uint3 blockDim [[threads_per_threadgroup]], + uint3 threadIdx [[thread_position_in_threadgroup]] +) { + if (channels > MAX_REGISTER_CHANNELS && workspace == nullptr) { + return; + } + // current naive implementation where tile data loading is redundant + // TODO tile data should be shared between tile threads + int32_t tile_id = blockIdx.y * tile_bounds.x + blockIdx.x; + uint i = blockIdx.y * blockDim.y + threadIdx.y; + uint j = blockIdx.x * blockDim.x + threadIdx.x; + float px = (float)j; + float py = (float)i; + int32_t pix_id = i * img_size.x + j; + + // return if out of bounds + if (i >= img_size.y || j >= img_size.x) { + return; + } + + // which gaussians get gradients for this pixel + int2 range = read_packed_int2(tile_bins, tile_id); + // df/d_out for this pixel + constant float *v_out = &(v_output[channels * pix_id]); + const float v_out_alpha = v_output_alpha[pix_id]; + // this is the T AFTER the last gaussian in this pixel + float T_final = final_Ts[pix_id]; + float T = T_final; + // the contribution from gaussians behind the current one + device float *S = &workspace[channels * pix_id]; + int bin_final = final_index[pix_id]; + + // iterate backward to compute the jacobians wrt rgb, opacity, mean2d, and + // conic recursively compute T_{n-1} from T_n, where T_i = prod(j < i) (1 - + // alpha_j), and S_{n-1} from S_n, where S_j = sum_{i > j}(rgb_i * alpha_i * + // T_i) df/dalpha_i = rgb_i * T_i - S_{i+1| / (1 - alpha_i) + for (int idx = bin_final - 1; idx >= range.x; --idx) { + const int32_t g = gaussians_ids_sorted[idx]; + const float3 conic = read_packed_float3(conics, g); + const float2 center = read_packed_float2(xys, g); + const float2 delta = {center.x - px, center.y - py}; + const float sigma = + 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + + conic.y * delta.x * delta.y; + if (sigma < 0.f) { + continue; + } + const float opac = opacities[g]; + const float vis = exp(-sigma); + const float alpha = min(0.99f, opac * vis); + if (alpha < 1.f / 255.f) { + continue; + } + + // compute the current T for this gaussian + const float ra = 1.f / (1.f - alpha); + T *= ra; + // rgb = rgbs[g]; + // update v_rgb for this gaussian + const float fac = alpha * T; + float v_alpha = 0.f; + for (int c = 0; c < channels; ++c) { + // gradient wrt rgb + atomic_fetch_add_explicit(v_rgb + channels * g + c, fac * v_out[c], memory_order_relaxed); + // contribution from this pixel + v_alpha += (rgbs[channels * g + c] * T - S[c] * ra) * v_out[c]; + // contribution from background pixel + v_alpha += -T_final * ra * background[c] * v_out[c]; + // update the running sum + S[c] += rgbs[channels * g + c] * fac; + } + v_alpha += T_final * ra * v_out_alpha; + // update v_opacity for this gaussian + atomic_fetch_add_explicit(v_opacity + g, vis * v_alpha, memory_order_relaxed); + + // compute vjps for conics and means + // d_sigma / d_delta = conic * delta + // d_sigma / d_conic = delta * delta.T + const float v_sigma = -opac * vis * v_alpha; + + atomic_fetch_add_explicit(v_conic + 3*g + 0, 0.5f * v_sigma * delta.x * delta.x, memory_order_relaxed); + atomic_fetch_add_explicit(v_conic + 3*g + 1, 0.5f * v_sigma * delta.x * delta.y, memory_order_relaxed); + atomic_fetch_add_explicit(v_conic + 3*g + 2, 0.5f * v_sigma * delta.y * delta.y, memory_order_relaxed); + atomic_fetch_add_explicit( + v_xy + 2*g + 0, v_sigma * (conic.x * delta.x + conic.y * delta.y), memory_order_relaxed + ); + atomic_fetch_add_explicit( + v_xy + 2*g + 1, v_sigma * (conic.y * delta.x + conic.z * delta.y), memory_order_relaxed + ); + } +} + // given v_xy_pix, get v_xyz inline float3 project_pix_vjp( constant float *mat, const float3 p, const uint2 img_size, const float2 v_xy @@ -1294,4 +1408,27 @@ kernel void project_gaussians_backward_kernel( &(v_scale[3*idx]), &(v_quat[4*idx]) ); +} + +kernel void compute_cov2d_bounds_kernel( + constant uint& num_pts, + constant float* covs2d, + device float* conics, + device float* radii, + uint row [[thread_index_in_threadgroup]] +) { + if (row >= num_pts) { + return; + } + int index = row * 3; + float3 conic; + float radius; + float3 cov2d{ + (float)covs2d[index], (float)covs2d[index + 1], (float)covs2d[index + 2] + }; + compute_cov2d_bounds(cov2d, conic, radius); + conics[index] = conic.x; + conics[index + 1] = conic.y; + conics[index + 2] = conic.z; + radii[row] = radius; } \ No newline at end of file diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index 5490851..7b59c96 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -10,7 +10,6 @@ id queue; dispatch_queue_t d_queue; - id nd_rasterize_forward_kernel_cpso; id nd_rasterize_backward_kernel_cpso; id rasterize_forward_kernel_cpso; id rasterize_backward_kernel_cpso; @@ -23,8 +22,6 @@ id get_tile_bin_edges_kernel_cpso; }; -// This function is used in both host and device code -// TODO(achan): Do I need to make this callable from the metal device? unsigned num_sh_bases(const unsigned degree) { if (degree == 0) return 1; @@ -110,15 +107,14 @@ @implementation DummyClassForPathHack } \ } - // GSPLAT_METAL_ADD_KERNEL(nd_rasterize_forward_kernel); - // GSPLAT_METAL_ADD_KERNEL(nd_rasterize_backward_kernel); + GSPLAT_METAL_ADD_KERNEL(nd_rasterize_backward_kernel); GSPLAT_METAL_ADD_KERNEL(rasterize_forward_kernel); GSPLAT_METAL_ADD_KERNEL(rasterize_backward_kernel); GSPLAT_METAL_ADD_KERNEL(project_gaussians_forward_kernel); GSPLAT_METAL_ADD_KERNEL(project_gaussians_backward_kernel); GSPLAT_METAL_ADD_KERNEL(compute_sh_forward_kernel); GSPLAT_METAL_ADD_KERNEL(compute_sh_backward_kernel); - // GSPLAT_METAL_ADD_KERNEL(compute_cov2d_bounds_kernel); + GSPLAT_METAL_ADD_KERNEL(compute_cov2d_bounds_kernel); GSPLAT_METAL_ADD_KERNEL(map_gaussian_to_intersects_kernel); GSPLAT_METAL_ADD_KERNEL(get_tile_bin_edges_kernel); @@ -129,7 +125,6 @@ @implementation DummyClassForPathHack // TODO(achan): Where do I call this? void free_gsplat_metal_context(MetalContext* ctx) { - [ctx->nd_rasterize_forward_kernel_cpso release]; [ctx->nd_rasterize_backward_kernel_cpso release]; [ctx->rasterize_forward_kernel_cpso release]; [ctx->rasterize_backward_kernel_cpso release]; @@ -265,6 +260,19 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize ); torch::Tensor radii = torch::zeros({num_pts, 1}, covs2d.options().dtype(torch::kFloat32)); + + // Dispatch the kernel + MetalContext* ctx = get_global_context(); + MTLSize grid_size = MTLSizeMake(num_pts, 1, 1); + NSUInteger num_threads_per_group = + MIN(ctx->compute_cov2d_bounds_kernel_cpso.maxTotalThreadsPerThreadgroup, (NSUInteger)num_pts); + MTLSize thread_group_size = MTLSizeMake(num_threads_per_group, 1, 1); + dispatchKernel(ctx, ctx->compute_cov2d_bounds_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::scalar(num_pts), + EncodeArg::tensor(covs2d), + EncodeArg::tensor(conics), + EncodeArg::tensor(radii) + }); return std::make_tuple(conics, radii); } @@ -634,6 +642,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize torch::Tensor > nd_rasterize_forward_tensor( const std::tuple tile_bounds, + // TODO(achan): we should be able to remove the 3rd dimension of `block` as it is always set to 1 const std::tuple block, const std::tuple img_size, const torch::Tensor &gaussian_ids_sorted, @@ -652,7 +661,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize CHECK_INPUT(opacities); CHECK_INPUT(background); - const int channels = colors.size(1); + const uint32_t channels = colors.size(1); const int img_width = std::get<0>(img_size); const int img_height = std::get<1>(img_size); @@ -666,6 +675,35 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize {img_height, img_width}, xys.options().dtype(torch::kInt32) ); + uint32_t img_size_dim3[4] = {(uint32_t)std::get<0>(img_size), (uint32_t)std::get<1>(img_size), (uint32_t)std::get<2>(img_size), 0xDEAD}; + uint32_t tile_bounds_arr[4] = { + (uint32_t)std::get<0>(tile_bounds), + (uint32_t)std::get<1>(tile_bounds), + (uint32_t)std::get<2>(tile_bounds), + 0xDEAD + }; + int32_t block_size_dim2[2] = {std::get<0>(block), std::get<1>(block)}; + + MetalContext* ctx = get_global_context(); + MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); + MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1); + dispatchKernel(ctx, ctx->rasterize_forward_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)), + EncodeArg::array(img_size_dim3, sizeof(img_size_dim3)), + EncodeArg::scalar(channels), + EncodeArg::tensor(gaussian_ids_sorted), + EncodeArg::tensor(tile_bins), + EncodeArg::tensor(xys), + EncodeArg::tensor(conics), + EncodeArg::tensor(colors), + EncodeArg::tensor(opacities), + EncodeArg::tensor(final_Ts), + EncodeArg::tensor(final_idx), + EncodeArg::tensor(out_img), + EncodeArg::tensor(background), + EncodeArg::array(block_size_dim2, sizeof(block_size_dim2)) + }); + return std::make_tuple(out_img, final_Ts, final_idx); } @@ -692,8 +730,17 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize const torch::Tensor &v_output, // dL_dout_color const torch::Tensor &v_output_alpha ) { + CHECK_INPUT(gaussians_ids_sorted); + CHECK_INPUT(tile_bins); CHECK_INPUT(xys); + CHECK_INPUT(conics); CHECK_INPUT(colors); + CHECK_INPUT(opacities); + CHECK_INPUT(background); + CHECK_INPUT(final_Ts); + CHECK_INPUT(final_idx); + CHECK_INPUT(v_output); + CHECK_INPUT(v_output_alpha); const int num_points = xys.size(0); const int channels = colors.size(1); @@ -703,6 +750,47 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize torch::Tensor v_colors = torch::zeros({num_points, channels}, xys.options()); torch::Tensor v_opacity = torch::zeros({num_points, 1}, xys.options()); + torch::Tensor workspace = torch::zeros( + {img_height, img_width, channels}, + xys.options().dtype(torch::kFloat32) + ); + + // Get a reference to the command buffer for the MPS stream + id command_buffer = torch::mps::get_command_buffer(); + TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); + + uint32_t img_size[2] = {img_width, img_height}; + uint32_t tile_bounds_arr[4] = { + (img_width + BLOCK_X - 1) / BLOCK_X, + (img_height + BLOCK_Y - 1) / BLOCK_Y, + 1, + 0xDEAD + }; + + MetalContext* ctx = get_global_context(); + MTLSize grid_size = MTLSizeMake(img_width, img_height, 1); + MTLSize thread_group_size = MTLSizeMake(BLOCK_X, BLOCK_Y, 1); + dispatchKernel(ctx, ctx->nd_rasterize_backward_kernel_cpso, grid_size, thread_group_size, { + EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)), + EncodeArg::array(img_size, sizeof(img_size)), + EncodeArg::scalar(channels), + EncodeArg::tensor(gaussians_ids_sorted), + EncodeArg::tensor(tile_bins), + EncodeArg::tensor(xys), + EncodeArg::tensor(conics), + EncodeArg::tensor(colors), + EncodeArg::tensor(opacities), + EncodeArg::tensor(background), + EncodeArg::tensor(final_Ts), + EncodeArg::tensor(final_idx), + EncodeArg::tensor(v_output), + EncodeArg::tensor(v_output_alpha), + EncodeArg::tensor(v_xy), + EncodeArg::tensor(v_conic), + EncodeArg::tensor(v_colors), + EncodeArg::tensor(v_opacity), + EncodeArg::tensor(workspace) + }); return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); } From f258ee41be0ee3ade78e231b365435e6a5d89c74 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Sat, 13 Apr 2024 19:02:26 -0700 Subject: [PATCH 13/19] cleanup --- rasterize_gaussians.cpp | 34 +++++++++++++------------- vendor/gsplat-metal/gsplat_metal.metal | 3 +-- vendor/gsplat-metal/gsplat_metal.mm | 28 +++------------------ 3 files changed, 22 insertions(+), 43 deletions(-) diff --git a/rasterize_gaussians.cpp b/rasterize_gaussians.cpp index 875134b..d6a9d40 100644 --- a/rasterize_gaussians.cpp +++ b/rasterize_gaussians.cpp @@ -160,13 +160,13 @@ torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx, torch::Device device = xys.device(); auto t = rasterize_forward_tensor_cpu(imgWidth, imgHeight, - xys.to(torch::kCPU), - conics.to(torch::kCPU), - colors.to(torch::kCPU), - opacity.to(torch::kCPU), - background.to(torch::kCPU), - cov2d.to(torch::kCPU), - camDepths.to(torch::kCPU) + xys, + conics, + colors, + opacity, + background, + cov2d, + camDepths ); // Final image torch::Tensor outImg = std::get<0>(t).to(device); @@ -200,17 +200,17 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr torch::Device device = xys.device(); auto t = rasterize_backward_tensor_cpu(imgHeight, imgWidth, - xys.to(torch::kCPU), - conics.to(torch::kCPU), - colors.to(torch::kCPU), - opacity.to(torch::kCPU), - background.to(torch::kCPU), - cov2d.to(torch::kCPU), - camDepths.to(torch::kCPU), - finalTs.to(torch::kCPU), + xys, + conics, + colors, + opacity, + background, + cov2d, + camDepths, + finalTs, px2gid, - v_outImg.to(torch::kCPU), - v_outAlpha.to(torch::kCPU)); + v_outImg, + v_outAlpha); // delete[] px2gid; diff --git a/vendor/gsplat-metal/gsplat_metal.metal b/vendor/gsplat-metal/gsplat_metal.metal index e2c3f70..a782ade 100644 --- a/vendor/gsplat-metal/gsplat_metal.metal +++ b/vendor/gsplat-metal/gsplat_metal.metal @@ -432,8 +432,7 @@ kernel void project_gaussians_forward_kernel( write_packed_float2(xys, idx, center); } -// TODO(achan): this is actually the nd_rasterize_forward_kernel -kernel void rasterize_forward_kernel( +kernel void nd_rasterize_forward_kernel( constant uint3& tile_bounds, constant uint3& img_size, constant uint& channels, diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index 7b59c96..ee2bb39 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -11,7 +11,7 @@ dispatch_queue_t d_queue; id nd_rasterize_backward_kernel_cpso; - id rasterize_forward_kernel_cpso; + id nd_rasterize_forward_kernel_cpso; id rasterize_backward_kernel_cpso; id project_gaussians_forward_kernel_cpso; id project_gaussians_backward_kernel_cpso; @@ -108,7 +108,7 @@ @implementation DummyClassForPathHack } GSPLAT_METAL_ADD_KERNEL(nd_rasterize_backward_kernel); - GSPLAT_METAL_ADD_KERNEL(rasterize_forward_kernel); + GSPLAT_METAL_ADD_KERNEL(nd_rasterize_forward_kernel); GSPLAT_METAL_ADD_KERNEL(rasterize_backward_kernel); GSPLAT_METAL_ADD_KERNEL(project_gaussians_forward_kernel); GSPLAT_METAL_ADD_KERNEL(project_gaussians_backward_kernel); @@ -123,26 +123,6 @@ @implementation DummyClassForPathHack return ctx; } -// TODO(achan): Where do I call this? -void free_gsplat_metal_context(MetalContext* ctx) { - [ctx->nd_rasterize_backward_kernel_cpso release]; - [ctx->rasterize_forward_kernel_cpso release]; - [ctx->rasterize_backward_kernel_cpso release]; - [ctx->project_gaussians_forward_kernel_cpso release]; - [ctx->project_gaussians_backward_kernel_cpso release]; - [ctx->compute_sh_forward_kernel_cpso release]; - [ctx->compute_sh_backward_kernel_cpso release]; - [ctx->compute_cov2d_bounds_kernel_cpso release]; - [ctx->map_gaussian_to_intersects_kernel_cpso release]; - [ctx->get_tile_bin_edges_kernel_cpso release]; - - [ctx->queue release]; - [ctx->device release]; - // We do not need to release `d_queue` here as that is managed by torch. - - free(ctx); -} - MetalContext* get_global_context() { static MetalContext* ctx = NULL; if (ctx == NULL) { @@ -616,7 +596,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize MetalContext* ctx = get_global_context(); MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1); - dispatchKernel(ctx, ctx->rasterize_forward_kernel_cpso, grid_size, thread_group_size, { + dispatchKernel(ctx, ctx->nd_rasterize_forward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)), EncodeArg::array(img_size_dim3, sizeof(img_size_dim3)), EncodeArg::scalar(channels), @@ -687,7 +667,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize MetalContext* ctx = get_global_context(); MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1); - dispatchKernel(ctx, ctx->rasterize_forward_kernel_cpso, grid_size, thread_group_size, { + dispatchKernel(ctx, ctx->nd_rasterize_forward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)), EncodeArg::array(img_size_dim3, sizeof(img_size_dim3)), EncodeArg::scalar(channels), From b899138b39d3831529b3a919d496b8a9aa2bc3ca Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Sat, 13 Apr 2024 19:10:56 -0700 Subject: [PATCH 14/19] readme --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 7959f8f..7fff3a6 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ brew install opencv brew install pytorch ``` -You will also need to install Xcode and the Xcode command line tools to compile with metal support (if you are fine with CPU-only acceleration, you can skip this step): +You will also need to install Xcode and the Xcode command line tools to compile with metal support (otherwise, OpenSplat will build with CPU acceleration only): 1. Install Xcode from the Apple App Store. 2. Install the command line tools with `xcode-select --install`. This might do nothing on your machine. 3. If `xcode-select --print-path` prints `/Library/Developer/CommandLineTools`,then run `sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer`. @@ -241,7 +241,6 @@ We recently released OpenSplat, so there's lots of work to do. * Support for running on AMD cards (more testing needed) * Improve speed / reduce memory usage - * Add Metal support on macOS * Distributed computation using multiple machines * Real-time training viewer output * Compressed scene outputs From fe78f58b3ae7cef0a3993caaf7ff2fd665c08649 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Sun, 14 Apr 2024 12:21:41 -0700 Subject: [PATCH 15/19] revert changes to gsplat_cpu.cpp --- vendor/gsplat-cpu/gsplat_cpu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vendor/gsplat-cpu/gsplat_cpu.cpp b/vendor/gsplat-cpu/gsplat_cpu.cpp index 849b08d..9385265 100644 --- a/vendor/gsplat-cpu/gsplat_cpu.cpp +++ b/vendor/gsplat-cpu/gsplat_cpu.cpp @@ -162,7 +162,7 @@ std::tuple< torch::Tensor outImg = torch::zeros({height, width, channels}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); torch::Tensor finalTs = torch::ones({height, width}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); - torch::Tensor done = torch::zeros({height, width}, torch::TensorOptions().dtype(torch::kBool).device(device)).fill_(false); + torch::Tensor done = torch::zeros({height, width}, torch::TensorOptions().dtype(torch::kBool).device(device)); torch::Tensor sqCov2dX = 3.0f * torch::sqrt(cov2d.index({"...", 0, 0})); torch::Tensor sqCov2dY = 3.0f * torch::sqrt(cov2d.index({"...", 1, 1})); @@ -205,7 +205,7 @@ std::tuple< for (int i = minx; i < maxx; i++){ for (int j = miny; j < maxy; j++){ - size_t pixIdx = (i * width/2 + j); + size_t pixIdx = (i * width + j); if (pDone[pixIdx]) continue; float xCam = gX - j; From e97adb836260de29619bfe0ef66ba745f2063dbe Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Sun, 14 Apr 2024 12:25:58 -0700 Subject: [PATCH 16/19] missing delete --- rasterize_gaussians.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasterize_gaussians.cpp b/rasterize_gaussians.cpp index d6a9d40..8a32330 100644 --- a/rasterize_gaussians.cpp +++ b/rasterize_gaussians.cpp @@ -212,7 +212,7 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr v_outImg, v_outAlpha); - // delete[] px2gid; + delete[] px2gid; torch::Tensor v_xy = std::get<0>(t).to(device); From b6c877ec27c92e4dfe1c411939d18d5629403b42 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Sun, 14 Apr 2024 14:52:11 -0700 Subject: [PATCH 17/19] fix w/h swap --- vendor/gsplat-metal/gsplat_metal.mm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index ee2bb39..e2a3464 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -594,7 +594,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize int32_t block_size_dim2[2] = {std::get<0>(block), std::get<1>(block)}; MetalContext* ctx = get_global_context(); - MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); + MTLSize grid_size = MTLSizeMake(img_width, img_height, 1); MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1); dispatchKernel(ctx, ctx->nd_rasterize_forward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)), From 043c9b2e8ac6e93e904e3d7cf8816a346c1d3f20 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Sun, 14 Apr 2024 14:53:55 -0700 Subject: [PATCH 18/19] ditto for nd_rasterize --- vendor/gsplat-metal/gsplat_metal.mm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/gsplat-metal/gsplat_metal.mm b/vendor/gsplat-metal/gsplat_metal.mm index e2a3464..b17c673 100644 --- a/vendor/gsplat-metal/gsplat_metal.mm +++ b/vendor/gsplat-metal/gsplat_metal.mm @@ -665,7 +665,7 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize int32_t block_size_dim2[2] = {std::get<0>(block), std::get<1>(block)}; MetalContext* ctx = get_global_context(); - MTLSize grid_size = MTLSizeMake(img_height, img_width, 1); + MTLSize grid_size = MTLSizeMake(img_width, img_height, 1); MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1); dispatchKernel(ctx, ctx->nd_rasterize_forward_kernel_cpso, grid_size, thread_group_size, { EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)), From 8c376c82750408060d96ddbedf4d0db51478acf4 Mon Sep 17 00:00:00 2001 From: Piero Toffanin Date: Sun, 14 Apr 2024 20:38:43 -0400 Subject: [PATCH 19/19] Revert RasterizeGaussiansCPU changes --- rasterize_gaussians.cpp | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/rasterize_gaussians.cpp b/rasterize_gaussians.cpp index 8a32330..736d4c0 100644 --- a/rasterize_gaussians.cpp +++ b/rasterize_gaussians.cpp @@ -155,9 +155,6 @@ torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx, ){ int numPoints = xys.size(0); - ctx->saved_data["imgWidth"] = imgWidth; - ctx->saved_data["imgHeight"] = imgHeight; - torch::Device device = xys.device(); auto t = rasterize_forward_tensor_cpu(imgWidth, imgHeight, xys, @@ -169,12 +166,14 @@ torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx, camDepths ); // Final image - torch::Tensor outImg = std::get<0>(t).to(device); + torch::Tensor outImg = std::get<0>(t); - torch::Tensor finalTs = std::get<1>(t).to(device); + torch::Tensor finalTs = std::get<1>(t); std::vector *px2gid = std::get<2>(t); ctx->saved_data["px2gid"] = reinterpret_cast(px2gid); + ctx->saved_data["imgWidth"] = imgWidth; + ctx->saved_data["imgHeight"] = imgHeight; ctx->save_for_backward({ xys, conics, colors, opacity, background, cov2d, camDepths, finalTs }); return outImg; @@ -197,7 +196,6 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr torch::Tensor finalTs = saved[7]; torch::Tensor v_outAlpha = torch::zeros_like(v_outImg.index({"...", 0})); - torch::Device device = xys.device(); auto t = rasterize_backward_tensor_cpu(imgHeight, imgWidth, xys, @@ -215,10 +213,10 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr delete[] px2gid; - torch::Tensor v_xy = std::get<0>(t).to(device); - torch::Tensor v_conic = std::get<1>(t).to(device); - torch::Tensor v_colors = std::get<2>(t).to(device); - torch::Tensor v_opacity = std::get<3>(t).to(device); + torch::Tensor v_xy = std::get<0>(t); + torch::Tensor v_conic = std::get<1>(t); + torch::Tensor v_colors = std::get<2>(t); + torch::Tensor v_opacity = std::get<3>(t); torch::Tensor none; return { v_xy,