Tuesday, February 28, 2017

Implementation notes on my auto differentiation for weight sharing

Update: My Neural Network Posts

  • Basic Auto Differentiation: Post
  • Basic Auto Differentiation with Weight Sharing: Post
  • Auto Differentiation Code from Book: Post
  • Refactoring The Ideas: Post
In my last post I wrote about building a very crude home made auto differentiation engine. As a side note I mentioned that if we get multiple incoming gradients we have to sum them. And also provided an option to sum the gradients and push them out instantly downstream. However in order to implement proper back propagation we have to wait until all incoming gradients have to be summed up before the gradient is propagated. This is especially important for weight sharing needed in order to implement recurrent neural networks such as LSTMs or more complex models such as the variational auto encoder.

If you are interested in the full code with the examples, I uploaded a Scala version Github 

Implementing the base computation node

The forward pass is the same as in the previous posts. The forward pass
simply computes it's activations using it's child nodes forward computations.
In this case the abstract class already holds the input nodes.
The backward pass follows the same idea too. Based on the gradients
from higher nodes and the local gradient in each node, we push the error
backwards through our computation. The only difference is that I added a wait
list which saves all parent node names. If all parents send their gradients,
we sum all of them and push down the error. Furthermore, I added
a recursive plotting function of the compute graph. The downside is now that we
have to register all parents at each node. During the graph plotting the parent
and wait relations are also checked.


abstract class Node(val name: String, inputs: Vector[Node]) {
  /**
    * Names of parents who have to send a gradient before backwards
    * pass continues
    */
  val WaitFor   = mutable.Set.empty[String]

  /**
    * All parents gradients
    */
  val Gradients = mutable.HashMap.empty[String, DenseMatrix[Double]]

  /**
    * Set value of forwards computation
    */
  protected var Output: Option[DenseMatrix[Double]] = None

  /**
    * Should attach node type to name
    */
  def typename: String

