diff --git a/fn/result.go b/fn/result.go index c8de5aa6b..5be5362ba 100644 --- a/fn/result.go +++ b/fn/result.go @@ -64,20 +64,37 @@ func (r Result[T]) IsErr() bool { return r.IsRight() } -// Map applies a function to the success value if it exists. +// Map applies an endomorphic function to the success value if it exists. +// +// Deprecated: Use MapOk instead. func (r Result[T]) Map(f func(T) T) Result[T] { return Result[T]{ MapLeft[T, error](f)(r.Either), } } -// MapErr applies a function to the error value if it exists. +// MapOk applies an endomorphic function to the success value if it exists. +func (r Result[T]) MapOk(f func(T) T) Result[T] { + return Result[T]{ + MapLeft[T, error](f)(r.Either), + } +} + +// MapErr applies an endomorphic function to the error value if it exists. func (r Result[T]) MapErr(f func(error) error) Result[T] { return Result[T]{ MapRight[T](f)(r.Either), } } +// MapOk applies a non-endomorphic function to the success value if it exists +// and returns a Result of the new type. +func MapOk[A, B any](f func(A) B) func(Result[A]) Result[B] { + return func(r Result[A]) Result[B] { + return Result[B]{MapLeft[A, error](f)(r.Either)} + } +} + // Option returns the success value as an Option. // // Deprecated: Use OkToSome instead. @@ -137,8 +154,22 @@ func (r Result[T]) UnwrapOrFail(t *testing.T) T { return r.left } -// FlatMap applies a function that returns a Result to the success value if it -// exists. +// FlattenResult takes a nested Result and joins the two functor layers into +// one. +func FlattenResult[A any](r Result[Result[A]]) Result[A] { + if r.IsErr() { + return Err[A](r.right) + } + + if r.left.IsErr() { + return Err[A](r.left.right) + } + + return r.left +} + +// FlatMap applies a kleisli endomorphic function that returns a Result to the +// success value if it exists. func (r Result[T]) FlatMap(f func(T) Result[T]) Result[T] { if r.IsOk() { return r diff --git a/fn/result_test.go b/fn/result_test.go index fd21e4443..bfbe8d9bd 100644 --- a/fn/result_test.go +++ b/fn/result_test.go @@ -2,7 +2,9 @@ package fn import ( "errors" + "fmt" "testing" + "testing/quick" "github.com/stretchr/testify/require" ) @@ -17,3 +19,29 @@ func TestOkToSome(t *testing.T) { t, Err[uint8](errors.New("err")).OkToSome(), None[uint8](), ) } + +func TestMapOk(t *testing.T) { + inc := func(i int) int { + return i + 1 + } + f := func(i int) bool { + ok := Ok(i) + return MapOk(inc)(ok) == Ok(inc(i)) + } + + require.NoError(t, quick.Check(f, nil)) +} + +func TestFlattenResult(t *testing.T) { + f := func(i int) bool { + e := fmt.Errorf("error") + + x := FlattenResult(Ok(Ok(i))) == Ok(i) + y := FlattenResult(Ok(Err[int](e))) == Err[int](e) + z := FlattenResult(Err[Result[int]](e)) == Err[int](e) + + return x && y && z + } + + require.NoError(t, quick.Check(f, nil)) +}