From c9c71d48d90987c33977d89527828d858f2b6f25 Mon Sep 17 00:00:00 2001 From: Piero Toffanin Date: Wed, 20 Mar 2024 20:21:29 -0400 Subject: [PATCH] Fix SH degrees variable mismatch --- model.cpp | 1 - vendor/gsplat-cpu/gsplat_cpu.cpp | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/model.cpp b/model.cpp index 86a6d2f..49c3d1b 100644 --- a/model.cpp +++ b/model.cpp @@ -159,7 +159,6 @@ torch::Tensor Model::forward(Camera& cam, int step){ rgbs = torch::clamp_min(rgbs + 0.5f, 0.0f); - if (device == torch::kCPU){ rgb = RasterizeGaussiansCPU::apply( xys, diff --git a/vendor/gsplat-cpu/gsplat_cpu.cpp b/vendor/gsplat-cpu/gsplat_cpu.cpp index 250b394..9385265 100644 --- a/vendor/gsplat-cpu/gsplat_cpu.cpp +++ b/vendor/gsplat-cpu/gsplat_cpu.cpp @@ -431,7 +431,7 @@ torch::Tensor compute_sh_forward_tensor_cpu( const int numChannels = 3; unsigned numBases = numShBases(degrees_to_use); - torch::Tensor result = torch::zeros({viewdirs.size(0), numBases}, torch::TensorOptions().dtype(torch::kFloat32).device(viewdirs.device())); + torch::Tensor result = torch::zeros({viewdirs.size(0), numShBases(degree)}, torch::TensorOptions().dtype(torch::kFloat32).device(viewdirs.device())); result.index_put_({"...", 0}, SH_C0); if (numBases > 1){ @@ -478,6 +478,6 @@ torch::Tensor compute_sh_forward_tensor_cpu( } } } - + return (result.index({"...", None}) * coeffs).sum(-2); } \ No newline at end of file