stencil.scala // Jump To …

Sliding Stencil

In this tutorial, we show a solution to Shonan Challenge 3.3 Stencil.

For explanation of this solution, see the group discussion.

For further reference, see:

Outline:

package scala.lms.tutorial

import scala.lms.common.ForwardTransformer
import scala.reflect.SourceContext

Stencil Challenge

We implement the challenge using an almost regular staged loop, except using sliding staging-time abstraction.

How does it work? We rely on the fact that LMS performs common subexpression elimination internally, and we insert variables for values that are re-used from one loop iteration to the next.

We compute the code for loop iterations i and i + 1. For iteration i + 1, we substitute i + 1 for i.

We peel the first loop iteration to initialize the variables.

And then generate a modified loop that executes the remaining iterations. The loop body corresponds to the (i+1) iteration above plus the reads and write statements that maintain the variables.

So with a twist on our general CSE facility, we implement CSE across loop iterations. The approach also works for window sizes larger than 1, we just need to unroll more iterations of the loop. In the challenge example, if we look at iteration i + 2, we find that it does not reuse any memory access operations from iteration i, so a larger window is not necessary in this case.

trait Stencil extends Sliding {
  def snippet(v: Rep[Array[Double]]): Rep[Array[Double]] = {
    val n = v.length
    val input = v
    val output = NewArray[Double](n)
    def a(j: Rep[Int]) = input(j)
    def w1(j: Rep[Int]) = a(j) * a(j+1)
    def wm(j: Rep[Int]) = a(j) - w1(j) + w1(j-1)
    def w2(j: Rep[Int]) = wm(j) * wm(j+1)
    def b(j: Rep[Int]) = wm(j) - w2(j) + w2(j-1)
    for (i <- (2 until n-2).sliding) {
      output(i) = b(i)
    }
    output
  }
}
abstract class StencilDriver extends DslDriver[Array[Double],Array[Double]] with Stencil
class StencilTest extends TutorialFunSuite {
  val under = "stencil"

  test("stencil without sliding") {
    val stencil0 = new StencilDriver with NoSlidingExp
    check("0", stencil0.code)
  }

  test("stencil with sliding") {
    val stencil1 = new StencilDriver with SlidingExp
    check("1", stencil1.code)
  }

  test("stencil with multi sliding") {
    val stencil2 = new StencilDriver with SlidingMultiExp
    check("2", stencil2.code)
  }

  test("stencil equal") {
    val stencil0 = new StencilDriver with NoSlidingExp
    val stencil1 = new StencilDriver with SlidingExp
    val stencil2 = new StencilDriver with SlidingMultiExp
    val input = Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)
    assert(stencil0.eval(input).mkString(",") == stencil1.eval(input).mkString(","))
    assert(stencil0.eval(input).mkString(",") == stencil2.eval(input).mkString(","))
  }
}

Generated code without sliding

/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Array[Double])=>(Array[Double])) {
  def apply(x0:Array[Double]): Array[Double] = {
    val x1 = x0.length
    val x2 = new Array[Double](x1)
    val x3 = x1 - 2
    var x5 : Int = 2
    val x40 = while (x5 < x3) {
      val x6 = x0(x5)
      val x7 = x5 + 1
      val x8 = x0(x7)
      val x9 = x6 * x8
      val x10 = x6 - x9
      val x11 = x5 - 1
      val x12 = x0(x11)
      val x13 = x11 + 1
      val x14 = x0(x13)
      val x15 = x12 * x14
      val x16 = x10 + x15
      val x17 = x7 + 1
      val x18 = x0(x17)
      val x19 = x8 * x18
      val x20 = x8 - x19
      val x21 = x20 + x9
      val x22 = x16 * x21
      val x23 = x16 - x22
      val x24 = x12 - x15
      val x25 = x11 - 1
      val x26 = x0(x25)
      val x27 = x25 + 1
      val x28 = x0(x27)
      val x29 = x26 * x28
      val x30 = x24 + x29
      val x31 = x13 + 1
      val x32 = x0(x31)
      val x33 = x14 * x32
      val x34 = x14 - x33
      val x35 = x34 + x15
      val x36 = x30 * x35
      val x37 = x23 + x36
      val x38 = x2(x5) = x37
      x5 = x5 + 1
    }
    x2
  }
}
/*****************************************
End of Generated Code
*******************************************/

