Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Nov 13, 2024
1 parent 225f6cb commit c4efd51
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions test/pytest/test_transpose_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,37 @@ def test_accuracy(data, keras_model, hls_model):
y_hls4ml = hls_model.predict(X).reshape(y_keras.shape)
# "accuracy" of hls4ml predictions vs keras
np.testing.assert_allclose(y_keras, y_hls4ml, rtol=0, atol=1e-04, verbose=True)


@pytest.fixture(scope='module')
def keras_model_highdim():
inp = Input(shape=(2, 3, 4, 5, 6), name='input_1')
out = Permute((3, 5, 4, 1, 2))(inp)
model = Model(inputs=inp, outputs=out)
return model


@pytest.fixture(scope='module')
def data_highdim():
X = np.random.randint(-128, 127, (100, 2, 3, 4, 5, 6)) / 128
X = X.astype(np.float32)
return X


@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel'])
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis'])
def test_highdim_permute(data_highdim, keras_model_highdim, io_type, backend):
X = data_highdim
model = keras_model_highdim

model_hls = hls4ml.converters.convert_from_keras_model(
model,
io_type=io_type,
backend=backend,
output_dir=str(test_root_path / f'hls4mlprj_highdim_transpose_{backend}_{io_type}'),
)
model_hls.compile()
y_keras = model.predict(X)
y_hls4ml = model_hls.predict(X).reshape(y_keras.shape) # type: ignore

assert np.all(y_keras == y_hls4ml)

0 comments on commit c4efd51

Please sign in to comment.