Skip to content

Commit

Permalink
Add two new measurements
Browse files Browse the repository at this point in the history
- LabelNTKAlignment
- NTKeigenvectoralignment
For both to work, add the following functions to matrix utils
- compute_vector_outer_product
- compute_matrix_alignment
  • Loading branch information
KonstiNik committed Aug 14, 2024
1 parent abdae59 commit 456ea29
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 5 deletions.
50 changes: 49 additions & 1 deletion CI/unit_tests/utils/test_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,19 @@
"""

import numpy as np
from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_raises
from numpy.testing import (
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
assert_raises,
)

from papyrus.utils import (
compute_gramian_diagonal_distribution,
compute_hermitian_eigensystem,
compute_l_pq_norm,
compute_matrix_alignment,
compute_vector_outer_product,
flatten_rank_4_tensor,
normalize_gram_matrix,
unflatten_rank_4_tensor,
Expand Down Expand Up @@ -180,6 +187,47 @@ def test_compute_l_pq_norm(self):
norm_numpy = np.linalg.norm(matrix, ord="fro")
assert_array_almost_equal(norm, norm_numpy)

def test_compute_vector_outer_product(self):
"""
Test the computation of the outer product of a vector.
The test is done in the following steps:
- Test for a 1D vector
- Test for a 2D vector
"""
# 1D vector
vector = np.arange(3)
outer_product = compute_vector_outer_product(vector)
outer_product_truth = np.array([[0, 0, 0], [0, 1, 2], [0, 2, 4]])
assert_array_equal(outer_product, outer_product_truth)

# 2D vector
vector = np.array([[0, 1], [1, 0]])
outer_product = compute_vector_outer_product(vector)
outer_product_truth = np.array(
[[[[0, 0], [0, 1]], [[0, 0], [1, 0]]], [[[0, 1], [0, 0]], [[1, 0], [0, 0]]]]
)
assert outer_product.shape == (2, 2, 2, 2)
assert_array_equal(outer_product, outer_product_truth)

def test_compute_matrix_alignment(self):
"""
Test the computation of the matrix alignment.
"""
# Create two random matrices
A = np.array([[1, 0], [0, 1]])
B = np.array([[0, 1], [1, 0]])

A @ B
assert compute_matrix_alignment(A, B) == 0
assert_almost_equal(compute_matrix_alignment(A, A), 1)
assert_almost_equal(compute_matrix_alignment(B, B), 1)
assert_almost_equal(compute_matrix_alignment(A, -A), -1)
assert_almost_equal(compute_matrix_alignment(A, 0.5 * A), 1)
assert_almost_equal(
compute_matrix_alignment(A, 0.5 * A + 0.5 * B), np.sqrt(0.5)
)

def test_flatten_rank_4_tensor(self):
"""
Test the flattening of a rank 4 tensor.
Expand Down
70 changes: 67 additions & 3 deletions examples/mnist_flax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/konstantinnikolaou/Applications/miniconda3/envs/jax/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp # JAX NumPy\n",
Expand All @@ -38,6 +47,61 @@
"import neural_tangents as nt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from papyrus.utils.matrix_utils import unflatten_rank_4_tensor, compute_matrix_alignment"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0, 0, 0],\n",
" [0, 1, 2],\n",
" [0, 2, 4]])"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vec = np.arange(3)\n",
"# vec = np.array([ [0, 1], [1, 0] ])\n",
"outer = np.outer(vec, vec)\n",
"# out = unflatten_rank_4_tensor(outer, (vec.shape[0], vec.shape[0], vec.shape[1], vec.shape[1]))\n",
"outer"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# Create two matrices that are orthogonal to each other\n",
"A = np.array([[1, 0], [0, 1]])\n",
"B = np.array([[0, 1], [1, 0]])\n",
"\n",
"A @ B\n",
"assert compute_matrix_alignment(A, B) == 0\n",
"np.testing.assert_almost_equal(compute_matrix_alignment(A, A), 1)\n",
"np.testing.assert_almost_equal(compute_matrix_alignment(B, B), 1)\n",
"np.testing.assert_almost_equal(compute_matrix_alignment(A, -A), -1)\n",
"np.testing.assert_almost_equal(compute_matrix_alignment(A, 0.5 * A), 1)\n",
"np.testing.assert_almost_equal(compute_matrix_alignment(A, 0.5*A + 0.5*B), np.sqrt(0.5))\n",
"np.testing.assert_almost_equal(compute_matrix_alignment(A, A + B), np.sqrt(0.5))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -449,7 +513,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
147 changes: 147 additions & 0 deletions papyrus/measurements/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from papyrus.utils.matrix_utils import (
compute_gramian_diagonal_distribution,
compute_hermitian_eigensystem,
compute_matrix_alignment,
compute_vector_outer_product,
flatten_rank_4_tensor,
)


