diff --git a/docs/src/transforms.md b/docs/src/transforms.md index aec2dab4..890e997d 100644 --- a/docs/src/transforms.md +++ b/docs/src/transforms.md @@ -50,6 +50,12 @@ Filter DropMissing ``` +## DropExtrema + +```@docs +DropExtrema +``` + ## Map ```@docs diff --git a/src/TableTransforms.jl b/src/TableTransforms.jl index 82810dbd..0c553a82 100644 --- a/src/TableTransforms.jl +++ b/src/TableTransforms.jl @@ -53,6 +53,7 @@ export Sample, Filter, DropMissing, + DropExtrema, Map, Replace, Coalesce, diff --git a/src/transforms.jl b/src/transforms.jl index 7e546a3d..a332dd5e 100644 --- a/src/transforms.jl +++ b/src/transforms.jl @@ -270,6 +270,7 @@ include("transforms/stdnames.jl") include("transforms/sort.jl") include("transforms/sample.jl") include("transforms/filter.jl") +include("transforms/dropextrema.jl") include("transforms/map.jl") include("transforms/replace.jl") include("transforms/coalesce.jl") diff --git a/src/transforms/dropextrema.jl b/src/transforms/dropextrema.jl new file mode 100644 index 00000000..4b95f2c7 --- /dev/null +++ b/src/transforms/dropextrema.jl @@ -0,0 +1,60 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +""" + DropExtrema(col; low=0.25, high=0.75) + +Drops the rows where the values in the column `col` are outside the interval +`[quantile(col, low), quantile(col, high)]`. + +# Examples + +```julia +DropExtrema(1) +DropExtrema(:a, low=0.2, high=0.8) +DropExtrema("a", low=0.3, high=0.7) +``` +""" +struct DropExtrema{S<:ColSpec,T} <: StatelessFeatureTransform + colspec::S + low::T + high::T + + function DropExtrema(col::Col, low::T, high::T) where {T} + @assert 0 ≤ low ≤ high ≤ 1 "invalid quantiles" + cs = colspec(col) + new{typeof(cs),T}(cs, low, high) + end +end + +DropExtrema(col::Col, low, high) = DropExtrema(col, promote(low, high)...) +DropExtrema(col::Col; low=0.25, high=0.75) = DropExtrema(col, low, high) + +isrevertible(::Type{<:DropExtrema}) = true + +function preprocess(transform::DropExtrema, table) + cols = Tables.columns(table) + names = Tables.columnnames(cols) + sname = choose(transform.colspec, names) |> first + + x = Tables.getcolumn(cols, sname) + low = convert(eltype(x), transform.low) + high = convert(eltype(x), transform.high) + xl, xh = quantile(x, (low, high)) + + ftrans = Filter(row -> xl ≤ row[sname] ≤ xh) + fprep = preprocess(ftrans, table) + ftrans, fprep +end + +function applyfeat(::DropExtrema, feat, prep) + ftrans, fprep = prep + newfeat, ffcache = applyfeat(ftrans, feat, fprep) + newfeat, (ftrans, ffcache) +end + +function revertfeat(::DropExtrema, newfeat, fcache) + ftrans, ffcache = fcache + revertfeat(ftrans, newfeat, ffcache) +end diff --git a/test/transforms.jl b/test/transforms.jl index b27c70fa..b80004d7 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -5,6 +5,7 @@ transformfiles = [ "sort.jl", "sample.jl", "filter.jl", + "dropextrema.jl", "map.jl", "replace.jl", "coalesce.jl", diff --git a/test/transforms/dropextrema.jl b/test/transforms/dropextrema.jl new file mode 100644 index 00000000..1f78f739 --- /dev/null +++ b/test/transforms/dropextrema.jl @@ -0,0 +1,47 @@ +@testset "DropExtrema" begin + @test isrevertible(DropExtrema(:a)) + + a = [6.9, 9.0, 7.8, 0.0, 5.1, 4.8, 1.1, 8.0, 5.4, 7.9] + b = [7.7, 4.2, 6.3, 1.4, 4.4, 0.5, 3.0, 6.1, 1.9, 1.5] + c = [6.1, 7.7, 5.7, 2.8, 2.8, 6.7, 8.4, 5.0, 8.9, 1.0] + d = [1.0, 2.8, 6.2, 1.9, 8.1, 6.2, 4.0, 6.9, 4.1, 1.4] + e = [1.5, 8.9, 4.1, 1.6, 5.9, 1.3, 4.9, 3.5, 2.4, 6.3] + f = [1.9, 2.1, 9.0, 6.2, 1.3, 8.9, 6.2, 3.8, 5.1, 2.3] + t = Table(; a, b, c, d, e, f) + + T = DropExtrema(1) + n, c = apply(T, t) + @test n.a == [6.9, 7.8, 5.1, 5.4] + @test n.b == [7.7, 6.3, 4.4, 1.9] + @test n.c == [6.1, 5.7, 2.8, 8.9] + @test n.d == [1.0, 6.2, 8.1, 4.1] + @test n.e == [1.5, 4.1, 5.9, 2.4] + @test n.f == [1.9, 9.0, 1.3, 5.1] + tₒ = revert(T, n, c) + @test t == tₒ + + T = DropExtrema(:c, low=0.3, high=0.7) + n, c = apply(T, t) + @test n.a == [6.9, 7.8, 4.8, 8.0] + @test n.b == [7.7, 6.3, 0.5, 6.1] + @test n.c == [6.1, 5.7, 6.7, 5.0] + @test n.d == [1.0, 6.2, 6.2, 6.9] + @test n.e == [1.5, 4.1, 1.3, 3.5] + @test n.f == [1.9, 9.0, 8.9, 3.8] + tₒ = revert(T, n, c) + @test t == tₒ + + T = DropExtrema("e", low=0.2, high=0.8) + n, c = apply(T, t) + @test n.a == [7.8, 0.0, 5.1, 1.1, 8.0, 5.4] + @test n.b == [6.3, 1.4, 4.4, 3.0, 6.1, 1.9] + @test n.c == [5.7, 2.8, 2.8, 8.4, 5.0, 8.9] + @test n.d == [6.2, 1.9, 8.1, 4.0, 6.9, 4.1] + @test n.e == [4.1, 1.6, 5.9, 4.9, 3.5, 2.4] + @test n.f == [9.0, 6.2, 1.3, 6.2, 3.8, 5.1] + tₒ = revert(T, n, c) + @test t == tₒ + + # throws + @test_throws AssertionError DropExtrema(:a, low=0, high=1.4) +end