mirror of
https://github.com/boostorg/math.git
synced 2026-02-25 04:22:15 +00:00
394 lines
50 KiB
HTML
394 lines
50 KiB
HTML
<html>
|
|
<head>
|
|
<meta charset="UTF-8">
|
|
<title>Nesterov Accelerated Gradient Descent</title>
|
|
<link rel="stylesheet" href="../../math.css" type="text/css">
|
|
<meta name="generator" content="DocBook XSL Stylesheets Vsnapshot">
|
|
<link rel="home" href="../../index.html" title="Math Toolkit 4.2.1">
|
|
<link rel="up" href="../gd_opt.html" title="Gradient Based Optimizers">
|
|
<link rel="prev" href="gradient_descent.html" title="Gradient Descent">
|
|
<link rel="next" href="lbfgs.html" title="L-BFGS">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
</head>
|
|
<body bgcolor="white" text="black" link="#0000FF" vlink="#840084" alink="#0000FF">
|
|
<table cellpadding="2" width="100%"><tr>
|
|
<td valign="top"><img alt="Boost C++ Libraries" width="277" height="86" src="../../../../../../boost.png"></td>
|
|
<td align="center"><a href="../../../../../../index.html">Home</a></td>
|
|
<td align="center"><a href="../../../../../../libs/libraries.htm">Libraries</a></td>
|
|
<td align="center"><a href="http://www.boost.org/users/people.html">People</a></td>
|
|
<td align="center"><a href="http://www.boost.org/users/faq.html">FAQ</a></td>
|
|
<td align="center"><a href="../../../../../../more/index.htm">More</a></td>
|
|
</tr></table>
|
|
<hr>
|
|
<div class="spirit-nav">
|
|
<a accesskey="p" href="gradient_descent.html"><img src="../../../../../../doc/src/images/prev.png" alt="Prev"></a><a accesskey="u" href="../gd_opt.html"><img src="../../../../../../doc/src/images/up.png" alt="Up"></a><a accesskey="h" href="../../index.html"><img src="../../../../../../doc/src/images/home.png" alt="Home"></a><a accesskey="n" href="lbfgs.html"><img src="../../../../../../doc/src/images/next.png" alt="Next"></a>
|
|
</div>
|
|
<div class="section">
|
|
<div class="titlepage"><div><div><h3 class="title">
|
|
<a name="math_toolkit.gd_opt.nesterov"></a><a class="link" href="nesterov.html" title="Nesterov Accelerated Gradient Descent">Nesterov Accelerated Gradient
|
|
Descent</a>
|
|
</h3></div></div></div>
|
|
<h5>
|
|
<a name="math_toolkit.gd_opt.nesterov.h0"></a>
|
|
<span class="phrase"><a name="math_toolkit.gd_opt.nesterov.synopsis"></a></span><a class="link" href="nesterov.html#math_toolkit.gd_opt.nesterov.synopsis">Synopsis</a>
|
|
</h5>
|
|
<pre class="programlisting"><span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">optimization</span><span class="special">/</span><span class="identifier">nesterov</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">></span>
|
|
|
|
<span class="keyword">namespace</span> <span class="identifier">boost</span> <span class="special">{</span>
|
|
<span class="keyword">namespace</span> <span class="identifier">math</span> <span class="special">{</span>
|
|
<span class="keyword">namespace</span> <span class="identifier">optimization</span> <span class="special">{</span>
|
|
|
|
<span class="keyword">namespace</span> <span class="identifier">rdiff</span> <span class="special">=</span> <span class="identifier">boost</span><span class="special">::</span><span class="identifier">math</span><span class="special">::</span><span class="identifier">differentiation</span><span class="special">::</span><span class="identifier">reverse_mode</span><span class="special">;</span>
|
|
|
|
<span class="comment">/**
|
|
* @brief The nesterov_accelerated_gradient class
|
|
*
|
|
* https://jlmelville.github.io/mize/nesterov.html
|
|
*/</span>
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span>
|
|
<span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">InitializationPolicy</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">ObjectiveEvalPolicy</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">GradEvalPolicy</span><span class="special">></span>
|
|
<span class="keyword">class</span> <span class="identifier">nesterov_accelerated_gradient</span>
|
|
<span class="special">:</span> <span class="keyword">public</span> <span class="identifier">abstract_optimizer</span><span class="special"><</span>
|
|
<span class="identifier">ArgumentContainer</span><span class="special">,</span>
|
|
<span class="identifier">RealType</span><span class="special">,</span>
|
|
<span class="identifier">Objective</span><span class="special">,</span>
|
|
<span class="identifier">InitializationPolicy</span><span class="special">,</span>
|
|
<span class="identifier">ObjectiveEvalPolicy</span><span class="special">,</span>
|
|
<span class="identifier">GradEvalPolicy</span><span class="special">,</span>
|
|
<span class="identifier">nesterov_update_policy</span><span class="special"><</span><span class="identifier">RealType</span><span class="special">>,</span>
|
|
<span class="identifier">nesterov_accelerated_gradient</span><span class="special"><</span><span class="identifier">ArgumentContainer</span><span class="special">,</span>
|
|
<span class="identifier">RealType</span><span class="special">,</span>
|
|
<span class="identifier">Objective</span><span class="special">,</span>
|
|
<span class="identifier">InitializationPolicy</span><span class="special">,</span>
|
|
<span class="identifier">ObjectiveEvalPolicy</span><span class="special">,</span>
|
|
<span class="identifier">GradEvalPolicy</span><span class="special">>></span>
|
|
<span class="special">{</span>
|
|
<span class="keyword">public</span><span class="special">:</span>
|
|
<span class="identifier">nesterov_accelerated_gradient</span><span class="special">(</span><span class="identifier">Objective</span><span class="special">&&</span> <span class="identifier">objective</span><span class="special">,</span>
|
|
<span class="identifier">ArgumentContainer</span><span class="special">&</span> <span class="identifier">x</span><span class="special">,</span>
|
|
<span class="identifier">InitializationPolicy</span><span class="special">&&</span> <span class="identifier">ip</span><span class="special">,</span>
|
|
<span class="identifier">ObjectiveEvalPolicy</span><span class="special">&&</span> <span class="identifier">oep</span><span class="special">,</span>
|
|
<span class="identifier">GradEvalPolicy</span><span class="special">&&</span> <span class="identifier">gep</span><span class="special">,</span>
|
|
<span class="identifier">nesterov_update_policy</span><span class="special"><</span><span class="identifier">RealType</span><span class="special">>&&</span> <span class="identifier">up</span><span class="special">);</span>
|
|
<span class="keyword">void</span> <span class="identifier">step</span><span class="special">();</span>
|
|
<span class="special">};</span>
|
|
|
|
<span class="comment">/* Convenience overloads */</span>
|
|
<span class="comment">/* make nesterov accelerated gradient descent by providing
|
|
** objective function
|
|
** variables to optimize over
|
|
** Optionally
|
|
* - lr: learning rate / step size (typical: 1e-4 .. 1e-1 depending on scaling)
|
|
* - mu: momentum coefficient in [0, 1) (typical: 0.8 .. 0.99)
|
|
*/</span>
|
|
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span> <span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span> <span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">></span>
|
|
<span class="keyword">auto</span> <span class="identifier">make_nag</span><span class="special">(</span><span class="identifier">Objective</span><span class="special">&&</span> <span class="identifier">obj</span><span class="special">,</span>
|
|
<span class="identifier">ArgumentContainer</span><span class="special">&</span> <span class="identifier">x</span><span class="special">,</span>
|
|
<span class="identifier">RealType</span> <span class="identifier">lr</span> <span class="special">=</span> <span class="identifier">RealType</span><span class="special">{</span> <span class="number">0.01</span> <span class="special">},</span>
|
|
<span class="identifier">RealType</span> <span class="identifier">mu</span> <span class="special">=</span> <span class="identifier">RealType</span><span class="special">{</span> <span class="number">0.95</span> <span class="special">});</span>
|
|
|
|
<span class="comment">/* provide initialization policy
|
|
* lr, and mu no longer optional
|
|
*/</span>
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span>
|
|
<span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span>
|
|
<span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">InitializationPolicy</span><span class="special">></span>
|
|
<span class="keyword">auto</span> <span class="identifier">make_nag</span><span class="special">(</span><span class="identifier">Objective</span><span class="special">&&</span> <span class="identifier">obj</span><span class="special">,</span>
|
|
<span class="identifier">ArgumentContainer</span><span class="special">&</span> <span class="identifier">x</span><span class="special">,</span>
|
|
<span class="identifier">RealType</span> <span class="identifier">lr</span><span class="special">,</span>
|
|
<span class="identifier">RealType</span> <span class="identifier">mu</span><span class="special">,</span>
|
|
<span class="identifier">InitializationPolicy</span><span class="special">&&</span> <span class="identifier">ip</span><span class="special">);</span>
|
|
|
|
<span class="comment">/* provide
|
|
* initialization policy
|
|
* objective evaluation policy
|
|
* gradient evaluation policy
|
|
*/</span>
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span>
|
|
<span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">InitializationPolicy</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">ObjectiveEvalPolicy</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">GradEvalPolicy</span><span class="special">></span>
|
|
<span class="keyword">auto</span> <span class="identifier">make_nag</span><span class="special">(</span><span class="identifier">Objective</span><span class="special">&&</span> <span class="identifier">obj</span><span class="special">,</span>
|
|
<span class="identifier">ArgumentContainer</span><span class="special">&</span> <span class="identifier">x</span><span class="special">,</span>
|
|
<span class="identifier">RealType</span> <span class="identifier">lr</span><span class="special">,</span>
|
|
<span class="identifier">RealType</span> <span class="identifier">mu</span><span class="special">,</span>
|
|
<span class="identifier">InitializationPolicy</span><span class="special">&&</span> <span class="identifier">ip</span><span class="special">,</span>
|
|
<span class="identifier">ObjectiveEvalPolicy</span><span class="special">&&</span> <span class="identifier">oep</span><span class="special">,</span>
|
|
<span class="identifier">GradEvalPolicy</span><span class="special">&&</span> <span class="identifier">gep</span><span class="special">);</span>
|
|
|
|
<span class="special">}</span> <span class="comment">// namespace optimization</span>
|
|
<span class="special">}</span> <span class="comment">// namespace math</span>
|
|
<span class="special">}</span> <span class="comment">// namespace boost</span>
|
|
</pre>
|
|
<p>
|
|
Nesterov accelerated gradient (NAG) is a first-order optimizer that augments
|
|
gradient descent with a momentum term and evaluates the gradient at a "lookahead"
|
|
point. In practice this often improves convergence speed compared to vanilla
|
|
gradient descent, especially in narrow valleys and ill-conditioned problems.
|
|
</p>
|
|
<h5>
|
|
<a name="math_toolkit.gd_opt.nesterov.h1"></a>
|
|
<span class="phrase"><a name="math_toolkit.gd_opt.nesterov.algorithm"></a></span><a class="link" href="nesterov.html#math_toolkit.gd_opt.nesterov.algorithm">Algorithm</a>
|
|
</h5>
|
|
<p>
|
|
NAG maintains a "velocity" vector v (same shape as x). At iteration
|
|
k it performs:
|
|
</p>
|
|
<pre class="programlisting"><span class="identifier">v</span> <span class="special">=</span> <span class="identifier">mu</span> <span class="special">*</span> <span class="identifier">v</span> <span class="special">-</span> <span class="identifier">lr</span> <span class="special">*</span> <span class="identifier">g</span><span class="special">;</span>
|
|
<span class="identifier">x</span> <span class="special">+=</span> <span class="special">-</span><span class="identifier">mu</span> <span class="special">*</span> <span class="identifier">v_prev</span> <span class="special">+</span> <span class="special">(</span><span class="number">1</span> <span class="special">+</span> <span class="identifier">mu</span><span class="special">)</span> <span class="special">*</span><span class="identifier">v</span>
|
|
</pre>
|
|
<p>
|
|
where:
|
|
</p>
|
|
<p>
|
|
lr is the learning rate / step size
|
|
</p>
|
|
<p>
|
|
mu is the momentum coefficient (typically close to 1)
|
|
</p>
|
|
<p>
|
|
Setting mu = 0 reduces NAG to standard gradient descent.
|
|
</p>
|
|
<h5>
|
|
<a name="math_toolkit.gd_opt.nesterov.h2"></a>
|
|
<span class="phrase"><a name="math_toolkit.gd_opt.nesterov.parameters"></a></span><a class="link" href="nesterov.html#math_toolkit.gd_opt.nesterov.parameters">Parameters</a>
|
|
</h5>
|
|
<div class="itemizedlist"><ul class="itemizedlist" style="list-style-type: disc; ">
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">Objective</span><span class="special">&&</span>
|
|
<span class="identifier">obj</span></code> : objective function to
|
|
minimize.
|
|
</li>
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">ArgumentContainer</span><span class="special">&</span>
|
|
<span class="identifier">x</span></code> : variables to optimize over.
|
|
Updated in place.
|
|
</li>
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">RealType</span> <span class="identifier">lr</span></code>
|
|
: learning rate. Larger values take larger steps (faster but potentially
|
|
unstable). Smaller values are more stable but converge more slowly.
|
|
</li>
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">RealType</span> <span class="identifier">mu</span></code>
|
|
: momentum coefficient in <code class="computeroutput"><span class="special">[</span><span class="number">0</span><span class="special">,</span><span class="number">1</span><span class="special">)</span></code>. Higher values, e.g. 0.9 to 0.99, typically
|
|
accelerate convergence but may require a smaller <code class="computeroutput"><span class="identifier">lr</span></code>
|
|
</li>
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">InitializationPolicy</span><span class="special">&&</span> <span class="identifier">ip</span></code>
|
|
: Initialization policy for optimizer state and variables. Users may
|
|
supply a custom initialization policy to control how the argument container
|
|
and any AD-specific runtime state : i.e. reverse-mode tape attachment/reset
|
|
are initialized. By default, the optimizer uses the same initialization
|
|
as gradient descent, taking the user provided initial values in x and
|
|
initializing the internal momentum/velocity state to zero. Custom initialization
|
|
policies are useful for randomized starts, non rvar AD types, or when
|
|
gradients are supplied externally. Check the docs for Reverse Mode autodiff
|
|
policies for initialization policy structure to write custom policies.
|
|
</li>
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">ObjectiveEvalPolicy</span><span class="special">&&</span>
|
|
<span class="identifier">oep</span></code> : objective evaluation
|
|
policy. By default <code class="computeroutput"><span class="identifier">reverse_mode_function_eval_policy</span><span class="special"><</span><span class="identifier">RealType</span><span class="special">></span></code>
|
|
</li>
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">GradEvalPolicy</span><span class="special">&&</span>
|
|
<span class="identifier">gep</span></code> : gradient evaluation policy.
|
|
By default <code class="computeroutput"><span class="identifier">reverse_mode_gradient_evaluation_policy</span><span class="special"><</span><span class="identifier">RealType</span><span class="special">></span></code>
|
|
</li>
|
|
</ul></div>
|
|
<h5>
|
|
<a name="math_toolkit.gd_opt.nesterov.h3"></a>
|
|
<span class="phrase"><a name="math_toolkit.gd_opt.nesterov.notes"></a></span><a class="link" href="nesterov.html#math_toolkit.gd_opt.nesterov.notes">Notes</a>
|
|
</h5>
|
|
<div class="itemizedlist"><ul class="itemizedlist" style="list-style-type: disc; ">
|
|
<li class="listitem">
|
|
NAG uses the same policy-based design as gradient descent: initialization,
|
|
objective evaluation, and gradient evaluation can be customized independently.
|
|
</li>
|
|
<li class="listitem">
|
|
When using reverse-mode AD <code class="computeroutput"><span class="identifier">rvar</span></code>,
|
|
the objective should be written in terms of AD variables so gradients
|
|
can be obtained automatically by the default gradient evaluation policy.
|
|
</li>
|
|
<li class="listitem">
|
|
Typical tuning: start with <code class="computeroutput"><span class="identifier">mu</span>
|
|
<span class="special">=</span> <span class="number">0.9</span></code>
|
|
or <code class="computeroutput"><span class="number">0.95</span></code>; if the objective
|
|
oscillates or diverges, reduce <code class="computeroutput"><span class="identifier">lr</span></code>
|
|
(or slightly reduce <code class="computeroutput"><span class="identifier">mu</span></code>).
|
|
</li>
|
|
</ul></div>
|
|
<h5>
|
|
<a name="math_toolkit.gd_opt.nesterov.h4"></a>
|
|
<span class="phrase"><a name="math_toolkit.gd_opt.nesterov.example_thomson_sphere"></a></span><a class="link" href="nesterov.html#math_toolkit.gd_opt.nesterov.example_thomson_sphere">Example
|
|
: Thomson Sphere</a>
|
|
</h5>
|
|
<pre class="programlisting"><span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">differentiation</span><span class="special">/</span><span class="identifier">autodiff_reverse</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">optimization</span><span class="special">/</span><span class="identifier">gradient_descent</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">optimization</span><span class="special">/</span><span class="identifier">minimizer</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">cmath</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">fstream</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">iostream</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">random</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">string</span><span class="special">></span>
|
|
<span class="keyword">namespace</span> <span class="identifier">rdiff</span> <span class="special">=</span> <span class="identifier">boost</span><span class="special">::</span><span class="identifier">math</span><span class="special">::</span><span class="identifier">differentiation</span><span class="special">::</span><span class="identifier">reverse_mode</span><span class="special">;</span>
|
|
<span class="keyword">namespace</span> <span class="identifier">bopt</span> <span class="special">=</span> <span class="identifier">boost</span><span class="special">::</span><span class="identifier">math</span><span class="special">::</span><span class="identifier">optimization</span><span class="special">;</span>
|
|
<span class="keyword">double</span> <span class="identifier">random_double</span><span class="special">(</span><span class="keyword">double</span> <span class="identifier">min</span> <span class="special">=</span> <span class="number">0.0</span><span class="special">,</span> <span class="keyword">double</span> <span class="identifier">max</span> <span class="special">=</span> <span class="number">1.0</span><span class="special">)</span>
|
|
<span class="special">{</span>
|
|
<span class="keyword">static</span> <span class="keyword">thread_local</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">mt19937</span> <span class="identifier">rng</span><span class="special">{</span><span class="identifier">std</span><span class="special">::</span><span class="identifier">random_device</span><span class="special">{}()};</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">uniform_real_distribution</span><span class="special"><</span><span class="keyword">double</span><span class="special">></span> <span class="identifier">dist</span><span class="special">(</span><span class="identifier">min</span><span class="special">,</span> <span class="identifier">max</span><span class="special">);</span>
|
|
<span class="keyword">return</span> <span class="identifier">dist</span><span class="special">(</span><span class="identifier">rng</span><span class="special">);</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">typename</span> <span class="identifier">S</span><span class="special">></span>
|
|
<span class="keyword">struct</span> <span class="identifier">vec3</span>
|
|
<span class="special">{</span>
|
|
<span class="comment">/**
|
|
* @brief R^3 coordinates of particle on Thomson Sphere
|
|
*/</span>
|
|
<span class="identifier">S</span> <span class="identifier">x</span><span class="special">,</span> <span class="identifier">y</span><span class="special">,</span> <span class="identifier">z</span><span class="special">;</span>
|
|
<span class="special">};</span>
|
|
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">class</span> <span class="identifier">S</span><span class="special">></span>
|
|
<span class="keyword">static</span> <span class="keyword">inline</span> <span class="identifier">vec3</span><span class="special"><</span><span class="identifier">S</span><span class="special">></span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="keyword">const</span> <span class="identifier">S</span><span class="special">&</span> <span class="identifier">theta</span><span class="special">,</span> <span class="keyword">const</span> <span class="identifier">S</span><span class="special">&</span> <span class="identifier">phi</span><span class="special">)</span>
|
|
<span class="special">{</span>
|
|
<span class="comment">/**
|
|
* convenience overload to convert from [theta,phi] -> x, y, z
|
|
*/</span>
|
|
<span class="keyword">return</span> <span class="special">{</span><span class="identifier">sin</span><span class="special">(</span><span class="identifier">theta</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">cos</span><span class="special">(</span><span class="identifier">phi</span><span class="special">),</span> <span class="identifier">sin</span><span class="special">(</span><span class="identifier">theta</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">sin</span><span class="special">(</span><span class="identifier">phi</span><span class="special">),</span> <span class="identifier">cos</span><span class="special">(</span><span class="identifier">theta</span><span class="special">)};</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">typename</span> <span class="identifier">T</span><span class="special">></span>
|
|
<span class="identifier">T</span> <span class="identifier">thomson_energy</span><span class="special">(</span><span class="identifier">std</span><span class="special">::</span><span class="identifier">vector</span><span class="special"><</span><span class="identifier">T</span><span class="special">>&</span> <span class="identifier">r</span><span class="special">)</span>
|
|
<span class="special">{</span>
|
|
<span class="comment">/* inverse square law
|
|
*/</span>
|
|
<span class="keyword">const</span> <span class="identifier">size_t</span> <span class="identifier">N</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">size</span><span class="special">()</span> <span class="special">/</span> <span class="number">2</span><span class="special">;</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span> <span class="identifier">tiny</span> <span class="special">=</span> <span class="identifier">T</span><span class="special">(</span><span class="number">1e-12</span><span class="special">);</span>
|
|
|
|
<span class="identifier">T</span> <span class="identifier">E</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span>
|
|
<span class="keyword">for</span> <span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">i</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">i</span> <span class="special"><</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">i</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&</span> <span class="identifier">theta_i</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">i</span> <span class="special">+</span> <span class="number">0</span><span class="special">];</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&</span> <span class="identifier">phi_i</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">i</span> <span class="special">+</span> <span class="number">1</span><span class="special">];</span>
|
|
<span class="keyword">auto</span> <span class="identifier">ri</span> <span class="special">=</span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="identifier">theta_i</span><span class="special">,</span> <span class="identifier">phi_i</span><span class="special">);</span>
|
|
|
|
<span class="keyword">for</span> <span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">j</span> <span class="special">=</span> <span class="identifier">i</span> <span class="special">+</span> <span class="number">1</span><span class="special">;</span> <span class="identifier">j</span> <span class="special"><</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">j</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&</span> <span class="identifier">theta_j</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">j</span> <span class="special">+</span> <span class="number">0</span><span class="special">];</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&</span> <span class="identifier">phi_j</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">j</span> <span class="special">+</span> <span class="number">1</span><span class="special">];</span>
|
|
<span class="keyword">auto</span> <span class="identifier">rj</span> <span class="special">=</span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="identifier">theta_j</span><span class="special">,</span> <span class="identifier">phi_j</span><span class="special">);</span>
|
|
|
|
<span class="identifier">T</span> <span class="identifier">dx</span> <span class="special">=</span> <span class="identifier">ri</span><span class="special">.</span><span class="identifier">x</span> <span class="special">-</span> <span class="identifier">rj</span><span class="special">.</span><span class="identifier">x</span><span class="special">;</span>
|
|
<span class="identifier">T</span> <span class="identifier">dy</span> <span class="special">=</span> <span class="identifier">ri</span><span class="special">.</span><span class="identifier">y</span> <span class="special">-</span> <span class="identifier">rj</span><span class="special">.</span><span class="identifier">y</span><span class="special">;</span>
|
|
<span class="identifier">T</span> <span class="identifier">dz</span> <span class="special">=</span> <span class="identifier">ri</span><span class="special">.</span><span class="identifier">z</span> <span class="special">-</span> <span class="identifier">rj</span><span class="special">.</span><span class="identifier">z</span><span class="special">;</span>
|
|
|
|
<span class="identifier">T</span> <span class="identifier">d2</span> <span class="special">=</span> <span class="identifier">dx</span> <span class="special">*</span> <span class="identifier">dx</span> <span class="special">+</span> <span class="identifier">dy</span> <span class="special">*</span> <span class="identifier">dy</span> <span class="special">+</span> <span class="identifier">dz</span> <span class="special">*</span> <span class="identifier">dz</span> <span class="special">+</span> <span class="identifier">tiny</span><span class="special">;</span>
|
|
<span class="identifier">E</span> <span class="special">+=</span> <span class="number">1.0</span> <span class="special">/</span> <span class="identifier">sqrt</span><span class="special">(</span><span class="identifier">d2</span><span class="special">);</span>
|
|
<span class="special">}</span>
|
|
<span class="special">}</span>
|
|
<span class="keyword">return</span> <span class="identifier">E</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">class</span> <span class="identifier">T</span><span class="special">></span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">vector</span><span class="special"><</span><span class="identifier">rdiff</span><span class="special">::</span><span class="identifier">rvar</span><span class="special"><</span><span class="identifier">T</span><span class="special">,</span> <span class="number">1</span><span class="special">>></span> <span class="identifier">init_theta_phi_uniform</span><span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">N</span><span class="special">,</span> <span class="keyword">unsigned</span> <span class="identifier">seed</span> <span class="special">=</span> <span class="number">12345</span><span class="special">)</span>
|
|
<span class="special">{</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span> <span class="identifier">pi</span> <span class="special">=</span> <span class="identifier">T</span><span class="special">(</span><span class="number">3.1415926535897932384626433832795</span><span class="special">);</span>
|
|
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">mt19937</span> <span class="identifier">rng</span><span class="special">(</span><span class="identifier">seed</span><span class="special">);</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">uniform_real_distribution</span><span class="special"><</span><span class="identifier">T</span><span class="special">></span> <span class="identifier">unif01</span><span class="special">(</span><span class="identifier">T</span><span class="special">(</span><span class="number">0</span><span class="special">),</span> <span class="identifier">T</span><span class="special">(</span><span class="number">1</span><span class="special">));</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">uniform_real_distribution</span><span class="special"><</span><span class="identifier">T</span><span class="special">></span> <span class="identifier">unifm11</span><span class="special">(</span><span class="identifier">T</span><span class="special">(-</span><span class="number">1</span><span class="special">),</span> <span class="identifier">T</span><span class="special">(</span><span class="number">1</span><span class="special">));</span>
|
|
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">vector</span><span class="special"><</span><span class="identifier">rdiff</span><span class="special">::</span><span class="identifier">rvar</span><span class="special"><</span><span class="identifier">T</span><span class="special">,</span> <span class="number">1</span><span class="special">>></span> <span class="identifier">u</span><span class="special">;</span>
|
|
<span class="identifier">u</span><span class="special">.</span><span class="identifier">reserve</span><span class="special">(</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">N</span><span class="special">);</span>
|
|
|
|
<span class="keyword">for</span> <span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">i</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">i</span> <span class="special"><</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">i</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="identifier">T</span> <span class="identifier">z</span> <span class="special">=</span> <span class="identifier">unifm11</span><span class="special">(</span><span class="identifier">rng</span><span class="special">);</span>
|
|
<span class="identifier">T</span> <span class="identifier">phi</span> <span class="special">=</span> <span class="special">(</span><span class="identifier">T</span><span class="special">(</span><span class="number">2</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">pi</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">unif01</span><span class="special">(</span><span class="identifier">rng</span><span class="special">)</span> <span class="special">-</span> <span class="identifier">pi</span><span class="special">;</span>
|
|
<span class="identifier">T</span> <span class="identifier">theta</span> <span class="special">=</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">acos</span><span class="special">(</span><span class="identifier">z</span><span class="special">);</span>
|
|
|
|
<span class="identifier">u</span><span class="special">.</span><span class="identifier">emplace_back</span><span class="special">(</span><span class="identifier">theta</span><span class="special">);</span>
|
|
<span class="identifier">u</span><span class="special">.</span><span class="identifier">emplace_back</span><span class="special">(</span><span class="identifier">phi</span><span class="special">);</span>
|
|
<span class="special">}</span>
|
|
<span class="keyword">return</span> <span class="identifier">u</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="keyword">int</span> <span class="identifier">main</span><span class="special">(</span><span class="keyword">int</span> <span class="identifier">argc</span><span class="special">,</span> <span class="keyword">char</span><span class="special">*</span> <span class="identifier">argv</span><span class="special">[])</span>
|
|
<span class="special">{</span>
|
|
|
|
<span class="keyword">if</span> <span class="special">(</span><span class="identifier">argc</span> <span class="special">!=</span> <span class="number">2</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">cerr</span> <span class="special"><<</span> <span class="string">"Usage: "</span> <span class="special"><<</span> <span class="identifier">argv</span><span class="special">[</span><span class="number">0</span><span class="special">]</span> <span class="special"><<</span> <span class="string">" <N>\n"</span><span class="special">;</span>
|
|
<span class="keyword">return</span> <span class="number">1</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="keyword">const</span> <span class="keyword">int</span> <span class="identifier">N</span> <span class="special">=</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">stoi</span><span class="special">(</span><span class="identifier">argv</span><span class="special">[</span><span class="number">1</span><span class="special">]);</span>
|
|
<span class="keyword">const</span> <span class="keyword">double</span> <span class="identifier">lr</span> <span class="special">=</span> <span class="number">1e-3</span><span class="special">;</span>
|
|
<span class="keyword">const</span> <span class="keyword">double</span> <span class="identifier">mu</span> <span class="special">=</span> <span class="number">0.95</span><span class="special">;</span>
|
|
<span class="keyword">auto</span> <span class="identifier">u_ad</span> <span class="special">=</span> <span class="identifier">init_theta_phi_uniform</span><span class="special"><</span><span class="keyword">double</span><span class="special">>(</span><span class="identifier">N</span><span class="special">);</span>
|
|
|
|
<span class="keyword">auto</span> <span class="identifier">nesterov_opt</span> <span class="special">=</span> <span class="identifier">bopt</span><span class="special">::</span><span class="identifier">make_nag</span><span class="special">(&</span><span class="identifier">thomson_energy</span><span class="special"><</span><span class="identifier">rdiff</span><span class="special">::</span><span class="identifier">rvar</span><span class="special"><</span><span class="keyword">double</span><span class="special">,</span> <span class="number">1</span><span class="special">>>,</span> <span class="identifier">u_ad</span><span class="special">,</span> <span class="identifier">lr</span><span class="special">,</span> <span class="identifier">mu</span><span class="special">);</span>
|
|
|
|
<span class="comment">// filenames</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">string</span> <span class="identifier">pos_filename</span> <span class="special">=</span> <span class="string">"thomson_"</span> <span class="special">+</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">to_string</span><span class="special">(</span><span class="identifier">N</span><span class="special">)</span> <span class="special">+</span> <span class="string">".csv"</span><span class="special">;</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">string</span> <span class="identifier">energy_filename</span> <span class="special">=</span> <span class="string">"nesterov_energy_"</span> <span class="special">+</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">to_string</span><span class="special">(</span><span class="identifier">N</span><span class="special">)</span> <span class="special">+</span> <span class="string">".csv"</span><span class="special">;</span>
|
|
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">ofstream</span> <span class="identifier">pos_out</span><span class="special">(</span><span class="identifier">pos_filename</span><span class="special">);</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">ofstream</span> <span class="identifier">energy_out</span><span class="special">(</span><span class="identifier">energy_filename</span><span class="special">);</span>
|
|
|
|
<span class="identifier">energy_out</span> <span class="special"><<</span> <span class="string">"step,energy\n"</span><span class="special">;</span>
|
|
|
|
<span class="keyword">auto</span> <span class="identifier">result</span> <span class="special">=</span> <span class="identifier">minimize</span><span class="special">(</span><span class="identifier">nesterov_opt</span><span class="special">);</span>
|
|
<span class="keyword">for</span> <span class="special">(</span><span class="keyword">int</span> <span class="identifier">pi</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">pi</span> <span class="special"><</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">pi</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="keyword">double</span> <span class="identifier">theta</span> <span class="special">=</span> <span class="identifier">u_ad</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">pi</span> <span class="special">+</span> <span class="number">0</span><span class="special">].</span><span class="identifier">item</span><span class="special">();</span>
|
|
<span class="keyword">double</span> <span class="identifier">phi</span> <span class="special">=</span> <span class="identifier">u_ad</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">pi</span> <span class="special">+</span> <span class="number">1</span><span class="special">].</span><span class="identifier">item</span><span class="special">();</span>
|
|
<span class="keyword">auto</span> <span class="identifier">r</span> <span class="special">=</span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="identifier">theta</span><span class="special">,</span> <span class="identifier">phi</span><span class="special">);</span>
|
|
<span class="identifier">pos_out</span> <span class="special"><<</span> <span class="identifier">pi</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">x</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">y</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">z</span> <span class="special"><<</span> <span class="string">"\n"</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
<span class="keyword">auto</span> <span class="identifier">E</span> <span class="special">=</span> <span class="identifier">nesterov_opt</span><span class="special">.</span><span class="identifier">objective_value</span><span class="special">();</span>
|
|
<span class="keyword">int</span> <span class="identifier">i</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span>
|
|
<span class="keyword">for</span><span class="special">(</span><span class="keyword">auto</span><span class="special">&</span> <span class="identifier">obj_hist</span> <span class="special">:</span> <span class="identifier">result</span><span class="special">.</span><span class="identifier">objective_history</span><span class="special">)</span>
|
|
<span class="special">{</span>
|
|
<span class="identifier">energy_out</span> <span class="special"><<</span> <span class="identifier">i</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">obj_hist</span> <span class="special"><<</span> <span class="string">"\n"</span><span class="special">;</span>
|
|
<span class="special">++</span><span class="identifier">i</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
<span class="identifier">energy_out</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">E</span> <span class="special"><<</span> <span class="string">"\n"</span><span class="special">;</span>
|
|
|
|
<span class="identifier">pos_out</span><span class="special">.</span><span class="identifier">close</span><span class="special">();</span>
|
|
<span class="identifier">energy_out</span><span class="special">.</span><span class="identifier">close</span><span class="special">();</span>
|
|
|
|
<span class="keyword">return</span> <span class="number">0</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
</pre>
|
|
<p>
|
|
The nesterov version of this problem converges much faster than regular gradient
|
|
descent, in only <code class="computeroutput"><span class="number">4663</span></code> iterations
|
|
with default parameters, vs the <code class="computeroutput"><span class="number">93799</span></code>
|
|
iterations required by gradient descent.
|
|
</p>
|
|
<div class="blockquote"><blockquote class="blockquote"><div class="blockquote"><blockquote class="blockquote"><p>
|
|
<span class="inlinemediaobject"><img src="../../../graphs/gradient_based_optimizers/nag_to_gd_comparison.svg"></span>
|
|
</p></blockquote></div></blockquote></div>
|
|
</div>
|
|
<div class="copyright-footer">Copyright © 2006-2021 Nikhar Agrawal, Anton Bikineev, Matthew Borland,
|
|
Paul A. Bristow, Marco Guazzone, Christopher Kormanyos, Hubert Holin, Bruno
|
|
Lalande, John Maddock, Evan Miller, Jeremy Murphy, Matthew Pulver, Johan Råde,
|
|
Gautam Sewani, Benjamin Sobotta, Nicholas Thompson, Thijs van den Berg, Daryle
|
|
Walker, Xiaogang Zhang, and Maksym Zhelyeznyakov<p>
|
|
Distributed under the Boost Software License, Version 1.0. (See accompanying
|
|
file LICENSE_1_0.txt or copy at <a href="http://www.boost.org/LICENSE_1_0.txt" target="_top">http://www.boost.org/LICENSE_1_0.txt</a>)
|
|
</p>
|
|
</div>
|
|
<hr>
|
|
<div class="spirit-nav">
|
|
<a accesskey="p" href="gradient_descent.html"><img src="../../../../../../doc/src/images/prev.png" alt="Prev"></a><a accesskey="u" href="../gd_opt.html"><img src="../../../../../../doc/src/images/up.png" alt="Up"></a><a accesskey="h" href="../../index.html"><img src="../../../../../../doc/src/images/home.png" alt="Home"></a><a accesskey="n" href="lbfgs.html"><img src="../../../../../../doc/src/images/next.png" alt="Next"></a>
|
|
</div>
|
|
</body>
|
|
</html>
|