  /**
    * Recursively plot parent child relationships in dot format.
    * Also output error if the waitlist does not include the expected parent
    */
  def graph(parent: String): List[String] = {
    if(!WaitFor.contains(parent) && parent.size > 1) throw new Exception(s"${parent} ${typename}_${name} connection not found")
    if(inputs.size == 0) List.empty[String]
    else {
      (for (in <- -="" in.name="" in.typename="" inputs="" s="" yield=""> ${typename}_${name};" :: in.graph(name))).flatten.toList
    }
  }

  def compute(x: Vector[DenseMatrix[Double]]): DenseMatrix[Double]

  def distribute(gradient: DenseMatrix[Double]): Unit

  def gradient: DenseMatrix[Double] = Gradients.map(_._2).reduce(_ + _)

  def register(parent: String): Unit = WaitFor += parent

  def reset: Unit = {
    Gradients.clear()
    Output = None
    inputs.foreach(_.reset)
  }

  def fwd: DenseMatrix[Double] = Output match {
    case Some(x) => x
    case None => {
      val x = compute(inputs.map(_.fwd))
      Output = Some(x)
      x
    }
  }

  def ready = (WaitFor & Gradients.keySet) == WaitFor

  def bwd(parent: String, gradient: DenseMatrix[Double]): Unit = {
    Gradients += parent -> gradient
    if(ready) distribute(gradient)
  }

}


The rest of the computations such as multiply, add and the dot product stay the same.
Now we are actually able to build models with weight sharing.

A simple LSTM

Now we can implement a simple LSTM node. An LSTM is a recurrent node
that holds it's state through multiple time steps using a cell node. In the LSTM node
this cell can be updated based on the current input, a hidden state and the lasts cell state.
An LSTM node is wired in the following way.



This wiring actually translate into the following code and the graph below.

class LSTM(val in: Node, val ht: Node, val ct: Node, wf: Tensor, bf: Tensor, wi: Tensor, bi: Tensor, wo: Tensor, bo: Tensor, wc: Tensor, bc: Tensor) {

  val x = new Concat(in ,ht)
  val input  = new FeedForwardSigmoid(x, wi, bi)
  val output = new FeedForwardSigmoid(x, wo, bo)
  val forget = new FeedForwardSigmoid(x, wf, bf)
  val cell   = new FeedForwardTanH(x, wc, bc)

  val forgetCell = new Mul(forget.Layer, ct)
  val inCell     = new Mul(input.Layer, cell.Layer)
  val cellState  = new Add(inCell, forgetCell)
  val cellAct    = new Tanh(cellState)
  val hidden     = new Mul(output.Layer, cellAct)

  final val Layer = {
    in.register(x.name)
    ht.register(x.name)

    forget.Layer.register(forgetCell.name)
    ct.register(forgetCell.name)

    input.Layer.register(inCell.name)
    cell.Layer.register(inCell.name)

    inCell.register(cellState.name)
    forgetCell.register(cellState.name)

    cellState.register(cellAct.name)

    output.Layer.register(hidden.name)
    cellAct.register(hidden.name)

    (cellState, hidden)
  }

  def update(rate: Double): Unit = {
      wf := wf.fwd - ((rate * wf.gradient))
      bf := bf.fwd - ((rate * bf.gradient))
      wi := wi.fwd - ((rate * wi.gradient))
      bi := bi.fwd - ((rate * bi.gradient))
      wc := wc.fwd - ((rate * wc.gradient))
      bc := bc.fwd - ((rate * bc.gradient))
      wo := wo.fwd - ((rate * wo.gradient))
      bo := bo.fwd - ((rate * bo.gradient))
  }
}






Then for some given input we can unroll multiple cells and run regular stochastic gradient descent.
Given a list of inputs, representing a time series, we can unroll the LSTM. Notice that we reuse the same weights for each time step. In other words, if we unroll for 5 time steps, every weight will receive 5 gradient updates.


  def unroll(in: List[Node]): (LSTM,List[FeedForwardTanH]) =
    if(in.size == 1) {
      val lstm = new LSTM(in.head, h, c, wf, bf, wi, bi, wo, bo, wc, bc)
      val classify  = new FeedForwardTanH(lstm.Layer._2, wl, bl)
      (lstm, List(classify))
    } else {
      val (unrolled, out) = unroll(in.tail)
      val lstm = new LSTM(in.head, unrolled.Layer._2, unrolled.Layer._1, wf, bf, wi, bi, wo, bo, wc, bc)
      val classify  = new FeedForwardTanH(lstm.Layer._2, wl, bl)
      (lstm, classify :: out)
    }


I trained the network with step size 5 on a trigonometric function and then generated 100 points
from an initial 5 points. Here is the result.







This looks promising I guess. So now we can generate some functions.


Variational Auto Encoders

The next model is a variational auto encoder. Here we do not connect the encoder
directly to the decoder but let the encoder generate a Gaussian distribution that when
sampled is used to generate the input when pushed through the decoder. The sample is
drawn uniformly from a multivariate gaussian with zero mean and variance one.

val enc   = new FeedForwardTanH(in, wenc, benc)
val mean1 = new Inner(enc.Layer, wmean)
val mean  = new Add(mean1, bmean)
val sigm1  = new Inner(enc.Layer, wsig)
val sigm2  = new Add(sigm1, bsig)
val sigm   = new Mul(sigm2, half)
val d1 = new Exp(sigm)
val d2 = new Mul(sample, d1)
val dist = new Add(mean, d2)
val dec  = new FeedForwardSigmoid(dist, wdec, bdec)


Or better to read in the final graph





I ran a very very few iterations on 200 random MNIST images. If we push uniform samples
through the decoder, we get images like these. Which also looks promising.



Most of the results are not amazing. The speed of the tool is not comparable at all to Tensorflow
and in turn I do not process a lot of data. Also, my hidden layers are very small. Furthermore, the only optimization I do is simple stochastic gradient descent. But this post is about differentiation and weight sharing not learning awesome models ... i guess. Again all the stuff is on github. In the next post I will go into developing a more performant variant of this algorithm.

Furthermore the new Deep Learning Book [3] defines the backpropagation algorithm in the general case as a dynamic programming problem filling a table. However, I think my method of directly coding the graph is a little better to explain.

References

[1] Hochreiter, Schmidhuber: "Long Short Term Memory", Journal of Neural Computation, 97
[2] Variational Auto Encoder, https://arxiv.org/abs/1606.05908
[3] Goodfellow, Bengio, Courville: "Deep Learning", MIT Press 2016



No comments:

Post a Comment

Note: Only a member of this blog may post a comment.