Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow pluggable device #387

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Check macOS architecture
if: matrix.os == 'macos-latest'
id: check-arch
run: |
arch_name="$(uname -m)"
echo "Detected architecture: $arch_name"
echo "::set-output name=architecture::$arch_name"
- name: Setup environment for Apple Silicon
if: matrix.os == 'macos-latest' && steps.check-arch.outputs.architecture == 'arm64'
run: |
bash ~/miniconda.sh -b -p $HOME/miniconda
source ~/miniconda/bin/activate
conda install -c apple tensorflow-deps
SYSTEM_VERSION_COMPAT=0 pip install tensorflow-macos tensorflow-metal
- name: Setup environment for AMD
if: matrix.os == 'macos-latest' && steps.check-arch.outputs.architecture == 'x86_64'
run: |
python3 -m venv ~/venv-metal
source ~/venv-metal/bin/activate
python -m pip install -U pip
SYSTEM_VERSION_COMPAT=0 pip install tensorflow-macos tensorflow-metal
# Install pip and pytest
- name: Install dependencies
run: |
Expand All @@ -40,6 +61,7 @@ jobs:
- name: Execute test-all
run: ./test-all
shell: bash

# clippy:
# runs-on: ubuntu-latest
# strategy:
Expand Down
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
[package]
name = "tensorflow"
version = "0.20.0"
authors = ["Adam Crume <[email protected]>"]
authors = [
"Adam Crume <[email protected]>",
"Maciej Maślanka <[email protected]>"
]
description = "Rust language bindings for TensorFlow."
license = "Apache-2.0"
keywords = ["TensorFlow", "bindings"]
Expand Down Expand Up @@ -44,6 +47,7 @@ tensorflow_gpu = ["tensorflow-sys/tensorflow_gpu"]
tensorflow_unstable = []
tensorflow_runtime_linking = ["tensorflow-sys-runtime"]
eager = ["tensorflow-sys/eager"]
experimental = ["tensorflow-sys/experimental"]
# This is for testing purposes; users should not use this.
examples_system_alloc = ["tensorflow-sys/examples_system_alloc"]
private-docs-rs = ["tensorflow-sys/private-docs-rs"] # DO NOT RELY ON THIS
Expand Down
2 changes: 2 additions & 0 deletions tensorflow-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ license = "Apache-2.0"
authors = [
"Adam Crume <[email protected]>",
"Ivan Ukhov <[email protected]>",
"Maciej Maślanka <[email protected]>",
]
description = "The package provides bindings to TensorFlow."
documentation = "https://tensorflow.github.io/rust"
Expand Down Expand Up @@ -36,6 +37,7 @@ zip = "0.6.4"
[features]
tensorflow_gpu = []
eager = []
experimental = []
# This is for testing purposes; users should not use this.
examples_system_alloc = []
private-docs-rs = [] # DO NOT RELY ON THIS
1 change: 1 addition & 0 deletions tensorflow-sys/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ compiled library will be picked up.

