diff --git a/mlrose/opt_probs.py b/mlrose/opt_probs.py index 7194bdc0..28300034 100644 --- a/mlrose/opt_probs.py +++ b/mlrose/opt_probs.py @@ -104,16 +104,21 @@ def eval_mate_probs(self): Calculate the probability of each member of the population reproducing. """ pop_fitness = np.copy(self.pop_fitness) + sum_fitness = np.sum(pop_fitness) # Set -1*inf values to 0 to avoid dividing by sum of infinity. # This forces mate_probs for these pop members to 0. - pop_fitness[pop_fitness == -1.0*np.inf] = 0 - - if np.sum(pop_fitness) == 0: - self.mate_probs = np.ones(len(pop_fitness)) \ - / len(pop_fitness) + pop_fitness[pop_fitness == -1.0 * np.inf] = 0 + + if sum_fitness == 0: + self.mate_probs = np.ones(len(pop_fitness)) / len(pop_fitness) + elif self.maximize == 1: + self.mate_probs = pop_fitness / sum_fitness + # creates mate probability if fitness is negative + # if fitness 0 mate probability will also be 0 else: - self.mate_probs = pop_fitness/np.sum(pop_fitness) + pop_fitness = [0 if x == 0 else sum_fitness / x for x in pop_fitness] + self.mate_probs = pop_fitness / np.sum(pop_fitness) def get_fitness(self): """ Return the fitness of the current state vector. diff --git a/tests/test_opt_probs.py b/tests/test_opt_probs.py index f50fa5df..2e7e2c59 100644 --- a/tests/test_opt_probs.py +++ b/tests/test_opt_probs.py @@ -188,7 +188,7 @@ def test_eval_fitness_min(): assert fitness == -10 @staticmethod - def test_eval_mate_probs(): + def test_eval_mate_probs_max(): """Test eval_mate_probs method""" problem = OptProb(5, OneMax(), maximize=True) @@ -206,6 +206,25 @@ def test_eval_mate_probs(): assert np.allclose(problem.get_mate_probs(), probs, atol=0.00001) + @staticmethod + def test_eval_mate_probs_min(): + """Test eval_mate_probs method""" + + problem = OptProb(5, OneMax(), maximize=False) + pop = np.array([[0, 0, 0, 0, 1], + [1, 0, 1, 0, 1], + [1, 1, 1, 1, 0], + [1, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1]]) + + problem.set_population(pop) + problem.eval_mate_probs() + + probs = np.array([0.4379562, 0.1459854, 0.10948905, 0.2189781, 0, 0.08759124]) + + assert np.allclose(problem.get_mate_probs(), probs, atol=0.00001) + @staticmethod def test_eval_mate_probs_all_zero(): """Test eval_mate_probs method when all states have zero fitness"""