Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add general transpose for vivado/vitis #1124

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 47 additions & 14 deletions hls4ml/backends/vivado/passes/reshaping_templates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from math import prod

import numpy as np

from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
from hls4ml.model.layers import Resize, Transpose, ZeroPadding1D, ZeroPadding2D

Expand Down Expand Up @@ -97,16 +101,45 @@ def format(self, node):

# Transpose templates

transpose_config_template = """struct config{index} : nnet::transpose_config {{
static const unsigned depth = {depth};
static const unsigned height = {height};
static const unsigned width = {width};
static constexpr unsigned perm[3] = {{{perm_str}}};
}};\n"""

transpose_function_template = 'nnet::transpose_{dim}<{input_t}, {output_t}, {config}>({input}, {output});'
transpose_include_list = ['nnet_utils/nnet_transpose.h', 'nnet_utils/nnet_transpose_stream.h']

transpose_config_template = """struct {config_name} {{
static const unsigned dims = {dims};
static const unsigned N = {N};
static const unsigned* const from_shape;
static const unsigned* const to_shape;
static const unsigned* const perm;
static const unsigned* const perm_strides;
}};

unsigned {config_name}_from_shape[{dims}] = {{{from_shape}}};
unsigned {config_name}_to_shape[{dims}] = {{{to_shape}}};
unsigned {config_name}_perm[{dims}] = {{{perm}}};
unsigned {config_name}_perm_strides[{dims}] = {{{perm_strides}}};

const unsigned* const {config_name}::from_shape = {config_name}_from_shape;
const unsigned* const {config_name}::to_shape = {config_name}_to_shape;
const unsigned* const {config_name}::perm = {config_name}_perm;
const unsigned* const {config_name}::perm_strides = {config_name}_perm_strides;
"""

transpose_function_template = 'nnet::transpose<{input_t}, {output_t}, {config_name}>({input}, {output});'

transpose_include_list = ['nnet_utils/nnet_array.h', 'nnet_utils/nnet_stream.h']

def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]):
new_shape = tuple(shape[i] for i in perm)
strides = np.cumprod((shape[1:] + (1,))[::-1])[::-1]
perm_strides = tuple(int(strides[i]) for i in perm)
return transpose_config_template.format(
dims=len(shape),
N=prod(shape),
from_shape=', '.join(str(x) for x in shape),
perm=', '.join(str(x) for x in perm),
perm_strides=', '.join(str(x) for x in perm_strides),
to_shape=', '.join(str(x) for x in new_shape),
config_name=name,
)


class TransposeConfigTemplate(LayerConfigTemplate):
Expand All @@ -115,18 +148,18 @@ def __init__(self):
self.template = transpose_config_template

def format(self, node):
params = self._default_config_params(node)

return self.template.format(**params)
shape = tuple(node.get_input_variable().shape)
perm = tuple(node.get_attr('perm'))
name = f'config{node.index}'
return permute_config_gen(name, shape, perm)


class TransposeFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(Transpose, include_header=transpose_include_list)
self.template = transpose_function_template
super().__init__(Transpose, include_header=transpose_include_list)

def format(self, node):
params = self._default_function_params(node)
params['dim'] = node.get_attr('dim')

params['config_name'] = f'config{node.index}'
return self.template.format(**params)
8 changes: 5 additions & 3 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,8 +1167,7 @@ def initialize(self):
perm = self.get_attr('perm')
self.set_attr('dim', f'{len(inp.shape)}d')

if len(perm) > 3:
raise Exception('ERROR: Transpose of tensors with rank > 3 is not yet supported.')
# TODO: dim>3 is only supported for vivado/vitis backend

# ONNX double transpose specific, sometimes ONNX injects
# useless double transpose layers when converting
Expand All @@ -1188,11 +1187,14 @@ def initialize(self):
self.set_attr('depth', 1)
self.set_attr('height', inp.shape[0])
self.set_attr('width', inp.shape[1])
elif len(shape) > 2:
elif len(shape) == 3:
dims = [f'OUT_DEPTH_{self.index}', f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}']
self.set_attr('depth', inp.shape[0])
self.set_attr('height', inp.shape[1])
self.set_attr('width', inp.shape[2])
elif len(shape) > 3:
# Differentiate between 2/3/3+ dim does not really appear to be needed. To be removed?
dims = [f'OUT_DIM_{i}_{self.index}' for i in range(1, len(shape) + 1)]
self.add_output_variable(shape, dims, precision=inp.type.precision)


Expand Down
52 changes: 0 additions & 52 deletions hls4ml/templates/vivado/nnet_utils/nnet_array.h

This file was deleted.

23 changes: 0 additions & 23 deletions hls4ml/templates/vivado/nnet_utils/nnet_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,29 +179,6 @@ void broadcast_stream(hls::stream<data_T> &data, hls::stream<res_T> &res) {
}
}

template <class data_T, class res_T, typename CONFIG_T>
void transpose_2d(hls::stream<data_T> &data, hls::stream<res_T> &res) {
typename data_T::value_type data_array[CONFIG_T::height * CONFIG_T::width];
#pragma HLS ARRAY_PARTITION variable=data_array complete

for (int i = 0; i < CONFIG_T::height * CONFIG_T::width / data_T::size; i++) {
#pragma HLS PIPELINE
data_T in_data = data.read();
for (int j = 0; j < data_T::size; j++) {
data_array[i * data_T::size + j] = typename data_T::value_type(in_data[j]);
}
}

for (int i = 0; i < CONFIG_T::height * CONFIG_T::width / res_T::size; i++) {
#pragma HLS PIPELINE
res_T out_data;
PRAGMA_DATA_PACK(out_data)
for (int j = 0; j < res_T::size; j++) {
out_data[j] = typename res_T::value_type(data_array[j * data_T::size + i]);
}
res.write(out_data);
}
}
} // namespace nnet

