diff --git a/AK/Function.h b/AK/Function.h index d3732cdb2a7a99..5d71d087e25a72 100644 --- a/AK/Function.h +++ b/AK/Function.h @@ -79,7 +79,7 @@ class Function { { } - ~Function() + constexpr ~Function() { clear(false); } @@ -94,14 +94,14 @@ class Function { } template - Function(CallableType&& callable) + constexpr Function(CallableType&& callable) requires((IsFunctionObject && IsCallableWithArguments && !IsSame, Function>)) { init_with_callable(forward(callable), CallableKind::FunctionObject); } template - Function(FunctionType f) + constexpr Function(FunctionType f) requires((IsFunctionPointer && IsCallableWithArguments, Out, In...> && !IsSame, Function>)) { init_with_callable(move(f), CallableKind::FunctionPointer); @@ -113,10 +113,11 @@ class Function { } // Note: Despite this method being const, a mutable lambda _may_ modify its own captures. - Out operator()(In... in) const + constexpr Out operator()(In... in) const { auto* wrapper = callable_wrapper(); - VERIFY(wrapper); + if (!is_constant_evaluated()) + VERIFY(wrapper); ++m_call_nesting_level; ScopeGuard guard([this] { if (--m_call_nesting_level == 0 && m_deferred_clear) @@ -128,7 +129,7 @@ class Function { explicit operator bool() const { return !!callable_wrapper(); } template - Function& operator=(CallableType&& callable) + constexpr Function& operator=(CallableType&& callable) requires((IsFunctionObject && IsCallableWithArguments)) { clear(); @@ -137,7 +138,7 @@ class Function { } template - Function& operator=(FunctionType f) + constexpr Function& operator=(FunctionType f) requires((IsFunctionPointer && IsCallableWithArguments, Out, In...>)) { clear(); @@ -171,8 +172,8 @@ class Function { public: virtual ~CallableWrapperBase() = default; // Note: This is not const to allow storing mutable lambdas. - virtual Out call(In...) = 0; - virtual void destroy() = 0; + virtual constexpr Out call(In...) = 0; + virtual constexpr void destroy() = 0; virtual void init_and_swap(u8*, size_t) = 0; }; @@ -182,17 +183,17 @@ class Function { AK_MAKE_NONCOPYABLE(CallableWrapper); public: - explicit CallableWrapper(CallableType&& callable) + explicit constexpr CallableWrapper(CallableType&& callable) : m_callable(move(callable)) { } - Out call(In... in) final override + Out constexpr call(In... in) final override { return m_callable(forward(in)...); } - void destroy() final override + void constexpr destroy() final override { delete this; } @@ -214,7 +215,7 @@ class Function { Outline, }; - CallableWrapperBase* callable_wrapper() const + constexpr CallableWrapperBase* callable_wrapper() const { switch (m_kind) { case FunctionKind::NullPointer: @@ -222,13 +223,13 @@ class Function { case FunctionKind::Inline: return bit_cast(&m_storage); case FunctionKind::Outline: - return *bit_cast(&m_storage); + return m_storage.wrapper; default: VERIFY_NOT_REACHED(); } } - void clear(bool may_defer = true) + constexpr void clear(bool may_defer = true) { bool called_from_inside_function = m_call_nesting_level > 0; // NOTE: This VERIFY could fail because a Function is destroyed from within itself. @@ -250,7 +251,7 @@ class Function { } template - void init_with_callable(Callable&& callable, CallableKind callable_kind) + constexpr void init_with_callable(Callable&& callable, CallableKind callable_kind) { if constexpr (alignof(Callable) > ExcessiveAlignmentThreshold && !AccommodateExcessiveAlignmentRequirements) { static_assert( @@ -259,20 +260,27 @@ class Function { "check your capture list if it is a lambda expression, " "and make sure your callable object is not excessively aligned."); } - VERIFY(m_call_nesting_level == 0); + if (!is_constant_evaluated()) + VERIFY(m_call_nesting_level == 0); using WrapperType = CallableWrapper; -#ifndef KERNEL - if constexpr (alignof(Callable) > inline_alignment || sizeof(WrapperType) > inline_capacity) { - *bit_cast(&m_storage) = new WrapperType(forward(callable)); + if (is_constant_evaluated()) { + m_storage.wrapper = new WrapperType(forward(callable)); m_kind = FunctionKind::Outline; } else { +#ifndef KERNEL + if constexpr (alignof(Callable) > inline_alignment || sizeof(WrapperType) > inline_capacity) { + m_storage.wrapper = new WrapperType(forward(callable)); + m_kind = FunctionKind::Outline; + } else { #endif - static_assert(sizeof(WrapperType) <= inline_capacity); - new (m_storage) WrapperType(forward(callable)); - m_kind = FunctionKind::Inline; + static_assert(sizeof(WrapperType) <= inline_capacity); + new (m_storage.storage) WrapperType(forward(callable)); + m_kind = FunctionKind::Inline; #ifndef KERNEL - } + } #endif + } + if (callable_kind == CallableKind::FunctionObject) m_size = sizeof(WrapperType); else @@ -288,11 +296,11 @@ class Function { case FunctionKind::NullPointer: break; case FunctionKind::Inline: - other_wrapper->init_and_swap(m_storage, inline_capacity); + other_wrapper->init_and_swap(m_storage.storage, inline_capacity); m_kind = FunctionKind::Inline; break; case FunctionKind::Outline: - *bit_cast(&m_storage) = other_wrapper; + m_storage.wrapper = other_wrapper; m_kind = FunctionKind::Outline; break; default: @@ -315,7 +323,10 @@ class Function { static constexpr size_t inline_capacity = 6 * sizeof(void*); #endif - alignas(inline_alignment) u8 m_storage[inline_capacity]; + alignas(inline_alignment) union { + u8 storage[inline_capacity]; + CallableWrapperBase* wrapper; + } m_storage; }; } diff --git a/AK/ScopeGuard.h b/AK/ScopeGuard.h index 2897203465713a..123f1ee75cc385 100644 --- a/AK/ScopeGuard.h +++ b/AK/ScopeGuard.h @@ -13,12 +13,12 @@ namespace AK { template class ScopeGuard { public: - ScopeGuard(Callback callback) + constexpr ScopeGuard(Callback callback) : m_callback(move(callback)) { } - ~ScopeGuard() + constexpr ~ScopeGuard() { m_callback(); } diff --git a/Tests/AK/CMakeLists.txt b/Tests/AK/CMakeLists.txt index f2fc3f1a94a64d..004bcb6f4cd16d 100644 --- a/Tests/AK/CMakeLists.txt +++ b/Tests/AK/CMakeLists.txt @@ -35,6 +35,7 @@ set(AK_TEST_SOURCES TestFloatingPointParsing.cpp TestFlyString.cpp TestFormat.cpp + TestFunction.cpp TestFuzzyMatch.cpp TestGeneratorAK.cpp TestGenericLexer.cpp diff --git a/Tests/AK/TestFunction.cpp b/Tests/AK/TestFunction.cpp new file mode 100644 index 00000000000000..f1fd6367ff000c --- /dev/null +++ b/Tests/AK/TestFunction.cpp @@ -0,0 +1,14 @@ +/* + * Copyright (c) 2024, Nico Weber + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include + +constexpr int const_call(Function f, int i) +{ + return f(i); +} + +constinit int i = const_call([](int i) { return i; }, 4);