Skip to content

Commit

Permalink
feat: allow tls (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
gadomski authored Dec 18, 2024
1 parent 7ff0791 commit cdef6f5
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 10 deletions.
199 changes: 199 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ crate-type = ["cdylib"]
bb8 = "0.8.6"
bb8-postgres = "0.8.1"
geojson = "0.24.1"
pgstac = { version = "0.2.2", git = "https://github.com/stac-utils/stac-rs" }
pgstac = { version = "0.2.2", git = "https://github.com/stac-utils/stac-rs", features = [
"tls",
] }
pyo3 = "0.23.2"
pyo3-async-runtimes = { version = "0.23.0", features = [
"tokio",
Expand Down
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use bb8::{Pool, RunError};
use bb8_postgres::PostgresConnectionManager;
use pgstac::Pgstac;
use pgstac::{make_unverified_tls, MakeRustlsConnect, Pgstac};
use pyo3::{
create_exception,
exceptions::{PyException, PyValueError},
Expand All @@ -18,7 +18,7 @@ use tokio_postgres::{Config, NoTls};
create_exception!(pgstacrs, PgstacError, PyException);
create_exception!(pgstacrs, StacError, PyException);

type PgstacPool = Pool<PostgresConnectionManager<NoTls>>;
type PgstacPool = Pool<PostgresConnectionManager<MakeRustlsConnect>>;

#[derive(Debug, Error)]
enum Error {
Expand Down Expand Up @@ -68,7 +68,7 @@ impl Client {
let config: Config = params
.parse()
.map_err(|err: <Config as FromStr>::Err| PyValueError::new_err(err.to_string()))?;
let manager = PostgresConnectionManager::new(config.clone(), NoTls);
let manager = PostgresConnectionManager::new(config.clone(), make_unverified_tls());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
{
// Quick connection to get better errors, bb8 will just time out
Expand Down Expand Up @@ -308,7 +308,7 @@ impl Client {
fn run<'a, F, T>(
&self,
py: Python<'a>,
f: impl FnOnce(Pool<PostgresConnectionManager<NoTls>>) -> F + Send + 'static,
f: impl FnOnce(Pool<PostgresConnectionManager<MakeRustlsConnect>>) -> F + Send + 'static,
) -> PyResult<Bound<'a, PyAny>>
where
F: Future<Output = Result<T>> + Send,
Expand Down
15 changes: 10 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Any, AsyncIterator, Iterator, cast
from typing import Any, Iterator, cast

import pytest
from pgstacrs import Client
Expand Down Expand Up @@ -38,7 +38,7 @@ def pgstac(


@pytest.fixture
async def client(pgstac: PostgreSQLExecutor) -> AsyncIterator[Client]:
def database_janitor(pgstac: PostgreSQLExecutor) -> Iterator[DatabaseJanitor]:
with DatabaseJanitor(
user=pgstac.user,
host=pgstac.host,
Expand All @@ -48,9 +48,14 @@ async def client(pgstac: PostgreSQLExecutor) -> AsyncIterator[Client]:
dbname="pypgstac_test",
template_dbname=pgstac.template_dbname,
) as database_janitor:
yield await Client.open(
f"user={database_janitor.user} host={database_janitor.host} port={database_janitor.port} dbname={database_janitor.dbname} password={database_janitor.password}"
)
yield database_janitor


@pytest.fixture
async def client(database_janitor: DatabaseJanitor) -> Client:
return await Client.open(
f"user={database_janitor.user} host={database_janitor.host} port={database_janitor.port} dbname={database_janitor.dbname} password={database_janitor.password}"
)


@pytest.fixture
Expand Down

0 comments on commit cdef6f5

Please sign in to comment.