-
Notifications
You must be signed in to change notification settings - Fork 206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
stateful RNNs #243
stateful RNNs #243
Conversation
79bbafc
to
af11059
Compare
def __setstate__(self, state): | ||
# restore instance attributes | ||
self.__dict__.update(state) | ||
# TODO: old models do not have the online attribute, thus create it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this an explanation, not a "TODO"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see below...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice overall :)
# TODO: old models do not have the online attribute, thus create it | ||
# remove this initialisation code after updating the models | ||
if not hasattr(self, 'online'): | ||
self.online = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is in initialised with None
, and not False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is that not all old models are "offline models", some are "online models". Thus this is a TODO and None
is used as the default. After we update at least the online models, we can set it to False
, after updating all models (which is unlikely to be done in the near future) this can be removed completely. Updated the TODO accordingly.
return self.activate(*args) | ||
return self.activate(*args, **kwargs) | ||
|
||
def __getstate__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO, this shouldn't be in the Layer
class. Every subclass of Layer
should be responsible to ensure that their state is not pickled. This does not lead to much code duplication - right now, only RecurrentLayer
and LSTMLayer
are stateful, and I think for LSTMLayer
you'll have to overwrite __getstate__
anyways, because you don't want to pickle the previous state and the cell state.
@@ -44,16 +51,24 @@ def activate(self, data): | |||
""" | |||
raise NotImplementedError('must be implemented by subclass.') | |||
|
|||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this should be a static method - sure, it does not change the object state in the Layer
class, but its functionality is intended to change the object state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this was a normal method before, changed it because some checker criticised it. Will change it back.
|
||
""" | ||
# reset previous time step to initial value | ||
self._prev = self.init if init is None else init |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._prev = init or self.init
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
# add non-pickled attributes needed for stateful processing | ||
self._prev = self.init | ||
self._state = self.cell_init | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't LSTMLayer
also have a __getstate__
for filtering out self._prev
and self._state
for pickling?
@@ -413,7 +572,7 @@ def activate(self, data, reset_gate, prev): | |||
return self.activation_fn(out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This functionality looks very similar to what the Gate
class already provides. Do we really need a separate class for the GRU cell?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately yes, since it is only similar...
af11059
to
d528b7e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, force pushed the requested changes.
# TODO: old models do not have the online attribute, thus create it | ||
# remove this initialisation code after updating the models | ||
if not hasattr(self, 'online'): | ||
self.online = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is that not all old models are "offline models", some are "online models". Thus this is a TODO and None
is used as the default. After we update at least the online models, we can set it to False
, after updating all models (which is unlikely to be done in the near future) this can be removed completely. Updated the TODO accordingly.
@@ -44,16 +51,24 @@ def activate(self, data): | |||
""" | |||
raise NotImplementedError('must be implemented by subclass.') | |||
|
|||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this was a normal method before, changed it because some checker criticised it. Will change it back.
|
||
""" | ||
# reset previous time step to initial value | ||
self._prev = self.init if init is None else init |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
@@ -413,7 +572,7 @@ def activate(self, data, reset_gate, prev): | |||
return self.activation_fn(out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately yes, since it is only similar...
self.init = init | ||
# keep the state of the layer | ||
self._prev = self.init | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing a __getstate__
here, because it's no longer in Layer
. (I still think it's good to have it here instead of there)
51314d5
to
1ce9cf6
Compare
added initialisation of hidden states to layers; fixes #230 renamed GRU parameters to be consistend with all other layers
1ce9cf6
to
a92614d
Compare
This is the latest attempt so solve #230 and #234. It also address partly #185, at least the parts relevant to NNs.
Supersedes #233 and #235.