Skip to content

Commit

Permalink
Cherry pick bn 3d for rocm rel 6.3 (#3387)
Browse files Browse the repository at this point in the history
* fix bn 3d issue

* fix review comments

* fix typo
  • Loading branch information
bghimireamd authored Nov 15, 2024
1 parent 7a3c295 commit 4bd61bb
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 85 deletions.
19 changes: 13 additions & 6 deletions src/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,27 +66,34 @@ void DeriveBNTensorDescriptor(TensorDescriptor& derivedBnDesc,

TensorDescriptor BuildReshaped4DTensorDescriptor(const miopen::TensorDescriptor& tDesc)
{
std::vector<size_t> dims(tDesc.GetLengths());

auto dataType = tDesc.GetType();
auto layout = tDesc.GetLayout_t();
if(layout == miopenTensorNCDHW)
{
layout = miopenTensorNCHW;

// NxCxDxHxW -> NxCx(D*H)xW
dims[2] *= dims[3];
dims[3] = dims[4];
dims.pop_back();
}
else if(layout == miopenTensorNDHWC)
{
layout = miopenTensorNHWC;

// NxDxHxWxC -> Nx(D*H)xWxC
dims[1] *= dims[2];
dims[2] = dims[3];
dims[3] = dims[4];
dims.pop_back();
}
else
{
std::cout << "Cannot handle layout : " << layout << "\n";
exit(EXIT_FAILURE); // NOLINT (concurrency-mt-unsafe)
}
std::vector<size_t> dims(tDesc.GetLengths());

// NxCxDxHxW -> NxCx(D*H)xW
dims[2] *= dims[3];
dims[3] = dims[4];
dims.pop_back();

return {dataType, layout, dims};
}
Expand Down
160 changes: 81 additions & 79 deletions src/batch_norm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t yDesc,
void* y,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t estMeanDesc,
const miopenTensorDescriptor_t estVarianceDesc,
void* bnScale,
Expand All @@ -222,7 +222,7 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
yDesc,
y,
scaleDesc,
BiasDesc,
biasDesc,
estMeanDesc,
estVarianceDesc,
bnScale,
Expand All @@ -239,31 +239,31 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
nullptr,
nullptr,
miopen::debug::BatchNormDirection_t::ForwardInference);

// In case of NxCxDxHxW
int size{0};
miopenGetTensorDescriptorSize(xDesc, &size);
// In case of NxCxDxHxW
auto ReshapeIfNeeded = [size](const auto desc) {
return (size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(desc))
: miopen::deref(desc);
};
return miopen::try_([&] {
miopen::BatchNormForwardInference(
miopen::deref(handle),
bn_mode,
alpha,
beta,
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(xDesc))
: miopen::deref(xDesc),
DataCast(x),
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(yDesc))
: miopen::deref(yDesc),
DataCast(y),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(estMeanDesc),
miopen::deref(estVarianceDesc),
DataCast(bnScale),
DataCast(bnBias),
DataCast(estimatedMean),
DataCast(estimatedVariance),
epsilon);
miopen::BatchNormForwardInference(miopen::deref(handle),
bn_mode,
alpha,
beta,
ReshapeIfNeeded(xDesc),
DataCast(x),
ReshapeIfNeeded(yDesc),
DataCast(y),
ReshapeIfNeeded(scaleDesc),
ReshapeIfNeeded(biasDesc),
ReshapeIfNeeded(estMeanDesc),
ReshapeIfNeeded(estVarianceDesc),
DataCast(bnScale),
DataCast(bnBias),
DataCast(estimatedMean),
DataCast(estimatedVariance),
epsilon);
});
}

Expand All @@ -277,7 +277,7 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t yDesc,
void* y,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t savedMeanDesc,
const miopenTensorDescriptor_t savedVarianceDesc,
void* bnScale,
Expand All @@ -296,7 +296,7 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
yDesc,
y,
scaleDesc,
BiasDesc,
biasDesc,
savedMeanDesc,
savedVarianceDesc,
bnScale,
Expand All @@ -316,33 +316,35 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
resultSaveMean,
resultSaveInvVariance,
miopen::debug::BatchNormDirection_t::ForwardTraining);
// In case of NxCxDxHxW

