diff --git a/src/atom/CMakeLists.txt b/src/atom/CMakeLists.txt deleted file mode 100644 index e27e0463..00000000 --- a/src/atom/CMakeLists.txt +++ /dev/null @@ -1,105 +0,0 @@ -# CMakeLists.txt for Atom -# This project is licensed under the terms of the GPL3 license. -# -# Project Name: Atom -# Description: Atom Library for all of the Element Astro Project -# Author: Max Qian -# License: GPL3 - -cmake_minimum_required(VERSION 3.20) -project(atom C CXX) - -# Versioning -set(ATOM_VERSION_MAJOR 1) -set(ATOM_VERSION_MINOR 0) -set(ATOM_VERSION_PATCH 0) -set(ATOM_SOVERSION ${ATOM_VERSION_MAJOR}) -set(ATOM_VERSION_STRING "${ATOM_VERSION_MAJOR}.${ATOM_VERSION_MINOR}.${ATOM_VERSION_PATCH}") - -# Python Support -option(ATOM_BUILD_PYTHON "Build Atom with Python support" OFF) -if(ATOM_BUILD_PYTHON) - find_package(Python COMPONENTS Interpreter Development REQUIRED) - if(PYTHON_FOUND) - message("-- Found Python ${PYTHON_VERSION_STRING}: ${PYTHON_EXECUTABLE}") - find_package(pybind11 QUIET) - if(pybind11_FOUND) - message(STATUS "Found pybind11: ${pybind11_INCLUDE_DIRS}") - else() - message(FATAL_ERROR "pybind11 not found") - endif() - else() - message(FATAL_ERROR "Python not found") - endif() -endif() - -# Subdirectories -add_subdirectory(algorithm) -add_subdirectory(async) -add_subdirectory(components) -add_subdirectory(connection) -add_subdirectory(error) -add_subdirectory(function) -add_subdirectory(io) -add_subdirectory(log) -add_subdirectory(search) -add_subdirectory(secret) -add_subdirectory(sysinfo) -add_subdirectory(system) -add_subdirectory(tests) -add_subdirectory(type) -add_subdirectory(utils) -add_subdirectory(web) - -# Sources and Headers -set(ATOM_SOURCES - log/atomlog.cpp - log/logger.cpp -) - -set(ATOM_HEADERS - log/atomlog.hpp - log/logger.hpp -) - -# Libraries -set(ATOM_LIBS - loguru - cpp_httplib - atom-function - atom-algorithm - atom-async - atom-io - atom-component - atom-type - atom-utils - atom-search - atom-web - atom-system - atom-sysinfo -) - -# Object Library -add_library(atom_object OBJECT ${ATOM_SOURCES} ${ATOM_HEADERS}) - -if(WIN32) - target_link_libraries(atom_object setupapi wsock32 ws2_32 shlwapi iphlpapi) -endif() - -target_link_libraries(atom_object ${ATOM_LIBS}) - -# Static Library -add_library(atom STATIC) -set_target_properties(atom PROPERTIES - IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}atom${CMAKE_STATIC_LIBRARY_SUFFIX}" - VERSION ${ATOM_VERSION_STRING} - SOVERSION ${ATOM_SOVERSION} -) - -target_link_libraries(atom atom_object ${CMAKE_THREAD_LIBS_INIT} ${ATOM_LIBS}) - -# Install -install(TARGETS atom - DESTINATION ${CMAKE_INSTALL_LIBDIR} - COMPONENT library -) diff --git a/src/atom/algorithm/CMakeLists.txt b/src/atom/algorithm/CMakeLists.txt deleted file mode 100644 index 55cad61c..00000000 --- a/src/atom/algorithm/CMakeLists.txt +++ /dev/null @@ -1,77 +0,0 @@ -# CMakeLists.txt for Atom-Algorithm -# This project is licensed under the terms of the GPL3 license. -# -# Project Name: Atom-Algorithm -# Description: A collection of algorithms -# Author: Max Qian -# License: GPL3 - -cmake_minimum_required(VERSION 3.20) -project(atom-algorithm C CXX) - -# Sources -set(${PROJECT_NAME}_SOURCES - algorithm.cpp - base.cpp - bignumber.cpp - convolve.cpp - fnmatch.cpp - fraction.cpp - huffman.cpp - math.cpp - matrix_compress.cpp - md5.cpp - mhash.cpp - tea.cpp -) - -# Headers -set(${PROJECT_NAME}_HEADERS - algorithm.hpp - base.hpp - bignumber.hpp - convolve.hpp - fnmatch.hpp - fraction.hpp - hash.hpp - huffman.hpp - math.hpp - matrix_compress.hpp - md5.hpp - mhash.hpp - tea.hpp -) - -# Build Object Library -add_library(${PROJECT_NAME}_OBJECT OBJECT) -set_property(TARGET ${PROJECT_NAME}_OBJECT PROPERTY POSITION_INDEPENDENT_CODE 1) - -target_sources(${PROJECT_NAME}_OBJECT - PUBLIC - ${${PROJECT_NAME}_HEADERS} - PRIVATE - ${${PROJECT_NAME}_SOURCES} -) - -target_link_libraries(${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) - -add_library(${PROJECT_NAME} STATIC) - -target_link_libraries(${PROJECT_NAME} ${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) -target_link_libraries(${PROJECT_NAME} ${CMAKE_THREAD_LIBS_INIT}) -target_include_directories(${PROJECT_NAME} PUBLIC .) - -set_target_properties(${PROJECT_NAME} PROPERTIES - VERSION ${CMAKE_HYDROGEN_VERSION_STRING} - SOVERSION ${HYDROGEN_SOVERSION} - OUTPUT_NAME ${PROJECT_NAME} -) - -install(TARGETS ${PROJECT_NAME} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} -) - -if (ATOM_BUILD_PYTHON) -pybind11_add_module(${PROJECT_NAME}-py _pybind.cpp) -target_link_libraries(${PROJECT_NAME}-py PRIVATE ${PROJECT_NAME}) -endif() diff --git a/src/atom/algorithm/algorithm.cpp b/src/atom/algorithm/algorithm.cpp deleted file mode 100644 index 903a4411..00000000 --- a/src/atom/algorithm/algorithm.cpp +++ /dev/null @@ -1,256 +0,0 @@ -#include "algorithm.hpp" - -#include "atom/log/loguru.hpp" - -#ifdef USE_OPENMP -#include -#endif - -namespace atom::algorithm { - -KMP::KMP(std::string_view pattern) { - LOG_F(INFO, "Initializing KMP with pattern: %.*s", - static_cast(pattern.size()), pattern.data()); - setPattern(pattern); -} - -auto KMP::search(std::string_view text) const -> std::vector { - std::vector occurrences; - try { - std::shared_lock lock(mutex_); - auto n = static_cast(text.length()); - auto m = static_cast(pattern_.length()); - LOG_F(INFO, "KMP searching text of length %d with pattern length %d.", - n, m); - if (m == 0) { - LOG_F(WARNING, "Empty pattern provided to KMP::search."); - return occurrences; - } - -#ifdef USE_SIMD - int i = 0; - int j = 0; - while (i <= n - m) { - __m256i text_chunk = - _mm256_loadu_si256(reinterpret_cast(&text[i])); - __m256i pattern_chunk = _mm256_loadu_si256( - reinterpret_cast(&pattern_[0])); - __m256i result = _mm256_cmpeq_epi8(text_chunk, pattern_chunk); - int mask = _mm256_movemask_epi8(result); - if (mask == 0xFFFFFFFF) { - occurrences.push_back(i); - i += m; - } else { - ++i; - } - } -#elif defined(USE_OPENMP) - std::vector local_occurrences[omp_get_max_threads()]; -#pragma omp parallel - { - int thread_num = omp_get_thread_num(); - int i = thread_num; - int j = 0; - while (i < n) { - if (text[i] == pattern_[j]) { - ++i; - ++j; - if (j == m) { - local_occurrences[thread_num].push_back(i - m); - j = failure_[j - 1]; - } - } else if (j > 0) { - j = failure_[j - 1]; - } else { - ++i; - } - } - } - for (int t = 0; t < omp_get_max_threads(); ++t) { - occurrences.insert(occurrences.end(), local_occurrences[t].begin(), - local_occurrences[t].end()); - } -#else - int i = 0; - int j = 0; - while (i < n) { - if (text[i] == pattern_[j]) { - ++i; - ++j; - if (j == m) { - occurrences.push_back(i - m); - j = failure_[j - 1]; - } - } else if (j > 0) { - j = failure_[j - 1]; - } else { - ++i; - } - } -#endif - LOG_F(INFO, "KMP search completed with {} occurrences found.", - occurrences.size()); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in KMP::search: {}", e.what()); - throw; - } - return occurrences; -} - -void KMP::setPattern(std::string_view pattern) { - std::unique_lock lock(mutex_); - LOG_F(INFO, "Setting new pattern for KMP: %.*s", - static_cast(pattern.size()), pattern.data()); - pattern_ = pattern; - failure_ = computeFailureFunction(pattern_); -} - -auto KMP::computeFailureFunction(std::string_view pattern) -> std::vector { - LOG_F(INFO, "Computing failure function for pattern."); - auto m = static_cast(pattern.length()); - std::vector failure(m, 0); - int j = 0; - for (int i = 1; i < m; ++i) { - while (j > 0 && pattern[i] != pattern[j]) { - j = failure[j - 1]; - } - if (pattern[i] == pattern[j]) { - failure[i] = ++j; - } - } - LOG_F(INFO, "Failure function computed."); - return failure; -} - -BoyerMoore::BoyerMoore(std::string_view pattern) { - LOG_F(INFO, "Initializing BoyerMoore with pattern: %.*s", - static_cast(pattern.size()), pattern.data()); - setPattern(pattern); -} - -auto BoyerMoore::search(std::string_view text) const -> std::vector { - std::vector occurrences; - try { - std::lock_guard lock(mutex_); - auto n = static_cast(text.length()); - auto m = static_cast(pattern_.length()); - LOG_F(INFO, - "BoyerMoore searching text of length %d with pattern length %d.", - n, m); - if (m == 0) { - LOG_F(WARNING, "Empty pattern provided to BoyerMoore::search."); - return occurrences; - } - -#ifdef USE_OPENMP - std::vector local_occurrences[omp_get_max_threads()]; -#pragma omp parallel - { - int thread_num = omp_get_thread_num(); - int i = thread_num; - while (i <= n - m) { - int j = m - 1; - while (j >= 0 && pattern_[j] == text[i + j]) { - --j; - } - if (j < 0) { - local_occurrences[thread_num].push_back(i); - i += good_suffix_shift_[0]; - } else { - int badCharShift = bad_char_shift_.find(text[i + j]) != - bad_char_shift_.end() - ? bad_char_shift_.at(text[i + j]) - : m; - i += std::max(good_suffix_shift_[j + 1], - static_cast(badCharShift - m + 1 + j)); - } - } - } - for (int t = 0; t < omp_get_max_threads(); ++t) { - occurrences.insert(occurrences.end(), local_occurrences[t].begin(), - local_occurrences[t].end()); - } -#else - int i = 0; - while (i <= n - m) { - int j = m - 1; - while (j >= 0 && pattern_[j] == text[i + j]) { - --j; - } - if (j < 0) { - occurrences.push_back(i); - i += good_suffix_shift_[0]; - } else { - int badCharShift = - bad_char_shift_.find(text[i + j]) != bad_char_shift_.end() - ? bad_char_shift_.at(text[i + j]) - : m; - i += std::max(good_suffix_shift_[j + 1], - badCharShift - m + 1 + j); - } - } -#endif - LOG_F(INFO, "BoyerMoore search completed with {} occurrences found.", - occurrences.size()); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in BoyerMoore::search: {}", e.what()); - throw; - } - return occurrences; -} - -void BoyerMoore::setPattern(std::string_view pattern) { - std::lock_guard lock(mutex_); - LOG_F(INFO, "Setting new pattern for BoyerMoore: %.*s", - static_cast(pattern.size()), pattern.data()); - pattern_ = std::string(pattern); - computeBadCharacterShift(); - computeGoodSuffixShift(); -} - -void BoyerMoore::computeBadCharacterShift() { - LOG_F(INFO, "Computing bad character shift table."); - bad_char_shift_.clear(); - for (int i = 0; i < static_cast(pattern_.length()) - 1; ++i) { - bad_char_shift_[pattern_[i]] = - static_cast(pattern_.length()) - 1 - i; - } - LOG_F(INFO, "Bad character shift table computed."); -} - -void BoyerMoore::computeGoodSuffixShift() { - LOG_F(INFO, "Computing good suffix shift table."); - auto m = static_cast(pattern_.length()); - good_suffix_shift_.resize(m + 1, m); - std::vector suffix(m + 1, 0); - suffix[m] = m + 1; - - for (int i = m; i > 0; --i) { - int j = i - 1; - while (j >= 0 && pattern_[j] != pattern_[m - 1 - (i - 1 - j)]) { - --j; - } - suffix[i - 1] = j + 1; - } - - for (int i = 0; i <= m; ++i) { - good_suffix_shift_[i] = m; - } - - for (int i = m; i > 0; --i) { - if (suffix[i - 1] == i) { - for (int j = 0; j < m - i; ++j) { - if (good_suffix_shift_[j] == m) { - good_suffix_shift_[j] = m - i; - } - } - } - } - - for (int i = 0; i < m - 1; ++i) { - good_suffix_shift_[m - suffix[i]] = m - 1 - i; - } - LOG_F(INFO, "Good suffix shift table computed."); -} - -} // namespace atom::algorithm \ No newline at end of file diff --git a/src/atom/algorithm/algorithm.hpp b/src/atom/algorithm/algorithm.hpp deleted file mode 100644 index c33b39ff..00000000 --- a/src/atom/algorithm/algorithm.hpp +++ /dev/null @@ -1,226 +0,0 @@ -/* - * algorithm.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-4-5 - -Description: A collection of algorithms for C++ - -**************************************************/ - -#ifndef ATOM_ALGORITHM_ALGORITHM_HPP -#define ATOM_ALGORITHM_ALGORITHM_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::algorithm { -/** - * @brief Implements the Knuth-Morris-Pratt (KMP) string searching algorithm. - * - * This class provides methods to search for occurrences of a pattern within a - * text using the KMP algorithm, which preprocesses the pattern to achieve - * efficient string searching. - */ -class KMP { -public: - /** - * @brief Constructs a KMP object with the given pattern. - * - * @param pattern The pattern to search for in text. - */ - explicit KMP(std::string_view pattern); - - /** - * @brief Searches for occurrences of the pattern in the given text. - * - * @param text The text to search within. - * @return std::vector Vector containing positions where the pattern - * starts in the text. - */ - [[nodiscard]] auto search(std::string_view text) const -> std::vector; - - /** - * @brief Sets a new pattern for searching. - * - * @param pattern The new pattern to search for. - */ - void setPattern(std::string_view pattern); - -private: - /** - * @brief Computes the failure function (partial match table) for the given - * pattern. - * - * This function preprocesses the pattern to determine the length of the - * longest proper prefix which is also a suffix at each position in the - * pattern. - * - * @param pattern The pattern for which to compute the failure function. - * @return std::vector The computed failure function (partial match - * table). - */ - auto computeFailureFunction(std::string_view pattern) -> std::vector; - - std::string pattern_; ///< The pattern to search for. - std::vector - failure_; ///< Failure function (partial match table) for the pattern. - - mutable std::shared_mutex mutex_; ///< Mutex for thread-safe operations -}; - -/** - * @brief The BloomFilter class implements a Bloom filter data structure. - * @tparam N The size of the Bloom filter (number of bits). - */ -template -class BloomFilter { -public: - /** - * @brief Constructs a new BloomFilter object with the specified number of - * hash functions. - * @param num_hash_functions The number of hash functions to use for the - * Bloom filter. - */ - explicit BloomFilter(std::size_t num_hash_functions); - - /** - * @brief Inserts an element into the Bloom filter. - * @param element The element to insert. - */ - void insert(std::string_view element); - - /** - * @brief Checks if an element might be present in the Bloom filter. - * @param element The element to check. - * @return True if the element might be present, false otherwise. - */ - [[nodiscard]] auto contains(std::string_view element) const -> bool; - -private: - std::bitset m_bits_; /**< The bitset representing the Bloom filter. */ - std::size_t - m_num_hash_functions_; /**< The number of hash functions used. */ - - /** - * @brief Computes the hash value of an element using a specific seed. - * @param element The element to hash. - * @param seed The seed value for the hash function. - * @return The hash value of the element. - */ - auto hash(std::string_view element, std::size_t seed) const -> std::size_t; -}; - -/** - * @brief Implements the Boyer-Moore string searching algorithm. - * - * This class provides methods to search for occurrences of a pattern within a - * text using the Boyer-Moore algorithm, which preprocesses the pattern to - * achieve efficient string searching. - */ -class BoyerMoore { -public: - /** - * @brief Constructs a BoyerMoore object with the given pattern. - * - * @param pattern The pattern to search for in text. - */ - explicit BoyerMoore(std::string_view pattern); - - /** - * @brief Searches for occurrences of the pattern in the given text. - * - * @param text The text to search within. - * @return std::vector Vector containing positions where the pattern - * starts in the text. - */ - auto search(std::string_view text) const -> std::vector; - - /** - * @brief Sets a new pattern for searching. - * - * @param pattern The new pattern to search for. - */ - void setPattern(std::string_view pattern); - -private: - /** - * @brief Computes the bad character shift table for the current pattern. - * - * This table determines how far to shift the pattern relative to the text - * based on the last occurrence of a mismatched character. - */ - void computeBadCharacterShift(); - - /** - * @brief Computes the good suffix shift table for the current pattern. - * - * This table helps determine how far to shift the pattern when a mismatch - * occurs based on the occurrence of a partial match (suffix of the - * pattern). - */ - void computeGoodSuffixShift(); - - std::string pattern_; ///< The pattern to search for. - std::unordered_map - bad_char_shift_; ///< Bad character shift table. - std::vector good_suffix_shift_; ///< Good suffix shift table. - - mutable std::mutex mutex_; ///< Mutex for thread-safe operations -}; - -template -BloomFilter::BloomFilter(std::size_t num_hash_functions) - : m_num_hash_functions_(num_hash_functions) {} - -template -void BloomFilter::insert(std::string_view element) { - try { - for (std::size_t i = 0; i < m_num_hash_functions_; ++i) { - std::size_t hashValue = hash(element, i); - m_bits_.set(hashValue % N); - } - } catch (const std::exception& e) { - throw; - } -} - -template -auto BloomFilter::contains(std::string_view element) const -> bool { - try { - for (std::size_t i = 0; i < m_num_hash_functions_; ++i) { - std::size_t hashValue = hash(element, i); - if (!m_bits_.test(hashValue % N)) { - return false; - } - } - return true; - } catch (const std::exception& e) { - throw; - } -} - -template -auto BloomFilter::hash(std::string_view element, - std::size_t seed) const -> std::size_t { - std::size_t hashValue = seed; - for (char c : element) { - hashValue = hashValue * 31 + static_cast(c); - } - return hashValue; -} - -} // namespace atom::algorithm - -#endif \ No newline at end of file diff --git a/src/atom/algorithm/annealing.hpp b/src/atom/algorithm/annealing.hpp deleted file mode 100644 index 633b5cea..00000000 --- a/src/atom/algorithm/annealing.hpp +++ /dev/null @@ -1,391 +0,0 @@ -#ifndef ATOM_ALGORITHM_ANNEALING_HPP -#define ATOM_ALGORITHM_ANNEALING_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef USE_SIMD -#ifdef __x86_64__ -#include -#elif __aarch64__ -#include -#endif -#endif - -#include "atom/log/loguru.hpp" - -// Define a concept for a problem that Simulated Annealing can solve -template -concept AnnealingProblem = - requires(ProblemType problemInstance, SolutionType solutionInstance) { - { - problemInstance.energy(solutionInstance) - } -> std::convertible_to; - { - problemInstance.neighbor(solutionInstance) - } -> std::same_as; - { problemInstance.random_solution() } -> std::same_as; - }; - -// Different cooling strategies for temperature reduction -enum class AnnealingStrategy { LINEAR, EXPONENTIAL, LOGARITHMIC }; - -// Simulated Annealing algorithm implementation -template - requires AnnealingProblem -class SimulatedAnnealing { -private: - ProblemType& problem_instance_; - std::function cooling_schedule_; - int max_iterations_; - double initial_temperature_; - AnnealingStrategy cooling_strategy_; - std::function progress_callback_; - std::function stop_condition_; - std::atomic should_stop_{false}; - - std::mutex best_mutex_; - SolutionType best_solution_; - double best_energy_ = std::numeric_limits::max(); - - static constexpr int K_DEFAULT_MAX_ITERATIONS = 1000; - static constexpr double K_DEFAULT_INITIAL_TEMPERATURE = 100.0; - static constexpr double K_COOLING_RATE = 0.95; - - void optimizeThread(); - -public: - explicit SimulatedAnnealing( - ProblemType& problemInstance, - AnnealingStrategy coolingStrategy = AnnealingStrategy::EXPONENTIAL, - int maxIterations = K_DEFAULT_MAX_ITERATIONS, - double initialTemperature = K_DEFAULT_INITIAL_TEMPERATURE); - - void setCoolingSchedule(AnnealingStrategy strategy); - - void setProgressCallback( - std::function callback); - - void setStopCondition( - std::function condition); - - auto optimize(int numThreads = 1) -> SolutionType; - - [[nodiscard]] auto getBestEnergy() const -> double; -}; - -// Example TSP (Traveling Salesman Problem) implementation -class TSP { -private: - std::vector> cities_; - -public: - explicit TSP(const std::vector>& cities); - - [[nodiscard]] auto energy(const std::vector& solution) const -> double; - - [[nodiscard]] static auto neighbor(const std::vector& solution) - -> std::vector; - - [[nodiscard]] auto randomSolution() const -> std::vector; -}; - -// SimulatedAnnealing class implementation -template - requires AnnealingProblem -SimulatedAnnealing::SimulatedAnnealing( - ProblemType& problemInstance, AnnealingStrategy coolingStrategy, - int maxIterations, double initialTemperature) - : problem_instance_(problemInstance), - max_iterations_(maxIterations), - initial_temperature_(initialTemperature), - cooling_strategy_(coolingStrategy) { - LOG_F(INFO, - "SimulatedAnnealing initialized with max_iterations: {}, " - "initial_temperature: %.2f, cooling_strategy: {}", - maxIterations, initialTemperature, static_cast(coolingStrategy)); - setCoolingSchedule(coolingStrategy); -} - -template - requires AnnealingProblem -void SimulatedAnnealing::setCoolingSchedule( - AnnealingStrategy strategy) { - cooling_strategy_ = strategy; - LOG_F(INFO, "Setting cooling schedule to strategy: {}", - static_cast(strategy)); - switch (cooling_strategy_) { - case AnnealingStrategy::LINEAR: - cooling_schedule_ = [this](int iteration) { - return initial_temperature_ * - (1 - static_cast(iteration) / max_iterations_); - }; - break; - case AnnealingStrategy::EXPONENTIAL: - cooling_schedule_ = [this](int iteration) { - return initial_temperature_ * - std::pow(K_COOLING_RATE, iteration); - }; - break; - case AnnealingStrategy::LOGARITHMIC: - cooling_schedule_ = [this](int iteration) { - if (iteration == 0) - return initial_temperature_; - return initial_temperature_ / std::log(iteration + 2); - }; - break; - default: - LOG_F(WARNING, - "Unknown cooling strategy. Defaulting to EXPONENTIAL."); - cooling_schedule_ = [this](int iteration) { - return initial_temperature_ * - std::pow(K_COOLING_RATE, iteration); - }; - break; - } -} - -template - requires AnnealingProblem -void SimulatedAnnealing::setProgressCallback( - std::function callback) { - progress_callback_ = callback; - LOG_F(INFO, "Progress callback has been set."); -} - -template - requires AnnealingProblem -void SimulatedAnnealing::setStopCondition( - std::function condition) { - stop_condition_ = condition; - LOG_F(INFO, "Stop condition has been set."); -} - -template - requires AnnealingProblem -void SimulatedAnnealing::optimizeThread() { - try { - std::random_device randomDevice; - std::mt19937 generator(randomDevice()); - std::uniform_real_distribution distribution(0.0, 1.0); - - auto currentSolution = problem_instance_.random_solution(); - double currentEnergy = problem_instance_.energy(currentSolution); - LOG_F(INFO, "Thread %ld started with initial energy: {}", - std::this_thread::get_id(), currentEnergy); - - { - std::lock_guard lock(best_mutex_); - if (currentEnergy < best_energy_) { - best_solution_ = currentSolution; - best_energy_ = currentEnergy; - LOG_F(INFO, "New best energy found: {}", best_energy_); - } - } - - for (int iteration = 0; - iteration < max_iterations_ && !should_stop_.load(); ++iteration) { - double temperature = cooling_schedule_(iteration); - if (temperature <= 0) { - LOG_F(WARNING, - "Temperature has reached zero or below at iteration {}.", - iteration); - break; - } - - auto neighborSolution = problem_instance_.neighbor(currentSolution); - double neighborEnergy = problem_instance_.energy(neighborSolution); - - double energyDifference = neighborEnergy - currentEnergy; - LOG_F(INFO, - "Iteration {}: Current Energy = {}, Neighbor Energy = " - "{}, Energy Difference = {}, Temperature = {}", - iteration, currentEnergy, neighborEnergy, energyDifference, - temperature); - - if (energyDifference < 0 || - distribution(generator) < - std::exp(-energyDifference / temperature)) { - currentSolution = std::move(neighborSolution); - currentEnergy = neighborEnergy; - LOG_F(INFO, "Solution accepted at iteration {} with energy: {}", - iteration, currentEnergy); - - std::lock_guard lock(best_mutex_); - if (currentEnergy < best_energy_) { - best_solution_ = currentSolution; - best_energy_ = currentEnergy; - LOG_F(INFO, "New best energy updated to: {}", best_energy_); - } - } - - if (progress_callback_) { - try { - progress_callback_(iteration, currentEnergy, - currentSolution); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in progress_callback_: {}", - e.what()); - } - } - - if (stop_condition_ && - stop_condition_(iteration, currentEnergy, currentSolution)) { - should_stop_.store(true); - LOG_F(INFO, "Stop condition met at iteration {}.", iteration); - break; - } - } - LOG_F(INFO, "Thread %ld completed optimization with best energy: {}", - std::this_thread::get_id(), best_energy_); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in optimizeThread: {}", e.what()); - } -} - -template - requires AnnealingProblem -auto SimulatedAnnealing::optimize(int numThreads) - -> SolutionType { - LOG_F(INFO, "Starting optimization with {} threads.", numThreads); - if (numThreads < 1) { - LOG_F(WARNING, "Invalid number of threads ({}). Defaulting to 1.", - numThreads); - numThreads = 1; - } - - std::vector> futures; - futures.reserve(numThreads); - for (int threadIndex = 0; threadIndex < numThreads; ++threadIndex) { - futures.emplace_back( - std::async(std::launch::async, [this]() { optimizeThread(); })); - LOG_F(INFO, "Launched optimization thread {}.", threadIndex + 1); - } - - for (auto& future : futures) { - try { - future.wait(); - future.get(); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in optimization thread: {}", e.what()); - } - } - - LOG_F(INFO, "Optimization completed with best energy: {}", best_energy_); - return best_solution_; -} - -template - requires AnnealingProblem -auto SimulatedAnnealing::getBestEnergy() const - -> double { - std::lock_guard lock(best_mutex_); - return best_energy_; -} - -// TSP class implementation -inline TSP::TSP(const std::vector>& cities) - : cities_(cities) { - LOG_F(INFO, "TSP instance created with %zu cities.", cities_.size()); -} - -inline auto TSP::energy(const std::vector& solution) const -> double { - double totalDistance = 0.0; - size_t numCities = solution.size(); - -#ifdef USE_SIMD - __m256d totalDistanceVec = _mm256_setzero_pd(); - size_t i = 0; - for (; i + 3 < numCities; i += 4) { - __m256d x1 = _mm256_set_pd( - cities_[solution[i]].first, cities_[solution[i + 1]].first, - cities_[solution[i + 2]].first, cities_[solution[i + 3]].first); - __m256d y1 = _mm256_set_pd( - cities_[solution[i]].second, cities_[solution[i + 1]].second, - cities_[solution[i + 2]].second, cities_[solution[i + 3]].second); - - __m256d x2 = - _mm256_set_pd(cities_[solution[(i + 1) % numCities]].first, - cities_[solution[(i + 2) % numCities]].first, - cities_[solution[(i + 3) % numCities]].first, - cities_[solution[(i + 4) % numCities]].first); - __m256d y2 = - _mm256_set_pd(cities_[solution[(i + 1) % numCities]].second, - cities_[solution[(i + 2) % numCities]].second, - cities_[solution[(i + 3) % numCities]].second, - cities_[solution[(i + 4) % numCities]].second); - - __m256d deltaX = _mm256_sub_pd(x1, x2); - __m256d deltaY = _mm256_sub_pd(y1, y2); - - __m256d distance = _mm256_sqrt_pd(_mm256_add_pd( - _mm256_mul_pd(deltaX, deltaX), _mm256_mul_pd(deltaY, deltaY))); - totalDistanceVec = _mm256_add_pd(totalDistanceVec, distance); - } - - // Horizontal addition to sum up the total distance in vector - double distances[4]; - _mm256_storeu_pd(distances, totalDistanceVec); - for (double d : distances) { - totalDistance += d; - } -#endif - - // Handle leftover cities that couldn't be processed in sets of 4 - for (size_t index = numCities - numCities % 4; index < numCities; ++index) { - auto [x1, y1] = cities_[solution[index]]; - auto [x2, y2] = cities_[solution[(index + 1) % numCities]]; - double deltaX = x1 - x2; - double deltaY = y1 - y2; - totalDistance += std::sqrt(deltaX * deltaX + deltaY * deltaY); - } - - LOG_F(INFO, "Computed energy (total distance): {}", totalDistance); - return totalDistance; -} - -inline auto TSP::neighbor(const std::vector& solution) - -> std::vector { - std::vector newSolution = solution; - try { - std::random_device randomDevice; - std::mt19937 generator(randomDevice()); - std::uniform_int_distribution distribution( - 0, static_cast(solution.size()) - 1); - int index1 = distribution(generator); - int index2 = distribution(generator); - std::swap(newSolution[index1], newSolution[index2]); - LOG_F(INFO, - "Generated neighbor solution by swapping indices {} and {}.", - index1, index2); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in TSP::neighbor: {}", e.what()); - throw; - } - return newSolution; -} - -inline auto TSP::randomSolution() const -> std::vector { - std::vector solution(cities_.size()); - std::iota(solution.begin(), solution.end(), 0); - try { - std::random_device randomDevice; - std::mt19937 generator(randomDevice()); - std::ranges::shuffle(solution, generator); - LOG_F(INFO, "Generated random solution."); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in TSP::randomSolution: {}", e.what()); - throw; - } - return solution; -} - -#endif // ATOM_ALGORITHM_ANNEALING_HPP \ No newline at end of file diff --git a/src/atom/algorithm/base.cpp b/src/atom/algorithm/base.cpp deleted file mode 100644 index 58889c40..00000000 --- a/src/atom/algorithm/base.cpp +++ /dev/null @@ -1,723 +0,0 @@ -/* - * base.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-4-5 - -Description: A collection of algorithms for C++ - -**************************************************/ - -#include "base.hpp" - -#include -#include - -#include "atom/error/exception.hpp" - -#ifdef _WIN32 -#include -#else -#include -#endif - -#ifdef USE_SIMD -#if defined(__AVX2__) || defined(USE_AVX) -#include -#define SIMD_WIDTH 32 -#elif defined(__SSE4_1__) || defined(USE_SSE) -#include -#define SIMD_WIDTH 16 -#elif defined(__ARM_NEON) || defined(USE_NEON) -#include -#define SIMD_WIDTH 16 -#endif -#endif - -#if USE_OPENCL -#include -constexpr bool HAS_OPEN_CL = true; -#else -constexpr bool HAS_OPEN_CL = false; -#endif - -namespace atom::algorithm { -namespace detail { -#if USE_OPENCL -const char* base64EncodeKernelSource = R"( - __kernel void base64EncodeKernel(__global const uchar* input, __global char* output, int size) { - int i = get_global_id(0); - if (i < size / 3) { - uchar3 in = vload3(i, input); - output[i * 4 + 0] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[(in.s0 >> 2) & 0x3F]; - output[i * 4 + 1] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[((in.s0 & 0x03) << 4) | ((in.s1 >> 4) & 0x0F)]; - output[i * 4 + 2] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[((in.s1 & 0x0F) << 2) | ((in.s2 >> 6) & 0x03)]; - output[i * 4 + 3] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[in.s2 & 0x3F]; - } - } - )"; - -const char* base64DecodeKernelSource = R"( - __kernel void base64DecodeKernel(__global const char* input, __global uchar* output, int size) { - int i = get_global_id(0); - if (i < size / 4) { - char4 in = vload4(i, input); - output[i * 3 + 0] = (uchar)((in.s0 << 2) | ((in.s1 >> 4) & 0x03)); - output[i * 3 + 1] = (uchar)(((in.s1 & 0x0F) << 4) | ((in.s2 >> 2) & 0x0F)); - output[i * 3 + 2] = (uchar)(((in.s2 & 0x03) << 6) | (in.s3 & 0x3F)); - } - } - )"; - -// OpenCL kernel for XOR encryption/decryption -const char* xorKernelSource = R"( - __kernel void xorKernel(__global const char* input, __global char* output, uchar key, int size) { - int i = get_global_id(0); - if (i < size) { - output[i] = input[i] ^ key; - } - } - )"; - -// OpenCL setup and context management -cl_context context; -cl_command_queue queue; -cl_program program; -cl_kernel base64EncodeKernel, base64DecodeKernel, xorKernel; - -void initializeOpenCL() { - // Initialize OpenCL context, compile the kernels, etc. - // Error handling omitted for brevity - cl_platform_id platform; - cl_device_id device; - clGetPlatformIDs(1, &platform, nullptr); - clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &device, nullptr); - context = clCreateContext(nullptr, 1, &device, nullptr, nullptr, nullptr); - queue = clCreateCommandQueue(context, device, 0, nullptr); - - // Compile the kernels - const char* sources[] = {base64EncodeKernelSource, base64DecodeKernelSource, - xorKernelSource}; - program = clCreateProgramWithSource(context, 3, sources, nullptr, nullptr); - clBuildProgram(program, 1, &device, nullptr, nullptr, nullptr); - - base64EncodeKernel = clCreateKernel(program, "base64EncodeKernel", nullptr); - base64DecodeKernel = clCreateKernel(program, "base64DecodeKernel", nullptr); - xorKernel = clCreateKernel(program, "xorKernel", nullptr); -} - -void cleanupOpenCL() { - // Cleanup OpenCL resources - clReleaseKernel(base64EncodeKernel); - clReleaseKernel(base64DecodeKernel); - clReleaseKernel(xorKernel); - clReleaseProgram(program); - clReleaseCommandQueue(queue); - clReleaseContext(context); -} - -void base64EncodeOpenCL(const unsigned char* input, char* output, size_t size) { - cl_mem inputBuffer = - clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, size, - (void*)input, nullptr); - cl_mem outputBuffer = clCreateBuffer(context, CL_MEM_WRITE_ONLY, - (size + 2) / 3 * 4, nullptr, nullptr); - - clSetKernelArg(base64EncodeKernel, 0, sizeof(cl_mem), &inputBuffer); - clSetKernelArg(base64EncodeKernel, 1, sizeof(cl_mem), &outputBuffer); - clSetKernelArg(base64EncodeKernel, 2, sizeof(int), &size); - - size_t globalWorkSize = (size + 2) / 3; - clEnqueueNDRangeKernel(queue, base64EncodeKernel, 1, nullptr, - &globalWorkSize, nullptr, 0, nullptr, nullptr); - - clEnqueueReadBuffer(queue, outputBuffer, CL_TRUE, 0, (size + 2) / 3 * 4, - output, 0, nullptr, nullptr); - - clReleaseMemObject(inputBuffer); - clReleaseMemObject(outputBuffer); -} - -void base64DecodeOpenCL(const char* input, unsigned char* output, size_t size) { - cl_mem inputBuffer = - clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, size, - (void*)input, nullptr); - cl_mem outputBuffer = clCreateBuffer(context, CL_MEM_WRITE_ONLY, - size / 4 * 3, nullptr, nullptr); - - clSetKernelArg(base64DecodeKernel, 0, sizeof(cl_mem), &inputBuffer); - clSetKernelArg(base64DecodeKernel, 1, sizeof(cl_mem), &outputBuffer); - clSetKernelArg(base64DecodeKernel, 2, sizeof(int), &size); - - size_t globalWorkSize = size / 4; - clEnqueueNDRangeKernel(queue, base64DecodeKernel, 1, nullptr, - &globalWorkSize, nullptr, 0, nullptr, nullptr); - - clEnqueueReadBuffer(queue, outputBuffer, CL_TRUE, 0, size / 4 * 3, output, - 0, nullptr, nullptr); - - clReleaseMemObject(inputBuffer); - clReleaseMemObject(outputBuffer); -} - -void xorEncryptOpenCL(const char* input, char* output, uint8_t key, - size_t size) { - cl_mem inputBuffer = - clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, size, - (void*)input, nullptr); - cl_mem outputBuffer = - clCreateBuffer(context, CL_MEM_WRITE_ONLY, size, nullptr, nullptr); - - clSetKernelArg(xorKernel, 0, sizeof(cl_mem), &inputBuffer); - clSetKernelArg(xorKernel, 1, sizeof(cl_mem), &outputBuffer); - clSetKernelArg(xorKernel, 2, sizeof(uint8_t), &key); - clSetKernelArg(xorKernel, 3, sizeof(int), &size); - - size_t globalWorkSize = size; - clEnqueueNDRangeKernel(queue, xorKernel, 1, nullptr, &globalWorkSize, - nullptr, 0, nullptr, nullptr); - - clEnqueueReadBuffer(queue, outputBuffer, CL_TRUE, 0, size, output, 0, - nullptr, nullptr); - - clReleaseMemObject(inputBuffer); - clReleaseMemObject(outputBuffer); -} -#endif -template -void base64Encode(InputIt begin, InputIt end, OutputIt dest) { - std::array charArray3{}; - std::array charArray4{}; - - size_t i = 0; - auto it = begin; - -#ifdef SIMD_AVAILABLE - // SIMD优化部分 - constexpr size_t simdSize = 16; // 处理16字节的输入 - std::array inputBuffer{}; - std::array outputBuffer{}; - - while (std::distance(it, end) >= simdSize) { - std::copy_n(it, simdSize, inputBuffer.begin()); - -#if defined(__x86_64__) || defined(_M_X64) - // x86 SIMD实现 - __m128i input = _mm_loadu_si128( - reinterpret_cast(inputBuffer.data())); - __m128i mask = _mm_set1_epi32(0x3F); - - __m128i result1 = _mm_srli_epi32(input, 2); - __m128i result2 = _mm_and_si128(_mm_slli_epi32(input, 4), mask); - __m128i result3 = _mm_and_si128(_mm_slli_epi32(input, 2), mask); - __m128i result4 = _mm_and_si128(input, mask); - - // 查表并存储结果 - for (int j = 0; j < 16; j += 4) { - outputBuffer[j] = BASE64_CHARS[_mm_extract_epi8(result1, j)]; - outputBuffer[j + 1] = - BASE64_CHARS[_mm_extract_epi8(result2, j + 1)]; - outputBuffer[j + 2] = - BASE64_CHARS[_mm_extract_epi8(result3, j + 2)]; - outputBuffer[j + 3] = - BASE64_CHARS[_mm_extract_epi8(result4, j + 3)]; - } -#elif defined(__ARM_NEON) - // ARM NEON实现 - uint8x16_t input = vld1q_u8(inputBuffer.data()); - uint8x16_t mask = vdupq_n_u8(0x3F); - - uint8x16_t result1 = vshrq_n_u8(input, 2); - uint8x16_t result2 = vandq_u8(vshlq_n_u8(input, 4), mask); - uint8x16_t result3 = vandq_u8(vshlq_n_u8(input, 2), mask); - uint8x16_t result4 = vandq_u8(input, mask); - - // 查表并存储结果 - for (int j = 0; j < 16; j += 4) { - outputBuffer[j] = BASE64_CHARS[vgetq_lane_u8(result1, j)]; - outputBuffer[j + 1] = BASE64_CHARS[vgetq_lane_u8(result2, j + 1)]; - outputBuffer[j + 2] = BASE64_CHARS[vgetq_lane_u8(result3, j + 2)]; - outputBuffer[j + 3] = BASE64_CHARS[vgetq_lane_u8(result4, j + 3)]; - } -#endif - - std::copy_n(outputBuffer.begin(), (simdSize / 3) * 4, dest); - std::advance(dest, (simdSize / 3) * 4); - std::advance(it, simdSize); - i += simdSize; - } -#endif - - // 处理剩余的字节(原始实现) - for (; it != end; ++it, ++i) { - charArray3[i % 3] = static_cast(*it); - if (i % 3 == 2) { - charArray4[0] = (charArray3[0] & 0xfc) >> 2; - charArray4[1] = - ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4); - charArray4[2] = - ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6); - charArray4[3] = charArray3[2] & 0x3f; - - for (int j = 0; j < 4; ++j) { - *dest++ = BASE64_CHARS[charArray4[j]]; - } - } - } - - if (i % 3 != 0) { - for (size_t j = i % 3; j < 3; ++j) { - charArray3[j] = '\0'; - } - - charArray4[0] = (charArray3[0] & 0xfc) >> 2; - charArray4[1] = - ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4); - charArray4[2] = - ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6); - charArray4[3] = charArray3[2] & 0x3f; - - for (size_t j = 0; j < i % 3 + 1; ++j) { - *dest++ = BASE64_CHARS[charArray4[j]]; - } - - while (i++ % 3 != 0) { - *dest++ = '='; - } - } -} - -std::array createReverseLookupTable() { - std::array table{}; - for (int i = 0; i < 64; ++i) { - table[static_cast(BASE64_CHARS[i])] = i; - } - return table; -} - -const auto REVERSE_LOOKUP = createReverseLookupTable(); - -template -void base64Decode(InputIt begin, InputIt end, OutputIt dest) { - std::array charArray4{}; - std::array charArray3{}; - - size_t i = 0; - auto it = begin; - -#ifdef SIMD_AVAILABLE - // SIMD优化部分 - constexpr size_t simdSize = 16; // 处理16字节的输入 - std::array inputBuffer{}; - std::array outputBuffer{}; - - while (std::distance(it, end) >= simdSize && - *std::next(it, simdSize - 1) != '=') { - std::copy_n(it, simdSize, inputBuffer.begin()); - -#if defined(__x86_64__) || defined(_M_X64) - // x86 SIMD实现 - __m128i input = _mm_loadu_si128( - reinterpret_cast(inputBuffer.data())); - __m128i lookup = _mm_setr_epi8( - REVERSE_LOOKUP[inputBuffer[0]], REVERSE_LOOKUP[inputBuffer[1]], - REVERSE_LOOKUP[inputBuffer[2]], REVERSE_LOOKUP[inputBuffer[3]], - REVERSE_LOOKUP[inputBuffer[4]], REVERSE_LOOKUP[inputBuffer[5]], - REVERSE_LOOKUP[inputBuffer[6]], REVERSE_LOOKUP[inputBuffer[7]], - REVERSE_LOOKUP[inputBuffer[8]], REVERSE_LOOKUP[inputBuffer[9]], - REVERSE_LOOKUP[inputBuffer[10]], REVERSE_LOOKUP[inputBuffer[11]], - REVERSE_LOOKUP[inputBuffer[12]], REVERSE_LOOKUP[inputBuffer[13]], - REVERSE_LOOKUP[inputBuffer[14]], REVERSE_LOOKUP[inputBuffer[15]]); - - __m128i merged = _mm_or_si128( - _mm_or_si128(_mm_slli_epi32(lookup, 18), - _mm_slli_epi32(_mm_and_si128(_mm_srli_epi32(lookup, 8), - _mm_set1_epi32(0x3F)), - 12)), - _mm_or_si128( - _mm_slli_epi32(_mm_and_si128(_mm_srli_epi32(lookup, 16), - _mm_set1_epi32(0x3F)), - 6), - _mm_and_si128(_mm_srli_epi32(lookup, 24), - _mm_set1_epi32(0x3F)))); - - __m128i result = - _mm_shuffle_epi8(merged, _mm_setr_epi8(2, 1, 0, 6, 5, 4, 10, 9, 8, - 14, 13, 12, -1, -1, -1, -1)); - _mm_storeu_si128(reinterpret_cast<__m128i*>(outputBuffer.data()), - result); - -#elif defined(__ARM_NEON) - // ARM NEON实现 - uint8x16_t input = vld1q_u8(inputBuffer.data()); - uint8x16_t lookup = vcreate_u8(0); - for (int j = 0; j < 16; ++j) { - lookup = vsetq_lane_u8(REVERSE_LOOKUP[inputBuffer[j]], lookup, j); - } - - uint32x4_t merged = vorrq_u32( - vorrq_u32( - vshlq_n_u32(vreinterpretq_u32_u8(lookup), 18), - vshlq_n_u32( - vandq_u32(vshrq_n_u32(vreinterpretq_u32_u8(lookup), 8), - vdupq_n_u32(0x3F)), - 12)), - vorrq_u32( - vshlq_n_u32( - vandq_u32(vshrq_n_u32(vreinterpretq_u32_u8(lookup), 16), - vdupq_n_u32(0x3F)), - 6), - vandq_u32(vshrq_n_u32(vreinterpretq_u32_u8(lookup), 24), - vdupq_n_u32(0x3F)))); - - uint8x16_t result = vqtbl1q_u8(vreinterpretq_u8_u32(merged), - vld1q_u8({2, 1, 0, 6, 5, 4, 10, 9, 8, 14, - 13, 12, 255, 255, 255, 255})); - vst1q_u8(outputBuffer.data(), result); -#endif - - std::copy_n(outputBuffer.begin(), (simdSize / 4) * 3, dest); - std::advance(dest, (simdSize / 4) * 3); - std::advance(it, simdSize); - } -#endif - - for (; it != end && *it != '='; ++it) { - charArray4[i++] = REVERSE_LOOKUP[static_cast(*it)]; - if (i == 4) { - charArray3[0] = - (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4); - charArray3[1] = - ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2); - charArray3[2] = ((charArray4[2] & 0x3) << 6) + charArray4[3]; - - for (i = 0; i < 3; ++i) { - *dest++ = charArray3[i]; - } - i = 0; - } - } - - if (i != 0) { - for (size_t j = i; j < 4; ++j) { - charArray4[j] = 0; - } - - charArray3[0] = (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4); - charArray3[1] = - ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2); - - for (size_t j = 0; j < i - 1; ++j) { - *dest++ = charArray3[j]; - } - } -} -} // namespace detail - -auto base64Encode(std::string_view bytes_to_encode) -> std::string { - std::string ret; - ret.reserve((bytes_to_encode.size() + 2) / 3 * 4); - - if (HAS_OPEN_CL) { -#if USE_OPENCL - detail::base64EncodeOpenCL( - reinterpret_cast(bytes_to_encode.data()), - ret.data(), bytes_to_encode.size()); -#endif - } else { - detail::base64Encode(bytes_to_encode.begin(), bytes_to_encode.end(), - std::back_inserter(ret)); - } - - return ret; -} - -auto base64Decode(std::string_view encoded_string) -> std::string { - std::string ret; - ret.reserve(encoded_string.size() / 4 * 3); - - if (HAS_OPEN_CL) { -#if USE_OPENCL - detail::base64DecodeOpenCL(encoded_string.data(), - reinterpret_cast(ret.data()), - encoded_string.size()); -#endif - } else { - detail::base64Decode(encoded_string.begin(), encoded_string.end(), - std::back_inserter(ret)); - } - - return ret; -} - -auto fbase64Encode(std::span input) -> std::string { - std::string output; - output.reserve((input.size() + 2) / 3 * 4); - - if (HAS_OPEN_CL) { -#if USE_OPENCL - detail::base64EncodeOpenCL(input.data(), output.data(), input.size()); -#endif - } else { - detail::base64Encode(input.begin(), input.end(), - std::back_inserter(output)); - } - - return output; -} - -auto fbase64Decode(std::span input) -> std::vector { - if (input.size() % 4 != 0) { - THROW_INVALID_ARGUMENT("Invalid base64 input length"); - } - - std::vector output; - output.reserve(input.size() / 4 * 3); - - if (HAS_OPEN_CL) { -#if USE_OPENCL - detail::base64DecodeOpenCL(input.data(), output.data(), input.size()); -#endif - } else { - detail::base64Decode(input.begin(), input.end(), - std::back_inserter(output)); - } - - return output; -} - -auto xorEncrypt(std::string_view plaintext, uint8_t key) -> std::string { - std::string ciphertext; - ciphertext.reserve(plaintext.size()); - - if (HAS_OPEN_CL) { -#if USE_OPENCL - detail::xorEncryptOpenCL(plaintext.data(), ciphertext.data(), key, - plaintext.size()); -#endif - } else { - for (char c : plaintext) { - ciphertext.push_back( - static_cast(static_cast(c) ^ key)); - } - } - - return ciphertext; -} - -auto xorDecrypt(std::string_view ciphertext, uint8_t key) -> std::string { - return xorEncrypt(ciphertext, key); -} - -constexpr std::string_view BASE32_ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; -constexpr int BITS_PER_BYTE = 8; -constexpr int BITS_PER_BASE32_CHAR = 5; -constexpr uint32_t BASE32_MASK = 0x1F; -constexpr uint32_t BYTE_MASK = 0xFF; - -#ifdef USE_MP -#pragma omp parallel for -#endif -auto encodeBase32(const std::vector& data) -> std::string { - std::string encoded; - size_t bitCount = 0; - uint32_t buffer = 0; - -#ifdef USE_SIMD - size_t simdChunkSize = SIMD_WIDTH / BITS_PER_BYTE; // 每个SIMD块的字节数 - - for (size_t i = 0; i + simdChunkSize <= data.size(); i += simdChunkSize) { -#if defined(USE_AVX) || defined(__AVX2__) - __m256i simdData = - _mm256_loadu_si256(reinterpret_cast(&data[i])); - uint32_t simdVal = _mm256_extract_epi32(simdData, 0); // 提取低32位 -#elif defined(USE_SSE) || defined(__SSE4_1__) - __m128i simdData = - _mm_loadu_si128(reinterpret_cast(&data[i])); - uint32_t simdVal = _mm_extract_epi32(simdData, 0); // 提取低32位 -#elif defined(USE_NEON) || defined(__ARM_NEON) - uint8x16_t simdData = vld1q_u8(&data[i]); - uint32_t simdVal = vgetq_lane_u32(vreinterpretq_u32_u8(simdData), 0); -#endif - - for (int j = 0; j < 5; ++j) { - uint8_t index = - (simdVal >> (27 - j * BITS_PER_BASE32_CHAR)) & BASE32_MASK; - encoded += BASE32_ALPHABET[index]; - } - } - - for (size_t i = (data.size() / simdChunkSize) * simdChunkSize; - i < data.size(); ++i) { - buffer = (buffer << BITS_PER_BYTE) | data[i]; - bitCount += BITS_PER_BYTE; - while (bitCount >= BITS_PER_BASE32_CHAR) { - bitCount -= BITS_PER_BASE32_CHAR; - encoded += BASE32_ALPHABET[(buffer >> bitCount) & BASE32_MASK]; - } - } -#else - // 非SIMD编码流程 -#ifdef USE_MP -#pragma omp parallel for -#endif - for (uint8_t byte : data) { - buffer = (buffer << BITS_PER_BYTE) | byte; - bitCount += BITS_PER_BYTE; - while (bitCount >= BITS_PER_BASE32_CHAR) { - bitCount -= BITS_PER_BASE32_CHAR; - encoded += BASE32_ALPHABET[(buffer >> bitCount) & BASE32_MASK]; - } - } -#endif - - if (bitCount > 0) { - encoded += - BASE32_ALPHABET[(buffer << (BITS_PER_BASE32_CHAR - bitCount)) & - BASE32_MASK]; - } - - while (encoded.size() % 8 != 0) { - encoded += '='; - } - - return encoded; -} - -// 解码函数 -#ifdef USE_MP -#pragma omp parallel for -#endif -auto decodeBase32(const std::string& encoded) -> std::vector { - std::vector decoded; - size_t bitCount = 0; - uint32_t buffer = 0; - -#ifdef USE_SIMD - size_t simdChunkSize = SIMD_WIDTH / BITS_PER_BYTE; - - for (size_t i = 0; i + simdChunkSize <= encoded.size(); - i += simdChunkSize) { -#if defined(USE_AVX) || defined(__AVX2__) - __m256i simdEncoded = - _mm256_loadu_si256(reinterpret_cast(&encoded[i])); -#elif defined(USE_SSE) || defined(__SSE4_1__) - __m128i simdEncoded = - _mm_loadu_si128(reinterpret_cast(&encoded[i])); -#elif defined(USE_NEON) || defined(__ARM_NEON) - uint8x16_t simdEncoded = - vld1q_u8(reinterpret_cast(&encoded[i])); -#endif - - for (int j = 0; j < simdChunkSize; ++j) { - int idx = BASE32_ALPHABET.find(encoded[i + j]); - if (idx == std::string::npos) { - throw std::invalid_argument("无效字符在Base32编码中"); - } - buffer = (buffer << BITS_PER_BASE32_CHAR) | idx; - bitCount += BITS_PER_BASE32_CHAR; - if (bitCount >= BITS_PER_BYTE) { - bitCount -= BITS_PER_BYTE; - decoded.push_back( - static_cast((buffer >> bitCount) & BYTE_MASK)); - } - } - } -#else - for (char character : encoded) { - if (character == '=') { - break; - } - auto index = BASE32_ALPHABET.find(character); - if (index == std::string::npos) { - THROW_INVALID_ARGUMENT("Invalid character in Base32 encoding"); - } - - buffer = (buffer << BITS_PER_BASE32_CHAR) | index; - bitCount += BITS_PER_BASE32_CHAR; - if (bitCount >= BITS_PER_BYTE) { - bitCount -= BITS_PER_BYTE; - decoded.push_back( - static_cast((buffer >> bitCount) & BYTE_MASK)); - } - } -#endif - - return decoded; -} - -#ifdef USE_CL -// 读取OpenCL内核文件 -auto readKernelSource(const std::string& filename) -> std::string { - std::ifstream file(filename); - if (!file.is_open()) { - throw std::runtime_error("无法打开内核文件"); - } - std::stringstream buffer; - buffer << file.rdbuf(); - return buffer.str(); -} - -// 使用OpenCL进行Base32编码 -auto encodeBase32CL(const std::vector& data) -> std::string { - // OpenCL平台和设备初始化 - std::vector platforms; - cl::Platform::get(&platforms); - if (platforms.empty()) { - throw std::runtime_error("没有可用的OpenCL平台"); - } - - // 选择第一个平台和设备 - cl::Platform platform = platforms[0]; - std::vector devices; - platform.getDevices(CL_DEVICE_TYPE_GPU, &devices); - if (devices.empty()) { - throw std::runtime_error("没有可用的GPU设备"); - } - cl::Device device = devices[0]; - - // 创建OpenCL上下文和命令队列 - cl::Context context(device); - cl::CommandQueue queue(context, device); - - // 读取内核源代码 - std::string kernelSource = readKernelSource("base32_encode_kernel.cl"); - cl::Program::Sources sources(1, std::make_pair(kernelSource.c_str(), kernelSource.size())); - - // 构建程序 - cl::Program program(context, sources); - if (program.build({device}) != CL_SUCCESS) { - throw std::runtime_error("内核程序构建失败"); - } - - // 分配输入和输出缓冲区 - size_t dataSize = data.size(); - size_t encodedSize = ((dataSize * 8) + 4) / 5; // Base32输出大小 - - cl::Buffer inputBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, dataSize, (void*)data.data()); - cl::Buffer outputBuffer(context, CL_MEM_WRITE_ONLY, encodedSize); - - // 设置内核参数 - cl::Kernel kernel(program, "base32_encode"); - kernel.setArg(0, inputBuffer); - kernel.setArg(1, outputBuffer); - kernel.setArg(2, static_cast(dataSize)); - - // 执行内核 - cl::NDRange global(dataSize); // 数据大小定义全局工作量 - queue.enqueueNDRangeKernel(kernel, cl::NullRange, global, cl::NullRange); - queue.finish(); - - // 读取结果 - std::vector encoded(encodedSize); - queue.enqueueReadBuffer(outputBuffer, CL_TRUE, 0, encodedSize, encoded.data()); - - // 将编码结果转成字符串 - return std::string(encoded.begin(), encoded.end()); -} -#endif -} // namespace atom::algorithm diff --git a/src/atom/algorithm/base.hpp b/src/atom/algorithm/base.hpp deleted file mode 100644 index 3985d93e..00000000 --- a/src/atom/algorithm/base.hpp +++ /dev/null @@ -1,207 +0,0 @@ -/* - * base.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-4-5 - -Description: A collection of algorithms for C++ - -**************************************************/ - -#ifndef ATOM_ALGORITHM_BASE16_HPP -#define ATOM_ALGORITHM_BASE16_HPP - -#include -#include -#include -#include - -#include "atom/type/static_string.hpp" - -#include "atom/macro.hpp" - -namespace atom::algorithm { -namespace detail { -constexpr std::string_view BASE64_CHARS = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; - -constexpr size_t BASE64_CHAR_COUNT = 64; -constexpr uint8_t MASK_6_BITS = 0x3F; -constexpr uint8_t MASK_4_BITS = 0x0F; -constexpr uint8_t MASK_2_BITS = 0x03; -constexpr uint8_t MASK_8_BITS = 0xFC; -constexpr uint8_t MASK_12_BITS = 0xF0; -constexpr uint8_t MASK_14_BITS = 0xC0; -constexpr uint8_t MASK_16_BITS = 0x30; -constexpr uint8_t MASK_18_BITS = 0x3C; -} // namespace detail - -/** - * @brief Base64编码函数 - * - * @param bytes_to_encode 待编码数据 - * @return std::string 编码后的字符串 - */ -[[nodiscard("The result of base64Encode is not used.")]] auto base64Encode( - std::string_view bytes_to_encode) -> std::string; - -/** - * @brief Base64解码函数 - * - * @param encoded_string 待解码字符串 - * @return std::vector 解码后的数据 - */ -[[nodiscard("The result of base64Decode is not used.")]] auto base64Decode( - std::string_view encoded_string) -> std::string; - -/** - * @brief Faster Base64 Encode - * - * @param input - * @return std::string - */ -auto fbase64Encode(std::span input) -> std::string; - -/** - * @brief Faster Base64 Decode - * - * @param input - * @return std::vector - */ -auto fbase64Decode(std::span input) -> std::vector; - -/** - * @brief Encrypts a string using the XOR algorithm. - * - * @param data The string to be encrypted. - * @param key The encryption key. - * @return The encrypted string. - */ -[[nodiscard("The result of xorEncrypt is not used.")]] auto xorEncrypt( - std::string_view plaintext, uint8_t key) -> std::string; - -/** - * @brief Decrypts a string using the XOR algorithm. - * - * @param data The string to be decrypted. - * @param key The decryption key. - * @return The decrypted string. - */ -[[nodiscard("The result of xorDecrypt is not used.")]] auto xorDecrypt( - std::string_view ciphertext, uint8_t key) -> std::string; - -ATOM_INLINE constexpr auto findBase64Char(char character) -> size_t { - for (size_t index = 0; index < detail::BASE64_CHAR_COUNT; ++index) { - if (detail::BASE64_CHARS[index] == character) { - return index; - } - } - return detail::BASE64_CHAR_COUNT; // Indicates not found, should not happen - // with valid input -} - -template -constexpr auto cbase64Encode(const StaticString &input) { - constexpr size_t ENCODED_SIZE = ((N + 2) / 3) * 4; - StaticString ret; - - auto addCharacter = [&](char character) constexpr { ret += character; }; - - std::array charArray3{}; - std::array charArray4{}; - - size_t index = 0; - for (auto it = input.begin(); it != input.end(); ++it, ++index) { - charArray3[index % 3] = static_cast(*it); - if (index % 3 == 2) { - charArray4[0] = (charArray3[0] & detail::MASK_8_BITS) >> 2; - charArray4[1] = ((charArray3[0] & detail::MASK_2_BITS) << 4) + - ((charArray3[1] & detail::MASK_12_BITS) >> 4); - charArray4[2] = ((charArray3[1] & detail::MASK_4_BITS) << 2) + - ((charArray3[2] & detail::MASK_14_BITS) >> 6); - charArray4[3] = charArray3[2] & detail::MASK_6_BITS; - - for (int j = 0; j < 4; ++j) { - addCharacter(detail::BASE64_CHARS[charArray4[j]]); - } - } - } - - if (index % 3 != 0) { - for (size_t j = index % 3; j < 3; ++j) { - charArray3[j] = '\0'; - } - - charArray4[0] = (charArray3[0] & detail::MASK_8_BITS) >> 2; - charArray4[1] = ((charArray3[0] & detail::MASK_2_BITS) << 4) + - ((charArray3[1] & detail::MASK_12_BITS) >> 4); - charArray4[2] = ((charArray3[1] & detail::MASK_4_BITS) << 2) + - ((charArray3[2] & detail::MASK_14_BITS) >> 6); - charArray4[3] = charArray3[2] & detail::MASK_6_BITS; - - for (size_t j = 0; j < index % 3 + 1; ++j) { - addCharacter(detail::BASE64_CHARS[charArray4[j]]); - } - - while (index++ % 3 != 0) { - addCharacter('='); - } - } - - return ret; -} - -template -constexpr auto cbase64Decode(const StaticString &input) { - constexpr size_t DECODED_SIZE = (N / 4) * 3; - StaticString ret; - - auto addCharacter = [&](char character) constexpr { ret += character; }; - - std::array charArray4{}; - std::array charArray3{}; - - size_t index = 0; - for (auto it = input.begin(); it != input.end() && *it != '='; ++it) { - charArray4[index++] = static_cast(findBase64Char(*it)); - if (index == 4) { - charArray3[0] = (charArray4[0] << 2) + - ((charArray4[1] & detail::MASK_16_BITS) >> 4); - charArray3[1] = ((charArray4[1] & detail::MASK_4_BITS) << 4) + - ((charArray4[2] & detail::MASK_18_BITS) >> 2); - charArray3[2] = - ((charArray4[2] & detail::MASK_2_BITS) << 6) + charArray4[3]; - - for (index = 0; index < 3; ++index) { - addCharacter(static_cast(charArray3[index])); - } - index = 0; - } - } - - if (index != 0) { - for (size_t j = index; j < 4; ++j) { - charArray4[j] = 0; - } - - charArray3[0] = (charArray4[0] << 2) + - ((charArray4[1] & detail::MASK_16_BITS) >> 4); - charArray3[1] = ((charArray4[1] & detail::MASK_4_BITS) << 4) + - ((charArray4[2] & detail::MASK_18_BITS) >> 2); - - for (size_t j = 0; j < index - 1; ++j) { - addCharacter(static_cast(charArray3[j])); - } - } - - return ret; -} -} // namespace atom::algorithm - -#endif diff --git a/src/atom/algorithm/bignumber.cpp b/src/atom/algorithm/bignumber.cpp deleted file mode 100644 index 82110086..00000000 --- a/src/atom/algorithm/bignumber.cpp +++ /dev/null @@ -1,300 +0,0 @@ -#include "bignumber.hpp" - -#include -#include - -#include "atom/error/exception.hpp" -#include "atom/log/loguru.hpp" - -namespace atom::algorithm { -auto BigNumber::add(const BigNumber& other) const -> BigNumber { - try { - LOG_F(INFO, "Adding {} and {}", this->numberString_, - other.numberString_); - if (isNegative() && other.isNegative()) { - LOG_F(INFO, "Both numbers are negative. Negating and adding."); - return negate().add(other.negate()).negate(); - } - if (isNegative()) { - LOG_F(INFO, "First number is negative. Performing subtraction."); - return other.subtract(abs()); - } - if (other.isNegative()) { - LOG_F(INFO, "Second number is negative. Performing subtraction."); - return subtract(other.abs()); - } - - std::string result; - int carry = 0; - int i = static_cast(numberString_.length()) - 1; - int j = static_cast(other.numberString_.length()) - 1; - - while (i >= 0 || j >= 0 || carry != 0) { - int digit1 = (i >= 0) ? numberString_[i--] - '0' : 0; - int digit2 = (j >= 0) ? other.numberString_[j--] - '0' : 0; - int sum = digit1 + digit2 + carry; - result.insert(result.begin(), '0' + (sum % 10)); - carry = sum / 10; - } - - LOG_F(INFO, "Result of addition: {}", result); - return BigNumber(result).trimLeadingZeros(); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in BigNumber::add: {}", e.what()); - throw; - } -} - -auto BigNumber::subtract(const BigNumber& other) const -> BigNumber { - try { - LOG_F(INFO, "Subtracting {} from {}", other.numberString_, - this->numberString_); - if (isNegative() && other.isNegative()) { - LOG_F(INFO, "Both numbers are negative. Adjusting subtraction."); - return other.negate().subtract(negate()); - } - if (isNegative()) { - LOG_F( - INFO, - "First number is negative. Performing addition with negation."); - return negate().add(other).negate(); - } - if (other.isNegative()) { - LOG_F(INFO, "Second number is negative. Performing addition."); - return add(other.negate()); - } - if (*this < other) { - LOG_F(INFO, "Result will be negative."); - return other.subtract(*this).negate(); - } - - std::string result; - int carry = 0; - int i = static_cast(numberString_.length()) - 1; - int j = static_cast(other.numberString_.length()) - 1; - - while (i >= 0 || j >= 0) { - int digit1 = (i >= 0) ? numberString_[i--] - '0' : 0; - int digit2 = (j >= 0) ? other.numberString_[j--] - '0' : 0; - int diff = digit1 - digit2 - carry; - if (diff < 0) { - diff += 10; - carry = 1; - } else { - carry = 0; - } - result.insert(result.begin(), '0' + diff); - } - - LOG_F(INFO, "Result of subtraction before trimming: {}", result); - return BigNumber(result).trimLeadingZeros(); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in BigNumber::subtract: {}", e.what()); - throw; - } -} - -auto BigNumber::multiply(const BigNumber& other) const -> BigNumber { - try { - LOG_F(INFO, "Multiplying {} and {}", this->numberString_, - other.numberString_); - if (*this == BigNumber("0") || other == BigNumber("0")) { - LOG_F(INFO, "One of the numbers is zero. Result is 0."); - return BigNumber("0"); - } - - bool resultNegative = isNegative() != other.isNegative(); - BigNumber b1 = abs(); - BigNumber b2 = other.abs(); - - std::vector result( - b1.numberString_.size() + b2.numberString_.size(), 0); - - for (int i = static_cast(b1.numberString_.size()) - 1; i >= 0; - --i) { - for (int j = static_cast(b2.numberString_.size()) - 1; j >= 0; - --j) { - int mul = - (b1.numberString_[i] - '0') * (b2.numberString_[j] - '0'); - int sum = mul + result[i + j + 1]; - - result[i + j + 1] = sum % 10; - result[i + j] += sum / 10; - } - } - - std::string resultStr; - bool started = false; - for (int num : result) { - if (!started && num == 0) - continue; - started = true; - resultStr.push_back(num + '0'); - } - - if (resultStr.empty()) { - resultStr = "0"; - } - - if (resultNegative && resultStr != "0") { - resultStr.insert(resultStr.begin(), '-'); - } - - LOG_F(INFO, "Result of multiplication: {}", resultStr); - return {resultStr}; - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in BigNumber::multiply: {}", e.what()); - throw; - } -} - -auto BigNumber::divide(const BigNumber& other) const -> BigNumber { - try { - LOG_F(INFO, "Dividing {} by {}", this->numberString_, - other.numberString_); - if (other == BigNumber("0")) { - LOG_F(ERROR, "Division by zero"); - THROW_INVALID_ARGUMENT("Division by zero"); - } - - bool resultNegative = isNegative() != other.isNegative(); - BigNumber dividend = abs(); - BigNumber divisor = other.abs(); - BigNumber quotient("0"); - BigNumber current("0"); - - for (char digit : dividend.numberString_) { - current = current.multiply(BigNumber("10")) - .add(BigNumber(std::string(1, digit))); - int count = 0; - while (current >= divisor) { - current = current.subtract(divisor); - ++count; - } - quotient = quotient.multiply(BigNumber("10")) - .add(BigNumber(std::to_string(count))); - } - - quotient = quotient.trimLeadingZeros(); - if (resultNegative && quotient != BigNumber("0")) { - quotient = quotient.negate(); - } - - LOG_F(INFO, "Result of division: {}", quotient.numberString_); - return quotient; - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in BigNumber::divide: {}", e.what()); - throw; - } -} - -auto BigNumber::pow(int exponent) const -> BigNumber { - try { - LOG_F(INFO, "Raising {} to the power of {}", this->numberString_, - exponent); - if (exponent < 0) { - LOG_F(ERROR, "Negative exponents are not supported"); - THROW_INVALID_ARGUMENT("Negative exponents are not supported"); - } - if (exponent == 0) { - return BigNumber("1"); - } - if (exponent == 1) { - return *this; - } - BigNumber result("1"); - BigNumber base = *this; - while (exponent != 0) { - if (exponent & 1) { - result = result.multiply(base); - } - exponent >>= 1; - if (exponent != 0) { - base = base.multiply(base); - } - } - LOG_F(INFO, "Result of exponentiation: {}", result.numberString_); - return result; - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in BigNumber::pow: {}", e.what()); - throw; - } -} - -auto BigNumber::trimLeadingZeros() const -> BigNumber { - try { - LOG_F(INFO, "Trimming leading zeros from {}", this->numberString_); - std::string trimmed = numberString_; - bool negative = false; - size_t start = 0; - - if (!trimmed.empty() && trimmed[0] == '-') { - negative = true; - start = 1; - } - - // Find the position of the first non-zero character - size_t nonZeroPos = trimmed.find_first_not_of('0', start); - if (nonZeroPos == std::string::npos) { - // The number is zero - return BigNumber("0"); - } - - trimmed = trimmed.substr(nonZeroPos); - if (negative) { - trimmed.insert(trimmed.begin(), '-'); - } - - LOG_F(INFO, "Trimmed number: {}", trimmed); - return {trimmed}; - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in BigNumber::trimLeadingZeros: {}", e.what()); - throw; - } -} - -auto operator>(const BigNumber& b1, const BigNumber& b2) -> bool { - try { - LOG_F(INFO, "Comparing if {} > {}", b1.numberString_, b2.numberString_); - if (b1.isNegative() || b2.isNegative()) { - if (b1.isNegative() && b2.isNegative()) { - LOG_F(INFO, "Both numbers are negative. Flipping comparison."); - return atom::algorithm::BigNumber(b2).abs() > - atom::algorithm::BigNumber(b1).abs(); - } - return b1.isNegative() < b2.isNegative(); - } - - BigNumber b1Trimmed = b1.trimLeadingZeros(); - BigNumber b2Trimmed = b2.trimLeadingZeros(); - - if (b1Trimmed.numberString_.size() != b2Trimmed.numberString_.size()) { - return b1Trimmed.numberString_.size() > - b2Trimmed.numberString_.size(); - } - return b1Trimmed.numberString_ > b2Trimmed.numberString_; - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in operator>: {}", e.what()); - throw; - } -} - -void BigNumber::validate() const { - if (numberString_.empty()) { - THROW_INVALID_ARGUMENT("Empty string is not a valid number"); - } - size_t start = 0; - if (numberString_[0] == '-') { - if (numberString_.size() == 1) { - THROW_INVALID_ARGUMENT("Invalid number format"); - } - start = 1; - } - for (size_t i = start; i < numberString_.size(); ++i) { - if (std::isdigit(numberString_[i]) == 0) { - THROW_INVALID_ARGUMENT("Invalid character in number string"); - } - } -} - -} // namespace atom::algorithm \ No newline at end of file diff --git a/src/atom/algorithm/bignumber.hpp b/src/atom/algorithm/bignumber.hpp deleted file mode 100644 index cb90b210..00000000 --- a/src/atom/algorithm/bignumber.hpp +++ /dev/null @@ -1,404 +0,0 @@ -#ifndef ATOM_ALGORITHM_BIGNUMBER_HPP -#define ATOM_ALGORITHM_BIGNUMBER_HPP - -#include -#include -#include -#include -#include -#include -#include "atom/macro.hpp" - -namespace atom::algorithm { - -/** - * @class BigNumber - * @brief A class to represent and manipulate large numbers. - */ -class BigNumber { -public: - /** - * @brief Constructs a BigNumber from a string. - * @param number The string representation of the number. - */ - BigNumber(std::string number) : numberString_(std::move(number)) { - numberString_ = trimLeadingZeros().numberString_; - validate(); - } - - /** - * @brief Constructs a BigNumber from a long long integer. - * @param number The long long integer representation of the number. - */ - BigNumber(long long number) : numberString_(std::to_string(number)) {} - - /** - * @brief Adds two BigNumber objects. - * @param other The other BigNumber to add. - * @return The result of the addition. - */ - ATOM_NODISCARD auto add(const BigNumber& other) const -> BigNumber; - - /** - * @brief Subtracts one BigNumber from another. - * @param other The other BigNumber to subtract. - * @return The result of the subtraction. - */ - ATOM_NODISCARD auto subtract(const BigNumber& other) const -> BigNumber; - - /** - * @brief Multiplies two BigNumber objects. - * @param other The other BigNumber to multiply. - * @return The result of the multiplication. - */ - ATOM_NODISCARD auto multiply(const BigNumber& other) const -> BigNumber; - - /** - * @brief Divides one BigNumber by another. - * @param other The other BigNumber to divide by. - * @return The result of the division. - */ - ATOM_NODISCARD auto divide(const BigNumber& other) const -> BigNumber; - - /** - * @brief Raises the BigNumber to the power of an exponent. - * @param exponent The exponent to raise the number to. - * @return The result of the exponentiation. - */ - ATOM_NODISCARD auto pow(int exponent) const -> BigNumber; - - /** - * @brief Gets the string representation of the BigNumber. - * @return The string representation of the number. - */ - ATOM_NODISCARD auto getString() const -> std::string { - return numberString_; - } - - /** - * @brief Sets the string representation of the BigNumber. - * @param newStr The new string representation of the number. - * @return A reference to the updated BigNumber. - */ - auto setString(const std::string& newStr) -> BigNumber { - numberString_ = newStr; - numberString_ = trimLeadingZeros().numberString_; - validate(); - return *this; - } - - /** - * @brief Negates the BigNumber. - * @return The negated BigNumber. - */ - ATOM_NODISCARD auto negate() const -> BigNumber { - if (isNegative()) { - return BigNumber(numberString_.substr(1)); - } else { - return BigNumber("-" + numberString_); - } - } - - /** - * @brief Trims leading zeros from the BigNumber. - * @return The BigNumber with leading zeros removed. - */ - ATOM_NODISCARD auto trimLeadingZeros() const -> BigNumber; - - /** - * @brief Checks if two BigNumber objects are equal. - * @param other The other BigNumber to compare with. - * @return True if the numbers are equal, false otherwise. - */ - ATOM_NODISCARD auto equals(const BigNumber& other) const -> bool { - return numberString_ == other.numberString_; - } - - /** - * @brief Checks if the BigNumber is equal to a long long integer. - * @param other The long long integer to compare with. - * @return True if the number is equal to the integer, false otherwise. - */ - ATOM_NODISCARD auto equals(const long long& other) const -> bool { - return numberString_ == std::to_string(other); - } - - /** - * @brief Checks if the BigNumber is equal to a string. - * @param other The string to compare with. - * @return True if the number is equal to the string, false otherwise. - */ - ATOM_NODISCARD auto equals(const std::string& other) const -> bool { - return numberString_ == other; - } - - /** - * @brief Gets the number of digits in the BigNumber. - * @return The number of digits. - */ - ATOM_NODISCARD auto digits() const -> unsigned int { - return numberString_.length() - (isNegative() ? 1 : 0); - } - - /** - * @brief Checks if the BigNumber is negative. - * @return True if the number is negative, false otherwise. - */ - ATOM_NODISCARD auto isNegative() const -> bool { - return !numberString_.empty() && numberString_[0] == '-'; - } - - /** - * @brief Checks if the BigNumber is positive. - * @return True if the number is positive, false otherwise. - */ - ATOM_NODISCARD auto isPositive() const -> bool { return !isNegative(); } - - /** - * @brief Checks if the BigNumber is even. - * @return True if the number is even, false otherwise. - */ - ATOM_NODISCARD auto isEven() const -> bool { - if (numberString_.empty()) - return false; - return (numberString_.back() - '0') % 2 == 0; - } - - /** - * @brief Checks if the BigNumber is odd. - * @return True if the number is odd, false otherwise. - */ - ATOM_NODISCARD auto isOdd() const -> bool { return !isEven(); } - - /** - * @brief Gets the absolute value of the BigNumber. - * @return The absolute value of the number. - */ - ATOM_NODISCARD auto abs() const -> BigNumber { - return isNegative() ? BigNumber(numberString_.substr(1)) : *this; - } - - /** - * @brief Overloads the stream insertion operator for BigNumber. - * @param os The output stream. - * @param num The BigNumber to insert into the stream. - * @return The output stream. - */ - friend auto operator<<(std::ostream& os, - const BigNumber& num) -> std::ostream& { - os << num.numberString_; - return os; - } - - /** - * @brief Overloads the addition operator for BigNumber. - * @param b1 The first BigNumber. - * @param b2 The second BigNumber. - * @return The result of the addition. - */ - friend auto operator+(const BigNumber& b1, - const BigNumber& b2) -> BigNumber { - return b1.add(b2); - } - - /** - * @brief Overloads the subtraction operator for BigNumber. - * @param b1 The first BigNumber. - * @param b2 The second BigNumber. - * @return The result of the subtraction. - */ - friend auto operator-(const BigNumber& b1, - const BigNumber& b2) -> BigNumber { - return b1.subtract(b2); - } - - /** - * @brief Overloads the multiplication operator for BigNumber. - * @param b1 The first BigNumber. - * @param b2 The second BigNumber. - * @return The result of the multiplication. - */ - friend auto operator*(const BigNumber& b1, - const BigNumber& b2) -> BigNumber { - return b1.multiply(b2); - } - - /** - * @brief Overloads the division operator for BigNumber. - * @param b1 The first BigNumber. - * @param b2 The second BigNumber. - * @return The result of the division. - */ - friend auto operator/(const BigNumber& b1, - const BigNumber& b2) -> BigNumber { - return b1.divide(b2); - } - - /** - * @brief Overloads the exponentiation operator for BigNumber. - * @param b1 The BigNumber base. - * @param b2 The exponent. - * @return The result of the exponentiation. - */ - friend auto operator^(const BigNumber& b1, int b2) -> BigNumber { - return b1.pow(b2); - } - - /** - * @brief Overloads the equality operator for BigNumber. - * @param b1 The first BigNumber. - * @param b2 The second BigNumber. - * @return True if the numbers are equal, false otherwise. - */ - friend auto operator==(const BigNumber& b1, const BigNumber& b2) -> bool { - return b1.equals(b2); - } - - /** - * @brief Overloads the greater than operator for BigNumber. - * @param b1 The first BigNumber. - * @param b2 The second BigNumber. - * @return True if the first number is greater than the second, false - * otherwise. - */ - friend auto operator>(const BigNumber& b1, const BigNumber& b2) -> bool; - - /** - * @brief Overloads the less than operator for BigNumber. - * @param b1 The first BigNumber. - * @param b2 The second BigNumber. - * @return True if the first number is less than the second, false - * otherwise. - */ - friend auto operator<(const BigNumber& b1, const BigNumber& b2) -> bool { - return !(b1 == b2) && !(b1 > b2); - } - - /** - * @brief Overloads the greater than or equal to operator for BigNumber. - * @param b1 The first BigNumber. - * @param b2 The second BigNumber. - * @return True if the first number is greater than or equal to the second, - * false otherwise. - */ - friend auto operator>=(const BigNumber& b1, const BigNumber& b2) -> bool { - return b1 > b2 || b1 == b2; - } - - /** - * @brief Overloads the less than or equal to operator for BigNumber. - * @param b1 The first BigNumber. - * @param b2 The second BigNumber. - * @return True if the first number is less than or equal to the second, - * false otherwise. - */ - friend auto operator<=(const BigNumber& b1, const BigNumber& b2) -> bool { - return b1 < b2 || b1 == b2; - } - - /** - * @brief Overloads the addition assignment operator for BigNumber. - * @param other The other BigNumber to add. - * @return A reference to the updated BigNumber. - */ - auto operator+=(const BigNumber& other) -> BigNumber& { - *this = *this + other; - return *this; - } - - /** - * @brief Overloads the subtraction assignment operator for BigNumber. - * @param other The other BigNumber to subtract. - * @return A reference to the updated BigNumber. - */ - auto operator-=(const BigNumber& other) -> BigNumber& { - *this = *this - other; - return *this; - } - - /** - * @brief Overloads the multiplication assignment operator for BigNumber. - * @param other The other BigNumber to multiply. - * @return A reference to the updated BigNumber. - */ - auto operator*=(const BigNumber& other) -> BigNumber& { - *this = *this * other; - return *this; - } - - /** - * @brief Overloads the division assignment operator for BigNumber. - * @param other The other BigNumber to divide by. - * @return A reference to the updated BigNumber. - */ - auto operator/=(const BigNumber& other) -> BigNumber& { - *this = *this / other; - return *this; - } - - /** - * @brief Overloads the prefix increment operator for BigNumber. - * @return A reference to the incremented BigNumber. - */ - auto operator++() -> BigNumber& { - *this += BigNumber("1"); - return *this; - } - - /** - * @brief Overloads the prefix decrement operator for BigNumber. - * @return A reference to the decremented BigNumber. - */ - auto operator--() -> BigNumber& { - *this -= BigNumber("1"); - return *this; - } - - /** - * @brief Overloads the postfix increment operator for BigNumber. - * @return The BigNumber before incrementing. - */ - auto operator++(int) -> BigNumber { - BigNumber temp(*this); - ++(*this); - return temp; - } - - /** - * @brief Overloads the postfix decrement operator for BigNumber. - * @return The BigNumber before decrementing. - */ - auto operator--(int) -> BigNumber { - BigNumber temp(*this); - --(*this); - return temp; - } - - /** - * @brief Overloads the subscript operator for BigNumber. - * @param index The index of the digit to access. - * @return The digit at the specified index. - */ - auto operator[](int index) const -> unsigned int { - if (index < 0 || index >= static_cast(numberString_.size())) { - throw std::out_of_range("Index out of range"); - } - if (isNegative() && index == 0) { - throw std::invalid_argument("Cannot access the negative sign"); - } - return static_cast(numberString_[index] - '0'); - } - -private: - std::string numberString_; ///< The string representation of the number. - - /** - * @brief Validates the BigNumber string. - * @throws std::invalid_argument if the string is not a valid number. - */ - void validate() const; -}; - -} // namespace atom::algorithm - -#endif // ATOM_ALGORITHM_BIGNUMBER_HPP \ No newline at end of file diff --git a/src/atom/algorithm/convolve.cpp b/src/atom/algorithm/convolve.cpp deleted file mode 100644 index dc16fb88..00000000 --- a/src/atom/algorithm/convolve.cpp +++ /dev/null @@ -1,905 +0,0 @@ -/* - * convolve.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-10 - -Description: Implementation of one-dimensional and two-dimensional convolution -and deconvolution with optional OpenCL support. - -**************************************************/ - -#include "convolve.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if USE_SIMD -#ifdef _MSC_VER -#include -#define SIMD_ALIGNED __declspec(align(32)) -#else -#include -#define SIMD_ALIGNED __attribute__((aligned(32))) -#endif - -#ifdef __AVX__ -#define SIMD_ENABLED -#define SIMD_WIDTH 4 -#elif defined(__SSE__) -#define SIMD_ENABLED -#define SIMD_WIDTH 2 -#endif -#endif - -#if USE_OPENCL -#include -#endif - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wsign-compare" -#elif defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wsign-compare" -#elif defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - -// Code that might generate warnings - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#elif defined(__clang__) -#pragma clang diagnostic pop -#elif defined(_MSC_VER) -#pragma warning(pop) -#endif - -#include "atom/error/exception.hpp" - -namespace atom::algorithm { - -// Function to convolve a 1D input with a kernel -auto convolve(const std::vector &input, - const std::vector &kernel) -> std::vector { - auto inputSize = input.size(); - auto kernelSize = kernel.size(); - auto outputSize = inputSize + kernelSize - 1; - std::vector output(outputSize, 0.0); - -#ifdef SIMD_ENABLED - const int simd_width = SIMD_WIDTH; - SIMD_ALIGNED double aligned_kernel[kernelSize]; - std::memcpy(aligned_kernel, kernel.data(), kernelSize * sizeof(double)); - - for (std::size_t i = 0; i < outputSize; i += simd_width) { - __m256d sum = _mm256_setzero_pd(); - - for (std::size_t j = 0; j < kernelSize; ++j) { - if (i >= j && (i - j + simd_width) <= inputSize) { - __m256d input_vec = _mm256_loadu_pd(&input[i - j]); - __m256d kernel_val = _mm256_set1_pd(aligned_kernel[j]); - sum = _mm256_add_pd(sum, _mm256_mul_pd(input_vec, kernel_val)); - } - } - - _mm256_storeu_pd(&output[i], sum); - } - - // Handle remaining elements - for (std::size_t i = (outputSize / simd_width) * simd_width; i < outputSize; - ++i) { - for (std::size_t j = 0; j < kernelSize; ++j) { - if (i >= j && (i - j) < inputSize) { - output[i] += input[i - j] * kernel[j]; - } - } - } -#else - // Fallback to non-SIMD version - for (std::size_t i = 0; i < outputSize; ++i) { - for (std::size_t j = 0; j < kernelSize; ++j) { - if (i >= j && (i - j) < inputSize) { - output[i] += input[i - j] * kernel[j]; - } - } - } -#endif - - return output; -} - -// Function to deconvolve a 1D input with a kernel -auto deconvolve(const std::vector &input, - const std::vector &kernel) -> std::vector { - auto inputSize = input.size(); - auto kernelSize = kernel.size(); - if (kernelSize > inputSize) { - THROW_INVALID_ARGUMENT("Kernel size cannot be larger than input size."); - } - - auto outputSize = inputSize - kernelSize + 1; - std::vector output(outputSize, 0.0); - -#ifdef SIMD_ENABLED - const int simd_width = SIMD_WIDTH; - SIMD_ALIGNED double aligned_kernel[kernelSize]; - std::memcpy(aligned_kernel, kernel.data(), kernelSize * sizeof(double)); - - for (std::size_t i = 0; i < outputSize; i += simd_width) { - __m256d sum = _mm256_setzero_pd(); - - for (std::size_t j = 0; j < kernelSize; ++j) { - __m256d input_vec = _mm256_loadu_pd(&input[i + j]); - __m256d kernel_val = _mm256_set1_pd(aligned_kernel[j]); - sum = _mm256_add_pd(sum, _mm256_mul_pd(input_vec, kernel_val)); - } - - _mm256_storeu_pd(&output[i], sum); - } - - // Handle remaining elements - for (std::size_t i = (outputSize / simd_width) * simd_width; i < outputSize; - ++i) { - for (std::size_t j = 0; j < kernelSize; ++j) { - output[i] += input[i + j] * kernel[j]; - } - } -#else - // Fallback to non-SIMD version - for (std::size_t i = 0; i < outputSize; ++i) { - for (std::size_t j = 0; j < kernelSize; ++j) { - output[i] += input[i + j] * kernel[j]; - } - } -#endif - - return output; -} - -// Helper function to extend 2D vectors -template -auto extend2D(const std::vector> &input, std::size_t newRows, - std::size_t newCols) -> std::vector> { - std::vector> extended(newRows, std::vector(newCols, 0.0)); - auto inputRows = input.size(); - auto inputCols = input[0].size(); - for (std::size_t i = 0; i < inputRows; ++i) { - for (std::size_t j = 0; j < inputCols; ++j) { - extended[i + newRows / 2][j + newCols / 2] = input[i][j]; - } - } - return extended; -} - -#if USE_OPENCL -// OpenCL initialization and helper functions -auto initializeOpenCL() -> cl_context { - cl_uint numPlatforms; - cl_platform_id platform = nullptr; - clGetPlatformIDs(1, &platform, &numPlatforms); - - cl_context_properties properties[] = {CL_CONTEXT_PLATFORM, - (cl_context_properties)platform, 0}; - - cl_int err; - cl_context context = clCreateContextFromType(properties, CL_DEVICE_TYPE_GPU, - nullptr, nullptr, &err); - if (err != CL_SUCCESS) { - THROW_RUNTIME_ERROR("Failed to create OpenCL context."); - } - return context; -} - -auto createCommandQueue(cl_context context) -> cl_command_queue { - cl_device_id device_id; - clGetDeviceIDs(nullptr, CL_DEVICE_TYPE_GPU, 1, &device_id, nullptr); - cl_int err; - cl_command_queue commandQueue = - clCreateCommandQueue(context, device_id, 0, &err); - if (err != CL_SUCCESS) { - THROW_RUNTIME_ERROR("Failed to create OpenCL command queue."); - } - return commandQueue; -} - -auto createProgram(const std::string &source, - cl_context context) -> cl_program { - const char *sourceStr = source.c_str(); - cl_int err; - cl_program program = - clCreateProgramWithSource(context, 1, &sourceStr, nullptr, &err); - if (err != CL_SUCCESS) { - THROW_RUNTIME_ERROR("Failed to create OpenCL program."); - } - return program; -} - -void checkErr(cl_int err, const char *operation) { - if (err != CL_SUCCESS) { - std::string errMsg = "OpenCL Error during operation: "; - errMsg += operation; - THROW_RUNTIME_ERROR(errMsg.c_str()); - } -} - -// OpenCL kernel code for 2D convolution -const std::string convolve2DKernelSrc = R"CLC( -__kernel void convolve2D(__global const float* input, - __global const float* kernel, - __global float* output, - const int inputRows, - const int inputCols, - const int kernelRows, - const int kernelCols) { - int row = get_global_id(0); - int col = get_global_id(1); - - int halfKernelRows = kernelRows / 2; - int halfKernelCols = kernelCols / 2; - - float sum = 0.0; - for (int i = -halfKernelRows; i <= halfKernelRows; ++i) { - for (int j = -halfKernelCols; j <= halfKernelCols; ++j) { - int x = clamp(row + i, 0, inputRows - 1); - int y = clamp(col + j, 0, inputCols - 1); - sum += input[x * inputCols + y] * kernel[(i + halfKernelRows) * kernelCols + (j + halfKernelCols)]; - } - } - output[row * inputCols + col] = sum; -} -)CLC"; - -// Function to convolve a 2D input with a 2D kernel using OpenCL -auto convolve2DOpenCL(const std::vector> &input, - const std::vector> &kernel, - int numThreads) -> std::vector> { - auto context = initializeOpenCL(); - auto queue = createCommandQueue(context); - - auto inputRows = input.size(); - auto inputCols = input[0].size(); - auto kernelRows = kernel.size(); - auto kernelCols = kernel[0].size(); - - std::vector inputFlattened(inputRows * inputCols); - std::vector kernelFlattened(kernelRows * kernelCols); - std::vector outputFlattened(inputRows * inputCols, 0.0); - - for (size_t i = 0; i < inputRows; ++i) - for (size_t j = 0; j < inputCols; ++j) - inputFlattened[i * inputCols + j] = static_cast(input[i][j]); - - for (size_t i = 0; i < kernelRows; ++i) - for (size_t j = 0; j < kernelCols; ++j) - kernelFlattened[i * kernelCols + j] = - static_cast(kernel[i][j]); - - cl_int err; - cl_mem inputBuffer = clCreateBuffer( - context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, - sizeof(float) * inputFlattened.size(), inputFlattened.data(), &err); - checkErr(err, "Creating input buffer"); - - cl_mem kernelBuffer = clCreateBuffer( - context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, - sizeof(float) * kernelFlattened.size(), kernelFlattened.data(), &err); - checkErr(err, "Creating kernel buffer"); - - cl_mem outputBuffer = - clCreateBuffer(context, CL_MEM_WRITE_ONLY, - sizeof(float) * outputFlattened.size(), nullptr, &err); - checkErr(err, "Creating output buffer"); - - cl_program program = createProgram(convolve2DKernelSrc, context); - err = clBuildProgram(program, 0, nullptr, nullptr, nullptr, nullptr); - checkErr(err, "Building program"); - - cl_kernel kernel = clCreateKernel(program, "convolve2D", &err); - checkErr(err, "Creating kernel"); - - err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputBuffer); - err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &kernelBuffer); - err |= clSetKernelArg(kernel, 2, sizeof(cl_mem), &outputBuffer); - err |= clSetKernelArg(kernel, 3, sizeof(int), &inputRows); - err |= clSetKernelArg(kernel, 4, sizeof(int), &inputCols); - err |= clSetKernelArg(kernel, 5, sizeof(int), &kernelRows); - err |= clSetKernelArg(kernel, 6, sizeof(int), &kernelCols); - checkErr(err, "Setting kernel arguments"); - - size_t globalWorkSize[2] = {static_cast(inputRows), - static_cast(inputCols)}; - err = clEnqueueNDRangeKernel(queue, kernel, 2, nullptr, globalWorkSize, - nullptr, 0, nullptr, nullptr); - checkErr(err, "Enqueueing kernel"); - - err = clEnqueueReadBuffer(queue, outputBuffer, CL_TRUE, 0, - sizeof(float) * outputFlattened.size(), - outputFlattened.data(), 0, nullptr, nullptr); - checkErr(err, "Reading back output buffer"); - - // Convert output back to 2D vector - std::vector> output( - inputRows, std::vector(inputCols, 0.0)); - for (size_t i = 0; i < inputRows; ++i) - for (size_t j = 0; j < inputCols; ++j) - output[i][j] = - static_cast(outputFlattened[i * inputCols + j]); - - // Clean up OpenCL resources - clReleaseMemObject(inputBuffer); - clReleaseMemObject(kernelBuffer); - clReleaseMemObject(outputBuffer); - clReleaseKernel(kernel); - clReleaseProgram(program); - clReleaseCommandQueue(queue); - clReleaseContext(context); - - return output; -} -#endif - -// Function to convolve a 2D input with a 2D kernel using multithreading or -// OpenCL -auto convolve2D(const std::vector> &input, - const std::vector> &kernel, - int numThreads) -> std::vector> { -#if USE_OPENCL - return convolve2DOpenCL(input, kernel, numThreads); -#else - auto inputRows = input.size(); - auto inputCols = input[0].size(); - auto kernelRows = kernel.size(); - auto kernelCols = kernel[0].size(); - - auto extendedInput = - extend2D(input, inputRows + kernelRows - 1, inputCols + kernelCols - 1); - auto extendedKernel = extend2D(kernel, inputRows + kernelRows - 1, - inputCols + kernelCols - 1); - - std::vector> output( - inputRows, std::vector(inputCols, 0.0)); - - // Function to compute a block of the convolution using SIMD - auto computeBlock = [&](std::size_t blockStartRow, - std::size_t blockEndRow) { -#if USE_SIMD - SIMD_ALIGNED -#endif - double aligned_kernel[kernelRows * kernelCols]; - for (std::size_t i = 0; i < kernelRows; ++i) { - std::memcpy(&aligned_kernel[i * kernelCols], kernel[i].data(), - kernelCols * sizeof(double)); - } - -#ifdef SIMD_ENABLED - const int simd_width = SIMD_WIDTH; - for (std::size_t i = blockStartRow; i < blockEndRow; ++i) { - for (std::size_t j = kernelCols / 2; j < inputCols + kernelCols / 2; - j += simd_width) { - __m256d sum = _mm256_setzero_pd(); - - for (std::size_t k = 0; k < kernelRows; ++k) { - for (std::size_t colOffset = 0; colOffset < kernelCols; - ++colOffset) { - __m256d input_vec = _mm256_loadu_pd( - &extendedInput[i + k - kernelRows / 2] - [j + colOffset - kernelCols / 2]); - __m256d kernel_val = _mm256_set1_pd( - aligned_kernel[k * kernelCols + colOffset]); - sum = _mm256_add_pd( - sum, _mm256_mul_pd(input_vec, kernel_val)); - } - } - - _mm256_storeu_pd( - &output[i - kernelRows / 2][j - kernelCols / 2], sum); - } - - // Handle remaining elements - for (std::size_t j = - ((inputCols + kernelCols / 2) / simd_width) * simd_width + - kernelCols / 2; - j < inputCols + kernelCols / 2; ++j) { - double sum = 0.0; - for (std::size_t k = 0; k < kernelRows; ++k) { - for (std::size_t colOffset = 0; colOffset < kernelCols; - ++colOffset) { - sum += extendedInput[i + k - kernelRows / 2] - [j + colOffset - kernelCols / 2] * - aligned_kernel[k * kernelCols + colOffset]; - } - } - output[i - kernelRows / 2][j - kernelCols / 2] = sum; - } - } -#else - // Fallback to non-SIMD version - for (std::size_t i = blockStartRow; i < blockEndRow; ++i) { - for (std::size_t j = kernelCols / 2; j < inputCols + kernelCols / 2; - ++j) { - double sum = 0.0; - for (std::size_t k = 0; k < kernelRows; ++k) { - for (std::size_t colOffset = 0; colOffset < kernelCols; - ++colOffset) { - sum += extendedInput[i + k - kernelRows / 2] - [j + colOffset - kernelCols / 2] * - aligned_kernel[k * kernelCols + colOffset]; - } - } - output[i - kernelRows / 2][j - kernelCols / 2] = sum; - } - } -#endif - }; - - // Use multiple threads if requested - if (numThreads > 1) { - std::vector threads; - std::size_t blockSize = (inputRows + numThreads - 1) / numThreads; - std::size_t blockStartRow = kernelRows / 2; - - for (int i = 0; i < numThreads; ++i) { - std::size_t blockEndRow = - std::min(blockStartRow + blockSize, inputRows + kernelRows / 2); - threads.emplace_back(computeBlock, blockStartRow, blockEndRow); - blockStartRow = blockEndRow; - } - - for (auto &thread : threads) { - thread.join(); - } - } else { - // Single-threaded execution - computeBlock(kernelRows / 2, inputRows + kernelRows / 2); - } - - return output; -#endif -} - -// Function to deconvolve a 2D input with a 2D kernel using multithreading or -// OpenCL -auto deconvolve2D(const std::vector> &signal, - const std::vector> &kernel, - int numThreads) -> std::vector> { -#if USE_OPENCL - // Implement OpenCL support if necessary - return deconvolve2DOpenCL(signal, kernel, numThreads); -#else - int M = signal.size(); - int N = signal[0].size(); - int K = kernel.size(); - int L = kernel[0].size(); - - auto extendedSignal = extend2D(signal, M + K - 1, N + L - 1); - auto extendedKernel = extend2D(kernel, M + K - 1, N + L - 1); - - auto dfT2DWrapper = [&](const std::vector> &input) { - return dfT2D(input, - numThreads); // Assume DFT2D supports multithreading - }; - - auto x = dfT2DWrapper(extendedSignal); - auto h = dfT2DWrapper(extendedKernel); - - std::vector>> g( - M + K - 1, std::vector>(N + L - 1)); - double alpha = 0.1; // Prevent division by zero - - // SIMD-optimized computation of g -#ifdef SIMD_ENABLED - const int simd_width = SIMD_WIDTH; - __m256d alpha_vec = _mm256_set1_pd(alpha); - - for (int u = 0; u < M + K - 1; ++u) { - for (int v = 0; v < N + L - 1; v += simd_width) { - __m256d h_real = _mm256_loadu_pd(&h[u][v].real()); - __m256d h_imag = _mm256_loadu_pd(&h[u][v].imag()); - - __m256d h_abs = _mm256_sqrt_pd(_mm256_add_pd( - _mm256_mul_pd(h_real, h_real), _mm256_mul_pd(h_imag, h_imag))); - __m256d mask = _mm256_cmp_pd(h_abs, alpha_vec, _CMP_GT_OQ); - - __m256d norm = _mm256_add_pd(_mm256_mul_pd(h_real, h_real), - _mm256_mul_pd(h_imag, h_imag)); - norm = _mm256_add_pd(norm, alpha_vec); - - __m256d g_real = _mm256_div_pd(h_real, norm); - __m256d g_imag = _mm256_div_pd( - _mm256_xor_pd(h_imag, _mm256_set1_pd(-0.0)), norm); - - g_real = _mm256_blendv_pd(h_real, g_real, mask); - g_imag = _mm256_blendv_pd(h_imag, g_imag, mask); - - _mm256_storeu_pd(&g[u][v].real(), g_real); - _mm256_storeu_pd(&g[u][v].imag(), g_imag); - } - - // Handle remaining elements - for (int v = ((N + L - 1) / simd_width) * simd_width; v < N + L - 1; - ++v) { - if (std::abs(h[u][v]) > alpha) { - g[u][v] = std::conj(h[u][v]) / (std::norm(h[u][v]) + alpha); - } else { - g[u][v] = std::conj(h[u][v]); - } - } - } -#else - // Fallback to non-SIMD version - for (int u = 0; u < M + K - 1; ++u) { - for (int v = 0; v < N + L - 1; ++v) { - if (std::abs(h[u][v]) > alpha) { - g[u][v] = std::conj(h[u][v]) / (std::norm(h[u][v]) + alpha); - } else { - g[u][v] = std::conj(h[u][v]); - } - } - } -#endif - - std::vector>> Y( - M + K - 1, std::vector>(N + L - 1)); - - // SIMD-optimized computation of Y -#ifdef SIMD_ENABLED - for (int u = 0; u < M + K - 1; ++u) { - for (int v = 0; v < N + L - 1; v += simd_width) { - __m256d g_real = _mm256_loadu_pd(&g[u][v].real()); - __m256d g_imag = _mm256_loadu_pd(&g[u][v].imag()); - __m256d x_real = _mm256_loadu_pd(&x[u][v].real()); - __m256d x_imag = _mm256_loadu_pd(&x[u][v].imag()); - - __m256d y_real = _mm256_sub_pd(_mm256_mul_pd(g_real, x_real), - _mm256_mul_pd(g_imag, x_imag)); - __m256d y_imag = _mm256_add_pd(_mm256_mul_pd(g_real, x_imag), - _mm256_mul_pd(g_imag, x_real)); - - _mm256_storeu_pd(&Y[u][v].real(), y_real); - _mm256_storeu_pd(&Y[u][v].imag(), y_imag); - } - - // Handle remaining elements - for (int v = ((N + L - 1) / simd_width) * simd_width; v < N + L - 1; - ++v) { - Y[u][v] = g[u][v] * x[u][v]; - } - } -#else - // Fallback to non-SIMD version - for (int u = 0; u < M + K - 1; ++u) { - for (int v = 0; v < N + L - 1; ++v) { - Y[u][v] = g[u][v] * x[u][v]; - } - } -#endif - - auto y = idfT2D(Y, numThreads); - - std::vector> result(M, std::vector(N, 0.0)); - for (int i = 0; i < M; ++i) { - for (int j = 0; j < N; ++j) { - result[i][j] = y[i][j]; - } - } - - return result; -#endif -} - -// 2D Discrete Fourier Transform (2D DFT) -auto dfT2D(const std::vector> &signal, - int numThreads) -> std::vector>> { - const auto M = signal.size(); - const auto N = signal[0].size(); - std::vector>> X( - M, std::vector>(N, {0, 0})); - - // Lambda function to compute the DFT for a block of rows - auto computeDFT = [&M, &N, &signal, &X](int startRow, int endRow) { - for (int u = startRow; u < endRow; ++u) { - for (int v = 0; v < N; ++v) { -#if USE_SIMD - __m256d sum_real = _mm256_setzero_pd(); - __m256d sum_imag = _mm256_setzero_pd(); - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; n += SIMD_WIDTH) { - __m256d theta = _mm256_set_pd( - -2 * std::numbers::pi * - ((u * m / static_cast(M)) + - (v * (n + 3) / static_cast(N))), - -2 * std::numbers::pi * - ((u * m / static_cast(M)) + - (v * (n + 2) / static_cast(N))), - -2 * std::numbers::pi * - ((u * m / static_cast(M)) + - (v * (n + 1) / static_cast(N))), - -2 * std::numbers::pi * - ((u * m / static_cast(M)) + - (v * n / static_cast(N)))); - __m256d w_real = _mm256_cos_pd(theta); - __m256d w_imag = _mm256_sin_pd(theta); - __m256d signal_val = _mm256_loadu_pd(&signal[m][n]); - - sum_real = - _mm256_fmadd_pd(signal_val, w_real, sum_real); - sum_imag = - _mm256_fmadd_pd(signal_val, w_imag, sum_imag); - } - } - X[u][v] = std::complex(_mm256_reduce_add_pd(sum_real), - _mm256_reduce_add_pd(sum_imag)); -#else - std::complex sum(0, 0); - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; ++n) { - double theta = -2 * std::numbers::pi * - ((u * m / static_cast(M)) + - (v * n / static_cast(N))); - std::complex w(cos(theta), sin(theta)); - sum += signal[m][n] * w; - } - } - X[u][v] = sum; -#endif - } - } - }; - - // Multithreading support - if (numThreads > 1) { - std::vector threads; - auto rowsPerThread = M / numThreads; - for (int i = 0; i < numThreads; ++i) { - auto startRow = i * rowsPerThread; - auto endRow = (i == numThreads - 1) ? M : startRow + rowsPerThread; - threads.emplace_back(computeDFT, startRow, endRow); - } - for (auto &thread : threads) { - thread.join(); - } - } else { - // Single-threaded execution - computeDFT(0, M); - } - - return X; -} - -// 2D Inverse Discrete Fourier Transform (2D IDFT) -auto idfT2D(const std::vector>> &spectrum, - int numThreads) -> std::vector> { - const auto M = spectrum.size(); - const auto N = spectrum[0].size(); - std::vector> x(M, std::vector(N, 0.0)); - - // Lambda function to compute the IDFT for a block of rows - auto computeIDFT = [&M, &N, &spectrum, &x](int startRow, int endRow) { - for (int m = startRow; m < endRow; ++m) { - for (int n = 0; n < N; ++n) { -#if USE_SIMD - __m256d sum_real = _mm256_setzero_pd(); - __m256d sum_imag = _mm256_setzero_pd(); - for (int u = 0; u < M; ++u) { - for (int v = 0; v < N; v += SIMD_WIDTH) { - __m256d theta = _mm256_set_pd( - 2 * std::numbers::pi * - ((u * m / static_cast(M)) + - (v * (n + 3) / static_cast(N))), - 2 * std::numbers::pi * - ((u * m / static_cast(M)) + - (v * (n + 2) / static_cast(N))), - 2 * std::numbers::pi * - ((u * m / static_cast(M)) + - (v * (n + 1) / static_cast(N))), - 2 * std::numbers::pi * - ((u * m / static_cast(M)) + - (v * n / static_cast(N)))); - __m256d w_real = _mm256_cos_pd(theta); - __m256d w_imag = _mm256_sin_pd(theta); - __m256d spectrum_real = _mm256_loadu_pd( - reinterpret_cast(&spectrum[u][v])); - __m256d spectrum_imag = _mm256_loadu_pd( - reinterpret_cast(&spectrum[u][v]) + - 4); - - sum_real = - _mm256_fmadd_pd(spectrum_real, w_real, sum_real); - sum_imag = - _mm256_fmadd_pd(spectrum_imag, w_imag, sum_imag); - } - } - x[m][n] = (_mm256_reduce_add_pd(sum_real) + - _mm256_reduce_add_pd(sum_imag)) / - (M * N); -#else - std::complex sum(0.0, 0.0); - for (int u = 0; u < M; ++u) { - for (int v = 0; v < N; ++v) { - double theta = 2 * std::numbers::pi * - ((u * m / static_cast(M)) + - (v * n / static_cast(N))); - std::complex w(cos(theta), sin(theta)); - sum += spectrum[u][v] * w; - } - } - x[m][n] = - std::real(sum) / (M * N); // Normalize by dividing by M*N -#endif - } - } - }; - - // Multithreading support - if (numThreads > 1) { - std::vector threads; - auto rowsPerThread = M / numThreads; - for (int i = 0; i < numThreads; ++i) { - auto startRow = i * rowsPerThread; - auto endRow = (i == numThreads - 1) ? M : startRow + rowsPerThread; - threads.emplace_back(computeIDFT, startRow, endRow); - } - for (auto &thread : threads) { - thread.join(); - } - } else { - // Single-threaded execution - computeIDFT(0, M); - } - - return x; -} - -auto generateGaussianKernel(int size, - double sigma) -> std::vector> { - std::vector> kernel(size, std::vector(size)); - double sum = 0.0; - int center = size / 2; - -#if USE_SIMD - SIMD_ALIGNED double temp_buffer[SIMD_WIDTH]; - __m256d sigma_vec = _mm256_set1_pd(sigma); - __m256d two_sigma_squared = - _mm256_mul_pd(_mm256_set1_pd(2.0), _mm256_mul_pd(sigma_vec, sigma_vec)); - __m256d scale = _mm256_div_pd( - _mm256_set1_pd(1.0), - _mm256_mul_pd(_mm256_set1_pd(2 * std::numbers::pi), two_sigma_squared)); - - for (int i = 0; i < size; ++i) { - __m256d i_vec = _mm256_set1_pd(i - center); - for (int j = 0; j < size; j += SIMD_WIDTH) { - __m256d j_vec = _mm256_set_pd(j + 3 - center, j + 2 - center, - j + 1 - center, j - center); - - __m256d x_squared = _mm256_mul_pd(i_vec, i_vec); - __m256d y_squared = _mm256_mul_pd(j_vec, j_vec); - __m256d exponent = _mm256_div_pd( - _mm256_add_pd(x_squared, y_squared), two_sigma_squared); - __m256d kernel_values = _mm256_mul_pd( - scale, - exp256_pd(_mm256_mul_pd(_mm256_set1_pd(-0.5), exponent))); - - _mm256_store_pd(temp_buffer, kernel_values); - for (int k = 0; k < SIMD_WIDTH && j + k < size; ++k) { - kernel[i][j + k] = temp_buffer[k]; - sum += temp_buffer[k]; - } - } - } - - // Normalize to ensure the sum of the weights is 1 - __m256d sum_vec = _mm256_set1_pd(sum); - for (int i = 0; i < size; ++i) { - for (int j = 0; j < size; j += SIMD_WIDTH) { - __m256d kernel_values = _mm256_loadu_pd(&kernel[i][j]); - kernel_values = _mm256_div_pd(kernel_values, sum_vec); - _mm256_storeu_pd(&kernel[i][j], kernel_values); - } - } -#else - for (int i = 0; i < size; ++i) { - for (int j = 0; j < size; ++j) { - kernel[i][j] = exp(-0.5 * (pow((i - center) / sigma, 2.0) + - pow((j - center) / sigma, 2.0))) / - (2 * std::numbers::pi * sigma * sigma); - sum += kernel[i][j]; - } - } - - // Normalize to ensure the sum of the weights is 1 - for (int i = 0; i < size; ++i) { - for (int j = 0; j < size; ++j) { - kernel[i][j] /= sum; - } - } -#endif - - return kernel; -} - -auto applyGaussianFilter(const std::vector> &image, - const std::vector> &kernel) - -> std::vector> { - auto imageHeight = image.size(); - auto imageWidth = image[0].size(); - auto kernelSize = kernel.size(); - auto kernelRadius = kernelSize / 2; - std::vector> filteredImage( - imageHeight, std::vector(imageWidth, 0)); - -#if USE_SIMD - SIMD_ALIGNED double temp_buffer[SIMD_WIDTH]; - - for (auto i = 0; i < imageHeight; ++i) { - for (auto j = 0; j < imageWidth; j += SIMD_WIDTH) { - __m256d sum_vec = _mm256_setzero_pd(); - - for (auto k = -kernelRadius; k <= kernelRadius; ++k) { - for (auto l = -kernelRadius; l <= kernelRadius; ++l) { - __m256d kernel_val = _mm256_set1_pd( - kernel[kernelRadius + k][kernelRadius + l]); - - for (int m = 0; m < SIMD_WIDTH; ++m) { - auto x = std::clamp(static_cast(i + k), 0, - static_cast(imageHeight) - 1); - auto y = std::clamp(static_cast(j + l + m), 0, - static_cast(imageWidth) - 1); - temp_buffer[m] = image[x][y]; - } - - __m256d image_val = _mm256_loadu_pd(temp_buffer); - sum_vec = _mm256_add_pd( - sum_vec, _mm256_mul_pd(image_val, kernel_val)); - } - } - - _mm256_storeu_pd(temp_buffer, sum_vec); - for (int m = 0; m < SIMD_WIDTH && j + m < imageWidth; ++m) { - filteredImage[i][j + m] = temp_buffer[m]; - } - } - } -#else - for (auto i = 0; i < imageHeight; ++i) { - for (auto j = 0; j < imageWidth; ++j) { - double sum = 0.0; - for (auto k = -kernelRadius; k <= kernelRadius; ++k) { - for (auto l = -kernelRadius; l <= kernelRadius; ++l) { - auto x = std::clamp(static_cast(i + k), 0, - static_cast(imageHeight) - 1); - auto y = std::clamp(static_cast(j + l), 0, - static_cast(imageWidth) - 1); - sum += image[x][y] * - kernel[kernelRadius + k][kernelRadius + l]; - } - } - filteredImage[i][j] = sum; - } - } -#endif - return filteredImage; -} - -} // namespace atom::algorithm - -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#elif defined(__clang__) -#pragma clang diagnostic pop -#elif defined(_MSC_VER) -#pragma warning(pop) -#endif diff --git a/src/atom/algorithm/convolve.hpp b/src/atom/algorithm/convolve.hpp deleted file mode 100644 index 4b8d7592..00000000 --- a/src/atom/algorithm/convolve.hpp +++ /dev/null @@ -1,133 +0,0 @@ -/* - * convolve.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-10 - -Description: Implementation of one-dimensional and two-dimensional convolution -and deconvolution. - -**************************************************/ - -#ifndef ATOM_ALGORITHM_CONVOLVE_HPP -#define ATOM_ALGORITHM_CONVOLVE_HPP - -#include -#include - -namespace atom::algorithm { -/** - * @brief Performs 1D convolution operation. - * - * This function convolves the input signal with the given kernel. - * - * @param input The input signal. - * @param kernel The convolution kernel. - * @return The convolved signal. - */ -[[nodiscard("The result of convolve is not used.")]] auto convolve( - const std::vector &input, - const std::vector &kernel) -> std::vector; - -/** - * @brief Performs 1D deconvolution operation. - * - * This function deconvolves the input signal with the given kernel. - * - * @param input The input signal. - * @param kernel The deconvolution kernel. - * @return The deconvolved signal. - */ -[[nodiscard("The result of deconvolve is not used.")]] auto deconvolve( - const std::vector &input, - const std::vector &kernel) -> std::vector; - -/** - * @brief Performs 2D convolution operation. - * - * This function convolves the input image with the given kernel. - * - * @param input The input image. - * @param kernel The convolution kernel. - * @param numThreads Number of threads for parallel execution (default: 1). - * @return The convolved image. - */ -[[nodiscard("The result of convolve2D is not used.")]] auto convolve2D( - const std::vector> &input, - const std::vector> &kernel, - int numThreads = 1) -> std::vector>; - -/** - * @brief Performs 2D deconvolution operation. - * - * This function deconvolves the input image with the given kernel. - * - * @param signal The input image. - * @param kernel The deconvolution kernel. - * @param numThreads Number of threads for parallel execution (default: 1). - * @return The deconvolved image. - */ -[[nodiscard("The result of deconvolve2D is not used.")]] auto deconvolve2D( - const std::vector> &signal, - const std::vector> &kernel, - int numThreads = 1) -> std::vector>; - -/** - * @brief Performs 2D Discrete Fourier Transform (DFT). - * - * This function computes the 2D DFT of the input image. - * - * @param signal The input image. - * @param numThreads Number of threads for parallel execution (default: 1). - * @return The 2D DFT spectrum. - */ -[[nodiscard("The result of DFT2D is not used.")]] auto dfT2D( - const std::vector> &signal, - int numThreads = 1) -> std::vector>>; - -/** - * @brief Performs 2D Inverse Discrete Fourier Transform (IDFT). - * - * This function computes the 2D IDFT of the input spectrum. - * - * @param spectrum The input spectrum. - * @param numThreads Number of threads for parallel execution (default: 1). - * @return The 2D IDFT image. - */ -[[nodiscard("The result of IDFT2D is not used.")]] auto idfT2D( - const std::vector>> &spectrum, - int numThreads = 1) -> std::vector>; - -/** - * @brief Generates a Gaussian kernel for 2D convolution. - * - * This function generates a Gaussian kernel for 2D convolution. - * - * @param size The size of the kernel. - * @param sigma The standard deviation of the Gaussian distribution. - * @return The generated Gaussian kernel. - */ -[[nodiscard("The result of generateGaussianKernel is not used.")]] auto -generateGaussianKernel(int size, - double sigma) -> std::vector>; - -/** - * @brief Applies a Gaussian filter to an image. - * - * This function applies a Gaussian filter to an image. - * - * @param image The input image. - * @param kernel The Gaussian kernel. - * @return The filtered image. - */ -[[nodiscard("The result of applyGaussianFilter is not used.")]] auto -applyGaussianFilter(const std::vector> &image, - const std::vector> &kernel) - -> std::vector>; -} // namespace atom::algorithm - -#endif diff --git a/src/atom/algorithm/error_calibration.hpp b/src/atom/algorithm/error_calibration.hpp deleted file mode 100644 index 80b430c3..00000000 --- a/src/atom/algorithm/error_calibration.hpp +++ /dev/null @@ -1,650 +0,0 @@ -#ifndef ATOM_ALGORITHM_ERROR_CALIBRATION_HPP -#define ATOM_ALGORITHM_ERROR_CALIBRATION_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef USE_SIMD -#ifdef __AVX__ -#include -#elif defined(__ARM_NEON) -#include -#endif -#endif - -#include "atom/error/exception.hpp" -#include "atom/log/loguru.hpp" - -namespace atom::algorithm { - -template -class AdvancedErrorCalibration { -private: - T slope_ = 1.0; - T intercept_ = 0.0; - std::optional r_squared_; - std::vector residuals_; - T mse_ = 0.0; // Mean Squared Error - T mae_ = 0.0; // Mean Absolute Error - - std::mutex metrics_mutex_; - - /** - * Calculate calibration metrics - * @param measured Vector of measured values - * @param actual Vector of actual values - */ - void calculateMetrics(const std::vector& measured, - const std::vector& actual) { - T sumSquaredError = 0.0; - T sumAbsoluteError = 0.0; - T meanActual = - std::accumulate(actual.begin(), actual.end(), T(0)) / actual.size(); - T ssTotal = 0; - T ssResidual = 0; - - residuals_.clear(); - -#ifdef USE_SIMD -#ifdef __AVX__ - // SIMD optimized loop for x86 using AVX - __m256d sumSquaredErrorVec = _mm256_setzero_pd(); - __m256d sumAbsoluteErrorVec = _mm256_setzero_pd(); - size_t i = 0; - - for (; i + 4 <= actual.size(); i += 4) { - __m256d measuredVec = _mm256_loadu_pd(&measured[i]); - __m256d actualVec = _mm256_loadu_pd(&actual[i]); - - __m256d predictedVec = _mm256_add_pd( - _mm256_mul_pd(_mm256_set1_pd(slope_), measuredVec), - _mm256_set1_pd(intercept_)); - - __m256d errorVec = _mm256_sub_pd(actualVec, predictedVec); - - sumSquaredErrorVec = _mm256_add_pd( - sumSquaredErrorVec, _mm256_mul_pd(errorVec, errorVec)); - sumAbsoluteErrorVec = - _mm256_add_pd(sumAbsoluteErrorVec, - _mm256_andnot_pd(_mm256_set1_pd(-0.0), errorVec)); - - ssTotal += std::pow(actual[i] - meanActual, 2); - ssResidual += - std::pow(_mm256_extract_pd(predictedVec, 0) - actual[i], 2); - } - - double tempSquaredError[4]; - _mm256_storeu_pd(tempSquaredError, sumSquaredErrorVec); - sumSquaredError = std::accumulate( - tempSquaredError, tempSquaredError + 4, sumSquaredError); - - double tempAbsoluteError[4]; - _mm256_storeu_pd(tempAbsoluteError, sumAbsoluteErrorVec); - sumAbsoluteError = std::accumulate( - tempAbsoluteError, tempAbsoluteError + 4, sumAbsoluteError); - -#elif defined(__ARM_NEON) - // SIMD optimized loop for ARM using NEON - float64x2_t sumSquaredErrorVec = vdupq_n_f64(0.0); - float64x2_t sumAbsoluteErrorVec = vdupq_n_f64(0.0); - size_t i = 0; - - for (; i + 2 <= actual.size(); i += 2) { - float64x2_t measuredVec = vld1q_f64(&measured[i]); - float64x2_t actualVec = vld1q_f64(&actual[i]); - - float64x2_t predictedVec = - vmlaq_n_f64(vdupq_n_f64(intercept_), measuredVec, slope_); - - float64x2_t errorVec = vsubq_f64(actualVec, predictedVec); - - sumSquaredErrorVec = - vmlaq_f64(sumSquaredErrorVec, errorVec, errorVec); - sumAbsoluteErrorVec = - vaddq_f64(sumAbsoluteErrorVec, vabsq_f64(errorVec)); - - ssTotal += std::pow(actual[i] - meanActual, 2); - ssResidual += std::pow(predictedVec[0] - actual[i], 2); - } - - double tempSquaredError[2]; - vst1q_f64(tempSquaredError, sumSquaredErrorVec); - sumSquaredError = std::accumulate( - tempSquaredError, tempSquaredError + 2, sumSquaredError); - - double tempAbsoluteError[2]; - vst1q_f64(tempAbsoluteError, sumAbsoluteErrorVec); - sumAbsoluteError = std::accumulate( - tempAbsoluteError, tempAbsoluteError + 2, sumAbsoluteError); - -#endif -#endif - - // Multithreaded computation for remaining elements - std::vector> futures; - size_t i = 0; - size_t chunk_size = 100; - for (size_t start = i; start < actual.size(); start += chunk_size) { - size_t end = std::min(start + chunk_size, actual.size()); - futures.emplace_back( - std::async(std::launch::async, [&, start, end]() { - T localSumSquared = 0.0; - T localSumAbsolute = 0.0; - T localSsTotal = 0.0; - T localSsResidual = 0.0; - std::vector localResiduals; - for (size_t j = start; j < end; ++j) { - T predicted = apply(measured[j]); - T error = actual[j] - predicted; - localResiduals.push_back(error); - - localSumSquared += error * error; - localSumAbsolute += std::abs(error); - localSsTotal += std::pow(actual[j] - meanActual, 2); - localSsResidual += std::pow(error, 2); - } - std::lock_guard lock(metrics_mutex_); - sumSquaredError += localSumSquared; - sumAbsoluteError += localSumAbsolute; - ssTotal += localSsTotal; - ssResidual += localSsResidual; - residuals_.insert(residuals_.end(), localResiduals.begin(), - localResiduals.end()); - })); - } - - for (auto& fut : futures) { - fut.get(); - } - - mse_ = sumSquaredError / actual.size(); - mae_ = sumAbsoluteError / actual.size(); - r_squared_ = 1 - (ssResidual / ssTotal); - } - - using NonlinearFunction = std::function&)>; - - /** - * Solve a system of linear equations using the Levenberg-Marquardt method - * @param x Vector of x values - * @param y Vector of y values - * @param func Nonlinear function to fit - * @param initial_params Initial guess for the parameters - * @param max_iterations Maximum number of iterations - * @param lambda Regularization parameter - * @param epsilon Convergence criterion - * @return Vector of optimized parameters - */ - auto levenbergMarquardt(const std::vector& x, const std::vector& y, - NonlinearFunction func, - std::vector initial_params, - int max_iterations = 100, T lambda = 0.01, - T epsilon = 1e-8) -> std::vector { - int n = x.size(); - int m = initial_params.size(); - std::vector params = initial_params; - std::vector prevParams(m); - std::vector> jacobian(n, std::vector(m)); - - for (int iteration = 0; iteration < max_iterations; ++iteration) { - std::vector residuals(n); - for (int i = 0; i < n; ++i) { - try { - residuals[i] = y[i] - func(x[i], params); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in func: %s", e.what()); - throw; - } - for (int j = 0; j < m; ++j) { - T h = std::max(T(1e-6), std::abs(params[j]) * T(1e-6)); - std::vector paramsPlusH = params; - paramsPlusH[j] += h; - try { - jacobian[i][j] = - (func(x[i], paramsPlusH) - func(x[i], params)) / h; - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in jacobian computation: %s", - e.what()); - throw; - } - } - } - - std::vector> JTJ(m, std::vector(m, 0.0)); - std::vector jTr(m, 0.0); - for (int i = 0; i < m; ++i) { - for (int j = 0; j < m; ++j) { - for (int k = 0; k < n; ++k) { - JTJ[i][j] += jacobian[k][i] * jacobian[k][j]; - } - if (i == j) - JTJ[i][j] += lambda; - } - for (int k = 0; k < n; ++k) { - jTr[i] += jacobian[k][i] * residuals[k]; - } - } - - std::vector delta; - try { - delta = solveLinearSystem(JTJ, jTr); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in solving linear system: %s", - e.what()); - throw; - } - - prevParams = params; - for (int i = 0; i < m; ++i) { - params[i] += delta[i]; - } - - T diff = 0; - for (int i = 0; i < m; ++i) { - diff += std::abs(params[i] - prevParams[i]); - } - if (diff < epsilon) { - break; - } - } - - return params; - } - - /** - * Solve a system of linear equations using Gaussian elimination - * @param A Coefficient matrix - * @param b Right-hand side vector - * @return Solution vector - */ - auto solveLinearSystem(const std::vector>& A, - const std::vector& b) -> std::vector { - int n = A.size(); - std::vector> augmented(n, std::vector(n + 1, 0.0)); - for (int i = 0; i < n; ++i) { - for (int j = 0; j < n; ++j) { - augmented[i][j] = A[i][j]; - } - augmented[i][n] = b[i]; - } - - for (int i = 0; i < n; ++i) { - // Partial pivoting - int maxRow = i; - for (int k = i + 1; k < n; ++k) { - if (std::abs(augmented[k][i]) > - std::abs(augmented[maxRow][i])) { - maxRow = k; - } - } - if (std::abs(augmented[maxRow][i]) < 1e-12) { - THROW_RUNTIME_ERROR("Matrix is singular or nearly singular."); - } - std::swap(augmented[i], augmented[maxRow]); - - // Eliminate below - for (int k = i + 1; k < n; ++k) { - T factor = augmented[k][i] / augmented[i][i]; - for (int j = i; j <= n; ++j) { - augmented[k][j] -= factor * augmented[i][j]; - } - } - } - - std::vector x(n, 0.0); - for (int i = n - 1; i >= 0; --i) { - if (std::abs(augmented[i][i]) < 1e-12) { - THROW_RUNTIME_ERROR( - "Division by zero during back substitution."); - } - x[i] = augmented[i][n]; - for (int j = i + 1; j < n; ++j) { - x[i] -= augmented[i][j] * x[j]; - } - x[i] /= augmented[i][i]; - } - - return x; - } - -public: - /** - * Linear calibration using the least squares method - * @param measured Vector of measured values - * @param actual Vector of actual values - */ - void linearCalibrate(const std::vector& measured, - const std::vector& actual) { - if (measured.size() != actual.size() || measured.empty()) { - THROW_INVALID_ARGUMENT( - "Input vectors must be non-empty and of equal size"); - } - - T sumX = std::accumulate(measured.begin(), measured.end(), T(0)); - T sumY = std::accumulate(actual.begin(), actual.end(), T(0)); - T sumXy = std::inner_product(measured.begin(), measured.end(), - actual.begin(), T(0)); - T sumXx = std::inner_product(measured.begin(), measured.end(), - measured.begin(), T(0)); - - T n = static_cast(measured.size()); - if (n * sumXx - sumX * sumX == 0) { - THROW_RUNTIME_ERROR("Division by zero in slope calculation."); - } - slope_ = (n * sumXy - sumX * sumY) / (n * sumXx - sumX * sumX); - intercept_ = (sumY - slope_ * sumX) / n; - - calculateMetrics(measured, actual); - } - - /** - * Polynomial calibration using the least squares method - * @param measured Vector of measured values - * @param actual Vector of actual values - * @param degree Degree of the polynomial - */ - void polynomialCalibrate(const std::vector& measured, - const std::vector& actual, int degree) { - if (measured.size() != actual.size() || measured.empty()) { - THROW_INVALID_ARGUMENT( - "Input vectors must be non-empty and of equal size"); - } - if (degree < 1) { - THROW_INVALID_ARGUMENT("Polynomial degree must be at least 1."); - } - - auto polyFunc = [degree](T x, const std::vector& params) -> T { - T result = 0; - for (int i = 0; i <= degree; ++i) { - result += params[i] * std::pow(x, i); - } - return result; - }; - - std::vector initialParams(degree + 1, 1.0); - auto params = - levenbergMarquardt(measured, actual, polyFunc, initialParams); - - if (params.size() < 2) { - THROW_RUNTIME_ERROR( - "Insufficient parameters returned from calibration."); - } - - slope_ = params[1]; // First-order coefficient as slope - intercept_ = params[0]; // Constant term as intercept - - calculateMetrics(measured, actual); - } - - /** - * Exponential calibration using the least squares method - * @param measured Vector of measured values - * @param actual Vector of actual values - */ - void exponentialCalibrate(const std::vector& measured, - const std::vector& actual) { - if (measured.size() != actual.size() || measured.empty()) { - THROW_INVALID_ARGUMENT( - "Input vectors must be non-empty and of equal size"); - } - if (std::any_of(actual.begin(), actual.end(), - [](T val) { return val <= 0; })) { - THROW_INVALID_ARGUMENT( - "Actual values must be positive for exponential calibration."); - } - - auto expFunc = [](T x, const std::vector& params) -> T { - return params[0] * std::exp(params[1] * x); - }; - - std::vector initialParams = {1.0, 0.1}; - auto params = - levenbergMarquardt(measured, actual, expFunc, initialParams); - - if (params.size() < 2) { - THROW_RUNTIME_ERROR( - "Insufficient parameters returned from calibration."); - } - - slope_ = params[1]; - intercept_ = params[0]; - - calculateMetrics(measured, actual); - } - - [[nodiscard]] auto apply(T value) const -> T { - return slope_ * value + intercept_; - } - - void printParameters() const { - LOG_F(INFO, "Calibration parameters: slope = {}, intercept = {}", - slope_, intercept_); - if (r_squared_.has_value()) { - LOG_F(INFO, "R-squared = {}", r_squared_.value()); - } - LOG_F(INFO, "MSE = {}, MAE = {}", mse_, mae_); - } - - [[nodiscard]] auto getResiduals() const -> std::vector { - return residuals_; - } - - void plotResiduals(const std::string& filename) const { - std::ofstream file(filename); - if (!file.is_open()) { - THROW_FAIL_TO_OPEN_FILE("Failed to open file: " + filename); - } - - file << "Index,Residual\n"; - for (size_t i = 0; i < residuals_.size(); ++i) { - file << i << "," << residuals_[i] << "\n"; - } - } - - /** - * Bootstrap confidence interval for the slope - * @param measured Vector of measured values - * @param actual Vector of actual values - * @param n_iterations Number of bootstrap iterations - * @param confidence_level Confidence level for the interval - * @return Pair of lower and upper bounds of the confidence interval - */ - auto bootstrapConfidenceInterval( - const std::vector& measured, const std::vector& actual, - int n_iterations = 1000, - double confidence_level = 0.95) -> std::pair { - if (n_iterations <= 0) { - THROW_INVALID_ARGUMENT("Number of iterations must be positive."); - } - if (confidence_level <= 0 || confidence_level >= 1) { - THROW_INVALID_ARGUMENT("Confidence level must be between 0 and 1."); - } - - std::vector bootstrapSlopes; - bootstrapSlopes.reserve(n_iterations); - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, measured.size() - 1); - - for (int i = 0; i < n_iterations; ++i) { - std::vector bootMeasured; - std::vector bootActual; - bootMeasured.reserve(measured.size()); - bootActual.reserve(actual.size()); - for (size_t j = 0; j < measured.size(); ++j) { - int idx = dis(gen); - bootMeasured.push_back(measured[idx]); - bootActual.push_back(actual[idx]); - } - - AdvancedErrorCalibration bootCalibrator; - try { - bootCalibrator.linearCalibrate(bootMeasured, bootActual); - bootstrapSlopes.push_back(bootCalibrator.getSlope()); - } catch (const std::exception& e) { - LOG_F(WARNING, "Bootstrap iteration %d failed: %s", i, - e.what()); - } - } - - if (bootstrapSlopes.empty()) { - THROW_RUNTIME_ERROR("All bootstrap iterations failed."); - } - - std::sort(bootstrapSlopes.begin(), bootstrapSlopes.end()); - int lowerIdx = static_cast((1 - confidence_level) / 2 * - bootstrapSlopes.size()); - int upperIdx = static_cast((1 + confidence_level) / 2 * - bootstrapSlopes.size()); - - lowerIdx = std::clamp(lowerIdx, 0, - static_cast(bootstrapSlopes.size()) - 1); - upperIdx = std::clamp(upperIdx, 0, - static_cast(bootstrapSlopes.size()) - 1); - - return {bootstrapSlopes[lowerIdx], bootstrapSlopes[upperIdx]}; - } - - /** - * Detect outliers using the residuals of the calibration - * @param measured Vector of measured values - * @param actual Vector of actual values - * @param threshold Threshold for outlier detection - * @return Tuple of mean residual, standard deviation, and threshold - */ - auto outlierDetection(const std::vector& measured, - const std::vector& actual, - T threshold = 2.0) -> std::tuple { - if (residuals_.empty()) { - THROW_RUNTIME_ERROR("Please call calculateMetrics() first."); - } - - T meanResidual = - std::accumulate(residuals_.begin(), residuals_.end(), T(0)) / - residuals_.size(); - T std_dev = std::sqrt( - std::accumulate(residuals_.begin(), residuals_.end(), T(0), - [meanResidual](T acc, T val) { - return acc + std::pow(val - meanResidual, 2); - }) / - residuals_.size()); - -#if ENABLE_DEBUG - std::cout << "Detected outliers:" << std::endl; - for (size_t i = 0; i < residuals_.size(); ++i) { - if (std::abs(residuals_[i] - meanResidual) > threshold * std_dev) { - std::cout << "Index: " << i << ", Measured: " << measured[i] - << ", Actual: " << actual[i] - << ", Residual: " << residuals_[i] << std::endl; - } - } -#endif - return {meanResidual, std_dev, threshold}; - } - - void crossValidation(const std::vector& measured, - const std::vector& actual, int k = 5) { - if (measured.size() != actual.size() || - measured.size() < static_cast(k)) { - THROW_INVALID_ARGUMENT( - "Input vectors must be non-empty and of size greater than k"); - } - - std::vector mseValues; - std::vector maeValues; - std::vector rSquaredValues; - - for (int i = 0; i < k; ++i) { - std::vector trainMeasured; - std::vector trainActual; - std::vector testMeasured; - std::vector testActual; - for (size_t j = 0; j < measured.size(); ++j) { - if (j % k == static_cast(i)) { - testMeasured.push_back(measured[j]); - testActual.push_back(actual[j]); - } else { - trainMeasured.push_back(measured[j]); - trainActual.push_back(actual[j]); - } - } - - AdvancedErrorCalibration cvCalibrator; - try { - cvCalibrator.linearCalibrate(trainMeasured, trainActual); - } catch (const std::exception& e) { - LOG_F(WARNING, "Cross-validation fold %d failed: %s", i, - e.what()); - continue; - } - - T foldMse = 0; - T foldMae = 0; - T foldSsTotal = 0; - T foldSsResidual = 0; - T meanTestActual = - std::accumulate(testActual.begin(), testActual.end(), T(0)) / - testActual.size(); - for (size_t j = 0; j < testMeasured.size(); ++j) { - T predicted = cvCalibrator.apply(testMeasured[j]); - T error = testActual[j] - predicted; - foldMse += error * error; - foldMae += std::abs(error); - foldSsTotal += std::pow(testActual[j] - meanTestActual, 2); - foldSsResidual += std::pow(error, 2); - } - - mseValues.push_back(foldMse / testMeasured.size()); - maeValues.push_back(foldMae / testMeasured.size()); - if (foldSsTotal != 0) { - rSquaredValues.push_back(1 - (foldSsResidual / foldSsTotal)); - } - } - - if (mseValues.empty()) { - THROW_RUNTIME_ERROR("All cross-validation folds failed."); - } - - T avgMse = std::accumulate(mseValues.begin(), mseValues.end(), T(0)) / - mseValues.size(); - T avgMae = std::accumulate(maeValues.begin(), maeValues.end(), T(0)) / - maeValues.size(); - T avgRSquared = 0; - if (!rSquaredValues.empty()) { - avgRSquared = std::accumulate(rSquaredValues.begin(), - rSquaredValues.end(), T(0)) / - rSquaredValues.size(); - } - -#if ENABLE_DEBUG - std::cout << "K-fold cross-validation results (k = " << k - << "):" << std::endl; - std::cout << "Average MSE: " << avgMse << std::endl; - std::cout << "Average MAE: " << avgMae << std::endl; - std::cout << "Average R-squared: " << avgRSquared << std::endl; -#endif - } - - [[nodiscard]] auto getSlope() const -> T { return slope_; } - [[nodiscard]] auto getIntercept() const -> T { return intercept_; } - [[nodiscard]] auto getRSquared() const -> std::optional { - return r_squared_; - } - [[nodiscard]] auto getMse() const -> T { return mse_; } - [[nodiscard]] auto getMae() const -> T { return mae_; } -}; - -} // namespace atom::algorithm - -#endif // ATOM_ALGORITHM_ERROR_CALIBRATION_HPP \ No newline at end of file diff --git a/src/atom/algorithm/fnmatch.cpp b/src/atom/algorithm/fnmatch.cpp deleted file mode 100644 index 880f99d4..00000000 --- a/src/atom/algorithm/fnmatch.cpp +++ /dev/null @@ -1,314 +0,0 @@ -/* - * fnmatch.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-5-2 - -Description: Enhanced Python-Like fnmatch for C++ - -**************************************************/ - -#include "fnmatch.hpp" - -#include -#include - -#ifdef _WIN32 -#include -#else -#include -#endif - -#include "atom/log/loguru.hpp" - -namespace atom::algorithm { - -#ifdef _WIN32 -constexpr int FNM_NOESCAPE = 0x01; -constexpr int FNM_PATHNAME = 0x02; -constexpr int FNM_PERIOD = 0x04; -constexpr int FNM_CASEFOLD = 0x08; -#endif - -auto fnmatch(std::string_view pattern, std::string_view string, - int flags) -> bool { - LOG_F(INFO, "fnmatch called with pattern: {}, string: {}, flags: {}", - pattern, string, flags); - - try { -#ifdef _WIN32 - auto p = pattern.begin(); - auto s = string.begin(); - - while (p != pattern.end() && s != string.end()) { - switch (*p) { - case '?': - LOG_F(INFO, "Wildcard '?' encountered."); - ++s; - break; - case '*': - LOG_F(INFO, "Wildcard '*' encountered."); - if (++p == pattern.end()) { - LOG_F(INFO, - "Trailing '*' matches the rest of the string."); - return true; - } - while (s != string.end()) { - if (fnmatch({p, pattern.end()}, {s, string.end()}, - flags)) { - return true; - } - ++s; - } - LOG_F(INFO, "No match found after '*'."); - return false; - case '[': { - LOG_F(INFO, "Character class '[' encountered."); - if (++p == pattern.end()) { - LOG_F(ERROR, "Unclosed '[' in pattern."); - throw FmmatchException("Unclosed '[' in pattern."); - } - bool invert = false; - if (*p == '!') { - invert = true; - LOG_F(INFO, "Inverted character class."); - ++p; - } - bool matched = false; - char last_char = 0; - while (p != pattern.end() && *p != ']') { - if (*p == '-' && last_char != 0 && - p + 1 != pattern.end() && *(p + 1) != ']') { - ++p; - if (*s >= last_char && *s <= *p) { - matched = true; - LOG_F(INFO, "Range match: {}-{}", last_char, - *p); - break; - } - } else { - if (*s == *p) { - matched = true; - LOG_F(INFO, "Exact character match: {}", *p); - break; - } - last_char = *p; - } - ++p; - } - if (p == pattern.end()) { - LOG_F(ERROR, "Unclosed '[' in pattern."); - throw FmmatchException("Unclosed '[' in pattern."); - } - if (invert) { - matched = !matched; - LOG_F(INFO, "Inversion applied to match result."); - } - if (!matched) { - LOG_F(INFO, "Character class did not match."); - return false; - } - ++s; - break; - } - case '\\': - LOG_F(INFO, "Escape character '\\' encountered."); - if (!(flags & FNM_NOESCAPE)) { - if (++p == pattern.end()) { - LOG_F(ERROR, - "Escape character '\\' at end of pattern."); - throw FmmatchException( - "Escape character '\\' at end of pattern."); - } - } - [[fallthrough]]; - default: - if ((flags & FNM_CASEFOLD) - ? (std::tolower(*p) != std::tolower(*s)) - : (*p != *s)) { - LOG_F(INFO, - "Literal character mismatch: pattern '{}' vs " - "string '{}'", - *p, *s); - return false; - } - ++s; - break; - } - ++p; - } - - if (p == pattern.end() && s == string.end()) { - LOG_F(INFO, "Full match achieved."); - return true; - } - if (p != pattern.end() && *p == '*') { - ++p; - LOG_F(INFO, "Trailing '*' allows remaining characters to match."); - } - bool result = p == pattern.end() && s == string.end(); - LOG_F(INFO, "Match result: {}", result ? "True" : "False"); - return result; -#else - LOG_F(INFO, "Using system fnmatch."); - int ret = ::fnmatch(pattern.data(), string.data(), flags); - bool result = (ret == 0); - LOG_F(INFO, "System fnmatch result: {}", result ? "True" : "False"); - return result; -#endif - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in fnmatch: {}", e.what()); - throw; // Rethrow the exception after logging - } -} - -auto filter(const std::vector& names, std::string_view pattern, - int flags) -> bool { - LOG_F(INFO, "Filter called with pattern: {} and {} names.", pattern, - names.size()); - try { - return std::ranges::any_of(names, [&](const std::string& name) { - bool match = fnmatch(pattern, name, flags); - LOG_F(INFO, "Checking if \"{}\" matches pattern \"{}\": {}", name, - pattern, match ? "Yes" : "No"); - return match; - }); - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in filter: {}", e.what()); - throw; - } -} - -auto filter(const std::vector& names, - const std::vector& patterns, - int flags) -> std::vector { - LOG_F(INFO, - "Filter called with multiple patterns: {} patterns and {} names.", - patterns.size(), names.size()); - std::vector result; - try { - for (const auto& name : names) { - bool matched = - std::ranges::any_of(patterns, [&](std::string_view pattern) { - bool match = fnmatch(pattern, name, flags); - LOG_F(INFO, "Checking if \"{}\" matches pattern \"{}\": {}", - name, pattern, match ? "Yes" : "No"); - return match; - }); - if (matched) { - LOG_F(INFO, "Name \"{}\" matches at least one pattern.", name); - result.push_back(name); - } - } - LOG_F(INFO, "Filter result contains {} matched names.", result.size()); - return result; - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in multiple patterns filter: {}", e.what()); - throw; - } -} - -auto translate(std::string_view pattern, std::string& result, - int flags) -> bool { - LOG_F(INFO, "Translating pattern: {} with flags: {}", pattern, flags); - result.clear(); - try { - for (auto it = pattern.begin(); it != pattern.end(); ++it) { - switch (*it) { - case '*': - LOG_F(INFO, "Translating '*' to '.*'"); - result += ".*"; - break; - case '?': - LOG_F(INFO, "Translating '?' to '.'"); - result += '.'; - break; - case '[': { - LOG_F(INFO, "Translating '[' to '['"); - result += '['; - if (++it == pattern.end()) { - LOG_F(ERROR, - "Unclosed '[' in pattern during translation."); - throw FnmatchException( - "Unclosed '[' in pattern during translation."); - } - if (*it == '!') { - LOG_F(INFO, - "Inverted character class during translation."); - result += '^'; - ++it; - } - if (it == pattern.end()) { - LOG_F(ERROR, - "Unclosed '[' in pattern during translation."); - throw FnmatchException( - "Unclosed '[' in pattern during translation."); - } - char lastChar = *it; - result += *it; - while (++it != pattern.end() && *it != ']') { - if (*it == '-' && it + 1 != pattern.end() && - *(it + 1) != ']') { - LOG_F(INFO, - "Translating range in character class."); - result += *it; - result += *(++it); - lastChar = *it; - } else { - result += *it; - lastChar = *it; - } - } - if (it == pattern.end()) { - LOG_F(ERROR, - "Unclosed '[' in pattern during translation."); - throw FnmatchException( - "Unclosed '[' in pattern during translation."); - } - result += ']'; - break; - } - case '\\': - LOG_F(INFO, "Translating escape character '\\' to '\\\\'"); - if ((flags & FNM_NOESCAPE) == 0) { - if (++it == pattern.end()) { - LOG_F(ERROR, - "Escape character '\\' at end of pattern " - "during translation."); - throw FnmatchException( - "Escape character '\\' at end of pattern " - "during translation."); - } - } - [[fallthrough]]; - default: - if (((flags & FNM_CASEFOLD) != 0) && - (std::isalpha(*it) != 0)) { - LOG_F(INFO, - "Translating alphabetic character with case " - "folding: {}", - *it); - result += '['; - result += static_cast(std::tolower(*it)); - result += static_cast(std::toupper(*it)); - result += ']'; - } else { - LOG_F(INFO, "Translating literal character: {}", *it); - result += *it; - } - break; - } - } - LOG_F(INFO, "Translation successful. Resulting regex: {}", result); - return true; - } catch (const std::exception& e) { - LOG_F(ERROR, "Exception in translate: {}", e.what()); - throw; - } -} - -} // namespace atom::algorithm \ No newline at end of file diff --git a/src/atom/algorithm/fnmatch.hpp b/src/atom/algorithm/fnmatch.hpp deleted file mode 100644 index 1de8e795..00000000 --- a/src/atom/algorithm/fnmatch.hpp +++ /dev/null @@ -1,110 +0,0 @@ -/* - * fnmatch.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-5-2 - -Description: Enhanced Python-Like fnmatch for C++ - -**************************************************/ - -#ifndef ATOM_SYSTEM_FNMATCH_HPP -#define ATOM_SYSTEM_FNMATCH_HPP - -#include -#include -#include -#include - -namespace atom::algorithm { - -/** - * @brief Exception class for fnmatch errors. - */ -class FnmatchException : public std::exception { -private: - std::string message_; - -public: - explicit FnmatchException(const std::string& message) : message_(message) {} - virtual const char* what() const noexcept override { - return message_.c_str(); - } -}; - -/** - * @brief Matches a string against a specified pattern. - * - * This function compares the given `string` against the specified `pattern` - * using shell-style pattern matching. The `flags` parameter can be used to - * modify the behavior of the matching process. - * - * @param pattern The pattern to match against. - * @param string The string to match. - * @param flags Optional flags to modify the matching behavior (default is 0). - * @return True if the `string` matches the `pattern`, false otherwise. - * @throws FmmatchException on invalid pattern or other matching errors. - */ -auto fnmatch(std::string_view pattern, std::string_view string, - int flags = 0) -> bool; - -/** - * @brief Filters a vector of strings based on a specified pattern. - * - * This function filters the given vector of `names` based on the specified - * `pattern` using shell-style pattern matching. The `flags` parameter can be - * used to modify the filtering behavior. - * - * @param names The vector of strings to filter. - * @param pattern The pattern to filter with. - * @param flags Optional flags to modify the filtering behavior (default is 0). - * @return True if any element of `names` matches the `pattern`, false - * otherwise. - * @throws FmmatchException on matching errors. - */ -auto filter(const std::vector& names, std::string_view pattern, - int flags = 0) -> bool; - -/** - * @brief Filters a vector of strings based on multiple patterns. - * - * This function filters the given vector of `names` based on the specified - * `patterns` using shell-style pattern matching. The `flags` parameter can be - * used to modify the filtering behavior. - * - * @param names The vector of strings to filter. - * @param patterns The vector of patterns to filter with. - * @param flags Optional flags to modify the filtering behavior (default is 0). - * @return A vector containing strings from `names` that match any pattern in - * `patterns`. - * @throws FmmatchException on matching errors. - */ -auto filter(const std::vector& names, - const std::vector& patterns, - int flags = 0) -> std::vector; - -/** - * @brief Translates a pattern into a regex string. - * - * This function translates the specified `pattern` into a regex string and - * stores the result in the `result` parameter. The `flags` parameter can be - * used to modify the translation behavior. - * - * @param pattern The pattern to translate. - * @param result A reference to a string where the translated pattern will be - * stored. - * @param flags Optional flags to modify the translation behavior (default is - * 0). - * @return True if the translation was successful, false otherwise. - * @throws FmmatchException on invalid patterns or other translation errors. - */ -auto translate(std::string_view pattern, std::string& result, - int flags = 0) -> bool; - -} // namespace atom::algorithm - -#endif // ATOM_SYSTEM_FNMATCH_HPP \ No newline at end of file diff --git a/src/atom/algorithm/fraction.cpp b/src/atom/algorithm/fraction.cpp deleted file mode 100644 index 643155b4..00000000 --- a/src/atom/algorithm/fraction.cpp +++ /dev/null @@ -1,297 +0,0 @@ -/* - * fraction.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-3-28 - -Description: Implementation of Fraction class - -**************************************************/ - -#include "fraction.hpp" - -#include -#include - -namespace atom::algorithm { - -/* ------------------------ Private Methods ------------------------ */ - -constexpr int Fraction::gcd(int a, int b) noexcept { - return (std::numeric_limits::min() != a && - std::numeric_limits::min() != b) - ? std::abs(std::gcd(a, b)) - : 1; // Prevent undefined behavior for min int -} - -void Fraction::reduce() noexcept { - if (denominator == 0) { - // Denominator check is handled in constructors and operators - return; - } - if (denominator < 0) { - numerator = -numerator; - denominator = -denominator; - } - int divisor = gcd(numerator, denominator); - numerator /= divisor; - denominator /= divisor; -} - -/* ------------------------ Arithmetic Operators ------------------------ */ - -auto Fraction::operator+=(const Fraction& other) -> Fraction& { - // Avoid overflow by using long long for intermediate calculations - long long commonDenominator = - static_cast(denominator) * other.denominator; - long long newNumerator = - static_cast(numerator) * other.denominator + - static_cast(other.numerator) * denominator; - - if (newNumerator > std::numeric_limits::max() || - newNumerator < std::numeric_limits::min() || - commonDenominator > std::numeric_limits::max() || - commonDenominator < std::numeric_limits::min()) { - throw FractionException("Integer overflow during addition."); - } - - numerator = static_cast(newNumerator); - denominator = static_cast(commonDenominator); - reduce(); - return *this; -} - -auto Fraction::operator-=(const Fraction& other) -> Fraction& { - long long commonDenominator = - static_cast(denominator) * other.denominator; - long long newNumerator = - static_cast(numerator) * other.denominator - - static_cast(other.numerator) * denominator; - - if (newNumerator > std::numeric_limits::max() || - newNumerator < std::numeric_limits::min() || - commonDenominator > std::numeric_limits::max() || - commonDenominator < std::numeric_limits::min()) { - throw FractionException("Integer overflow during subtraction."); - } - - numerator = static_cast(newNumerator); - denominator = static_cast(commonDenominator); - reduce(); - return *this; -} - -auto Fraction::operator*=(const Fraction& other) -> Fraction& { - if (other.numerator == 0) { - numerator = 0; - denominator = 1; - return *this; - } - - long long newNumerator = - static_cast(numerator) * other.numerator; - long long newDenominator = - static_cast(denominator) * other.denominator; - - if (newNumerator > std::numeric_limits::max() || - newNumerator < std::numeric_limits::min() || - newDenominator > std::numeric_limits::max() || - newDenominator < std::numeric_limits::min()) { - throw FractionException("Integer overflow during multiplication."); - } - - numerator = static_cast(newNumerator); - denominator = static_cast(newDenominator); - reduce(); - return *this; -} - -auto Fraction::operator/=(const Fraction& other) -> Fraction& { - if (other.numerator == 0) { - throw FractionException("Division by zero."); - } - - long long newNumerator = - static_cast(numerator) * other.denominator; - long long newDenominator = - static_cast(denominator) * other.numerator; - - if (newDenominator == 0) { - throw FractionException("Denominator cannot be zero after division."); - } - - if (newNumerator > std::numeric_limits::max() || - newNumerator < std::numeric_limits::min() || - newDenominator > std::numeric_limits::max() || - newDenominator < std::numeric_limits::min()) { - throw FractionException("Integer overflow during division."); - } - - numerator = static_cast(newNumerator); - denominator = static_cast(newDenominator); - if (denominator < 0) { // Handle negative denominators - numerator = -numerator; - denominator = -denominator; - } - reduce(); - return *this; -} - -/* ------------------------ Arithmetic Operators (Non-Member) - * ------------------------ */ - -auto Fraction::operator+(const Fraction& other) const -> Fraction { - Fraction result(*this); - result += other; - return result; -} - -auto Fraction::operator-(const Fraction& other) const -> Fraction { - Fraction result(*this); - result -= other; - return result; -} - -auto Fraction::operator*(const Fraction& other) const -> Fraction { - Fraction result(*this); - result *= other; - return result; -} - -auto Fraction::operator/(const Fraction& other) const -> Fraction { - Fraction result(*this); - result /= other; - return result; -} - -/* ------------------------ Comparison Operators ------------------------ */ - -#if __cplusplus >= 202002L -auto Fraction::operator<=>(const Fraction& other) const - -> std::strong_ordering { - long long lhs = static_cast(numerator) * other.denominator; - long long rhs = static_cast(other.numerator) * denominator; - if (lhs < rhs) { - return std::strong_ordering::less; - } - if (lhs > rhs) { - return std::strong_ordering::greater; - } - return std::strong_ordering::equal; -} -#endif - -auto Fraction::operator==(const Fraction& other) const -> bool { -#if __cplusplus >= 202002L - return (*this <=> other) == std::strong_ordering::equal; -#else - return (numerator == other.numerator) && (denominator == other.denominator); -#endif -} - -/* ------------------------ Type Conversion Operators ------------------------ - */ - -Fraction::operator double() const { - return static_cast(numerator) / denominator; -} - -Fraction::operator float() const { - return static_cast(numerator) / denominator; -} - -Fraction::operator int() const { return numerator / denominator; } - -/* ------------------------ Utility Methods ------------------------ */ - -auto Fraction::toString() const -> std::string { - std::ostringstream oss; - oss << numerator << '/' << denominator; - return oss.str(); -} - -auto Fraction::toDouble() const -> double { return static_cast(*this); } - -auto Fraction::invert() -> Fraction& { - if (numerator == 0) { - throw FractionException( - "Cannot invert a fraction with numerator zero."); - } - std::swap(numerator, denominator); - if (denominator < 0) { - numerator = -numerator; - denominator = -denominator; - } - reduce(); - return *this; -} - -auto Fraction::abs() const -> Fraction { - return Fraction(numerator < 0 ? -numerator : numerator, denominator); -} - -auto Fraction::isZero() const -> bool { return numerator == 0; } - -auto Fraction::isPositive() const -> bool { return numerator > 0; } - -auto Fraction::isNegative() const -> bool { return numerator < 0; } - -/* ------------------------ Friend Functions ------------------------ */ - -auto operator<<(std::ostream& os, const Fraction& f) -> std::ostream& { - os << f.toString(); - return os; -} - -auto operator>>(std::istream& is, Fraction& f) -> std::istream& { - int n = 0, d = 1; - char sep = '/'; - is >> n >> sep >> d; - if (sep != '/') { - is.setstate(std::ios::failbit); - throw FractionException( - "Invalid input format. Expected 'numerator/denominator'."); - } - if (d == 0) { - throw FractionException("Denominator cannot be zero."); - } - f.numerator = n; - f.denominator = d; - f.reduce(); - return is; -} - -/* ------------------------ Inline Utility Functions ------------------------ */ - -auto makeFraction(int value) -> Fraction { return Fraction(value, 1); } - -auto makeFraction(double value, int max_denominator) -> Fraction { - if (std::isnan(value) || std::isinf(value)) { - throw FractionException("Cannot create Fraction from NaN or Infinity."); - } - - int sign = (value < 0) ? -1 : 1; - value = std::abs(value); - int numerator = 0; - int denominator = 1; - double minError = std::numeric_limits::max(); - - for (denominator = 1; denominator <= max_denominator; ++denominator) { - numerator = static_cast(std::round(value * denominator)); - double currentError = - std::abs(value - static_cast(numerator) / denominator); - if (currentError < minError) { - minError = currentError; - } else { - break; - } - } - - return Fraction(sign * numerator, denominator); -} - -} // namespace atom::algorithm \ No newline at end of file diff --git a/src/atom/algorithm/fraction.hpp b/src/atom/algorithm/fraction.hpp deleted file mode 100644 index cd96bc92..00000000 --- a/src/atom/algorithm/fraction.hpp +++ /dev/null @@ -1,257 +0,0 @@ -/* - * fraction.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-3-28 - -Description: Implementation of Fraction class - -**************************************************/ - -#ifndef ATOM_ALGORITHM_FRACTION_HPP -#define ATOM_ALGORITHM_FRACTION_HPP - -#include -#include -#include -#include -#include - -namespace atom::algorithm { - -/** - * @brief Exception class for Fraction errors. - */ -class FractionException : public std::runtime_error { -public: - explicit FractionException(const std::string& message) - : std::runtime_error(message) {} -}; - -/** - * @brief Represents a fraction with numerator and denominator. - */ -class Fraction { -private: - int numerator; /**< The numerator of the fraction. */ - int denominator; /**< The denominator of the fraction. */ - - /** - * @brief Computes the greatest common divisor (GCD) of two numbers. - * @param a The first number. - * @param b The second number. - * @return The GCD of the two numbers. - */ - static constexpr int gcd(int a, int b) noexcept; - - /** - * @brief Reduces the fraction to its simplest form. - */ - void reduce() noexcept; - -public: - /** - * @brief Constructs a new Fraction object with the given numerator and - * denominator. - * @param n The numerator (default is 0). - * @param d The denominator (default is 1). - * @throws FractionException if the denominator is zero. - */ - explicit constexpr Fraction(int n, int d) : numerator(n), denominator(d) { - if (denominator == 0) { - throw FractionException("Denominator cannot be zero."); - } - reduce(); - } - - /** - * @brief Constructs a new Fraction object with the given integer value. - * @param value The integer value. - */ - explicit constexpr Fraction(int value) : numerator(value), denominator(1) {} - - /** - * @brief Default constructor. Initializes the fraction as 0/1. - */ - constexpr Fraction() : Fraction(0, 1) {} - - /** - * @brief Adds another fraction to this fraction. - * @param other The fraction to add. - * @return Reference to the modified fraction. - * @throws FractionException on arithmetic overflow. - */ - auto operator+=(const Fraction& other) -> Fraction&; - - /** - * @brief Subtracts another fraction from this fraction. - * @param other The fraction to subtract. - * @return Reference to the modified fraction. - * @throws FractionException on arithmetic overflow. - */ - auto operator-=(const Fraction& other) -> Fraction&; - - /** - * @brief Multiplies this fraction by another fraction. - * @param other The fraction to multiply by. - * @return Reference to the modified fraction. - * @throws FractionException if multiplication leads to zero denominator. - */ - auto operator*=(const Fraction& other) -> Fraction&; - - /** - * @brief Divides this fraction by another fraction. - * @param other The fraction to divide by. - * @return Reference to the modified fraction. - * @throws FractionException if division by zero occurs. - */ - auto operator/=(const Fraction& other) -> Fraction&; - - /** - * @brief Adds another fraction to this fraction. - * @param other The fraction to add. - * @return The result of addition. - */ - auto operator+(const Fraction& other) const -> Fraction; - - /** - * @brief Subtracts another fraction from this fraction. - * @param other The fraction to subtract. - * @return The result of subtraction. - */ - auto operator-(const Fraction& other) const -> Fraction; - - /** - * @brief Multiplies this fraction by another fraction. - * @param other The fraction to multiply by. - * @return The result of multiplication. - */ - auto operator*(const Fraction& other) const -> Fraction; - - /** - * @brief Divides this fraction by another fraction. - * @param other The fraction to divide by. - * @return The result of division. - */ - auto operator/(const Fraction& other) const -> Fraction; - -#if __cplusplus >= 202002L - /** - * @brief Compares this fraction with another fraction. - * @param other The fraction to compare with. - * @return A std::strong_ordering indicating the comparison result. - */ - auto operator<=>(const Fraction& other) const -> std::strong_ordering; -#endif - - /** - * @brief Checks if this fraction is equal to another fraction. - * @param other The fraction to compare with. - * @return True if fractions are equal, false otherwise. - */ - auto operator==(const Fraction& other) const -> bool; - - /** - * @brief Converts the fraction to a double value. - * @return The fraction as a double. - */ - explicit operator double() const; - - /** - * @brief Converts the fraction to a float value. - * @return The fraction as a float. - */ - explicit operator float() const; - - /** - * @brief Converts the fraction to an integer value. - * @return The fraction as an integer (truncates towards zero). - */ - explicit operator int() const; - - /** - * @brief Converts the fraction to a string representation. - * @return The string representation of the fraction. - */ - [[nodiscard]] auto toString() const -> std::string; - - /** - * @brief Converts the fraction to a double value. - * @return The fraction as a double. - */ - [[nodiscard]] auto toDouble() const -> double; - - /** - * @brief Inverts the fraction (reciprocal). - * @return Reference to the modified fraction. - * @throws FractionException if numerator is zero. - */ - auto invert() -> Fraction&; - - /** - * @brief Returns the absolute value of the fraction. - * @return A new Fraction representing the absolute value. - */ - [[nodiscard]] auto abs() const -> Fraction; - - /** - * @brief Checks if the fraction is zero. - * @return True if the fraction is zero, false otherwise. - */ - [[nodiscard]] auto isZero() const -> bool; - - /** - * @brief Checks if the fraction is positive. - * @return True if the fraction is positive, false otherwise. - */ - [[nodiscard]] auto isPositive() const -> bool; - - /** - * @brief Checks if the fraction is negative. - * @return True if the fraction is negative, false otherwise. - */ - [[nodiscard]] auto isNegative() const -> bool; - - /** - * @brief Outputs the fraction to the output stream. - * @param os The output stream. - * @param f The fraction to output. - * @return Reference to the output stream. - */ - friend auto operator<<(std::ostream& os, - const Fraction& f) -> std::ostream&; - - /** - * @brief Inputs the fraction from the input stream. - * @param is The input stream. - * @param f The fraction to input. - * @return Reference to the input stream. - * @throws FractionException if the input format is invalid or denominator - * is zero. - */ - friend auto operator>>(std::istream& is, Fraction& f) -> std::istream&; -}; - -/** - * @brief Creates a Fraction from an integer. - * @param value The integer value. - * @return A Fraction representing the integer. - */ -auto makeFraction(int value) -> Fraction; - -/** - * @brief Creates a Fraction from a double by approximating it. - * @param value The double value. - * @param max_denominator The maximum allowed denominator to limit the - * approximation. - * @return A Fraction approximating the double value. - */ -auto makeFraction(double value, int max_denominator = 1000000) -> Fraction; - -} // namespace atom::algorithm - -#endif // ATOM_ALGORITHM_FRACTION_HPP \ No newline at end of file diff --git a/src/atom/algorithm/hash.hpp b/src/atom/algorithm/hash.hpp deleted file mode 100644 index c78b458f..00000000 --- a/src/atom/algorithm/hash.hpp +++ /dev/null @@ -1,224 +0,0 @@ -/* - * hash.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-3-28 - -Description: A collection of optimized and enhanced hash algorithms - -**************************************************/ - -#ifndef ATOM_ALGORITHM_HASH_HPP -#define ATOM_ALGORITHM_HASH_HPP - -#include -#include -#include -#include -#include -#include -#include - -namespace atom::algorithm { - -/** - * @brief Concept for types that can be hashed. - * - * A type is Hashable if it supports hashing via std::hash and the result is - * convertible to std::size_t. - */ -template -concept Hashable = requires(T a) { - { std::hash{}(a) } -> std::convertible_to; -}; - -/** - * @brief Combines two hash values into one. - * - * This function implements the hash combining technique proposed by Boost. - * - * @param seed The initial hash value. - * @param hash The hash value to combine with the seed. - * @return std::size_t The combined hash value. - */ -inline auto hashCombine(std::size_t seed, - std::size_t hash) noexcept -> std::size_t { - // Magic number from Boost library - return seed ^ (hash + 0x9e3779b9 + (seed << 6) + (seed >> 2)); -} - -/** - * @brief Computes the hash value for a single Hashable value. - * - * @tparam T Type of the value to hash, must satisfy Hashable concept. - * @param value The value to hash. - * @return std::size_t Hash value of the input value. - */ -template -inline auto computeHash(const T& value) noexcept -> std::size_t { - return std::hash{}(value); -} - -/** - * @brief Computes the hash value for a vector of Hashable values. - * - * @tparam T Type of the elements in the vector, must satisfy Hashable concept. - * @param values The vector of values to hash. - * @return std::size_t Hash value of the vector of values. - */ -template -inline auto computeHash(const std::vector& values) noexcept -> std::size_t { - std::size_t result = 0; - for (const auto& value : values) { - result = hashCombine(result, computeHash(value)); - } - return result; -} - -/** - * @brief Computes the hash value for a tuple of Hashable values. - * - * @tparam Ts Types of the elements in the tuple, all must satisfy Hashable - * concept. - * @param tuple The tuple of values to hash. - * @return std::size_t Hash value of the tuple of values. - */ -template -inline auto computeHash(const std::tuple& tuple) noexcept - -> std::size_t { - std::size_t result = 0; - std::apply( - [&result](const Ts&... values) { - ((result = hashCombine(result, computeHash(values))), ...); - }, - tuple); - return result; -} - -/** - * @brief Computes the hash value for an array of Hashable values. - * - * @tparam T Type of the elements in the array, must satisfy Hashable concept. - * @tparam N Size of the array. - * @param array The array of values to hash. - * @return std::size_t Hash value of the array of values. - */ -template -inline auto computeHash(const std::array& array) noexcept -> std::size_t { - std::size_t result = 0; - for (const auto& value : array) { - result = hashCombine(result, computeHash(value)); - } - return result; -} - -/** - * @brief Computes the hash value for a std::pair of Hashable values. - * - * @tparam T1 Type of the first element in the pair, must satisfy Hashable - * concept. - * @tparam T2 Type of the second element in the pair, must satisfy Hashable - * concept. - * @param pair The pair of values to hash. - * @return std::size_t Hash value of the pair of values. - */ -template -inline auto computeHash(const std::pair& pair) noexcept -> std::size_t { - std::size_t seed = computeHash(pair.first); - seed = hashCombine(seed, computeHash(pair.second)); - return seed; -} - -/** - * @brief Computes the hash value for a std::optional of a Hashable value. - * - * @tparam T Type of the value inside the optional, must satisfy Hashable - * concept. - * @param opt The optional value to hash. - * @return std::size_t Hash value of the optional value. - */ -template -inline auto computeHash(const std::optional& opt) noexcept -> std::size_t { - if (opt.has_value()) { - return computeHash(*opt) + - 1; // Adding 1 to differentiate from std::nullopt - } - return 0; -} - -/** - * @brief Computes the hash value for a std::variant of Hashable types. - * - * @tparam Ts Types contained in the variant, all must satisfy Hashable concept. - * @param var The variant of values to hash. - * @return std::size_t Hash value of the variant value. - */ -template -inline auto computeHash(const std::variant& var) noexcept - -> std::size_t { - return std::visit( - [](const auto& value) -> std::size_t { return computeHash(value); }, - var); -} - -/** - * @brief Computes the hash value for a std::any value. - * - * This function attempts to hash the contained value if it is Hashable. - * If the contained type is not Hashable, it hashes the type information - * instead. - * - * @param value The std::any value to hash. - * @return std::size_t Hash value of the std::any value. - */ -inline auto computeHash(const std::any& value) noexcept -> std::size_t { - if (value.has_value()) { - const std::type_info& type = value.type(); - // Hashing the type information as a fallback - return type.hash_code(); - } - return 0; -} - -/** - * @brief Computes a hash value for a null-terminated string using FNV-1a - * algorithm. - * - * @param str Pointer to the null-terminated string to hash. - * @param basis Initial basis value for hashing. - * @return constexpr std::size_t Hash value of the string. - */ -constexpr auto hash(const char* str, - std::size_t basis = 2166136261u) noexcept -> std::size_t { - std::size_t hash = basis; - while (*str != '\0') { - hash ^= static_cast(*str); - hash *= 16777619u; - ++str; - } - return hash; -} - -/** - * @brief User-defined literal for computing hash values of string literals. - * - * Example usage: "example"_hash - * - * @param str Pointer to the string literal to hash. - * @param size Size of the string literal (unused). - * @return constexpr std::size_t Hash value of the string literal. - */ -constexpr auto operator""_hash(const char* str, - std::size_t size) noexcept -> std::size_t { - // The size parameter is not used in this implementation - static_cast(size); - return hash(str); -} - -} // namespace atom::algorithm - -#endif // ATOM_ALGORITHM_HASH_HPP \ No newline at end of file diff --git a/src/atom/algorithm/huffman.cpp b/src/atom/algorithm/huffman.cpp deleted file mode 100644 index 69bdd500..00000000 --- a/src/atom/algorithm/huffman.cpp +++ /dev/null @@ -1,276 +0,0 @@ -/* - * huffman.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-24 - -Description: Enhanced implementation of Huffman encoding - -**************************************************/ - -#include "huffman.hpp" -#include -#include -#include -#include -#include - -namespace atom::algorithm { - -/* ------------------------ HuffmanNode Implementation ------------------------ - */ - -HuffmanNode::HuffmanNode(unsigned char data, int frequency) - : data(data), frequency(frequency), left(nullptr), right(nullptr) {} - -/* ------------------------ Priority Queue Comparator ------------------------ - */ - -struct CompareNode { - bool operator()(const std::shared_ptr& a, - const std::shared_ptr& b) const { - return a->frequency > b->frequency; - } -}; - -/* ------------------------ createHuffmanTree ------------------------ */ - -auto createHuffmanTree(const std::unordered_map& - frequencies) -> std::shared_ptr { - if (frequencies.empty()) { - throw HuffmanException( - "Frequency map is empty. Cannot create Huffman Tree."); - } - - std::priority_queue, - std::vector>, CompareNode> - minHeap; - - // Initialize heap with leaf nodes - for (const auto& [data, freq] : frequencies) { - minHeap.push(std::make_shared(data, freq)); - } - - // Edge case: Only one unique byte - if (minHeap.size() == 1) { - auto soleNode = std::move(minHeap.top()); - minHeap.pop(); - auto parent = std::make_shared('\0', soleNode->frequency); - parent->left = std::move(soleNode); - parent->right = nullptr; - minHeap.push(std::move(parent)); - } - - // Build Huffman Tree - while (minHeap.size() > 1) { - auto left = std::move(minHeap.top()); - minHeap.pop(); - auto right = std::move(minHeap.top()); - minHeap.pop(); - - auto merged = std::make_shared( - '\0', left->frequency + right->frequency); - merged->left = std::move(left); - merged->right = std::move(right); - - minHeap.push(std::move(merged)); - } - - return minHeap.empty() ? nullptr : std::move(minHeap.top()); -} - -/* ------------------------ generateHuffmanCodes ------------------------ */ - -void generateHuffmanCodes( - const HuffmanNode* root, const std::string& code, - std::unordered_map& huffmanCodes) { - if (root == nullptr) { - throw HuffmanException( - "Cannot generate Huffman codes from a null tree."); - } - - if (!root->left && !root->right) { - if (code.empty()) { - // Edge case: Only one unique byte - huffmanCodes[root->data] = "0"; - } else { - huffmanCodes[root->data] = code; - } - return; - } - - if (root->left) { - generateHuffmanCodes(root->left.get(), code + "0", huffmanCodes); - } - - if (root->right) { - generateHuffmanCodes(root->right.get(), code + "1", huffmanCodes); - } -} - -/* ------------------------ compressData ------------------------ */ - -auto compressData(const std::vector& data, - const std::unordered_map& - huffmanCodes) -> std::string { - std::string compressedData; - compressedData.reserve(data.size() * 2); // Approximate reserve - - for (unsigned char byte : data) { - auto it = huffmanCodes.find(byte); - if (it == huffmanCodes.end()) { - throw HuffmanException( - std::string("Byte '") + std::to_string(static_cast(byte)) + - "' does not have a corresponding Huffman code."); - } - compressedData += it->second; - } - - return compressedData; -} - -/* ------------------------ decompressData ------------------------ */ - -auto decompressData(const std::string& compressedData, - const HuffmanNode* root) -> std::vector { - if (!root) { - throw HuffmanException("Huffman tree is null. Cannot decompress data."); - } - - std::vector decompressedData; - const HuffmanNode* current = root; - - for (char bit : compressedData) { - if (bit == '0') { - if (current->left) { - current = current->left.get(); - } else { - throw HuffmanException( - "Invalid compressed data. Traversed to a null left child."); - } - } else if (bit == '1') { - if (current->right) { - current = current->right.get(); - } else { - throw HuffmanException( - "Invalid compressed data. Traversed to a null right " - "child."); - } - } else { - throw HuffmanException( - "Invalid bit in compressed data. Only '0' and '1' are " - "allowed."); - } - - // If leaf node, append the data and reset to root - if (!current->left && !current->right) { - decompressedData.push_back(current->data); - current = root; - } - } - - // Edge case: compressed data does not end at a leaf node - if (current != root) { - throw HuffmanException( - "Incomplete compressed data. Did not end at a leaf node."); - } - - return decompressedData; -} - -/* ------------------------ serializeTree ------------------------ */ - -auto serializeTree(const HuffmanNode* root) -> std::string { - if (root == nullptr) { - throw HuffmanException("Cannot serialize a null Huffman tree."); - } - - std::string serialized; - std::function serializeHelper = - [&](const HuffmanNode* node) { - if (!node) { - serialized += '1'; // Marker for null - return; - } - - if (!node->left && !node->right) { - serialized += '0'; // Marker for leaf - serialized += node->data; - } else { - serialized += '2'; // Marker for internal node - serializeHelper(node->left.get()); - serializeHelper(node->right.get()); - } - }; - - serializeHelper(root); - return serialized; -} - -/* ------------------------ deserializeTree ------------------------ */ - -auto deserializeTree(const std::string& serializedTree, - size_t& index) -> std::shared_ptr { - if (index >= serializedTree.size()) { - throw HuffmanException( - "Invalid serialized tree format: Unexpected end of data."); - } - - char marker = serializedTree[index++]; - if (marker == '1') { - return nullptr; - } else if (marker == '0') { - if (index >= serializedTree.size()) { - throw HuffmanException( - "Invalid serialized tree format: Missing byte data for leaf " - "node."); - } - unsigned char data = serializedTree[index++]; - return std::make_shared( - data, 0); // Frequency is not needed for decompression - } else if (marker == '2') { - auto node = std::make_shared('\0', 0); - node->left = deserializeTree(serializedTree, index); - node->right = deserializeTree(serializedTree, index); - return node; - } else { - throw HuffmanException( - "Invalid serialized tree format: Unknown marker encountered."); - } -} - -/* ------------------------ visualizeHuffmanTree ------------------------ */ - -void visualizeHuffmanTree(const HuffmanNode* root, const std::string& indent) { - if (!root) { - std::cout << indent << "nullptr\n"; - return; - } - - if (!root->left && !root->right) { - std::cout << indent << "Leaf: '" << root->data << "'\n"; - } else { - std::cout << indent << "Internal Node (Frequency: " << root->frequency - << ")\n"; - } - - if (root->left) { - std::cout << indent << " Left:\n"; - visualizeHuffmanTree(root->left.get(), indent + " "); - } else { - std::cout << indent << " Left: nullptr\n"; - } - - if (root->right) { - std::cout << indent << " Right:\n"; - visualizeHuffmanTree(root->right.get(), indent + " "); - } else { - std::cout << indent << " Right: nullptr\n"; - } -} - -} // namespace atom::algorithm \ No newline at end of file diff --git a/src/atom/algorithm/huffman.hpp b/src/atom/algorithm/huffman.hpp deleted file mode 100644 index d4c00d2c..00000000 --- a/src/atom/algorithm/huffman.hpp +++ /dev/null @@ -1,169 +0,0 @@ -/* - * huffman.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-24 - -Description: Enhanced implementation of Huffman encoding - -**************************************************/ - -#ifndef ATOM_ALGORITHM_HUFFMAN_HPP -#define ATOM_ALGORITHM_HUFFMAN_HPP - -#include -#include -#include -#include -#include - -namespace atom::algorithm { - -/** - * @brief Exception class for Huffman encoding/decoding errors. - */ -class HuffmanException : public std::runtime_error { -public: - explicit HuffmanException(const std::string& message) - : std::runtime_error(message) {} -}; - -/** - * @brief Represents a node in the Huffman tree. - * - * This structure is used to construct the Huffman tree for encoding and - * decoding data based on byte frequencies. - */ -struct HuffmanNode { - unsigned char - data; ///< Byte stored in this node (used only in leaf nodes). - int frequency; ///< Frequency of the byte or sum of frequencies for - ///< internal nodes. - std::shared_ptr left; ///< Pointer to the left child node. - std::shared_ptr right; ///< Pointer to the right child node. - - /** - * @brief Constructs a new Huffman Node. - * - * @param data Byte to store in the node. - * @param frequency Frequency of the byte or combined frequency for a parent - * node. - */ - HuffmanNode(unsigned char data, int frequency); -}; - -/** - * @brief Creates a Huffman tree based on the frequency of bytes. - * - * This function builds a Huffman tree using the frequencies of bytes in - * the input data. It employs a priority queue to build the tree from the bottom - * up by merging the two least frequent nodes until only one node remains, which - * becomes the root. - * - * @param frequencies A map of bytes and their corresponding frequencies. - * @return A unique pointer to the root of the Huffman tree. - * @throws HuffmanException if the frequency map is empty. - */ -[[nodiscard]] auto createHuffmanTree( - const std::unordered_map& frequencies) - -> std::shared_ptr; - -/** - * @brief Generates Huffman codes for each byte from the Huffman tree. - * - * This function recursively traverses the Huffman tree and assigns a binary - * code to each byte. These codes are derived from the path taken to reach - * the byte: left child gives '0' and right child gives '1'. - * - * @param root Pointer to the root node of the Huffman tree. - * @param code Current Huffman code generated during the traversal. - * @param huffmanCodes A reference to a map where the byte and its - * corresponding Huffman code will be stored. - */ -void generateHuffmanCodes( - const HuffmanNode* root, const std::string& code, - std::unordered_map& huffmanCodes); - -/** - * @brief Compresses data using Huffman codes. - * - * This function converts a vector of bytes into a string of binary codes based - * on the Huffman codes provided. Each byte in the input data is replaced - * by its corresponding Huffman code. - * - * @param data The original data to compress. - * @param huffmanCodes The map of bytes to their corresponding Huffman codes. - * @return A string representing the compressed data. - * @throws HuffmanException if a byte in data does not have a corresponding - * Huffman code. - */ -[[nodiscard]] auto compressData( - const std::vector& data, - const std::unordered_map& huffmanCodes) - -> std::string; - -/** - * @brief Decompresses Huffman encoded data back to its original form. - * - * This function decodes a string of binary codes back into the original data - * using the provided Huffman tree. It traverses the Huffman tree from the root - * to the leaf nodes based on the binary string, reconstructing the original - * data. - * - * @param compressedData The Huffman encoded data. - * @param root Pointer to the root of the Huffman tree. - * @return The original decompressed data as a vector of bytes. - * @throws HuffmanException if the compressed data is invalid or the tree is - * null. - */ -[[nodiscard]] auto decompressData(const std::string& compressedData, - const HuffmanNode* root) - -> std::vector; - -/** - * @brief Serializes the Huffman tree into a binary string. - * - * This function converts the Huffman tree into a binary string representation - * which can be stored or transmitted alongside the compressed data. - * - * @param root Pointer to the root node of the Huffman tree. - * @return A binary string representing the serialized Huffman tree. - */ -[[nodiscard]] auto serializeTree(const HuffmanNode* root) -> std::string; - -/** - * @brief Deserializes the binary string back into a Huffman tree. - * - * This function reconstructs the Huffman tree from its binary string - * representation. - * - * @param serializedTree The binary string representing the serialized Huffman - * tree. - * @param index Reference to the current index in the binary string (used during - * recursion). - * @return A unique pointer to the root of the reconstructed Huffman tree. - * @throws HuffmanException if the serialized tree format is invalid. - */ -[[nodiscard]] auto deserializeTree(const std::string& serializedTree, - size_t& index) - -> std::shared_ptr; - -/** - * @brief Visualizes the Huffman tree structure. - * - * This function prints the Huffman tree in a human-readable format for - * debugging and analysis purposes. - * - * @param root Pointer to the root node of the Huffman tree. - * @param indent Current indentation level (used during recursion). - */ -void visualizeHuffmanTree(const HuffmanNode* root, - const std::string& indent = ""); - -} // namespace atom::algorithm - -#endif // ATOM_ALGORITHM_HUFFMAN_HPP \ No newline at end of file diff --git a/src/atom/algorithm/math.cpp b/src/atom/algorithm/math.cpp deleted file mode 100644 index 2578861b..00000000 --- a/src/atom/algorithm/math.cpp +++ /dev/null @@ -1,261 +0,0 @@ -/* - * mathutils.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-10 - -Description: Extra Math Library with SIMD support - -**************************************************/ - -#include "math.hpp" - -#include // For std::bit_width -#include // For std::sqrt -#include // For std::gcd -#ifdef _MSC_VER -#include // For std::runtime_error -#endif - -#include "atom/error/exception.hpp" - -// SIMD headers -#ifdef USE_SIMD -#if defined(__x86_64__) || defined(_M_X64) -#include -#elif defined(__ARM_NEON) -#include -#endif -#endif - -namespace atom::algorithm { - -#if defined(__GNUC__) && defined(__SIZEOF_INT128__) -auto mulDiv64(uint64_t operand, uint64_t multiplier, - uint64_t divider) -> uint64_t { - if (divider == 0) { - THROW_INVALID_ARGUMENT("Division by zero"); - } - - __uint128_t a = operand; - __uint128_t b = multiplier; - __uint128_t c = divider; - - return static_cast((a * b) / c); -} -#elif defined(_MSC_VER) -#include // For _umul128 and _BitScanReverse - -uint64_t mulDiv64(uint64_t operand, uint64_t multiplier, uint64_t divider) { - if (divider == 0) { - THROW_INVALID_ARGUMENT("Division by zero"); - } - - uint64_t highProd; - uint64_t lowProd = _umul128(operand, multiplier, &highProd); - - unsigned long shift = 63 - std::bit_width(divider - 1); - uint64_t normDiv = divider << shift; - - highProd = (highProd << shift) | (lowProd >> (64 - shift)); - lowProd <<= shift; - - uint64_t quotient; - _udiv128(highProd, lowProd, normDiv, "ient); - - return quotient; -} -#else -#error "Platform not supported for mulDiv64 function!" -#endif - -auto safeAdd(uint64_t a, uint64_t b) -> uint64_t { - uint64_t result; - if (__builtin_add_overflow(a, b, &result)) { - THROW_OVERFLOW("Overflow in addition"); - } - return result; -} - -auto safeMul(uint64_t a, uint64_t b) -> uint64_t { - uint64_t result; - if (__builtin_mul_overflow(a, b, &result)) { - THROW_OVERFLOW("Overflow in multiplication"); - } - return result; -} - -auto rotl64(uint64_t n, unsigned int c) -> uint64_t { return std::rotl(n, c); } - -auto rotr64(uint64_t n, unsigned int c) -> uint64_t { return std::rotr(n, c); } - -auto clz64(uint64_t x) -> int { - if (x == 0) { - return 64; - } - return __builtin_clzll(x); -} - -auto normalize(uint64_t x) -> uint64_t { - if (x == 0) { - return 0; - } - int n = clz64(x); - return x << n; -} - -auto safeSub(uint64_t a, uint64_t b) -> uint64_t { - uint64_t result; - if (__builtin_sub_overflow(a, b, &result)) { - THROW_UNDERFLOW("Underflow in subtraction"); - } - return result; -} - -auto safeDiv(uint64_t a, uint64_t b) -> uint64_t { - if (b == 0) { - THROW_INVALID_ARGUMENT("Division by zero"); - } - return a / b; -} - -auto bitReverse64(uint64_t n) -> uint64_t { -#ifdef USE_SIMD -#if defined(__x86_64__) || defined(_M_X64) - return _byteswap_uint64(n); -#elif defined(__ARM_NEON) - return vrev64_u8(vcreate_u8(n)); -#else - // Fallback to non-SIMD implementation -#endif -#endif - n = ((n & 0xAAAAAAAAAAAAAAAA) >> 1) | ((n & 0x5555555555555555) << 1); - n = ((n & 0xCCCCCCCCCCCCCCCC) >> 2) | ((n & 0x3333333333333333) << 2); - n = ((n & 0xF0F0F0F0F0F0F0F0) >> 4) | ((n & 0x0F0F0F0F0F0F0F0F) << 4); - n = ((n & 0xFF00FF00FF00FF00) >> 8) | ((n & 0x00FF00FF00FF00FF) << 8); - n = ((n & 0xFFFF0000FFFF0000) >> 16) | ((n & 0x0000FFFF0000FFFF) << 16); - n = (n >> 32) | (n << 32); - return n; -} - -auto approximateSqrt(uint64_t n) -> uint64_t { -#ifdef USE_SIMD -#if defined(__x86_64__) || defined(_M_X64) - return _mm_cvtsd_si64( - _mm_sqrt_sd(_mm_setzero_pd(), _mm_set_sd(static_cast(n)))); -#elif defined(__ARM_NEON) - float32x2_t x = vdup_n_f32(static_cast(n)); - float32x2_t sqrt_reciprocal = vrsqrte_f32(x); - float32x2_t result = vmul_f32(x, sqrt_reciprocal); - return static_cast(vget_lane_f32(result, 0)); -#else - // Fallback to non-SIMD implementation -#endif -#endif - if (n == 0 || n == 1) { - return n; - } - double x = n; - double y = 1; - double e = 0.000001; - while (x - y > e) { - x = (x + y) / 2; - y = n / x; - } - return static_cast(x); -} - -auto gcd64(uint64_t a, uint64_t b) -> uint64_t { return std::gcd(a, b); } - -auto lcm64(uint64_t a, uint64_t b) -> uint64_t { return a / gcd64(a, b) * b; } - -auto isPowerOfTwo(uint64_t n) -> bool { return n != 0 && (n & (n - 1)) == 0; } - -auto nextPowerOfTwo(uint64_t n) -> uint64_t { -#ifdef USE_SIMD -#if defined(__x86_64__) || defined(_M_X64) - if (n == 0) - return 1; - unsigned long index; - _BitScanReverse64(&index, n); - return 1ULL << (index + 1); -#elif defined(__ARM_NEON) - if (n == 0) - return 1; - return 1ULL << (64 - __builtin_clzll(n - 1)); -#else - // Fallback to non-SIMD implementation -#endif -#endif - if (n == 0) { - return 1; - } - --n; - n |= n >> 1; - n |= n >> 2; - n |= n >> 4; - n |= n >> 8; - n |= n >> 16; - n |= n >> 32; - return n + 1; -} - -// New SIMD-optimized functions - -#ifdef USE_SIMD - -template -void vectorAdd(const T* a, const T* b, T* result, size_t size) { -#if defined(__x86_64__) || defined(_M_X64) - for (size_t i = 0; i < size; i += N) { - __m256i va = - _mm256_loadu_si256(reinterpret_cast(a + i)); - __m256i vb = - _mm256_loadu_si256(reinterpret_cast(b + i)); - __m256i vr = _mm256_add_epi32(va, vb); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(result + i), vr); - } -#elif defined(__ARM_NEON) - for (size_t i = 0; i < size; i += N) { - int32x4_t va = vld1q_s32(reinterpret_cast(a + i)); - int32x4_t vb = vld1q_s32(reinterpret_cast(b + i)); - int32x4_t vr = vaddq_s32(va, vb); - vst1q_s32(reinterpret_cast(result + i), vr); - } -#endif -} - -template -void vectorMul(const T* a, const T* b, T* result, size_t size) { -#if defined(__x86_64__) || defined(_M_X64) - for (size_t i = 0; i < size; i += N) { - __m256i va = - _mm256_loadu_si256(reinterpret_cast(a + i)); - __m256i vb = - _mm256_loadu_si256(reinterpret_cast(b + i)); - __m256i vr = _mm256_mullo_epi32(va, vb); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(result + i), vr); - } -#elif defined(__ARM_NEON) - for (size_t i = 0; i < size; i += N) { - int32x4_t va = vld1q_s32(reinterpret_cast(a + i)); - int32x4_t vb = vld1q_s32(reinterpret_cast(b + i)); - int32x4_t vr = vmulq_s32(va, vb); - vst1q_s32(reinterpret_cast(result + i), vr); - } -#endif -} - -// Explicit instantiations for common types -template void vectorAdd(const int32_t*, const int32_t*, int32_t*, - size_t); -template void vectorMul(const int32_t*, const int32_t*, int32_t*, - size_t); - -#endif // USE_SIMD - -} // namespace atom::algorithm diff --git a/src/atom/algorithm/math.hpp b/src/atom/algorithm/math.hpp deleted file mode 100644 index 92a17f2f..00000000 --- a/src/atom/algorithm/math.hpp +++ /dev/null @@ -1,192 +0,0 @@ -/* - * math.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-10 - -Description: Extra Math Library - -**************************************************/ - -#ifndef ATOM_ALGORITHM_MATH_HPP -#define ATOM_ALGORITHM_MATH_HPP - -#include - -namespace atom::algorithm { -/** - * @brief Performs a 64-bit multiplication followed by division. - * - * This function calculates the result of (operant * multiplier) / divider. - * - * @param operant The first operand for multiplication. - * @param multiplier The second operand for multiplication. - * @param divider The divisor for the division operation. - * @return The result of (operant * multiplier) / divider. - */ -auto mulDiv64(uint64_t operant, uint64_t multiplier, - uint64_t divider) -> uint64_t; - -/** - * @brief Performs a safe addition operation. - * - * This function adds two unsigned 64-bit integers, handling potential overflow. - * - * @param a The first operand for addition. - * @param b The second operand for addition. - * @return The result of a + b, or 0 if there is an overflow. - */ -auto safeAdd(uint64_t a, uint64_t b) -> uint64_t; - -/** - * @brief Performs a safe multiplication operation. - * - * This function multiplies two unsigned 64-bit integers, handling potential - * overflow. - * - * @param a The first operand for multiplication. - * @param b The second operand for multiplication. - * @return The result of a * b, or 0 if there is an overflow. - */ -auto safeMul(uint64_t a, uint64_t b) -> uint64_t; - -/** - * @brief Rotates a 64-bit integer to the left. - * - * This function rotates a 64-bit integer to the left by a specified number of - * bits. - * - * @param n The 64-bit integer to rotate. - * @param c The number of bits to rotate. - * @return The rotated 64-bit integer. - */ -auto rotl64(uint64_t n, unsigned int c) -> uint64_t; - -/** - * @brief Rotates a 64-bit integer to the right. - * - * This function rotates a 64-bit integer to the right by a specified number of - * bits. - * - * @param n The 64-bit integer to rotate. - * @param c The number of bits to rotate. - * @return The rotated 64-bit integer. - */ -auto rotr64(uint64_t n, unsigned int c) -> uint64_t; - -/** - * @brief Counts the leading zeros in a 64-bit integer. - * - * This function counts the number of leading zeros in a 64-bit integer. - * - * @param x The 64-bit integer to count leading zeros in. - * @return The number of leading zeros in the 64-bit integer. - */ -auto clz64(uint64_t x) -> int; - -/** - * @brief Normalizes a 64-bit integer. - * - * This function normalizes a 64-bit integer by shifting it to the right until - * the most significant bit is set. - * - * @param x The 64-bit integer to normalize. - * @return The normalized 64-bit integer. - */ -auto normalize(uint64_t x) -> uint64_t; - -/** - * @brief Performs a safe subtraction operation. - * - * This function subtracts two unsigned 64-bit integers, handling potential - * underflow. - * - * @param a The first operand for subtraction. - * @param b The second operand for subtraction. - * @return The result of a - b, or 0 if there is an underflow. - */ -auto safeSub(uint64_t a, uint64_t b) -> uint64_t; - -/** - * @brief Performs a safe division operation. - * - * This function divides two unsigned 64-bit integers, handling potential - * division by zero. - * - * @param a The numerator for division. - * @param b The denominator for division. - * @return The result of a / b, or 0 if there is a division by zero. - */ -auto safeDiv(uint64_t a, uint64_t b) -> uint64_t; - -/** - * @brief Calculates the bitwise reverse of a 64-bit integer. - * - * This function calculates the bitwise reverse of a 64-bit integer. - * - * @param n The 64-bit integer to reverse. - * @return The bitwise reverse of the 64-bit integer. - */ -auto bitReverse64(uint64_t n) -> uint64_t; - -/** - * @brief Approximates the square root of a 64-bit integer. - * - * This function approximates the square root of a 64-bit integer using a fast - * algorithm. - * - * @param n The 64-bit integer for which to approximate the square root. - * @return The approximate square root of the 64-bit integer. - */ -auto approximateSqrt(uint64_t n) -> uint64_t; - -/** - * @brief Calculates the greatest common divisor (GCD) of two 64-bit integers. - * - * This function calculates the greatest common divisor (GCD) of two 64-bit - * integers. - * - * @param a The first 64-bit integer. - * @param b The second 64-bit integer. - * @return The greatest common divisor of the two 64-bit integers. - */ -auto gcd64(uint64_t a, uint64_t b) -> uint64_t; - -/** - * @brief Calculates the least common multiple (LCM) of two 64-bit integers. - * - * This function calculates the least common multiple (LCM) of two 64-bit - * integers. - * - * @param a The first 64-bit integer. - * @param b The second 64-bit integer. - * @return The least common multiple of the two 64-bit integers. - */ -auto lcm64(uint64_t a, uint64_t b) -> uint64_t; - -/** - * @brief Checks if a 64-bit integer is a power of two. - * - * This function checks if a 64-bit integer is a power of two. - * - * @param n The 64-bit integer to check. - * @return True if the 64-bit integer is a power of two, false otherwise. - */ -auto isPowerOfTwo(uint64_t n) -> bool; - -/** - * @brief Calculates the next power of two for a 64-bit integer. - * - * This function calculates the next power of two for a 64-bit integer. - * - * @param n The 64-bit integer for which to calculate the next power of two. - * @return The next power of two for the 64-bit integer. - */ -auto nextPowerOfTwo(uint64_t n) -> uint64_t; -} // namespace atom::algorithm - -#endif diff --git a/src/atom/algorithm/matrix.hpp b/src/atom/algorithm/matrix.hpp deleted file mode 100644 index f79976ba..00000000 --- a/src/atom/algorithm/matrix.hpp +++ /dev/null @@ -1,369 +0,0 @@ -#ifndef ATOM_ALGORITHM_MATRIX_HPP -#define ATOM_ALGORITHM_MATRIX_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "atom/error/exception.hpp" - -namespace atom::algorithm { -template -class Matrix; - -template -constexpr Matrix identity(); - -// 矩阵模板类,支持编译期矩阵计算 -template -class Matrix { -private: - std::array data_{}; - -public: - // 构造函数 - constexpr Matrix() = default; - constexpr explicit Matrix(const std::array& arr) - : data_(arr) {} - - // 访问矩阵元素 - constexpr auto operator()(std::size_t row, std::size_t col) -> T& { - return data_[row * Cols + col]; - } - - constexpr auto operator()(std::size_t row, - std::size_t col) const -> const T& { - return data_[row * Cols + col]; - } - - // 数据访问器 - auto getData() const -> const std::array& { return data_; } - - auto getData() -> std::array& { return data_; } - - // 打印矩阵 - void print(int width = 8, int precision = 2) const { - for (std::size_t i = 0; i < Rows; ++i) { - for (std::size_t j = 0; j < Cols; ++j) { - std::cout << std::setw(width) << std::fixed - << std::setprecision(precision) << (*this)(i, j) - << ' '; - } - std::cout << '\n'; - } - } - - // 矩阵的迹(对角线元素之和) - constexpr auto trace() const -> T { - static_assert(Rows == Cols, - "Trace is only defined for square matrices"); - T result = T{}; - for (std::size_t i = 0; i < Rows; ++i) { - result += (*this)(i, i); - } - return result; - } - - // Frobenius范数 - auto freseniusNorm() const -> T { - T sum = T{}; - for (const auto& elem : data_) { - sum += std::norm(elem); - } - return std::sqrt(sum); - } - - // 矩阵的最大元素 - auto maxElement() const -> T { - return *std::max_element( - data_.begin(), data_.end(), - [](const T& a, const T& b) { return std::abs(a) < std::abs(b); }); - } - - // 矩阵的最小元素 - auto minElement() const -> T { - return *std::min_element( - data_.begin(), data_.end(), - [](const T& a, const T& b) { return std::abs(a) < std::abs(b); }); - } - - // 判断矩阵是否为对称矩阵 - [[nodiscard]] auto isSymmetric() const -> bool { - static_assert(Rows == Cols, - "Symmetry is only defined for square matrices"); - for (std::size_t i = 0; i < Rows; ++i) { - for (std::size_t j = i + 1; j < Cols; ++j) { - if ((*this)(i, j) != (*this)(j, i)) { - return false; - } - } - } - return true; - } - - // 矩阵的幂运算 - auto pow(unsigned int n) const -> Matrix { - static_assert(Rows == Cols, - "Matrix power is only defined for square matrices"); - if (n == 0) { - return identity(); - } - if (n == 1) { - return *this; - } - Matrix result = *this; - for (unsigned int i = 1; i < n; ++i) { - result = result * (*this); - } - return result; - } - - // 矩阵的行列式(使用LU分解) - auto determinant() const -> T { - static_assert(Rows == Cols, - "Determinant is only defined for square matrices"); - auto [L, U] = lu_decomposition(*this); - T det = T{1}; - for (std::size_t i = 0; i < Rows; ++i) { - det *= U(i, i); - } - return det; - } - - // 矩阵的秩(使用高斯消元) - [[nodiscard]] auto rank() const -> std::size_t { - Matrix temp = *this; - std::size_t rank = 0; - for (std::size_t i = 0; i < Rows && i < Cols; ++i) { - // 找主元 - std::size_t pivot = i; - for (std::size_t j = i + 1; j < Rows; ++j) { - if (std::abs(temp(j, i)) > std::abs(temp(pivot, i))) { - pivot = j; - } - } - if (std::abs(temp(pivot, i)) < 1e-10) { - continue; - } - // 交换行 - if (pivot != i) { - for (std::size_t j = i; j < Cols; ++j) { - std::swap(temp(i, j), temp(pivot, j)); - } - } - // 消元 - for (std::size_t j = i + 1; j < Rows; ++j) { - T factor = temp(j, i) / temp(i, i); - for (std::size_t k = i; k < Cols; ++k) { - temp(j, k) -= factor * temp(i, k); - } - } - ++rank; - } - return rank; - } - - // 矩阵的条件数(使用2范数) - auto conditionNumber() const -> T { - static_assert(Rows == Cols, - "Condition number is only defined for square matrices"); - auto svd = singular_value_decomposition(*this); - return svd[0] / svd[svd.size() - 1]; - } -}; - -// 矩阵加法 -template -constexpr auto operator+(const Matrix& a, - const Matrix& b) - -> Matrix { - Matrix result{}; - for (std::size_t i = 0; i < Rows * Cols; ++i) { - result.get_data()[i] = a.get_data()[i] + b.get_data()[i]; - } - return result; -} - -// 矩阵减法 -template -constexpr auto operator-(const Matrix& a, - const Matrix& b) - -> Matrix { - Matrix result{}; - for (std::size_t i = 0; i < Rows * Cols; ++i) { - result.get_data()[i] = a.get_data()[i] - b.get_data()[i]; - } - return result; -} - -// 矩阵乘法 -template -auto operator*(const Matrix& a, - const Matrix& b) - -> Matrix { - Matrix result{}; - for (std::size_t i = 0; i < RowsA; ++i) { - for (std::size_t j = 0; j < ColsB; ++j) { - for (std::size_t k = 0; k < ColsA_RowsB; ++k) { - result(i, j) += a(i, k) * b(k, j); - } - } - } - return result; -} - -// 标量乘法(左乘和右乘) -template -constexpr auto operator*(const Matrix& m, U scalar) { - Matrix result; - for (std::size_t i = 0; i < Rows * Cols; ++i) { - result.get_data()[i] = m.get_data()[i] * scalar; - } - return result; -} - -template -constexpr auto operator*(U scalar, const Matrix& m) { - return m * scalar; -} - -// 矩阵逐元素乘法(Hadamard积) -template -constexpr auto hadamardProduct(const Matrix& a, - const Matrix& b) - -> Matrix { - Matrix result{}; - for (std::size_t i = 0; i < Rows * Cols; ++i) { - result.get_data()[i] = a.get_data()[i] * b.get_data()[i]; - } - return result; -} - -// 矩阵转置 -template -constexpr auto transpose(const Matrix& m) - -> Matrix { - Matrix result{}; - for (std::size_t i = 0; i < Rows; ++i) { - for (std::size_t j = 0; j < Cols; ++j) { - result(j, i) = m(i, j); - } - } - return result; -} - -// 创建单位矩阵 -template -constexpr auto identity() -> Matrix { - Matrix result{}; - for (std::size_t i = 0; i < Size; ++i) { - result(i, i) = T{1}; - } - return result; -} - -// 矩阵的LU分解 -template -auto luDecomposition(const Matrix& m) - -> std::pair, Matrix> { - Matrix L = identity(); - Matrix U = m; - - for (std::size_t k = 0; k < Size - 1; ++k) { - for (std::size_t i = k + 1; i < Size; ++i) { - if (std::abs(U(k, k)) < 1e-10) { - THROW_RUNTIME_ERROR( - "LU decomposition failed: division by zero"); - } - T factor = U(i, k) / U(k, k); - L(i, k) = factor; - for (std::size_t j = k; j < Size; ++j) { - U(i, j) -= factor * U(k, j); - } - } - } - - return {L, U}; -} - -// 矩阵的奇异值分解(仅返回奇异值) -template -auto singularValueDecomposition(const Matrix& m) - -> std::vector { - const std::size_t n = std::min(Rows, Cols); - Matrix mt = transpose(m); - Matrix mtm = mt * m; - - // 使用幂法计算最大特征值和对应的特征向量 - auto powerIteration = [&mtm](std::size_t max_iter = 100, T tol = 1e-10) { - std::vector v(Cols); - std::generate(v.begin(), v.end(), - []() { return static_cast(rand()) / RAND_MAX; }); - T lambdaOld = 0; - for (std::size_t iter = 0; iter < max_iter; ++iter) { - std::vector vNew(Cols); - for (std::size_t i = 0; i < Cols; ++i) { - for (std::size_t j = 0; j < Cols; ++j) { - vNew[i] += mtm(i, j) * v[j]; - } - } - T lambda = 0; - for (std::size_t i = 0; i < Cols; ++i) { - lambda += vNew[i] * v[i]; - } - T norm = std::sqrt(std::inner_product(vNew.begin(), vNew.end(), - vNew.begin(), T(0))); - for (auto& x : vNew) { - x /= norm; - } - if (std::abs(lambda - lambdaOld) < tol) { - return std::sqrt(lambda); - } - lambdaOld = lambda; - v = vNew; - } - THROW_RUNTIME_ERROR("Power iteration did not converge"); - }; - - std::vector singularValues; - for (std::size_t i = 0; i < n; ++i) { - T sigma = powerIteration(); - singularValues.push_back(sigma); - // Deflate the matrix - Matrix vvt; - for (std::size_t j = 0; j < Cols; ++j) { - for (std::size_t k = 0; k < Cols; ++k) { - vvt(j, k) = mtm(j, k) / (sigma * sigma); - } - } - mtm = mtm - vvt; - } - - std::sort(singularValues.begin(), singularValues.end(), std::greater()); - return singularValues; -} - -// 生成随机矩阵 -template -auto randomMatrix(T min = 0, T max = 1) -> Matrix { - static std::random_device rd; - static std::mt19937 gen(rd()); - std::uniform_real_distribution<> dis(min, max); - - Matrix result; - for (auto& elem : result.get_data()) { - elem = dis(gen); - } - return result; -} - -} // namespace atom::algorithm - -#endif diff --git a/src/atom/algorithm/matrix_compress.cpp b/src/atom/algorithm/matrix_compress.cpp deleted file mode 100644 index 0fc3ec7b..00000000 --- a/src/atom/algorithm/matrix_compress.cpp +++ /dev/null @@ -1,272 +0,0 @@ -#include "matrix_compress.hpp" - -#include -#include -#include -#include -#include "error/exception.hpp" - -#if USE_SIMD -#include -#endif - -namespace atom::algorithm { -auto MatrixCompressor::compress(const Matrix& matrix) -> CompressedData { - CompressedData compressed; - if (matrix.empty() || matrix[0].empty()) { - return compressed; - } - - char currentChar = matrix[0][0]; - int count = 0; - -#ifdef USE_SIMD - // 使用 SIMD 优化压缩 - for (const auto& row : matrix) { - for (size_t i = 0; i < row.size(); i += 16) { - __m128i chars = - _mm_loadu_si128(reinterpret_cast(&row[i])); - for (int j = 0; j < 16; ++j) { - char ch = reinterpret_cast(&chars)[j]; - if (ch == currentChar) { - count++; - } else { - compressed.emplace_back(currentChar, count); - currentChar = ch; - count = 1; - } - } - } - } -#else - // 常规压缩 - for (const auto& row : matrix) { - for (char ch : row) { - if (ch == currentChar) { - count++; - } else { - compressed.emplace_back(currentChar, count); - currentChar = ch; - count = 1; - } - } - } -#endif - compressed.emplace_back(currentChar, count); - - return compressed; -} - -auto MatrixCompressor::decompress(const CompressedData& compressed, int rows, - int cols) -> Matrix { - Matrix matrix(rows, std::vector(cols)); - int index = 0; - -#ifdef USE_SIMD - // 使用 SIMD 优化解压缩 - for (const auto& [ch, count] : compressed) { - __m128i chars = _mm_set1_epi8(ch); - for (int i = 0; i < count; i += 16) { - int remaining = std::min(16, count - i); - for (int j = 0; j < remaining; ++j) { - int row = index / cols; - int col = index % cols; - if (row >= rows || col >= cols) { - THROW_MATRIX_DECOMPRESS_EXCEPTION( - "Decompression error: Invalid matrix size"); - } - matrix[row][col] = reinterpret_cast(&chars)[j]; - index++; - } - } - } -#else - // 常规解压缩 - for (const auto& [ch, count] : compressed) { - for (int i = 0; i < count; ++i) { - int row = index / cols; - int col = index % cols; - if (row >= rows || col >= cols) { - THROW_MATRIX_DECOMPRESS_EXCEPTION( - "Decompression error: Invalid matrix size"); - } - matrix[row][col] = ch; - index++; - } - } -#endif - - if (index != rows * cols) { - THROW_MATRIX_DECOMPRESS_EXCEPTION( - "Decompression error: Incorrect number of elements"); - } - - return matrix; -} - -void MatrixCompressor::printMatrix(const Matrix& matrix) { - for (const auto& row : matrix) { - for (char ch : row) { - std::cout << ch << ' '; - } - std::cout << '\n'; - } -} - -auto MatrixCompressor::generateRandomMatrix( - int rows, int cols, const std::string& charset) -> Matrix { - std::random_device randomDevice; - std::mt19937 generator(randomDevice()); - std::uniform_int_distribution distribution( - 0, static_cast(charset.length()) - 1); - - Matrix matrix(rows, std::vector(cols)); - for (auto& row : matrix) { - std::ranges::generate(row.begin(), row.end(), [&]() { - return charset[distribution(generator)]; - }); - } - return matrix; -} - -void MatrixCompressor::saveCompressedToFile(const CompressedData& compressed, - const std::string& filename) { - std::ofstream file(filename, std::ios::binary); - if (!file) { - THROW_FAIL_TO_OPEN_FILE("Unable to open file for writing: " + filename); - } - - for (const auto& [ch, count] : compressed) { - file.write(reinterpret_cast(&ch), sizeof(ch)); - file.write(reinterpret_cast(&count), sizeof(count)); - } -} - -auto MatrixCompressor::loadCompressedFromFile(const std::string& filename) - -> CompressedData { - std::ifstream file(filename, std::ios::binary); - if (!file) { - THROW_FAIL_TO_OPEN_FILE("Unable to open file for reading: " + filename); - } - - CompressedData compressed; - char ch; - int count; - while (file.read(reinterpret_cast(&ch), sizeof(ch)) && - file.read(reinterpret_cast(&count), sizeof(count))) { - compressed.emplace_back(ch, count); - } - - return compressed; -} - -auto MatrixCompressor::calculateCompressionRatio( - const Matrix& original, const CompressedData& compressed) -> double { - size_t originalSize = original.size() * original[0].size() * sizeof(char); - size_t compressedSize = compressed.size() * (sizeof(char) + sizeof(int)); - return static_cast(compressedSize) / - static_cast(originalSize); -} - -auto MatrixCompressor::downsample(const Matrix& matrix, int factor) -> Matrix { - if (factor <= 0) { - THROW_INVALID_ARGUMENT("Downsampling factor must be positive"); - } - - int rows = static_cast(matrix.size()); - int cols = static_cast(matrix[0].size()); - int newRows = std::max(1, rows / factor); - int newCols = std::max(1, cols / factor); - - Matrix downsampled(newRows, std::vector(newCols)); - - for (int i = 0; i < newRows; ++i) { - for (int j = 0; j < newCols; ++j) { - // 使用简单的平均值作为降采样策略 - int sum = 0; - int count = 0; - for (int di = 0; di < factor && i * factor + di < rows; ++di) { - for (int dj = 0; dj < factor && j * factor + dj < cols; ++dj) { - sum += matrix[i * factor + di][j * factor + dj]; - count++; - } - } - downsampled[i][j] = static_cast(sum / count); - } - } - - return downsampled; -} - -auto MatrixCompressor::upsample(const Matrix& matrix, int factor) -> Matrix { - if (factor <= 0) { - THROW_INVALID_ARGUMENT("Upsampling factor must be positive"); - } - - int rows = static_cast(matrix.size()); - int cols = static_cast(matrix[0].size()); - int newRows = rows * factor; - int newCols = cols * factor; - - Matrix upsampled(newRows, std::vector(newCols)); - - for (int i = 0; i < newRows; ++i) { - for (int j = 0; j < newCols; ++j) { - // 使用最近邻插值 - upsampled[i][j] = matrix[i / factor][j / factor]; - } - } - - return upsampled; -} - -auto MatrixCompressor::calculateMSE(const Matrix& matrix1, - const Matrix& matrix2) -> double { - if (matrix1.size() != matrix2.size() || - matrix1[0].size() != matrix2[0].size()) { - THROW_INVALID_ARGUMENT("Matrices must have the same dimensions"); - } - - double mse = 0.0; - auto rows = static_cast(matrix1.size()); - auto cols = static_cast(matrix1[0].size()); - - for (int i = 0; i < rows; ++i) { - for (int j = 0; j < cols; ++j) { - double diff = static_cast(matrix1[i][j]) - - static_cast(matrix2[i][j]); - mse += diff * diff; - } - } - - return mse / (rows * cols); -} - -#if ATOM_ENABLE_DEBUG -void performanceTest(int rows, int cols) { - auto matrix = MatrixCompressor::generateRandomMatrix(rows, cols); - - auto start = std::chrono::high_resolution_clock::now(); - auto compressed = MatrixCompressor::compress(matrix); - auto end = std::chrono::high_resolution_clock::now(); - - std::chrono::duration compression_time = end - start; - - start = std::chrono::high_resolution_clock::now(); - auto decompressed = MatrixCompressor::decompress(compressed, rows, cols); - end = std::chrono::high_resolution_clock::now(); - - std::chrono::duration decompression_time = end - start; - - double compression_ratio = - MatrixCompressor::calculateCompressionRatio(matrix, compressed); - - std::cout << "Matrix size: " << rows << "x" << cols << "\n"; - std::cout << "Compression time: " << compression_time.count() << " ms\n"; - std::cout << "Decompression time: " << decompression_time.count() - << " ms\n"; - std::cout << "Compression ratio: " << compression_ratio << "\n"; - std::cout << "Compressed size: " << compressed.size() << " elements\n"; -} -#endif -} // namespace atom::algorithm diff --git a/src/atom/algorithm/matrix_compress.hpp b/src/atom/algorithm/matrix_compress.hpp deleted file mode 100644 index 26441547..00000000 --- a/src/atom/algorithm/matrix_compress.hpp +++ /dev/null @@ -1,139 +0,0 @@ -#ifndef MATRIX_COMPRESS_HPP -#define MATRIX_COMPRESS_HPP - -#include -#include - -#include "atom/error/exception.hpp" - -class MatrixCompressException : public atom::error::Exception { -public: - using atom::error::Exception::Exception; -}; - -#define THROW_MATRIX_COMPRESS_EXCEPTION(...) \ - throw MatrixCompressException(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__); - -#define THROW_NESTED_MATRIX_COMPRESS_EXCEPTION(...) \ - MatrixCompressException::rethrowNested(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__); - -class MatrixDecompressException : public atom::error::Exception { -public: - using atom::error::Exception::Exception; -}; - -#define THROW_MATRIX_DECOMPRESS_EXCEPTION(...) \ - throw MatrixDecompressException(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__); - -#define THROW_NESTED_MATRIX_DECOMPRESS_EXCEPTION(...) \ - MatrixDecompressException::rethrowNested(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__); - -namespace atom::algorithm { -/** - * @class MatrixCompressor - * @brief A class for compressing and decompressing matrices. - */ -class MatrixCompressor { -public: - using Matrix = std::vector>; - using CompressedData = std::vector>; - - /** - * @brief Compresses a matrix using run-length encoding. - * @param matrix The matrix to compress. - * @return The compressed data. - */ - static auto compress(const Matrix& matrix) -> CompressedData; - - /** - * @brief Decompresses data into a matrix. - * @param compressed The compressed data. - * @param rows The number of rows in the decompressed matrix. - * @param cols The number of columns in the decompressed matrix. - * @return The decompressed matrix. - */ - static auto decompress(const CompressedData& compressed, int rows, - int cols) -> Matrix; - - /** - * @brief Prints the matrix to the standard output. - * @param matrix The matrix to print. - */ - static void printMatrix(const Matrix& matrix); - - /** - * @brief Generates a random matrix. - * @param rows The number of rows in the matrix. - * @param cols The number of columns in the matrix. - * @param charset The set of characters to use for generating the matrix. - * @return The generated random matrix. - */ - static auto generateRandomMatrix( - int rows, int cols, const std::string& charset = "ABCD") -> Matrix; - - /** - * @brief Saves the compressed data to a file. - * @param compressed The compressed data to save. - * @param filename The name of the file to save the data to. - */ - static void saveCompressedToFile(const CompressedData& compressed, - const std::string& filename); - - /** - * @brief Loads compressed data from a file. - * @param filename The name of the file to load the data from. - * @return The loaded compressed data. - */ - static auto loadCompressedFromFile(const std::string& filename) - -> CompressedData; - - /** - * @brief Calculates the compression ratio. - * @param original The original matrix. - * @param compressed The compressed data. - * @return The compression ratio. - */ - static auto calculateCompressionRatio( - const Matrix& original, const CompressedData& compressed) -> double; - - /** - * @brief Downsamples a matrix by a given factor. - * @param matrix The matrix to downsample. - * @param factor The downsampling factor. - * @return The downsampled matrix. - */ - static auto downsample(const Matrix& matrix, int factor) -> Matrix; - - /** - * @brief Upsamples a matrix by a given factor. - * @param matrix The matrix to upsample. - * @param factor The upsampling factor. - * @return The upsampled matrix. - */ - static auto upsample(const Matrix& matrix, int factor) -> Matrix; - - /** - * @brief Calculates the mean squared error (MSE) between two matrices. - * @param matrix1 The first matrix. - * @param matrix2 The second matrix. - * @return The mean squared error. - */ - static auto calculateMSE(const Matrix& matrix1, - const Matrix& matrix2) -> double; -}; - -#if ATOM_ENABLE_DEBUG -/** - * @brief Runs a performance test on matrix compression and decompression. - * @param rows The number of rows in the test matrix. - * @param cols The number of columns in the test matrix. - */ -void performanceTest(int rows, int cols); -#endif -} // namespace atom::algorithm - -#endif // MATRIX_COMPRESS_HPP diff --git a/src/atom/algorithm/md5.cpp b/src/atom/algorithm/md5.cpp deleted file mode 100644 index cbe7741d..00000000 --- a/src/atom/algorithm/md5.cpp +++ /dev/null @@ -1,171 +0,0 @@ -/* - * md5.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-10 - -Description: Self implemented MD5 algorithm. - -**************************************************/ - -#include "md5.hpp" - -#include -#include -#include -#include -#include -#include - -#ifdef USE_OPENMP -#include -#endif - -namespace atom::algorithm { - -constexpr std::array T{ - 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a, - 0xa8304613, 0xfd469501, 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, - 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, 0xf61e2562, 0xc040b340, - 0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8, - 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8, - 0x676f02d9, 0x8d2a4c8a, 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c, - 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, 0x289b7ec6, 0xeaa127fa, - 0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665, - 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92, - 0xffeff47d, 0x85845dd1, 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, - 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391}; - -constexpr std::array s{ - 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, - 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, - 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, - 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21}; - -void MD5::init() { - a_ = 0x67452301; - b_ = 0xefcdab89; - c_ = 0x98badcfe; - d_ = 0x10325476; - count_ = 0; - buffer_.clear(); -} - -void MD5::update(const std::string &input) { - auto update_length = [this](size_t length) { count_ += length * 8; }; - - update_length(input.size()); - - for (char ch : input) { - buffer_.push_back(static_cast(ch)); - if (buffer_.size() == 64) { - processBlock(buffer_.data()); - buffer_.clear(); - } - } - - // Padding - buffer_.push_back(0x80); - while (buffer_.size() < 56) { - buffer_.push_back(0x00); - } - - for (int i = 0; i < 8; ++i) { - buffer_.push_back(static_cast((count_ >> (i * 8)) & 0xff)); - } - - processBlock(buffer_.data()); -} - -auto MD5::finalize() -> std::string { - std::stringstream ss; - ss << std::hex << std::setfill('0'); - ss << std::setw(8) << std::byteswap(a_); - ss << std::setw(8) << std::byteswap(b_); - ss << std::setw(8) << std::byteswap(c_); - ss << std::setw(8) << std::byteswap(d_); - return ss.str(); -} - -void MD5::processBlock(const uint8_t *block) { - std::array M; - for (size_t i = 0; i < 16; ++i) { - M[i] = std::bit_cast( - std::array{block[i * 4], block[i * 4 + 1], - block[i * 4 + 2], block[i * 4 + 3]}); - } - - uint32_t a = a_; - uint32_t b = b_; - uint32_t c = c_; - uint32_t d = d_; - -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (uint32_t i = 0; i < 64; ++i) { - uint32_t f, g; - if (i < 16) { - f = F(b, c, d); - g = i; - } else if (i < 32) { - f = G(b, c, d); - g = (5 * i + 1) % 16; - } else if (i < 48) { - f = H(b, c, d); - g = (3 * i + 5) % 16; - } else { - f = I(b, c, d); - g = (7 * i) % 16; - } - - uint32_t temp = d; - d = c; - c = b; - b += leftRotate(a + f + T[i] + M[g], s[i]); - a = temp; - } - -#ifdef USE_OPENMP -#pragma omp critical -#endif - { - a_ += a; - b_ += b; - c_ += c; - d_ += d; - } -} - -auto MD5::F(uint32_t x, uint32_t y, uint32_t z) -> uint32_t { - return (x & y) | (~x & z); -} - -auto MD5::G(uint32_t x, uint32_t y, uint32_t z) -> uint32_t { - return (x & z) | (y & ~z); -} - -auto MD5::H(uint32_t x, uint32_t y, uint32_t z) -> uint32_t { - return x ^ y ^ z; -} - -auto MD5::I(uint32_t x, uint32_t y, uint32_t z) -> uint32_t { - return y ^ (x | ~z); -} - -auto MD5::leftRotate(uint32_t x, uint32_t n) -> uint32_t { - return std::rotl(x, n); -} - -auto MD5::encrypt(const std::string &input) -> std::string { - MD5 md5; - md5.init(); - md5.update(input); - return md5.finalize(); -} - -} // namespace atom::algorithm diff --git a/src/atom/algorithm/md5.hpp b/src/atom/algorithm/md5.hpp deleted file mode 100644 index 05120581..00000000 --- a/src/atom/algorithm/md5.hpp +++ /dev/null @@ -1,119 +0,0 @@ -/* - * md5.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-10 - -Description: Self implemented MD5 algorithm. - -**************************************************/ - -#ifndef ATOM_UTILS_MD5_HPP -#define ATOM_UTILS_MD5_HPP - -#include -#include -#include - -namespace atom::algorithm { - -/** - * @class MD5 - * @brief A class that implements the MD5 hashing algorithm. - */ -class MD5 { -public: - /** - * @brief Encrypts the input string using the MD5 algorithm. - * @param input The input string to be hashed. - * @return The MD5 hash of the input string. - */ - static auto encrypt(const std::string &input) -> std::string; - -private: - /** - * @brief Initializes the MD5 context. - */ - void init(); - - /** - * @brief Updates the MD5 context with a new input string. - * @param input The input string to update the context with. - */ - void update(const std::string &input); - - /** - * @brief Finalizes the MD5 hash and returns the result. - * @return The finalized MD5 hash as a string. - */ - auto finalize() -> std::string; - - /** - * @brief Processes a 512-bit block of the input. - * @param block A pointer to the 512-bit block. - */ - void processBlock(const uint8_t *block); - - /** - * @brief MD5 auxiliary function F. - * @param x Input value. - * @param y Input value. - * @param z Input value. - * @return The result of the function. - */ - static auto F(uint32_t x, uint32_t y, uint32_t z) -> uint32_t; - - /** - * @brief MD5 auxiliary function G. - * @param x Input value. - * @param y Input value. - * @param z Input value. - * @return The result of the function. - */ - static auto G(uint32_t x, uint32_t y, uint32_t z) -> uint32_t; - - /** - * @brief MD5 auxiliary function H. - * @param x Input value. - * @param y Input value. - * @param z Input value. - * @return The result of the function. - */ - static auto H(uint32_t x, uint32_t y, uint32_t z) -> uint32_t; - - /** - * @brief MD5 auxiliary function I. - * @param x Input value. - * @param y Input value. - * @param z Input value. - * @return The result of the function. - */ - static auto I(uint32_t x, uint32_t y, uint32_t z) -> uint32_t; - - /** - * @brief Rotates the bits of x to the left by n positions. - * @param x The value to be rotated. - * @param n The number of positions to rotate. - * @return The rotated value. - */ - static auto leftRotate(uint32_t x, uint32_t n) -> uint32_t; - - /** - * @brief Reverses the byte order of a 32-bit value. - * @param x The value to reverse. - * @return The byte-reversed value. - */ - static auto reverseBytes(uint32_t x) -> uint32_t; - - uint32_t a_, b_, c_, d_; ///< MD5 state variables. - uint64_t count_; ///< Number of bits processed. - std::vector buffer_; ///< Input buffer. -}; - -} // namespace atom::algorithm - -#endif // ATOM_UTILS_MD5_HPP diff --git a/src/atom/algorithm/mhash.cpp b/src/atom/algorithm/mhash.cpp deleted file mode 100644 index 05026c3e..00000000 --- a/src/atom/algorithm/mhash.cpp +++ /dev/null @@ -1,335 +0,0 @@ -/* - * mhash.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-12-16 - -Description: Implementation of murmur3 hash and quick hash - -**************************************************/ - -#include "mhash.hpp" - -#include -#include -#include -#include -#include - -#include "atom/error/exception.hpp" -#include "atom/utils/random.hpp" - -#include -#include -#include -#include - -namespace atom::algorithm { -// Keccak state constants -constexpr size_t K_KECCAK_F_RATE = 1088; // For Keccak-256 -constexpr size_t K_ROUNDS = 24; -constexpr size_t K_STATE_SIZE = 5; -constexpr size_t K_RATE_IN_BYTES = K_KECCAK_F_RATE / 8; -constexpr uint8_t K_PADDING_BYTE = 0x06; -constexpr uint8_t K_PADDING_LAST_BYTE = 0x80; - -// Round constants for Keccak -constexpr std::array K_ROUND_CONSTANTS = { - 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, - 0x8000000080008000ULL, 0x000000000000808bULL, 0x0000000080000001ULL, - 0x8000000080008081ULL, 0x8000000000008009ULL, 0x000000000000008aULL, - 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, - 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, - 0x8000000000008003ULL, 0x8000000000008002ULL, 0x8000000000000080ULL, - 0x000000000000800aULL, 0x800000008000000aULL, 0x8000000080008081ULL, - 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL}; - -// Rotation offsets -constexpr std::array, K_STATE_SIZE> - K_ROTATION_CONSTANTS = {{{0, 1, 62, 28, 27}, - {36, 44, 6, 55, 20}, - {3, 10, 43, 25, 39}, - {41, 45, 15, 21, 8}, - {18, 2, 61, 56, 14}}}; - -// Keccak state as 5x5 matrix of 64-bit integers -using StateArray = std::array, K_STATE_SIZE>; - -namespace { -#if USE_OPENCL -const char *minhashKernelSource = R"CLC( -__kernel void minhash_kernel(__global const size_t* hashes, __global size_t* signature, __global const size_t* a_values, __global const size_t* b_values, const size_t p, const size_t num_hashes, const size_t num_elements) { - int gid = get_global_id(0); - if (gid < num_hashes) { - size_t min_hash = SIZE_MAX; - size_t a = a_values[gid]; - size_t b = b_values[gid]; - for (size_t i = 0; i < num_elements; ++i) { - size_t h = (a * hashes[i] + b) % p; - if (h < min_hash) { - min_hash = h; - } - } - signature[gid] = min_hash; - } -} -)CLC"; -#endif -} // anonymous namespace - -MinHash::MinHash(size_t num_hashes) -#if USE_OPENCL - : opencl_available_(false) -#endif -{ - hash_functions_.reserve(num_hashes); - for (size_t i = 0; i < num_hashes; ++i) { - hash_functions_.emplace_back(generateHashFunction()); - } -#if USE_OPENCL - initializeOpenCL(); -#endif -} - -MinHash::~MinHash() { -#if USE_OPENCL - cleanupOpenCL(); -#endif -} - -#if USE_OPENCL -void MinHash::initializeOpenCL() { - cl_int err; - cl_platform_id platform; - cl_device_id device; - - err = clGetPlatformIDs(1, &platform, nullptr); - if (err != CL_SUCCESS) { - return; - } - - err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &device, nullptr); - if (err != CL_SUCCESS) { - return; - } - - context_ = clCreateContext(nullptr, 1, &device, nullptr, nullptr, &err); - if (err != CL_SUCCESS) { - return; - } - - queue_ = clCreateCommandQueue(context_, device, 0, &err); - if (err != CL_SUCCESS) { - return; - } - - program_ = clCreateProgramWithSource(context_, 1, &minhashKernelSource, - nullptr, &err); - if (err != CL_SUCCESS) { - return; - } - - err = clBuildProgram(program_, 1, &device, nullptr, nullptr, nullptr); - if (err != CL_SUCCESS) { - return; - } - - minhash_kernel_ = clCreateKernel(program_, "minhash_kernel", &err); - if (err == CL_SUCCESS) { - opencl_available_ = true; - } -} - -void MinHash::cleanupOpenCL() { - if (opencl_available_) { - clReleaseKernel(minhash_kernel_); - clReleaseProgram(program_); - clReleaseCommandQueue(queue_); - clReleaseContext(context_); - } -} -#endif - -auto MinHash::generateHashFunction() -> HashFunction { - utils::Random> rand( - 0, std::numeric_limits::max()); - - size_t a = rand(); - size_t b = rand(); - size_t p = std::numeric_limits::max(); - - return [a, b, p](size_t x) -> size_t { return (a * x + b) % p; }; -} - -auto MinHash::jaccardIndex(const std::vector &sig1, - const std::vector &sig2) -> double { - size_t equalCount = 0; - - for (size_t i = 0; i < sig1.size(); ++i) { - if (sig1[i] == sig2[i]) { - ++equalCount; - } - } - - return static_cast(equalCount) / sig1.size(); -} - -auto hexstringFromData(const std::string &data) -> std::string { - const char *hexChars = "0123456789ABCDEF"; - std::string output; - output.reserve(data.size() * 2); // Reserve space for the hex string - - for (unsigned char byte : data) { - output.push_back(hexChars[(byte >> 4) & 0x0F]); - output.push_back(hexChars[byte & 0x0F]); - } - - return output; -} - -auto dataFromHexstring(const std::string &data) -> std::string { - if (data.size() % 2 != 0) { - THROW_INVALID_ARGUMENT("Hex string length must be even"); - } - - std::string result; - result.resize(data.size() / 2); - - size_t outputIndex = 0; - for (size_t i = 0; i < data.size(); i += 2) { - int byte = 0; - auto [ptr, ec] = - std::from_chars(data.data() + i, data.data() + i + 2, byte, 16); - - if (ec == std::errc::invalid_argument || ptr != data.data() + i + 2) { - THROW_INVALID_ARGUMENT("Invalid hex character"); - } - - result[outputIndex++] = static_cast(byte); - } - - return result; -} - -// θ step: XOR each column and then propagate changes across the state -inline void theta(StateArray &stateArray) { - std::array column, diff; - for (size_t colIndex = 0; colIndex < K_STATE_SIZE; ++colIndex) { - column[colIndex] = stateArray[colIndex][0] ^ stateArray[colIndex][1] ^ - stateArray[colIndex][2] ^ stateArray[colIndex][3] ^ - stateArray[colIndex][4]; - } - for (size_t colIndex = 0; colIndex < K_STATE_SIZE; ++colIndex) { - diff[colIndex] = column[(colIndex + 4) % K_STATE_SIZE] ^ - std::rotl(column[(colIndex + 1) % K_STATE_SIZE], 1); - for (size_t rowIndex = 0; rowIndex < K_STATE_SIZE; ++rowIndex) { - stateArray[colIndex][rowIndex] ^= diff[colIndex]; - } - } -} - -// ρ step: Rotate each bit-plane by pre-determined offsets -inline void rho(StateArray &stateArray) { - for (size_t colIndex = 0; colIndex < K_STATE_SIZE; ++colIndex) { - for (size_t rowIndex = 0; rowIndex < K_STATE_SIZE; ++rowIndex) { - stateArray[colIndex][rowIndex] = std::rotl( - stateArray[colIndex][rowIndex], - static_cast(K_ROTATION_CONSTANTS[colIndex][rowIndex])); - } - } -} - -// π step: Permute bits to new positions based on a fixed pattern -inline void pi(StateArray &stateArray) { - StateArray temp = stateArray; - for (size_t colIndex = 0; colIndex < K_STATE_SIZE; ++colIndex) { - for (size_t rowIndex = 0; rowIndex < K_STATE_SIZE; ++rowIndex) { - stateArray[colIndex][rowIndex] = - temp[(colIndex + 3 * rowIndex) % K_STATE_SIZE][colIndex]; - } - } -} - -// χ step: Non-linear step XORs data across rows, producing diffusion -inline void chi(StateArray &stateArray) { - for (size_t rowIndex = 0; rowIndex < K_STATE_SIZE; ++rowIndex) { - std::array temp = stateArray[rowIndex]; - for (size_t colIndex = 0; colIndex < K_STATE_SIZE; ++colIndex) { - stateArray[colIndex][rowIndex] ^= - (~temp[(colIndex + 1) % K_STATE_SIZE] & - temp[(colIndex + 2) % K_STATE_SIZE]); - } - } -} - -// ι step: XOR a round constant into the first state element -inline void iota(StateArray &stateArray, size_t round) { - stateArray[0][0] ^= K_ROUND_CONSTANTS[round]; -} - -// Keccak-p permutation: 24 rounds of transformations on the state -inline void keccakP(StateArray &stateArray) { - for (size_t round = 0; round < K_ROUNDS; ++round) { - theta(stateArray); - rho(stateArray); - pi(stateArray); - chi(stateArray); - iota(stateArray, round); - } -} - -// Absorb phase: XOR input into the state and permute -void absorb(StateArray &state, const uint8_t *input, size_t length) { - while (length >= K_RATE_IN_BYTES) { - for (size_t i = 0; i < K_RATE_IN_BYTES / 8; ++i) { - state[i % K_STATE_SIZE][i / K_STATE_SIZE] ^= - std::bit_cast(input + i * 8); - } - keccakP(state); - input += K_RATE_IN_BYTES; - length -= K_RATE_IN_BYTES; - } -} - -// Padding and absorbing the last block -void padAndAbsorb(StateArray &state, const uint8_t *input, size_t length) { - std::array paddedBlock = {}; - std::memcpy(paddedBlock.data(), input, length); - paddedBlock[length] = K_PADDING_BYTE; // Keccak padding - paddedBlock.back() |= K_PADDING_LAST_BYTE; // Set last bit to 1 - absorb(state, paddedBlock.data(), paddedBlock.size()); -} - -// Squeeze phase: Extract output from the state -void squeeze(StateArray &state, uint8_t *output, size_t outputLength) { - while (outputLength >= K_RATE_IN_BYTES) { - for (size_t i = 0; i < K_RATE_IN_BYTES / 8; ++i) { - std::memcpy(output + i * 8, - &state[i % K_STATE_SIZE][i / K_STATE_SIZE], 8); - } - keccakP(state); - output += K_RATE_IN_BYTES; - outputLength -= K_RATE_IN_BYTES; - } - for (size_t i = 0; i < outputLength / 8; ++i) { - std::memcpy(output + i * 8, &state[i % K_STATE_SIZE][i / K_STATE_SIZE], - 8); - } -} - -// Keccak-256 hashing function -auto keccak256(const uint8_t *input, - size_t length) -> std::array { - StateArray state = {}; - padAndAbsorb(state, input, length); - - std::array hash = {}; - squeeze(state, hash.data(), hash.size()); - return hash; -} - -} // namespace atom::algorithm diff --git a/src/atom/algorithm/mhash.hpp b/src/atom/algorithm/mhash.hpp deleted file mode 100644 index 881176c5..00000000 --- a/src/atom/algorithm/mhash.hpp +++ /dev/null @@ -1,223 +0,0 @@ -/* - * mhash.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-12-16 - -Description: Implementation of murmur3 hash and quick hash - -**************************************************/ - -#ifndef ATOM_ALGORITHM_MHASH_HPP -#define ATOM_ALGORITHM_MHASH_HPP - -#include -#include -#include -#include -#include -#include - -#if USE_OPENCL -#include -#endif - -#include "atom/macro.hpp" - -namespace atom::algorithm { -constexpr size_t K_HASH_SIZE = 32; - -/** - * @brief Converts a string to a hexadecimal string representation. - * - * @param data The input string. - * @return std::string The hexadecimal string representation. - */ -ATOM_NODISCARD auto hexstringFromData(const std::string& data) -> std::string; - -/** - * @brief Converts a hexadecimal string representation to binary data. - * - * @param data The input hexadecimal string. - * @return std::string The binary data. - * @throw std::invalid_argument If the input hexstring is not a valid - * hexadecimal string. - */ -ATOM_NODISCARD auto dataFromHexstring(const std::string& data) -> std::string; - -/** - * @brief Implements the MinHash algorithm for estimating Jaccard similarity. - * - * The MinHash algorithm generates hash signatures for sets and estimates the - * Jaccard index between sets based on these signatures. - */ -class MinHash { -public: - /** - * @brief Type definition for a hash function used in MinHash. - */ - using HashFunction = std::function; - - /** - * @brief Constructs a MinHash object with a specified number of hash - * functions. - * - * @param num_hashes The number of hash functions to use for MinHash. - */ - explicit MinHash(size_t num_hashes); - - /** - * @brief Destructor to clean up OpenCL resources. - */ - ~MinHash(); - - /** - * @brief Computes the MinHash signature (hash values) for a given set. - * - * @tparam Range Type of the range representing the set elements. - * @param set The set for which to compute the MinHash signature. - * @return std::vector MinHash signature (hash values) for the set. - */ - template - auto computeSignature(const Range& set) const -> std::vector { - std::vector signature(hash_functions_.size(), - std::numeric_limits::max()); -#if USE_OPENCL - if (opencl_available_) { - computeSignatureOpenCL(set, signature); - } else { -#endif - for (const auto& element : set) { - size_t elementHash = - std::hash{}(element); - for (size_t i = 0; i < hash_functions_.size(); ++i) { - signature[i] = - std::min(signature[i], hash_functions_[i](elementHash)); - } - } -#if USE_OPENCL - } -#endif - return signature; - } - - /** - * @brief Computes the Jaccard index between two sets based on their MinHash - * signatures. - * - * @param sig1 MinHash signature of the first set. - * @param sig2 MinHash signature of the second set. - * @return double Estimated Jaccard index between the two sets. - */ - static auto jaccardIndex(const std::vector& sig1, - const std::vector& sig2) -> double; - -private: - /** - * @brief Vector of hash functions used for MinHash. - */ - std::vector hash_functions_; - - /** - * @brief Generates a hash function suitable for MinHash. - * - * @return HashFunction Generated hash function. - */ - static auto generateHashFunction() -> HashFunction; - -#if USE_OPENCL - /** - * @brief OpenCL resources and state. - */ - cl_context context_; - cl_command_queue queue_; - cl_program program_; - cl_kernel minhash_kernel_; - bool opencl_available_; - - /** - * @brief Initializes OpenCL context and resources. - */ - void initializeOpenCL(); - - /** - * @brief Cleans up OpenCL resources. - */ - void cleanupOpenCL(); - - /** - * @brief Computes the MinHash signature using OpenCL. - * - * @tparam Range Type of the range representing the set elements. - * @param set The set for which to compute the MinHash signature. - * @param signature The vector to store the computed signature. - */ - template - void computeSignatureOpenCL(const Range& set, - std::vector& signature) const { - cl_int err; - size_t numHashes = hash_functions_.size(); - size_t numElements = set.size(); - - std::vector hashes; - hashes.reserve(numElements); - for (const auto& element : set) { - hashes.push_back(std::hash{}(element)); - } - - std::vector aValues(numHashes); - std::vector bValues(numHashes); - for (size_t i = 0; i < numHashes; ++i) { - aValues; // Use the generated hash function's "a" value - bValues; // Use the generated hash function's "b" value - } - - cl_mem hashesBuffer = - clCreateBuffer(context_, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, - numElements * sizeof(size_t), hashes.data(), &err); - cl_mem signatureBuffer = - clCreateBuffer(context_, CL_MEM_WRITE_ONLY, - numHashes * sizeof(size_t), nullptr, &err); - cl_mem aValuesBuffer = - clCreateBuffer(context_, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, - numHashes * sizeof(size_t), aValues.data(), &err); - cl_mem bValuesBuffer = - clCreateBuffer(context_, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, - numHashes * sizeof(size_t), bValues.data(), &err); - - size_t p = std::numeric_limits::max(); - - clSetKernelArg(minhash_kernel_, 0, sizeof(cl_mem), &hashesBuffer); - clSetKernelArg(minhash_kernel_, 1, sizeof(cl_mem), &signatureBuffer); - clSetKernelArg(minhash_kernel_, 2, sizeof(cl_mem), &aValuesBuffer); - clSetKernelArg(minhash_kernel_, 3, sizeof(cl_mem), &bValuesBuffer); - clSetKernelArg(minhash_kernel_, 4, sizeof(size_t), &p); - clSetKernelArg(minhash_kernel_, 5, sizeof(size_t), &numHashes); - clSetKernelArg(minhash_kernel_, 6, sizeof(size_t), &numElements); - - size_t globalWorkSize = numHashes; - clEnqueueNDRangeKernel(queue_, minhash_kernel_, 1, nullptr, - &globalWorkSize, nullptr, 0, nullptr, nullptr); - - clEnqueueReadBuffer(queue_, signatureBuffer, CL_TRUE, 0, - numHashes * sizeof(size_t), signature.data(), 0, - nullptr, nullptr); - - clReleaseMemObject(hashesBuffer); - clReleaseMemObject(signatureBuffer); - clReleaseMemObject(aValuesBuffer); - clReleaseMemObject(bValuesBuffer); - } -#endif -}; - -auto keccak256(const uint8_t *input, - size_t length) -> std::array; - -} // namespace atom::algorithm - -#endif // ATOM_ALGORITHM_MHASH_HPP diff --git a/src/atom/algorithm/perlin.hpp b/src/atom/algorithm/perlin.hpp deleted file mode 100644 index 68e423c1..00000000 --- a/src/atom/algorithm/perlin.hpp +++ /dev/null @@ -1,335 +0,0 @@ -#ifndef ATOM_ALGORITHM_PERLIN_HPP -#define ATOM_ALGORITHM_PERLIN_HPP - -#include -#include -#include -#include -#include -#include -#include - -#ifdef USE_OPENCL // 宏定义:是否启用OpenCL -#include -#endif - -namespace atom::algorithm { -class PerlinNoise { -public: - explicit PerlinNoise( - unsigned int seed = std::default_random_engine::default_seed) { - p.resize(512); - std::iota(p.begin(), p.begin() + 256, 0); - - std::default_random_engine engine(seed); - std::ranges::shuffle(std::span(p.begin(), p.begin() + 256), engine); - - std::ranges::copy(std::span(p.begin(), p.begin() + 256), - p.begin() + 256); - -#ifdef USE_OPENCL - initializeOpenCL(); -#endif - } - - ~PerlinNoise() { -#ifdef USE_OPENCL - cleanupOpenCL(); -#endif - } - - template - [[nodiscard]] auto noise(T x, T y, T z) const -> T { -#ifdef USE_OPENCL - if (opencl_available_) { - return noiseOpenCL(x, y, z); - } -#endif - return noiseCPU(x, y, z); - } - - template - [[nodiscard]] auto octaveNoise(T x, T y, T z, int octaves, - T persistence) const -> T { - T total = 0; - T frequency = 1; - T amplitude = 1; - T maxValue = 0; - - for (int i = 0; i < octaves; ++i) { - total += - noise(x * frequency, y * frequency, z * frequency) * amplitude; - maxValue += amplitude; - amplitude *= persistence; - frequency *= 2; - } - - return total / maxValue; - } - - [[nodiscard]] auto generateNoiseMap( - int width, int height, double scale, int octaves, double persistence, - double /*lacunarity*/, - int seed = std::default_random_engine::default_seed) const - -> std::vector> { - std::vector> noiseMap(height, - std::vector(width)); - std::default_random_engine prng(seed); - std::uniform_real_distribution dist(-10000, 10000); - double offsetX = dist(prng); - double offsetY = dist(prng); - - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - double sampleX = (x - width / 2.0 + offsetX) / scale; - double sampleY = (y - height / 2.0 + offsetY) / scale; - noiseMap[y][x] = - octaveNoise(sampleX, sampleY, 0.0, octaves, persistence); - } - } - - return noiseMap; - } - -private: - std::vector p; - -#ifdef USE_OPENCL - cl_context context_; - cl_command_queue queue_; - cl_program program_; - cl_kernel noise_kernel_; - bool opencl_available_; - - void initializeOpenCL() { - cl_int err; - cl_platform_id platform; - cl_device_id device; - - err = clGetPlatformIDs(1, &platform, nullptr); - if (err != CL_SUCCESS) { - opencl_available_ = false; - return; - } - - err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &device, nullptr); - if (err != CL_SUCCESS) { - opencl_available_ = false; - return; - } - - context_ = clCreateContext(nullptr, 1, &device, nullptr, nullptr, &err); - if (err != CL_SUCCESS) { - opencl_available_ = false; - return; - } - - queue_ = clCreateCommandQueue(context_, device, 0, &err); - if (err != CL_SUCCESS) { - opencl_available_ = false; - return; - } - - const char* kernel_source = R"CLC( - __kernel void noise_kernel(__global const float* coords, - __global float* result, - __constant int* p) { - int gid = get_global_id(0); - - float x = coords[gid * 3]; - float y = coords[gid * 3 + 1]; - float z = coords[gid * 3 + 2]; - - int X = ((int)floor(x)) & 255; - int Y = ((int)floor(y)) & 255; - int Z = ((int)floor(z)) & 255; - - x -= floor(x); - y -= floor(y); - z -= floor(z); - - float u = x * x * x * (x * (x * 6 - 15) + 10); - float v = y * y * y * (y * (y * 6 - 15) + 10); - float w = z * z * z * (z * (z * 6 - 15) + 10); - - int A = p[X] + Y; - int AA = p[A] + Z; - int AB = p[A + 1] + Z; - int B = p[X + 1] + Y; - int BA = p[B] + Z; - int BB = p[B + 1] + Z; - - float res = lerp(w, - lerp(v, lerp(u, grad(p[AA], x, y, z), - grad(p[BA], x - 1, y, z)), - lerp(u, grad(p[AB], x, y - 1, z), - grad(p[BB], x - 1, y - 1, z))), - lerp(v, lerp(u, grad(p[AA + 1], x, y, z - 1), - grad(p[BA + 1], x - 1, y, z - 1)), - lerp(u, grad(p[AB + 1], x, y - 1, z - 1), - grad(p[BB + 1], x - 1, y - 1, - z - 1)))); - result[gid] = (res + 1) / 2; - } - - float lerp(float t, float a, float b) { - return a + t * (b - a); - } - - float grad(int hash, float x, float y, float z) { - int h = hash & 15; - float u = h < 8 ? x : y; - float v = h < 4 ? y : (h == 12 || h == 14 ? x : z); - return ((h & 1) == 0 ? u : -u) + ((h & 2) == 0 ? v : -v); - } - )CLC"; - - program_ = clCreateProgramWithSource(context_, 1, &kernel_source, - nullptr, &err); - if (err != CL_SUCCESS) { - opencl_available_ = false; - return; - } - - err = clBuildProgram(program_, 1, &device, nullptr, nullptr, nullptr); - if (err != CL_SUCCESS) { - opencl_available_ = false; - return; - } - - noise_kernel_ = clCreateKernel(program_, "noise_kernel", &err); - if (err != CL_SUCCESS) { - opencl_available_ = false; - return; - } - - opencl_available_ = true; - } - - void cleanupOpenCL() { - if (opencl_available_) { - clReleaseKernel(noise_kernel_); - clReleaseProgram(program_); - clReleaseCommandQueue(queue_); - clReleaseContext(context_); - } - } - - template - auto noiseOpenCL(T x, T y, T z) const -> T { - float coords[] = {static_cast(x), static_cast(y), - static_cast(z)}; - float result; - - cl_mem coords_buffer = - clCreateBuffer(context_, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, - sizeof(coords), coords, nullptr); - cl_mem result_buffer = clCreateBuffer(context_, CL_MEM_WRITE_ONLY, - sizeof(float), nullptr, nullptr); - cl_mem p_buffer = - clCreateBuffer(context_, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, - p.size() * sizeof(int), p.data(), nullptr); - - clSetKernelArg(noise_kernel_, 0, sizeof(cl_mem), &coords_buffer); - clSetKernelArg(noise_kernel_, 1, sizeof(cl_mem), &result_buffer); - clSetKernelArg(noise_kernel_, 2, sizeof(cl_mem), &p_buffer); - - size_t global_work_size = 1; - clEnqueueNDRangeKernel(queue_, noise_kernel_, 1, nullptr, - &global_work_size, nullptr, 0, nullptr, nullptr); - - clEnqueueReadBuffer(queue_, result_buffer, CL_TRUE, 0, sizeof(float), - &result, 0, nullptr, nullptr); - - clReleaseMemObject(coords_buffer); - clReleaseMemObject(result_buffer); - clReleaseMemObject(p_buffer); - - return static_cast(result); - } -#endif // USE_OPENCL - - template - [[nodiscard]] auto noiseCPU(T x, T y, T z) const -> T { - // Find unit cube containing point - int X = static_cast(std::floor(x)) & 255; - int Y = static_cast(std::floor(y)) & 255; - int Z = static_cast(std::floor(z)) & 255; - - // Find relative x, y, z of point in cube - x -= std::floor(x); - y -= std::floor(y); - z -= std::floor(z); - -// Compute fade curves for each of x, y, z -#ifdef USE_SIMD - // SIMD-based fade function calculations - __m256d xSimd = _mm256_set1_pd(x); - __m256d ySimd = _mm256_set1_pd(y); - __m256d zSimd = _mm256_set1_pd(z); - - __m256d uSimd = - _mm256_mul_pd(xSimd, _mm256_sub_pd(xSimd, _mm256_set1_pd(15))); - uSimd = _mm256_mul_pd( - uSimd, _mm256_add_pd(_mm256_set1_pd(10), - _mm256_mul_pd(xSimd, _mm256_set1_pd(6)))); - // Apply similar SIMD operations for v and w if needed - __m256d vSimd = - _mm256_mul_pd(ySimd, _mm256_sub_pd(ySimd, _mm256_set1_pd(15))); - vSimd = _mm256_mul_pd( - vSimd, _mm256_add_pd(_mm256_set1_pd(10), - _mm256_mul_pd(ySimd, _mm256_set1_pd(6)))); - __m256d wSimd = - _mm256_mul_pd(zSimd, _mm256_sub_pd(zSimd, _mm256_set1_pd(15))); - wSimd = _mm256_mul_pd( - wSimd, _mm256_add_pd(_mm256_set1_pd(10), - _mm256_mul_pd(zSimd, _mm256_set1_pd(6)))); -#else - T u = fade(x); - T v = fade(y); - T w = fade(z); -#endif - - // Hash coordinates of the 8 cube corners - int A = p[X] + Y; - int AA = p[A] + Z; - int AB = p[A + 1] + Z; - int B = p[X + 1] + Y; - int BA = p[B] + Z; - int BB = p[B + 1] + Z; - - // Add blended results from 8 corners of cube - T res = lerp( - w, - lerp(v, lerp(u, grad(p[AA], x, y, z), grad(p[BA], x - 1, y, z)), - lerp(u, grad(p[AB], x, y - 1, z), - grad(p[BB], x - 1, y - 1, z))), - lerp(v, - lerp(u, grad(p[AA + 1], x, y, z - 1), - grad(p[BA + 1], x - 1, y, z - 1)), - lerp(u, grad(p[AB + 1], x, y - 1, z - 1), - grad(p[BB + 1], x - 1, y - 1, z - 1)))); - return (res + 1) / 2; // Normalize to [0,1] - } - - static constexpr auto fade(double t) noexcept -> double { - return t * t * t * (t * (t * 6 - 15) + 10); - } - - static constexpr auto lerp(double t, double a, - double b) noexcept -> double { - return a + t * (b - a); - } - - static constexpr auto grad(int hash, double x, double y, - double z) noexcept -> double { - int h = hash & 15; - double u = h < 8 ? x : y; - double v = h < 4 ? y : (h == 12 || h == 14 ? x : z); - return ((h & 1) == 0 ? u : -u) + ((h & 2) == 0 ? v : -v); - } -}; - -} // namespace atom::algorithm - -#endif // ATOM_ALGORITHM_PERLIN_HPP diff --git a/src/atom/algorithm/sha1.cpp b/src/atom/algorithm/sha1.cpp deleted file mode 100644 index 380dd02c..00000000 --- a/src/atom/algorithm/sha1.cpp +++ /dev/null @@ -1,138 +0,0 @@ -#include "sha1.hpp" - -#include -#include -#include - -namespace atom::algorithm { -SHA1::SHA1() { reset(); } - -void SHA1::update(const uint8_t* data, size_t length) { - size_t remaining = length; - size_t offset = 0; - - while (remaining > 0) { - size_t bufferOffset = (bitCount_ / 8) % BLOCK_SIZE; - - size_t bytesToFill = BLOCK_SIZE - bufferOffset; - size_t bytesToCopy = std::min(remaining, bytesToFill); - - std::copy(data + offset, data + offset + bytesToCopy, - buffer_.data() + bufferOffset); - offset += bytesToCopy; - remaining -= bytesToCopy; - bitCount_ += bytesToCopy * BITS_PER_BYTE; - - if (bufferOffset + bytesToCopy == BLOCK_SIZE) { - processBlock(buffer_.data()); - } - } -} - -std::array SHA1::digest() { - uint64_t bitLength = bitCount_; - - // Padding - size_t bufferOffset = (bitCount_ / 8) % BLOCK_SIZE; - buffer_[bufferOffset] = PADDING_BYTE; // Append the bit '1' - - if (bufferOffset >= BLOCK_SIZE - LENGTH_SIZE) { - // Not enough space for the length, process the block - processBlock(buffer_.data()); - std::fill(buffer_.begin(), buffer_.end(), 0); - } - - // Append the length of the message - for (size_t i = 0; i < LENGTH_SIZE; ++i) { - buffer_[BLOCK_SIZE - LENGTH_SIZE + i] = - (bitLength >> (LENGTH_SIZE * BITS_PER_BYTE - i * BITS_PER_BYTE)) & - BYTE_MASK; - } - processBlock(buffer_.data()); - - // Produce the final hash value - std::array result; - for (size_t i = 0; i < HASH_SIZE; ++i) { - result[i * 4] = (hash_[i] >> 24) & BYTE_MASK; - result[i * 4 + 1] = (hash_[i] >> 16) & BYTE_MASK; - result[i * 4 + 2] = (hash_[i] >> 8) & BYTE_MASK; - result[i * 4 + 3] = hash_[i] & BYTE_MASK; - } - - return result; -} - -void SHA1::reset() { - bitCount_ = 0; - hash_.fill(0); - hash_[0] = 0x67452301; - hash_[1] = 0xEFCDAB89; - hash_[2] = 0x98BADCFE; - hash_[3] = 0x10325476; - hash_[4] = 0xC3D2E1F0; - buffer_.fill(0); -} - -void SHA1::processBlock(const uint8_t* block) { - std::array schedule{}; - for (size_t i = 0; i < 16; ++i) { - schedule[i] = (block[i * 4] << 24) | (block[i * 4 + 1] << 16) | - (block[i * 4 + 2] << 8) | block[i * 4 + 3]; - } - - for (size_t i = 16; i < SCHEDULE_SIZE; ++i) { - schedule[i] = rotateLeft(schedule[i - 3] ^ schedule[i - 8] ^ - schedule[i - 14] ^ schedule[i - 16], - 1); - } - - uint32_t a = hash_[0]; - uint32_t b = hash_[1]; - uint32_t c = hash_[2]; - uint32_t d = hash_[3]; - uint32_t e = hash_[4]; - - for (size_t i = 0; i < SCHEDULE_SIZE; ++i) { - uint32_t f; - uint32_t k; - if (i < 20) { - f = (b & c) | (~b & d); - k = 0x5A827999; - } else if (i < 40) { - f = b ^ c ^ d; - k = 0x6ED9EBA1; - } else if (i < 60) { - f = (b & c) | (b & d) | (c & d); - k = 0x8F1BBCDC; - } else { - f = b ^ c ^ d; - k = 0xCA62C1D6; - } - - uint32_t temp = rotateLeft(a, 5) + f + e + k + schedule[i]; - e = d; - d = c; - c = rotateLeft(b, 30); - b = a; - a = temp; - } - - hash_[0] += a; - hash_[1] += b; - hash_[2] += c; - hash_[3] += d; - hash_[4] += e; -} - -auto SHA1::rotateLeft(uint32_t value, size_t bits) -> uint32_t { - return (value << bits) | (value >> (WORD_SIZE - bits)); -} - -auto bytesToHex(const std::array& bytes) -> std::string { - std::ostringstream oss; - for (uint8_t byte : bytes) { - oss << std::setw(2) << std::setfill('0') << std::hex << (int)byte; - } - return oss.str(); -} -} // namespace atom::algorithm diff --git a/src/atom/algorithm/sha1.hpp b/src/atom/algorithm/sha1.hpp deleted file mode 100644 index 5a45caba..00000000 --- a/src/atom/algorithm/sha1.hpp +++ /dev/null @@ -1,41 +0,0 @@ -#ifndef ATOM_ALGORITHM_SHA1_HPP -#define ATOM_ALGORITHM_SHA1_HPP - -#include -#include -#include - -namespace atom::algorithm { -class SHA1 { -public: - SHA1(); - - void update(const uint8_t* data, size_t length); - auto digest() -> std::array; - void reset(); - - static constexpr size_t DIGEST_SIZE = 20; - -private: - void processBlock(const uint8_t* block); - static auto rotateLeft(uint32_t value, size_t bits) -> uint32_t; - - static constexpr size_t BLOCK_SIZE = 64; - static constexpr size_t HASH_SIZE = 5; - static constexpr size_t SCHEDULE_SIZE = 80; - static constexpr size_t LENGTH_SIZE = 8; - static constexpr size_t BITS_PER_BYTE = 8; - static constexpr uint8_t PADDING_BYTE = 0x80; - static constexpr uint8_t BYTE_MASK = 0xFF; - static constexpr size_t WORD_SIZE = 32; - - std::array hash_; - std::array buffer_; - uint64_t bitCount_; -}; - -auto bytesToHex(const std::array& bytes) -> std::string; - -} // namespace atom::algorithm - -#endif // ATOM_ALGORITHM_SHA1_HPP diff --git a/src/atom/algorithm/snowflake.hpp b/src/atom/algorithm/snowflake.hpp deleted file mode 100644 index d0ed1368..00000000 --- a/src/atom/algorithm/snowflake.hpp +++ /dev/null @@ -1,199 +0,0 @@ -#ifndef ATOM_ALGORITHM_SNOWFLAKE_HPP -#define ATOM_ALGORITHM_SNOWFLAKE_HPP - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::algorithm { - -// Custom exception classes for clearer error handling -class SnowflakeException : public std::runtime_error { -public: - explicit SnowflakeException(const std::string &message) - : std::runtime_error(message) {} -}; - -class InvalidWorkerIdException : public SnowflakeException { -public: - InvalidWorkerIdException(uint64_t worker_id, uint64_t max) - : SnowflakeException("Worker ID " + std::to_string(worker_id) + - " exceeds maximum of " + std::to_string(max)) {} -}; - -class InvalidDatacenterIdException : public SnowflakeException { -public: - InvalidDatacenterIdException(uint64_t datacenter_id, uint64_t max) - : SnowflakeException("Datacenter ID " + std::to_string(datacenter_id) + - " exceeds maximum of " + std::to_string(max)) {} -}; - -class InvalidTimestampException : public SnowflakeException { -public: - InvalidTimestampException(uint64_t timestamp) - : SnowflakeException("Timestamp " + std::to_string(timestamp) + - " is invalid or out of range.") {} -}; - -class SnowflakeNonLock { -public: - void lock() {} - void unlock() {} -}; - -template -class Snowflake { - static_assert(std::is_same_v || - std::is_same_v, - "Lock must be SnowflakeNonLock or std::mutex"); - -public: - using lock_type = Lock; - static constexpr uint64_t TWEPOCH = Twepoch; - static constexpr uint64_t WORKER_ID_BITS = 5; - static constexpr uint64_t DATACENTER_ID_BITS = 5; - static constexpr uint64_t MAX_WORKER_ID = (1ULL << WORKER_ID_BITS) - 1; - static constexpr uint64_t MAX_DATACENTER_ID = - (1ULL << DATACENTER_ID_BITS) - 1; - static constexpr uint64_t SEQUENCE_BITS = 12; - static constexpr uint64_t WORKER_ID_SHIFT = SEQUENCE_BITS; - static constexpr uint64_t DATACENTER_ID_SHIFT = - SEQUENCE_BITS + WORKER_ID_BITS; - static constexpr uint64_t TIMESTAMP_LEFT_SHIFT = - SEQUENCE_BITS + WORKER_ID_BITS + DATACENTER_ID_BITS; - static constexpr uint64_t SEQUENCE_MASK = (1ULL << SEQUENCE_BITS) - 1; - - explicit Snowflake(uint64_t worker_id = 0, uint64_t datacenter_id = 0) - : workerid_(worker_id), datacenterid_(datacenter_id) { - initialize(); - } - - Snowflake(const Snowflake &) = delete; - auto operator=(const Snowflake &) -> Snowflake & = delete; - - void init(uint64_t worker_id, uint64_t datacenter_id) { - std::lock_guard lock(lock_); - if (worker_id > MAX_WORKER_ID) { - throw InvalidWorkerIdException(worker_id, MAX_WORKER_ID); - } - if (datacenter_id > MAX_DATACENTER_ID) { - throw InvalidDatacenterIdException(datacenter_id, - MAX_DATACENTER_ID); - } - workerid_ = worker_id; - datacenterid_ = datacenter_id; - } - - [[nodiscard]] auto nextid() -> uint64_t { - std::lock_guard lock(lock_); - uint64_t timestamp = current_millis(); - if (timestamp < last_timestamp_) { - throw InvalidTimestampException(timestamp); - } - - if (last_timestamp_ == timestamp) { - sequence_ = (sequence_ + 1) & SEQUENCE_MASK; - if (sequence_ == 0) { - timestamp = wait_next_millis(last_timestamp_); - if (timestamp < last_timestamp_) { - throw InvalidTimestampException(timestamp); - } - } - } else { - sequence_ = 0; - } - - last_timestamp_ = timestamp; - - uint64_t id = ((timestamp - TWEPOCH) << TIMESTAMP_LEFT_SHIFT) | - (datacenterid_ << DATACENTER_ID_SHIFT) | - (workerid_ << WORKER_ID_SHIFT) | sequence_; - - return id ^ secret_key_; - } - - void parseId(uint64_t encrypted_id, uint64_t ×tamp, - uint64_t &datacenter_id, uint64_t &worker_id, - uint64_t &sequence) const { - uint64_t id = encrypted_id ^ secret_key_; - - timestamp = (id >> TIMESTAMP_LEFT_SHIFT) + TWEPOCH; - datacenter_id = (id >> DATACENTER_ID_SHIFT) & MAX_DATACENTER_ID; - worker_id = (id >> WORKER_ID_SHIFT) & MAX_WORKER_ID; - sequence = id & SEQUENCE_MASK; - } - - // Additional functionality: Reset the Snowflake generator - void reset() { - std::lock_guard lock(lock_); - last_timestamp_ = 0; - sequence_ = 0; - } - - // Additional functionality: Retrieve current worker ID - [[nodiscard]] auto getWorkerId() const -> uint64_t { return workerid_; } - - // Additional functionality: Retrieve current datacenter ID - [[nodiscard]] auto getDatacenterId() const -> uint64_t { - return datacenterid_; - } - -private: - uint64_t workerid_ = 0; - uint64_t datacenterid_ = 0; - uint64_t sequence_ = 0; - mutable lock_type lock_; - uint64_t secret_key_; - - std::atomic last_timestamp_{0}; - std::chrono::steady_clock::time_point start_time_point_ = - std::chrono::steady_clock::now(); - uint64_t start_millisecond_ = get_system_millis(); - - void initialize() { - std::random_device rd; - std::mt19937_64 eng(rd()); - std::uniform_int_distribution distr; - secret_key_ = distr(eng); - - if (workerid_ > MAX_WORKER_ID) { - throw InvalidWorkerIdException(workerid_, MAX_WORKER_ID); - } - if (datacenterid_ > MAX_DATACENTER_ID) { - throw InvalidDatacenterIdException(datacenterid_, - MAX_DATACENTER_ID); - } - } - - [[nodiscard]] auto get_system_millis() const -> uint64_t { - return static_cast( - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } - - [[nodiscard]] auto current_millis() const -> uint64_t { - auto now = std::chrono::steady_clock::now(); - auto diff = std::chrono::duration_cast( - now - start_time_point_) - .count(); - return start_millisecond_ + static_cast(diff); - } - - [[nodiscard]] auto wait_next_millis(uint64_t last) const -> uint64_t { - uint64_t timestamp = current_millis(); - while (timestamp <= last) { - timestamp = current_millis(); - } - return timestamp; - } -}; - -} // namespace atom::algorithm - -#endif // ATOM_ALGORITHM_SNOWFLAKE_HPP \ No newline at end of file diff --git a/src/atom/algorithm/tea.cpp b/src/atom/algorithm/tea.cpp deleted file mode 100644 index 78253976..00000000 --- a/src/atom/algorithm/tea.cpp +++ /dev/null @@ -1,179 +0,0 @@ -#include "tea.hpp" - -namespace atom::algorithm { -// Constants for TEA -constexpr uint32_t DELTA = 0x9E3779B9; -constexpr int NUM_ROUNDS = 32; -constexpr int SHIFT_4 = 4; -constexpr int SHIFT_5 = 5; -constexpr int BYTE_SHIFT = 8; -constexpr size_t MIN_ROUNDS = 6; -constexpr size_t MAX_ROUNDS = 52; -constexpr int SHIFT_3 = 3; -constexpr int SHIFT_2 = 2; -constexpr uint32_t KEY_MASK = 3; -constexpr int SHIFT_11 = 11; - -// TEA encryption function -auto teaEncrypt(uint32_t &value0, uint32_t &value1, - const std::array &key) -> void { - uint32_t sum = 0; - for (int i = 0; i < NUM_ROUNDS; ++i) { - sum += DELTA; - value0 += ((value1 << SHIFT_4) + key[0]) ^ (value1 + sum) ^ - ((value1 >> SHIFT_5) + key[1]); - value1 += ((value0 << SHIFT_4) + key[2]) ^ (value0 + sum) ^ - ((value0 >> SHIFT_5) + key[3]); - } -} - -// TEA decryption function -auto teaDecrypt(uint32_t &value0, uint32_t &value1, - const std::array &key) -> void { - uint32_t sum = DELTA * NUM_ROUNDS; - for (int i = 0; i < NUM_ROUNDS; ++i) { - value1 -= ((value0 << SHIFT_4) + key[2]) ^ (value0 + sum) ^ - ((value0 >> SHIFT_5) + key[3]); - value0 -= ((value1 << SHIFT_4) + key[0]) ^ (value1 + sum) ^ - ((value1 >> SHIFT_5) + key[1]); - sum -= DELTA; - } -} - -// Helper function to convert a byte array to a vector of uint32_t -auto toUint32Vector(const std::vector &data) -> std::vector { - size_t numElements = (data.size() + 3) / 4; - std::vector result(numElements); - - for (size_t index = 0; index < data.size(); ++index) { - result[index / 4] |= static_cast(data[index]) - << ((index % 4) * BYTE_SHIFT); - } - - return result; -} - -// Helper function to convert a vector of uint32_t back to a byte array -auto toByteArray(const std::vector &data) -> std::vector { - std::vector result(data.size() * 4); - - for (size_t index = 0; index < data.size() * 4; ++index) { - result[index] = - static_cast(data[index / 4] >> ((index % 4) * BYTE_SHIFT)); - } - - return result; -} - -// XXTEA encrypt function -auto xxteaEncrypt(const std::vector &inputData, - const std::vector &inputKey) - -> std::vector { - size_t numElements = inputData.size(); - if (numElements < 2) { - return inputData; - } - - uint32_t sum = 0; - uint32_t lastElement = inputData[numElements - 1]; - uint32_t currentElement; - size_t numRounds = MIN_ROUNDS + MAX_ROUNDS / numElements; - - std::vector result = inputData; - - for (size_t roundIndex = 0; roundIndex < numRounds; ++roundIndex) { - sum += DELTA; - uint32_t keyIndex = (sum >> SHIFT_2) & KEY_MASK; - for (size_t elementIndex = 0; elementIndex < numElements - 1; - ++elementIndex) { - currentElement = result[elementIndex + 1]; - result[elementIndex] += - ((lastElement >> SHIFT_5) ^ (currentElement << SHIFT_2)) + - ((currentElement >> SHIFT_3) ^ (lastElement << SHIFT_4)) ^ - ((sum ^ currentElement) + - (inputKey[(elementIndex & KEY_MASK) ^ keyIndex] ^ - lastElement)); - lastElement = result[elementIndex]; - } - currentElement = result[0]; - result[numElements - 1] += - ((lastElement >> SHIFT_5) ^ (currentElement << SHIFT_2)) + - ((currentElement >> SHIFT_3) ^ (lastElement << SHIFT_4)) ^ - ((sum ^ currentElement) + - (inputKey[((numElements - 1) & KEY_MASK) ^ keyIndex] ^ - lastElement)); - lastElement = result[numElements - 1]; - } - - return result; -} - -// XXTEA decrypt function -auto xxteaDecrypt(const std::vector &inputData, - const std::vector &inputKey) - -> std::vector { - size_t numElements = inputData.size(); - if (numElements < 2) { - return inputData; - } - - uint32_t sum = (MIN_ROUNDS + MAX_ROUNDS / numElements) * DELTA; - uint32_t lastElement = inputData[numElements - 1]; - uint32_t currentElement; - - std::vector result = inputData; - - for (size_t roundIndex = 0; - roundIndex < MIN_ROUNDS + MAX_ROUNDS / numElements; ++roundIndex) { - uint32_t keyIndex = (sum >> SHIFT_2) & KEY_MASK; - for (size_t elementIndex = numElements - 1; elementIndex > 0; - --elementIndex) { - lastElement = result[elementIndex - 1]; - result[elementIndex] -= - ((lastElement >> SHIFT_5) ^ (currentElement << SHIFT_2)) + - ((currentElement >> SHIFT_3) ^ (lastElement << SHIFT_4)) ^ - ((sum ^ currentElement) + - (inputKey[(elementIndex & KEY_MASK) ^ keyIndex] ^ - lastElement)); - currentElement = result[elementIndex]; - } - lastElement = result[numElements - 1]; - result[0] -= - ((lastElement >> SHIFT_5) ^ (currentElement << SHIFT_2)) + - ((currentElement >> SHIFT_3) ^ (lastElement << SHIFT_4)) ^ - ((sum ^ currentElement) + - (inputKey[((numElements - 1) & KEY_MASK) ^ keyIndex] ^ - lastElement)); - sum -= DELTA; - currentElement = result[0]; - } - - return result; -} - -// XTEA encryption function -auto xteaEncrypt(uint32_t &value0, uint32_t &value1, - const XTEAKey &key) -> void { - uint32_t sum = 0; - for (int i = 0; i < NUM_ROUNDS; ++i) { - value0 += ((value1 << SHIFT_4) ^ (value1 >> SHIFT_5)) + value1 ^ - (sum + key[sum & KEY_MASK]); - sum += DELTA; - value1 += ((value0 << SHIFT_4) ^ (value0 >> SHIFT_5)) + value0 ^ - (sum + key[(sum >> SHIFT_11) & KEY_MASK]); - } -} - -// XTEA decryption function -auto xteaDecrypt(uint32_t &value0, uint32_t &value1, - const XTEAKey &key) -> void { - uint32_t sum = DELTA * NUM_ROUNDS; - for (int i = 0; i < NUM_ROUNDS; ++i) { - value1 -= ((value0 << SHIFT_4) ^ (value0 >> SHIFT_5)) + value0 ^ - (sum + key[(sum >> SHIFT_11) & KEY_MASK]); - sum -= DELTA; - value0 -= ((value1 << SHIFT_4) ^ (value1 >> SHIFT_5)) + value1 ^ - (sum + key[sum & KEY_MASK]); - } -} -} // namespace atom::algorithm diff --git a/src/atom/algorithm/tea.hpp b/src/atom/algorithm/tea.hpp deleted file mode 100644 index 6e6fae88..00000000 --- a/src/atom/algorithm/tea.hpp +++ /dev/null @@ -1,90 +0,0 @@ -#ifndef ATOM_ALGORITHM_TEA_HPP -#define ATOM_ALGORITHM_TEA_HPP - -#include -#include -#include - -namespace atom::algorithm { -using XTEAKey = std::array; - -/** - * @brief Encrypts two 32-bit values using the TEA algorithm. - * - * @param value0 The first 32-bit value to be encrypted. - * @param value1 The second 32-bit value to be encrypted. - * @param key The 128-bit key used for encryption. - */ -auto teaEncrypt(uint32_t &value0, uint32_t &value1, - const std::array &key) -> void; - -/** - * @brief Decrypts two 32-bit values using the TEA algorithm. - * - * @param value0 The first 32-bit value to be decrypted. - * @param value1 The second 32-bit value to be decrypted. - * @param key The 128-bit key used for decryption. - */ -auto teaDecrypt(uint32_t &value0, uint32_t &value1, - const std::array &key) -> void; - -/** - * @brief Encrypts a vector of 32-bit values using the XXTEA algorithm. - * - * @param inputData The vector of 32-bit values to be encrypted. - * @param inputKey The 128-bit key used for encryption. - * @return A vector of encrypted 32-bit values. - */ -auto xxteaEncrypt(const std::vector &inputData, - const std::vector &inputKey) - -> std::vector; - -/** - * @brief Decrypts a vector of 32-bit values using the XXTEA algorithm. - * - * @param inputData The vector of 32-bit values to be decrypted. - * @param inputKey The 128-bit key used for decryption. - * @return A vector of decrypted 32-bit values. - */ -auto xxteaDecrypt(const std::vector &inputData, - const std::vector &inputKey) - -> std::vector; - -/** - * @brief Encrypts two 32-bit values using the XTEA algorithm. - * - * @param value0 The first 32-bit value to be encrypted. - * @param value1 The second 32-bit value to be encrypted. - * @param key The 128-bit key used for encryption. - */ -auto xteaEncrypt(uint32_t &value0, uint32_t &value1, - const XTEAKey &key) -> void; - -/** - * @brief Decrypts two 32-bit values using the XTEA algorithm. - * - * @param value0 The first 32-bit value to be decrypted. - * @param value1 The second 32-bit value to be decrypted. - * @param key The 128-bit key used for decryption. - */ -auto xteaDecrypt(uint32_t &value0, uint32_t &value1, - const XTEAKey &key) -> void; - -/** - * @brief Converts a byte array to a vector of 32-bit unsigned integers. - * - * @param data The byte array to be converted. - * @return A vector of 32-bit unsigned integers. - */ -auto toUint32Vector(const std::vector &data) -> std::vector; - -/** - * @brief Converts a vector of 32-bit unsigned integers back to a byte array. - * - * @param data The vector of 32-bit unsigned integers to be converted. - * @return A byte array. - */ -auto toByteArray(const std::vector &data) -> std::vector; -} // namespace atom::algorithm - -#endif diff --git a/src/atom/algorithm/weight.hpp b/src/atom/algorithm/weight.hpp deleted file mode 100644 index 05c6f3de..00000000 --- a/src/atom/algorithm/weight.hpp +++ /dev/null @@ -1,255 +0,0 @@ -#ifndef ATOM_ALGORITHM_WEIGHT_HPP -#define ATOM_ALGORITHM_WEIGHT_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "atom/error/exception.hpp" -#include "atom/function/concept.hpp" -#include "atom/utils/random.hpp" - -namespace atom::algorithm { - -template -class WeightSelector { -public: - class SelectionStrategy { - public: - virtual ~SelectionStrategy() = default; - virtual auto select(std::span cumulative_weights, - T total_weight) -> size_t = 0; - }; - - class DefaultSelectionStrategy : public SelectionStrategy { - private: - utils::Random> random_; - - public: - DefaultSelectionStrategy() : random_(0.0, 1.0) {} - - auto select(std::span cumulative_weights, - T total_weight) -> size_t override { - T randomValue = random_() * total_weight; - auto it = std::ranges::upper_bound(cumulative_weights, randomValue); - return std::distance(cumulative_weights.begin(), it); - } - }; - - class BottomHeavySelectionStrategy : public SelectionStrategy { - private: - utils::Random> random_; - - public: - BottomHeavySelectionStrategy() : random_(0.0, 1.0) {} - - auto select(std::span cumulative_weights, - T total_weight) -> size_t override { - T randomValue = std::sqrt(random_()) * total_weight; - auto it = std::ranges::upper_bound(cumulative_weights, randomValue); - return std::distance(cumulative_weights.begin(), it); - } - }; - - class RandomSelectionStrategy : public SelectionStrategy { - private: - utils::Random> - random_index_; - - public: - explicit RandomSelectionStrategy(size_t max_index) - : random_index_(0, max_index - 1) {} - - auto select(std::span /*cumulative_weights*/, - T /*total_weight*/) -> size_t override { - return random_index_(); - } - }; - - class WeightedRandomSampler { - public: - auto sample(std::span weights, - size_t n) -> std::vector { - std::vector indices(weights.size()); - std::iota(indices.begin(), indices.end(), 0); - - utils::Random> random( - weights); - std::vector results(n); - std::generate(results.begin(), results.end(), - [&]() { return random(); }); - - return results; - } - }; - -private: - std::vector weights_; - std::vector cumulative_weights_; - std::unique_ptr strategy_; - - void updateCumulativeWeights() { - cumulative_weights_.resize(weights_.size()); - std::exclusive_scan(weights_.begin(), weights_.end(), - cumulative_weights_.begin(), T{0}); - } - -public: - explicit WeightSelector(std::span input_weights, - std::unique_ptr custom_strategy = - std::make_unique()) - : weights_(input_weights.begin(), input_weights.end()), - strategy_(std::move(custom_strategy)) { - updateCumulativeWeights(); - } - - void setSelectionStrategy(std::unique_ptr new_strategy) { - strategy_ = std::move(new_strategy); - } - - auto select() -> size_t { - T totalWeight = std::reduce(weights_.begin(), weights_.end()); - if (totalWeight <= T{0}) { - THROW_RUNTIME_ERROR("Total weight must be greater than zero."); - } - return strategy_->select(cumulative_weights_, totalWeight); - } - - auto selectMultiple(size_t n) -> std::vector { - std::vector results; - results.reserve(n); - for (size_t i = 0; i < n; ++i) { - results.push_back(select()); - } - return results; - } - - void updateWeight(size_t index, T new_weight) { - if (index >= weights_.size()) { - throw std::out_of_range("Index out of range"); - } - weights_[index] = new_weight; - updateCumulativeWeights(); - } - - void addWeight(T new_weight) { - weights_.push_back(new_weight); - updateCumulativeWeights(); - } - - void removeWeight(size_t index) { - if (index >= weights_.size()) { - throw std::out_of_range("Index out of range"); - } - weights_.erase(weights_.begin() + index); - updateCumulativeWeights(); - } - - void normalizeWeights() { - T sum = std::reduce(weights_.begin(), weights_.end()); - if (sum > T{0}) { - std::ranges::transform(weights_, weights_.begin(), - [sum](T w) { return w / sum; }); - updateCumulativeWeights(); - } - } - - void applyFunctionToWeights(std::invocable auto&& func) { - std::ranges::transform(weights_, weights_.begin(), - std::forward(func)); - updateCumulativeWeights(); - } - - void batchUpdateWeights(const std::vector>& updates) { - for (const auto& [index, new_weight] : updates) { - if (index >= weights_.size()) { - throw std::out_of_range("Index out of range"); - } - weights_[index] = new_weight; - } - updateCumulativeWeights(); - } - - [[nodiscard]] auto getWeight(size_t index) const -> std::optional { - if (index >= weights_.size()) { - return std::nullopt; - } - return weights_[index]; - } - - [[nodiscard]] auto getMaxWeightIndex() const -> size_t { - return std::distance(weights_.begin(), - std::ranges::max_element(weights_)); - } - - [[nodiscard]] auto getMinWeightIndex() const -> size_t { - return std::distance(weights_.begin(), - std::ranges::min_element(weights_)); - } - - [[nodiscard]] auto size() const -> size_t { return weights_.size(); } - - [[nodiscard]] auto getWeights() const -> std::span { - return weights_; - } - - [[nodiscard]] auto getTotalWeight() const -> T { - return std::reduce(weights_.begin(), weights_.end()); - } - - void resetWeights(const std::vector& new_weights) { - weights_ = new_weights; - updateCumulativeWeights(); - } - - void scaleWeights(T factor) { - std::ranges::transform(weights_, weights_.begin(), - [factor](T w) { return w * factor; }); - updateCumulativeWeights(); - } - - [[nodiscard]] auto getAverageWeight() const -> T { - if (weights_.empty()) { - THROW_RUNTIME_ERROR("No weights available to calculate average."); - } - return getTotalWeight() / static_cast(weights_.size()); - } - - void printWeights(std::ostream& oss) const { - if (weights_.empty()) { - oss << "[]\n"; - return; - } - oss << std::format("[{:.2f}", weights_.front()); - for (auto it = weights_.begin() + 1; it != weights_.end(); ++it) { - oss << std::format(", {:.2f}", *it); - } - oss << "]\n"; - } -}; - -template -class TopHeavySelectionStrategy : public WeightSelector::SelectionStrategy { -private: - utils::Random> random_; - -public: - TopHeavySelectionStrategy() : random_(0.0, 1.0) {} - - auto select(std::span cumulative_weights, - T total_weight) -> size_t override { - T randomValue = std::pow(random_(), 2) * total_weight; - auto it = std::ranges::upper_bound(cumulative_weights, randomValue); - return std::distance(cumulative_weights.begin(), it); - } -}; - -} // namespace atom::algorithm - -#endif \ No newline at end of file diff --git a/src/atom/algorithm/xmake.lua b/src/atom/algorithm/xmake.lua deleted file mode 100644 index 45f56012..00000000 --- a/src/atom/algorithm/xmake.lua +++ /dev/null @@ -1,21 +0,0 @@ --- xmake.lua for Atom-Algorithm --- This project is licensed under the terms of the GPL3 license. --- --- Project Name: Atom-Algorithm --- Description: A collection of algorithms --- Author: Max Qian --- License: GPL3 - -package("foo") - add_deps("cmake") - set_sourcedir(path.join(os.scriptdir(), "foo")) - on_install(function (package) - local configs = {} - table.insert(configs, "-DCMAKE_BUILD_TYPE=" .. (package:debug() and "Debug" or "Release")) - table.insert(configs, "-DBUILD_SHARED_LIBS=" .. (package:config("shared") and "ON" or "OFF")) - import("package.tools.cmake").install(package, configs) - end) - on_test(function (package) - assert(package:has_cfuncs("add", {includes = "foo.h"})) - end) -package_end() diff --git a/src/atom/async/CMakeLists.txt b/src/atom/async/CMakeLists.txt deleted file mode 100644 index 58a470f8..00000000 --- a/src/atom/async/CMakeLists.txt +++ /dev/null @@ -1,73 +0,0 @@ -# CMakeLists.txt for Atom-Async -# This project is licensed under the terms of the GPL3 license. -# -# Project Name: Atom-Async -# Description: Async Implementation of Lithium Server and Driver -# Author: Max Qian -# License: GPL3 - -cmake_minimum_required(VERSION 3.20) -project(atom-async C CXX) - -# Sources -set(${PROJECT_NAME}_SOURCES - daemon.cpp - limiter.cpp - lock.cpp - timer.cpp -) - -# Headers -set(${PROJECT_NAME}_HEADERS - async.hpp - daemon.hpp - eventstack.hpp - limiter.hpp - lock.hpp - message_bus.hpp - message_queue.hpp - pool.hpp - queue.hpp - safetype.hpp - thread_wrapper.hpp - timer.hpp - trigger.hpp -) - -set(${PROJECT_NAME}_LIBS - loguru - ${CMAKE_THREAD_LIBS_INIT} -) - -# Build Object Library -add_library(${PROJECT_NAME}_OBJECT OBJECT) -set_property(TARGET ${PROJECT_NAME}_OBJECT PROPERTY POSITION_INDEPENDENT_CODE 1) - -target_sources(${PROJECT_NAME}_OBJECT - PUBLIC - ${${PROJECT_NAME}_HEADERS} - PRIVATE - ${${PROJECT_NAME}_SOURCES} -) - -target_link_libraries(${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) - -add_library(${PROJECT_NAME} STATIC) - -target_link_libraries(${PROJECT_NAME} ${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) -target_include_directories(${PROJECT_NAME} PUBLIC .) - -set_target_properties(${PROJECT_NAME} PROPERTIES - VERSION ${CMAKE_HYDROGEN_VERSION_STRING} - SOVERSION ${HYDROGEN_SOVERSION} - OUTPUT_NAME ${PROJECT_NAME} -) - -install(TARGETS ${PROJECT_NAME} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} -) - -if (ATOM_BUILD_PYTHON) -pybind11_add_module(${PROJECT_NAME}-py _pybind.cpp) -target_link_libraries(${PROJECT_NAME}-py PRIVATE ${PROJECT_NAME}) -endif() diff --git a/src/atom/async/async.hpp b/src/atom/async/async.hpp deleted file mode 100644 index c63699ab..00000000 --- a/src/atom/async/async.hpp +++ /dev/null @@ -1,464 +0,0 @@ -/* - * async.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-10 - -Description: A simple but useful async worker manager - -**************************************************/ - -#ifndef ATOM_ASYNC_ASYNC_HPP -#define ATOM_ASYNC_ASYNC_HPP - -#include -#include -#include -#include -#include -#include - -#include "atom/async/future.hpp" -#include "atom/error/exception.hpp" - -class TimeoutException : public atom::error::RuntimeError { -public: - using atom::error::RuntimeError::RuntimeError; -}; - -#define THROW_TIMEOUT_EXCEPTION(...) \ - throw TimeoutException(ATOM_FILE_NAME, ATOM_FILE_LINE, ATOM_FUNC_NAME, \ - __VA_ARGS__); - -namespace atom::async { -/** - * @brief Class for performing asynchronous tasks. - * - * This class allows you to start a task asynchronously and get the result when - * it's done. It also provides functionality to cancel the task, check if it's - * done or active, validate the result, set a callback function, and set a - * timeout. - * - * @tparam ResultType The type of the result returned by the task. - */ -template -class AsyncWorker { -public: - /** - * @brief Starts the task asynchronously. - * - * @tparam Func The type of the function to be executed asynchronously. - * @tparam Args The types of the arguments to be passed to the function. - * @param func The function to be executed asynchronously. - * @param args The arguments to be passed to the function. - */ - template - void startAsync(Func &&func, Args &&...args); - - /** - * @brief Gets the result of the task. - * - * @throw std::runtime_error if the task is not valid. - * @return The result of the task. - */ - auto getResult() -> ResultType; - - /** - * @brief Cancels the task. - * - * If the task is valid, this function waits for the task to complete. - */ - void cancel(); - - /** - * @brief Checks if the task is done. - * - * @return True if the task is done, false otherwise. - */ - [[nodiscard]] auto isDone() const -> bool; - - /** - * @brief Checks if the task is active. - * - * @return True if the task is active, false otherwise. - */ - [[nodiscard]] auto isActive() const -> bool; - - /** - * @brief Validates the result of the task using a validator function. - * - * @param validator The function used to validate the result. - * @return True if the result is valid, false otherwise. - */ - auto validate(std::function validator) -> bool; - - /** - * @brief Sets a callback function to be called when the task is done. - * - * @param callback The callback function to be set. - */ - void setCallback(std::function callback); - - /** - * @brief Sets a timeout for the task. - * - * @param timeout The timeout duration. - */ - void setTimeout(std::chrono::seconds timeout); - - /** - * @brief Waits for the task to complete. - * - * If a timeout is set, this function waits until the task is done or the - * timeout is reached. If a callback function is set and the task is done, - * the callback function is called with the result. - */ - void waitForCompletion(); - -private: - std::future - task_; ///< The future representing the asynchronous task. - std::function - callback_; ///< The callback function to be called when the task is - ///< done. - std::chrono::seconds timeout_{0}; ///< The timeout duration for the task. -}; - -/** - * @brief Class for managing multiple AsyncWorker instances. - * - * This class provides functionality to create and manage multiple AsyncWorker - * instances. - * - * @tparam ResultType The type of the result returned by the tasks managed by - * this class. - */ -template -class AsyncWorkerManager { -public: - /** - * @brief Default constructor. - */ - AsyncWorkerManager() = default; - - /** - * @brief Creates a new AsyncWorker instance and starts the task - * asynchronously. - * - * @tparam Func The type of the function to be executed asynchronously. - * @tparam Args The types of the arguments to be passed to the function. - * @param func The function to be executed asynchronously. - * @param args The arguments to be passed to the function. - * @return A shared pointer to the created AsyncWorker instance. - */ - template - auto createWorker(Func &&func, Args &&...args) - -> std::shared_ptr>; - - /** - * @brief Cancels all the managed tasks. - */ - void cancelAll(); - - /** - * @brief Checks if all the managed tasks are done. - * - * @return True if all tasks are done, false otherwise. - */ - auto allDone() const -> bool; - - /** - * @brief Waits for all the managed tasks to complete. - */ - void waitForAll(); - - /** - * @brief Checks if a specific task is done. - * - * @param worker The AsyncWorker instance to check. - * @return True if the task is done, false otherwise. - */ - bool isDone(std::shared_ptr> worker) const; - - /** - * @brief Cancels a specific task. - * - * @param worker The AsyncWorker instance to cancel. - */ - void cancel(std::shared_ptr> worker); - -private: - std::vector>> - workers_; ///< The list of managed AsyncWorker instances. -}; - -/** - * @brief Gets the result of the task with a timeout. - * - * @param future The future representing the asynchronous task. - * @param timeout The timeout duration. - * @return The result of the task. - */ -template -auto getWithTimeout(std::future &future, - std::chrono::milliseconds timeout) -> ReturnType; - -template -template -void AsyncWorker::startAsync(Func &&func, Args &&...args) { - static_assert(std::is_invocable_r_v, - "Function must return a result"); - task_ = std::async(std::launch::async, std::forward(func), - std::forward(args)...); -} - -template -[[nodiscard]] auto AsyncWorker::getResult() -> ResultType { - if (!task_.valid()) { - throw std::invalid_argument("Task is not valid"); - } - return task_.get(); -} - -template -void AsyncWorker::cancel() { - if (task_.valid()) { - task_.wait(); // 等待任务完成 - } -} - -template -auto AsyncWorker::isDone() const -> bool { - return task_.valid() && (task_.wait_for(std::chrono::seconds(0)) == - std::future_status::ready); -} - -template -auto AsyncWorker::isActive() const -> bool { - return task_.valid() && (task_.wait_for(std::chrono::seconds(0)) == - std::future_status::timeout); -} - -template -auto AsyncWorker::validate( - std::function validator) -> bool { - if (!isDone()) { - } - ResultType result = getResult(); - return validator(result); -} - -template -void AsyncWorker::setCallback( - std::function callback) { - callback_ = callback; -} - -template -void AsyncWorker::setTimeout(std::chrono::seconds timeout) { - timeout_ = timeout; -} - -template -void AsyncWorker::waitForCompletion() { - if (timeout_ != std::chrono::seconds(0)) { - auto startTime = std::chrono::steady_clock::now(); - while (!isDone()) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (std::chrono::steady_clock::now() - startTime > timeout_) { - cancel(); - break; - } - } - } else { - while (!isDone()) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - } - - if (callback_ && isDone()) { - callback_(getResult()); - } -} - -template -template -[[nodiscard]] auto AsyncWorkerManager::createWorker( - Func &&func, Args &&...args) -> std::shared_ptr> { - auto worker = std::make_shared>(); - workers_.push_back(worker); - worker->StartAsync(std::forward(func), std::forward(args)...); - return worker; -} - -template -void AsyncWorkerManager::cancelAll() { - for (auto &worker : workers_) { - worker->Cancel(); - } -} - -template -auto AsyncWorkerManager::allDone() const -> bool { - return std::all_of(workers_.begin(), workers_.end(), - [](const auto &worker) { return worker->IsDone(); }); -} - -template -void AsyncWorkerManager::waitForAll() { - while (!allDone()) { - } -} - -template -auto AsyncWorkerManager::isDone( - std::shared_ptr> worker) const -> bool { - return worker->IsDone(); -} - -template -void AsyncWorkerManager::cancel( - std::shared_ptr> worker) { - worker->Cancel(); -} - -template -using EnableIfNotVoid = typename std::enable_if_t, T>; - -// Retry strategy enum for different backoff strategies -enum class BackoffStrategy { FIXED, LINEAR, EXPONENTIAL }; - -/** - * @brief Async execution with retry. - * - * @tparam Func The type of the function to be executed asynchronously. - * @tparam Args The types of the arguments to be passed to the function. - * @param func The function to be executed asynchronously. - * @param args The arguments to be passed to the function. - * @return A shared pointer to the created AsyncWorker instance. - */ -template -auto asyncRetryImpl(Func &&func, int attemptsLeft, - std::chrono::milliseconds initialDelay, - BackoffStrategy strategy, - std::chrono::milliseconds maxTotalDelay, - Callback &&callback, ExceptionHandler &&exceptionHandler, - CompleteHandler &&completeHandler, Args &&...args) -> - typename std::invoke_result_t { - using ReturnType = typename std::invoke_result_t; - - auto attempt = std::async(std::launch::async, std::forward(func), - std::forward(args)...); - - try { - if constexpr (std::is_same_v) { - attempt.get(); - callback(); - completeHandler(); - return; - } else { - auto result = attempt.get(); - callback(); - completeHandler(); - return result; - } - } catch (const std::exception &e) { - exceptionHandler(e); // Call custom exception handler - - if (attemptsLeft <= 1 || maxTotalDelay.count() <= 0) { - completeHandler(); // Invoke complete handler on final failure - throw; - } - - switch (strategy) { - case BackoffStrategy::LINEAR: - initialDelay *= 2; - break; - case BackoffStrategy::EXPONENTIAL: - initialDelay = std::chrono::milliseconds(static_cast( - initialDelay.count() * std::pow(2, (5 - attemptsLeft)))); - break; - default: - break; - } - - std::this_thread::sleep_for(initialDelay); - - // Decrease the maximum total delay by the time spent in the last - // attempt - maxTotalDelay -= initialDelay; - - return asyncRetryImpl(std::forward(func), attemptsLeft - 1, - initialDelay, strategy, maxTotalDelay, - std::forward(callback), - std::forward(exceptionHandler), - std::forward(completeHandler), - std::forward(args)...); - } -} - -template -auto asyncRetry(Func &&func, int attemptsLeft, - std::chrono::milliseconds initialDelay, - BackoffStrategy strategy, - std::chrono::milliseconds maxTotalDelay, Callback &&callback, - ExceptionHandler &&exceptionHandler, - CompleteHandler &&completeHandler, Args &&...args) - -> std::future> { - - return std::async(std::launch::async, [=]() mutable { - return asyncRetryImpl(std::forward(func), attemptsLeft, - initialDelay, strategy, maxTotalDelay, - std::forward(callback), - std::forward(exceptionHandler), - std::forward(completeHandler), - std::forward(args)...); - }); -} - -template -auto asyncRetryE(Func &&func, int attemptsLeft, - std::chrono::milliseconds initialDelay, - BackoffStrategy strategy, - std::chrono::milliseconds maxTotalDelay, Callback &&callback, - ExceptionHandler &&exceptionHandler, - CompleteHandler &&completeHandler, Args &&...args) - -> EnhancedFuture> { - using ReturnType = typename std::invoke_result_t; - - auto future = - std::async(std::launch::async, [=]() mutable { - return asyncRetryImpl( - std::forward(func), attemptsLeft, initialDelay, strategy, - maxTotalDelay, std::forward(callback), - std::forward(exceptionHandler), - std::forward(completeHandler), - std::forward(args)...); - }).share(); - - if constexpr (std::is_same_v) { - return EnhancedFuture(std::shared_future(future)); - } else { - return EnhancedFuture( - std::shared_future(future)); - } -} - -// getWithTimeout function for C++17 -template -auto getWithTimeout(std::future &future, - Duration timeout) -> EnableIfNotVoid { - if (future.wait_for(timeout) == std::future_status::ready) { - return future.get(); - } - THROW_TIMEOUT_EXCEPTION("Timeout occurred while waiting for future result"); -} -} // namespace atom::async -#endif diff --git a/src/atom/async/daemon.cpp b/src/atom/async/daemon.cpp deleted file mode 100644 index 3ca3ef74..00000000 --- a/src/atom/async/daemon.cpp +++ /dev/null @@ -1,213 +0,0 @@ -/* - * daemon.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-11 - -Description: Daemon process implementation for Linux and Windows. But there is -still some problems on Windows, especially the console. - -**************************************************/ - -#include "daemon.hpp" - -#include -#include -#include -#include -#include "atom/macro.hpp" - -#ifndef _WIN32 -#include -#endif - -#include "atom/log/loguru.hpp" -#include "atom/utils/time.hpp" - -constexpr int kDaemonRestartInterval = 10; -const std::string kPidFilePath = "lithium-daemon"; - -bool gIsDaemon = false; - -namespace atom::async { -auto DaemonGuard::toString() const -> std::string { - std::stringstream stringStream; - stringStream << "[DaemonGuard parentId=" << m_parentId - << " mainId=" << m_mainId << " parentStartTime=" - << utils::timeStampToString(m_parentStartTime) - << " mainStartTime=" - << utils::timeStampToString(m_mainStartTime) - << " restartCount=" << m_restartCount.load() << "]"; - return stringStream.str(); -} - -auto DaemonGuard::realStart(int /*argc*/, char **argv, - const std::function &mainCb) - -> int { -#ifdef _WIN32 - m_mainId = reinterpret_cast(static_cast(getpid())); -#else - m_mainId = getpid(); -#endif - m_mainStartTime = time(nullptr); - return mainCb(0, argv); -} - -auto DaemonGuard::realDaemon(int /*argc*/, char **argv, - const std::function &mainCb) - -> int { -#ifdef _WIN32 - // 在 Windows 平台下模拟守护进程 - FreeConsole(); - m_parentId = - reinterpret_cast(static_cast(GetCurrentProcessId())); - m_parentStartTime = time(nullptr); - while (true) { - PROCESS_INFORMATION processInfo; - STARTUPINFO startupInfo; - memset(&processInfo, 0, sizeof(processInfo)); - memset(&startupInfo, 0, sizeof(startupInfo)); - startupInfo.cb = sizeof(startupInfo); - if (!CreateProcess(nullptr, argv[0], nullptr, nullptr, FALSE, - CREATE_NEW_CONSOLE, nullptr, nullptr, &startupInfo, - &processInfo)) { - LOG_F(ERROR, "Create process failed with error code {}", - GetLastError()); - return -1; - } - WaitForSingleObject(processInfo.hProcess, INFINITE); - CloseHandle(processInfo.hProcess); - CloseHandle(processInfo.hThread); - - // 等待一段时间后重新启动子进程 - m_restartCount++; - Sleep(kDaemonRestartInterval * 1000); - } -#else - if (daemon(1, 0) == -1) { - perror("daemon"); - exit(EXIT_FAILURE); - } - - m_parentId = getpid(); - m_parentStartTime = time(nullptr); - while (true) { - pid_t pid = fork(); // 创建子进程 - if (pid == 0) { // 子进程 - m_mainId = getpid(); - m_mainStartTime = time(nullptr); - LOG_F(INFO, "daemon process start pid={}", - reinterpret_cast(getpid())); - return realStart(0, argv, mainCb); - } - if (pid < 0) { // 创建子进程失败 - LOG_F(ERROR, "fork fail return={} errno={} errstr={}", pid, errno, - strerror(errno)); - return -1; - } // 父进程 - int status = 0; - waitpid(pid, &status, 0); // 等待子进程退出 - - // 子进程异常退出 - if (status != 0) { - if (status == 9) { // SIGKILL 信号杀死子进程,不需要重新启动 - LOG_F(INFO, "daemon process killed pid={}", getpid()); - break; - } // 记录日志并重新启动子进程 - LOG_F(ERROR, "child crash pid={} status={}", pid, status); - - } else { // 正常退出,直接退出程序 - LOG_F(INFO, "daemon process exit pid={}", getpid()); - break; - } - - // 等待一段时间后重新启动子进程 - m_restartCount++; - sleep(kDaemonRestartInterval); - } -#endif - return 0; -} - -// 启动进程,如果需要创建守护进程,则先创建守护进程 -auto DaemonGuard::startDaemon(int argc, char **argv, - const std::function &mainCb, - bool isDaemon) -> int { -#ifdef _WIN32 - if (isDaemon) { - AllocConsole(); - if (!freopen("CONOUT$", "w", stdout)) { - LOG_F(ERROR, "Failed to redirect stdout"); - return -1; - } - if (!freopen("CONOUT$", "w", stderr)) { - LOG_F(ERROR, "Failed to redirect stderr"); - return -1; - } - } -#endif - - if (!isDaemon) { // 不需要创建守护进程 -#ifdef _WIN32 - m_parentId = reinterpret_cast(static_cast(getpid())); -#else - m_parentId = getpid(); -#endif - m_parentStartTime = time(nullptr); - return realStart(argc, argv, mainCb); - } - // 创建守护进程 - return realDaemon(argc, argv, mainCb); -} - -void signalHandler(int signum) { -#ifdef _WIN32 - if (signum == SIGTERM || signum == SIGINT) { - if (remove(kPidFilePath.c_str()) != 0) { - LOG_F(ERROR, "Failed to remove PID file"); - } - exit(0); - } -#else - if (signum == SIGTERM || signum == SIGINT) { - ATOM_UNREF_PARAM(remove(kPidFilePath.c_str())); - exit(0); - } -#endif -} - -void writePidFile() { - std::ofstream ofs(kPidFilePath); - if (!ofs) { - LOG_F(ERROR, "open pid file {} failed", kPidFilePath); - exit(-1); - } - ofs << getpid(); - ofs.close(); -} - -// 检查 PID 文件是否存在,并检查文件中的 PID 是否有效 -auto checkPidFile() -> bool { -#ifdef _WIN32 - // Windows 平台下不检查 PID 文件是否存在以及文件中的 PID 是否有效 - return false; -#else - struct stat st {}; - if (stat(kPidFilePath.c_str(), &st) != 0) { - return false; - } - std::ifstream ifs(kPidFilePath); - if (!ifs) { - return false; - } - pid_t pid = -1; - ifs >> pid; - ifs.close(); - return kill(pid, 0) != -1 || errno != ESRCH; -#endif -} -} // namespace atom::async diff --git a/src/atom/async/daemon.hpp b/src/atom/async/daemon.hpp deleted file mode 100644 index fda4c2f8..00000000 --- a/src/atom/async/daemon.hpp +++ /dev/null @@ -1,122 +0,0 @@ -/* - * daemon.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-11 - -Description: Daemon process implementation - -**************************************************/ - -#ifndef ATOM_SERVER_DAEMON_HPP -#define ATOM_SERVER_DAEMON_HPP - -#include -#include -#include -#include -#include - -#ifdef _WIN32 -#include -#else -#include -#include -#include -#endif - -namespace atom::async { -// Class for managing process information -class DaemonGuard { -public: - /** - * @brief Default constructor. - */ - DaemonGuard() = default; - - /** - * @brief Converts process information to a string. - * - * @return The process information as a string. - */ - [[nodiscard]] auto toString() const -> std::string; - - /** - * @brief Starts a child process to execute the actual task. - * - * @param argc The number of command line arguments. - * @param argv An array of command line arguments. - * @param mainCb The main callback function to be executed in the child - * process. - * @return The return value of the main callback function. - */ - auto realStart(int argc, char **argv, - const std::function &mainCb) - -> int; - - /** - * @brief Starts a child process to execute the actual task. - * - * @param argc The number of command line arguments. - * @param argv An array of command line arguments. - * @param mainCb The main callback function to be executed in the child - * process. - * @return The return value of the main callback function. - */ - auto realDaemon(int argc, char **argv, - const std::function &mainCb) - -> int; - - /** - * @brief Starts the process. If a daemon process needs to be created, it - * will create the daemon process first. - * - * @param argc The number of command line arguments. - * @param argv An array of command line arguments. - * @param mainCb The main callback function to be executed. - * @param isDaemon Determines if a daemon process should be created. - * @return The return value of the main callback function. - */ - auto startDaemon(int argc, char **argv, - const std::function &mainCb, - bool isDaemon) -> int; - -private: -#ifdef _WIN32 - HANDLE m_parentId = 0; - HANDLE m_mainId = 0; -#else - pid_t m_parentId = 0; /**< The parent process ID. */ - pid_t m_mainId = 0; /**< The child process ID. */ -#endif - time_t m_parentStartTime = 0; /**< The start time of the parent process. */ - time_t m_mainStartTime = 0; /**< The start time of the child process. */ - std::atomic m_restartCount{0}; /**< The number of restarts. */ -}; - -/** - * @brief Signal handler function. - * - * @param signum The signal number. - */ -void signalHandler(int signum); - -/** - * @brief Writes the process ID to a file. - */ -void writePidFile(); - -/** - * @brief Checks if the process ID file exists. - * - * @return True if the process ID file exists, false otherwise. - */ -auto checkPidFile() -> bool; - -} // namespace atom::async - -#endif diff --git a/src/atom/async/eventstack.hpp b/src/atom/async/eventstack.hpp deleted file mode 100644 index cd07bd78..00000000 --- a/src/atom/async/eventstack.hpp +++ /dev/null @@ -1,380 +0,0 @@ -/* - * eventstack.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-3-26 - -Description: A thread-safe stack data structure for managing events. - -**************************************************/ - -#ifndef ATOM_ASYNC_EVENTSTACK_HPP -#define ATOM_ASYNC_EVENTSTACK_HPP - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::async { -/** - * @brief A thread-safe stack data structure for managing events. - * - * @tparam T The type of events to store. - */ -template -class EventStack { -public: - EventStack() = default; - ~EventStack() = default; - - // Rule of five: explicitly define copy constructor, copy assignment - // operator, move constructor, and move assignment operator. - EventStack(const EventStack& other); - EventStack& operator=(const EventStack& other); - EventStack(EventStack&& other) noexcept; - EventStack& operator=(EventStack&& other) noexcept; - - /** - * @brief Pushes an event onto the stack. - * - * @param event The event to push. - */ - void pushEvent(T event); - - /** - * @brief Pops an event from the stack. - * - * @return The popped event, or std::nullopt if the stack is empty. - */ - auto popEvent() -> std::optional; - -#if ENABLE_DEBUG - /** - * @brief Prints all events in the stack. - */ - void printEvents() const; -#endif - - /** - * @brief Checks if the stack is empty. - * - * @return true if the stack is empty, false otherwise. - */ - auto isEmpty() const -> bool; - - /** - * @brief Returns the number of events in the stack. - * - * @return The number of events. - */ - auto size() const -> size_t; - - /** - * @brief Clears all events from the stack. - */ - void clearEvents(); - - /** - * @brief Returns the top event in the stack without removing it. - * - * @return The top event, or std::nullopt if the stack is empty. - */ - auto peekTopEvent() const -> std::optional; - - /** - * @brief Copies the current stack. - * - * @return A copy of the stack. - */ - auto copyStack() const -> EventStack; - - /** - * @brief Filters events based on a custom filter function. - * - * @param filterFunc The filter function. - */ - void filterEvents(std::function filterFunc); - - /** - * @brief Serializes the stack into a string. - * - * @return The serialized stack. - */ - auto serializeStack() const -> std::string; - - /** - * @brief Deserializes a string into the stack. - * - * @param serializedData The serialized stack data. - */ - void deserializeStack(std::string_view serializedData); - - /** - * @brief Removes duplicate events from the stack. - */ - void removeDuplicates(); - - /** - * @brief Sorts the events in the stack based on a custom comparison - * function. - * - * @param compareFunc The comparison function. - */ - void sortEvents(std::function compareFunc); - - /** - * @brief Reverses the order of events in the stack. - */ - void reverseEvents(); - - /** - * @brief Counts the number of events that satisfy a predicate. - * - * @param predicate The predicate function. - * @return The count of events satisfying the predicate. - */ - auto countEvents(std::function predicate) const -> size_t; - - /** - * @brief Finds the first event that satisfies a predicate. - * * - * @param predicate The predicate function. - * @return The first event satisfying the predicate, or std::nullopt if not - * found. - */ - auto findEvent(std::function predicate) const - -> std::optional; - - /** - * @brief Checks if any event in the stack satisfies a predicate. - * - * @param predicate The predicate function. - * @return true if any event satisfies the predicate, false otherwise. - */ - auto anyEvent(std::function predicate) const -> bool; - - /** - * @brief Checks if all events in the stack satisfy a predicate. - * - * @param predicate The predicate function. - * @return true if all events satisfy the predicate, false otherwise. - */ - auto allEvents(std::function predicate) const -> bool; - -private: - std::vector events_; /**< Vector to store events. */ - mutable std::shared_mutex mtx_; /**< Mutex for thread safety. */ - std::atomic eventCount_{0}; /**< Atomic counter for event count. */ -}; - -// Copy constructor -template -EventStack::EventStack(const EventStack& other) { - std::shared_lock lock(other.mtx_); - events_ = other.events_; - eventCount_.store(other.eventCount_.load()); -} - -// Copy assignment operator -template -EventStack& EventStack::operator=(const EventStack& other) { - if (this != &other) { - std::unique_lock lock1(mtx_, std::defer_lock); - std::shared_lock lock2(other.mtx_, std::defer_lock); - std::lock(lock1, lock2); - events_ = other.events_; - eventCount_.store(other.eventCount_.load()); - } - return *this; -} - -// Move constructor -template -EventStack::EventStack(EventStack&& other) noexcept { - std::unique_lock lock(other.mtx_); - events_ = std::move(other.events_); - eventCount_.store(other.eventCount_.load()); - other.eventCount_.store(0); -} - -// Move assignment operator -template -EventStack& EventStack::operator=(EventStack&& other) noexcept { - if (this != &other) { - std::unique_lock lock1(mtx_, std::defer_lock); - std::unique_lock lock2(other.mtx_, std::defer_lock); - std::lock(lock1, lock2); - events_ = std::move(other.events_); - eventCount_.store(other.eventCount_.load()); - other.eventCount_.store(0); - } - return *this; -} - -template -void EventStack::pushEvent(T event) { - std::unique_lock lock(mtx_); - events_.push_back(std::move(event)); - ++eventCount_; -} - -template -auto EventStack::popEvent() -> std::optional { - std::unique_lock lock(mtx_); - if (!events_.empty()) { - T event = std::move(events_.back()); - events_.pop_back(); - --eventCount_; - return event; - } - return std::nullopt; -} - -#if ENABLE_DEBUG -template -void EventStack::printEvents() const { - std::shared_lock lock(mtx_); - std::cout << "Events in stack:" << std::endl; - for (const T& event : events_) { - std::cout << event << std::endl; - } -} -#endif - -template -auto EventStack::isEmpty() const -> bool { - std::shared_lock lock(mtx_); - return events_.empty(); -} - -template -auto EventStack::size() const -> size_t { - return eventCount_.load(); -} - -template -void EventStack::clearEvents() { - std::unique_lock lock(mtx_); - events_.clear(); - eventCount_.store(0); -} - -template -auto EventStack::peekTopEvent() const -> std::optional { - std::shared_lock lock(mtx_); - if (!events_.empty()) { - return events_.back(); - } - return std::nullopt; -} - -template -auto EventStack::copyStack() const -> EventStack { - std::shared_lock lock(mtx_); - EventStack newStack; - newStack.events_ = events_; - newStack.eventCount_.store(eventCount_.load()); - return newStack; -} - -template -void EventStack::filterEvents(std::function filterFunc) { - std::unique_lock lock(mtx_); - events_.erase( - std::remove_if(events_.begin(), events_.end(), - [&](const T& event) { return !filterFunc(event); }), - events_.end()); - eventCount_.store(events_.size()); -} - -template -auto EventStack::serializeStack() const -> std::string { - std::shared_lock lock(mtx_); - std::string serializedStack; - serializedStack.reserve(events_.size() * - sizeof(T)); // Reserve space to improve performance - for (const T& event : events_) { - serializedStack += event + ";"; - } - return serializedStack; -} - -template -void EventStack::deserializeStack(std::string_view serializedData) { - std::unique_lock lock(mtx_); - events_.clear(); - size_t pos = 0; - size_t nextPos = 0; - while ((nextPos = serializedData.find(';', pos)) != - std::string_view::npos) { - T event = serializedData.substr(pos, nextPos - pos); - events_.push_back(std::move(event)); - pos = nextPos + 1; - } - eventCount_.store(events_.size()); -} - -template -void EventStack::removeDuplicates() { - std::unique_lock lock(mtx_); - std::sort(events_.begin(), events_.end()); - events_.erase(std::unique(events_.begin(), events_.end()), events_.end()); - eventCount_.store(events_.size()); -} - -template -void EventStack::sortEvents( - std::function compareFunc) { - std::unique_lock lock(mtx_); - std::sort(events_.begin(), events_.end(), compareFunc); -} - -template -void EventStack::reverseEvents() { - std::unique_lock lock(mtx_); - std::reverse(events_.begin(), events_.end()); -} - -template -auto EventStack::countEvents(std::function predicate) const - -> size_t { - std::shared_lock lock(mtx_); - return std::count_if(events_.begin(), events_.end(), predicate); -} - -template -auto EventStack::findEvent(std::function predicate) const - -> std::optional { - std::shared_lock lock(mtx_); - auto iterator = std::find_if(events_.begin(), events_.end(), predicate); - if (iterator != events_.end()) { - return *iterator; - } - return std::nullopt; -} - -template -auto EventStack::anyEvent(std::function predicate) const - -> bool { - std::shared_lock lock(mtx_); - return std::any_of(events_.begin(), events_.end(), predicate); -} - -template -auto EventStack::allEvents(std::function predicate) const - -> bool { - std::shared_lock lock(mtx_); - return std::all_of(events_.begin(), events_.end(), predicate); -} -} // namespace atom::async - -#endif // ATOM_ASYNC_EVENTSTACK_HPP diff --git a/src/atom/async/future.hpp b/src/atom/async/future.hpp deleted file mode 100644 index afa5ce46..00000000 --- a/src/atom/async/future.hpp +++ /dev/null @@ -1,474 +0,0 @@ -#ifndef ATOM_ASYNC_FUTURE_HPP -#define ATOM_ASYNC_FUTURE_HPP - -#include -#include -#include - -#include "atom/error/exception.hpp" - -namespace atom::async { - -/** - * @class InvalidFutureException - * @brief Exception thrown when an invalid future is encountered. - */ -class InvalidFutureException : public atom::error::RuntimeError { -public: - using atom::error::RuntimeError::RuntimeError; -}; - -/** - * @def THROW_INVALID_FUTURE_EXCEPTION - * @brief Macro to throw an InvalidFutureException with file, line, and function - * information. - */ -#define THROW_INVALID_FUTURE_EXCEPTION(...) \ - throw InvalidFutureException(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__); - -/** - * @def THROW_NESTED_INVALID_FUTURE_EXCEPTION - * @brief Macro to rethrow a nested InvalidFutureException with file, line, and - * function information. - */ -#define THROW_NESTED_INVALID_FUTURE_EXCEPTION(...) \ - InvalidFutureException::rethrowNested(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, \ - "Invalid future: " __VA_ARGS__); - -/** - * @class EnhancedFuture - * @brief A template class that extends the standard future with additional - * features. - * @tparam T The type of the value that the future will hold. - */ -template -class EnhancedFuture { -public: - /** - * @brief Constructs an EnhancedFuture from a shared future. - * @param fut The shared future to wrap. - */ - explicit EnhancedFuture(std::shared_future &&fut) - : future_(std::move(fut)), cancelled_(false) {} - - explicit EnhancedFuture(const std::shared_future &fut) - : future_(fut), cancelled_(false) {} - - /** - * @brief Chains another operation to be called after the future is done. - * @tparam F The type of the function to call. - * @param func The function to call when the future is done. - * @return An EnhancedFuture for the result of the function. - */ - template - auto then(F &&func) { - using ResultType = std::invoke_result_t; - return EnhancedFuture( - std::async(std::launch::async, [fut = future_, - func = std::forward( - func)]() mutable { - if (fut.valid()) { - return func(fut.get()); - } - THROW_INVALID_FUTURE_EXCEPTION( - "Future is invalid or cancelled"); - }).share()); - } - - /** - * @brief Waits for the future with a timeout and auto-cancels if not ready. - * @param timeout The timeout duration. - * @return An optional containing the value if ready, or nullopt if timed - * out. - */ - auto waitFor(std::chrono::milliseconds timeout) -> std::optional { - if (future_.wait_for(timeout) == std::future_status::ready && - !cancelled_) { - return future_.get(); - } - cancel(); - return std::nullopt; - } - - /** - * @brief Checks if the future is done. - * @return True if the future is done, false otherwise. - */ - [[nodiscard]] auto isDone() const -> bool { - return future_.wait_for(std::chrono::milliseconds(0)) == - std::future_status::ready; - } - - /** - * @brief Sets a completion callback to be called when the future is done. - * @tparam F The type of the callback function. - * @param func The callback function to add. - */ - template - void onComplete(F &&func) { - if (!cancelled_) { - callbacks_.emplace_back(std::forward(func)); - std::async(std::launch::async, [this]() { - try { - if (future_.valid()) { - auto result = future_.get(); - for (auto &callback : callbacks_) { - callback(result); - } - } - } catch (const std::exception &e) { - } - }).get(); - } - } - - /** - * @brief Waits synchronously for the future to complete. - * @return The value of the future. - * @throws InvalidFutureException if the future is cancelled. - */ - auto wait() -> T { - if (cancelled_) { - THROW_OBJ_NOT_EXIST("Future has been cancelled"); - } - return future_.get(); - } - - template - auto catching(F &&func) { - using ResultType = T; - auto sharedFuture = std::make_shared>(future_); - return EnhancedFuture( - std::async(std::launch::async, [sharedFuture, - func = std::forward( - func)]() mutable { - try { - if (sharedFuture->valid()) { - return sharedFuture->get(); - } - THROW_INVALID_FUTURE_EXCEPTION( - "Future is invalid or cancelled"); - } catch (...) { - return func(std::current_exception()); - } - }).share()); - } - - /** - * @brief Cancels the future. - */ - void cancel() { cancelled_ = true; } - - /** - * @brief Checks if the future has been cancelled. - * @return True if the future has been cancelled, false otherwise. - */ - [[nodiscard]] auto isCancelled() const -> bool { return cancelled_; } - - /** - * @brief Gets the exception associated with the future, if any. - * @return A pointer to the exception, or nullptr if no exception. - */ - auto getException() -> std::exception_ptr { - try { - future_.get(); - } catch (...) { - return std::current_exception(); - } - return nullptr; - } - - /** - * @brief Retries the operation associated with the future. - * @tparam F The type of the function to call. - * @param func The function to call when retrying. - * @param max_retries The maximum number of retries. - * @return An EnhancedFuture for the result of the function. - */ - template - auto retry(F &&func, int max_retries) { - using ResultType = std::invoke_result_t; - return EnhancedFuture( - std::async(std::launch::async, [fut = future_, - func = std::forward(func), - max_retries]() mutable { - for (int attempt = 0; attempt < max_retries; ++attempt) { - if (fut.valid()) { - try { - return func(fut.get()); - } catch (const std::exception &e) { - if (attempt == max_retries - 1) { - throw; - } - } - } else { - THROW_UNLAWFUL_OPERATION( - "Future is invalid or cancelled"); - } - } - }).share()); - } - - auto isReady() -> bool { - return future_.wait_for(std::chrono::milliseconds(0)) == - std::future_status::ready; - } - - auto get() -> T { return future_.get(); } - -protected: - std::shared_future future_; ///< The underlying shared future. - std::vector> - callbacks_; ///< List of callbacks to be called on completion. - bool cancelled_; ///< Flag indicating if the future has been cancelled. -}; - -/** - * @class EnhancedFuture - * @brief Specialization of the EnhancedFuture class for void type. - */ -template <> -class EnhancedFuture { -public: - /** - * @brief Constructs an EnhancedFuture from a shared future. - * @param fut The shared future to wrap. - */ - explicit EnhancedFuture(std::shared_future &&fut) - : future_(std::move(fut)), cancelled_(false) {} - - explicit EnhancedFuture(const std::shared_future &fut) - : future_(fut), cancelled_(false) {} - - /** - * @brief Chains another operation to be called after the future is done. - * @tparam F The type of the function to call. - * @param func The function to call when the future is done. - * @return An EnhancedFuture for the result of the function. - */ - template - auto then(F &&func) { - using ResultType = std::invoke_result_t; - return EnhancedFuture( - std::async(std::launch::async, [fut = future_, - func = std::forward( - func)]() mutable { - if (fut.valid()) { - fut.get(); - return func(); - } - THROW_UNLAWFUL_OPERATION("Future is invalid or cancelled"); - }).share()); - } - - /** - * @brief Waits for the future with a timeout and auto-cancels if not ready. - * @param timeout The timeout duration. - * @return True if the future is ready, false otherwise. - */ - auto waitFor(std::chrono::milliseconds timeout) -> bool { - if (future_.wait_for(timeout) == std::future_status::ready && - !cancelled_) { - future_.get(); - return true; - } - cancel(); - return false; - } - - /** - * @brief Checks if the future is done. - * @return True if the future is done, false otherwise. - */ - [[nodiscard]] auto isDone() const -> bool { - return future_.wait_for(std::chrono::milliseconds(0)) == - std::future_status::ready; - } - - /** - * @brief Sets a completion callback to be called when the future is done. - * @tparam F The type of the callback function. - * @param func The callback function to add. - */ - template - void onComplete(F &&func) { - if (!cancelled_) { - callbacks_.emplace_back(std::forward(func)); - std::async(std::launch::async, [this]() { - try { - if (future_.valid()) { - future_.get(); - for (auto &callback : callbacks_) { - callback(); - } - } - } catch (const std::exception &e) { - } - }).get(); - } - } - - /** - * @brief Waits synchronously for the future to complete. - * @throws InvalidFutureException if the future is cancelled. - */ - void wait() { - if (cancelled_) { - THROW_OBJ_NOT_EXIST("Future has been cancelled"); - } - future_.get(); - } - - /** - * @brief Cancels the future. - */ - void cancel() { cancelled_ = true; } - - /** - * @brief Checks if the future has been cancelled. - * @return True if the future has been cancelled, false otherwise. - */ - [[nodiscard]] auto isCancelled() const -> bool { return cancelled_; } - - /** - * @brief Gets the exception associated with the future, if any. - * @return A pointer to the exception, or nullptr if no exception. - */ - auto getException() -> std::exception_ptr { - try { - future_.get(); - } catch (...) { - return std::current_exception(); - } - return nullptr; - } - - auto isReady() -> bool { - return future_.wait_for(std::chrono::milliseconds(0)) == - std::future_status::ready; - } - - auto get() -> void { future_.get(); } - -protected: - std::shared_future future_; ///< The underlying shared future. - std::vector> - callbacks_; ///< List of callbacks to be called on completion. - std::atomic - cancelled_; ///< Flag indicating if the future has been cancelled. -}; - -/** - * @brief Helper function to create an EnhancedFuture. - * @tparam F The type of the function to call. - * @tparam Args The types of the arguments to pass to the function. - * @param f The function to call. - * @param args The arguments to pass to the function. - * @return An EnhancedFuture for the result of the function. - */ -template -auto makeEnhancedFuture(F &&f, Args &&...args) { - using result_type = std::invoke_result_t; - return EnhancedFuture(std::async(std::launch::async, - std::forward(f), - std::forward(args)...) - .share()); -} - -/** - * @brief Helper function to get a future for a range of futures. - * @tparam InputIt The type of the input iterator. - * @param first The beginning of the range. - * @param last The end of the range. - * @param timeout An optional timeout duration. - * @return A future containing a vector of the results of the input futures. - */ -template -auto whenAll(InputIt first, InputIt last, - std::optional timeout = std::nullopt) - -> std::future< - std::vector::value_type>> { - using FutureType = typename std::iterator_traits::value_type; - using ResultType = std::vector; - - std::promise promise; - std::future resultFuture = promise.get_future(); - - // Launch an async task to wait for all the futures - auto asyncTask = std::async([promise = std::move(promise), first, last, - timeout]() mutable { - ResultType results; - try { - for (auto it = first; it != last; ++it) { - if (timeout) { - // Check each future with timeout (if specified) - if (it->wait_for(*timeout) == std::future_status::timeout) { - THROW_INVALID_ARGUMENT( - "Timeout while waiting for a future."); - } - } - results.push_back(std::move(*it)); - } - promise.set_value(std::move(results)); - } catch (const std::exception &e) { - promise.set_exception( - std::current_exception()); // Pass the exception to the future - } - }); - - // Optionally, store the future or use it if needed - asyncTask.wait(); // Wait for the async task to finish - - return resultFuture; -} - -/** - * @brief Helper to get the return type of a future. - * @tparam T The type of the future. - */ -template -using future_value_t = decltype(std::declval().get()); - -/** - * @brief Helper function for a variadic template version (when_all for futures - * as arguments). - * @tparam Futures The types of the futures. - * @param futures The futures to wait for. - * @return A future containing a tuple of the results of the input futures. - */ -template -auto whenAll(Futures &&...futures) - -> std::future...>> { - std::promise...>> promise; - std::future...>> resultFuture = - promise.get_future(); - - // Use async to wait for all futures and gather results - auto asyncTask = - std::async([promise = std::move(promise), - futures = std::make_tuple( - std::forward(futures)...)]() mutable { - try { - auto results = std::apply( - [](auto &&...fs) { - return std::make_tuple( - fs.get()...); // Wait for each future and collect - // the results - }, - futures); - promise.set_value(std::move(results)); - } catch (const std::exception &e) { - promise.set_exception(std::current_exception()); - } - }); - - asyncTask.wait(); // Wait for the async task to finish - - return resultFuture; -} - -} // namespace atom::async - -#endif // ATOM_ASYNC_FUTURE_HPP diff --git a/src/atom/async/limiter.cpp b/src/atom/async/limiter.cpp deleted file mode 100644 index 5a885694..00000000 --- a/src/atom/async/limiter.cpp +++ /dev/null @@ -1,290 +0,0 @@ -#include "limiter.hpp" - -#include "atom/log/loguru.hpp" - -namespace atom::async { -RateLimiter::Settings::Settings(size_t max_requests, - std::chrono::seconds time_window) - : maxRequests(max_requests), timeWindow(time_window) { - LOG_F(INFO, "Settings created: max_requests=%zu, time_window=%lld seconds", - max_requests, time_window.count()); -} - -// Implementation of RateLimiter constructor -RateLimiter::RateLimiter() { LOG_F(INFO, "RateLimiter created"); } - -// Implementation of Awaiter constructor -RateLimiter::Awaiter::Awaiter(RateLimiter& limiter, - const std::string& function_name) - : limiter_(limiter), function_name_(function_name) { - LOG_F(INFO, "Awaiter created for function: %s", function_name.c_str()); -} - -// Implementation of Awaiter::await_ready -auto RateLimiter::Awaiter::await_ready() -> bool { - LOG_F(INFO, "Awaiter::await_ready called for function: %s", - function_name_.c_str()); - return false; -} - -// Implementation of Awaiter::await_suspend -void RateLimiter::Awaiter::await_suspend(std::coroutine_handle<> handle) { - LOG_F(INFO, "Awaiter::await_suspend called for function: %s", - function_name_.c_str()); - std::unique_lock lock(limiter_.mutex_); - auto& settings = limiter_.settings_[function_name_]; - limiter_.cleanup(function_name_, settings.timeWindow); - if (limiter_.paused_ || - limiter_.requests_[function_name_].size() >= settings.maxRequests) { - limiter_.waiters_[function_name_].emplace_back(handle); - limiter_.rejected_requests_[function_name_]++; - LOG_F(WARNING, "Request for function %s rejected. Total rejected: %zu", - function_name_.c_str(), - limiter_.rejected_requests_[function_name_]); - } else { - limiter_.requests_[function_name_].emplace_back( - std::chrono::steady_clock::now()); - lock.unlock(); - LOG_F(INFO, "Request for function %s accepted", function_name_.c_str()); - handle.resume(); - } -} - -// Implementation of Awaiter::await_resume -void RateLimiter::Awaiter::await_resume() { - LOG_F(INFO, "Awaiter::await_resume called for function: %s", - function_name_.c_str()); -} - -// Implementation of RateLimiter::acquire -RateLimiter::Awaiter RateLimiter::acquire(const std::string& function_name) { - LOG_F(INFO, "RateLimiter::acquire called for function: %s", - function_name.c_str()); - return Awaiter(*this, function_name); -} - -// Implementation of RateLimiter::setFunctionLimit -void RateLimiter::setFunctionLimit(const std::string& function_name, - size_t max_requests, - std::chrono::seconds time_window) { - LOG_F(INFO, - "RateLimiter::setFunctionLimit called for function: %s, " - "max_requests=%zu, time_window=%lld seconds", - function_name.c_str(), max_requests, time_window.count()); - std::unique_lock lock(mutex_); - settings_[function_name] = Settings(max_requests, time_window); -} - -// Implementation of RateLimiter::pause -void RateLimiter::pause() { - LOG_F(INFO, "RateLimiter::pause called"); - std::unique_lock lock(mutex_); - paused_ = true; -} - -// Implementation of RateLimiter::resume -void RateLimiter::resume() { - LOG_F(INFO, "RateLimiter::resume called"); - std::unique_lock lock(mutex_); - paused_ = false; - processWaiters(); -} - -// Implementation of RateLimiter::printLog -void RateLimiter::printLog() { -#if ENABLE_DEBUG - LOG_F(INFO, "RateLimiter::printLog called"); - std::unique_lock lock(mutex_); - for (const auto& [function_name, timestamps] : log_) { - std::cout << "Request log for " << function_name << ":\n"; - for (const auto& timestamp : timestamps) { - std::cout << "Request at " << timestamp.time_since_epoch().count() - << std::endl; - } - } -#endif -} - -// Implementation of RateLimiter::getRejectedRequests -auto RateLimiter::getRejectedRequests(const std::string& function_name) - -> size_t { - LOG_F(INFO, "RateLimiter::getRejectedRequests called for function: %s", - function_name.c_str()); - std::unique_lock lock(mutex_); - return rejected_requests_[function_name]; -} - -// Implementation of RateLimiter::cleanup -void RateLimiter::cleanup(const std::string& function_name, - const std::chrono::seconds& time_window) { - LOG_F(INFO, - "RateLimiter::cleanup called for function: %s, time_window=%lld " - "seconds", - function_name.c_str(), time_window.count()); - auto now = std::chrono::steady_clock::now(); - auto& reqs = requests_[function_name]; - while (!reqs.empty() && now - reqs.front() > time_window) { - reqs.pop_front(); - } -} - -// Implementation of RateLimiter::processWaiters -void RateLimiter::processWaiters() { - LOG_F(INFO, "RateLimiter::processWaiters called"); - for (auto& [function_name, wait_queue] : waiters_) { - auto& settings = settings_[function_name]; - while (!wait_queue.empty() && - requests_[function_name].size() < settings.maxRequests) { - auto waiter = wait_queue.front(); - wait_queue.pop_front(); - requests_[function_name].emplace_back( - std::chrono::steady_clock::now()); - mutex_.unlock(); - LOG_F(INFO, "Resuming waiter for function: %s", - function_name.c_str()); - waiter.resume(); - mutex_.lock(); - } - } -} - -Debounce::Debounce(std::function func, std::chrono::milliseconds delay, - bool leading, - std::optional maxWait) - : func_(std::move(func)), - delay_(delay), - leading_(leading), - maxWait_(maxWait) { - LOG_F(INFO, "Debounce created: delay=%lld ms, leading=%d, maxWait=%lld ms", - delay.count(), leading, maxWait ? maxWait->count() : 0); -} - -void Debounce::operator()() { - LOG_F(INFO, "Debounce operator() called"); - auto now = std::chrono::steady_clock::now(); - std::unique_lock lock(mutex_); - - if (leading_ && !scheduled_) { - scheduled_ = true; - func_(); - ++call_count_; - } - - last_call_ = now; - if (!thread_.joinable()) { - thread_ = std::jthread([this]() { this->run(); }); - } -} - -void Debounce::cancel() { - LOG_F(INFO, "Debounce::cancel called"); - std::unique_lock lock(mutex_); - scheduled_ = false; - last_call_.reset(); -} - -void Debounce::flush() { - LOG_F(INFO, "Debounce::flush called"); - std::unique_lock lock(mutex_); - if (scheduled_) { - func_(); - ++call_count_; - scheduled_ = false; - } -} - -void Debounce::reset() { - LOG_F(INFO, "Debounce::reset called"); - std::unique_lock lock(mutex_); - last_call_.reset(); - scheduled_ = false; -} - -size_t Debounce::callCount() const { - std::unique_lock lock(mutex_); - return call_count_; -} - -void Debounce::run() { - LOG_F(INFO, "Debounce::run started"); - while (true) { - std::this_thread::sleep_for(delay_); - std::unique_lock lock(mutex_); - auto now = std::chrono::steady_clock::now(); - if (last_call_ && now - last_call_.value() >= delay_) { - if (scheduled_) { - func_(); - ++call_count_; - scheduled_ = false; - } - LOG_F(INFO, "Debounce::run finished"); - return; - } - if (maxWait_ && now - last_call_.value() >= maxWait_) { - if (scheduled_) { - func_(); - ++call_count_; - scheduled_ = false; - } - LOG_F(INFO, "Debounce::run finished"); - return; - } - } -} - -Throttle::Throttle(std::function func, - std::chrono::milliseconds interval, bool leading, - std::optional maxWait) - : func_(std::move(func)), - interval_(interval), - last_call_(std::chrono::steady_clock::now() - interval), - leading_(leading), - maxWait_(maxWait) { - LOG_F(INFO, - "Throttle created: interval=%lld ms, leading=%d, maxWait=%lld ms", - interval.count(), leading, maxWait ? maxWait->count() : 0); -} - -void Throttle::operator()() { - LOG_F(INFO, "Throttle operator() called"); - auto now = std::chrono::steady_clock::now(); - std::unique_lock lock(mutex_); - - if (leading_ && !called_) { - called_ = true; - func_(); - last_call_ = now; - ++call_count_; - return; - } - - if (now - last_call_ >= interval_) { - last_call_ = now; - func_(); - ++call_count_; - } else if (maxWait_ && (now - last_call_ >= maxWait_)) { - last_call_ = now; - func_(); - ++call_count_; - } -} - -void Throttle::cancel() { - LOG_F(INFO, "Throttle::cancel called"); - std::unique_lock lock(mutex_); - called_ = false; -} - -void Throttle::reset() { - LOG_F(INFO, "Throttle::reset called"); - std::unique_lock lock(mutex_); - last_call_ = std::chrono::steady_clock::now() - interval_; - called_ = false; -} - -auto Throttle::callCount() const -> size_t { - std::unique_lock lock(mutex_); - return call_count_; -} - -} // namespace atom::async diff --git a/src/atom/async/limiter.hpp b/src/atom/async/limiter.hpp deleted file mode 100644 index 24c6ea52..00000000 --- a/src/atom/async/limiter.hpp +++ /dev/null @@ -1,329 +0,0 @@ -#ifndef ATOM_ASYNC_LIMITER_HPP -#define ATOM_ASYNC_LIMITER_HPP - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::async { -/** - * @brief A rate limiter class to control the rate of function executions. - */ -class RateLimiter { -public: - /** - * @brief Settings for the rate limiter. - */ - struct Settings { - size_t maxRequests; ///< Maximum number of requests allowed in the time - ///< window. - std::chrono::seconds - timeWindow; ///< The time window in which maxRequests are allowed. - - /** - * @brief Constructor for Settings. - * @param max_requests Maximum number of requests. - * @param time_window Duration of the time window. - */ - explicit Settings( - size_t max_requests = 5, - std::chrono::seconds time_window = std::chrono::seconds(1)); - }; - - /** - * @brief Constructor for RateLimiter. - */ - RateLimiter(); - - /** - * @brief Awaiter class for handling coroutines. - */ - class Awaiter { - public: - /** - * @brief Constructor for Awaiter. - * @param limiter Reference to the rate limiter. - * @param function_name Name of the function to be rate-limited. - */ - Awaiter(RateLimiter& limiter, const std::string& function_name); - - /** - * @brief Checks if the awaiter is ready. - * @return Always returns false. - */ - auto await_ready() -> bool; - - /** - * @brief Suspends the coroutine. - * @param handle Coroutine handle. - */ - void await_suspend(std::coroutine_handle<> handle); - - /** - * @brief Resumes the coroutine. - */ - void await_resume(); - - private: - RateLimiter& limiter_; - std::string function_name_; - }; - - /** - * @brief Acquires the rate limiter for a specific function. - * @param function_name Name of the function to be rate-limited. - * @return An Awaiter object. - */ - Awaiter acquire(const std::string& function_name); - - /** - * @brief Sets the rate limit for a specific function. - * @param function_name Name of the function to be rate-limited. - * @param max_requests Maximum number of requests allowed. - * @param time_window Duration of the time window. - */ - void setFunctionLimit(const std::string& function_name, size_t max_requests, - std::chrono::seconds time_window); - - /** - * @brief Pauses the rate limiter. - */ - void pause(); - - /** - * @brief Resumes the rate limiter. - */ - void resume(); - - /** - * @brief Prints the log of requests. - */ - void printLog(); - - /** - * @brief Gets the number of rejected requests for a specific function. - * @param function_name Name of the function. - * @return Number of rejected requests. - */ - auto getRejectedRequests(const std::string& function_name) -> size_t; - -#if !defined(TEST_F) && !defined(TEST) -private: -#endif - /** - * @brief Cleans up old requests outside the time window. - * @param function_name Name of the function. - * @param time_window Duration of the time window. - */ - void cleanup(const std::string& function_name, - const std::chrono::seconds& time_window); - - /** - * @brief Processes waiting coroutines. - */ - void processWaiters(); - - std::unordered_map settings_; - std::unordered_map> - requests_; - std::unordered_map>> - waiters_; - std::unordered_map> - log_; - std::unordered_map rejected_requests_; - bool paused_ = false; - std::mutex mutex_; -}; - -/** - * @class Debounce - * @brief A class that implements a debouncing mechanism for function calls. - * - * The `Debounce` class ensures that the given function is not invoked more - * frequently than a specified delay interval. It postpones the function call - * until the delay has elapsed since the last call. If a new call occurs before - * the delay expires, the previous call is canceled and the delay starts over. - * This is useful for situations where you want to limit the rate of function - * invocations, such as handling user input events. - */ -class Debounce { -public: - /** - * @brief Constructs a Debounce object. - * - * @param func The function to be debounced. - * @param delay The time delay to wait before invoking the function. - * @param leading If true, the function will be invoked immediately on the - * first call and then debounced for subsequent calls. If false, the - * function will be debounced and invoked only after the delay has passed - * since the last call. - * @param maxWait Optional maximum wait time before invoking the function if - * it has been called frequently. If not provided, there is no maximum wait - * time. - */ - Debounce(std::function func, std::chrono::milliseconds delay, - bool leading = false, - std::optional maxWait = std::nullopt); - - /** - * @brief Invokes the debounced function if the delay has elapsed since the - * last call. - * - * This method schedules the function call if the delay period has passed - * since the last call. If the leading flag is set, the function will be - * called immediately on the first call. Subsequent calls will reset the - * delay timer. - */ - void operator()(); - - /** - * @brief Cancels any pending function calls. - * - * This method cancels any pending invocation of the function that is - * scheduled to occur based on the debouncing mechanism. - */ - void cancel(); - - /** - * @brief Immediately invokes the function if it is scheduled to be called. - * - * This method flushes any pending function calls, ensuring the function is - * called immediately. - */ - void flush(); - - /** - * @brief Resets the debouncer, clearing any pending function call and - * timer. - * - * This method resets the internal state of the debouncer, allowing it to - * start fresh and schedule new function calls based on the debounce delay. - */ - void reset(); - - /** - * @brief Returns the number of times the function has been invoked. - * - * @return The count of function invocations. - */ - size_t callCount() const; - -private: - /** - * @brief Runs the function in a separate thread after the debounce delay. - * - * This method is used internally to handle the scheduling and execution of - * the function after the specified delay. - */ - void run(); - - std::function func_; ///< The function to be debounced. - std::chrono::milliseconds - delay_; ///< The time delay before invoking the function. - std::optional - last_call_; ///< The timestamp of the last call. - std::jthread thread_; ///< A thread used to handle delayed function calls. - mutable std::mutex - mutex_; ///< Mutex to protect concurrent access to internal state. - bool leading_; ///< Indicates if the function should be called immediately - ///< upon the first call. - bool scheduled_ = - false; ///< Flag to track if the function is scheduled for execution. - std::optional - maxWait_; ///< Optional maximum wait time before invocation. - size_t - call_count_{}; ///< Counter to keep track of function call invocations. -}; - -/** - * @class Throttle - * @brief A class that provides throttling for function calls, ensuring they are - * not invoked more frequently than a specified interval. - * - * This class is useful for rate-limiting function calls. It ensures that the - * given function is not called more frequently than the specified interval. - * Additionally, it can be configured to either throttle function calls to be - * executed at most once per interval or to execute the function immediately - * upon the first call and then throttle subsequent calls. - */ -class Throttle { -public: - /** - * @brief Constructs a Throttle object. - * - * @param func The function to be throttled. - * @param interval The minimum time interval between calls to the function. - * @param leading If true, the function will be called immediately upon the - * first call, then throttled. If false, the function will be throttled and - * called at most once per interval. - * @param maxWait Optional maximum wait time before invoking the function if - * it has been called frequently. If not provided, there is no maximum wait - * time. - */ - Throttle(std::function func, std::chrono::milliseconds interval, - bool leading = false, - std::optional maxWait = std::nullopt); - - /** - * @brief Invokes the throttled function if the interval has elapsed. - * - * This method will check if enough time has passed since the last function - * call. If so, it will invoke the function and update the last call - * timestamp. If the function is being invoked immediately as per the - * leading configuration, it will be executed at once, and subsequent calls - * will be throttled. - */ - void operator()(); - - /** - * @brief Cancels any pending function calls. - * - * This method cancels any pending function invocations that are scheduled - * to occur based on the throttling mechanism. - */ - void cancel(); - - /** - * @brief Resets the throttle, clearing the last call timestamp and allowing - * the function to be invoked immediately if required. - * - * This method can be used to reset the throttle state, allowing the - * function to be called immediately if the leading flag is set or to reset - * the interval for subsequent function calls. - */ - void reset(); - - /** - * @brief Returns the number of times the function has been called. - * - * @return The count of function invocations. - */ - auto callCount() const -> size_t; - -private: - std::function func_; ///< The function to be throttled. - std::chrono::milliseconds - interval_; ///< The time interval between allowed function calls. - std::chrono::steady_clock::time_point - last_call_; ///< The timestamp of the last function call. - mutable std::mutex - mutex_; ///< Mutex to protect concurrent access to internal state. - bool leading_; ///< Indicates if the function should be called immediately - ///< upon first call. - bool called_ = false; ///< Flag to track if the function has been called. - std::optional - maxWait_; ///< Optional maximum wait time before invocation. - size_t - call_count_{}; ///< Counter to keep track of function call invocations. -}; - -} // namespace atom::async - -#endif diff --git a/src/atom/async/lock.cpp b/src/atom/async/lock.cpp deleted file mode 100644 index f03056ef..00000000 --- a/src/atom/async/lock.cpp +++ /dev/null @@ -1,49 +0,0 @@ -/* - * lock.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-2-13 - -Description: Some useful spinlock implementations - -**************************************************/ - -#include "lock.hpp" - -namespace atom::async { -void Spinlock::lock() { - while (flag_.test_and_set(std::memory_order_acquire)) { - cpu_relax(); - } -} - -auto Spinlock::tryLock() -> bool { - return !flag_.test_and_set(std::memory_order_acquire); -} - -void Spinlock::unlock() { flag_.clear(std::memory_order_release); } - -auto TicketSpinlock::lock() -> uint64_t { - const auto TICKET = ticket_.fetch_add(1, std::memory_order_acq_rel); - while (serving_.load(std::memory_order_acquire) != TICKET) { - cpu_relax(); - } - return TICKET; -} - -void TicketSpinlock::unlock(uint64_t TICKET) { - serving_.store(TICKET + 1, std::memory_order_release); -} - -void UnfairSpinlock::lock() { - while (flag_.test_and_set(std::memory_order_acquire)) { - cpu_relax(); - } -} - -void UnfairSpinlock::unlock() { flag_.clear(std::memory_order_release); } -} // namespace atom::async diff --git a/src/atom/async/lock.hpp b/src/atom/async/lock.hpp deleted file mode 100644 index a6073fb2..00000000 --- a/src/atom/async/lock.hpp +++ /dev/null @@ -1,214 +0,0 @@ -/* - * lock.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-2-13 - -Description: Some useful spinlock implementations - -**************************************************/ - -#ifndef ATOM_ASYNC_LOCK_HPP -#define ATOM_ASYNC_LOCK_HPP - -#include - -#include "atom/type/noncopyable.hpp" - -namespace atom::async { - -// Pause instruction to prevent excess processor bus usage -#if defined(_MSC_VER) -#define cpu_relax() std::this_thread::yield() -#elif defined(__i386__) || defined(__x86_64__) -#define cpu_relax() asm volatile("pause\n" : : : "memory") -#elif defined(__aarch64__) -#define cpu_relax() asm volatile("yield\n" : : : "memory") -#elif defined(__arm__) -#define cpu_relax() asm volatile("nop\n" : : : "memory") -#else -#error "Unknown architecture, CPU relax code required" -#endif - -/** - * @brief A simple spinlock implementation using atomic_flag. - */ -class Spinlock : public NonCopyable { - std::atomic_flag flag_ = ATOMIC_FLAG_INIT; - -public: - /** - * @brief Default constructor. - */ - Spinlock() = default; - - /** - * @brief Acquires the lock. - */ - void lock(); - - /** - * @brief Releases the lock. - */ - void unlock(); - - /** - * @brief Tries to acquire the lock. - * - * @return true if the lock was acquired, false otherwise. - */ - auto tryLock() -> bool; -}; - -/** - * @brief A ticket spinlock implementation using atomic operations. - */ -class TicketSpinlock : public NonCopyable { - std::atomic ticket_{0}; - std::atomic serving_{0}; - -public: - TicketSpinlock() = default; - /** - * @brief Lock guard for TicketSpinlock. - */ - class LockGuard { - TicketSpinlock &spinlock_; - const uint64_t TICKET; - - public: - /** - * @brief Constructs the lock guard and acquires the lock. - * - * @param spinlock The TicketSpinlock to guard. - */ - explicit LockGuard(TicketSpinlock &spinlock) - : spinlock_(spinlock), TICKET(spinlock_.lock()) {} - - /** - * @brief Destructs the lock guard and releases the lock. - */ - ~LockGuard() { spinlock_.unlock(TICKET); } - }; - - using scoped_lock = LockGuard; - - /** - * @brief Acquires the lock and returns the ticket number. - * - * @return The acquired ticket number. - */ - auto lock() -> uint64_t; - - /** - * @brief Releases the lock given a specific ticket number. - * - * @param ticket The ticket number to release. - */ - void unlock(uint64_t TICKET); -}; - -/** - * @brief An unfair spinlock implementation using atomic_flag. - */ -class UnfairSpinlock : public NonCopyable { - std::atomic_flag flag_ = ATOMIC_FLAG_INIT; - -public: - UnfairSpinlock() = default; - /** - * @brief Acquires the lock. - */ - void lock(); - - /** - * @brief Releases the lock. - */ - void unlock(); -}; - -/** - * @brief Scoped lock for any type of spinlock. - * - * @tparam Mutex Type of the spinlock (e.g., Spinlock, TicketSpinlock, - * UnfairSpinlock). - */ -template -class ScopedLock { - Mutex &mutex_; - -public: - /** - * @brief Constructs the scoped lock and acquires the lock on the provided - * mutex. - * - * @param mutex The mutex to lock. - */ - explicit ScopedLock(Mutex &mutex) : mutex_(mutex) { mutex_.lock(); } - - /** - * @brief Destructs the scoped lock and releases the lock. - */ - ~ScopedLock() { mutex_.unlock(); } - - ScopedLock(const ScopedLock &) = delete; - ScopedLock &operator=(const ScopedLock &) = delete; -}; - -/** - * @brief Scoped lock for TicketSpinlock. - * - * @tparam Mutex Type of the spinlock (i.e., TicketSpinlock). - */ -template -class ScopedTicketLock : public NonCopyable { - Mutex &mutex_; - const uint64_t TICKET; - -public: - /** - * @brief Constructs the scoped lock and acquires the lock on the provided - * mutex. - * - * @param mutex The mutex to lock. - */ - explicit ScopedTicketLock(Mutex &mutex) - : mutex_(mutex), TICKET(mutex_.lock()) {} - - /** - * @brief Destructs the scoped lock and releases the lock. - */ - ~ScopedTicketLock() { mutex_.unlock(TICKET); } -}; - -/** - * @brief Scoped lock for UnfairSpinlock. - * - * @tparam Mutex Type of the spinlock (i.e., UnfairSpinlock). - */ -template -class ScopedUnfairLock : public NonCopyable { - Mutex &mutex_; - -public: - /** - * @brief Constructs the scoped lock and acquires the lock on the provided - * mutex. - * - * @param mutex The mutex to lock. - */ - explicit ScopedUnfairLock(Mutex &mutex) : mutex_(mutex) { mutex_.lock(); } - - /** - * @brief Destructs the scoped lock and releases the lock. - */ - ~ScopedUnfairLock() { mutex_.unlock(); } -}; - -} // namespace atom::async - -#endif diff --git a/src/atom/async/message_bus.hpp b/src/atom/async/message_bus.hpp deleted file mode 100644 index 74de9b97..00000000 --- a/src/atom/async/message_bus.hpp +++ /dev/null @@ -1,404 +0,0 @@ -/* - * message_bus.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-7-23 - -Description: Main Message Bus with Asio support and additional features - -**************************************************/ - -#ifndef ATOM_ASYNC_MESSAGE_BUS_HPP -#define ATOM_ASYNC_MESSAGE_BUS_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "atom/macro.hpp" - -namespace atom::async { - -/** - * @brief The MessageBus class provides a message bus system with Asio support. - */ -class MessageBus { -public: - using Token = std::size_t; - static constexpr std::size_t K_MAX_HISTORY_SIZE = - 100; ///< Maximum number of messages to keep in history. - - /** - * @brief Constructs a MessageBus with the given io_context. - * @param io_context The Asio io_context to use for asynchronous operations. - */ - explicit MessageBus(asio::io_context& io_context) - : io_context_(io_context) {} - - /** - * @brief Creates a shared instance of MessageBus. - * @param io_context The Asio io_context to use for asynchronous operations. - * @return A shared pointer to the created MessageBus instance. - */ - static auto createShared(asio::io_context& io_context) - -> std::shared_ptr { - return std::make_shared(io_context); - } - - /** - * @brief Publishes a message to the bus, optionally with a delay. - * @tparam MessageType The type of the message. - * @param name The name of the message. - * @param message The message to publish. - * @param delay Optional delay before publishing the message. - */ - template - void publish( - const std::string& name, const MessageType& message, - std::optional delay = std::nullopt) { - auto publishTask = [this, name, message]() { - std::shared_lock lock(mutex_); - std::unordered_set - calledSubscribers; // Track called subscribers - - // Publish to directly matching subscribers - publishToSubscribers(name, message, calledSubscribers); - - // Publish to namespace matching subscribers - for (const auto& namespaceName : namespaces_) { - if (name.find(namespaceName + ".") == - 0) { // Namespace match must start with namespaceName + dot - publishToSubscribers(namespaceName, message, - calledSubscribers); - } - } - - // Record the message in history - recordMessageHistory(name, message); - - // 记录日志 - std::cout << "[MessageBus] Published message: " << name - << std::endl; - }; - - if (delay) { - // Use Asio's steady_timer for delayed publishing - auto timer = - std::make_shared(io_context_, *delay); - timer->async_wait( - [timer, publishTask](const asio::error_code& errorCode) { - if (!errorCode) { - publishTask(); - } - }); - } else { - // Immediately publish asynchronously using asio::post - asio::post(io_context_, publishTask); - } - } - - /** - * @brief Publishes a message to all subscribers globally. - * @tparam MessageType The type of the message. - * @param message The message to publish. - */ - template - void publishGlobal(const MessageType& message) { - std::shared_lock lock(mutex_); - for (const auto& [type, subscribersMap] : subscribers_) { - for (const auto& [name, subscribersList] : subscribersMap) { - publish(name, message); - } - } - } - - /** - * @brief Subscribes to a message. - * @tparam MessageType The type of the message. - * @param name The name of the message or namespace (supports wildcard). - * @param handler The handler function to call when the message is received. - * @param async Whether to call the handler asynchronously. - * @param once Whether to unsubscribe after the first message is received. - * @param filter Optional filter function to determine whether to call the - * handler. - * @return A token representing the subscription. - */ - template - auto subscribe( - const std::string& name, - std::function handler, bool async = true, - bool once = false, - std::function filter = - [](const MessageType&) { return true; }) -> Token { - std::unique_lock lock(mutex_); - Token token = nextToken_++; - subscribers_[std::type_index(typeid(MessageType))][name].emplace_back( - Subscriber{[handler = std::move(handler)](const std::any& msg) { - handler(std::any_cast(msg)); - }, - async, once, - [filter = std::move(filter)](const std::any& msg) { - return filter( - std::any_cast(msg)); - }, - token}); - namespaces_.insert(extractNamespace(name)); // Record namespace - std::cout << "[MessageBus] Subscribed to: " << name - << " with token: " << token << std::endl; - return token; - } - - /** - * @brief Unsubscribes from a message using the given token. - * @tparam MessageType The type of the message. - * @param token The token representing the subscription. - */ - template - void unsubscribe(Token token) { - std::unique_lock lock(mutex_); - auto iterator = subscribers_.find(std::type_index(typeid(MessageType))); - if (iterator != subscribers_.end()) { - for (auto& [name, subscribersList] : iterator->second) { - removeSubscription(subscribersList, token); - } - } - std::cout << "[MessageBus] Unsubscribed token: " << token << std::endl; - } - - /** - * @brief Unsubscribes all handlers for a given message name or namespace. - * @tparam MessageType The type of the message. - * @param name The name of the message or namespace. - */ - template - void unsubscribeAll(const std::string& name) { - std::unique_lock lock(mutex_); - auto iterator = subscribers_.find(std::type_index(typeid(MessageType))); - if (iterator != subscribers_.end()) { - auto nameIterator = iterator->second.find(name); - if (nameIterator != iterator->second.end()) { - size_t count = nameIterator->second.size(); - iterator->second.erase(nameIterator); - std::cout << "[MessageBus] Unsubscribed all handlers for: " - << name << " (" << count << " subscribers)" - << std::endl; - } - } - } - - /** - * @brief Gets the number of subscribers for a given message name or - * namespace. - * @tparam MessageType The type of the message. - * @param name The name of the message or namespace. - * @return The number of subscribers. - */ - template - auto getSubscriberCount(const std::string& name) -> std::size_t { - std::shared_lock lock(mutex_); - auto iterator = subscribers_.find(std::type_index(typeid(MessageType))); - if (iterator != subscribers_.end()) { - auto nameIterator = iterator->second.find(name); - if (nameIterator != iterator->second.end()) { - return nameIterator->second.size(); - } - } - return 0; - } - - /** - * @brief Checks if there are any subscribers for a given message name or - * namespace. - * @tparam MessageType The type of the message. - * @param name The name of the message or namespace. - * @return True if there are subscribers, false otherwise. - */ - template - auto hasSubscriber(const std::string& name) -> bool { - std::shared_lock lock(mutex_); - auto iterator = subscribers_.find(std::type_index(typeid(MessageType))); - if (iterator != subscribers_.end()) { - auto nameIterator = iterator->second.find(name); - return nameIterator != iterator->second.end() && - !nameIterator->second.empty(); - } - return false; - } - - /** - * @brief Clears all subscribers. - */ - void clearAllSubscribers() { - std::unique_lock lock(mutex_); - subscribers_.clear(); - namespaces_.clear(); - std::cout << "[MessageBus] Cleared all subscribers." << std::endl; - } - - /** - * @brief Gets the list of active namespaces. - * @return A vector of active namespace names. - */ - auto getActiveNamespaces() const -> std::vector { - std::shared_lock lock(mutex_); - return {namespaces_.begin(), namespaces_.end()}; - } - - /** - * @brief Gets the message history for a given message name. - * @tparam MessageType The type of the message. - * @param name The name of the message. - * @return A vector of messages. - */ - template - auto getMessageHistory(const std::string& name, - std::size_t count = K_MAX_HISTORY_SIZE) const - -> std::vector { - std::shared_lock lock(mutex_); - auto iterator = - messageHistory_.find(std::type_index(typeid(MessageType))); - if (iterator != messageHistory_.end()) { - auto nameIterator = iterator->second.find(name); - if (nameIterator != iterator->second.end()) { - std::vector history; - std::size_t start = (nameIterator->second.size() > count) - ? nameIterator->second.size() - count - : 0; - for (std::size_t i = start; i < nameIterator->second.size(); - ++i) { - history.emplace_back( - std::any_cast(nameIterator->second[i])); - } - return history; - } - } - return {}; - } - -private: - struct Subscriber { - std::function - handler; ///< The handler function. - bool async; ///< Whether to call the handler asynchronously. - bool once; ///< Whether to unsubscribe after the first message. - std::function filter; ///< The filter function. - Token token; ///< The subscription token. - } ATOM_ALIGNAS(64); - - /** - * @brief Publishes a message to the subscribers. - * @tparam MessageType The type of the message. - * @param name The name of the message. - * @param message The message to publish. - * @param calledSubscribers The set of already called subscribers. - */ - template - void publishToSubscribers(const std::string& name, - const MessageType& message, - std::unordered_set& calledSubscribers) { - auto iterator = subscribers_.find(std::type_index(typeid(MessageType))); - if (iterator != subscribers_.end()) { - auto nameIterator = iterator->second.find(name); - if (nameIterator != iterator->second.end()) { - auto& subscribersList = nameIterator->second; - for (auto it = subscribersList.begin(); - it != subscribersList.end();) { - if (it->filter(message) && - calledSubscribers.insert(it->token).second) { - auto handler = [handlerFunc = it->handler, message]() { - std::any msg = message; - handlerFunc(msg); - }; - if (it->async) { - asio::post(io_context_, handler); - } else { - handler(); - } - if (it->once) { - it = subscribersList.erase(it); - continue; - } - } - ++it; - } - } - } - } - - /** - * @brief Removes a subscription from the list. - * @param subscribersList The list of subscribers. - * @param token The token representing the subscription. - */ - static void removeSubscription(std::vector& subscribersList, - Token token) { - subscribersList.erase( - std::remove_if( - subscribersList.begin(), subscribersList.end(), - [token](const Subscriber& sub) { return sub.token == token; }), - subscribersList.end()); - } - - /** - * @brief Records a message in the history. - * @tparam MessageType The type of the message. - * @param name The name of the message. - * @param message The message to record. - */ - template - void recordMessageHistory(const std::string& name, - const MessageType& message) { - auto& history = - messageHistory_[std::type_index(typeid(MessageType))][name]; - history.emplace_back(message); - if (history.size() > K_MAX_HISTORY_SIZE) { - history.erase(history.begin()); - } - } - - /** - * @brief Extracts the namespace from the message name. - * @param name The message name. - * @return The namespace part of the name. - */ - std::string extractNamespace(const std::string& name) const { - auto pos = name.find('.'); - if (pos != std::string::npos) { - return name.substr(0, pos); - } - return name; - } - - std::unordered_map>> - subscribers_; ///< Map of subscribers. - std::unordered_map>> - messageHistory_; ///< Map of message history. - std::unordered_set namespaces_; ///< Set of namespaces. - mutable std::shared_mutex mutex_; ///< Mutex for thread safety. - Token nextToken_ = 0; ///< Next token value. - - asio::io_context& - io_context_; ///< Asio io_context for asynchronous operations. -}; - -} // namespace atom::async - -#endif // ATOM_ASYNC_MESSAGE_BUS_HPP diff --git a/src/atom/async/message_queue.hpp b/src/atom/async/message_queue.hpp deleted file mode 100644 index 45ddfb77..00000000 --- a/src/atom/async/message_queue.hpp +++ /dev/null @@ -1,301 +0,0 @@ -/* - * message_queue.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -#ifndef ATOM_ASYNC_MESSAGE_QUEUE_HPP -#define ATOM_ASYNC_MESSAGE_QUEUE_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::async { - -/** - * @brief A message queue that allows subscribers to receive messages of type T. - * - * @tparam T The type of messages that can be published and subscribed to. - */ -template -class MessageQueue { -public: - using CallbackType = std::function; - using FilterType = std::function; - - /** - * @brief Constructs a MessageQueue with the given io_context. - * @param ioContext The Asio io_context to use for asynchronous operations. - */ - explicit MessageQueue(asio::io_context& ioContext) - : ioContext_(ioContext) {} - - /** - * @brief Subscribe to messages with a callback and optional filter and - * timeout. - * - * @param callback The callback function to be called when a new message is - * received. - * @param subscriberName The name of the subscriber. - * @param priority The priority of the subscriber. Higher priority receives - * messages first. - * @param filter An optional filter to only receive messages that match the - * criteria. - * @param timeout The maximum time allowed for the subscriber to process a - * message. - */ - void subscribe( - CallbackType callback, const std::string& subscriberName, - int priority = 0, FilterType filter = nullptr, - std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()); - - /** - * @brief Unsubscribe from messages using the given callback. - * - * @param callback The callback function used during subscription. - */ - void unsubscribe(CallbackType callback); - - /** - * @brief Publish a message to the queue, with an optional priority. - * - * @param message The message to publish. - * @param priority The priority of the message, higher priority messages are - * handled first. - */ - void publish(const T& message, int priority = 0); - - /** - * @brief Start processing messages in the queue. - */ - void startProcessing(); - - /** - * @brief Stop processing messages in the queue. - */ - void stopProcessing(); - - /** - * @brief Get the number of messages currently in the queue. - * @return The number of messages in the queue. - */ - auto getMessageCount() const -> size_t; - - /** - * @brief Get the number of subscribers currently subscribed to the queue. - * @return The number of subscribers. - */ - auto getSubscriberCount() const -> size_t; - - /** - * @brief Cancel specific messages that meet a given condition. - * - * @param cancelCondition The condition to cancel certain messages. - */ - void cancelMessages(std::function cancelCondition); - -private: - struct Subscriber { - std::string name; - CallbackType callback; - int priority; - FilterType filter; - std::chrono::milliseconds timeout; - - Subscriber(std::string name, const CallbackType& callback, int priority, - FilterType filter, std::chrono::milliseconds timeout) - : name(std::move(name)), - callback(callback), - priority(priority), - filter(filter), - timeout(timeout) {} - - auto operator<(const Subscriber& other) const -> bool { - return priority > other.priority; - } - }; - - struct Message { - T data; - int priority; - - Message(T data, int priority) - : data(std::move(data)), priority(priority) {} - - auto operator<(const Message& other) const -> bool { - return priority > other.priority; - } - }; - - std::deque m_messages_; - std::vector m_subscribers_; - mutable std::mutex m_mutex_; - std::condition_variable m_condition_; - std::atomic m_isRunning_{true}; - asio::io_context& ioContext_; - - /** - * @brief Process messages in the queue. - */ - void processMessages(); - - /** - * @brief Apply the filter to a message for a given subscriber. - * @param subscriber The subscriber to apply the filter for. - * @param message The message to filter. - * @return True if the message passes the filter, false otherwise. - */ - bool applyFilter(const Subscriber& subscriber, const T& message); - - /** - * @brief Handle the timeout for a given subscriber and message. - * @param subscriber The subscriber to handle the timeout for. - * @param message The message to process. - * @return True if the message was processed within the timeout, false - * otherwise. - */ - bool handleTimeout(const Subscriber& subscriber, const T& message); -}; - -template -void MessageQueue::subscribe(CallbackType callback, - const std::string& subscriberName, int priority, - FilterType filter, - std::chrono::milliseconds timeout) { - std::lock_guard lock(m_mutex_); - m_subscribers_.emplace_back(subscriberName, callback, priority, filter, - timeout); - std::ranges::sort(m_subscribers_, std::greater{}); -} - -template -void MessageQueue::unsubscribe(CallbackType callback) { - std::lock_guard lock(m_mutex_); - auto iterator = std::ranges::remove_if( - m_subscribers_, [&callback](const auto& subscriber) { - return subscriber.callback.target_type() == callback.target_type(); - }); - m_subscribers_.erase(iterator.begin(), iterator.end()); -} - -template -void MessageQueue::publish(const T& message, int priority) { - { - std::lock_guard lock(m_mutex_); - m_messages_.emplace_back(message, priority); - } - ioContext_.post([this]() { processMessages(); }); -} - -template -void MessageQueue::startProcessing() { - m_isRunning_.store(true); - ioContext_.run(); -} - -template -void MessageQueue::stopProcessing() { - m_isRunning_.store(false); - ioContext_.stop(); -} - -template -auto MessageQueue::getMessageCount() const -> size_t { - std::lock_guard lock(m_mutex_); - return m_messages_.size(); -} - -template -auto MessageQueue::getSubscriberCount() const -> size_t { - std::lock_guard lock(m_mutex_); - return m_subscribers_.size(); -} - -template -void MessageQueue::cancelMessages( - std::function cancelCondition) { - std::lock_guard lock(m_mutex_); - auto iterator = std::remove_if(m_messages_.begin(), m_messages_.end(), - [&cancelCondition](const auto& msg) { - return cancelCondition(msg.data); - }); - m_messages_.erase(iterator, m_messages_.end()); -} - -template -bool MessageQueue::applyFilter(const Subscriber& subscriber, - const T& message) { - if (!subscriber.filter) { - return true; - } - return subscriber.filter(message); -} - -template -bool MessageQueue::handleTimeout(const Subscriber& subscriber, - const T& message) { - if (subscriber.timeout == std::chrono::milliseconds::zero()) { - subscriber.callback(message); - return true; - } - - std::packaged_task task( - [&subscriber, &message]() { subscriber.callback(message); }); - auto future = task.get_future(); - asio::post(ioContext_, std::move(task)); - - if (future.wait_for(subscriber.timeout) == std::future_status::timeout) { - return false; // Timeout occurred. - } - - return true; // Process completed within timeout. -} - -template -void MessageQueue::processMessages() { - while (m_isRunning_.load()) { - std::optional message; - - { - std::lock_guard lock(m_mutex_); - if (m_messages_.empty()) { - return; - } - message = std::move(m_messages_.front()); - m_messages_.pop_front(); - } - - if (message) { - std::vector subscribersCopy; - - { - std::lock_guard lock(m_mutex_); - subscribersCopy.reserve(m_subscribers_.size()); - for (const auto& subscriber : m_subscribers_) { - subscribersCopy.emplace_back(subscriber); - } - } - - for (const auto& subscriber : subscribersCopy) { - if (applyFilter(subscriber, message->data)) { - handleTimeout(subscriber, message->data); - } - } - } - } -} - -} // namespace atom::async - -#endif // ATOM_ASYNC_MESSAGE_QUEUE_HPP diff --git a/src/atom/async/packaged_task.hpp b/src/atom/async/packaged_task.hpp deleted file mode 100644 index 5f378791..00000000 --- a/src/atom/async/packaged_task.hpp +++ /dev/null @@ -1,231 +0,0 @@ -#ifndef ATOM_ASYNC_PACKAGED_TASK_HPP -#define ATOM_ASYNC_PACKAGED_TASK_HPP - -#include -#include -#include -#include - -#include "atom/async/future.hpp" - -namespace atom::async { - -/** - * @class InvalidPackagedTaskException - * @brief Exception thrown when an invalid packaged task is encountered. - */ -class InvalidPackagedTaskException : public atom::error::RuntimeError { -public: - using atom::error::RuntimeError::RuntimeError; -}; - -/** - * @def THROW_INVALID_PACKAGED_TASK_EXCEPTION - * @brief Macro to throw an InvalidPackagedTaskException with file, line, and - * function information. - */ -#define THROW_INVALID_PACKAGED_TASK_EXCEPTION(...) \ - throw InvalidPackagedTaskException(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__); - -/** - * @def THROW_NESTED_INVALID_PACKAGED_TASK_EXCEPTION - * @brief Macro to rethrow a nested InvalidPackagedTaskException with file, - * line, and function information. - */ -#define THROW_NESTED_INVALID_PACKAGED_TASK_EXCEPTION(...) \ - InvalidPackagedTaskException::rethrowNested( \ - ATOM_FILE_NAME, ATOM_FILE_LINE, ATOM_FUNC_NAME, \ - "Invalid packaged task: " __VA_ARGS__); - -/** - * @class EnhancedPackagedTask - * @brief A template class that extends the standard packaged task with - * additional features. - * @tparam ResultType The type of the result that the task will produce. - * @tparam Args The types of the arguments that the task will accept. - */ -template -class EnhancedPackagedTask { -public: - using TaskType = std::function; - - /** - * @brief Constructs an EnhancedPackagedTask with the given task. - * @param task The task to be executed. - */ - explicit EnhancedPackagedTask(TaskType task) - : task_(std::move(task)), cancelled_(false) { - promise_ = std::promise(); - future_ = promise_.get_future().share(); - } - - /** - * @brief Gets the enhanced future associated with this task. - * @return An EnhancedFuture object. - */ - EnhancedFuture getEnhancedFuture() { - return EnhancedFuture(std::move(future_)); - } - - /** - * @brief Executes the task with the given arguments. - * @param args The arguments to pass to the task. - */ - void operator()(Args... args) { - if (cancelled_) { - promise_.set_exception(std::make_exception_ptr( - std::runtime_error("Task has been cancelled"))); - return; - } - - try { - if (task_) { - ResultType result = task_(std::forward(args)...); - promise_.set_value(result); - runCallbacks(result); - } - } catch (...) { - promise_.set_exception(std::current_exception()); - } - } - - /** - * @brief Adds a callback to be called upon task completion. - * @tparam F The type of the callback function. - * @param func The callback function to add. - */ - template - void onComplete(F &&func) { - callbacks_.emplace_back(std::forward(func)); - } - - /** - * @brief Cancels the task. - */ - void cancel() { cancelled_ = true; } - - /** - * @brief Checks if the task is cancelled. - * @return True if the task is cancelled, false otherwise. - */ - [[nodiscard]] bool isCancelled() const { return cancelled_; } - -protected: - TaskType task_; ///< The task to be executed. - std::promise - promise_; ///< The promise associated with the task. - std::shared_future - future_; ///< The shared future associated with the task. - std::vector> - callbacks_; ///< List of callbacks to be called on completion. - std::atomic - cancelled_; ///< Flag indicating if the task has been cancelled. - -private: - /** - * @brief Runs all the registered callbacks with the given result. - * @param result The result to pass to the callbacks. - */ - void runCallbacks(ResultType result) { - for (auto &callback : callbacks_) { - callback(result); - } - } -}; - -/** - * @class EnhancedPackagedTask - * @brief Specialization of the EnhancedPackagedTask class for void result type. - * @tparam Args The types of the arguments that the task will accept. - */ -template -class EnhancedPackagedTask { -public: - using TaskType = std::function; - - /** - * @brief Constructs an EnhancedPackagedTask with the given task. - * @param task The task to be executed. - */ - explicit EnhancedPackagedTask(TaskType task) - : task_(std::move(task)), cancelled_(false) { - promise_ = std::promise(); - future_ = promise_.get_future().share(); - } - - /** - * @brief Gets the enhanced future associated with this task. - * @return An EnhancedFuture object. - */ - EnhancedFuture getEnhancedFuture() { - return EnhancedFuture(std::move(future_)); - } - - /** - * @brief Executes the task with the given arguments. - * @param args The arguments to pass to the task. - */ - void operator()(Args... args) { - if (cancelled_) { - promise_.set_exception(std::make_exception_ptr( - std::runtime_error("Task has been cancelled"))); - return; - } - - try { - if (task_) { - task_(std::forward(args)...); - promise_.set_value(); - runCallbacks(); - } - } catch (...) { - promise_.set_exception(std::current_exception()); - } - } - - /** - * @brief Adds a callback to be called upon task completion. - * @tparam F The type of the callback function. - * @param func The callback function to add. - */ - template - void onComplete(F &&func) { - callbacks_.emplace_back(std::forward(func)); - } - - /** - * @brief Cancels the task. - */ - void cancel() { cancelled_ = true; } - - /** - * @brief Checks if the task is cancelled. - * @return True if the task is cancelled, false otherwise. - */ - [[nodiscard]] bool isCancelled() const { return cancelled_; } - -protected: - TaskType task_; ///< The task to be executed. - std::promise promise_; ///< The promise associated with the task. - std::shared_future - future_; ///< The shared future associated with the task. - std::vector> - callbacks_; ///< List of callbacks to be called on completion. - std::atomic - cancelled_; ///< Flag indicating if the task has been cancelled. - -private: - /** - * @brief Runs all the registered callbacks. - */ - void runCallbacks() { - for (auto &callback : callbacks_) { - callback(); - } - } -}; - -} // namespace atom::async - -#endif // ATOM_ASYNC_PACKAGED_TASK_HPP diff --git a/src/atom/async/pool.hpp b/src/atom/async/pool.hpp deleted file mode 100644 index 249be0cc..00000000 --- a/src/atom/async/pool.hpp +++ /dev/null @@ -1,394 +0,0 @@ -/* - * pool.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-2-13 - -Description: A very simple thread pool for preload - -**************************************************/ - -#ifndef ATOM_ASYNC_POOL_HPP -#define ATOM_ASYNC_POOL_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "atom/macro.hpp" -#ifdef __has_include -#if __has_include() -#include -#endif -#endif - -namespace atom::async { -/** - * @brief Simple concept for the Lockable and Basic Lockable types as defined by - * the C++ standard. - * @details See https://en.cppreference.com/w/cpp/named_req/Lockable and - * https://en.cppreference.com/w/cpp/named_req/BasicLockable for details. - */ -template -concept is_lockable = requires(Lock&& lock) { - lock.lock(); - lock.unlock(); - { lock.try_lock() } -> std::convertible_to; -}; - -template - requires is_lockable -class ThreadSafeQueue { -public: - using value_type = T; - using size_type = typename std::deque::size_type; - - ThreadSafeQueue() = default; - - // Copy constructor - ThreadSafeQueue(const ThreadSafeQueue& other) { - std::scoped_lock lock(other.mutex_); - data_ = other.data_; - } - - // Copy assignment operator - auto operator=(const ThreadSafeQueue& other) -> ThreadSafeQueue& { - if (this != &other) { - std::scoped_lock lockThis(mutex_, std::defer_lock); - std::scoped_lock lockOther(other.mutex_, std::defer_lock); - std::lock(lockThis, lockOther); - data_ = other.data_; - } - return *this; - } - - // Move constructor - ThreadSafeQueue(ThreadSafeQueue&& other) noexcept { - std::scoped_lock lock(other.mutex_); - data_ = std::move(other.data_); - } - - // Move assignment operator - auto operator=(ThreadSafeQueue&& other) noexcept -> ThreadSafeQueue& { - if (this != &other) { - std::scoped_lock lockThis(mutex_, std::defer_lock); - std::scoped_lock lockOther(other.mutex_, std::defer_lock); - std::lock(lockThis, lockOther); - data_ = std::move(other.data_); - } - return *this; - } - - void pushBack(T&& value) { - std::scoped_lock lock(mutex_); - data_.push_back(std::forward(value)); - } - - void pushFront(T&& value) { - std::scoped_lock lock(mutex_); - data_.push_front(std::forward(value)); - } - - [[nodiscard]] auto empty() const -> bool { - std::scoped_lock lock(mutex_); - return data_.empty(); - } - - [[nodiscard]] auto size() const -> size_type { - std::scoped_lock lock(mutex_); - return data_.size(); - } - - [[nodiscard]] auto popFront() -> std::optional { - std::scoped_lock lock(mutex_); - if (data_.empty()) { - return std::nullopt; - } - - auto front = std::move(data_.front()); - data_.pop_front(); - return front; - } - - [[nodiscard]] auto popBack() -> std::optional { - std::scoped_lock lock(mutex_); - if (data_.empty()) { - return std::nullopt; - } - - auto back = std::move(data_.back()); - data_.pop_back(); - return back; - } - - [[nodiscard]] auto steal() -> std::optional { - std::scoped_lock lock(mutex_); - if (data_.empty()) { - return std::nullopt; - } - - auto back = std::move(data_.back()); - data_.pop_back(); - return back; - } - - void rotateToFront(const T& item) { - std::scoped_lock lock(mutex_); - auto iter = std::find(data_.begin(), data_.end(), item); - - if (iter != data_.end()) { - std::ignore = data_.erase(iter); - } - - data_.push_front(item); - } - - [[nodiscard]] auto copyFrontAndRotateToBack() -> std::optional { - std::scoped_lock lock(mutex_); - - if (data_.empty()) { - return std::nullopt; - } - - auto front = data_.front(); - data_.pop_front(); - - data_.push_back(front); - - return front; - } - - void clear() { - std::scoped_lock lock(mutex_); - data_.clear(); - } - -private: - std::deque data_; - mutable Lock mutex_; -}; - -namespace details { -#ifdef __cpp_lib_move_only_function -using default_function_type = std::move_only_function; -#else -using default_function_type = std::function; -#endif -} // namespace details - -template - requires std::invocable && - std::is_same_v> -class ThreadPool { -public: - template < - typename InitializationFunction = std::function> - requires std::invocable && - std::is_same_v> - explicit ThreadPool( - const unsigned int& number_of_threads = - std::thread::hardware_concurrency(), - InitializationFunction init = [](std::size_t) {}) - : tasks_(number_of_threads) { - std::size_t currentId = 0; - for (std::size_t i = 0; i < number_of_threads; ++i) { - priority_queue_.pushBack(std::move(currentId)); - try { - threads_.emplace_back([&, threadId = currentId, - init](const std::stop_token& stop_tok) { - try { - std::invoke(init, threadId); - } catch (...) { - } - - do { - tasks_[threadId].signal.acquire(); - - do { - while (auto task = - tasks_[threadId].tasks.popFront()) { - unassigned_tasks_.fetch_sub( - 1, std::memory_order_release); - std::invoke(std::move(task.value())); - in_flight_tasks_.fetch_sub( - 1, std::memory_order_release); - } - - for (std::size_t j = 1; j < tasks_.size(); ++j) { - const std::size_t INDEX = - (threadId + j) % tasks_.size(); - if (auto task = tasks_[INDEX].tasks.steal()) { - unassigned_tasks_.fetch_sub( - 1, std::memory_order_release); - std::invoke(std::move(task.value())); - in_flight_tasks_.fetch_sub( - 1, std::memory_order_release); - break; - } - } - } while (unassigned_tasks_.load( - std::memory_order_acquire) > 0); - - priority_queue_.rotateToFront(threadId); - - if (in_flight_tasks_.load(std::memory_order_acquire) == - 0) { - threads_complete_signal_.store( - true, std::memory_order_release); - threads_complete_signal_.notify_one(); - } - - } while (!stop_tok.stop_requested()); - }); - ++currentId; - - } catch (...) { - tasks_.pop_back(); - std::ignore = priority_queue_.popBack(); - } - } - } - - ~ThreadPool() { - waitForTasks(); - - for (auto& thread : threads_) { - thread.request_stop(); - } - - for (auto& task : tasks_) { - task.signal.release(); - } - - for (auto& thread : threads_) { - thread.join(); - } - } - - ThreadPool(const ThreadPool&) = delete; - auto operator=(const ThreadPool&) -> ThreadPool& = delete; - - // Define move constructor and move assignment operator - ThreadPool(ThreadPool&& other) noexcept = default; - auto operator=(ThreadPool&& other) noexcept -> ThreadPool& = default; - - template > - requires std::invocable - [[nodiscard]] auto enqueue(Function func, - Args... args) -> std::future { -#ifdef __cpp_lib_move_only_function - std::promise promise; - auto future = promise.get_future(); - auto task = [func = std::move(func), ... largs = std::move(args), - promise = std::move(promise)]() mutable { - try { - if constexpr (std::is_same_v) { - func(largs...); - promise.set_value(); - } else { - promise.set_value(func(largs...)); - } - } catch (...) { - promise.set_exception(std::current_exception()); - } - }; - enqueueTask(std::move(task)); - return future; -#else - auto shared_promise = std::make_shared>(); - auto task = [func = std::move(func), ... largs = std::move(args), - promise = shared_promise]() { - try { - if constexpr (std::is_same_v) { - func(largs...); - promise->set_value(); - } else { - promise->set_value(func(largs...)); - } - } catch (...) { - promise->set_exception(std::current_exception()); - } - }; - - auto future = shared_promise->get_future(); - enqueue_task(std::move(task)); - return future; -#endif - } - - template - requires std::invocable - void enqueueDetach(Function&& func, Args&&... args) { - enqueueTask([func = std::forward(func), - ... largs = std::forward(args)]() mutable { - try { - if constexpr (std::is_same_v>) { - std::invoke(func, largs...); - } else { - std::ignore = std::invoke(func, largs...); - } - } catch (...) { - } - }); - } - - [[nodiscard]] auto size() const -> std::size_t { return threads_.size(); } - - void waitForTasks() { - if (in_flight_tasks_.load(std::memory_order_acquire) > 0) { - threads_complete_signal_.wait(false); - } - } - -private: - template - void enqueueTask(Function&& func) { - auto iOpt = priority_queue_.copyFrontAndRotateToBack(); - if (!iOpt.has_value()) { - return; - } - auto index = *(iOpt); - - unassigned_tasks_.fetch_add(1, std::memory_order_release); - const auto PREV_IN_FLIGHT = - in_flight_tasks_.fetch_add(1, std::memory_order_release); - - if (PREV_IN_FLIGHT == 0) { - threads_complete_signal_.store(false, std::memory_order_release); - } - - tasks_[index].tasks.pushBack(std::forward(func)); - tasks_[index].signal.release(); - } - - struct TaskItem { - atom::async::ThreadSafeQueue tasks{}; - std::binary_semaphore signal{0}; - } ATOM_ALIGNAS(128); - - std::vector threads_; - std::deque tasks_; - atom::async::ThreadSafeQueue priority_queue_; - std::atomic_int_fast64_t unassigned_tasks_{0}, in_flight_tasks_{0}; - std::atomic_bool threads_complete_signal_{false}; -}; -} // namespace atom::async - -#endif // ATOM_ASYNC_POOL_HPP diff --git a/src/atom/async/promise.hpp b/src/atom/async/promise.hpp deleted file mode 100644 index 044bcfe7..00000000 --- a/src/atom/async/promise.hpp +++ /dev/null @@ -1,332 +0,0 @@ -#ifndef ATOM_ASYNC_PROMISE_HPP -#define ATOM_ASYNC_PROMISE_HPP - -#include -#include -#include -#include -#include - -#include "atom/async/future.hpp" - -namespace atom::async { - -/** - * @class PromiseCancelledException - * @brief Exception thrown when a promise is cancelled. - */ -class PromiseCancelledException : public atom::error::RuntimeError { -public: - using atom::error::RuntimeError::RuntimeError; -}; - -/** - * @def THROW_PROMISE_CANCELLED_EXCEPTION - * @brief Macro to throw a PromiseCancelledException with file, line, and - * function information. - */ -#define THROW_PROMISE_CANCELLED_EXCEPTION(...) \ - throw PromiseCancelledException(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__); - -/** - * @def THROW_NESTED_PROMISE_CANCELLED_EXCEPTION - * @brief Macro to rethrow a nested PromiseCancelledException with file, line, - * and function information. - */ -#define THROW_NESTED_PROMISE_CANCELLED_EXCEPTION(...) \ - PromiseCancelledException::rethrowNested( \ - ATOM_FILE_NAME, ATOM_FILE_LINE, ATOM_FUNC_NAME, \ - "Promise cancelled: " __VA_ARGS__); - -/** - * @class EnhancedPromise - * @brief A template class that extends the standard promise with additional - * features. - * @tparam T The type of the value that the promise will hold. - */ -template -class EnhancedPromise { -public: - /** - * @brief Constructor that initializes the promise and shared future. - */ - EnhancedPromise(); - - /** - * @brief Gets the enhanced future associated with this promise. - * @return An EnhancedFuture object. - */ - auto getEnhancedFuture() -> EnhancedFuture; - - /** - * @brief Sets the value of the promise. - * @param value The value to set. - * @throws PromiseCancelledException if the promise has been cancelled. - */ - void setValue(T value); - - /** - * @brief Sets an exception for the promise. - * @param exception The exception to set. - * @throws PromiseCancelledException if the promise has been cancelled. - */ - void setException(std::exception_ptr exception); - - /** - * @brief Adds a callback to be called when the promise is completed. - * @tparam F The type of the callback function. - * @param func The callback function to add. - */ - template - void onComplete(F &&func); - - /** - * @brief Cancels the promise. - */ - void cancel(); - - /** - * @brief Checks if the promise has been cancelled. - * @return True if the promise has been cancelled, false otherwise. - */ - [[nodiscard]] auto isCancelled() const -> bool; - - /** - * @brief Gets the shared future associated with this promise. - * @return A shared future object. - */ - auto getFuture() -> std::shared_future; - -private: - /** - * @brief Runs all the registered callbacks. - */ - void runCallbacks(); - - std::promise promise_; ///< The underlying promise object. - std::shared_future - future_; ///< The shared future associated with the promise. - std::vector> - callbacks_; ///< List of callbacks to be called on completion. - std::atomic - cancelled_; ///< Flag indicating if the promise has been cancelled. -}; - -/** - * @class EnhancedPromise - * @brief Specialization of the EnhancedPromise class for void type. - */ -template <> -class EnhancedPromise { -public: - /** - * @brief Constructor that initializes the promise and shared future. - */ - EnhancedPromise(); - - /** - * @brief Gets the enhanced future associated with this promise. - * @return An EnhancedFuture object. - */ - auto getEnhancedFuture() -> EnhancedFuture; - - /** - * @brief Sets the value of the promise. - * @throws PromiseCancelledException if the promise has been cancelled. - */ - void setValue(); - - /** - * @brief Sets an exception for the promise. - * @param exception The exception to set. - * @throws PromiseCancelledException if the promise has been cancelled. - */ - void setException(std::exception_ptr exception); - - /** - * @brief Adds a callback to be called when the promise is completed. - * @tparam F The type of the callback function. - * @param func The callback function to add. - */ - template - void onComplete(F &&func); - - /** - * @brief Cancels the promise. - */ - void cancel(); - - /** - * @brief Checks if the promise has been cancelled. - * @return True if the promise has been cancelled, false otherwise. - */ - [[nodiscard]] auto isCancelled() const -> bool; - - /** - * @brief Gets the shared future associated with this promise. - * @return A shared future object. - */ - auto getFuture() -> std::shared_future { return future_; } - -private: - /** - * @brief Runs all the registered callbacks. - */ - void runCallbacks(); - - std::promise promise_; ///< The underlying promise object. - std::shared_future - future_; ///< The shared future associated with the promise. - std::vector> - callbacks_; ///< List of callbacks to be called on completion. - std::atomic - cancelled_; ///< Flag indicating if the promise has been cancelled. -}; - -template -EnhancedPromise::EnhancedPromise() - : future_(promise_.get_future().share()), cancelled_(false) {} - -template -auto EnhancedPromise::getEnhancedFuture() -> EnhancedFuture { - return EnhancedFuture(future_); -} - -template -void EnhancedPromise::setValue(T value) { - if (isCancelled()) { - THROW_PROMISE_CANCELLED_EXCEPTION( - "Cannot set value, promise was cancelled."); - } - promise_.set_value(value); - runCallbacks(); // Execute callbacks -} - -template -void EnhancedPromise::setException(std::exception_ptr exception) { - if (isCancelled()) { - THROW_PROMISE_CANCELLED_EXCEPTION( - "Cannot set exception, promise was cancelled."); - } - promise_.set_exception(exception); - runCallbacks(); // Execute callbacks -} - -template -template -void EnhancedPromise::onComplete(F &&func) { - if (isCancelled()) { - return; // No callbacks should be added if the promise is cancelled - } - callbacks_.emplace_back(std::forward(func)); - - // If the promise is already set, run the callback immediately - if (future_.valid() && future_.wait_for(std::chrono::seconds(0)) == - std::future_status::ready) { - runCallbacks(); - } -} - -template -void EnhancedPromise::cancel() { - cancelled_ = true; -} - -template -auto EnhancedPromise::isCancelled() const -> bool { - return cancelled_.load(); -} - -template -auto EnhancedPromise::getFuture() -> std::shared_future { - return future_; -} - -template -void EnhancedPromise::runCallbacks() { - if (isCancelled()) { - return; - } - if (future_.valid() && future_.wait_for(std::chrono::seconds(0)) == - std::future_status::ready) { - try { - T value = - future_.get(); // Get the value and pass it to the callbacks - for (auto &callback : callbacks_) { - callback(value); - } - } catch (...) { - // Handle the case where the future contains an exception. - // We don't invoke callbacks in this case. - } - } -} - -EnhancedPromise::EnhancedPromise() - : future_(promise_.get_future().share()), cancelled_(false) {} - -auto EnhancedPromise::getEnhancedFuture() -> EnhancedFuture { - return EnhancedFuture(future_); -} - -void EnhancedPromise::setValue() { - if (isCancelled()) { - THROW_PROMISE_CANCELLED_EXCEPTION( - "Cannot set value, promise was cancelled."); - } - promise_.set_value(); - runCallbacks(); // Execute callbacks -} - -void EnhancedPromise::setException(std::exception_ptr exception) { - if (isCancelled()) { - THROW_PROMISE_CANCELLED_EXCEPTION( - "Cannot set exception, promise was cancelled."); - } - promise_.set_exception(exception); - runCallbacks(); // Execute callbacks -} - -template -void EnhancedPromise::onComplete(F &&func) { - if (isCancelled()) { - return; // No callbacks should be added if the promise is cancelled - } - callbacks_.emplace_back(std::forward(func)); - - // If the promise is already set, run the callback immediately - if (future_.valid() && future_.wait_for(std::chrono::seconds(0)) == - std::future_status::ready) { - runCallbacks(); - } -} - -void EnhancedPromise::cancel() { cancelled_ = true; } - -auto EnhancedPromise::isCancelled() const -> bool { - return cancelled_.load(); -} - -void EnhancedPromise::runCallbacks() { - if (isCancelled()) { - return; - } - if (future_.valid() && future_.wait_for(std::chrono::seconds(0)) == - std::future_status::ready) { - try { - future_.get(); // Get the value and execute callbacks (for void, - // there's no value to pass) - for (auto &callback : callbacks_) { - callback(); - } - } catch (...) { - // Handle the case where the future contains an exception. - // We don't invoke callbacks in this case. - } - } -} - -} // namespace atom::async - -#endif // ATOM_ASYNC_PROMISE_HPP diff --git a/src/atom/async/queue.hpp b/src/atom/async/queue.hpp deleted file mode 100644 index 247ffc37..00000000 --- a/src/atom/async/queue.hpp +++ /dev/null @@ -1,317 +0,0 @@ -/* - * queue.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-2-13 - -Description: A simple thread safe queue - -**************************************************/ - -#ifndef ATOM_ASYNC_QUEUE_HPP -#define ATOM_ASYNC_QUEUE_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::async { -template -class ThreadSafeQueue { -public: - ThreadSafeQueue() = default; - - void put(T element) { - { - std::lock_guard lock(m_mutex_); - m_queue_.push(std::move(element)); - } - m_conditionVariable_.notify_one(); - } - - auto take() -> std::optional { - std::unique_lock lock(m_mutex_); - m_conditionVariable_.wait( - lock, [this] { return m_mustReturnNullptr_ || !m_queue_.empty(); }); - - if (m_mustReturnNullptr_) { - return std::nullopt; - } - - T ret = std::move(m_queue_.front()); - m_queue_.pop(); - - return ret; - } - - auto destroy() -> std::queue { - { - std::lock_guard lock(m_mutex_); - m_mustReturnNullptr_ = true; - } - m_conditionVariable_.notify_all(); - - std::queue result; - { - std::lock_guard lock(m_mutex_); - std::swap(result, m_queue_); - } - return result; - } - - [[nodiscard]] auto size() const -> size_t { - std::lock_guard lock(m_mutex_); - return m_queue_.size(); - } - - [[nodiscard]] auto empty() const -> bool { - std::lock_guard lock(m_mutex_); - return m_queue_.empty(); - } - - void clear() { - std::lock_guard lock(m_mutex_); - std::queue empty; - std::swap(m_queue_, empty); - } - - auto front() -> std::optional { - std::lock_guard lock(m_mutex_); - if (m_queue_.empty()) { - return std::nullopt; - } - return m_queue_.front(); - } - - auto back() -> std::optional { - std::lock_guard lock(m_mutex_); - if (m_queue_.empty()) { - return std::nullopt; - } - return m_queue_.back(); - } - - template - void emplace(Args&&... args) { - { - std::lock_guard lock(m_mutex_); - m_queue_.emplace(std::forward(args)...); - } - m_conditionVariable_.notify_one(); - } - - template Predicate> - auto waitFor(Predicate predicate) -> std::optional { - std::unique_lock lock(m_mutex_); - m_conditionVariable_.wait(lock, [this, &predicate] { - return m_mustReturnNullptr_ || - (!m_queue_.empty() && predicate(m_queue_.front())); - }); - - if (m_mustReturnNullptr_) - return std::nullopt; - - T ret = std::move(m_queue_.front()); - m_queue_.pop(); - - return ret; - } - - void waitUntilEmpty() { - std::unique_lock lock(m_mutex_); - m_conditionVariable_.wait( - lock, [this] { return m_mustReturnNullptr_ || m_queue_.empty(); }); - } - - template UnaryPredicate> - auto extractIf(UnaryPredicate pred) -> std::vector { - std::vector result; - { - std::lock_guard lock(m_mutex_); - std::queue remaining; - while (!m_queue_.empty()) { - T& item = m_queue_.front(); - if (pred(item)) { - result.push_back(std::move(item)); - } else { - remaining.push(std::move(item)); - } - m_queue_.pop(); - } - std::swap(m_queue_, remaining); - } - return result; - } - - template - requires std::is_invocable_r_v - void sort(Compare comp) { - std::lock_guard lock(m_mutex_); - std::vector temp; - temp.reserve(m_queue_.size()); - while (!m_queue_.empty()) { - temp.push_back(std::move(m_queue_.front())); - m_queue_.pop(); - } - std::sort(temp.begin(), temp.end(), comp); - for (auto& elem : temp) { - m_queue_.push(std::move(elem)); - } - } - - template - auto transform(std::function func) - -> std::shared_ptr> { - std::shared_ptr> resultQueue; - { - std::lock_guard lock(m_mutex_); - std::vector original; - original.reserve(m_queue_.size()); - - while (!m_queue_.empty()) { - original.push_back(std::move(m_queue_.front())); - m_queue_.pop(); - } - - std::vector transformed(original.size()); - std::transform(original.begin(), original.end(), - transformed.begin(), func); - - for (auto& item : transformed) { - resultQueue->put(std::move(item)); - } - } - return resultQueue; - } - - template - auto groupBy(std::function func) - -> std::vector>> { - std::unordered_map>> - resultMap; - { - std::lock_guard lock(m_mutex_); - while (!m_queue_.empty()) { - T item = std::move(m_queue_.front()); - m_queue_.pop(); - GroupKey key = func(item); - if (!resultMap.contains(key)) { - resultMap[key] = std::make_shared>(); - } - resultMap[key]->put(std::move(item)); - } - } - - std::vector>> resultQueues; - resultQueues.reserve(resultMap.size()); - for (auto& [_, queue_ptr] : resultMap) { - resultQueues.push_back(queue_ptr); - } - - return resultQueues; - } - - auto toVector() const -> std::vector { - std::lock_guard lock(m_mutex_); - return std::vector(m_queue_.front(), m_queue_.back()); - } - - template - requires std::is_invocable_r_v - void forEach(Func func, bool parallel = false) { - std::lock_guard lock(m_mutex_); - if (parallel) { - std::vector vec; - vec.reserve(m_queue_.size()); - while (!m_queue_.empty()) { - vec.push_back(std::move(m_queue_.front())); - m_queue_.pop(); - } - -#pragma omp parallel for - for (size_t i = 0; i < vec.size(); ++i) { - func(vec[i]); - } - - for (auto& item : vec) { - m_queue_.push(std::move(item)); - } - } else { - std::queue tempQueue; - while (!m_queue_.empty()) { - T& item = m_queue_.front(); - func(item); - tempQueue.push(std::move(item)); - m_queue_.pop(); - } - m_queue_ = std::move(tempQueue); - } - } - - auto tryTake() -> std::optional { - std::lock_guard lock(m_mutex_); - if (m_queue_.empty()) { - return std::nullopt; - } - T ret = std::move(m_queue_.front()); - m_queue_.pop(); - return ret; - } - - template - auto takeFor(const std::chrono::duration& timeout) - -> std::optional { - std::unique_lock lock(m_mutex_); - if (m_conditionVariable_.wait_for(lock, timeout, [this] { - return !m_queue_.empty() || m_mustReturnNullptr_; - })) { - if (m_mustReturnNullptr_) { - return std::nullopt; - } - T ret = std::move(m_queue_.front()); - m_queue_.pop(); - return ret; - } - return std::nullopt; - } - - template - auto takeUntil(const std::chrono::time_point& timeout_time) - -> std::optional { - std::unique_lock lock(m_mutex_); - if (m_conditionVariable_.wait_until(lock, timeout_time, [this] { - return !m_queue_.empty() || m_mustReturnNullptr_; - })) { - if (m_mustReturnNullptr_) { - return std::nullopt; - } - T ret = std::move(m_queue_.front()); - m_queue_.pop(); - return ret; - } - return std::nullopt; - } - -private: - std::queue m_queue_; - mutable std::mutex m_mutex_; - std::condition_variable m_conditionVariable_; - std::atomic m_mustReturnNullptr_{false}; -}; - -} // namespace atom::async - -#endif // ATOM_ASYNC_QUEUE_HPP diff --git a/src/atom/async/safetype.hpp b/src/atom/async/safetype.hpp deleted file mode 100644 index 28e0199f..00000000 --- a/src/atom/async/safetype.hpp +++ /dev/null @@ -1,690 +0,0 @@ -#ifndef ATOM_ASYNC_SAFETYPE_HPP -#define ATOM_ASYNC_SAFETYPE_HPP - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "atom/error/exception.hpp" - -namespace atom::async { -/** - * @brief A lock-free stack implementation suitable for concurrent use. - * - * @tparam T Type of elements stored in the stack. - */ -template -class LockFreeStack { -private: - struct Node { - T value; ///< The stored value of type T. - std::atomic next = - nullptr; ///< Pointer to the next node in the stack. - - /** - * @brief Construct a new Node object. - * - * @param value_ The value to store in the node. - */ - explicit Node(T value_); - }; - - std::atomic head_; ///< Atomic pointer to the top of the stack. - std::atomic approximateSize_ = - 0; ///< An approximate count of the stack's elements. - -public: - /** - * @brief Construct a new Lock Free Stack object. - */ - LockFreeStack(); - - /** - * @brief Destroy the Lock Free Stack object. - */ - ~LockFreeStack(); - - /** - * @brief Pushes a value onto the stack. Thread-safe. - * - * @param value The value to push onto the stack. - */ - void push(const T& value); - - /** - * @brief Pushes a value onto the stack using move semantics. Thread-safe. - * - * @param value The value to move onto the stack. - */ - void push(T&& value); - - /** - * @brief Attempts to pop the top value off the stack. Thread-safe. - * - * @return std::optional The popped value if stack is not empty, - * otherwise nullopt. - */ - auto pop() -> std::optional; - - /** - * @brief Get the top value of the stack without removing it. Thread-safe. - * - * @return std::optional The top value if stack is not empty, otherwise - * nullopt. - */ - auto top() const -> std::optional; - - /** - * @brief Check if the stack is empty. Thread-safe. - * - * @return true If the stack is empty. - * @return false If the stack has one or more elements. - */ - [[nodiscard]] auto empty() const -> bool; - - /** - * @brief Get the approximate size of the stack. Thread-safe. - * - * @return int The approximate number of elements in the stack. - */ - [[nodiscard]] auto size() const -> int; -}; - -template -class LockFreeHashTable { -private: - struct Node { - Key key; - Value value; - std::atomic next; - - Node(Key k, Value v) : key(k), value(v), next(nullptr) {} - }; - - struct Bucket { - std::atomic head; - - Bucket() : head(nullptr) {} - - ~Bucket() { - Node* node = head.load(); - while (node) { - Node* next = node->next.load(); - delete node; - node = next; - } - } - - auto find(const Key& key) const -> std::optional { - Node* node = head.load(); - while (node) { - if (node->key == key) { - return node->value; - } - node = node->next.load(); - } - return std::nullopt; - } - - void insert(const Key& key, const Value& value) { - Node* newNode = new Node(key, value); - newNode->next = head.load(); - Node* expected = newNode->next.load(); - while (!head.compare_exchange_weak(expected, newNode)) { - newNode->next = expected; - } - } - - void erase(const Key& key) { - Node* node = head.load(); - Node* prev = nullptr; - while (node) { - if (node->key == key) { - Node* next = node->next.load(); - if (prev) { - prev->next.compare_exchange_strong(node, next); - } else { - head.compare_exchange_strong(node, next); - } - delete node; - return; - } - prev = node; - node = node->next.load(); - } - } - }; - - std::vector> buckets_; - std::hash hasher_; - - auto getBucket(const Key& key) const -> Bucket& { - auto bucketIndex = hasher_(key) % buckets_.size(); - return *buckets_[bucketIndex]; - } - -public: - explicit LockFreeHashTable(size_t num_buckets = 16) - : buckets_(num_buckets) { - for (size_t i = 0; i < num_buckets; ++i) { - buckets_[i] = std::make_unique(); - } - } - - auto find(const Key& key) const -> std::optional { - return getBucket(key).find(key); - } - - void insert(const Key& key, const Value& value) { - getBucket(key).insert(key, value); - } - - void erase(const Key& key) { getBucket(key).erase(key); } - - [[nodiscard]] auto empty() const -> bool { - for (const auto& bucket : buckets_) { - if (bucket->head.load() != nullptr) { - return false; - } - } - return true; - } - - [[nodiscard]] auto size() const -> size_t { - size_t totalSize = 0; - for (const auto& bucket : buckets_) { - Node* node = bucket->head.load(); - while (node) { - ++totalSize; - node = node->next.load(); - } - } - return totalSize; - } - - void clear() { - for (const auto& bucket : buckets_) { - Node* node = bucket->head.load(); - while (node) { - Node* next = node->next.load(); - delete node; - node = next; - } - bucket->head.store(nullptr); - } - } - - auto operator[](const Key& key) -> Value& { - auto& bucket = getBucket(key); - auto value = bucket.find(key); - if (value) { - return *value; - } - insert(key, Value()); - return *find(key); - } - - // 迭代器类 - class Iterator { - public: - using iterator_category = std::forward_iterator_tag; - using value_type = std::pair; - using difference_type = std::ptrdiff_t; - using pointer = value_type*; - using reference = value_type&; - - Iterator( - typename std::vector>::iterator bucket_iter, - typename std::vector>::iterator bucket_end, - Node* node) - : bucket_iter_(bucket_iter), bucket_end_(bucket_end), node_(node) { - advancePastEmptyBuckets(); - } - - auto operator++() -> Iterator& { - if (node_) { - node_ = node_->next.load(); - if (!node_) { - ++bucket_iter_; - advancePastEmptyBuckets(); - } - } - return *this; - } - - auto operator++(int) -> Iterator { - Iterator tmp = *this; - ++(*this); - return tmp; - } - - auto operator==(const Iterator& other) const -> bool { - return bucket_iter_ == other.bucket_iter_ && node_ == other.node_; - } - - auto operator!=(const Iterator& other) const -> bool { - return !(*this == other); - } - - auto operator*() const -> reference { - return *reinterpret_cast(node_); - } - - auto operator->() const -> pointer { - return reinterpret_cast(node_); - } - - private: - void advancePastEmptyBuckets() { - while (bucket_iter_ != bucket_end_ && !node_) { - node_ = (*bucket_iter_)->head.load(); - if (!node_) { - ++bucket_iter_; - } - } - } - - typename std::vector>::iterator bucket_iter_; - typename std::vector>::iterator bucket_end_; - Node* node_; - }; - - auto begin() -> Iterator { - auto bucketIter = buckets_.begin(); - auto bucketEnd = buckets_.end(); - Node* node = - bucketIter != bucketEnd ? (*bucketIter)->head.load() : nullptr; - return Iterator(bucketIter, bucketEnd, node); - } - - auto end() -> Iterator { - return Iterator(buckets_.end(), buckets_.end(), nullptr); - } -}; - -template -class ThreadSafeVector { - std::atomic data_; - std::atomic capacity_; - std::atomic size_; - mutable std::shared_mutex resize_mutex_; - - void resize() { - std::unique_lock lock(resize_mutex_); - - size_t oldCapacity = capacity_.load(std::memory_order_relaxed); - size_t newCapacity = oldCapacity * 2; - T* newData = new T[newCapacity]; - - for (size_t i = 0; i < size_.load(std::memory_order_relaxed); ++i) { - newData[i] = std::move(data_.load(std::memory_order_relaxed)[i]); - } - - T* oldData = data_.exchange(newData, std::memory_order_acq_rel); - capacity_.store(newCapacity, std::memory_order_release); - - delete[] oldData; - } - -public: - explicit ThreadSafeVector(size_t initial_capacity = 16) - : data_(new T[initial_capacity]), - capacity_(initial_capacity), - size_(0) {} - - ~ThreadSafeVector() { delete[] data_.load(std::memory_order_relaxed); } - - void pushBack(const T& value) { - size_t currentSize = size_.load(std::memory_order_relaxed); - while (true) { - if (currentSize < capacity_.load(std::memory_order_relaxed)) { - if (size_.compare_exchange_weak(currentSize, currentSize + 1, - std::memory_order_acq_rel)) { - data_.load(std::memory_order_relaxed)[currentSize] = value; - return; - } - } else { - resize(); - } - currentSize = size_.load(std::memory_order_relaxed); - } - } - - void pushBack(T&& value) { - size_t currentSize = size_.load(std::memory_order_relaxed); - while (true) { - if (currentSize < capacity_.load(std::memory_order_relaxed)) { - if (size_.compare_exchange_weak(currentSize, currentSize + 1, - std::memory_order_acq_rel)) { - data_.load(std::memory_order_relaxed)[currentSize] = - std::move(value); - return; - } - } else { - resize(); - } - currentSize = size_.load(std::memory_order_relaxed); - } - } - - auto popBack() -> std::optional { - size_t currentSize = size_.load(std::memory_order_relaxed); - while (currentSize > 0) { - if (size_.compare_exchange_weak(currentSize, currentSize - 1, - std::memory_order_acq_rel)) { - return data_.load(std::memory_order_relaxed)[currentSize - 1]; - } - currentSize = size_.load(std::memory_order_relaxed); - } - return std::nullopt; - } - - auto at(size_t index) const -> std::optional { - if (index >= size_.load(std::memory_order_relaxed)) { - return std::nullopt; - } - return data_.load(std::memory_order_relaxed)[index]; - } - - auto empty() const -> bool { - return size_.load(std::memory_order_relaxed) == 0; - } - - auto getSize() const -> size_t { - return size_.load(std::memory_order_relaxed); - } - - auto getCapacity() const -> size_t { - return capacity_.load(std::memory_order_relaxed); - } - - void clear() { size_.store(0, std::memory_order_relaxed); } - - void shrinkToFit() { - std::unique_lock lock(resize_mutex_); - - size_t currentSize = size_.load(std::memory_order_relaxed); - T* newData = new T[currentSize]; - - for (size_t i = 0; i < currentSize; ++i) { - newData[i] = std::move(data_.load(std::memory_order_relaxed)[i]); - } - - T* oldData = data_.exchange(newData, std::memory_order_acq_rel); - capacity_.store(currentSize, std::memory_order_release); - - delete[] oldData; - } - - auto front() const -> T { - if (empty()) { - THROW_OUT_OF_RANGE("Vector is empty"); - } - return data_.load(std::memory_order_relaxed)[0]; - } - - auto back() const -> T { - if (empty()) { - THROW_OUT_OF_RANGE("Vector is empty"); - } - return data_.load( - std::memory_order_relaxed)[size_.load(std::memory_order_relaxed) - - 1]; - } - - auto operator[](size_t index) const -> T { - if (index >= size_.load(std::memory_order_relaxed)) { - THROW_OUT_OF_RANGE("Index out of range"); - } - return data_.load(std::memory_order_relaxed)[index]; - } -}; - -template -class LockFreeList { -private: - struct Node { - std::shared_ptr value; - std::atomic next; - explicit Node(T val) : value(std::make_shared(val)), next(nullptr) {} - }; - - std::atomic head_; - - // Hazard pointers structure - struct HazardPointer { - std::atomic id; - std::atomic pointer; - }; - - static const int MAX_HAZARD_POINTERS = 100; - HazardPointer hazard_pointers_[MAX_HAZARD_POINTERS]; - - // Get hazard pointer for current thread - auto getHazardPointerForCurrentThread() -> std::atomic& { - std::thread::id thisId = std::this_thread::get_id(); - for (auto& hazardPointer : hazard_pointers_) { - std::thread::id oldId; - if (hazardPointer.id.compare_exchange_strong(oldId, thisId)) { - return hazardPointer.pointer; - } - if (hazardPointer.id == thisId) { - return hazardPointer.pointer; - } - } - THROW_RUNTIME_ERROR("No hazard pointers available"); - } - - // Reclaim list - void reclaimLater(Node* node) { - retired_nodes_.push_back(node); - if (retired_nodes_.size() >= MAX_HAZARD_POINTERS) { - doReclamation(); - } - } - - // Reclaim retired nodes - void doReclamation() { - std::vector toReclaim; - for (Node* node : retired_nodes_) { - if (!isHazard(node)) { - toReclaim.push_back(node); - } - } - retired_nodes_.clear(); - for (Node* node : toReclaim) { - delete node; - } - } - - // Check if node is a hazard - auto isHazard(Node* node) -> bool { - for (auto& hazardPointer : hazard_pointers_) { - if (hazardPointer.pointer.load() == node) { - return true; - } - } - return false; - } - - std::vector retired_nodes_; - -public: - LockFreeList() : head_(nullptr) {} - - ~LockFreeList() { - while (head_.load()) { - Node* oldHead = head_.load(); - head_.store(oldHead->next); - delete oldHead; - } - } - - void pushFront(T value) { - Node* newNode = new Node(value); - newNode->next = head_.load(); - while (!head_.compare_exchange_weak(newNode->next, newNode)) { - } - } - - auto popFront() -> std::optional { - std::atomic& hazardPointer = getHazardPointerForCurrentThread(); - Node* oldHead = head_.load(); - do { - Node* temp; - do { - temp = oldHead; - hazardPointer.store(oldHead); - oldHead = head_.load(); - } while (oldHead != temp); - if (!oldHead) { - hazardPointer.store(nullptr); - return std::nullopt; - } - } while (!head_.compare_exchange_strong(oldHead, oldHead->next)); - hazardPointer.store(nullptr); - std::shared_ptr res = oldHead->value; - if (res.use_count() == 1) { - reclaimLater(oldHead); - } - return *res; - } - - [[nodiscard]] auto empty() const -> bool { return head_.load() == nullptr; } - - class Iterator { - public: - using iterator_category = std::forward_iterator_tag; - using value_type = T; - using difference_type = std::ptrdiff_t; - using pointer = T*; - using reference = T&; - - Iterator(Node* node, LockFreeList* list) : node_(node), list_(list) {} - - auto operator++() -> Iterator& { - if (node_) { - node_ = node_->next.load(); - } - return *this; - } - - auto operator++(int) -> Iterator { - Iterator tmp = *this; - ++(*this); - return tmp; - } - - auto operator==(const Iterator& other) const -> bool { - return node_ == other.node_; - } - - auto operator!=(const Iterator& other) const -> bool { - return node_ != other.node_; - } - - auto operator*() const -> reference { return *(node_->value); } - - auto operator->() const -> pointer { return node_->value.get(); } - - private: - Node* node_; - LockFreeList* list_; - }; - - auto begin() -> Iterator { return Iterator(head_.load(), this); } - - auto end() -> Iterator { return Iterator(nullptr, this); } -}; - -template -LockFreeStack::Node::Node(T value_) : value(std::move(value_)) {} - -// 构造函数 -template -LockFreeStack::LockFreeStack() : head_(nullptr) {} - -// 析构函数 -template -LockFreeStack::~LockFreeStack() { - while (auto node = head_.load(std::memory_order_relaxed)) { - head_.store(node->next.load(std::memory_order_relaxed), - std::memory_order_relaxed); - delete node; - } -} - -// push 常量左值引用 -template -void LockFreeStack::push(const T& value) { - auto newNode = new Node(value); - newNode->next = head_.load(std::memory_order_relaxed); - Node* expected = newNode->next.load(std::memory_order_relaxed); - while (!head_.compare_exchange_weak(expected, newNode, - std::memory_order_release, - std::memory_order_relaxed)) { - newNode->next = expected; - } - approximateSize_.fetch_add(1, std::memory_order_relaxed); -} - -// push 右值引用 -template -void LockFreeStack::push(T&& value) { - auto newNode = new Node(std::move(value)); - newNode->next = head_.load(std::memory_order_relaxed); - Node* expected = newNode->next.load(std::memory_order_relaxed); - while (!head_.compare_exchange_weak(expected, newNode, - std::memory_order_release, - std::memory_order_relaxed)) { - newNode->next = expected; - } - approximateSize_.fetch_add(1, std::memory_order_relaxed); -} - -// pop -template -auto LockFreeStack::pop() -> std::optional { - Node* oldHead = head_.load(std::memory_order_relaxed); - while (oldHead && !head_.compare_exchange_weak(oldHead, oldHead->next, - std::memory_order_acquire, - std::memory_order_relaxed)) { - } - if (oldHead) { - T value = std::move(oldHead->value); - delete oldHead; - approximateSize_.fetch_sub(1, std::memory_order_relaxed); - return value; - } - return std::nullopt; -} - -// top -template -auto LockFreeStack::top() const -> std::optional { - Node* topNode = head_.load(std::memory_order_relaxed); - if (head_.load(std::memory_order_relaxed)) { - return std::optional(topNode->value); - } - return std::nullopt; -} - -// empty -template -auto LockFreeStack::empty() const -> bool { - return head_.load(std::memory_order_relaxed) == nullptr; -} - -// size -template -auto LockFreeStack::size() const -> int { - return approximateSize_.load(std::memory_order_relaxed); -} -} // namespace atom::async - -#endif // ATOM_ASYNC_SAFETYPE_HPP diff --git a/src/atom/async/slot.hpp b/src/atom/async/slot.hpp deleted file mode 100644 index 8fd29c6f..00000000 --- a/src/atom/async/slot.hpp +++ /dev/null @@ -1,569 +0,0 @@ -#ifndef ATOM_ASYNC_SIGNAL_HPP -#define ATOM_ASYNC_SIGNAL_HPP - -#include -#include -#include -#include -#include -#include -#include - -namespace atom::async { - -/** - * @brief A signal class that allows connecting, disconnecting, and emitting - * slots. - * - * @tparam Args The argument types for the slots. - */ -template -class Signal { -public: - using SlotType = std::function; - - /** - * @brief Connect a slot to the signal. - * - * @param slot The slot to connect. - */ - void connect(SlotType slot) { - std::lock_guard lock(mutex_); - slots_.push_back(std::move(slot)); - } - - /** - * @brief Disconnect a slot from the signal. - * - * @param slot The slot to disconnect. - */ - void disconnect(const SlotType& slot) { - std::lock_guard lock(mutex_); - slots_.erase(std::remove_if(slots_.begin(), slots_.end(), - [&](const SlotType& s) { - return s.target_type() == - slot.target_type(); - }), - slots_.end()); - } - - /** - * @brief Emit the signal, calling all connected slots. - * - * @param args The arguments to pass to the slots. - */ - void emit(Args... args) { - std::lock_guard lock(mutex_); - for (const auto& slot : slots_) { - slot(args...); - } - } - -private: - std::vector slots_; - std::mutex mutex_; -}; - -/** - * @brief A signal class that allows asynchronous slot execution. - * - * @tparam Args The argument types for the slots. - */ -template -class AsyncSignal { -public: - using SlotType = std::function; - - /** - * @brief Connect a slot to the signal. - * - * @param slot The slot to connect. - */ - void connect(SlotType slot) { - std::lock_guard lock(mutex_); - slots_.push_back(std::move(slot)); - } - - /** - * @brief Disconnect a slot from the signal. - * - * @param slot The slot to disconnect. - */ - void disconnect(const SlotType& slot) { - std::lock_guard lock(mutex_); - slots_.erase(std::remove_if(slots_.begin(), slots_.end(), - [&](const SlotType& s) { - return s.target_type() == - slot.target_type(); - }), - slots_.end()); - } - - /** - * @brief Emit the signal asynchronously, calling all connected slots. - * - * @param args The arguments to pass to the slots. - */ - void emit(Args... args) { - std::vector> futures; - { - std::lock_guard lock(mutex_); - for (const auto& slot : slots_) { - futures.push_back( - std::async(std::launch::async, slot, args...)); - } - } - for (auto& future : futures) { - future.get(); - } - } - -private: - std::vector slots_; - std::mutex mutex_; -}; - -/** - * @brief A signal class that allows automatic disconnection of slots. - * - * @tparam Args The argument types for the slots. - */ -template -class AutoDisconnectSignal { -public: - using SlotType = std::function; - - /** - * @brief Connect a slot to the signal and return its unique ID. - * - * @param slot The slot to connect. - * @return int The unique ID of the connected slot. - */ - auto connect(SlotType slot) -> int { - std::lock_guard lock(mutex_); - auto id = nextId_++; - slots_.emplace(id, std::move(slot)); - return id; - } - - /** - * @brief Disconnect a slot from the signal using its unique ID. - * - * @param id The unique ID of the slot to disconnect. - */ - void disconnect(int id) { - std::lock_guard lock(mutex_); - slots_.erase(id); - } - - /** - * @brief Emit the signal, calling all connected slots. - * - * @param args The arguments to pass to the slots. - */ - void emit(Args... args) { - std::lock_guard lock(mutex_); - for (const auto& [id, slot] : slots_) { - slot(args...); - } - } - -private: - std::map slots_; - std::mutex mutex_; - int nextId_ = 0; -}; - -/** - * @brief A signal class that allows chaining of signals. - * - * @tparam Args The argument types for the slots. - */ -template -class ChainedSignal { -public: - using SlotType = std::function; - - /** - * @brief Connect a slot to the signal. - * - * @param slot The slot to connect. - */ - void connect(SlotType slot) { - std::lock_guard lock(mutex_); - slots_.push_back(std::move(slot)); - } - - /** - * @brief Add a chained signal to be emitted after this signal. - * - * @param nextSignal The next signal to chain. - */ - void addChain(ChainedSignal& nextSignal) { - std::lock_guard lock(mutex_); - chains_.push_back(&nextSignal); - } - - /** - * @brief Emit the signal, calling all connected slots and chained signals. - * - * @param args The arguments to pass to the slots. - */ - void emit(Args... args) { - std::lock_guard lock(mutex_); - for (const auto& slot : slots_) { - slot(args...); - } - for (auto& chain : chains_) { - chain->emit(args...); - } - } - -private: - std::vector slots_; - std::vector*> chains_; - std::mutex mutex_; -}; - -/** - * @brief A signal class that allows connecting, disconnecting, and emitting - * slots. - * - * @tparam Args The argument types for the slots. - */ -template -class TemplateSignal { -public: - using SlotType = std::function; - - /** - * @brief Connect a slot to the signal. - * - * @param slot The slot to connect. - */ - void connect(SlotType slot) { - std::lock_guard lock(mutex_); - slots_.push_back(std::move(slot)); - } - - /** - * @brief Disconnect a slot from the signal. - * - * @param slot The slot to disconnect. - */ - void disconnect(const SlotType& slot) { - std::lock_guard lock(mutex_); - slots_.erase(std::remove_if(slots_.begin(), slots_.end(), - [&](const SlotType& s) { - return s.target_type() == - slot.target_type(); - }), - slots_.end()); - } - - /** - * @brief Emit the signal, calling all connected slots. - * - * @param args The arguments to pass to the slots. - */ - void emit(Args... args) { - std::lock_guard lock(mutex_); - for (const auto& slot : slots_) { - slot(args...); - } - } - -private: - std::vector slots_; - std::mutex mutex_; -}; - -/** - * @brief A signal class that ensures thread-safe slot execution. - * - * @tparam Args The argument types for the slots. - */ -template -class ThreadSafeSignal { -public: - using SlotType = std::function; - - /** - * @brief Connect a slot to the signal. - * - * @param slot The slot to connect. - */ - void connect(SlotType slot) { - std::lock_guard lock(mutex_); - slots_.push_back(std::move(slot)); - } - - /** - * @brief Disconnect a slot from the signal. - * - * @param slot The slot to disconnect. - */ - void disconnect(const SlotType& slot) { - std::lock_guard lock(mutex_); - slots_.erase(std::remove_if(slots_.begin(), slots_.end(), - [&](const SlotType& s) { - return s.target_type() == - slot.target_type(); - }), - slots_.end()); - } - - /** - * @brief Emit the signal, calling all connected slots in a thread-safe - * manner. - * - * @param args The arguments to pass to the slots. - */ - void emit(Args... args) { - std::vector> tasks; - { - std::lock_guard lock(mutex_); - for (const auto& slot : slots_) { - tasks.emplace_back([slot, args...]() { slot(args...); }); - } - } - for (auto& task : tasks) { - std::async(std::launch::async, task).get(); - } - } - -private: - std::vector slots_; - std::mutex mutex_; -}; - -/** - * @brief A signal class that allows broadcasting to chained signals. - * - * @tparam Args The argument types for the slots. - */ -template -class BroadcastSignal { -public: - using SlotType = std::function; - - /** - * @brief Connect a slot to the signal. - * - * @param slot The slot to connect. - */ - void connect(SlotType slot) { - std::lock_guard lock(mutex_); - slots_.push_back(std::move(slot)); - } - - /** - * @brief Disconnect a slot from the signal. - * - * @param slot The slot to disconnect. - */ - void disconnect(const SlotType& slot) { - std::lock_guard lock(mutex_); - slots_.erase(std::remove_if(slots_.begin(), slots_.end(), - [&](const SlotType& s) { - return s.target_type() == - slot.target_type(); - }), - slots_.end()); - } - - /** - * @brief Emit the signal, calling all connected slots and chained signals. - * - * @param args The arguments to pass to the slots. - */ - void emit(Args... args) { - std::lock_guard lock(mutex_); - for (const auto& slot : slots_) { - slot(args...); - } - for (const auto& signal : chainedSignals_) { - signal->emit(args...); - } - } - - /** - * @brief Add a chained signal to be emitted after this signal. - * - * @param signal The next signal to chain. - */ - void addChain(BroadcastSignal& signal) { - std::lock_guard lock(mutex_); - chainedSignals_.push_back(&signal); - } - -private: - std::vector slots_; - std::vector*> chainedSignals_; - std::mutex mutex_; -}; - -/** - * @brief A signal class that limits the number of times it can be emitted. - * - * @tparam Args The argument types for the slots. - */ -template -class LimitedSignal { -public: - using SlotType = std::function; - - /** - * @brief Construct a new Limited Signal object. - * - * @param maxCalls The maximum number of times the signal can be emitted. - */ - explicit LimitedSignal(size_t maxCalls) : maxCalls_(maxCalls) {} - - /** - * @brief Connect a slot to the signal. - * - * @param slot The slot to connect. - */ - void connect(SlotType slot) { - std::lock_guard lock(mutex_); - slots_.push_back(std::move(slot)); - } - - /** - * @brief Disconnect a slot from the signal. - * - * @param slot The slot to disconnect. - */ - void disconnect(const SlotType& slot) { - std::lock_guard lock(mutex_); - slots_.erase(std::remove_if(slots_.begin(), slots_.end(), - [&](const SlotType& s) { - return s.target_type() == - slot.target_type(); - }), - slots_.end()); - } - - /** - * @brief Emit the signal, calling all connected slots up to the maximum - * number of calls. - * - * @param args The arguments to pass to the slots. - */ - void emit(Args... args) { - std::lock_guard lock(mutex_); - if (callCount_ >= maxCalls_) { - return; - } - for (const auto& slot : slots_) { - slot(args...); - } - ++callCount_; - } - -private: - std::vector slots_; - size_t maxCalls_; - size_t callCount_{}; - std::mutex mutex_; -}; - -/** - * @brief A signal class that allows dynamic slot management. - * - * @tparam Args The argument types for the slots. - */ -template -class DynamicSignal { -public: - using SlotType = std::function; - - /** - * @brief Connect a slot to the signal. - * - * @param slot The slot to connect. - */ - void connect(SlotType slot) { - std::lock_guard lock(mutex_); - slots_.push_back(std::move(slot)); - } - - /** - * @brief Disconnect a slot from the signal. - * - * @param slot The slot to disconnect. - */ - void disconnect(const SlotType& slot) { - std::lock_guard lock(mutex_); - slots_.erase(std::remove_if(slots_.begin(), slots_.end(), - [&](const SlotType& s) { - return s.target_type() == - slot.target_type(); - }), - slots_.end()); - } - - /** - * @brief Emit the signal, calling all connected slots. - * - * @param args The arguments to pass to the slots. - */ - void emit(Args... args) { - std::lock_guard lock(mutex_); - for (const auto& slot : slots_) { - slot(args...); - } - } - -private: - std::vector slots_; - std::mutex mutex_; -}; - -/** - * @brief A signal class that allows scoped slot management. - * - * @tparam Args The argument types for the slots. - */ -template -class ScopedSignal { -public: - using SlotType = std::function; - - /** - * @brief Connect a slot to the signal using a shared pointer. - * - * @param slotPtr The shared pointer to the slot to connect. - */ - void connect(std::shared_ptr slotPtr) { - std::lock_guard lock(mutex_); - slots_.push_back(slotPtr); - } - - /** - * @brief Emit the signal, calling all connected slots. - * - * @param args The arguments to pass to the slots. - */ - void emit(Args... args) { - std::lock_guard lock(mutex_); - auto it = slots_.begin(); - while (it != slots_.end()) { - if (auto slot = *it; slot) { - (*slot)(args...); - ++it; - } else { - it = slots_.erase(it); - } - } - } - -private: - std::vector> slots_; - std::mutex mutex_; -}; - -} // namespace atom::async - -#endif diff --git a/src/atom/async/thread_wrapper.hpp b/src/atom/async/thread_wrapper.hpp deleted file mode 100644 index 3c3fd8c9..00000000 --- a/src/atom/async/thread_wrapper.hpp +++ /dev/null @@ -1,140 +0,0 @@ -/* - * thread_wrapper.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-2-13 - -Description: A simple wrapper of std::jthread - -**************************************************/ - -#ifndef ATOM_ASYNC_THREAD_WRAPPER_HPP -#define ATOM_ASYNC_THREAD_WRAPPER_HPP - -#include -#include -#include -#include -#include "type/noncopyable.hpp" - -namespace atom::async { -/** - * @brief A wrapper class for managing a C++20 jthread. - * - * This class provides a convenient interface for managing a C++20 jthread, - * allowing for starting, stopping, and joining threads easily. - */ -class Thread : public NonCopyable { -public: - /** - * @brief Default constructor. - */ - Thread() = default; - - /** - * @brief Starts a new thread with the specified callable object and - * arguments. - * - * If the callable object is invocable with a std::stop_token and the - * provided arguments, it will be invoked with a std::stop_token as the - * first argument. Otherwise, it will be invoked with the provided - * arguments. - * - * @tparam Callable The type of the callable object. - * @tparam Args The types of the arguments. - * @param func The callable object to execute in the new thread. - * @param args The arguments to pass to the callable object. - */ - template - void start(Callable&& func, Args&&... args) { - thread_ = - std::jthread([func = std::forward(func), - ... args = std::forward(args), this]() mutable { - if constexpr (std::is_invocable_v) { - func(std::stop_token(thread_.get_stop_token()), - std::move(args)...); - } else { - func(std::move(args)...); - } - }); - } - - /** - * @brief Requests the thread to stop execution. - */ - void requestStop() { thread_.request_stop(); } - - /** - * @brief Waits for the thread to finish execution. - */ - void join() { thread_.join(); } - - /** - * @brief Checks if the thread is currently running. - * @return True if the thread is running, false otherwise. - */ - [[nodiscard]] auto running() const noexcept -> bool { - return thread_.joinable(); - } - - /** - * @brief Swaps the content of this Thread object with another Thread - * object. - * @param other The Thread object to swap with. - */ - void swap(Thread& other) noexcept { thread_.swap(other.thread_); } - - /** - * @brief Gets the underlying std::jthread object. - * @return Reference to the underlying std::jthread object. - */ - [[nodiscard]] auto getThread() noexcept -> std::jthread& { return thread_; } - - /** - * @brief Gets the underlying std::jthread object (const version). - * @return Constant reference to the underlying std::jthread object. - */ - [[nodiscard]] auto getThread() const noexcept -> const std::jthread& { - return thread_; - } - - /** - * @brief Gets the ID of the thread. - * @return The ID of the thread. - */ - [[nodiscard]] auto getId() const noexcept -> std::thread::id { - return thread_.get_id(); - } - - /** - * @brief Gets the underlying std::stop_source object. - * @return The underlying std::stop_source object. - */ - [[nodiscard]] auto getStopSource() noexcept -> std::stop_source { - return thread_.get_stop_source(); - } - - /** - * @brief Gets the underlying std::stop_token object. - * @return The underlying std::stop_token object. - */ - [[nodiscard]] auto getStopToken() const noexcept -> std::stop_token { - return thread_.get_stop_token(); - } - - /** - * @brief Default destructor. - */ - ~Thread() = default; - -private: - std::jthread thread_; ///< The underlying jthread object. -}; -} // namespace atom::async - -#endif diff --git a/src/atom/async/threadlocal.hpp b/src/atom/async/threadlocal.hpp deleted file mode 100644 index c7746b82..00000000 --- a/src/atom/async/threadlocal.hpp +++ /dev/null @@ -1,262 +0,0 @@ -/* - * threadlocal.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-4-16 - -Description: ThreadLocal - -**************************************************/ - -#ifndef ATOM_ASYNC_THREADLOCAL_HPP -#define ATOM_ASYNC_THREADLOCAL_HPP - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "atom/type/noncopyable.hpp" - -namespace atom::async { -/** - * @brief A thread-local storage class template that provides thread-specific - * storage for objects of type T. - * - * This class allows each thread to maintain its own independent instance of T, - * with optional initialization and a variety of access methods. It is not - * copyable to ensure each instance is unique per thread. - * - * @tparam T The type of the value that will be stored in thread-local storage. - */ -template -class ThreadLocal : public NonCopyable { -public: - using InitializerFn = - std::function; ///< Type definition for the initializer function. - - /** - * @brief Default constructor for ThreadLocal. - * - * Initializes an instance of ThreadLocal without an initializer. - */ - ThreadLocal() = default; - - /** - * @brief Constructs a ThreadLocal instance with an initializer function. - * - * @param initializer A function that is called to initialize the value the - * first time it is accessed. - */ - explicit ThreadLocal(InitializerFn initializer); - - // Move constructor - ThreadLocal(ThreadLocal&&) noexcept = default; - - /** - * @brief Move assignment operator. - * - * @param other The ThreadLocal instance to move from. - * @return A reference to this instance after the move. - */ - auto operator=(ThreadLocal&&) noexcept -> ThreadLocal& = default; - - /** - * @brief Retrieves the thread-local value. - * - * If the value has not been initialized for the current thread, the - * initializer function is called to create it. - * - * @return A reference to the thread-local value of type T. - */ - auto get() -> T&; - - /** - * @brief Access the thread-local value using the arrow operator. - * - * @return A pointer to the thread-local value of type T. - */ - auto operator->() -> T*; - - /** - * @brief Access the thread-local value using the arrow operator (const - * version). - * - * @return A pointer to the thread-local value of type T (const version). - */ - auto operator->() const -> const T*; - - /** - * @brief Dereference the thread-local value. - * - * @return A reference to the thread-local value of type T. - */ - auto operator*() -> T&; - - /** - * @brief Dereference the thread-local value (const version). - * - * @return A const reference to the thread-local value of type T. - */ - auto operator*() const -> const T&; - - /** - * @brief Resets the value in thread-local storage. - * - * If a value is provided, it will be set to the thread-local value. If no - * value is provided, the thread-local value will be reset to its default - * constructed value. - * - * @param value The value to set; the default is T(), which is the default - * constructed value of T. - */ - void reset(T value = T()); - - /** - * @brief Checks if the current thread has a value. - * - * @return true if the current thread has an initialized value, otherwise - * false. - */ - auto hasValue() const -> bool; - - /** - * @brief Retrieves a pointer to the thread-local value. - * - * If the value has not been initialized, this will return a nullptr. - * - * @return A pointer to the thread-local value of type T. - */ - auto getPointer() -> T*; - - /** - * @brief Retrieves a pointer to the thread-local value (const version). - * - * @return A const pointer to the thread-local value of type T. - */ - auto getPointer() const -> const T*; - - /** - * @brief Executes a function for each thread-local value. - * - * This allows the caller to provide a function that will be called with the - * value of type T for each thread that has an initialized value. - * - * @tparam Func A callable type (e.g., a lambda or a function pointer) that - * takes a reference to T. - * @param func The function to execute for each thread-local value. - */ - template Func> - void forEach(Func&& func); - - /** - * @brief Clears the thread-local storage for the current thread. - * - * This will remove the value associated with the current thread. - */ - void clear(); - -private: - InitializerFn initializer_; ///< The function used to initialize T. - mutable std::shared_mutex mutex_; ///< Mutex for thread-safe access. - std::unordered_map> - values_; ///< Store values by thread ID. -}; - -template -ThreadLocal::ThreadLocal(InitializerFn initializer) - : initializer_(std::move(initializer)) {} - -template -auto ThreadLocal::get() -> T& { - auto tid = std::this_thread::get_id(); - std::unique_lock lock(mutex_); - auto [it, inserted] = values_.try_emplace(tid); - if (inserted && initializer_) { - it->second = std::make_optional(initializer_()); - } - lock.unlock(); - return it->second.value(); -} - -template -auto ThreadLocal::operator->() -> T* { - return &get(); -} - -template -auto ThreadLocal::operator->() const -> const T* { - return &get(); -} - -template -auto ThreadLocal::operator*() -> T& { - return get(); -} - -template -auto ThreadLocal::operator*() const -> const T& { - return get(); -} - -template -void ThreadLocal::reset(T value) { - auto tid = std::this_thread::get_id(); - std::unique_lock lock(mutex_); - values_[tid] = std::make_optional(std::move(value)); -} - -template -auto ThreadLocal::hasValue() const -> bool { - auto tid = std::this_thread::get_id(); - std::shared_lock lock(mutex_); - auto it = values_.find(tid); - return it != values_.end() && it->second.has_value(); -} - -template -auto ThreadLocal::getPointer() -> T* { - auto tid = std::this_thread::get_id(); - std::shared_lock lock(mutex_); - auto it = values_.find(tid); - return it != values_.end() && it->second.has_value() ? &it->second.value() - : nullptr; -} - -template -auto ThreadLocal::getPointer() const -> const T* { - auto tid = std::this_thread::get_id(); - std::shared_lock lock(mutex_); - auto it = values_.find(tid); - return it != values_.end() && it->second.has_value() ? &it->second.value() - : nullptr; -} - -template -template Func> -void ThreadLocal::forEach(Func&& func) { - std::unique_lock lock(mutex_); - for (auto& [tid, value_opt] : values_) { - if (value_opt.has_value()) { - func(value_opt.value()); - } - } -} - -template -void ThreadLocal::clear() { - std::unique_lock lock(mutex_); - values_.clear(); -} - -} // namespace atom::async - -#endif // ATOM_ASYNC_THREADLOCAL_HPP diff --git a/src/atom/async/timer.cpp b/src/atom/async/timer.cpp deleted file mode 100644 index f35b56aa..00000000 --- a/src/atom/async/timer.cpp +++ /dev/null @@ -1,131 +0,0 @@ -/* - * timer.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-12-14 - -Description: Timer class for C++ - -**************************************************/ - -#include "timer.hpp" -#include -#include "error/exception.hpp" - -namespace atom::async { -TimerTask::TimerTask(std::function func, unsigned int delay, - int repeatCount, int priority) - : m_func(func), - m_delay(delay), - m_repeatCount(repeatCount), - m_priority(priority) { - m_nextExecutionTime = - std::chrono::steady_clock::now() + std::chrono::milliseconds(m_delay); -} - -auto TimerTask::operator<(const TimerTask &other) const -> bool { - if (m_priority != other.m_priority) { - return m_priority > other.m_priority; - } - return m_nextExecutionTime > other.m_nextExecutionTime; -} - -void TimerTask::run() { - try { - m_func(); - } catch (const std::exception &e) { - THROW_RUNTIME_ERROR("Failed to run timer task: ", e.what()); - } - if (m_repeatCount > 0) { - --m_repeatCount; - if (m_repeatCount > 0) { - m_nextExecutionTime = std::chrono::steady_clock::now() + - std::chrono::milliseconds(m_delay); - } - } -} - -std::chrono::steady_clock::time_point TimerTask::getNextExecutionTime() const { - return m_nextExecutionTime; -} - -Timer::Timer() : m_stop(false), m_paused(false) { - m_thread = std::thread(&Timer::run, this); -} - -Timer::~Timer() { - stop(); - if (m_thread.joinable()) { - m_thread.join(); - } -} - -void Timer::cancelAllTasks() { - std::unique_lock lock(m_mutex); - m_taskQueue = std::priority_queue(); - m_cond.notify_all(); -} - -void Timer::pause() { m_paused = true; } - -void Timer::resume() { - m_paused = false; - m_cond.notify_all(); -} - -void Timer::stop() { - m_stop = true; - m_cond.notify_all(); -} - -auto Timer::now() const -> std::chrono::steady_clock::time_point { - return std::chrono::steady_clock::now(); -} - -void Timer::run() { - while (!m_stop) { - std::unique_lock lock(m_mutex); - while (!m_stop && m_paused && m_taskQueue.empty()) { - m_cond.wait(lock, [&]() { - return m_stop || !m_paused || !m_taskQueue.empty(); - }); - } - if (m_stop) { - break; - } - if (!m_taskQueue.empty()) { - TimerTask task = m_taskQueue.top(); - if (std::chrono::steady_clock::now() >= - task.getNextExecutionTime()) { - m_taskQueue.pop(); - lock.unlock(); - task.run(); - if (task.m_repeatCount > 0) { - std::unique_lock lock(m_mutex); - m_taskQueue.emplace(task.m_func, task.m_delay, - task.m_repeatCount, task.m_priority); - } - if (m_callback) { - m_callback(); - } - } else { - m_cond.wait_until(lock, task.getNextExecutionTime()); - } - } - } -} - -auto Timer::getTaskCount() const -> size_t { - std::unique_lock lock(m_mutex); - return m_taskQueue.size(); -} - -void Timer::wait() { - std::unique_lock lock(m_mutex); - m_cond.wait(lock, [&]() { return m_taskQueue.empty(); }); -} -} // namespace atom::async diff --git a/src/atom/async/timer.hpp b/src/atom/async/timer.hpp deleted file mode 100644 index ab5c3a6b..00000000 --- a/src/atom/async/timer.hpp +++ /dev/null @@ -1,244 +0,0 @@ -/* - * timer.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-12-14 - -Description: Timer class for C++ - -**************************************************/ - -#ifndef ATOM_ASYNC_TIMER_HPP -#define ATOM_ASYNC_TIMER_HPP - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "future.hpp" - -namespace atom::async { -/** - * @brief Represents a task to be scheduled and executed by the Timer. - */ -class TimerTask { -public: - /** - * @brief Constructor for TimerTask. - * - * @param func The function to be executed when the task runs. - * @param delay The delay in milliseconds before the first execution. - * @param repeatCount The number of times the task should be repeated. -1 - * for infinite repetition. - * @param priority The priority of the task. - */ - explicit TimerTask(std::function func, unsigned int delay, - int repeatCount, int priority); - - /** - * @brief Comparison operator for comparing two TimerTask objects based on - * their next execution time. - * - * @param other Another TimerTask object to compare to. - * @return True if this task's next execution time is earlier than the other - * task's next execution time. - */ - auto operator<(const TimerTask &other) const -> bool; - - /** - * @brief Executes the task's associated function. - */ - void run(); - - /** - * @brief Get the next scheduled execution time of the task. - * - * @return The steady clock time point representing the next execution time. - */ - auto getNextExecutionTime() const -> std::chrono::steady_clock::time_point; - - std::function m_func; ///< The function to be executed. - unsigned int m_delay; ///< The delay before the first execution. - int m_repeatCount; ///< The number of repetitions remaining. - int m_priority; ///< The priority of the task. - std::chrono::steady_clock::time_point - m_nextExecutionTime; ///< The next execution time. -}; - -/** - * @brief Represents a timer for scheduling and executing tasks. - */ -class Timer { -public: - /** - * @brief Constructor for Timer. - */ - Timer(); - - /** - * @brief Destructor for Timer. - */ - ~Timer(); - - /** - * @brief Schedules a task to be executed once after a specified delay. - * - * @tparam Function The type of the function to be executed. - * @tparam Args The types of the arguments to be passed to the function. - * @param func The function to be executed. - * @param delay The delay in milliseconds before the function is executed. - * @param args The arguments to be passed to the function. - * @return An EnhancedFuture representing the result of the function - * execution. - */ - template - [[nodiscard]] auto setTimeout(Function &&func, unsigned int delay, - Args &&...args) - -> EnhancedFuture::type>; - - /** - * @brief Schedules a task to be executed repeatedly at a specified - * interval. - * - * @tparam Function The type of the function to be executed. - * @tparam Args The types of the arguments to be passed to the function. - * @param func The function to be executed. - * @param interval The interval in milliseconds between executions. - * @param repeatCount The number of times the function should be repeated. - * -1 for infinite repetition. - * @param priority The priority of the task. - * @param args The arguments to be passed to the function. - */ - template - void setInterval(Function &&func, unsigned int interval, int repeatCount, - int priority, Args &&...args); - - [[nodiscard]] auto now() const -> std::chrono::steady_clock::time_point; - - /** - * @brief Cancels all scheduled tasks. - */ - void cancelAllTasks(); - - /** - * @brief Pauses the execution of scheduled tasks. - */ - void pause(); - - /** - * @brief Resumes the execution of scheduled tasks after pausing. - */ - void resume(); - - /** - * @brief Stops the timer and cancels all tasks. - */ - void stop(); - - /** - * @brief Blocks the calling thread until all tasks are completed. - */ - void wait(); - - /** - * @brief Sets a callback function to be called when a task is executed. - * - * @tparam Function The type of the callback function. - * @param func The callback function to be set. - */ - template - void setCallback(Function &&func); - - [[nodiscard]] auto getTaskCount() const -> size_t; - -private: - /** - * @brief Adds a task to the task queue. - * - * @tparam Function The type of the function to be executed. - * @tparam Args The types of the arguments to be passed to the function. - * @param func The function to be executed. - * @param delay The delay in milliseconds before the function is executed. - * @param repeatCount The number of repetitions remaining. - * @param priority The priority of the task. - * @param args The arguments to be passed to the function. - * @return An EnhancedFuture representing the result of the function - * execution. - */ - template - auto addTask(Function &&func, unsigned int delay, int repeatCount, - int priority, Args &&...args) - -> EnhancedFuture::type>; - - /** - * @brief Main execution loop for processing and running tasks. - */ - void run(); - -#if _cplusplus >= 202203L - std::jthread - m_thread; ///< The thread for running the timer loop (C++20 or later). -#else - std::thread m_thread; ///< The thread for running the timer loop. -#endif - std::priority_queue - m_taskQueue; ///< The priority queue for scheduled tasks. - mutable std::mutex m_mutex; ///< Mutex for thread synchronization. - std::condition_variable - m_cond; ///< Condition variable for thread synchronization. - std::function m_callback; ///< The callback function to be called - ///< when a task is executed. - bool m_stop; ///< Flag indicating whether the timer should stop. - bool m_paused; ///< Flag indicating whether the timer is paused. -}; - -template -auto Timer::setTimeout(Function &&func, unsigned int delay, Args &&...args) - -> EnhancedFuture::type> { - using ReturnType = typename std::result_of::type; - auto task = std::make_shared>( - std::bind(std::forward(func), std::forward(args)...)); - std::future result = task->get_future(); - std::unique_lock lock(m_mutex); - m_taskQueue.emplace([task]() { (*task)(); }, delay, 1, 0); - m_cond.notify_all(); - return EnhancedFuture(std::move(result).share()); -} - -template -void Timer::setInterval(Function &&func, unsigned int interval, int repeatCount, - int priority, Args &&...args) { - addTask(std::forward(func), interval, repeatCount, priority, - std::forward(args)...); -} - -template -auto Timer::addTask(Function &&func, unsigned int delay, int repeatCount, - int priority, Args &&...args) - -> EnhancedFuture::type> { - using ReturnType = typename std::result_of::type; - auto task = std::make_shared>( - std::bind(std::forward(func), std::forward(args)...)); - std::future result = task->get_future(); - std::unique_lock lock(m_mutex); - m_taskQueue.emplace([task]() { (*task)(); }, delay, repeatCount, priority); - m_cond.notify_all(); - return EnhancedFuture(std::move(result).share()); -} - -template -void Timer::setCallback(Function &&func) { - m_callback = std::forward(func); -} -} // namespace atom::async - -#endif \ No newline at end of file diff --git a/src/atom/async/trigger.hpp b/src/atom/async/trigger.hpp deleted file mode 100644 index 7719855e..00000000 --- a/src/atom/async/trigger.hpp +++ /dev/null @@ -1,236 +0,0 @@ -/* - * trigger.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-12-14 - -Description: Trigger class for C++ - -**************************************************/ - -#ifndef ATOM_ASYNC_TRIGGER_HPP -#define ATOM_ASYNC_TRIGGER_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::async { - -/** - * @brief Concept to check if a type can be invoked with a given parameter type. - * - * This concept checks if a std::function taking a parameter of type ParamType - * is invocable with an instance of ParamType. - * - * @tparam ParamType The parameter type to check for. - */ -template -concept CallableWithParam = requires(ParamType p) { - std::invoke(std::declval>(), p); -}; - -/** - * @brief A class for handling event-driven callbacks with parameter support. - * - * This class allows users to register, unregister, and trigger callbacks for - * different events, providing a mechanism to manage callbacks with priorities - * and delays. - * - * @tparam ParamType The type of parameter to be passed to the callbacks. - */ -template - requires CallableWithParam -class Trigger { -public: - using Callback = std::function; ///< Type alias for the - ///< callback function. - - /// Enumeration for callback priority levels. - enum class CallbackPriority { High, Normal, Low }; - - /** - * @brief Registers a callback for a specified event. - * - * @param event The name of the event for which the callback is registered. - * @param callback The callback function to be executed when the event is - * triggered. - * @param priority The priority level of the callback (default is Normal). - */ - void registerCallback(const std::string& event, Callback callback, - CallbackPriority priority = CallbackPriority::Normal); - - /** - * @brief Unregisters a callback for a specified event. - * - * @param event The name of the event from which the callback is - * unregistered. - * @param callback The callback function to be removed. - * - * If the callback is not registered for the event, no action is taken. - */ - void unregisterCallback(const std::string& event, Callback callback); - - /** - * @brief Triggers the callbacks associated with a specified event. - * - * @param event The name of the event to trigger. - * @param param The parameter to be passed to the callbacks. - * - * All callbacks registered for the event are executed with the provided - * parameter. - */ - void trigger(const std::string& event, const ParamType& param); - - /** - * @brief Schedules a trigger for a specified event after a delay. - * - * @param event The name of the event to trigger. - * @param param The parameter to be passed to the callbacks. - * @param delay The delay after which to trigger the event, specified in - * milliseconds. - */ - void scheduleTrigger(const std::string& event, const ParamType& param, - std::chrono::milliseconds delay); - - /** - * @brief Schedules an asynchronous trigger for a specified event. - * - * @param event The name of the event to trigger. - * @param param The parameter to be passed to the callbacks. - * @return A future representing the ongoing operation to trigger the event. - */ - auto scheduleAsyncTrigger(const std::string& event, - const ParamType& param) -> std::future; - - /** - * @brief Cancels the scheduled trigger for a specified event. - * - * @param event The name of the event for which to cancel the trigger. - * - * This will prevent the execution of any scheduled callbacks for the event. - */ - void cancelTrigger(const std::string& event); - - /** - * @brief Cancels all scheduled triggers. - * - * This method clears all scheduled callbacks for any events. - */ - void cancelAllTriggers(); - -private: - std::mutex m_mutex_; ///< Mutex for thread-safe access to the internal - ///< callback structures. - std::unordered_map>> - m_callbacks_; ///< Map of events to their callbacks and priorities. -}; - -template - requires CallableWithParam -void Trigger::registerCallback(const std::string& event, - Callback callback, - CallbackPriority priority) { - std::scoped_lock lock(m_mutex_); - auto& callbacks = m_callbacks_[event]; - if (auto pos = std::ranges::find_if( - callbacks, - [&callback](const auto& cb) { - return cb.second.target_type() == callback.target_type() && - cb.second.template target() == - callback.template target(); - }); - pos != callbacks.end()) { - pos->first = priority; - } else { - callbacks.emplace_back(priority, callback); - } -} - -template - requires CallableWithParam -void Trigger::unregisterCallback(const std::string& event, - Callback callback) { - std::scoped_lock lock(m_mutex_); - auto& callbacks = m_callbacks_[event]; - std::erase_if(callbacks, [&callback](const auto& cb) { - return cb.second.target_type() == callback.target_type() && - cb.second.template target() == - callback.template target(); - }); -} - -template - requires CallableWithParam -void Trigger::trigger(const std::string& event, - const ParamType& param) { - std::scoped_lock lock(m_mutex_); - auto& callbacks = m_callbacks_[event]; - std::ranges::sort(callbacks, [](const auto& cb1, const auto& cb2) { - return static_cast(cb1.first) > static_cast(cb2.first); - }); - for (auto& [priority, callback] : callbacks) { - try { - callback(param); - } catch (...) { - // Swallow exceptions in callbacks - } - } -} - -template - requires CallableWithParam -void Trigger::scheduleTrigger(const std::string& event, - const ParamType& param, - std::chrono::milliseconds delay) { - std::jthread([this, event, param, delay]() { - std::this_thread::sleep_for(delay); - trigger(event, param); - }).detach(); -} - -template - requires CallableWithParam -auto Trigger::scheduleAsyncTrigger( - const std::string& event, const ParamType& param) -> std::future { - auto promise = std::make_shared>(); - auto future = promise->get_future(); - std::jthread([this, event, param, promise]() mutable { - try { - trigger(event, param); - promise->set_value(); - } catch (...) { - promise->set_exception(std::current_exception()); - } - }).detach(); - return future; -} - -template - requires CallableWithParam -void Trigger::cancelTrigger(const std::string& event) { - std::scoped_lock lock(m_mutex_); - m_callbacks_.erase(event); -} - -template - requires CallableWithParam -void Trigger::cancelAllTriggers() { - std::scoped_lock lock(m_mutex_); - m_callbacks_.clear(); -} - -} // namespace atom::async - -#endif // ATOM_ASYNC_TRIGGER_HPP diff --git a/src/atom/async/xmake.lua b/src/atom/async/xmake.lua deleted file mode 100644 index 6af35683..00000000 --- a/src/atom/async/xmake.lua +++ /dev/null @@ -1,48 +0,0 @@ --- xmake.lua for Atom-Async --- This project is licensed under the terms of the GPL3 license. --- --- Project Name: Atom-Async --- Description: Async Implementation of Lithium Server and Driver --- Author: Max Qian --- License: GPL3 - -add_rules("mode.debug", "mode.release") - --- Set project name -set_project("atom-async") - --- Set languages -set_languages("cxx17") - --- Set source files -add_files("lock.cpp", "timer.cpp") - --- Set header files -add_headerfiles("*.hpp", "*.inl") - --- Set link libraries -add_linkdirs("path/to/loguru/library") -- Replace with actual path to loguru library -add_links("loguru") - --- Build static library -target("atom-async") - set_kind("static") - add_deps("atom-async-object") - add_files("lock.cpp", "timer.cpp") - add_headerfiles("*.hpp", "*.inl") - add_includedirs(".") - add_linkdirs(".") - add_links("loguru") - --- Build object library -target("atom-async-object") - set_kind("object") - add_files("lock.cpp", "timer.cpp") - add_headerfiles("*.hpp", "*.inl") - add_includedirs(".") - add_linkdirs(".") - add_links("loguru") - --- Install target -set_configvar("xmake", "installdir", "/path/to/installation/directory") -- Replace with actual installation directory -add_installfiles("build/lib/*.a", {prefixdir = "lib"}) diff --git a/src/atom/components/CMakeLists.txt b/src/atom/components/CMakeLists.txt deleted file mode 100644 index 24a7f51f..00000000 --- a/src/atom/components/CMakeLists.txt +++ /dev/null @@ -1,56 +0,0 @@ -# CMakeLists.txt for Atom-Component -# This project adheres to the GPL3 license. -# -# Project Details: -# Name: Atom-Component -# Description: Central component library for the Atom framework -# Author: Max Qian -# License: GPL3 - -cmake_minimum_required(VERSION 3.20) -project(atom-component LANGUAGES C CXX) -# Source files with project-specific prefix -set(${PROJECT_NAME}_SOURCES - component.cpp - dispatch.cpp - registry.cpp - var.cpp -) - -set(${PROJECT_NAME}_HEADERS - component.hpp - dispatch.hpp - types.hpp - var.hpp -) - -# Dependencies -set(${PROJECT_NAME}_LIBS - loguru - atom-error - atom-type - atom-utils -) - -# Include directories -include_directories(.) - -set(CMAKE_POSITION_INDEPENDENT_CODE ON) - -# Object library for headers and sources with project prefix -add_library(${PROJECT_NAME}_OBJECT OBJECT ${${PROJECT_NAME}_HEADERS} ${${PROJECT_NAME}_SOURCES}) -# set_target_properties(${PROJECT_NAME}_OBJECT PROPERTIES LINKER_LANGUAGE CXX) - -# Static library target -add_library(${PROJECT_NAME} SHARED $) - -# Set project properties and definitions -# set_property(TARGET ${PROJECT_NAME} PROPERTY POSITION_INDEPENDENT_CODE ON) - -# Link dependencies to the main target -target_link_libraries(${PROJECT_NAME} PRIVATE ${${PROJECT_NAME}_LIBS} ${CMAKE_THREAD_LIBS_INIT}) - -# Install rules -install(TARGETS ${PROJECT_NAME} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} -) diff --git a/src/atom/components/component.cpp b/src/atom/components/component.cpp deleted file mode 100644 index d3575554..00000000 --- a/src/atom/components/component.cpp +++ /dev/null @@ -1,241 +0,0 @@ -/* - * component.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-12-26 - -Description: Basic Component Definition - -**************************************************/ - -#include "component.hpp" - -#include "atom/log/loguru.hpp" - -Component::Component(std::string name) : m_name_(std::move(name)) { - LOG_F(INFO, "Component created: {}", m_name_); -} - -auto Component::getInstance() const -> std::weak_ptr { - LOG_SCOPE_FUNCTION(INFO); - return shared_from_this(); -} - -auto Component::initialize() -> bool { - LOG_SCOPE_FUNCTION(INFO); - LOG_F(INFO, "Initializing component: {}", m_name_); - return true; -} - -auto Component::destroy() -> bool { - LOG_SCOPE_FUNCTION(INFO); - LOG_F(INFO, "Destroying component: {}", m_name_); - return true; -} - -auto Component::getName() const -> std::string { - LOG_SCOPE_FUNCTION(INFO); - return m_name_; -} - -auto Component::getTypeInfo() const -> atom::meta::TypeInfo { - LOG_SCOPE_FUNCTION(INFO); - return m_typeInfo_; -} - -void Component::setTypeInfo(atom::meta::TypeInfo typeInfo) { - LOG_SCOPE_FUNCTION(INFO); - m_typeInfo_ = typeInfo; -} - -void Component::addAlias(const std::string& name, - const std::string& alias) const { - LOG_SCOPE_FUNCTION(INFO); - LOG_F(INFO, "Adding alias '{}' for command '{}'", alias, name); - m_CommandDispatcher_->addAlias(name, alias); -} - -void Component::addGroup(const std::string& name, - const std::string& group) const { - LOG_SCOPE_FUNCTION(INFO); - LOG_F(INFO, "Adding command '{}' to group '{}'", name, group); - m_CommandDispatcher_->addGroup(name, group); -} - -void Component::setTimeout(const std::string& name, - std::chrono::milliseconds timeout) const { - LOG_SCOPE_FUNCTION(INFO); - LOG_F(INFO, "Setting timeout for command '{}': {} ms", name, - timeout.count()); - m_CommandDispatcher_->setTimeout(name, timeout); -} - -void Component::removeCommand(const std::string& name) const { - LOG_SCOPE_FUNCTION(INFO); - LOG_F(INFO, "Removing command '{}'", name); - m_CommandDispatcher_->removeCommand(name); -} - -auto Component::getCommandsInGroup(const std::string& group) const - -> std::vector { - LOG_SCOPE_FUNCTION(INFO); - return m_CommandDispatcher_->getCommandsInGroup(group); -} - -auto Component::getCommandDescription(const std::string& name) const - -> std::string { - LOG_SCOPE_FUNCTION(INFO); - return m_CommandDispatcher_->getCommandDescription(name); -} - -#if ENABLE_FASTHASH -emhash::HashSet Component::getCommandAliases( - const std::string& name) const -#else -auto Component::getCommandAliases(const std::string& name) const - -> std::unordered_set -#endif -{ - LOG_SCOPE_FUNCTION(INFO); - return m_CommandDispatcher_->getCommandAliases(name); -} - -auto Component::getCommandArgAndReturnType(const std::string& name) - -> std::pair, std::string> { - LOG_SCOPE_FUNCTION(INFO); - return m_CommandDispatcher_->getCommandArgAndReturnType(name); -} - -auto Component::getNeededComponents() -> std::vector { - LOG_SCOPE_FUNCTION(INFO); - return {}; -} - -void Component::addOtherComponent(const std::string& name, - const std::weak_ptr& component) { - LOG_SCOPE_FUNCTION(INFO); - if (m_OtherComponents_.contains(name)) { - LOG_F(ERROR, "Component '{}' already exists", name); - THROW_OBJ_ALREADY_EXIST(name); - } - LOG_F(INFO, "Adding other component '{}'", name); - m_OtherComponents_[name] = component; -} - -void Component::removeOtherComponent(const std::string& name) { - LOG_SCOPE_FUNCTION(INFO); - LOG_F(INFO, "Removing other component '{}'", name); - m_OtherComponents_.erase(name); -} - -void Component::clearOtherComponents() { - LOG_SCOPE_FUNCTION(INFO); - LOG_F(INFO, "Clearing all other components"); - m_OtherComponents_.clear(); -} - -auto Component::getOtherComponent(const std::string& name) - -> std::weak_ptr { - LOG_SCOPE_FUNCTION(INFO); - if (m_OtherComponents_.contains(name)) { - return m_OtherComponents_[name]; - } - return {}; -} - -bool Component::has(const std::string& name) const { - LOG_SCOPE_FUNCTION(INFO); - return m_CommandDispatcher_->has(name); -} - -bool Component::hasType(std::string_view name) const { - LOG_SCOPE_FUNCTION(INFO); - if (auto it = m_classes_.find(name); it != m_classes_.end()) { - return true; - } - return false; -} - -auto Component::getAllCommands() const -> std::vector { - LOG_SCOPE_FUNCTION(INFO); - if (m_CommandDispatcher_ == nullptr) { - LOG_F(ERROR, "Component command dispatch is not initialized"); - THROW_OBJ_UNINITIALIZED( - "Component command dispatch is not initialized"); - } - return m_CommandDispatcher_->getAllCommands(); -} - -auto Component::getRegisteredTypes() const -> std::vector { - LOG_SCOPE_FUNCTION(INFO); - return m_TypeCaster_->getRegisteredTypes(); -} - -auto Component::runCommand(const std::string& name, - const std::vector& args) -> std::any { - LOG_SCOPE_FUNCTION(INFO); - auto cmd = getAllCommands(); - - if (auto it = std::ranges::find(cmd, name); it != cmd.end()) { - LOG_F(INFO, "Running command '{}'", name); - return m_CommandDispatcher_->dispatch(name, args); - } - for (const auto& [key, value] : m_OtherComponents_) { - if (!value.expired() && value.lock()->has(name)) { - LOG_F(INFO, "Running command '{}' in other component '{}'", name, - key); - return value.lock()->dispatch(name, args); - } - LOG_F(ERROR, "Component '{}' has expired", key); - m_OtherComponents_.erase(key); - } - - LOG_F(ERROR, "Command '{}' not found", name); - THROW_EXCEPTION("Component ", name, " not found"); -} - -void Component::doc(const std::string& description) { - LOG_SCOPE_FUNCTION(INFO); - m_doc_ = description; -} - -auto Component::getDoc() const -> std::string { - LOG_SCOPE_FUNCTION(INFO); - return m_doc_; -} - -void Component::defClassConversion( - const std::shared_ptr& conversion) { - LOG_SCOPE_FUNCTION(INFO); - m_TypeConverter_->addConversion(conversion); -} - -auto Component::hasVariable(const std::string& name) const -> bool { - LOG_SCOPE_FUNCTION(INFO); - return m_VariableManager_->has(name); -} - -auto Component::getVariableDescription(const std::string& name) const - -> std::string { - LOG_SCOPE_FUNCTION(INFO); - return m_VariableManager_->getDescription(name); -} - -auto Component::getVariableAlias(const std::string& name) const -> std::string { - LOG_SCOPE_FUNCTION(INFO); - return m_VariableManager_->getAlias(name); -} - -auto Component::getVariableGroup(const std::string& name) const -> std::string { - LOG_SCOPE_FUNCTION(INFO); - return m_VariableManager_->getGroup(name); -} - -auto Component::getVariableNames() const -> std::vector { - LOG_SCOPE_FUNCTION(INFO); - return m_VariableManager_->getAllVariables(); -} diff --git a/src/atom/components/component.hpp b/src/atom/components/component.hpp deleted file mode 100644 index 68fc094f..00000000 --- a/src/atom/components/component.hpp +++ /dev/null @@ -1,756 +0,0 @@ -/* - * component.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-12-26 - -Description: Basic Component Definition - -**************************************************/ - -#ifndef ATOM_COMPONENT_HPP -#define ATOM_COMPONENT_HPP - -#include -#include -#include -#include -#include - -#include "dispatch.hpp" -#include "module_macro.hpp" -#include "var.hpp" - -#include "atom/function/concept.hpp" -#include "atom/function/constructor.hpp" -#include "atom/function/conversion.hpp" -#include "atom/function/func_traits.hpp" -#include "atom/function/type_caster.hpp" -#include "atom/function/type_info.hpp" -#include "atom/log/loguru.hpp" -#include "atom/type/pointer.hpp" - -class Component : public std::enable_shared_from_this { -public: - /** - * @brief Type definition for initialization function. - */ - using InitFunc = std::function; - - /** - * @brief Type definition for cleanup function. - */ - using CleanupFunc = std::function; - - /** - * @brief Constructs a new Component object. - */ - explicit Component(std::string name); - - /** - * @brief Destroys the Component object. - */ - virtual ~Component() = default; - - // ------------------------------------------------------------------- - // Inject methods - // ------------------------------------------------------------------- - - auto getInstance() const -> std::weak_ptr; - - auto getSharedInstance() -> std::shared_ptr { - return shared_from_this(); - } - - // ------------------------------------------------------------------- - // Common methods - // ------------------------------------------------------------------- - - /** - * @brief Initializes the plugin. - * - * @return True if the plugin was initialized successfully, false otherwise. - * @note This function is called by the server when the plugin is loaded. - * @note This function should be overridden by the plugin. - */ - virtual auto initialize() -> bool; - - /** - * @brief Destroys the plugin. - * - * @return True if the plugin was destroyed successfully, false otherwise. - * @note This function is called by the server when the plugin is unloaded. - * @note This function should be overridden by the plugin. - * @note The plugin should not be used after this function is called. - * @note This is for the plugin to release any resources it has allocated. - */ - virtual auto destroy() -> bool; - - /** - * @brief Gets the name of the plugin. - * - * @return The name of the plugin. - */ - auto getName() const -> std::string; - - /** - * @brief Gets the type information of the plugin. - * - * @return The type information of the plugin. - */ - auto getTypeInfo() const -> atom::meta::TypeInfo; - - /** - * @brief Sets the type information of the plugin. - * - * @param typeInfo The type information of the plugin. - */ - void setTypeInfo(atom::meta::TypeInfo typeInfo); - - // ------------------------------------------------------------------- - // Variable methods - // ------------------------------------------------------------------- - - /** - * @brief Adds a variable to the component. - * @param name The name of the variable. - * @param initialValue The initial value of the variable. - * @param description The description of the variable. - * @param alias The alias of the variable. - * @param group The group of the variable. - */ - template - void addVariable(const std::string& name, T initialValue, - const std::string& description = "", - const std::string& alias = "", - const std::string& group = "") { - m_VariableManager_->addVariable(name, initialValue, description, alias, - group); - } - - /** - * @brief Sets the range of a variable. - * @param name The name of the variable. - * @param min The minimum value of the variable. - * @param max The maximum value of the variable. - */ - template - void setRange(const std::string& name, T min, T max) { - m_VariableManager_->setRange(name, min, max); - } - - /** - * @brief Sets the options of a variable. - * @param name The name of the variable. - * @param options The options of the variable. - */ - void setStringOptions(const std::string& name, - const std::vector& options) { - m_VariableManager_->setStringOptions(name, options); - } - - /** - * @brief Gets a variable by name. - * @param name The name of the variable. - * @return A shared pointer to the variable. - */ - template - auto getVariable(const std::string& name) -> std::shared_ptr> { - return m_VariableManager_->getVariable(name); - } - - /** - * @brief Gets a variable by name. - * @param name The name of the variable. - * @return A shared pointer to the variable. - */ - [[nodiscard]] auto hasVariable(const std::string& name) const -> bool; - - /** - * @brief Sets the value of a variable. - * @param name The name of the variable. - * @param newValue The new value of the variable. - * @note const char * is not equivalent to std::string, please use - * std::string - */ - template - void setValue(const std::string& name, T newValue) { - m_VariableManager_->setValue(name, newValue); - } - - /** - * @brief Gets the value of a variable. - * @param name The name of the variable. - * @return The value of the variable. - */ - auto getVariableNames() const -> std::vector; - - /** - * @brief Gets the description of a variable. - * @param name The name of the variable. - * @return The description of the variable. - */ - auto getVariableDescription(const std::string& name) const -> std::string; - - /** - * @brief Gets the alias of a variable. - * @param name The name of the variable. - * @return The alias of the variable. - */ - auto getVariableAlias(const std::string& name) const -> std::string; - - /** - * @brief Gets the group of a variable. - * @param name The name of the variable. - * @return The group of the variable. - */ - auto getVariableGroup(const std::string& name) const -> std::string; - - // ------------------------------------------------------------------- - // Function methods - // ------------------------------------------------------------------- - - void doc(const std::string& description); - - auto getDoc() const -> std::string; - - // ------------------------------------------------------------------- - // No Class - // ------------------------------------------------------------------- - - template - void def(const std::string& name, Callable&& func, - const std::string& group = "", - const std::string& description = ""); - - template - void def(const std::string& name, Ret (*func)(), - const std::string& group = "", - const std::string& description = ""); - - template - void def(const std::string& name, Ret (*func)(Args...), - const std::string& group = "", - const std::string& description = ""); - - // ------------------------------------------------------------------- - // Without instance - // ------------------------------------------------------------------- - -#define DEF_MEMBER_FUNC(cv_qualifier) \ - template \ - void def( \ - const std::string& name, Ret (Class::*func)(Args...) cv_qualifier, \ - const std::string& group = "", const std::string& description = ""); - - DEF_MEMBER_FUNC() // Non-const, non-volatile - DEF_MEMBER_FUNC(const) // Const - DEF_MEMBER_FUNC(volatile) // Volatile - DEF_MEMBER_FUNC(const volatile) // Const volatile - DEF_MEMBER_FUNC(noexcept) - DEF_MEMBER_FUNC(const noexcept) - DEF_MEMBER_FUNC(const volatile noexcept) - - template - void def(const std::string& name, VarType Class::*var, - const std::string& group = "", - const std::string& description = ""); - - // ------------------------------------------------------------------- - // With instance - // ------------------------------------------------------------------- - - template - requires Pointer || SmartPointer || - std::is_same_v> - void def(const std::string& name, Ret (Class::*func)(), - const InstanceType& instance, const std::string& group = "", - const std::string& description = ""); - -#define DEF_MEMBER_FUNC_WITH_INSTANCE(cv_qualifier) \ - template \ - requires Pointer || SmartPointer || \ - std::is_same_v> \ - void def(const std::string& name, \ - Ret (Class::*func)(Args...) cv_qualifier, \ - const InstanceType& instance, const std::string& group = "", \ - const std::string& description = ""); - - DEF_MEMBER_FUNC_WITH_INSTANCE() - DEF_MEMBER_FUNC_WITH_INSTANCE(const) - DEF_MEMBER_FUNC_WITH_INSTANCE(volatile) - DEF_MEMBER_FUNC_WITH_INSTANCE(const volatile) - DEF_MEMBER_FUNC_WITH_INSTANCE(noexcept) - DEF_MEMBER_FUNC_WITH_INSTANCE(const noexcept) - DEF_MEMBER_FUNC_WITH_INSTANCE(const volatile noexcept) - - template - requires Pointer || SmartPointer || - std::is_same_v> - void def(const std::string& name, MemberType Class::*var, - const InstanceType& instance, const std::string& group = "", - const std::string& description = ""); - - template - requires Pointer || SmartPointer || - std::is_same_v> - void def(const std::string& name, const MemberType Class::*var, - const InstanceType& instance, const std::string& group = "", - const std::string& description = ""); - - template - requires Pointer || SmartPointer || - std::is_same_v> - void def(const std::string& name, Ret (Class::*getter)() const, - void (Class::*setter)(Ret), const InstanceType& instance, - const std::string& group, const std::string& description); - - // Register a static member variable - template - void def(const std::string& name, MemberType* var, - const std::string& group = "", - const std::string& description = ""); - - // Register a const & static member variable - template - void def(const std::string& name, const MemberType* var, - const std::string& group = "", - const std::string& description = ""); - - template - void def(const std::string& name, const std::string& group = "", - const std::string& description = ""); - - template - void def(const std::string& name, const std::string& group = "", - const std::string& description = ""); - - template - void defConstructor(const std::string& name, const std::string& group = "", - const std::string& description = ""); - - template - void defDefaultConstructor(const std::string& name, - const std::string& group = "", - const std::string& description = ""); - - template - void defType(std::string_view name, const std::string& group = "", - const std::string& description = ""); - - template - void defEnum(const std::string& name, - const std::unordered_map& enumMap); - - template - void defConversion(std::function func); - - template - void defBaseClass(); - - void defClassConversion( - const std::shared_ptr& conversion); - - void addAlias(const std::string& name, const std::string& alias) const; - - void addGroup(const std::string& name, const std::string& group) const; - - void setTimeout(const std::string& name, - std::chrono::milliseconds timeout) const; - - template - auto dispatch(const std::string& name, Args&&... args) -> std::any { - return m_CommandDispatcher_->dispatch(name, - std::forward(args)...); - } - - auto dispatch(const std::string& name, - const std::vector& args) const -> std::any { - return m_CommandDispatcher_->dispatch(name, args); - } - - [[nodiscard]] auto has(const std::string& name) const -> bool; - - [[nodiscard]] auto hasType(std::string_view name) const -> bool; - - template - [[nodiscard]] auto hasConversion() const -> bool; - - void removeCommand(const std::string& name) const; - - auto getCommandsInGroup(const std::string& group) const - -> std::vector; - - auto getCommandDescription(const std::string& name) const -> std::string; - - auto getCommandArgAndReturnType(const std::string& name) - -> std::pair, std::string>; - -#if ENABLE_FASTHASH - emhash::HashSet getCommandAliases( - const std::string& name) const; -#else - auto getCommandAliases(const std::string& name) const - -> std::unordered_set; -#endif - - auto getAllCommands() const -> std::vector; - - auto getRegisteredTypes() const -> std::vector; - - // ------------------------------------------------------------------- - // Other Components methods - // ------------------------------------------------------------------- - /** - * @note This method is not thread-safe. And we must make sure the pointer - * is valid. The PointerSentinel will help you to avoid this problem. We - * will directly get the std::weak_ptr from the pointer. - */ - - /** - * @return The names of the components that are needed by this component. - * @note This will be called when the component is initialized. - */ - static auto getNeededComponents() -> std::vector; - - void addOtherComponent(const std::string& name, - const std::weak_ptr& component); - - void removeOtherComponent(const std::string& name); - - void clearOtherComponents(); - - auto getOtherComponent(const std::string& name) -> std::weak_ptr; - - auto runCommand(const std::string& name, - const std::vector& args) -> std::any; - - InitFunc initFunc; /**< The initialization function for the component. */ - CleanupFunc cleanupFunc; /**< The cleanup function for the component. */ - -private: - std::string m_name_; - std::string m_doc_; - std::string m_configPath_; - std::string m_infoPath_; - atom::meta::TypeInfo m_typeInfo_{atom::meta::userType()}; - std::unordered_map m_classes_; - - ///< managing commands. - std::shared_ptr m_VariableManager_{ - std::make_shared()}; ///< The variable registry for - ///< managing variables. - - std::unordered_map> - m_OtherComponents_; - - std::shared_ptr m_TypeCaster_{ - atom::meta::TypeCaster::createShared()}; - std::shared_ptr m_TypeConverter_{ - atom::meta::TypeConversions::createShared()}; - - std::shared_ptr m_CommandDispatcher_{ - std::make_shared( - m_TypeCaster_)}; ///< The command dispatcher for -}; - -template -auto Component::hasConversion() const -> bool { - if constexpr (std::is_same_v) { - return true; - } - return m_TypeConverter_->canConvert( - atom::meta::userType(), - atom::meta::userType()); -} - -template -void Component::defType(std::string_view name, - [[maybe_unused]] const std::string& group, - [[maybe_unused]] const std::string& description) { - m_classes_[name] = atom::meta::userType(); - m_TypeCaster_->registerType(std::string(name)); -} - -template -void Component::defConversion(std::function func) { - static_assert(!std::is_same_v, - "SourceType and DestinationType must be not the same"); - m_TypeCaster_->registerConversion(func); -} - -template -void Component::defBaseClass() { - static_assert(std::is_base_of_v, - "Derived must be derived from Base"); - m_TypeConverter_->addBaseClass(); -} - -template -void Component::def(const std::string& name, Callable&& func, - const std::string& group, const std::string& description) { - using Traits = atom::meta::FunctionTraits>; - using ReturnType = typename Traits::return_type; - static_assert(Traits::arity <= 8, "Too many arguments"); -// clang-format off - #include "component.template" -// clang-format on -} - -template -void Component::def(const std::string& name, Ret (*func)(), - const std::string& group, const std::string& description) { - m_CommandDispatcher_->def(name, group, description, - std::function(func)); -} - -template -void Component::def(const std::string& name, Ret (*func)(Args...), - const std::string& group, const std::string& description) { - m_CommandDispatcher_->def(name, group, description, - std::function([func](Args... args) { - return func(std::forward(args)...); - })); -} - -#define DEF_MEMBER_FUNC_IMPL(cv_qualifier) \ - template \ - void Component::def( \ - const std::string& name, Ret (Class::*func)(Args...) cv_qualifier, \ - const std::string& group, const std::string& description) { \ - auto boundFunc = atom::meta::bindMemberFunction(func); \ - m_CommandDispatcher_->def( \ - name, group, description, \ - std::function, Args...)>( \ - [boundFunc](std::reference_wrapper instance, \ - Args... args) -> Ret { \ - return boundFunc(instance, std::forward(args)...); \ - })); \ - } - -DEF_MEMBER_FUNC_IMPL() -DEF_MEMBER_FUNC_IMPL(const) -DEF_MEMBER_FUNC_IMPL(volatile) -DEF_MEMBER_FUNC_IMPL(const volatile) -DEF_MEMBER_FUNC_IMPL(noexcept) -DEF_MEMBER_FUNC_IMPL(const noexcept) -DEF_MEMBER_FUNC_IMPL(const volatile noexcept) - -template - requires Pointer || SmartPointer || - std::is_same_v> -void Component::def(const std::string& name, Ret (Class::*func)(), - const InstanceType& instance, const std::string& group, - const std::string& description) { - m_CommandDispatcher_->def(name, group, description, - std::function([instance, func]() { - return std::invoke(func, instance.get()); - })); -} - -template - requires Pointer || SmartPointer || - std::is_same_v> -void Component::def(const std::string& name, Ret (Class::*func)(Args...), - const InstanceType& instance, const std::string& group, - const std::string& description) { - m_CommandDispatcher_->def( - name, group, description, - std::function([instance, func](Args... args) { - return std::invoke(func, instance.get(), - std::forward(args)...); - })); -} - -template - requires Pointer || SmartPointer || - std::is_same_v> -void Component::def(const std::string& name, Ret (Class::*func)(Args...) const, - const InstanceType& instance, const std::string& group, - const std::string& description) { - if constexpr (std::is_same_v>) { - m_CommandDispatcher_->def( - name, group, description, - std::function([&instance, func](Args... args) { - return std::invoke(func, instance.get(), - std::forward(args)...); - })); - - } else if constexpr (SmartPointer || - std::is_same_v>) { - m_CommandDispatcher_->def( - name, group, description, - std::function([instance, func](Args... args) { - return std::invoke(func, instance.get(), - std::forward(args)...); - })); - } else { - m_CommandDispatcher_->def( - name, group, description, - std::function([instance, func](Args... args) { - return std::invoke(func, instance, std::forward(args)...); - })); - } -} - -template - requires Pointer || SmartPointer || - std::is_same_v> -void Component::def(const std::string& name, - Ret (Class::*func)(Args...) noexcept, - const InstanceType& instance, const std::string& group, - const std::string& description) { - m_CommandDispatcher_->def( - name, group, description, - std::function([instance, func](Args... args) { - return std::invoke(func, instance.get(), - std::forward(args)...); - })); -} - -template - requires Pointer || SmartPointer || - std::is_same_v> -void Component::def(const std::string& name, - Ret (Class::*func)(Args...) const noexcept, - const InstanceType& instance, const std::string& group, - const std::string& description) { - m_CommandDispatcher_->def( - name, group, description, - std::function([instance, func](Args... args) { - return std::invoke(func, instance.get(), - std::forward(args)...); - })); -} - -template - requires Pointer || SmartPointer || - std::is_same_v> -void Component::def(const std::string& name, MemberType Class::*member_var, - const InstanceType& instance, const std::string& group, - const std::string& description) { - m_CommandDispatcher_->def( - "get_" + name, group, "Get " + description, - std::function([instance, member_var]() { - return atom::meta::bindMemberVariable(member_var)(*instance); - })); - m_CommandDispatcher_->def( - "set_" + name, group, "Set " + description, - std::function( - [instance, member_var](MemberType value) { - atom::meta::bindMemberVariable(member_var)(*instance) = value; - })); -} - -template - requires Pointer || SmartPointer || - std::is_same_v> -void Component::def(const std::string& name, - const MemberType Class::*member_var, - const InstanceType& instance, const std::string& group, - const std::string& description) { - m_CommandDispatcher_->def( - "get_" + name, group, "Get " + description, - std::function([instance, member_var]() { - return atom::meta::bindMemberVariable(member_var)(*instance); - })); -} - -template - requires Pointer || SmartPointer || - std::is_same_v> -void Component::def(const std::string& name, Ret (Class::*getter)() const, - void (Class::*setter)(Ret), const InstanceType& instance, - const std::string& group, const std::string& description) { - m_CommandDispatcher_->def("get_" + name, group, "Get " + description, - std::function([instance, getter]() { - return std::invoke(getter, instance.get()); - })); - m_CommandDispatcher_->def( - "set_" + name, group, "Set " + description, - std::function([instance, setter](Ret value) { - std::invoke(setter, instance.get(), value); - })); -} - -template -void Component::def(const std::string& name, MemberType* member_var, - const std::string& group, const std::string& description) { - m_CommandDispatcher_->def( - name, group, description, - std::function( - [member_var]() -> MemberType& { return *member_var; })); -} - -template -void Component::def(const std::string& name, const MemberType* member_var, - const std::string& group, const std::string& description) { - m_CommandDispatcher_->def( - name, group, description, - std::function( - [member_var]() -> const MemberType& { return *member_var; })); -} - -template -void Component::defConstructor(const std::string& name, - const std::string& group, - const std::string& description) { - m_CommandDispatcher_->def(name, group, description, - std::function(Args...)>( - atom::meta::constructor())); -} - -template -void Component::defEnum( - const std::string& name, - const std::unordered_map& enumMap) { - m_TypeCaster_->registerType(std::string(name)); - - for (const auto& [key, value] : enumMap) { - m_TypeCaster_->registerEnumValue(name, key, value); - } - - defConversion( - [this, name](const std::any& enumValue) -> std::any { - const EnumType& value = std::any_cast(enumValue); - return m_TypeCaster_->enumToString(value, name); - }); - - defConversion( - [this, name](const std::any& strValue) -> std::any { - const std::string& value = std::any_cast(strValue); - return m_TypeCaster_->stringToEnum(value, name); - }); -} - -template -void Component::defDefaultConstructor(const std::string& name, - const std::string& group, - const std::string& description) { - m_CommandDispatcher_->def( - name, group, description, - std::function()>([]() -> std::shared_ptr { - return std::make_shared(); - })); -} - -template -void Component::def(const std::string& name, const std::string& group, - const std::string& description) { - auto constructor = atom::meta::defaultConstructor(); - def(name, constructor, group, description); -} - -template -void Component::def(const std::string& name, const std::string& group, - const std::string& description) { - auto constructor_ = atom::meta::constructor(); - def(name, constructor_, group, description); -} - -#endif diff --git a/src/atom/components/component.template b/src/atom/components/component.template deleted file mode 100644 index 2df33d58..00000000 --- a/src/atom/components/component.template +++ /dev/null @@ -1,98 +0,0 @@ -if constexpr (Traits::arity == 0) { - - m_CommandDispatcher_->def( - name, group, description, - std::function(std::forward(func))); -} - - -if constexpr (Traits::arity == 1) { - using ArgType_0 = typename Traits::template argument_t<0>; - m_CommandDispatcher_->def( - name, group, description, - std::function(std::forward(func))); -} - - -if constexpr (Traits::arity == 2) { - using ArgType_0 = typename Traits::template argument_t<0>; - using ArgType_1 = typename Traits::template argument_t<1>; - m_CommandDispatcher_->def( - name, group, description, - std::function(std::forward(func))); -} - - -if constexpr (Traits::arity == 3) { - using ArgType_0 = typename Traits::template argument_t<0>; - using ArgType_1 = typename Traits::template argument_t<1>; - using ArgType_2 = typename Traits::template argument_t<2>; - m_CommandDispatcher_->def( - name, group, description, - std::function(std::forward(func))); -} - - -if constexpr (Traits::arity == 4) { - using ArgType_0 = typename Traits::template argument_t<0>; - using ArgType_1 = typename Traits::template argument_t<1>; - using ArgType_2 = typename Traits::template argument_t<2>; - using ArgType_3 = typename Traits::template argument_t<3>; - m_CommandDispatcher_->def( - name, group, description, - std::function(std::forward(func))); -} - - -if constexpr (Traits::arity == 5) { - using ArgType_0 = typename Traits::template argument_t<0>; - using ArgType_1 = typename Traits::template argument_t<1>; - using ArgType_2 = typename Traits::template argument_t<2>; - using ArgType_3 = typename Traits::template argument_t<3>; - using ArgType_4 = typename Traits::template argument_t<4>; - m_CommandDispatcher_->def( - name, group, description, - std::function(std::forward(func))); -} - - -if constexpr (Traits::arity == 6) { - using ArgType_0 = typename Traits::template argument_t<0>; - using ArgType_1 = typename Traits::template argument_t<1>; - using ArgType_2 = typename Traits::template argument_t<2>; - using ArgType_3 = typename Traits::template argument_t<3>; - using ArgType_4 = typename Traits::template argument_t<4>; - using ArgType_5 = typename Traits::template argument_t<5>; - m_CommandDispatcher_->def( - name, group, description, - std::function(std::forward(func))); -} - - -if constexpr (Traits::arity == 7) { - using ArgType_0 = typename Traits::template argument_t<0>; - using ArgType_1 = typename Traits::template argument_t<1>; - using ArgType_2 = typename Traits::template argument_t<2>; - using ArgType_3 = typename Traits::template argument_t<3>; - using ArgType_4 = typename Traits::template argument_t<4>; - using ArgType_5 = typename Traits::template argument_t<5>; - using ArgType_6 = typename Traits::template argument_t<6>; - m_CommandDispatcher_->def( - name, group, description, - std::function(std::forward(func))); -} - - -if constexpr (Traits::arity == 8) { - using ArgType_0 = typename Traits::template argument_t<0>; - using ArgType_1 = typename Traits::template argument_t<1>; - using ArgType_2 = typename Traits::template argument_t<2>; - using ArgType_3 = typename Traits::template argument_t<3>; - using ArgType_4 = typename Traits::template argument_t<4>; - using ArgType_5 = typename Traits::template argument_t<5>; - using ArgType_6 = typename Traits::template argument_t<6>; - using ArgType_7 = typename Traits::template argument_t<7>; - m_CommandDispatcher_->def( - name, group, description, - std::function(std::forward(func))); -} diff --git a/src/atom/components/dispatch.cpp b/src/atom/components/dispatch.cpp deleted file mode 100644 index 50e7b3c6..00000000 --- a/src/atom/components/dispatch.cpp +++ /dev/null @@ -1,286 +0,0 @@ -#include "dispatch.hpp" - -#include "atom/log/loguru.hpp" -#include "atom/utils/to_string.hpp" - -void CommandDispatcher::checkPrecondition(const Command& cmd, - const std::string& name) { - LOG_SCOPE_FUNCTION(INFO); - if (!cmd.precondition.has_value()) { - LOG_F(INFO, "No precondition for command: {}", name); - return; - } - try { - std::invoke(cmd.precondition.value()); - LOG_F(INFO, "Precondition for command '{}' passed.", name); - } catch (const std::bad_function_call& e) { - LOG_F(INFO, "Bad precondition function invoke for command '{}': {}", - name, e.what()); - } catch (const std::bad_optional_access& e) { - LOG_F(INFO, "Bad precondition function access for command '{}': {}", - name, e.what()); - } catch (const std::exception& e) { - LOG_F(ERROR, "Precondition for command '{}' failed: {}", name, - e.what()); - THROW_DISPATCH_EXCEPTION("Precondition failed for command '{}': {}", - name, e.what()); - } -} - -void CommandDispatcher::checkPostcondition(const Command& cmd, - const std::string& name) { - LOG_SCOPE_FUNCTION(INFO); - if (!cmd.postcondition.has_value()) { - LOG_F(INFO, "No postcondition for command: {}", name); - return; - } - try { - std::invoke(cmd.postcondition.value()); - LOG_F(INFO, "Postcondition for command '{}' passed.", name); - } catch (const std::bad_function_call& e) { - LOG_F(INFO, "Bad postcondition function invoke for command '{}': {}", - name, e.what()); - } catch (const std::bad_optional_access& e) { - LOG_F(INFO, "Bad postcondition function access for command '{}': {}", - name, e.what()); - } catch (const std::exception& e) { - LOG_F(ERROR, "Postcondition for command '{}' failed: {}", name, - e.what()); - THROW_DISPATCH_EXCEPTION("Postcondition failed for command '{}': {}", - name, e.what()); - } -} - -auto CommandDispatcher::executeCommand( - const Command& cmd, const std::string& name, - const std::vector& args) -> std::any { - LOG_SCOPE_FUNCTION(INFO); - if (auto timeoutIt = timeoutMap_.find(name); - timeoutIt != timeoutMap_.end()) { - LOG_F(INFO, "Executing command '{}' with timeout.", name); - return executeWithTimeout(cmd, name, args, timeoutIt->second); - } - LOG_F(INFO, "Executing command '{}' without timeout.", name); - return executeWithoutTimeout(cmd, name, args); -} - -auto CommandDispatcher::executeWithTimeout( - const Command& cmd, const std::string& name, - const std::vector& args, - const std::chrono::duration& timeout) -> std::any { - LOG_SCOPE_FUNCTION(INFO); - auto future = std::async(std::launch::async, - [&]() { return executeFunctions(cmd, args); }); - - if (future.wait_for(timeout) == std::future_status::timeout) { - LOG_F(ERROR, "Command '{}' timed out.", name); - THROW_DISPATCH_TIMEOUT("Command '{}' timed out.", name); - } - - return future.get(); -} - -auto CommandDispatcher::executeWithoutTimeout( - const Command& cmd, const std::string& name, - const std::vector& args) -> std::any { - LOG_SCOPE_FUNCTION(INFO); - if (!args.empty()) { - if (args.size() == 1 && - args[0].type() == typeid(std::vector)) { - LOG_F(INFO, "Executing command '{}' with nested arguments.", name); - return executeFunctions( - cmd, std::any_cast>(args[0])); - } - } - - LOG_F(INFO, "Executing command '{}' with arguments.", name); - return executeFunctions(cmd, args); -} - -auto CommandDispatcher::executeFunctions( - const Command& cmd, const std::vector& args) -> std::any { - LOG_SCOPE_FUNCTION(INFO); - if (std::string funcHash = computeFunctionHash(args); - cmd.hash == funcHash) { - try { - LOG_F(INFO, "Executing function for command with hash: {}", - funcHash); - return std::invoke(cmd.func, args); - } catch (const std::bad_any_cast&) { - LOG_F(ERROR, "Failed to call function for command with hash: {}", - funcHash); - THROW_DISPATCH_EXCEPTION( - "Failed to call function for command with hash {}", funcHash); - } - } - - LOG_F(ERROR, "No matching overload found for command"); - THROW_INVALID_ARGUMENT("No matching overload found for command "); -} - -auto CommandDispatcher::computeFunctionHash(const std::vector& args) - -> std::string { - LOG_SCOPE_FUNCTION(INFO); - std::vector argTypes; - argTypes.reserve(args.size()); - for (const auto& arg : args) { - argTypes.emplace_back( - atom::meta::DemangleHelper::demangle(arg.type().name())); - } - auto hash = atom::utils::toString(atom::algorithm::computeHash(argTypes)); - LOG_F(INFO, "Computed function hash: {}", hash); - return hash; -} - -auto CommandDispatcher::has(const std::string& name) const -> bool { - LOG_SCOPE_FUNCTION(INFO); - if (commands_.find(name) != commands_.end()) { - LOG_F(INFO, "Command '{}' found.", name); - return true; - } - for (const auto& command : commands_) { - if (command.second.aliases.find(name) != command.second.aliases.end()) { - LOG_F(INFO, "Alias '{}' found for command '{}'.", name, - command.first); - return true; - } - } - LOG_F(INFO, "Command '{}' not found.", name); - return false; -} - -void CommandDispatcher::addAlias(const std::string& name, - const std::string& alias) { - LOG_SCOPE_FUNCTION(INFO); - auto it = commands_.find(name); - if (it != commands_.end()) { - it->second.aliases.insert(alias); - commands_[alias] = it->second; - groupMap_[alias] = groupMap_[name]; - LOG_F(INFO, "Alias '{}' added for command '{}'.", alias, name); - } else { - LOG_F(WARNING, "Command '{}' not found. Alias '{}' not added.", name, - alias); - } -} - -void CommandDispatcher::addGroup(const std::string& name, - const std::string& group) { - LOG_SCOPE_FUNCTION(INFO); - groupMap_[name] = group; - LOG_F(INFO, "Command '{}' added to group '{}'.", name, group); -} - -void CommandDispatcher::setTimeout(const std::string& name, - std::chrono::milliseconds timeout) { - LOG_SCOPE_FUNCTION(INFO); - timeoutMap_[name] = timeout; - LOG_F(INFO, "Timeout set for command '{}': {} ms.", name, timeout.count()); -} - -void CommandDispatcher::removeCommand(const std::string& name) { - LOG_SCOPE_FUNCTION(INFO); - commands_.erase(name); - groupMap_.erase(name); - timeoutMap_.erase(name); - LOG_F(INFO, "Command '{}' removed.", name); -} - -auto CommandDispatcher::getCommandsInGroup(const std::string& group) const - -> std::vector { - LOG_SCOPE_FUNCTION(INFO); - std::vector result; - for (const auto& pair : groupMap_) { - if (pair.second == group) { - result.push_back(pair.first); - } - } - LOG_F(INFO, "Commands in group '{}': {}", group, - atom::utils::toString(result)); - return result; -} - -auto CommandDispatcher::getCommandDescription(const std::string& name) const - -> std::string { - LOG_SCOPE_FUNCTION(INFO); - auto it = commands_.find(name); - if (it != commands_.end()) { - LOG_F(INFO, "Description for command '{}': {}", name, - it->second.description); - return it->second.description; - } - LOG_F(INFO, "No description found for command '{}'.", name); - return ""; -} - -auto CommandDispatcher::getCommandAliases(const std::string& name) const - -> std::unordered_set { - LOG_SCOPE_FUNCTION(INFO); - auto it = commands_.find(name); - if (it != commands_.end()) { - LOG_F(INFO, "Aliases for command '{}': {}", name, - atom::utils::toString(it->second.aliases)); - return it->second.aliases; - } - LOG_F(INFO, "No aliases found for command '{}'.", name); - return {}; -} - -auto CommandDispatcher::dispatch( - const std::string& name, const std::vector& args) -> std::any { - LOG_SCOPE_FUNCTION(INFO); - LOG_F(INFO, "Dispatching command '{}'.", name); - return dispatchHelper(name, args); -} - -auto CommandDispatcher::dispatch(const std::string& name, - const atom::meta::FunctionParams& params) - -> std::any { - LOG_SCOPE_FUNCTION(INFO); - LOG_F(INFO, "Dispatching command '{}' with FunctionParams.", name); - return dispatchHelper(name, params.toAnyVector()); -} - -auto CommandDispatcher::getAllCommands() const -> std::vector { - LOG_SCOPE_FUNCTION(INFO); - std::vector result; - result.reserve(commands_.size()); - for (const auto& pair : commands_) { - result.push_back(pair.first); - } - for (const auto& command : commands_) { - for (const auto& alias : command.second.aliases) { - result.push_back(alias); - } - } - auto it = std::unique(result.begin(), result.end()); - result.erase(it, result.end()); - LOG_F(INFO, "All commands: {}", atom::utils::toString(result)); - return result; -} - -namespace atom::utils { - auto toString(const std::vector &arg) -> std::string{ - std::string result; - for (const auto& a : arg) { - result += a.getName() + " "; - } - return result; - } -} - -auto CommandDispatcher::getCommandArgAndReturnType(const std::string& name) - -> std::pair, std::string> { - LOG_SCOPE_FUNCTION(INFO); - auto it = commands_.find(name); - if (it != commands_.end()) { - LOG_F(INFO, - "Argument and return types for command '{}': args = [{}], return " - "= {}", - name, atom::utils::toString(it->second.argTypes), - it->second.returnType); - return {it->second.argTypes, it->second.returnType}; - } - LOG_F(INFO, "No argument and return types found for command '{}'.", name); - return {{}, ""}; -} diff --git a/src/atom/components/dispatch.hpp b/src/atom/components/dispatch.hpp deleted file mode 100644 index f1943f84..00000000 --- a/src/atom/components/dispatch.hpp +++ /dev/null @@ -1,487 +0,0 @@ -#ifndef ATOM_COMMAND_DISPATCH_HPP -#define ATOM_COMMAND_DISPATCH_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if ENABLE_FASTHASH -#include "emhash/hash_set8.hpp" -#include "emhash/hash_table8.hpp" -#else -#include -#include -#endif - -#include "atom/error/exception.hpp" -#include "atom/function/proxy.hpp" -#include "atom/function/type_caster.hpp" -#include "atom/type/json.hpp" - -#include "atom/macro.hpp" - -using json = nlohmann::json; - -// ------------------------------------------------------------------- -// Command Exception -// ------------------------------------------------------------------- - -class DispatchException : public atom::error::Exception { -public: - using atom::error::Exception::Exception; -}; - -#define THROW_DISPATCH_EXCEPTION(...) \ - throw DispatchException(ATOM_FILE_NAME, ATOM_FILE_LINE, ATOM_FUNC_NAME, \ - __VA_ARGS__); - -class DispatchTimeout : public atom::error::Exception { -public: - using atom::error::Exception::Exception; -}; - -#define THROW_DISPATCH_TIMEOUT(...) \ - throw DispatchTimeout(ATOM_FILE_NAME, ATOM_FILE_LINE, ATOM_FUNC_NAME, \ - __VA_ARGS__); - -// ------------------------------------------------------------------- -// Command Dispatcher -// ------------------------------------------------------------------- - -/** - * @brief Manages and dispatches commands. - */ -class CommandDispatcher { -public: - /** - * @brief Constructs a CommandDispatcher with a TypeCaster. - * @param typeCaster A weak pointer to a TypeCaster. - */ - explicit CommandDispatcher(std::weak_ptr typeCaster) - : typeCaster_(std::move(typeCaster)) {} - - /** - * @brief Defines a command. - * @tparam Ret The return type of the command function. - * @tparam Args The argument types of the command function. - * @param name The name of the command. - * @param group The group of the command. - * @param description The description of the command. - * @param func The command function. - * @param precondition An optional precondition function. - * @param postcondition An optional postcondition function. - * @param arg_info Information about the arguments. - */ - template - void def(const std::string& name, const std::string& group, - const std::string& description, std::function func, - std::optional> precondition = std::nullopt, - std::optional> postcondition = std::nullopt, - std::vector arg_info = {}, bool isTimed = false); - - /** - * @brief Checks if a command exists. - * @param name The name of the command. - * @return True if the command exists, false otherwise. - */ - [[nodiscard]] auto has(const std::string& name) const -> bool; - - /** - * @brief Adds an alias for a command. - * @param name The name of the command. - * @param alias The alias for the command. - */ - void addAlias(const std::string& name, const std::string& alias); - - /** - * @brief Adds a command to a group. - * @param name The name of the command. - * @param group The group of the command. - */ - void addGroup(const std::string& name, const std::string& group); - - /** - * @brief Sets a timeout for a command. - * @param name The name of the command. - * @param timeout The timeout duration. - */ - void setTimeout(const std::string& name, std::chrono::milliseconds timeout); - - /** - * @brief Dispatches a command with arguments. - * @tparam Args The argument types. - * @param name The name of the command. - * @param args The arguments for the command. - * @return The result of the command execution. - */ - template - auto dispatch(const std::string& name, Args&&... args) -> std::any; - - /** - * @brief Dispatches a command with a vector of arguments. - * @param name The name of the command. - * @param args The arguments for the command. - * @return The result of the command execution. - */ - auto dispatch(const std::string& name, - const std::vector& args) -> std::any; - - /** - * @brief Dispatches a command with function parameters. - * @param name The name of the command. - * @param params The function parameters. - * @return The result of the command execution. - */ - auto dispatch(const std::string& name, - const atom::meta::FunctionParams& params) -> std::any; - - /** - * @brief Removes a command. - * @param name The name of the command. - */ - void removeCommand(const std::string& name); - - /** - * @brief Gets the commands in a group. - * @param group The group name. - * @return A vector of command names in the group. - */ - [[nodiscard]] auto getCommandsInGroup(const std::string& group) const - -> std::vector; - - /** - * @brief Gets the description of a command. - * @param name The name of the command. - * @return The description of the command. - */ - [[nodiscard]] auto getCommandDescription(const std::string& name) const - -> std::string; - -#if ENABLE_FASTHASH - /** - * @brief Gets the aliases of a command. - * @param name The name of the command. - * @return A set of aliases for the command. - */ - emhash::HashSet getCommandAliases( - const std::string& name) const; -#else - /** - * @brief Gets the aliases of a command. - * @param name The name of the command. - * @return A set of aliases for the command. - */ - [[nodiscard]] auto getCommandAliases(const std::string& name) const - -> std::unordered_set; -#endif - - /** - * @brief Gets all commands. - * @return A vector of all command names. - */ - [[nodiscard]] auto getAllCommands() const -> std::vector; - - [[nodiscard]] auto getCommandArgAndReturnType(const std::string& name) - -> std::pair, std::string>; - - struct Command { - std::function&)> func; - std::string returnType; - std::vector argTypes; - std::string hash; - std::string description; -#if ENABLE_FASTHASH - emhash::HashSet aliases; -#else - std::unordered_set aliases; -#endif - std::optional> precondition; - std::optional> postcondition; - } ATOM_ALIGNAS(128); - -private: - /** - * @brief Helper function to dispatch a command. - * @tparam ArgsType The type of the arguments. - * @param name The name of the command. - * @param args The arguments for the command. - * @return The result of the command execution. - */ - template - auto dispatchHelper(const std::string& name, - const ArgsType& args) -> std::any; - - /** - * @brief Converts a tuple to a vector of arguments. - * @tparam Args The types of the arguments. - * @param tuple The tuple of arguments. - * @return A vector of arguments. - */ - template - auto convertToArgsVector(std::tuple&& tuple) - -> std::vector; - - /** - * @brief Finds a command by name. - * @param name The name of the command. - * @return An iterator to the command. - */ - auto findCommand(const std::string& name); - - /** - * @brief Completes the arguments for a command. - * @tparam ArgsType The type of the arguments. - * @param cmd The command. - * @param args The arguments for the command. - * @return A vector of completed arguments. - */ - template - auto completeArgs(const Command& cmd, - const ArgsType& args) -> std::vector; - - /** - * @brief Checks the precondition of a command. - * @param cmd The command. - * @param name The name of the command. - */ - static void checkPrecondition(const Command& cmd, const std::string& name); - - /** - * @brief Checks the postcondition of a command. - * @param cmd The command. - * @param name The name of the command. - */ - static void checkPostcondition(const Command& cmd, const std::string& name); - - /** - * @brief Executes a command. - * @param cmd The command. - * @param name The name of the command. - * @param args The arguments for the command. - * @return The result of the command execution. - */ - auto executeCommand(const Command& cmd, const std::string& name, - const std::vector& args) -> std::any; - - /** - * @brief Executes a command with a timeout. - * @param cmd The command. - * @param name The name of the command. - * @param args The arguments for the command. - * @param timeout The timeout duration. - * @return The result of the command execution. - */ - static auto executeWithTimeout(const Command& cmd, const std::string& name, - const std::vector& args, - const std::chrono::duration& timeout) - -> std::any; - - /** - * @brief Executes a command without a timeout. - * @param cmd The command. - * @param name The name of the command. - * @param args The arguments for the command. - * @return The result of the command execution. - */ - static auto executeWithoutTimeout(const Command& cmd, const std::string& name, - const std::vector& args) -> std::any; - - /** - * @brief Executes the functions of a command. - * @param cmd The command. - * @param args The arguments for the command. - * @return The result of the command execution. - */ - static auto executeFunctions(const Command& cmd, - const std::vector& args) -> std::any; - - /** - * @brief Computes the hash of the function arguments. - * @param args The arguments for the command. - * @return The hash of the function arguments. - */ - static auto computeFunctionHash(const std::vector& args) -> std::string; - -#if ENABLE_FASTHASH - emhash8::HashMap commands; - emhash8::HashMap groupMap; - emhash8::HashMap timeoutMap; -#else - std::unordered_map commands_; - std::unordered_map groupMap_; - std::unordered_map timeoutMap_; -#endif - - std::weak_ptr typeCaster_; -}; - -inline void to_json(json& j, const CommandDispatcher::Command& cmd) { - j = json{ - {"returnType", cmd.returnType}, - {"argTypes", cmd.argTypes}, - {"hash", cmd.hash}, - {"description", cmd.description}, - {"aliases", cmd.aliases} - }; - - if (cmd.precondition) { - j["precondition"] = true; - } else { - j["precondition"] = false; - } - - if (cmd.postcondition) { - j["postcondition"] = true; - } else { - j["postcondition"] = false; - } -} - -inline void from_json(const json& j, CommandDispatcher::Command& cmd) { - j.at("returnType").get_to(cmd.returnType); - j.at("argTypes").get_to(cmd.argTypes); - j.at("hash").get_to(cmd.hash); - j.at("description").get_to(cmd.description); - j.at("aliases").get_to(cmd.aliases); - - if (j.at("precondition").get()) { - cmd.precondition = []() { return true; }; // Placeholder function - } else { - cmd.precondition.reset(); - } - - if (j.at("postcondition").get()) { - cmd.postcondition = []() {}; // Placeholder function - } else { - cmd.postcondition.reset(); - } -} - -ATOM_INLINE auto CommandDispatcher::findCommand(const std::string& name) { - auto it = commands_.find(name); - if (it == commands_.end()) { - for (const auto& [cmdName, cmd] : commands_) { - if (std::ranges::find(cmd.aliases.begin(), cmd.aliases.end(), name) != - cmd.aliases.end()) { -#if ENABLE_DEBUG - std::cout << "Command '" << name - << "' not found, did you mean '" << cmdName << "'?\n"; -#endif - return commands_.find(cmdName); - } - } - } - return it; -} - -template -void CommandDispatcher::def(const std::string& name, const std::string& group, - const std::string& description, - std::function func, - std::optional> precondition, - std::optional> postcondition, - std::vector arg_info, - bool isTimed) { - std::function&)> wrappedFunc; - atom::meta::FunctionInfo info; - if (isTimed) { - // TODO: Custom timeout duration for each command - auto _func = atom::meta::TimerProxyFunction(std::move(func)); - info = _func.getFunctionInfo(); - wrappedFunc = - [_func](const std::vector& args) mutable -> std::any { - std::chrono::milliseconds defaultTimeout(1000); - return _func(args, defaultTimeout); - }; - } else { - auto _func = atom::meta::ProxyFunction(std::move(func)); - wrappedFunc = - [_func](const std::vector& args) mutable -> std::any { - return _func(args); - }; - } - Command cmd{{std::move(wrappedFunc)}, - {info.getReturnType()}, - arg_info, - {info.getHash()}, - description, - {}, - std::move(precondition), - std::move(postcondition)}; - commands_[name] = std::move(cmd); - groupMap_[name] = group; -} - -template -auto CommandDispatcher::dispatch(const std::string& name, - Args&&... args) -> std::any { - auto argsTuple = std::make_tuple(std::forward(args)...); - auto argsVec = convertToArgsVector(std::move(argsTuple)); - return dispatchHelper(name, argsVec); -} - -template -auto CommandDispatcher::convertToArgsVector(std::tuple&& tuple) - -> std::vector { - std::vector argsVec; - argsVec.reserve(sizeof...(Args)); - std::apply( - [&argsVec](auto&&... args) { - ((argsVec.emplace_back(std::forward(args))), ...); - }, - std::move(tuple)); - return argsVec; -} - -template -auto CommandDispatcher::dispatchHelper(const std::string& name, - const ArgsType& args) -> std::any { - auto it = findCommand(name); - if (it == commands_.end()) { - THROW_INVALID_ARGUMENT("Unknown command: " + name); - } - - const auto& cmd = it->second; - std::vector fullArgs; - fullArgs = completeArgs(cmd, args); - - if constexpr (std::is_same_v>) { - auto it1 = args.begin(); - auto it2 = cmd.argTypes.begin(); - for (; it1 != args.end() && it2 != cmd.argTypes.end(); ++it1, ++it2) { - } - } - - checkPrecondition(cmd, name); - - auto result = executeCommand(cmd, name, fullArgs); - - checkPostcondition(cmd, name); - - return result; -} - -template -auto CommandDispatcher::completeArgs(const Command& cmd, const ArgsType& args) - -> std::vector { - std::vector fullArgs(args.begin(), args.end()); - for (size_t i = args.size(); i < cmd.argTypes.size(); ++i) { - if (cmd.argTypes[i].getDefaultValue()) { - fullArgs.push_back(cmd.argTypes[i].getDefaultValue().value()); - } else { - THROW_INVALID_ARGUMENT("Missing argument: " + - cmd.argTypes[i].getName()); - } - } - return fullArgs; -} - -#endif diff --git a/src/atom/components/module_macro.hpp b/src/atom/components/module_macro.hpp deleted file mode 100644 index f8279763..00000000 --- a/src/atom/components/module_macro.hpp +++ /dev/null @@ -1,155 +0,0 @@ -// Helper macros for registering initializers, dependencies, and modules -#ifndef REGISTER_INITIALIZER -#define REGISTER_INITIALIZER(name, init_func, cleanup_func) \ - namespace { \ - struct Initializer_##name { \ - Initializer_##name() { \ - LOG_F(INFO, "Registering initializer: {}", #name); \ - Registry::instance().addInitializer(#name, init_func, \ - cleanup_func); \ - } \ - }; \ - static Initializer_##name initializer_##name; \ - } -#endif - -#ifndef REGISTER_DEPENDENCY -#define REGISTER_DEPENDENCY(name, dependency) \ - namespace { \ - struct Dependency_##name { \ - Dependency_##name() { \ - LOG_F(INFO, "Registering dependency: {} -> {}", #name, \ - #dependency); \ - Registry::instance().addDependency(#name, #dependency); \ - } \ - }; \ - static Dependency_##name dependency_##name; \ - } -#endif - -// Nested macro for module initialization -#ifndef ATOM_MODULE_INIT -#define ATOM_MODULE_INIT(module_name, init_func) \ - namespace module_name { \ - struct ModuleManager { \ - static void init() { \ - LOG_F(INFO, "Initializing module: {}", #module_name); \ - Registry::instance().registerModule(#module_name, init_func); \ - Registry::instance().addInitializer(#module_name, init_func); \ - Registry::instance().initializeAll(); \ - } \ - static void cleanup() { \ - static std::once_flag flag; \ - std::call_once(flag, []() { \ - LOG_F(INFO, "Cleaning up module: {}", #module_name); \ - Registry::instance().cleanupAll(); \ - }); \ - } \ - }; \ - } -#endif - -// Macro for dynamic library module -#ifndef ATOM_MODULE -#define ATOM_MODULE(module_name, init_func) \ - ATOM_MODULE_INIT(module_name, init_func) \ - extern "C" void module_name##_initialize_registry() { \ - LOG_F(INFO, "Initializing registry for module: {}", #module_name); \ - module_name::ModuleManager::init(); \ - LOG_F(INFO, "Initialized registry for module: {}", #module_name); \ - } \ - extern "C" void module_name##_cleanup_registry() { \ - LOG_F(INFO, "Cleaning up registry for module: {}", #module_name); \ - module_name::ModuleManager::cleanup(); \ - LOG_F(INFO, "Cleaned up registry for module: {}", #module_name); \ - } \ - extern "C" auto module_name##_getInstance()->std::shared_ptr { \ - LOG_F(INFO, "Getting instance of module: {}", #module_name); \ - return Registry::instance().getComponent(#module_name); \ - } -#endif - -// Macro for embedded module -#ifndef ATOM_EMBED_MODULE -#define ATOM_EMBED_MODULE(module_name, init_func) \ - ATOM_MODULE_INIT(module_name, init_func) \ - namespace module_name { \ - inline std::optional init_flag; \ - struct ModuleInitializer { \ - ModuleInitializer() { \ - if (!init_flag.has_value()) { \ - LOG_F(INFO, "Embedding module: {}", #module_name); \ - init_flag.emplace(); \ - Registry::instance().registerModule(#module_name, init_func); \ - Registry::instance().addInitializer(#module_name, init_func); \ - } \ - } \ - ~ModuleInitializer() { \ - if (init_flag.has_value()) { \ - LOG_F(INFO, "Cleaning up embedded module: {}", #module_name); \ - init_flag.reset(); \ - } \ - } \ - }; \ - static ModuleInitializer module_initializer; \ - } -#endif - -// Macro for dynamic library module with test support -// Max: This means that the module is a dynamic library that can be loaded at -// runtime. -// And the test function should hava a signature like void -// test_func(std::shared_ptr instance). -#ifndef ATOM_MODULE_TEST -#define ATOM_MODULE_TEST(module_name, test_func) \ - extern "C" void module_name##_run_tests() { \ - LOG_F(INFO, "Running tests for module: {}", #module_name); \ - try { \ - test_func(module_name##_getInstance()); \ - } catch (const atom::error::ObjectNotExist& e) { \ - LOG_F(ERROR, "{} not found", #module_name); \ - } catch (const std::exception& e) { \ - LOG_F(ERROR, "Exception thrown: {} in {}'s tests", e.what(), \ - #module_name); \ - } \ - LOG_F(INFO, "Finished running tests for module: {}", #module_name); \ - } -#endif - -// Macro for embedded module with test support -#ifndef ATOM_EMBED_MODULE_TEST -#define ATOM_EMBED_MODULE_TEST(module_name, init_func, test_func) \ - ATOM_MODULE_INIT(module_name, init_func) \ - namespace module_name { \ - inline std::optional init_flag; \ - struct ModuleInitializer { \ - ModuleInitializer() { \ - if (!init_flag.has_value()) { \ - LOG_F(INFO, "Embedding module: {}", #module_name); \ - init_flag.emplace(); \ - Registry::instance().registerModule(#module_name, init_func); \ - Registry::instance().addInitializer(#module_name, init_func); \ - } \ - } \ - ~ModuleInitializer() { \ - if (init_flag.has_value()) { \ - LOG_F(INFO, "Cleaning up embedded module: {}", #module_name); \ - init_flag.reset(); \ - } \ - } \ - }; \ - static ModuleInitializer module_initializer; \ - } \ - extern "C" void run_tests() { \ - LOG_F(INFO, "Running tests for module: {}", #module_name); \ - try { \ - test_func(module_name::getInstance()); \ - } catch (const atom::error::ObjectNotExist& e) { \ - LOG_F(ERROR, "{} not found", #module_name); \ - } catch (const std::exception& e) { \ - LOG_F(ERROR, "Exception thrown: {} in {}'s tests", e.what(), \ - #module_name); \ - } \ - LOG_F(INFO, "Finished running tests for module: {}", #module_name); \ - } -#endif diff --git a/src/atom/components/package.hpp b/src/atom/components/package.hpp deleted file mode 100644 index 61e451f3..00000000 --- a/src/atom/components/package.hpp +++ /dev/null @@ -1,215 +0,0 @@ -#ifndef ATOM_COMPONENTS_PACKAGE_HPP -#define ATOM_COMPONENTS_PACKAGE_HPP - -#include -#include -#include -#include -#include -#include - -// Constants -constexpr size_t ALIGNMENT = 64; -constexpr size_t MAX_ELEMENTS = 10; - -// Enum for JSON value types -enum class JsonValueType { STRING, OBJECT, ARRAY, NUMBER, BOOLEAN, UNKNOWN }; - -// Structure for JSON key-value pair -struct alignas(ALIGNMENT) JsonKeyValue { - std::string_view key; - std::string_view value; - JsonValueType type; -}; - -// Helper function to compare strings -constexpr auto Equals(std::string_view str1, std::string_view str2) -> bool { - return str1 == str2; -} - -// Trim leading and trailing spaces -constexpr auto Trim(std::string_view str) -> std::string_view { - auto isSpace = [](char character) { - return character == ' ' || character == '\n' || character == '\t'; - }; - str.remove_prefix(std::ranges::find_if_not(str, isSpace) - str.begin()); - str.remove_suffix( - std::ranges::find_if_not(str | std::views::reverse, isSpace).base() - - str.end()); - return str; -} - -// Remove ']' from a string -constexpr auto RemoveBrackets(std::string_view str) -> std::string_view { - size_t position = 0; - while ((position = str.find(']')) != std::string_view::npos) { - str.remove_suffix(str.size() - position); - } - while ((position = str.find('{')) != std::string_view::npos || - (position = str.find('}')) != std::string_view::npos) { - str.remove_suffix(str.size() - position); - } - return str; -} - -// Parse key-value pair from a JSON line -constexpr auto ParseKeyValue(std::string_view jsonLine) - -> std::pair { - auto colonPosition = jsonLine.find(':'); - if (colonPosition == std::string_view::npos) { - return {JsonKeyValue{}, "Invalid JSON line: no colon found"}; - } - - std::string_view key = Trim(jsonLine.substr(0, colonPosition)); - std::string_view value = Trim(jsonLine.substr(colonPosition + 1)); - - // Remove quotes from key - if (key.front() == '"' && key.back() == '"') { - key.remove_prefix(1); - key.remove_suffix(1); - } - - JsonValueType type = JsonValueType::STRING; - - // Detect value type (OBJECT, ARRAY, STRING, NUMBER, BOOLEAN) - if (value.front() == '{' && value.back() == '}') { - type = JsonValueType::OBJECT; - } else if (value.front() == '[' && value.back() == ']') { - type = JsonValueType::ARRAY; - } else if (value.front() == '"' && value.back() == '"') { - value.remove_prefix(1); - value.remove_suffix(1); - } else if (value == "true" || value == "false") { - type = JsonValueType::BOOLEAN; - } else if (std::all_of(value.begin(), value.end(), ::isdigit)) { - type = JsonValueType::NUMBER; - } else { - type = JsonValueType::UNKNOWN; - } - - return {JsonKeyValue{key, value, type}, ""}; -} - -// Parse a JSON array -constexpr auto ParseArray(std::string_view arrayString) - -> std::pair, std::string> { - std::array result = {}; - size_t currentPosition = 0; - size_t elementIndex = 0; - - // Remove square brackets - arrayString.remove_prefix(1); - arrayString.remove_suffix(1); - - while (currentPosition < arrayString.size() && - elementIndex < result.size()) { - size_t nextCommaPosition = arrayString.find(',', currentPosition); - if (nextCommaPosition == std::string_view::npos) { - nextCommaPosition = arrayString.size(); - } - - std::string_view element = Trim(arrayString.substr( - currentPosition, nextCommaPosition - currentPosition)); - if (element.front() == '"' && element.back() == '"') { - element.remove_prefix(1); - element.remove_suffix(1); - } - result[elementIndex++] = element; - - currentPosition = nextCommaPosition + 1; - } - - return {result, ""}; -} - -// Parse a JSON object -constexpr auto ParseObject(std::string_view objectString) - -> std::pair, std::string> { - std::array result = {}; - size_t currentPosition = 0; - size_t lineIndex = 0; - - // Remove curly braces - objectString.remove_prefix(1); - objectString.remove_suffix(1); - - while (currentPosition < objectString.size() && lineIndex < result.size()) { - size_t nextCommaPosition = objectString.find(',', currentPosition); - if (nextCommaPosition == std::string_view::npos) { - nextCommaPosition = objectString.size(); - } - - std::string_view line = Trim(objectString.substr( - currentPosition, nextCommaPosition - currentPosition)); - if (!line.empty() && absl::StrContains(line, ':')) { - auto [kv, error] = ParseKeyValue(line); - if (!error.empty()) { - return {result, error}; - } - result[lineIndex++] = kv; - } - - currentPosition = nextCommaPosition + 1; - } - - return {result, ""}; -} - -// Parse the entire JSON document -constexpr auto ParseJson(std::string_view json) - -> std::pair, std::string> { - std::array result = {}; - size_t currentPosition = 0; - size_t lineIndex = 0; - - while (currentPosition < json.size() && lineIndex < result.size()) { - size_t nextLinePosition = json.find('\n', currentPosition); - if (nextLinePosition == std::string_view::npos) { - nextLinePosition = json.size(); - } - - std::string_view line = Trim( - json.substr(currentPosition, nextLinePosition - currentPosition)); - if (!line.empty() && absl::StrContains(line, ':')) { - auto [kv, error] = ParseKeyValue(line); - if (!error.empty()) { - return {result, error}; - } - - // If it's an array, parse it - if (kv.type == JsonValueType::ARRAY) { - auto [arrayValues, arrayError] = ParseArray(kv.value); - if (!arrayError.empty()) { - return {result, arrayError}; - } - // Handle array values if needed - } - // If it's an object, parse it - else if (kv.type == JsonValueType::OBJECT) { - auto [objectValues, objectError] = ParseObject(kv.value); - if (!objectError.empty()) { - return {result, objectError}; - } - // Handle object values if needed - } - - result[lineIndex++] = kv; - } - - currentPosition = nextLinePosition + 1; - } - - return {result, ""}; -} - -// Split array elements and get internal values -constexpr void SplitArrayElements( - const std::span& arrayElements) { - for (const auto& element : arrayElements) { - if (!element.empty()) { - std::cout << " Array Element: " << element << '\n'; - } - } -} - -#endif // ATOM_COMPONENTS_PACKAGE_HPP diff --git a/src/atom/components/registry.cpp b/src/atom/components/registry.cpp deleted file mode 100644 index 2b183db3..00000000 --- a/src/atom/components/registry.cpp +++ /dev/null @@ -1,199 +0,0 @@ -/* - * registry.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-3-1 - -Description: Registry Pattern Implementation - -**************************************************/ - -#include "registry.hpp" - -#include "atom/log/loguru.hpp" - -auto Registry::instance() -> Registry& { - static Registry instance; - return instance; -} - -void Registry::registerModule(const std::string& name, - Component::InitFunc init_func) { - std::scoped_lock lock(mutex_); - LOG_F(INFO, "Registering module: {}", name); - module_initializers_[name] = std::move(init_func); -} - -void Registry::addInitializer(const std::string& name, - Component::InitFunc init_func, - Component::CleanupFunc cleanup_func) { - std::scoped_lock lock(mutex_); - if (initializers_.contains(name)) { - return; - } - initializers_[name] = std::make_shared(name); - initializers_[name]->initFunc = std::move(init_func); - initializers_[name]->cleanupFunc = std::move(cleanup_func); - initialized_[name] = false; -} - -void Registry::addDependency(const std::string& name, - const std::string& dependency) { - std::unique_lock lock(mutex_); - if (hasCircularDependency(name, dependency)) { - THROW_RUNTIME_ERROR("Circular dependency detected: " + name + " -> " + - dependency); - } - dependencies_[name].insert(dependency); -} - -void Registry::initializeAll() { - std::unique_lock lock(mutex_); - LOG_F(INFO, "Initializing all components"); - determineInitializationOrder(); - for (const auto& name : initializationOrder_) { - std::unordered_set initStack; - LOG_F(INFO, "Initializing component: {}", name); - initializeComponent(name, initStack); - } -} - -void Registry::cleanupAll() { - std::unique_lock lock(mutex_); - for (const auto& name : std::ranges::reverse_view(initializationOrder_)) { - if (initializers_[name]->cleanupFunc && initialized_[name]) { - initializers_[name]->cleanupFunc(); - initialized_[name] = false; - } - } -} - -auto Registry::isInitialized(const std::string& name) const -> bool { - std::shared_lock lock(mutex_); - auto it = initialized_.find(name); - return it != initialized_.end() && it->second; -} - -void Registry::reinitializeComponent(const std::string& name) { - std::scoped_lock lock(mutex_); - if (initialized_[name]) { - if (auto it = initializers_.find(name); - it != initializers_.end() && it->second->cleanupFunc) { - it->second->cleanupFunc(); - } - } - auto it = module_initializers_.find(name); - if (it != module_initializers_.end()) { - auto component = std::make_shared(name); - it->second(*component); - initializers_[name] = component; - initialized_[name] = true; - } -} - -auto Registry::getComponent(const std::string& name) const - -> std::shared_ptr { - std::shared_lock lock(mutex_); - if (!initializers_.contains(name)) { - THROW_OBJ_NOT_EXIST("Component not registered: " + name); - } - return initializers_.at(name); -} - -auto Registry::getAllComponents() const - -> std::vector> { - std::shared_lock lock(mutex_); - std::vector> components; - for (const auto& pair : initializers_) { - if (pair.second) { - components.push_back(pair.second); - } - } - return components; -} - -auto Registry::getAllComponentNames() const -> std::vector { - std::shared_lock lock(mutex_); - std::vector names; - names.reserve(initializers_.size()); - for (const auto& pair : initializers_) { - names.push_back(pair.first); - } - return names; -} - -void Registry::removeComponent(const std::string& name) { - std::scoped_lock lock(mutex_); - if (initializers_.contains(name)) { - if (initialized_[name] && initializers_[name]->cleanupFunc) { - initializers_[name]->cleanupFunc(); - } - initializers_.erase(name); - initialized_.erase(name); - dependencies_.erase(name); - initializationOrder_.erase( - std::remove(initializationOrder_.begin(), - initializationOrder_.end(), name), - initializationOrder_.end()); - } -} - -bool Registry::hasCircularDependency(const std::string& name, - const std::string& dependency) { - if (dependencies_[dependency].contains(name)) { - return true; - } - for (const auto& dep : dependencies_[dependency]) { - if (hasCircularDependency(name, dep)) { - return true; - } - } - return false; -} - -void Registry::initializeComponent( - const std::string& name, std::unordered_set& init_stack) { - if (initialized_[name]) { - if (init_stack.contains(name)) { - THROW_RUNTIME_ERROR( - "Circular dependency detected while initializing component " - "'{}'", - name); - } - return; - } - if (init_stack.contains(name)) { - THROW_RUNTIME_ERROR( - "Circular dependency detected while initializing: " + name); - } - init_stack.insert(name); - for (const auto& dep : dependencies_[name]) { - initializeComponent(dep, init_stack); - } - if (initializers_[name]->initFunc) { - initializers_[name]->initFunc(*initializers_[name]); - } - initialized_[name] = true; - init_stack.erase(name); -} - -void Registry::determineInitializationOrder() { - std::unordered_set visited; - std::function visit = - [&](const std::string& name) { - if (!visited.contains(name)) { - visited.insert(name); - for (const auto& dep : dependencies_[name]) { - visit(dep); - } - initializationOrder_.push_back(name); - } - }; - for (const auto& pair : initializers_) { - visit(pair.first); - } -} diff --git a/src/atom/components/registry.hpp b/src/atom/components/registry.hpp deleted file mode 100644 index 13c5948b..00000000 --- a/src/atom/components/registry.hpp +++ /dev/null @@ -1,151 +0,0 @@ -/* - * registry.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-3-1 - -Description: Registry Pattern - -**************************************************/ - -#ifndef ATOM_COMPONENT_REGISTRY_HPP -#define ATOM_COMPONENT_REGISTRY_HPP - -#include -#include -#include -#include -#include -#include - -#include "component.hpp" - -/** - * @class Registry - * @brief Manages initialization and cleanup of components in a registry - * pattern. - */ -class Registry { -public: - /** - * @brief Gets the singleton instance of the Registry. - * @return Reference to the singleton instance of the Registry. - */ - static auto instance() -> Registry&; - - /** - * @brief Registers a module's initialization function. - * @param name The name of the module. - * @param init_func The initialization function for the module. - */ - void registerModule(const std::string& name, Component::InitFunc init_func); - - /** - * @brief Adds an initializer function for a component to the registry. - * @param name The name of the component. - * @param init_func The initialization function for the component. - * @param cleanup_func The cleanup function for the component (optional). - */ - void addInitializer(const std::string& name, Component::InitFunc init_func, - Component::CleanupFunc cleanup_func = nullptr); - - /** - * @brief Adds a dependency between two components. - * @param name The name of the component. - * @param dependency The name of the component's dependency. - */ - void addDependency(const std::string& name, const std::string& dependency); - - /** - * @brief Initializes all components in the registry. - */ - void initializeAll(); - - /** - * @brief Cleans up all components in the registry. - */ - void cleanupAll(); - - /** - * @brief Checks if a component is initialized. - * @param name The name of the component to check. - * @return True if the component is initialized, false otherwise. - */ - auto isInitialized(const std::string& name) const -> bool; - - /** - * @brief Reinitializes a component in the registry. - * @param name The name of the component to reinitialize. - */ - void reinitializeComponent(const std::string& name); - - /** - * @brief Gets a component by name. - * @param name The name of the component. - * @return Shared pointer to the component. - */ - auto getComponent(const std::string& name) const - -> std::shared_ptr; - - /** - * @brief Gets all components. - * @return Vector of shared pointers to all components. - */ - auto getAllComponents() const -> std::vector>; - - /** - * @brief Gets the names of all components. - * @return Vector of all component names. - */ - auto getAllComponentNames() const -> std::vector; - - /** - * @brief Removes a component from the registry. - * @param name The name of the component to remove. - */ - void removeComponent(const std::string& name); - -private: - /** - * @brief Private constructor to prevent instantiation. - */ - Registry() = default; - - std::unordered_map> initializers_; - std::unordered_map> - dependencies_; - std::unordered_map initialized_; - std::vector initializationOrder_; - std::unordered_map module_initializers_; - mutable std::shared_mutex mutex_; - - /** - * @brief Checks if adding a dependency creates a circular dependency. - * @param name The name of the component. - * @param dependency The name of the dependency. - * @return True if adding the dependency creates a circular dependency, - * false otherwise. - */ - bool hasCircularDependency(const std::string& name, - const std::string& dependency); - - /** - * @brief Initializes a component and its dependencies recursively. - * @param name The name of the component to initialize. - * @param init_stack Stack to keep track of components being initialized to - * detect circular dependencies. - */ - void initializeComponent(const std::string& name, - std::unordered_set& init_stack); - - /** - * @brief Determines the order of initialization based on dependencies. - */ - void determineInitializationOrder(); -}; - -#endif // ATOM_COMPONENT_REGISTRY_HPP diff --git a/src/atom/components/types.hpp b/src/atom/components/types.hpp deleted file mode 100644 index ed8ad7a9..00000000 --- a/src/atom/components/types.hpp +++ /dev/null @@ -1,46 +0,0 @@ -/* - * types.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-3-1 - -Description: Basic Component Types Definition and Some Utilities - -**************************************************/ - -#ifndef ATOM_COMPONENT_TYPES_HPP -#define ATOM_COMPONENT_TYPES_HPP - -#include "atom/function/enum.hpp" - -enum class ComponentType { - NONE, - SHARED, - SHARED_INJECTED, - SCRIPT, - EXECUTABLE, - TASK, - LAST_ENUM_VALUE -}; - -template <> -struct EnumTraits { - static constexpr std::array VALUES = { - ComponentType::NONE, - ComponentType::SHARED, - ComponentType::SHARED_INJECTED, - ComponentType::SCRIPT, - ComponentType::EXECUTABLE, - ComponentType::TASK, - ComponentType::LAST_ENUM_VALUE}; - - static constexpr std::array NAMES = { - "NONE", "SHARED", "SHARED_INJECTED", "SCRIPT", - "EXECUTABLE", "TASK", "LAST_ENUM_VALUE"}; -}; - -#endif // ATOM_COMPONENT_TYPES_HPP diff --git a/src/atom/components/var.cpp b/src/atom/components/var.cpp deleted file mode 100644 index 82abe757..00000000 --- a/src/atom/components/var.cpp +++ /dev/null @@ -1,77 +0,0 @@ -#include "var.hpp" - -void VariableManager::setStringOptions(const std::string& name, - std::span options) { - LOG_F(INFO, "Setting string options for variable: {}", name); - if (auto variable = getVariable(name)) { - stringOptions_[name] = - std::vector(options.begin(), options.end()); - } -} - -void VariableManager::setValue(const std::string& name, const char* newValue) { - LOG_F(INFO, "Setting value for variable: {}", name); - setValue(name, std::string(newValue)); -} - -auto VariableManager::has(const std::string& name) const -> bool { - LOG_F(INFO, "Checking if variable exists: {}", name); - return variables_.contains(name); -} - -auto VariableManager::getDescription(const std::string& name) const - -> std::string { - LOG_F(INFO, "Getting description for variable: {}", name); - if (auto it = variables_.find(name); it != variables_.end()) { - return it->second.description; - } - for (const auto& [key, value] : variables_) { - if (value.alias == name) { - return value.description; - } - } - return ""; -} - -auto VariableManager::getAlias(const std::string& name) const -> std::string { - LOG_F(INFO, "Getting alias for variable: {}", name); - if (auto it = variables_.find(name); it != variables_.end()) { - return it->second.alias; - } - for (const auto& [key, value] : variables_) { - if (value.alias == name) { - return key; - } - } - return ""; -} - -auto VariableManager::getGroup(const std::string& name) const -> std::string { - LOG_F(INFO, "Getting group for variable: {}", name); - if (auto it = variables_.find(name); it != variables_.end()) { - return it->second.group; - } - for (const auto& [key, value] : variables_) { - if (value.alias == name) { - return value.group; - } - } - return ""; -} - -void VariableManager::removeVariable(const std::string& name) { - LOG_F(INFO, "Removing variable: {}", name); - variables_.erase(name); - ranges_.erase(name); - stringOptions_.erase(name); -} - -auto VariableManager::getAllVariables() const -> std::vector { - LOG_F(INFO, "Getting all variables"); - std::vector variableNames; - variableNames.reserve(variables_.size()); - for (const auto& [name, _] : variables_) { - variableNames.push_back(name); - } - return variableNames; -} diff --git a/src/atom/components/var.hpp b/src/atom/components/var.hpp deleted file mode 100644 index 63085872..00000000 --- a/src/atom/components/var.hpp +++ /dev/null @@ -1,183 +0,0 @@ -/* - * var.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-3-1 - -Description: Variable Manager - -**************************************************/ - -#ifndef ATOM_COMPONENT_VAR_HPP -#define ATOM_COMPONENT_VAR_HPP - -#include -#include -#include -#include -#include -#include - -#include "atom/macro.hpp" - -#if ENABLE_FASTHASH -#include "emhash/hash_table8.hpp" -#else -#include -#endif - -#include "atom/error/exception.hpp" -#include "atom/log/loguru.hpp" -#include "atom/type/trackable.hpp" - -class VariableManager { -public: - template - requires std::is_copy_constructible_v - void addVariable(const std::string& name, T initialValue, - const std::string& description = "", - const std::string& alias = "", - const std::string& group = ""); - - template - requires std::is_copy_constructible_v - void addVariable(const std::string& name, T C::*memberPointer, C& instance, - const std::string& description = "", - const std::string& alias = "", - const std::string& group = ""); - - template - requires std::is_arithmetic_v - void setRange(const std::string& name, T min, T max); - - void setStringOptions(const std::string& name, - std::span options); - - template - [[nodiscard]] auto getVariable(const std::string& name) - -> std::shared_ptr>; - - void setValue(const std::string& name, const char* newValue); - - template - void setValue(const std::string& name, T newValue); - - [[nodiscard]] bool has(const std::string& name) const; - - [[nodiscard]] std::string getDescription(const std::string& name) const; - - [[nodiscard]] std::string getAlias(const std::string& name) const; - - [[nodiscard]] std::string getGroup(const std::string& name) const; - - // New functionalities - void removeVariable(const std::string& name); - [[nodiscard]] std::vector getAllVariables() const; - -private: - struct VariableInfo { - std::any variable; - std::string description; - std::string alias; - std::string group; - } ATOM_ALIGNAS(128); - -#if ENABLE_FASTHASH - emhash8::HashMap variables_; - emhash8::HashMap ranges_; - emhash8::HashMap> stringOptions_; -#else - std::unordered_map variables_; - std::unordered_map ranges_; - std::unordered_map> stringOptions_; -#endif -}; - -template - requires std::is_copy_constructible_v -void VariableManager::addVariable(const std::string& name, T initialValue, - const std::string& description, - const std::string& alias, - const std::string& group) { - LOG_F(INFO, "Adding variable: {}", name); - auto variable = std::make_shared>(std::move(initialValue)); - variables_[name] = {std::move(variable), description, alias, group}; -} - -template - requires std::is_copy_constructible_v -void VariableManager::addVariable(const std::string& name, T C::*memberPointer, - C& instance, const std::string& description, - const std::string& alias, - const std::string& group) { - LOG_F(INFO, "Adding variable with member pointer: {}", name); - auto variable = std::make_shared>(instance.*memberPointer); - variable->setOnChangeCallback( - [&instance, memberPointer](const T& newValue) { - instance.*memberPointer = newValue; - }); - variables_[name] = {std::move(variable), description, alias, group}; -} - -template - requires std::is_arithmetic_v -void VariableManager::setRange(const std::string& name, T min, T max) { - LOG_F(INFO, "Setting range for variable: {}", name); - if (auto variable = getVariable(name)) { - ranges_[name] = std::make_pair(std::move(min), std::move(max)); - } -} - -template -[[nodiscard]] auto VariableManager::getVariable(const std::string& name) - -> std::shared_ptr> { - LOG_F(INFO, "Getting variable: {}", name); - if (auto it = variables_.find(name); it != variables_.end()) { - try { - return std::any_cast>>( - it->second.variable); - } catch (const std::bad_any_cast& e) { - LOG_F(ERROR, "Type mismatch for variable: {}", name); - THROW_INVALID_ARGUMENT("Type mismatch: ", name); - } - } - return nullptr; -} - -template -void VariableManager::setValue(const std::string& name, T newValue) { - LOG_F(INFO, "Setting value for variable: {}", name); - if (auto variable = getVariable(name)) { - if constexpr (std::is_arithmetic_v) { - if (ranges_.contains(name)) { - auto [min, max] = std::any_cast>(ranges_[name]); - if (newValue < min || newValue > max) { - LOG_F(ERROR, "Value out of range for variable: {}", - name); - THROW_OUT_OF_RANGE("Value out of range"); - } - } - } else if constexpr (std::is_same_v || - std::is_same_v) { - if (stringOptions_.contains(name)) { - auto& options = stringOptions_[name]; - if (std::ranges::find(options.begin(), options.end(), newValue) == - options.end()) { - LOG_F(ERROR, "Invalid string option for variable: {}", - name); - THROW_INVALID_ARGUMENT("Invalid string option"); - } - } - } - *variable = std::move(newValue); - } else { - LOG_F(ERROR, "Variable not found: {}", name); - THROW_OBJ_NOT_EXIST("Variable not found"); - } -} - -#endif // ATOM_COMPONENT_VAR_HPP diff --git a/src/atom/components/xmake.lua b/src/atom/components/xmake.lua deleted file mode 100644 index e18b8b9e..00000000 --- a/src/atom/components/xmake.lua +++ /dev/null @@ -1,58 +0,0 @@ -set_project("atom-component") -set_version("1.0.0") - --- Set the C++ standard -set_languages("cxx20") - --- Add required packages -add_requires("loguru") - --- Define libraries -local atom_component_libs = { - "atom-error", - "atom-type", - "atom-utils" -} - -local atom_component_packages = { - "loguru", - "pthread" -} - --- Source files -local source_files = { - "registry.cpp" -} - --- Header files -local header_files = { - "component.hpp", - "dispatch.hpp", - "types.hpp", - "var.hpp" -} - --- Object Library -target("atom-component_object") - set_kind("object") - add_files(table.unpack(source_files)) - add_headerfiles(table.unpack(header_files)) - add_deps(table.unpack(atom_component_libs)) - add_packages(table.unpack(atom_component_packages)) -target_end() - --- Static Library -target("atom-component") - set_kind("static") - add_deps("atom-component_object") - add_files(table.unpack(source_files)) - add_headerfiles(table.unpack(header_files)) - add_packages(table.unpack(atom_component_libs)) - add_includedirs(".") - set_targetdir("$(buildir)/lib") - set_installdir("$(installdir)/lib") - set_version("1.0.0", {build = "%Y%m%d%H%M"}) - on_install(function (target) - os.cp(target:targetfile(), path.join(target:installdir(), "lib")) - end) -target_end() diff --git a/src/atom/connection/CMakeLists.txt b/src/atom/connection/CMakeLists.txt deleted file mode 100644 index f0d6b718..00000000 --- a/src/atom/connection/CMakeLists.txt +++ /dev/null @@ -1,98 +0,0 @@ -# CMakeLists.txt for Atom-Connection -# This project is licensed under the terms of the GPL3 license. -# -# Project Name: Atom-Connection -# Description: Connection Between Lithium Drivers, TCP and IPC -# Author: Max Qian -# License: GPL3 - -cmake_minimum_required(VERSION 3.20) -project(atom-connection C CXX) - -# Sources -list(APPEND ${PROJECT_NAME}_SOURCES - fifoclient.cpp - fifoserver.cpp - sockethub.cpp - tcpclient.cpp - udpclient.cpp - udpserver.cpp -) - -# Headers -list(APPEND ${PROJECT_NAME}_HEADERS - fifoclient.hpp - fifoserver.hpp - sockethub.hpp - tcpclient.hpp - udpclient.hpp - udpserver.hpp -) - -if (ENABLE_LIBSSH) -list(APPEND ${PROJECT_NAME}_SOURCES - sshclient.cpp - sshserver.cpp -) -list(APPEND ${PROJECT_NAME}_HEADERS - sshclient.hpp - sshserver.hpp -) -endif() - -set(${PROJECT_NAME}_LIBS - loguru - ${CMAKE_THREAD_LIBS_INIT} -) - -if (WIN32) -list(APPEND ${PROJECT_NAME}_LIBS - ws2_32 -) -endif() - -if (ENABLE_SSH) -find_package(LibSSH REQUIRED) -list(APPEND ${PROJECT_NAME}_LIBS - ${LIBSSH_LIBRARIES} -) -link_directories(${LIBSSH_LIBRARY_DIRS}) -endif() - -# Build Object Library -add_library(${PROJECT_NAME}_OBJECT OBJECT) -set_property(TARGET ${PROJECT_NAME}_OBJECT PROPERTY POSITION_INDEPENDENT_CODE 1) - -target_sources(${PROJECT_NAME}_OBJECT - PUBLIC - ${${PROJECT_NAME}_HEADERS} - PRIVATE - ${${PROJECT_NAME}_SOURCES} -) - -target_link_libraries(${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) - -add_library(${PROJECT_NAME} STATIC) - -target_link_libraries(${PROJECT_NAME} ${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) -target_include_directories(${PROJECT_NAME} PUBLIC .) - - - -set_target_properties(${PROJECT_NAME} PROPERTIES - VERSION ${CMAKE_HYDROGEN_VERSION_STRING} - SOVERSION ${HYDROGEN_SOVERSION} - OUTPUT_NAME ${PROJECT_NAME} -) - -install(TARGETS ${PROJECT_NAME} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} -) - -if (ATOM_BUILD_PYTHON) -pybind11_add_module(${PROJECT_NAME}-py _pybind.cpp) -target_link_libraries(${PROJECT_NAME}-py PRIVATE ${PROJECT_NAME}) -if (WIN32) -target_link_libraries(${PROJECT_NAME}-py PRIVATE ws2_32) -endif() -endif() diff --git a/src/atom/connection/async_fifoclient.cpp b/src/atom/connection/async_fifoclient.cpp deleted file mode 100644 index 06c9c6a8..00000000 --- a/src/atom/connection/async_fifoclient.cpp +++ /dev/null @@ -1,204 +0,0 @@ -#include "async_fifoclient.hpp" - -#include -#include -#include -#include - -#ifdef _WIN32 -#include -#else -#include -#include -#include -#include -#endif - -namespace atom::async::connection { - -struct FifoClient::Impl { - asio::io_context io_context; -#ifdef _WIN32 - HANDLE fifoHandle{nullptr}; -#else - int fifoFd{-1}; -#endif - std::string fifoPath; - asio::steady_timer timer; - - Impl(std::string_view path) : fifoPath(path), timer(io_context) { - openFifo(); - } - - ~Impl() { close(); } - - void openFifo() { -#ifdef _WIN32 - fifoHandle = - CreateFileA(fifoPath.c_str(), GENERIC_READ | GENERIC_WRITE, 0, - nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); - if (fifoHandle == INVALID_HANDLE_VALUE) { - throw std::runtime_error("Failed to open FIFO pipe"); - } -#else - if (mkfifo(fifoPath.c_str(), 0666) == -1 && errno != EEXIST) { - throw std::system_error(errno, std::generic_category(), - "Failed to create FIFO"); - } - fifoFd = open(fifoPath.c_str(), O_RDWR | O_NONBLOCK); - if (fifoFd == -1) { - throw std::system_error(errno, std::generic_category(), - "Failed to open FIFO pipe"); - } -#endif - } - - bool isOpen() const { -#ifdef _WIN32 - return fifoHandle != INVALID_HANDLE_VALUE; -#else - return fifoFd != -1; -#endif - } - - void close() { -#ifdef _WIN32 - if (isOpen()) { - CloseHandle(fifoHandle); - fifoHandle = INVALID_HANDLE_VALUE; - } -#else - if (isOpen()) { - ::close(fifoFd); - fifoFd = -1; - } -#endif - } - - bool write(std::string_view data, - const std::optional& timeout) { - if (!isOpen()) - return false; - - // Convert data to buffer - std::vector buffer(data.begin(), data.end()); - buffer.push_back('\0'); - -#ifdef _WIN32 - // Windows specific writing logic - DWORD bytesWritten; - if (timeout) { - timer.expires_after(*timeout); - timer.async_wait( - [this, &buffer, &bytesWritten](const asio::error_code&) { - WriteFile(fifoHandle, buffer.data(), - static_cast(buffer.size()), &bytesWritten, - nullptr); - }); - } else { - return WriteFile(fifoHandle, buffer.data(), - static_cast(buffer.size()), &bytesWritten, - nullptr) != 0; - } - io_context.run(); - io_context.reset(); - return true; -#else - if (timeout) { - fd_set writeFds; - FD_ZERO(&writeFds); - FD_SET(fifoFd, &writeFds); - timeval tv{}; - tv.tv_sec = timeout->count() / 1000; - tv.tv_usec = (timeout->count() % 1000) * 1000; - int result = select(fifoFd + 1, nullptr, &writeFds, nullptr, &tv); - if (result > 0) { - return ::write(fifoFd, buffer.data(), buffer.size()) != -1; - } - return false; - } else { - return ::write(fifoFd, buffer.data(), buffer.size()) != -1; - } -#endif - } - - std::optional read( - const std::optional& timeout) { - if (!isOpen()) - return std::nullopt; - - std::string data; - char buffer[1024]; - -#ifdef _WIN32 - // Windows specific reading logic - DWORD bytesRead; - if (timeout) { - timer.expires_after(*timeout); - timer.async_wait( - [this, &data, &buffer, &bytesRead](const asio::error_code&) { - if (ReadFile(fifoHandle, buffer, sizeof(buffer) - 1, - &bytesRead, nullptr) && - bytesRead > 0) { - buffer[bytesRead] = '\0'; - data += buffer; - } - }); - } else { - while (ReadFile(fifoHandle, buffer, sizeof(buffer) - 1, &bytesRead, - nullptr) && - bytesRead > 0) { - buffer[bytesRead] = '\0'; - data += buffer; - } - } -#else - if (timeout) { - fd_set readFds; - FD_ZERO(&readFds); - FD_SET(fifoFd, &readFds); - timeval tv{}; - tv.tv_sec = timeout->count() / 1000; - tv.tv_usec = (timeout->count() % 1000) * 1000; - int result = select(fifoFd + 1, &readFds, nullptr, nullptr, &tv); - if (result > 0) { - ssize_t bytesRead = ::read(fifoFd, buffer, sizeof(buffer) - 1); - if (bytesRead > 0) { - buffer[bytesRead] = '\0'; - data += buffer; - } - } - } else { - ssize_t bytesRead; - while ((bytesRead = ::read(fifoFd, buffer, sizeof(buffer) - 1)) > - 0) { - buffer[bytesRead] = '\0'; - data += buffer; - } - } -#endif - - return data.empty() ? std::nullopt : std::make_optional(data); - } -}; - -FifoClient::FifoClient(std::string fifoPath) - : m_impl(std::make_unique(fifoPath)) {} - -FifoClient::~FifoClient() = default; - -bool FifoClient::write(std::string_view data, - std::optional timeout) { - return m_impl->write(data, timeout); -} - -std::optional FifoClient::read( - std::optional timeout) { - return m_impl->read(timeout); -} - -bool FifoClient::isOpen() const { return m_impl->isOpen(); } - -void FifoClient::close() { m_impl->close(); } - -} // namespace atom::connection diff --git a/src/atom/connection/async_fifoclient.hpp b/src/atom/connection/async_fifoclient.hpp deleted file mode 100644 index 1030b92f..00000000 --- a/src/atom/connection/async_fifoclient.hpp +++ /dev/null @@ -1,72 +0,0 @@ -#ifndef ATOM_CONNECTION_ASYNC_FIFOCLIENT_HPP -#define ATOM_CONNECTION_ASYNC_FIFOCLIENT_HPP - -#include -#include -#include -#include -#include - -namespace atom::async::connection { - -/** - * @brief A class for interacting with a FIFO (First In, First Out) pipe. - * - * This class provides methods to read from and write to a FIFO pipe, - * handling timeouts and ensuring proper resource management. - */ -class FifoClient { -public: - /** - * @brief Constructs a FifoClient with the specified FIFO path. - * - * @param fifoPath The path to the FIFO file to be used for communication. - */ - explicit FifoClient(std::string fifoPath); - - /** - * @brief Destroys the FifoClient and closes the FIFO if it is open. - */ - ~FifoClient(); - - /** - * @brief Writes data to the FIFO. - * - * @param data The data to be written to the FIFO, as a string view. - * @param timeout Optional timeout for the write operation, in milliseconds. - * @return true if the data was successfully written, false if there was an - * error. - */ - auto write(std::string_view data, - std::optional timeout = std::nullopt) - -> bool; - - /** - * @brief Reads data from the FIFO. - * - * @param timeout Optional timeout for the read operation, in milliseconds. - * @return An optional string containing the data read from the FIFO. - */ - auto read(std::optional timeout = std::nullopt) - -> std::optional; - - /** - * @brief Checks if the FIFO is currently open. - * - * @return true if the FIFO is open, false otherwise. - */ - [[nodiscard]] auto isOpen() const -> bool; - - /** - * @brief Closes the FIFO. - */ - void close(); - -private: - struct Impl; ///< Forward declaration of the implementation details - std::unique_ptr m_impl; ///< Pointer to the implementation -}; - -} // namespace atom::connection - -#endif // ATOM_CONNECTION_ASYNC_FIFOCLIENT_HPP diff --git a/src/atom/connection/async_fifoserver.cpp b/src/atom/connection/async_fifoserver.cpp deleted file mode 100644 index eff8f4b0..00000000 --- a/src/atom/connection/async_fifoserver.cpp +++ /dev/null @@ -1,108 +0,0 @@ -/* - * fifoserver.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-1 - -Description: FIFO Server - -*************************************************/ - -#include "async_fifoserver.hpp" - -#include -#include -#include - -namespace atom::async::connection { - -class FifoServer::Impl { -public: - explicit Impl(std::string_view fifo_path) - : fifo_path_(fifo_path), io_context_(), fifo_stream_(io_context_) { -#if __APPLE__ || __linux__ - // Create FIFO if it doesn't exist - if (!std::filesystem::exists(fifo_path_)) { - mkfifo(fifo_path_.c_str(), 0666); - } -#endif - } - - ~Impl() { - stop(); -#if __APPLE__ || __linux__ - std::filesystem::remove(fifo_path_); -#endif - } - - void start() { - if (!isRunning()) { - running_ = true; - io_thread_ = std::thread([this]() { io_context_.run(); }); - acceptConnection(); - } - } - - void stop() { - if (isRunning()) { - running_ = false; - io_context_.stop(); - if (io_thread_.joinable()) { - io_thread_.join(); - } - } - } - - [[nodiscard]] bool isRunning() const { return running_; } - -private: - void acceptConnection() { -#if __APPLE__ || __linux__ - fifo_stream_.assign(open(fifo_path_.c_str(), O_RDWR | O_NONBLOCK)); - readMessage(); -#endif - } - - void readMessage() { -#if __APPLE__ || __linux__ - asio::async_read_until( - fifo_stream_, asio::dynamic_buffer(buffer_), '\n', - [this](std::error_code ec, std::size_t length) { - if (!ec) { - std::string message(buffer_.substr(0, length)); - buffer_.erase(0, length); - std::cout << "Received message: " << message << std::endl; - readMessage(); // Continue reading - } - }); -#endif - } - - std::string fifo_path_; - asio::io_context io_context_; -#ifdef _WIN32 - asio::windows::stream_handle fifo_stream_; -#else - asio::posix::stream_descriptor fifo_stream_; -#endif - std::thread io_thread_; - std::string buffer_; - bool running_ = false; -}; - -FifoServer::FifoServer(std::string_view fifo_path) - : impl_(std::make_unique(fifo_path)) {} - -FifoServer::~FifoServer() = default; - -void FifoServer::start() { impl_->start(); } - -void FifoServer::stop() { impl_->stop(); } - -bool FifoServer::isRunning() const { return impl_->isRunning(); } - -} // namespace atom::async::connection diff --git a/src/atom/connection/async_fifoserver.hpp b/src/atom/connection/async_fifoserver.hpp deleted file mode 100644 index 2935872e..00000000 --- a/src/atom/connection/async_fifoserver.hpp +++ /dev/null @@ -1,64 +0,0 @@ -/* - * fifoserver.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-1 - -Description: FIFO Server - -*************************************************/ - -#ifndef ATOM_CONNECTION_ASYNC_FIFOSERVER_HPP -#define ATOM_CONNECTION_ASYNC_FIFOSERVER_HPP - -#include -#include - -namespace atom::async::connection { - -/** - * @brief A class representing a server for handling FIFO messages. - */ -class FifoServer { -public: - /** - * @brief Constructs a new FifoServer object. - * - * @param fifo_path The path to the FIFO pipe. - */ - explicit FifoServer(std::string_view fifo_path); - - /** - * @brief Destroys the FifoServer object. - */ - ~FifoServer(); - - /** - * @brief Starts the server to listen for messages. - */ - void start(); - - /** - * @brief Stops the server. - */ - void stop(); - - /** - * @brief Checks if the server is running. - * - * @return True if the server is running, false otherwise. - */ - [[nodiscard]] bool isRunning() const; - -private: - class Impl; - std::unique_ptr impl_; -}; - -} // namespace atom::async::connection - -#endif // ATOM_CONNECTION_ASYNC_FIFOSERVER_HPP diff --git a/src/atom/connection/async_sockethub.cpp b/src/atom/connection/async_sockethub.cpp deleted file mode 100644 index fd096f9b..00000000 --- a/src/atom/connection/async_sockethub.cpp +++ /dev/null @@ -1,259 +0,0 @@ -#include "async_sockethub.hpp" -#include -#include -#include - -namespace atom::async::connection { - -class SocketHub::Impl { -public: - Impl(bool use_ssl) - : io_context_(), - acceptor_(io_context_), - ssl_context_(asio::ssl::context::sslv23), - use_ssl_(use_ssl), - is_running_(false) {} - - void start(int port); - void stop(); - - void addHandler( - const std::function& handler); - void addConnectHandler(const std::function& handler); - void addDisconnectHandler(const std::function& handler); - - void broadcastMessage(const std::string& message); - void sendMessageToClient(size_t client_id, const std::string& message); - - [[nodiscard]] auto isRunning() const -> bool; - -private: - void doAccept(); - void handleNewConnection(std::shared_ptr socket); - void doRead(std::shared_ptr socket); - void handleIncomingMessage(const std::string& message, - std::shared_ptr socket); - void handleDisconnect(std::shared_ptr socket); - void disconnectAllClients(); - size_t getClientId(const std::shared_ptr& socket); - void log(const std::string& message); - - asio::io_context io_context_; - asio::ip::tcp::acceptor acceptor_; - asio::ssl::context ssl_context_; - bool use_ssl_; - bool is_running_; - std::unordered_map> clients_; - std::mutex client_mutex_; - std::vector> handlers_; - std::mutex handler_mutex_; - std::vector> connect_handlers_; - std::mutex connect_handler_mutex_; - std::vector> disconnect_handlers_; - std::mutex disconnect_handler_mutex_; - size_t next_client_id_ = 1; - std::thread io_thread_; -}; - -SocketHub::SocketHub(bool use_ssl) : impl_(std::make_unique(use_ssl)) {} - -SocketHub::~SocketHub() = default; - -void SocketHub::start(int port) { impl_->start(port); } - -void SocketHub::stop() { impl_->stop(); } - -void SocketHub::addHandler( - const std::function& handler) { - impl_->addHandler(handler); -} - -void SocketHub::addConnectHandler(const std::function& handler) { - impl_->addConnectHandler(handler); -} - -void SocketHub::addDisconnectHandler( - const std::function& handler) { - impl_->addDisconnectHandler(handler); -} - -void SocketHub::broadcastMessage(const std::string& message) { - impl_->broadcastMessage(message); -} - -void SocketHub::sendMessageToClient(size_t client_id, - const std::string& message) { - impl_->sendMessageToClient(client_id, message); -} - -auto SocketHub::isRunning() const -> bool { return impl_->isRunning(); } - -// Definitions for Impl -void SocketHub::Impl::start(int port) { - asio::ip::tcp::endpoint endpoint(asio::ip::tcp::v4(), port); - acceptor_.open(endpoint.protocol()); - acceptor_.set_option(asio::ip::tcp::acceptor::reuse_address(true)); - acceptor_.bind(endpoint); - acceptor_.listen(); - - is_running_ = true; - doAccept(); - - io_thread_ = std::thread([this]() { io_context_.run(); }); - log("SocketHub started."); -} - -void SocketHub::Impl::stop() { - if (is_running_) { - is_running_ = false; - io_context_.stop(); - disconnectAllClients(); - if (io_thread_.joinable()) - io_thread_.join(); - log("SocketHub stopped."); - } -} - -void SocketHub::Impl::addHandler( - const std::function& handler) { - std::lock_guard lock(handler_mutex_); - handlers_.push_back(handler); -} - -void SocketHub::Impl::addConnectHandler( - const std::function& handler) { - std::lock_guard lock(connect_handler_mutex_); - connect_handlers_.push_back(handler); -} - -void SocketHub::Impl::addDisconnectHandler( - const std::function& handler) { - std::lock_guard lock(disconnect_handler_mutex_); - disconnect_handlers_.push_back(handler); -} - -void SocketHub::Impl::broadcastMessage(const std::string& message) { - std::lock_guard lock(client_mutex_); - for (const auto& [id, socket] : clients_) { - asio::async_write(*socket, asio::buffer(message), - [](std::error_code ec, std::size_t) { - if (ec) { - std::cerr - << "Broadcast error: " << ec.message() - << std::endl; - } - }); - } - log("Broadcasted message: " + message); -} - -void SocketHub::Impl::sendMessageToClient(size_t client_id, - const std::string& message) { - std::lock_guard lock(client_mutex_); - auto it = clients_.find(client_id); - if (it != clients_.end()) { - asio::async_write(*it->second, asio::buffer(message), - [](std::error_code ec, std::size_t) { - if (ec) { - std::cerr << "Send error: " << ec.message() - << std::endl; - } - }); - log("Sent message to client " + std::to_string(client_id) + ": " + - message); - } -} - -[[nodiscard]] auto SocketHub::Impl::isRunning() const -> bool { - return is_running_; -} - -// Private members and methods -void SocketHub::Impl::doAccept() { - auto socket = std::make_shared(io_context_); - acceptor_.async_accept(*socket, [this, socket](std::error_code ec) { - if (!ec) { - handleNewConnection(socket); - doRead(socket); - log("New client connected."); - } - if (is_running_) { - doAccept(); - } - }); -} - -void SocketHub::Impl::handleNewConnection( - std::shared_ptr socket) { - std::lock_guard lock(client_mutex_); - size_t client_id = next_client_id_++; - clients_[client_id] = socket; - for (const auto& handler : connect_handlers_) { - handler(client_id); - } -} - -void SocketHub::Impl::doRead(std::shared_ptr socket) { - auto buffer = std::make_shared>(1024); - socket->async_read_some( - asio::buffer(*buffer), - [this, socket, buffer](std::error_code ec, std::size_t length) { - if (!ec) { - std::string message(buffer->data(), length); - handleIncomingMessage(message, socket); - doRead(socket); - } else { - handleDisconnect(socket); - } - }); -} - -void SocketHub::Impl::handleIncomingMessage( - const std::string& message, std::shared_ptr socket) { - size_t client_id = getClientId(socket); - std::lock_guard lock(handler_mutex_); - for (const auto& handler : handlers_) { - handler(message, client_id); - } - log("Received message from client " + std::to_string(client_id) + ": " + - message); -} - -void SocketHub::Impl::handleDisconnect( - std::shared_ptr socket) { - size_t client_id = getClientId(socket); - { - std::lock_guard lock(client_mutex_); - clients_.erase(client_id); - } - for (const auto& handler : disconnect_handlers_) { - handler(client_id); - } - log("Client " + std::to_string(client_id) + " disconnected."); -} - -void SocketHub::Impl::disconnectAllClients() { - std::lock_guard lock(client_mutex_); - for (auto& [id, socket] : clients_) { - socket->close(); - } - clients_.clear(); -} - -size_t SocketHub::Impl::getClientId( - const std::shared_ptr& socket) { - std::lock_guard lock(client_mutex_); - for (const auto& [id, sock] : clients_) { - if (sock == socket) { - return id; - } - } - return 0; // Should not happen unless the socket is not tracked (edge case) -} - -void SocketHub::Impl::log(const std::string& message) { - // Simple console logging - std::cout << "[SocketHub] " << message << std::endl; -} - -} // namespace atom::async::connection diff --git a/src/atom/connection/async_sockethub.hpp b/src/atom/connection/async_sockethub.hpp deleted file mode 100644 index 2ff861ca..00000000 --- a/src/atom/connection/async_sockethub.hpp +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef ATOM_CONNECTION_ASYNC_SOCKETHUB_HPP -#define ATOM_CONNECTION_ASYNC_SOCKETHUB_HPP - -#include -#include -#include -#include -#include - -namespace atom::async::connection { - -class SocketHub { -public: - SocketHub(bool use_ssl = false); - ~SocketHub(); - - void start(int port); - void stop(); - - void addHandler( - const std::function& handler); - void addConnectHandler(const std::function& handler); - void addDisconnectHandler(const std::function& handler); - - void broadcastMessage(const std::string& message); - void sendMessageToClient(size_t client_id, const std::string& message); - - [[nodiscard]] auto isRunning() const -> bool; - -private: - class Impl; - std::unique_ptr impl_; -}; - -} // namespace atom::async::connection - -#endif // ATOM_CONNECTION_ASYNC_SOCKETHUB_HPP diff --git a/src/atom/connection/async_tcpclient.cpp b/src/atom/connection/async_tcpclient.cpp deleted file mode 100644 index 231c87a5..00000000 --- a/src/atom/connection/async_tcpclient.cpp +++ /dev/null @@ -1,313 +0,0 @@ -#include "async_tcpclient.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::async::connection { - -class TcpClient::Impl { -public: - Impl(bool use_ssl) - : io_context_(), - ssl_context_(asio::ssl::context::sslv23), - socket_(use_ssl ? (asio::ip::tcp::socket(io_context_)) - : (ssl_socket_t(io_context_, ssl_context_))), - use_ssl_(use_ssl), - connected_(false), - reconnect_attempts_(0), - heartbeat_interval_(5000), - total_bytes_sent_(0), - total_bytes_received_(0) { - if (use_ssl_) { - ssl_context_.set_verify_mode(asio::ssl::verify_peer); - } - } - - ~Impl() { disconnect(); } - - bool connect( - const std::string& host, int port, - std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) { - last_host_ = host; - last_port_ = port; - - try { - asio::ip::tcp::resolver resolver(io_context_); - auto endpoints = resolver.resolve(host, std::to_string(port)); - - asio::error_code ec; - asio::connect(socket_, endpoints, ec); - - if (ec) { - logError(ec.message()); - return false; - } - - if (use_ssl_) { - asio::error_code ssl_ec; - socket_.handshake(asio::ssl::stream_base::client, ssl_ec); - if (ssl_ec) { - logError(ssl_ec.message()); - return false; - } - } - - connected_ = true; - if (on_connected_) - on_connected_(); - - startReceiving(1024); - startHeartbeat(); - - io_thread_ = std::thread([this]() { io_context_.run(); }); - - logInfo("Connected to server."); - return true; - } catch (const std::exception& e) { - logError(e.what()); - return false; - } - } - - void disconnect() { - if (connected_) { - if (use_ssl_) { - socket_.lowest_layer().close(); - } else { - socket_.lowest_layer().close(); - } - connected_ = false; - if (on_disconnected_) - on_disconnected_(); - logInfo("Disconnected from server."); - } - - if (io_thread_.joinable()) { - io_context_.stop(); - io_thread_.join(); - } - } - - void enableReconnection(int attempts) { reconnect_attempts_ = attempts; } - - void setHeartbeatInterval(std::chrono::milliseconds interval) { - heartbeat_interval_ = interval; - } - - bool send(const std::vector& data) { - if (!connected_) { - logError("Not connected to any server."); - return false; - } - - try { - auto bytes_written = asio::write(socket_, asio::buffer(data)); - total_bytes_sent_ += bytes_written; - logInfo("Sent data of size: " + std::to_string(bytes_written)); - return true; - } catch (const std::exception& e) { - logError(e.what()); - return false; - } - } - - std::future> receive( - size_t size, - std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) { - return std::async(std::launch::async, [=, this]() { - std::vector data(size); - try { - auto bytes_read = asio::read(socket_, asio::buffer(data, size)); - total_bytes_received_ += bytes_read; - logInfo("Received data of size: " + std::to_string(bytes_read)); - return data; - } catch (const std::exception& e) { - logError(e.what()); - } - return data; - }); - } - - [[nodiscard]] bool isConnected() const { return connected_; } - - [[nodiscard]] std::string getErrorMessage() const { return last_error_; } - - void setOnConnectedCallback(const OnConnectedCallback& callback) { - on_connected_ = callback; - } - - void setOnDisconnectedCallback(const OnDisconnectedCallback& callback) { - on_disconnected_ = callback; - } - - void setOnDataReceivedCallback(const OnDataReceivedCallback& callback) { - on_data_received_ = callback; - } - - void setOnErrorCallback(const OnErrorCallback& callback) { - on_error_ = callback; - } - -private: - using ssl_socket_t = asio::ssl::stream; - - void startReceiving(size_t bufferSize) { - receive_buffer_.resize(bufferSize); - doReceive(); - } - - void doReceive() { - socket_.async_read_some( - asio::buffer(receive_buffer_), - [this](std::error_code ec, std::size_t length) { - if (!ec) { - total_bytes_received_ += length; - if (on_data_received_) { - on_data_received_(std::vector( - receive_buffer_.begin(), - receive_buffer_.begin() + length)); - } - doReceive(); - } else { - handleDisconnect(ec.message()); - } - }); - } - - void startHeartbeat() { - heartbeat_timer_.expires_after(heartbeat_interval_); - heartbeat_timer_.async_wait([this](const std::error_code& ec) { - if (!ec && connected_) { - send(std::vector{'P'}); // Example ping message - startHeartbeat(); // Re-schedule the heartbeat - } - }); - } - - void handleDisconnect(const std::string& error) { - connected_ = false; - if (on_disconnected_) - on_disconnected_(); - - logError("Disconnected due to: " + error); - - reconnect(); - } - - void reconnect() { - int attempts = 0; - while (attempts < reconnect_attempts_ && !connected_) { - attempts++; - if (connect(last_host_, last_port_)) { - logInfo("Reconnected after " + std::to_string(attempts) + - " attempts."); - return; - } - - std::this_thread::sleep_for(std::chrono::seconds(1) * - attempts); - } - - if (!connected_ && on_error_) { - on_error_("Reconnection failed after " + - std::to_string(reconnect_attempts_) + " attempts."); - } - } - - void logInfo(const std::string& message) { - std::cout << "[INFO] " << message << std::endl; - } - - void logError(const std::string& message) { - std::cerr << "[ERROR] " << message << std::endl; - last_error_ = message; - } - - asio::io_context io_context_; - asio::ssl::context ssl_context_; - ssl_socket_t socket_; - asio::steady_timer heartbeat_timer_{io_context_}; - std::thread io_thread_; - - bool use_ssl_; - bool connected_; - std::string last_error_; - std::vector receive_buffer_; - - std::string last_host_; - int last_port_; - - OnConnectedCallback on_connected_; - OnDisconnectedCallback on_disconnected_; - OnDataReceivedCallback on_data_received_; - OnErrorCallback on_error_; - - int reconnect_attempts_; - std::chrono::milliseconds heartbeat_interval_; - - std::atomic total_bytes_sent_; - std::atomic total_bytes_received_; -}; - -TcpClient::TcpClient(bool use_ssl) : impl_(std::make_unique(use_ssl)) {} - -TcpClient::~TcpClient() = default; - -bool TcpClient::connect(const std::string& host, int port, - std::chrono::milliseconds timeout) { - return impl_->connect(host, port, timeout); -} - -void TcpClient::disconnect() { impl_->disconnect(); } - -void TcpClient::enableReconnection(int attempts) { - impl_->enableReconnection(attempts); -} - -void TcpClient::setHeartbeatInterval(std::chrono::milliseconds interval) { - impl_->setHeartbeatInterval(interval); -} - -bool TcpClient::send(const std::vector& data) { - return impl_->send(data); -} - -std::future> TcpClient::receive( - size_t size, std::chrono::milliseconds timeout) { - return impl_->receive(size, timeout); -} - -bool TcpClient::isConnected() const { return impl_->isConnected(); } - -std::string TcpClient::getErrorMessage() const { - return impl_->getErrorMessage(); -} - -void TcpClient::setOnConnectedCallback(const OnConnectedCallback& callback) { - impl_->setOnConnectedCallback(callback); -} - -void TcpClient::setOnDisconnectedCallback( - const OnDisconnectedCallback& callback) { - impl_->setOnDisconnectedCallback(callback); -} - -void TcpClient::setOnDataReceivedCallback( - const OnDataReceivedCallback& callback) { - impl_->setOnDataReceivedCallback(callback); -} - -void TcpClient::setOnErrorCallback(const OnErrorCallback& callback) { - impl_->setOnErrorCallback(callback); -} - -} // namespace atom::async::connection diff --git a/src/atom/connection/async_tcpclient.hpp b/src/atom/connection/async_tcpclient.hpp deleted file mode 100644 index e8634c6b..00000000 --- a/src/atom/connection/async_tcpclient.hpp +++ /dev/null @@ -1,54 +0,0 @@ -#ifndef ATOM_CONNECTION_ASYNC_TCPCLIENT_HPP -#define ATOM_CONNECTION_ASYNC_TCPCLIENT_HPP - -#include -#include -#include -#include -#include -#include - -namespace atom::async::connection { - -class TcpClient { -public: - using OnConnectedCallback = std::function; - using OnDisconnectedCallback = std::function; - using OnDataReceivedCallback = - std::function&)>; - using OnErrorCallback = std::function; - - TcpClient(bool use_ssl = false); - ~TcpClient(); - - bool connect( - const std::string& host, int port, - std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()); - - void disconnect(); - - void enableReconnection(int attempts); - void setHeartbeatInterval(std::chrono::milliseconds interval); - - bool send(const std::vector& data); - - std::future> receive( - size_t size, - std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()); - - [[nodiscard]] bool isConnected() const; - [[nodiscard]] std::string getErrorMessage() const; - - void setOnConnectedCallback(const OnConnectedCallback& callback); - void setOnDisconnectedCallback(const OnDisconnectedCallback& callback); - void setOnDataReceivedCallback(const OnDataReceivedCallback& callback); - void setOnErrorCallback(const OnErrorCallback& callback); - -private: - class Impl; - std::unique_ptr impl_; -}; - -} // namespace atom::async::connection - -#endif // ATOM_CONNECTION_ASYNC_TCPCLIENT_HPP diff --git a/src/atom/connection/async_udpclient.cpp b/src/atom/connection/async_udpclient.cpp deleted file mode 100644 index b2106700..00000000 --- a/src/atom/connection/async_udpclient.cpp +++ /dev/null @@ -1,142 +0,0 @@ -#include "async_udpclient.hpp" - -#include -#include -#include - -namespace atom::async::connection { - -class UdpClient::Impl { -public: - Impl() : io_context_(), socket_(io_context_), is_receiving_(false) {} - - bool bind(int port) { - try { - asio::ip::udp::endpoint endpoint(asio::ip::udp::v4(), port); - socket_.open(endpoint.protocol()); - socket_.bind(endpoint); - return true; - } catch (...) { - return false; - } - } - - bool send(const std::string& host, int port, - const std::vector& data) { - try { - asio::ip::udp::resolver resolver(io_context_); - asio::ip::udp::endpoint destination = - *resolver.resolve(host, std::to_string(port)).begin(); - socket_.send_to(asio::buffer(data), destination); - return true; - } catch (...) { - return false; - } - } - - std::vector receive(size_t size, std::string& remoteHost, - int& remotePort, - std::chrono::milliseconds timeout) { - std::vector data(size); - asio::ip::udp::endpoint senderEndpoint; - asio::error_code ec; - socket_.receive_from(asio::buffer(data), senderEndpoint, 0, ec); - if (!ec) { - remoteHost = senderEndpoint.address().to_string(); - remotePort = senderEndpoint.port(); - return data; - } - return {}; - } - - void setOnDataReceivedCallback(const OnDataReceivedCallback& callback) { - onDataReceivedCallback_ = callback; - } - - void setOnErrorCallback(const OnErrorCallback& callback) { - onErrorCallback_ = callback; - } - - void startReceiving(size_t bufferSize) { - is_receiving_ = true; - receive_buffer_.resize(bufferSize); - doReceive(); - receive_thread_ = std::thread([this] { io_context_.run(); }); - } - - void stopReceiving() { - is_receiving_ = false; - socket_.close(); - if (receive_thread_.joinable()) { - receive_thread_.join(); - } - } - -private: - void doReceive() { - if (!is_receiving_) - return; - - socket_.async_receive_from( - asio::buffer(receive_buffer_), remote_endpoint_, - [this](std::error_code ec, std::size_t bytes_recvd) { - if (!ec && bytes_recvd > 0) { - if (onDataReceivedCallback_) { - auto data = std::vector( - receive_buffer_.begin(), - receive_buffer_.begin() + bytes_recvd); - onDataReceivedCallback_( - data, remote_endpoint_.address().to_string(), - remote_endpoint_.port()); - } - doReceive(); - } else { - if (onErrorCallback_) { - onErrorCallback_("Receive error"); - } - } - }); - } - - asio::io_context io_context_; - asio::ip::udp::socket socket_; - asio::ip::udp::endpoint remote_endpoint_; - std::vector receive_buffer_; - std::thread receive_thread_; - bool is_receiving_; - OnDataReceivedCallback onDataReceivedCallback_; - OnErrorCallback onErrorCallback_; -}; - -UdpClient::UdpClient() : impl_(std::make_unique()) {} -UdpClient::~UdpClient() = default; - -bool UdpClient::bind(int port) { return impl_->bind(port); } - -bool UdpClient::send(const std::string& host, int port, - const std::vector& data) { - return impl_->send(host, port, data); -} - -std::vector UdpClient::receive(size_t size, std::string& remoteHost, - int& remotePort, - std::chrono::milliseconds timeout) { - return impl_->receive(size, remoteHost, remotePort, timeout); -} - -void UdpClient::setOnDataReceivedCallback( - const OnDataReceivedCallback& callback) { - impl_->setOnDataReceivedCallback(callback); -} - -void UdpClient::setOnErrorCallback(const OnErrorCallback& callback) { - impl_->setOnErrorCallback(callback); -} - -void UdpClient::startReceiving(size_t bufferSize) { - impl_->startReceiving(bufferSize); -} - -void UdpClient::stopReceiving() { impl_->stopReceiving(); } - -} // namespace atom::async::connection diff --git a/src/atom/connection/async_udpclient.hpp b/src/atom/connection/async_udpclient.hpp deleted file mode 100644 index 1a2eb20a..00000000 --- a/src/atom/connection/async_udpclient.hpp +++ /dev/null @@ -1,58 +0,0 @@ -/* - * udpclient.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* -Date: 2024-5-24 -Description: UDP Client Class -*************************************************/ - -#ifndef ATOM_CONNECTION_ASYNC_UDPCLIENT_HPP -#define ATOM_CONNECTION_ASYNC_UDPCLIENT_HPP - -#include -#include -#include -#include -#include -#include - -namespace atom::async::connection { - -/** - * @class UdpClient - * @brief Represents a UDP client for sending and receiving datagrams. - */ -class UdpClient { -public: - using OnDataReceivedCallback = - std::function&, const std::string&, int)>; - using OnErrorCallback = std::function; - - UdpClient(); - ~UdpClient(); - - UdpClient(const UdpClient&) = delete; - UdpClient& operator=(const UdpClient&) = delete; - - bool bind(int port); - bool send(const std::string& host, int port, const std::vector& data); - std::vector receive( - size_t size, std::string& remoteHost, int& remotePort, - std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()); - - void setOnDataReceivedCallback(const OnDataReceivedCallback& callback); - void setOnErrorCallback(const OnErrorCallback& callback); - - void startReceiving(size_t bufferSize); - void stopReceiving(); - -private: - class Impl; - std::unique_ptr impl_; -}; - -} // namespace atom::async::connection -#endif // ATOM_CONNECTION_ASYNC_UDPCLIENT_HPP diff --git a/src/atom/connection/async_udpserver.cpp b/src/atom/connection/async_udpserver.cpp deleted file mode 100644 index 0d7b8d37..00000000 --- a/src/atom/connection/async_udpserver.cpp +++ /dev/null @@ -1,153 +0,0 @@ -/* - * udp_server.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-1-4 - -Description: A simple Asio-based UDP server. - -*************************************************/ - -#include "async_udpserver.hpp" - -#include -#include -#include -#include - - -namespace atom::async::connection { - -constexpr std::size_t BUFFER_SIZE = 1024; - -class UdpSocketHub::Impl { -public: - Impl() : socket_(io_context_), running_(false), data_{} {} - - ~Impl() { stop(); } - - Impl(const Impl&) = delete; - Impl& operator=(const Impl&) = delete; - Impl(Impl&&) = delete; - Impl& operator=(Impl&&) = delete; - - void start(unsigned short port) { - if (running_) { - return; - } - - asio::ip::udp::endpoint endpoint(asio::ip::udp::v4(), port); - socket_.open(endpoint.protocol()); - socket_.bind(endpoint); - - running_ = true; - doReceive(); - - io_thread_ = std::thread([this] { io_context_.run(); }); - } - - void stop() { - if (!running_) { - return; - } - - running_ = false; - socket_.close(); - io_context_.stop(); - - if (io_thread_.joinable()) { - io_thread_.join(); - } - } - - [[nodiscard]] auto isRunning() const -> bool { return running_; } - - void addMessageHandler(MessageHandler handler) { - handlers_.push_back(std::move(handler)); - } - - void removeMessageHandler(MessageHandler handler) { - handlers_.erase( - std::remove_if( - handlers_.begin(), handlers_.end(), - [&](const MessageHandler& handlerToRemove) { - return handler.target() == - handlerToRemove.target(); - }), - handlers_.end()); - } - - void sendTo(const std::string& message, const std::string& ipAddress, - unsigned short port) { - if (!running_) { - std::cerr << "Server is not running." << std::endl; - return; - } - - asio::ip::udp::endpoint endpoint(asio::ip::make_address(ipAddress), - port); - socket_.async_send_to( - asio::buffer(message), endpoint, - [](std::error_code /*errorCode*/, std::size_t /*bytesSent*/) {}); - } - -private: - void doReceive() { - socket_.async_receive_from( - asio::buffer(data_), senderEndpoint_, - [this](std::error_code errorCode, std::size_t bytesReceived) { - if (!errorCode && bytesReceived > 0) { - std::string message(data_.data(), bytesReceived); - std::string senderIp = - senderEndpoint_.address().to_string(); - unsigned short senderPort = senderEndpoint_.port(); - - for (const auto& handler : handlers_) { - handler(message, senderIp, senderPort); - } - doReceive(); - } - }); - } - - asio::io_context io_context_; - asio::ip::udp::socket socket_; - asio::ip::udp::endpoint senderEndpoint_; - std::array data_; - std::vector handlers_; - std::thread io_thread_; - bool running_ = false; -}; - -UdpSocketHub::UdpSocketHub() : impl_(std::make_unique()) {} - -UdpSocketHub::~UdpSocketHub() = default; - -void UdpSocketHub::start(unsigned short port) { impl_->start(port); } - -void UdpSocketHub::stop() { impl_->stop(); } - -auto UdpSocketHub::isRunning() const -> bool { return impl_->isRunning(); } - -void UdpSocketHub::addMessageHandler(MessageHandler handler) { - impl_->addMessageHandler(std::move(handler)); -} - -void UdpSocketHub::removeMessageHandler(MessageHandler handler) { - impl_->removeMessageHandler(std::move(handler)); -} - -void UdpSocketHub::sendTo(const std::string& message, - const std::string& ipAddress, unsigned short port) { - impl_->sendTo(message, ipAddress, port); -} - -} // namespace atom::connection diff --git a/src/atom/connection/async_udpserver.hpp b/src/atom/connection/async_udpserver.hpp deleted file mode 100644 index a3e720cd..00000000 --- a/src/atom/connection/async_udpserver.hpp +++ /dev/null @@ -1,55 +0,0 @@ -/* - * udp_server.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-1-4 - -Description: A simple Asio-based UDP server. - -*************************************************/ - -#ifndef ATOM_CONNECTION_ASYNC_UDPSERVER_HPP -#define ATOM_CONNECTION_ASYNC_UDPSERVER_HPP - -#include -#include -#include - -namespace atom::async::connection { -/** - * @class UdpSocketHub - * @brief Represents a hub for managing UDP sockets and message handling using - * Asio. - */ -class UdpSocketHub { -public: - using MessageHandler = std::function; - - UdpSocketHub(); - ~UdpSocketHub(); - - UdpSocketHub(const UdpSocketHub&) = delete; - UdpSocketHub& operator=(const UdpSocketHub&) = delete; - - void start(unsigned short port); - void stop(); - bool isRunning() const; - - void addMessageHandler(MessageHandler handler); - void removeMessageHandler(MessageHandler handler); - void sendTo(const std::string& message, const std::string& ip, - unsigned short port); - -private: - class Impl; - std::unique_ptr impl_; -}; - -} // namespace atom::connection - -#endif diff --git a/src/atom/connection/fifoclient.cpp b/src/atom/connection/fifoclient.cpp deleted file mode 100644 index c85f6c16..00000000 --- a/src/atom/connection/fifoclient.cpp +++ /dev/null @@ -1,195 +0,0 @@ -/* - * fifoclient.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-1 - -Description: FIFO Client - -*************************************************/ - -#include "fifoclient.hpp" - -#include -#include - -#ifdef _WIN32 -#include -#else -#include -#include -#include -#include -#endif - -namespace atom::connection { -struct FifoClient::Impl { -#ifdef _WIN32 - HANDLE fifoHandle{nullptr}; -#else - int fifoFd{-1}; -#endif - std::string fifoPath; - -#ifdef _WIN32 - Impl(std::string_view path) : fifoPath(path) { - fifoHandle = - CreateFileA(fifoPath.c_str(), GENERIC_READ | GENERIC_WRITE, 0, - nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); - if (fifoHandle == INVALID_HANDLE_VALUE) - throw std::runtime_error("Failed to open FIFO pipe"); - } -#else - Impl(std::string_view path) : fifoPath(path) { - fifoFd = open(fifoPath.c_str(), O_RDWR | O_NONBLOCK); - if (fifoFd == -1) - throw std::system_error(errno, std::generic_category(), - "Failed to open FIFO pipe"); - } -#endif - - ~Impl() { close(); } - - bool isOpen() const { -#ifdef _WIN32 - return fifoHandle != INVALID_HANDLE_VALUE; -#else - return fifoFd != -1; -#endif - } - - void close() { -#ifdef _WIN32 - if (isOpen()) { - CloseHandle(fifoHandle); - fifoHandle = INVALID_HANDLE_VALUE; - } -#else - if (isOpen()) { - ::close(fifoFd); - fifoFd = -1; - } -#endif - } - - bool write(std::string_view data, - std::optional timeout) { - std::vector buffer(data.begin(), data.end()); - buffer.push_back('\0'); - -#ifdef _WIN32 - DWORD bytesWritten; - if (timeout) { - COMMTIMEOUTS timeouts{}; - timeouts.WriteTotalTimeoutConstant = - static_cast(timeout->count()); - SetCommTimeouts(fifoHandle, &timeouts); - bool success = WriteFile(fifoHandle, buffer.data(), - static_cast(buffer.size()), - &bytesWritten, nullptr) != 0; - timeouts.WriteTotalTimeoutConstant = 0; - SetCommTimeouts(fifoHandle, &timeouts); - return success; - } - return WriteFile(fifoHandle, buffer.data(), - static_cast(buffer.size()), &bytesWritten, - nullptr) != 0; -#else - if (!timeout) { - return ::write(fifoFd, buffer.data(), buffer.size()) != -1; - } else { - fd_set writeFds; - FD_ZERO(&writeFds); - FD_SET(fifoFd, &writeFds); - timeval tv{}; - tv.tv_sec = timeout->count() / 1000; - tv.tv_usec = (timeout->count() % 1000) * 1000; - int result = select(fifoFd + 1, nullptr, &writeFds, nullptr, &tv); - if (result <= 0) - return false; - return ::write(fifoFd, buffer.data(), buffer.size()) != -1; - } -#endif - } - - std::optional read( - std::optional timeout) { - std::string data; - char buffer[1024]; - -#ifdef _WIN32 - DWORD bytesRead; - if (timeout) { - COMMTIMEOUTS timeouts{}; - timeouts.ReadTotalTimeoutConstant = - static_cast(timeout->count()); - SetCommTimeouts(fifoHandle, &timeouts); - if (ReadFile(fifoHandle, buffer, sizeof(buffer) - 1, &bytesRead, - nullptr) && - bytesRead > 0) { - buffer[bytesRead] = '\0'; - data += buffer; - } - timeouts.ReadTotalTimeoutConstant = 0; - SetCommTimeouts(fifoHandle, &timeouts); - } else { - while (ReadFile(fifoHandle, buffer, sizeof(buffer) - 1, &bytesRead, - nullptr) && - bytesRead > 0) { - buffer[bytesRead] = '\0'; - data += buffer; - } - } -#else - if (!timeout) { - ssize_t bytesRead; - while ((bytesRead = ::read(fifoFd, buffer, sizeof(buffer) - 1)) > - 0) { - buffer[bytesRead] = '\0'; - data += buffer; - } - } else { - fd_set readFds; - FD_ZERO(&readFds); - FD_SET(fifoFd, &readFds); - timeval tv{}; - tv.tv_sec = timeout->count() / 1000; - tv.tv_usec = (timeout->count() % 1000) * 1000; - int result = select(fifoFd + 1, &readFds, nullptr, nullptr, &tv); - if (result > 0) { - ssize_t bytesRead = ::read(fifoFd, buffer, sizeof(buffer) - 1); - if (bytesRead > 0) { - buffer[bytesRead] = '\0'; - data += buffer; - } - } - } -#endif - - return data.empty() ? std::nullopt : std::make_optional(data); - } -}; - -FifoClient::FifoClient(std::string fifoPath) - : m_impl(std::make_unique(fifoPath)) {} -FifoClient::~FifoClient() = default; - -bool FifoClient::write(std::string_view data, - std::optional timeout) { - return m_impl->write(data, timeout); -} - -std::optional FifoClient::read( - std::optional timeout) { - return m_impl->read(timeout); -} - -bool FifoClient::isOpen() const { return m_impl->isOpen(); } - -void FifoClient::close() { m_impl->close(); } - -} // namespace atom::connection diff --git a/src/atom/connection/fifoclient.hpp b/src/atom/connection/fifoclient.hpp deleted file mode 100644 index bfef84b6..00000000 --- a/src/atom/connection/fifoclient.hpp +++ /dev/null @@ -1,111 +0,0 @@ -/* - * fifoclient.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-1 - -Description: FIFO Client - -*************************************************/ - -#ifndef ATOM_CONNECTION_FIFOCLIENT_HPP -#define ATOM_CONNECTION_FIFOCLIENT_HPP - -#include -#include -#include -#include -#include - -namespace atom::connection { -/** - * @brief A class for interacting with a FIFO (First In, First Out) pipe. - * - * This class provides methods to read from and write to a FIFO pipe, - * handling timeouts and ensuring proper resource management. - */ -class FifoClient { -public: - /** - * @brief Constructs a FifoClient with the specified FIFO path. - * - * @param fifoPath The path to the FIFO file to be used for communication. - * - * This constructor opens the FIFO and prepares the client for - * reading and writing operations. - */ - explicit FifoClient(std::string fifoPath); - - /** - * @brief Destroys the FifoClient and closes the FIFO if it is open. - * - * This destructor ensures that all resources are released and the FIFO - * is properly closed to avoid resource leaks. - */ - ~FifoClient(); - - /** - * @brief Writes data to the FIFO. - * - * @param data The data to be written to the FIFO, as a string view. - * @param timeout Optional timeout for the write operation, in milliseconds. - * If not provided, the default is no timeout. - * @return true if the data was successfully written, false if there was an - * error. - * - * This method will attempt to write the specified data to the FIFO. - * If a timeout is specified, the operation will fail if it cannot complete - * within the given duration. - */ - auto write(std::string_view data, - std::optional timeout = std::nullopt) - -> bool; - - /** - * @brief Reads data from the FIFO. - * - * @param timeout Optional timeout for the read operation, in milliseconds. - * If not provided, the default is no timeout. - * @return An optional string containing the data read from the FIFO. - * If there is an error or no data is available, returns - * std::nullopt. - * - * This method will read data from the FIFO. If a timeout is specified, - * it will return std::nullopt if the operation cannot complete within the - * specified time. - */ - auto read(std::optional timeout = std::nullopt) - -> std::optional; - - /** - * @brief Checks if the FIFO is currently open. - * - * @return true if the FIFO is open, false otherwise. - * - * This method can be used to determine if the FIFO client is ready for - * operations. - */ - [[nodiscard]] auto isOpen() const -> bool; - - /** - * @brief Closes the FIFO. - * - * This method closes the FIFO and releases any associated resources. - * It is good practice to call this when you are done using the FIFO - * to ensure proper cleanup. - */ - void close(); - -private: - struct Impl; ///< Forward declaration of the implementation details. - std::unique_ptr m_impl; ///< Pointer to the implementation, using - ///< PImpl idiom for encapsulation. -}; - -} // namespace atom::connection - -#endif // ATOM_CONNECTION_FIFOCLIENT_HPP diff --git a/src/atom/connection/fifoserver.cpp b/src/atom/connection/fifoserver.cpp deleted file mode 100644 index cbcd3927..00000000 --- a/src/atom/connection/fifoserver.cpp +++ /dev/null @@ -1,145 +0,0 @@ -/* - * fifoserver.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-1 - -Description: FIFO Server - -*************************************************/ - -#include "fifoserver.hpp" - -#include -#include -#include -#include -#include -#include - -#ifdef _WIN32 -#include -#else -#include -#include -#include -#include -#endif - -namespace atom::connection { - -class FIFOServer::Impl { -public: - explicit Impl(std::string_view fifo_path) - : fifo_path_(fifo_path), stop_server_(false) { - // 创建 FIFO 文件 -#ifdef _WIN32 - CreateNamedPipeA(fifo_path_.c_str(), PIPE_ACCESS_DUPLEX, - PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_WAIT, - PIPE_UNLIMITED_INSTANCES, 4096, 4096, 0, NULL); -#elif __APPLE__ || __linux__ - mkfifo(fifo_path_.c_str(), 0666); -#endif - } - - ~Impl() { - stop(); - // 删除 FIFO 文件 -#ifdef _WIN32 - DeleteFileA(fifo_path_.c_str()); -#elif __APPLE__ || __linux__ - std::filesystem::remove(fifo_path_); -#endif - } - - void sendMessage(std::string message) { - { - std::scoped_lock lock(queue_mutex_); - message_queue_.emplace(std::move(message)); - } - message_cv_.notify_one(); - } - - void start() { - if (!server_thread_.joinable()) { - stop_server_ = false; - server_thread_ = std::jthread([this] { serverLoop(); }); - } - } - - void stop() { - if (server_thread_.joinable()) { - stop_server_ = true; - message_cv_.notify_one(); - server_thread_.join(); - } - } - - [[nodiscard]] bool isRunning() const { return server_thread_.joinable(); } - -private: - void serverLoop() { - while (!stop_server_) { - std::string message; - { - std::unique_lock lock(queue_mutex_); - message_cv_.wait(lock, [this] { - return stop_server_ || !message_queue_.empty(); - }); - if (stop_server_ && message_queue_.empty()) { - break; - } - if (!message_queue_.empty()) { - message = std::move(message_queue_.front()); - message_queue_.pop(); - } - } - -#ifdef _WIN32 - HANDLE pipe = CreateFileA(fifo_path_.c_str(), GENERIC_WRITE, 0, - NULL, OPEN_EXISTING, 0, NULL); - if (pipe != INVALID_HANDLE_VALUE) { - DWORD bytes_written; - WriteFile(pipe, message.c_str(), - static_cast(message.length()), &bytes_written, - NULL); - CloseHandle(pipe); - } -#elif __APPLE__ || __linux__ - int fd = open(fifo_path_.c_str(), O_WRONLY); - if (fd != -1) { - write(fd, message.c_str(), message.length()); - close(fd); - } -#endif - } - } - - std::string fifo_path_; - std::jthread server_thread_; - std::atomic_bool stop_server_; - std::queue message_queue_; - std::mutex queue_mutex_; - std::condition_variable message_cv_; -}; - -FIFOServer::FIFOServer(std::string_view fifo_path) - : impl_(std::make_unique(fifo_path)) {} - -FIFOServer::~FIFOServer() = default; - -void FIFOServer::sendMessage(std::string message) { - impl_->sendMessage(std::move(message)); -} - -void FIFOServer::start() { impl_->start(); } - -void FIFOServer::stop() { impl_->stop(); } - -bool FIFOServer::isRunning() const { return impl_->isRunning(); } - -} // namespace atom::connection diff --git a/src/atom/connection/fifoserver.hpp b/src/atom/connection/fifoserver.hpp deleted file mode 100644 index 2b71e241..00000000 --- a/src/atom/connection/fifoserver.hpp +++ /dev/null @@ -1,71 +0,0 @@ -/* - * fifoserver.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-1 - -Description: FIFO Server - -*************************************************/ - -#ifndef ATOM_CONNECTION_FIFOSERVER_HPP -#define ATOM_CONNECTION_FIFOSERVER_HPP - -#include -#include - -namespace atom::connection { - -/** - * @brief A class representing a server for handling FIFO messages. - */ -class FIFOServer { -public: - /** - * @brief Constructs a new FIFOServer object. - * - * @param fifo_path The path to the FIFO pipe. - */ - explicit FIFOServer(std::string_view fifo_path); - - /** - * @brief Destroys the FIFOServer object. - */ - ~FIFOServer(); - - /** - * @brief Sends a message through the FIFO pipe. - * - * @param message The message to be sent. - */ - void sendMessage(std::string message); - - /** - * @brief Starts the server. - */ - void start(); - - /** - * @brief Stops the server. - */ - void stop(); - - /** - * @brief Checks if the server is running. - * - * @return True if the server is running, false otherwise. - */ - [[nodiscard]] bool isRunning() const; - -private: - class Impl; - std::unique_ptr impl_; -}; - -} // namespace atom::connection - -#endif // ATOM_CONNECTION_FIFOSERVER_HPP diff --git a/src/atom/connection/sockethub.cpp b/src/atom/connection/sockethub.cpp deleted file mode 100644 index 0c942b47..00000000 --- a/src/atom/connection/sockethub.cpp +++ /dev/null @@ -1,356 +0,0 @@ -/* - * sockethub.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-1 - -Description: SocketHub类用于管理socket连接的类。 - -*************************************************/ - -#include "sockethub.hpp" - -#include -#include -#include -#include -#include -#include -#include - -#include "atom/log/loguru.hpp" - -#ifdef _WIN32 -#include -#include -#pragma comment(lib, "ws2_32.lib") -#else -#include -#include -#include -#include -#endif - -namespace atom::connection { - -class SocketHubImpl { -public: - SocketHubImpl() - : running_(false), - serverSocket(-1) -#ifdef __linux__ - , - epoll_fd(-1) -#endif - { - } - - ~SocketHubImpl() { stop(); } - - void start(int port); - void stop(); - void addHandler(std::function handler); - [[nodiscard]] auto isRunning() const -> bool; - -private: - static const int maxConnections = 10; - std::atomic running_; -#ifdef _WIN32 - SOCKET serverSocket; - std::vector clients; -#else - int serverSocket; - std::vector clients; - int epoll_fd; -#endif - std::map clientThreads_; - std::mutex clientMutex; -#if __cplusplus >= 202002L - std::jthread acceptThread; -#else - std::unique_ptr acceptThread; -#endif - - std::function handler; - - bool initWinsock(); - void cleanupWinsock(); -#ifdef _WIN32 - void closeSocket(SOCKET socket); -#else - void closeSocket(int socket); -#endif - void acceptConnections(); -#ifdef _WIN32 - void handleClientMessages(SOCKET clientSocket); -#else - void handleClientMessages(int clientSocket); -#endif - void cleanupSocket(); -}; - -SocketHub::SocketHub() : impl_(std::make_unique()) {} - -SocketHub::~SocketHub() = default; - -void SocketHub::start(int port) { impl_->start(port); } - -void SocketHub::stop() { impl_->stop(); } - -void SocketHub::addHandler(std::function handler) { - impl_->addHandler(std::move(handler)); -} - -auto SocketHub::isRunning() const -> bool { return impl_->isRunning(); } - -void SocketHubImpl::start(int port) { - if (running_.load()) { - LOG_F(WARNING, "SocketHub is already running."); - return; - } - - if (!initWinsock()) { - return; - } - - serverSocket = socket(AF_INET, SOCK_STREAM, 0); -#ifdef _WIN32 - if (serverSocket == INVALID_SOCKET) -#else - if (serverSocket < 0) -#endif - { - LOG_F(ERROR, "Failed to create server socket."); - cleanupWinsock(); - return; - } - - sockaddr_in serverAddress{}; - serverAddress.sin_family = AF_INET; - serverAddress.sin_addr.s_addr = INADDR_ANY; - serverAddress.sin_port = htons(port); - -#ifdef _WIN32 - if (bind(serverSocket, reinterpret_cast(&serverAddress), - sizeof(serverAddress)) == SOCKET_ERROR) -#else - if (bind(serverSocket, reinterpret_cast(&serverAddress), - sizeof(serverAddress)) < 0) -#endif - { - LOG_F(ERROR, "Failed to bind server socket."); - cleanupSocket(); - return; - } - -#ifdef _WIN32 - if (listen(serverSocket, maxConnections) == SOCKET_ERROR) -#else - if (listen(serverSocket, maxConnections) < 0) -#endif - { - LOG_F(ERROR, "Failed to listen on server socket."); - cleanupSocket(); - return; - } - -#ifdef __linux__ - epoll_fd = epoll_create1(0); - if (epoll_fd == -1) { - LOG_F(ERROR, "Failed to create epoll file descriptor."); - cleanupSocket(); - return; - } - - struct epoll_event event; - event.events = EPOLLIN; - event.data.fd = serverSocket; - if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, serverSocket, &event) == -1) { - LOG_F(ERROR, "Failed to add server socket to epoll."); - cleanupSocket(); - return; - } -#endif - - running_.store(true); - DLOG_F(INFO, "SocketHub started on port {}", port); - -#if __cplusplus >= 202002L - acceptThread = std::jthread(&SocketHubImpl::acceptConnections, this); -#else - acceptThread = - std::make_unique(&SocketHubImpl::acceptConnections, this); -#endif -} - -void SocketHubImpl::stop() { - if (!running_.load()) { - LOG_F(WARNING, "SocketHub is not running."); - return; - } - - running_.store(false); - - if (acceptThread.joinable()) { - acceptThread.join(); - } - - cleanupSocket(); - cleanupWinsock(); - DLOG_F(INFO, "SocketHub stopped."); -} - -void SocketHubImpl::addHandler(std::function handler) { - this->handler = std::move(handler); -} - -bool SocketHubImpl::initWinsock() { -#ifdef _WIN32 - WSADATA wsaData; - if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { - LOG_F(ERROR, "Failed to initialize Winsock."); - return false; - } -#endif - return true; -} - -void SocketHubImpl::cleanupWinsock() { -#ifdef _WIN32 - WSACleanup(); -#endif -} - -#ifdef _WIN32 -void SocketHubImpl::closeSocket(SOCKET socket) { closesocket(socket); } -#else -void SocketHubImpl::closeSocket(int socket) { close(socket); } -#endif - -void SocketHubImpl::acceptConnections() { -#ifdef __linux__ - struct epoll_event events[maxConnections]; - while (running_.load()) { - int n = epoll_wait(epoll_fd, events, maxConnections, -1); - for (int i = 0; i < n; i++) { - if (events[i].data.fd == serverSocket) { - sockaddr_in clientAddress{}; - socklen_t clientAddressLength = sizeof(clientAddress); - int clientSocket = accept( - serverSocket, reinterpret_cast(&clientAddress), - &clientAddressLength); - - if (clientSocket < 0) { - if (running_.load()) { - LOG_F(ERROR, "Failed to accept client connection."); - } - continue; - } - - struct epoll_event event; - event.events = EPOLLIN | EPOLLET; - event.data.fd = clientSocket; - if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, clientSocket, &event) == - -1) { - LOG_F(ERROR, "Failed to add client socket to epoll."); - closeSocket(clientSocket); - continue; - } - - std::scoped_lock lock(clientMutex); - clients.push_back(clientSocket); - - clientThreads_[clientSocket] = std::jthread( - &SocketHubImpl::handleClientMessages, this, clientSocket); - } else { - handleClientMessages(events[i].data.fd); - } - } - } -#else - while (running_.load()) { - sockaddr_in clientAddress{}; - socklen_t clientAddressLength = sizeof(clientAddress); - - SOCKET clientSocket = - accept(serverSocket, reinterpret_cast(&clientAddress), - &clientAddressLength); - if (clientSocket == INVALID_SOCKET) { - if (running_.load()) { - LOG_F(ERROR, "Failed to accept client connection."); - } - continue; - } - - std::scoped_lock lock(clientMutex); - clients.push_back(clientSocket); - - std::jthread(&SocketHubImpl::handleClientMessages, this, clientSocket) - .detach(); - } -#endif -} - -#ifdef _WIN32 -void SocketHubImpl::handleClientMessages(SOCKET clientSocket) { -#else -void SocketHubImpl::handleClientMessages(int clientSocket) { -#endif - char buffer[1024]; - while (running_.load()) { - memset(buffer, 0, sizeof(buffer)); - int bytesRead = recv(clientSocket, buffer, sizeof(buffer), 0); - if (bytesRead <= 0) { - { - std::scoped_lock lock(clientMutex); - closeSocket(clientSocket); - clients.erase( - std::remove(clients.begin(), clients.end(), clientSocket), - clients.end()); - } -#ifdef __linux__ - clientThreads_.erase(clientSocket); -#endif - break; - } - - std::string message(buffer, bytesRead); - if (handler) { - handler(message); - } - } -} - -void SocketHubImpl::cleanupSocket() { - { - std::scoped_lock lock(clientMutex); - for (const auto &client : clients) { - closeSocket(client); - } - clients.clear(); - } - - closeSocket(serverSocket); - -#ifdef __linux__ - if (epoll_fd != -1) { - close(epoll_fd); - epoll_fd = -1; - } -#endif - - for (auto &pair : clientThreads_) { - if (pair.second.joinable()) { - pair.second.join(); - } - } - clientThreads_.clear(); -} - -auto SocketHubImpl::isRunning() const -> bool { return running_.load(); } - -} // namespace atom::connection diff --git a/src/atom/connection/sockethub.hpp b/src/atom/connection/sockethub.hpp deleted file mode 100644 index 1f3b2f32..00000000 --- a/src/atom/connection/sockethub.hpp +++ /dev/null @@ -1,96 +0,0 @@ -/* - * sockethub.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-1 - -Description: SocketHub class for managing socket connections. - -*************************************************/ - -#ifndef ATOM_CONNECTION_SOCKETHUB_HPP -#define ATOM_CONNECTION_SOCKETHUB_HPP - -#include -#include -#include - -namespace atom::connection { - -class SocketHubImpl; - -/** - * @class SocketHub - * @brief Manages socket connections. - * - * The SocketHub class is responsible for managing socket connections. - * It provides functionality to start and stop the socket service, and - * handles multiple client connections. For each client, it spawns a - * thread to handle incoming messages. The class allows for adding - * custom message handlers that are called when a message is received - * from a client. - */ -class SocketHub { -public: - /** - * @brief Constructs a SocketHub instance. - */ - SocketHub(); - - /** - * @brief Destroys the SocketHub instance. - * - * Cleans up resources and stops any ongoing socket operations. - */ - ~SocketHub(); - - /** - * @brief Starts the socket service. - * @param port The port number on which the socket service will listen. - * - * Initializes the socket service and starts listening for incoming - * connections on the specified port. It spawns threads to handle - * each connected client. - */ - void start(int port); - - /** - * @brief Stops the socket service. - * - * Shuts down the socket service, closes all client connections, - * and stops any running threads associated with handling client - * messages. - */ - void stop(); - - /** - * @brief Adds a message handler. - * @param handler A function to handle incoming messages from clients. - * - * The provided handler function will be called with the received - * message as a string parameter. Multiple handlers can be added - * and will be called in the order they are added. - */ - void addHandler(std::function handler); - - /** - * @brief Checks if the socket service is currently running. - * @return True if the socket service is running, false otherwise. - * - * This method returns the status of the socket service, indicating - * whether it is currently active and listening for connections. - */ - [[nodiscard]] auto isRunning() const -> bool; - -private: - std::unique_ptr - impl_; ///< Pointer to the implementation details of SocketHub. -}; - -} // namespace atom::connection - -#endif // ATOM_CONNECTION_SOCKETHUB_HPP diff --git a/src/atom/connection/sshclient.cpp b/src/atom/connection/sshclient.cpp deleted file mode 100644 index f634c464..00000000 --- a/src/atom/connection/sshclient.cpp +++ /dev/null @@ -1,308 +0,0 @@ -/* - * sshclient.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-1 - -Description: SSH Client - -*************************************************/ - -#include "sshclient.hpp" - -#include - -#include "atom/error/exception.hpp" - -namespace fs = std::filesystem; - -namespace atom::connection { -SSHClient::SSHClient(const std::string &host, int port) - : host_(host), port_(port), ssh_session_(nullptr), sftp_session_(nullptr) {} - -SSHClient::~SSHClient() { - if (sftp_session_) { - sftp_free(sftp_session_); - } - if (ssh_session_) { - ssh_disconnect(ssh_session_); - ssh_free(ssh_session_); - } -} - -void SSHClient::connect(const std::string &username, - const std::string &password, int timeout) { - ssh_session_ = ssh_new(); - if (!ssh_session_) { - THROW_RUNTIME_ERROR("Failed to create SSH session."); - } - - ssh_options_set(ssh_session_, SSH_OPTIONS_HOST, host_.c_str()); - ssh_options_set(ssh_session_, SSH_OPTIONS_PORT, &port_); - ssh_options_set(ssh_session_, SSH_OPTIONS_USER, username.c_str()); - ssh_options_set(ssh_session_, SSH_OPTIONS_TIMEOUT, &timeout); - - int rc = ssh_connect(ssh_session_); - if (rc != SSH_OK) { - THROW_RUNTIME_ERROR("Failed to connect to SSH server: " + - std::string(ssh_get_error(ssh_session_))); - } - - rc = ssh_userauth_password(ssh_session_, nullptr, password.c_str()); - if (rc != SSH_AUTH_SUCCESS) { - THROW_RUNTIME_ERROR("Failed to authenticate with SSH server: " + - std::string(ssh_get_error(ssh_session_))); - } - - sftp_session_ = sftp_new(ssh_session_); - if (!sftp_session_) { - THROW_RUNTIME_ERROR("Failed to create SFTP session."); - } - - rc = sftp_init(sftp_session_); - if (rc != SSH_OK) { - THROW_RUNTIME_ERROR("Failed to initialize SFTP session: " + - std::string(ssh_get_error(ssh_session_))); - } -} - -bool SSHClient::isConnected() const { - return (ssh_session_ != nullptr && sftp_session_ != nullptr); -} - -void SSHClient::disconnect() { - if (sftp_session_) { - sftp_free(sftp_session_); - sftp_session_ = nullptr; - } - if (ssh_session_) { - ssh_disconnect(ssh_session_); - ssh_free(ssh_session_); - ssh_session_ = nullptr; - } -} - -void SSHClient::executeCommand(const std::string &command, - std::vector &output) { - ssh_channel channel = ssh_channel_new(ssh_session_); - if (!channel) { - THROW_RUNTIME_ERROR("Failed to create SSH channel."); - } - - int rc = ssh_channel_open_session(channel); - if (rc != SSH_OK) { - ssh_channel_free(channel); - THROW_RUNTIME_ERROR("Failed to open SSH channel: " + - std::string(ssh_get_error(ssh_session_))); - } - - rc = ssh_channel_request_exec(channel, command.c_str()); - if (rc != SSH_OK) { - ssh_channel_close(channel); - ssh_channel_free(channel); - THROW_RUNTIME_ERROR("Failed to execute command: " + - std::string(ssh_get_error(ssh_session_))); - } - - char buffer[256]; - int nbytes = 0; - while ((nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 0)) > - 0) { - output.emplace_back(buffer, nbytes); - } - - if (nbytes < 0) { - ssh_channel_close(channel); - ssh_channel_free(channel); - THROW_RUNTIME_ERROR("Failed to read command output: " + - std::string(ssh_get_error(ssh_session_))); - } - - ssh_channel_send_eof(channel); - ssh_channel_close(channel); - ssh_channel_free(channel); -} - -void SSHClient::executeCommands(const std::vector &commands, - std::vector> &output) { - ssh_channel channel = ssh_channel_new(ssh_session_); - if (!channel) { - THROW_RUNTIME_ERROR("Failed to create SSH channel."); - } - - int rc = ssh_channel_open_session(channel); - if (rc != SSH_OK) { - ssh_channel_free(channel); - THROW_RUNTIME_ERROR("Failed to open SSH channel: " + - std::string(ssh_get_error(ssh_session_))); - } - - for (const auto &cmd : commands) { - rc = ssh_channel_request_exec(channel, cmd.c_str()); - if (rc != SSH_OK) { - ssh_channel_close(channel); - ssh_channel_free(channel); - THROW_RUNTIME_ERROR("Failed to execute command: " + - std::string(ssh_get_error(ssh_session_))); - } - - std::vector cmd_output; - char buffer[256]; - int nbytes = 0; - while ((nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 0)) > - 0) { - cmd_output.emplace_back(buffer, nbytes); - } - - if (nbytes < 0) { - ssh_channel_close(channel); - ssh_channel_free(channel); - THROW_RUNTIME_ERROR("Failed to read command output: " + - std::string(ssh_get_error(ssh_session_))); - } - - ssh_channel_send_eof(channel); - output.push_back(std::move(cmd_output)); - } - - ssh_channel_close(channel); - ssh_channel_free(channel); -} - -bool SSHClient::fileExists(const std::string &remote_path) const { - sftp_attributes attrs = sftp_stat(sftp_session_, remote_path.c_str()); - if (attrs) { - sftp_attributes_free(attrs); - return true; - } else { - return false; - } -} - -void SSHClient::createDirectory(const std::string &remote_path, int mode) { - int rc = sftp_mkdir(sftp_session_, remote_path.c_str(), mode); - if (rc != SSH_OK) { - THROW_RUNTIME_ERROR("Failed to create remote directory: " + - remote_path); - } -} - -void SSHClient::removeFile(const std::string &remote_path) { - int rc = sftp_unlink(sftp_session_, remote_path.c_str()); - if (rc != SSH_OK) { - THROW_RUNTIME_ERROR("Failed to remove remote file: " + remote_path); - } -} - -void SSHClient::removeDirectory(const std::string &remote_path) { - int rc = sftp_rmdir(sftp_session_, remote_path.c_str()); - if (rc != SSH_OK) { - THROW_RUNTIME_ERROR("Failed to remove remote directory: " + - remote_path); - } -} - -std::vector SSHClient::listDirectory( - const std::string &remote_path) const { - std::vector file_list; - sftp_dir dir = sftp_opendir(sftp_session_, remote_path.c_str()); - if (dir) { - sftp_attributes attributes; - while ((attributes = sftp_readdir(sftp_session_, dir)) != NULL) { - file_list.push_back(attributes->name); - sftp_attributes_free(attributes); - } - sftp_closedir(dir); - } - return file_list; -} - -void SSHClient::rename(const std::string &old_path, - const std::string &new_path) { - int rc = sftp_rename(sftp_session_, old_path.c_str(), new_path.c_str()); - if (rc != SSH_OK) { - THROW_RUNTIME_ERROR("Failed to rename remote file or directory: " + - old_path + " to " + new_path); - } -} - -void SSHClient::getFileInfo(const std::string &remote_path, - sftp_attributes &attrs) { - attrs = sftp_stat(sftp_session_, remote_path.c_str()); - if (!attrs) { - THROW_RUNTIME_ERROR("Failed to get file info for remote path: " + - remote_path); - } -} - -void SSHClient::downloadFile(const std::string &remote_path, - const std::string &local_path) { - sftp_file file = - sftp_open(sftp_session_, remote_path.c_str(), OFN_READONLY, 0); - if (!file) { - THROW_RUNTIME_ERROR("Failed to open remote file for download: " + - remote_path); - } - - FILE *fp = fopen(local_path.c_str(), "wb"); - if (!fp) { - sftp_close(file); - THROW_RUNTIME_ERROR("Failed to open local file for download: " + - local_path); - } - - char buffer[256]; - int nbytes = 0; - while ((nbytes = sftp_read(file, buffer, sizeof(buffer))) > 0) { - fwrite(buffer, 1, nbytes, fp); - } - - fclose(fp); - sftp_close(file); -} - -void SSHClient::uploadFile(const std::string &local_path, - const std::string &remote_path) { - sftp_file file = - sftp_open(sftp_session_, remote_path.c_str(), OF_CREATE, OF_WRITE); - if (!file) { - THROW_RUNTIME_ERROR("Failed to open remote file for upload: " + - remote_path); - } - - FILE *fp = fopen(local_path.c_str(), "rb"); - if (!fp) { - sftp_close(file); - THROW_RUNTIME_ERROR("Failed to open local file for upload: " + - local_path); - } - - char buffer[256]; - int nbytes = 0; - while ((nbytes = fread(buffer, 1, sizeof(buffer), fp)) > 0) { - sftp_write(file, buffer, nbytes); - } - - fclose(fp); - sftp_close(file); -} - -void SSHClient::uploadDirectory(const std::string &local_path, - const std::string &remote_path) { - for (const auto &entry : fs::recursive_directory_iterator(local_path)) { - const auto &path = entry.path(); - auto relativePath = fs::relative(path, local_path); - auto remoteFilePath = remote_path + "/" + relativePath.string(); - - if (entry.is_directory()) { - createDirectory(remoteFilePath); - } else if (entry.is_regular_file()) { - uploadFile(path.string(), remoteFilePath); - } - } -} -} // namespace atom::connection diff --git a/src/atom/connection/sshclient.hpp b/src/atom/connection/sshclient.hpp deleted file mode 100644 index 3bd17e8b..00000000 --- a/src/atom/connection/sshclient.hpp +++ /dev/null @@ -1,195 +0,0 @@ -/* - * sshclient.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-1 - -Description: SSH Client - -*************************************************/ - -#ifndef ATOM_CONNECTION_SSHCLIENT_HPP -#define ATOM_CONNECTION_SSHCLIENT_HPP - -#include -#include -#include -#include -#include -#include - -#if __has_include() -#include -#include - -namespace atom::connection { - -constexpr int DEFAULT_SSH_PORT = 22; -constexpr int DEFAULT_TIMEOUT = 10; -constexpr int DEFAULT_MODE = S_NORMAL; - -/** - * @class SSHClient - * @brief A class for SSH client connection and file operations. - */ -class SSHClient { -public: - /** - * @brief Constructor. - * @param host The hostname or IP address of the SSH server. - * @param port The port number of the SSH server. Default is 22. - */ - explicit SSHClient(const std::string &host, int port = DEFAULT_SSH_PORT); - - /** - * @brief Destructor. - */ - ~SSHClient(); - - // Copy constructor - SSHClient(const SSHClient &other) = default; - - // Copy assignment operator - auto operator=(const SSHClient &other) -> SSHClient & = default; - - // Move constructor - SSHClient(SSHClient &&other) noexcept = default; - - // Move assignment operator - auto operator=(SSHClient &&other) noexcept -> SSHClient & = default; - - /** - * @brief Connects to the SSH server. - * @param username The username for authentication. - * @param password The password for authentication. - * @param timeout The connection timeout in seconds. Default is 10 seconds. - * @throws std::runtime_error if connection or authentication fails. - */ - void connect(const std::string &username, const std::string &password, - int timeout = DEFAULT_TIMEOUT); - - /** - * @brief Checks if the SSH client is connected to the server. - * @return true if connected, false otherwise. - */ - [[nodiscard]] auto isConnected() const -> bool; - - /** - * @brief Disconnects from the SSH server. - */ - void disconnect(); - - /** - * @brief Executes a single command on the SSH server. - * @param command The command to execute. - * @param output Output vector to store the command output. - * @throws std::runtime_error if command execution fails. - */ - void executeCommand(const std::string &command, - std::vector &output); - - /** - * @brief Executes multiple commands on the SSH server. - * @param commands Vector of commands to execute. - * @param output Vector of vectors to store the command outputs. - * @throws std::runtime_error if any command execution fails. - */ - void executeCommands(const std::vector &commands, - std::vector> &output); - - /** - * @brief Checks if a file exists on the remote server. - * @param remote_path The path of the remote file. - * @return true if the file exists, false otherwise. - */ - [[nodiscard]] auto fileExists(const std::string &remote_path) const -> bool; - - /** - * @brief Creates a directory on the remote server. - * @param remote_path The path of the remote directory. - * @param mode The permissions of the directory. Default is S_NORMAL. - * @throws std::runtime_error if directory creation fails. - */ - void createDirectory(const std::string &remote_path, - int mode = DEFAULT_MODE); - - /** - * @brief Removes a file from the remote server. - * @param remote_path The path of the remote file. - * @throws std::runtime_error if file removal fails. - */ - void removeFile(const std::string &remote_path); - - /** - * @brief Removes a directory from the remote server. - * @param remote_path The path of the remote directory. - * @throws std::runtime_error if directory removal fails. - */ - void removeDirectory(const std::string &remote_path); - - /** - * @brief Lists the contents of a directory on the remote server. - * @param remote_path The path of the remote directory. - * @return Vector of strings containing the names of the directory contents. - * @throws std::runtime_error if listing directory fails. - */ - auto listDirectory(const std::string &remote_path) const - -> std::vector; - - /** - * @brief Renames a file or directory on the remote server. - * @param old_path The current path of the remote file or directory. - * @param new_path The new path of the remote file or directory. - * @throws std::runtime_error if renaming fails. - */ - void rename(const std::string &old_path, const std::string &new_path); - - /** - * @brief Retrieves file information for a remote file. - * @param remote_path The path of the remote file. - * @param attrs Attribute struct to store the file information. - * @throws std::runtime_error if getting file information fails. - */ - void getFileInfo(const std::string &remote_path, sftp_attributes &attrs); - - /** - * @brief Downloads a file from the remote server. - * @param remote_path The path of the remote file. - * @param local_path The path of the local destination file. - * @throws std::runtime_error if file download fails. - */ - void downloadFile(const std::string &remote_path, - const std::string &local_path); - - /** - * @brief Uploads a file to the remote server. - * @param local_path The path of the local source file. - * @param remote_path The path of the remote destination file. - * @throws std::runtime_error if file upload fails. - */ - void uploadFile(const std::string &local_path, - const std::string &remote_path); - - /** - * @brief Uploads a directory to the remote server. - * @param local_path The path of the local source directory. - * @param remote_path The path of the remote destination directory. - * @throws std::runtime_error if directory upload fails. - */ - void uploadDirectory(const std::string &local_path, - const std::string &remote_path); - -private: - std::string host_; - int port_; - ssh_session ssh_session_; - sftp_session sftp_session_; -}; -} // namespace atom::connection -#endif - -#endif diff --git a/src/atom/connection/sshserver.cpp b/src/atom/connection/sshserver.cpp deleted file mode 100644 index 07d869ae..00000000 --- a/src/atom/connection/sshserver.cpp +++ /dev/null @@ -1,296 +0,0 @@ -/* - * sshserver.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-5-24 - -Description: SSH Server - -*************************************************/ - -#include "sshserver.hpp" - -#include -#include -#include -#include -#include -#include -#include - -#ifdef _WIN32 -#include -#include -#else -#include -#include -#endif - -#include "atom/error/exception.hpp" - -namespace atom::connection { -class SshServer::Impl { -public: - explicit Impl(const std::filesystem::path& configFile) - : configFile_(configFile) { - loadConfig(); - } - - void start() { - if (isRunning()) { - THROW_RUNTIME_ERROR("SSH server is already running"); - } - - saveConfig(); - -#ifdef _WIN32 - std::string command = - "start /b sshd -f \"" + configFile_.string() + "\""; - system(command.c_str()); -#else - std::string command = - "/usr/sbin/sshd -f \"" + configFile_.string() + "\" -D &"; - system(command.c_str()); -#endif - } - - void stop() { - if (!isRunning()) { - THROW_RUNTIME_ERROR("SSH server is not running"); - } - -#ifdef _WIN32 - system("taskkill /F /IM sshd.exe > nul"); -#else - system("pkill -f sshd"); -#endif - } - - bool isRunning() const { -#ifdef _WIN32 - HANDLE snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0); - if (snapshot == INVALID_HANDLE_VALUE) { - return false; - } - - PROCESSENTRY32 entry{}; - entry.dwSize = sizeof(entry); - - if (!Process32First(snapshot, &entry)) { - CloseHandle(snapshot); - return false; - } - - do { - if (_stricmp(entry.szExeFile, "sshd.exe") == 0) { - CloseHandle(snapshot); - return true; - } - } while (Process32Next(snapshot, &entry)); - - CloseHandle(snapshot); - return false; -#else - std::array buffer{}; - std::string result; - std::unique_ptr pipe(popen("pgrep sshd", "r"), - pclose); - if (!pipe) { - THROW_RUNTIME_ERROR("Failed to execute pgrep command"); - } - while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { - result += buffer.data(); - } - return !result.empty(); -#endif - } - - void setPort(int port) { port_ = port; } - - int getPort() const { return port_; } - - void setListenAddress(const std::string& address) { - listenAddress_ = address; - } - - std::string getListenAddress() const { return listenAddress_; } - - void setHostKey(const std::filesystem::path& keyFile) { - hostKey_ = keyFile; - } - - std::filesystem::path getHostKey() const { return hostKey_; } - - void setAuthorizedKeys(const std::vector& keyFiles) { - authorizedKeys_ = keyFiles; - } - - std::vector getAuthorizedKeys() const { - return authorizedKeys_; - } - - void allowRootLogin(bool allow) { allowRootLogin_ = allow; } - - bool isRootLoginAllowed() const { return allowRootLogin_; } - - void setPasswordAuthentication(bool enable) { - passwordAuthentication_ = enable; - } - - bool isPasswordAuthenticationEnabled() const { - return passwordAuthentication_; - } - - void setSubsystem(const std::string& name, const std::string& command) { - subsystems_[name] = command; - } - - void removeSubsystem(const std::string& name) { subsystems_.erase(name); } - - std::string getSubsystem(const std::string& name) const { - auto it = subsystems_.find(name); - if (it != subsystems_.end()) { - return it->second; - } - return {}; - } - -private: - void loadConfig() { - std::ifstream file(configFile_); - if (!file) { - THROW_RUNTIME_ERROR( - "Failed to open SSH server configuration file"); - } - - std::string line; - while (std::getline(file, line)) { - std::istringstream iss(line); - std::string key, value; - if (std::getline(iss, key, ' ') && std::getline(iss, value)) { - if (key == "Port") { - port_ = std::stoi(value); - } else if (key == "ListenAddress") { - listenAddress_ = value; - } else if (key == "HostKey") { - hostKey_ = value; - } else if (key == "AuthorizedKeysFile") { - authorizedKeys_.push_back(value); - } else if (key == "PermitRootLogin") { - allowRootLogin_ = (value == "yes"); - } else if (key == "PasswordAuthentication") { - passwordAuthentication_ = (value == "yes"); - } else if (key == "Subsystem") { - std::istringstream subsystemIss(value); - std::string subsystemName, subsystemCommand; - if (std::getline(subsystemIss, subsystemName, ' ') && - std::getline(subsystemIss, subsystemCommand)) { - subsystems_[subsystemName] = subsystemCommand; - } - } - } - } - } - - void saveConfig() { - std::ofstream file(configFile_); - if (!file) { - THROW_RUNTIME_ERROR( - "Failed to save SSH server configuration file"); - } - - file << "Port " << port_ << '\n'; - file << "ListenAddress " << listenAddress_ << '\n'; - file << "HostKey " << hostKey_.string() << '\n'; - for (const auto& keyFile : authorizedKeys_) { - file << "AuthorizedKeysFile " << keyFile.string() << '\n'; - } - file << "PermitRootLogin " << (allowRootLogin_ ? "yes" : "no") << '\n'; - file << "PasswordAuthentication " - << (passwordAuthentication_ ? "yes" : "no") << '\n'; - for (const auto& [name, command] : subsystems_) { - file << "Subsystem " << name << " " << command << '\n'; - } - } - - std::filesystem::path configFile_; - int port_ = 22; - std::string listenAddress_ = "0.0.0.0"; - std::filesystem::path hostKey_; - std::vector authorizedKeys_; - bool allowRootLogin_ = false; - bool passwordAuthentication_ = false; - std::unordered_map subsystems_; -}; - -SshServer::SshServer(const std::filesystem::path& configFile) - : impl_(std::make_unique(configFile)) {} - -SshServer::~SshServer() = default; - -void SshServer::start() { impl_->start(); } - -void SshServer::stop() { impl_->stop(); } - -bool SshServer::isRunning() const { return impl_->isRunning(); } - -void SshServer::setPort(int port) { impl_->setPort(port); } - -int SshServer::getPort() const { return impl_->getPort(); } - -void SshServer::setListenAddress(const std::string& address) { - impl_->setListenAddress(address); -} - -std::string SshServer::getListenAddress() const { - return impl_->getListenAddress(); -} - -void SshServer::setHostKey(const std::filesystem::path& keyFile) { - impl_->setHostKey(keyFile); -} - -std::filesystem::path SshServer::getHostKey() const { - return impl_->getHostKey(); -} - -void SshServer::setAuthorizedKeys( - const std::vector& keyFiles) { - impl_->setAuthorizedKeys(keyFiles); -} - -std::vector SshServer::getAuthorizedKeys() const { - return impl_->getAuthorizedKeys(); -} - -void SshServer::allowRootLogin(bool allow) { impl_->allowRootLogin(allow); } - -bool SshServer::isRootLoginAllowed() const { - return impl_->isRootLoginAllowed(); -} - -void SshServer::setPasswordAuthentication(bool enable) { - impl_->setPasswordAuthentication(enable); -} - -bool SshServer::isPasswordAuthenticationEnabled() const { - return impl_->isPasswordAuthenticationEnabled(); -} - -void SshServer::setSubsystem(const std::string& name, - const std::string& command) { - impl_->setSubsystem(name, command); -} - -void SshServer::removeSubsystem(const std::string& name) { - impl_->removeSubsystem(name); -} - -std::string SshServer::getSubsystem(const std::string& name) const { - return impl_->getSubsystem(name); -} -} // namespace atom::connection diff --git a/src/atom/connection/sshserver.hpp b/src/atom/connection/sshserver.hpp deleted file mode 100644 index b36ea71e..00000000 --- a/src/atom/connection/sshserver.hpp +++ /dev/null @@ -1,216 +0,0 @@ -/* - * sshserver.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-5-24 - -Description: SSH Server - -*************************************************/ - -#ifndef ATOM_CONNECTION_SSHSERVER_HPP -#define ATOM_CONNECTION_SSHSERVER_HPP - -#include -#include -#include -#include - -#include "atom/type/noncopyable.hpp" - -#include "atom/macro.hpp" - -namespace atom::connection { -/** - * @class SshServer - * @brief Represents an SSH server for handling secure shell connections. - * - * This class provides methods to configure and manage an SSH server, handling - * client connections and user authentication through various methods including - * public key and password authentication. - */ -class SshServer : public NonCopyable { -public: - /** - * @brief Constructor for SshServer. - * - * Initializes the SSH server with a specified configuration file. - * - * @param configFile The path to the configuration file for the SSH server. - */ - explicit SshServer(const std::filesystem::path& configFile); - - /** - * @brief Destructor for SshServer. - * - * Cleans up resources used by the SSH server. - */ - ~SshServer() override; - - /** - * @brief Starts the SSH server. - * - * This method will begin listening for incoming connections on the - * configured port and address. - */ - void start(); - - /** - * @brief Stops the SSH server. - * - * This method will stop the server from accepting new connections and - * cleanly shut down any existing connections. - */ - void stop(); - - /** - * @brief Checks if the SSH server is currently running. - * - * @return true if the server is running, false otherwise. - */ - ATOM_NODISCARD auto isRunning() const -> bool; - - /** - * @brief Sets the port on which the SSH server listens for connections. - * - * @param port The port number to listen on. - * - * This method updates the server's listening port to the specified value. - */ - void setPort(int port); - - /** - * @brief Gets the port on which the SSH server is listening. - * - * @return The current listening port. - */ - ATOM_NODISCARD auto getPort() const -> int; - - /** - * @brief Sets the address on which the SSH server listens for connections. - * - * @param address The IP address or hostname for listening. - * - * The server will bind to this address, allowing connections from it. - */ - void setListenAddress(const std::string& address); - - /** - * @brief Gets the address on which the SSH server is listening. - * - * @return The current listening address as a string. - */ - ATOM_NODISCARD auto getListenAddress() const -> std::string; - - /** - * @brief Sets the host key file used for SSH connections. - * - * @param keyFile The path to the host key file. - * - * The host key is used to establish the identity of the server, - * enabling secure communication with clients. - */ - void setHostKey(const std::filesystem::path& keyFile); - - /** - * @brief Gets the path to the host key file. - * - * @return The current host key file path. - */ - ATOM_NODISCARD auto getHostKey() const -> std::filesystem::path; - - /** - * @brief Sets the list of authorized public key files for user - * authentication. - * - * @param keyFiles A vector of paths to public key files. - * - * This method updates the SSH server to allow authentication using the - * specified public keys. - */ - void setAuthorizedKeys(const std::vector& keyFiles); - - /** - * @brief Gets the list of authorized public key files. - * - * @return A vector of paths to authorized public key files. - */ - ATOM_NODISCARD auto getAuthorizedKeys() const - -> std::vector; - - /** - * @brief Enables or disables root login to the SSH server. - * - * @param allow true to permit root login, false to deny it. - * - * This method must be configured with caution, as enabling root login - * can pose a security risk. - */ - void allowRootLogin(bool allow); - - /** - * @brief Checks if root login is allowed. - * - * @return true if root login is permitted, false otherwise. - */ - ATOM_NODISCARD auto isRootLoginAllowed() const -> bool; - - /** - * @brief Enables or disables password authentication for the SSH server. - * - * @param enable true to enable password authentication, false to disable - * it. - */ - void setPasswordAuthentication(bool enable); - - /** - * @brief Checks if password authentication is enabled. - * - * @return true if password authentication is enabled, false otherwise. - */ - ATOM_NODISCARD auto isPasswordAuthenticationEnabled() const -> bool; - - /** - * @brief Sets a subsystem for handling a specific command. - * - * @param name The name of the subsystem. - * @param command The command that the subsystem will execute. - * - * This allows for additional functionality to be added to the SSH server, - * such as file transfers or other custom commands. - */ - void setSubsystem(const std::string& name, const std::string& command); - - /** - * @brief Removes a previously set subsystem by name. - * - * @param name The name of the subsystem to remove. - * - * After this method is called, the subsystem will no longer be available. - */ - void removeSubsystem(const std::string& name); - - /** - * @brief Gets the command associated with a subsystem by name. - * - * @param name The name of the subsystem. - * @return The command associated with the subsystem. - * - * If the subsystem does not exist, an empty string may be returned. - */ - ATOM_NODISCARD auto getSubsystem(const std::string& name) const - -> std::string; - -private: - class Impl; ///< Forward declaration of the implementation class. - std::unique_ptr impl_; ///< Pointer to the implementation object - ///< holding the core functionalities. -}; - -} // namespace atom::connection - -#endif // ATOM_CONNECTION_SSHSERVER_HPP diff --git a/src/atom/connection/tcpclient.cpp b/src/atom/connection/tcpclient.cpp deleted file mode 100644 index da33e119..00000000 --- a/src/atom/connection/tcpclient.cpp +++ /dev/null @@ -1,300 +0,0 @@ -/* - * tcpclient.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-5-24 - -Description: TCP Client Class - -*************************************************/ - -#include "tcpclient.hpp" - -#include -#include -#include -#include - -#ifdef _WIN32 -#include -#include -#pragma comment(lib, "ws2_32.lib") -#else -#include -#include -#include -#include -#include -#endif - -#include "atom/error/exception.hpp" - -namespace atom::connection { -class TcpClient::Impl { -public: - Impl() { -#ifdef _WIN32 - WSADATA wsaData; - int result = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (result != 0) { - THROW_RUNTIME_ERROR("WSAStartup failed"); - } -#endif - socket_ = socket(AF_INET, SOCK_STREAM, 0); - if (socket_ < 0) { - THROW_RUNTIME_ERROR("Socket creation failed"); - } - -#ifdef __linux__ - epoll_fd_ = epoll_create1(0); - if (epoll_fd_ == -1) { - THROW_RUNTIME_ERROR("Failed to create epoll file descriptor"); - } -#endif - } - - ~Impl() { - disconnect(); -#ifdef _WIN32 - WSACleanup(); -#endif -#ifdef __linux__ - close(epoll_fd_); -#endif - } - - bool connect(const std::string& host, int port, - std::chrono::milliseconds timeout) { - struct hostent* server = gethostbyname(host.c_str()); - if (server == nullptr) { - errorMessage_ = "Host not found"; - return false; - } - - struct sockaddr_in serverAddress {}; - serverAddress.sin_family = AF_INET; - std::memcpy(&serverAddress.sin_addr.s_addr, server->h_addr, - server->h_length); - serverAddress.sin_port = htons(port); - - if (timeout > std::chrono::milliseconds::zero()) { -#ifdef _WIN32 - DWORD tv = timeout.count(); - setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, - reinterpret_cast(&tv), sizeof(tv)); - setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, - reinterpret_cast(&tv), sizeof(tv)); -#else - struct timeval tv; - tv.tv_sec = timeout.count() / 1000; - tv.tv_usec = (timeout.count() % 1000) * 1000; - setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); - setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); -#endif - } - - if (::connect(socket_, - reinterpret_cast(&serverAddress), - sizeof(serverAddress)) < 0) { - errorMessage_ = "Connection failed"; - return false; - } - - connected_ = true; - -#ifdef __linux__ - struct epoll_event event; - event.events = EPOLLIN | EPOLLOUT; - event.data.fd = socket_; - if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, socket_, &event) == -1) { - errorMessage_ = "Failed to add file descriptor to epoll"; - return false; - } -#endif - - return true; - } - - void disconnect() { - if (connected_) { -#ifdef _WIN32 - closesocket(socket_); -#else - close(socket_); -#endif - connected_ = false; - } - } - - bool send(const std::vector& data) { - if (!connected_) { - errorMessage_ = "Not connected"; - return false; - } - - if (::send(socket_, data.data(), data.size(), 0) < 0) { - errorMessage_ = "Send failed"; - return false; - } - - return true; - } - - std::future> receive( - size_t size, - std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) { - return std::async(std::launch::async, [this, size, timeout] { - if (timeout > std::chrono::milliseconds::zero()) { -#ifdef _WIN32 - DWORD tv = timeout.count(); - setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, - reinterpret_cast(&tv), sizeof(tv)); -#else - struct timeval tv; - tv.tv_sec = timeout.count() / 1000; - tv.tv_usec = (timeout.count() % 1000) * 1000; - setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); -#endif - } - - std::vector data(size); - ssize_t bytesRead = ::recv(socket_, data.data(), size, 0); - if (bytesRead < 0) { - errorMessage_ = "Receive failed"; - return std::vector{}; - } - data.resize(bytesRead); - return data; - }); - } - - [[nodiscard]] bool isConnected() const { return connected_; } - - [[nodiscard]] std::string getErrorMessage() const { return errorMessage_; } - - void setOnConnectedCallback(const OnConnectedCallback& callback) { - onConnectedCallback_ = callback; - } - - void setOnDisconnectedCallback(const OnDisconnectedCallback& callback) { - onDisconnectedCallback_ = callback; - } - - void setOnDataReceivedCallback(const OnDataReceivedCallback& callback) { - onDataReceivedCallback_ = callback; - } - - void setOnErrorCallback(const OnErrorCallback& callback) { - onErrorCallback_ = callback; - } - - void startReceiving(size_t bufferSize) { - stopReceiving(); - receivingThread_ = std::thread(&Impl::receivingLoop, this, bufferSize); - } - - void stopReceiving() { - if (receivingThread_.joinable()) { - receivingStopped_ = true; - receivingThread_.join(); - receivingStopped_ = false; - } - } - -private: - void receivingLoop(size_t bufferSize) { -#ifdef __linux__ - struct epoll_event events[10]; -#endif - while (!receivingStopped_) { -#ifdef __linux__ - int n = epoll_wait(epoll_fd_, events, 10, -1); - for (int i = 0; i < n; i++) { - if (events[i].events & EPOLLIN) { - std::vector data = receive(bufferSize).get(); - if (!data.empty() && onDataReceivedCallback_) { - onDataReceivedCallback_(data); - } - } - } -#else - std::vector data = receive(bufferSize).get(); - if (!data.empty() && onDataReceivedCallback_) { - onDataReceivedCallback_(data); - } -#endif - } - } - -#ifdef _WIN32 - SOCKET socket_; -#else - int socket_; - int epoll_fd_; -#endif - bool connected_ = false; - std::string errorMessage_; - - OnConnectedCallback onConnectedCallback_; - OnDisconnectedCallback onDisconnectedCallback_; - OnDataReceivedCallback onDataReceivedCallback_; - OnErrorCallback onErrorCallback_; - - std::thread receivingThread_; - bool receivingStopped_ = false; -}; - -TcpClient::TcpClient() : impl_(std::make_unique()) {} - -TcpClient::~TcpClient() = default; - -bool TcpClient::connect(const std::string& host, int port, - std::chrono::milliseconds timeout) { - return impl_->connect(host, port, timeout); -} - -void TcpClient::disconnect() { impl_->disconnect(); } - -bool TcpClient::send(const std::vector& data) { - return impl_->send(data); -} - -std::future> TcpClient::receive( - size_t size, std::chrono::milliseconds timeout) { - return impl_->receive(size, timeout); -} - -bool TcpClient::isConnected() const { return impl_->isConnected(); } - -std::string TcpClient::getErrorMessage() const { - return impl_->getErrorMessage(); -} - -void TcpClient::setOnConnectedCallback(const OnConnectedCallback& callback) { - impl_->setOnConnectedCallback(callback); -} - -void TcpClient::setOnDisconnectedCallback( - const OnDisconnectedCallback& callback) { - impl_->setOnDisconnectedCallback(callback); -} - -void TcpClient::setOnDataReceivedCallback( - const OnDataReceivedCallback& callback) { - impl_->setOnDataReceivedCallback(callback); -} - -void TcpClient::setOnErrorCallback(const OnErrorCallback& callback) { - impl_->setOnErrorCallback(callback); -} - -void TcpClient::startReceiving(size_t bufferSize) { - impl_->startReceiving(bufferSize); -} - -void TcpClient::stopReceiving() { impl_->stopReceiving(); } -} // namespace atom::connection diff --git a/src/atom/connection/tcpclient.hpp b/src/atom/connection/tcpclient.hpp deleted file mode 100644 index 580705b9..00000000 --- a/src/atom/connection/tcpclient.hpp +++ /dev/null @@ -1,147 +0,0 @@ -/* - * tcpclient.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-5-24 - -Description: TCP Client Class - -*************************************************/ - -#ifndef ATOM_CONNECTION_TCPCLIENT_HPP -#define ATOM_CONNECTION_TCPCLIENT_HPP - -#include -#include -#include -#include -#include -#include - -#include "atom/type/noncopyable.hpp" - -namespace atom::connection { -/** - * @class TcpClient - * @brief Represents a TCP client for connecting to a server and - * sending/receiving data. - */ -class TcpClient : public NonCopyable { -public: - using OnConnectedCallback = - std::function; /**< Type definition for connected callback - function. */ - using OnDisconnectedCallback = - std::function; /**< Type definition for disconnected callback - function. */ - using OnDataReceivedCallback = std::function&)>; /**< Type definition for data received - callback function. */ - using OnErrorCallback = - std::function; /**< Type definition for error - callback function. */ - - /** - * @brief Constructor. - */ - TcpClient(); - - /** - * @brief Destructor. - */ - ~TcpClient() override; - - /** - * @brief Connects to a TCP server. - * @param host The hostname or IP address of the server. - * @param port The port number of the server. - * @param timeout The connection timeout duration. - * @return True if the connection is successful, false otherwise. - */ - auto connect(const std::string& host, int port, - std::chrono::milliseconds timeout = - std::chrono::milliseconds::zero()) -> bool; - - /** - * @brief Disconnects from the server. - */ - void disconnect(); - - /** - * @brief Sends data to the server. - * @param data The data to be sent. - * @return True if the data is sent successfully, false otherwise. - */ - auto send(const std::vector& data) -> bool; - - /** - * @brief Receives data from the server. - * @param size The number of bytes to receive. - * @param timeout The receive timeout duration. - * @return The received data. - */ - auto receive(size_t size, std::chrono::milliseconds timeout = - std::chrono::milliseconds::zero()) - -> std::future>; - - /** - * @brief Checks if the client is connected to the server. - * @return True if connected, false otherwise. - */ - [[nodiscard]] auto isConnected() const -> bool; - - /** - * @brief Gets the error message in case of any error. - * @return The error message. - */ - [[nodiscard]] auto getErrorMessage() const -> std::string; - - /** - * @brief Sets the callback function to be called when connected to the - * server. - * @param callback The callback function. - */ - void setOnConnectedCallback(const OnConnectedCallback& callback); - - /** - * @brief Sets the callback function to be called when disconnected from the - * server. - * @param callback The callback function. - */ - void setOnDisconnectedCallback(const OnDisconnectedCallback& callback); - - /** - * @brief Sets the callback function to be called when data is received from - * the server. - * @param callback The callback function. - */ - void setOnDataReceivedCallback(const OnDataReceivedCallback& callback); - - /** - * @brief Sets the callback function to be called when an error occurs. - * @param callback The callback function. - */ - void setOnErrorCallback(const OnErrorCallback& callback); - - /** - * @brief Starts receiving data from the server. - * @param bufferSize The size of the receive buffer. - */ - void startReceiving(size_t bufferSize); - - /** - * @brief Stops receiving data from the server. - */ - void stopReceiving(); - -private: - class Impl; /**< Forward declaration of the implementation class. */ - std::unique_ptr impl_; /**< Pointer to the implementation object. */ -}; -} // namespace atom::connection - -#endif // ATOM_CONNECTION_TCPCLIENT_HPP diff --git a/src/atom/connection/ttybase.cpp b/src/atom/connection/ttybase.cpp deleted file mode 100644 index 0f5924a7..00000000 --- a/src/atom/connection/ttybase.cpp +++ /dev/null @@ -1,696 +0,0 @@ -#include "ttybase.hpp" - -#if defined(_WIN32) || defined(_WIN64) -#include -#else -#include -#include -#include -#include -#endif - -#include "atom/log/loguru.hpp" - -TTYBase::~TTYBase() { - if (m_PortFD != -1) { - disconnect(); - } -} - -TTYBase::TTYResponse TTYBase::checkTimeout(uint8_t timeout) { -#ifdef _WIN32 - // Windows specific implementation - COMMTIMEOUTS timeouts = {0}; - timeouts.ReadIntervalTimeout = timeout; - timeouts.ReadTotalTimeoutConstant = timeout * 1000; - timeouts.ReadTotalTimeoutMultiplier = 0; - timeouts.WriteTotalTimeoutConstant = timeout * 1000; - timeouts.WriteTotalTimeoutMultiplier = 0; - - if (!SetCommTimeouts(reinterpret_cast(m_PortFD), &timeouts)) - return TTYResponse::Errno; - - return TTYResponse::OK; -#else - if (m_PortFD == -1) { - return TTYResponse::Errno; - } - - struct timeval tv; - fd_set readout; - int retval; - - FD_ZERO(&readout); - FD_SET(m_PortFD, &readout); - - tv.tv_sec = timeout; - tv.tv_usec = 0; - - retval = select(m_PortFD + 1, &readout, nullptr, nullptr, &tv); - - if (retval > 0) { - return TTYResponse::OK; - } - if (retval == -1) { - return TTYResponse::SelectError; - } - return TTYResponse::Timeout; -#endif -} - -TTYBase::TTYResponse TTYBase::write(const uint8_t* buffer, uint32_t nbytes, - uint32_t& nbytesWritten) { - if (m_PortFD == -1) - return TTYResponse::Errno; - -#ifdef _WIN32 - // Windows specific write implementation - DWORD bytesWritten; - if (!WriteFile(reinterpret_cast(m_PortFD), buffer, nbytes, - &bytesWritten, nullptr)) - return TTYResponse::WriteError; - - nbytesWritten = bytesWritten; - return TTYResponse::OK; -#else - int bytesW = 0; - nbytesWritten = 0; - - while (nbytes > 0) { - bytesW = ::write(m_PortFD, buffer + nbytesWritten, nbytes); - - if (bytesW < 0) { - return TTYResponse::WriteError; - } - - nbytesWritten += bytesW; - nbytes -= bytesW; - } - - return TTYResponse::OK; -#endif -} - -TTYBase::TTYResponse TTYBase::writeString(std::string_view string, - uint32_t& nbytesWritten) { - return write(reinterpret_cast(string.data()), string.size(), - nbytesWritten); -} - -TTYBase::TTYResponse TTYBase::read(uint8_t* buffer, uint32_t nbytes, - uint8_t timeout, uint32_t& nbytesRead) { - if (m_PortFD == -1) { - return TTYResponse::Errno; - } - -#ifdef _WIN32 - // Windows specific read implementation - DWORD bytesRead; - if (!ReadFile(reinterpret_cast(m_PortFD), buffer, nbytes, - &bytesRead, nullptr)) - return TTYResponse::ReadError; - - nbytesRead = bytesRead; - return TTYResponse::OK; -#else - uint32_t numBytesToRead = nbytes; - int bytesRead = 0; - TTYResponse timeoutResponse = TTYResponse::OK; - nbytesRead = 0; - - while (numBytesToRead > 0) { - if ((timeoutResponse = checkTimeout(timeout)) != TTYResponse::OK) { - return timeoutResponse; - } - - bytesRead = ::read(m_PortFD, buffer + nbytesRead, numBytesToRead); - - if (bytesRead < 0) { - return TTYResponse::ReadError; - } - - nbytesRead += bytesRead; - numBytesToRead -= bytesRead; - } - - return TTYResponse::OK; -#endif -} - -TTYBase::TTYResponse TTYBase::readSection(uint8_t* buffer, uint32_t nsize, - uint8_t stopByte, uint8_t timeout, - uint32_t& nbytesRead) { - if (m_PortFD == -1) { - return TTYResponse::Errno; - } - - nbytesRead = 0; - memset(buffer, 0, nsize); - - while (nbytesRead < nsize) { - if (auto timeoutResponse = checkTimeout(timeout); - timeoutResponse != TTYResponse::OK) { - return timeoutResponse; - } - - uint8_t readChar; - int bytesRead = ::read(m_PortFD, &readChar, 1); - - if (bytesRead < 0) { - return TTYResponse::ReadError; - } - - buffer[nbytesRead++] = readChar; - - if (readChar == stopByte) { - return TTYResponse::OK; - } - } - - return TTYResponse::Overflow; -} - -TTYBase::TTYResponse TTYBase::connect(std::string_view device, uint32_t bitRate, - uint8_t wordSize, uint8_t parity, - uint8_t stopBits) { -#ifdef _WIN32 - // Windows specific implementation - HANDLE hSerial = - CreateFile(device.data(), GENERIC_READ | GENERIC_WRITE, 0, nullptr, - OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); - if (hSerial == INVALID_HANDLE_VALUE) - return TTYResponse::PortFailure; - - DCB dcbSerialParams = {0}; - dcbSerialParams.DCBlength = sizeof(dcbSerialParams); - - if (!GetCommState(hSerial, &dcbSerialParams)) { - CloseHandle(hSerial); - return TTYResponse::PortFailure; - } - - dcbSerialParams.BaudRate = bitRate; - dcbSerialParams.ByteSize = wordSize; - dcbSerialParams.StopBits = (stopBits == 1) ? ONESTOPBIT : TWOSTOPBITS; - dcbSerialParams.Parity = parity; - - if (!SetCommState(hSerial, &dcbSerialParams)) { - CloseHandle(hSerial); - return TTYResponse::PortFailure; - } - - m_PortFD = reinterpret_cast(hSerial); - return TTYResponse::OK; -#elif defined(BSD) && !defined(__GNU__) - int t_fd = -1; - int bps; - int handshake; - struct termios tty_setting; - - // Open the serial port read/write, with no controlling terminal, and don't - // wait for a connection. The O_NONBLOCK flag also causes subsequent I/O on - // the device to be non-blocking. See open(2) ("man 2 open") for details. - - t_fd = open(device, O_RDWR | O_NOCTTY | O_NONBLOCK); - if (t_fd == -1) { - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "Error opening serial port (%s) - %s(%d).", device, - strerror(errno), errno); - goto error; - } - - // Note that open() follows POSIX semantics: multiple open() calls to the - // same file will succeed unless the TIOCEXCL ioctl is issued. This will - // prevent additional opens except by root-owned processes. See tty(4) ("man - // 4 tty") and ioctl(2) ("man 2 ioctl") for details. - - if (ioctl(t_fd, TIOCEXCL) == -1) { - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "Error setting TIOCEXCL on %s - %s(%d).", device, - strerror(errno), errno); - goto error; - } - - // Now that the device is open, clear the O_NONBLOCK flag so subsequent I/O - // will block. See fcntl(2) ("man 2 fcntl") for details. - - if (fcntl(t_fd, F_SETFL, 0) == -1) { - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "Error clearing O_NONBLOCK %s - %s(%d).", device, - strerror(errno), errno); - goto error; - } - - // Get the current options and save them so we can restore the default - // settings later. - if (tcgetattr(t_fd, &tty_setting) == -1) { - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "Error getting tty attributes %s - %s(%d).", device, - strerror(errno), errno); - goto error; - } - - // Set raw input (non-canonical) mode, with reads blocking until either a - // single character has been received or a one second timeout expires. See - // tcsetattr(4) ("man 4 tcsetattr") and termios(4) ("man 4 termios") for - // details. - - cfmakeraw(&tty_setting); - tty_setting.c_cc[VMIN] = 1; - tty_setting.c_cc[VTIME] = 10; - - // The baud rate, word length, and handshake options can be set as follows: - switch (bit_rate) { - case 0: - bps = B0; - break; - case 50: - bps = B50; - break; - case 75: - bps = B75; - break; - case 110: - bps = B110; - break; - case 134: - bps = B134; - break; - case 150: - bps = B150; - break; - case 200: - bps = B200; - break; - case 300: - bps = B300; - break; - case 600: - bps = B600; - break; - case 1200: - bps = B1200; - break; - case 1800: - bps = B1800; - break; - case 2400: - bps = B2400; - break; - case 4800: - bps = B4800; - break; - case 9600: - bps = B9600; - break; - case 19200: - bps = B19200; - break; - case 38400: - bps = B38400; - break; - case 57600: - bps = B57600; - break; - case 115200: - bps = B115200; - break; - case 230400: - bps = B230400; - break; - default: - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "connect: %d is not a valid bit rate.", bit_rate); - return TTY_PARAM_ERROR; - } - - cfsetspeed(&tty_setting, bps); // Set baud rate - /* word size */ - switch (word_size) { - case 5: - tty_setting.c_cflag |= CS5; - break; - case 6: - tty_setting.c_cflag |= CS6; - break; - case 7: - tty_setting.c_cflag |= CS7; - break; - case 8: - tty_setting.c_cflag |= CS8; - break; - default: - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "connect: %d is not a valid data bit count.", - word_size); - return TTY_PARAM_ERROR; - } - - /* parity */ - switch (parity) { - case PARITY_NONE: - break; - case PARITY_EVEN: - tty_setting.c_cflag |= PARENB; - break; - case PARITY_ODD: - tty_setting.c_cflag |= PARENB | PARODD; - break; - default: - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "connect: %d is not a valid parity selection value.", - parity); - return TTY_PARAM_ERROR; - } - - /* stop_bits */ - switch (stop_bits) { - case 1: - break; - case 2: - tty_setting.c_cflag |= CSTOPB; - break; - default: - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "connect: %d is not a valid stop bit count.", - stop_bits); - return TTY_PARAM_ERROR; - } - -#if defined(MAC_OS_X_VERSION_10_4) && \ - (MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_4) - // Starting with Tiger, the IOSSIOSPEED ioctl can be used to set arbitrary - // baud rates other than those specified by POSIX. The driver for the - // underlying serial hardware ultimately determines which baud rates can be - // used. This ioctl sets both the input and output speed. - - speed_t speed = 14400; // Set 14400 baud - if (ioctl(t_fd, IOSSIOSPEED, &speed) == -1) { - IDLog("Error calling ioctl(..., IOSSIOSPEED, ...) - %s(%d).\n", - strerror(errno), errno); - } -#endif - - // Cause the new options to take effect immediately. - if (tcsetattr(t_fd, TCSANOW, &tty_setting) == -1) { - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "Error setting tty attributes %s - %s(%d).", device, - strerror(errno), errno); - goto error; - } - - // To set the modem handshake lines, use the following ioctls. - // See tty(4) ("man 4 tty") and ioctl(2) ("man 2 ioctl") for details. - - if (ioctl(t_fd, TIOCSDTR) == -1) // Assert Data Terminal Ready (DTR) - { - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "Error asserting DTR %s - %s(%d).", device, - strerror(errno), errno); - } - - if (ioctl(t_fd, TIOCCDTR) == -1) // Clear Data Terminal Ready (DTR) - { - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "Error clearing DTR %s - %s(%d).", device, strerror(errno), - errno); - } - - handshake = TIOCM_DTR | TIOCM_RTS | TIOCM_CTS | TIOCM_DSR; - if (ioctl(t_fd, TIOCMSET, &handshake) == -1) - // Set the modem lines depending on the bits set in handshake - { - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "Error setting handshake lines %s - %s(%d).", device, - strerror(errno), errno); - } - - // To read the state of the modem lines, use the following ioctl. - // See tty(4) ("man 4 tty") and ioctl(2) ("man 2 ioctl") for details. - - if (ioctl(t_fd, TIOCMGET, &handshake) == -1) - // Store the state of the modem lines in handshake - { - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "Error getting handshake lines %s - %s(%d).", device, - strerror(errno), errno); - } - - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "Handshake lines currently set to %d", handshake); - -#if defined(MAC_OS_X_VERSION_10_3) && \ - (MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_3) - unsigned long mics = 1UL; - - // Set the receive latency in microseconds. Serial drivers use this value to - // determine how often to dequeue characters received by the hardware. Most - // applications don't need to set this value: if an app reads lines of - // characters, the app can't do anything until the line termination - // character has been received anyway. The most common applications which - // are sensitive to read latency are MIDI and IrDA applications. - - if (ioctl(t_fd, IOSSDATALAT, &mics) == -1) { - // set latency to 1 microsecond - DEBUGFDEVICE(m_DriverName, m_DebugChannel, - "Error setting read latency %s - %s(%d).\n", device, - strerror(errno), errno); - goto error; - } -#endif - - m_PortFD = t_fd; - /* return success */ - return TTY_OK; - - // Failure path -error: - if (t_fd != -1) { - close(t_fd); - m_PortFD = -1; - } - - return TTY_PORT_FAILURE; -#else - int tFd = open(device.data(), O_RDWR | O_NOCTTY); - if (tFd == -1) { - LOG_F(ERROR, "Error opening {}: {}", device.data(), strerror(errno)); - m_PortFD = -1; - return TTYResponse::PortFailure; - } - - termios ttySetting{}; - if (tcgetattr(tFd, &ttySetting) == -1) { - LOG_F(ERROR, "Error getting {} tty attributes: {}", device.data(), - strerror(errno)); - return TTYResponse::PortFailure; - } - - int bps; - switch (bitRate) { - case 0: - bps = B0; - break; - case 50: - bps = B50; - break; - case 75: - bps = B75; - break; - case 110: - bps = B110; - break; - case 134: - bps = B134; - break; - case 150: - bps = B150; - break; - case 200: - bps = B200; - break; - case 300: - bps = B300; - break; - case 600: - bps = B600; - break; - case 1200: - bps = B1200; - break; - case 1800: - bps = B1800; - break; - case 2400: - bps = B2400; - break; - case 4800: - bps = B4800; - break; - case 9600: - bps = B9600; - break; - case 19200: - bps = B19200; - break; - case 38400: - bps = B38400; - break; - case 57600: - bps = B57600; - break; - case 115200: - bps = B115200; - break; - case 230400: - bps = B230400; - break; - default: - LOG_F(ERROR, "connect: {} is not a valid bit rate.", bitRate); - return TTYResponse::ParamError; - } - - // Set baud rate - if ((cfsetispeed(&ttySetting, bps) < 0) || - (cfsetospeed(&ttySetting, bps) < 0)) { - LOG_F(ERROR, "connect: failed setting bit rate."); - return TTYResponse::PortFailure; - } - - ttySetting.c_cflag &= ~(CSIZE | CSTOPB | PARENB | PARODD | HUPCL | CRTSCTS); - ttySetting.c_cflag |= (CLOCAL | CREAD); - - // Set word size - switch (wordSize) { - case 5: - ttySetting.c_cflag |= CS5; - break; - case 6: - ttySetting.c_cflag |= CS6; - break; - case 7: - ttySetting.c_cflag |= CS7; - break; - case 8: - ttySetting.c_cflag |= CS8; - break; - default: - LOG_F(ERROR, "connect: {} is not a valid data bit count.", - wordSize); - return TTYResponse::ParamError; - } - - // Set parity - if (parity == 1) { - ttySetting.c_cflag |= PARENB; - } else if (parity == 2) { - ttySetting.c_cflag |= PARENB | PARODD; - } else { - LOG_F(ERROR, "connect: {} is not a valid parity setting.", parity); - return TTYResponse::ParamError; - } - - // Set stop bits - if (stopBits == 2) { - ttySetting.c_cflag |= CSTOPB; - } else if (stopBits != 1) { - LOG_F(ERROR, "connect: {} is not a valid stop bit count.", stopBits); - return TTYResponse::ParamError; - } - - /* Ignore bytes with parity errors and make terminal raw and dumb.*/ - ttySetting.c_iflag &= - ~(PARMRK | ISTRIP | IGNCR | ICRNL | INLCR | IXOFF | IXON | IXANY); - ttySetting.c_iflag |= INPCK | IGNPAR | IGNBRK; - - /* Raw output.*/ - ttySetting.c_oflag &= ~(OPOST | ONLCR); - - /* Local Modes - Don't echo characters. Don't generate signals. - Don't process any characters.*/ - ttySetting.c_lflag &= - ~(ICANON | ECHO | ECHOE | ISIG | IEXTEN | NOFLSH | TOSTOP); - ttySetting.c_lflag |= NOFLSH; - - /* blocking read until 1 char arrives */ - ttySetting.c_cc[VMIN] = 1; - ttySetting.c_cc[VTIME] = 0; - - tcflush(tFd, TCIOFLUSH); - - // Set raw input mode (non-canonical) - cfmakeraw(&ttySetting); - - // Set the new attributes for the port - if (tcsetattr(tFd, TCSANOW, &ttySetting) != 0) { - close(tFd); - return TTYResponse::PortFailure; - } - - m_PortFD = tFd; - return TTYResponse::OK; -#endif -} - -TTYBase::TTYResponse TTYBase::disconnect() { - if (m_PortFD == -1) { - return TTYResponse::Errno; - } - -#ifdef _WIN32 - // Windows specific disconnection - if (!CloseHandle(reinterpret_cast(m_PortFD))) - return TTYResponse::Errno; - - m_PortFD = -1; - return TTYResponse::OK; -#else - if (tcflush(m_PortFD, TCIOFLUSH) != 0 || close(m_PortFD) != 0) { - return TTYResponse::Errno; - } - - m_PortFD = -1; - return TTYResponse::OK; -#endif -} - -void TTYBase::setDebug(bool enabled) { - m_Debug = enabled; - if (m_Debug) - LOG_F(INFO, "Debugging enabled."); - else - LOG_F(INFO, "Debugging disabled."); -} - -std::string TTYBase::getErrorMessage(TTYResponse code) const { - switch (code) { - case TTYResponse::OK: - return "No Error"; - case TTYResponse::ReadError: - return "Read Error: " + std::string(strerror(errno)); - case TTYResponse::WriteError: - return "Write Error: " + std::string(strerror(errno)); - case TTYResponse::SelectError: - return "Select Error: " + std::string(strerror(errno)); - case TTYResponse::Timeout: - return "Timeout Error"; - case TTYResponse::PortFailure: - if (errno == EACCES) { - return "Port failure: Access denied. Try adding your user to " - "the dialout group and restart (sudo adduser $USER " - "dialout)"; - } else { - return "Port failure: " + std::string(strerror(errno)) + - ". Check if the device is connected to this port."; - } - case TTYResponse::ParamError: - return "Parameter Error"; - case TTYResponse::Errno: - return "Error: " + std::string(strerror(errno)); - case TTYResponse::Overflow: - return "Read Overflow Error"; - default: - return "Unknown Error"; - } -} diff --git a/src/atom/connection/ttybase.hpp b/src/atom/connection/ttybase.hpp deleted file mode 100644 index 93fced49..00000000 --- a/src/atom/connection/ttybase.hpp +++ /dev/null @@ -1,165 +0,0 @@ -#ifndef ATOM_CONNECTION_TTYBASE_HPP -#define ATOM_CONNECTION_TTYBASE_HPP - -#include -#include -#include -#include -#include -#include -#include -#include - -// Windows specific includes -#ifdef _WIN32 -#include -#undef min -#undef max -#endif - -/** - * @class TTYBase - * @brief Provides a base class for handling TTY (Teletypewriter) connections. - * - * This class serves as an interface for reading from and writing to TTY - * devices, handling various responses and errors associated with the - * communication. - */ -class TTYBase { -public: - /** - * @enum TTYResponse - * @brief Enumerates possible responses from TTY operations. - */ - enum class TTYResponse { - OK = 0, ///< Operation completed successfully. - ReadError = -1, ///< Error occurred while reading from the TTY. - WriteError = -2, ///< Error occurred while writing to the TTY. - SelectError = -3, ///< Error occurred while selecting the TTY device. - Timeout = -4, ///< Operation timed out. - PortFailure = -5, ///< Failed to connect to the TTY port. - ParamError = -6, ///< Invalid parameters provided to a function. - Errno = -7, ///< An error occurred as indicated by errno. - Overflow = -8 ///< Buffer overflow occurred during an operation. - }; - - /** - * @brief Constructs a TTYBase instance with the specified driver name. - * - * @param driverName The name of the TTY driver to be used. - */ - explicit TTYBase(std::string_view driverName) : m_DriverName(driverName) {} - - /** - * @brief Destructor for TTYBase. - * - * Cleans up resources associated with the TTY connection. - */ - virtual ~TTYBase(); - - /** - * @brief Reads data from the TTY device. - * - * @param buffer Pointer to the buffer where read data will be stored. - * @param nbytes The number of bytes to read from the TTY. - * @param timeout Timeout duration for the read operation in seconds. - * @param nbytesRead Reference to store the actual number of bytes read. - * @return TTYResponse indicating the result of the read operation. - */ - TTYResponse read(uint8_t* buffer, uint32_t nbytes, uint8_t timeout, - uint32_t& nbytesRead); - - /** - * @brief Reads a section of data from the TTY until a stop byte is - * encountered. - * - * @param buffer Pointer to the buffer where read data will be stored. - * @param nsize The maximum number of bytes to read. - * @param stopByte The byte value that will stop the reading. - * @param timeout Timeout duration for the read operation in seconds. - * @param nbytesRead Reference to store the actual number of bytes read. - * @return TTYResponse indicating the result of the read operation. - */ - TTYResponse readSection(uint8_t* buffer, uint32_t nsize, uint8_t stopByte, - uint8_t timeout, uint32_t& nbytesRead); - - /** - * @brief Writes data to the TTY device. - * - * @param buffer Pointer to the data to be written. - * @param nbytes The number of bytes to write to the TTY. - * @param nbytesWritten Reference to store the actual number of bytes - * written. - * @return TTYResponse indicating the result of the write operation. - */ - TTYResponse write(const uint8_t* buffer, uint32_t nbytes, - uint32_t& nbytesWritten); - - /** - * @brief Writes a string to the TTY device. - * - * @param string The string to be written to the TTY. - * @param nbytesWritten Reference to store the actual number of bytes - * written. - * @return TTYResponse indicating the result of the write operation. - */ - TTYResponse writeString(std::string_view string, uint32_t& nbytesWritten); - - /** - * @brief Connects to the specified TTY device. - * - * @param device The device name or path to connect to. - * @param bitRate The baud rate for the connection. - * @param wordSize The data size (in bits) of each character. - * @param parity The parity checking mode (e.g. none, odd, even). - * @param stopBits The number of stop bits to use in communication. - * @return TTYResponse indicating the result of the connection attempt. - */ - TTYResponse connect(std::string_view device, uint32_t bitRate, - uint8_t wordSize, uint8_t parity, uint8_t stopBits); - - /** - * @brief Disconnects from the TTY device. - * - * @return TTYResponse indicating the result of the disconnection. - */ - TTYResponse disconnect(); - - /** - * @brief Enables or disables debugging information. - * - * @param enabled true to enable debugging, false to disable it. - */ - void setDebug(bool enabled); - - /** - * @brief Retrieves an error message corresponding to a given TTYResponse - * code. - * - * @param code The TTYResponse code for which to get the error message. - * @return A string containing the error message. - */ - std::string getErrorMessage(TTYResponse code) const; - - /** - * @brief Gets the file descriptor for the TTY port. - * - * @return The integer file descriptor for the TTY port. - */ - int getPortFD() const { return m_PortFD; } - -private: - /** - * @brief Checks for timeouts. - * - * @param timeout The timeout duration to check. - * @return TTYResponse indicating the result of the timeout check. - */ - TTYResponse checkTimeout(uint8_t timeout); - - int m_PortFD{-1}; ///< File descriptor for the TTY port. - bool m_Debug{false}; ///< Flag indicating whether debugging is enabled. - std::string_view m_DriverName; ///< The name of the driver for this TTY. -}; - -#endif diff --git a/src/atom/connection/udpclient.cpp b/src/atom/connection/udpclient.cpp deleted file mode 100644 index 17b1f62a..00000000 --- a/src/atom/connection/udpclient.cpp +++ /dev/null @@ -1,237 +0,0 @@ -/* - * udpclient.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-5-24 - -Description: UDP Client Class - -*************************************************/ - -#include "udpclient.hpp" -#include -#include -#include - -#ifdef _WIN32 -#include -#include -#pragma comment(lib, "ws2_32.lib") -#include -#else -#include -#include -#include -#include -#include -#endif - -#include "atom/error/exception.hpp" - -namespace atom::connection { -class UdpClient::Impl { -public: - Impl() { -#ifdef _WIN32 - WSADATA wsaData; - int result = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (result != 0) { - THROW_RUNTIME_ERROR("WSAStartup failed"); - } -#endif - socket_ = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); - if (socket_ < 0) { - THROW_RUNTIME_ERROR("Socket creation failed"); - } -#ifdef __linux__ - epoll_fd_ = epoll_create1(0); - if (epoll_fd_ == -1) { - THROW_RUNTIME_ERROR("Epoll creation failed"); - } -#endif - } - - ~Impl() { - stopReceiving(); -#ifdef _WIN32 - closesocket(socket_); - WSACleanup(); -#else - close(socket_); - close(epoll_fd_); -#endif - } - - bool bind(int port) { - struct sockaddr_in address {}; - address.sin_family = AF_INET; - address.sin_addr.s_addr = INADDR_ANY; - address.sin_port = htons(port); - - if (::bind(socket_, reinterpret_cast(&address), - sizeof(address)) < 0) { - errorMessage_ = "Bind failed"; - return false; - } - - return true; - } - - bool send(const std::string& host, int port, - const std::vector& data) { - struct hostent* server = gethostbyname(host.c_str()); - if (server == nullptr) { - errorMessage_ = "Host not found"; - return false; - } - - struct sockaddr_in address {}; - address.sin_family = AF_INET; - std::memcpy(&address.sin_addr.s_addr, server->h_addr, server->h_length); - address.sin_port = htons(port); - - if (sendto(socket_, data.data(), data.size(), 0, - reinterpret_cast(&address), - sizeof(address)) < 0) { - errorMessage_ = "Send failed"; - return false; - } - - return true; - } - - std::vector receive( - size_t size, std::string& remoteHost, int& remotePort, - std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) { - if (timeout > std::chrono::milliseconds::zero()) { -#ifdef _WIN32 - DWORD timeout_ms = static_cast(timeout.count()); - setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, - reinterpret_cast(&timeout_ms), - sizeof(timeout_ms)); -#else - struct epoll_event event; - event.events = EPOLLIN; - event.data.fd = socket_; - if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, socket_, &event) == -1) { - errorMessage_ = "Epoll control failed"; - return {}; - } - - struct epoll_event events[1]; - int nfds = epoll_wait(epoll_fd_, events, 1, timeout.count()); - if (nfds == 0) { - errorMessage_ = "Receive timeout"; - return {}; - } else if (nfds == -1) { - errorMessage_ = "Epoll wait failed"; - return {}; - } -#endif - } - - std::vector data(size); - struct sockaddr_in clientAddress {}; - socklen_t clientAddressLength = sizeof(clientAddress); - - ssize_t bytesRead = - recvfrom(socket_, data.data(), size, 0, - reinterpret_cast(&clientAddress), - &clientAddressLength); - if (bytesRead < 0) { - errorMessage_ = "Receive failed"; - return {}; - } - - data.resize(bytesRead); - remoteHost = inet_ntoa(clientAddress.sin_addr); - remotePort = ntohs(clientAddress.sin_port); - - return data; - } - - void setOnDataReceivedCallback(const OnDataReceivedCallback& callback) { - onDataReceivedCallback_ = callback; - } - - void setOnErrorCallback(const OnErrorCallback& callback) { - onErrorCallback_ = callback; - } - - void startReceiving(size_t bufferSize) { - stopReceiving(); - receivingThread_ = std::jthread(&Impl::receivingLoop, this, bufferSize); - } - - void stopReceiving() { - if (receivingThread_.joinable()) { - receivingStopped_ = true; - receivingThread_.join(); - receivingStopped_ = false; - } - } - -private: - void receivingLoop(size_t bufferSize) { - while (!receivingStopped_) { - std::string remoteHost; - int remotePort; - std::vector data = - receive(bufferSize, remoteHost, remotePort); - if (!data.empty() && onDataReceivedCallback_) { - onDataReceivedCallback_(data, remoteHost, remotePort); - } - } - } - -#ifdef _WIN32 - SOCKET socket_; -#else - int socket_; - int epoll_fd_; -#endif - std::string errorMessage_; - - OnDataReceivedCallback onDataReceivedCallback_; - OnErrorCallback onErrorCallback_; - - std::jthread receivingThread_; - std::atomic receivingStopped_ = false; -}; - -UdpClient::UdpClient() : impl_(std::make_unique()) {} - -UdpClient::~UdpClient() = default; - -bool UdpClient::bind(int port) { return impl_->bind(port); } - -bool UdpClient::send(const std::string& host, int port, - const std::vector& data) { - return impl_->send(host, port, data); -} - -std::vector UdpClient::receive(size_t size, std::string& remoteHost, - int& remotePort, - std::chrono::milliseconds timeout) { - return impl_->receive(size, remoteHost, remotePort, timeout); -} - -void UdpClient::setOnDataReceivedCallback( - const OnDataReceivedCallback& callback) { - impl_->setOnDataReceivedCallback(callback); -} - -void UdpClient::setOnErrorCallback(const OnErrorCallback& callback) { - impl_->setOnErrorCallback(callback); -} - -void UdpClient::startReceiving(size_t bufferSize) { - impl_->startReceiving(bufferSize); -} - -void UdpClient::stopReceiving() { impl_->stopReceiving(); } -} // namespace atom::connection diff --git a/src/atom/connection/udpclient.hpp b/src/atom/connection/udpclient.hpp deleted file mode 100644 index 42fa90f3..00000000 --- a/src/atom/connection/udpclient.hpp +++ /dev/null @@ -1,114 +0,0 @@ -/* - * udpclient.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-5-24 - -Description: UDP Client Class - -*************************************************/ - -#ifndef ATOM_CONNECTION_UDPCLIENT_HPP -#define ATOM_CONNECTION_UDPCLIENT_HPP - -#include -#include -#include -#include -#include - -namespace atom::connection { -/** - * @class UdpClient - * @brief Represents a UDP client for sending and receiving datagrams. - */ -class UdpClient { -public: - using OnDataReceivedCallback = std::function&, const std::string&, - int)>; /**< Type definition for data received callback function. */ - using OnErrorCallback = - std::function; /**< Type definition for error - callback function. */ - - /** - * @brief Constructor. - */ - UdpClient(); - - /** - * @brief Destructor. - */ - ~UdpClient(); - - /** - * @brief Deleted copy constructor to prevent copying. - */ - UdpClient(const UdpClient&) = delete; - - /** - * @brief Deleted copy assignment operator to prevent copying. - */ - UdpClient& operator=(const UdpClient&) = delete; - - /** - * @brief Binds the client to a specific port for receiving data. - * @param port The port number to bind to. - * @return True if the binding is successful, false otherwise. - */ - bool bind(int port); - - /** - * @brief Sends data to a specified host and port. - * @param host The destination host address. - * @param port The destination port number. - * @param data The data to be sent. - * @return True if the data is sent successfully, false otherwise. - */ - bool send(const std::string& host, int port, const std::vector& data); - - /** - * @brief Receives data from a remote host. - * @param size The number of bytes to receive. - * @param remoteHost The hostname or IP address of the remote host. - * @param remotePort The port number of the remote host. - * @param timeout The receive timeout duration. - * @return The received data. - */ - std::vector receive( - size_t size, std::string& remoteHost, int& remotePort, - std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()); - - /** - * @brief Sets the callback function to be called when data is received. - * @param callback The callback function. - */ - void setOnDataReceivedCallback(const OnDataReceivedCallback& callback); - - /** - * @brief Sets the callback function to be called when an error occurs. - * @param callback The callback function. - */ - void setOnErrorCallback(const OnErrorCallback& callback); - - /** - * @brief Starts receiving data asynchronously. - * @param bufferSize The size of the receive buffer. - */ - void startReceiving(size_t bufferSize); - - /** - * @brief Stops receiving data. - */ - void stopReceiving(); - -private: - class Impl; /**< Forward declaration of the implementation class. */ - std::unique_ptr impl_; /**< Pointer to the implementation object. */ -}; -} // namespace atom::connection -#endif // ATOM_CONNECTION_UDPCLIENT_HPP diff --git a/src/atom/connection/udpserver.cpp b/src/atom/connection/udpserver.cpp deleted file mode 100644 index f6ca28df..00000000 --- a/src/atom/connection/udpserver.cpp +++ /dev/null @@ -1,212 +0,0 @@ -/* - * udp_server.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-1-4 - -Description: A simple UDP server. - -*************************************************/ - -#include "udpserver.hpp" - -#include -#include -#include - -#ifdef _WIN32 -#include -#include -#pragma comment(lib, "Ws2_32.lib") -#else -#include -#include -#include -#include -#include -#endif - -#include "atom/log/loguru.hpp" - -namespace atom::connection { -class UdpSocketHub::Impl { -public: - Impl() : running_(false), socket_(-1) {} // Use -1 for Linux - - ~Impl() { stop(); } - - void start(int port) { - if (running_.load()) { - return; - } - - if (!initNetworking()) { - LOG_F(ERROR, "Networking initialization failed."); - return; - } - - socket_ = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); - if (socket_ == -1) { // Use -1 for Linux - LOG_F(ERROR, "Failed to create socket."); - cleanupNetworking(); - return; - } - - sockaddr_in serverAddr{}; - serverAddr.sin_family = AF_INET; - serverAddr.sin_port = htons(port); - serverAddr.sin_addr.s_addr = INADDR_ANY; - - if (bind(socket_, reinterpret_cast(&serverAddr), - sizeof(serverAddr)) < 0) { // Use < 0 for Linux - LOG_F(ERROR, "Bind failed with error."); - closeSocket(); - cleanupNetworking(); - return; - } - - running_.store(true); - receiverThread_ = std::jthread([this] { receiveMessages(); }); - } - - void stop() { - if (!running_.load()) { - return; - } - - running_.store(false); - closeSocket(); - cleanupNetworking(); - - if (receiverThread_.joinable()) { - receiverThread_.join(); - } - } - - bool isRunning() const { return running_.load(); } - - void addMessageHandler(MessageHandler handler) { - std::scoped_lock lock(handlersMutex_); - handlers_.push_back(std::move(handler)); - } - - void removeMessageHandler(MessageHandler handler) { - std::scoped_lock lock(handlersMutex_); - auto it = std::find_if( - handlers_.begin(), handlers_.end(), - [&handler](const MessageHandler& h) { - return handler.target_type() == h.target_type() && - handler.target() == - h.target(); - }); - if (it != handlers_.end()) { - handlers_.erase(it); - } - } - - void sendTo(const std::string& message, const std::string& ip, int port) { - if (!running_.load()) { - LOG_F(ERROR, "Server is not running."); - return; - } - - sockaddr_in targetAddr{}; - targetAddr.sin_family = AF_INET; - targetAddr.sin_port = htons(port); - inet_pton(AF_INET, ip.c_str(), &targetAddr.sin_addr); - - if (sendto(socket_, message.data(), message.size(), 0, - reinterpret_cast(&targetAddr), - sizeof(targetAddr)) < 0) { // Use < 0 for Linux - LOG_F(ERROR, "Failed to send message."); - } - } - -private: - bool initNetworking() { -#ifdef _WIN32 - WSADATA wsaData; - return WSAStartup(MAKEWORD(2, 2), &wsaData) == 0; -#else - return true; // On Linux, no initialization needed -#endif - } - - void cleanupNetworking() { -#ifdef _WIN32 - WSACleanup(); -#endif - } - - void closeSocket() { -#ifdef _WIN32 - closesocket(socket_); -#else - if (socket_ != -1) { - close(socket_); - } -#endif - socket_ = -1; // Use -1 for Linux - } - - void receiveMessages() { - char buffer[1024]; - sockaddr_in clientAddr{}; - socklen_t clientAddrSize = sizeof(clientAddr); - - while (running_.load()) { - const auto bytesReceived = recvfrom( - socket_, buffer, sizeof(buffer), 0, - reinterpret_cast(&clientAddr), &clientAddrSize); - if (bytesReceived < 0) { // Use < 0 for Linux - LOG_F(ERROR, "recvfrom failed with error."); - continue; - } - - std::string message(buffer, bytesReceived); - std::string clientIp = inet_ntoa(clientAddr.sin_addr); - int clientPort = ntohs(clientAddr.sin_port); - - std::scoped_lock lock(handlersMutex_); - for (const auto& handler : handlers_) { - handler(message, clientIp, clientPort); - } - } - } - - std::atomic running_; - int socket_; // Use int for Linux - std::jthread receiverThread_; - std::vector handlers_; - std::mutex handlersMutex_; -}; - -UdpSocketHub::UdpSocketHub() : impl_(std::make_unique()) {} - -UdpSocketHub::~UdpSocketHub() = default; - -void UdpSocketHub::start(int port) { impl_->start(port); } - -void UdpSocketHub::stop() { impl_->stop(); } - -bool UdpSocketHub::isRunning() const { return impl_->isRunning(); } - -void UdpSocketHub::addMessageHandler(MessageHandler handler) { - impl_->addMessageHandler(std::move(handler)); -} - -void UdpSocketHub::removeMessageHandler(MessageHandler handler) { - impl_->removeMessageHandler(std::move(handler)); -} - -void UdpSocketHub::sendTo(const std::string& message, const std::string& ip, - int port) { - impl_->sendTo(message, ip, port); -} -} // namespace atom::connection diff --git a/src/atom/connection/udpserver.hpp b/src/atom/connection/udpserver.hpp deleted file mode 100644 index e3114a99..00000000 --- a/src/atom/connection/udpserver.hpp +++ /dev/null @@ -1,97 +0,0 @@ -/* - * udp_server.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-1-4 - -Description: A simple UDP server. - -*************************************************/ - -#ifndef ATOM_CONNECTION_UDP_HPP -#define ATOM_CONNECTION_UDP_HPP - -#include -#include -#include - -namespace atom::connection { -/** - * @class UdpSocketHub - * @brief Represents a hub for managing UDP sockets and message handling. - */ -class UdpSocketHub { -public: - /** - * @brief Type definition for message handler function. - * @param message The message received. - * @param ip The IP address of the sender. - * @param port The port of the sender. - */ - using MessageHandler = - std::function; - - /** - * @brief Constructor. - */ - UdpSocketHub(); - - /** - * @brief Destructor. - */ - ~UdpSocketHub(); - - UdpSocketHub(const UdpSocketHub&) = - delete; /**< Deleted copy constructor to prevent copying. */ - UdpSocketHub& operator=(const UdpSocketHub&) = - delete; /**< Deleted copy assignment operator to prevent copying. */ - - /** - * @brief Starts the UDP socket hub and binds it to the specified port. - * @param port The port on which the UDP socket hub will listen for incoming - * messages. - */ - void start(int port); - - /** - * @brief Stops the UDP socket hub. - */ - void stop(); - - /** - * @brief Checks if the UDP socket hub is currently running. - * @return True if the UDP socket hub is running, false otherwise. - */ - bool isRunning() const; - - /** - * @brief Adds a message handler function to the UDP socket hub. - * @param handler The message handler function to add. - */ - void addMessageHandler(MessageHandler handler); - - /** - * @brief Removes a message handler function from the UDP socket hub. - * @param handler The message handler function to remove. - */ - void removeMessageHandler(MessageHandler handler); - - /** - * @brief Sends a message to the specified IP address and port. - * @param message The message to send. - * @param ip The IP address of the recipient. - * @param port The port of the recipient. - */ - void sendTo(const std::string& message, const std::string& ip, int port); - -private: - class Impl; /**< Forward declaration of the implementation class. */ - std::unique_ptr impl_; /**< Pointer to the implementation object. */ -}; -} // namespace atom::connection - -#endif diff --git a/src/atom/connection/xmake.lua b/src/atom/connection/xmake.lua deleted file mode 100644 index 852da1b2..00000000 --- a/src/atom/connection/xmake.lua +++ /dev/null @@ -1,69 +0,0 @@ --- 设置项目信息 -set_project("atom-connection") -set_version("1.0.0") -set_description("Connection Between Lithium Drivers, TCP and IPC") -set_license("GPL3") - --- 添加构建模式 -add_rules("mode.debug", "mode.release") - --- 设置构建选项 -option("enable_ssh") - set_default(false) - set_showmenu(true) - set_description("Enable SSH support") -option_end() - -option("enable_libssh") - set_default(false) - set_showmenu(true) - set_description("Enable LibSSH support") -option_end() - -option("enable_python") - set_default(false) - set_showmenu(true) - set_description("Enable Python bindings") -option_end() - --- 设置构建目标 -target("atom-connection") - set_kind("static") - add_files("*.cpp") - add_headerfiles("*.hpp") - add_packages("loguru") - if is_plat("windows") then - add_syslinks("ws2_32") - end - if has_config("enable_ssh") then - add_packages("libssh") - end - if has_config("enable_libssh") then - add_files("sshclient.cpp") - add_headerfiles("sshclient.hpp") - end - if has_config("enable_python") then - add_rules("python.pybind11_module") - add_files("_pybind.cpp") - add_deps("python") - end - --- 安装目标文件 -target("install") - set_kind("phony") - add_deps("atom-connection") - on_install(function (target) - import("package.tools.install") - local installx = package.tools.install - installx.static("atom-connection", {destdir = "/usr/local/lib"}) - end) - --- 构建项目 -target("build") - set_kind("phony") - add_deps("atom-connection") - --- 清理构建产物 -target("clean") - set_kind("phony") - add_rules("utils.clean.clean") diff --git a/src/atom/error/CMakeLists.txt b/src/atom/error/CMakeLists.txt deleted file mode 100644 index a8ccf9f2..00000000 --- a/src/atom/error/CMakeLists.txt +++ /dev/null @@ -1,55 +0,0 @@ -# CMakeLists.txt for Atom-Error -# This project is licensed under the terms of the GPL3 license. -# -# Project Name: Atom-Error -# Description: Atom Error Library -# Author: Max Qian -# License: GPL3 - -cmake_minimum_required(VERSION 3.20) -project(atom-error C CXX) - -list(APPEND ${PROJECT_NAME}_SOURCES - exception.cpp - stacktrace.cpp -) - -# Headers -list(APPEND ${PROJECT_NAME}_HEADERS - error_code.hpp - stacktrace.hpp -) - -list(APPEND ${PROJECT_NAME}_LIBS - loguru -) - -if (LINUX) -list (APPEND ${PROJECT_NAME}_LIBS - dl -) -endif() - -# Build Object Library -add_library(${PROJECT_NAME}_OBJECT OBJECT) -set_property(TARGET ${PROJECT_NAME}_OBJECT PROPERTY POSITION_INDEPENDENT_CODE 1) - -target_link_libraries(${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) - -target_sources(${PROJECT_NAME}_OBJECT - PUBLIC - ${${PROJECT_NAME}_HEADERS} - PRIVATE - ${${PROJECT_NAME}_SOURCES} -) - -target_link_libraries(${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) - -add_library(${PROJECT_NAME} SHARED) - -target_link_libraries(${PROJECT_NAME} ${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) -target_include_directories(${PROJECT_NAME} PUBLIC .) - -install(TARGETS ${PROJECT_NAME} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} -) diff --git a/src/atom/error/error_code.hpp b/src/atom/error/error_code.hpp deleted file mode 100644 index 8a9e012f..00000000 --- a/src/atom/error/error_code.hpp +++ /dev/null @@ -1,195 +0,0 @@ -/* - * error_code.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-8-10 - -Description: All of the error code - -**************************************************/ - -#ifndef ATOM_ERROR_CODE_HPP -#define ATOM_ERROR_CODE_HPP - -// 基础错误码(可选) -enum class ErrorCodeBase { - Success = 0, // 成功 - Failed = 1, // 失败 - Cancelled = 2, // 操作被取消 -}; - -// 文件操作错误 -enum class FileError : int { - None = static_cast(ErrorCodeBase::Success), - NotFound = 100, // 文件未找到 - OpenError = 101, // 无法打开 - AccessDenied = 102, // 访问被拒绝 - ReadError = 103, // 读取错误 - WriteError = 104, // 写入错误 - PermissionDenied = 105, // 权限被拒绝 - ParseError = 106, // 解析错误 - InvalidPath = 107, // 无效路径 - FileExists = 108, // 文件已存在 - DirectoryNotEmpty = 109, // 目录非空 - TooManyOpenFiles = 110, // 打开的文件过多 - DiskFull = 111, // 磁盘已满 - LoadError = 112, // 动态库加载错误 - UnLoadError = 113, // 动态卸载错误 - LockError = 114, // 文件锁错误 - FormatError = 115, // 文件格式错误 - PathTooLong = 116, // 路径过长 - FileCorrupted = 117, // 文件损坏 - UnsupportedFormat = 118, // 不支持的文件格式 -}; - -// 设备错误 -enum class DeviceError : int { - None = static_cast(ErrorCodeBase::Success), - NotSpecific = 200, - NotFound = 201, // 设备未找到 - NotSupported = 202, // 不支持的设备 - NotConnected = 203, // 设备未连接 - MissingValue = 204, // 缺少必要的值 - InvalidValue = 205, // 无效的值 - Busy = 206, // 设备忙 - - // 相机特有错误 - ExposureError = 210, - GainError = 211, - OffsetError = 212, - ISOError = 213, - CoolingError = 214, - - // 望远镜特有错误 - GotoError = 220, - ParkError = 221, - UnParkError = 222, - ParkedError = 223, - HomeError = 224, - - InitializationError = 230, // 初始化错误 - ResourceExhausted = 231, // 资源耗尽 - FirmwareUpdateFailed = 232, // 固件更新失败 - CalibrationError = 233, // 校准错误 - Overheating = 234, // 设备过热 - PowerFailure = 235, // 电源故障 -}; - -// 网络错误 -enum class NetworkError : int { - None = static_cast(ErrorCodeBase::Success), - ConnectionLost = 400, // 网络连接丢失 - ConnectionRefused = 401, // 连接被拒绝 - DNSLookupFailed = 402, // DNS查询失败 - ProtocolError = 403, // 协议错误 - SSLHandshakeFailed = 404, // SSL握手失败 - AddressInUse = 405, // 地址已在使用 - AddressNotAvailable = 406, // 地址不可用 - NetworkDown = 407, // 网络已关闭 - HostUnreachable = 408, // 主机不可达 - MessageTooLarge = 409, // 消息过大 - BufferOverflow = 410, // 缓冲区溢出 - TimeoutError = 411, // 网络超时 - BandwidthExceeded = 412, // 带宽超限 - NetworkCongested = 413, // 网络拥塞 -}; - -// 数据库错误 -enum class DatabaseError : int { - None = static_cast(ErrorCodeBase::Success), - ConnectionFailed = 500, // 数据库连接失败 - QueryFailed = 501, // 查询失败 - TransactionFailed = 502, // 事务失败 - IntegrityConstraintViolation = 503, // 违反完整性约束 - NoSuchTable = 504, // 表不存在 - DuplicateEntry = 505, // 重复条目 - DataTooLong = 506, // 数据过长 - DataTruncated = 507, // 数据被截断 - Deadlock = 508, // 死锁 - LockTimeout = 509, // 锁超时 - IndexOutOfBounds = 510, // 索引越界 - ConnectionTimeout = 511, // 连接超时 - InvalidQuery = 512, // 无效查询 -}; - -// 内存管理错误 -enum class MemoryError : int { - None = static_cast(ErrorCodeBase::Success), - AllocationFailed = 600, // 内存分配失败 - OutOfMemory = 601, // 内存不足 - AccessViolation = 602, // 内存访问违例 - BufferOverflow = 603, // 缓冲区溢出 - DoubleFree = 604, // 双重释放 - InvalidPointer = 605, // 无效指针 - MemoryLeak = 606, // 内存泄漏 - StackOverflow = 607, // 栈溢出 - CorruptedHeap = 608, // 堆损坏 -}; - -// 用户输入错误 -enum class UserInputError : int { - None = static_cast(ErrorCodeBase::Success), - InvalidInput = 700, // 无效输入 - OutOfRange = 701, // 输入值超出范围 - MissingInput = 702, // 缺少输入 - FormatError = 703, // 输入格式错误 - UnsupportedType = 704, // 不支持的输入类型 - InputTooLong = 705, // 输入过长 - InputTooShort = 706, // 输入过短 - InvalidCharacter = 707, // 无效字符 -}; - -// 配置错误 -enum class ConfigError : int { - None = static_cast(ErrorCodeBase::Success), - MissingConfig = 800, // 缺少配置文件 - InvalidConfig = 801, // 无效的配置 - ConfigParseError = 802, // 配置解析错误 - UnsupportedConfig = 803, // 不支持的配置 - ConfigConflict = 804, // 配置冲突 - InvalidOption = 805, // 无效选项 - ConfigNotSaved = 806, // 配置未保存 - ConfigLocked = 807, // 配置被锁定 -}; - -// 进程和线程错误 -enum class ProcessError : int { - None = static_cast(ErrorCodeBase::Success), - ProcessNotFound = 900, // 进程未找到 - ProcessFailed = 901, // 进程失败 - ThreadCreationFailed = 902, // 线程创建失败 - ThreadJoinFailed = 903, // 线程合并失败 - ThreadTimeout = 904, // 线程超时 - DeadlockDetected = 905, // 检测到死锁 - ProcessTerminated = 906, // 进程被终止 - InvalidProcessState = 907, // 无效的进程状态 - InsufficientResources = 908, // 资源不足 - InvalidThreadPriority = 909, // 无效的线程优先级 -}; - -// 服务器错误 -enum class ServerError : int { - None = static_cast(ErrorCodeBase::Success), - InvalidParameters = 300, // 无效参数 - InvalidFormat = 301, // 无效格式 - MissingParameters = 302, // 缺少参数 - RunFailed = 303, // 运行失败 - UnknownError = 310, // 未知错误 - UnknownCommand = 311, // 未知命令 - UnknownDevice = 312, // 未知设备 - UnknownDeviceType = 313, // 未知设备类型 - UnknownDeviceName = 314, // 未知设备名称 - UnknownDeviceID = 315, // 未知设备ID - NetworkError = 320, // 网络错误 - TimeoutError = 321, // 请求超时 - AuthenticationError = 322, // 认证失败 - PermissionDenied = 323, // 权限被拒绝 - ServerOverload = 324, // 服务器过载 - MaintenanceMode = 325, // 维护模式 -}; - -#endif diff --git a/src/atom/error/exception.cpp b/src/atom/error/exception.cpp deleted file mode 100644 index 9bfc0ab8..00000000 --- a/src/atom/error/exception.cpp +++ /dev/null @@ -1,55 +0,0 @@ -/* - * exception.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-10 - -Description: Better Exception Library - -**************************************************/ - -#include "exception.hpp" - -#include -#include -#include - -#if ENABLE_CPPTRACE -#include -#endif -#ifdef ENABLE_BOOST_STACKTRACE -#include -#endif - -namespace atom::error { -auto Exception::what() const noexcept -> const char* { - if (full_message_.empty()) { - std::ostringstream oss; - oss << "Exception at " << file_ << ":" << line_ << " in " << func_ - << "()"; - oss << " (thread " << thread_id_ << ")"; - oss << "\n\tMessage: " << message_; -#if ENABLE_CPPTRACE - oss << "\n\tStack trace:\n" - << cpptrace::generate() -#elif defined(ENABLE_BOOST_STACKTRACE) - full_message_ += std::format( - "\n\tStack trace:\n{}", boost::stacktrace::to_string(stack_trace_)); -#else - oss << "\n\tStack trace:\n" << stack_trace_.toString(); -#endif - full_message_ = oss.str(); - } - return full_message_.c_str(); -} - -auto Exception::getFile() const -> std::string { return file_; } -auto Exception::getLine() const -> int { return line_; } -auto Exception::getFunction() const -> std::string { return func_; } -auto Exception::getMessage() const -> std::string { return message_; } -auto Exception::getThreadId() const -> std::thread::id { return thread_id_; } -} // namespace atom::error diff --git a/src/atom/error/exception.hpp b/src/atom/error/exception.hpp deleted file mode 100644 index 65e07bbc..00000000 --- a/src/atom/error/exception.hpp +++ /dev/null @@ -1,564 +0,0 @@ -/* - * exception.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-10 - -Description: Better Exception Library - -**************************************************/ - -#ifndef ATOM_ERROR_EXCEPTION_HPP -#define ATOM_ERROR_EXCEPTION_HPP - -#include -#include -#include -#include - -#include "atom/macro.hpp" -#include "stacktrace.hpp" - -namespace atom::error { - -/** - * @brief Custom exception class with detailed information about the error. - */ -class Exception : public std::exception { -public: - /** - * @brief Constructs an Exception object. - * @param file The file where the exception occurred. - * @param line The line number in the file where the exception occurred. - * @param func The function where the exception occurred. - * @param args Additional arguments to provide context for the exception. - */ - template - Exception(const char *file, int line, const char *func, Args &&...args) - : file_(file), line_(line), func_(func) { - std::ostringstream oss; - ((oss << std::forward(args)), ...); - message_ = oss.str(); - } - - template - static void rethrowNested(Args &&...args) { - try { - throw; // 捕获当前异常 - } catch (...) { - std::throw_with_nested(Exception(std::forward(args)...)); - } - } - - /** - * @brief Returns a C-style string describing the exception. - * @return A pointer to a string describing the exception. - */ - auto what() const ATOM_NOEXCEPT -> const char * override; - - /** - * @brief Gets the file where the exception occurred. - * @return The file where the exception occurred. - */ - auto getFile() const -> std::string; - - /** - * @brief Gets the line number where the exception occurred. - * @return The line number where the exception occurred. - */ - auto getLine() const -> int; - - /** - * @brief Gets the function where the exception occurred. - * @return The function where the exception occurred. - */ - auto getFunction() const -> std::string; - - /** - * @brief Gets the message associated with the exception. - * @return The message associated with the exception. - */ - auto getMessage() const -> std::string; - - /** - * @brief Gets the ID of the thread where the exception occurred. - * @return The ID of the thread where the exception occurred. - */ - auto getThreadId() const -> std::thread::id; - -private: - std::string file_; /**< The file where the exception occurred. */ - int line_; /**< The line number in the file where the exception occurred. */ - std::string func_; /**< The function where the exception occurred. */ - std::string message_; /**< The message associated with the exception. */ - mutable std::string - full_message_; /**< The full message including additional context. */ - std::thread::id thread_id_ = - std::this_thread::get_id(); /**< The ID of the thread where the - exception occurred. */ - StackTrace stack_trace_; -}; - -// System error exception class -class SystemErrorException : public Exception { -public: - SystemErrorException(const char *file, int line, const char *func, - int err_code, std::string msg) - : Exception(file, line, func, msg), - error_code_(err_code), - error_message_( - std::error_code(err_code, std::generic_category()).message()) {} - - const char *what() const noexcept override { - if (what_message_.empty()) { - what_message_ = "System error [" + std::to_string(error_code_) + - "]: " + error_message_ + "\n" + Exception::what(); - } - return what_message_.c_str(); - } - -private: - int error_code_; - std::string error_message_; - mutable std::string what_message_; -}; - -// Nested exception handling -class NestedException : public Exception { -public: - explicit NestedException(const char *file, int line, const char *func, - std::exception_ptr ptr) - : Exception(file, line, func), exception_ptr_(std::move(ptr)) {} - - const char *what() const noexcept override { - if (what_message_.empty()) { - try { - std::rethrow_exception(exception_ptr_); - } catch (const std::exception &e) { - what_message_ = "Nested exception: " + std::string(e.what()); - } catch (...) { - what_message_ = "Nested unknown exception"; - } - } - return what_message_.c_str(); - } - -private: - std::exception_ptr exception_ptr_; - mutable std::string what_message_; -}; - -#define THROW_EXCEPTION(...) \ - throw atom::error::Exception(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -#define THROW_NESTED_EXCEPTION(...) \ - atom::error::Exception::rethrowNested(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -#define THROW_SYSTEM_ERROR(error_code, ...) \ - static_assert(std::is_integral::value, \ - "Error code must be an integral type"); \ - static_assert(error_code != 0, "Error code must be non-zero"); \ - throw atom::error::SystemErrorException(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, error_code, \ - __VA_ARGS__) - -// ------------------------------------------------------------------- -// Common -// ------------------------------------------------------------------- - -class RuntimeError : public Exception { -public: - using Exception::Exception; -}; - -namespace internal { -template -struct are_all_printable; - -// Base case: Empty parameter pack is printable -template <> -struct are_all_printable<> { - static constexpr bool value = true; -}; - -// Recursive case: Check if the first argument is printable and recursively -// check the rest -template -struct are_all_printable { - // Check if std::ostream can output the type - static constexpr bool value = - std::is_convertible() - << std::declval()), - std::ostream &>::value && - are_all_printable::value; -}; -} // namespace internal - -#define THROW_RUNTIME_ERROR(...) \ - throw atom::error::RuntimeError(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -#define THROW_NESTED_RUNTIME_ERROR(...) \ - atom::error::RuntimeError::rethrowNested(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class LogicError : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_LOGIC_ERROR(...) \ - throw atom::error::LogicError(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class UnlawfulOperation : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_UNLAWFUL_OPERATION(...) \ - throw atom::error::UnlawfulOperation(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class OutOfRange : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_OUT_OF_RANGE(...) \ - throw atom::error::OutOfRange(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__); - -class OverflowException : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_OVERFLOW(...) \ - throw atom::error::OverflowException(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__); - -class UnderflowException : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_UNDERFLOW(...) \ - throw atom::error::UnderflowException(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__); - -class Unkown : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_UNKOWN(...) \ - throw atom::error::Unkown(ATOM_FILE_NAME, ATOM_FILE_LINE, ATOM_FUNC_NAME, \ - __VA_ARGS__); - -// ------------------------------------------------------------------- -// Object -// ------------------------------------------------------------------- - -class ObjectAlreadyExist : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_OBJ_ALREADY_EXIST(...) \ - throw atom::error::ObjectAlreadyExist(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class ObjectAlreadyInitialized : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_OBJ_ALREADY_INITIALIZED(...) \ - throw atom::error::ObjectAlreadyInitialized( \ - ATOM_FILE_NAME, ATOM_FILE_LINE, ATOM_FUNC_NAME, __VA_ARGS__) - -class ObjectNotExist : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_OBJ_NOT_EXIST(...) \ - throw atom::error::ObjectNotExist(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class ObjectUninitialized : public Exception { -public: - using Exception::Exception; -}; - -class SystemCollapse : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_SYSTEM_COLLAPSE(...) \ - throw atom::error::SystemCollapse(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class NullPointer : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_NULL_POINTER(...) \ - throw atom::error::NullPointer(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class NotFound : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_NOT_FOUND(...) \ - throw atom::error::NotFound(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -// ------------------------------------------------------------------- -// Argument -// ------------------------------------------------------------------- - -#define THROW_OBJ_UNINITIALIZED(...) \ - throw atom::error::ObjectUninitialized(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class WrongArgument : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_WRONG_ARGUMENT(...) \ - throw atom::error::WrongArgument(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class InvalidArgument : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_INVALID_ARGUMENT(...) \ - throw atom::error::InvalidArgument(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class MissingArgument : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_MISSING_ARGUMENT(...) \ - throw atom::error::MissingArgument(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -// ------------------------------------------------------------------- -// File -// ------------------------------------------------------------------- - -class FileNotFound : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FILE_NOT_FOUND(...) \ - throw atom::error::FileNotFound(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FileNotReadable : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FILE_NOT_READABLE(...) \ - throw atom::error::FileNotReadable(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FileNotWritable : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FILE_NOT_WRITABLE(...) \ - throw atom::error::FileNotWritable(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FailToOpenFile : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_OPEN_FILE(...) \ - throw atom::error::FailToOpenFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FailToCloseFile : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_CLOSE_FILE(...) \ - throw atom::error::FailToCloseFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FailToCreateFile : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_CREATE_FILE(...) \ - throw atom::error::FailToCreateFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FailToDeleteFile : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_DELETE_FILE(...) \ - throw atom::error::FailToDeleteFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FailToCopyFile : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_COPY_FILE(...) \ - throw atom::error::FailToCopyFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FailToMoveFile : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_MOVE_FILE(...) \ - throw atom::error::FailToMoveFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FailToReadFile : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_READ_FILE(...) \ - throw atom::error::FailToReadFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FailToWriteFile : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_WRITE_FILE(...) \ - throw atom::error::FailToWriteFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -// ------------------------------------------------------------------- -// Dynamic Library -// ------------------------------------------------------------------- - -class FailToLoadDll : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_LOAD_DLL(...) \ - throw atom::error::FailToLoadDll(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FailToUnloadDll : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_UNLOAD_DLL(...) \ - throw atom::error::FailToUnloadDll(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FailToLoadSymbol : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_LOAD_SYMBOL(...) \ - throw atom::error::FailToLoadSymbol(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -// ------------------------------------------------------------------- -// Proccess Library -// ------------------------------------------------------------------- - -class FailToCreateProcess : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_CREATE_PROCESS(...) \ - throw atom::error::FailToCreateProcess(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class FailToTerminateProcess : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_FAIL_TO_TERMINATE_PROCESS(...) \ - throw atom::error::FailToTerminateProcess(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -// ------------------------------------------------------------------- -// JSON Error -// ------------------------------------------------------------------- - -class JsonParseError : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_JSON_PARSE_ERROR(...) \ - throw atom::error::JsonParseError(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class JsonValueError : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_JSON_VALUE_ERROR(...) \ - throw atom::error::JsonValueError(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -// ------------------------------------------------------------------- -// Network Error -// ------------------------------------------------------------------- - -class CurlInitializationError : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_CURL_INITIALIZATION_ERROR(...) \ - throw atom::error::CurlInitializationError(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class CurlRuntimeError : public Exception { -public: - using Exception::Exception; -}; - -#define THROW_CURL_RUNTIME_ERROR(...) \ - throw atom::error::CurlRuntimeError(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) -} // namespace atom::error - -#endif diff --git a/src/atom/error/stacktrace.cpp b/src/atom/error/stacktrace.cpp deleted file mode 100644 index 45f9939b..00000000 --- a/src/atom/error/stacktrace.cpp +++ /dev/null @@ -1,147 +0,0 @@ -/* - * stacktrace.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-10 - -Description: Enhanced StackTrace with more details - -**************************************************/ - -#include "stacktrace.hpp" -#include "atom/function/abi.hpp" - -#include -#include -#include -#include - -#ifdef _WIN32 -// clang-format off -#include -#include -// clang-format on -#if !defined(__MINGW32__) && !defined(__MINGW64__) -#pragma comment(lib, "dbghelp.lib") -#endif -#elif defined(__APPLE__) || defined(__linux__) -#include -#include -#include -#endif - -namespace atom::error { - -namespace { -#if defined(__linux__) || defined(__APPLE__) -auto processString(const std::string& input) -> std::string { - size_t startIndex = input.find("_Z"); - if (startIndex == std::string::npos) { - return input; - } - size_t endIndex = input.find('+', startIndex); - if (endIndex == std::string::npos) { - return input; - } - std::string abiName = input.substr(startIndex, endIndex - startIndex); - abiName = meta::DemangleHelper::demangle(abiName); - std::string result = input; - result.replace(startIndex, endIndex - startIndex, abiName); - return result; -} -#endif - -auto prettifyStacktrace(const std::string& input) -> std::string { - std::string output = input; - static const std::vector> REPLACEMENTS = - {{"std::__1::", "std::"}, - {"__thiscall ", ""}, - {"__cdecl ", ""}, - {", std::allocator<[^<>]+>", ""}}; - - for (const auto& [from, to] : REPLACEMENTS) { - output = std::regex_replace(output, std::regex(from), to); - } - - // Clean up spaces in template arguments - output = - std::regex_replace(output, std::regex(R"(<\s*([^<> ]+)\s*>)"), "<$1>"); - - return output; -} - -} // unnamed namespace - -StackTrace::StackTrace() { capture(); } - -auto StackTrace::toString() const -> std::string { - std::ostringstream oss; - -#ifdef _WIN32 - auto* symbol = reinterpret_cast( - calloc(sizeof(SYMBOL_INFO) + 256 * sizeof(char), 1)); - symbol->MaxNameLen = 255; - symbol->SizeOfStruct = sizeof(SYMBOL_INFO); - - for (void* frame : frames_) { - DWORD64 displacement = 0; - if (SymFromAddr(GetCurrentProcess(), reinterpret_cast(frame), - &displacement, symbol) != 0) { - std::string symbolName = symbol->Name; - oss << "\t\t" << meta::DemangleHelper::demangle("_" + symbolName) - << " - 0x" << std::hex << symbol->Address << "\n"; - } - } - free(symbol); - -#elif defined(__APPLE__) || defined(__linux__) - for (int i = 0; i < num_frames_; ++i) { - Dl_info info; - if (dladdr(frames_[i], &info) && info.dli_sname) { - std::string symbol_name = - meta::DemangleHelper::demangle(info.dli_sname); - oss << "\t\t" << symbol_name << " (" << info.dli_fname << ")\n"; - } else { - std::string_view symbol(symbols_.get()[i]); - oss << "\t\t" << processString(std::string(symbol)) << "\n"; - } - } - -#else - oss << "\t\tStack trace not available on this platform.\n"; -#endif - - return prettifyStacktrace(oss.str()); -} - -void StackTrace::capture() { -#ifdef _WIN32 - constexpr int max_frames = 64; - frames_.resize(max_frames); - SymInitialize(GetCurrentProcess(), nullptr, TRUE); - - void* framePtrs[max_frames]; - WORD capturedFrames = - CaptureStackBackTrace(0, max_frames, framePtrs, nullptr); - - frames_.resize(capturedFrames); - std::copy_n(framePtrs, capturedFrames, frames_.begin()); - -#elif defined(__APPLE__) || defined(__linux__) - constexpr int MAX_FRAMES = 64; - void* framePtrs[MAX_FRAMES]; - - num_frames_ = backtrace(framePtrs, MAX_FRAMES); - symbols_.reset(backtrace_symbols(framePtrs, num_frames_)); - frames_.assign(framePtrs, framePtrs + num_frames_); - -#else - num_frames_ = 0; -#endif -} - -} // namespace atom::error diff --git a/src/atom/error/stacktrace.hpp b/src/atom/error/stacktrace.hpp deleted file mode 100644 index 493a018a..00000000 --- a/src/atom/error/stacktrace.hpp +++ /dev/null @@ -1,70 +0,0 @@ -/* - * stacktrace.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-10 - -Description: Enhanced StackTrace with more details - -**************************************************/ - -#ifndef ATOM_ERROR_STACKTRACE_HPP -#define ATOM_ERROR_STACKTRACE_HPP - -#include -#include -#include - -namespace atom::error { - -/** - * @brief Class for capturing and representing a stack trace. - * - * This class captures the stack trace of the current - * execution context and represents it as a string, including - * file names, line numbers, and symbols if available. - */ -class StackTrace { -public: - /** - * @brief Default constructor. - * - * Constructs a StackTrace object and captures the current stack trace. - */ - StackTrace(); - - /** - * @brief Get the string representation of the stack trace. - * - * @return A string representing the captured stack trace. - */ - [[nodiscard]] auto toString() const -> std::string; - -private: - /** - * @brief Capture the current stack trace. - * - * This method captures the current stack trace based on the operating - * system. - */ - void capture(); - -#ifdef _WIN32 - std::vector frames_; /**< Vector to store stack frames on Windows. */ -#elif defined(__APPLE__) || defined(__linux__) - std::unique_ptr symbols_{ - nullptr, - &free}; /**< Pointer to store stack symbols on macOS or Linux. */ - std::vector - frames_; /**< Vector to store raw stack frame pointers. */ - int num_frames_ = 0; /**< Number of stack frames captured. */ -#endif -}; - -} // namespace atom::error - -#endif diff --git a/src/atom/error/xmake.lua b/src/atom/error/xmake.lua deleted file mode 100644 index 36ae0f56..00000000 --- a/src/atom/error/xmake.lua +++ /dev/null @@ -1,64 +0,0 @@ -set_project("atom-error") -set_version("1.0.0") -set_xmakever("2.5.1") - --- Set the C++ standard -set_languages("cxx20") - --- Add required packages -add_requires("loguru") - --- Define libraries -local atom_error_libs = { - "atom-utils" -} - -local project_packages = { - "loguru", - "dl" -} - --- Source files -local source_files = { - "error_stack.cpp", - "exception.cpp", - "stacktrace.cpp" -} - --- Header files -local header_files = { - "error_code.hpp", - "error_stack.hpp", - "stacktrace.hpp" -} - --- Object Library -target("atom-error_object") - set_kind("object") - add_files(table.unpack(source_files)) - add_headerfiles(table.unpack(header_files)) - add_packages("loguru") - if is_plat("linux") then - add_syslinks("dl") - end -target_end() - --- Static Library -target("atom-error") - set_kind("static") - add_deps("atom-error_object") - add_files(table.unpack(source_files)) - add_headerfiles(table.unpack(header_files)) - add_packages("loguru") - add_deps("atom-utils") - if is_plat("linux") then - add_syslinks("dl") - end - add_includedirs(".") - set_targetdir("$(buildir)/lib") - set_installdir("$(installdir)/lib") - set_version("1.0.0", {build = "%Y%m%d%H%M"}) - on_install(function (target) - os.cp(target:targetfile(), path.join(target:installdir(), "lib")) - end) -target_end() diff --git a/src/atom/extra/beast/http.cpp b/src/atom/extra/beast/http.cpp deleted file mode 100644 index f74d7c94..00000000 --- a/src/atom/extra/beast/http.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include "http.hpp" - -#include - -HttpClient::HttpClient(net::io_context& ioc) - : resolver_(net::make_strand(ioc)), stream_(net::make_strand(ioc)) {} - -void HttpClient::setDefaultHeader(const std::string& key, - const std::string& value) { - default_headers_[key] = value; -} - -void HttpClient::setTimeout(std::chrono::seconds timeout) { - timeout_ = timeout; -} - -auto HttpClient::uploadFile( - const std::string& host, const std::string& port, const std::string& target, - const std::string& filepath, - const std::string& field_name) -> http::response { - std::ifstream file(filepath, std::ios::binary); - if (!file) { - throw std::runtime_error("Failed to open file: " + filepath); - } - std::string fileContent((std::istreambuf_iterator(file)), - std::istreambuf_iterator()); - - std::string boundary = - "-------------------------" + std::to_string(std::time(nullptr)); - - std::string body = "--" + boundary + "\r\n"; - body += "Content-Disposition: form-data; name=\"" + field_name + - "\"; filename=\"" + - std::filesystem::path(filepath).filename().string() + "\"\r\n"; - body += "Content-Type: application/octet-stream\r\n\r\n"; - body += fileContent + "\r\n"; - body += "--" + boundary + "--\r\n"; - - std::string contentType = "multipart/form-data; boundary=" + boundary; - - return request(http::verb::post, host, port, target, 11, contentType, body); -} - -void HttpClient::downloadFile(const std::string& host, const std::string& port, - const std::string& target, - const std::string& filepath) { - auto res = request(http::verb::get, host, port, target); - std::ofstream outFile(filepath, std::ios::binary); - outFile << res.body(); -} - -void HttpClient::runWithThreadPool(size_t num_threads) { - net::thread_pool pool(num_threads); - - for (size_t i = 0; i < num_threads; ++i) { - net::post(pool, [this] { - // Example task: send a request in a thread from the pool - auto res = request(http::verb::get, "example.com", "80", "/"); - std::cout << "Response in thread pool: " << res << std::endl; - }); - } - - pool.join(); // Wait for all threads to finish -} diff --git a/src/atom/extra/beast/http.hpp b/src/atom/extra/beast/http.hpp deleted file mode 100644 index 6049ef50..00000000 --- a/src/atom/extra/beast/http.hpp +++ /dev/null @@ -1,484 +0,0 @@ -#ifndef HTTP_CLIENT_HPP -#define HTTP_CLIENT_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace beast = boost::beast; -namespace http = beast::http; -namespace net = boost::asio; -using tcp = boost::asio::ip::tcp; -using json = nlohmann::json; - -class HttpClient { -public: - /** - * @brief Constructs an HttpClient with the given I/O context. - * @param ioc The I/O context to use for asynchronous operations. - */ - explicit HttpClient(net::io_context& ioc); - - /** - * @brief Sets a default header for all requests. - * @param key The header key. - * @param value The header value. - */ - void setDefaultHeader(const std::string& key, const std::string& value); - - /** - * @brief Sets the timeout duration for the HTTP operations. - * @param timeout The timeout duration in seconds. - */ - void setTimeout(std::chrono::seconds timeout); - - /** - * @brief Sends a synchronous HTTP request. - * @tparam Body The type of the request body. - * @param method The HTTP method (verb). - * @param host The server host. - * @param port The server port. - * @param target The target URI. - * @param version The HTTP version (default is 11). - * @param content_type The content type of the request body. - * @param body The request body. - * @param headers Additional headers to include in the request. - * @return The HTTP response. - */ - template - auto request(http::verb method, const std::string& host, - const std::string& port, const std::string& target, - int version = 11, const std::string& content_type = "", - const std::string& body = "", - const std::unordered_map& headers = - {}) -> http::response; - - /** - * @brief Sends an asynchronous HTTP request. - * @tparam Body The type of the request body. - * @tparam ResponseHandler The type of the handler to call when the - * operation completes. - * @param method The HTTP method (verb). - * @param host The server host. - * @param port The server port. - * @param target The target URI. - * @param handler The handler to call when the operation completes. - * @param version The HTTP version (default is 11). - * @param content_type The content type of the request body. - * @param body The request body. - * @param headers Additional headers to include in the request. - */ - template - void asyncRequest( - http::verb method, const std::string& host, const std::string& port, - const std::string& target, ResponseHandler&& handler, int version = 11, - const std::string& content_type = "", const std::string& body = "", - const std::unordered_map& headers = {}); - - /** - * @brief Sends a synchronous HTTP request with a JSON body and returns a - * JSON response. - * @param method The HTTP method (verb). - * @param host The server host. - * @param port The server port. - * @param target The target URI. - * @param json_body The JSON body of the request. - * @param headers Additional headers to include in the request. - * @return The JSON response. - */ - auto jsonRequest(http::verb method, const std::string& host, - const std::string& port, const std::string& target, - const json& json_body = {}, - const std::unordered_map& - headers = {}) -> json; - - /** - * @brief Sends an asynchronous HTTP request with a JSON body and returns a - * JSON response. - * @tparam ResponseHandler The type of the handler to call when the - * operation completes. - * @param method The HTTP method (verb). - * @param host The server host. - * @param port The server port. - * @param target The target URI. - * @param handler The handler to call when the operation completes. - * @param json_body The JSON body of the request. - * @param headers Additional headers to include in the request. - */ - template - void asyncJsonRequest( - http::verb method, const std::string& host, const std::string& port, - const std::string& target, ResponseHandler&& handler, - const json& json_body = {}, - const std::unordered_map& headers = {}); - - /** - * @brief Uploads a file to the server. - * @param host The server host. - * @param port The server port. - * @param target The target URI. - * @param filepath The path to the file to upload. - * @param field_name The field name for the file (default is "file"). - * @return The HTTP response. - */ - auto uploadFile(const std::string& host, const std::string& port, - const std::string& target, const std::string& filepath, - const std::string& field_name = "file") - -> http::response; - - /** - * @brief Downloads a file from the server. - * @param host The server host. - * @param port The server port. - * @param target The target URI. - * @param filepath The path to save the downloaded file. - */ - void downloadFile(const std::string& host, const std::string& port, - const std::string& target, const std::string& filepath); - - /** - * @brief Sends a synchronous HTTP request with retry logic. - * @tparam Body The type of the request body. - * @param method The HTTP method (verb). - * @param host The server host. - * @param port The server port. - * @param target The target URI. - * @param retry_count The number of retry attempts (default is 3). - * @param version The HTTP version (default is 11). - * @param content_type The content type of the request body. - * @param body The request body. - * @param headers Additional headers to include in the request. - * @return The HTTP response. - */ - template - auto requestWithRetry( - http::verb method, const std::string& host, const std::string& port, - const std::string& target, int retry_count = 3, int version = 11, - const std::string& content_type = "", const std::string& body = "", - const std::unordered_map& headers = {}) - -> http::response; - - /** - * @brief Sends multiple synchronous HTTP requests in a batch. - * @tparam Body The type of the request body. - * @param requests A vector of tuples containing the HTTP method, host, - * port, and target for each request. - * @param headers Additional headers to include in each request. - * @return A vector of HTTP responses. - */ - template - std::vector> batchRequest( - const std::vector>& requests, - const std::unordered_map& headers = {}); - - /** - * @brief Sends multiple asynchronous HTTP requests in a batch. - * @tparam ResponseHandler The type of the handler to call when the - * operation completes. - * @param requests A vector of tuples containing the HTTP method, host, - * port, and target for each request. - * @param handler The handler to call when the operation completes. - * @param headers Additional headers to include in each request. - */ - template - void asyncBatchRequest( - const std::vector>& requests, - ResponseHandler&& handler, - const std::unordered_map& headers = {}); - - /** - * @brief Runs the I/O context with a thread pool. - * @param num_threads The number of threads in the pool. - */ - void runWithThreadPool(size_t num_threads); - - /** - * @brief Asynchronously downloads a file from the server. - * @tparam ResponseHandler The type of the handler to call when the - * operation completes. - * @param host The server host. - * @param port The server port. - * @param target The target URI. - * @param filepath The path to save the downloaded file. - * @param handler The handler to call when the operation completes. - */ - template - void asyncDownloadFile(const std::string& host, const std::string& port, - const std::string& target, - const std::string& filepath, - ResponseHandler&& handler); - -private: - tcp::resolver resolver_; ///< The resolver for DNS lookups. - beast::tcp_stream stream_; ///< The TCP stream for HTTP communication. - std::unordered_map - default_headers_; ///< Default headers for all requests. - std::chrono::seconds timeout_{ - 30}; ///< The timeout duration for HTTP operations. -}; - -template -auto HttpClient::request(http::verb method, const std::string& host, - const std::string& port, const std::string& target, - int version, const std::string& content_type, - const std::string& body, - const std::unordered_map& - headers) -> http::response { - http::request req{method, target, version}; - req.set(http::field::host, host); - req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING); - - for (const auto& [key, value] : default_headers_) { - req.set(key, value); - } - - for (const auto& [key, value] : headers) { - req.set(key, value); - } - - if (!content_type.empty()) { - req.set(http::field::content_type, content_type); - } - - if (!body.empty()) { - req.body() = body; - req.prepare_payload(); - } - - auto const results = resolver_.resolve(host, port); - stream_.connect(results); - - stream_.expires_after(timeout_); - - http::write(stream_, req); - - beast::flat_buffer buffer; - http::response res; - http::read(stream_, buffer, res); - - beast::error_code ec; - stream_.socket().shutdown(tcp::socket::shutdown_both, ec); - - return res; -} - -template -void HttpClient::asyncRequest( - http::verb method, const std::string& host, const std::string& port, - const std::string& target, ResponseHandler&& handler, int version, - const std::string& content_type, const std::string& body, - const std::unordered_map& headers) { - auto req = std::make_shared>( - method, target, version); - req->set(http::field::host, host); - req->set(http::field::user_agent, BOOST_BEAST_VERSION_STRING); - - for (const auto& [key, value] : default_headers_) { - req->set(key, value); - } - - for (const auto& [key, value] : headers) { - req->set(key, value); - } - - if (!content_type.empty()) { - req->set(http::field::content_type, content_type); - } - - if (!body.empty()) { - req->body() = body; - req->prepare_payload(); - } - - resolver_.async_resolve( - host, port, - [this, req, handler = std::forward(handler)]( - beast::error_code ec, tcp::resolver::results_type results) { - if (ec) { - return handler(ec, {}); - } - - stream_.async_connect( - results, [this, req, handler = std::move(handler)]( - beast::error_code ec, - tcp::resolver::results_type::endpoint_type) { - if (ec) { - return handler(ec, {}); - } - - stream_.expires_after(timeout_); - - http::async_write( - stream_, *req, - [this, req, handler = std::move(handler)]( - beast::error_code ec, std::size_t) { - if (ec) { - return handler(ec, {}); - } - - auto res = std::make_shared>(); - auto buffer = - std::make_shared(); - - http::async_read( - stream_, *buffer, *res, - [this, res, buffer, - handler = std::move(handler)]( - beast::error_code ec, std::size_t) { - stream_.socket().shutdown( - tcp::socket::shutdown_both, ec); - handler(ec, std::move(*res)); - }); - }); - }); - }); -} - -template -void HttpClient::asyncJsonRequest( - http::verb method, const std::string& host, const std::string& port, - const std::string& target, ResponseHandler&& handler, const json& json_body, - const std::unordered_map& headers) { - asyncRequest( - method, host, port, target, - [handler = std::forward(handler)]( - beast::error_code ec, http::response res) { - if (ec) { - handler(ec, {}); - } else { - try { - auto jv = json::parse(res.body()); - handler({}, std::move(jv)); - } catch (const json::parse_error& e) { - handler(beast::error_code{e.id, beast::generic_category()}, - {}); - } - } - }, - 11, "application/json", json_body.empty() ? "" : json_body.dump(), - headers); -} - -template -auto HttpClient::requestWithRetry( - http::verb method, const std::string& host, const std::string& port, - const std::string& target, int retry_count, int version, - const std::string& content_type, const std::string& body, - const std::unordered_map& headers) - -> http::response { - beast::error_code ec; - http::response response; - for (int attempt = 0; attempt < retry_count; ++attempt) { - try { - response = request(method, host, port, target, version, - content_type, body, headers); - // If no exception was thrown, return the response - return response; - } catch (const beast::system_error& e) { - ec = e.code(); - std::cerr << "Request attempt " << (attempt + 1) - << " failed: " << ec.message() << std::endl; - if (attempt + 1 == retry_count) { - throw; // Throw the exception if this was the last retry - } - } - } - return response; -} - -template -std::vector> HttpClient::batchRequest( - const std::vector>& requests, - const std::unordered_map& headers) { - std::vector> responses; - for (const auto& [method, host, port, target] : requests) { - try { - responses.push_back( - request(method, host, port, target, 11, "", "", headers)); - } catch (const std::exception& e) { - std::cerr << "Batch request failed for " << target << ": " - << e.what() << std::endl; - // Push an empty response if an exception occurs (or handle as - // needed) - responses.emplace_back(); - } - } - return responses; -} - -template -void HttpClient::asyncBatchRequest( - const std::vector>& requests, - ResponseHandler&& handler, - const std::unordered_map& headers) { - auto responses = - std::make_shared>>(); - auto remaining = std::make_shared>(requests.size()); - - for (const auto& [method, host, port, target] : requests) { - asyncRequest( - method, host, port, target, - [handler, responses, remaining]( - beast::error_code ec, http::response res) { - if (ec) { - std::cerr << "Error during batch request: " << ec.message() - << std::endl; - responses - ->emplace_back(); // Empty response in case of error - } else { - responses->emplace_back(std::move(res)); - } - - if (--(*remaining) == 0) { - handler(*responses); - } - }, - 11, "", "", headers); - } -} - -template -void HttpClient::asyncDownloadFile(const std::string& host, - const std::string& port, - const std::string& target, - const std::string& filepath, - ResponseHandler&& handler) { - asyncRequest( - http::verb::get, host, port, target, - [filepath, handler = std::forward(handler)]( - beast::error_code ec, http::response res) { - if (ec) { - handler(ec, false); - } else { - std::ofstream outFile(filepath, std::ios::binary); - if (!outFile) { - std::cerr << "Failed to open file for writing: " << filepath - << std::endl; - handler(beast::error_code{}, false); - return; - } - outFile << res.body(); - handler({}, true); // Download successful - } - }); -} - -#endif // HTTP_CLIENT_HPP diff --git a/src/atom/extra/beast/ws.cpp b/src/atom/extra/beast/ws.cpp deleted file mode 100644 index 550cba98..00000000 --- a/src/atom/extra/beast/ws.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include "ws.hpp" - -#if __has_include("atom/log/loguru.hpp") -#include "atom/log/loguru.hpp" -#else -#include -#endif - -WSClient::WSClient(net::io_context& ioc) - : resolver_(net::make_strand(ioc)), - ws_(net::make_strand(ioc)), - ping_timer_(ioc) {} - -void WSClient::setTimeout(std::chrono::seconds timeout) { timeout_ = timeout; } - -void WSClient::setReconnectOptions(int retries, std::chrono::seconds interval) { - max_retries_ = retries; - reconnect_interval_ = interval; -} - -void WSClient::setPingInterval(std::chrono::seconds interval) { - ping_interval_ = interval; -} - -void WSClient::connect(const std::string& host, const std::string& port) { - auto const results = resolver_.resolve(host, port); - beast::get_lowest_layer(ws_).connect(results->endpoint()); - ws_.handshake(host, "/"); - startPing(); -} - -void WSClient::send(const std::string& message) { - ws_.write(net::buffer(message)); -} - -std::string WSClient::receive() { - beast::flat_buffer buffer; - ws_.read(buffer); - return beast::buffers_to_string(buffer.data()); -} - -void WSClient::close() { ws_.close(websocket::close_code::normal); } - -void WSClient::startPing() { - if (ping_interval_.count() > 0) { - ping_timer_.expires_after(ping_interval_); - ping_timer_.async_wait([this](beast::error_code ec) { - if (!ec) { - ws_.async_ping({}, [this](beast::error_code ec) { - if (!ec) { - startPing(); - } - }); - } - }); - } -} - -template -void WSClient::handleConnectError(beast::error_code ec, - ConnectHandler&& handler) { - if (retry_count_ < max_retries_) { - ++retry_count_; - LOG_F(ERROR, "Failed to connect: {}. Retrying in {} seconds...", - ec.message(), reconnect_interval_.count()); - ws_.next_layer().close(); - ping_timer_.expires_after(reconnect_interval_); - ping_timer_.async_wait([this, handler = std::forward( - handler)](beast::error_code ec) { - if (!ec) { - asyncConnect("example.com", "80", - std::forward(handler)); - } - }); - } else { - LOG_F(ERROR, "Failed to connect: {}. Giving up.", ec.message()); - handler(ec); - } -} diff --git a/src/atom/extra/beast/ws.hpp b/src/atom/extra/beast/ws.hpp deleted file mode 100644 index 1aa9ac23..00000000 --- a/src/atom/extra/beast/ws.hpp +++ /dev/null @@ -1,246 +0,0 @@ -#ifndef WS_CLIENT_HPP -#define WS_CLIENT_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace beast = boost::beast; -namespace net = boost::asio; -namespace websocket = beast::websocket; -using tcp = boost::asio::ip::tcp; -using json = nlohmann::json; - -/** - * @class WSClient - * @brief A WebSocket client class for managing WebSocket connections and - * communication. - */ -class WSClient { -public: - /** - * @brief Constructs a WSClient with the given I/O context. - * @param ioc The I/O context to use for asynchronous operations. - */ - explicit WSClient(net::io_context& ioc); - - /** - * @brief Sets the timeout duration for the WebSocket operations. - * @param timeout The timeout duration in seconds. - */ - void setTimeout(std::chrono::seconds timeout); - - /** - * @brief Sets the reconnection options. - * @param retries The number of reconnection attempts. - * @param interval The interval between reconnection attempts in seconds. - */ - void setReconnectOptions(int retries, std::chrono::seconds interval); - - /** - * @brief Sets the interval for sending ping messages. - * @param interval The ping interval in seconds. - */ - void setPingInterval(std::chrono::seconds interval); - - /** - * @brief Connects to the WebSocket server. - * @param host The server host. - * @param port The server port. - */ - void connect(const std::string& host, const std::string& port); - - /** - * @brief Sends a message to the WebSocket server. - * @param message The message to send. - */ - void send(const std::string& message); - - /** - * @brief Receives a message from the WebSocket server. - * @return The received message. - */ - std::string receive(); - - /** - * @brief Closes the WebSocket connection. - */ - void close(); - - /** - * @brief Asynchronously connects to the WebSocket server. - * @tparam ConnectHandler The type of the handler to call when the operation - * completes. - * @param host The server host. - * @param port The server port. - * @param handler The handler to call when the operation completes. - */ - template - void asyncConnect(const std::string& host, const std::string& port, - ConnectHandler&& handler); - - /** - * @brief Asynchronously sends a message to the WebSocket server. - * @tparam WriteHandler The type of the handler to call when the operation - * completes. - * @param message The message to send. - * @param handler The handler to call when the operation completes. - */ - template - void asyncSend(const std::string& message, WriteHandler&& handler); - - /** - * @brief Asynchronously receives a message from the WebSocket server. - * @tparam ReadHandler The type of the handler to call when the operation - * completes. - * @param handler The handler to call when the operation completes. - */ - template - void asyncReceive(ReadHandler&& handler); - - /** - * @brief Asynchronously closes the WebSocket connection. - * @tparam CloseHandler The type of the handler to call when the operation - * completes. - * @param handler The handler to call when the operation completes. - */ - template - void asyncClose(CloseHandler&& handler); - - /** - * @brief Asynchronously sends a JSON object to the WebSocket server. - * @param jdata The JSON object to send. - * @param handler The handler to call when the operation completes. - */ - void asyncSendJson( - const json& jdata, - std::function handler); - - /** - * @brief Asynchronously receives a JSON object from the WebSocket server. - * @tparam JsonHandler The type of the handler to call when the operation - * completes. - * @param handler The handler to call when the operation completes. - */ - template - void asyncReceiveJson(JsonHandler&& handler); - -private: - /** - * @brief Starts the ping timer to send periodic ping messages. - */ - void startPing(); - - /** - * @brief Handles connection errors and retries if necessary. - * @tparam ConnectHandler The type of the handler to call when the operation - * completes. - * @param ec The error code. - * @param handler The handler to call when the operation completes. - */ - template - void handleConnectError(beast::error_code ec, ConnectHandler&& handler); - - tcp::resolver resolver_; ///< The resolver for DNS lookups. - websocket::stream ws_; ///< The WebSocket stream. - net::steady_timer ping_timer_; ///< The timer for sending ping messages. - std::chrono::seconds timeout_{ - 30}; ///< The timeout duration for WebSocket operations. - std::chrono::seconds ping_interval_{ - 10}; ///< The interval for sending ping messages. - std::chrono::seconds reconnect_interval_{ - 5}; ///< The interval between reconnection attempts. - int max_retries_ = 3; ///< The maximum number of reconnection attempts. - int retry_count_ = 0; ///< The current number of reconnection attempts. -}; - -template -void WSClient::asyncConnect(const std::string& host, const std::string& port, - ConnectHandler&& handler) { - retry_count_ = 0; - resolver_.async_resolve( - host, port, - [this, handler = std::forward(handler)]( - beast::error_code ec, tcp::resolver::results_type results) { - if (ec) { - handleConnectError(ec, handler); - return; - } - - beast::get_lowest_layer(ws_).async_connect( - results, [this, handler = std::move(handler), results]( - beast::error_code ec, - tcp::resolver::results_type::endpoint_type) { - if (ec) { - handleConnectError(ec, handler); - return; - } - - ws_.async_handshake(results->host_name(), "/", - [this, handler = std::move(handler)]( - beast::error_code ec) { - if (!ec) { - startPing(); - } - handler(ec); - }); - }); - }); -} - -template -void WSClient::asyncSend(const std::string& message, WriteHandler&& handler) { - ws_.async_write(net::buffer(message), - [handler = std::forward(handler)]( - beast::error_code ec, std::size_t bytes_transferred) { - handler(ec, bytes_transferred); - }); -} - -template -void WSClient::asyncReceive(ReadHandler&& handler) { - auto buffer = std::make_shared(); - ws_.async_read( - *buffer, [buffer, handler = std::forward(handler)]( - beast::error_code ec, std::size_t bytes_transferred) { - if (ec) { - handler(ec, ""); - } else { - handler(ec, beast::buffers_to_string(buffer->data())); - } - }); -} - -template -void WSClient::asyncClose(CloseHandler&& handler) { - ws_.async_close(websocket::close_code::normal, - [handler = std::forward(handler)]( - beast::error_code ec) { handler(ec); }); -} - -template -void WSClient::asyncReceiveJson(JsonHandler&& handler) { - asyncReceive([handler = std::forward(handler)]( - beast::error_code ec, const std::string& message) { - if (ec) { - handler(ec, {}); - } else { - try { - auto jdata = json::parse(message); - handler(ec, jdata); - } catch (const json::parse_error&) { - handler(beast::error_code{}, {}); - } - } - }); -} - -#endif // WS_CLIENT_HPP diff --git a/src/atom/extra/boost/charconv.hpp b/src/atom/extra/boost/charconv.hpp deleted file mode 100644 index db5d294e..00000000 --- a/src/atom/extra/boost/charconv.hpp +++ /dev/null @@ -1,280 +0,0 @@ -#ifndef ATOM_EXTRA_BOOST_CHARCONV_HPP -#define ATOM_EXTRA_BOOST_CHARCONV_HPP - -#if __has_include() -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::extra::boost { - -// Constants definition -constexpr int ALIGNMENT = 16; -constexpr int DEFAULT_BASE = 10; -constexpr size_t BUFFER_SIZE = 128; - -/** - * @brief Enum class representing different number formats. - */ -enum class NumberFormat { GENERAL, SCIENTIFIC, FIXED, HEX }; - -/** - * @brief Struct for specifying format options for number conversion. - */ -struct alignas(ALIGNMENT) FormatOptions { - NumberFormat format = NumberFormat::GENERAL; ///< The number format. - std::optional precision = - std::nullopt; ///< The precision for floating-point numbers. - bool uppercase = false; ///< Whether to use uppercase letters. - char thousandsSeparator = - '\0'; ///< The character to use as a thousands separator. -}; - -/** - * @brief Class for converting numbers to and from strings using Boost.CharConv. - */ -class BoostCharConv { -public: - /** - * @brief Converts an integer to a string. - * @tparam T The type of the integer. - * @param value The integer value to convert. - * @param base The base for the conversion (default is 10). - * @param options The format options for the conversion. - * @return The converted string. - * @throws std::runtime_error if the conversion fails. - */ - template - static auto intToString(T value, int base = DEFAULT_BASE, - const FormatOptions& options = {}) -> std::string { - static_assert(std::is_integral_v, - "intToString only works with integral types"); - - std::array buffer{}; // Buffer for conversion - auto result = - std::to_chars(buffer.data(), buffer.data() + buffer.size(), value, - base); // Standard to_chars - - if (result.ec == std::errc()) { - std::string str(buffer.data(), result.ptr); - if (options.thousandsSeparator != '\0') { - str = addThousandsSeparator(str, options.thousandsSeparator); - } - return options.uppercase ? toUpper(str) : str; - } - throw std::runtime_error("Int to string conversion failed: " + - std::make_error_code(result.ec).message()); - } - - /** - * @brief Converts a floating-point number to a string. - * @tparam T The type of the floating-point number. - * @param value The floating-point value to convert. - * @param options The format options for the conversion. - * @return The converted string. - * @throws std::runtime_error if the conversion fails. - */ - template - static auto floatToString(T value, const FormatOptions& options = {}) - -> std::string { - std::array buffer{}; - auto format = getFloatFormat(options.format); - auto result = options.precision - ? ::boost::charconv::to_chars( - buffer.data(), buffer.data() + buffer.size(), - value, format, *options.precision) - : ::boost::charconv::to_chars( - buffer.data(), buffer.data() + buffer.size(), - value, format); - if (result.ec == std::errc()) { - std::string str(buffer.data(), result.ptr); - if (options.thousandsSeparator != '\0') { - str = addThousandsSeparator(str, options.thousandsSeparator); - } - return options.uppercase ? toUpper(str) : str; - } - throw std::runtime_error("Float to string conversion failed: " + - std::make_error_code(result.ec).message()); - } - - /** - * @brief Converts a string to an integer. - * @tparam T The type of the integer. - * @param str The string to convert. - * @param base The base for the conversion (default is 10). - * @return The converted integer. - * @throws std::runtime_error if the conversion fails. - */ - template - static auto stringToInt(const std::string& str, - int base = DEFAULT_BASE) -> T { - T value; - auto result = ::boost::charconv::from_chars( - str.data(), str.data() + str.size(), value, base); - if (result.ec == std::errc() && result.ptr == str.data() + str.size()) { - return value; - } - throw std::runtime_error("String to int conversion failed: " + - std::make_error_code(result.ec).message()); - } - - /** - * @brief Converts a string to a floating-point number. - * @tparam T The type of the floating-point number. - * @param str The string to convert. - * @return The converted floating-point number. - * @throws std::runtime_error if the conversion fails. - */ - template - static auto stringToFloat(const std::string& str) -> T { - T value; - auto result = ::boost::charconv::from_chars( - str.data(), str.data() + str.size(), value); - if (result.ec == std::errc() && result.ptr == str.data() + str.size()) { - return value; - } - throw std::runtime_error("String to float conversion failed: " + - std::make_error_code(result.ec).message()); - } - - /** - * @brief Converts a value to a string using the appropriate conversion - * function. - * @tparam T The type of the value. - * @param value The value to convert. - * @param options The format options for the conversion. - * @return The converted string. - */ - template - static auto toString(T value, - const FormatOptions& options = {}) -> std::string { - if constexpr (std::is_integral_v) { - return intToString(value, DEFAULT_BASE, options); - } else if constexpr (std::is_floating_point_v) { - return floatToString(value, options); - } else { - static_assert(ALWAYS_FALSE, "Unsupported type for toString"); - } - } - - /** - * @brief Converts a string to a value using the appropriate conversion - * function. - * @tparam T The type of the value. - * @param str The string to convert. - * @param base The base for the conversion (default is 10). - * @return The converted value. - */ - template - static auto fromString(const std::string& str, - int base = DEFAULT_BASE) -> T { - if constexpr (std::is_integral_v) { - return stringToInt(str, base); - } else if constexpr (std::is_floating_point_v) { - return stringToFloat(str); - } else { - static_assert(ALWAYS_FALSE, "Unsupported type for fromString"); - } - } - - /** - * @brief Converts special floating-point values (NaN, Inf) to strings. - * @tparam T The type of the floating-point value. - * @param value The floating-point value to convert. - * @return The converted string. - */ - template - static auto specialValueToString(T value) -> std::string { - if (std::isnan(value)) { - return "NaN"; - } - if (std::isinf(value)) { - return value > 0 ? "Inf" : "-Inf"; - } - return toString(value); - } - -private: - template - static constexpr bool ALWAYS_FALSE = false; - - /** - * @brief Gets the Boost.CharConv format for floating-point numbers. - * @param format The number format. - * @return The Boost.CharConv format. - */ - static auto getFloatFormat(NumberFormat format) - -> ::boost::charconv::chars_format { - switch (format) { - case NumberFormat::SCIENTIFIC: - return ::boost::charconv::chars_format::scientific; - case NumberFormat::FIXED: - return ::boost::charconv::chars_format::fixed; - case NumberFormat::HEX: - return ::boost::charconv::chars_format::hex; - default: - return ::boost::charconv::chars_format::general; - } - } - - /** - * @brief Gets the Boost.CharConv format for integer numbers. - * @param format The number format. - * @return The Boost.CharConv format. - */ - static auto getIntegerFormat(NumberFormat format) - -> ::boost::charconv::chars_format { - return (format == NumberFormat::HEX) - ? ::boost::charconv::chars_format::hex - : ::boost::charconv::chars_format::general; - } - - /** - * @brief Adds a thousands separator to a string. - * @param str The string to modify. - * @param separator The character to use as a thousands separator. - * @return The modified string with thousands separators. - */ - static auto addThousandsSeparator(const std::string& str, - char separator) -> std::string { - std::string result; - int count = 0; - bool pastDecimalPoint = false; - for (char it : std::ranges::reverse_view(str)) { - if (it == '.') { - pastDecimalPoint = true; - } - if (!pastDecimalPoint && count > 0 && count % 3 == 0) { - result.push_back(separator); - } - result.push_back(it); - count++; - } - std::reverse(result.begin(), result.end()); - return result; - } - - /** - * @brief Converts a string to uppercase. - * @param str The string to convert. - * @return The converted uppercase string. - */ - static auto toUpper(std::string str) -> std::string { - for (char& character : str) { - character = std::toupper(character); - } - return str; - } -}; - -} // namespace atom::extra::boost - -#endif - -#endif // ATOM_EXTRA_BOOST_CHARCONV_HPP diff --git a/src/atom/extra/boost/locale.hpp b/src/atom/extra/boost/locale.hpp deleted file mode 100644 index ed4e90cd..00000000 --- a/src/atom/extra/boost/locale.hpp +++ /dev/null @@ -1,228 +0,0 @@ -#ifndef ATOM_EXTRA_BOOST_LOCALE_HPP -#define ATOM_EXTRA_BOOST_LOCALE_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::extra::boost { - -/** - * @brief A wrapper class for Boost.Locale functionalities. - * - * This class provides various utilities for string conversion, Unicode - * normalization, tokenization, translation, case conversion, collation, date - * and time formatting, number formatting, currency formatting, and regex - * replacement using Boost.Locale. - */ -class LocaleWrapper { -public: - /** - * @brief Constructs a LocaleWrapper object with the specified locale. - * @param localeName The name of the locale to use. If empty, the global - * locale is used. - */ - explicit LocaleWrapper(const std::string& localeName = "") { - ::boost::locale::generator gen; - std::locale::global(gen(localeName)); - locale_ = std::locale(); - } - - /** - * @brief Converts a string to UTF-8 encoding. - * @param str The string to convert. - * @param fromCharset The original character set of the string. - * @return The UTF-8 encoded string. - */ - static auto toUtf8(const std::string& str, - const std::string& fromCharset) -> std::string { - return ::boost::locale::conv::to_utf(str, fromCharset); - } - - /** - * @brief Converts a UTF-8 encoded string to another character set. - * @param str The UTF-8 encoded string to convert. - * @param toCharset The target character set. - * @return The converted string. - */ - static auto fromUtf8(const std::string& str, - const std::string& toCharset) -> std::string { - return ::boost::locale::conv::from_utf(str, toCharset); - } - - /** - * @brief Normalizes a Unicode string. - * @param str The string to normalize. - * @param norm The normalization form to use (default is NFC). - * @return The normalized string. - */ - static auto normalize(const std::string& str, - ::boost::locale::norm_type norm = - ::boost::locale::norm_default) -> std::string { - return ::boost::locale::normalize(str, norm); - } - - /** - * @brief Tokenizes a string into words. - * @param str The string to tokenize. - * @param localeName The name of the locale to use for tokenization. - * @return A vector of tokens. - */ - static auto tokenize(const std::string& str, - const std::string& localeName = "") - -> std::vector { - ::boost::locale::generator gen; - std::locale loc = gen(localeName); - ::boost::locale::boundary::ssegment_index map( - ::boost::locale::boundary::word, str.begin(), str.end(), loc); - std::vector tokens; -#pragma unroll - for (const auto& token : map) { - tokens.push_back(token.str()); - } - return tokens; - } - - /** - * @brief Translates a string to the specified locale. - * @param str The string to translate. - * @param domain The domain for the translation (not used in this - * implementation). - * @param localeName The name of the locale to use for translation. - * @return The translated string. - */ - static auto translate(const std::string& str, const std::string& /*domain*/, - const std::string& localeName = "") -> std::string { - ::boost::locale::generator gen; - std::locale loc = gen(localeName); - return ::boost::locale::translate(str).str(loc); - } - - /** - * @brief Converts a string to uppercase. - * @param str The string to convert. - * @return The uppercase string. - */ - [[nodiscard]] auto toUpper(const std::string& str) const -> std::string { - return ::boost::locale::to_upper(str, locale_); - } - - /** - * @brief Converts a string to lowercase. - * @param str The string to convert. - * @return The lowercase string. - */ - [[nodiscard]] auto toLower(const std::string& str) const -> std::string { - return ::boost::locale::to_lower(str, locale_); - } - - /** - * @brief Converts a string to title case. - * @param str The string to convert. - * @return The title case string. - */ - [[nodiscard]] auto toTitle(const std::string& str) const -> std::string { - return ::boost::locale::to_title(str, locale_); - } - - /** - * @brief Compares two strings using locale-specific collation rules. - * @param str1 The first string to compare. - * @param str2 The second string to compare. - * @return An integer less than, equal to, or greater than zero if str1 is - * found, respectively, to be less than, to match, or be greater than str2. - */ - [[nodiscard]] auto compare(const std::string& str1, - const std::string& str2) const -> int { - return static_cast(::boost::locale::comparator< - char, ::boost::locale::collator_base::primary>( - locale_)(str1, str2)); - } - - /** - * @brief Formats a date and time according to the specified format. - * @param dateTime The date and time to format. - * @param format The format string. - * @return The formatted date and time string. - */ - [[nodiscard]] static auto formatDate( - const ::boost::posix_time::ptime& dateTime, - const std::string& format) -> std::string { - std::ostringstream oss; - oss.imbue(std::locale()); - oss << ::boost::locale::format(format) % dateTime; - return oss.str(); - } - - /** - * @brief Formats a number with the specified precision. - * @param number The number to format. - * @param precision The number of decimal places. - * @return The formatted number string. - */ - [[nodiscard]] static auto formatNumber(double number, - int precision = 2) -> std::string { - std::ostringstream oss; - oss.imbue(std::locale()); - oss << std::fixed << std::setprecision(precision) << number; - return oss.str(); - } - - /** - * @brief Formats a currency amount. - * @param amount The amount to format. - * @param currency The currency code. - * @return The formatted currency string. - */ - [[nodiscard]] static auto formatCurrency( - double amount, const std::string& currency) -> std::string { - std::ostringstream oss; - oss.imbue(std::locale()); - oss << ::boost::locale::as::currency << currency << amount; - return oss.str(); - } - - /** - * @brief Replaces occurrences of a regex pattern in a string with a format - * string. - * @param str The string to search. - * @param regex The regex pattern to search for. - * @param format The format string to replace with. - * @return The resulting string after replacements. - */ - [[nodiscard]] static auto regexReplace( - const std::string& str, const ::boost::regex& regex, - const std::string& format) -> std::string { - return ::boost::regex_replace( - str, regex, format, ::boost::match_default | ::boost::format_all); - } - - /** - * @brief Formats a string with named arguments. - * @tparam Args The types of the arguments. - * @param formatString The format string. - * @param args The arguments to format. - * @return The formatted string. - */ - template - [[nodiscard]] auto format(const std::string& formatString, - Args&&... args) const -> std::string { - return (::boost::locale::format(formatString) % ... % args) - .str(locale_); - } - -private: - std::locale locale_; ///< The locale used for various operations. - static constexpr std::size_t K_BUFFER_SIZE = - 4096; ///< Buffer size for internal operations. -}; - -} // namespace atom::extra::boost - -#endif // ATOM_EXTRA_BOOST_LOCALE_HPP diff --git a/src/atom/extra/boost/math.hpp b/src/atom/extra/boost/math.hpp deleted file mode 100644 index 8c8202fd..00000000 --- a/src/atom/extra/boost/math.hpp +++ /dev/null @@ -1,601 +0,0 @@ -#ifndef ATOM_EXTRA_BOOST_MATH_HPP -#define ATOM_EXTRA_BOOST_MATH_HPP - -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include -#include - -namespace atom::extra::boost { - -/** - * @brief Concept to check if a type is numeric. - * @tparam T The type to check. - */ -template -concept Numeric = std::is_arithmetic_v; - -/** - * @brief Wrapper class for special mathematical functions. - * @tparam T The numeric type. - */ -template -class SpecialFunctions { -public: - /** - * @brief Computes the beta function. - * @param alpha The alpha parameter. - * @param beta The beta parameter. - * @return The value of the beta function. - */ - static auto beta(T alpha, T beta) -> T { - return ::boost::math::beta(alpha, beta); - } - - /** - * @brief Computes the gamma function. - * @param value The input value. - * @return The value of the gamma function. - */ - static auto gamma(T value) -> T { return ::boost::math::tgamma(value); } - - /** - * @brief Computes the digamma function. - * @param value The input value. - * @return The value of the digamma function. - */ - static auto digamma(T value) -> T { return ::boost::math::digamma(value); } - - /** - * @brief Computes the error function. - * @param value The input value. - * @return The value of the error function. - */ - static auto erf(T value) -> T { return ::boost::math::erf(value); } - - /** - * @brief Computes the Bessel function of the first kind. - * @param order The order of the Bessel function. - * @param value The input value. - * @return The value of the Bessel function. - */ - static auto besselJ(int order, T value) -> T { - return ::boost::math::cyl_bessel_j(order, value); - } - - /** - * @brief Computes the Legendre polynomial. - * @param order The order of the polynomial. - * @param value The input value. - * @return The value of the Legendre polynomial. - */ - static auto legendreP(int order, T value) -> T { - return ::boost::math::legendre_p(order, value); - } -}; - -/** - * @brief Wrapper class for statistical functions. - * @tparam T The numeric type. - */ -template -class Statistics { -public: - /** - * @brief Computes the mean of a dataset. - * @param data The input dataset. - * @return The mean of the dataset. - */ - static auto mean(const std::vector& data) -> T { - return ::boost::math::statistics::mean(data); - } - - /** - * @brief Computes the variance of a dataset. - * @param data The input dataset. - * @return The variance of the dataset. - */ - static auto variance(const std::vector& data) -> T { - return ::boost::math::statistics::variance(data); - } - - /** - * @brief Computes the skewness of a dataset. - * @param data The input dataset. - * @return The skewness of the dataset. - */ - static auto skewness(const std::vector& data) -> T { - return ::boost::math::statistics::skewness(data); - } - - /** - * @brief Computes the kurtosis of a dataset. - * @param data The input dataset. - * @return The kurtosis of the dataset. - */ - static auto kurtosis(const std::vector& data) -> T { - return ::boost::math::statistics::kurtosis(data); - } -}; - -/** - * @brief Wrapper class for probability distributions. - * @tparam T The numeric type. - */ -template -class Distributions { -public: - /** - * @brief Wrapper class for normal distribution. - */ - class NormalDistribution { - private: - ::boost::math::normal_distribution distribution; - - public: - /** - * @brief Constructs a normal distribution with given mean and standard - * deviation. - * @param mean The mean of the distribution. - * @param stddev The standard deviation of the distribution. - */ - NormalDistribution(T mean, T stddev) : distribution(mean, stddev) {} - - /** - * @brief Computes the probability density function (PDF) at a given - * value. - * @param value The input value. - * @return The PDF value. - */ - [[nodiscard]] auto pdf(T value) const -> T { - return ::boost::math::pdf(distribution, value); - } - - /** - * @brief Computes the cumulative distribution function (CDF) at a given - * value. - * @param value The input value. - * @return The CDF value. - */ - [[nodiscard]] auto cdf(T value) const -> T { - return ::boost::math::cdf(distribution, value); - } - - /** - * @brief Computes the quantile (inverse CDF) at a given probability. - * @param probability The input probability. - * @return The quantile value. - */ - [[nodiscard]] auto quantile(T probability) const -> T { - return ::boost::math::quantile(distribution, probability); - } - }; - - /** - * @brief Wrapper class for Student's t-distribution. - */ - class StudentTDistribution { - private: - ::boost::math::students_t_distribution distribution; - - public: - /** - * @brief Constructs a Student's t-distribution with given degrees of - * freedom. - * @param degreesOfFreedom The degrees of freedom. - */ - explicit StudentTDistribution(T degreesOfFreedom) - : distribution(degreesOfFreedom) {} - - /** - * @brief Computes the probability density function (PDF) at a given - * value. - * @param value The input value. - * @return The PDF value. - */ - [[nodiscard]] auto pdf(T value) const -> T { - return ::boost::math::pdf(distribution, value); - } - - /** - * @brief Computes the cumulative distribution function (CDF) at a given - * value. - * @param value The input value. - * @return The CDF value. - */ - [[nodiscard]] auto cdf(T value) const -> T { - return ::boost::math::cdf(distribution, value); - } - - /** - * @brief Computes the quantile (inverse CDF) at a given probability. - * @param probability The input probability. - * @return The quantile value. - */ - [[nodiscard]] auto quantile(T probability) const -> T { - return ::boost::math::quantile(distribution, probability); - } - }; - - /** - * @brief Wrapper class for Poisson distribution. - */ - class PoissonDistribution { - private: - ::boost::math::poisson_distribution distribution; - - public: - /** - * @brief Constructs a Poisson distribution with given mean. - * @param mean The mean of the distribution. - */ - explicit PoissonDistribution(T mean) : distribution(mean) {} - - /** - * @brief Computes the probability density function (PDF) at a given - * value. - * @param value The input value. - * @return The PDF value. - */ - [[nodiscard]] auto pdf(T value) const -> T { - return ::boost::math::pdf(distribution, value); - } - - /** - * @brief Computes the cumulative distribution function (CDF) at a given - * value. - * @param value The input value. - * @return The CDF value. - */ - [[nodiscard]] auto cdf(T value) const -> T { - return ::boost::math::cdf(distribution, value); - } - }; - - /** - * @brief Wrapper class for exponential distribution. - */ - class ExponentialDistribution { - private: - ::boost::math::exponential_distribution distribution; - - public: - /** - * @brief Constructs an exponential distribution with given rate - * parameter. - * @param lambda The rate parameter. - */ - explicit ExponentialDistribution(T lambda) : distribution(lambda) {} - - /** - * @brief Computes the probability density function (PDF) at a given - * value. - * @param value The input value. - * @return The PDF value. - */ - [[nodiscard]] auto pdf(T value) const -> T { - return ::boost::math::pdf(distribution, value); - } - - /** - * @brief Computes the cumulative distribution function (CDF) at a given - * value. - * @param value The input value. - * @return The CDF value. - */ - [[nodiscard]] auto cdf(T value) const -> T { - return ::boost::math::cdf(distribution, value); - } - }; -}; - -/** - * @brief Wrapper class for numerical integration methods. - * @tparam T The numeric type. - */ -template -class NumericalIntegration { -public: - /** - * @brief Computes the integral of a function using the trapezoidal rule. - * @param func The function to integrate. - * @param start The start of the integration interval. - * @param end The end of the integration interval. - * @return The computed integral. - */ - static auto trapezoidal(std::function func, T start, T end) -> T { - return ::boost::math::quadrature::trapezoidal(func, start, end); - } -}; - -/** - * @brief Computes the factorial of a number using constexpr if for compile-time - * optimization. - * @tparam T The numeric type. - * @param number The input number. - * @return The factorial of the number. - */ -template -constexpr auto factorial(T number) -> T { - if constexpr (std::is_integral_v) { - if (number == 0 || number == 1) { - return 1; - } - return number * factorial(number - 1); - } else { - return std::tgamma(number + 1); - } -} - -/** - * @brief Transforms a range of data using a given function. - * @tparam Range The type of the input range. - * @tparam Func The type of the transformation function. - * @param range The input range. - * @param func The transformation function. - * @return A transformed view of the input range. - */ -template -auto transformRange(Range&& range, Func func) { - return std::ranges::transform_view(std::forward(range), func); -} - -/** - * @brief Wrapper class for optimization methods. - * @tparam T The numeric type. - */ -template -class Optimization { -public: - /** - * @brief Performs one-dimensional golden section search to find the minimum - * of a function. - * @param func The function to minimize. - * @param start The start of the search interval. - * @param end The end of the search interval. - * @param tolerance The tolerance for convergence. - * @return The point where the function attains its minimum. - */ - static auto goldenSectionSearch(std::function func, T start, T end, - T tolerance) -> T { - const T goldenRatio = 0.618033988749895; - T pointC = end - goldenRatio * (end - start); - T pointD = start + goldenRatio * (end - start); - - while (std::abs(pointC - pointD) > tolerance) { - if (func(pointC) < func(pointD)) { - end = pointD; - } else { - start = pointC; - } - pointC = end - goldenRatio * (end - start); - pointD = start + goldenRatio * (end - start); - } - - return (start + end) / 2; - } - - /** - * @brief Performs Newton-Raphson method to find the root of a function. - * @param func The function whose root is to be found. - * @param derivativeFunc The derivative of the function. - * @param initialGuess The initial guess for the root. - * @param tolerance The tolerance for convergence. - * @param maxIterations The maximum number of iterations. - * @return The root of the function. - * @throws std::runtime_error If the derivative is zero or maximum - * iterations are reached without convergence. - */ - static auto newtonRaphson(std::function func, - std::function derivativeFunc, - T initialGuess, T tolerance, - int maxIterations) -> T { - T currentGuess = initialGuess; - for (int i = 0; i < maxIterations; ++i) { - T funcValue = func(currentGuess); - if (std::abs(funcValue) < tolerance) { - return currentGuess; - } - T derivativeValue = derivativeFunc(currentGuess); - if (derivativeValue == 0) { - throw std::runtime_error( - "Derivative is zero. Cannot continue."); - } - currentGuess = currentGuess - funcValue / derivativeValue; - } - throw std::runtime_error("Max iterations reached without convergence."); - } -}; - -/** - * @brief Wrapper class for linear algebra operations. - * @tparam T The numeric type. - */ -template -class LinearAlgebra { -public: - using Matrix = ::boost::numeric::ublas::matrix; - using Vector = ::boost::numeric::ublas::vector; - - /** - * @brief Solves a linear system of equations Ax = b. - * @param matrix The matrix A. - * @param vector The vector b. - * @return The solution vector x. - */ - static auto solveLinearSystem(const Matrix& matrix, - const Vector& vector) -> Vector { - ::boost::numeric::ublas::permutation_matrix - permutationMatrix(matrix.size1()); - Matrix matrixCopy = matrix; - ::boost::numeric::ublas::lu_factorize(matrixCopy, permutationMatrix); - Vector solution = vector; - ::boost::numeric::ublas::lu_substitute(matrixCopy, permutationMatrix, - solution); - return solution; - } - - /** - * @brief Computes the determinant of a matrix. - * @param matrix The input matrix. - * @return The determinant of the matrix. - */ - static auto determinant(const Matrix& matrix) -> T { - Matrix matrixCopy = matrix; - ::boost::numeric::ublas::permutation_matrix - permutationMatrix(matrix.size1()); - ::boost::numeric::ublas::lu_factorize(matrixCopy, permutationMatrix); - T determinantValue = 1.0; - for (std::size_t i = 0; i < matrix.size1(); ++i) { - determinantValue *= matrixCopy(i, i); - } - return determinantValue * (permutationMatrix.size() % 2 == 1 ? -1 : 1); - } - - /** - * @brief Multiplies two matrices. - * @param matrix1 The first matrix. - * @param matrix2 The second matrix. - * @return The product of the two matrices. - */ - static auto multiply(const Matrix& matrix1, - const Matrix& matrix2) -> Matrix { - return ::boost::numeric::ublas::prod(matrix1, matrix2); - } - - /** - * @brief Computes the transpose of a matrix. - * @param matrix The input matrix. - * @return The transpose of the matrix. - */ - static auto transpose(const Matrix& matrix) -> Matrix { - return ::boost::numeric::ublas::trans(matrix); - } -}; - -/** - * @brief Wrapper class for solving ordinary differential equations (ODEs). - * @tparam T The numeric type. - */ -template -class ODESolver { -public: - using State = std::vector; - using SystemFunction = std::function; - - /** - * @brief Solves an ODE using the 4th order Runge-Kutta method. - * @param system The system function defining the ODE. - * @param initialState The initial state of the system. - * @param startTime The start time. - * @param endTime The end time. - * @param stepSize The step size. - * @return A vector of states representing the solution. - */ - static auto rungeKutta4(SystemFunction system, State initialState, - T startTime, T endTime, - T stepSize) -> std::vector { - std::vector solution; - ::boost::numeric::odeint::runge_kutta4 stepper; - ::boost::numeric::odeint::integrate_const( - stepper, system, initialState, startTime, endTime, stepSize, - [&solution](const State& state, T) { solution.push_back(state); }); - return solution; - } -}; - -/** - * @brief Wrapper class for financial mathematics functions. - * @tparam T The numeric type. - */ -template -class FinancialMath { -public: - /** - * @brief Computes the price of a European call option using the - * Black-Scholes formula. - * @param stockPrice The current stock price. - * @param strikePrice The strike price of the option. - * @param riskFreeRate The risk-free interest rate. - * @param volatility The volatility of the stock. - * @param timeToMaturity The time to maturity of the option. - * @return The price of the European call option. - */ - static auto blackScholesCall(T stockPrice, T strikePrice, T riskFreeRate, - T volatility, T timeToMaturity) -> T { - T d1 = - (std::log(stockPrice / strikePrice) + - (riskFreeRate + 0.5 * volatility * volatility) * timeToMaturity) / - (volatility * std::sqrt(timeToMaturity)); - T d2 = d1 - volatility * std::sqrt(timeToMaturity); - return stockPrice * ::boost::math::cdf( - ::boost::math::normal_distribution(), d1) - - strikePrice * std::exp(-riskFreeRate * timeToMaturity) * - ::boost::math::cdf(::boost::math::normal_distribution(), - d2); - } - - /** - * @brief Computes the modified duration of a bond. - * @param yield The yield to maturity. - * @param couponRate The coupon rate of the bond. - * @param faceValue The face value of the bond. - * @param periods The number of periods. - * @return The modified duration of the bond. - */ - static auto modifiedDuration(T yield, T couponRate, T faceValue, - int periods) -> T { - T periodYield = yield / periods; - T couponPayment = couponRate * faceValue / periods; - T numPeriods = static_cast(periods); - T presentValue = 0; - T weightedPresentValue = 0; - for (int i = 1; i <= periods; ++i) { - T discountFactor = std::pow(1 + periodYield, -i); - presentValue += couponPayment * discountFactor; - weightedPresentValue += i * couponPayment * discountFactor; - } - presentValue += faceValue * std::pow(1 + periodYield, -numPeriods); - weightedPresentValue += - numPeriods * faceValue * std::pow(1 + periodYield, -numPeriods); - return (weightedPresentValue / presentValue) / (1 + periodYield); - } - - // 计算债券价格 - static auto bondPrice(T yield, T couponRate, T faceValue, - int periods) -> T { - T periodYield = yield / periods; - T couponPayment = couponRate * faceValue / periods; - T presentValue = 0; - for (int i = 1; i <= periods; ++i) { - presentValue += couponPayment * std::pow(1 + periodYield, -i); - } - presentValue += faceValue * std::pow(1 + periodYield, -periods); - return presentValue; - } - - // 计算期权的隐含波动率 - static auto impliedVolatility(T marketPrice, T stockPrice, T strikePrice, - T riskFreeRate, T timeToMaturity) -> T { - auto objectiveFunction = [&](T volatility) { - return blackScholesCall(stockPrice, strikePrice, riskFreeRate, - volatility, timeToMaturity) - - marketPrice; - }; - return Optimization::newtonRaphson( - objectiveFunction, [](T) { return 1; }, 0.2, 1e-6, 100); - } -}; - -} // namespace atom::extra::boost - -#endif diff --git a/src/atom/extra/boost/regex.hpp b/src/atom/extra/boost/regex.hpp deleted file mode 100644 index 045af020..00000000 --- a/src/atom/extra/boost/regex.hpp +++ /dev/null @@ -1,315 +0,0 @@ -#ifndef ATOM_EXTRA_BOOST_REGEX_HPP -#define ATOM_EXTRA_BOOST_REGEX_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::extra::boost { - -/** - * @class RegexWrapper - * @brief A wrapper class for Boost.Regex providing various regex operations. - */ -class RegexWrapper { -public: - /** - * @brief Constructs a RegexWrapper with the given pattern and flags. - * @param pattern The regex pattern. - * @param flags The regex syntax option flags. - */ - explicit RegexWrapper(std::string_view pattern, - ::boost::regex_constants::syntax_option_type flags = - ::boost::regex_constants::normal) - : regex_(pattern.data(), flags) {} - - /** - * @brief Matches the given string against the regex pattern. - * @tparam T The type of the input string, convertible to std::string_view. - * @param str The input string to match. - * @return True if the string matches the pattern, false otherwise. - */ - template - requires std::convertible_to - auto match(const T& str) const -> bool { - return ::boost::regex_match(std::string_view(str).begin(), - std::string_view(str).end(), regex_); - } - - /** - * @brief Searches the given string for the first match of the regex - * pattern. - * @tparam T The type of the input string, convertible to std::string_view. - * @param str The input string to search. - * @return An optional containing the first match if found, std::nullopt - * otherwise. - */ - template - requires std::convertible_to - auto search(const T& str) const -> std::optional { - ::boost::smatch what; - if (::boost::regex_search(std::string(str), what, regex_)) { - return what.str(); - } - return std::nullopt; - } - - /** - * @brief Searches the given string for all matches of the regex pattern. - * @tparam T The type of the input string, convertible to std::string_view. - * @param str The input string to search. - * @return A vector containing all matches found. - */ - template - requires std::convertible_to - auto searchAll(const T& str) const -> std::vector { - std::vector results; - std::string s(str); - ::boost::sregex_iterator iter(s.begin(), s.end(), regex_); - ::boost::sregex_iterator end; - for (; iter != end; ++iter) { - results.push_back(iter->str()); - } - return results; - } - - /** - * @brief Replaces all matches of the regex pattern in the given string with - * the replacement string. - * @tparam T The type of the input string, convertible to std::string_view. - * @tparam U The type of the replacement string, convertible to - * std::string_view. - * @param str The input string. - * @param replacement The replacement string. - * @return A new string with all matches replaced. - */ - template - requires std::convertible_to && - std::convertible_to - auto replace(const T& str, const U& replacement) const -> std::string { - return ::boost::regex_replace(std::string(str), regex_, - std::string(replacement)); - } - - /** - * @brief Splits the given string by the regex pattern. - * @tparam T The type of the input string, convertible to std::string_view. - * @param str The input string to split. - * @return A vector containing the split parts of the string. - */ - template - requires std::convertible_to - auto split(const T& str) const -> std::vector { - std::vector results; - std::string s(str); - ::boost::sregex_token_iterator iter(s.begin(), s.end(), regex_, -1); - ::boost::sregex_token_iterator end; - for (; iter != end; ++iter) { - results.push_back(*iter); - } - return results; - } - - /** - * @brief Matches the given string and returns the groups of each match. - * @tparam T The type of the input string, convertible to std::string_view. - * @param str The input string to match. - * @return A vector of pairs, each containing the full match and a vector of - * groups. - */ - template - requires std::convertible_to - auto matchGroups(const T& str) const - -> std::vector>> { - std::vector>> results; - ::boost::smatch what; - std::string s(str); - std::string::const_iterator start = s.begin(); - std::string::const_iterator end = s.end(); - while (::boost::regex_search(start, end, what, regex_)) { - std::vector groups; - for (size_t i = 1; i < what.size(); ++i) { - groups.push_back(what[i].str()); - } - results.emplace_back(what[0].str(), std::move(groups)); - start = what[0].second; - } - return results; - } - - /** - * @brief Applies a function to each match of the regex pattern in the given - * string. - * @tparam T The type of the input string, convertible to std::string_view. - * @tparam Func The type of the function to apply. - * @param str The input string. - * @param func The function to apply to each match. - */ - template - requires std::convertible_to && - std::invocable - void forEachMatch(const T& str, Func&& func) const { - std::string s(str); - ::boost::sregex_iterator iter(s.begin(), s.end(), regex_); - ::boost::sregex_iterator end; - for (; iter != end; ++iter) { - func(*iter); - } - } - - /** - * @brief Gets the regex pattern as a string. - * @return The regex pattern. - */ - [[nodiscard]] auto getPattern() const -> std::string { - return regex_.str(); - } - - /** - * @brief Sets a new regex pattern with optional flags. - * @param pattern The new regex pattern. - * @param flags The regex syntax option flags. - */ - void setPattern(std::string_view pattern, - ::boost::regex_constants::syntax_option_type flags = - ::boost::regex_constants::normal) { - regex_.assign(pattern.data(), flags); - } - - /** - * @brief Matches the given string and returns the named captures. - * @tparam T The type of the input string, convertible to std::string_view. - * @param str The input string to match. - * @return A map of named captures. - */ - template - requires std::convertible_to - auto namedCaptures(const T& str) const - -> std::map { - std::map result; - ::boost::smatch what; - if (::boost::regex_match(std::string(str), what, regex_)) { - for (size_t i = 1; i <= regex_.mark_count(); ++i) { - result[std::to_string(i)] = what[i].str(); - } - } - return result; - } - - /** - * @brief Checks if the given string is a valid match for the regex pattern. - * @tparam T The type of the input string, convertible to std::string_view. - * @param str The input string to check. - * @return True if the string is a valid match, false otherwise. - */ - template - requires std::convertible_to - auto isValid(const T& str) const -> bool { - try { - ::boost::regex_match(std::string_view(str).begin(), - std::string_view(str).end(), regex_); - return true; - } catch (const ::boost::regex_error&) { - return false; - } - } - - /** - * @brief Replaces all matches of the regex pattern in the given string - * using a callback function. - * @tparam T The type of the input string, convertible to std::string_view. - * @param str The input string. - * @param callback The callback function to generate replacements. - * @return A new string with all matches replaced by the callback results. - */ - template - requires std::convertible_to - auto replaceCallback( - const T& str, - const std::function& callback) - const -> std::string { - std::string result = std::string(str); - ::boost::sregex_iterator iter(result.begin(), result.end(), regex_); - ::boost::sregex_iterator end; - - std::vector> - replacements; - while (iter != end) { - const ::boost::smatch& match = *iter; - std::string replacement = callback(match); - replacements.emplace_back(match.position(), std::move(replacement)); - ++iter; - } - - for (auto iter = replacements.rbegin(); iter != replacements.rend(); - ++iter) { - result.replace(iter->first, iter->second.length(), iter->second); - } - - return result; - } - - /** - * @brief Escapes special characters in the given string for use in a regex - * pattern. - * @param str The input string to escape. - * @return The escaped string. - */ - [[nodiscard]] static auto escapeString(const std::string& str) - -> std::string { - return ::boost::regex_replace( - str, ::boost::regex(R"([.^$|()\[\]{}*+?\\])"), R"(\\&)", - ::boost::regex_constants::match_default | - ::boost::regex_constants::format_sed); - } - - /** - * @brief Benchmarks the match operation for the given string over a number - * of iterations. - * @tparam T The type of the input string, convertible to std::string_view. - * @param str The input string to match. - * @param iterations The number of iterations to run the benchmark. - * @return The average time per match operation in nanoseconds. - */ - template - requires std::convertible_to - auto benchmarkMatch(const T& str, int iterations = 1000) const - -> std::chrono::nanoseconds { - auto start = std::chrono::high_resolution_clock::now(); - for (int i = 0; i < iterations; ++i) { - ::boost::regex_match(std::string_view(str).begin(), - std::string_view(str).end(), regex_); - } - auto end = std::chrono::high_resolution_clock::now(); - return std::chrono::duration_cast(end - - start) / - iterations; - } - - /** - * @brief Checks if the given regex pattern is valid. - * @param pattern The regex pattern to check. - * @return True if the pattern is valid, false otherwise. - */ - static auto isValidRegex(const std::string& pattern) -> bool { - try { - ::boost::regex test(pattern); - return true; - } catch (const ::boost::regex_error&) { - return false; - } - } - -private: - ::boost::regex regex_; ///< The Boost.Regex object. -}; - -} // namespace atom::extra::boost - -#endif diff --git a/src/atom/extra/boost/system.hpp b/src/atom/extra/boost/system.hpp deleted file mode 100644 index 35ec7e3d..00000000 --- a/src/atom/extra/boost/system.hpp +++ /dev/null @@ -1,315 +0,0 @@ -#ifndef ATOM_EXTRA_BOOST_SYSTEM_HPP -#define ATOM_EXTRA_BOOST_SYSTEM_HPP - -#if __has_include() -#include -#endif -#include -#include - -#include -#include -#include - -namespace atom::extra::boost { - -/** - * @class Error - * @brief A wrapper class for Boost.System error codes. - */ -class Error { -public: - /** - * @brief Default constructor. - */ - Error() noexcept = default; - - /** - * @brief Constructs an Error from a Boost.System error code. - * @param error_code The Boost.System error code. - */ - explicit Error(const ::boost::system::error_code& error_code) noexcept - : m_ec_(error_code) {} - - /** - * @brief Constructs an Error from an error value and category. - * @param error_value The error value. - * @param error_category The error category. - */ - Error(int error_value, - const ::boost::system::error_category& error_category) noexcept - : m_ec_(error_value, error_category) {} - - /** - * @brief Gets the error value. - * @return The error value. - */ - [[nodiscard]] auto value() const noexcept -> int { return m_ec_.value(); } - - /** - * @brief Gets the error category. - * @return The error category. - */ - [[nodiscard]] auto category() const noexcept - -> const ::boost::system::error_category& { - return m_ec_.category(); - } - - /** - * @brief Gets the error message. - * @return The error message. - */ - [[nodiscard]] auto message() const -> std::string { - return m_ec_.message(); - } - - /** - * @brief Checks if the error code is valid. - * @return True if the error code is valid, false otherwise. - */ - [[nodiscard]] explicit operator bool() const noexcept { - return static_cast(m_ec_); - } - - /** - * @brief Converts to a Boost.System error code. - * @return The Boost.System error code. - */ - [[nodiscard]] auto toBoostErrorCode() const noexcept - -> ::boost::system::error_code { - return m_ec_; - } - - /** - * @brief Equality operator. - * @param other The other Error to compare. - * @return True if the errors are equal, false otherwise. - */ - [[nodiscard]] auto operator==(const Error& other) const noexcept -> bool { - return m_ec_ == other.m_ec_; - } - - /** - * @brief Inequality operator. - * @param other The other Error to compare. - * @return True if the errors are not equal, false otherwise. - */ - [[nodiscard]] auto operator!=(const Error& other) const noexcept -> bool { - return !(*this == other); - } - -private: - ::boost::system::error_code m_ec_; ///< The Boost.System error code. -}; - -/** - * @class Exception - * @brief A custom exception class for handling errors. - */ -class Exception : public std::system_error { -public: - /** - * @brief Constructs an Exception from an Error. - * @param error The Error object. - */ - explicit Exception(const Error& error) - : std::system_error(error.value(), error.category(), error.message()) {} - - /** - * @brief Gets the associated Error. - * @return The associated Error. - */ - [[nodiscard]] auto error() const noexcept -> Error { - return Error(::boost::system::error_code( - code().value(), ::boost::system::generic_category())); - } -}; - -/** - * @class Result - * @brief A class template for handling results with potential errors. - * @tparam T The type of the result value. - */ -template -class Result { -public: - using value_type = T; ///< The type of the result value. - - /** - * @brief Constructs a Result with a value. - * @param value The result value. - */ - explicit Result(T value) : m_value_(std::move(value)) {} - - /** - * @brief Constructs a Result with an Error. - * @param error The Error object. - */ - explicit Result(Error error) : m_error_(error) {} - - /** - * @brief Checks if the Result has a value. - * @return True if the Result has a value, false otherwise. - */ - [[nodiscard]] auto hasValue() const noexcept -> bool { return !m_error_; } - - /** - * @brief Gets the result value. - * @return The result value. - * @throws Exception if there is an error. - */ - [[nodiscard]] auto value() const& -> const T& { - if (!hasValue()) { - throw Exception(m_error_); - } - return *m_value_; - } - - /** - * @brief Gets the result value. - * @return The result value. - * @throws Exception if there is an error. - */ - [[nodiscard]] auto value() && -> T&& { - if (!hasValue()) { - throw Exception(m_error_); - } - return std::move(*m_value_); - } - - /** - * @brief Gets the associated Error. - * @return The associated Error. - */ - [[nodiscard]] auto error() const& noexcept -> const Error& { - return m_error_; - } - - /** - * @brief Gets the associated Error. - * @return The associated Error. - */ - [[nodiscard]] auto error() && noexcept -> Error { return m_error_; } - - /** - * @brief Checks if the Result has a value. - * @return True if the Result has a value, false otherwise. - */ - [[nodiscard]] explicit operator bool() const noexcept { return hasValue(); } - - /** - * @brief Gets the result value or a default value. - * @tparam U The type of the default value. - * @param default_value The default value. - * @return The result value or the default value. - */ - template - auto valueOr(U&& default_value) const& -> T { - return hasValue() ? value() - : static_cast(std::forward(default_value)); - } - - /** - * @brief Applies a function to the result value if it exists. - * @tparam F The type of the function. - * @param func The function to apply. - * @return A new Result with the function applied. - */ - template - auto map(F&& func) const -> Result> { - if (hasValue()) { - return Result>(func(*m_value_)); - } - return Result>(Error(m_error_)); - } - - /** - * @brief Applies a function to the result value if it exists. - * @tparam F The type of the function. - * @param func The function to apply. - * @return The result of the function. - */ - template - auto andThen(F&& func) const -> std::invoke_result_t { - if (hasValue()) { - return func(*m_value_); - } - return std::invoke_result_t(Error(m_error_)); - } - -private: - std::optional m_value_; ///< The result value. - Error m_error_; ///< The associated Error. -}; - -/** - * @class Result - * @brief A specialization of the Result class for void type. - */ -template <> -class Result { -public: - /** - * @brief Default constructor. - */ - Result() = default; - - /** - * @brief Constructs a Result with an Error. - * @param error The Error object. - */ - explicit Result(Error error) : m_error_(error) {} - - /** - * @brief Checks if the Result has a value. - * @return True if the Result has a value, false otherwise. - */ - [[nodiscard]] auto hasValue() const noexcept -> bool { return !m_error_; } - - /** - * @brief Gets the associated Error. - * @return The associated Error. - */ - [[nodiscard]] auto error() const& noexcept -> const Error& { - return m_error_; - } - - /** - * @brief Gets the associated Error. - * @return The associated Error. - */ - [[nodiscard]] auto error() && noexcept -> Error { return m_error_; } - - /** - * @brief Checks if the Result has a value. - * @return True if the Result has a value, false otherwise. - */ - [[nodiscard]] explicit operator bool() const noexcept { return hasValue(); } - -private: - Error m_error_; ///< The associated Error. -}; - -/** - * @brief Creates a Result from a function. - * @tparam F The type of the function. - * @param func The function to execute. - * @return A Result with the function's return value or an Error. - */ -template -auto makeResult(F&& func) -> Result> { - using return_type = std::invoke_result_t; - try { - return Result(func()); - } catch (const Exception& e) { - return Result(e.error()); - } catch (const std::exception&) { - return Result( - Error(::boost::system::errc::invalid_argument, - ::boost::system::generic_category())); - } -} - -} // namespace atom::extra::boost - -#endif // ATOM_EXTRA_BOOST_SYSTEM_HPP diff --git a/src/atom/extra/boost/uuid.hpp b/src/atom/extra/boost/uuid.hpp deleted file mode 100644 index 709c6ec3..00000000 --- a/src/atom/extra/boost/uuid.hpp +++ /dev/null @@ -1,300 +0,0 @@ -#ifndef ATOM_EXTRA_BOOST_UUID_HPP -#define ATOM_EXTRA_BOOST_UUID_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace atom::extra::boost { -constexpr size_t UUID_SIZE = 16; -constexpr int BASE64_RESERVE_SIZE = 22; -constexpr int SHIFT_40 = 40; -constexpr int SHIFT_32 = 32; -constexpr int SHIFT_24 = 24; -constexpr int SHIFT_16 = 16; -constexpr int SHIFT_8 = 8; -constexpr int BASE64_MASK = 63; -constexpr int BASE64_SHIFT_18 = 18; -constexpr int BASE64_SHIFT_12 = 12; -constexpr int BASE64_SHIFT_6 = 6; -constexpr uint64_t TIMESTAMP_DIVISOR = 10000000; -constexpr uint64_t UUID_EPOCH = 0x01B21DD213814000L; - -/** - * @class UUID - * @brief A wrapper class for Boost.UUID providing various UUID operations. - */ -class UUID { -private: - ::boost::uuids::uuid uuid_; ///< The Boost.UUID object. - -public: - /** - * @brief Default constructor that generates a random UUID (v4). - */ - UUID() : uuid_(::boost::uuids::random_generator()()) {} - - /** - * @brief Constructs a UUID from a string representation. - * @param str The string representation of the UUID. - */ - explicit UUID(const std::string& str) - : uuid_(::boost::uuids::string_generator()(str)) {} - - /** - * @brief Constructs a UUID from a Boost.UUID object. - * @param uuid The Boost.UUID object. - */ - explicit UUID(const ::boost::uuids::uuid& uuid) : uuid_(uuid) {} - - /** - * @brief Converts the UUID to a string representation. - * @return The string representation of the UUID. - */ - [[nodiscard]] auto toString() const -> std::string { - return ::boost::uuids::to_string(uuid_); - } - - /** - * @brief Checks if the UUID is nil (all zeros). - * @return True if the UUID is nil, false otherwise. - */ - [[nodiscard]] auto isNil() const -> bool { return uuid_.is_nil(); } - - /** - * @brief Compares this UUID with another UUID. - * @param other The other UUID to compare. - * @return The result of the comparison. - */ - auto operator<=>(const UUID& other) const -> std::strong_ordering { - if (uuid_ < other.uuid_) { - return std::strong_ordering::less; - } - if (uuid_ > other.uuid_) { - return std::strong_ordering::greater; - } - return std::strong_ordering::equal; - } - - /** - * @brief Checks if this UUID is equal to another UUID. - * @param other The other UUID to compare. - * @return True if the UUIDs are equal, false otherwise. - */ - auto operator==(const UUID& other) const -> bool { - return uuid_ == other.uuid_; - } - - /** - * @brief Formats the UUID as a string enclosed in curly braces. - * @return The formatted string. - */ - [[nodiscard]] auto format() const -> std::string { - return std::format("{{{}}}", toString()); - } - - /** - * @brief Converts the UUID to a vector of bytes. - * @return The vector of bytes representing the UUID. - */ - [[nodiscard]] auto toBytes() const -> std::vector { - return {uuid_.begin(), uuid_.end()}; - } - - /** - * @brief Constructs a UUID from a span of bytes. - * @param bytes The span of bytes. - * @return The constructed UUID. - * @throws std::invalid_argument if the span size is not 16 bytes. - */ - static auto fromBytes(const std::span& bytes) -> UUID { - if (bytes.size() != UUID_SIZE) { - throw std::invalid_argument("UUID must be exactly 16 bytes"); - } - ::boost::uuids::uuid uuid; - std::copy(bytes.begin(), bytes.end(), uuid.begin()); - return UUID(uuid); - } - - /** - * @brief Converts the UUID to a 64-bit unsigned integer. - * @return The 64-bit unsigned integer representation of the UUID. - */ - [[nodiscard]] auto toUint64() const -> uint64_t { - return ::boost::lexical_cast(uuid_); - } - - /** - * @brief Gets the DNS namespace UUID. - * @return The DNS namespace UUID. - */ - static auto namespaceDNS() -> UUID { - return UUID(::boost::uuids::ns::dns()); - } - - /** - * @brief Gets the URL namespace UUID. - * @return The URL namespace UUID. - */ - static auto namespaceURL() -> UUID { - return UUID(::boost::uuids::ns::url()); - } - - /** - * @brief Gets the OID namespace UUID. - * @return The OID namespace UUID. - */ - static auto namespaceOID() -> UUID { - return UUID(::boost::uuids::ns::oid()); - } - - /** - * @brief Generates a version 3 (MD5) UUID based on a namespace UUID and a - * name. - * @param namespace_uuid The namespace UUID. - * @param name The name. - * @return The generated UUID. - */ - static auto v3(const UUID& namespace_uuid, - const std::string& name) -> UUID { - return UUID(::boost::uuids::name_generator(namespace_uuid.uuid_)(name)); - } - - /** - * @brief Generates a version 5 (SHA-1) UUID based on a namespace UUID and a - * name. - * @param namespace_uuid The namespace UUID. - * @param name The name. - * @return The generated UUID. - */ - static auto v5(const UUID& namespace_uuid, - const std::string& name) -> UUID { - ::boost::uuids::name_generator_sha1 gen(namespace_uuid.uuid_); - return UUID(gen(name)); - } - - /** - * @brief Gets the version of the UUID. - * @return The version of the UUID. - */ - [[nodiscard]] auto version() const -> int { return uuid_.version(); } - - /** - * @brief Gets the variant of the UUID. - * @return The variant of the UUID. - */ - [[nodiscard]] auto variant() const -> int { return uuid_.variant(); } - - /** - * @brief Generates a version 1 (timestamp-based) UUID. - * @return The generated UUID. - */ - [[nodiscard]] static auto v1() -> UUID { - static ::boost::uuids::basic_random_generator gen; - return UUID(gen()); - } - - /** - * @brief Generates a version 4 (random) UUID. - * @return The generated UUID. - */ - [[nodiscard]] static auto v4() -> UUID { - return {}; // Default constructor already generates v4 UUID - } - - /** - * @brief Converts the UUID to a Base64 string representation. - * @return The Base64 string representation of the UUID. - */ - [[nodiscard]] auto toBase64() const -> std::string { - static const char* basE64Chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string result; - result.reserve(BASE64_RESERVE_SIZE); - - auto bytes = toBytes(); - for (size_t i = 0; i < bytes.size(); i += 3) { - uint32_t num = - (bytes[i] << SHIFT_16) | - (i + 1 < bytes.size() ? bytes[i + 1] << SHIFT_8 : 0) | - (i + 2 < bytes.size() ? bytes[i + 2] : 0); - result += basE64Chars[(num >> BASE64_SHIFT_18) & BASE64_MASK]; - result += basE64Chars[(num >> BASE64_SHIFT_12) & BASE64_MASK]; - result += basE64Chars[(num >> BASE64_SHIFT_6) & BASE64_MASK]; - result += basE64Chars[num & BASE64_MASK]; - } - result.resize(BASE64_RESERVE_SIZE); // Remove padding - return result; - } - - /** - * @brief Gets the timestamp from a version 1 UUID. - * @return The timestamp as a std::chrono::system_clock::time_point. - * @throws std::runtime_error if the UUID is not version 1. - */ - [[nodiscard]] auto getTimestamp() const - -> std::chrono::system_clock::time_point { - if (version() != 1) { - throw std::runtime_error( - "Timestamp is only available for version 1 UUIDs"); - } - uint64_t timestamp = ((uint64_t)uuid_.data[6] << SHIFT_40) | - ((uint64_t)uuid_.data[7] << SHIFT_32) | - ((uint64_t)uuid_.data[4] << SHIFT_24) | - ((uint64_t)uuid_.data[5] << SHIFT_16) | - ((uint64_t)uuid_.data[0] << SHIFT_8) | - (uint64_t)uuid_.data[1]; - return std::chrono::system_clock::from_time_t(static_cast( - timestamp / TIMESTAMP_DIVISOR - UUID_EPOCH / TIMESTAMP_DIVISOR)); - } - - /** - * @brief Hash function for UUIDs. - * @tparam H The hash function type. - * @param h The hash function. - * @param uuid The UUID to hash. - * @return The hash value. - */ - template - friend auto abslHashValue(H h, const UUID& uuid) -> H { - return H::combine(std::move(h), uuid.uuid_); - } - - /** - * @brief Gets the underlying Boost.UUID object. - * @return The Boost.UUID object. - */ - [[nodiscard]] auto getUUID() const -> const ::boost::uuids::uuid& { - return uuid_; - } -}; -} // namespace atom::extra::boost - -namespace std { -/** - * @brief Specialization of std::hash for UUID. - */ -template <> -struct hash { - /** - * @brief Hash function for UUIDs. - * @param uuid The UUID to hash. - * @return The hash value. - */ - auto operator()(const atom::extra::boost::UUID& uuid) const -> size_t { - return ::boost::hash<::boost::uuids::uuid>()(uuid.getUUID()); - } -}; -} // namespace std - -#endif diff --git a/src/atom/extra/inicpp/common.hpp b/src/atom/extra/inicpp/common.hpp deleted file mode 100644 index 0a14b8d5..00000000 --- a/src/atom/extra/inicpp/common.hpp +++ /dev/null @@ -1,101 +0,0 @@ -#ifndef ATOM_EXTRA_INICPP_COMMON_HPP -#define ATOM_EXTRA_INICPP_COMMON_HPP - -#include -#include -#include -#include -#include - -#include "atom/macro.hpp" - -namespace inicpp { - -/** - * @brief Returns a string view of whitespace characters. - * @return A string view containing whitespace characters. - */ -ATOM_CONSTEXPR auto whitespaces() -> std::string_view { return " \t\n\r\f\v"; } - -/** - * @brief Returns a string view of indent characters. - * @return A string view containing indent characters. - */ -ATOM_CONSTEXPR auto indents() -> std::string_view { return " \t"; } - -/** - * @brief Trims leading and trailing whitespace from a string. - * @param str The string to trim. - */ -ATOM_INLINE void trim(std::string &str) { - auto first = str.find_first_not_of(whitespaces()); - auto last = str.find_last_not_of(whitespaces()); - - if (first == std::string::npos || last == std::string::npos) { - str.clear(); - } else { - str = str.substr(first, last - first + 1); - } -} - -/** - * @brief Converts a string view to a long integer. - * @param value The string view to convert. - * @return An optional containing the converted long integer, or std::nullopt if - * conversion fails. - */ -ATOM_INLINE auto strToLong(std::string_view value) -> std::optional { - long result; - auto [ptr, ec] = - std::from_chars(value.data(), value.data() + value.size(), result); - if (ec == std::errc()) { - return result; - } - return std::nullopt; -} - -/** - * @brief Converts a string view to an unsigned long integer. - * @param value The string view to convert. - * @return An optional containing the converted unsigned long integer, or - * std::nullopt if conversion fails. - */ -ATOM_INLINE auto strToULong(std::string_view value) - -> std::optional { - unsigned long result; - auto [ptr, ec] = - std::from_chars(value.data(), value.data() + value.size(), result); - if (ec == std::errc()) { - return result; - } - return std::nullopt; -} - -/** - * @struct StringInsensitiveLess - * @brief A comparator for case-insensitive string comparison. - */ -struct StringInsensitiveLess { - /** - * @brief Compares two strings in a case-insensitive manner. - * @param lhs The left-hand side string view. - * @param rhs The right-hand side string view. - * @return True if lhs is less than rhs, false otherwise. - */ - auto operator()(std::string_view lhs, std::string_view rhs) const -> bool { - auto tolower = [](unsigned char ctx) { return std::tolower(ctx); }; - - auto lhsRange = std::ranges::subrange(lhs.begin(), lhs.end()); - auto rhsRange = std::ranges::subrange(rhs.begin(), rhs.end()); - - return std::ranges::lexicographical_compare( - lhsRange, rhsRange, - [tolower](unsigned char first, unsigned char second) { - return tolower(first) < tolower(second); - }); - } -}; - -} // namespace inicpp - -#endif // ATOM_EXTRA_INICPP_COMMON_HPP diff --git a/src/atom/extra/inicpp/convert.hpp b/src/atom/extra/inicpp/convert.hpp deleted file mode 100644 index acfe7154..00000000 --- a/src/atom/extra/inicpp/convert.hpp +++ /dev/null @@ -1,427 +0,0 @@ -#ifndef ATOM_EXTRA_INICPP_CONVERT_HPP -#define ATOM_EXTRA_INICPP_CONVERT_HPP - -#include -#include -#include "common.hpp" - -namespace inicpp { - -/** - * @brief Template structure for converting between types and strings. - * @tparam T The type to convert. - */ -template -struct Convert {}; - -/** - * @brief Specialization of Convert for bool type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to a bool. - * @param value The string view to decode. - * @param result The resulting bool. - * @throws std::invalid_argument if the string is not "TRUE" or "FALSE". - */ - void decode(std::string_view value, bool &result) { - std::string str(value); - std::ranges::transform(str, str.begin(), [](char c) { - return static_cast(::toupper(c)); - }); - - if (str == "TRUE") - result = true; - else if (str == "FALSE") - result = false; - else - throw std::invalid_argument("field is not a bool"); - } - - /** - * @brief Encodes a bool to a string. - * @param value The bool to encode. - * @param result The resulting string. - */ - void encode(const bool value, std::string &result) { - result = value ? "true" : "false"; - } -}; - -/** - * @brief Specialization of Convert for char type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to a char. - * @param value The string view to decode. - * @param result The resulting char. - * @throws std::invalid_argument if the string is empty. - */ - void decode(std::string_view value, char &result) { - if (value.empty()) - throw std::invalid_argument("field is empty"); - result = value.front(); - } - - /** - * @brief Encodes a char to a string. - * @param value The char to encode. - * @param result The resulting string. - */ - void encode(const char value, std::string &result) { result = value; } -}; - -/** - * @brief Specialization of Convert for unsigned char type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to an unsigned char. - * @param value The string view to decode. - * @param result The resulting unsigned char. - * @throws std::invalid_argument if the string is empty. - */ - void decode(std::string_view value, unsigned char &result) { - if (value.empty()) - throw std::invalid_argument("field is empty"); - result = value.front(); - } - - /** - * @brief Encodes an unsigned char to a string. - * @param value The unsigned char to encode. - * @param result The resulting string. - */ - void encode(const unsigned char value, std::string &result) { - result = value; - } -}; - -/** - * @brief Specialization of Convert for short type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to a short. - * @param value The string view to decode. - * @param result The resulting short. - * @throws std::invalid_argument if the string cannot be converted to a - * short. - */ - void decode(std::string_view value, short &result) { - if (auto tmp = strToLong(value); tmp.has_value()) - result = static_cast(tmp.value()); - else - throw std::invalid_argument("field is not a short"); - } - - /** - * @brief Encodes a short to a string. - * @param value The short to encode. - * @param result The resulting string. - */ - void encode(const short value, std::string &result) { - result = std::to_string(value); - } -}; - -/** - * @brief Specialization of Convert for unsigned short type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to an unsigned short. - * @param value The string view to decode. - * @param result The resulting unsigned short. - * @throws std::invalid_argument if the string cannot be converted to an - * unsigned short. - */ - void decode(std::string_view value, unsigned short &result) { - if (auto tmp = strToULong(value); tmp.has_value()) - result = static_cast(tmp.value()); - else - throw std::invalid_argument("field is not an unsigned short"); - } - - /** - * @brief Encodes an unsigned short to a string. - * @param value The unsigned short to encode. - * @param result The resulting string. - */ - void encode(const unsigned short value, std::string &result) { - result = std::to_string(value); - } -}; - -/** - * @brief Specialization of Convert for int type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to an int. - * @param value The string view to decode. - * @param result The resulting int. - * @throws std::invalid_argument if the string cannot be converted to an - * int. - */ - void decode(std::string_view value, int &result) { - if (auto tmp = strToLong(value); tmp.has_value()) - result = static_cast(tmp.value()); - else - throw std::invalid_argument("field is not an int"); - } - - /** - * @brief Encodes an int to a string. - * @param value The int to encode. - * @param result The resulting string. - */ - void encode(const int value, std::string &result) { - result = std::to_string(value); - } -}; - -/** - * @brief Specialization of Convert for unsigned int type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to an unsigned int. - * @param value The string view to decode. - * @param result The resulting unsigned int. - * @throws std::invalid_argument if the string cannot be converted to an - * unsigned int. - */ - void decode(std::string_view value, unsigned int &result) { - if (auto tmp = strToULong(value); tmp.has_value()) - result = static_cast(tmp.value()); - else - throw std::invalid_argument("field is not an unsigned int"); - } - - /** - * @brief Encodes an unsigned int to a string. - * @param value The unsigned int to encode. - * @param result The resulting string. - */ - void encode(const unsigned int value, std::string &result) { - result = std::to_string(value); - } -}; - -/** - * @brief Specialization of Convert for long type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to a long. - * @param value The string view to decode. - * @param result The resulting long. - * @throws std::invalid_argument if the string cannot be converted to a - * long. - */ - void decode(std::string_view value, long &result) { - if (auto tmp = strToLong(value); tmp.has_value()) - result = tmp.value(); - else - throw std::invalid_argument("field is not a long"); - } - - /** - * @brief Encodes a long to a string. - * @param value The long to encode. - * @param result The resulting string. - */ - void encode(const long value, std::string &result) { - result = std::to_string(value); - } -}; - -/** - * @brief Specialization of Convert for unsigned long type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to an unsigned long. - * @param value The string view to decode. - * @param result The resulting unsigned long. - * @throws std::invalid_argument if the string cannot be converted to an - * unsigned long. - */ - void decode(std::string_view value, unsigned long &result) { - if (auto tmp = strToULong(value); tmp.has_value()) - result = tmp.value(); - else - throw std::invalid_argument("field is not an unsigned long"); - } - - /** - * @brief Encodes an unsigned long to a string. - * @param value The unsigned long to encode. - * @param result The resulting string. - */ - void encode(const unsigned long value, std::string &result) { - result = std::to_string(value); - } -}; - -/** - * @brief Specialization of Convert for double type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to a double. - * @param value The string view to decode. - * @param result The resulting double. - */ - void decode(std::string_view value, double &result) { - result = std::stod(std::string(value)); - } - - /** - * @brief Encodes a double to a string. - * @param value The double to encode. - * @param result The resulting string. - */ - void encode(const double value, std::string &result) { - result = std::to_string(value); - } -}; - -/** - * @brief Specialization of Convert for float type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to a float. - * @param value The string view to decode. - * @param result The resulting float. - */ - void decode(std::string_view value, float &result) { - result = std::stof(std::string(value)); - } - - /** - * @brief Encodes a float to a string. - * @param value The float to encode. - * @param result The resulting string. - */ - void encode(const float value, std::string &result) { - result = std::to_string(value); - } -}; - -/** - * @brief Specialization of Convert for std::string type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to a std::string. - * @param value The string view to decode. - * @param result The resulting std::string. - */ - void decode(std::string_view value, std::string &result) { result = value; } - - /** - * @brief Encodes a std::string to a string. - * @param value The std::string to encode. - * @param result The resulting string. - */ - void encode(const std::string &value, std::string &result) { - result = value; - } -}; - -#ifdef __cpp_lib_string_view -/** - * @brief Specialization of Convert for std::string_view type. - */ -template <> -struct Convert { - /** - * @brief Decodes a string view to a std::string_view. - * @param value The string view to decode. - * @param result The resulting std::string_view. - */ - void decode(std::string_view value, std::string_view &result) { - result = value; - } - - /** - * @brief Encodes a std::string_view to a string. - * @param value The std::string_view to encode. - * @param result The resulting string. - */ - void encode(std::string_view value, std::string &result) { result = value; } -}; -#endif - -/** - * @brief Specialization of Convert for const char* type. - */ -template <> -struct Convert { - /** - * @brief Encodes a const char* to a string. - * @param value The const char* to encode. - * @param result The resulting string. - */ - void encode(const char *const &value, std::string &result) { - result = value; - } - - /** - * @brief Decodes a string view to a const char*. - * @param value The string view to decode. - * @param result The resulting const char*. - */ - void decode(std::string_view value, const char *&result) { - result = value.data(); - } -}; - -/** - * @brief Specialization of Convert for char arrays. - * @tparam N The size of the char array. - */ -template -struct Convert { - /** - * @brief Decodes a string to a char array. - * @param value The string to decode. - * @param result The resulting char array. - * @throws std::invalid_argument if the string is too large for the char - * array. - */ - void decode(const std::string &value, char (&result)[N]) { - if (value.size() >= N) - throw std::invalid_argument( - "field value is too large for the char array"); - std::copy(value.begin(), value.end(), result); - result[value.size()] = '\0'; // Null-terminate the char array - } - - /** - * @brief Encodes a char array to a string. - * @param value The char array to encode. - * @param result The resulting string. - */ - void encode(const char (&value)[N], std::string &result) { result = value; } -}; - -} // namespace inicpp - -#endif // ATOM_EXTRA_INICPP_CONVERT_HPP diff --git a/src/atom/extra/inicpp/field.hpp b/src/atom/extra/inicpp/field.hpp deleted file mode 100644 index f3148b70..00000000 --- a/src/atom/extra/inicpp/field.hpp +++ /dev/null @@ -1,41 +0,0 @@ -#ifndef ATOM_EXTRA_INICPP_INIFIELD_HPP -#define ATOM_EXTRA_INICPP_INIFIELD_HPP - -#include "convert.hpp" - -#include -#include - -namespace inicpp { - -class IniField { -private: - std::string value_; - -public: - IniField() = default; - explicit IniField(std::string value) : value_(std::move(value)) {} - IniField(const IniField &field) = default; - ~IniField() = default; - - template - T as() const { - Convert conv; - T result; - conv.decode(value_, result); - return result; - } - - template - IniField &operator=(const T &value) { - Convert conv; - conv.encode(value, value_); - return *this; - } - - IniField &operator=(const IniField &field) = default; -}; - -} // namespace inicpp - -#endif // ATOM_EXTRA_INICPP_INIFIELD_HPP diff --git a/src/atom/extra/inicpp/file.hpp b/src/atom/extra/inicpp/file.hpp deleted file mode 100644 index 43974d46..00000000 --- a/src/atom/extra/inicpp/file.hpp +++ /dev/null @@ -1,265 +0,0 @@ -#ifndef ATOM_EXTRA_INICPP_INIFILE_HPP -#define ATOM_EXTRA_INICPP_INIFILE_HPP - -#include -#include -#include -#include "section.hpp" - -#include "atom/error/exception.hpp" - -namespace inicpp { - -/** - * @class IniFileBase - * @brief A class for handling INI files with customizable comparison. - * @tparam Comparator The comparator type for section names. - */ -template -class IniFileBase - : public std::map, Comparator> { -private: - char fieldSep_ = '='; ///< The character used to separate fields. - char esc_ = '\\'; ///< The escape character. - std::vector commentPrefixes_ = { - "#", ";"}; ///< The prefixes for comments. - bool multiLineValues_ = false; ///< Flag to enable multi-line values. - bool overwriteDuplicateFields_ = - true; ///< Flag to allow overwriting duplicate fields. - - /** - * @brief Erases comments from a line. - * @param str The line to process. - * @param startpos The position to start searching for comments. - */ - void eraseComment(std::string &str, std::string::size_type startpos = 0) { - for (const auto &commentPrefix : commentPrefixes_) { - auto pos = str.find(commentPrefix, startpos); - if (pos != std::string::npos) { - // Check for escaped comment - if (pos > 0 && str[pos - 1] == esc_) { - str.erase(pos - 1, 1); - continue; - } - str.erase(pos); - } - } - } - - /** - * @brief Writes a string to an output stream with escaping. - * @param oss The output stream. - * @param str The string to write. - */ - void writeEscaped(std::ostream &oss, const std::string &str) const { - for (size_t i = 0; i < str.length(); ++i) { - auto prefixpos = std::ranges::find_if( - commentPrefixes_, [&](const std::string &prefix) { - return str.find(prefix, i) == i; - }); - - if (prefixpos != commentPrefixes_.end()) { - oss.put(esc_); - oss.write(prefixpos->c_str(), prefixpos->size()); - i += prefixpos->size() - 1; - } else if (multiLineValues_ && str[i] == '\n') { - oss.write("\n\t", 2); - } else { - oss.put(str[i]); - } - } - } - -public: - /** - * @brief Default constructor. - */ - IniFileBase() = default; - - /** - * @brief Constructs an IniFileBase from a file. - * @param filename The path to the INI file. - */ - explicit IniFileBase(const std::string &filename) { load(filename); } - - /** - * @brief Constructs an IniFileBase from an input stream. - * @param iss The input stream. - */ - explicit IniFileBase(std::istream &iss) { decode(iss); } - - /** - * @brief Destructor. - */ - ~IniFileBase() = default; - - /** - * @brief Sets the field separator character. - * @param sep The field separator character. - */ - void setFieldSep(char sep) { fieldSep_ = sep; } - - /** - * @brief Sets the comment prefixes. - * @param commentPrefixes The vector of comment prefixes. - */ - void setCommentPrefixes(const std::vector &commentPrefixes) { - commentPrefixes_ = commentPrefixes; - } - - /** - * @brief Sets the escape character. - * @param esc The escape character. - */ - void setEscapeChar(char esc) { esc_ = esc; } - - /** - * @brief Enables or disables multi-line values. - * @param enable True to enable multi-line values, false to disable. - */ - void setMultiLineValues(bool enable) { multiLineValues_ = enable; } - - /** - * @brief Allows or disallows overwriting duplicate fields. - * @param allowed True to allow overwriting, false to disallow. - */ - void allowOverwriteDuplicateFields(bool allowed) { - overwriteDuplicateFields_ = allowed; - } - - /** - * @brief Decodes an INI file from an input stream. - * @param iss The input stream. - */ - void decode(std::istream &iss) { - this->clear(); - std::string line; - IniSectionBase *currentSection = nullptr; - std::string multiLineValueFieldName; - - int lineNo = 0; - while (std::getline(iss, line)) { - ++lineNo; - eraseComment(line); - bool hasIndent = line.find_first_not_of(indents()) != 0; - trim(line); - - if (line.empty()) { - continue; - } - - if (line.front() == '[') { - // Section line - auto pos = line.find(']'); - if (pos == std::string::npos) { - THROW_LOGIC_ERROR("Section not closed at line " + - std::to_string(lineNo)); - } - if (pos == 1) { - THROW_LOGIC_ERROR("Empty section name at line " + - std::to_string(lineNo)); - } - - std::string secName = line.substr(1, pos - 1); - currentSection = &(*this)[secName]; - multiLineValueFieldName.clear(); - } else { - if (!currentSection) - THROW_LOGIC_ERROR("Field without section at line " + - std::to_string(lineNo)); - - auto pos = line.find(fieldSep_); - if (multiLineValues_ && hasIndent && - !multiLineValueFieldName.empty()) { - (*currentSection)[multiLineValueFieldName] = - (*currentSection)[multiLineValueFieldName] - .template as() + - "\n" + line; - } else if (pos == std::string::npos) { - THROW_LOGIC_ERROR("Field separator missing at line " + - std::to_string(lineNo)); - } else { - std::string name = line.substr(0, pos); - trim(name); - - if (!overwriteDuplicateFields_ && - currentSection->count(name)) { - THROW_LOGIC_ERROR("Duplicate field at line " + - std::to_string(lineNo)); - } - - std::string value = line.substr(pos + 1); - trim(value); - (*currentSection)[name] = value; - - multiLineValueFieldName = name; - } - } - } - } - - /** - * @brief Decodes an INI file from a string. - * @param content The string content of the INI file. - */ - void decode(const std::string &content) { - std::istringstream ss(content); - decode(ss); - } - - /** - * @brief Loads and decodes an INI file from a file path. - * @param fileName The path to the INI file. - */ - void load(const std::string &fileName) { - std::ifstream iss(fileName); - if (!iss.is_open()) { - THROW_FAIL_TO_OPEN_FILE("Unable to open file " + fileName); - } - decode(iss); - } - - /** - * @brief Encodes the INI file to an output stream. - * @param oss The output stream. - */ - void encode(std::ostream &oss) const { - for (const auto §ionPair : *this) { - oss << '[' << sectionPair.first << "]\n"; - for (const auto &fieldPair : sectionPair.second) { - oss << fieldPair.first << fieldSep_ - << fieldPair.second.template as() << "\n"; - } - } - } - - /** - * @brief Encodes the INI file to a string and returns it. - * @return The encoded INI file as a string. - */ - [[nodiscard]] auto encode() const -> std::string { - std::ostringstream sss; - encode(sss); - return sss.str(); - } - - /** - * @brief Saves the INI file to a given file path. - * @param fileName The path to the file. - */ - void save(const std::string &fileName) const { - std::ofstream oss(fileName); - if (!oss.is_open()) { - THROW_FAIL_TO_OPEN_FILE("Unable to open file " + fileName); - } - encode(oss); - } -}; - -using IniFile = IniFileBase>; ///< Case-sensitive INI file. -using IniFileCaseInsensitive = - IniFileBase; ///< Case-insensitive INI file. - -} // namespace inicpp - -#endif // ATOM_EXTRA_INICPP_INIFILE_HPP diff --git a/src/atom/extra/inicpp/inicpp.hpp b/src/atom/extra/inicpp/inicpp.hpp deleted file mode 100644 index 539ea15c..00000000 --- a/src/atom/extra/inicpp/inicpp.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef ATOM_EXTRA_INICPP_HPP -#define ATOM_EXTRA_INICPP_HPP - -#include "common.hpp" -#include "convert.hpp" -#include "field.hpp" -#include "file.hpp" -#include "section.hpp" - -#endif // ATOM_EXTRA_INICPP_HPP diff --git a/src/atom/extra/inicpp/section.hpp b/src/atom/extra/inicpp/section.hpp deleted file mode 100644 index 4a548f45..00000000 --- a/src/atom/extra/inicpp/section.hpp +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef ATOM_EXTRA_INICPP_INISECTION_HPP -#define ATOM_EXTRA_INICPP_INISECTION_HPP - -#include -#include - -#include "field.hpp" - -namespace inicpp { - -template -class IniSectionBase : public std::map { -public: - IniSectionBase() = default; - ~IniSectionBase() = default; -}; - -using IniSection = IniSectionBase>; -using IniSectionCaseInsensitive = IniSectionBase; - -} // namespace inicpp - -#endif // ATOM_EXTRA_INICPP_INISECTION_HPP diff --git a/src/atom/extra/injection/all.hpp b/src/atom/extra/injection/all.hpp deleted file mode 100644 index 83a941b3..00000000 --- a/src/atom/extra/injection/all.hpp +++ /dev/null @@ -1,7 +0,0 @@ -#pragma once - -#include "common.hpp" -#include "inject.hpp" -#include "resolver.hpp" -#include "binding.hpp" -#include "container.hpp" diff --git a/src/atom/extra/injection/binding.hpp b/src/atom/extra/injection/binding.hpp deleted file mode 100644 index 7069cacd..00000000 --- a/src/atom/extra/injection/binding.hpp +++ /dev/null @@ -1,146 +0,0 @@ -#pragma once - -#include "common.hpp" -#include "resolver.hpp" - -namespace atom::extra { - -/** - * @class BindingScope - * @brief A class template for managing the lifecycle of bindings. - * @tparam T The type of the binding. - * @tparam SymbolTypes The symbol types associated with the binding. - */ -template -class BindingScope { -public: - /** - * @brief Sets the binding to transient scope. - */ - void inTransientScope() { lifecycle_ = Lifecycle::Transient; } - - /** - * @brief Sets the binding to singleton scope. - */ - void inSingletonScope() { - lifecycle_ = Lifecycle::Singleton; - resolver_ = - std::make_shared>(resolver_); - } - - /** - * @brief Sets the binding to request scope. - */ - void inRequestScope() { lifecycle_ = Lifecycle::Request; } - -protected: - ResolverPtr - resolver_; ///< The resolver for the binding. - Lifecycle lifecycle_ = - Lifecycle::Transient; ///< The lifecycle of the binding. -}; - -/** - * @class BindingTo - * @brief A class template for binding to specific values or factories. - * @tparam T The type of the binding. - * @tparam SymbolTypes The symbol types associated with the binding. - */ -template -class BindingTo : public BindingScope { -public: - /** - * @brief Binds to a constant value. - * @param value The constant value to bind. - */ - void toConstantValue(T&& value) { - this->resolver_ = std::make_shared>( - std::forward(value)); - } - - /** - * @brief Binds to a dynamic value generated by a factory. - * @param factory The factory to generate the dynamic value. - * @return A reference to the BindingScope. - */ - BindingScope& toDynamicValue( - Factory&& factory) { - this->resolver_ = std::make_shared>( - std::move(factory)); - return *this; - } - - /** - * @brief Binds to another type. - * @tparam U The type to bind to. - * @return A reference to the BindingScope. - */ - template - BindingScope& to() { - this->resolver_ = - std::make_shared>(); - return *this; - } -}; - -/** - * @class Binding - * @brief A class template for managing bindings and resolving values. - * @tparam T The type of the binding. - * @tparam SymbolTypes The symbol types associated with the binding. - */ -template -class Binding : public BindingTo { -public: - /** - * @brief Resolves the value of the binding. - * @param context The context for resolving the value. - * @return The resolved value. - * @throws exceptions::ResolutionException if the resolver is not found. - */ - typename T::value resolve(const Context& context) { - if (!this->resolver_) { - throw exceptions::ResolutionException( - "atom::extra::Resolver not found. Malformed binding."); - } - return this->resolver_->resolve(context); - } - - /** - * @brief Adds a tag to the binding. - * @param tag The tag to add. - */ - void when(const Tag& tag) { tags_.push_back(tag); } - - /** - * @brief Sets the target name for the binding. - * @param name The target name. - */ - void whenTargetNamed(const std::string& name) { targetName_ = name; } - - /** - * @brief Checks if the binding matches a given tag. - * @param tag The tag to check. - * @return True if the binding matches the tag, false otherwise. - */ - bool matchesTag(const Tag& tag) const { - return std::find_if(tags_.begin(), tags_.end(), [&](const Tag& t) { - return t.name == tag.name; - }) != tags_.end(); - } - - /** - * @brief Checks if the binding matches a given target name. - * @param name The target name to check. - * @return True if the binding matches the target name, false otherwise. - */ - bool matchesTargetName(const std::string& name) const { - return targetName_ == name; - } - -private: - std::vector tags_; ///< The tags associated with the binding. - std::string targetName_; ///< The target name for the binding. -}; - -} // namespace atom::extra diff --git a/src/atom/extra/injection/common.hpp b/src/atom/extra/injection/common.hpp deleted file mode 100644 index 33e3d079..00000000 --- a/src/atom/extra/injection/common.hpp +++ /dev/null @@ -1,150 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace atom::extra { - -// Forward declarations -template -class Container; - -template -struct Context { - Container& container; -}; - -// Concepts - -/** - * @brief Concept to check if a type is symbolic. - * @tparam T The type to check. - */ -template -concept Symbolic = requires { typename T::value; }; - -/** - * @brief Concept to check if a type is injectable. - * @tparam T The type to check. - */ -template -concept Injectable = requires { - { - T::template resolve(std::declval&>()) - } -> std::convertible_to>; -}; - -// Symbol - -/** - * @brief A struct representing a symbol for an interface. - * @tparam Interface The interface type. - */ -template -struct Symbol { - static_assert(!std::is_abstract_v, - "atom::extra::Container cannot bind/get abstract class value " - "(use a smart pointer instead)."); - using value = Interface; -}; - -// Factory - -/** - * @brief A type alias for a factory function. - * @tparam T The type to produce. - * @tparam SymbolTypes The symbol types associated with the factory. - */ -template -using Factory = std::function&)>; - -// Exceptions - -namespace exceptions { - -/** - * @brief Exception thrown when resolution fails. - */ -struct ResolutionException : public std::runtime_error { - using std::runtime_error::runtime_error; -}; - -} // namespace exceptions - -// Lifecycle - -/** - * @brief Enum representing the lifecycle of a binding. - */ -enum class Lifecycle { - Transient, ///< The binding is created anew each time. - Singleton, ///< The binding is created once and shared. - Request ///< The binding is created once per request. -}; - -// Tag - -/** - * @brief A struct representing a tag for a binding. - */ -struct Tag { - std::string name; ///< The name of the tag. - explicit Tag(std::string tag_name) : name(std::move(tag_name)) {} -}; - -// Named - -/** - * @brief A struct representing a named binding. - * @tparam T The type of the binding. - */ -template -struct Named { - std::string name; ///< The name of the binding. - using value = T; ///< The type of the binding. - explicit Named(std::string binding_name) : name(std::move(binding_name)) {} -}; - -// Multi - -/** - * @brief A struct representing a multi-binding. - * @tparam T The type of the binding. - */ -template -struct Multi { - using value = std::vector; ///< The type of the multi-binding. -}; - -// Lazy - -/** - * @brief A class representing a lazy binding. - * @tparam T The type of the binding. - */ -template -class Lazy { -public: - /** - * @brief Constructs a Lazy binding with a factory function. - * @param factory The factory function to produce the binding. - */ - explicit Lazy(std::function factory) : factory_(std::move(factory)) {} - - /** - * @brief Gets the value of the binding. - * @return The value of the binding. - */ - T get() const { return factory_(); } - -private: - std::function - factory_; ///< The factory function to produce the binding. -}; - -} // namespace atom::extra diff --git a/src/atom/extra/injection/container.hpp b/src/atom/extra/injection/container.hpp deleted file mode 100644 index 5c3b2967..00000000 --- a/src/atom/extra/injection/container.hpp +++ /dev/null @@ -1,134 +0,0 @@ -#pragma once - -#include -#include "binding.hpp" -#include "common.hpp" - -namespace atom::extra { - -/** - * @class Container - * @brief A dependency injection container for managing bindings and resolving - * dependencies. - * @tparam SymbolTypes The symbol types associated with the container. - */ -template -class Container { -public: - using BindingMap = - std::tuple...>; ///< The map of - ///< bindings. - - /** - * @brief Binds a symbol to a value or factory. - * @tparam T The symbol type to bind. - * @return A reference to the BindingTo object for further configuration. - */ - template - BindingTo& bind() { - static_assert((std::is_same_v || ...), - "atom::extra::Container symbol not registered"); - return std::get>(bindings_); - } - - /** - * @brief Resolves a value for a given symbol. - * @tparam T The symbol type to resolve. - * @return The resolved value. - */ - template - typename T::value get() { - return get(Tag("")); - } - - /** - * @brief Resolves a value for a given symbol and tag. - * @tparam T The symbol type to resolve. - * @param tag The tag to match. - * @return The resolved value. - * @throws exceptions::ResolutionException if no matching binding is found. - */ - template - typename T::value get(const Tag& tag) { - static_assert((std::is_same_v || ...), - "atom::extra::Container symbol not registered"); - auto& binding = std::get>(bindings_); - if (binding.matchesTag(tag)) { - return binding.resolve(context_); - } - throw exceptions::ResolutionException( - "No matching binding found for the given tag."); - } - - /** - * @brief Resolves a value for a given symbol and name. - * @tparam T The symbol type to resolve. - * @param name The name to match. - * @return The resolved value. - * @throws exceptions::ResolutionException if no matching binding is found. - */ - template - typename T::value getNamed(const std::string& name) { - static_assert((std::is_same_v || ...), - "atom::extra::Container symbol not registered"); - auto& binding = std::get>(bindings_); - if (binding.matchesTargetName(name)) { - return binding.resolve(context_); - } - throw exceptions::ResolutionException( - "No matching binding found for the given name."); - } - - /** - * @brief Resolves all values for a given symbol. - * @tparam T The symbol type to resolve. - * @return A vector of resolved values. - */ - template - std::vector getAll() { - static_assert((std::is_same_v || ...), - "atom::extra::Container symbol not registered"); - std::vector result; - auto& binding = std::get>(bindings_); - result.push_back(binding.resolve(context_)); - return result; - } - - /** - * @brief Checks if a binding exists for a given symbol. - * @tparam T The symbol type to check. - * @return True if a binding exists, false otherwise. - */ - template - bool hasBinding() const { - return std::get>(bindings_).resolver_ != - nullptr; - } - - /** - * @brief Unbinds a symbol, removing its binding. - * @tparam T The symbol type to unbind. - */ - template - void unbind() { - std::get>(bindings_).resolver_.reset(); - } - - /** - * @brief Creates a child container that inherits bindings from the parent. - * @return A unique pointer to the child container. - */ - std::unique_ptr createChildContainer() { - auto child = std::make_unique(); - child->parent_ = this; - return child; - } - -private: - BindingMap bindings_; ///< The map of bindings. - Context context_{ - *this}; ///< The context for resolving dependencies. - Container* parent_ = nullptr; ///< The parent container, if any. -}; - -} // namespace atom::extra diff --git a/src/atom/extra/injection/inject.hpp b/src/atom/extra/injection/inject.hpp deleted file mode 100644 index 13c6236b..00000000 --- a/src/atom/extra/injection/inject.hpp +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include "common.hpp" - -namespace atom::extra { - -template -struct Inject { - template - static auto resolve(const Context& context) { - return std::make_tuple(context.container.template get()...); - } -}; - -template > -struct InjectableA : Inject {}; - -} // namespace atom::extra diff --git a/src/atom/extra/injection/resolver.hpp b/src/atom/extra/injection/resolver.hpp deleted file mode 100644 index 3478be53..00000000 --- a/src/atom/extra/injection/resolver.hpp +++ /dev/null @@ -1,212 +0,0 @@ -#pragma once - -#include -#include "common.hpp" -#include "inject.hpp" - -namespace atom::extra { - -/** - * @class Resolver - * @brief An abstract base class for resolving dependencies. - * @tparam T The type of the dependency. - * @tparam SymbolTypes The symbol types associated with the resolver. - */ -template -class Resolver { -public: - /** - * @brief Virtual destructor. - */ - virtual ~Resolver() = default; - - /** - * @brief Resolves the dependency. - * @param context The context for resolving the dependency. - * @return The resolved dependency. - */ - virtual T resolve(const Context& context) = 0; -}; - -/** - * @brief A type alias for a shared pointer to a Resolver. - * @tparam T The type of the dependency. - * @tparam SymbolTypes The symbol types associated with the resolver. - */ -template -using ResolverPtr = std::shared_ptr>; - -/** - * @class ConstantResolver - * @brief A resolver that returns a constant value. - * @tparam T The type of the dependency. - * @tparam SymbolTypes The symbol types associated with the resolver. - */ -template -class ConstantResolver : public Resolver { -public: - /** - * @brief Constructs a ConstantResolver with a constant value. - * @param value The constant value to return. - */ - explicit ConstantResolver(T value) : value_(std::move(value)) {} - - /** - * @brief Resolves the dependency by returning the constant value. - * @param context The context for resolving the dependency. - * @return The constant value. - */ - T resolve(const Context&) override { return value_; } - -private: - T value_; ///< The constant value. -}; - -/** - * @class DynamicResolver - * @brief A resolver that returns a dynamic value generated by a factory. - * @tparam T The type of the dependency. - * @tparam SymbolTypes The symbol types associated with the resolver. - */ -template -class DynamicResolver : public Resolver { -public: - /** - * @brief Constructs a DynamicResolver with a factory function. - * @param factory The factory function to generate the dynamic value. - */ - explicit DynamicResolver(Factory factory) - : factory_(std::move(factory)) {} - - /** - * @brief Resolves the dependency by calling the factory function. - * @param context The context for resolving the dependency. - * @return The dynamic value generated by the factory. - */ - T resolve(const Context& context) override { - return factory_(context); - } - -private: - Factory factory_; ///< The factory function. -}; - -/** - * @class AutoResolver - * @brief A resolver that automatically resolves dependencies for a type. - * @tparam T The type of the dependency. - * @tparam U The type to instantiate. - * @tparam SymbolTypes The symbol types associated with the resolver. - */ -template -class AutoResolver : public Resolver { -public: - /** - * @brief Resolves the dependency by automatically instantiating the type. - * @param context The context for resolving the dependency. - * @return The instantiated type. - */ - T resolve(const Context& context) override { - return std::make_from_tuple( - InjectableA::template resolve(context)); - } -}; - -/** - * @class AutoResolver, U, SymbolTypes...> - * @brief A resolver that automatically resolves dependencies for a unique - * pointer type. - * @tparam T The type of the dependency. - * @tparam U The type to instantiate. - * @tparam SymbolTypes The symbol types associated with the resolver. - */ -template -class AutoResolver, U, SymbolTypes...> - : public Resolver, SymbolTypes...> { -public: - /** - * @brief Resolves the dependency by automatically instantiating the type as - * a unique pointer. - * @param context The context for resolving the dependency. - * @return The instantiated type as a unique pointer. - */ - std::unique_ptr resolve( - const Context& context) override { - return std::apply( - [](auto&&... deps) { - return std::make_unique( - std::forward(deps)...); - }, - InjectableA::template resolve(context)); - } -}; - -/** - * @class AutoResolver, U, SymbolTypes...> - * @brief A resolver that automatically resolves dependencies for a shared - * pointer type. - * @tparam T The type of the dependency. - * @tparam U The type to instantiate. - * @tparam SymbolTypes The symbol types associated with the resolver. - */ -template -class AutoResolver, U, SymbolTypes...> - : public Resolver, SymbolTypes...> { -public: - /** - * @brief Resolves the dependency by automatically instantiating the type as - * a shared pointer. - * @param context The context for resolving the dependency. - * @return The instantiated type as a shared pointer. - */ - std::shared_ptr resolve( - const Context& context) override { - return std::apply( - [](auto&&... deps) { - return std::make_shared( - std::forward(deps)...); - }, - InjectableA::template resolve(context)); - } -}; - -/** - * @class CachedResolver - * @brief A resolver that caches the resolved value. - * @tparam T The type of the dependency. - * @tparam SymbolTypes The symbol types associated with the resolver. - */ -template -class CachedResolver : public Resolver { - static_assert( - std::is_copy_constructible_v, - "atom::extra::CachedResolver requires a copy constructor. Are " - "you caching a unique_ptr?"); - -public: - /** - * @brief Constructs a CachedResolver with a parent resolver. - * @param parent The parent resolver to cache the value from. - */ - explicit CachedResolver(ResolverPtr parent) - : parent_(std::move(parent)) {} - - /** - * @brief Resolves the dependency by returning the cached value or resolving - * it from the parent. - * @param context The context for resolving the dependency. - * @return The cached value or the resolved value from the parent. - */ - T resolve(const Context& context) override { - if (!cached_.has_value()) { - cached_ = parent_->resolve(context); - } - return cached_.value(); - } - -private: - std::optional cached_; ///< The cached value. - ResolverPtr parent_; ///< The parent resolver. -}; - -} // namespace atom::extra diff --git a/src/atom/function/CMakeLists.txt b/src/atom/function/CMakeLists.txt deleted file mode 100644 index f9d558dc..00000000 --- a/src/atom/function/CMakeLists.txt +++ /dev/null @@ -1,46 +0,0 @@ -# CMakeLists.txt for Atom-Function -# This project is licensed under the terms of the GPL3 license. -# -# Project Name: Atom-Function -# Description: a library for meta programming in C++ -# Author: Max Qian -# License: GPL3 - -cmake_minimum_required(VERSION 3.20) -project(atom-function C CXX) - -list(APPEND ${PROJECT_NAME}_SOURCES - global_ptr.cpp -) - -# Headers -list(APPEND ${PROJECT_NAME}_HEADERS - global_ptr.hpp -) - -list(APPEND ${PROJECT_NAME}_LIBS -) - -# Build Object Library -add_library(${PROJECT_NAME}_OBJECT OBJECT) -set_property(TARGET ${PROJECT_NAME}_OBJECT PROPERTY POSITION_INDEPENDENT_CODE 1) - -target_link_libraries(${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) - -target_sources(${PROJECT_NAME}_OBJECT - PUBLIC - ${${PROJECT_NAME}_HEADERS} - PRIVATE - ${${PROJECT_NAME}_SOURCES} -) - -target_link_libraries(${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) - -add_library(${PROJECT_NAME} STATIC) - -target_link_libraries(${PROJECT_NAME} ${PROJECT_NAME}_OBJECT ${${PROJECT_NAME}_LIBS}) -target_include_directories(${PROJECT_NAME} PUBLIC .) - -install(TARGETS ${PROJECT_NAME} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} -) diff --git a/src/atom/function/abi.hpp b/src/atom/function/abi.hpp deleted file mode 100644 index a0a67dbf..00000000 --- a/src/atom/function/abi.hpp +++ /dev/null @@ -1,236 +0,0 @@ -/*! - * \file abi.hpp - * \brief A simple C++ ABI wrapper - * \author Max Qian - * \date 2024-5-25 - * \copyright Copyright (C) 2023-2024 Max Qian - */ - -#ifndef ATOM_META_ABI_HPP -#define ATOM_META_ABI_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef _MSC_VER -// clang-format off -#include -#include -#pragma comment(lib, "dbghelp.lib") -// clang-format on -#else -#include -#endif - -#if ENABLE_DEBUG -#include -#include -#endif - -namespace atom::meta { - -constexpr std::size_t BUFFER_SIZE = 1024; - -class DemangleHelper { -public: - template - static auto demangleType() -> std::string { - return demangleInternal(typeid(T).name()); - } - - template - static auto demangleType(const T& instance) -> std::string { - return demangleInternal(typeid(instance).name()); - } - - static auto demangle(std::string_view mangled_name, - const std::optional& location = - std::nullopt) -> std::string { - std::string demangled = demangleInternal(mangled_name); - - if (location) { - demangled += " ("; - demangled += location->file_name(); - demangled += ":"; - demangled += std::to_string(location->line()); - demangled += ")"; - } - - return demangled; - } - - static auto demangleMany( - const std::vector& mangled_names, - const std::optional& location = std::nullopt) - -> std::vector { - std::vector demangledNames; - demangledNames.reserve(mangled_names.size()); - - for (const auto& name : mangled_names) { - demangledNames.push_back(demangle(name, location)); - } - - return demangledNames; - } - -#if ENABLE_DEBUG - static auto visualize(const std::string& demangled_name) -> std::string { - return visualizeType(demangled_name); - } -#endif - -private: - static auto demangleInternal(std::string_view mangled_name) -> std::string { - static std::unordered_map cache; - if (auto it = cache.find(mangled_name); it != cache.end()) { - return it->second; - } - -#ifdef _MSC_VER - std::array buffer; - DWORD length = UnDecorateSymbolName(mangled_name.data(), buffer.data(), - buffer.size(), UNDNAME_COMPLETE); - - std::string demangled = (length > 0) - ? std::string(buffer.data(), length) - : std::string(mangled_name); -#else - int status = -1; - std::unique_ptr demangledName( - abi::__cxa_demangle(mangled_name.data(), nullptr, nullptr, &status), - std::free); - - std::string demangled = (status == 0) ? std::string(demangledName.get()) - : std::string(mangled_name); -#endif - - cache[mangled_name] = demangled; - return demangled; - } - -#if ENABLE_DEBUG - static auto visualizeType(const std::string& type_name, - int indent_level = 0) -> std::string { - std::string indent(static_cast(indent_level) * 4, - ' '); // 4 spaces per indent level - std::string result; - - // Regular expressions for parsing - std::regex templateRegex(R"((\w+)<(.*)>)"); - std::regex functionRegex(R"(\((.*)\)\s*->\s*(.*))"); - std::regex ptrRegex(R"((.+)\s*\*\s*)"); - std::regex refRegex(R"((.+)\s*&\s*)"); - std::regex constRegex(R"((const\s+)(.+))"); - std::regex arrayRegex(R"((.+)\s*\[(\d+)\])"); - std::smatch match; - - if (std::regex_match(type_name, match, templateRegex)) { - // Template type - result += indent + "`-- " + match[1].str() + " [template]\n"; - std::string params = match[2].str(); - result += visualizeTemplateParams(params, indent_level + 1); - } else if (std::regex_match(type_name, match, functionRegex)) { - // Function type - result += indent + "`-- function\n"; - std::string params = match[1].str(); - std::string returnType = match[2].str(); - result += visualizeFunctionParams(params, indent_level + 1); - result += indent + " `-- R: " + - visualizeType(returnType, indent_level + 1) - .substr(indent.size() + 4); - } else if (std::regex_match(type_name, match, ptrRegex)) { - // Pointer type - result += indent + "`-- ptr\n"; - result += visualizeType(match[1].str(), indent_level + 1); - } else if (std::regex_match(type_name, match, refRegex)) { - // Reference type - result += indent + "`-- ref\n"; - result += visualizeType(match[1].str(), indent_level + 1); - } else if (std::regex_match(type_name, match, constRegex)) { - // Const type - result += indent + "`-- const\n"; - result += visualizeType(match[2].str(), indent_level + 1); - } else if (std::regex_match(type_name, match, arrayRegex)) { - // Array type - result += indent + "`-- array [N = " + match[2].str() + "]\n"; - result += visualizeType(match[1].str(), indent_level + 1); - } else { - // Simple type - result += indent + "`-- " + type_name + "\n"; - } - - return result; - } - - std::string visualizeTemplateParams(const std::string& params, - int indent_level) { - std::string indent(static_cast(indent_level) * 4, ' '); - std::string result; - int paramIndex = 0; - - size_t start = 0; - int angleBrackets = 0; - - for (size_t i = 0; i < params.size(); ++i) { - if (params[i] == '<') { - ++angleBrackets; - } else if (params[i] == '>') { - --angleBrackets; - } else if (params[i] == ',' && angleBrackets == 0) { - result += indent + "├── " + std::to_string(paramIndex++) + ": "; - result += visualizeType(params.substr(start, i - start), - indent_level + 1); - start = i + 1; - } - } - - result += indent + "└── " + std::to_string(paramIndex) + ": "; - result += visualizeType(params.substr(start), indent_level + 1); - - return result; - } - - static auto visualizeFunctionParams(const std::string& params, - int indent_level) -> std::string { - std::string indent(static_cast(indent_level) * 4, ' '); - std::string result; - int paramIndex = 0; - - size_t start = 0; - size_t end = 0; - int angleBrackets = 0; - - for (size_t i = 0; i < params.size(); ++i) { - if (params[i] == '<') { - ++angleBrackets; - } else if (params[i] == '>') { - --angleBrackets; - } else if (params[i] == ',' && angleBrackets == 0) { - end = i; - result += indent + "|-- " + std::to_string(paramIndex++) + - ": " + - visualizeType(params.substr(start, end - start), - indent_level + 1) - .substr(indent.size() + 4); - start = i + 1; - } - } - - result += indent + "|-- " + std::to_string(paramIndex++) + ": " + - visualizeType(params.substr(start), indent_level + 1) - .substr(indent.size() + 4); - - return result; - } -#endif -}; -} // namespace atom::meta - -#endif // ATOM_META_ABI_HPP diff --git a/src/atom/function/any.hpp b/src/atom/function/any.hpp deleted file mode 100644 index 405368a2..00000000 --- a/src/atom/function/any.hpp +++ /dev/null @@ -1,581 +0,0 @@ -/*! - * \file any.hpp - * \brief Enhanced BoxedValue using C++20 features - * \author Max Qian - * \date 2023-12-28 - * \copyright Copyright (C) 2023-2024 Max Qian - */ - -#ifndef ATOM_META_ANY_HPP -#define ATOM_META_ANY_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "atom/error/exception.hpp" -#include "atom/function/proxy.hpp" -#include "atom/macro.hpp" -#include "type_info.hpp" - -namespace atom::meta { - -/*! - * \class BoxedValue - * \brief A class that encapsulates a value of any type with additional - * metadata. - */ -class BoxedValue { -public: - /*! - * \struct VoidType - * \brief A placeholder type representing void. - */ - struct VoidType {}; - -private: - /*! - * \struct Data - * \brief Internal data structure to hold the value and its metadata. - */ - struct ATOM_ALIGNAS(128) Data { - std::any mObj; ///< The encapsulated value. - TypeInfo mTypeInfo; ///< Type information of the value. - std::shared_ptr>> - mAttrs; ///< Attributes associated with the value. - bool mIsRef = false; ///< Indicates if the value is a reference. - bool mReturnValue = - false; ///< Indicates if the value is a return value. - bool mReadonly = false; ///< Indicates if the value is read-only. - const void* mConstDataPtr = nullptr; ///< Pointer to the constant data. - std::chrono::time_point - mCreationTime; ///< Creation time. - std::chrono::time_point - mModificationTime; ///< Modification time. - int mAccessCount = 0; ///< Access count. - - /*! - * \brief Constructor for non-void types. - * \tparam T The type of the value. - * \param obj The value to be encapsulated. - * \param is_ref Indicates if the value is a reference. - * \param return_value Indicates if the value is a return value. - * \param readonly Indicates if the value is read-only. - */ - template - requires(!std::is_same_v, VoidType>) - Data(T&& obj, bool is_ref, bool return_value, bool readonly) - : mObj(std::forward(obj)), - mTypeInfo(userType>()), - mAttrs(nullptr), - mIsRef(is_ref), - mReturnValue(return_value), - mReadonly(readonly), - mConstDataPtr( - std::is_const_v> ? &obj : nullptr), - mCreationTime(std::chrono::system_clock::now()), - mModificationTime(std::chrono::system_clock::now()) {} - - /*! - * \brief Constructor for void type. - * \tparam T The type of the value. - * \param obj The value to be encapsulated. - * \param is_ref Indicates if the value is a reference. - * \param return_value Indicates if the value is a return value. - * \param readonly Indicates if the value is read-only. - */ - template - requires(std::is_same_v, VoidType>) - Data([[maybe_unused]] T&& obj, bool is_ref, bool return_value, - bool readonly) - : mTypeInfo(userType>()), - mAttrs(nullptr), - mIsRef(is_ref), - mReturnValue(return_value), - mReadonly(readonly), - mCreationTime(std::chrono::system_clock::now()), - mModificationTime(std::chrono::system_clock::now()) {} - }; - - std::shared_ptr m_data_; ///< Shared pointer to the internal data. - mutable std::shared_mutex m_mutex_; ///< Mutex for thread-safe access. - -public: - /*! - * \brief Constructor for any type. - * \tparam T The type of the value. - * \param value The value to be encapsulated. - * \param return_value Indicates if the value is a return value. - * \param readonly Indicates if the value is read-only. - */ - // clang-tidy: disable=hicpp-explicit-constructor - template - requires(!std::same_as>) - BoxedValue(T&& value, bool return_value = false, bool readonly = false) - : m_data_(std::make_shared( - std::forward(value), - std::is_reference_v || - std::is_same_v< - std::decay_t, - std::reference_wrapper>>, - return_value, readonly)) { - if constexpr (std::is_same_v< - std::decay_t, - std::reference_wrapper>>) { - m_data_->mIsRef = true; - } - } - - /*! - * \brief Default constructor for VoidType. - */ - BoxedValue() - : m_data_(std::make_shared(VoidType{}, false, false, false)) {} - - /*! - * \brief Constructor with shared data pointer. - * \param data Shared pointer to the internal data. - */ - BoxedValue(std::shared_ptr data) : m_data_(std::move(data)) {} - - /*! - * \brief Copy constructor. - * \param other The other BoxedValue to copy from. - */ - BoxedValue(const BoxedValue& other) { - std::shared_lock lock(other.m_mutex_); - if (other.m_data_) { - m_data_ = std::make_shared(*other.m_data_); - } else { - m_data_ = nullptr; - } - } - - /*! - * \brief Move constructor. - * \param other The other BoxedValue to move from. - */ - BoxedValue(BoxedValue&& other) noexcept { - std::unique_lock lock(other.m_mutex_); - m_data_ = std::move(other.m_data_); - other.m_data_ = nullptr; - } - - /*! - * \brief Copy assignment operator. - * \param other The other BoxedValue to copy from. - * \return Reference to this BoxedValue. - */ - auto operator=(const BoxedValue& other) -> BoxedValue& { - if (this != &other) { - std::unique_lock lock(m_mutex_); - std::shared_lock otherLock(other.m_mutex_); - m_data_ = std::make_shared(*other.m_data_); - } - return *this; - } - - /*! - * \brief Move assignment operator. - * \param other The other BoxedValue to move from. - * \return Reference to this BoxedValue. - */ - auto operator=(BoxedValue&& other) noexcept -> BoxedValue& { - if (this != &other) { - std::unique_lock lock(m_mutex_); - std::unique_lock otherLock(other.m_mutex_); - m_data_ = std::move(other.m_data_); - } - return *this; - } - - /*! - * \brief Assignment operator for any type. - * \tparam T The type of the value. - * \param value The value to be assigned. - * \return Reference to this BoxedValue. - */ - template - requires(!std::same_as>) - auto operator=(T&& value) -> BoxedValue& { - std::unique_lock lock(m_mutex_); - m_data_->mObj = std::forward(value); - m_data_->mTypeInfo = userType(); - m_data_->mModificationTime = std::chrono::system_clock::now(); - return *this; - } - - /*! - * \brief Assignment operator for constant values. - * \tparam T The type of the value. - * \param value The constant value to be assigned. - * \return Reference to this BoxedValue. - */ - template - auto operator=(const T& value) -> BoxedValue& { - std::unique_lock lock(m_mutex_); - m_data_->mObj = value; - m_data_->mTypeInfo = userType(); - m_data_->mReadonly = true; - m_data_->mModificationTime = std::chrono::system_clock::now(); - return *this; - } - - /*! - * \brief Constructor for constant values. - * \tparam T The type of the value. - * \param value The constant value to be encapsulated. - */ - template - BoxedValue(const T& value) - : m_data_(std::make_shared(value, false, false, true)) {} - - /*! - * \brief Swap function. - * \param rhs The other BoxedValue to swap with. - */ - void swap(BoxedValue& rhs) noexcept { - if (this != &rhs) { - std::scoped_lock lock(m_mutex_, rhs.m_mutex_); - std::swap(m_data_, rhs.m_data_); - } - } - - template - auto isType() const -> bool { - std::shared_lock lock(m_mutex_); - return m_data_->mTypeInfo == userType(); - } - - /*! - * \brief Check if the value is undefined. - * \return True if the value is undefined, false otherwise. - */ - [[nodiscard]] auto isUndef() const noexcept -> bool { - std::shared_lock lock(m_mutex_); - return !m_data_ || m_data_->mObj.type() == typeid(VoidType) || - !m_data_->mObj.has_value(); - } - - /*! - * \brief Check if the value is constant. - * \return True if the value is constant, false otherwise. - */ - [[nodiscard]] auto isConst() const noexcept -> bool { - std::shared_lock lock(m_mutex_); - return m_data_->mTypeInfo.isConst(); - } - - /*! - * \brief Check if the value is of a specific type. - * \param type_info The type information to check against. - * \return True if the value is of the specified type, false otherwise. - */ - [[nodiscard]] auto isType(const TypeInfo& type_info) const noexcept - -> bool { - std::shared_lock lock(m_mutex_); - return m_data_->mTypeInfo == type_info; - } - - /*! - * \brief Check if the value is a reference. - * \return True if the value is a reference, false otherwise. - */ - [[nodiscard]] auto isRef() const noexcept -> bool { - std::shared_lock lock(m_mutex_); - return m_data_->mIsRef; - } - - /*! - * \brief Check if the value is a return value. - * \return True if the value is a return value, false otherwise. - */ - [[nodiscard]] auto isReturnValue() const noexcept -> bool { - std::shared_lock lock(m_mutex_); - return m_data_->mReturnValue; - } - - /*! - * \brief Reset the return value flag. - */ - void resetReturnValue() noexcept { - std::unique_lock lock(m_mutex_); - m_data_->mReturnValue = false; - } - - /*! - * \brief Check if the value is read-only. - * \return True if the value is read-only, false otherwise. - */ - [[nodiscard]] auto isReadonly() const noexcept -> bool { - std::shared_lock lock(m_mutex_); - return m_data_->mReadonly; - } - - /*! - * \brief Check if the value is a constant data pointer. - * \return True if the value is a constant data pointer, false - * otherwise. - */ - [[nodiscard]] auto isConstDataPtr() const noexcept -> bool { - std::shared_lock lock(m_mutex_); - return m_data_->mConstDataPtr != nullptr; - } - - /*! - * \brief Get the encapsulated value. - * \return The encapsulated value. - */ - [[nodiscard]] auto get() const noexcept -> const std::any& { - std::shared_lock lock(m_mutex_); - m_data_->mAccessCount++; - return m_data_->mObj; - } - - /*! - * \brief Get the type information of the value. - * \return The type information of the value. - */ - [[nodiscard]] auto getTypeInfo() const noexcept -> const TypeInfo& { - std::shared_lock lock(m_mutex_); - return m_data_->mTypeInfo; - } - - /*! - * \brief Set an attribute. - * \param name The name of the attribute. - * \param value The value of the attribute. - * \return Reference to this BoxedValue. - */ - auto setAttr(const std::string& name, - const BoxedValue& value) -> BoxedValue& { - std::unique_lock lock(m_mutex_); - if (!m_data_->mAttrs) { - m_data_->mAttrs = std::make_shared< - std::unordered_map>>(); - } - (*m_data_->mAttrs)[name] = value.m_data_; - m_data_->mModificationTime = std::chrono::system_clock::now(); - return *this; - } - - /*! - * \brief Get an attribute. - * \param name The name of the attribute. - * \return The value of the attribute. - */ - [[nodiscard]] auto getAttr(const std::string& name) const -> BoxedValue { - std::shared_lock lock(m_mutex_); - if (m_data_->mAttrs) { - if (auto iter = m_data_->mAttrs->find(name); - iter != m_data_->mAttrs->end()) { - return BoxedValue(iter->second); - } - } - return {}; // Undefined BoxedValue - } - - /*! - * \brief List all attributes. - * \return A vector of attribute names. - */ - [[nodiscard]] auto listAttrs() const -> std::vector { - std::shared_lock lock(m_mutex_); - std::vector attrs; - if (m_data_->mAttrs) { - attrs.reserve(m_data_->mAttrs->size()); - for (const auto& entry : *m_data_->mAttrs) { - attrs.push_back(entry.first); - } - } - return attrs; - } - - /*! - * \brief Check if an attribute exists. - * \param name The name of the attribute. - * \return True if the attribute exists, false otherwise. - */ - [[nodiscard]] auto hasAttr(const std::string& name) const -> bool { - std::shared_lock lock(m_mutex_); - return m_data_->mAttrs && - m_data_->mAttrs->find(name) != m_data_->mAttrs->end(); - } - - /*! - * \brief Remove an attribute. - * \param name The name of the attribute. - */ - void removeAttr(const std::string& name) { - std::unique_lock lock(m_mutex_); - if (m_data_->mAttrs) { - m_data_->mAttrs->erase(name); - m_data_->mModificationTime = std::chrono::system_clock::now(); - } - } - - /*! - * \brief Check if the BoxedValue is null (i.e., contains an unset - * value). \return True if the BoxedValue is null, false otherwise. - */ - [[nodiscard]] auto isNull() const noexcept -> bool { - std::shared_lock lock(m_mutex_); - return !m_data_->mObj.has_value(); - } - - /*! - * \brief Get the pointer to the contained data. - * \return Pointer to the contained data. - */ - [[nodiscard]] auto getPtr() const noexcept -> void* { - std::shared_lock lock(m_mutex_); - return const_cast(m_data_->mConstDataPtr); - } - - /*! - * \brief Try to cast the internal value to a specified type. - * \tparam T The type to cast to. - * \return An optional containing the casted value if successful, - * std::nullopt otherwise. - */ - template - [[nodiscard]] auto tryCast() const noexcept -> std::optional { - std::shared_lock lock(m_mutex_); - try { - if constexpr (std::is_reference_v) { - if (m_data_->mObj.type() == - typeid( - std::reference_wrapper>)) { - return std::any_cast>>(m_data_->mObj) - .get(); - } - } - if (m_data_->mObj.type() == typeid(std::reference_wrapper)) { - return std::any_cast>(m_data_->mObj) - .get(); - } - if (isConst() || isReadonly()) { - using constT = std::add_const_t; - return std::any_cast(m_data_->mObj); - } - return std::any_cast(m_data_->mObj); - } catch (const std::bad_any_cast&) { - return std::nullopt; - } - } - - /*! - * \brief Check if the internal value can be cast to a specified type. - * \tparam T The type to check. - * \return True if the value can be cast to the specified type, false - * otherwise. - */ - template - [[nodiscard]] auto canCast() const noexcept -> bool { - std::shared_lock lock(m_mutex_); - try { - if constexpr (std::is_reference_v) { - return m_data_->mObj.type() == - typeid( - std::reference_wrapper>); - } else { - std::any_cast(m_data_->mObj); - return true; - } - } catch (const std::bad_any_cast&) { - return false; - } - } - - /*! - * \brief Get a debug string representation of the BoxedValue. - * \return A string representing the BoxedValue. - */ - [[nodiscard]] auto debugString() const -> std::string { - std::ostringstream oss; - oss << "BoxedValue<" << m_data_->mTypeInfo.name() << ">: "; - std::shared_lock lock(m_mutex_); - if (auto* intPtr = std::any_cast(&m_data_->mObj)) { - oss << *intPtr; - } else if (auto* doublePtr = std::any_cast(&m_data_->mObj)) { - oss << *doublePtr; - } else if (auto* strPtr = std::any_cast(&m_data_->mObj)) { - oss << *strPtr; - } else { - oss << "unknown type"; - } - return oss.str(); - } - - /*! - * \brief Destructor. - */ - ~BoxedValue() = default; -}; - -/*! - * \brief Helper function to create a BoxedValue instance. - * \tparam T The type of the value. - * \param value The value to be encapsulated. - * \return A BoxedValue instance. - */ -template -auto var(T&& value) -> BoxedValue { - using DecayedType = std::decay_t; - constexpr bool IS_REF_WRAPPER = - std::is_same_v>>; - return BoxedValue(std::forward(value), IS_REF_WRAPPER, false); -} - -/*! - * \brief Helper function to create a constant BoxedValue instance. - * \tparam T The type of the value. - * \param value The constant value to be encapsulated. - * \return A BoxedValue instance. - */ -template -auto constVar(const T& value) -> BoxedValue { - using DecayedType = std::decay_t; - constexpr bool IS_REF_WRAPPER = - std::is_same_v>>; - return BoxedValue(std::cref(value), IS_REF_WRAPPER, true); -} - -inline auto voidVar() -> BoxedValue { return {}; } - -/*! - * \brief Helper function to create a BoxedValue instance with additional - * options. \tparam T The type of the value. \param value The value to be - * encapsulated. \param is_return_value Indicates if the value is a return - * value. \param readonly Indicates if the value is read-only. \return A - * BoxedValue instance. - */ -template -auto makeBoxedValue(T&& value, bool is_return_value = false, - bool readonly = false) -> BoxedValue { - if constexpr (std::is_reference_v) { - return BoxedValue(std::ref(value), is_return_value, readonly); - } else { - return BoxedValue(std::forward(value), is_return_value, readonly); - } -} - -} // namespace atom::meta - -#endif // ATOM_META_ANY_HPP diff --git a/src/atom/function/anymeta.hpp b/src/atom/function/anymeta.hpp deleted file mode 100644 index e0d3e5c7..00000000 --- a/src/atom/function/anymeta.hpp +++ /dev/null @@ -1,318 +0,0 @@ -/*! - * \file anymeta.hpp - * \brief Enhanced Type Metadata with Dynamic Reflection, Method Overloads, and - * Event System \author Max Qian \date 2023-12-28 \copyright - * Copyright (C) 2023-2024 Max Qian - */ - -#ifndef ATOM_META_ANYMETA_HPP -#define ATOM_META_ANYMETA_HPP - -#include "any.hpp" -#include "type_info.hpp" - -#include -#include -#include -#include -#include -#include -#include - -#include "atom/error/exception.hpp" - -#include "atom/macro.hpp" - -namespace atom::meta { -class TypeMetadata { -public: - using MethodFunction = std::function)>; - using GetterFunction = std::function; - using SetterFunction = std::function; - using ConstructorFunction = - std::function)>; - using EventCallback = - std::function&)>; - - struct ATOM_ALIGNAS(64) Property { - GetterFunction getter; - SetterFunction setter; - BoxedValue default_value; - std::string description; - }; - - struct ATOM_ALIGNAS(32) Event { - std::vector> - listeners; // Pair of priority and callback - std::string description; - }; - -private: - std::unordered_map> - m_methods_; // Supports overloaded methods - std::unordered_map m_properties_; - std::unordered_map> - m_constructors_; - std::unordered_map m_events_; - -public: - // Add overloaded method to type metadata - void addMethod(const std::string& name, MethodFunction method) { - m_methods_[name].push_back(std::move(method)); - } - - // Remove method by name - void removeMethod(const std::string& name) { m_methods_.erase(name); } - - // Add property (getter and setter) to type metadata - void addProperty(const std::string& name, GetterFunction getter, - SetterFunction setter, BoxedValue default_value = {}, - const std::string& description = "") { - m_properties_[name] = {std::move(getter), std::move(setter), - std::move(default_value), description}; - } - - // Remove property by name - void removeProperty(const std::string& name) { m_properties_.erase(name); } - - // Add constructor to type metadata with an associated type name - void addConstructor(const std::string& type_name, - ConstructorFunction constructor) { - m_constructors_[type_name].push_back(std::move(constructor)); - } - - // Add event to type metadata - void addEvent(const std::string& event_name, - const std::string& description = "") { - m_events_[event_name].description = - description; // Creates an empty event with description - } - - // Remove event by name - void removeEvent(const std::string& event_name) { - m_events_.erase(event_name); - } - - // Add event listener to a specific event with priority - void addEventListener(const std::string& event_name, EventCallback callback, - int priority = 0) { - m_events_[event_name].listeners.emplace_back(priority, - std::move(callback)); - std::sort(m_events_[event_name].listeners.begin(), - m_events_[event_name].listeners.end(), - [](const auto& a, const auto& b) { - return a.first > b.first; // Higher priority first - }); - } - - // Fire event and notify listeners - void fireEvent(BoxedValue& obj, const std::string& event_name, - const std::vector& args) const { - if (auto eventIter = m_events_.find(event_name); - eventIter != m_events_.end()) { - for (const auto& [priority, listener] : - eventIter->second.listeners) { - listener(obj, args); - } - } else { - std::cerr << "Event " << event_name << " not found." << std::endl; - } - } - - // Retrieve all overloaded methods by name - [[nodiscard]] auto getMethods(const std::string& name) const - -> std::optional*> { - if (auto methodIter = m_methods_.find(name); - methodIter != m_methods_.end()) { - return &methodIter->second; - } - return std::nullopt; - } - - // Retrieve property by name - [[nodiscard]] auto getProperty(const std::string& name) const - -> std::optional { - if (auto propertyIter = m_properties_.find(name); - propertyIter != m_properties_.end()) { - return propertyIter->second; - } - return std::nullopt; - } - - // Retrieve constructor by index (defaults to the first constructor) - [[nodiscard]] auto getConstructor(const std::string& type_name, - size_t index = 0) const - -> std::optional { - if (auto constructorIter = m_constructors_.find(type_name); - constructorIter != m_constructors_.end()) { - if (index < constructorIter->second.size()) { - return constructorIter->second[index]; - } - } - return std::nullopt; - } - - // Retrieve event by name - [[nodiscard]] auto getEvent(const std::string& name) const - -> std::optional { - if (auto eventIter = m_events_.find(name); - eventIter != m_events_.end()) { - return &eventIter->second; - } - return std::nullopt; - } -}; - -class TypeRegistry { -private: - std::unordered_map m_registry_; - mutable std::shared_mutex m_mutex_; - -public: - // Singleton pattern to retrieve the global type registry - static auto instance() -> TypeRegistry& { - static TypeRegistry registry; - return registry; - } - - // Register a type and its metadata - void registerType(const std::string& name, TypeMetadata metadata) { - std::unique_lock lock(m_mutex_); - m_registry_[name] = std::move(metadata); - } - - // Retrieve metadata for a registered type - [[nodiscard]] auto getMetadata(const std::string& name) const - -> std::optional { - std::shared_lock lock(m_mutex_); - if (auto registryIter = m_registry_.find(name); - registryIter != m_registry_.end()) { - return registryIter->second; - } - return std::nullopt; - } -}; - -// Helper function to dynamically call overloaded methods on BoxedValue objects -inline auto callMethod(BoxedValue& obj, const std::string& method_name, - std::vector args) -> BoxedValue { - if (auto metadata = - TypeRegistry::instance().getMetadata(obj.getTypeInfo().name()); - metadata) { - if (auto methods = metadata->getMethods(method_name); methods) { - for (const auto& method : **methods) { - // TODO: FIX ME - 参数类型匹配逻辑: - // 确保传入的参数与方法期望的参数类型一致 - /* - auto argTypesMatch = true; - for (size_t i = 0; i < args.size(); ++i) { - if (args[i].getTypeInfo() != method.argument_type(i)) { - argTypesMatch = false; - break; - } - } - */ - // if (argTypesMatch) { - return method(args); - //} - } - } - } - THROW_NOT_FOUND("Method not found or no matching overload found"); -} - -// Helper function to dynamically get properties from BoxedValue objects -inline auto getProperty(const BoxedValue& obj, - const std::string& property_name) -> BoxedValue { - if (auto metadata = - TypeRegistry::instance().getMetadata(obj.getTypeInfo().name()); - metadata) { - if (auto property = metadata->getProperty(property_name); property) { - return (*property).getter( - obj); // 修复后的代码,正确调用 getter 函数 - } - } - THROW_NOT_FOUND("Property not found"); -} - -// Helper function to dynamically set properties on BoxedValue objects -inline void setProperty(BoxedValue& obj, const std::string& property_name, - const BoxedValue& value) { - if (auto metadata = - TypeRegistry::instance().getMetadata(obj.getTypeInfo().name()); - metadata) { - if (auto property = metadata->getProperty(property_name); property) { - property->setter(obj, value); - return; - } - } - THROW_NOT_FOUND("Property not found"); -} - -// Helper function to fire events on BoxedValue objects -inline void fireEvent(BoxedValue& obj, const std::string& event_name, - const std::vector& args) { - if (auto metadata = - TypeRegistry::instance().getMetadata(obj.getTypeInfo().name()); - metadata) { - metadata->fireEvent(obj, event_name, args); - } else { - std::cerr << "Event not found." << std::endl; - } -} - -// Factory function to dynamically construct an object by type name -inline auto createInstance(const std::string& type_name, - std::vector args) -> BoxedValue { - if (auto metadata = TypeRegistry::instance().getMetadata(type_name); - metadata) { - if (auto constructor = metadata->getConstructor(type_name); - constructor) { - return (*constructor)(std::move(args)); - } - } - THROW_NOT_FOUND("Constructor not found"); -} - -// Reflective registration of types, methods, properties, and events leveraging -// C++20 features -template -class TypeRegistrar { -public: - // Register a type with metadata - static void registerType(const std::string& type_name) { - TypeMetadata metadata; - - // Register default constructor - metadata.addConstructor( - type_name, [](std::vector args) -> BoxedValue { - if (args.empty()) { - return BoxedValue(T{}); // Default constructor - } - return BoxedValue{}; // Placeholder for more complex - // constructors - }); - - // Register events - metadata.addEvent("onCreate", "Triggered when an object is created"); - metadata.addEvent("onDestroy", "Triggered when an object is destroyed"); - - // Add methods, properties, events dynamically as needed - metadata.addMethod( - "print", [](std::vector args) -> BoxedValue { - if (!args.empty()) { - std::cout << "Method print called with value: " - << args[0].debugString() << std::endl; - return BoxedValue{}; - } - return BoxedValue{}; - }); - - // Register type in the global registry - TypeRegistry::instance().registerType(type_name, std::move(metadata)); - } -}; - -} // namespace atom::meta - -#endif // ATOM_META_ANYMETA_HPP diff --git a/src/atom/function/bind_first.hpp b/src/atom/function/bind_first.hpp deleted file mode 100644 index c01dcb16..00000000 --- a/src/atom/function/bind_first.hpp +++ /dev/null @@ -1,153 +0,0 @@ -/*! - * \file bind_first.hpp - * \brief An easy way to bind a function to an object - * \author Max Qian - * \date 2024-03-01 - * \copyright Copyright (C) 2023-2024 Max Qian - */ - -#ifndef ATOM_META_BIND_FIRST_HPP -#define ATOM_META_BIND_FIRST_HPP - -#include -#include -#include -#include - -namespace atom::meta { -template -constexpr auto getPointer(T *ptr) noexcept -> T * { - return ptr; -} - -template -auto getPointer(const std::reference_wrapper &ref) noexcept -> T * { - return &ref.get(); -} - -template -constexpr auto getPointer(const T &ref) noexcept -> const T * { - return &ref; -} - -template -constexpr auto removeConstPointer(const T *ptr) noexcept -> T * { - return const_cast(ptr); -} - -template -concept invocable = std::is_invocable_v; - -template -concept nothrow_invocable = std::is_nothrow_invocable_v; - -template -constexpr bool IS_INVOCABLE_V = invocable; - -template -constexpr bool IS_NOTHROW_INVOCABLE_V = std::is_nothrow_invocable_v; - -template -constexpr auto bindFirst(Ret (*func)(P1, Param...), O &&object) - requires invocable -{ - return [func, object = std::forward(object)](Param... param) -> Ret { - return func(object, std::forward(param)...); - }; -} - -template -constexpr auto bindFirst(Ret (Class::*func)(Param...), O &&object) - requires invocable -{ - return [func, object = std::forward(object)](Param... param) -> Ret { - return (removeConstPointer(getPointer(object))->*func)( - std::forward(param)...); - }; -} - -template -constexpr auto bindFirst(Ret (Class::*func)(Param...) const, O &&object) - requires invocable -{ - return [func, object = std::forward(object)](Param... param) -> Ret { - return (getPointer(object)->*func)(std::forward(param)...); - }; -} - -template -auto bindFirst(const std::function &func, O &&object) - requires invocable, O, Param...> -{ - return [func, object = std::forward(object)](Param... param) -> Ret { - return func(object, std::forward(param)...); - }; -} - -template -constexpr auto bindFirst(const F &funcObj, O &&object, - Ret (Class::*func)(P1, Param...) const) - requires invocable -{ - return [funcObj, object = std::forward(object), - func](Param... param) -> Ret { - return (funcObj.*func)(object, std::forward(param)...); - }; -} - -template -constexpr auto bindFirst(const F &func, O &&object) - requires invocable -{ - return bindFirst(func, std::forward(object), &F::operator()); -} - -template -constexpr auto bindFirst(F &&func, O &&object) - requires std::invocable -{ - return [func = std::forward(func), object = std::forward(object)]( - auto &&...param) -> decltype(auto) { - return std::invoke(func, object, - std::forward(param)...); - }; -} - -template -constexpr auto bindMember(T Class::*member, O &&object) noexcept { - return [member, object = std::forward(object)]() -> T & { - return removeConstPointer(getPointer(object))->*member; - }; -} - -template -constexpr auto bindStatic(Ret (*func)(Param...)) noexcept { - return [func](Param... param) -> Ret { - return func(std::forward(param)...); - }; -} - -template -auto asyncBindFirst(F &&func, Args &&...args) { - return std::async(std::launch::async, std::forward(func), - std::forward(args)...); -} - -template -constexpr auto bindFirstWithExceptionHandling(Ret (*func)(P1, Param...), - O &&object) - requires invocable -{ - return [func, object = std::forward(object)](Param... param) -> Ret { - try { - return func(object, std::forward(param)...); - } catch (const std::exception &e) { - throw; - } - }; -} - -} // namespace atom::meta - -#endif // ATOM_META_BIND_FIRST_HPP diff --git a/src/atom/function/concept.hpp b/src/atom/function/concept.hpp deleted file mode 100644 index feda85cc..00000000 --- a/src/atom/function/concept.hpp +++ /dev/null @@ -1,361 +0,0 @@ -/*! - * \file concept.hpp - * \brief C++ Concepts - * \author Max Qian - * \date 2024-03-01 - * \copyright Copyright (C) 2023-2024 Max Qian - */ - -#ifndef ATOM_META_CONCEPT_HPP -#define ATOM_META_CONCEPT_HPP - -#include -#include -#include -#include -#include -#include -#include -#include - -#if __cplusplus < 202002L -#error "C++20 is required for this library" -#endif - -// ----------------------------------------------------------------------------- -// Function Concepts -// ----------------------------------------------------------------------------- - -template -concept Invocable = requires(F func, Args&&... args) { - { std::invoke(func, std::forward(args)...) }; -}; - -template -concept InvocableR = requires(F func, Args&&... args) { - { - std::invoke(func, std::forward(args)...) - } -> std::convertible_to; -}; - -template -concept NothrowInvocable = requires(F func, Args&&... args) { - { std::invoke(func, std::forward(args)...) } noexcept; -}; - -template -concept NothrowInvocableR = requires(F func, Args&&... args) { - { - std::invoke(func, std::forward(args)...) - } noexcept -> std::convertible_to; -}; - -template -concept FunctionPointer = std::is_function_v>; - -template -concept Callable = requires(T obj) { - { std::function{std::declval()} }; -}; - -template -concept CallableReturns = std::is_invocable_r_v; - -template -concept CallableNoexcept = requires(T obj, Args&&... args) { - { obj(std::forward(args)...) } noexcept; -}; - -template -concept StdFunction = requires { - typename T::result_type; - requires std::is_same_v< - T, std::function>; -}; - -// ----------------------------------------------------------------------------- -// Object Concepts -// ----------------------------------------------------------------------------- - -template -concept Relocatable = requires(T obj) { - { std::is_nothrow_move_constructible_v } -> std::convertible_to; - { std::is_nothrow_move_assignable_v } -> std::convertible_to; -}; - -template -concept DefaultConstructible = requires(T obj) { - { T() } -> std::same_as; -}; - -template -concept CopyConstructible = requires(T obj) { - { T(obj) } -> std::same_as; -}; - -template -concept CopyAssignable = requires(T obj) { - { obj = obj } -> std::same_as; -}; - -template -concept MoveAssignable = requires(T obj) { - { obj = std::move(obj) } -> std::same_as; -}; - -template -concept EqualityComparable = requires(T obj) { - { obj == obj } -> std::convertible_to; - { obj != obj } -> std::convertible_to; -}; - -template -concept LessThanComparable = requires(T obj) { - { obj < obj } -> std::convertible_to; -}; - -template -concept Hashable = requires(T obj) { - { std::hash{}(obj) } -> std::convertible_to; -}; - -template -concept Swappable = requires(T obj) { std::swap(obj, obj); }; - -template -concept Copyable = - std::is_copy_constructible_v && std::is_copy_assignable_v; - -template -concept Destructible = requires(T obj) { - { obj.~T() } -> std::same_as; -}; - -// ----------------------------------------------------------------------------- -// Type Concepts -// ----------------------------------------------------------------------------- - -template -concept Arithmetic = std::is_arithmetic_v; - -template -concept Integral = std::is_integral_v; - -template -concept FloatingPoint = std::is_floating_point_v; - -template -concept SignedInteger = std::is_integral_v && std::is_signed_v; - -template -concept UnsignedInteger = std::is_integral_v && std::is_unsigned_v; - -template -concept Number = Arithmetic || Integral || FloatingPoint; - -#if __has_include() -#include -template -concept ComplexNumber = requires(T obj) { - typename T::value_type; - requires std::is_same_v>; -}; -#endif - -template -concept Char = std::is_same_v; - -template -concept WChar = std::is_same_v; - -template -concept Char16 = std::is_same_v; - -template -concept Char32 = std::is_same_v; - -template -concept AnyChar = Char || WChar || Char16 || Char32; - -template -concept StringType = - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v; - -template -concept IsBuiltIn = std::is_fundamental_v || StringType; - -template -concept Enum = std::is_enum_v; - -template -concept Pointer = std::is_pointer_v; - -template -concept UniquePointer = requires(T obj) { - requires std::is_same_v>; -}; - -template -concept SharedPointer = requires(T obj) { - requires std::is_same_v>; -}; - -template -concept WeakPointer = requires(T obj) { - requires std::is_same_v>; -}; - -template -concept SmartPointer = UniquePointer || SharedPointer || WeakPointer; - -template -concept Reference = std::is_reference_v; - -template -concept LvalueReference = std::is_lvalue_reference_v; - -template -concept RvalueReference = std::is_rvalue_reference_v; - -template -concept Const = std::is_const_v>; - -template -concept Trivial = std::is_trivial_v; - -template -concept TriviallyConstructible = std::is_trivially_constructible_v; - -template -concept TriviallyCopyable = - std::is_trivially_copyable_v && std::is_standard_layout_v; - -// ----------------------------------------------------------------------------- -// Container Concepts -// ----------------------------------------------------------------------------- - -#if __has_include() -#include - -template -concept Iterable = requires(T obj) { - { obj.begin() } -> std::forward_iterator; - { obj.end() } -> std::forward_iterator; -}; - -template -concept Container = requires(T obj) { - { obj.size() } -> std::convertible_to; - requires Iterable; -}; - -template -concept StringContainer = requires(T obj) { - typename T::value_type; - requires StringType || Char; - { obj.push_back(std::declval()) }; -}; - -template -concept NumberContainer = requires(T obj) { - typename T::value_type; - requires Number; - { obj.push_back(std::declval()) }; -}; - -template -concept AssociativeContainer = requires(T obj) { - typename T::key_type; - typename T::mapped_type; - requires Container; -}; - -template -concept Iterator = requires(T iter) { - { - *iter - } -> std::convertible_to::value_type>; - { ++iter } -> std::same_as; - { iter++ } -> std::convertible_to; -}; - -template -concept NotSequenceContainer = - !std::is_same_v> && - !std::is_same_v> && - !std::is_same_v>; - -template -concept NotAssociativeOrSequenceContainer = - !std::is_same_v> && - !std::is_same_v< - T, std::unordered_map> && - !std::is_same_v< - T, std::multimap> && - !std::is_same_v> && - !NotSequenceContainer; - -template -concept String = NotSequenceContainer && requires(T obj) { - { obj.size() } -> std::convertible_to; - { obj.empty() } -> std::convertible_to; - { obj.begin() } -> std::convertible_to; - { obj.end() } -> std::convertible_to; -}; - -// ----------------------------------------------------------------------------- -// Multi-threading Concepts -// ----------------------------------------------------------------------------- - -template -concept Lockable = requires(T obj) { - { obj.lock() } -> std::same_as; - { obj.unlock() } -> std::same_as; -}; - -template -concept SharedLockable = requires(T obj) { - { obj.lock_shared() } -> std::same_as; - { obj.unlock_shared() } -> std::same_as; -}; - -template -concept Mutex = Lockable && requires(T obj) { - { obj.try_lock() } -> std::same_as; -}; - -template -concept SharedMutex = SharedLockable && requires(T obj) { - { obj.try_lock_shared() } -> std::same_as; -}; - -// ----------------------------------------------------------------------------- -// Asynchronous Concepts -// ----------------------------------------------------------------------------- - -template -concept Future = requires(T obj) { - { obj.get() } -> std::same_as; - { obj.wait() } -> std::same_as; -}; - -template -concept Promise = requires(T obj) { - { - obj.set_value(std::declval()) - } -> std::same_as; - { - obj.set_exception(std::declval()) - } -> std::same_as; -}; - -template -concept AsyncResult = Future || Promise; - -#endif - -#endif diff --git a/src/atom/function/constructor.hpp b/src/atom/function/constructor.hpp deleted file mode 100644 index a15df8f9..00000000 --- a/src/atom/function/constructor.hpp +++ /dev/null @@ -1,245 +0,0 @@ -/*! - * \file constructors.hpp - * \brief C++ Function Constructors - * \author Max Qian - * \date 2024-03-01 - * \copyright Copyright (C) 2023-2024 Max Qian - */ - -#ifndef ATOM_META_CONSTRUCTOR_HPP -#define ATOM_META_CONSTRUCTOR_HPP - -#include -#include -#include - -#include "atom/error/exception.hpp" -#include "func_traits.hpp" - -namespace atom::meta { - -/*! - * \brief Binds a member function to an object. - * \tparam MemberFunc Type of the member function. - * \tparam ClassType Type of the class. - * \param member_func Pointer to the member function. - * \return A lambda that binds the member function to an object. - */ -template -auto bindMemberFunction(MemberFunc ClassType::*member_func) { - return [member_func](ClassType &obj, auto &&...params) { - if constexpr (FunctionTraits::is_const_member_function) { - return (std::as_const(obj).* - member_func)(std::forward(params)...); - } else { - return (obj.* - member_func)(std::forward(params)...); - } - }; -} - -/*! - * \brief Binds a static function. - * \tparam Func Type of the function. - * \param func The static function. - * \return The static function itself. - */ -template -auto bindStaticFunction(Func func) { - return func; -} - -/*! - * \brief Binds a member variable to an object. - * \tparam MemberType Type of the member variable. - * \tparam ClassType Type of the class. - * \param member_var Pointer to the member variable. - * \return A lambda that binds the member variable to an object. - */ -template -auto bindMemberVariable(MemberType ClassType::*member_var) { - return [member_var](ClassType &instance) -> MemberType & { - return instance.*member_var; - }; -} - -/*! - * \brief Builds a shared constructor for a class. - * \tparam Class Type of the class. - * \tparam Params Types of the constructor parameters. - * \param unused Unused parameter to deduce types. - * \return A lambda that constructs a shared pointer to the class. - */ -template -auto buildSharedConstructor(Class (* /*unused*/)(Params...)) { - return [](auto &&...params) { - return std::make_shared( - std::forward(params)...); - }; -} - -/*! - * \brief Builds a copy constructor for a class. - * \tparam Class Type of the class. - * \tparam Params Types of the constructor parameters. - * \param unused Unused parameter to deduce types. - * \return A lambda that constructs an instance of the class. - */ -template -auto buildCopyConstructor(Class (* /*unused*/)(Params...)) { - return [](auto &&...params) { - return Class(std::forward(params)...); - }; -} - -/*! - * \brief Builds a plain constructor for a class. - * \tparam Class Type of the class. - * \tparam Params Types of the constructor parameters. - * \param unused Unused parameter to deduce types. - * \return A lambda that constructs an instance of the class. - */ -template -auto buildPlainConstructor(Class (* /*unused*/)(Params...)) { - return [](auto &&...params) { - return Class(std::forward(params)...); - }; -} - -/*! - * \brief Builds a constructor for a class with specified arguments. - * \tparam Class Type of the class. - * \tparam Args Types of the constructor arguments. - * \return A lambda that constructs a shared pointer to the class. - */ -template -auto buildConstructor() { - return [](Args... args) -> std::shared_ptr { - return std::make_shared(std::forward(args)...); - }; -} - -/*! - * \brief Builds a default constructor for a class. - * \tparam Class Type of the class. - * \return A lambda that constructs an instance of the class. - */ -template -auto buildDefaultConstructor() { - return []() { return Class(); }; -} - -/*! - * \brief Constructs an instance of a class based on its traits. - * \tparam T Type of the function. - * \return A lambda that constructs an instance of the class. - */ -template -auto constructor() { - T *func = nullptr; - using ClassType = typename FunctionTraits::class_type; - - if constexpr (!std::is_copy_constructible_v) { - return buildSharedConstructor(func); - } else { - return buildCopyConstructor(func); - } -} - -/*! - * \brief Constructs an instance of a class with specified arguments. - * \tparam Class Type of the class. - * \tparam Args Types of the constructor arguments. - * \return A lambda that constructs a shared pointer to the class. - */ -template -auto constructor() { - return buildConstructor(); -} - -/*! - * \brief Constructs an instance of a class using the default constructor. - * \tparam Class Type of the class. - * \return A lambda that constructs an instance of the class. - * \throws Exception if the class is not default constructible. - */ -template -auto defaultConstructor() { - if constexpr (std::is_default_constructible_v) { - return buildDefaultConstructor(); - } else { - THROW_NOT_FOUND("Class is not default constructible"); - } -} - -/*! - * \brief Constructs an instance of a class using a move constructor. - * \tparam Class Type of the class. - * \return A lambda that constructs an instance of the class using a move - * constructor. - */ -template -auto buildMoveConstructor() { - return [](Class &&instance) { return Class(std::move(instance)); }; -} - -/*! - * \brief Constructs an instance of a class using an initializer list. - * \tparam Class Type of the class. - * \tparam T Type of the elements in the initializer list. - * \return A lambda that constructs an instance of the class using an - * initializer list. - */ -template -auto buildInitializerListConstructor() { - return [](std::initializer_list init_list) { return Class(init_list); }; -} - -/*! - * \brief Constructs an instance of a class asynchronously. - * \tparam Class Type of the class. - * \tparam Args Types of the constructor arguments. - * \return A future that constructs an instance of the class. - */ -template -auto asyncConstructor() { - return [](Args... args) -> std::future> { - return std::async( - std::launch::async, - [](Args... args) { - return std::make_shared(std::forward(args)...); - }, - std::forward(args)...); - }; -} - -/*! - * \brief Constructs a singleton instance of a class. - * \tparam Class Type of the class. - * \return A lambda that constructs a singleton instance of the class. - */ -template -auto singletonConstructor() { - return []() -> std::shared_ptr { - static std::shared_ptr instance = std::make_shared(); - return instance; - }; -} - -/*! - * \brief Constructs an instance of a class using a custom constructor. - * \tparam Class Type of the class. - * \tparam CustomConstructor Type of the custom constructor. - * \return A lambda that constructs an instance of the class using the custom - * constructor. - */ -template -auto customConstructor(CustomConstructor custom_constructor) { - return [custom_constructor](auto &&...args) { - return custom_constructor(std::forward(args)...); - }; -} - -} // namespace atom::meta - -#endif // ATOM_META_CONSTRUCTOR_HPP diff --git a/src/atom/function/conversion.hpp b/src/atom/function/conversion.hpp deleted file mode 100644 index 944a1bbc..00000000 --- a/src/atom/function/conversion.hpp +++ /dev/null @@ -1,513 +0,0 @@ -#ifndef ATOM_META_CONVERSION_HPP -#define ATOM_META_CONVERSION_HPP - -#include -#include -#include -#include -#include -#include "atom/macro.hpp" - -#if ENABLE_FASTHASH -#include "emhash/hash_table8.hpp" -#else -#include -#endif - -#include "atom/error/exception.hpp" -#include "type_info.hpp" - -namespace atom::meta { - -class BadConversionException : public error::RuntimeError { - using atom::error::RuntimeError::RuntimeError; -}; - -#define THROW_CONVERSION_ERROR(...) \ - throw BadConversionException(ATOM_FILE_NAME, ATOM_FILE_LINE, \ - ATOM_FUNC_NAME, __VA_ARGS__) - -class TypeConversionBase { -public: - ATOM_NODISCARD virtual auto convert(const std::any& from) const - -> std::any = 0; - ATOM_NODISCARD virtual auto convertDown(const std::any& toAny) const - -> std::any = 0; - - ATOM_NODISCARD virtual auto to() const ATOM_NOEXCEPT -> const TypeInfo& { - return toType; - } - ATOM_NODISCARD virtual auto from() const ATOM_NOEXCEPT -> const TypeInfo& { - return fromType; - } - - ATOM_NODISCARD auto getFromType() const ATOM_NOEXCEPT -> const TypeInfo& { - return fromType; - } - - ATOM_NODISCARD auto getToType() const ATOM_NOEXCEPT -> const TypeInfo& { - return toType; - } - - ATOM_NODISCARD virtual auto bidir() const ATOM_NOEXCEPT -> bool { - return true; - } - - virtual ~TypeConversionBase() = default; - - TypeConversionBase(const TypeConversionBase&) = default; - TypeConversionBase& operator=(const TypeConversionBase&) = default; - TypeConversionBase(TypeConversionBase&&) = default; - TypeConversionBase& operator=(TypeConversionBase&&) = default; - -protected: - TypeConversionBase(const TypeInfo& toTypeInfo, const TypeInfo& fromTypeInfo) - : toType(toTypeInfo), fromType(fromTypeInfo) {} - - TypeInfo toType; - TypeInfo fromType; -}; - -template -class StaticConversion : public TypeConversionBase { -public: - StaticConversion() : TypeConversionBase(userType(), userType()) {} - - ATOM_NODISCARD auto convert(const std::any& from) const - -> std::any override { - // Pointer types static conversion (upcasting) - try { - if constexpr (std::is_pointer_v && std::is_pointer_v) { - auto fromPtr = std::any_cast(from); - return std::any(static_cast(fromPtr)); - } - // Reference types static conversion (upcasting) - else if constexpr (std::is_reference_v && - std::is_reference_v) { - auto& fromRef = std::any_cast(from); - return std::any(static_cast(fromRef)); - - } else { - THROW_CONVERSION_ERROR("Failed to convert ", fromType.name(), - " to ", toType.name()); - } - } catch (const std::bad_cast&) { - THROW_CONVERSION_ERROR("Failed to convert ", fromType.name(), - " to ", toType.name()); - } - } - - ATOM_NODISCARD auto convertDown(const std::any& toAny) const - -> std::any override { - // Pointer types static conversion (downcasting) - try { - if constexpr (std::is_pointer_v && std::is_pointer_v) { - auto toPtr = std::any_cast(toAny); - return std::any(static_cast(toPtr)); - } - // Reference types static conversion (downcasting) - else if constexpr (std::is_reference_v && - std::is_reference_v) { - auto& toRef = std::any_cast(toAny); - return std::any(static_cast(toRef)); - - } else { - THROW_CONVERSION_ERROR("Failed to convert ", toType.name(), - " to ", fromType.name()); - } - } catch (const std::bad_cast&) { - THROW_CONVERSION_ERROR("Failed to convert ", toType.name(), " to ", - fromType.name()); - } - } -}; - -template -class DynamicConversion : public TypeConversionBase { -public: - DynamicConversion() - : TypeConversionBase(userType(), userType()) {} - - ATOM_NODISCARD auto convert(const std::any& from) const - -> std::any override { - // Pointer types dynamic conversion - if constexpr (std::is_pointer_v && std::is_pointer_v) { - auto fromPtr = std::any_cast(from); - auto convertedPtr = dynamic_cast(fromPtr); - if (!convertedPtr && fromPtr != nullptr) { - throw std::bad_cast(); - } - return std::any(convertedPtr); - } - // Reference types dynamic conversion - else if constexpr (std::is_reference_v && - std::is_reference_v) { - try { - auto& fromRef = std::any_cast(from); - return std::any(dynamic_cast(fromRef)); - } catch (const std::bad_cast&) { - THROW_CONVERSION_ERROR("Failed to convert ", fromType.name(), - " to ", toType.name()); - } - } else { - THROW_CONVERSION_ERROR("Failed to convert ", fromType.name(), - " to ", toType.name()); - } - } - - ATOM_NODISCARD auto convertDown(const std::any& toAny) const - -> std::any override { - // Pointer types dynamic conversion - if constexpr (std::is_pointer_v && std::is_pointer_v) { - auto toPtr = std::any_cast(toAny); - auto convertedPtr = dynamic_cast(toPtr); - if (!convertedPtr && toPtr != nullptr) { - throw std::bad_cast(); - } - return std::any(convertedPtr); - } - // Reference types dynamic conversion - else if constexpr (std::is_reference_v && - std::is_reference_v) { - try { - auto& toRef = std::any_cast(toAny); - return std::any(dynamic_cast(toRef)); - } catch (const std::bad_cast&) { - THROW_CONVERSION_ERROR("Failed to convert ", toType.name(), - " to ", fromType.name()); - } - } else { - THROW_CONVERSION_ERROR("Failed to convert ", toType.name(), " to ", - fromType.name()); - } - } -}; - -template -auto baseClass() -> std::shared_ptr { - if constexpr (std::is_polymorphic_v && - std::is_polymorphic_v) { - return std::make_shared>(); - } else { - return std::make_shared>(); - } -} - -// Specialized conversion for std::vector -template -class VectorConversion : public TypeConversionBase { -public: - VectorConversion() - : TypeConversionBase(userType>(), - userType>()) {} - - [[nodiscard]] auto convert(const std::any& from) const - -> std::any override { - try { - const auto& fromVec = std::any_cast&>(from); - std::vector toVec; - toVec.reserve(fromVec.size()); - - for (const auto& elem : fromVec) { - // Convert each element using dynamic cast - auto convertedElem = - std::dynamic_pointer_cast(elem); - if (!convertedElem) { - throw std::bad_cast(); - } - toVec.push_back(convertedElem); - } - - return std::any(toVec); - } catch (const std::bad_any_cast&) { - THROW_CONVERSION_ERROR("Failed to convert ", fromType.name(), - " to ", toType.name()); - } - } - - ATOM_NODISCARD auto convertDown(const std::any& toAny) const - -> std::any override { - try { - const auto& toVec = std::any_cast&>(toAny); - std::vector fromVec; - fromVec.reserve(toVec.size()); - - for (const auto& elem : toVec) { - // Convert each element using dynamic cast - auto convertedElem = - std::dynamic_pointer_cast( - elem); - if (!convertedElem) { - throw std::bad_cast(); - } - fromVec.push_back(convertedElem); - } - - return std::any(fromVec); - } catch (const std::bad_any_cast&) { - THROW_CONVERSION_ERROR("Failed to convert ", toType.name(), " to ", - fromType.name()); - } - } -}; - -template