Skip to content

Commit 0cc7b88

Browse files
committed
[SPARK-5536] replace old ALS implementation by the new one
The only issue is that `analyzeBlock` is removed, which was marked as a developer API. I didn't change other tests in the ALSSuite under `spark.mllib` to ensure that the implementation is correct. CC: srowen coderxiang Author: Xiangrui Meng <[email protected]> Closes apache#4321 from mengxr/SPARK-5536 and squashes the following commits: 5a3cee8 [Xiangrui Meng] update python tests that are too strict e840acf [Xiangrui Meng] ignore scala style check for ALS.train e9a721c [Xiangrui Meng] update mima excludes 9ee6a36 [Xiangrui Meng] merge master 9a8aeac [Xiangrui Meng] update tests d8c3271 [Xiangrui Meng] remove analyzeBlocks d68eee7 [Xiangrui Meng] add checkpoint to new ALS 22a56f8 [Xiangrui Meng] wrap old ALS c387dff [Xiangrui Meng] support random seed 3bdf24b [Xiangrui Meng] make storage level configurable in the new ALS
1 parent b8ebebe commit 0cc7b88

File tree

6 files changed

+90
-622
lines changed

6 files changed

+90
-622
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

+45-24
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.{util => ju}
2222
import scala.collection.mutable
2323
import scala.reflect.ClassTag
2424
import scala.util.Sorting
25+
import scala.util.hashing.byteswap64
2526

