Skip to content

Commit

Permalink
Merge pull request #276 from foxtran/feature/swap_operands
Browse files Browse the repository at this point in the history
Implement swap operands for binary ops
  • Loading branch information
MilesCranmer authored Jan 3, 2024
2 parents b771ea1 + acd585e commit e3a7ccc
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/Mutate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import ..MutationFunctionsModule:
gen_random_tree_fixed_size,
mutate_constant,
mutate_operator,
swap_operands,
append_random_op,
prepend_random_op,
insert_random_op,
Expand All @@ -28,6 +29,7 @@ function condition_mutation_weights!(
# If equation is too small, don't delete operators
# or simplify
weights.mutate_operator = 0.0
weights.swap_operands = 0.0
weights.delete_node = 0.0
weights.simplify = 0.0
if !member.tree.constant
Expand All @@ -37,6 +39,11 @@ function condition_mutation_weights!(
return nothing
end

if !any(node -> node.degree == 2, member.tree)
# swap is implemented only for binary ops
weights.swap_operands = 0.0
end

#More constants => more likely to do constant mutation
n_constants = count_constants(member.tree)
weights.mutate_constant *= min(8, n_constants) / 8.0
Expand Down Expand Up @@ -110,6 +117,11 @@ function next_generation(
is_success_always_possible = true
# Can always mutate to the same operator

elseif mutation_choice == :swap_operands
tree = swap_operands(tree, options)
@recorder tmp_recorder["type"] = "swap_operands"
is_success_always_possible = true

elseif mutation_choice == :add_node
if rand() < 0.5
tree = append_random_op(tree, options, nfeatures)
Expand Down
10 changes: 10 additions & 0 deletions src/MutationFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ function random_node(tree::Node{T}; filter::F=Returns(true))::Node{T} where {T,F
return chosen_node[]
end

# Swap operands in binary operator for ops like pow and divide
function swap_operands(tree::Node{T}, options::Options)::Node{T} where {T}
if !any(node -> node.degree == 2, tree)
return tree
end
node = random_node(tree; filter=t -> t.degree == 2)
node.l, node.r = node.r, node.l
return tree
end

# Randomly convert an operator into another one (binary->binary;
# unary->unary)
function mutate_operator(tree::Node{T}, options::Options)::Node{T} where {T}
Expand Down
6 changes: 6 additions & 0 deletions src/OptionsStruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import LossFunctions: SupervisedLoss
mutable struct MutationWeights
mutate_constant::Float64
mutate_operator::Float64
swap_operands::Float64
add_node::Float64
insert_node::Float64
delete_node::Float64
Expand All @@ -27,6 +28,7 @@ will be normalized to sum to 1.0 after initialization.
# Arguments
- `mutate_constant::Float64`: How often to mutate a constant.
- `mutate_operator::Float64`: How often to mutate an operator.
- `swap_operands::Float64`: How often to swap operands in binary operators.
- `add_node::Float64`: How often to append a node to the tree.
- `insert_node::Float64`: How often to insert a node into the tree.
- `delete_node::Float64`: How often to delete a node from the tree.
Expand All @@ -40,6 +42,7 @@ will be normalized to sum to 1.0 after initialization.
function MutationWeights(;
mutate_constant=0.048,
mutate_operator=0.47,
swap_operands=0.0,
add_node=0.79,
insert_node=5.1,
delete_node=1.7,
Expand All @@ -51,6 +54,7 @@ function MutationWeights(;
return MutationWeights(
mutate_constant,
mutate_operator,
swap_operands,
add_node,
insert_node,
delete_node,
Expand All @@ -66,6 +70,7 @@ function Base.convert(::Type{Vector}, w::MutationWeights)::Vector{Float64}
return [
w.mutate_constant,
w.mutate_operator,
w.swap_operands,
w.add_node,
w.insert_node,
w.delete_node,
Expand All @@ -81,6 +86,7 @@ function Base.copy(weights::MutationWeights)
return MutationWeights(
weights.mutate_constant,
weights.mutate_operator,
weights.swap_operands,
weights.add_node,
weights.insert_node,
weights.delete_node,
Expand Down
1 change: 1 addition & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ function do_precompilation(::Val{mode}) where {mode}
mutation_weights=MutationWeights(;
mutate_constant=1.0,
mutate_operator=1.0,
swap_operands=1.0,
add_node=1.0,
insert_node=1.0,
delete_node=1.0,
Expand Down

0 comments on commit e3a7ccc

Please sign in to comment.