diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index d83e2a845..a8db9e153 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -103,7 +103,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( iterators: jobjectArray, serialized_query: jbyteArray, metrics_node: JObject, - task_memory_manager_obj: JObject, + comet_task_memory_manager_obj: JObject, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { // Init JVM classes @@ -148,7 +148,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( let input_source = Arc::new(jni_new_global_ref!(env, input_source)?); input_sources.push(input_source); } - let task_memory_manager = Arc::new(jni_new_global_ref!(env, task_memory_manager_obj)?); + let task_memory_manager = + Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?); // We need to keep the session context alive. Some session state like temporary // dictionaries are stored in session context. If it is dropped, the temporary @@ -177,7 +178,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( /// Parse Comet configs and configure DataFusion session context. fn prepare_datafusion_session_context( conf: &HashMap, - task_memory_manager: Arc, + comet_task_memory_manager: Arc, ) -> CometResult { // Get the batch size from Comet JVM side let batch_size = conf @@ -189,16 +190,17 @@ fn prepare_datafusion_session_context( let mut rt_config = RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs); + // Check if we are using unified memory manager integrated with Spark. Default to false if not + // set. let use_unified_memory_manager = conf .get("use_unified_memory_manager") - .ok_or(CometError::Internal( - "Config 'use_unified_memory_manager' is not specified from Comet JVM side".to_string(), - ))? + .map(String::as_str) + .unwrap_or("false") .parse::()?; if use_unified_memory_manager { // Set Comet memory pool for native - let memory_pool = CometMemoryPool::new(task_memory_manager); + let memory_pool = CometMemoryPool::new(comet_task_memory_manager); rt_config = rt_config.with_memory_pool(Arc::new(memory_pool)); } else { // Use the memory pool from DF diff --git a/spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java b/spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java index 120434c04..1933e4857 100644 --- a/spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java +++ b/spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java @@ -28,10 +28,14 @@ * memory manager. This assumes Spark's off-heap memory mode is enabled. */ public class CometTaskMemoryManager { + /** The id uniquely identifies the native plan this memory manager is associated to */ + private final long id; + private final TaskMemoryManager internal; private final NativeMemoryConsumer nativeMemoryConsumer; - public CometTaskMemoryManager() { + public CometTaskMemoryManager(long id) { + this.id = id; this.internal = TaskContext$.MODULE$.get().taskMemoryManager(); this.nativeMemoryConsumer = new NativeMemoryConsumer(); } @@ -62,5 +66,10 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { // No spilling return 0; } + + @Override + public String toString() { + return String.format("NativeMemoryConsumer(id=%)", id); + } } } diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 2d2a9976c..b3604c9e0 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -60,7 +60,7 @@ class CometExecIterator( cometBatchIterators, protobufQueryPlan, nativeMetrics, - new CometTaskMemoryManager) + new CometTaskMemoryManager(id)) } private var nextBatch: Option[ColumnarBatch] = None