start.scala // Jump To …

Getting Started

Outline:

Staging and LMS

What is staging? The idea behind staging is to delay computation of certain expressions, generating code to compute them later. The benefit is abstraction without regret: using high-level programming abstractions once, to structure generated code, instead of all the time during execution.

Lightweight Modular Staging (LMS) is a staging technique driven by types. In addition to the staging aspect, the technique is lightweight, because it is purely library based. It is also modular, because features can be mixed and matched and the framework is easy to extend.

My First Multi-Stage Program

This tutorial is a literate Scala file. We invite you to clone the GitHub repo and play with the code as you follow along.

package scala.lms.tutorial

import scala.lms.common._

class GettingStartedTest extends TutorialFunSuite {
  val under = "dslapi"


Rep[T] vs T

In LMS, Rep[T] represents a delayed computation of type T. Thus, during staging, an expression of type Rep[T] becomes part of the generated code, while an expression of bare type T becomes a constant in the generated code. For core Scala features, adding Rep types should be enough to build a program generator, as we will see later.

  test("1") {
    val snippet = new DslDriver[Int,Int] {
      def snippet(x: Rep[Int]) = {

        def compute(b: Boolean): Rep[Int] = {
          // the if is executed in the first stage
          if (b) 1 else x
        }
        compute(true)+compute(1==1)

      }
    }
    check("1", snippet.code)
    assert(snippet.eval(0) === 2)
  }
/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Int)=>(Int)) {
  def apply(x0:Int): Int = {
    2
  }
}
/*****************************************
End of Generated Code
*******************************************/

Contrast the snippet above, where b is a Boolean, with the snippet below, where b is a Rep[Boolean]. The expression if (b) 1 else x is executed at staging time when b is a Boolean, while it is delayed causing code to be generated for the if expression when b is a Rep[Boolean] – indeed, the actual value of b is not known at staging time, but only when the generated code is executed.

  test("2") {
    val snippet = new  DslDriver[Int,Int] {
      def snippet(x: Rep[Int]) = {

        def compute(b: Rep[Boolean]): Rep[Int] = {
          // the if is deferred to the second stage
          if (b) 1 else x
        }
        compute(x==1)

      }
    }
    check("2", snippet.code)
    assert(snippet.eval(2) === 2)
  }
/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Int)=>(Int)) {
  def apply(x0:Int): Int = {
    val x1 = x0 == 1
    val x2 = if (x1) {
      1
    } else {
      x0
    }
    x2
  }
}
/*****************************************
End of Generated Code
*******************************************/

Rep[A => B] vs Rep[A]=>Rep[B]

In the previous snippets, we already notice some abstraction without regret: the compute function exists only at staging time, and is not part of the generated code – more generally, we can freely use abstractions to structure and compose the staged program, but these abstractions are not part of the generated code when their type is a bare T as opposed of a Rep[T]. In the second snippet, compute has the type Rep[Boolean] => Rep[Int], not Rep[Boolean => Int] – its type already tells us that the function is known at staging time.

Similarly, below, the recursive power function and the helper square function only exists during staging time.

  test("power") {
    val snippet = new DslDriver[Int,Int] {
      def square(x: Rep[Int]): Rep[Int] = x*x

      def power(b: Rep[Int], n: Int): Rep[Int] =
        if (n == 0) 1
        else if (n % 2 == 0) square(power(b, n/2))
        else b * power(b, n-1)

      def snippet(b: Rep[Int]) =
        power(b, 7)

    }
    check("power", snippet.code)
    assert(snippet.eval(2) === 128)
  }

Because of common subexpression elimination, we get reuse of the square argument for free.

/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Int)=>(Int)) {
  def apply(x0:Int): Int = {
    val x1 = x0 * x0
    val x2 = x0 * x1
    val x3 = x2 * x2
    val x4 = x0 * x3
    x4
  }
}
/*****************************************
End of Generated Code
*******************************************/

We could also create a generated square function, of type Rep[Int=>Int] instead of Rep[Int]=>Rep[Int].

  test("power with fun square") {
    val snippet = new DslDriver[Int,Int] {
      def square: Rep[Int=>Int] = fun {x => x*x}

      def power(b: Rep[Int], n: Int): Rep[Int] =
        if (n == 0) 1
        else if (n % 2 == 0) square(power(b, n/2))
        else b * power(b, n-1)

      def snippet(b: Rep[Int]) =
        power(b, 7)

    }
    check("powerfunsquare", snippet.code)
    assert(snippet.eval(2) === 128)
  }

The code we get is in fact slightly less efficient, because of these extra calls to the generated square function.

/*****************************************
Emitting Generated Code
*******************************************/
class Snippet extends ((Int)=>(Int)) {
  def apply(x0:Int): Int = {
    val x1 = {x2: (Int) =>
      val x3 = x2 * x2
      x3: Int
    }
    val x4 = x1(x0)
    val x5 = x0 * x4
    val x6 = x1(x5)
    val x7 = x0 * x6
    x7
  }
}
/*****************************************
End of Generated Code
*******************************************/

Rep[Range] vs Range

Loops can be unrolled in the first stage, or be generated as loops in the second stage, driven by the type of their condition.

  test("range1") {
    val snippet = new DslDriver[Int,Unit] {
      def snippet(x: Rep[Int]) = comment("for", verbose = false) {

        for (i <- (0 until 3): Range) {
          println(i)
        }

      }
    }
    check("range1", snippet.code)
  }
      // generated code
      val x1 = println(0)
      val x2 = println(1)
      val x3 = println(2)
      ()
  test("range2") {
    val snippet = new DslDriver[Int,Unit] {
      def snippet(x: Rep[Int]) = comment("for", verbose = false) {

        for (i <- (0 until x): Rep[Range]) {
          println(i)
        }

      }
    }
    check("range2", snippet.code)
  }
}
      // generated code
      var x2 : Int = 0
      val x5 = while (x2 < x0) {
        val x3 = println(x2)
        x2 = x2 + 1
      }
      x5

What's next?

Go back to the tutorial index or continue with the Shonan Challenge.

Comments? Suggestions for improvement? View this file on GitHub.