diff --git a/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala b/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala index b8c17706..dd1f8e68 100644 --- a/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala +++ b/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala @@ -5,6 +5,7 @@ import com.redislabs.provider.redis.util.ConnectionUtils.withConnection import com.redislabs.provider.redis.util.PipelineUtils._ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD + import scala.collection.JavaConversions.mapAsJavaMap /** @@ -253,6 +254,19 @@ class RedisContext(@transient val sc: SparkContext) extends Serializable { kvs.foreachPartition(partition => setKVs(partition, ttl, redisConfig, readWriteConfig)) } + /** + * Write RDD of binary key-value pairs to Redis String + * + * @param kvs Pair RDD of K/V as byte arrays + * @param ttl time to live + */ + def toRedisByteKV(kvs: RDD[(Array[Byte], Array[Byte])], ttl: Int = 0) + (implicit + redisConfig: RedisConfig = RedisConfig.fromSparkConf(sc.getConf), + readWriteConfig: ReadWriteConfig = ReadWriteConfig.fromSparkConf(sc.getConf)) { + kvs.foreachPartition(partition => setByteKVs(partition, ttl, redisConfig, readWriteConfig)) + } + /** * @param kvs Pair RDD of K/V * @param hashName target hash's name which hold all the kvs @@ -408,6 +422,29 @@ object RedisContext extends Serializable { } } + /** + * Save KVs as byte arrys. + * + * @param arr k/v which should be saved in the target host as bytes + * @param ttl time to live + */ + def setByteKVs(arr: Iterator[(Array[Byte], Array[Byte])], ttl: Int, redisConfig: RedisConfig, + readWriteConfig: ReadWriteConfig) { + implicit val rwConf: ReadWriteConfig = readWriteConfig + arr.map(kv => (redisConfig.getHost(kv._1), kv)).toArray.groupBy(_._1). + mapValues(a => a.map(p => p._2)).foreach { x => + val conn = x._1.endpoint.connect() + foreachWithPipeline(conn, x._2) { case (pipeline, (k, v)) => + if (ttl <= 0) { + pipeline.set(k, v) + } + else { + pipeline.setex(k, ttl.toLong, v) + } + } + conn.close() + } + } /** * @param hashName diff --git a/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala index 17102052..ee39abd0 100644 --- a/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala +++ b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala @@ -73,6 +73,17 @@ trait RedisRddExtraSuite extends SparkRedisSuite with Keys with Matchers { verifyHash("hash2", map2) } + test("toRedisByteKV") { + val binaryKeyValue1 = ("binary-key1".getBytes, "binary-value1".getBytes) + val binaryKeyValue2 = ("binary-key2".getBytes, "binary-value2".getBytes) + val keyValueBytes = Seq(binaryKeyValue1, binaryKeyValue2) + val rdd = sc.parallelize(keyValueBytes) + sc.toRedisByteKV(rdd) + + verifyBytes(binaryKeyValue1._1, binaryKeyValue1._2) + verifyBytes(binaryKeyValue2._1, binaryKeyValue2._2) + } + test("connection fails with incorrect user/pass") { assertThrows[JedisConnectionException] { new RedisConfig(RedisEndpoint( @@ -111,5 +122,10 @@ trait RedisRddExtraSuite extends SparkRedisSuite with Keys with Matchers { conn.hgetAll(hash).asScala should be(vals) } } - + + def verifyBytes(key: Array[Byte], value: Array[Byte]): Unit = { + withConnection(redisConfig.getHost(key).endpoint.connect()) { conn => + conn.get(key) should be(value) + } + } }