Talk: Differentiation of black-box combinatorial solvers

Michal Rolinek
Combinatorics, Machine learning

The goal is to merge combinatorial optimization and deep learning.

Make use of strong battle tested optimization methods. Some of those can find almost-optimal solutions to NP-hard problems in ~quadratic time.

Goal is to cover many combinatorial problems, TSP multi-cut, etc.

  • fast backward pass
  • theoretically sound
  • easy to use

But the goal is not to take a combinatorial problem but just relax it to make it differentiable, because there is often a huge price to pay for this.

Many think that it is essential to extend classical deep learning with combinatorics for AI.


We get an input -> learn a representation with deep learning -> use existing solver for features and maybe pass this output in another layer of deep learning

A solve is a function taking continuous inputs and returns a discrete output (TSP: graph nodes coordinates -> shortest path).

Although objectives are often linear, this is usually still a huge cost. Many classical problems fall into this category. The main difficulty is often the gradient of this black-box optimizer. Usually those solvers are contrary to general opinion perfectly differentiable, actually piece-wise constant. -> problem the gradient is almost always 0. Estimating the “real” gradient doesn’t help at all.

Some non-solutions:

  • Sample finite differences
  • Apply smoothing
  • Use some zero-order method?

But these methods need a lot a samples to be accurate, and the samples are potentially very expensive because the solver is expensive.

\(x \rightarrow … \rightarrow w \rightarrow y … \rightarrow L\) Let’s consider \(L(w)\), sometimes a function \(f(w)\) can be representative.

Interpolate \(L(w)\) to \(L^\lambda(w)\) where lambda controls “locality” of interpolation. But the interpolation is implicit, and the gradient of this interpolated function can be estimated with only one evaluation of the solver. This exploits the fact that the solver minimizes a linear objective.

Input \(dL/dy\) and output \(dL/dw\)

Lambda shifts “islands” of constant values and linear slope appears in between them.

← Back to Notes