diff --git a/recirq/optimize/_util.py b/recirq/optimize/_util.py new file mode 100644 index 00000000..e242bd0c --- /dev/null +++ b/recirq/optimize/_util.py @@ -0,0 +1,64 @@ +# Copyright 2021 Google +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# `wrap_function` is re-used under the following license: + +# Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +def wrap_function(function, args): + """Count the number of times a function is called and close around `args`. + + This function existed in scipy prior to 1.7 + It must not have been considered part of the public API, as it has + disappeared without a trace. + """ + ncalls = [0] + if function is None: + return ncalls, None + + def function_wrapper(*wrapper_args): + ncalls[0] += 1 + return function(*(wrapper_args + args)) + + return ncalls, function_wrapper diff --git a/recirq/optimize/mgd.py b/recirq/optimize/mgd.py index 74f9f97a..92250674 100644 --- a/recirq/optimize/mgd.py +++ b/recirq/optimize/mgd.py @@ -16,12 +16,13 @@ import numpy as np import scipy -from scipy.optimize.optimize import wrap_function from scipy.optimize import OptimizeResult from sklearn.linear_model import LinearRegression from sklearn.preprocessing import PolynomialFeatures from sklearn.pipeline import Pipeline +from recirq.optimize._util import wrap_function + def _get_least_squares_model_gradient( xs: List[np.ndarray], diff --git a/recirq/optimize/mpg.py b/recirq/optimize/mpg.py index 9252d156..aab3554f 100644 --- a/recirq/optimize/mpg.py +++ b/recirq/optimize/mpg.py @@ -18,13 +18,14 @@ import numpy as np import scipy from scipy.optimize import OptimizeResult -from scipy.optimize.optimize import wrap_function from sklearn.linear_model import LinearRegression from sklearn.pipeline import Pipeline from sklearn.preprocessing import PolynomialFeatures from cirq import value +from recirq.optimize._util import wrap_function + if TYPE_CHECKING: import cirq @@ -98,6 +99,7 @@ def value(self, t): return self.learning_rate * self.decay_rate ** m + def _adam_update( grad: np.ndarray, x: np.ndarray, @@ -167,7 +169,6 @@ def model_policy_gradient( known_values: Optional[Tuple[List[np.ndarray], List[float]]] = None, max_evaluations: Optional[int] = None ) -> scipy.optimize.OptimizeResult: - """Model policy gradient algorithm for black-box optimization. The idea of this algorithm is to perform policy gradient, but estimate