Skip to content

Commit

Permalink
Bring ListT
Browse files Browse the repository at this point in the history
  • Loading branch information
gusty committed Jul 11, 2022
2 parents 038859e + cdb81a3 commit 2f5c4b9
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 44 deletions.
177 changes: 133 additions & 44 deletions src/FSharpPlus/Data/List.fs
Original file line number Diff line number Diff line change
Expand Up @@ -36,68 +36,157 @@ open FSharpPlus.Control

/// Monad Transformer for list<'T>
[<Struct>]
type ListT<'``monad<list<'t>>``> = ListT of '``monad<list<'t>>``
type ListT<'``monad<'t>``> = ListT of obj
type ListTNode<'``monad<'t>``,'t> = Nil | Cons of 't * ListT<'``monad<'t>``>

/// Basic operations on ListT
[<RequireQualifiedAccess>]
module ListT =
let run (ListT m) = m : '``Monad<list<'T>>``

/// Embed a Monad<'T> into a ListT<'Monad<list<'T>>>
let inline lift (x: '``Monad<'T>``) : ListT<'``Monad<list<'T>>``> =
if opaqueId false then x |> liftM List.singleton |> ListT
else x |> map List.singleton |> ListT

let inline internal sequence ms =
let k m m' = m >>= fun (x: 'a) -> m' >>= fun xs -> (result: list<'a> -> 'M) (x::xs)
List.foldBack k ms ((result :list<'a> -> 'M) [])

let inline internal mapM f as' = sequence (List.map f as')

