Shallow neural networks as Wasserstein gradient flows

From Optimal Transport Wiki
Jump to navigation Jump to search


Motivation

Artificial neural networks (ANNs) consist of layers of artificial "neurons" which take in information from the previous layer and output information to neurons in the next layer. Gradient descent is a common method for updating the weights of each neuron based on training data. While in practice every layer of a neural network has only finitely many neurons, it is beneficial to consider a neural network layer with infinitely many neurons, for the sake of developing a theory that explains how ANNs work. In particular, from this viewpoint the process of updating the neuron weights for a shallow neural network can be described by a Wasserstein gradient flow.


Single Layer Neural Networks

{See also Mathematics of Artificial Neural Networks}

explain, add figure

Discrete Formulation

Let us introduce the mathematical framework and notation for a neural network with a single hidden layer.[1] Let be open . The set represents the space of inputs into the network. There is some unknown function which we would like to approximate. Let be the number of neurons in the hidden layer. Define

be given by

where is a fixed activation function and is a space of possible parameters . The goal is to use training data to repeatedly update the weights and based on how close is to the function . More concretely, we want to find that minimizes the loss function:

A standard way to choose and update the weights is to start with a random choice of weights and perform gradient descent on these parameters. Unfortunately, this problem is non-convex, so the minimizer may not be achieved. It turns out in practice, neural networks are surprisingly good at finding the minimizer. A nicer minimization problem that may provide insight into how neural networks work is a neural network model with infinitely many neurons.


Continuous Formulation

For the continuous formulation (i.e. when ), we rephrase the above mathematical framework. In this case, it no longer makes sense to look for weights that minimize the loss function. We instead look for a probability measure such that

minimizes the loss function:

.

Here is an activation function with parameter .

Note that by restricting choices of to probability measures of the form , the above minimization problem generalizes to case with finitely many neurons as well.

To avoid overfitting the network to the training data, a potential term is added the loss function. For the remainder of this article, we define the loss function to be:

for a convex potential function . Often we choose . In fact, is convex (along linear interpolations), in contrast to the minimization function in the finite neuron case.

Gradient Flow

When , the gradient flow of a differentiable function starting at a point is a curve satisfying the differential equation

.

where is the gradient of f.[2]

Crucially, the gradient flow heads in the direction that decreases the value of the fastest. We would like to use this nice property of gradient flow in our setting with the functional . However, it is not immediately straightforward how to do this, since is defined on the space of probability measures, rather than on , so the usual gradient is not defined. Before we generalize the notion of gradient flow, note that

.

Recall that for a vector field .

Motivated by this and using the notion of a subdifferential we can define the gradient flow in a Hilbert space of a convex and lower semi-continuous functional . An absolutely continuous curve is a gradient flow for starting at if

and for almost every

where is the subdifferential of . [3]


Wasserstein Distance

Let us first define the (pth) Wasserstein distance between two probability measures. [4] This can be defined for probability measures in any separable metric space. Let and let denote the space of transport plans from to . Define the pth Wasserstein distance from to to be:

where is the distance between and in the metric space . In this context, is just a subset of for some with the euclidean metric. Often .

First Variation

Wasserstein Subdifferential

Wasserstein Gradient Flow

Main Results

Consistency Between Infinite and Finite Cases

References