diff --git a/.gitignore b/.gitignore index 5c5c3148ef19a..fbd3ed029dfc9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,16 @@ +# aten/src/THH/ +# c10/hip +# aten/src/ATen/hip +# aten/src/ATen/native/hip +# aten/src/ATen/native/cudnn/hip +# aten/src/ATen/native/nested/hip +# aten/src/ATen/native/quantized/cudnn/hip +# aten/src/ATen/native/quantized/hip +# aten/src/ATen/native/transformers/hip +# aten/src/ATen/test/hip +# aten/src/ATen/test/test_install/hip +# binaries/hip +# aten/src/ATen/native/sparse/hip/ # READ THIS BEFORE YOU REFACTOR ME # # setup.py uses the list of patterns in this file to decide diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000000..e15c98b3a0ba0 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,164 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "bf16 nhwc v2", + "type": "debugpy", + "request": "launch", + "pythonArgs": ["-u"], + "cwd": "${workspaceFolder}/test", + "program": "test_nn.py", + "console": "integratedTerminal", + "args": [ + "-v", + "-k", + "test_batchnorm_nhwc_miopen_cuda_bfloat16" + ], + "env": { + "MIOPEN_ENABLE_LOGGING_CMD": "1", + "PYTORCH_MIOPEN_EXTRA_LOGGING": "1", + "PYTORCH_MIOPEN_USE_API_V2": "1", + // "MIOPEN_LOG_LEVEL": "6", + // "MIOPEN_ENABLE_LOGGING": "1", + // "AMD_LOG_LEVEL": "6", + } + }, + { + "name": "bf16 nhwc v1", + "type": "debugpy", + "request": "launch", + "pythonArgs": ["-u"], + "cwd": "${workspaceFolder}/test", + "program": "test_nn.py", + "console": "integratedTerminal", + "args": [ + "-v", + "-k", + "test_batchnorm_nhwc_miopen_cuda_bfloat16" + ], + "env": { + "MIOPEN_ENABLE_LOGGING_CMD": "1", + "PYTORCH_MIOPEN_EXTRA_LOGGING": "1", + // "PYTORCH_MIOPEN_USE_API_V2": "1", + // "MIOPEN_LOG_LEVEL": "6", + // "MIOPEN_ENABLE_LOGGING": "1", + // "AMD_LOG_LEVEL": "6", + } + }, + { + "name": "bf16 nchw", + "type": "debugpy", + "request": "launch", + "pythonArgs": ["-u"], + "cwd": "${workspaceFolder}/test", + "program": "test_nn.py", + "console": "integratedTerminal", + "args": [ + "-v", + "-k", + "test_batchnorm_nchw_miopen_cuda_bfloat16" + ], + "env": { + "MIOPEN_ENABLE_LOGGING_CMD": "1", + "PYTORCH_MIOPEN_EXTRA_LOGGING": "1", + "PYTORCH_MIOPEN_USE_API_V2": "1", + // "MIOPEN_LOG_LEVEL": "6", + // "MIOPEN_ENABLE_LOGGING": "1", + // "AMD_LOG_LEVEL": "6", + } + }, + { + "name": "fp16 nhwc", + "type": "debugpy", + "request": "launch", + "cwd": "${workspaceFolder}/test", + "program": "test_nn.py", + "console": "integratedTerminal", + "args": [ + "-v", + "-k", + "test_batchnorm_nhwc_miopen_cuda_float16" + ], + "env": { + "MIOPEN_ENABLE_LOGGING_CMD": "1", + "PYTORCH_MIOPEN_EXTRA_LOGGING": "1", + "PYTORCH_MIOPEN_USE_API_V2": "1", + } + }, + { + "name": "fp16 nchw", + "type": "debugpy", + "request": "launch", + "cwd": "${workspaceFolder}/test", + "program": "test_nn.py", + "console": "integratedTerminal", + "args": [ + "-v", + "-k", + "test_batchnorm_nchw_miopen_cuda_float16" + ], + "env": { + "MIOPEN_ENABLE_LOGGING_CMD": "1", + "PYTORCH_MIOPEN_EXTRA_LOGGING": "1", + "PYTORCH_MIOPEN_USE_API_V2": "1", + } + }, + { + "name": "fp32 nChw", + "type": "debugpy", + "request": "launch", + "cwd": "${workspaceFolder}/test", + "program": "test_nn.py", + "console": "integratedTerminal", + "args": [ + "-v", + "-k", + "test_batchnorm_nchw_miopen_cuda_float32", + ], + "env": { + "MIOPEN_ENABLE_LOGGING_CMD": "1", + "PYTORCH_MIOPEN_EXTRA_LOGGING": "1", + "PYTORCH_MIOPEN_USE_API_V2": "1", + } + }, + { + "name": "fp32 nHwc", + "type": "debugpy", + "request": "launch", + "cwd": "${workspaceFolder}/test", + "program": "test_nn.py", + "console": "integratedTerminal", + "args": [ + "-v", + "-k", + "test_batchnorm_nhwc_miopen_cuda_float32" + ], + "env": { + "MIOPEN_ENABLE_LOGGING_CMD": "1", + "PYTORCH_MIOPEN_EXTRA_LOGGING": "1", + "PYTORCH_MIOPEN_USE_API_V2": "1", + } + }, + { + "name": "eval", + "type": "debugpy", + "request": "launch", + "cwd": "${workspaceFolder}/test", + "program": "test_nn.py", + "console": "integratedTerminal", + "args": [ + "-v", + "-k", + "test_batchnorm_nhwc_cuda" + ], + "env": { + "MIOPEN_ENABLE_LOGGING_CMD": "1", + "PYTORCH_MIOPEN_EXTRA_LOGGING": "1", + "PYTORCH_MIOPEN_USE_API_V2": "1", + } + } + ] +} \ No newline at end of file diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 0d2b6bfced09e..247fb6b44f12e 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -61,6 +61,7 @@ #include #include #include +#include static const int MIOPEN_DIM_MAX = 5; @@ -153,7 +154,6 @@ std::tuple batch_norm_cpu_transform_input_template( } return std::make_tuple(output, save_mean, save_invstd); } - const int64_t ndim = input.dim(); // Helper to convert 1d tensors to an nd tensor that broadcasts with input // All elements go into the channel dimension @@ -484,10 +484,13 @@ std::tuple batch_norm_backward_cpu_template( return std::make_tuple(grad_input, grad_weight, grad_bias); } +bool PYTORCH_MIOPEN_EXTRA_LOGGING = c10::utils::check_env("PYTORCH_MIOPEN_EXTRA_LOGGING").value_or(false); + BatchNormBackend _select_batch_norm_backend( const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps) { - + if (at::native::PYTORCH_MIOPEN_EXTRA_LOGGING) + std :: cout << "********************* _select_batch_norm_backend" << std::endl; auto& ctx = at::globalContext(); bool cudnn_enabled = ctx.userEnabledCuDNN(); @@ -514,25 +517,44 @@ BatchNormBackend _select_batch_norm_backend( // See #64427 // non static variable is used to be able to change environment variable in runtime for testing bool PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC").value_or(false); - + + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "**+** SUGGEST_NHWC=" << PYTORCH_MIOPEN_SUGGEST_NHWC + << " cudnn_enabled=" << cudnn_enabled + << " dim=" << input.dim() + << " memory_format=" << input.suggest_memory_format() + << " input.dtype=" << input.scalar_type() + << " weight.dtype=" << (weight.defined()?"+":"-") << weight.scalar_type() + << " bias.dtype=" << (bias.defined()?"+":"-") << bias.scalar_type() + << " running_mean.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type() + << " running_var.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type() + << " training=" << training + << std::endl; if ( input.is_cuda() - && input.dim() <= MIOPEN_DIM_MAX - && input.scalar_type() != at::kDouble - && input.scalar_type() != at::kBFloat16 - && (weight.scalar_type() != at::kHalf) - && weight.defined() && bias.defined() - && ((running_mean.defined() && running_var.defined()) - || (!running_mean.defined() && !running_var.defined() && training)) - && (input.dim() >= 3) && detail::getCUDAHooks().compiledWithMIOpen() && cudnn_enabled - && (input.suggest_memory_format() == MemoryFormat::Contiguous - || (input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC)) + && input.dim() <= MIOPEN_DIM_MAX + && (input.dim() >= 3) + && + ( + (input.scalar_type() == at::kFloat && input.suggest_memory_format() == MemoryFormat::Contiguous && weight.scalar_type() == at::kFloat) + || + (input.scalar_type() == at::kFloat && input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC && weight.scalar_type() == at::kFloat) + || + (input.scalar_type() == at::kHalf) // && input.suggest_memory_format() == MemoryFormat::ChannelsLast /* && weight.scalar_type() == at::kFloat*/) + || + (input.scalar_type() == at::kBFloat16) // && input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC && weight.scalar_type() == at::kBFloat16) + ) + && weight.defined() && bias.defined() + && ((running_mean.defined() && running_var.defined()) || (!running_mean.defined() && !running_var.defined() && training)) ) { + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "***** BatchNormBackend::Miopen" << std::endl; return BatchNormBackend::Miopen; } - + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "***** BatchNormBackend::Native" << std::endl; return BatchNormBackend::Native; } @@ -546,6 +568,20 @@ std::tuple _batch_norm_impl_index( const Tensor& input, const std::optional& weight_opt /* optional */, const std::optional& bias_opt /* optional */, const std::optional& running_mean_opt /* optional */, const std::optional& running_var_opt /* optional */, bool training, double momentum, double eps, bool cudnn_enabled) { // See [Note: hacky wrapper removal for optional tensor] + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std :: cout + << "********************* _batch_norm_impl_index" + << " input=" << input.scalar_type() + << " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined) + << " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined) + << " training=" << training + // << " momentum=" << momentum + // << " eps=" << eps + << " cudnn_enabled=" << cudnn_enabled + << std::endl; + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); @@ -605,10 +641,12 @@ std::tuple _batch_norm_impl_index( Tensor reserve = at::empty({0}, input.options().dtype(kByte)); - if (backend == BatchNormBackend::Miopen) { - return std::tuple_cat( + if (backend == BatchNormBackend::Miopen) { + return std::tuple_cat( at::miopen_batch_norm( - input.contiguous(input.suggest_memory_format()), weight.contiguous(), bias.contiguous(), + input.contiguous(input.suggest_memory_format()), + weight.contiguous(), + bias.contiguous(), running_mean.defined() ? running_mean.contiguous() : running_mean, running_var.defined() ? running_var.contiguous() : running_var, training, momentum, eps), @@ -625,9 +663,17 @@ std::tuple _batch_norm_impl_index( std::tuple _batch_norm_impl_index_backward( int64_t impl_index, - const Tensor& input, const Tensor& grad_output, const std::optional& weight_opt /* optional */, const std::optional& running_mean_opt /* optional */, const std::optional& running_var_opt /* optional */, const std::optional& save_mean_opt /* optional */, const std::optional& save_var_transform_opt /* optional */, + const Tensor& input, + const Tensor& grad_output, + const std::optional& weight_opt /* optional */, + const std::optional& running_mean_opt /* optional */, + const std::optional& running_var_opt /* optional */, + const std::optional& save_mean_opt /* optional */, + const std::optional& save_var_transform_opt /* optional */, bool train, double epsilon, std::array output_mask, const Tensor &reservedSpace) { // See [Note: hacky wrapper removal for optional tensor] + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std :: cout << "********************* _batch_norm_impl_index_backward" << std::endl; c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); @@ -674,6 +720,20 @@ Tensor batch_norm( const Tensor& input, const std::optional& weight_opt, const std::optional& bias_opt, const std::optional& running_mean_opt, const std::optional& running_var_opt, bool training, double momentum, double eps, bool cudnn_enabled) { + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std :: cout + << "********************* batch_norm" + << " input=" << input.scalar_type() + << " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined) + << " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined) + << " training=" << training + // << " momentum=" << momentum + // << " eps=" << eps + << " cudnn_enabled=" << cudnn_enabled + << std::endl; + const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();}); const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index 5491f85b5d184..5beb5360f5d18 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -7,6 +7,7 @@ #include #else #include +#include #include #include #endif @@ -15,6 +16,8 @@ // don't build this file as part of CPU build. #include +#include + #if !AT_ROCM_ENABLED() namespace at { namespace native { @@ -57,10 +60,33 @@ Tensor expandScale(const Tensor& t, int64_t dim) { } // namespace -std::tuple miopen_batch_norm( +bool PYTORCH_MIOPEN_EXTRA_LOGGING = c10::utils::check_env("PYTORCH_MIOPEN_EXTRA_LOGGING").value_or(false); +bool PYTORCH_MIOPEN_USE_API_V2 = c10::utils::check_env("PYTORCH_MIOPEN_USE_API_V2").value_or(false); +bool PYTORCH_MIOPEN_BATCHNORM_ENABLE_CK = c10::utils::check_env("PYTORCH_MIOPEN_BATCHNORM_ENABLE_CK").value_or(false); + +miopenBatchNormMode_t getMiopenBatchNormMode(const Tensor& t) +{ + return (t.dim() == 2) ? miopenBNPerActivation : miopenBNSpatial; +} + +std::tuple miopen_batch_norm_train_forward_v2( const Tensor& input_t, const Tensor& weight_t, const std::optional& bias_t_opt, const std::optional& running_mean_t_opt, const std::optional& running_var_t_opt, bool training, double exponential_average_factor, double epsilon) { + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout + << "$$$$$ miopen_batch_norm_train_forward V2" + << " input_t=" << input_t.scalar_type() + << " weight_t=" << weight_t.scalar_type() + << " bias_t_opt=" << (bias_t_opt.has_value() ? bias_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_mean_t_opt=" << (running_mean_t_opt.has_value() ? running_mean_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_var_t_opt=" << (running_var_t_opt.has_value() ? running_var_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " training=" << training + // << " exponential_average_factor=" << exponential_average_factor + // << " epsilon=" << epsilon + << std::endl; + + const bool use_CK = (input_t.scalar_type() == at::kBFloat16 || input_t.scalar_type() == at::kHalf);// && PYTORCH_MIOPEN_BATCHNORM_USE_CK; // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); const Tensor& bias_t = *bias_t_maybe_owned; @@ -70,19 +96,306 @@ std::tuple miopen_batch_norm( TensorArg input{ input_t, "input", 1 }, weight{ weight_t, "weight", 2 }, bias{ bias_t, "bias", 3 }, - running_mean{ running_mean_t, "running_mean", 4 }, - running_var{ running_var_t, "running_var", 5 }; + running_mean{ /*use_CK ? running_mean_t.to(at::kFloat) :*/ running_mean_t, "running_mean", 4 }, + running_var{ /*use_CK? running_var_t.to(at::kFloat) :*/ running_var_t, "running_var", 5 }; CheckedFrom c = "miopen_batch_norm"; - checkAllDefined(c, {input, weight, bias}); - if (!training) { - checkAllDefined(c, {running_mean, running_var}); + // if (!training) { + // checkAllDefined(c, {running_mean, running_var}); + // } + checkAllSameGPU(c, {input, weight, bias, running_mean, running_var}); + // if (input->scalar_type() != ScalarType::Half || input->scalar_type() != ScalarType::BFloat16) { + checkAllSameType(c, {input, weight}); + // } + // checkAllSameType(c, {weight, bias, running_mean, running_var}); + checkAllContiguous(c, {weight, bias, running_mean, running_var}); + TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); + checkDimRange(c, input, 2, 6 /* exclusive */); + auto num_features = input->size(1); + for (auto t : {weight, bias, running_mean, running_var}) { + if (t->defined()) { + checkNumel(c, t, num_features); + } } + + miopenBatchNormMode_t mode = getMiopenBatchNormMode(input_t); + + auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format()); + TensorArg output{ output_t, "output", 0 }; + + auto dataType = getMiopenDataType(*input); + Constant one(dataType, 1); + Constant zero(dataType, 0); + Tensor save_mean, save_var; + + if (use_CK) + { + save_mean = at::empty(num_features, at::kFloat, weight_t.layout(), weight_t.device(), weight_t.is_pinned(), weight_t.suggest_memory_format()); + save_var = at::empty(num_features, at::kFloat, weight_t.layout(), weight_t.device(), weight_t.is_pinned(), weight_t.suggest_memory_format()); + } + else + { + save_mean = at::empty({ num_features }, weight_t.options()); + save_var = at::empty({ num_features }, weight_t.options()); + } + + auto handle = getMiopenHandle(); + + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "##### miopenBatchNormalizationForward V2 Training " + << " use_CK=" << use_CK + << " training=" << training + << " mode=" << mode + << " input=" << input->scalar_type() // in + << " output=" << output->scalar_type() // out + << " weight=" << weight->scalar_type() // in + << " bias=" << bias->scalar_type() // in + // << " eaf=" << exponential_average_factor + << " running_mean=" << running_mean->scalar_type() // out + << " running_var=" << running_var->scalar_type() // out + // << " epsilon=" << epsilon + << " save_mean=" << save_mean.scalar_type() // out + << " save_var=" << save_var.scalar_type() // out + << std::endl; + // std::cout << "*** XXXXXXXXX INPUT miopenBatchNormalizationForward running_mean = " << running_mean->data() << std::endl; + TensorDescriptor idesc{ *input, 4 }; // input descriptor + TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, running_mean, etc. + TensorDescriptor sdesc{ expandScale(save_mean, input->dim()), 4}; + + MIOPEN_CHECK(miopenBatchNormalizationForwardTraining_V2( + handle, mode, &one, &zero, + idesc.desc(), input->const_data_ptr(), + idesc.desc(), output->data_ptr(), + wdesc.desc(), // weight + wdesc.desc(), // bias + sdesc.desc(), // saved_mean + sdesc.desc(), // saved_var + // NOTE: MIOpen docs say that the bnScale and bnBias args are only inputs, + // not outputs. However, unfortunately the function signature only takes + // non-const pointers, presumably by accident + const_cast(weight->const_data_ptr()), + const_cast(bias->const_data_ptr()), + exponential_average_factor, + at::maybe_data_ptr(running_mean), + at::maybe_data_ptr(running_var), + epsilon, + save_mean.mutable_data_ptr(), + save_var.mutable_data_ptr())); + // std::cout << "*** XXXXXXX OUTPUT miopenBatchNormalizationForward running_mean = " << running_mean->data() << std::endl; + // save_mean and save_var can be undefined + // If this causes problems, we can initialize them to empty tensors + // of the correct type + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "##### miopenBatchNormalizationForward V2 RETURN" + << " training=" << training + << " output=" << output->scalar_type() + << " save_mean=" << save_mean.scalar_type() + << " save_var=" << save_var.scalar_type() + << std::endl; + // if (use_CK) + // { + // std::cout << "##### miopenBatchNormalizationForward RETURN convert to " << input->scalar_type() << std::endl; + // return std::tuple{output_t, save_mean.to(input->scalar_type()), save_var.to(input->scalar_type())}; + // } + // else + return std::tuple{output_t, save_mean, save_var}; +} + +std::tuple miopen_batch_norm_train_forward( + const Tensor& input_t, const Tensor& weight_t, const std::optional& bias_t_opt, const std::optional& running_mean_t_opt, const std::optional& running_var_t_opt, + bool training, double exponential_average_factor, double epsilon) +{ + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout + << "$$$$$ miopen_batch_norm_train_forward" + << " input_t=" << input_t.scalar_type() + << " weight_t=" << weight_t.scalar_type() + << " bias_t_opt=" << (bias_t_opt.has_value() ? bias_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_mean_t_opt=" << (running_mean_t_opt.has_value() ? running_mean_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_var_t_opt=" << (running_var_t_opt.has_value() ? running_var_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " training=" << training + // << " exponential_average_factor=" << exponential_average_factor + // << " epsilon=" << epsilon + << std::endl; + + const bool use_CK = (input_t.scalar_type() == at::kBFloat16 || input_t.scalar_type() == at::kHalf); + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); + const Tensor& bias_t = *bias_t_maybe_owned; + const Tensor& running_mean_t = c10::value_or_else(running_mean_t_opt, [] {return Tensor();}); + const Tensor& running_var_t = c10::value_or_else(running_var_t_opt, [] {return Tensor();}); + + // if (input_t.scalar_type() != ScalarType::Half || input_t.scalar_type() != ScalarType::BFloat16) + + TensorArg input{ input_t, "input", 1 }, + weight{ weight_t, "weight", 2 }, + bias{ bias_t, "bias", 3 }, + running_mean{ /*use_CK ? running_mean_t.to(at::kFloat) :*/ running_mean_t, "running_mean", 4 }, + running_var{ /*use_CK? running_var_t.to(at::kFloat) :*/ running_var_t, "running_var", 5 }; + CheckedFrom c = "miopen_batch_norm"; + checkAllDefined(c, {input, weight, bias}); + // if (!training) { + // checkAllDefined(c, {running_mean, running_var}); + // } checkAllSameGPU(c, {input, weight, bias, running_mean, running_var}); - if (input->scalar_type() != ScalarType::Half) { - checkAllSameType(c, {input, weight}); + // if (input->scalar_type() != ScalarType::Half || input->scalar_type() != ScalarType::BFloat16) { + checkAllSameType(c, {input, weight}); + // } + // checkAllSameType(c, {weight, bias, running_mean, running_var}); + checkAllContiguous(c, {weight, bias, running_mean, running_var}); + TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); + checkDimRange(c, input, 2, 6 /* exclusive */); + auto num_features = input->size(1); + for (auto t : {weight, bias, running_mean, running_var}) { + if (t->defined()) { + checkNumel(c, t, num_features); + } } - checkAllSameType(c, {weight, bias, running_mean, running_var}); + + miopenBatchNormMode_t mode = getMiopenBatchNormMode(input_t); + + auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format()); + TensorArg output{ output_t, "output", 0 }; + + auto handle = getMiopenHandle(); + auto dataType = getMiopenDataType(*input); + TensorDescriptor idesc{ *input, 4 }; // input descriptor + TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, running_mean, etc. + + Constant one(dataType, 1); + Constant zero(dataType, 0); + Tensor save_mean, save_var; + + // int64_t num_features = input_t.size(1); + + if (use_CK) + { + save_mean = at::empty(num_features, at::kFloat, weight_t.layout(), weight_t.device(), weight_t.is_pinned(), weight_t.suggest_memory_format()); + save_var = at::empty(num_features, at::kFloat, weight_t.layout(), weight_t.device(), weight_t.is_pinned(), weight_t.suggest_memory_format()); + } + else + { + save_mean = at::empty({ num_features }, weight_t.options()); + save_var = at::empty({ num_features }, weight_t.options()); + } + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "##### miopenBatchNormalizationForward Training " + << " use_CK=" << use_CK + << " training=" << training + << " mode=" << mode + << " input=" << input->scalar_type() // in + << " output=" << output->scalar_type() // out + << " weight=" << weight->scalar_type() // in + << " bias=" << bias->scalar_type() // in + // << " eaf=" << exponential_average_factor + << " running_mean=" << running_mean->scalar_type() // out + << " running_var=" << running_var->scalar_type() // out + // << " epsilon=" << epsilon + << " save_mean=" << save_mean.scalar_type() // out + << " save_var=" << save_var.scalar_type() // out + << std::endl; + // std::cout << "*** XXXXXXXXX INPUT miopenBatchNormalizationForward running_mean = " << running_mean->data() << std::endl; + MIOPEN_CHECK(miopenBatchNormalizationForwardTraining( + handle, mode, &one, &zero, + idesc.desc(), input->const_data_ptr(), + idesc.desc(), output->data_ptr(), + wdesc.desc(), + // NOTE: MIOpen docs say that the bnScale and bnBias args are only inputs, + // not outputs. However, unfortunately the function signature only takes + // non-const pointers, presumably by accident + const_cast(weight->const_data_ptr()), + const_cast(bias->const_data_ptr()), + exponential_average_factor, + at::maybe_data_ptr(running_mean), + at::maybe_data_ptr(running_var), + epsilon, + save_mean.mutable_data_ptr(), + save_var.mutable_data_ptr())); + // std::cout << "*** XXXXXXX OUTPUT miopenBatchNormalizationForward running_mean = " << running_mean->data() << std::endl; + // save_mean and save_var can be undefined + // If this causes problems, we can initialize them to empty tensors + // of the correct type + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "##### miopenBatchNormalizationForward RETURN" + << " training=" << training + << " output=" << output->scalar_type() + << " save_mean=" << save_mean.scalar_type() + << " save_var=" << save_var.scalar_type() + << std::endl; + // if (use_CK) + // { + // std::cout << "##### miopenBatchNormalizationForward RETURN convert to " << input->scalar_type() << std::endl; + // return std::tuple{output_t, save_mean.to(input->scalar_type()), save_var.to(input->scalar_type())}; + // } + // else + return std::tuple{output_t, save_mean, save_var}; +} + +std::tuple miopen_batch_norm_inference( + const Tensor& input_t, const Tensor& weight_t, const std::optional& bias_t_opt, const std::optional& running_mean_t_opt, const std::optional& running_var_t_opt, + bool training, double exponential_average_factor, double epsilon) +{ + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout + << "$$$$$ miopen_batch_norm_inference" + << " input_t=" << input_t.scalar_type() + << " weight_t=" << weight_t.scalar_type() + << " bias_t_opt=" << (bias_t_opt.has_value() ? bias_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_mean_t_opt=" << (running_mean_t_opt.has_value() ? running_mean_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_var_t_opt=" << (running_var_t_opt.has_value() ? running_var_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " training=" << training + // << " exponential_average_factor=" << exponential_average_factor + // << " epsilon=" << epsilon + << std::endl; + + const bool use_CK = (input_t.scalar_type() == at::kBFloat16 || input_t.scalar_type() == at::kHalf); + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); + const Tensor& bias_t = *bias_t_maybe_owned; + const Tensor& running_mean_t = c10::value_or_else(running_mean_t_opt, [] {return Tensor();}); + const Tensor& running_var_t = c10::value_or_else(running_var_t_opt, [] {return Tensor();}); + + TensorArg input{ input_t, "input", 1 }, + weight{ weight_t, "weight", 2 }, + bias{ bias_t, "bias", 3 }, + running_mean{ /*use_CK ? running_mean_t.to(at::kFloat):*/running_mean_t, "running_mean", 4 }, + running_var{ /*use_CK ? running_var_t.to(at::kFloat):*/running_var_t, "running_var", 5 }; + CheckedFrom c = "miopen_batch_norm"; + + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "$$$$$XXXXX" + << " training=" << training + << " dim=" << input->dim() + << " memory_format=" << input->suggest_memory_format() + << "\ninput[" + << " dtype=" << input->scalar_type() + << " sizes=" << input->sizes() + << " strides=" << input->strides() + << " ]\nweight[" + << " dtype=" << weight->scalar_type() + << " sizes=" << weight->sizes() + << " strides=" << weight->strides() + << " ]\nbias[" + << " dtype=" << bias->scalar_type() + << " sizes=" << bias->sizes() + << " strides=" << bias->strides() + << " ]\nrunning_mean[" + << " dtype=" << running_mean->scalar_type() + << " sizes=" << running_mean->sizes() + << " strides=" << running_mean->strides() + << " ]\nrunning_var[" + << " dtype=" << running_var->scalar_type() + << " sizes=" << running_var->sizes() + << " strides=" << running_var->strides() + + << std::endl; + checkAllDefined(c, {input, weight, bias}); + checkAllDefined(c, {running_mean, running_var}); + checkAllSameGPU(c, {input, weight, bias, running_mean, running_var}); + // if (input->scalar_type() != ScalarType::Half || input->scalar_type() != ScalarType::BFloat16) { + checkAllSameType(c, {input, weight}); + // } + // checkAllSameType(c, {weight, bias, running_mean, running_var}); checkAllContiguous(c, {weight, bias, running_mean, running_var}); TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); checkDimRange(c, input, 2, 6 /* exclusive */); @@ -93,13 +406,161 @@ std::tuple miopen_batch_norm( } } - miopenBatchNormMode_t mode; - if (input->dim() == 2) { - mode = miopenBNPerActivation; - } else { - mode = miopenBNSpatial; + auto mode= getMiopenBatchNormMode(input_t); + + auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format()); + TensorArg output{ output_t, "output", 0 }; + + auto handle = getMiopenHandle(); + auto dataType = getMiopenDataType(*input); + TensorDescriptor idesc{ *input, 4 }; // input descriptor + TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, running_mean, etc. + + Constant one(dataType, 1); + Constant zero(dataType, 0); + Tensor save_mean, save_var; + save_mean = at::empty({ num_features }, weight_t.options()); + save_var = at::empty({ num_features }, weight_t.options()); + if (use_CK) /* && input->suggest_memory_format() == MemoryFormat::ChannelsLast */ + { + save_mean = save_mean.to(at::kFloat); + save_var = save_var.to(at::kFloat); + } + + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + { + std::cout << "##### INPUT miopenBatchNormalizationForward running_mean = " << (float*)running_mean->data_ptr() << std::endl; + std::cout << "##### miopenBatchNormalizationForward Inference " + << " use_CK=" << use_CK + << " training=" << training + << " mode=" << mode + << " input=" << input->scalar_type() + << " output=" << output->scalar_type() + << " weight=" << weight->scalar_type() + << " bias=" << bias->scalar_type() + // << " eaf=" << exponential_average_factor + << " running_mean=" << running_mean->scalar_type() + << " running_var=" << running_var->scalar_type() + // << " epsilon=" << epsilon + << " save_mean=" << save_mean.scalar_type() + << " save_var=" << save_var.scalar_type() + << std::endl; + } + MIOPEN_CHECK(miopenBatchNormalizationForwardInference( + handle, mode, &one, &zero, + idesc.desc(), input->const_data_ptr(), // in + idesc.desc(), output->data_ptr(), // out + wdesc.desc(), + // NOTE: MIOpen docs say that the bnScale and bnBias args are only inputs, + // not outputs. However, unfortunately the function signature only takes + // non-const pointers, presumably by accident + const_cast(weight->const_data_ptr()), // in + const_cast(bias->const_data_ptr()), // in + running_mean->data_ptr(), // in + running_var->data_ptr(), // in + epsilon)); + + // save_mean and save_var can be undefined + // If this causes problems, we can initialize them to empty tensors + // of the correct type + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + { + std::cout << "#####*** OUTPUT miopenBatchNormalizationForward running_mean = " << "AAA" /*(float*)running_mean->data_ptr()*/ << std::endl; + std::cout << "#####XXXXX miopenBatchNormalizationForward RETURN" + << " training=" << training + << " output=" << output->scalar_type() + << " save_mean=" << save_mean.scalar_type() + << " save_var=" << save_var.scalar_type() + << std::endl; + } + if (use_CK) + { + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "##### miopenBatchNormalizationForward Inference RETURN convert to " << input->scalar_type() << std::endl; + return std::tuple{output_t, save_mean.to(input->scalar_type()), save_var.to(input->scalar_type())}; + } + else + return std::tuple{output_t, save_mean, save_var}; + +} + +std::tuple miopen_batch_norm_inference_v2( + const Tensor& input_t, const Tensor& weight_t, const std::optional& bias_t_opt, const std::optional& running_mean_t_opt, const std::optional& running_var_t_opt, + bool training, double exponential_average_factor, double epsilon) +{ + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout + << "$$$$$ miopen_batch_norm_inference V2" + << " input_t=" << input_t.scalar_type() + << " weight_t=" << weight_t.scalar_type() + << " bias_t_opt=" << (bias_t_opt.has_value() ? bias_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_mean_t_opt=" << (running_mean_t_opt.has_value() ? running_mean_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_var_t_opt=" << (running_var_t_opt.has_value() ? running_var_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " training=" << training + // << " exponential_average_factor=" << exponential_average_factor + // << " epsilon=" << epsilon + << std::endl; + + const bool use_CK = (input_t.scalar_type() == at::kBFloat16 || input_t.scalar_type() == at::kHalf); + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); + const Tensor& bias_t = *bias_t_maybe_owned; + const Tensor& running_mean_t = c10::value_or_else(running_mean_t_opt, [] {return Tensor();}); + const Tensor& running_var_t = c10::value_or_else(running_var_t_opt, [] {return Tensor();}); + + TensorArg input{ input_t, "input", 1 }, + weight{ weight_t, "weight", 2 }, + bias{ bias_t, "bias", 3 }, + running_mean{ /*use_CK ? running_mean_t.to(at::kFloat):*/running_mean_t, "running_mean", 4 }, + running_var{ /*use_CK ? running_var_t.to(at::kFloat):*/running_var_t, "running_var", 5 }; + CheckedFrom c = "miopen_batch_norm"; + + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "$$$$$XXXXX V2" + << " training=" << training + << " dim=" << input->dim() + << " memory_format=" << input->suggest_memory_format() + << "\ninput[" + << " dtype=" << input->scalar_type() + << " sizes=" << input->sizes() + << " strides=" << input->strides() + << " ]\nweight[" + << " dtype=" << weight->scalar_type() + << " sizes=" << weight->sizes() + << " strides=" << weight->strides() + << " ]\nbias[" + << " dtype=" << bias->scalar_type() + << " sizes=" << bias->sizes() + << " strides=" << bias->strides() + << " ]\nrunning_mean[" + << " dtype=" << running_mean->scalar_type() + << " sizes=" << running_mean->sizes() + << " strides=" << running_mean->strides() + << " ]\nrunning_var[" + << " dtype=" << running_var->scalar_type() + << " sizes=" << running_var->sizes() + << " strides=" << running_var->strides() + + << std::endl; + checkAllDefined(c, {input, weight, bias}); + checkAllDefined(c, {running_mean, running_var}); + checkAllSameGPU(c, {input, weight, bias, running_mean, running_var}); + // if (input->scalar_type() != ScalarType::Half || input->scalar_type() != ScalarType::BFloat16) { + checkAllSameType(c, {input, weight}); + // } + // checkAllSameType(c, {weight, bias, running_mean, running_var}); + checkAllContiguous(c, {weight, bias, running_mean, running_var}); + TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); + checkDimRange(c, input, 2, 6 /* exclusive */); + auto num_features = input->size(1); + for (auto t : {weight, bias, running_mean, running_var}) { + if (t->defined()) { + checkNumel(c, t, num_features); + } } + auto mode= getMiopenBatchNormMode(input_t); + auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format()); TensorArg output{ output_t, "output", 0 }; @@ -107,53 +568,109 @@ std::tuple miopen_batch_norm( auto dataType = getMiopenDataType(*input); TensorDescriptor idesc{ *input, 4 }; // input descriptor TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, running_mean, etc. + TensorDescriptor rdesc{ expandScale(*running_mean, input->dim()), 4 }; Constant one(dataType, 1); Constant zero(dataType, 0); Tensor save_mean, save_var; + save_mean = at::empty({ num_features }, weight_t.options()); + save_var = at::empty({ num_features }, weight_t.options()); + if (use_CK) /* && input->suggest_memory_format() == MemoryFormat::ChannelsLast */ + { + save_mean = save_mean.to(at::kFloat); + save_var = save_var.to(at::kFloat); + } - if (training) { - int64_t num_features = input_t.size(1); - save_mean = at::empty({ num_features }, weight_t.options()); - save_var = at::empty({ num_features }, weight_t.options()); - MIOPEN_CHECK(miopenBatchNormalizationForwardTraining( - handle, mode, &one, &zero, - idesc.desc(), input->const_data_ptr(), - idesc.desc(), output->data_ptr(), - wdesc.desc(), - // NOTE: MIOpen docs say that the bnScale and bnBias args are only inputs, - // not outputs. However, unfortunately the function signature only takes - // non-const pointers, presumably by accident - const_cast(weight->const_data_ptr()), - const_cast(bias->const_data_ptr()), - exponential_average_factor, - at::maybe_data_ptr(running_mean), - at::maybe_data_ptr(running_var), - epsilon, - save_mean.mutable_data_ptr(), - save_var.mutable_data_ptr())); - } else { - save_mean = at::empty({0}, weight_t.options()); - save_var = at::empty({0}, weight_t.options()); - MIOPEN_CHECK(miopenBatchNormalizationForwardInference( - handle, mode, &one, &zero, - idesc.desc(), input->const_data_ptr(), - idesc.desc(), output->data_ptr(), - wdesc.desc(), - // NOTE: MIOpen docs say that the bnScale and bnBias args are only inputs, - // not outputs. However, unfortunately the function signature only takes - // non-const pointers, presumably by accident - const_cast(weight->const_data_ptr()), - const_cast(bias->const_data_ptr()), - running_mean->data_ptr(), - running_var->data_ptr(), - epsilon)); + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + { + std::cout << "##### INPUT miopenBatchNormalizationForward running_mean = " << (float*)running_mean->data_ptr() << std::endl; + std::cout << "##### miopenBatchNormalizationForward Inference " + << " use_CK=" << use_CK + << " training=" << training + << " mode=" << mode + << " input=" << input->scalar_type() + << " output=" << output->scalar_type() + << " weight=" << weight->scalar_type() + << " bias=" << bias->scalar_type() + // << " eaf=" << exponential_average_factor + << " running_mean=" << running_mean->scalar_type() + << " running_var=" << running_var->scalar_type() + // << " epsilon=" << epsilon + << " save_mean=" << save_mean.scalar_type() + << " save_var=" << save_var.scalar_type() + << std::endl; } + MIOPEN_CHECK(miopenBatchNormalizationForwardInference_V2( + handle, mode, &one, &zero, + idesc.desc(), input->const_data_ptr(), // in + idesc.desc(), output->data_ptr(), // out + wdesc.desc(), // weight + wdesc.desc(), // bias + rdesc.desc(), // running_mean + rdesc.desc(), // running_var + // NOTE: MIOpen docs say that the bnScale and bnBias args are only inputs, + // not outputs. However, unfortunately the function signature only takes + // non-const pointers, presumably by accident + const_cast(weight->const_data_ptr()), // in + const_cast(bias->const_data_ptr()), // in + running_mean->data_ptr(), // in + running_var->data_ptr(), // in + epsilon)); // save_mean and save_var can be undefined // If this causes problems, we can initialize them to empty tensors // of the correct type - return std::tuple{output_t, save_mean, save_var}; + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + { + std::cout << "#####*** OUTPUT miopenBatchNormalizationForward running_mean = " << "AAA" /*(float*)running_mean->data_ptr()*/ << std::endl; + std::cout << "#####XXXXX miopenBatchNormalizationForward RETURN" + << " training=" << training + << " output=" << output->scalar_type() + << " save_mean=" << save_mean.scalar_type() + << " save_var=" << save_var.scalar_type() + << std::endl; + } + if (use_CK) + { + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "##### miopenBatchNormalizationForward Inference RETURN convert to " << input->scalar_type() << std::endl; + return std::tuple{output_t, save_mean.to(input->scalar_type()), save_var.to(input->scalar_type())}; + } + else + return std::tuple{output_t, save_mean, save_var}; + +} + +std::tuple miopen_batch_norm( + const Tensor& input_t, const Tensor& weight_t, const std::optional& bias_t_opt, const std::optional& running_mean_t_opt, const std::optional& running_var_t_opt, + bool training, double exponential_average_factor, double epsilon) +{ + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout + << "$$$$$ miopen_batch_norm" + << " input_t=" << input_t.scalar_type() + << " weight_t=" << weight_t.scalar_type() + << " bias_t_opt=" << (bias_t_opt.has_value() ? bias_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_mean_t_opt=" << (running_mean_t_opt.has_value() ? running_mean_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_var_t_opt=" << (running_var_t_opt.has_value() ? running_var_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " training=" << training + // << " exponential_average_factor=" << exponential_average_factor + // << " epsilon=" << epsilon + << std::endl; + + if (training) + return PYTORCH_MIOPEN_USE_API_V2? + miopen_batch_norm_train_forward_v2(input_t, weight_t, bias_t_opt, running_mean_t_opt, running_var_t_opt, + training, exponential_average_factor, epsilon) + : miopen_batch_norm_train_forward(input_t, weight_t, bias_t_opt, running_mean_t_opt, running_var_t_opt, + training, exponential_average_factor, epsilon); + else + return PYTORCH_MIOPEN_USE_API_V2? + miopen_batch_norm_inference_v2(input_t, weight_t, bias_t_opt, running_mean_t_opt, running_var_t_opt, + training, exponential_average_factor, epsilon) + : miopen_batch_norm_inference(input_t, weight_t, bias_t_opt, running_mean_t_opt, running_var_t_opt, + training, exponential_average_factor, epsilon); + } std::tuple miopen_batch_norm_backward( @@ -167,6 +684,21 @@ std::tuple miopen_batch_norm_backward( const std::optional& save_mean_t_opt, const std::optional& save_var_t_opt, double epsilon) { + + const bool use_CK = (input_t.scalar_type() == at::kBFloat16 || input_t.scalar_type() == at::kHalf); + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout + << "$$$$$ miopen_batch_norm_backward" + << " use_CK=" << use_CK + << " input_t=" << input_t.scalar_type() + << " grad_output_t=" << grad_output_t.scalar_type() + << " weight_t=" << weight_t.scalar_type() + << " running_mean_opt=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_var_opt=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined) + << " save_mean_t_opt=" << (save_mean_t_opt.has_value() ? save_mean_t_opt.value().scalar_type() : at::ScalarType::Undefined) + << " save_var_t_opt=" << (save_var_t_opt.has_value() ? save_var_t_opt.value().scalar_type() : at::ScalarType::Undefined) + // << " epsilon=" << epsilon + << std::endl; // See [Note: hacky wrapper removal for optional tensor] const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] { return Tensor(); }); @@ -177,8 +709,36 @@ std::tuple miopen_batch_norm_backward( const Tensor& save_var_t = c10::value_or_else(save_var_t_opt, [] { return Tensor(); }); - auto grad_output_contig = - grad_output_t.contiguous(input_t.suggest_memory_format()); + // auto grad_output_contig = + // grad_output_t.contiguous(input_t.suggest_memory_format()); + + at::Tensor grad_input_t, grad_weight_t, grad_bias_t, grad_output_contig; + + if (use_CK /* && input_t.suggest_memory_format() == MemoryFormat::ChannelsLast */) + { + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "##### miopen_batch_norm_backward (BF16/FP16 NHWC)" + << " input_t=" << input_t.scalar_type() << " : " // << (c10::MemoryFormat) input_t.suggest_memory_format() + << " weight_t=" << weight_t.scalar_type() << " : " // << (c10::MemoryFormat) weight_t.suggest_memory_format() + << std::endl; + grad_input_t = at::empty(input_t.sizes(), at::kFloat, input_t.layout(), input_t.device(), input_t.is_pinned(), input_t.suggest_memory_format()); + grad_weight_t = at::empty(weight_t.sizes(), at::kFloat, weight_t.layout(), weight_t.device(), weight_t.is_pinned(), MemoryFormat::Contiguous); + grad_bias_t = at::empty(weight_t.sizes(), at::kFloat, weight_t.layout(), weight_t.device(), weight_t.is_pinned(), MemoryFormat::Contiguous); + grad_output_contig = grad_output_t.to(at::kFloat).contiguous(input_t.suggest_memory_format()); + } + else + { + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "##### miopen_batch_norm_backward non (BF16/FP16 NHWC)" + << " input_t=" << input_t.scalar_type() << " : " // << (c10::MemoryFormat) input_t.suggest_memory_format() + << " weight_t=" << weight_t.scalar_type() << " : " // << (c10::MemoryFormat) weight_t.suggest_memory_format() + << std::endl; + grad_input_t = at::empty(input_t.sizes(), input_t.scalar_type(), input_t.layout(), input_t.device(), input_t.is_pinned(), input_t.suggest_memory_format()); + grad_weight_t = at::empty(weight_t.sizes(), weight_t.options()); + grad_bias_t = at::empty(weight_t.sizes(), weight_t.options()); + grad_output_contig = grad_output_t.contiguous(input_t.suggest_memory_format()); + } + TensorArg input{ input_t, "input", 1 }, grad_output{ grad_output_contig, "grad_output", 2 }, weight{ weight_t, "weight", 3 }, @@ -188,13 +748,13 @@ std::tuple miopen_batch_norm_backward( checkAllDefined(c, {input, grad_output, weight, save_mean, save_var}); checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var}); - if (input->scalar_type() == ScalarType::Half) { - checkScalarType(c, weight, ScalarType::Float); - } else { - checkAllSameType(c, {input, weight}); - } - checkAllSameType(c, {input, grad_output}); - checkAllSameType(c, {weight, save_mean, save_var}); + // // if (input->scalar_type() == ScalarType::Half) { + // // checkScalarType(c, weight, ScalarType::Float); + // // } else { + checkAllSameType(c, {input, weight}); + // // } + // checkAllSameType(c, {input, grad_output}); + // checkAllSameType(c, {weight, save_mean, save_var}); checkAllContiguous(c, {save_mean, save_var}); TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format())); @@ -212,33 +772,85 @@ std::tuple miopen_batch_norm_backward( mode = miopenBNSpatial; } - auto grad_input_t = at::empty( - input->sizes(), input->options(), input->suggest_memory_format()); - auto grad_weight_t = at::empty(weight->sizes(), weight->options()); - auto grad_bias_t = at::empty(weight->sizes(), weight->options()); - auto handle = getMiopenHandle(); auto dataType = getMiopenDataType(*input); - TensorDescriptor idesc{ *input, 4 }; // input, output, grad_output descriptor - TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, save_mean, etc. - + Constant one(dataType, 1); Constant zero(dataType, 0); + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout + << "##### miopenBatchNormalizationBackward " + << " mode=" << mode + << " input=" << input->scalar_type() + << " grad_output=" << grad_output->scalar_type() + << " grad_input=" << grad_input_t.scalar_type() + << " weight=" << weight->scalar_type() + << " grad_weight=" << grad_weight_t.scalar_type() + << " grad_bias=" << grad_bias_t.scalar_type() + // << " epsilon=" << epsilon + << " save_mean=" << save_mean->scalar_type() + << " save_var=" << save_var->scalar_type() + << std::endl; + if (PYTORCH_MIOPEN_USE_API_V2) + { + TensorDescriptor inputdesc{ *input, 4 }; // input, output, grad_output descriptor + TensorDescriptor gradoutdesc{ *grad_output, 4 }; // grad_input descriptor + TensorDescriptor gradinputdesc{ grad_input_t, 4 }; // grad_input descriptor + TensorDescriptor weightdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, save_mean, etc. + TensorDescriptor biasgraddesc { expandScale(grad_bias_t, input->dim()), 4 }; + TensorDescriptor sdesc { expandScale(*save_mean, input->dim()), 4 }; + + MIOPEN_CHECK(miopenBatchNormalizationBackward_V2( + handle, mode, &one, &zero, &one, &zero, + inputdesc.desc(), input->const_data_ptr(), + gradoutdesc.desc(), grad_output->const_data_ptr(), + gradinputdesc.desc(), grad_input_t.data_ptr(), + weightdesc.desc(), // weight + biasgraddesc.desc(), // grad bias + sdesc.desc(), // saved mean + sdesc.desc(), // saved var + weight->const_data_ptr(), + grad_weight_t.data_ptr(), + grad_bias_t.data_ptr(), + epsilon, + save_mean->const_data_ptr(), + save_var->const_data_ptr())); + } + else{ + TensorDescriptor inputdesc{ *input, 4 }; // input, output, grad_output descriptor + TensorDescriptor gradoutdesc{ *grad_output, 4 }; // grad_input descriptor + TensorDescriptor gradinputdesc{ grad_input_t, 4 }; // grad_input descriptor + TensorDescriptor weightdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, save_mean, etc. + // TensorDescriptor biasgraddesc { expandScale(grad_bias_t, input->dim()), 4 }; + // TensorDescriptor sdesc { expandScale(*save_mean, input->dim()), 4 }; + MIOPEN_CHECK(miopenBatchNormalizationBackward( handle, mode, &one, &zero, &one, &zero, - idesc.desc(), input->const_data_ptr(), - idesc.desc(), grad_output->const_data_ptr(), - idesc.desc(), grad_input_t.data_ptr(), - wdesc.desc(), weight->const_data_ptr(), + inputdesc.desc(), input->const_data_ptr(), + gradoutdesc.desc(), grad_output->const_data_ptr(), + gradinputdesc.desc(), grad_input_t.data_ptr(), + weightdesc.desc(), weight->const_data_ptr(), grad_weight_t.data_ptr(), grad_bias_t.data_ptr(), epsilon, save_mean->const_data_ptr(), save_var->const_data_ptr())); - - return std::tuple{grad_input_t, grad_weight_t, grad_bias_t}; + } + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "##### miopenBatchNormalizationBackward RETURN" + << " grad_input=" << grad_input_t.scalar_type() + << " grad_weight=" << grad_weight_t.scalar_type() + << " grad_bias=" << grad_bias_t.scalar_type() + << std::endl; + + // if (use_CK) + // { + // return std::tuple{grad_input_t.to(input_t.scalar_type()), grad_weight_t.to(input_t.scalar_type()), grad_bias_t.to(input_t.scalar_type())}; + // } + // else + return std::tuple{grad_input_t, grad_weight_t, grad_bias_t}; } }} // namespace native diff --git a/test/test_nn.py b/test/test_nn.py index c1706d32128f2..a7116010ea8c5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1,3 +1,4 @@ +import time # Owner(s): ["module: nn"] import contextlib @@ -5062,7 +5063,8 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_batchnorm_nhwc_cuda(self): - for dtype in (torch.half, torch.float): + # for dtype in (torch.half, torch.float): + for dtype in (torch.bfloat16,): (N, C, H, W) = 2, 64, 50, 50 model = torch.nn.BatchNorm2d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) model = model.eval().cuda().to(dtype) @@ -8239,30 +8241,74 @@ def test_affine_3d_rotateRandom(self, device): self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) def batchnorm2d_miopen(self, dtype, memory_format): - def run_test(input, grad_output): + def run_test(input, grad_output, enable_native = True, enable_cpu = False): + print(f"XXXXXXXXXXXXXX {torch.__file__}") c = input.size(1) - mod = nn.BatchNorm2d(c).cuda().to(dtype=input.dtype) + mod = nn.BatchNorm2d(c, device='cuda', dtype=input.dtype) mod.weight.data.uniform_() mod.bias.data.uniform_() - ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True) - ref_grad = grad.detach().clone(memory_format=torch.preserve_format) - ref_mod = nn.BatchNorm2d(c).cuda().to(dtype=input.dtype) - ref_mod.load_state_dict(mod.state_dict()) + if enable_native: + ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True) + ref_grad = grad.detach().clone(memory_format=torch.preserve_format) + ref_mod = nn.BatchNorm2d(c).cuda().to(dtype=input.dtype) + ref_mod.load_state_dict(mod.state_dict()) + + if enable_cpu: + cpu_input = input.detach().clone(memory_format=torch.preserve_format).cpu().requires_grad_(True) + cpu_grad = grad.detach().cpu().clone(memory_format=torch.preserve_format) + cpu_mod = nn.BatchNorm2d(c).cpu().to(dtype=input.dtype) + cpu_mod.load_state_dict(mod.state_dict()) + + print("---------------- forward ----------------") + time.sleep(1) out = mod(input) + # return + if enable_cpu: + print("---------------- cpu_forward ----------------") + time.sleep(1) + cpu_out = cpu_mod(cpu_input) + if enable_native: + print("---------------- ref_forward ----------------") + with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm + time.sleep(1) + ref_out = ref_mod(ref_input) + + print("---------------- backward ----------------") + time.sleep(1) + # if input.dtype == torch.bfloat16 and memory_format==torch.channels_last: + # grad_output = grad_output.to(torch.float) # .contiguous(memory_format=torch.channels_last) out.backward(grad_output) - with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm - ref_out = ref_mod(ref_input) + if enable_cpu: + print("---------------- cpu_backward ----------------") + time.sleep(1) + cpu_out.backward(cpu_grad) + if enable_native: + print("---------------- ref_backward ----------------") + time.sleep(1) ref_out.backward(ref_grad) + print("---------------- check ----------------") + time.sleep(1) self.assertTrue(out.is_contiguous(memory_format=memory_format)) - self.assertTrue(ref_out.is_contiguous(memory_format=memory_format)) - self.assertEqual(out, ref_out) - self.assertEqual(mod.weight.grad, ref_mod.weight.grad) - self.assertEqual(mod.bias.grad, ref_mod.bias.grad) - self.assertEqual(mod.running_mean, ref_mod.running_mean) - self.assertEqual(mod.running_var, ref_mod.running_var) - self.assertEqual(input.grad, ref_input.grad) - - size = (4, 8, 2, 2) + if enable_cpu: + self.assertTrue(cpu_out.is_contiguous(memory_format=memory_format)) + self.assertEqual(out, cpu_out) + self.assertEqual(mod.weight.grad, cpu_mod.weight.grad) + self.assertEqual(mod.bias.grad, cpu_mod.bias.grad) + self.assertEqual(mod.running_mean, cpu_mod.running_mean) + self.assertEqual(mod.running_var, cpu_mod.running_var) + self.assertEqual(input.grad, cpu_input.grad) + if enable_native: + self.assertTrue(ref_out.is_contiguous(memory_format=memory_format)) + self.assertEqual(out, ref_out) + self.assertEqual(mod.weight.grad, ref_mod.weight.grad) + self.assertEqual(mod.bias.grad, ref_mod.bias.grad) + self.assertEqual(mod.running_mean, ref_mod.running_mean, atol=1e-2, rtol=3e-2, exact_dtype=False) + self.assertEqual(mod.running_var, ref_mod.running_var, atol=1e-2, rtol=3e-2, exact_dtype=False) + self.assertEqual(input.grad, ref_input.grad) + print("---------------- end ----------------") + + # size = (4, 8, 2, 2) + size = (8, 32, 470, 725) input = torch.randint(1, 10, size=size, dtype=dtype, device="cuda") input = input.contiguous(memory_format=memory_format).detach().requires_grad_() grad = torch.randint(1, 10, size=size, dtype=dtype, device="cuda") @@ -8270,15 +8316,15 @@ def run_test(input, grad_output): run_test(input, grad) # see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous" # not channels_last - input = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda") - input = input.contiguous(memory_format=memory_format).detach().requires_grad_() - grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda") - grad = grad.permute(0, 2, 1, 3) - run_test(input, grad) + # input = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda") + # input = input.contiguous(memory_format=memory_format).detach().requires_grad_() + # grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda") + # grad = grad.permute(0, 2, 1, 3) + # run_test(input, grad) @onlyCUDA - @dtypes(torch.float) + @dtypes(torch.float, torch.float16, torch.bfloat16) def test_batchnorm_nhwc_miopen(self, dtype): # TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC" @@ -8293,7 +8339,7 @@ def test_batchnorm_nhwc_miopen(self, dtype): os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val @onlyCUDA - @dtypes(torch.float) + @dtypes(torch.float, torch.float16, torch.bfloat16) def test_batchnorm_nchw_miopen(self, dtype): self.batchnorm2d_miopen(dtype, torch.contiguous_format) diff --git a/torch/csrc/autograd/autograd.cpp b/torch/csrc/autograd/autograd.cpp index 4a550e7006389..a40df28f48a08 100644 --- a/torch/csrc/autograd/autograd.cpp +++ b/torch/csrc/autograd/autograd.cpp @@ -14,6 +14,8 @@ #include +#include + namespace torch { namespace autograd { @@ -96,11 +98,14 @@ static variable_list run_backward( const variable_list& inputs, bool allow_unused, bool accumulate_grad) { + size_t num_tensors = outputs.size(); + std::cout << "^^^^^^^^^^ run_backward num_tensors=" << num_tensors << std::endl; edge_list roots; roots.reserve(num_tensors); for (const auto i : c10::irange(num_tensors)) { const Variable& output = outputs[i]; + std::cout << "^^^^^^^^^^ run_backward output[" << i << "]=" << output << std::endl; auto gradient_edge = impl::gradient_edge(output); TORCH_CHECK( gradient_edge.function, @@ -113,9 +118,11 @@ static variable_list run_backward( edge_list output_edges; if (!inputs.empty()) { size_t num_inputs = inputs.size(); + std::cout << "^^^^^^^^^^ run_backward num_inputs=" << num_inputs << std::endl; output_edges.reserve(num_inputs); for (const auto i : c10::irange(num_inputs)) { const Variable& input = inputs[i]; + std::cout << "^^^^^^^^^^ run_backward input[" << i << "]=" << input << std::endl; const auto output_nr = input.output_nr(); auto grad_fn = input.grad_fn(); if (!grad_fn) { @@ -172,6 +179,11 @@ void backward( if (!retain_graph) { retain_graph = create_graph; } + std::cout << "^^^^^^^^^^ backward" + << " tensors.size()=" << tensors.size() + << " grad_tensors.size()=" << grad_tensors.size() + << " inputs.size()=" << inputs.size() + << std::endl; run_backward( tensors, gradients, @@ -193,6 +205,11 @@ variable_list grad( if (!retain_graph) { retain_graph = create_graph; } + std::cout << "^^^^^^^^^^ grad" + << " outputs.size()=" << outputs.size() + << " inputs.size()=" << inputs.size() + << " grad_outputs.size()=" << grad_outputs.size() + << std::endl; return run_backward( outputs, gradients, diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 8ba9ad24f1165..ae370da9ee4ef 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -185,6 +185,23 @@ def forward(self, input: Tensor) -> Tensor: else: bn_training = (self.running_mean is None) and (self.running_var is None) + # ROCM only + if torch.version.hip \ + and torch._C._get_cudnn_enabled() \ + and input.device.type == "cuda" : + # and input.is_contiguous(memory_format=torch.channels_last): + if input.dtype == torch.bfloat16 : + # NOTE: This is a workaround for a BF16 NHWC/NCHW in ROCm batchnorm implementation + self.weight = Parameter(self.weight.to(torch.bfloat16)) + self.bias = Parameter(self.bias.to(torch.bfloat16)) + self.running_mean = self.running_mean.to(torch.float32) + self.running_var = self.running_var.to(torch.float32) + elif input.dtype == torch.float16: + # NOTE: This is a workaround for a FP16 NHWC/NCHW in ROCm batchnorm implementation + self.weight = Parameter(self.weight.to(torch.float16)) + self.bias = Parameter(self.bias.to(torch.float16)) + self.running_mean = self.running_mean.to(torch.float32) + self.running_var = self.running_var.to(torch.float32) r""" Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are