Skip to content

Commit

Permalink
Deploying to gh-pages from @ 800ebd9 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
asogaard committed Dec 1, 2023
1 parent 2094390 commit 36a395f
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 5 deletions.
114 changes: 112 additions & 2 deletions _modules/graphnet/training/callbacks.html
Original file line number Diff line number Diff line change
Expand Up @@ -325,20 +325,26 @@ <h1 id="modules-graphnet-training-callbacks--page-root">Source code for graphnet
<span></span><span class="sd">"""Callback class(es) for using during model training."""</span>

<span class="kn">import</span> <span class="nn">logging</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">TYPE_CHECKING</span><span class="p">,</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Optional</span>
<span class="kn">import</span> <span class="nn">warnings</span>

<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">tqdm.std</span> <span class="kn">import</span> <span class="n">Bar</span>

<span class="kn">from</span> <span class="nn">pytorch_lightning</span> <span class="kn">import</span> <span class="n">LightningModule</span><span class="p">,</span> <span class="n">Trainer</span>
<span class="kn">from</span> <span class="nn">pytorch_lightning.callbacks</span> <span class="kn">import</span> <span class="n">TQDMProgressBar</span>
<span class="kn">from</span> <span class="nn">pytorch_lightning.callbacks</span> <span class="kn">import</span> <span class="n">TQDMProgressBar</span><span class="p">,</span> <span class="n">EarlyStopping</span>
<span class="kn">from</span> <span class="nn">pytorch_lightning.utilities</span> <span class="kn">import</span> <span class="n">rank_zero_only</span>
<span class="kn">from</span> <span class="nn">torch.optim</span> <span class="kn">import</span> <span class="n">Optimizer</span>
<span class="kn">from</span> <span class="nn">torch.optim.lr_scheduler</span> <span class="kn">import</span> <span class="n">_LRScheduler</span>

<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span>

<span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
<span class="kn">import</span> <span class="nn">pytorch_lightning</span> <span class="k">as</span> <span class="nn">pl</span>


<div class="viewcode-block" id="PiecewiseLinearLR">
<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.PiecewiseLinearLR">[docs]</a>
Expand Down Expand Up @@ -506,6 +512,110 @@ <h1 id="modules-graphnet-training-callbacks--page-root">Source code for graphnet
<span class="n">h</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span></div>
</div>



<div class="viewcode-block" id="GraphnetEarlyStopping">
<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.GraphnetEarlyStopping">[docs]</a>
<span class="k">class</span> <span class="nc">GraphnetEarlyStopping</span><span class="p">(</span><span class="n">EarlyStopping</span><span class="p">):</span>
<span class="w"> </span><span class="sd">"""Early stopping callback for graphnet."""</span>

<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">save_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="w"> </span><span class="sd">"""Construct `GraphnetEarlyStopping` Callback.</span>

<span class="sd"> Args:</span>
<span class="sd"> save_dir: Path to directory to save best model and config.</span>
<span class="sd"> **kwargs: Keyword arguments to pass to `EarlyStopping`. See</span>
<span class="sd"> `pytorch_lightning.callbacks.EarlyStopping` for details.</span>
<span class="sd"> """</span>
<span class="bp">self</span><span class="o">.</span><span class="n">save_dir</span> <span class="o">=</span> <span class="n">save_dir</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>

<div class="viewcode-block" id="GraphnetEarlyStopping.setup">
<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.GraphnetEarlyStopping.setup">[docs]</a>
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">trainer</span><span class="p">:</span> <span class="s2">"pl.Trainer"</span><span class="p">,</span>
<span class="n">graphnet_model</span><span class="p">:</span> <span class="s2">"Model"</span><span class="p">,</span>
<span class="n">stage</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="w"> </span><span class="sd">"""Call at setup stage of training.</span>