int size{0};
miopenGetTensorDescriptorSize(xDesc, &size);
// In case of NxCxDxHxW
auto ReshapeIfNeeded = [size](const auto desc) {
return (size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(desc))
: miopen::deref(desc);
};
return miopen::try_([&] {
miopen::BatchNormForwardTraining(
miopen::deref(handle),
bn_mode,
alpha,
beta,
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(xDesc))
: miopen::deref(xDesc),
DataCast(x),
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(yDesc))
: miopen::deref(yDesc),
DataCast(y),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(savedMeanDesc),
miopen::deref(savedVarianceDesc),
DataCast(bnScale),
DataCast(bnBias),
expAvgFactor,
DataCast(resultRunningMean),
DataCast(resultRunningVariance),
epsilon,
DataCast(resultSaveMean),
DataCast(resultSaveInvVariance));
miopen::BatchNormForwardTraining(miopen::deref(handle),
bn_mode,
alpha,
beta,
ReshapeIfNeeded(xDesc),
DataCast(x),
ReshapeIfNeeded(yDesc),
DataCast(y),
ReshapeIfNeeded(scaleDesc),
ReshapeIfNeeded(biasDesc),
ReshapeIfNeeded(savedMeanDesc),
ReshapeIfNeeded(savedVarianceDesc),
DataCast(bnScale),
DataCast(bnBias),
expAvgFactor,
DataCast(resultRunningMean),
DataCast(resultRunningVariance),
epsilon,
DataCast(resultSaveMean),
DataCast(resultSaveInvVariance));
});
}

Expand All @@ -360,7 +362,7 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t dxDesc,
void* dx,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t savedMeanDesc,
const miopenTensorDescriptor_t savedVarianceDesc,
const void* bnScale,
Expand All @@ -379,7 +381,7 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
dxDesc,
dx,
scaleDesc,
BiasDesc,
biasDesc,
savedMeanDesc,
savedVarianceDesc,
bnScale,
Expand All @@ -396,35 +398,35 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
savedMean,
savedInvVariance,
miopen::debug::BatchNormDirection_t::Backward);
// In case of NxCxDxHxW
int size{0};
miopenGetTensorDescriptorSize(xDesc, &size);
// In case of NxCxDxHxW
auto ReshapeIfNeeded = [size](const auto desc) {
return (size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(desc))
: miopen::deref(desc);
};
return miopen::try_([&] {
miopen::BatchNormBackward(
miopen::deref(handle),
bn_mode,
alphaDataDiff,
betaDataDiff,
alphaParamDiff,
betaParamDiff,
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(xDesc))
: miopen::deref(xDesc),
DataCast(x),
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(dyDesc))
: miopen::deref(dyDesc),
DataCast(dy),
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(dxDesc))
: miopen::deref(dxDesc),
DataCast(dx),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(savedMeanDesc),
miopen::deref(savedVarianceDesc),
DataCast(bnScale),
DataCast(resultBnScaleDiff),
DataCast(resultBnBiasDiff),
epsilon,
DataCast(savedMean),
DataCast(savedInvVariance));
miopen::BatchNormBackward(miopen::deref(handle),
bn_mode,
alphaDataDiff,
betaDataDiff,
alphaParamDiff,
betaParamDiff,
ReshapeIfNeeded(xDesc),
DataCast(x),
ReshapeIfNeeded(dyDesc),
DataCast(dy),
ReshapeIfNeeded(dxDesc),
DataCast(dx),
ReshapeIfNeeded(scaleDesc),
ReshapeIfNeeded(biasDesc),
ReshapeIfNeeded(savedMeanDesc),
ReshapeIfNeeded(savedVarianceDesc),
DataCast(bnScale),
DataCast(resultBnScaleDiff),
DataCast(resultBnBiasDiff),
epsilon,
DataCast(savedMean),
DataCast(savedInvVariance));
});
}

0 comments on commit 4bd61bb

Please sign in to comment.