#endif
39 changes: 39 additions & 0 deletions hls4ml/templates/vivado/nnet_utils/nnet_transpose.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef NNET_PERMUTE_H_
#define NNET_PERMUTE_H_

namespace nnet {

struct transpose_config {
static const unsigned dims;
static const unsigned N;
// vivado/vitis hls can't index constexpr array for some reason
// and vivado hls don't like template recursion either (vitis is fine)
// thus this appears to be the only workaround (or overkill it with codegen)
static const unsigned *const from_shape;
static const unsigned *const to_shape;
static const unsigned *const perm;
static const unsigned *const perm_strides;
};

template <typename CONFIG_T> unsigned transfer_idx(int index) {
// Given output idx in c-order flat array, return input idx
int idx = 0;
for (int i = CONFIG_T::dims - 1; i >= 0; i--) {
idx += (index % CONFIG_T::to_shape[i]) * CONFIG_T::perm_strides[i];
index /= CONFIG_T::to_shape[i];
}
return idx;
}

template <typename data_T, typename res_T, typename CONFIG_T>
void transpose(const data_T data[CONFIG_T::N], res_T res[CONFIG_T::N]) {
for (int i = 0; i < CONFIG_T::N; i++) {
#pragma HLS UNROLL
int idx = transfer_idx<CONFIG_T>(i);
res[i] = data[idx];
}
}

} // namespace nnet

#endif
67 changes: 67 additions & 0 deletions hls4ml/templates/vivado/nnet_utils/nnet_transpose_stream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#ifndef NNET_TRANSPOSE_STREAM_H
#define NNET_TRANSPOSE_STREAM_H

#include "hls_stream.h"
#include "nnet_transpose.h"
#include <type_traits>

namespace nnet {

template <typename data_T, typename res_T, typename CONFIG_T>
typename std::enable_if<CONFIG_T::dims == 2, void>::type transpose(hls::stream<data_T> &data, hls::stream<res_T> &res) {
// #pragma HLS INLINE RECURSIVE
typename data_T::value_type data_array[CONFIG_T::N];
#pragma HLS ARRAY_PARTITION variable=data_array complete

for (int i = 0; i < CONFIG_T::N / data_T::size; i++) {
#pragma HLS PIPELINE
data_T in_data = data.read();
for (int j = 0; j < data_T::size; j++) {
#pragma HLS UNROLL
data_array[i * data_T::size + j] = typename data_T::value_type(in_data[j]);
}
}

for (int i = 0; i < CONFIG_T::N / res_T::size; i++) {
#pragma HLS PIPELINE
res_T out_data;
PRAGMA_DATA_PACK(out_data)
for (int j = 0; j < res_T::size; j++) {
#pragma HLS UNROLL
out_data[j] = typename res_T::value_type(data_array[j * CONFIG_T::from_shape[1] + i]);
}
res.write(out_data);
}
}

// This sfinae is for vivado_hls, which has some overhead using the transfer_idx in io_stream.
// In vitis both performs exactly the same, thus this is not removed out of convenience.
template <typename data_T, typename res_T, typename CONFIG_T>
typename std::enable_if<CONFIG_T::dims != 2, void>::type transpose(hls::stream<data_T> &data, hls::stream<res_T> &res) {
// #pragma HLS INLINE RECURSIVE
typename data_T::value_type data_array[CONFIG_T::N];
#pragma HLS ARRAY_PARTITION variable=data_array complete

for (int i = 0; i < CONFIG_T::N / data_T::size; i++) {
#pragma HLS PIPELINE
data_T in_data = data.read();
for (int j = 0; j < data_T::size; j++) {
#pragma HLS UNROLL
data_array[i * data_T::size + j] = typename data_T::value_type(in_data[j]);
}
}

for (int i = 0; i < CONFIG_T::N / res_T::size; i++) {
#pragma HLS PIPELINE
res_T out_data;
PRAGMA_DATA_PACK(out_data)
for (int j = 0; j < res_T::size; j++) {
#pragma HLS UNROLL
out_data[j] = typename res_T::value_type(data_array[transfer_idx<CONFIG_T>(i * res_T::size + j)]);
}
res.write(out_data);
}
}

} // namespace nnet
#endif
34 changes: 34 additions & 0 deletions test/pytest/test_transpose_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,37 @@ def test_accuracy(data, keras_model, hls_model):
y_hls4ml = hls_model.predict(X).reshape(y_keras.shape)
# "accuracy" of hls4ml predictions vs keras
np.testing.assert_allclose(y_keras, y_hls4ml, rtol=0, atol=1e-04, verbose=True)


@pytest.fixture(scope='module')
def keras_model_highdim():
inp = Input(shape=(2, 3, 4, 5, 6), name='input_1')
out = Permute((3, 5, 4, 1, 2))(inp)
model = Model(inputs=inp, outputs=out)
return model


@pytest.fixture(scope='module')
def data_highdim():
X = np.random.randint(-128, 127, (100, 2, 3, 4, 5, 6)) / 128
X = X.astype(np.float32)
return X


@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel'])
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis'])
def test_highdim_permute(data_highdim, keras_model_highdim, io_type, backend):
X = data_highdim
model = keras_model_highdim

model_hls = hls4ml.converters.convert_from_keras_model(
model,
io_type=io_type,
backend=backend,
output_dir=str(test_root_path / f'hls4mlprj_highdim_transpose_{backend}_{io_type}'),
)
model_hls.compile()
y_keras = model.predict(X)
y_hls4ml = model_hls.predict(X).reshape(y_keras.shape) # type: ignore

assert np.all(y_keras == y_hls4ml)
Loading