From 28946b06c92776d8ee51cfcb7b8459701d0a3350 Mon Sep 17 00:00:00 2001 From: weimingdiit Date: Wed, 8 Apr 2026 14:22:19 +0800 Subject: [PATCH] [AURON #2183] Implement native support for ORC InsertIntoHiveTable writes Signed-off-by: weimingdiit --- .../auron-jni-bridge/src/jni_bridge.rs | 41 ++ native-engine/auron-planner/proto/auron.proto | 14 + native-engine/auron-planner/src/planner.rs | 14 + native-engine/auron/src/rt.rs | 2 + native-engine/datafusion-ext-plans/src/lib.rs | 1 + .../datafusion-ext-plans/src/orc_sink_exec.rs | 561 ++++++++++++++++++ .../apache/spark/sql/auron/ShimsImpl.scala | 18 + .../NativeOrcInsertIntoHiveTableExec.scala | 344 +++++++++++ .../auron/plan/NativeOrcSinkExec.scala | 41 ++ .../apache/auron/BaseAuronHiveSQLSuite.scala | 28 + .../apache/auron/exec/AuronExecSuite.scala | 1 + .../auron/exec/AuronHiveExecSuite.scala | 251 ++++++++ .../SparkAuronConfiguration.java | 12 + .../auron/plan/NativeOrcSinkUtils.java | 30 + .../spark/sql/auron/AuronConverters.scala | 82 ++- .../org/apache/spark/sql/auron/Shims.scala | 11 + .../auron/arrowio/ArrowFFIExporter.scala | 79 ++- .../auron/plan/ConvertToNativeBase.scala | 14 +- .../NativeOrcInsertIntoHiveTableBase.scala | 196 ++++++ .../auron/plan/NativeOrcSinkBase.scala | 135 +++++ 20 files changed, 1855 insertions(+), 20 deletions(-) create mode 100644 native-engine/datafusion-ext-plans/src/orc_sink_exec.rs create mode 100644 spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcInsertIntoHiveTableExec.scala create mode 100644 spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkExec.scala create mode 100644 spark-extension-shims-spark/src/test/scala/org/apache/auron/BaseAuronHiveSQLSuite.scala create mode 100644 spark-extension-shims-spark/src/test/scala/org/apache/auron/exec/AuronHiveExecSuite.scala create mode 100644 spark-extension/src/main/java/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkUtils.java create mode 100644 spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcInsertIntoHiveTableBase.scala create mode 100644 spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkBase.scala diff --git a/native-engine/auron-jni-bridge/src/jni_bridge.rs b/native-engine/auron-jni-bridge/src/jni_bridge.rs index 85b7598d3..a3c0257f9 100644 --- a/native-engine/auron-jni-bridge/src/jni_bridge.rs +++ b/native-engine/auron-jni-bridge/src/jni_bridge.rs @@ -448,6 +448,7 @@ pub struct JavaClasses<'a> { pub cSparkUDAFMemTracker: SparkUDAFMemTracker<'a>, pub cAuronRssPartitionWriterBase: AuronRssPartitionWriterBase<'a>, pub cAuronOnHeapSpillManager: AuronOnHeapSpillManager<'a>, + pub cAuronNativeOrcSinkUtils: AuronNativeOrcSinkUtils<'a>, pub cAuronNativeParquetSinkUtils: AuronNativeParquetSinkUtils<'a>, pub cAuronBlockObject: AuronBlockObject<'a>, pub cAuronJsonFallbackWrapper: AuronJsonFallbackWrapper<'a>, @@ -504,6 +505,7 @@ impl JavaClasses<'static> { c_spark_udaf_mem_tracker, c_auron_rss_partition_writer_base, c_auron_on_heap_spill_manager, + c_auron_native_orc_sink_utils, c_auron_native_parquet_sink_utils, c_auron_block_object, c_auron_json_fallback_wrapper, @@ -517,6 +519,7 @@ impl JavaClasses<'static> { SparkUDAFMemTracker::new(env)?, AuronRssPartitionWriterBase::new(env)?, AuronOnHeapSpillManager::new(env)?, + AuronNativeOrcSinkUtils::new(env)?, AuronNativeParquetSinkUtils::new(env)?, AuronBlockObject::new(env)?, AuronJsonFallbackWrapper::new(env)?, @@ -530,6 +533,7 @@ impl JavaClasses<'static> { SparkUDAFMemTracker::default(), AuronRssPartitionWriterBase::default(), AuronOnHeapSpillManager::default(), + AuronNativeOrcSinkUtils::default(), AuronNativeParquetSinkUtils::default(), AuronBlockObject::default(), AuronJsonFallbackWrapper::default(), @@ -568,6 +572,7 @@ impl JavaClasses<'static> { cSparkUDAFMemTracker: c_spark_udaf_mem_tracker, cAuronRssPartitionWriterBase: c_auron_rss_partition_writer_base, cAuronOnHeapSpillManager: c_auron_on_heap_spill_manager, + cAuronNativeOrcSinkUtils: c_auron_native_orc_sink_utils, cAuronNativeParquetSinkUtils: c_auron_native_parquet_sink_utils, cAuronBlockObject: c_auron_block_object, cAuronJsonFallbackWrapper: c_auron_json_fallback_wrapper, @@ -1603,6 +1608,42 @@ impl<'a> AuronNativeParquetSinkUtils<'a> { } } +#[allow(non_snake_case)] +pub struct AuronNativeOrcSinkUtils<'a> { + pub class: JClass<'a>, + pub method_getTaskOutputPath: JStaticMethodID, + pub method_getTaskOutputPath_ret: ReturnType, + pub method_completeOutput: JStaticMethodID, + pub method_completeOutput_ret: ReturnType, +} +impl<'a> AuronNativeOrcSinkUtils<'a> { + pub const SIG_TYPE: &'static str = + "org/apache/spark/sql/execution/auron/plan/NativeOrcSinkUtils"; + + pub fn new(env: &JNIEnv<'a>) -> JniResult> { + let class = get_global_jclass(env, Self::SIG_TYPE)?; + Ok(AuronNativeOrcSinkUtils { + class, + method_getTaskOutputPath: env.get_static_method_id( + class, + "getTaskOutputPath", + "()Ljava/lang/String;", + )?, + method_getTaskOutputPath_ret: ReturnType::Object, + method_completeOutput: env.get_static_method_id( + class, + "completeOutput", + "(Ljava/lang/String;JJ)V", + )?, + method_completeOutput_ret: ReturnType::Primitive(Primitive::Void), + }) + } + + fn default() -> Self { + unsafe { std::mem::zeroed() } + } +} + #[allow(non_snake_case)] pub struct AuronBlockObject<'a> { pub class: JClass<'a>, diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index b0618b971..f896daad4 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -52,6 +52,7 @@ message PhysicalPlanNode { ParquetSinkExecNode parquet_sink = 24; OrcScanExecNode orc_scan = 25; KafkaScanExecNode kafka_scan = 26; + OrcSinkExecNode orc_sink = 27; } } @@ -622,6 +623,19 @@ message ParquetProp { string value = 2; } +message OrcSinkExecNode { + PhysicalPlanNode input = 1; + string fs_resource_id = 2; + int32 num_dyn_parts = 3; + Schema schema = 4; + repeated OrcProp prop = 5; +} + +message OrcProp { + string key = 1; + string value = 2; +} + message IpcWriterExecNode { PhysicalPlanNode input = 1; string ipc_consumer_resource_id = 2; diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index 84a625734..2591ae07f 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -79,6 +79,7 @@ use datafusion_ext_plans::{ ipc_writer_exec::IpcWriterExec, limit_exec::LimitExec, orc_exec::OrcExec, + orc_sink_exec::OrcSinkExec, parquet_exec::ParquetExec, parquet_sink_exec::ParquetSinkExec, project_exec::ProjectExec, @@ -802,6 +803,19 @@ impl PhysicalPlanner { props, ))) } + PhysicalPlanType::OrcSink(orc_sink) => { + let mut props: Vec<(String, String)> = vec![]; + for prop in &orc_sink.prop { + props.push((prop.key.clone(), prop.value.clone())); + } + Ok(Arc::new(OrcSinkExec::new( + convert_box_required!(self, orc_sink.input)?, + orc_sink.fs_resource_id.clone(), + orc_sink.num_dyn_parts as usize, + Arc::new(convert_required!(orc_sink.schema)?), + props, + ))) + } PhysicalPlanType::KafkaScan(kafka_scan) => { let schema = Arc::new(convert_required!(kafka_scan.schema)?); if !kafka_scan.mock_data_json_array.is_empty() { diff --git a/native-engine/auron/src/rt.rs b/native-engine/auron/src/rt.rs index b3d4adddd..b03ccdeb3 100644 --- a/native-engine/auron/src/rt.rs +++ b/native-engine/auron/src/rt.rs @@ -46,6 +46,7 @@ use datafusion_ext_commons::{df_execution_err, downcast_any}; use datafusion_ext_plans::{ common::execution_context::{ExecutionContext, cancel_all_tasks}, ipc_writer_exec::IpcWriterExec, + orc_sink_exec::OrcSinkExec, parquet_sink_exec::ParquetSinkExec, shuffle_writer_exec::ShuffleWriterExec, }; @@ -156,6 +157,7 @@ impl NativeExecutionRuntime { // coalesce output stream if necessary if downcast_any!(execution_plan_cloned, EmptyExec).is_err() + && downcast_any!(execution_plan_cloned, OrcSinkExec).is_err() && downcast_any!(execution_plan_cloned, ParquetSinkExec).is_err() && downcast_any!(execution_plan_cloned, IpcWriterExec).is_err() && downcast_any!(execution_plan_cloned, ShuffleWriterExec).is_err() diff --git a/native-engine/datafusion-ext-plans/src/lib.rs b/native-engine/datafusion-ext-plans/src/lib.rs index c6339e146..4d62d2be2 100644 --- a/native-engine/datafusion-ext-plans/src/lib.rs +++ b/native-engine/datafusion-ext-plans/src/lib.rs @@ -47,6 +47,7 @@ pub mod ipc_writer_exec; pub mod joins; pub mod limit_exec; pub mod orc_exec; +pub mod orc_sink_exec; pub mod parquet_exec; pub mod parquet_sink_exec; pub mod project_exec; diff --git a/native-engine/datafusion-ext-plans/src/orc_sink_exec.rs b/native-engine/datafusion-ext-plans/src/orc_sink_exec.rs new file mode 100644 index 000000000..37c723a27 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/orc_sink_exec.rs @@ -0,0 +1,561 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + any::Any, + fmt::Formatter, + io::Write, + sync::{Arc, mpsc}, +}; + +use arrow::{ + datatypes::SchemaRef, + record_batch::{RecordBatch, RecordBatchOptions}, +}; +use auron_jni_bridge::{jni_call_static, jni_get_string, jni_new_global_ref, jni_new_string}; +use datafusion::{ + common::{Result, ScalarValue, Statistics}, + execution::context::TaskContext, + physical_expr::EquivalenceProperties, + physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + SendableRecordBatchStream, + execution_plan::{Boundedness, EmissionType}, + metrics::{Count, ExecutionPlanMetricsSet, MetricsSet, Time}, + }, +}; +use datafusion_ext_commons::{ + arrow::{array_size::BatchSize, cast::cast}, + df_execution_err, + hadoop_fs::{FsDataOutputWrapper, FsProvider}, +}; +use futures::StreamExt; +use once_cell::sync::OnceCell; +use orc_rust::ArrowWriterBuilder; +use tokio::sync::oneshot; + +use crate::common::execution_context::ExecutionContext; + +#[derive(Debug)] +pub struct OrcSinkExec { + fs_resource_id: String, + input: Arc, + num_dyn_parts: usize, + schema: SchemaRef, + props: Vec<(String, String)>, + metrics: ExecutionPlanMetricsSet, + plan_props: OnceCell, +} + +impl OrcSinkExec { + pub fn new( + input: Arc, + fs_resource_id: String, + num_dyn_parts: usize, + schema: SchemaRef, + props: Vec<(String, String)>, + ) -> Self { + Self { + input, + fs_resource_id, + num_dyn_parts, + schema, + props, + metrics: ExecutionPlanMetricsSet::new(), + plan_props: OnceCell::new(), + } + } +} + +impl DisplayAs for OrcSinkExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "OrcSink") + } +} + +impl ExecutionPlan for OrcSinkExec { + fn name(&self) -> &str { + "OrcSinkExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + fn properties(&self) -> &PlanProperties { + self.plan_props.get_or_init(|| { + PlanProperties::new( + EquivalenceProperties::new(self.schema()), + self.input.output_partitioning().clone(), + EmissionType::Both, + Boundedness::Bounded, + ) + }) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + children[0].clone(), + self.fs_resource_id.clone(), + self.num_dyn_parts, + self.schema.clone(), + self.props.clone(), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let exec_ctx = ExecutionContext::new(context, partition, self.schema(), &self.metrics); + let elapsed_compute = exec_ctx.baseline_metrics().elapsed_compute().clone(); + let _timer = elapsed_compute.timer(); + let io_time = exec_ctx.register_timer_metric("io_time"); + + let orc_sink_context = Arc::new(OrcSinkContext::try_new( + &self.fs_resource_id, + self.num_dyn_parts, + self.schema.clone(), + &io_time, + &self.props, + )?); + + let input = exec_ctx.execute_with_input_stats(&self.input)?; + execute_orc_sink(orc_sink_context, input, exec_ctx) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + todo!() + } +} + +struct OrcSinkContext { + fs_provider: FsProvider, + schema: SchemaRef, + num_dyn_parts: usize, + batch_size: usize, + stripe_byte_size: usize, +} + +impl OrcSinkContext { + fn try_new( + fs_resource_id: &str, + num_dyn_parts: usize, + schema: SchemaRef, + io_time: &Time, + props: &[(String, String)], + ) -> Result { + let fs_provider = { + let resource_id = jni_new_string!(&fs_resource_id)?; + let fs = jni_call_static!(JniBridge.getResource(resource_id.as_obj()) -> JObject)?; + FsProvider::new(jni_new_global_ref!(fs.as_obj())?, io_time) + }; + + let batch_size = props + .iter() + .find(|(key, _)| key == "orc.row.batch.size") + .and_then(|(_, value)| value.parse::().ok()) + .unwrap_or(1024); + let stripe_byte_size = props + .iter() + .find(|(key, _)| key == "orc.stripe.size") + .and_then(|(_, value)| value.parse::().ok()) + .unwrap_or(64 * 1024 * 1024); + + Ok(Self { + fs_provider, + schema, + num_dyn_parts, + batch_size, + stripe_byte_size, + }) + } +} + +fn execute_orc_sink( + orc_sink_context: Arc, + mut input: SendableRecordBatchStream, + exec_ctx: Arc, +) -> Result { + let bytes_written = exec_ctx.register_counter_metric("bytes_written"); + + Ok(exec_ctx + .clone() + .output_with_sender("OrcSink", move |sender| async move { + let (part_writer_tx, part_writer_rx) = mpsc::channel(); + let part_writer_handle = { + let orc_sink_context = orc_sink_context.clone(); + tokio::task::spawn_blocking(move || { + part_writer_worker_loop(orc_sink_context, part_writer_rx) + }) + }; + let mut active_part_values: Option> = None; + + macro_rules! part_writer_init { + ($batch:expr, $part_values:expr) => {{ + log::info!("starts writing partition: {:?}", $part_values); + sender.send($batch.slice(0, 1)).await; + open_part_writer(&part_writer_tx, $part_values.to_vec()).await?; + active_part_values = Some($part_values.to_vec()); + }}; + } + macro_rules! part_writer_close { + () => {{ + if active_part_values.take().is_some() { + if let Some(file_stat) = close_part_writer(&part_writer_tx).await? { + jni_call_static!( + AuronNativeOrcSinkUtils.completeOutput( + jni_new_string!(&file_stat.path)?.as_obj(), + file_stat.num_rows as i64, + file_stat.num_bytes as i64, + ) -> () + )?; + exec_ctx.baseline_metrics().output_rows().add(file_stat.num_rows); + bytes_written.add(file_stat.num_bytes); + } + } + }} + } + + while let Some(mut batch) = input.next().await.transpose()? { + let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer(); + if batch.num_rows() == 0 { + continue; + } + + while batch.num_rows() > 0 { + let part_values = + get_dyn_part_values(&batch, orc_sink_context.num_dyn_parts, 0)?; + let part_writer_outdated = active_part_values.as_ref() != Some(&part_values); + + if part_writer_outdated { + part_writer_close!(); + part_writer_init!(batch, &part_values); + continue; + } + + let batch_mem_size = batch.get_batch_mem_size(); + let num_sub_batches = (batch_mem_size / 1048576).max(1); + let num_sub_batch_rows = (batch.num_rows() / num_sub_batches).max(16); + + let m = rfind_part_values(&batch, &part_values)?; + let cur_batch = batch.slice(0, m); + batch = batch.slice(m, batch.num_rows() - m); + + let cur_batch = adapt_schema(&cur_batch, &orc_sink_context.schema)?; + let mut offset = 0; + while offset < cur_batch.num_rows() { + let sub_batch_size = num_sub_batch_rows.min(cur_batch.num_rows() - offset); + let sub_batch = cur_batch.slice(offset, sub_batch_size); + offset += sub_batch_size; + + write_part_writer(&part_writer_tx, sub_batch).await?; + } + } + } + part_writer_close!(); + shutdown_part_writer(&part_writer_tx).await?; + part_writer_handle + .await + .or_else(|e| df_execution_err!("orc writer thread error: {e}"))??; + Ok(()) + })) +} + +enum PartWriterCommand { + Open { + part_values: Vec, + response: oneshot::Sender>, + }, + Write { + batch: RecordBatch, + response: oneshot::Sender>, + }, + Close { + response: oneshot::Sender>>, + }, + Shutdown { + response: oneshot::Sender>, + }, +} + +fn part_writer_worker_loop( + orc_sink_context: Arc, + command_rx: mpsc::Receiver, +) -> Result<()> { + let mut part_writer: Option = None; + while let Ok(command) = command_rx.recv() { + match command { + PartWriterCommand::Open { + part_values, + response, + } => { + let result = (|| -> Result<()> { + if part_writer.is_some() { + return df_execution_err!( + "opening orc file error: partition writer already open" + ); + } + part_writer = Some( + PartWriter::try_new(orc_sink_context.clone(), &part_values) + .or_else(|e| df_execution_err!("opening orc file error: {e}"))?, + ); + Ok(()) + })(); + let _ = response.send(result); + } + PartWriterCommand::Write { batch, response } => { + let result = match part_writer.as_mut() { + Some(writer) => writer + .write(&batch) + .or_else(|e| df_execution_err!("writing orc file error: {e}")), + None => df_execution_err!("writing orc file error: missing partition writer"), + }; + let _ = response.send(result); + } + PartWriterCommand::Close { response } => { + let result = close_current_part_writer(&mut part_writer) + .or_else(|e| df_execution_err!("closing orc file error: {e}")); + let _ = response.send(result); + } + PartWriterCommand::Shutdown { response } => { + let result = close_current_part_writer(&mut part_writer) + .or_else(|e| df_execution_err!("closing orc file error: {e}")) + .map(|_| ()); + let _ = response.send(result); + break; + } + } + } + close_current_part_writer(&mut part_writer).map(|_| ())?; + Ok(()) +} + +fn close_current_part_writer(part_writer: &mut Option) -> Result> { + part_writer.take().map(|writer| writer.close()).transpose() +} + +async fn open_part_writer( + part_writer_tx: &mpsc::Sender, + part_values: Vec, +) -> Result<()> { + let (response_tx, response_rx) = oneshot::channel(); + part_writer_tx + .send(PartWriterCommand::Open { + part_values, + response: response_tx, + }) + .or_else(|e| df_execution_err!("opening orc writer command error: {e}"))?; + response_rx + .await + .or_else(|e| df_execution_err!("opening orc writer response error: {e}"))? +} + +async fn write_part_writer( + part_writer_tx: &mpsc::Sender, + batch: RecordBatch, +) -> Result<()> { + let (response_tx, response_rx) = oneshot::channel(); + part_writer_tx + .send(PartWriterCommand::Write { + batch, + response: response_tx, + }) + .or_else(|e| df_execution_err!("writing orc writer command error: {e}"))?; + response_rx + .await + .or_else(|e| df_execution_err!("writing orc writer response error: {e}"))? +} + +async fn close_part_writer( + part_writer_tx: &mpsc::Sender, +) -> Result> { + let (response_tx, response_rx) = oneshot::channel(); + part_writer_tx + .send(PartWriterCommand::Close { + response: response_tx, + }) + .or_else(|e| df_execution_err!("closing orc writer command error: {e}"))?; + response_rx + .await + .or_else(|e| df_execution_err!("closing orc writer response error: {e}"))? +} + +async fn shutdown_part_writer(part_writer_tx: &mpsc::Sender) -> Result<()> { + let (response_tx, response_rx) = oneshot::channel(); + part_writer_tx + .send(PartWriterCommand::Shutdown { + response: response_tx, + }) + .or_else(|e| df_execution_err!("shutting down orc writer command error: {e}"))?; + response_rx + .await + .or_else(|e| df_execution_err!("shutting down orc writer response error: {e}"))? +} + +fn adapt_schema(batch: &RecordBatch, schema: &SchemaRef) -> Result { + let num_rows = batch.num_rows(); + let mut casted_cols = vec![]; + + for (col_idx, casted_field) in schema.fields().iter().enumerate() { + casted_cols.push(cast(batch.column(col_idx), casted_field.data_type())?); + } + Ok(RecordBatch::try_new_with_options( + schema.clone(), + casted_cols, + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + )?) +} + +fn rfind_part_values(batch: &RecordBatch, part_values: &[ScalarValue]) -> Result { + for row_idx in (0..batch.num_rows()).rev() { + if get_dyn_part_values(batch, part_values.len(), row_idx)? == part_values { + return Ok(row_idx + 1); + } + } + Ok(0) +} + +#[derive(Debug)] +struct PartFileStat { + path: String, + num_rows: usize, + num_bytes: usize, +} + +struct PartWriter { + path: String, + _orc_sink_context: Arc, + orc_writer: orc_rust::ArrowWriter, + part_values: Vec, + rows_written: Count, + bytes_written: Count, +} + +impl PartWriter { + fn try_new(orc_sink_context: Arc, part_values: &[ScalarValue]) -> Result { + if !part_values.is_empty() { + log::info!("starts outputting dynamic partition: {part_values:?}"); + } + let part_file = jni_get_string!( + jni_call_static!(AuronNativeOrcSinkUtils.getTaskOutputPath() -> JObject)? + .as_obj() + .into() + )?; + log::info!("starts writing orc file: {part_file}"); + + let fs = orc_sink_context.fs_provider.provide(&part_file)?; + let bytes_written = Count::new(); + let rows_written = Count::new(); + let fout = Arc::into_inner(fs.create(&part_file)?).expect("Arc::into_inner"); + let data_writer = FSDataWriter::new(fout, &bytes_written); + let orc_writer = ArrowWriterBuilder::new(data_writer, orc_sink_context.schema.clone()) + .with_batch_size(orc_sink_context.batch_size) + .with_stripe_byte_size(orc_sink_context.stripe_byte_size) + .try_build() + .or_else(|e| df_execution_err!("building orc writer error: {e}"))?; + Ok(Self { + path: part_file, + _orc_sink_context: orc_sink_context, + orc_writer, + part_values: part_values.to_vec(), + rows_written, + bytes_written, + }) + } + + fn write(&mut self, batch: &RecordBatch) -> Result<()> { + self.orc_writer + .write(batch) + .or_else(|e| df_execution_err!("encoding orc batch error: {e}"))?; + self.rows_written.add(batch.num_rows()); + Ok(()) + } + + fn close(self) -> Result { + let rows_written = self.rows_written.value(); + let bytes_written = self.bytes_written.value(); + self.orc_writer + .close() + .or_else(|e| df_execution_err!("closing orc writer error: {e}"))?; + + let stat = PartFileStat { + path: self.path, + num_rows: rows_written, + num_bytes: bytes_written, + }; + log::info!("finished writing orc file: {stat:?}"); + Ok(stat) + } +} + +fn get_dyn_part_values( + batch: &RecordBatch, + num_dyn_parts: usize, + row_idx: usize, +) -> Result> { + batch + .columns() + .iter() + .skip(batch.num_columns() - num_dyn_parts) + .map(|part_col| ScalarValue::try_from_array(part_col, row_idx)) + .collect() +} + +struct FSDataWriter { + inner: FsDataOutputWrapper, + bytes_written: Count, +} + +impl FSDataWriter { + pub fn new(inner: FsDataOutputWrapper, bytes_written: &Count) -> Self { + Self { + inner, + bytes_written: bytes_written.clone(), + } + } +} + +impl Write for FSDataWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.inner + .write_fully(&buf) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?; + self.bytes_written.add(buf.len()); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} diff --git a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala index 0e8a2c8e7..db1ddb627 100644 --- a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala +++ b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala @@ -73,11 +73,16 @@ import org.apache.spark.sql.execution.auron.plan.NativeGlobalLimitBase import org.apache.spark.sql.execution.auron.plan.NativeGlobalLimitExec import org.apache.spark.sql.execution.auron.plan.NativeLocalLimitBase import org.apache.spark.sql.execution.auron.plan.NativeLocalLimitExec +import org.apache.spark.sql.execution.auron.plan.NativeOrcInsertIntoHiveTableBase +import org.apache.spark.sql.execution.auron.plan.NativeOrcInsertIntoHiveTableExec import org.apache.spark.sql.execution.auron.plan.NativeOrcScanExec +import org.apache.spark.sql.execution.auron.plan.NativeOrcSinkBase +import org.apache.spark.sql.execution.auron.plan.NativeOrcSinkExec import org.apache.spark.sql.execution.auron.plan.NativeParquetInsertIntoHiveTableBase import org.apache.spark.sql.execution.auron.plan.NativeParquetInsertIntoHiveTableExec import org.apache.spark.sql.execution.auron.plan.NativeParquetScanBase import org.apache.spark.sql.execution.auron.plan.NativeParquetScanExec +import org.apache.spark.sql.execution.auron.plan.NativeParquetSinkBase import org.apache.spark.sql.execution.auron.plan.NativeProjectBase import org.apache.spark.sql.execution.auron.plan.NativeRenameColumnsBase import org.apache.spark.sql.execution.auron.plan.NativeShuffleExchangeBase @@ -330,6 +335,11 @@ class ShimsImpl extends Shims with Logging { child: SparkPlan): NativeParquetInsertIntoHiveTableBase = NativeParquetInsertIntoHiveTableExec(cmd, child) + override def createNativeOrcInsertIntoHiveTableExec( + cmd: InsertIntoHiveTable, + child: SparkPlan): NativeOrcInsertIntoHiveTableBase = + NativeOrcInsertIntoHiveTableExec(cmd, child) + override def createNativeParquetScanExec( basedFileScan: FileSourceScanExec): NativeParquetScanBase = NativeParquetScanExec(basedFileScan) @@ -395,6 +405,14 @@ class ShimsImpl extends Shims with Logging { metrics: Map[String, SQLMetric]): NativeParquetSinkBase = NativeParquetSinkExec(sparkSession, table, partition, child, metrics) + override def createNativeOrcSinkExec( + sparkSession: SparkSession, + table: CatalogTable, + partition: Map[String, Option[String]], + child: SparkPlan, + metrics: Map[String, SQLMetric]): NativeOrcSinkBase = + NativeOrcSinkExec(sparkSession, table, partition, child, metrics) + override def getUnderlyingBroadcast(plan: SparkPlan): BroadcastExchangeLike = { plan match { case exec: BroadcastExchangeLike => exec diff --git a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcInsertIntoHiveTableExec.scala b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcInsertIntoHiveTableExec.scala new file mode 100644 index 000000000..e364f2434 --- /dev/null +++ b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcInsertIntoHiveTableExec.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.Row +import org.apache.spark.sql.auron.Shims +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.hive.execution.InsertIntoHiveTable + +import org.apache.auron.sparkver + +case class NativeOrcInsertIntoHiveTableExec( + cmd: InsertIntoHiveTable, + override val child: SparkPlan) + extends NativeOrcInsertIntoHiveTableBase(cmd, child) { + + @sparkver("3.0 / 3.1 / 3.2 / 3.3") + override protected def getInsertIntoHiveTableCommand( + table: CatalogTable, + partition: Map[String, Option[String]], + query: LogicalPlan, + overwrite: Boolean, + ifPartitionNotExists: Boolean, + outputColumnNames: Seq[String], + metrics: Map[String, SQLMetric]): InsertIntoHiveTable = { + new AuronInsertIntoHiveTable30( + table, + partition, + query, + overwrite, + ifPartitionNotExists, + outputColumnNames, + metrics) + } + + @sparkver("3.4 / 3.5") + override protected def getInsertIntoHiveTableCommand( + table: CatalogTable, + partition: Map[String, Option[String]], + query: LogicalPlan, + overwrite: Boolean, + ifPartitionNotExists: Boolean, + outputColumnNames: Seq[String], + metrics: Map[String, SQLMetric]): InsertIntoHiveTable = { + new AuronInsertIntoHiveTable34( + table, + partition, + query, + overwrite, + ifPartitionNotExists, + outputColumnNames, + metrics) + } + + @sparkver("4.0 / 4.1") + override protected def getInsertIntoHiveTableCommand( + table: CatalogTable, + partition: Map[String, Option[String]], + query: LogicalPlan, + overwrite: Boolean, + ifPartitionNotExists: Boolean, + outputColumnNames: Seq[String], + metrics: Map[String, SQLMetric]): InsertIntoHiveTable = { + new AuronInsertIntoHiveTable41( + table, + partition, + query, + overwrite, + ifPartitionNotExists, + outputColumnNames, + metrics) + } + + @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0 / 4.1") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + @sparkver("3.0 / 3.1") + override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = + copy(child = newChildren.head) + + @sparkver("3.0 / 3.1 / 3.2 / 3.3") + class AuronInsertIntoHiveTable30( + table: CatalogTable, + partition: Map[String, Option[String]], + query: LogicalPlan, + overwrite: Boolean, + ifPartitionNotExists: Boolean, + outputColumnNames: Seq[String], + outerMetrics: Map[String, SQLMetric]) + extends InsertIntoHiveTable( + table, + partition, + query, + overwrite, + ifPartitionNotExists, + outputColumnNames) { + + override lazy val metrics: Map[String, SQLMetric] = outerMetrics + + override def run( + sparkSession: org.apache.spark.sql.SparkSession, + child: SparkPlan): Seq[Row] = { + val nativeOrcSink = + Shims.get.createNativeOrcSinkExec(sparkSession, table, partition, child, metrics) + super.run(sparkSession, nativeOrcSink) + } + + @sparkver("3.2 / 3.3") + override def basicWriteJobStatsTracker(hadoopConf: org.apache.hadoop.conf.Configuration) + : org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker = { + import org.apache.spark.sql.catalyst.InternalRow + import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker + import org.apache.spark.sql.execution.datasources.BasicWriteTaskStatsTracker + import org.apache.spark.sql.execution.datasources.WriteTaskStatsTracker + import org.apache.spark.util.SerializableConfiguration + + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) { + override def newTaskInstance(): WriteTaskStatsTracker = { + new BasicWriteTaskStatsTracker(serializableHadoopConf.value) { + override def newRow(filePath: String, row: InternalRow): Unit = { + if (!OrcSinkTaskContext.get.isNative) { + return super.newRow(filePath, row) + } + } + + override def closeFile(filePath: String): Unit = { + if (!OrcSinkTaskContext.get.isNative) { + return super.closeFile(filePath) + } + + val outputFileStat = OrcSinkTaskContext.get.processedOutputFiles.remove() + for (_ <- 0L until outputFileStat.numRows) { + super.newRow(filePath, null) + } + super.closeFile(filePath) + } + } + } + } + } + + @sparkver("3.1") + override def basicWriteJobStatsTracker(hadoopConf: org.apache.hadoop.conf.Configuration) + : org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker = { + import org.apache.spark.sql.catalyst.InternalRow + import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker + import org.apache.spark.sql.execution.datasources.BasicWriteTaskStats + import org.apache.spark.sql.execution.datasources.BasicWriteTaskStatsTracker + import org.apache.spark.sql.execution.datasources.WriteTaskStats + import org.apache.spark.sql.execution.datasources.WriteTaskStatsTracker + import org.apache.spark.util.SerializableConfiguration + + import scala.collection.mutable + + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) { + override def newTaskInstance(): WriteTaskStatsTracker = { + new BasicWriteTaskStatsTracker(serializableHadoopConf.value) { + private[this] val partitions: mutable.ArrayBuffer[InternalRow] = + mutable.ArrayBuffer.empty + + override def newPartition(partitionValues: InternalRow): Unit = { + if (!OrcSinkTaskContext.get.isNative) { + return super.newPartition(partitionValues) + } + partitions.append(partitionValues) + } + + override def newRow(row: InternalRow): Unit = { + if (!OrcSinkTaskContext.get.isNative) { + return super.newRow(row) + } + } + + override def getFinalStats(): WriteTaskStats = { + if (!OrcSinkTaskContext.get.isNative) { + return super.getFinalStats() + } + + val outputFileStat = OrcSinkTaskContext.get.processedOutputFiles.remove() + BasicWriteTaskStats( + partitions = partitions, + numFiles = 1, + numBytes = outputFileStat.numBytes, + numRows = outputFileStat.numRows) + } + } + } + } + } + + @sparkver("3.0") + override def basicWriteJobStatsTracker(hadoopConf: org.apache.hadoop.conf.Configuration) + : org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker = { + import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker + import org.apache.spark.sql.execution.datasources.BasicWriteTaskStats + import org.apache.spark.sql.execution.datasources.BasicWriteTaskStatsTracker + import org.apache.spark.sql.execution.datasources.WriteTaskStats + import org.apache.spark.sql.execution.datasources.WriteTaskStatsTracker + import org.apache.spark.util.SerializableConfiguration + + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) { + override def newTaskInstance(): WriteTaskStatsTracker = { + new BasicWriteTaskStatsTracker(serializableHadoopConf.value) { + override def newRow(row: org.apache.spark.sql.catalyst.InternalRow): Unit = { + if (!OrcSinkTaskContext.get.isNative) { + return super.newRow(row) + } + } + + override def getFinalStats(): WriteTaskStats = { + if (!OrcSinkTaskContext.get.isNative) { + return super.getFinalStats() + } + + val outputFileStat = OrcSinkTaskContext.get.processedOutputFiles.remove() + BasicWriteTaskStats( + numPartitions = 1, + numFiles = 1, + numBytes = outputFileStat.numBytes, + numRows = outputFileStat.numRows) + } + } + } + } + } + } + + @sparkver("3.4 / 3.5") + class AuronInsertIntoHiveTable34( + table: CatalogTable, + partition: Map[String, Option[String]], + query: LogicalPlan, + overwrite: Boolean, + ifPartitionNotExists: Boolean, + outputColumnNames: Seq[String], + outerMetrics: Map[String, SQLMetric]) + extends { + private val insertIntoHiveTable = InsertIntoHiveTable( + table, + partition, + query, + overwrite, + ifPartitionNotExists, + outputColumnNames) + private val initPartitionColumns = insertIntoHiveTable.partitionColumns + private val initBucketSpec = insertIntoHiveTable.bucketSpec + private val initOptions = insertIntoHiveTable.options + private val initFileFormat = insertIntoHiveTable.fileFormat + private val initHiveTmpPath = insertIntoHiveTable.hiveTmpPath + + } + with InsertIntoHiveTable( + table, + partition, + query, + overwrite, + ifPartitionNotExists, + outputColumnNames, + initPartitionColumns, + initBucketSpec, + initOptions, + initFileFormat, + initHiveTmpPath) { + + override lazy val metrics: Map[String, SQLMetric] = outerMetrics + + override def run( + sparkSession: org.apache.spark.sql.SparkSession, + child: SparkPlan): Seq[Row] = { + val nativeOrcSink = + Shims.get.createNativeOrcSinkExec(sparkSession, table, partition, child, metrics) + super.run(sparkSession, nativeOrcSink) + } + } + + @sparkver("4.0 / 4.1") + class AuronInsertIntoHiveTable41( + table: CatalogTable, + partition: Map[String, Option[String]], + query: LogicalPlan, + overwrite: Boolean, + ifPartitionNotExists: Boolean, + outputColumnNames: Seq[String], + outerMetrics: Map[String, SQLMetric]) + extends { + private val insertIntoHiveTable = InsertIntoHiveTable( + table, + partition, + query, + overwrite, + ifPartitionNotExists, + outputColumnNames) + private val initPartitionColumns = insertIntoHiveTable.partitionColumns + private val initBucketSpec = insertIntoHiveTable.bucketSpec + private val initOptions = insertIntoHiveTable.options + private val initFileFormat = insertIntoHiveTable.fileFormat + private val initHiveTmpPath = insertIntoHiveTable.hiveTmpPath + + } + with InsertIntoHiveTable( + table, + partition, + query, + overwrite, + ifPartitionNotExists, + outputColumnNames, + initPartitionColumns, + initBucketSpec, + initOptions, + initFileFormat, + initHiveTmpPath) { + + override lazy val metrics: Map[String, SQLMetric] = outerMetrics + + override def run( + sparkSession: org.apache.spark.sql.classic.SparkSession, + child: SparkPlan): Seq[Row] = { + val nativeOrcSink = + Shims.get.createNativeOrcSinkExec(sparkSession, table, partition, child, metrics) + super.run(sparkSession, nativeOrcSink) + } + } +} diff --git a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkExec.scala b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkExec.scala new file mode 100644 index 000000000..183e4271a --- /dev/null +++ b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkExec.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric + +import org.apache.auron.sparkver + +case class NativeOrcSinkExec( + sparkSession: SparkSession, + table: CatalogTable, + partition: Map[String, Option[String]], + override val child: SparkPlan, + override val metrics: Map[String, SQLMetric]) + extends NativeOrcSinkBase(sparkSession, table, partition, child, metrics) { + + @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0 / 4.1") + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + @sparkver("3.0 / 3.1") + override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = + copy(child = newChildren.head) +} diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/BaseAuronHiveSQLSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/BaseAuronHiveSQLSuite.scala new file mode 100644 index 000000000..8656e2ad3 --- /dev/null +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/BaseAuronHiveSQLSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.auron + +import org.apache.spark.SparkConf + +trait BaseAuronHiveSQLSuite extends BaseAuronSQLSuite { + override protected def sparkConf: SparkConf = + super.sparkConf + .set("spark.sql.catalogImplementation", "hive") + .set( + "spark.hadoop.javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$metastoreDir/metastore_db;create=true") +} diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/exec/AuronExecSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/exec/AuronExecSuite.scala index 7f62dd521..e4e766993 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/exec/AuronExecSuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/exec/AuronExecSuite.scala @@ -127,4 +127,5 @@ class AuronExecSuite extends AuronQueryTest with BaseAuronSQLSuite { } } } + } diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/exec/AuronHiveExecSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/exec/AuronHiveExecSuite.scala new file mode 100644 index 000000000..57e89f57e --- /dev/null +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/exec/AuronHiveExecSuite.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.auron.exec + +import java.io.File + +import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde} +import org.apache.spark.sql.{AuronQueryTest, Row} +import org.apache.spark.sql.auron.AuronConverters +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.auron.plan.{NativeOrcInsertIntoHiveTableBase, NativeSortBase} +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.hive.execution.InsertIntoHiveTable +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +import org.apache.auron.BaseAuronHiveSQLSuite + +class AuronHiveExecSuite extends AuronQueryTest with BaseAuronHiveSQLSuite { + + private def withSqlExecutionId[T](body: => T): T = { + val sparkContext = spark.sparkContext + val previousExecutionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + sparkContext.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, System.nanoTime().toString) + try { + body + } finally { + sparkContext.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, previousExecutionId) + } + } + + private def buildOrcTable( + tableName: String, + schema: org.apache.spark.sql.types.StructType, + partitionColumnNames: Seq[String] = Nil): CatalogTable = { + val tableLocation = new File(warehouseDir, s"${spark.catalog.currentDatabase}.db/$tableName") + CatalogTable( + identifier = TableIdentifier(tableName, Some(spark.catalog.currentDatabase)), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty.copy( + locationUri = Some(tableLocation.toURI), + inputFormat = Some(classOf[OrcInputFormat].getName), + outputFormat = Some(classOf[OrcOutputFormat].getName), + serde = Some(classOf[OrcSerde].getName)), + schema = schema, + provider = None, + partitionColumnNames = partitionColumnNames) + } + + private def buildOrcInsertExec( + table: CatalogTable, + queryDf: org.apache.spark.sql.DataFrame, + partition: Map[String, Option[String]] = Map.empty): DataWritingCommandExec = { + val analyzedQuery = queryDf.queryExecution.analyzed + val partitionColumns = table.partitionSchema.fields.toSeq.map { field => + AttributeReference(field.name, field.dataType, field.nullable, field.metadata)() + } + val outputColumnNames = queryDf.columns.toSeq + val ctor = classOf[InsertIntoHiveTable].getConstructors + .find(c => c.getParameterCount == 11 || c.getParameterCount == 6) + .getOrElse( + throw new IllegalStateException(s"Unsupported InsertIntoHiveTable constructor count: " + + classOf[InsertIntoHiveTable].getConstructors.map(_.getParameterCount).mkString(","))) + val args: Seq[Object] = + if (ctor.getParameterCount == 11) { + Seq( + table, + partition, + analyzedQuery, + Boolean.box(false), + Boolean.box(false), + outputColumnNames, + partitionColumns, + None, + Map.empty[String, String], + new OrcFileFormat, + null) + } else { + Seq( + table, + partition, + analyzedQuery, + Boolean.box(false), + Boolean.box(false), + outputColumnNames) + } + val cmd = ctor.newInstance(args: _*).asInstanceOf[InsertIntoHiveTable] + DataWritingCommandExec(cmd, queryDf.queryExecution.sparkPlan) + } + + private def createHiveOrcTable( + tableName: String, + schema: StructType, + partitionColumnNames: Seq[String]): CatalogTable = { + val table = buildOrcTable(tableName, schema, partitionColumnNames) + spark.sharedState.externalCatalog.createTable(table, ignoreIfExists = false) + spark.sessionState.catalog.getTableMetadata( + TableIdentifier(tableName, Some(spark.catalog.currentDatabase))) + } + + test("convert ORC InsertIntoHiveTable to native ORC insert") { + withSQLConf("spark.auron.enable.data.writing" -> "true") { + withTable("src_orc_insert") { + sql(""" + |create table src_orc_insert using parquet as + |select 1 as id, 'a' as v + |union all + |select 2 as id, 'b' as v + |""".stripMargin) + + val queryDf = sql("select id, v from src_orc_insert") + val exec = buildOrcInsertExec(buildOrcTable("t_orc_native", queryDf.schema), queryDf) + val converted = AuronConverters.convertDataWritingCommandExec(exec) + + assert(converted.isInstanceOf[NativeOrcInsertIntoHiveTableBase], converted.toString) + } + } + } + + test("convert ORC InsertIntoHiveTable with dynamic partitions to native ORC insert") { + withSQLConf( + "spark.auron.enable.data.writing" -> "true", + "hive.exec.dynamic.partition" -> "true", + "hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable("src_orc_insert_part") { + sql(""" + |create table src_orc_insert_part using parquet as + |select 1 as id, 'a' as v, 'p1' as part + |union all + |select 2 as id, 'b' as v, 'p2' as part + |""".stripMargin) + + val queryDf = sql("select id, v, part from src_orc_insert_part") + val exec = buildOrcInsertExec( + buildOrcTable("t_orc_native_part", queryDf.schema, Seq("part")), + queryDf, + partition = Map("part" -> None)) + val converted = AuronConverters.convertDataWritingCommandExec(exec) + + assert(converted.isInstanceOf[NativeOrcInsertIntoHiveTableBase], converted.toString) + assert(collect(converted) { case e: NativeSortBase => e }.nonEmpty, converted.toString) + } + } + } + + test( + "convert ORC InsertIntoHiveTable with static and dynamic partitions to native ORC insert") { + withSQLConf( + "spark.auron.enable.data.writing" -> "true", + "hive.exec.dynamic.partition" -> "true") { + withTable("src_orc_insert_mixed_part") { + sql(""" + |create table src_orc_insert_mixed_part using parquet as + |select 1 as id, 'a' as v, 'p1' as part + |union all + |select 2 as id, 'b' as v, 'p2' as part + |""".stripMargin) + + val queryDf = sql("select id, v, part from src_orc_insert_mixed_part") + val mixedPartitionSchema = StructType( + queryDf.schema.fields.filterNot(_.name == "part") ++ + Seq(StructField("ds", StringType, nullable = true), queryDf.schema("part"))) + val exec = buildOrcInsertExec( + buildOrcTable("t_orc_native_mixed_part", mixedPartitionSchema, Seq("ds", "part")), + queryDf, + partition = Map("ds" -> Some("2026-04-13"), "part" -> None)) + val converted = AuronConverters.convertDataWritingCommandExec(exec) + + assert(converted.isInstanceOf[NativeOrcInsertIntoHiveTableBase], converted.toString) + assert(collect(converted) { case e: NativeSortBase => e }.nonEmpty, converted.toString) + } + } + } + + test("execute native ORC InsertIntoHiveTable with static and dynamic partitions") { + withSQLConf( + "spark.auron.enable.data.writing" -> "true", + "hive.exec.dynamic.partition" -> "true") { + withTable("src_orc_insert_exec", "t_orc_native_exec") { + sql(""" + |create table src_orc_insert_exec using parquet as + |select 1 as id, 'a' as v, 'p1' as part + |union all + |select 2 as id, 'b' as v, 'p2' as part + |""".stripMargin) + + val queryDf = sql("select id, v, part from src_orc_insert_exec") + val mixedPartitionSchema = StructType( + queryDf.schema.fields.filterNot(_.name == "part") ++ + Seq(StructField("ds", StringType, nullable = true), queryDf.schema("part"))) + val targetTable = + createHiveOrcTable("t_orc_native_exec", mixedPartitionSchema, Seq("ds", "part")) + val exec = buildOrcInsertExec( + targetTable, + queryDf, + partition = Map("ds" -> Some("2026-04-13"), "part" -> None)) + val converted = AuronConverters.convertDataWritingCommandExec(exec) + val plan = stripAQEPlan(converted) + + assert( + collect(plan) { case e: NativeOrcInsertIntoHiveTableBase => e }.size == 1, + plan.toString) + assert(collect(plan) { case e: NativeSortBase => e }.nonEmpty, plan.toString) + + withSqlExecutionId { + converted.executeCollect() + } + val actualRows = spark.read + .orc(targetTable.storage.locationUri.get.toString) + .selectExpr("id", "v", "cast(ds as string) as ds", "part") + .collect() + .sortBy(row => (row.getInt(0), row.getString(3))) + .toSeq + val expectedRows = Seq(Row(1, "a", "2026-04-13", "p1"), Row(2, "b", "2026-04-13", "p2")) + assert(actualRows == expectedRows, s"actualRows=$actualRows expectedRows=$expectedRows") + } + } + } + + test("keep unsupported ORC InsertIntoHiveTable schema on Spark path") { + withTable("src_orc_insert_map") { + sql(""" + |create table src_orc_insert_map using parquet as + |select map('a', 1, 'b', 2) as m + |""".stripMargin) + + val queryDf = sql("select m from src_orc_insert_map") + val exec = buildOrcInsertExec(buildOrcTable("t_orc_native_map", queryDf.schema), queryDf) + val converted = AuronConverters.convertSparkPlan(exec) + + assert(!converted.isInstanceOf[NativeOrcInsertIntoHiveTableBase], converted.toString) + } + } +} diff --git a/spark-extension/src/main/java/org/apache/auron/spark/configuration/SparkAuronConfiguration.java b/spark-extension/src/main/java/org/apache/auron/spark/configuration/SparkAuronConfiguration.java index bc46ed312..8d3728e17 100644 --- a/spark-extension/src/main/java/org/apache/auron/spark/configuration/SparkAuronConfiguration.java +++ b/spark-extension/src/main/java/org/apache/auron/spark/configuration/SparkAuronConfiguration.java @@ -447,6 +447,18 @@ public class SparkAuronConfiguration extends AuronConfiguration { .withDescription("Enable DataWritingExec operation conversion to native Auron implementations.") .withDefaultValue(false); + public static final ConfigOption ENABLE_DATA_WRITING_PARQUET = new SQLConfOption<>(Boolean.class) + .withKey("auron.enable.data.writing.parquet") + .withCategory("Operator Supports") + .withDescription("Enable Parquet DataWritingExec operation conversion to native Auron implementations.") + .withDefaultValue(true); + + public static final ConfigOption ENABLE_DATA_WRITING_ORC = new SQLConfOption<>(Boolean.class) + .withKey("auron.enable.data.writing.orc") + .withCategory("Operator Supports") + .withDescription("Enable ORC DataWritingExec operation conversion to native Auron implementations.") + .withDefaultValue(true); + public static final ConfigOption ENABLE_SCAN_PARQUET = new SQLConfOption<>(Boolean.class) .withKey("auron.enable.scan.parquet") .withCategory("Data Sources") diff --git a/spark-extension/src/main/java/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkUtils.java b/spark-extension/src/main/java/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkUtils.java new file mode 100644 index 000000000..68a8f7497 --- /dev/null +++ b/spark-extension/src/main/java/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkUtils.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan; + +// for jni_bridge usage +@SuppressWarnings("unused") +public class NativeOrcSinkUtils { + public static String getTaskOutputPath() throws InterruptedException { + return OrcSinkTaskContext$.MODULE$.get().processingOutputFiles().take(); + } + + public static void completeOutput(String path, long numRows, long numBytes) { + OutputFileStat stat = new OutputFileStat(path, numRows, numBytes); + OrcSinkTaskContext$.MODULE$.get().processedOutputFiles().push(stat); + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala index cc12a176a..6d1367df3 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import scala.jdk.CollectionConverters._ import org.apache.commons.lang3.reflect.MethodUtils +import org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat import org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat import org.apache.spark.Partition import org.apache.spark.broadcast.Broadcast @@ -57,6 +58,7 @@ import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.spark.sql.execution.auron.plan.ConvertToNativeBase import org.apache.spark.sql.execution.auron.plan.NativeAggBase import org.apache.spark.sql.execution.auron.plan.NativeBroadcastExchangeBase +import org.apache.spark.sql.execution.auron.plan.NativeOrcInsertIntoHiveTableBase import org.apache.spark.sql.execution.auron.plan.NativeOrcScanBase import org.apache.spark.sql.execution.auron.plan.NativeParquetScanBase import org.apache.spark.sql.execution.auron.plan.NativeSortBase @@ -70,7 +72,17 @@ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.hive.execution.InsertIntoHiveTable import org.apache.spark.sql.hive.execution.auron.plan.NativeHiveTableScanBase import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.ByteType +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.FloatType +import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.ShortType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructType import org.apache.auron.configuration.AuronConfiguration import org.apache.auron.jni.AuronAdaptor @@ -101,6 +113,9 @@ object AuronConverters extends Logging { def enableGenerate: Boolean = SparkAuronConfiguration.ENABLE_GENERATE.get() def enableLocalTableScan: Boolean = SparkAuronConfiguration.ENABLE_LOCAL_TABLE_SCAN.get() def enableDataWriting: Boolean = SparkAuronConfiguration.ENABLE_DATA_WRITING.get() + def enableDataWritingParquet: Boolean = + SparkAuronConfiguration.ENABLE_DATA_WRITING_PARQUET.get() + def enableDataWritingOrc: Boolean = SparkAuronConfiguration.ENABLE_DATA_WRITING_ORC.get() def enableScanParquet: Boolean = SparkAuronConfiguration.ENABLE_SCAN_PARQUET.get() def enableScanParquetTimestamp: Boolean = SparkAuronConfiguration.ENABLE_SCAN_PARQUET_TIMESTAMP.get() @@ -132,6 +147,16 @@ object AuronConverters extends Logging { supportedShuffleManagers.exists(name.contains) } + def isOrcWriteTypeSupported(dataType: DataType): Boolean = dataType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + StringType | BinaryType => + true + case _ => false + } + + def isOrcWriteSchemaSupported(schema: StructType): Boolean = + schema.forall(field => isOrcWriteTypeSupported(field.dataType)) + def convertSparkPlanRecursively(exec: SparkPlan): SparkPlan = { // convert var danglingConverted: Seq[SparkPlan] = Nil @@ -1012,25 +1037,50 @@ object AuronConverters extends Logging { def convertDataWritingCommandExec(exec: DataWritingCommandExec): SparkPlan = { logDebugPlanConversion(exec) + + def isParquetInsertIntoHiveTable(cmd: InsertIntoHiveTable): Boolean = + cmd.table.storage.outputFormat.contains(classOf[MapredParquetOutputFormat].getName) + + def isOrcInsertIntoHiveTable(cmd: InsertIntoHiveTable): Boolean = + cmd.table.storage.outputFormat.contains(classOf[OrcOutputFormat].getName) + + def failWhenDataWritingDisabled(enabled: Boolean, confKey: String): Unit = { + if (!enabled) { + throw new NotImplementedError(s"Conversion disabled: $confKey=false.") + } + } + + def sortInsertChild(cmd: InsertIntoHiveTable, child: SparkPlan): SparkPlan = { + var sortedChild = convertToNative(child) + val numDynParts = cmd.partition.count(_._2.isEmpty) + val requiredOrdering = + child.output.slice(child.output.length - numDynParts, child.output.length) + if (requiredOrdering.nonEmpty && child.outputOrdering.map(_.child) != requiredOrdering) { + val rowNumExpr = StubExpr("RowNum", LongType, nullable = false) + sortedChild = Shims.get.createNativeSortExec( + requiredOrdering.map(SortOrder(_, Ascending)) ++ Seq(SortOrder(rowNumExpr, Ascending)), + global = false, + sortedChild) + } + sortedChild + } + exec match { case DataWritingCommandExec(cmd: InsertIntoHiveTable, child) - if cmd.table.storage.outputFormat.contains( - classOf[MapredParquetOutputFormat].getName) => - // add an extra SortExec to sort child with dynamic columns - // add row number to achieve stable sort - var sortedChild = convertToNative(child) - val numDynParts = cmd.partition.count(_._2.isEmpty) - val requiredOrdering = - child.output.slice(child.output.length - numDynParts, child.output.length) - if (requiredOrdering.nonEmpty && child.outputOrdering.map(_.child) != requiredOrdering) { - val rowNumExpr = StubExpr("RowNum", LongType, nullable = false) - sortedChild = Shims.get.createNativeSortExec( - requiredOrdering.map(SortOrder(_, Ascending)) ++ Seq( - SortOrder(rowNumExpr, Ascending)), - global = false, - sortedChild) + if isParquetInsertIntoHiveTable(cmd) => + failWhenDataWritingDisabled( + enableDataWritingParquet, + "spark.auron.enable.data.writing.parquet") + Shims.get.createNativeParquetInsertIntoHiveTableExec(cmd, sortInsertChild(cmd, child)) + + case DataWritingCommandExec(cmd: InsertIntoHiveTable, child) + if isOrcInsertIntoHiveTable(cmd) => + failWhenDataWritingDisabled(enableDataWritingOrc, "spark.auron.enable.data.writing.orc") + if (NativeOrcInsertIntoHiveTableBase.isSupportedWriteSchema(cmd.table, cmd.partition)) { + Shims.get.createNativeOrcInsertIntoHiveTableExec(cmd, sortInsertChild(cmd, child)) + } else { + throw new NotImplementedError("unsupported DataWritingCommandExec") } - Shims.get.createNativeParquetInsertIntoHiveTableExec(cmd, sortedChild) case _ => throw new NotImplementedError("unsupported DataWritingCommandExec") diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala index 19f98b415..c51470e86 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala @@ -143,6 +143,10 @@ abstract class Shims { cmd: InsertIntoHiveTable, child: SparkPlan): NativeParquetInsertIntoHiveTableBase + def createNativeOrcInsertIntoHiveTableExec( + cmd: InsertIntoHiveTable, + child: SparkPlan): NativeOrcInsertIntoHiveTableBase + def createNativeParquetScanExec(basedFileScan: FileSourceScanExec): NativeParquetScanBase def createNativeOrcScanExec(basedFileScan: FileSourceScanExec): NativeOrcScanBase @@ -196,6 +200,13 @@ abstract class Shims { child: SparkPlan, metrics: Map[String, SQLMetric]): NativeParquetSinkBase + def createNativeOrcSinkExec( + sparkSession: SparkSession, + table: CatalogTable, + partition: Map[String, Option[String]], + child: SparkPlan, + metrics: Map[String, SQLMetric]): NativeOrcSinkBase + def isNative(plan: SparkPlan): Boolean def getUnderlyingNativePlan(plan: SparkPlan): NativeSupports diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/arrowio/ArrowFFIExporter.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/arrowio/ArrowFFIExporter.scala index cc1678594..c4ccbedce 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/arrowio/ArrowFFIExporter.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/arrowio/ArrowFFIExporter.scala @@ -36,12 +36,13 @@ import org.apache.spark.sql.execution.auron.arrowio.util.ArrowUtils.CHILD_ALLOCA import org.apache.spark.sql.execution.auron.arrowio.util.ArrowUtils.ROOT_ALLOCATOR import org.apache.spark.sql.execution.auron.arrowio.util.ArrowWriter import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.auron.arrowio.AuronArrowFFIExporter import org.apache.auron.configuration.AuronConfiguration import org.apache.auron.jni.AuronAdaptor -class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) +class ArrowFFIExporter(inputIter: Iterator[Any], schema: StructType) extends AuronArrowFFIExporter with Logging { private val sparkAuronConfig: AuronConfiguration = @@ -69,6 +70,7 @@ class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) private val outputQueue: BlockingQueue[QueueState] = new ArrayBlockingQueue[QueueState](16) private val processingQueue: BlockingQueue[Unit] = new ArrayBlockingQueue[Unit](16) private var currentRoot: VectorSchemaRoot = _ + private val rowIter = new InputToRowIter(inputIter) private val outputThread = startOutputThread() def exportSchema(exportArrowSchemaPtr: Long): Unit = { @@ -149,6 +151,8 @@ class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) logDebug(s"ArrowFFIExporter-$exporterId: outputThread interrupted, exiting") outputQueue.clear() outputQueue.put(Finished(None)) + } finally { + rowIter.close() } } }) @@ -197,4 +201,77 @@ class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) logDebug(s"ArrowFFIExporter-$exporterId: close() completed") } } + + private class InputToRowIter(inputIter: Iterator[Any]) extends Iterator[InternalRow] { + private var currentBatch: ColumnarBatch = _ + private var currentBatchRowId = 0 + private var pendingRow: InternalRow = _ + + override def hasNext: Boolean = { + if (pendingRow != null) { + return true + } + + closeFinishedBatch() + if (currentBatch != null) { + return true + } + + while (inputIter.hasNext) { + inputIter.next() match { + case row: InternalRow => + pendingRow = row + return true + case batch: ColumnarBatch if batch.numRows() > 0 => + currentBatch = batch + currentBatchRowId = 0 + return true + case batch: ColumnarBatch => + batch.close() + case null => + throw new IllegalStateException( + "ArrowFFIExporter expects InternalRow or ColumnarBatch input, but got null") + case other => + throw new IllegalStateException( + s"ArrowFFIExporter expects InternalRow or ColumnarBatch input, " + + s"but got ${other.getClass.getName}") + } + } + + false + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException("no more rows") + } + + if (pendingRow != null) { + val row = pendingRow + pendingRow = null + row + } else { + val row = currentBatch.getRow(currentBatchRowId) + currentBatchRowId += 1 + row + } + } + + def close(): Unit = { + if (currentBatch != null) { + currentBatch.close() + currentBatch = null + } + currentBatchRowId = 0 + pendingRow = null + } + + private def closeFinishedBatch(): Unit = { + if (currentBatch != null && currentBatchRowId >= currentBatch.numRows()) { + currentBatch.close() + currentBatch = null + currentBatchRowId = 0 + } + } + } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/ConvertToNativeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/ConvertToNativeBase.scala index 0d48ad275..c37e69f7d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/ConvertToNativeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/ConvertToNativeBase.scala @@ -21,6 +21,7 @@ import java.util.UUID import scala.collection.immutable.SortedMap import org.apache.spark.OneToOneDependency +import org.apache.spark.rdd.RDD import org.apache.spark.sql.auron.NativeConverters import org.apache.spark.sql.auron.NativeHelper import org.apache.spark.sql.auron.NativeRDD @@ -61,7 +62,11 @@ abstract class ConvertToNativeBase(override val child: SparkPlan) val nativeSchema: Schema = NativeConverters.convertSchema(renamedSchema) override def doExecuteNative(): NativeRDD = { - val inputRDD = child.execute() + val inputRDD: RDD[Any] = if (child.supportsColumnar) { + child.executeColumnar().asInstanceOf[RDD[Any]] + } else { + child.execute().asInstanceOf[RDD[Any]] + } val numInputPartitions = inputRDD.getNumPartitions val nativeMetrics = SparkMetricNode(metrics, Nil) @@ -73,9 +78,12 @@ abstract class ConvertToNativeBase(override val child: SparkPlan) rddDependencies = new OneToOneDependency(inputRDD) :: Nil, Shims.get.getRDDShuffleReadFull(inputRDD), (partition, context) => { - val inputRowIter = inputRDD.compute(partition, context) + // Columnar Spark plans can cast-fail internally if we force them through execute(). + // Keep columnar children on executeColumnar() and let ArrowFFIExporter normalize + // ColumnarBatch / InternalRow input into Arrow. + val inputIter = inputRDD.compute(partition, context).asInstanceOf[Iterator[Any]] val resourceId = s"ConvertToNativeExec:${UUID.randomUUID().toString}" - JniBridge.putResource(resourceId, new ArrowFFIExporter(inputRowIter, renamedSchema)) + JniBridge.putResource(resourceId, new ArrowFFIExporter(inputIter, renamedSchema)) PhysicalPlanNode .newBuilder() diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcInsertIntoHiveTableBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcInsertIntoHiveTableBase.scala new file mode 100644 index 000000000..fd763365d --- /dev/null +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcInsertIntoHiveTableBase.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import java.util +import java.util.Locale +import java.util.Properties +import java.util.concurrent.LinkedBlockingDeque + +import scala.collection.immutable.SortedMap +import scala.collection.mutable + +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.FileSinkOperator +import org.apache.hadoop.hive.ql.io.HiveOutputFormat +import org.apache.hadoop.hive.ql.io.orc.OrcSerde +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.FileOutputFormat +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.RecordWriter +import org.apache.hadoop.util.Progressable +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.auron.AuronConverters +import org.apache.spark.sql.auron.NativeHelper +import org.apache.spark.sql.auron.NativeRDD +import org.apache.spark.sql.auron.NativeSupports +import org.apache.spark.sql.auron.Shims +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.hive.execution.InsertIntoHiveTable +import org.apache.spark.sql.types.StructType + +object NativeOrcInsertIntoHiveTableBase { + def dataSchema(table: CatalogTable, partition: Map[String, Option[String]]): StructType = + StructType(table.schema.dropRight(partition.size)) + + def isSupportedWriteSchema( + table: CatalogTable, + partition: Map[String, Option[String]]): Boolean = + AuronConverters.isOrcWriteSchemaSupported(dataSchema(table, partition)) +} + +abstract class NativeOrcInsertIntoHiveTableBase( + cmd: InsertIntoHiveTable, + override val child: SparkPlan) + extends UnaryExecNode + with NativeSupports { + + override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ + BasicWriteJobStatsTracker.metrics ++ + Map( + NativeHelper + .getDefaultNativeMetrics(sparkContext) + .filterKeys(Set("stage_id", "output_rows", "elapsed_compute")) + .toSeq + :+ ("io_time", SQLMetrics.createNanoTimingMetric(sparkContext, "Native.io_time")) + :+ ("bytes_written", + SQLMetrics + .createSizeMetric(sparkContext, "Native.bytes_written")): _*) + + def check(): Unit = { + val tblStorage = cmd.table.storage + val outputFormatClassName = tblStorage.outputFormat.getOrElse("").toLowerCase(Locale.ROOT) + + assert(outputFormatClassName.endsWith("orcoutputformat"), "not orc format") + assert( + NativeOrcInsertIntoHiveTableBase.isSupportedWriteSchema(cmd.table, cmd.partition), + "not supported writing ORC schema") + } + check() + + @transient + val wrapped: DataWritingCommandExec = { + val transformedTable = { + val tblStorage = cmd.table.storage + cmd.table.withNewStorage( + tblStorage.locationUri, + tblStorage.inputFormat, + outputFormat = Some(classOf[AuronOrcOutputFormat].getName), + tblStorage.compressed, + serde = Some(classOf[OrcSerde].getName), + tblStorage.properties) + } + + val transformedCmd = getInsertIntoHiveTableCommand( + transformedTable, + cmd.partition, + cmd.query, + cmd.overwrite, + cmd.ifPartitionNotExists, + cmd.outputColumnNames, + metrics) + DataWritingCommandExec(transformedCmd, child) + } + + override def output: Seq[Attribute] = wrapped.output + override def outputPartitioning: Partitioning = wrapped.outputPartitioning + override def outputOrdering: Seq[SortOrder] = wrapped.outputOrdering + override def doExecute(): RDD[InternalRow] = wrapped.execute() + + override def executeCollect(): Array[InternalRow] = wrapped.executeCollect() + override def executeTake(n: Int): Array[InternalRow] = wrapped.executeTake(n) + override def executeToIterator(): Iterator[InternalRow] = wrapped.executeToIterator() + + override def doExecuteNative(): NativeRDD = { + Shims.get.createConvertToNativeExec(wrapped).executeNative() + } + + override def nodeName: String = + s"NativeOrcInsert ${cmd.table.identifier.unquotedString}" + + protected def getInsertIntoHiveTableCommand( + table: CatalogTable, + partition: Map[String, Option[String]], + query: LogicalPlan, + overwrite: Boolean, + ifPartitionNotExists: Boolean, + outputColumnNames: Seq[String], + metrics: Map[String, SQLMetric]): InsertIntoHiveTable +} + +// A dummy output format which does not write anything but only pass output path to native OrcSinkExec. +class AuronOrcOutputFormat + extends FileOutputFormat[NullWritable, NullWritable] + with HiveOutputFormat[NullWritable, NullWritable] { + + override def getRecordWriter( + fileSystem: FileSystem, + jobConf: JobConf, + name: String, + progressable: Progressable): RecordWriter[NullWritable, NullWritable] = + throw new NotImplementedError() + + override def getHiveRecordWriter( + jobConf: JobConf, + finalOutPath: Path, + valueClass: Class[_ <: Writable], + isCompressed: Boolean, + tableProperties: Properties, + progress: Progressable): FileSinkOperator.RecordWriter = { + + new FileSinkOperator.RecordWriter { + override def write(w: Writable): Unit = { + OrcSinkTaskContext.get.processingOutputFiles.offer(finalOutPath.toString) + } + + override def close(abort: Boolean): Unit = {} + } + } +} + +class OrcSinkTaskContext { + var isNative: Boolean = false + val processingOutputFiles = new LinkedBlockingDeque[String]() + val processedOutputFiles = new util.ArrayDeque[OutputFileStat]() +} + +object OrcSinkTaskContext { + private val instances = mutable.Map[Long, OrcSinkTaskContext]() + + def get: OrcSinkTaskContext = { + val taskId = TaskContext.get.taskAttemptId() + instances.getOrElseUpdate( + taskId, { + TaskContext.get().addTaskCompletionListener(_ => instances.remove(taskId)) + new OrcSinkTaskContext + }) + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkBase.scala new file mode 100644 index 000000000..2523b28e7 --- /dev/null +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeOrcSinkBase.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.auron.plan + +import java.net.URI +import java.security.PrivilegedExceptionAction +import java.util.UUID + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapreduce.Job +import org.apache.spark.OneToOneDependency +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.auron.NativeConverters +import org.apache.spark.sql.auron.NativeHelper +import org.apache.spark.sql.auron.NativeRDD +import org.apache.spark.sql.auron.NativeSupports +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.hive.auron.HiveClientHelper +import org.apache.spark.util.SerializableConfiguration + +import org.apache.auron.jni.JniBridge +import org.apache.auron.metric.SparkMetricNode +import org.apache.auron.protobuf.OrcProp +import org.apache.auron.protobuf.OrcSinkExecNode +import org.apache.auron.protobuf.PhysicalPlanNode + +abstract class NativeOrcSinkBase( + sparkSession: SparkSession, + table: CatalogTable, + partition: Map[String, Option[String]], + override val child: SparkPlan, + override val metrics: Map[String, SQLMetric]) + extends UnaryExecNode + with NativeSupports { + + private val dataSchema = NativeOrcInsertIntoHiveTableBase.dataSchema(table, partition) + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doExecuteNative(): NativeRDD = { + val hiveQlTable = HiveClientHelper.toHiveTable(table) + val tableDesc = new TableDesc( + hiveQlTable.getInputFormatClass, + hiveQlTable.getOutputFormatClass, + hiveQlTable.getMetadata) + val hadoopConf = newHadoopConf(tableDesc) + val job = Job.getInstance(hadoopConf) + val orcFileFormat = new OrcFileFormat() + orcFileFormat.prepareWrite(sparkSession, job, Map(), dataSchema) + + val serializableConf = new SerializableConfiguration(job.getConfiguration) + val numDynParts = partition.count(_._2.isEmpty) + + val inputRDD = NativeHelper.executeNative(child) + val nativeMetrics = SparkMetricNode(metrics, inputRDD.metrics :: Nil) + val nativeDependencies = new OneToOneDependency(inputRDD) :: Nil + new NativeRDD( + sparkSession.sparkContext, + nativeMetrics, + inputRDD.partitions, + inputRDD.partitioner, + nativeDependencies, + inputRDD.isShuffleReadFull, + (partition, context) => { + + OrcSinkTaskContext.get.isNative = true + + val resourceId = s"NativeOrcSinkExec:${UUID.randomUUID().toString}" + JniBridge.putResource( + resourceId, + (location: String) => { + NativeHelper.currentUser.doAs(new PrivilegedExceptionAction[FileSystem] { + override def run(): FileSystem = + FileSystem.get(new URI(location), serializableConf.value) + }) + }) + + val job = Job.getInstance(new JobConf(serializableConf.value)) + val nativeProps = job.getConfiguration.asScala + .filter(_.getKey.startsWith("orc.")) + .map(entry => + OrcProp + .newBuilder() + .setKey(entry.getKey) + .setValue(entry.getValue) + .build()) + + val inputPartition = inputRDD.partitions(partition.index) + val orcSink = OrcSinkExecNode + .newBuilder() + .setInput(inputRDD.nativePlan(inputPartition, context)) + .setFsResourceId(resourceId) + .setNumDynParts(numDynParts) + .setSchema(NativeConverters.convertSchema(dataSchema)) + .addAllProp(nativeProps.asJava) + PhysicalPlanNode.newBuilder().setOrcSink(orcSink).build() + }, + friendlyName = "NativeRDD.OrcSink") + } + + @nowarn("cat=unused") // _tableDesc temporarily unused + protected def newHadoopConf(_tableDesc: TableDesc): Configuration = + sparkSession.sessionState.newHadoopConf() +}