Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stateful RNNs #243

Merged
merged 4 commits into from
Jan 20, 2017
Merged

stateful RNNs #243

merged 4 commits into from
Jan 20, 2017

Conversation

superbock
Copy link
Collaborator

This is the latest attempt so solve #230 and #234. It also address partly #185, at least the parts relevant to NNs.

Supersedes #233 and #235.

@superbock superbock force-pushed the stateful_rnns branch 2 times, most recently from 79bbafc to af11059 Compare January 18, 2017 13:21
@superbock superbock requested a review from fdlm January 18, 2017 13:24
def __setstate__(self, state):
# restore instance attributes
self.__dict__.update(state)
# TODO: old models do not have the online attribute, thus create it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this an explanation, not a "TODO"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see below...

Copy link
Contributor

@fdlm fdlm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice overall :)

# TODO: old models do not have the online attribute, thus create it
# remove this initialisation code after updating the models
if not hasattr(self, 'online'):
self.online = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is in initialised with None, and not False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is that not all old models are "offline models", some are "online models". Thus this is a TODO and None is used as the default. After we update at least the online models, we can set it to False, after updating all models (which is unlikely to be done in the near future) this can be removed completely. Updated the TODO accordingly.

return self.activate(*args)
return self.activate(*args, **kwargs)

def __getstate__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, this shouldn't be in the Layer class. Every subclass of Layer should be responsible to ensure that their state is not pickled. This does not lead to much code duplication - right now, only RecurrentLayer and LSTMLayer are stateful, and I think for LSTMLayer you'll have to overwrite __getstate__ anyways, because you don't want to pickle the previous state and the cell state.

@@ -44,16 +51,24 @@ def activate(self, data):
"""
raise NotImplementedError('must be implemented by subclass.')

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this should be a static method - sure, it does not change the object state in the Layer class, but its functionality is intended to change the object state.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was a normal method before, changed it because some checker criticised it. Will change it back.


"""
# reset previous time step to initial value
self._prev = self.init if init is None else init
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._prev = init or self.init ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

# add non-pickled attributes needed for stateful processing
self._prev = self.init
self._state = self.cell_init

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't LSTMLayer also have a __getstate__ for filtering out self._prev and self._state for pickling?

@@ -413,7 +572,7 @@ def activate(self, data, reset_gate, prev):
return self.activation_fn(out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functionality looks very similar to what the Gate class already provides. Do we really need a separate class for the GRU cell?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately yes, since it is only similar...

Copy link
Collaborator Author

@superbock superbock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, force pushed the requested changes.

# TODO: old models do not have the online attribute, thus create it
# remove this initialisation code after updating the models
if not hasattr(self, 'online'):
self.online = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is that not all old models are "offline models", some are "online models". Thus this is a TODO and None is used as the default. After we update at least the online models, we can set it to False, after updating all models (which is unlikely to be done in the near future) this can be removed completely. Updated the TODO accordingly.

@@ -44,16 +51,24 @@ def activate(self, data):
"""
raise NotImplementedError('must be implemented by subclass.')

@staticmethod
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was a normal method before, changed it because some checker criticised it. Will change it back.


"""
# reset previous time step to initial value
self._prev = self.init if init is None else init
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@@ -413,7 +572,7 @@ def activate(self, data, reset_gate, prev):
return self.activation_fn(out)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately yes, since it is only similar...

self.init = init
# keep the state of the layer
self._prev = self.init

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing a __getstate__ here, because it's no longer in Layer. (I still think it's good to have it here instead of there)

@superbock superbock force-pushed the stateful_rnns branch 2 times, most recently from 51314d5 to 1ce9cf6 Compare January 19, 2017 08:50
added initialisation of hidden states to layers; fixes #230
renamed GRU parameters to be consistend with all other layers
@superbock superbock merged commit 98787a4 into master Jan 20, 2017
@superbock superbock deleted the stateful_rnns branch January 20, 2017 10:28
@superbock superbock changed the title [WIP] stateful RNNs stateful RNNs Jan 20, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants