Skip to content

Commit

Permalink
PoC to test libtorch support for onnxnruntime-extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagocrepaldi committed Jul 23, 2024
1 parent d79299e commit bb4f87a
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 1 deletion.
31 changes: 31 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
cmake_minimum_required(VERSION 3.25)
project(onnxruntime_extensions LANGUAGES C CXX)

MESSAGE(OCOS_LIBTORCH_PATH=$ENV{OCOS_LIBTORCH_PATH})
if (DEFINED ENV{OCOS_LIBTORCH_PATH})
# https://download.pytorch.org/libtorch/nightly/cu121/libtorch-shared-with-deps-latest.zip <-- NOT TESTED
# https://download.pytorch.org/libtorch/nightly/cu121/libtorch-cxx11-abi-shared-with-deps-latest.zip <- NOT TESTED
# https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip <-- WORKS
# TODO: Maybe use export _GLIBCXX_USE_CXX11_ABI=1 if building pytorch from source
set(CMAKE_PREFIX_PATH $ENV{OCOS_LIBTORCH_PATH})
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
endif()

# set(CMAKE_VERBOSE_MAKEFILE ON)
if(NOT CMAKE_BUILD_TYPE)
message(STATUS "Build type not set - using RelWithDebInfo")
Expand Down Expand Up @@ -175,6 +186,18 @@ if (MSVC)
# See https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856
# Remove this definition once the conda msvcp140.dll dll is updated.
add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)

if (DEFINED ENV{OCOS_LIBTORCH_PATH})
# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
# the DLLs need to be copied to avoid memory errors.
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
add_custom_command(TARGET ocos_operators
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${TORCH_DLLS}
$<TARGET_FILE_DIR:ocos_operators>)
endif()
endif()

if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
Expand Down Expand Up @@ -596,6 +619,10 @@ target_include_directories(ocos_operators PUBLIC
${PROJECT_SOURCE_DIR}/base
${PROJECT_SOURCE_DIR}/operators)

if (DEFINED ENV{OCOS_LIBTORCH_PATH})
target_include_directories(ocos_operators PUBLIC ${TORCH_INCLUDE_DIRS})
endif()

