mirror of
https://github.com/boostorg/math.git
synced 2026-02-24 16:12:15 +00:00
414 lines
54 KiB
HTML
414 lines
54 KiB
HTML
<html>
|
|
<head>
|
|
<meta charset="UTF-8">
|
|
<title>Gradient Desccent</title>
|
|
<link rel="stylesheet" href="../../math.css" type="text/css">
|
|
<meta name="generator" content="DocBook XSL Stylesheets Vsnapshot">
|
|
<link rel="home" href="../../index.html" title="Math Toolkit 4.2.1">
|
|
<link rel="up" href="../gd_opt.html" title="Gradient Based Optimizers">
|
|
<link rel="prev" href="introduction.html" title="Introduction">
|
|
<link rel="next" href="nesterov.html" title="Nesterov Accelerated Gradient Desccent">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
</head>
|
|
<body bgcolor="white" text="black" link="#0000FF" vlink="#840084" alink="#0000FF">
|
|
<table cellpadding="2" width="100%"><tr>
|
|
<td valign="top"><img alt="Boost C++ Libraries" width="277" height="86" src="../../../../../../boost.png"></td>
|
|
<td align="center"><a href="../../../../../../index.html">Home</a></td>
|
|
<td align="center"><a href="../../../../../../libs/libraries.htm">Libraries</a></td>
|
|
<td align="center"><a href="http://www.boost.org/users/people.html">People</a></td>
|
|
<td align="center"><a href="http://www.boost.org/users/faq.html">FAQ</a></td>
|
|
<td align="center"><a href="../../../../../../more/index.htm">More</a></td>
|
|
</tr></table>
|
|
<hr>
|
|
<div class="spirit-nav">
|
|
<a accesskey="p" href="introduction.html"><img src="../../../../../../doc/src/images/prev.png" alt="Prev"></a><a accesskey="u" href="../gd_opt.html"><img src="../../../../../../doc/src/images/up.png" alt="Up"></a><a accesskey="h" href="../../index.html"><img src="../../../../../../doc/src/images/home.png" alt="Home"></a><a accesskey="n" href="nesterov.html"><img src="../../../../../../doc/src/images/next.png" alt="Next"></a>
|
|
</div>
|
|
<div class="section">
|
|
<div class="titlepage"><div><div><h3 class="title">
|
|
<a name="math_toolkit.gd_opt.gradient_descent"></a><a class="link" href="gradient_descent.html" title="Gradient Desccent">Gradient Desccent</a>
|
|
</h3></div></div></div>
|
|
<h5>
|
|
<a name="math_toolkit.gd_opt.gradient_descent.h0"></a>
|
|
<span class="phrase"><a name="math_toolkit.gd_opt.gradient_descent.synopsis"></a></span><a class="link" href="gradient_descent.html#math_toolkit.gd_opt.gradient_descent.synopsis">Synopsis</a>
|
|
</h5>
|
|
<pre class="programlisting"><span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">optimization</span><span class="special">/</span><span class="identifier">gradient_descent</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">></span>
|
|
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span>
|
|
<span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">InitializationPolicy</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">ObjectiveEvalPolicy</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">GradEvalPolicy</span><span class="special">></span>
|
|
<span class="keyword">class</span> <span class="identifier">gradient_descent</span> <span class="special">{</span>
|
|
<span class="keyword">public</span><span class="special">:</span>
|
|
<span class="keyword">void</span> <span class="identifier">step</span><span class="special">();</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="comment">/* Convenience overloads */</span>
|
|
<span class="comment">/* make gradient descent by providing
|
|
** objective function
|
|
** variables to optimize over
|
|
** optionally learing rate
|
|
*
|
|
* requires that code is written using boost::math::differentiation::rvar
|
|
*/</span>
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span> <span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span> <span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">></span>
|
|
<span class="keyword">auto</span> <span class="identifier">make_gradient_descent</span><span class="special">(</span><span class="identifier">Objective</span><span class="special">&&</span> <span class="identifier">obj</span><span class="special">,</span> <span class="identifier">ArgumentContainer</span><span class="special">&</span> <span class="identifier">x</span><span class="special">,</span> <span class="identifier">RealType</span> <span class="identifier">lr</span> <span class="special">=</span> <span class="identifier">RealType</span><span class="special">{</span> <span class="number">0.01</span> <span class="special">});</span>
|
|
|
|
<span class="comment">/* make gradient descent by providing
|
|
* objective function
|
|
** variables to optimize over
|
|
** learning rate (not optional)
|
|
** initialization policy
|
|
*
|
|
* requires that code is written using boost::math::differentiation::rvar
|
|
*/</span>
|
|
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span> <span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span> <span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">,</span> <span class="keyword">class</span> <span class="identifier">InitializationPolicy</span><span class="special">></span>
|
|
<span class="keyword">auto</span> <span class="identifier">make_gradient_descent</span><span class="special">(</span><span class="identifier">Objective</span><span class="special">&&</span> <span class="identifier">obj</span><span class="special">,</span>
|
|
<span class="identifier">ArgumentContainer</span><span class="special">&</span> <span class="identifier">x</span><span class="special">,</span>
|
|
<span class="identifier">RealType</span> <span class="identifier">lr</span><span class="special">,</span>
|
|
<span class="identifier">InitializationPolicy</span><span class="special">&&</span> <span class="identifier">ip</span><span class="special">);</span>
|
|
<span class="comment">/* make gradient descent by providing
|
|
** objective function
|
|
** variables to optimize over
|
|
** learning rate (not optional)
|
|
** variable initialization policy
|
|
** objective evaluation policy
|
|
** gradient evaluation policy
|
|
*
|
|
* code does not have to use boost::math::differentiation::rvar
|
|
*/</span>
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">typename</span> <span class="identifier">ArgumentContainer</span><span class="special">,</span>
|
|
<span class="keyword">typename</span> <span class="identifier">RealType</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">Objective</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">InitializationPolicy</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">ObjectiveEvalPolicy</span><span class="special">,</span>
|
|
<span class="keyword">class</span> <span class="identifier">GradEvalPolicy</span><span class="special">></span>
|
|
<span class="keyword">auto</span> <span class="identifier">make_gradient_descent</span><span class="special">(</span><span class="identifier">Objective</span><span class="special">&&</span> <span class="identifier">obj</span><span class="special">,</span>
|
|
<span class="identifier">ArgumentContainer</span><span class="special">&</span> <span class="identifier">x</span><span class="special">,</span>
|
|
<span class="identifier">RealType</span><span class="special">&</span> <span class="identifier">lr</span><span class="special">,</span>
|
|
<span class="identifier">InitializationPolicy</span><span class="special">&&</span> <span class="identifier">ip</span><span class="special">,</span>
|
|
<span class="identifier">ObjectiveEvalPolicy</span><span class="special">&&</span> <span class="identifier">oep</span><span class="special">,</span>
|
|
<span class="identifier">GradEvalPolicy</span><span class="special">&&</span> <span class="identifier">gep</span><span class="special">)</span>
|
|
</pre>
|
|
<p>
|
|
Gradient descent iteratively updates parameters <code class="computeroutput"><span class="identifier">x</span></code>
|
|
in the direction opposite to the gradient of the objective function (minimizing
|
|
the objective).
|
|
</p>
|
|
<pre class="programlisting"><span class="identifier">x</span><span class="special">[</span><span class="identifier">i</span><span class="special">]</span> <span class="special">-=</span> <span class="identifier">lr</span> <span class="special">*</span> <span class="identifier">g</span><span class="special">[</span><span class="identifier">i</span><span class="special">]</span>
|
|
</pre>
|
|
<p>
|
|
where <code class="computeroutput"><span class="identifier">lr</span></code> is a user defined
|
|
learning rate. For a more complete decription of the theoretical principle
|
|
check <a href="https://en.wikipedia.org/wiki/Gradient_descent" target="_top">the wikipedia
|
|
page</a>
|
|
</p>
|
|
<p>
|
|
The implementation delegates: - the initialization of differentiable variables
|
|
to an initialization policy - objective evaluation to an objective evaluation
|
|
policy - the gradient computation to a gradient evaluation policy - the parameter
|
|
updates to an update policy
|
|
</p>
|
|
<p>
|
|
The interface is intended to be pytorch-like, where a optimizer object is
|
|
constructed and progressed with a <code class="computeroutput"><span class="identifier">step</span><span class="special">()</span></code> method. A helper <code class="computeroutput"><span class="identifier">minimize</span></code>
|
|
method is also provided.
|
|
</p>
|
|
<h5>
|
|
<a name="math_toolkit.gd_opt.gradient_descent.h1"></a>
|
|
<span class="phrase"><a name="math_toolkit.gd_opt.gradient_descent.parameters"></a></span><a class="link" href="gradient_descent.html#math_toolkit.gd_opt.gradient_descent.parameters">Parameters</a>
|
|
</h5>
|
|
<div class="itemizedlist"><ul class="itemizedlist" style="list-style-type: disc; ">
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">Objective</span><span class="special">&&</span>
|
|
<span class="identifier">obj</span></code> : objective funciton to
|
|
minimize
|
|
</li>
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">ArgumentContainer</span><span class="special">&</span>
|
|
<span class="identifier">x</span></code> : variables to optimize over
|
|
</li>
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">RealType</span><span class="special">&</span>
|
|
<span class="identifier">lr</span></code> : learning rate. A larger
|
|
value takes larger steps during descent, leading to faster, but more
|
|
unstable convergence. Conversely, small vaues are more stable but take
|
|
longer to converge.
|
|
</li>
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">InitializationPolicy</span><span class="special">&&</span> <span class="identifier">ip</span></code>
|
|
: Initialization policy for <code class="computeroutput"><span class="identifier">ArgumentContainer</span></code>,
|
|
or the initial guess. By default it is set to <code class="computeroutput"><span class="identifier">tape_initializer_rvar</span><span class="special"><</span><span class="identifier">RealType</span><span class="special">></span></code> which lets the user provide the "initial
|
|
guess" by setting the values of <code class="computeroutput"><span class="identifier">x</span></code>
|
|
manually. For more info check the Policies section.
|
|
</li>
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">ObjectiveEvalPolicy</span><span class="special">&&</span>
|
|
<span class="identifier">oep</span></code> : tells the optimizer how
|
|
to evaluate the objective function. By default <code class="computeroutput"><span class="identifier">reverse_mode_function_eval_policy</span><span class="special"><</span><span class="identifier">RealType</span><span class="special">></span></code>.
|
|
</li>
|
|
<li class="listitem">
|
|
<code class="computeroutput"><span class="identifier">GradEvalPolicy</span><span class="special">&&</span>
|
|
<span class="identifier">gep</span></code> : tells the optimzier how
|
|
to evaluate the gradient of the objective function. By default <code class="computeroutput"><span class="identifier">reverse_mode_gradient_evaluation_policy</span><span class="special"><</span><span class="identifier">RealType</span><span class="special">></span></code>
|
|
</li>
|
|
</ul></div>
|
|
<h5>
|
|
<a name="math_toolkit.gd_opt.gradient_descent.h2"></a>
|
|
<span class="phrase"><a name="math_toolkit.gd_opt.gradient_descent.example_using_a_manual_optimizat"></a></span><a class="link" href="gradient_descent.html#math_toolkit.gd_opt.gradient_descent.example_using_a_manual_optimizat">Example
|
|
using a manual optimization</a>
|
|
</h5>
|
|
<p>
|
|
In this section we will present an example for finding optimal configurations
|
|
of electrically charged particles confined to a <code class="computeroutput"><span class="identifier">R</span>
|
|
<span class="special">=</span> <span class="number">1</span></code>
|
|
sphere. This problem is also known as the <a href="https://en.wikipedia.org/wiki/Thomson_problem" target="_top">Thomson
|
|
problem</a>. In summary, we are looking for the configuration of an N-electron
|
|
system subject to the Coulomb potential confined to the $S^2$ sphere. The
|
|
Coulomb potential is given by:
|
|
</p>
|
|
<div class="blockquote"><blockquote class="blockquote"><div class="blockquote"><blockquote class="blockquote"><p>
|
|
<span class="inlinemediaobject"><img src="../../../equations/autodiff/thomson_potential.svg"></span>
|
|
</p></blockquote></div></blockquote></div>
|
|
<p>
|
|
The code below manually minimizes the abover potential energy function for
|
|
N particles over their two angular pozitions.
|
|
</p>
|
|
<pre class="programlisting"><span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">differentiation</span><span class="special">/</span><span class="identifier">autodiff_reverse</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">optimization</span><span class="special">/</span><span class="identifier">gradient_descent</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">optimization</span><span class="special">/</span><span class="identifier">minimizer</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">cmath</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">fstream</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">iostream</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">random</span><span class="special">></span>
|
|
<span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">string</span><span class="special">></span>
|
|
<span class="keyword">namespace</span> <span class="identifier">rdiff</span> <span class="special">=</span> <span class="identifier">boost</span><span class="special">::</span><span class="identifier">math</span><span class="special">::</span><span class="identifier">differentiation</span><span class="special">::</span><span class="identifier">reverse_mode</span><span class="special">;</span>
|
|
<span class="keyword">namespace</span> <span class="identifier">bopt</span> <span class="special">=</span> <span class="identifier">boost</span><span class="special">::</span><span class="identifier">math</span><span class="special">::</span><span class="identifier">optimization</span><span class="special">;</span>
|
|
<span class="keyword">double</span> <span class="identifier">random_double</span><span class="special">(</span><span class="keyword">double</span> <span class="identifier">min</span> <span class="special">=</span> <span class="number">0.0</span><span class="special">,</span> <span class="keyword">double</span> <span class="identifier">max</span> <span class="special">=</span> <span class="number">1.0</span><span class="special">)</span>
|
|
<span class="special">{</span>
|
|
<span class="keyword">static</span> <span class="keyword">thread_local</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">mt19937</span> <span class="identifier">rng</span><span class="special">{</span><span class="identifier">std</span><span class="special">::</span><span class="identifier">random_device</span><span class="special">{}()};</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">uniform_real_distribution</span><span class="special"><</span><span class="keyword">double</span><span class="special">></span> <span class="identifier">dist</span><span class="special">(</span><span class="identifier">min</span><span class="special">,</span> <span class="identifier">max</span><span class="special">);</span>
|
|
<span class="keyword">return</span> <span class="identifier">dist</span><span class="special">(</span><span class="identifier">rng</span><span class="special">);</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">typename</span> <span class="identifier">S</span><span class="special">></span>
|
|
<span class="keyword">struct</span> <span class="identifier">vec3</span>
|
|
<span class="special">{</span>
|
|
<span class="comment">/**
|
|
* @brief R^3 coordinates of particle on Thomson Sphere
|
|
*/</span>
|
|
<span class="identifier">S</span> <span class="identifier">x</span><span class="special">,</span> <span class="identifier">y</span><span class="special">,</span> <span class="identifier">z</span><span class="special">;</span>
|
|
<span class="special">};</span>
|
|
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">class</span> <span class="identifier">S</span><span class="special">></span>
|
|
<span class="keyword">static</span> <span class="keyword">inline</span> <span class="identifier">vec3</span><span class="special"><</span><span class="identifier">S</span><span class="special">></span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="keyword">const</span> <span class="identifier">S</span><span class="special">&</span> <span class="identifier">theta</span><span class="special">,</span> <span class="keyword">const</span> <span class="identifier">S</span><span class="special">&</span> <span class="identifier">phi</span><span class="special">)</span>
|
|
<span class="special">{</span>
|
|
<span class="comment">/**
|
|
* convenience overload to convert from [theta,phi] -> x, y, z
|
|
*/</span>
|
|
<span class="keyword">return</span> <span class="special">{</span><span class="identifier">sin</span><span class="special">(</span><span class="identifier">theta</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">cos</span><span class="special">(</span><span class="identifier">phi</span><span class="special">),</span> <span class="identifier">sin</span><span class="special">(</span><span class="identifier">theta</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">sin</span><span class="special">(</span><span class="identifier">phi</span><span class="special">),</span> <span class="identifier">cos</span><span class="special">(</span><span class="identifier">theta</span><span class="special">)};</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">typename</span> <span class="identifier">T</span><span class="special">></span>
|
|
<span class="identifier">T</span> <span class="identifier">thomson_energy</span><span class="special">(</span><span class="identifier">std</span><span class="special">::</span><span class="identifier">vector</span><span class="special"><</span><span class="identifier">T</span><span class="special">>&</span> <span class="identifier">r</span><span class="special">)</span>
|
|
<span class="special">{</span>
|
|
<span class="keyword">const</span> <span class="identifier">size_t</span> <span class="identifier">N</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">size</span><span class="special">()</span> <span class="special">/</span> <span class="number">2</span><span class="special">;</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span> <span class="identifier">tiny</span> <span class="special">=</span> <span class="identifier">T</span><span class="special">(</span><span class="number">1e-12</span><span class="special">);</span>
|
|
|
|
<span class="identifier">T</span> <span class="identifier">E</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span>
|
|
<span class="keyword">for</span> <span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">i</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">i</span> <span class="special"><</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">i</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&</span> <span class="identifier">theta_i</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">i</span> <span class="special">+</span> <span class="number">0</span><span class="special">];</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&</span> <span class="identifier">phi_i</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">i</span> <span class="special">+</span> <span class="number">1</span><span class="special">];</span>
|
|
<span class="keyword">auto</span> <span class="identifier">ri</span> <span class="special">=</span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="identifier">theta_i</span><span class="special">,</span> <span class="identifier">phi_i</span><span class="special">);</span>
|
|
|
|
<span class="keyword">for</span> <span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">j</span> <span class="special">=</span> <span class="identifier">i</span> <span class="special">+</span> <span class="number">1</span><span class="special">;</span> <span class="identifier">j</span> <span class="special"><</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">j</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&</span> <span class="identifier">theta_j</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">j</span> <span class="special">+</span> <span class="number">0</span><span class="special">];</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span><span class="special">&</span> <span class="identifier">phi_j</span> <span class="special">=</span> <span class="identifier">r</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">j</span> <span class="special">+</span> <span class="number">1</span><span class="special">];</span>
|
|
<span class="keyword">auto</span> <span class="identifier">rj</span> <span class="special">=</span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="identifier">theta_j</span><span class="special">,</span> <span class="identifier">phi_j</span><span class="special">);</span>
|
|
|
|
<span class="identifier">T</span> <span class="identifier">dx</span> <span class="special">=</span> <span class="identifier">ri</span><span class="special">.</span><span class="identifier">x</span> <span class="special">-</span> <span class="identifier">rj</span><span class="special">.</span><span class="identifier">x</span><span class="special">;</span>
|
|
<span class="identifier">T</span> <span class="identifier">dy</span> <span class="special">=</span> <span class="identifier">ri</span><span class="special">.</span><span class="identifier">y</span> <span class="special">-</span> <span class="identifier">rj</span><span class="special">.</span><span class="identifier">y</span><span class="special">;</span>
|
|
<span class="identifier">T</span> <span class="identifier">dz</span> <span class="special">=</span> <span class="identifier">ri</span><span class="special">.</span><span class="identifier">z</span> <span class="special">-</span> <span class="identifier">rj</span><span class="special">.</span><span class="identifier">z</span><span class="special">;</span>
|
|
|
|
<span class="identifier">T</span> <span class="identifier">d2</span> <span class="special">=</span> <span class="identifier">dx</span> <span class="special">*</span> <span class="identifier">dx</span> <span class="special">+</span> <span class="identifier">dy</span> <span class="special">*</span> <span class="identifier">dy</span> <span class="special">+</span> <span class="identifier">dz</span> <span class="special">*</span> <span class="identifier">dz</span> <span class="special">+</span> <span class="identifier">tiny</span><span class="special">;</span>
|
|
<span class="identifier">E</span> <span class="special">+=</span> <span class="number">1.0</span> <span class="special">/</span> <span class="identifier">sqrt</span><span class="special">(</span><span class="identifier">d2</span><span class="special">);</span>
|
|
<span class="special">}</span>
|
|
<span class="special">}</span>
|
|
<span class="keyword">return</span> <span class="identifier">E</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="keyword">template</span><span class="special"><</span><span class="keyword">class</span> <span class="identifier">T</span><span class="special">></span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">vector</span><span class="special"><</span><span class="identifier">rdiff</span><span class="special">::</span><span class="identifier">rvar</span><span class="special"><</span><span class="identifier">T</span><span class="special">,</span> <span class="number">1</span><span class="special">>></span> <span class="identifier">init_theta_phi_uniform</span><span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">N</span><span class="special">,</span> <span class="keyword">unsigned</span> <span class="identifier">seed</span> <span class="special">=</span> <span class="number">12345</span><span class="special">)</span>
|
|
<span class="special">{</span>
|
|
<span class="keyword">const</span> <span class="identifier">T</span> <span class="identifier">pi</span> <span class="special">=</span> <span class="identifier">T</span><span class="special">(</span><span class="number">3.1415926535897932384626433832795</span><span class="special">);</span>
|
|
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">mt19937</span> <span class="identifier">rng</span><span class="special">(</span><span class="identifier">seed</span><span class="special">);</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">uniform_real_distribution</span><span class="special"><</span><span class="identifier">T</span><span class="special">></span> <span class="identifier">unif01</span><span class="special">(</span><span class="identifier">T</span><span class="special">(</span><span class="number">0</span><span class="special">),</span> <span class="identifier">T</span><span class="special">(</span><span class="number">1</span><span class="special">));</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">uniform_real_distribution</span><span class="special"><</span><span class="identifier">T</span><span class="special">></span> <span class="identifier">unifm11</span><span class="special">(</span><span class="identifier">T</span><span class="special">(-</span><span class="number">1</span><span class="special">),</span> <span class="identifier">T</span><span class="special">(</span><span class="number">1</span><span class="special">));</span>
|
|
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">vector</span><span class="special"><</span><span class="identifier">rdiff</span><span class="special">::</span><span class="identifier">rvar</span><span class="special"><</span><span class="identifier">T</span><span class="special">,</span> <span class="number">1</span><span class="special">>></span> <span class="identifier">u</span><span class="special">;</span>
|
|
<span class="identifier">u</span><span class="special">.</span><span class="identifier">reserve</span><span class="special">(</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">N</span><span class="special">);</span>
|
|
|
|
<span class="keyword">for</span> <span class="special">(</span><span class="identifier">size_t</span> <span class="identifier">i</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">i</span> <span class="special"><</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">i</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="identifier">T</span> <span class="identifier">z</span> <span class="special">=</span> <span class="identifier">unifm11</span><span class="special">(</span><span class="identifier">rng</span><span class="special">);</span>
|
|
<span class="identifier">T</span> <span class="identifier">phi</span> <span class="special">=</span> <span class="special">(</span><span class="identifier">T</span><span class="special">(</span><span class="number">2</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">pi</span><span class="special">)</span> <span class="special">*</span> <span class="identifier">unif01</span><span class="special">(</span><span class="identifier">rng</span><span class="special">)</span> <span class="special">-</span> <span class="identifier">pi</span><span class="special">;</span>
|
|
<span class="identifier">T</span> <span class="identifier">theta</span> <span class="special">=</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">acos</span><span class="special">(</span><span class="identifier">z</span><span class="special">);</span>
|
|
|
|
<span class="identifier">u</span><span class="special">.</span><span class="identifier">emplace_back</span><span class="special">(</span><span class="identifier">theta</span><span class="special">);</span>
|
|
<span class="identifier">u</span><span class="special">.</span><span class="identifier">emplace_back</span><span class="special">(</span><span class="identifier">phi</span><span class="special">);</span>
|
|
<span class="special">}</span>
|
|
<span class="keyword">return</span> <span class="identifier">u</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="keyword">int</span> <span class="identifier">main</span><span class="special">(</span><span class="keyword">int</span> <span class="identifier">argc</span><span class="special">,</span> <span class="keyword">char</span><span class="special">*</span> <span class="identifier">argv</span><span class="special">[])</span>
|
|
<span class="special">{</span>
|
|
<span class="keyword">if</span> <span class="special">(</span><span class="identifier">argc</span> <span class="special">!=</span> <span class="number">2</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">cerr</span> <span class="special"><<</span> <span class="string">"Usage: "</span> <span class="special"><<</span> <span class="identifier">argv</span><span class="special">[</span><span class="number">0</span><span class="special">]</span> <span class="special"><<</span> <span class="string">" <N>\n"</span><span class="special">;</span>
|
|
<span class="keyword">return</span> <span class="number">1</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="keyword">const</span> <span class="keyword">int</span> <span class="identifier">N</span> <span class="special">=</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">stoi</span><span class="special">(</span><span class="identifier">argv</span><span class="special">[</span><span class="number">1</span><span class="special">]);</span>
|
|
<span class="keyword">const</span> <span class="keyword">int</span> <span class="identifier">NSTEPS</span> <span class="special">=</span> <span class="number">100000</span><span class="special">;</span>
|
|
<span class="keyword">const</span> <span class="keyword">double</span> <span class="identifier">lr</span> <span class="special">=</span> <span class="number">1e-3</span><span class="special">;</span>
|
|
|
|
<span class="keyword">auto</span> <span class="identifier">u_ad</span> <span class="special">=</span> <span class="identifier">init_theta_phi_uniform</span><span class="special"><</span><span class="keyword">double</span><span class="special">>(</span><span class="identifier">N</span><span class="special">);</span>
|
|
|
|
<span class="keyword">auto</span> <span class="identifier">gdopt</span> <span class="special">=</span> <span class="identifier">bopt</span><span class="special">::</span><span class="identifier">make_gradient_descent</span><span class="special">(&</span><span class="identifier">thomson_energy</span><span class="special"><</span><span class="identifier">rdiff</span><span class="special">::</span><span class="identifier">rvar</span><span class="special"><</span><span class="keyword">double</span><span class="special">,</span> <span class="number">1</span><span class="special">>>,</span> <span class="identifier">u_ad</span><span class="special">,</span> <span class="identifier">lr</span><span class="special">);</span>
|
|
|
|
<span class="comment">// filenames</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">string</span> <span class="identifier">pos_filename</span> <span class="special">=</span> <span class="string">"thomson_"</span> <span class="special">+</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">to_string</span><span class="special">(</span><span class="identifier">N</span><span class="special">)</span> <span class="special">+</span> <span class="string">".csv"</span><span class="special">;</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">string</span> <span class="identifier">energy_filename</span> <span class="special">=</span> <span class="string">"energy_"</span> <span class="special">+</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">to_string</span><span class="special">(</span><span class="identifier">N</span><span class="special">)</span> <span class="special">+</span> <span class="string">".csv"</span><span class="special">;</span>
|
|
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">ofstream</span> <span class="identifier">pos_out</span><span class="special">(</span><span class="identifier">pos_filename</span><span class="special">);</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">ofstream</span> <span class="identifier">energy_out</span><span class="special">(</span><span class="identifier">energy_filename</span><span class="special">);</span>
|
|
|
|
<span class="identifier">pos_out</span> <span class="special"><<</span> <span class="string">"step,particle,x,y,z\n"</span><span class="special">;</span>
|
|
<span class="identifier">energy_out</span> <span class="special"><<</span> <span class="string">"step,energy\n"</span><span class="special">;</span>
|
|
|
|
<span class="keyword">for</span> <span class="special">(</span><span class="keyword">int</span> <span class="identifier">step</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">step</span> <span class="special"><</span> <span class="identifier">NSTEPS</span><span class="special">;</span> <span class="special">++</span><span class="identifier">step</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="identifier">gdopt</span><span class="special">.</span><span class="identifier">step</span><span class="special">();</span>
|
|
<span class="keyword">for</span> <span class="special">(</span><span class="keyword">int</span> <span class="identifier">pi</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">pi</span> <span class="special"><</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">pi</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="keyword">double</span> <span class="identifier">theta</span> <span class="special">=</span> <span class="identifier">u_ad</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">pi</span> <span class="special">+</span> <span class="number">0</span><span class="special">].</span><span class="identifier">item</span><span class="special">();</span>
|
|
<span class="keyword">double</span> <span class="identifier">phi</span> <span class="special">=</span> <span class="identifier">u_ad</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">pi</span> <span class="special">+</span> <span class="number">1</span><span class="special">].</span><span class="identifier">item</span><span class="special">();</span>
|
|
<span class="keyword">auto</span> <span class="identifier">r</span> <span class="special">=</span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="identifier">theta</span><span class="special">,</span> <span class="identifier">phi</span><span class="special">);</span>
|
|
<span class="identifier">pos_out</span> <span class="special"><<</span> <span class="identifier">step</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">pi</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">x</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">y</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">z</span> <span class="special"><<</span> <span class="string">"\n"</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
<span class="keyword">auto</span> <span class="identifier">E</span> <span class="special">=</span> <span class="identifier">gdopt</span><span class="special">.</span><span class="identifier">objective_value</span><span class="special">();</span>
|
|
<span class="identifier">energy_out</span> <span class="special"><<</span> <span class="identifier">step</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">E</span> <span class="special"><<</span> <span class="string">"\n"</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
|
|
<span class="identifier">pos_out</span><span class="special">.</span><span class="identifier">close</span><span class="special">();</span>
|
|
<span class="identifier">energy_out</span><span class="special">.</span><span class="identifier">close</span><span class="special">();</span>
|
|
|
|
<span class="keyword">return</span> <span class="number">0</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
</pre>
|
|
<p>
|
|
The variable
|
|
</p>
|
|
<pre class="programlisting"><span class="keyword">const</span> <span class="keyword">int</span> <span class="identifier">N</span> <span class="special">=</span> <span class="identifier">std</span><span class="special">::</span><span class="identifier">stoi</span><span class="special">(</span><span class="identifier">argv</span><span class="special">[</span><span class="number">1</span><span class="special">]);</span></pre>
|
|
<p>
|
|
is the number of particles read from the command line
|
|
</p>
|
|
<pre class="programlisting"><span class="keyword">const</span> <span class="keyword">int</span> <span class="identifier">NSTEPS</span> <span class="special">=</span> <span class="number">100000</span><span class="special">;</span></pre>
|
|
<p>
|
|
is the number of optimization steps
|
|
</p>
|
|
<pre class="programlisting"><span class="keyword">const</span> <span class="keyword">double</span> <span class="identifier">lr</span> <span class="special">=</span> <span class="number">1e-3</span><span class="special">;</span></pre>
|
|
<p>
|
|
is the optimizer learning rate. Using the code the way its written, the optimizer
|
|
runs for 100000 steps. Running tthe program with
|
|
</p>
|
|
<pre class="programlisting"><span class="special">./</span><span class="identifier">thomson_sphere</span> <span class="identifier">N</span>
|
|
</pre>
|
|
<p>
|
|
optimizes the N particle system. Below is a plot of several optimal configurations
|
|
for N=2,...8 particles.
|
|
</p>
|
|
<div class="blockquote"><blockquote class="blockquote"><div class="blockquote"><blockquote class="blockquote"><p>
|
|
<span class="inlinemediaobject"><img src="../../../graphs/gradient_based_optimizers/thomson_sphere_2to8.svg"></span>
|
|
</p></blockquote></div></blockquote></div>
|
|
<p>
|
|
Below is a plot of the final energy of the system, and its deviation from
|
|
the theoretically predicted values. The table of theorical energy values
|
|
for the problem is from <a href="https://en.wikipedia.org/wiki/Thomson_problem" target="_top">wikipedia</a>.
|
|
</p>
|
|
<div class="blockquote"><blockquote class="blockquote"><div class="blockquote"><blockquote class="blockquote"><p>
|
|
[gbo_graphi thomson_energy_error_gradient_descent.svg]
|
|
</p></blockquote></div></blockquote></div>
|
|
<h5>
|
|
<a name="math_toolkit.gd_opt.gradient_descent.h3"></a>
|
|
<span class="phrase"><a name="math_toolkit.gd_opt.gradient_descent.example_using_minimize"></a></span><a class="link" href="gradient_descent.html#math_toolkit.gd_opt.gradient_descent.example_using_minimize">Example
|
|
using minimize</a>
|
|
</h5>
|
|
<p>
|
|
Often, we don't want to actually implement our own stepping function, i.e.
|
|
we care about certain convergence criteria. In the above example, we need
|
|
to include the minimier.hpp header:
|
|
</p>
|
|
<pre class="programlisting"><span class="preprocessor">#include</span> <span class="special"><</span><span class="identifier">boost</span><span class="special">/</span><span class="identifier">math</span><span class="special">/</span><span class="identifier">optimization</span><span class="special">/</span><span class="identifier">minimizer</span><span class="special">.</span><span class="identifier">hpp</span><span class="special">></span>
|
|
</pre>
|
|
<p>
|
|
and replace the optimization loop:
|
|
</p>
|
|
<pre class="programlisting"><span class="keyword">for</span> <span class="special">(</span><span class="keyword">int</span> <span class="identifier">step</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">step</span> <span class="special"><</span> <span class="identifier">NSTEPS</span><span class="special">;</span> <span class="special">++</span><span class="identifier">step</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="identifier">gdopt</span><span class="special">.</span><span class="identifier">step</span><span class="special">();</span>
|
|
<span class="keyword">for</span> <span class="special">(</span><span class="keyword">int</span> <span class="identifier">pi</span> <span class="special">=</span> <span class="number">0</span><span class="special">;</span> <span class="identifier">pi</span> <span class="special"><</span> <span class="identifier">N</span><span class="special">;</span> <span class="special">++</span><span class="identifier">pi</span><span class="special">)</span> <span class="special">{</span>
|
|
<span class="keyword">double</span> <span class="identifier">theta</span> <span class="special">=</span> <span class="identifier">u_ad</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">pi</span> <span class="special">+</span> <span class="number">0</span><span class="special">].</span><span class="identifier">item</span><span class="special">();</span>
|
|
<span class="keyword">double</span> <span class="identifier">phi</span> <span class="special">=</span> <span class="identifier">u_ad</span><span class="special">[</span><span class="number">2</span> <span class="special">*</span> <span class="identifier">pi</span> <span class="special">+</span> <span class="number">1</span><span class="special">].</span><span class="identifier">item</span><span class="special">();</span>
|
|
<span class="keyword">auto</span> <span class="identifier">r</span> <span class="special">=</span> <span class="identifier">sph_to_xyz</span><span class="special">(</span><span class="identifier">theta</span><span class="special">,</span> <span class="identifier">phi</span><span class="special">);</span>
|
|
<span class="identifier">pos_out</span> <span class="special"><<</span> <span class="identifier">step</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">pi</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">x</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">y</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">r</span><span class="special">.</span><span class="identifier">z</span> <span class="special"><<</span> <span class="string">"\n"</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
<span class="keyword">auto</span> <span class="identifier">E</span> <span class="special">=</span> <span class="identifier">gdopt</span><span class="special">.</span><span class="identifier">objective_value</span><span class="special">();</span>
|
|
<span class="identifier">energy_out</span> <span class="special"><<</span> <span class="identifier">step</span> <span class="special"><<</span> <span class="string">","</span> <span class="special"><<</span> <span class="identifier">E</span> <span class="special"><<</span> <span class="string">"\n"</span><span class="special">;</span>
|
|
<span class="special">}</span>
|
|
</pre>
|
|
<p>
|
|
with
|
|
</p>
|
|
<pre class="programlisting"><span class="keyword">auto</span> <span class="identifier">result</span> <span class="special">=</span> <span class="identifier">minimize</span><span class="special">(</span><span class="identifier">gdopt</span><span class="special">);</span>
|
|
</pre>
|
|
<p>
|
|
minimize returns a
|
|
</p>
|
|
<pre class="programlisting"><span class="identifier">optimization_result</span><span class="special"><</span><span class="keyword">typename</span> <span class="identifier">Optimizer</span><span class="special">::</span><span class="identifier">real_type_t</span><span class="special">></span></pre>
|
|
<p>
|
|
, a struct with the following fields:
|
|
</p>
|
|
<pre class="programlisting"><span class="identifier">size_t</span> <span class="identifier">num_iter</span><span class="special">;</span>
|
|
<span class="identifier">RealType</span> <span class="identifier">objective_value</span><span class="special">;</span>
|
|
<span class="identifier">std</span><span class="special">::</span><span class="identifier">vector</span><span class="special"><</span><span class="identifier">RealType</span><span class="special">></span> <span class="identifier">objective_history</span><span class="special">;</span>
|
|
<span class="keyword">bool</span> <span class="identifier">converged</span><span class="special">;</span>
|
|
</pre>
|
|
<p>
|
|
where <code class="computeroutput"><span class="identifier">num_iter</span></code> is the number
|
|
of iterations the optimizer went through, <code class="computeroutput"><span class="identifier">objective_value</span></code>
|
|
is the final objective value, <code class="computeroutput"><span class="identifier">objective_history</span></code>
|
|
are the intermediate objective values, and <code class="computeroutput"><span class="identifier">converged</span></code>
|
|
is whether the convergence criterion was satisfied. By default, <code class="computeroutput"><span class="identifier">minimize</span><span class="special">(</span><span class="identifier">optimizer</span><span class="special">)</span></code>
|
|
uses a gradient norm convergence criterion. If norm(gradient_vector) <
|
|
1e-3, the criterion is satisfied. Maximum number of iterations is set at
|
|
100000. For more info on how to use <code class="computeroutput"><span class="identifier">minimize</span></code>
|
|
check the minimize docs. With default parameters, gradient descent solves
|
|
the <code class="computeroutput"><span class="identifier">N</span><span class="special">=</span><span class="number">2</span></code> problem in <code class="computeroutput"><span class="number">93799</span></code>
|
|
steps.
|
|
</p>
|
|
</div>
|
|
<div class="copyright-footer">Copyright © 2006-2021 Nikhar Agrawal, Anton Bikineev, Matthew Borland,
|
|
Paul A. Bristow, Marco Guazzone, Christopher Kormanyos, Hubert Holin, Bruno
|
|
Lalande, John Maddock, Evan Miller, Jeremy Murphy, Matthew Pulver, Johan Råde,
|
|
Gautam Sewani, Benjamin Sobotta, Nicholas Thompson, Thijs van den Berg, Daryle
|
|
Walker, Xiaogang Zhang, and Maksym Zhelyeznyakov<p>
|
|
Distributed under the Boost Software License, Version 1.0. (See accompanying
|
|
file LICENSE_1_0.txt or copy at <a href="http://www.boost.org/LICENSE_1_0.txt" target="_top">http://www.boost.org/LICENSE_1_0.txt</a>)
|
|
</p>
|
|
</div>
|
|
<hr>
|
|
<div class="spirit-nav">
|
|
<a accesskey="p" href="introduction.html"><img src="../../../../../../doc/src/images/prev.png" alt="Prev"></a><a accesskey="u" href="../gd_opt.html"><img src="../../../../../../doc/src/images/up.png" alt="Up"></a><a accesskey="h" href="../../index.html"><img src="../../../../../../doc/src/images/home.png" alt="Home"></a><a accesskey="n" href="nesterov.html"><img src="../../../../../../doc/src/images/next.png" alt="Next"></a>
|
|
</div>
|
|
</body>
|
|
</html>
|