SparkRDD缓存

时间:2017-09-17 16:00:00

RDD缓存是Spark的一个重要特性,也是Spark速度快的原因之一,RDD在内存持久化或缓存之后,每一个节点都将把计算的分区结果留在内存中,并再对RDD进行其他的Action动作重用,这样后续的动作就会更快;
查看StorageLevel可以看到缓存的级别

/**
 * Various [[org.apache.spark.storage.StorageLevel]] defined and utility functions for creating
 * new storage levels.
 */
object StorageLevel {
  val NONE = new StorageLevel(false, false, false, false)
  val DISK_ONLY = new StorageLevel(true, false, false, false)
  val DISK_ONLY_2 = new StorageLevel(true, false, false, false, 2)
  val MEMORY_ONLY = new StorageLevel(false, true, false, true)
  val MEMORY_ONLY_2 = new StorageLevel(false, true, false, true, 2)
  val MEMORY_ONLY_SER = new StorageLevel(false, true, false, false)
  val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, false, 2)
  val MEMORY_AND_DISK = new StorageLevel(true, true, false, true)
  val MEMORY_AND_DISK_2 = new StorageLevel(true, true, false, true, 2)
  val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, false)
  val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, false, 2)
  val OFF_HEAP = new StorageLevel(true, true, true, false, 1)
...
通过persist()和cache()方法可以对RDD进行缓存或持久化,查看他们的源码如下

  /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
  def persist(): this.type = persist(StorageLevel.MEMORY_ONLY)

  /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
  def cache(): this.type = persist()
可以看出cache其实就是调用persist默认的内存级别进行缓存,/* Persist this RDD with the default storage level (MEMORY_ONLY). /,就是说cache其实是一个快捷方法,实际上还是persist()为主,persist是可以传入根据需要的StorageLevel进行缓存的

/**
   * Set this RDD's storage level to persist its values across operations after the first time
   * it is computed. This can only be used to assign a new storage level if the RDD does not
   * have a storage level set yet. Local checkpointing is an exception.
   */
  def persist(newLevel: StorageLevel): this.type = {
    if (isLocallyCheckpointed) {
      // This means the user previously called localCheckpoint(), which should have already
      // marked this RDD for persisting. Here we should override the old storage level with
      // one that is explicitly requested by the user (after adapting it to use disk).
      persist(LocalRDDCheckpointData.transformStorageLevel(newLevel), allowOverride = true)
    } else {
      persist(newLevel, allowOverride = false)
    }
  }
rdd2.persist(StorageLevel.DISK_ONLY)
对于如rd1->rd2->rd3,如果对rd2进行缓存的话,那么在执行rd3计算时就不会再进行rd1->rd2,如下中对rd2进行缓存了,那么在执行rd2.collect和 rd3=rd2.map(f=>(f._1+f._2))时就不会在进行rd2有关的依赖计算了,速度也得到了很大的提升


scala> val rd1=sc.makeRDD((1 to 20),4)
rd1: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[10] at makeRDD at :24

scala> val rd2=rd1.map(f=>(f,f*f))
rd2: org.apache.spark.rdd.RDD[(Int, Int)] = MapPartitionsRDD[12] at map at :26

scala> rd2.cache
res13: rd2.type = MapPartitionsRDD[12] at map at :26

scala> rd2.collect
res10: Array[(Int, Int)] = Array((1,1), (2,4), (3,9), (4,16), (5,25), (6,36), (7,49), (8,64), (9,81), (10,100), (11,121), (12,144), (13,169), (14,196), (15,225), (16,256), (17,289), (18,324), (19,361), (20,400))

scala> val rd3=rd2.map(f=>(f._1+f._2))
rd3: org.apache.spark.rdd.RDD[Int] = MapPartitionsRDD[14] at map at :28

scala> rd3.collect
res12: Array[Int] = Array(2, 6, 12, 20, 30, 42, 56, 72, 90, 110, 132, 156, 182, 210, 240, 272, 306, 342, 380, 420)
RDD的缓存有可能会造成数据丢失,或者存储于内存中的数据由于内存不足而被删除,RDD的容错机制保证缓存了数据及时丢失也能保证还能正确计算,RDD的各个Partition是相对独立的,只需要重新计算丢失的部分即可,并不需要重新建计算所有的分区
RDD迭代iterator中可以看到如果存储级别为空则直接进行计算,否则去检查点检查是否计算还是从缓存中拿

  /**
   * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
   * This should ''not'' be called by users directly, but is available for implementors of custom
   * subclasses of RDD.
   */
  final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      getOrCompute(split, context)
    } else {
      computeOrReadCheckpoint(split, context)
    }
  }
private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
    val blockId = RDDBlockId(id, partition.index)
    var readCachedBlock = true
    // This method is called on executors, so we need call SparkEnv.get instead of sc.env.
    SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
      readCachedBlock = false
      computeOrReadCheckpoint(partition, context)
    }) match {
      case Left(blockResult) =>
        if (readCachedBlock) {
          val existingMetrics = context.taskMetrics().inputMetrics
          existingMetrics.incBytesRead(blockResult.bytes)
          new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
            override def next(): T = {
              existingMetrics.incRecordsRead(1)
              delegate.next()
            }
          }
        } else {
          new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
        }
      case Right(iter) =>
        new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
    }
  }