Skip to content

Commit

Permalink
feat: allow altering prepared execution for Query (#2065)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulfryk committed Dec 18, 2024
1 parent 91effe2 commit 592f967
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 30 deletions.
27 changes: 4 additions & 23 deletions modules/core/src/main/scala/doobie/hi/connection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ object connection {
)

def executionWithResultSet[A](
prepared: PreparedExecutionWithResultSet[A],
prepared: PreparedExecution[A],
loggingInfo: LoggingInfo
): ConnectionIO[A] = executeWithResultSet(
prepared.create,
Expand Down Expand Up @@ -128,7 +128,7 @@ object connection {
)

def executeWithoutResultSet[A](
prepared: PreparedExecutionWithoutResultSet[A],
prepared: PreparedExecutionWithoutProcessStep[A],
loggingInfo: LoggingInfo
): ConnectionIO[A] =
executeWithoutResultSet(
Expand Down Expand Up @@ -248,18 +248,6 @@ object connection {
} yield ele
}

def stream[A: Read](
prepared: PreparedExecutionStream,
loggingInfo: LoggingInfo
): Stream[ConnectionIO, A] =
stream[A](
prepared.create,
prepared.prep,
prepared.exec,
prepared.chunkSize,
loggingInfo
)

// Old implementation, used by deprecated methods
private def liftStream[A: Read](
chunkSize: Int,
Expand Down Expand Up @@ -578,23 +566,16 @@ object connection {
// getMetaData(IFDMD.getTypeInfo.flatMap(IFDMD.embed(_, HRS.list[(String, JdbcType)].map(_.toMap))))
// }

final case class PreparedExecutionWithResultSet[A](
final case class PreparedExecution[A](
create: ConnectionIO[PreparedStatement],
prep: PreparedStatementIO[Unit],
exec: PreparedStatementIO[ResultSet],
process: ResultSetIO[A]
)

final case class PreparedExecutionWithoutResultSet[A](
final case class PreparedExecutionWithoutProcessStep[A](
create: ConnectionIO[PreparedStatement],
prep: PreparedStatementIO[Unit],
exec: PreparedStatementIO[A]
)

final case class PreparedExecutionStream(
create: ConnectionIO[PreparedStatement],
prep: PreparedStatementIO[Unit],
exec: PreparedStatementIO[ResultSet],
chunkSize: Int
)
}
53 changes: 46 additions & 7 deletions modules/core/src/main/scala/doobie/util/query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import doobie.free.connection.ConnectionIO
import doobie.free.preparedstatement.PreparedStatementIO
import doobie.free.resultset.ResultSetIO
import doobie.free.{connection as IFC, preparedstatement as IFPS}
import doobie.hi.connection.PreparedExecutionWithResultSet
import doobie.hi.connection.PreparedExecution
import doobie.hi.{connection as IHC, preparedstatement as IHPS, resultset as IHRS}
import doobie.util.MultiVersionTypeSupport.=:=
import doobie.util.analysis.Analysis
Expand Down Expand Up @@ -105,6 +105,15 @@ object query {
toConnectionIO(a, IHRS.build[F, B])
}

/** Just like `to` but allowing to alter `PreparedExecution`.
*/
def toAlteringExecution[F[_]](
a: A,
fn: PreparedExecution[F[B]] => PreparedExecution[F[B]]
)(implicit f: FactoryCompat[B, F[B]]): ConnectionIO[F[B]] = {
toConnectionIOAlteringExecution(a, IHRS.build[F, B], fn)
}

/** Apply the argument `a` to construct a program in `[[doobie.free.connection.ConnectionIO ConnectionIO]]` yielding
* an `Map[(K, V)]` accumulated via the provided `CanBuildFrom`. This is the fastest way to accumulate a
* collection. this function can call only when B is (K, V).
Expand All @@ -115,7 +124,10 @@ object query {

/** Just like `toMap` but allowing to alter `PreparedExecution`.
*/
def toMapAlteringExecution[K, V](a: A, fn: PreparedExecutionUpdate[Map[K, V]])(implicit
def toMapAlteringExecution[K, V](
a: A,
fn: PreparedExecution[Map[K, V]] => PreparedExecution[Map[K, V]]
)(implicit
ev: B =:= (K, V),
f: FactoryCompat[(K, V), Map[K, V]]
): ConnectionIO[Map[K, V]] =
Expand All @@ -128,20 +140,41 @@ object query {
def accumulate[F[_]: Alternative](a: A): ConnectionIO[F[B]] =
toConnectionIO(a, IHRS.accumulate[F, B])

/** Just like `accumulate` but allowing to alter `PreparedExecution`.
*/
def accumulateAlteringExecution[F[_]: Alternative](
a: A,
fn: PreparedExecution[F[B]] => PreparedExecution[F[B]]
): ConnectionIO[F[B]] =
toConnectionIOAlteringExecution(a, IHRS.accumulate[F, B], fn)

/** Apply the argument `a` to construct a program in `[[doobie.free.connection.ConnectionIO ConnectionIO]]` yielding
* a unique `B` and raising an exception if the resultset does not have exactly one row. See also `option`.
* @group Results
*/
def unique(a: A): ConnectionIO[B] =
toConnectionIO(a, IHRS.getUnique[B])

/** Just like `unique` but allowing to alter `PreparedExecution`.
*/
def uniqueAlteringExecution(a: A, fn: PreparedExecution[B] => PreparedExecution[B]): ConnectionIO[B] =
toConnectionIOAlteringExecution(a, IHRS.getUnique[B], fn)

/** Apply the argument `a` to construct a program in `[[doobie.free.connection.ConnectionIO ConnectionIO]]` yielding
* an optional `B` and raising an exception if the resultset has more than one row. See also `unique`.
* @group Results
*/
def option(a: A): ConnectionIO[Option[B]] =
toConnectionIO(a, IHRS.getOption[B])

/** Just like `option` but allowing to alter `PreparedExecution`.
*/
def optionAlteringExecution(
a: A,
fn: PreparedExecution[Option[B]] => PreparedExecution[Option[B]]
): ConnectionIO[Option[B]] =
toConnectionIOAlteringExecution(a, IHRS.getOption[B], fn)

/** Apply the argument `a` to construct a program in `[[doobie.free.connection.ConnectionIO ConnectionIO]]` yielding
* an `NonEmptyList[B]` and raising an exception if the resultset does not have at least one row. See also
* `unique`.
Expand All @@ -150,18 +183,26 @@ object query {
def nel(a: A): ConnectionIO[NonEmptyList[B]] =
toConnectionIO(a, IHRS.nel[B])

/** Just like `nel` but allowing to alter `PreparedExecution`.
*/
def nelAlteringExecution(
a: A,
fn: PreparedExecution[NonEmptyList[B]] => PreparedExecution[NonEmptyList[B]]
): ConnectionIO[NonEmptyList[B]] =
toConnectionIOAlteringExecution(a, IHRS.nel[B], fn)

private def toConnectionIO[C](a: A, rsio: ResultSetIO[C]): ConnectionIO[C] =
IHC.executionWithResultSet(preparedExecution(sql, a, rsio), mkLoggingInfo(a))

private def toConnectionIOAlteringExecution[C](
a: A,
rsio: ResultSetIO[C],
fn: PreparedExecutionUpdate[C]
fn: PreparedExecution[C] => PreparedExecution[C]
): ConnectionIO[C] =
IHC.executionWithResultSet(fn(preparedExecution(sql, a, rsio)), mkLoggingInfo(a))

private def preparedExecution[C](sql: String, a: A, rsio: ResultSetIO[C]): PreparedExecutionWithResultSet[C] =
PreparedExecutionWithResultSet(
private def preparedExecution[C](sql: String, a: A, rsio: ResultSetIO[C]): PreparedExecution[C] =
PreparedExecution(
create = IFC.prepareStatement(sql),
prep = IHPS.set(a),
exec = IFPS.executeQuery,
Expand Down Expand Up @@ -264,8 +305,6 @@ object query {

}

type PreparedExecutionUpdate[A] = PreparedExecutionWithResultSet[A] => PreparedExecutionWithResultSet[A]

/** An abstract query closed over its input arguments and yielding values of type `B`, without a specified
* disposition. Methods provided on `[[Query0]]` allow the query to be interpreted as a stream or program in
* `CollectionIO`.
Expand Down
70 changes: 70 additions & 0 deletions modules/core/src/test/scala/doobie/util/QuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ class QuerySuite extends munit.FunSuite {
assertEquals(q.contramap[Int](n => "bar" * n).to[List](1).transact(xa).unsafeRunSync(), Nil)
}

test("Query toAlteringExecution (result set operations)") {
var didRun = false

pairQuery.toAlteringExecution[List](
"x",
{ preparedExec =>
val process = IHRS.delay { didRun = true } *> preparedExec.process
preparedExec.copy(process = process)
})
.transact(xa).unsafeRunSync()

assert(didRun)
}

test("Query toMapAlteringExecution (result set operations)") {
var didRun = false

Expand All @@ -77,6 +91,62 @@ class QuerySuite extends munit.FunSuite {
assert(didRun)
}

test("Query accumulateAlteringExecution (result set operations)") {
var didRun = false

pairQuery.accumulateAlteringExecution[List](
"x",
{ preparedExec =>
val process = IHRS.delay { didRun = true } *> preparedExec.process
preparedExec.copy(process = process)
})
.transact(xa).unsafeRunSync()

assert(didRun)
}

test("Query uniqueAlteringExecution (result set operations)") {
var didRun = false

pairQuery.uniqueAlteringExecution(
"foo",
{ preparedExec =>
val process = IHRS.delay { didRun = true } *> preparedExec.process
preparedExec.copy(process = process)
})
.transact(xa).unsafeRunSync()

assert(didRun)
}

test("Query optionAlteringExecution (result set operations)") {
var didRun = false

pairQuery.optionAlteringExecution(
"x",
{ preparedExec =>
val process = IHRS.delay { didRun = true } *> preparedExec.process
preparedExec.copy(process = process)
})
.transact(xa).unsafeRunSync()

assert(didRun)
}

test("Query nelAlteringExecution (result set operations)") {
var didRun = false

pairQuery.nelAlteringExecution(
"foo",
{ preparedExec =>
val process = IHRS.delay { didRun = true } *> preparedExec.process
preparedExec.copy(process = process)
})
.transact(xa).unsafeRunSync()

assert(didRun)
}

test("Query0 from Query (non-empty) to") {
assertEquals(q.toQuery0("foo").to[List].transact(xa).unsafeRunSync(), List(123))
}
Expand Down

0 comments on commit 592f967

Please sign in to comment.