diff --git a/core/src/errors.rs b/core/src/errors.rs index 7188ebd1d..936d97d35 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -122,6 +122,9 @@ pub enum CometError { #[from] source: DataFusionError, }, + + #[error("{class}: {msg}")] + JavaException { class: String, msg: String }, } pub fn init() { diff --git a/core/src/execution/datafusion/expressions/subquery.rs b/core/src/execution/datafusion/expressions/subquery.rs index a4b32ba16..7cae12963 100644 --- a/core/src/execution/datafusion/expressions/subquery.rs +++ b/core/src/execution/datafusion/expressions/subquery.rs @@ -93,7 +93,7 @@ impl PhysicalExpr for Subquery { let mut env = JVMClasses::get_env(); unsafe { - let is_null = jni_static_call!(env, + let is_null = jni_static_call!(&mut env, comet_exec.is_null(self.exec_context_id, self.id) -> jboolean )?; @@ -105,50 +105,50 @@ impl PhysicalExpr for Subquery { match &self.data_type { DataType::Boolean => { - let r = jni_static_call!(env, + let r = jni_static_call!(&mut env, comet_exec.get_bool(self.exec_context_id, self.id) -> jboolean )?; Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r > 0)))) } DataType::Int8 => { - let r = jni_static_call!(env, + let r = jni_static_call!(&mut env, comet_exec.get_byte(self.exec_context_id, self.id) -> jbyte )?; Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r)))) } DataType::Int16 => { - let r = jni_static_call!(env, + let r = jni_static_call!(&mut env, comet_exec.get_short(self.exec_context_id, self.id) -> jshort )?; Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r)))) } DataType::Int32 => { - let r = jni_static_call!(env, + let r = jni_static_call!(&mut env, comet_exec.get_int(self.exec_context_id, self.id) -> jint )?; Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r)))) } DataType::Int64 => { - let r = jni_static_call!(env, + let r = jni_static_call!(&mut env, comet_exec.get_long(self.exec_context_id, self.id) -> jlong )?; Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r)))) } DataType::Float32 => { - let r = jni_static_call!(env, + let r = jni_static_call!(&mut env, comet_exec.get_float(self.exec_context_id, self.id) -> f32 )?; Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r)))) } DataType::Float64 => { - let r = jni_static_call!(env, + let r = jni_static_call!(&mut env, comet_exec.get_double(self.exec_context_id, self.id) -> f64 )?; Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r)))) } DataType::Decimal128(p, s) => { - let bytes = jni_static_call!(env, + let bytes = jni_static_call!(&mut env, comet_exec.get_decimal(self.exec_context_id, self.id) -> BinaryWrapper )?; let bytes: &JByteArray = bytes.get().into(); @@ -161,14 +161,14 @@ impl PhysicalExpr for Subquery { ))) } DataType::Date32 => { - let r = jni_static_call!(env, + let r = jni_static_call!(&mut env, comet_exec.get_int(self.exec_context_id, self.id) -> jint )?; Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r)))) } DataType::Timestamp(TimeUnit::Microsecond, timezone) => { - let r = jni_static_call!(env, + let r = jni_static_call!(&mut env, comet_exec.get_long(self.exec_context_id, self.id) -> jlong )?; @@ -178,7 +178,7 @@ impl PhysicalExpr for Subquery { ))) } DataType::Utf8 => { - let string = jni_static_call!(env, + let string = jni_static_call!(&mut env, comet_exec.get_string(self.exec_context_id, self.id) -> StringWrapper )?; @@ -186,7 +186,7 @@ impl PhysicalExpr for Subquery { Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))) } DataType::Binary => { - let bytes = jni_static_call!(env, + let bytes = jni_static_call!(&mut env, comet_exec.get_binary(self.exec_context_id, self.id) -> BinaryWrapper )?; let bytes: &JByteArray = bytes.get().into(); diff --git a/core/src/jvm_bridge/mod.rs b/core/src/jvm_bridge/mod.rs index 331e7768d..d3db7ba48 100644 --- a/core/src/jvm_bridge/mod.rs +++ b/core/src/jvm_bridge/mod.rs @@ -19,7 +19,8 @@ use jni::{ errors::{Error, Result as JniResult}, - objects::{JClass, JObject, JString, JValueGen, JValueOwned}, + objects::{JClass, JMethodID, JObject, JString, JThrowable, JValueGen, JValueOwned}, + signature::ReturnType, AttachGuard, JNIEnv, }; use once_cell::sync::OnceCell; @@ -58,29 +59,52 @@ macro_rules! jni_new_string { /// jname and value are the arguments. macro_rules! jni_call { ($env:expr, $clsname:ident($obj:expr).$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{ - $crate::jvm_bridge::jni_map_error!( - $env, - $env.call_method_unchecked( - $obj, - paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}, - paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}.clone(), - $crate::jvm_bridge::jvalues!($($args,)*) - ) - ).and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result))) + let method_id = paste::paste! { + $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] + }; + let ret_type = paste::paste! { + $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] + }.clone(); + let args = $crate::jvm_bridge::jvalues!($($args,)*); + + // Call the JVM method and obtain the returned value + let ret = $env.call_method_unchecked($obj, method_id, ret_type, args); + + // Check if JVM has thrown any exception, and handle it if so. + let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env).unwrap() { + Err(exception.into()) + } else { + $crate::jvm_bridge::jni_map_error!($env, ret) + }; + + result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result))) }} } macro_rules! jni_static_call { ($env:expr, $clsname:ident.$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{ - $crate::jvm_bridge::jni_map_error!( - $env, - $env.call_static_method_unchecked( - &paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}, - paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}, - paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}.clone(), - $crate::jvm_bridge::jvalues!($($args,)*) - ) - ).and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result))) + let clazz = &paste::paste! { + $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] + }; + let method_id = paste::paste! { + $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] + }; + let ret_type = paste::paste! { + $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] + }.clone(); + let args = $crate::jvm_bridge::jvalues!($($args,)*); + + // Call the JVM static method and obtain the returned value + let ret = $env.call_static_method_unchecked(clazz, method_id, ret_type, args); + + // Check if JVM has thrown any exception, and handle it if so. + let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env).unwrap() { + Err(exception.into()) + } else { + $crate::jvm_bridge::jni_map_error!($env, ret) + }; + + result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result))) }} } @@ -167,11 +191,21 @@ pub fn get_global_jclass(env: &mut JNIEnv, cls: &str) -> JniResult { + /// Cached method ID for "java.lang.Object#getClass" + pub object_get_class_method: JMethodID, + /// Cached method ID for "java.lang.Class#getName" + pub class_get_name_method: JMethodID, + /// Cached method ID for "java.lang.Throwable#getMessage" + pub throwable_get_message_method: JMethodID, + /// The CometMetricNode class. Used for updating the metrics. pub comet_metric_node: CometMetricNode<'a>, /// The static CometExec class. Used for getting the subquery result. @@ -192,7 +226,25 @@ impl JVMClasses<'_> { // `JNIEnv` except for creating the global references of the classes. let env = unsafe { std::mem::transmute::<_, &'static mut JNIEnv>(env) }; + let clazz = env.find_class("java/lang/Object").unwrap(); + let object_get_class_method = env + .get_method_id(clazz, "getClass", "()Ljava/lang/Class;") + .unwrap(); + + let clazz = env.find_class("java/lang/Class").unwrap(); + let class_get_name_method = env + .get_method_id(clazz, "getName", "()Ljava/lang/String;") + .unwrap(); + + let clazz = env.find_class("java/lang/Throwable").unwrap(); + let throwable_get_message_method = env + .get_method_id(clazz, "getMessage", "()Ljava/lang/String;") + .unwrap(); + JVMClasses { + object_get_class_method, + class_get_name_method, + throwable_get_message_method, comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), } @@ -211,3 +263,68 @@ impl JVMClasses<'_> { } } } + +pub(crate) fn check_exception(env: &mut JNIEnv) -> CometResult> { + let result = if env.exception_check()? { + let exception = env.exception_occurred()?; + env.exception_clear()?; + let exception_err = convert_exception(env, &exception)?; + Some(exception_err) + } else { + None + }; + + Ok(result) +} + +/// Given a `JThrowable` which is thrown from calling a Java method on the native side, +/// this converts it into a `CometError::JavaException` with the exception class name +/// and exception message. This error can then be populated to the JVM side to let +/// users know the cause of the native side error. +pub(crate) fn convert_exception( + env: &mut JNIEnv, + throwable: &JThrowable, +) -> CometResult { + unsafe { + let cache = JVMClasses::get(); + + // get the class name of the exception by: + // 1. get the `Class` object of the input `throwable` via `Object#getClass` method + // 2. get the exception class name via calling `Class#getName` on the above object + let class_obj = env + .call_method_unchecked( + throwable, + cache.object_get_class_method, + ReturnType::Object, + &[], + )? + .l()?; + let exception_class_name = env + .call_method_unchecked( + class_obj, + cache.class_get_name_method, + ReturnType::Object, + &[], + )? + .l()? + .into(); + let exception_class_name_str = env.get_string(&exception_class_name)?.into(); + + // get the exception message via calling `Throwable#getMessage` on the throwable object + let message = env + .call_method_unchecked( + throwable, + cache.throwable_get_message_method, + ReturnType::Object, + &[], + )? + .l()? + .into(); + let message_str = env.get_string(&message)?.into(); + + Ok(CometError::JavaException { + class: exception_class_name_str, + msg: message_str, + }) + } +}