if (OCOS_USE_CUDA)
target_include_directories(ocos_operators PUBLIC ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
endif()
Expand Down Expand Up @@ -724,6 +751,10 @@ list(APPEND ocos_libraries noexcep_operators)
target_compile_definitions(ocos_operators PRIVATE ${OCOS_COMPILE_DEFINITIONS})
target_link_libraries(ocos_operators PRIVATE ${ocos_libraries})

if (DEFINED ENV{OCOS_LIBTORCH_PATH})
target_link_libraries(ocos_operators PRIVATE "${TORCH_LIBRARIES}")
endif()

set (file_patterns "shared/lib/*.cc")
if (OCOS_ENABLE_C_API)
list(APPEND file_patterns "shared/api/*.h*" "shared/api/*.cc")
Expand Down
1 change: 1 addition & 0 deletions cmake/pytorch
Submodule pytorch added at e880cb
37 changes: 37 additions & 0 deletions operators/math/com_amd_myrelu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <torch/torch.h>

#include "ocos.h"

// TODO: Good example for CPU/CUDA op
// https://github.com/microsoft/onnxruntime-extensions/pull/739/files

// TODO: Add DLPack support to ONNXRuntime-extensions for perf improvement
// https://github.com/microsoft/onnxruntime/pull/6968

// TODO: Make templates for Tensor<T>? Testing for Tensor<float>
// template <typename T>
OrtStatusPtr com_amd_myrelu(const ortc::Tensor<float>& input_ort,
ortc::Tensor<float>& out_ort) {

int64_t input_size = input_ort.NumberOfElement();
if (0 == input_size) {
return nullptr;
}

// Massaging the input to Pytorch format
torch::Tensor X = torch::empty(input_ort.Shape()).contiguous();
memcpy(X.data_ptr<float>(), input_ort.Data(), input_size * sizeof(float)); // TODO: replace with todlpack + torch::Tensor

// Do computation
float* out_ort_ptr = out_ort.Allocate(input_ort.Shape());

// Massaging the output to ORT format
auto out_torch = torch::relu(X);
memcpy(out_ort_ptr, out_torch.data_ptr<float>(), input_size * sizeof(float)); // TODO: replace with todlpack + ortc::Tensor conversion

return nullptr;
}
17 changes: 17 additions & 0 deletions operators/math/cuda/com_amd_myrelu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// #include <torch/extension.h>
#include <torch/torch.h>
#include <cuda_runtime.h>
#include "com_amd_myrelu.cuh"

__global__ void com_amd_myrelu_kernel(const float* input, float* out, int input_size) {
// TODO: Properly implement CUDA version

// Massaging the output to ORT format
auto out_torch = torch::relu(input);
memcpy(out, out_torch.data_ptr<float>(), input_size); // TODO: replace with todlpack + ortc::Tensor conversion
}

void com_amd_myrelu_impl(cudaStream_t stream,
const float* input, float* out, int size) {
com_amd_myrelu_kernel<<<1, 1, 0, stream>>>(input, out, size);
}
9 changes: 9 additions & 0 deletions operators/math/cuda/com_amd_myrelu.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>

void com_amd_myrelu_impl(cudaStream_t stream,
const float* input, float* out, int size);
25 changes: 25 additions & 0 deletions operators/math/cuda/com_amd_myrelu_def.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "com_amd_myrelu.cuh"
#include "narrow.h"
#include "com_amd_myrelu_def.h"
#include <cuda.h>
#include <cuda_runtime.h>

OrtStatusPtr com_amd_myrelu_cuda(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<float>& input,
ortc::Tensor<float>& out0_tensor) {
// TODO: Properly implement CUDA version
int64_t input_size = input_ort.NumberOfElement() * sizeof(float);
if (0 == input_size) {
return nullptr;
}

// Massaging the input to Pytorch format
torch::Tensor X = torch::empty(input_ort.Shape()).contiguous();
memcpy(X.data_ptr<float>(), input_ort.Data(), input_size); // TODO: replace with todlpack + torch::Tensor

// Do computation
float* out_ort_ptr = out_ort.Allocate(input_ort.Shape());

com_amd_myrelu_impl(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()), X, out_ort_ptr, input_size);
return nullptr;
}
9 changes: 9 additions & 0 deletions operators/math/cuda/com_amd_myrelu_def.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"

OrtStatusPtr com_amd_myrelu_cuda(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<float>& input,
ortc::Tensor<float>& out0_tensor);
7 changes: 6 additions & 1 deletion operators/math/math.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "ocos.h"
#include "negpos.hpp"
#include "com_amd_myrelu.hpp"
#ifdef ENABLE_DLIB
#include "dlib/inverse.hpp"
#include "dlib/stft_norm.hpp"
Expand All @@ -9,19 +10,23 @@

#ifdef USE_CUDA
#include "cuda/negpos_def.h"
#include "cuda/com_amd_myrelu_def.h"
#endif // USE_CUDA

FxLoadCustomOpFactory LoadCustomOpClasses_Math = []() -> CustomOpArray& {
static OrtOpLoader op_loader(CustomCpuFuncV2("NegPos", neg_pos),
#ifdef USE_CUDA
CustomCudaFuncV2("NegPos", neg_pos_cuda),
CustomCudaFuncV2("MyReLu", com_amd_myrelu_cuda),
#endif
CustomCpuFuncV2("MyReLu", com_amd_myrelu),
#ifdef ENABLE_DLIB
CustomCpuFuncV2("Inverse", inverse),
CustomCpuStructV2("StftNorm", StftNormal),
#endif
CustomCpuFuncV2("SegmentExtraction", segment_extraction),
CustomCpuFuncV2("SegmentSum", segment_sum));
CustomCpuFuncV2("SegmentSum", segment_sum)
);

