-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PoC to test libtorch support for onnxnruntime-extensions
- Loading branch information
1 parent
d79299e
commit 38a0978
Showing
10 changed files
with
262 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "cmake/pytorch"] | ||
path = cmake/pytorch | ||
url = https://github.com/pytorch/pytorch.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include <torch/torch.h> | ||
|
||
#include "ocos.h" | ||
|
||
// TODO: Good example for CPU/CUDA op | ||
// https://github.com/microsoft/onnxruntime-extensions/pull/739/files | ||
|
||
// TODO: Add DLPack support to ONNXRuntime-extensions for perf improvement | ||
// https://github.com/microsoft/onnxruntime/pull/6968 | ||
|
||
// TODO: Make templates for Tensor<T>? Testing for Tensor<float> | ||
// template <typename T> | ||
OrtStatusPtr com_amd_myrelu(const ortc::Tensor<float>& input_ort, | ||
ortc::Tensor<float>& out_ort) { | ||
|
||
int64_t input_size = input_ort.NumberOfElement(); | ||
if (0 == input_size) { | ||
return nullptr; | ||
} | ||
|
||
// Massaging the input to Pytorch format | ||
torch::Tensor X = torch::empty(input_ort.Shape()).contiguous(); | ||
memcpy(X.data_ptr<float>(), input_ort.Data(), input_size * sizeof(float)); // TODO: replace with todlpack + torch::Tensor | ||
|
||
// Do computation | ||
float* out_ort_ptr = out_ort.Allocate(input_ort.Shape()); | ||
|
||
// Massaging the output to ORT format | ||
auto out_torch = torch::relu(X); | ||
memcpy(out_ort_ptr, out_torch.data_ptr<float>(), input_size * sizeof(float)); // TODO: replace with todlpack + ortc::Tensor conversion | ||
|
||
return nullptr; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// #include <torch/extension.h> | ||
#include <torch/torch.h> | ||
#include <cuda_runtime.h> | ||
#include "com_amd_myrelu.cuh" | ||
|
||
__global__ void com_amd_myrelu_kernel(const float* input, float* out, int input_size) { | ||
// TODO: Properly implement CUDA version | ||
|
||
// Massaging the output to ORT format | ||
auto out_torch = torch::relu(input); | ||
memcpy(out, out_torch.data_ptr<float>(), input_size); // TODO: replace with todlpack + ortc::Tensor conversion | ||
} | ||
|
||
void com_amd_myrelu_impl(cudaStream_t stream, | ||
const float* input, float* out, int size) { | ||
com_amd_myrelu_kernel<<<1, 1, 0, stream>>>(input, out, size); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
|
||
void com_amd_myrelu_impl(cudaStream_t stream, | ||
const float* input, float* out, int size); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#include "com_amd_myrelu.cuh" | ||
#include "narrow.h" | ||
#include "com_amd_myrelu_def.h" | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
|
||
OrtStatusPtr com_amd_myrelu_cuda(Ort::Custom::CUDAKernelContext* ctx, | ||
const ortc::Tensor<float>& input, | ||
ortc::Tensor<float>& out0_tensor) { | ||
// TODO: Properly implement CUDA version | ||
int64_t input_size = input_ort.NumberOfElement() * sizeof(float); | ||
if (0 == input_size) { | ||
return nullptr; | ||
} | ||
|
||
// Massaging the input to Pytorch format | ||
torch::Tensor X = torch::empty(input_ort.Shape()).contiguous(); | ||
memcpy(X.data_ptr<float>(), input_ort.Data(), input_size); // TODO: replace with todlpack + torch::Tensor | ||
|
||
// Do computation | ||
float* out_ort_ptr = out_ort.Allocate(input_ort.Shape()); | ||
|
||
com_amd_myrelu_impl(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()), X, out_ort_ptr, input_size); | ||
return nullptr; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "ocos.h" | ||
|
||
OrtStatusPtr com_amd_myrelu_cuda(Ort::Custom::CUDAKernelContext* ctx, | ||
const ortc::Tensor<float>& input, | ||
ortc::Tensor<float>& out0_tensor); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import io | ||
import onnx | ||
import unittest | ||
import torch | ||
import numpy as np | ||
import onnxruntime as _ort | ||
from onnxruntime_extensions import ( | ||
onnx_op, PyCustomOpDef, | ||
get_library_path as _get_library_path) | ||
|
||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torchvision import datasets, transforms | ||
|
||
|
||
def _create_test_model(device:str="cpu", seed=42): | ||
# Basic setup | ||
use_cuda = "cuda" in device.lower() and torch.cuda.is_available() | ||
torch.manual_seed(seed) | ||
|
||
device = torch.device(device) | ||
|
||
# Data loader stuff | ||
export_kwargs = {'batch_size': 1} | ||
if use_cuda: | ||
export_kwargs = {'num_workers': 1, | ||
'pin_memory': True, | ||
'shuffle': True} | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
]) | ||
export_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform) | ||
export_loader = torch.utils.data.DataLoader(export_dataset,**export_kwargs) | ||
|
||
# Register custom op for relu in onnx and use in the model | ||
# Domain must be "ai.onnx.contrib" to be compatible with onnxruntime-extensions | ||
from torch.onnx import register_custom_op_symbolic | ||
|
||
def com_amd_relu_1(g, input): | ||
return g.op("ai.onnx.contrib::MyReLu", input).setType(input.type()) | ||
|
||
register_custom_op_symbolic("::relu", com_amd_relu_1, 9) | ||
|
||
# Model | ||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 10, 5) | ||
self.conv2 = nn.Conv2d(10, 20, 5) | ||
self.conv2_drop = nn.Dropout2d() | ||
self.dropout = nn.Dropout(0.5) | ||
self.fc1 = nn.Linear(320, 50) | ||
self.fc2 = nn.Linear(50, 10) | ||
self.relu = nn.ReLU() | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = torch.max_pool2d(x, 2) | ||
x = self.relu(x) | ||
x = self.conv2(x) | ||
x = self.conv2_drop(x) | ||
x = torch.max_pool2d(x, 2) | ||
x = self.relu(x) | ||
x = x.view(-1, 320) | ||
x = self.fc1(x) | ||
x = self.relu(x) | ||
x = self.dropout(x) | ||
x = self.fc2(x) | ||
output = F.log_softmax(x, dim=1) | ||
return output | ||
|
||
# Exporting to ONNX with custom op | ||
model = Net().to(device) | ||
input_sample = next(iter(export_loader)) | ||
input_sample[0] = input_sample[0].to(device) # torch.randn(1,1,28,28, dtype=torch.float) | ||
input_sample[1] = input_sample[1].to(device) # torch.randint(1,10,(1,)) | ||
f = io.BytesIO() | ||
with torch.no_grad(): | ||
torch.onnx.export(model, (input_sample[0],), f) | ||
model_proto = onnx.ModelProto.FromString(f.getvalue()) | ||
return model_proto, input_sample | ||
|
||
class TestPythonOp(unittest.TestCase): | ||
|
||
# Used to test custom op on PyThon. | ||
# The ONNX graph has the custom which is executed by the function below | ||
# @classmethod | ||
# def setUpClass(cls): | ||
# @onnx_op(op_type="MyReLu", | ||
# inputs=[PyCustomOpDef.dt_float], | ||
# outputs=[PyCustomOpDef.dt_float]) | ||
# def myrelu(x): | ||
# return torch.relu(torch.from_numpy(x.copy())) | ||
|
||
def test_python_myrelu(self): | ||
# EPs = ['CPUExecutionProvider', 'CUDAExecutionProvider'] | ||
EPs = ['CPUExecutionProvider'] # TODO: Test with CUDA | ||
DEVICEs = ["cpu", "cuda"] | ||
for dev_idx, ep in enumerate(EPs): | ||
so = _ort.SessionOptions() | ||
so.register_custom_ops_library(_get_library_path()) | ||
onnx_model, inputs = _create_test_model(device=DEVICEs[dev_idx], seed=42) | ||
self.assertIn('op_type: "MyReLu"', str(onnx_model)) | ||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=[ep]) | ||
out = sess.run(None, {'input.1': inputs[0].numpy(force=True)}) | ||
|
||
def test_cc_myrelu(self): | ||
# EPs = ['CPUExecutionProvider', 'CUDAExecutionProvider'] | ||
EPs = ['CPUExecutionProvider'] # TODO: Test with CUDA | ||
DEVICEs = ["cpu", "cuda"] | ||
for dev_idx, ep in enumerate(EPs): | ||
so = _ort.SessionOptions() | ||
so.register_custom_ops_library(_get_library_path()) | ||
onnx_model, inputs = _create_test_model(device=DEVICEs[dev_idx], seed=42) | ||
self.assertIn('op_type: "MyReLu"', str(onnx_model)) | ||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=[ep]) | ||
out = sess.run(None, {'input.1': inputs[0].numpy(force=True)}) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |