Skip to content

Commit

Permalink
Fix callback compatibility for trainable_model (#878)
Browse files Browse the repository at this point in the history
* Fix callback compatibility for trainable_model

* Update releasenotes/notes/fix-callback-trainable_model-ca52060260a3e466.yaml

Co-authored-by: Edoardo Altamura <[email protected]>

* Update tests with callbacks

* Disable pylint unused arguments in tests

---------

Co-authored-by: Edoardo Altamura <[email protected]>
(cherry picked from commit 02ee6a7)
  • Loading branch information
OkuyanBoga authored and mergify[bot] committed Dec 16, 2024
1 parent 6a8cc57 commit c0ffa55
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
5 changes: 3 additions & 2 deletions qiskit_machine_learning/algorithms/trainable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .objective_functions import ObjectiveFunction
from .serializable_model import SerializableModelMixin
from ..optimizers import Optimizer, SLSQP, OptimizerResult, Minimizer
from ..optimizers import Optimizer, SciPyOptimizer, SLSQP, OptimizerResult, Minimizer
from ..utils import algorithm_globals
from ..neural_networks import NeuralNetwork
from ..utils.loss_functions import (
Expand Down Expand Up @@ -269,7 +269,8 @@ def _get_objective(

def objective(objective_weights):
objective_value = function.objective(objective_weights)
self._callback(objective_weights, objective_value)
if isinstance(self._optimizer, SciPyOptimizer):
self._callback(objective_weights, objective_value)
return objective_value

return objective
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Fixed a compatibility issue for different `callback` functions when they are
used in :class:`~qiskit_machine_learning.algorithms.trainable_model` based algorithms.
9 changes: 5 additions & 4 deletions test/algorithms/classifiers/test_neural_network_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,17 +521,18 @@ def test_callback_setter(self):

loss_history = []

def store_loss(_, loss):
loss_history.append(loss)
def store_loss(nfev, x_next, loss, update, is_accepted): # pylint: disable=unused-argument
if is_accepted:
loss_history.append(loss)

# use setter for the callback instead of providing in the initialize method
classifier.callback = store_loss
classifier.optimizer.callback = store_loss

features = np.array([[0, 0], [1, 1]])
labels = np.array([0, 1])
classifier.fit(features, labels)

self.assertEqual(len(loss_history), 3)
self.assertEqual(len(loss_history), 1)


if __name__ == "__main__":
Expand Down
10 changes: 6 additions & 4 deletions test/algorithms/regressors/test_neural_network_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,19 @@ def test_callback_setter(self):

loss_history = []

def store_loss(_, loss):
loss_history.append(loss)
def store_loss(nfev, x_next, loss, update, is_accepted): # pylint: disable=unused-argument

if is_accepted:
loss_history.append(loss)

# use setter for the callback instead of providing in the initialize method
regressor.callback = store_loss
regressor.optimizer.callback = store_loss

features = np.array([[0, 0], [0.1, 0.1], [0.4, 0.4], [1, 1]])
labels = np.array([0, 0.1, 0.4, 1])
regressor.fit(features, labels)

self.assertEqual(len(loss_history), 3)
self.assertEqual(len(loss_history), 1)


if __name__ == "__main__":
Expand Down

0 comments on commit c0ffa55

Please sign in to comment.