Skip to content

Commit

Permalink
AK: Make Function a little bit constexpr
Browse files Browse the repository at this point in the history
Just enough to make it possible to pass a lambda to a constexpr function
taking an AK::Function parameter and have that constexpr function call
the passed-in AK::Function.

The main thing is that bit_cast<>s of pointers aren't ok in constexpr
functions, and neither is placement new. So add a union with stronger
typing for the constexpr case.

The body of the `if (is_constexpr_evaluated())` in
`init_with_callable()` is identical to the `if constexpr` right
after it. But `if constexpr (is_constexpr_evaluated())` always
evaluates `is_constexpr_evaluated()` in a constexpr context, so
this can't just add ` || is_constexpr_evaluated()` to that
`if constexpr`.
  • Loading branch information
nico committed Dec 20, 2024
1 parent d8eb3cf commit d5e7da4
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 29 deletions.
65 changes: 38 additions & 27 deletions AK/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Function<Out(In...)> {
{
}

~Function()
constexpr ~Function()
{
clear(false);
}
Expand All @@ -94,14 +94,14 @@ class Function<Out(In...)> {
}

template<typename CallableType>
Function(CallableType&& callable)
constexpr Function(CallableType&& callable)
requires((IsFunctionObject<CallableType> && IsCallableWithArguments<CallableType, Out, In...> && !IsSame<RemoveCVReference<CallableType>, Function>))
{
init_with_callable(forward<CallableType>(callable), CallableKind::FunctionObject);
}

template<typename FunctionType>
Function(FunctionType f)
constexpr Function(FunctionType f)
requires((IsFunctionPointer<FunctionType> && IsCallableWithArguments<RemovePointer<FunctionType>, Out, In...> && !IsSame<RemoveCVReference<FunctionType>, Function>))
{
init_with_callable(move(f), CallableKind::FunctionPointer);
Expand All @@ -113,10 +113,11 @@ class Function<Out(In...)> {
}

// Note: Despite this method being const, a mutable lambda _may_ modify its own captures.
Out operator()(In... in) const
constexpr Out operator()(In... in) const
{
auto* wrapper = callable_wrapper();
VERIFY(wrapper);
if (!is_constant_evaluated())
VERIFY(wrapper);
++m_call_nesting_level;
ScopeGuard guard([this] {
if (--m_call_nesting_level == 0 && m_deferred_clear)
Expand All @@ -128,7 +129,7 @@ class Function<Out(In...)> {
explicit operator bool() const { return !!callable_wrapper(); }

template<typename CallableType>
Function& operator=(CallableType&& callable)
constexpr Function& operator=(CallableType&& callable)
requires((IsFunctionObject<CallableType> && IsCallableWithArguments<CallableType, Out, In...>))
{
clear();
Expand All @@ -137,7 +138,7 @@ class Function<Out(In...)> {
}

template<typename FunctionType>
Function& operator=(FunctionType f)
constexpr Function& operator=(FunctionType f)
requires((IsFunctionPointer<FunctionType> && IsCallableWithArguments<RemovePointer<FunctionType>, Out, In...>))
{
clear();
Expand Down Expand Up @@ -171,8 +172,8 @@ class Function<Out(In...)> {
public:
virtual ~CallableWrapperBase() = default;
// Note: This is not const to allow storing mutable lambdas.
virtual Out call(In...) = 0;
virtual void destroy() = 0;
virtual constexpr Out call(In...) = 0;
virtual constexpr void destroy() = 0;
virtual void init_and_swap(u8*, size_t) = 0;
};

Expand All @@ -182,17 +183,17 @@ class Function<Out(In...)> {
AK_MAKE_NONCOPYABLE(CallableWrapper);

public:
explicit CallableWrapper(CallableType&& callable)
explicit constexpr CallableWrapper(CallableType&& callable)
: m_callable(move(callable))
{
}

Out call(In... in) final override
Out constexpr call(In... in) final override
{
return m_callable(forward<In>(in)...);
}

void destroy() final override
void constexpr destroy() final override
{
delete this;
}
Expand All @@ -214,21 +215,21 @@ class Function<Out(In...)> {
Outline,
};

CallableWrapperBase* callable_wrapper() const
constexpr CallableWrapperBase* callable_wrapper() const
{
switch (m_kind) {
case FunctionKind::NullPointer:
return nullptr;
case FunctionKind::Inline:
return bit_cast<CallableWrapperBase*>(&m_storage);
case FunctionKind::Outline:
return *bit_cast<CallableWrapperBase**>(&m_storage);
return m_storage.wrapper;
default:
VERIFY_NOT_REACHED();
}
}

void clear(bool may_defer = true)
constexpr void clear(bool may_defer = true)
{
bool called_from_inside_function = m_call_nesting_level > 0;
// NOTE: This VERIFY could fail because a Function is destroyed from within itself.
Expand All @@ -250,7 +251,7 @@ class Function<Out(In...)> {
}

template<typename Callable>
void init_with_callable(Callable&& callable, CallableKind callable_kind)
constexpr void init_with_callable(Callable&& callable, CallableKind callable_kind)
{
if constexpr (alignof(Callable) > ExcessiveAlignmentThreshold && !AccommodateExcessiveAlignmentRequirements) {
static_assert(
Expand All @@ -259,20 +260,27 @@ class Function<Out(In...)> {
"check your capture list if it is a lambda expression, "
"and make sure your callable object is not excessively aligned.");
}
VERIFY(m_call_nesting_level == 0);
if (!is_constant_evaluated())
VERIFY(m_call_nesting_level == 0);
using WrapperType = CallableWrapper<Callable>;
#ifndef KERNEL
if constexpr (alignof(Callable) > inline_alignment || sizeof(WrapperType) > inline_capacity) {
*bit_cast<CallableWrapperBase**>(&m_storage) = new WrapperType(forward<Callable>(callable));
if (is_constant_evaluated()) {
m_storage.wrapper = new WrapperType(forward<Callable>(callable));
m_kind = FunctionKind::Outline;
} else {
#ifndef KERNEL
if constexpr (alignof(Callable) > inline_alignment || sizeof(WrapperType) > inline_capacity) {
m_storage.wrapper = new WrapperType(forward<Callable>(callable));
m_kind = FunctionKind::Outline;
} else {
#endif
static_assert(sizeof(WrapperType) <= inline_capacity);
new (m_storage) WrapperType(forward<Callable>(callable));
m_kind = FunctionKind::Inline;
static_assert(sizeof(WrapperType) <= inline_capacity);
new (m_storage.storage) WrapperType(forward<Callable>(callable));
m_kind = FunctionKind::Inline;
#ifndef KERNEL
}
}
#endif
}

if (callable_kind == CallableKind::FunctionObject)
m_size = sizeof(WrapperType);
else
Expand All @@ -288,11 +296,11 @@ class Function<Out(In...)> {
case FunctionKind::NullPointer:
break;
case FunctionKind::Inline:
other_wrapper->init_and_swap(m_storage, inline_capacity);
other_wrapper->init_and_swap(m_storage.storage, inline_capacity);
m_kind = FunctionKind::Inline;
break;
case FunctionKind::Outline:
*bit_cast<CallableWrapperBase**>(&m_storage) = other_wrapper;
m_storage.wrapper = other_wrapper;
m_kind = FunctionKind::Outline;
break;
default:
Expand All @@ -315,7 +323,10 @@ class Function<Out(In...)> {
static constexpr size_t inline_capacity = 6 * sizeof(void*);
#endif

alignas(inline_alignment) u8 m_storage[inline_capacity];
alignas(inline_alignment) union {
u8 storage[inline_capacity];
CallableWrapperBase* wrapper;
} m_storage;
};

}
Expand Down
4 changes: 2 additions & 2 deletions AK/ScopeGuard.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ namespace AK {
template<typename Callback>
class ScopeGuard {
public:
ScopeGuard(Callback callback)
constexpr ScopeGuard(Callback callback)
: m_callback(move(callback))
{
}

~ScopeGuard()
constexpr ~ScopeGuard()
{
m_callback();
}
Expand Down
1 change: 1 addition & 0 deletions Tests/AK/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ set(AK_TEST_SOURCES
TestFloatingPointParsing.cpp
TestFlyString.cpp
TestFormat.cpp
TestFunction.cpp
TestFuzzyMatch.cpp
TestGeneratorAK.cpp
TestGenericLexer.cpp
Expand Down
14 changes: 14 additions & 0 deletions Tests/AK/TestFunction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright (c) 2024, Nico Weber <[email protected]>
*
* SPDX-License-Identifier: BSD-2-Clause
*/

#include <AK/Function.h>

constexpr int const_call(Function<int(int)> f, int i)
{
return f(i);
}

constinit int i = const_call([](int i) { return i; }, 4);

0 comments on commit d5e7da4

Please sign in to comment.