2627
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2728
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
@@ -37,6 +38,7 @@ import org.apache.spark.rdd.RDD
3738
import org.apache.spark.sql.DataFrame
3839
import org.apache.spark.sql.Dsl._
3940
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
41+
import org.apache.spark.storage.StorageLevel
4042
import org.apache.spark.util.Utils
4143
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
4244
import org.apache.spark.util.random.XORShiftRandom
@@ -412,7 +414,7 @@ object ALS extends Logging {
412414
/**
413415
* Implementation of the ALS algorithm.
414416
*/
415-
def train[ID: ClassTag](
417+
def train[ID: ClassTag]( // scalastyle:ignore
416418
ratings: RDD[Rating[ID]],
417419
rank: Int = 10,
418420
numUserBlocks: Int = 10,
@@ -421,34 +423,47 @@ object ALS extends Logging {
421423
regParam: Double = 1.0,
422424
implicitPrefs: Boolean = false,
423425
alpha: Double = 1.0,
424-
nonnegative: Boolean = false)(
426+
nonnegative: Boolean = false,
427+
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
428+
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
429+
seed: Long = 0L)(
425430
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
431+
require(intermediateRDDStorageLevel != StorageLevel.NONE,
432+
"ALS is not designed to run without persisting intermediate RDDs.")
433+
val sc = ratings.sparkContext
426434
val userPart = new HashPartitioner(numUserBlocks)
427435
val itemPart = new HashPartitioner(numItemBlocks)
428436
val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
429437
val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
430438
val solver = if (nonnegative) new NNLSSolver else new CholeskySolver
431-
val blockRatings = partitionRatings(ratings, userPart, itemPart).cache()
432-
val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart)
439+
val blockRatings = partitionRatings(ratings, userPart, itemPart)
440+
.persist(intermediateRDDStorageLevel)
441+
val (userInBlocks, userOutBlocks) =
442+
makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel)
433443
// materialize blockRatings and user blocks
434444
userOutBlocks.count()
435445
val swappedBlockRatings = blockRatings.map {
436446
case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) =>
437447
((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings))
438448
}
439-
val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart)
449+
val (itemInBlocks, itemOutBlocks) =
450+
makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel)
440451
// materialize item blocks
441452
itemOutBlocks.count()
442-
var userFactors = initialize(userInBlocks, rank)
443-
var itemFactors = initialize(itemInBlocks, rank)
453+
val seedGen = new XORShiftRandom(seed)
454+
var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
455+
var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
444456
if (implicitPrefs) {
445457
for (iter <- 1 to maxIter) {
446-
userFactors.setName(s"userFactors-$iter").persist()
458+
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
447459
val previousItemFactors = itemFactors
448460
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
449461
userLocalIndexEncoder, implicitPrefs, alpha, solver)
450462
previousItemFactors.unpersist()
451-
itemFactors.setName(s"itemFactors-$iter").persist()
463+
if (sc.checkpointDir.isDefined && (iter % 3 == 0)) {
464+
itemFactors.checkpoint()
465+
}
466+
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
452467
val previousUserFactors = userFactors
453468
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
454469
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
@@ -467,21 +482,23 @@ object ALS extends Logging {
467482
.join(userFactors)
468483
.values
469484
.setName("userFactors")
470-
.cache()
471-
userIdAndFactors.count()
472-
itemFactors.unpersist()
485+
.persist(finalRDDStorageLevel)
473486
val itemIdAndFactors = itemInBlocks
474487
.mapValues(_.srcIds)
475488
.join(itemFactors)
476489
.values
477490
.setName("itemFactors")
478-
.cache()
479-
itemIdAndFactors.count()
480-
userInBlocks.unpersist()
481-
userOutBlocks.unpersist()
482-
itemInBlocks.unpersist()
483-
itemOutBlocks.unpersist()
484-
blockRatings.unpersist()
491+
.persist(finalRDDStorageLevel)
492+
if (finalRDDStorageLevel != StorageLevel.NONE) {
493+
userIdAndFactors.count()
494+
itemFactors.unpersist()
495+
itemIdAndFactors.count()
496+
userInBlocks.unpersist()
497+
userOutBlocks.unpersist()
498+
itemInBlocks.unpersist()
499+
itemOutBlocks.unpersist()
500+
blockRatings.unpersist()
501+
}
485502
val userOutput = userIdAndFactors.flatMap { case (ids, factors) =>
486503
ids.view.zip(factors)
487504
}
@@ -546,14 +563,15 @@ object ALS extends Logging {
546563
*/
547564
private def initialize[ID](
548565
inBlocks: RDD[(Int, InBlock[ID])],
549-
rank: Int): RDD[(Int, FactorBlock)] = {
566+
rank: Int,
567+
seed: Long): RDD[(Int, FactorBlock)] = {
550568
// Choose a unit vector uniformly at random from the unit sphere, but from the
551569
// "first quadrant" where all elements are nonnegative. This can be done by choosing
552570
// elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
553571
// This appears to create factorizations that have a slightly better reconstruction
554572
// (<1%) compared picking elements uniformly at random in [0,1].
555573
inBlocks.map { case (srcBlockId, inBlock) =>
556-
val random = new XORShiftRandom(srcBlockId)
574+
val random = new XORShiftRandom(byteswap64(seed ^ srcBlockId))
557575
val factors = Array.fill(inBlock.srcIds.length) {
558576
val factor = Array.fill(rank)(random.nextGaussian().toFloat)
559577
val nrm = blas.snrm2(rank, factor, 1)
@@ -877,7 +895,8 @@ object ALS extends Logging {
877895
prefix: String,
878896
ratingBlocks: RDD[((Int, Int), RatingBlock[ID])],
879897
srcPart: Partitioner,
880-
dstPart: Partitioner)(
898+
dstPart: Partitioner,
899+
storageLevel: StorageLevel)(
881900
implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = {
882901
val inBlocks = ratingBlocks.map {
883902
case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
@@ -914,7 +933,8 @@ object ALS extends Logging {
914933
builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
915934
}
916935
builder.build().compress()
917-
}.setName(prefix + "InBlocks").cache()
936+
}.setName(prefix + "InBlocks")
937+
.persist(storageLevel)
918938
val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) =>
919939
val encoder = new LocalIndexEncoder(dstPart.numPartitions)
920940
val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int])
@@ -936,7 +956,8 @@ object ALS extends Logging {
936956
activeIds.map { x =>
937957
x.result()
938958
}
939-
}.setName(prefix + "OutBlocks").cache()
959+
}.setName(prefix + "OutBlocks")
960+
.persist(storageLevel)
940961
(inBlocks, outBlocks)
941962
}
942963

0 commit comments

Comments
 (0)