diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2d1808b393..1da2823b4c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,6 +32,27 @@ jobs: uses: actions/setup-python@v4 with: python-version: 3.8 + - name: Check macOS architecture + if: matrix.os == 'macos-latest' + id: check-arch + run: | + arch_name="$(uname -m)" + echo "Detected architecture: $arch_name" + echo "::set-output name=architecture::$arch_name" + - name: Setup environment for Apple Silicon + if: matrix.os == 'macos-latest' && steps.check-arch.outputs.architecture == 'arm64' + run: | + bash ~/miniconda.sh -b -p $HOME/miniconda + source ~/miniconda/bin/activate + conda install -c apple tensorflow-deps + SYSTEM_VERSION_COMPAT=0 pip install tensorflow-macos tensorflow-metal + - name: Setup environment for AMD + if: matrix.os == 'macos-latest' && steps.check-arch.outputs.architecture == 'x86_64' + run: | + python3 -m venv ~/venv-metal + source ~/venv-metal/bin/activate + python -m pip install -U pip + SYSTEM_VERSION_COMPAT=0 pip install tensorflow-macos tensorflow-metal # Install pip and pytest - name: Install dependencies run: | @@ -40,6 +61,7 @@ jobs: - name: Execute test-all run: ./test-all shell: bash + # clippy: # runs-on: ubuntu-latest # strategy: diff --git a/Cargo.toml b/Cargo.toml index 6fc2182ee8..2f5ccea2e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,10 @@ [package] name = "tensorflow" version = "0.20.0" -authors = ["Adam Crume "] +authors = [ + "Adam Crume ", + "Maciej Maślanka " +] description = "Rust language bindings for TensorFlow." license = "Apache-2.0" keywords = ["TensorFlow", "bindings"] @@ -44,6 +47,7 @@ tensorflow_gpu = ["tensorflow-sys/tensorflow_gpu"] tensorflow_unstable = [] tensorflow_runtime_linking = ["tensorflow-sys-runtime"] eager = ["tensorflow-sys/eager"] +experimental = ["tensorflow-sys/experimental"] # This is for testing purposes; users should not use this. examples_system_alloc = ["tensorflow-sys/examples_system_alloc"] private-docs-rs = ["tensorflow-sys/private-docs-rs"] # DO NOT RELY ON THIS diff --git a/tensorflow-sys/Cargo.toml b/tensorflow-sys/Cargo.toml index 9e4ff2d280..6caf57dbed 100644 --- a/tensorflow-sys/Cargo.toml +++ b/tensorflow-sys/Cargo.toml @@ -5,6 +5,7 @@ license = "Apache-2.0" authors = [ "Adam Crume ", "Ivan Ukhov ", + "Maciej Maślanka ", ] description = "The package provides bindings to TensorFlow." documentation = "https://tensorflow.github.io/rust" @@ -36,6 +37,7 @@ zip = "0.6.4" [features] tensorflow_gpu = [] eager = [] +experimental = [] # This is for testing purposes; users should not use this. examples_system_alloc = [] private-docs-rs = [] # DO NOT RELY ON THIS diff --git a/tensorflow-sys/README.md b/tensorflow-sys/README.md index 680c966650..287d1a44de 100644 --- a/tensorflow-sys/README.md +++ b/tensorflow-sys/README.md @@ -59,6 +59,7 @@ compiled library will be picked up. **macOS Note**: Via [Homebrew](https://brew.sh/), you can just run `brew install libtensorflow`. +[tensorflow metal plugin]: https://developer.apple.com/metal/tensorflow-plugin/ ## Resources diff --git a/tensorflow-sys/generate_bindgen_rs.sh b/tensorflow-sys/generate_bindgen_rs.sh index 6ce66390ee..1c9883bf98 100755 --- a/tensorflow-sys/generate_bindgen_rs.sh +++ b/tensorflow-sys/generate_bindgen_rs.sh @@ -18,3 +18,8 @@ bindgen_options_eager="--allowlist-function TFE_.+ --allowlist-type TFE_.+ --all cmd="bindgen ${bindgen_options_eager} ${include_dir}/tensorflow/c/eager/c_api.h --output src/eager/c_api.rs -- -I ${include_dir}" echo ${cmd} ${cmd} + +bindgen_options_experimental="--no-derive-copy --allowlist-function TF_LoadPluggableDeviceLibrary --allowlist-function TF_DeletePluggableDeviceLibraryHandle --allowlist-var TF_Buffer* --allowlist-type TF_ShapeAndTypeList --allowlist-type TF_ShapeAndType --allowlist-type TF_CheckpointReader --allowlist-type TF_AttrBuilder --size_t-is-usize --default-enum-style=rust --generate-inline-functions --blocklist-type TF_Library --blocklist-type TF_DataType --blocklist-type TF_Status" +cmd="bindgen ${bindgen_options_experimental} ${include_dir}/tensorflow/c/c_api_experimental.h --output src/experimental/c_api.rs -- -I ${include_dir}" +echo ${cmd} +${cmd} diff --git a/tensorflow-sys/src/experimental/c_api.rs b/tensorflow-sys/src/experimental/c_api.rs new file mode 100644 index 0000000000..44449340a2 --- /dev/null +++ b/tensorflow-sys/src/experimental/c_api.rs @@ -0,0 +1,113 @@ +/* automatically generated by rust-bindgen 0.61.0 */ + +#[repr(C)] +#[derive(Debug)] +pub struct TF_CheckpointReader { + _unused: [u8; 0], +} +#[repr(C)] +#[derive(Debug)] +pub struct TF_AttrBuilder { + _unused: [u8; 0], +} +#[repr(C)] +pub struct TF_ShapeAndType { + pub num_dims: ::std::os::raw::c_int, + pub dims: *mut i64, + pub dtype: TF_DataType, +} +#[test] +fn bindgen_test_layout_TF_ShapeAndType() { + const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); + let ptr = UNINIT.as_ptr(); + assert_eq!( + ::std::mem::size_of::(), + 24usize, + concat!("Size of: ", stringify!(TF_ShapeAndType)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(TF_ShapeAndType)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).num_dims) as usize - ptr as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(TF_ShapeAndType), + "::", + stringify!(num_dims) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).dims) as usize - ptr as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(TF_ShapeAndType), + "::", + stringify!(dims) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).dtype) as usize - ptr as usize }, + 16usize, + concat!( + "Offset of field: ", + stringify!(TF_ShapeAndType), + "::", + stringify!(dtype) + ) + ); +} +#[repr(C)] +#[derive(Debug)] +pub struct TF_ShapeAndTypeList { + pub num_items: ::std::os::raw::c_int, + pub items: *mut TF_ShapeAndType, +} +#[test] +fn bindgen_test_layout_TF_ShapeAndTypeList() { + const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); + let ptr = UNINIT.as_ptr(); + assert_eq!( + ::std::mem::size_of::(), + 16usize, + concat!("Size of: ", stringify!(TF_ShapeAndTypeList)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(TF_ShapeAndTypeList)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).num_items) as usize - ptr as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(TF_ShapeAndTypeList), + "::", + stringify!(num_items) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).items) as usize - ptr as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(TF_ShapeAndTypeList), + "::", + stringify!(items) + ) + ); +} +extern "C" { + pub fn TF_LoadPluggableDeviceLibrary( + library_filename: *const ::std::os::raw::c_char, + status: *mut TF_Status, + ) -> *mut TF_Library; +} +extern "C" { + pub fn TF_DeletePluggableDeviceLibraryHandle(lib_handle: *mut TF_Library); +} diff --git a/tensorflow-sys/src/experimental/mod.rs b/tensorflow-sys/src/experimental/mod.rs new file mode 100644 index 0000000000..da35649e0a --- /dev/null +++ b/tensorflow-sys/src/experimental/mod.rs @@ -0,0 +1,3 @@ +use crate::{TF_DataType, TF_Library, TF_Status}; + +include!("c_api.rs"); diff --git a/tensorflow-sys/src/lib.rs b/tensorflow-sys/src/lib.rs index 2e3f1999d7..87204b27db 100644 --- a/tensorflow-sys/src/lib.rs +++ b/tensorflow-sys/src/lib.rs @@ -11,3 +11,8 @@ include!("c_api.rs"); pub use crate::TF_AttrType::*; pub use crate::TF_Code::*; pub use crate::TF_DataType::*; + +#[cfg(feature = "experimental")] +mod experimental; +#[cfg(feature = "experimental")] +pub use experimental::*; diff --git a/tensorflow-sys/tests/lib.rs b/tensorflow-sys/tests/lib.rs index 946b9844db..d8ccea75da 100644 --- a/tensorflow-sys/tests/lib.rs +++ b/tensorflow-sys/tests/lib.rs @@ -46,3 +46,24 @@ fn tfe_tensor_handle() { ffi::TF_DeleteTensor(tf_tensor); } } + +/// Test that the experimental API works. +#[cfg(all(feature = "experimental", target_os = "macos"))] +#[test] +fn load_plugable_device() { + let c_filename = std::ffi::CString::new("libmetal_plugin.dylib").expect("CString::new failed"); + unsafe { + let raw_status = ffi::TF_NewStatus(); + let lib_handle = ffi::TF_LoadPluggableDeviceLibrary(c_filename.as_ptr(), raw_status); + if ffi::TF_GetCode(raw_status) != ffi::TF_OK { + panic!( + "{}", + std::ffi::CStr::from_ptr(ffi::TF_Message(raw_status)) + .to_string_lossy() + .into_owned() + ); + } + ffi::TF_DeletePluggableDeviceLibraryHandle(lib_handle); + ffi::TF_DeleteStatus(raw_status); + }; +}