-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Upsample Layer #2255
Add Upsample Layer #2255
Conversation
60a182d
to
f0fda83
Compare
edb7f9f
to
8dc219b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks pretty good. Just a couple of small things.
template <typename TensorDataType, data_layout Layout, El::Device Device> | ||
template <typename ArchiveT> | ||
void upsample_layer<TensorDataType, Layout, Device>::serialize(ArchiveT& ar) | ||
{ | ||
using DataTypeLayer = data_type_layer<TensorDataType>; | ||
ar(::cereal::make_nvp("DataTypeLayer", | ||
::cereal::base_class<DataTypeLayer>(this))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this save more of the state? At least the upscale mode and scaling factors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
src/layers/transform/upsample.cpp
Outdated
|
||
/// Pooling forward propagation with im2col | ||
template <typename TensorDataType, data_layout Layout, El::Device Dev> | ||
void upsample_layer<TensorDataType, Layout, Dev>::fp_compute_im2col() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might consider just deleting this and the corresponding bp_
function to just remove unnecessary code. The call-sites are commented out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
src/layers/transform/upsample.cpp
Outdated
|
||
/// Pooling forward propagation with im2col | ||
template <typename TensorDataType, data_layout Layout, El::Device Dev> | ||
void upsample_layer<TensorDataType, Layout, Dev>::bp_compute_im2col() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
Co-authored-by: Tom Benson <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few suggestions for further tidying.
#include "lbann/utils/dnn_lib/upsample.hpp" | ||
#endif // LBANN_HAS_DNN_LIB | ||
#include "lbann/utils/exception.hpp" | ||
#include "lbann/utils/im2col.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include "lbann/utils/im2col.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
/// Pooling forward propagation with im2col | ||
void fp_compute_im2col(); | ||
|
||
/// Pooling forward propagation with im2col | ||
void bp_compute_im2col(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Pooling forward propagation with im2col | |
void fp_compute_im2col(); | |
/// Pooling forward propagation with im2col | |
void bp_compute_im2col(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
/// Pooling forward propagation with DNN library | ||
void fp_compute_dnn(); | ||
|
||
/// Pooling backward propagation with DNN library |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Pooling forward propagation with DNN library | |
void fp_compute_dnn(); | |
/// Pooling backward propagation with DNN library | |
/// Upscaling forward propagation with DNN library | |
void fp_compute_dnn(); | |
/// Upscaling backward propagation with DNN library |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
Co-authored-by: Tal Ben-Nun <[email protected]>
Adds a distconv supported upsampling layer. Currently, this layer can only perform nearest neighbor upsampling and only works on NVIDIA machines.