Generated code with sliding

/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Array[Double])=>(Array[Double])) {
  def apply(x0:Array[Double]): Array[Double] = {
    val x1 = x0.length
    val x2 = new Array[Double](x1)
    val x3 = x1 - 2
    val x48 = x3 > 2
    val x105 = if (x48) {
      val x49 = x0(2)
      val x50 = x0(3)
      val x51 = x49 * x50
      val x52 = x49 - x51
      val x53 = x0(1)
      val x54 = x53 * x49
      val x55 = x52 + x54
      val x56 = x0(4)
      val x57 = x50 * x56
      val x58 = x50 - x57
      val x59 = x58 + x51
      val x60 = x55 * x59
      val x61 = x55 - x60
      val x62 = x53 - x54
      val x63 = x0(0)
      val x64 = x63 * x53
      val x65 = x62 + x64
      val x66 = x65 * x55
      val x67 = x61 + x66
      val x68 = x2(2) = x67
      var x69: Double = x56
      var x70: Double = x57
      var x71: Double = x59
      var x72: Double = x60
      var x73: Int = 3
      var x74: Int = 4
      var x76 : Int = 3
      val x103 = while (x76 < x3) {
        // variable reads
        val x78 = x69
        val x79 = x70
        val x80 = x71
        val x81 = x72
        val x82 = x73
        val x83 = x74
        // computation
        val x86 = x76 + 2
        val x87 = x0(x86)
        val x88 = x78 * x87
        val x89 = x78 - x88
        val x90 = x89 + x79
        val x91 = x80 * x90
        val x92 = x80 - x91
        val x93 = x92 + x81
        val x94 = x2(x82) = x93
        // variable writes
        x69 = x87
        x70 = x88
        x71 = x90
        x72 = x91
        x73 = x83
        x74 = x86
        x76 = x76 + 1
      }
      x103
    } else {
      ()
    }
    x2
  }
}
/*****************************************
End of Generated Code
*******************************************/

Generated code with multi sliding

/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Array[Double])=>(Array[Double])) {
  def apply(x0:Array[Double]): Array[Double] = {
    val x1 = x0.length
    val x2 = new Array[Double](x1)
    val x3 = x1 - 2
    val x57 = x3 > 2
    val x114 = if (x57) {
      val x58 = x0(2)
      val x59 = x0(3)
      val x60 = x58 * x59
      val x61 = x58 - x60
      val x62 = x0(1)
      val x63 = x62 * x58
      val x64 = x61 + x63
      val x65 = x0(4)
      val x66 = x59 * x65
      val x67 = x59 - x66
      val x68 = x67 + x60
      val x69 = x64 * x68
      val x70 = x64 - x69
      val x71 = x62 - x63
      val x72 = x0(0)
      val x73 = x72 * x62
      val x74 = x71 + x73
      val x75 = x74 * x64
      val x76 = x70 + x75
      val x77 = x2(2) = x76
      var x78: Double = x65
      var x79: Double = x66
      var x80: Double = x68
      var x81: Double = x69
      var x82: Int = 3
      var x83: Int = 4
      var x85 : Int = 3
      val x112 = while (x85 < x3) {
        // variable reads
        val x87 = x78
        val x88 = x79
        val x89 = x80
        val x90 = x81
        val x91 = x82
        val x92 = x83
        // computation
        val x95 = x85 + 2
        val x96 = x0(x95)
        val x97 = x87 * x96
        val x98 = x87 - x97
        val x99 = x98 + x88
        val x100 = x89 * x99
        val x101 = x89 - x100
        val x102 = x101 + x90
        val x103 = x2(x91) = x102
        // variable writes
        x78 = x96
        x79 = x97
        x80 = x99
        x81 = x100
        x82 = x92
        x83 = x95
        x85 = x85 + 1
      }
      x112
    } else {
      ()
    }
    x2
  }
}
/*****************************************
End of Generated Code
*******************************************/

