Awesome SILO seminar this week by Suriya Gunasekar of TTI Chicago. Here’s the idea, as I understand it. In a classical optimization problem, like linear regression, you are trying to solve a problem which typically has no solution (draw a line that passes through every point in this cloud!) and the challenge is to find the best approximate solution. Algebraically speaking: you might be asked to solve
for x; but since x may not be in the image of the linear transformation A, you settle for minimizing
in whatever norm you like (L^2 for standard linear regression.)
In many modern optimization problems, on the other hand, the problem you’re trying to solve may have a lot more degrees of freedom. Maybe you’re setting up an RNN with lots and lots and lots of parameters. Or maybe, to bring this down to earth, you’re trying to pass a curve through lots of points but the curve is allowed to have very high degree. This has the advantage that you can definitely find a curve that passes through all the points. But it also has the disadvantage that you can definitely find a curve that passes through all the points. You are likely to overfit! Your wildly wiggly curve, engineered to exactly fit the data you trained on, is unlikely to generalize well to future data.
Everybody knows about this problem, everybody knows to worry about it. But here’s the thing. A lot of modern problems are of this form, and yet the optima we find on training data often do generalize pretty well to test data! Why?
Make this more formal. Let’s say for the sake of argument you’re trying to learn a real-valued function F, which you hypothesize is drawn from some giant space X. (Not necessarily a vector space, just any old space.) You have N training pairs (x_i, y_i), and a good choice for F might be one such that F(x_i) = y_i. So you might try to find F such that
for all i. But if X is big enough, there will be a whole space of functions F which do the trick! The solution set to
will be some big subspace F_{x,y} of X. How do you know which of these F’s to pick?
One popular way is to regularize; you decide that some elements of X are just better than others, and choose the point of F_{x,y} that optimizes that objective. For instance, if you’re curve-fitting, you might try to find, among those curves passing through your N points, the least wiggly one (e.g. the one with the least total curvature.) Or you might optimize for some combination of hitting the points and non-wiggliness, arriving at a compromise curve that wiggles only mildly and still passes near most of the points. (The ultimate version of this strategy would be to retreat all the way back to linear regression.)
But it’s not obvious what regularization objective to choose, and maybe trying to optimize that objective is yet another hard computational problem, and so on and so on. What’s really surprising is that something much simpler often works pretty well. Namely: how would you find F such that F(x) = y in the first place? You would choose some random F in X, then do some version of gradient descent. Find the direction in the tangent space to X at F that decreases most steeply, perturb F a bit in that direction, lather, rinse, repeat.
If this process converges, it ought to get you somewhere on the solution space F_{x,y}. But where? And this is really what Gunasekar’s work is about. Even if your starting F is distributed broadly, the distribution of the spot where gradient descent “lands” on F_{x,y} can be much more sharply focused. In some cases, it’s concentrated on a single point! The “likely targets of gradient descent” seem to generalize better to test data, and in some cases Gunasekar et al can prove gradient descent likes to find the points on F_{x,y} which optimize some regularizer.
I was really struck by this outlook. I have tended to think of function learning as a problem of optimization; how can you effectively minimize the training loss ||F(x) – y||? But Gunasekar asks us instead to think about the much richer mathematical structure of the dynamical system of gradient descent on X guided by the loss function. (Or I should say dynamical systems; gradient descent comes in many flavors.)
The dynamical system has a lot more stuff in it! Think about iterating a function; knowing the fixed points is one thing, but knowing which fixed points are stable and which aren’t, and knowing which stable points have big basins of attraction, tells you way more.
What’s more, the dynamical system formulation is much more natural for learning problems as they are so often encountered in life, with streaming rather than static training data. If you are constantly observing more pairs (x_i,y_i), you don’t want to have to start over every second and optimize a new loss function! But if you take the primary object of study to be, not the loss function, but the dynamical system on the hypothesis space X, new data is no problem; your gradient is just a longer and longer sum with each timestep (or you exponentially deweight the older data, whatever you want my friend, the world is yours.)
Anyway. Loved this talk. Maybe this dynamical framework is the way other people are already accustomed to think of it but it was news to me.
Slides for a talk of Gunasekar’s similar to the one she gave here
“Characterizing Implicit Bias in terms of Optimization Geometry” (2018)
“Convergence of Gradient Descent on Separable Data” (2018)
A little googling for gradient descent and dynamical systems shows me that, unsurprisingly, Ben Recht is on this train.