Skip to content

Commit

Permalink
feat: Handle exception thrown from native side (#61)
Browse files Browse the repository at this point in the history
This PR catches exceptions thrown from native side via calling Java methods, and convert them into a `CometError::JavaException` which can then be properly propagated to the JVM.
  • Loading branch information
sunchao authored Feb 20, 2024
1 parent 7018225 commit 180f962
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 33 deletions.
3 changes: 3 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ pub enum CometError {
#[from]
source: DataFusionError,
},

#[error("{class}: {msg}")]
JavaException { class: String, msg: String },
}

pub fn init() {
Expand Down
26 changes: 13 additions & 13 deletions core/src/execution/datafusion/expressions/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl PhysicalExpr for Subquery {
let mut env = JVMClasses::get_env();

unsafe {
let is_null = jni_static_call!(env,
let is_null = jni_static_call!(&mut env,
comet_exec.is_null(self.exec_context_id, self.id) -> jboolean
)?;

Expand All @@ -105,50 +105,50 @@ impl PhysicalExpr for Subquery {

match &self.data_type {
DataType::Boolean => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_bool(self.exec_context_id, self.id) -> jboolean
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r > 0))))
}
DataType::Int8 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_byte(self.exec_context_id, self.id) -> jbyte
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r))))
}
DataType::Int16 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_short(self.exec_context_id, self.id) -> jshort
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r))))
}
DataType::Int32 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_int(self.exec_context_id, self.id) -> jint
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r))))
}
DataType::Int64 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_long(self.exec_context_id, self.id) -> jlong
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r))))
}
DataType::Float32 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_float(self.exec_context_id, self.id) -> f32
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r))))
}
DataType::Float64 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_double(self.exec_context_id, self.id) -> f64
)?;

Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r))))
}
DataType::Decimal128(p, s) => {
let bytes = jni_static_call!(env,
let bytes = jni_static_call!(&mut env,
comet_exec.get_decimal(self.exec_context_id, self.id) -> BinaryWrapper
)?;
let bytes: &JByteArray = bytes.get().into();
Expand All @@ -161,14 +161,14 @@ impl PhysicalExpr for Subquery {
)))
}
DataType::Date32 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_int(self.exec_context_id, self.id) -> jint
)?;

Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r))))
}
DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_long(self.exec_context_id, self.id) -> jlong
)?;

Expand All @@ -178,15 +178,15 @@ impl PhysicalExpr for Subquery {
)))
}
DataType::Utf8 => {
let string = jni_static_call!(env,
let string = jni_static_call!(&mut env,
comet_exec.get_string(self.exec_context_id, self.id) -> StringWrapper
)?;

let string = env.get_string(string.get()).unwrap().into();
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))))
}
DataType::Binary => {
let bytes = jni_static_call!(env,
let bytes = jni_static_call!(&mut env,
comet_exec.get_binary(self.exec_context_id, self.id) -> BinaryWrapper
)?;
let bytes: &JByteArray = bytes.get().into();
Expand Down
157 changes: 137 additions & 20 deletions core/src/jvm_bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
use jni::{
errors::{Error, Result as JniResult},
objects::{JClass, JObject, JString, JValueGen, JValueOwned},
objects::{JClass, JMethodID, JObject, JString, JThrowable, JValueGen, JValueOwned},
signature::ReturnType,
AttachGuard, JNIEnv,
};
use once_cell::sync::OnceCell;
Expand Down Expand Up @@ -58,29 +59,52 @@ macro_rules! jni_new_string {
/// jname and value are the arguments.
macro_rules! jni_call {
($env:expr, $clsname:ident($obj:expr).$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{
$crate::jvm_bridge::jni_map_error!(
$env,
$env.call_method_unchecked(
$obj,
paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]},
paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method _ret>]}.clone(),
$crate::jvm_bridge::jvalues!($($args,)*)
)
).and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result)))
let method_id = paste::paste! {
$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]
};
let ret_type = paste::paste! {
$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method _ret>]
}.clone();
let args = $crate::jvm_bridge::jvalues!($($args,)*);

// Call the JVM method and obtain the returned value
let ret = $env.call_method_unchecked($obj, method_id, ret_type, args);

// Check if JVM has thrown any exception, and handle it if so.
let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env).unwrap() {
Err(exception.into())
} else {
$crate::jvm_bridge::jni_map_error!($env, ret)
};

result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result)))
}}
}