Expand Down Expand Up @@ -701,3 +704,147 @@ def apply(self, predictions: np.ndarray, targets: np.ndarray) -> np.ndarray:
The derivative of the loss with respect to the neural network outputs.
"""
return self.apply_fn(predictions, targets)


class LabelNTKAlignment(BaseMeasurement):
"""
Measurement class to record the alignment of the labels with the NTK.
Neural State Keys
-----------------
ntk : np.ndarray
The Neural Tangent Kernel (NTK) matrix.
labels : np.ndarray
The labels of the neural network.
"""

def __init__(
self,
name: str = "label_ntk_alignment",
rank: int = 0,
):
"""
Constructor method of the LabelNTKAlignment class.
Parameters
----------
name : str (default="label_ntk_alignment")
The name of the measurement, defining how the instance in the database
will be identified.
rank : int (default=0)
The rank of the measurement, defining the tensor order of the
measurement.
"""
super().__init__(name, rank)

def apply(self, ntk: np.ndarray, targets: np.ndarray) -> np.ndarray:
"""
Method to record the alignment of the labels with the NTK.
Parameters need to be provided as keyword arguments.
Parameters
----------
ntk : np.ndarray
The Neural Tangent Kernel (NTK) matrix.
targets : np.ndarray
The target values of the neural network.
Returns
-------
np.ndarray
The alignment of the labels with the NTK.
"""
# Assert that the NTK is a square matrix
if ntk.shape[0] != ntk.shape[1]:
raise ValueError(
"To compute the self-entropy of the NTK, the NTK matrix must"
f" be a square matrix, but got a matrix of shape {ntk.shape}."
)
if len(ntk.shape) != 2:
raise ValueError(
"To compute the self-entropy of the NTK, the NTK matrix must"
f" be a tensor of rank 2, but got a tensor of rank {len(ntk.shape)}."
)
# Assert that the labels are one-hot encoded
if len(targets.shape) != 2:
raise ValueError(
"To compute the alignment of the labels with the NTK, the labels must"
" be a one-hot encoded tensor of rank 2, but got a tensor of rank "
f"{len(targets.shape)}."
)
label_matrix = compute_vector_outer_product(targets)
flat_label_matrix, shape = flatten_rank_4_tensor(label_matrix)
return compute_matrix_alignment(ntk, flat_label_matrix)


class NTKEigenvectorAlignment(BaseMeasurement):
"""
Measurement class to record the alignment of the eigenvectors of the NTK.
TODO: Add larger memory size to store and evaluate the alignment of the eigenvectors
over more than one step.
Neural State Keys
-----------------
ntk : np.ndarray
The Neural Tangent Kernel (NTK) matrix.
"""

def __init__(
self,
name: str = "ntk_eigenvector_alignment",
rank: int = 1,
):
"""
Constructor method of the NTKEigenvectorAlignment class.
Parameters
----------
name : str (default="ntk_eigenvector_alignment")
The name of the measurement, defining how the instance in the database
will be identified.
rank : int (default=0)
The rank of the measurement, defining the tensor order of the
measurement.
memory_size : int (default=1)
The number of eigenvectors to store in memory for alignment computation.
"""
super().__init__(name, rank)
self.previous_eigenvectors = None

def apply(self, ntk: np.ndarray) -> np.ndarray:
"""
Method to record the alignment of the eigenvectors of the NTK.
Parameters need to be provided as keyword arguments.
Parameters
----------
ntk : np.ndarray
The Neural Tangent Kernel (NTK) matrix.
Returns
-------
np.ndarray
The alignment of the eigenvectors of the NTK.
"""
# Assert that the NTK is a square matrix
if ntk.shape[0] != ntk.shape[1]:
raise ValueError(
"To compute the alignment of the eigenvectors of the NTK, the NTK "
"matrix must be a square matrix, but got a matrix of shape {ntk.shape}."
)
if len(ntk.shape) != 2:
raise ValueError(
"To compute the alignment of the eigenvectors of the NTK, the NTK "
"matrix must be a tensor of rank 2, but got a tensor of rank "
f"{len(ntk.shape)}."
)
if self.previous_eigenvectors is None:
self.previous_eigenvectors = compute_hermitian_eigensystem(ntk)[1]
return 0
eigenvectors = compute_hermitian_eigensystem(ntk)[1]
alignment = compute_matrix_alignment(eigenvectors, self.previous_eigenvectors)
self.previous_eigenvectors = eigenvectors
return alignment
4 changes: 4 additions & 0 deletions papyrus/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
compute_gramian_diagonal_distribution,
compute_hermitian_eigensystem,
compute_l_pq_norm,
compute_matrix_alignment,
compute_vector_outer_product,
flatten_rank_4_tensor,
normalize_gram_matrix,
unflatten_rank_4_tensor,
Expand All @@ -45,4 +47,6 @@
compute_trace.__name__,
compute_shannon_entropy.__name__,
compute_von_neumann_entropy.__name__,
compute_vector_outer_product.__name__,
compute_matrix_alignment.__name__,
]
Loading

0 comments on commit 456ea29

Please sign in to comment.