@@ -22,6 +22,7 @@ import java.{util => ju}
22
22
import scala .collection .mutable
23
23
import scala .reflect .ClassTag
24
24
import scala .util .Sorting
25
+ import scala .util .hashing .byteswap64
25
26
26
27
import com .github .fommil .netlib .BLAS .{getInstance => blas }
27
28
import com .github .fommil .netlib .LAPACK .{getInstance => lapack }
@@ -37,6 +38,7 @@ import org.apache.spark.rdd.RDD
37
38
import org .apache .spark .sql .DataFrame
38
39
import org .apache .spark .sql .Dsl ._
39
40
import org .apache .spark .sql .types .{DoubleType , FloatType , IntegerType , StructField , StructType }
41
+ import org .apache .spark .storage .StorageLevel
40
42
import org .apache .spark .util .Utils
41
43
import org .apache .spark .util .collection .{OpenHashMap , OpenHashSet , SortDataFormat , Sorter }
42
44
import org .apache .spark .util .random .XORShiftRandom
@@ -412,7 +414,7 @@ object ALS extends Logging {
412
414
/**
413
415
* Implementation of the ALS algorithm.
414
416
*/
415
- def train [ID : ClassTag ](
417
+ def train [ID : ClassTag ]( // scalastyle:ignore
416
418
ratings : RDD [Rating [ID ]],
417
419
rank : Int = 10 ,
418
420
numUserBlocks : Int = 10 ,
@@ -421,34 +423,47 @@ object ALS extends Logging {
421
423
regParam : Double = 1.0 ,
422
424
implicitPrefs : Boolean = false ,
423
425
alpha : Double = 1.0 ,
424
- nonnegative : Boolean = false )(
426
+ nonnegative : Boolean = false ,
427
+ intermediateRDDStorageLevel : StorageLevel = StorageLevel .MEMORY_AND_DISK ,
428
+ finalRDDStorageLevel : StorageLevel = StorageLevel .MEMORY_AND_DISK ,
429
+ seed : Long = 0L )(
425
430
implicit ord : Ordering [ID ]): (RDD [(ID , Array [Float ])], RDD [(ID , Array [Float ])]) = {
431
+ require(intermediateRDDStorageLevel != StorageLevel .NONE ,
432
+ " ALS is not designed to run without persisting intermediate RDDs." )
433
+ val sc = ratings.sparkContext
426
434
val userPart = new HashPartitioner (numUserBlocks)
427
435
val itemPart = new HashPartitioner (numItemBlocks)
428
436
val userLocalIndexEncoder = new LocalIndexEncoder (userPart.numPartitions)
429
437
val itemLocalIndexEncoder = new LocalIndexEncoder (itemPart.numPartitions)
430
438
val solver = if (nonnegative) new NNLSSolver else new CholeskySolver
431
- val blockRatings = partitionRatings(ratings, userPart, itemPart).cache()
432
- val (userInBlocks, userOutBlocks) = makeBlocks(" user" , blockRatings, userPart, itemPart)
439
+ val blockRatings = partitionRatings(ratings, userPart, itemPart)
440
+ .persist(intermediateRDDStorageLevel)
441
+ val (userInBlocks, userOutBlocks) =
442
+ makeBlocks(" user" , blockRatings, userPart, itemPart, intermediateRDDStorageLevel)
433
443
// materialize blockRatings and user blocks
434
444
userOutBlocks.count()
435
445
val swappedBlockRatings = blockRatings.map {
436
446
case ((userBlockId, itemBlockId), RatingBlock (userIds, itemIds, localRatings)) =>
437
447
((itemBlockId, userBlockId), RatingBlock (itemIds, userIds, localRatings))
438
448
}
439
- val (itemInBlocks, itemOutBlocks) = makeBlocks(" item" , swappedBlockRatings, itemPart, userPart)
449
+ val (itemInBlocks, itemOutBlocks) =
450
+ makeBlocks(" item" , swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel)
440
451
// materialize item blocks
441
452
itemOutBlocks.count()
442
- var userFactors = initialize(userInBlocks, rank)
443
- var itemFactors = initialize(itemInBlocks, rank)
453
+ val seedGen = new XORShiftRandom (seed)
454
+ var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
455
+ var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
444
456
if (implicitPrefs) {
445
457
for (iter <- 1 to maxIter) {
446
- userFactors.setName(s " userFactors- $iter" ).persist()
458
+ userFactors.setName(s " userFactors- $iter" ).persist(intermediateRDDStorageLevel )
447
459
val previousItemFactors = itemFactors
448
460
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
449
461
userLocalIndexEncoder, implicitPrefs, alpha, solver)
450
462
previousItemFactors.unpersist()
451
- itemFactors.setName(s " itemFactors- $iter" ).persist()
463
+ if (sc.checkpointDir.isDefined && (iter % 3 == 0 )) {
464
+ itemFactors.checkpoint()
465
+ }
466
+ itemFactors.setName(s " itemFactors- $iter" ).persist(intermediateRDDStorageLevel)
452
467
val previousUserFactors = userFactors
453
468
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
454
469
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
@@ -467,21 +482,23 @@ object ALS extends Logging {
467
482
.join(userFactors)
468
483
.values
469
484
.setName(" userFactors" )
470
- .cache()
471
- userIdAndFactors.count()
472
- itemFactors.unpersist()
485
+ .persist(finalRDDStorageLevel)
473
486
val itemIdAndFactors = itemInBlocks
474
487
.mapValues(_.srcIds)
475
488
.join(itemFactors)
476
489
.values
477
490
.setName(" itemFactors" )
478
- .cache()
479
- itemIdAndFactors.count()
480
- userInBlocks.unpersist()
481
- userOutBlocks.unpersist()
482
- itemInBlocks.unpersist()
483
- itemOutBlocks.unpersist()
484
- blockRatings.unpersist()
491
+ .persist(finalRDDStorageLevel)
492
+ if (finalRDDStorageLevel != StorageLevel .NONE ) {
493
+ userIdAndFactors.count()
494
+ itemFactors.unpersist()
495
+ itemIdAndFactors.count()
496
+ userInBlocks.unpersist()
497
+ userOutBlocks.unpersist()
498
+ itemInBlocks.unpersist()
499
+ itemOutBlocks.unpersist()
500
+ blockRatings.unpersist()
501
+ }
485
502
val userOutput = userIdAndFactors.flatMap { case (ids, factors) =>
486
503
ids.view.zip(factors)
487
504
}
@@ -546,14 +563,15 @@ object ALS extends Logging {
546
563
*/
547
564
private def initialize [ID ](
548
565
inBlocks : RDD [(Int , InBlock [ID ])],
549
- rank : Int ): RDD [(Int , FactorBlock )] = {
566
+ rank : Int ,
567
+ seed : Long ): RDD [(Int , FactorBlock )] = {
550
568
// Choose a unit vector uniformly at random from the unit sphere, but from the
551
569
// "first quadrant" where all elements are nonnegative. This can be done by choosing
552
570
// elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
553
571
// This appears to create factorizations that have a slightly better reconstruction
554
572
// (<1%) compared picking elements uniformly at random in [0,1].
555
573
inBlocks.map { case (srcBlockId, inBlock) =>
556
- val random = new XORShiftRandom (srcBlockId)
574
+ val random = new XORShiftRandom (byteswap64(seed ^ srcBlockId) )
557
575
val factors = Array .fill(inBlock.srcIds.length) {
558
576
val factor = Array .fill(rank)(random.nextGaussian().toFloat)
559
577
val nrm = blas.snrm2(rank, factor, 1 )
@@ -877,7 +895,8 @@ object ALS extends Logging {
877
895
prefix : String ,
878
896
ratingBlocks : RDD [((Int , Int ), RatingBlock [ID ])],
879
897
srcPart : Partitioner ,
880
- dstPart : Partitioner )(
898
+ dstPart : Partitioner ,
899
+ storageLevel : StorageLevel )(
881
900
implicit srcOrd : Ordering [ID ]): (RDD [(Int , InBlock [ID ])], RDD [(Int , OutBlock )]) = {
882
901
val inBlocks = ratingBlocks.map {
883
902
case ((srcBlockId, dstBlockId), RatingBlock (srcIds, dstIds, ratings)) =>
@@ -914,7 +933,8 @@ object ALS extends Logging {
914
933
builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
915
934
}
916
935
builder.build().compress()
917
- }.setName(prefix + " InBlocks" ).cache()
936
+ }.setName(prefix + " InBlocks" )
937
+ .persist(storageLevel)
918
938
val outBlocks = inBlocks.mapValues { case InBlock (srcIds, dstPtrs, dstEncodedIndices, _) =>
919
939
val encoder = new LocalIndexEncoder (dstPart.numPartitions)
920
940
val activeIds = Array .fill(dstPart.numPartitions)(mutable.ArrayBuilder .make[Int ])
@@ -936,7 +956,8 @@ object ALS extends Logging {
936
956
activeIds.map { x =>
937
957
x.result()
938
958
}
939
- }.setName(prefix + " OutBlocks" ).cache()
959
+ }.setName(prefix + " OutBlocks" )
960
+ .persist(storageLevel)
940
961
(inBlocks, outBlocks)
941
962
}
942
963
0 commit comments