#if defined(USE_CUDA)
// CustomCudaFunc("NegPos", neg_pos_cuda),
Expand Down
124 changes: 124 additions & 0 deletions test/test_myrelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import io
import onnx
import unittest
import torch
import numpy as np
import onnxruntime as _ort
from onnxruntime_extensions import (
onnx_op, PyCustomOpDef,
get_library_path as _get_library_path)


import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms


def _create_test_model(device:str="cpu", seed=42):
# Basic setup
use_cuda = "cuda" in device.lower() and torch.cuda.is_available()
torch.manual_seed(seed)

device = torch.device(device)

# Data loader stuff
export_kwargs = {'batch_size': 1}
if use_cuda:
export_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
export_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
export_loader = torch.utils.data.DataLoader(export_dataset,**export_kwargs)

# Register custom op for relu in onnx and use in the model
# Domain must be "ai.onnx.contrib" to be compatible with onnxruntime-extensions
from torch.onnx import register_custom_op_symbolic

def com_amd_relu_1(g, input):
return g.op("ai.onnx.contrib::MyReLu", input).setType(input.type())

register_custom_op_symbolic("::relu", com_amd_relu_1, 9)

# Model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.conv2 = nn.Conv2d(10, 20, 5)
self.conv2_drop = nn.Dropout2d()
self.dropout = nn.Dropout(0.5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv1(x)
x = torch.max_pool2d(x, 2)
x = self.relu(x)
x = self.conv2(x)
x = self.conv2_drop(x)
x = torch.max_pool2d(x, 2)
x = self.relu(x)
x = x.view(-1, 320)
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output

# Exporting to ONNX with custom op
model = Net().to(device)
input_sample = next(iter(export_loader))
input_sample[0] = input_sample[0].to(device) # torch.randn(1,1,28,28, dtype=torch.float)
input_sample[1] = input_sample[1].to(device) # torch.randint(1,10,(1,))
f = io.BytesIO()
with torch.no_grad():
torch.onnx.export(model, (input_sample[0],), f)
model_proto = onnx.ModelProto.FromString(f.getvalue())
return model_proto, input_sample

class TestPythonOp(unittest.TestCase):

# Used to test custom op on PyThon.
# The ONNX graph has the custom which is executed by the function below
# @classmethod
# def setUpClass(cls):
# @onnx_op(op_type="MyReLu",
# inputs=[PyCustomOpDef.dt_float],
# outputs=[PyCustomOpDef.dt_float])
# def myrelu(x):
# return torch.relu(torch.from_numpy(x.copy()))

def test_python_myrelu(self):
# EPs = ['CPUExecutionProvider', 'CUDAExecutionProvider']
EPs = ['CPUExecutionProvider'] # TODO: Test with CUDA
DEVICEs = ["cpu", "cuda"]
for dev_idx, ep in enumerate(EPs):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model, inputs = _create_test_model(device=DEVICEs[dev_idx], seed=42)
self.assertIn('op_type: "MyReLu"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=[ep])
out = sess.run(None, {'input.1': inputs[0].numpy(force=True)})

def test_cc_myrelu(self):
# EPs = ['CPUExecutionProvider', 'CUDAExecutionProvider']
EPs = ['CPUExecutionProvider'] # TODO: Test with CUDA
DEVICEs = ["cpu", "cuda"]
for dev_idx, ep in enumerate(EPs):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model, inputs = _create_test_model(device=DEVICEs[dev_idx], seed=42)
self.assertIn('op_type: "MyReLu"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=[ep])
out = sess.run(None, {'input.1': inputs[0].numpy(force=True)})


if __name__ == "__main__":
unittest.main()

0 comments on commit bb4f87a

Please sign in to comment.