Skip to content

Commit

Permalink
Generalize get_dimensions_type
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 10, 2023
1 parent 8b8d841 commit 888c4b9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
53 changes: 47 additions & 6 deletions src/InterfaceDynamicQuantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@ Recursively finds the dimension type from an array, or,
if no quantity is found, returns the default type.
"""
function get_dimensions_type(A::AbstractArray, default::Type{D}) where {D}
for a in A
# Look through columns for any dimensions (so we can return the correct type)
if typeof(a) <: UnionAbstractQuantity
return dim_type(a)
end
i = findfirst(a -> isa(a, UnionAbstractQuantity), A)
if i === nothing
return D
else
return typeof(dimension(A[i]))
end
return D
end
function get_dimensions_type(
::AbstractArray{Q}, default::Type
Expand All @@ -89,4 +88,46 @@ function get_dimensions_type(_, default::Type{D}) where {D}
return D
end

# Shortcut for basic numeric types
function get_dimensions_type(
::AbstractArray{
<:Union{
Bool,
Int8,
UInt8,
Int16,
UInt16,
Int32,
UInt32,
Int64,
UInt64,
Int128,
UInt128,
Float16,
Float32,
Float64,
BigFloat,
BigInt,
ComplexF16,
ComplexF32,
ComplexF64,
Complex{BigFloat},
Rational{Int8},
Rational{UInt8},
Rational{Int16},
Rational{UInt16},
Rational{Int32},
Rational{UInt32},
Rational{Int64},
Rational{UInt64},
Rational{Int128},
Rational{UInt128},
Rational{BigInt},
},
},
default::Type{D},
) where {D}
return D
end

end
7 changes: 6 additions & 1 deletion test/test_units.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using SymbolicRegression
using SymbolicRegression.InterfaceDynamicQuantitiesModule: get_units
using SymbolicRegression.InterfaceDynamicQuantitiesModule: get_units, get_dimensions_type
using SymbolicRegression.MLJInterfaceModule: unwrap_units_single
using SymbolicRegression.DimensionalAnalysisModule:
violates_dimensional_constraints, @maybe_return_call, WildcardQuantity
Expand Down Expand Up @@ -336,4 +336,9 @@ end
_, test_dims = unwrap_units_single(Xm_t, Dimensions)
@test test_dims == dimension.([u"1", u"m", u"m/s"])
@test_skip @inferred unwrap_units_single(Xm_t, Dimensions)

# Another edge case
## Should be able to pull it out from array:
@test get_dimensions_type(Number[1.0, us"1"], Dimensions) <: SymbolicDimensions
@test get_dimensions_type(Number[1.0, 1.0], Dimensions) <: Dimensions
end

0 comments on commit 888c4b9

Please sign in to comment.