Skip to content

Commit

Permalink
Downbeats: allow weighting of beats per bar
Browse files Browse the repository at this point in the history
 * Optional parameter, implicitly defaults to ones for the array
 * Clean up the handling of lengths into the constructor, it was getting verbose
 * Check weights don't sum to zero, to avoid divide-by-zero pain.
 * Weight the HMM results in log space by normalised weight values, as suggested by @superbock
 * Add new test to prove that (sufficient, but arbitrary) weighting to 3-time (over 4-time)
   does indeed return 3-time beats results.

This fixes CPJKU#402.
  • Loading branch information
declension committed Jan 10, 2019
1 parent 3942e9f commit fd670e6
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
31 changes: 23 additions & 8 deletions madmom/features/downbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ class DBNDownBeatTrackingProcessor(Processor):
(down-)beat activation function).
fps : float, optional
Frames per second.
beats_per_bar_weights : list, optional
Weight the beats_per_bar list when choosing.
Higher numbers favour the beat number at the same index, e.g.
for beats_per_bar of [3, 4], a value here for [1, 2] will bias the
the choice towards 4 beats per bar.
References
----------
Expand Down Expand Up @@ -200,11 +205,15 @@ class DBNDownBeatTrackingProcessor(Processor):
def __init__(self, beats_per_bar, min_bpm=MIN_BPM, max_bpm=MAX_BPM,
num_tempi=NUM_TEMPI, transition_lambda=TRANSITION_LAMBDA,
observation_lambda=OBSERVATION_LAMBDA, threshold=THRESHOLD,
correct=CORRECT, fps=None, **kwargs):
correct=CORRECT, fps=None, beats_per_bar_weights=None,
**kwargs):
# pylint: disable=unused-argument
# pylint: disable=no-name-in-module
# expand arguments to arrays
beats_per_bar = np.array(beats_per_bar, ndmin=1)
beats_per_bar_weights = (np.array(beats_per_bar_weights, ndmin=1)
if beats_per_bar_weights
else np.ones(beats_per_bar.shape))
min_bpm = np.array(min_bpm, ndmin=1)
max_bpm = np.array(max_bpm, ndmin=1)
num_tempi = np.array(num_tempi, ndmin=1)
Expand All @@ -220,11 +229,14 @@ def __init__(self, beats_per_bar, min_bpm=MIN_BPM, max_bpm=MAX_BPM,
if len(transition_lambda) != len(beats_per_bar):
transition_lambda = np.repeat(transition_lambda,
len(beats_per_bar))
if not (len(min_bpm) == len(max_bpm) == len(num_tempi) ==
len(beats_per_bar) == len(transition_lambda)):
raise ValueError('`min_bpm`, `max_bpm`, `num_tempi`, `num_beats` '
'and `transition_lambda` must all have the same '
'length.')
lengths = [len(a) for a in (min_bpm, max_bpm, num_tempi, beats_per_bar,
transition_lambda, beats_per_bar_weights)]
if not sum(beats_per_bar_weights):
raise ValueError("`beats_per_bar_weights` cannot total zero")
if np.var(lengths):
raise ValueError('`min_bpm`, `max_bpm`, `num_tempi`, `num_beats`, '
'`beats_per_bar_weights` and `transition_lambda` '
'must all have the same length.')
# get num_threads from kwargs
num_threads = min(len(beats_per_bar), kwargs.get('num_threads', 1))
# init a pool of workers (if needed)
Expand All @@ -245,6 +257,7 @@ def __init__(self, beats_per_bar, min_bpm=MIN_BPM, max_bpm=MAX_BPM,
self.hmms.append(HiddenMarkovModel(tm, om))
# save variables
self.beats_per_bar = beats_per_bar
self.beats_per_bar_weights = beats_per_bar_weights
self.threshold = threshold
self.correct = correct
self.fps = fps
Expand Down Expand Up @@ -283,8 +296,10 @@ def process(self, activations, **kwargs):
# (parallel) decoding of the activations with HMM
results = list(self.map(_process_dbn, zip(self.hmms,
it.repeat(activations))))
# choose the best HMM (highest log probability)
best = np.argmax(np.asarray(results)[:, 1])
# choose the best HMM (highest log probability) after weighting
weights = self.beats_per_bar_weights
scores = np.asarray(results)[:, 1] + np.log(weights / np.sum(weights))
best = np.argmax(scores)
# the best path through the state space
path, _ = results[best]
# the state space and observation model of the best HMM
Expand Down
10 changes: 10 additions & 0 deletions tests/test_features_downbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ def test_process(self):
downbeats = self.processor(sample_downbeat_act)
self.assertTrue(np.allclose(downbeats, np.empty((0, 2))))

def test_weighting_measure(self):
self.processor = DBNDownBeatTrackingProcessor(
[3, 4], fps=sample_downbeat_act.fps,
beats_per_bar_weights=[100, 1], correct=False)
downbeats = self.processor(sample_downbeat_act)
correct = np.array([[0.08, 1], [0.43, 2], [0.77, 3],
[1.11, 1], [1.45, 2], [1.79, 3],
[2.13, 1], [2.47, 2]])
self.assertTrue(np.allclose(downbeats, correct))


class TestPatternTrackingProcessorClass(unittest.TestCase):

Expand Down

0 comments on commit fd670e6

Please sign in to comment.