<span class="sd"> Args:</span>
<span class="sd"> trainer: The trainer.</span>
<span class="sd"> graphnet_model: The model.</span>
<span class="sd"> stage: The stage of training.</span>
<span class="sd"> """</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">setup</span><span class="p">(</span><span class="n">trainer</span><span class="p">,</span> <span class="n">graphnet_model</span><span class="p">,</span> <span class="n">stage</span><span class="p">)</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">save_dir</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">graphnet_model</span><span class="o">.</span><span class="n">save_config</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">save_dir</span><span class="p">,</span> <span class="s2">"config.yml"</span><span class="p">))</span></div>


<div class="viewcode-block" id="GraphnetEarlyStopping.on_train_epoch_end">
<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.GraphnetEarlyStopping.on_train_epoch_end">[docs]</a>
<span class="k">def</span> <span class="nf">on_train_epoch_end</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="s2">"pl.Trainer"</span><span class="p">,</span> <span class="n">graphnet_model</span><span class="p">:</span> <span class="s2">"Model"</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="w"> </span><span class="sd">"""Call after each train epoch.</span>

<span class="sd"> Args:</span>
<span class="sd"> trainer: Trainer object.</span>
<span class="sd"> graphnet_model: Graphnet Model.</span>

<span class="sd"> Returns: None.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_on_train_epoch_end</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_skip_check</span><span class="p">(</span>
<span class="n">trainer</span>
<span class="p">):</span>
<span class="k">return</span>
<span class="n">current_best</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_score</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_run_early_stopping_check</span><span class="p">(</span><span class="n">trainer</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_score</span> <span class="o">!=</span> <span class="n">current_best</span><span class="p">:</span>
<span class="n">graphnet_model</span><span class="o">.</span><span class="n">save_state_dict</span><span class="p">(</span>
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">save_dir</span><span class="p">,</span> <span class="s2">"best_model.pth"</span><span class="p">)</span>
<span class="p">)</span></div>


<div class="viewcode-block" id="GraphnetEarlyStopping.on_validation_end">
<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.GraphnetEarlyStopping.on_validation_end">[docs]</a>
<span class="k">def</span> <span class="nf">on_validation_end</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="s2">"pl.Trainer"</span><span class="p">,</span> <span class="n">graphnet_model</span><span class="p">:</span> <span class="s2">"Model"</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="w"> </span><span class="sd">"""Call after each validation epoch.</span>

<span class="sd"> Args:</span>
<span class="sd"> trainer: Trainer object.</span>
<span class="sd"> graphnet_model: Graphnet Model.</span>

<span class="sd"> Returns: None.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_on_train_epoch_end</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">_should_skip_check</span><span class="p">(</span><span class="n">trainer</span><span class="p">):</span>
<span class="k">return</span>
<span class="n">current_best</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_score</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_run_early_stopping_check</span><span class="p">(</span><span class="n">trainer</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_score</span> <span class="o">!=</span> <span class="n">current_best</span><span class="p">:</span>
<span class="n">graphnet_model</span><span class="o">.</span><span class="n">save_state_dict</span><span class="p">(</span>
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">save_dir</span><span class="p">,</span> <span class="s2">"best_model.pth"</span><span class="p">)</span>
<span class="p">)</span></div>


<div class="viewcode-block" id="GraphnetEarlyStopping.on_fit_end">
<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.GraphnetEarlyStopping.on_fit_end">[docs]</a>
<span class="k">def</span> <span class="nf">on_fit_end</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="s2">"pl.Trainer"</span><span class="p">,</span> <span class="n">graphnet_model</span><span class="p">:</span> <span class="s2">"Model"</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="w"> </span><span class="sd">"""Call at the end of training.</span>

<span class="sd"> Args:</span>
<span class="sd"> trainer: Trainer object.</span>
<span class="sd"> graphnet_model: Graphnet Model.</span>

<span class="sd"> Returns: None.</span>
<span class="sd"> """</span>
<span class="n">graphnet_model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span>
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">save_dir</span><span class="p">,</span> <span class="s2">"best_model.pth"</span><span class="p">)</span>
<span class="p">)</span></div>
</div>

</pre></div>

</article>
Expand Down
Loading

0 comments on commit 36a395f

Please sign in to comment.