Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sortedTake in beam runner #1949

Merged
merged 3 commits into from
Sep 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import com.twitter.scalding.typed._
import com.twitter.scalding.typed.functions.{
FilterKeysToFilter,
FlatMapValuesToFlatMap,
MapValuesToMap
MapValuesToMap,
ScaldingPriorityQueueMonoid
}

object BeamPlanner {
Expand Down Expand Up @@ -65,7 +66,12 @@ object BeamPlanner {
config.getMapSideAggregationThreshold match {
case None => op
case Some(count) =>
op.mapSideAggregator(count, sg)
// Semigroup is invariant on T. We cannot pattern match as it is a Semigroup[PriorityQueue[T]]
if (sg.isInstanceOf[ScaldingPriorityQueueMonoid[_]]) {
op
} else {
op.mapSideAggregator(count, sg)
}
}
case (ReduceStepPipe(ir @ IdentityReduce(_, _, _, _, _)), rec) =>
def go[K, V1, V2](ir: IdentityReduce[K, V1, V2]): BeamOp[(K, V2)] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
package com.twitter.scalding.beam_backend

import com.twitter.algebird.Semigroup
import com.twitter.algebird.mutable.PriorityQueueMonoid
import com.twitter.scalding.Config
import com.twitter.scalding.beam_backend.BeamFunctions._
import com.twitter.scalding.typed.functions.ComposedFunctions.ComposedMapGroup
import com.twitter.scalding.typed.functions.{EmptyGuard, MapValueStream, SumAll}
import com.twitter.scalding.typed.functions.{
EmptyGuard,
MapValueStream,
ScaldingPriorityQueueMonoid,
SumAll
}
import com.twitter.scalding.typed.{CoGrouped, TypedSource}
import java.lang
import java.util.PriorityQueue
import java.util.{Comparator, PriorityQueue}
import org.apache.beam.sdk.Pipeline
import org.apache.beam.sdk.coders.{Coder, IterableCoder, KvCoder}
import org.apache.beam.sdk.transforms.DoFn.ProcessElement
import org.apache.beam.sdk.transforms.Top.TopCombineFn
import org.apache.beam.sdk.transforms._
import org.apache.beam.sdk.transforms.join.{
CoGbkResult,
Expand Down Expand Up @@ -52,6 +57,10 @@ sealed abstract class BeamOp[+A] {
parDo(FlatMapFn(f))
}

private final case class SerializableComparator[T](comp: Comparator[T]) extends Comparator[T] {
nownikhil marked this conversation as resolved.
Show resolved Hide resolved
override def compare(o1: T, o2: T): Int = comp.compare(o1, o2)
}

object BeamOp extends Serializable {
implicit private def fakeClassTag[A]: ClassTag[A] = ClassTag(classOf[AnyRef]).asInstanceOf[ClassTag[A]]

Expand All @@ -61,19 +70,24 @@ object BeamOp extends Serializable {
)(implicit ordK: Ordering[K], kryoCoder: KryoCoder): PCollection[KV[K, java.lang.Iterable[U]]] = {
reduceFn match {
case ComposedMapGroup(f, g) => planMapGroup(planMapGroup(pcoll, f), g)
case EmptyGuard(MapValueStream(SumAll(pqm: PriorityQueueMonoid[V]))) =>
pcoll.apply(MapElements.via(
new SimpleFunction[KV[K, java.lang.Iterable[V]], KV[K, java.lang.Iterable[U]]]() {
override def apply(input: KV[K, lang.Iterable[V]]): KV[K, java.lang.Iterable[U]] = {
// We are not using plus method defined in PriorityQueueMonoid as it is mutating
// input Priority Queues. We create a new PQ from the individual ones.
// We didn't use Top PTransformation in beam as it is not needed, also
// we cannot access `max` defined in PQ monoid.
val flattenedValues = input.getValue.asScala.flatMap { value =>
value.asInstanceOf[PriorityQueue[V]].iterator().asScala
}
val mergedPQ = pqm.build(flattenedValues)
KV.of(input.getKey, Iterable(mergedPQ.asInstanceOf[U]).asJava)
case EmptyGuard(MapValueStream(SumAll(pqm: ScaldingPriorityQueueMonoid[v]))) =>
val vCollection = pcoll.asInstanceOf[PCollection[KV[K, java.lang.Iterable[PriorityQueue[v]]]]]

vCollection.apply(MapElements.via(
new SimpleFunction[KV[K, java.lang.Iterable[PriorityQueue[v]]], KV[K, java.lang.Iterable[U]]]() {
override def apply(input: KV[K, lang.Iterable[PriorityQueue[v]]]): KV[K, java.lang.Iterable[U]] = {

val topCombineFn = new TopCombineFn[v, SerializableComparator[v]](
nownikhil marked this conversation as resolved.
Show resolved Hide resolved
pqm.count,
SerializableComparator[v](pqm.ordering.reverse)
)

@inline def flattenedValues: Stream[v] =
input.getValue.asScala.toStream.flatMap(_.asScala.toStream)

val outputs: java.util.List[v] = topCombineFn.apply(flattenedValues.asJava)
val pqs = pqm.build(outputs.asScala)
nownikhil marked this conversation as resolved.
Show resolved Hide resolved
KV.of(input.getKey, Iterable(pqs.asInstanceOf[U]).asJava)
}
})
).setCoder(KvCoder.of(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.twitter.scalding.beam_backend

import com.twitter.algebird.mutable.PriorityQueueMonoid
import com.twitter.algebird.{AveragedValue, Semigroup}
import com.twitter.scalding.{Config, TextLine, TypedPipe}
import java.io.File
import java.nio.file.Paths
import java.util.PriorityQueue
import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory}
import org.scalatest.{BeforeAndAfter, FunSuite}
import scala.io.Source
Expand Down Expand Up @@ -113,6 +115,18 @@ class BeamBackendTests extends FunSuite with BeforeAndAfter {
)
}

test("bufferedTake"){
beamMatchesSeq(
TypedPipe
.from(1 to 50)
.groupAll
.bufferedTake(100)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a note: this is going to be really bad in a real job without map-side aggregation: the key is Unit so there is only one key, so this would have each mapper send 100, then have the reducers pick 100 of those.

But with no mapside aggregation, all the data will be sent to the reducers, and they will throw away all but 100.

But we can add an issue and come back and address this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened a ticket for this.
#1952

.map(_._2),
1 to 50,
Config(Map("cascading.aggregateby.threshold" -> "100"))
)
}

test("SumByLocalKeys"){
beamMatchesSeq(
TypedPipe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import com.twitter.algebird.{
Aggregator
}

import com.twitter.algebird.mutable.PriorityQueueMonoid
import com.twitter.scalding.typed.functions.ScaldingPriorityQueueMonoid

import java.util.PriorityQueue

Expand Down Expand Up @@ -391,7 +391,7 @@ trait ReduceOperations[+Self <: ReduceOperations[Self]] extends java.io.Serializ
def sortedTake[T](f: (Fields, Fields), k: Int)(implicit conv: TupleConverter[T], ord: Ordering[T]): Self = {

assert(f._2.size == 1, "output field size must be 1")
implicit val mon: PriorityQueueMonoid[T] = new PriorityQueueMonoid[T](k)
implicit val mon: ScaldingPriorityQueueMonoid[T] = new ScaldingPriorityQueueMonoid[T](k)
mapPlusMap(f) { (tup: T) => mon.build(tup) } {
(lout: PriorityQueue[T]) => lout.iterator.asScala.toList.sorted
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
package com.twitter.scalding.typed

import com.twitter.algebird.Semigroup
import com.twitter.algebird.mutable.PriorityQueueMonoid
import com.twitter.scalding.typed.functions._
import com.twitter.scalding.typed.functions.ComposedFunctions.ComposedMapGroup
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -659,7 +658,7 @@ final case class UnsortedIdentityReduce[K, V1, V2](
// If you care which items you take, you should sort by a random number
// or the value itself.
val fakeOrdering: Ordering[V1] = Ordering.by { v: V1 => v.hashCode }
implicit val mon: PriorityQueueMonoid[V1] = new PriorityQueueMonoid[V1](n)(fakeOrdering)
implicit val mon: ScaldingPriorityQueueMonoid[V1] = new ScaldingPriorityQueueMonoid[V1](n)(fakeOrdering)
// Do the heap-sort on the mappers:
val pretake: TypedPipe[(K, V1)] = mapped.mapValues { v: V1 => mon.build(v) }
.sumByLocalKeys
Expand Down Expand Up @@ -745,7 +744,7 @@ final case class IdentityValueSortedReduce[K, V1, V2](
// This means don't take anything, which is legal, but strange
filterKeys(Constant(false))
} else {
implicit val mon: PriorityQueueMonoid[V1] = new PriorityQueueMonoid[V1](n)(valueSort)
implicit val mon: ScaldingPriorityQueueMonoid[V1] = new ScaldingPriorityQueueMonoid[V1](n)(valueSort)
// Do the heap-sort on the mappers:
val pretake: TypedPipe[(K, V1)] = mapped.mapValues { v: V1 => mon.build(v) }
.sumByLocalKeys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import java.io.Serializable
import scala.collection.JavaConverters._

import com.twitter.algebird.{ Fold, Semigroup, Ring, Aggregator }
import com.twitter.algebird.mutable.PriorityQueueMonoid

import com.twitter.scalding.typed.functions._

Expand Down Expand Up @@ -79,7 +78,7 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] extends Se
// If you care which items you take, you should sort by a random number
// or the value itself.
val fakeOrdering: Ordering[T] = Ordering.by { v: T => v.hashCode }
implicit val mon = new PriorityQueueMonoid(n)(fakeOrdering)
implicit val mon = new ScaldingPriorityQueueMonoid(n)(fakeOrdering)
mapValues(mon.build(_))
// Do the heap-sort on the mappers:
.sum
Expand Down Expand Up @@ -213,7 +212,7 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] extends Se
* to fit in memory.
*/
def sortedTake[U >: T](k: Int)(implicit ord: Ordering[U]): This[K, Seq[U]] = {
val mon = new PriorityQueueMonoid[U](k)(ord)
val mon = new ScaldingPriorityQueueMonoid[U](k)(ord)
mapValues(mon.build(_))
.sum(mon) // results in a PriorityQueue
// scala can't infer the type, possibly due to the view bound on TypedPipe
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.twitter.scalding.typed.functions

import com.twitter.algebird.mutable.PriorityQueueMonoid

class ScaldingPriorityQueueMonoid[K](
val count: Int
)(implicit val ordering: Ordering[K]) extends PriorityQueueMonoid[K](count)(ordering)