diff --git a/README.md b/README.md index cb0f603..2e8fd4f 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # pyMLIR: Python Interface for the Multi-Level Intermediate Representation -pyMLIR is a full Python interface to parse, process, and output [MLIR](https://mlir.llvm.org/) files according to the +pyMLIR is a full Python interface to parse, process, output and run [MLIR](https://mlir.llvm.org/) files according to the syntax described in the [MLIR documentation](https://github.com/llvm/llvm-project/tree/master/mlir/docs). pyMLIR supports the basic dialects and can be extended with other dialects. It uses [Lark](https://github.com/lark-parser/lark) to parse the MLIR syntax, and mirrors the classes into Python classes. Custom dialects can also be implemented with a @@ -19,10 +19,10 @@ Note that the tool *does not depend on LLVM or MLIR*. It can be installed and in **Requirements:** Python 3.6 or newer, and the requirements in `setup.py` or `requirements.txt`. To manually install the requirements, use `pip install -r requirements.txt` -**Problem parsing MLIR files?** Run the file through LLVM's `mlir-opt --mlir-print-op-generic` to get the generic form -of the IR (instructions on how to build/install MLIR can be found [here](https://mlir.llvm.org/getting_started/)): -``` -$ mlir-opt file.mlir --mlir-print-op-generic > output.mlir +**Problem parsing MLIR files?** Run the file through LLVM's `mlir-opt` as `mlir.run.mlir_opt(source, ["--mlir-print-op-generic"])` to +get the generic form of the IR (instructions on how to build/install MLIR can be found [here](https://mlir.llvm.org/getting_started/)): +```python +source = mlir.run.mlir_opt(source, ["--mlir-print-op-generic"]) ``` **Found other problems parsing files?** Not all dialects and modes are supported. Feel free to send us an issue or @@ -130,3 +130,38 @@ print(m.dump_ast()) All dialect implementations can be found in the [dialects](mlir/dialects) subfolder. Additional uses of the library, including a custom dialect implementation, can be found in the [tests](tests) subfolder. + + +### Call `mlir-opt` and invoke functions + +Note that invoking MLIR functions depends on LLVM toolchain. The following binaries must be present in `$PATH`: +- `mlir-opt` +- `mlir-translate` +- `llc` + +```python +source = """ +#identity = affine_map<(i,j) -> (i,j)> +#attrs = { + indexing_maps = [#identity, #identity, #identity], + iterator_types = ["parallel", "parallel"] +} +func @example(%A: memref, %B: memref, %C: memref) { + linalg.generic #attrs ins(%A, %B: memref, memref) outs(%C: memref) { + ^bb0(%a: f64, %b: f64, %c: f64): + %d = addf %a, %b : f64 + linalg.yield %d : f64 + } + return +}""" + +source = mlirrun.mlir_opt(source, ["-convert-linalg-to-loops", + "-convert-scf-to-std"]) +a = np.random.rand(10, 10) +b = np.random.rand(10, 10) +c = np.empty_like(a) + +mlirrun.call_function(source, "example", [a, b, c]) + +assert (c == a+b).all() +``` diff --git a/mlir/run.py b/mlir/run.py new file mode 100644 index 0000000..8f3ad81 --- /dev/null +++ b/mlir/run.py @@ -0,0 +1,267 @@ +""" MLIR kernel invocation.""" + +__copyright__ = "Copyright (C) 2020 Kaushik Kulkarni" + +__license__ = """ +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +import ctypes +import tempfile +import numpy as np +from dataclasses import dataclass +from typing import Tuple, List, Any, Optional +from pytools import memoize_method +from pytools.prefork import call_capture_output, ExecError +from codepy.jit import compile_from_string +from codepy.toolchain import ToolchainGuessError, GCCToolchain +from codepy.toolchain import guess_toolchain as guess_toolchain_base + + +# {{{ Memref + +def get_nd_memref_struct_type(n: int): + nd_long = ctypes.c_long * n + + class NDMemrefStruct(ctypes.Structure): + _fields_ = [("data", ctypes.c_void_p), + ("alignedData", ctypes.c_void_p), + ("offset", ctypes.c_long), + ("shape", nd_long), + ("strides", nd_long)] + + return NDMemrefStruct + + +@dataclass(init=True) +class Memref: + data_ptr: int + shape: Tuple[int, ...] + strides: Tuple[int, ...] + + @staticmethod + def from_numpy(ary): + """ + Create a :class:`Memref` from a :class:`numpy.ndarray` + """ + shape = ary.shape + strides = tuple(stride // ary.itemsize for stride in ary.strides) + return Memref(ary.ctypes.data, + shape, + strides) + + @property + def ndim(self): + return len(self.shape) + + @property + @memoize_method + def ctype(self): + struct_cls = get_nd_memref_struct_type(self.ndim) + + typemap = dict(struct_cls._fields_) + dataptr_cls = typemap["data"] + shape_cls = typemap["shape"] + strides_cls = typemap["strides"] + + return struct_cls(dataptr_cls(self.data_ptr), + dataptr_cls(self.data_ptr), + 0, # offset is alway zero for numpy arrays + shape_cls(*self.shape), + strides_cls(*self.strides)) + + @property + @memoize_method + def pointer_ctype(self): + return ctypes.pointer(self.ctype) + +# }}} + + +# {{{ run kernels + +def guess_toolchain(): + # copied from loopy/target/c/c_execution.py + try: + toolchain = guess_toolchain_base() + except (ToolchainGuessError, ExecError): + # missing compiler python was built with (likely, Conda) + # use a default GCCToolchain + # this is ugly, but I'm not sure there's a clean way to copy the + # default args + toolchain = GCCToolchain( + cc="gcc", + cflags="-std=c99 -O3 -fPIC".split(), + ldflags=["-shared"], + libraries=[], + library_dirs=[], + defines=[], + undefines=[], + source_suffix="c", + so_ext=".so", + o_ext=".o", + include_dirs=[]) + + return toolchain + + +def get_mlir_opt_version(mlir_opt="mlir-opt"): + cmdline = [mlir_opt, "-version"] + result, stdout, stderr = call_capture_output(cmdline) + return stdout.decode() + + +def mlir_opt(source: str, options: List[str], mlir_opt="mlir-opt"): + """ + Calls ``mlir-opt`` on *source* with *options* as additional arguments. + + :arg source: The code to be passed to mlir-opt. + :arg options: An instance of :class:`list`. + :return: Transformed *source* as emitted by ``mlir-opt``. + """ + assert "-o" not in options + with tempfile.NamedTemporaryFile(mode="w", suffix=".mlir") as fp: + fp.write(source) + fp.file.flush() + + cmdline = [mlir_opt, fp.name] + options + result, stdout, stderr = call_capture_output(cmdline) + + return stdout.decode() + + +def mlir_translate(source, options, mlir_translate="mlir-translate"): + """ + Calls ``mlir-translate`` on *source* with *options* as additional arguments. + + :arg source: The code to be passed to mlir-translate. + :arg options: An instance of :class:`list`. + :return: Transformed *source* as emitted by ``mlir-translate``. + """ + with tempfile.NamedTemporaryFile(mode="w", suffix=".mlir", delete=False) as fp: + fp.write(source) + fp.file.flush() + cmdline = [mlir_translate, fp.name] + options + result, stdout, stderr = call_capture_output(cmdline) + + return stdout.decode() + + +def mlir_to_llvmir(source, debug=False): + """ + Converts MLIR *source* to LLVM IR. Invokes ``mlir-tranlate -mlir-to-llvmir`` + under the hood. + """ + if debug: + return mlir_translate(source, ["-mlir-to-llvmir", "-debugify-level=location+variables"]) + else: + return mlir_translate(source, ["-mlir-to-llvmir"]) + + +def llvmir_to_obj(source, llc="llc"): + """ + Returns the compiled object code for the LLVM code *source*. + """ + with tempfile.NamedTemporaryFile(mode="w", suffix=".ll") as llfp: + llfp.write(source) + llfp.file.flush() + with tempfile.NamedTemporaryFile(suffix=".o", mode="rb") as objfp: + cmdline = [llc, llfp.name, "-o", objfp.name, "-filetype=obj"] + result, stdout, stderr = call_capture_output(cmdline) + + obj_code = objfp.read() + + return obj_code + + +def preprocess_arg(arg): + if isinstance(arg, Memref): + return arg.pointer_ctype + elif isinstance(arg, np.ndarray): + return Memref.from_numpy(arg).pointer_ctype + elif isinstance(arg, np.number): + return arg + else: + raise NotImplementedError(f"Unknown type: {type(arg)}.") + + +def guess_argtypes(args): + argtypes = [] + for arg in args: + if isinstance(arg, Memref): + argtypes.append(ctypes.c_void_p) + elif isinstance(arg, np.ndarray): + argtypes.append(ctypes.c_void_p) + elif isinstance(arg, np.number): + argtypes.append(np.ctypeslib.as_ctypes_type(arg.dtype)) + else: + raise NotImplementedError(f"Unknown type: {type(arg)}.") + + return argtypes + + +def call_function(source: str, fn_name: str, args: List[Any], + argtypes: Optional[List[ctypes._SimpleCData]] = None): + """ + Calls the function *fn_name* in *source*. + + :arg source: The MLIR code whose function is to be called. + :arg args: A list of args to be passed to the function. Each arg can have + one of the following types: + - :class:`numpy.ndarray` + - :class:`numpy.number + - :class:`Memref` + :arg fn_name: Name of the function op which is the to be called + """ + + source = mlir_opt(source, ["-convert-std-to-llvm=emit-c-wrappers"]) + fn_name = f"_mlir_ciface_{fn_name}" + + if argtypes is None: + argtypes = guess_argtypes(args) + + args = [preprocess_arg(arg) for arg in args] + + obj_code = llvmir_to_obj(mlir_to_llvmir(source)) + + toolchain = guess_toolchain() + + _, mod_name, ext_file, recompiled = \ + compile_from_string(toolchain, fn_name, obj_code, + ["module.o"], + source_is_binary=True) + + f = ctypes.CDLL(ext_file) + fn = getattr(f, fn_name) + fn.argtypes = argtypes + fn.restype = None + fn(*args) + +# }}} + + +# vim: fdm=marker diff --git a/setup.py b/setup.py index 6377a43..574fe12 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,10 @@ install_requires=[ 'lark-parser', 'parse' ], - tests_require=['pytest', 'pytest-cov'], + extras_require={ + 'run': ['pytools', 'codepy', 'numpy'], + 'test': ['pytest', 'pytest-cov', 'pytools', 'codepy', 'numpy'] + }, + tests_require=['pytest', 'pytest-cov', 'pytools', 'codepy', 'numpy'], test_suite='pytest', scripts=[]) diff --git a/tests/test_run.py b/tests/test_run.py new file mode 100644 index 0000000..67ba047 --- /dev/null +++ b/tests/test_run.py @@ -0,0 +1,108 @@ +__copyright__ = "Copyright (C) 2020 Kaushik Kulkarni" + +__license__ = """ +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +import numpy as np +import sys +import mlir.run as mlirrun +import pytest +from pytools.prefork import ExecError + + +def is_mlir_opt_present(): + try: + mlirrun.get_mlir_opt_version() + return True + except ExecError: + return False + + +@pytest.mark.skipif(not is_mlir_opt_present(), reason="mlir-opt not found") +def test_add(): + source = """ + #identity = affine_map<(i,j) -> (i,j)> + #attrs = { + indexing_maps = [#identity, #identity, #identity], + iterator_types = ["parallel", "parallel"] + } + func @example(%A: memref, %B: memref, %C: memref) { + linalg.generic #attrs ins(%A, %B: memref, memref) outs(%C: memref) { + ^bb0(%a: f64, %b: f64, %c: f64): + %d = addf %a, %b : f64 + linalg.yield %d : f64 + } + return + }""" + + source = mlirrun.mlir_opt(source, ["-convert-linalg-to-loops", + "-convert-scf-to-std"]) + a = np.random.rand(10, 10) + b = np.random.rand(10, 10) + c = np.empty_like(a) + + mlirrun.call_function(source, "example", [a, b, c]) + + np.testing.assert_allclose(c, a+b) + + +@pytest.mark.skipif(not is_mlir_opt_present(), reason="mlir-opt not found") +def test_axpy(): + source = """ +func @saxpy(%a : f32, %x : memref, %y : memref) { + %c0 = constant 0: index + %n = dim %x, %c0 : memref + + affine.for %i = 0 to %n { + %xi = affine.load %x[%i] : memref + %axi = mulf %a, %xi : f32 + %yi = affine.load %y[%i] : memref + %axpyi = addf %yi, %axi : f32 + affine.store %axpyi, %y[%i] : memref + } + return +}""" + + source = mlirrun.mlir_opt(source, ["-lower-affine", + "-convert-scf-to-std"]) + alpha = np.float32(np.random.rand()) + x_in = np.random.rand(10).astype(np.float32) + y_in = np.random.rand(10).astype(np.float32) + y_out = y_in.copy() + + mlirrun.call_function(source, "saxpy", [alpha, x_in, y_out]) + + np.testing.assert_allclose(y_out, alpha*x_in+y_in) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__])