diff --git a/08-seq_classification.ipynb b/08-seq_classification.ipynb index b8792e4ef..1255c3b6e 100644 --- a/08-seq_classification.ipynb +++ b/08-seq_classification.ipynb @@ -497,7 +497,7 @@ "\n", " with torch.no_grad():\n", " for batch_idx in range(len(data_generator)):\n", - " data, target = test_data_gen[batch_idx]\n", + " data, target = data_generator[batch_idx]\n", " data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)\n", "\n", " data_decoded = data_generator.decode_x_batch(data.cpu().numpy())\n", @@ -527,18 +527,22 @@ " print(f'{label}: {num_correct} / {count_classes[label]} correct')\n", "\n", " # Report some random sequences for examination\n", + " num_sequences_to_print = min(10, len(correct))\n", + " idxs = random.sample(range(len(correct)), num_sequences_to_print)\n", + "\n", " print('\\nHere are some example sequences:')\n", - " for i in range(10):\n", - " sequence, truth, prediction = correct[random.randrange(0, 10)]\n", + " for i in idxs:\n", + " sequence, truth, prediction = correct[i]\n", " print(f'{sequence} -> {truth} was labelled {prediction}')\n", "\n", " # Report misclassified sequences for investigation\n", - " if incorrect and verbose:\n", - " print('\\nThe following sequences were misclassified:')\n", - " for sequence, truth, prediction in incorrect:\n", - " print(f'{sequence} -> {truth} was labelled {prediction}')\n", - " else:\n", - " print('\\nThere were no misclassified sequences.')" + " if verbose:\n", + " if incorrect:\n", + " print('\\nThe following sequences were misclassified:')\n", + " for sequence, truth, prediction in incorrect:\n", + " print(f'{sequence} -> {truth} was labelled {prediction}')\n", + " else:\n", + " print('\\nThere were no misclassified sequences.')" ] }, {