diff --git a/src/main/scala/StreamExtensions.scala b/src/main/scala/StreamExtensions.scala new file mode 100644 index 0000000..2ad2dbf --- /dev/null +++ b/src/main/scala/StreamExtensions.scala @@ -0,0 +1,20 @@ +package org.cvogt.scala.collection + +class StreamExtensions[A](val coll: Stream[A]) extends AnyVal{ + def distinctBy[B](toKey: A => B): Stream[A]= { + def loop(seen: Set[B], rest: Stream[A]): Stream[A] = { + if (rest.isEmpty) { + Stream.Empty + } else { + val elem = toKey(rest.head) + if (seen(elem)) { + loop(seen, rest.tail) + } else { + Stream.cons(rest.head, loop(seen + elem, rest.tail)) + } + } + } + + loop(Set(), coll) + } +} diff --git a/src/main/scala/TraversableLikeExtensions.scala b/src/main/scala/TraversableLikeExtensions.scala new file mode 100644 index 0000000..1ec39ea --- /dev/null +++ b/src/main/scala/TraversableLikeExtensions.scala @@ -0,0 +1,61 @@ +package org.cvogt.scala.collection + +import scala.collection._ +import scala.collection.generic.CanBuildFrom +import scala.collection.mutable.Builder + +class TraversableLikeExtensions[A, Repr](val coll: TraversableLike[A, Repr]) extends AnyVal{ + /** Eliminates duplicates based on the given key function. + There is no guarantee which elements stay in case element two elements result in the same key. + @param toKey maps elements to a key, which is used for comparison*/ + def distinctBy[B, That](toKey: A => B)(implicit bf: CanBuildFrom[Repr, A, That]): That = { + val builder = bf(coll.repr) + val keys = mutable.Set[B]() + for(element <- coll){ + val key = toKey(element) + if (!keys(key)) { + builder += element + keys += key + } + } + builder.result() + } + + /** Groups elements given an equivalence function. + @param symmetric comparison function which tests whether the two arguments are considered equivalent. */ + def groupWith[That](equivalent: (A,A) => Boolean)(implicit bf: CanBuildFrom[Repr, A, That]): Seq[That] = { + var l = List[(A, Builder[A, That])]() + for (elem <- coll) { + val b = l.find{ + case (sample, group) => equivalent(elem,sample) + }.map(_._2).getOrElse{ + val bldr = bf(coll.repr) + l = (elem, bldr) +: l + bldr + } + b += elem + } + val b = Vector.newBuilder[That] + for ((k, v) <- l.reverse) + b += v.result + b.result + } + + /** Eliminates duplicates based on the given equivalence function. + There is no guarantee which elements stay in case element two elements are considered equivalent. + this has runtime O(n^2) + @param symmetric comparison function which tests whether the two arguments are considered equivalent. */ + def distinctWith[That](equivalent: (A,A) => Boolean)(implicit bf: CanBuildFrom[Repr, A, That]): That = { + var l = List[A]() + val b = bf(coll.repr) + for (elem <- coll) { + l.find{ + case first => equivalent(elem,first) + }.getOrElse{ + l = elem +: l + b += elem + } + } + b.result + } +} diff --git a/src/main/scala/collection.scala b/src/main/scala/collection.scala index 660c8ad..655f8ee 100644 --- a/src/main/scala/collection.scala +++ b/src/main/scala/collection.scala @@ -4,67 +4,22 @@ import scala.collection.generic.CanBuildFrom import scala.annotation.tailrec import scala.collection.mutable.Builder - -object `package`{ +sealed abstract class LowPriorityCollectionImplicits { + + implicit def ToTraversableLikeExtensions[A, Repr](coll: TraversableLike[A, Repr]): TraversableLikeExtensions[A, Repr] = + new TraversableLikeExtensions[A, Repr](coll) + +} + +object `package` extends LowPriorityCollectionImplicits { implicit class SeqLikeExtensions[A, Repr](val coll: SeqLike[A, Repr]) extends AnyVal{ /** type-safe contains check */ def containsTyped(t: A) = coll.contains(t) } - implicit class TraversableLikeExtensions[A, Repr](val coll: TraversableLike[A, Repr]) extends AnyVal{ - /** Eliminates duplicates based on the given key function. - There is no guarantee which elements stay in case element two elements result in the same key. - @param toKey maps elements to a key, which is used for comparison*/ - def distinctBy[B, That](toKey: A => B)(implicit bf: CanBuildFrom[Repr, A, That]): That = { - val builder = bf(coll.repr) - val keys = mutable.Set[B]() - for(element <- coll){ - val key = toKey(element) - if (!keys(key)) { - builder += element - keys += key - } - } - builder.result() - } - /** Groups elements given an equivalence function. - @param symmetric comparison function which tests whether the two arguments are considered equivalent. */ - def groupWith[That](equivalent: (A,A) => Boolean)(implicit bf: CanBuildFrom[Repr, A, That]): Seq[That] = { - var l = List[(A, Builder[A, That])]() - for (elem <- coll) { - val b = l.find{ - case (sample, group) => equivalent(elem,sample) - }.map(_._2).getOrElse{ - val bldr = bf(coll.repr) - l = (elem, bldr) +: l - bldr - } - b += elem - } - val b = Vector.newBuilder[That] - for ((k, v) <- l.reverse) - b += v.result - b.result - } + implicit def ToStreamExtensions[A](coll: Stream[A]): StreamExtensions[A] = + new StreamExtensions[A](coll) - /** Eliminates duplicates based on the given equivalence function. - There is no guarantee which elements stay in case element two elements are considered equivalent. - this has runtime O(n^2) - @param symmetric comparison function which tests whether the two arguments are considered equivalent. */ - def distinctWith[That](equivalent: (A,A) => Boolean)(implicit bf: CanBuildFrom[Repr, A, That]): That = { - var l = List[A]() - val b = bf(coll.repr) - for (elem <- coll) { - l.find{ - case first => equivalent(elem,first) - }.getOrElse{ - l = elem +: l - b += elem - } - } - b.result - } - } implicit class TraversableExtensions[A](val coll: Traversable[A]) extends AnyVal{ /** tests weather the given collection has duplicates using default equality == */ def containsDuplicates = coll.groupBy(identity).exists(_._2.size > 1) diff --git a/src/test/scala/collection.scala b/src/test/scala/collection.scala index 2dd61db..3c68a87 100644 --- a/src/test/scala/collection.scala +++ b/src/test/scala/collection.scala @@ -33,6 +33,19 @@ class CollectionTest extends FunSuite{ val ps4 = ps.toSet.distinctBy(_.age) val ps5: Set[Person] = ps4 } + test("Stream#distinctBy"){ + val ps = Stream( + chris, + marcos + ) + assert(2 === ps.distinct.size) + assert(1 === ps.distinctBy(_.age).size) + assert(2 === ps.distinctBy(identity).size) + assert(2 === ps.distinctBy(_.name).size) + + // https://github.com/cvogt/scala-extensions/issues/5 + assert(Stream.from(1).distinctBy(identity).take(5) === Stream(1, 2, 3, 4, 5)) + } test("groupWith"){ val ps = List(chris,marcos) assert(