diff --git a/desdeo/emo/operators/mutation.py b/desdeo/emo/operators/mutation.py index d625a88..5755f92 100644 --- a/desdeo/emo/operators/mutation.py +++ b/desdeo/emo/operators/mutation.py @@ -3,6 +3,7 @@ Various evolutionary operators for mutation in multiobjective optimization are defined here. """ +import copy from abc import abstractmethod from collections.abc import Sequence @@ -207,3 +208,133 @@ def state(self) -> Sequence[Message]: value=self.distribution_index, ), ] + + +class BinaryFlipMutation(BaseMutation): + """Implements the bit flip mutation operator for binary variables. + + The binary flip mutation will mutate each binary decision variable, + by flipping it (0 to 1, 1 to 0) with a provided probability. + """ + + @property + def provided_topics(self) -> dict[int, Sequence[MutationMessageTopics]]: + """The message topics provided by the mutation operator.""" + return { + 0: [], + 1: [ + MutationMessageTopics.MUTATION_PROBABILITY, + ], + 2: [ + MutationMessageTopics.MUTATION_PROBABILITY, + MutationMessageTopics.OFFSPRING_ORIGINAL, + MutationMessageTopics.OFFSPRINGS, + ], + } + + @property + def interested_topics(self): + """The message topics that the mutation operator is interested in.""" + return [] + + def __init__( + self, + *, + problem: Problem, + seed: int, + mutation_probability: float | None = None, + **kwargs, + ): + """Initialize a binary flip mutation operator. + + Args: + problem (Problem): The problem object. + seed (int): The seed for the random number generator. + mutation_probability (float | None, optional): The probability of mutation. If None, + the probability will be set to be 1/n, where n is the number of decision variables + in the problem. Defaults to None. + kwargs: Additional keyword arguments. These are passed to the Subscriber class. At the very least, the + publisher must be passed. See the Subscriber class for more information. + """ + super().__init__(problem, **kwargs) + + if self.variable_combination != VariableDomainTypeEnum.binary: + raise ValueError("This mutation operator only works with binary variables.") + if mutation_probability is None: + self.mutation_probability = 1 / len(self.variable_symbols) + else: + self.mutation_probability = mutation_probability + + self.rng = np.random.default_rng(seed) + self.seed = seed + self.offspring_original: pl.DataFrame + self.parents: pl.DataFrame + self.offspring: pl.DataFrame + + def do(self, offsprings: pl.DataFrame, parents: pl.DataFrame) -> pl.DataFrame: + """Perform the binary flip mutation operation. + + Args: + offsprings (pl.DataFrame): the offspring population to mutate. + parents (pl.DataFrame): the parent population from which the offspring + was generated (via crossover). Not used in the mutation operator. + + Returns: + pl.DataFrame: the offspring resulting from the mutation. + """ + self.offspring_original = copy.copy(offsprings) + self.parents = parents # Not used, but kept for consistency + offspring = offsprings.to_numpy().astype(dtype=np.bool) + + # create a boolean mask based on the mutation probability + flip_mask = self.rng.random(offspring.shape) < self.mutation_probability + + # using XOR (^), flip the bits in the offspring when the mask is True + # otherwise leave the bit's value as it is + offspring = offspring ^ flip_mask + + self.offspring = pl.from_numpy(offspring, schema=self.variable_symbols).select(pl.all()).cast(pl.Float64) + self.notify() + + return self.offspring + + def update(self, *_, **__): + """Do nothing. This is just the basic polynomial mutation operator.""" + + def state(self) -> Sequence[Message]: + """Return the state of the mutation operator.""" + if self.offspring_original is None or self.offspring is None: + return [] + if self.verbosity == 0: + return [] + if self.verbosity == 1: + return [ + FloatMessage( + topic=MutationMessageTopics.MUTATION_PROBABILITY, + source=self.__class__.__name__, + value=self.mutation_probability, + ), + ] + # verbosity == 2 + return [ + PolarsDataFrameMessage( + topic=MutationMessageTopics.OFFSPRING_ORIGINAL, + source=self.__class__.__name__, + value=self.offspring_original, + ), + PolarsDataFrameMessage( + topic=MutationMessageTopics.PARENTS, + source=self.__class__.__name__, + value=self.parents, + ), + PolarsDataFrameMessage( + topic=MutationMessageTopics.OFFSPRINGS, + source=self.__class__.__name__, + value=self.offspring, + ), + FloatMessage( + topic=MutationMessageTopics.MUTATION_PROBABILITY, + source=self.__class__.__name__, + value=self.mutation_probability, + ), + ] diff --git a/desdeo/problem/schema.py b/desdeo/problem/schema.py index 5204dac..1c568cc 100644 --- a/desdeo/problem/schema.py +++ b/desdeo/problem/schema.py @@ -217,6 +217,8 @@ class VariableDomainTypeEnum(str, Enum): continuous = "continuous" """All variables are real valued.""" + binary = "binary" + """All variables are binary valued.""" integer = "integer" """All variables are integer or binary valued.""" mixed = "mixed" @@ -1351,6 +1353,10 @@ def variable_domain(self) -> VariableDomainTypeEnum: # all variables are real valued -> continuous problem return VariableDomainTypeEnum.continuous + if all(t == VariableTypeEnum.binary for t in variable_types): + # all variables are binary valued -> binary problem + return VariableDomainTypeEnum.binary + if all(t in [VariableTypeEnum.integer, VariableTypeEnum.binary] for t in variable_types): # all variables are integer or binary -> integer problem return VariableDomainTypeEnum.integer diff --git a/tests/test_ea.py b/tests/test_ea.py index dbc1480..ba8f373 100644 --- a/tests/test_ea.py +++ b/tests/test_ea.py @@ -13,7 +13,7 @@ from desdeo.emo.operators.crossover import SimulatedBinaryCrossover, SinglePointBinaryCrossover from desdeo.emo.operators.evaluator import EMOEvaluator from desdeo.emo.operators.generator import LHSGenerator, RandomGenerator -from desdeo.emo.operators.mutation import BoundedPolynomialMutation +from desdeo.emo.operators.mutation import BinaryFlipMutation, BoundedPolynomialMutation from desdeo.emo.operators.selection import ParameterAdaptationStrategy, ReferenceVectorOptions, RVEASelector from desdeo.emo.operators.termination import MaxEvaluationsTerminator from desdeo.problem.testproblems import dtlz2, simple_knapsack, simple_knapsack_vectors @@ -246,3 +246,50 @@ def test_single_point_binary_crossover(): with npt.assert_raises(AssertionError): npt.assert_allclose(population, result) + + +@pytest.mark.ea +def test_binary_flip_mutation(): + """Test whether the binary flip mutation operator works as intended.""" + publisher = Publisher() + + problem = simple_knapsack() + + # default mutation probability + mutation = BinaryFlipMutation(problem=problem, publisher=publisher, seed=0) + num_vars = len(mutation.variable_symbols) + + population = pl.DataFrame( + np.ones((10, num_vars)), + schema=mutation.variable_symbols, + ) + + result = mutation.do(offsprings=population, parents=population) + + assert result.shape == (len(population), num_vars) + + with npt.assert_raises(AssertionError): + npt.assert_allclose(population, result) + + assert 1.0 in result.to_numpy() + assert 0.0 in result.to_numpy() + + # all bits should flip + mutation = BinaryFlipMutation(problem=problem, publisher=publisher, seed=0, mutation_probability=1.0) + num_vars = len(mutation.variable_symbols) + + result = mutation.do(offsprings=population, parents=population) + + assert result.shape == (len(population), num_vars) + + npt.assert_allclose(np.zeros((10, num_vars)), result) + + # no bit should flip + mutation = BinaryFlipMutation(problem=problem, publisher=publisher, seed=0, mutation_probability=0) + num_vars = len(mutation.variable_symbols) + + result = mutation.do(offsprings=population, parents=population) + + assert result.shape == (len(population), num_vars) + + npt.assert_allclose(np.ones((10, num_vars)), result) diff --git a/tests/test_problem_schema.py b/tests/test_problem_schema.py index 85009d5..81fa7cd 100644 --- a/tests/test_problem_schema.py +++ b/tests/test_problem_schema.py @@ -521,7 +521,7 @@ def test_variable_domain(): integer_problem = simple_knapsack() - assert integer_problem.variable_domain == VariableDomainTypeEnum.integer + assert integer_problem.variable_domain == VariableDomainTypeEnum.binary @pytest.mark.schema