Skip to content

Commit

Permalink
Add DropExtrema transform (#201)
Browse files Browse the repository at this point in the history
* Add 'DropExtrema' transform

* Add tests

* Add to docs

* Fix typo

* Update test/transforms.jl
  • Loading branch information
eliascarv authored Sep 21, 2023
1 parent 71cda24 commit 0f56276
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/src/transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ Filter
DropMissing
```

## DropExtrema

```@docs
DropExtrema
```

## Map

```@docs
Expand Down
1 change: 1 addition & 0 deletions src/TableTransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export
Sample,
Filter,
DropMissing,
DropExtrema,
Map,
Replace,
Coalesce,
Expand Down
1 change: 1 addition & 0 deletions src/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ include("transforms/stdnames.jl")
include("transforms/sort.jl")
include("transforms/sample.jl")
include("transforms/filter.jl")
include("transforms/dropextrema.jl")
include("transforms/map.jl")
include("transforms/replace.jl")
include("transforms/coalesce.jl")
Expand Down
60 changes: 60 additions & 0 deletions src/transforms/dropextrema.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

"""
DropExtrema(col; low=0.25, high=0.75)
Drops the rows where the values in the column `col` are outside the interval
`[quantile(col, low), quantile(col, high)]`.
# Examples
```julia
DropExtrema(1)
DropExtrema(:a, low=0.2, high=0.8)
DropExtrema("a", low=0.3, high=0.7)
```
"""
struct DropExtrema{S<:ColSpec,T} <: StatelessFeatureTransform
colspec::S
low::T
high::T

function DropExtrema(col::Col, low::T, high::T) where {T}
@assert 0 low high 1 "invalid quantiles"
cs = colspec(col)
new{typeof(cs),T}(cs, low, high)
end
end

DropExtrema(col::Col, low, high) = DropExtrema(col, promote(low, high)...)
DropExtrema(col::Col; low=0.25, high=0.75) = DropExtrema(col, low, high)

isrevertible(::Type{<:DropExtrema}) = true

function preprocess(transform::DropExtrema, table)
cols = Tables.columns(table)
names = Tables.columnnames(cols)
sname = choose(transform.colspec, names) |> first

x = Tables.getcolumn(cols, sname)
low = convert(eltype(x), transform.low)
high = convert(eltype(x), transform.high)
xl, xh = quantile(x, (low, high))

ftrans = Filter(row -> xl row[sname] xh)
fprep = preprocess(ftrans, table)
ftrans, fprep
end

function applyfeat(::DropExtrema, feat, prep)
ftrans, fprep = prep
newfeat, ffcache = applyfeat(ftrans, feat, fprep)
newfeat, (ftrans, ffcache)
end

function revertfeat(::DropExtrema, newfeat, fcache)
ftrans, ffcache = fcache
revertfeat(ftrans, newfeat, ffcache)
end
1 change: 1 addition & 0 deletions test/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ transformfiles = [
"sort.jl",
"sample.jl",
"filter.jl",
"dropextrema.jl",
"map.jl",
"replace.jl",
"coalesce.jl",
Expand Down
47 changes: 47 additions & 0 deletions test/transforms/dropextrema.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
@testset "DropExtrema" begin
@test isrevertible(DropExtrema(:a))

a = [6.9, 9.0, 7.8, 0.0, 5.1, 4.8, 1.1, 8.0, 5.4, 7.9]
b = [7.7, 4.2, 6.3, 1.4, 4.4, 0.5, 3.0, 6.1, 1.9, 1.5]
c = [6.1, 7.7, 5.7, 2.8, 2.8, 6.7, 8.4, 5.0, 8.9, 1.0]
d = [1.0, 2.8, 6.2, 1.9, 8.1, 6.2, 4.0, 6.9, 4.1, 1.4]
e = [1.5, 8.9, 4.1, 1.6, 5.9, 1.3, 4.9, 3.5, 2.4, 6.3]
f = [1.9, 2.1, 9.0, 6.2, 1.3, 8.9, 6.2, 3.8, 5.1, 2.3]
t = Table(; a, b, c, d, e, f)

T = DropExtrema(1)
n, c = apply(T, t)
@test n.a == [6.9, 7.8, 5.1, 5.4]
@test n.b == [7.7, 6.3, 4.4, 1.9]
@test n.c == [6.1, 5.7, 2.8, 8.9]
@test n.d == [1.0, 6.2, 8.1, 4.1]
@test n.e == [1.5, 4.1, 5.9, 2.4]
@test n.f == [1.9, 9.0, 1.3, 5.1]
tₒ = revert(T, n, c)
@test t == tₒ

T = DropExtrema(:c, low=0.3, high=0.7)
n, c = apply(T, t)
@test n.a == [6.9, 7.8, 4.8, 8.0]
@test n.b == [7.7, 6.3, 0.5, 6.1]
@test n.c == [6.1, 5.7, 6.7, 5.0]
@test n.d == [1.0, 6.2, 6.2, 6.9]
@test n.e == [1.5, 4.1, 1.3, 3.5]
@test n.f == [1.9, 9.0, 8.9, 3.8]
tₒ = revert(T, n, c)
@test t == tₒ

T = DropExtrema("e", low=0.2, high=0.8)
n, c = apply(T, t)
@test n.a == [7.8, 0.0, 5.1, 1.1, 8.0, 5.4]
@test n.b == [6.3, 1.4, 4.4, 3.0, 6.1, 1.9]
@test n.c == [5.7, 2.8, 2.8, 8.4, 5.0, 8.9]
@test n.d == [6.2, 1.9, 8.1, 4.0, 6.9, 4.1]
@test n.e == [4.1, 1.6, 5.9, 4.9, 3.5, 2.4]
@test n.f == [9.0, 6.2, 1.3, 6.2, 3.8, 5.1]
tₒ = revert(T, n, c)
@test t == tₒ

# throws
@test_throws AssertionError DropExtrema(:a, low=0, high=1.4)
end

0 comments on commit 0f56276

Please sign in to comment.