Skip to content

Commit

Permalink
do not reset NN layers in online (streaming) mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Böck committed Dec 8, 2016
1 parent eb0310e commit 5a2e813
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
8 changes: 6 additions & 2 deletions madmom/ml/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ class NeuralNetwork(Processor):
"""

def __init__(self, layers):
def __init__(self, layers, online=False):
self.layers = layers
self.online = online

def process(self, data):
"""
Expand All @@ -89,13 +90,16 @@ def process(self, data):
Network predictions for this data.
"""
# reset the layers? (online: do not reset, keep the state)
# Note: use getattr to be able to process old models
reset = not getattr(self, 'online', False)
# check the dimensions of the data
if data.ndim == 1:
data = np.atleast_2d(data).T
# loop over all layers
for layer in self.layers:
# activate the layer and feed the output into the next one
data = layer(data)
data = layer(data, reset=reset)
# ravel the predictions if needed
if data.ndim == 2 and data.shape[1] == 1:
data = data.ravel()
Expand Down
20 changes: 10 additions & 10 deletions madmom/ml/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class Layer(object):
"""

def __call__(self, *args):
def __call__(self, *args, **kwargs):
# this magic method makes a Layer callable
return self.activate(*args)
return self.activate(*args, **kwargs)

def activate(self, data):
"""
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(self, weights, bias, activation_fn):
self.bias = bias
self.activation_fn = activation_fn

def activate(self, data):
def activate(self, data, **kwargs):
"""
Activate the layer.
Expand Down Expand Up @@ -180,7 +180,7 @@ def __init__(self, fwd_layer, bwd_layer):
self.fwd_layer = fwd_layer
self.bwd_layer = bwd_layer

def activate(self, data):
def activate(self, data, **kwargs):
"""
Activate the layer.
Expand All @@ -200,9 +200,9 @@ def activate(self, data):
"""
# activate in forward direction
fwd = self.fwd_layer(data)
fwd = self.fwd_layer(data, **kwargs)
# also activate with reverse input
bwd = self.bwd_layer(data[::-1])
bwd = self.bwd_layer(data[::-1], **kwargs)
# stack data
return np.hstack((fwd, bwd[::-1]))

Expand Down Expand Up @@ -700,7 +700,7 @@ def __init__(self, weights, bias, stride=1, pad='valid',
raise NotImplementedError('only `pad` == "valid" implemented.')
self.pad = pad

def activate(self, data):
def activate(self, data, **kwargs):
"""
Activate the layer.
Expand Down Expand Up @@ -757,7 +757,7 @@ class StrideLayer(Layer):
def __init__(self, block_size):
self.block_size = block_size

def activate(self, data):
def activate(self, data, **kwargs):
"""
Activate the layer.
Expand Down Expand Up @@ -798,7 +798,7 @@ def __init__(self, size, stride=None):
stride = size
self.stride = stride

def activate(self, data):
def activate(self, data, **kwargs):
"""
Activate the layer.
Expand Down Expand Up @@ -862,7 +862,7 @@ def __init__(self, beta, gamma, mean, inv_std, activation_fn):
self.inv_std = inv_std
self.activation_fn = activation_fn

def activate(self, data):
def activate(self, data, **kwargs):
"""
Activate the layer.
Expand Down
1 change: 1 addition & 0 deletions madmom/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,4 @@ def io_arguments(parser, output_suffix='.txt', pickle=True, online=False):
sp.set_defaults(origin='future')
sp.set_defaults(num_frames=1)
sp.set_defaults(stream=None)
sp.set_defaults(online=True)

0 comments on commit 5a2e813

Please sign in to comment.