Skip to content

Commit

Permalink
Add missing fanout SCollection API (#5497)
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Sep 19, 2024
1 parent 965a4ab commit f36016b
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ private[scio] object TupleFunctions {
def klToTuple[K](kv: KV[K, java.lang.Long]): (K, Long) =
(kv.getKey, kv.getValue)

def kdToTuple[K](kv: KV[K, java.lang.Double]): (K, Double) =
(kv.getKey, kv.getValue)

def kvIterableToTuple[K, V](kv: KV[K, JIterable[V]]): (K, Iterable[V]) =
(kv.getKey, kv.getValue.asScala)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ import com.spotify.scio.ScioContext
import com.spotify.scio.util.Functions
import com.spotify.scio.coders.Coder
import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator, Semigroup}
import org.apache.beam.sdk.transforms.{Combine, Top}
import org.apache.beam.sdk.transforms.{Combine, Latest, Mean, Reify, Top}
import org.apache.beam.sdk.values.PCollection

import java.lang.{Iterable => JIterable}

import java.lang.{Double => JDouble, Iterable => JIterable}
import scala.jdk.CollectionConverters._

/**
Expand Down Expand Up @@ -116,6 +115,32 @@ class SCollectionWithFanout[T] private[values] (coll: SCollection[T], fanout: In
)
}

/** [[SCollection.min]] with fan out. */
def min(implicit ord: Ordering[T]): SCollection[T] =
this.reduce(ord.min)

/** [[SCollection.max]] with fan out. */
def max(implicit ord: Ordering[T]): SCollection[T] =
this.reduce(ord.max)

/** [[SCollection.mean]] with fan out. */
def mean(implicit ev: Numeric[T]): SCollection[Double] = {
val e = ev // defeat closure
coll.transform { in =>
in.map[JDouble](e.toDouble)
.pApply(Mean.globally().withFanout(fanout))
.asInstanceOf[SCollection[Double]]
}
}

/** [[SCollection.latest]] with fan out. */
def latest: SCollection[T] = {
coll.transform { in =>
in.pApply("Reify Timestamps", Reify.timestamps[T]())
.pApply("Latest Value", Combine.globally(Latest.combineFn[T]()).withFanout(fanout))
}
}

/** [[SCollection.top]] with fan out. */
def top(num: Int)(implicit ord: Ordering[T]): SCollection[Iterable[T]] = {
coll.transform { in =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@ import com.spotify.scio.util.TupleFunctions._
import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator, Semigroup}
import org.apache.beam.sdk.transforms.Combine.PerKeyWithHotKeyFanout
import org.apache.beam.sdk.transforms.Top.TopCombineFn
import org.apache.beam.sdk.transforms.{Combine, SerializableFunction}
import org.apache.beam.sdk.transforms.{
Combine,
Latest,
Mean,
PTransform,
Reify,
SerializableFunction
}
import org.apache.beam.sdk.values.{KV, PCollection}

import java.lang.{Double => JDouble}

/**
* An enhanced SCollection that uses an intermediate node to combine "hot" keys partially before
Expand Down Expand Up @@ -142,6 +152,34 @@ class SCollectionWithHotKeyFanout[K, V] private[values] (
self.applyPerKey(withFanout(Combine.perKey(Functions.reduceFn(context, sg))))(kvToTuple)
}

/** [[SCollection.min]] with hot key fan out. */
def minByKey(implicit ord: Ordering[V]): SCollection[(K, V)] =
self.reduceByKey(ord.min)

/** [[SCollection.max]] with hot key fan out. */
def maxByKey(implicit ord: Ordering[V]): SCollection[(K, V)] =
self.reduceByKey(ord.max)

/** [[SCollection.mean]] with hot key fan out. */
def meanByKey(implicit ev: Numeric[V]): SCollection[(K, Double)] = {
val e = ev // defeat closure
self.self.transform { in =>
in.mapValues[JDouble](e.toDouble).applyPerKey(Mean.perKey[K, JDouble]())(kdToTuple)
}
}

/** [[SCollection.latest]] with hot key fan out. */
def latestByKey: SCollection[(K, V)] = {
self.applyPerKey(new PTransform[PCollection[KV[K, V]], PCollection[KV[K, V]]]() {
override def expand(input: PCollection[KV[K, V]]): PCollection[KV[K, V]] = {
input
.apply("Reify Timestamps", Reify.timestampsInValue[K, V])
.apply("Latest Value", withFanout(Combine.perKey(Latest.combineFn[V]())))
.setCoder(input.getCoder)
}
})(kvToTuple)
}

/** [[PairSCollectionFunctions.topByKey]] with hot key fanout. */
def topByKey(num: Int)(implicit ord: Ordering[V]): SCollection[(K, Iterable[V])] =
self.applyPerKey(withFanout(Combine.perKey(new TopCombineFn[V, Ordering[V]](num, ord))))(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.spotify.scio.values

import com.twitter.algebird.{Aggregator, Semigroup}
import com.spotify.scio.coders.Coder
import org.joda.time.Instant

class SCollectionWithFanoutTest extends NamedTransformSpec {
"SCollectionWithFanout" should "support aggregate()" in {
Expand Down Expand Up @@ -60,7 +61,7 @@ class SCollectionWithFanoutTest extends NamedTransformSpec {
}
}

it should "support sum()" in {
it should "support sum" in {
runWithContext { sc =>
def sum[T: Coder: Semigroup](elems: T*): SCollection[T] =
sc.parallelize(elems).withFanout(10).sum
Expand All @@ -72,6 +73,51 @@ class SCollectionWithFanoutTest extends NamedTransformSpec {
}
}

it should "support min" in {
runWithContext { sc =>
def min[T: Coder: Ordering](elems: T*): SCollection[T] =
sc.parallelize(elems).withFanout(10).min
min(1, 2, 3) should containSingleValue(1)
min(1L, 2L, 3L) should containSingleValue(1L)
min(1f, 2f, 3f) should containSingleValue(1f)
min(1.0, 2.0, 3.0) should containSingleValue(1.0)
min(1 to 100: _*) should containSingleValue(1)
}
}

it should "support max" in {
runWithContext { sc =>
def max[T: Coder: Ordering](elems: T*): SCollection[T] =
sc.parallelize(elems).withFanout(10).max
max(1, 2, 3) should containSingleValue(3)
max(1L, 2L, 3L) should containSingleValue(3L)
max(1f, 2f, 3f) should containSingleValue(3f)
max(1.0, 2.0, 3.0) should containSingleValue(3.0)
max(1 to 100: _*) should containSingleValue(100)
}
}

it should "support mean" in {
runWithContext { sc =>
def mean[T: Coder: Numeric](elems: T*): SCollection[Double] =
sc.parallelize(elems).withFanout(10).mean
mean(1, 2, 3) should containSingleValue(2.0)
mean(1L, 2L, 3L) should containSingleValue(2.0)
mean(1f, 2f, 3f) should containSingleValue(2.0)
mean(1.0, 2.0, 3.0) should containSingleValue(2.0)
mean(0 to 100: _*) should containSingleValue(50.0)
}
}

it should "support latest" in {
runWithContext { sc =>
def latest(elems: Long*): SCollection[Long] =
sc.parallelize(elems).timestampBy(Instant.ofEpochMilli).withFanout(10).latest
latest(1L, 2L, 3L) should containSingleValue(3L)
latest(1L to 100L: _*) should containSingleValue(100L)
}
}

it should "support top()" in {
runWithContext { sc =>
def top3[T: Ordering: Coder](elems: T*): SCollection[Iterable[T]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.spotify.scio.values

import com.twitter.algebird.Aggregator
import org.joda.time.Instant

class SCollectionWithHotKeyFanoutTest extends NamedTransformSpec {
"SCollectionWithHotKeyFanout" should "support aggregateByKey()" in {
Expand Down Expand Up @@ -83,7 +84,7 @@ class SCollectionWithHotKeyFanoutTest extends NamedTransformSpec {
}
}

it should "support sumByKey()" in {
it should "support sumByKey" in {
runWithContext { sc =>
val p = sc.parallelize(List(("a", 1), ("b", 2), ("b", 2)) ++ (1 to 100).map(("c", _)))
val r1 = p.withHotKeyFanout(10).sumByKey
Expand All @@ -93,6 +94,48 @@ class SCollectionWithHotKeyFanoutTest extends NamedTransformSpec {
}
}

it should "support minByKey" in {
runWithContext { sc =>
val p = sc.parallelize(List(("a", 1), ("b", 2), ("b", 3)) ++ (1 to 100).map(("c", _)))
val r1 = p.withHotKeyFanout(10).minByKey
val r2 = p.withHotKeyFanout(_.hashCode).minByKey
r1 should containInAnyOrder(Seq(("a", 1), ("b", 2), ("c", 1)))
r2 should containInAnyOrder(Seq(("a", 1), ("b", 2), ("c", 1)))
}
}

it should "support maxByKey" in {
runWithContext { sc =>
val p = sc.parallelize(List(("a", 1), ("b", 2), ("b", 3)) ++ (1 to 100).map(("c", _)))
val r1 = p.withHotKeyFanout(10).maxByKey
val r2 = p.withHotKeyFanout(_.hashCode).maxByKey
r1 should containInAnyOrder(Seq(("a", 1), ("b", 3), ("c", 100)))
r2 should containInAnyOrder(Seq(("a", 1), ("b", 3), ("c", 100)))
}
}

it should "support meanByKey" in {
runWithContext { sc =>
val p = sc.parallelize(List(("a", 1), ("b", 2), ("b", 3)) ++ (0 to 100).map(("c", _)))
val r1 = p.withHotKeyFanout(10).meanByKey
val r2 = p.withHotKeyFanout(_.hashCode).meanByKey
r1 should containInAnyOrder(Seq(("a", 1.0), ("b", 2.5), ("c", 50.0)))
r2 should containInAnyOrder(Seq(("a", 1.0), ("b", 2.5), ("c", 50.0)))
}
}

it should "support latestByKey" in {
runWithContext { sc =>
val p = sc
.parallelize(List(("a", 1L), ("b", 2L), ("b", 3L)) ++ (1L to 100L).map(("c", _)))
.timestampBy { case (_, v) => Instant.ofEpochMilli(v) }
val r1 = p.withHotKeyFanout(10).latestByKey
val r2 = p.withHotKeyFanout(_.hashCode).latestByKey
r1 should containInAnyOrder(Seq(("a", 1L), ("b", 3L), ("c", 100L)))
r2 should containInAnyOrder(Seq(("a", 1L), ("b", 3L), ("c", 100L)))
}
}

it should "support topByKey()" in {
runWithContext { sc =>
val p = sc.parallelize(
Expand Down

0 comments on commit f36016b

Please sign in to comment.