Note that for the examples here, the multi sliding does not improve upon the number of shared variables.

Infrastructure

trait Sliding extends Dsl {
  def infix_sliding[T:Typ](n: Rep[Int], f: Rep[Int] => Rep[T]): Rep[Array[T]] = {
    val a = NewArray[T](n)
    sliding(0,n)(i => a(i) = f(i))
    a
  }
  def infix_sliding(r: Rep[Range]) = new {
    def foreach(f: Rep[Int] => Rep[Unit]): Rep[Unit] =
      sliding(r.start, r.end)(f)
  }
  def sliding(start: Rep[Int], end: Rep[Int])(f: Rep[Int] => Rep[Unit]): Rep[Unit]
}

No Sliding (Baseline)

Not actually sliding – just to have a baseline reference.

trait NoSlidingExp extends DslExp with Sliding {
  def sliding(start: Rep[Int], end: Rep[Int])(f: Rep[Int] => Rep[Unit]): Rep[Unit] = {
    (start until end) foreach f
  }
}

Sliding

trait SlidingExp extends DslExp with Sliding {

We use a LMS transformer to evaluate the loop body across various iterations and transformations.

  object trans extends ForwardTransformer {
    val IR: SlidingExp.this.type = SlidingExp.this
  }
  def log(x: Any): Unit = {
    System.out.println("sliding log: "+x)
  }

Some arithmetic rewrites to maximize sliding sharing.