let inline bind (f: 'T-> ListT<'``Monad<list<'U>``>) (ListT m: ListT<'``Monad<list<'T>``>) = (ListT (m >>= mapM (run << f) >>= ((List.concat: list<_>->_) >> result)))
let inline apply (ListT f: ListT<'``Monad<list<('T -> 'U)>``>) (ListT x: ListT<'``Monad<list<'T>``>) = ListT (map List.apply f <*> x) : ListT<'``Monad<list<'U>``>
let inline lift2 (f: 'T->'U->'V) (ListT x: ListT<'``Monad<list<'T>``>) (ListT y: ListT<'``Monad<list<'U>``>) = ListT (lift2 (List.lift2 f) x y) : ListT<'``Monad<list<'V>``>
let inline lift3 (f: 'T->'U->'V->'W) (ListT x: ListT<'``Monad<list<'T>``>) (ListT y: ListT<'``Monad<list<'U>``>) (ListT z: ListT<'``Monad<list<'V>``>) = ListT (lift3 (List.lift3 f) x y z) : ListT<'``Monad<list<'W>``>
let inline map (f: 'T->'U) (ListT m: ListT<'``Monad<list<'T>``>) = ListT (map (List.map f) m) : ListT<'``Monad<list<'U>``>

type ListT<'``monad<list<'t>>``> with

static member inline Return (x: 'T) = [x] |> result |> ListT : ListT<'``Monad<list<'T>``>
let inline internal wrap (mit: 'mit) =
let _mnil = (result Unchecked.defaultof<'t> : 'mt) >>= fun (_:'t) -> (result ListTNode<'mt,'t>.Nil ) : 'mit
ListT mit : ListT<'mt>

let inline internal unwrap (ListT mit : ListT<'mt>) =
let _mnil = (result Unchecked.defaultof<'t> : 'mt) >>= fun (_:'t) -> (result ListTNode<'mt,'t>.Nil ) : 'mit
unbox mit : 'mit

let inline empty () = wrap ((result ListTNode<'mt,'t>.Nil) : 'mit) : ListT<'mt>

/// Concatenates the elements of two lists
let inline concat l1 l2 =
let rec loop (l1: ListT<'mt>) (lst2: ListT<'mt>) =
let (l1, l2) = unwrap l1, unwrap lst2
ListT (l1 >>= function Nil -> l2 | Cons (x: 't, xs) -> ((result (Cons (x, loop xs lst2))) : 'mit))
loop l1 l2 : ListT<'mt>

let inline bind f (source: ListT<'mt>) : ListT<'mu> =
let _mnil = (result Unchecked.defaultof<'t> : 'mt) >>= fun (_: 't) -> (result Unchecked.defaultof<'u>) : 'mu
let rec loop f input =
ListT (
(unwrap input : 'mit) >>= function
| Nil -> result <| (Nil : ListTNode<'mu,'u>) : 'miu
| Cons (h:'t, t: ListT<'mt>) ->
let res = concat (f h: ListT<'mu>) (loop f t)
unwrap res : 'miu)
loop f source : ListT<'mu>

let inline unfold (f:'State -> '``M<('T * 'State) option>``) (s:'State) : ListT<'MT> =
let rec loop f s = f s |> map (function
| Some (a, s) -> Cons(a, loop f s)
| None -> Nil) |> wrap
loop f s

let inline map f (input : ListT<'mt>) : ListT<'mu> =
let rec collect f (input : ListT<'mt>) : ListT<'mu> =
wrap (
(unwrap input : 'mit) >>= function
| Nil -> result <| (Nil : ListTNode<'mu,'u>) : 'miu
| Cons (h: 't, t: ListT<'mt>) ->
let ( res) = Cons (f h, collect f t)
result res : 'miu)
collect f (input: ListT<'mt>) : ListT<'mu>

let inline singleton (v: 't) =
let mresult x = result x
let _mnil = (result Unchecked.defaultof<'t> : 'mt) >>= konst (mresult ListTNode<'mt,'t>.Nil ) : 'mit
wrap ((mresult <| ListTNode<'mt,'t>.Cons (v, (wrap (mresult ListTNode<'mt,'t>.Nil): ListT<'mt> ))) : 'mit) : ListT<'mt>

let inline apply f x = bind (fun (x1: _) -> bind (fun x2 -> singleton (x1 x2)) x) f

let inline append (head: 't) tail = wrap ((result <| ListTNode<'mt,'t>.Cons (head, (tail: ListT<'mt> ))) : 'mit) : ListT<'mt>

let inline head (x : ListT<'mt>) =
unwrap x >>= function
| Nil -> failwith "empty list"
| Cons (head, _) -> result head : 'mt

let inline tail (x: ListT<'mt>) : ListT<'mt> =
(unwrap x >>= function
| Nil -> failwith "empty list"
| Cons (_: 't, tail) -> unwrap tail) |> wrap

let inline iterM (action: 'T -> '``M<unit>``) (lst: ListT<'MT>) : '``M<unit>`` =
let rec loop lst action =
unwrap lst >>= function
| Nil -> result ()
| Cons (h, t) -> action h >>= (fun () -> loop t action)
loop lst action

let inline iter (action: 'T -> unit) (lst: ListT<'MT>) : '``M<unit>`` = iterM (action >> result) lst

let inline lift (x: '``Monad<'T>``) = wrap (x >>= (result << (fun x -> Cons (x, empty () )))) : ListT<'``Monad<'T>``>

let inline take count (input : ListT<'MT>) : ListT<'MT> =
let rec loop count (input : ListT<'MT>) : ListT<'MT> = wrap <| monad {
if count > 0 then
let! v = unwrap input
match v with
| Cons (h, t) -> return Cons (h, loop (count - 1) t)
| Nil -> return Nil
else return Nil }
loop count (input: ListT<'MT>)

let inline filterM (f: 'T -> '``M<bool>``) (input: ListT<'MT>) : ListT<'MT> =
input |> bind (fun v -> lift (f v) |> bind (fun b -> if b then singleton v else empty ()))

let inline filter f (input: ListT<'MT>) : ListT<'MT> = filterM (f >> result) input

let inline run (lst: ListT<'MT>) : '``Monad<list<'T>>`` =
let rec loop acc x = unwrap x >>= function
| Nil -> result (List.rev acc)
| Cons (x, xs) -> loop (x::acc) xs
loop [] lst



[<AutoOpen>]
module ListTPrimitives =
let inline listT (al: '``Monad<list<'T>>``) : ListT<'``Monad<'T>``> =
ListT.unfold (fun i -> map (fun (lst:list<_>) -> if lst.Length > i then Some (lst.[i], i+1) else None) al) 0

// let inline lift2 (f: 'T->'U->'V) (ListT x: ListT<'``Monad<list<'T>``>) (ListT y: ListT<'``Monad<list<'U>``>) = ListT (lift2 (List.lift2 f) x y) : ListT<'``Monad<list<'V>``>
// let inline lift3 (f: 'T->'U->'V->'W) (ListT x: ListT<'``Monad<list<'T>``>) (ListT y: ListT<'``Monad<list<'U>``>) (ListT z: ListT<'``Monad<list<'V>``>) = ListT (lift3 (List.lift3 f) x y z) : ListT<'``Monad<list<'W>``>


type ListT<'``monad<'t>``> with
static member inline Return (x: 'T) = ListT.singleton x : ListT<'M>

[<EditorBrowsable(EditorBrowsableState.Never)>]
static member inline Map (x: ListT<'``Monad<list<'T>``>, f: 'T->'U) = ListT.map f x : ListT<'``Monad<list<'U>``>
static member inline Map (x, f) = ListT.map f x

[<EditorBrowsable(EditorBrowsableState.Never)>]
static member inline Lift2 (f: 'T->'U->'V, x: ListT<'``Monad<list<'T>``>, y: ListT<'``Monad<list<'U>``>) = ListT.lift2 f x y : ListT<'``Monad<list<'V>``>
// [<EditorBrowsable(EditorBrowsableState.Never)>]
// static member inline Lift2 (f: 'T->'U->'V, x: ListT<'``Monad<list<'T>``>, y: ListT<'``Monad<list<'U>``>) = ListT.lift2 f x y : ListT<'``Monad<list<'V>``>

[<EditorBrowsable(EditorBrowsableState.Never)>]
static member inline Lift3 (f: 'T->'U->'V->'W, x: ListT<'``Monad<list<'T>``>, y: ListT<'``Monad<list<'U>``>, z: ListT<'``Monad<list<'V>``>) = ListT.lift3 f x y z : ListT<'``Monad<list<'W>``>
// [<EditorBrowsable(EditorBrowsableState.Never)>]
// static member inline Lift3 (f: 'T->'U->'V->'W, x: ListT<'``Monad<list<'T>``>, y: ListT<'``Monad<list<'U>``>, z: ListT<'``Monad<list<'V>``>) = ListT.lift3 f x y z : ListT<'``Monad<list<'W>``>

static member inline (<*>) (f: ListT<'``Monad<list<('T -> 'U)>``>, x: ListT<'``Monad<list<'T>``>) = ListT.apply f x : ListT<'``Monad<list<'U>``>
static member inline (>>=) (x: ListT<'``Monad<list<'T>``>, f: 'T -> ListT<'``Monad<list<'U>``>) = ListT.bind f x
static member inline (<*>) (f, x) = ListT.apply f x

static member inline get_Empty () = ListT <| result [] : ListT<'``MonadPlus<list<'T>``>
static member inline (<|>) (ListT x, ListT y) = ListT (x >>= (fun a -> y >>= (fun b -> result (a @ b)))) : ListT<'``MonadPlus<list<'T>``>
static member inline (>>=) (x, f) = ListT.bind f x
static member inline get_Empty () = ListT.empty ()
static member inline (<|>) (x, y) = ListT.concat x y

static member inline TryWith (source: ListT<'``Monad<list<'T>>``>, f: exn -> ListT<'``Monad<list<'T>>``>) = ListT (TryWith.Invoke (ListT.run source) (ListT.run << f))
static member inline TryFinally (computation: ListT<'``Monad<list<'T>>``>, f) = ListT (TryFinally.Invoke (ListT.run computation) f)
static member inline Using (resource, f: _ -> ListT<'``Monad<list<'T>>``>) = ListT (Using.Invoke resource (ListT.run << f))
static member inline Delay (body : unit -> ListT<'``Monad<list<'T>>``>) = ListT (Delay.Invoke (fun _ -> ListT.run (body ()))) : ListT<'``Monad<list<'T>>``>
static member inline TryWith (source: ListT<'``Monad<'T>``>, f: exn -> ListT<'``Monad<'T>``>) = ListT (TryWith.Invoke (ListT.unwrap source) (ListT.unwrap << f))
static member inline TryFinally (computation: ListT<'``Monad<'T>``>, f) = ListT (TryFinally.Invoke (ListT.unwrap computation) f)
static member inline Using (resource, f: _ -> ListT<'``Monad<'T>``>) = ListT (Using.Invoke resource (ListT.unwrap << f))
static member inline Delay (body : unit -> ListT<'``Monad<'T>``>) = ListT (Delay.Invoke (fun _ -> ListT.unwrap (body ()))) : ListT<'``Monad<'T>``>

[<EditorBrowsable(EditorBrowsableState.Never)>]
static member inline Lift (x: '``Monad<'T>``) : ListT<'``Monad<list<'T>>``> = ListT.lift x
static member inline Lift (x: '``Monad<'T>``) = ListT.lift x : ListT<'``Monad<'T>``>

static member inline LiftAsync (x: Async<'T>) = ListT.lift (liftAsync x) : ListT<'``MonadAsync<'T>``>
static member inline LiftAsync (x: Async<'T>) = lift (liftAsync x) : '``ListT<'MonadAsync<'T>>``

static member inline Throw (x: 'E) = x |> throw |> ListT.lift
static member inline Throw (x: 'E) = x |> throw |> lift
static member inline Catch (m: ListT<'``MonadError<'E1,'T>``>, h: 'E1 -> ListT<'``MonadError<'E2,'T>``>) = ListT ((fun v h -> Catch.Invoke v h) (ListT.run m) (ListT.run << h)) : ListT<'``MonadError<'E2,'T>``>

static member inline CallCC (f: (('T -> ListT<'``MonadCont<'R,list<'U>>``>) -> _)) = ListT (callCC <| fun c -> ListT.run (f (ListT << c << List.singleton))) : ListT<'``MonadCont<'R, list<'T>>``>

static member inline get_Get () = ListT.lift get : ListT<'``MonadState<'S,'S>``>
static member inline Put (x: 'S) = x |> put |> ListT.lift : ListT<'``MonadState<unit,'S>``>
static member inline get_Get () = lift get : '``ListT<'MonadState<'S,'S>>``
static member inline Put (x: 'T) = x |> put |> lift : '``ListT<'MonadState<unit,'S>>``

static member inline get_Ask () = ListT.lift ask : ListT<'``MonadReader<'R, list<'R>>``>
static member inline Local (ListT (m: '``MonadReader<'R2,'T>``), f: 'R1->'R2) = ListT (local f m)
static member inline get_Ask () = lift ask : '``ListT<'MonadReader<'R, list<'R>>>``
static member inline Local (m: ListT<'``MonadReader<'R2,'T>``>, f: 'R1->'R2) = listT (local f (ListT.run m))

static member inline Take (lst, c, _: Take) = ListT.take c lst

#endif
#endif
1 change: 1 addition & 0 deletions tests/FSharpPlus.Tests/FSharpPlus.Tests.fsproj
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
<Compile Include="Validations.fs" />
<Compile Include="Task.fs" />
<Compile Include="Free.fs" />
<Compile Include="ListT.fs" />
<Compile Include="ComputationExpressions.fs" />
<Compile Include="Lens.fs" />
<Compile Include="Extensions.fs" />
Expand Down
51 changes: 51 additions & 0 deletions tests/FSharpPlus.Tests/ListT.fs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
module FSharpPlus.Tests.ListT

open System
open FSharpPlus
open FSharpPlus.Data
open NUnit.Framework
open FsCheck
open Helpers
open System.Collections.Generic
open System.Threading.Tasks

module BasicTests =
[<Test>]
let wrap_unwrap () =
let c = listT (async.Return (['a'..'g']))
let res = c |> ListT.run |> listT |> ListT.run |> extract
let exp = c |> ListT.run |> extract
CollectionAssert.AreEqual (res, exp)

[<Test>]
let infiniteLists () =
let (infinite: ListT<Lazy<_>>) = ListT.unfold (fun x -> monad { return (Some (x, x + 1) ) }) 0
let finite = take 12 infinite
let res = finite <|> infinite
CollectionAssert.AreEqual (res |> take 13 |> ListT.run |> extract, [0;1;2;3;4;5;6;7;8;9;10;11;0])

// Compile tests
let binds () =
let res1 = listT [| [1..4] |] >>= fun x -> listT [| [x * 2] |]
let res2 = listT (Task.FromResult [1..4]) >>= (fun x -> listT (Task.FromResult [x * 2]))
let res3 = listT (ResizeArray [ [1..4] ]) >>= (fun x -> listT (ResizeArray [ [x * 2] ]))
let res4 = listT (lazy [1..4]) >>= (fun x -> listT (lazy ( [x * 2])))
let (res5: ListT<_ seq>) = listT (seq [ [1..4] ]) >>= (fun x -> listT (seq [ [x * 2] ]))
() // Note: seq needs type annotation.

let bind_for_ideantity () =
let res = listT (Identity [1..4]) >>= fun x -> listT (Identity [x * 2])
()

let computation_expressions () =
let oneTwoThree : ListT<_> = monad.plus {
do! lift <| Async.Sleep 10
yield 1
do! lift <| Async.Sleep 50
yield 2
yield 3}
()

let applicative_with_options () =
let x = (+) <!> listT None <*> listT (Some [1;2;3;4])
() // It doesn't work with asyncs

0 comments on commit 2f5c4b9

Please sign in to comment.