Artificial Intelligence 🤖
Gradient Descent

Gradient Descent

Computation graphs are visual representations of mathematical operations that go from inputs to outputs. They're particularly useful in deep learning to organize calculations. For example, consider the function:

J(a,b,c)=3(a+bc)J(a, b, c) = 3(a + b \cdot c)

In this function, we have intermediate variables:

u=bcu = b \cdot c v=a+uv = a + u

And the final output:

J=3vJ = 3 \cdot v

We can represent this as:

Computation Graph

Derivatives with a Computation Graph

In a computation graph, we often want to calculate how changes in inputs affect the outputs. This is where the chain rule is key. It states that for two functions ff and gg, the derivative of heir composition (fg)(x)=f(g(x))(f \circ g)(x) = f(g(x)) with respect to xx is:

(fg)(x)=f(g(x))g(x)(f \circ g)'(x) = f'(g(x)) \cdot g'(x)

In terms of variables, if y=f(u)y = f(u) and u=g(x)u = g(x), then the derivative of yy with respect to xx is given by:

dydx=dydududx\frac{dy}{dx} = \frac{dy}{du} \frac{du}{dx}

Here, dydu\frac{dy}{du} is evaluated at the point u=g(x)u = g(x).

In the context of a computation graph, we calculate derivatives from the output back to the inputs. This ight-to-left pass that's more natural for computing derivatives is known as backpropagation in machine learning. The notation dvard\text{var} represents the derivative of the final output with respect to various intermediate quantities.

Backpropagation Example

Take the visual representation of a computation graph with forward pass above. For the backward pass, we want to compute the derivatives of the final output JJ with respect to the inputs a,b,c a, b, c and along the way we will need it in terms of the intermediate quantities u,vu, v.

  1. Compute intermediates dJ/dvdJ/dv and dJ/dudJ/du, which is the derivative of the final output JJ with respect to uu and vv.
dJdv=3 \frac{dJ}{dv} = 3 dJdu=dJdvdvdu=31=3 \frac{dJ}{du} = \frac{dJ}{dv} \cdot \frac{dv}{du} = 3 \cdot 1 = 3
  1. Compute dJ/dadJ/da, which is the derivative of the final output JJ with respect to aa.
dJda=dJdvdvda=31=3 \frac{dJ}{da} = \frac{dJ}{dv} \cdot \frac{dv}{da} = 3 \cdot 1 = 3
  1. Compute dJ/dbdJ/db, which is the derivative of the final output JJ with respect to bb.
dJdb=dJdududb=3c=3c \frac{dJ}{db} = \frac{dJ}{du} \cdot \frac{du}{db} = 3 \cdot c = 3c
  1. Compute dJ/dcdJ/dc, which is the derivative of the final output JJ with respect to cc.
dJdc=dJdududc=3b=3b \frac{dJ}{dc} = \frac{dJ}{du} \cdot \frac{du}{dc} = 3 \cdot b = 3b

We compute these derivatives going from right to left, applying the chain rule at each step.These gradients tell us how much a change in a certain variable affects the final output. This is crucial for algorithms like gradient descent, where we need to adjust parameters to minimize a cost function.

Logistic Regression Gradient Descent

Similarly take the derivatives of gradient descent example for one sample with two features:

Computation Graph Logistic Regression

  • zz is the linear combination of inputs and weights plus the bias.
  • aa is the activation computed using the sigmoid function σ(z)\sigma(z).
  • L(a,y)\mathcal{L}(a, y) is the loss function, which is the binary cross-entropy.
  • The derivatives dLdw1\frac{d\mathcal{L}}{dw_1}, dLdw2\frac{d\mathcal{L}}{dw_2}, and dLdb\frac{d\mathcal{L}}{db} describe how the loss function changes with respect to each parameter.
  • The update rules define how to adjust the parameters w1w_1, w2w_2, and bb in the direction that minimizes the loss, using a learning rate α \alpha

The derivatives for the Backward pass. First the derivative of the binary cross-entropy loss function:

dL(a,y)da=(ya1y1a)\frac{d\mathcal{L}(a, y)}{da} = -\left(\frac{y}{a} - \frac{1 - y}{1 - a}\right)

For the derivative of the sigmoid function

dadz=a(1a)\frac{da}{dz} = a \cdot (1 - a)

For zz we have:

dLdz=dLdadadz=ay\frac{d\mathcal{L}}{dz} = \frac{d\mathcal{L}}{da} \cdot \frac{da}{dz} = a - y

Now for the partial derivatives with respect to weights and bias:

dLdw1=x1dLdz\frac{d\mathcal{L}}{dw_1} = x_1 \cdot \frac{d\mathcal{L}}{dz} dLdw2=x2dLdz\frac{d\mathcal{L}}{dw_2} = x_2 \cdot \frac{d\mathcal{L}}{dz} dLdb=dLdz\frac{d\mathcal{L}}{db} = \frac{d\mathcal{L}}{dz}

And now the update rules for weights and bias

w1:=w1αdLdw1w_1 := w_1 - \alpha \frac{d\mathcal{L}}{dw_1} w2:=w2αdLdw2w_2 := w_2 - \alpha \frac{d\mathcal{L}}{dw_2} b:=bαdLdbb := b - \alpha \frac{d\mathcal{L}}{db}

Gradient Descent on mm Examples

Take the cost function JJ for logistic regression which is the average of the binary cross-entropy loss over all mm training examples:

J(w,b)=1mi=1mL(y^(i),y(i))=1mi=1m(y(i)log(y^(i))+(1y(i))log(1y^(i)))\begin{align*} J(w, b) &= -\frac{1}{m} \sum_{i=1}^{m} \mathcal{L}(\hat{y}^{(i)}, y^{(i)}) \\ &= -\frac{1}{m} \sum_{i=1}^{m} \left( y^{(i)} \log(\hat{y}^{(i)}) + (1 - y^{(i)}) \log(1 - \hat{y}^{(i)}) \right) \end{align*}

We can also calculate the average gradient across all examples. The derivative of JJ with respect to wjw_j is:

J(w,b)wj=1mi=1mwjL(a(i),y(i))\frac{\partial J(w,b)}{\partial w*j} = \frac{1}{m} \sum_{i=1}^{m} \frac{\partial}{\partial w_j} \mathcal{L}(a^{(i)}, y^{(i)})

Once you've computed the derivatives of the cost function JJ with respect to each parameter, we update the parameters ww and bb to minimize the cost function JJ.