Skip to content

Commit

Permalink
fix: Don't panic on unsupported data insert with Postgres (#211)
Browse files Browse the repository at this point in the history
  • Loading branch information
peasee authored Dec 31, 2024
1 parent 9585717 commit dc06251
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 39 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ insta = { version = "1.40.0", features = ["filters"] }

[features]
mysql = ["dep:mysql_async", "dep:async-stream"]
postgres = ["dep:tokio-postgres", "dep:uuid", "dep:postgres-native-tls", "dep:bb8", "dep:bb8-postgres", "dep:native-tls", "dep:pem", "dep:async-stream"]
postgres = ["dep:tokio-postgres", "dep:uuid", "dep:postgres-native-tls", "dep:bb8", "dep:bb8-postgres", "dep:native-tls", "dep:pem", "dep:async-stream", "dep:arrow-schema"]
sqlite = ["dep:rusqlite", "dep:tokio-rusqlite", "dep:arrow-schema"]
duckdb = ["dep:duckdb", "dep:r2d2", "dep:uuid", "dep:dyn-clone", "dep:async-stream", "dep:arrow-schema"]
flight = [
Expand Down
22 changes: 14 additions & 8 deletions src/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use crate::sql::sql_provider_datafusion::{
expr::{self, Engine},
SqlTable,
};
use crate::util::schema::SchemaValidator;
use crate::InvalidTypeAction;
use arrow::{
array::RecordBatch,
datatypes::{Schema, SchemaRef},
Expand Down Expand Up @@ -122,7 +124,9 @@ pub enum Error {
#[snafu(display("Error parsing on_conflict: {source}"))]
UnableToParseOnConflict { source: on_conflict::Error },

#[snafu(display("Failed to create '{table_name}': creating a table with a schema is not supported"))]
#[snafu(display(
"Failed to create '{table_name}': creating a table with a schema is not supported"
))]
TableWithSchemaCreationNotSupported { table_name: String },
}

Expand Down Expand Up @@ -207,12 +211,12 @@ impl TableProviderFactory for PostgresTableProviderFactory {
_state: &dyn Session,
cmd: &CreateExternalTable,
) -> DataFusionResult<Arc<dyn TableProvider>> {

if cmd.name.schema().is_some() {
TableWithSchemaCreationNotSupportedSnafu {
table_name: cmd.name.to_string(),
}
.fail().map_err(to_datafusion_error)?;
.fail()
.map_err(to_datafusion_error)?;
}

let name = cmd.name.clone();
Expand Down Expand Up @@ -259,7 +263,10 @@ impl TableProviderFactory for PostgresTableProviderFactory {
.map_err(to_datafusion_error)?,
);

let schema = Arc::new(schema);
let schema: SchemaRef = Arc::new(schema);
PostgresConnection::handle_unsupported_schema(&schema, InvalidTypeAction::default())
.map_err(|e| DataFusionError::External(e.into()))?;

let postgres = Postgres::new(
name.clone(),
Arc::clone(&pool),
Expand Down Expand Up @@ -424,8 +431,7 @@ impl Postgres {
batch: RecordBatch,
on_conflict: Option<OnConflict>,
) -> Result<()> {
let insert_table_builder =
InsertBuilder::new(&self.table, vec![batch]);
let insert_table_builder = InsertBuilder::new(&self.table, vec![batch]);

let sea_query_on_conflict =
on_conflict.map(|oc| oc.build_sea_query_on_conflict(&self.schema));
Expand Down Expand Up @@ -460,8 +466,8 @@ impl Postgres {
transaction: &Transaction<'_>,
primary_keys: Vec<String>,
) -> Result<()> {
let create_table_statement = CreateTableBuilder::new(schema, self.table.table())
.primary_keys(primary_keys);
let create_table_statement =
CreateTableBuilder::new(schema, self.table.table()).primary_keys(primary_keys);
let create_stmts = create_table_statement.build_postgres();

for create_stmt in create_stmts {
Expand Down
60 changes: 31 additions & 29 deletions src/sql/db_connection_pool/dbconnection/postgresconn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use std::sync::Arc;

use crate::sql::arrow_sql_gen::postgres::rows_to_arrow;
use crate::sql::arrow_sql_gen::postgres::schema::pg_data_type_to_arrow_type;
use crate::util::handle_invalid_type_error;
use crate::util::schema::SchemaValidator;
use arrow::datatypes::Field;
use arrow::datatypes::Schema;
use arrow::datatypes::SchemaRef;
use arrow_schema::DataType;
use async_stream::stream;
use bb8_postgres::tokio_postgres::types::ToSql;
use bb8_postgres::PostgresConnectionManager;
Expand All @@ -25,7 +28,7 @@ use super::AsyncDbConnection;
use super::DbConnection;
use super::Result;

const SCHEMA_QUERY: &str = r#"
const SCHEMA_QUERY: &str = r"
WITH custom_type_details AS (
SELECT
t.typname,
Expand Down Expand Up @@ -109,7 +112,7 @@ c.table_schema = $1
AND c.table_name = $2
ORDER BY
c.ordinal_position;
"#;
";

#[derive(Debug, Snafu)]
pub enum PostgresError {
Expand All @@ -131,6 +134,21 @@ pub struct PostgresConnection {
invalid_type_action: InvalidTypeAction,
}

impl SchemaValidator for PostgresConnection {
type Error = super::Error;

fn is_data_type_valid(data_type: &DataType) -> bool {
!matches!(data_type, DataType::Map(_, _))
}

fn invalid_type_error(data_type: &DataType, field_name: &str) -> Self::Error {
super::Error::UnsupportedDataType {
data_type: data_type.to_string(),
field_name: field_name.to_string(),
}
}
}

impl<'a>
DbConnection<
bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
Expand Down Expand Up @@ -212,13 +230,18 @@ impl<'a>
let nullable_str = row.get::<usize, String>(2);
let nullable = nullable_str == "YES";
let type_details = row.get::<usize, Option<serde_json::Value>>(3);
let arrow_type = match pg_data_type_to_arrow_type(&pg_type, type_details) {
Ok(arrow_type) => arrow_type,
Err(_) => {
handle_unsupported_data_type(&pg_type, &column_name, self.invalid_type_action)?;
continue;
}
let Ok(arrow_type) = pg_data_type_to_arrow_type(&pg_type, type_details) else {
handle_invalid_type_error(
self.invalid_type_action,
super::Error::UnsupportedDataType {
data_type: pg_type.to_string(),
field_name: column_name.to_string(),
},
)?;

continue;
};

fields.push(Field::new(column_name, arrow_type, nullable));
}

Expand Down Expand Up @@ -292,24 +315,3 @@ impl PostgresConnection {
self
}
}

fn handle_unsupported_data_type(
data_type: &str,
field_name: &str,
invalid_type_action: InvalidTypeAction,
) -> Result<(), super::Error> {
let error = super::Error::UnsupportedDataType {
data_type: data_type.to_string(),
field_name: field_name.to_string(),
};
match invalid_type_action {
InvalidTypeAction::Error => {
return Err(error);
}
InvalidTypeAction::Warn => {
tracing::warn!("{error}");
}
InvalidTypeAction::Ignore => {}
}
Ok(())
}
2 changes: 1 addition & 1 deletion src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub mod ns_lookup;
pub mod on_conflict;
pub mod retriable_error;

#[cfg(any(feature = "sqlite", feature = "duckdb"))]
#[cfg(any(feature = "sqlite", feature = "duckdb", feature = "postgres"))]
pub mod schema;
pub mod secrets;
pub mod test;
Expand Down

0 comments on commit dc06251

Please sign in to comment.