From d36a80ba33390f4bdb6453c72a5da7bd2d2f4b4c Mon Sep 17 00:00:00 2001 From: Piero Toffanin Date: Wed, 20 Mar 2024 16:14:43 -0400 Subject: [PATCH] Add --val-render --- opensplat.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/opensplat.cpp b/opensplat.cpp index e0c410d..971bdee 100644 --- a/opensplat.cpp +++ b/opensplat.cpp @@ -17,6 +17,7 @@ int main(int argc, char *argv[]){ ("s,save-every", "Save output scene every these many steps (set to -1 to disable)", cxxopts::value()->default_value("-1")) ("val", "Withhold a camera shot for validating the scene loss") ("val-image", "Filename of the image to withhold for validating scene loss", cxxopts::value()->default_value("random")) + ("val-render", "Path of the directory where to render validation images", cxxopts::value()->default_value("")) ("cpu", "Force CPU execution") ("n,num-iters", "Number of iterations to run", cxxopts::value()->default_value("30000")) @@ -57,8 +58,10 @@ int main(int argc, char *argv[]){ const std::string projectRoot = result["input"].as(); const std::string outputScene = result["output"].as(); const int saveEvery = result["save-every"].as(); - const bool validate = result.count("val") > 0; + const bool validate = result.count("val") > 0 || result.count("val-render") > 0; const std::string valImage = result["val-image"].as(); + const std::string valRender = result["val-render"].as(); + if (!valRender.empty() && !fs::exists(valRender)) fs::create_directories(valRender); const float downScaleFactor = (std::max)(result["downscale-factor"].as(), 1.0f); const int numIters = result["num-iters"].as(); @@ -79,6 +82,7 @@ int main(int argc, char *argv[]){ torch::Device device = torch::kCPU; int displayStep = 1; + if (torch::cuda::is_available() && result.count("cpu") == 0) { std::cout << "Using CUDA" << std::endl; device = torch::kCUDA; @@ -132,6 +136,13 @@ int main(int argc, char *argv[]){ fs::path p(outputScene); model.savePlySplat((p.replace_filename(fs::path(p.stem().string() + "_" + std::to_string(step) + p.extension().string())).string())); } + + if (!valRender.empty() && step % displayStep == 0){ + torch::Tensor rgb = model.forward(*valCam, step); + cv::Mat image = tensorToImage(rgb.detach().cpu()); + cv::cvtColor(image, image, cv::COLOR_RGB2BGR); + cv::imwrite((fs::path(valRender) / (std::to_string(step) + ".png")).string(), image); + } } model.savePlySplat(outputScene);