The Multivariable Chain Rule

Sharing is caring

In this post we learn how to apply the chain rule to vector-valued functions with multiple variables.

We’ve seen how to apply the chain rule to real number functions. Now we can extend this concept into higher dimensions.

How does the Multivariable Chain Rule Work?

Remember that the chain rule helps us differentiate nested functions.
If we have a function f of multiple variables x and y, which are themselves functions of another variable r, we can calculate the total differential.

\frac{d}{dr}f(x(r), y(r)) = 
\frac{\partial f}{\partial x} \frac{d x}{d r} + 
\frac{\partial f}{\partial y} \frac{d y}{d r} 

As we’ve seen when constructing the Jacobian matrix, then treating x(r) and y(r) as disparate functions, we can write them together in a vector.

\vec{v} = \begin{bmatrix}
    x(r) \\
    y(r)
  \end{bmatrix}

Note that I am writing v with this tiny arrow on top to distinguish it from the other non-vector variables.

Accordingly, we can also write the derivatives as vectors.

Partial outer derivatives of f with respect to x and y.

Nested derivatives of x and y with respect to r.

 \frac{\partial f}{\partial (x,y)} =
\begin{bmatrix}
    \frac{\partial f}{\partial x}  \\
    \frac{\partial f}{\partial y}
  \end{bmatrix}

 \frac{d (x,y)}{d r} =
\begin{bmatrix}
    \frac{d x}{d r}   \\
    \frac{d y}{d r} 
  \end{bmatrix}

Now we can write the total derivative of f with respect to the nested variable r as a dot product of the two vectors.

\frac{df}{dr} = \frac{\partial f}{\partial (x,y)}  \frac{d (x,y)}{d r}  = 
\begin{bmatrix}
    \frac{\partial f}{\partial x}  \\
    \frac{\partial f}{\partial y}
  \end{bmatrix}
\cdot
\begin{bmatrix}
    \frac{d x}{d r}   \\
    \frac{d y}{d r} 
  \end{bmatrix}

Example

f(x,y) = 2x^2 + 3y \\
x(r) = r^2 - 1 \\
y(r) = 2r^2+3

Let’s first calculate the partial derivatives of f with respect to x, y, and the derivatives for x, y with respect to r.

    \frac{\partial f}{\partial x}  = 4x = 4r^2-4 \\

    \frac{\partial f}{\partial y} =3\\

    \frac{d x}{d r} = 2r  \\
    \frac{d y}{d r} =4r

Let’s write them in vector format as a dot product and multiply out.

\begin{bmatrix}
    \frac{\partial f}{\partial x}  \\
    \frac{\partial f}{\partial y}
  \end{bmatrix}
\cdot
\begin{bmatrix}
    \frac{d x}{d r}   \\
    \frac{d y}{d r} 
  \end{bmatrix}
=

\begin{bmatrix}
    4r^2-4  \\
    3
  \end{bmatrix}
\cdot
\begin{bmatrix}
   2r   \\
   4r
  \end{bmatrix} 
= 8r^3 + 4r

Alternatively, we can eliminate x and y from the start by substituting the appropriate terms of r.

f(r) = 2(r^2-1)^2 + 3(2r^2+3)\\
= 2r^4-4r^2+2+6r^2+6

Now we can simply differentiate, which gives us the following.

\frac{df}{dr} = 8r^3 + 4r

It resolves to the same term as when we applied the chain rule. In this simple case, it is probably faster to use the second method. But once you are dealing with many nested variables, the chain rule is a much better and more scalable approach.

With the multivariate chain rule, the Jacobian, and the Hessian under our belt, we have all the conceptual tools to understand how neural networks work.

This post is part of a series on Calculus for Machine Learning. To read the other posts, go to the index.


Sharing is caring