Skip to content

Commit

Permalink
fix potential divide by zero fault when bf16s / fp16s enabled, fix #2125
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 16, 2020
1 parent 1c5af3d commit b766c8c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2513,7 +2513,7 @@ void VkTransfer::record_upload(const Mat& src, VkMat& dst, const Option& opt, bo
// NCNN_LOGE("record_upload src = %d | %d %d %d @ %d", src.dims, src.w, src.h, src.c, src.elempack);

// NOTE keep the hack here ?
if (src.elemsize == src.elempack * 4u)
if (src.elembits() == 32)
{
if (opt.use_fp16_storage || (opt.use_fp16_packed && src.elempack % 4 == 0))
{
Expand Down
20 changes: 10 additions & 10 deletions src/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1150,13 +1150,13 @@ int Net::forward_layer(int layer_index, std::vector<Mat>& blob_mats, const Optio
#if NCNN_ARM82
if (opt.use_fp16_storage && cpu_support_arm_asimdhp())
{
if (bottom_blob.elemsize / bottom_blob.elempack == 4u && layer->support_fp16_storage)
if (bottom_blob.elembits() == 32 && layer->support_fp16_storage)
{
Mat bottom_blob_fp16;
cast_float32_to_float16(bottom_blob, bottom_blob_fp16, opt);
bottom_blob = bottom_blob_fp16;
}
if (bottom_blob.elemsize / bottom_blob.elempack == 2u && !layer->support_fp16_storage)
if (bottom_blob.elembits() == 16 && !layer->support_fp16_storage)
{
Mat bottom_blob_fp32;
cast_float16_to_float32(bottom_blob, bottom_blob_fp32, opt);
Expand All @@ -1167,13 +1167,13 @@ int Net::forward_layer(int layer_index, std::vector<Mat>& blob_mats, const Optio
#endif // NCNN_ARM82
if (opt.use_bf16_storage)
{
if (bottom_blob.elemsize / bottom_blob.elempack == 4u && layer->support_bf16_storage)
if (bottom_blob.elembits() == 32 && layer->support_bf16_storage)
{
Mat bottom_blob_bf16;
cast_float32_to_bfloat16(bottom_blob, bottom_blob_bf16, opt);
bottom_blob = bottom_blob_bf16;
}
if (bottom_blob.elemsize / bottom_blob.elempack == 2u && !layer->support_bf16_storage)
if (bottom_blob.elembits() == 16 && !layer->support_bf16_storage)
{
Mat bottom_blob_fp32;
cast_bfloat16_to_float32(bottom_blob, bottom_blob_fp32, opt);
Expand Down Expand Up @@ -1283,13 +1283,13 @@ int Net::forward_layer(int layer_index, std::vector<Mat>& blob_mats, const Optio
#if NCNN_ARM82
if (opt.use_fp16_storage && cpu_support_arm_asimdhp())
{
if (bottom_blobs[i].elemsize / bottom_blobs[i].elempack == 4u && layer->support_fp16_storage)
if (bottom_blobs[i].elembits() == 32 && layer->support_fp16_storage)
{
Mat bottom_blob_fp16;
cast_float32_to_float16(bottom_blobs[i], bottom_blob_fp16, opt);
bottom_blobs[i] = bottom_blob_fp16;
}
if (bottom_blobs[i].elemsize / bottom_blobs[i].elempack == 2u && !layer->support_fp16_storage)
if (bottom_blobs[i].elembits() == 16 && !layer->support_fp16_storage)
{
Mat bottom_blob_fp32;
cast_float16_to_float32(bottom_blobs[i], bottom_blob_fp32, opt);
Expand All @@ -1300,13 +1300,13 @@ int Net::forward_layer(int layer_index, std::vector<Mat>& blob_mats, const Optio
#endif // NCNN_ARM82
if (opt.use_bf16_storage)
{
if (bottom_blobs[i].elemsize / bottom_blobs[i].elempack == 4u && layer->support_bf16_storage)
if (bottom_blobs[i].elembits() == 32 && layer->support_bf16_storage)
{
Mat bottom_blob_bf16;
cast_float32_to_bfloat16(bottom_blobs[i], bottom_blob_bf16, opt);
bottom_blobs[i] = bottom_blob_bf16;
}
if (bottom_blobs[i].elemsize / bottom_blobs[i].elempack == 2u && !layer->support_bf16_storage)
if (bottom_blobs[i].elembits() == 16 && !layer->support_bf16_storage)
{
Mat bottom_blob_fp32;
cast_bfloat16_to_float32(bottom_blobs[i], bottom_blob_fp32, opt);
Expand Down Expand Up @@ -2779,7 +2779,7 @@ int Extractor::extract(int blob_index, Mat& feat)
#if NCNN_ARM82
if (opt.use_fp16_storage && cpu_support_arm_asimdhp())
{
if (feat.elemsize / feat.elempack == 2u)
if (feat.elembits() == 16)
{
Mat feat_fp32;
cast_float16_to_float32(feat, feat_fp32, opt);
Expand All @@ -2790,7 +2790,7 @@ int Extractor::extract(int blob_index, Mat& feat)
#endif // NCNN_ARM82
if (opt.use_bf16_storage)
{
if (feat.elemsize / feat.elempack == 2u)
if (feat.elembits() == 16)
{
Mat feat_fp32;
cast_bfloat16_to_float32(feat, feat_fp32, opt);
Expand Down

0 comments on commit b766c8c

Please sign in to comment.