In this tutorial, we show a solution to Shonan Challenge 3.3 Stencil.
For explanation of this solution, see the group discussion.
For further reference, see:
Outline:
package scala.lms.tutorial
import scala.lms.common.ForwardTransformer
import scala.reflect.SourceContext
We implement the challenge using an almost regular staged loop,
except using sliding
staging-time abstraction.
How does it work? We rely on the fact that LMS performs common subexpression elimination internally, and we insert variables for values that are re-used from one loop iteration to the next.
We compute the code for loop iterations i and i + 1. For iteration i + 1, we substitute i + 1 for i.
We peel the first loop iteration to initialize the variables.
And then generate a modified loop that executes the remaining iterations. The loop body corresponds to the (i+1) iteration above plus the reads and write statements that maintain the variables.
So with a twist on our general CSE facility, we implement CSE across loop iterations. The approach also works for window sizes larger than 1, we just need to unroll more iterations of the loop. In the challenge example, if we look at iteration i + 2, we find that it does not reuse any memory access operations from iteration i, so a larger window is not necessary in this case.
trait Stencil extends Sliding {
def snippet(v: Rep[Array[Double]]): Rep[Array[Double]] = {
val n = v.length
val input = v
val output = NewArray[Double](n)
def a(j: Rep[Int]) = input(j)
def w1(j: Rep[Int]) = a(j) * a(j+1)
def wm(j: Rep[Int]) = a(j) - w1(j) + w1(j-1)
def w2(j: Rep[Int]) = wm(j) * wm(j+1)
def b(j: Rep[Int]) = wm(j) - w2(j) + w2(j-1)
for (i <- (2 until n-2).sliding) {
output(i) = b(i)
}
output
}
}
abstract class StencilDriver extends DslDriver[Array[Double],Array[Double]] with Stencil
class StencilTest extends TutorialFunSuite {
val under = "stencil"
test("stencil without sliding") {
val stencil0 = new StencilDriver with NoSlidingExp
check("0", stencil0.code)
}
test("stencil with sliding") {
val stencil1 = new StencilDriver with SlidingExp
check("1", stencil1.code)
}
test("stencil with multi sliding") {
val stencil2 = new StencilDriver with SlidingMultiExp
check("2", stencil2.code)
}
test("stencil equal") {
val stencil0 = new StencilDriver with NoSlidingExp
val stencil1 = new StencilDriver with SlidingExp
val stencil2 = new StencilDriver with SlidingMultiExp
val input = Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)
assert(stencil0.eval(input).mkString(",") == stencil1.eval(input).mkString(","))
assert(stencil0.eval(input).mkString(",") == stencil2.eval(input).mkString(","))
}
}
/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Array[Double])=>(Array[Double])) {
def apply(x0:Array[Double]): Array[Double] = {
val x1 = x0.length
val x2 = new Array[Double](x1)
val x3 = x1 - 2
var x5 : Int = 2
val x40 = while (x5 < x3) {
val x6 = x0(x5)
val x7 = x5 + 1
val x8 = x0(x7)
val x9 = x6 * x8
val x10 = x6 - x9
val x11 = x5 - 1
val x12 = x0(x11)
val x13 = x11 + 1
val x14 = x0(x13)
val x15 = x12 * x14
val x16 = x10 + x15
val x17 = x7 + 1
val x18 = x0(x17)
val x19 = x8 * x18
val x20 = x8 - x19
val x21 = x20 + x9
val x22 = x16 * x21
val x23 = x16 - x22
val x24 = x12 - x15
val x25 = x11 - 1
val x26 = x0(x25)
val x27 = x25 + 1
val x28 = x0(x27)
val x29 = x26 * x28
val x30 = x24 + x29
val x31 = x13 + 1
val x32 = x0(x31)
val x33 = x14 * x32
val x34 = x14 - x33
val x35 = x34 + x15
val x36 = x30 * x35
val x37 = x23 + x36
val x38 = x2(x5) = x37
x5 = x5 + 1
}
x2
}
}
/*****************************************
End of Generated Code
*******************************************/
/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Array[Double])=>(Array[Double])) {
def apply(x0:Array[Double]): Array[Double] = {
val x1 = x0.length
val x2 = new Array[Double](x1)
val x3 = x1 - 2
val x48 = x3 > 2
val x105 = if (x48) {
val x49 = x0(2)
val x50 = x0(3)
val x51 = x49 * x50
val x52 = x49 - x51
val x53 = x0(1)
val x54 = x53 * x49
val x55 = x52 + x54
val x56 = x0(4)
val x57 = x50 * x56
val x58 = x50 - x57
val x59 = x58 + x51
val x60 = x55 * x59
val x61 = x55 - x60
val x62 = x53 - x54
val x63 = x0(0)
val x64 = x63 * x53
val x65 = x62 + x64
val x66 = x65 * x55
val x67 = x61 + x66
val x68 = x2(2) = x67
var x69: Double = x56
var x70: Double = x57
var x71: Double = x59
var x72: Double = x60
var x73: Int = 3
var x74: Int = 4
var x76 : Int = 3
val x103 = while (x76 < x3) {
// variable reads
val x78 = x69
val x79 = x70
val x80 = x71
val x81 = x72
val x82 = x73
val x83 = x74
// computation
val x86 = x76 + 2
val x87 = x0(x86)
val x88 = x78 * x87
val x89 = x78 - x88
val x90 = x89 + x79
val x91 = x80 * x90
val x92 = x80 - x91
val x93 = x92 + x81
val x94 = x2(x82) = x93
// variable writes
x69 = x87
x70 = x88
x71 = x90
x72 = x91
x73 = x83
x74 = x86
x76 = x76 + 1
}
x103
} else {
()
}
x2
}
}
/*****************************************
End of Generated Code
*******************************************/
/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Array[Double])=>(Array[Double])) {
def apply(x0:Array[Double]): Array[Double] = {
val x1 = x0.length
val x2 = new Array[Double](x1)
val x3 = x1 - 2
val x57 = x3 > 2
val x114 = if (x57) {
val x58 = x0(2)
val x59 = x0(3)
val x60 = x58 * x59
val x61 = x58 - x60
val x62 = x0(1)
val x63 = x62 * x58
val x64 = x61 + x63
val x65 = x0(4)
val x66 = x59 * x65
val x67 = x59 - x66
val x68 = x67 + x60
val x69 = x64 * x68
val x70 = x64 - x69
val x71 = x62 - x63
val x72 = x0(0)
val x73 = x72 * x62
val x74 = x71 + x73
val x75 = x74 * x64
val x76 = x70 + x75
val x77 = x2(2) = x76
var x78: Double = x65
var x79: Double = x66
var x80: Double = x68
var x81: Double = x69
var x82: Int = 3
var x83: Int = 4
var x85 : Int = 3
val x112 = while (x85 < x3) {
// variable reads
val x87 = x78
val x88 = x79
val x89 = x80
val x90 = x81
val x91 = x82
val x92 = x83
// computation
val x95 = x85 + 2
val x96 = x0(x95)
val x97 = x87 * x96
val x98 = x87 - x97
val x99 = x98 + x88
val x100 = x89 * x99
val x101 = x89 - x100
val x102 = x101 + x90
val x103 = x2(x91) = x102
// variable writes
x78 = x96
x79 = x97
x80 = x99
x81 = x100
x82 = x92
x83 = x95
x85 = x85 + 1
}
x112
} else {
()
}
x2
}
}
/*****************************************
End of Generated Code
*******************************************/
Note that for the examples here, the multi sliding does not improve upon the number of shared variables.
trait Sliding extends Dsl {
def infix_sliding[T:Typ](n: Rep[Int], f: Rep[Int] => Rep[T]): Rep[Array[T]] = {
val a = NewArray[T](n)
sliding(0,n)(i => a(i) = f(i))
a
}
def infix_sliding(r: Rep[Range]) = new {
def foreach(f: Rep[Int] => Rep[Unit]): Rep[Unit] =
sliding(r.start, r.end)(f)
}
def sliding(start: Rep[Int], end: Rep[Int])(f: Rep[Int] => Rep[Unit]): Rep[Unit]
}
trait NoSlidingExp extends DslExp with Sliding {
def sliding(start: Rep[Int], end: Rep[Int])(f: Rep[Int] => Rep[Unit]): Rep[Unit] = {
(start until end) foreach f
}
}
trait SlidingExp extends DslExp with Sliding {
We use a LMS transformer to evaluate the loop body across various iterations and transformations.
object trans extends ForwardTransformer {
val IR: SlidingExp.this.type = SlidingExp.this
}
def log(x: Any): Unit = {
System.out.println("sliding log: "+x)
}
Some arithmetic rewrites to maximize sliding sharing.
override def int_plus(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext): Exp[Int] =
((lhs,rhs) match {
// (x+y)+z --> x+(y+z)
case (Def(IntPlus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_plus(x, unit(y+z))
// (x-y)+z --> x-(y-z)
case (Def(IntMinus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_minus(x, unit(y-z))
case (x: Exp[Int], Const(z:Int)) if z < 0 => int_minus(x, unit(-z))
case _ => super.int_plus(lhs,rhs)
}).asInstanceOf[Exp[Int]]
override def int_minus(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext): Exp[Int] =
((lhs,rhs) match {
// (x-y)-z --> x-(y+z)
case (Def(IntMinus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_minus(x, unit(y+z))
// (x+y)-z --> x+(y-z)
case (Def(IntPlus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_plus(x, unit(y-z))
case (x: Exp[Int], Const(z:Int)) if z < 0 => int_plus(x, unit(-z))
case _ => super.int_minus(lhs,rhs)
}).asInstanceOf[Exp[Int]]
type Subst = scala.collection.immutable.Map[Exp[Any],Exp[Any]]
type Doublet = (Rep[Unit], List[Stm])
type Triplet = (Rep[Unit], List[Stm], Subst)
Find the overlapping symbols: defined by f(i), used by f(i+1), f(i+2), … We create a helper function so that we can generalize to a fixpoint calculation in the subclass below.
def findOverlap(i: Sym[Int], f: Rep[Int] => Rep[Unit]): (List[Sym[Any]], Doublet, Triplet) = {
We evaluate loop contents f(i), then f(i+1), then f(i+2). For evaluation, we use the reified sub graph for f(i), but we substitute i for i+1 or i+2. Each time, we reflect the statements from the previous iteration.
val save = context
// evaluate loop contents f(i)
val (r0,stms0) = reifySubGraph(f(i))
val (((r1,stms1,subst1),(r2,stms2,subst2)), _) = reifySubGraph {
reflectSubGraph(stms0)
context = save
// evaluate loop contents f(i+1)
val ((r1,subst1),stms1) = reifySubGraph(trans.withSubstScope(i -> (i+1)) {
stms0.foreach(s=>trans.traverseStm(s))
(trans(r0), trans.subst)
})
val ((r2,stms2,subst2), _) = reifySubGraph {
reflectSubGraph(stms1)
context = save
// evaluate loop contents f(i+2)
val ((r2,subst2),stms2) = reifySubGraph(trans.withSubstScope(i -> (i+2)) {
stms0.foreach(s=>trans.traverseStm(s))
(trans(r0), trans.subst)
})
(r2,stms2,subst2)
}
((r1,stms1,subst1), (r2,stms2,subst2))
}
context = save
val defs = stms0.flatMap(_.lhs)
Here, we find the overlap symbols: defined by f(i), used by f(i+1) and f(i+2).
val overlap01 = stms1.flatMap { case TP(s,d) => syms(d) filter (defs contains _) }.distinct
val overlap02 = stms2.flatMap { case TP(s,d) => syms(d) filter (defs contains _) }.distinct
We haven't yet a fixed point yet, if there are still overlapping symbols in the last iteration. We will fix this in the subclass below.
if (overlap02.nonEmpty)
log("Overlap beyond a single loop iteration is ignored, since not yet implemented.")
val overlap0 = (overlap01++overlap02).distinct
(overlap0, (r0, stms0), (r1, stms1, subst1))
}
def sliding(start: Rep[Int], end: Rep[Int])(f: Rep[Int] => Rep[Unit]): Rep[Unit] = {
val i = fresh[Int]
val (overlap0, (r0, stms0), (r1, stms1, subst1)) = findOverlap(i, f)
val overlap1 = overlap0 map subst1
Build a variable for each overlap symbol. Initialize the variables by peeling first loop iteration.
if (end > start) {
val (rX,substX) = trans.withSubstScope(i -> start) {
stms0.foreach(s=>trans.traverseStm(s))
(trans(r0), trans.subst)
}
val vars = overlap0.map{x => var_new(substX(x))(x.tp,x.pos.head)}
Now generate the loop:
for (j <- (start + unit(1)) until end) {
generate_comment("variable reads")
val reads = (overlap0 zip vars).map(p => (p._1, readVar(p._2)(p._1.tp,p._1.pos.head)))
generate_comment("computation")
val (ri, substY1: Subst) = trans.withSubstScope((reads:+(i->(j-unit(1)))): _*) {
stms1.foreach(s=>trans.traverseStm(s))
(trans(r1), trans.subst)
}
generate_comment("variable writes")
val writes = (overlap1 zip vars).map{p =>
(p._1, var_assign(p._2, substY1(p._1))(p._1.tp,p._1.pos.head))}
}
}
}
}
trait SlidingMultiExp extends SlidingExp with DslExp with Sliding {
Find overlapping symbols between iterations until there are no new ones.
override def findOverlap(i: Sym[Int], f: Rep[Int] => Rep[Unit]) = {
val save = context
val (r0,stms0) = reifySubGraph(f(i))
val defs = stms0.flatMap(_.lhs)
def step(n: Int, last: List[Stm], acc: List[Triplet], overlap: List[Sym[Any]]): (Triplet, List[Sym[Any]]) = {
val (res: (Triplet, List[Sym[Any]]), _) = reifySubGraph {
reflectSubGraph(last)
context = save
val ((ri, substi), stmsi) = reifySubGraph(trans.withSubstScope(i -> (i+n)) {
stms0.foreach(s => trans.traverseStm(s))
(trans(r0), trans.subst)
})
val overlapi = stmsi.flatMap{ case TP(s,d) => syms(d) filter (defs contains _)}.distinct
if (overlapi.nonEmpty)
step(n+1, stmsi, ((ri, stmsi, substi):Triplet)::acc, (overlap++overlapi).distinct)
else {
log("stopping at "+n)
(acc.last, overlap)
}
}
res
}
val ((r1, stms1, subst1), overlap0) = step(1, stms0, Nil, Nil)
context = save
(overlap0, (r0, stms0), (r1, stms1, subst1))
}
}
trait SlidingWarmup extends Sliding {
def snippet(n: Rep[Int]): Rep[Array[Int]] = {
def compute(i: Rep[Int]) = 2*i+3
n sliding { i => compute(i) + compute(i+1) }
}
}
abstract class SlidingWarmupDriver extends DslDriver[Int,Array[Int]] with SlidingWarmup
class SlidingWarmupTest extends TutorialFunSuite {
val under = "sliding"
test("warmup without sliding") {
val sliding0 = new SlidingWarmupDriver with NoSlidingExp
check("0", sliding0.code)
}
test("warmup with sliding") {
val sliding1 = new SlidingWarmupDriver with SlidingExp
check("1", sliding1.code)
}
test("warmup with multi sliding") {
val sliding2 = new SlidingWarmupDriver with SlidingMultiExp
check("1", sliding2.code) // same as single sliding
}
test("warmup equal") {
val sliding0 = new SlidingWarmupDriver with NoSlidingExp
val sliding1 = new SlidingWarmupDriver with SlidingExp
val sliding2 = new SlidingWarmupDriver with SlidingMultiExp
val input = 5
assert(sliding0.eval(input).mkString(",") == sliding1.eval(input).mkString(","))
assert(sliding0.eval(input).mkString(",") == sliding2.eval(input).mkString(","))
}
}
/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Int)=>(Array[Int])) {
def apply(x0:Int): Array[Int] = {
val x1 = new Array[Int](x0)
var x3 : Int = 0
val x12 = while (x3 < x0) {
val x4 = 2 * x3
val x5 = x4 + 3
val x6 = x3 + 1
val x7 = 2 * x6
val x8 = x7 + 3
val x9 = x5 + x8
val x10 = x1(x3) = x9
x3 = x3 + 1
}
x1
}
}
/*****************************************
End of Generated Code
*******************************************/
/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Int)=>(Array[Int])) {
def apply(x0:Int): Array[Int] = {
val x1 = new Array[Int](x0)
val x20 = x0 > 0
val x42 = if (x20) {
val x21 = x1(0) = 8
var x22: Int = 5
var x23: Int = 1
var x25 : Int = 1
val x40 = while (x25 < x0) {
// variable reads
val x27 = x22
val x28 = x23
// computation
val x31 = x25 + 1
val x32 = 2 * x31
val x33 = x32 + 3
val x34 = x27 + x33
val x35 = x1(x28) = x34
// variable writes
x22 = x33
x23 = x31
x25 = x25 + 1
}
x40
} else {
()
}
x1
}
}
/*****************************************
End of Generated Code
*******************************************/
/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Int)=>(Array[Int])) {
def apply(x0:Int): Array[Int] = {
val x1 = new Array[Int](x0)
val x20 = x0 > 0
val x46 = if (x20) {
val x21 = x1(0) = 8
var x22: Int = 5
var x23: Int = 1
var x25 : Int = 0
val x44 = while (x25 < x0) {
// unrolled for k=0
val x27 = x25 + 1
val x28 = x27 < x0
val x42 = if (x28) {
// variable reads
val x30 = x22
val x31 = x23
// computation
val x33 = x25 + 2
val x34 = 2 * x33
val x35 = x34 + 3
val x36 = x30 + x35
val x37 = x1(x31) = x36
// variable writes
x22 = x35
x23 = x33
()
} else {
()
}
x25 = x25 + 1
}
x44
} else {
()
}
x1
}
}
/*****************************************
End of Generated Code
*******************************************/
Comments? Suggestions for improvement? View this file on GitHub.