diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 00813ff..57aab47 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -24,8 +24,12 @@ export DECQuantity # this ensures symtype doesn't recurse endlessly SymbolicUtils.symtype(::Type{S}) where S<:DECQuantity = S -struct Scalar <: DECQuantity end -export Scalar +abstract type AbstractScalar <: DECQuantity end + +struct Scalar <: AbstractScalar end +struct Parameter <: AbstractScalar end +struct ConstScalar <: AbstractScalar end +export Scalar, Parameter, ConstScalar struct FormParams dim::Int @@ -107,7 +111,7 @@ end export PatFormDim @active PatScalar(T) begin - if T <: Scalar + if T <: AbstractScalar Some(T) end end @@ -225,7 +229,9 @@ abstract type SortError <: Exception end # struct WedgeDimError <: SortError end -Base.nameof(s::Scalar) = :Constant +Base.nameof(s::ConstScalar) = :ConstScalar +Base.nameof(s::Parameter) = :Parameter +Base.nameof(s::Scalar) = :Scalar function Base.nameof(f::Form; with_dim_parameter=false) dual = isdual(f) ? "Dual" : "" @@ -269,7 +275,9 @@ end function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) @match qty begin - :Scalar || :Constant => Scalar + :Scalar => Scalar + :ConstScalar => ConstScalar + :Parameter => Parameter :Form0 => PrimalForm{0, space, 1} :Form1 => PrimalForm{1, space, 1} :Form2 => PrimalForm{2, space, 1} diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index bd82387..3a0cb92 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -7,6 +7,7 @@ using SymbolicUtils: symtype, promote_symtype, Symbolic using MLStyle # load up some variable variables and expressions +c, t = @syms c::ConstScalar t::Parameter a, b = @syms a::Scalar b::Scalar u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} @@ -14,7 +15,9 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} # TODO would be nice to pass the space globally to avoid duplication @testset "Term Construction" begin - + + @test symtype(c) == ConstScalar + @test symtype(t) == Parameter @test symtype(a) == Scalar @test symtype(u) == PrimalForm{0, :X, 2} @test symtype(ω) == PrimalForm{1, :X, 2} @@ -22,6 +25,10 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test symtype(ϕ) == PrimalVF{:X, 2} @test symtype(ψ) == DualVF{:X, 2} + @test symtype(c + t) == Scalar + @test symtype(t + t) == Scalar + @test symtype(c + c) == Scalar + @test symtype(u ∧ ω) == PrimalForm{1, :X, 2} @test symtype(ω ∧ ω) == PrimalForm{2, :X, 2} # @test_throws ThDEC.SortError ThDEC.♯(u) @@ -30,6 +37,8 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} # test unary operator conversion to decaexpr @test Term(1) == Lit(Symbol("1")) @test Term(a) == Var(:a) + @test Term(c) == Var(:c) + @test Term(t) == Var(:t) @test Term(∂ₜ(u)) == Tan(Var(:u)) @test Term(★(ω)) == App1(:★₁, Var(:ω))