Skip to content

Commit

Permalink
cache spectrogram filterbank and stft window; make bin_frequencies a …
Browse files Browse the repository at this point in the history
…property
  • Loading branch information
Sebastian Böck committed Jul 28, 2016
1 parent 924b43d commit 43f64f7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
35 changes: 21 additions & 14 deletions madmom/audio/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def __new__(cls, stft, **kwargs):
obj = np.asarray(data).view(cls)
# save additional attributes
obj.stft = stft
obj.bin_frequencies = stft.bin_frequencies
# return the object
return obj

Expand All @@ -218,7 +217,6 @@ def __array_finalize__(self, obj):
return
# set default values here, also needed for views
self.stft = getattr(obj, 'stft', None)
self.bin_frequencies = getattr(obj, 'bin_frequencies', None)
# Note: these attributes are added for compatibility, if they are
# present any spectrogram sub-class behaves exactly the same
self.filterbank = getattr(obj, 'filterbank', None)
Expand All @@ -235,6 +233,11 @@ def num_bins(self):
"""Number of bins."""
return int(self.shape[1])

@property
def bin_frequencies(self):
"""Bin frequencies."""
return self.stft.bin_frequencies

def diff(self, **kwargs):
"""
Return the difference of the magnitude spectrogram.
Expand Down Expand Up @@ -461,21 +464,24 @@ def __new__(cls, spectrogram, filterbank=FILTERBANK, num_bands=NUM_BANDS,
obj = np.asarray(data).view(cls)
# save additional attributes
obj.filterbank = filterbank
# use the center frequencies of the filterbank as bin_frequencies
obj.bin_frequencies = filterbank.center_frequencies
# and those from the given spectrogram
obj.stft = spectrogram.stft
obj.mul = spectrogram.mul
obj.add = spectrogram.add
# return the object
return obj

def __array_finalize__(self, obj):
if obj is None:
return
# set default values here, also needed for views
self.filterbank = getattr(obj, 'filterbank', None)
super(FilteredSpectrogram, self).__array_finalize__(obj)
# def __array_finalize__(self, obj):
# if obj is None:
# return
# # set default values here, also needed for views
# self.filterbank = getattr(obj, 'filterbank', None)
# super(FilteredSpectrogram, self).__array_finalize__(obj)

@property
def bin_frequencies(self):
# use the center frequencies of the filterbank as bin_frequencies
return self.filterbank.center_frequencies


class FilteredSpectrogramProcessor(Processor):
Expand Down Expand Up @@ -507,7 +513,6 @@ def __init__(self, filterbank=FILTERBANK, num_bands=NUM_BANDS, fmin=FMIN,
fmax=FMAX, fref=A4, norm_filters=NORM_FILTERS,
unique_filters=UNIQUE_FILTERS, **kwargs):
# pylint: disable=unused-argument

self.filterbank = filterbank
self.num_bands = num_bands
self.fmin = fmin
Expand All @@ -534,12 +539,15 @@ def process(self, data, **kwargs):
"""
# instantiate a FilteredSpectrogram and return it
return FilteredSpectrogram(data, filterbank=self.filterbank,
data = FilteredSpectrogram(data, filterbank=self.filterbank,
num_bands=self.num_bands, fmin=self.fmin,
fmax=self.fmax, fref=self.fref,
norm_filters=self.norm_filters,
unique_filters=self.unique_filters,
**kwargs)
# cache the filterbank
self.filterbank = data.filterbank
return data


# logarithmic spectrogram stuff
Expand Down Expand Up @@ -611,7 +619,7 @@ def __new__(cls, spectrogram, log=LOG, mul=MUL, add=ADD, **kwargs):
obj.add = add
# and those from the given spectrogram
obj.stft = spectrogram.stft
obj.bin_frequencies = spectrogram.bin_frequencies
# obj.bin_frequencies = spectrogram.bin_frequencies
obj.filterbank = spectrogram.filterbank
# return the object
return obj
Expand Down Expand Up @@ -1043,7 +1051,6 @@ def __new__(cls, spectrogram, diff_ratio=DIFF_RATIO,
obj.positive_diffs = positive_diffs
# and those from the given spectrogram
obj.filterbank = spectrogram.filterbank
obj.bin_frequencies = spectrogram.bin_frequencies
obj.mul = spectrogram.mul
obj.add = spectrogram.add
# return the object
Expand Down
23 changes: 15 additions & 8 deletions madmom/audio/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,6 @@ def __new__(cls, frames, window=np.hanning, fft_size=None,
obj = np.asarray(data).view(cls)
# save the other parameters
obj.frames = frames
obj.bin_frequencies = fft_frequencies(obj.shape[1],
frames.signal.sample_rate)
obj.window = window
obj.fft_window = fft_window
obj.fft_size = fft_size if fft_size else frame_size
Expand All @@ -349,12 +347,15 @@ def __array_finalize__(self, obj):
return
# set default values here, also needed for views
self.frames = getattr(obj, 'frames', None)
self.bin_frequencies = getattr(obj, 'bin_frequencies', None)
self.window = getattr(obj, 'window', np.hanning)
self.fft_window = getattr(obj, 'fft_window', None)
self.fft_size = getattr(obj, 'fft_size', None)
self.circular_shift = getattr(obj, 'circular_shift', False)

@property
def bin_frequencies(self):
return fft_frequencies(self.num_bins, self.frames.signal.sample_rate)

def spec(self, **kwargs):
"""
Returns the magnitude spectrogram of the STFT.
Expand Down Expand Up @@ -459,10 +460,13 @@ def process(self, data, **kwargs):
"""
# instantiate a STFT
return ShortTimeFourierTransform(data, window=self.window,
data = ShortTimeFourierTransform(data, window=self.window,
fft_size=self.fft_size,
circular_shift=self.circular_shift,
**kwargs)
# cache the window
self.window = data.window
return data

@staticmethod
def add_arguments(parser, window=None, fft_size=None):
Expand Down Expand Up @@ -563,7 +567,6 @@ def __new__(cls, stft, **kwargs):
obj = np.asarray(phase(stft)).view(cls)
# save additional attributes
obj.stft = stft
obj.bin_frequencies = stft.bin_frequencies
# return the object
return obj

Expand All @@ -572,7 +575,10 @@ def __array_finalize__(self, obj):
return
# set default values here, also needed for views
self.stft = getattr(obj, 'stft', None)
self.bin_frequencies = getattr(obj, 'bin_frequencies', None)

@property
def bin_frequencies(self):
return self.stft.bin_frequencies

def local_group_delay(self, **kwargs):
"""
Expand Down Expand Up @@ -645,7 +651,6 @@ def __new__(cls, phase, **kwargs):
# save additional attributes
obj.phase = phase
obj.stft = phase.stft
obj.bin_frequencies = phase.bin_frequencies
# return the object
return obj

Expand All @@ -655,7 +660,9 @@ def __array_finalize__(self, obj):
# set default values here, also needed for views
self.phase = getattr(obj, 'phase', None)
self.stft = getattr(obj, 'stft', None)
self.bin_frequencies = getattr(obj, 'bin_frequencies', None)

@property
def bin_frequencies(self):
return self.stft.bin_frequencies

LGD = LocalGroupDelay

0 comments on commit 43f64f7

Please sign in to comment.