diff --git a/ci_test/common_python/tools.py b/ci_test/common_python/tools.py index 18a2fe34a8b..9729c36695d 100644 --- a/ci_test/common_python/tools.py +++ b/ci_test/common_python/tools.py @@ -835,7 +835,7 @@ def test_func(cluster, dirname, weekly): ) assert_success(return_code, stderr_log_file) if post_test_func is not None: - post_test_func(lbann, weekly) + post_test_func(lbann, weekly, **_kwargs) return { 'return_code': return_code, 'work_dir': work_dir, diff --git a/ci_test/unit_tests/test_unit_layer_upsample.py b/ci_test/unit_tests/test_unit_layer_upsample.py new file mode 100644 index 00000000000..20601a2fe0e --- /dev/null +++ b/ci_test/unit_tests/test_unit_layer_upsample.py @@ -0,0 +1,288 @@ +import functools +import math +import operator +import os +import os.path +import sys +import numpy as np +import lbann.contrib.args + +# Bamboo utilities +current_file = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file) +sys.path.insert(0, os.path.join(os.path.dirname(current_dir), "common_python")) +import tools + +# ============================================== +# Objects for Python data reader +# ============================================== +# Note: The Python data reader imports this file as a module and calls +# the functions below to ingest data. + + +def make_random_array(shape, seed): + """Hacked function to generate a random array. + + NumPy's RNG produces different values with different NumPy + versions. This function is helpful when array values must be + identical across all runs, e.g. when checking against precomputed + metric values. + + Args: + shape (Iterable of int): Array dimensions + seed (int): Parameter for RNG. Must be non-zero. + Returns: + numpy.ndarray: Array of `np.float32`. Values will be in + [-0.5,0.5). + + """ + size = functools.reduce(operator.mul, shape) + eps = np.finfo(np.float32).eps + x = (seed / np.linspace(math.sqrt(eps), 0.1, size)) % 1 - 0.5 + return x.reshape(shape).astype(np.float32) + + +# Data +_num_samples = 64 +_sample_dims = [6, 11, 7] +_sample_dims_3d = [2, 3, 11, 7] +_sample_size = functools.reduce(operator.mul, _sample_dims) +_samples = make_random_array([_num_samples] + _sample_dims, 7) + + +# Sample access functions +def get_sample(index): + return _samples[index, :].reshape(-1) + + +def num_samples(): + return _num_samples + + +def sample_dims(): + return (_sample_size,) + + +# ============================================== +# Setup LBANN experiment +# ============================================== + + +def setup_experiment(lbann, weekly): + """Construct LBANN experiment. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + mini_batch_size = num_samples() // 2 + trainer = lbann.Trainer(mini_batch_size) + model = construct_model(lbann) + data_reader = construct_data_reader(lbann) + optimizer = lbann.NoOptimizer() + return ( + trainer, + model, + data_reader, + optimizer, + None, + ) # Don't request any specific number of nodes + + +upsample_configs = [] + +# 3x3 upsampling +for mode in ["nearest"]: + upsample_configs.append( + { + "name": "3x3 {} upsample".format(mode), + "scale_factors": (3, 3), + "upsample_mode": mode, + } + ) + +# 2x4 upsampling +for mode in ["nearest"]: + upsample_configs.append( + { + "name": "2x4 {} upsample".format(mode), + "scale_factors": (2, 4), + "upsample_mode": mode, + } + ) + +# 2x2x2 3D upsampling +for mode in ["nearest"]: + upsample_configs.append( + { + "name": "2x2x2 {} upsample".format(mode), + "scale_factors": (2, 2, 2), + "upsample_mode": mode, + } + ) + + +def construct_model(lbann): + """Construct LBANN model. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + # Input data + # Note: Sum with a weights layer so that gradient checking will + # verify that error signals are correct. + x_weights = lbann.Weights( + optimizer=lbann.SGD(), + initializer=lbann.ConstantInitializer(value=0.0), + name="input_weights", + ) + x = lbann.Sum( + lbann.Reshape(lbann.Input(data_field="samples"), dims=_sample_dims), + lbann.WeightsLayer(weights=x_weights, dims=_sample_dims), + ) + x_lbann = x + + # Objects for LBANN model + obj = [] + metrics = [] + callbacks = [] + + # ------------------------------------------ + # Upsample + # ------------------------------------------ + + for u in upsample_configs: + uname = u["name"].split(" ")[0] + + # Apply upsampling + x = x_lbann + if len(u["scale_factors"]) == 3: + x = lbann.Reshape(x, dims=_sample_dims_3d) + x = lbann.Identity(x, name=f"in_{uname}") + + y = lbann.Upsample( + x, + num_dims=len(u["scale_factors"]), + has_vectors=True, + scale_factors=u["scale_factors"], + upsample_mode=u["upsample_mode"], + ) + y = lbann.Identity(y, name=f"out_{uname}") + z = lbann.L2Norm2(y) + + obj.append(z) + + # Save the inputs and outputs to check later. + callbacks.append( + lbann.CallbackDumpOutputs( + layers=f"in_{uname} out_{uname}", directory="outputs" + ) + ) + + # ------------------------------------------ + # Gradient checking + # ------------------------------------------ + + callbacks.append(lbann.CallbackCheckGradients(error_on_failure=True)) + + # ------------------------------------------ + # Construct model + # ------------------------------------------ + + num_epochs = 0 + return lbann.Model( + num_epochs, + layers=lbann.traverse_layer_graph(x_lbann), + objective_function=obj, + metrics=metrics, + callbacks=callbacks, + ) + + +def construct_data_reader(lbann): + """Construct Protobuf message for Python data reader. + + The Python data reader will import the current Python file to + access the sample access functions. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + # Note: The training data reader should be removed when + # https://github.com/LLNL/lbann/issues/1098 is resolved. + message = lbann.reader_pb2.DataReader() + message.reader.extend( + [ + tools.create_python_data_reader( + lbann, current_file, "get_sample", "num_samples", "sample_dims", "train" + ) + ] + ) + message.reader.extend( + [ + tools.create_python_data_reader( + lbann, current_file, "get_sample", "num_samples", "sample_dims", "test" + ) + ] + ) + return message + + +# ============================================== +# Setup PyTest +# ============================================== + + +def check_output(lbann, weekly, **kwargs): + for u in upsample_configs: + uname = u["name"].split(" ")[0] + in_data = np.loadtxt( + os.path.join( + kwargs["work_dir"], + "outputs", + "trainer0", + "model0", + f"sgd.testing.epoch.0.step.0_in_{uname}_output0.csv", + ), + delimiter=",", + ) + out_data = np.loadtxt( + os.path.join( + kwargs["work_dir"], + "outputs", + "trainer0", + "model0", + f"sgd.testing.epoch.0.step.0_out_{uname}_output0.csv", + ), + delimiter=",", + ) + + ndims = len(u["scale_factors"]) + upsampled_data = in_data.copy().reshape( + [-1] + (_sample_dims if ndims == 2 else _sample_dims_3d) + ) + for i, scale_fac in enumerate(u["scale_factors"]): + if u["upsample_mode"] == "nearest": + upsampled_data = upsampled_data.repeat(scale_fac, axis=i + 2) + + assert np.allclose(upsampled_data.ravel(), out_data.ravel()) + + +# Runtime parameters/arguments +environment = lbann.contrib.args.get_distconv_environment() +environment["LBANN_KEEP_ERROR_SIGNALS"] = 1 + +# Create test functions that can interact with PyTest +# Note: Create test name by removing ".py" from file name +_test_name = os.path.splitext(os.path.basename(current_file))[0] +for _test_func in tools.create_tests( + setup_experiment, + _test_name, + post_test_func=check_output, + skip_clusters=["tioga", "corona"], +): + globals()[_test_func.__name__] = _test_func diff --git a/ci_test/unit_tests/test_unit_layer_upsample_distconv.py b/ci_test/unit_tests/test_unit_layer_upsample_distconv.py new file mode 100644 index 00000000000..039f74f04ee --- /dev/null +++ b/ci_test/unit_tests/test_unit_layer_upsample_distconv.py @@ -0,0 +1,378 @@ +import functools +import math +import operator +import os +import os.path +import sys +import pytest +import numpy as np +import lbann.contrib.args + +# Bamboo utilities +current_file = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file) +sys.path.insert(0, os.path.join(os.path.dirname(current_dir), "common_python")) +import tools + +# ============================================== +# Objects for Python data reader +# ============================================== +# Note: The Python data reader imports this file as a module and calls +# the functions below to ingest data. + + +def make_random_array(shape, seed): + """Hacked function to generate a random array. + + NumPy's RNG produces different values with different NumPy + versions. This function is helpful when array values must be + identical across all runs, e.g. when checking against precomputed + metric values. + + Args: + shape (Iterable of int): Array dimensions + seed (int): Parameter for RNG. Must be non-zero. + Returns: + numpy.ndarray: Array of `np.float32`. Values will be in + [-0.5,0.5). + + """ + size = functools.reduce(operator.mul, shape) + eps = np.finfo(np.float32).eps + x = (seed / np.linspace(math.sqrt(eps), 0.1, size)) % 1 - 0.5 + return x.reshape(shape).astype(np.float32) + + +# Data +_num_samples = 64 +_sample_dims = [6, 11, 7] +_sample_dims_3d = [2, 3, 11, 7] +_sample_size = functools.reduce(operator.mul, _sample_dims) +_samples = make_random_array([_num_samples] + _sample_dims, 7) + + +# Sample access functions +def get_sample(index): + return _samples[index, :].reshape(-1) + + +def num_samples(): + return _num_samples + + +def sample_dims(): + return (_sample_size,) + + +# ============================================== +# Setup LBANN experiment +# ============================================== + + +def setup_experiment(lbann, weekly): + """Construct LBANN experiment. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + mini_batch_size = num_samples() // 2 + trainer = lbann.Trainer(mini_batch_size) + model = construct_model(lbann) + data_reader = construct_data_reader(lbann) + optimizer = lbann.NoOptimizer() + return ( + trainer, + model, + data_reader, + optimizer, + None, + ) # Don't request any specific number of nodes + + +upsample_configs = [] + +# 3x3 upsampling +for mode in ["nearest"]: + upsample_configs.append( + { + "name": "3x3 {} upsample".format(mode), + "scale_factors": (3, 3), + "upsample_mode": mode, + } + ) + +# 2x4 upsampling +for mode in ["nearest"]: + upsample_configs.append( + { + "name": "2x4 {} upsample".format(mode), + "scale_factors": (2, 4), + "upsample_mode": mode, + } + ) + +# 2x2x2 3D upsampling +for mode in ["nearest"]: + upsample_configs.append( + { + "name": "2x2x2 {} upsample".format(mode), + "scale_factors": (2, 2, 2), + "upsample_mode": mode, + } + ) + + +def create_parallel_strategy(num_height_groups): + return {"height_groups": num_height_groups} + + +def construct_model(lbann): + """Construct LBANN model. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + num_height_groups = tools.gpus_per_node(lbann) + if num_height_groups == 0: + e = "this test requires GPUs." + print("Skip - " + e) + pytest.skip(e) + + # Input data + # Note: Sum with a weights layer so that gradient checking will + # verify that error signals are correct. + x_weights = lbann.Weights( + optimizer=lbann.SGD(), + initializer=lbann.ConstantInitializer(value=0.0), + name="input_weights", + ) + x = lbann.Sum( + lbann.Reshape(lbann.Input(data_field="samples"), dims=_sample_dims), + lbann.WeightsLayer(weights=x_weights, dims=_sample_dims), + ) + x_lbann = x + + # Objects for LBANN model + obj = [] + metrics = [] + callbacks = [] + + # ------------------------------------------ + # Upsample + # ------------------------------------------ + + for u in upsample_configs: + uname = u["name"].split(" ")[0] + num_dims = len(u["scale_factors"]) + + # Apply upsampling + x = x_lbann + if len(u["scale_factors"]) == 3: + x = lbann.Reshape(x, dims=_sample_dims_3d) + x = lbann.Identity(x, name=f"in_{uname}") + + # Convolution settings + kernel_dims = [_sample_dims[0] if num_dims == 2 else _sample_dims_3d[0]] * 2 + [ + 1 + ] * num_dims + strides = [1] * num_dims + pads = [0] * num_dims + dilations = [1] * num_dims + + kernel = np.zeros(kernel_dims) + for i in range(len(kernel)): + if num_dims == 2: + kernel[i, i, 0, 0] = 1 + else: + kernel[i, i, 0, 0, 0] = 1 + + # Apply convolution + kernel_weights = lbann.Weights( + optimizer=lbann.SGD(), + initializer=lbann.ValueInitializer(values=np.nditer(kernel)), + name="kernel1_{}".format(uname), + ) + + x = lbann.Convolution( + x, + weights=(kernel_weights,), + num_dims=num_dims, + out_channels=kernel_dims[0], + kernel_size=kernel_dims[2:], + stride=strides, + padding=pads, + dilation=dilations, + has_bias=False, + parallel_strategy=create_parallel_strategy(num_height_groups), + name=f"conv1_{uname}", + ) + + y = lbann.Upsample( + x, + num_dims=len(u["scale_factors"]), + has_vectors=True, + scale_factors=u["scale_factors"], + upsample_mode=u["upsample_mode"], + parallel_strategy=create_parallel_strategy(num_height_groups), + name=f"upsample_{uname}", + ) + + # Convolution settings + kernel_dims = [_sample_dims[0] if num_dims == 2 else _sample_dims_3d[0]] * 2 + [ + 3 + ] * num_dims + strides = [1] * num_dims + pads = [1] * num_dims + dilations = [1] * num_dims + + kernel = np.zeros(kernel_dims) + for i in range(len(kernel)): + if num_dims == 2: + kernel[i, i, 1, 1] = 1 + else: + kernel[i, i, 1, 1, 1] = 1 + + # Apply convolution + kernel_weights = lbann.Weights( + optimizer=lbann.SGD(), + initializer=lbann.ValueInitializer(values=np.nditer(kernel)), + name="kernel2_{}".format(uname), + ) + + y = lbann.Convolution( + y, + weights=(kernel_weights,), + num_dims=num_dims, + out_channels=kernel_dims[0], + kernel_size=kernel_dims[2:], + stride=strides, + padding=pads, + dilation=dilations, + has_bias=False, + parallel_strategy=create_parallel_strategy(num_height_groups), + name=f"conv2_{uname}", + ) + + y = lbann.Identity(y, name=f"out_{uname}") + z = lbann.L2Norm2(y) + + obj.append(z) + + # Save the inputs and outputs to check later. + callbacks.append( + lbann.CallbackDumpOutputs( + layers=f"in_{uname} out_{uname}", directory="outputs" + ) + ) + + # ------------------------------------------ + # Gradient checking + # ------------------------------------------ + + callbacks.append(lbann.CallbackCheckGradients(error_on_failure=True)) + + # ------------------------------------------ + # Construct model + # ------------------------------------------ + + num_epochs = 0 + return lbann.Model( + num_epochs, + layers=lbann.traverse_layer_graph(x_lbann), + objective_function=obj, + metrics=metrics, + callbacks=callbacks, + ) + + +def construct_data_reader(lbann): + """Construct Protobuf message for Python data reader. + + The Python data reader will import the current Python file to + access the sample access functions. + + Args: + lbann (module): Module for LBANN Python frontend + + """ + + # Note: The training data reader should be removed when + # https://github.com/LLNL/lbann/issues/1098 is resolved. + message = lbann.reader_pb2.DataReader() + message.reader.extend( + [ + tools.create_python_data_reader( + lbann, current_file, "get_sample", "num_samples", "sample_dims", "train" + ) + ] + ) + message.reader.extend( + [ + tools.create_python_data_reader( + lbann, current_file, "get_sample", "num_samples", "sample_dims", "test" + ) + ] + ) + return message + + +# ============================================== +# Setup PyTest +# ============================================== + + +def check_output(lbann, weekly, **kwargs): + + # Check that the output values are close to the input values. + for u in upsample_configs: + uname = u["name"].split(" ")[0] + in_data = np.loadtxt( + os.path.join( + kwargs["work_dir"], + "outputs", + "trainer0", + "model0", + f"sgd.testing.epoch.0.step.0_in_{uname}_output0.csv", + ), + delimiter=",", + ) + out_data = np.loadtxt( + os.path.join( + kwargs["work_dir"], + "outputs", + "trainer0", + "model0", + f"sgd.testing.epoch.0.step.0_out_{uname}_output0.csv", + ), + delimiter=",", + ) + + ndims = len(u["scale_factors"]) + upsampled_data = in_data.copy().reshape( + [-1] + (_sample_dims if ndims == 2 else _sample_dims_3d) + ) + for i, scale_fac in enumerate(u["scale_factors"]): + if u["upsample_mode"] == "nearest": + upsampled_data = upsampled_data.repeat(scale_fac, axis=i + 2) + + assert np.allclose(upsampled_data.ravel(), out_data.ravel()) + + +# Runtime parameters/arguments +environment = lbann.contrib.args.get_distconv_environment() +environment["LBANN_KEEP_ERROR_SIGNALS"] = 1 + +# Create test functions that can interact with PyTest +# Note: Create test name by removing ".py" from file name +_test_name = os.path.splitext(os.path.basename(current_file))[0] +for _test_func in tools.create_tests( + setup_experiment, + _test_name, + post_test_func=check_output, + skip_clusters=["tioga", "corona"], + environment=environment, +): + globals()[_test_func.__name__] = _test_func diff --git a/include/lbann/layers/transform/CMakeLists.txt b/include/lbann/layers/transform/CMakeLists.txt index c152fe7deb9..691eb4edb9b 100644 --- a/include/lbann/layers/transform/CMakeLists.txt +++ b/include/lbann/layers/transform/CMakeLists.txt @@ -56,6 +56,7 @@ set_full_path(THIS_DIR_HEADERS scatter.hpp gather.hpp multidim_reduction.hpp + upsample.hpp ) if (LBANN_HAS_DISTCONV AND LBANN_HAS_NVSHMEM) diff --git a/include/lbann/layers/transform/transform_builders.hpp b/include/lbann/layers/transform/transform_builders.hpp index 26f655dbb83..79cdc70c39c 100644 --- a/include/lbann/layers/transform/transform_builders.hpp +++ b/include/lbann/layers/transform/transform_builders.hpp @@ -60,6 +60,7 @@ LBANN_DEFINE_LAYER_BUILDER(sum); LBANN_DEFINE_LAYER_BUILDER(tessellate); LBANN_DEFINE_LAYER_BUILDER(uniform); LBANN_DEFINE_LAYER_BUILDER(unpooling); +LBANN_DEFINE_LAYER_BUILDER(upsample); LBANN_DEFINE_LAYER_BUILDER(weighted_sum); LBANN_DEFINE_LAYER_BUILDER(weights); diff --git a/include/lbann/layers/transform/upsample.hpp b/include/lbann/layers/transform/upsample.hpp new file mode 100644 index 00000000000..798ccdc0ddb --- /dev/null +++ b/include/lbann/layers/transform/upsample.hpp @@ -0,0 +1,294 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#ifndef LBANN_LAYER_UPSAMPLE_HPP_INCLUDED +#define LBANN_LAYER_UPSAMPLE_HPP_INCLUDED + +#include "lbann/layers/data_type_layer.hpp" +#include "lbann/models/model.hpp" +#include "lbann/utils/dim_helpers.hpp" +#include "lbann/utils/dnn_enums.hpp" +#ifdef LBANN_HAS_DNN_LIB +#include "lbann/utils/dnn_lib/helpers.hpp" +#include "lbann/utils/dnn_lib/upsample.hpp" +#endif // LBANN_HAS_DNN_LIB +#include "lbann/utils/exception.hpp" + +#include +#include + +#ifdef LBANN_HAS_DISTCONV +#include "lbann/utils/distconv.hpp" +#endif // LBANN_HAS_DISTCONV + +namespace lbann { + +enum class upsample_mode +{ + NEAREST +}; + +inline upsample_mode to_upsample_mode(std::string m) +{ + if (m == "nearest") + return upsample_mode::NEAREST; + else { + LBANN_ERROR("Invalid upsample mode requested."); + } +} + +#ifdef LBANN_HAS_DISTCONV + +namespace dc { +using Shape = ::distconv::tensor::Shape; +using Backend = ::distconv::BackendDNNLib; +} // namespace dc + +template +class upsample_distconv_adapter + : public data_type_distconv_adapter +{ +public: + using TensorDevType = + typename data_type_distconv_adapter::TensorDevType; + upsample_distconv_adapter(Layer& layer) + : data_type_distconv_adapter(layer) + {} + virtual ~upsample_distconv_adapter() = default; + dc::Shape get_activations_local_shape(int index = 0) const override; + void setup_layer(size_t workspace_capacity) override; + void + fp_compute(bool training = true); // training=true for max back-compatibility. + void bp_compute(); + +private: + dnn_lib::TensorDescriptor m_xdesc; + dnn_lib::TensorDescriptor m_ydesc; + dnn_lib::TensorDescriptor m_dxdesc; + dnn_lib::TensorDescriptor m_dydesc; +}; +#endif // LBANN_HAS_DISTCONV + +template +class upsample_layer : public data_type_layer +{ + static_assert(T_layout == data_layout::DATA_PARALLEL, + "upsample only supports DATA_PARALLEL"); + +private: + /** Upsample mode. */ + upsample_mode m_upsample_mode; + + /** Output scale factors. */ + std::vector m_scale_factors; + +#ifdef LBANN_HAS_DNN_LIB + /** Pooling descriptor. */ + dnn_lib::PoolingDescriptor m_pooling_dnn_desc; + /** Tensor DNN library descriptors. */ + dnn_lib::data_parallel_layer_tensor_manager + m_tensors_dnn_desc; +#endif // LBANN_HAS_DNN_LIB + +public: + upsample_layer(lbann_comm* comm, + int num_data_dims, + int scale_factors, + upsample_mode mode) + : upsample_layer(comm, + num_data_dims, + std::vector(num_data_dims, scale_factors), + mode) + {} + + upsample_layer(lbann_comm* comm, + int num_data_dims, + std::vector scale_factors, + upsample_mode mode) + : data_type_layer(comm), + m_upsample_mode(mode), + m_scale_factors{std::move(scale_factors)} +#ifdef LBANN_HAS_DNN_LIB + , + m_tensors_dnn_desc(this) +#endif // LBANN_HAS_DNN_LIB + {} + + upsample_layer(const upsample_layer& other) + : data_type_layer(other), + m_upsample_mode(other.m_upsample_mode), + m_scale_factors(other.m_scale_factors) +#ifdef LBANN_HAS_DNN_LIB + , + m_pooling_dnn_desc(other.m_pooling_dnn_desc), + m_tensors_dnn_desc(other.m_tensors_dnn_desc) +#endif // LBANN_HAS_DNN_LIB + { +#ifdef LBANN_HAS_DNN_LIB + m_tensors_dnn_desc.set_layer(this); +#endif // LBANN_HAS_DNN_LIB + } + + upsample_layer& operator=(const upsample_layer& other) + { + data_type_layer::operator=(other); + m_upsample_mode = other.m_upsample_mode; + m_scale_factors = other.m_scale_factors; +#ifdef LBANN_HAS_DNN_LIB + m_pooling_dnn_desc = other.m_pooling_dnn_desc; + m_tensors_dnn_desc = other.m_tensors_dnn_desc; + m_tensors_dnn_desc.set_layer(this); +#endif // LBANN_HAS_DNN_LIB + return *this; + } + + ~upsample_layer() override = default; + + upsample_layer* copy() const override { return new upsample_layer(*this); } + + /** @name Serialization */ + ///@{ + + template + void serialize(ArchiveT& ar); + + ///@} + + std::string get_type() const override { return "upsample"; } + data_layout get_data_layout() const override { return T_layout; } + El::Device get_device_allocation() const override { return Dev; } + +#ifdef LBANN_HAS_ONNX + void fill_onnx_node(onnx::GraphProto& graph) const override; +#endif // LBANN_HAS_ONNX + + description get_description() const override + { + auto desc = data_type_layer::get_description(); + + // Upsample mode + std::string mode_str; + switch (m_upsample_mode) { + case upsample_mode::NEAREST: + desc.add("Upsample mode", "nearest"); + break; + default: + desc.add("Upsample mode", "invalid"); + } + + // Upsample scale factors + std::ostringstream ss; + for (size_t i = 0; i < m_scale_factors.size(); ++i) { + ss << (i > 0 ? ", " : "") << m_scale_factors[i]; + } + desc.add("Scale factors", ss.str()); + + // Result + return desc; + } + +protected: + /** Add layer specific data to prototext */ + void write_specific_proto(lbann_data::Layer& proto) const final; + + friend class cereal::access; + upsample_layer() : upsample_layer(nullptr, 1, 1, upsample_mode::NEAREST) {} + + void setup_dims() override + { + data_type_layer::setup_dims(); + const auto& input_dims = this->get_input_dims(); + auto output_dims = input_dims; + for (size_t i = 0; i < output_dims.size() - 1; ++i) { + output_dims[i + 1] = m_scale_factors[i] * input_dims[i + 1]; + } + this->set_output_dims(output_dims); + } + + /// Initialize GPU objects + void setup_gpu() override + { + data_type_layer::setup_gpu(); +#ifndef LBANN_HAS_DNN_LIB + LBANN_ERROR("DNN library not detected"); +#else + + // Set upsample descriptor + int ndims = m_scale_factors.size(); + std::vector padding(ndims, 0); + m_pooling_dnn_desc.set(pooling_mode::AVERAGE_COUNT_INCLUDE_PADDING, + dnn_lib::DNN_PROPAGATE_NAN, + ndims, + m_scale_factors.data(), + padding.data(), + m_scale_factors.data()); + +#endif // #ifndef LBANN_HAS_DNN_LIB + } + + void fp_compute() override; + + void bp_compute() override; + +private: + /// Upsample forward propagation with DNN library + void fp_compute_dnn(); + + /// Upsample backward propagation with DNN library + void bp_compute_dnn(); + +#ifdef LBANN_HAS_DISTCONV + friend class upsample_distconv_adapter; + +protected: + bool is_distconv_supported() const override; + void setup_distconv_adapter() override + { + this->get_distconv_adapter_ptr() = std::make_unique< + upsample_distconv_adapter>(*this); + } + upsample_distconv_adapter& + get_distconv_adapter() override; + const upsample_distconv_adapter& + get_distconv_adapter() const override; +#endif // LBANN_HAS_DISTCONV +}; + +#ifndef LBANN_UPSAMPLE_LAYER_INSTANTIATE +#define PROTO_DEVICE(T, Device) \ + extern template class upsample_layer + +#include "lbann/macros/instantiate_device.hpp" +#undef PROTO_DEVICE +#endif // LBANN_UPSAMPLE_LAYER_INSTANTIATE + +} // namespace lbann + +#endif // LBANN_LAYER_UPSAMPLE_HPP_INCLUDED diff --git a/include/lbann/lbann.hpp b/include/lbann/lbann.hpp index eb01662cdf3..9b7871fbb53 100644 --- a/include/lbann/lbann.hpp +++ b/include/lbann/lbann.hpp @@ -95,6 +95,7 @@ #include "lbann/layers/transform/tessellate.hpp" #include "lbann/layers/transform/uniform.hpp" #include "lbann/layers/transform/unpooling.hpp" +#include "lbann/layers/transform/upsample.hpp" #include "lbann/layers/transform/weighted_sum.hpp" #include "lbann/layers/transform/weights.hpp" diff --git a/include/lbann/utils/dnn_lib/cudnn/upsample.hpp b/include/lbann/utils/dnn_lib/cudnn/upsample.hpp new file mode 100644 index 00000000000..5348ca4ac7f --- /dev/null +++ b/include/lbann/utils/dnn_lib/cudnn/upsample.hpp @@ -0,0 +1,139 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// +#ifndef LBANN_UTILS_DNN_LIB_CUDNN_UPSAMPLE_HPP_ +#define LBANN_UTILS_DNN_LIB_CUDNN_UPSAMPLE_HPP_ + +#include "lbann/utils/dnn_enums.hpp" +#include "lbann/utils/dnn_lib/helpers.hpp" +#include "lbann/utils/gpu/helpers.hpp" + +#include "utils.hpp" + +namespace lbann { + +#ifdef LBANN_HAS_CUDNN +namespace dnn_lib { + +using namespace cudnn; + +template +void upsample_nearest_forward(PoolingDescriptor const& poolingDesc, + ScalarParameterType const& alpha_in, + TensorDescriptor const& xDesc, + TensorDataType const* x, + ScalarParameterType const& beta_in, + TensorDescriptor const& yDesc, + TensorDataType* y, + dnnHandle_t handle) +{ + using LibScalingParamT = dnn_lib::ScalingParamType; + auto alpha = El::To(alpha_in); + auto beta = El::To(beta_in); + CHECK_CUDNN(cudnnPoolingBackward(handle, + poolingDesc, + &alpha, + NULL, + NULL, + xDesc, + x, + NULL, + NULL, + &beta, + yDesc, + y)); +} + +template +void upsample_nearest_forward(PoolingDescriptor const& poolingDesc, + ScalarParameterType const& alpha_in, + TensorDescriptor const& xDesc, + El::AbstractMatrix const& x, + ScalarParameterType const& beta_in, + TensorDescriptor const& yDesc, + El::AbstractMatrix& y) +{ + auto multisync = + El::MakeMultiSync(gpu::get_sync_info(y), gpu::get_sync_info(x)); + auto handle_manager = internal::make_default_handle_manager(multisync); + upsample_nearest_forward(poolingDesc, + alpha_in, + xDesc, + x.LockedBuffer(), + beta_in, + yDesc, + y.Buffer(), + handle_manager.get()); +} + +template +void upsample_nearest_backward(PoolingDescriptor const& poolingDesc, + ScalarParameterType const& alpha_in, + TensorDescriptor const& dyDesc, + TensorDataType const* dy, + ScalarParameterType const& beta_in, + TensorDescriptor const& dxDesc, + TensorDataType* dx, + dnnHandle_t handle) +{ + using LibScalingParamT = dnn_lib::ScalingParamType; + auto alpha = El::To(alpha_in); + auto beta = El::To(beta_in); + CHECK_CUDNN(cudnnPoolingForward(handle, + poolingDesc, + &alpha, + dyDesc, + dy, + &beta, + dxDesc, + dx)); +} + +template +void upsample_nearest_backward(PoolingDescriptor const& poolingDesc, + ScalarParameterType const& alpha_in, + TensorDescriptor const& dyDesc, + El::AbstractMatrix const& dy, + ScalarParameterType const& beta_in, + TensorDescriptor const& dxDesc, + El::AbstractMatrix& dx) +{ + auto multisync = + El::MakeMultiSync(gpu::get_sync_info(dx), gpu::get_sync_info(dy)); + auto handle_manager = internal::make_default_handle_manager(multisync); + upsample_nearest_backward(poolingDesc, + alpha_in, + dyDesc, + dy.LockedBuffer(), + beta_in, + dxDesc, + dx.Buffer(), + handle_manager.get()); +} + +} // namespace dnn_lib +#endif // LBANN_HAS_CUDNN +} // namespace lbann +#endif // LBANN_UTILS_DNN_LIB_CUDNN_UPSAMPLE_HPP_ diff --git a/include/lbann/utils/dnn_lib/miopen/upsample.hpp b/include/lbann/utils/dnn_lib/miopen/upsample.hpp new file mode 100644 index 00000000000..c89179f7cfc --- /dev/null +++ b/include/lbann/utils/dnn_lib/miopen/upsample.hpp @@ -0,0 +1,115 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// +#ifndef LBANN_UTILS_DNN_LIB_MIOPEN_UPSAMPLE_HPP_ +#define LBANN_UTILS_DNN_LIB_MIOPEN_UPSAMPLE_HPP_ + +#include "lbann/utils/dnn_enums.hpp" +#include "lbann/utils/dnn_lib/helpers.hpp" +#include "lbann/utils/gpu/helpers.hpp" + +#include "utils.hpp" + +namespace lbann { + +#ifdef LBANN_HAS_MIOPEN +namespace dnn_lib { + +using namespace miopen; + +template +void upsample_nearest_forward(PoolingDescriptor const& poolingDesc, + ScalarParameterType const& alpha_in, + TensorDescriptor const& xDesc, + TensorDataType const* x, + ScalarParameterType const& beta_in, + TensorDescriptor const& yDesc, + TensorDataType* y, + dnnHandle_t handle) +{ + LBANN_ERROR("Upsample layer is not currently supported on ROCm devices."); +} + +template +void upsample_nearest_forward(PoolingDescriptor const& poolingDesc, + ScalarParameterType const& alpha_in, + TensorDescriptor const& xDesc, + El::AbstractMatrix const& x, + ScalarParameterType const& beta_in, + TensorDescriptor const& yDesc, + El::AbstractMatrix& y) +{ + auto multisync = + El::MakeMultiSync(gpu::get_sync_info(y), gpu::get_sync_info(x)); + auto handle_manager = internal::make_default_handle_manager(multisync); + upsample_nearest_forward(poolingDesc, + alpha_in, + xDesc, + x.LockedBuffer(), + beta_in, + yDesc, + y.Buffer(), + handle_manager.get()); +} + +template +void upsample_nearest_backward(PoolingDescriptor const& poolingDesc, + ScalarParameterType const& alpha_in, + TensorDescriptor const& dyDesc, + TensorDataType const* dy, + ScalarParameterType const& beta_in, + TensorDescriptor const& dxDesc, + TensorDataType* dx, + dnnHandle_t handle) +{ + LBANN_ERROR("Upsample layer is not currently supported on ROCm devices."); +} + +template +void upsample_nearest_backward(PoolingDescriptor const& poolingDesc, + ScalarParameterType const& alpha_in, + TensorDescriptor const& dyDesc, + El::AbstractMatrix const& dy, + ScalarParameterType const& beta_in, + TensorDescriptor const& dxDesc, + El::AbstractMatrix& dx) +{ + auto multisync = + El::MakeMultiSync(gpu::get_sync_info(dx), gpu::get_sync_info(dy)); + auto handle_manager = internal::make_default_handle_manager(multisync); + upsample_nearest_backward(poolingDesc, + alpha_in, + dyDesc, + dy.LockedBuffer(), + beta_in, + dxDesc, + dx.Buffer(), + handle_manager.get()); +} + +} // namespace dnn_lib +#endif // LBANN_HAS_MIOPEN +} // namespace lbann +#endif // LBANN_UTILS_DNN_LIB_MIOPEN_UPSAMPLE_HPP_ diff --git a/include/lbann/utils/dnn_lib/upsample.hpp b/include/lbann/utils/dnn_lib/upsample.hpp new file mode 100644 index 00000000000..430c4d8e27b --- /dev/null +++ b/include/lbann/utils/dnn_lib/upsample.hpp @@ -0,0 +1,41 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// +#ifndef LBANN_UTILS_DNN_LIB_UPSAMPLE_HPP +#define LBANN_UTILS_DNN_LIB_UPSAMPLE_HPP + +#include "lbann_config.hpp" + +#if defined LBANN_HAS_CUDNN +#include "lbann/utils/dnn_lib/cudnn/upsample.hpp" +#elif defined LBANN_HAS_MIOPEN +#include "lbann/utils/dnn_lib/miopen/upsample.hpp" +#else +static_assert(false, + "This file must be included only if there is support from a " + "valid DNN library."); +#endif // LBANN_HAS_CUDNN + +#endif // LBANN_UTILS_DNN_LIB_UPSAMPLE_HPP diff --git a/src/base.cpp b/src/base.cpp index 1d41324238d..3b00c67b269 100644 --- a/src/base.cpp +++ b/src/base.cpp @@ -522,6 +522,7 @@ CEREAL_FORCE_DYNAMIC_INIT(tessellate_layer); CEREAL_FORCE_DYNAMIC_INIT(top_k_categorical_accuracy_layer); CEREAL_FORCE_DYNAMIC_INIT(uniform_layer); CEREAL_FORCE_DYNAMIC_INIT(unpooling_layer); +CEREAL_FORCE_DYNAMIC_INIT(upsample_layer); CEREAL_FORCE_DYNAMIC_INIT(variance_layer); CEREAL_FORCE_DYNAMIC_INIT(weighted_sum_layer); CEREAL_FORCE_DYNAMIC_INIT(weights_layer); diff --git a/src/layers/transform/CMakeLists.txt b/src/layers/transform/CMakeLists.txt index e843eb22277..2c1caf5de0d 100644 --- a/src/layers/transform/CMakeLists.txt +++ b/src/layers/transform/CMakeLists.txt @@ -54,6 +54,7 @@ set_full_path(THIS_DIR_SOURCES tessellate.cpp uniform.cpp unpooling.cpp + upsample.cpp weighted_sum.cpp weights.cpp diff --git a/src/layers/transform/cereal_registration/CMakeLists.txt b/src/layers/transform/cereal_registration/CMakeLists.txt index 425d82b5bc3..f246776eb58 100644 --- a/src/layers/transform/cereal_registration/CMakeLists.txt +++ b/src/layers/transform/cereal_registration/CMakeLists.txt @@ -52,6 +52,7 @@ set_full_path(THIS_DIR_SOURCES tessellate.cpp uniform.cpp unpooling.cpp + upsample.cpp weighted_sum.cpp weights.cpp ) diff --git a/src/layers/transform/cereal_registration/upsample.cpp b/src/layers/transform/cereal_registration/upsample.cpp new file mode 100644 index 00000000000..cd088b01986 --- /dev/null +++ b/src/layers/transform/cereal_registration/upsample.cpp @@ -0,0 +1,45 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// +#include "lbann/utils/serialize.hpp" +#include + +namespace lbann { + +template +template +void upsample_layer::serialize(ArchiveT& ar) +{ + using DataTypeLayer = data_type_layer; + ar(::cereal::make_nvp("DataTypeLayer", + ::cereal::base_class(this)), + CEREAL_NVP(m_upsample_mode), + CEREAL_NVP(m_scale_factors)); +} + +} // namespace lbann + +#define LBANN_LAYER_NAME upsample_layer +#include diff --git a/src/layers/transform/upsample.cpp b/src/layers/transform/upsample.cpp new file mode 100644 index 00000000000..4d6778347b7 --- /dev/null +++ b/src/layers/transform/upsample.cpp @@ -0,0 +1,349 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#define LBANN_UPSAMPLE_LAYER_INSTANTIATE +#include "lbann/layers/transform/upsample.hpp" + +#include "lbann/execution_algorithms/execution_context.hpp" +#include "lbann/proto/datatype_helpers.hpp" +#include "lbann/proto/proto_common.hpp" +#include "lbann/utils/protobuf.hpp" + +#ifdef LBANN_HAS_DISTCONV +#include "lbann/layers/data_type_distconv_adapter.hpp" +using dc_backend = ::distconv::GPUDNNBackend; +#endif // LBANN_HAS_DISTCONV + +#include "lbann/proto/layers.pb.h" +#include "lbann/proto/lbann.pb.h" + +namespace lbann { +namespace { + +template +struct Builder +{ + template + static std::unique_ptr Build(Args&&...) + { + LBANN_ERROR("Attempted to instantiate layer \"upsample\" with " + "Layout=", + to_string(L), + ", Device=", + El::DeviceName(), + ".\nThis layer is only " + "supported on GPU with DATA_PARALLEL data layout."); + return nullptr; + } +}; + +#ifdef LBANN_HAS_GPU +template +struct Builder +{ + template + static std::unique_ptr Build(Args&&... args) + { + using LayerType = upsample_layer; + return std::make_unique(std::forward(args)...); + } +}; +#endif // LBANN_HAS_GPU +} // namespace + +template +void upsample_layer::fp_compute() +{ + if (this->using_gpus()) { +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { + const auto& mode = + this->m_model->get_execution_context().get_execution_mode(); + get_distconv_adapter().fp_compute(mode == execution_mode::training); + return; + } +#endif // LBANN_HAS_DISTCONV + fp_compute_dnn(); + } + else { + LBANN_ERROR("Upsampling with CPU is not implemented."); + } +} + +template +void upsample_layer::bp_compute() +{ + if (this->using_gpus()) { +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { + get_distconv_adapter().bp_compute(); + return; + } +#endif // LBANN_HAS_DISTCONV + bp_compute_dnn(); + } + else { + LBANN_ERROR("Upsampling with CPU is not implemented."); + } +} + +/// Upsample forward propagation with DNN library +template +void upsample_layer::fp_compute_dnn() +{ +#ifndef LBANN_HAS_DNN_LIB + LBANN_ERROR("DNN library not detected"); +#else + using ScalingType = dnn_lib::ScalingParamType; + const auto& local_input = this->get_local_prev_activations(); + auto& local_output = this->get_local_activations(); + if (local_input.Height() > 0 && local_input.Width() > 0) { + const auto zero = El::TypeTraits::Zero(); + const auto alpha = El::To(get_linear_size(m_scale_factors)); + dnn_lib::upsample_nearest_forward(m_pooling_dnn_desc, + alpha, + m_tensors_dnn_desc.get_prev_activations(), + local_input, + zero, + m_tensors_dnn_desc.get_activations(), + local_output); + } +#endif // #ifndef LBANN_HAS_DNN_LIB +} + +/// Upsample backward propagation with DNN library +template +void upsample_layer::bp_compute_dnn() +{ +#ifndef LBANN_HAS_DNN_LIB + LBANN_ERROR("DNN library not detected"); +#else + using ScalingType = dnn_lib::ScalingParamType; + const auto& local_gradient_wrt_output = this->get_local_prev_error_signals(); + auto& local_gradient_wrt_input = this->get_local_error_signals(); + if (local_gradient_wrt_output.Height() > 0 && + local_gradient_wrt_output.Width() > 0) { + + // Useful constants + const auto alpha = El::To(get_linear_size(m_scale_factors)); + const auto zero = El::TypeTraits::Zero(); + + // Perform backprop on GPU + dnn_lib::upsample_nearest_backward( + m_pooling_dnn_desc, + alpha, + m_tensors_dnn_desc.get_prev_error_signals(), + local_gradient_wrt_output, + zero, + m_tensors_dnn_desc.get_error_signals(), + local_gradient_wrt_input); + } +#endif // #ifndef LBANN_HAS_DNN_LIB +} + +template +void upsample_layer::write_specific_proto( + lbann_data::Layer& proto) const +{ + proto.set_datatype(proto::ProtoDataType); + auto* msg = proto.mutable_upsample(); + switch (m_upsample_mode) { + case upsample_mode::NEAREST: + msg->set_upsample_mode("nearest"); + break; + default: + LBANN_ERROR("Invalid upsample mode requested."); + } + msg->set_num_dims(m_scale_factors.size()); + msg->set_has_vectors(true); + protobuf::assign_to_repeated(*msg->mutable_scale_factors(), m_scale_factors); +} + +#ifdef LBANN_HAS_DISTCONV +template +upsample_distconv_adapter& +upsample_layer::get_distconv_adapter() +{ + return const_cast&>( + static_cast&>(*this) + .get_distconv_adapter()); +} + +template +const upsample_distconv_adapter& +upsample_layer::get_distconv_adapter() const +{ + return dynamic_cast< + const upsample_distconv_adapter&>( + data_type_layer::get_distconv_adapter()); +} + +template +bool upsample_layer::is_distconv_supported() + const +{ + return Dev == El::Device::GPU && T_layout == data_layout::DATA_PARALLEL; +} + +template +dc::Shape upsample_distconv_adapter:: + get_activations_local_shape(int index) const +{ + assert_eq(index, 0); + const auto& layer = + dynamic_cast&>( + this->layer()); + auto scale_factors = layer.m_scale_factors; + std::reverse(std::begin(scale_factors), std::end(scale_factors)); + auto output_spatial_local_shape = + this->get_prev_activations(index).get_local_shape(); + for (size_t i = 0; i < scale_factors.size(); i++) { + output_spatial_local_shape[i] *= scale_factors[i]; + } + return output_spatial_local_shape; +} + +template +void upsample_distconv_adapter::setup_layer( + size_t workspace_capacity) +{ + m_xdesc.create(); + m_ydesc.create(); + m_dxdesc.create(); + m_dydesc.create(); + + auto& l = dynamic_cast&>( + this->layer()); + + std::string mode; + switch (l.m_upsample_mode) { + case upsample_mode::NEAREST: + mode = "nearest"; + break; + default: + LBANN_ERROR("upsample_layer: no DISTCONV implementation for upsample mode"); + } +} + +template +void upsample_distconv_adapter::fp_compute( + bool const training) +{ + auto& l = dynamic_cast&>( + this->layer()); + + auto& prev_activations = this->get_prev_activations(); + auto& activations = this->get_activations(); + + auto xdesc = const_cast(m_xdesc.get()); + auto ydesc = const_cast(m_ydesc.get()); + dc_backend::setup_tensor_descriptor(xdesc, + prev_activations, + prev_activations.get_local_shape()); + dc_backend::setup_tensor_descriptor(ydesc, + activations, + activations.get_local_shape()); + + using ScalingType = dnn_lib::ScalingParamType; + const auto zero = El::TypeTraits::Zero(); + const auto alpha = El::To(get_linear_size(l.m_scale_factors)); + + dnn_lib::upsample_nearest_forward(l.m_pooling_dnn_desc, + alpha, + m_xdesc, + prev_activations.get_const_base_ptr(), + zero, + m_ydesc, + activations.get_base_ptr(), + dc::get_backend().get_handle()); +} + +template +void upsample_distconv_adapter::bp_compute() +{ + auto& l = dynamic_cast&>( + this->layer()); + + auto& prev_error_signals = this->get_prev_error_signals(); + auto& error_signals = this->get_error_signals(); + + auto dxdesc = const_cast(m_dxdesc.get()); + auto dydesc = const_cast(m_dydesc.get()); + dc_backend::setup_tensor_descriptor(dxdesc, + error_signals, + error_signals.get_local_shape()); + dc_backend::setup_tensor_descriptor(dydesc, + prev_error_signals, + prev_error_signals.get_local_shape()); + + using ScalingType = dnn_lib::ScalingParamType; + const auto zero = El::TypeTraits::Zero(); + const auto alpha = El::To(get_linear_size(l.m_scale_factors)); + + dnn_lib::upsample_nearest_backward(l.m_pooling_dnn_desc, + alpha, + m_dydesc, + prev_error_signals.get_const_base_ptr(), + zero, + m_dxdesc, + error_signals.get_base_ptr(), + dc::get_backend().get_handle()); +} +#endif // LBANN_HAS_DISTCONV + +template +std::unique_ptr +build_upsample_layer_from_pbuf(lbann_comm* comm, + lbann_data::Layer const& proto_layer) +{ + LBANN_ASSERT_MSG_HAS_FIELD(proto_layer, upsample); + + using BuilderType = Builder; + const auto& params = proto_layer.upsample(); + upsample_mode const mode = to_upsample_mode(params.upsample_mode()); + if (params.has_vectors()) { + return BuilderType::Build(comm, + params.scale_factors_size(), + protobuf::to_vector(params.scale_factors()), + mode); + } + else { + return BuilderType::Build(comm, + params.num_dims(), + params.scale_factors_i(), + mode); + } +} + +#define PROTO_DEVICE(T, Device) \ + template class upsample_layer; \ + LBANN_LAYER_BUILDER_ETI(upsample, T, Device) + +#include "lbann/macros/instantiate_device.hpp" + +} // namespace lbann diff --git a/src/proto/factories/layer_factory.cpp b/src/proto/factories/layer_factory.cpp index add478c1686..e1577f08521 100644 --- a/src/proto/factories/layer_factory.cpp +++ b/src/proto/factories/layer_factory.cpp @@ -160,6 +160,7 @@ class factory_manager LBANN_REGISTER_BUILDER(Tessellate, tessellate); LBANN_REGISTER_BUILDER(Uniform, uniform); LBANN_REGISTER_BUILDER(Unpooling, unpooling); + LBANN_REGISTER_BUILDER(Upsample, upsample); LBANN_REGISTER_BUILDER(WeightedSum, weighted_sum); LBANN_REGISTER_BUILDER(WeightsLayer, weights); diff --git a/src/proto/layers.proto b/src/proto/layers.proto index 12a936a766c..9e726ec188b 100644 --- a/src/proto/layers.proto +++ b/src/proto/layers.proto @@ -169,6 +169,7 @@ message Layer { TensorPermute permute = 77; IdentityZero identity_zero = 78; MultiDimReduction multidim_reduction = 79; + Upsample upsample = 80; CategoricalRandom categorical_random = 406; // Deprecated DiscreteRandom discrete_random = 407; // Deprecated @@ -1058,6 +1059,47 @@ message Layer { */ message IdentityZero {} + /** @brief Upsample + * + * Spatially upsamples a tensor. + */ + message Upsample { + /** @brief Upsample operation + * + * Options: nearest + */ + string upsample_mode = 1; + + /** @brief Number of spatial dimensions + * + * The first data dimension is treated as the channel dimension, + * and all others are treated as spatial dimensions (recall that + * the mini-batch dimension is implicit). + */ + int64 num_dims = 2; + + /** @brief Whether to use vector-valued options + * + * If true, then the pooling is configured with @c pool_dims, + * @c pool_pads, @c pool_strides. Otherwise, @c pool_dims_i, + * @c pool_pads_i, @c pool_strides_i. + */ + bool has_vectors = 3; + + /** @brief Upsampling scale factors (vector-valued) + * + * List of integers, one for each spatial + * dimension. Used when @c has_vectors is enabled. + */ + repeated int64 scale_factors = 4; + + /** @brief Upsampling scale factor dimension (integer-valued) + * + * Used when @c has_vectors is disabled. + */ + int64 scale_factors_i = 5; + } + /// Deprecated message CategoricalRandom {} /// Deprecated