2
0
mirror of https://github.com/boostorg/math.git synced 2026-02-25 04:22:15 +00:00
Files
math/doc/html/math_toolkit/gd_opt/nesterov.html
2026-01-28 22:12:08 -05:00

389 lines
50 KiB
HTML

<html>
<head>
<meta charset="UTF-8">
<title>Nesterov Accelerated Gradient Desccent</title>
<link rel="stylesheet" href="../../math.css" type="text/css">
<meta name="generator" content="DocBook XSL Stylesheets Vsnapshot">
<link rel="home" href="../../index.html" title="Math Toolkit 4.2.1">
<link rel="up" href="../gd_opt.html" title="Gradient Based Optimizers">
<link rel="prev" href="gradient_descent.html" title="Gradient Desccent">
<link rel="next" href="lbfgs.html" title="L-BFGS">
<meta name="viewport" content="width=device-width, initial-scale=1">
</head>
<body bgcolor="white" text="black" link="#0000FF" vlink="#840084" alink="#0000FF">
<table cellpadding="2" width="100%"><tr>
<td valign="top"><img alt="Boost C++ Libraries" width="277" height="86" src="../../../../../../boost.png"></td>
<td align="center"><a href="../../../../../../index.html">Home</a></td>
<td align="center"><a href="../../../../../../libs/libraries.htm">Libraries</a></td>
<td align="center"><a href="http://www.boost.org/users/people.html">People</a></td>
<td align="center"><a href="http://www.boost.org/users/faq.html">FAQ</a></td>
<td align="center"><a href="../../../../../../more/index.htm">More</a></td>
</tr></table>
<hr>
<div class="spirit-nav">
<a accesskey="p" href="gradient_descent.html"><img src="../../../../../../doc/src/images/prev.png" alt="Prev"></a><a accesskey="u" href="../gd_opt.html"><img src="../../../../../../doc/src/images/up.png" alt="Up"></a><a accesskey="h" href="../../index.html"><img src="../../../../../../doc/src/images/home.png" alt="Home"></a><a accesskey="n" href="lbfgs.html"><img src="../../../../../../doc/src/images/next.png" alt="Next"></a>
</div>
<div class="section">
<div class="titlepage"><div><div><h3 class="title">
<a name="math_toolkit.gd_opt.nesterov"></a><a class="link" href="nesterov.html" title="Nesterov Accelerated Gradient Desccent">Nesterov Accelerated Gradient
Desccent</a>
</h3></div></div></div>
<h5>
<a name="math_toolkit.gd_opt.nesterov.h0"></a>
<span class="phrase"><a name="math_toolkit.gd_opt.nesterov.synopsis"></a></span><a class="link" href="nesterov.html#math_toolkit.gd_opt.nesterov.synopsis">Synopsis</a>
</h5>
<pre class="programlisting"><span class="preprocessor">#include</span> <span class="special">&lt;</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">optimization</span><span class="special">/</span><span class="identifier">nesterov</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">&gt;</span>
<span class="keyword">namespace</span> <span class="identifier">boost</span> <span class="special">{</span>
<span class="keyword">namespace</span> <span class="identifier">math</span> <span class="special">{</span>
<span class="keyword">namespace</span> <span class="identifier">optimization</span> <span class="special">{</span>
<span class="keyword">namespace</span> <span class="identifier">rdiff</span> <span class="special">=</span> <span class="identifier">boost</span><span class="special">::</span><span class="identifier">math</span><span class="special">::</span><span class="identifier">differentiation</span><span class="special">::</span><span class="identifier">reverse_mode</span><span class="special">;</span>
<span class="comment">/**
* @brief The nesterov_accelerated_gradient class
*
* https://jlmelville.github.io/mize/nesterov.html
*/</span>
<span class="keyword">template</span><span class="special">&lt;</span><span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span>
<span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">,</span>
<span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span>
<span class="keyword">class</span> <span class="identifier">InitializationPolicy</span><span class="special">,</span>
<span class="keyword">class</span> <span class="identifier">ObjectiveEvalPolicy</span><span class="special">,</span>
<span class="keyword">class</span> <span class="identifier">GradEvalPolicy</span><span class="special">&gt;</span>
<span class="keyword">class</span> <span class="identifier">nesterov_accelerated_gradient</span>
<span class="special">:</span> <span class="keyword">public</span> <span class="identifier">abstract_optimizer</span><span class="special">&lt;</span>
<span class="identifier">ArgumentContainer</span><span class="special">,</span>
<span class="identifier">RealType</span><span class="special">,</span>
<span class="identifier">Objective</span><span class="special">,</span>
<span class="identifier">InitializationPolicy</span><span class="special">,</span>
<span class="identifier">ObjectiveEvalPolicy</span><span class="special">,</span>
<span class="identifier">GradEvalPolicy</span><span class="special">,</span>
<span class="identifier">nesterov_update_policy</span><span class="special">&lt;</span><span class="identifier">RealType</span><span class="special">&gt;,</span>
<span class="identifier">nesterov_accelerated_gradient</span><span class="special">&lt;</span><span class="identifier">ArgumentContainer</span><span class="special">,</span>
<span class="identifier">RealType</span><span class="special">,</span>
<span class="identifier">Objective</span><span class="special">,</span>
<span class="identifier">InitializationPolicy</span><span class="special">,</span>
<span class="identifier">ObjectiveEvalPolicy</span><span class="special">,</span>
<span class="identifier">GradEvalPolicy</span><span class="special">&gt;&gt;</span>
<span class="special">{</span>
<span class="keyword">public</span><span class="special">:</span>
<span class="identifier">nesterov_accelerated_gradient</span><span class="special">(</span><span class="identifier">Objective</span><span class="special">&amp;&amp;</span> <span class="identifier">objective</span><span class="special">,</span>
<span class="identifier">ArgumentContainer</span><span class="special">&amp;</span> <span class="identifier">x</span><span class="special">,</span>
<span class="identifier">InitializationPolicy</span><span class="special">&amp;&amp;</span> <span class="identifier">ip</span><span class="special">,</span>
<span class="identifier">ObjectiveEvalPolicy</span><span class="special">&amp;&amp;</span> <span class="identifier">oep</span><span class="special">,</span>
<span class="identifier">GradEvalPolicy</span><span class="special">&amp;&amp;</span> <span class="identifier">gep</span><span class="special">,</span>
<span class="identifier">nesterov_update_policy</span><span class="special">&lt;</span><span class="identifier">RealType</span><span class="special">&gt;&amp;&amp;</span> <span class="identifier">up</span><span class="special">);</span>
<span class="keyword">void</span> <span class="identifier">step</span><span class="special">();</span>
<span class="special">};</span>
<span class="comment">/* Convenience overloads */</span>
<span class="comment">/* make nesterov acelerated gradient descent by providing
** objective function
** variables to optimize over
** Optionally
* - lr: learning rate / step size (typical: 1e-4 .. 1e-1 depending on scaling)
* - mu: momentum coefficient in [0, 1) (typical: 0.8 .. 0.99)
*/</span>
<span class="keyword">template</span><span class="special">&lt;</span><span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span> <span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span> <span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">&gt;</span>
<span class="keyword">auto</span> <span class="identifier">make_nag</span><span class="special">(</span><span class="identifier">Objective</span><span class="special">&amp;&amp;</span> <span class="identifier">obj</span><span class="special">,</span>
<span class="identifier">ArgumentContainer</span><span class="special">&amp;</span> <span class="identifier">x</span><span class="special">,</span>
<span class="identifier">RealType</span> <span class="identifier">lr</span> <span class="special">=</span> <span class="identifier">RealType</span><span class="special">{</span> <span class="number">0.01</span> <span class="special">},</span>
<span class="identifier">RealType</span> <span class="identifier">mu</span> <span class="special">=</span> <span class="identifier">RealType</span><span class="special">{</span> <span class="number">0.95</span> <span class="special">});</span>
<span class="comment">/* provide initialization policy
* lr, and mu no longer optional
*/</span>
<span class="keyword">template</span><span class="special">&lt;</span><span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span>
<span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span>
<span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">,</span>
<span class="keyword">class</span> <span class="identifier">InitializationPolicy</span><span class="special">&gt;</span>
<span class="keyword">auto</span> <span class="identifier">make_nag</span><span class="special">(</span><span class="identifier">Objective</span><span class="special">&amp;&amp;</span> <span class="identifier">obj</span><span class="special">,</span>
<span class="identifier">ArgumentContainer</span><span class="special">&amp;</span> <span class="identifier">x</span><span class="special">,</span>
<span class="identifier">RealType</span> <span class="identifier">lr</span><span class="special">,</span>
<span class="identifier">RealType</span> <span class="identifier">mu</span><span class="special">,</span>
<span class="identifier">InitializationPolicy</span><span class="special">&amp;&amp;</span> <span class="identifier">ip</span><span class="special">);</span>
<span class="comment">/* provide
* initilaization policy
* objective evaluation policy
* gradient evaluation policy
*/</span>
<span class="keyword">template</span><span class="special">&lt;</span><span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span>
<span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">,</span>
<span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span>
<span class="keyword">class</span> <span class="identifier">InitializationPolicy</span><span class="special">,</span>
<span class="keyword">class</span> <span class="identifier">ObjectiveEvalPolicy</span><span class="special">,</span>
<span class="keyword">class</span> <span class="identifier">GradEvalPolicy</span><span class="special">&gt;</span>
<span class="keyword">auto</span> <span class="identifier">make_nag</span><span class="special">(</span><span class="identifier">Objective</span><span class="special">&amp;&amp;</span> <span class="identifier">obj</span><span class="special">,</span>
<span class="identifier">ArgumentContainer</span><span class="special">&amp;</span> <span class="identifier">x</span><span class="special">,</span>
<span class="identifier">RealType</span> <span class="identifier">lr</span><span class="special">,</span>
<span class="identifier">RealType</span> <span class="identifier">mu</span><span class="special">,</span>
<span class="identifier">InitializationPolicy</span><span class="special">&amp;&amp;</span> <span class="identifier">ip</span><span class="special">,</span>
<span class="identifier">ObjectiveEvalPolicy</span><span class="special">&amp;&amp;</span> <span class="identifier">oep</span><span class="special">,</span>
<span class="identifier">GradEvalPolicy</span><span class="special">&amp;&amp;</span> <span class="identifier">gep</span><span class="special">);</span>
<span class="special">}</span> <span class="comment">// namespace optimization</span>
<span class="special">}</span> <span class="comment">// namespace math</span>
<span class="special">}</span> <span class="comment">// namespace boost</span>
</pre>
<p>
Nesterov accelerated gradient (NAG) is a first-order optimizer that augments
gradient descent with a momentum term and evaluates the gradient at a "lookahead"
point. In practice this often improves convergence speed compared to vanilla
gradient descent, especially in narrow valleys and ill-conditioned problems.
</p>
<h5>
<a name="math_toolkit.gd_opt.nesterov.h1"></a>
<span class="phrase"><a name="math_toolkit.gd_opt.nesterov.algorithm"></a></span><a class="link" href="nesterov.html#math_toolkit.gd_opt.nesterov.algorithm">Algorithm</a>
</h5>
<p>
NAG maintains a "velocity" vector v (same shape as x). At iteration
k it performs:
</p>
<pre class="programlisting"><span class="identifier">v</span> <span class="special">=</span> <span class="identifier">mu</span> <span class="special">*</span> <span class="identifier">v</span> <span class="special">-</span> <span class="identifier">lr</span> <span class="special">*</span> <span class="identifier">g</span><span class="special">;</span>
<span class="identifier">x</span> <span class="special">+=</span> <span class="special">-</span><span class="identifier">mu</span> <span class="special">*</span> <span class="identifier">v_prev</span> <span class="special">+</span> <span class="special">(</span><span class="number">1</span> <span class="special">+</span> <span class="identifier">mu</span><span class="special">)</span> <span class="special">*</span><span class="identifier">v</span>
</pre>
<p>
where:
</p>
<p>
lr is the learning rate / step size
</p>
<p>
mu is the momentum coefficient (typically close to 1)
</p>
<p>
Setting mu = 0 reduces NAG to standard gradient descent.
</p>
<h5>
<a name="math_toolkit.gd_opt.nesterov.h2"></a>
<span class="phrase"><a name="math_toolkit.gd_opt.nesterov.parameters"></a></span><a class="link" href="nesterov.html#math_toolkit.gd_opt.nesterov.parameters">Parameters</a>
</h5>
<div class="itemizedlist"><ul class="itemizedlist" style="list-style-type: disc; ">
<li class="listitem">
<code class="computeroutput"><span class="identifier">Objective</span><span class="special">&amp;&amp;</span>
<span class="identifier">obj</span></code> : objective function to
minimize.
</li>
<li class="listitem">
<code class="computeroutput"><span class="identifier">ArgumentContainer</span><span class="special">&amp;</span>
<span class="identifier">x</span></code> : variables to optimize over.
Updated in place.
</li>
<li class="listitem">
<code class="computeroutput"><span class="identifier">RealType</span> <span class="identifier">lr</span></code>
: learning rate. Larger values take larger steps (faster but potentially
unsable). Smaller values are more stable but converge more slowly.
</li>
<li class="listitem">
<code class="computeroutput"><span class="identifier">RealType</span> <span class="identifier">mu</span></code>
: momentum coefficient in <code class="computeroutput"><span class="special">[</span><span class="number">0</span><span class="special">,</span><span class="number">1</span><span class="special">)</span></code>. Higher values, e.g. 0.9 to 0.99, typically
accelerate convergence but may require a smaller <code class="computeroutput"><span class="identifier">lr</span></code>
</li>
<li class="listitem">
<code class="computeroutput"><span class="identifier">InitializationPolicy</span><span class="special">&amp;&amp;</span> <span class="identifier">ip</span></code>
: initialization policy for the optimizer state and variables. For NAG,
this also initializes the internal momentum/velocity state. By default
the optimizer uses the same initializer as gradient descent and initializes
velocity to zero.
</li>
<li class="listitem">
<code class="computeroutput"><span class="identifier">ObjectiveEvalPolicy</span><span class="special">&amp;&amp;</span>
<span class="identifier">oep</span></code> : objective evaluation
policy. By default <code class="computeroutput"><span class="identifier">reverse_mode_function_eval_policy</span><span class="special">&lt;</span><span class="identifier">RealType</span><span class="special">&gt;</span></code>
</li>
<li class="listitem">
<code class="computeroutput"><span class="identifier">GradEvalPolicy</span><span class="special">&amp;&amp;</span>
<span class="identifier">gep</span></code> : gradient evaluation policy.
By default <code class="computeroutput"><span class="identifier">reverse_mode_gradient_evaluation_policy</span><span class="special">&lt;</span><span class="identifier">RealType</span><span class="special">&gt;</span></code>
</li>
</ul></div>
<h5>
<a name="math_toolkit.gd_opt.nesterov.h3"></a>
<span class="phrase"><a name="math_toolkit.gd_opt.nesterov.notes"></a></span><a class="link" href="nesterov.html#math_toolkit.gd_opt.nesterov.notes">Notes</a>
</h5>
<div class="itemizedlist"><ul class="itemizedlist" style="list-style-type: disc; ">
<li class="listitem">
NAG uses the same policy-based design as gradient descent: initialization,
objective evaluation, and gradient evaluation can be customized independently.
</li>
<li class="listitem">
When using reverse-mode AD <code class="computeroutput"><span class="identifier">rvar</span></code>,
the objective should be written in terms of AD variables so gradients
can be obtained automatically by the default gradient evaluation policy.
</li>
<li class="listitem">
Typical tuning: start with <code class="computeroutput"><span class="identifier">mu</span>
<span class="special">=</span> <span class="number">0.9</span></code>
or <code class="computeroutput"><span class="number">0.95</span></code>; if the objective
oscillates or diverges, reduce <code class="computeroutput"><span class="identifier">lr</span></code>
(or slightly reduce <code class="computeroutput"><span class="identifier">mu</span></code>).
</li>
</ul></div>
<h5>
<a name="math_toolkit.gd_opt.nesterov.h4"></a>
<span class="phrase"><a name="math_toolkit.gd_opt.nesterov.example_thomson_sphere"></a></span><a class="link" href="nesterov.html#math_toolkit.gd_opt.nesterov.example_thomson_sphere">Example
: Thomson Sphere</a>
</h5>
<pre class="programlisting"><span class="preprocessor">#include</span> <span class="special">&lt;</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">differentiation</span><span class="special">/</span><span class="identifier">autodiff_reverse</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">&gt;</span>
<span class="preprocessor">#include</span> <span class="special">&lt;</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">optimization</span><span class="special">/</span><span class="identifier">gradient_descent</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">&gt;</span>
<span class="preprocessor">#include</span> <span class="special">&lt;</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">optimization</span><span class="special">/</span><span class="identifier">minimizer</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">&gt;</span>
<span class="preprocessor">#include</span> <span class="special">&lt;</span><span class="identifier">cmath</span><span class="special">&gt;</span>
<span class="preprocessor">#include</span> <span class="special">&lt;</span><span class="identifier">fstream</span><span class="special">&gt;</span>
<span class="preprocessor">#include</span> <span class="special">&lt;</span><span class="identifier">iostream</span><span class="special">&gt;</span>
<span class="preprocessor">#include</span> <span class="special">&lt;</span><span class="identifier">random</span><span class="special">&gt;</span>
<span class="preprocessor">#include</span> <span class="special">&lt;</span><span class="identifier">string</span><span class="special">&gt;</span>
<span class="keyword">namespace</span> <span class="identifier">rdiff</span> <span class="special">=</span> <span class="identifier">boost</span><span class="special">::</span><span class="identifier">math</span><span class="special">::</span><span class="identifier">differentiation</span><span class="special">::</span><span class="identifier">reverse_mode</span><span class="special">;</span>
<span class="keyword">namespace</span> <span class="identifier">bopt</span> <span class="special">=</span> <span class="identifier">boost</span><span class="special">::</span><span class="identifier">math</span><span class="special">::</span><span class="identifier">optimization</span><span class="special">;</span>
<span class="keyword">double</span> <span class="identifier">random_double</span><span class="special">(</span><span class="keyword">double</span> <span class="identifier">min</span> <span class="special">=</span> <span class="number">0.0</span><span class="special">,</span> <span class="keyword">double</span> <span class="identifier">max</span> <span class="special">=</span> <span class="number">1.0</span><span class="special">)</span>
<span class="special">{</span>
<span class="keyword">static</span> <span class="keyword">thread_local</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">mt19937</span> <span class="identifier">rng</span><span class="special">{</span><span class="identifier">std</span><span class="special">::</span><span class="identifier">random_device</span><span class="special">{}()};</span>
<span class="identifier">std</span><span class="special">::</span><span class="identifier">uniform_real_distribution</span><span class="special">&lt;</span><span class="keyword">double</span><span class="special">&gt;</span> <span class="identifier">dist</span><span class="special">(</span><span class="identifier">min</span><span class="special">,</span> <span class="identifier">max</span><span class="special">);</span>
<span class="keyword">return</span> <span class="identifier">dist</span><span class="special">(</span><span class="identifier">rng</span><span class="special">);</span>
<span class="special">}</span>
<span class="keyword">template</span><span class="special">&lt;</span><span class="keyword">typename</span> <span class="identifier">S</span><span class="special">&gt;</span>
<span class="keyword">struct</span> <span class="identifier">vec3</span>
<span class="special">{</span>
<span class="comment">/**
* @brief R^3 coordinates of particle on Thomson Sphere
*/</span>
<span class="identifier">S</span> <span class="identifier">x</span><span class="special">,</span> <span class="identifier">y</span><span class="special">,</span> <span class="identifier">z</span><span class="special">;</span>
<span class="special">};</span>
<span class="keyword">template</span><span class="special">&lt;</span><span class="keyword">class</span> <span class="identifier">S</span><span class="special">&gt;</span>
<span class="keyword">static</span> <span class="keyword">inline</span> <span class="identifier">vec3</span><span class="special">&lt;</span><span class="identifier">S</span><span class="special">&gt;</span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="keyword">const</span> <span class="identifier">S</span><span class="special">&amp;</span> <span class="identifier">theta</span><span class="special">,</span> <span class="keyword">const</span> <span class="identifier">S</span><span class="special">&amp;</span> <span class="identifier">phi</span><span class="special">)</span>
<span class="special">{</span>
<span class="comment">/**
* convenience overload to convert from [theta,phi] -&gt; x, y, z
*/</span>
<span class="keyword">return</span> <span class="special">{</span><span class="identifier">sin</span><span class="special">(</span><span class="identifier">theta</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">cos</span><span class="special">(</span><span class="identifier">phi</span><span class="special">),</span> <span class="identifier">sin</span><span class="special">(</span><span class="identifier">theta</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">sin</span><span class="special">(</span><span class="identifier">phi</span><span class="special">),</span> <span class="identifier">cos</span><span class="special">(</span><span class="identifier">theta</span><span class="special">)};</span>
<span class="special">}</span>
<span class="keyword">template</span><span class="special">&lt;</span><span class="keyword">typename</span> <span class="identifier">T</span><span class="special">&gt;</span>
<span class="identifier">T</span> <span class="identifier">thomson_energy</span><span class="special">(</span><span class="identifier">std</span><span class="special">::</span><span class="identifier">vector</span><span class="special">&lt;</span><span class="identifier">T</span><span class="special">&gt;&amp;</span> <span class="identifier">r</span><span class="special">)</span>
<span class="special">{</span>
<span class="comment">/* inverse square law
*/</span>
<span class="keyword">const</span> <span class="identifier">size_t</span> <span class="identifier">N</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">size</span><span class="special">()</span> <span class="special">/</span> <span class="number">2</span><span class="special">;</span>
<span class="keyword">const</span> <span class="identifier">T</span> <span class="identifier">tiny</span> <span class="special">=</span> <span class="identifier">T</span><span class="special">(</span><span class="number">1e-12</span><span class="special">);</span>
<span class="identifier">T</span> <span class="identifier">E</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span>
<span class="keyword">for</span> <span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">i</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">i</span> <span class="special">&lt;</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">i</span><span class="special">)</span> <span class="special">{</span>
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&amp;</span> <span class="identifier">theta_i</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">i</span> <span class="special">+</span> <span class="number">0</span><span class="special">];</span>
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&amp;</span> <span class="identifier">phi_i</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">i</span> <span class="special">+</span> <span class="number">1</span><span class="special">];</span>
<span class="keyword">auto</span> <span class="identifier">ri</span> <span class="special">=</span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="identifier">theta_i</span><span class="special">,</span> <span class="identifier">phi_i</span><span class="special">);</span>
<span class="keyword">for</span> <span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">j</span> <span class="special">=</span> <span class="identifier">i</span> <span class="special">+</span> <span class="number">1</span><span class="special">;</span> <span class="identifier">j</span> <span class="special">&lt;</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">j</span><span class="special">)</span> <span class="special">{</span>
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&amp;</span> <span class="identifier">theta_j</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">j</span> <span class="special">+</span> <span class="number">0</span><span class="special">];</span>
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&amp;</span> <span class="identifier">phi_j</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">j</span> <span class="special">+</span> <span class="number">1</span><span class="special">];</span>
<span class="keyword">auto</span> <span class="identifier">rj</span> <span class="special">=</span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="identifier">theta_j</span><span class="special">,</span> <span class="identifier">phi_j</span><span class="special">);</span>
<span class="identifier">T</span> <span class="identifier">dx</span> <span class="special">=</span> <span class="identifier">ri</span><span class="special">.</span><span class="identifier">x</span> <span class="special">-</span> <span class="identifier">rj</span><span class="special">.</span><span class="identifier">x</span><span class="special">;</span>
<span class="identifier">T</span> <span class="identifier">dy</span> <span class="special">=</span> <span class="identifier">ri</span><span class="special">.</span><span class="identifier">y</span> <span class="special">-</span> <span class="identifier">rj</span><span class="special">.</span><span class="identifier">y</span><span class="special">;</span>
<span class="identifier">T</span> <span class="identifier">dz</span> <span class="special">=</span> <span class="identifier">ri</span><span class="special">.</span><span class="identifier">z</span> <span class="special">-</span> <span class="identifier">rj</span><span class="special">.</span><span class="identifier">z</span><span class="special">;</span>
<span class="identifier">T</span> <span class="identifier">d2</span> <span class="special">=</span> <span class="identifier">dx</span> <span class="special">*</span> <span class="identifier">dx</span> <span class="special">+</span> <span class="identifier">dy</span> <span class="special">*</span> <span class="identifier">dy</span> <span class="special">+</span> <span class="identifier">dz</span> <span class="special">*</span> <span class="identifier">dz</span> <span class="special">+</span> <span class="identifier">tiny</span><span class="special">;</span>
<span class="identifier">E</span> <span class="special">+=</span> <span class="number">1.0</span> <span class="special">/</span> <span class="identifier">sqrt</span><span class="special">(</span><span class="identifier">d2</span><span class="special">);</span>
<span class="special">}</span>
<span class="special">}</span>
<span class="keyword">return</span> <span class="identifier">E</span><span class="special">;</span>
<span class="special">}</span>
<span class="keyword">template</span><span class="special">&lt;</span><span class="keyword">class</span> <span class="identifier">T</span><span class="special">&gt;</span>
<span class="identifier">std</span><span class="special">::</span><span class="identifier">vector</span><span class="special">&lt;</span><span class="identifier">rdiff</span><span class="special">::</span><span class="identifier">rvar</span><span class="special">&lt;</span><span class="identifier">T</span><span class="special">,</span> <span class="number">1</span><span class="special">&gt;&gt;</span> <span class="identifier">init_theta_phi_uniform</span><span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">N</span><span class="special">,</span> <span class="keyword">unsigned</span> <span class="identifier">seed</span> <span class="special">=</span> <span class="number">12345</span><span class="special">)</span>
<span class="special">{</span>
<span class="keyword">const</span> <span class="identifier">T</span> <span class="identifier">pi</span> <span class="special">=</span> <span class="identifier">T</span><span class="special">(</span><span class="number">3.1415926535897932384626433832795</span><span class="special">);</span>
<span class="identifier">std</span><span class="special">::</span><span class="identifier">mt19937</span> <span class="identifier">rng</span><span class="special">(</span><span class="identifier">seed</span><span class="special">);</span>
<span class="identifier">std</span><span class="special">::</span><span class="identifier">uniform_real_distribution</span><span class="special">&lt;</span><span class="identifier">T</span><span class="special">&gt;</span> <span class="identifier">unif01</span><span class="special">(</span><span class="identifier">T</span><span class="special">(</span><span class="number">0</span><span class="special">),</span> <span class="identifier">T</span><span class="special">(</span><span class="number">1</span><span class="special">));</span>
<span class="identifier">std</span><span class="special">::</span><span class="identifier">uniform_real_distribution</span><span class="special">&lt;</span><span class="identifier">T</span><span class="special">&gt;</span> <span class="identifier">unifm11</span><span class="special">(</span><span class="identifier">T</span><span class="special">(-</span><span class="number">1</span><span class="special">),</span> <span class="identifier">T</span><span class="special">(</span><span class="number">1</span><span class="special">));</span>
<span class="identifier">std</span><span class="special">::</span><span class="identifier">vector</span><span class="special">&lt;</span><span class="identifier">rdiff</span><span class="special">::</span><span class="identifier">rvar</span><span class="special">&lt;</span><span class="identifier">T</span><span class="special">,</span> <span class="number">1</span><span class="special">&gt;&gt;</span> <span class="identifier">u</span><span class="special">;</span>
<span class="identifier">u</span><span class="special">.</span><span class="identifier">reserve</span><span class="special">(</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">N</span><span class="special">);</span>
<span class="keyword">for</span> <span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">i</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">i</span> <span class="special">&lt;</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">i</span><span class="special">)</span> <span class="special">{</span>
<span class="identifier">T</span> <span class="identifier">z</span> <span class="special">=</span> <span class="identifier">unifm11</span><span class="special">(</span><span class="identifier">rng</span><span class="special">);</span>
<span class="identifier">T</span> <span class="identifier">phi</span> <span class="special">=</span> <span class="special">(</span><span class="identifier">T</span><span class="special">(</span><span class="number">2</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">pi</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">unif01</span><span class="special">(</span><span class="identifier">rng</span><span class="special">)</span> <span class="special">-</span> <span class="identifier">pi</span><span class="special">;</span>
<span class="identifier">T</span> <span class="identifier">theta</span> <span class="special">=</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">acos</span><span class="special">(</span><span class="identifier">z</span><span class="special">);</span>
<span class="identifier">u</span><span class="special">.</span><span class="identifier">emplace_back</span><span class="special">(</span><span class="identifier">theta</span><span class="special">);</span>
<span class="identifier">u</span><span class="special">.</span><span class="identifier">emplace_back</span><span class="special">(</span><span class="identifier">phi</span><span class="special">);</span>
<span class="special">}</span>
<span class="keyword">return</span> <span class="identifier">u</span><span class="special">;</span>
<span class="special">}</span>
<span class="keyword">int</span> <span class="identifier">main</span><span class="special">(</span><span class="keyword">int</span> <span class="identifier">argc</span><span class="special">,</span> <span class="keyword">char</span><span class="special">*</span> <span class="identifier">argv</span><span class="special">[])</span>
<span class="special">{</span>
<span class="keyword">if</span> <span class="special">(</span><span class="identifier">argc</span> <span class="special">!=</span> <span class="number">2</span><span class="special">)</span> <span class="special">{</span>
<span class="identifier">std</span><span class="special">::</span><span class="identifier">cerr</span> <span class="special">&lt;&lt;</span> <span class="string">"Usage: "</span> <span class="special">&lt;&lt;</span> <span class="identifier">argv</span><span class="special">[</span><span class="number">0</span><span class="special">]</span> <span class="special">&lt;&lt;</span> <span class="string">" &lt;N&gt;\n"</span><span class="special">;</span>
<span class="keyword">return</span> <span class="number">1</span><span class="special">;</span>
<span class="special">}</span>
<span class="keyword">const</span> <span class="keyword">int</span> <span class="identifier">N</span> <span class="special">=</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">stoi</span><span class="special">(</span><span class="identifier">argv</span><span class="special">[</span><span class="number">1</span><span class="special">]);</span>
<span class="keyword">const</span> <span class="keyword">double</span> <span class="identifier">lr</span> <span class="special">=</span> <span class="number">1e-3</span><span class="special">;</span>
<span class="keyword">const</span> <span class="keyword">double</span> <span class="identifier">mu</span> <span class="special">=</span> <span class="number">0.95</span><span class="special">;</span>
<span class="keyword">auto</span> <span class="identifier">u_ad</span> <span class="special">=</span> <span class="identifier">init_theta_phi_uniform</span><span class="special">&lt;</span><span class="keyword">double</span><span class="special">&gt;(</span><span class="identifier">N</span><span class="special">);</span>
<span class="keyword">auto</span> <span class="identifier">nesterov_opt</span> <span class="special">=</span> <span class="identifier">bopt</span><span class="special">::</span><span class="identifier">make_nag</span><span class="special">(&amp;</span><span class="identifier">thomson_energy</span><span class="special">&lt;</span><span class="identifier">rdiff</span><span class="special">::</span><span class="identifier">rvar</span><span class="special">&lt;</span><span class="keyword">double</span><span class="special">,</span> <span class="number">1</span><span class="special">&gt;&gt;,</span> <span class="identifier">u_ad</span><span class="special">,</span> <span class="identifier">lr</span><span class="special">,</span> <span class="identifier">mu</span><span class="special">);</span>
<span class="comment">// filenames</span>
<span class="identifier">std</span><span class="special">::</span><span class="identifier">string</span> <span class="identifier">pos_filename</span> <span class="special">=</span> <span class="string">"thomson_"</span> <span class="special">+</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">to_string</span><span class="special">(</span><span class="identifier">N</span><span class="special">)</span> <span class="special">+</span> <span class="string">".csv"</span><span class="special">;</span>
<span class="identifier">std</span><span class="special">::</span><span class="identifier">string</span> <span class="identifier">energy_filename</span> <span class="special">=</span> <span class="string">"nesterov_energy_"</span> <span class="special">+</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">to_string</span><span class="special">(</span><span class="identifier">N</span><span class="special">)</span> <span class="special">+</span> <span class="string">".csv"</span><span class="special">;</span>
<span class="identifier">std</span><span class="special">::</span><span class="identifier">ofstream</span> <span class="identifier">pos_out</span><span class="special">(</span><span class="identifier">pos_filename</span><span class="special">);</span>
<span class="identifier">std</span><span class="special">::</span><span class="identifier">ofstream</span> <span class="identifier">energy_out</span><span class="special">(</span><span class="identifier">energy_filename</span><span class="special">);</span>
<span class="identifier">energy_out</span> <span class="special">&lt;&lt;</span> <span class="string">"step,energy\n"</span><span class="special">;</span>
<span class="keyword">auto</span> <span class="identifier">result</span> <span class="special">=</span> <span class="identifier">minimize</span><span class="special">(</span><span class="identifier">nesterov_opt</span><span class="special">);</span>
<span class="keyword">for</span> <span class="special">(</span><span class="keyword">int</span> <span class="identifier">pi</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">pi</span> <span class="special">&lt;</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">pi</span><span class="special">)</span> <span class="special">{</span>
<span class="keyword">double</span> <span class="identifier">theta</span> <span class="special">=</span> <span class="identifier">u_ad</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">pi</span> <span class="special">+</span> <span class="number">0</span><span class="special">].</span><span class="identifier">item</span><span class="special">();</span>
<span class="keyword">double</span> <span class="identifier">phi</span> <span class="special">=</span> <span class="identifier">u_ad</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">pi</span> <span class="special">+</span> <span class="number">1</span><span class="special">].</span><span class="identifier">item</span><span class="special">();</span>
<span class="keyword">auto</span> <span class="identifier">r</span> <span class="special">=</span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="identifier">theta</span><span class="special">,</span> <span class="identifier">phi</span><span class="special">);</span>
<span class="identifier">pos_out</span> <span class="special">&lt;&lt;</span> <span class="identifier">pi</span> <span class="special">&lt;&lt;</span> <span class="string">","</span> <span class="special">&lt;&lt;</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">x</span> <span class="special">&lt;&lt;</span> <span class="string">","</span> <span class="special">&lt;&lt;</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">y</span> <span class="special">&lt;&lt;</span> <span class="string">","</span> <span class="special">&lt;&lt;</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">z</span> <span class="special">&lt;&lt;</span> <span class="string">"\n"</span><span class="special">;</span>
<span class="special">}</span>
<span class="keyword">auto</span> <span class="identifier">E</span> <span class="special">=</span> <span class="identifier">nesterov_opt</span><span class="special">.</span><span class="identifier">objective_value</span><span class="special">();</span>
<span class="keyword">int</span> <span class="identifier">i</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span>
<span class="keyword">for</span><span class="special">(</span><span class="keyword">auto</span><span class="special">&amp;</span> <span class="identifier">obj_hist</span> <span class="special">:</span> <span class="identifier">result</span><span class="special">.</span><span class="identifier">objective_history</span><span class="special">)</span>
<span class="special">{</span>
<span class="identifier">energy_out</span> <span class="special">&lt;&lt;</span> <span class="identifier">i</span> <span class="special">&lt;&lt;</span> <span class="string">","</span> <span class="special">&lt;&lt;</span> <span class="identifier">obj_hist</span> <span class="special">&lt;&lt;</span> <span class="string">"\n"</span><span class="special">;</span>
<span class="special">++</span><span class="identifier">i</span><span class="special">;</span>
<span class="special">}</span>
<span class="identifier">energy_out</span> <span class="special">&lt;&lt;</span> <span class="string">","</span> <span class="special">&lt;&lt;</span> <span class="identifier">E</span> <span class="special">&lt;&lt;</span> <span class="string">"\n"</span><span class="special">;</span>
<span class="identifier">pos_out</span><span class="special">.</span><span class="identifier">close</span><span class="special">();</span>
<span class="identifier">energy_out</span><span class="special">.</span><span class="identifier">close</span><span class="special">();</span>
<span class="keyword">return</span> <span class="number">0</span><span class="special">;</span>
<span class="special">}</span>
</pre>
<p>
The nesterov version of this problem converges much faster than regular gradient
descent, in only <code class="computeroutput"><span class="number">4663</span></code> iterations
with default parameters, vs the <code class="computeroutput"><span class="number">93799</span></code>
iterations required by gradient descent.
</p>
<div class="blockquote"><blockquote class="blockquote"><div class="blockquote"><blockquote class="blockquote"><p>
<span class="inlinemediaobject"><img src="../../../graphs/gradient_based_optimizers/nag_to_gd_comparison.svg"></span>
</p></blockquote></div></blockquote></div>
</div>
<div class="copyright-footer">Copyright © 2006-2021 Nikhar Agrawal, Anton Bikineev, Matthew Borland,
Paul A. Bristow, Marco Guazzone, Christopher Kormanyos, Hubert Holin, Bruno
Lalande, John Maddock, Evan Miller, Jeremy Murphy, Matthew Pulver, Johan Råde,
Gautam Sewani, Benjamin Sobotta, Nicholas Thompson, Thijs van den Berg, Daryle
Walker, Xiaogang Zhang, and Maksym Zhelyeznyakov<p>
Distributed under the Boost Software License, Version 1.0. (See accompanying
file LICENSE_1_0.txt or copy at <a href="http://www.boost.org/LICENSE_1_0.txt" target="_top">http://www.boost.org/LICENSE_1_0.txt</a>)
</p>
</div>
<hr>
<div class="spirit-nav">
<a accesskey="p" href="gradient_descent.html"><img src="../../../../../../doc/src/images/prev.png" alt="Prev"></a><a accesskey="u" href="../gd_opt.html"><img src="../../../../../../doc/src/images/up.png" alt="Up"></a><a accesskey="h" href="../../index.html"><img src="../../../../../../doc/src/images/home.png" alt="Home"></a><a accesskey="n" href="lbfgs.html"><img src="../../../../../../doc/src/images/next.png" alt="Next"></a>
</div>
</body>
</html>