macro_rules! jni_static_call {
($env:expr, $clsname:ident.$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{
$crate::jvm_bridge::jni_map_error!(
$env,
$env.call_static_method_unchecked(
&paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<class>]},
paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]},
paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method _ret>]}.clone(),
$crate::jvm_bridge::jvalues!($($args,)*)
)
).and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result)))
let clazz = &paste::paste! {
$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<class>]
};
let method_id = paste::paste! {
$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]
};
let ret_type = paste::paste! {
$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method _ret>]
}.clone();
let args = $crate::jvm_bridge::jvalues!($($args,)*);

// Call the JVM static method and obtain the returned value
let ret = $env.call_static_method_unchecked(clazz, method_id, ret_type, args);

// Check if JVM has thrown any exception, and handle it if so.
let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env).unwrap() {
Err(exception.into())
} else {
$crate::jvm_bridge::jni_map_error!($env, ret)
};

result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result)))
}}
}

Expand Down Expand Up @@ -167,11 +191,21 @@ pub fn get_global_jclass(env: &mut JNIEnv, cls: &str) -> JniResult<JClass<'stati
mod comet_exec;
pub use comet_exec::*;
mod comet_metric_node;
use crate::JAVA_VM;
use crate::{
errors::{CometError, CometResult},
JAVA_VM,
};
pub use comet_metric_node::*;

/// The JVM classes that are used in the JNI calls.
pub struct JVMClasses<'a> {
/// Cached method ID for "java.lang.Object#getClass"
pub object_get_class_method: JMethodID,
/// Cached method ID for "java.lang.Class#getName"
pub class_get_name_method: JMethodID,
/// Cached method ID for "java.lang.Throwable#getMessage"
pub throwable_get_message_method: JMethodID,

/// The CometMetricNode class. Used for updating the metrics.
pub comet_metric_node: CometMetricNode<'a>,
/// The static CometExec class. Used for getting the subquery result.
Expand All @@ -192,7 +226,25 @@ impl JVMClasses<'_> {
// `JNIEnv` except for creating the global references of the classes.
let env = unsafe { std::mem::transmute::<_, &'static mut JNIEnv>(env) };

let clazz = env.find_class("java/lang/Object").unwrap();
let object_get_class_method = env
.get_method_id(clazz, "getClass", "()Ljava/lang/Class;")
.unwrap();

let clazz = env.find_class("java/lang/Class").unwrap();
let class_get_name_method = env
.get_method_id(clazz, "getName", "()Ljava/lang/String;")
.unwrap();

let clazz = env.find_class("java/lang/Throwable").unwrap();
let throwable_get_message_method = env
.get_method_id(clazz, "getMessage", "()Ljava/lang/String;")
.unwrap();

JVMClasses {
object_get_class_method,
class_get_name_method,
throwable_get_message_method,
comet_metric_node: CometMetricNode::new(env).unwrap(),
comet_exec: CometExec::new(env).unwrap(),
}
Expand All @@ -211,3 +263,68 @@ impl JVMClasses<'_> {
}
}
}

pub(crate) fn check_exception(env: &mut JNIEnv) -> CometResult<Option<CometError>> {
let result = if env.exception_check()? {
let exception = env.exception_occurred()?;
env.exception_clear()?;
let exception_err = convert_exception(env, &exception)?;
Some(exception_err)
} else {
None
};

Ok(result)
}

/// Given a `JThrowable` which is thrown from calling a Java method on the native side,
/// this converts it into a `CometError::JavaException` with the exception class name
/// and exception message. This error can then be populated to the JVM side to let
/// users know the cause of the native side error.
pub(crate) fn convert_exception(
env: &mut JNIEnv,
throwable: &JThrowable,
) -> CometResult<CometError> {
unsafe {
let cache = JVMClasses::get();

// get the class name of the exception by:
// 1. get the `Class` object of the input `throwable` via `Object#getClass` method
// 2. get the exception class name via calling `Class#getName` on the above object
let class_obj = env
.call_method_unchecked(
throwable,
cache.object_get_class_method,
ReturnType::Object,
&[],
)?
.l()?;
let exception_class_name = env
.call_method_unchecked(
class_obj,
cache.class_get_name_method,
ReturnType::Object,
&[],
)?
.l()?
.into();
let exception_class_name_str = env.get_string(&exception_class_name)?.into();

// get the exception message via calling `Throwable#getMessage` on the throwable object
let message = env
.call_method_unchecked(
throwable,
cache.throwable_get_message_method,
ReturnType::Object,
&[],
)?
.l()?
.into();
let message_str = env.get_string(&message)?.into();

Ok(CometError::JavaException {
class: exception_class_name_str,
msg: message_str,
})
}
}

0 comments on commit 180f962

Please sign in to comment.