  override def int_plus(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext): Exp[Int] =
    ((lhs,rhs) match {
      // (x+y)+z --> x+(y+z)
      case (Def(IntPlus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_plus(x, unit(y+z))
      // (x-y)+z --> x-(y-z)
      case (Def(IntMinus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_minus(x, unit(y-z))
      case (x: Exp[Int], Const(z:Int)) if z < 0 => int_minus(x, unit(-z))
      case _ => super.int_plus(lhs,rhs)
    }).asInstanceOf[Exp[Int]]

  override def int_minus(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext): Exp[Int] =
    ((lhs,rhs) match {
      // (x-y)-z --> x-(y+z)
      case (Def(IntMinus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_minus(x, unit(y+z))
      // (x+y)-z --> x+(y-z)
      case (Def(IntPlus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_plus(x, unit(y-z))
      case (x: Exp[Int], Const(z:Int)) if z < 0 => int_plus(x, unit(-z))
      case _ => super.int_minus(lhs,rhs)
    }).asInstanceOf[Exp[Int]]

  type Subst = scala.collection.immutable.Map[Exp[Any],Exp[Any]]
  type Doublet = (Rep[Unit], List[Stm])
  type Triplet = (Rep[Unit], List[Stm], Subst)

Find the overlapping symbols: defined by f(i), used by f(i+1), f(i+2), … We create a helper function so that we can generalize to a fixpoint calculation in the subclass below.

  def findOverlap(i: Sym[Int], f: Rep[Int] => Rep[Unit]): (List[Sym[Any]], Doublet, Triplet) = {

We evaluate loop contents f(i), then f(i+1), then f(i+2). For evaluation, we use the reified sub graph for f(i), but we substitute i for i+1 or i+2. Each time, we reflect the statements from the previous iteration.

    val save = context
    // evaluate loop contents f(i)
    val (r0,stms0) = reifySubGraph(f(i))
    val (((r1,stms1,subst1),(r2,stms2,subst2)), _) = reifySubGraph {
      reflectSubGraph(stms0)
      context = save
      // evaluate loop contents f(i+1)
      val ((r1,subst1),stms1) = reifySubGraph(trans.withSubstScope(i -> (i+1)) {
        stms0.foreach(s=>trans.traverseStm(s))
        (trans(r0), trans.subst)
      })
      val ((r2,stms2,subst2), _) = reifySubGraph {
        reflectSubGraph(stms1)
        context = save
        // evaluate loop contents f(i+2)
        val ((r2,subst2),stms2) = reifySubGraph(trans.withSubstScope(i -> (i+2)) {
          stms0.foreach(s=>trans.traverseStm(s))
          (trans(r0), trans.subst)
        })
        (r2,stms2,subst2)
      }
      ((r1,stms1,subst1), (r2,stms2,subst2))
    }
    context = save

    val defs = stms0.flatMap(_.lhs)

Here, we find the overlap symbols: defined by f(i), used by f(i+1) and f(i+2).

    val overlap01 = stms1.flatMap { case TP(s,d) => syms(d) filter (defs contains _) }.distinct
    val overlap02 = stms2.flatMap { case TP(s,d) => syms(d) filter (defs contains _) }.distinct

We haven't yet a fixed point yet, if there are still overlapping symbols in the last iteration. We will fix this in the subclass below.

    if (overlap02.nonEmpty)
      log("Overlap beyond a single loop iteration is ignored, since not yet implemented.")

    val overlap0 = (overlap01++overlap02).distinct
    (overlap0, (r0, stms0), (r1, stms1, subst1))
  }

  def sliding(start: Rep[Int], end: Rep[Int])(f: Rep[Int] => Rep[Unit]): Rep[Unit] = {
    val i = fresh[Int]
    val (overlap0, (r0, stms0), (r1, stms1, subst1)) = findOverlap(i, f)
    val overlap1 = overlap0 map subst1

Build a variable for each overlap symbol. Initialize the variables by peeling first loop iteration.

    if (end > start) {
      val (rX,substX) = trans.withSubstScope(i -> start) {
        stms0.foreach(s=>trans.traverseStm(s))
        (trans(r0), trans.subst)
      }
      val vars = overlap0.map{x => var_new(substX(x))(x.tp,x.pos.head)}

Now generate the loop:

      for (j <- (start + unit(1)) until end) {
        generate_comment("variable reads")
        val reads = (overlap0 zip vars).map(p => (p._1, readVar(p._2)(p._1.tp,p._1.pos.head)))
        generate_comment("computation")
        val (ri, substY1: Subst) = trans.withSubstScope((reads:+(i->(j-unit(1)))): _*) {
          stms1.foreach(s=>trans.traverseStm(s))
          (trans(r1), trans.subst)
        }
        generate_comment("variable writes")
        val writes = (overlap1 zip vars).map{p =>
          (p._1, var_assign(p._2, substY1(p._1))(p._1.tp,p._1.pos.head))}
      }
    }
  }
}

Multi sliding

trait SlidingMultiExp extends SlidingExp with DslExp with Sliding {

Find overlapping symbols between iterations until there are no new ones.

  override def findOverlap(i: Sym[Int], f: Rep[Int] => Rep[Unit]) = {
    val save = context
    val (r0,stms0) = reifySubGraph(f(i))
    val defs = stms0.flatMap(_.lhs)
    def step(n: Int, last: List[Stm], acc: List[Triplet], overlap: List[Sym[Any]]): (Triplet, List[Sym[Any]]) = {
      val (res: (Triplet, List[Sym[Any]]), _) = reifySubGraph {
        reflectSubGraph(last)
        context = save
        val ((ri, substi), stmsi) = reifySubGraph(trans.withSubstScope(i -> (i+n)) {
          stms0.foreach(s => trans.traverseStm(s))
          (trans(r0), trans.subst)
        })

        val overlapi = stmsi.flatMap{ case TP(s,d) => syms(d) filter (defs contains _)}.distinct
        if (overlapi.nonEmpty)
          step(n+1, stmsi, ((ri, stmsi, substi):Triplet)::acc, (overlap++overlapi).distinct)
        else {
          log("stopping at "+n)
          (acc.last, overlap)
        }
      }
      res
    }
    val ((r1, stms1, subst1), overlap0) = step(1, stms0, Nil, Nil)
    context = save
    (overlap0, (r0, stms0), (r1, stms1, subst1))
  }
}

Warmup

trait SlidingWarmup extends Sliding {
  def snippet(n: Rep[Int]): Rep[Array[Int]] = {
    def compute(i: Rep[Int]) = 2*i+3
    n sliding { i => compute(i) + compute(i+1) }
  }
}
abstract class SlidingWarmupDriver extends DslDriver[Int,Array[Int]] with SlidingWarmup
class SlidingWarmupTest extends TutorialFunSuite {
  val under = "sliding"

  test("warmup without sliding") {
    val sliding0 = new SlidingWarmupDriver with NoSlidingExp
    check("0", sliding0.code)
  }

  test("warmup with sliding") {
    val sliding1 = new SlidingWarmupDriver with SlidingExp
    check("1", sliding1.code)
  }

  test("warmup with multi sliding") {
    val sliding2 = new SlidingWarmupDriver with SlidingMultiExp
    check("1", sliding2.code) // same as single sliding
  }

  test("warmup equal") {
    val sliding0 = new SlidingWarmupDriver with NoSlidingExp
    val sliding1 = new SlidingWarmupDriver with SlidingExp
    val sliding2 = new SlidingWarmupDriver with SlidingMultiExp
    val input = 5
    assert(sliding0.eval(input).mkString(",") == sliding1.eval(input).mkString(","))
    assert(sliding0.eval(input).mkString(",") == sliding2.eval(input).mkString(","))
  }
}

Generated code without sliding

/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Int)=>(Array[Int])) {
  def apply(x0:Int): Array[Int] = {
    val x1 = new Array[Int](x0)
    var x3 : Int = 0
    val x12 = while (x3 < x0) {
      val x4 = 2 * x3
      val x5 = x4 + 3
      val x6 = x3 + 1
      val x7 = 2 * x6
      val x8 = x7 + 3
      val x9 = x5 + x8
      val x10 = x1(x3) = x9
      x3 = x3 + 1
    }
    x1
  }
}
/*****************************************
End of Generated Code
*******************************************/

Generated code with sliding

/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Int)=>(Array[Int])) {
  def apply(x0:Int): Array[Int] = {
    val x1 = new Array[Int](x0)
    val x20 = x0 > 0
    val x42 = if (x20) {
      val x21 = x1(0) = 8
      var x22: Int = 5
      var x23: Int = 1
      var x25 : Int = 1
      val x40 = while (x25 < x0) {
        // variable reads
        val x27 = x22
        val x28 = x23
        // computation
        val x31 = x25 + 1
        val x32 = 2 * x31
        val x33 = x32 + 3
        val x34 = x27 + x33
        val x35 = x1(x28) = x34
        // variable writes
        x22 = x33
        x23 = x31
        x25 = x25 + 1
      }
      x40
    } else {
      ()
    }
    x1
  }
}
/*****************************************
End of Generated Code
*******************************************/

Generated code with multi sliding

/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Int)=>(Array[Int])) {
  def apply(x0:Int): Array[Int] = {
    val x1 = new Array[Int](x0)
    val x20 = x0 > 0
    val x46 = if (x20) {
      val x21 = x1(0) = 8
      var x22: Int = 5
      var x23: Int = 1
      var x25 : Int = 0
      val x44 = while (x25 < x0) {
        // unrolled for k=0
        val x27 = x25 + 1
        val x28 = x27 < x0
        val x42 = if (x28) {
          // variable reads
          val x30 = x22
          val x31 = x23
          // computation
          val x33 = x25 + 2
          val x34 = 2 * x33
          val x35 = x34 + 3
          val x36 = x30 + x35
          val x37 = x1(x31) = x36
          // variable writes
          x22 = x35
          x23 = x33
          ()
        } else {
          ()
        }
        x25 = x25 + 1
      }
      x44
    } else {
      ()
    }
    x1
  }
}
/*****************************************
End of Generated Code
*******************************************/

Comments? Suggestions for improvement? View this file on GitHub.