diff --git a/.gitignore b/.gitignore index d30f7bf..9a06dc6 100644 --- a/.gitignore +++ b/.gitignore @@ -498,3 +498,4 @@ cmake-build-*/ *.fbs **/fletcherfiltering_test_workspace/** **/mysql-data/** +vivado-projects/** diff --git a/.idea/sqldialects.xml b/.idea/sqldialects.xml new file mode 100644 index 0000000..56782ca --- /dev/null +++ b/.idea/sqldialects.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/run-pytest.sh b/run-pytest.sh index 605149f..545412b 100755 --- a/run-pytest.sh +++ b/run-pytest.sh @@ -2,4 +2,4 @@ source activate FletcherFiltering -PYTHONPATH="$PYTHONPATH:`pwd`/src:`pwd`/../transpyle:`pwd`/../fletcher/codegen:`pwd`/../moz-sql-parser" python -m pytest -s --show-progress --print-relative-time --verbose --cov=fletcherfiltering "$@" tests/ +PYTHONPATH="$PYTHONPATH:`pwd`/src:`pwd`/../transpyle:`pwd`/../fletcher/codegen:`pwd`/../moz-sql-parser" python -m pytest -rxXs --show-progress --print-relative-time --verbose --cov=fletcherfiltering "$@" tests/ diff --git a/run_pytest.ps1 b/run_pytest.ps1 index b048e8a..98a03af 100644 --- a/run_pytest.ps1 +++ b/run_pytest.ps1 @@ -6,4 +6,4 @@ conda activate FletcherFiltering $pwd = Get-Location $env:PYTHONPATH ="$pwd\src;$pwd\..\transpyle;$pwd\..\fletcher\codegen;$pwd\..\moz-sql-parser" -python -m pytest -s --show-progress --print-relative-time --verbose --cov=fletcherfiltering @Passthrough tests/ +python -m pytest -rxXs --show-progress --print-relative-time --verbose --cov=fletcherfiltering @Passthrough tests/ diff --git a/src/fletcherfiltering/codegen/compiler.py b/src/fletcherfiltering/codegen/compiler.py index eeb31a1..b7ebf66 100644 --- a/src/fletcherfiltering/codegen/compiler.py +++ b/src/fletcherfiltering/codegen/compiler.py @@ -24,6 +24,8 @@ from .transformations.ConstantPropagationTransform import ConstantPropagationTransform from .transformations.PythonASTTransform import PythonASTTransform +from collections import namedtuple + # These templates are all formatted, so double up curly braces. source_header_header = """#pragma once #if _MSC_VER && !__INTEL_COMPILER @@ -48,6 +50,8 @@ source_header_test_footer = """}}""" +TemplateData = namedtuple('TemplateData', ['source', 'destination']) + class Compiler(object): def __init__(self, in_schema: pa.Schema, out_schema: pa.Schema): @@ -75,7 +79,7 @@ def __call__(self, query_str: str, output_dir: Path = Path('.'), query_name: str include_build_system: bool = True, include_test_system: bool = True, extra_include_dirs: List[PurePath] = '', extra_link_dirs: List[PurePath] = '', - extra_link_libraries: List[str] = '', part_name: str = 'xa7a12tcsg325-1q'): + extra_link_libraries: List[str] = '', part_name: str = settings.PART_NAME): assert isinstance(query_str, str) queries = self.parse(query_str) @@ -100,6 +104,7 @@ def __call__(self, query_str: str, output_dir: Path = Path('.'), query_name: str out_str_columns = [x.name for x in self.out_schema if x.type == pa.string()] template_data = { + 'VAR_LENGTH': settings.VAR_LENGTH, 'query_name': query_name, 'generated_files': " ".join(generated_files), 'extra_include_dirs': ' '.join([d.as_posix() for d in extra_include_dirs]), @@ -113,32 +118,37 @@ def __call__(self, query_str: str, output_dir: Path = Path('.'), query_name: str 'out_columns_tb_new': "\n ".join([tb_new.format(x, settings.VAR_LENGTH + 1) for x in out_str_columns]), } - if include_build_system: - self.copy_files(self.template_path, output_dir.resolve(), - [Path('fletcherfiltering.cpp'), Path('fletcherfiltering.h')]) + build_system_files = [ + TemplateData(self.template_path / Path('template.fletcherfiltering.cpp'), + output_dir / Path('fletcherfiltering.cpp')), + TemplateData(self.template_path / Path('template.fletcherfiltering.h'), + output_dir / Path('fletcherfiltering.h')), + TemplateData(self.template_path / Path('template.CMakeLists.txt'), + output_dir / Path('CMakeLists.txt')), + ] + + test_system_files = [ + TemplateData(self.template_path / Path('template.run_complete_hls.tcl'), + output_dir / Path('run_complete_hls.tcl')), + TemplateData(self.template_path / Path('template.testbench.cpp'), + output_dir / Path('{0}{1}.cpp'.format(query_name, settings.TESTBENCH_SUFFIX))), + TemplateData(self.template_path / Path('template.data.h'), + output_dir / Path('{0}{1}.h'.format(query_name, settings.DATA_SUFFIX))), + ] - with open(self.template_path / 'template.CMakeLists.txt', 'r') as template_file: - cmake_list = string.Template(template_file.read()) - with open(output_dir / Path('CMakeLists.txt'), 'w') as cmake_file: - cmake_file.write(cmake_list.safe_substitute(template_data)) + if include_build_system: + for file in build_system_files: + with open(file.source, 'r') as template_file: + output_data = string.Template(template_file.read()) + with open(file.destination, 'w') as output_file: + output_file.write(output_data.safe_substitute(template_data)) if include_test_system: - with open(self.template_path / 'template.run_complete_hls.tcl', 'r') as template_file: - hls_tcl = string.Template(template_file.read()) - with open(output_dir / Path('run_complete_hls.tcl'), 'w') as hls_tcl_file: - hls_tcl_file.write(hls_tcl.safe_substitute(template_data)) - - with open(self.template_path / 'template.testbench.cpp', 'r') as template_file: - teshbench_cpp = string.Template(template_file.read()) - with open(output_dir / Path('{0}{1}.cpp'.format(query_name, settings.TESTBENCH_SUFFIX)), - 'w') as teshbench_cpp_file: - teshbench_cpp_file.write(teshbench_cpp.safe_substitute(template_data)) - - with open(self.template_path / 'template.data.h', 'r') as template_file: - data_cpp = string.Template(template_file.read()) - with open(output_dir / Path('{0}{1}.h'.format(query_name, settings.DATA_SUFFIX)), - 'w') as data_cpp_file: - data_cpp_file.write(data_cpp.safe_substitute(template_data)) + for file in test_system_files: + with open(file.source, 'r') as template_file: + output_data = string.Template(template_file.read()) + with open(file.destination, 'w') as output_file: + output_file.write(output_data.safe_substitute(template_data)) def copy_files(self, source_dir: PurePath, output_dir: PurePath, file_list: List[Path]): if source_dir == output_dir: diff --git a/src/fletcherfiltering/codegen/templates/fletcherfiltering.cpp b/src/fletcherfiltering/codegen/templates/template.fletcherfiltering.cpp similarity index 62% rename from src/fletcherfiltering/codegen/templates/fletcherfiltering.cpp rename to src/fletcherfiltering/codegen/templates/template.fletcherfiltering.cpp index 2ba9728..3090a73 100644 --- a/src/fletcherfiltering/codegen/templates/fletcherfiltering.cpp +++ b/src/fletcherfiltering/codegen/templates/template.fletcherfiltering.cpp @@ -6,8 +6,10 @@ bool __sql_builtin_like(char* data, int len, const char* pattern_name){ } void __sql_builtin_concat(char* buffer, int* offset, const char* value, int length){ - for(int i = 0; i < length && *offset < STRING_SIZE; i++, (*offset)++){ - buffer[*offset] = value[i]; + for(int i = 0, j = *offset; i < length && j < VAR_LENGTH; i++, (j)++){ + #pragma HLS PIPELINE II=1 + buffer[j] = value[i]; } + *offset += length; buffer[*offset] = '\0'; } \ No newline at end of file diff --git a/src/fletcherfiltering/codegen/templates/fletcherfiltering.h b/src/fletcherfiltering/codegen/templates/template.fletcherfiltering.h similarity index 89% rename from src/fletcherfiltering/codegen/templates/fletcherfiltering.h rename to src/fletcherfiltering/codegen/templates/template.fletcherfiltering.h index 5b8f859..431c4cb 100644 --- a/src/fletcherfiltering/codegen/templates/fletcherfiltering.h +++ b/src/fletcherfiltering/codegen/templates/template.fletcherfiltering.h @@ -1,6 +1,6 @@ #pragma once #include -#define STRING_SIZE 255 +#define VAR_LENGTH ${VAR_LENGTH} template struct nullable { diff --git a/src/fletcherfiltering/codegen/templates/template.run_complete_hls.tcl b/src/fletcherfiltering/codegen/templates/template.run_complete_hls.tcl index eb6cc2c..3c16a38 100644 --- a/src/fletcherfiltering/codegen/templates/template.run_complete_hls.tcl +++ b/src/fletcherfiltering/codegen/templates/template.run_complete_hls.tcl @@ -16,6 +16,6 @@ set_part {${part_name}} create_clock -period 10 -name default csim_design -O csynth_design -cosim_design -trace_level all -rtl vhdl +cosim_design -O -trace_level all -rtl vhdl #export_design -rtl vhdl -format ip_catalog exit \ No newline at end of file diff --git a/src/fletcherfiltering/codegen/transformations/PythonASTTransform.py b/src/fletcherfiltering/codegen/transformations/PythonASTTransform.py index 7819d20..d7be6b4 100644 --- a/src/fletcherfiltering/codegen/transformations/PythonASTTransform.py +++ b/src/fletcherfiltering/codegen/transformations/PythonASTTransform.py @@ -1,4 +1,5 @@ import typed_ast.ast3 as ast +import horast from typing import Union, Tuple from moz_sql_parser import ast_nodes import pyarrow as pa @@ -25,7 +26,8 @@ def transform(self, node, query_name: str = 'query'): # region Schema Helpers - def get_schema_ast(self, schema: pa.Schema, schema_name: str = "schema", test_data=False, is_input=False) -> ast.Expr: + def get_schema_ast(self, schema: pa.Schema, schema_name: str = "schema", test_data=False, + is_input=False) -> ast.Expr: schema_ast = [] for col in schema: col_def = ast.AnnAssign( @@ -68,7 +70,7 @@ def get_load_schema_ast(self, schema: pa.Schema): id=col.name, ctx=ast.Store() ), - slice=ast.Index(ast.Num(settings.VAR_LENGTH)), + slice=ast.Index(ast.Name(id="VAR_LENGTH", ctx=ast.Load())), ctx=ast.Store() ), annotation=self.type_resolver.resolve(arrow_type=col.type), @@ -107,27 +109,28 @@ def get_load_schema_ast(self, schema: pa.Schema): ], keywords=[] ), - body=[ast.Expr(ast.BinOp( - left=ast.Attribute( - value=ast.Name( - id='in_data', - ctx=ast.Load()), - attr=col.name, - ctx=ast.Load()), - op=ast.RShift(), - right=ast.Subscript( - value=ast.Name( - id=col.name, - ctx=ast.Load() - ), - slice=ast.Index(ast.Name( - id='i', - ctx=ast.Load() - )), - ctx=ast.Store() - ), - type_comment=None) - )], + body=[make_comment("pragma HLS PIPELINE"), + ast.Expr(ast.BinOp( + left=ast.Attribute( + value=ast.Name( + id=settings.INPUT_NAME, + ctx=ast.Load()), + attr=col.name, + ctx=ast.Load()), + op=ast.RShift(), + right=ast.Subscript( + value=ast.Name( + id=col.name, + ctx=ast.Load() + ), + slice=ast.Index(ast.Name( + id='i', + ctx=ast.Load() + )), + ctx=ast.Store() + ), + type_comment=None) + )], orelse=None, type_comment=ast.Name( id='int', @@ -138,7 +141,7 @@ def get_load_schema_ast(self, schema: pa.Schema): col_load = ast.Expr(ast.BinOp( left=ast.Attribute( value=ast.Name( - id='in_data', + id=settings.INPUT_NAME, ctx=ast.Load()), attr=col.name, ctx=ast.Load()), @@ -150,7 +153,7 @@ def get_load_schema_ast(self, schema: pa.Schema): col_load_len = ast.Expr(ast.BinOp( left=ast.Attribute( value=ast.Name( - id='in_data', + id=settings.INPUT_NAME, ctx=ast.Load()), attr=col.name + settings.LENGTH_SUFFIX, ctx=ast.Load()), @@ -187,29 +190,30 @@ def get_load_test_ast(self, schema: pa.Schema): ], keywords=[] ), - body=[ast.Expr(ast.BinOp( - left=ast.Attribute( - value=ast.Name( - id='in_data', - ctx=ast.Load()), - attr=col.name, - ctx=ast.Load()), - op=ast.LShift(), - right=ast.Subscript( - value=ast.Attribute( - value=ast.Name( - id='in_data_test', - ctx=ast.Load()), - attr=col.name, - ctx=ast.Load()), - slice=ast.Index(ast.Name( - id='i', - ctx=ast.Load() - )), - ctx=ast.Store() - ), - type_comment=None) - )], + body=[make_comment("pragma HLS PIPELINE"), + ast.Expr(ast.BinOp( + left=ast.Attribute( + value=ast.Name( + id=settings.INPUT_NAME, + ctx=ast.Load()), + attr=col.name, + ctx=ast.Load()), + op=ast.LShift(), + right=ast.Subscript( + value=ast.Attribute( + value=ast.Name( + id=settings.INPUT_NAME + settings.TEST_SUFFIX, + ctx=ast.Load()), + attr=col.name, + ctx=ast.Load()), + slice=ast.Index(ast.Name( + id='i', + ctx=ast.Load() + )), + ctx=ast.Store() + ), + type_comment=None) + )], orelse=None, type_comment=ast.Name( id='int', @@ -220,14 +224,14 @@ def get_load_test_ast(self, schema: pa.Schema): col_load = ast.Expr(ast.BinOp( left=ast.Attribute( value=ast.Name( - id='in_data', + id=settings.INPUT_NAME, ctx=ast.Load()), attr=col.name, ctx=ast.Load()), op=ast.LShift(), right=ast.Attribute( value=ast.Name( - id='in_data_test', + id=settings.INPUT_NAME + settings.TEST_SUFFIX, ctx=ast.Load()), attr=col.name, ctx=ast.Load()), @@ -246,7 +250,7 @@ def get_load_test_ast(self, schema: pa.Schema): args=[ ast.Attribute( value=ast.Name( - id='in_data_test', + id=settings.INPUT_NAME + settings.TEST_SUFFIX, ctx=ast.Load() ), attr=col.name, @@ -258,7 +262,7 @@ def get_load_test_ast(self, schema: pa.Schema): col_load_len = ast.Expr(ast.BinOp( left=ast.Attribute( value=ast.Name( - id='in_data', + id=settings.INPUT_NAME, ctx=ast.Load()), attr=col.name + settings.LENGTH_SUFFIX, ctx=ast.Load()), @@ -297,28 +301,29 @@ def get_store_schema_ast(self, schema: pa.Schema): ], keywords=[] ), - body=[ast.Expr(ast.BinOp( - left= - ast.Attribute( - value=ast.Name( - id='out_data', - ctx=ast.Load()), - attr=col.name, - ctx=ast.Store()), - op=ast.LShift(), - right=ast.Subscript( - value=ast.Name( - id=col_name_code, - ctx=ast.Load() - ), - slice=ast.Index(ast.Name( - id='i', - ctx=ast.Load() - )), - ctx=ast.Load() - ), - type_comment=None)) - ], + body=[make_comment("pragma HLS PIPELINE"), + ast.Expr(ast.BinOp( + left= + ast.Attribute( + value=ast.Name( + id=settings.OUTPUT_NAME, + ctx=ast.Load()), + attr=col.name, + ctx=ast.Store()), + op=ast.LShift(), + right=ast.Subscript( + value=ast.Name( + id=col_name_code, + ctx=ast.Load() + ), + slice=ast.Index(ast.Name( + id='i', + ctx=ast.Load() + )), + ctx=ast.Load() + ), + type_comment=None)) + ], orelse=None, type_comment=ast.Name( id='int', @@ -330,7 +335,7 @@ def get_store_schema_ast(self, schema: pa.Schema): left= ast.Attribute( value=ast.Name( - id='out_data', + id=settings.OUTPUT_NAME, ctx=ast.Load()), attr=col.name, ctx=ast.Store()), @@ -345,7 +350,7 @@ def get_store_schema_ast(self, schema: pa.Schema): left= ast.Attribute( value=ast.Name( - id='out_data', + id=settings.OUTPUT_NAME, ctx=ast.Load()), attr=col.name + settings.LENGTH_SUFFIX, ctx=ast.Store()), @@ -387,7 +392,7 @@ def get_store_test_ast(self, schema: pa.Schema): left= ast.Attribute( value=ast.Name( - id='out_data', + id=settings.OUTPUT_NAME, ctx=ast.Load()), attr=col.name, ctx=ast.Store()), @@ -395,7 +400,7 @@ def get_store_test_ast(self, schema: pa.Schema): right=ast.Subscript( value=ast.Attribute( value=ast.Name( - id='out_data_test', + id=settings.OUTPUT_NAME + settings.TEST_SUFFIX, ctx=ast.Load()), attr=col.name, ctx=ast.Load()), @@ -418,14 +423,14 @@ def get_store_test_ast(self, schema: pa.Schema): left= ast.Attribute( value=ast.Name( - id='out_data', + id=settings.OUTPUT_NAME, ctx=ast.Load()), attr=col.name, ctx=ast.Store()), op=ast.RShift(), right=ast.Attribute( value=ast.Name( - id='out_data_test', + id=settings.OUTPUT_NAME + settings.TEST_SUFFIX, ctx=ast.Load()), attr=col.name, ctx=ast.Load()), @@ -436,7 +441,7 @@ def get_store_test_ast(self, schema: pa.Schema): ast.Subscript( value=ast.Attribute( value=ast.Name( - id='out_data_test', + id=settings.OUTPUT_NAME + settings.TEST_SUFFIX, ctx=ast.Load()), attr=col.name, ctx=ast.Load()), @@ -456,7 +461,7 @@ def get_store_test_ast(self, schema: pa.Schema): left= ast.Attribute( value=ast.Name( - id='out_data', + id=settings.OUTPUT_NAME, ctx=ast.Load()), attr=col.name + settings.LENGTH_SUFFIX, ctx=ast.Store()), @@ -496,6 +501,43 @@ def get_out_schema_ast(self, test_data=False) -> ast.Expr: def get_input_ast(self) -> list: return self.get_load_schema_ast(self.in_schema) + def get_port_pragma_ast(self) -> list: + port_pragma_ast = [] + port_pragma_ast.append(make_comment('pragma HLS INTERFACE {} port=return'.format(settings.BLOCK_LEVEL_IO_TYPE))) + + # TODO Add name={2}{1} back in + for col in self.in_schema: + port_pragma_ast.append( + make_comment( + 'pragma HLS INTERFACE {0} port={3}.{1}'.format(settings.PORT_TYPE, col.name, + settings.INPORT_PREFIX, + settings.INPUT_NAME))) + if col.type in settings.VAR_LENGTH_TYPES: + port_pragma_ast.append( + make_comment( + 'pragma HLS INTERFACE {0} port={3}.{1}'.format(settings.PORT_TYPE, + col.name + settings.LENGTH_SUFFIX, + settings.INPORT_PREFIX, + settings.INPUT_NAME))) + + for col in self.out_schema: + port_pragma_ast.append( + make_comment( + 'pragma HLS INTERFACE {0} port={3}.{1}'.format(settings.PORT_TYPE, col.name, + settings.OUTPORT_PREFIX, + settings.OUTPUT_NAME))) + if col.type in settings.VAR_LENGTH_TYPES: + port_pragma_ast.append( + make_comment( + 'pragma HLS INTERFACE {0} port={3}.{1}'.format(settings.PORT_TYPE, + col.name + settings.LENGTH_SUFFIX, + settings.OUTPORT_PREFIX, + settings.OUTPUT_NAME))) + + # port_pragma_ast.append(make_comment('pragma HLS DATAFLOW'.format(settings.BLOCK_LEVEL_IO_TYPE))) + + return port_pragma_ast + def get_output_ast(self) -> list: return [ ast.If( @@ -510,14 +552,14 @@ def get_input_test_ast(self) -> list: return [ ast.AnnAssign( target=ast.Name( - id='in_data', + id=settings.INPUT_NAME, ctx=ast.Store()), annotation=ast.Name(id='in_schema', ctx=ast.Load()), value=None, simple=1), ast.AnnAssign( target=ast.Name( - id='out_data', + id=settings.OUTPUT_NAME, ctx=ast.Store()), annotation=ast.Name(id='out_schema', ctx=ast.Load()), value=None, @@ -730,7 +772,7 @@ def visit_Select(self, node: ast_nodes.Select) -> list: col_value = self.visit(select_col) - col_len_ast = ast.Num(settings.VAR_LENGTH) + col_len_ast = ast.Name(id="VAR_LENGTH", ctx=ast.Load()) if isinstance(col_value, PrePostAST): col_value_ast = col_value.ast @@ -810,13 +852,13 @@ def visit_Query(self, node: ast_nodes.Query) -> Tuple[list, list, list, list]: args=ast.arguments( args=[ ast.arg( - arg='in_data', + arg=settings.INPUT_NAME, annotation=ast.Index(ast.Name( id='in_schema', ctx=ast.Load())), type_comment=None), ast.arg( - arg='out_data', + arg=settings.OUTPUT_NAME, annotation=ast.Index(ast.Name( id='out_schema', ctx=ast.Load())), @@ -827,8 +869,9 @@ def visit_Query(self, node: ast_nodes.Query) -> Tuple[list, list, list, list]: kw_defaults=[], kwarg=None, defaults=[]), - body=self.get_input_ast() + [make_comment("Start of data processing")] + self.visit(node.q) + [ - make_comment("End of data processing")] + self.get_output_ast() + [ + body=self.get_port_pragma_ast() + self.get_input_ast() + [ + make_comment("Start of data processing")] + self.visit(node.q) + [ + make_comment("End of data processing")] + self.get_output_ast() + [ ast.Expr(ast.Return( value=ast.Name( id=settings.PASS_VAR_NAME, @@ -848,13 +891,13 @@ def visit_Query(self, node: ast_nodes.Query) -> Tuple[list, list, list, list]: args=ast.arguments( args=[ ast.arg( - arg='in_data' + settings.TEST_SUFFIX, + arg=settings.INPUT_NAME + settings.TEST_SUFFIX, annotation=ast.Index(ast.Name( id='in_schema' + settings.TEST_SUFFIX, ctx=ast.Load())), type_comment=None), ast.arg( - arg='out_data' + settings.TEST_SUFFIX, + arg=settings.OUTPUT_NAME + settings.TEST_SUFFIX, annotation=ast.Index(ast.Name( id='out_schema' + settings.TEST_SUFFIX, ctx=ast.Load())), @@ -883,11 +926,11 @@ def visit_Query(self, node: ast_nodes.Query) -> Tuple[list, list, list, list]: ), args=[ ast.Name( - id='in_data', + id=settings.INPUT_NAME, ctx=ast.Load() ), ast.Name( - id='out_data', + id=settings.OUTPUT_NAME, ctx=ast.Load() ) ], @@ -911,10 +954,11 @@ def visit_Query(self, node: ast_nodes.Query) -> Tuple[list, list, list, list]: ), type_comment=None) - header_ast = [self.get_in_schema_ast(), - self.get_out_schema_ast(), - function_ast - ] + header_ast = [ + self.get_in_schema_ast(), + self.get_out_schema_ast(), + function_ast + ] code_ast = [ function_ast diff --git a/src/fletcherfiltering/codegen/transformations/helpers/ArrowTypeResolver.py b/src/fletcherfiltering/codegen/transformations/helpers/ArrowTypeResolver.py index c23dc46..314cd6f 100644 --- a/src/fletcherfiltering/codegen/transformations/helpers/ArrowTypeResolver.py +++ b/src/fletcherfiltering/codegen/transformations/helpers/ArrowTypeResolver.py @@ -95,6 +95,18 @@ def type_timestamp_ms_(self, arrow_type, as_pointer: bool = False, as_const: boo ctx=ast.Load() ) + def type_timestamp_us_(self, arrow_type, as_pointer: bool = False, as_const: bool = False): + return ast.Name( + id=("const " if as_const else "") + "uint64_t" + ("*" if as_pointer else ''), + ctx=ast.Load() + ) + + def type_timestamp_ns_(self, arrow_type, as_pointer: bool = False, as_const: bool = False): + return ast.Name( + id=("const " if as_const else "") + "uint64_t" + ("*" if as_pointer else ''), + ctx=ast.Load() + ) + def type_string(self, arrow_type, as_pointer: bool = False, as_const: bool = False): return ast.Name( id=("const " if as_const else "") + "char" + ("*" if as_pointer else ''), diff --git a/src/fletcherfiltering/codegen/transformations/helpers/FunctionResolver.py b/src/fletcherfiltering/codegen/transformations/helpers/FunctionResolver.py index 676d3dc..9e6eedb 100644 --- a/src/fletcherfiltering/codegen/transformations/helpers/FunctionResolver.py +++ b/src/fletcherfiltering/codegen/transformations/helpers/FunctionResolver.py @@ -1,17 +1,14 @@ import pyarrow as pa -import collections +from collections import namedtuple import typed_ast.ast3 as ast -from . import grouped +from . import grouped, make_comment from ... import settings -FunctionMetadata = collections.namedtuple('FunctionMetadata', ['resolved_name', 'generator', 'id']) +FunctionMetadata = namedtuple('FunctionMetadata', ['resolved_name', 'generator', 'id']) -PrePostAST = collections.namedtuple('FunctionAST', ['pre_ast', 'ast', 'post_ast', 'len_ast']) - - -import itertools +PrePostAST = namedtuple('FunctionAST', ['pre_ast', 'ast', 'post_ast', 'len_ast']) class FunctionResolver(object): @@ -36,7 +33,7 @@ def resolve(self, funcname: str, length_func: bool = False) -> FunctionMetadata: return method(length_func) def func_CONCAT(self, length_func: bool = False) -> FunctionMetadata: - return FunctionMetadata('__sql_builtin_concat' + (settings.LENGHT_SUFFIX if length_func else ''), self.gen_CONCAT, + return FunctionMetadata('__sql_builtin_concat' + (settings.LENGTH_SUFFIX if length_func else ''), self.gen_CONCAT, self.get_next_id()) @staticmethod @@ -57,7 +54,7 @@ def gen_CONCAT(func_ast, func: FunctionMetadata) -> PrePostAST: id=buf_name, ctx=ast.Store() ), - slice=ast.Index(ast.Num(256)), + slice=ast.Index(ast.Name(id="VAR_LENGTH",ctx=ast.Load())), ctx=ast.Store() ), annotation=ast.Name( @@ -81,6 +78,7 @@ def gen_CONCAT(func_ast, func: FunctionMetadata) -> PrePostAST: buffer_ast = ast.Name(id=buf_name, ctx=ast.Load()) offset_ast = ast.Name(id=offset_name, ctx=ast.Load()) offset_value_ast = ast.Name(id=offset_ref_name, ctx=ast.Load()) + extra_ast.append(make_comment("pragma HLS INLINE REGION")) for arg, arg_len in grouped(func_ast.args,2): extra_ast.append( ast.Expr( @@ -91,6 +89,8 @@ def gen_CONCAT(func_ast, func: FunctionMetadata) -> PrePostAST: ) ) + + extra_ast.append(make_comment("pragma HLS INLINE OFF")) func_ast.args = [ast.Name(id=buf_name, ctx=ast.Load())] + func_ast.args value_ast = ast.Name( diff --git a/src/fletcherfiltering/settings.py b/src/fletcherfiltering/settings.py index c0b167e..1e43f84 100644 --- a/src/fletcherfiltering/settings.py +++ b/src/fletcherfiltering/settings.py @@ -5,11 +5,18 @@ MINIMAL_QUERY_LENGTH = len('SELECT *') VAR_LENGTH_TYPES = [pa.string()] -VAR_LENGTH = 255 +VAR_LENGTH = 32 +PORT_TYPE = 'axis' # axis would be proper, but it has weird reserved named resulting in TID suffixes. And has longer latency. +BLOCK_LEVEL_IO_TYPE = 'ap_ctrl_hs' # ap_ctrl_chained if backpressure is required. PASS_VAR_NAME = '__pass_record' LENGTH_SUFFIX = '_len' TEST_SUFFIX = '_test' DATA_SUFFIX = '_data' TESTBENCH_SUFFIX = '_tb' OUTPUT_SUFFIX = '_o' -LENGTH_TYPE = pa.int32() \ No newline at end of file +INPUT_NAME = 'in' +OUTPUT_NAME = 'out' +INPORT_PREFIX = 'i_' +OUTPORT_PREFIX = 'o_' +LENGTH_TYPE = pa.int32() +PART_NAME = 'xc7z020clg400-1' # 'xa7a12tcsg325-1q' 40/40/16k/8k \ No newline at end of file diff --git a/tests/helpers/base_query.py b/tests/helpers/base_query.py index cdd6f3e..d7ecce4 100644 --- a/tests/helpers/base_query.py +++ b/tests/helpers/base_query.py @@ -44,7 +44,10 @@ def __init__(self, printer, cnx, working_dir_base: Path, name='query', has_data_ self.cnx = cnx assert working_dir_base.is_dir() self.data = [] - self.cursor = cnx.cursor(dictionary=True, buffered=True) + if self.cnx: + self.cursor = cnx.cursor(dictionary=True, buffered=True) + else: + self.cursor = None self.name = name self.has_data_file = has_data_file self.in_schema = None @@ -59,27 +62,30 @@ def __init__(self, printer, cnx, working_dir_base: Path, name='query', has_data_ self.working_dir = working_dir_base / test_settings.WORKSPACE_NAME def setup(self): - if not self.create_table(): - if not self.drop_table(): - pytest.fail("Could not drop table successfully.") + if 'sql' in test_settings.TEST_PARTS: if not self.create_table(): - pytest.fail("Could not create table successfully on second try.") - - if not self.working_dir.exists(): - self.printer("Creating workspace directory '{}'".format(self.working_dir)) - self.working_dir.mkdir(parents=True, exist_ok=True) - else: - if self.clean_workdir: - self.printer("Re-creating workspace directory '{}'".format(self.working_dir)) - shutil.rmtree(self.working_dir) + if not self.drop_table(): + pytest.fail("Could not drop table successfully.") + if not self.create_table(): + pytest.fail("Could not create table successfully on second try.") + + if 'fletcherfiltering' in test_settings.TEST_PARTS or 'vivado' in test_settings.TEST_PARTS: + if not self.working_dir.exists(): + self.printer("Creating workspace directory '{}'".format(self.working_dir)) self.working_dir.mkdir(parents=True, exist_ok=True) else: - self.printer("Using workspace directory '{}'".format(self.working_dir)) + if self.clean_workdir: + self.printer("Re-creating workspace directory '{}'".format(self.working_dir)) + shutil.rmtree(self.working_dir) + self.working_dir.mkdir(parents=True, exist_ok=True) + else: + self.printer("Using workspace directory '{}'".format(self.working_dir)) if not self.has_data_file: self.generate_random_data() - self.insert_data() + if 'sql' in test_settings.TEST_PARTS: + self.insert_data() return True @@ -113,7 +119,7 @@ def generate_random_data(self): pytest.fail("Unsupported PK column type {} for {}".format(col.type, col.name)) else: if col.type == pa.string(): - record[col.name] = str_gen.generate(maxlength=128) + record[col.name] = str_gen.generate(maxlength=int(settings.VAR_LENGTH/2)) elif col.type == pa.int8(): record[col.name] = int_gen.generate(8) elif col.type == pa.uint8(): @@ -130,7 +136,7 @@ def generate_random_data(self): record[col.name] = int_gen.generate(64) elif col.type == pa.uint64(): record[col.name] = uint_gen.generate(64) - elif col.type == pa.timestamp('ms'): + elif pa.types.is_timestamp(col.type): record[col.name] = uint_gen.generate(64) elif col.type == pa.float16(): record[col.name] = float_gen.generate(16) @@ -162,7 +168,6 @@ def insert_data(self): return True def compile(self): - self.printer("Compiling SQL to HLS C++...") compiler = Compiler(self.in_schema, self.out_schema) @@ -170,6 +175,14 @@ def compile(self): extra_include_dirs=test_settings.HLS_INCLUDE_PATH, extra_link_dirs=test_settings.HLS_LINK_PATH, extra_link_libraries=test_settings.HLS_LIBS) + with open(self.working_dir / "in_schema.fbs", 'wb') as file: + s_serialized = self.in_schema.serialize() + file.write(s_serialized) + + with open(self.working_dir / "out_schema.fbs", 'wb') as file: + s_serialized = self.out_schema.serialize() + file.write(s_serialized) + def build_schema_class(self, schema: pa.Schema, suffix: str): schema_name = "Struct{}{}".format(self.name, suffix) schema_ast = python_class_generator.get_class_ast(schema, schema_name) @@ -179,22 +192,23 @@ def build_schema_class(self, schema: pa.Schema, suffix: str): return schema_local_scope[schema_name] def run_fletcherfiltering(self): - with open(os.devnull, "w") as f: - redir = f - if not self.swallow_build_output: - redir = None - - self.printer("Running CMake Generate...") - result = ProcessRunner(self.printer, ['cmake', '-G', test_settings.CMAKE_GENERATOR, - '-DCMAKE_BUILD_TYPE={}'.format(test_settings.BUILD_CONFIG), '.'], - shell=False, cwd=self.working_dir, stdout=redir) - if result != 0: - pytest.fail("CMake Generate exited with code {}".format(result)) - self.printer("Running CMake Build...") - result = ProcessRunner(self.printer, ['cmake', '--build', '.', '--config', test_settings.BUILD_CONFIG], - shell=False, cwd=self.working_dir, stdout=redir) - if result != 0: - pytest.fail("CMake Build exited with code {}".format(result)) + + if not self.swallow_build_output: + cmake_printer = self.printer + else: + cmake_printer = lambda val: None + + self.printer("Running CMake Generate...") + result = ProcessRunner(cmake_printer, ['cmake', '-G', test_settings.CMAKE_GENERATOR, + '-DCMAKE_BUILD_TYPE={}'.format(test_settings.BUILD_CONFIG), '.'], + shell=False, cwd=self.working_dir, stdout=redir) + if result != 0: + pytest.fail("CMake Generate exited with code {}".format(result)) + self.printer("Running CMake Build...") + result = ProcessRunner(cmake_printer, ['cmake', '--build', '.', '--config', test_settings.BUILD_CONFIG], + shell=False, cwd=self.working_dir, stdout=redir) + if result != 0: + pytest.fail("CMake Build exited with code {}".format(result)) in_schema_type = self.build_schema_class(self.in_schema, 'In') @@ -222,7 +236,7 @@ def run_fletcherfiltering(self): for col in self.in_schema: if col.type == pa.string(): setattr(in_schema, col.name, - ctypes.cast(ctypes.create_string_buffer(data_item[col.name].encode('ascii', 'replace'), + ctypes.cast(ctypes.create_string_buffer(data_item[col.name].encode('utf-8', 'replace'), size=settings.VAR_LENGTH), ctypes.c_char_p)) elif col.type == pa.float16(): @@ -237,7 +251,7 @@ def run_fletcherfiltering(self): for col in self.out_schema: if col.type == pa.string(): try: - out_data[col.name] = copy.copy(getattr(out_schema, col.name)).decode('ascii') + out_data[col.name] = copy.copy(getattr(out_schema, col.name)).decode('utf-8') except UnicodeDecodeError: print(getattr(out_schema, col.name)) elif col.type == pa.float16(): @@ -277,7 +291,7 @@ def run_vivado(self): data_item_lst.append("{}ll".format(data_item[col.name])) elif col.type == pa.uint64(): data_item_lst.append("{}ull".format(data_item[col.name])) - elif col.type == pa.timestamp('ms'): + elif pa.types.is_timestamp(col.type): data_item_lst.append("{}ull".format(data_item[col.name])) else: data_item_lst.append("{}".format(data_item[col.name])) @@ -293,21 +307,24 @@ def run_vivado(self): data_file.seek(0) data_file.write(data_cpp.safe_substitute(template_data)) data_file.truncate() - with open(os.devnull, "w") as f: - redir = f - if not self.swallow_build_output: - redir = None - result, sim_result = VivadoHLSProcessRunner(self.printer, - [str(test_settings.VIVADO_BIN_DIR / 'vivado_hls.bat'), '-f', - 'run_complete_hls.tcl'], - shell=False, cwd=self.working_dir, stdout=redir, - env={**os.environ, **vivado_env}) - if result != 0: - pytest.fail("Failed to run Vivado. Exited with code {}.".format(result)) - - self.printer("Vivado reported C/RTL co-simulation result: {}".format(sim_result)) - - assert sim_result == 'PASS' + + if not self.swallow_build_output: + vivado_printer = self.printer + else: + vivado_printer = lambda val: None + + result, sim_result = VivadoHLSProcessRunner(vivado_printer, + [str(test_settings.VIVADO_BIN_DIR / 'vivado_hls.bat'), '-f', + 'run_complete_hls.tcl'], + shell=False, cwd=self.working_dir, + env={**os.environ, **vivado_env}) + + if result != 0: + pytest.fail("Failed to run Vivado. Exited with code {}.".format(result)) + + self.printer("Vivado reported C/RTL co-simulation result: {}".format(sim_result)) + + assert sim_result == 'PASS' xor = XSIMOutputReader(self.in_schema, self.out_schema) @@ -326,19 +343,25 @@ def run_sql(self): def run(self): self.compile() - self.printer("Executing query on MySQL...") - sql_data = self.run_sql() - if platform.system() == 'Darwin' or platform.system() == 'Linux': + if 'sql' in test_settings.TEST_PARTS: + self.printer("Executing query on MySQL...") + sql_data = self.run_sql() + else: + sql_data = None + if (platform.system() == 'Darwin' or platform.system() == 'Linux') and 'fletcherfiltering' in test_settings.TEST_PARTS: self.printer("Executing query on FletcherFiltering...") fletcher_data = self.run_fletcherfiltering() else: fletcher_data = None - if platform.system() == 'Windows' or platform.system() == 'Linux': + if (platform.system() == 'Windows' or platform.system() == 'Linux') and 'vivado' in test_settings.TEST_PARTS: self.printer("Executing query on Vivado XSIM...") vivado_data = self.run_vivado() else: vivado_data = None + if sql_data is None: + pytest.xfail("No MySQL data was gathered. Can not compare results.") + if fletcher_data is None and vivado_data is None: pytest.xfail("No implementation data was gathered. Platform possibly unsupported.") @@ -390,6 +413,8 @@ def check_record_set(self, reference, candidate): continue if col.type == pa.float16(): + reference[col.name] = self.clamp_float16(reference[col.name]) + candidate[col.name] = self.clamp_float16(candidate[col.name]) if not math.isclose(reference[col.name], candidate[col.name], rel_tol=test_settings.REL_TOL_FLOAT16): self.printer("Column {} has a larger difference than the configured tolerance: {}.".format(col.name, test_settings.REL_TOL_FLOAT16)) @@ -408,9 +433,21 @@ def check_record_set(self, reference, candidate): if not reference[col.name] == candidate[col.name]: self.printer("Column {} does not have the same value in both records.".format(col.name)) errors += 1 - self.printer("Record errors: {}".format(errors)) + if errors > 0: + self.printer("Record errors: {}".format(errors)) return errors == 0 + @staticmethod + def clamp_float16(value): + if value > test_settings.FLOAT16_MAX: + return float("inf") + elif value < -test_settings.FLOAT16_MAX: + return float("-inf") + elif -test_settings.FLOAT16_MIN < value < test_settings.FLOAT16_MIN: + return 0 + + return value + def drop_table(self): query = """DROP TABLE `{0}`;""".format(self.name) self.printer("Dropping table for test {}".format(self.name)) diff --git a/tests/helpers/ctypes_type_mapper.py b/tests/helpers/ctypes_type_mapper.py index ebec932..2a4dfe9 100644 --- a/tests/helpers/ctypes_type_mapper.py +++ b/tests/helpers/ctypes_type_mapper.py @@ -45,6 +45,12 @@ def type_int64(self, arrow_type): def type_uint64(self, arrow_type): return ctypes.c_uint64 + def type_timestamp_ns_(self, arrow_type): + return ctypes.c_uint64 + + def type_timestamp_us_(self, arrow_type): + return ctypes.c_uint64 + def type_timestamp_ms_(self, arrow_type): return ctypes.c_uint64 diff --git a/tests/helpers/mysql_type_mapper.py b/tests/helpers/mysql_type_mapper.py index 5733ba4..46adeef 100644 --- a/tests/helpers/mysql_type_mapper.py +++ b/tests/helpers/mysql_type_mapper.py @@ -2,6 +2,7 @@ import typed_ast.ast3 as ast +from fletcherfiltering import settings class MySQLTypeMapper(object): @@ -50,10 +51,16 @@ def type_timestamp_ms_(self, arrow_type): # TODO: figure out a way to deal with the 32-bitness of TIMESTAMP columns in SQL (UNIX standard) def type_timestamp_s_(self, arrow_type): - return "TIMESTAMP" + return "BIGINT UNSIGNED" + + def type_timestamp_us_(self, arrow_type): + return "BIGINT UNSIGNED" + + def type_timestamp_ns_(self, arrow_type): + return "BIGINT UNSIGNED" def type_string(self, arrow_type): - return "VARCHAR(255)" + return "VARCHAR({})".format(settings.VAR_LENGTH) # TODO: Maybe replace with the correct decimal column def type_halffloat(self, arrow_type): diff --git a/tests/helpers/process_runner.py b/tests/helpers/process_runner.py index 8cc9158..ccbb375 100644 --- a/tests/helpers/process_runner.py +++ b/tests/helpers/process_runner.py @@ -26,7 +26,7 @@ def VivadoHLSProcessRunner(printer, proc_args: Sequence[Union[bytes, str, PathLi printed_levels = ['ERROR', 'WARNING'] - printed_modules = ['COSIM','HLS', 'COMMON'] + printed_modules = ['COSIM', 'HLS', 'COMMON'] test_output = False @@ -47,10 +47,11 @@ def VivadoHLSProcessRunner(printer, proc_args: Sequence[Union[bytes, str, PathLi if groupdict['module'].upper() in printed_modules or groupdict['level'].upper() in printed_levels: printer("{level}: [{module} {code}] {message}".format(**matches.groupdict())) - if groupdict['module'].upper() == 'COSIM' or groupdict['level'].upper() == 'INFO': - result_matches = result_regex.match(groupdict['message']) - if result_matches: - sim_result = result_matches.group('result').strip() + + if groupdict['module'].upper() == 'COSIM' or groupdict['level'].upper() == 'INFO': + result_matches = result_regex.match(groupdict['message']) + if result_matches: + sim_result = result_matches.group('result').strip() elif "== Start ==" in output: test_output = True printer(output.strip()) diff --git a/tests/helpers/struct_type_mapper.py b/tests/helpers/struct_type_mapper.py index 7c16a99..26989ac 100644 --- a/tests/helpers/struct_type_mapper.py +++ b/tests/helpers/struct_type_mapper.py @@ -45,6 +45,12 @@ def type_int64(self, arrow_type): def type_uint64(self, arrow_type): return 'Q',8 + def type_timestamp_us_(self, arrow_type): + return 'Q',8 + + def type_timestamp_ns_(self, arrow_type): + return 'Q',8 + def type_timestamp_ms_(self, arrow_type): return 'Q',8 diff --git a/tests/helpers/xsim_output_reader.py b/tests/helpers/xsim_output_reader.py index 5b4df7e..fa4bf7b 100644 --- a/tests/helpers/xsim_output_reader.py +++ b/tests/helpers/xsim_output_reader.py @@ -43,8 +43,8 @@ def read(self, data_path: PurePath, query_name: str): length_column = False col_type = pa.bool_() if column[0] is not col_name: - if not column[0].startswith("in_data"): - colname_regex = re.compile(r"out_data_([a-zA-Z0-9_]+?)(_len)?_V") + if not column[0].startswith(settings.INPUT_NAME): + colname_regex = re.compile(r"{0}_([a-zA-Z0-9_]+?)(_len)?_V".format(settings.OUTPUT_NAME)) matches = colname_regex.match(column[0]) if matches: col_name = matches.group(1) diff --git a/tests/queries/test_combination1.py b/tests/queries/test_combination1.py index 9553ef9..0097e58 100644 --- a/tests/queries/test_combination1.py +++ b/tests/queries/test_combination1.py @@ -4,19 +4,37 @@ class Combination1(BaseQuery): def __init__(self, printer, cnx, working_dir_base='/tmp', **kwargs): - super().__init__(printer, cnx, working_dir_base, name=self.__class__.__name__, has_data_file=False, separate_work_dir=True, **kwargs) + super().__init__(printer, cnx, working_dir_base, name=self.__class__.__name__, has_data_file=False, + separate_work_dir=True, **kwargs) self.in_schema = pa.schema([('id', pa.int32(), False), ('int1', pa.int32(), False), ('int2', pa.int32(), False), ('string1', pa.string(), False), - ('timestamp1', pa.timestamp('ms'), False)]) + ('timestamp1', pa.timestamp('s'), False), + ('timestamp2', pa.timestamp('us'), False), + ('timestamp3', pa.timestamp('ms'), False), + ('timestamp4', pa.timestamp('ns'), False) + ]) self.in_schema_pk = 'id' self.out_schema = pa.schema([('int1', pa.int32(), False), ('concat', pa.string(), False), - ('concat2', pa.string(), False)]) + ('concat2', pa.string(), False), + ('timestamp1', pa.timestamp('s'), False), + ('timestamp2', pa.timestamp('us'), False), + ('timestamp3', pa.timestamp('ms'), False), + ('timestamp4', pa.timestamp('ns'), False) + ]) - self.query = """select `int1`+`int2` as `int1`, CONCAT(string1,1<<4,'NULL') as concat, CONCAT('123456',string1,True,False) as concat2 FROM `{0}` WHERE `int1` > 4 AND `int2` < 18""".format( + self.query = """select + `int1`+`int2` as `int1`, + CONCAT(string1,1<<4,'NULL') as concat, + CONCAT('123456',string1,True,False) as concat2, + timestamp1, + timestamp2, + timestamp3, + timestamp4 + FROM `{0}` + WHERE `int1` > 4 AND `int2` < 18""".format( self.name) - diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 016461c..fad5a5f 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -4,9 +4,12 @@ def test_query(printer, test_class): printer('Started') - cnx = mysql.connector.connect(user=test_settings.MYSQL_USER, password=test_settings.MYSQL_PASSWORD, - host=test_settings.MYSQL_HOST, - database=test_settings.MYSQL_DATABASE) + if 'sql' in test_settings.TEST_PARTS: + cnx = mysql.connector.connect(user=test_settings.MYSQL_USER, password=test_settings.MYSQL_PASSWORD, + host=test_settings.MYSQL_HOST, + database=test_settings.MYSQL_DATABASE) + else: + cnx = None test = test_class(printer, cnx, working_dir_base=Path('.'), clean_workdir=test_settings.CLEAN_WORKDIR) try: diff --git a/tests/test_settings.py b/tests/test_settings.py index 1158769..9e17e37 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -11,6 +11,9 @@ REL_TOL_FLOAT32 = 1e-5 # digits for single float: 7.22 REL_TOL_FLOAT64 = 1e-9 # digits for double float: 15.95 +FLOAT16_MAX = 65504.0 +FLOAT16_MIN = 0.000061035 + MYSQL_USER='fletcherfiltering' MYSQL_PASSWORD='pfUcFN4S9Qq7X6NDBMHk' MYSQL_HOST='127.0.0.1'