**macOS Note**: Via [Homebrew](https://brew.sh/), you can just run
`brew install libtensorflow`.
[tensorflow metal plugin]: https://developer.apple.com/metal/tensorflow-plugin/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't render properly (https://github.com/Apurer/rust/tree/feature/experimental/tensorflow-sys). It also needs some explanation. Do users need to follow that link and do anything specific?

Copy link
Author

@Apurer Apurer Apr 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix the rendering issue. Yes, users need to click on the provided link to set up the TensorFlow plugin on Apple computers. I considered including the steps listed on developer.apple.com, but I was unsure if the steps would be updated.


## Resources

Expand Down
5 changes: 5 additions & 0 deletions tensorflow-sys/generate_bindgen_rs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ bindgen_options_eager="--allowlist-function TFE_.+ --allowlist-type TFE_.+ --all
cmd="bindgen ${bindgen_options_eager} ${include_dir}/tensorflow/c/eager/c_api.h --output src/eager/c_api.rs -- -I ${include_dir}"
echo ${cmd}
${cmd}

bindgen_options_experimental="--no-derive-copy --allowlist-function TF_LoadPluggableDeviceLibrary --allowlist-function TF_DeletePluggableDeviceLibraryHandle --allowlist-var TF_Buffer* --allowlist-type TF_ShapeAndTypeList --allowlist-type TF_ShapeAndType --allowlist-type TF_CheckpointReader --allowlist-type TF_AttrBuilder --size_t-is-usize --default-enum-style=rust --generate-inline-functions --blocklist-type TF_Library --blocklist-type TF_DataType --blocklist-type TF_Status"
cmd="bindgen ${bindgen_options_experimental} ${include_dir}/tensorflow/c/c_api_experimental.h --output src/experimental/c_api.rs -- -I ${include_dir}"
echo ${cmd}
${cmd}
113 changes: 113 additions & 0 deletions tensorflow-sys/src/experimental/c_api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/* automatically generated by rust-bindgen 0.61.0 */

#[repr(C)]
#[derive(Debug)]
pub struct TF_CheckpointReader {
_unused: [u8; 0],
}
#[repr(C)]
#[derive(Debug)]
pub struct TF_AttrBuilder {
_unused: [u8; 0],
}
#[repr(C)]
pub struct TF_ShapeAndType {
pub num_dims: ::std::os::raw::c_int,
pub dims: *mut i64,
pub dtype: TF_DataType,
}
#[test]
fn bindgen_test_layout_TF_ShapeAndType() {
const UNINIT: ::std::mem::MaybeUninit<TF_ShapeAndType> = ::std::mem::MaybeUninit::uninit();
let ptr = UNINIT.as_ptr();
assert_eq!(
::std::mem::size_of::<TF_ShapeAndType>(),
24usize,
concat!("Size of: ", stringify!(TF_ShapeAndType))
);
assert_eq!(
::std::mem::align_of::<TF_ShapeAndType>(),
8usize,
concat!("Alignment of ", stringify!(TF_ShapeAndType))
);
assert_eq!(
unsafe { ::std::ptr::addr_of!((*ptr).num_dims) as usize - ptr as usize },
0usize,
concat!(
"Offset of field: ",
stringify!(TF_ShapeAndType),
"::",
stringify!(num_dims)
)
);
assert_eq!(
unsafe { ::std::ptr::addr_of!((*ptr).dims) as usize - ptr as usize },
8usize,
concat!(
"Offset of field: ",
stringify!(TF_ShapeAndType),
"::",
stringify!(dims)
)
);
assert_eq!(
unsafe { ::std::ptr::addr_of!((*ptr).dtype) as usize - ptr as usize },
16usize,
concat!(
"Offset of field: ",
stringify!(TF_ShapeAndType),
"::",
stringify!(dtype)
)
);
}
#[repr(C)]
#[derive(Debug)]
pub struct TF_ShapeAndTypeList {
pub num_items: ::std::os::raw::c_int,
pub items: *mut TF_ShapeAndType,
}
#[test]
fn bindgen_test_layout_TF_ShapeAndTypeList() {
const UNINIT: ::std::mem::MaybeUninit<TF_ShapeAndTypeList> = ::std::mem::MaybeUninit::uninit();
let ptr = UNINIT.as_ptr();
assert_eq!(
::std::mem::size_of::<TF_ShapeAndTypeList>(),
16usize,
concat!("Size of: ", stringify!(TF_ShapeAndTypeList))
);
assert_eq!(
::std::mem::align_of::<TF_ShapeAndTypeList>(),
8usize,
concat!("Alignment of ", stringify!(TF_ShapeAndTypeList))
);
assert_eq!(
unsafe { ::std::ptr::addr_of!((*ptr).num_items) as usize - ptr as usize },
0usize,
concat!(
"Offset of field: ",
stringify!(TF_ShapeAndTypeList),
"::",
stringify!(num_items)
)
);
assert_eq!(
unsafe { ::std::ptr::addr_of!((*ptr).items) as usize - ptr as usize },
8usize,
concat!(
"Offset of field: ",
stringify!(TF_ShapeAndTypeList),
"::",
stringify!(items)
)
);
}
extern "C" {
pub fn TF_LoadPluggableDeviceLibrary(
library_filename: *const ::std::os::raw::c_char,
status: *mut TF_Status,
) -> *mut TF_Library;
}
extern "C" {
pub fn TF_DeletePluggableDeviceLibraryHandle(lib_handle: *mut TF_Library);
}
3 changes: 3 additions & 0 deletions tensorflow-sys/src/experimental/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
use crate::{TF_DataType, TF_Library, TF_Status};

include!("c_api.rs");
5 changes: 5 additions & 0 deletions tensorflow-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ include!("c_api.rs");
pub use crate::TF_AttrType::*;
pub use crate::TF_Code::*;
pub use crate::TF_DataType::*;

#[cfg(feature = "experimental")]
mod experimental;
#[cfg(feature = "experimental")]
pub use experimental::*;
21 changes: 21 additions & 0 deletions tensorflow-sys/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,24 @@ fn tfe_tensor_handle() {
ffi::TF_DeleteTensor(tf_tensor);
}
}

/// Test that the experimental API works.
#[cfg(all(feature = "experimental", target_os = "macos"))]
#[test]
fn load_plugable_device() {
let c_filename = std::ffi::CString::new("libmetal_plugin.dylib").expect("CString::new failed");
unsafe {
let raw_status = ffi::TF_NewStatus();
let lib_handle = ffi::TF_LoadPluggableDeviceLibrary(c_filename.as_ptr(), raw_status);
if ffi::TF_GetCode(raw_status) != ffi::TF_OK {
panic!(
"{}",
std::ffi::CStr::from_ptr(ffi::TF_Message(raw_status))
.to_string_lossy()
.into_owned()
);
}
ffi::TF_DeletePluggableDeviceLibraryHandle(lib_handle);
ffi::TF_DeleteStatus(raw_status);
};
}