diff --git a/CMakeLists.txt b/CMakeLists.txt index 5ffc8e14..8213984a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,51 @@ LIST(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules/") LIST(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../cmake_modules/") include(cmake_modules/compiler_options.cmake) +if (USE_CONAN) +include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake OPTIONAL) # Optional inclusion + +# 检查是否已经安装Conan +find_program(CONAN_CMD conan) +if(NOT CONAN_CMD) + message(FATAL_ERROR "Conan is not installed. Please install Conan (pip install conan).") +endif() + +# 检测Conan默认配置文件是否存在 +execute_process( + COMMAND ${CONAN_CMD} config home + OUTPUT_VARIABLE CONAN_HOME + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +set(CONAN_DEFAULT_PROFILE "${CONAN_HOME}/profiles/default") + +if(NOT EXISTS "${CONAN_DEFAULT_PROFILE}") + message(STATUS "Conan default profile not found. Creating a new profile based on platform.") + # 根据操作系统创建默认配置 + if(WIN32) + execute_process(COMMAND ${CONAN_CMD} profile detect --force) + elseif(UNIX) + execute_process(COMMAND ${CONAN_CMD} profile detect --force) + else() + message(FATAL_ERROR "Unsupported platform for Conan profile detection.") + endif() +endif() + +# 如果conanbuildinfo.cmake不存在,执行conan install命令 +if(NOT EXISTS "${CMAKE_BINARY_DIR}/conanbuildinfo.cmake") + message(STATUS "Running Conan install...") + execute_process( + COMMAND ${CONAN_CMD} install ${CMAKE_SOURCE_DIR} --build=missing + RESULT_VARIABLE result + ) + if(result) + message(FATAL_ERROR "Conan install failed with error code: ${result}") + endif() + include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake) + conan_basic_setup() +endif() +endif() + # Include directories include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_SOURCE_DIR}/libs/) diff --git a/LICENSE b/LICENSE index f288702d..0ad25db4 100644 --- a/LICENSE +++ b/LICENSE @@ -1,5 +1,5 @@ - GNU GENERAL PUBLIC LICENSE - Version 3, 29 June 2007 + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies @@ -7,17 +7,15 @@ Preamble - The GNU General Public License is a free, copyleft license for -software and other kinds of works. + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, -the GNU General Public License is intended to guarantee your freedom to +our General Public Licenses are intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free -software for all its users. We, the Free Software Foundation, use the -GNU General Public License for most of our software; it applies also to -any other work released this way by its authors. You can apply it to -your programs, too. +software for all its users. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you @@ -26,44 +24,34 @@ them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. - To protect your rights, we need to prevent others from denying you -these rights or asking you to surrender the rights. Therefore, you have -certain responsibilities if you distribute copies of the software, or if -you modify it: responsibilities to respect the freedom of others. - - For example, if you distribute copies of such a program, whether -gratis or for a fee, you must pass on to the recipients the same -freedoms that you received. You must make sure that they, too, receive -or can get the source code. And you must show them these terms so they -know their rights. - - Developers that use the GNU GPL protect your rights with two steps: -(1) assert copyright on the software, and (2) offer you this License -giving you legal permission to copy, distribute and/or modify it. - - For the developers' and authors' protection, the GPL clearly explains -that there is no warranty for this free software. For both users' and -authors' sake, the GPL requires that modified versions be marked as -changed, so that their problems will not be attributed erroneously to -authors of previous versions. - - Some devices are designed to deny users access to install or run -modified versions of the software inside them, although the manufacturer -can do so. This is fundamentally incompatible with the aim of -protecting users' freedom to change the software. The systematic -pattern of such abuse occurs in the area of products for individuals to -use, which is precisely where it is most unacceptable. Therefore, we -have designed this version of the GPL to prohibit the practice for those -products. If such problems arise substantially in other domains, we -stand ready to extend this provision to those domains in future versions -of the GPL, as needed to protect the freedom of users. - - Finally, every program is threatened constantly by software patents. -States should not allow patents to restrict development and use of -software on general-purpose computers, but in those that do, we wish to -avoid the special danger that patents applied to a free program could -make it effectively proprietary. To prevent this, the GPL assures that -patents cannot be used to render the program non-free. + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. The precise terms and conditions for copying, distribution and modification follow. @@ -72,7 +60,7 @@ modification follow. 0. Definitions. - "This License" refers to version 3 of the GNU General Public License. + "This License" refers to version 3 of the GNU Affero General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. @@ -549,35 +537,45 @@ to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. - 13. Use with the GNU Affero General Public License. + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed -under version 3 of the GNU Affero General Public License into a single +under version 3 of the GNU General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, -but the special requirements of the GNU Affero General Public License, -section 13, concerning interaction through a network will apply to the -combination as such. +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of -the GNU General Public License from time to time. Such new versions will -be similar in spirit to the present version, but may differ in detail to +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the -Program specifies that a certain numbered version of the GNU General +Program specifies that a certain numbered version of the GNU Affero General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the -GNU General Public License, you may choose any version ever published +GNU Affero General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future -versions of the GNU General Public License can be used, that proxy's +versions of the GNU Affero General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. @@ -635,40 +633,29 @@ the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. + GNU Affero General Public License for more details. - You should have received a copy of the GNU General Public License + You should have received a copy of the GNU Affero General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. - If the program does terminal interaction, make it output a short -notice like this when it starts in an interactive mode: - - Copyright (C) - This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. - This is free software, and you are welcome to redistribute it - under certain conditions; type `show c' for details. - -The hypothetical commands `show w' and `show c' should show the appropriate -parts of the General Public License. Of course, your program's commands -might be different; for a GUI interface, you would use an "about box". + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. -For more information on this, and how to apply and follow the GNU GPL, see +For more information on this, and how to apply and follow the GNU AGPL, see . - - The GNU General Public License does not permit incorporating your program -into proprietary programs. If your program is a subroutine library, you -may consider it more useful to permit linking proprietary applications with -the library. If this is what you want to do, use the GNU Lesser General -Public License instead of this License. But first, please read -. diff --git a/conanfile.py b/conanfile.py deleted file mode 100644 index a2462b84..00000000 --- a/conanfile.py +++ /dev/null @@ -1,52 +0,0 @@ -from conan import ConanFile -from conan.tools.cmake import CMakeToolchain, CMakeDeps, cmake_layout - - -class ConanFile(ConanFile): - name = "lithium" - version = "1.0.0" - - # Binary configuration - settings = "os", "compiler", "build_type", "arch" - - # Sources are located in the same place as this recipe, copy them to the recipe - exports_sources = "CMakeLists.txt", "src/*" - - requires = [ - "argparse/3.0", - "cpp-httplib/0.15.3", - "openssl/3.2.1", - "zlib/1.3.1", - "oatpp/1.3.0", - "oatpp-websocket/1.3.0", - "oatpp-openssl/1.3.0", - "oatpp-swagger/1.3.0", - "loguru/cci.20230406", - "cfitsio/4.3.1", - "tinyxml2/10.0.0", - "cpython/3.12.2", - "fmt/10.2.1", - "opencv/4.9.0" - ] - - def layout(self): - cmake_layout(self) - - def generate(self): - tc = CMakeToolchain(self) - tc.generate() - - deps = CMakeDeps(self) - deps.generate() - - def build(self): - cmake = CMake(self) - cmake.configure() - cmake.build() - - def package(self): - cmake = CMake(self) - cmake.install() - - def package_info(self): - self.cpp_info.libs = ["my_project"] diff --git a/conanfile.txt b/conanfile.txt index 1d181bcc..3b38b6ab 100644 --- a/conanfile.txt +++ b/conanfile.txt @@ -1,18 +1,9 @@ [requires] -argparse/3.0 cfitsio/4.3.1 cpython/3.12.2 -cpp-httplib/0.15.3 -fmt/10.2.1 -loguru/cci.20230406 -oatpp/1.3.0 -oatpp-websocket/1.3.0 -oatpp-openssl/1.3.0 -oatpp-swagger/1.3.0 opencv/4.9.0 openssl/3.2.1 pybind11/2.12.0 -pybind11_json/0.2.13 tinyxml2/10.0.0 zlib/1.3.1 diff --git a/config/addon/compiler.json b/config/addon/compiler.json new file mode 100644 index 00000000..36fc3db3 --- /dev/null +++ b/config/addon/compiler.json @@ -0,0 +1,24 @@ +{ + "compiler": "clang++", + "optimization_level": "-O3", + "cplus_version": "-std=c++20", + "warnings": "-Wall -Wextra -Wpedantic", + "include_flag": "-I./include -I./external/include", + "output_flag": "-o output", + "debug_info": "-g", + "sanitizers": "-fsanitize=address,undefined", + "extra_flags": [ + "-fno-exceptions", + "-fno-rtti", + "-march=native" + ], + "defines": [ + "-DNDEBUG", + "-DUSE_OPENMP" + ], + "libraries": [ + "-lpthread", + "-lm", + "-lrt" + ] +} diff --git a/doc/extra/day01.md b/doc/extra/day01.md deleted file mode 100644 index edc6caf4..00000000 --- a/doc/extra/day01.md +++ /dev/null @@ -1,209 +0,0 @@ -# 从零入门 C 语言: Day1 - "Hello, World!"原来也这么复杂 - -大家好啊,欢迎来到 C 语言的世界!作为一个入门者,你即将踏上一段有趣的编程旅程。而我们要做的第一件事,就是写出一个无比经典且简单的程序——"Hello, World!"。准备好了吗?带上你的大脑和电脑(如果是手机超人也可以),我们要出发咯! - -## 最难的问题:Hello World - -程序员调查表明,Hello, World 劝退很多新人,如果你能越过这道坎,那么你已经击败了 50%的萌新了 awa - -```c -#include - -int main() { - printf("Hello World!\n"); - return 0; -} -``` - -## 冒险启程 - -### 引入主角:`#include ` - -一切编程冒险的开始,往往需要介绍主角。`#include `,就像是你打开了大门,邀请了 C 语言的图书管理员进来。这个图书管理员叫做 stdio.h,全名是“标准输入输出头文件”,它会帮你处理所有输入输出的操作。 - -- 问题:为什么要请这个图书管理员? -- 答案:因为你马上要用到它的宝贝——printf 函数,来把“Hello, World!”打印到屏幕上。如果没有它,printf 就无家可归,根本用不了。 - -### 主办方出场:`int main()` - -然后,我们的主办方登场了:int main()。这就是你程序的“大老板”,相当于所有代码的起点,所有程序都必须从这里开始执行。主办方全程监督你的程序,不管你有多少代码,都得从这里开始运行。 - -- 问题:为什么是 int 类型? -- 答案:因为在 C 语言里,main 函数必须给操作系统打个招呼,告诉它程序运行完后是成功了,还是出问题了。int 类型表示返回一个整数,0 表示“一切顺利,没毛病”,其他数字则意味着“嗯,有点小问题”。 - -### 打印操作:`printf("Hello, World!\n");` - -接下来就是本次冒险的重头戏了!`printf("Hello, World!\n");`。`printf` 函数就像是一个巨大的喇叭,它会大声宣布:“Hello, World!”。 - -- 问题:为什么有个\n? -- dark 案:这个神秘的\n 是一个控制符,表示“换行”。就好像你在纸上写字,写完一句话,换行再写下一句。如果没有它,你的输出就会一行连着一行,显得有点凌乱。 - -### 礼貌的告别:`return 0;` - -最后,主办方 main 函数礼貌地对操作系统说:“感谢光临,一切顺利!”通过 return 0;这句话,你告诉操作系统,程序没有出错,可以放心关闭。这是一种好习惯,因为虽然你不告诉它也能运行,但你得有礼貌嘛,是吧? - -### 编译与运行:是时候展示真正的实力了 - -写完代码可不意味着工作结束了!你还得把这段文字转化为计算机能懂的“机器语言”。这是通过编译器来完成的。 - -#### 编译命令 - -##### GCC - -```bash -gcc helloworld.c -o helloworld -``` - -这句话告诉编译器:“把 helloworld.c 这个文件编译成名为 helloworld 的可执行文件。” - -#### 运行程序 - -```bash -./helloworld -``` - -##### MSVC - -直接点编译和运行按钮 - -这时,计算机会乖乖地在屏幕上显示:“Hello, World!” - -## 拆包 - -### 头文件 - -#### 什么是头文件? - -头文件其实是一些已经写好的 C 语言代码的集合,通常以.h 结尾。你可以把它理解为一份“说明书”,当你需要用某个功能时,头文件就告诉编译器如何找到和使用这个功能。 - -举个栗子: - -- `stdio.h`:包含标准输入输出函数,比如`printf`、`scanf`等。 -- `stdlib.h`:包含内存分配、程序控制、随机数等函数。 -- `math.h`:包含数学函数,比如`sin`、`cos`、`sqrt`等。 - -#### 如何引入头文件? - -我们使用预处理指令`#include`来引入头文件。举个例子: - -```c -#include -``` - -这个指令告诉编译器:在编译这段代码时,先把 stdio.h 的内容包含进来,以便程序中可以使用它的功能. - -#### 系统头文件和自定义头文件 - -- 系统头文件:放在尖括号< >里面,编译器知道去系统目录找这些文件。例如: - -```c -#include -``` - -- 自定义头文件:如果你自己写了一个头文件,需要用双引号" "包含,编译器会在你的项目目录里找。例如: - -```c -#include "myheader.h" -``` - -### 主函数 - -```c -int main(int argc, char *argv[]) { - // 使用命令行参数的代码 - return 0; -} -``` - -- 参数: argc:表示命令行参数的个数。argv[]:是一个字符串数组,存储命令行传入的参数。**一般情况下可以不写** -- 返回值类型 int:main()函数返回一个整数类型的数据。这个返回值通常用来告诉操作系统程序的执行结果。返回 0 意味着程序成功执行,非零值则表示程序有问题。 -- 函数体:花括号{}里面就是你程序的逻辑。你可以把所有需要执行的代码都写在这里。 -- `return 0;`:return 语句告诉操作系统程序执行完毕了。返回 0 表示成功。虽然对于小程序可以不写 return 0;,但这是个好习惯。 - -**C++用户请注意,写 `return 0;`并不是好习惯,并不被推荐!** - -### 输出 - -`printf()`函数是 C 语言中最常用的输出函数,它的作用是将数据打印到标准输出(通常是屏幕)。`printf()`函数由`stdio.h`头文件提供。 - -#### printf()的基本用法 - -```c -printf("Hello, World!\n"); -``` - -这个函数会将双引号内的字符串输出到屏幕上。printf()可以输出不同类型的数据,具体的输出格式由格式化占位符决定。 - -#### 常用的格式化占位符 - -- %d:输出整数(int 类型)。 -- %f:输出浮点数(float、double 类型)。 -- %c:输出单个字符。 -- %s:输出字符串(字符数组)。 - -```c -int age = 25; -float height = 5.9; -printf("I am %d years old and %.1f feet tall.\n", age, height); -``` - -```c -I am 25 years old and 5.9 feet tall. -``` - -#### 转义字符 - -转义字符是用于表示某些特殊字符的,例如换行、制表符等。 - -- \n:换行符。 -- \t:制表符,相当于插入了一个“跳格”。 -- \\:输出反斜杠\本身。 - -```c -printf("Hello,\nWorld!\tTabbed!\n"); -``` - -```text -Hello, -World! Tabbed! -``` - -#### 关于 puts(另一种输出) - -```c -puts("This line is printed using puts(), and it automatically adds a newline."); -``` - -**请注意 puts 并不被推荐使用,完全不如 printf,以下是不使用的原因** - -- 自动添加换行符: `puts()`函数的一个显著特征是它会自动添加换行符。这在某些简单的场景下很方便,但在更复杂的程序中,往往并不符合预期。 -- 无法处理格式化输出: `puts()`只能打印单一的字符串,并且不支持格式化输出(如插入变量、控制输出精度等)。 -- 缺少错误处理: `puts()`在失败时通常返回一个非负值(EOF 表示错误),但它并不会提供详细的错误信息。 - -## 完整的代码 - -[在线编译器代码测试](https://godbolt.org/z/39ao5df4q) - -```c -#include -#include - -int main() { - // 1. 使用 printf 输出到屏幕 - printf("Hello, World! This is printed using printf().\n"); - - // 2. 使用 puts 输出到屏幕 - puts("This line is printed using puts(), and it automatically adds a newline."); - - // 3. 使用 putchar 输出单个字符到屏幕 - putchar('A'); - putchar('\n'); - return 0; -} -``` - -## 总结 - -“Hello, World!”虽然只是个简单的小程序,但它背后涉及了很多 C 语言的基本概念:头文件、函数、语句、控制符、返回值、编译运行……别看它简单,这可是每个程序员的“开门钥匙”。 -所以,恭喜你!迈出了编程世界的第一步。只要记住,每一个复杂的程序都源于这种小小的 printf,就像每一篇大作都是从一个字开始的! -从此,编程之路正式开启,未来你还会学到更多神奇的东西,比如条件判断、循环、指针等等。不过别急,慢慢来,一切都会明朗。希望你编程愉快,`bug 少少,代码美美!` diff --git a/doc/extra/day02.md b/doc/extra/day02.md deleted file mode 100644 index 098ef854..00000000 --- a/doc/extra/day02.md +++ /dev/null @@ -1,412 +0,0 @@ -# 从零入门 C 语言: Day2 - 不一样的选择却殊途同归 - -分支语句在 C 语言中就像是程序的导航系统,告诉代码在不同的条件下该去哪里。这些语句帮助你的程序在不同的情况下选择不同的路线,就像你开车时选择不同的道路以避开交通堵塞。通过这些分支,程序能根据不同的输入或状态选择最合适的代码块,从而提升其灵活性和适应性。C 语言中的分支语句有点像你在饭店点餐时的选项:你可以选择“if”你喜欢这道菜,或者“if-else”你更喜欢另一道菜。还有“else-if”结构,就像在菜单上翻找更多的选择,最后是“switch”语句,就像你在自助餐厅里随机挑选自己想吃的美食。本文将以一种轻松愉快的方式,详细介绍这些分支语句的语法、用法以及注意事项,确保你在编程的旅途中不会迷路。 -~~奇怪的比喻增加了~~ - -## `if` 语句 - -if 语句是 C 语言中最基础的条件分支控制结构,主要用于判断某个条件是否为真,如果为真则执行特定的代码块。if 语句的条件表达式通常是一个逻辑表达式,其值为 true 或 false。 - -### 语法 - -```c -if (条件表达式) { -// 条件为真时执行的代码块 -} -``` - -- 条件表达式:通常为一个布尔表达式。如果条件成立(即结果为 true),则执行大括号内的代码块;否则跳过该代码块。 -- 代码块:大括号 {} 内的代码仅在条件为真时执行。 - -### 举个简单的例子 - -```c -#include - -int main() { - int x = 5; - - if (x > 0) { - printf("x is positive.\n"); - } - - return 0; - -} -``` - -`if (x > 0)` 判断变量 x 是否大于 0。因为 x 的值为 5,条件成立,所以会输出 `x is positive.`。 - -## `if-else` 语句 - -if-else 语句是 if 语句的扩展,允许程序在条件不成立时执行另一段代码。即当条件为真时执行 if 分支的代码;当条件为假(false)时,执行 else 分支的代码。 - -### 语法格式 - -```c -if (条件表达式) { -// 条件为真时执行的代码块 -} else { -// 条件为假时执行的代码块 -} -``` - -### 再来个例子 - -```c -#include - -int main() { - int x = -3; - - if (x > 0) { - printf("x is positive.\n"); - } else { - printf("x is not positive.\n"); - } - - return 0; - -} -``` - -在这个例子中,变量 x 的值为-3,因此 `if (x > 0)` 条件不成立,程序会执行 else 分支,输出 `x is not positive.`。 - -## `else if` 结构 - -else if 结构用于处理多个条件判断。它可以在一个 if-else 语句的基础上,增加多个条件判断,直到找到一个为真的条件。如果某个条件成立,则执行该条件对应的代码块,之后的分支将不会执行。 - -### 语法格式 - -```c -if (条件表达式 1) { -// 条件 1 为真时执行的代码块 -} else if (条件表达式 2) { -// 条件 2 为真时执行的代码块 -} else { -// 当所有条件都不成立时执行的代码块 -} -``` - -### 还需要一个例子 - -```c -#include - -int main() { - int x = 0; - if (x > 0) { - printf("x is positive.\n"); - } else if (x < 0) { - printf("x is negative.\n"); - } else { - printf("x is zero.\n"); - } - return 0; -} -``` - -程序首先判断 x > 0 是否成立,如果成立则执行第一个分支。如果不成立,继续判断 x < 0,如果也不成立,则执行 else 分支。在这个例子中,x 的值为 0,因此会输出 `x is zero.`。 - -## `switch` 语句 - -switch 语句是一种多分支的选择结构,常用于处理离散值的判断。它通过匹配一个表达式的值,执行与该值对应的代码块。switch 语句常用于替代多个 if-else if 的条件判断,使得代码更简洁易读。 - -### 语法格式 - -```c -switch (表达式) { -case 常量 1: -// 当表达式的值等于常量 1 时执行 -break; -case 常量 2: -// 当表达式的值等于常量 2 时执行 -break; -// 可以有多个 case 分支 -default: -// 当所有 case 都不匹配时执行 -} -``` - -- 表达式:通常是一个整型表达式或字符表达式。 -- case:用于匹配表达式的值。当表达式的值等于某个 case 分支的常量时,执行该分支的代码。 -- break:用于跳出 switch 语句。如果不加 break,程序会继续执行下一个 case 的代码(即使条件不匹配),这称为"fall-through"现象。 -- default:当所有 case 分支都不匹配时,执行 default 分支的代码。 - -### 又来一个例子 - -```c -#include - -int main() { - int day = 3; - switch (day) { - case 1: - printf("Monday\n"); - break; - case 2: - printf("Tuesday\n"); - break; - case 3: - printf("Wednesday\n"); - break; - case 4: - printf("Thursday\n"); - break; - case 5: - printf("Friday\n"); - break; - default: - printf("Invalid day\n"); - } - return 0; -} -``` - -在这个例子中,day 的值为 3,因此程序会执行 case 3 的代码,输出“Wednesday”。如果没有 break 语句,程序会继续执行下一个 case 的代码,直到遇到 break 或结束 switch 语句。 - -## 嵌套的分支语句 - -分支语句可以互相嵌套,以处理更加复杂的条件判断。常见的嵌套方式是在 if 语句内部嵌套另一个 if 或 switch 语句。 - -### if里套if - -```c -#include - -int main() { - int x = 10; - - if (x > 0) { - if (x % 2 == 0) { - printf("x is positive and even.\n"); - } else { - printf("x is positive and odd.\n"); - } - } else { - printf("x is not positive.\n"); - } - - return 0; - -} -``` - -外层 if 语句判断 x 是否为正数,内层 if 语句进一步判断它的奇偶性。在这个例子中,x 的值为 10,因此会输出 `x is positive and even.`。 - -### switch里套switch - -```c -#include - -int main() { - int category = 1; - int type = 2; - - switch (category) { - case 1: - switch (type) { - case 1: - printf("Category 1, Type 1\n"); - break; - case 2: - printf("Category 1, Type 2\n"); - break; - default: - printf("Unknown type in category 1\n"); - } - break; - case 2: - printf("Category 2\n"); - break; - default: - printf("Unknown category\n"); - } - return 0; -} -``` - -这里我们嵌套了两个 switch 语句,分别处理类别和类别下的类型。category 为 1,type 为 2,因此会输出 `Category 1, Type 2`。 - -## 注意 - -### if 语句中的空语句 - -在使用 if 语句时,建议总是使用大括号{}包围代码块,即使代码块中只有一条语句。这可以防止某些情况下由于缩进或其他原因导致的逻辑错误。 - -典中典:未使用大括号的潜在问题 - -```c -#include - -int main() { - int x = 5; - - if (x > 0) - printf("x is positive.\n"); - printf("This is outside the if.\n"); // 实际上总是会执行 - - return 0; - -} -``` - -```txt -x is positive. -This is outside the if. -``` - -由于没有使用大括号,第二个 printf 语句并不属于 if 语句的条件判断部分。它将始终执行,可能导致不符合预期的程序行为。 - -改进示例 - -```c -#include - -int main() { -int x = 5; - - if (x > 0) { - printf("x is positive.\n"); - printf("This is inside the if.\n"); // 只有条件成立时才执行 - } - - return 0; - -} -``` - -x is positive. -This is inside the if. - -### switch 语句中的“fall-through”现象 - -switch 语句中,如果没有 break 语句,会发生“fall-through”现象,导致程序继续执行下一个 case 的代码块,即使表达式不匹配下一个 case。 - -```c -#include - -int main() { - int x = 2; - - switch (x) { - case 1: - printf("One\n"); - case 2: - printf("Two\n"); // 执行此 case 后会继续执行 case 3 的代码 - case 3: - printf("Three\n"); - break; - default: - printf("Unknown\n"); - } - - return 0; - -} -``` - -```txt -Two -Three -``` - -改进示例 - -```c -#include - -int main() { - int x = 2; - - switch (x) { - case 1: - printf("One\n"); - break; - case 2: - printf("Two\n"); - break; - case 3: - printf("Three\n"); - break; - default: - printf("Unknown\n"); - } - - return 0; - -} -``` - -```txt -Two -``` - -## 来一点高级技巧 - -### 使用三元运算符简化条件判断 - -C 语言提供了三元运算符?:,用于简化简单的条件判断。三元运算符的语法如下: - -```c -条件 ? 表达式 1 : 表达式 2; -``` - -当条件为真时,返回表达式 1 的值;否则返回表达式 2 的值。它通常用于替代简单的 if-else 结构。 - -```c -#include - -int main() { - int x = 5; - const char result = (x > 0) ? "positive" : "non-positive"; - - printf("x is %s.\n", result); - return 0; - -} -``` - -### 枚举与 switch 语句的结合 - -当处理一组相关的常量时,可以使用 enum 枚举类型结合 switch 语句,这样代码更加清晰,并减少了使用魔术数字的风险。 -示例:使用枚举和 switch - -```c -#include - -enum Days { MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY }; - -int main() { - enum Days today = WEDNESDAY; - - switch (today) { - case MONDAY: - printf("Today is Monday.\n"); - break; - case TUESDAY: - printf("Today is Tuesday.\n"); - break; - case WEDNESDAY: - printf("Today is Wednesday.\n"); - break; - case THURSDAY: - printf("Today is Thursday.\n"); - break; - case FRIDAY: - printf("Today is Friday.\n"); - break; - default: - printf("Unknown day.\n"); - } - - return 0; - -} -``` - -### 总结 - -通过合理使用分支语句,程序可以根据输入和条件灵活调整执行路径,完成更为复杂的任务。在编写代码时,牢记代码的可读性和可维护性,尽量简化逻辑和控制流。 - -**生命中最重要的不是我们选择了什么,而是我们如何对待我们所选择的** diff --git a/doc/extra/day03.md b/doc/extra/day03.md deleted file mode 100644 index 9aacf056..00000000 --- a/doc/extra/day03.md +++ /dev/null @@ -1,232 +0,0 @@ -# 从零入门 C 语言: Day3 - 重复的事情说一次就够了 - -今天的内容将为你提供一个清晰的起点,帮助你掌握循环语句这一编程中的重要概念。循环语句让我们能够高效地重复执行任务,这不仅是编程的基本能力之一,更是编写灵活、动态程序的关键。 -准备好了吗?让我们一起进入 C 语言的世界,揭开循环语句的神秘面纱,开始这场充满逻辑与创造力的冒险吧!**C 语言,启动!** - -## `for` 循环 - -for 循环是最常见的循环之一。最经典形状应该是长下面这样的: - -```c -for (初始化; 条件; 更新) { - // 循环体 -} -``` - -### 具体的执行步骤 - -- 初始化:在循环开始时执行一次。 -- 条件:每次循环前检查,如果条件为真(非零),继续执行循环体;为假(零)则退出循环。 -- 更新:每次循环结束后执行一次,通常用于更新循环变量。 - -### 举个栗子 - -```c -#include - -int main() { - int i; - for (i = 0; i < 5; i++) { - printf("i = %d\n", i); - } - return 0; -} -``` - -这个例子中,i 从 0 开始,每次循环加 1,直到 i 不小于 5。输出结果是 i = 0 到 i = 4。 - -### 注意点 - -for 循环适用于已知循环次数的情况。你可以在 for 循环中省略某些部分,比如初始化或更新,但分号必须保留。 - -### 特殊的 `for(;;)` - -for (;;) 是一种特殊的 for 循环结构,它表示一个无限循环。 - -#### 具体剖析 - -按照刚才对 for 结构的分析,我们不难发现,在 for (;;) 中,三个部分都被省略了: - -- 初始化:没有初始化变量。 -- 条件:没有明确条件,C/C++ 默认为真(true),因此会导致无限循环。 -- 迭代:没有迭代语句。 - -由于没有条件限制,所以当遇到 for (;;) 时,程序会无限地执行其循环体,直到发生某种中断,如 break、return 或外部干预(例如终止程序) - -#### 来个栗子 - -```c -#include - -int main() { - int count = 0; - - for (;;) { // 无限循环 - printf("Count: %d\n", count); - count++; - - if (count >= 5) { - break; // 当 count 达到 5 时退出循环 - } - } - - return 0; -} -``` - -## while 循环 - -while 循环是在条件为真时反复执行的循环。 - -```c -while (条件) { -// 循环体 -} -``` - -在每次循环前检查。如果条件为真(非零),执行循环体;否则退出循环。 - -### 还是需要一个例子 - -```c -#include - -int main() { - int i = 0; - while (i < 5) { - printf("i = %d\n", i); - i++; - } - return 0; -} -``` - -这个例子与 for 循环的例子类似,i 从 0 开始,每次循环加 1,直到 i 不小于 5 时退出。 - -### 注意点 - -如果条件在循环内没有被修改,可能导致死循环(循环永远不会结束)。但是有的时候我们也会主动创造一些死循环的情况,比如说像下面这种,条件始终为真,那么如果内部不退出或者杀死进程的话,会不停运行循环体。 - -```c -while (true) { - // 循环体 -} - -``` - -## do-while 循环 - -do-while 循环与 while 循环类似,但它至少执行一次循环体,然后再检查条件。 - -```c -do { -// 循环体 -} while (条件); -``` - -首先执行一次循环体,然后检查条件。如果条件为真,继续循环;否则退出。 - -### Example - -```c -#include - -int main() { - int i = 0; - do { - printf("i = %d\n", i); - i++; - } while (i < 5); - return 0; -} -``` - -这里 i 依旧从 0 开始,循环体会先执行一次,然后再检查条件。 - -### 注意点 - -由于 do-while 循环是先执行再判断条件,因此即使条件一开始为假,循环体也会执行一次。 - -## 经典循环应用 - -### 求和算法 - -计算从 1 到 100 的所有整数之和。 - -```c -#include - -int main() { - int sum = 0; - int i; - - for (i = 1; i <= 100; i++) { - sum += i; - } - - printf("1到100的和是: %d\n", sum); - return 0; - -} -``` - -### 计算阶乘 - -计算一个数的阶乘,比如 5! = 5 × 4 × 3 × 2 × 1。 - -```c -#include - -int main() { - int n = 5; // 要计算的阶乘数 - int factorial = 1; - int i; - - for (i = 1; i <= n; i++) { - factorial *= i; - } - - printf("%d 的阶乘是: %d\n", n, factorial); - return 0; - -} -``` - -### 斐波那契数列 - -生成斐波那契数列的前 10 个数字。 - -```c -#include - -int main() { - int n = 10; - int first = 0, second = 1, next; - int i; - - printf("斐波那契数列: "); - - for (i = 0; i < n; i++) { - if (i <= 1) { - next = i; - } else { - next = first + second; - first = second; - second = next; - } - printf("%d ", next); - } - - printf("\n"); - return 0; - -} -``` - -## 总结 - -- for 循环:适用于已知循环次数的情况。 -- while 循环:适用于条件驱动的循环。 -- do-while 循环:至少执行一次的条件循环。 - -**人生就像一场旅行,不要重复走以前的路,而要勇敢地探索未知的方向。** diff --git a/doc/extra/day04.md b/doc/extra/day04.md deleted file mode 100644 index 3444ed59..00000000 --- a/doc/extra/day04.md +++ /dev/null @@ -1,833 +0,0 @@ -# 从零入门 C 语言: Day4 - 你是个什么东西 - -## 引入 - -学习 C 语言时,理解类型系统和类型转换是非常重要的。它们帮助你掌握如何存储和操作不同类型的数据。在接下来的讲解中,我会以简单生动的方式带你一探 C 语言类型系统的奥秘,并通过示例加深理解。 - -## 类型系统 - -C 语言是一门强类型语言,这意味着每一个变量在使用之前都必须有一个确定的类型。数据类型决定了变量可以存储什么样的数据,以及在内存中占据多少空间。 - -### 基本数据类型 - -#### 整型(Integer Types) - -- int:标准整型,通常占用 4 字节内存。 -- short:短整型,通常占用 2 字节。 -- long:长整型,通常占用 4 或 8 字节。 -- long long:更长的整型,通常占用 8 字节。 - -```c -int age = 25; -short year = 2022; -long population = 8000000L; -long long distance = 12345678901234LL; -``` - -#### 字符型(Character Type) - -char:用于存储单个字符,占用 1 字节内存。 - -```c -char letter = 'A'; -``` - -#### 浮点型(Floating Point Types) - -- float:单精度浮点型,通常占用 4 字节内存。 -- double:双精度浮点型,通常占用 8 字节内存。 -- long double:扩展精度浮点型,通常占用 12 或 16 字节。 - -```c -float pi = 3.14f; -double e = 2.718281828459045; -``` - -#### 枚举类型(Enumerated Types) - -枚举类型是一种用户定义的类型,用于定义一组命名的整型常量。主要是为了更好的描述数据,并简化代码。 - -```c -enum Day { SUNDAY, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY }; -enum Day today = WEDNESDAY; -``` - -#### void 类型 - -void 类型表示“无类型”,通常用于函数返回类型表示该函数不返回任何值。**注意:部分教程说函数返回值 void 可以省略并不是好习惯!** - -```c -void sayHello() { - printf("Hello, World!\n"); -} -``` - -#### 指针类型 - -指针是 C 语言中的重要类型,用于存储内存地址,也是 c 语言入门学习的噩梦之一。指针的类型决定了它指向的变量类型。 - -```c -int x = 10; -int *ptr = &x; // ptr 是一个指向整数的指针,存储 x 的地址 -``` - -### 类型转换 - -在 C 语言中,有时候需要将一种类型的数据转换为另一种类型的数据。类型转换可以分为`隐式类型转换`和`显式类型转换`。 - -#### 隐式类型转换 - -当你在表达式中混合使用不同类型的变量时,编译器会自动将它们转换为相同的类型,以避免数据损失。这种转换称为隐式类型转换。说白了就是编译器帮你自动添加了一些代码,完成了这个任务,不需要你来操心! - -```c -int a = 5; -double b = 3.2; -double result = a + b; // a 被隐式转换为 double 类型 -``` - -但是需要注意的是,并不是所有同类类型都可以隐式转换!编译器通常会将“窄”类型(如 int)转换为“宽”类型(如 double),以确保精度。 - -#### 显式类型转换(强制类型转换) - -有时候,常规的隐式转换已经不能满足需求,你需要手动将一种类型的数据转换为另一种类型,这被称为显式类型转换。具体的语法形式为: - -```c -new_type variable = (new_type)expression; -``` - -```c -double pi = 3.14159; -int truncated_pi = (int)pi; // 将 double 转换为 int,结果为3 -``` - -这时候你可能发现了,pi 的小数部分被抹去了,因此强制转换可能会导致数据丢失,所以使用强制转换时要小心,确保了解数据的潜在变化。 - -### 来点例子 - -- 整数到浮点数的转换 - -```c -#include - -int main() { - int apples = 10; - double price = 1.5; - double total_cost = apples * price; // 隐式类型转换,apples 被转换为 double - printf("Total cost: $%.2f\n", total_cost); - return 0; -} -``` - -在这个例子中,`apples` 是一个整型,`price` 是一个浮点型。在计算`total_cost`时,`apples` 被隐式转换为 **double** 类型以进行浮点数乘法运算。 - -- 强制类型转换 - -```c -#include - -int main() { - double average = 85.6; - int rounded_average = (int)average; // 强制转换为 int,结果为 85 - printf("Rounded average: %d\n", rounded_average); - return 0; -} -``` - -`average` 是一个`double`类型。在将其转换为`int`时,强制转换 **截断了小数部分** ,使得`rounded_average`变为 85。 - -- 指针类型转换 - -```c -#include - -int main() { - int a = 42; - void *ptr = &a; // 将 int 指针转换为 void 指针 - int *int_ptr = (int *)ptr; // 强制转换回 int 指针 - printf("Value of a: %d\n", *int_ptr); - return 0; -} -``` - -这里,`ptr` 原本是一个`void`指针,可以指向任何类型的变量。我们将它强制转换回 int 指针,以正确地访问变量 a 的值。 - -### 陷阱和注意点 - -- 溢出问题:在类型转换时,特别是从大类型转换为小类型时,要注意溢出问题。例如,将一个 long 值强制转换为 short 时,如果 long 值超出了 short 的范围,可能会得到一个错误的结果。 - -```c -long big_number = 1234567890L; -short small_number = (short)big_number; // 可能导致溢出,结果不可预测 -``` - -- 截断问题:当从浮点类型转换为整数类型时,小数部分会被截断。要确保这种截断是你想要的结果。 - -```c -double pi = 3.14159; -int truncated_pi = (int)pi; // 结果为 3,截断了小数部分 -``` - -- 指针类型的转换:指针的类型转换需要特别小心。如果你将一个指向某种类型的指针转换为另一种类型的指针,可能会导致未定义行为。 - -```c -int a = 10; -double *ptr = (double *)&a; // 可能导致指针错误访问 -``` - -## 数组 - -刚刚我们讲解的都是单个变量,那么如果我要存储一组类型相同的变量,难道只能够定义 n 个变量吗?显然不用,这个时候就要请出数组了! - -数组是 C 语言中一个非常重要的概念,它让我们可以处理一组相同类型的数据。数组在内存中是连续存储的,可以高效地管理和访问大量数据。 - -### 一维数组 - -#### 定义和初始化 - -一维数组是最基本的数组类型,存储一组相同类型的数据。 - -```c -int arr[5]; // 定义一个包含5个整型元素的数组,中括号内的数字即为数组大小 -``` - -你可以在定义数组时进行初始化: - -```c -int arr[5] = {1, 2, 3, 4, 5}; // 初始化一个包含5个元素的数组 -``` - -如果你不指定数组的大小,编译器会根据初始化列表自动推导大小: - -```c -int arr[] = {1, 2, 3}; // 自动推导数组大小为3 -``` - -需要注意的是,如果你部分初始化数组,那么未初始化的元素会被设置为 0。 - -```c -int arr[5] = {1, 2}; // 剩余元素自动初始化为0 -``` - -#### 访问元素 - -这里很明确,我们可以通过数组的索引来访问和修改元素。这里有一个大坑, **索引从 0 开始** ,初学者往往会习惯性认为索引从 1 开始而导致错误。 - -```c -int x = arr[2]; // 获取数组的第三个元素,值为3 -arr[0] = 10; // 将第一个元素的值改为10 -``` - -### 多维数组 - -多维数组用于存储表格或矩阵形式的数据。其中,二维数组是最常见的多维数组。 - -#### 定义&初始化 - -二维数组可以看作是“数组的数组”,即每个元素本身又是一个数组。来个简单的例子: - -```c -int matrix[3][4]; // 定义一个3行4列的二维数组 -``` - -同理,二维数组的初始化可以通过嵌套的花括号来完成: - -```c -int matrix[3][4] = { - {1, 2, 3, 4}, - {5, 6, 7, 8}, - {9, 10, 11, 12} -}; -``` - -当然,如果你不明确行列也是可以的,编译器会自动处理行和列的关联,就像下面这样: - -```c -int matrix[3][4] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; -``` - -#### 访问 - -由访问一维数组需要一个索引可以推断,访问二维数组的元素需要两个索引,分别表示行和列。 - -```c -int value = matrix[1][2]; // 获取第二行第三列的值,值为7 -matrix[2][3] = 15; // 将第三行第四列的值设为15 -``` - -多维数组的元素在内存中是连续存储的,**行优先(row-major)** 顺序。即第一行的所有元素存储在内存中,然后是第二行,依此类推。 - -### 指针数组和数组指针 - -#### 指针数组 - -指针数组是一个数组,其中的每个元素都是指针。它通常用于存储字符串数组或其他需要动态分配内存的数据。 - -```c -char *strArr[3]; // 定义一个字符指针数组 -strArr[0] = "Hello"; -strArr[1] = "World"; -strArr[2] = "C"; -``` - -你可以通过数组索引访问这些字符串: - -```c -printf("%s\n", strArr[1]); // 输出 "World" -``` - -#### 数组指针 - -数组指针是指向数组的指针。它允许我们通过指针操作数组。 - -```c -int arr[5] = {1, 2, 3, 4, 5}; -int (*p)[5] = &arr; // p 是一个指向包含5个整数的数组的指针 -``` - -你可以通过指针访问数组元素: - -```c -int x = (*p)[2]; // 获取第三个元素,值为3 -``` - -#### 指针运算 - -由于数组名实际上是指向数组首元素的指针,因此可以进行指针运算: - -```c -int arr[5] = {1, 2, 3, 4, 5}; -int *p = arr; - -printf("%d\n", *(p + 2)); // 输出3,即arr[2] -``` - -p + 2 表示指针 p 向后移动两个元素,然后通过\*解引用访问该元素的值。 - -### 二级数组(指针的数组) - -二级数组通常用于表示指针的数组,特别是在处理二维数组或指针数组时。 - -#### 定义 and 用法 - -一个简单的二级数组是一个指向指针的指针(char **或 int **)。这种结构可以用于 **动态分配** 二维数组或处理字符指针数组。 - -```c -char *lines[] = {"line1", "line2", "line3"}; -char **p = lines; // p 是一个指向字符指针的指针 -``` - -你可以使用二级指针访问或修改指针数组的内容: - -```c -printf("%s\n", p[1]); // 输出 "line2" -p[2] = "new line3"; // 修改第三个指针所指向的字符串 -``` - -#### 动态分配 - -二级指针非常适合动态分配二维数组,下面是一个典型例子: - -```c -#include -#include - -int main() { - int rows = 3; // 矩阵的行数 - int cols = 4; // 矩阵的列数 - - // 动态分配一个指向指针的数组,用于存储每一行的指针 - int **matrix = (int **)malloc(rows * sizeof(int *)); - - // 为每一行动态分配列数的内存 - for (int i = 0; i < rows; i++) { - matrix[i] = (int *)malloc(cols * sizeof(int)); - } - - // 初始化矩阵 - for (int i = 0; i < rows; i++) { - for (int j = 0; j < cols; j++) { - matrix[i][j] = i * cols + j; // 将矩阵元素赋值为它的线性索引 - } - } - - // 打印矩阵 - printf("矩阵内容:\n"); - for (int i = 0; i < rows; i++) { - for (int j = 0; j < cols; j++) { - printf("%d ", matrix[i][j]); // 打印每个元素 - } - printf("\n"); // 换行 - } - - // 释放内存 - for (int i = 0; i < rows; i++) { - free(matrix[i]); // 释放每一行的内存 - } - free(matrix); // 释放存储行指针的数组 - - return 0; // 返回成功状态 -} -``` - -```txt -0 1 2 3 -4 5 6 7 -8 9 10 11 -``` - -### 柔性数组(Flexible Array Members) - -柔性数组成员是一种 C99 标准引入的高级特性,允许在结构体中定义一个可变长度的数组。这种数组没有固定的大小,需要通过动态内存分配来使用。 - -#### 定义&使用 - -```c -#include -#include - -struct FlexibleArray { - int length; - int array[]; // 柔性数组成员 -}; - -int main() { - int n = 5; - - // 动态分配内存,包含结构体和数组的总大小 - struct FlexibleArray *fa = malloc(sizeof(struct FlexibleArray) + n * sizeof(int)); - - fa->length = n; - for (int i = 0; i < n; i++) { - fa->array[i] = i * 2; // 初始化数组 - } - - // 打印数组 - for (int i = 0; i < fa->length; i++) { - printf("%d ", fa->array[i]); - } - printf("\n"); - - free(fa); // 释放内存 - - return 0; -} -``` - -```txt -0 2 4 6 8 -``` - -需要格外注意的是,柔性数组成员 array[]在定义时不指定大小,内存分配时根据需要来决定数组的实际大小! - -### 静态数组和动态数组 - -其实,上面的很多例子中已经出现了这两个概念,那么下面就是具体的讲解。 - -#### 静态数组 - -静态数组是在 **编译时** 分配内存的,数组的大小在程序编译时就已经确定,并且在程序的生命周期内保持不变。静态数组通常分配在栈内存中,大小是 **固定的**。 - -```c -int static_array[10]; // 定义一个大小为10的静态数组 -``` - -优点: - -- 访问速度快,因为它们位于栈内存中。 -- 不需要手动管理内存,编译器会自动处理内存的分配和释放。 - -缺点: - -- 大小固定,一旦定义,无法在程序运行时更改。 -- 如果数组很大,可能导致栈溢出,特别是在嵌套调用较深的情况下。 - -#### 动态数组 - -动态数组是在程序运行时根据需要动态分配内存的。它的大小可以在运行时确定,并且可以在程序的不同阶段分配和释放内存。动态数组通常分配在堆内存中。 - -```c -#include -#include - -int main() { - int n = 10; - int *dynamic_array = (int *)malloc(n * sizeof(int)); // 动态分配一个大小为n的数组 - - if (dynamic_array == NULL) { - printf("Memory allocation failed!\n"); - return 1; - } - - // 使用数组 - for (int i = 0; i < n; i++) { - dynamic_array[i] = i * 2; - } - - // 打印数组 - for (int i = 0; i < n; i++) { - printf("%d ", dynamic_array[i]); - } - printf("\n"); - - // 释放内存 - free(dynamic_array); - - return 0; -} -``` - -优点: - -- 数组的大小可以在运行时动态调整。 -- 适合处理需要在运行时确定大小的大型数据集。 - -缺点: - -- 需要手动管理内存,必须确保在使用完后释放内存(使用 free 函数),否则会导致内存泄漏。 -- 动态内存分配相比静态内存分配速度稍慢,因为它涉及到系统调用。 - -### 函数中的数组参数 - -当数组作为参数传递给函数时,它实际上是将数组的指针传递给函数。因此,函数内对数组的任何修改会影响到原数组。 - -```c -void modifyArray(int *arr, int size) { - for (int i = 0; i < size; i++) { - arr[i] = arr[i] * 2; - } -} - -int main() { - int arr[5] = {1, 2, 3, 4, 5}; - modifyArray(arr, 5); - - for (int i = 0; i < 5; i++) { - printf("%d ", arr[i]); // 输出2 4 6 8 10 - } - return 0; -} -``` - -### 复杂一点的特性 - -#### 动态调整数组大小:realloc - -动态数组的大小可以在运行时通过 `realloc`函数调整。这允许你在程序运行期间根据需要扩展或缩小数组的大小。 - -```c -#include -#include - -int main() { - int *arr = (int *)malloc(5 * sizeof(int)); - - for (int i = 0; i < 5; i++) { - arr[i] = i; - } - - // 调整数组大小 - arr = (int *)realloc(arr, 10 * sizeof(int)); - - for (int i = 5; i < 10; i++) { - arr[i] = i * 2; - } - - // 打印数组 - for (int i = 0; i < 10; i++) { - printf("%d ", arr[i]); - } - printf("\n"); - - free(arr); - return 0; -} -``` - -需要注意,realloc 可能会移动数组到新的内存位置,因此返回的指针可能不同于原来的指针。如果 realloc 失败,返回 NULL,旧的内存保持不变。 - -### 数组中的注意点 - -- 数组越界 - -数组越界是 C 语言中常见且危险的错误,访问数组边界之外的内存可能导致未定义行为或程序崩溃。通常情况下你会获得名为 Segmentation Fault 的错误,也就是臭名昭著的段错误! - -```c -int arr[5] = {1, 2, 3, 4, 5}; -int value = arr[5]; // 错误:越界访问,而且是由最容易犯得记错索引导致的 -``` - -- 指针和数组的关系 - -数组名在大多数表达式中会被转换为指向其第一个元素的指针。例如: - -```c -int arr[5] = {1, 2, 3, 4, 5}; -int *p = arr; // p 指向 arr[0] -``` - -但是,数组名并不是指针,它是一个常量指针,不能被修改。 - -## 字符串 - -很多时候我们需要让变量能够存储一个句子,那么就需要用到字符串了。C 语言中的字符串体系是一个非常重要的概念,因为它涉及到如何存储、处理和操作文本数据。在 C 语言中,字符串并不像在其他高级语言中那样是一个独立的数据类型,而是一个字符数组。 - -### 表示字符串 - -在 C 语言中,字符串是由字符数组表示的,且以 **空字符(\0)** 结尾。空字符标志着字符串的结束, **因此它的长度比实际字符数多一位**。 - -### 定义与初始化 - -```c -char str1[] = "Hello, World!"; -``` - -这里我们并没有添加\0,但是仍然是正确的,因为编译器会自动添加\0,因此 str1 的实际大小是 14 个字符(包括\0)。这种初始化方式是最常用的,也是最安全的,因为它 **自动处理了字符串的结束标志**。 - -当然了,如果你不嫌麻烦,也可以逐字符初始化一个字符串,虽然这有点呆: - -```c -char str2[] = {'H', 'e', 'l', 'l', 'o', '\0'}; -``` - -不过这种方式也更灵活,但方便与灵活不可兼得,这样容易出错,尤其是在手动忘记添加\0 时,当场爆炸。 - -你还可以定义一个指向字符串的指针: - -```c -char *str3 = "Hello, World!"; -``` - -这里 str3 是一个指向字符串常量的指针。需要注意的是,这种方式定义的字符串通常存储在只读内存区,因此不能修改它的内容,可以加上 const 修饰符表明这是一个常量。 - -### 常见操作 - -C 语言标准库提供了一组函数来处理字符串。这些函数大多在``头文件中定义。需要注意的是,c 语言提供的很多操作函数其实是不安全的,在 - -#### `strlen`:获取字符串长度 - -`strlen 用于获取字符串的长度 **(不包括末尾的\0)**。 - -```c -#include -#include - -int main() { - char str[] = "Hello, World!"; - int length = strlen(str); - printf("Length of the string: %d\n", length); - return 0; -} -``` - -```txt -Length of the string: 13 -``` - -#### `strcpy`:复制字符串 - -strcpy 函数将一个字符串复制到另一个字符串中。 - -```c -#include -#include - -int main() { - char src[] = "Hello"; - char dest[20]; // 确保目标数组足够大,不然会溢出 - strcpy(dest, src); - printf("Copied string: %s\n", dest); - return 0; -} -``` - -```txt -Copied string: Hello -``` - -#### `strcat`:连接字符串 - -strcat 函数将两个字符串连接在一起。 - -```c -#include -#include - -int main() { - char str1[20] = "Hello"; - char str2[] = ", World!"; - strcat(str1, str2); // 确保目标数组有足够的空间存储连接后的字符串,包括空字符\0。 - printf("Concatenated string: %s\n", str1); - return 0; -} -``` - -```txt -Concatenated string: Hello, World! -``` - -#### `strcmp`:比较字符串 - -strcmp 函数用于比较两个字符串的字典顺序。 - -```c -#include -#include - -int main() { - char str1[] = "Apple"; - char str2[] = "Banana"; - int result = strcmp(str1, str2); - - if (result < 0) { - printf("str1 is less than str2\n"); - } else if (result > 0) { - printf("str1 is greater than str2\n"); - } else { - printf("str1 is equal to str2\n"); - } - return 0; -} -``` - -```txt -str1 is less than str2 -``` - -strcmp 返回一个整数值,如果第一个字符串在字典顺序上小于、等于或大于第二个字符串,则分别返回负值、0 或正值。 - -#### `strchr` 和 `strstr`:查找字符或子串 - -strchr 用于查找字符串中某个字符的第一次出现。 -strstr 用于查找一个字符串中第一次出现的子串。 - -```c -#include -#include - -int main() { - char str[] = "Hello, World!"; - char *pos = strchr(str, 'W'); - - if (pos != NULL) { - printf("Character found at position: %ld\n", pos - str); - } else { - printf("Character not found.\n"); - } - - char *substr = strstr(str, "World"); - - if (substr != NULL) { - printf("Substring found: %s\n", substr); - } else { - printf("Substring not found.\n"); - } - - return 0; -} -``` - -```txt -Character found at position: 7 -Substring found: World! -``` - -### 字符串操作中的注意点 - -很多前面都有,但还是反复强调! - -#### 字符数组大小 - -在定义字符数组时,一定要确保数组大小足够容纳字符串和空字符\0。例如: - -```c -char str[5] = "Hello"; // 错误:数组长度不足 -``` - -这个例子会导致缓冲区溢出,因为"Hello"需要 6 个字符空间(包括\0)。 - -#### 字符串常量和可变性 - -字符串常量(例如"Hello")通常存储在只读内存中,因此不能通过指针修改它们。如果你尝试修改一个字符串常量,会导致运行时错误: - -```c -char *str = "Hello"; -str[0] = 'h'; // 错误:可能导致程序崩溃 -``` - -#### 缓冲区溢出 - -字符串操作时的缓冲区溢出是 C 语言中非常常见且危险的错误。操作字符串时务必确保目标缓冲区的大小足够。例如: - -```c -char str1[10] = "Hello"; -char str2[] = "World!"; -strcat(str1, str2); // 错误:str1 缓冲区溢出 -``` - -这种错误会导致未定义行为,甚至程序崩溃。 - -### 来几个经典的例子 - -#### 反转字符串 - -```c -#include -#include - -void reverse(char str[]) { - int n = strlen(str); - for (int i = 0; i < n / 2; i++) { - char temp = str[i]; - str[i] = str[n - i - 1]; - str[n - i - 1] = temp; - } -} - -int main() { - char str[] = "Hello, World!"; - reverse(str); - printf("Reversed string: %s\n", str); - return 0; -} -``` - -```txt -Reversed string: !dlroW ,olleH -``` - -#### 检查回文字符串 - -```c -#include -#include - -int isPalindrome(char str[]) { - int n = strlen(str); - for (int i = 0; i < n / 2; i++) { - if (str[i] != str[n - i - 1]) { - return 0; - } - } - return 1; -} - -int main() { - char str[] = "madam"; - if (isPalindrome(str)) { - printf("The string is a palindrome.\n"); - } else { - printf("The string is not a palindrome.\n"); - } - return 0; -} -``` - -```txt -The string is a palindrome. -``` - -## 总结 - -今天的内容很多,需要好好消化,希望你能够理解并应用到你的实际编程项目中。 - -**人生是一条不断探索的旅程,充满了选择与变化,体验着喜悦与挑战,最终形成我们独特的故事。** diff --git a/doc/extra/day05.md b/doc/extra/day05.md deleted file mode 100644 index 88ce30f7..00000000 --- a/doc/extra/day05.md +++ /dev/null @@ -1,512 +0,0 @@ -# 从零入门 C 语言: Day5 - 标准输入 - -## 引入 - -在 C 语言中,输入是从用户或文件等外部源获取数据并存储到程序变量中的过程。C 语言提供了多种方式来获取输入数据,包括标准输入函数、文件输入函数以及低级别的系统输入函数。 - -## 标准输入函数 - -### `scanf` - -`scanf`是 C 语言中最常用的标准输入函数,它允许从标准输入(通常是键盘)中读取格式化的数据,并将这些数据存储到变量中。 - -```c -int scanf(const char *format, ...); -``` - -- format:指定要读取的输入数据类型的格式字符串(例如"%d"表示整数,"%f"表示浮点数)。 -- 返回值:返回成功读取的变量数量。如果读取失败,返回值为 EOF。 - -举个例子: - -```c -#include - -int main() { - int num; - float f; - - printf("Enter an integer and a float: "); - scanf("%d %f", &num, &f); - - printf("You entered: %d and %.2f\n", num, f); - return 0; -} -``` - -需要注意的是: - -- scanf 需要传入变量的地址(使用&符号),因为它需要修改这些变量的值。 -- scanf 会忽略输入数据中的空格、换行符和制表符,但可以用空格等分隔符来读取多项数据。 -- 如果输入的格式与指定的格式不匹配,可能导致读取失败或数据错误。 - -### `fscanf` - -fscanf 与 scanf 类似,但它是从文件流中读取格式化数据。 - -```c -int fscanf(FILE *stream, const char *format, ...); -``` - -- stream:文件流指针,指定要读取数据的文件。 -- 其他参数与 scanf 相同。 - -来个基础示例 - -```c -#include - -int main() { - FILE *file = fopen("input.txt", "r"); - if (file == NULL) { - printf("Error opening file!\n"); - return 1; - } - - int num; - fscanf(file, "%d", &num); - printf("Number from file: %d\n", num); - - fclose(file); - return 0; -} -``` - -切记,使用 fscanf 时,确保文件已成功打开,并在使用完毕后 **正确关闭文件**。 - -### `sscanf` - -从名字可以看出,sscanf 从字符串中读取格式化数据,而不是从标准输入或文件。 - -```c -int sscanf(const char *str, const char *format, ...); -``` - -- str:要读取数据的字符串。 -- 其他参数与 scanf 相同。 - -```c -#include - -int main() { - char input[] = "42 3.14"; - int num; - float f; - - sscanf(input, "%d %f", &num, &f); - printf("Parsed values: %d and %.2f\n", num, f); - - return 0; -} -``` - -## 字符输入函数 - -### `getchar` - -getchar 从标准输入中读取一个字符。它是一个简单的字符输入函数,通常用于逐字符读取输入。 - -```c -int getchar(void); -``` - -- 返回值:返回读取的字符(作为 int 类型)。如果遇到输入结束(EOF),返回 EOF。 - -```c -#include - -int main() { - char c; - - printf("Enter a character: "); - c = getchar(); - - printf("You entered: %c\n", c); - return 0; -} -``` - -需要格外注意的是,getchar 不会跳过空格和换行符,它会逐字符读取每一个输入字符。由于返回值为 int,你需要将其转换为 char 类型来使用。 - -### `fgetc` - -fgetc 与 getchar 类似,但它用于从文件中读取一个字符。 - -```c -int fgetc(FILE *stream); -``` - -- stream:要从中读取字符的文件流指针。 -- 返回值:返回读取的字符(作为 int 类型)。如果遇到文件结束或错误,返回 EOF。 - -```c -#include - -int main() { - FILE *file = fopen("input.txt", "r"); - if (file == NULL) { - printf("Error opening file!\n"); - return 1; - } - - char c; - while ((c = fgetc(file)) != EOF) { - putchar(c); // 打印读取的字符 - } - - fclose(file); - return 0; -} -``` - -注意点: - -- fgetc 用于从文件流中逐字符读取数据,适用于处理文件内容的逐行或逐字符处理。 -- 读取时会自动移动文件指针,逐字符读取下一个字符。 - -### `getc` - -getc 与 fgetc 功能相同,但可能实现方式略有不同。getc 通常用于从文件流中读取字符。 - -```c -int getc(FILE *stream); -``` - -- stream:文件流指针。 -- 返回值:返回读取的字符(作为 int 类型),或在遇到 EOF 时返回 EOF。 - -```c -#include - -int main() { - FILE *file = fopen("input.txt", "r"); - if (file == NULL) { - printf("Error opening file!\n"); - return 1; - } - - char c; - while ((c = getc(file)) != EOF) { - putchar(c); // 打印读取的字符 - } - - fclose(file); - return 0; -} -``` - -### `ungetc` - -ungetc 将字符“放回”到输入流中,使其成为下一个要读取的字符。这在某些解析场景中非常有用。 - -```c -int ungetc(int c, FILE *stream); -``` - -- c:要放回的字符。 -- stream:文件流指针。 -- 返回值:返回放回的字符,若失败返回 EOF。 - -```c -#include - -int main() { - int c; - FILE *file = fopen("input.txt", "r"); - - if (file == NULL) { - printf("Error opening file!\n"); - return 1; - } - - c = fgetc(file); - if (c != EOF) { - ungetc(c, file); // 将字符放回流中 - c = fgetc(file); // 再次读取同一个字符 - printf("Character read again: %c\n", c); - } - - fclose(file); - return 0; -} -``` - -ungetc 最多只能将一个字符放回流中,放回后可以再次读取同一个字符。 -在使用 ungetc 之前,**确保放回的字符与流的状态一致**。 - -## 行输入函数 - -### `gets`(不推荐) - -gets 从标准输入读取一行字符串,直到遇到换行符或文件结束符。由于存在缓冲区溢出的风险,**gets 已被 C11 标准弃用**。 - -```c -char *gets(char *str); -``` - -- str:指向接收输入字符串的缓冲区指针。 -- 返回值:返回输入字符串指针(str),如遇到 EOF 则返回 NULL。 - -Outdated!!! - -```c -#include - -int main() { - char str[100]; - - printf("Enter a string: "); - gets(str); - - printf("You entered: %s\n", str); - return 0; -} -``` - -由于 gets 不检查输入是否超过缓冲区大小,可能导致缓冲区溢出并引发安全问题,强烈建议不要使用 gets,应使用更安全的 fgets。 - -### `fgets` - -fgets 是读取一行输入的推荐方法,它可以避免缓冲区溢出问题。 - -```c -char *fgets(char *str, int n, FILE *stream); -``` - -- str:指向接收输入字符串的缓冲区指针。 -- n:要读取的最大字符数(包括终止符)。 -- stream:文件流指针,stdin 表示标准输入。 -- 返回值:返回 str,如果遇到 EOF 或发生错误,返回 NULL。 - -```c -#include - -int main() { - char str[100]; - - printf("Enter a string: "); - if (fgets(str, 100, stdin) != NULL) { - printf("You entered: %s", str); - } else { - printf("Error reading input.\n"); - } - - return 0; -} -``` - -注意,fgets 会保留输入的换行符\n,如果你不想保留换行符,可以手动去掉它: - -```c -str[strcspn(str, "\n")] = '\0'; -``` - -fgets 是 **安全** 的,适合用于读取输入行或处理来自文件的文本行。 - -## 低级别输入函数 - -C 语言还提供了一些低级别的输入函数,这些函数直接与操作系统交互,通常用于更底层的输入/输出操作。 - -### `read` - -read 是 UNIX 系统调用,用于从文件描述符读取原始数据。它允许对低级别的文件操作进行更多控制,通常在系统编程中使用。 - -```c -ssize_t read(int fd, void *buf, size_t count); -``` - -- fd:文件描述符,表示从哪个文件或输入源读取数据。 -- buf:指向接收读取数据的缓冲区指针。 -- count:要读取的最大字节数。 -- 返回值:返回读取的字节数,如果遇到 EOF 返回 0,如遇错误返回-1。 - -```c -#include -#include -#include - -int main() { - char buffer[100]; - int fd = open("input.txt", O_RDONLY); - - if (fd == -1) { - perror("Error opening file"); - return 1; - } - - ssize_t bytesRead = read(fd, buffer, sizeof(buffer) - 1); - - if (bytesRead == -1) { - perror("Error reading file"); - close(fd); - return 1; - } - - buffer[bytesRead] = '\0'; // 确保字符串以空字符结束 - printf("Read from file: %s\n", buffer); - - close(fd); - return 0; -} -``` - -贴近操作系统是有代价的:不会自动处理文本行结束符或进行缓冲区管理,因此需要手动管理输入数据。 - -### `getch` 和 `getche` - -这两个函数用于从标准输入读取单个字符,并且不需要按下回车键即可输入。它们通常用于处理键盘输入,特别是在控制台应用程序中。 - -- getch:读取单个字符,不显示在屏幕上。 -- getche:读取单个字符,并显示在屏幕上。 - -这些函数不是标准 C 库的一部分,通常在 Windows 环境中由`conio.h`提供。 - -```c -int getch(void); -int getche(void); -``` - -```c -#include -#include - -int main() { - char c; - - printf("Press a key: "); - c = getch(); // 使用 getch() 获取字符,但不显示 - printf("\nYou pressed: %c\n", c); - - printf("Press another key: "); - c = getche(); // 使用 getche() 获取字符,并显示 - printf("\nYou pressed: %c\n", c); - - return 0; -} -``` - -**Windows Only!** - -## 文件输入函数 - -C 语言还提供了一些用于从文件中读取数据的函数。 - -### `fread` - -fread 用于从文件中读取原始数据块,它通常用于二进制文件的读取。 - -```c -size_t fread(void *ptr, size_t size, size_t nmemb, FILE *stream); -``` - -- ptr:指向接收读取数据的缓冲区指针。 -- size:每个数据块的大小(字节数)。 -- nmemb:要读取的块的数量。 -- stream:文件流指针。 -- 返回值:返回成功读取的块数。 - -```c -#include - -int main() { - FILE *file = fopen("data.bin", "rb"); - if (file == NULL) { - printf("Error opening file!\n"); - return 1; - } - - int data[5]; - size_t bytesRead = fread(data, sizeof(int), 5, file); - - for (size_t i = 0; i < bytesRead; i++) { - printf("data[%zu] = %d\n", i, data[i]); - } - - fclose(file); - return 0; -} -``` - -- fread 直接读取原始数据块,而不是格式化的数据。适用于读取二进制文件或自定义文件格式。 -- 确保缓冲区足够大以容纳读取的数据,并检查返回值以确定是否成功读取。 - -### `getline` - -getline 用于从文件或标准输入中读取整行数据,包括换行符。它动态分配缓冲区以适应读取的数据行大小。 - -```c -ssize_t getline(char **lineptr, size_t *n, FILE *stream); -``` - -- lineptr:指向缓冲区的指针指针。如果指针指向 NULL,getline 会自动分配缓冲区。 -- n:指向缓冲区大小的指针。 -- stream:文件流指针。 -- 返回值:返回读取的字符数(包括换行符),如果遇到 EOF 返回-1。 - -```c -#include -#include - -int main() { - char *line = NULL; - size_t len = 0; - ssize_t read; - - FILE *file = fopen("input.txt", "r"); - if (file == NULL) { - printf("Error opening file!\n"); - return 1; - } - - while ((read = getline(&line, &len, file)) != -1) { - printf("Retrieved line of length %zu: %s", read, line); - } - - free(line); - fclose(file); - return 0; -} -``` - -**getline 是 POSIX 标准的扩展函数** - -## 安全输入(以`scanf_s`为例) - -scanf_s 是 scanf 函数的安全版本,它通过增加对输入长度的控制,来避免输入数据超过目标缓冲区的大小,从而降低缓冲区溢出的风险。 - -```c -int scanf_s(const char *format, ...); -``` - -- format:格式字符串,与 scanf 的格式字符串类似,用于指定输入的数据类型。 -- ...:可变参数,指定要存储输入数据的变量地址。对于某些类型的输入(如字符串或字符数组),需要额外指定目标缓冲区的大小。 - -```c -#include - -int main() { - char buffer[10]; - int num; - - printf("Enter a number and a string: "); - scanf_s("%d %9s", &num, buffer, sizeof(buffer)); - - printf("You entered: %d and %s\n", num, buffer); - return 0; -} -``` - -在这个示例中,%d 用于读取一个整数,%9s 用于读取最多 9 个字符的字符串。字符串后面的 sizeof(buffer) 指定了缓冲区的大小,以确保不会读入超过缓冲区容量的字符串。 -scanf_s 在读取字符串时要求提供额外的参数,指定目标缓冲区的大小。这个参数是 **必需** 的。 - -## 总结 - -C 语言提供了多种输入方式来满足不同的需求,从简单的标准输入函数到复杂的文件输入函数,每种方式都有其适用场景和使用注意点: - -- 标准输入函数:scanf、fscanf、sscanf 用于读取格式化的数据,适合从标准输入或文件中获取结构化数据。 -- 字符输入函数:getchar、fgetc、getc 用于逐字符读取输入,适合处理逐字符输入或文件内容。 -- 行输入函数:fgets 和 getline 用于读取整行数据,fgets 安全性高,getline 适合处理动态长度行。 -- 低级别输入函数:read、getch、getche 用于系统级编程或控制台输入,适合需要直接与硬件或操作系统交互的场景。 -- 文件输入函数:fread 用于读取二进制数据块,适合处理二进制文件或自定义格式的文件。 - -掌握这些输入函数,可以让你在 C 语言编程中灵活地处理各种输入数据,满足不同的编程需要 AwA diff --git a/doc/extra/rrii.md b/doc/extra/rrii.md deleted file mode 100644 index 37e88f86..00000000 --- a/doc/extra/rrii.md +++ /dev/null @@ -1,263 +0,0 @@ -# 临时释放锁:Antilock 模式 (RRII) - -## 前言 - -在多线程编程中,正确管理线程同步是确保程序稳定性和性能的关键。C++ 提供了多种工具来帮助开发者实现线程同步,例如 `std::mutex`、`std::lock_guard` 和 `std::unique_lock` 等。这些工具虽然强大,但在某些复杂场景下,可能需要更灵活的锁管理方式。比如说有的时候我在加锁后,在局部代码中需要释放锁,然后后续运行又需要加锁,这个时候我们虽然可以通过`unlock`和`lock`组合完成,但是代码变长后难免会出现遗忘的情况,从而产生错误。那么,本文将介绍一种名为“Antilock”的模式,它能够在需要时暂时释放锁,并在操作完成后自动重新获取锁,从而避免潜在的死锁和遗忘问题。 - -本文的理念来自 [Antilock 模式](https://devblogs.microsoft.com/oldnewthing/20240814-00/?p=110129) - -## 概念 - -### 互斥锁与 RAII 模式 - -在 C++ 中,互斥锁(mutex)用于保护共享资源,防止多个线程同时访问而导致数据不一致。为了简化锁的管理,C++ 标准库引入了 RAII(Resource Acquisition Is Initialization)模式。RAII 模式通过在对象的构造函数中获取资源,在析构函数中释放资源,确保资源管理的安全性。 -例如,`std::lock_guard` 就是一个常用的 RAII 类型,它会在构造时自动锁定互斥锁,并在析构时自动解锁: - -```cpp -std::mutex mtx; -{ - std::lock_guard guard(mtx); - // 这里的代码块在锁的保护下执行 -} // 离开作用域时,锁自动释放 -``` - -### Antilock 模式 - -Antilock 模式是一种反作用的锁管理策略,它的作用是暂时释放一个已经锁定的互斥锁,并在特定操作完成后重新获取该锁。这种模式特别适合需要在锁的保护下执行部分操作,但又需要在某些时候释放锁以避免死锁的场景。 - -那么这种机制是否可以称为`RRII`呢?即`Resource Release Is Initialization`! - -## 实现 - -### 基本实现 - -我们首先定义一个简单的 Antilock 模板类,它可以接受一个互斥锁对象,在构造时解锁该互斥锁,并在析构时重新锁定它: - -```cpp -template -class Antilock { -public: - Antilock() = default; - - explicit Antilock(Mutex& mutex) - : m_mutex(std::addressof(mutex)) { - if (m_mutex) { // 这里做一个检查如果互斥锁有效 - m_mutex->unlock(); - std::cout << "[Antilock] Lock released.\n"; - } - } - - ~Antilock() { - if (m_mutex) { - m_mutex->lock(); - std::cout << "[Antilock] Lock reacquired.\n"; - } - } - -private: - Mutex* m_mutex = nullptr; // 指向互斥锁的指针 -}; -``` - -这个简单的 `Antilock` 类在构造时解锁传入的 `Mutex` 对象,并在析构时重新锁定它。通过这种方式,可以安全地在一个作用域内暂时释放锁,执行其他操作。 - -### 支持 Guard 的扩展 - -在某些情况下,我们可能不希望 `Antilock` 直接操作互斥锁,而是通过一个管理锁的 `Guard` 对象来间接操作锁。例如,我们可以使用 `std::unique_lock` 作为 Guard,这样可以更加灵活。 - -下面是一个简单的例子: - -```cpp -template -class Antilock { -public: -Antilock() = default; - - explicit Antilock(Guard& guard) - : m_mutex(guard.mutex()) { // 使用 Guard 的 mutex 方法获取互斥锁 - if (m_mutex) { - m_mutex->unlock(); - std::cout << "[Antilock] Lock released.\n"; - } - } - - ~Antilock() { - if (m_mutex) { - m_mutex->lock(); - std::cout << "[Antilock] Lock reacquired.\n"; - } - } - -private: - typename Guard::mutex_type* m_mutex = nullptr; -}; -``` - -在这个版本中,`Antilock` 使用 `Guard` 对象的 `mutex()` 方法来获取互斥锁的指针,从而实现对锁的间接管理。 - -## 使用场景 - -### 调用外部操作 - -假设我们有一个共享资源需要在多线程环境下访问,同时我们需要在访问资源的过程中调用一个外部操作。由于外部操作可能会导致长时间等待或死锁,我们希望在执行外部操作时释放锁,操作完成后重新获取锁。 -以下是一个使用 `Antilock` 模式的示例: - -```cpp -class Resource { -public: - void DoSomething() { - std::unique_lock lock(m_mutex); - std::cout << "[Resource] Doing something under lock.\n"; - // 临时释放锁,执行外部操作 - Antilock> antilock(lock); - m_data = ExternalOperation(); - std::cout << "[Resource] Finished doing something.\n"; - } - -private: - int ExternalOperation() { - std::this_thread::sleep_for(std::chrono::seconds(1)); // 模拟耗时操作 - std::cout << "[External] External operation completed.\n"; - return 42; // 返回一个结果 - } - - std::mutex m_mutex; // 保护共享数据的互斥锁 - int m_data = 0; // 共享数据 -}; -``` - -在这个示例中,`DoSomething` 方法使用 `std::unique_lock` 锁定互斥锁,然后使用 `Antilock` 在执行外部操作时释放锁。外部操作完成后,锁会被自动重新获取。 - -### 多线程 - -为了更好地展示 `Antilock` 模式的优势,我们可以创建多个线程同时访问共享资源 - -```cpp -int main() { -Resource resource; - - // 创建多个线程来访问共享资源 - std::vector threads; - for (int i = 0; i < 5; ++i) { - threads.emplace_back(&Resource::DoSomething, &resource); - } - - // 等待所有线程完成 - for (auto& thread : threads) { - thread.join(); - } - - return 0; - -} -``` - -在这个例子中,我们创建了多个线程,每个线程都会调用 Resource 对象的 DoSomething 方法。Antilock 模式确保了在外部操作期间,锁会被暂时释放,从而避免了可能的死锁。 - -## 完整例子 - -[CE链接](https://godbolt.org/z/7nEqT8n5M) - -```cpp -#include -#include -#include -#include -#include -#include -#include -#include - -// 定义支持 Guard 的 Antilock 模板类 -template -class Antilock { -public: - Antilock() = default; - - explicit Antilock(Guard& guard) - : m_mutex(guard.mutex()) { // 使用 Guard 的 mutex 方法获取互斥锁 - if (m_mutex) { - m_mutex->unlock(); - std::cout << "[Antilock] Lock released.\n"; - } - } - - ~Antilock() { - if (m_mutex) { - m_mutex->lock(); - std::cout << "[Antilock] Lock reacquired.\n"; - } - } - -private: - typename Guard::mutex_type* m_mutex = nullptr; // 指向互斥锁的指针 -}; - -// 模拟一个资源类,包含一个互斥锁和条件变量来保护和协调共享数据 -class Resource { -public: - void DoSomething() { - std::unique_lock lock(m_mutex); - - // 使用条件变量等待某些条件满足 - m_condVar.wait(lock, [this] { return m_data.load() == 0; }); - - std::cout << "[Resource] Doing something under lock.\n"; - - try { - // 临时释放锁,执行外部操作 - Antilock> antilock(lock); - m_data.store(ExternalOperation()); - } catch (const std::exception& ex) { - std::cerr << "[Error] Exception caught: " << ex.what() << "\n"; - // 处理异常的情况下,需要确保锁状态的一致性 - } - - std::cout << "[Resource] Finished doing something.\n"; - m_condVar.notify_all(); // 通知其他等待的线程 - } - -private: - int ExternalOperation() { - std::this_thread::sleep_for(std::chrono::seconds(1)); // 模拟耗时操作 - std::cout << "[External] External operation completed.\n"; - return 42; // 返回一个结果 - } - - std::mutex m_mutex; // 保护共享数据的互斥锁 - std::condition_variable m_condVar; // 用于线程同步的条件变量 - std::atomic m_data{0}; // 线程安全的共享数据 -}; - -// 主函数 -int main() { - Resource resource; - - // 创建多个线程来访问共享资源 - std::vector threads; - for (int i = 0; i < 5; ++i) { - threads.emplace_back(&Resource::DoSomething, &resource); - } - - // 等待所有线程完成 - for (auto& thread : threads) { - thread.join(); - } - - return 0; -} -``` - -## 总结 - -### 优点 - -- 灵活性:Antilock 模式允许在需要时暂时释放锁,从而执行一些可能导致死锁的操作。 -- 自动管理:通过 RAII 模式,Antilock 在作用域结束时自动重新获取锁,无需手动管理锁的状态。 - -### 注意事项 - -- 锁状态的复杂性:使用 Antilock 可能会增加锁状态的复杂性,尤其是在嵌套锁定或递归锁定的场景中。 -- 异常处理:在实现 Antilock 模式时,确保在出现异常的情况下,锁能够被正确管理,以避免锁定状态不一致的问题。 - -通过本文的介绍,希望大家能对 Antilock 模式有了更深入的了解,并能够在实际项目中应用这一模式来解决复杂的锁管理问题。如果有问题,欢迎大家在评论区中指出,新人发文,求大佬们轻喷! diff --git a/doc/loading.drawio b/doc/loading.drawio index 5bc0f8c9..9a3e311f 100644 --- a/doc/loading.drawio +++ b/doc/loading.drawio @@ -1,6 +1,6 @@ - + @@ -692,6 +692,68 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/doc/task/camera/smart_epxosure.md b/doc/task/camera/smart_epxosure.md index e974dfc8..c5581f4c 100644 --- a/doc/task/camera/smart_epxosure.md +++ b/doc/task/camera/smart_epxosure.md @@ -1,43 +1,97 @@ -# SmartExposure Class Overview +# `SmartExposure` 类概述 + +`SmartExposure` 类是 N.I.N.A.(Nighttime Imaging 'N' Astronomy)软件中的关键组件,专门用于处理成像序列。以下是其逻辑、事件处理、验证、克隆等方面的详细描述。 The `SmartExposure` class is a key component within the N.I.N.A. (Nighttime Imaging 'N' Astronomy) software, designed to handle imaging sequences. Below is a detailed breakdown of its logic, event handling, validation, cloning, and more. +--- + +## 关键功能 + ## Key Functionalities -- **Deserialization Initialization** +### **反序列化初始化** + +### **Deserialization Initialization** + +- 在反序列化过程中清除所有序列项、条件和触发器,以确保一个干净的开始。 +- Clears all sequence items, conditions, and triggers during the deserialization process to ensure a clean start. + +--- + +### **构造函数** + +### **Constructor** + +- 接受依赖项,如 `profileService`、`cameraMediator` 和 `filterWheelMediator`。 +- Initializes with dependencies like `profileService`, `cameraMediator`, and `filterWheelMediator`. + +- 初始化核心组件,包括 `SwitchFilter`、`TakeExposure`、`LoopCondition` 和 `DitherAfterExposures`。 +- Initializes core components including `SwitchFilter`, `TakeExposure`, `LoopCondition`, and `DitherAfterExposures`. + +--- + +### **事件处理** + +### **Event Handling** + +- 监控 `SwitchFilter` 属性的变化。如果滤镜在序列创建或执行过程中发生变化,则进度会被重置。 +- Monitors changes to the `SwitchFilter` property. If the filter changes during sequence creation or execution, the progress is reset. - - Clears all sequence items, conditions, and triggers during the deserialization process to ensure a clean start. +--- -- **Constructor** +### **错误处理** - - Accepts dependencies like `profileService`, `cameraMediator`, and `filterWheelMediator`. - - Initializes core components including `SwitchFilter`, `TakeExposure`, `LoopCondition`, and `DitherAfterExposures`. +### **Error Handling** -- **Event Handling** +- 包含管理错误行为的属性 (`ErrorBehavior`) 和重试尝试次数 (`Attempts`),确保在执行过程中进行稳健的错误管理。 +- Incorporates properties for managing error behavior (`ErrorBehavior`) and retry attempts (`Attempts`), ensuring robust error management during execution. - - Monitors changes to the `SwitchFilter` property. If the filter changes during sequence creation or execution, the progress is reset. +--- -- **Error Handling** +### **验证** - - Incorporates properties for managing error behavior (`ErrorBehavior`) and retry attempts (`Attempts`), ensuring robust error management during execution. +### **Validation** -- **Validation** +- `Validate` 方法检查所有内部组件的配置,并返回一个布尔值,指示该序列是否可以执行。 +- The `Validate` method checks the configuration of all internal components and returns a boolean indicating whether the sequence is valid for execution. - - The `Validate` method checks the configuration of all internal components and returns a boolean indicating whether the sequence is valid for execution. +--- -- **Cloning** +### **克隆** - - Provides a `Clone` method to create a deep copy of the `SmartExposure` instance, including all its associated components. +### **Cloning** -- **Duration Estimation** +- 提供 `Clone` 方法来创建 `SmartExposure` 实例的深度拷贝,包括所有关联的组件。 +- Provides a `Clone` method to create a deep copy of the `SmartExposure` instance, including all its associated components. - - Calculates the estimated duration for the sequence based on the exposure settings. +--- -- **Interrupt Handling** - - If an interruption occurs, it is rerouted to the parent sequence for consistent behavior. +### **持续时间估算** + +### **Duration Estimation** + +- 根据曝光设置计算序列的预计持续时间。 +- Calculates the estimated duration for the sequence based on the exposure settings. + +--- + +### **中断处理** + +### **Interrupt Handling** + +- 如果发生中断,它会被重新引导到父序列以保持一致的行为。 +- If an interruption occurs, it is rerouted to the parent sequence for consistent behavior. + +--- + +## 流程图 ## Flowchart +以下是 `SmartExposure` 类执行过程中关键步骤的流程图。 +Below is a flowchart outlining the key steps in the execution process of the `SmartExposure` class. + ```mermaid graph TD; A[Start] --> B{OnDeserializing}; @@ -69,3 +123,17 @@ graph TD; T --> U[Interrupt Handling]; U --> V[End]; ``` + +--- + +### 解释 + +### Explanation + +1. **OnDeserializing**: 在反序列化过程中,清除所有项目、条件和触发器,确保干净的状态。 +2. **Constructor Called**: 构造函数初始化各个核心组件。 +3. **Event Handling**: 监听属性变化,如果 `SwitchFilter` 改变,进度将被重置。 +4. **Validate**: 验证序列的有效性,确保其可以执行。 +5. **Cloning**: 创建 `SmartExposure` 的深度克隆副本。 +6. **Estimate Duration**: 计算执行过程的预计时间。 +7. **Interrupt Handling**: 如果发生中断,将其引导至父序列。 diff --git a/doc/task/guider/dither.md b/doc/task/guider/dither.md index 278a14be..981c5919 100644 --- a/doc/task/guider/dither.md +++ b/doc/task/guider/dither.md @@ -1,12 +1,30 @@ -# Dither +# `Dither` 类 -The `Dither` class in the N.I.N.A. application is designed to execute the dithering process during an astronomical imaging session. Dithering is a technique used to slightly move the telescope between exposures to reduce the impact of fixed pattern noise in the final stacked image. This class is essential for achieving higher-quality images by ensuring that noise patterns do not align across exposures. +`Dither` 类在 N.I.N.A.(Nighttime Imaging 'N' Astronomy)应用程序中用于在天文摄影会话期间执行抖动过程。抖动是一种技术,通过在曝光之间轻微移动望远镜来减少固定模式噪声对最终堆叠图像的影响。该类确保抖动过程的执行,从而提高图像质量,防止噪声模式在多个曝光中对齐。 + +The `Dither` class in the N.I.N.A. (Nighttime Imaging 'N' Astronomy) application is designed to execute the dithering process during an astronomical imaging session. Dithering is a technique used to slightly move the telescope between exposures to reduce the impact of fixed pattern noise in the final stacked image. This class is essential for achieving higher-quality images by ensuring that noise patterns do not align across exposures. + +--- + +## 类概述 ## Class Overview +### 命名空间 + ### Namespace +- **命名空间:** `NINA.Sequencer.SequenceItem.Guider` - **Namespace:** `NINA.Sequencer.SequenceItem.Guider` + +- **依赖项:** + + - `NINA.Core.Model` + - `NINA.Sequencer.Validations` + - `NINA.Equipment.Interfaces.Mediator` + - `NINA.Core.Locale` + - `NINA.Profile.Interfaces` + - **Dependencies:** - `NINA.Core.Model` - `NINA.Sequencer.Validations` @@ -14,6 +32,10 @@ The `Dither` class in the N.I.N.A. application is designed to execute the dither - `NINA.Core.Locale` - `NINA.Profile.Interfaces` +--- + +### 类声明 + ### Class Declaration ```csharp @@ -26,14 +48,29 @@ The `Dither` class in the N.I.N.A. application is designed to execute the dither public class Dither : SequenceItem, IValidatable ``` +--- + +### 类属性 + ### Class Properties +- **guiderMediator**: 负责与引导硬件的通信,以执行抖动命令。 - **guiderMediator**: Handles communication with the guider hardware to execute the dithering commands. + +- **profileService**: 提供对活动配置文件设置的访问,包括引导器配置。 - **profileService**: Provides access to the active profile settings, including guider configuration. + +- **Issues**: 在类验证过程中发现的问题列表,尤其是与引导器连接状态相关的问题。 - **Issues**: A list of issues found during the validation of the class, particularly related to the connection status of the guider. +--- + +### 构造函数 + ### Constructor +构造函数初始化 `Dither` 类,并依赖于引导器中介器和配置文件服务,确保在执行抖动过程中所需的组件可用。 + The constructor initializes the `Dither` class with dependencies on the guider mediator and profile service, ensuring the necessary components are available for executing the dithering process. ```csharp @@ -41,17 +78,35 @@ The constructor initializes the `Dither` class with dependencies on the guider m public Dither(IGuiderMediator guiderMediator, IProfileService profileService) ``` +--- + +### 关键方法 + ### Key Methods +- **Execute(IProgress progress, CancellationToken token)**: 该方法通过向引导器发出抖动命令来执行抖动过程。 - **Execute(IProgress progress, CancellationToken token)**: This method executes the dithering process by issuing the dither command to the guider. + +- **Validate()**: 在尝试抖动之前,检查引导器是否连接并正常工作。 - **Validate()**: Checks if the guider is connected and operational before attempting to dither. + +- **AfterParentChanged()**: 当父级序列项发生更改时,验证引导器的连接状态。 - **AfterParentChanged()**: Validates the connection status of the guider when the parent sequence item changes. + +- **GetEstimatedDuration()**: 返回抖动过程的预计持续时间,基于配置文件中的引导器设置。 - **GetEstimatedDuration()**: Returns an estimated duration for the dithering process, based on the profile's guider settings. + +- **Clone()**: 创建 `Dither` 对象的深拷贝。 - **Clone()**: Creates a deep copy of the `Dither` object. +--- + +### 流程图:执行过程 + ### Flowchart: Execution Process -Below is a flowchart that outlines the key steps in the `Execute` method of the `Dither` class. +以下是 `Dither` 类中 `Execute` 方法的关键步骤流程图。 +Below is a flowchart outlining the key steps in the `Execute` method of the `Dither` class. ```mermaid flowchart TD @@ -63,29 +118,70 @@ flowchart TD F --> G[End Execution] ``` +--- + +### 流程图解释 + ### Flowchart Explanation +1. **引导器是否连接?**:首先检查引导器是否已连接且准备执行命令。 + + - **否:** 如果引导器未连接,会记录一个连接问题,并中止执行。 + - **是:** 如果引导器已连接,继续执行下一步。 + 1. **Is Guider Connected?:** The process begins by checking whether the guider is connected and ready to execute commands. + - **No:** If the guider is not connected, an issue is logged indicating the connection problem, and the execution is aborted. - **Yes:** If the guider is connected, the process continues. -2. **Issue Dither Command to Guider:** The `Dither` class sends a command to the guider to perform the dithering operation. -3. **Await Dither Completion:** The system waits for the guider to complete the dithering process. -4. **End Execution:** The dithering process concludes, and control is returned to the sequence executor. + +1. **向引导器发出抖动命令**:`Dither` 类向引导器发送命令,执行抖动操作。 +1. **等待抖动完成**:系统等待引导器完成抖动过程。 +1. **结束执行**:抖动过程结束后,控制返回给序列执行器。 + +1. **Issue Dither Command to Guider:** The `Dither` class sends a command to the guider to perform the dithering operation. +1. **Await Dither Completion:** The system waits for the guider to complete the dithering process. +1. **End Execution:** The dithering process concludes, and control is returned to the sequence executor. + +--- + +### 方法详细描述 ### Detailed Method Descriptions +#### `Execute` 方法 + #### `Execute` Method +`Execute` 方法负责向引导器发出抖动命令。它依赖于 `guiderMediator` 来与引导硬件通信。该方法确保在尝试抖动之前引导器已连接并准备就绪,从而防止运行时错误。 + The `Execute` method is responsible for issuing the dither command to the guider. It relies on the `guiderMediator` to communicate with the guider hardware. The method ensures that the guider is connected and ready before attempting to dither, thus preventing runtime errors. +--- + +#### `Validate` 方法 + #### `Validate` Method +`Validate` 方法检查引导器的连接状态。如果引导器未连接,方法会在 `Issues` 列表中添加问题,用户可以查看并解决问题。此验证步骤对于确保抖动过程能够无误执行至关重要。 + The `Validate` method checks the connection status of the guider. If the guider is not connected, the method adds an issue to the `Issues` list, which can be reviewed by the user to troubleshoot the problem. This validation step is critical for ensuring that the dithering process can be executed without errors. +--- + +#### `AfterParentChanged` 方法 + #### `AfterParentChanged` Method +每当父级序列项发生更改时,都会调用 `AfterParentChanged` 方法。此方法会重新验证 `Dither` 项,以确保考虑到序列上下文中的任何变化,尤其是设备连接性方面的变化。 + The `AfterParentChanged` method is called whenever the parent sequence item is changed. It triggers a re-validation of the `Dither` item to ensure that any changes in the sequence context are taken into account, particularly in terms of equipment connectivity. +--- + +#### `GetEstimatedDuration` 方法 + #### `GetEstimatedDuration` Method +`GetEstimatedDuration` 方法返回完成抖动过程的预计时间。该估算基于活动配置文件中引导器设置的 `SettleTimeout` 值,提供抖动所需时间的大致时间线。 + The `GetEstimatedDuration` method returns the estimated time required to complete the dithering process. This estimate is based on the `SettleTimeout` value specified in the active profile's guider settings, providing a rough timeline for how long the dithering will take. diff --git a/doc/task/guider/start_guiding.md b/doc/task/guider/start_guiding.md index e404da0a..374ae30f 100644 --- a/doc/task/guider/start_guiding.md +++ b/doc/task/guider/start_guiding.md @@ -1,18 +1,39 @@ -# StartGuiding +# `StartGuiding` 类 -The `StartGuiding` class in the N.I.N.A. application is designed to initiate the guiding process during an astronomical imaging session. Guiding is a critical process in astrophotography, where a separate guider camera or system is used to keep the telescope precisely aligned with the target object. This class ensures that the guiding process starts correctly and optionally forces a new calibration. +`StartGuiding` 类在 N.I.N.A.(Nighttime Imaging 'N' Astronomy)应用程序中用于启动引导过程,这是天文摄影会话中的关键步骤。在天文摄影中,使用一个独立的引导相机或系统来保持望远镜与目标物体的精确对齐。该类确保引导过程正确启动,并且可以选择强制进行新的校准。 + +The `StartGuiding` class in the N.I.N.A. (Nighttime Imaging 'N' Astronomy) application is used to initiate the guiding process, which is a critical step in an astronomical imaging session. In astrophotography, a separate guiding camera or system is used to keep the telescope precisely aligned with the target object. This class ensures that the guiding process starts correctly and optionally forces a new calibration. + +--- + +## 类概述 ## Class Overview +### 命名空间 + ### Namespace +- **命名空间:** `NINA.Sequencer.SequenceItem.Guider` - **Namespace:** `NINA.Sequencer.SequenceItem.Guider` + +- **依赖项:** + + - `NINA.Core.Model` + - `NINA.Sequencer.Validations` + - `NINA.Equipment.Interfaces.Mediator` + - `NINA.Core.Locale` + - **Dependencies:** - `NINA.Core.Model` - `NINA.Sequencer.Validations` - `NINA.Equipment.Interfaces.Mediator` - `NINA.Core.Locale` +--- + +### 类声明 + ### Class Declaration ```csharp @@ -25,31 +46,62 @@ The `StartGuiding` class in the N.I.N.A. application is designed to initiate the public class StartGuiding : SequenceItem, IValidatable ``` +--- + +### 类属性 + ### Class Properties -- **guiderMediator**: Manages communication with the guider hardware, handling the start of the guiding process. -- **ForceCalibration**: A boolean flag indicating whether the guiding process should force a new calibration before starting. -- **Issues**: A list of issues identified during validation, particularly related to the guider’s connection status and calibration capability. +- **guiderMediator**: 管理与引导硬件的通信,处理引导过程的启动。 +- **guiderMediator**: Manages communication with the guiding hardware, handling the start of the guiding process. + +- **ForceCalibration**: 一个布尔值标志,指示是否应在启动引导之前强制进行新的校准。 +- **ForceCalibration**: A boolean flag indicating whether a new calibration should be forced before starting the guiding process. + +- **Issues**: 在验证过程中发现的问题列表,尤其与引导器的连接状态和校准能力相关。 +- **Issues**: A list of issues identified during validation, particularly related to the guider's connection status and calibration capability. + +--- + +### 构造函数 ### Constructor -The constructor initializes the `StartGuiding` class by setting up the connection with the guider mediator. This ensures that the class can interact with the guider system when executing the guiding process. +构造函数初始化 `StartGuiding` 类,并通过 `guiderMediator` 与引导设备建立连接。这确保了该类能够在执行引导过程中与引导系统交互。 + +The constructor initializes the `StartGuiding` class and establishes a connection with the guider hardware via the `guiderMediator`. This ensures that the class can interact with the guider system during the guiding process. ```csharp [ImportingConstructor] public StartGuiding(IGuiderMediator guiderMediator) ``` +--- + +### 关键方法 + ### Key Methods +- **Execute(IProgress progress, CancellationToken token)**: 启动引导过程,可选择强制执行新的校准。如果过程失败,将抛出异常。 - **Execute(IProgress progress, CancellationToken token)**: Starts the guiding process, optionally forcing a new calibration. If the process fails, an exception is thrown. -- **Validate()**: Validates the connection to the guider and checks if forced calibration is possible, updating the `Issues` list if any problems are detected. + +- **Validate()**: 验证引导器的连接,并检查是否可以强制校准,若发现问题则更新 `Issues` 列表。 +- **Validate()**: Validates the guider connection and checks whether forced calibration is possible. It updates the `Issues` list if any problems are detected. + +- **AfterParentChanged()**: 每当父级序列项发生变化时,重新验证引导器的连接和能力。 - **AfterParentChanged()**: Re-validates the guider connection and capabilities whenever the parent sequence item changes. + +- **Clone()**: 创建 `StartGuiding` 对象的副本,保留其属性和元数据。 - **Clone()**: Creates a copy of the `StartGuiding` object, preserving its properties and metadata. +--- + +### 流程图:执行过程 + ### Flowchart: Execution Process -Below is a flowchart that outlines the key steps in the `Execute` method of the `StartGuiding` class. +以下是 `StartGuiding` 类中 `Execute` 方法的关键步骤流程图。 +Below is a flowchart outlining the key steps in the `Execute` method of the `StartGuiding` class. ```mermaid flowchart TD @@ -67,37 +119,75 @@ flowchart TD J -->|Yes| L[End Execution] ``` +--- + +### 流程图解释 + ### Flowchart Explanation -1. **Is Guider Connected?**: The process begins by verifying that the guider is connected and ready. +1. **引导器是否连接?**:首先检查引导器是否已连接且准备就绪。 + + - **否:** 如果引导器未连接,抛出异常并终止过程。 + - **是:** 如果已连接,继续执行下一步。 + +1. **Is Guider Connected?**: The process begins by verifying whether the guider is connected and ready. + - **No:** If the guider is not connected, an exception is thrown, aborting the process. - **Yes:** If connected, the process continues to the next step. -2. **Force Calibration?**: Checks if the guiding process should force a new calibration. - - **Yes:** If calibration is required, the system checks if the guider can clear its current calibration. + +1. **是否强制校准?**:检查引导过程中是否需要强制执行新校准。 + + - **是:** 如果需要校准,系统将检查引导器是否可以清除当前校准数据。 + - **否:** 如果不强制校准,则直接开始引导。 + +1. **Force Calibration?**: Checks whether a new calibration should be forced during the guiding process. + + - **Yes:** If calibration is required, the system checks if the guider can clear its current calibration data. - **No:** If no calibration is forced, guiding begins without recalibration. -3. **Can Guider Clear Calibration?**: If calibration is forced, this step checks whether the guider can clear its existing calibration. + +1. **引导器能否清除校准?**:如果需要校准,此步骤检查引导器是否能清除现有校准数据。 + + - **否:** 如果不能清除校准,抛出异常。 + - **是:** 如果可以,系统开始带有校准的引导。 + +1. **Can Guider Clear Calibration?**: If calibration is required, this step checks whether the guider can clear existing calibration data. + - **No:** If the guider cannot clear calibration, an exception is thrown. - - **Yes:** If it can, the system proceeds to start guiding with calibration. -4. **Start Guiding**: The guiding process is initiated, either with or without calibration based on the previous steps. -5. **Await Guiding Completion**: The system waits for the guiding process to either succeed or fail. -6. **Guiding Started Successfully?**: A final check to confirm that guiding has started successfully. + - **Yes:** If possible, the system proceeds to start guiding with calibration. + +1. **开始引导**:根据前面的步骤,启动引导过程,无论是否带有校准。 +1. **等待引导完成**:系统等待引导过程的成功或失败。 +1. **引导是否成功启动?**:最后检查引导是否成功启动。 + + - **否:** 如果引导失败,抛出异常。 + - **是:** 如果成功,则执行完成。 + +1. **Start Guiding**: The guiding process is initiated, either with or without calibration based on previous steps. +1. **Await Guiding Completion**: The system waits for the guiding process to either succeed or fail. +1. **Guiding Started Successfully?**: A final check is made to confirm whether guiding has started successfully. - **No:** If guiding fails, an exception is thrown. - **Yes:** If successful, the process completes. +--- + +### 方法详细描述 + ### Detailed Method Descriptions +#### `Execute` 方法 + #### `Execute` Method -The `Execute` method is the core of the `StartGuiding` class. It handles starting the guiding process, including the optional step of forcing a new calibration. The method uses the `guiderMediator` to interact with the guider hardware, ensuring the process is executed correctly. If any issues arise—such as a failure to start guiding or the inability to clear calibration—an exception is thrown to halt the sequence. +`Execute` 方法是 `StartGuiding` 类的核心,它负责启动引导过程,包括选择性地强制执行新校准。该方法通过 `guiderMediator` 与引导硬件交互,确保过程正确执行。如果在引导启动或清除校准时出现问题,会抛出异常以中断序列。 -#### `Validate` Method +The `Execute` method is the core of the `StartGuiding` class. It is responsible for starting the guiding process, including optionally forcing a new calibration. The method interacts with the guider hardware through `guiderMediator` to ensure the process is executed correctly. If any issues arise—such as failing to start guiding or being unable to clear calibration—an exception is thrown to halt the sequence. -The `Validate` method checks the readiness of the guiding system. It ensures that the guider is connected and, if necessary, verifies that it can clear calibration data before starting a new calibration. The results of this validation are stored in the `Issues` list, which provides a way to identify and troubleshoot potential problems before executing the guiding process. +--- -#### `AfterParentChanged` Method +#### `Validate` 方法 -The `AfterParentChanged` method is invoked whenever the parent sequence item changes. This triggers a re-validation of the `StartGuiding` class to ensure that any contextual changes—such as different equipment or settings—are taken into account. It helps maintain the integrity of the sequence by ensuring that guiding can be started successfully in the new context. +#### `Validate` Method -#### `Clone` Method +`Validate` 方法检查引导系统的就绪状态。它确保引导器已连接,并在必要时验证引导器能否清除校准数据。验证结果存储在 `Issues` 列表中,以便在执行引导过程之前识别和解决潜在问题。 -The `Clone` method creates a new instance of the `StartGuiding` class with the same properties and metadata as the original. This allows the guiding process to be reused or repeated in different parts of a sequence without needing to manually reconfigure each instance. +The `Validate` method checks the readiness of the guiding system. It ensures that the guider is connected and, if necessary, verifies that it diff --git a/doc/task/guider/stop_guiding.md b/doc/task/guider/stop_guiding.md index 5092c7bd..323f5619 100644 --- a/doc/task/guider/stop_guiding.md +++ b/doc/task/guider/stop_guiding.md @@ -1,18 +1,39 @@ -# StopGuiding +# `StopGuiding` 类 -The `StopGuiding` class in the N.I.N.A. application is responsible for halting the guiding process during an astronomical imaging session. Guiding is a critical aspect of astrophotography, where a guiding camera or system keeps the telescope aligned with the target object. This class ensures that the guiding process stops correctly and validates that the guiding system is properly connected before attempting to stop guiding. +`StopGuiding` 类在 N.I.N.A.(Nighttime Imaging 'N' Astronomy)应用程序中负责在天文摄影会话期间停止引导过程。引导是天文摄影的关键步骤,其中引导相机或系统用于保持望远镜与目标物体的对齐。该类确保引导过程正确停止,并在尝试停止引导之前验证引导系统是否正确连接。 + +The `StopGuiding` class in the N.I.N.A. (Nighttime Imaging 'N' Astronomy) application is responsible for halting the guiding process during an astronomical imaging session. Guiding is a critical aspect of astrophotography, where a guiding camera or system keeps the telescope aligned with the target object. This class ensures that the guiding process stops correctly and validates that the guiding system is properly connected before attempting to stop guiding. + +--- + +## 类概述 ## Class Overview +### 命名空间 + ### Namespace +- **命名空间:** `NINA.Sequencer.SequenceItem.Guider` - **Namespace:** `NINA.Sequencer.SequenceItem.Guider` + +- **依赖项:** + + - `NINA.Core.Model` + - `NINA.Sequencer.Validations` + - `NINA.Equipment.Interfaces.Mediator` + - `NINA.Core.Locale` + - **Dependencies:** - `NINA.Core.Model` - `NINA.Sequencer.Validations` - `NINA.Equipment.Interfaces.Mediator` - `NINA.Core.Locale` +--- + +### 类声明 + ### Class Declaration ```csharp @@ -25,13 +46,26 @@ The `StopGuiding` class in the N.I.N.A. application is responsible for halting t public class StopGuiding : SequenceItem, IValidatable ``` +--- + +### 类属性 + ### Class Properties +- **guiderMediator**: 管理与引导硬件的通信,专门处理停止引导过程。 - **guiderMediator**: An interface that handles communication with the guider hardware, specifically managing the stop guiding process. + +- **Issues**: 在验证引导器连接状态时识别的问题列表。 - **Issues**: A list of issues that are identified during the validation of the guider's connection status. +--- + +### 构造函数 + ### Constructor +构造函数初始化 `StopGuiding` 类,并通过 `guiderMediator` 与引导设备建立连接,确保类能够与引导系统交互以停止引导过程。 + The constructor initializes the `StopGuiding` class by setting up the connection with the guider mediator, ensuring that the class can interact with the guider system to stop the guiding process. ```csharp @@ -39,16 +73,32 @@ The constructor initializes the `StopGuiding` class by setting up the connection public StopGuiding(IGuiderMediator guiderMediator) ``` +--- + +### 关键方法 + ### Key Methods +- **Execute(IProgress progress, CancellationToken token)**: 使用 `guiderMediator` 停止引导过程。如果出现任何问题,可能会抛出异常。 - **Execute(IProgress progress, CancellationToken token)**: Stops the guiding process using the `guiderMediator`. If any issues occur, an exception may be thrown. + +- **Validate()**: 在尝试停止引导之前,确保引导系统已连接。如果发现问题,更新 `Issues` 列表。 - **Validate()**: Ensures that the guiding system is connected before attempting to stop guiding. Updates the `Issues` list if any problems are found. + +- **AfterParentChanged()**: 每当父级序列项发生变化时,重新验证引导器的连接状态。 - **AfterParentChanged()**: Re-validates the guider connection whenever the parent sequence item changes. + +- **Clone()**: 创建 `StopGuiding` 对象的新实例,保留其属性和元数据。 - **Clone()**: Creates a new instance of the `StopGuiding` object, preserving its properties and metadata. +--- + +### 流程图:执行过程 + ### Flowchart: Execution Process -Below is a flowchart that outlines the key steps in the `Execute` method of the `StopGuiding` class. +以下是 `StopGuiding` 类中 `Execute` 方法的关键步骤流程图。 +Below is a flowchart outlining the key steps in the `Execute` method of the `StopGuiding` class. ```mermaid flowchart TD @@ -59,29 +109,70 @@ flowchart TD E --> F[End Execution] ``` +--- + +### 流程图解释 + ### Flowchart Explanation +1. **引导器是否连接?**:首先检查引导器是否已连接且准备就绪。 + + - **否:** 如果引导器未连接,抛出异常并终止过程。 + - **是:** 如果已连接,继续执行停止引导过程。 + 1. **Is Guider Connected?**: The process begins by verifying that the guider is connected and ready. + - **No:** If the guider is not connected, an exception is thrown, aborting the process. - **Yes:** If connected, the process continues to stop the guiding process. -2. **Stop Guiding Process**: The guider is instructed to stop the guiding process. -3. **Await Stopping Completion**: The system waits for the guiding process to stop completely. -4. **End Execution**: The method completes execution after successfully stopping the guider. + +1. **停止引导过程**:引导器被指示停止引导过程。 +1. **等待停止完成**:系统等待引导过程完全停止。 +1. **执行结束**:引导成功停止后,方法执行完成。 + +1. **Stop Guiding Process**: The guider is instructed to stop the guiding process. +1. **Await Stopping Completion**: The system waits for the guiding process to stop completely. +1. **End Execution**: The method completes execution after successfully stopping the guider. + +--- + +### 方法详细描述 ### Detailed Method Descriptions +#### `Execute` 方法 + #### `Execute` Method +`Execute` 方法是 `StopGuiding` 类的主要功能。它使用 `guiderMediator` 与引导硬件交互,发送停止引导的命令。该方法确保引导过程成功停止,如果出现任何问题,会适当地处理,可能会抛出异常。 + The `Execute` method is the primary function of the `StopGuiding` class. It uses the `guiderMediator` to interact with the guider hardware, sending the command to stop guiding. The method ensures that the guiding process halts successfully, and if any issues arise, it handles them appropriately, potentially throwing an exception. +--- + +#### `Validate` 方法 + #### `Validate` Method +`Validate` 方法检查引导系统是否正确连接,以便允许停止引导过程。它会更新 `Issues` 列表,记录遇到的任何问题,例如引导器断开连接。此验证对于在执行过程中防止错误至关重要。 + The `Validate` method checks that the guiding system is properly connected before allowing the guiding process to be stopped. It updates the `Issues` list with any problems that it encounters, such as the guider being disconnected. This validation is crucial to prevent errors during execution. +--- + +#### `AfterParentChanged` 方法 + #### `AfterParentChanged` Method +每当父级序列项发生变化时,都会调用 `AfterParentChanged` 方法。这会触发对 `StopGuiding` 类的重新验证,以确保任何上下文变化(如不同的设备或设置)都被考虑到。这有助于确保序列的可靠性,确认在新上下文中停止引导仍然合适。 + The `AfterParentChanged` method is called whenever the parent sequence item changes. This triggers a re-validation of the `StopGuiding` class to ensure that any contextual changes—such as different equipment or settings—are accounted for. This helps maintain the reliability of the sequence by confirming that stopping guiding is still appropriate in the new context. +--- + +#### `Clone` 方法 + #### `Clone` Method +`Clone` 方法创建 `StopGuiding` 类的新实例,保留所有属性和元数据。这在序列的不同部分重复停止引导过程时非常有用,而无需手动配置每个实例。 + The `Clone` method creates a new instance of the `StopGuiding` class, preserving all properties and metadata from the original instance. This is useful for repeating the stop guiding process in different parts of a sequence without manually configuring each instance. diff --git a/driverlibs b/driverlibs new file mode 160000 index 00000000..339f65c3 --- /dev/null +++ b/driverlibs @@ -0,0 +1 @@ +Subproject commit 339f65c35c2398b0cf2a7153123c5e5b2fa9c66e diff --git a/libs b/libs index f2b60c27..f1082244 160000 --- a/libs +++ b/libs @@ -1 +1 @@ -Subproject commit f2b60c276ef6fa16ce1b820e5062063c5f2416be +Subproject commit f1082244ea563500599fcc1726aadfb3e138c387 diff --git a/modules/lithium.pytools/tools/cbuilder.py b/modules/lithium.pytools/tools/cbuilder.py index e68654a1..2805c10c 100644 --- a/modules/lithium.pytools/tools/cbuilder.py +++ b/modules/lithium.pytools/tools/cbuilder.py @@ -25,61 +25,132 @@ import argparse import subprocess import sys +import os from pathlib import Path from typing import Literal, List, Optional -class CMakeBuilder: + +class BuildHelperBase: """ - CMakeBuilder is a utility class to handle building projects using CMake. + Base class for build helpers providing shared functionality. Args: source_dir (Path): Path to the source directory. build_dir (Path): Path to the build directory. - generator (Literal["Ninja", "Unix Makefiles"]): CMake generator to use. - build_type (Literal["Debug", "Release"]): Build type (Debug or Release). - install_prefix (Path): Installation prefix directory. - cmake_options (Optional[List[str]]): Additional options for CMake. + install_prefix (Path): Directory prefix where the project will be installed. + options (Optional[List[str]]): Additional options for the build system. + env_vars (Optional[dict]): Environment variables to set for the build process. + verbose (bool): Flag to enable verbose output during command execution. Methods: - run_command: Helper function to run shell commands. - configure: Configure the CMake build system. - build: Build the project. - install: Install the project. - clean: Clean the build directory. - test: Run CTest for the project. - generate_docs: Generate documentation if documentation target is available. + run_command: Executes shell commands with optional environment variables and verbosity. + clean: Cleans the build directory by removing all files and subdirectories. """ + def __init__( self, source_dir: Path, build_dir: Path, - generator: Literal["Ninja", "Unix Makefiles"] = "Ninja", - build_type: Literal["Debug", "Release"] = "Debug", - install_prefix: Path = None, # type: ignore - cmake_options: Optional[List[str]] = None, + install_prefix: Path = None, # type: ignore + options: Optional[List[str]] = None, + env_vars: Optional[dict] = None, + verbose: bool = False, ): self.source_dir = source_dir self.build_dir = build_dir - self.generator = generator - self.build_type = build_type self.install_prefix = install_prefix or build_dir / "install" - self.cmake_options = cmake_options or [] + self.options = options or [] + self.env_vars = env_vars or {} + self.verbose = verbose def run_command(self, *cmd: str): """ - Helper function to run shell commands. + Helper function to run shell commands with optional environment variables and verbosity. Args: - cmd (str): The command and its arguments to run. + cmd (str): The command and its arguments to run as separate strings. + + Raises: + SystemExit: Exits with the command's return code if it fails. """ print(f"Running: {' '.join(cmd)}") - result = subprocess.run(cmd, check=True, capture_output=True, text=True) - print(result.stdout) - if result.stderr: - print(result.stderr, file=sys.stderr) + env = os.environ.copy() + env.update(self.env_vars) + try: + result = subprocess.run( + cmd, check=True, capture_output=True, text=True, env=env) + if self.verbose or result.returncode != 0: + print(result.stdout) + if result.stderr: + print(result.stderr, file=sys.stderr) + except subprocess.CalledProcessError as e: + print(f"Error running command: {e}", file=sys.stderr) + sys.exit(e.returncode) + + def clean(self): + """ + Cleans the build directory by removing all files and subdirectories. + + This function ensures that the build directory is completely cleaned, making it ready + for a fresh build by removing all existing files and directories inside the build path. + """ + if self.build_dir.exists(): + for item in self.build_dir.iterdir(): + if item.is_dir(): + self.run_command("rm", "-rf", str(item)) + else: + item.unlink() + print(f"Cleaned: {self.build_dir}") + + +class CMakeBuilder(BuildHelperBase): + """ + CMakeBuilder is a utility class to handle building projects using CMake. + + Args: + source_dir (Path): Path to the source directory. + build_dir (Path): Path to the build directory. + generator (Literal["Ninja", "Unix Makefiles"]): The CMake generator to use (e.g., Ninja or Unix Makefiles). + build_type (Literal["Debug", "Release"]): Type of build (Debug or Release). + install_prefix (Path): Directory prefix where the project will be installed. + cmake_options (Optional[List[str]]): Additional options for CMake. + env_vars (Optional[dict]): Environment variables to set for the build process. + verbose (bool): Flag to enable verbose output during command execution. + parallel (int): Number of parallel jobs to use for building. + + Methods: + configure: Configures the CMake build system by generating build files. + build: Builds the project, optionally specifying a target. + install: Installs the project to the specified prefix. + test: Runs tests using CTest with detailed output on failure. + generate_docs: Generates documentation using the specified documentation target. + """ + + def __init__( + self, + source_dir: Path, + build_dir: Path, + generator: Literal["Ninja", "Unix Makefiles"] = "Ninja", + build_type: Literal["Debug", "Release"] = "Debug", + install_prefix: Path = None, # type: ignore + cmake_options: Optional[List[str]] = None, + env_vars: Optional[dict] = None, + verbose: bool = False, + parallel: int = 4, + ): + super().__init__(source_dir, build_dir, + install_prefix, cmake_options, env_vars, verbose) + self.generator = generator + self.build_type = build_type + self.parallel = parallel def configure(self): - """Configure the CMake build system.""" + """ + Configures the CMake build system. + + This function generates the necessary build files using CMake based on the + specified generator, build type, and additional CMake options provided. + """ self.build_dir.mkdir(parents=True, exist_ok=True) cmake_args = [ "cmake", @@ -87,96 +158,100 @@ def configure(self): f"-DCMAKE_BUILD_TYPE={self.build_type}", f"-DCMAKE_INSTALL_PREFIX={self.install_prefix}", str(self.source_dir), - ] + self.cmake_options + ] + self.options self.run_command(*cmake_args) def build(self, target: str = ""): """ - Build the project. + Builds the project using CMake. Args: - target (str): Specific build target to build. + target (str): Specific build target to build (optional). If not specified, the default target is built. + + This function uses the CMake build command, supporting parallel jobs to speed up the build process. """ - build_cmd = ["cmake", "--build", str(self.build_dir)] + build_cmd = ["cmake", "--build", + str(self.build_dir), "--parallel", str(self.parallel)] if target: build_cmd += ["--target", target] self.run_command(*build_cmd) def install(self): - """Install the project.""" - self.run_command("cmake", "--install", str(self.build_dir)) + """ + Installs the project to the specified prefix. - def clean(self): - """Clean the build directory.""" - if self.build_dir.exists(): - for item in self.build_dir.iterdir(): - if item.is_dir(): - self.run_command("rm", "-rf", str(item)) - else: - item.unlink() + This function runs the CMake install command, which installs the built + artifacts to the directory specified by the install prefix. + """ + self.run_command("cmake", "--install", str(self.build_dir)) def test(self): - """Run CTest for the project.""" - self.run_command("ctest", "--output-on-failure", "-C", self.build_type, "-S", str(self.build_dir)) + """ + Runs tests using CTest with detailed output on failure. + + This function runs CTest to execute the project's tests, providing detailed + output if any tests fail, making it easier to diagnose issues. + """ + self.run_command("ctest", "--output-on-failure", "-C", + self.build_type, "-S", str(self.build_dir)) def generate_docs(self, doc_target: str = "doc"): """ - Generate documentation if documentation target is available. + Generates documentation if the specified documentation target is available. Args: - doc_target (str): Documentation target to build. + doc_target (str): The documentation target to build (default is 'doc'). + + This function builds the specified documentation target using the CMake build command. """ self.build(doc_target) -class MesonBuilder: + +class MesonBuilder(BuildHelperBase): """ MesonBuilder is a utility class to handle building projects using Meson. Args: source_dir (Path): Path to the source directory. build_dir (Path): Path to the build directory. - build_type (Literal["debug", "release"]): Build type (debug or release). - install_prefix (Path): Installation prefix directory. + build_type (Literal["debug", "release"]): Type of build (debug or release). + install_prefix (Path): Directory prefix where the project will be installed. meson_options (Optional[List[str]]): Additional options for Meson. + env_vars (Optional[dict]): Environment variables to set for the build process. + verbose (bool): Flag to enable verbose output during command execution. + parallel (int): Number of parallel jobs to use for building. Methods: - run_command: Helper function to run shell commands. - configure: Configure the Meson build system. - build: Build the project. - install: Install the project. - clean: Clean the build directory. - test: Run Meson tests for the project. - generate_docs: Generate documentation if documentation target is available. + configure: Configures the Meson build system by generating build files. + build: Builds the project, optionally specifying a target. + install: Installs the project to the specified prefix. + test: Runs tests using Meson, with error logs printed on failures. + generate_docs: Generates documentation using the specified documentation target. """ + def __init__( self, source_dir: Path, build_dir: Path, build_type: Literal["debug", "release"] = "debug", - install_prefix: Path = None, # type: ignore + install_prefix: Path = None, # type: ignore meson_options: Optional[List[str]] = None, + env_vars: Optional[dict] = None, + verbose: bool = False, + parallel: int = 4, ): - self.source_dir = source_dir - self.build_dir = build_dir + super().__init__(source_dir, build_dir, + install_prefix, meson_options, env_vars, verbose) self.build_type = build_type - self.install_prefix = install_prefix or build_dir / "install" - self.meson_options = meson_options or [] + self.parallel = parallel - def run_command(self, *cmd: str): + def configure(self): """ - Helper function to run shell commands. + Configures the Meson build system. - Args: - cmd (str): The command and its arguments to run. + This function sets up the Meson build system, generating necessary build files based on + the specified build type and additional Meson options provided. """ - print(f"Running: {' '.join(cmd)}") - result = subprocess.run(cmd, check=True, capture_output=True, text=True) - print(result.stdout) - if result.stderr: - print(result.stderr, file=sys.stderr) - - def configure(self): - """Configure the Meson build system。""" self.build_dir.mkdir(parents=True, exist_ok=True) meson_args = [ "meson", @@ -185,84 +260,117 @@ def configure(self): str(self.source_dir), f"--buildtype={self.build_type}", f"--prefix={self.install_prefix}", - ] + self.meson_options + ] + self.options self.run_command(*meson_args) def build(self, target: str = ""): """ - Build the project. + Builds the project using Meson. Args: - target (str): Specific target to build. + target (str): Specific target to build (optional). If not specified, the default target is built. + + This function compiles the project using Meson's compile command, with support for parallel jobs + to speed up the build process. """ - build_cmd = ["meson", "compile", "-C", str(self.build_dir)] + build_cmd = ["meson", "compile", "-C", + str(self.build_dir), f"-j{self.parallel}"] if target: build_cmd += ["--target", target] self.run_command(*build_cmd) def install(self): - """Install the project。""" - self.run_command("meson", "install", "-C", str(self.build_dir)) + """ + Installs the project to the specified prefix. - def clean(self): - """Clean the build directory。""" - if self.build_dir.exists(): - for item in self.build_dir.iterdir(): - if item.is_dir(): - self.run_command("rm", "-rf", str(item)) - else: - item.unlink() + This function runs the Meson install command, which installs the built + artifacts to the directory specified by the install prefix. + """ + self.run_command("meson", "install", "-C", str(self.build_dir)) def test(self): - """Run Meson tests for the project。""" - self.run_command("meson", "test", "-C", str(self.build_dir), "--print-errorlogs") + """ + Runs tests using Meson, with error logs printed on failures. + + This function runs Meson tests, displaying error logs for any failed tests + to provide detailed feedback and aid in debugging. + """ + self.run_command("meson", "test", "-C", + str(self.build_dir), "--print-errorlogs") def generate_docs(self, doc_target: str = "doc"): """ - Generate documentation if documentation target is available. + Generates documentation if the specified documentation target is available. Args: - doc_target (str): Documentation target to build. + doc_target (str): The documentation target to build (default is 'doc'). + + This function builds the specified documentation target using the Meson build system. """ self.build(doc_target) + def main(): """ Main function to run the build system helper. + + This function parses command-line arguments to determine the build system (CMake or Meson), + source and build directories, build options, and actions (clean, build, install, test, generate docs). + It then initializes the appropriate builder class and performs the requested operations. + + Command-line Arguments: + --source_dir: Specifies the source directory of the project. + --build_dir: Specifies the build directory where build files and artifacts will be generated. + --builder: Specifies the build system to use ('cmake' or 'meson'). + --generator: Specifies the CMake generator (e.g., Ninja, Unix Makefiles) if using CMake. + --build_type: Specifies the build type ('Debug', 'Release', 'debug', 'release'). + --target: Specifies a specific build target to build. + --install: Flag to indicate that the project should be installed after building. + --clean: Flag to indicate that the build directory should be cleaned before building. + --test: Flag to indicate that tests should be run after building. + --cmake_options: Additional options for CMake. + --meson_options: Additional options for Meson. + --generate_docs: Flag to indicate that documentation should be generated. + --env: Environment variables to set during the build process. + --verbose: Enables verbose output during command execution. + --parallel: Number of parallel jobs to use for building. """ parser = argparse.ArgumentParser(description="Build System Python Builder") + parser.add_argument("--source_dir", type=Path, + default=Path(".").resolve(), help="Source directory") + parser.add_argument("--build_dir", type=Path, + default=Path("build").resolve(), help="Build directory") parser.add_argument( - "--source_dir", type=Path, default=Path(".").resolve(), help="Source directory" - ) + "--builder", choices=["cmake", "meson"], required=True, help="Choose the build system") parser.add_argument( - "--build_dir", - type=Path, - default=Path("build").resolve(), - help="Build directory", - ) - parser.add_argument("--builder", choices=["cmake", "meson"], required=True, help="Choose the build system") - parser.add_argument("--generator", choices=["Ninja", "Unix Makefiles"], default="Ninja") - parser.add_argument("--build_type", choices=["Debug", "Release", "debug", "release"], default="Debug") - parser.add_argument("--target", default="") - parser.add_argument("--install", action="store_true", help="Install the project") - parser.add_argument("--clean", action="store_true", help="Clean the build directory") + "--generator", choices=["Ninja", "Unix Makefiles"], default="Ninja", help="CMake generator to use") + parser.add_argument("--build_type", choices=[ + "Debug", "Release", "debug", "release"], default="Debug", help="Build type") + parser.add_argument("--target", default="", help="Specify a build target") + parser.add_argument("--install", action="store_true", + help="Install the project") + parser.add_argument("--clean", action="store_true", + help="Clean the build directory") parser.add_argument("--test", action="store_true", help="Run the tests") - parser.add_argument( - "--cmake_options", - nargs="*", - default=[], - help="Custom CMake options (e.g. -DVAR=VALUE)", - ) - parser.add_argument( - "--meson_options", - nargs="*", - default=[], - help="Custom Meson options (e.g. -Dvar=value)", - ) - parser.add_argument("--generate_docs", action="store_true", help="Generate documentation") + parser.add_argument("--cmake_options", nargs="*", default=[], + help="Custom CMake options (e.g. -DVAR=VALUE)") + parser.add_argument("--meson_options", nargs="*", default=[], + help="Custom Meson options (e.g. -Dvar=value)") + parser.add_argument("--generate_docs", action="store_true", + help="Generate documentation") + parser.add_argument("--env", nargs="*", default=[], + help="Set environment variables (e.g. VAR=value)") + parser.add_argument("--verbose", action="store_true", + help="Enable verbose output") + parser.add_argument("--parallel", type=int, default=4, + help="Number of parallel jobs for building") args = parser.parse_args() + # Parse environment variables from the command line + env_vars = {var.split("=")[0]: var.split("=")[1] for var in args.env} + + # Initialize the appropriate builder based on the specified build system if args.builder == "cmake": builder = CMakeBuilder( source_dir=args.source_dir, @@ -270,6 +378,9 @@ def main(): generator=args.generator, build_type=args.build_type, cmake_options=args.cmake_options, + env_vars=env_vars, + verbose=args.verbose, + parallel=args.parallel, ) elif args.builder == "meson": builder = MesonBuilder( @@ -277,22 +388,33 @@ def main(): build_dir=args.build_dir, build_type=args.build_type, meson_options=args.meson_options, + env_vars=env_vars, + verbose=args.verbose, + parallel=args.parallel, ) + # Perform cleaning if requested if args.clean: builder.clean() + # Configure the build system builder.configure() + + # Build the project with the specified target builder.build(args.target) + # Install the project if the install flag is set if args.install: builder.install() + # Run tests if the test flag is set if args.test: builder.test() + # Generate documentation if the generate_docs flag is set if args.generate_docs: builder.generate_docs() + if __name__ == "__main__": main() diff --git a/modules/lithium.pytools/tools/cmake_generator.py b/modules/lithium.pytools/tools/cmake_generator.py new file mode 100644 index 00000000..2fdc7c3b --- /dev/null +++ b/modules/lithium.pytools/tools/cmake_generator.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +""" +This script automates the generation of CMake build configuration files for C++ projects. It supports multi-directory +project structures, the creation of custom FindXXX.cmake files for third-party libraries, and the use of JSON configuration +files for specifying project settings. + +Key features: +1. Multi-directory support with separate CMakeLists.txt for each subdirectory. +2. Custom FindXXX.cmake generation for locating third-party libraries. +3. JSON-based project configuration to streamline the CMakeLists.txt generation process. +4. Customizable compiler flags, linker flags, and dependencies. +""" +import argparse +import json +from pathlib import Path +from dataclasses import dataclass, field +import platform + + +@dataclass +class ProjectConfig: + """ + Dataclass to hold the project configuration. + + Attributes: + project_name (str): The name of the project. + version (str): The version of the project. + cpp_standard (str): The C++ standard to use (default is C++11). + executable (bool): Whether to generate an executable target. + static_library (bool): Whether to generate a static library. + shared_library (bool): Whether to generate a shared library. + enable_testing (bool): Whether to enable testing with CMake's `enable_testing()`. + include_dirs (list): List of directories to include in the project. + sources (str): Glob pattern to specify source files (default is `src/*.cpp`). + compiler_flags (list): List of compiler flags (e.g., `-O3`, `-Wall`). + linker_flags (list): List of linker flags (e.g., `-lpthread`). + dependencies (list): List of third-party dependencies. + subdirs (list): List of subdirectories for multi-directory project structure. + install_path (str): Custom installation path (default is `bin`). + test_framework (str): The test framework to be used (optional, e.g., `GoogleTest`). + """ + project_name: str + version: str = "1.0" + cpp_standard: str = "11" + executable: bool = True + static_library: bool = False + shared_library: bool = False + enable_testing: bool = False + include_dirs: list = field(default_factory=list) + sources: str = "src/*.cpp" + compiler_flags: list = field(default_factory=list) + linker_flags: list = field(default_factory=list) + dependencies: list = field(default_factory=list) + subdirs: list = field(default_factory=list) + install_path: str = "bin" + test_framework: str = None # Optional: e.g., GoogleTest + + +def detect_os() -> str: + """ + Detects the current operating system and returns a suitable CMake system name. + + Returns: + str: The appropriate CMake system name for the current OS (Windows, Darwin for macOS, Linux). + """ + current_os = platform.system() + if current_os == "Windows": + return "set(CMAKE_SYSTEM_NAME Windows)\n" + elif current_os == "Darwin": + return "set(CMAKE_SYSTEM_NAME Darwin)\n" + elif current_os == "Linux": + return "set(CMAKE_SYSTEM_NAME Linux)\n" + return "" + + +def generate_cmake(config: ProjectConfig) -> str: + """ + Generates the content of the main CMakeLists.txt based on the provided project configuration. + + Args: + config (ProjectConfig): The project configuration object containing all settings. + + Returns: + str: The content of the generated CMakeLists.txt file. + """ + cmake_template = f"""cmake_minimum_required(VERSION 3.15) + +# Project name and version +project({config.project_name} VERSION {config.version}) + +# Set C++ standard +set(CMAKE_CXX_STANDARD {config.cpp_standard}) +set(CMAKE_CXX_STANDARD_REQUIRED True) + +# OS-specific settings +{detect_os()} + +# Source files +file(GLOB_RECURSE SOURCES "{config.sources}") + +# Include directories +""" + if config.include_dirs: + for include_dir in config.include_dirs: + cmake_template += f'include_directories("{include_dir}")\n' + + # Compiler flags + if config.compiler_flags: + cmake_template += "add_compile_options(" + cmake_template += " ".join(config.compiler_flags) + ")\n" + + # Linker flags + if config.linker_flags: + cmake_template += "add_link_options(" + cmake_template += " ".join(config.linker_flags) + ")\n" + + # Dependencies (find_package) + if config.dependencies: + for dep in config.dependencies: + cmake_template += f'find_package({dep} REQUIRED)\n' + + # Subdirectory handling for multi-directory support + if config.subdirs: + for subdir in config.subdirs: + cmake_template += f'add_subdirectory({subdir})\n' + + # Create targets: executable or library + if config.executable: + cmake_template += f'add_executable({ + config.project_name} ${{{{SOURCES}}}})\n' + elif config.static_library: + cmake_template += f'add_library({ + config.project_name} STATIC ${{{{SOURCES}}}})\n' + elif config.shared_library: + cmake_template += f'add_library({ + config.project_name} SHARED ${{{{SOURCES}}}})\n' + + # Testing support + if config.enable_testing: + cmake_template += """ +# Enable testing +enable_testing() +add_subdirectory(tests) +""" + # Custom install path + cmake_template += f""" +# Installation rule +install(TARGETS {config.project_name} DESTINATION {config.install_path}) +""" + + return cmake_template + + +def generate_find_cmake(dependency_name: str) -> str: + """ + Generates a FindXXX.cmake template for a specified third-party dependency. + + Args: + dependency_name (str): The name of the third-party library (e.g., Boost, OpenCV). + + Returns: + str: The content of the FindXXX.cmake file to locate the library. + """ + return f"""# Find{dependency_name}.cmake - Find {dependency_name} library + +# Locate the {dependency_name} library and headers +find_path({dependency_name}_INCLUDE_DIR + NAMES {dependency_name}.h + PATHS /usr/local/include /usr/include +) + +find_library({dependency_name}_LIBRARY + NAMES {dependency_name} + PATHS /usr/local/lib /usr/lib +) + +if({dependency_name}_INCLUDE_DIR AND {dependency_name}_LIBRARY) + set({dependency_name}_FOUND TRUE) + message(STATUS "{dependency_name} found") +else() + set({dependency_name}_FOUND FALSE) + message(FATAL_ERROR "{dependency_name} not found") +endif() + +# Mark variables as advanced +mark_as_advanced({dependency_name}_INCLUDE_DIR {dependency_name}_LIBRARY) +""" + + +def save_file(content: str, directory: str = ".", filename: str = "CMakeLists.txt"): + """ + Saves the provided content to a file. + + Args: + content (str): The content to save to the file. + directory (str): The directory where the file should be saved. + filename (str): The name of the file (default is CMakeLists.txt). + """ + directory_path = Path(directory) + directory_path.mkdir(parents=True, exist_ok=True) + + file_path = directory_path / filename + file_path.write_text(content, encoding='utf-8') + + print(f"{filename} generated in {file_path}") + + +def generate_from_json(json_file: str) -> ProjectConfig: + """ + Reads project configuration from a JSON file and converts it into a ProjectConfig object. + + Args: + json_file (str): Path to the JSON configuration file. + + Returns: + ProjectConfig: A project configuration object with settings parsed from the JSON file. + """ + with open(json_file, "r", encoding="utf-8") as file: + data = json.load(file) + + return ProjectConfig( + project_name=data.get("project_name", "MyProject"), + version=data.get("version", "1.0"), + cpp_standard=data.get("cpp_standard", "11"), + executable=data.get("executable", True), + static_library=data.get("static_library", False), + shared_library=data.get("shared_library", False), + enable_testing=data.get("enable_testing", False), + include_dirs=data.get("include_dirs", []), + sources=data.get("sources", "src/*.cpp"), + compiler_flags=data.get("compiler_flags", []), + linker_flags=data.get("linker_flags", []), + dependencies=data.get("dependencies", []), + subdirs=data.get("subdirs", []), + install_path=data.get("install_path", "bin"), + test_framework=data.get("test_framework", None) + ) + + +def parse_arguments(): + """ + Parses command-line arguments to either generate CMakeLists.txt files, FindXXX.cmake files, or handle JSON input. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ + parser = argparse.ArgumentParser( + description="Generate a CMake template for C++ projects.") + parser.add_argument( + "--json", type=str, help="Path to JSON config file to generate CMakeLists.txt.") + parser.add_argument("--find-package", type=str, + help="Generate a FindXXX.cmake file for a specified library.") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + + if args.json: + # Generate CMakeLists.txt from JSON configuration + config = generate_from_json(args.json) + cmake_content = generate_cmake(config) + save_file(cmake_content) + + # Generate FindXXX.cmake files for each dependency specified in the JSON file + for dep in config.dependencies: + find_cmake_content = generate_find_cmake(dep) + save_file(find_cmake_content, directory="cmake", + filename=f"Find{dep}.cmake") + + if args.find_package: + # Generate a single FindXXX.cmake file for the specified dependency + find_cmake_content = generate_find_cmake(args.find_package) + save_file(find_cmake_content, directory="cmake", + filename=f"Find{args.find_package}.cmake") diff --git a/modules/lithium.pytools/tools/compiler_parser.py b/modules/lithium.pytools/tools/compiler_parser.py index de8efc53..b246a30a 100644 --- a/modules/lithium.pytools/tools/compiler_parser.py +++ b/modules/lithium.pytools/tools/compiler_parser.py @@ -1,12 +1,16 @@ """ This module contains functions for parsing compiler output and converting it to JSON, CSV, or XML format. """ + +from pathlib import Path import re import json import csv import argparse +import os import xml.etree.ElementTree as ET -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed +from termcolor import colored def parse_gcc_clang_output(output): @@ -32,7 +36,8 @@ def parse_gcc_clang_output(output): "file": match[0], "line": int(match[1]), "column": int(match[2]), - "message": match[4].strip() + "message": match[4].strip(), + "severity": match[3].lower(), } if match[3].lower() == 'error': results["errors"].append(entry) @@ -70,7 +75,8 @@ def parse_msvc_output(output): "file": match[0], "line": int(match[1]), "code": match[3], - "message": match[4].strip() + "message": match[4].strip(), + "severity": match[2].lower(), } if match[2].lower() == 'error': results["errors"].append(entry) @@ -107,7 +113,8 @@ def parse_cmake_output(output): entry = { "file": match[0], "line": int(match[1]), - "message": match[3].strip() + "message": match[3].strip(), + "severity": match[2].lower(), } if match[2].lower() == 'error': results["errors"].append(entry) @@ -152,7 +159,8 @@ def write_to_csv(data, output_path): output_path (str): The path to the output CSV file. """ with open(output_path, 'w', newline='', encoding="utf-8") as csvfile: - fieldnames = ['file', 'line', 'column', 'type', 'code', 'message'] + fieldnames = ['file', 'line', 'column', + 'type', 'code', 'message', 'severity'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for entry in data: @@ -198,6 +206,25 @@ def process_file(compiler, file_path): } +def colorize_output(entries): + """ + Prints compiler results with colorized output in the console. + + Args: + entries (list): A list of parsed compiler entries. + """ + for entry in entries: + if entry['type'] == 'errors': + print(colored(f"Error in {entry['file']}:{ + entry['line']} - {entry['message']}", 'red')) + elif entry['type'] == 'warnings': + print(colored(f"Warning in {entry['file']}:{ + entry['line']} - {entry['message']}", 'yellow')) + else: + print(colored(f"Info in {entry['file']}:{ + entry['line']} - {entry['message']}", 'blue')) + + def main(): """ Main function to parse compiler output and convert to JSON, CSV, or XML format. @@ -213,50 +240,78 @@ def main(): parser.add_argument( '--output-file', default='output.json', help="Output file name.") parser.add_argument( - '--filter', choices=['error', 'warning', 'info'], help="Filter by message type.") + '--output-dir', default='.', help="Output directory.") + parser.add_argument( + '--filter', nargs='*', choices=['error', 'warning', 'info'], help="Filter by message types.") + parser.add_argument('--stats', action='store_true', + help="Include statistics in the output.") parser.add_argument( - '--stats', action='store_true', help="Include statistics in the output.") + '--concurrency', type=int, default=4, help="Number of concurrent threads for processing files.") args = parser.parse_args() + # Prepare the output directory + output_dir = Path(args.output_dir).resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + # Initialize results list all_results = [] - with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_file, args.compiler, file_path) - for file_path in args.file_paths] - for future in futures: - all_results.append(future.result()) + # Use ThreadPoolExecutor for concurrent processing of files + with ThreadPoolExecutor(max_workers=args.concurrency) as executor: + futures = {executor.submit( + process_file, args.compiler, file_path): file_path for file_path in args.file_paths} + + for future in as_completed(futures): + try: + result = future.result() + all_results.append(result) + except Exception as e: + print(colored(f"Error processing { + futures[future]}: {e}", 'red')) + + # Flatten results for output processing flattened_results = [] for result in all_results: - for key, entries in result['results'].items(): + for severity, entries in result['results'].items(): for entry in entries: - entry['type'] = key + entry['type'] = severity entry['file'] = result['file'] flattened_results.append(entry) + # Apply filtering if specified if args.filter: flattened_results = [ - entry for entry in flattened_results if entry['type'] == args.filter] + entry for entry in flattened_results if entry['type'] in args.filter] + # Calculate statistics if requested if args.stats: stats = { + "total": len(flattened_results), "errors": sum(1 for entry in flattened_results if entry['type'] == 'errors'), "warnings": sum(1 for entry in flattened_results if entry['type'] == 'warnings'), "info": sum(1 for entry in flattened_results if entry['type'] == 'info'), } - print(f"Statistics: {json.dumps(stats, indent=4)}") + print(f"Statistics:\n{json.dumps(stats, indent=4)}") + + # Output results to the specified format + output_file_path = output_dir / args.output_file if args.output_format == 'json': json_output = json.dumps(flattened_results, indent=4) - with open(args.output_file, 'w', encoding="utf-8") as json_file: + with open(output_file_path, 'w', encoding="utf-8") as json_file: json_file.write(json_output) - print(f"JSON output saved to {args.output_file}") + print(f"JSON output saved to {output_file_path}") elif args.output_format == 'csv': - write_to_csv(flattened_results, args.output_file) - print(f"CSV output saved to {args.output_file}") + write_to_csv(flattened_results, output_file_path) + print(f"CSV output saved to {output_file_path}") elif args.output_format == 'xml': - write_to_xml(flattened_results, args.output_file) - print(f"XML output saved to {args.output_file}") + write_to_xml(flattened_results, output_file_path) + print(f"XML output saved to {output_file_path}") + + # Optional: Print colorized output to the console + print("\nColorized Output:") + colorize_output(flattened_results) if __name__ == "__main__": diff --git a/modules/lithium.pytools/tools/video_editor.py b/modules/lithium.pytools/tools/video_editor.py new file mode 100644 index 00000000..3ff25808 --- /dev/null +++ b/modules/lithium.pytools/tools/video_editor.py @@ -0,0 +1,264 @@ +import argparse +import os +import numpy as np +from moviepy.editor import VideoFileClip, AudioFileClip, CompositeAudioClip +from moviepy.audio.fx.all import volumex, audio_fadein, audio_fadeout +from pydub import AudioSegment +import matplotlib.pyplot as plt +from scipy.io import wavfile +from scipy.signal import spectrogram, wiener + + +def process_audio(video_path, audio_output_path, output_format='mp3', speed=1.0, volume=1.0, + start_time=None, end_time=None, fade_in=0, fade_out=0, reverse=False, normalize=False, noise_reduce=False): + """ + Processes audio from a video file: extracts, modifies speed, volume, applies fade effects, reverses, and saves the audio. + + Parameters: + video_path (str): Path to the input video file. + audio_output_path (str): Path to save the processed audio file. + output_format (str): Format of the output audio file ('mp3' or 'wav'). + speed (float): Speed factor to adjust the audio playback. + volume (float): Volume multiplier for the audio. + start_time (float): Start time (in seconds) to trim the audio. + end_time (float): End time (in seconds) to trim the audio. + fade_in (float): Duration (in seconds) for the fade-in effect. + fade_out (float): Duration (in seconds) for the fade-out effect. + reverse (bool): Whether to reverse the audio. + normalize (bool): Normalize the audio volume to -1 to 1 range. + noise_reduce (bool): Apply basic noise reduction using Wiener filter. + """ + # Load the video file + video = VideoFileClip(video_path) + + # Extract audio from the video + audio = video.audio + + # Adjust audio speed + if speed != 1.0: + audio = audio.speedx(speed) + + # Adjust audio volume + if volume != 1.0: + audio = audio.fx(volumex, volume) + + # Trim audio (if start and/or end times are specified) + if start_time is not None or end_time is not None: + audio = audio.subclip(start_time, end_time) + + # Apply fade-in and fade-out effects + if fade_in > 0: + audio = audio.fx(audio_fadein, duration=fade_in) + if fade_out > 0: + audio = audio.fx(audio_fadeout, duration=fade_out) + + # Reverse the audio + if reverse: + audio = audio.fx(lambda clip: clip.fl_time( + lambda t: clip.duration - t)) + + # Normalize the audio + if normalize: + audio_array = np.array(audio.to_soundarray()) + max_amplitude = np.max(np.abs(audio_array)) + audio_array = audio_array / max_amplitude + audio = audio.set_audio_array(audio_array) + + # Apply noise reduction + if noise_reduce: + sample_rate, data = wavfile.read(audio_output_path) + data = wiener(data) + wavfile.write(audio_output_path, sample_rate, data) + + # Save the processed audio + audio.write_audiofile( + audio_output_path, codec='libmp3lame' if output_format == 'mp3' else 'pcm_s16le') + + # Close the video file + video.close() + + +def batch_process(input_dir, output_dir, **kwargs): + """ + Processes all video files in a given directory in batch mode. + + Parameters: + input_dir (str): Directory containing the input video files. + output_dir (str): Directory where the processed audio files will be saved. + kwargs: Additional arguments passed to the audio processing function. + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + for filename in os.listdir(input_dir): + if filename.endswith(('.mp4', '.avi', '.mov', '.mkv')): + input_path = os.path.join(input_dir, filename) + output_path = os.path.join(output_dir, os.path.splitext( + filename)[0] + '.' + kwargs.get('output_format', 'mp3')) + try: + process_audio(input_path, output_path, **kwargs) + print(f"Processed: {filename}") + except Exception as e: + print(f"Error processing {filename}: {e}") + + +def mix_audio(audio_files, output_file, volumes=None): + """ + Mixes multiple audio files into one. + + Parameters: + audio_files (list): List of paths to audio files to be mixed. + output_file (str): Path to save the mixed audio file. + volumes (list): List of volume multipliers for each audio file. + """ + if volumes is None: + volumes = [1.0] * len(audio_files) + + audio_clips = [AudioFileClip(f).volumex(v) + for f, v in zip(audio_files, volumes)] + final_audio = CompositeAudioClip(audio_clips) + final_audio.write_audiofile(output_file) + + +def split_audio(input_file, output_prefix, segment_length): + """ + Splits an audio file into segments of specified length. + + Parameters: + input_file (str): Path to the input audio file. + output_prefix (str): Prefix for the output segment files. + segment_length (float): Length of each segment in seconds. + """ + audio = AudioSegment.from_file(input_file) + length_ms = len(audio) + segment_length_ms = segment_length * 1000 + + for i, start in enumerate(range(0, length_ms, segment_length_ms)): + end = start + segment_length_ms + segment = audio[start:end] + segment.export(f"{output_prefix}_{i+1}.mp3", format="mp3") + + +def convert_audio_format(input_file, output_file): + """ + Converts an audio file to a different format. + + Parameters: + input_file (str): Path to the input audio file. + output_file (str): Path to the output file with the desired format. + """ + audio = AudioSegment.from_file(input_file) + export_format = os.path.splitext(output_file)[1][1:] + audio.export(output_file, format=export_format) + + +def visualize_audio(input_file, output_file): + """ + Visualizes an audio file as a waveform and spectrogram. + + Parameters: + input_file (str): Path to the input audio file (WAV format). + output_file (str): Path to save the visualization image. + """ + sample_rate, data = wavfile.read(input_file) + + plt.figure(figsize=(12, 8)) + + # Plot waveform + plt.subplot(2, 1, 1) + plt.plot(np.arange(len(data)) / sample_rate, data) + plt.title('Audio Waveform') + plt.xlabel('Time (seconds)') + plt.ylabel('Amplitude') + + # Plot spectrogram + plt.subplot(2, 1, 2) + frequencies, times, Sxx = spectrogram(data, sample_rate) + plt.pcolormesh(times, frequencies, 10 * np.log10(Sxx)) + plt.title('Spectrogram') + plt.xlabel('Time (seconds)') + plt.ylabel('Frequency (Hz)') + plt.colorbar(label='Intensity (dB)') + + plt.tight_layout() + plt.savefig(output_file) + plt.close() + + +def main(): + parser = argparse.ArgumentParser( + description="Full-featured audio processing tool.") + parser.add_argument( + "input", help="Path to the input video/audio file or directory containing files.") + parser.add_argument( + "output", help="Path to save the output audio file or directory.") + parser.add_argument( + "--format", choices=['mp3', 'wav'], default='mp3', help="Output audio format (default: mp3)") + parser.add_argument("--speed", type=float, default=1.0, + help="Adjust audio speed (default: 1.0)") + parser.add_argument("--volume", type=float, default=1.0, + help="Adjust audio volume (default: 1.0)") + parser.add_argument("--start", type=float, + help="Start time of the audio (seconds)") + parser.add_argument("--end", type=float, + help="End time of the audio (seconds)") + parser.add_argument("--fade-in", type=float, default=0, + help="Fade-in duration (seconds)") + parser.add_argument("--fade-out", type=float, default=0, + help="Fade-out duration (seconds)") + parser.add_argument("--reverse", action="store_true", + help="Reverse the audio") + parser.add_argument("--normalize", action="store_true", + help="Normalize the audio volume") + parser.add_argument("--noise-reduce", action="store_true", + help="Apply basic noise reduction") + parser.add_argument("--batch", action="store_true", + help="Batch process all videos in a directory") + parser.add_argument("--mix", nargs='+', help="Mix multiple audio files") + parser.add_argument("--mix-volumes", nargs='+', type=float, + help="Volumes for each audio file during mix") + parser.add_argument("--split", type=float, + help="Split audio into segments of specified length (seconds)") + parser.add_argument( + "--convert", help="Convert audio to a different format") + parser.add_argument("--visualize", action="store_true", + help="Generate a visualization of the audio (waveform and spectrogram)") + + args = parser.parse_args() + + # Mix multiple audio files + if args.mix: + mix_audio(args.mix, args.output, args.mix_volumes) + # Split audio into segments + elif args.split: + split_audio(args.input, args.output, args.split) + # Convert audio format + elif args.convert: + convert_audio_format(args.input, args.output) + # Visualize the audio as waveform and spectrogram + elif args.visualize: + visualize_audio(args.input, args.output) + # Batch process all video files in a directory + elif args.batch: + batch_process( + args.input, args.output, + output_format=args.format, speed=args.speed, volume=args.volume, + start_time=args.start, end_time=args.end, + fade_in=args.fade_in, fade_out=args.fade_out, reverse=args.reverse, + normalize=args.normalize, noise_reduce=args.noise_reduce + ) + # Process a single video or audio file + else: + process_audio( + args.input, args.output, + output_format=args.format, speed=args.speed, volume=args.volume, + start_time=args.start, end_time=args.end, + fade_in=args.fade_in, fade_out=args.fade_out, reverse=args.reverse, + normalize=args.normalize, noise_reduce=args.noise_reduce + ) + + print("Processing completed successfully.") + + +if __name__ == "__main__": + main() diff --git a/package/package.sh b/package/package.sh new file mode 100644 index 00000000..9fd28ccf --- /dev/null +++ b/package/package.sh @@ -0,0 +1,164 @@ +#!/bin/bash + +# 脚本:CMake MinGW项目打包脚本 +# Script: CMake MinGW Project Packaging Script +# 描述:这个脚本用于构建、安装和打包基于CMake的MinGW项目。 +# Description: This script is used to build, install, and package CMake-based MinGW projects. +# 作者:Max Qian +# Author: Max Qian +# 版本:1.2 +# Version: 1.2 +# 使用方法:./package.sh [clean|build|package] +# Usage: ./package.sh [clean|build|package] + +# ===== 变量设置 / Variable Settings ===== +# 项目名称 / Project name +PROJECT_NAME="MyProject" +# 构建目录 / Build directory +BUILD_DIR="build" +# 安装目录 / Install directory +INSTALL_DIR="install" +# 打包目录 / Package directory +PACKAGE_DIR="package" +# 版本号(从git获取,如果失败则使用默认值) +# Version number (obtained from git, use default if failed) +VERSION=$(git describe --tags --always --dirty 2>/dev/null || echo "1.0.0") +# 最终生成的压缩包名称 / Name of the final generated archive +ARCHIVE_NAME="${PROJECT_NAME}-${VERSION}-win64.zip" + +# ===== 颜色代码 / Color Codes ===== +# 用于美化控制台输出 / Used to beautify console output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color / 无颜色 + +# ===== 辅助函数 / Helper Functions ===== +# 输出信息日志 / Output information log +log_info() { + echo -e "${GREEN}[INFO] $1${NC}" +} + +# 输出警告日志 / Output warning log +log_warn() { + echo -e "${YELLOW}[WARN] $1${NC}" +} + +# 输出错误日志 / Output error log +log_error() { + echo -e "${RED}[ERROR] $1${NC}" +} + +# 检查命令是否存在 / Check if a command exists +check_command() { + if ! command -v $1 &> /dev/null; then + log_error "$1 could not be found. Please install it and try again." + log_error "$1 未找到。请安装后重试。" + exit 1 + fi +} + +# ===== 主要功能函数 / Main Function ===== +# 清理函数:删除之前的构建产物 +# Clean function: Remove previous build artifacts +clean() { + log_info "Cleaning previous build artifacts..." + log_info "清理之前的构建产物..." + rm -rf $BUILD_DIR $INSTALL_DIR $PACKAGE_DIR +} + +# 构建函数:配置CMake,编译项目,并安装 +# Build function: Configure CMake, compile the project, and install +build() { + log_info "Configuring project with CMake..." + log_info "使用CMake配置项目..." + mkdir -p $BUILD_DIR + cd $BUILD_DIR + # 使用MinGW Makefiles生成器,设置安装前缀和构建类型 + # Use MinGW Makefiles generator, set install prefix and build type + cmake -G "MinGW Makefiles" -DCMAKE_INSTALL_PREFIX=../$INSTALL_DIR -DCMAKE_BUILD_TYPE=Release .. + if [ $? -ne 0 ]; then + log_error "CMake configuration failed." + log_error "CMake配置失败。" + exit 1 + fi + + log_info "Building project..." + log_info "构建项目..." + # 使用CMake构建项目 / Build the project using CMake + cmake --build . --config Release + if [ $? -ne 0 ]; then + log_error "Build failed." + log_error "构建失败。" + exit 1 + fi + + log_info "Installing project..." + log_info "安装项目..." + # 安装项目到指定目录 / Install the project to the specified directory + cmake --install . + if [ $? -ne 0 ]; then + log_error "Installation failed." + log_error "安装失败。" + exit 1 + fi + + cd .. +} + +# 打包函数:复制文件,处理依赖,创建压缩包 +# Package function: Copy files, handle dependencies, create archive +package() { + log_info "Packaging project..." + log_info "打包项目..." + mkdir -p $PACKAGE_DIR + # 复制安装的文件到打包目录 / Copy installed files to the package directory + cp -R $INSTALL_DIR/* $PACKAGE_DIR/ + + log_info "Copying dependencies..." + log_info "复制依赖项..." + # 复制所有可执行文件的MinGW依赖 / Copy MinGW dependencies for all executables + for file in $PACKAGE_DIR/bin/*.exe; do + ldd "$file" | grep mingw | awk '{print $3}' | xargs -I '{}' cp '{}' $PACKAGE_DIR/bin/ + done + + log_info "Creating archive..." + log_info "创建压缩包..." + # 使用7-Zip创建ZIP压缩包 / Create ZIP archive using 7-Zip + 7z a -tzip $ARCHIVE_NAME $PACKAGE_DIR/* + + log_info "Packaging complete. Archive created: $ARCHIVE_NAME" + log_info "打包完成。创建的压缩包:$ARCHIVE_NAME" +} + +# 主函数:执行完整的构建和打包流程 +# Main function: Execute the complete build and package process +main() { + clean + build + package +} + +# ===== 脚本入口 / Script Entry ===== +# 检查必要的命令是否存在 / Check if necessary commands exist +check_command cmake +check_command mingw32-make +check_command git +check_command 7z + +# 根据命令行参数执行相应的功能 +# Execute corresponding functionality based on command line arguments +case "$1" in + clean) + clean + ;; + build) + build + ;; + package) + package + ;; + *) + main + ;; +esac diff --git a/pysrc/app/command_dispatcher.py b/pysrc/app/command_dispatcher.py new file mode 100644 index 00000000..248f5f2b --- /dev/null +++ b/pysrc/app/command_dispatcher.py @@ -0,0 +1,131 @@ +# app/command_dispatcher.py +from typing import Dict, Any, Callable, Awaitable, List, Optional +from loguru import logger +import inspect +import asyncio +from datetime import datetime, timedelta + + +class Command: + def __init__(self, func: Callable, name: str, description: str, rate_limit: Optional[int] = None): + self.func = func + self.name = name + self.description = description + self.rate_limit = rate_limit + self.last_called: Dict[str, datetime] = {} + + +class CommandDispatcher: + def __init__(self): + self.commands: Dict[str, Command] = {} + self.middlewares: List[Callable] = [] + self.default_command: Optional[Command] = None + + def register(self, name: str, description: str = "", rate_limit: Optional[int] = None): + def decorator(f: Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]): + self.commands[name] = Command(f, name, description, rate_limit) + return f + return decorator + + def set_default(self, f: Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]): + self.default_command = Command(f, "default", "Default command handler") + + def add_middleware(self, middleware: Callable): + self.middlewares.append(middleware) + + async def dispatch(self, command_name: str, params: Dict[str, Any], user_id: str) -> Dict[str, Any]: + command = self.commands.get(command_name, self.default_command) + + if command is None: + return {"error": f"Unknown command: {command_name}"} + + # Apply rate limiting + if command.rate_limit: + last_called = command.last_called.get(user_id) + if last_called and datetime.now() - last_called < timedelta(seconds=command.rate_limit): + return {"error": f"Rate limit exceeded for command: {command_name}"} + command.last_called[user_id] = datetime.now() + + # Apply middlewares + for middleware in self.middlewares: + params = await middleware(command_name, params, user_id) + + try: + # Check if the command function is asynchronous + if inspect.iscoroutinefunction(command.func): + result = await command.func(params) + else: + # If it's not async, run it in an executor + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, command.func, params) + + return result + except Exception as e: + logger.error(f"Error executing command {command_name}: {e}") + return {"error": f"Command execution failed: {str(e)}"} + + def get_command_list(self) -> List[Dict[str, Any]]: + return [{"name": cmd.name, "description": cmd.description} for cmd in self.commands.values()] + + async def batch_dispatch(self, commands: List[Dict[str, Any]], user_id: str) -> List[Dict[str, Any]]: + results = [] + for cmd in commands: + result = await self.dispatch(cmd.get("command"), cmd.get("params", {}), user_id) + results.append(result) + return results + +# Example usage: + + +dispatcher = CommandDispatcher() + + +@dispatcher.register("echo", "Echoes the input message", rate_limit=5) +async def echo_command(params: Dict[str, Any]) -> Dict[str, Any]: + return {"result": params.get("message", "No message provided")} + + +@dispatcher.register("add", "Adds two numbers") +def add_command(params: Dict[str, Any]) -> Dict[str, Any]: + a = params.get("a", 0) + b = params.get("b", 0) + return {"result": a + b} + + +@dispatcher.set_default +async def default_command(params: Dict[str, Any]) -> Dict[str, Any]: + return {"result": "Unknown command. Use 'help' to see available commands."} + + +async def log_middleware(command: str, params: Dict[str, Any], user_id: str) -> Dict[str, Any]: + logger.info(f"User {user_id} is executing command: {command}") + return params + +dispatcher.add_middleware(log_middleware) + +# Usage example: + + +async def main(): + result = await dispatcher.dispatch("echo", {"message": "Hello, World!"}, "user1") + print(result) # {"result": "Hello, World!"} + + result = await dispatcher.dispatch("add", {"a": 5, "b": 3}, "user1") + print(result) # {"result": 8} + + result = await dispatcher.dispatch("unknown", {}, "user1") + # {"result": "Unknown command. Use 'help' to see available commands."} + print(result) + + command_list = dispatcher.get_command_list() + # [{"name": "echo", "description": "Echoes the input message"}, {"name": "add", "description": "Adds two numbers"}] + print(command_list) + + batch_results = await dispatcher.batch_dispatch([ + {"command": "echo", "params": {"message": "First command"}}, + {"command": "add", "params": {"a": 10, "b": 20}} + ], "user1") + print(batch_results) # [{"result": "First command"}, {"result": 30}] + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pysrc/app/dependence.py b/pysrc/app/dependence.py new file mode 100644 index 00000000..81e3e67d --- /dev/null +++ b/pysrc/app/dependence.py @@ -0,0 +1,16 @@ +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from config.config import config + +security = HTTPBasic() + +def get_current_username(credentials: HTTPBasicCredentials = Depends(security)): + correct_username = credentials.username == config.auth_username + correct_password = credentials.password == config.auth_password + if not (correct_username and correct_password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Basic"}, + ) + return credentials.username diff --git a/pysrc/image/adaptive_stretch/__init__.py b/pysrc/image/adaptive_stretch/__init__.py new file mode 100644 index 00000000..6c543b23 --- /dev/null +++ b/pysrc/image/adaptive_stretch/__init__.py @@ -0,0 +1,2 @@ +from .stretch import AdaptiveStretch +from .preview import apply_real_time_preview diff --git a/pysrc/image/adaptive_stretch/preview.py b/pysrc/image/adaptive_stretch/preview.py new file mode 100644 index 00000000..522a740e --- /dev/null +++ b/pysrc/image/adaptive_stretch/preview.py @@ -0,0 +1,26 @@ +import matplotlib.pyplot as plt +import cv2 +import numpy as np +from .stretch import AdaptiveStretch +from typing import Optional, Tuple + +def apply_real_time_preview(image: np.ndarray, noise_threshold: float = 1e-4, contrast_protection: Optional[float] = None, max_curve_points: int = 106, roi: Optional[Tuple[int, int, int, int]] = None): + """ + Simulate real-time preview by iteratively applying the adaptive stretch + transformation and displaying the result. + + :param image: Input image as a numpy array (grayscale or color). + :param noise_threshold: Threshold for treating brightness differences as noise. + :param contrast_protection: Optional contrast protection parameter. + :param max_curve_points: Maximum points for the transformation curve. + :param roi: Tuple (x, y, width, height) defining the region of interest. + """ + adaptive_stretch = AdaptiveStretch(noise_threshold, contrast_protection, max_curve_points) + preview_image = adaptive_stretch.stretch(image, roi) + + if len(preview_image.shape) == 3: + preview_image = cv2.cvtColor(preview_image, cv2.COLOR_BGR2RGB) + + plt.imshow(preview_image, cmap='gray' if len(preview_image.shape) == 2 else None) + plt.title(f"Noise Threshold: {noise_threshold}, Contrast Protection: {contrast_protection}") + plt.show() diff --git a/pysrc/image/adaptive_stretch/stretch.py b/pysrc/image/adaptive_stretch/stretch.py new file mode 100644 index 00000000..c8483be6 --- /dev/null +++ b/pysrc/image/adaptive_stretch/stretch.py @@ -0,0 +1,97 @@ +import numpy as np +import cv2 +from typing import Optional, Tuple + + +class AdaptiveStretch: + def __init__(self, noise_threshold: float = 1e-4, contrast_protection: Optional[float] = None, max_curve_points: int = 106): + """ + Initialize the AdaptiveStretch object with specific parameters. + + :param noise_threshold: Threshold for treating brightness differences as noise. + :param contrast_protection: Optional contrast protection parameter. + :param max_curve_points: Maximum points for the transformation curve. + """ + self.noise_threshold = noise_threshold + self.contrast_protection = contrast_protection + self.max_curve_points = max_curve_points + + def compute_brightness_diff(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute brightness differences between adjacent pixels. + Returns matrices of differences along the x and y axes. + + :param image: Input image as a numpy array (grayscale). + :return: Tuple of differences along x and y axes. + """ + diff_x = np.diff(image, axis=1) # differences between columns + diff_y = np.diff(image, axis=0) # differences between rows + + # Pad the differences to match the original image size + diff_x = np.pad(diff_x, ((0, 0), (0, 1)), mode='constant') + diff_y = np.pad(diff_y, ((0, 1), (0, 0)), mode='constant') + + return diff_x, diff_y + + def stretch(self, image: np.ndarray, roi: Optional[Tuple[int, int, int, int]] = None) -> np.ndarray: + """ + Apply the AdaptiveStretch transformation to the image. + + :param image: Input image as a numpy array (grayscale or color). + :param roi: Tuple (x, y, width, height) defining the region of interest. + :return: Stretched image. + """ + if len(image.shape) == 3: + channels = cv2.split(image) + else: + channels = [image] + + stretched_channels = [] + + for channel in channels: + # Normalize the channel to the range [0, 1] + channel = channel.astype(np.float32) / 255.0 + + if roi is not None: + x, y, w, h = roi + channel_roi = channel[y:y+h, x:x+w] + else: + channel_roi = channel + + diff_x, diff_y = self.compute_brightness_diff(channel_roi) + + positive_forces = np.maximum(diff_x, 0) + np.maximum(diff_y, 0) + negative_forces = np.minimum(diff_x, 0) + np.minimum(diff_y, 0) + + positive_forces[positive_forces < self.noise_threshold] = 0 + negative_forces[negative_forces > -self.noise_threshold] = 0 + + transformation_curve = positive_forces + negative_forces + + if self.contrast_protection is not None: + transformation_curve = np.clip( + transformation_curve, -self.contrast_protection, self.contrast_protection) + + resampled_curve = cv2.resize( + transformation_curve, (self.max_curve_points, 1), interpolation=cv2.INTER_LINEAR) + + interpolated_curve = cv2.resize( + resampled_curve, (channel_roi.shape[1], channel_roi.shape[0]), interpolation=cv2.INTER_LINEAR) + + stretched_channel = channel_roi + interpolated_curve + + stretched_channel = np.clip( + stretched_channel * 255, 0, 255).astype(np.uint8) + + if roi is not None: + channel[y:y+h, x:x+w] = stretched_channel + stretched_channel = channel + + stretched_channels.append(stretched_channel) + + if len(stretched_channels) > 1: + stretched_image = cv2.merge(stretched_channels) + else: + stretched_image = stretched_channels[0] + + return stretched_image diff --git a/pysrc/image/api/strecth_count.py b/pysrc/image/api/strecth_count.py new file mode 100644 index 00000000..0e05b936 --- /dev/null +++ b/pysrc/image/api/strecth_count.py @@ -0,0 +1,200 @@ +from pathlib import Path +from typing import Tuple, Dict, Optional, List +import json +import numpy as np +import cv2 +from astropy.io import fits +from scipy import ndimage + + +def debayer_image(img: np.ndarray, bayer_pattern: Optional[str] = None) -> np.ndarray: + bayer_patterns = { + "RGGB": cv2.COLOR_BAYER_RGGB2BGR, + "GBRG": cv2.COLOR_BAYER_GBRG2BGR, + "BGGR": cv2.COLOR_BAYER_BGGR2BGR, + "GRBG": cv2.COLOR_BAYER_GRBG2BGR + } + return cv2.cvtColor(img, bayer_patterns.get(bayer_pattern, cv2.COLOR_BAYER_RGGB2BGR)) + + +def resize_image(img: np.ndarray, target_size: int) -> np.ndarray: + scale = min(target_size / max(img.shape[:2]), 1) + if scale < 1: + return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) + return img + + +def normalize_image(img: np.ndarray) -> np.ndarray: + if not np.allclose(img, img.astype(np.uint8)): + img = cv2.normalize(img, None, alpha=0, beta=255, + norm_type=cv2.NORM_MINMAX) + return img.astype(np.uint8) + + +def stretch_image(img: np.ndarray, is_color: bool) -> np.ndarray: + if is_color: + return ComputeAndStretch_ThreeChannels(img, True) + return ComputeStretch_OneChannels(img, True) + + +def detect_stars(img: np.ndarray, remove_hotpixel: bool, remove_noise: bool, do_star_mark: bool, mark_img: Optional[np.ndarray] = None) -> Tuple[np.ndarray, float, float, Dict[str, float]]: + return StarDetectAndHfr(img, remove_hotpixel, remove_noise, do_star_mark, down_sample_mean_std=True, mark_img=mark_img) + +# New functions for enhanced image processing + + +def apply_gaussian_blur(img: np.ndarray, kernel_size: int = 5) -> np.ndarray: + """Apply Gaussian blur to reduce noise.""" + return cv2.GaussianBlur(img, (kernel_size, kernel_size), 0) + + +def apply_unsharp_mask(img: np.ndarray, kernel_size: int = 5, amount: float = 1.5) -> np.ndarray: + """Apply unsharp mask to enhance image details.""" + blurred = cv2.GaussianBlur(img, (kernel_size, kernel_size), 0) + return cv2.addWeighted(img, amount + 1, blurred, -amount, 0) + + +def equalize_histogram(img: np.ndarray) -> np.ndarray: + """Apply histogram equalization to improve contrast.""" + if len(img.shape) == 2: + return cv2.equalizeHist(img) + else: + ycrcb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) + ycrcb[:, :, 0] = cv2.equalizeHist(ycrcb[:, :, 0]) + return cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2BGR) + + +def remove_hot_pixels(img: np.ndarray, threshold: float = 3.0) -> np.ndarray: + """Remove hot pixels using median filter and thresholding.""" + median = ndimage.median_filter(img, size=3) + diff = np.abs(img - median) + mask = diff > (threshold * np.std(diff)) + img[mask] = median[mask] + return img + + +def adjust_gamma(img: np.ndarray, gamma: float = 1.0) -> np.ndarray: + """Adjust image gamma.""" + inv_gamma = 1.0 / gamma + table = np.array([((i / 255.0) ** inv_gamma) * + 255 for i in np.arange(0, 256)]).astype("uint8") + return cv2.LUT(img, table) + + +def apply_clahe(img: np.ndarray, clip_limit: float = 2.0, tile_grid_size: Tuple[int, int] = (8, 8)) -> np.ndarray: + """Apply Contrast Limited Adaptive Histogram Equalization (CLAHE).""" + if len(img.shape) == 2: + clahe = cv2.createCLAHE(clipLimit=clip_limit, + tileGridSize=tile_grid_size) + return clahe.apply(img) + else: + lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) + l, a, b = cv2.split(lab) + clahe = cv2.createCLAHE(clipLimit=clip_limit, + tileGridSize=tile_grid_size) + cl = clahe.apply(l) + limg = cv2.merge((cl, a, b)) + return cv2.cvtColor(limg, cv2.COLOR_LAB2BGR) + + +def denoise_image(img: np.ndarray, h: float = 10) -> np.ndarray: + """Apply Non-Local Means Denoising.""" + return cv2.fastNlMeansDenoisingColored(img, None, h, h, 7, 21) + + +def enhance_image(img: np.ndarray, config: Dict[str, bool]) -> np.ndarray: + """Apply a series of image enhancements based on the configuration.""" + if config.get('remove_hot_pixels', False): + img = remove_hot_pixels(img) + if config.get('denoise', False): + img = denoise_image(img) + if config.get('equalize_histogram', False): + img = equalize_histogram(img) + if config.get('apply_clahe', False): + img = apply_clahe(img) + if config.get('unsharp_mask', False): + img = apply_unsharp_mask(img) + if config.get('adjust_gamma', False): + img = adjust_gamma(img, gamma=config.get('gamma', 1.0)) + return img + + +def process_image(filepath: Path, config: Dict[str, bool], resize_size: int = 2048) -> Tuple[Optional[np.ndarray], Dict[str, float]]: + try: + img, header = fits.getdata(filepath, header=True) + except Exception: + return None, {"star_count": -1, "average_hfr": -1, "max_star": -1, "min_star": -1, "average_star": -1} + + is_color = 'BAYERPAT' in header + if is_color: + img = debayer_image(img, header['BAYERPAT']) + + img = resize_image(img, resize_size) + img = normalize_image(img) + + # Apply image enhancements + img = enhance_image(img, config) + + if config['do_stretch']: + img = stretch_image(img, is_color) + + if config['do_star_count']: + img, star_count, avg_hfr, area_range = detect_stars( + img, config['remove_hotpixel'], config['remove_noise'], config['do_star_mark']) + return img, { + "star_count": float(star_count), + "average_hfr": avg_hfr, + "max_star": area_range['max'], + "min_star": area_range['min'], + "average_star": area_range['average'] + } + + return img, {"star_count": -1, "average_hfr": -1, "max_star": -1, "min_star": -1, "average_star": -1} + + +def ImageStretchAndStarCount_Optim(filepath: Path, config: Dict[str, bool], resize_size: int = 2048, + jpg_file: Optional[Path] = None, star_file: Optional[Path] = None) -> Tuple[Optional[np.ndarray], Dict[str, float]]: + img, result = process_image(filepath, config, resize_size) + + if jpg_file and img is not None: + cv2.imwrite(str(jpg_file), img) + if star_file: + with star_file.open('w') as f: + json.dump(result, f) + + return img, result + + +def StreamingDebayerAndStretch(fits_data: bytearray, width: int, height: int, config: Dict[str, bool], + resize_size: int = 2048, bayer_type: Optional[str] = None) -> Optional[np.ndarray]: + img = np.frombuffer(fits_data, dtype=np.uint16).reshape(height, width) + + if bayer_type: + img = debayer_image(img, bayer_type) + + img = resize_image(img, resize_size) + img = normalize_image(img) + + # Apply image enhancements + img = enhance_image(img, config) + + if config.get('do_stretch', True): + img = stretch_image(img, len(img.shape) == 3) + + return img + + +def StreamingDebayer(fits_data: bytearray, width: int, height: int, config: Dict[str, bool], + resize_size: int = 2048, bayer_type: Optional[str] = None) -> Optional[np.ndarray]: + img = np.frombuffer(fits_data, dtype=np.uint16).reshape(height, width) + + if bayer_type: + img = debayer_image(img, bayer_type) + + img = resize_image(img, resize_size) + img = normalize_image(img) + + # Apply image enhancements + img = enhance_image(img, config) + + return img diff --git a/pysrc/image/astroalign/astroalign.py b/pysrc/image/astroalign/astroalign.py new file mode 100644 index 00000000..2e52a688 --- /dev/null +++ b/pysrc/image/astroalign/astroalign.py @@ -0,0 +1,553 @@ +""" +Module for starfield image registration using triangle-based invariants. + +This module provides functions to estimate and apply geometric transformations between +two sets of starfield images. The core technique relies on computing invariant features +(triangle invariants) from sets of nearest neighbors of stars in both images. The +transformation between the two images is then estimated using these invariant features. + +The key functionalities include: +- Estimating transformation between two images +- Applying the estimated transformation to align images +- Extracting source positions (stars) from images +- RANSAC algorithm for robust model estimation + +Author: Max Qian +Version: 2.6.0 (lithium) +""" + +__version__ = "2.6.0" + +__all__ = [ + "MIN_MATCHES_FRACTION", + "MaxIterError", + "NUM_NEAREST_NEIGHBORS", + "PIXEL_TOL", + "apply_transform", + "estimate_transform", + "find_transform", + "matrix_transform", + "register", +] + +from collections import Counter +from functools import partial +from itertools import combinations +from typing import Any, Tuple, Union + +import sep_pjw as sep +import numpy as np +from numpy.typing import NDArray +from scipy.spatial import KDTree +from skimage.transform import estimate_transform, matrix_transform, warp + +try: + import bottleneck as bn +except ImportError: + HAS_BOTTLENECK = False +else: + HAS_BOTTLENECK = True + +PIXEL_TOL = 2 +"""The pixel distance tolerance to assume two invariant points are the same. + +Default: 2 +""" + +MIN_MATCHES_FRACTION = 0.8 +"""The minimum fraction of triangle matches to accept a transformation. + +If the minimum fraction yields more than 10 triangles, 10 is used instead. + +Default: 0.8 +""" + +NUM_NEAREST_NEIGHBORS = 5 +""" +The number of nearest neighbors of a given star (including itself) to construct +the triangle invariants. + +Default: 5 +""" + +_default_median = bn.nanmedian if HAS_BOTTLENECK else np.nanmedian # pragma: no cover +""" +Default median function when/if optional bottleneck is available +""" + +_default_average = bn.nanmean if HAS_BOTTLENECK else np.nanmean # pragma: no cover +""" +Default mean function when/if optional bottleneck is available +""" + + +def _invariantfeatures(x1: NDArray, x2: NDArray, x3: NDArray) -> list[float]: + """ + Given 3 points x1, x2, x3, return the invariant features for the set. + + Invariant features are ratios of side lengths of the triangles formed by these points, + sorted by size. These features are scale-invariant and can be used to compare star + positions between images. + + Args: + x1, x2, x3: 2D coordinates of points in the source image. + + Returns: + List containing two invariant feature values derived from the triangle side ratios. + """ + sides = np.sort( + [np.linalg.norm(x1 - x2), np.linalg.norm(x2 - x3), np.linalg.norm(x1 - x3)]) + return [sides[2] / sides[1], sides[1] / sides[0]] + + +def _arrangetriplet(sources: NDArray, vertex_indices: tuple[int, int, int]) -> NDArray: + """ + Reorder the given triplet of vertex indices according to the length of their sides. + + This function returns the indices in a consistent order based on the triangle's side + lengths. It ensures that the triangle invariants are consistently computed. + + Args: + sources: Array of source star positions. + vertex_indices: Indices of the three vertices that form the triangle. + + Returns: + Reordered array of vertex indices based on side lengths. + """ + ind1, ind2, ind3 = vertex_indices + x1, x2, x3 = sources[vertex_indices] + + side_ind = np.array([(ind1, ind2), (ind2, ind3), (ind3, ind1)]) + side_lengths = [np.linalg.norm( + x1 - x2), np.linalg.norm(x2 - x3), np.linalg.norm(x3 - x1)] + l1_ind, l2_ind, l3_ind = np.argsort(side_lengths) + + count = Counter(side_ind[[l1_ind, l2_ind]].flatten()) + a = count.most_common(1)[0][0] + count = Counter(side_ind[[l2_ind, l3_ind]].flatten()) + b = count.most_common(1)[0][0] + count = Counter(side_ind[[l3_ind, l1_ind]].flatten()) + c = count.most_common(1)[0][0] + + return np.array([a, b, c]) + + +def _generate_invariants(sources: NDArray) -> Tuple[NDArray, NDArray]: + """ + Generate invariant features and the corresponding triangles from a set of source points. + + This function constructs triangles from the nearest neighbors of each source point and + calculates their invariant features. The invariants are used for matching between images. + + Args: + sources: Array of source star positions. + + Returns: + A tuple containing the unique invariant features and the corresponding triangle vertices. + """ + arrange = partial(_arrangetriplet, sources=sources) + + inv = [] + triang_vrtx = [] + coordtree = KDTree(sources) + knn = min(len(sources), NUM_NEAREST_NEIGHBORS) + for asrc in sources: + _, indx = coordtree.query(asrc, knn) + + all_asterism_triang = [arrange(vertex_indices=list(cmb)) + for cmb in combinations(indx, 3)] + triang_vrtx.extend(all_asterism_triang) + + inv.extend([_invariantfeatures(*sources[triplet]) + for triplet in all_asterism_triang]) + + uniq_ind = [pos for (pos, elem) in enumerate(inv) + if elem not in inv[pos + 1:]] + inv_uniq = np.array(inv)[uniq_ind] + triang_vrtx_uniq = np.array(triang_vrtx)[uniq_ind] + + return inv_uniq, triang_vrtx_uniq + + +class _MatchTransform: + """ + A class to manage the fitting of a geometric transformation using matched invariant points. + + This class handles the estimation of the 2D similarity transformation between + two sets of points, and computes errors between estimated and actual points. + """ + + def __init__(self, source: NDArray, target: NDArray): + """ + Initialize the transformation model with source and target control points. + + Args: + source: Source control points. + target: Target control points. + """ + self.source = source + self.target = target + + def fit(self, data: NDArray) -> Any: + """ + Estimate the best 2D similarity transformation from the matched points in data. + + Args: + data: Matched point pairs from source and target. + + Returns: + A similarity transform object. + """ + d1, d2, d3 = data.shape + s, d = data.reshape(d1 * d2, d3).T + return estimate_transform("similarity", self.source[s], self.target[d]) + + def get_error(self, data: NDArray, approx_t: Any) -> NDArray: + """ + Calculate the maximum residual error for the matched points given the estimated transform. + + Args: + data: Matched point pairs. + approx_t: Estimated transformation model. + + Returns: + Maximum residual error for each set of matched points. + """ + d1, d2, d3 = data.shape + s, d = data.reshape(d1 * d2, d3).T + resid = approx_t.residuals( + self.source[s], self.target[d]).reshape(d1, d2) + return resid.max(axis=1) + + +def _data(image: Union[NDArray, Any]) -> NDArray: + """ + Retrieve the underlying 2D pixel data from the image. + + Args: + image: The input image. + + Returns: + The pixel data as a 2D NumPy array. + """ + if hasattr(image, "data") and isinstance(image.data, np.ndarray): + return image.data + return np.asarray(image) + + +def _mask(image: Union[NDArray, Any]) -> Union[NDArray, None]: + """ + Retrieve the mask from the image, if available. + + Args: + image: The input image. + + Returns: + The mask as a 2D NumPy array, or None if no mask is present. + """ + if hasattr(image, "mask"): + thenp_mask = np.asarray(image.mask) + return thenp_mask if thenp_mask.ndim == 2 else np.logical_or.reduce(thenp_mask, axis=-1) + return None + + +def _bw(image: NDArray) -> NDArray: + """ + Convert the input image to a 2D grayscale image. + + Args: + image: Input image, possibly with multiple channels. + + Returns: + Grayscale 2D NumPy array. + """ + return image if image.ndim == 2 else _default_average(image, axis=-1) + + +def _shape(image: NDArray) -> tuple[int, int]: + """ + Get the shape of the image, ignoring channels. + + Args: + image: Input image. + + Returns: + A tuple representing the 2D shape (height, width) of the image. + """ + return image.shape if image.ndim == 2 else image.shape[:2] + + +def find_transform( + source: Union[NDArray, Any], + target: Union[NDArray, Any], + max_control_points: int = 50, + detection_sigma: int = 5, + min_area: int = 5 +) -> Tuple[Any, Tuple[NDArray, NDArray]]: + """ + Estimate the geometric transformation between the source and target images. + + This function identifies control points (stars) in both the source and target images, + computes their triangle-based invariant features, and finds the best transformation + to align the source to the target using a RANSAC-based method. + + Args: + source: The source image to be transformed. + target: The target image to match the source to. + max_control_points: Maximum number of control points to use for transformation. + detection_sigma: Sigma threshold for detecting control points. + min_area: Minimum area for detecting sources in the image. + + Returns: + A tuple containing the estimated transformation object and a tuple of matched control points + from the source and target images. + + Raises: + ValueError: If fewer than 3 control points are found in either image. + TypeError: If the input type of source or target is unsupported. + """ + try: + source_controlp = (np.array(source)[:max_control_points] if len(_data(source)[0]) == 2 + else _find_sources(_bw(_data(source)), detection_sigma, min_area, _mask(source))[:max_control_points]) + except Exception: + raise TypeError("Input type for source not supported.") + + try: + target_controlp = (np.array(target)[:max_control_points] if len(_data(target)[0]) == 2 + else _find_sources(_bw(_data(target)), detection_sigma, min_area, _mask(target))[:max_control_points]) + except Exception: + raise TypeError("Input type for target not supported.") + + if len(source_controlp) < 3 or len(target_controlp) < 3: + raise ValueError( + "Reference stars in source or target image are less than the minimum value (3).") + + source_invariants, source_asterisms = _generate_invariants(source_controlp) + target_invariants, target_asterisms = _generate_invariants(target_controlp) + + source_tree = KDTree(source_invariants) + target_tree = KDTree(target_invariants) + + matches_list = source_tree.query_ball_tree(target_tree, r=0.1) + + matches = [list(zip(t1, t2)) for t1, t2_list in zip( + source_asterisms, matches_list) for t2 in target_asterisms[t2_list]] + matches = np.array(matches) + + inv_model = _MatchTransform(source_controlp, target_controlp) + n_invariants = len(matches) + min_matches = max(1, min(10, int(n_invariants * MIN_MATCHES_FRACTION))) + + if (len(source_controlp) == 3 or len(target_controlp) == 3) and len(matches) == 1: + best_t = inv_model.fit(matches) + inlier_ind = np.arange(len(matches)) + else: + best_t, inlier_ind = _ransac( + matches, inv_model, PIXEL_TOL, min_matches) + + triangle_inliers = matches[inlier_ind] + inl_arr = triangle_inliers.reshape(-1, 2) + inl_unique = set(map(tuple, inl_arr)) + + inl_dict = {} + for s_i, t_i in inl_unique: + s_vertex = source_controlp[s_i] + t_vertex = target_controlp[t_i] + t_vertex_pred = matrix_transform(s_vertex, best_t.params) + error = np.linalg.norm(t_vertex_pred - t_vertex) + + if s_i not in inl_dict or error < inl_dict[s_i][1]: + inl_dict[s_i] = (t_i, error) + + inl_arr_unique = np.array([[s_i, t_i] + for s_i, (t_i, _) in inl_dict.items()]) + s, d = inl_arr_unique.T + + return best_t, (source_controlp[s], target_controlp[d]) + + +def apply_transform( + transform: Any, + source: Union[NDArray, Any], + target: Union[NDArray, Any], + fill_value: Union[float, None] = None, + propagate_mask: bool = False +) -> Tuple[NDArray, NDArray]: + """ + Apply the estimated transformation to align the source image to the target image. + + The transformation is applied to the source image, and an optional mask is propagated + if requested. The function returns the aligned source image and a binary footprint + of the transformed region. + + Args: + transform: The transformation to apply. + source: The source image to be transformed. + target: The target image to align the source to. + fill_value: Value to fill in regions outside the source image after transformation. + propagate_mask: Whether to propagate the source mask after transformation. + + Returns: + A tuple containing the aligned source image and the transformation footprint. + """ + source_data = _data(source) + target_shape = _data(target).shape + + aligned_image = warp( + source_data, + inverse_map=transform.inverse, + output_shape=target_shape, + order=3, + mode="constant", + cval=_default_median(source_data), + clip=True, + preserve_range=True, + ) + + footprint = warp( + np.zeros(_shape(source_data), dtype="float32"), + inverse_map=transform.inverse, + output_shape=target_shape, + cval=1.0, + ) + footprint = footprint > 0.4 + + source_mask = _mask(source) + if source_mask is not None and propagate_mask: + if source_mask.shape == source_data.shape: + source_mask_rot = warp( + source_mask.astype("float32"), + inverse_map=transform.inverse, + output_shape=target_shape, + cval=1.0, + ) + source_mask_rot = source_mask_rot > 0.4 + footprint |= source_mask_rot + + if fill_value is not None: + aligned_image[footprint] = fill_value + + return aligned_image, footprint + + +def register( + source: Union[NDArray, Any], + target: Union[NDArray, Any], + fill_value: Union[float, None] = None, + propagate_mask: bool = False, + max_control_points: int = 50, + detection_sigma: int = 5, + min_area: int = 5 +) -> Tuple[NDArray, NDArray]: + """ + Register and align the source image to the target image using triangle invariants. + + This function estimates the transformation between the source and target images, + applies the transformation, and returns the aligned source image along with + the transformation footprint. + + Args: + source: The source image to be aligned. + target: The target image for alignment. + fill_value: Value to fill in regions outside the source image after transformation. + propagate_mask: Whether to propagate the source mask after transformation. + max_control_points: Maximum number of control points to use for transformation. + detection_sigma: Sigma threshold for detecting control points. + min_area: Minimum area for detecting sources in the image. + + Returns: + A tuple containing the aligned source image and the transformation footprint. + """ + t, _ = find_transform( + source=source, + target=target, + max_control_points=max_control_points, + detection_sigma=detection_sigma, + min_area=min_area, + ) + return apply_transform(t, source, target, fill_value, propagate_mask) + + +def _find_sources(img: NDArray, detection_sigma: int = 5, min_area: int = 5, mask: Union[NDArray, None] = None) -> NDArray: + """ + Detect bright sources (e.g., stars) in the image using SEP (Source Extractor). + + This function returns the coordinates of sources sorted by brightness. + + Args: + img: The input image in which to detect sources. + detection_sigma: Sigma threshold for source detection. + min_area: Minimum area for detecting sources. + mask: Optional mask for ignoring certain parts of the image. + + Returns: + A NumPy array of detected source coordinates (x, y), sorted by brightness. + """ + image = img.astype("float32") + bkg = sep.Background(image, mask=mask) + thresh = detection_sigma * bkg.globalrms + sources = sep.extract(image - bkg.back(), thresh, + minarea=min_area, mask=mask) + sources.sort(order="flux") + return np.array([[asrc["x"], asrc["y"]] for asrc in sources[::-1]]) + + +class MaxIterError(RuntimeError): + """ + Custom error raised if the maximum number of iterations is reached during the RANSAC process. + + This exception indicates that the RANSAC algorithm has exhausted all possible + matching triangles without finding an acceptable transformation. + """ + pass + + +def _ransac(data: NDArray, model: Any, thresh: float, min_matches: int) -> Tuple[Any, NDArray]: + """ + Fit a model to data using the RANSAC (Random Sample Consensus) algorithm. + + This robust method estimates the transformation model by iteratively fitting + to subsets of data and discarding outliers. + + Args: + data: Matched point pairs. + model: The transformation model to fit. + thresh: Error threshold to consider a data point as an inlier. + min_matches: Minimum number of inliers required to accept the model. + + Returns: + A tuple containing the best-fit model and the indices of inliers. + + Raises: + MaxIterError: If the maximum number of iterations is reached without finding + an acceptable transformation. + """ + n_data = data.shape[0] + all_idxs = np.arange(n_data) + np.random.default_rng().shuffle(all_idxs) + + for iter_i in range(n_data): + maybe_idxs = all_idxs[iter_i:iter_i + 1] + test_idxs = np.concatenate([all_idxs[:iter_i], all_idxs[iter_i + 1:]]) + maybeinliers = data[maybe_idxs, :] + test_points = data[test_idxs, :] + maybemodel = model.fit(maybeinliers) + test_err = model.get_error(test_points, maybemodel) + also_idxs = test_idxs[test_err < thresh] + alsoinliers = data[also_idxs, :] + if len(alsoinliers) >= min_matches: + good_data = np.concatenate((maybeinliers, alsoinliers)) + good_fit = model.fit(good_data) + break + else: + raise MaxIterError( + "List of matching triangles exhausted before an acceptable transformation was found") + + better_fit = good_fit + for _ in range(3): + test_err = model.get_error(data, better_fit) + better_inlier_idxs = np.arange(n_data)[test_err < thresh] + better_data = data[better_inlier_idxs] + better_fit = model.fit(better_data) + + return better_fit, better_inlier_idxs diff --git a/pysrc/image/auto_histogram/__init__.py b/pysrc/image/auto_histogram/__init__.py new file mode 100644 index 00000000..96186445 --- /dev/null +++ b/pysrc/image/auto_histogram/__init__.py @@ -0,0 +1,2 @@ +from .histogram import auto_histogram +from .utils import save_image, load_image diff --git a/pysrc/image/auto_histogram/histogram.py b/pysrc/image/auto_histogram/histogram.py new file mode 100644 index 00000000..a402b7ff --- /dev/null +++ b/pysrc/image/auto_histogram/histogram.py @@ -0,0 +1,111 @@ +import cv2 +import numpy as np +from typing import Optional, List, Tuple +from .utils import save_image, load_image + + +def auto_histogram(image: Optional[np.ndarray] = None, clip_shadow: float = 0.01, clip_highlight: float = 0.01, + target_median: int = 128, method: str = 'gamma', apply_clahe: bool = False, + clahe_clip_limit: float = 2.0, clahe_tile_grid_size: Tuple[int, int] = (8, 8), + apply_noise_reduction: bool = False, noise_reduction_method: str = 'median', + apply_sharpening: bool = False, sharpening_strength: float = 1.0, + batch_process: bool = False, file_list: Optional[List[str]] = None) -> Optional[List[np.ndarray]]: + """ + Apply automated histogram transformations and enhancements to an image or a batch of images. + + Parameters: + - image: Input image, grayscale or RGB. + - clip_shadow: Percentage of shadow pixels to clip. + - clip_highlight: Percentage of highlight pixels to clip. + - target_median: Target median value for histogram stretching. + - method: Stretching method ('gamma', 'logarithmic', 'mtf'). + - apply_clahe: Apply CLAHE (Contrast Limited Adaptive Histogram Equalization). + - clahe_clip_limit: CLAHE clip limit. + - clahe_tile_grid_size: CLAHE grid size. + - apply_noise_reduction: Apply noise reduction. + - noise_reduction_method: Noise reduction method ('median', 'gaussian'). + - apply_sharpening: Apply image sharpening. + - sharpening_strength: Strength of sharpening. + - batch_process: Enable batch processing mode. + - file_list: List of file paths for batch processing. + + Returns: + - Processed image or list of processed images. + """ + def histogram_clipping(image: np.ndarray, clip_shadow: float, clip_highlight: float) -> np.ndarray: + flat = image.flatten() + low_val = np.percentile(flat, clip_shadow * 100) + high_val = np.percentile(flat, 100 - clip_highlight * 100) + return np.clip(image, low_val, high_val) + + def gamma_transformation(image: np.ndarray, target_median: int) -> np.ndarray: + mean_val = np.median(image) + gamma = np.log(target_median / 255.0) / np.log(mean_val / 255.0) + return np.array(255 * (image / 255.0) ** gamma, dtype='uint8') + + def logarithmic_transformation(image: np.ndarray) -> np.ndarray: + c = 255 / np.log(1 + np.max(image)) + return np.array(c * np.log(1 + image), dtype='uint8') + + def mtf_transformation(image: np.ndarray, target_median: int) -> np.ndarray: + mean_val = np.median(image) + mtf = target_median / mean_val + return np.array(image * mtf, dtype='uint8') + + def apply_clahe_method(image: np.ndarray, clip_limit: float, tile_grid_size: Tuple[int, int]) -> np.ndarray: + clahe = cv2.createCLAHE(clipLimit=clip_limit, + tileGridSize=tile_grid_size) + if len(image.shape) == 2: + return clahe.apply(image) + else: + return cv2.merge([clahe.apply(channel) for channel in cv2.split(image)]) + + def noise_reduction(image: np.ndarray, method: str) -> np.ndarray: + if method == 'median': + return cv2.medianBlur(image, 3) + elif method == 'gaussian': + return cv2.GaussianBlur(image, (3, 3), 0) + else: + raise ValueError("Invalid noise reduction method specified.") + + def sharpen_image(image: np.ndarray, strength: float) -> np.ndarray: + kernel = np.array([[-1, -1, -1], [-1, 9 + strength, -1], [-1, -1, -1]]) + return cv2.filter2D(image, -1, kernel) + + def process_single_image(image: np.ndarray) -> np.ndarray: + if apply_noise_reduction: + image = noise_reduction(image, noise_reduction_method) + + image = histogram_clipping(image, clip_shadow, clip_highlight) + + if method == 'gamma': + image = gamma_transformation(image, target_median) + elif method == 'logarithmic': + image = logarithmic_transformation(image) + elif method == 'mtf': + image = mtf_transformation(image, target_median) + else: + raise ValueError("Invalid method specified.") + + if apply_clahe: + image = apply_clahe_method( + image, clahe_clip_limit, clahe_tile_grid_size) + + if apply_sharpening: + image = sharpen_image(image, sharpening_strength) + + return image + + if batch_process: + if file_list is None: + raise ValueError( + "File list cannot be None when batch processing is enabled.") + processed_images = [] + for file_path in file_list: + image = load_image(file_path, method != 'mtf') + processed_image = process_single_image(image) + processed_images.append(processed_image) + save_image(f'processed_{file_path}', processed_image) + return processed_images + else: + return process_single_image(image) diff --git a/pysrc/image/auto_histogram/processing.py b/pysrc/image/auto_histogram/processing.py new file mode 100644 index 00000000..3e64425d --- /dev/null +++ b/pysrc/image/auto_histogram/processing.py @@ -0,0 +1,28 @@ +from .histogram import auto_histogram +from .utils import save_image, load_image +import os +from typing import List + + +def process_directory(directory: str, output_directory: str, method: str = 'gamma', **kwargs): + """ + Process all images in a directory using the auto_histogram function. + + :param directory: Input directory containing images to process. + :param output_directory: Directory to save processed images. + :param method: Histogram stretching method ('gamma', 'logarithmic', 'mtf'). + :param kwargs: Additional parameters for auto_histogram. + """ + if not os.path.exists(output_directory): + os.makedirs(output_directory) + + file_list = [os.path.join(directory, file) for file in os.listdir( + directory) if file.endswith(('.jpg', '.png'))] + + processed_images = auto_histogram( + None, method=method, batch_process=True, file_list=file_list, **kwargs) + + for file, image in zip(file_list, processed_images): + output_path = os.path.join( + output_directory, f'processed_{os.path.basename(file)}') + save_image(output_path, image) diff --git a/pysrc/image/auto_histogram/utils.py b/pysrc/image/auto_histogram/utils.py new file mode 100644 index 00000000..f62a21c4 --- /dev/null +++ b/pysrc/image/auto_histogram/utils.py @@ -0,0 +1,23 @@ +import cv2 +import numpy as np +from typing import Tuple, Union + +def save_image(filepath: str, image: np.ndarray): + """ + Save an image to the specified filepath. + + :param filepath: Path to save the image. + :param image: Image data. + """ + cv2.imwrite(filepath, image) + +def load_image(filepath: str, grayscale: bool = False) -> np.ndarray: + """ + Load an image from the specified filepath. + + :param filepath: Path to load the image from. + :param grayscale: Load image as grayscale if True. + :return: Loaded image. + """ + flags = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR + return cv2.imread(filepath, flags) diff --git a/pysrc/image/channel/__init__.py b/pysrc/image/channel/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pysrc/image/channel/combination.py b/pysrc/image/channel/combination.py new file mode 100644 index 00000000..a072018c --- /dev/null +++ b/pysrc/image/channel/combination.py @@ -0,0 +1,67 @@ +from PIL import Image +import numpy as np +from skimage import color +import cv2 + +def resize_to_match(image, target_size): + return image.resize(target_size, Image.ANTIALIAS) + +def load_image_as_gray(path): + return Image.open(path).convert('L') + +def combine_channels(channels, color_space='RGB'): + match color_space: + case 'RGB': + return Image.merge("RGB", channels) + + case 'LAB': + lab_image = Image.merge("LAB", channels) + return lab_image.convert('RGB') + + case 'HSV': + hsv_image = Image.merge("HSV", channels) + return hsv_image.convert('RGB') + + case 'HSI': + hsi_image = np.dstack([np.array(ch) / 255.0 for ch in channels]) + rgb_image = color.hsv2rgb(hsi_image) # Scikit-image doesn't have direct HSI support; using HSV as proxy + return Image.fromarray((rgb_image * 255).astype(np.uint8)) + + case _: + raise ValueError(f"Unsupported color space: {color_space}") + +def channel_combination(src1_path, src2_path, src3_path, color_space='RGB'): + # Load and resize images to match + channel_1 = load_image_as_gray(src1_path) + channel_2 = load_image_as_gray(src2_path) + channel_3 = load_image_as_gray(src3_path) + + # Automatically resize images to match the size of the first image + size = channel_1.size + channel_2 = resize_to_match(channel_2, size) + channel_3 = resize_to_match(channel_3, size) + + # Combine the channels + combined_image = combine_channels([channel_1, channel_2, channel_3], color_space=color_space) + + return combined_image + +# 示例用法 +if __name__ == "__main__": + # 指定通道对应的图像路径 + src1_path = 'channel_R.png' # 对应于RGB空间的R通道或Lab空间的L通道 + src2_path = 'channel_G.png' # 对应于RGB空间的G通道或Lab空间的a*通道 + src3_path = 'channel_B.png' # 对应于RGB空间的B通道或Lab空间的b*通道 + + # 执行通道组合,保存结果 + combined_rgb = channel_combination(src1_path, src2_path, src3_path, color_space='RGB') + combined_rgb.save('combined_rgb.png') + + combined_lab = channel_combination(src1_path, src2_path, src3_path, color_space='LAB') + combined_lab.save('combined_lab.png') + + combined_hsv = channel_combination(src1_path, src2_path, src3_path, color_space='HSV') + combined_hsv.save('combined_hsv.png') + + combined_hsi = channel_combination(src1_path, src2_path, src3_path, color_space='HSI') + combined_hsi.save('combined_hsi.png') diff --git a/pysrc/image/channel/extraction.py b/pysrc/image/channel/extraction.py new file mode 100644 index 00000000..fd85d782 --- /dev/null +++ b/pysrc/image/channel/extraction.py @@ -0,0 +1,138 @@ +import cv2 +import numpy as np +import os +from matplotlib import pyplot as plt + + +def extract_channels(image, color_space='RGB'): + channels = {} + + if color_space == 'RGB': + channels['R'], channels['G'], channels['B'] = cv2.split(image) + + elif color_space == 'XYZ': + xyz_image = cv2.cvtColor(image, cv2.COLOR_BGR2XYZ) + channels['X'], channels['Y'], channels['Z'] = cv2.split(xyz_image) + + elif color_space == 'Lab': + lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2Lab) + channels['L*'], channels['a*'], channels['b*'] = cv2.split(lab_image) + + elif color_space == 'LCh': + lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2Lab) + L, a, b = cv2.split(lab_image) + h, c = cv2.cartToPolar(a, b) + channels['L*'] = L + channels['c*'] = c + channels['h*'] = h + + elif color_space == 'HSV': + hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + channels['H'], channels['S'], channels['V'] = cv2.split(hsv_image) + + elif color_space == 'HSI': + hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + H, S, V = cv2.split(hsv_image) + I = V.copy() + channels['H'] = H + channels['Si'] = S + channels['I'] = I + + elif color_space == 'YUV': + yuv_image = cv2.cvtColor(image, cv2.COLOR_BGR2YUV) + channels['Y'], channels['U'], channels['V'] = cv2.split(yuv_image) + + elif color_space == 'YCbCr': + ycbcr_image = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb) + channels['Y'], channels['Cb'], channels['Cr'] = cv2.split(ycbcr_image) + + elif color_space == 'HSL': + hsl_image = cv2.cvtColor(image, cv2.COLOR_BGR2HLS) + channels['H'], channels['S'], channels['L'] = cv2.split(hsl_image) + + elif color_space == 'CMYK': + # Trick: Convert to CMY via XYZ + cmyk_image = cv2.cvtColor(image, cv2.COLOR_BGR2XYZ) + C, M, Y = cv2.split(255 - cmyk_image) + K = np.minimum(C, np.minimum(M, Y)) + channels['C'] = C + channels['M'] = M + channels['Y'] = Y + channels['K'] = K + + return channels + + +def show_histogram(channel_data, title='Channel Histogram'): + plt.figure() + plt.title(title) + plt.xlabel('Pixel Value') + plt.ylabel('Frequency') + plt.hist(channel_data.ravel(), bins=256, range=[0, 256]) + plt.show() + + +def merge_channels(channels): + merged_image = None + channel_list = list(channels.values()) + if len(channel_list) >= 3: + merged_image = cv2.merge(channel_list[:3]) + elif len(channel_list) == 2: + merged_image = cv2.merge( + [channel_list[0], channel_list[1], np.zeros_like(channel_list[0])]) + elif len(channel_list) == 1: + merged_image = channel_list[0] + return merged_image + + +def process_directory(input_dir, output_dir, color_space='RGB'): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + for filename in os.listdir(input_dir): + if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): + image_path = os.path.join(input_dir, filename) + image = cv2.imread(image_path) + base_name = os.path.splitext(filename)[0] + extracted_channels = extract_channels(image, color_space) + for channel_name, channel_data in extracted_channels.items(): + save_path = os.path.join( + output_dir, f"{base_name}_{channel_name}.png") + cv2.imwrite(save_path, channel_data) + show_histogram( + channel_data, title=f"{base_name} - {channel_name}") + print(f"Saved {save_path}") + + +def save_channels(channels, base_name='output'): + for channel_name, channel_data in channels.items(): + filename = f"{base_name}_{channel_name}.png" + cv2.imwrite(filename, channel_data) + print(f"Saved {filename}") + + +def display_image(title, image): + cv2.imshow(title, image) + cv2.waitKey(0) + cv2.destroyAllWindows() + + +# Example usage: +image = cv2.imread('input_image.png') +extracted_channels = extract_channels(image, color_space='Lab') + +# Show histograms +for name, channel in extracted_channels.items(): + show_histogram(channel, title=f"{name} Histogram") + +# Save channels +save_channels(extracted_channels, base_name='output_image') + +# Merge channels +merged_image = merge_channels(extracted_channels) +if merged_image is not None: + display_image('Merged Image', merged_image) + cv2.imwrite('merged_image.png', merged_image) + +# Process directory +process_directory('input_images', 'output_images', color_space='HSV') diff --git a/pysrc/image/color_calibration/__init__.py b/pysrc/image/color_calibration/__init__.py new file mode 100644 index 00000000..0a89bce9 --- /dev/null +++ b/pysrc/image/color_calibration/__init__.py @@ -0,0 +1,12 @@ +from .processing import ImageProcessing +from .calibration import ColorCalibration +from .utils import select_roi +from .io import read_image, save_image + +__all__ = [ + "ImageProcessing", + "ColorCalibration", + "select_roi", + "read_image", + "save_image", +] diff --git a/pysrc/image/color_calibration/calibration.py b/pysrc/image/color_calibration/calibration.py new file mode 100644 index 00000000..6ca4cb2c --- /dev/null +++ b/pysrc/image/color_calibration/calibration.py @@ -0,0 +1,65 @@ +import cv2 +import numpy as np +from typing import Tuple +from dataclasses import dataclass + +@dataclass +class ColorCalibration: + """Class to handle color calibration tasks for astronomical images.""" + image: np.ndarray + + def calculate_color_factors(self, white_reference: np.ndarray) -> np.ndarray: + """ + Calculate the color calibration factors based on a white reference image. + + Parameters: + white_reference (np.ndarray): The white reference region. + + Returns: + np.ndarray: The calibration factors for the RGB channels. + """ + mean_values = np.mean(white_reference, axis=(0, 1)) + factors = 1.0 / mean_values + return factors + + def apply_color_calibration(self, factors: np.ndarray) -> np.ndarray: + """ + Apply color calibration to the image using provided factors. + + Parameters: + factors (np.ndarray): The RGB calibration factors. + + Returns: + np.ndarray: Color calibrated image. + """ + calibrated_image = np.zeros_like(self.image) + for i in range(3): + calibrated_image[:, :, i] = self.image[:, :, i] * factors[i] + return calibrated_image + + def automatic_white_balance(self) -> np.ndarray: + """ + Perform automatic white balance using the Gray World algorithm. + + Returns: + np.ndarray: White balanced image. + """ + mean_values = np.mean(self.image, axis=(0, 1)) + gray_value = np.mean(mean_values) + factors = gray_value / mean_values + return self.apply_color_calibration(factors) + + def match_histograms(self, reference_image: np.ndarray) -> np.ndarray: + """ + Match the color histogram of the image to a reference image. + + Parameters: + reference_image (np.ndarray): The reference image whose histogram is to be matched. + + Returns: + np.ndarray: Histogram matched image. + """ + matched_image = np.zeros_like(self.image) + for i in range(3): + matched_image[:, :, i] = exposure.match_histograms(self.image[:, :, i], reference_image[:, :, i], multichannel=False) + return matched_image diff --git a/pysrc/image/color_calibration/io.py b/pysrc/image/color_calibration/io.py new file mode 100644 index 00000000..8127f13c --- /dev/null +++ b/pysrc/image/color_calibration/io.py @@ -0,0 +1,28 @@ +import cv2 +import numpy as np + +def read_image(image_path: str) -> np.ndarray: + """ + Read an image from the given file path. + + Parameters: + image_path (str): Path to the image file. + + Returns: + np.ndarray: The image in a floating point format normalized to [0, 1]. + """ + image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 65535.0 + return image + +def save_image(image: np.ndarray, output_path: str) -> None: + """ + Save the image to the given file path. + + Parameters: + image (np.ndarray): The image to save. + output_path (str): Path to the output file. + + Returns: + None + """ + cv2.imwrite(output_path, (image * 65535).astype(np.uint16)) diff --git a/pysrc/image/color_calibration/processing.py b/pysrc/image/color_calibration/processing.py new file mode 100644 index 00000000..27ca9243 --- /dev/null +++ b/pysrc/image/color_calibration/processing.py @@ -0,0 +1,64 @@ +import cv2 +import numpy as np +from skimage import exposure, filters, feature +from dataclasses import dataclass + +@dataclass +class ImageProcessing: + """Class to handle image processing tasks for astronomical images.""" + image: np.ndarray + + def apply_gamma_correction(self, gamma: float) -> np.ndarray: + """ + Apply gamma correction to the image to adjust the contrast. + + Parameters: + gamma (float): The gamma value to apply. < 1 makes the image darker, > 1 makes it brighter. + + Returns: + np.ndarray: Gamma-corrected image. + """ + inv_gamma = 1.0 / gamma + table = np.array([((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)]).astype("float32") + corrected_image = cv2.LUT(self.image, table) + return corrected_image + + def histogram_equalization(self) -> np.ndarray: + """ + Apply histogram equalization to improve the contrast of the image. + + Returns: + np.ndarray: Histogram equalized image. + """ + equalized_image = np.zeros_like(self.image) + for i in range(3): + equalized_image[:, :, i] = exposure.equalize_hist(self.image[:, :, i]) + return equalized_image + + def denoise_image(self, h: float = 10.0) -> np.ndarray: + """ + Apply Non-Local Means Denoising to the image. + + Parameters: + h (float): Filter strength. Higher h removes noise more aggressively. + + Returns: + np.ndarray: Denoised image. + """ + denoised_image = cv2.fastNlMeansDenoisingColored(self.image, None, h, h, 7, 21) + return denoised_image + + def detect_stars(self, min_distance: int = 10, threshold_rel: float = 0.2) -> np.ndarray: + """ + Detect stars in the image using peak local max method. + + Parameters: + min_distance (int): Minimum number of pixels separating peaks in a region of 2 * min_distance + 1. + threshold_rel (float): Minimum intensity of peaks relative to the highest peak. + + Returns: + np.ndarray: Coordinates of the detected stars. + """ + gray_image = cv2.cvtColor((self.image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) + coordinates = feature.peak_local_max(gray_image, min_distance=min_distance, threshold_rel=threshold_rel) + return coordinates diff --git a/pysrc/image/color_calibration/utils.py b/pysrc/image/color_calibration/utils.py new file mode 100644 index 00000000..f10d9582 --- /dev/null +++ b/pysrc/image/color_calibration/utils.py @@ -0,0 +1,20 @@ +import cv2 +import numpy as np +from typing import Tuple + +def select_roi(image: np.ndarray, x: int, y: int, width: int, height: int) -> np.ndarray: + """ + Select a Region of Interest (ROI) from the image. + + Parameters: + image (np.ndarray): The input image. + x (int): X-coordinate of the top-left corner of the ROI. + y (int): Y-coordinate of the top-left corner of the ROI. + width (int): Width of the ROI. + height (int): Height of the ROI. + + Returns: + np.ndarray: The selected ROI. + """ + roi = image[y:y+height, x:x+width] + return roi diff --git a/pysrc/image/debayer/__init__.py b/pysrc/image/debayer/__init__.py new file mode 100644 index 00000000..e21f4ef2 --- /dev/null +++ b/pysrc/image/debayer/__init__.py @@ -0,0 +1,2 @@ +from .debayer import Debayer +from .metrics import calculate_image_quality, evaluate_image_quality diff --git a/pysrc/image/debayer/debayer.py b/pysrc/image/debayer/debayer.py new file mode 100644 index 00000000..bde2fcce --- /dev/null +++ b/pysrc/image/debayer/debayer.py @@ -0,0 +1,247 @@ +import numpy as np +import cv2 +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Tuple +import time + + +class Debayer: + def __init__(self, method: str = 'bilinear', pattern: Optional[str] = None): + """ + Initialize Debayer object with method and optional Bayer pattern. + + :param method: Debayering method ('superpixel', 'bilinear', 'vng', 'ahd', 'laplacian') + :param pattern: Bayer pattern ('BGGR', 'RGGB', 'GBRG', 'GRBG'), None for auto-detection + """ + self.method = method + self.pattern = pattern + + def detect_bayer_pattern(self, image: np.ndarray) -> str: + """ + Automatically detect Bayer pattern from the CFA image. + """ + height, width = image.shape + + # 初始化统计信息 + patterns = {'BGGR': 0, 'RGGB': 0, 'GBRG': 0, 'GRBG': 0} + + # 检测边缘并增强检测精度 + edges = cv2.Canny(image, 50, 150) + + for i in range(0, height - 1, 2): + for j in range(0, width - 1, 2): + # BGGR + patterns['BGGR'] += (image[i, j] + image[i+1, j+1]) + \ + (edges[i, j] + edges[i+1, j+1]) + # RGGB + patterns['RGGB'] += (image[i+1, j] + image[i, j+1]) + \ + (edges[i+1, j] + edges[i, j+1]) + # GBRG + patterns['GBRG'] += (image[i, j+1] + image[i+1, j]) + \ + (edges[i, j+1] + edges[i+1, j]) + # GRBG + patterns['GRBG'] += (image[i, j] + image[i+1, j+1]) + \ + (edges[i, j] + edges[i+1, j+1]) + + # 分析颜色通道的分布,进一步强化检测 + color_sums = {'BGGR': 0, 'RGGB': 0, 'GBRG': 0, 'GRBG': 0} + + # 遍历整个图像并分析色彩通道的强度分布 + for i in range(0, height - 1, 2): + for j in range(0, width - 1, 2): + block = image[i:i+2, j:j+2] + color_sums['BGGR'] += block[0, 0] + block[1, 1] # 蓝-绿对 + color_sums['RGGB'] += block[1, 0] + block[0, 1] # 红-绿对 + color_sums['GBRG'] += block[0, 1] + block[1, 0] # 绿-蓝对 + color_sums['GRBG'] += block[0, 0] + block[1, 1] # 绿-红对 + + # 综合所有信息,选择最有可能的Bayer模式 + for pattern in patterns.keys(): + patterns[pattern] += color_sums[pattern] + + detected_pattern = max(patterns, key=patterns.get) + + return detected_pattern + + def debayer_image(self, cfa_image: np.ndarray) -> np.ndarray: + """ + Perform the debayering process using the specified method. + """ + if self.pattern is None: + self.pattern = self.detect_bayer_pattern(cfa_image) + + cfa_image = self.extend_image_edges(cfa_image, pad_width=2) + print(f"Using Bayer pattern: {self.pattern}") + + if self.method == 'superpixel': + return self.debayer_superpixel(cfa_image) + elif self.method == 'bilinear': + return self.debayer_bilinear(cfa_image) + elif self.method == 'vng': + return self.debayer_vng(cfa_image) + elif self.method == 'ahd': + return self.parallel_debayer_ahd(cfa_image) + elif self.method == 'laplacian': + return self.debayer_laplacian_harmonization(cfa_image) + else: + raise ValueError(f"Unknown debayer method: {self.method}") + + def debayer_superpixel(self, cfa_image: np.ndarray) -> np.ndarray: + red = cfa_image[0::2, 0::2] + green = (cfa_image[0::2, 1::2] + cfa_image[1::2, 0::2]) / 2 + blue = cfa_image[1::2, 1::2] + + rgb_image = np.stack((red, green, blue), axis=-1) + return rgb_image + + def debayer_bilinear(self, cfa_image, pattern='BGGR'): + """ + 使用双线性插值法进行去拜耳处理。 + + :param cfa_image: 输入的CFA图像 + :param pattern: Bayer模式 ('BGGR', 'RGGB', 'GBRG', 'GRBG') + :return: 去拜耳处理后的RGB图像 + """ + if pattern == 'BGGR': + return cv2.cvtColor(cfa_image, cv2.COLOR_BayerBG2BGR) + elif pattern == 'RGGB': + return cv2.cvtColor(cfa_image, cv2.COLOR_BayerRG2BGR) + elif pattern == 'GBRG': + return cv2.cvtColor(cfa_image, cv2.COLOR_BayerGB2BGR) + elif pattern == 'GRBG': + return cv2.cvtColor(cfa_image, cv2.COLOR_BayerGR2BGR) + else: + raise ValueError(f"Unsupported Bayer pattern: {pattern}") + + def debayer_vng(self, cfa_image: np.ndarray) -> np.ndarray: + code = cv2.COLOR_BayerBG2BGR_VNG if pattern == 'BGGR' else cv2.COLOR_BayerRG2BGR_VNG + rgb_image = cv2.cvtColor(cfa_image, code) + return rgb_image + + def parallel_debayer_ahd(self, cfa_image: np.ndarray, num_threads: int = 4) -> np.ndarray: + height, width = cfa_image.shape + chunk_size = height // num_threads + + # 用于存储每个线程处理的部分图像 + results = [None] * num_threads + + def process_chunk(start_row, end_row, index): + chunk = cfa_image[start_row:end_row, :] + gradient_x, gradient_y = calculate_gradients(chunk) + green_channel = interpolate_green_channel( + chunk, gradient_x, gradient_y) + red_channel, blue_channel = interpolate_red_blue_channel( + chunk, green_channel, pattern) + rgb_chunk = np.stack( + (red_channel, green_channel, blue_channel), axis=-1) + results[index] = np.clip(rgb_chunk, 0, 255).astype(np.uint8) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + for i in range(num_threads): + start_row = i * chunk_size + end_row = (i + 1) * chunk_size if i < num_threads - \ + 1 else height + futures.append(executor.submit( + process_chunk, start_row, end_row, i)) + + concurrent.futures.wait(futures) + + # 合并处理后的块 + rgb_image = np.vstack(results) + return rgb_image + + def calculate_laplacian(self, image): + """ + 计算图像的拉普拉斯算子,用于增强边缘检测。 + + :param image: 输入的图像(灰度图像) + :return: 拉普拉斯图像 + """ + laplacian = cv2.Laplacian(image, cv2.CV_64F) + return laplacian + + def harmonize_edges(self, original, interpolated, laplacian): + """ + 使用拉普拉斯算子结果来调整插值后的图像,增强边缘细节。 + + :param original: 原始CFA图像 + :param interpolated: 双线性插值后的图像 + :param laplacian: 计算的拉普拉斯图像 + :return: 经过拉普拉斯调和的图像 + """ + return np.clip(interpolated + 0.2 * laplacian, 0, 255).astype(np.uint8) + + def debayer_laplacian_harmonization(self, cfa_image, pattern='BGGR'): + """ + 使用简化的拉普拉斯调和方法进行去拜耳处理,以增强边缘处理。 + + :param cfa_image: 输入的CFA图像 + :param pattern: Bayer模式 ('BGGR', 'RGGB', 'GBRG', 'GRBG') + :return: 去拜耳处理后的RGB图像 + """ + # Step 1: 双线性插值 + interpolated_image = self.debayer_bilinear(cfa_image, pattern) + + # Step 2: 计算每个通道的拉普拉斯图像 + laplacian_b = self.calculate_laplacian(interpolated_image[:, :, 0]) + laplacian_g = self.calculate_laplacian(interpolated_image[:, :, 1]) + laplacian_r = self.calculate_laplacian(interpolated_image[:, :, 2]) + + # Step 3: 使用拉普拉斯结果调和插值后的图像 + harmonized_b = self.harmonize_edges(cfa_image, interpolated_image[:, :, 0], laplacian_b) + harmonized_g = self.harmonize_edges(cfa_image, interpolated_image[:, :, 1], laplacian_g) + harmonized_r = self.harmonize_edges(cfa_image, interpolated_image[:, :, 2], laplacian_r) + + # Step 4: 合并调和后的通道 + harmonized_image = np.stack((harmonized_b, harmonized_g, harmonized_r), axis=-1) + + return harmonized_image + + def extend_image_edges(self, image: np.ndarray, pad_width: int) -> np.ndarray: + """ + Extend image edges using mirror padding to handle boundary issues during interpolation. + """ + return np.pad(image, pad_width, mode='reflect') + + def visualize_intermediate_steps(self, cfa_image: np.ndarray): + """ + Visualize intermediate steps in the debayering process. + """ + gradient_x, gradient_y = calculate_gradients(cfa_image) + green_channel = interpolate_green_channel( + cfa_image, gradient_x, gradient_y) + red_channel, blue_channel = interpolate_red_blue_channel( + cfa_image, green_channel, pattern) + + # 显示梯度和各通道图像 + cv2.imshow("Gradient X", gradient_x) + cv2.imshow("Gradient Y", gradient_y) + cv2.imshow("Green Channel", green_channel) + cv2.imshow("Red Channel", red_channel) + cv2.imshow("Blue Channel", blue_channel) + cv2.waitKey(0) + cv2.destroyAllWindows() + + def process_batch(self, image_paths: list, num_threads: int = 4): + """ + Batch processing for multiple CFA images using multithreading. + """ + start_time = time.time() + with ThreadPoolExecutor(max_workers=num_threads) as executor: + results = executor.map( + lambda path: self.process_single_image(path), image_paths) + + elapsed_time = time.time() - start_time + print(f"Batch processing completed in {elapsed_time:.2f} seconds.") + + def process_single_image(self, path: str) -> np.ndarray: + """ + Helper function for processing a single image. + """ + cfa_image = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + rgb_image = self.debayer_image(cfa_image) + output_path = path.replace('.png', f'_{self.method}.png') + cv2.imwrite(output_path, rgb_image) + print(f"Processed {path} -> {output_path}") + return rgb_image diff --git a/pysrc/image/debayer/metrics.py b/pysrc/image/debayer/metrics.py new file mode 100644 index 00000000..0f27e42c --- /dev/null +++ b/pysrc/image/debayer/metrics.py @@ -0,0 +1,27 @@ +import cv2 +import numpy as np +from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim + +def evaluate_image_quality(rgb_image: np.ndarray) -> dict: + """ + Evaluate the quality of the debayered image. + """ + gray_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2GRAY) + laplacian_var = cv2.Laplacian(gray_image, cv2.CV_64F).var() + + mean_colors = cv2.mean(rgb_image)[:3] + + return { + "sharpness": laplacian_var, + "mean_red": mean_colors[2], + "mean_green": mean_colors[1], + "mean_blue": mean_colors[0] + } + +def calculate_image_quality(original: np.ndarray, processed: np.ndarray) -> Tuple[float, float]: + """ + Calculate PSNR and SSIM between the original and processed images. + """ + psnr_value = psnr(original, processed, data_range=processed.max() - processed.min()) + ssim_value = ssim(original, processed, multichannel=True) + return psnr_value, ssim_value diff --git a/pysrc/image/debayer/utils.py b/pysrc/image/debayer/utils.py new file mode 100644 index 00000000..4ef5e83a --- /dev/null +++ b/pysrc/image/debayer/utils.py @@ -0,0 +1,117 @@ +from typing import Tuple +import numpy as np + +def calculate_gradients(cfa_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Calculate the gradients of a CFA image. + """ + gradient_x = np.abs(np.diff(cfa_image, axis=1)) + gradient_y = np.abs(np.diff(cfa_image, axis=0)) + + gradient_x = np.pad(gradient_x, ((0, 0), (0, 1)), 'constant') + gradient_y = np.pad(gradient_y, ((0, 1), (0, 0)), 'constant') + + return gradient_x, gradient_y + +def interpolate_green_channel(cfa_image: np.ndarray, gradient_x: np.ndarray, gradient_y: np.ndarray) -> np.ndarray: + """ + Interpolate the green channel of the CFA image. + """ + height, width = cfa_image.shape + green_channel = np.zeros((height, width)) + + for i in range(1, height - 1): + for j in range(1, width - 1): + if (i % 2 == 0 and j % 2 == 1) or (i % 2 == 1 and j % 2 == 0): + # 当前点是绿色通道点,直接赋值 + green_channel[i, j] = cfa_image[i, j] + else: + # 当前点不是绿色通道点,需要插值 + if gradient_x[i, j] < gradient_y[i, j]: + green_channel[i, j] = 0.5 * \ + (cfa_image[i, j-1] + cfa_image[i, j+1]) + else: + green_channel[i, j] = 0.5 * \ + (cfa_image[i-1, j] + cfa_image[i+1, j]) + + return green_channel + +def interpolate_red_blue_channel(cfa_image: np.ndarray, green_channel: np.ndarray, pattern: str) -> Tuple[np.ndarray, np.ndarray]: + """ + Interpolate the red and blue channels of the CFA image. + """ + height, width = cfa_image.shape + red_channel = np.zeros((height, width)) + blue_channel = np.zeros((height, width)) + + if pattern == 'BGGR': + for i in range(0, height, 2): + for j in range(0, width, 2): + blue_channel[i, j] = cfa_image[i, j] + red_channel[i+1, j+1] = cfa_image[i+1, j+1] + + green_r = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1]) + green_b = 0.5 * (green_channel[i, j] + green_channel[i+1, j+1]) + + blue_channel[i+1, j] = cfa_image[i+1, j] - \ + green_b + green_channel[i+1, j] + blue_channel[i, j+1] = cfa_image[i, j+1] - \ + green_b + green_channel[i, j+1] + red_channel[i, j] = cfa_image[i, j] - \ + green_r + green_channel[i, j] + red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ + green_r + green_channel[i+1, j+1] + + elif pattern == 'RGGB': + for i in range(0, height, 2): + for j in range(0, width, 2): + red_channel[i, j] = cfa_image[i, j] + blue_channel[i+1, j+1] = cfa_image[i+1, j+1] + + green_r = 0.5 * (green_channel[i, j+1] + green_channel[i+1, j]) + green_b = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1]) + + red_channel[i+1, j] = cfa_image[i+1, j] - \ + green_r + green_channel[i+1, j] + red_channel[i, j+1] = cfa_image[i, j+1] - \ + green_r + green_channel[i, j+1] + blue_channel[i, j] = cfa_image[i, j] - \ + green_b + green_channel[i, j] + blue_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ + green_b + green_channel[i+1, j+1] + + elif pattern == 'GBRG': + for i in range(0, height, 2): + for j in range(0, width, 2): + green_channel[i, j+1] = cfa_image[i, j+1] + blue_channel[i+1, j] = cfa_image[i+1, j] + + green_r = 0.5 * (green_channel[i, j] + green_channel[i+1, j+1]) + green_b = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1]) + + red_channel[i, j] = cfa_image[i, j] - \ + green_r + green_channel[i, j] + red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ + green_r + green_channel[i+1, j+1] + blue_channel[i, j] = cfa_image[i, j] - \ + green_b + green_channel[i, j] + blue_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ + green_b + green_channel[i+1, j+1] + + elif pattern == 'GRBG': + for i in range(0, height, 2): + for j in range(0, width, 2): + green_channel[i, j] = cfa_image[i, j] + red_channel[i+1, j] = cfa_image[i+1, j] + + green_r = 0.5 * (green_channel[i, j] + green_channel[i+1, j+1]) + green_b = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1]) + + red_channel[i, j] = cfa_image[i, j] - \ + green_r + green_channel[i, j] + red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ + green_r + green_channel[i+1, j+1] + blue_channel[i+1, j] = cfa_image[i+1, j] - \ + green_b + green_channel[i+1, j] + blue_channel[i, j+1] = cfa_image[i, j+1] - \ + green_b + green_channel[i, j+1] diff --git a/pysrc/image/defect_map/__init__.py b/pysrc/image/defect_map/__init__.py new file mode 100644 index 00000000..2e190333 --- /dev/null +++ b/pysrc/image/defect_map/__init__.py @@ -0,0 +1 @@ +from .defect_correction import defect_map_enhanced, parallel_defect_map diff --git a/pysrc/image/defect_map/defect_correction.py b/pysrc/image/defect_map/defect_correction.py new file mode 100644 index 00000000..e138f708 --- /dev/null +++ b/pysrc/image/defect_map/defect_correction.py @@ -0,0 +1,148 @@ +import numpy as np +from scipy.ndimage import generic_filter, median_filter, minimum_filter, maximum_filter, gaussian_filter +from skimage.filters import sobel +from skimage import img_as_float +from .interpolation import interpolate_defects +import multiprocessing +from typing import Optional + +def defect_map_enhanced(image: np.ndarray, defect_map: np.ndarray, operation: str = 'mean', structure: str = 'square', + radius: int = 1, is_cfa: bool = False, protect_edges: bool = False, + adaptive_structure: bool = False) -> np.ndarray: + """ + Enhanced Defect Map function to repair defective pixels in an image. + + Parameters: + - image: np.ndarray, Input image to be repaired. + - defect_map: np.ndarray, Defect map where defective pixels are marked as 0 (black). + - operation: str, Operation to perform ('mean', 'gaussian', 'minimum', 'maximum', 'median', 'bilinear', 'bicubic'). + - structure: str, Neighborhood structure ('square', 'circular', 'horizontal', 'vertical'). + - radius: int, Radius or size of the neighborhood. + - is_cfa: bool, If True, process the image as a CFA (Color Filter Array) image. + - protect_edges: bool, Whether to protect the edges of the image. + - adaptive_structure: bool, Whether to adaptively adjust the neighborhood based on defect density. + + Returns: + - corrected_image: np.ndarray, The repaired image. + """ + if structure == 'square': + footprint = np.ones((2 * radius + 1, 2 * radius + 1)) + elif structure == 'circular': + y, x = np.ogrid[-radius: radius + 1, -radius: radius + 1] + footprint = x**2 + y**2 <= radius**2 + elif structure == 'horizontal': + footprint = np.zeros((1, 2 * radius + 1)) + footprint[0, :] = 1 + elif structure == 'vertical': + footprint = np.zeros((2 * radius + 1, 1)) + footprint[:, 0] = 1 + else: + raise ValueError("Invalid structure type.") + + mask = defect_map == 0 + + if protect_edges: + edges = sobel(img_as_float(image)) + mask = np.logical_and(mask, edges < np.mean(edges)) + + corrected_image = np.copy(image) + + if is_cfa: + for c in range(3): # Assuming an RGB image + corrected_image[:, :, c] = correct_channel( + image[:, :, c], mask[:, :, c], operation, footprint, adaptive_structure) + else: + corrected_image = correct_channel( + image, mask, operation, footprint, adaptive_structure) + + return corrected_image + +def correct_channel(channel: np.ndarray, mask: np.ndarray, operation: str, footprint: np.ndarray, + adaptive_structure: bool) -> np.ndarray: + """ + Helper function to repair a single channel of an image. + + Parameters: + - channel: np.ndarray, Image channel to be repaired. + - mask: np.ndarray, Defect mask indicating defective pixels. + - operation: str, Operation to perform. + - footprint: np.ndarray, Neighborhood structure. + - adaptive_structure: bool, Whether to adaptively adjust the neighborhood size. + + Returns: + - channel: np.ndarray, The repaired image channel. + """ + if adaptive_structure: + density = np.sum(mask) / mask.size + radius = int(3 / density) if density > 0 else 1 + footprint = np.ones((2 * radius + 1, 2 * radius + 1)) + + if operation == 'mean': + channel_corrected = generic_filter( + channel, np.mean, footprint=footprint, mode='constant', cval=np.nan) + elif operation == 'gaussian': + channel_corrected = gaussian_filter(channel, sigma=1) + elif operation == 'minimum': + channel_corrected = minimum_filter(channel, footprint=footprint) + elif operation == 'maximum': + channel_corrected = maximum_filter(channel, footprint=footprint) + elif operation == 'median': + channel_corrected = median_filter(channel, footprint=footprint) + elif operation == 'bilinear': + channel_corrected = interpolate_defects(channel, mask, method='linear') + elif operation == 'bicubic': + channel_corrected = interpolate_defects(channel, mask, method='cubic') + else: + raise ValueError("Invalid operation type.") + + channel[mask] = channel_corrected[mask] + + return channel + +def parallel_defect_map(image: np.ndarray, defect_map: np.ndarray, **kwargs) -> np.ndarray: + """ + Parallel processing for defect map repair. + + Parameters: + - image: np.ndarray, Input image. + - defect_map: np.ndarray, Defect map indicating defective pixels. + - kwargs: Additional keyword arguments for defect_map_enhanced. + + Returns: + - corrected_image: np.ndarray, The repaired image. + """ + if image.ndim == 2: + return defect_map_enhanced(image, defect_map, **kwargs) + + pool = multiprocessing.Pool() + channels = [image[:, :, i] for i in range(image.shape[2])] + + task_args = [(ch, defect_map, kwargs['operation'], kwargs['structure'], kwargs['radius'], + kwargs['is_cfa'], kwargs['protect_edges'], kwargs['adaptive_structure']) for ch in channels] + + results = pool.starmap(defect_map_enhanced_single_channel, task_args) + pool.close() + pool.join() + + return np.stack(results, axis=-1) + +def defect_map_enhanced_single_channel(channel: np.ndarray, defect_map: np.ndarray, operation: str, + structure: str, radius: int, is_cfa: bool, protect_edges: bool, + adaptive_structure: bool) -> np.ndarray: + """ + Single-channel version of defect_map_enhanced for multiprocessing. + + Parameters: + - channel: np.ndarray, Single image channel. + - defect_map: np.ndarray, Defect map indicating defective pixels. + - operation: str, Operation to perform. + - structure: str, Neighborhood structure. + - radius: int, Radius or size of the neighborhood. + - is_cfa: bool, If True, process as a CFA image. + - protect_edges: bool, Whether to protect the edges of the image. + - adaptive_structure: bool, Whether to adaptively adjust the neighborhood size. + + Returns: + - channel: np.ndarray, The repaired image channel. + """ + return defect_map_enhanced(channel, defect_map, operation, structure, radius, is_cfa, protect_edges, adaptive_structure) diff --git a/pysrc/image/defect_map/interpolation.py b/pysrc/image/defect_map/interpolation.py new file mode 100644 index 00000000..1838ea80 --- /dev/null +++ b/pysrc/image/defect_map/interpolation.py @@ -0,0 +1,19 @@ +import numpy as np +from scipy.interpolate import griddata + +def interpolate_defects(image: np.ndarray, mask: np.ndarray, method: str = 'linear') -> np.ndarray: + """ + Interpolate defective pixels in an image. + + Parameters: + - image: np.ndarray, Image data. + - mask: np.ndarray, Defect mask indicating defective pixels. + - method: str, Interpolation method ('linear', 'cubic'). + + Returns: + - interpolated_image: np.ndarray, Image with interpolated defective pixels. + """ + x, y = np.indices(image.shape) + points = np.column_stack((x[~mask], y[~mask])) + values = image[~mask] + return griddata(points, values, (x, y), method=method) diff --git a/pysrc/image/defect_map/utils.py b/pysrc/image/defect_map/utils.py new file mode 100644 index 00000000..783c0933 --- /dev/null +++ b/pysrc/image/defect_map/utils.py @@ -0,0 +1,24 @@ +import cv2 +import numpy as np + +def save_image(filepath: str, image: np.ndarray) -> None: + """ + Save an image to a specified file path. + + Parameters: + - filepath: str, The file path where the image should be saved. + - image: np.ndarray, The image data to be saved. + """ + cv2.imwrite(filepath, image) + +def display_image(window_name: str, image: np.ndarray) -> None: + """ + Display an image in a new window. + + Parameters: + - window_name: str, The name of the window where the image will be displayed. + - image: np.ndarray, The image data to be displayed. + """ + cv2.imshow(window_name, image) + cv2.waitKey(0) + cv2.destroyAllWindows() diff --git a/pysrc/image/fluxcalibration/__init__.py b/pysrc/image/fluxcalibration/__init__.py new file mode 100644 index 00000000..3eafd815 --- /dev/null +++ b/pysrc/image/fluxcalibration/__init__.py @@ -0,0 +1,6 @@ +from .core import CalibrationParams +from .calibration import flux_calibration, save_to_fits +from .utils import read_fits_header, instrument_response_correction, background_noise_correction + +__all__ = ['CalibrationParams', 'flux_calibration', 'save_to_fits', + 'read_fits_header', 'instrument_response_correction', 'background_noise_correction'] diff --git a/pysrc/image/fluxcalibration/calibration.py b/pysrc/image/fluxcalibration/calibration.py new file mode 100644 index 00000000..b3d3b02b --- /dev/null +++ b/pysrc/image/fluxcalibration/calibration.py @@ -0,0 +1,68 @@ +import numpy as np +from astropy.io import fits +from .core import CalibrationParams +from .utils import instrument_response_correction, background_noise_correction + +def compute_flx2dn(params: CalibrationParams) -> float: + """ + Compute the flux conversion factor (FLX2DN). + :param params: Calibration parameters. + :return: Flux conversion factor. + """ + c = 3.0e8 # Speed of light in m/s + h = 6.626e-34 # Planck's constant in J·s + wavelength_m = params.wavelength * 1e-9 # Convert nm to meters + + aperture_area = np.pi * (params.aperture**2 - params.obstruction**2) / 4 + FLX2DN = (params.exposure_time * aperture_area * params.filter_width * + params.transmissivity * params.gain * params.quantum_efficiency * + (1 - params.extinction) * (wavelength_m / (c * h))) + return FLX2DN + +def flux_calibration(image: np.ndarray, params: CalibrationParams, + response_function: Optional[np.ndarray] = None) -> np.ndarray: + """ + Perform flux calibration on an astronomical image. + :param image: Input image (numpy array). + :param params: Calibration parameters. + :param response_function: Optional instrument response function (numpy array). + :return: Flux-calibrated and rescaled image. + """ + if response_function is not None: + image = instrument_response_correction(image, response_function) + + FLX2DN = compute_flx2dn(params) + calibrated_image = image / FLX2DN + calibrated_image = background_noise_correction(calibrated_image) + + # Rescale the image to the range [0, 1] + min_val = np.min(calibrated_image) + max_val = np.max(calibrated_image) + FLXRANGE = max_val - min_val + FLXMIN = min_val + rescaled_image = (calibrated_image - min_val) / FLXRANGE + + return rescaled_image, FLXMIN, FLXRANGE, FLX2DN + +def save_to_fits(image: np.ndarray, filename: str, FLXMIN: float, FLXRANGE: float, + FLX2DN: float, header_info: dict = {}) -> None: + """ + Save the calibrated image to a FITS file with necessary header information. + :param image: Calibrated image (numpy array). + :param filename: Output FITS file name. + :param FLXMIN: Minimum flux value used for rescaling. + :param FLXRANGE: Range of flux values used for rescaling. + :param FLX2DN: Flux conversion factor. + :param header_info: Dictionary of additional header information to include. + """ + hdu = fits.PrimaryHDU(image) + hdr = hdu.header + hdr['FLXMIN'] = FLXMIN + hdr['FLXRANGE'] = FLXRANGE + hdr['FLX2DN'] = FLX2DN + + # Add additional header information + for key, value in header_info.items(): + hdr[key] = value + + hdu.writeto(filename, overwrite=True) diff --git a/pysrc/image/fluxcalibration/core.py b/pysrc/image/fluxcalibration/core.py new file mode 100644 index 00000000..af7773d5 --- /dev/null +++ b/pysrc/image/fluxcalibration/core.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +import numpy as np +from typing import Optional + +@dataclass +class CalibrationParams: + wavelength: float # Effective filter wavelength in nm + transmissivity: float # Filter transmissivity in the range (0,1) + filter_width: float # Filter bandwidth in nm + aperture: float # Telescope aperture diameter in mm + obstruction: float # Telescope central obstruction diameter in mm + exposure_time: float # Exposure time in seconds + extinction: float # Atmospheric extinction in the range [0,1) + gain: float # Sensor gain in e-/ADU + quantum_efficiency: float # Sensor quantum efficiency in the range (0,1) diff --git a/pysrc/image/fluxcalibration/utils.py b/pysrc/image/fluxcalibration/utils.py new file mode 100644 index 00000000..b3c3739b --- /dev/null +++ b/pysrc/image/fluxcalibration/utils.py @@ -0,0 +1,41 @@ +import numpy as np +from astropy.stats import sigma_clipped_stats + +def instrument_response_correction(image: np.ndarray, response_function: np.ndarray) -> np.ndarray: + """ + Apply instrument response correction to the image. + :param image: Input image (numpy array). + :param response_function: Instrument response function (numpy array of the same shape as the image). + :return: Corrected image. + """ + return image / response_function + +def background_noise_correction(image: np.ndarray) -> np.ndarray: + """ + Estimate and subtract the background noise from the image. + :param image: Input image (numpy array). + :return: Image with background noise subtracted. + """ + _, median, _ = sigma_clipped_stats(image, sigma=3.0) # Estimate background + return image - median + +def read_fits_header(file_path: str) -> dict: + """ + Reads the FITS header and returns the necessary calibration parameters. + :param file_path: Path to the FITS file. + :return: Dictionary containing calibration parameters. + """ + with fits.open(file_path) as hdul: + header = hdul[0].header + params = { + 'wavelength': header.get('WAVELEN', 550), # nm + 'transmissivity': header.get('TRANSMIS', 0.8), + 'filter_width': header.get('FILTWDTH', 100), # nm + 'aperture': header.get('APERTURE', 200), # mm + 'obstruction': header.get('OBSTRUCT', 50), # mm + 'exposure_time': header.get('EXPTIME', 60), # seconds + 'extinction': header.get('EXTINCT', 0.1), + 'gain': header.get('GAIN', 1.5), # e-/ADU + 'quantum_efficiency': header.get('QUANTEFF', 0.9), + } + return params diff --git a/pysrc/image/image_io/__init__.py b/pysrc/image/image_io/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pysrc/image/image_io/io.py b/pysrc/image/image_io/io.py new file mode 100644 index 00000000..0b857551 --- /dev/null +++ b/pysrc/image/image_io/io.py @@ -0,0 +1,86 @@ +import cv2 +import numpy as np +from dataclasses import dataclass +from typing import Tuple + + +@dataclass +class ImageProcessor: + """ + A class that provides various image processing functionalities. + """ + + @staticmethod + def save_image(filepath: str, image: np.ndarray) -> None: + """ + Save an image to a specified file path. + + Parameters: + - filepath: str, The file path where the image should be saved. + - image: np.ndarray, The image data to be saved. + """ + cv2.imwrite(filepath, image) + + @staticmethod + def display_image(window_name: str, image: np.ndarray) -> None: + """ + Display an image in a new window. + + Parameters: + - window_name: str, The name of the window where the image will be displayed. + - image: np.ndarray, The image data to be displayed. + """ + cv2.imshow(window_name, image) + cv2.waitKey(0) + cv2.destroyAllWindows() + + @staticmethod + def load_image(filepath: str, color_mode: int = cv2.IMREAD_COLOR) -> np.ndarray: + """ + Load an image from a specified file path. + + Parameters: + - filepath: str, The file path from where the image will be loaded. + - color_mode: int, The color mode in which to load the image (default: cv2.IMREAD_COLOR). + + Returns: + - np.ndarray, The loaded image data. + """ + return cv2.imread(filepath, color_mode) + + @staticmethod + def resize_image(image: np.ndarray, size: Tuple[int, int]) -> np.ndarray: + """ + Resize an image to a new size. + + Parameters: + - image: np.ndarray, The image data to be resized. + - size: Tuple[int, int], The new size as (width, height). + + Returns: + - np.ndarray, The resized image data. + """ + return cv2.resize(image, size) + + @staticmethod + def crop_image(image: np.ndarray, start: Tuple[int, int], size: Tuple[int, int]) -> np.ndarray: + """ + Crop a region from an image. + + Parameters: + - image: np.ndarray, The image data to be cropped. + - start: Tuple[int, int], The top-left corner of the crop region as (x, y). + - size: Tuple[int, int], The size of the crop region as (width, height). + + Returns: + - np.ndarray, The cropped image data. + """ + x, y = start + w, h = size + return image[y:y+h, x:x+w] + + +# Defining the public API of the module +__all__ = [ + 'ImageProcessor' +] diff --git a/pysrc/image/raw/__init__.py b/pysrc/image/raw/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pysrc/image/raw/raw.py b/pysrc/image/raw/raw.py new file mode 100644 index 00000000..eb159edc --- /dev/null +++ b/pysrc/image/raw/raw.py @@ -0,0 +1,90 @@ +import rawpy +import cv2 +import numpy as np + +class RawImageProcessor: + def __init__(self, raw_path): + """初始化并读取RAW图像""" + self.raw_path = raw_path + self.raw = rawpy.imread(raw_path) + self.rgb_image = self.raw.postprocess( + gamma=(1.0, 1.0), + no_auto_bright=True, + use_camera_wb=True, + output_bps=8 + ) + # 转换为OpenCV使用的BGR格式 + self.bgr_image = cv2.cvtColor(self.rgb_image, cv2.COLOR_RGB2BGR) + + def adjust_contrast(self, alpha=1.0): + """调整图像对比度""" + self.bgr_image = cv2.convertScaleAbs(self.bgr_image, alpha=alpha) + + def adjust_brightness(self, beta=0): + """调整图像亮度""" + self.bgr_image = cv2.convertScaleAbs(self.bgr_image, beta=beta) + + def apply_sharpening(self): + """应用图像锐化""" + kernel = np.array([[0, -1, 0], + [-1, 5,-1], + [0, -1, 0]]) + self.bgr_image = cv2.filter2D(self.bgr_image, -1, kernel) + + def apply_gamma_correction(self, gamma=1.0): + """应用Gamma校正""" + inv_gamma = 1.0 / gamma + table = np.array([((i / 255.0) ** inv_gamma) * 255 + for i in np.arange(0, 256)]).astype("uint8") + self.bgr_image = cv2.LUT(self.bgr_image, table) + + def save_image(self, output_path, file_format="png", jpeg_quality=90): + """保存图像为指定格式""" + if file_format.lower() == "jpg" or file_format.lower() == "jpeg": + cv2.imwrite(output_path, self.bgr_image, [cv2.IMWRITE_JPEG_QUALITY, jpeg_quality]) + else: + cv2.imwrite(output_path, self.bgr_image) + + def show_image(self, window_name="Image"): + """显示处理后的图像""" + cv2.imshow(window_name, self.bgr_image) + cv2.waitKey(0) + cv2.destroyAllWindows() + + def get_bgr_image(self): + """返回处理后的BGR图像""" + return self.bgr_image + + def reset(self): + """重置图像到最初的状态""" + self.bgr_image = cv2.cvtColor(self.rgb_image, cv2.COLOR_RGB2BGR) + +# 使用示例 +if __name__ == "__main__": + # 初始化RAW图像处理器 + processor = RawImageProcessor('path_to_your_image.raw') + + # 调整对比度 + processor.adjust_contrast(alpha=1.3) + + # 调整亮度 + processor.adjust_brightness(beta=20) + + # 应用锐化 + processor.apply_sharpening() + + # 应用Gamma校正 + processor.apply_gamma_correction(gamma=1.2) + + # 显示处理后的图像 + processor.show_image() + + # 保存处理后的图像 + processor.save_image('output_image.png') + + # 重置图像 + processor.reset() + + # 进行其他处理并保存为JPEG + processor.adjust_contrast(alpha=1.1) + processor.save_image('output_image.jpg', file_format="jpg", jpeg_quality=85) diff --git a/pysrc/image/resample/__init__.py b/pysrc/image/resample/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pysrc/image/resample/resample.py b/pysrc/image/resample/resample.py new file mode 100644 index 00000000..95481218 --- /dev/null +++ b/pysrc/image/resample/resample.py @@ -0,0 +1,185 @@ +import cv2 +import numpy as np +from PIL import Image +import os +from typing import Optional, Tuple, Literal + + +def resample_image(input_image_path: str, + output_image_path: str, + width: Optional[int] = None, + height: Optional[int] = None, + scale: Optional[float] = None, + resolution: Optional[Tuple[int, int]] = None, + interpolation: int = cv2.INTER_LINEAR, + preserve_aspect_ratio: bool = True, + crop_area: Optional[Tuple[int, int, int, int]] = None, + edge_detection: bool = False, + color_space: Literal['BGR', 'GRAY', 'HSV', 'RGB'] = 'BGR', + add_watermark: bool = False, + watermark_text: str = '', + watermark_position: Tuple[int, int] = (0, 0), + watermark_opacity: float = 0.5, + batch_mode: bool = False, + output_format: str = 'jpg', + brightness: float = 1.0, + contrast: float = 1.0, + sharpen: bool = False, + rotate_angle: Optional[float] = None): + """ + Resamples an image with given dimensions, scale, resolution, and additional processing options. + + :param input_image_path: Path to the input image (or directory in batch mode) + :param output_image_path: Path to save the resampled image(s) + :param width: Desired width in pixels + :param height: Desired height in pixels + :param scale: Scale factor for resizing (e.g., 0.5 for half size, 2.0 for double size) + :param resolution: Tuple of horizontal and vertical resolution (in dpi) + :param interpolation: Interpolation method (e.g., cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_NEAREST) + :param preserve_aspect_ratio: Whether to preserve the aspect ratio of the original image + :param crop_area: Tuple defining the crop area (x, y, w, h) + :param edge_detection: Whether to apply edge detection before resizing + :param color_space: Color space to convert the image to (e.g., 'GRAY', 'HSV', 'RGB') + :param add_watermark: Whether to add a watermark to the image + :param watermark_text: The text to be used as a watermark + :param watermark_position: Position tuple (x, y) for the watermark + :param watermark_opacity: Opacity level for the watermark (0.0 to 1.0) + :param batch_mode: Whether to process multiple images in a directory + :param output_format: Output image format (e.g., 'jpg', 'png', 'tiff') + :param brightness: Brightness adjustment factor (1.0 = no change) + :param contrast: Contrast adjustment factor (1.0 = no change) + :param sharpen: Whether to apply sharpening to the image + :param rotate_angle: Angle to rotate the image (in degrees) + """ + + def adjust_brightness_contrast(image: np.ndarray, brightness: float = 1.0, contrast: float = 1.0) -> np.ndarray: + # Clip the pixel values to be in the valid range after adjustment + adjusted = cv2.convertScaleAbs( + image, alpha=contrast, beta=(brightness - 1.0) * 255) + return adjusted + + def add_text_watermark(image: np.ndarray, text: str, position: Tuple[int, int], opacity: float) -> np.ndarray: + overlay = image.copy() + output = image.copy() + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 1 + thickness = 2 + text_size = cv2.getTextSize(text, font, font_scale, thickness)[0] + text_x = position[0] + text_y = position[1] + text_size[1] + cv2.putText(overlay, text, (text_x, text_y), font, + font_scale, (255, 255, 255), thickness) + # Combine original image with overlay + cv2.addWeighted(overlay, opacity, output, 1 - opacity, 0, output) + return output + + def sharpen_image(image: np.ndarray) -> np.ndarray: + kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) + sharpened = cv2.filter2D(image, -1, kernel) + return sharpened + + def process_image(image_path: str, output_path: str): + img = cv2.imread(image_path) + if img is None: + raise ValueError(f"Cannot load image from {image_path}") + + original_height, original_width = img.shape[:2] + + # Crop if needed + if crop_area: + x, y, w, h = crop_area + img = img[y:y+h, x:x+w] + + # Edge detection + if edge_detection: + img = cv2.Canny(img, 100, 200) + + # Convert color space if needed + # Only convert if image is not already grayscale + if color_space == 'GRAY' and len(img.shape) == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # Only convert if image has 3 channels + elif color_space == 'HSV' and len(img.shape) == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + # Convert to RGB if required + elif color_space == 'RGB' and len(img.shape) == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # Calculate new dimensions + if scale: + new_width = int(original_width * scale) + new_height = int(original_height * scale) + else: + if width and height: + new_width = width + new_height = height + elif width: + new_width = width + new_height = int( + (width / original_width) * original_height) if preserve_aspect_ratio else height + elif height: + new_height = height + new_width = int((height / original_height) * + original_width) if preserve_aspect_ratio else width + else: + new_width, new_height = original_width, original_height + + # Perform resizing + resized_img = cv2.resize( + img, (new_width, new_height), interpolation=interpolation) + + # Adjust brightness and contrast + resized_img = adjust_brightness_contrast( + resized_img, brightness, contrast) + + # Apply sharpening if needed + if sharpen: + resized_img = sharpen_image(resized_img) + + # Rotate the image if needed + if rotate_angle: + center = (new_width // 2, new_height // 2) + rotation_matrix = cv2.getRotationMatrix2D( + center, rotate_angle, 1.0) + resized_img = cv2.warpAffine( + resized_img, rotation_matrix, (new_width, new_height)) + + # Add watermark if needed + if add_watermark: + resized_img = add_text_watermark( + resized_img, watermark_text, watermark_position, watermark_opacity) + + # Save the image + if resolution: + dpi_x, dpi_y = resolution + pil_img = Image.fromarray(cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB) if len( + resized_img.shape) == 3 else resized_img) + pil_img.save(output_path, dpi=(dpi_x, dpi_y), format=output_format) + else: + cv2.imwrite(output_path, resized_img) + + # Batch processing mode + if batch_mode: + if not os.path.isdir(input_image_path): + raise ValueError( + "In batch mode, input_image_path must be a directory") + if not os.path.exists(output_image_path): + os.makedirs(output_image_path) + for filename in os.listdir(input_image_path): + file_path = os.path.join(input_image_path, filename) + output_file_path = os.path.join( + output_image_path, f"{os.path.splitext(filename)[0]}.{output_format}") + process_image(file_path, output_file_path) + else: + process_image(input_image_path, output_image_path) + + +# Example usage: +input_path = 'star_image.png' +output_path = './output/star_image_processed.png' + +# Resize image, apply edge detection, adjust brightness/contrast, add watermark, sharpen, and rotate +resample_image(input_path, output_path, width=800, height=600, interpolation=cv2.INTER_CUBIC, resolution=(300, 300), + crop_area=(100, 100, 400, 400), edge_detection=True, color_space='GRAY', add_watermark=True, + watermark_text='Sample Watermark', watermark_position=(10, 10), watermark_opacity=0.7, + brightness=1.2, contrast=1.5, output_format='png', sharpen=True, rotate_angle=45) diff --git a/pysrc/image/star_detection/__init__.py b/pysrc/image/star_detection/__init__.py new file mode 100644 index 00000000..b498b982 --- /dev/null +++ b/pysrc/image/star_detection/__init__.py @@ -0,0 +1,3 @@ +from .detection import StarDetectionConfig, multiscale_detect_stars +from .clustering import cluster_stars +from .preprocessing import load_image diff --git a/pysrc/image/star_detection/clustering.py b/pysrc/image/star_detection/clustering.py new file mode 100644 index 00000000..365b9db3 --- /dev/null +++ b/pysrc/image/star_detection/clustering.py @@ -0,0 +1,32 @@ +import numpy as np +from sklearn.cluster import DBSCAN +from typing import List, Tuple + +def cluster_stars(stars: List[Tuple[int, int]], dbscan_eps: float, dbscan_min_samples: int) -> List[Tuple[int, int]]: + """ + Cluster stars using the DBSCAN algorithm. + + Parameters: + - stars: List of star positions as (x, y) tuples. + - dbscan_eps: The maximum distance between two stars for them to be considered in the same neighborhood. + - dbscan_min_samples: The number of stars in a neighborhood for a point to be considered a core point. + + Returns: + - List of clustered star positions as (x, y) tuples. + """ + if len(stars) == 0: + return [] + + clustering = DBSCAN(eps=dbscan_eps, min_samples=dbscan_min_samples).fit(stars) + labels = clustering.labels_ + + unique_labels = set(labels) + clustered_stars = [] + for label in unique_labels: + if label == -1: # -1 indicates noise + continue + class_members = [stars[i] for i in range(len(stars)) if labels[i] == label] + centroid = np.mean(class_members, axis=0).astype(int) + clustered_stars.append(tuple(centroid)) + + return clustered_stars diff --git a/pysrc/image/star_detection/detection.py b/pysrc/image/star_detection/detection.py new file mode 100644 index 00000000..e4a84298 --- /dev/null +++ b/pysrc/image/star_detection/detection.py @@ -0,0 +1,84 @@ +import cv2 +import numpy as np +from .preprocessing import apply_median_filter, wavelet_transform, inverse_wavelet_transform, binarize, detect_stars, background_subtraction +from typing import List, Tuple + +class StarDetectionConfig: + """ + Configuration class for star detection settings. + """ + def __init__(self, + median_filter_size: int = 3, + wavelet_levels: int = 4, + binarization_threshold: int = 30, + min_star_size: int = 10, + min_star_brightness: int = 20, + min_circularity: float = 0.7, + max_circularity: float = 1.3, + scales: List[float] = [1.0, 0.75, 0.5], + dbscan_eps: float = 10, + dbscan_min_samples: int = 2): + self.median_filter_size = median_filter_size + self.wavelet_levels = wavelet_levels + self.binarization_threshold = binarization_threshold + self.min_star_size = min_star_size + self.min_star_brightness = min_star_brightness + self.min_circularity = min_circularity + self.max_circularity = max_circularity + self.scales = scales + self.dbscan_eps = dbscan_eps + self.dbscan_min_samples = dbscan_min_samples + + +def multiscale_detect_stars(image: np.ndarray, config: StarDetectionConfig) -> List[Tuple[int, int]]: + """ + Detect stars in an image using multiscale analysis. + + Parameters: + - image: Grayscale input image as a numpy array. + - config: Configuration object containing detection settings. + + Returns: + - List of detected star positions as (x, y) tuples. + """ + all_stars = [] + for scale in config.scales: + resized_image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) + filtered_image = apply_median_filter(resized_image, config) + pyramid = wavelet_transform(filtered_image, config.wavelet_levels) + background = pyramid[-1] + subtracted_image = background_subtraction(filtered_image, background) + pyramid = pyramid[:2] + processed_image = inverse_wavelet_transform(pyramid) + binary_image = binarize(processed_image, config) + _, star_properties = detect_stars(binary_image) + filtered_stars = filter_stars(star_properties, binary_image, config) + # Adjust star positions back to original scale + filtered_stars = [(int(x / scale), int(y / scale)) for (x, y) in filtered_stars] + all_stars.extend(filtered_stars) + return all_stars + +def filter_stars(star_properties: List[Tuple[Tuple[int, int], float, float]], binary_image: np.ndarray, config: StarDetectionConfig) -> List[Tuple[int, int]]: + """ + Filter detected stars based on shape, size, and brightness. + + Parameters: + - star_properties: List of tuples containing star properties (center, area, perimeter). + - binary_image: Binary image used for star detection. + - config: Configuration object containing filter settings. + + Returns: + - List of filtered star positions as (x, y) tuples. + """ + filtered_stars = [] + for (center, area, perimeter) in star_properties: + circularity = (4 * np.pi * area) / (perimeter ** 2) + mask = np.zeros_like(binary_image) + cv2.circle(mask, center, 5, 255, -1) + star_pixels = cv2.countNonZero(mask) + brightness = np.mean(binary_image[mask == 255]) + if (star_pixels > config.min_star_size and + brightness > config.min_star_brightness and + config.min_circularity <= circularity <= config.max_circularity): + filtered_stars.append(center) + return filtered_stars diff --git a/pysrc/image/star_detection/preprocessing.py b/pysrc/image/star_detection/preprocessing.py new file mode 100644 index 00000000..16753dbe --- /dev/null +++ b/pysrc/image/star_detection/preprocessing.py @@ -0,0 +1,179 @@ +import cv2 +import numpy as np +from astropy.io import fits +from typing import List, Tuple + +def load_fits_image(file_path: str) -> np.ndarray: + """ + Load a FITS image from the specified file path. + + Parameters: + - file_path: Path to the FITS file. + + Returns: + - Image data as a numpy array. + """ + with fits.open(file_path) as hdul: + image_data = hdul[0].data + return image_data + +def preprocess_fits_image(image_data: np.ndarray) -> np.ndarray: + """ + Preprocess FITS image by normalizing to the 0-255 range. + + Parameters: + - image_data: Raw image data from the FITS file. + + Returns: + - Preprocessed image data as a numpy array. + """ + image_data = np.nan_to_num(image_data) + image_data = image_data.astype(np.float64) + image_data -= np.min(image_data) + image_data /= np.max(image_data) + image_data *= 255 + return image_data.astype(np.uint8) + +def load_image(file_path: str) -> np.ndarray: + """ + Load an image from the specified file path. Supports FITS and standard image formats. + + Parameters: + - file_path: Path to the image file. + + Returns: + - Loaded image as a numpy array. + """ + if file_path.endswith('.fits'): + image_data = load_fits_image(file_path) + if image_data.ndim == 2: + return preprocess_fits_image(image_data) + elif image_data.ndim == 3: + channels = [preprocess_fits_image(image_data[..., i]) for i in range(image_data.shape[2])] + return cv2.merge(channels) + else: + image = cv2.imread(file_path, cv2.IMREAD_UNCHANGED) + if image is None: + raise ValueError("Unable to load image file: {}".format(file_path)) + return image + +def apply_median_filter(image: np.ndarray, config) -> np.ndarray: + """ + Apply median filtering to the image. + + Parameters: + - image: Input image. + - config: Configuration object containing median filter settings. + + Returns: + - Filtered image. + """ + return cv2.medianBlur(image, config.median_filter_size) + +def wavelet_transform(image: np.ndarray, levels: int) -> List[np.ndarray]: + """ + Perform wavelet transform using a Laplacian pyramid. + + Parameters: + - image: Input image. + - levels: Number of levels in the wavelet transform. + + Returns: + - List of wavelet transformed images at each level. + """ + pyramid = [] + current_image = image.copy() + for _ in range(levels): + down = cv2.pyrDown(current_image) + up = cv2.pyrUp(down, current_image.shape[:2]) + + # Resize up to match the original image size + up = cv2.resize(up, (current_image.shape[1], current_image.shape[0])) + + # Calculate the difference to get the detail layer + layer = cv2.subtract(current_image, up) + pyramid.append(layer) + current_image = down + + pyramid.append(current_image) # Add the final low-resolution image + return pyramid + +def inverse_wavelet_transform(pyramid: List[np.ndarray]) -> np.ndarray: + """ + Reconstruct the image from its wavelet pyramid representation. + + Parameters: + - pyramid: List of wavelet transformed images at each level. + + Returns: + - Reconstructed image. + """ + image = pyramid.pop() + while pyramid: + up = cv2.pyrUp(image, pyramid[-1].shape[:2]) + + # Resize up to match the size of the current level + up = cv2.resize(up, (pyramid[-1].shape[1], pyramid[-1].shape[0])) + + # Add the detail layer to reconstruct the image + image = cv2.add(up, pyramid.pop()) + + return image + +def background_subtraction(image: np.ndarray, background: np.ndarray) -> np.ndarray: + """ + Subtract the background from the image using the provided background image. + + Parameters: + - image: Original image. + - background: Background image to subtract. + + Returns: + - Image with background subtracted. + """ + # Resize the background to match the original image size + background_resized = cv2.resize(background, (image.shape[1], image.shape[0])) + + # Subtract background and ensure no negative values + result = cv2.subtract(image, background_resized) + result[result < 0] = 0 + return result + +def binarize(image: np.ndarray, config) -> np.ndarray: + """ + Binarize the image using a fixed threshold. + + Parameters: + - image: Input image. + - config: Configuration object containing binarization settings. + + Returns: + - Binarized image. + """ + _, binary_image = cv2.threshold(image, config.binarization_threshold, 255, cv2.THRESH_BINARY) + return binary_image + +def detect_stars(binary_image: np.ndarray) -> Tuple[List[Tuple[int, int]], List[Tuple[Tuple[int, int], float, float]]]: + """ + Detect stars in a binary image by finding contours. + + Parameters: + - binary_image: Binarized image. + + Returns: + - Tuple containing a list of star centers and a list of star properties (center, area, perimeter). + """ + contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + star_centers = [] + star_properties = [] + + for contour in contours: + M = cv2.moments(contour) + if M['m00'] != 0: + center = (int(M['m10'] / M['m00']), int(M['m01'] / M['m00'])) + star_centers.append(center) + area = cv2.contourArea(contour) + perimeter = cv2.arcLength(contour, True) + star_properties.append((center, area, perimeter)) + + return star_centers, star_properties diff --git a/pysrc/image/star_detection/utils.py b/pysrc/image/star_detection/utils.py new file mode 100644 index 00000000..8c7fa420 --- /dev/null +++ b/pysrc/image/star_detection/utils.py @@ -0,0 +1,24 @@ +import cv2 +import numpy as np + +def save_image(filepath: str, image: np.ndarray) -> None: + """ + Save an image to a specified file path. + + Parameters: + - filepath: The file path where the image should be saved. + - image: The image data to be saved. + """ + cv2.imwrite(filepath, image) + +def display_image(window_name: str, image: np.ndarray) -> None: + """ + Display an image in a new window. + + Parameters: + - window_name: The name of the window where the image will be displayed. + - image: The image data to be displayed. + """ + cv2.imshow(window_name, image) + cv2.waitKey(0) + cv2.destroyAllWindows() diff --git a/pysrc/image/transformation/__init__.py b/pysrc/image/transformation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pysrc/image/transformation/curve.py b/pysrc/image/transformation/curve.py new file mode 100644 index 00000000..8275158a --- /dev/null +++ b/pysrc/image/transformation/curve.py @@ -0,0 +1,154 @@ +import numpy as np +import matplotlib.pyplot as plt +from scipy.interpolate import CubicSpline, Akima1DInterpolator + + +class CurvesTransformation: + def __init__(self, interpolation='akima'): + self.points = [] + self.interpolation = interpolation + self.curve = None + self.stored_curve = None + + def add_point(self, x, y): + self.points.append((x, y)) + self.points.sort() # Sort points by x value + self._update_curve() + + def remove_point(self, index): + if 0 <= index < len(self.points): + self.points.pop(index) + self._update_curve() + + def _update_curve(self): + if len(self.points) < 2: + self.curve = None + return + + x, y = zip(*self.points) + + if self.interpolation == 'cubic': + self.curve = CubicSpline(x, y) + elif self.interpolation == 'akima': + self.curve = Akima1DInterpolator(x, y) + elif self.interpolation == 'linear': + self.curve = lambda x_new: np.interp(x_new, x, y) + else: + raise ValueError("Unsupported interpolation method") + + def transform(self, image, channel=None): + if self.curve is None: + raise ValueError("No valid curve defined") + + if len(image.shape) == 2: # Grayscale image + transformed_image = self.curve(image) + elif len(image.shape) == 3: # RGB image + if channel is None: + raise ValueError("Channel must be specified for color images") + transformed_image = image.copy() + transformed_image[:, :, channel] = self.curve(image[:, :, channel]) + else: + raise ValueError("Unsupported image format") + + transformed_image = np.clip(transformed_image, 0, 1) + return transformed_image + + def plot_curve(self): + if self.curve is None: + print("No curve to plot") + return + + x_vals = np.linspace(0, 1, 100) + y_vals = self.curve(x_vals) + + plt.plot(x_vals, y_vals, label=f'Interpolation: {self.interpolation}') + plt.scatter(*zip(*self.points), color='red') + plt.title('Curves Transformation') + plt.xlabel('Input') + plt.ylabel('Output') + plt.grid(True) + plt.legend() + plt.show() + + def store_curve(self): + self.stored_curve = self.points.copy() + print("Curve stored.") + + def restore_curve(self): + if self.stored_curve: + self.points = self.stored_curve.copy() + self._update_curve() + print("Curve restored.") + else: + print("No stored curve to restore.") + + def invert_curve(self): + if self.curve is None: + print("No curve to invert") + return + + self.points = [(x, 1 - y) for x, y in self.points] + self._update_curve() + print("Curve inverted.") + + def reset_curve(self): + self.points = [(0, 0), (1, 1)] + self._update_curve() + print("Curve reset to default.") + + def pixel_readout(self, x): + if self.curve is None: + print("No curve defined") + return None + return self.curve(x) + + +# Example Usage +if __name__ == "__main__": + # Create a CurvesTransformation object + curve_transform = CurvesTransformation(interpolation='akima') + + # Add points to the curve + curve_transform.add_point(0.0, 0.0) + curve_transform.add_point(0.3, 0.5) + curve_transform.add_point(0.7, 0.8) + curve_transform.add_point(1.0, 1.0) + + # Plot the curve + curve_transform.plot_curve() + + # Store the curve + curve_transform.store_curve() + + # Invert the curve + curve_transform.invert_curve() + curve_transform.plot_curve() + + # Restore the original curve + curve_transform.restore_curve() + curve_transform.plot_curve() + + # Reset the curve to default + curve_transform.reset_curve() + curve_transform.plot_curve() + + # Generate a test image + test_image = np.linspace(0, 1, 256).reshape(16, 16) + + # Apply the transformation + transformed_image = curve_transform.transform(test_image) + + # Plot original and transformed images + plt.figure(figsize=(8, 4)) + plt.subplot(1, 2, 1) + plt.title("Original Image") + plt.imshow(test_image, cmap='gray') + + plt.subplot(1, 2, 2) + plt.title("Transformed Image") + plt.imshow(transformed_image, cmap='gray') + plt.show() + + # Pixel readout + readout_value = curve_transform.pixel_readout(0.5) + print(f"Pixel readout at x=0.5: {readout_value}") diff --git a/pysrc/image/transformation/histogram.py b/pysrc/image/transformation/histogram.py new file mode 100644 index 00000000..80f8f390 --- /dev/null +++ b/pysrc/image/transformation/histogram.py @@ -0,0 +1,121 @@ +import cv2 +import numpy as np +import matplotlib.pyplot as plt + +# 1. 直方图计算 + + +def calculate_histogram(image, channel=0): + histogram = cv2.calcHist([image], [channel], None, [256], [0, 256]) + return histogram + +# 2. 显示直方图 + + +def display_histogram(histogram, title="Histogram"): + plt.plot(histogram) + plt.title(title) + plt.xlabel('Pixel Intensity') + plt.ylabel('Frequency') + plt.show() + +# 3. 直方图变换功能 + + +def apply_histogram_transformation(image, shadows_clip=0.0, highlights_clip=1.0, midtones_balance=0.5, lower_bound=-1.0, upper_bound=2.0): + # 归一化 + normalized_image = image.astype(np.float32) / 255.0 + + # 阴影和高光裁剪 + clipped_image = np.clip( + (normalized_image - shadows_clip) / (highlights_clip - shadows_clip), 0, 1) + + # 中间调平衡 + def mtf(x): return (x**midtones_balance) / \ + ((x**midtones_balance + (1-x)**midtones_balance)**(1/midtones_balance)) + transformed_image = mtf(clipped_image) + + # 动态范围扩展 + expanded_image = np.clip( + (transformed_image - lower_bound) / (upper_bound - lower_bound), 0, 1) + + # 重新缩放至[0, 255] + output_image = (expanded_image * 255).astype(np.uint8) + return output_image + +# 4. 自动裁剪功能 + + +def auto_clip(image, clip_percent=0.01): + # 计算累积分布函数 (CDF) + hist, bins = np.histogram(image.flatten(), 256, [0, 256]) + cdf = hist.cumsum() + + # 计算裁剪点 + total_pixels = image.size + lower_clip = np.searchsorted(cdf, total_pixels * clip_percent) + upper_clip = np.searchsorted(cdf, total_pixels * (1 - clip_percent)) + + # 应用裁剪 + auto_clipped_image = apply_histogram_transformation( + image, shadows_clip=lower_clip/255.0, highlights_clip=upper_clip/255.0) + + return auto_clipped_image + +# 5. 显示原始RGB直方图 + + +def display_rgb_histogram(image): + color = ('b', 'g', 'r') + for i, col in enumerate(color): + hist = calculate_histogram(image, channel=i) + plt.plot(hist, color=col) + plt.title('RGB Histogram') + plt.xlabel('Pixel Intensity') + plt.ylabel('Frequency') + plt.show() + +# 6. 实时预览功能(简单模拟) + + +def real_time_preview(image, transformation_function, **kwargs): + preview_image = transformation_function(image, **kwargs) + cv2.imshow('Real-Time Preview', preview_image) + + +# 主程序入口 +if __name__ == "__main__": + # 加载图像 + image = cv2.imread('image.jpg') + + # 转换为灰度图像 + grayscale_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + # 显示原始图像和直方图 + cv2.imshow('Original Image', image) + histogram = calculate_histogram(grayscale_image) + display_histogram(histogram, title="Original Grayscale Histogram") + + # 显示RGB直方图 + display_rgb_histogram(image) + + # 应用直方图变换 + transformed_image = apply_histogram_transformation( + grayscale_image, shadows_clip=0.1, highlights_clip=0.9, midtones_balance=0.4, lower_bound=-0.5, upper_bound=1.5) + cv2.imshow('Transformed Image', transformed_image) + + # 显示变换后的直方图 + transformed_histogram = calculate_histogram(transformed_image) + display_histogram(transformed_histogram, + title="Transformed Grayscale Histogram") + + # 应用自动裁剪 + auto_clipped_image = auto_clip(grayscale_image, clip_percent=0.01) + cv2.imshow('Auto Clipped Image', auto_clipped_image) + + # 实时预览模拟 + real_time_preview(grayscale_image, apply_histogram_transformation, + shadows_clip=0.05, highlights_clip=0.95, midtones_balance=0.5) + + cv2.waitKey(0) + cv2.destroyAllWindows() diff --git a/pysrc/main.py b/pysrc/main.py index ff25169d..8398c8ef 100644 --- a/pysrc/main.py +++ b/pysrc/main.py @@ -9,6 +9,8 @@ from app.connection_manager import ConnectionManager from app.plugin_manager import load_plugins, start_plugin_watcher, stop_plugin_watcher, update_plugin, install_plugin, get_plugin_info, check_plugin_dependencies +from router import websocket + # 配置 loguru 日志系统 logger.add("server.log", level="DEBUG", format="{time} {level} {message}", rotation="10 MB") @@ -71,28 +73,6 @@ def get_current_username(credentials: HTTPBasicCredentials = Depends(security)): logger.info(f"Authenticated user: {credentials.username}") return credentials.username -@app.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket, username: str = Depends(get_current_username)): - """ - WebSocket endpoint for handling client connections. - """ - client_id = await manager.connect(websocket) - await manager.broadcast(f'{{"type": "Server_msg", "message": "Client {client_id} connected"}}') - - try: - async with websocket: - while True: - data = await websocket.receive_text() - logger.info(f"Received message from {client_id}: {data}") - await manager.broadcast(data) - except WebSocketDisconnect: - manager.disconnect(client_id) - await manager.broadcast(f'{{"type": "Server_msg", "message": "Client {client_id} disconnected"}}') - except Exception as e: - logger.error(f"Unexpected error with client {client_id}: {e}") - manager.disconnect(client_id) - await manager.broadcast(f'{{"type": "Server_msg", "message": "Client {client_id} disconnected due to error"}}') - # Heartbeat function to check if clients are still connected async def ping(): """ diff --git a/pysrc/router/__init__.py b/pysrc/router/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pysrc/router/websocket.py b/pysrc/router/websocket.py new file mode 100644 index 00000000..172adc4a --- /dev/null +++ b/pysrc/router/websocket.py @@ -0,0 +1,67 @@ +from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends +from app.connection_manager import ConnectionManager +from app.dependence import get_current_username +from app.command_dispatcher import CommandDispatcher +from loguru import logger +import json +import asyncio +from typing import Dict, Any + +router = APIRouter() +manager = ConnectionManager() +command_dispatcher = CommandDispatcher() + +@router.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket, username: str = Depends(get_current_username)): + client_id = await manager.connect(websocket) + await manager.broadcast(json.dumps({"type": "Server_msg", "message": f"Client {client_id} connected"})) + + try: + # Start a background task for handling heartbeats + heartbeat_task = asyncio.create_task(handle_heartbeat(client_id, websocket)) + + while True: + data = await websocket.receive_text() + await process_message(client_id, data) + except WebSocketDisconnect: + await handle_disconnect(client_id, "Client disconnected") + except Exception as e: + logger.error(f"Unexpected error with client {client_id}: {e}") + await handle_disconnect(client_id, f"Client disconnected due to error: {str(e)}") + finally: + heartbeat_task.cancel() + +async def process_message(client_id: str, data: str): + logger.info(f"Received message from {client_id}: {data}") + try: + message = json.loads(data) + if isinstance(message, dict) and "command" in message: + response = await command_dispatcher.dispatch(message["command"], message.get("params", {})) + await manager.send_personal_message(json.dumps(response), client_id) + else: + await manager.broadcast(data) + except json.JSONDecodeError: + logger.warning(f"Received invalid JSON from client {client_id}") + await manager.send_personal_message(json.dumps({"error": "Invalid JSON"}), client_id) + +async def handle_disconnect(client_id: str, message: str): + manager.disconnect(client_id) + await manager.broadcast(json.dumps({"type": "Server_msg", "message": message})) + +async def handle_heartbeat(client_id: str, websocket: WebSocket): + while True: + try: + await asyncio.sleep(30) # Send heartbeat every 30 seconds + await websocket.send_text(json.dumps({"type": "heartbeat"})) + except Exception as e: + logger.error(f"Heartbeat failed for client {client_id}: {e}") + break + +# Register commands +@command_dispatcher.register("echo") +async def echo_command(params: Dict[str, Any]) -> Dict[str, Any]: + return {"result": params.get("message", "No message provided")} + +@command_dispatcher.register("get_active_clients") +async def get_active_clients_command(params: Dict[str, Any]) -> Dict[str, Any]: + return {"result": len(manager.active_connections)} diff --git a/scripts/docker.py b/scripts/docker.py new file mode 100644 index 00000000..de22a67a --- /dev/null +++ b/scripts/docker.py @@ -0,0 +1,406 @@ +import docker +import argparse +import sys +import json +from typing import Literal, Union +from dataclasses import dataclass, asdict +from prettytable import PrettyTable +import os +import tempfile +import tarfile +import yaml + + +@dataclass +class ContainerInfo: + id: str + name: str + status: str + image: str + ports: dict + cpu_usage: float + memory_usage: float + + +class DockerManager: + def __init__(self): + self.client = docker.from_env() + + def list_containers(self, all: bool = False) -> list[ContainerInfo]: + containers = self.client.containers.list(all=all) + container_info = [] + for c in containers: + stats = c.stats(stream=False) + cpu_usage = self._calculate_cpu_percent(stats) + memory_usage = self._calculate_memory_usage(stats) + container_info.append(ContainerInfo( + c.id, c.name, c.status, + c.image.tags[0] if c.image.tags else "None", + c.ports, cpu_usage, memory_usage + )) + return container_info + + def _calculate_cpu_percent(self, stats): + cpu_delta = stats["cpu_stats"]["cpu_usage"]["total_usage"] - \ + stats["precpu_stats"]["cpu_usage"]["total_usage"] + system_delta = stats["cpu_stats"]["system_cpu_usage"] - \ + stats["precpu_stats"]["system_cpu_usage"] + if system_delta > 0.0: + return (cpu_delta / system_delta) * 100.0 + return 0.0 + + def _calculate_memory_usage(self, stats): + return stats["memory_stats"]["usage"] / stats["memory_stats"]["limit"] * 100.0 + + def manage_container(self, action: Literal["start", "stop", "restart", "remove", "pause", "unpause"], container_id: str) -> str: + try: + container = self.client.containers.get(container_id) + getattr(container, action)() + return f"Container {container_id} {action}ed" + except docker.errors.NotFound: + return f"Container {container_id} not found" + + def create_container(self, image: str, name: str, ports: dict = None, volumes: list = None, environment: dict = None) -> Union[str, ContainerInfo]: + try: + container = self.client.containers.run( + image, name=name, detach=True, + ports=ports, volumes=volumes, + environment=environment + ) + stats = container.stats(stream=False) + cpu_usage = self._calculate_cpu_percent(stats) + memory_usage = self._calculate_memory_usage(stats) + return ContainerInfo( + container.id, container.name, container.status, + container.image.tags[0] if container.image.tags else "None", + container.ports, cpu_usage, memory_usage + ) + except docker.errors.ImageNotFound: + return f"Image {image} not found" + + def pull_image(self, image: str) -> str: + try: + self.client.images.pull(image) + return f"Image {image} pulled successfully" + except docker.errors.ImageNotFound: + return f"Image {image} not found" + + def list_images(self) -> list[str]: + return [image.tags[0] for image in self.client.images.list() if image.tags] + + def get_container_logs(self, container_id: str, lines: int = 50) -> str: + try: + container = self.client.containers.get(container_id) + return container.logs(tail=lines).decode('utf-8') + except docker.errors.NotFound: + return f"Container {container_id} not found" + + def exec_command(self, container_id: str, cmd: str) -> str: + try: + container = self.client.containers.get(container_id) + exit_code, output = container.exec_run(cmd) + return f"Exit Code: {exit_code}\nOutput:\n{output.decode('utf-8')}" + except docker.errors.NotFound: + return f"Container {container_id} not found" + + def get_container_stats(self, container_id: str) -> Union[str, dict]: + try: + container = self.client.containers.get(container_id) + stats = container.stats(stream=False) + return { + "CPU Usage": f"{self._calculate_cpu_percent(stats):.2f}%", + "Memory Usage": f"{self._calculate_memory_usage(stats):.2f}%", + "Network I/O": f"In: {stats['networks']['eth0']['rx_bytes']/1024/1024:.2f}MB, Out: {stats['networks']['eth0']['tx_bytes']/1024/1024:.2f}MB", + "Block I/O": f"In: {stats['blkio_stats']['io_service_bytes_recursive'][0]['value']/1024/1024:.2f}MB, Out: {stats['blkio_stats']['io_service_bytes_recursive'][1]['value']/1024/1024:.2f}MB" + } + except docker.errors.NotFound: + return f"Container {container_id} not found" + + def copy_to_container(self, container_id: str, src: str, dest: str) -> str: + try: + container = self.client.containers.get(container_id) + with tempfile.NamedTemporaryFile() as tmp: + with tarfile.open(tmp.name, "w:gz") as tar: + tar.add(src, arcname=os.path.basename(src)) + container.put_archive(os.path.dirname(dest), tmp.read()) + return f"File {src} copied to container {container_id} at {dest}" + except docker.errors.NotFound: + return f"Container {container_id} not found" + except FileNotFoundError: + return f"Source file {src} not found" + + def copy_from_container(self, container_id: str, src: str, dest: str) -> str: + try: + container = self.client.containers.get(container_id) + bits, stat = container.get_archive(src) + with tempfile.NamedTemporaryFile() as tmp: + for chunk in bits: + tmp.write(chunk) + tmp.seek(0) + with tarfile.open(fileobj=tmp) as tar: + tar.extractall(path=dest) + return f"File {src} copied from container {container_id} to {dest}" + except docker.errors.NotFound: + return f"Container {container_id} not found" + except KeyError: + return f"Source file {src} not found in container" + + def export_container(self, container_id: str, output_path: str) -> str: + try: + container = self.client.containers.get(container_id) + with open(output_path, 'wb') as f: + for chunk in container.export(): + f.write(chunk) + return f"Container {container_id} exported to {output_path}" + except docker.errors.NotFound: + return f"Container {container_id} not found" + + def import_image(self, image_path: str, repository: str, tag: str) -> str: + try: + with open(image_path, 'rb') as f: + image = self.client.images.import_image( + f, repository=repository, tag=tag) + return f"Image imported as {repository}:{tag}" + except FileNotFoundError: + return f"Image file {image_path} not found" + + def build_image(self, dockerfile_path: str, tag: str) -> str: + try: + image, logs = self.client.images.build(path=os.path.dirname( + dockerfile_path), dockerfile=os.path.basename(dockerfile_path), tag=tag) + return f"Image built successfully with tag {tag}" + except docker.errors.BuildError as e: + return f"Error building image: {str(e)}" + + def compose_up(self, compose_file: str) -> str: + try: + project = self.client.compose.project.from_config( + project_name="myproject", config_files=[compose_file]) + project.up() + return f"Docker Compose services started from {compose_file}" + except docker.errors.APIError as e: + return f"Error starting Docker Compose services: {str(e)}" + + def compose_down(self, compose_file: str) -> str: + try: + project = self.client.compose.project.from_config( + project_name="myproject", config_files=[compose_file]) + project.down() + return f"Docker Compose services stopped from {compose_file}" + except docker.errors.APIError as e: + return f"Error stopping Docker Compose services: {str(e)}" + + +def print_table(data, headers): + table = PrettyTable() + table.field_names = headers + for row in data: + table.add_row(row) + print(table) + + +def parse_key_value_pairs(s: str) -> dict: + if not s: + return {} + return dict(item.split("=") for item in s.split(",")) + + +def main(): + parser = argparse.ArgumentParser( + description="Comprehensive Docker Management CLI Tool") + subparsers = parser.add_subparsers( + dest="command", help="Available commands") + + # List containers + list_parser = subparsers.add_parser("list", help="List containers") + list_parser.add_argument("--all", action="store_true", + help="Show all containers (default shows just running)") + list_parser.add_argument( + "--format", choices=["table", "json"], default="table", help="Output format") + + # Manage container + manage_parser = subparsers.add_parser("manage", help="Manage container") + manage_parser.add_argument("action", choices=[ + "start", "stop", "restart", "remove", "pause", "unpause"], help="Action to perform") + manage_parser.add_argument("container_id", help="Container ID") + + # Create container + create_parser = subparsers.add_parser( + "create", help="Create a new container") + create_parser.add_argument("image", help="Image name") + create_parser.add_argument("name", help="Container name") + create_parser.add_argument( + "--ports", help="Port mappings (e.g., '8080:80,9000:9000')") + create_parser.add_argument( + "--volumes", help="Volume mappings (e.g., '/host/path:/container/path')") + create_parser.add_argument( + "--env", help="Environment variables (e.g., 'KEY1=VALUE1,KEY2=VALUE2')") + + # Pull image + pull_parser = subparsers.add_parser("pull", help="Pull an image") + pull_parser.add_argument("image", help="Image name") + + # List images + subparsers.add_parser("images", help="List images") + + # View container logs + logs_parser = subparsers.add_parser("logs", help="View container logs") + logs_parser.add_argument("container_id", help="Container ID") + logs_parser.add_argument( + "--lines", type=int, default=50, help="Number of lines to display") + + # Execute command in container + exec_parser = subparsers.add_parser( + "exec", help="Execute command in container") + exec_parser.add_argument("container_id", help="Container ID") + exec_parser.add_argument("command", help="Command to execute") + + # View container statistics + stats_parser = subparsers.add_parser( + "stats", help="View container statistics") + stats_parser.add_argument("container_id", help="Container ID") + + # Copy file to container + cp_to_parser = subparsers.add_parser( + "cp-to", help="Copy file to container") + cp_to_parser.add_argument("container_id", help="Container ID") + cp_to_parser.add_argument("src", help="Source file path") + cp_to_parser.add_argument("dest", help="Destination path in container") + + # Copy file from container + cp_from_parser = subparsers.add_parser( + "cp-from", help="Copy file from container") + cp_from_parser.add_argument("container_id", help="Container ID") + cp_from_parser.add_argument("src", help="Source file path in container") + cp_from_parser.add_argument("dest", help="Destination path on host") + + # Export container + export_parser = subparsers.add_parser( + "export", help="Export container to a tar archive") + export_parser.add_argument("container_id", help="Container ID") + export_parser.add_argument("output", help="Output file path") + + # Import image + import_parser = subparsers.add_parser( + "import", help="Import an image from a tar archive") + import_parser.add_argument("image_path", help="Path to the tar archive") + import_parser.add_argument( + "repository", help="Repository name for the image") + import_parser.add_argument("tag", help="Tag for the image") + + # Build image + build_parser = subparsers.add_parser( + "build", help="Build an image from a Dockerfile") + build_parser.add_argument("dockerfile", help="Path to the Dockerfile") + build_parser.add_argument("tag", help="Tag for the image") + + # Docker Compose up + compose_up_parser = subparsers.add_parser( + "compose-up", help="Start services defined in a Docker Compose file") + compose_up_parser.add_argument( + "compose_file", help="Path to the Docker Compose file") + + # Docker Compose down + compose_down_parser = subparsers.add_parser( + "compose-down", help="Stop services defined in a Docker Compose file") + compose_down_parser.add_argument( + "compose_file", help="Path to the Docker Compose file") + + args = parser.parse_args() + + manager = DockerManager() + + try: + if args.command == "list": + containers = manager.list_containers(all=args.all) + if args.format == "json": + print(json.dumps([asdict(c) for c in containers], indent=2)) + else: + data = [[c.id[:12], c.name, c.status, c.image, + f"{c.cpu_usage:.2f}%", f"{c.memory_usage:.2f}%"] for c in containers] + headers = ["ID", "Name", "Status", + "Image", "CPU Usage", "Memory Usage"] + print_table(data, headers) + + elif args.command == "manage": + result = manager.manage_container(args.action, args.container_id) + print(result) + + elif args.command == "create": + ports = parse_key_value_pairs(args.ports) + volumes = args.volumes.split(",") if args.volumes else None + env = parse_key_value_pairs(args.env) + result = manager.create_container( + args.image, args.name, ports, volumes, env) + if isinstance(result, ContainerInfo): + print( + f"Container created - ID: {result.id}, Name: {result.name}, Status: {result.status}") + else: + print(result) + + elif args.command == "pull": + result = manager.pull_image(args.image) + print(result) + + elif args.command == "images": + images = manager.list_images() + print("Available images:") + for image in images: + print(image) + + elif args.command == "logs": + logs = manager.get_container_logs(args.container_id, args.lines) + print(f"Logs for container {args.container_id}:") + print(logs) + + elif args.command == "exec": + result = manager.exec_command(args.container_id, args.command) + print(result) + + elif args.command == "stats": + stats = manager.get_container_stats(args.container_id) + if isinstance(stats, dict): + for key, value in stats.items(): + print(f"{key}: {value}") + else: + print(stats) + + elif args.command == "cp-to": + result = manager.copy_to_container( + args.container_id, args.src, args.dest) + print(result) + + elif args.command == "cp-from": + result = manager.copy_from_container( + args.container_id, args.src, args.dest) + print(result) + + elif args.command == "export": + result = manager.export_container(args.container_id, args.output) + print(result) + + elif args.command == "import": + result = manager.import_image( + args.image_path, args.repository, args.tag) + print(result) + + elif args.command == "build": + result = manager.build_image(args.dockerfile, args.tag) + print(result) + + elif args.command == "compose-up": + result = manager.compose_up(args.compose_file) + print(result) + + elif args.command == "compose-down": + result = manager.compose_down(args.compose_file) + print(result) + + except docker.errors.APIError as e: + print(f"Docker API Error: {str(e)}") + except Exception as e: + print(f"An unexpected error occurred: {str(e)}") + + +if __name__ == "__main__": + main() diff --git a/scripts/pip.sh b/scripts/pip.sh old mode 100644 new mode 100755 index dd905c45..0a2159a8 --- a/scripts/pip.sh +++ b/scripts/pip.sh @@ -30,6 +30,9 @@ update_python_packages() { echo "Updated packages: ${updated_packages[*]}" echo "Failed packages: ${failed_packages[*]}" + + # Return updated and failed packages as a string + echo "${updated_packages[*]}|${failed_packages[*]}" } # Function to generate update report @@ -94,9 +97,9 @@ done update_result=$(update_python_packages "${EXCLUDED_PACKAGES[@]}") # Extract updated and failed packages from the result -IFS=$'\n' read -rd '' -a lines <<< "$update_result" -updated_packages=($(echo "${lines[0]}" | cut -d ':' -f2)) -failed_packages=($(echo "${lines[1]}" | cut -d ':' -f2)) +IFS='|' read -r updated_str failed_str <<< "$update_result" +IFS=' ' read -r -a updated_packages <<< "$updated_str" +IFS=' ' read -r -a failed_packages <<< "$failed_str" echo "Update completed." echo "Total packages updated: ${#updated_packages[@]}" diff --git a/src/LithiumApp.cpp b/src/LithiumApp.cpp index 4eeece38..f4d486df 100644 --- a/src/LithiumApp.cpp +++ b/src/LithiumApp.cpp @@ -284,8 +284,8 @@ void initLithiumApp(int argc, char **argv) { // AddPtr("ScriptManager", // ScriptManager::createShared(GetPtr("MessageBus"))); - AddPtr(constants::LITHIUM_DEVICE_MANAGER, DeviceManager::createShared()); - AddPtr(constants::LITHIUM_DEVICE_LOADER, ModuleLoader::createShared("./drivers")); + AddPtr(Constants::LITHIUM_DEVICE_MANAGER, DeviceManager::createShared()); + AddPtr(Constants::LITHIUM_DEVICE_LOADER, ModuleLoader::createShared("./drivers")); AddPtr("lithium.error.stack", std::make_shared()); diff --git a/src/addon/CMakeLists.txt b/src/addon/CMakeLists.txt index 28d17d6f..5d3a28af 100644 --- a/src/addon/CMakeLists.txt +++ b/src/addon/CMakeLists.txt @@ -13,7 +13,9 @@ project(lithium-addons VERSION 1.0.0 LANGUAGES C CXX) # Author: Max Qian # License: GPL3 +if (NOT MINGW OR NOT WIN32) find_package(Seccomp REQUIRED) +endif() # Project sources set(PROJECT_SOURCES diff --git a/src/addon/analysts.cpp b/src/addon/analysts.cpp index 2f3bcc9a..c35f1bf1 100644 --- a/src/addon/analysts.cpp +++ b/src/addon/analysts.cpp @@ -1,15 +1,165 @@ #include "analysts.hpp" +#include #include +#include +#include +#include #include #include "atom/error/exception.hpp" +#include "atom/log/loguru.hpp" #include "atom/type/json.hpp" namespace lithium { -void CompilerOutputParser::parseLine(const std::string& line) { + +/** + * @class HtmlBuilder + * @brief A simple utility class for building HTML documents. + * + * HtmlBuilder provides an interface to construct an HTML document by appending + * elements like headers, paragraphs, lists, etc. The HTML is stored as a string + * and can be retrieved using the `str()` method. + */ +class HtmlBuilder { +public: + HtmlBuilder(); + + // Add various HTML elements + void addTitle(const std::string& title); + void addHeader(const std::string& header, int level = 1); + void addParagraph(const std::string& text); + void addList(const std::vector& items); + + // Start and end unordered list + void startUnorderedList(); + void endUnorderedList(); + + // Add list item to an unordered list + void addListItem(const std::string& item); + + // Get the final HTML document as a string + auto str() const -> std::string; + +private: + std::ostringstream html_; ///< String stream to build the HTML document. + bool inList_{}; ///< Tracks if currently inside a list. +}; + +HtmlBuilder::HtmlBuilder() { + html_ << "\n"; + LOG_F(INFO, "HtmlBuilder created with initial HTML structure."); +} + +void HtmlBuilder::addTitle(const std::string& title) { + html_ << "" << title << "\n"; + LOG_F(INFO, "Title added: {}", title); +} + +void HtmlBuilder::addHeader(const std::string& header, int level) { + if (level < 1 || level > 6) { + level = 1; // Default to

if level is invalid + } + html_ << "" << header << "\n"; + LOG_F(INFO, "Header added: {} at level %d", header, level); +} + +void HtmlBuilder::addParagraph(const std::string& text) { + html_ << "

" << text << "

\n"; + LOG_F(INFO, "Paragraph added: {}", text); +} + +void HtmlBuilder::addList(const std::vector& items) { + html_ << "
    \n"; + for (const auto& item : items) { + html_ << "
  • " << item << "
  • \n"; + LOG_F(INFO, "List item added: {}", item); + } + html_ << "
\n"; +} + +void HtmlBuilder::startUnorderedList() { + if (!inList_) { + html_ << "
    \n"; + inList_ = true; + LOG_F(INFO, "Unordered list started."); + } +} + +void HtmlBuilder::endUnorderedList() { + if (inList_) { + html_ << "
\n"; + inList_ = false; + LOG_F(INFO, "Unordered list ended."); + } +} + +void HtmlBuilder::addListItem(const std::string& item) { + if (inList_) { + html_ << "
  • " << item << "
  • \n"; + LOG_F(INFO, "List item added inside unordered list: {}", item); + } +} + +std::string HtmlBuilder::str() const { + std::ostringstream finalHtml; + finalHtml << html_.str() << "\n"; + LOG_F(INFO, "Final HTML document generated."); + return finalHtml.str(); +} + +// Message constructor +Message::Message(MessageType t, std::string f, int l, int c, std::string code, + std::string func, std::string msg, std::string ctx) + : type(t), + file(std::move(f)), + line(l), + column(c), + errorCode(std::move(code)), + functionName(std::move(func)), + message(std::move(msg)), + context(std::move(ctx)) {} + +/** + * @class CompilerOutputParser::Impl + * @brief The private implementation for CompilerOutputParser (PIMPL idiom). + */ +class CompilerOutputParser::Impl { +public: + Impl() { + initRegexPatterns(); + LOG_F(INFO, "CompilerOutputParser::Impl initialized."); + } + + void parseLine(const std::string& line); + void parseFile(const std::string& filename); + void parseFileMultiThreaded(const std::string& filename, int numThreads); + auto getReport(bool detailed) const -> std::string; + void generateHtmlReport(const std::string& outputFilename) const; + auto generateJsonReport() -> json; + void setCustomRegexPattern(const std::string& compiler, + const std::string& pattern); + +private: + std::vector messages_; + std::unordered_map> counts_; + mutable std::unordered_map regexPatterns_; + mutable std::mutex mutex_; + std::string currentContext_; + std::regex includePattern_{R"((.*):(\d+):(\d+):)"}; + std::smatch includeMatch_; + + void initRegexPatterns(); + auto determineType(const std::string& typeStr) const -> MessageType; + auto toString(MessageType type) const -> std::string; +}; + +void CompilerOutputParser::Impl::parseLine(const std::string& line) { + LOG_F(INFO, "Parsing line: {}", line); + if (std::regex_search(line, includeMatch_, includePattern_)) { currentContext_ = includeMatch_.str(1); + LOG_F(INFO, "Context updated: {}", currentContext_); return; } @@ -25,6 +175,11 @@ void CompilerOutputParser::parseLine(const std::string& line) { std::string message = match.size() > 7 ? match.str(7) : match.str(5); + LOG_F(INFO, + "Parsed message - File: {}, Line: %d, Column: %d, ErrorCode: " + "{}, FunctionName: {}, Message: {}", + file, lineNum, column, errorCode, functionName, message); + std::lock_guard lock(mutex_); messages_.emplace_back(type, file, lineNum, column, errorCode, functionName, message, currentContext_); @@ -37,11 +192,15 @@ void CompilerOutputParser::parseLine(const std::string& line) { messages_.emplace_back(MessageType::UNKNOWN, "", 0, 0, "", "", line, currentContext_); counts_[MessageType::UNKNOWN]++; + LOG_F(WARNING, "Unknown message parsed: {}", line); } -void CompilerOutputParser::parseFile(const std::string& filename) { +void CompilerOutputParser::Impl::parseFile(const std::string& filename) { + LOG_F(INFO, "Parsing file: {}", filename); + std::ifstream inputFile(filename); if (!inputFile.is_open()) { + LOG_F(ERROR, "Failed to open file: {}", filename); THROW_FAIL_TO_OPEN_FILE("Failed to open file: " + filename); } @@ -49,12 +208,18 @@ void CompilerOutputParser::parseFile(const std::string& filename) { while (std::getline(inputFile, line)) { parseLine(line); } + + LOG_F(INFO, "Completed parsing file: {}", filename); } -void CompilerOutputParser::parseFileMultiThreaded(const std::string& filename, - int numThreads) { +void CompilerOutputParser::Impl::parseFileMultiThreaded( + const std::string& filename, int numThreads) { + LOG_F(INFO, "Parsing file multithreaded: {} with %d threads", filename, + numThreads); + std::ifstream inputFile(filename); if (!inputFile.is_open()) { + LOG_F(ERROR, "Failed to open file: {}", filename); THROW_FAIL_TO_OPEN_FILE("Failed to open file: " + filename); } @@ -64,27 +229,36 @@ void CompilerOutputParser::parseFileMultiThreaded(const std::string& filename, lines.push_back(line); } - std::vector threads; - auto worker = [this](const std::vector& lines, int start, - int end) { - for (int i = start; i < end; ++i) { - parseLine(lines[i]); + std::vector threads; + auto worker = [this](std::span lines) { + for (const auto& line : lines) { + parseLine(line); } }; int blockSize = lines.size() / numThreads; for (int i = 0; i < numThreads; ++i) { - int start = i * blockSize; - int end = (i == numThreads - 1) ? lines.size() : (i + 1) * blockSize; - threads.emplace_back(worker, std::cref(lines), start, end); + auto start = lines.begin() + i * blockSize; + auto end = (i == numThreads - 1) ? lines.end() : start + blockSize; + threads.emplace_back(worker, std::span(start, end)); + LOG_F(INFO, "Thread %d started processing lines [{} - {}]", i, + start - lines.begin(), end - lines.begin()); } + // Join the threads once processing is complete for (auto& thread : threads) { - thread.join(); + if (thread.joinable()) { + thread.join(); + } } + + LOG_F(INFO, "Multithreaded file parsing completed for file: {}", filename); } -auto CompilerOutputParser::getReport(bool detailed) const -> std::string { +auto CompilerOutputParser::Impl::getReport(bool detailed) const -> std::string { + LOG_F(INFO, "Generating report with detailed: {}", + detailed ? "true" : "false"); + std::ostringstream report; report << "Compiler Messages Report:\n"; report << "Errors: " << counts_.at(MessageType::ERROR) << "\n"; @@ -116,60 +290,69 @@ auto CompilerOutputParser::getReport(bool detailed) const -> std::string { } } + LOG_F(INFO, "Report generation completed."); return report.str(); } -void CompilerOutputParser::generateHtmlReport( +void CompilerOutputParser::Impl::generateHtmlReport( const std::string& outputFilename) const { - std::ofstream outputFile(outputFilename); - if (!outputFile.is_open()) { - THROW_FAIL_TO_OPEN_FILE("Failed to open output file: " + - outputFilename); - } + LOG_F(INFO, "Generating HTML report: {}", outputFilename); - outputFile << "\n"; - outputFile << "

    Compiler Messages Report

    \n"; - outputFile << "
      \n"; - outputFile << "
    • Errors: " << counts_.at(MessageType::ERROR) << "
    • \n"; - outputFile << "
    • Warnings: " << counts_.at(MessageType::WARNING) - << "
    • \n"; - outputFile << "
    • Notes: " << counts_.at(MessageType::NOTE) << "
    • \n"; - outputFile << "
    • Unknown: " << counts_.at(MessageType::UNKNOWN) - << "
    • \n"; - outputFile << "
    \n"; + HtmlBuilder builder; + builder.addTitle("Compiler Messages Report"); + builder.addHeader("Compiler Messages Report", 1); - outputFile << "

    Details

    \n"; - outputFile << "
      \n"; + builder.addHeader("Summary", 2); + builder.startUnorderedList(); + builder.addListItem("Errors: " + + std::to_string(counts_.at(MessageType::ERROR))); + builder.addListItem("Warnings: " + + std::to_string(counts_.at(MessageType::WARNING))); + builder.addListItem("Notes: " + + std::to_string(counts_.at(MessageType::NOTE))); + builder.addListItem("Unknown: " + + std::to_string(counts_.at(MessageType::UNKNOWN))); + builder.endUnorderedList(); + + builder.addHeader("Details", 2); + builder.startUnorderedList(); for (const auto& msg : messages_) { - outputFile << "
    • [" << toString(msg.type) << "] "; + std::string messageStr = "[" + toString(msg.type) + "] "; if (!msg.file.empty()) { - outputFile << msg.file << ":" << msg.line << ":" << msg.column - << ": "; + messageStr += msg.file + ":" + std::to_string(msg.line) + ":" + + std::to_string(msg.column) + ": "; } if (!msg.errorCode.empty()) { - outputFile << msg.errorCode << " "; + messageStr += msg.errorCode + " "; } if (!msg.functionName.empty()) { - outputFile << msg.functionName << " "; - } - outputFile << msg.message << "
    • \n"; - if (!msg.context.empty()) { - outputFile << "
    • Context: " << msg.context << "
    • \n"; - } - for (const auto& note : msg.relatedNotes) { - outputFile << "
    • Note: " << note << "
    • \n"; + messageStr += msg.functionName + " "; } + messageStr += msg.message; + + builder.addListItem(messageStr); } - outputFile << "
    \n"; - outputFile << "\n"; + builder.endUnorderedList(); + + std::ofstream outputFile(outputFilename); + if (!outputFile.is_open()) { + LOG_F(ERROR, "Failed to open output file: {}", outputFilename); + THROW_FAIL_TO_OPEN_FILE("Failed to open output file: " + + outputFilename); + } + + outputFile << builder.str(); + LOG_F(INFO, "HTML report generated and saved to: {}", outputFilename); } -auto CompilerOutputParser::generateJsonReport() -> json { +auto CompilerOutputParser::Impl::generateJsonReport() -> json { + LOG_F(INFO, "Generating JSON report."); + json root; - root["Errors"] = counts_[MessageType::ERROR]; - root["Warnings"] = counts_[MessageType::WARNING]; - root["Notes"] = counts_[MessageType::NOTE]; - root["Unknown"] = counts_[MessageType::UNKNOWN]; + root["Errors"] = counts_.at(MessageType::ERROR).load(); + root["Warnings"] = counts_.at(MessageType::WARNING).load(); + root["Notes"] = counts_.at(MessageType::NOTE).load(); + root["Unknown"] = counts_.at(MessageType::UNKNOWN).load(); for (const auto& msg : messages_) { json entry; @@ -187,16 +370,20 @@ auto CompilerOutputParser::generateJsonReport() -> json { root["Details"].push_back(entry); } + LOG_F(INFO, "JSON report generation completed."); return root; } -void CompilerOutputParser::setCustomRegexPattern(const std::string& compiler, - const std::string& pattern) { +void CompilerOutputParser::Impl::setCustomRegexPattern( + const std::string& compiler, const std::string& pattern) { + LOG_F(INFO, "Setting custom regex pattern for compiler: {}", compiler); std::lock_guard lock(mutex_); regexPatterns_[compiler] = std::regex(pattern); } -void CompilerOutputParser::initRegexPatterns() { +void CompilerOutputParser::Impl::initRegexPatterns() { + LOG_F(INFO, "Initializing regex patterns for supported compilers."); + regexPatterns_["gcc_clang"] = std::regex(R"((.*):(\d+):(\d+): (error|warning|note): (.*))"); regexPatterns_["msvc"] = @@ -205,7 +392,7 @@ void CompilerOutputParser::initRegexPatterns() { std::regex(R"((.*)\((\d+)\): (error|remark|warning|note): (.*))"); } -auto CompilerOutputParser::determineType(const std::string& typeStr) const +auto CompilerOutputParser::Impl::determineType(const std::string& typeStr) const -> MessageType { if (typeStr == "error") { return MessageType::ERROR; @@ -219,7 +406,8 @@ auto CompilerOutputParser::determineType(const std::string& typeStr) const return MessageType::UNKNOWN; } -auto CompilerOutputParser::toString(MessageType type) const -> std::string { +auto CompilerOutputParser::Impl::toString(MessageType type) const + -> std::string { switch (type) { case MessageType::ERROR: return "Error"; @@ -232,4 +420,55 @@ auto CompilerOutputParser::toString(MessageType type) const -> std::string { } } +CompilerOutputParser::CompilerOutputParser() : pImpl(std::make_unique()) { + LOG_F(INFO, "CompilerOutputParser created."); +} + +CompilerOutputParser::~CompilerOutputParser() = default; + +CompilerOutputParser::CompilerOutputParser(CompilerOutputParser&&) noexcept = + default; +CompilerOutputParser& CompilerOutputParser::operator=( + CompilerOutputParser&&) noexcept = default; + +void CompilerOutputParser::parseLine(const std::string& line) { + LOG_F(INFO, "Parsing single line."); + pImpl->parseLine(line); +} + +void CompilerOutputParser::parseFile(const std::string& filename) { + LOG_F(INFO, "Parsing file: {}", filename); + pImpl->parseFile(filename); +} + +void CompilerOutputParser::parseFileMultiThreaded(const std::string& filename, + int numThreads) { + LOG_F(INFO, "Parsing file {} with multithreading (%d threads)", filename, + numThreads); + pImpl->parseFileMultiThreaded(filename, numThreads); +} + +auto CompilerOutputParser::getReport(bool detailed) const -> std::string { + LOG_F(INFO, "Requesting report with detailed option: {}", + detailed ? "true" : "false"); + return pImpl->getReport(detailed); +} + +void CompilerOutputParser::generateHtmlReport( + const std::string& outputFilename) const { + LOG_F(INFO, "Generating HTML report to file: {}", outputFilename); + pImpl->generateHtmlReport(outputFilename); +} + +auto CompilerOutputParser::generateJsonReport() -> json { + LOG_F(INFO, "Generating JSON report."); + return pImpl->generateJsonReport(); +} + +void CompilerOutputParser::setCustomRegexPattern(const std::string& compiler, + const std::string& pattern) { + LOG_F(INFO, "Setting custom regex pattern for compiler: {}", compiler); + pImpl->setCustomRegexPattern(compiler, pattern); +} + } // namespace lithium diff --git a/src/addon/analysts.hpp b/src/addon/analysts.hpp index 3cfb0304..94a23e1f 100644 --- a/src/addon/analysts.hpp +++ b/src/addon/analysts.hpp @@ -1,20 +1,27 @@ #ifndef LITHIUM_ADDON_COMPILER_ANALYSIS_HPP #define LITHIUM_ADDON_COMPILER_ANALYSIS_HPP -#include -#include +#include #include -#include #include #include "atom/type/json_fwd.hpp" +using json = nlohmann::json; -namespace lithium { +#include "macro.hpp" -using json = nlohmann::json; +namespace lithium { +/** + * @enum MessageType + * @brief Represents the type of a compiler message. + */ enum class MessageType { ERROR, WARNING, NOTE, UNKNOWN }; +/** + * @struct Message + * @brief Holds information about a single compiler message. + */ struct Message { MessageType type; std::string file; @@ -28,11 +35,27 @@ struct Message { Message(MessageType t, std::string f, int l, int c, std::string code, std::string func, std::string msg, std::string ctx); -}; +} ATOM_ALIGNAS(128); +/** + * @class CompilerOutputParser + * @brief Parses compiler output and generates reports. + * + * Uses regular expressions to parse compiler messages from various compilers, + * supports both single-threaded and multi-threaded parsing. + */ class CompilerOutputParser { public: CompilerOutputParser(); + ~CompilerOutputParser(); + + // Delete copy constructor and copy assignment to avoid copying state + CompilerOutputParser(const CompilerOutputParser&) = delete; + CompilerOutputParser& operator=(const CompilerOutputParser&) = delete; + + // Enable move semantics + CompilerOutputParser(CompilerOutputParser&&) noexcept; + CompilerOutputParser& operator=(CompilerOutputParser&&) noexcept; void parseLine(const std::string& line); void parseFile(const std::string& filename); @@ -44,19 +67,10 @@ class CompilerOutputParser { const std::string& pattern); private: - std::vector messages_; - std::unordered_map counts_; - mutable std::unordered_map regexPatterns_; - mutable std::mutex mutex_; - std::string currentContext_; - std::regex includePattern_; - std::smatch includeMatch_; - - void initRegexPatterns(); - MessageType determineType(const std::string& typeStr) const; - std::string toString(MessageType type) const; + class Impl; // Forward declaration of implementation class + std::unique_ptr pImpl; // PIMPL idiom, pointer to implementation }; } // namespace lithium -#endif +#endif // LITHIUM_ADDON_COMPILER_ANALYSIS_HPP diff --git a/src/addon/builder.cpp b/src/addon/builder.cpp index 47568a37..22d3981f 100644 --- a/src/addon/builder.cpp +++ b/src/addon/builder.cpp @@ -6,6 +6,7 @@ #include "atom/error/exception.hpp" namespace lithium { + BuildManager::BuildManager(BuildSystemType type) { switch (type) { case BuildSystemType::CMake: @@ -20,35 +21,69 @@ BuildManager::BuildManager(BuildSystemType type) { } auto BuildManager::configureProject( - const std::string &sourceDir, const std::string &buildDir, - const std::string &buildType, - const std::vector &options) -> bool { + const std::filesystem::path& sourceDir, + const std::filesystem::path& buildDir, BuildType buildType, + const std::vector& options) -> BuildResult { return builder_->configureProject(sourceDir, buildDir, buildType, options); } -auto BuildManager::buildProject(const std::string &buildDir, int jobs) -> bool { +auto BuildManager::buildProject(const std::filesystem::path& buildDir, + std::optional jobs) -> BuildResult { return builder_->buildProject(buildDir, jobs); } -auto BuildManager::cleanProject(const std::string &buildDir) -> bool { +auto BuildManager::cleanProject(const std::filesystem::path& buildDir) + -> BuildResult { return builder_->cleanProject(buildDir); } -auto BuildManager::installProject(const std::string &buildDir, - const std::string &installDir) -> bool { +auto BuildManager::installProject(const std::filesystem::path& buildDir, + const std::filesystem::path& installDir) + -> BuildResult { return builder_->installProject(buildDir, installDir); } -auto BuildManager::runTests(const std::string &buildDir) -> bool { - return builder_->runTests(buildDir); +auto BuildManager::runTests(const std::filesystem::path& buildDir, + const std::vector& testNames) + -> BuildResult { + return builder_->runTests(buildDir, testNames); } -auto BuildManager::generateDocs(const std::string &buildDir) -> bool { - return builder_->generateDocs(buildDir); +auto BuildManager::generateDocs(const std::filesystem::path& buildDir, + const std::filesystem::path& outputDir) + -> BuildResult { + return builder_->generateDocs(buildDir, outputDir); } -auto BuildManager::loadConfig(const std::string &configPath) -> bool { +auto BuildManager::loadConfig(const std::filesystem::path& configPath) -> bool { return builder_->loadConfig(configPath); } +auto BuildManager::setLogCallback( + std::function callback) -> void { + builder_->setLogCallback(std::move(callback)); +} + +auto BuildManager::getAvailableTargets(const std::filesystem::path& buildDir) + -> std::vector { + return builder_->getAvailableTargets(buildDir); +} + +auto BuildManager::buildTarget(const std::filesystem::path& buildDir, + const std::string& target, + std::optional jobs) -> BuildResult { + return builder_->buildTarget(buildDir, target, jobs); +} + +auto BuildManager::getCacheVariables(const std::filesystem::path& buildDir) + -> std::vector> { + return builder_->getCacheVariables(buildDir); +} + +auto BuildManager::setCacheVariable(const std::filesystem::path& buildDir, + const std::string& name, + const std::string& value) -> bool { + return builder_->setCacheVariable(buildDir, name, value); +} + } // namespace lithium diff --git a/src/addon/builder.hpp b/src/addon/builder.hpp index 844162a0..bf74beb1 100644 --- a/src/addon/builder.hpp +++ b/src/addon/builder.hpp @@ -1,31 +1,64 @@ #ifndef LITHIUM_ADDON_BUILDER_HPP #define LITHIUM_ADDON_BUILDER_HPP +#include +#include #include +#include #include "platform/base.hpp" namespace lithium { + class BuildManager { public: enum class BuildSystemType { CMake, Meson }; BuildManager(BuildSystemType type); - auto configureProject(const std::string &sourceDir, - const std::string &buildDir, - const std::string &buildType, - const std::vector &options) -> bool; - auto buildProject(const std::string &buildDir, int jobs) -> bool; - auto cleanProject(const std::string &buildDir) -> bool; - auto installProject(const std::string &buildDir, - const std::string &installDir) -> bool; - auto runTests(const std::string &buildDir) -> bool; - auto generateDocs(const std::string &buildDir) -> bool; - auto loadConfig(const std::string &configPath) -> bool; + + auto configureProject( + const std::filesystem::path& sourceDir, + const std::filesystem::path& buildDir, BuildType buildType, + const std::vector& options) -> BuildResult; + + auto buildProject(const std::filesystem::path& buildDir, + std::optional jobs = std::nullopt) -> BuildResult; + + auto cleanProject(const std::filesystem::path& buildDir) -> BuildResult; + + auto installProject(const std::filesystem::path& buildDir, + const std::filesystem::path& installDir) -> BuildResult; + + auto runTests(const std::filesystem::path& buildDir, + const std::vector& testNames = {}) + -> BuildResult; + + auto generateDocs(const std::filesystem::path& buildDir, + const std::filesystem::path& outputDir) -> BuildResult; + + auto loadConfig(const std::filesystem::path& configPath) -> bool; + + auto setLogCallback(std::function callback) + -> void; + + auto getAvailableTargets(const std::filesystem::path& buildDir) + -> std::vector; + + auto buildTarget(const std::filesystem::path& buildDir, + const std::string& target, + std::optional jobs = std::nullopt) -> BuildResult; + + auto getCacheVariables(const std::filesystem::path& buildDir) + -> std::vector>; + + auto setCacheVariable(const std::filesystem::path& buildDir, + const std::string& name, + const std::string& value) -> bool; private: std::unique_ptr builder_; }; + } // namespace lithium -#endif +#endif // LITHIUM_ADDON_BUILDER_HPP diff --git a/src/addon/command.cpp b/src/addon/command.cpp index 6889d4a9..ac79d6c0 100644 --- a/src/addon/command.cpp +++ b/src/addon/command.cpp @@ -1,14 +1,19 @@ #include "command.hpp" +#include #include #include -#include +#include #include -#include -#include +#include +#include "atom/log/loguru.hpp" +#include "atom/type/json.hpp" using json = nlohmann::json; +#include "macro.hpp" + +namespace lithium { struct CompileCommand { std::string directory; std::string command; @@ -18,7 +23,13 @@ struct CompileCommand { return json{ {"directory", directory}, {"command", command}, {"file", file}}; } -}; + + void fromJson(const json& j) { + directory = j["directory"].get(); + command = j["command"].get(); + file = j["file"].get(); + } +} ATOM_ALIGNAS(128); struct CompileCommandGenerator::Impl { std::string sourceDir = "./src"; @@ -31,8 +42,12 @@ struct CompileCommandGenerator::Impl { std::string projectName = "MyProject"; std::string projectVersion = "1.0.0"; std::mutex output_mutex; + std::atomic commandCounter{0}; - std::vector getSourceFiles() { + // 获取源文件 + auto getSourceFiles() -> std::vector { + LOG_F(INFO, "Scanning source directory: {}", + sourceDir); // Log scanning process std::vector sourceFiles; for (const auto& entry : std::filesystem::directory_iterator(sourceDir)) { @@ -41,30 +56,42 @@ struct CompileCommandGenerator::Impl { for (const auto& ext : extensions) { if (path.extension() == ext) { sourceFiles.push_back(path.string()); + LOG_F(INFO, "Found source file: {}", + path.string()); // Log found file } } } } + LOG_F(INFO, "Total source files found: {}", sourceFiles.size()); return sourceFiles; } - std::vector parseExistingCommands() { + [[nodiscard]] auto parseExistingCommands() const + -> std::vector { std::vector commands; - std::ifstream ifs(existingCommandsPath); + if (existingCommandsPath.empty() || + !std::filesystem::exists(existingCommandsPath)) { + LOG_F(WARNING, "No existing compile commands found at {}", + existingCommandsPath); + return commands; + } + LOG_F(INFO, "Parsing existing compile commands from {}", + existingCommandsPath); + std::ifstream ifs(existingCommandsPath); if (ifs.is_open()) { json j; ifs >> j; for (const auto& cmd : j["commands"]) { - commands.emplace_back(cmd["directory"].get(), - cmd["command"].get(), - cmd["file"].get()); + CompileCommand c; + c.fromJson(cmd); + commands.push_back(c); } ifs.close(); + LOG_F(INFO, "Parsed {} existing compile commands", commands.size()); } else { - std::cerr << "Could not open " << existingCommandsPath << std::endl; + LOG_F(ERROR, "Failed to open {}", existingCommandsPath); } - return commands; } @@ -73,94 +100,114 @@ struct CompileCommandGenerator::Impl { compiler + " " + includeFlag + " " + outputFlag + " " + file; CompileCommand cmd{sourceDir, command, file}; + LOG_F(INFO, "Generating compile command for file: {}", file); std::lock_guard lock(output_mutex); j_commands.push_back(cmd.toJson()); + int currentCount = + commandCounter.fetch_add(1, std::memory_order_relaxed) + 1; + LOG_F(INFO, "Total commands generated so far: {}", currentCount); } void saveCommandsToFile(const json& j) { + LOG_F(INFO, "Saving compile commands to file: {}", outputPath); std::ofstream ofs(outputPath); if (ofs.is_open()) { ofs << j.dump(4); ofs.close(); - std::cout << "compile_commands.json generated successfully at " - << outputPath << "." << std::endl; + LOG_F(INFO, + "compile_commands.json generated successfully with {} " + "commands at {}.", + commandCounter.load(std::memory_order_relaxed), + outputPath); // Log success } else { - std::cerr << "Failed to create compile_commands.json." << std::endl; + LOG_F(ERROR, "Failed to open {} for writing.", outputPath); } } -}; +} ATOM_ALIGNAS(128); -CompileCommandGenerator::CompileCommandGenerator() : pImpl(new Impl) {} +CompileCommandGenerator::CompileCommandGenerator() + : impl_(std::make_unique()) {} -CompileCommandGenerator::~CompileCommandGenerator() { delete pImpl; } +CompileCommandGenerator::~CompileCommandGenerator() = default; void CompileCommandGenerator::setSourceDir(const std::string& dir) { - pImpl->sourceDir = dir; + LOG_F(INFO, "Setting source directory to {}", + dir); // Log set source directory + impl_->sourceDir = dir; } void CompileCommandGenerator::setCompiler(const std::string& compiler) { - pImpl->compiler = compiler; + LOG_F(INFO, "Setting compiler to {}", compiler); + impl_->compiler = compiler; } void CompileCommandGenerator::setIncludeFlag(const std::string& flag) { - pImpl->includeFlag = flag; + LOG_F(INFO, "Setting include flag to {}", flag); + impl_->includeFlag = flag; } void CompileCommandGenerator::setOutputFlag(const std::string& flag) { - pImpl->outputFlag = flag; + LOG_F(INFO, "Setting output flag to {}", flag); + impl_->outputFlag = flag; } void CompileCommandGenerator::setProjectName(const std::string& name) { - pImpl->projectName = name; + LOG_F(INFO, "Setting project name to {}", name); + impl_->projectName = name; } void CompileCommandGenerator::setProjectVersion(const std::string& version) { - pImpl->projectVersion = version; + LOG_F(INFO, "Setting project version to {}", + version); // Log set project version + impl_->projectVersion = version; } void CompileCommandGenerator::addExtension(const std::string& ext) { - pImpl->extensions.push_back(ext); + LOG_F(INFO, "Adding file extension: {}", ext); + impl_->extensions.push_back(ext); } void CompileCommandGenerator::setOutputPath(const std::string& path) { - pImpl->outputPath = path; + LOG_F(INFO, "Setting output path to {}", path); + impl_->outputPath = path; } void CompileCommandGenerator::setExistingCommandsPath(const std::string& path) { - pImpl->existingCommandsPath = path; + LOG_F(INFO, "Setting existing commands path to {}", + path); // Log set existing commands path + impl_->existingCommandsPath = path; } void CompileCommandGenerator::generate() { + LOG_F(INFO, "Starting compile command generation"); std::vector commands; // 解析现有的 compile_commands.json - if (!pImpl->existingCommandsPath.empty()) { - auto existing_commands = pImpl->parseExistingCommands(); - commands.insert(commands.end(), existing_commands.begin(), - existing_commands.end()); + if (!impl_->existingCommandsPath.empty()) { + auto existingCommands = impl_->parseExistingCommands(); + commands.insert(commands.end(), existingCommands.begin(), + existingCommands.end()); } - auto source_files = pImpl->getSourceFiles(); - json j_commands = json::array(); + auto sourceFiles = impl_->getSourceFiles(); + json jCommands = json::array(); - // 多线程处理构建命令 - std::vector threads; - for (const auto& file : source_files) { - threads.emplace_back(&Impl::generateCompileCommand, pImpl, file, - std::ref(j_commands)); - } - - // 等待所有线程完成 - for (auto& thread : threads) { - thread.join(); - } + LOG_F(INFO, "Generating compile commands for {} source files", + sourceFiles.size()); // Log number of files + std::ranges::for_each(sourceFiles.begin(), sourceFiles.end(), + [&](const std::string& file) { + impl_->generateCompileCommand(file, jCommands); + }); // 构建最终的 JSON json j = {{"version", 4}, - {"project_name", pImpl->projectName}, - {"project_version", pImpl->projectVersion}, - {"commands", j_commands}}; + {"project_name", impl_->projectName}, + {"project_version", impl_->projectVersion}, + {"commands", jCommands}}; // 保存到文件 - pImpl->saveCommandsToFile(j); + impl_->saveCommandsToFile(j); + LOG_F(INFO, "Compile command generation complete"); } + +} // namespace lithium diff --git a/src/addon/command.hpp b/src/addon/command.hpp index 2e9deacd..54ec1598 100644 --- a/src/addon/command.hpp +++ b/src/addon/command.hpp @@ -1,14 +1,19 @@ #ifndef COMPILE_COMMAND_GENERATOR_H #define COMPILE_COMMAND_GENERATOR_H +#include #include -#include +#include "atom/type/json.hpp" +using json = nlohmann::json; + +namespace lithium { class CompileCommandGenerator { public: CompileCommandGenerator(); ~CompileCommandGenerator(); + void setSourceDir(const std::string& dir); void setCompiler(const std::string& compiler); void setIncludeFlag(const std::string& flag); @@ -22,8 +27,9 @@ class CompileCommandGenerator { void generate(); private: - struct Impl; // Forward declaration of the implementation class - Impl* pImpl; // Pointer to the implementation + struct Impl; + std::unique_ptr impl_; }; +} // namespace lithium #endif // COMPILE_COMMAND_GENERATOR_H diff --git a/src/addon/compiler.cpp b/src/addon/compiler.cpp index 39cb68e2..40b2f147 100644 --- a/src/addon/compiler.cpp +++ b/src/addon/compiler.cpp @@ -1,11 +1,11 @@ #include "compiler.hpp" +#include "command.hpp" #include "io/io.hpp" #include "toolchain.hpp" #include "utils/constant.hpp" #include -#include #include #include #include "atom/log/loguru.hpp" @@ -30,11 +30,14 @@ class CompilerImpl { auto getAvailableCompilers() const -> std::vector; + void generateCompileCommands(const std::string &sourceDir); + private: void createOutputDirectory(const fs::path &outputDir); - auto syntaxCheck(std::string_view code, std::string_view compiler) -> bool; - auto compileCode(std::string_view code, std::string_view compiler, - std::string_view compileOptions, + auto syntaxCheck(std::string_view code, + const std::string &compiler) -> bool; + auto compileCode(std::string_view code, const std::string &compiler, + const std::string &compileOptions, const fs::path &output) -> bool; auto findAvailableCompilers() -> std::vector; @@ -42,6 +45,7 @@ class CompilerImpl { std::unordered_map cache_; std::string customCompileOptions_; ToolchainManager toolchainManager_; + std::unique_ptr compileCommandGenerator_; }; Compiler::Compiler() : impl_(std::make_unique()) {} @@ -64,19 +68,42 @@ auto Compiler::getAvailableCompilers() const -> std::vector { return impl_->getAvailableCompilers(); } -CompilerImpl::CompilerImpl() { toolchainManager_.scanForToolchains(); } +CompilerImpl::CompilerImpl() + : compileCommandGenerator_(std::make_unique()) { + LOG_F(INFO, "Initializing CompilerImpl..."); + toolchainManager_.scanForToolchains(); + LOG_F(INFO, "Toolchains scanned."); + compileCommandGenerator_->setCompiler(Constants::COMPILER); + compileCommandGenerator_->setIncludeFlag("-I./include"); + compileCommandGenerator_->setOutputFlag("-o output"); + compileCommandGenerator_->addExtension(".cpp"); + compileCommandGenerator_->addExtension(".c"); + LOG_F(INFO, "CompileCommandGenerator initialized with default settings."); +} + +void CompilerImpl::generateCompileCommands(const std::string &sourceDir) { + LOG_F(INFO, "Generating compile commands for source directory: {}", + sourceDir); + compileCommandGenerator_->setSourceDir(sourceDir); + compileCommandGenerator_->setOutputPath("compile_commands.json"); + compileCommandGenerator_->generate(); + LOG_F(INFO, "Compile commands generation complete for directory: {}", + sourceDir); +} auto CompilerImpl::compileToSharedLibrary( std::string_view code, std::string_view moduleName, std::string_view functionName, std::string_view optionsFile) -> bool { - LOG_F(INFO, "Compiling module {}::{}...", moduleName, functionName); + LOG_F(INFO, "Compiling module {}::{} with options file: {}", moduleName, + functionName, optionsFile); if (code.empty() || moduleName.empty() || functionName.empty()) { - LOG_F(ERROR, "Invalid parameters."); + LOG_F( + ERROR, + "Invalid parameters: code, moduleName, or functionName is empty."); return false; } - // 检查模块是否已编译并缓存 std::string cacheKey = std::format("{}::{}", moduleName, functionName); if (cache_.find(cacheKey) != cache_.end()) { LOG_F(WARNING, "Module {} is already compiled, using cached result.", @@ -84,11 +111,9 @@ auto CompilerImpl::compileToSharedLibrary( return true; } - // 创建输出目录 const fs::path OUTPUT_DIR = "atom/global"; createOutputDirectory(OUTPUT_DIR); - // Max: 检查可用的编译器,然后再和指定的进行比对 auto availableCompilers = findAvailableCompilers(); if (availableCompilers.empty()) { LOG_F(ERROR, "No available compilers found."); @@ -97,63 +122,101 @@ auto CompilerImpl::compileToSharedLibrary( LOG_F(INFO, "Available compilers: {}", atom::utils::toString(availableCompilers)); - // 读取编译选项 + // Read compile options std::ifstream optionsStream(optionsFile.data()); - std::string compileOptions = [&] -> std::string { - if (!optionsStream) { - LOG_F( - WARNING, - "Failed to open compile options file, using default options."); - return "-O2 -std=c++20 -Wall -shared -fPIC"; - } - + json optionsJson; + if (!optionsStream) { + LOG_F(WARNING, + "Failed to open compile options file {}, using default options.", + optionsFile); + optionsJson = {{"compiler", Constants::COMPILER}, + {"optimization_level", "-O2"}, + {"cplus_version", "-std=c++20"}, + {"warnings", "-Wall"}}; + } else { try { - json optionsJson; optionsStream >> optionsJson; - - auto compiler = optionsJson.value("compiler", constants::COMPILER); - if (std::find(availableCompilers.begin(), availableCompilers.end(), - compiler) == availableCompilers.end()) { - LOG_F(WARNING, "Compiler {} is not available, using default.", - compiler); - compiler = constants::COMPILER; - } - - auto cmd = std::format( - "{} {} {} {} {}", compiler, - optionsJson.value("optimization_level", "-O2"), - optionsJson.value("cplus_version", "-std=c++20"), - optionsJson.value("warnings", "-Wall"), customCompileOptions_); - - LOG_F(INFO, "Compile options: {}", cmd); - return cmd; - + LOG_F(INFO, "Compile options file {} successfully parsed.", + optionsFile); } catch (const json::parse_error &e) { - LOG_F(ERROR, "Failed to parse compile options file: {}", e.what()); - } catch (const std::exception &e) { - LOG_F(ERROR, "Failed to parse compile options file: {}", e.what()); + LOG_F(ERROR, "Failed to parse compile options file {}: {}", + optionsFile, e.what()); + return false; } + } - return constants::COMPILER + - std::string{"-O2 -std=c++20 -Wall -shared -fPIC"}; - }(); + std::string compiler = optionsJson.value("compiler", Constants::COMPILER); + if (std::find(availableCompilers.begin(), availableCompilers.end(), + compiler) == availableCompilers.end()) { + LOG_F(WARNING, + "Compiler {} is not available, using default compiler {}.", + compiler, Constants::COMPILER); + compiler = Constants::COMPILER; + } - // 语法检查 - if (!syntaxCheck(code, constants::COMPILER)) { - return false; + // Use CompileCommandGenerator to generate compile command + compileCommandGenerator_->setCompiler(compiler); + compileCommandGenerator_->setIncludeFlag( + optionsJson.value("include_flag", "-I./include")); + compileCommandGenerator_->setOutputFlag( + optionsJson.value("output_flag", "-o output")); + compileCommandGenerator_->setProjectName(std::string(moduleName)); + compileCommandGenerator_->setProjectVersion("1.0.0"); + + // Temporarily create a file to store the code + fs::path tempSourceFile = fs::temp_directory_path() / "temp_code.cpp"; + { + std::ofstream tempFile(tempSourceFile); + tempFile << code; } + LOG_F(INFO, "Temporary source file created at: {}", + tempSourceFile.string()); - // 编译代码 - fs::path outputPath = - OUTPUT_DIR / std::format("{}{}{}", constants::LIB_EXTENSION, moduleName, - constants::LIB_EXTENSION); - if (!compileCode(code, constants::COMPILER, compileOptions, outputPath)) { - return false; + compileCommandGenerator_->setSourceDir( + tempSourceFile.parent_path().string()); + compileCommandGenerator_->setOutputPath(OUTPUT_DIR / + "compile_commands.json"); + compileCommandGenerator_->generate(); + + // Read generated compile_commands.json + json compileCommands; + { + std::ifstream commandsFile(OUTPUT_DIR / "compile_commands.json"); + commandsFile >> compileCommands; } + LOG_F(INFO, "Compile commands file read from: {}", + (OUTPUT_DIR / "compile_commands.json").string()); + + // Use the generated command to compile + if (!compileCommands["commands"].empty()) { + auto command = + compileCommands["commands"][0]["command"].get(); + command += " " + customCompileOptions_; + + fs::path outputPath = + OUTPUT_DIR / std::format("{}{}{}", Constants::LIB_EXTENSION, + moduleName, Constants::LIB_EXTENSION); + command += " -o " + outputPath.string(); + + LOG_F(INFO, "Executing compilation command: {}", command); + std::string compilationOutput = atom::system::executeCommand(command); + if (!compilationOutput.empty()) { + LOG_F(ERROR, "Compilation failed:\n{}", compilationOutput); + fs::remove(tempSourceFile); + return false; + } - // 缓存编译结果 - cache_[cacheKey] = outputPath; - return true; + // Cache the compilation result + cache_[cacheKey] = outputPath; + LOG_F(INFO, "Compilation successful, result cached with key: {}", + cacheKey); + fs::remove(tempSourceFile); + return true; + } + + LOG_F(ERROR, "Failed to generate compile command."); + fs::remove(tempSourceFile); + return false; } void CompilerImpl::createOutputDirectory(const fs::path &outputDir) { @@ -161,78 +224,147 @@ void CompilerImpl::createOutputDirectory(const fs::path &outputDir) { LOG_F(WARNING, "Output directory {} does not exist, creating it.", outputDir.string()); fs::create_directories(outputDir); + LOG_F(INFO, "Output directory {} created.", outputDir.string()); + } else { + LOG_F(INFO, "Output directory {} already exists.", outputDir.string()); } } auto CompilerImpl::syntaxCheck(std::string_view code, - std::string_view compiler) -> bool { - if (atom::io::isFileExists("temp_code.cpp")) { - if (!atom::io::removeFile("temp_code.cpp")) { - LOG_F(ERROR, "Failed to remove temp_code.cpp"); - return false; - } + const std::string &compiler) -> bool { + LOG_F(INFO, "Starting syntax check using compiler: {}", compiler); + compileCommandGenerator_->setCompiler(compiler); + compileCommandGenerator_->setIncludeFlag("-fsyntax-only"); + compileCommandGenerator_->setOutputFlag(""); + + fs::path tempSourceFile = fs::temp_directory_path() / "syntax_check.cpp"; + { + std::ofstream tempFile(tempSourceFile); + tempFile << code; } - // Create a temporary file to store the code - std::string tempFileName = "temp_code.cpp"; + LOG_F(INFO, "Temporary file for syntax check created at: {}", + tempSourceFile.string()); + + compileCommandGenerator_->setSourceDir( + tempSourceFile.parent_path().string()); + compileCommandGenerator_->setOutputPath(fs::temp_directory_path() / + "syntax_check_commands.json"); + compileCommandGenerator_->generate(); + + json syntaxCheckCommands; { - std::ofstream tempFile(tempFileName, std::ios::trunc); - if (!tempFile) { - LOG_F(ERROR, "Failed to create temporary file for code"); + std::ifstream commandsFile(fs::temp_directory_path() / + "syntax_check_commands.json"); + commandsFile >> syntaxCheckCommands; + } + LOG_F(INFO, "Syntax check commands file read."); + + if (!syntaxCheckCommands["commands"].empty()) { + auto command = + syntaxCheckCommands["commands"][0]["command"].get(); + LOG_F(INFO, "Executing syntax check command: {}", command); + std::string output = atom::system::executeCommand(command); + + fs::remove(tempSourceFile); + fs::remove(fs::temp_directory_path() / "syntax_check_commands.json"); + + if (!output.empty()) { + LOG_F(ERROR, "Syntax check failed:\n{}", output); return false; } - tempFile << code; + LOG_F(INFO, "Syntax check passed."); + return true; } + LOG_F(ERROR, "Failed to generate syntax check command."); + fs::remove(tempSourceFile); + fs::remove(fs::temp_directory_path() / "syntax_check_commands.json"); + return false; +} - // Format the command to invoke the compiler with syntax-only check - std::string command = - std::format("{} -fsyntax-only -x c++ {}", compiler, tempFileName); - std::string output; +auto CompilerImpl::compileCode(std::string_view code, + const std::string &compiler, + const std::string &compileOptions, + const fs::path &output) -> bool { + LOG_F(INFO, + "Starting compilation with compiler: {}, options: {}, output: {}", + compiler, compileOptions, output.string()); - // Execute the command and process output - output = atom::system::executeCommand( - command, - false, // No need for a shell - [&](const std::string &line) { output += line + "\n"; }); + // Use CompileCommandGenerator to generate the compile command + compileCommandGenerator_->setCompiler(compiler); + compileCommandGenerator_->setIncludeFlag(compileOptions); + compileCommandGenerator_->setOutputFlag("-o " + output.string()); - // Clean up temporary file - if (!atom::io::removeFile("temp_code.cpp")) { - LOG_F(ERROR, "Failed to remove temp_code.cpp"); - return false; + fs::path tempSourceFile = fs::temp_directory_path() / "compile_code.cpp"; + { + std::ofstream tempFile(tempSourceFile); + tempFile << code; } + LOG_F(INFO, "Temporary file for compilation created at: {}", + tempSourceFile.string()); - // Check the output for errors - if (!output.empty()) { - LOG_F(ERROR, "Syntax check failed:\n{}", output); - return false; + compileCommandGenerator_->setSourceDir( + tempSourceFile.parent_path().string()); + compileCommandGenerator_->setOutputPath(fs::temp_directory_path() / + "compile_code_commands.json"); + compileCommandGenerator_->generate(); + + json compileCodeCommands; + { + std::ifstream commandsFile(fs::temp_directory_path() / + "compile_code_commands.json"); + commandsFile >> compileCodeCommands; } - return true; -} + LOG_F(INFO, "Compile commands file read."); -auto CompilerImpl::compileCode(std::string_view code, std::string_view compiler, - std::string_view compileOptions, - const fs::path &output) -> bool { - std::string command = std::format("{} {} -xc++ - -o {}", compiler, - compileOptions, output.string()); - std::string compilationOutput; - compilationOutput = atom::system::executeCommandWithInput( - command, std::string(code), - [&](const std::string &line) { compilationOutput += line + "\n"; }); - if (!compilationOutput.empty()) { - LOG_F(ERROR, "Compilation failed:\n{}", compilationOutput); - return false; + if (!compileCodeCommands["commands"].empty()) { + auto command = + compileCodeCommands["commands"][0]["command"].get(); + LOG_F(INFO, "Executing compilation command: {}", command); + std::string compilationOutput = atom::system::executeCommand(command); + + fs::remove(tempSourceFile); + fs::remove(fs::temp_directory_path() / "compile_code_commands.json"); + + if (!compilationOutput.empty()) { + LOG_F(ERROR, "Compilation failed:\n{}", compilationOutput); + return false; + } + LOG_F(INFO, "Compilation successful, output file: {}", output.string()); + return true; } - return true; + + LOG_F(ERROR, "Failed to generate compile command."); + fs::remove(tempSourceFile); + fs::remove(fs::temp_directory_path() / "compile_code_commands.json"); + return false; } auto CompilerImpl::findAvailableCompilers() -> std::vector { - return toolchainManager_.getAvailableCompilers(); + LOG_F(INFO, "Finding available compilers..."); + auto compilers = toolchainManager_.getAvailableCompilers(); + if (compilers.empty()) { + LOG_F(WARNING, "No compilers found."); + } else { + LOG_F(INFO, "Found compilers: {}", atom::utils::toString(compilers)); + } + return compilers; } void CompilerImpl::addCompileOptions(const std::string &options) { + LOG_F(INFO, "Adding custom compile options: {}", options); customCompileOptions_ = options; } auto CompilerImpl::getAvailableCompilers() const -> std::vector { + LOG_F(INFO, "Retrieving available compilers..."); return toolchainManager_.getAvailableCompilers(); } + +void Compiler::generateCompileCommands(const std::string &sourceDir) { + LOG_F(INFO, + "Generating compile commands in Compiler for source directory: {}", + sourceDir); + impl_->generateCompileCommands(sourceDir); +} + } // namespace lithium diff --git a/src/addon/compiler.hpp b/src/addon/compiler.hpp index 618c831d..eda9cbe0 100644 --- a/src/addon/compiler.hpp +++ b/src/addon/compiler.hpp @@ -25,19 +25,20 @@ using json = nlohmann::json; namespace lithium { class CompilerImpl; - +class CompileCommandGenerator; class Compiler { public: Compiler(); ~Compiler(); /** - * 编译 C++ 代码为共享库,并加载到内存中 - * @param code 要编译的代码 - * @param moduleName 模块名 - * @param functionName 入口函数名 - * @param optionsFile 编译选项文件路径,默认为 "compile_options.json" - * @return 编译是否成功 + * Compile C++ code into a shared library and load it into memory + * @param code Code to compile + * @param moduleName Module name + * @param functionName Entry function name + * @param optionsFile Compilation options file path, default is + * "compile_options.json" + * @return Whether compilation was successful */ ATOM_NODISCARD auto compileToSharedLibrary( std::string_view code, std::string_view moduleName, @@ -45,18 +46,24 @@ class Compiler { std::string_view optionsFile = "compile_options.json") -> bool; /** - * 添加自定义编译选项 - * @param options 编译选项 + * Add custom compilation options + * @param options Compilation options */ - void addCompileOptions(const std::string &options); + void addCompileOptions(const std::string& options); /** - * 获取可用编译器列表 - * @return 编译器列表 + * Get list of available compilers + * @return List of compilers */ ATOM_NODISCARD auto getAvailableCompilers() const -> std::vector; + /** + * Generate compile commands for a given source directory + * @param sourceDir Source directory path + */ + void generateCompileCommands(const std::string& sourceDir); + private: std::unique_ptr impl_; }; diff --git a/src/addon/debug/elf.cpp b/src/addon/debug/elf.cpp index e5b21e8f..52613941 100644 --- a/src/addon/debug/elf.cpp +++ b/src/addon/debug/elf.cpp @@ -1,336 +1,236 @@ #include "elf.hpp" -#include -#include -#include -#include -#include +#include +#include +#include +#include -#include "atom/log/loguru.hpp" -#include "atom/macro.hpp" +#include "atom/error/exception.hpp" namespace lithium { -// Define your structures -struct ElfHeader { - uint16_t type; - uint16_t machine; - uint32_t version; - uint64_t entry; - uint64_t phoff; - uint64_t shoff; - uint32_t flags; - uint16_t ehsize; - uint16_t phentsize; - uint16_t phnum; - uint16_t shentsize; - uint16_t shnum; - uint16_t shstrndx; -} ATOM_ALIGNAS(64); - -struct ProgramHeader { - uint32_t type; - uint64_t offset; - uint64_t vaddr; - uint64_t paddr; - uint64_t filesz; - uint64_t memsz; - uint32_t flags; - uint64_t align; -} ATOM_ALIGNAS(64); - -struct SectionHeader { - std::string name; - uint32_t type{}; - uint64_t flags{}; - uint64_t addr{}; - uint64_t offset{}; - uint64_t size{}; - uint32_t link{}; - uint32_t info{}; - uint64_t addralign{}; - uint64_t entsize{}; -} ATOM_ALIGNAS(128); - -struct Symbol { - std::string name; - uint64_t value{}; - uint64_t size{}; - unsigned char bind{}; - unsigned char type{}; - uint16_t shndx{}; -} ATOM_ALIGNAS(64); - -struct DynamicEntry { - uint64_t tag; - union { - uint64_t val; - uint64_t ptr; - } dUn; -} ATOM_ALIGNAS(16); - -struct RelocationEntry { - uint64_t offset; - uint64_t info; - int64_t addend; -} ATOM_ALIGNAS(32); - -// Forward declaration of the implementation class + class ElfParser::Impl { public: - explicit Impl(const char* file); - ~Impl(); - - auto initialize() -> bool; - void cleanup(); - auto parse() -> bool; - [[nodiscard]] auto getElfHeader() const -> ElfHeader; - [[nodiscard]] auto getProgramHeaders() const -> std::vector; - [[nodiscard]] auto getSectionHeaders() const -> std::vector; - [[nodiscard]] auto getSymbolTable() const -> std::vector; - [[nodiscard]] auto getDynamicEntries() const -> std::vector; - [[nodiscard]] auto getRelocationEntries() const -> std::vector; + explicit Impl(std::string_view file) : filePath_(file) {} -private: - const char* filePath_; - int fd_{}; - Elf* elf_{}; - GElf_Ehdr ehdr_; -}; + auto parse() -> bool { + std::ifstream file(filePath_, std::ios::binary); + if (!file) { + return false; + } -// Implementation of the `Impl` class methods -ElfParser::Impl::Impl(const char* file) : filePath_(file) {} + file.seekg(0, std::ios::end); + fileSize_ = file.tellg(); + file.seekg(0, std::ios::beg); -ElfParser::Impl::~Impl() { cleanup(); } + fileContent_.resize(fileSize_); + file.read(reinterpret_cast(fileContent_.data()), fileSize_); -auto ElfParser::Impl::initialize() -> bool { - if (elf_version(EV_CURRENT) == EV_NONE) { - LOG_F(ERROR, "ELF library initialization failed: {}", elf_errmsg(-1)); - return false; + return parseElfHeader() && parseProgramHeaders() && + parseSectionHeaders() && parseSymbolTable(); } - fd_ = open(filePath_, O_RDONLY, 0); - if (fd_ < 0) { - LOG_F(ERROR, "Failed to open ELF file: {}", filePath_); - return false; - } + [[nodiscard]] auto getElfHeader() const -> std::optional { return elfHeader_; } - elf_ = elf_begin(fd_, ELF_C_READ, nullptr); - if (elf_ == nullptr) { - LOG_F(ERROR, "elf_begin() failed: {}", elf_errmsg(-1)); - close(fd_); - fd_ = -1; - return false; + [[nodiscard]] auto getProgramHeaders() const -> std::span { + return programHeaders_; } - if (gelf_getehdr(elf_, &ehdr_) == nullptr) { - LOG_F(ERROR, "gelf_getehdr() failed: {}", elf_errmsg(-1)); - elf_end(elf_); - elf_ = nullptr; - close(fd_); - fd_ = -1; - return false; + [[nodiscard]] auto getSectionHeaders() const -> std::span { + return sectionHeaders_; } - return true; -} + [[nodiscard]] auto getSymbolTable() const -> std::span { return symbolTable_; } -void ElfParser::Impl::cleanup() { - if (elf_ != nullptr) { - elf_end(elf_); - elf_ = nullptr; + [[nodiscard]] auto findSymbolByName(std::string_view name) const -> std::optional { + auto it = std::ranges::find_if(symbolTable_, [name](const auto& symbol) { + return symbol.name == name; + }); + if (it != symbolTable_.end()) { + return *it; + } + return std::nullopt; } - if (fd_ >= 0) { - close(fd_); - fd_ = -1; + + [[nodiscard]] auto findSymbolByAddress(uint64_t address) const -> std::optional { + auto it = std::ranges::find_if( + symbolTable_, + [address](const auto& symbol) { return symbol.value == address; }); + if (it != symbolTable_.end()) { + return *it; + } + return std::nullopt; } -} -auto ElfParser::Impl::parse() -> bool { return initialize(); } - -auto ElfParser::Impl::getElfHeader() const -> ElfHeader { - ElfHeader header{}; - header.type = ehdr_.e_type; - header.machine = ehdr_.e_machine; - header.version = ehdr_.e_version; - header.entry = ehdr_.e_entry; - header.phoff = ehdr_.e_phoff; - header.shoff = ehdr_.e_shoff; - header.flags = ehdr_.e_flags; - header.ehsize = ehdr_.e_ehsize; - header.phentsize = ehdr_.e_phentsize; - header.phnum = ehdr_.e_phnum; - header.shentsize = ehdr_.e_shentsize; - header.shnum = ehdr_.e_shnum; - header.shstrndx = ehdr_.e_shstrndx; - return header; -} + [[nodiscard]] auto findSection(std::string_view name) const -> std::optional { + auto it = std::ranges::find_if( + sectionHeaders_, + [name](const auto& section) { return section.name == name; }); + if (it != sectionHeaders_.end()) { + return *it; + } + return std::nullopt; + } -auto ElfParser::Impl::getProgramHeaders() const -> std::vector { - std::vector headers; - for (size_t i = 0; i < ehdr_.e_phnum; ++i) { - GElf_Phdr phdr; - if (gelf_getphdr(elf_, i, &phdr) != &phdr) { - LOG_F(ERROR, "gelf_getphdr() failed: {}", elf_errmsg(-1)); - continue; + [[nodiscard]] auto getSectionData(const SectionHeader& section) const -> std::vector { + if (section.offset + section.size > fileSize_) { + THROW_OUT_OF_RANGE("Section data out of bounds"); } - ProgramHeader header; - header.type = phdr.p_type; - header.offset = phdr.p_offset; - header.vaddr = phdr.p_vaddr; - header.paddr = phdr.p_paddr; - header.filesz = phdr.p_filesz; - header.memsz = phdr.p_memsz; - header.flags = phdr.p_flags; - header.align = phdr.p_align; - headers.push_back(header); + return {fileContent_.begin() + section.offset, + fileContent_.begin() + section.offset + section.size}; } - return headers; -} -auto ElfParser::Impl::getSectionHeaders() const -> std::vector { - std::vector headers; - for (size_t i = 0; i < ehdr_.e_shnum; ++i) { - GElf_Shdr shdr; - if (gelf_getshdr(elf_getscn(elf_, i), &shdr) != &shdr) { - LOG_F(ERROR, "gelf_getshdr() failed: {}", elf_errmsg(-1)); - continue; +private: + std::string filePath_; + std::vector fileContent_; + size_t fileSize_{}; + + std::optional elfHeader_; + std::vector programHeaders_; + std::vector sectionHeaders_; + std::vector symbolTable_; + + auto parseElfHeader() -> bool { + if (fileSize_ < sizeof(Elf64_Ehdr)) { + return false; } - SectionHeader header; - header.name = elf_strptr(elf_, ehdr_.e_shstrndx, shdr.sh_name); - header.type = shdr.sh_type; - header.flags = shdr.sh_flags; - header.addr = shdr.sh_addr; - header.offset = shdr.sh_offset; - header.size = shdr.sh_size; - header.link = shdr.sh_link; - header.info = shdr.sh_info; - header.addralign = shdr.sh_addralign; - header.entsize = shdr.sh_entsize; - headers.push_back(header); + + const auto* ehdr = + reinterpret_cast(fileContent_.data()); + elfHeader_ = ElfHeader{.type = ehdr->e_type, + .machine = ehdr->e_machine, + .version = ehdr->e_version, + .entry = ehdr->e_entry, + .phoff = ehdr->e_phoff, + .shoff = ehdr->e_shoff, + .flags = ehdr->e_flags, + .ehsize = ehdr->e_ehsize, + .phentsize = ehdr->e_phentsize, + .phnum = ehdr->e_phnum, + .shentsize = ehdr->e_shentsize, + .shnum = ehdr->e_shnum, + .shstrndx = ehdr->e_shstrndx}; + + return true; } - return headers; -} -std::vector ElfParser::Impl::getSymbolTable() const { - std::vector symbols; - size_t shnum; - elf_getshdrnum(elf_, &shnum); - - for (size_t i = 0; i < shnum; ++i) { - Elf_Scn* scn = elf_getscn(elf_, i); - GElf_Shdr shdr; - gelf_getshdr(scn, &shdr); - - if (shdr.sh_type == SHT_SYMTAB || shdr.sh_type == SHT_DYNSYM) { - Elf_Data* data = elf_getdata(scn, nullptr); - size_t symbolCount = shdr.sh_size / shdr.sh_entsize; - - for (size_t j = 0; j < symbolCount; ++j) { - GElf_Sym sym; - gelf_getsym(data, j, &sym); - Symbol symbol; - symbol.name = elf_strptr(elf_, shdr.sh_link, sym.st_name); - symbol.value = sym.st_value; - symbol.size = sym.st_size; - symbol.bind = GELF_ST_BIND(sym.st_info); - symbol.type = GELF_ST_TYPE(sym.st_info); - symbol.shndx = sym.st_shndx; - symbols.push_back(symbol); - } + auto parseProgramHeaders() -> bool { + if (!elfHeader_) { + return false; + } + + const auto* phdr = reinterpret_cast( + fileContent_.data() + elfHeader_->phoff); + for (uint16_t i = 0; i < elfHeader_->phnum; ++i) { + programHeaders_.push_back(ProgramHeader{.type = phdr[i].p_type, + .offset = phdr[i].p_offset, + .vaddr = phdr[i].p_vaddr, + .paddr = phdr[i].p_paddr, + .filesz = phdr[i].p_filesz, + .memsz = phdr[i].p_memsz, + .flags = phdr[i].p_flags, + .align = phdr[i].p_align}); } + + return true; } - return symbols; -} -auto ElfParser::Impl::getDynamicEntries() const -> std::vector { - std::vector entries; - size_t shnum; - elf_getshdrnum(elf_, &shnum); - - for (size_t i = 0; i < shnum; ++i) { - Elf_Scn* scn = elf_getscn(elf_, i); - GElf_Shdr shdr; - gelf_getshdr(scn, &shdr); - - if (shdr.sh_type == SHT_DYNAMIC) { - Elf_Data* data = elf_getdata(scn, nullptr); - size_t entryCount = shdr.sh_size / shdr.sh_entsize; - - for (size_t j = 0; j < entryCount; ++j) { - GElf_Dyn dyn; - gelf_getdyn(data, j, &dyn); - DynamicEntry entry; - entry.tag = dyn.d_tag; - entry.dUn.val = dyn.d_un.d_val; - entries.push_back(entry); - } + auto parseSectionHeaders() -> bool { + if (!elfHeader_) { + return false; + } + + const auto* shdr = reinterpret_cast( + fileContent_.data() + elfHeader_->shoff); + const auto* strtab = reinterpret_cast( + fileContent_.data() + shdr[elfHeader_->shstrndx].sh_offset); + + for (uint16_t i = 0; i < elfHeader_->shnum; ++i) { + sectionHeaders_.push_back( + SectionHeader{.name = std::string(strtab + shdr[i].sh_name), + .type = shdr[i].sh_type, + .flags = shdr[i].sh_flags, + .addr = shdr[i].sh_addr, + .offset = shdr[i].sh_offset, + .size = shdr[i].sh_size, + .link = shdr[i].sh_link, + .info = shdr[i].sh_info, + .addralign = shdr[i].sh_addralign, + .entsize = shdr[i].sh_entsize}); } + + return true; } - return entries; -} -std::vector ElfParser::Impl::getRelocationEntries() const { - std::vector entries; - size_t shnum; - elf_getshdrnum(elf_, &shnum); - - for (size_t i = 0; i < shnum; ++i) { - Elf_Scn* scn = elf_getscn(elf_, i); - GElf_Shdr shdr; - gelf_getshdr(scn, &shdr); - - if (shdr.sh_type == SHT_RELA || shdr.sh_type == SHT_REL) { - Elf_Data* data = elf_getdata(scn, nullptr); - size_t entryCount = shdr.sh_size / shdr.sh_entsize; - - for (size_t j = 0; j < entryCount; ++j) { - RelocationEntry entry; - if (shdr.sh_type == SHT_RELA) { - GElf_Rela rela; - gelf_getrela(data, j, &rela); - entry.offset = rela.r_offset; - entry.info = rela.r_info; - entry.addend = rela.r_addend; - } else { - GElf_Rel rel; - gelf_getrel(data, j, &rel); - entry.offset = rel.r_offset; - entry.info = rel.r_info; - entry.addend = 0; // REL doesn't have addend - } - entries.push_back(entry); - } + auto parseSymbolTable() -> bool { + auto symtabSection = std::ranges::find_if( + sectionHeaders_, + [](const auto& section) { return section.type == SHT_SYMTAB; }); + + if (symtabSection == sectionHeaders_.end()) { + return true; // No symbol table, but not an error + } + + const auto* symtab = reinterpret_cast( + fileContent_.data() + symtabSection->offset); + size_t numSymbols = symtabSection->size / sizeof(Elf64_Sym); + + const auto* strtab = reinterpret_cast( + fileContent_.data() + sectionHeaders_[symtabSection->link].offset); + + for (size_t i = 0; i < numSymbols; ++i) { + symbolTable_.push_back( + Symbol{.name = std::string(strtab + symtab[i].st_name), + .value = symtab[i].st_value, + .size = symtab[i].st_size, + .bind = ELF64_ST_BIND(symtab[i].st_info), + .type = ELF64_ST_TYPE(symtab[i].st_info), + .shndx = symtab[i].st_shndx}); } + + return true; } - return entries; -} +}; -// Define the public `ElfParser` class methods -ElfParser::ElfParser(const char* file) - : pImpl(std::make_unique(file)) {} +// ElfParser method implementations +ElfParser::ElfParser(std::string_view file) + : pImpl(std::make_unique(file)) {} -bool ElfParser::parse() { return pImpl->parse(); } +ElfParser::~ElfParser() = default; -ElfHeader ElfParser::getElfHeader() const { return pImpl->getElfHeader(); } +auto ElfParser::parse() -> bool { return pImpl->parse(); } -std::vector ElfParser::getProgramHeaders() const { +auto ElfParser::getElfHeader() const -> std::optional { + return pImpl->getElfHeader(); +} + +auto ElfParser::getProgramHeaders() const -> std::span { return pImpl->getProgramHeaders(); } -std::vector ElfParser::getSectionHeaders() const { +auto ElfParser::getSectionHeaders() const -> std::span { return pImpl->getSectionHeaders(); } -std::vector ElfParser::getSymbolTable() const { +auto ElfParser::getSymbolTable() const -> std::span { return pImpl->getSymbolTable(); } -std::vector ElfParser::getDynamicEntries() const { - return pImpl->getDynamicEntries(); +auto ElfParser::findSymbolByName(std::string_view name) const -> std::optional { + return pImpl->findSymbolByName(name); +} + +auto ElfParser::findSymbolByAddress(uint64_t address) const -> std::optional { + return pImpl->findSymbolByAddress(address); } -std::vector ElfParser::getRelocationEntries() const { - return pImpl->getRelocationEntries(); +auto ElfParser::findSection( + std::string_view name) const -> std::optional { + return pImpl->findSection(name); } +auto ElfParser::getSectionData( + const SectionHeader& section) const -> std::vector { + return pImpl->getSectionData(section); +} } // namespace lithium diff --git a/src/addon/debug/elf.hpp b/src/addon/debug/elf.hpp index 2521c8d1..20e285b4 100644 --- a/src/addon/debug/elf.hpp +++ b/src/addon/debug/elf.hpp @@ -1,33 +1,122 @@ +// elf.hpp #ifndef LITHIUM_ADDON_ELF_HPP #define LITHIUM_ADDON_ELF_HPP +#include #include +#include +#include +#include #include +#include "macro.hpp" + namespace lithium { -struct ElfHeader; -struct ProgramHeader; -struct SectionHeader; -struct Symbol; -struct DynamicEntry; -struct RelocationEntry; + +struct ElfHeader { + uint16_t type; + uint16_t machine; + uint32_t version; + uint64_t entry; + uint64_t phoff; + uint64_t shoff; + uint32_t flags; + uint16_t ehsize; + uint16_t phentsize; + uint16_t phnum; + uint16_t shentsize; + uint16_t shnum; + uint16_t shstrndx; +} ATOM_ALIGNAS(64); + +struct ProgramHeader { + uint32_t type; + uint64_t offset; + uint64_t vaddr; + uint64_t paddr; + uint64_t filesz; + uint64_t memsz; + uint32_t flags; + uint64_t align; +} ATOM_ALIGNAS(64); + +struct SectionHeader { + std::string name; + uint32_t type; + uint64_t flags; + uint64_t addr; + uint64_t offset; + uint64_t size; + uint32_t link; + uint32_t info; + uint64_t addralign; + uint64_t entsize; +} ATOM_ALIGNAS(128); + +struct Symbol { + std::string name; + uint64_t value; + uint64_t size; + unsigned char bind; + unsigned char type; + uint16_t shndx; +} ATOM_ALIGNAS(64); + +struct DynamicEntry { + uint64_t tag; + union { + uint64_t val; + uint64_t ptr; + } d_un; +} ATOM_ALIGNAS(16); + +struct RelocationEntry { + uint64_t offset; + uint64_t info; + int64_t addend; +} ATOM_ALIGNAS(32); class ElfParser { public: - explicit ElfParser(const char* file); + explicit ElfParser(std::string_view file); + ~ElfParser(); + + [[nodiscard]] bool parse(); + [[nodiscard]] auto getElfHeader() const -> std::optional; + [[nodiscard]] std::span getProgramHeaders() const; + [[nodiscard]] auto getSectionHeaders() const + -> std::span; + [[nodiscard]] auto getSymbolTable() const -> std::span; + [[nodiscard]] std::span getDynamicEntries() const; + [[nodiscard]] std::span getRelocationEntries() const; + + template Predicate> + [[nodiscard]] std::optional findSymbol(Predicate&& pred) const; - bool parse(); - ElfHeader getElfHeader() const; - std::vector getProgramHeaders() const; - std::vector getSectionHeaders() const; - std::vector getSymbolTable() const; - std::vector getDynamicEntries() const; - std::vector getRelocationEntries() const; + [[nodiscard]] std::optional findSymbolByName( + std::string_view name) const; + [[nodiscard]] std::optional findSymbolByAddress( + uint64_t address) const; + + [[nodiscard]] std::optional findSection( + std::string_view name) const; + [[nodiscard]] std::vector getSectionData( + const SectionHeader& section) const; private: class Impl; std::unique_ptr pImpl; }; + +template Predicate> +std::optional ElfParser::findSymbol(Predicate&& pred) const { + auto symbols = getSymbolTable(); + auto it = std::ranges::find_if(symbols, std::forward(pred)); + if (it != symbols.end()) { + return *it; + } + return std::nullopt; +} } // namespace lithium #endif // LITHIUM_ADDON_ELF_HPP diff --git a/src/addon/debug/pdb.cpp b/src/addon/debug/pdb.cpp index 299b8734..1fc5404a 100644 --- a/src/addon/debug/pdb.cpp +++ b/src/addon/debug/pdb.cpp @@ -1,187 +1,291 @@ #ifdef _WIN32 #include "pdb.hpp" +#include +#include +#include +#include +#include -#include "atom/log/loguru.hpp" +#pragma comment(lib, "dbghelp.lib") namespace lithium { + class PdbParser::Impl { public: - Impl(const std::string& pdbFile) - : pdbFilePath(pdbFile), - hProcess(GetCurrentProcess()), - symInitialized(FALSE), - baseAddress(0) {} - ~Impl() { unloadPdb(); } - - bool initialize() { return loadPdb(); } - - std::vector getSymbols() const; - std::vector getTypes() const; - std::vector getGlobalVariables() const; - std::vector getFunctions() const; - SymbolInfo findSymbolByName(const std::string& name) const; - SymbolInfo findSymbolByAddress(DWORD64 address) const; + explicit Impl(std::string_view pdbFile) + : pdbFilePath(pdbFile), hProcess(GetCurrentProcess()), baseAddress(0) {} -private: - std::string pdbFilePath; - HANDLE hProcess; - BOOL symInitialized; - DWORD64 baseAddress; + ~Impl() { cleanup(); } - bool loadPdb(); - void unloadPdb(); -}; + bool initialize() { + if (!SymInitialize(hProcess, nullptr, FALSE)) { + throw std::runtime_error("Failed to initialize DbgHelp"); + } -bool PdbParser::Impl::loadPdb() { - symInitialized = SymInitialize(hProcess, nullptr, FALSE); - if (!symInitialized) { - LOG_F(ERROR, "Failed to initialize DbgHelp: {}", GetLastError()); - return false; + baseAddress = SymLoadModuleEx(hProcess, nullptr, pdbFilePath.c_str(), + nullptr, 0, 0, nullptr, 0); + if (baseAddress == 0) { + SymCleanup(hProcess); + throw std::runtime_error("Failed to load PDB file"); + } + + return true; } - baseAddress = SymLoadModuleEx(hProcess, nullptr, pdbFilePath.c_str(), - nullptr, 0, 0, nullptr, 0); - if (baseAddress == 0) { - LOG_F(ERROR, "Failed to load PDB file: {}", GetLastError()); - return false; + std::span getSymbols() { + if (symbols.empty()) { + SymEnumSymbols( + hProcess, baseAddress, nullptr, + [](PSYMBOL_INFO pSymInfo, ULONG SymbolSize, + PVOID UserContext) -> BOOL { + auto* pSymbols = + static_cast*>(UserContext); + pSymbols->push_back(SymbolInfo{ + .name = std::string(pSymInfo->Name, pSymInfo->NameLen), + .address = pSymInfo->Address, + .size = pSymInfo->Size, + .flags = pSymInfo->Flags}); + return TRUE; + }, + &symbols); + } + return symbols; } - return true; -} + std::span getTypes() { + if (types.empty()) { + SymEnumTypes( + hProcess, baseAddress, + [](PSYMBOL_INFO pSymInfo, ULONG SymbolSize, + PVOID UserContext) -> BOOL { + auto* pTypes = + static_cast*>(UserContext); + pTypes->push_back(TypeInfo{ + .name = std::string(pSymInfo->Name, pSymInfo->NameLen), + .typeId = pSymInfo->TypeIndex, + .size = pSymInfo->Size, + .typeIndex = pSymInfo->TypeIndex}); + return TRUE; + }, + &types); + } + return types; + } -void PdbParser::Impl::unloadPdb() { - if (baseAddress != 0) { - SymUnloadModule64(hProcess, baseAddress); - baseAddress = 0; + std::span getGlobalVariables() { + if (globalVariables.empty()) { + SymEnumSymbols( + hProcess, baseAddress, nullptr, + [](PSYMBOL_INFO pSymInfo, ULONG SymbolSize, + PVOID UserContext) -> BOOL { + auto* pVariables = + static_cast*>(UserContext); + if (pSymInfo->Flags & SYMFLAG_VALUEPRESENT) { + pVariables->push_back(VariableInfo{ + .name = + std::string(pSymInfo->Name, pSymInfo->NameLen), + .address = pSymInfo->Address, + .size = pSymInfo->Size, + .type = getTypeInfo(pSymInfo->TypeIndex)}); + } + return TRUE; + }, + &globalVariables); + } + return globalVariables; } - if (symInitialized) { - SymCleanup(hProcess); - symInitialized = FALSE; + std::span getFunctions() { + if (functions.empty()) { + SymEnumSymbols( + hProcess, baseAddress, nullptr, + [](PSYMBOL_INFO pSymInfo, ULONG SymbolSize, + PVOID UserContext) -> BOOL { + auto* pFunctions = + static_cast*>(UserContext); + if (pSymInfo->Tag == SymTagFunction) { + pFunctions->push_back(FunctionInfo{ + .name = + std::string(pSymInfo->Name, pSymInfo->NameLen), + .address = pSymInfo->Address, + .size = pSymInfo->Size, + .typeIndex = pSymInfo->TypeIndex, + .parameters = + getFunctionParameters(pSymInfo->TypeIndex), + .returnType = getTypeInfo(pSymInfo->TypeIndex)}); + } + return TRUE; + }, + &functions); + } + return functions; + } + + std::optional findSymbolByName(std::string_view name) const { + SYMBOL_INFO_PACKAGE sip = {}; + sip.si.SizeOfStruct = sizeof(SYMBOL_INFO); + sip.si.MaxNameLen = MAX_SYM_NAME; + + if (SymFromName(hProcess, name.data(), &sip.si)) { + return SymbolInfo{.name = std::string(sip.si.Name, sip.si.NameLen), + .address = sip.si.Address, + .size = sip.si.Size, + .flags = sip.si.Flags}; + } + return std::nullopt; + } + + std::optional findSymbolByAddress(uint64_t address) const { + SYMBOL_INFO_PACKAGE sip = {}; + sip.si.SizeOfStruct = sizeof(SYMBOL_INFO); + sip.si.MaxNameLen = MAX_SYM_NAME; + DWORD64 displacement = 0; + + if (SymFromAddr(hProcess, address, &displacement, &sip.si)) { + return SymbolInfo{.name = std::string(sip.si.Name, sip.si.NameLen), + .address = sip.si.Address, + .size = sip.si.Size, + .flags = sip.si.Flags}; + } + return std::nullopt; + } + + std::optional findTypeByName(std::string_view name) const { + for (const auto& type : types) { + if (type.name == name) { + return type; + } + } + return std::nullopt; + } + + std::optional findFunctionByName( + std::string_view name) const { + for (const auto& func : functions) { + if (func.name == name) { + return func; + } + } + return std::nullopt; + } + + std::vector getSourceLinesForAddress( + uint64_t address) const { + IMAGEHLP_LINE64 line = {}; + line.SizeOfStruct = sizeof(IMAGEHLP_LINE64); + DWORD displacement = 0; + std::vector sourceLines; + + if (SymGetLineFromAddr64(hProcess, address, &displacement, &line)) { + sourceLines.push_back(SourceLineInfo{.fileName = line.FileName, + .lineNumber = line.LineNumber, + .address = line.Address}); + } + + return sourceLines; + } + + std::string demangleName(std::string_view name) const { + char undecorated[MAX_SYM_NAME] = {}; + if (UnDecorateSymbolName(name.data(), undecorated, MAX_SYM_NAME, + UNDNAME_COMPLETE)) { + return undecorated; + } + return std::string(name); } -} -std::vector PdbParser::Impl::getSymbols() const { +private: + std::string pdbFilePath; + HANDLE hProcess; + DWORD64 baseAddress; std::vector symbols; + std::vector types; + std::vector globalVariables; + std::vector functions; + + void cleanup() { + if (baseAddress != 0) { + SymUnloadModule64(hProcess, baseAddress); + } + SymCleanup(hProcess); + } + + static std::optional getTypeInfo(DWORD typeIndex) { + // This is a placeholder. In a real implementation, you would use + // SymGetTypeInfo to get detailed type information. + return std::nullopt; + } + + static std::vector getFunctionParameters(DWORD typeIndex) { + // This is a placeholder. In a real implementation, you would use + // SymGetTypeInfo to get function parameter information. + return {}; + } +}; + +// PdbParser method implementations +PdbParser::PdbParser(std::string_view pdbFile) + : pImpl(std::make_unique(pdbFile)) {} - SymEnumSymbols( - hProcess, baseAddress, nullptr, - [](PSYMBOL_INFO pSymInfo, ULONG SymbolSize, PVOID UserContext) -> BOOL { - auto* pSymbols = static_cast*>(UserContext); - SymbolInfo symbol; - symbol.name = std::string(pSymInfo->Name, pSymInfo->NameLen); - symbol.address = pSymInfo->Address; - symbol.size = pSymInfo->Size; - symbol.flags = pSymInfo->Flags; - pSymbols->emplace_back(symbol); - return TRUE; - }, - &symbols); - - return symbols; +PdbParser::~PdbParser() = default; + +bool PdbParser::initialize() { return pImpl->initialize(); } + +std::span PdbParser::getSymbols() const { + return pImpl->getSymbols(); } -std::vector PdbParser::Impl::getTypes() const { - std::vector types; +std::span PdbParser::getTypes() const { + return pImpl->getTypes(); +} - SymEnumTypes( - hProcess, baseAddress, - [](PSYMBOL_INFO pSymInfo, ULONG SymbolSize, PVOID UserContext) -> BOOL { - auto* pTypes = static_cast*>(UserContext); - TypeInfo type; - type.name = std::string(pSymInfo->Name, pSymInfo->NameLen); - type.typeId = pSymInfo->TypeIndex; - type.size = pSymInfo->Size; - type.typeIndex = pSymInfo->TypeIndex; - pTypes->emplace_back(type); - return TRUE; - }, - &types); - - return types; +std::span PdbParser::getGlobalVariables() const { + return pImpl->getGlobalVariables(); } -std::vector PdbParser::Impl::getGlobalVariables() const { - std::vector variables; - - SymEnumSymbols( - hProcess, baseAddress, nullptr, - [](PSYMBOL_INFO pSymInfo, ULONG SymbolSize, PVOID UserContext) -> BOOL { - auto* pVariables = - static_cast*>(UserContext); - if (pSymInfo->Flags & SYMFLAG_GLOBAL) { - VariableInfo var; - var.name = std::string(pSymInfo->Name, pSymInfo->NameLen); - var.address = pSymInfo->Address; - var.size = pSymInfo->Size; - pVariables->emplace_back(var); - } - return TRUE; - }, - &variables); +std::span PdbParser::getFunctions() const { + return pImpl->getFunctions(); +} - return variables; +std::optional PdbParser::findSymbolByName( + std::string_view name) const { + return pImpl->findSymbolByName(name); } -std::vector PdbParser::Impl::getFunctions() const { - std::vector functions; +std::optional PdbParser::findSymbolByAddress( + uint64_t address) const { + return pImpl->findSymbolByAddress(address); +} - SymEnumSymbols( - hProcess, baseAddress, nullptr, - [](PSYMBOL_INFO pSymInfo, ULONG SymbolSize, PVOID UserContext) -> BOOL { - auto* pFunctions = - static_cast*>(UserContext); - if (pSymInfo->Tag == SymTagFunction) { - FunctionInfo func; - func.name = std::string(pSymInfo->Name, pSymInfo->NameLen); - func.address = pSymInfo->Address; - func.size = pSymInfo->Size; - func.typeIndex = pSymInfo->TypeIndex; - pFunctions->emplace_back(func); - } - return TRUE; - }, - &functions); +std::optional PdbParser::findTypeByName(std::string_view name) const { + return pImpl->findTypeByName(name); +} - return functions; +std::optional PdbParser::findFunctionByName( + std::string_view name) const { + return pImpl->findFunctionByName(name); } -SymbolInfo PdbParser::Impl::findSymbolByName(const std::string& name) const { - SYMBOL_INFO_PACKAGE sip; - ZeroMemory(&sip, sizeof(sip)); - sip.si.MaxNameLen = MAX_SYM_NAME; - sip.si.SizeOfStruct = sizeof(SYMBOL_INFO); - - if (SymFromName(hProcess, name.c_str(), &sip.si)) { - SymbolInfo symbol; - symbol.name = sip.si.Name; - symbol.address = sip.si.Address; - symbol.size = sip.si.Size; - symbol.flags = sip.si.Flags; - return symbol; - } +std::vector PdbParser::getSourceLinesForAddress( + uint64_t address) const { + return pImpl->getSourceLinesForAddress(address); +} - return {"", 0, 0, 0}; +std::string PdbParser::demangleName(std::string_view name) const { + return pImpl->demangleName(name); } -SymbolInfo PdbParser::Impl::findSymbolByAddress(DWORD64 address) const { - SYMBOL_INFO_PACKAGE sip; - ZeroMemory(&sip, sizeof(sip)); - sip.si.MaxNameLen = MAX_SYM_NAME; - sip.si.SizeOfStruct = sizeof(SYMBOL_INFO); - - if (SymFromAddr(hProcess, address, nullptr, &sip.si)) { - SymbolInfo symbol; - symbol.name = sip.si.Name; - symbol.address = sip.si.Address; - symbol.size = sip.si.Size; - symbol.flags = sip.si.Flags; - return symbol; +template Predicate> +std::optional PdbParser::findSymbol(Predicate&& pred) const { + auto symbols = getSymbols(); + auto it = std::ranges::find_if(symbols, std::forward(pred)); + if (it != symbols.end()) { + return *it; } - - return {"", 0, 0, 0}; + return std::nullopt; } + } // namespace lithium -#endif +#endif // _WIN32 diff --git a/src/addon/debug/pdb.hpp b/src/addon/debug/pdb.hpp index 0d675f3e..81908321 100644 --- a/src/addon/debug/pdb.hpp +++ b/src/addon/debug/pdb.hpp @@ -1,60 +1,86 @@ #ifndef LITHIUM_ADDON_PDB_HPP #define LITHIUM_ADDON_PDB_HPP -#ifdef _WIN32 -#include -#include +#include #include +#include +#include #include #include +#include "macro.hpp" + namespace lithium { + struct SymbolInfo { std::string name; - DWORD64 address; - DWORD size; - DWORD flags; -}; + uint64_t address; + uint32_t size; + uint32_t flags; +} ATOM_ALIGNAS(64); struct TypeInfo { std::string name; - DWORD typeId; - DWORD size; - DWORD typeIndex; -}; + uint32_t typeId; + uint32_t size; + uint32_t typeIndex; +} ATOM_ALIGNAS(64); struct VariableInfo { std::string name; - DWORD64 address; - DWORD size; -}; + uint64_t address; + uint32_t size; + std::optional type; +} ATOM_ALIGNAS(128); struct FunctionInfo { std::string name; - DWORD64 address; - DWORD size; - DWORD typeIndex; -}; + uint64_t address; + uint32_t size; + uint32_t typeIndex; + std::vector parameters; + std::optional returnType; +} ATOM_ALIGNAS(128); + +struct SourceLineInfo { + std::string fileName; + uint32_t lineNumber; + uint64_t address; +} ATOM_ALIGNAS(64); class PdbParser { public: - explicit PdbParser(const std::string& pdbFile); + explicit PdbParser(std::string_view pdbFile); ~PdbParser(); - bool initialize(); - std::vector getSymbols() const; - std::vector getTypes() const; - std::vector getGlobalVariables() const; - std::vector getFunctions() const; - SymbolInfo findSymbolByName(const std::string& name) const; - SymbolInfo findSymbolByAddress(DWORD64 address) const; + [[nodiscard]] bool initialize(); + [[nodiscard]] std::span getSymbols() const; + [[nodiscard]] std::span getTypes() const; + [[nodiscard]] std::span getGlobalVariables() const; + [[nodiscard]] std::span getFunctions() const; + + template Predicate> + [[nodiscard]] std::optional findSymbol(Predicate&& pred) const; + + [[nodiscard]] std::optional findSymbolByName( + std::string_view name) const; + [[nodiscard]] std::optional findSymbolByAddress( + uint64_t address) const; + + [[nodiscard]] std::optional findTypeByName( + std::string_view name) const; + [[nodiscard]] std::optional findFunctionByName( + std::string_view name) const; + [[nodiscard]] std::vector getSourceLinesForAddress( + uint64_t address) const; + + [[nodiscard]] std::string demangleName(std::string_view name) const; private: class Impl; - std::unique_ptr pImpl; // Pointer to implementation + std::unique_ptr pImpl; }; -} // namespace lithium -#endif +} // namespace lithium -#endif +#endif // LITHIUM_ADDON_PDB_HPP diff --git a/src/addon/dependency.cpp b/src/addon/dependency.cpp index 3a9aea5a..b3c4d676 100644 --- a/src/addon/dependency.cpp +++ b/src/addon/dependency.cpp @@ -1,11 +1,9 @@ #include "dependency.hpp" #include "version.hpp" -#include #include -#include -#include #include +#include #include "atom/error/exception.hpp" #include "atom/log/loguru.hpp" @@ -13,82 +11,87 @@ namespace lithium { void DependencyGraph::addNode(const Node& node, const Version& version) { + LOG_F(INFO, "Adding node: {} with version: {}", node, version.toString()); adjList_.try_emplace(node); incomingEdges_.try_emplace(node); nodeVersions_[node] = version; + LOG_F(INFO, "Node {} added successfully.", node); } void DependencyGraph::addDependency(const Node& from, const Node& to, const Version& requiredVersion) { - if (nodeVersions_.find(to) != nodeVersions_.end() && - nodeVersions_[to] < requiredVersion) { + LOG_F(INFO, "Adding dependency from {} to {} with required version: {}", + from, to, requiredVersion.toString()); + + if (nodeVersions_.contains(to) && nodeVersions_[to] < requiredVersion) { + LOG_F(ERROR, + "Version requirement not satisfied for dependency {} -> {}", from, + to); THROW_INVALID_ARGUMENT( "Version requirement not satisfied for dependency " + from + " -> " + to); } + adjList_[from].insert(to); incomingEdges_[to].insert(from); + LOG_F(INFO, "Dependency from {} to {} added successfully.", from, to); } void DependencyGraph::removeNode(const Node& node) { + LOG_F(INFO, "Removing node: {}", node); + adjList_.erase(node); incomingEdges_.erase(node); + for (auto& [key, neighbors] : adjList_) { neighbors.erase(node); } for (auto& [key, sources] : incomingEdges_) { sources.erase(node); } + + LOG_F(INFO, "Node {} removed successfully.", node); } void DependencyGraph::removeDependency(const Node& from, const Node& to) { - if (adjList_.find(from) != adjList_.end()) { + LOG_F(INFO, "Removing dependency from {} to {}", from, to); + + if (adjList_.contains(from)) { adjList_[from].erase(to); } - if (incomingEdges_.find(to) != incomingEdges_.end()) { + if (incomingEdges_.contains(to)) { incomingEdges_[to].erase(from); } -} - -auto DependencyGraph::getDependencies(const Node& node) const - -> std::vector { - if (adjList_.find(node) != adjList_.end()) { - return {adjList_.at(node).begin(), adjList_.at(node).end()}; - } - return {}; -} -auto DependencyGraph::getDependents(const Node& node) const - -> std::vector { - std::vector dependents; - for (const auto& [key, neighbors] : adjList_) { - if (neighbors.contains(node)) { - dependents.push_back(key); - } - } - return dependents; + LOG_F(INFO, "Dependency from {} to {} removed successfully.", from, to); } auto DependencyGraph::hasCycle() const -> bool { + LOG_F(INFO, "Checking for cycles in the dependency graph."); std::unordered_set visited; std::unordered_set recStack; for (const auto& [node, _] : adjList_) { if (hasCycleUtil(node, visited, recStack)) { + LOG_F(ERROR, "Cycle detected in the graph."); return true; } } + LOG_F(INFO, "No cycles detected."); return false; } auto DependencyGraph::topologicalSort() const - -> std::optional> { + -> std::optional> { + LOG_F(INFO, "Performing topological sort."); std::unordered_set visited; std::stack stack; + for (const auto& [node, _] : adjList_) { if (!visited.contains(node)) { if (!topologicalSortUtil(node, visited, stack)) { - return std::nullopt; // Cycle detected + LOG_F(ERROR, "Cycle detected during topological sort."); + return std::nullopt; } } } @@ -98,157 +101,43 @@ auto DependencyGraph::topologicalSort() const sortedNodes.push_back(stack.top()); stack.pop(); } - return sortedNodes; -} - -auto DependencyGraph::getAllDependencies(const Node& node) const - -> std::unordered_set { - std::unordered_set allDependencies; - getAllDependenciesUtil(node, allDependencies); - return allDependencies; -} - -void DependencyGraph::loadNodesInParallel( - std::function loadFunction) const { - std::queue readyQueue; - std::mutex mtx; - std::condition_variable cv; - std::unordered_map inDegree; - std::unordered_set loadedNodes; - std::vector threads; - bool done = false; - - // Initialize in-degree and ready queue - for (const auto& [node, deps] : adjList_) { - inDegree[node] = - incomingEdges_.contains(node) ? incomingEdges_.at(node).size() : 0; - if (inDegree[node] == 0) { - readyQueue.push(node); - } - } - - auto worker = [&]() { - while (true) { - Node node; - { - std::unique_lock lock(mtx); - cv.wait(lock, [&] { return !readyQueue.empty() || done; }); - - if (done && readyQueue.empty()) { - return; - } - - node = readyQueue.front(); - readyQueue.pop(); - } - - loadFunction(node); - - { - std::unique_lock lock(mtx); - loadedNodes.insert(node); - - for (const auto& dep : adjList_.at(node)) { - inDegree[dep]--; - if (inDegree[dep] == 0) { - readyQueue.push(dep); - } - } - - if (readyQueue.empty() && - loadedNodes.size() == adjList_.size()) { - done = true; - cv.notify_all(); - } else { - cv.notify_all(); - } - } - } - }; - - int numThreads = std::thread::hardware_concurrency(); - threads.reserve(numThreads); - for (int i = 0; i < numThreads; ++i) { - threads.emplace_back(worker); - } -} - -auto DependencyGraph::hasCycleUtil( - const Node& node, std::unordered_set& visited, - std::unordered_set& recStack) const -> bool { - if (!visited.contains(node)) { - visited.insert(node); - recStack.insert(node); - - for (const auto& neighbour : adjList_.at(node)) { - if (!visited.contains(neighbour) && - hasCycleUtil(neighbour, visited, recStack)) { - return true; - } - if (recStack.contains(neighbour)) { - return true; - } - } - } - recStack.erase(node); - return false; -} - -auto DependencyGraph::topologicalSortUtil( - const Node& node, std::unordered_set& visited, - std::stack& stack) const -> bool { - visited.insert(node); - for (const auto& neighbour : adjList_.at(node)) { - if (!visited.contains(neighbour)) { - if (!topologicalSortUtil(neighbour, visited, stack)) { - return false; // Cycle detected - } - } - } - stack.push(node); - return true; -} - -void DependencyGraph::getAllDependenciesUtil( - const Node& node, std::unordered_set& allDependencies) const { - if (adjList_.contains(node)) { - for (const auto& neighbour : adjList_.at(node)) { - if (!allDependencies.contains(neighbour)) { - allDependencies.insert(neighbour); - getAllDependenciesUtil(neighbour, allDependencies); - } - } - } -} - -auto DependencyGraph::removeDuplicates(const std::vector& input) - -> std::vector { - std::unordered_set seen; - std::vector result; - - for (const auto& element : input) { - if (seen.insert(element).second) { - result.push_back(element); - } - } - return result; + LOG_F(INFO, "Topological sort completed successfully."); + return sortedNodes; } auto DependencyGraph::resolveDependencies( const std::vector& directories) -> std::vector { + LOG_F(INFO, "Resolving dependencies for directories."); DependencyGraph graph; for (const auto& dir : directories) { - std::string packagePath = dir + "/package.json"; - auto [package_name, deps] = parsePackageJson(packagePath); + std::string packageJsonPath = dir + "/package.json"; + std::string packageXmlPath = dir + "/package.xml"; + + if (std::filesystem::exists(packageJsonPath)) { + LOG_F(INFO, "Parsing package.json in directory: {}", dir); + auto [package_name, deps] = parsePackageJson(packageJsonPath); + graph.addNode(package_name, deps.at(package_name)); + + for (const auto& dep : deps) { + if (dep.first != package_name) { + graph.addNode(dep.first, dep.second); + graph.addDependency(package_name, dep.first, dep.second); + } + } + } - graph.addNode(package_name, deps.at(package_name)); + if (std::filesystem::exists(packageXmlPath)) { + LOG_F(INFO, "Parsing package.xml in directory: {}", dir); + auto [package_name, deps] = parsePackageXml(packageXmlPath); + graph.addNode(package_name, deps.at(package_name)); - for (const auto& dep : deps) { - if (dep.first != package_name) { - graph.addNode(dep.first, dep.second); - graph.addDependency(package_name, dep.first, dep.second); + for (const auto& dep : deps) { + if (dep.first != package_name) { + graph.addNode(dep.first, dep.second); + graph.addDependency(package_name, dep.first, dep.second); + } } } } @@ -264,6 +153,7 @@ auto DependencyGraph::resolveDependencies( return {}; } + LOG_F(INFO, "Dependencies resolved successfully."); return removeDuplicates(sortedPackagesOpt.value()); } @@ -298,4 +188,37 @@ auto DependencyGraph::parsePackageJson(const std::string& path) file.close(); return {packageName, deps}; } + +auto DependencyGraph::parsePackageXml(const std::string& path) + -> std::pair> { + XMLDocument doc; + if (doc.LoadFile(path.c_str()) != XML_SUCCESS) { + THROW_FAIL_TO_OPEN_FILE("Failed to open " + path); + } + + XMLElement* root = doc.FirstChildElement("package"); + if (root == nullptr) { + THROW_MISSING_ARGUMENT("Missing root element in " + path); + } + + const char* packageName = root->FirstChildElement("name")->GetText(); + if (packageName == nullptr) { + THROW_MISSING_ARGUMENT("Missing package name in " + path); + } + + std::unordered_map deps; + + XMLElement* dependElement = root->FirstChildElement("depend"); + while (dependElement != nullptr) { + const char* depName = dependElement->GetText(); + if (depName != nullptr) { + deps[depName] = Version{}; // Assuming no version info in XML, + // could extend if needed. + } + dependElement = dependElement->NextSiblingElement("depend"); + } + + return {packageName, deps}; +} + } // namespace lithium diff --git a/src/addon/dependency.hpp b/src/addon/dependency.hpp index 48b6cfdb..c0022652 100644 --- a/src/addon/dependency.hpp +++ b/src/addon/dependency.hpp @@ -9,10 +9,13 @@ #include #include -#include "version.hpp" +#include "tinyxml2/tinyxml2.h" #include "atom/type/json_fwd.hpp" +#include "version.hpp" + using json = nlohmann::json; +using namespace tinyxml2; namespace lithium { /** @@ -132,57 +135,24 @@ class DependencyGraph { std::unordered_map nodeVersions_; ///< Map to track node versions. - /** - * @brief Utility function to check for cycles in the graph using DFS. - * - * @param node The current node being visited. - * @param visited Set of visited nodes. - * @param recStack Set of nodes currently in the recursion stack. - * @return True if a cycle is detected, false otherwise. - */ auto hasCycleUtil(const Node& node, std::unordered_set& visited, std::unordered_set& recStack) const -> bool; - /** - * @brief Utility function to perform DFS for topological sorting. - * - * @param node The current node being visited. - * @param visited Set of visited nodes. - * @param stack The stack to hold the topological order. - * @return True if successful, false otherwise. - */ auto topologicalSortUtil(const Node& node, std::unordered_set& visited, std::stack& stack) const -> bool; - /** - * @brief Utility function to gather all dependencies of a node. - * - * @param node The node for which to collect dependencies. - * @param allDependencies Set to hold all found dependencies. - */ void getAllDependenciesUtil( const Node& node, std::unordered_set& allDependencies) const; - /** - * @brief Removes duplicate entries from a vector of strings. - * - * @param input The input vector potentially containing duplicates. - * @return A vector containing unique entries from the input. - */ static auto removeDuplicates(const std::vector& input) -> std::vector; - /** - * @brief Parses a package.json file to extract package name and - * dependencies. - * - * @param path The path to the package.json file. - * @return A pair containing the package name and its dependencies. - */ static auto parsePackageJson(const Node& path) -> std::pair>; -}; + static auto parsePackageXml(const Node& path) + -> std::pair>; +}; } // namespace lithium #endif // LITHIUM_ADDON_DEPENDENCY_HPP diff --git a/src/addon/generator.cpp b/src/addon/generator.cpp index 27db2df6..ae24d8b1 100644 --- a/src/addon/generator.cpp +++ b/src/addon/generator.cpp @@ -1,5 +1,7 @@ #include "generator.hpp" +#include + namespace lithium { void CppMemberGenerator::generate(const JsonType auto &j, std::ostream &os) { for (const auto &member : j) { diff --git a/src/addon/loader.cpp b/src/addon/loader.cpp index cc9c6789..d1b9b5f0 100644 --- a/src/addon/loader.cpp +++ b/src/addon/loader.cpp @@ -21,9 +21,10 @@ Description: C++20 and Modules Loader #include "function/ffi.hpp" #ifdef _WIN32 -#include #include +#include #else +#include #include #include #endif @@ -50,10 +51,12 @@ ModuleLoader::ModuleLoader(std::string dirName) { ModuleLoader::~ModuleLoader() { try { if (!modules_.empty()) { + LOG_F(INFO, "Unloading all modules..."); if (!unloadAllModules()) { LOG_F(ERROR, "Failed to unload all modules"); } modules_.clear(); + LOG_F(INFO, "All modules unloaded successfully."); } } catch (const std::exception& ex) { LOG_F(ERROR, "Exception during module unloading: {}", ex.what()); @@ -61,11 +64,14 @@ ModuleLoader::~ModuleLoader() { } auto ModuleLoader::createShared() -> std::shared_ptr { + LOG_F(INFO, "Creating shared ModuleLoader instance."); return std::make_shared("modules"); } auto ModuleLoader::createShared(std::string dirName) -> std::shared_ptr { + LOG_F(INFO, "Creating shared ModuleLoader instance with directory: {}", + dirName); return std::make_shared(std::move(dirName)); } @@ -81,10 +87,13 @@ auto ModuleLoader::loadModule(const std::string& path, LOG_F(ERROR, "Module {} does not exist", name); return false; } + LOG_F(INFO, "Loading module: {} from {}", name, path); + auto modInfo = std::make_shared(); try { modInfo->mLibrary = std::make_shared(path); + LOG_F(INFO, "Library loaded for module {}", name); } catch (const std::exception& ex) { LOG_F(ERROR, "Failed to load module {}: {}", name, ex.what()); return false; @@ -94,6 +103,7 @@ auto ModuleLoader::loadModule(const std::string& path, auto moduleDumpPath = internal::replaceFilename(path, "module_dump.json"); if (!atom::io::isFileExists(moduleDumpPath)) { + LOG_F(INFO, "Dumping module functions to {}", moduleDumpPath); std::ofstream out(moduleDumpPath); json dump; for (auto& func : modInfo->functions) { @@ -104,11 +114,13 @@ auto ModuleLoader::loadModule(const std::string& path, dump.push_back(j); } out << dump.dump(4); - LOG_F(INFO, "Dumped module functions to {}", moduleDumpPath); + LOG_F(INFO, "Module functions dumped to {}", moduleDumpPath); } else { - LOG_F(WARNING, "Module dump file already exists, skipping"); + LOG_F(WARNING, "Module dump file {} already exists, skipping", + moduleDumpPath); } - DLOG_F(INFO, "Loaded module: {}", name); + + DLOG_F(INFO, "Module {} loaded successfully.", name); return true; } @@ -116,9 +128,11 @@ auto ModuleLoader::loadModuleFunctions(const std::string& name) -> std::vector> { std::vector> funcs; + LOG_F(INFO, "Loading functions for module: {}", name); + auto it = modules_.find(name); if (it == modules_.end()) { - LOG_F(ERROR, "Module not found: {}", name); + LOG_F(ERROR, "Module not found: {}", name); return funcs; } @@ -134,6 +148,7 @@ auto ModuleLoader::loadModuleFunctions(const std::string& name) loadFunctionsUnix(handle, funcs); #endif + LOG_F(INFO, "Loaded %zu functions for module: {}", funcs.size(), name); return funcs; } @@ -167,11 +182,13 @@ void ModuleLoader::loadFunctionsWindows( } #else void ModuleLoader::loadFunctionsUnix( - void* /*handle*/, std::vector>& funcs) { + void* /*handle*/, std::vector>& funcs) { // UNIX-specific implementation using dl_iterate_phdr + LOG_F(INFO, "Loading functions using Unix-specific implementation"); dl_iterate_phdr( [](struct dl_phdr_info* info, size_t, void* data) { - auto* funcs = static_cast>*>(data); + auto* funcs = + static_cast>*>(data); for (int i = 0; i < info->dlpi_phnum; ++i) { const ElfW(Phdr)* phdr = &info->dlpi_phdr[i]; @@ -179,7 +196,8 @@ void ModuleLoader::loadFunctionsUnix( continue; } - auto* dyn = reinterpret_cast(info->dlpi_addr + phdr->p_vaddr); + auto* dyn = reinterpret_cast(info->dlpi_addr + + phdr->p_vaddr); ElfW(Sym)* symtab = nullptr; char* strtab = nullptr; size_t numSymbols = 0; @@ -188,14 +206,17 @@ void ModuleLoader::loadFunctionsUnix( for (; dyn->d_tag != DT_NULL; ++dyn) { switch (dyn->d_tag) { case DT_SYMTAB: - symtab = reinterpret_cast(info->dlpi_addr + dyn->d_un.d_ptr); + symtab = reinterpret_cast( + info->dlpi_addr + dyn->d_un.d_ptr); break; case DT_STRTAB: - strtab = reinterpret_cast(info->dlpi_addr + dyn->d_un.d_ptr); + strtab = reinterpret_cast(info->dlpi_addr + + dyn->d_un.d_ptr); break; case DT_HASH: if (dyn->d_un.d_ptr) { - auto* hash = reinterpret_cast(info->dlpi_addr + dyn->d_un.d_ptr); + auto* hash = reinterpret_cast( + info->dlpi_addr + dyn->d_un.d_ptr); numSymbols = hash[1]; } break; @@ -205,18 +226,23 @@ void ModuleLoader::loadFunctionsUnix( } } - if (!symtab || !strtab || numSymbols == 0 || symEntrySize == 0) { + if ((symtab == nullptr) || (strtab == nullptr) || + numSymbols == 0 || symEntrySize == 0) { continue; } for (size_t j = 0; j < numSymbols; ++j) { - auto* sym = reinterpret_cast(reinterpret_cast(symtab) + j * symEntrySize); + auto* sym = reinterpret_cast( + reinterpret_cast(symtab) + j * symEntrySize); - if (ELF32_ST_TYPE(sym->st_info) == STT_FUNC && sym->st_name != 0) { + if (ELF32_ST_TYPE(sym->st_info) == STT_FUNC && + sym->st_name != 0) { auto func = std::make_unique(); func->name = &strtab[sym->st_name]; - func->address = reinterpret_cast(info->dlpi_addr + sym->st_value); + func->address = reinterpret_cast( + info->dlpi_addr + sym->st_value); funcs->push_back(std::move(func)); + LOG_F(INFO, "Loaded function: {}", func->name); } } } @@ -228,9 +254,11 @@ void ModuleLoader::loadFunctionsUnix( auto ModuleLoader::unloadModule(const std::string& name) -> bool { std::unique_lock lock(sharedMutex_); + LOG_F(INFO, "Unloading module: {}", name); + if (auto it = modules_.find(name); it != modules_.end()) { modules_.erase(it); - LOG_F(INFO, "Unloaded module: {}", name); + LOG_F(INFO, "Module {} unloaded successfully.", name); return true; } LOG_F(ERROR, "Module {} is not loaded", name); @@ -239,93 +267,159 @@ auto ModuleLoader::unloadModule(const std::string& name) -> bool { auto ModuleLoader::unloadAllModules() -> bool { std::unique_lock lock(sharedMutex_); - for (auto& [name, module] : modules_) { + LOG_F(INFO, "Unloading all loaded modules."); + + for (const auto& [name, module] : modules_) { if (!unloadModule(name)) { LOG_F(ERROR, "Failed to unload module {}", name); } } + modules_.clear(); + LOG_F(INFO, "All modules have been unloaded."); return true; } auto ModuleLoader::checkModuleExists(const std::string& name) const -> bool { + LOG_F(INFO, "Checking if module {} exists.", name); + if (void* handle = LOAD_LIBRARY(name.c_str()); handle) { UNLOAD_LIBRARY(handle); + LOG_F(INFO, "Module {} exists.", name); return true; } + + LOG_F(WARNING, "Module {} does not exist.", name); return false; } auto ModuleLoader::getModule(const std::string& name) const -> std::shared_ptr { std::shared_lock lock(sharedMutex_); + LOG_F(INFO, "Fetching module info for {}", name); + if (auto it = modules_.find(name); it != modules_.end()) { + LOG_F(INFO, "Module {} found.", name); return it->second; } + + LOG_F(ERROR, "Module {} not found.", name); return nullptr; } auto ModuleLoader::getHandle(const std::string& name) const -> std::shared_ptr { std::shared_lock lock(sharedMutex_); + LOG_F(INFO, "Fetching dynamic library handle for module {}", name); + if (auto it = modules_.find(name); it != modules_.end()) { + LOG_F(INFO, "Handle for module {} retrieved.", name); return it->second->mLibrary; } + + LOG_F(ERROR, "Module {} not found.", name); return nullptr; } auto ModuleLoader::hasModule(const std::string& name) const -> bool { std::shared_lock lock(sharedMutex_); - return modules_.contains(name); + LOG_F(INFO, "Checking if module {} is loaded.", name); + + bool exists = modules_.contains(name); + if (exists) { + LOG_F(INFO, "Module {} is currently loaded.", name); + } else { + LOG_F(WARNING, "Module {} is not loaded.", name); + } + + return exists; } auto ModuleLoader::enableModule(const std::string& name) -> bool { std::unique_lock lock(sharedMutex_); + LOG_F(INFO, "Enabling module {}.", name); + if (auto mod = getModule(name); mod && !mod->m_enabled.load()) { mod->m_enabled.store(true); - LOG_F(INFO, "Module {} enabled", name); + LOG_F(INFO, "Module {} enabled.", name); return true; } - LOG_F(ERROR, "Failed to enable module {}", name); + + LOG_F(ERROR, + "Failed to enable module {}. Either the module is already enabled or " + "not found.", + name); return false; } auto ModuleLoader::disableModule(const std::string& name) -> bool { std::unique_lock lock(sharedMutex_); + LOG_F(INFO, "Disabling module {}.", name); + if (auto mod = getModule(name); mod && mod->m_enabled.load()) { mod->m_enabled.store(false); - LOG_F(INFO, "Module {} disabled", name); + LOG_F(INFO, "Module {} disabled.", name); return true; } - LOG_F(ERROR, "Failed to disable module {}", name); + + LOG_F(ERROR, + "Failed to disable module {}. Either the module is already disabled " + "or not found.", + name); return false; } auto ModuleLoader::isModuleEnabled(const std::string& name) const -> bool { std::shared_lock lock(sharedMutex_); + LOG_F(INFO, "Checking if module {} is enabled.", name); + if (auto mod = getModule(name); mod) { - return mod->m_enabled.load(); + bool enabled = mod->m_enabled.load(); + if (enabled) { + LOG_F(INFO, "Module {} is enabled.", name); + } else { + LOG_F(WARNING, "Module {} is disabled.", name); + } + return enabled; } + + LOG_F(ERROR, "Module {} not found.", name); return false; } auto ModuleLoader::getAllExistedModules() const -> std::vector { std::shared_lock lock(sharedMutex_); + LOG_F(INFO, "Retrieving all loaded modules."); + std::vector moduleNames; moduleNames.reserve(modules_.size()); + for (const auto& [name, _] : modules_) { moduleNames.push_back(name); + LOG_F(INFO, "Module {} is currently loaded.", name); } + return moduleNames; } auto ModuleLoader::hasFunction(const std::string& name, const std::string& functionName) -> bool { std::shared_lock lock(sharedMutex_); + LOG_F(INFO, "Checking if function {} exists in module {}.", functionName, + name); + if (auto it = modules_.find(name); it != modules_.end()) { - return it->second->mLibrary->hasFunction(functionName); + bool exists = it->second->mLibrary->hasFunction(functionName); + if (exists) { + LOG_F(INFO, "Function {} found in module {}.", functionName, name); + } else { + LOG_F(ERROR, "Function {} not found in module {}.", functionName, + name); + } + return exists; } - LOG_F(ERROR, "Failed to find module {}", name); + + LOG_F(ERROR, "Module {} not found.", name); return false; } diff --git a/src/addon/manager.cpp b/src/addon/manager.cpp index 093e0d59..1bfdfa38 100644 --- a/src/addon/manager.cpp +++ b/src/addon/manager.cpp @@ -70,20 +70,17 @@ class ComponentManagerImpl { std::string modulePath; DependencyGraph dependencyGraph; std::mutex mutex; - - ComponentManagerImpl() = default; - ~ComponentManagerImpl() = default; }; ComponentManager::ComponentManager() : impl_(std::make_unique()) { impl_->moduleLoader = - GetWeakPtr(constants::LITHIUM_MODULE_LOADER); - impl_->env = GetWeakPtr(constants::LITHIUM_UTILS_ENV); + GetWeakPtr(Constants::LITHIUM_MODULE_LOADER); + impl_->env = GetWeakPtr(Constants::LITHIUM_UTILS_ENV); impl_->addonManager = - GetWeakPtr(constants::LITHIUM_ADDON_MANAGER); + GetWeakPtr(Constants::LITHIUM_ADDON_MANAGER); impl_->processManager = GetWeakPtr( - constants::LITHIUM_PROCESS_MANAGER); + Constants::LITHIUM_PROCESS_MANAGER); impl_->sandbox = std::make_unique(); impl_->compiler = std::make_unique(); @@ -117,8 +114,8 @@ auto ComponentManager::initialize() -> bool { return false; } - impl_->modulePath = envLock->getEnv(constants::ENV_VAR_MODULE_PATH, - constants::MODULE_FOLDER); + impl_->modulePath = envLock->getEnv(Constants::ENV_VAR_MODULE_PATH, + Constants::MODULE_FOLDER); for (const auto& dir : getQualifiedSubDirs(impl_->modulePath)) { LOG_F(INFO, "Found module: {}", dir); @@ -182,7 +179,7 @@ auto ComponentManager::initialize() -> bool { auto dependencies = componentInfo.value("dependencies", std::vector{}); auto modulePath = - path / (componentName + std::string(constants::LIB_EXTENSION)); + path / (componentName + std::string(Constants::LIB_EXTENSION)); std::string componentFullName; componentFullName.reserve(addonName.length() + componentName.length() + @@ -263,15 +260,15 @@ auto ComponentManager::getQualifiedSubDirs(const std::string& path) LOG_F(INFO, "Checking directory: {}", entry.path().string()); auto files = getFilesInDir(entry.path().string()); for (const auto& fileName : files) { - if (fileName == constants::PACKAGE_NAME) { + if (fileName == Constants::PACKAGE_NAME) { hasJson = true; } else if (fileName.size() > 4 && #ifdef _WIN32 fileName.substr(fileName.size() - 4) == - constants::LIB_EXTENSION + Constants::LIB_EXTENSION #else fileName.substr(fileName.size() - 3) == - constants::LIB_EXTENSION + Constants::LIB_EXTENSION #endif ) { hasLib = true; @@ -376,8 +373,8 @@ auto ComponentManager::checkComponent(const std::string& module_name, return false; } - if (!std::filesystem::exists(module_path + constants::PATH_SEPARATOR + - constants::PACKAGE_NAME)) { + if (!std::filesystem::exists(module_path + Constants::PATH_SEPARATOR + + Constants::PACKAGE_NAME)) { LOG_F(ERROR, "Component path {} does not contain package.json", module_path); return false; @@ -388,7 +385,7 @@ auto ComponentManager::checkComponent(const std::string& module_name, [&](const std::string& fileName) { return fileName.size() > 4 && fileName.substr(fileName.size() - 4) == - constants::LIB_EXTENSION; + Constants::LIB_EXTENSION; }); if (it == files.end()) { @@ -399,7 +396,7 @@ auto ComponentManager::checkComponent(const std::string& module_name, } if (moduleLoader && - !moduleLoader->loadModule(module_path + constants::PATH_SEPARATOR + *it, + !moduleLoader->loadModule(module_path + Constants::PATH_SEPARATOR + *it, module_name)) { LOG_F(ERROR, "Failed to load module: {}'s library {}", module_name, module_path); @@ -411,7 +408,7 @@ auto ComponentManager::checkComponent(const std::string& module_name, auto ComponentManager::loadComponentInfo( const std::string& module_path, const std::string& component_name) -> bool { std::string filePath = - module_path + constants::PATH_SEPARATOR + constants::PACKAGE_NAME; + module_path + Constants::PATH_SEPARATOR + Constants::PACKAGE_NAME; if (!std::filesystem::exists(filePath)) { LOG_F(ERROR, "Component path {} does not contain package.json", @@ -601,7 +598,7 @@ auto ComponentManager::loadSharedComponent( #else atom::utils::replaceString(module_path, "\\", "/") #endif - + constants::PATH_SEPARATOR + component_name + constants::LIB_EXTENSION; + + Constants::PATH_SEPARATOR + component_name + Constants::LIB_EXTENSION; auto moduleLoader = impl_->moduleLoader.lock(); if (!moduleLoader) { @@ -762,8 +759,8 @@ auto ComponentManager::loadStandaloneComponent( return false; } } - auto componentFullPath = module_path + constants::PATH_SEPARATOR + - component_name + constants::EXECUTABLE_EXTENSION; + auto componentFullPath = module_path + Constants::PATH_SEPARATOR + + component_name + Constants::EXECUTABLE_EXTENSION; auto standaloneComponent = std::make_shared(component_name); standaloneComponent->startLocalDriver(componentFullPath); diff --git a/src/addon/platform/base.hpp b/src/addon/platform/base.hpp index 04cf9ef3..3f7471da 100644 --- a/src/addon/platform/base.hpp +++ b/src/addon/platform/base.hpp @@ -1,27 +1,71 @@ #ifndef LITHIUM_ADDON_BUILDBASE_HPP #define LITHIUM_ADDON_BUILDBASE_HPP +#include +#include +#include #include #include namespace lithium { + +enum class BuildType { Debug, Release, RelWithDebInfo, MinSizeRel }; + +struct BuildResult { + bool success; + std::string output; + std::string error; +}; + class BuildSystem { public: virtual ~BuildSystem() = default; virtual auto configureProject( - const std::string &sourceDir, const std::string &buildDir, - const std::string &buildType, - const std::vector &options) -> bool = 0; - virtual auto buildProject(const std::string &buildDir, - int jobs) -> bool = 0; - virtual auto cleanProject(const std::string &buildDir) -> bool = 0; - virtual auto installProject(const std::string &buildDir, - const std::string &installDir) -> bool = 0; - virtual auto runTests(const std::string &buildDir) -> bool = 0; - virtual auto generateDocs(const std::string &buildDir) -> bool = 0; - virtual auto loadConfig(const std::string &configPath) -> bool = 0; + const std::filesystem::path& sourceDir, + const std::filesystem::path& buildDir, BuildType buildType, + const std::vector& options) -> BuildResult = 0; + + virtual auto buildProject(const std::filesystem::path& buildDir, + std::optional jobs = std::nullopt) + -> BuildResult = 0; + + virtual auto cleanProject(const std::filesystem::path& buildDir) + -> BuildResult = 0; + + virtual auto installProject(const std::filesystem::path& buildDir, + const std::filesystem::path& installDir) + -> BuildResult = 0; + + virtual auto runTests(const std::filesystem::path& buildDir, + const std::vector& testNames = {}) + -> BuildResult = 0; + + virtual auto generateDocs(const std::filesystem::path& buildDir, + const std::filesystem::path& outputDir) + -> BuildResult = 0; + + virtual auto loadConfig(const std::filesystem::path& configPath) + -> bool = 0; + + virtual auto setLogCallback( + std::function callback) -> void = 0; + + virtual auto getAvailableTargets(const std::filesystem::path& buildDir) + -> std::vector = 0; + + virtual auto buildTarget( + const std::filesystem::path& buildDir, const std::string& target, + std::optional jobs = std::nullopt) -> BuildResult = 0; + + virtual auto getCacheVariables(const std::filesystem::path& buildDir) + -> std::vector> = 0; + + virtual auto setCacheVariable(const std::filesystem::path& buildDir, + const std::string& name, + const std::string& value) -> bool = 0; }; + } // namespace lithium #endif // LITHIUM_ADDON_BUILDBASE_HPP diff --git a/src/addon/platform/cmake.cpp b/src/addon/platform/cmake.cpp index 39eea089..9ed408a2 100644 --- a/src/addon/platform/cmake.cpp +++ b/src/addon/platform/cmake.cpp @@ -3,35 +3,40 @@ #include #include #include -#include +#include +#include "addon/platform/base.hpp" +#include "atom/log/loguru.hpp" +#include "atom/system/command.hpp" #include "atom/type/json.hpp" namespace fs = std::filesystem; using json = nlohmann::json; -#include "atom/log/loguru.hpp" -#include "atom/system/command.hpp" - namespace lithium { + class CMakeBuilderImpl { public: std::unique_ptr configOptions = std::make_unique(); + std::vector preBuildScripts; + std::vector postBuildScripts; + std::vector environmentVariables; std::vector dependencies; + std::function logCallback; }; CMakeBuilder::CMakeBuilder() : pImpl_(std::make_unique()) {} CMakeBuilder::~CMakeBuilder() = default; -bool CMakeBuilder::checkAndInstallDependencies() { - for (const auto &dep : pImpl_->dependencies) { +auto CMakeBuilder::checkAndInstallDependencies() -> bool { + for (const auto& dep : pImpl_->dependencies) { std::string checkCommand = "pkg-config --exists " + dep; if (!atom::system::executeCommandSimple(checkCommand)) { - LOG_F(INFO, "Dependency {} not found, attempting to install...", - dep); + pImpl_->logCallback("Dependency " + dep + + " not found, attempting to install..."); std::string installCommand = "sudo apt-get install -y " + dep; if (!atom::system::executeCommandSimple(installCommand)) { - LOG_F(ERROR, "Failed to install dependency: {}", dep); + pImpl_->logCallback("Failed to install dependency: " + dep); return false; } } @@ -40,84 +45,259 @@ bool CMakeBuilder::checkAndInstallDependencies() { } auto CMakeBuilder::configureProject( - const std::string &sourceDir, const std::string &buildDir, - const std::string &buildType, - const std::vector &options) -> bool { + const std::filesystem::path& sourceDir, + const std::filesystem::path& buildDir, BuildType buildType, + const std::vector& options) -> BuildResult { if (!fs::exists(buildDir)) { fs::create_directories(buildDir); } if (!checkAndInstallDependencies()) { - return false; + return {false, "", "Failed to install dependencies"}; + } + + std::string buildTypeStr; + switch (buildType) { + case BuildType::Debug: + buildTypeStr = "Debug"; + break; + case BuildType::Release: + buildTypeStr = "Release"; + break; + case BuildType::RelWithDebInfo: + buildTypeStr = "RelWithDebInfo"; + break; + case BuildType::MinSizeRel: + buildTypeStr = "MinSizeRel"; + break; } - std::string opts; - for (const auto &opt : options) { - opts += " " + opt; + std::string command = "cmake -S " + sourceDir.string() + " -B " + + buildDir.string() + + " -DCMAKE_BUILD_TYPE=" + buildTypeStr; + + for (const auto& opt : options) { + command += " " + opt; } - std::string command = "cmake -S " + sourceDir + " -B " + buildDir + - " -DCMAKE_BUILD_TYPE=" + buildType + opts; - return atom::system::executeCommandSimple(command); + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to configure project."; + } + return res; } -auto CMakeBuilder::buildProject(const std::string &buildDir, int jobs) -> bool { - std::string command = "cmake --build " + buildDir; - if (jobs > 0) { - command += " -j" + std::to_string(jobs); +auto CMakeBuilder::buildProject(const std::filesystem::path& buildDir, + std::optional jobs) -> BuildResult { + std::string command = "cmake --build " + buildDir.string(); + if (jobs.has_value()) { + command += " -j" + std::to_string(*jobs); } - return atom::system::executeCommandSimple(command); + + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to build project."; + } + return res; } -auto CMakeBuilder::cleanProject(const std::string &buildDir) -> bool { +auto CMakeBuilder::cleanProject(const std::filesystem::path& buildDir) + -> BuildResult { if (!fs::exists(buildDir)) { - LOG_F(ERROR, "Build directory does not exist: {}", buildDir); - return false; + return {false, "", + "Build directory does not exist: " + buildDir.string()}; + } + + std::string command = + "cmake --build " + buildDir.string() + " --target clean"; + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to clean project."; + } + return res; +} + +auto CMakeBuilder::installProject(const std::filesystem::path& buildDir, + const std::filesystem::path& installDir) + -> BuildResult { + std::string command = "cmake --install " + buildDir.string() + + " --prefix " + installDir.string(); + + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to install project."; + } + return res; +} + +auto CMakeBuilder::generateDocs(const std::filesystem::path& buildDir, + const std::filesystem::path& outputDir) + -> BuildResult { + std::string command = + "cmake --build " + buildDir.string() + " --target docs"; + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + fs::path docsDir = buildDir / "docs"; + if (fs::exists(docsDir)) { + fs::create_directories(outputDir); + fs::copy(docsDir, outputDir, + fs::copy_options::recursive | + fs::copy_options::update_existing); + res.output = + "Documentation generated and copied to " + outputDir.string(); + } else { + res.error = "Documentation directory not found after generation."; + } + } else { + res.error = "Failed to generate documentation."; } - std::string command = "cmake --build " + buildDir + " --target clean"; - return atom::system::executeCommandSimple(command); + return res; } -auto CMakeBuilder::installProject(const std::string &buildDir, - const std::string &installDir) -> bool { - std::string command = "cmake --build " + buildDir + - " --target install --prefix " + installDir; - return atom::system::executeCommandSimple(command); +auto CMakeBuilder::buildTarget(const std::filesystem::path& buildDir, + const std::string& target, + std::optional jobs) -> BuildResult { + std::string command = + "cmake --build " + buildDir.string() + " --target " + target; + if (jobs.has_value()) { + command += " -j" + std::to_string(*jobs); + } + + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to build target: " + target; + } + return res; } -auto CMakeBuilder::runTests(const std::string &buildDir) -> bool { - std::string command = "ctest --test-dir " + buildDir; - return atom::system::executeCommandSimple(command); +auto CMakeBuilder::setCacheVariable(const std::filesystem::path& buildDir, + const std::string& name, + const std::string& value) -> bool { + std::string command = + "cmake -D" + name + "=" + value + " " + buildDir.string(); + + auto [output, status] = atom::system::executeCommandWithStatus(command); + return status == 0; } -auto CMakeBuilder::generateDocs(const std::string &buildDir) -> bool { - std::string command = "cmake --build " + buildDir + " --target docs"; - return atom::system::executeCommandSimple(command); +auto CMakeBuilder::runTests(const std::filesystem::path& buildDir, + const std::vector& testNames) + -> BuildResult { + std::string command = "ctest --test-dir " + buildDir.string(); + if (!testNames.empty()) { + command += + " -R \"" + + std::accumulate(testNames.begin(), testNames.end(), std::string(), + [](const std::string& a, const std::string& b) { + return a + (a.empty() ? "" : "|") + b; + }) + + "\""; + } + + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to run tests."; + } + return res; } -auto CMakeBuilder::loadConfig(const std::string &configPath) -> bool { +auto CMakeBuilder::loadConfig(const std::filesystem::path& configPath) -> bool { std::ifstream configFile(configPath); if (!configFile.is_open()) { - LOG_F(ERROR, "Failed to open config file: {}", configPath); + pImpl_->logCallback("Failed to open config file: " + + configPath.string()); return false; } try { configFile >> *(pImpl_->configOptions); - pImpl_->dependencies = pImpl_->configOptions->at("dependencies") - .get>(); - } catch (const json::parse_error &e) { - LOG_F(ERROR, "Failed to parse config file: {} with {}", configPath, - e.what()); - return false; - } catch (const json::type_error &e) { - LOG_F(ERROR, "Failed to parse config file: {} with {}", configPath, - e.what()); + pImpl_->preBuildScripts = pImpl_->configOptions->value( + "preBuildScripts", std::vector{}); + pImpl_->postBuildScripts = pImpl_->configOptions->value( + "postBuildScripts", std::vector{}); + pImpl_->environmentVariables = pImpl_->configOptions->value( + "environmentVariables", std::vector{}); + pImpl_->dependencies = pImpl_->configOptions->value( + "dependencies", std::vector{}); + } catch (const json::exception& e) { + pImpl_->logCallback("Failed to parse config file: " + + configPath.string() + " with " + e.what()); return false; } - configFile.close(); return true; } +auto CMakeBuilder::setLogCallback( + std::function callback) -> void { + pImpl_->logCallback = std::move(callback); +} + +auto CMakeBuilder::getAvailableTargets(const std::filesystem::path& buildDir) + -> std::vector { + std::string command = + "cmake --build " + buildDir.string() + " --target help"; + auto [output, status] = atom::system::executeCommandWithStatus(command); + + std::vector targets; + if (status == 0 && !output.empty()) { + std::istringstream iss(output); + std::string line; + std::regex targetRegex(R"(^\.\.\.\s+(.+))"); + while (std::getline(iss, line)) { + std::smatch match; + if (std::regex_search(line, match, targetRegex)) { + targets.push_back(match[1]); + } + } + } + return targets; +} + +auto CMakeBuilder::getCacheVariables(const std::filesystem::path& buildDir) + -> std::vector> { + std::string command = "cmake -LA -N " + buildDir.string(); + auto [output, status] = atom::system::executeCommandWithStatus(command); + + std::vector> variables; + if (status == 0 && !output.empty()) { + std::istringstream iss(output); + std::string line; + std::regex varRegex(R"(^([^:]+):([^=]+)=(.+)$)"); + while (std::getline(iss, line)) { + std::smatch match; + if (std::regex_search(line, match, varRegex)) { + variables.emplace_back(match[1], match[3]); + } + } + } + return variables; +} + } // namespace lithium diff --git a/src/addon/platform/cmake.hpp b/src/addon/platform/cmake.hpp index fd5fae3e..212fc581 100644 --- a/src/addon/platform/cmake.hpp +++ b/src/addon/platform/cmake.hpp @@ -1,36 +1,71 @@ #ifndef LITHIUM_ADDON_CMAKEBUILDER_HPP #define LITHIUM_ADDON_CMAKEBUILDER_HPP +#include #include +#include #include #include - #include "base.hpp" namespace lithium { + class CMakeBuilderImpl; + class CMakeBuilder : public BuildSystem { public: CMakeBuilder(); ~CMakeBuilder() override; auto configureProject( - const std::string &sourceDir, const std::string &buildDir, - const std::string &buildType, - const std::vector &options) -> bool override; - auto buildProject(const std::string &buildDir, int jobs) -> bool override; - auto cleanProject(const std::string &buildDir) -> bool override; - auto installProject(const std::string &buildDir, - const std::string &installDir) -> bool override; - auto runTests(const std::string &buildDir) -> bool override; - auto generateDocs(const std::string &buildDir) -> bool override; - auto loadConfig(const std::string &configPath) -> bool override; + const std::filesystem::path& sourceDir, + const std::filesystem::path& buildDir, BuildType buildType, + const std::vector& options) -> BuildResult override; + + auto buildProject(const std::filesystem::path& buildDir, + std::optional jobs = std::nullopt) + -> BuildResult override; + + auto cleanProject(const std::filesystem::path& buildDir) + -> BuildResult override; + + auto installProject(const std::filesystem::path& buildDir, + const std::filesystem::path& installDir) + -> BuildResult override; + + auto runTests(const std::filesystem::path& buildDir, + const std::vector& testNames = {}) + -> BuildResult override; + + auto generateDocs(const std::filesystem::path& buildDir, + const std::filesystem::path& outputDir) + -> BuildResult override; + + auto loadConfig(const std::filesystem::path& configPath) -> bool override; + + auto setLogCallback(std::function callback) + -> void override; + + auto getAvailableTargets(const std::filesystem::path& buildDir) + -> std::vector override; + + auto buildTarget( + const std::filesystem::path& buildDir, const std::string& target, + std::optional jobs = std::nullopt) -> BuildResult override; + + auto getCacheVariables(const std::filesystem::path& buildDir) + -> std::vector> override; + + auto setCacheVariable(const std::filesystem::path& buildDir, + const std::string& name, + const std::string& value) -> bool override; private: std::unique_ptr pImpl_; auto checkAndInstallDependencies() -> bool; }; + } // namespace lithium #endif // LITHIUM_ADDON_CMAKEBUILDER_HPP diff --git a/src/addon/platform/meson.cpp b/src/addon/platform/meson.cpp index 768843fa..38abe84f 100644 --- a/src/addon/platform/meson.cpp +++ b/src/addon/platform/meson.cpp @@ -3,14 +3,18 @@ #include #include #include +#include +#include #include "atom/log/loguru.hpp" #include "atom/system/command.hpp" #include "atom/type/json.hpp" + namespace fs = std::filesystem; using json = nlohmann::json; namespace lithium { + class MesonBuilderImpl { public: std::unique_ptr configOptions = std::make_unique(); @@ -18,20 +22,21 @@ class MesonBuilderImpl { std::vector postBuildScripts; std::vector environmentVariables; std::vector dependencies; + std::function logCallback; }; MesonBuilder::MesonBuilder() : pImpl_(std::make_unique()) {} MesonBuilder::~MesonBuilder() = default; auto MesonBuilder::checkAndInstallDependencies() -> bool { - for (const auto &dep : pImpl_->dependencies) { + for (const auto& dep : pImpl_->dependencies) { std::string checkCommand = "pkg-config --exists " + dep; if (!atom::system::executeCommandSimple(checkCommand)) { - LOG_F(INFO, "Dependency {} not found, attempting to install...", - dep); + pImpl_->logCallback("Dependency " + dep + + " not found, attempting to install..."); std::string installCommand = "sudo apt-get install -y " + dep; if (!atom::system::executeCommandSimple(installCommand)) { - LOG_F(ERROR, "Failed to install dependency: {}", dep); + pImpl_->logCallback("Failed to install dependency: " + dep); return false; } } @@ -40,91 +45,244 @@ auto MesonBuilder::checkAndInstallDependencies() -> bool { } auto MesonBuilder::configureProject( - const std::string &sourceDir, const std::string &buildDir, - const std::string &buildType, - const std::vector &options) -> bool { + const std::filesystem::path& sourceDir, + const std::filesystem::path& buildDir, BuildType buildType, + const std::vector& options) -> BuildResult { if (!fs::exists(buildDir)) { fs::create_directories(buildDir); } if (!checkAndInstallDependencies()) { - return false; + return {false, "", "Failed to install dependencies"}; } - std::string opts; - for (const auto &opt : options) { - opts += " " + opt; + std::string buildTypeStr; + switch (buildType) { + case BuildType::Debug: + buildTypeStr = "debug"; + break; + case BuildType::Release: + buildTypeStr = "release"; + break; + case BuildType::RelWithDebInfo: + buildTypeStr = "debugoptimized"; + break; + case BuildType::MinSizeRel: + buildTypeStr = "minsize"; + break; + } + + std::string command = "meson setup " + buildDir.string() + " " + + sourceDir.string() + " --buildtype=" + buildTypeStr; + + for (const auto& opt : options) { + command += " " + opt; } - std::string command = "meson setup " + buildDir + " " + sourceDir + - " --buildtype=" + buildType + opts; - return atom::system::executeCommandSimple(command); + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to configure project."; + } + return res; } -auto MesonBuilder::buildProject(const std::string &buildDir, int jobs) -> bool { - std::string command = "ninja -C " + buildDir; - if (jobs > 0) { - command += " -j" + std::to_string(jobs); +auto MesonBuilder::buildProject(const std::filesystem::path& buildDir, + std::optional jobs) -> BuildResult { + std::string command = "ninja -C " + buildDir.string(); + if (jobs.has_value()) { + command += " -j" + std::to_string(*jobs); + } + + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to build project."; } - return atom::system::executeCommandSimple(command); + return res; } -auto MesonBuilder::cleanProject(const std::string &buildDir) -> bool { +auto MesonBuilder::cleanProject(const std::filesystem::path& buildDir) + -> BuildResult { if (!fs::exists(buildDir)) { - LOG_F(ERROR, "Build directory does not exist: {}", buildDir); - return false; + return {false, "", + "Build directory does not exist: " + buildDir.string()}; } - std::string command = "ninja -C " + buildDir + " clean"; - return atom::system::executeCommandSimple(command); + std::string command = "ninja -C " + buildDir.string() + " clean"; + + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to clean project."; + } + return res; } -auto MesonBuilder::installProject(const std::string &buildDir, - const std::string &installDir) -> bool { - std::string command = - "ninja -C " + buildDir + " install --prefix " + installDir; - return atom::system::executeCommandSimple(command); +auto MesonBuilder::installProject(const std::filesystem::path& buildDir, + const std::filesystem::path& installDir) + -> BuildResult { + std::string command = "meson install -C " + buildDir.string() + + " --destdir " + installDir.string(); + + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to install project."; + } + return res; } -auto MesonBuilder::runTests(const std::string &buildDir) -> bool { - std::string command = "meson test -C " + buildDir; - return atom::system::executeCommandSimple(command); +auto MesonBuilder::runTests(const std::filesystem::path& buildDir, + const std::vector& testNames) + -> BuildResult { + std::string command = "meson test -C " + buildDir.string(); + if (!testNames.empty()) { + command += " " + std::accumulate( + testNames.begin(), testNames.end(), std::string(), + [](const std::string& a, const std::string& b) { + return a + (a.empty() ? "" : " ") + b; + }); + } + + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to run tests."; + } + return res; } -auto MesonBuilder::generateDocs(const std::string &buildDir) -> bool { - std::string command = "ninja -C " + buildDir + " docs"; - return atom::system::executeCommandSimple(command); +auto MesonBuilder::generateDocs(const std::filesystem::path& buildDir, + const std::filesystem::path& outputDir) + -> BuildResult { + std::string command = "ninja -C " + buildDir.string() + " docs"; + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + + if (status == 0) { + fs::path docsDir = buildDir / "docs"; + if (fs::exists(docsDir)) { + fs::create_directories(outputDir); + fs::copy(docsDir, outputDir, + fs::copy_options::recursive | + fs::copy_options::update_existing); + res.output = + "Documentation generated and copied to " + outputDir.string(); + } else { + res.error = "Documentation directory not found after generation."; + } + } else { + res.error = "Failed to generate documentation."; + } + return res; } -auto MesonBuilder::loadConfig(const std::string &configPath) -> bool { +auto MesonBuilder::loadConfig(const std::filesystem::path& configPath) -> bool { std::ifstream configFile(configPath); if (!configFile.is_open()) { - LOG_F(ERROR, "Failed to open config file: {}", configPath); + pImpl_->logCallback("Failed to open config file: " + + configPath.string()); return false; } try { configFile >> *(pImpl_->configOptions); - pImpl_->preBuildScripts = pImpl_->configOptions->at("preBuildScripts") - .get>(); - pImpl_->postBuildScripts = pImpl_->configOptions->at("postBuildScripts") - .get>(); - pImpl_->environmentVariables = - pImpl_->configOptions->at("environmentVariables") - .get>(); - pImpl_->dependencies = pImpl_->configOptions->at("dependencies") - .get>(); - } catch (const json::parse_error &e) { - LOG_F(ERROR, "Failed to parse config file: {} with {}", configPath, - e.what()); - return false; - } catch (const json::type_error &e) { - LOG_F(ERROR, "Failed to parse config file: {} with {}", configPath, - e.what()); + pImpl_->preBuildScripts = pImpl_->configOptions->value( + "preBuildScripts", std::vector{}); + pImpl_->postBuildScripts = pImpl_->configOptions->value( + "postBuildScripts", std::vector{}); + pImpl_->environmentVariables = pImpl_->configOptions->value( + "environmentVariables", std::vector{}); + pImpl_->dependencies = pImpl_->configOptions->value( + "dependencies", std::vector{}); + } catch (const json::exception& e) { + pImpl_->logCallback("Failed to parse config file: " + + configPath.string() + " with " + e.what()); return false; } - configFile.close(); return true; } +auto MesonBuilder::setLogCallback( + std::function callback) -> void { + pImpl_->logCallback = std::move(callback); +} + +auto MesonBuilder::getAvailableTargets(const std::filesystem::path& buildDir) + -> std::vector { + std::string command = "ninja -C " + buildDir.string() + " -t targets all"; + auto [output, status] = atom::system::executeCommandWithStatus(command); + + std::vector targets; + if (status == 0 && !output.empty()) { + std::istringstream iss(output); + std::string line; + std::regex targetRegex(R"(^([^:]+):)"); + while (std::getline(iss, line)) { + std::smatch match; + if (std::regex_search(line, match, targetRegex)) { + targets.push_back(match[1]); + } + } + } + + return targets; +} + +auto MesonBuilder::buildTarget(const std::filesystem::path& buildDir, + const std::string& target, + std::optional jobs) -> BuildResult { + std::string command = "ninja -C " + buildDir.string() + " " + target; + if (jobs.has_value()) { + command += " -j" + std::to_string(*jobs); + } + + auto [output, status] = atom::system::executeCommandWithStatus(command); + BuildResult res; + res.success = status == 0; + if (status == 0) { + res.output = output; + } else { + res.error = "Failed to build target: " + target; + } + return res; +} + +auto MesonBuilder::getCacheVariables(const std::filesystem::path& buildDir) + -> std::vector> { + std::string command = "meson configure " + buildDir.string(); + auto [output, status] = atom::system::executeCommandWithStatus(command); + + std::vector> variables; + if (status == 0 && !output.empty()) { + std::istringstream iss(output); + std::string line; + std::regex varRegex(R"(^([^:]+):\s*(.+)$)"); + while (std::getline(iss, line)) { + std::smatch match; + if (std::regex_search(line, match, varRegex)) { + variables.emplace_back(match[1], match[2]); + } + } + } + return variables; +} + } // namespace lithium diff --git a/src/addon/platform/meson.hpp b/src/addon/platform/meson.hpp index 3b7f0247..94230cdb 100644 --- a/src/addon/platform/meson.hpp +++ b/src/addon/platform/meson.hpp @@ -1,36 +1,71 @@ #ifndef LITHIUM_ADDON_MESONBUILDER_HPP #define LITHIUM_ADDON_MESONBUILDER_HPP +#include #include +#include #include #include - #include "base.hpp" namespace lithium { + class MesonBuilderImpl; + class MesonBuilder : public BuildSystem { public: MesonBuilder(); ~MesonBuilder() override; auto configureProject( - const std::string &sourceDir, const std::string &buildDir, - const std::string &buildType, - const std::vector &options) -> bool override; - auto buildProject(const std::string &buildDir, int jobs) -> bool override; - auto cleanProject(const std::string &buildDir) -> bool override; - auto installProject(const std::string &buildDir, - const std::string &installDir) -> bool override; - auto runTests(const std::string &buildDir) -> bool override; - auto generateDocs(const std::string &buildDir) -> bool override; - auto loadConfig(const std::string &configPath) -> bool override; + const std::filesystem::path& sourceDir, + const std::filesystem::path& buildDir, BuildType buildType, + const std::vector& options) -> BuildResult override; + + auto buildProject(const std::filesystem::path& buildDir, + std::optional jobs = std::nullopt) + -> BuildResult override; + + auto cleanProject(const std::filesystem::path& buildDir) + -> BuildResult override; + + auto installProject(const std::filesystem::path& buildDir, + const std::filesystem::path& installDir) + -> BuildResult override; + + auto runTests(const std::filesystem::path& buildDir, + const std::vector& testNames = {}) + -> BuildResult override; + + auto generateDocs(const std::filesystem::path& buildDir, + const std::filesystem::path& outputDir) + -> BuildResult override; + + auto loadConfig(const std::filesystem::path& configPath) -> bool override; + + auto setLogCallback(std::function callback) + -> void override; + + auto getAvailableTargets(const std::filesystem::path& buildDir) + -> std::vector override; + + auto buildTarget( + const std::filesystem::path& buildDir, const std::string& target, + std::optional jobs = std::nullopt) -> BuildResult override; + + auto getCacheVariables(const std::filesystem::path& buildDir) + -> std::vector> override; + + auto setCacheVariable(const std::filesystem::path& buildDir, + const std::string& name, + const std::string& value) -> bool override; private: std::unique_ptr pImpl_; auto checkAndInstallDependencies() -> bool; }; + } // namespace lithium #endif // LITHIUM_ADDON_MESONBUILDER_HPP diff --git a/src/addon/platform/xmake.cpp b/src/addon/platform/xmake.cpp index 8e776a6c..5288f727 100644 --- a/src/addon/platform/xmake.cpp +++ b/src/addon/platform/xmake.cpp @@ -3,16 +3,17 @@ #include #include #include +#include +#include "atom/log/loguru.hpp" +#include "atom/system/command.hpp" #include "atom/type/json.hpp" namespace fs = std::filesystem; using json = nlohmann::json; -#include "atom/log/loguru.hpp" -#include "atom/system/command.hpp" - namespace lithium { + class XMakeBuilderImpl { public: std::unique_ptr configOptions = std::make_unique(); @@ -23,7 +24,7 @@ XMakeBuilder::XMakeBuilder() : pImpl_(std::make_unique()) {} XMakeBuilder::~XMakeBuilder() = default; auto XMakeBuilder::checkAndInstallDependencies() -> bool { - for (const auto &dep : pImpl_->dependencies) { + for (const auto& dep : pImpl_->dependencies) { std::string checkCommand = "pkg-config --exists " + dep; if (!atom::system::executeCommandSimple(checkCommand)) { LOG_F(INFO, "Dependency {} not found, attempting to install...", @@ -39,64 +40,139 @@ auto XMakeBuilder::checkAndInstallDependencies() -> bool { } auto XMakeBuilder::configureProject( - const std::string &sourceDir, const std::string &buildDir, - const std::string &buildType, - const std::vector &options) -> bool { + const fs::path& sourceDir, const fs::path& buildDir, BuildType buildType, + const std::vector& options) -> BuildResult { + BuildResult result; if (!fs::exists(buildDir)) { fs::create_directories(buildDir); } if (!checkAndInstallDependencies()) { - return false; + result.success = false; + result.error = "Failed to install dependencies."; + return result; + } + + std::string buildTypeStr; + switch (buildType) { + case BuildType::Debug: + buildTypeStr = "debug"; + break; + case BuildType::Release: + buildTypeStr = "release"; + break; + case BuildType::RelWithDebInfo: + buildTypeStr = "reldebug"; + break; + case BuildType::MinSizeRel: + buildTypeStr = "minsizerel"; + break; } std::string opts; - for (const auto &opt : options) { + for (const auto& opt : options) { opts += " " + opt; } std::string command = - "xmake f -p " + buildType + " -o " + buildDir + " " + opts; - return atom::system::executeCommandSimple(command); + "xmake f -p " + buildTypeStr + " -o " + buildDir.string() + " " + opts; + if (atom::system::executeCommandSimple(command)) { + result.success = true; + result.output = "Configured successfully."; + } else { + result.success = false; + result.error = "Configuration failed."; + } + return result; } -auto XMakeBuilder::buildProject(const std::string &buildDir, int jobs) -> bool { - std::string command = "xmake -C " + buildDir; - if (jobs > 0) { - command += " -j" + std::to_string(jobs); +auto XMakeBuilder::buildProject(const fs::path& buildDir, + std::optional jobs) -> BuildResult { + BuildResult result; + std::string command = "xmake -C " + buildDir.string(); + if (jobs && *jobs > 0) { + command += " -j" + std::to_string(*jobs); } - return atom::system::executeCommandSimple(command); + if (atom::system::executeCommandSimple(command)) { + result.success = true; + result.output = "Build succeeded."; + } else { + result.success = false; + result.error = "Build failed."; + } + return result; } -auto XMakeBuilder::cleanProject(const std::string &buildDir) -> bool { +auto XMakeBuilder::cleanProject(const fs::path& buildDir) -> BuildResult { + BuildResult result; if (!fs::exists(buildDir)) { - LOG_F(ERROR, "Build directory does not exist: {}", buildDir); - return false; + result.success = false; + result.error = "Build directory does not exist: " + buildDir.string(); + return result; + } + std::string command = "xmake clean -C " + buildDir.string(); + if (atom::system::executeCommandSimple(command)) { + result.success = true; + result.output = "Clean succeeded."; + } else { + result.success = false; + result.error = "Clean failed."; } - std::string command = "xmake clean -C " + buildDir; - return atom::system::executeCommandSimple(command); + return result; } -auto XMakeBuilder::installProject(const std::string &buildDir, - const std::string &installDir) -> bool { - std::string command = "xmake install -o " + installDir + " -C " + buildDir; - return atom::system::executeCommandSimple(command); +auto XMakeBuilder::installProject(const fs::path& buildDir, + const fs::path& installDir) -> BuildResult { + BuildResult result; + std::string command = + "xmake install -o " + installDir.string() + " -C " + buildDir.string(); + if (atom::system::executeCommandSimple(command)) { + result.success = true; + result.output = "Install succeeded."; + } else { + result.success = false; + result.error = "Install failed."; + } + return result; } -auto XMakeBuilder::runTests(const std::string &buildDir) -> bool { - std::string command = "xmake run test -C " + buildDir; - return atom::system::executeCommandSimple(command); +auto XMakeBuilder::runTests(const fs::path& buildDir, + const std::vector& testNames) + -> BuildResult { + BuildResult result; + std::string command = "xmake run test -C " + buildDir.string(); + for (const auto& testName : testNames) { + command += " " + testName; + } + if (atom::system::executeCommandSimple(command)) { + result.success = true; + result.output = "Tests ran successfully."; + } else { + result.success = false; + result.error = "Tests failed."; + } + return result; } -auto XMakeBuilder::generateDocs(const std::string &buildDir) -> bool { - std::string command = "xmake doc -C " + buildDir; - return atom::system::executeCommandSimple(command); +auto XMakeBuilder::generateDocs(const fs::path& buildDir, + const fs::path& outputDir) -> BuildResult { + BuildResult result; + std::string command = + "xmake doc -C " + buildDir.string() + " -o " + outputDir.string(); + if (atom::system::executeCommandSimple(command)) { + result.success = true; + result.output = "Documentation generated successfully."; + } else { + result.success = false; + result.error = "Documentation generation failed."; + } + return result; } -auto XMakeBuilder::loadConfig(const std::string &configPath) -> bool { +auto XMakeBuilder::loadConfig(const fs::path& configPath) -> bool { std::ifstream configFile(configPath); if (!configFile.is_open()) { - LOG_F(ERROR, "Failed to open config file: {}", configPath); + LOG_F(ERROR, "Failed to open config file: {}", configPath.string()); return false; } @@ -104,13 +180,13 @@ auto XMakeBuilder::loadConfig(const std::string &configPath) -> bool { configFile >> *(pImpl_->configOptions); pImpl_->dependencies = pImpl_->configOptions->at("dependencies") .get>(); - } catch (const json::parse_error &e) { - LOG_F(ERROR, "Failed to parse config file: {} with {}", configPath, - e.what()); + } catch (const json::parse_error& e) { + LOG_F(ERROR, "Failed to parse config file: {} with {}", + configPath.string(), e.what()); return false; - } catch (const json::type_error &e) { - LOG_F(ERROR, "Failed to parse config file: {} with {}", configPath, - e.what()); + } catch (const json::type_error& e) { + LOG_F(ERROR, "Failed to parse config file: {} with {}", + configPath.string(), e.what()); return false; } @@ -118,4 +194,89 @@ auto XMakeBuilder::loadConfig(const std::string &configPath) -> bool { return true; } +auto XMakeBuilder::setLogCallback( + std::function callback) -> void { + // loguru::add_callback("XMakeBuilder", callback, nullptr, + // loguru::Verbosity_INFO); +} + +auto XMakeBuilder::getAvailableTargets(const fs::path& buildDir) + -> std::vector { + std::vector targets; + std::string command = "xmake show -C " + buildDir.string(); + std::string output; + if (const auto [output, status] = + atom::system::executeCommandWithStatus(command); + status == 0) { + // Assume that the output contains targets in a specific format, e.g., + // each target on a new line + std::istringstream stream(output); + std::string line; + while (std::getline(stream, line)) { + targets.push_back(line); + } + } else { + LOG_F(ERROR, "Failed to retrieve available targets."); + } + return targets; +} + +auto XMakeBuilder::buildTarget(const fs::path& buildDir, + const std::string& target, + std::optional jobs) -> BuildResult { + BuildResult result; + std::string command = "xmake build " + target + " -C " + buildDir.string(); + if (jobs && *jobs > 0) { + command += " -j" + std::to_string(*jobs); + } + if (atom::system::executeCommandSimple(command)) { + result.success = true; + result.output = "Target " + target + " built successfully."; + } else { + result.success = false; + result.error = "Failed to build target: " + target; + } + return result; +} + +auto XMakeBuilder::getCacheVariables(const fs::path& buildDir) + -> std::vector> { + std::vector> cacheVariables; + std::string command = "xmake config --show -C " + buildDir.string(); + std::string output; + if (const auto& [output, status] = + atom::system::executeCommandWithStatus(command); + status == 0) { + // Parse the output to extract cache variables, assuming a key=value + // format + std::istringstream stream(output); + std::string line; + while (std::getline(stream, line)) { + auto pos = line.find('='); + if (pos != std::string::npos) { + std::string key = line.substr(0, pos); + std::string value = line.substr(pos + 1); + cacheVariables.emplace_back(key, value); + } + } + } else { + LOG_F(ERROR, "Failed to retrieve cache variables."); + } + return cacheVariables; +} + +auto XMakeBuilder::setCacheVariable(const fs::path& buildDir, + const std::string& name, + const std::string& value) -> bool { + std::string command = + "xmake config " + name + "=" + value + " -C " + buildDir.string(); + if (atom::system::executeCommandSimple(command)) { + LOG_F(INFO, "Cache variable {} set to {}.", name, value); + return true; + } else { + LOG_F(ERROR, "Failed to set cache variable: {}", name); + return false; + } +} + } // namespace lithium diff --git a/src/addon/platform/xmake.hpp b/src/addon/platform/xmake.hpp index 3d2170e0..fc3400b1 100644 --- a/src/addon/platform/xmake.hpp +++ b/src/addon/platform/xmake.hpp @@ -1,36 +1,72 @@ #ifndef LITHIUM_ADDON_XMAKEBUILDER_HPP #define LITHIUM_ADDON_XMAKEBUILDER_HPP +#include +#include #include +#include #include #include #include "base.hpp" namespace lithium { + class XMakeBuilderImpl; + class XMakeBuilder : public BuildSystem { public: XMakeBuilder(); ~XMakeBuilder() override; auto configureProject( - const std::string &sourceDir, const std::string &buildDir, - const std::string &buildType, - const std::vector &options) -> bool override; - auto buildProject(const std::string &buildDir, int jobs) -> bool override; - auto cleanProject(const std::string &buildDir) -> bool override; - auto installProject(const std::string &buildDir, - const std::string &installDir) -> bool override; - auto runTests(const std::string &buildDir) -> bool override; - auto generateDocs(const std::string &buildDir) -> bool override; - auto loadConfig(const std::string &configPath) -> bool override; + const std::filesystem::path& sourceDir, + const std::filesystem::path& buildDir, BuildType buildType, + const std::vector& options) -> BuildResult override; + + auto buildProject(const std::filesystem::path& buildDir, + std::optional jobs = std::nullopt) + -> BuildResult override; + + auto cleanProject(const std::filesystem::path& buildDir) + -> BuildResult override; + + auto installProject(const std::filesystem::path& buildDir, + const std::filesystem::path& installDir) + -> BuildResult override; + + auto runTests(const std::filesystem::path& buildDir, + const std::vector& testNames = {}) + -> BuildResult override; + + auto generateDocs(const std::filesystem::path& buildDir, + const std::filesystem::path& outputDir) + -> BuildResult override; + + auto loadConfig(const std::filesystem::path& configPath) -> bool override; + + auto setLogCallback(std::function callback) + -> void override; + + auto getAvailableTargets(const std::filesystem::path& buildDir) + -> std::vector override; + + auto buildTarget( + const std::filesystem::path& buildDir, const std::string& target, + std::optional jobs = std::nullopt) -> BuildResult override; + + auto getCacheVariables(const std::filesystem::path& buildDir) + -> std::vector> override; + + auto setCacheVariable(const std::filesystem::path& buildDir, + const std::string& name, + const std::string& value) -> bool override; private: std::unique_ptr pImpl_; - auto checkAndInstallDependencies() -> bool; }; + } // namespace lithium #endif // LITHIUM_ADDON_XMAKEBUILDER_HPP diff --git a/src/addon/project/base.hpp b/src/addon/project/base.hpp index 8dc55595..645fca45 100644 --- a/src/addon/project/base.hpp +++ b/src/addon/project/base.hpp @@ -1,10 +1,20 @@ #ifndef LITHIUM_ADDON_VCSMANAGER_HPP #define LITHIUM_ADDON_VCSMANAGER_HPP +#include +#include #include #include namespace lithium { + +struct CommitInfo { + std::string id; + std::string author; + std::string message; + std::chrono::system_clock::time_point timestamp; +}; + class VcsManager { public: virtual ~VcsManager() = default; @@ -20,7 +30,14 @@ class VcsManager { const std::string& branchName) -> bool = 0; virtual auto push(const std::string& remoteName, const std::string& branchName) -> bool = 0; - virtual auto getLog(int limit = 10) -> std::vector = 0; + virtual auto getLog(int limit = 10) -> std::vector = 0; + virtual auto getCurrentBranch() -> std::optional = 0; + virtual auto getBranches() -> std::vector = 0; + virtual auto getStatus() + -> std::vector> = 0; + virtual auto revertCommit(const std::string& commitId) -> bool = 0; + virtual auto createTag(const std::string& tagName, + const std::string& message) -> bool = 0; }; } // namespace lithium diff --git a/src/addon/project/git_impl.cpp b/src/addon/project/git_impl.cpp index 45a57bd0..98ad839b 100644 --- a/src/addon/project/git_impl.cpp +++ b/src/addon/project/git_impl.cpp @@ -7,14 +7,19 @@ namespace lithium { GitManager::Impl::Impl(const std::string& repoPath) : repoPath(repoPath), repo(nullptr) { + LOG_F(INFO, "Initializing GitManager for repository path: {}", repoPath); git_libgit2_init(); + LOG_F(INFO, "libgit2 initialized."); } GitManager::Impl::~Impl() { + LOG_F(INFO, "Shutting down GitManager."); if (repo != nullptr) { git_repository_free(repo); + LOG_F(INFO, "Repository freed."); } git_libgit2_shutdown(); + LOG_F(INFO, "libgit2 shutdown completed."); } void GitManager::Impl::printError(int error) { @@ -27,24 +32,29 @@ void GitManager::Impl::printError(int error) { } auto GitManager::Impl::initRepository() -> bool { - int error = git_repository_init(&repo, repoPath.c_str(), 0); - if (error < 0) { + LOG_F(INFO, "Initializing repository at: {}", repoPath); + if (int error = git_repository_init(&repo, repoPath.c_str(), 0); + error < 0) { printError(error); return false; } + LOG_F(INFO, "Repository successfully initialized."); return true; } auto GitManager::Impl::cloneRepository(const std::string& url) -> bool { - int error = git_clone(&repo, url.c_str(), repoPath.c_str(), nullptr); - if (error < 0) { + LOG_F(INFO, "Cloning repository from URL: {} to path: {}", url, repoPath); + if (int error = git_clone(&repo, url.c_str(), repoPath.c_str(), nullptr); + error < 0) { printError(error); return false; } + LOG_F(INFO, "Repository successfully cloned."); return true; } auto GitManager::Impl::createBranch(const std::string& branchName) -> bool { + LOG_F(INFO, "Creating new branch: {}", branchName); git_reference* newBranchRef = nullptr; git_oid commitOid; @@ -68,10 +78,12 @@ auto GitManager::Impl::createBranch(const std::string& branchName) -> bool { printError(error); return false; } + LOG_F(INFO, "Branch {} successfully created.", branchName); return true; } auto GitManager::Impl::checkoutBranch(const std::string& branchName) -> bool { + LOG_F(INFO, "Checking out branch: {}", branchName); git_object* treeish = nullptr; int error = git_revparse_single(&treeish, repo, ("refs/heads/" + branchName).c_str()); @@ -92,6 +104,7 @@ auto GitManager::Impl::checkoutBranch(const std::string& branchName) -> bool { printError(error); return false; } + LOG_F(INFO, "Branch {} checked out.", branchName); return true; } @@ -151,6 +164,7 @@ auto GitManager::Impl::mergeBranch(const std::string& branchName) -> bool { return false; } } else { + LOG_F(INFO, "Performing non-fast-forward merge."); // Perform a non-fast-forward merge git_index* index; error = git_merge(repo, (const git_annotated_commit**)&branchCommit, 1, @@ -243,7 +257,7 @@ auto GitManager::Impl::mergeBranch(const std::string& branchName) -> bool { return false; } } - + LOG_F(INFO, "Merge of branch {} completed successfully.", branchName); return true; } diff --git a/src/addon/project/svn.cpp b/src/addon/project/svn.cpp index 9ded56c3..722063a6 100644 --- a/src/addon/project/svn.cpp +++ b/src/addon/project/svn.cpp @@ -1,14 +1,29 @@ #include "svn.hpp" -#include #include "svn_impl.hpp" namespace lithium { + SvnManager::SvnManager(const std::string& repoPath) : impl_(std::make_unique(repoPath)) {} + SvnManager::~SvnManager() = default; -bool SvnManager::checkout(const std::string& url, const std::string& revision) { - return impl_->checkout(url, revision); +bool SvnManager::initRepository() { return impl_->initRepository(); } + +bool SvnManager::cloneRepository(const std::string& url) { + return impl_->cloneRepository(url); +} + +bool SvnManager::createBranch(const std::string& branchName) { + return impl_->createBranch(branchName); +} + +bool SvnManager::checkoutBranch(const std::string& branchName) { + return impl_->checkoutBranch(branchName); +} + +bool SvnManager::mergeBranch(const std::string& branchName) { + return impl_->mergeBranch(branchName); } bool SvnManager::addFile(const std::string& filePath) { @@ -19,18 +34,39 @@ bool SvnManager::commitChanges(const std::string& message) { return impl_->commitChanges(message); } -bool SvnManager::update() { return impl_->update(); } - -bool SvnManager::createBranch(const std::string& branchName) { - return impl_->createBranch(branchName); +bool SvnManager::pull(const std::string& remoteName, + const std::string& branchName) { + return impl_->pull(remoteName, branchName); } -bool SvnManager::mergeBranch(const std::string& branchName) { - return impl_->mergeBranch(branchName); +bool SvnManager::push(const std::string& remoteName, + const std::string& branchName) { + return impl_->push(remoteName, branchName); } -std::vector SvnManager::getLog(int limit) { +std::vector SvnManager::getLog(int limit) { return impl_->getLog(limit); } +std::optional SvnManager::getCurrentBranch() { + return impl_->getCurrentBranch(); +} + +std::vector SvnManager::getBranches() { + return impl_->getBranches(); +} + +std::vector> SvnManager::getStatus() { + return impl_->getStatus(); +} + +bool SvnManager::revertCommit(const std::string& commitId) { + return impl_->revertCommit(commitId); +} + +bool SvnManager::createTag(const std::string& tagName, + const std::string& message) { + return impl_->createTag(tagName, message); +} + } // namespace lithium diff --git a/src/addon/project/svn.hpp b/src/addon/project/svn.hpp index 83dc6140..f7c95c1a 100644 --- a/src/addon/project/svn.hpp +++ b/src/addon/project/svn.hpp @@ -4,20 +4,32 @@ #include #include #include +#include "base.hpp" namespace lithium { -class SvnManager { +class SvnManager : public VcsManager { public: SvnManager(const std::string& repoPath); - ~SvnManager(); + ~SvnManager() override; - auto checkout(const std::string& url, const std::string& revision) -> bool; - auto addFile(const std::string& filePath) -> bool; - auto commitChanges(const std::string& message) -> bool; - auto update() -> bool; - auto createBranch(const std::string& branchName) -> bool; - auto mergeBranch(const std::string& branchName) -> bool; - auto getLog(int limit = 10) -> std::vector; + bool initRepository() override; + bool cloneRepository(const std::string& url) override; + bool createBranch(const std::string& branchName) override; + bool checkoutBranch(const std::string& branchName) override; + bool mergeBranch(const std::string& branchName) override; + bool addFile(const std::string& filePath) override; + bool commitChanges(const std::string& message) override; + bool pull(const std::string& remoteName, + const std::string& branchName) override; + bool push(const std::string& remoteName, + const std::string& branchName) override; + std::vector getLog(int limit = 10) override; + std::optional getCurrentBranch() override; + std::vector getBranches() override; + std::vector> getStatus() override; + bool revertCommit(const std::string& commitId) override; + bool createTag(const std::string& tagName, + const std::string& message) override; private: class Impl; diff --git a/src/addon/project/svn_impl.cpp b/src/addon/project/svn_impl.cpp index a31fd6d6..b16b80e0 100644 --- a/src/addon/project/svn_impl.cpp +++ b/src/addon/project/svn_impl.cpp @@ -21,6 +21,19 @@ SvnManager::Impl::~Impl() { apr_terminate(); } +bool SvnManager::Impl::initRepository() { + svn_error_t* err = svn_client_create_context2(&ctx, nullptr, pool); + if (err) { + printError(err); + return false; + } + return true; +} + +bool SvnManager::Impl::cloneRepository(const std::string& url) { + return checkout(url, "HEAD"); +} + void SvnManager::Impl::printError(svn_error_t* err) { if (err) { char buf[1024]; @@ -154,4 +167,197 @@ std::vector SvnManager::Impl::getLog(int limit) { return logMessages; } +bool SvnManager::Impl::checkoutBranch(const std::string& branchName) { + std::string branchUrl = repoPath + "/branches/" + branchName; + return checkout(branchUrl, "HEAD"); +} + +bool SvnManager::Impl::pull(const std::string& remoteName, + const std::string& branchName) { + // SVN doesn't have a direct equivalent to Git's pull, so we'll just update + return update(); +} + +bool SvnManager::Impl::push(const std::string& remoteName, + const std::string& branchName) { + // SVN doesn't have a direct equivalent to Git's push, as changes are + // immediately visible after commit + return true; +} + +std::vector SvnManager::Impl::getLog(int limit) { + std::vector logEntries; + + struct LogReceiverBaton { + std::vector* logEntries; + int remaining; + }; + + LogReceiverBaton baton = {&logEntries, limit}; + + auto log_receiver = [](void* baton, svn_log_entry_t* log_entry, + apr_pool_t* pool) -> svn_error_t* { + auto* b = static_cast(baton); + if (b->remaining-- <= 0) + return SVN_NO_ERROR; + + CommitInfo info; + info.id = std::to_string(log_entry->revision); + info.author = svn_prop_get_value(log_entry->revprops, "svn:author") + ? svn_string_value(svn_prop_get_value( + log_entry->revprops, "svn:author")) + : ""; + info.message = svn_prop_get_value(log_entry->revprops, "svn:log") + ? svn_string_value(svn_prop_get_value( + log_entry->revprops, "svn:log")) + : ""; + + apr_time_t commit_time = 0; + if (svn_prop_get_value(log_entry->revprops, "svn:date")) { + svn_time_from_cstring(&commit_time, + svn_string_value(svn_prop_get_value( + log_entry->revprops, "svn:date")), + pool); + } + info.timestamp = + std::chrono::system_clock::from_time_t(apr_time_sec(commit_time)); + + b->logEntries->push_back(info); + return SVN_NO_ERROR; + }; + + svn_opt_revision_t start = {svn_opt_revision_head}; + svn_opt_revision_t end = {svn_opt_revision_number, 0}; + + svn_error_t* err = svn_client_log5( + repoPath.c_str(), nullptr, &start, &end, limit, TRUE, TRUE, FALSE, + apr_array_make(pool, 1, sizeof(const char*)), log_receiver, &baton, ctx, + pool); + + if (err) { + printError(err); + return {}; + } + + return logEntries; +} + +std::optional SvnManager::Impl::getCurrentBranch() { + // SVN doesn't have a direct equivalent to Git's current branch concept + // We can return the current working copy root URL instead + const char* url; + svn_error_t* err = + svn_client_get_repos_root(&url, nullptr, repoPath.c_str(), ctx, pool); + if (err) { + printError(err); + return std::nullopt; + } + return std::string(url); +} + +std::vector SvnManager::Impl::getBranches() { + std::vector branches; + const char* url; + svn_error_t* err = + svn_client_get_repos_root(&url, nullptr, repoPath.c_str(), ctx, pool); + if (err) { + printError(err); + return branches; + } + + std::string branchesUrl = std::string(url) + "/branches"; + svn_opt_revision_t revision = {svn_opt_revision_head}; + + auto list_receiver = [](void* baton, const char* path, + const svn_dirent_t* dirent, const svn_lock_t* lock, + const char* abs_path, + apr_pool_t* pool) -> svn_error_t* { + auto* branches = static_cast*>(baton); + branches->push_back(svn_path_basename(path, pool)); + return SVN_NO_ERROR; + }; + + err = svn_client_list3(branchesUrl.c_str(), &revision, &revision, + svn_depth_immediates, SVN_DIRENT_ALL, FALSE, + list_receiver, &branches, ctx, pool); + if (err) { + printError(err); + } + + return branches; +} + +std::vector> SvnManager::Impl::getStatus() { + std::vector> status; + + auto status_receiver = [](void* baton, const char* path, + const svn_client_status_t* status, + apr_pool_t* pool) -> svn_error_t* { + auto* statusVec = + static_cast>*>( + baton); + std::string statusStr; + switch (status->node_status) { + case svn_wc_status_added: + statusStr = "Added"; + break; + case svn_wc_status_deleted: + statusStr = "Deleted"; + break; + case svn_wc_status_modified: + statusStr = "Modified"; + break; + case svn_wc_status_replaced: + statusStr = "Replaced"; + break; + case svn_wc_status_unversioned: + statusStr = "Unversioned"; + break; + default: + statusStr = "Unknown"; + } + statusVec->emplace_back(path, statusStr); + return SVN_NO_ERROR; + }; + + svn_opt_revision_t revision = {svn_opt_revision_working}; + svn_error_t* err = svn_client_status6( + nullptr, ctx, repoPath.c_str(), &revision, svn_depth_infinity, TRUE, + TRUE, TRUE, TRUE, TRUE, TRUE, nullptr, status_receiver, &status, pool); + if (err) { + printError(err); + } + + return status; +} + +bool SvnManager::Impl::revertCommit(const std::string& commitId) { + svn_revnum_t revision = std::stol(commitId); + svn_opt_revision_t rev = {svn_opt_revision_number, {revision}}; + svn_error_t* err = + svn_client_merge_peg5(repoPath.c_str(), nullptr, &rev, &rev, + repoPath.c_str(), svn_depth_infinity, TRUE, FALSE, + FALSE, FALSE, FALSE, nullptr, ctx, pool); + if (err) { + printError(err); + return false; + } + return true; +} + +bool SvnManager::Impl::createTag(const std::string& tagName, + const std::string& message) { + std::string tagUrl = repoPath + "/tags/" + tagName; + svn_opt_revision_t revision = {svn_opt_revision_head}; + + svn_error_t* err = svn_client_copy6( + nullptr, repoPath.c_str(), &revision, tagUrl.c_str(), FALSE, TRUE, + FALSE, nullptr, nullptr, nullptr, message.c_str(), ctx, pool); + if (err) { + printError(err); + return false; + } + return true; +} + } // namespace lithium diff --git a/src/addon/project/svn_impl.hpp b/src/addon/project/svn_impl.hpp index 36bbdd91..bb654cb3 100644 --- a/src/addon/project/svn_impl.hpp +++ b/src/addon/project/svn_impl.hpp @@ -1,11 +1,7 @@ #ifndef LITHIUM_ADDON_SVNMANAGERIMPL_HPP #define LITHIUM_ADDON_SVNMANAGERIMPL_HPP -#include -#include - #include "svn.hpp" - #include namespace lithium { @@ -14,13 +10,21 @@ class SvnManager::Impl { Impl(const std::string& repoPath); ~Impl(); - bool checkout(const std::string& url, const std::string& revision); - bool addFile(const std::string& filePath); - bool commitChanges(const std::string& message); - bool update(); + bool initRepository(); + bool cloneRepository(const std::string& url); bool createBranch(const std::string& branchName); + bool checkoutBranch(const std::string& branchName); bool mergeBranch(const std::string& branchName); - std::vector getLog(int limit); + bool addFile(const std::string& filePath); + bool commitChanges(const std::string& message); + bool pull(const std::string& remoteName, const std::string& branchName); + bool push(const std::string& remoteName, const std::string& branchName); + std::vector getLog(int limit); + std::optional getCurrentBranch(); + std::vector getBranches(); + std::vector> getStatus(); + bool revertCommit(const std::string& commitId); + bool createTag(const std::string& tagName, const std::string& message); private: std::string repoPath; diff --git a/src/addon/template/remote.cpp b/src/addon/template/remote.cpp index 0fb18328..45ab37bc 100644 --- a/src/addon/template/remote.cpp +++ b/src/addon/template/remote.cpp @@ -1,7 +1,9 @@ #include "remote.hpp" #include - +#include +#include +#include #include #include #include @@ -9,23 +11,32 @@ #include "atom/log/loguru.hpp" using asio::ip::tcp; +using asio::ip::udp; class RemoteStandAloneComponentImpl { public: std::string driverName; std::atomic shouldExit{false}; std::jthread driverThread; - std::optional socket; - std::optional endpoint; + std::variant, std::optional> socket; + std::optional> tcpEndpoint; + std::optional> udpEndpoint; asio::io_context ioContext; asio::steady_timer heartbeatTimer{ioContext}; std::atomic isListening{false}; std::function onMessageReceived; std::function onDisconnected; std::function onConnected; - int heartbeatInterval{0}; + std::chrono::milliseconds heartbeatInterval{0}; std::string heartbeatMessage; std::atomic heartbeatEnabled{false}; + ProtocolType protocol{ProtocolType::TCP}; + + // Reconnection strategy + std::chrono::milliseconds initialReconnectDelay{1000}; + std::chrono::milliseconds maxReconnectDelay{30000}; + int maxReconnectAttempts{5}; + int currentReconnectAttempts{0}; void handleDriverOutput(std::string_view buffer) { if (onMessageReceived) { @@ -40,15 +51,17 @@ RemoteStandAloneComponent::RemoteStandAloneComponent(std::string name) : Component(std::move(name)), impl_(std::make_unique()) { doc("A remote standalone component that can connect to a remote driver via " - "TCP"); + "TCP or UDP"); def("connect", &RemoteStandAloneComponent::connectToRemoteDriver); def("disconnect", &RemoteStandAloneComponent::disconnectRemoteDriver); - def("send", &RemoteStandAloneComponent::sendMessageToDriver); - def("send_async", &RemoteStandAloneComponent::sendMessageAsync); + def("send", &RemoteStandAloneComponent::sendMessageToDriver); + def("send_async", + &RemoteStandAloneComponent::sendMessageAsync); def("listen", &RemoteStandAloneComponent::toggleDriverListening); def("print", &RemoteStandAloneComponent::printDriver); def("heartbeat_on", &RemoteStandAloneComponent::enableHeartbeat); def("heartbeat_off", &RemoteStandAloneComponent::disableHeartbeat); + def("execute", &RemoteStandAloneComponent::executeCommand); } RemoteStandAloneComponent::~RemoteStandAloneComponent() { @@ -61,43 +74,65 @@ RemoteStandAloneComponent::~RemoteStandAloneComponent() { } void RemoteStandAloneComponent::connectToRemoteDriver( - const std::string& address, uint16_t port, std::optional timeout) { + const std::string& address, uint16_t port, ProtocolType protocol, + std::chrono::milliseconds timeout) { + impl_->protocol = protocol; try { - tcp::resolver resolver(impl_->ioContext); - auto endpoints = resolver.resolve(address, std::to_string(port)); - impl_->socket.emplace(impl_->ioContext); - - if (timeout) { - asio::steady_timer timer(impl_->ioContext); - timer.expires_after(std::chrono::milliseconds(*timeout)); - asio::error_code ec = asio::error::operation_aborted; - - impl_->socket->async_connect(*endpoints, - [&](const asio::error_code& error) { - ec = error; - timer.cancel(); - }); - - timer.async_wait([&](const asio::error_code& error) { - if (!error) { - impl_->socket->cancel(); - } - }); + switch (protocol) { + case ProtocolType::TCP: { + tcp::resolver resolver(impl_->ioContext); + auto endpoints = + resolver.resolve(address, std::to_string(port)); + impl_->socket.emplace>( + impl_->ioContext); + auto& tcpSocket = + std::get>(impl_->socket).value(); + + asio::steady_timer timer(impl_->ioContext); + timer.expires_after(timeout); + asio::error_code ec = asio::error::operation_aborted; + + tcpSocket.async_connect(*endpoints, + [&](const asio::error_code& error) { + ec = error; + timer.cancel(); + }); + + timer.async_wait([&](const asio::error_code& error) { + if (!error) { + tcpSocket.cancel(); + } + }); + + impl_->ioContext.run_one(); - impl_->ioContext.run_one(); + if (ec) { + throw asio::system_error(ec); + } - if (ec) { - throw asio::system_error(ec); + impl_->tcpEndpoint = *endpoints; + break; + } + case ProtocolType::UDP: { + udp::resolver resolver(impl_->ioContext); + auto endpoint = + *resolver.resolve(udp::v4(), address, std::to_string(port)) + .begin(); + impl_->socket.emplace>( + impl_->ioContext, udp::endpoint(udp::v4(), 0)); + auto& udpSocket = + std::get>(impl_->socket).value(); + udpSocket.connect(endpoint); + impl_->udpEndpoint = endpoint; + break; } - } else { - asio::connect(*impl_->socket, endpoints); } - impl_->endpoint = *endpoints; if (impl_->onConnected) impl_->onConnected(); - LOG_F(INFO, "Connected to remote driver at {}:{}", address, port); + LOG_F(INFO, "Connected to remote driver at {}:{} using {}", address, + port, protocol == ProtocolType::TCP ? "TCP" : "UDP"); impl_->driverThread = std::jthread( &RemoteStandAloneComponent::backgroundProcessing, this); @@ -109,37 +144,60 @@ void RemoteStandAloneComponent::connectToRemoteDriver( } void RemoteStandAloneComponent::disconnectRemoteDriver() { - if (impl_->socket && impl_->socket->is_open()) { - asio::error_code ec; - impl_->socket->shutdown(tcp::socket::shutdown_both, ec); - impl_->socket->close(ec); - if (ec) { - LOG_F(ERROR, "Error closing connection: {}", ec.message()); - } else { - LOG_F(INFO, "Disconnected from remote driver"); - if (impl_->onDisconnected) - impl_->onDisconnected(); - } - } + std::visit( + [](auto&& socket) { + if (socket && socket->is_open()) { + asio::error_code ec; + socket->shutdown(std::decay_t::shutdown_both, + ec); + socket->close(ec); + } + }, + impl_->socket); + + LOG_F(INFO, "Disconnected from remote driver"); + if (impl_->onDisconnected) + impl_->onDisconnected(); + impl_->shouldExit = true; } -void RemoteStandAloneComponent::sendMessageToDriver(std::string_view message) { - if (impl_->socket && impl_->socket->is_open()) { - asio::write(*impl_->socket, asio::buffer(message)); - } else { - LOG_F(ERROR, "No active connection to send message"); - } +template +void RemoteStandAloneComponent::sendMessageToDriver(T&& message) { + std::visit( + [&](auto&& socket) { + if (socket && socket->is_open()) { + asio::write(*socket, asio::buffer(std::forward(message))); + } else { + LOG_F(ERROR, "No active connection to send message"); + } + }, + impl_->socket); } -void RemoteStandAloneComponent::sendMessageAsync( - std::string_view message, - std::function callback) { - if (impl_->socket && impl_->socket->is_open()) { - asio::async_write(*impl_->socket, asio::buffer(message), callback); - } else { - LOG_F(ERROR, "No active connection to send message"); - } +template +std::future> +RemoteStandAloneComponent::sendMessageAsync(T&& message) { + auto promise = std::make_shared< + std::promise>>(); + auto future = promise->get_future(); + + std::visit( + [&](auto&& socket) { + if (socket && socket->is_open()) { + asio::async_write( + *socket, asio::buffer(std::forward(message)), + [promise](const asio::error_code& ec, + std::size_t bytes_transferred) { + promise->set_value({ec, bytes_transferred}); + }); + } else { + promise->set_value({asio::error::not_connected, 0}); + } + }, + impl_->socket); + + return future; } void RemoteStandAloneComponent::setOnMessageReceivedCallback( @@ -157,9 +215,9 @@ void RemoteStandAloneComponent::setOnConnectedCallback( impl_->onConnected = std::move(callback); } -void RemoteStandAloneComponent::enableHeartbeat(int interval_ms, - std::string_view pingMessage) { - impl_->heartbeatInterval = interval_ms; +void RemoteStandAloneComponent::enableHeartbeat( + std::chrono::milliseconds interval, std::string_view pingMessage) { + impl_->heartbeatInterval = interval; impl_->heartbeatMessage = pingMessage; impl_->heartbeatEnabled = true; startHeartbeat(); @@ -171,9 +229,14 @@ void RemoteStandAloneComponent::disableHeartbeat() { } void RemoteStandAloneComponent::printDriver() const { - if (impl_->endpoint) { - LOG_F(INFO, "Remote Driver: {}:{}", - impl_->endpoint->address().to_string(), impl_->endpoint->port()); + if (impl_->tcpEndpoint) { + LOG_F(INFO, "Remote Driver (TCP): {}:{}", + impl_->tcpEndpoint->address().to_string(), + impl_->tcpEndpoint->port()); + } else if (impl_->udpEndpoint) { + LOG_F(INFO, "Remote Driver (UDP): {}:{}", + impl_->udpEndpoint->address().to_string(), + impl_->udpEndpoint->port()); } else { LOG_F(INFO, "No remote driver connected"); } @@ -185,23 +248,65 @@ void RemoteStandAloneComponent::toggleDriverListening() { impl_->isListening ? "ON" : "OFF"); } -void RemoteStandAloneComponent::executeCommand( - std::string_view command, std::function callback) { - sendMessageAsync(command, [this, callback](std::error_code ec, - std::size_t) { - if (!ec) { - std::array buffer; - asio::error_code error; - size_t len = impl_->socket->read_some(asio::buffer(buffer), error); - if (!error) { - callback(std::string_view(buffer.data(), len)); +template +std::future RemoteStandAloneComponent::executeCommand( + T&& command) { + auto promise = std::make_shared>(); + auto future = promise->get_future(); + + sendMessageAsync(std::forward(command)) + .then([this, promise](auto&& result) { + auto [ec, _] = result.get(); + if (!ec) { + std::array buffer; + std::size_t len = 0; + + std::visit( + [&](auto&& socket) { + if constexpr (std::is_same_v< + std::decay_t, + std::optional>) { + asio::error_code error; + len = + socket->read_some(asio::buffer(buffer), error); + if (error) { + promise->set_exception(std::make_exception_ptr( + std::runtime_error(error.message()))); + return; + } + } else if constexpr (std::is_same_v< + std::decay_t, + std::optional>) { + udp::endpoint sender_endpoint; + asio::error_code error; + len = + socket->receive_from(asio::buffer(buffer), + sender_endpoint, 0, error); + if (error) { + promise->set_exception(std::make_exception_ptr( + std::runtime_error(error.message()))); + return; + } + } + }, + impl_->socket); + + promise->set_value(std::string(buffer.data(), len)); } else { - LOG_F(ERROR, "Command execution failed: {}", error.message()); + promise->set_exception( + std::make_exception_ptr(std::runtime_error(ec.message()))); } - } else { - LOG_F(ERROR, "Failed to send command: {}", ec.message()); - } - }); + }); + + return future; +} + +void RemoteStandAloneComponent::setReconnectionStrategy( + std::chrono::milliseconds initialDelay, std::chrono::milliseconds maxDelay, + int maxAttempts) { + impl_->initialReconnectDelay = initialDelay; + impl_->maxReconnectDelay = maxDelay; + impl_->maxReconnectAttempts = maxAttempts; } void RemoteStandAloneComponent::backgroundProcessing() { @@ -213,50 +318,80 @@ void RemoteStandAloneComponent::backgroundProcessing() { } void RemoteStandAloneComponent::monitorConnection() { - if (impl_->socket && impl_->socket->is_open()) { - // 可以在这里添加更多逻辑来监控连接状态,比如检查连接的活跃性。 - } else if (!impl_->shouldExit) { + bool isConnected = + std::visit([](auto&& socket) { return socket && socket->is_open(); }, + impl_->socket); + + if (!isConnected && !impl_->shouldExit) { LOG_F(INFO, "Connection lost. Attempting to reconnect..."); - disconnectRemoteDriver(); - // 根据需要添加自动重连的逻辑。 + attemptReconnection(); } } void RemoteStandAloneComponent::processMessages() { - if (impl_->socket && impl_->socket->is_open() && impl_->isListening) { - std::array buffer; - asio::error_code error; - size_t len = impl_->socket->read_some(asio::buffer(buffer), error); - - if (error == asio::error::eof) { - LOG_F(INFO, "Connection closed by remote driver"); - disconnectRemoteDriver(); - if (impl_->onDisconnected) { - impl_->onDisconnected(); + if (!impl_->isListening) + return; + + std::visit( + [this](auto&& socket) { + if constexpr (std::is_same_v, + std::optional>) { + if (socket && socket->is_open()) { + std::array buffer; + socket->async_read_some( + asio::buffer(buffer), + [this, buffer](const asio::error_code& error, + std::size_t bytes_transferred) { + if (!error) { + impl_->handleDriverOutput(std::string_view( + buffer.data(), bytes_transferred)); + } else if (error == asio::error::eof) { + LOG_F(INFO, + "Connection closed by remote driver"); + disconnectRemoteDriver(); + } else { + LOG_F(ERROR, "Read error: {}", error.message()); + } + }); + } + } else if constexpr (std::is_same_v, + std::optional>) { + if (socket && socket->is_open()) { + std::array buffer; + udp::endpoint sender_endpoint; + socket->async_receive_from( + asio::buffer(buffer), sender_endpoint, + [this, buffer](const asio::error_code& error, + std::size_t bytes_transferred) { + if (!error) { + impl_->handleDriverOutput(std::string_view( + buffer.data(), bytes_transferred)); + } else { + LOG_F(ERROR, "Read error: {}", error.message()); + } + }); + } } - } else if (error) { - LOG_F(ERROR, "Read error: {}", error.message()); - } else { - impl_->handleDriverOutput(std::string_view(buffer.data(), len)); - } - } + }, + impl_->socket); } void RemoteStandAloneComponent::startHeartbeat() { if (!impl_->heartbeatEnabled) return; - impl_->heartbeatTimer.expires_after( - std::chrono::milliseconds(impl_->heartbeatInterval)); + impl_->heartbeatTimer.expires_after(impl_->heartbeatInterval); impl_->heartbeatTimer.async_wait([this](const asio::error_code& error) { if (!error && impl_->heartbeatEnabled) { - sendMessageAsync(impl_->heartbeatMessage, [](std::error_code ec, - std::size_t) { + sendMessageAsync(impl_->heartbeatMessage).then([this](auto future) { + auto [ec, _] = future.get(); if (ec) { LOG_F(ERROR, "Failed to send heartbeat: {}", ec.message()); + attemptReconnection(); + } else { + startHeartbeat(); } }); - startHeartbeat(); } }); } @@ -265,3 +400,50 @@ void RemoteStandAloneComponent::stopHeartbeat() { impl_->heartbeatEnabled = false; impl_->heartbeatTimer.cancel(); } + +void RemoteStandAloneComponent::attemptReconnection() { + if (impl_->currentReconnectAttempts >= impl_->maxReconnectAttempts) { + LOG_F(ERROR, "Max reconnection attempts reached. Giving up."); + return; + } + + std::chrono::milliseconds delay = std::min( + impl_->initialReconnectDelay * (1 << impl_->currentReconnectAttempts), + impl_->maxReconnectDelay); + + LOG_F(INFO, "Attempting to reconnect in {} ms", delay.count()); + + std::this_thread::sleep_for(delay); + + std::visit( + [this](auto&& socket) { + using SocketType = std::decay_t; + if constexpr (std::is_same_v) { + connectToRemoteDriver(impl_->tcpEndpoint->address().to_string(), + impl_->tcpEndpoint->port(), + ProtocolType::TCP); + } else if constexpr (std::is_same_v) { + connectToRemoteDriver(impl_->udpEndpoint->address().to_string(), + impl_->udpEndpoint->port(), + ProtocolType::UDP); + } + }, + impl_->socket); + + impl_->currentReconnectAttempts++; +} + +// Explicit template instantiations +template void RemoteStandAloneComponent::sendMessageToDriver( + std::string&&); +template void RemoteStandAloneComponent::sendMessageToDriver( + std::string_view&&); +template std::future> +RemoteStandAloneComponent::sendMessageAsync(std::string&&); +template std::future> +RemoteStandAloneComponent::sendMessageAsync( + std::string_view&&); +template std::future +RemoteStandAloneComponent::executeCommand(std::string&&); +template std::future +RemoteStandAloneComponent::executeCommand(std::string_view&&); diff --git a/src/addon/template/remote.hpp b/src/addon/template/remote.hpp index 8541106e..33ed89bd 100644 --- a/src/addon/template/remote.hpp +++ b/src/addon/template/remote.hpp @@ -1,13 +1,21 @@ #ifndef LITHIUM_ADDON_REMOTE_STANDALONE_HPP #define LITHIUM_ADDON_REMOTE_STANDALONE_HPP +#include +#include #include +#include #include #include #include #include "atom/components/component.hpp" +enum class ProtocolType { TCP, UDP }; + +template +concept Stringlike = std::is_convertible_v; + class RemoteStandAloneComponentImpl; class RemoteStandAloneComponent : public Component { @@ -15,16 +23,19 @@ class RemoteStandAloneComponent : public Component { explicit RemoteStandAloneComponent(std::string name); ~RemoteStandAloneComponent() override; - void connectToRemoteDriver(const std::string& address, uint16_t port, - std::optional timeout = std::nullopt); + void connectToRemoteDriver( + const std::string& address, uint16_t port, + ProtocolType protocol = ProtocolType::TCP, + std::chrono::milliseconds timeout = std::chrono::seconds(5)); void disconnectRemoteDriver(); - void sendMessageToDriver(std::string_view message); + template + void sendMessageToDriver(T&& message); - void sendMessageAsync( - std::string_view message, - std::function callback); + template + std::future> sendMessageAsync( + T&& message); void setOnMessageReceivedCallback( std::function callback); @@ -33,7 +44,8 @@ class RemoteStandAloneComponent : public Component { void setOnConnectedCallback(std::function callback); - void enableHeartbeat(int interval_ms, std::string_view pingMessage); + void enableHeartbeat(std::chrono::milliseconds interval, + std::string_view pingMessage); void disableHeartbeat(); @@ -41,19 +53,20 @@ class RemoteStandAloneComponent : public Component { void toggleDriverListening(); - void executeCommand(std::string_view command, - std::function callback); + template + std::future executeCommand(T&& command); + + void setReconnectionStrategy(std::chrono::milliseconds initialDelay, + std::chrono::milliseconds maxDelay, + int maxAttempts); private: void backgroundProcessing(); - void monitorConnection(); - void processMessages(); - void startHeartbeat(); - void stopHeartbeat(); + void attemptReconnection(); std::unique_ptr impl_; }; diff --git a/src/addon/template/standalone.cpp b/src/addon/template/standalone.cpp index 86ea7944..6bf70655 100644 --- a/src/addon/template/standalone.cpp +++ b/src/addon/template/standalone.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -20,25 +21,22 @@ #include #include #include +#include #include #include #endif struct LocalDriver { int processHandle{}; - int stdinFd{}; - int stdoutFd{}; + std::variant, std::pair> io; std::string name; bool isListening{}; + InteractionMethod method; }; -#if defined(_WIN32) || defined(_WIN64) -constexpr char SEM_NAME[] = "driver_semaphore"; -constexpr char SHM_NAME[] = "driver_shm"; -#else constexpr char SEM_NAME[] = "/driver_semaphore"; constexpr char SHM_NAME[] = "/driver_shm"; -#endif +constexpr char FIFO_NAME[] = "/tmp/driver_fifo"; class StandAloneComponentImpl { public: @@ -51,20 +49,6 @@ class StandAloneComponentImpl { LOG_F(INFO, "Output from driver {}: {}", driver_name, std::string_view(buffer.data(), buffer.size())); } - - void closePipes(int stdinPipe[2], int stdoutPipe[2]) { -#if defined(_WIN32) || defined(_WIN64) - _close(stdinPipe[0]); - _close(stdinPipe[1]); - _close(stdoutPipe[0]); - _close(stdoutPipe[1]); -#else - close(stdinPipe[0]); - close(stdinPipe[1]); - close(stdoutPipe[0]); - close(stdoutPipe[1]); -#endif - } }; StandAloneComponent::StandAloneComponent(std::string name) @@ -87,19 +71,40 @@ StandAloneComponent::~StandAloneComponent() { } } -void StandAloneComponent::startLocalDriver(const std::string& driver_name) { - int stdinPipe[2]; - int stdoutPipe[2]; - - if (!createPipes(stdinPipe, stdoutPipe)) { - return; +void StandAloneComponent::startLocalDriver(const std::string& driver_name, + InteractionMethod method) { + std::variant, std::pair> io; + + switch (method) { + case InteractionMethod::Pipe: + if (auto pipes = createPipes()) { + io = *pipes; + } else { + return; + } + break; + case InteractionMethod::FIFO: + if (auto fifo = createFIFO()) { + io = *fifo; + } else { + return; + } + break; + case InteractionMethod::SharedMemory: + if (auto shm = createSharedMemory()) { + io = *shm; + } else { + return; + } + break; } #if defined(_WIN32) || defined(_WIN64) - startWindowsProcess(driver_name, stdinPipe, stdoutPipe); + startWindowsProcess(driver_name, io); #else - startUnixProcess(driver_name, stdinPipe, stdoutPipe); + startUnixProcess(driver_name, io); #endif + impl_->driver.method = method; impl_->driverThread = std::jthread(&StandAloneComponent::backgroundProcessing, this); } @@ -112,20 +117,88 @@ void StandAloneComponent::backgroundProcessing() { } } -auto StandAloneComponent::createPipes(int stdinPipe[2], - int stdoutPipe[2]) -> bool { +auto StandAloneComponent::createPipes() -> std::optional> { + int stdinPipe[2], stdoutPipe[2]; + +#if defined(_WIN32) || defined(_WIN64) + if (_pipe(stdinPipe, 4096, _O_BINARY | _O_NOINHERIT) == 0 && + _pipe(stdoutPipe, 4096, _O_BINARY | _O_NOINHERIT) == 0) { + return std::make_pair(stdinPipe[1], stdoutPipe[0]); + } +#else + if (pipe(stdinPipe) == 0 && pipe(stdoutPipe) == 0) { + return std::make_pair(stdinPipe[1], stdoutPipe[0]); + } +#endif + LOG_F(ERROR, "Failed to create pipes"); + return std::nullopt; +} + +auto StandAloneComponent::createFIFO() -> std::optional> { #if defined(_WIN32) || defined(_WIN64) - return _pipe(stdinPipe, 4096, _O_BINARY | _O_NOINHERIT) == 0 && - _pipe(stdoutPipe, 4096, _O_BINARY | _O_NOINHERIT) == 0; + LOG_F(ERROR, "FIFO not supported on Windows"); + return std::nullopt; #else - return pipe(stdinPipe) == 0 && pipe(stdoutPipe) == 0; + if (mkfifo(FIFO_NAME, 0666) == -1 && errno != EEXIST) { + LOG_F(ERROR, "Failed to create FIFO"); + return std::nullopt; + } + + int readFd = open(FIFO_NAME, O_RDONLY | O_NONBLOCK); + int writeFd = open(FIFO_NAME, O_WRONLY); + + if (readFd == -1 || writeFd == -1) { + LOG_F(ERROR, "Failed to open FIFO"); + if (readFd != -1) + close(readFd); + if (writeFd != -1) + close(writeFd); + return std::nullopt; + } + + return std::make_pair(writeFd, readFd); +#endif +} + +auto StandAloneComponent::createSharedMemory() + -> std::optional> { +#if defined(_WIN32) || defined(_WIN64) + LOG_F(ERROR, "Shared memory not implemented for Windows"); + return std::nullopt; +#else + int shm_fd = shm_open(SHM_NAME, O_CREAT | O_RDWR, 0666); + if (shm_fd == -1) { + LOG_F(ERROR, "Failed to create shared memory"); + return std::nullopt; + } + + if (ftruncate(shm_fd, sizeof(int)) == -1) { + LOG_F(ERROR, "Failed to set size of shared memory"); + close(shm_fd); + shm_unlink(SHM_NAME); + return std::nullopt; + } + + int* shm_ptr = static_cast(mmap( + nullptr, sizeof(int), PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0)); + if (shm_ptr == MAP_FAILED) { + LOG_F(ERROR, "Failed to map shared memory"); + close(shm_fd); + shm_unlink(SHM_NAME); + return std::nullopt; + } + + *shm_ptr = 0; // Initialize shared memory to 0 + return std::make_pair(shm_fd, shm_ptr); #endif } #if defined(_WIN32) || defined(_WIN64) -void StandAloneComponent::startWindowsProcess(const std::string& driver_name, - int stdinPipe[2], - int stdoutPipe[2]) { +void StandAloneComponent::startWindowsProcess( + const std::string& driver_name, + std::variant, std::pair> io) { + auto [inHandle, outHandle] = std::get>(io); + SECURITY_ATTRIBUTES sa = {sizeof(SECURITY_ATTRIBUTES), nullptr, TRUE}; HANDLE hStdinRead, hStdinWrite, hStdoutRead, hStdoutWrite; @@ -150,7 +223,6 @@ void StandAloneComponent::startWindowsProcess(const std::string& driver_name, if (!CreateProcess(nullptr, cmd.data(), nullptr, nullptr, TRUE, CREATE_NO_WINDOW, nullptr, nullptr, &si, &pi)) { LOG_F(ERROR, "Failed to start process"); - impl_->closePipes(stdinPipe, stdoutPipe); return; } @@ -159,117 +231,58 @@ void StandAloneComponent::startWindowsProcess(const std::string& driver_name, impl_->driver.processHandle = static_cast(reinterpret_cast(pi.hProcess)); - - impl_->driver.stdinFd = - _open_osfhandle(reinterpret_cast(hStdinWrite), 0); - impl_->driver.stdoutFd = - _open_osfhandle(reinterpret_cast(hStdoutRead), 0); + impl_->driver.io = std::make_pair( + _open_osfhandle(reinterpret_cast(hStdinWrite), 0), + _open_osfhandle(reinterpret_cast(hStdoutRead), 0)); impl_->driver.name = driver_name; } #else -void StandAloneComponent::startUnixProcess(const std::string& driver_name, - int stdinPipe[2], - int stdoutPipe[2]) { - auto [shmFd, shmPtr] = - createSharedMemory().value_or(std::pair{-1, nullptr}); - if (shmFd == -1 || shmPtr == nullptr) { - impl_->closePipes(stdinPipe, stdoutPipe); - return; - } - +void StandAloneComponent::startUnixProcess( + const std::string& driver_name, + std::variant, std::pair> io) { auto sem = createSemaphore().value_or(nullptr); if (sem == nullptr) { - impl_->closePipes(stdinPipe, stdoutPipe); - closeSharedMemory(shmFd, shmPtr); return; } pid_t pid = fork(); if (pid == 0) { - handleChildProcess(driver_name, stdinPipe, stdoutPipe, shmPtr, sem, - shmFd); + handleChildProcess(driver_name, io, nullptr, sem, -1); } else if (pid > 0) { - handleParentProcess(pid, stdinPipe, stdoutPipe, shmPtr, sem, shmFd); + handleParentProcess(pid, io, nullptr, sem, -1); } else { LOG_F(ERROR, "Failed to fork driver process"); - impl_->closePipes(stdinPipe, stdoutPipe); - closeSharedMemory(shmFd, shmPtr); sem_close(sem); } } -auto StandAloneComponent::createSharedMemory() - -> std::optional > { - int shm_fd = shm_open(SHM_NAME, O_CREAT | O_RDWR, 0666); - if (shm_fd == -1) { - LOG_F(ERROR, "Failed to create shared memory"); - return std::nullopt; - } - - if (ftruncate(shm_fd, sizeof(int)) == -1) { - LOG_F(ERROR, "Failed to set size of shared memory"); - close(shm_fd); - shm_unlink(SHM_NAME); - return std::nullopt; - } - - int* shm_ptr = static_cast(mmap( - nullptr, sizeof(int), PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0)); - if (shm_ptr == MAP_FAILED) { - LOG_F(ERROR, "Failed to map shared memory"); - close(shm_fd); - shm_unlink(SHM_NAME); - return std::nullopt; - } - - *shm_ptr = 0; // Initialize shared memory to 0 - return std::make_pair(shm_fd, shm_ptr); -} - -void StandAloneComponent::closeSharedMemory(int shm_fd, int* shm_ptr) { - munmap(shm_ptr, sizeof(int)); - close(shm_fd); - shm_unlink(SHM_NAME); -} - -auto StandAloneComponent::createSemaphore() -> std::optional { - sem_t* sem = sem_open(SEM_NAME, O_CREAT | O_EXCL, 0644, 0); - if (sem == SEM_FAILED) { - LOG_F(ERROR, "Failed to create semaphore"); - return std::nullopt; +void StandAloneComponent::handleChildProcess( + const std::string& driver_name, + std::variant, std::pair> io, int* shm_ptr, + sem_t* sem, int shm_fd) { + if (std::holds_alternative>(io)) { + auto [inFd, outFd] = std::get>(io); + dup2(inFd, STDIN_FILENO); + dup2(outFd, STDOUT_FILENO); } - sem_unlink(SEM_NAME); // Ensure the semaphore is removed once it's no - // longer needed - return sem; -} - -void StandAloneComponent::handleChildProcess(const std::string& driver_name, - int stdinPipe[2], - int stdoutPipe[2], int* shm_ptr, - sem_t* sem, int shm_fd) { - close(stdinPipe[1]); - close(stdoutPipe[0]); - - dup2(stdinPipe[0], STDIN_FILENO); - dup2(stdoutPipe[1], STDOUT_FILENO); execlp(driver_name.data(), driver_name.data(), nullptr); - *shm_ptr = -1; + if (shm_ptr) + *shm_ptr = -1; sem_post(sem); LOG_F(ERROR, "Failed to exec driver process"); - close(shm_fd); - munmap(shm_ptr, sizeof(int)); + if (shm_fd != -1) { + close(shm_fd); + munmap(shm_ptr, sizeof(int)); + } sem_close(sem); std::exit(1); } -void StandAloneComponent::handleParentProcess(pid_t pid, int stdinPipe[2], - int stdoutPipe[2], int* shm_ptr, - sem_t* sem, int shm_fd) { - close(stdinPipe[0]); - close(stdoutPipe[1]); - +void StandAloneComponent::handleParentProcess( + pid_t pid, std::variant, std::pair> io, + int* shm_ptr, sem_t* sem, int shm_fd) { struct timespec ts; clock_gettime(CLOCK_REALTIME, &ts); ts.tv_sec += 1; @@ -277,32 +290,60 @@ void StandAloneComponent::handleParentProcess(pid_t pid, int stdinPipe[2], if (sem_timedwait(sem, &ts) == -1) { LOG_F(ERROR, errno == ETIMEDOUT ? "Driver process start timed out" : "Failed to wait on semaphore"); - close(stdinPipe[1]); - close(stdoutPipe[0]); kill(pid, SIGKILL); waitpid(pid, nullptr, 0); - } else if (*shm_ptr == -1) { + } else if ((shm_ptr != nullptr) && *shm_ptr == -1) { LOG_F(ERROR, "Driver process failed to start"); - close(stdinPipe[1]); - close(stdoutPipe[0]); waitpid(pid, nullptr, 0); } else { impl_->driver.processHandle = pid; - impl_->driver.stdinFd = stdinPipe[1]; - impl_->driver.stdoutFd = stdoutPipe[0]; - // impl_->driver.name = driver_name; + impl_->driver.io = io; } - closeSharedMemory(shm_fd, shm_ptr); + if (shm_fd != -1) { + closeSharedMemory(shm_fd, shm_ptr); + } sem_close(sem); } #endif +auto StandAloneComponent::createSemaphore() -> std::optional { +#if defined(_WIN32) || defined(_WIN64) + LOG_F(ERROR, "Semaphore not implemented for Windows"); + return std::nullopt; +#else + sem_t* sem = sem_open(SEM_NAME, O_CREAT | O_EXCL, 0644, 0); + if (sem == SEM_FAILED) { + LOG_F(ERROR, "Failed to create semaphore"); + return std::nullopt; + } + sem_unlink(SEM_NAME); // Ensure the semaphore is removed once it's no + // longer needed + return sem; +#endif +} + +void StandAloneComponent::closeSharedMemory(int shm_fd, int* shm_ptr) { +#if !defined(_WIN32) && !defined(_WIN64) + munmap(shm_ptr, sizeof(int)); + close(shm_fd); + shm_unlink(SHM_NAME); +#endif +} + void StandAloneComponent::stopLocalDriver() { - if (impl_->driver.stdinFd != -1) - close(impl_->driver.stdinFd); - if (impl_->driver.stdoutFd != -1) - close(impl_->driver.stdoutFd); + std::visit( + [](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + close(arg.first); + close(arg.second); + } else if constexpr (std::is_same_v>) { + close(arg.first); + munmap(arg.second, sizeof(int)); + } + }, + impl_->driver.io); #if defined(_WIN32) || defined(_WIN64) TerminateProcess(reinterpret_cast(impl_->driver.processHandle), 0); @@ -315,6 +356,11 @@ void StandAloneComponent::stopLocalDriver() { if (impl_->driverThread.joinable()) { impl_->driverThread.join(); } + + // Clean up FIFO if used + if (impl_->driver.method == InteractionMethod::FIFO) { + unlink(FIFO_NAME); + } } void StandAloneComponent::monitorDrivers() { @@ -336,21 +382,38 @@ void StandAloneComponent::monitorDrivers() { } #endif LOG_F(INFO, "Driver {} exited, restarting...", impl_->driver.name); - startLocalDriver(impl_->driver.name); + startLocalDriver(impl_->driver.name, impl_->driver.method); } void StandAloneComponent::processMessages() { std::array buffer; if (impl_->driver.isListening) { + int bytesRead = 0; + std::visit( + [&](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v>) { #if defined(_WIN32) || defined(_WIN64) - int bytesRead = _read(impl_->driver.stdoutFd, buffer.data(), - static_cast(buffer.size())); + bytesRead = _read(arg.second, buffer.data(), + static_cast(buffer.size())); #else - int flags = fcntl(impl_->driver.stdoutFd, F_GETFL, 0); - fcntl(impl_->driver.stdoutFd, F_SETFL, flags | O_NONBLOCK); - int bytesRead = - read(impl_->driver.stdoutFd, buffer.data(), buffer.size()); + int flags = fcntl(arg.second, F_GETFL, 0); + fcntl(arg.second, F_SETFL, flags | O_NONBLOCK); + bytesRead = read(arg.second, buffer.data(), buffer.size()); #endif + } else if constexpr (std::is_same_v>) { + // For shared memory, we need to implement a custom protocol + // This is a simple example, you might want to implement a + // more robust solution + if (*arg.second != 0) { + bytesRead = snprintf(buffer.data(), buffer.size(), "%d", + *arg.second); + *arg.second = 0; // Reset the shared memory + } + } + }, + impl_->driver.io); + if (bytesRead > 0) { impl_->handleDriverOutput(impl_->driver.name, std::span(buffer.data(), bytesRead)); @@ -359,18 +422,43 @@ void StandAloneComponent::processMessages() { } void StandAloneComponent::sendMessageToDriver(std::string_view message) { + std::visit( + [&](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v>) { #if defined(_WIN32) || defined(_WIN64) - _write(impl_->driver.stdinFd, message.data(), - static_cast(message.size())); + _write(arg.first, message.data(), + static_cast(message.size())); #else - write(impl_->driver.stdinFd, message.data(), message.size()); + write(arg.first, message.data(), message.size()); #endif + } else if constexpr (std::is_same_v>) { + // For shared memory, we need to implement a custom protocol + // This is a simple example, you might want to implement a more + // robust solution + *arg.second = std::atoi(message.data()); + } + }, + impl_->driver.io); } void StandAloneComponent::printDriver() const { - LOG_F(INFO, "{} (PID: {}) {}", impl_->driver.name, + std::string interactionMethod; + switch (impl_->driver.method) { + case InteractionMethod::Pipe: + interactionMethod = "Pipe"; + break; + case InteractionMethod::FIFO: + interactionMethod = "FIFO"; + break; + case InteractionMethod::SharedMemory: + interactionMethod = "Shared Memory"; + break; + } + + LOG_F(INFO, "{} (PID: {}) {} [{}]", impl_->driver.name, impl_->driver.processHandle, - impl_->driver.isListening ? "[Listening]" : ""); + impl_->driver.isListening ? "[Listening]" : "", interactionMethod); } void StandAloneComponent::toggleDriverListening() { diff --git a/src/addon/template/standalone.hpp b/src/addon/template/standalone.hpp index fea1adce..22c50850 100644 --- a/src/addon/template/standalone.hpp +++ b/src/addon/template/standalone.hpp @@ -4,9 +4,13 @@ #include #include #include +#include +#include #include "atom/components/component.hpp" +enum class InteractionMethod { Pipe, FIFO, SharedMemory }; + class StandAloneComponentImpl; class StandAloneComponent : public Component { @@ -14,41 +18,41 @@ class StandAloneComponent : public Component { explicit StandAloneComponent(std::string name); ~StandAloneComponent() override; - void startLocalDriver(const std::string& driver_name); - + void startLocalDriver(const std::string& driver_name, + InteractionMethod method); void stopLocalDriver(); - void monitorDrivers(); - void processMessages(); - void sendMessageToDriver(std::string_view message); - void printDriver() const; - void toggleDriverListening(); private: - auto createPipes(int stdinPipe[2], int stdoutPipe[2]) -> bool; + auto createPipes() -> std::optional>; + auto createFIFO() -> std::optional>; + auto createSharedMemory() -> std::optional>; void backgroundProcessing(); #if defined(_WIN32) || defined(_WIN64) - void startWindowsProcess(const std::string& driver_name, int stdinPipe[2], - int stdoutPipe[2]); + void startWindowsProcess( + const std::string& driver_name, + std::variant, std::pair> io); #else - void startUnixProcess(const std::string& driver_name, int stdinPipe[2], - int stdoutPipe[2]); + void startUnixProcess( + const std::string& driver_name, + std::variant, std::pair> io); + void handleChildProcess( + const std::string& driver_name, + std::variant, std::pair> io, + int* shm_ptr, sem_t* sem, int shm_fd); + void handleParentProcess( + pid_t pid, std::variant, std::pair> io, + int* shm_ptr, sem_t* sem, int shm_fd); +#endif - auto createSharedMemory() -> std::optional >; - void closeSharedMemory(int shm_fd, int* shm_ptr); auto createSemaphore() -> std::optional; - void handleChildProcess(const std::string& driver_name, int stdinPipe[2], - int stdoutPipe[2], int* shm_ptr, sem_t* sem, - int shm_fd); - void handleParentProcess(pid_t pid, int stdinPipe[2], int stdoutPipe[2], - int* shm_ptr, sem_t* sem, int shm_fd); -#endif + void closeSharedMemory(int shm_fd, int* shm_ptr); std::unique_ptr impl_; }; diff --git a/src/addon/toolchain.cpp b/src/addon/toolchain.cpp index 6926f48c..73283ed7 100644 --- a/src/addon/toolchain.cpp +++ b/src/addon/toolchain.cpp @@ -1,45 +1,119 @@ +// toolchain.cpp #include "toolchain.hpp" +#include #include #include #include -#include +#include +#include +#include #include "atom/log/loguru.hpp" -#include "macro.hpp" +#include "error/exception.hpp" +#include "utils/constant.hpp" + +// Toolchain implementation +class Toolchain::Impl { +public: + std::string name; + std::string compiler; + std::string buildTool; + std::string version; + std::string path; + Type type; + + Impl(std::string name, std::string compiler, std::string buildTool, + std::string version, std::string path, Type type) + : name(std::move(name)), + compiler(std::move(compiler)), + buildTool(std::move(buildTool)), + version(std::move(version)), + path(std::move(path)), + type(type) {} +}; Toolchain::Toolchain(std::string name, std::string compiler, std::string buildTool, std::string version, - std::string path) - : name_(std::move(name)), - compiler_(std::move(compiler)), - buildTool_(std::move(buildTool)), - version_(std::move(version)), - path_(std::move(path)) {} + std::string path, Type type) + : impl_(std::make_unique(std::move(name), std::move(compiler), + std::move(buildTool), std::move(version), + std::move(path), type)) {} + +Toolchain::~Toolchain() = default; +Toolchain::Toolchain(const Toolchain& other) + : impl_(std::make_unique(*other.impl_)) {} +Toolchain::Toolchain(Toolchain&& other) noexcept = default; +auto Toolchain::operator=(const Toolchain& other) -> Toolchain& { + if (this != &other) { + impl_ = std::make_unique(*other.impl_); + } + return *this; +} +auto Toolchain::operator=(Toolchain&& other) noexcept -> Toolchain& = default; void Toolchain::displayInfo() const { - LOG_F(INFO, "Toolchain Information for {}", name_); - LOG_F(INFO, "Compiler: {}", compiler_); - LOG_F(INFO, "Build Tool: {}", buildTool_); - LOG_F(INFO, "Version: {}", version_); - LOG_F(INFO, "Path: {}", path_); + LOG_F(INFO, "Toolchain Information for {}", impl_->name); + LOG_F(INFO, "Compiler: {}", impl_->compiler); + LOG_F(INFO, "Build Tool: {}", impl_->buildTool); + LOG_F(INFO, "Version: {}", impl_->version); + LOG_F(INFO, "Path: {}", impl_->path); + LOG_F(INFO, "Type: {}", + impl_->type == Type::Compiler + ? "Compiler" + : (impl_->type == Type::BuildTool ? "Build Tool" : "Unknown")); } -auto Toolchain::getName() const -> const std::string& { return name_; } +auto Toolchain::getName() const -> const std::string& { return impl_->name; } +auto Toolchain::getCompiler() const -> const std::string& { + return impl_->compiler; +} +auto Toolchain::getBuildTool() const -> const std::string& { + return impl_->buildTool; +} +auto Toolchain::getVersion() const -> const std::string& { + return impl_->version; +} +auto Toolchain::getPath() const -> const std::string& { return impl_->path; } +auto Toolchain::getType() const -> Type { return impl_->type; } -void ToolchainManager::scanForToolchains() { +void Toolchain::setVersion(const std::string& version) { + impl_->version = version; +} +void Toolchain::setPath(const std::string& path) { impl_->path = path; } +void Toolchain::setType(Type type) { impl_->type = type; } + +auto Toolchain::isCompatibleWith(const Toolchain& other) const -> bool { + // Implement compatibility logic here + // For example, check if compiler versions are compatible + return true; // Placeholder implementation +} + +// ToolchainManager implementation +class ToolchainManager::Impl { +public: + std::vector toolchains; std::vector searchPaths; + std::unordered_map toolchainAliases; + std::optional defaultToolchain; -#if defined(_WIN32) || defined(_WIN64) - searchPaths = {"C:\\Program Files", "C:\\Program Files (x86)", - "C:\\MinGW\\bin", "C:\\LLVM\\bin", - "C:\\msys64\\mingw64\\bin", "C:\\msys64\\mingw32\\bin", - "C:\\msys64\\clang64\\bin", "C:\\msys64\\clang32\\bin"}; -#else - searchPaths = {"/usr/bin", "/usr/local/bin"}; -#endif + static auto getCompilerVersion(const std::string& path) -> std::string; + void scanBuildTools(); + static auto executeCommand(const std::string& command) -> std::string; + void initializeDefaultSearchPaths(); +}; - for (const auto& path : searchPaths) { +ToolchainManager::ToolchainManager() : impl_(std::make_unique()) { + impl_->initializeDefaultSearchPaths(); +} + +ToolchainManager::~ToolchainManager() = default; +ToolchainManager::ToolchainManager(ToolchainManager&&) noexcept = default; +auto ToolchainManager::operator=(ToolchainManager&&) noexcept + -> ToolchainManager& = default; + +void ToolchainManager::scanForToolchains() { + for (const auto& path : impl_->searchPaths) { if (std::filesystem::exists(path)) { for (const auto& entry : std::filesystem::directory_iterator(path)) { @@ -50,92 +124,307 @@ void ToolchainManager::scanForToolchains() { filename.starts_with("clang") || filename.starts_with("clang++")) { std::string version = - getCompilerVersion(entry.path().string()); - toolchains_.emplace_back(filename, filename, filename, - version, - entry.path().string()); + Impl::getCompilerVersion(entry.path().string()); + addToolchain(Toolchain(filename, filename, "", version, + entry.path().string(), + Toolchain::Type::Compiler)); } } } } } - scanBuildTools(); + impl_->scanBuildTools(); } -void ToolchainManager::scanBuildTools() { +void ToolchainManager::Impl::scanBuildTools() { std::vector buildTools = {"make", "ninja", "cmake", "gmake", "msbuild"}; for (const auto& tool : buildTools) { - if (std::filesystem::exists(tool)) { - std::string version = getCompilerVersion(tool); - toolchains_.emplace_back(tool, tool, tool, version, tool); + std::string toolPath = tool + Constants::EXECUTABLE_EXTENSION; + if (std::filesystem::exists(toolPath)) { + std::string version = getCompilerVersion(toolPath); + toolchains.emplace_back(tool, "", tool, version, toolPath, + Toolchain::Type::BuildTool); } } } void ToolchainManager::listToolchains() const { LOG_F(INFO, "Available Toolchains:"); - for (const auto& tc : toolchains_) { - LOG_F(INFO, "- {}", tc.getName()); + for (const auto& tc : impl_->toolchains) { + LOG_F(INFO, "- {} ({}) [{}]", tc.getName(), tc.getVersion(), + tc.getType() == Toolchain::Type::Compiler + ? "Compiler" + : (tc.getType() == Toolchain::Type::BuildTool ? "Build Tool" + : "Unknown")); } } -bool ToolchainManager::selectToolchain(const std::string& name) const { - for (const auto& tc : toolchains_) { - if (tc.getName() == name) { - tc.displayInfo(); - return true; - } +auto ToolchainManager::selectToolchain(const std::string& name) const + -> std::optional { + auto it = std::find_if( + impl_->toolchains.begin(), impl_->toolchains.end(), + [&name](const Toolchain& tc) { return tc.getName() == name; }); + if (it != impl_->toolchains.end()) { + it->displayInfo(); + return *it; } - return false; + return std::nullopt; } -void ToolchainManager::saveConfig(const std::string& filename) { +void ToolchainManager::saveConfig(const std::string& filename) const { std::ofstream file(filename); - for (const auto& tc : toolchains_) { - file << tc.getName() << "\n"; + if (!file) { + THROW_FAIL_TO_CLOSE_FILE("Unable to open file for writing: " + + filename); + } + for (const auto& tc : impl_->toolchains) { + file << tc.getName() << "," << tc.getCompiler() << "," + << tc.getBuildTool() << "," << tc.getVersion() << "," + << tc.getPath() << "," << static_cast(tc.getType()) << "\n"; + } + // Save aliases + file << "--- Aliases ---\n"; + for (const auto& [alias, toolchainName] : impl_->toolchainAliases) { + file << alias << "," << toolchainName << "\n"; + } + // Save default toolchain + file << "--- Default ---\n"; + if (impl_->defaultToolchain) { + file << *impl_->defaultToolchain << "\n"; } - file.close(); LOG_F(INFO, "Configuration saved to {}", filename); } void ToolchainManager::loadConfig(const std::string& filename) { std::ifstream file(filename); - std::string toolchainName; - while (std::getline(file, toolchainName)) { - ATOM_UNUSED_RESULT(selectToolchain(toolchainName)); + if (!file) { + throw std::runtime_error("Unable to open file for reading: " + + filename); } - file.close(); + impl_->toolchains.clear(); + impl_->toolchainAliases.clear(); + impl_->defaultToolchain.reset(); + + std::string line; + bool readingAliases = false; + bool readingDefault = false; + while (std::getline(file, line)) { + if (line == "--- Aliases ---") { + readingAliases = true; + readingDefault = false; + continue; + } else if (line == "--- Default ---") { + readingAliases = false; + readingDefault = true; + continue; + } + + if (readingAliases) { + std::istringstream iss(line); + std::string alias; + std::string toolchainName; + if (std::getline(iss, alias, ',') && + std::getline(iss, toolchainName)) { + setToolchainAlias(alias, toolchainName); + } + } else if (readingDefault) { + setDefaultToolchain(line); + } else { + std::vector parts; + std::istringstream iss(line); + std::string part; + while (std::getline(iss, part, ',')) { + parts.push_back(part); + } + if (parts.size() == 6) { + addToolchain(Toolchain( + parts[0], parts[1], parts[2], parts[3], parts[4], + static_cast(std::stoi(parts[5])))); + } + } + } + LOG_F(INFO, "Configuration loaded from {}", filename); } -std::string ToolchainManager::getCompilerVersion(const std::string& path) { - std::string command = path + " --version"; -#if defined(_WIN32) || defined(_WIN64) - command = "\"" + path + "\"" + " --version"; -#endif +auto ToolchainManager::getToolchains() const -> const std::vector& { + return impl_->toolchains; +} +auto ToolchainManager::getAvailableCompilers() const + -> std::vector { + std::vector compilers; + for (const auto& tc : impl_->toolchains) { + if (tc.getType() == Toolchain::Type::Compiler) { + compilers.push_back(tc.getName()); + } + } + return compilers; +} + +auto ToolchainManager::getAvailableBuildTools() const + -> std::vector { + std::vector buildTools; + for (const auto& tc : impl_->toolchains) { + if (tc.getType() == Toolchain::Type::BuildTool) { + buildTools.push_back(tc.getName()); + } + } + return buildTools; +} + +void ToolchainManager::addToolchain(const Toolchain& toolchain) { + auto it = std::find_if(impl_->toolchains.begin(), impl_->toolchains.end(), + [&](const Toolchain& tc) { + return tc.getName() == toolchain.getName(); + }); + if (it == impl_->toolchains.end()) { + impl_->toolchains.push_back(toolchain); + } else { + *it = toolchain; + } +} + +void ToolchainManager::removeToolchain(const std::string& name) { + impl_->toolchains.erase( + std::remove_if( + impl_->toolchains.begin(), impl_->toolchains.end(), + [&](const Toolchain& tc) { return tc.getName() == name; }), + impl_->toolchains.end()); +} + +void ToolchainManager::updateToolchain(const std::string& name, + const Toolchain& updatedToolchain) { + auto it = + std::find_if(impl_->toolchains.begin(), impl_->toolchains.end(), + [&](const Toolchain& tc) { return tc.getName() == name; }); + if (it != impl_->toolchains.end()) { + *it = updatedToolchain; + } +} + +auto ToolchainManager::findToolchain(const std::string& name) const + -> std::optional { + auto it = + std::find_if(impl_->toolchains.begin(), impl_->toolchains.end(), + [&](const Toolchain& tc) { return tc.getName() == name; }); + if (it != impl_->toolchains.end()) { + return *it; + } + return std::nullopt; +} + +auto ToolchainManager::findToolchains(const ToolchainFilter& filter) const + -> std::vector { + std::vector result; + std::copy_if(impl_->toolchains.begin(), impl_->toolchains.end(), + std::back_inserter(result), filter); + return result; +} + +auto ToolchainManager::suggestCompatibleToolchains(const Toolchain& base) const + -> std::vector { + return findToolchains( + [&base](const Toolchain& tc) { return base.isCompatibleWith(tc); }); +} + +void ToolchainManager::registerCustomToolchain(const std::string& name, + const std::string& path) { + if (!std::filesystem::exists(path)) { + throw std::runtime_error("Custom toolchain path does not exist: " + + path); + } + std::string version = Impl::getCompilerVersion(path); + Toolchain::Type type = path.find("make") != std::string::npos || + path.find("ninja") != std::string::npos + ? Toolchain::Type::BuildTool + : Toolchain::Type::Compiler; + addToolchain(Toolchain(name, name, "", version, path, type)); +} + +void ToolchainManager::setDefaultToolchain(const std::string& name) { + if (findToolchain(name)) { + impl_->defaultToolchain = name; + } else { + throw std::runtime_error("Toolchain not found: " + name); + } +} + +auto ToolchainManager::getDefaultToolchain() const -> std::optional { + if (impl_->defaultToolchain) { + return findToolchain(*impl_->defaultToolchain); + } + return std::nullopt; +} + +void ToolchainManager::addSearchPath(const std::string& path) { + if (std::find(impl_->searchPaths.begin(), impl_->searchPaths.end(), path) == + impl_->searchPaths.end()) { + impl_->searchPaths.push_back(path); + } +} + +void ToolchainManager::removeSearchPath(const std::string& path) { + impl_->searchPaths.erase( + std::remove(impl_->searchPaths.begin(), impl_->searchPaths.end(), path), + impl_->searchPaths.end()); +} + +auto ToolchainManager::getSearchPaths() const + -> const std::vector& { + return impl_->searchPaths; +} + +void ToolchainManager::setToolchainAlias(const std::string& alias, + const std::string& toolchainName) { + if (findToolchain(toolchainName)) { + impl_->toolchainAliases[alias] = toolchainName; + } else { + throw std::runtime_error("Toolchain not found: " + toolchainName); + } +} + +auto ToolchainManager::getToolchainByAlias(const std::string& alias) const + -> std::optional { + auto it = impl_->toolchainAliases.find(alias); + if (it != impl_->toolchainAliases.end()) { + return findToolchain(it->second); + } + return std::nullopt; +} + +auto ToolchainManager::Impl::getCompilerVersion(const std::string& path) + -> std::string { + std::string command = "\"" + path + "\" --version"; + std::string result = executeCommand(command); + return result.empty() ? "Unknown version" + : result.substr(0, result.find('\n')); +} + +auto ToolchainManager::Impl::executeCommand(const std::string& command) + -> std::string { std::array buffer; std::string result; - std::unique_ptr pipe(popen(command.c_str(), "r"), pclose); if (!pipe) { - return "Unknown version"; + throw std::runtime_error("popen() failed!"); } - while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { result += buffer.data(); } - - return result.empty() ? "Unknown version" : result; + return result; } -std::vector ToolchainManager::getAvailableCompilers() const { - std::vector compilers; - for (const auto& toolchain : toolchains_) { - compilers.push_back(toolchain.getName()); // 收集每个可用工具链的名称 - } - return compilers; +void ToolchainManager::Impl::initializeDefaultSearchPaths() { +#if defined(_WIN32) || defined(_WIN64) + searchPaths = {"C:\\Program Files", "C:\\Program Files (x86)", + "C:\\MinGW\\bin", "C:\\LLVM\\bin", + "C:\\msys64\\mingw64\\bin", "C:\\msys64\\mingw32\\bin", + "C:\\msys64\\clang64\\bin", "C:\\msys64\\clang32\\bin"}; +#else + searchPaths = {"/usr/bin", "/usr/local/bin", "/opt/local/bin"}; +#endif } diff --git a/src/addon/toolchain.hpp b/src/addon/toolchain.hpp index 0e90c009..4f218f4c 100644 --- a/src/addon/toolchain.hpp +++ b/src/addon/toolchain.hpp @@ -1,53 +1,91 @@ #ifndef LITHIUM_ADDON_TOOLCHAIN_HPP #define LITHIUM_ADDON_TOOLCHAIN_HPP -#include +#include +#include +#include #include #include -#if defined(_WIN32) || defined(_WIN64) -#include -#define CMD_EXTENSION ".exe" -#else -#define CMD_EXTENSION "" -#endif - class Toolchain { public: + enum class Type { Compiler, BuildTool, Unknown }; + Toolchain(std::string name, std::string compiler, std::string buildTool, - std::string version, std::string path); + std::string version, std::string path, Type type = Type::Unknown); + ~Toolchain(); + Toolchain(const Toolchain& other); + Toolchain(Toolchain&& other) noexcept; + auto operator=(const Toolchain& other) -> Toolchain&; + auto operator=(Toolchain&& other) noexcept -> Toolchain&; void displayInfo() const; - [[nodiscard]] auto getName() const -> const std::string&; + [[nodiscard]] auto getCompiler() const -> const std::string&; + [[nodiscard]] auto getBuildTool() const -> const std::string&; + [[nodiscard]] auto getVersion() const -> const std::string&; + [[nodiscard]] auto getPath() const -> const std::string&; + [[nodiscard]] auto getType() const -> Type; + + void setVersion(const std::string& version); + void setPath(const std::string& path); + void setType(Type type); + + [[nodiscard]] auto isCompatibleWith(const Toolchain& other) const -> bool; private: - std::string name_; - std::string compiler_; - std::string buildTool_; - std::string version_; - std::string path_; + class Impl; + std::unique_ptr impl_; }; class ToolchainManager { public: + using ToolchainFilter = std::function; + + ToolchainManager(); + ~ToolchainManager(); + ToolchainManager(const ToolchainManager&) = delete; + ToolchainManager& operator=(const ToolchainManager&) = delete; + ToolchainManager(ToolchainManager&&) noexcept; + ToolchainManager& operator=(ToolchainManager&&) noexcept; + void scanForToolchains(); void listToolchains() const; - [[nodiscard]] auto selectToolchain(const std::string& name) const -> bool; - void saveConfig(const std::string& filename); + [[nodiscard]] auto selectToolchain(const std::string& name) const + -> std::optional; + void saveConfig(const std::string& filename) const; void loadConfig(const std::string& filename); - - [[nodiscard]] auto getToolchains() const -> const std::vector& { - return toolchains_; - } - - auto getAvailableCompilers() const -> std::vector; + [[nodiscard]] auto getToolchains() const -> const std::vector&; + [[nodiscard]] auto getAvailableCompilers() const + -> std::vector; + [[nodiscard]] auto getAvailableBuildTools() const + -> std::vector; + void addToolchain(const Toolchain& toolchain); + void removeToolchain(const std::string& name); + void updateToolchain(const std::string& name, + const Toolchain& updatedToolchain); + [[nodiscard]] auto findToolchain(const std::string& name) const + -> std::optional; + [[nodiscard]] auto findToolchains(const ToolchainFilter& filter) const + -> std::vector; + [[nodiscard]] auto suggestCompatibleToolchains(const Toolchain& base) const + -> std::vector; + void registerCustomToolchain(const std::string& name, + const std::string& path); + void setDefaultToolchain(const std::string& name); + [[nodiscard]] auto getDefaultToolchain() const -> std::optional; + void addSearchPath(const std::string& path); + void removeSearchPath(const std::string& path); + [[nodiscard]] auto getSearchPaths() const + -> const std::vector&; + void setToolchainAlias(const std::string& alias, + const std::string& toolchainName); + [[nodiscard]] auto getToolchainByAlias(const std::string& alias) const + -> std::optional; private: - std::vector toolchains_; - - auto getCompilerVersion(const std::string& path) -> std::string; - void scanBuildTools(); + class Impl; + std::unique_ptr impl_; }; #endif // LITHIUM_ADDON_TOOLCHAIN_HPP diff --git a/src/addon/tracker.cpp b/src/addon/tracker.cpp index 217758b6..23e53031 100644 --- a/src/addon/tracker.cpp +++ b/src/addon/tracker.cpp @@ -3,20 +3,18 @@ #include #include -#include +#include #include #include #include #include #include -#include -#include +#include #include "atom/type/json.hpp" #include "atom/utils/aes.hpp" -namespace fs = std::filesystem; +#include "atom/utils/time.hpp" -// Implementation of the FileTracker class struct FileTracker::Impl { std::string directory; std::string jsonFilePath; @@ -26,167 +24,143 @@ struct FileTracker::Impl { json oldJson; json differences; std::mutex mtx; + std::optional encryptionKey; - // Constructor - Impl(std::string dir, std::string jFilePath, - const std::vector& types, bool rec) - : directory(std::move(dir)), - jsonFilePath(std::move(jFilePath)), + Impl(std::string_view dir, std::string_view jFilePath, + std::span types, bool rec) + : directory(dir), + jsonFilePath(jFilePath), recursive(rec), - fileTypes(types) {} - - // Get last write time - static auto getLastWriteTime(const fs::path& filePath) -> std::string; - - // Save JSON to a file - static void saveJSON(const json& j, const std::string& filePath); - - // Load JSON from a file - static auto loadJSON(const std::string& filePath) -> json; - - // Generate the JSON data for files in the specified directory - void generateJSON(); - - // Compare the old and new JSON data - auto compareJSON() -> json; - - // Recover files from the JSON state - void recoverFiles(); -}; - -// Get the last write time of a file -auto FileTracker::Impl::getLastWriteTime(const fs::path& filePath) - -> std::string { - auto ftime = fs::last_write_time(filePath); - auto sctp = - std::chrono::time_point_cast( - ftime - fs::file_time_type::clock::now() + - std::chrono::system_clock::now()); - std::time_t cftime = std::chrono::system_clock::to_time_t(sctp); - - return std::asctime(std::localtime(&cftime)); -} - -// Save JSON object to a file -void FileTracker::Impl::saveJSON(const json& j, const std::string& filePath) { - std::ofstream outFile(filePath); - outFile << std::setw(4) << j << std::endl; -} + fileTypes(types.begin(), types.end()) {} + + static void saveJSON(const json& j, const std::string& filePath, + const std::optional& key) { + std::ofstream outFile(filePath, std::ios::binary); + if (key) { + std::vector iv; + std::vector tag; + std::string encrypted = + atom::utils::encryptAES(j.dump(), *key, iv, tag); + outFile.write(encrypted.data(), encrypted.size()); + } else { + outFile << std::setw(4) << j << std::endl; + } + } -// Load JSON object from a file -auto FileTracker::Impl::loadJSON(const std::string& filePath) -> json { - std::ifstream inFile(filePath); - if (!inFile.is_open()) { - return json(); + static auto loadJSON(const std::string& filePath, + const std::optional& key) -> json { + std::ifstream inFile(filePath, std::ios::binary); + if (!inFile.is_open()) { + return {}; + } + if (key) { + std::string encrypted((std::istreambuf_iterator(inFile)), + std::istreambuf_iterator()); + std::vector iv; + std::vector tag; + std::string decrypted = + atom::utils::decryptAES(encrypted, *key, iv, tag); + return json::parse(decrypted); + } else { + json j; + inFile >> j; + return j; + } } - json j; - inFile >> j; - return j; -} -// Generate JSON for all tracked files in the specified directory -void FileTracker::Impl::generateJSON() { - std::vector threads; + void generateJSON() { + using DirIterVariant = std::variant; + + DirIterVariant fileRange = + recursive + ? DirIterVariant(fs::recursive_directory_iterator(directory)) + : DirIterVariant(fs::directory_iterator(directory)); + + std::visit( + [&](auto&& iter) { + for (const auto& entry : iter) { + if (std::ranges::find(fileTypes, + entry.path().extension().string()) != + fileTypes.end()) { + processFile(entry.path()); + } + } + }, + fileRange); + + saveJSON(newJson, jsonFilePath, encryptionKey); + } - auto process = [&](const fs::path& entry) { + void processFile(const fs::path& entry) { std::string hash = atom::utils::calculateSha256(entry.string()); - std::string lastWriteTime = getLastWriteTime(entry); - std::lock_guard lock(mtx); + std::string lastWriteTime = atom::utils::getChinaTimestampString(); + std::lock_guard lock(mtx); newJson[entry.string()] = {{"last_write_time", lastWriteTime}, {"hash", hash}, {"size", fs::file_size(entry)}, {"type", entry.extension().string()}}; - }; - - if (recursive) { - for (const auto& entry : fs::recursive_directory_iterator(directory)) { - if (std::find(fileTypes.begin(), fileTypes.end(), - entry.path().extension().string()) != - fileTypes.end()) { - threads.emplace_back(process, entry.path()); - } - } - } else { - for (const auto& entry : fs::directory_iterator(directory)) { - if (std::find(fileTypes.begin(), fileTypes.end(), - entry.path().extension().string()) != - fileTypes.end()) { - threads.emplace_back(process, entry.path()); - } - } - } - - for (auto& th : threads) { - if (th.joinable()) { - th.join(); - } } - saveJSON(newJson, jsonFilePath); -} - -// Compare new and old JSON objects for differences -auto FileTracker::Impl::compareJSON() -> json { - json diff; - for (const auto& [filePath, newFileInfo] : newJson.items()) { - if (oldJson.contains(filePath)) { - if (oldJson[filePath]["hash"] != newFileInfo["hash"]) { - diff[filePath] = {{"status", "modified"}, - {"new", newFileInfo}, - {"old", oldJson[filePath]}}; + auto compareJSON() -> json { + json diff; + for (const auto& [filePath, newFileInfo] : newJson.items()) { + if (oldJson.contains(filePath)) { + if (oldJson[filePath]["hash"] != newFileInfo["hash"]) { + diff[filePath] = {{"status", "modified"}, + {"new", newFileInfo}, + {"old", oldJson[filePath]}}; + } + } else { + diff[filePath] = {{"status", "new"}}; } - } else { - diff[filePath] = {{"status", "new"}}; } - } - for (const auto& [filePath, oldFileInfo] : oldJson.items()) { - if (!newJson.contains(filePath)) { - diff[filePath] = {{"status", "deleted"}}; + for (const auto& [filePath, oldFileInfo] : oldJson.items()) { + if (!newJson.contains(filePath)) { + diff[filePath] = {{"status", "deleted"}}; + } } + return diff; } - return diff; -} -// Recover files from the JSON state -void FileTracker::Impl::recoverFiles() { - for (const auto& [filePath, fileInfo] : oldJson.items()) { - if (!fs::exists(filePath)) { - std::ofstream outFile(filePath); - if (outFile.is_open()) { - outFile << "This file was recovered based on version: " - << fileInfo["last_write_time"] << std::endl; - outFile.close(); + void recoverFiles() { + for (const auto& [filePath, fileInfo] : oldJson.items()) { + if (!fs::exists(filePath)) { + std::ofstream outFile(filePath); + if (outFile.is_open()) { + outFile << "This file was recovered based on version: " + << fileInfo["last_write_time"] << std::endl; + outFile.close(); + } } } } -} +}; -// Constructor -FileTracker::FileTracker(const std::string& directory, - const std::string& jsonFilePath, - const std::vector& fileTypes, - bool recursive) +FileTracker::FileTracker(std::string_view directory, + std::string_view jsonFilePath, + std::span fileTypes, bool recursive) : pImpl(std::make_unique(directory, jsonFilePath, fileTypes, recursive)) {} -// Destructor FileTracker::~FileTracker() = default; -// Scan for files and generate JSON +FileTracker::FileTracker(FileTracker&&) noexcept = default; +auto FileTracker::operator=(FileTracker&&) noexcept -> FileTracker& = default; + void FileTracker::scan() { if (fs::exists(pImpl->jsonFilePath)) { - pImpl->oldJson = pImpl->loadJSON(pImpl->jsonFilePath); + pImpl->oldJson = + pImpl->loadJSON(pImpl->jsonFilePath, pImpl->encryptionKey); } pImpl->generateJSON(); } -// Compare the loaded JSON data void FileTracker::compare() { pImpl->differences = pImpl->compareJSON(); } -// Log the differences found to a log file -void FileTracker::logDifferences(const std::string& logFilePath) { - std::ofstream logFile(logFilePath, std::ios_base::app); +void FileTracker::logDifferences(std::string_view logFilePath) const { + std::ofstream logFile(logFilePath.data(), std::ios_base::app); if (logFile.is_open()) { for (const auto& [filePath, info] : pImpl->differences.items()) { logFile << "File: " << filePath << ", Status: " << info["status"] @@ -195,19 +169,73 @@ void FileTracker::logDifferences(const std::string& logFilePath) { } } -// Recover the files based on the saved JSON data -void FileTracker::recover(const std::string& jsonFilePath) { - pImpl->oldJson = pImpl->loadJSON(jsonFilePath); +void FileTracker::recover(std::string_view jsonFilePath) { + pImpl->oldJson = pImpl->loadJSON(jsonFilePath.data(), pImpl->encryptionKey); pImpl->recoverFiles(); } -// Get the differences in JSON format -auto FileTracker::getDifferences() const -> const json& { +auto FileTracker::asyncScan() -> std::future { + return std::async(std::launch::async, [this] { scan(); }); +} + +auto FileTracker::asyncCompare() -> std::future { + return std::async(std::launch::async, [this] { compare(); }); +} + +auto FileTracker::getDifferences() const noexcept -> const json& { return pImpl->differences; } -// Get the list of tracked file types -auto FileTracker::getTrackedFileTypes() const - -> const std::vector& { +auto FileTracker::getTrackedFileTypes() + const noexcept -> const std::vector& { return pImpl->fileTypes; } + +template Func> +void FileTracker::forEachFile(Func&& func) const { + using DirIterVariant = + std::variant; + + DirIterVariant fileRange = + pImpl->recursive + ? DirIterVariant(fs::recursive_directory_iterator(pImpl->directory)) + : DirIterVariant(fs::directory_iterator(pImpl->directory)); + + std::visit( + [&](auto&& iter) { + for (const auto& entry : iter) { + if (std::ranges::find(pImpl->fileTypes, + entry.path().extension().string()) != + pImpl->fileTypes.end()) { + func(entry.path()); + } + } + }, + fileRange); +} + +auto FileTracker::getFileInfo(const fs::path& filePath) const -> std::optional { + if (auto it = pImpl->newJson.find(filePath.string()); + it != pImpl->newJson.end()) { + return *it; + } + return std::nullopt; +} + +void FileTracker::addFileType(std::string_view fileType) { + pImpl->fileTypes.emplace_back(fileType); +} + +void FileTracker::removeFileType(std::string_view fileType) { + pImpl->fileTypes.erase( + std::remove(pImpl->fileTypes.begin(), pImpl->fileTypes.end(), fileType), + pImpl->fileTypes.end()); +} + +void FileTracker::setEncryptionKey(std::string_view key) { + pImpl->encryptionKey = std::string(key); +} + +// Explicitly instantiate the template function to avoid linker errors +template void FileTracker::forEachFile>( + std::function&&) const; diff --git a/src/addon/tracker.hpp b/src/addon/tracker.hpp index b9861cb1..0ba30f6f 100644 --- a/src/addon/tracker.hpp +++ b/src/addon/tracker.hpp @@ -1,116 +1,57 @@ #ifndef LITHIUM_ADDON_TRACKER_HPP #define LITHIUM_ADDON_TRACKER_HPP +#include +#include +#include #include +#include +#include #include #include #include "atom/type/json_fwd.hpp" using json = nlohmann::json; -/** - * @class FileTracker - * @brief A class that tracks files in a directory, compares their states, logs - * differences, and can recover files from JSON data. - * - * This class provides functionality to scan a directory for files, compare the - * current state of the files with a previously saved state (in JSON format), - * log any differences, and recover files based on saved JSON data. It supports - * both recursive and non-recursive directory scanning and allows tracking - * specific file types. - */ +namespace fs = std::filesystem; + class FileTracker { public: - /** - * @brief Constructs a FileTracker instance. - * - * @param directory The directory to be tracked. - * @param jsonFilePath The path to the JSON file where the state of files - * will be saved. - * @param fileTypes A vector of file extensions (types) to track. - * @param recursive A boolean indicating whether to scan directories - * recursively. Default is false. - */ - FileTracker(const std::string& directory, const std::string& jsonFilePath, - const std::vector& fileTypes, - bool recursive = false); + FileTracker(std::string_view directory, std::string_view jsonFilePath, + std::span fileTypes, bool recursive = false); - /** - * @brief Destructs the FileTracker instance. - */ ~FileTracker(); - /** - * @brief Scans the specified directory for files and generates a JSON file - * representing the state of the files. - * - * This method will create or update the JSON file with the current state of - * the files in the tracked directory. The state includes details about file - * names, types, and last modified timestamps. - */ - void scan(); + FileTracker(const FileTracker&) = delete; + FileTracker& operator=(const FileTracker&) = delete; + FileTracker(FileTracker&&) noexcept; + FileTracker& operator=(FileTracker&&) noexcept; - /** - * @brief Compares the new JSON data with the previous state to find - * differences. - * - * This method will read the previously saved JSON file and compare it with - * the current state of the files. The differences will be stored internally - * and can be retrieved using the `getDifferences()` method. - */ + void scan(); void compare(); + void logDifferences(std::string_view logFilePath) const; + void recover(std::string_view jsonFilePath); - /** - * @brief Logs the differences between the current and previous file states - * to a specified log file. - * - * @param logFilePath The path to the log file where differences will be - * written. - * - * This method generates a log file detailing the differences detected by - * the `compare()` method. - */ - void logDifferences(const std::string& logFilePath); + [[nodiscard]] std::future asyncScan(); + [[nodiscard]] std::future asyncCompare(); - /** - * @brief Recovers files based on the saved JSON data. - * - * @param jsonFilePath The path to the JSON file containing the saved state - * of files. - * - * This method will use the provided JSON data to restore files to their - * state as described in the JSON file. It is useful for recovering lost or - * changed files based on previous snapshots. - */ - void recover(const std::string& jsonFilePath); + [[nodiscard]] const json& getDifferences() const noexcept; + [[nodiscard]] const std::vector& getTrackedFileTypes() + const noexcept; - /** - * @brief Retrieves the differences found during the comparison in JSON - * format. - * - * @return A constant reference to a JSON object containing the differences. - * - * This method provides access to the JSON representation of differences - * detected during the `compare()` method. - */ - [[nodiscard]] auto getDifferences() const -> const json&; + template Func> + void forEachFile(Func&& func) const; - /** - * @brief Retrieves the list of tracked file types. - * - * @return A constant reference to a vector of file types being tracked. - * - * This method returns the list of file extensions that the `FileTracker` is - * currently monitoring. - */ - [[nodiscard]] auto getTrackedFileTypes() const - -> const std::vector&; + [[nodiscard]] std::optional getFileInfo( + const fs::path& filePath) const; + + void addFileType(std::string_view fileType); + void removeFileType(std::string_view fileType); + + void setEncryptionKey(std::string_view key); private: - // Forward declaration of the implementation class struct Impl; - - // Unique pointer to hold the implementation std::unique_ptr pImpl; }; diff --git a/src/addon/version.cpp b/src/addon/version.cpp index e4b6a23b..585f1c8f 100644 --- a/src/addon/version.cpp +++ b/src/addon/version.cpp @@ -1,6 +1,7 @@ #include "version.hpp" #include +#include #include "atom/error/exception.hpp" @@ -52,6 +53,18 @@ constexpr auto Version::parse(std::string_view versionStr) -> Version { return {major, minor, patch, prerelease, build}; } +auto Version::toString() const -> std::string { + auto result = std::format("{}.{}.{}", major, minor, patch); + if (!prerelease.empty()) { + result += "-" + prerelease; + } + if (!build.empty()) { + result += "+" + build; + } + + return result; +} + constexpr auto Version::operator<(const Version& other) const -> bool { if (major != other.major) { return major < other.major; diff --git a/src/addon/version.hpp b/src/addon/version.hpp index b8972f84..0f688824 100644 --- a/src/addon/version.hpp +++ b/src/addon/version.hpp @@ -28,6 +28,7 @@ struct Version { build(std::move(bld)) {} static constexpr auto parse(std::string_view versionStr) -> Version; + [[nodiscard]] auto toString() const -> std::string; constexpr auto operator<(const Version& other) const -> bool; constexpr auto operator>(const Version& other) const -> bool; diff --git a/src/app/app.cpp b/src/app/app.cpp index 4c369df4..df63adde 100644 --- a/src/app/app.cpp +++ b/src/app/app.cpp @@ -1,29 +1,130 @@ #include "app.hpp" -#include "command.hpp" -#include "token.hpp" - -#include "atom/async/message_bus.hpp" -#include "atom/error/exception.hpp" #include "atom/function/global_ptr.hpp" -#include "atom/io/io.hpp" +#include "exception.hpp" + +#include "utils/constant.hpp" #include "atom/type/json.hpp" -using json = nlohmann::json; -namespace lithium { -class LithiumAppImpl { -public: - std::weak_ptr bus_; +namespace lithium::app { +ServerCore::ServerCore(size_t num_threads) + : asyncExecutor(std::make_unique(num_threads)), + commandDispatcher(std::make_unique(eventLoop)), + messageBus(atom::async::MessageBus::createShared()) { + // Create EventLoop + + GET_OR_CREATE_PTR(eventLoop, EventLoop, Constants::EVENTLOOP); + GET_OR_CREATE_PTR(componentManager, ComponentManager, + Constants::COMPONENT_MANAGER); + GET_OR_CREATE_PTR(commandDispatcher, CommandDispatcher, Constants::DISPATCHER); + + initializeSystemEvents(); + + +} + +ServerCore::~ServerCore() { stop(); } + +void ServerCore::start() { + publish("system.status", std::string("Server starting")); - // Max:: This config is for the core component. - json config_; -}; + // 加载所有组件 + auto components = componentManager->getComponentList(); + for (const auto& component : components) { + json params; // You can customize parameters if needed + componentManager->loadComponent(params); + } + + publish("system.status", std::string("Server started")); +} -LithiumApp::LithiumApp() { - if (!atom::io::isFileExists("config/base.json")) { - LOG_F(WARNING, "Failed to find config/base.json"); +void ServerCore::stop() { + publish("system.status", std::string("Server stopping")); + + // 卸载所有组件 + auto components = componentManager->getComponentList(); + for (const auto& component : components) { + json params; // You can customize parameters if needed + componentManager->unloadComponent(params); } - GetWeakPtr(""); + + componentManager->destroy(); // 销毁组件管理器 + + asyncExecutor->shutdown(); + eventLoop->stop(); + messageBus->clearAllSubscribers(); + publish("system.status", std::string("Server stopped")); +} + +template +void ServerCore::registerCommand( + const std::string& commandName, + std::function handler) { + commandDispatcher->registerCommand( + commandName, [this, handler, commandName](const CommandType& cmd) { + publish("system.command.executed", commandName); + handler(cmd); + }); } -} // namespace lithium + +template +void ServerCore::executeCommand(const std::string& commandName, + const CommandType& command) { + publish("system.command.executing", commandName); + commandDispatcher->dispatch(commandName, command); +} + +template +void ServerCore::subscribe(const std::string& topic, + std::function handler) { + messageBus->subscribe(topic, handler); +} + +template +void ServerCore::publish(const std::string& topic, const MessageType& message) { + messageBus->publish(topic, message); +} + +void ServerCore::scheduleTask(std::function task, + std::chrono::milliseconds delay) { + if (delay.count() == 0) { + asyncExecutor->submit(task); + } else { + eventLoop->postDelayed(delay, std::move(task)); + } +} + +AsyncExecutor& ServerCore::getAsyncExecutor() { return *asyncExecutor; } + +EventLoop& ServerCore::getEventLoop() { return *eventLoop; } + +atom::async::MessageBus& ServerCore::getMessageBus() { return *messageBus; } + +void ServerCore::initializeSystemEvents() { + subscribe("system.status", [](const std::string& status) { + std::cout << "System status: " << status << std::endl; + }); + + subscribe( + "system.command.executing", [](const std::string& commandName) { + std::cout << "Executing command: " << commandName << std::endl; + }); + + subscribe( + "system.command.executed", [](const std::string& commandName) { + std::cout << "Command executed: " << commandName << std::endl; + }); +} + +// Explicit template instantiations for common types +template void ServerCore::registerCommand( + const std::string&, std::function); +template void ServerCore::executeCommand(const std::string&, + const std::string&); +template void ServerCore::subscribe( + const std::string&, std::function); +template void ServerCore::publish(const std::string&, + const std::string&); + +} // namespace lithium::app diff --git a/src/app/app.hpp b/src/app/app.hpp index 91e559b2..fa83b597 100644 --- a/src/app/app.hpp +++ b/src/app/app.hpp @@ -1,19 +1,63 @@ -#ifndef LITHIUM_APP_APP_HPP -#define LITHIUM_APP_APP_HPP +#ifndef SERVER_CORE_HPP +#define SERVER_CORE_HPP #include +#include +#include "atom/async/message_bus.hpp" +#include "command.hpp" +#include "eventloop.hpp" +#include "executor.hpp" -namespace lithium { +#include "addon/manager.hpp" +#include "config/configor.hpp" -class LithiumAppImpl; -class LithiumApp { +namespace lithium::app { +class ServerCore { public: - LithiumApp(); + explicit ServerCore( + size_t num_threads = std::jthread::hardware_concurrency()); + ~ServerCore(); + + void start(); + void stop(); + + template + void registerCommand(const std::string& commandName, + std::function handler); + + template + void executeCommand(const std::string& commandName, + const CommandType& command); + + template + void subscribe(const std::string& topic, + std::function handler); + + template + void publish(const std::string& topic, const MessageType& message); + + void scheduleTask( + std::function task, + std::chrono::milliseconds delay = std::chrono::milliseconds(0)); + + AsyncExecutor& getAsyncExecutor(); + EventLoop& getEventLoop(); + atom::async::MessageBus& getMessageBus(); + + void loadComponent(const json& params); + void unloadComponent(const json& params); + void reloadComponent(const json& params); + std::vector getComponentList() const; private: - std::unique_ptr impl_; -}; + std::unique_ptr asyncExecutor; + std::shared_ptr eventLoop; + std::shared_ptr commandDispatcher; + std::shared_ptr messageBus; + std::shared_ptr componentManager; -} // namespace lithium + void initializeSystemEvents(); +}; +} // namespace lithium::app -#endif +#endif // SERVER_CORE_HPP diff --git a/src/app/command.cpp b/src/app/command.cpp index d2db95d6..cca05aaa 100644 --- a/src/app/command.cpp +++ b/src/app/command.cpp @@ -1,19 +1,62 @@ #include "command.hpp" +#include "eventloop.hpp" + +namespace lithium::app { +CommandDispatcher::CommandDispatcher(std::shared_ptr eventLoop) + : eventLoop_(std::move(eventLoop)) {} -namespace lithium { void CommandDispatcher::unregisterCommand(const CommandID& id) { std::unique_lock lock(mutex_); handlers_.erase(id); undoHandlers_.erase(id); } +void CommandDispatcher::recordHistory(const CommandID& id, + const std::any& command) { + auto& commandHistory = history_[id]; + commandHistory.push_back(command); + if (commandHistory.size() > maxHistorySize_) { + commandHistory.erase(commandHistory.begin()); + } +} + +void CommandDispatcher::notifySubscribers(const CommandID& id, + const std::any& command) { + auto it = subscribers_.find(id); + if (it != subscribers_.end()) { + for (auto& [_, callback] : it->second) { + callback(id, command); + } + } +} + +int CommandDispatcher::subscribe(const CommandID& id, EventCallback callback) { + std::unique_lock lock(mutex_); + int token = nextSubscriberId_++; + subscribers_[id][token] = std::move(callback); + return token; +} + +void CommandDispatcher::unsubscribe(const CommandID& id, int token) { + std::unique_lock lock(mutex_); + auto& callbacks = subscribers_[id]; + callbacks.erase(token); + if (callbacks.empty()) { + subscribers_.erase(id); + } +} + void CommandDispatcher::clearHistory() { std::unique_lock lock(mutex_); history_.clear(); } -std::vector CommandDispatcher::getActiveCommands() - const { +void CommandDispatcher::clearCommandHistory(const CommandID& id) { + std::unique_lock lock(mutex_); + history_.erase(id); +} + +auto CommandDispatcher::getActiveCommands() const -> std::vector { std::shared_lock lock(mutex_); std::vector activeCommands; activeCommands.reserve(handlers_.size()); @@ -23,4 +66,94 @@ std::vector CommandDispatcher::getActiveCommands() return activeCommands; } -} // namespace lithium +// Template implementations +template +void CommandDispatcher::registerCommand( + const CommandID& id, std::function handler, + std::optional> undoHandler) { + std::unique_lock lock(mutex_); + handlers_[id] = [handler](const std::any& cmd) { + handler(std::any_cast(cmd)); + }; + if (undoHandler) { + undoHandlers_[id] = [undoHandler](const std::any& cmd) { + (*undoHandler)(std::any_cast(cmd)); + }; + } +} + +template +auto CommandDispatcher::dispatch( + const CommandID& id, const CommandType& command, int priority, + std::optional delay, + CommandCallback callback) -> std::future { + auto task = [this, id, command, callback]() -> ResultType { + try { + std::shared_lock lock(mutex_); + auto it = handlers_.find(id); + if (it != handlers_.end()) { + it->second(command); + recordHistory(id, command); + notifySubscribers(id, command); + ResultType result = command; + if (callback) + callback(id, result); + return result; + } else { + throw std::runtime_error("Command not found: " + id); + } + } catch (...) { + auto ex = std::current_exception(); + if (callback) { + callback(id, ex); + } + return ex; + } + }; + + if (delay) { + return eventLoop_->postDelayed(*delay, priority, std::move(task)); + } else { + return eventLoop_->post(priority, std::move(task)); + } +} + +template +auto CommandDispatcher::getResult(std::future& resultFuture) + -> CommandType { + auto result = resultFuture.get(); + if (std::holds_alternative(result)) { + return std::any_cast(std::get(result)); + } else { + std::rethrow_exception(std::get(result)); + } +} + +template +void CommandDispatcher::undo(const CommandID& id, const CommandType& command) { + std::unique_lock lock(mutex_); + auto it = undoHandlers_.find(id); + if (it != undoHandlers_.end()) { + it->second(command); + } +} + +template +void CommandDispatcher::redo(const CommandID& id, const CommandType& command) { + dispatch(id, command, 0, std::nullopt).get(); +} + +template +auto CommandDispatcher::getCommandHistory(const CommandID& id) + -> std::vector { + std::shared_lock lock(mutex_); + std::vector history; + if (auto it = history_.find(id); it != history_.end()) { + for (const auto& cmd : it->second) { + history.push_back(std::any_cast(cmd)); + } + } + return history; +} + +} // namespace lithium::app diff --git a/src/app/command.hpp b/src/app/command.hpp index 1c77c83e..c54ffcd2 100644 --- a/src/app/command.hpp +++ b/src/app/command.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -11,7 +12,9 @@ #include #include -namespace lithium { +namespace lithium::app { +class EventLoop; + class CommandDispatcher { public: using CommandID = std::string; @@ -19,130 +22,58 @@ class CommandDispatcher { using ResultType = std::variant; using CommandCallback = std::function; + using EventCallback = + std::function; + + explicit CommandDispatcher(std::shared_ptr eventLoop); - // 注册命令 template - void registerCommand( - const CommandID& id, CommandHandler handler, - std::optional undoHandler = std::nullopt) { - std::unique_lock lock(mutex_); - handlers_[id] = [handler](const std::any& cmd) { - handler(std::any_cast(cmd)); - }; - if (undoHandler) { - undoHandlers_[id] = [undoHandler](const std::any& cmd) { - (*undoHandler)(std::any_cast(cmd)); - }; - } - } + void registerCommand(const CommandID& id, + std::function handler, + std::optional> + undoHandler = std::nullopt); void unregisterCommand(const CommandID& id); template auto dispatch( - const CommandID& id, const CommandType& command, bool async = true, + const CommandID& id, const CommandType& command, int priority = 0, std::optional delay = std::nullopt, - CommandCallback callback = nullptr) -> std::future { - auto task = [this, id, command, callback]() -> ResultType { - try { - std::shared_lock lock(mutex_); - auto it = handlers_.find(id); - if (it != handlers_.end()) { - it->second(command); - recordHistory(id, command); - ResultType result = command; - if (callback) - callback(id, result); - return result; - } else { - throw std::runtime_error("Command not found: " + id); - } - } catch (...) { - auto ex = std::current_exception(); - if (callback) { - callback(id, ex); - } - return ex; - } - }; - - if (delay) { - return std::async(std::launch::async, [task, delay]() { - std::this_thread::sleep_for(*delay); - return task(); - }); - } - if (async) { - return std::async(std::launch::async, task); - } - return std::async(std::launch::deferred, task); - } + CommandCallback callback = nullptr) -> std::future; template - auto getResult(std::future& resultFuture) -> CommandType { - auto result = resultFuture.get(); - if (std::holds_alternative(result)) { - return std::any_cast(std::get(result)); - } else { - std::rethrow_exception(std::get(result)); - } - } + auto getResult(std::future& resultFuture) -> CommandType; template - void undo(const CommandID& id, const CommandType& command) { - std::unique_lock lock(mutex_); - auto it = undoHandlers_.find(id); - if (it != undoHandlers_.end()) { - it->second(command); - } - } + void undo(const CommandID& id, const CommandType& command); template - void redo(const CommandID& id, const CommandType& command) { - dispatch(id, command, false).get(); - } + void redo(const CommandID& id, const CommandType& command); + + int subscribe(const CommandID& id, EventCallback callback); + void unsubscribe(const CommandID& id, int token); - // 获取命令执行历史 template - auto getCommandHistory(const CommandID& id) -> std::vector { - std::shared_lock lock(mutex_); - std::vector history; - if (auto it = history_.find(id); it != history_.end()) { - for (const auto& cmd : it->second) { - history.push_back(std::any_cast(cmd)); - } - } - return history; - } + auto getCommandHistory(const CommandID& id) -> std::vector; void clearHistory(); - - template - void registerCommandChain(const std::vector& ids, - std::tuple commands) { - std::apply( - [this, &ids, &commands](auto&&... cmds) { - (..., dispatch(ids[&cmds - &std::get<0>(commands)], cmds)); - }, - commands); - } - + void clearCommandHistory(const CommandID& id); auto getActiveCommands() const -> std::vector; private: - void recordHistory(const CommandID& id, const std::any& command) { - history_[id].push_back(command); - if (history_[id].size() > maxHistorySize_) { - history_[id].erase(history_[id].begin()); - } - } + void recordHistory(const CommandID& id, const std::any& command); + void notifySubscribers(const CommandID& id, const std::any& command); std::unordered_map handlers_; std::unordered_map undoHandlers_; std::unordered_map> history_; - mutable std::timed_mutex mutex_; + std::unordered_map> + subscribers_; + mutable std::shared_mutex mutex_; std::size_t maxHistorySize_ = 100; + std::shared_ptr eventLoop_; + int nextSubscriberId_ = 0; }; -} // namespace lithium +} // namespace lithium::app -#endif +#endif // LITHIUM_APP_COMMAND_HPP diff --git a/src/app/cotask.cpp b/src/app/cotask.cpp new file mode 100644 index 00000000..e69de29b diff --git a/src/app/cotask.hpp b/src/app/cotask.hpp new file mode 100644 index 00000000..e69de29b diff --git a/src/app/counter.cpp b/src/app/counter.cpp new file mode 100644 index 00000000..931e27f1 --- /dev/null +++ b/src/app/counter.cpp @@ -0,0 +1,170 @@ +#include "counter.hpp" + +#include +#include + +#include "atom/log/loguru.hpp" + +void FunctionCounter::start_timing(const std::source_location location) { + std::unique_lock lock(mutex); + auto& stats = counts[location.function_name()]; + stats.call_count++; + if (!time_stack.empty()) { + stats.callers.push_back(time_stack.back().first); + } + time_stack.push_back( + {location.function_name(), std::chrono::high_resolution_clock::now()}); +} + +void FunctionCounter::end_timing() { + std::unique_lock lock(mutex); + if (time_stack.empty()) + return; + + auto end_time = std::chrono::high_resolution_clock::now(); + auto [func_name, start_time] = time_stack.back(); + time_stack.pop_back(); + + auto duration = std::chrono::duration_cast( + end_time - start_time); + auto& stats = counts[func_name]; + stats.total_time += duration; + stats.min_time = std::min(stats.min_time, duration); + stats.max_time = std::max(stats.max_time, duration); + + if (duration > performance_threshold) { + LOG_F(WARNING, "Performance Alert: Function {} took {}", func_name, + format_duration(duration)); + } +} + +void FunctionCounter::print_stats(size_t top_n) { + std::shared_lock lock(mutex); + std::vector> sorted_stats( + counts.begin(), counts.end()); + std::sort(sorted_stats.begin(), sorted_stats.end(), + [](const auto& a, const auto& b) { + return a.second.call_count > b.second.call_count; + }); + + if (top_n > 0 && top_n < sorted_stats.size()) { + sorted_stats.resize(top_n); + } + + print_stats_header(); + + for (const auto& [func, stats] : sorted_stats) { + print_function_stats(func, stats); + } +} + +void FunctionCounter::reset_stats() { + std::unique_lock lock(mutex); + counts.clear(); + time_stack.clear(); +} + +void FunctionCounter::save_stats(const std::string& filename) { + std::shared_lock lock(mutex); + std::ofstream file(filename); + if (!file) { + LOG_F(ERROR, "Failed to open file for writing: {}", filename); + return; + } + + for (const auto& [func, stats] : counts) { + file << func << "," << stats.call_count << "," + << stats.total_time.count() << "," << stats.min_time.count() << "," + << stats.max_time.count(); + for (const auto& caller : stats.callers) { + file << "," << caller; + } + file << "\n"; + } +} + +void FunctionCounter::load_stats(const std::string& filename) { + std::unique_lock lock(mutex); + std::ifstream file(filename); + if (!file) { + LOG_F(ERROR, "Failed to open file for reading: {}", filename); + return; + } + + counts.clear(); + std::string line; + while (std::getline(file, line)) { + std::istringstream iss(line); + std::string func_name; + FunctionStats stats; + long long total_time, min_time, max_time; + + if (std::getline(iss, func_name, ',') && iss >> stats.call_count && + iss.ignore() && iss >> total_time && iss.ignore() && + iss >> min_time && iss.ignore() && iss >> max_time) { + stats.total_time = std::chrono::nanoseconds(total_time); + stats.min_time = std::chrono::nanoseconds(min_time); + stats.max_time = std::chrono::nanoseconds(max_time); + + std::string caller; + while (std::getline(iss, caller, ',')) { + stats.callers.push_back(caller); + } + + counts[func_name] = stats; + } + } +} + +void FunctionCounter::set_performance_threshold( + std::chrono::nanoseconds threshold) { + std::unique_lock lock(mutex); + performance_threshold = threshold; +} + +void FunctionCounter::print_call_graph() { + std::shared_lock lock(mutex); + LOG_F(INFO, "Printing Call Graph"); + for (const auto& [func, stats] : counts) { + LOG_F(INFO, "Function: {}", func); + for (const auto& caller : stats.callers) { + LOG_F(INFO, " Caller: {}", caller); + } + } +} + +void FunctionCounter::print_stats_header() { + LOG_F(INFO, + std::format("{:<30}{:>10}{:>15}{:>15}{:>15}{:>15}\n", "Function Name", + "Calls", "Total Time", "Avg Time", "Min Time", "Max Time") + .c_str()); +} + +void FunctionCounter::print_function_stats(std::string_view func, + const FunctionStats& stats) { + std::chrono::nanoseconds avg_time{0}; + if (stats.call_count > 0) { + avg_time = std::chrono::nanoseconds(stats.total_time.count() / + stats.call_count); + } + + LOG_F(INFO, std::format("{:<30}{:>10}{:>15}{:>15}{:>15}{:>15}\n", func, + stats.call_count, format_duration(stats.total_time), + format_duration(avg_time), + format_duration(stats.min_time), + format_duration(stats.max_time)) + .c_str()); +} + +std::string FunctionCounter::format_duration(std::chrono::nanoseconds ns) { + auto us = std::chrono::duration_cast(ns); + if (us.count() < 1000) { + return std::to_string(us.count()) + "µs"; + } + auto ms = std::chrono::duration_cast(ns); + if (ms.count() < 1000) { + return std::to_string(ms.count()) + "ms"; + } + auto s = std::chrono::duration_cast(ns); + return std::to_string(s.count()) + "s"; +} diff --git a/src/app/counter.hpp b/src/app/counter.hpp new file mode 100644 index 00000000..01b4dbc8 --- /dev/null +++ b/src/app/counter.hpp @@ -0,0 +1,62 @@ +#ifndef FUNCTION_COUNTER_HPP +#define FUNCTION_COUNTER_HPP + +#include +#include +#include +#include +#include +#include + +#define COUNT_AND_TIME_CALL \ + FunctionCounter::start_timing(); \ + auto counter_guard = std::make_unique(0); \ + (void)counter_guard; \ + auto timer_end = [](void*) { FunctionCounter::end_timing(); }; \ + std::unique_ptr timer_guard( \ + counter_guard.get(), timer_end); + +class FunctionCounter { +public: + struct FunctionStats { + size_t call_count = 0; + std::chrono::nanoseconds total_time{0}; + std::chrono::nanoseconds min_time{std::chrono::nanoseconds::max()}; + std::chrono::nanoseconds max_time{std::chrono::nanoseconds::min()}; + std::vector callers; + }; + + static void start_timing(const std::source_location location = std::source_location::current()); + static void end_timing(); + static void print_stats(size_t top_n = 0); + static void reset_stats(); + static void save_stats(const std::string& filename); + static void load_stats(const std::string& filename); + static void set_performance_threshold(std::chrono::nanoseconds threshold); + static void print_call_graph(); + + template + static void conditional_count(bool condition, Func&& func); + +private: + static inline std::map counts; + static inline std::vector> time_stack; + static inline std::shared_mutex mutex; + static inline std::chrono::nanoseconds performance_threshold; + + static void print_stats_header(); + static void print_function_stats(std::string_view func, const FunctionStats& stats); + static std::string format_duration(std::chrono::nanoseconds ns); +}; + +template +void FunctionCounter::conditional_count(bool condition, Func&& func) { + if (condition) { + COUNT_AND_TIME_CALL; + func(); + } else { + func(); + } +} + +#endif // FUNCTION_COUNTER_HPP diff --git a/src/app/eventloop.cpp b/src/app/eventloop.cpp new file mode 100644 index 00000000..d4fe2fb7 --- /dev/null +++ b/src/app/eventloop.cpp @@ -0,0 +1,276 @@ +#include "eventloop.hpp" + +#ifdef __linux__ +#include +#include +#elif _WIN32 +#include +#pragma comment(lib, "ws2_32.lib") +#endif + +#include "atom/log/loguru.hpp" + +namespace lithium::app { +EventLoop::EventLoop(int num_threads) : stop_flag_(false) { +#ifdef __linux__ + epoll_fd_ = epoll_create1(0); + if (epoll_fd_ == -1) { + ABORT_F("Failed to create epoll file descriptor"); + exit(EXIT_FAILURE); + } + epoll_events_.resize(10); +#elif _WIN32 + FD_ZERO(&read_fds); + WSADATA wsaData; + WSAStartup(MAKEWORD(2, 2), &wsaData); +#endif + +#ifdef USE_ASIO + // Boost.Asio 初始化 + for (int i = 0; i < num_threads; ++i) { + thread_pool_.emplace_back([this] { io_context_.run(); }); + } +#else + // 初始化线程池 + for (int i = 0; i < num_threads; ++i) { + thread_pool_.emplace_back(&EventLoop::workerThread, this); + } +#endif +} + +EventLoop::~EventLoop() { + stop(); + for (auto& thread : thread_pool_) { + if (thread.joinable()) { + thread.join(); + } + } +#ifdef __linux__ + close(epoll_fd_); + if (signal_fd_ != -1) { + close(signal_fd_); + } +#elif _WIN32 + WSACleanup(); +#endif +} + +void EventLoop::run() { + stop_flag_.store(false); +#ifndef USE_ASIO + workerThread(); +#endif +} + +void EventLoop::workerThread() { + while (!stop_flag_.load()) { + std::function task; + { + std::unique_lock lock(queue_mutex_); + if (!tasks_.empty()) { + auto currentTime = std::chrono::steady_clock::now(); + if (tasks_.top().execTime <= currentTime) { + task = tasks_.top().func; + tasks_.pop(); + } + } + } + + if (task) { + task(); + } else { +#ifdef __linux__ + int nfds = epoll_wait(epoll_fd_, epoll_events_.data(), + epoll_events_.size(), 10); + if (nfds == -1) { + ABORT_F("Epoll wait failed"); + } else if (nfds > 0) { + for (int i = 0; i < nfds; ++i) { + int fd = epoll_events_[i].data.fd; + if (fd == signal_fd_) { + // 处理信号事件 + uint64_t sigVal; + read(signal_fd_, &sigVal, sizeof(uint64_t)); + auto it = signal_handlers_.find(sigVal); + if (it != signal_handlers_.end()) { + it->second(); + } + } else { + // 处理文件描述符事件 + } + } + } +#elif _WIN32 + timeval timeout; + timeout.tv_sec = 0; + timeout.tv_usec = 10000; // 10ms + fd_set tmp_fds = read_fds; + int result = select(0, &tmp_fds, nullptr, nullptr, &timeout); + if (result > 0) { + for (u_int i = 0; i < tmp_fds.fd_count; ++i) { + SOCKET fd = tmp_fds.fd_array[i]; + // 处理 socket 事件 + } + } +#endif + std::this_thread::sleep_for( + std::chrono::milliseconds(10)); // Idle time + } + } +} + +void EventLoop::stop() { + stop_flag_.store(true); + wakeup(); +} + +template +auto EventLoop::post(int priority, F&& f, Args&&... args) + -> std::future> { + using return_type = std::invoke_result_t; + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + std::future result = task->get_future(); + { + std::unique_lock lock(queue_mutex_); + tasks_.emplace(Task{(*task)(), priority, + std::chrono::steady_clock::now(), next_task_id_++}); + } +#ifdef USE_ASIO + io_context_.post(task { (*task)(); }); +#else + condition_.notify_one(); // 通知等待中的线程 +#endif + return result; +} + +template +auto EventLoop::postDelayed(std::chrono::milliseconds delay, int priority, + F&& f, Args&&... args) + -> std::future> { + using return_type = std::invoke_result_t; + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + std::future result = task->get_future(); + auto execTime = std::chrono::steady_clock::now() + delay; +#ifdef USE_ASIO + auto timer = + std::make_unique(io_context_, delay); + timers_.emplace_back(std::move(timer)); + timers_.back()->async_wait([task](const boost::system::error_code& ec) { + if (!ec) { + (*task)(); + } + }); +#else + { + std::unique_lock lock(queue_mutex_); + tasks_.emplace( + Task{[task]() { (*task)(); }, priority, execTime, next_task_id_++}); + } + condition_.notify_one(); +#endif + return result; +} + +template +auto EventLoop::postDelayed(std::chrono::milliseconds delay, F&& f, + Args&&... args) + -> std::future> { + return postDelayed(delay, 0, std::forward(f), + std::forward(args)...); +} + +auto EventLoop::adjustTaskPriority(int task_id, int new_priority) -> bool { + std::unique_lock lock(queue_mutex_); + std::priority_queue newQueue; + + bool found = false; + while (!tasks_.empty()) { + Task task = std::move(const_cast(tasks_.top())); + tasks_.pop(); + if (task.taskId == task_id) { + task.priority = new_priority; + found = true; + } + newQueue.push(std::move(task)); + } + + tasks_ = std::move(newQueue); + return found; +} + +template +void EventLoop::postWithDependency(F&& f, G&& dependency_task) { + std::future dependency = dependency_task.get_future(); + std::thread([this, f = std::forward(f), + dependency = std::move(dependency)]() mutable { + dependency.wait(); // 等待依赖任务完成 + post(std::move(f)); // 执行依赖任务完成后的任务 + }).detach(); +} + +void EventLoop::subscribeEvent(const std::string& event_name, + const EventCallback& callback) { + std::unique_lock lock(queue_mutex_); + event_subscribers_[event_name].push_back(callback); +} + +void EventLoop::emitEvent(const std::string& event_name) { + std::unique_lock lock(queue_mutex_); + if (event_subscribers_.count(event_name)) { + for (const auto& callback : event_subscribers_[event_name]) { + post(callback); + } + } +} + +#ifdef __linux__ +void EventLoop::addSignalHandler(int signal, std::function handler) { + std::unique_lock lock(queue_mutex_); + signal_handlers_[signal] = std::move(handler); + + sigset_t mask; + sigemptyset(&mask); + sigaddset(&mask, signal); + signal_fd_ = signalfd(-1, &mask, SFD_NONBLOCK); + + epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = signal_fd_; + epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, signal_fd_, &ev); +} +#endif + +void EventLoop::wakeup() { +#ifdef __linux__ + // Linux: 使用 eventfd 唤醒 epoll + int eventFd = eventfd(0, EFD_NONBLOCK); + epoll_event ev = {0}; + ev.events = EPOLLIN; + ev.data.fd = eventFd; + epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, eventFd, &ev); + uint64_t one = 1; + write(eventFd, &one, sizeof(uint64_t)); // 触发事件 + close(eventFd); +#elif _WIN32 + // Windows: 利用 select 中的 socket 模拟唤醒机制 + SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); + FD_SET(sock, &read_fds); +#endif +} + +#ifdef __linux__ +void EventLoop::addEpollFd(int fd) const { + epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = fd; + if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &ev) == -1) { + ABORT_F("Failed to add fd to epoll"); + } +} +#elif _WIN32 +void EventLoop::add_socket_fd(SOCKET fd) { FD_SET(fd, &read_fds); } +#endif + +} // namespace lithium::app diff --git a/src/app/eventloop.hpp b/src/app/eventloop.hpp new file mode 100644 index 00000000..35caf8be --- /dev/null +++ b/src/app/eventloop.hpp @@ -0,0 +1,136 @@ +#ifndef EVENT_LOOP_HPP +#define EVENT_LOOP_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef USE_ASIO +#include +#endif + +#ifdef _WIN32 +#include +#elif __linux__ +#include +#include +#endif + +namespace lithium::app { + +class EventLoop { +public: + explicit EventLoop(int num_threads = 1); + ~EventLoop(); + + void run(); + void stop(); + + // 提交任务 (支持优先级) + template + auto post(int priority, F&& f, + Args&&... args) -> std::future>; + + // 无优先级任务提交 + template + auto post(F&& f, + Args&&... args) -> std::future>; + + // 延迟任务提交 + template + auto postDelayed(std::chrono::milliseconds delay, int priority, F&& f, + Args&&... args) + -> std::future>; + + template + auto postDelayed(std::chrono::milliseconds delay, F&& f, Args&&... args) + -> std::future>; + + // 动态调整任务优先级 + auto adjustTaskPriority(int task_id, int new_priority) -> bool; + + // 任务依赖 + template + void postWithDependency(F&& f, G&& dependency_task); + + // 定时任务 + void schedulePeriodic(std::chrono::milliseconds interval, int priority, + std::function func); + + // 增加任务取消支持 + template + auto postCancelable(F&& f, + std::atomic& cancel_flag) -> std::future; + + // 定时器接口 + void setTimeout(std::function func, + std::chrono::milliseconds delay); + void setInterval(std::function func, + std::chrono::milliseconds interval); + + // 事件订阅/发布 + using EventCallback = std::function; + void subscribeEvent(const std::string& event_name, + const EventCallback& callback); + void emitEvent(const std::string& event_name); + + // 文件描述符的支持(epoll 或 select) +#ifdef __linux__ + void addEpollFd(int fd) const; + void addSignalHandler(int signal, + std::function handler); // 增加信号处理 +#elif _WIN32 + void add_socket_fd(SOCKET fd); +#endif + +private: + struct Task { + std::function func; + int priority; + std::chrono::steady_clock::time_point execTime; + int taskId; + + auto operator<(const Task& other) const -> bool; + }; + + std::priority_queue tasks_; + std::mutex queue_mutex_; + std::condition_variable condition_; + std::atomic stop_flag_; + std::vector thread_pool_; // 线程池支持 + + std::unordered_map> + event_subscribers_; // 事件订阅者 + std::unordered_map> + signal_handlers_; // 信号处理 + int next_task_id_ = 0; + +#ifdef USE_ASIO + boost::asio::io_context io_context_; + std::vector> + timers_; // 定时器列表 +#endif + +#ifdef __linux__ + int epoll_fd_; + std::vector epoll_events_; + int signal_fd_; // 用于监听系统信号 +#elif _WIN32 + fd_set read_fds; +#endif + + void workerThread(); // 用于线程池中的工作线程 + void wakeup(); // 用于唤醒 `epoll` 或 `select` 处理 +}; + +} // namespace lithium::app + +#endif // EVENT_LOOP_HPP diff --git a/src/app/executor.cpp b/src/app/executor.cpp new file mode 100644 index 00000000..a4731647 --- /dev/null +++ b/src/app/executor.cpp @@ -0,0 +1,116 @@ +#include "executor.hpp" + +#include + +AsyncExecutor::AsyncExecutor(size_t thread_count) + : stop_flag(false), active_threads(0), task_counter(0) { + resize(thread_count); +} + +AsyncExecutor::~AsyncExecutor() { + shutdown(); +} + +bool AsyncExecutor::cancel_task(size_t task_id) { + std::unique_lock lock(queue_mutex); + auto it = std::find_if(tasks.begin(), tasks.end(), + [task_id](const Task& task) { return task.task_id == task_id; }); + + if (it != tasks.end()) { + it->is_cancelled = true; + return true; + } + + return false; +} + +void AsyncExecutor::resize(size_t new_thread_count) { + std::unique_lock lock(queue_mutex); + size_t current_count = workers.size(); + if (new_thread_count == current_count) + return; + + if (new_thread_count < current_count) { + stop_flag = true; + condition.notify_all(); + for (size_t i = 0; i < current_count - new_thread_count; ++i) { + if (workers[i].joinable()) { + workers[i].join(); + } + } + workers.resize(new_thread_count); + stop_flag = false; + } else { + for (size_t i = current_count; i < new_thread_count; ++i) { + workers.emplace_back([this] { + while (true) { + Task task; + { + std::unique_lock lock(queue_mutex); + condition.wait(lock, [this] { + return stop_flag || !tasks.empty(); + }); + + if (stop_flag && tasks.empty()) { + return; + } + + ++active_threads; + + task = std::move(tasks.front()); + std::pop_heap(tasks.begin(), tasks.end()); + tasks.pop_back(); + } + if (!task.is_cancelled) { + task(); + } + --active_threads; + } + }); + } + } +} + +void AsyncExecutor::shutdown(bool force) { + { + std::unique_lock lock(queue_mutex); + stop_flag = true; + } + + if (!force) { + std::unique_lock lock(queue_mutex); + condition.wait(lock, [this] { return tasks.empty(); }); + } + + condition.notify_all(); + for (std::thread& worker : workers) { + if (worker.joinable()) { + worker.join(); + } + } + workers.clear(); +} + +void AsyncExecutor::shutdown_delayed(std::chrono::milliseconds delay) { + std::thread([this, delay] { + std::this_thread::sleep_for(delay); + shutdown(); + }).detach(); +} + +size_t AsyncExecutor::get_active_threads() const { + return active_threads; +} + +size_t AsyncExecutor::get_task_queue_size() const { + std::unique_lock lock(queue_mutex); + return tasks.size(); +} + +bool AsyncExecutor::Task::operator<(const Task& other) const { + return priority < other.priority; +} + +void AsyncExecutor::Task::operator()() { + func(); +} diff --git a/src/app/executor.hpp b/src/app/executor.hpp new file mode 100644 index 00000000..1b0878a3 --- /dev/null +++ b/src/app/executor.hpp @@ -0,0 +1,77 @@ +#ifndef ASYNC_EXECUTOR_HPP +#define ASYNC_EXECUTOR_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class AsyncExecutor { +public: + explicit AsyncExecutor( + size_t thread_count = std::thread::hardware_concurrency()); + ~AsyncExecutor(); + + template + auto submit(int priority, F&& f, Args&&... args) + -> std::future> { + using return_type = std::invoke_result_t; + + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + + std::future result = task->get_future(); + { + std::unique_lock lock(queue_mutex); + if (stop_flag) { + throw std::runtime_error("Submit on stopped AsyncExecutor"); + } + + tasks.push_back( + Task{[task]() { (*task)(); }, priority, false, task_counter++}); + std::push_heap(tasks.begin(), tasks.end()); + } + + condition.notify_one(); + return result; + } + + template + auto submit(F&& f, Args&&... args) + -> std::future> { + return submit(0, std::forward(f), std::forward(args)...); + } + + bool cancel_task(size_t task_id); + void resize(size_t new_thread_count); + void shutdown(bool force = false); + void shutdown_delayed(std::chrono::milliseconds delay); + size_t get_active_threads() const; + size_t get_task_queue_size() const; + +private: + struct Task { + std::function func; + int priority; + bool is_cancelled; + size_t task_id; + + bool operator<(const Task& other) const; + void operator()(); + }; + + std::vector workers; + std::vector tasks; + mutable std::mutex queue_mutex; + std::condition_variable condition; + std::atomic stop_flag; + std::atomic active_threads; + std::atomic task_counter; +}; + +#endif // ASYNC_EXECUTOR_HPP diff --git a/src/atom/algorithm/error_calibration.hpp b/src/atom/algorithm/error_calibration.hpp new file mode 100644 index 00000000..3bc26037 --- /dev/null +++ b/src/atom/algorithm/error_calibration.hpp @@ -0,0 +1,369 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "atom/error/exception.hpp" +#include "atom/log/loguru.hpp" + +template +class AdvancedErrorCalibration { +private: + T slope_ = 1.0; + T intercept_ = 0.0; + std::optional r_squared_; + std::vector residuals_; + T mse_ = 0.0; // Mean Squared Error + T mae_ = 0.0; // Mean Absolute Error + + void calculate_metrics(const std::vector& measured, + const std::vector& actual) { + T sumSquaredError = 0.0; + T sumAbsoluteError = 0.0; + T meanActual = + std::accumulate(actual.begin(), actual.end(), T(0)) / actual.size(); + T ssTotal = 0; + T ssResidual = 0; + + residuals_.clear(); + for (size_t i = 0; i < actual.size(); ++i) { + T predicted = apply(measured[i]); + T error = actual[i] - predicted; + residuals_.push_back(error); + + sumSquaredError += error * error; + sumAbsoluteError += std::abs(error); + ssTotal += std::pow(actual[i] - meanActual, 2); + ssResidual += std::pow(error, 2); + } + + mse_ = sumSquaredError / actual.size(); + mae_ = sumAbsoluteError / actual.size(); + r_squared_ = 1 - (ssResidual / ssTotal); + } + + // 非线性校准函数类型 + using NonlinearFunction = std::function&)>; + + // Levenberg-Marquardt算法用于非线性拟合 + auto levenbergMarquardt(const std::vector& x, const std::vector& y, + NonlinearFunction func, + std::vector initial_params, + int max_iterations = 100, T lambda = 0.01, + T epsilon = 1e-8) -> std::vector { + int n = x.size(); + int m = initial_params.size(); + std::vector params = initial_params; + std::vector prevParams(m); + std::vector> jacobian(n, std::vector(m)); + + for (int iteration = 0; iteration < max_iterations; ++iteration) { + std::vector residuals(n); + for (int i = 0; i < n; ++i) { + residuals[i] = y[i] - func(x[i], params); + for (int j = 0; j < m; ++j) { + T h = std::max(1e-6, std::abs(params[j]) * 1e-6); + std::vector paramsPlusH = params; + paramsPlusH[j] += h; + jacobian[i][j] = + (func(x[i], paramsPlusH) - func(x[i], params)) / h; + } + } + + // 计算 J^T * J 和 J^T * r + std::vector> JTJ(m, std::vector(m)); + std::vector jTr(m); + for (int i = 0; i < m; ++i) { + for (int j = 0; j < m; ++j) { + JTJ[i][j] = 0; + for (int k = 0; k < n; ++k) { + JTJ[i][j] += jacobian[k][i] * jacobian[k][j]; + } + if (i == j) + JTJ[i][j] += lambda; + } + jTr[i] = 0; + for (int k = 0; k < n; ++k) { + jTr[i] += jacobian[k][i] * residuals[k]; + } + } + + // 解线性方程组 + std::vector delta = solveLinearSystem(JTJ, jTr); + + // 更新参数 + prevParams = params; + for (int i = 0; i < m; ++i) { + params[i] += delta[i]; + } + + // 检查收敛 + T diff = 0; + for (int i = 0; i < m; ++i) { + diff += std::abs(params[i] - prevParams[i]); + } + if (diff < epsilon) { + break; + } + } + + return params; + } + auto solveLinearSystem(const std::vector>& A, + const std::vector& b) -> std::vector { + int n = A.size(); + std::vector> augmented(n, std::vector(n + 1)); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + augmented[i][j] = A[i][j]; + } + augmented[i][n] = b[i]; + } + + for (int i = 0; i < n; ++i) { + int maxRow = i; + for (int k = i + 1; k < n; ++k) { + if (std::abs(augmented[k][i]) > + std::abs(augmented[maxRow][i])) { + maxRow = k; + } + } + std::swap(augmented[i], augmented[maxRow]); + + for (int k = i + 1; k < n; ++k) { + T factor = augmented[k][i] / augmented[i][i]; + for (int j = i; j <= n; ++j) { + augmented[k][j] -= factor * augmented[i][j]; + } + } + } + + std::vector x(n); + for (int i = n - 1; i >= 0; --i) { + x[i] = augmented[i][n]; + for (int j = i + 1; j < n; ++j) { + x[i] -= augmented[i][j] * x[j]; + } + x[i] /= augmented[i][i]; + } + + return x; + } + +public: + void linearCalibrate(const std::vector& measured, + const std::vector& actual) { + if (measured.size() != actual.size() || measured.empty()) { + THROW_INVALID_ARGUMENT( + "Input vectors must be non-empty and of equal size"); + } + + T sumX = std::accumulate(measured.begin(), measured.end(), T(0)); + T sumY = std::accumulate(actual.begin(), actual.end(), T(0)); + T sumXy = std::inner_product(measured.begin(), measured.end(), + actual.begin(), T(0)); + T sumXx = std::inner_product(measured.begin(), measured.end(), + measured.begin(), T(0)); + + T n = static_cast(measured.size()); + slope_ = (n * sumXy - sumX * sumY) / (n * sumXx - sumX * sumX); + intercept_ = (sumY - slope_ * sumX) / n; + + calculate_metrics(measured, actual); + } + + void polynomialCalibrate(const std::vector& measured, + const std::vector& actual, int degree) { + if (measured.size() != actual.size() || measured.empty()) { + THROW_INVALID_ARGUMENT( + "Input vectors must be non-empty and of equal size"); + } + + auto polyFunc = [degree](T x, const std::vector& params) { + T result = 0; + for (int i = 0; i <= degree; ++i) { + result += params[i] * std::pow(x, i); + } + return result; + }; + + std::vector initialParams(degree + 1, 1.0); + auto params = + levenbergMarquardt(measured, actual, polyFunc, initialParams); + + // 更新校准参数 + slope_ = params[1]; // 一阶系数作为斜率 + intercept_ = params[0]; // 常数项作为截距 + + calculate_metrics(measured, actual); + } + + [[nodiscard]] auto apply(T value) const -> T { + return slope_ * value + intercept_; + } + + void printParameters() const { + LOG_F(INFO, "Calibration parameters: slope = {}, intercept = {}", + slope_, intercept_); + if (r_squared_.has_value()) { + LOG_F(INFO, "R-squared = {}", r_squared_.value()); + } + LOG_F(INFO, "MSE = {}, MAE = {}", mse_, mae_); + } + + [[nodiscard]] auto getResiduals() const -> std::vector { + return residuals_; + } + + void plotResiduals(const std::string& filename) const { + std::ofstream file(filename); + if (!file.is_open()) { + THROW_FAIL_TO_OPEN_FILE("Failed to open file: " + filename); + } + + file << "Index,Residual\n"; + for (size_t i = 0; i < residuals_.size(); ++i) { + file << i << "," << residuals_[i] << "\n"; + } + } + + auto bootstrapConfidenceInterval( + const std::vector& measured, const std::vector& actual, + int n_iterations = 1000, + double confidence_level = 0.95) -> std::pair { + std::vector bootstrapSlopes; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, measured.size() - 1); + + for (int i = 0; i < n_iterations; ++i) { + std::vector bootMeasured; + std::vector bootActual; + for (size_t j = 0; j < measured.size(); ++j) { + int idx = dis(gen); + bootMeasured.push_back(measured[idx]); + bootActual.push_back(actual[idx]); + } + + AdvancedErrorCalibration bootCalibrator; + bootCalibrator.linearCalibrate(bootMeasured, bootActual); + bootstrapSlopes.push_back(bootCalibrator.getSlope()); + } + + std::sort(bootstrapSlopes.begin(), bootstrapSlopes.end()); + int lowerIdx = + static_cast((1 - confidence_level) / 2 * n_iterations); + int upperIdx = + static_cast((1 + confidence_level) / 2 * n_iterations); + + return {bootstrapSlopes[lowerIdx], bootstrapSlopes[upperIdx]}; + } + + void outlierDetection(const std::vector& measured, + const std::vector& actual, T threshold = 2.0) { + if (residuals_.empty()) { + THROW_RUNTIME_ERROR("Please call calculate_metrics() first"); + } + + T meanResidual = + std::accumulate(residuals_.begin(), residuals_.end(), T(0)) / + residuals_.size(); + T std_dev = std::sqrt( + std::accumulate(residuals_.begin(), residuals_.end(), T(0), + [meanResidual](T acc, T val) { + return acc + std::pow(val - meanResidual, 2); + }) / + residuals_.size()); + + /* + std::cout << "检测到的异常值:" << std::endl; + for (size_t i = 0; i < residuals_.size(); ++i) { + if (std::abs(residuals_[i] - meanResidual) > threshold * std_dev) { + std::cout << "索引: " << i << ", 测量值: " << measured[i] + << ", 实际值: " << actual[i] + << ", 残差: " << residuals_[i] << std::endl; + } + } + */ + } + + void crossValidation(const std::vector& measured, + const std::vector& actual, int k = 5) { + if (measured.size() != actual.size() || measured.size() < k) { + THROW_INVALID_ARGUMENT( + "Input vectors must be non-empty and of size"); + } + + std::vector mseValues; + std::vector maeValues; + std::vector rSquaredValues; + + for (int i = 0; i < k; ++i) { + std::vector trainMeasured; + std::vector trainActual; + std::vector testMeasured; + std::vector testActual; + for (size_t j = 0; j < measured.size(); ++j) { + if (j % k == i) { + testMeasured.push_back(measured[j]); + testActual.push_back(actual[j]); + } else { + trainMeasured.push_back(measured[j]); + trainActual.push_back(actual[j]); + } + } + + AdvancedErrorCalibration cvCalibrator; + cvCalibrator.linearCalibrate(trainMeasured, trainActual); + + T foldMse = 0; + T foldMae = 0; + T foldSsTotal = 0; + T foldSsResidual = 0; + T meanTestActual = + std::accumulate(testActual.begin(), testActual.end(), T(0)) / + testActual.size(); + for (size_t j = 0; j < testMeasured.size(); ++j) { + T predicted = cvCalibrator.apply(testMeasured[j]); + T error = testActual[j] - predicted; + foldMse += error * error; + foldMae += std::abs(error); + foldSsTotal += std::pow(testActual[j] - meanTestActual, 2); + foldSsResidual += std::pow(error, 2); + } + + mseValues.push_back(foldMse / testMeasured.size()); + maeValues.push_back(foldMae / testMeasured.size()); + rSquaredValues.push_back(1 - (foldSsResidual / foldSsTotal)); + } + + T avgMse = + std::accumulate(mseValues.begin(), mseValues.end(), T(0)) / k; + T avgMae = + std::accumulate(maeValues.begin(), maeValues.end(), T(0)) / k; + T avgRSquared = std::accumulate(rSquaredValues.begin(), + rSquaredValues.end(), T(0)) / + k; + + /* + std::cout << "K-fold 交叉验证结果 (k = " << k << "):" << std::endl; + std::cout << "平均 MSE: " << avgMse << std::endl; + std::cout << "平均 MAE: " << avgMae << std::endl; + std::cout << "平均 R-squared: " << avgRSquared << std::endl; + */ + } + + [[nodiscard]] auto getSlope() const -> T { return slope_; } + [[nodiscard]] auto getIntercept() const -> T { return intercept_; } + [[nodiscard]] auto getRSquared() const -> std::optional { + return r_squared_; + } + [[nodiscard]] auto getMse() const -> T { return mse_; } + [[nodiscard]] auto getMae() const -> T { return mae_; } +}; diff --git a/src/atom/algorithm/matrix.hpp b/src/atom/algorithm/matrix.hpp new file mode 100644 index 00000000..f79976ba --- /dev/null +++ b/src/atom/algorithm/matrix.hpp @@ -0,0 +1,369 @@ +#ifndef ATOM_ALGORITHM_MATRIX_HPP +#define ATOM_ALGORITHM_MATRIX_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "atom/error/exception.hpp" + +namespace atom::algorithm { +template +class Matrix; + +template +constexpr Matrix identity(); + +// 矩阵模板类,支持编译期矩阵计算 +template +class Matrix { +private: + std::array data_{}; + +public: + // 构造函数 + constexpr Matrix() = default; + constexpr explicit Matrix(const std::array& arr) + : data_(arr) {} + + // 访问矩阵元素 + constexpr auto operator()(std::size_t row, std::size_t col) -> T& { + return data_[row * Cols + col]; + } + + constexpr auto operator()(std::size_t row, + std::size_t col) const -> const T& { + return data_[row * Cols + col]; + } + + // 数据访问器 + auto getData() const -> const std::array& { return data_; } + + auto getData() -> std::array& { return data_; } + + // 打印矩阵 + void print(int width = 8, int precision = 2) const { + for (std::size_t i = 0; i < Rows; ++i) { + for (std::size_t j = 0; j < Cols; ++j) { + std::cout << std::setw(width) << std::fixed + << std::setprecision(precision) << (*this)(i, j) + << ' '; + } + std::cout << '\n'; + } + } + + // 矩阵的迹(对角线元素之和) + constexpr auto trace() const -> T { + static_assert(Rows == Cols, + "Trace is only defined for square matrices"); + T result = T{}; + for (std::size_t i = 0; i < Rows; ++i) { + result += (*this)(i, i); + } + return result; + } + + // Frobenius范数 + auto freseniusNorm() const -> T { + T sum = T{}; + for (const auto& elem : data_) { + sum += std::norm(elem); + } + return std::sqrt(sum); + } + + // 矩阵的最大元素 + auto maxElement() const -> T { + return *std::max_element( + data_.begin(), data_.end(), + [](const T& a, const T& b) { return std::abs(a) < std::abs(b); }); + } + + // 矩阵的最小元素 + auto minElement() const -> T { + return *std::min_element( + data_.begin(), data_.end(), + [](const T& a, const T& b) { return std::abs(a) < std::abs(b); }); + } + + // 判断矩阵是否为对称矩阵 + [[nodiscard]] auto isSymmetric() const -> bool { + static_assert(Rows == Cols, + "Symmetry is only defined for square matrices"); + for (std::size_t i = 0; i < Rows; ++i) { + for (std::size_t j = i + 1; j < Cols; ++j) { + if ((*this)(i, j) != (*this)(j, i)) { + return false; + } + } + } + return true; + } + + // 矩阵的幂运算 + auto pow(unsigned int n) const -> Matrix { + static_assert(Rows == Cols, + "Matrix power is only defined for square matrices"); + if (n == 0) { + return identity(); + } + if (n == 1) { + return *this; + } + Matrix result = *this; + for (unsigned int i = 1; i < n; ++i) { + result = result * (*this); + } + return result; + } + + // 矩阵的行列式(使用LU分解) + auto determinant() const -> T { + static_assert(Rows == Cols, + "Determinant is only defined for square matrices"); + auto [L, U] = lu_decomposition(*this); + T det = T{1}; + for (std::size_t i = 0; i < Rows; ++i) { + det *= U(i, i); + } + return det; + } + + // 矩阵的秩(使用高斯消元) + [[nodiscard]] auto rank() const -> std::size_t { + Matrix temp = *this; + std::size_t rank = 0; + for (std::size_t i = 0; i < Rows && i < Cols; ++i) { + // 找主元 + std::size_t pivot = i; + for (std::size_t j = i + 1; j < Rows; ++j) { + if (std::abs(temp(j, i)) > std::abs(temp(pivot, i))) { + pivot = j; + } + } + if (std::abs(temp(pivot, i)) < 1e-10) { + continue; + } + // 交换行 + if (pivot != i) { + for (std::size_t j = i; j < Cols; ++j) { + std::swap(temp(i, j), temp(pivot, j)); + } + } + // 消元 + for (std::size_t j = i + 1; j < Rows; ++j) { + T factor = temp(j, i) / temp(i, i); + for (std::size_t k = i; k < Cols; ++k) { + temp(j, k) -= factor * temp(i, k); + } + } + ++rank; + } + return rank; + } + + // 矩阵的条件数(使用2范数) + auto conditionNumber() const -> T { + static_assert(Rows == Cols, + "Condition number is only defined for square matrices"); + auto svd = singular_value_decomposition(*this); + return svd[0] / svd[svd.size() - 1]; + } +}; + +// 矩阵加法 +template +constexpr auto operator+(const Matrix& a, + const Matrix& b) + -> Matrix { + Matrix result{}; + for (std::size_t i = 0; i < Rows * Cols; ++i) { + result.get_data()[i] = a.get_data()[i] + b.get_data()[i]; + } + return result; +} + +// 矩阵减法 +template +constexpr auto operator-(const Matrix& a, + const Matrix& b) + -> Matrix { + Matrix result{}; + for (std::size_t i = 0; i < Rows * Cols; ++i) { + result.get_data()[i] = a.get_data()[i] - b.get_data()[i]; + } + return result; +} + +// 矩阵乘法 +template +auto operator*(const Matrix& a, + const Matrix& b) + -> Matrix { + Matrix result{}; + for (std::size_t i = 0; i < RowsA; ++i) { + for (std::size_t j = 0; j < ColsB; ++j) { + for (std::size_t k = 0; k < ColsA_RowsB; ++k) { + result(i, j) += a(i, k) * b(k, j); + } + } + } + return result; +} + +// 标量乘法(左乘和右乘) +template +constexpr auto operator*(const Matrix& m, U scalar) { + Matrix result; + for (std::size_t i = 0; i < Rows * Cols; ++i) { + result.get_data()[i] = m.get_data()[i] * scalar; + } + return result; +} + +template +constexpr auto operator*(U scalar, const Matrix& m) { + return m * scalar; +} + +// 矩阵逐元素乘法(Hadamard积) +template +constexpr auto hadamardProduct(const Matrix& a, + const Matrix& b) + -> Matrix { + Matrix result{}; + for (std::size_t i = 0; i < Rows * Cols; ++i) { + result.get_data()[i] = a.get_data()[i] * b.get_data()[i]; + } + return result; +} + +// 矩阵转置 +template +constexpr auto transpose(const Matrix& m) + -> Matrix { + Matrix result{}; + for (std::size_t i = 0; i < Rows; ++i) { + for (std::size_t j = 0; j < Cols; ++j) { + result(j, i) = m(i, j); + } + } + return result; +} + +// 创建单位矩阵 +template +constexpr auto identity() -> Matrix { + Matrix result{}; + for (std::size_t i = 0; i < Size; ++i) { + result(i, i) = T{1}; + } + return result; +} + +// 矩阵的LU分解 +template +auto luDecomposition(const Matrix& m) + -> std::pair, Matrix> { + Matrix L = identity(); + Matrix U = m; + + for (std::size_t k = 0; k < Size - 1; ++k) { + for (std::size_t i = k + 1; i < Size; ++i) { + if (std::abs(U(k, k)) < 1e-10) { + THROW_RUNTIME_ERROR( + "LU decomposition failed: division by zero"); + } + T factor = U(i, k) / U(k, k); + L(i, k) = factor; + for (std::size_t j = k; j < Size; ++j) { + U(i, j) -= factor * U(k, j); + } + } + } + + return {L, U}; +} + +// 矩阵的奇异值分解(仅返回奇异值) +template +auto singularValueDecomposition(const Matrix& m) + -> std::vector { + const std::size_t n = std::min(Rows, Cols); + Matrix mt = transpose(m); + Matrix mtm = mt * m; + + // 使用幂法计算最大特征值和对应的特征向量 + auto powerIteration = [&mtm](std::size_t max_iter = 100, T tol = 1e-10) { + std::vector v(Cols); + std::generate(v.begin(), v.end(), + []() { return static_cast(rand()) / RAND_MAX; }); + T lambdaOld = 0; + for (std::size_t iter = 0; iter < max_iter; ++iter) { + std::vector vNew(Cols); + for (std::size_t i = 0; i < Cols; ++i) { + for (std::size_t j = 0; j < Cols; ++j) { + vNew[i] += mtm(i, j) * v[j]; + } + } + T lambda = 0; + for (std::size_t i = 0; i < Cols; ++i) { + lambda += vNew[i] * v[i]; + } + T norm = std::sqrt(std::inner_product(vNew.begin(), vNew.end(), + vNew.begin(), T(0))); + for (auto& x : vNew) { + x /= norm; + } + if (std::abs(lambda - lambdaOld) < tol) { + return std::sqrt(lambda); + } + lambdaOld = lambda; + v = vNew; + } + THROW_RUNTIME_ERROR("Power iteration did not converge"); + }; + + std::vector singularValues; + for (std::size_t i = 0; i < n; ++i) { + T sigma = powerIteration(); + singularValues.push_back(sigma); + // Deflate the matrix + Matrix vvt; + for (std::size_t j = 0; j < Cols; ++j) { + for (std::size_t k = 0; k < Cols; ++k) { + vvt(j, k) = mtm(j, k) / (sigma * sigma); + } + } + mtm = mtm - vvt; + } + + std::sort(singularValues.begin(), singularValues.end(), std::greater()); + return singularValues; +} + +// 生成随机矩阵 +template +auto randomMatrix(T min = 0, T max = 1) -> Matrix { + static std::random_device rd; + static std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(min, max); + + Matrix result; + for (auto& elem : result.get_data()) { + elem = dis(gen); + } + return result; +} + +} // namespace atom::algorithm + +#endif diff --git a/src/atom/async/async.hpp b/src/atom/async/async.hpp index d1c3903f..53baa89f 100644 --- a/src/atom/async/async.hpp +++ b/src/atom/async/async.hpp @@ -413,5 +413,286 @@ auto getWithTimeout(std::future &future, } THROW_TIMEOUT_EXCEPTION("Timeout occurred while waiting for future result"); } + +// Helper function to get a future for a range of futures +template +auto whenAll(InputIt first, InputIt last, + std::optional timeout = std::nullopt) + -> std::future< + std::vector::value_type>> { + using FutureType = typename std::iterator_traits::value_type; + using ResultType = std::vector; + + std::promise promise; + std::future resultFuture = promise.get_future(); + + // Launch an async task to wait for all the futures + auto asyncTask = std::async([promise = std::move(promise), first, last, + timeout]() mutable { + ResultType results; + try { + for (auto it = first; it != last; ++it) { + if (timeout) { + // Check each future with timeout (if specified) + if (it->wait_for(*timeout) == std::future_status::timeout) { + THROW_INVALID_ARGUMENT( + "Timeout while waiting for a future."); + } + } + results.push_back(std::move(*it)); + } + promise.set_value(std::move(results)); + } catch (const std::exception &e) { + promise.set_exception( + std::current_exception()); // Pass the exception to the future + } + }); + + // Optionally, store the future or use it if needed + asyncTask.wait(); // Wait for the async task to finish + + return resultFuture; +} + +// Helper to get the return type of a future +template +using future_value_t = decltype(std::declval().get()); + +// Helper function for a variadic template version (when_all for futures as +// arguments) +template +auto whenAll(Futures &&...futures) + -> std::future...>> { + std::promise...>> promise; + std::future...>> resultFuture = + promise.get_future(); + + // Use async to wait for all futures and gather results + auto asyncTask = + std::async([promise = std::move(promise), + futures = std::make_tuple( + std::forward(futures)...)]() mutable { + try { + auto results = std::apply( + [](auto &&...fs) { + return std::make_tuple( + fs.get()...); // Wait for each future and collect + // the results + }, + futures); + promise.set_value(std::move(results)); + } catch (const std::exception &e) { + promise.set_exception(std::current_exception()); + } + }); + + asyncTask.wait(); // Wait for the async task to finish + + return resultFuture; +} + +// Primary template for EnhancedFuture +template +class EnhancedFuture { +public: + explicit EnhancedFuture(std::shared_future &&fut) + : future_(std::move(fut)), cancelled_(false) {} + + // Chaining: call another operation after the future is done + template + auto then(F &&func) { + using ResultType = std::invoke_result_t; + return EnhancedFuture( + std::async(std::launch::async, [fut = future_, + func = std::forward( + func)]() mutable { + if (fut.valid()) { + return func(fut.get()); + } + throw std::runtime_error("Future is invalid or cancelled"); + }).share()); + } + + // Wait with timeout and auto cancel + auto waitFor(std::chrono::milliseconds timeout) -> std::optional { + if (future_.wait_for(timeout) == std::future_status::ready && + !cancelled_) { + return future_.get(); + } + cancel(); + return std::nullopt; + } + + // Check if the future is done + [[nodiscard]] auto isDone() const -> bool { + return future_.wait_for(std::chrono::milliseconds(0)) == + std::future_status::ready; + } + + // Set a completion callback, allows multiple callbacks + template + void onComplete(F &&func) { + if (!cancelled_) { + callbacks_.emplace_back(std::forward(func)); + std::async(std::launch::async, [this]() { + try { + if (future_.valid()) { + auto result = future_.get(); + for (auto &callback : callbacks_) { + callback(result); + } + } + } catch (const std::exception &e) { + } + }).get(); + } + } + + // Synchronous wait + auto wait() -> T { + if (cancelled_) { + THROW_OBJ_NOT_EXIST("Future has been cancelled"); + } + return future_.get(); + } + + // Support cancellation + void cancel() { cancelled_ = true; } + + // Check if cancelled + [[nodiscard]] auto isCancelled() const -> bool { return cancelled_; } + + // Exception handling + auto getException() -> std::exception_ptr { + try { + future_.get(); + } catch (...) { + return std::current_exception(); + } + return nullptr; + } + + // Retry mechanism + template + auto retry(F &&func, int max_retries) { + using ResultType = std::invoke_result_t; + return EnhancedFuture( + std::async(std::launch::async, [fut = future_, + func = std::forward(func), + max_retries]() mutable { + for (int attempt = 0; attempt < max_retries; ++attempt) { + if (fut.valid()) { + try { + return func(fut.get()); + } catch (const std::exception &e) { + if (attempt == max_retries - 1) { + throw; + } + } + } else { + THROW_UNLAWFUL_OPERATION( + "Future is invalid or cancelled"); + } + } + }).share()); + } + +protected: + std::shared_future future_; + std::vector> callbacks_; + std::atomic cancelled_; +}; + +// Specialization for void type +template <> +class EnhancedFuture { +public: + explicit EnhancedFuture(std::shared_future &&fut) + : future_(std::move(fut)), cancelled_(false) {} + + template + auto then(F &&func) { + using ResultType = std::invoke_result_t; + return EnhancedFuture( + std::async(std::launch::async, [fut = future_, + func = std::forward( + func)]() mutable { + if (fut.valid()) { + fut.get(); + return func(); + } + THROW_UNLAWFUL_OPERATION("Future is invalid or cancelled"); + }).share()); + } + + auto waitFor(std::chrono::milliseconds timeout) -> bool { + if (future_.wait_for(timeout) == std::future_status::ready && + !cancelled_) { + future_.get(); + return true; + } + cancel(); + return false; + } + + [[nodiscard]] auto isDone() const -> bool { + return future_.wait_for(std::chrono::milliseconds(0)) == + std::future_status::ready; + } + + template + void onComplete(F &&func) { + if (!cancelled_) { + callbacks_.emplace_back(std::forward(func)); + std::async(std::launch::async, [this]() { + try { + if (future_.valid()) { + future_.get(); + for (auto &callback : callbacks_) { + callback(); + } + } + } catch (const std::exception &e) { + } + }).get(); + } + } + + void wait() { + if (cancelled_) { + THROW_OBJ_NOT_EXIST("Future has been cancelled"); + } + future_.get(); + } + + void cancel() { cancelled_ = true; } + + [[nodiscard]] auto isCancelled() const -> bool { return cancelled_; } + + auto getException() -> std::exception_ptr { + try { + future_.get(); + } catch (...) { + return std::current_exception(); + } + return nullptr; + } + +protected: + std::shared_future future_; + std::vector> callbacks_; + std::atomic cancelled_; +}; + +// Helper function to create EnhancedFuture +template +auto makeEnhancedFuture(F &&f, Args &&...args) { + using result_type = std::invoke_result_t; + return EnhancedFuture(std::async(std::launch::async, + std::forward(f), + std::forward(args)...) + .share()); +} + } // namespace atom::async #endif diff --git a/src/atom/async/slot.hpp b/src/atom/async/slot.hpp index c9c12aac..8fd29c6f 100644 --- a/src/atom/async/slot.hpp +++ b/src/atom/async/slot.hpp @@ -10,16 +10,33 @@ #include namespace atom::async { + +/** + * @brief A signal class that allows connecting, disconnecting, and emitting + * slots. + * + * @tparam Args The argument types for the slots. + */ template class Signal { public: using SlotType = std::function; + /** + * @brief Connect a slot to the signal. + * + * @param slot The slot to connect. + */ void connect(SlotType slot) { std::lock_guard lock(mutex_); slots_.push_back(std::move(slot)); } + /** + * @brief Disconnect a slot from the signal. + * + * @param slot The slot to disconnect. + */ void disconnect(const SlotType& slot) { std::lock_guard lock(mutex_); slots_.erase(std::remove_if(slots_.begin(), slots_.end(), @@ -30,6 +47,11 @@ class Signal { slots_.end()); } + /** + * @brief Emit the signal, calling all connected slots. + * + * @param args The arguments to pass to the slots. + */ void emit(Args... args) { std::lock_guard lock(mutex_); for (const auto& slot : slots_) { @@ -42,16 +64,31 @@ class Signal { std::mutex mutex_; }; +/** + * @brief A signal class that allows asynchronous slot execution. + * + * @tparam Args The argument types for the slots. + */ template class AsyncSignal { public: using SlotType = std::function; + /** + * @brief Connect a slot to the signal. + * + * @param slot The slot to connect. + */ void connect(SlotType slot) { std::lock_guard lock(mutex_); slots_.push_back(std::move(slot)); } + /** + * @brief Disconnect a slot from the signal. + * + * @param slot The slot to disconnect. + */ void disconnect(const SlotType& slot) { std::lock_guard lock(mutex_); slots_.erase(std::remove_if(slots_.begin(), slots_.end(), @@ -62,6 +99,11 @@ class AsyncSignal { slots_.end()); } + /** + * @brief Emit the signal asynchronously, calling all connected slots. + * + * @param args The arguments to pass to the slots. + */ void emit(Args... args) { std::vector> futures; { @@ -81,11 +123,22 @@ class AsyncSignal { std::mutex mutex_; }; +/** + * @brief A signal class that allows automatic disconnection of slots. + * + * @tparam Args The argument types for the slots. + */ template class AutoDisconnectSignal { public: using SlotType = std::function; + /** + * @brief Connect a slot to the signal and return its unique ID. + * + * @param slot The slot to connect. + * @return int The unique ID of the connected slot. + */ auto connect(SlotType slot) -> int { std::lock_guard lock(mutex_); auto id = nextId_++; @@ -93,11 +146,21 @@ class AutoDisconnectSignal { return id; } + /** + * @brief Disconnect a slot from the signal using its unique ID. + * + * @param id The unique ID of the slot to disconnect. + */ void disconnect(int id) { std::lock_guard lock(mutex_); slots_.erase(id); } + /** + * @brief Emit the signal, calling all connected slots. + * + * @param args The arguments to pass to the slots. + */ void emit(Args... args) { std::lock_guard lock(mutex_); for (const auto& [id, slot] : slots_) { @@ -111,21 +174,41 @@ class AutoDisconnectSignal { int nextId_ = 0; }; +/** + * @brief A signal class that allows chaining of signals. + * + * @tparam Args The argument types for the slots. + */ template class ChainedSignal { public: using SlotType = std::function; + /** + * @brief Connect a slot to the signal. + * + * @param slot The slot to connect. + */ void connect(SlotType slot) { std::lock_guard lock(mutex_); slots_.push_back(std::move(slot)); } + /** + * @brief Add a chained signal to be emitted after this signal. + * + * @param nextSignal The next signal to chain. + */ void addChain(ChainedSignal& nextSignal) { std::lock_guard lock(mutex_); chains_.push_back(&nextSignal); } + /** + * @brief Emit the signal, calling all connected slots and chained signals. + * + * @param args The arguments to pass to the slots. + */ void emit(Args... args) { std::lock_guard lock(mutex_); for (const auto& slot : slots_) { @@ -142,16 +225,32 @@ class ChainedSignal { std::mutex mutex_; }; +/** + * @brief A signal class that allows connecting, disconnecting, and emitting + * slots. + * + * @tparam Args The argument types for the slots. + */ template class TemplateSignal { public: using SlotType = std::function; + /** + * @brief Connect a slot to the signal. + * + * @param slot The slot to connect. + */ void connect(SlotType slot) { std::lock_guard lock(mutex_); slots_.push_back(std::move(slot)); } + /** + * @brief Disconnect a slot from the signal. + * + * @param slot The slot to disconnect. + */ void disconnect(const SlotType& slot) { std::lock_guard lock(mutex_); slots_.erase(std::remove_if(slots_.begin(), slots_.end(), @@ -162,6 +261,11 @@ class TemplateSignal { slots_.end()); } + /** + * @brief Emit the signal, calling all connected slots. + * + * @param args The arguments to pass to the slots. + */ void emit(Args... args) { std::lock_guard lock(mutex_); for (const auto& slot : slots_) { @@ -174,16 +278,31 @@ class TemplateSignal { std::mutex mutex_; }; +/** + * @brief A signal class that ensures thread-safe slot execution. + * + * @tparam Args The argument types for the slots. + */ template class ThreadSafeSignal { public: using SlotType = std::function; + /** + * @brief Connect a slot to the signal. + * + * @param slot The slot to connect. + */ void connect(SlotType slot) { std::lock_guard lock(mutex_); slots_.push_back(std::move(slot)); } + /** + * @brief Disconnect a slot from the signal. + * + * @param slot The slot to disconnect. + */ void disconnect(const SlotType& slot) { std::lock_guard lock(mutex_); slots_.erase(std::remove_if(slots_.begin(), slots_.end(), @@ -194,6 +313,12 @@ class ThreadSafeSignal { slots_.end()); } + /** + * @brief Emit the signal, calling all connected slots in a thread-safe + * manner. + * + * @param args The arguments to pass to the slots. + */ void emit(Args... args) { std::vector> tasks; { @@ -212,16 +337,31 @@ class ThreadSafeSignal { std::mutex mutex_; }; +/** + * @brief A signal class that allows broadcasting to chained signals. + * + * @tparam Args The argument types for the slots. + */ template class BroadcastSignal { public: using SlotType = std::function; + /** + * @brief Connect a slot to the signal. + * + * @param slot The slot to connect. + */ void connect(SlotType slot) { std::lock_guard lock(mutex_); slots_.push_back(std::move(slot)); } + /** + * @brief Disconnect a slot from the signal. + * + * @param slot The slot to disconnect. + */ void disconnect(const SlotType& slot) { std::lock_guard lock(mutex_); slots_.erase(std::remove_if(slots_.begin(), slots_.end(), @@ -232,6 +372,11 @@ class BroadcastSignal { slots_.end()); } + /** + * @brief Emit the signal, calling all connected slots and chained signals. + * + * @param args The arguments to pass to the slots. + */ void emit(Args... args) { std::lock_guard lock(mutex_); for (const auto& slot : slots_) { @@ -242,6 +387,11 @@ class BroadcastSignal { } } + /** + * @brief Add a chained signal to be emitted after this signal. + * + * @param signal The next signal to chain. + */ void addChain(BroadcastSignal& signal) { std::lock_guard lock(mutex_); chainedSignals_.push_back(&signal); @@ -253,18 +403,38 @@ class BroadcastSignal { std::mutex mutex_; }; +/** + * @brief A signal class that limits the number of times it can be emitted. + * + * @tparam Args The argument types for the slots. + */ template class LimitedSignal { public: using SlotType = std::function; + /** + * @brief Construct a new Limited Signal object. + * + * @param maxCalls The maximum number of times the signal can be emitted. + */ explicit LimitedSignal(size_t maxCalls) : maxCalls_(maxCalls) {} + /** + * @brief Connect a slot to the signal. + * + * @param slot The slot to connect. + */ void connect(SlotType slot) { std::lock_guard lock(mutex_); slots_.push_back(std::move(slot)); } + /** + * @brief Disconnect a slot from the signal. + * + * @param slot The slot to disconnect. + */ void disconnect(const SlotType& slot) { std::lock_guard lock(mutex_); slots_.erase(std::remove_if(slots_.begin(), slots_.end(), @@ -275,6 +445,12 @@ class LimitedSignal { slots_.end()); } + /** + * @brief Emit the signal, calling all connected slots up to the maximum + * number of calls. + * + * @param args The arguments to pass to the slots. + */ void emit(Args... args) { std::lock_guard lock(mutex_); if (callCount_ >= maxCalls_) { @@ -293,16 +469,31 @@ class LimitedSignal { std::mutex mutex_; }; +/** + * @brief A signal class that allows dynamic slot management. + * + * @tparam Args The argument types for the slots. + */ template class DynamicSignal { public: using SlotType = std::function; + /** + * @brief Connect a slot to the signal. + * + * @param slot The slot to connect. + */ void connect(SlotType slot) { std::lock_guard lock(mutex_); slots_.push_back(std::move(slot)); } + /** + * @brief Disconnect a slot from the signal. + * + * @param slot The slot to disconnect. + */ void disconnect(const SlotType& slot) { std::lock_guard lock(mutex_); slots_.erase(std::remove_if(slots_.begin(), slots_.end(), @@ -313,6 +504,11 @@ class DynamicSignal { slots_.end()); } + /** + * @brief Emit the signal, calling all connected slots. + * + * @param args The arguments to pass to the slots. + */ void emit(Args... args) { std::lock_guard lock(mutex_); for (const auto& slot : slots_) { @@ -325,16 +521,31 @@ class DynamicSignal { std::mutex mutex_; }; +/** + * @brief A signal class that allows scoped slot management. + * + * @tparam Args The argument types for the slots. + */ template class ScopedSignal { public: using SlotType = std::function; + /** + * @brief Connect a slot to the signal using a shared pointer. + * + * @param slotPtr The shared pointer to the slot to connect. + */ void connect(std::shared_ptr slotPtr) { std::lock_guard lock(mutex_); slots_.push_back(slotPtr); } + /** + * @brief Emit the signal, calling all connected slots. + * + * @param args The arguments to pass to the slots. + */ void emit(Args... args) { std::lock_guard lock(mutex_); auto it = slots_.begin(); diff --git a/src/atom/components/component.cpp b/src/atom/components/component.cpp new file mode 100644 index 00000000..e69de29b diff --git a/src/atom/components/component.hpp b/src/atom/components/component.hpp index 36bc1482..4243ee0b 100644 --- a/src/atom/components/component.hpp +++ b/src/atom/components/component.hpp @@ -59,7 +59,11 @@ class Component : public std::enable_shared_from_this { // Inject methods // ------------------------------------------------------------------- - std::weak_ptr getInstance() const; + auto getInstance() const -> std::weak_ptr; + + auto getSharedInstance() -> std::shared_ptr { + return shared_from_this(); + } // ------------------------------------------------------------------- // Common methods diff --git a/src/atom/error/error_code.hpp b/src/atom/error/error_code.hpp index ba3190d4..8a9e012f 100644 --- a/src/atom/error/error_code.hpp +++ b/src/atom/error/error_code.hpp @@ -41,6 +41,9 @@ enum class FileError : int { UnLoadError = 113, // 动态卸载错误 LockError = 114, // 文件锁错误 FormatError = 115, // 文件格式错误 + PathTooLong = 116, // 路径过长 + FileCorrupted = 117, // 文件损坏 + UnsupportedFormat = 118, // 不支持的文件格式 }; // 设备错误 @@ -68,32 +71,125 @@ enum class DeviceError : int { ParkedError = 223, HomeError = 224, - InitializationError = 230, // 初始化错误 - ResourceExhausted = 231, // 资源耗尽 + InitializationError = 230, // 初始化错误 + ResourceExhausted = 231, // 资源耗尽 + FirmwareUpdateFailed = 232, // 固件更新失败 + CalibrationError = 233, // 校准错误 + Overheating = 234, // 设备过热 + PowerFailure = 235, // 电源故障 }; -// 设备警告 -// (保持现有结构,根据需要添加更多警告类型) +// 网络错误 +enum class NetworkError : int { + None = static_cast(ErrorCodeBase::Success), + ConnectionLost = 400, // 网络连接丢失 + ConnectionRefused = 401, // 连接被拒绝 + DNSLookupFailed = 402, // DNS查询失败 + ProtocolError = 403, // 协议错误 + SSLHandshakeFailed = 404, // SSL握手失败 + AddressInUse = 405, // 地址已在使用 + AddressNotAvailable = 406, // 地址不可用 + NetworkDown = 407, // 网络已关闭 + HostUnreachable = 408, // 主机不可达 + MessageTooLarge = 409, // 消息过大 + BufferOverflow = 410, // 缓冲区溢出 + TimeoutError = 411, // 网络超时 + BandwidthExceeded = 412, // 带宽超限 + NetworkCongested = 413, // 网络拥塞 +}; -// 服务器错误 -enum class ServerError : int { +// 数据库错误 +enum class DatabaseError : int { + None = static_cast(ErrorCodeBase::Success), + ConnectionFailed = 500, // 数据库连接失败 + QueryFailed = 501, // 查询失败 + TransactionFailed = 502, // 事务失败 + IntegrityConstraintViolation = 503, // 违反完整性约束 + NoSuchTable = 504, // 表不存在 + DuplicateEntry = 505, // 重复条目 + DataTooLong = 506, // 数据过长 + DataTruncated = 507, // 数据被截断 + Deadlock = 508, // 死锁 + LockTimeout = 509, // 锁超时 + IndexOutOfBounds = 510, // 索引越界 + ConnectionTimeout = 511, // 连接超时 + InvalidQuery = 512, // 无效查询 +}; + +// 内存管理错误 +enum class MemoryError : int { + None = static_cast(ErrorCodeBase::Success), + AllocationFailed = 600, // 内存分配失败 + OutOfMemory = 601, // 内存不足 + AccessViolation = 602, // 内存访问违例 + BufferOverflow = 603, // 缓冲区溢出 + DoubleFree = 604, // 双重释放 + InvalidPointer = 605, // 无效指针 + MemoryLeak = 606, // 内存泄漏 + StackOverflow = 607, // 栈溢出 + CorruptedHeap = 608, // 堆损坏 +}; + +// 用户输入错误 +enum class UserInputError : int { None = static_cast(ErrorCodeBase::Success), - InvalidParameters = 300, - InvalidFormat = 301, - MissingParameters = 302, + InvalidInput = 700, // 无效输入 + OutOfRange = 701, // 输入值超出范围 + MissingInput = 702, // 缺少输入 + FormatError = 703, // 输入格式错误 + UnsupportedType = 704, // 不支持的输入类型 + InputTooLong = 705, // 输入过长 + InputTooShort = 706, // 输入过短 + InvalidCharacter = 707, // 无效字符 +}; - RunFailed = 303, +// 配置错误 +enum class ConfigError : int { + None = static_cast(ErrorCodeBase::Success), + MissingConfig = 800, // 缺少配置文件 + InvalidConfig = 801, // 无效的配置 + ConfigParseError = 802, // 配置解析错误 + UnsupportedConfig = 803, // 不支持的配置 + ConfigConflict = 804, // 配置冲突 + InvalidOption = 805, // 无效选项 + ConfigNotSaved = 806, // 配置未保存 + ConfigLocked = 807, // 配置被锁定 +}; - UnknownError = 310, - UnknownCommand = 311, - UnknownDevice = 312, - UnknownDeviceType = 313, - UnknownDeviceName = 314, - UnknownDeviceID = 315, +// 进程和线程错误 +enum class ProcessError : int { + None = static_cast(ErrorCodeBase::Success), + ProcessNotFound = 900, // 进程未找到 + ProcessFailed = 901, // 进程失败 + ThreadCreationFailed = 902, // 线程创建失败 + ThreadJoinFailed = 903, // 线程合并失败 + ThreadTimeout = 904, // 线程超时 + DeadlockDetected = 905, // 检测到死锁 + ProcessTerminated = 906, // 进程被终止 + InvalidProcessState = 907, // 无效的进程状态 + InsufficientResources = 908, // 资源不足 + InvalidThreadPriority = 909, // 无效的线程优先级 +}; +// 服务器错误 +enum class ServerError : int { + None = static_cast(ErrorCodeBase::Success), + InvalidParameters = 300, // 无效参数 + InvalidFormat = 301, // 无效格式 + MissingParameters = 302, // 缺少参数 + RunFailed = 303, // 运行失败 + UnknownError = 310, // 未知错误 + UnknownCommand = 311, // 未知命令 + UnknownDevice = 312, // 未知设备 + UnknownDeviceType = 313, // 未知设备类型 + UnknownDeviceName = 314, // 未知设备名称 + UnknownDeviceID = 315, // 未知设备ID NetworkError = 320, // 网络错误 TimeoutError = 321, // 请求超时 AuthenticationError = 322, // 认证失败 + PermissionDenied = 323, // 权限被拒绝 + ServerOverload = 324, // 服务器过载 + MaintenanceMode = 325, // 维护模式 }; -#endif // ATOM_ERROR_CODE_HPP +#endif diff --git a/src/atom/error/exception.hpp b/src/atom/error/exception.hpp index 42f8a5d1..db72775f 100644 --- a/src/atom/error/exception.hpp +++ b/src/atom/error/exception.hpp @@ -107,10 +107,9 @@ class Exception : public std::exception { throw atom::error::Exception(ATOM_FILE_NAME, ATOM_FILE_LINE, \ ATOM_FUNC_NAME, __VA_ARGS__) -#define THROW_NESTED_EXCEPTION(...) \ - atom::error::Exception::rethrowNested(__FILE__, __LINE__, __FUNCTION__, \ - __VA_ARGS__) -// Special Exception +#define THROW_NESTED_EXCEPTION(...) \ + atom::error::Exception::rethrowNested(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) // ------------------------------------------------------------------- // Common @@ -125,9 +124,18 @@ class RuntimeError : public Exception { throw atom::error::RuntimeError(ATOM_FILE_NAME, ATOM_FILE_LINE, \ ATOM_FUNC_NAME, __VA_ARGS__) -#define THROW_NESTED_RUNTIME_ERROR(...) \ - atom::error::RuntimeError::rethrowNested(__FILE__, __LINE__, __FUNCTION__, \ - __VA_ARGS__) +#define THROW_NESTED_RUNTIME_ERROR(...) \ + atom::error::RuntimeError::rethrowNested(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) + +class LogicError : public Exception { +public: + using Exception::Exception; +}; + +#define THROW_LOGIC_ERROR(...) \ + throw atom::error::LogicError(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) class UnlawfulOperation : public Exception { public: @@ -321,15 +329,69 @@ class FailToCloseFile : public Exception { throw atom::error::FailToCloseFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ ATOM_FUNC_NAME, __VA_ARGS__) -class FailToLoadDll : public Exception { +class FailToCreateFile : public Exception { +public: + using Exception::Exception; +}; + +#define THROW_FAIL_TO_CREATE_FILE(...) \ + throw atom::error::FailToCreateFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) + +class FailToDeleteFile : public Exception { +public: + using Exception::Exception; +}; + +#define THROW_FAIL_TO_DELETE_FILE(...) \ + throw atom::error::FailToDeleteFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) + +class FailToCopyFile : public Exception { +public: + using Exception::Exception; +}; + +#define THROW_FAIL_TO_COPY_FILE(...) \ + throw atom::error::FailToCopyFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) + +class FailToMoveFile : public Exception { +public: + using Exception::Exception; +}; + +#define THROW_FAIL_TO_MOVE_FILE(...) \ + throw atom::error::FailToMoveFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) + +class FailToReadFile : public Exception { +public: + using Exception::Exception; +}; + +#define THROW_FAIL_TO_READ_FILE(...) \ + throw atom::error::FailToReadFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) + +class FailToWriteFile : public Exception { public: using Exception::Exception; }; +#define THROW_FAIL_TO_WRITE_FILE(...) \ + throw atom::error::FailToWriteFile(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) + // ------------------------------------------------------------------- // Dynamic Library // ------------------------------------------------------------------- +class FailToLoadDll : public Exception { +public: + using Exception::Exception; +}; + #define THROW_FAIL_TO_LOAD_DLL(...) \ throw atom::error::FailToLoadDll(ATOM_FILE_NAME, ATOM_FILE_LINE, \ ATOM_FUNC_NAME, __VA_ARGS__) diff --git a/src/atom/error/stacktrace.cpp b/src/atom/error/stacktrace.cpp index ff514248..45f9939b 100644 --- a/src/atom/error/stacktrace.cpp +++ b/src/atom/error/stacktrace.cpp @@ -8,7 +8,7 @@ Date: 2023-11-10 -Description: StackTrace +Description: Enhanced StackTrace with more details **************************************************/ @@ -37,24 +37,6 @@ Description: StackTrace namespace atom::error { namespace { - -/** - * @brief Perform platform-specific symbol demangling. - * - * @param input The mangled symbol name. - * @return A demangled symbol name if possible, otherwise the original name. - */ -auto demangleSymbol(const std::string& input) -> std::string { -#if defined(__linux__) || defined(__APPLE__) - int status = 0; - std::unique_ptr demangled( - abi::__cxa_demangle(input.c_str(), nullptr, nullptr, &status), free); - return (status == 0 && demangled) ? demangled.get() : input; -#else - return input; -#endif -} - #if defined(__linux__) || defined(__APPLE__) auto processString(const std::string& input) -> std::string { size_t startIndex = input.find("_Z"); @@ -66,7 +48,7 @@ auto processString(const std::string& input) -> std::string { return input; } std::string abiName = input.substr(startIndex, endIndex - startIndex); - abiName = demangleSymbol(abiName); + abiName = meta::DemangleHelper::demangle(abiName); std::string result = input; result.replace(startIndex, endIndex - startIndex, abiName); return result; @@ -100,26 +82,33 @@ auto StackTrace::toString() const -> std::string { std::ostringstream oss; #ifdef _WIN32 - SYMBOL_INFO* symbol = reinterpret_cast( + auto* symbol = reinterpret_cast( calloc(sizeof(SYMBOL_INFO) + 256 * sizeof(char), 1)); symbol->MaxNameLen = 255; symbol->SizeOfStruct = sizeof(SYMBOL_INFO); for (void* frame : frames_) { - SymFromAddr(GetCurrentProcess(), reinterpret_cast(frame), 0, - symbol); - std::string symbol_name = symbol->Name; - if (!symbol_name.empty()) { - oss << "\t\t" << demangleSymbol("_" + symbol_name) << " - 0x" - << std::hex << symbol->Address << "\n"; + DWORD64 displacement = 0; + if (SymFromAddr(GetCurrentProcess(), reinterpret_cast(frame), + &displacement, symbol) != 0) { + std::string symbolName = symbol->Name; + oss << "\t\t" << meta::DemangleHelper::demangle("_" + symbolName) + << " - 0x" << std::hex << symbol->Address << "\n"; } } free(symbol); #elif defined(__APPLE__) || defined(__linux__) for (int i = 0; i < num_frames_; ++i) { - std::string_view symbol(symbols_.get()[i]); - oss << "\t\t" << processString(std::string(symbol)) << "\n"; + Dl_info info; + if (dladdr(frames_[i], &info) && info.dli_sname) { + std::string symbol_name = + meta::DemangleHelper::demangle(info.dli_sname); + oss << "\t\t" << symbol_name << " (" << info.dli_fname << ")\n"; + } else { + std::string_view symbol(symbols_.get()[i]); + oss << "\t\t" << processString(std::string(symbol)) << "\n"; + } } #else @@ -135,12 +124,12 @@ void StackTrace::capture() { frames_.resize(max_frames); SymInitialize(GetCurrentProcess(), nullptr, TRUE); - std::array frame_ptrs; - WORD captured_frames = - CaptureStackBackTrace(0, max_frames, frame_ptrs.data(), nullptr); + void* framePtrs[max_frames]; + WORD capturedFrames = + CaptureStackBackTrace(0, max_frames, framePtrs, nullptr); - frames_.resize(captured_frames); - std::copy_n(frame_ptrs.begin(), captured_frames, frames_.begin()); + frames_.resize(capturedFrames); + std::copy_n(framePtrs, capturedFrames, frames_.begin()); #elif defined(__APPLE__) || defined(__linux__) constexpr int MAX_FRAMES = 64; @@ -148,6 +137,7 @@ void StackTrace::capture() { num_frames_ = backtrace(framePtrs, MAX_FRAMES); symbols_.reset(backtrace_symbols(framePtrs, num_frames_)); + frames_.assign(framePtrs, framePtrs + num_frames_); #else num_frames_ = 0; diff --git a/src/atom/error/stacktrace.hpp b/src/atom/error/stacktrace.hpp index 21a83dab..493a018a 100644 --- a/src/atom/error/stacktrace.hpp +++ b/src/atom/error/stacktrace.hpp @@ -8,25 +8,25 @@ Date: 2023-11-10 -Description: StackTrace +Description: Enhanced StackTrace with more details **************************************************/ #ifndef ATOM_ERROR_STACKTRACE_HPP #define ATOM_ERROR_STACKTRACE_HPP +#include #include #include -#include namespace atom::error { /** * @brief Class for capturing and representing a stack trace. * - * This class provides functionality to capture the stack trace of the current - * execution context and represent it as a string. It supports different - * implementations for different operating systems. + * This class captures the stack trace of the current + * execution context and represents it as a string, including + * file names, line numbers, and symbols if available. */ class StackTrace { public: @@ -56,7 +56,11 @@ class StackTrace { #ifdef _WIN32 std::vector frames_; /**< Vector to store stack frames on Windows. */ #elif defined(__APPLE__) || defined(__linux__) - std::unique_ptr symbols_{nullptr, &free}; /**< Pointer to store stack symbols on macOS or Linux. */ + std::unique_ptr symbols_{ + nullptr, + &free}; /**< Pointer to store stack symbols on macOS or Linux. */ + std::vector + frames_; /**< Vector to store raw stack frame pointers. */ int num_frames_ = 0; /**< Number of stack frames captured. */ #endif }; diff --git a/src/atom/extra/beast/http.cpp b/src/atom/extra/beast/http.cpp new file mode 100644 index 00000000..f74d7c94 --- /dev/null +++ b/src/atom/extra/beast/http.cpp @@ -0,0 +1,64 @@ +#include "http.hpp" + +#include + +HttpClient::HttpClient(net::io_context& ioc) + : resolver_(net::make_strand(ioc)), stream_(net::make_strand(ioc)) {} + +void HttpClient::setDefaultHeader(const std::string& key, + const std::string& value) { + default_headers_[key] = value; +} + +void HttpClient::setTimeout(std::chrono::seconds timeout) { + timeout_ = timeout; +} + +auto HttpClient::uploadFile( + const std::string& host, const std::string& port, const std::string& target, + const std::string& filepath, + const std::string& field_name) -> http::response { + std::ifstream file(filepath, std::ios::binary); + if (!file) { + throw std::runtime_error("Failed to open file: " + filepath); + } + std::string fileContent((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + + std::string boundary = + "-------------------------" + std::to_string(std::time(nullptr)); + + std::string body = "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + field_name + + "\"; filename=\"" + + std::filesystem::path(filepath).filename().string() + "\"\r\n"; + body += "Content-Type: application/octet-stream\r\n\r\n"; + body += fileContent + "\r\n"; + body += "--" + boundary + "--\r\n"; + + std::string contentType = "multipart/form-data; boundary=" + boundary; + + return request(http::verb::post, host, port, target, 11, contentType, body); +} + +void HttpClient::downloadFile(const std::string& host, const std::string& port, + const std::string& target, + const std::string& filepath) { + auto res = request(http::verb::get, host, port, target); + std::ofstream outFile(filepath, std::ios::binary); + outFile << res.body(); +} + +void HttpClient::runWithThreadPool(size_t num_threads) { + net::thread_pool pool(num_threads); + + for (size_t i = 0; i < num_threads; ++i) { + net::post(pool, [this] { + // Example task: send a request in a thread from the pool + auto res = request(http::verb::get, "example.com", "80", "/"); + std::cout << "Response in thread pool: " << res << std::endl; + }); + } + + pool.join(); // Wait for all threads to finish +} diff --git a/src/atom/extra/beast/http.hpp b/src/atom/extra/beast/http.hpp new file mode 100644 index 00000000..4832551d --- /dev/null +++ b/src/atom/extra/beast/http.hpp @@ -0,0 +1,357 @@ +#ifndef HTTP_CLIENT_HPP +#define HTTP_CLIENT_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace beast = boost::beast; +namespace http = beast::http; +namespace net = boost::asio; +using tcp = boost::asio::ip::tcp; +using json = nlohmann::json; + +class HttpClient { +public: + explicit HttpClient(net::io_context& ioc); + + void setDefaultHeader(const std::string& key, const std::string& value); + void setTimeout(std::chrono::seconds timeout); + + // Synchronous request + template + auto request(http::verb method, const std::string& host, + const std::string& port, const std::string& target, + int version = 11, const std::string& content_type = "", + const std::string& body = "", + const std::unordered_map& headers = + {}) -> http::response; + + // Asynchronous request + template + void asyncRequest( + http::verb method, const std::string& host, const std::string& port, + const std::string& target, ResponseHandler&& handler, int version = 11, + const std::string& content_type = "", const std::string& body = "", + const std::unordered_map& headers = {}); + + auto jsonRequest(http::verb method, const std::string& host, + const std::string& port, const std::string& target, + const json& json_body = {}, + const std::unordered_map& + headers = {}) -> json; + + template + void asyncJsonRequest( + http::verb method, const std::string& host, const std::string& port, + const std::string& target, ResponseHandler&& handler, + const json& json_body = {}, + const std::unordered_map& headers = {}); + + auto uploadFile(const std::string& host, const std::string& port, + const std::string& target, const std::string& filepath, + const std::string& field_name = "file") + -> http::response; + + void downloadFile(const std::string& host, const std::string& port, + const std::string& target, const std::string& filepath); + + template + auto requestWithRetry( + http::verb method, const std::string& host, const std::string& port, + const std::string& target, int retry_count = 3, int version = 11, + const std::string& content_type = "", const std::string& body = "", + const std::unordered_map& headers = {}) + -> http::response; + + template + std::vector> batchRequest( + const std::vector>& requests, + const std::unordered_map& headers = {}); + + template + void asyncBatchRequest( + const std::vector>& requests, + ResponseHandler&& handler, + const std::unordered_map& headers = {}); + + void runWithThreadPool(size_t num_threads); + + template + void asyncDownloadFile(const std::string& host, const std::string& port, + const std::string& target, + const std::string& filepath, + ResponseHandler&& handler); + +private: + tcp::resolver resolver_; + beast::tcp_stream stream_; + std::unordered_map default_headers_; + std::chrono::seconds timeout_{30}; +}; + +template +auto HttpClient::request(http::verb method, const std::string& host, + const std::string& port, const std::string& target, + int version, const std::string& content_type, + const std::string& body, + const std::unordered_map& + headers) -> http::response { + http::request req{method, target, version}; + req.set(http::field::host, host); + req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING); + + for (const auto& [key, value] : default_headers_) { + req.set(key, value); + } + + for (const auto& [key, value] : headers) { + req.set(key, value); + } + + if (!content_type.empty()) { + req.set(http::field::content_type, content_type); + } + + if (!body.empty()) { + req.body() = body; + req.prepare_payload(); + } + + auto const results = resolver_.resolve(host, port); + stream_.connect(results); + + stream_.expires_after(timeout_); + + http::write(stream_, req); + + beast::flat_buffer buffer; + http::response res; + http::read(stream_, buffer, res); + + beast::error_code ec; + stream_.socket().shutdown(tcp::socket::shutdown_both, ec); + + return res; +} + +template +void HttpClient::asyncRequest( + http::verb method, const std::string& host, const std::string& port, + const std::string& target, ResponseHandler&& handler, int version, + const std::string& content_type, const std::string& body, + const std::unordered_map& headers) { + auto req = std::make_shared>( + method, target, version); + req->set(http::field::host, host); + req->set(http::field::user_agent, BOOST_BEAST_VERSION_STRING); + + for (const auto& [key, value] : default_headers_) { + req->set(key, value); + } + + for (const auto& [key, value] : headers) { + req->set(key, value); + } + + if (!content_type.empty()) { + req->set(http::field::content_type, content_type); + } + + if (!body.empty()) { + req->body() = body; + req->prepare_payload(); + } + + resolver_.async_resolve( + host, port, + [this, req, handler = std::forward(handler)]( + beast::error_code ec, tcp::resolver::results_type results) { + if (ec) { + return handler(ec, {}); + } + + stream_.async_connect( + results, [this, req, handler = std::move(handler)]( + beast::error_code ec, + tcp::resolver::results_type::endpoint_type) { + if (ec) { + return handler(ec, {}); + } + + stream_.expires_after(timeout_); + + http::async_write( + stream_, *req, + [this, req, handler = std::move(handler)]( + beast::error_code ec, std::size_t) { + if (ec) { + return handler(ec, {}); + } + + auto res = std::make_shared>(); + auto buffer = + std::make_shared(); + + http::async_read( + stream_, *buffer, *res, + [this, res, buffer, + handler = std::move(handler)]( + beast::error_code ec, std::size_t) { + stream_.socket().shutdown( + tcp::socket::shutdown_both, ec); + handler(ec, std::move(*res)); + }); + }); + }); + }); +} + +template +void HttpClient::asyncJsonRequest( + http::verb method, const std::string& host, const std::string& port, + const std::string& target, ResponseHandler&& handler, const json& json_body, + const std::unordered_map& headers) { + asyncRequest( + method, host, port, target, + [handler = std::forward(handler)]( + beast::error_code ec, http::response res) { + if (ec) { + handler(ec, {}); + } else { + try { + auto jv = json::parse(res.body()); + handler({}, std::move(jv)); + } catch (const json::parse_error& e) { + handler(beast::error_code{e.id, beast::generic_category()}, + {}); + } + } + }, + 11, "application/json", json_body.empty() ? "" : json_body.dump(), + headers); +} + +template +auto HttpClient::requestWithRetry( + http::verb method, const std::string& host, const std::string& port, + const std::string& target, int retry_count, int version, + const std::string& content_type, const std::string& body, + const std::unordered_map& headers) + -> http::response { + beast::error_code ec; + http::response response; + for (int attempt = 0; attempt < retry_count; ++attempt) { + try { + response = request(method, host, port, target, version, + content_type, body, headers); + // If no exception was thrown, return the response + return response; + } catch (const beast::system_error& e) { + ec = e.code(); + std::cerr << "Request attempt " << (attempt + 1) + << " failed: " << ec.message() << std::endl; + if (attempt + 1 == retry_count) { + throw; // Throw the exception if this was the last retry + } + } + } + return response; +} + +template +std::vector> HttpClient::batchRequest( + const std::vector>& requests, + const std::unordered_map& headers) { + std::vector> responses; + for (const auto& [method, host, port, target] : requests) { + try { + responses.push_back( + request(method, host, port, target, 11, "", "", headers)); + } catch (const std::exception& e) { + std::cerr << "Batch request failed for " << target << ": " + << e.what() << std::endl; + // Push an empty response if an exception occurs (or handle as + // needed) + responses.emplace_back(); + } + } + return responses; +} + +template +void HttpClient::asyncBatchRequest( + const std::vector>& requests, + ResponseHandler&& handler, + const std::unordered_map& headers) { + auto responses = + std::make_shared>>(); + auto remaining = std::make_shared>(requests.size()); + + for (const auto& [method, host, port, target] : requests) { + asyncRequest( + method, host, port, target, + [handler, responses, remaining]( + beast::error_code ec, http::response res) { + if (ec) { + std::cerr << "Error during batch request: " << ec.message() + << std::endl; + responses + ->emplace_back(); // Empty response in case of error + } else { + responses->emplace_back(std::move(res)); + } + + if (--(*remaining) == 0) { + handler(*responses); + } + }, + 11, "", "", headers); + } +} + +template +void HttpClient::asyncDownloadFile(const std::string& host, + const std::string& port, + const std::string& target, + const std::string& filepath, + ResponseHandler&& handler) { + asyncRequest( + http::verb::get, host, port, target, + [filepath, handler = std::forward(handler)]( + beast::error_code ec, http::response res) { + if (ec) { + handler(ec, false); + } else { + std::ofstream outFile(filepath, std::ios::binary); + if (!outFile) { + std::cerr << "Failed to open file for writing: " << filepath + << std::endl; + handler(beast::error_code{}, false); + return; + } + outFile << res.body(); + handler({}, true); // Download successful + } + }); +} + +#endif // HTTP_CLIENT_HPP diff --git a/src/atom/extra/beast/ws.cpp b/src/atom/extra/beast/ws.cpp new file mode 100644 index 00000000..550cba98 --- /dev/null +++ b/src/atom/extra/beast/ws.cpp @@ -0,0 +1,79 @@ +#include "ws.hpp" + +#if __has_include("atom/log/loguru.hpp") +#include "atom/log/loguru.hpp" +#else +#include +#endif + +WSClient::WSClient(net::io_context& ioc) + : resolver_(net::make_strand(ioc)), + ws_(net::make_strand(ioc)), + ping_timer_(ioc) {} + +void WSClient::setTimeout(std::chrono::seconds timeout) { timeout_ = timeout; } + +void WSClient::setReconnectOptions(int retries, std::chrono::seconds interval) { + max_retries_ = retries; + reconnect_interval_ = interval; +} + +void WSClient::setPingInterval(std::chrono::seconds interval) { + ping_interval_ = interval; +} + +void WSClient::connect(const std::string& host, const std::string& port) { + auto const results = resolver_.resolve(host, port); + beast::get_lowest_layer(ws_).connect(results->endpoint()); + ws_.handshake(host, "/"); + startPing(); +} + +void WSClient::send(const std::string& message) { + ws_.write(net::buffer(message)); +} + +std::string WSClient::receive() { + beast::flat_buffer buffer; + ws_.read(buffer); + return beast::buffers_to_string(buffer.data()); +} + +void WSClient::close() { ws_.close(websocket::close_code::normal); } + +void WSClient::startPing() { + if (ping_interval_.count() > 0) { + ping_timer_.expires_after(ping_interval_); + ping_timer_.async_wait([this](beast::error_code ec) { + if (!ec) { + ws_.async_ping({}, [this](beast::error_code ec) { + if (!ec) { + startPing(); + } + }); + } + }); + } +} + +template +void WSClient::handleConnectError(beast::error_code ec, + ConnectHandler&& handler) { + if (retry_count_ < max_retries_) { + ++retry_count_; + LOG_F(ERROR, "Failed to connect: {}. Retrying in {} seconds...", + ec.message(), reconnect_interval_.count()); + ws_.next_layer().close(); + ping_timer_.expires_after(reconnect_interval_); + ping_timer_.async_wait([this, handler = std::forward( + handler)](beast::error_code ec) { + if (!ec) { + asyncConnect("example.com", "80", + std::forward(handler)); + } + }); + } else { + LOG_F(ERROR, "Failed to connect: {}. Giving up.", ec.message()); + handler(ec); + } +} diff --git a/src/atom/extra/beast/ws.hpp b/src/atom/extra/beast/ws.hpp new file mode 100644 index 00000000..b0656dfe --- /dev/null +++ b/src/atom/extra/beast/ws.hpp @@ -0,0 +1,151 @@ +#ifndef WS_CLIENT_HPP +#define WS_CLIENT_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace beast = boost::beast; +namespace net = boost::asio; +namespace websocket = beast::websocket; +using tcp = boost::asio::ip::tcp; +using json = nlohmann::json; + +class WSClient { +public: + explicit WSClient(net::io_context& ioc); + + void setTimeout(std::chrono::seconds timeout); + void setReconnectOptions(int retries, std::chrono::seconds interval); + void setPingInterval(std::chrono::seconds interval); + + void connect(const std::string& host, const std::string& port); + void send(const std::string& message); + std::string receive(); + void close(); + + template + void asyncConnect(const std::string& host, const std::string& port, + ConnectHandler&& handler); + + template + void asyncSend(const std::string& message, WriteHandler&& handler); + + template + void asyncReceive(ReadHandler&& handler); + + template + void asyncClose(CloseHandler&& handler); + + void asyncSendJson( + const json& jdata, + std::function handler); + + template + void asyncReceiveJson(JsonHandler&& handler); + +private: + void startPing(); + template + void handleConnectError(beast::error_code ec, ConnectHandler&& handler); + + tcp::resolver resolver_; + websocket::stream ws_; + net::steady_timer ping_timer_; + std::chrono::seconds timeout_{30}; + std::chrono::seconds ping_interval_{10}; + std::chrono::seconds reconnect_interval_{5}; + int max_retries_ = 3; + int retry_count_ = 0; +}; + +template +void WSClient::asyncConnect(const std::string& host, const std::string& port, + ConnectHandler&& handler) { + retry_count_ = 0; + resolver_.async_resolve( + host, port, + [this, handler = std::forward(handler)]( + beast::error_code ec, tcp::resolver::results_type results) { + if (ec) { + handleConnectError(ec, handler); + return; + } + + beast::get_lowest_layer(ws_).async_connect( + results, [this, handler = std::move(handler), results]( + beast::error_code ec, + tcp::resolver::results_type::endpoint_type) { + if (ec) { + handleConnectError(ec, handler); + return; + } + + ws_.async_handshake(results->host_name(), "/", + [this, handler = std::move(handler)]( + beast::error_code ec) { + if (!ec) { + startPing(); + } + handler(ec); + }); + }); + }); +} + +template +void WSClient::asyncSend(const std::string& message, WriteHandler&& handler) { + ws_.async_write(net::buffer(message), + [handler = std::forward(handler)]( + beast::error_code ec, std::size_t bytes_transferred) { + handler(ec, bytes_transferred); + }); +} + +template +void WSClient::asyncReceive(ReadHandler&& handler) { + auto buffer = std::make_shared(); + ws_.async_read( + *buffer, [buffer, handler = std::forward(handler)]( + beast::error_code ec, std::size_t bytes_transferred) { + if (ec) { + handler(ec, ""); + } else { + handler(ec, beast::buffers_to_string(buffer->data())); + } + }); +} + +template +void WSClient::asyncClose(CloseHandler&& handler) { + ws_.async_close(websocket::close_code::normal, + [handler = std::forward(handler)]( + beast::error_code ec) { handler(ec); }); +} + +template +void WSClient::asyncReceiveJson(JsonHandler&& handler) { + asyncReceive([handler = std::forward(handler)]( + beast::error_code ec, const std::string& message) { + if (ec) { + handler(ec, {}); + } else { + try { + auto jdata = json::parse(message); + handler(ec, jdata); + } catch (const json::parse_error&) { + handler(beast::error_code{}, {}); + } + } + }); +} + +#endif // WS_CLIENT_HPP diff --git a/src/atom/extra/inicpp/common.hpp b/src/atom/extra/inicpp/common.hpp new file mode 100644 index 00000000..82833125 --- /dev/null +++ b/src/atom/extra/inicpp/common.hpp @@ -0,0 +1,66 @@ +#ifndef ATOM_EXTRA_INICPP_COMMON_HPP +#define ATOM_EXTRA_INICPP_COMMON_HPP + +#include +#include +#include +#include +#include + +#include "macro.hpp" + +namespace inicpp { + +ATOM_CONSTEXPR auto whitespaces() -> std::string_view { return " \t\n\r\f\v"; } +ATOM_CONSTEXPR auto indents() -> std::string_view { return " \t"; } + +ATOM_INLINE void trim(std::string &str) { + auto first = str.find_first_not_of(whitespaces()); + auto last = str.find_last_not_of(whitespaces()); + + if (first == std::string::npos || last == std::string::npos) { + str.clear(); + } else { + str = str.substr(first, last - first + 1); + } +} + +ATOM_INLINE auto strToLong(std::string_view value) -> std::optional { + long result; + auto [ptr, ec] = + std::from_chars(value.data(), value.data() + value.size(), result); + if (ec == std::errc()) { + return result; + } + return std::nullopt; +} + +ATOM_INLINE auto strToULong(std::string_view value) + -> std::optional { + unsigned long result; + auto [ptr, ec] = + std::from_chars(value.data(), value.data() + value.size(), result); + if (ec == std::errc()) { + return result; + } + return std::nullopt; +} + +struct StringInsensitiveLess { + auto operator()(std::string_view lhs, std::string_view rhs) const -> bool { + auto tolower = [](unsigned char ctx) { return std::tolower(ctx); }; + + auto lhsRange = std::ranges::subrange(lhs.begin(), lhs.end()); + auto rhsRange = std::ranges::subrange(rhs.begin(), rhs.end()); + + return std::ranges::lexicographical_compare( + lhsRange, rhsRange, + [tolower](unsigned char first, unsigned char second) { + return tolower(first) < tolower(second); + }); + } +}; + +} // namespace inicpp + +#endif // ATOM_EXTRA_INICPP_COMMON_HPP diff --git a/src/atom/extra/inicpp/convert.hpp b/src/atom/extra/inicpp/convert.hpp new file mode 100644 index 00000000..cbef9869 --- /dev/null +++ b/src/atom/extra/inicpp/convert.hpp @@ -0,0 +1,210 @@ +#ifndef ATOM_EXTRA_INICPP_CONVERT_HPP +#define ATOM_EXTRA_INICPP_CONVERT_HPP + +#include +#include +#include "common.hpp" + +namespace inicpp { + +template +struct Convert {}; + +template <> +struct Convert { + void decode(std::string_view value, bool &result) { + std::string str(value); + std::ranges::transform(str, str.begin(), [](char c) { + return static_cast(::toupper(c)); + }); + + if (str == "TRUE") + result = true; + else if (str == "FALSE") + result = false; + else + throw std::invalid_argument("field is not a bool"); + } + + void encode(const bool value, std::string &result) { + result = value ? "true" : "false"; + } +}; + +template <> +struct Convert { + void decode(std::string_view value, char &result) { + if (value.empty()) + throw std::invalid_argument("field is empty"); + result = value.front(); + } + + void encode(const char value, std::string &result) { result = value; } +}; + +template <> +struct Convert { + void decode(std::string_view value, unsigned char &result) { + if (value.empty()) + throw std::invalid_argument("field is empty"); + result = value.front(); + } + + void encode(const unsigned char value, std::string &result) { + result = value; + } +}; + +template <> +struct Convert { + void decode(std::string_view value, short &result) { + if (auto tmp = strToLong(value); tmp.has_value()) + result = static_cast(tmp.value()); + else + throw std::invalid_argument("field is not a short"); + } + + void encode(const short value, std::string &result) { + result = std::to_string(value); + } +}; + +template <> +struct Convert { + void decode(std::string_view value, unsigned short &result) { + if (auto tmp = strToULong(value); tmp.has_value()) + result = static_cast(tmp.value()); + else + throw std::invalid_argument("field is not an unsigned short"); + } + + void encode(const unsigned short value, std::string &result) { + result = std::to_string(value); + } +}; + +template <> +struct Convert { + void decode(std::string_view value, int &result) { + if (auto tmp = strToLong(value); tmp.has_value()) + result = static_cast(tmp.value()); + else + throw std::invalid_argument("field is not an int"); + } + + void encode(const int value, std::string &result) { + result = std::to_string(value); + } +}; + +template <> +struct Convert { + void decode(std::string_view value, unsigned int &result) { + if (auto tmp = strToULong(value); tmp.has_value()) + result = static_cast(tmp.value()); + else + throw std::invalid_argument("field is not an unsigned int"); + } + + void encode(const unsigned int value, std::string &result) { + result = std::to_string(value); + } +}; + +template <> +struct Convert { + void decode(std::string_view value, long &result) { + if (auto tmp = strToLong(value); tmp.has_value()) + result = tmp.value(); + else + throw std::invalid_argument("field is not a long"); + } + + void encode(const long value, std::string &result) { + result = std::to_string(value); + } +}; + +template <> +struct Convert { + void decode(std::string_view value, unsigned long &result) { + if (auto tmp = strToULong(value); tmp.has_value()) + result = tmp.value(); + else + throw std::invalid_argument("field is not an unsigned long"); + } + + void encode(const unsigned long value, std::string &result) { + result = std::to_string(value); + } +}; + +template <> +struct Convert { + void decode(std::string_view value, double &result) { + result = std::stod(std::string(value)); + } + + void encode(const double value, std::string &result) { + result = std::to_string(value); + } +}; + +template <> +struct Convert { + void decode(std::string_view value, float &result) { + result = std::stof(std::string(value)); + } + + void encode(const float value, std::string &result) { + result = std::to_string(value); + } +}; + +template <> +struct Convert { + void decode(std::string_view value, std::string &result) { result = value; } + + void encode(const std::string &value, std::string &result) { + result = value; + } +}; + +#ifdef __cpp_lib_string_view +template <> +struct Convert { + void decode(std::string_view value, std::string_view &result) { + result = value; + } + + void encode(std::string_view value, std::string &result) { result = value; } +}; +#endif + +template <> +struct Convert { + void encode(const char *const &value, std::string &result) { + result = value; + } + + void decode(std::string_view value, const char *&result) { + result = value.data(); + } +}; + +// 对 char[n] 进行模板特化的支持 +template +struct Convert { + void decode(const std::string &value, char (&result)[N]) { + if (value.size() >= N) + throw std::invalid_argument( + "field value is too large for the char array"); + std::copy(value.begin(), value.end(), result); + result[value.size()] = '\0'; // Null-terminate the char array + } + + void encode(const char (&value)[N], std::string &result) { result = value; } +}; +} // namespace inicpp + +#endif // ATOM_EXTRA_INICPP_CONVERT_HPP diff --git a/src/atom/extra/inicpp/field.hpp b/src/atom/extra/inicpp/field.hpp new file mode 100644 index 00000000..f3148b70 --- /dev/null +++ b/src/atom/extra/inicpp/field.hpp @@ -0,0 +1,41 @@ +#ifndef ATOM_EXTRA_INICPP_INIFIELD_HPP +#define ATOM_EXTRA_INICPP_INIFIELD_HPP + +#include "convert.hpp" + +#include +#include + +namespace inicpp { + +class IniField { +private: + std::string value_; + +public: + IniField() = default; + explicit IniField(std::string value) : value_(std::move(value)) {} + IniField(const IniField &field) = default; + ~IniField() = default; + + template + T as() const { + Convert conv; + T result; + conv.decode(value_, result); + return result; + } + + template + IniField &operator=(const T &value) { + Convert conv; + conv.encode(value, value_); + return *this; + } + + IniField &operator=(const IniField &field) = default; +}; + +} // namespace inicpp + +#endif // ATOM_EXTRA_INICPP_INIFIELD_HPP diff --git a/src/atom/extra/inicpp/file.hpp b/src/atom/extra/inicpp/file.hpp new file mode 100644 index 00000000..96975bb7 --- /dev/null +++ b/src/atom/extra/inicpp/file.hpp @@ -0,0 +1,195 @@ +#ifndef ATOM_EXTRA_INICPP_INIFILE_HPP +#define ATOM_EXTRA_INICPP_INIFILE_HPP + +#include +#include +#include +#include "section.hpp" + +#include "atom/error/exception.hpp" + +namespace inicpp { + +template +class IniFileBase + : public std::map, Comparator> { +private: + char fieldSep_ = '='; + char esc_ = '\\'; + std::vector commentPrefixes_ = {"#", ";"}; + bool multiLineValues_ = false; + bool overwriteDuplicateFields_ = true; + + void eraseComment(std::string &str, std::string::size_type startpos = 0) { + for (const auto &commentPrefix : commentPrefixes_) { + auto pos = str.find(commentPrefix, startpos); + if (pos != std::string::npos) { + // Check for escaped comment + if (pos > 0 && str[pos - 1] == esc_) { + str.erase(pos - 1, 1); + continue; + } + str.erase(pos); + } + } + } + + void writeEscaped(std::ostream &oss, const std::string &str) const { + for (size_t i = 0; i < str.length(); ++i) { + auto prefixpos = std::ranges::find_if( + commentPrefixes_, [&](const std::string &prefix) { + return str.find(prefix, i) == i; + }); + + if (prefixpos != commentPrefixes_.end()) { + oss.put(esc_); + oss.write(prefixpos->c_str(), prefixpos->size()); + i += prefixpos->size() - 1; + } else if (multiLineValues_ && str[i] == '\n') { + oss.write("\n\t", 2); + } else { + oss.put(str[i]); + } + } + } + +public: + IniFileBase() = default; + + explicit IniFileBase(const std::string &filename) { load(filename); } + + explicit IniFileBase(std::istream &iss) { decode(iss); } + + ~IniFileBase() = default; + + void setFieldSep(char sep) { fieldSep_ = sep; } + + void setCommentPrefixes(const std::vector &commentPrefixes) { + commentPrefixes_ = commentPrefixes; + } + + void setEscapeChar(char esc) { esc_ = esc; } + + void setMultiLineValues(bool enable) { multiLineValues_ = enable; } + + void allowOverwriteDuplicateFields(bool allowed) { + overwriteDuplicateFields_ = allowed; + } + + /** Decodes a ini file from input stream. */ + void decode(std::istream &iss) { + this->clear(); + std::string line; + IniSectionBase *currentSection = nullptr; + std::string multiLineValueFieldName; + + int lineNo = 0; + while (std::getline(iss, line)) { + ++lineNo; + eraseComment(line); + bool hasIndent = line.find_first_not_of(indents()) != 0; + trim(line); + + if (line.empty()) { + continue; + } + + if (line.front() == '[') { + // Section line + auto pos = line.find(']'); + if (pos == std::string::npos) { + THROW_LOGIC_ERROR("Section not closed at line " + + std::to_string(lineNo)); + } + if (pos == 1) { + THROW_LOGIC_ERROR("Empty section name at line " + + std::to_string(lineNo)); + } + + std::string secName = line.substr(1, pos - 1); + currentSection = &(*this)[secName]; + multiLineValueFieldName.clear(); + } else { + if (!currentSection) + THROW_LOGIC_ERROR("Field without section at line " + + std::to_string(lineNo)); + + auto pos = line.find(fieldSep_); + if (multiLineValues_ && hasIndent && + !multiLineValueFieldName.empty()) { + (*currentSection)[multiLineValueFieldName] = + (*currentSection)[multiLineValueFieldName] + .template as() + + "\n" + line; + } else if (pos == std::string::npos) { + THROW_LOGIC_ERROR("Field separator missing at line " + + std::to_string(lineNo)); + } else { + std::string name = line.substr(0, pos); + trim(name); + + if (!overwriteDuplicateFields_ && + currentSection->count(name)) { + THROW_LOGIC_ERROR("Duplicate field at line " + + std::to_string(lineNo)); + } + + std::string value = line.substr(pos + 1); + trim(value); + (*currentSection)[name] = value; + + multiLineValueFieldName = name; + } + } + } + } + + /** Decodes an ini file from a string. */ + void decode(const std::string &content) { + std::istringstream ss(content); + decode(ss); + } + + /** Loads and decodes an ini file from a file path. */ + void load(const std::string &fileName) { + std::ifstream iss(fileName); + if (!iss.is_open()) { + THROW_FAIL_TO_OPEN_FILE("Unable to open file " + fileName); + } + decode(iss); + } + + /** Encodes the ini file to the output stream. */ + void encode(std::ostream &oss) const { + for (const auto §ionPair : *this) { + oss << '[' << sectionPair.first << "]\n"; + for (const auto &fieldPair : sectionPair.second) { + oss << fieldPair.first << fieldSep_ + << fieldPair.second.template as() << "\n"; + } + } + } + + /** Encodes the ini file to a string and returns it. */ + [[nodiscard]] auto encode() const -> std::string { + std::ostringstream sss; + encode(sss); + return sss.str(); + } + + /** Saves the ini file to a given file path. */ + void save(const std::string &fileName) const { + std::ofstream oss(fileName); + if (!oss.is_open()) { + THROW_FAIL_TO_OPEN_FILE("Unable to open file " + fileName); + } + encode(oss); + } +}; + +using IniFile = IniFileBase>; +using IniFileCaseInsensitive = IniFileBase; + +} // namespace inicpp + +#endif // ATOM_EXTRA_INICPP_INIFILE_HPP diff --git a/src/atom/extra/inicpp/inicpp.hpp b/src/atom/extra/inicpp/inicpp.hpp new file mode 100644 index 00000000..539ea15c --- /dev/null +++ b/src/atom/extra/inicpp/inicpp.hpp @@ -0,0 +1,10 @@ +#ifndef ATOM_EXTRA_INICPP_HPP +#define ATOM_EXTRA_INICPP_HPP + +#include "common.hpp" +#include "convert.hpp" +#include "field.hpp" +#include "file.hpp" +#include "section.hpp" + +#endif // ATOM_EXTRA_INICPP_HPP diff --git a/src/atom/extra/inicpp/section.hpp b/src/atom/extra/inicpp/section.hpp new file mode 100644 index 00000000..4a548f45 --- /dev/null +++ b/src/atom/extra/inicpp/section.hpp @@ -0,0 +1,23 @@ +#ifndef ATOM_EXTRA_INICPP_INISECTION_HPP +#define ATOM_EXTRA_INICPP_INISECTION_HPP + +#include +#include + +#include "field.hpp" + +namespace inicpp { + +template +class IniSectionBase : public std::map { +public: + IniSectionBase() = default; + ~IniSectionBase() = default; +}; + +using IniSection = IniSectionBase>; +using IniSectionCaseInsensitive = IniSectionBase; + +} // namespace inicpp + +#endif // ATOM_EXTRA_INICPP_INISECTION_HPP diff --git a/src/atom/extra/injection/all.hpp b/src/atom/extra/injection/all.hpp new file mode 100644 index 00000000..83a941b3 --- /dev/null +++ b/src/atom/extra/injection/all.hpp @@ -0,0 +1,7 @@ +#pragma once + +#include "common.hpp" +#include "inject.hpp" +#include "resolver.hpp" +#include "binding.hpp" +#include "container.hpp" diff --git a/src/atom/extra/injection/binding.hpp b/src/atom/extra/injection/binding.hpp new file mode 100644 index 00000000..ec1af662 --- /dev/null +++ b/src/atom/extra/injection/binding.hpp @@ -0,0 +1,80 @@ +#pragma once + +#include "common.hpp" +#include "resolver.hpp" + +namespace atom::extra { + +template +class BindingScope { +public: + void inTransientScope() { + lifecycle_ = Lifecycle::Transient; + } + + void inSingletonScope() { + lifecycle_ = Lifecycle::Singleton; + resolver_ = std::make_shared>(resolver_); + } + + void inRequestScope() { + lifecycle_ = Lifecycle::Request; + } + +protected: + ResolverPtr resolver_; + Lifecycle lifecycle_ = Lifecycle::Transient; +}; + +template +class BindingTo : public BindingScope { +public: + void toConstantValue(T&& value) { + this->resolver_ = std::make_shared>(std::forward(value)); + } + + BindingScope& toDynamicValue(Factory&& factory) { + this->resolver_ = std::make_shared>(std::move(factory)); + return *this; + } + + template + BindingScope& to() { + this->resolver_ = std::make_shared>(); + return *this; + } +}; + +template +class Binding : public BindingTo { +public: + typename T::value resolve(const Context& context) { + if (!this->resolver_) { + throw exceptions::ResolutionException("atom::extra::Resolver not found. Malformed binding."); + } + return this->resolver_->resolve(context); + } + + void when(const Tag& tag) { + tags_.push_back(tag); + } + + void whenTargetNamed(const std::string& name) { + targetName_ = name; + } + + bool matchesTag(const Tag& tag) const { + return std::find_if(tags_.begin(), tags_.end(), + [&](const Tag& t) { return t.name == tag.name; }) != tags_.end(); + } + + bool matchesTargetName(const std::string& name) const { + return targetName_ == name; + } + +private: + std::vector tags_; + std::string targetName_; +}; + +} // namespace atom::extra diff --git a/src/atom/extra/injection/common.hpp b/src/atom/extra/injection/common.hpp new file mode 100644 index 00000000..a363bc10 --- /dev/null +++ b/src/atom/extra/injection/common.hpp @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace atom::extra { + +// Forward declarations +template +class Container; + +template +struct Context { + Container& container; +}; + +// Concepts +template +concept Symbolic = requires { + typename T::value; +}; + +template +concept Injectable = requires { + { T::template resolve(std::declval&>()) } -> std::convertible_to>; +}; + +// Symbol +template +struct Symbol { + static_assert(!std::is_abstract_v, + "atom::extra::Container cannot bind/get abstract class value " + "(use a smart pointer instead)."); + using value = Interface; +}; + +// Factory +template +using Factory = std::function&)>; + +// Exceptions +namespace exceptions { + struct ResolutionException : public std::runtime_error { + using std::runtime_error::runtime_error; + }; +} + +// Lifecycle +enum class Lifecycle { + Transient, + Singleton, + Request +}; + +// Tag +struct Tag { + std::string name; + explicit Tag(std::string tag_name) : name(std::move(tag_name)) {} +}; + +// Named +template +struct Named { + std::string name; + using value = T; + explicit Named(std::string binding_name) : name(std::move(binding_name)) {} +}; + +// Multi +template +struct Multi { + using value = std::vector; +}; + +// Lazy +template +class Lazy { +public: + explicit Lazy(std::function factory) : factory_(std::move(factory)) {} + T get() const { return factory_(); } + +private: + std::function factory_; +}; + +} // namespace atom::extra diff --git a/src/atom/extra/injection/container.hpp b/src/atom/extra/injection/container.hpp new file mode 100644 index 00000000..b790075d --- /dev/null +++ b/src/atom/extra/injection/container.hpp @@ -0,0 +1,80 @@ +#pragma once + +#include "common.hpp" +#include "binding.hpp" +#include + +namespace atom::extra { + +template +class Container { +public: + using BindingMap = std::tuple...>; + + template + BindingTo& bind() { + static_assert((std::is_same_v || ...), + "atom::extra::Container symbol not registered"); + return std::get>(bindings_); + } + + template + typename T::value get() { + return get(Tag("")); + } + + template + typename T::value get(const Tag& tag) { + static_assert((std::is_same_v || ...), + "atom::extra::Container symbol not registered"); + auto& binding = std::get>(bindings_); + if (binding.matchesTag(tag)) { + return binding.resolve(context_); + } + throw exceptions::ResolutionException("No matching binding found for the given tag."); + } + + template + typename T::value getNamed(const std::string& name) { + static_assert((std::is_same_v || ...), + "atom::extra::Container symbol not registered"); + auto& binding = std::get>(bindings_); + if (binding.matchesTargetName(name)) { + return binding.resolve(context_); + } + throw exceptions::ResolutionException("No matching binding found for the given name."); + } + + template + std::vector getAll() { + static_assert((std::is_same_v || ...), + "atom::extra::Container symbol not registered"); + std::vector result; + auto& binding = std::get>(bindings_); + result.push_back(binding.resolve(context_)); + return result; + } + + template + bool hasBinding() const { + return std::get>(bindings_).resolver_ != nullptr; + } + + template + void unbind() { + std::get>(bindings_).resolver_.reset(); + } + + std::unique_ptr createChildContainer() { + auto child = std::make_unique(); + child->parent_ = this; + return child; + } + +private: + BindingMap bindings_; + Context context_{*this}; + Container* parent_ = nullptr; +}; + +} // namespace atom::extra diff --git a/src/atom/extra/injection/inject.hpp b/src/atom/extra/injection/inject.hpp new file mode 100644 index 00000000..13c6236b --- /dev/null +++ b/src/atom/extra/injection/inject.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "common.hpp" + +namespace atom::extra { + +template +struct Inject { + template + static auto resolve(const Context& context) { + return std::make_tuple(context.container.template get()...); + } +}; + +template > +struct InjectableA : Inject {}; + +} // namespace atom::extra diff --git a/src/atom/extra/injection/resolver.hpp b/src/atom/extra/injection/resolver.hpp new file mode 100644 index 00000000..33d44100 --- /dev/null +++ b/src/atom/extra/injection/resolver.hpp @@ -0,0 +1,102 @@ +#pragma once + +#include +#include "common.hpp" +#include "inject.hpp" + +namespace atom::extra { + +template +class Resolver { +public: + virtual ~Resolver() = default; + virtual T resolve(const Context&) = 0; +}; + +template +using ResolverPtr = std::shared_ptr>; + +template +class ConstantResolver : public Resolver { +public: + explicit ConstantResolver(T value) : value_(std::move(value)) {} + T resolve(const Context&) override { return value_; } + +private: + T value_; +}; + +template +class DynamicResolver : public Resolver { +public: + explicit DynamicResolver(Factory factory) + : factory_(std::move(factory)) {} + T resolve(const Context& context) override { + return factory_(context); + } + +private: + Factory factory_; +}; + +template +class AutoResolver : public Resolver { +public: + T resolve(const Context& context) override { + return std::make_from_tuple( + InjectableA::template resolve(context)); + } +}; + +template +class AutoResolver, U, SymbolTypes...> + : public Resolver, SymbolTypes...> { +public: + std::unique_ptr resolve( + const Context& context) override { + return std::apply( + [](auto&&... deps) { + return std::make_unique( + std::forward(deps)...); + }, + InjectableA::template resolve(context)); + } +}; + +template +class AutoResolver, U, SymbolTypes...> + : public Resolver, SymbolTypes...> { +public: + std::shared_ptr resolve( + const Context& context) override { + return std::apply( + [](auto&&... deps) { + return std::make_shared( + std::forward(deps)...); + }, + InjectableA::template resolve(context)); + } +}; + +template +class CachedResolver : public Resolver { + static_assert(std::is_copy_constructible_v, + "atom::extra::CachedResolver requires a copy constructor. Are " + "you caching a unique_ptr?"); + +public: + explicit CachedResolver(ResolverPtr parent) + : parent_(std::move(parent)) {} + T resolve(const Context& context) override { + if (!cached_.has_value()) { + cached_ = parent_->resolve(context); + } + return cached_.value(); + } + +private: + std::optional cached_; + ResolverPtr parent_; +}; + +} // namespace atom::extra diff --git a/src/atom/function/concept.hpp b/src/atom/function/concept.hpp index c2ff95a8..9829a8cf 100644 --- a/src/atom/function/concept.hpp +++ b/src/atom/function/concept.hpp @@ -9,6 +9,9 @@ #ifndef ATOM_META_CONCEPT_HPP #define ATOM_META_CONCEPT_HPP +#include +#include +#include #if __cplusplus < 202002L #error "C++20 is required for this library" #endif @@ -193,13 +196,34 @@ template concept AnyChar = Char || WChar || Char16 || Char32; template -concept String = requires(T x) { +concept NotSequenceContainer = + !std::is_same_v> && + !std::is_same_v> && + !std::is_same_v>; + +template +concept NotAssociativeOrSequenceContainer = + !std::is_same_v> && + !std::is_same_v< + T, std::unordered_map> && + !std::is_same_v< + T, std::multimap> && + !std::is_same_v> && + !NotSequenceContainer; + +template +concept String = NotAssociativeOrSequenceContainer && requires(T x) { { x.size() } -> std::convertible_to; { x.empty() } -> std::convertible_to; { x.begin() } -> std::convertible_to; { x.end() } -> std::convertible_to; }; +template +concept IsBuiltIn = std::is_fundamental_v || String; + // Checks if a type is an enum template concept Enum = std::is_enum_v; @@ -226,6 +250,9 @@ concept WeakPointer = requires(T x) { requires std::is_same_v>; }; +template +concept SmartPointer = UniquePointer || SharedPointer || WeakPointer; + // Checks if a type is a reference template concept Reference = std::is_reference_v; @@ -271,6 +298,20 @@ concept Container = requires(T x) { requires Iterable; }; +template +concept StringContainer = requires(T t) { + typename T::value_type; + String || Char; + { t.push_back(std::declval()) }; +}; + +template +concept NumberContainer = requires(T t) { + typename T::value_type; + Number; + { t.push_back(std::declval()) }; +}; + // Checks if a type is an associative container like map or set template concept AssociativeContainer = requires(T x) { diff --git a/src/atom/function/constructor.hpp b/src/atom/function/constructor.hpp index 140870e4..52c2ed66 100644 --- a/src/atom/function/constructor.hpp +++ b/src/atom/function/constructor.hpp @@ -9,13 +9,21 @@ #ifndef ATOM_META_CONSTRUCTOR_HPP #define ATOM_META_CONSTRUCTOR_HPP +#include #include -#include "func_traits.hpp" - #include "atom/error/exception.hpp" +#include "func_traits.hpp" namespace atom::meta { + +/*! + * \brief Binds a member function to an object. + * \tparam MemberFunc Type of the member function. + * \tparam ClassType Type of the class. + * \param member_func Pointer to the member function. + * \return A lambda that binds the member function to an object. + */ template auto bindMemberFunction(MemberFunc ClassType::*member_func) { return [member_func](ClassType &obj, auto &&...params) { @@ -29,11 +37,24 @@ auto bindMemberFunction(MemberFunc ClassType::*member_func) { }; } +/*! + * \brief Binds a static function. + * \tparam Func Type of the function. + * \param func The static function. + * \return The static function itself. + */ template auto bindStaticFunction(Func func) { return func; } +/*! + * \brief Binds a member variable to an object. + * \tparam MemberType Type of the member variable. + * \tparam ClassType Type of the class. + * \param member_var Pointer to the member variable. + * \return A lambda that binds the member variable to an object. + */ template auto bindMemberVariable(MemberType ClassType::*member_var) { return [member_var](ClassType &instance) -> MemberType & { @@ -41,6 +62,13 @@ auto bindMemberVariable(MemberType ClassType::*member_var) { }; } +/*! + * \brief Builds a shared constructor for a class. + * \tparam Class Type of the class. + * \tparam Params Types of the constructor parameters. + * \param unused Unused parameter to deduce types. + * \return A lambda that constructs a shared pointer to the class. + */ template auto buildSharedConstructor(Class (* /*unused*/)(Params...)) { return [](auto &&...params) { @@ -49,6 +77,13 @@ auto buildSharedConstructor(Class (* /*unused*/)(Params...)) { }; } +/*! + * \brief Builds a copy constructor for a class. + * \tparam Class Type of the class. + * \tparam Params Types of the constructor parameters. + * \param unused Unused parameter to deduce types. + * \return A lambda that constructs an instance of the class. + */ template auto buildCopyConstructor(Class (* /*unused*/)(Params...)) { return [](auto &&...params) { @@ -56,6 +91,13 @@ auto buildCopyConstructor(Class (* /*unused*/)(Params...)) { }; } +/*! + * \brief Builds a plain constructor for a class. + * \tparam Class Type of the class. + * \tparam Params Types of the constructor parameters. + * \param unused Unused parameter to deduce types. + * \return A lambda that constructs an instance of the class. + */ template auto buildPlainConstructor(Class (* /*unused*/)(Params...)) { return [](auto &&...params) { @@ -63,6 +105,12 @@ auto buildPlainConstructor(Class (* /*unused*/)(Params...)) { }; } +/*! + * \brief Builds a constructor for a class with specified arguments. + * \tparam Class Type of the class. + * \tparam Args Types of the constructor arguments. + * \return A lambda that constructs a shared pointer to the class. + */ template auto buildConstructor() { return [](Args... args) -> std::shared_ptr { @@ -70,28 +118,50 @@ auto buildConstructor() { }; } +/*! + * \brief Builds a default constructor for a class. + * \tparam Class Type of the class. + * \return A lambda that constructs an instance of the class. + */ template auto buildDefaultConstructor() { return []() { return Class(); }; } +/*! + * \brief Constructs an instance of a class based on its traits. + * \tparam T Type of the function. + * \return A lambda that constructs an instance of the class. + */ template auto constructor() { T *func = nullptr; using ClassType = typename FunctionTraits::class_type; if constexpr (!std::is_copy_constructible_v) { - return build_shared_constructor_(func); + return buildSharedConstructor(func); } else { - return build_copy_constructor_(func); + return buildCopyConstructor(func); } } +/*! + * \brief Constructs an instance of a class with specified arguments. + * \tparam Class Type of the class. + * \tparam Args Types of the constructor arguments. + * \return A lambda that constructs a shared pointer to the class. + */ template auto constructor() { return buildConstructor(); } +/*! + * \brief Constructs an instance of a class using the default constructor. + * \tparam Class Type of the class. + * \return A lambda that constructs an instance of the class. + * \throws Exception if the class is not default constructible. + */ template auto defaultConstructor() { if constexpr (std::is_default_constructible_v) { @@ -100,5 +170,30 @@ auto defaultConstructor() { THROW_NOT_FOUND("Class is not default constructible"); } } + +/*! + * \brief Constructs an instance of a class using a move constructor. + * \tparam Class Type of the class. + * \return A lambda that constructs an instance of the class using a move + * constructor. + */ +template +auto buildMoveConstructor() { + return [](Class &&instance) { return Class(std::move(instance)); }; +} + +/*! + * \brief Constructs an instance of a class using an initializer list. + * \tparam Class Type of the class. + * \tparam T Type of the elements in the initializer list. + * \return A lambda that constructs an instance of the class using an + * initializer list. + */ +template +auto buildInitializerListConstructor() { + return [](std::initializer_list init_list) { return Class(init_list); }; +} + } // namespace atom::meta + #endif // ATOM_META_CONSTRUCTOR_HPP diff --git a/src/atom/function/enum.hpp b/src/atom/function/enum.hpp index 1f9946b8..02c94d17 100644 --- a/src/atom/function/enum.hpp +++ b/src/atom/function/enum.hpp @@ -15,11 +15,18 @@ #include #include -// EnumTraits 模板结构体需要为每个枚举类型特化。 +/*! + * \brief Template struct for EnumTraits, needs to be specialized for each enum + * type. \tparam T Enum type. + */ template struct EnumTraits; -// 辅助函数:从编译器生成的函数签名中提取枚举名称 +/*! + * \brief Helper function to extract enum name from compiler-generated function + * signature. \tparam T Enum type. \param func_sig Function signature. \return + * Extracted enum name. + */ template constexpr std::string_view extract_enum_name(const char* func_sig) { std::string_view name(func_sig); @@ -36,7 +43,12 @@ constexpr std::string_view extract_enum_name(const char* func_sig) { return name.substr(prefixPos, suffixPos - prefixPos); } -// 生成枚举值的字符串名称 +/*! + * \brief Generate string name for enum value. + * \tparam T Enum type. + * \tparam Value Enum value. + * \return String name of the enum value. + */ template constexpr std::string_view enum_name() noexcept { #if defined(__clang__) || defined(__GNUC__) @@ -48,7 +60,12 @@ constexpr std::string_view enum_name() noexcept { #endif } -// 枚举值转字符串 +/*! + * \brief Convert enum value to string. + * \tparam T Enum type. + * \param value Enum value. + * \return String name of the enum value. + */ template constexpr auto enum_name(T value) noexcept -> std::string_view { constexpr auto VALUES = EnumTraits::values; @@ -62,7 +79,12 @@ constexpr auto enum_name(T value) noexcept -> std::string_view { return {}; } -// 字符串转枚举值 +/*! + * \brief Convert string to enum value. + * \tparam T Enum type. + * \param name String name of the enum value. + * \return Optional enum value. + */ template constexpr auto enum_cast(std::string_view name) noexcept -> std::optional { constexpr auto VALUES = EnumTraits::values; @@ -76,13 +98,23 @@ constexpr auto enum_cast(std::string_view name) noexcept -> std::optional { return std::nullopt; } -// 枚举值转整数 +/*! + * \brief Convert enum value to integer. + * \tparam T Enum type. + * \param value Enum value. + * \return Integer representation of the enum value. + */ template constexpr auto enum_to_integer(T value) noexcept { return static_cast>(value); } -// 整数转枚举值 +/*! + * \brief Convert integer to enum value. + * \tparam T Enum type. + * \param value Integer value. + * \return Optional enum value. + */ template constexpr auto integer_to_enum(std::underlying_type_t value) noexcept -> std::optional { @@ -96,7 +128,12 @@ constexpr auto integer_to_enum(std::underlying_type_t value) noexcept return std::nullopt; } -// 检查枚举值是否有效 +/*! + * \brief Check if enum value is valid. + * \tparam T Enum type. + * \param value Enum value. + * \return True if valid, false otherwise. + */ template constexpr auto enum_contains(T value) noexcept -> bool { constexpr auto VALUES = EnumTraits::values; @@ -108,7 +145,11 @@ constexpr auto enum_contains(T value) noexcept -> bool { return false; } -// 获取所有枚举值和名称 +/*! + * \brief Get all enum values and names. + * \tparam T Enum type. + * \return Array of pairs of enum values and their names. + */ template constexpr auto enum_entries() noexcept { constexpr auto VALUES = EnumTraits::values; @@ -122,7 +163,13 @@ constexpr auto enum_entries() noexcept { return entries; } -// 支持标志枚举(位运算) +/*! + * \brief Support for flag enums (bitwise operations). + * \tparam T Enum type. + * \param lhs Left-hand side enum value. + * \param rhs Right-hand side enum value. + * \return Result of bitwise OR operation. + */ template , int> = 0> constexpr auto operator|(T lhs, T rhs) noexcept -> T { using UT = std::underlying_type_t; @@ -162,13 +209,21 @@ constexpr auto operator~(T rhs) noexcept -> T { return static_cast(~static_cast(rhs)); } -// 获取枚举的默认值 +/*! + * \brief Get the default value of an enum. + * \tparam T Enum type. + * \return Default enum value. + */ template constexpr auto enum_default() noexcept -> T { return EnumTraits::values[0]; } -// 根据名字排序枚举值 +/*! + * \brief Sort enum values by their names. + * \tparam T Enum type. + * \return Sorted array of pairs of enum values and their names. + */ template constexpr auto enum_sorted_by_name() noexcept { auto entries = enum_entries(); @@ -177,7 +232,11 @@ constexpr auto enum_sorted_by_name() noexcept { return entries; } -// 根据整数值排序枚举值 +/*! + * \brief Sort enum values by their integer values. + * \tparam T Enum type. + * \return Sorted array of pairs of enum values and their names. + */ template constexpr auto enum_sorted_by_value() noexcept { auto entries = enum_entries(); @@ -187,7 +246,12 @@ constexpr auto enum_sorted_by_value() noexcept { return entries; } -// 模糊匹配字符串并转换为枚举值 +/*! + * \brief Fuzzy match string and convert to enum value. + * \tparam T Enum type. + * \param name String name of the enum value. + * \return Optional enum value. + */ template auto enum_cast_fuzzy(std::string_view name) -> std::optional { constexpr auto names = EnumTraits::names; @@ -200,7 +264,12 @@ auto enum_cast_fuzzy(std::string_view name) -> std::optional { return std::nullopt; } -// 检查整数值是否在枚举范围内 +/*! + * \brief Check if integer value is within enum range. + * \tparam T Enum type. + * \param value Integer value. + * \return True if within range, false otherwise. + */ template constexpr auto integer_in_enum_range(std::underlying_type_t value) noexcept -> bool { @@ -209,12 +278,21 @@ constexpr auto integer_in_enum_range(std::underlying_type_t value) noexcept [value](T e) { return enum_to_integer(e) == value; }); } -// 添加枚举别名支持 +/*! + * \brief Support for enum aliases. + * \tparam T Enum type. + */ template struct EnumAliasTraits { static constexpr std::array ALIASES = {}; }; +/*! + * \brief Convert string to enum value with alias support. + * \tparam T Enum type. + * \param name String name of the enum value. + * \return Optional enum value. + */ template constexpr auto enum_cast_with_alias(std::string_view name) noexcept -> std::optional { diff --git a/src/atom/function/field_count.hpp b/src/atom/function/field_count.hpp index 9a70ee04..ff4f3d0a 100644 --- a/src/atom/function/field_count.hpp +++ b/src/atom/function/field_count.hpp @@ -14,17 +14,38 @@ namespace atom::meta { +/*! + * \brief A struct that can be converted to any type. + */ struct Any { + /*! + * \brief Constexpr conversion operator to any type. + * \tparam T The type to convert to. + * \return An instance of type T. + */ template consteval operator T() const noexcept; }; +/*! + * \brief Checks if a type T is constructible with braces. + * \tparam T The type to check. + * \tparam I The index sequence. + * \param[in] std::index_sequence The index sequence. + * \return True if T is constructible with braces, false otherwise. + */ template consteval auto isBracesConstructible(std::index_sequence) noexcept -> bool { return requires { T{((void)I, std::declval())...}; }; } +/*! + * \brief Recursively counts the number of fields in a type T. + * \tparam T The type to count fields in. + * \tparam N The current count of fields. + * \return The number of fields in type T. + */ template consteval auto fieldCount() noexcept -> std::size_t { if constexpr (!isBracesConstructible( @@ -35,9 +56,18 @@ consteval auto fieldCount() noexcept -> std::size_t { } } +/*! + * \brief A template struct to hold type information. + * \tparam T The type to hold information for. + */ template struct TypeInfo; +/*! + * \brief Gets the number of fields in a type T. + * \tparam T The type to get the field count for. + * \return The number of fields in type T. + */ template consteval auto fieldCountOf() noexcept -> std::size_t { if constexpr (std::is_aggregate_v) { @@ -51,7 +81,12 @@ consteval auto fieldCountOf() noexcept -> std::size_t { } } -// Overload for arrays +/*! + * \brief Gets the number of elements in an array. + * \tparam T The type of the array elements. + * \tparam N The number of elements in the array. + * \return The number of elements in the array. + */ template consteval auto fieldCountOf() noexcept -> std::size_t { return N; diff --git a/src/atom/function/global_ptr.hpp b/src/atom/function/global_ptr.hpp index 1c97484c..2a83b15a 100644 --- a/src/atom/function/global_ptr.hpp +++ b/src/atom/function/global_ptr.hpp @@ -31,6 +31,14 @@ #define GetPtrOrCreate \ GlobalSharedPtrManager::getInstance().getOrCreateSharedPtr +#define GET_OR_CREATE_PTR(variable, type, constant, ...) \ + if (auto ptr = GetPtrOrCreate( \ + constant, [] { return std::make_shared(__VA_ARGS__); })) { \ + variable = ptr; \ + } else { \ + THROW_UNLAWFUL_OPERATION("Failed to create " #type "."); \ + } + /** * @brief The GlobalSharedPtrManager class manages a collection of shared * pointers and weak pointers. It provides functions to add, remove, and diff --git a/src/atom/function/god.hpp b/src/atom/function/god.hpp index 8ec4e370..f4af7bc7 100644 --- a/src/atom/function/god.hpp +++ b/src/atom/function/god.hpp @@ -16,65 +16,146 @@ #include "atom/macro.hpp" namespace atom::meta { - -// No-op function for blessing with no bugs +/*! + * \brief No-op function for blessing with no bugs. + */ ATOM_INLINE void blessNoBugs() {} +/*! + * \brief Casts a value from one type to another. + * \tparam To The type to cast to. + * \tparam From The type to cast from. + * \param f The value to cast. + * \return The casted value. + */ template constexpr auto cast(From&& f) -> To { return static_cast(std::forward(f)); } +/*! + * \brief Aligns a value up to the nearest multiple of A. + * \tparam A The alignment value, must be a power of 2. + * \tparam X The type of the value to align. + * \param x The value to align. + * \return The aligned value. + */ template constexpr auto alignUp(X x) -> X { static_assert((A & (A - 1)) == 0, "A must be power of 2"); return (x + static_cast(A - 1)) & ~static_cast(A - 1); } +/*! + * \brief Aligns a pointer up to the nearest multiple of A. + * \tparam A The alignment value, must be a power of 2. + * \tparam X The type of the pointer to align. + * \param x The pointer to align. + * \return The aligned pointer. + */ template constexpr auto alignUp(X* x) -> X* { return reinterpret_cast(alignUp(reinterpret_cast(x))); } +/*! + * \brief Aligns a value up to the nearest multiple of a. + * \tparam X The type of the value to align. + * \tparam A The type of the alignment value, must be integral. + * \param x The value to align. + * \param a The alignment value. + * \return The aligned value. + */ template constexpr auto alignUp(X x, A a) -> X { static_assert(std::is_integral::value, "A must be integral type"); return (x + static_cast(a - 1)) & ~static_cast(a - 1); } +/*! + * \brief Aligns a pointer up to the nearest multiple of a. + * \tparam X The type of the pointer to align. + * \tparam A The type of the alignment value, must be integral. + * \param x The pointer to align. + * \param a The alignment value. + * \return The aligned pointer. + */ template constexpr auto alignUp(X* x, A a) -> X* { return reinterpret_cast(alignUp(reinterpret_cast(x), a)); } +/*! + * \brief Aligns a value down to the nearest multiple of A. + * \tparam A The alignment value, must be a power of 2. + * \tparam X The type of the value to align. + * \param x The value to align. + * \return The aligned value. + */ template constexpr auto alignDown(X x) -> X { static_assert((A & (A - 1)) == 0, "A must be power of 2"); return x & ~static_cast(A - 1); } +/*! + * \brief Aligns a pointer down to the nearest multiple of A. + * \tparam A The alignment value, must be a power of 2. + * \tparam X The type of the pointer to align. + * \param x The pointer to align. + * \return The aligned pointer. + */ template constexpr auto alignDown(X* x) -> X* { return reinterpret_cast(alignDown(reinterpret_cast(x))); } +/*! + * \brief Aligns a value down to the nearest multiple of a. + * \tparam X The type of the value to align. + * \tparam A The type of the alignment value, must be integral. + * \param x The value to align. + * \param a The alignment value. + * \return The aligned value. + */ template constexpr auto alignDown(X x, A a) -> X { static_assert(std::is_integral::value, "A must be integral type"); return x & ~static_cast(a - 1); } +/*! + * \brief Aligns a pointer down to the nearest multiple of a. + * \tparam X The type of the pointer to align. + * \tparam A The type of the alignment value, must be integral. + * \param x The pointer to align. + * \param a The alignment value. + * \return The aligned pointer. + */ template constexpr auto alignDown(X* x, A a) -> X* { return reinterpret_cast(alignDown(reinterpret_cast(x), a)); } +/*! + * \brief Computes the base-2 logarithm of an integral value. + * \tparam T The type of the value, must be integral. + * \param x The value to compute the logarithm of. + * \return The base-2 logarithm of the value. + */ template constexpr auto log2(T x) -> T { static_assert(std::is_integral::value, "T must be integral type"); return x <= 1 ? 0 : 1 + atom::meta::log2(x >> 1); } +/*! + * \brief Computes the number of blocks of size N needed to cover a value. + * \tparam N The block size, must be a power of 2. + * \tparam X The type of the value. + * \param x The value to compute the number of blocks for. + * \return The number of blocks needed to cover the value. + */ template constexpr auto nb(X x) -> X { static_assert((N & (N - 1)) == 0, "N must be power of 2"); @@ -82,11 +163,24 @@ constexpr auto nb(X x) -> X { !!(x & static_cast(N - 1)); } +/*! + * \brief Compares two values for equality. + * \tparam T The type of the values. + * \param p Pointer to the first value. + * \param q Pointer to the second value. + * \return True if the values are equal, false otherwise. + */ template ATOM_INLINE auto eq(const void* p, const void* q) -> bool { return *reinterpret_cast(p) == *reinterpret_cast(q); } +/*! + * \brief Copies N bytes from src to dst. + * \tparam N The number of bytes to copy. + * \param dst Pointer to the destination. + * \param src Pointer to the source. + */ template ATOM_INLINE void copy(void* dst, const void* src) { if constexpr (N > 0) { @@ -94,9 +188,20 @@ ATOM_INLINE void copy(void* dst, const void* src) { } } +/*! + * \brief Specialization of copy for N = 0, does nothing. + */ template <> ATOM_INLINE void copy<0>(void*, const void*) {} +/*! + * \brief Swaps the value pointed to by p with v. + * \tparam T The type of the value pointed to by p. + * \tparam V The type of the value v. + * \param p Pointer to the value to swap. + * \param v The value to swap with. + * \return The original value pointed to by p. + */ template ATOM_INLINE auto swap(T* p, V v) -> T { T x = *p; @@ -104,6 +209,14 @@ ATOM_INLINE auto swap(T* p, V v) -> T { return x; } +/*! + * \brief Adds v to the value pointed to by p and returns the original value. + * \tparam T The type of the value pointed to by p. + * \tparam V The type of the value v. + * \param p Pointer to the value to add to. + * \param v The value to add. + * \return The original value pointed to by p. + */ template ATOM_INLINE auto fetchAdd(T* p, V v) -> T { T x = *p; @@ -111,6 +224,12 @@ ATOM_INLINE auto fetchAdd(T* p, V v) -> T { return x; } +/*! + * \brief Subtracts v from the value pointed to by p and returns the original + * value. \tparam T The type of the value pointed to by p. \tparam V The type of + * the value v. \param p Pointer to the value to subtract from. \param v The + * value to subtract. \return The original value pointed to by p. + */ template ATOM_INLINE auto fetchSub(T* p, V v) -> T { T x = *p; @@ -118,6 +237,14 @@ ATOM_INLINE auto fetchSub(T* p, V v) -> T { return x; } +/*! + * \brief Performs a bitwise AND between the value pointed to by p and v, and + * returns the original value. \tparam T The type of the value pointed to by p. + * \tparam V The type of the value v. + * \param p Pointer to the value to AND. + * \param v The value to AND with. + * \return The original value pointed to by p. + */ template ATOM_INLINE auto fetchAnd(T* p, V v) -> T { T x = *p; @@ -125,6 +252,14 @@ ATOM_INLINE auto fetchAnd(T* p, V v) -> T { return x; } +/*! + * \brief Performs a bitwise OR between the value pointed to by p and v, and + * returns the original value. \tparam T The type of the value pointed to by p. + * \tparam V The type of the value v. + * \param p Pointer to the value to OR. + * \param v The value to OR with. + * \return The original value pointed to by p. + */ template ATOM_INLINE auto fetchOr(T* p, V v) -> T { T x = *p; @@ -132,6 +267,14 @@ ATOM_INLINE auto fetchOr(T* p, V v) -> T { return x; } +/*! + * \brief Performs a bitwise XOR between the value pointed to by p and v, and + * returns the original value. \tparam T The type of the value pointed to by p. + * \tparam V The type of the value v. + * \param p Pointer to the value to XOR. + * \param v The value to XOR with. + * \return The original value pointed to by p. + */ template ATOM_INLINE auto fetchXor(T* p, V v) -> T { T x = *p; @@ -139,80 +282,169 @@ ATOM_INLINE auto fetchXor(T* p, V v) -> T { return x; } +/*! + * \brief Alias for std::enable_if_t. + * \tparam C The condition. + * \tparam T The type to enable if the condition is true. + */ template using if_t = std::enable_if_t; +/*! + * \brief Alias for std::remove_reference_t. + * \tparam T The type to remove reference from. + */ template using rmRefT = std::remove_reference_t; +/*! + * \brief Alias for std::remove_cv_t. + * \tparam T The type to remove const and volatile qualifiers from. + */ template using rmCvT = std::remove_cv_t; +/*! + * \brief Alias for removing both const, volatile qualifiers and reference. + * \tparam T The type to remove const, volatile qualifiers and reference from. + */ template using rmCvRefT = rmCvT>; +/*! + * \brief Alias for std::remove_extent_t. + * \tparam T The type to remove extent from. + */ template using rmArrT = std::remove_extent_t; +/*! + * \brief Alias for std::add_const_t. + * \tparam T The type to add const qualifier to. + */ template using constT = std::add_const_t; +/*! + * \brief Alias for adding const qualifier and lvalue reference. + * \tparam T The type to add const qualifier and lvalue reference to. + */ template using constRefT = std::add_lvalue_reference_t>>; namespace detail { + +/*! + * \brief Helper struct to check if all types are the same. + * \tparam T The types to check. + */ template struct isSame { static constexpr bool value = false; }; +/*! + * \brief Specialization of isSame for two or more types. + * \tparam T The first type. + * \tparam U The second type. + * \tparam X The remaining types. + */ template struct isSame { static constexpr bool value = std::is_same_v || isSame::value; }; + } // namespace detail +/*! + * \brief Checks if all types are the same. + * \tparam T The first type. + * \tparam U The second type. + * \tparam X The remaining types. + * \return True if all types are the same, false otherwise. + */ template constexpr auto isSame() -> bool { return detail::isSame::value; } +/*! + * \brief Checks if a type is a reference. + * \tparam T The type to check. + * \return True if the type is a reference, false otherwise. + */ template constexpr auto isRef() -> bool { return std::is_reference_v; } +/*! + * \brief Checks if a type is an array. + * \tparam T The type to check. + * \return True if the type is an array, false otherwise. + */ template constexpr auto isArray() -> bool { return std::is_array_v; } +/*! + * \brief Checks if a type is a class. + * \tparam T The type to check. + * \return True if the type is a class, false otherwise. + */ template constexpr auto isClass() -> bool { return std::is_class_v; } +/*! + * \brief Checks if a type is a scalar. + * \tparam T The type to check. + * \return True if the type is a scalar, false otherwise. + */ template constexpr auto isScalar() -> bool { return std::is_scalar_v; } +/*! + * \brief Checks if a type is trivially copyable. + * \tparam T The type to check. + * \return True if the type is trivially copyable, false otherwise. + */ template constexpr auto isTriviallyCopyable() -> bool { return std::is_trivially_copyable_v; } +/*! + * \brief Checks if a type is trivially destructible. + * \tparam T The type to check. + * \return True if the type is trivially destructible, false otherwise. + */ template constexpr auto isTriviallyDestructible() -> bool { return std::is_trivially_destructible_v; } +/*! + * \brief Checks if a type is a base of another type. + * \tparam B The base type. + * \tparam D The derived type. + * \return True if B is a base of D, false otherwise. + */ template constexpr auto isBaseOf() -> bool { return std::is_base_of_v; } +/*! + * \brief Checks if a type has a virtual destructor. + * \tparam T The type to check. + * \return True if the type has a virtual destructor, false otherwise. + */ template constexpr auto hasVirtualDestructor() -> bool { return std::has_virtual_destructor_v; diff --git a/src/atom/function/invoke.hpp b/src/atom/function/invoke.hpp index 91125945..86d1d27c 100644 --- a/src/atom/function/invoke.hpp +++ b/src/atom/function/invoke.hpp @@ -17,9 +17,23 @@ #include "atom/error/exception.hpp" +/*! + * \brief Concept to check if a function is invocable with given arguments. + * \tparam F The function type. + * \tparam Args The argument types. + */ template concept Invocable = std::is_invocable_v, std::decay_t...>; +/*! + * \brief Delays the invocation of a function with given arguments. + * \tparam F The function type. + * \tparam Args The argument types. + * \param func The function to be invoked. + * \param args The arguments to be passed to the function. + * \return A lambda that, when called, invokes the function with the given + * arguments. + */ template requires Invocable auto delayInvoke(F &&func, Args &&...args) { @@ -29,6 +43,16 @@ auto delayInvoke(F &&func, Args &&...args) { }; } +/*! + * \brief Delays the invocation of a member function with given arguments. + * \tparam R The return type of the member function. + * \tparam T The class type of the member function. + * \tparam Args The argument types. + * \param func The member function to be invoked. + * \param obj The object on which the member function will be invoked. + * \return A lambda that, when called, invokes the member function with the + * given arguments. + */ template auto delayMemInvoke(R (T::*func)(Args...), T *obj) { return [func, obj](Args... args) { @@ -36,6 +60,16 @@ auto delayMemInvoke(R (T::*func)(Args...), T *obj) { }; } +/*! + * \brief Delays the invocation of a const member function with given arguments. + * \tparam R The return type of the member function. + * \tparam T The class type of the member function. + * \tparam Args The argument types. + * \param func The const member function to be invoked. + * \param obj The object on which the member function will be invoked. + * \return A lambda that, when called, invokes the const member function with + * the given arguments. + */ template auto delayMemInvoke(R (T::*func)(Args...) const, const T *obj) { return [func, obj](Args... args) { @@ -43,6 +77,14 @@ auto delayMemInvoke(R (T::*func)(Args...) const, const T *obj) { }; } +/*! + * \brief Delays the invocation of a static member function with given + * arguments. \tparam R The return type of the static member function. \tparam T + * The class type of the static member function. \tparam Args The argument + * types. \param func The static member function to be invoked. \param obj The + * object (not used in static member functions). \return A lambda that, when + * called, invokes the static member function with the given arguments. + */ template auto delayStaticMemInvoke(R (*func)(Args...), T *obj) { return [func, obj](Args... args) { @@ -51,11 +93,28 @@ auto delayStaticMemInvoke(R (*func)(Args...), T *obj) { }; } +/*! + * \brief Delays the invocation of a member variable. + * \tparam T The class type of the member variable. + * \tparam M The type of the member variable. + * \param m The member variable to be accessed. + * \param obj The object on which the member variable will be accessed. + * \return A lambda that, when called, returns the member variable. + */ template auto delayMemberVarInvoke(M T::*m, T *obj) { return [m, obj]() -> decltype(auto) { return (obj->*m); }; } +/*! + * \brief Safely calls a function with given arguments, catching any exceptions. + * \tparam Func The function type. + * \tparam Args The argument types. + * \param func The function to be called. + * \param args The arguments to be passed to the function. + * \return The result of the function call, or a default-constructed value if an + * exception occurs. + */ template requires Invocable auto safeCall(Func &&func, Args &&...args) { @@ -73,6 +132,14 @@ auto safeCall(Func &&func, Args &&...args) { } } +/*! + * \brief Safely tries to call a function with given arguments, catching any + * exceptions. \tparam F The function type. \tparam Args The argument types. + * \param func The function to be called. + * \param args The arguments to be passed to the function. + * \return A variant containing either the result of the function call or an + * exception pointer. + */ template requires std::is_invocable_v, std::decay_t...> auto safeTryCatch(F &&func, Args &&...args) { @@ -93,6 +160,14 @@ auto safeTryCatch(F &&func, Args &&...args) { } } +/*! + * \brief Safely tries to call a function with given arguments, returning a + * default value if an exception occurs. \tparam Func The function type. \tparam + * Args The argument types. \param func The function to be called. \param + * default_value The default value to return if an exception occurs. \param args + * The arguments to be passed to the function. \return The result of the + * function call, or the default value if an exception occurs. + */ template requires Invocable auto safeTryCatchOrDefault( @@ -108,6 +183,15 @@ auto safeTryCatchOrDefault( } } +/*! + * \brief Safely tries to call a function with given arguments, using a custom + * handler if an exception occurs. \tparam Func The function type. \tparam Args + * The argument types. \param func The function to be called. \param handler The + * custom handler to be called if an exception occurs. \param args The arguments + * to be passed to the function. \return The result of the function call, or a + * default-constructed value if an exception occurs and the handler does not + * rethrow. + */ template requires Invocable auto safeTryCatchWithCustomHandler( diff --git a/src/atom/function/overload.hpp b/src/atom/function/overload.hpp index 9ca5bf74..9ee53bcd 100644 --- a/src/atom/function/overload.hpp +++ b/src/atom/function/overload.hpp @@ -9,65 +9,118 @@ #ifndef ATOM_META_OVERLOAD_HPP #define ATOM_META_OVERLOAD_HPP -namespace atom::meta { +#include +#include -// Simplified OverloadCast with improved type deduction and usage +namespace atom::meta { +/// @brief A utility to simplify the casting of overloaded member functions and +/// free functions. +/// @tparam Args The argument types of the function to be cast. template struct OverloadCast { + /// @brief Casts a non-const member function. + /// @tparam ReturnType The return type of the member function. + /// @tparam ClassType The class type of the member function. + /// @param func The member function pointer. + /// @return The casted member function pointer. template constexpr auto operator()( ReturnType (ClassType::*func)(Args...)) const noexcept { return func; } + /// @brief Casts a const member function. + /// @tparam ReturnType The return type of the member function. + /// @tparam ClassType The class type of the member function. + /// @param func The const member function pointer. + /// @return The casted const member function pointer. template constexpr auto operator()(ReturnType (ClassType::*func)(Args...) const) const noexcept { return func; } + /// @brief Casts a volatile member function. + /// @tparam ReturnType The return type of the member function. + /// @tparam ClassType The class type of the member function. + /// @param func The volatile member function pointer. + /// @return The casted volatile member function pointer. template constexpr auto operator()( ReturnType (ClassType::*func)(Args...) volatile) const noexcept { return func; } + /// @brief Casts a const volatile member function. + /// @tparam ReturnType The return type of the member function. + /// @tparam ClassType The class type of the member function. + /// @param func The const volatile member function pointer. + /// @return The casted const volatile member function pointer. template constexpr auto operator()(ReturnType (ClassType::*func)(Args...) const volatile) const noexcept { return func; } + /// @brief Casts a free function. + /// @tparam ReturnType The return type of the free function. + /// @param func The free function pointer. + /// @return The casted free function pointer. template constexpr auto operator()(ReturnType (*func)(Args...)) const noexcept { return func; } // Added noexcept overloads + + /// @brief Casts a non-const noexcept member function. + /// @tparam ReturnType The return type of the member function. + /// @tparam ClassType The class type of the member function. + /// @param func The noexcept member function pointer. + /// @return The casted noexcept member function pointer. template constexpr auto operator()( ReturnType (ClassType::*func)(Args...) noexcept) const noexcept { return func; } + /// @brief Casts a const noexcept member function. + /// @tparam ReturnType The return type of the member function. + /// @tparam ClassType The class type of the member function. + /// @param func The const noexcept member function pointer. + /// @return The casted const noexcept member function pointer. template constexpr auto operator()(ReturnType (ClassType::*func)(Args...) const noexcept) const noexcept { return func; } + /// @brief Casts a volatile noexcept member function. + /// @tparam ReturnType The return type of the member function. + /// @tparam ClassType The class type of the member function. + /// @param func The volatile noexcept member function pointer. + /// @return The casted volatile noexcept member function pointer. template constexpr auto operator()(ReturnType (ClassType::*func)( Args...) volatile noexcept) const noexcept { return func; } + /// @brief Casts a const volatile noexcept member function. + /// @tparam ReturnType The return type of the member function. + /// @tparam ClassType The class type of the member function. + /// @param func The const volatile noexcept member function pointer. + /// @return The casted const volatile noexcept member function pointer. template constexpr auto operator()(ReturnType (ClassType::*func)(Args...) const volatile noexcept) const noexcept { return func; } + /// @brief Casts a noexcept free function. + /// @tparam ReturnType The return type of the free function. + /// @param func The noexcept free function pointer. + /// @return The casted noexcept free function pointer. template constexpr auto operator()( ReturnType (*func)(Args...) noexcept) const noexcept { @@ -75,10 +128,19 @@ struct OverloadCast { } }; -// Helper function to instantiate OverloadCast, simplified to improve usability +/// @brief Helper function to instantiate OverloadCast, simplified to improve +/// usability. +/// @tparam Args The argument types of the function to be cast. +/// @return An instance of OverloadCast with the specified argument types. template constexpr auto overload_cast = OverloadCast{}; +template +constexpr std::decay_t decay_copy(T&& value) noexcept( + std::is_nothrow_convertible_v>) { + return std::forward(value); // 将值转发并转换为衰减类型 +} + } // namespace atom::meta #endif // ATOM_META_OVERLOAD_HPP diff --git a/src/atom/image/CMakeLists.txt b/src/atom/image/CMakeLists.txt new file mode 100644 index 00000000..e69de29b diff --git a/src/atom/image/fits_data.cpp b/src/atom/image/fits_data.cpp new file mode 100644 index 00000000..ee9b1b8f --- /dev/null +++ b/src/atom/image/fits_data.cpp @@ -0,0 +1,62 @@ +#include "fits_data.hpp" +#include + +template +void TypedFITSData::readData(std::ifstream& file, int64_t dataSize) { + data.resize(dataSize / sizeof(T)); + file.read(reinterpret_cast(data.data()), dataSize); + + if (std::endian::native == std::endian::little) { + for (auto& value : data) { + swapEndian(value); + } + } +} + +template +void TypedFITSData::writeData(std::ofstream& file) const { + std::vector tempData = data; + if (std::endian::native == std::endian::little) { + for (auto& value : tempData) { + swapEndian(value); + } + } + + file.write(reinterpret_cast(tempData.data()), tempData.size() * sizeof(T)); + + // Pad the data to a multiple of 2880 bytes + size_t padding = (2880 - (tempData.size() * sizeof(T)) % 2880) % 2880; + std::vector paddingData(padding, 0); + file.write(paddingData.data(), padding); +} + +template +DataType TypedFITSData::getDataType() const { + if constexpr (std::is_same_v) return DataType::BYTE; + else if constexpr (std::is_same_v) return DataType::SHORT; + else if constexpr (std::is_same_v) return DataType::INT; + else if constexpr (std::is_same_v) return DataType::LONG; + else if constexpr (std::is_same_v) return DataType::FLOAT; + else if constexpr (std::is_same_v) return DataType::DOUBLE; + else throw std::runtime_error("Unsupported data type"); +} + +template +size_t TypedFITSData::getElementCount() const { + return data.size(); +} + +template +template +void TypedFITSData::swapEndian(U& value) { + uint8_t* bytes = reinterpret_cast(&value); + std::reverse(bytes, bytes + sizeof(U)); +} + +// Explicit template instantiations +template class TypedFITSData; +template class TypedFITSData; +template class TypedFITSData; +template class TypedFITSData; +template class TypedFITSData; +template class TypedFITSData; diff --git a/src/atom/image/fits_data.hpp b/src/atom/image/fits_data.hpp new file mode 100644 index 00000000..725b17b9 --- /dev/null +++ b/src/atom/image/fits_data.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +enum class DataType { + BYTE, SHORT, INT, LONG, FLOAT, DOUBLE +}; + +class FITSData { +public: + virtual ~FITSData() = default; + virtual void readData(std::ifstream& file, int64_t dataSize) = 0; + virtual void writeData(std::ofstream& file) const = 0; + virtual DataType getDataType() const = 0; + virtual size_t getElementCount() const = 0; +}; + +template +class TypedFITSData : public FITSData { +public: + void readData(std::ifstream& file, int64_t dataSize) override; + void writeData(std::ofstream& file) const override; + DataType getDataType() const override; + size_t getElementCount() const override; + + const std::vector& getData() const { return data; } + std::vector& getData() { return data; } + +private: + std::vector data; + + template + static void swapEndian(U& value); +}; diff --git a/src/atom/image/fits_example.cpp b/src/atom/image/fits_example.cpp new file mode 100644 index 00000000..3a69342a --- /dev/null +++ b/src/atom/image/fits_example.cpp @@ -0,0 +1,98 @@ +#include +#include +#include "fits_file.hpp" + +int main() { + try { + FITSFile fitsFile; + + // 创建一个简单的 10x10 彩色图像 + auto imageHDU = std::make_unique(); + imageHDU->setImageSize(10, 10, 3); // 3 channels for RGB + imageHDU->setHeaderKeyword("SIMPLE", "T"); + imageHDU->setHeaderKeyword("BITPIX", "16"); + imageHDU->setHeaderKeyword("NAXIS", "3"); + imageHDU->setHeaderKeyword("EXTEND", "T"); + + // 用渐变填充图像 + for (int y = 0; y < 10; ++y) { + for (int x = 0; x < 10; ++x) { + imageHDU->setPixel(x, y, + static_cast(x * 1000 / 9), + 0); // Red channel + imageHDU->setPixel(x, y, + static_cast(y * 1000 / 9), + 1); // Green channel + imageHDU->setPixel( + x, y, static_cast((x + y) * 500 / 9), + 2); // Blue channel + } + } + + fitsFile.addHDU(std::move(imageHDU)); + + // 写入文件 + fitsFile.writeFITS("test_color.fits"); + + // 读取文件 + FITSFile readFile; + readFile.readFITS("test_color.fits"); + + // 验证图像内容 + const auto& readHDU = dynamic_cast(readFile.getHDU(0)); + auto [width, height, channels] = readHDU.getImageSize(); + std::cout << "Image size: " << width << "x" << height << "x" << channels + << std::endl; + + // 显示每个通道的第一行 + for (int c = 0; c < channels; ++c) { + std::cout << "Channel " << c << ", first row:" << std::endl; + for (int x = 0; x < width; ++x) { + std::cout << std::setw(5) << readHDU.getPixel(x, 0, c) + << " "; + } + std::cout << std::endl; + } + + // 计算每个通道的图像统计信息 + for (int c = 0; c < channels; ++c) { + auto stats = readHDU.computeImageStats(c); + std::cout << "\nImage statistics for channel " << c << ":" + << std::endl; + std::cout << "Min: " << stats.min << std::endl; + std::cout << "Max: " << stats.max << std::endl; + std::cout << "Mean: " << stats.mean << std::endl; + std::cout << "StdDev: " << stats.stddev << std::endl; + } + + // 应用高斯模糊滤波器到绿色通道 + std::vector> gaussianKernel = { + {1 / 16.0, 1 / 8.0, 1 / 16.0}, + {1 / 8.0, 1 / 4.0, 1 / 8.0}, + {1 / 16.0, 1 / 8.0, 1 / 16.0}}; + + auto& editableHDU = dynamic_cast(readFile.getHDU(0)); + editableHDU.applyFilter(gaussianKernel, + 1); // Apply to green channel only + + std::cout << "\nAfter applying Gaussian blur to green channel:" + << std::endl; + for (int c = 0; c < channels; ++c) { + std::cout << "Channel " << c << ", first row:" << std::endl; + for (int x = 0; x < width; ++x) { + std::cout << std::setw(5) + << editableHDU.getPixel(x, 0, c) << " "; + } + std::cout << std::endl; + } + + // 将修改后的图像保存到新文件 + readFile.writeFITS("test_color_blurred.fits"); + + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } + + return 0; +} diff --git a/src/atom/image/fits_file.cpp b/src/atom/image/fits_file.cpp new file mode 100644 index 00000000..99d6b9dc --- /dev/null +++ b/src/atom/image/fits_file.cpp @@ -0,0 +1,48 @@ +#include "fits_file.hpp" +#include +#include + +void FITSFile::readFITS(const std::string& filename) { + std::ifstream file(filename, std::ios::binary); + if (!file) { + throw std::runtime_error("Cannot open file: " + filename); + } + + hdus.clear(); + while (file.peek() != EOF) { + auto hdu = std::make_unique(); + hdu->readHDU(file); + hdus.push_back(std::move(hdu)); + } +} + +void FITSFile::writeFITS(const std::string& filename) const { + std::ofstream file(filename, std::ios::binary); + if (!file) { + throw std::runtime_error("Cannot create file: " + filename); + } + + for (const auto& hdu : hdus) { + hdu->writeHDU(file); + } +} + +size_t FITSFile::getHDUCount() const { return hdus.size(); } + +const HDU& FITSFile::getHDU(size_t index) const { + if (index >= hdus.size()) { + throw std::out_of_range("HDU index out of range"); + } + return *hdus[index]; +} + +HDU& FITSFile::getHDU(size_t index) { + if (index >= hdus.size()) { + throw std::out_of_range("HDU index out of range"); + } + return *hdus[index]; +} + +void FITSFile::addHDU(std::unique_ptr hdu) { + hdus.push_back(std::move(hdu)); +} diff --git a/src/atom/image/fits_file.hpp b/src/atom/image/fits_file.hpp new file mode 100644 index 00000000..8d40c0f7 --- /dev/null +++ b/src/atom/image/fits_file.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include "hdu.hpp" +#include +#include + +class FITSFile { +public: + void readFITS(const std::string& filename); + void writeFITS(const std::string& filename) const; + + size_t getHDUCount() const; + const HDU& getHDU(size_t index) const; + HDU& getHDU(size_t index); + void addHDU(std::unique_ptr hdu); + +private: + std::vector> hdus; +}; diff --git a/src/atom/image/fits_header.cpp b/src/atom/image/fits_header.cpp new file mode 100644 index 00000000..e69de29b diff --git a/src/atom/image/fits_header.hpp b/src/atom/image/fits_header.hpp new file mode 100644 index 00000000..911a10bf --- /dev/null +++ b/src/atom/image/fits_header.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include +#include +#include +#include + +class FITSHeader { +public: + static constexpr int FITS_HEADER_UNIT_SIZE = 2880; + static constexpr int FITS_HEADER_CARD_SIZE = 80; + + struct KeywordRecord { + std::array keyword; + std::array value; + }; + + void addKeyword(const std::string& keyword, const std::string& value); + std::string getKeywordValue(const std::string& keyword) const; + std::vector serialize() const; + void deserialize(const std::vector& data); + +private: + std::vector records; +}; diff --git a/src/atom/image/hdu.cpp b/src/atom/image/hdu.cpp new file mode 100644 index 00000000..f531e583 --- /dev/null +++ b/src/atom/image/hdu.cpp @@ -0,0 +1,216 @@ +#include "hdu.hpp" +#include +#include +#include + +void HDU::setHeaderKeyword(const std::string& keyword, + const std::string& value) { + header.addKeyword(keyword, value); +} + +std::string HDU::getHeaderKeyword(const std::string& keyword) const { + return header.getKeywordValue(keyword); +} + +void ImageHDU::readHDU(std::ifstream& file) { + std::vector headerData(FITSHeader::FITS_HEADER_UNIT_SIZE); + file.read(headerData.data(), headerData.size()); + header.deserialize(headerData); + + width = std::stoi(header.getKeywordValue("NAXIS1")); + height = std::stoi(header.getKeywordValue("NAXIS2")); + channels = std::stoi(header.getKeywordValue("NAXIS3")); + int bitpix = std::stoi(header.getKeywordValue("BITPIX")); + + switch (bitpix) { + case 8: + initializeData(); + break; + case 16: + initializeData(); + break; + case 32: + initializeData(); + break; + case 64: + initializeData(); + break; + case -32: + initializeData(); + break; + case -64: + initializeData(); + break; + default: + throw std::runtime_error("Unsupported BITPIX value"); + } + + int64_t dataSize = + static_cast(width) * height * channels * std::abs(bitpix) / 8; + data->readData(file, dataSize); +} + +void ImageHDU::writeHDU(std::ofstream& file) const { + auto headerData = header.serialize(); + file.write(headerData.data(), headerData.size()); + data->writeData(file); +} + +void ImageHDU::setImageSize(int w, int h, int c) { + width = w; + height = h; + channels = c; + header.addKeyword("NAXIS1", std::to_string(width)); + header.addKeyword("NAXIS2", std::to_string(height)); + if (channels > 1) { + header.addKeyword("NAXIS", "3"); + header.addKeyword("NAXIS3", std::to_string(channels)); + } else { + header.addKeyword("NAXIS", "2"); + } +} + +std::tuple ImageHDU::getImageSize() const { + return {width, height, channels}; +} + +template +void ImageHDU::setPixel(int x, int y, T value, int channel) { + if (x < 0 || x >= width || y < 0 || y >= height || channel < 0 || + channel >= channels) { + throw std::out_of_range("Pixel coordinates or channel out of range"); + } + auto& typedData = static_cast&>(*data); + typedData.getData()[(y * width + x) * channels + channel] = value; +} + +template +T ImageHDU::getPixel(int x, int y, int channel) const { + if (x < 0 || x >= width || y < 0 || y >= height || channel < 0 || + channel >= channels) { + throw std::out_of_range("Pixel coordinates or channel out of range"); + } + const auto& typedData = static_cast&>(*data); + return typedData.getData()[(y * width + x) * channels + channel]; +} + +template +typename ImageHDU::template ImageStats ImageHDU::computeImageStats( + int channel) const { + const auto& typedData = static_cast&>(*data); + const auto& pixelData = typedData.getData(); + + T min = std::numeric_limits::max(); + T max = std::numeric_limits::lowest(); + double sum = 0.0; + + for (int i = channel; i < pixelData.size(); i += channels) { + T pixel = pixelData[i]; + min = std::min(min, pixel); + max = std::max(max, pixel); + sum += static_cast(pixel); + } + + size_t pixelCount = width * height; + double mean = sum / pixelCount; + double variance = 0.0; + + for (int i = channel; i < pixelData.size(); i += channels) { + double diff = static_cast(pixelData[i]) - mean; + variance += diff * diff; + } + + variance /= pixelCount; + double stddev = std::sqrt(variance); + + return {min, max, mean, stddev}; +} + +template +void ImageHDU::applyFilter(const std::vector>& kernel, + int channel) { + auto& typedData = static_cast&>(*data); + auto& pixelData = typedData.getData(); + + int kernelHeight = kernel.size(); + int kernelWidth = kernel[0].size(); + int kernelCenterY = kernelHeight / 2; + int kernelCenterX = kernelWidth / 2; + + std::vector newPixelData(pixelData.size()); + + for (int c = 0; c < channels; ++c) { + if (channel != -1 && c != channel) + continue; + + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + double sum = 0.0; + for (int ky = 0; ky < kernelHeight; ++ky) { + for (int kx = 0; kx < kernelWidth; ++kx) { + int imgY = y + ky - kernelCenterY; + int imgX = x + kx - kernelCenterX; + if (imgY >= 0 && imgY < height && imgX >= 0 && + imgX < width) { + sum += + kernel[ky][kx] * + pixelData[(imgY * width + imgX) * channels + c]; + } + } + } + newPixelData[(y * width + x) * channels + c] = + static_cast(sum); + } + } + } + + pixelData = std::move(newPixelData); +} + +template +void ImageHDU::initializeData() { + data = std::make_unique>(); + auto& typedData = static_cast&>(*data); + typedData.getData().resize(width * height * channels); +} + +// Explicit template instantiations +template void ImageHDU::setPixel(int, int, uint8_t, int); +template void ImageHDU::setPixel(int, int, int16_t, int); +template void ImageHDU::setPixel(int, int, int32_t, int); +template void ImageHDU::setPixel(int, int, int64_t, int); +template void ImageHDU::setPixel(int, int, float, int); +template void ImageHDU::setPixel(int, int, double, int); + +template uint8_t ImageHDU::getPixel(int, int, int) const; +template int16_t ImageHDU::getPixel(int, int, int) const; +template int32_t ImageHDU::getPixel(int, int, int) const; +template int64_t ImageHDU::getPixel(int, int, int) const; +template float ImageHDU::getPixel(int, int, int) const; +template double ImageHDU::getPixel(int, int, int) const; + +template ImageHDU::ImageStats ImageHDU::computeImageStats( + int) const; +template ImageHDU::ImageStats ImageHDU::computeImageStats( + int) const; +template ImageHDU::ImageStats ImageHDU::computeImageStats( + int) const; +template ImageHDU::ImageStats ImageHDU::computeImageStats( + int) const; +template ImageHDU::ImageStats ImageHDU::computeImageStats( + int) const; +template ImageHDU::ImageStats ImageHDU::computeImageStats( + int) const; + +template void ImageHDU::applyFilter( + const std::vector>&, int); +template void ImageHDU::applyFilter( + const std::vector>&, int); +template void ImageHDU::applyFilter( + const std::vector>&, int); +template void ImageHDU::applyFilter( + const std::vector>&, int); +template void ImageHDU::applyFilter( + const std::vector>&, int); +template void ImageHDU::applyFilter( + const std::vector>&, int); diff --git a/src/atom/image/hdu.hpp b/src/atom/image/hdu.hpp new file mode 100644 index 00000000..6f316788 --- /dev/null +++ b/src/atom/image/hdu.hpp @@ -0,0 +1,66 @@ +#pragma once + +#include +#include +#include +#include "fits_data.hpp" +#include "fits_header.hpp" + +class HDU { +public: + virtual ~HDU() = default; + virtual void readHDU(std::ifstream& file) = 0; + virtual void writeHDU(std::ofstream& file) const = 0; + + const FITSHeader& getHeader() const { return header; } + FITSHeader& getHeader() { return header; } + + void setHeaderKeyword(const std::string& keyword, const std::string& value); + std::string getHeaderKeyword(const std::string& keyword) const; + +protected: + FITSHeader header; + std::unique_ptr data; +}; + +class ImageHDU : public HDU { +public: + void readHDU(std::ifstream& file) override; + void writeHDU(std::ofstream& file) const override; + + void setImageSize(int w, int h, int c = 1); + std::tuple getImageSize() const; + + template + void setPixel(int x, int y, T value, int channel = 0); + + template + T getPixel(int x, int y, int channel = 0) const; + + template + struct ImageStats { + T min; + T max; + double mean; + double stddev; + }; + + template + ImageStats computeImageStats(int channel = 0) const; + + template + void applyFilter(const std::vector>& kernel, + int channel = -1); + + // New methods for color image support + bool isColor() const { return channels > 1; } + int getChannelCount() const { return channels; } + +private: + int width = 0; + int height = 0; + int channels = 1; + + template + void initializeData(); +}; diff --git a/src/atom/image/image_blob.hpp b/src/atom/image/image_blob.hpp new file mode 100644 index 00000000..42050281 --- /dev/null +++ b/src/atom/image/image_blob.hpp @@ -0,0 +1,326 @@ +#ifndef ATOM_IMAGE_BLOB_HPP +#define ATOM_IMAGE_BLOB_HPP + +#include +#include +#include +#include +#include +#include + +#include "atom/error/exception.hpp" + +// Max: Here we need to include opencv2/core.hpp and opencv2/imgproc.hpp to use cv::Mat. +#if __has_include() +#include +#include +#endif + +namespace atom::image { + +template +concept BlobType = + std::same_as || std::same_as; + +template +concept BlobValueType = std::is_trivially_copyable_v; + +enum class BlobMode { NORMAL, FAST }; + +template +class Blob { +private: + std::conditional_t, std::vector> + storage_; + int rows_ = 0; + int cols_ = 0; + int channels_ = 1; +#if __has_include() + int depth = CV_8U; +#else + int depth_ = 8; +#endif +public: + Blob() noexcept = default; + Blob(const Blob&) = default; + Blob(Blob&&) noexcept = default; + auto operator=(const Blob&) -> Blob& = default; + auto operator=(Blob&&) noexcept -> Blob& = default; + + template + requires(std::is_const_v && std::same_as) + explicit Blob(const Blob& that) noexcept + : storage_(that.storage), + rows_(that.rows), + cols_(that.cols), + channels_(that.channels), + depth_(that.depth) {} + + template + requires(std::is_const_v && std::same_as) + explicit Blob(Blob&& that) noexcept + : storage_(std::move(that.storage)), + rows_(that.rows), + cols_(that.cols), + channels_(that.channels), + depth_(that.depth) {} + + Blob(void* ptr, size_t n) noexcept + requires(Mode == BlobMode::FAST) + : storage_(reinterpret_cast(ptr), n) {} + + template + explicit Blob(U& var) noexcept + requires(Mode == BlobMode::FAST) + : storage_(reinterpret_cast(&var), sizeof(U)) {} + + template + Blob(U* ptr, size_t n) + : storage_(Mode == BlobMode::FAST + ? std::span(reinterpret_cast(ptr), n * sizeof(U)) + : std::vector( + reinterpret_cast(ptr), + reinterpret_cast(ptr) + n * sizeof(U))) {} + + template + explicit Blob(U (&arr)[N]) + : storage_(Mode == BlobMode::FAST + ? std::span(reinterpret_cast(arr), sizeof(U) * N) + : std::vector( + reinterpret_cast(arr), + reinterpret_cast(arr) + sizeof(U) * N)) {} + +#if __has_include() + explicit blob_(const cv::Mat& mat) + : rows(mat.rows), + cols(mat.cols), + channels(mat.channels()), + depth(mat.depth()) { + if (mat.isContinuous()) { + if constexpr (Mode == BlobMode::Fast) { + storage = std::span(reinterpret_cast(mat.data), + mat.total() * mat.elemSize()); + } else { + storage.assign(reinterpret_cast(mat.data), + reinterpret_cast(mat.data) + + mat.total() * mat.elemSize()); + } + } else { + if constexpr (Mode == BlobMode::Fast) { + THROW_RUNTIME_ERROR( + "Cannot create fast blob from non-continuous cv::Mat"); + } else { + storage.reserve(mat.total() * mat.elemSize()); + for (int i = 0; i < mat.rows; ++i) { + storage.insert(storage.end(), mat.ptr(i), + mat.ptr(i) + mat.cols * mat.elemSize()); + } + } + } + } +#endif + + auto operator[](size_t idx) -> T& { return storage_[idx]; } + auto operator[](size_t idx) const -> const T& { return storage_[idx]; } + + auto begin() { return storage_.begin(); } + auto end() { return storage_.end(); } + auto begin() const { return storage_.begin(); } + auto end() const { return storage_.end(); } + + [[nodiscard]] auto size() const -> size_t { return storage_.size(); } + + auto slice(size_t offset, size_t length) const -> Blob { + if (offset + length > size()) { + THROW_OUT_OF_RANGE("Slice range out of bounds"); + } + Blob result; + result.rows_ = 1; + result.cols_ = length; + result.channels_ = channels_; + result.depth_ = depth_; + if constexpr (Mode == BlobMode::FAST) { + result.storage_ = std::span(storage_.data() + offset, length); + } else { + result.storage_.assign(storage_.begin() + offset, + storage_.begin() + offset + length); + } + return result; + } + + auto operator==(const Blob& other) const -> bool { + return std::ranges::equal(storage_, other.storage_) && + rows_ == other.rows_ && cols_ == other.cols_ && + channels_ == other.channels_ && depth_ == other.depth_; + } + + void fill(T value) { std::ranges::fill(storage_, value); } + + void append(const Blob& other) { + if constexpr (Mode == BlobMode::FAST) { + THROW_RUNTIME_ERROR("Cannot append in Fast mode"); + } else { + storage_.insert(storage_.end(), other.storage_.begin(), + other.storage_.end()); + // Update properties (assuming appending along rows) + rows_ += other.rows_; + } + } + + void append(const void* ptr, size_t n) { + if constexpr (Mode == BlobMode::FAST) { + THROW_RUNTIME_ERROR("Cannot append in Fast mode"); + } else { + const auto* bytePtr = reinterpret_cast(ptr); + storage_.insert(storage_.end(), bytePtr, bytePtr + n); + // Update properties (assuming appending along rows) + rows_ += n / (cols_ * channels_); + } + } + + void allocate(size_t size) { + if constexpr (Mode == BlobMode::FAST) { + THROW_RUNTIME_ERROR("Cannot allocate in Fast mode"); + } else { + storage_.resize(size); + } + } + + void deallocate() { + if constexpr (Mode == BlobMode::FAST) { + THROW_RUNTIME_ERROR("Cannot deallocate in Fast mode"); + } else { + storage_.clear(); + storage_.shrink_to_fit(); + } + } + + void xorWith(const Blob& other) { + if (size() != other.size()) { + THROW_RUNTIME_ERROR( + "Blobs must be of the same size for XOR operation"); + } + std::transform(begin(), end(), other.begin(), begin(), std::bit_xor()); + } + + auto compress() const -> Blob { + Blob compressed; + for (size_t i = 0; i < size(); ++i) { + T current = storage_[i]; + size_t count = 1; + while (i + 1 < size() && storage_[i + 1] == current) { + ++count; + ++i; + } + compressed.append(¤t, sizeof(T)); + compressed.append(&count, sizeof(size_t)); + } + return compressed; + } + + auto decompress() const -> Blob { + Blob decompressed; + for (size_t i = 0; i < size(); i += sizeof(T) + sizeof(size_t)) { + T value; + size_t count; + std::memcpy(&value, &storage_[i], sizeof(T)); + std::memcpy(&count, &storage_[i + sizeof(T)], sizeof(size_t)); + for (size_t j = 0; j < count; ++j) { + decompressed.append(&value, sizeof(T)); + } + } + return decompressed; + } + + auto serialize() const -> std::vector { + std::vector serialized; + size_t size = storage_.size(); + serialized.insert( + serialized.end(), reinterpret_cast(&size), + reinterpret_cast(&size) + sizeof(size_t)); + serialized.insert(serialized.end(), + reinterpret_cast(storage_.data()), + reinterpret_cast(storage_.data()) + + storage_.size()); + return serialized; + } + + static auto deserialize(const std::vector& data) -> Blob { + if (data.size() < sizeof(size_t)) { + THROW_RUNTIME_ERROR("Invalid serialized data"); + } + size_t size; + std::memcpy(&size, data.data(), sizeof(size_t)); + if (data.size() != sizeof(size_t) + size) { + THROW_RUNTIME_ERROR("Invalid serialized data size"); + } + return Blob(reinterpret_cast(data.data() + sizeof(size_t)), + size); + } + +#if __has_include() + cv::Mat to_mat() const { + int type = CV_MAKETYPE(depth, channels); + cv::Mat mat(rows, cols, type); + std::memcpy(mat.data, storage.data(), size()); + return mat; + } + + void apply_filter(const cv::Mat& kernel) { + cv::Mat src = to_mat(); + cv::Mat dst; + cv::filter2D(src, dst, -1, kernel); + *this = blob_(dst); + } + + void resize(int new_rows, int new_cols) { + cv::Mat src = to_mat(); + cv::Mat dst; + cv::resize(src, dst, cv::Size(new_cols, new_rows)); + *this = blob_(dst); + } + + void convert_color(int code) { + cv::Mat src = to_mat(); + cv::Mat dst; + cv::cvtColor(src, dst, code); + *this = blob_(dst); + } + + std::vector split_channels() const { + cv::Mat src = to_mat(); + std::vector channels; + cv::split(src, channels); + std::vector channel_blobs; + for (const auto& channel : channels) { + channel_blobs.emplace_back(channel); + } + return channel_blobs; + } + + static blob_ merge_channels(const std::vector& channel_blobs) { + std::vector channels; + for (const auto& blob : channel_blobs) { + channels.push_back(blob.to_mat()); + } + cv::Mat merged; + cv::merge(channels, merged); + return blob_(merged); + } +#endif + + [[nodiscard]] auto get_rows() const -> int { return rows_; } + [[nodiscard]] auto get_cols() const -> int { return cols_; } + [[nodiscard]] auto get_channels() const -> int { return channels_; } + [[nodiscard]] auto get_depth() const -> int { return depth_; } +}; + +using cblob = Blob; +using blob = Blob; + +using fast_cblob = Blob; +using fast_blob = Blob; + +} // namespace atom::image + +#endif // ATOM_IMAGE_BLOB_HPP diff --git a/src/atom/io/asyncio.cpp b/src/atom/io/asyncio.cpp new file mode 100644 index 00000000..4fabf845 --- /dev/null +++ b/src/atom/io/asyncio.cpp @@ -0,0 +1,231 @@ +#include "asyncio.hpp" + +#include + +#include "atom/error/exception.hpp" +#include "atom/log/loguru.hpp" + +namespace atom::io { +FileWriter::FileWriter(handle_type h) : coro_handle(h) {} +FileWriter::~FileWriter() { + if (coro_handle) { + coro_handle.destroy(); + } +} + +auto FileWriter::promise_type::get_return_object() { + return FileWriter{handle_type::from_promise(*this)}; +} +auto FileWriter::promise_type::initial_suspend() noexcept { + return std::suspend_never{}; +} +auto FileWriter::promise_type::final_suspend() noexcept { + return std::suspend_never{}; +} +void FileWriter::promise_type::return_void() {} +void FileWriter::promise_type::unhandled_exception() { std::terminate(); } + +bool FileWriter::await_ready() const noexcept { return false; } +void FileWriter::await_suspend(std::coroutine_handle<> h) { + coro_handle = handle_type::from_address(h.address()); +} +void FileWriter::await_resume() const {} + +FileWriter async_write(const std::string& filename, const std::string& data) { +#ifdef USE_ASIO + asio::io_context io_context; + asio::posix::stream_descriptor file( + io_context, ::open(filename.c_str(), O_WRONLY | O_CREAT, 0644)); + if (!file.is_open()) { + THROW_FAIL_TO_OPEN_FILE(errno, std::generic_category(), + "Failed to open file"); + } + + asio::async_write( + file, asio::buffer(data), [](std::error_code ec, std::size_t) { + if (ec) { + LOG_F(ERROR, "async_write failed with error: {}", ec.message()); + } + }); + + io_context.run(); +#else + std::ofstream file(filename, std::ios::binary | std::ios::out); + if (!file.is_open()) { + THROW_FAIL_TO_OPEN_FILE("Failed to open file"); + } + +#ifdef _WIN32 + OVERLAPPED overlapped = {0}; + HANDLE hFile = CreateFile(filename.c_str(), GENERIC_WRITE, 0, NULL, + CREATE_ALWAYS, FILE_FLAG_OVERLAPPED, NULL); + if (hFile == INVALID_HANDLE_VALUE) { + THROW_FAIL_TO_CREATE_FILE(GetLastError(), std::system_category(), + "Failed to create file"); + } + + WriteFileEx(hFile, data.c_str(), data.size(), &overlapped, + [](DWORD dwErrorCode, DWORD /*dwNumberOfBytesTransfered*/, + LPOVERLAPPED /*lpOverlapped*/) { + if (dwErrorCode != 0) { + LOG_F(ERROR, "WriteFileEx failed with error: {}", + dwErrorCode); + } + }); + + SleepEx(INFINITE, TRUE); + CloseHandle(hFile); +#else + aiocb cb = {0}; + cb.aio_fildes = open(filename.c_str(), O_WRONLY | O_CREAT, 0644); + if (cb.aio_fildes == -1) { + THROW_FAIL_TO_OPEN_FILE(errno, std::generic_category(), + "Failed to open file"); + } + + cb.aio_buf = data.c_str(); + cb.aio_nbytes = data.size(); + cb.aio_offset = 0; + + if (aio_write(&cb) == -1) { + THROW_FAIL_TO_WRITE_FILE(errno, std::generic_category(), + "Failed to initiate aio_write"); + } + + while (aio_error(&cb) == EINPROGRESS) { + // Wait for the write to complete + } + + if (aio_return(&cb) == -1) { + THROW_FAIL_TO_WRITE_FILE(errno, std::generic_category(), + "aio_write failed"); + } + + close(cb.aio_fildes); +#endif +#endif + + co_return; +} + +FileWriter async_read(const std::string& filename, std::string& data, + std::size_t size) { +#ifdef USE_ASIO + asio::io_context io_context; + asio::posix::stream_descriptor file(io_context, + ::open(filename.c_str(), O_RDONLY)); + if (!file.is_open()) { + THROW_FAIL_TO_OPEN_FILE(errno, std::generic_category(), + "Failed to open file"); + } + + data.resize(size); + asio::async_read( + file, asio::buffer(data), [](std::error_code ec, std::size_t) { + if (ec) { + LOG_F(ERROR, "async_read failed with error: {}", ec.message()); + } + }); + + io_context.run(); +#else +#ifdef _WIN32 + OVERLAPPED overlapped = {0}; + HANDLE hFile = CreateFile(filename.c_str(), GENERIC_READ, 0, NULL, + OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL); + if (hFile == INVALID_HANDLE_VALUE) { + THROW_FAIL_TO_OPEN_FILE(GetLastError(), std::system_category(), + "Failed to open file"); + } + + data.resize(size); + ReadFileEx(hFile, &data[0], size, &overlapped, + [](DWORD dwErrorCode, DWORD /*dwNumberOfBytesTransfered*/, + LPOVERLAPPED /*lpOverlapped*/) { + if (dwErrorCode != 0) { + LOG_F(ERROR, "ReadFileEx failed with error: {}", + dwErrorCode); + } + }); + + SleepEx(INFINITE, TRUE); + CloseHandle(hFile); +#else + aiocb cb = {0}; + cb.aio_fildes = open(filename.c_str(), O_RDONLY); + if (cb.aio_fildes == -1) { + THROW_FAIL_TO_OPEN_FILE(errno, std::generic_category(), + "Failed to open file"); + } + + data.resize(size); + cb.aio_buf = &data[0]; + cb.aio_nbytes = size; + cb.aio_offset = 0; + + if (aio_read(&cb) == -1) { + THROW_FAIL_TO_WRITE_FILE(errno, std::generic_category(), + "Failed to initiate aio_read"); + } + + while (aio_error(&cb) == EINPROGRESS) { + // Wait for the read to complete + } + + if (aio_return(&cb) == -1) { + THROW_FAIL_TO_WRITE_FILE(errno, std::generic_category(), + "aio_read failed"); + } + + close(cb.aio_fildes); +#endif +#endif + + co_return; +} + +FileWriter async_delete(const std::string& filename) { +#ifdef USE_ASIO + asio::io_context io_context; + io_context.post([filename]() { + if (std::remove(filename.c_str()) != 0) { + LOG_F(ERROR, "Failed to delete file: {}", filename); + } + }); + io_context.run(); +#else + if (std::remove(filename.c_str()) != 0) { + THROW_FAIL_TO_DELETE_FILE(errno, std::generic_category(), + "Failed to delete file"); + } +#endif + + co_return; +} + +FileWriter async_copy(const std::string& src_filename, + const std::string& dest_filename) { +#ifdef USE_ASIO + asio::io_context io_context; + io_context.post([src_filename, dest_filename]() { + std::ifstream src(src_filename, std::ios::binary); + std::ofstream dest(dest_filename, std::ios::binary); + if (!src.is_open() || !dest.is_open()) { + LOG_F(ERROR, "Failed to open source or destination file"); + return; + } + dest << src.rdbuf(); + }); + io_context.run(); +#else + std::ifstream src(src_filename, std::ios::binary); + std::ofstream dest(dest_filename, std::ios::binary); + if (!src.is_open() || !dest.is_open()) { + THROW_FAIL_TO_OPEN_FILE("Failed to open source or destination file"); + } + dest << src.rdbuf(); +#endif + + co_return; +} +} // namespace atom::io diff --git a/src/atom/io/asyncio.hpp b/src/atom/io/asyncio.hpp new file mode 100644 index 00000000..50607c21 --- /dev/null +++ b/src/atom/io/asyncio.hpp @@ -0,0 +1,51 @@ +#ifndef ATOM_IO_ASYNC_IO_HPP +#define ATOM_IO_ASYNC_IO_HPP + +#ifdef _WIN32 +#include +#else +#include +#include +#include +#include +#endif + +#ifdef USE_ASIO +#include +#endif + +#include +#include + +namespace atom::io { +struct FileWriter { + struct promise_type; + using handle_type = std::coroutine_handle; + + handle_type coro_handle; + + FileWriter(handle_type h); + ~FileWriter(); + + struct promise_type { + auto get_return_object(); + auto initial_suspend() noexcept; + auto final_suspend() noexcept; + void return_void(); + void unhandled_exception(); + }; + + bool await_ready() const noexcept; + void await_suspend(std::coroutine_handle<> h); + void await_resume() const; +}; + +FileWriter async_write(const std::string& filename, const std::string& data); +FileWriter async_read(const std::string& filename, std::string& data, + std::size_t size); +FileWriter async_delete(const std::string& filename); +FileWriter async_copy(const std::string& src_filename, + const std::string& dest_filename); +} // namespace atom::io + +#endif // ATOM_IO_ASYNC_IO_HPP diff --git a/src/atom/sysinfo/cpu.hpp b/src/atom/sysinfo/cpu.hpp index 993e79d0..3130ddda 100644 --- a/src/atom/sysinfo/cpu.hpp +++ b/src/atom/sysinfo/cpu.hpp @@ -20,84 +20,119 @@ Description: System Information Module - CPU #include "macro.hpp" namespace atom::system { +/** + * @brief A structure to hold the sizes of the CPU caches. + * + * This structure contains the sizes of the L1 data cache, L1 instruction cache, + * L2 cache, and L3 cache of the CPU. These values are typically provided in + * bytes. The cache size information is important for performance tuning and + * understanding the CPU's capabilities. + */ struct CacheSizes { - size_t l1d; - size_t l1i; - size_t l2; - size_t l3; -} ATOM_ALIGNAS(32); + size_t l1d; ///< The size of the L1 data cache in bytes. + size_t l1i; ///< The size of the L1 instruction cache in bytes. + size_t l2; ///< The size of the L2 cache in bytes. + size_t l3; ///< The size of the L3 cache in bytes. +} ATOM_ALIGNAS(32); ///< Ensure the structure is aligned to a 32-byte boundary. /** - * @brief Get the CPU usage percentage. - * 获取 CPU 使用率百分比 + * @brief Retrieves the current CPU usage percentage. * - * @return The CPU usage percentage. - * CPU 使用率百分比 + * This function calculates and returns the current percentage of CPU usage. + * It typically measures how much time the CPU spends in active processing + * compared to idle time. The CPU usage percentage can be useful for + * monitoring system performance and detecting high load conditions. + * + * @return A float representing the current CPU usage as a percentage (0.0 to + * 100.0). */ -auto getCurrentCpuUsage() -> float; +[[nodiscard]] auto getCurrentCpuUsage() -> float; /** - * @brief Get the CPU temperature. - * 获取 CPU 温度 + * @brief Retrieves the current CPU temperature. + * + * This function returns the current temperature of the CPU in degrees Celsius. + * Monitoring the CPU temperature is important for preventing overheating and + * ensuring optimal performance. If the temperature is too high, it could + * indicate cooling issues or high load on the system. * - * @return The CPU temperature. - * CPU 温度 + * @return A float representing the CPU temperature in degrees Celsius. */ -auto getCurrentCpuTemperature() -> float; +[[nodiscard]] auto getCurrentCpuTemperature() -> float; /** - * @brief Get the CPU model. - * 获取 CPU 型号 + * @brief Retrieves the CPU model name. + * + * This function returns a string that contains the model name of the CPU. + * The CPU model provides information about the manufacturer and specific model + * (e.g., "Intel Core i7-10700K", "AMD Ryzen 9 5900X"). This information can be + * useful for system diagnostics and performance evaluations. * - * @return The CPU model. - * CPU 型号 + * @return A string representing the CPU model name. */ -auto getCPUModel() -> std::string; +[[nodiscard]] auto getCPUModel() -> std::string; /** - * @brief Get the CPU identifier. - * 获取 CPU 标识 + * @brief Retrieves the CPU identifier. * - * @return The CPU identifier. - * CPU 标识 + * This function returns a unique string identifier for the CPU. This identifier + * typically includes details about the CPU architecture, model, stepping, and + * other low-level characteristics. It can be useful for identifying specific + * processor versions in a system. + * + * @return A string representing the CPU identifier. */ -auto getProcessorIdentifier() -> std::string; +[[nodiscard]] auto getProcessorIdentifier() -> std::string; /** - * @brief Get the CPU frequency. - * 获取 CPU 频率 + * @brief Retrieves the current CPU frequency. + * + * This function returns the current operating frequency of the CPU in GHz. + * The frequency may vary depending on system load, power settings, and + * the capabilities of the CPU (e.g., turbo boost, power-saving modes). + * Understanding the CPU frequency is important for performance tuning and + * optimizing application performance. * - * @return The CPU frequency. - * CPU 频率 + * @return A double representing the CPU frequency in gigahertz (GHz). */ -auto getProcessorFrequency() -> double; +[[nodiscard]] auto getProcessorFrequency() -> double; /** - * @brief Get the number of physical CPUs. - * 获取物理 CPU 数量 + * @brief Retrieves the number of physical CPU packages. + * + * This function returns the number of physical CPU packages (sockets) installed + * in the system. A system may have multiple CPU packages, especially in + * server configurations. Each physical package may contain multiple cores. * - * @return The number of physical CPUs. - * 物理 CPU 数量 + * @return An integer representing the number of physical CPU packages. */ -auto getNumberOfPhysicalPackages() -> int; +[[nodiscard]] auto getNumberOfPhysicalPackages() -> int; /** - * @brief Get the number of logical CPUs. - * 获取逻辑 CPU 数量 + * @brief Retrieves the number of logical CPUs (cores). * - * @return The number of logical CPUs. - * 逻辑 CPU 数量 + * This function returns the number of logical CPUs (cores) available in the + * system. Logical CPUs include both physical cores and additional virtual cores + * created by technologies like Hyper-Threading (on Intel processors) or SMT + * (on AMD processors). This value represents the total number of logical + * processors that the operating system can use. + * + * @return An integer representing the total number of logical CPUs (cores). */ -auto getNumberOfPhysicalCPUs() -> int; +[[nodiscard]] auto getNumberOfPhysicalCPUs() -> int; /** - * @brief Get the cache sizes. - * 获取缓存大小 + * @brief Retrieves the sizes of the CPU caches (L1, L2, L3). + * + * This function returns a `CacheSizes` structure that contains the sizes of + * the L1 data cache (L1D), L1 instruction cache (L1I), L2 cache, and L3 cache. + * Cache sizes play an important role in determining CPU performance, as + * larger caches can improve data locality and reduce memory latency. * - * @return The cache sizes. - * 缓存大小 + * @return A `CacheSizes` structure containing the sizes of the L1, L2, and L3 + * caches in bytes. */ -auto getCacheSizes() -> CacheSizes; +[[nodiscard]] auto getCacheSizes() -> CacheSizes; } // namespace atom::system diff --git a/src/atom/sysinfo/disk.cpp b/src/atom/sysinfo/disk.cpp index 2babba7b..c9fb5fb4 100644 --- a/src/atom/sysinfo/disk.cpp +++ b/src/atom/sysinfo/disk.cpp @@ -14,8 +14,6 @@ Description: System Information Module - Disk #include "atom/sysinfo/disk.hpp" -#include "atom/log/loguru.hpp" - #include #include #include @@ -43,6 +41,8 @@ Description: System Information Module - Disk #include #endif +#include "atom/log/loguru.hpp" + namespace fs = std::filesystem; namespace atom::system { @@ -278,7 +278,7 @@ bool checkForMaliciousFiles(const std::string& path) { } bool isDeviceInWhiteList(const std::string& deviceID) { - if (whiteList.find(deviceID) != whiteList.end()) { + if (whiteList.contains(deviceID)) { LOG_F(INFO, "Device {} is in the whitelist. Access granted.", deviceID); return true; } @@ -286,6 +286,97 @@ bool isDeviceInWhiteList(const std::string& deviceID) { return false; } +auto getFileSystemType(const std::string& path) -> std::string { +#ifdef _WIN32 + char fileSystemNameBuffer[MAX_PATH] = {0}; + if (!GetVolumeInformationA(path.c_str(), // 根目录路径 + NULL, // 卷名称 + 0, // 卷名称缓冲区大小 + NULL, // 卷序列号 + NULL, // 最大组件长度 + NULL, // 文件系统标志 + fileSystemNameBuffer, // 文件系统名称 + sizeof(fileSystemNameBuffer))) { + LOG_F(ERROR, "Error retrieving filesystem information for: {}", path); + return "Unknown"; + } + return std::string(fileSystemNameBuffer); + +#elif __linux__ || __ANDROID__ + struct statfs buffer; + if (statfs(path.c_str(), &buffer) != 0) { + LOG_F(ERROR, "Error retrieving filesystem information for: {}", path); + return "Unknown"; + } + + // 文件系统类型 + switch (buffer.f_type) { + case 0xEF53: + return "ext4"; + case 0x6969: + return "nfs"; + case 0xFF534D42: + return "cifs"; + case 0x4d44: + return "vfat"; + default: + return "Unknown"; + } + +#elif __APPLE__ + struct statfs buffer; + if (statfs(path.c_str(), &buffer) != 0) { + LOG_F(ERROR, "Error retrieving filesystem information for: {}", path); + return "Unknown"; + } + + if (strcmp(buffer.f_fstypename, "hfs") == 0) { + return "HFS"; + } else if (strcmp(buffer.f_fstypename, "apfs") == 0) { + return "APFS"; + } else if (strcmp(buffer.f_fstypename, "msdos") == 0) { + return "FAT32"; + } else if (strcmp(buffer.f_fstypename, "exfat") == 0) { + return "ExFAT"; + } else if (strcmp(buffer.f_fstypename, "nfs") == 0) { + return "NFS"; + } else { + return "Unknown"; + } + +#elif __FreeBSD__ || __NetBSD__ || __OpenBSD__ + struct statfs buffer; + if (statfs(path.c_str(), &buffer) != 0) { + LOG_F(ERROR, "Error retrieving filesystem information for: {}", path); + return "Unknown"; + } + + if (strcmp(buffer.f_fstypename, "ufs") == 0) { + return "UFS"; + } else if (strcmp(buffer.f_fstypename, "zfs") == 0) { + return "ZFS"; + } else if (strcmp(buffer.f_fstypename, "msdosfs") == 0) { + return "FAT32"; + } else if (strcmp(buffer.f_fstypename, "nfs") == 0) { + return "NFS"; + } else { + return "Unknown"; + } + +#else + // 其他 Unix 系统使用 statvfs + struct statvfs buffer; + if (statvfs(path.c_str(), &buffer) != 0) { + std::cerr << "Error retrieving filesystem information for: " << path + << std::endl; + return "Unknown"; + } + + // 其他系统的文件系统类型判断逻辑(这里没有定义特定的类型) + return "Unknown"; +#endif +} + #ifdef _WIN32 bool setReadOnlyWindows(const std::string& driveLetter) { diff --git a/src/atom/sysinfo/disk.hpp b/src/atom/sysinfo/disk.hpp index 66a9b4fc..0b60dfd2 100644 --- a/src/atom/sysinfo/disk.hpp +++ b/src/atom/sysinfo/disk.hpp @@ -19,53 +19,97 @@ Description: System Information Module - Disk #include namespace atom::system { + +/** + * @brief Retrieves the disk usage information for all available disks. + * + * This function scans the system for all connected disks and calculates the + * disk usage for each one. It returns a vector of pairs where the first element + * is the name of the disk (e.g., "/dev/sda1") and the second element is the + * disk's usage percentage. The usage percentage is calculated based on the + * total space and available space on the disk. + * + * @return A vector of pairs where each pair consists of: + * - A string representing the disk name. + * - A float representing the usage percentage of the disk. + */ +[[nodiscard]] auto getDiskUsage() -> std::vector>; + /** - * @brief Get the disk usage for all disks. - * 获取所有磁盘的使用情况 + * @brief Retrieves the model of a specified drive. * - * @return A vector of pairs containing the disk name and its usage percentage. - * 包含磁盘名称和使用率百分比的一对对的向量 + * Given the path to a drive (e.g., "/dev/sda"), this function returns the model + * information of the drive. This information is typically extracted from system + * files or using system-level APIs that provide details about the storage + * hardware. + * + * @param drivePath A string representing the path of the drive. + * For example, "/dev/sda" or "C:\\" on Windows. + * @return A string containing the model name of the drive. */ -[[nodiscard]] std::vector> getDiskUsage(); +[[nodiscard]] auto getDriveModel(const std::string& drivePath) -> std::string; /** - * @brief Get the drive model. - * 获取驱动器型号 + * @brief Retrieves the models of all connected storage devices. + * + * This function queries the system for all connected storage devices (e.g., + * hard drives, SSDs) and returns a list of their names along with their + * respective models. Each element in the returned vector is a pair, where the + * first element is the name of the storage device (e.g., "/dev/sda" or "C:\\" + * on Windows) and the second element is the device's model name. * - * @param drivePath The path of the drive. 驱动器路径 - * @return The drive model. - * 驱动器型号 + * @return A vector of pairs where each pair consists of: + * - A string representing the storage device name. + * - A string representing the model name of the storage device. */ -[[nodiscard]] std::string getDriveModel(const std::string &drivePath); +[[nodiscard]] auto getStorageDeviceModels() + -> std::vector>; /** - * @brief Get the storage device models. - * 获取存储设备型号 + * @brief Retrieves a list of all available drives on the system. + * + * This function returns a list of all available drives currently recognized by + * the system. The drives are represented by their mount points or device paths. + * For example, on Linux, the returned list may contain paths such as + * "/dev/sda1", while on Windows it may contain drive letters like "C:\\". * - * @return A vector of pairs containing the storage device name and its model. - * 包含存储设备名称和型号的一对对的向量 + * @return A vector of strings where each string represents an available drive. */ -[[nodiscard]] std::vector> getStorageDeviceModels(); +[[nodiscard]] auto getAvailableDrives() -> std::vector; /** - * @brief Get the available drives. - * 获取可用驱动器 + * @brief Calculates the disk usage percentage. * - * @return A vector of available drives. - * 可用驱动器的向量 + * Given the total and free space on a disk, this function computes the disk + * usage percentage. The calculation is performed using the formula: + * \f$ \text{Usage Percentage} = \left( \frac{\text{Total Space} - \text{Free + * Space}}{\text{Total Space}} \right) \times 100 \f$ This percentage represents + * how much of the disk space is currently used. + * + * @param totalSpace The total space on the disk, in bytes. + * @param freeSpace The free (available) space on the disk, in bytes. + * @return A double representing the disk usage percentage. The result is + * a value between 0.0 and 100.0. */ -[[nodiscard]] std::vector getAvailableDrives(); +[[nodiscard]] auto calculateDiskUsagePercentage( + unsigned long totalSpace, unsigned long freeSpace) -> double; /** - * @brief Calculate the disk usage percentage. - * 计算磁盘使用率 + * @brief Retrieves the file system type for a specified path. + * + * This function determines the type of the file system used by the disk at the + * specified path. The file system type could be, for example, "ext4" for Linux + * systems, "NTFS" for Windows, or "APFS" for macOS. The function queries the + * system to retrieve this information and returns the file system type as a + * string. * - * @param totalSpace The total space of the disk. 磁盘的总空间 - * @param freeSpace The free space of the disk. 磁盘的可用空间 - * @return The disk usage percentage. 磁盘使用率 + * @param path A string representing the path to the disk or mount point. + * For example, "/dev/sda1" or "C:\\". + * @return A string containing the file system type (e.g., "ext4", "NTFS", + * "APFS"). */ -[[nodiscard]] double calculateDiskUsagePercentage(unsigned long totalSpace, - unsigned long freeSpace); +[[nodiscard]] auto getFileSystemType(const std::string& path) -> std::string; } // namespace atom::system + #endif diff --git a/src/atom/type/CMakeLists.txt b/src/atom/type/CMakeLists.txt index 96fa101c..7e63f34c 100644 --- a/src/atom/type/CMakeLists.txt +++ b/src/atom/type/CMakeLists.txt @@ -11,9 +11,7 @@ project(atom-type C CXX) # Sources set(${PROJECT_NAME}_SOURCES - ini.cpp message.cpp - string.cpp ) # Headers @@ -21,8 +19,6 @@ set(${PROJECT_NAME}_HEADERS args.hpp argsview.hpp flatset.hpp - ini.inl - ini.hpp json.hpp message.hpp pointer.hpp diff --git a/src/atom/type/ini.cpp b/src/atom/type/ini.cpp deleted file mode 100644 index 301d78ee..00000000 --- a/src/atom/type/ini.cpp +++ /dev/null @@ -1,193 +0,0 @@ -/* - * ini.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-17 - -Description: INI File Read/Write Library - -**************************************************/ - -#include "ini.hpp" - -#include -#include - -#include "atom/error/exception.hpp" -#include "atom/utils/string.hpp" - -namespace atom::type { - -auto INIFile::has(const std::string §ion, - const std::string &key) const -> bool { - std::shared_lock lock(m_sharedMutex_); - if (auto it = data_.find(section); it != data_.end()) { - return it->second.contains(key); - } - return false; -} - -auto INIFile::hasSection(const std::string §ion) const -> bool { - std::shared_lock lock(m_sharedMutex_); - return data_.contains(section); -} - -auto INIFile::sections() const -> std::vector { - std::shared_lock lock(m_sharedMutex_); - std::vector result; - result.reserve(data_.size()); - for (const auto &[section, _] : data_) { - result.emplace_back(section); - } - return result; -} - -auto INIFile::keys(const std::string §ion) const - -> std::vector { - std::shared_lock lock(m_sharedMutex_); - if (auto it = data_.find(section); it != data_.end()) { - std::vector result; - result.reserve(it->second.size()); - for (const auto &[key, _] : it->second) { - result.emplace_back(key); - } - return result; - } - return {}; -} - -void INIFile::load(const std::string &filename) { - std::unique_lock lock(m_sharedMutex_); - std::ifstream file(filename); - if (!file) { - THROW_EXCEPTION("Failed to open file: " + filename); - } - - std::string currentSection; - for (std::string line; std::getline(file, line);) { - parseLine(line, currentSection); - } -} - -void INIFile::save(const std::string &filename) const { - std::shared_lock lock(m_sharedMutex_); - std::ofstream file(filename); - if (!file) { - THROW_FILE_NOT_WRITABLE("Failed to create file: ", filename); - } - - for (const auto &[section, entries] : data_) { - file << "[" << section << "]\n"; - for (const auto &[key, value] : entries) { - file << key << "="; - if (value.type() == typeid(int)) { - file << std::any_cast(value); - } else if (value.type() == typeid(float)) { - file << std::any_cast(value); - } else if (value.type() == typeid(double)) { - file << std::any_cast(value); - } else if (value.type() == typeid(std::string)) { - file << std::any_cast(value); - } else if (value.type() == typeid(const char *)) { - file << std::any_cast(value); - } else if (value.type() == typeid(bool)) { - file << std::boolalpha << std::any_cast(value); - } else { - THROW_INVALID_ARGUMENT("Unsupported type"); - } - file << "\n"; - } - file << "\n"; - } -} - -void INIFile::parseLine(std::string_view line, std::string ¤tSection) { - if (line.empty() || line.front() == ';') { - return; - } - if (line.front() == '[') { - auto pos = line.find(']'); - if (pos != std::string_view::npos) { - currentSection = atom::utils::trim(line.substr(1, pos - 1)); - } - } else { - auto pos = line.find('='); - if (pos != std::string_view::npos) { - auto key = atom::utils::trim(line.substr(0, pos)); - auto value = atom::utils::trim(line.substr(pos + 1)); - data_[currentSection][key] = value; - } - } -} - -auto INIFile::toJson() const -> std::string { - std::shared_lock lock(m_sharedMutex_); - std::ostringstream oss; - oss << "{"; - for (const auto &[section, entries] : data_) { - oss << "\"" << section << "\": {"; - for (const auto &[key, value] : entries) { - oss << "\"" << key << "\": "; - if (value.type() == typeid(int)) { - oss << std::any_cast(value); - } else if (value.type() == typeid(float)) { - oss << std::any_cast(value); - } else if (value.type() == typeid(double)) { - oss << std::any_cast(value); - } else if (value.type() == typeid(std::string)) { - oss << "\"" << std::any_cast(value) << "\""; - } else if (value.type() == typeid(const char *)) { - oss << "\"" << std::any_cast(value) << "\""; - } else if (value.type() == typeid(bool)) { - oss << std::boolalpha << std::any_cast(value); - } else { - THROW_INVALID_ARGUMENT("Unsupported type"); - } - oss << ","; - } - oss.seekp(-1, std::ios_base::end); - oss << "},"; - } - oss.seekp(-1, std::ios_base::end); - oss << "}"; - return oss.str(); -} - -auto INIFile::toXml() const -> std::string { - std::shared_lock lock(m_sharedMutex_); - std::ostringstream oss; - oss << "\n"; - oss << "\n"; - for (const auto &[section, entries] : data_) { - oss << "
    \n"; - for (const auto &[key, value] : entries) { - oss << " " << std::any_cast(value); - } else if (value.type() == typeid(float)) { - oss << "float\">" << std::any_cast(value); - } else if (value.type() == typeid(double)) { - oss << "double\">" << std::any_cast(value); - } else if (value.type() == typeid(std::string)) { - oss << "string\">" << std::any_cast(value); - } else if (value.type() == typeid(const char *)) { - oss << "string\">" << std::any_cast(value); - } else if (value.type() == typeid(bool)) { - oss << "bool\">" << std::boolalpha - << std::any_cast(value); - } else { - THROW_INVALID_ARGUMENT("Unsupported type"); - } - oss << "\n"; - } - oss << "
    \n"; - } - oss << "
    \n"; - return oss.str(); -} - -} // namespace atom::type diff --git a/src/atom/type/ini.hpp b/src/atom/type/ini.hpp deleted file mode 100644 index 0326876e..00000000 --- a/src/atom/type/ini.hpp +++ /dev/null @@ -1,130 +0,0 @@ -/* - * ini.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-17 - -Description: INI File Read/Write Library - -**************************************************/ - -#ifndef ATOM_TYPE_INI_HPP -#define ATOM_TYPE_INI_HPP - -#include -#include -#include -#include -#include -#include - -namespace atom::type { -class INIFile { -public: - /** - * @brief 加载INI文件 - * @param filename 文件名 - */ - void load(const std::string &filename); - - /** - * @brief 保存INI文件 - * @param filename 文件名 - */ - void save(const std::string &filename) const; - - /** - * @brief 设置INI文件中的值 - * @tparam T 类型 - * @param section 部分名 - * @param key 键 - * @param value 值 - */ - template - void set(const std::string §ion, const std::string &key, - const T &value); - - /** - * @brief 获取INI文件中的值 - * @tparam T 类型 - * @param section 部分名 - * @param key 键 - * @return 值,如果不存在则返回std::nullopt - */ - template - [[nodiscard]] auto get(const std::string §ion, - const std::string &key) const -> std::optional; - - /** - * @brief 判断INI文件中是否存在指定键 - * @param section 部分名 - * @param key 键 - * @return 存在返回true,否则返回false - */ - [[nodiscard]] auto has(const std::string §ion, - const std::string &key) const -> bool; - - /** - * @brief 判断INI文件中是否存在指定部分 - * @param section 部分名 - * @return 存在返回true,否则返回false - */ - [[nodiscard]] auto hasSection(const std::string §ion) const -> bool; - - /** - * @brief 获取INI文件中所有的section - * @return section列表 - */ - [[nodiscard]] auto sections() const -> std::vector; - - /** - * @brief 获取指定section下的所有key - * @param section 部分名 - * @return key列表 - */ - [[nodiscard]] std::vector keys( - const std::string §ion) const; - - /** - * @brief 获取INI文件中的部分 - * @param section 部分名 - * @return 部分内容 - */ - auto operator[]( - const std::string §ion) -> std::unordered_map & { - return data_[section]; - } - - /** - * @brief 将INI文件中的数据转换为JSON字符串 - * @return JSON字符串 - */ - [[nodiscard]] std::string toJson() const; - - /** - * @brief 将INI文件中的数据转换为XML字符串 - * @return XML字符串 - */ - [[nodiscard]] std::string toXml() const; - -private: - std::unordered_map> - data_; // 存储数据的映射表 - mutable std::shared_mutex m_sharedMutex_; // 共享互斥锁,用于线程安全 - - /** - * @brief 解析INI文件的一行,并更新当前部分 - * @param line 行内容 - * @param currentSection 当前部分 - */ - void parseLine(std::string_view line, std::string ¤tSection); -}; -} // namespace atom::type - -#include "ini.inl" - -#endif diff --git a/src/atom/type/ini.inl b/src/atom/type/ini.inl deleted file mode 100644 index f8a30fd8..00000000 --- a/src/atom/type/ini.inl +++ /dev/null @@ -1,49 +0,0 @@ -/* - * ini.inl - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-6-17 - -Description: INI File Read/Write Library - -**************************************************/ - -#ifndef ATOM_TYPE_INI_INL -#define ATOM_TYPE_INI_INL - -#include "ini.hpp" - -#include - -namespace atom::type { - -template -void INIFile::set(const std::string §ion, const std::string &key, - const T &value) { - std::unique_lock lock(m_sharedMutex_); - data_[section][key] = value; -} - -template -auto INIFile::get(const std::string §ion, - const std::string &key) const -> std::optional { - std::shared_lock lock(m_sharedMutex_); - if (auto it = data_.find(section); it != data_.end()) { - if (auto entryIt = it->second.find(key); entryIt != it->second.end()) { - if constexpr (std::is_same_v) { - return entryIt->second; - } else { - return std::any_cast(entryIt->second); - } - } - } - return std::nullopt; -} - -} // namespace atom::type - -#endif diff --git a/src/atom/type/iter.hpp b/src/atom/type/iter.hpp index 27c184d6..491c3f2e 100644 --- a/src/atom/type/iter.hpp +++ b/src/atom/type/iter.hpp @@ -42,9 +42,9 @@ class PointerIterator { explicit PointerIterator(IteratorT it) : it_(std::move(it)) {} - value_type operator*() const { return &*it_; } + auto operator*() const -> value_type { return &*it_; } - PointerIterator& operator++() { + auto operator++() -> PointerIterator& { ++it_; return *this; } diff --git a/src/atom/type/pod_vector.hpp b/src/atom/type/pod_vector.hpp index fec79e45..b74aac5b 100644 --- a/src/atom/type/pod_vector.hpp +++ b/src/atom/type/pod_vector.hpp @@ -2,12 +2,11 @@ #define ATOM_TYPE_POD_VECTOR_HPP #include -#include #include #include +#include #include #include -#include #include "atom/macro.hpp" @@ -15,66 +14,62 @@ namespace atom::type { template concept PodType = std::is_trivial_v && std::is_standard_layout_v; -ATOM_INLINE auto pool64Alloc(std::size_t size) -> void* { - return std::malloc(size); -} - -ATOM_INLINE void pool64Dealloc(void* ptr) { std::free(ptr); } +template +concept ValueType = requires(T t) { + { std::is_copy_constructible_v }; + { std::is_move_constructible_v }; +}; template -struct PodVector { - static constexpr int SIZE_T = sizeof(T); - static constexpr int N = 64 / SIZE_T; +class PodVector { + static ATOM_CONSTEXPR int SIZE_T = sizeof(T); + static ATOM_CONSTEXPR int N = 64 / SIZE_T; static_assert(N >= 4, "Element size_ too large"); private: - int size_; - int capacity_; - T* data_; + int size_ = 0; + int capacity_ = N; + std::allocator allocator_; + T* data_ = allocator_.allocate(capacity_); public: using size_type = int; - PodVector() : size_(0), capacity_(N) { - data_ = static_cast( - pool64Alloc(static_cast(capacity_ * SIZE_T))); - } + ATOM_CONSTEXPR PodVector() ATOM_NOEXCEPT = default; - PodVector(std::initializer_list il) - : size_(il.size()), capacity_(std::max(N, size_)) { - data_ = static_cast( - pool64Alloc(static_cast(capacity_ * SIZE_T))); - std::copy(il.begin(), il.end(), data_); + ATOM_CONSTEXPR PodVector(std::initializer_list il) + : size_(static_cast(il.size())), + capacity_(std::max(N, size_)), + data_(allocator_.allocate(capacity_)) { + std::ranges::copy(il, data_); } - explicit PodVector(int size_) - : size_(size_), capacity_(std::max(N, size_)) { - data_ = static_cast( - pool64Alloc(static_cast(capacity_ * SIZE_T))); - } + explicit ATOM_CONSTEXPR PodVector(int size_) + : size_(size_), + capacity_(std::max(N, size_)), + data_(allocator_.allocate(capacity_)) {} PodVector(const PodVector& other) - : size_(other.size_), capacity_(other.capacity_) { - data_ = static_cast( - pool64Alloc(static_cast(capacity_ * SIZE_T))); + : size_(other.size_), + capacity_(other.capacity_), + data_(allocator_.allocate(capacity_)) { std::memcpy(data_, other.data_, SIZE_T * size_); } - PodVector(PodVector&& other) noexcept - : size_(other.size_), capacity_(other.capacity_), data_(other.data_) { - other.data_ = nullptr; - } + PodVector(PodVector&& other) ATOM_NOEXCEPT + : size_(other.size_), + capacity_(other.capacity_), + data_(std::exchange(other.data_, nullptr)) {} - auto operator=(PodVector&& other) noexcept -> PodVector& { + auto operator=(PodVector&& other) ATOM_NOEXCEPT -> PodVector& { if (this != &other) { if (data_ != nullptr) { - pool64Dealloc(data_); + allocator_.deallocate(data_, capacity_); } size_ = other.size_; capacity_ = other.capacity_; - data_ = other.data_; - other.data_ = nullptr; + data_ = std::exchange(other.data_, nullptr); } return *this; } @@ -83,7 +78,7 @@ struct PodVector { template void pushBack(ValueT&& t) { - if (size_ == capacity_) { + if (size_ == capacity_) [[unlikely]] { reserve(capacity_ * Growth); } data_[size_++] = std::forward(t); @@ -91,32 +86,28 @@ struct PodVector { template void emplaceBack(Args&&... args) { - if (size_ == capacity_) { + if (size_ == capacity_) [[unlikely]] { reserve(capacity_ * Growth); } new (&data_[size_++]) T(std::forward(args)...); } - void reserve(int cap) { - if (cap <= capacity_) { + ATOM_CONSTEXPR void reserve(int cap) { + if (cap <= capacity_) [[likely]] { return; } - capacity_ = cap; - T* oldData = data_; - data_ = static_cast( - pool64Alloc(static_cast(capacity_ * SIZE_T))); - if (oldData != nullptr) { - std::memcpy(data_, oldData, SIZE_T * size_); - pool64Dealloc(oldData); + T* newData = allocator_.allocate(cap); + if (data_ != nullptr) { + std::memcpy(newData, data_, SIZE_T * size_); + allocator_.deallocate(data_, capacity_); } + data_ = newData; + capacity_ = cap; } - void popBack() { size_--; } - auto popxBack() -> T { - T t = std::move(data_[size_ - 1]); - size_--; - return t; - } + ATOM_CONSTEXPR void popBack() ATOM_NOEXCEPT { size_--; } + + ATOM_CONSTEXPR auto popxBack() -> T { return std::move(data_[--size_]); } void extend(const PodVector& other) { for (const auto& elem : other) { @@ -130,21 +121,25 @@ struct PodVector { } } - auto operator[](int index) -> T& { return data_[index]; } - auto operator[](int index) const -> const T& { return data_[index]; } + ATOM_CONSTEXPR auto operator[](int index) -> T& { return data_[index]; } + ATOM_CONSTEXPR auto operator[](int index) const -> const T& { + return data_[index]; + } - auto begin() -> T* { return data_; } - auto end() -> T* { return data_ + size_; } - auto begin() const -> const T* { return data_; } - auto end() const -> const T* { return data_ + size_; } - auto back() -> T& { return data_[size_ - 1]; } - auto back() const -> const T& { return data_[size_ - 1]; } + ATOM_CONSTEXPR auto begin() ATOM_NOEXCEPT -> T* { return data_; } + ATOM_CONSTEXPR auto end() ATOM_NOEXCEPT -> T* { return data_ + size_; } + ATOM_CONSTEXPR auto begin() const ATOM_NOEXCEPT -> const T* { return data_; } + ATOM_CONSTEXPR auto end() const ATOM_NOEXCEPT -> const T* { return data_ + size_; } + ATOM_CONSTEXPR auto back() -> T& { return data_[size_ - 1]; } + ATOM_CONSTEXPR auto back() const -> const T& { return data_[size_ - 1]; } - [[nodiscard]] auto empty() const -> bool { return size_ == 0; } - [[nodiscard]] auto size() const -> int { return size_; } - auto data() -> T* { return data_; } - auto data() const -> const T* { return data_; } - void clear() { size_ = 0; } + ATOM_NODISCARD ATOM_CONSTEXPR auto empty() const ATOM_NOEXCEPT -> bool { + return size_ == 0; + } + ATOM_NODISCARD ATOM_CONSTEXPR auto size() const ATOM_NOEXCEPT -> int { return size_; } + ATOM_CONSTEXPR auto data() ATOM_NOEXCEPT -> T* { return data_; } + ATOM_CONSTEXPR auto data() const ATOM_NOEXCEPT -> const T* { return data_; } + ATOM_CONSTEXPR void clear() ATOM_NOEXCEPT { size_ = 0; } template void insert(int i, ValueT&& val) { @@ -158,23 +153,21 @@ struct PodVector { size_++; } - void erase(int i) { - for (int j = i; j < size_ - 1; j++) { - data_[j] = data_[j + 1]; - } + ATOM_CONSTEXPR void erase(int i) { + std::ranges::copy(data_ + i + 1, data_ + size_, data_ + i); size_--; } - void reverse() { std::reverse(data_, data_ + size_); } + ATOM_CONSTEXPR void reverse() { std::ranges::reverse(data_, data_ + size_); } - void resize(int size_) { + ATOM_CONSTEXPR void resize(int size_) { if (size_ > capacity_) { reserve(size_); } - size_ = size_; + this->size_ = size_; } - auto detach() noexcept -> std::pair { + auto detach() ATOM_NOEXCEPT -> std::pair { T* p = data_; int size = size_; data_ = nullptr; @@ -184,38 +177,13 @@ struct PodVector { ~PodVector() { if (data_ != nullptr) { - pool64Dealloc(data_); + allocator_.deallocate(data_, capacity_); } } - [[nodiscard]] auto capacity() const -> size_t { return capacity_; } -}; - -template > -class Stack { - Container vec_; - -public: - void push(const T& t) { vec_.push_back(t); } - void push(T&& t) { vec_.push_back(std::move(t)); } - template - void emplace(Args&&... args) { - vec_.emplace_back(std::forward(args)...); - } - void pop() { vec_.pop_back(); } - void clear() { vec_.clear(); } - [[nodiscard]] auto empty() const -> bool { return vec_.empty(); } - auto size() const -> typename Container::size_type { return vec_.size(); } - auto top() -> T& { return vec_.back(); } - auto top() const -> const T& { return vec_.back(); } - auto popx() -> T { - T t = std::move(vec_.back()); - vec_.pop_back(); - return t; - } - void reserve(int n) { vec_.reserve(n); } - auto container() -> Container& { return vec_; } - auto container() const -> const Container& { return vec_; } + ATOM_NODISCARD ATOM_CONSTEXPR auto capacity() const ATOM_NOEXCEPT -> int { + return capacity_; + } }; } // namespace atom::type diff --git a/src/atom/type/rtype.hpp b/src/atom/type/rtype.hpp index 9341997b..5c3f40bf 100644 --- a/src/atom/type/rtype.hpp +++ b/src/atom/type/rtype.hpp @@ -1,8 +1,11 @@ +#ifndef ATOM_TYPE_RTYPE_HPP +#define ATOM_TYPE_RTYPE_HPP + #include #include -#include "atom/function/concept.hpp" #include "atom/error/exception.hpp" +#include "atom/function/concept.hpp" #include "rjson.hpp" #include "ryaml.hpp" @@ -58,39 +61,25 @@ struct Reflectable { typename std::decay_t::member_type; if (it != j.end()) { - if constexpr (std::is_same_v) { + if constexpr (String || Char) { obj.*(field.member) = it->second.as_string(); - } else if constexpr (std::is_same_v) { + } else if constexpr (Number) { obj.*(field.member) = static_cast(it->second.as_number()); - } else if constexpr (std::is_same_v) { - obj.*(field.member) = it->second.as_number(); } else if constexpr (std::is_same_v) { obj.*(field.member) = it->second.as_bool(); - } else if constexpr (std::is_same_v< - MemberType, - std::vector>) { + } else if constexpr (StringContainer) { for (const auto& item : it->second.as_array()) { (obj.*(field.member)) .push_back(item.as_string()); } - } else if constexpr (std::is_same_v< - MemberType, - std::vector>) { + } else if constexpr (NumberContainer) { for (const auto& item : it->second.as_array()) { (obj.*(field.member)) .push_back( - static_cast(item.as_number())); - } - } else if constexpr (std::is_same_v< - MemberType, - std::vector>) { - for (const auto& item : it->second.as_array()) { - (obj.*(field.member)) - .push_back(item.as_number()); + static_cast( + item.as_number())); } } else if constexpr (std::is_class_v) { // 处理复杂对象反射 @@ -113,9 +102,8 @@ struct Reflectable { if (!field.required) { obj.*(field.member) = field.default_value; } else { - THROW_INVALID_ARGUMENT( - "Missing required field: " + - std::string(field.name)); + THROW_INVALID_ARGUMENT("Missing required field: " + + std::string(field.name)); } } }()), @@ -132,30 +120,24 @@ struct Reflectable { (([&] { using MemberType = typename std::decay_t::member_type; - if constexpr (std::is_same_v) { + if constexpr (String || Char) { j[field.name] = JsonValue(obj.*(field.member)); - } else if constexpr (std::is_same_v || - std::is_same_v) { + } else if constexpr (Number) { j[field.name] = JsonValue( - static_cast(obj.*(field.member))); + static_cast(obj.*(field.member))); } else if constexpr (std::is_same_v) { j[field.name] = JsonValue(obj.*(field.member)); - } else if constexpr (std::is_same_v< - MemberType, - std::vector>) { + } else if constexpr (StringContainer) { JsonArray arr; for (const auto& item : obj.*(field.member)) { arr.push_back(JsonValue(item)); } j[field.name] = JsonValue(arr); - } else if constexpr (std::is_same_v> || - std::is_same_v>) { + } else if constexpr (NumberContainer) { JsonArray arr; for (const auto& item : obj.*(field.member)) { - arr.push_back( - JsonValue(static_cast(item))); + arr.push_back(JsonValue( + static_cast(item))); } j[field.name] = JsonValue(arr); } else if constexpr (std::is_class_v) { @@ -163,9 +145,8 @@ struct Reflectable { j[field.name] = JsonValue( field.reflect_type.to_json(obj.*(field.member))); } else { - THROW_INVALID_ARGUMENT( - "Unsupported type for field: " + - std::string(field.name)); + THROW_INVALID_ARGUMENT("Unsupported type for field: " + + std::string(field.name)); } }()), ...); @@ -184,8 +165,7 @@ struct Reflectable { typename std::decay_t::member_type; if (it != y.end()) { - if constexpr (std::is_same_v) { + if constexpr (String || Char) { obj.*(field.member) = it->second.as_string(); } else if constexpr (Number) { obj.*(field.member) = @@ -193,7 +173,7 @@ struct Reflectable { } else if constexpr (std::is_same_v) { obj.*(field.member) = it->second.as_bool(); - } else if constexpr (String) { + } else if constexpr (StringContainer) { for (const auto& item : it->second.as_array()) { (obj.*(field.member)) .push_back(item.as_string()); @@ -264,9 +244,8 @@ struct Reflectable { if (!field.required) { obj.*(field.member) = field.default_value; } else { - THROW_INVALID_ARGUMENT( - "Missing required field: " + - std::string(field.name)); + THROW_INVALID_ARGUMENT("Missing required field: " + + std::string(field.name)); } } }()), @@ -313,9 +292,8 @@ struct Reflectable { y[field.name] = YamlValue( field.reflect_type.to_yaml(obj.*(field.member))); } else { - THROW_INVALID_ARGUMENT( - "Unsupported type for field: " + - std::string(field.name)); + THROW_INVALID_ARGUMENT("Unsupported type for field: " + + std::string(field.name)); } }()), ...); @@ -345,3 +323,5 @@ auto make_field(const char* name, const char* description, reflect_type); } } // namespace atom::type + +#endif diff --git a/src/atom/type/string.cpp b/src/atom/type/string.cpp deleted file mode 100644 index b8023a9a..00000000 --- a/src/atom/type/string.cpp +++ /dev/null @@ -1,210 +0,0 @@ -/* - * string.cpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2024-2-10 - -Description: A super enhanced string class. - -**************************************************/ - -#include "string.hpp" - -#include - -String::String(const char *str) : m_data_(str) {} - -String::String(std::string_view str) : m_data_(str) {} - -String::String(const std::string &str) : m_data_(str) {} - -auto String::operator==(const String &other) const -> bool { - return m_data_ == other.m_data_; -} - -auto String::operator!=(const String &other) const -> bool { - return m_data_ != other.m_data_; -} - -auto String::empty() const -> bool { return m_data_.empty(); } - -auto String::operator<(const String &other) const -> bool { - return m_data_ < other.m_data_; -} - -auto String::operator>(const String &other) const -> bool { - return m_data_ > other.m_data_; -} - -auto String::operator<=(const String &other) const -> bool { - return m_data_ <= other.m_data_; -} - -auto String::operator>=(const String &other) const -> bool { - return m_data_ >= other.m_data_; -} - -auto String::operator+=(const String &other) -> String & { - m_data_ += other.m_data_; - return *this; -} - -auto String::operator+=(const char *str) -> String & { - m_data_ += str; - return *this; -} - -auto String::operator+=(char c) -> String & { - m_data_ += c; - return *this; -} - -auto String::cStr() const -> const char * { return m_data_.c_str(); } - -auto String::length() const -> size_t { return m_data_.length(); } - -auto String::substr(size_t pos, size_t count) const -> String { - return m_data_.substr(pos, count); -} - -auto String::find(const String &str, size_t pos) const -> size_t { - return m_data_.find(str.m_data_, pos); -} - -auto String::replace(const String &oldStr, const String &newStr) -> bool { - size_t pos = m_data_.find(oldStr.m_data_); - if (pos != std::string::npos) { - m_data_.replace(pos, oldStr.length(), newStr.m_data_); - return true; - } - return false; -} - -auto String::replaceAll(const String &oldStr, const String &newStr) -> size_t { - size_t count = 0; - size_t pos = 0; - - while ((pos = m_data_.find(oldStr.m_data_, pos)) != std::string::npos) { - m_data_.replace(pos, oldStr.length(), newStr.m_data_); - pos += newStr.length(); - ++count; - } - - return count; -} - -auto String::toUpper() const -> String { - String result; - std::transform(m_data_.begin(), m_data_.end(), - std::back_inserter(result.m_data_), - [](unsigned char c) { return std::toupper(c); }); - return result; -} - -auto String::toLower() const -> String { - String result; - std::transform(m_data_.begin(), m_data_.end(), - std::back_inserter(result.m_data_), - [](unsigned char c) { return std::tolower(c); }); - return result; -} - -auto String::split(const String &delimiter) const -> std::vector { - std::vector tokens; - size_t start = 0; - size_t end = m_data_.find(delimiter.m_data_); - - while (end != std::string::npos) { - tokens.emplace_back(substr(start, end - start)); - start = end + delimiter.length(); - end = m_data_.find(delimiter.m_data_, start); - } - - tokens.emplace_back(substr(start)); - - return tokens; -} - -auto String::join(const std::vector &strings, - const String &separator) -> String { - String result; - for (size_t i = 0; i < strings.size(); ++i) { - if (i > 0) { - result += separator; - } - result += strings[i]; - } - return result; -} - -void String::insert(size_t pos, char c) { m_data_.insert(pos, 1, c); } - -void String::erase(size_t pos, size_t count) { m_data_.erase(pos, count); } - -auto String::reverse() const -> String { - String result(m_data_); - std::reverse(result.m_data_.begin(), result.m_data_.end()); - return result; -} - -auto String::equalsIgnoreCase(const String &other) const -> bool { - return std::equal(m_data_.begin(), m_data_.end(), other.m_data_.begin(), - other.m_data_.end(), [](char a, char b) { - return std::tolower(a) == std::tolower(b); - }); -} - -auto String::startsWith(const String &prefix) const -> bool { - return m_data_.find(prefix.m_data_) == 0; -} - -auto String::endsWith(const String &suffix) const -> bool { - if (suffix.length() > m_data_.length()) { - return false; - } - return std::equal(suffix.m_data_.rbegin(), suffix.m_data_.rend(), - m_data_.rbegin()); -} - -void String::trim() { - auto start = - std::find_if_not(m_data_.begin(), m_data_.end(), - [](unsigned char c) { return std::isspace(c); }); - auto end = - std::find_if_not(m_data_.rbegin(), m_data_.rend(), [](unsigned char c) { - return std::isspace(c); - }).base(); - m_data_ = std::string(start, end); -} - -void String::ltrim() { - auto start = - std::find_if_not(m_data_.begin(), m_data_.end(), - [](unsigned char c) { return std::isspace(c); }); - m_data_.erase(m_data_.begin(), start); -} - -void String::rtrim() { - auto end = - std::find_if_not(m_data_.rbegin(), m_data_.rend(), [](unsigned char c) { - return std::isspace(c); - }).base(); - m_data_.erase(end, m_data_.end()); -} - -auto String::data() const -> std::string { return m_data_; } - -auto operator+(const String &lhs, const String &rhs) -> String { - String result(lhs); - result += rhs; - return result; -} - -auto operator<<(std::ostream &os, const String &str) -> std::ostream & { - os << str.data(); - return os; -} diff --git a/src/atom/type/string.hpp b/src/atom/type/string.hpp index 521afbaa..40abc212 100644 --- a/src/atom/type/string.hpp +++ b/src/atom/type/string.hpp @@ -12,14 +12,21 @@ Description: A super enhanced string class. **************************************************/ -#ifndef ATOM_EXPERIMENT_STRING_HPP -#define ATOM_EXPERIMENT_STRING_HPP +#ifndef ATOM_TYPE_STRING_HPP +#define ATOM_TYPE_STRING_HPP +#include #include +#include +#include +#include +#include #include #include #include +#include "macro.hpp" + /** * @brief A super enhanced string class. */ @@ -31,269 +38,412 @@ class String { String() = default; /** - * @brief Constructor. - * @param str - C-style string. + * @brief Constructor from C-style string. */ - String(const char *str); + String(const char *str) : m_data_(str) {} /** - * @brief Constructor. - * @param str - std::string_view. + * @brief Constructor from std::string_view. */ - String(std::string_view str); + String(std::string_view str) : m_data_(str) {} /** - * @brief Constructor. - * @param str - std::string. + * @brief Constructor from std::string. */ - String(const std::string &str); + String(std::string str) : m_data_(std::move(str)) {} /** * @brief Copy constructor. - * @param other - other String. */ String(const String &other) = default; /** * @brief Move constructor. - * @param other - other String. */ - String(String &&other) noexcept = default; + String(String &&other) ATOM_NOEXCEPT = default; /** * @brief Copy assignment. - * @param other - other String. */ auto operator=(const String &other) -> String & = default; /** * @brief Move assignment. - * @param other - other String. */ - auto operator=(String &&other) noexcept -> String & = default; + auto operator=(String &&other) ATOM_NOEXCEPT->String & = default; /** - * @brief Equality. - * @param other - other String. + * @brief Equality comparison. */ - auto operator==(const String &other) const -> bool; + auto operator==(const String &other) const -> bool = default; /** - * @brief Inequality. - * @param other - other String. + * @brief Three-way comparison (C++20). */ - auto operator!=(const String &other) const -> bool; + auto operator<=>(const String &other) const = default; /** - * @brief Check if the String is empty. + * @brief Concatenation with another String. */ - [[nodiscard]] auto empty() const -> bool; + auto operator+=(const String &other) -> String & { + m_data_ += other.m_data_; + return *this; + } /** - * @brief Less than. - * @param other - other String. + * @brief Concatenation with C-style string. */ - auto operator<(const String &other) const -> bool; + auto operator+=(const char *str) -> String & { + m_data_ += str; + return *this; + } /** - * @brief Greater than. - * @param other - other String. + * @brief Concatenation with a single character. */ - auto operator>(const String &other) const -> bool; + auto operator+=(char c) -> String & { + m_data_ += c; + return *this; + } /** - * @brief Less than or equal. - * @param other - other String. + * @brief Get C-style string. */ - auto operator<=(const String &other) const -> bool; + ATOM_NODISCARD auto cStr() const -> const char * { return m_data_.c_str(); } /** - * @brief Greater than or equal. - * @param other - other String. + * @brief Get length of the string. */ - auto operator>=(const String &other) const -> bool; + ATOM_NODISCARD auto length() const -> size_t { return m_data_.length(); } /** - * @brief Concatenation. - * @param other - other String. + * @brief Get substring. */ - auto operator+=(const String &other) -> String &; + ATOM_NODISCARD auto substr( + size_t pos, size_t count = std::string::npos) const -> String { + return m_data_.substr(pos, count); + } /** - * @brief Concatenation. - * @param str - C-style string. + * @brief Find a substring. */ - auto operator+=(const char *str) -> String &; + ATOM_NODISCARD auto find(const String &str, + size_t pos = 0) const -> size_t { + return m_data_.find(str.m_data_, pos); + } /** - * @brief Concatenation. - * @param c - char. + * @brief Replace first occurrence of oldStr with newStr. */ - auto operator+=(char c) -> String &; + auto replace(const String &oldStr, const String &newStr) -> bool { + if (size_t pos = m_data_.find(oldStr.m_data_); + pos != std::string::npos) { + m_data_.replace(pos, oldStr.length(), newStr.m_data_); + return true; + } + return false; + } /** - * @brief Get C-style string. + * @brief Replace all occurrences of oldStr with newStr. */ - [[nodiscard]] auto cStr() const -> const char *; + auto replaceAll(const String &oldStr, const String &newStr) -> size_t { + size_t count = 0; + size_t pos = 0; + + while ((pos = m_data_.find(oldStr.m_data_, pos)) != std::string::npos) { + m_data_.replace(pos, oldStr.length(), newStr.m_data_); + pos += newStr.length(); + ++count; + } + + return count; + } /** - * @brief Get length. + * @brief Convert string to uppercase. */ - [[nodiscard]] auto length() const -> size_t; + ATOM_NODISCARD auto toUpper() const -> String { + String result; + std::ranges::transform(m_data_, std::back_inserter(result.m_data_), + [](unsigned char c) { return std::toupper(c); }); + return result; + } /** - * @brief Get substring. - * @param pos - start position. - * @param count - length. + * @brief Convert string to lowercase. */ - [[nodiscard]] auto substr(size_t pos, - size_t count = std::string::npos) const -> String; + ATOM_NODISCARD auto toLower() const -> String { + String result; + std::ranges::transform(m_data_, std::back_inserter(result.m_data_), + [](unsigned char c) { return std::tolower(c); }); + return result; + } /** - * @brief Find. - * @param str - string to find. - * @param pos - start position. + * @brief Split the string by a delimiter. */ - [[nodiscard]] auto find(const String &str, size_t pos = 0) const -> size_t; + ATOM_NODISCARD auto split(const String &delimiter) const + -> std::vector { + if (delimiter.empty()) { + return {*this}; + } + if (m_data_.empty()) { + return {}; + } + std::vector tokens; + size_t start = 0; + size_t end = m_data_.find(delimiter.m_data_); + + while (end != std::string::npos) { + tokens.emplace_back(substr(start, end - start)); + start = end + delimiter.length(); + end = m_data_.find(delimiter.m_data_, start); + } + + tokens.emplace_back(substr(start)); + + return tokens; + } /** - * @brief Replace. - * @param oldStr - old string. - * @param newStr - new string. + * @brief Join a vector of strings with a separator. */ - auto replace(const String &oldStr, const String &newStr) -> bool; + static auto join(const std::vector &strings, + const String &separator) -> String { + String result; + for (size_t i = 0; i < strings.size(); ++i) { + if (i > 0) { + result += separator; + } + result += strings[i]; + } + return result; + } /** - * @brief Replace all. - * @param oldStr - old string. - * @param newStr - new string. + * @brief Insert a character at a position. */ - auto replaceAll(const String &oldStr, const String &newStr) -> size_t; + void insert(size_t pos, char c) { m_data_.insert(pos, 1, c); } /** - * @brief To uppercase. + * @brief Erase a portion of the string. */ - [[nodiscard]] auto toUpper() const -> String; + void erase(size_t pos = 0, size_t count = std::string::npos) { + m_data_.erase(pos, count); + } /** - * @brief To lowercase. + * @brief Reverse the string. */ - [[nodiscard]] auto toLower() const -> String; + ATOM_NODISCARD auto reverse() const -> String { + String result(m_data_); + std::ranges::reverse(result.m_data_); + return result; + } /** - * @brief Split. - * @param delimiter - delimiter. + * @brief Case-insensitive comparison. */ - [[nodiscard]] auto split(const String &delimiter) const - -> std::vector; + ATOM_NODISCARD auto equalsIgnoreCase(const String &other) const -> bool { + return std::ranges::equal(m_data_, other.m_data_, [](char a, char b) { + return std::tolower(a) == std::tolower(b); + }); + } /** - * @brief Join. - * @param strings - strings. - * @param separator - separator. + * @brief Check if string starts with a prefix. */ - static auto join(const std::vector &strings, - const String &separator) -> String; + ATOM_NODISCARD auto startsWith(const String &prefix) const -> bool { + return m_data_.starts_with(prefix.m_data_); + } /** - * @brief Insert char. - * @param pos - position. - * @param c - char. + * @brief Check if string ends with a suffix. */ - void insert(size_t pos, char c); + ATOM_NODISCARD auto endsWith(const String &suffix) const -> bool { + return m_data_.ends_with(suffix.m_data_); + } /** - * @brief Erase char. - * @param pos - position. - * */ - void erase(size_t pos = 0, size_t count = std::string::npos); + * @brief Trim whitespace from both ends. + */ + void trim() { + ltrim(); + rtrim(); + } /** - * @brief Reverse. + * @brief Left trim. + */ + void ltrim() { + m_data_.erase(m_data_.begin(), + std::ranges::find_if_not(m_data_, [](unsigned char c) { + return std::isspace(c); + })); + } + + /** + * @brief Right trim. */ - [[nodiscard]] auto reverse() const -> String; + void rtrim() { + m_data_.erase(std::ranges::find_if_not( + m_data_.rbegin(), m_data_.rend(), + [](unsigned char c) { return std::isspace(c); }) + .base(), + m_data_.end()); + } /** - * @brief Equals ignore case. - * @param other - other String. + * @brief Get the underlying data as a std::string. */ - [[nodiscard]] auto equalsIgnoreCase(const String &other) const -> bool; + ATOM_NODISCARD auto data() const -> std::string { return m_data_; } + + ATOM_NODISCARD auto empty() const -> bool { return m_data_.empty(); } + + auto replace(char oldChar, char newChar) -> size_t { + size_t count = 0; + for (auto &c : m_data_) { + if (c == oldChar) { + c = newChar; + ++count; + } + } + return count; + } + + auto remove(char ch) -> size_t { + size_t count = std::erase(m_data_, ch); + return count; + } /** - * @brief Starts with. - * @param prefix - prefix. + * @brief Pad the string from the left with a specific character. */ - [[nodiscard]] auto startsWith(const String &prefix) const -> bool; + auto padLeft(size_t totalLength, char paddingChar = ' ') -> String & { + if (m_data_.length() < totalLength) { + m_data_.insert(m_data_.begin(), totalLength - m_data_.length(), + paddingChar); + } + return *this; + } /** - * @brief Ends with. - * @param suffix - suffix. + * @brief Pad the string from the right with a specific character. */ - [[nodiscard]] auto endsWith(const String &suffix) const -> bool; + auto padRight(size_t totalLength, char paddingChar = ' ') -> String & { + if (m_data_.length() < totalLength) { + m_data_.append(totalLength - m_data_.length(), paddingChar); + } + return *this; + } /** - * @brief Trim. + * @brief Remove a specific prefix from the string. */ - void trim(); + auto removePrefix(const String &prefix) -> bool { + if (startsWith(prefix)) { + m_data_.erase(0, prefix.length()); + return true; + } + return false; + } /** - * @brief Left trim. + * @brief Remove a specific suffix from the string. */ - void ltrim(); + auto removeSuffix(const String &suffix) -> bool { + if (endsWith(suffix)) { + m_data_.erase(m_data_.length() - suffix.length()); + return true; + } + return false; + } /** - * @brief Right trim. + * @brief Check if the string contains a substring. */ - void rtrim(); + ATOM_NODISCARD auto contains(const String &str) const -> bool { + return m_data_.find(str.m_data_) != std::string::npos; + } /** - * @brief Format. - * @param format - format string. - * @param ... - arguments. + * @brief Check if the string contains a specific character. */ + ATOM_NODISCARD auto contains(char c) const -> bool { + return m_data_.find(c) != std::string::npos; + } + + /** + * @brief Compress multiple consecutive spaces into a single space. + */ + void compressSpaces() { + auto newEnd = + std::unique(m_data_.begin(), m_data_.end(), [](char lhs, char rhs) { + return std::isspace(lhs) && std::isspace(rhs); + }); + m_data_.erase(newEnd, m_data_.end()); + } + + /** + * @brief Reverse the order of words in the string. + */ + ATOM_NODISCARD auto reverseWords() const -> String { + auto words = split(" "); + std::ranges::reverse(words); + return join(words, " "); + } + + auto replaceRegex(const std::string &pattern, + const std::string &replacement) -> String { + std::regex re(pattern); + return std::regex_replace(m_data_, re, replacement); + } + + /** + * @brief Format a string. + */ + template - static String format(const char *format, Args &&...args) { - int size = - std::snprintf(nullptr, 0, format, std::forward(args)...); - String result; - result.m_data_.resize(size + 1); - std::snprintf(result.m_data_.data(), size + 1, format, - std::forward(args)...); - result.m_data_.pop_back(); - return result; + static auto format(const std::string &format_str, + Args &&...args) -> std::string { + return std::format(format_str, std::forward(args)...); } static constexpr size_t NPOS = std::string::npos; - [[nodiscard]] auto data() const -> std::string; - private: std::string m_data_; }; /** - * @brief Concatenation. - * @param lhs - left operand. - * @param rhs - right operand. - * @return - result. + * @brief Concatenation operator for String class. */ -auto operator+(const String &lhs, const String &rhs) -> String; +ATOM_INLINE auto operator+(const String &lhs, const String &rhs) -> String { + String result(lhs); + result += rhs; + return result; +} /** - * @brief Output stream operator. - * @param os - output stream. - * @param str - string. - * @return - output stream. + * @brief Output stream operator for String class. */ -auto operator<<(std::ostream &os, const String &str) -> std::ostream &; +ATOM_INLINE auto operator<<(std::ostream &os, + const String &str) -> std::ostream & { + os << str.data(); + return os; +} namespace std { +/** + * @brief Specialization of std::hash for String class. + */ template <> struct hash { - auto operator()(const String &str) const -> size_t { + auto operator()(const String &str) const ATOM_NOEXCEPT->size_t { return std::hash()(str.data()); } }; } // namespace std -#endif +#endif // ATOM_TYPE_STRING_HPP diff --git a/src/atom/utils/anyutils.hpp b/src/atom/utils/anyutils.hpp index 2ff13690..e7cd5575 100644 --- a/src/atom/utils/anyutils.hpp +++ b/src/atom/utils/anyutils.hpp @@ -21,6 +21,8 @@ Description: A collection of useful functions with std::any Or Any #include #include +#include "atom/function/concept.hpp" + template concept CanBeStringified = requires(T t) { { toString(t) } -> std::convertible_to; @@ -31,22 +33,10 @@ concept CanBeStringifiedToJson = requires(T t) { { toJson(t) } -> std::convertible_to; }; -template -concept IsBuiltIn = - std::is_fundamental_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v; - -template -concept ContainerLike = requires(const Container &c) { - { c.begin() } -> std::input_iterator; - { c.end() } -> std::input_iterator; -}; - template [[nodiscard]] std::string toString(const T &value, bool prettyPrint = false); -template +template [[nodiscard]] std::string toString(const Container &container, bool prettyPrint = false) { std::string result = "["; @@ -65,8 +55,8 @@ template } template -[[nodiscard]] std::string toString(const std::unordered_map &map, - bool prettyPrint = false) { +[[nodiscard]] auto toString(const std::unordered_map &map, + bool prettyPrint = false) -> std::string { std::string result = "{"; for (const auto &pair : map) { result += toString(pair.first, prettyPrint) + ": " + @@ -110,7 +100,7 @@ template template [[nodiscard]] std::string toJson(const T &value, bool prettyPrint = false); -template +template [[nodiscard]] std::string toJson(const Container &container, bool prettyPrint = false) { std::string result = "["; @@ -168,7 +158,7 @@ template template [[nodiscard]] std::string toXml(const T &value, const std::string &tagName); -template +template [[nodiscard]] std::string toXml(const Container &container, const std::string &tagName) { std::string result; @@ -221,7 +211,7 @@ template template [[nodiscard]] std::string toYaml(const T &value, const std::string &key); -template +template [[nodiscard]] std::string toYaml(const Container &container, const std::string &key) { std::string result = key.empty() ? "" : key + ":\n"; @@ -290,7 +280,7 @@ template template [[nodiscard]] std::string toToml(const T &value, const std::string &key); -template +template [[nodiscard]] std::string toToml(const Container &container, const std::string &key) { std::string result = key + " = [\n"; diff --git a/src/atom/utils/to_string.hpp b/src/atom/utils/to_string.hpp index bbfe3ca1..a87a8eb5 100644 --- a/src/atom/utils/to_string.hpp +++ b/src/atom/utils/to_string.hpp @@ -4,20 +4,9 @@ * Copyright (C) 2023-2024 Max Qian */ -/************************************************* +#ifndef ATOM_UTILS_TO_STRING_HPP +#define ATOM_UTILS_TO_STRING_HPP -Date: 2023-4-5 - -Description: Implementation of command line generator. - -**************************************************/ - -#ifndef ATOM_EXPERIMENT_STRINGUTILS_HPP -#define ATOM_EXPERIMENT_STRINGUTILS_HPP - -#include -#include -#include #include #include #include @@ -27,154 +16,137 @@ Description: Implementation of command line generator. #include #include -namespace atom::utils { -template -constexpr bool is_string_type = - std::is_same_v || std::is_same_v || - std::is_same_v; +#include "atom/function/concept.hpp" -template -struct is_container : std::false_type {}; +namespace atom::utils { -template -struct is_container_helper {}; +// ----------------------------------------------------------------------------- +// Concepts +// ----------------------------------------------------------------------------- template -struct is_container< - T, - std::conditional_t().begin()), - decltype(std::declval().end()), - decltype(std::declval().size()), - typename T::value_type>, - void>> - : public std::integral_constant> {}; +concept StringType = String || Char || std::is_same_v; template -struct is_map : std::false_type {}; - -template -struct is_map> : std::true_type {}; - -template -struct is_map> : std::true_type {}; +concept Container = requires(T t) { + t.begin(); + t.end(); + t.size(); +}; -#if __cplusplus >= 202002L +template +concept MapType = requires(T t) { + typename T::key_type; + typename T::mapped_type; + requires Container; +}; template -concept BasicType = std::is_arithmetic_v; +concept PointerType = std::is_pointer_v; -/** - * @brief Check if a type is a basic type. - * @tparam T The type to check. - * @return True if the type is a basic type, false otherwise. - */ template -concept StringType = requires(T a) { - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v; -}; +concept EnumType = std::is_enum_v; + +// ----------------------------------------------------------------------------- +// toString Implementation +// ----------------------------------------------------------------------------- /** - * @brief Check if a type is a basic type. - * @tparam T The type to check. - * @return True if the type is a basic type, false otherwise. + * @brief Convert a non-container, non-map, non-pointer, non-enum type to a + * string. + * @tparam T The type of the value. + * @param value The value to convert. + * @return A string representation of the value. */ template -concept SequenceContainer = requires(T a) { - typename T::value_type; - { a.begin() } -> std::forward_iterator; - { a.end() } -> std::forward_iterator; -}; + requires(!Container && !MapType && !StringType && + !PointerType && !EnumType && !SmartPointer) +auto toString(const T& value) -> std::string { + std::ostringstream oss; + oss << value; + return oss.str(); +} /** - * @brief Check if a type is an associative container. - * @tparam T The type to check. - * @return True if the type is an associative container, false otherwise. + * @brief Convert a string type to a string. + * @tparam T The type of the string (e.g., std::string, const char*). + * @param value The string to convert. + * @return The string itself. */ -template -concept AssociativeContainer = requires(T a) { - typename T::key_type; - typename T::mapped_type; - { a.begin() } -> std::forward_iterator; - { a.end() } -> std::forward_iterator; -}; +template + requires (!Container) +auto toString(const T& value) -> std::string { + return std::string(value); +} /** - * @brief Check if a type is a smart pointer. - * @tparam T The type to check. - * @return True if the type is a smart pointer, false otherwise. + * @brief Convert an enum type to a string by casting to its underlying type. + * @tparam Enum The enum type. + * @param value The enum value to convert. + * @return A string representation of the enum's underlying type. */ -template -concept SmartPointer = requires(T a) { - { *a } -> std::convertible_to; -}; - -#endif +template +auto toString(const Enum& value) -> std::string { + return std::to_string(static_cast>(value)); +} /** - * @brief Convert a value to a string representation. - * @tparam T The type of the value. - * @param value The value to be converted. - * @return A string representation of the value. + * @brief Convert a pointer type to a string. + * @tparam T The pointer type. + * @param ptr The pointer to convert. + * @return A string representation of the pointer address or value. */ -template -auto toString(const T &value) - -> std::enable_if_t::value && !is_container::value, - std::string> { - if constexpr (is_string_type) { - return std::string(value); - } else { - std::ostringstream oss; - oss << value; - return oss.str(); +template +std::string toString(T ptr) { + if (ptr) { + return "Pointer(" + toString(*ptr) + ")"; } + return "nullptr"; } /** - * @brief Join a pair of key-value pairs into a string representation. - * @tparam Key The type of the key. - * @tparam Value The type of the value. - * @param keyValue The pair of key-value pairs. - * @return A string representation of the pair. + * @brief Convert a smart pointer type (unique_ptr/shared_ptr) to a string. + * @tparam SmartPtr The smart pointer type. + * @param ptr The smart pointer to convert. + * @return A string representation of the smart pointer. */ -template -auto toString(const std::pair &keyValue) { - return "(" + toString(keyValue.first) + ", " + toString(keyValue.second) + - ")"; +template +std::string toString(const SmartPtr& ptr) { + if (ptr) { + return "SmartPointer(" + toString(*ptr) + ")"; + } + return "nullptr"; } /** - * @brief Join a pair of key-value pairs into a string representation. + * @brief Convert a key-value pair to a string. * @tparam Key The type of the key. * @tparam Value The type of the value. - * @param keyValue The pair of key-value pairs. - * @param separator The separator to use between the key and value. - * @return A string representation of the pair. + * @param keyValue The key-value pair to convert. + * @return A string representation of the key-value pair. */ template -auto toString(const std::pair &keyValue, - const std::string &separator) { - return toString(keyValue.first) + separator + toString(keyValue.second); +auto toString(const std::pair& keyValue) -> std::string { + return "(" + toString(keyValue.first) + ", " + toString(keyValue.second) + + ")"; } /** - * @brief Join a map of key-value pairs into a string representation. - * @tparam Container The type of the map. - * @param container The map of key-value pairs. + * @brief Convert a map (or unordered_map) to a string representation. + * @tparam Map The map type. + * @param map The map to convert. * @return A string representation of the map. */ -template -std::enable_if_t::value, std::string> toString( - const Container &container) { +template +auto toString(const Map& map) -> std::string { std::ostringstream oss; oss << "{"; bool first = true; - for (const auto &elem : container) { + for (const auto& [key, value] : map) { if (!first) { oss << ", "; } - oss << toString(elem.first) << ": " << toString(elem.second); + oss << toString(key) << ": " << toString(value); first = false; } oss << "}"; @@ -182,43 +154,39 @@ std::enable_if_t::value, std::string> toString( } /** - * @brief Join a container of values into a string representation. - * @tparam Container The type of the container. - * @param container The container of values. + * @brief Convert a container (e.g., vector) to a string representation. + * @tparam Container The container type. + * @param container The container to convert. * @return A string representation of the container. */ -template -auto toString(const Container &container) - -> std::enable_if_t::value && - !is_map::value && - !is_string_type, - std::string> { +template + requires(!MapType && !StringType) +auto toString(const Container& container) -> std::string { std::ostringstream oss; oss << "["; auto it = container.begin(); while (it != container.end()) { oss << toString(*it); ++it; - if (it != container.end()) + if (it != container.end()) { oss << ", "; + } } oss << "]"; return oss.str(); } -/** - * @brief Join a vector of values into a string representation. - * @tparam T The type of the values in the vector. - * @param value The vector of values. - * @return A string representation of the vector. - */ -template -std::string toString(const std::vector &value) { +template + requires(Container && + StringType) +auto toString(const ContainerT& container) -> std::string { std::ostringstream oss; oss << "["; - for (size_t i = 0; i < value.size(); ++i) { - oss << toString(value[i]); - if (i != value.size() - 1) { + auto it = container.begin(); + while (it != container.end()) { + oss << toString(*it); + ++it; + if (it != container.end()) { oss << ", "; } } @@ -227,55 +195,27 @@ std::string toString(const std::vector &value) { } /** - * @brief Join a key-value pair into a string representation. - * @tparam T The type of the value. - * @param key The key. - * @param value The value. - * @param separator The separator to use between the key and value. - * @return A string representation of the key-value pair. - */ -template -std::enable_if_t, std::string> joinKeyValuePair( - const std::string &key, const T &value, const std::string &separator = "") { - return key + separator + std::string(value); -} - -/** - * @brief Join a key-value pair into a string representation. - * @tparam Key The type of the key. - * @tparam Value The type of the value. - * @param keyValue The key-value pair to join. - * @param separator The separator to use between the key and value. - * @return A string representation of the key-value pair. - */ -template -std::string joinKeyValuePair(const std::pair &keyValue, - const std::string &separator = "") { - return joinKeyValuePair(keyValue.first, keyValue.second, separator); -} - -/** - * @brief Join command line arguments into a single string. - * @tparam Args The types of the command line arguments. - * @param args The command line arguments. + * @brief Join multiple values into a single command line string. + * @tparam Args The types of the arguments. + * @param args The arguments to join. * @return A string representation of the command line arguments. */ template -[[nodiscard]] std::string joinCommandLine(const Args &...args) { +auto joinCommandLine(const Args&... args) -> std::string { std::ostringstream oss; - bool firstArg = true; - ((oss << (firstArg ? (firstArg = false, "") : " ") << toString(args)), ...); + bool first = true; + ((oss << (first ? (first = false, "") : " ") << toString(args)), ...); return oss.str(); } /** - * @brief Convert a vector of elements to a string representation. - * @tparam T The type of elements in the vector. + * @brief Convert a vector to a space-separated string representation. + * @tparam T The type of the elements in the vector. * @param array The vector to convert. - * @return A string representation of the vector. + * @return A space-separated string representation of the vector. */ -template -auto toStringArray(const std::vector &array) -> std::string { +template +auto toStringArray(const T& array) -> std::string { std::ostringstream oss; for (size_t i = 0; i < array.size(); ++i) { oss << toString(array[i]); @@ -287,123 +227,39 @@ auto toStringArray(const std::vector &array) -> std::string { } /** - * @brief Concept to check if a type has begin() and end() member functions. - * @tparam T The type to check. - */ -template -concept HasIterator = requires(T t) { - t.begin(); /**< The begin() member function. */ - t.end(); /**< The end() member function. */ -}; - -/** - * @brief Implementation of string equality comparison for types supporting - * iterators. - * @tparam LHS Type of the left-hand side. - * @tparam RHS Type of the right-hand side. - * @param t_lhs The left-hand side operand. - * @param t_rhs The right-hand side operand. - * @return True if the strings are equal, false otherwise. + * @brief Convert an iterator range to a string representation. + * @tparam Iterator The type of the iterator. + * @param begin The beginning of the range. + * @param end The end of the range. + * @return A string representation of the range. */ -template -[[nodiscard]] constexpr bool str_equal_impl(const LHS &t_lhs, - const RHS &t_rhs) noexcept { - return std::equal(t_lhs.begin(), t_lhs.end(), t_rhs.begin(), t_rhs.end()); -} - -/** - * @brief Functor for string equality comparison. - */ -struct str_equal { - /** - * @brief Compares two std::string objects for equality. - * @param t_lhs The left-hand side string. - * @param t_rhs The right-hand side string. - * @return True if the strings are equal, false otherwise. - */ - [[nodiscard]] bool operator()(const std::string &t_lhs, - const std::string &t_rhs) const noexcept { - return t_lhs == t_rhs; - } - - /** - * @brief Compares two objects for equality using iterators. - * @tparam LHS Type of the left-hand side. - * @tparam RHS Type of the right-hand side. - * @param t_lhs The left-hand side operand. - * @param t_rhs The right-hand side operand. - * @return True if the strings are equal, false otherwise. - */ - template - [[nodiscard]] constexpr bool operator()(const LHS &t_lhs, - const RHS &t_rhs) const noexcept { - return str_equal_impl(t_lhs, t_rhs); +template +auto toStringRange(Iterator begin, Iterator end) -> std::string { + std::ostringstream oss; + oss << "["; + while (begin != end) { + oss << toString(*begin); + ++begin; + if (begin != end) { + oss << ", "; + } } - - struct is_transparent {}; /**< Enables transparent comparison. */ -}; - -/** - * @brief Implementation of string less-than comparison for types supporting - * iterators. - * @tparam T Type of the operands. - * @param t_lhs The left-hand side operand. - * @param t_rhs The right-hand side operand. - * @return True if t_lhs is less than t_rhs, false otherwise. - */ -template -[[nodiscard]] constexpr bool str_less_impl(const T &t_lhs, - const T &t_rhs) noexcept { - return t_lhs < t_rhs; + oss << "]"; + return oss.str(); } /** - * @brief Implementation of string less-than comparison for types supporting - * iterators. - * @tparam LHS Type of the left-hand side. - * @tparam RHS Type of the right-hand side. - * @param t_lhs The left-hand side operand. - * @param t_rhs The right-hand side operand. - * @return True if t_lhs is less than t_rhs, false otherwise. + * @brief Convert an array to a string. + * @tparam T The type of the array elements. + * @tparam N The size of the array. + * @param array The array to convert. + * @return A string representation of the array. */ -template -[[nodiscard]] constexpr bool str_less_impl(const LHS &t_lhs, - const RHS &t_rhs) noexcept { - return std::lexicographical_compare(t_lhs.begin(), t_lhs.end(), - t_rhs.begin(), t_rhs.end()); +template +auto toString(const T (&array)[N]) -> std::string { + return toStringRange(std::begin(array), std::end(array)); } -/** - * @brief Functor for string less-than comparison. - */ -struct str_less { - /** - * @brief Compares two std::string objects. - * @param t_lhs The left-hand side string. - * @param t_rhs The right-hand side string. - * @return True if t_lhs is less than t_rhs, false otherwise. - */ - [[nodiscard]] bool operator()(const std::string &t_lhs, - const std::string &t_rhs) const noexcept { - return t_lhs < t_rhs; - } - - /** - * @brief Compares two objects using iterators. - * @tparam LHS Type of the left-hand side. - * @tparam RHS Type of the right-hand side. - * @param t_lhs The left-hand side operand. - * @param t_rhs The right-hand side operand. - * @return True if t_lhs is less than t_rhs, false otherwise. - */ - template - [[nodiscard]] constexpr bool operator()(const LHS &t_lhs, - const RHS &t_rhs) const noexcept { - return str_less_impl(t_lhs, t_rhs); - } - - struct is_transparent {}; /**< Enables transparent comparison. */ -}; } // namespace atom::utils -#endif // ATOM_STRINGUTILS_HPP +#endif // ATOM_UTILS_TO_STRING_HPP diff --git a/src/atom/utils/uuid.cpp b/src/atom/utils/uuid.cpp index ab701570..570b6b70 100644 --- a/src/atom/utils/uuid.cpp +++ b/src/atom/utils/uuid.cpp @@ -104,6 +104,26 @@ auto UUID::generateV3(const UUID& namespace_uuid, return generateNameBased(namespace_uuid, name, 3); } +auto UUID::generateV4() -> UUID { + // Generate a random UUID (version 4) + UUID uuid; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist(0, 255); + + for (auto& byte : uuid.data_) { + byte = dist(gen); + } + + // Set version to 4 (randomly generated UUID) + uuid.data_[6] = (uuid.data_[6] & 0x0F) | 0x40; + + // Set the variant (RFC 4122 variant) + uuid.data_[8] = (uuid.data_[8] & 0x3F) | 0x80; + + return uuid; +} + auto UUID::generateV5(const UUID& namespace_uuid, const std::string& name) -> UUID { return generateNameBased(namespace_uuid, name, 5); diff --git a/src/atom/utils/uuid.hpp b/src/atom/utils/uuid.hpp index e301c85b..ed4b6302 100644 --- a/src/atom/utils/uuid.hpp +++ b/src/atom/utils/uuid.hpp @@ -136,6 +136,12 @@ class UUID { */ static auto generateV1() -> UUID; + /** + * @brief Generates a version 4, random UUID. + * @return A version 4 UUID. + */ + static auto generateV4() -> UUID; + private: /** * @brief Generates a random UUID. diff --git a/src/atom/web/CMakeLists.txt b/src/atom/web/CMakeLists.txt index 3dab0c5b..0e531069 100644 --- a/src/atom/web/CMakeLists.txt +++ b/src/atom/web/CMakeLists.txt @@ -13,7 +13,6 @@ project(atom-web C CXX) set(${PROJECT_NAME}_SOURCES address.cpp downloader.cpp - httplite.cpp httpparser.cpp utils.cpp time.cpp @@ -23,7 +22,6 @@ set(${PROJECT_NAME}_SOURCES set(${PROJECT_NAME}_HEADERS address.hpp downloader.hpp - httplite.hpp httpparser.hpp utils.hpp time.hpp diff --git a/src/atom/web/downloader.cpp b/src/atom/web/downloader.cpp index 5f40d1c3..c63ba170 100644 --- a/src/atom/web/downloader.cpp +++ b/src/atom/web/downloader.cpp @@ -4,68 +4,120 @@ * Copyright (C) 2023-2024 Max Qian */ -/************************************************* - -Date: 2023-4-9 - -Description: Downloader - -**************************************************/ - #include "downloader.hpp" +#include +#include #include #include +#include +#include #include +#include +#include #include "atom/error/exception.hpp" #include "atom/log/loguru.hpp" -#include "cpp_httplib/httplib.h" +#include "macro.hpp" namespace atom::web { -DownloadManager::DownloadManager(const std::string &task_file) - : task_file_(task_file) { - try { - std::ifstream infile(task_file_); - if (!infile) { - LOG_F(ERROR, "Failed to open task file {}", task_file_); - THROW_EXCEPTION("Failed to open task file."); - } - while (infile >> std::ws && !infile.eof()) { - std::string url; - std::string filepath; - infile >> url >> filepath; - if (!url.empty() && !filepath.empty()) { - tasks_.push_back({url, filepath}); - } + +class DownloadManager::Impl { +public: + explicit Impl(std::string task_file); + ~Impl(); + + void addTask(const std::string& url, const std::string& filepath, + int priority); + auto removeTask(size_t index) -> bool; + void start(size_t thread_count, size_t download_speed); + void pauseTask(size_t index); + void resumeTask(size_t index); + void cancelTask(size_t index); + auto getDownloadedBytes(size_t index) -> size_t; + void setThreadCount(size_t thread_count); + void setMaxRetries(size_t retries); + void onDownloadComplete(const std::function& callback); + void onProgressUpdate(const std::function& callback); + + struct DownloadTask { + std::string url; + std::string filepath; + bool completed{false}; + bool paused{false}; + bool cancelled{false}; + size_t downloadedBytes{0}; + int priority{0}; + size_t retries{0}; + } ATOM_ALIGNAS(128); + +private: + auto getNextTaskIndex() -> std::optional; + auto getNextTask() -> std::optional; + void downloadTask(DownloadTask& task, size_t download_speed); + void run(size_t download_speed); + void saveTaskListToFile(); + + std::string task_file_; + std::vector tasks_; + std::priority_queue task_queue_; + std::mutex mutex_; + std::atomic running_{false}; + std::chrono::system_clock::time_point start_time_; + size_t max_retries_{3}; + size_t thread_count_; + std::function on_complete_; + std::function on_progress_; +}; + +// 写入回调函数,带进度监控 +auto writeDataWithProgress(void* buffer, size_t size, size_t nmemb, + void* userp) -> size_t { + auto* taskInfo = static_cast(userp); + size_t bytesWritten = size * nmemb; + taskInfo->downloadedBytes += bytesWritten; + + return bytesWritten; +} + +DownloadManager::Impl::Impl(std::string task_file) + : task_file_(std::move(task_file)), + thread_count_(std::thread::hardware_concurrency()) { + curl_global_init(CURL_GLOBAL_DEFAULT); + + // 读取任务列表 + std::ifstream infile(task_file_); + if (!infile) { + LOG_F(ERROR, "Failed to open task file {}", task_file_); + THROW_EXCEPTION("Failed to open task file."); + } + while (infile >> std::ws && !infile.eof()) { + std::string url; + std::string filepath; + infile >> url >> filepath; + if (!url.empty() && !filepath.empty()) { + tasks_.emplace_back(url, filepath); } - infile.close(); - } catch (const std::exception &e) { - LOG_F(ERROR, "Error: {}", e.what()); - THROW_EXCEPTION("Error: ", e.what()); } } -DownloadManager::~DownloadManager() { save_task_list_to_file(); } +DownloadManager::Impl::~Impl() { + saveTaskListToFile(); + curl_global_cleanup(); +} -void DownloadManager::add_task(const std::string &url, - const std::string &filepath, int priority) { - try { - std::ofstream outfile(task_file_, std::ios_base::app); - if (!outfile) { - LOG_F(ERROR, "Failed to open task file {}", task_file_); - THROW_EXCEPTION("Failed to open task file."); - } - outfile << url << " " << filepath << std::endl; - outfile.close(); - } catch (const std::exception &e) { - LOG_F(ERROR, "Error: {}", e.what()); - THROW_EXCEPTION("Error: ", e.what()); +void DownloadManager::Impl::addTask(const std::string& url, + const std::string& filepath, int priority) { + std::ofstream outfile(task_file_, std::ios_base::app); + if (!outfile) { + LOG_F(ERROR, "Failed to open task file {}", task_file_); + THROW_FAIL_TO_OPEN_FILE("Failed to open task file."); } - tasks_.push_back({url, filepath, false, false, 0, priority}); + outfile << url << " " << filepath << std::endl; + tasks_.push_back({url, filepath, false, false, false, 0, priority}); } -bool DownloadManager::remove_task(size_t index) { +auto DownloadManager::Impl::removeTask(size_t index) -> bool { if (index >= tasks_.size()) { return false; } @@ -73,182 +125,225 @@ bool DownloadManager::remove_task(size_t index) { return true; } -void DownloadManager::start(size_t thread_count, size_t download_speed) { +void DownloadManager::Impl::cancelTask(size_t index) { + if (index >= tasks_.size()) { + return; + } + tasks_[index].cancelled = true; + tasks_[index].paused = true; // 停止该任务的下载 +} + +void DownloadManager::Impl::start(size_t thread_count, size_t download_speed) { running_ = true; -#if _cplusplus >= 202002L std::vector threads; -#else - std::vector threads; -#endif - for (size_t i = 0; i < thread_count; ++i) { - threads.emplace_back(&DownloadManager::run, this, download_speed); + thread_count_ = thread_count; + for (size_t i = 0; i < thread_count_; ++i) { + threads.emplace_back(&DownloadManager::Impl::run, this, download_speed); } - for (auto &thread : threads) { - thread.join(); + + for (auto& thread : threads) { + thread.join(); // 等待所有线程完成 } } -void DownloadManager::pause_task(size_t index) { +void DownloadManager::Impl::pauseTask(size_t index) { if (index >= tasks_.size()) { - LOG_F(ERROR, "Index out of bounds!"); + LOG_F(ERROR, "Index out of bounds for pause_task."); return; } - tasks_[index].paused = true; - DLOG_F(INFO, "Paused task {} - {}", tasks_[index].url, - tasks_[index].filepath); } -void DownloadManager::resume_task(size_t index) { +void DownloadManager::Impl::resumeTask(size_t index) { if (index >= tasks_.size()) { - LOG_F(ERROR, "Index out of bounds!"); + LOG_F(ERROR, "Index out of bounds for resume_task."); return; } - tasks_[index].paused = false; - - // 如果任务未完成,则重新下载 - if (!tasks_[index].completed) { - DLOG_F(INFO, "Resumed task {} - {}", tasks_[index].url, - tasks_[index].filepath); - } } -size_t DownloadManager::get_downloaded_bytes(size_t index) { +auto DownloadManager::Impl::getDownloadedBytes(size_t index) -> size_t { if (index >= tasks_.size()) { - LOG_F(ERROR, "Index out of bounds!"); + LOG_F(ERROR, "Index out of bounds for get_downloaded_bytes."); return 0; } + return tasks_[index].downloadedBytes; +} + +void DownloadManager::Impl::setThreadCount(size_t thread_count) { + thread_count_ = thread_count; +} - return tasks_[index].downloaded_bytes; +void DownloadManager::Impl::setMaxRetries(size_t retries) { + max_retries_ = retries; } -std::optional DownloadManager::get_next_task_index() { - std::unique_lock lock(mutex_); +void DownloadManager::Impl::onDownloadComplete( + const std::function& callback) { + on_complete_ = callback; +} + +void DownloadManager::Impl::onProgressUpdate( + const std::function& callback) { + on_progress_ = callback; +} + +auto DownloadManager::Impl::getNextTaskIndex() -> std::optional { + std::unique_lock lock(mutex_); for (size_t i = 0; i < tasks_.size(); ++i) { - if (!tasks_[i].completed && !tasks_[i].paused) { + if (!tasks_[i].completed && !tasks_[i].paused && !tasks_[i].cancelled) { return i; } } return std::nullopt; } -std::optional DownloadManager::get_next_task() { - std::unique_lock lock(mutex_); +auto +DownloadManager::Impl::getNextTask() -> std::optional { + std::unique_lock lock(mutex_); if (!task_queue_.empty()) { auto task = task_queue_.top(); task_queue_.pop(); return task; } - auto index = get_next_task_index(); - if (index) { + if (auto index = getNextTaskIndex(); index) { return tasks_[*index]; } return std::nullopt; } -void DownloadManager::run(size_t download_speed) { +void DownloadManager::Impl::run(size_t download_speed) { while (running_) { - auto task = get_next_task(); - if (!task) { - break; + auto taskOpt = getNextTask(); + if (!taskOpt) { + break; // 没有任务可执行时退出 } + auto& task = *taskOpt; - if (task->completed || task->paused) { + if (task.completed || task.paused || task.cancelled) { continue; } + + // 记录任务开始的时间 start_time_ = std::chrono::system_clock::now(); - download_task(*task, download_speed); + + // 处理下载任务 + downloadTask(task, download_speed); + + if (task.completed && on_complete_) { + on_complete_(task.priority); // 下载完成后触发回调 + } } } -void DownloadManager::download_task(DownloadTask &task, size_t download_speed) { - httplib::Client cli(task.url); - auto res = cli.Get("/"); - if (!res || res->status != 200) { - LOG_F(ERROR, "Failed to download {}", task.url); +void DownloadManager::Impl::downloadTask(DownloadTask& task, + size_t download_speed) { + CURL* curl = curl_easy_init(); + if (!curl) { + LOG_F(ERROR, "Failed to initialize curl for {}", task.url); return; } - std::ofstream outfile(task.filepath, - std::ofstream::binary | std::ofstream::app); + std::ofstream outfile(task.filepath, std::ios::binary | std::ios::app); if (!outfile) { LOG_F(ERROR, "Failed to open file {}", task.filepath); + curl_easy_cleanup(curl); return; } - // 断点续传:从已经下载的字节数开始写入文件 - outfile.seekp(task.downloaded_bytes); + // 设置 URL + curl_easy_setopt(curl, CURLOPT_URL, task.url.c_str()); - constexpr size_t kBufferSize = 1024 * 1024; // 1MB 缓存 - size_t buffer_size = kBufferSize; - if (download_speed > 0) { - // 如果有下载速度限制,根据时间计算出当前应该下载的字节数 - auto bytes_per_ms = static_cast(download_speed) / 1000.0; - auto elapsed_ms = std::chrono::duration_cast( - std::chrono::system_clock::now() - start_time_) - .count(); - auto bytes_to_download = - static_cast(elapsed_ms * bytes_per_ms) - - task.downloaded_bytes; - if (bytes_to_download < buffer_size) { - buffer_size = bytes_to_download; - } - } + // 设置写入数据回调函数 + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeDataWithProgress); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &task); - size_t total_bytes_read = 0; + // 设置断点续传 + curl_easy_setopt(curl, CURLOPT_RESUME_FROM, task.downloadedBytes); - while (!res->body.empty() && !task.completed && !task.paused) { - size_t bytes_to_write = res->body.size() - task.downloaded_bytes; - if (bytes_to_write > buffer_size) { - bytes_to_write = buffer_size; - } + // 设置下载速度限制 + if (download_speed > 0) { + curl_easy_setopt(curl, CURLOPT_MAX_RECV_SPEED_LARGE, + static_cast(download_speed)); + } - outfile.write(res->body.c_str() + task.downloaded_bytes, - bytes_to_write); - auto bytes_written = - static_cast(outfile.tellp()) - task.downloaded_bytes; - task.downloaded_bytes += bytes_written; - total_bytes_read += bytes_written; - - // 下载速度控制:根据需要下载的字节数和已经下载的字节数计算出需要等待的时间 - if (download_speed > 0) { - auto bytes_per_ms = static_cast(download_speed) / 1000.0; - auto elapsed_ms = - static_cast(total_bytes_read) / bytes_per_ms; - std::this_thread::sleep_for( - std::chrono::milliseconds(static_cast(elapsed_ms))); + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + LOG_F(ERROR, "Download failed: {}", curl_easy_strerror(res)); + + // 错误处理,重试机制 + if (task.retries < max_retries_) { + LOG_F(INFO, "Retrying task {} ({} retries left)", task.url, + max_retries_ - task.retries); + task.retries++; + downloadTask(task, download_speed); // 重试下载任务 + } else { + LOG_F(ERROR, "Max retries reached for task {}", task.url); } - - // 如果下载完成,则标记任务为已完成 - if (task.downloaded_bytes >= res->body.size()) { - task.completed = true; + } else { + double totalSize; + curl_easy_getinfo(curl, CURLINFO_SIZE_DOWNLOAD, &totalSize); + task.downloadedBytes += static_cast(totalSize); + task.completed = true; + + // 下载进度更新回调 + if (on_progress_) { + double progress = + static_cast(task.downloadedBytes) / totalSize * 100.0; + on_progress_(task.priority, progress); // 传递任务索引和进度百分比 } } - outfile.close(); + curl_easy_cleanup(curl); +} - if (task.completed) { - DLOG_F(INFO, "Downloaded file {}", task.filepath); +void DownloadManager::Impl::saveTaskListToFile() { + std::ofstream outfile(task_file_); + if (!outfile) { + LOG_F(ERROR, "Failed to open task file {}", task_file_); + return; } -} -void DownloadManager::save_task_list_to_file() { - try { - std::ofstream outfile(task_file_); - if (!outfile) { - DLOG_F(INFO, "Failed to open task file {}", task_file_); - throw std::runtime_error("Failed to open task file."); - } - for (const auto &task : tasks_) { - outfile << task.url << " " << task.filepath << std::endl; - } - outfile.close(); - } catch (const std::exception &e) { - LOG_F(ERROR, "Error: {}", e.what()); - THROW_EXCEPTION("Error: ", e.what()); + for (const auto& task : tasks_) { + outfile << task.url << " " << task.filepath << std::endl; } } + +// DownloadManager public functions +DownloadManager::DownloadManager(const std::string& task_file) + : impl_(std::make_unique(task_file)) {} +DownloadManager::~DownloadManager() = default; +void DownloadManager::add_task(const std::string& url, + const std::string& filepath, int priority) { + impl_->addTask(url, filepath, priority); +} +auto DownloadManager::remove_task(size_t index) -> bool { + return impl_->removeTask(index); +} +void DownloadManager::start(size_t thread_count, size_t download_speed) { + impl_->start(thread_count, download_speed); +} +void DownloadManager::pause_task(size_t index) { impl_->pauseTask(index); } +void DownloadManager::resume_task(size_t index) { impl_->resumeTask(index); } +auto DownloadManager::get_downloaded_bytes(size_t index) -> size_t { + return impl_->getDownloadedBytes(index); +} +void DownloadManager::cancel_task(size_t index) { impl_->cancelTask(index); } +void DownloadManager::set_thread_count(size_t thread_count) { + impl_->setThreadCount(thread_count); +} +void DownloadManager::set_max_retries(size_t retries) { + impl_->setMaxRetries(retries); +} +void DownloadManager::on_download_complete( + const std::function& callback) { + impl_->onDownloadComplete(callback); +} +void DownloadManager::on_progress_update( + const std::function& callback) { + impl_->onProgressUpdate(callback); +} + } // namespace atom::web diff --git a/src/atom/web/downloader.hpp b/src/atom/web/downloader.hpp index 8a943a9f..c3102066 100644 --- a/src/atom/web/downloader.hpp +++ b/src/atom/web/downloader.hpp @@ -4,43 +4,16 @@ * Copyright (C) 2023-2024 Max Qian */ -/************************************************* - -Date: 2023-6-21 - -Description: Downloader - -**************************************************/ - #pragma once -#define DOWNLOAD_MANAGER_H - -#include -#include -#include -#include +#include +#include #include #include -#include namespace atom::web { - -struct DownloadTask { - std::string url; - std::string filepath; - bool completed{false}; - bool paused{false}; - size_t downloaded_bytes{0}; - int priority{0}; -}; - -inline bool operator<(const DownloadTask &lhs, const DownloadTask &rhs) { - return lhs.priority < rhs.priority; -} - /** - * @brief DownloadManager 类,用于管理下载任务 + * @brief DownloadManager 类,用于管理下载任务,使用 Pimpl 模式隐藏实现细节 */ class DownloadManager { public: @@ -48,7 +21,7 @@ class DownloadManager { * @brief 构造函数 * @param task_file 保存下载任务列表的文件路径 */ - DownloadManager(const std::string &task_file); + explicit DownloadManager(const std::string& task_file); /** * @brief 析构函数,用于释放资源 @@ -61,7 +34,7 @@ class DownloadManager { * @param filepath 本地保存文件路径 * @param priority 下载任务优先级,数字越大优先级越高 */ - void add_task(const std::string &url, const std::string &filepath, + void add_task(const std::string& url, const std::string& filepath, int priority = 0); /** @@ -98,148 +71,41 @@ class DownloadManager { */ size_t get_downloaded_bytes(size_t index); -private: /** - * @brief 获取下一个要下载的任务的索引 - * @return 下一个要下载的任务的索引,如果任务队列为空,则返回空 + * @brief 取消下载任务 + * @param index 要取消的任务在任务列表中的索引 */ - std::optional get_next_task_index(); + void cancel_task(size_t index); /** - * @brief 获取下一个要下载的任务 - * @return 下一个要下载的任务,如果任务队列为空,则返回空 + * @brief 动态调整下载线程数 + * @param thread_count 新的下载线程数 */ - std::optional get_next_task(); + void set_thread_count(size_t thread_count); /** - * @brief 启动下载线程 - * @param download_speed 下载速度限制,单位为字节/秒,0 表示不限制下载速度 + * @brief 设置下载错误重试次数 + * @param retries 每个任务失败时的最大重试次数 */ - void run(size_t download_speed); + void set_max_retries(size_t retries); /** - * @brief 下载指定的任务 - * @param task 要下载的任务 - * @param download_speed 下载速度限制,单位为字节/秒,0 表示不限制下载速度 + * @brief 注册下载完成回调函数 + * @param callback 下载完成时的回调函数,参数为任务索引 */ - void download_task(DownloadTask &task, size_t download_speed); + void on_download_complete(const std::function& callback); /** - * @brief 保存下载任务列表到文件中 + * @brief 注册下载进度更新回调函数 + * @param callback 下载进度更新时的回调函数,参数为任务索引和下载百分比 */ - void save_task_list_to_file(); + void on_progress_update( + const std::function& callback); + class Impl; private: - std::string task_file_; ///< 下载任务列表文件路径 - std::vector tasks_; ///< 下载任务列表 - std::priority_queue - task_queue_; ///< 任务队列,按照优先级排序 - std::mutex mutex_; ///< 互斥量,用于保护任务列表和任务队列 - std::atomic running_{false}; ///< 是否正在下载中 - std::chrono::system_clock::time_point start_time_; + + std::unique_ptr impl_; ///< 使用 Pimpl 隐藏实现细节 }; -} // namespace atom::web -/* -#include -#include -#include "downloader.hpp" - -int main() -{ - DownloadManager download_manager("tasks.txt"); - - while (true) - { - std::cout << "1. Add task" << std::endl; - std::cout << "2. Pause task" << std::endl; - std::cout << "3. Resume task" << std::endl; - std::cout << "4. Remove task" << std::endl; - std::cout << "5. Start downloading" << std::endl; - std::cout << "6. Exit" << std::endl; - - int choice; - std::cout << "Please select an option: "; - std::cin >> choice; - - if (choice == 1) - { - std::string url, filepath; - int priority; - - std::cout << "URL: "; - std::cin >> url; - - std::cout << "Filepath: "; - std::cin >> filepath; - - std::cout << "Priority (-1 for default): "; - std::cin >> priority; - - download_manager.add_task(url, filepath, priority); - std::cout << "Task added." << std::endl; - } - else if (choice == 2) - { - int index; - std::cout << "Task index: "; - std::cin >> index; - - if (download_manager.pause_task(index)) - { - std::cout << "Task paused." << std::endl; - } - else - { - std::cout << "Failed to pause task." << std::endl; - } - } - else if (choice == 3) - { - int index; - std::cout << "Task index: "; - std::cin >> index; - - if (download_manager.resume_task(index)) - { - std::cout << "Task resumed." << std::endl; - } - else - { - std::cout << "Failed to resume task." << std::endl; - } - } - else if (choice == 4) - { - int index; - std::cout << "Task index: "; - std::cin >> index; - - if (download_manager.remove_task(index)) - { - std::cout << "Task removed." << std::endl; - } - else - { - std::cout << "Failed to remove task." << std::endl; - } - } - else if (choice == 5) - { - download_manager.start(); - break; - } - else if (choice == 6) - { - break; - } - else - { - std::cout << "Invalid choice, please try again." << std::endl; - } - } - - return 0; -} - -*/ +} // namespace atom::web diff --git a/src/atom/web/httplite.cpp b/src/atom/web/httplite.cpp deleted file mode 100644 index 28cf4d65..00000000 --- a/src/atom/web/httplite.cpp +++ /dev/null @@ -1,234 +0,0 @@ -/* - * httplite.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-3 - -Description: Simple Http Client - -**************************************************/ - -#include "httplite.hpp" - -#include -#include - -#include "atom/log/loguru.hpp" - -namespace atom::web { - -HttpClient::HttpClient() : socketfd(0) {} - -HttpClient::~HttpClient() { closeSocket(); } - -void HttpClient::setErrorHandler( - std::function errorHandler) { - this->errorHandler = errorHandler; -} - -bool HttpClient::initialize() { -#ifdef _WIN32 - WSADATA wsaData; - if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { - errorHandler("Failed to initialize Winsock"); - return false; - } -#endif - socketfd = socket(AF_INET, SOCK_STREAM, 0); - if (socketfd == -1) { - errorHandler("Failed to create socket"); - return false; - } - return true; -} - -bool HttpClient::connectToServer(const std::string &host, int port, - bool useHttps) { - struct sockaddr_in serverAddr {}; - serverAddr.sin_family = AF_INET; - serverAddr.sin_port = htons(port); - -#ifdef _WIN32 - if (InetPton(AF_INET, host.c_str(), &serverAddr.sin_addr) <= 0) -#else - if (inet_pton(AF_INET, host.c_str(), &(serverAddr.sin_addr)) <= 0) -#endif - { - errorHandler("Failed to parse server address"); - closeSocket(); - return false; - } - - if (connect(socketfd, (struct sockaddr *)&serverAddr, sizeof(serverAddr)) < - 0) { - errorHandler("Failed to connect to server"); - closeSocket(); - return false; - } - - if (useHttps) { - SSL_load_error_strings(); - SSL_library_init(); - SSL_CTX *ctx = SSL_CTX_new(SSLv23_client_method()); - SSL *ssl = SSL_new(ctx); - SSL_set_fd(ssl, socketfd); - if (SSL_connect(ssl) != 1) { - errorHandler("Failed to establish SSL connection"); - closeSocket(); - return false; - } - } - - return true; -} - -bool HttpClient::sendRequest(const std::string &request) { - if (send(socketfd, request.c_str(), request.length(), 0) < 0) { - errorHandler("Failed to send request"); - closeSocket(); - return false; - } - return true; -} - -HttpResponse HttpClient::receiveResponse() { - HttpResponse response; - - char buffer[4096]; - while (true) { - memset(buffer, 0, sizeof(buffer)); - int bytesRead = recv(socketfd, buffer, sizeof(buffer) - 1, 0); - if (bytesRead <= 0) { - break; - } - response.body += buffer; - } - - closeSocket(); - return response; -} - -void HttpClient::closeSocket() { - if (socketfd != 0) { -#ifdef _WIN32 - closesocket(socketfd); - WSACleanup(); -#else - close(socketfd); -#endif - socketfd = 0; - } -} - -HttpRequestBuilder::HttpRequestBuilder(HttpMethod method, - const std::string &url) - : method(method), url(url), timeout(10) {} - -HttpRequestBuilder &HttpRequestBuilder::setBody(const std::string &bodyText) { - body = bodyText; - return *this; -} - -HttpRequestBuilder &HttpRequestBuilder::setContentType( - const std::string &contentTypeValue) { - contentType = contentTypeValue; - return *this; -} - -HttpRequestBuilder &HttpRequestBuilder::setTimeout( - std::chrono::seconds timeoutValue) { - timeout = timeoutValue; - return *this; -} - -HttpRequestBuilder &HttpRequestBuilder::addHeader(const std::string &key, - const std::string &value) { - headers[key] = value; - return *this; -} - -HttpResponse HttpRequestBuilder::send() { - HttpClient client; - client.setErrorHandler([](const std::string &error) { - std::cerr << "Error: " << error << std::endl; - }); - - if (!client.initialize()) { - return HttpResponse{}; - } - - std::string host; - std::string path; - bool useHttps = false; - size_t pos = url.find("://"); - if (pos != std::string::npos) { - std::string protocol = url.substr(0, pos); - if (protocol == "https") { - useHttps = true; - } - pos += 3; - size_t slashPos = url.find('/', pos); - if (slashPos != std::string::npos) { - host = url.substr(pos, slashPos - pos); - path = url.substr(slashPos); - } else { - std::cerr << "Invalid URL" << std::endl; - return HttpResponse{}; - } - } else { - std::cerr << "Invalid URL" << std::endl; - return HttpResponse{}; - } - - if (!client.connectToServer(host, useHttps ? 443 : 80, useHttps)) { - return HttpResponse{}; - } - - std::string request = buildRequestString(host, path); - if (!client.sendRequest(request)) { - return HttpResponse{}; - } - - return client.receiveResponse(); -} - -std::string HttpRequestBuilder::buildRequestString(const std::string &host, - const std::string &path) { - std::string request; - switch (method) { - case HttpMethod::GET: - request = "GET "; - break; - case HttpMethod::POST: - request = "POST "; - break; - case HttpMethod::PUT: - request = "PUT "; - break; - case HttpMethod::DELETE: - request = "DELETE "; - break; - } - request += path + " HTTP/1.1\r\n"; - request += "Host: " + host + "\r\n"; - for (const auto &header : headers) { - request += header.first + ": " + header.second + "\r\n"; - } - if (!contentType.empty()) { - request += "Content-Type: " + contentType + "\r\n"; - } - if (!body.empty()) { - request += "Content-Length: " + std::to_string(body.length()) + "\r\n"; - } - request += "Connection: close\r\n\r\n"; - if (!body.empty()) { - request += body; - } - - return request; -} -} // namespace atom::web diff --git a/src/atom/web/httplite.hpp b/src/atom/web/httplite.hpp deleted file mode 100644 index 30170d17..00000000 --- a/src/atom/web/httplite.hpp +++ /dev/null @@ -1,175 +0,0 @@ -/* - * httplite.hpp - * - * Copyright (C) 2023-2024 Max Qian - */ - -/************************************************* - -Date: 2023-11-3 - -Description: Simple Http Client - -**************************************************/ - -#ifndef ATOM_WEB_HTTPLITE_HPP -#define ATOM_WEB_HTTPLITE_HPP - -#include -#include -#include -#include -#include -#include - -#ifdef _WIN32 -#include -#include -#pragma comment(lib, "ws2_32.lib") -#undef DELETE -#else -#include -#include -#include -#define SOCKET int -#endif - -namespace atom::web { - -enum class HttpMethod { GET, POST, PUT, DELETE }; - -struct HttpResponse { - std::string body; - int statusCode; - std::string statusMessage; -}; - -/** - * @brief HttpClient 类用于与服务器建立连接并发送请求接收响应。 - */ -class HttpClient { -public: - /** - * @brief HttpClient 构造函数 - */ - HttpClient(); - - /** - * @brief HttpClient 析构函数 - */ - ~HttpClient(); - - /** - * @brief 设置错误处理函数 - * @param errorHandler 错误处理函数 - */ - void setErrorHandler(std::function errorHandler); - - /** - * @brief 初始化 HttpClient - * @return 初始化是否成功 - */ - bool initialize(); - - /** - * @brief 连接到服务器 - * @param host 服务器主机名 - * @param port 服务器端口号 - * @param useHttps 是否使用 HTTPS - * @return 连接是否成功 - */ - bool connectToServer(const std::string &host, int port, bool useHttps); - - /** - * @brief 发送请求到服务器 - * @param request 请求内容 - * @return 是否成功发送请求 - */ - bool sendRequest(const std::string &request); - - /** - * @brief 接收服务器响应 - * @return 服务器响应 - */ - HttpResponse receiveResponse(); - -private: - /** - * @brief 关闭 socket 连接 - */ - void closeSocket(); - -private: - SOCKET socketfd; /**< socket 文件描述符 */ - std::function errorHandler; /**< 错误处理函数 */ -}; - -/** - * @brief HttpRequestBuilder 类用于构建 HTTP 请求。 - */ -class HttpRequestBuilder { -public: - /** - * @brief HttpRequestBuilder 构造函数 - * @param method HTTP 请求方法 - * @param url 请求 URL - */ - HttpRequestBuilder(HttpMethod method, const std::string &url); - - /** - * @brief 设置请求体 - * @param bodyText 请求体内容 - * @return 当前 HttpRequestBuilder 实例 - */ - HttpRequestBuilder &setBody(const std::string &bodyText); - - /** - * @brief 设置请求内容类型 - * @param contentTypeValue 内容类型值 - * @return 当前 HttpRequestBuilder 实例 - */ - HttpRequestBuilder &setContentType(const std::string &contentTypeValue); - - /** - * @brief 设置超时时间 - * @param timeoutValue 超时时间 - * @return 当前 HttpRequestBuilder 实例 - */ - HttpRequestBuilder &setTimeout(std::chrono::seconds timeoutValue); - - /** - * @brief 添加请求头 - * @param key 请求头键 - * @param value 请求头值 - * @return 当前 HttpRequestBuilder 实例 - */ - HttpRequestBuilder &addHeader(const std::string &key, - const std::string &value); - - /** - * @brief 发送 HTTP 请求 - * @return HTTP 响应 - */ - HttpResponse send(); - -private: - /** - * @brief 构建请求字符串 - * @param host 主机名 - * @param path 路径 - * @return 构建的请求字符串 - */ - std::string buildRequestString(const std::string &host, - const std::string &path); - -private: - HttpMethod method; /**< HTTP 请求方法 */ - std::string url; /**< 请求 URL */ - std::string body; /**< 请求体 */ - std::string contentType; /**< 内容类型 */ - std::chrono::seconds timeout; /**< 超时时间 */ - std::map headers; /**< 请求头映射 */ -}; -} // namespace atom::web - -#endif diff --git a/src/atom/web/httpparser.cpp b/src/atom/web/httpparser.cpp index f6d68867..34b67554 100644 --- a/src/atom/web/httpparser.cpp +++ b/src/atom/web/httpparser.cpp @@ -8,22 +8,26 @@ Date: 2023-11-3 -Description: Http Header Parser +Description: Http Header Parser with C++20 features **************************************************/ #include "httpparser.hpp" +#include // C++20 algorithms #include +#include // C++20 ranges #include namespace atom::web { -class HttpHeaderParserImpl { + +class HttpHeaderParser::HttpHeaderParserImpl { public: std::map> headers_; }; -HttpHeaderParser::HttpHeaderParser() : m_pImpl(std::make_unique()) {} +HttpHeaderParser::HttpHeaderParser() + : m_pImpl(std::make_unique()) {} void HttpHeaderParser::parseHeaders(const std::string &rawHeaders) { m_pImpl->headers_.clear(); @@ -54,37 +58,30 @@ void HttpHeaderParser::setHeaders( m_pImpl->headers_ = headers; } -std::vector HttpHeaderParser::getHeaderValues( +void HttpHeaderParser::addHeaderValue(const std::string &key, + const std::string &value) { + m_pImpl->headers_[key].push_back(value); +} + +std::optional> HttpHeaderParser::getHeaderValues( const std::string &key) const { - auto it = m_pImpl->headers_.find(key); - if (it != m_pImpl->headers_.end()) { + if (auto it = m_pImpl->headers_.find(key); it != m_pImpl->headers_.end()) { return it->second; - } else { - return {}; } + return std::nullopt; // Use optional to represent missing values } void HttpHeaderParser::removeHeader(const std::string &key) { m_pImpl->headers_.erase(key); } -void HttpHeaderParser::printHeaders() const { - for (const auto &entry : m_pImpl->headers_) { - std::cout << entry.first << ": "; - for (const auto &value : entry.second) { - std::cout << value << ", "; - } - std::cout << std::endl; - } -} - std::map> HttpHeaderParser::getAllHeaders() const { return m_pImpl->headers_; } bool HttpHeaderParser::hasHeader(const std::string &key) const { - return m_pImpl->headers_.find(key) != m_pImpl->headers_.end(); + return m_pImpl->headers_.contains(key); // Use C++20 contains method } void HttpHeaderParser::clearHeaders() { m_pImpl->headers_.clear(); } diff --git a/src/atom/web/httpparser.hpp b/src/atom/web/httpparser.hpp index f5989166..553f98d9 100644 --- a/src/atom/web/httpparser.hpp +++ b/src/atom/web/httpparser.hpp @@ -8,7 +8,7 @@ Date: 2023-11-3 -Description: Http Header Parser +Description: Http Header Parser with C++20 features **************************************************/ @@ -17,13 +17,11 @@ Description: Http Header Parser #include #include +#include #include #include - namespace atom::web { -class HttpHeaderParserImpl; - /** * @brief The HttpHeaderParser class is responsible for parsing and manipulating * HTTP headers. @@ -55,12 +53,20 @@ class HttpHeaderParser { void setHeaders( const std::map> &headers); + /** + * @brief Adds a new value to an existing header field. + * @param key The key of the header field. + * @param value The value to add. + */ + void addHeaderValue(const std::string &key, const std::string &value); + /** * @brief Retrieves the values of a specific header field. * @param key The key of the header field. * @return A vector containing the values of the header field. */ - std::vector getHeaderValues(const std::string &key) const; + [[nodiscard]] auto getHeaderValues(const std::string &key) const + -> std::optional>; /** * @brief Removes a specific header field. @@ -68,23 +74,19 @@ class HttpHeaderParser { */ void removeHeader(const std::string &key); - /** - * @brief Prints all the parsed headers to the console. - */ - void printHeaders() const; - /** * @brief Retrieves all the parsed headers. * @return A map containing all the parsed headers. */ - std::map> getAllHeaders() const; + [[nodiscard]] auto getAllHeaders() const + -> std::map>; /** * @brief Checks if a specific header field exists. * @param key The key of the header field to check. * @return True if the header field exists, false otherwise. */ - bool hasHeader(const std::string &key) const; + [[nodiscard]] auto hasHeader(const std::string &key) const -> bool; /** * @brief Clears all the parsed headers. @@ -92,6 +94,7 @@ class HttpHeaderParser { void clearHeaders(); private: + class HttpHeaderParserImpl; std::unique_ptr m_pImpl; // Pointer to implementation }; } // namespace atom::web diff --git a/src/client/alpaca/covercalibrator.cpp b/src/client/alpaca/covercalibrator.cpp new file mode 100644 index 00000000..e9a66ef9 --- /dev/null +++ b/src/client/alpaca/covercalibrator.cpp @@ -0,0 +1,77 @@ +#include "covercalibrator.hpp" +#include +#include + +AlpacaCoverCalibrator::AlpacaCoverCalibrator(const std::string& address, + int device_number, + const std::string& protocol) + : AlpacaDevice(address, "covercalibrator", device_number, protocol) {} + +int AlpacaCoverCalibrator::GetBrightness() { + return GetNumericProperty("brightness"); +} + +AlpacaCoverCalibrator::CalibratorStatus +AlpacaCoverCalibrator::GetCalibratorState() { + return static_cast( + GetNumericProperty("calibratorstate")); +} + +AlpacaCoverCalibrator::CoverStatus AlpacaCoverCalibrator::GetCoverState() { + return static_cast(GetNumericProperty("coverstate")); +} + +int AlpacaCoverCalibrator::GetMaxBrightness() { + return GetNumericProperty("maxbrightness"); +} + +template +std::future AlpacaCoverCalibrator::AsyncOperation( + Func&& func, const std::string& operationName) { + return std::async( + std::launch::async, + [this, func = std::forward(func), operationName]() { + if (m_current_operation.valid() && + m_current_operation.wait_for(std::chrono::seconds(0)) != + std::future_status::ready) { + throw std::runtime_error("Another operation is in progress"); + } + + func(); + + // Poll the device state until the operation is complete + while (true) { + auto state = operationName == "calibrator" + ? GetCalibratorState() + : GetCoverState(); + if (state != CalibratorStatus::NotReady && + state != CoverStatus::Moving) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + }); +} + +std::future AlpacaCoverCalibrator::CalibratorOff() { + return AsyncOperation([this]() { Put("calibratoroff"); }, "calibrator"); +} + +std::future AlpacaCoverCalibrator::CalibratorOn(int BrightnessVal) { + return AsyncOperation( + [this, BrightnessVal]() { + Put("calibratoron", + {{"Brightness", std::to_string(BrightnessVal)}}); + }, + "calibrator"); +} + +std::future AlpacaCoverCalibrator::CloseCover() { + return AsyncOperation([this]() { Put("closecover"); }, "cover"); +} + +void AlpacaCoverCalibrator::HaltCover() { Put("haltcover"); } + +std::future AlpacaCoverCalibrator::OpenCover() { + return AsyncOperation([this]() { Put("opencover"); }, "cover"); +} diff --git a/src/client/alpaca/covercalibrator.hpp b/src/client/alpaca/covercalibrator.hpp new file mode 100644 index 00000000..8b2e2751 --- /dev/null +++ b/src/client/alpaca/covercalibrator.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include + +#include "device.hpp" + +class AlpacaCoverCalibrator : public AlpacaDevice { +public: + enum class CalibratorStatus { + NotPresent = 0, + Off = 1, + NotReady = 2, + Ready = 3, + Unknown = 4, + Error = 5 + }; + + enum class CoverStatus { + NotPresent = 0, + Closed = 1, + Moving = 2, + Open = 3, + Unknown = 4, + Error = 5 + }; + + AlpacaCoverCalibrator(const std::string& address, int device_number, + const std::string& protocol = "http"); + virtual ~AlpacaCoverCalibrator() = default; + + // Properties + int GetBrightness(); + CalibratorStatus GetCalibratorState(); + CoverStatus GetCoverState(); + int GetMaxBrightness(); + + // Methods + std::future CalibratorOff(); + std::future CalibratorOn(int BrightnessVal); + std::future CloseCover(); + void HaltCover(); + std::future OpenCover(); + +private: + template + std::future AsyncOperation(Func&& func, + const std::string& operationName); + + std::future m_current_operation; +}; diff --git a/src/client/alpaca/device.cpp b/src/client/alpaca/device.cpp new file mode 100644 index 00000000..362fcbc9 --- /dev/null +++ b/src/client/alpaca/device.cpp @@ -0,0 +1,209 @@ +#include "device.hpp" +#include +#include +#include + +int AlpacaDevice::s_client_id = std::random_device{}() % 65536; +int AlpacaDevice::s_client_trans_id = 1; +std::mutex AlpacaDevice::s_ctid_mutex; + +AlpacaDevice::AlpacaDevice(const std::string& address, + const std::string& device_type, int device_number, + const std::string& protocol) + : m_address(address), + m_device_type(device_type), + m_device_number(device_number), + m_api_version(1) { + m_base_url = std::format("{}://{}/api/v{}/{}/{}", protocol, address, + m_api_version, device_type, device_number); + + m_curl = std::unique_ptr>( + curl_easy_init(), [](CURL* curl) { curl_easy_cleanup(curl); }); + if (!m_curl) { + throw std::runtime_error("Failed to initialize CURL"); + } +} + +AlpacaDevice::~AlpacaDevice() = default; + +std::string AlpacaDevice::Action(const std::string& ActionName, + const std::vector& Parameters) { + nlohmann::json params = {{"Action", ActionName}, + {"Parameters", Parameters}}; + return Put("action", params)["Value"]; +} + +void AlpacaDevice::CommandBlind(const std::string& Command, bool Raw) { + Put("commandblind", + {{"Command", Command}, {"Raw", Raw ? "true" : "false"}}); +} + +bool AlpacaDevice::CommandBool(const std::string& Command, bool Raw) { + return Put("commandbool", {{"Command", Command}, + {"Raw", Raw ? "true" : "false"}})["Value"]; +} + +std::string AlpacaDevice::CommandString(const std::string& Command, bool Raw) { + return Put("commandstring", {{"Command", Command}, + {"Raw", Raw ? "true" : "false"}})["Value"]; +} + +bool AlpacaDevice::GetConnected() { return Get("connected"); } + +void AlpacaDevice::SetConnected(bool ConnectedState) { + Put("connected", {{"Connected", ConnectedState ? "true" : "false"}}); +} + +std::string AlpacaDevice::GetDescription() { return Get("description"); } + +std::vector AlpacaDevice::GetDriverInfo() { + std::string info = Get("driverinfo"); + std::vector result; + std::istringstream iss(info); + std::string item; + while (std::getline(iss, item, ',')) { + result.push_back(item); + } + return result; +} + +std::string AlpacaDevice::GetDriverVersion() { return Get("driverversion"); } + +int AlpacaDevice::GetInterfaceVersion() { + return std::stoi(Get("interfaceversion").get()); +} + +std::string AlpacaDevice::GetName() { return Get("name"); } + +std::vector AlpacaDevice::GetSupportedActions() { + return Get("supportedactions").get>(); +} + +nlohmann::json AlpacaDevice::Get( + const std::string& attribute, + const std::map& params) { + std::string url = m_base_url + "/" + attribute; + + std::string query_string; + for (const auto& [key, value] : params) { + if (!query_string.empty()) + query_string += "&"; + query_string += key + "=" + value; + } + + { + std::lock_guard lock(s_ctid_mutex); + if (!query_string.empty()) + query_string += "&"; + query_string += std::format("ClientTransactionID={}&ClientID={}", + s_client_trans_id++, s_client_id); + } + + if (!query_string.empty()) { + url += "?" + query_string; + } + + std::string response_string; + curl_easy_setopt(m_curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(m_curl.get(), CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(m_curl.get(), CURLOPT_WRITEDATA, &response_string); + + CURLcode res = curl_easy_perform(m_curl.get()); + if (res != CURLE_OK) { + throw std::runtime_error( + std::format("CURL error: {}", curl_easy_strerror(res))); + } + + nlohmann::json response = nlohmann::json::parse(response_string); + CheckError(response); + return response["Value"]; +} + +nlohmann::json AlpacaDevice::Put( + const std::string& attribute, + const std::map& data) { + std::string url = m_base_url + "/" + attribute; + + std::string post_fields; + for (const auto& [key, value] : data) { + if (!post_fields.empty()) + post_fields += "&"; + post_fields += key + "=" + value; + } + + { + std::lock_guard lock(s_ctid_mutex); + if (!post_fields.empty()) + post_fields += "&"; + post_fields += std::format("ClientTransactionID={}&ClientID={}", + s_client_trans_id++, s_client_id); + } + + std::string response_string; + curl_easy_setopt(m_curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(m_curl.get(), CURLOPT_POSTFIELDS, post_fields.c_str()); + curl_easy_setopt(m_curl.get(), CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(m_curl.get(), CURLOPT_WRITEDATA, &response_string); + + CURLcode res = curl_easy_perform(m_curl.get()); + if (res != CURLE_OK) { + throw std::runtime_error( + std::format("CURL error: {}", curl_easy_strerror(res))); + } + + nlohmann::json response = nlohmann::json::parse(response_string); + CheckError(response); + return response; +} + +size_t AlpacaDevice::WriteCallback(void* contents, size_t size, size_t nmemb, + std::string* s) { + size_t newLength = size * nmemb; + try { + s->append((char*)contents, newLength); + } catch (std::bad_alloc& e) { + return 0; + } + return newLength; +} + +void AlpacaDevice::CheckError(const nlohmann::json& response) { + int errorNumber = response["ErrorNumber"]; + std::string errorMessage = response["ErrorMessage"]; + + if (errorNumber != 0) { + switch (errorNumber) { + case 0x0400: + throw std::runtime_error("NotImplementedException: " + + errorMessage); + case 0x0401: + throw std::invalid_argument("InvalidValueException: " + + errorMessage); + case 0x0402: + throw std::runtime_error("ValueNotSetException: " + + errorMessage); + case 0x0407: + throw std::runtime_error("NotConnectedException: " + + errorMessage); + case 0x0408: + throw std::runtime_error("ParkedException: " + errorMessage); + case 0x0409: + throw std::runtime_error("SlavedException: " + errorMessage); + case 0x040B: + throw std::runtime_error("InvalidOperationException: " + + errorMessage); + case 0x040C: + throw std::runtime_error("ActionNotImplementedException: " + + errorMessage); + default: + if (errorNumber >= 0x500 && errorNumber <= 0xFFF) { + throw std::runtime_error(std::format( + "DriverException: ({}) {}", errorNumber, errorMessage)); + } else { + throw std::runtime_error( + std::format("UnknownException: ({}) {}", errorNumber, + errorMessage)); + } + } + } +} diff --git a/src/client/alpaca/device.hpp b/src/client/alpaca/device.hpp new file mode 100644 index 00000000..97aa1410 --- /dev/null +++ b/src/client/alpaca/device.hpp @@ -0,0 +1,77 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "atom/function/concept.hpp" + +template +concept JsonCompatible = requires(nlohmann::json j, T t) { + { j.get() } -> std::convertible_to; +}; + +class AlpacaDevice { +public: + AlpacaDevice(const std::string& address, const std::string& device_type, + int device_number, const std::string& protocol); + virtual ~AlpacaDevice(); + + // Common interface methods + std::string Action(const std::string& ActionName, + const std::vector& Parameters = {}); + void CommandBlind(const std::string& Command, bool Raw); + bool CommandBool(const std::string& Command, bool Raw); + std::string CommandString(const std::string& Command, bool Raw); + + bool GetConnected(); + void SetConnected(bool ConnectedState); + std::string GetDescription(); + std::vector GetDriverInfo(); + std::string GetDriverVersion(); + int GetInterfaceVersion(); + std::string GetName(); + std::vector GetSupportedActions(); + + template + T GetNumericProperty(const std::string& property_name) { + return std::get(Get(property_name)); + } + + template + std::vector GetArrayProperty( + const std::string& property, + const std::map& parameters = {}) { + nlohmann::json response = Get(property, parameters); + return response["Value"].get>(); + } + +protected: + // HTTP methods + nlohmann::json Get(const std::string& attribute, + const std::map& params = {}); + nlohmann::json Put(const std::string& attribute, + const std::map& data = {}); + +private: + static size_t WriteCallback(void* contents, size_t size, size_t nmemb, + std::string* s); + void CheckError(const nlohmann::json& response); + + std::string m_address; + std::string m_device_type; + int m_device_number; + int m_api_version; + std::string m_base_url; + + static int s_client_id; + static int s_client_trans_id; + static std::mutex s_ctid_mutex; + + std::unique_ptr> m_curl; +}; diff --git a/src/client/alpaca/discovery.cpp b/src/client/alpaca/discovery.cpp new file mode 100644 index 00000000..3183f87d --- /dev/null +++ b/src/client/alpaca/discovery.cpp @@ -0,0 +1,226 @@ +#include "discovery.hpp" + +#ifdef _WIN32 +#include +#include +#include +#pragma comment(lib, "ws2_32.lib") +#pragma comment(lib, "iphlpapi.lib") +#else +#include +#include +#include +#include +#include +#include +#include +#endif + +#include +#include +#include +#include +#include +#include + +using json = nlohmann::json; + +class AlpacaDiscovery::Impl { +public: + static constexpr int PORT = 32227; + static constexpr const char* ALPACA_DISCOVERY = "alpacadiscovery1"; + static constexpr const char* ALPACA_RESPONSE = "AlpacaPort"; + + std::vector searchIPv4(int numQuery, int timeout); + +private: + std::vector getInterfaces(); + void sendBroadcast(int sock, const std::string& interface); + void receiveResponses(int sock, std::vector& addresses); +}; + +AlpacaDiscovery::AlpacaDiscovery() : pImpl(std::make_unique()) {} + +AlpacaDiscovery::~AlpacaDiscovery() = default; + +std::vector AlpacaDiscovery::searchIPv4(int numQuery, + int timeout) { + return pImpl->searchIPv4(numQuery, timeout); +} + +std::vector AlpacaDiscovery::Impl::searchIPv4(int numQuery, + int timeout) { + std::vector addresses; + +#ifdef _WIN32 + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { + throw std::runtime_error("Failed to initialize Winsock"); + } +#endif + + int sock = socket(AF_INET, SOCK_DGRAM, 0); + if (sock == -1) { + throw std::runtime_error("Failed to create socket"); + } + + int broadcastEnable = 1; + if (setsockopt(sock, SOL_SOCKET, SO_BROADCAST, + reinterpret_cast(&broadcastEnable), + sizeof(broadcastEnable)) < 0) { + throw std::runtime_error("Failed to set socket option"); + } + + struct timeval tv; + tv.tv_sec = timeout; + tv.tv_usec = 0; + if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), + sizeof(tv)) < 0) { + throw std::runtime_error("Failed to set socket timeout"); + } + + std::vector interfaces = getInterfaces(); + + for (int i = 0; i < numQuery; ++i) { + for (const auto& interface : interfaces) { + sendBroadcast(sock, interface); + } + + receiveResponses(sock, addresses); + } + +#ifdef _WIN32 + closesocket(sock); + WSACleanup(); +#else + close(sock); +#endif + + return addresses; +} + +std::vector AlpacaDiscovery::Impl::getInterfaces() { + std::vector interfaces; + +#ifdef _WIN32 + ULONG bufferSize = 15000; + PIP_ADAPTER_ADDRESSES pAddresses = nullptr; + DWORD retVal = 0; + + do { + pAddresses = (IP_ADAPTER_ADDRESSES*)malloc(bufferSize); + if (pAddresses == nullptr) { + throw std::runtime_error( + "Memory allocation failed for IP_ADAPTER_ADDRESSES struct"); + } + + retVal = GetAdaptersAddresses(AF_INET, GAA_FLAG_INCLUDE_PREFIX, nullptr, + pAddresses, &bufferSize); + + if (retVal == ERROR_BUFFER_OVERFLOW) { + free(pAddresses); + pAddresses = nullptr; + } + } while (retVal == ERROR_BUFFER_OVERFLOW); + + if (retVal == NO_ERROR) { + for (PIP_ADAPTER_ADDRESSES pCurrAddresses = pAddresses; + pCurrAddresses != nullptr; pCurrAddresses = pCurrAddresses->Next) { + PIP_ADAPTER_UNICAST_ADDRESS pUnicast = + pCurrAddresses->FirstUnicastAddress; + if (pUnicast != nullptr) { + for (; pUnicast != nullptr; pUnicast = pUnicast->Next) { + sockaddr_in* pAddr = reinterpret_cast( + pUnicast->Address.lpSockaddr); + char ip[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &(pAddr->sin_addr), ip, INET_ADDRSTRLEN); + interfaces.push_back(ip); + } + } + } + } + + if (pAddresses) { + free(pAddresses); + } +#else + struct ifaddrs* ifAddrStruct = nullptr; + struct ifaddrs* ifa = nullptr; + + if (getifaddrs(&ifAddrStruct) == -1) { + throw std::runtime_error("Failed to get network interfaces"); + } + + for (ifa = ifAddrStruct; ifa != nullptr; ifa = ifa->ifa_next) { + if (ifa->ifa_addr == nullptr) + continue; + + if (ifa->ifa_addr->sa_family == AF_INET) { + void* tmpAddrPtr = &((struct sockaddr_in*)ifa->ifa_addr)->sin_addr; + char addressBuffer[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, tmpAddrPtr, addressBuffer, INET_ADDRSTRLEN); + interfaces.push_back(addressBuffer); + } + } + + if (ifAddrStruct != nullptr) + freeifaddrs(ifAddrStruct); +#endif + + return interfaces; +} + +void AlpacaDiscovery::Impl::sendBroadcast(int sock, + const std::string& interface) { + struct sockaddr_in broadcastAddr; + memset(&broadcastAddr, 0, sizeof(broadcastAddr)); + broadcastAddr.sin_family = AF_INET; + broadcastAddr.sin_port = htons(PORT); + + if (interface == "127.0.0.1") { + broadcastAddr.sin_addr.s_addr = inet_addr("127.255.255.255"); + } else { + broadcastAddr.sin_addr.s_addr = inet_addr("255.255.255.255"); + } + + if (sendto(sock, ALPACA_DISCOVERY, strlen(ALPACA_DISCOVERY), 0, + (struct sockaddr*)&broadcastAddr, sizeof(broadcastAddr)) == -1) { + std::cerr << "Failed to send broadcast on interface " + << interface << std::endl; + } +} + +void AlpacaDiscovery::Impl::receiveResponses( + int sock, std::vector& addresses) { + char buffer[1024]; + struct sockaddr_in senderAddr; + socklen_t senderAddrLen = sizeof(senderAddr); + + while (true) { + int bytesReceived = + recvfrom(sock, buffer, sizeof(buffer) - 1, 0, + (struct sockaddr*)&senderAddr, &senderAddrLen); + if (bytesReceived == -1) { + break; // Timeout or error + } + + buffer[bytesReceived] = '\0'; + + try { + json response = json::parse(buffer); + int alpacaPort = response[ALPACA_RESPONSE]; + char senderIP[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &(senderAddr.sin_addr), senderIP, + INET_ADDRSTRLEN); + + std::string addressPort = + std::format("{}:{}", senderIP, alpacaPort); + if (std::find(addresses.begin(), addresses.end(), addressPort) == + addresses.end()) { + addresses.push_back(addressPort); + } + } catch (const std::exception& e) { + std::cerr << "Failed to parse response: " << e.what() << std::endl; + } + } +} diff --git a/src/client/alpaca/discovery.hpp b/src/client/alpaca/discovery.hpp new file mode 100644 index 00000000..6429c882 --- /dev/null +++ b/src/client/alpaca/discovery.hpp @@ -0,0 +1,45 @@ +#ifndef ALPACADISCOVERY_H +#define ALPACADISCOVERY_H + +#include +#include +#include + +/** + * @class AlpacaDiscovery + * @brief This class handles the discovery of Alpaca servers on the local + * network via UDP broadcast. + * + * The AlpacaDiscovery class searches for Alpaca servers by broadcasting a + * discovery message on the local network. It listens for responses that contain + * the server's IP address and port. + */ +class AlpacaDiscovery { +public: + /** + * @brief Constructor + */ + AlpacaDiscovery(); + + /** + * @brief Destructor + */ + ~AlpacaDiscovery(); + + /** + * @brief Searches for Alpaca servers on the local network. + * + * @param numQuery The number of broadcast queries to send. + * @param timeout The timeout in seconds for waiting on responses. + * @return A vector of strings containing the IP addresses and ports of + * discovered servers. + * @throw std::runtime_error If socket operations fail. + */ + std::vector searchIPv4(int numQuery = 2, int timeout = 2); + +private: + class Impl; + std::unique_ptr pImpl; /**< Pointer to the implementation class */ +}; + +#endif // ALPACADISCOVERY_H diff --git a/src/client/alpaca/dome.cpp b/src/client/alpaca/dome.cpp new file mode 100644 index 00000000..128752d0 --- /dev/null +++ b/src/client/alpaca/dome.cpp @@ -0,0 +1,94 @@ +#include "dome.hpp" +#include +#include +#include + +AlpacaDome::AlpacaDome(std::string_view address, int device_number, + std::string_view protocol) + : AlpacaDevice(std::string(address), "dome", device_number, + std::string(protocol)) {} + +double AlpacaDome::GetAltitude() { return GetProperty("altitude"); } +bool AlpacaDome::GetAtHome() { return GetProperty("athome"); } +bool AlpacaDome::GetAtPark() { return GetProperty("atpark"); } +double AlpacaDome::GetAzimuth() { return GetProperty("azimuth"); } +bool AlpacaDome::GetCanFindHome() { return GetProperty("canfindhome"); } +bool AlpacaDome::GetCanPark() { return GetProperty("canpark"); } +bool AlpacaDome::GetCanSetAltitude() { + return GetProperty("cansetaltitude"); +} +bool AlpacaDome::GetCanSetAzimuth() { + return GetProperty("cansetazimuth"); +} +bool AlpacaDome::GetCanSetPark() { return GetProperty("cansetpark"); } +bool AlpacaDome::GetCanSetShutter() { + return GetProperty("cansetshutter"); +} +bool AlpacaDome::GetCanSlave() { return GetProperty("canslave"); } +bool AlpacaDome::GetCanSyncAzimuth() { + return GetProperty("cansyncazimuth"); +} +AlpacaDome::ShutterState AlpacaDome::GetShutterStatus() { + return static_cast(GetProperty("shutterstatus")); +} +bool AlpacaDome::GetSlaved() { return GetProperty("slaved"); } +void AlpacaDome::SetSlaved(bool SlavedState) { + Put("slaved", {{"Slaved", SlavedState ? "true" : "false"}}); +} +bool AlpacaDome::GetSlewing() { return GetProperty("slewing"); } + +void AlpacaDome::AbortSlew() { Put("abortslew"); } + +template +std::future AlpacaDome::AsyncOperation(Func&& func, + const std::string& operationName) { + return std::async( + std::launch::async, + [this, func = std::forward(func), operationName]() { + func(); + while (GetSlewing() || + (operationName == "shutter" && + (GetShutterStatus() == ShutterState::ShutterOpening || + GetShutterStatus() == ShutterState::ShutterClosing))) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + }); +} + +std::future AlpacaDome::CloseShutter() { + return AsyncOperation([this]() { Put("closeshutter"); }, "shutter"); +} + +std::future AlpacaDome::FindHome() { + return AsyncOperation([this]() { Put("findhome"); }, "home"); +} + +std::future AlpacaDome::OpenShutter() { + return AsyncOperation([this]() { Put("openshutter"); }, "shutter"); +} + +std::future AlpacaDome::Park() { + return AsyncOperation([this]() { Put("park"); }, "park"); +} + +void AlpacaDome::SetPark() { Put("setpark"); } + +std::future AlpacaDome::SlewToAltitude(double Altitude) { + return AsyncOperation( + [this, Altitude]() { + Put("slewtoaltitude", {{"Altitude", std::to_string(Altitude)}}); + }, + "slew"); +} + +std::future AlpacaDome::SlewToAzimuth(double Azimuth) { + return AsyncOperation( + [this, Azimuth]() { + Put("slewtoazimuth", {{"Azimuth", std::to_string(Azimuth)}}); + }, + "slew"); +} + +void AlpacaDome::SyncToAzimuth(double Azimuth) { + Put("synctoazimuth", {{"Azimuth", std::to_string(Azimuth)}}); +} diff --git a/src/client/alpaca/dome.hpp b/src/client/alpaca/dome.hpp new file mode 100644 index 00000000..cb91472a --- /dev/null +++ b/src/client/alpaca/dome.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include +#include "device.hpp" + +class AlpacaDome : public AlpacaDevice { +public: + enum class ShutterState { + ShutterOpen = 0, + ShutterClosed = 1, + ShutterOpening = 2, + ShutterClosing = 3, + ShutterError = 4 + }; + + AlpacaDome(std::string_view address, int device_number, + std::string_view protocol = "http"); + virtual ~AlpacaDome() = default; + + // Properties + double GetAltitude(); + bool GetAtHome(); + bool GetAtPark(); + double GetAzimuth(); + bool GetCanFindHome(); + bool GetCanPark(); + bool GetCanSetAltitude(); + bool GetCanSetAzimuth(); + bool GetCanSetPark(); + bool GetCanSetShutter(); + bool GetCanSlave(); + bool GetCanSyncAzimuth(); + ShutterState GetShutterStatus(); + bool GetSlaved(); + void SetSlaved(bool SlavedState); + bool GetSlewing(); + + // Methods + void AbortSlew(); + std::future CloseShutter(); + std::future FindHome(); + std::future OpenShutter(); + std::future Park(); + void SetPark(); + std::future SlewToAltitude(double Altitude); + std::future SlewToAzimuth(double Azimuth); + void SyncToAzimuth(double Azimuth); + +private: + template + T GetProperty(const std::string& property) const { + return GetNumericProperty(property); + } + + template + std::future AsyncOperation(Func&& func, + const std::string& operationName); +}; diff --git a/src/client/alpaca/filterwheel.cpp b/src/client/alpaca/filterwheel.cpp new file mode 100644 index 00000000..d06e66cf --- /dev/null +++ b/src/client/alpaca/filterwheel.cpp @@ -0,0 +1,36 @@ +#include +#include +#include + +#include "filterwheel.hpp" + +AlpacaFilterWheel::AlpacaFilterWheel(std::string_view address, + int device_number, + std::string_view protocol) + : AlpacaDevice(std::string(address), "filterwheel", device_number, + std::string(protocol)) {} + +std::vector AlpacaFilterWheel::GetFocusOffsets() { + return GetArrayProperty("focusoffsets"); +} + +std::vector AlpacaFilterWheel::GetNames() { + return GetArrayProperty("names"); +} + +int AlpacaFilterWheel::GetPosition() { + return GetNumericProperty("position"); +} + +std::future AlpacaFilterWheel::SetPosition(int Position) { + Put("position", {{"Position", std::to_string(Position)}}); + + return std::async(std::launch::async, + [this, Position]() { WaitForFilterChange(); }); +} + +void AlpacaFilterWheel::WaitForFilterChange() { + while (GetPosition() == FILTER_MOVING) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } +} diff --git a/src/client/alpaca/filterwheel.hpp b/src/client/alpaca/filterwheel.hpp new file mode 100644 index 00000000..99b66435 --- /dev/null +++ b/src/client/alpaca/filterwheel.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include +#include +#include +#include "device.hpp" + +class AlpacaFilterWheel : public AlpacaDevice { +public: + AlpacaFilterWheel(std::string_view address, int device_number, + std::string_view protocol = "http"); + virtual ~AlpacaFilterWheel() = default; + + // Properties + std::vector GetFocusOffsets(); + std::vector GetNames(); + int GetPosition(); + std::future SetPosition(int Position); + +private: + static constexpr int FILTER_MOVING = -1; + + std::future m_position_change_future; + + void WaitForFilterChange(); +}; diff --git a/src/client/alpaca/focuser.cpp b/src/client/alpaca/focuser.cpp new file mode 100644 index 00000000..dbe228e3 --- /dev/null +++ b/src/client/alpaca/focuser.cpp @@ -0,0 +1,93 @@ +#include "focuser.hpp" + +#include +#include +#include +#include + +AlpacaFocuser::AlpacaFocuser(const std::string& address, int device_number, + const std::string& protocol) + : AlpacaDevice(address, "focuser", device_number, protocol) {} + +AlpacaFocuser::~AlpacaFocuser() { + if (m_move_thread.joinable()) { + m_move_thread.join(); + } +} + +bool AlpacaFocuser::GetAbsolute() { + return GetNumericProperty("absolute"); +} + +bool AlpacaFocuser::GetIsMoving() { return m_is_moving.load(); } + +int AlpacaFocuser::GetMaxIncrement() { + return GetNumericProperty("maxincrement"); +} + +int AlpacaFocuser::GetMaxStep() { return GetNumericProperty("maxstep"); } + +int AlpacaFocuser::GetPosition() { return GetNumericProperty("position"); } + +float AlpacaFocuser::GetStepSize() { + return GetNumericProperty("stepsize"); +} + +bool AlpacaFocuser::GetTempComp() { + return GetNumericProperty("tempcomp"); +} + +void AlpacaFocuser::SetTempComp(bool TempCompState) { + Put("tempcomp", {{"TempComp", TempCompState ? "true" : "false"}}); +} + +bool AlpacaFocuser::GetTempCompAvailable() { + return GetNumericProperty("tempcompavailable"); +} + +std::optional AlpacaFocuser::GetTemperature() { + try { + return GetNumericProperty("temperature"); + } catch (const std::runtime_error& e) { + // If temperature is not implemented, return nullopt + return std::nullopt; + } +} + +void AlpacaFocuser::Halt() { + Put("halt"); + m_is_moving.store(false); +} + +void AlpacaFocuser::StartMove(int Position) { + Put("move", {{"Position", std::to_string(Position)}}); +} + +void AlpacaFocuser::MoveThread(int Position) { + m_is_moving.store(true); + StartMove(Position); + + // Poll the IsMoving property until the move is complete + while (GetNumericProperty("ismoving")) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + m_is_moving.store(false); +} + +std::future AlpacaFocuser::Move(int Position) { + // If a move is already in progress, wait for it to complete + if (m_move_thread.joinable()) { + m_move_thread.join(); + } + + // Start a new move thread + m_move_thread = std::thread(&AlpacaFocuser::MoveThread, this, Position); + + // Return a future that will be ready when the move is complete + return std::async(std::launch::deferred, [this]() { + if (m_move_thread.joinable()) { + m_move_thread.join(); + } + }); +} diff --git a/src/client/alpaca/focuser.hpp b/src/client/alpaca/focuser.hpp new file mode 100644 index 00000000..d16522df --- /dev/null +++ b/src/client/alpaca/focuser.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include "device.hpp" + +class AlpacaFocuser : public AlpacaDevice { +public: + AlpacaFocuser(const std::string& address, int device_number, + const std::string& protocol = "http"); + virtual ~AlpacaFocuser(); + + // Properties + bool GetAbsolute(); + bool GetIsMoving(); + int GetMaxIncrement(); + int GetMaxStep(); + int GetPosition(); + float GetStepSize(); + bool GetTempComp(); + void SetTempComp(bool TempCompState); + bool GetTempCompAvailable(); + std::optional GetTemperature(); + + // Methods + void Halt(); + std::future Move(int Position); + + // Template method for numeric properties + +private: + void StartMove(int Position); + void MoveThread(int Position); + + std::atomic m_is_moving{false}; + std::thread m_move_thread; +}; diff --git a/src/client/alpaca/observingconditions.cpp b/src/client/alpaca/observingconditions.cpp new file mode 100644 index 00000000..ad309427 --- /dev/null +++ b/src/client/alpaca/observingconditions.cpp @@ -0,0 +1,80 @@ +#include "observingconditions.hpp" +#include + +AlpacaObservingConditions::AlpacaObservingConditions(std::string_view address, + int device_number, + std::string_view protocol) + : AlpacaDevice(std::string(address), "observingconditions", device_number, + std::string(protocol)) {} + +double AlpacaObservingConditions::GetAveragePeriod() { + return GetNumericProperty("averageperiod"); +} + +void AlpacaObservingConditions::SetAveragePeriod(double period) { + Put("averageperiod", {{"AveragePeriod", std::to_string(period)}}); +} + +std::optional AlpacaObservingConditions::GetCloudCover() { + return GetOptionalProperty("cloudcover"); +} + +std::optional AlpacaObservingConditions::GetDewPoint() { + return GetOptionalProperty("dewpoint"); +} + +std::optional AlpacaObservingConditions::GetHumidity() { + return GetOptionalProperty("humidity"); +} + +std::optional AlpacaObservingConditions::GetPressure() { + return GetOptionalProperty("pressure"); +} + +std::optional AlpacaObservingConditions::GetRainRate() { + return GetOptionalProperty("rainrate"); +} + +std::optional AlpacaObservingConditions::GetSkyBrightness() { + return GetOptionalProperty("skybrightness"); +} + +std::optional AlpacaObservingConditions::GetSkyQuality() { + return GetOptionalProperty("skyquality"); +} + +std::optional AlpacaObservingConditions::GetSkyTemperature() { + return GetOptionalProperty("skytemperature"); +} + +std::optional AlpacaObservingConditions::GetStarFWHM() { + return GetOptionalProperty("starfwhm"); +} + +std::optional AlpacaObservingConditions::GetTemperature() { + return GetOptionalProperty("temperature"); +} + +std::optional AlpacaObservingConditions::GetWindDirection() { + return GetOptionalProperty("winddirection"); +} + +std::optional AlpacaObservingConditions::GetWindGust() { + return GetOptionalProperty("windgust"); +} + +std::optional AlpacaObservingConditions::GetWindSpeed() { + return GetOptionalProperty("windspeed"); +} + +void AlpacaObservingConditions::Refresh() { Put("refresh"); } + +std::string AlpacaObservingConditions::SensorDescription( + std::string_view SensorName) { + return Get("sensordescription", {{"SensorName", std::string(SensorName)}}); +} + +double AlpacaObservingConditions::TimeSinceLastUpdate( + std::string_view SensorName) { + return GetNumericProperty("timesincelastupdate"); +} diff --git a/src/client/alpaca/observingconditions.hpp b/src/client/alpaca/observingconditions.hpp new file mode 100644 index 00000000..5fb59abf --- /dev/null +++ b/src/client/alpaca/observingconditions.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include +#include "device.hpp" + +class AlpacaObservingConditions : public AlpacaDevice { +public: + AlpacaObservingConditions(std::string_view address, int device_number, + std::string_view protocol = "http"); + virtual ~AlpacaObservingConditions() = default; + + // Properties + double GetAveragePeriod(); + void SetAveragePeriod(double period); + std::optional GetCloudCover(); + std::optional GetDewPoint(); + std::optional GetHumidity(); + std::optional GetPressure(); + std::optional GetRainRate(); + std::optional GetSkyBrightness(); + std::optional GetSkyQuality(); + std::optional GetSkyTemperature(); + std::optional GetStarFWHM(); + std::optional GetTemperature(); + std::optional GetWindDirection(); + std::optional GetWindGust(); + std::optional GetWindSpeed(); + + // Methods + void Refresh(); + std::string SensorDescription(std::string_view SensorName); + double TimeSinceLastUpdate(std::string_view SensorName); + +private: + template + std::optional GetOptionalProperty(const std::string& property) const { + try { + return GetNumericProperty(property); + } catch (const std::exception&) { + return std::nullopt; + } + } +}; diff --git a/src/client/alpaca/rotator.cpp b/src/client/alpaca/rotator.cpp new file mode 100644 index 00000000..725e452e --- /dev/null +++ b/src/client/alpaca/rotator.cpp @@ -0,0 +1,58 @@ +#include "rotator.hpp" +#include + +AlpacaRotator::AlpacaRotator(const std::string& address, int device_number, + const std::string& protocol) + : AlpacaDevice(address, "rotator", device_number, protocol) {} + +bool AlpacaRotator::GetCanReverse() { + return GetNumericProperty("canreverse"); +} + +bool AlpacaRotator::GetIsMoving() { + return GetNumericProperty("ismoving"); +} + +double AlpacaRotator::GetMechanicalPosition() { + return GetNumericProperty("mechanicalposition"); +} + +double AlpacaRotator::GetPosition() { + return GetNumericProperty("position"); +} + +bool AlpacaRotator::GetReverse() { return GetNumericProperty("reverse"); } + +void AlpacaRotator::SetReverse(bool ReverseState) { + Put("reverse", {{"Reverse", ReverseState ? "true" : "false"}}); +} + +std::optional AlpacaRotator::GetStepSize() { + try { + return GetNumericProperty("stepsize"); + } catch (const std::runtime_error& e) { + return std::nullopt; + } +} + +double AlpacaRotator::GetTargetPosition() { + return GetNumericProperty("targetposition"); +} + +void AlpacaRotator::Halt() { Put("halt"); } + +std::future AlpacaRotator::Move(double Position) { + return AsyncMove("move", Position); +} + +std::future AlpacaRotator::MoveAbsolute(double Position) { + return AsyncMove("moveabsolute", Position); +} + +std::future AlpacaRotator::MoveMechanical(double Position) { + return AsyncMove("movemechanical", Position); +} + +void AlpacaRotator::Sync(double Position) { + Put("sync", {{"Position", std::to_string(Position)}}); +} diff --git a/src/client/alpaca/rotator.hpp b/src/client/alpaca/rotator.hpp new file mode 100644 index 00000000..ba543121 --- /dev/null +++ b/src/client/alpaca/rotator.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include +#include +#include "device.hpp" + +template +concept Numeric = std::is_arithmetic_v; + +class AlpacaRotator : public AlpacaDevice { +public: + AlpacaRotator(const std::string& address, int device_number, + const std::string& protocol = "http"); + virtual ~AlpacaRotator() = default; + + // Properties + bool GetCanReverse(); + bool GetIsMoving(); + double GetMechanicalPosition(); + double GetPosition(); + bool GetReverse(); + void SetReverse(bool ReverseState); + std::optional GetStepSize(); + double GetTargetPosition(); + + // Methods + void Halt(); + std::future Move(double Position); + std::future MoveAbsolute(double Position); + std::future MoveMechanical(double Position); + void Sync(double Position); + +private: + template + std::future AsyncMove(const std::string& method, T Position) { + return std::async(std::launch::async, [this, method, Position]() { + Put(method, {{"Position", std::to_string(Position)}}); + while (GetIsMoving()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + }); + } +}; diff --git a/src/client/alpaca/switch.cpp b/src/client/alpaca/switch.cpp new file mode 100644 index 00000000..aad190c1 --- /dev/null +++ b/src/client/alpaca/switch.cpp @@ -0,0 +1,54 @@ +#include "switch.hpp" +#include + +AlpacaSwitch::AlpacaSwitch(const std::string& address, int device_number, + const std::string& protocol) + : AlpacaDevice(address, "switch", device_number, protocol) {} + +int AlpacaSwitch::GetMaxSwitch() { + return GetNumericProperty("maxswitch"); +} + +bool AlpacaSwitch::CanWrite(int Id) { + return GetSwitchProperty("canwrite", Id); +} + +bool AlpacaSwitch::GetSwitch(int Id) { + return GetSwitchProperty("getswitch", Id); +} + +std::string AlpacaSwitch::GetSwitchDescription(int Id) { + return Get("getswitchdescription", {{"Id", std::to_string(Id)}}); +} + +std::string AlpacaSwitch::GetSwitchName(int Id) { + return Get("getswitchname", {{"Id", std::to_string(Id)}}); +} + +double AlpacaSwitch::GetSwitchValue(int Id) { + return GetSwitchProperty("getswitchvalue", Id); +} + +double AlpacaSwitch::MaxSwitchValue(int Id) { + return GetSwitchProperty("maxswitchvalue", Id); +} + +double AlpacaSwitch::MinSwitchValue(int Id) { + return GetSwitchProperty("minswitchvalue", Id); +} + +void AlpacaSwitch::SetSwitch(int Id, bool State) { + SetSwitchProperty("setswitch", Id, State); +} + +void AlpacaSwitch::SetSwitchName(int Id, const std::string& Name) { + SetSwitchProperty("setswitchname", Id, Name); +} + +void AlpacaSwitch::SetSwitchValue(int Id, double Value) { + SetSwitchProperty("setswitchvalue", Id, Value); +} + +double AlpacaSwitch::SwitchStep(int Id) { + return GetSwitchProperty("switchstep", Id); +} diff --git a/src/client/alpaca/switch.hpp b/src/client/alpaca/switch.hpp new file mode 100644 index 00000000..bc16bb0d --- /dev/null +++ b/src/client/alpaca/switch.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include +#include +#include "device.hpp" + +template +concept Numeric = std::is_arithmetic_v; + +class AlpacaSwitch : public AlpacaDevice { +public: + AlpacaSwitch(const std::string& address, int device_number, + const std::string& protocol = "http"); + virtual ~AlpacaSwitch() = default; + + // Properties + int GetMaxSwitch(); + + // Methods + bool CanWrite(int Id); + bool GetSwitch(int Id); + std::string GetSwitchDescription(int Id); + std::string GetSwitchName(int Id); + double GetSwitchValue(int Id); + double MaxSwitchValue(int Id); + double MinSwitchValue(int Id); + void SetSwitch(int Id, bool State); + void SetSwitchName(int Id, const std::string& Name); + void SetSwitchValue(int Id, double Value); + double SwitchStep(int Id); + +private: + template + T GetSwitchProperty(const std::string& property, int Id) const { + return GetNumericProperty(property, {{"Id", std::to_string(Id)}}); + } + + template + void SetSwitchProperty(const std::string& property, int Id, + const T& value) { + if constexpr (std::is_same_v) { + Put(property, {{"Id", std::to_string(Id)}, + {"State", value ? "true" : "false"}}); + } else if constexpr (std::is_arithmetic_v) { + Put(property, + {{"Id", std::to_string(Id)}, {"Value", std::to_string(value)}}); + } else if constexpr (std::is_same_v) { + Put(property, {{"Id", std::to_string(Id)}, {"Name", value}}); + } + } +}; diff --git a/src/client/alpaca/telescope.cpp b/src/client/alpaca/telescope.cpp new file mode 100644 index 00000000..932dc20f --- /dev/null +++ b/src/client/alpaca/telescope.cpp @@ -0,0 +1,348 @@ +#include "telescope.hpp" +#include +#include +#include + +AlpacaTelescope::AlpacaTelescope(std::string_view address, int device_number, + std::string_view protocol) + : AlpacaDevice(std::string(address), "telescope", device_number, + std::string(protocol)) {} + +AlpacaTelescope::AlignmentModes AlpacaTelescope::GetAlignmentMode() { + return static_cast(GetProperty("alignmentmode")); +} + +double AlpacaTelescope::GetAltitude() { + return GetProperty("altitude"); +} + +double AlpacaTelescope::GetApertureArea() { + return GetProperty("aperturearea"); +} + +double AlpacaTelescope::GetApertureDiameter() { + return GetProperty("aperturediameter"); +} + +bool AlpacaTelescope::GetAtHome() { return GetProperty("athome"); } + +bool AlpacaTelescope::GetAtPark() { return GetProperty("atpark"); } + +double AlpacaTelescope::GetAzimuth() { return GetProperty("azimuth"); } + +bool AlpacaTelescope::GetCanFindHome() { + return GetProperty("canfindhome"); +} + +bool AlpacaTelescope::GetCanPark() { return GetProperty("canpark"); } + +bool AlpacaTelescope::GetCanPulseGuide() { + return GetProperty("canpulseguide"); +} + +bool AlpacaTelescope::GetCanSetDeclinationRate() { + return GetProperty("cansetdeclinationrate"); +} + +bool AlpacaTelescope::GetCanSetGuideRates() { + return GetProperty("cansetguiderates"); +} + +bool AlpacaTelescope::GetCanSetPark() { + return GetProperty("cansetpark"); +} + +bool AlpacaTelescope::GetCanSetPierSide() { + return GetProperty("cansetpierside"); +} + +bool AlpacaTelescope::GetCanSetRightAscensionRate() { + return GetProperty("cansetrightascensionrate"); +} + +bool AlpacaTelescope::GetCanSetTracking() { + return GetProperty("cansettracking"); +} + +bool AlpacaTelescope::GetCanSlew() { return GetProperty("canslew"); } + +bool AlpacaTelescope::GetCanSlewAsync() { + return GetProperty("canslewasync"); +} + +bool AlpacaTelescope::GetCanSlewAltAz() { + return GetProperty("canslewaltaz"); +} + +bool AlpacaTelescope::GetCanSlewAltAzAsync() { + return GetProperty("canslewaltazasync"); +} + +bool AlpacaTelescope::GetCanSync() { return GetProperty("cansync"); } + +bool AlpacaTelescope::GetCanSyncAltAz() { + return GetProperty("cansyncaltaz"); +} + +bool AlpacaTelescope::GetCanUnpark() { return GetProperty("canunpark"); } + +double AlpacaTelescope::GetDeclination() { + return GetProperty("declination"); +} + +double AlpacaTelescope::GetDeclinationRate() { + return GetProperty("declinationrate"); +} + +void AlpacaTelescope::SetDeclinationRate(double DeclinationRate) { + Put("declinationrate", + {{"DeclinationRate", std::to_string(DeclinationRate)}}); +} + +bool AlpacaTelescope::GetDoesRefraction() { + return GetProperty("doesrefraction"); +} + +void AlpacaTelescope::SetDoesRefraction(bool DoesRefraction) { + Put("doesrefraction", + {{"DoesRefraction", DoesRefraction ? "true" : "false"}}); +} + +AlpacaTelescope::EquatorialCoordinateType +AlpacaTelescope::GetEquatorialSystem() { + return static_cast( + GetProperty("equatorialsystem")); +} + +double AlpacaTelescope::GetFocalLength() { + return GetProperty("focallength"); +} + +// ... (其他属性的getter和setter方法) + +template +std::future AlpacaTelescope::AsyncOperation( + Func&& func, const std::string& operationName) { + return std::async( + std::launch::async, + [this, func = std::forward(func), operationName]() { + func(); + while (GetSlewing()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + }); +} + +void AlpacaTelescope::AbortSlew() { Put("abortslew"); } + +std::future AlpacaTelescope::FindHome() { + return AsyncOperation([this]() { Put("findhome"); }, "findhome"); +} + +void AlpacaTelescope::MoveAxis(TelescopeAxes Axis, double Rate) { + Put("moveaxis", {{"Axis", std::to_string(static_cast(Axis))}, + {"Rate", std::to_string(Rate)}}); +} + +std::future AlpacaTelescope::Park() { + return AsyncOperation([this]() { Put("park"); }, "park"); +} + +std::future AlpacaTelescope::PulseGuide(GuideDirections Direction, + int Duration) { + return AsyncOperation( + [this, Direction, Duration]() { + Put("pulseguide", + {{"Direction", std::to_string(static_cast(Direction))}, + {"Duration", std::to_string(Duration)}}); + }, + "pulseguide"); +} + +void AlpacaTelescope::SetPark() { Put("setpark"); } + +double AlpacaTelescope::GetGuideRateDeclination() { + return GetProperty("guideratedeclination"); +} + +void AlpacaTelescope::SetGuideRateDeclination(double GuideRateDeclination) { + Put("guideratedeclination", {{"GuideRateDeclination", std::to_string(GuideRateDeclination)}}); +} + +double AlpacaTelescope::GetGuideRateRightAscension() { + return GetProperty("guideraterightascension"); +} + +void AlpacaTelescope::SetGuideRateRightAscension(double GuideRateRightAscension) { + Put("guideraterightascension", {{"GuideRateRightAscension", std::to_string(GuideRateRightAscension)}}); +} + +bool AlpacaTelescope::GetIsPulseGuiding() { + return GetProperty("ispulseguiding"); +} + +double AlpacaTelescope::GetRightAscension() { + return GetProperty("rightascension"); +} + +double AlpacaTelescope::GetRightAscensionRate() { + return GetProperty("rightascensionrate"); +} + +void AlpacaTelescope::SetRightAscensionRate(double RightAscensionRate) { + Put("rightascensionrate", {{"RightAscensionRate", std::to_string(RightAscensionRate)}}); +} + +AlpacaTelescope::PierSide AlpacaTelescope::GetSideOfPier() { + return static_cast(GetProperty("sideofpier")); +} + +void AlpacaTelescope::SetSideOfPier(PierSide SideOfPier) { + Put("sideofpier", {{"SideOfPier", std::to_string(static_cast(SideOfPier))}}); +} + +double AlpacaTelescope::GetSiderealTime() { + return GetProperty("siderealtime"); +} + +double AlpacaTelescope::GetSiteElevation() { + return GetProperty("siteelevation"); +} + +void AlpacaTelescope::SetSiteElevation(double SiteElevation) { + Put("siteelevation", {{"SiteElevation", std::to_string(SiteElevation)}}); +} + +double AlpacaTelescope::GetSiteLatitude() { + return GetProperty("sitelatitude"); +} + +void AlpacaTelescope::SetSiteLatitude(double SiteLatitude) { + Put("sitelatitude", {{"SiteLatitude", std::to_string(SiteLatitude)}}); +} + +double AlpacaTelescope::GetSiteLongitude() { + return GetProperty("sitelongitude"); +} + +void AlpacaTelescope::SetSiteLongitude(double SiteLongitude) { + Put("sitelongitude", {{"SiteLongitude", std::to_string(SiteLongitude)}}); +} + +bool AlpacaTelescope::GetSlewing() { + return GetProperty("slewing"); +} + +int AlpacaTelescope::GetSlewSettleTime() { + return GetProperty("slewsettletime"); +} + +void AlpacaTelescope::SetSlewSettleTime(int SlewSettleTime) { + Put("slewsettletime", {{"SlewSettleTime", std::to_string(SlewSettleTime)}}); +} + +double AlpacaTelescope::GetTargetDeclination() { + return GetProperty("targetdeclination"); +} + +void AlpacaTelescope::SetTargetDeclination(double TargetDeclination) { + Put("targetdeclination", {{"TargetDeclination", std::to_string(TargetDeclination)}}); +} + +double AlpacaTelescope::GetTargetRightAscension() { + return GetProperty("targetrightascension"); +} + +void AlpacaTelescope::SetTargetRightAscension(double TargetRightAscension) { + Put("targetrightascension", {{"TargetRightAscension", std::to_string(TargetRightAscension)}}); +} + +bool AlpacaTelescope::GetTracking() { + return GetProperty("tracking"); +} + +void AlpacaTelescope::SetTracking(bool Tracking) { + Put("tracking", {{"Tracking", Tracking ? "true" : "false"}}); +} + +AlpacaTelescope::DriveRates AlpacaTelescope::GetTrackingRate() { + return static_cast(GetProperty("trackingrate")); +} + +void AlpacaTelescope::SetTrackingRate(DriveRates TrackingRate) { + Put("trackingrate", {{"TrackingRate", std::to_string(static_cast(TrackingRate))}}); +} + +std::vector AlpacaTelescope::GetTrackingRates() { + auto rates = GetArrayProperty("trackingrates"); + std::vector result; + for (auto rate : rates) { + result.push_back(static_cast(rate)); + } + return result; +} + +std::chrono::system_clock::time_point AlpacaTelescope::GetUTCDate() { + std::string dateStr = Get("utcdate"); + std::tm tm = {}; + std::istringstream ss(dateStr); + ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%S"); + return std::chrono::system_clock::from_time_t(std::mktime(&tm)); +} + +void AlpacaTelescope::SetUTCDate(const std::chrono::system_clock::time_point& UTCDate) { + auto time = std::chrono::system_clock::to_time_t(UTCDate); + std::stringstream ss; + ss << std::put_time(std::gmtime(&time), "%Y-%m-%dT%H:%M:%S"); + Put("utcdate", {{"UTCDate", ss.str()}}); +} + +std::vector AlpacaTelescope::AxisRates(TelescopeAxes Axis) { + auto rates = GetArrayProperty("axisrates", {{"Axis", std::to_string(static_cast(Axis))}}); + return rates; +} + +bool AlpacaTelescope::CanMoveAxis(TelescopeAxes Axis) { + return GetProperty("canmoveaxis", {{"Axis", std::to_string(static_cast(Axis))}}); +} + +AlpacaTelescope::PierSide AlpacaTelescope::DestinationSideOfPier(double RightAscension, double Declination) { + return static_cast(GetProperty("destinationsideofpier", { + {"RightAscension", std::to_string(RightAscension)}, + {"Declination", std::to_string(Declination)} + })); +} + +std::future AlpacaTelescope::SlewToAltAzAsync(double Azimuth, double Altitude) { + return AsyncOperation([this, Azimuth, Altitude]() { + Put("slewtoaltazasync", {{"Azimuth", std::to_string(Azimuth)}, {"Altitude", std::to_string(Altitude)}}); + }, "slewtoaltaz"); +} + +std::future AlpacaTelescope::SlewToCoordinatesAsync(double RightAscension, double Declination) { + return AsyncOperation([this, RightAscension, Declination]() { + Put("slewtocoordinatesasync", {{"RightAscension", std::to_string(RightAscension)}, {"Declination", std::to_string(Declination)}}); + }, "slewtocoordinates"); +} + +std::future AlpacaTelescope::SlewToTargetAsync() { + return AsyncOperation([this]() { + Put("slewtotargetasync"); + }, "slewtotarget"); +} + +void AlpacaTelescope::SyncToAltAz(double Azimuth, double Altitude) { + Put("synctoaltaz", {{"Azimuth", std::to_string(Azimuth)}, {"Altitude", std::to_string(Altitude)}}); +} + +void AlpacaTelescope::SyncToCoordinates(double RightAscension, double Declination) { + Put("synctocoordinates", {{"RightAscension", std::to_string(RightAscension)}, {"Declination", std::to_string(Declination)}}); +} + +void AlpacaTelescope::SyncToTarget() { + Put("synctotarget"); +} + +void AlpacaTelescope::Unpark() { + Put("unpark"); +} diff --git a/src/client/alpaca/telescope.hpp b/src/client/alpaca/telescope.hpp new file mode 100644 index 00000000..ca3bce08 --- /dev/null +++ b/src/client/alpaca/telescope.hpp @@ -0,0 +1,150 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "device.hpp" + +class AlpacaTelescope : public AlpacaDevice { +public: + enum class AlignmentModes { + algAltAz = 0, + algPolar = 1, + algGermanPolar = 2 + }; + + enum class DriveRates { + driveSidereal = 0, + driveLunar = 1, + driveSolar = 2, + driveKing = 3 + }; + + enum class EquatorialCoordinateType { + equOther = 0, + equTopocentric = 1, + equJ2000 = 2, + equJ2050 = 3, + equB1950 = 4 + }; + + enum class GuideDirections { + guideNorth = 0, + guideSouth = 1, + guideEast = 2, + guideWest = 3 + }; + + enum class PierSide { pierEast = 0, pierWest = 1, pierUnknown = -1 }; + + enum class TelescopeAxes { + axisPrimary = 0, + axisSecondary = 1, + axisTertiary = 2 + }; + + struct Rate { + double Maximum; + double Minimum; + }; + + AlpacaTelescope(std::string_view address, int device_number, + std::string_view protocol = "http"); + virtual ~AlpacaTelescope() = default; + + // Properties + AlignmentModes GetAlignmentMode(); + double GetAltitude(); + double GetApertureArea(); + double GetApertureDiameter(); + bool GetAtHome(); + bool GetAtPark(); + double GetAzimuth(); + bool GetCanFindHome(); + bool GetCanPark(); + bool GetCanPulseGuide(); + bool GetCanSetDeclinationRate(); + bool GetCanSetGuideRates(); + bool GetCanSetPark(); + bool GetCanSetPierSide(); + bool GetCanSetRightAscensionRate(); + bool GetCanSetTracking(); + bool GetCanSlew(); + bool GetCanSlewAsync(); + bool GetCanSlewAltAz(); + bool GetCanSlewAltAzAsync(); + bool GetCanSync(); + bool GetCanSyncAltAz(); + bool GetCanUnpark(); + double GetDeclination(); + double GetDeclinationRate(); + void SetDeclinationRate(double DeclinationRate); + bool GetDoesRefraction(); + void SetDoesRefraction(bool DoesRefraction); + EquatorialCoordinateType GetEquatorialSystem(); + double GetFocalLength(); + double GetGuideRateDeclination(); + void SetGuideRateDeclination(double GuideRateDeclination); + double GetGuideRateRightAscension(); + void SetGuideRateRightAscension(double GuideRateRightAscension); + bool GetIsPulseGuiding(); + double GetRightAscension(); + double GetRightAscensionRate(); + void SetRightAscensionRate(double RightAscensionRate); + PierSide GetSideOfPier(); + void SetSideOfPier(PierSide SideOfPier); + double GetSiderealTime(); + double GetSiteElevation(); + void SetSiteElevation(double SiteElevation); + double GetSiteLatitude(); + void SetSiteLatitude(double SiteLatitude); + double GetSiteLongitude(); + void SetSiteLongitude(double SiteLongitude); + bool GetSlewing(); + int GetSlewSettleTime(); + void SetSlewSettleTime(int SlewSettleTime); + double GetTargetDeclination(); + void SetTargetDeclination(double TargetDeclination); + double GetTargetRightAscension(); + void SetTargetRightAscension(double TargetRightAscension); + bool GetTracking(); + void SetTracking(bool Tracking); + DriveRates GetTrackingRate(); + void SetTrackingRate(DriveRates TrackingRate); + std::vector GetTrackingRates(); + std::chrono::system_clock::time_point GetUTCDate(); + void SetUTCDate(const std::chrono::system_clock::time_point& UTCDate); + + // Methods + std::vector AxisRates(TelescopeAxes Axis); + bool CanMoveAxis(TelescopeAxes Axis); + PierSide DestinationSideOfPier(double RightAscension, + double Declination); + void AbortSlew(); + std::future FindHome(); + void MoveAxis(TelescopeAxes Axis, double Rate); + std::future Park(); + std::future PulseGuide(GuideDirections Direction, int Duration); + void SetPark(); + std::future SlewToAltAzAsync(double Azimuth, double Altitude); + std::future SlewToCoordinatesAsync(double RightAscension, + double Declination); + std::future SlewToTargetAsync(); + void SyncToAltAz(double Azimuth, double Altitude); + void SyncToCoordinates(double RightAscension, double Declination); + void SyncToTarget(); + void Unpark(); + +private: + template + T GetProperty(const std::string& property) const { + return GetNumericProperty(property); + } + + template + std::future AsyncOperation(Func&& func, + const std::string& operationName); +}; diff --git a/src/client/gps/gpsd.cpp b/src/client/gps/gpsd.cpp new file mode 100644 index 00000000..da33be88 --- /dev/null +++ b/src/client/gps/gpsd.cpp @@ -0,0 +1,86 @@ +#include "gpsd.hpp" +#include +#include +#include +#include + +GPSD::GPSD() = default; + +GPSD::~GPSD() { disconnect(); } + +bool GPSD::connect(const std::string& host, const std::string& port) { + gps = std::make_unique(host.c_str(), port.c_str()); + if (gps->stream(WATCH_ENABLE | WATCH_JSON) == nullptr) { + std::cerr << "No GPSD running." << std::endl; + return false; + } + return true; +} + +bool GPSD::disconnect() { + gps.reset(); + std::cout << "GPS disconnected successfully." << std::endl; + return true; +} + +std::optional GPSD::updateGPS() { + if (!gps || !gps->waiting(1000)) { + return std::nullopt; + } + + struct gps_data_t* gpsData; + while (gps->waiting(0)) { + gpsData = gps->read(); + if (!gpsData) { + std::cerr << "GPSD read error." << std::endl; + return std::nullopt; + } + } + + if (gpsData->fix.mode < MODE_2D) { + latestData.fixStatus = "NO FIX"; + return std::nullopt; + } + + latestData.fixStatus = (gpsData->fix.mode == MODE_3D) ? "3D FIX" : "2D FIX"; + latestData.latitude = gpsData->fix.latitude; + latestData.longitude = gpsData->fix.longitude; + if (latestData.longitude.value() < 0) { + latestData.longitude = latestData.longitude.value() + 360; + } + latestData.altitude = + (gpsData->fix.mode == MODE_3D) ? gpsData->fix.altitude : 0; + latestData.time = + std::chrono::system_clock::from_time_t(gpsData->fix.time.tv_sec); + latestData.polarisHourAngle = calculatePolarisHourAngle(gpsData); + + return latestData; +} + +double GPSD::calculatePolarisHourAngle(const gps_data_t* gpsData) { + double jd = ln_get_julian_from_timet( + reinterpret_cast<__time_t*>(&gpsData->fix.time.tv_sec)); + + double lst = ln_get_apparent_sidereal_time(jd); + return std::fmod(lst - 2.529722222 + (gpsData->fix.longitude / 15.0), 24.0); +} + +std::optional GPSD::getLatitude() const { return latestData.latitude; } + +std::optional GPSD::getLongitude() const { + return latestData.longitude; +} + +std::optional GPSD::getAltitude() const { return latestData.altitude; } + +std::optional GPSD::getTime() const { + return latestData.time; +} + +std::optional GPSD::getFixStatus() const { + return latestData.fixStatus; +} + +std::optional GPSD::getPolarisHourAngle() const { + return latestData.polarisHourAngle; +} diff --git a/src/client/gps/gpsd.hpp b/src/client/gps/gpsd.hpp new file mode 100644 index 00000000..c9b0db43 --- /dev/null +++ b/src/client/gps/gpsd.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include +#include +#include + +struct GPSData { + std::optional latitude; + std::optional longitude; + std::optional altitude; + std::optional time; + std::optional fixStatus; + std::optional polarisHourAngle; +}; + +class GPSD { +public: + GPSD(); + ~GPSD(); + + bool connect(const std::string& host = "localhost", const std::string& port = DEFAULT_GPSD_PORT); + bool disconnect(); + + std::optional updateGPS(); + + // Getter methods for GPS data + std::optional getLatitude() const; + std::optional getLongitude() const; + std::optional getAltitude() const; + std::optional getTime() const; + std::optional getFixStatus() const; + std::optional getPolarisHourAngle() const; + +private: + std::unique_ptr gps; + GPSData latestData; + + double calculatePolarisHourAngle(const gps_data_t* gpsData); +}; diff --git a/src/client/phd2/profile.cpp b/src/client/phd2/profile.cpp new file mode 100644 index 00000000..c0046416 --- /dev/null +++ b/src/client/phd2/profile.cpp @@ -0,0 +1,289 @@ +#include "profile.hpp" +#include +#include +#include +#include + +namespace fs = std::filesystem; +using json = nlohmann::json; + +struct ServerConfigData { + static inline const fs::path PHD2_HIDDEN_CONFIG_FILE = + "./phd2_hidden_config.json"; + static inline const fs::path DEFAULT_PHD2_CONFIG_FILE = + "./default_phd2_config.json"; +}; + +class PHD2ProfileSettingHandler::Impl { +public: + std::optional loaded_config_status; + const fs::path phd2_profile_save_path = "./server/data/phd2"; + + static void replace_double_marker(const fs::path& file_path) { + std::ifstream input_file(file_path); + std::string content((std::istreambuf_iterator(input_file)), + std::istreambuf_iterator()); + input_file.close(); + + size_t pos = content.find("\"\"#"); + while (pos != std::string::npos) { + content.replace(pos, 3, "#"); + pos = content.find("\"\"#", pos + 1); + } + + std::ofstream output_file(file_path); + output_file << content; + output_file.close(); + } + + json load_json_file(const fs::path& file_path) const { + std::ifstream file(file_path); + json config; + file >> config; + return config; + } + + void save_json_file(const fs::path& file_path, const json& config) const { + std::ofstream file(file_path); + file << config.dump(4); + file.close(); + replace_double_marker(file_path); + } +}; + +PHD2ProfileSettingHandler::PHD2ProfileSettingHandler() + : pimpl(std::make_unique()) {} +PHD2ProfileSettingHandler::~PHD2ProfileSettingHandler() = default; +PHD2ProfileSettingHandler::PHD2ProfileSettingHandler( + PHD2ProfileSettingHandler&&) noexcept = default; +PHD2ProfileSettingHandler& PHD2ProfileSettingHandler::operator=( + PHD2ProfileSettingHandler&&) noexcept = default; + +std::optional +PHD2ProfileSettingHandler::load_profile_file() { + if (!fs::exists(ServerConfigData::PHD2_HIDDEN_CONFIG_FILE)) { + fs::copy_file(ServerConfigData::DEFAULT_PHD2_CONFIG_FILE, + ServerConfigData::PHD2_HIDDEN_CONFIG_FILE, + fs::copy_options::overwrite_existing); + } + + try { + json phd2_config = + pimpl->load_json_file(ServerConfigData::PHD2_HIDDEN_CONFIG_FILE); + pimpl->loaded_config_status = InterfacePHD2Profile{ + .name = phd2_config["profile"]["1"]["name"], + .camera = phd2_config["profile"]["1"]["indi"]["INDIcam"], + .camera_ccd = phd2_config["profile"]["1"]["indi"]["INDIcam_ccd"], + .pixel_size = phd2_config["profile"]["1"]["camera"]["pixelsize"], + .telescope = phd2_config["profile"]["1"]["indi"]["INDImount"], + .focal_length = phd2_config["profile"]["1"]["frame"]["focalLength"], + .mass_change_threshold = + phd2_config["profile"]["1"]["guider"]["onestar"] + ["MassChangeThreshold"], + .mass_change_flag = phd2_config["profile"]["1"]["guider"]["onestar"] + ["MassChangeThresholdEnabled"], + .calibration_distance = + phd2_config["profile"]["1"]["scope"]["CalibrationDistance"], + .calibration_duration = + phd2_config["profile"]["1"]["scope"]["CalibrationDuration"]}; + } catch (const json::exception& e) { + std::cerr << "JSON parsing error: " << e.what() << std::endl; + fs::remove(ServerConfigData::PHD2_HIDDEN_CONFIG_FILE); + fs::copy_file(ServerConfigData::DEFAULT_PHD2_CONFIG_FILE, + ServerConfigData::PHD2_HIDDEN_CONFIG_FILE, + fs::copy_options::overwrite_existing); + return load_profile_file(); // Recursive call with default config + } + + return pimpl->loaded_config_status; +} + +bool PHD2ProfileSettingHandler::load_profile(const std::string& profile_name) { + fs::path profile_file = + pimpl->phd2_profile_save_path / (profile_name + ".json"); + + if (fs::exists(profile_file)) { + fs::copy_file(profile_file, ServerConfigData::PHD2_HIDDEN_CONFIG_FILE, + fs::copy_options::overwrite_existing); + load_profile_file(); + return true; + } else { + fs::copy_file(ServerConfigData::DEFAULT_PHD2_CONFIG_FILE, + ServerConfigData::PHD2_HIDDEN_CONFIG_FILE, + fs::copy_options::overwrite_existing); + load_profile_file(); + return false; + } +} + +bool PHD2ProfileSettingHandler::new_profile_setting( + const std::string& new_profile_name) { + fs::path new_profile_file = + pimpl->phd2_profile_save_path / (new_profile_name + ".json"); + + if (fs::exists(new_profile_file)) { + restore_profile(new_profile_name); + return false; + } else { + fs::copy_file(ServerConfigData::DEFAULT_PHD2_CONFIG_FILE, + ServerConfigData::PHD2_HIDDEN_CONFIG_FILE, + fs::copy_options::overwrite_existing); + load_profile_file(); + return true; + } +} + +bool PHD2ProfileSettingHandler::update_profile( + const InterfacePHD2Profile& phd2_profile_setting) { + json phd2_config = + pimpl->load_json_file(ServerConfigData::PHD2_HIDDEN_CONFIG_FILE); + + phd2_config["profile"]["1"]["name"] = phd2_profile_setting.name; + phd2_config["profile"]["1"]["indi"]["INDIcam"] = + phd2_profile_setting.camera; + phd2_config["profile"]["1"]["indi"]["INDIcam_ccd"] = + phd2_profile_setting.camera_ccd; + phd2_config["profile"]["1"]["camera"]["pixelsize"] = + phd2_profile_setting.pixel_size; + phd2_config["profile"]["1"]["indi"]["INDImount"] = + phd2_profile_setting.telescope; + phd2_config["profile"]["1"]["frame"]["focalLength"] = + phd2_profile_setting.focal_length; + phd2_config["profile"]["1"]["guider"]["onestar"]["MassChangeThreshold"] = + phd2_profile_setting.mass_change_threshold; + phd2_config["profile"]["1"]["guider"]["onestar"] + ["MassChangeThresholdEnabled"] = + phd2_profile_setting.mass_change_flag; + phd2_config["profile"]["1"]["scope"]["CalibrationDistance"] = + phd2_profile_setting.calibration_distance; + phd2_config["profile"]["1"]["scope"]["CalibrationDuration"] = + phd2_profile_setting.calibration_duration; + + pimpl->save_json_file(ServerConfigData::PHD2_HIDDEN_CONFIG_FILE, + phd2_config); + return true; +} + +bool PHD2ProfileSettingHandler::delete_profile( + const std::string& to_delete_profile) { + fs::path to_delete_profile_file = + pimpl->phd2_profile_save_path / (to_delete_profile + ".json"); + if (fs::exists(to_delete_profile_file)) { + fs::remove(to_delete_profile_file); + return true; + } + return false; +} + +void PHD2ProfileSettingHandler::save_profile(const std::string& profile_name) { + fs::path profile_file = + pimpl->phd2_profile_save_path / (profile_name + ".json"); + if (fs::exists(ServerConfigData::PHD2_HIDDEN_CONFIG_FILE)) { + if (fs::exists(profile_file)) { + fs::remove(profile_file); + } + fs::copy_file(ServerConfigData::PHD2_HIDDEN_CONFIG_FILE, profile_file, + fs::copy_options::overwrite_existing); + } +} + +bool PHD2ProfileSettingHandler::restore_profile( + const std::string& to_restore_profile) { + fs::path to_restore_file = + pimpl->phd2_profile_save_path / (to_restore_profile + ".json"); + if (fs::exists(to_restore_file)) { + fs::copy_file(to_restore_file, + ServerConfigData::PHD2_HIDDEN_CONFIG_FILE, + fs::copy_options::overwrite_existing); + load_profile_file(); + return true; + } else { + new_profile_setting(to_restore_profile); + return false; + } +} + +// New functionality implementations + +std::vector PHD2ProfileSettingHandler::list_profiles() const { + std::vector profiles; + for (const auto& entry : + fs::directory_iterator(pimpl->phd2_profile_save_path)) { + if (entry.path().extension() == ".json") { + profiles.push_back(entry.path().stem().string()); + } + } + return profiles; +} + +bool PHD2ProfileSettingHandler::export_profile( + const std::string& profile_name, const fs::path& export_path) const { + fs::path source_file = + pimpl->phd2_profile_save_path / (profile_name + ".json"); + if (fs::exists(source_file)) { + fs::copy_file(source_file, export_path, + fs::copy_options::overwrite_existing); + return true; + } + return false; +} + +bool PHD2ProfileSettingHandler::import_profile( + const fs::path& import_path, const std::string& new_profile_name) { + if (fs::exists(import_path)) { + fs::path destination_file = + pimpl->phd2_profile_save_path / (new_profile_name + ".json"); + fs::copy_file(import_path, destination_file, + fs::copy_options::overwrite_existing); + return true; + } + return false; +} + +bool PHD2ProfileSettingHandler::compare_profiles( + const std::string& profile1, const std::string& profile2) const { + fs::path file1 = pimpl->phd2_profile_save_path / (profile1 + ".json"); + fs::path file2 = pimpl->phd2_profile_save_path / (profile2 + ".json"); + + if (!fs::exists(file1) || !fs::exists(file2)) { + return false; + } + + json config1 = pimpl->load_json_file(file1); + json config2 = pimpl->load_json_file(file2); + + std::cout << "Comparing profiles: " << profile1 << " and " << profile2 + << std::endl; + std::cout << "Differences:" << std::endl; + + for (auto it = config1.begin(); it != config1.end(); ++it) { + if (config2.find(it.key()) == config2.end() || + config2[it.key()] != it.value()) { + std::cout << it.key() << ": " << it.value() << " vs " + << config2[it.key()] << std::endl; + } + } + + for (auto it = config2.begin(); it != config2.end(); ++it) { + if (config1.find(it.key()) == config1.end()) { + std::cout << it.key() << ": missing in " << profile1 << std::endl; + } + } + + return true; +} + +void PHD2ProfileSettingHandler::print_profile_details( + const std::string& profile_name) const { + fs::path profile_file = + pimpl->phd2_profile_save_path / (profile_name + ".json"); + if (fs::exists(profile_file)) { + json config = pimpl->load_json_file(profile_file); + std::cout << "Profile: " << profile_name << std::endl; + std::cout << "Details:" << std::endl; + std::cout << config.dump(4) << std::endl; + } else { + std::cout << "Profile " << profile_name << " does not exist." + << std::endl; + } +} diff --git a/src/client/phd2/profile.hpp b/src/client/phd2/profile.hpp new file mode 100644 index 00000000..b38c7c9a --- /dev/null +++ b/src/client/phd2/profile.hpp @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include +#include +#include + +struct InterfacePHD2Profile { + std::string name; + std::string camera; + std::string camera_ccd; + double pixel_size; + std::string telescope; + double focal_length; + double mass_change_threshold; + bool mass_change_flag; + double calibration_distance; + double calibration_duration; +}; + +class PHD2ProfileSettingHandler { +public: + PHD2ProfileSettingHandler(); + ~PHD2ProfileSettingHandler(); + + // Disable copy operations + PHD2ProfileSettingHandler(const PHD2ProfileSettingHandler&) = delete; + PHD2ProfileSettingHandler& operator=(const PHD2ProfileSettingHandler&) = + delete; + + // Enable move operations + PHD2ProfileSettingHandler(PHD2ProfileSettingHandler&&) noexcept; + PHD2ProfileSettingHandler& operator=(PHD2ProfileSettingHandler&&) noexcept; + + std::optional load_profile_file(); + bool load_profile(const std::string& profile_name); + bool new_profile_setting(const std::string& new_profile_name); + bool update_profile(const InterfacePHD2Profile& phd2_profile_setting); + bool delete_profile(const std::string& to_delete_profile); + void save_profile(const std::string& profile_name); + bool restore_profile(const std::string& to_restore_profile); + + // New functionality + std::vector list_profiles() const; + bool export_profile(const std::string& profile_name, + const std::filesystem::path& export_path) const; + bool import_profile(const std::filesystem::path& import_path, + const std::string& new_profile_name); + bool compare_profiles(const std::string& profile1, + const std::string& profile2) const; + void print_profile_details(const std::string& profile_name) const; + +private: + class Impl; + std::unique_ptr pimpl; +}; diff --git a/src/client/phd2/shared.cpp b/src/client/phd2/shared.cpp index 9c04ddd0..3a7371a3 100644 --- a/src/client/phd2/shared.cpp +++ b/src/client/phd2/shared.cpp @@ -1,791 +1,419 @@ -void MainWindow::InitPHD2() -{ - isGuideCapture = true; - - cmdPHD2 = new QProcess(); - cmdPHD2->start("pkill phd2"); - cmdPHD2->waitForStarted(); - cmdPHD2->waitForFinished(); - - key_phd = ftok("../", 2015); - key_phd = 0x90; - - if (key_phd == -1) - { - qDebug("ftok_phd"); - } - - // build the shared memory - system("ipcs -m"); // 查看共享内存 - shmid_phd = shmget(key_phd, BUFSZ_PHD, IPC_CREAT | 0666); - if (shmid_phd < 0) - { - qDebug("main.cpp | main | shared memory phd shmget ERROR"); - exit(-1); - } - - // 映射 - sharedmemory_phd = (char *)shmat(shmid_phd, NULL, 0); - if (sharedmemory_phd == NULL) - { - qDebug("main.cpp | main | shared memor phd map ERROR"); - exit(-1); +#include "shared.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std::literals; + +class PHDSharedMemoryClient::Impl { +public: + Impl() { initSharedMemory(); } + + ~Impl() { + if (sharedMemory_) { + shmdt(sharedMemory_); + } } - // 读共享内存区数据 - qDebug("data_phd = [%s]\n", sharedmemory_phd); - - cmdPHD2->start("phd2"); - - QElapsedTimer t; - t.start(); - while (t.elapsed() < 10000) - { - usleep(10000); - qApp->processEvents(); - if (connectPHD() == true) - break; - } -} + void initSharedMemory() { + key_t key = ftok("../", 2015); + if (key == -1) { + throw std::runtime_error("Failed to create key for shared memory"); + } -bool MainWindow::connectPHD(void) -{ - QString versionName = ""; - call_phd_GetVersion(versionName); + shmid_ = shmget(key, 4096, IPC_CREAT | 0666); + if (shmid_ < 0) { + throw std::runtime_error("Failed to get shared memory"); + } - qDebug() << "QSCOPE|connectPHD|version:" << versionName; - if (versionName != "") - { - // init stellarium operation - return true; - } - else - { - qDebug() << "QSCOPE|connectPHD|error:there is no openPHD2 running"; - return false; + sharedMemory_ = static_cast(shmat(shmid_, nullptr, 0)); + if (sharedMemory_ == nullptr) { + throw std::runtime_error("Failed to attach shared memory"); + } } -} -bool MainWindow::call_phd_GetVersion(QString &versionName) -{ - unsigned int baseAddress; - unsigned int vendcommand; - bzero(sharedmemory_phd, 1024); // 共享内存清空 + bool sendCommand(unsigned int vendCommand) { + constexpr unsigned int baseAddress = 0x03; - baseAddress = 0x03; - vendcommand = 0x01; + std::memset(sharedMemory_, 0, 1024); - sharedmemory_phd[1] = Tools::MSB(vendcommand); - sharedmemory_phd[2] = Tools::LSB(vendcommand); + sharedMemory_[1] = msb(vendCommand); + sharedMemory_[2] = lsb(vendCommand); + sharedMemory_[0] = 0x01; // enable command - sharedmemory_phd[0] = 0x01; // enable command - - QElapsedTimer t; - t.start(); - - while (sharedmemory_phd[0] == 0x01 && t.elapsed() < 500) - { - // QCoreApplication::processEvents(); + return waitForResponse(500ms); } - if (t.elapsed() >= 500) - { - versionName = ""; - return false; - } - else - { - unsigned char addr = 0; - uint16_t length; - memcpy(&length, sharedmemory_phd + baseAddress + addr, sizeof(uint16_t)); - addr = addr + sizeof(uint16_t); - // qDebug()< 0 && length < 1024) - { - for (int i = 0; i < length; i++) - { - versionName.append(sharedmemory_phd[baseAddress + addr + i]); + bool waitForResponse(std::chrono::milliseconds timeout) { + auto start = std::chrono::steady_clock::now(); + while (sharedMemory_[0] == 0x01) { + if (std::chrono::steady_clock::now() - start > timeout) { + return false; } - return true; - // qDebug()<((value >> 8) & 0xFF); } - if (t.elapsed() >= 500) - return false; // timeout - else - return true; -} -uint32_t MainWindow::call_phd_StopLooping(void) -{ - unsigned int vendcommand; - unsigned int baseAddress; + static unsigned char lsb(unsigned int value) { + return static_cast(value & 0xFF); + } - bzero(sharedmemory_phd, 1024); // 共享内存清空 + int shmid_; + char* sharedMemory_; + std::mutex mutex_; - baseAddress = 0x03; - vendcommand = 0x04; + // Additional member variables for new functions + double starX_ = 0.0; + double starY_ = 0.0; + double rmsError_ = 0.0; +}; - sharedmemory_phd[1] = Tools::MSB(vendcommand); - sharedmemory_phd[2] = Tools::LSB(vendcommand); +PHDSharedMemoryClient::PHDSharedMemoryClient() + : pImpl(std::make_unique()) {} - sharedmemory_phd[0] = 0x01; // enable command +PHDSharedMemoryClient::~PHDSharedMemoryClient() = default; - QElapsedTimer t; - t.start(); +PHDSharedMemoryClient::PHDSharedMemoryClient(PHDSharedMemoryClient&&) noexcept = + default; +PHDSharedMemoryClient& PHDSharedMemoryClient::operator=( + PHDSharedMemoryClient&&) noexcept = default; - while (sharedmemory_phd[0] == 0x01 && t.elapsed() < 500) - { - // QCoreApplication::processEvents(); - } - if (t.elapsed() >= 500) - return false; // timeout - else +bool PHDSharedMemoryClient::connectPHD() { + std::string versionName; + if (call_phd_GetVersion(versionName)) { + std::cout << "QSCOPE|connectPHD|version: " << versionName << std::endl; return true; + } else { + std::cout << "QSCOPE|connectPHD|error: there is no openPHD2 running" + << std::endl; + return false; + } } -uint32_t MainWindow::call_phd_AutoFindStar(void) -{ - unsigned int vendcommand; - unsigned int baseAddress; +bool PHDSharedMemoryClient::call_phd_GetVersion(std::string& versionName) { + constexpr unsigned int baseAddress = 0x03; + constexpr unsigned int vendCommand = 0x01; - bzero(sharedmemory_phd, 1024); // 共享内存清空 + std::memset(pImpl->sharedMemory_, 0, 1024); - baseAddress = 0x03; - vendcommand = 0x05; + pImpl->sharedMemory_[1] = Impl::msb(vendCommand); + pImpl->sharedMemory_[2] = Impl::lsb(vendCommand); + pImpl->sharedMemory_[0] = 0x01; // enable command - sharedmemory_phd[1] = Tools::MSB(vendcommand); - sharedmemory_phd[2] = Tools::LSB(vendcommand); - - sharedmemory_phd[0] = 0x01; // enable command + if (!pImpl->waitForResponse(500ms)) { + versionName.clear(); + return false; + } - QElapsedTimer t; - t.start(); + unsigned char addr = 0; + uint16_t length; + std::memcpy(&length, pImpl->sharedMemory_ + baseAddress + addr, + sizeof(uint16_t)); + addr += sizeof(uint16_t); - while (sharedmemory_phd[0] == 0x01 && t.elapsed() < 500) - { - // QCoreApplication::processEvents(); - } - if (t.elapsed() >= 500) - return false; // timeout - else + if (length > 0 && length < 1024) { + versionName.assign(pImpl->sharedMemory_ + baseAddress + addr, length); return true; + } else { + versionName.clear(); + return false; + } } -uint32_t MainWindow::call_phd_StartGuiding(void) -{ - unsigned int vendcommand; - unsigned int baseAddress; - - bzero(sharedmemory_phd, 1024); // 共享内存清空 - - baseAddress = 0x03; - vendcommand = 0x06; - - sharedmemory_phd[1] = Tools::MSB(vendcommand); - sharedmemory_phd[2] = Tools::LSB(vendcommand); - - sharedmemory_phd[0] = 0x01; // enable command - - QElapsedTimer t; - t.start(); - - while (sharedmemory_phd[0] == 0x01 && t.elapsed() < 500) - { - // QCoreApplication::processEvents(); - } - if (t.elapsed() >= 500) - return false; // timeout - else - return true; +bool PHDSharedMemoryClient::call_phd_StartLooping() { + return pImpl->sendCommand(0x03); } -uint32_t MainWindow::call_phd_checkStatus(unsigned char &status) -{ - unsigned int vendcommand; - unsigned int baseAddress; +bool PHDSharedMemoryClient::call_phd_StopLooping() { + return pImpl->sendCommand(0x04); +} - bzero(sharedmemory_phd, 1024); // 共享内存清空 +bool PHDSharedMemoryClient::call_phd_AutoFindStar() { + return pImpl->sendCommand(0x05); +} - baseAddress = 0x03; - vendcommand = 0x07; +bool PHDSharedMemoryClient::call_phd_StartGuiding() { + return pImpl->sendCommand(0x06); +} - sharedmemory_phd[1] = Tools::MSB(vendcommand); - sharedmemory_phd[2] = Tools::LSB(vendcommand); +bool PHDSharedMemoryClient::call_phd_checkStatus(unsigned char& status) { + constexpr unsigned int baseAddress = 0x03; + constexpr unsigned int vendCommand = 0x07; - sharedmemory_phd[0] = 0x01; // enable command + std::memset(pImpl->sharedMemory_, 0, 1024); - // wait stellarium finished this task - QElapsedTimer t; - t.start(); - while (sharedmemory_phd[0] == 0x01 && t.elapsed() < 500) - { - // QCoreApplication::processEvents(); - } // wait stellarium run end + pImpl->sharedMemory_[1] = Impl::msb(vendCommand); + pImpl->sharedMemory_[2] = Impl::lsb(vendCommand); + pImpl->sharedMemory_[0] = 0x01; // enable command - if (t.elapsed() >= 500) - { - // timeout + if (!pImpl->waitForResponse(500ms)) { status = 0; return false; } - else - { - status = sharedmemory_phd[3]; - return true; - } + status = pImpl->sharedMemory_[3]; + return true; } -uint32_t MainWindow::call_phd_setExposureTime(unsigned int expTime) -{ - unsigned int vendcommand; - unsigned int baseAddress; - qDebug() << "call_phd_setExposureTime" << expTime; - bzero(sharedmemory_phd, 1024); // 共享内存清空 - - baseAddress = 0x03; - vendcommand = 0x0b; +bool PHDSharedMemoryClient::call_phd_setExposureTime(unsigned int expTime) { + constexpr unsigned int baseAddress = 0x03; + constexpr unsigned int vendCommand = 0x0b; - sharedmemory_phd[1] = Tools::MSB(vendcommand); - sharedmemory_phd[2] = Tools::LSB(vendcommand); - - unsigned char addr = 0; - memcpy(sharedmemory_phd + baseAddress + addr, &expTime, sizeof(unsigned int)); - addr = addr + sizeof(unsigned int); + std::memset(pImpl->sharedMemory_, 0, 1024); - sharedmemory_phd[0] = 0x01; // enable command + pImpl->sharedMemory_[1] = Impl::msb(vendCommand); + pImpl->sharedMemory_[2] = Impl::lsb(vendCommand); - // wait stellarium finished this task - QElapsedTimer t; - t.start(); + std::memcpy(pImpl->sharedMemory_ + baseAddress, &expTime, + sizeof(unsigned int)); - while (sharedmemory_phd[0] == 0x01 && t.elapsed() < 500) - { - // QCoreApplication::processEvents(); - } // wait stellarium run end + pImpl->sharedMemory_[0] = 0x01; // enable command - if (t.elapsed() >= 500) - return QHYCCD_ERROR; // timeout - else - return QHYCCD_SUCCESS; + return pImpl->waitForResponse(500ms); } -uint32_t MainWindow::call_phd_whichCamera(std::string Camera) -{ - unsigned int vendcommand; - unsigned int baseAddress; +bool PHDSharedMemoryClient::call_phd_whichCamera(const std::string& camera) { + constexpr unsigned int baseAddress = 0x03; + constexpr unsigned int vendCommand = 0x0d; - bzero(sharedmemory_phd, 1024); // 共享内存清空 + std::memset(pImpl->sharedMemory_, 0, 1024); - baseAddress = 0x03; - vendcommand = 0x0d; - - sharedmemory_phd[1] = Tools::MSB(vendcommand); - sharedmemory_phd[2] = Tools::LSB(vendcommand); - - sharedmemory_phd[0] = 0x01; // enable command - - int length = Camera.length() + 1; + pImpl->sharedMemory_[1] = Impl::msb(vendCommand); + pImpl->sharedMemory_[2] = Impl::lsb(vendCommand); + int length = camera.length() + 1; unsigned char addr = 0; - // memcpy(sharedmemory_phd + baseAddress + addr, &index, sizeof(int)); - // addr = addr + sizeof(int); - memcpy(sharedmemory_phd + baseAddress + addr, &length, sizeof(int)); - addr = addr + sizeof(int); - memcpy(sharedmemory_phd + baseAddress + addr, Camera.c_str(), length); - addr = addr + length; - - // wait stellarium finished this task - QElapsedTimer t; - t.start(); - - while (sharedmemory_phd[0] == 0x01 && t.elapsed() < 500) - { - // QCoreApplication::processEvents(); - } // wait stellarium run end - - if (t.elapsed() >= 500) - return QHYCCD_ERROR; // timeout - else - return QHYCCD_SUCCESS; -} -uint32_t MainWindow::call_phd_ChackControlStatus(int sdk_num) -{ - unsigned int vendcommand; - unsigned int baseAddress; + std::memcpy(pImpl->sharedMemory_ + baseAddress + addr, &length, + sizeof(int)); + addr += sizeof(int); + std::memcpy(pImpl->sharedMemory_ + baseAddress + addr, camera.c_str(), + length); - bzero(sharedmemory_phd, 1024); // 共享内存清空 + pImpl->sharedMemory_[0] = 0x01; // enable command - baseAddress = 0x03; - vendcommand = 0x0e; - - sharedmemory_phd[1] = Tools::MSB(vendcommand); - sharedmemory_phd[2] = Tools::LSB(vendcommand); - - sharedmemory_phd[0] = 0x01; // enable command - - unsigned char addr = 0; - memcpy(sharedmemory_phd + baseAddress + addr, &sdk_num, sizeof(int)); - addr = addr + sizeof(int); - - QElapsedTimer t; - t.start(); - - while (sharedmemory_phd[0] == 0x01 && t.elapsed() < 500) - { - // QCoreApplication::processEvents(); - } - if (t.elapsed() >= 500) - return false; // timeout - else - return true; + return pImpl->waitForResponse(500ms); } -uint32_t MainWindow::call_phd_ClearCalibration(void) -{ - unsigned int vendcommand; - unsigned int baseAddress; +bool PHDSharedMemoryClient::call_phd_ChackControlStatus(int sdk_num) { + constexpr unsigned int baseAddress = 0x03; + constexpr unsigned int vendCommand = 0x0e; - bzero(sharedmemory_phd, 1024); // 共享内存清空 + std::memset(pImpl->sharedMemory_, 0, 1024); - baseAddress = 0x03; - vendcommand = 0x02; + pImpl->sharedMemory_[1] = Impl::msb(vendCommand); + pImpl->sharedMemory_[2] = Impl::lsb(vendCommand); - sharedmemory_phd[1] = Tools::MSB(vendcommand); - sharedmemory_phd[2] = Tools::LSB(vendcommand); + std::memcpy(pImpl->sharedMemory_ + baseAddress, &sdk_num, sizeof(int)); - sharedmemory_phd[0] = 0x01; // enable command + pImpl->sharedMemory_[0] = 0x01; // enable command - QElapsedTimer t; - t.start(); + return pImpl->waitForResponse(500ms); +} - while (sharedmemory_phd[0] == 0x01 && t.elapsed() < 500) - { - // QCoreApplication::processEvents(); - } - if (t.elapsed() >= 500) - return false; // timeout - else - return true; +bool PHDSharedMemoryClient::call_phd_ClearCalibration() { + return pImpl->sendCommand(0x02); } -void MainWindow::ShowPHDdata() -{ - unsigned int currentPHDSizeX = 1; - unsigned int currentPHDSizeY = 1; - unsigned int bitDepth = 1; +void PHDSharedMemoryClient::showPHDData() { + std::lock_guard lock(pImpl->mutex_); - unsigned char guideDataIndicator; - unsigned int guideDataIndicatorAddress; - double dRa, dDec, SNR, MASS, RMSErrorX, RMSErrorY, RMSErrorTotal, PixelRatio; + if (pImpl->sharedMemory_[2047] != 0x02) + return; + + unsigned int currentPHDSizeX, currentPHDSizeY; + unsigned char bitDepth; + double dRa, dDec, SNR, MASS, RMSErrorX, RMSErrorY, RMSErrorTotal, + PixelRatio; int RADUR, DECDUR; char RADIR, DECDIR; - unsigned char LossAlert; - - double StarX; - double StarY; - bool isSelected; - - bool showLockedCross; - double LockedPositionX; - double LockedPositionY; - - unsigned char MultiStarNumber; - unsigned short MultiStarX[32]; - unsigned short MultiStarY[32]; - - unsigned int mem_offset; - int sdk_direction = 0; - int sdk_duration = 0; - int sdk_num; - int zero = 0; - - bool StarLostAlert = false; - - if (sharedmemory_phd[2047] != 0x02) - return; // if there is no image comes, return - - mem_offset = 1024; - // guide image dimention data - memcpy(¤tPHDSizeX, sharedmemory_phd + mem_offset, sizeof(unsigned int)); - mem_offset = mem_offset + sizeof(unsigned int); - memcpy(¤tPHDSizeY, sharedmemory_phd + mem_offset, sizeof(unsigned int)); - mem_offset = mem_offset + sizeof(unsigned int); - memcpy(&bitDepth, sharedmemory_phd + mem_offset, sizeof(unsigned char)); - mem_offset = mem_offset + sizeof(unsigned char); - - mem_offset = mem_offset + sizeof(int); // &sdk_num - mem_offset = mem_offset + sizeof(int); // &sdk_direction - mem_offset = mem_offset + sizeof(int); // &sdk_duration - - guideDataIndicatorAddress = mem_offset; - - // guide error data - guideDataIndicator = sharedmemory_phd[guideDataIndicatorAddress]; - - mem_offset = mem_offset + sizeof(unsigned char); - memcpy(&dRa, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - memcpy(&dDec, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - memcpy(&SNR, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - memcpy(&MASS, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - - memcpy(&RADUR, sharedmemory_phd + mem_offset, sizeof(int)); - mem_offset = mem_offset + sizeof(int); - memcpy(&DECDUR, sharedmemory_phd + mem_offset, sizeof(int)); - mem_offset = mem_offset + sizeof(int); - - memcpy(&RADIR, sharedmemory_phd + mem_offset, sizeof(char)); - mem_offset = mem_offset + sizeof(char); - memcpy(&DECDIR, sharedmemory_phd + mem_offset, sizeof(char)); - mem_offset = mem_offset + sizeof(char); - - memcpy(&RMSErrorX, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - memcpy(&RMSErrorY, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - memcpy(&RMSErrorTotal, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - memcpy(&PixelRatio, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - memcpy(&StarLostAlert, sharedmemory_phd + mem_offset, sizeof(bool)); - mem_offset = mem_offset + sizeof(bool); - memcpy(&InGuiding, sharedmemory_phd + mem_offset, sizeof(bool)); - mem_offset = mem_offset + sizeof(bool); - - mem_offset = 1024 + 200; - memcpy(&isSelected, sharedmemory_phd + mem_offset, sizeof(bool)); - mem_offset = mem_offset + sizeof(bool); - memcpy(&StarX, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - memcpy(&StarY, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - memcpy(&showLockedCross, sharedmemory_phd + mem_offset, sizeof(bool)); - mem_offset = mem_offset + sizeof(bool); - memcpy(&LockedPositionX, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - memcpy(&LockedPositionY, sharedmemory_phd + mem_offset, sizeof(double)); - mem_offset = mem_offset + sizeof(double); - memcpy(&MultiStarNumber, sharedmemory_phd + mem_offset, sizeof(unsigned char)); - mem_offset = mem_offset + sizeof(unsigned char); - memcpy(MultiStarX, sharedmemory_phd + mem_offset, sizeof(MultiStarX)); - mem_offset = mem_offset + sizeof(MultiStarX); - memcpy(MultiStarY, sharedmemory_phd + mem_offset, sizeof(MultiStarY)); - mem_offset = mem_offset + sizeof(MultiStarY); - - sharedmemory_phd[guideDataIndicatorAddress] = 0x00; // have been read back - - glPHD_isSelected = isSelected; - glPHD_StarX = StarX; - glPHD_StarY = StarY; - glPHD_CurrentImageSizeX = currentPHDSizeX; - glPHD_CurrentImageSizeY = currentPHDSizeY; - glPHD_LockPositionX = LockedPositionX; - glPHD_LockPositionY = LockedPositionY; - glPHD_ShowLockCross = showLockedCross; - - glPHD_Stars.clear(); - for (int i = 0; i < MultiStarNumber; i++) - { - if (i > 30) - break; - QPoint p; - p.setX(MultiStarX[i]); - p.setY(MultiStarY[i]); - glPHD_Stars.push_back(p); - } - - if (glPHD_StarX != 0 && glPHD_StarY != 0) - glPHD_StartGuide = true; - - unsigned int byteCount; - byteCount = currentPHDSizeX * currentPHDSizeY * (bitDepth / 8); - - mem_offset = 2048; - - unsigned char m = sharedmemory_phd[2047]; - - if (sharedmemory_phd[2047] == 0x02 && bitDepth > 0 && currentPHDSizeX > 0 && currentPHDSizeY > 0) - { - // 导星过程中的数据 - // qDebug() << guideDataIndicator << "dRa:" << dRa << "dDec:" << dDec - // << "rmsX:" << RMSErrorX << "rmsY:" << RMSErrorY - // << "rmsTotal:" << RMSErrorTotal << "SNR:" << SNR; - unsigned char phdstatu; - call_phd_checkStatus(phdstatu); - - if (dRa != 0 && dDec != 0) - { - QPointF tmp; - tmp.setX(-dRa * PixelRatio); - tmp.setY(dDec * PixelRatio); - glPHD_rmsdate.append(tmp); - // m_pToolbarWidget->guiderLabel->Series_err->append(-dRa * PixelRatio, -dDec * PixelRatio); - emit wsThread->sendMessageToClient("AddScatterChartData:" + QString::number(-dRa * PixelRatio) + ":" + QString::number(-dDec * PixelRatio)); - - // 曲线的数值 - // qDebug() << "Ra|Dec: " << -dRa * PixelRatio << "," << dDec * PixelRatio; - - // 图像中的小绿框 - if (InGuiding == true) - { - // m_pToolbarWidget->LabelMainStarBox->setStyleSheet("QLabel{border:2px solid rgb(0,255,0);border-radius:3px;background-color:transparent;}"); - // m_pToolbarWidget->LabelCrossX->setStyleSheet("QLabel{border:1px solid rgb(0,255,0);border-radius:3px;background-color:transparent;}"); - // m_pToolbarWidget->LabelCrossY->setStyleSheet("QLabel{border:1px solid rgb(0,255,0);border-radius:3px;background-color:transparent;}"); - emit wsThread->sendMessageToClient("InGuiding"); - } - else - { - // m_pToolbarWidget->LabelMainStarBox->setStyleSheet("QLabel{border:2px solid rgb(255,255,0);border-radius:3px;background-color:transparent;}"); - // m_pToolbarWidget->LabelCrossX->setStyleSheet("QLabel{border:1px solid rgb(255,255,0);border-radius:3px;background-color:transparent;}"); - // m_pToolbarWidget->LabelCrossY->setStyleSheet("QLabel{border:1px solid rgb(255,255,0);border-radius:3px;background-color:transparent;}"); - emit wsThread->sendMessageToClient("InCalibration"); - } - - if (StarLostAlert == true) - { - // m_pToolbarWidget->LabelMainStarBox->setStyleSheet("QLabel{border:2px solid rgb(255,0,0);border-radius:3px;background-color:transparent;}"); - // m_pToolbarWidget->LabelCrossX->setStyleSheet("QLabel{border:1px solid rgb(255,0,0);border-radius:3px;background-color:transparent;}"); - // m_pToolbarWidget->LabelCrossY->setStyleSheet("QLabel{border:1px solid rgb(255,0,0);border-radius:3px;background-color:transparent;}"); - emit wsThread->sendMessageToClient("StarLostAlert"); - } - - emit wsThread->sendMessageToClient("AddRMSErrorData:" + QString::number(RMSErrorX, 'f', 3) + ":" + QString::number(RMSErrorX, 'f', 3)); - } - // m_pToolbarWidget->guiderLabel->RMSErrorX_value->setPlainText(QString::number(RMSErrorX, 'f', 3)); - // m_pToolbarWidget->guiderLabel->RMSErrorY_value->setPlainText(QString::number(RMSErrorY, 'f', 3)); - - // m_pToolbarWidget->guiderLabel->GuiderDataRA->clear(); - // m_pToolbarWidget->guiderLabel->GuiderDataDEC->clear(); - - for (int i = 0; i < glPHD_rmsdate.size(); i++) - { - // m_pToolbarWidget->guiderLabel->GuiderDataRA ->append(i, glPHD_rmsdate[i].x()); - // m_pToolbarWidget->guiderLabel->GuiderDataDEC->append(i, glPHD_rmsdate[i].y()); - if (i == glPHD_rmsdate.size() - 1) - { - emit wsThread->sendMessageToClient("AddLineChartData:" + QString::number(i) + ":" + QString::number(glPHD_rmsdate[i].x()) + ":" + QString::number(glPHD_rmsdate[i].y())); - if (i > 50) - { - // m_pToolbarWidget->guiderLabel->AxisX_Graph->setRange(i-100,i); - emit wsThread->sendMessageToClient("SetLineChartRange:" + QString::number(i - 50) + ":" + QString::number(i)); - } - } - } - - unsigned char *srcData = new unsigned char[byteCount]; - mem_offset = 2048; - - memcpy(srcData, sharedmemory_phd + mem_offset, byteCount); - sharedmemory_phd[2047] = 0x00; // 0x00= image has been read - - cv::Mat img8; - cv::Mat PHDImg; - - img8.create(currentPHDSizeY, currentPHDSizeX, CV_8UC1); + bool StarLostAlert, InGuiding; + + unsigned int mem_offset = 1024; + std::memcpy(¤tPHDSizeX, pImpl->sharedMemory_ + mem_offset, + sizeof(unsigned int)); + mem_offset += sizeof(unsigned int); + std::memcpy(¤tPHDSizeY, pImpl->sharedMemory_ + mem_offset, + sizeof(unsigned int)); + mem_offset += sizeof(unsigned int); + std::memcpy(&bitDepth, pImpl->sharedMemory_ + mem_offset, + sizeof(unsigned char)); + mem_offset += sizeof(unsigned char); + + // Skip sdk_num, sdk_direction, sdk_duration + mem_offset += 3 * sizeof(int); + + // Guide error data + mem_offset += sizeof(unsigned char); // guideDataIndicator + std::memcpy(&dRa, pImpl->sharedMemory_ + mem_offset, sizeof(double)); + mem_offset += sizeof(double); + std::memcpy(&dDec, pImpl->sharedMemory_ + mem_offset, sizeof(double)); + mem_offset += sizeof(double); + std::memcpy(&SNR, pImpl->sharedMemory_ + mem_offset, sizeof(double)); + mem_offset += sizeof(double); + std::memcpy(&MASS, pImpl->sharedMemory_ + mem_offset, sizeof(double)); + mem_offset += sizeof(double); + std::memcpy(&RADUR, pImpl->sharedMemory_ + mem_offset, sizeof(int)); + mem_offset += sizeof(int); + std::memcpy(&DECDUR, pImpl->sharedMemory_ + mem_offset, sizeof(int)); + mem_offset += sizeof(int); + std::memcpy(&RADIR, pImpl->sharedMemory_ + mem_offset, sizeof(char)); + mem_offset += sizeof(char); + std::memcpy(&DECDIR, pImpl->sharedMemory_ + mem_offset, sizeof(char)); + mem_offset += sizeof(char); + std::memcpy(&RMSErrorX, pImpl->sharedMemory_ + mem_offset, sizeof(double)); + mem_offset += sizeof(double); + std::memcpy(&RMSErrorY, pImpl->sharedMemory_ + mem_offset, sizeof(double)); + mem_offset += sizeof(double); + std::memcpy(&RMSErrorTotal, pImpl->sharedMemory_ + mem_offset, + sizeof(double)); + mem_offset += sizeof(double); + std::memcpy(&PixelRatio, pImpl->sharedMemory_ + mem_offset, sizeof(double)); + mem_offset += sizeof(double); + std::memcpy(&StarLostAlert, pImpl->sharedMemory_ + mem_offset, + sizeof(bool)); + mem_offset += sizeof(bool); + std::memcpy(&InGuiding, pImpl->sharedMemory_ + mem_offset, sizeof(bool)); + + // Update member variables + pImpl->starX_ = dRa; + pImpl->starY_ = dDec; + pImpl->rmsError_ = RMSErrorTotal; + + // Process and use the data as needed + std::cout << std::format("RMSErrorX: {:.3f}, RMSErrorY: {:.3f}", RMSErrorX, + RMSErrorY) + << std::endl; + + // Clear the data indicator + pImpl->sharedMemory_[2047] = 0x00; +} - if (bitDepth == 16) - PHDImg.create(currentPHDSizeY, currentPHDSizeX, CV_16UC1); - else - PHDImg.create(currentPHDSizeY, currentPHDSizeX, CV_8UC1); +void PHDSharedMemoryClient::controlGuide(int direction, int duration) { + // Implement the guide control logic here + // This might involve sending commands to the mount or updating shared + // memory + std::cout << std::format("ControlGuide: Direction={}, Duration={}", + direction, duration) + << std::endl; - PHDImg.data = srcData; + // Example implementation (you may need to adjust this based on your + // specific requirements): + constexpr unsigned int baseAddress = 0x03; + constexpr unsigned int vendCommand = + 0x0F; // Assuming 0x0F is the command for guide control - uint16_t B = 0; - uint16_t W = 65535; + std::memset(pImpl->sharedMemory_, 0, 1024); - cv::Mat image_raw8; - image_raw8.create(PHDImg.rows, PHDImg.cols, CV_8UC1); + pImpl->sharedMemory_[1] = Impl::msb(vendCommand); + pImpl->sharedMemory_[2] = Impl::lsb(vendCommand); - if (AutoStretch == true) - { - Tools::GetAutoStretch(PHDImg, 0, B, W); - } - else - { - B = 0; - W = 65535; - } + std::memcpy(pImpl->sharedMemory_ + baseAddress, &direction, sizeof(int)); + std::memcpy(pImpl->sharedMemory_ + baseAddress + sizeof(int), &duration, + sizeof(int)); - Tools::Bit16To8_Stretch(PHDImg, image_raw8, B, W); + pImpl->sharedMemory_[0] = 0x01; // enable command - saveGuiderImageAsJPG(image_raw8); + pImpl->waitForResponse(500ms); +} - // saveGuiderImageAsJPG(PHDImg); +void PHDSharedMemoryClient::getPHD2ControlInstruct() { + std::lock_guard lock(pImpl->mutex_); - // refreshGuideImage(PHDImg, "MONO"); + unsigned int mem_offset = + 1024 + 2 * sizeof(unsigned int) + sizeof(unsigned char); - int centerX = glPHD_StarX; // Replace with your X coordinate - int centerY = glPHD_StarY; // Replace with your Y coordinate + int controlInstruct = 0; + std::memcpy(&controlInstruct, pImpl->sharedMemory_ + mem_offset, + sizeof(int)); - int cropSize = 20; // Size of the cropped region + int sdk_num = (controlInstruct >> 24) & 0xFFF; + int sdk_direction = (controlInstruct >> 12) & 0xFFF; + int sdk_duration = controlInstruct & 0xFFF; - // Calculate crop region - int startX = std::max(0, centerX - cropSize / 2); - int startY = std::max(0, centerY - cropSize / 2); - int endX = std::min(PHDImg.cols - 1, centerX + cropSize / 2); - int endY = std::min(PHDImg.rows - 1, centerY + cropSize / 2); + if (sdk_num != 0) { + std::cout + << std::format( + "PHD2ControlTelescope: num={}, direction={}, duration={}", + sdk_num, sdk_direction, sdk_duration) + << std::endl; + } - // Crop the image using OpenCV's ROI (Region of Interest) functionality - cv::Rect cropRegion(startX, startY, endX - startX + 1, endY - startY + 1); - cv::Mat croppedImage = PHDImg(cropRegion).clone(); + if (sdk_duration != 0) { + controlGuide(sdk_direction, sdk_duration); - // strechShowImage(croppedImage, m_pToolbarWidget->guiderLabel->ImageLable,m_pToolbarWidget->histogramLabel->hisLabel,"MONO",false,false,0,0,65535,1.0,1.7,100,false); - // m_pToolbarWidget->guiderLabel->ImageLable->setScaledContents(true); + int zero = 0; + std::memcpy(pImpl->sharedMemory_ + mem_offset, &zero, sizeof(int)); - delete[] srcData; - img8.release(); - PHDImg.release(); + call_phd_ChackControlStatus(sdk_num); } } -void MainWindow::ControlGuide(int Direction, int Duration) -{ - qDebug() << "\033[32m" - << "ControlGuide: " - << "\033[0m" << Direction << "," << Duration; - switch (Direction) - { - case 1: - { - if (dpMount != NULL) - { - indi_Client->setTelescopeGuideNS(dpMount, Direction, Duration); - } - break; - } - case 0: - { - if (dpMount != NULL) - { - indi_Client->setTelescopeGuideNS(dpMount, Direction, Duration); - } - break; - } - case 2: - { - if (dpMount != NULL) - { - indi_Client->setTelescopeGuideWE(dpMount, Direction, Duration); - } - break; - } - case 3: - { - if (dpMount != NULL) - { - indi_Client->setTelescopeGuideWE(dpMount, Direction, Duration); - } - break; - } - default: - break; // - } +// New functions implementation + +bool PHDSharedMemoryClient::startCalibration() { + constexpr unsigned int vendCommand = + 0x10; // Assuming 0x10 is the command for starting calibration + return pImpl->sendCommand(vendCommand); } -void MainWindow::getTimeNow(int index) -{ - // 获取当前时间点 - auto now = std::chrono::system_clock::now(); +bool PHDSharedMemoryClient::abortCalibration() { + constexpr unsigned int vendCommand = + 0x11; // Assuming 0x11 is the command for aborting calibration + return pImpl->sendCommand(vendCommand); +} - // 将当前时间点转换为毫秒 - auto ms = std::chrono::duration_cast(now.time_since_epoch()).count(); +bool PHDSharedMemoryClient::dither(double pixels) { + constexpr unsigned int baseAddress = 0x03; + constexpr unsigned int vendCommand = + 0x12; // Assuming 0x12 is the command for dithering - // 将毫秒时间戳转换为时间类型(std::time_t) - std::time_t time_now = ms / 1000; // 将毫秒转换为秒 + std::memset(pImpl->sharedMemory_, 0, 1024); - // 使用 std::strftime 函数将时间格式化为字符串 - char buffer[80]; - std::strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", - std::localtime(&time_now)); + pImpl->sharedMemory_[1] = Impl::msb(vendCommand); + pImpl->sharedMemory_[2] = Impl::lsb(vendCommand); - // 添加毫秒部分 - std::string formatted_time = buffer + std::to_string(ms % 1000); + std::memcpy(pImpl->sharedMemory_ + baseAddress, &pixels, sizeof(double)); - // 打印带有当前时间的输出 - // std::cout << "TIME(ms): " << formatted_time << "," << index << std::endl; -} + pImpl->sharedMemory_[0] = 0x01; // enable command -void MainWindow::onPHDControlGuideTimeout() -{ - GetPHD2ControlInstruct(); + return pImpl->waitForResponse(500ms); } -void MainWindow::GetPHD2ControlInstruct() -{ - std::lock_guard lock(receiveMutex); - - unsigned int mem_offset; +bool PHDSharedMemoryClient::setLockPosition(double x, double y) { + constexpr unsigned int baseAddress = 0x03; + constexpr unsigned int vendCommand = + 0x13; // Assuming 0x13 is the command for setting lock position - int sdk_direction = 0; - int sdk_duration = 0; - int sdk_num = 0; - int zero = 0; - mem_offset = 1024; + std::memset(pImpl->sharedMemory_, 0, 1024); - mem_offset = mem_offset + sizeof(unsigned int); - mem_offset = mem_offset + sizeof(unsigned int); - mem_offset = mem_offset + sizeof(unsigned char); + pImpl->sharedMemory_[1] = Impl::msb(vendCommand); + pImpl->sharedMemory_[2] = Impl::lsb(vendCommand); - int ControlInstruct = 0; + std::memcpy(pImpl->sharedMemory_ + baseAddress, &x, sizeof(double)); + std::memcpy(pImpl->sharedMemory_ + baseAddress + sizeof(double), &y, + sizeof(double)); - memcpy(&ControlInstruct, sharedmemory_phd + mem_offset, sizeof(int)); - int mem_offset_sdk_num = mem_offset; - mem_offset = mem_offset + sizeof(int); + pImpl->sharedMemory_[0] = 0x01; // enable command - sdk_num = (ControlInstruct >> 24) & 0xFFF; // 取前12位 - sdk_direction = (ControlInstruct >> 12) & 0xFFF; // 取中间12位 - sdk_duration = ControlInstruct & 0xFFF; // 取后12位 - - if (sdk_num != 0) - { - getTimeNow(sdk_num); - std::cout << "\033[31m" - << "PHD2ControlTelescope: " - << "\033[0m" << sdk_num << "," << sdk_direction << "," - << sdk_duration << std::endl; - } - if (sdk_duration != 0) - { - MainWindow::ControlGuide(sdk_direction, sdk_duration); + return pImpl->waitForResponse(500ms); +} - memcpy(sharedmemory_phd + mem_offset_sdk_num, &zero, sizeof(int)); +std::pair PHDSharedMemoryClient::getStarPosition() const { + return {pImpl->starX_, pImpl->starY_}; +} - call_phd_ChackControlStatus(sdk_num); // set pFrame->ControlStatus = 0; - } +double PHDSharedMemoryClient::getGuideRMSError() const { + return pImpl->rmsError_; } diff --git a/src/client/phd2/shared.hpp b/src/client/phd2/shared.hpp index e69de29b..d04793cf 100644 --- a/src/client/phd2/shared.hpp +++ b/src/client/phd2/shared.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include +#include +#include + +class PHDSharedMemoryClient { +public: + PHDSharedMemoryClient(); + ~PHDSharedMemoryClient(); + + // Disable copy operations + PHDSharedMemoryClient(const PHDSharedMemoryClient&) = delete; + PHDSharedMemoryClient& operator=(const PHDSharedMemoryClient&) = delete; + + // Enable move operations + PHDSharedMemoryClient(PHDSharedMemoryClient&&) noexcept; + PHDSharedMemoryClient& operator=(PHDSharedMemoryClient&&) noexcept; + + bool connectPHD(); + bool call_phd_GetVersion(std::string& versionName); + bool call_phd_StartLooping(); + bool call_phd_StopLooping(); + bool call_phd_AutoFindStar(); + bool call_phd_StartGuiding(); + bool call_phd_checkStatus(unsigned char& status); + bool call_phd_setExposureTime(unsigned int expTime); + bool call_phd_whichCamera(const std::string& camera); + bool call_phd_ChackControlStatus(int sdk_num); + bool call_phd_ClearCalibration(); + + void showPHDData(); + void controlGuide(int direction, int duration); + void getPHD2ControlInstruct(); + + // New functions + bool startCalibration(); + bool abortCalibration(); + bool dither(double pixels); + bool setLockPosition(double x, double y); + std::pair getStarPosition() const; + double getGuideRMSError() const; + +private: + class Impl; + std::unique_ptr pImpl; +}; diff --git a/src/client/planetarium/skysafari.cpp b/src/client/planetarium/skysafari.cpp new file mode 100644 index 00000000..5f1e1584 --- /dev/null +++ b/src/client/planetarium/skysafari.cpp @@ -0,0 +1,264 @@ +#include "skysafari.hpp" + +#include +#include +#include + +class SkySafariController::Impl { +public: + bool initialize(const std::string& host, int port) { + // Actual connection code would go here + m_connected = true; + return m_connected; + } + + std::string processCommand(std::string_view command) { + if (command.starts_with("GR")) + return getRightAscension(); + else if (command.starts_with("GD")) + return getDeclination(); + else if (command.starts_with("Sr")) + return setRightAscension(command.substr(2)); + else if (command.starts_with("Sd")) + return setDeclination(command.substr(2)); + else if (command == "MS") + return goTo(); + else if (command == "CM") + return sync(); + else if (command == "Q") + return abort(); + else if (command == "GG") + return getUTCOffset(); + else if (command.starts_with("SG")) + return setUTCOffset(command.substr(2)); + else if (command.starts_with("St")) + return setLatitude(command.substr(2)); + else if (command.starts_with("Sg")) + return setLongitude(command.substr(2)); + else if (command == "MP") + return park() ? "1" : "0"; + else if (command == "MU") + return unpark() ? "1" : "0"; + // Add other command handlers here + + return "1"; // Default response for unrecognized commands + } + + void setTargetCoordinates(const Coordinates& coords) { + m_targetCoords = coords; + } + void setGeographicCoordinates(const GeographicCoordinates& coords) { + m_geoCoords = coords; + } + void setDateTime(const DateTime& dt) { m_dateTime = dt; } + void setSlewRate(SlewRate rate) { m_slewRate = rate; } + + bool startSlew(Direction direction) { + m_slewingDirection = direction; + // Actual slew start code would go here + return true; + } + + bool stopSlew(Direction direction) { + if (m_slewingDirection == direction) { + m_slewingDirection = std::nullopt; + // Actual slew stop code would go here + return true; + } + return false; + } + + bool park() { + // Actual park code would go here + m_parked = true; + return true; + } + + bool unpark() { + // Actual unpark code would go here + m_parked = false; + return true; + } + + Coordinates getCurrentCoordinates() const { + return m_currentCoords.value_or(Coordinates{0, 0}); + } + GeographicCoordinates getGeographicCoordinates() const { + return m_geoCoords.value_or(GeographicCoordinates{0, 0}); + } + DateTime getDateTime() const { return m_dateTime.value_or(DateTime{}); } + SlewRate getSlewRate() const { return m_slewRate; } + bool isConnected() const { return m_connected; } + bool isParked() const { return m_parked; } + +private: + bool m_connected = false; + bool m_parked = false; + std::optional m_currentCoords; + std::optional m_targetCoords; + std::optional m_geoCoords; + std::optional m_dateTime; + SlewRate m_slewRate = SlewRate::CENTERING; + std::optional m_slewingDirection; + + static std::string hoursToSexagesimal(double hours) { + int h = static_cast(hours); + int m = static_cast((hours - h) * 60); + int s = static_cast(((hours - h) * 60 - m) * 60); + return std::format("{:02d}:{:02d}:{:02d}", h, m, s); + } + + static std::string degreesToSexagesimal(double degrees) { + char sign = degrees >= 0 ? '+' : '-'; + degrees = std::abs(degrees); + int d = static_cast(degrees); + int m = static_cast((degrees - d) * 60); + int s = static_cast(((degrees - d) * 60 - m) * 60); + return std::format("{}{:02d}:{:02d}:{:02d}", sign, d, m, s); + } + + std::string getRightAscension() { + if (!m_currentCoords) + return "Error"; + return hoursToSexagesimal(m_currentCoords->ra) + '#'; + } + + std::string getDeclination() { + if (!m_currentCoords) + return "Error"; + return degreesToSexagesimal(m_currentCoords->dec) + '#'; + } + + std::string setRightAscension(std::string_view ra) { + if (!m_targetCoords) + m_targetCoords = Coordinates{}; + m_targetCoords->ra = std::stod(std::string(ra)); + return "1"; + } + + std::string setDeclination(std::string_view dec) { + if (!m_targetCoords) + m_targetCoords = Coordinates{}; + m_targetCoords->dec = std::stod(std::string(dec)); + return "1"; + } + + std::string goTo() { + if (!m_targetCoords) + return "2#"; + + // Actual GOTO implementation would go here + m_currentCoords = m_targetCoords; + return "0"; + } + + std::string sync() { + if (!m_targetCoords) + return "Error"; + + m_currentCoords = m_targetCoords; + return " M31 EX GAL MAG 3.5 SZ178.0'#"; + } + + std::string abort() { + m_slewingDirection = std::nullopt; + // Actual abort implementation would go here + return "1"; + } + + std::string getUTCOffset() { + if (!m_dateTime) + return "Error"; + return std::format("{:.1f}", m_dateTime->utcOffset); + } + + std::string setUTCOffset(std::string_view offset) { + if (!m_dateTime) + m_dateTime = DateTime{}; + m_dateTime->utcOffset = std::stod(std::string(offset)); + return "1"; + } + + std::string setLatitude(std::string_view lat) { + if (!m_geoCoords) + m_geoCoords = GeographicCoordinates{}; + m_geoCoords->latitude = std::stod(std::string(lat)); + return "1"; + } + + std::string setLongitude(std::string_view lon) { + if (!m_geoCoords) + m_geoCoords = GeographicCoordinates{}; + m_geoCoords->longitude = std::stod(std::string(lon)); + return "1"; + } +}; + +// SkySafariController implementation + +SkySafariController::SkySafariController() : pImpl(std::make_unique()) {} +SkySafariController::~SkySafariController() = default; + +SkySafariController::SkySafariController(SkySafariController&&) noexcept = + default; +SkySafariController& SkySafariController::operator=( + SkySafariController&&) noexcept = default; + +bool SkySafariController::initialize(const std::string& host, int port) { + return pImpl->initialize(host, port); +} + +std::string SkySafariController::processCommand(std::string_view command) { + return pImpl->processCommand(command); +} + +void SkySafariController::setTargetCoordinates(const Coordinates& coords) { + pImpl->setTargetCoordinates(coords); +} + +void SkySafariController::setGeographicCoordinates( + const GeographicCoordinates& coords) { + pImpl->setGeographicCoordinates(coords); +} + +void SkySafariController::setDateTime(const DateTime& dt) { + pImpl->setDateTime(dt); +} + +void SkySafariController::setSlewRate(SlewRate rate) { + pImpl->setSlewRate(rate); +} + +bool SkySafariController::startSlew(Direction direction) { + return pImpl->startSlew(direction); +} + +bool SkySafariController::stopSlew(Direction direction) { + return pImpl->stopSlew(direction); +} + +bool SkySafariController::park() { return pImpl->park(); } + +bool SkySafariController::unpark() { return pImpl->unpark(); } + +SkySafariController::Coordinates SkySafariController::getCurrentCoordinates() + const { + return pImpl->getCurrentCoordinates(); +} + +SkySafariController::GeographicCoordinates +SkySafariController::getGeographicCoordinates() const { + return pImpl->getGeographicCoordinates(); +} + +SkySafariController::DateTime SkySafariController::getDateTime() const { + return pImpl->getDateTime(); +} + +SkySafariController::SlewRate SkySafariController::getSlewRate() const { + return pImpl->getSlewRate(); +} + +bool SkySafariController::isConnected() const { return pImpl->isConnected(); } + +bool SkySafariController::isParked() const { return pImpl->isParked(); } diff --git a/src/client/planetarium/skysafari.hpp b/src/client/planetarium/skysafari.hpp new file mode 100644 index 00000000..cc2e0f6c --- /dev/null +++ b/src/client/planetarium/skysafari.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include +#include +#include + +class SkySafariController { +public: + struct Coordinates { + double ra; // Right Ascension in hours + double dec; // Declination in degrees + }; + + struct GeographicCoordinates { + double latitude; + double longitude; + }; + + struct DateTime { + int year, month, day, hour, minute, second; + double utcOffset; + }; + + enum class SlewRate { GUIDE, CENTERING, FIND, MAX }; + enum class Direction { NORTH, SOUTH, EAST, WEST }; + + SkySafariController(); + ~SkySafariController(); + + // Prevent copying + SkySafariController(const SkySafariController&) = delete; + SkySafariController& operator=(const SkySafariController&) = delete; + + // Allow moving + SkySafariController(SkySafariController&&) noexcept; + SkySafariController& operator=(SkySafariController&&) noexcept; + + bool initialize(const std::string& host, int port); + std::string processCommand(std::string_view command); + + void setTargetCoordinates(const Coordinates& coords); + void setGeographicCoordinates(const GeographicCoordinates& coords); + void setDateTime(const DateTime& dt); + void setSlewRate(SlewRate rate); + + bool startSlew(Direction direction); + bool stopSlew(Direction direction); + bool park(); + bool unpark(); + + Coordinates getCurrentCoordinates() const; + GeographicCoordinates getGeographicCoordinates() const; + DateTime getDateTime() const; + SlewRate getSlewRate() const; + bool isConnected() const; + bool isParked() const; + +private: + class Impl; + std::unique_ptr pImpl; +}; diff --git a/src/client/platesolver/platesolver2.cpp b/src/client/platesolver/platesolver2.cpp new file mode 100644 index 00000000..f2e64084 --- /dev/null +++ b/src/client/platesolver/platesolver2.cpp @@ -0,0 +1,93 @@ +#include "platesolver2.hpp" +#include +#include +#include +#include +#include "device/template/solver.hpp" + +namespace fs = std::filesystem; + +Platesolve2Solver::Platesolve2Solver(std::string executableLocation) + : m_executableLocation(std::move(executableLocation)) , AtomSolver(executableLocation) {} + +PlateSolveResult Platesolve2Solver::solve( + const std::string& imageFilePath, + const std::optional& initialCoordinates, double fovW, + double fovH, int imageWidth, int imageHeight) { + int regions = 100; // Default value, could be made configurable + std::string outputFilePath = getOutputPath(imageFilePath); + std::string arguments = + getArguments(imageFilePath, initialCoordinates, fovW, fovH, regions); + + if (int result = executeCommand(m_executableLocation, arguments); + result != 0) { + std::cerr << "Error executing Platesolve2" << std::endl; + return PlateSolveResult{false}; + } + + return readResult(outputFilePath, imageWidth, imageHeight); +} + +std::string Platesolve2Solver::getOutputPath( + const std::string& imageFilePath) const { + fs::path path(imageFilePath); + return (path.parent_path() / path.stem()).string() + ".apm"; +} + +std::string Platesolve2Solver::getArguments( + const std::string& imageFilePath, + const std::optional& initialCoordinates, double fovW, + double fovH, int regions) const { + std::ostringstream oss; + if (initialCoordinates) { + oss << initialCoordinates->ra << "," << initialCoordinates->dec << ","; + } else { + oss << "0,0,"; + } + oss << fovW << "," << fovH << "," << regions << "," << imageFilePath + << ",0"; + return oss.str(); +} + +PlateSolveResult Platesolve2Solver::readResult( + const std::string& outputFilePath, int imageWidth, int imageHeight) const { + PlateSolveResult result{false}; + std::ifstream file(outputFilePath); + if (!file.is_open()) { + return result; + } + + std::string line; + int lineNum = 0; + while (std::getline(file, line)) { + std::vector tokens; + std::istringstream iss(line); + std::string token; + while (std::getline(iss, token, ',')) { + tokens.push_back(token); + } + + if (lineNum == 0 && tokens.size() > 2) { + result.success = (std::stoi(tokens[2]) == 1); + if (result.success) { + result.coordinates.ra = std::stod(tokens[0]); + result.coordinates.dec = std::stod(tokens[1]); + } + } else if (lineNum == 1 && tokens.size() > 2) { + result.pixscale = std::stod(tokens[0]); + result.positionAngle = 360 - std::stod(tokens[1]); + result.flipped = (std::stod(tokens[2]) >= 0); + if (*result.flipped) { + result.positionAngle += 180; + } + if (!std::isnan(result.pixscale)) { + double diagonalPixels = std::hypot(imageWidth, imageHeight); + result.radius = (diagonalPixels * result.pixscale) / + (2 * 3600); // Convert to degrees + } + } + lineNum++; + } + + return result; +} diff --git a/src/client/platesolver/platesolver2.hpp b/src/client/platesolver/platesolver2.hpp new file mode 100644 index 00000000..7d09b750 --- /dev/null +++ b/src/client/platesolver/platesolver2.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "device/template/solver.hpp" + +class Platesolve2Solver : public AtomSolver { +public: + Platesolve2Solver(std::string executableLocation); + + PlateSolveResult solve(const std::string& imageFilePath, + const std::optional& initialCoordinates, + double fovW, double fovH, int imageWidth, + int imageHeight) override; + +protected: + std::string getOutputPath(const std::string& imageFilePath) const override; + +private: + std::string m_executableLocation; + + std::string getArguments( + const std::string& imageFilePath, + const std::optional& initialCoordinates, double fovW, + double fovH, int regions) const; + + PlateSolveResult readResult(const std::string& outputFilePath, + int imageWidth, int imageHeight) const; +}; diff --git a/src/client/platesolver/platesolver3.cpp b/src/client/platesolver/platesolver3.cpp new file mode 100644 index 00000000..b28e53be --- /dev/null +++ b/src/client/platesolver/platesolver3.cpp @@ -0,0 +1,90 @@ +#include "platesolver3.hpp" +#include +#include +#include +#include + +namespace fs = std::filesystem; + +Platesolve3Solver::Platesolve3Solver(std::string executableLocation) + : m_executableLocation(std::move(executableLocation)) {} + +PlateSolveResult Platesolve3Solver::solve( + const std::string& imageFilePath, + const std::optional& initialCoordinates, double fovW, + double fovH, int imageWidth, int imageHeight) { + std::string outputFilePath = getOutputPath(imageFilePath); + std::string arguments = + getArguments(imageFilePath, initialCoordinates, fovW, fovH); + + if (int result = executeCommand(m_executableLocation, arguments); + result != 0) { + std::cerr << "Error executing Platesolve3" << std::endl; + return PlateSolveResult{false}; + } + + return readResult(outputFilePath, imageWidth, imageHeight); +} + +std::string Platesolve3Solver::getOutputPath( + const std::string& imageFilePath) const { + fs::path path(imageFilePath); + return (path.parent_path() / path.stem()).string() + "_PS3.txt"; +} + +std::string Platesolve3Solver::getArguments( + const std::string& imageFilePath, + const std::optional& initialCoordinates, double fovW, + double fovH) const { + std::ostringstream oss; + oss << "\"" << imageFilePath << "\" "; + if (initialCoordinates) { + oss << toRadians(initialCoordinates->ra) << " " + << toRadians(initialCoordinates->dec) << " "; + } else { + oss << "0 0 "; + } + oss << toRadians(fovW) << " " << toRadians(fovH); + return oss.str(); +} + +PlateSolveResult Platesolve3Solver::readResult( + const std::string& outputFilePath, int imageWidth, int imageHeight) const { + PlateSolveResult result{false}; + std::ifstream file(outputFilePath); + if (!file.is_open()) { + return result; + } + + std::string line; + int lineNum = 0; + while (std::getline(file, line)) { + std::vector tokens; + std::istringstream iss(line); + std::string token; + while (std::getline(iss, token, ',')) { + tokens.push_back(token); + } + + if (lineNum == 0) { + result.success = (line == "True"); + if (!result.success) + return result; + } else if (lineNum == 1 && tokens.size() >= 2) { + result.coordinates.ra = toDegrees(std::stod(tokens[0])); + result.coordinates.dec = toDegrees(std::stod(tokens[1])); + } else if (lineNum == 2 && tokens.size() >= 2) { + result.pixscale = 206264.8 / std::stod(tokens[0]); + if (!std::isnan(result.pixscale)) { + result.radius = + arcsecToDegree(std::hypot(imageWidth * result.pixscale, + imageHeight * result.pixscale) / + 2.0); + } + result.positionAngle = std::stod(tokens[1]); + } + lineNum++; + } + + return result; +} diff --git a/src/client/platesolver/platesolver3.hpp b/src/client/platesolver/platesolver3.hpp new file mode 100644 index 00000000..4c1b6368 --- /dev/null +++ b/src/client/platesolver/platesolver3.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "device/template/solver.hpp" + +class Platesolve3Solver : public AtomSolver { +public: + Platesolve3Solver(std::string executableLocation); + + PlateSolveResult solve(const std::string& imageFilePath, + const std::optional& initialCoordinates, + double fovW, double fovH, int imageWidth, + int imageHeight) override; + +protected: + std::string getOutputPath(const std::string& imageFilePath) const override; + +private: + std::string m_executableLocation; + + std::string getArguments( + const std::string& imageFilePath, + const std::optional& initialCoordinates, double fovW, + double fovH) const; + + PlateSolveResult readResult(const std::string& outputFilePath, + int imageWidth, int imageHeight) const; +}; diff --git a/src/config/configor.hpp b/src/config/configor.hpp index 6f14c89f..6d1a12e4 100644 --- a/src/config/configor.hpp +++ b/src/config/configor.hpp @@ -19,45 +19,98 @@ Description: Configor #include #include #include + +#include "atom/error/exception.hpp" #include "atom/type/json_fwd.hpp" +#include "utils/constant.hpp" + namespace fs = std::filesystem; +using json = nlohmann::json; -#define GetIntConfig(path) \ - GetPtr("lithium.config") \ - .value() \ - ->getValue(path) \ - .value() \ +#define GetIntConfig(path) \ + GetPtr(Constatns::CONFIG_MANAGER) \ + .value() \ + ->getValue(path) \ + .value() \ .get() -#define GetFloatConfig(path) \ - GetPtr("lithium.config") \ - .value() \ - ->getValue(path) \ - .value() \ +#define GetFloatConfig(path) \ + GetPtr(Constatns::CONFIG_MANAGER) \ + .value() \ + ->getValue(path) \ + .value() \ .get() -#define GetBoolConfig(path) \ - GetPtr("lithium.config") \ - .value() \ - ->getValue(path) \ - .value() \ +#define GetBoolConfig(path) \ + GetPtr(Constatns::CONFIG_MANAGER) \ + .value() \ + ->getValue(path) \ + .value() \ .get() -#define GetDoubleConfig(path) \ - GetPtr("lithium.config") \ - .value() \ - ->getValue(path) \ - .value() \ +#define GetDoubleConfig(path) \ + GetPtr(Constatns::CONFIG_MANAGER) \ + .value() \ + ->getValue(path) \ + .value() \ .get() -#define GetStringConfig(path) \ - GetPtr("lithium.config") \ - .value() \ - ->getValue(path) \ - .value() \ +#define GetStringConfig(path) \ + GetPtr(Constatns::CONFIG_MANAGER) \ + .value() \ + ->getValue(path) \ + .value() \ .get() +#define GET_CONFIG_VALUE(configManager, path, type, outputVar) \ + type outputVar; \ + do { \ + auto opt = (configManager)->getValue(path); \ + if (opt.has_value()) { \ + try { \ + (outputVar) = opt.value().get(); \ + } catch (const std::bad_optional_access& e) { \ + LOG_F(ERROR, "Bad access to config value for {}: {}", path, \ + e.what()); \ + THROW_BAD_CONFIG_EXCEPTION(e.what()); \ + } catch (const json::exception& e) { \ + LOG_F(ERROR, "Invalid config value for {}: {}", path, \ + e.what()); \ + THROW_INVALID_CONFIG_EXCEPTION(e.what()); \ + } catch (const std::exception& e) { \ + THROW_UNKOWN(e.what()); \ + } \ + } else { \ + LOG_F(WARNING, "Config value for {} not found", path); \ + THROW_OBJ_NOT_EXIST("Config value for", path); \ + } \ + } while (0) + +class BadConfigException : public atom::error::Exception { + using atom::error::Exception::Exception; +}; + +#define THROW_BAD_CONFIG_EXCEPTION(...) \ + throw BadConfigException(ATOM_FILE_NAME, ATOM_FILE_LINE, ATOM_FUNC_NAME, \ + __VA_ARGS__) + +#define THROW_NESTED_BAD_CONFIG_EXCEPTION(...) \ + BadConfigException::rethrowNested(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) + +class InvalidConfigException : public BadConfigException { + using BadConfigException::BadConfigException; +}; + +#define THROW_INVALID_CONFIG_EXCEPTION(...) \ + throw InvalidConfigException(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) + +#define THROW_NESTED_INVALID_CONFIG_EXCEPTION(...) \ + InvalidConfigException::rethrowNested(ATOM_FILE_NAME, ATOM_FILE_LINE, \ + ATOM_FUNC_NAME, __VA_ARGS__) + namespace lithium { class ConfigManagerImpl; /** @@ -96,10 +149,10 @@ class ConfigManager { /** * @brief Retrieves the value associated with the given key path. * @param key_path The path to the configuration value. - * @return std::optional The optional JSON value if found. + * @return std::optional The optional JSON value if found. */ [[nodiscard]] auto getValue(const std::string& key_path) const - -> std::optional; + -> std::optional; /** * @brief Sets the value for the specified key path. @@ -107,10 +160,9 @@ class ConfigManager { * @param value The JSON value to set. * @return bool True if the value was successfully set, false otherwise. */ - auto setValue(const std::string& key_path, - const nlohmann::json& value) -> bool; + auto setValue(const std::string& key_path, const json& value) -> bool; - auto appendValue(const std::string& key_path, const nlohmann::json& value) -> bool; + auto appendValue(const std::string& key_path, const json& value) -> bool; /** * @brief Deletes the value associated with the given key path. @@ -166,13 +218,13 @@ class ConfigManager { * @brief Merges the current configuration with the provided JSON data. * @param src The JSON object to merge into the current configuration. */ - void mergeConfig(const nlohmann::json& src); + void mergeConfig(const json& src); private: std::unique_ptr m_impl_; ///< Implementation-specific pointer. - void mergeConfig(const nlohmann::json& src, nlohmann::json& target); + void mergeConfig(const json& src, json& target); }; } // namespace lithium diff --git a/src/debug/command.cpp b/src/debug/command.cpp index 1533a6d0..42c345f7 100644 --- a/src/debug/command.cpp +++ b/src/debug/command.cpp @@ -22,7 +22,7 @@ void loadSharedCompoennt(const std::string &compoennt_name, return; } auto manager = GetWeakPtr( - constants::LITHIUM_COMPONENT_MANAGER); + Constants::LITHIUM_COMPONENT_MANAGER); if (manager.expired()) { std::cout << "Component manager not found" << '\n'; return; @@ -31,7 +31,7 @@ void loadSharedCompoennt(const std::string &compoennt_name, {{"component_name", compoennt_name}, {"module_name", module_name}, {"module_path", atom::system::getCurrentWorkingDirectory() + - constants::MODULE_FOLDER}})) { + Constants::MODULE_FOLDER}})) { std::cout << "Failed to load component" << '\n'; return; } @@ -44,7 +44,7 @@ void unloadSharedCompoennt(const std::string &compoennt_name) { return; } if (!GetWeakPtr( - constants::LITHIUM_COMPONENT_MANAGER) + Constants::LITHIUM_COMPONENT_MANAGER) .lock() ->unloadComponent({{"component_name", compoennt_name}})) { std::cout << "Failed to unload component" << '\n'; @@ -59,7 +59,7 @@ void reloadSharedCompoennt(const std::string &compoennt_name) { return; } if (!GetWeakPtr( - constants::LITHIUM_COMPONENT_MANAGER) + Constants::LITHIUM_COMPONENT_MANAGER) .lock() ->reloadComponent({{"component_name", compoennt_name}})) { std::cout << "Failed to reload component" << '\n'; @@ -70,7 +70,7 @@ void reloadSharedCompoennt(const std::string &compoennt_name) { void reloadAllComponents() { if (!GetWeakPtr( - constants::LITHIUM_COMPONENT_MANAGER) + Constants::LITHIUM_COMPONENT_MANAGER) .lock() ->reloadAllComponents()) { std::cout << "Failed to reload all components" << '\n'; @@ -85,7 +85,7 @@ void scanComponents(const std::string &path) { return; } if (auto vec = GetWeakPtr( - constants::LITHIUM_COMPONENT_MANAGER) + Constants::LITHIUM_COMPONENT_MANAGER) .lock() ->scanComponents(path); vec.empty()) { @@ -105,7 +105,7 @@ void getComponentInfo(const std::string &name) { return; } auto manager = GetWeakPtr( - constants::LITHIUM_COMPONENT_MANAGER); + Constants::LITHIUM_COMPONENT_MANAGER); if (manager.expired()) { std::cout << "Component manager not found" << '\n'; return; @@ -121,7 +121,7 @@ void getComponentInfo(const std::string &name) { void getComponentList() { auto manager = GetWeakPtr( - constants::LITHIUM_COMPONENT_MANAGER); + Constants::LITHIUM_COMPONENT_MANAGER); if (manager.expired()) { std::cout << "Component manager not found" << std::endl; return; diff --git a/src/device/template/solver.hpp b/src/device/template/solver.hpp index 140cd8cc..9283c149 100644 --- a/src/device/template/solver.hpp +++ b/src/device/template/solver.hpp @@ -16,11 +16,50 @@ Description: AtomSolver Simulator and Basic Definition #include "device.hpp" +#include #include +#include "macro.hpp" + +template +concept CoordinateType = requires(T t) { + { t.ra } -> std::convertible_to; + { t.dec } -> std::convertible_to; +}; + +struct Coordinates { + double ra; // in degrees + double dec; // in degrees +} ATOM_ALIGNAS(16); + +struct PlateSolveResult { + bool success{}; + Coordinates coordinates{}; + double pixscale{}; + double positionAngle{}; + std::optional flipped; + double radius{}; +} ATOM_ALIGNAS(64); + class AtomSolver : public AtomDriver { public: - explicit AtomSolver(std::string name) : AtomDriver(name) {} + explicit AtomSolver(std::string name) : AtomDriver(std::move(name)) {} + + virtual PlateSolveResult solve( + const std::string& imageFilePath, + const std::optional& initialCoordinates, double fovW, + double fovH, int imageWidth, int imageHeight) = 0; + + std::future async_solve( + const std::string& imageFilePath, + const std::optional& initialCoordinates, double fovW, + double fovH, int imageWidth, int imageHeight); +protected: + static double toRadians(double degrees); + static double toDegrees(double radians); + static double arcsecToDegree(double arcsec); + virtual std::string getOutputPath( + const std::string& imageFilePath) const = 0; }; diff --git a/src/task/CMakeLists.txt b/src/task/CMakeLists.txt index 4b1649c4..8cfd04f2 100644 --- a/src/task/CMakeLists.txt +++ b/src/task/CMakeLists.txt @@ -15,7 +15,6 @@ set(PROJECT_SOURCES generator.cpp loader.cpp manager.cpp - sequencer.cpp singlepool.cpp task.cpp ) @@ -26,7 +25,6 @@ set(PROJECT_HEADERS generator.hpp loader.hpp manager.hpp - sequencer.hpp singlepool.hpp task.hpp ) diff --git a/src/task/builtin.cpp b/src/task/builtin.cpp new file mode 100644 index 00000000..6af9ef7b --- /dev/null +++ b/src/task/builtin.cpp @@ -0,0 +1,56 @@ +#include "builtin.hpp" +#include +#include +#include + +#include "atom/type/json.hpp" + +namespace lithium { + +BuiltinFunctions::BuiltinFunctions() { + registerMathFunctions(); + registerStringFunctions(); + registerArrayFunctions(); +} + +auto BuiltinFunctions::executeFunction(const std::string& name, const nlohmann::json& args) -> nlohmann::json { + if (functions_.find(name) == functions_.end()) { + LITHIUM_THROW(std::runtime_error, "Unknown builtin function: {}", name); + } + return functions_[name](args); +} + +void BuiltinFunctions::registerMathFunctions() { + functions_["math_sin"] = [](const nlohmann::json& args) -> json { + return std::sin(args[0].get()); + }; + functions_["math_cos"] = [](const nlohmann::json& args) { + return std::cos(args[0].get()); + }; + functions_["math_tan"] = [](const nlohmann::json& args) { + return std::tan(args[0].get()); + }; + functions_["math_pow"] = [](const nlohmann::json& args) { + return std::pow(args[0].get(), args[1].get()); + }; + // Add more math functions as needed +} + +void BuiltinFunctions::registerStringFunctions() { + functions_["string_length"] = [](const nlohmann::json& args) { + return args[0].get().length(); + }; + functions_["string_to_upper"] = [](const nlohmann::json& args) { + std::string result = args[0].get(); + std::transform(result.begin(), result.end(), result.begin(), ::toupper); + return result; + }; + functions_["string_to_lower"] = [](const nlohmann::json& args) { + std::string result = args[0].get(); + std::transform(result.begin(), result.end(), result.begin(), ::tolower); + return result; + }; + // Add more string functions as needed +} + +} // namespace lithium diff --git a/src/task/builtin.hpp b/src/task/builtin.hpp new file mode 100644 index 00000000..5677bd5c --- /dev/null +++ b/src/task/builtin.hpp @@ -0,0 +1,29 @@ +#ifndef LITHIUM_TASK_INTERPRETER_BUILTINS_HPP +#define LITHIUM_TASK_INTERPRETER_BUILTINS_HPP + +#include +#include +#include "atom/type/json_fwd.hpp" + +using json = nlohmann::json; + +namespace lithium { + +class BuiltinFunctions { +public: + BuiltinFunctions(); + + auto executeFunction(const std::string& name, const json& args) -> json; + +private: + std::unordered_map> + functions_; + + void registerMathFunctions(); + void registerStringFunctions(); + void registerArrayFunctions(); +}; + +} // namespace lithium + +#endif // LITHIUM_TASK_INTERPRETER_BUILTINS_HPP diff --git a/src/task/custom/autofocus/autofocus.cpp b/src/task/custom/autofocus/autofocus.cpp new file mode 100644 index 00000000..ef0daab3 --- /dev/null +++ b/src/task/custom/autofocus/autofocus.cpp @@ -0,0 +1,149 @@ +#include "autofocus.hpp" +#include +#include +#include "curve.hpp" +#include "detector.cpp" +#include "utils.hpp" + +const double AstroAutoFocus::HFR_THRESHOLD = 0.1; +const double AstroAutoFocus::TEMPERATURE_COEFFICIENT = 0.001; + +AstroAutoFocus::AstroAutoFocus() + : currentPosition(0), + bestPosition(0), + bestHFR(std::numeric_limits::max()), + currentTemperature(20.0), + starDetector(new StarDetector()), + curveFitter(new FocusCurveFitter()) {} + +void AstroAutoFocus::focus(const std::vector& images, + const std::vector& positions, + double temperature) { + hfrScores.clear(); + focusPositions = positions; + bestHFR = std::numeric_limits::max(); + currentTemperature = temperature; + + int step = std::abs(positions[1] - positions[0]); + + for (int i = 0; + i < std::min(static_cast(images.size()), MAX_FOCUS_STEPS); ++i) { + double hfr = calculateHFR(images[i]); + if (!Utils::isOutlier(hfr, hfrScores)) { + hfrScores.push_back(hfr); + + if (hfr < bestHFR) { + bestHFR = hfr; + bestPosition = focusPositions[i]; + } + + if (isPeak(i) && !isFalsePeak(i)) { + currentPosition = focusPositions[i]; + break; + } + + if (i > 0) { + step = calculateAdaptiveStep(step, hfrScores[i - 1], hfr); + } + + if (i > WINDOW_SIZE && + std::all_of(hfrScores.end() - WINDOW_SIZE, hfrScores.end(), + [&](double h) { return h > bestHFR; })) { + break; + } + } + } + + std::vector smoothedHFR = Utils::applyNoiseReduction(hfrScores); + + auto [fitPosition, fitHFR] = + curveFitter->fitCurve(focusPositions, smoothedHFR); + if (fitHFR < bestHFR) { + bestPosition = fitPosition; + bestHFR = fitHFR; + } + + currentPosition = getTemperatureCompensatedPosition(bestPosition); + updateFocusHistory(currentPosition, bestHFR); +} + +double AstroAutoFocus::calculateHFR(const cv::Mat& image) { + std::vector stars = starDetector->detectStars(image); + if (stars.empty()) { + return std::numeric_limits::max(); + } + double totalHFR = 0; + for (const auto& star : stars) { + totalHFR += star.hfr; + } + return totalHFR / stars.size(); +} + +bool AstroAutoFocus::isPeak(int index) { + if (index <= 0 || index >= hfrScores.size() - 1) + return false; + return hfrScores[index] < hfrScores[index - 1] && + hfrScores[index] < hfrScores[index + 1]; +} + +bool AstroAutoFocus::isFalsePeak(int index) { + if (index <= WINDOW_SIZE / 2 || index >= hfrScores.size() - WINDOW_SIZE / 2) + return false; + + double localMean = 0; + for (int i = index - WINDOW_SIZE / 2; i <= index + WINDOW_SIZE / 2; ++i) { + localMean += hfrScores[i]; + } + localMean /= WINDOW_SIZE; + + double localStdDev = 0; + for (int i = index - WINDOW_SIZE / 2; i <= index + WINDOW_SIZE / 2; ++i) { + localStdDev += std::pow(hfrScores[i] - localMean, 2); + } + localStdDev = std::sqrt(localStdDev / WINDOW_SIZE); + + return localStdDev < HFR_THRESHOLD; +} + +int AstroAutoFocus::calculateAdaptiveStep(int currentStep, double previousHFR, + double currentHFR) { + double hfrChange = std::abs(currentHFR - previousHFR); + if (hfrChange < HFR_THRESHOLD / 2) { + return currentStep * 2; + } else if (hfrChange > HFR_THRESHOLD * 2) { + return std::max(1, currentStep / 2); + } + return currentStep; +} + +void AstroAutoFocus::updateFocusHistory(int position, double hfr) { + focusHistory.push_front({currentTemperature, position}); + if (focusHistory.size() > HISTORY_SIZE) { + focusHistory.pop_back(); + } +} + +int AstroAutoFocus::getTemperatureCompensatedPosition(int position) { + if (focusHistory.empty()) { + return position; + } + double tempDiff = currentTemperature - focusHistory.front().first; + int compensation = static_cast(tempDiff * TEMPERATURE_COEFFICIENT); + return position + compensation; +} + +int AstroAutoFocus::getFocusPosition() const { return currentPosition; } + +double AstroAutoFocus::getBestHFR() const { return bestHFR; } + +std::vector> AstroAutoFocus::getFocusCurve() const { + std::vector> curve; + for (size_t i = 0; i < focusPositions.size(); ++i) { + curve.emplace_back(focusPositions[i], hfrScores[i]); + } + return curve; +} + +void AstroAutoFocus::setTemperature(double temperature) { + currentTemperature = temperature; +} diff --git a/src/task/custom/autofocus/autofocus.hpp b/src/task/custom/autofocus/autofocus.hpp new file mode 100644 index 00000000..1f041062 --- /dev/null +++ b/src/task/custom/autofocus/autofocus.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include + +class StarDetector; +class FocusCurveFitter; + +class AstroAutoFocus { +public: + AstroAutoFocus(); + void focus(const std::vector& images, + const std::vector& positions, double temperature); + int getFocusPosition() const; + double getBestHFR() const; + std::vector> getFocusCurve() const; + void setTemperature(double temperature); + +private: + static const int MAX_FOCUS_STEPS = 100; + static const double HFR_THRESHOLD; + static const int WINDOW_SIZE = 5; + static const double TEMPERATURE_COEFFICIENT; + static const int HISTORY_SIZE = 10; + + std::vector hfrScores; + std::vector focusPositions; + int currentPosition; + int bestPosition; + double bestHFR; + double currentTemperature; + std::deque> focusHistory; + + StarDetector* starDetector; + FocusCurveFitter* curveFitter; + + double calculateHFR(const cv::Mat& image); + bool isPeak(int index); + bool isFalsePeak(int index); + int calculateAdaptiveStep(int currentStep, double previousHFR, + double currentHFR); + void updateFocusHistory(int position, double hfr); + int getTemperatureCompensatedPosition(int position); +}; diff --git a/src/task/custom/autofocus/curve.cpp b/src/task/custom/autofocus/curve.cpp new file mode 100644 index 00000000..150268b6 --- /dev/null +++ b/src/task/custom/autofocus/curve.cpp @@ -0,0 +1,462 @@ +#include "curve.hpp" + +#include +#include +#include +#include +#include +#include + +#include "atom/log/loguru.hpp" +#include "atom/system/command.hpp" + +class FocusCurveFitter::Impl { +public: + std::vector data_; + int polynomial_degree_ = 2; + ModelType current_model_ = ModelType::POLYNOMIAL; + + void addDataPoint(double position, double sharpness) { + data_.push_back({position, sharpness}); + } + + std::vector fitCurve() { + switch (current_model_) { + case ModelType::POLYNOMIAL: + return fitPolynomialCurve(); + case ModelType::GAUSSIAN: + return fitGaussianCurve(); + case ModelType::LORENTZIAN: + return fitLorentzianCurve(); + } + return {}; + } + + std::vector fitPolynomialCurve() { + int n = data_.size(); + int degree = polynomial_degree_; + + std::vector> X(n, std::vector(degree + 1)); + std::vector y(n); + + for (int i = 0; i < n; ++i) { + for (int j = 0; j <= degree; ++j) { + X[i][j] = std::pow(data_[i].position, j); + } + y[i] = data_[i].sharpness; + } + + auto xt = transpose(X); + auto xtX = matrixMultiply(xt, X); + auto xty = matrixVectorMultiply(xt, y); + return solveLinearSystem(xtX, xty); + } + + std::vector fitGaussianCurve() { + auto [min_it, max_it] = + std::minmax_element(data_.begin(), data_.end(), + [](const DataPoint& a, const DataPoint& b) { + return a.sharpness < b.sharpness; + }); + + std::vector initialGuess = { + max_it->sharpness - min_it->sharpness, max_it->position, 1.0, + min_it->sharpness}; + + return levenbergMarquardt( + initialGuess, [](double x, const std::vector& params) { + double A = params[0], mu = params[1], sigma = params[2], + C = params[3]; + return A * std::exp(-std::pow(x - mu, 2) / + (2 * std::pow(sigma, 2))) + + C; + }); + } + + std::vector fitLorentzianCurve() { + auto [min_it, max_it] = + std::minmax_element(data_.begin(), data_.end(), + [](const DataPoint& a, const DataPoint& b) { + return a.sharpness < b.sharpness; + }); + + std::vector initialGuess = { + max_it->sharpness - min_it->sharpness, max_it->position, 1.0, + min_it->sharpness}; + + return levenbergMarquardt( + initialGuess, [](double x, const std::vector& params) { + double A = params[0], x0 = params[1], gamma = params[2], + C = params[3]; + return A / (1 + std::pow((x - x0) / gamma, 2)) + C; + }); + } + + void autoSelectModel() { + std::vector models = { + ModelType::POLYNOMIAL, ModelType::GAUSSIAN, ModelType::LORENTZIAN}; + double bestAic = std::numeric_limits::infinity(); + ModelType bestModel = ModelType::POLYNOMIAL; + + for (const auto& model : models) { + current_model_ = model; + auto coeffs = fitCurve(); + double aic = calculateAIC(coeffs); + if (aic < bestAic) { + bestAic = aic; + bestModel = model; + } + } + + current_model_ = bestModel; + LOG_F(INFO, "Selected model: {}", getModelName(current_model_)); + } + + std::vector> calculateConfidenceIntervals( + double confidence_level = 0.95) { + auto coeffs = fitCurve(); + int n = data_.size(); + int p = coeffs.size(); + double tValue = calculateTValue(n - p, confidence_level); + + std::vector> intervals; + for (int i = 0; i < p; ++i) { + double se = calculateStandardError(coeffs, i); + intervals.emplace_back(coeffs[i] - tValue * se, + coeffs[i] + tValue * se); + } + return intervals; + } + + void visualize(const std::string& filename = "focus_curve.png") { + std::ofstream gnuplotScript("plot_script.gp"); + gnuplotScript << "set terminal png\n"; + gnuplotScript << "set output '" << filename << "'\n"; + gnuplotScript << "set title 'Focus Position Curve'\n"; + gnuplotScript << "set xlabel 'Position'\n"; + gnuplotScript << "set ylabel 'Sharpness'\n"; + gnuplotScript << "plot '-' with points title 'Data', '-' with lines " + "title 'Fitted Curve'\n"; + + for (const auto& point : data_) { + gnuplotScript << point.position << " " << point.sharpness << "\n"; + } + gnuplotScript << "e\n"; + + auto coeffs = fitCurve(); + double minPos = data_.front().position; + double maxPos = data_.back().position; + for (double pos = minPos; pos <= maxPos; + pos += (maxPos - minPos) / 100) { + gnuplotScript << pos << " " << evaluateCurve(coeffs, pos) << "\n"; + } + gnuplotScript << "e\n"; + gnuplotScript.close(); + + auto res = + atom::system::executeCommandWithStatus("gnuplot plot_script.gp"); + if (res.second != 0) { + LOG_F(ERROR, "Failed to execute gnuplot script: {}", res.first); + return; + } + LOG_F(INFO, "Curve visualization saved as {}", filename); + } + + void preprocessData() { + std::sort(data_.begin(), data_.end(), + [](const DataPoint& a, const DataPoint& b) { + return a.position < b.position; + }); + + data_.erase(std::unique(data_.begin(), data_.end(), + [](const DataPoint& a, const DataPoint& b) { + return a.position == b.position; + }), + data_.end()); + + double minPos = data_.front().position; + double maxPos = data_.back().position; + double minSharpness = std::numeric_limits::infinity(); + double maxSharpness = -std::numeric_limits::infinity(); + + for (const auto& point : data_) { + minSharpness = std::min(minSharpness, point.sharpness); + maxSharpness = std::max(maxSharpness, point.sharpness); + } + + for (auto& point : data_) { + point.position = (point.position - minPos) / (maxPos - minPos); + point.sharpness = (point.sharpness - minSharpness) / + (maxSharpness - minSharpness); + } + } + + void realTimeFitAndPredict(double new_position) { + addDataPoint(new_position, 0); + preprocessData(); + auto coeffs = fitCurve(); + double predictedSharpness = evaluateCurve(coeffs, new_position); + LOG_F(INFO, "Predicted sharpness at position {}: {}", new_position, + predictedSharpness); + } + + void parallelFitting() { + int numThreads = std::thread::hardware_concurrency(); + std::vector>> futures; + + for (int i = 0; i < numThreads; ++i) { + futures.push_back(std::async(std::launch::async, + [this]() { return fitCurve(); })); + } + + std::vector> results; + results.reserve(futures.size()); + for (auto& future : futures) { + results.push_back(future.get()); + } + + // Choose the best fit based on MSE + auto bestFit = + *std::min_element(results.begin(), results.end(), + [this](const auto& a, const auto& b) { + return calculateMSE(a) < calculateMSE(b); + }); + + LOG_F(INFO, "Best parallel fit MSE: {}", calculateMSE(bestFit)); + } + +private: + // Helper functions + + static auto matrixVectorMultiply(const std::vector>& A, + const std::vector& v) + -> std::vector { + int m = A.size(); + int n = A[0].size(); + std::vector result(m, 0.0); + + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + result[i] += A[i][j] * v[j]; + } + } + return result; + } + + static auto matrixMultiply(const std::vector>& A, + const std::vector>& B) + -> std::vector> { + int m = A.size(); + int n = A[0].size(); + int p = B[0].size(); + + std::vector> C(m, std::vector(p, 0.0)); + + for (int i = 0; i < m; ++i) { + for (int j = 0; j < p; ++j) { + for (int k = 0; k < n; ++k) { + C[i][j] += A[i][k] * B[k][j]; + } + } + } + return C; + } + + template + auto levenbergMarquardt(const std::vector& initial_guess, + Func model) -> std::vector { + const int MAX_ITERATIONS = 100; + const double TOLERANCE = 1e-6; + double lambda = 0.001; + + std::vector params = initial_guess; + int n = data_.size(); + int p = initial_guess.size(); + + for (int iter = 0; iter < MAX_ITERATIONS; ++iter) { + std::vector> J( + n, std::vector(p)); // Jacobian matrix + std::vector residuals(n); + + for (int i = 0; i < n; ++i) { + double x = data_[i].position; + double y = data_[i].sharpness; + double modelValue = model(x, params); + residuals[i] = y - modelValue; + + for (int j = 0; j < p; ++j) { + std::vector paramsDelta = params; + paramsDelta[j] += TOLERANCE; + double modelDelta = model(x, paramsDelta); + J[i][j] = (modelDelta - modelValue) / TOLERANCE; + } + } + + auto jt = transpose(J); + auto jtJ = matrixMultiply(jt, J); + for (int i = 0; i < p; ++i) { + jtJ[i][i] += lambda; + } + auto jtr = matrixVectorMultiply(jt, residuals); + auto deltaParams = solveLinearSystem(jtJ, jtr); + + for (int i = 0; i < p; ++i) { + params[i] += deltaParams[i]; + } + + if (std::inner_product(deltaParams.begin(), deltaParams.end(), + deltaParams.begin(), 0.0) < TOLERANCE) { + break; + } + } + + return params; + } + + auto calculateAIC(const std::vector& coeffs) -> double { + int n = data_.size(); + int p = coeffs.size(); + double mse = calculateMSE(coeffs); + double aic = n * std::log(mse) + 2 * p; + return aic; + } + + static auto calculateTValue(int /*degrees_of_freedom*/, + double confidence_level) -> double { + if (confidence_level == 0.95) { + return 1.96; + } + return 1.0; + } + + auto calculateStandardError(const std::vector& coeffs, + int /*index*/) -> double { + double mse = calculateMSE(coeffs); + return std::sqrt(mse); + } + + static auto transpose(const std::vector>& A) + -> std::vector> { + int m = A.size(); + int n = A[0].size(); + std::vector> At(n, std::vector(m)); + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + At[j][i] = A[i][j]; + } + } + return At; + } + + auto calculateMSE(const std::vector& coeffs) -> double { + double mse = 0.0; + for (const auto& point : data_) { + double predicted = evaluateCurve(coeffs, point.position); + mse += std::pow(predicted - point.sharpness, 2); + } + return mse / data_.size(); + } + + static auto solveLinearSystem(std::vector> A, + std::vector b) + -> std::vector { + int n = A.size(); + for (int i = 0; i < n; ++i) { + int maxRow = i; + for (int j = i + 1; j < n; ++j) { + if (std::abs(A[j][i]) > std::abs(A[maxRow][i])) { + maxRow = j; + } + } + std::swap(A[i], A[maxRow]); + std::swap(b[i], b[maxRow]); + + for (int j = i + 1; j < n; ++j) { + double factor = A[j][i] / A[i][i]; + for (int k = i; k < n; ++k) { + A[j][k] -= factor * A[i][k]; + } + b[j] -= factor * b[i]; + } + } + + std::vector x(n); + for (int i = n - 1; i >= 0; --i) { + x[i] = b[i]; + for (int j = i + 1; j < n; ++j) { + x[i] -= A[i][j] * x[j]; + } + x[i] /= A[i][i]; + } + return x; + } + + auto evaluateCurve(const std::vector& coeffs, double x) -> double { + switch (current_model_) { + case ModelType::POLYNOMIAL: + return evaluatePolynomial(coeffs, x); + case ModelType::GAUSSIAN: + return coeffs[0] * std::exp(-std::pow(x - coeffs[1], 2) / + (2 * std::pow(coeffs[2], 2))) + + coeffs[3]; + case ModelType::LORENTZIAN: + return coeffs[0] / + (1 + std::pow((x - coeffs[1]) / coeffs[2], 2)) + + coeffs[3]; + } + return 0; + } + + static auto evaluatePolynomial(const std::vector& coeffs, + double x) -> double { + double result = 0.0; + for (int i = 0; i < coeffs.size(); ++i) { + result += coeffs[i] * std::pow(x, i); + } + return result; + } + + static auto getModelName(ModelType model) -> std::string { + switch (model) { + case ModelType::POLYNOMIAL: + return "Polynomial"; + case ModelType::GAUSSIAN: + return "Gaussian"; + case ModelType::LORENTZIAN: + return "Lorentzian"; + } + return "Unknown"; + } +}; + +// Constructor and Destructor for Pimpl pattern +FocusCurveFitter::FocusCurveFitter() : impl_(new Impl()) {} +FocusCurveFitter::~FocusCurveFitter() { delete impl_; } + +// Public interface forwarding to the implementation +void FocusCurveFitter::addDataPoint(double position, double sharpness) { + impl_->addDataPoint(position, sharpness); +} + +auto FocusCurveFitter::fitCurve() -> std::vector { + return impl_->fitCurve(); +} + +void FocusCurveFitter::autoSelectModel() { impl_->autoSelectModel(); } + +auto FocusCurveFitter::calculateConfidenceIntervals(double confidence_level) + -> std::vector> { + return impl_->calculateConfidenceIntervals(confidence_level); +} + +void FocusCurveFitter::visualize(const std::string& filename) { + impl_->visualize(filename); +} + +void FocusCurveFitter::preprocessData() { impl_->preprocessData(); } + +void FocusCurveFitter::realTimeFitAndPredict(double new_position) { + impl_->realTimeFitAndPredict(new_position); +} + +void FocusCurveFitter::parallelFitting() { impl_->parallelFitting(); } diff --git a/src/task/custom/autofocus/curve.hpp b/src/task/custom/autofocus/curve.hpp new file mode 100644 index 00000000..6ddc6ee0 --- /dev/null +++ b/src/task/custom/autofocus/curve.hpp @@ -0,0 +1,35 @@ +#ifndef FOCUS_CURVE_FITTER_H +#define FOCUS_CURVE_FITTER_H + +#include +#include +#include + +enum class ModelType { POLYNOMIAL, GAUSSIAN, LORENTZIAN }; + +struct DataPoint { + double position; + double sharpness; +}; + +class FocusCurveFitter { +public: + FocusCurveFitter(); + ~FocusCurveFitter(); + + void addDataPoint(double position, double sharpness); + std::vector fitCurve(); + void autoSelectModel(); + std::vector> calculateConfidenceIntervals( + double confidence_level = 0.95); + void visualize(const std::string& filename = "focus_curve.png"); + void preprocessData(); + void realTimeFitAndPredict(double new_position); + void parallelFitting(); + +private: + class Impl; // Forward declaration of the implementation class + Impl* impl_; // Pointer to implementation (Pimpl idiom) +}; + +#endif // FOCUS_CURVE_FITTER_H diff --git a/src/task/custom/autofocus/detector.cpp b/src/task/custom/autofocus/detector.cpp new file mode 100644 index 00000000..885cb78c --- /dev/null +++ b/src/task/custom/autofocus/detector.cpp @@ -0,0 +1,44 @@ +#include "detector.hpp" + +#include + +StarDetector::StarDetector(int maxStars) : maxStars(maxStars) {} + +std::vector StarDetector::detectStars( + const cv::Mat& image) { + cv::Mat gray; + if (image.channels() > 1) { + cv::cvtColor(image, gray, cv::COLOR_BGR2GRAY); + } else { + gray = image.clone(); + } + + cv::Mat binary; + cv::adaptiveThreshold(gray, binary, 255, cv::ADAPTIVE_THRESH_GAUSSIAN_C, + cv::THRESH_BINARY, 11, 2); + + std::vector> contours; + cv::findContours(binary, contours, cv::RETR_EXTERNAL, + cv::CHAIN_APPROX_SIMPLE); + + std::vector stars; + for (const auto& contour : contours) { + if (contour.size() >= 5) { + cv::RotatedRect ellipse = cv::fitEllipse(contour); + Star star; + star.center = ellipse.center; + star.hfr = (ellipse.size.width + ellipse.size.height) / 4.0; + stars.push_back(star); + } + } + + std::sort(stars.begin(), stars.end(), + [&gray](const Star& a, const Star& b) { + return gray.at(a.center) > gray.at(b.center); + }); + if (stars.size() > maxStars) { + stars.resize(maxStars); + } + + return stars; +} diff --git a/src/task/custom/autofocus/detector.hpp b/src/task/custom/autofocus/detector.hpp new file mode 100644 index 00000000..e66d6dd9 --- /dev/null +++ b/src/task/custom/autofocus/detector.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +class StarDetector { +public: + struct Star { + cv::Point2f center; + double hfr; + }; + + StarDetector(int maxStars = 10); + std::vector detectStars(const cv::Mat& image); + +private: + int maxStars; +}; diff --git a/src/task/custom/autofocus/utils.cpp b/src/task/custom/autofocus/utils.cpp new file mode 100644 index 00000000..254193e2 --- /dev/null +++ b/src/task/custom/autofocus/utils.cpp @@ -0,0 +1,25 @@ +#include "utils.hpp" +#include +#include + +namespace Utils { + +std::vector applyNoiseReduction(const std::vector& data) { + std::vector smoothed = data; + for (size_t i = 1; i < smoothed.size() - 1; ++i) { + smoothed[i] = (data[i-1] + data[i] + data[i+1]) / 3.0; + } + return smoothed; +} + +bool isOutlier(double value, const std::vector& data) { + if (data.size() < 2) return false; + + double mean = std::accumulate(data.begin(), data.end(), 0.0) / data.size(); + double sq_sum = std::inner_product(data.begin(), data.end(), data.begin(), 0.0); + double stdev = std::sqrt(sq_sum / data.size() - mean * mean); + + return std::abs(value - mean) > 3 * stdev; +} + +} // namespace Utils diff --git a/src/task/custom/autofocus/utils.hpp b/src/task/custom/autofocus/utils.hpp new file mode 100644 index 00000000..5b341510 --- /dev/null +++ b/src/task/custom/autofocus/utils.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace Utils { + std::vector applyNoiseReduction(const std::vector& data); + bool isOutlier(double value, const std::vector& data); +} diff --git a/src/task/custom/camera/take_many_exposure.cpp b/src/task/custom/camera/take_many_exposure.cpp new file mode 100644 index 00000000..69850780 --- /dev/null +++ b/src/task/custom/camera/take_many_exposure.cpp @@ -0,0 +1,179 @@ +#include "take_many_exposure.hpp" +#include "task/utils/macro.hpp" + +#include +#include +#include +#include + +#include "config/configor.hpp" + +#include "atom/function/global_ptr.hpp" +#include "atom/log/loguru.hpp" +#include "atom/type/json.hpp" + +#include "error/exception.hpp" +#include "utils/constant.hpp" + +namespace lithium { +class TakeManyExposure::Impl { +public: + std::string camera_name_; + double exposure_time_; + + int gain; + int max_gain_; + int min_gain_; + + int offset; + int max_offset_; + int min_offset_; + + std::shared_ptr config_manager_; + std::shared_ptr task_scheduler_; +}; + +TakeManyExposure::TakeManyExposure(const json& params) { + GET_OR_CREATE_PTR(impl_->config_manager_, ConfigManager, + Constants::CONFIG_MANAGER); + GET_OR_CREATE_PTR(impl_->task_scheduler_, TaskScheduler, + Constants::TASK_SCHEDULER); + + GET_PARAM_OR_THROW(params, "camera_name", impl_->camera_name_) + GET_PARAM_OR_THROW(params, "exposure_time", impl_->exposure_time_) + GET_PARAM_OR_THROW(params, "gain", impl_->gain) + GET_PARAM_OR_THROW(params, "offset", impl_->offset) +} + +TaskScheduler::Task TakeManyExposure::validateExposure() { + if (impl_->exposure_time_ < 0 || impl_->exposure_time_ > 3600) { + LOG_F(ERROR, "Invalid exposure time: {}", impl_->exposure_time_); + THROW_INVALID_ARGUMENT("Exposure failed due to long exposure time: ", + impl_->exposure_time_); + } + + GET_CONFIG_VALUE(impl_->config_manager_, + std::format("/camera/{}/gain/max", impl_->camera_name_), + int, max_gain_); + GET_CONFIG_VALUE(impl_->config_manager_, + std::format("/camera/{}/gain/min", impl_->camera_name_), + int, min_gain_); + GET_CONFIG_VALUE( + impl_->config_manager_, + std::format("/camera/{}/gain/default", impl_->camera_name_), int, + default_gain_); + if (impl_->gain < min_gain_ || impl_->gain > max_gain_) { + LOG_F(ERROR, "Invalid gain: {}", impl_->gain); + impl_->gain = default_gain_; + //THROW_INVALID_ARGUMENT("Exposure failed due to invalid gain: ", + // impl_->gain); + co_return std::format("Exposure failed due to invalid gain: {}", + impl_->gain); + } + + co_yield std::format("Validated exposure time for camera {}: {}", + impl_->camera_name_, impl_->exposure_time_); + co_return "Validation successful for camera " + impl_->camera_name_; +} + +TaskScheduler::Task TakeManyExposure::takeExposure() { + try { + LOG_F(INFO, "Taking exposure for camera {} with {} seconds.", + camera_name_, exposure_time_); + std::this_thread::sleep_for( + std::chrono::duration(exposure_time_)); + + std::string result = "Exposure result for camera " + camera_name_ + + " with " + std::to_string(exposure_time_) + + " seconds."; + co_yield "Exposure completed: " + result; + co_return result; + } catch (const std::exception& e) { + THROW_UNLAWFUL_OPERATION("Exposure failed for camera " + camera_name_ + + ": " + e.what()); + } +} + +TaskScheduler::Task TakeManyExposure::handleExposureError() { + int exposure_time = exposure_time_; + + GET_CONFIG_VALUE(config_manager_, "/camera/retry_attempts", int, + retryAttempts); + GET_CONFIG_VALUE(config_manager_, "/camera/retry_delay", int, retryDelay); + GET_CONFIG_VALUE(config_manager_, "/camera/quality_threshold", double, + qualityThreshold); + + for (int i = 0; i < retryAttempts; ++i) { + try { + co_yield "Attempting exposure for camera " + camera_name_ + + " after " + std::to_string(i + 1) + " retry(ies)."; + + auto exposure_task = takeExposure(); + co_await exposure_task; + + if (auto result = task_scheduler_->getResult(exposure_task)) { + double quality = evaluateExposureQuality(*result); + LOG_F(INFO, "Exposure quality for camera {}: {}", camera_name_, + quality); + + if (quality >= qualityThreshold) { + co_return *result; + } else { + exposure_time = adjustExposureTime(exposure_time, quality); + LOG_F(INFO, "Adjusted exposure time for camera {}: {}", + camera_name_, exposure_time); + } + } else { + LOG_F(ERROR, + "Exposure task completed but no result produced for {}", + camera_name_); + THROW_UNLAWFUL_OPERATION( + "Exposure task completed but no result produced for " + "camera " + + camera_name_); + } + } catch (const atom::error::UnlawfulOperation& e) { + LOG_F(ERROR, "Exposure attempt {} failed for camera {}: {}", i + 1, + camera_name_, e.what()); + } + } + co_return std::format("Exposure failed for camera {} after {} retries.", + camera_name_, retryAttempts); +} + +TaskScheduler::Task TakeManyExposure::run() { + auto validate_task = + std::make_shared(validateExposure()); + task_scheduler_->schedule("validate_exposure_" + camera_name_, + validate_task); + + auto exposure_task = + std::make_shared(handleExposureError()); + exposure_task->dependencies.push_back("validate_exposure_" + camera_name_); + task_scheduler_->schedule("exposure_task_" + camera_name_, exposure_task); + + co_await *exposure_task; + + co_return "Exposure sequence completed for camera " + camera_name_; +} + +double TakeManyExposure::evaluateExposureQuality( + const std::string& exposure_result) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0.0, 1.0); + return dis(gen); +} + +int TakeManyExposure::adjustExposureTime(double current_time, double quality) { + if (quality < 0.3) { + return current_time + 2; + } else if (quality < 0.7) { + return current_time + 1; + } else if (quality > 0.9) { + return std::max(1.0, current_time - 1); + } + return current_time; +} + +} // namespace lithium diff --git a/src/task/custom/camera/take_many_exposure.hpp b/src/task/custom/camera/take_many_exposure.hpp new file mode 100644 index 00000000..a9f16d5b --- /dev/null +++ b/src/task/custom/camera/take_many_exposure.hpp @@ -0,0 +1,41 @@ +#ifndef LITHIUM_TASK_CUSTOM_CAMERA_TAKE_MANY_EXPOSURE_HPP +#define LITHIUM_TASK_CUSTOM_CAMERA_TAKE_MANY_EXPOSURE_HPP + +#include +#include +#include "atom/type/json_fwd.hpp" +#include "task/custom/cotask.hpp" +#include "task/interface/task.hpp" + +using json = nlohmann::json; + +namespace lithium { + +class TakeManyExposure : public ITask { +public: + TakeManyExposure(const json& params); + + // Function to run the full exposure sequence + auto run() -> TaskScheduler::Task override; + +private: + // Function to validate the exposure settings + auto validateExposure() -> TaskScheduler::Task; + + // Function to execute an exposure + auto takeExposure() -> TaskScheduler::Task; + + // Function to handle errors and retry logic + auto handleExposureError() -> TaskScheduler::Task; + + // Utility functions + auto evaluateExposureQuality(const std::string& exposure_result) -> double; + auto adjustExposureTime(double current_time, double quality) -> int; + + class Impl; + std::unique_ptr impl_; +}; + +} // namespace lithium + +#endif diff --git a/src/task/custom/cotask.cpp b/src/task/custom/cotask.cpp new file mode 100644 index 00000000..10d0325d --- /dev/null +++ b/src/task/custom/cotask.cpp @@ -0,0 +1,143 @@ +#include "cotask.hpp" +#include + +#include "atom/log/loguru.hpp" + +auto TaskScheduler::Task::promise_type::get_return_object() -> Task { + return Task{handle_type::from_promise(*this)}; +} + +auto TaskScheduler::Task::promise_type::initial_suspend() + -> std::suspend_never { + return {}; +} + +auto TaskScheduler::Task::promise_type::final_suspend() noexcept + -> std::suspend_always { + return {}; +} + +void TaskScheduler::Task::promise_type::return_value(std::string value) { + result = std::move(value); +} + +auto TaskScheduler::Task::promise_type::yield_value(std::string value) + -> std::suspend_always { + result = std::move(value); + return {}; +} + +void TaskScheduler::Task::promise_type::unhandled_exception() { + result = std::current_exception(); + if (exceptionHandler) { + try { + std::rethrow_exception(std::get(result)); + } catch (const std::exception& e) { + exceptionHandler(e); + } + } +} + +// Task constructors, destructor, and move operations + +TaskScheduler::Task::Task(handle_type h) : handle(h) {} + +TaskScheduler::Task::~Task() { + if (handle) + handle.destroy(); +} + +TaskScheduler::Task::Task(Task&& other) noexcept : handle(other.handle) { + other.handle = nullptr; +} + +auto TaskScheduler::Task::operator=(Task&& other) noexcept -> Task& { + if (this != &other) { + if (handle) { + handle.destroy(); + } + handle = other.handle; + other.handle = nullptr; + } + return *this; +} + +// Task::await methods + +auto TaskScheduler::Task::await_ready() const noexcept -> bool { + return handle.done(); +} + +void TaskScheduler::Task::await_suspend(std::coroutine_handle<> h) const { + handle.resume(); +} + +void TaskScheduler::Task::await_resume() const {} + +// TaskScheduler methods + +void TaskScheduler::schedule(std::string id, std::shared_ptr task) { + tasks_[id] = std::move(task); + LOG_F(INFO, "Scheduling task: {}", id); +} + +void TaskScheduler::setGlobalExceptionHandler( + std::function handler) { + global_exception_handler_ = std::move(handler); +} + +void TaskScheduler::run() { + while (!tasks_.empty()) { + for (auto it = tasks_.begin(); it != tasks_.end();) { + if (areDependenciesMet(it->second->dependencies)) { + try { + if (it->second->handle) { + it->second->handle.resume(); + if (it->second->handle.done()) { + completed_tasks_.insert(it->first); + it = tasks_.erase(it); + continue; + } + } + } catch (const std::exception& e) { + handleException(e, it->second); + it = tasks_.erase(it); + continue; + } + } + ++it; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } +} + +auto TaskScheduler::getResult(const Task& task) -> std::optional { + if (auto* result = + std::get_if(&task.handle.promise().result)) { + return *result; + } + if (auto* exPtr = + std::get_if(&task.handle.promise().result)) { + std::rethrow_exception(*exPtr); + } + return std::nullopt; +} + +auto TaskScheduler::areDependenciesMet( + const std::vector& dependencies) -> bool { + return std::all_of(dependencies.begin(), dependencies.end(), + [this](const std::string& dep) { + return completed_tasks_.contains(dep); + }); +} + +void TaskScheduler::handleException(const std::exception& e, + const std::shared_ptr& task) { + if (task->handle.promise().exceptionHandler) { + task->handle.promise().exceptionHandler(e); + } else if (global_exception_handler_) { + global_exception_handler_(e); + } else { + LOG_F(ERROR, "Unhandled task exception: {}", e.what()); + } +} diff --git a/src/task/custom/cotask.hpp b/src/task/custom/cotask.hpp new file mode 100644 index 00000000..4f2a4cc6 --- /dev/null +++ b/src/task/custom/cotask.hpp @@ -0,0 +1,71 @@ +#ifndef TASK_SCHEDULER_H +#define TASK_SCHEDULER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class TaskScheduler { +public: + struct Task { + struct promise_type; + using handle_type = std::coroutine_handle; + + struct promise_type { + std::variant + result; + std::function exceptionHandler; + + auto get_return_object() -> Task; + auto initial_suspend() -> std::suspend_never; + auto final_suspend() noexcept -> std::suspend_always; + void return_value(std::string value); + auto yield_value(std::string value) -> std::suspend_always; + void unhandled_exception(); + }; + + handle_type handle; + std::vector dependencies; + + Task(handle_type h); + ~Task(); + Task(const Task&) = delete; + Task(Task&& other) noexcept; + auto operator=(Task&& other) noexcept -> Task&; + + void set_exception_handler( + std::function handler); + + auto await_ready() const noexcept -> bool; + void await_suspend(std::coroutine_handle<> h) const; + void await_resume() const; + }; + + void schedule(std::string id, std::shared_ptr task); + + void setGlobalExceptionHandler( + std::function handler); + + void run(); + + static auto getResult(const Task& task) -> std::optional; + +private: + std::unordered_map> tasks_; + std::unordered_set completed_tasks_; + std::function global_exception_handler_; + + auto areDependenciesMet(const std::vector& dependencies) + -> bool; + void handleException(const std::exception& e, + const std::shared_ptr& task); +}; + +#endif // TASK_SCHEDULER_H diff --git a/src/task/custom/guider/dither.hpp b/src/task/custom/guider/dither.hpp new file mode 100644 index 00000000..e69de29b diff --git a/src/task/imagepath.cpp b/src/task/imagepath.cpp index 4a7894e8..c9f2739d 100644 --- a/src/task/imagepath.cpp +++ b/src/task/imagepath.cpp @@ -1,16 +1,18 @@ #include "imagepath.hpp" +#include #include #include -#include -#include +#include #include "atom/log/loguru.hpp" #include "atom/type/json.hpp" #include "atom/utils/string.hpp" + namespace fs = std::filesystem; namespace lithium { + auto ImageInfo::toJson() const -> json { return {{"path", path}, {"dateTime", dateTime.value_or("")}, @@ -32,44 +34,160 @@ auto ImageInfo::fromJson(const json& j) -> ImageInfo { info.exposureTime = j.value("exposureTime", ""); info.frameNr = j.value("frameNr", ""); } catch (const std::exception& e) { - LOG_F(ERROR, "Error deserializing ImageInfo from JSON: {}", - e.what()); + LOG_F(ERROR, "Error deserializing ImageInfo from JSON: {}", e.what()); } return info; } -ImagePatternParser::ImagePatternParser(const std::string& pattern, - char delimiter) - : delimiter_(delimiter) { - parsePattern(pattern); -} -auto ImagePatternParser::parseFilename(const std::string& filename) const - -> std::optional { - ImageInfo info; - info.path = fs::absolute(fs::path(filename)).string(); +class ImagePatternParser::Impl { +public: + explicit Impl(const std::string& pattern, char delimiter) + : delimiter_(delimiter) { + parsePattern(pattern); + } - // Remove file extension - std::string name = filename.substr(0, filename.find_last_of('.')); + [[nodiscard]] auto parseFilename(const std::string& filename) const + -> std::optional { + ImageInfo info; + info.path = fs::absolute(fs::path(filename)).string(); - // Split into parts - std::vector parts = atom::utils::splitString(name, delimiter_); - if (parts.size() < patterns_.size()) { - LOG_F(ERROR, "Filename does not match the pattern: {}", name); - return std::nullopt; + // Remove file extension + auto name = filename.substr(0, filename.find_last_of('.')); + + // Split into parts + auto parts = atom::utils::splitString(name, delimiter_); + if (parts.size() < patterns_.size()) { + LOG_F(ERROR, "Filename does not match the pattern: {}", name); + return std::nullopt; + } + + // Assign parts dynamically based on the pattern + for (size_t i = 0; i < patterns_.size(); ++i) { + const auto& key = patterns_[i]; + const auto& value = parts[i]; + if (auto it = parsers_.find(key); it != parsers_.end()) { + it->second(info, value); + } else { + LOG_F(ERROR, "No parser for key: {}", key); + } + } + + return info; + } + + void addCustomParser(const std::string& key, FieldParser parser) { + parsers_[key] = std::move(parser); + } + + void setOptionalField(const std::string& key, + const std::string& defaultValue) { + optionalFields_[key] = defaultValue; + } + + [[nodiscard]] auto getPatterns() const -> const std::vector& { + return patterns_; } - // Assign parts dynamically based on the pattern - for (size_t i = 0; i < patterns_.size(); ++i) { - const auto& key = patterns_[i]; - const auto& value = parts[i]; - if (parsers_.find(key) != parsers_.end()) { - parsers_.at(key)(info, value); - } else { - LOG_F(ERROR, "No parser for key: {}", key); + [[nodiscard]] auto getDelimiter() const -> char { return delimiter_; } + +private: + std::vector patterns_; + std::unordered_map parsers_; + std::unordered_map optionalFields_; + char delimiter_; + + void parsePattern(const std::string& pattern) { + std::string temp; + bool inVar = false; + for (char ch : pattern) { + if (ch == '$') { + if (inVar) { + patterns_.push_back(temp); + temp.clear(); + } + inVar = !inVar; + } else if (inVar) { + temp += ch; + } } + + initializeParsers(); } - return info; + void initializeParsers() { + parsers_["DATETIME"] = [](ImageInfo& info, const std::string& value) { + info.dateTime = validateDateTime(value) + ? std::optional(value) + : std::nullopt; + }; + parsers_["IMAGETYPE"] = [](ImageInfo& info, const std::string& value) { + info.imageType = !value.empty() ? std::optional(value) + : std::nullopt; + }; + parsers_["FILTER"] = [](ImageInfo& info, const std::string& value) { + info.filter = !value.empty() ? std::optional(value) + : std::nullopt; + }; + parsers_["SENSORTEMP"] = [](ImageInfo& info, const std::string& value) { + info.sensorTemp = formatTemperature(value); + }; + parsers_["EXPOSURETIME"] = [](ImageInfo& info, + const std::string& value) { + if (auto pos = value.find('s'); pos != std::string::npos) { + info.exposureTime = value.substr(0, pos); + } + }; + parsers_["FRAMENR"] = [](ImageInfo& info, const std::string& value) { + info.frameNr = !value.empty() ? std::optional(value) + : std::nullopt; + }; + + // Set default values for optional fields + for (const auto& [key, value] : optionalFields_) { + if (parsers_.find(key) == parsers_.end()) { + parsers_[key] = [value](ImageInfo&, const std::string&) { + // No-op: Assign default value if key is not present + }; + } + } + } + + static auto validateDateTime(const std::string& dateTime) -> bool { + static const std::regex dateTimePattern( + R"(^\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$)"); + return std::regex_match(dateTime, dateTimePattern); + } + + static auto formatTemperature(const std::string& temp) -> std::string { + // Assume temperature is in the format -10.0, format to 1 decimal place + float t; + auto [ptr, ec] = + std::from_chars(temp.data(), temp.data() + temp.size(), t); + if (ec == std::errc()) { + char buffer[16]; + auto [ptr2, ec2] = std::to_chars(buffer, buffer + sizeof(buffer), t, + std::chars_format::fixed, 1); + if (ec2 == std::errc()) { + return std::string(buffer, ptr2 - buffer); + } + } + return temp; // Return as is if parsing fails + } +}; + +ImagePatternParser::ImagePatternParser(const std::string& pattern, + char delimiter) + : pImpl(std::make_unique(pattern, delimiter)) {} + +ImagePatternParser::~ImagePatternParser() = default; + +ImagePatternParser::ImagePatternParser(ImagePatternParser&&) noexcept = default; +ImagePatternParser& ImagePatternParser::operator=( + ImagePatternParser&&) noexcept = default; + +auto ImagePatternParser::parseFilename(const std::string& filename) const + -> std::optional { + return pImpl->parseFilename(filename); } auto ImagePatternParser::serializeToJson(const ImageInfo& info) -> json { @@ -80,89 +198,22 @@ auto ImagePatternParser::deserializeFromJson(const json& j) -> ImageInfo { return ImageInfo::fromJson(j); } -// Allow adding custom parsers for new elements void ImagePatternParser::addCustomParser(const std::string& key, FieldParser parser) { - parsers_[key] = std::move(parser); + pImpl->addCustomParser(key, std::move(parser)); } -// Allow optional fields void ImagePatternParser::setOptionalField(const std::string& key, const std::string& defaultValue) { - optionalFields_[key] = defaultValue; + pImpl->setOptionalField(key, defaultValue); } -void ImagePatternParser::parsePattern(const std::string& pattern) { - std::string temp; - bool inVar = false; - for (char ch : pattern) { - if (ch == '$') { - if (inVar) { - patterns_.push_back(temp); - temp.clear(); - } - inVar = !inVar; - } else if (inVar) { - temp += ch; - } - } - - // Initialize parsers - parsers_["DATETIME"] = [](ImageInfo& info, const std::string& value) { - info.dateTime = validateDateTime(value) - ? std::optional(value) - : std::nullopt; - }; - parsers_["IMAGETYPE"] = [](ImageInfo& info, const std::string& value) { - info.imageType = - value.empty() ? std::nullopt : std::optional(value); - }; - parsers_["FILTER"] = [](ImageInfo& info, const std::string& value) { - info.filter = - value.empty() ? std::nullopt : std::optional(value); - }; - parsers_["SENSORTEMP"] = [](ImageInfo& info, const std::string& value) { - info.sensorTemp = formatTemperature(value); - }; - parsers_["EXPOSURETIME"] = [](ImageInfo& info, const std::string& value) { - size_t pos = value.find('s'); - if (pos != std::string::npos) { - info.exposureTime = value.substr(0, pos); - } - }; - parsers_["FRAMENR"] = [](ImageInfo& info, const std::string& value) { - info.frameNr = - value.empty() ? std::nullopt : std::optional(value); - }; - - // Set default values for optional fields - for (const auto& [key, value] : optionalFields_) { - if (parsers_.find(key) == parsers_.end()) { - parsers_[key] = [value]([[maybe_unused]] ImageInfo& info, - const std::string&) { - // No-op: Assign default value if key is not present - }; - } - } +auto ImagePatternParser::getPatterns() const -> std::vector { + return pImpl->getPatterns(); } -auto ImagePatternParser::validateDateTime(const std::string& dateTime) -> bool { - // Simple regex validation for DateTime: YYYY-MM-DD-HH-MM-SS - std::regex dateTimePattern(R"(^\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$)"); - return std::regex_match(dateTime, dateTimePattern); +auto ImagePatternParser::getDelimiter() const -> char { + return pImpl->getDelimiter(); } -auto ImagePatternParser::formatTemperature(const std::string& temp) - -> std::string { - // Assume temperature is in the format -10.0, format to 1 decimal place - std::ostringstream oss; - try { - float t = std::stof(temp); - oss.precision(1); - oss << std::fixed << t; - } catch (const std::exception&) { - return temp; // Return as is if parsing fails - } - return oss.str(); -} } // namespace lithium diff --git a/src/task/imagepath.hpp b/src/task/imagepath.hpp index 62d51b26..58d20b5e 100644 --- a/src/task/imagepath.hpp +++ b/src/task/imagepath.hpp @@ -2,9 +2,9 @@ #define LITHIUM_TASK_IMAGEPATH_HPP #include +#include #include #include -#include #include #include "macro.hpp" @@ -13,6 +13,7 @@ using json = nlohmann::json; namespace lithium { + struct ImageInfo { std::string path; std::optional dateTime; @@ -23,8 +24,9 @@ struct ImageInfo { std::optional frameNr; [[nodiscard]] auto toJson() const -> json; - static auto fromJson(const json& j) -> ImageInfo; + + bool operator==(const ImageInfo&) const = default; } ATOM_ALIGNAS(128); class ImagePatternParser { @@ -33,33 +35,43 @@ class ImagePatternParser { explicit ImagePatternParser(const std::string& pattern, char delimiter = '_'); + ~ImagePatternParser(); - auto parseFilename(const std::string& filename) const + // Disable copy operations + ImagePatternParser(const ImagePatternParser&) = delete; + ImagePatternParser& operator=(const ImagePatternParser&) = delete; + + // Enable move operations + ImagePatternParser(ImagePatternParser&&) noexcept; + ImagePatternParser& operator=(ImagePatternParser&&) noexcept; + + [[nodiscard]] auto parseFilename(const std::string& filename) const -> std::optional; static auto serializeToJson(const ImageInfo& info) -> json; - static auto deserializeFromJson(const json& j) -> ImageInfo; - // Allow adding custom parsers for new elements void addCustomParser(const std::string& key, FieldParser parser); - - // Allow optional fields void setOptionalField(const std::string& key, const std::string& defaultValue); -private: - std::vector patterns_; - std::unordered_map parsers_; - std::unordered_map optionalFields_; - char delimiter_; + // New methods + [[nodiscard]] auto getPatterns() const -> std::vector; + [[nodiscard]] auto getDelimiter() const -> char; - void parsePattern(const std::string& pattern); +private: + class Impl; + std::unique_ptr pImpl; +}; - static auto validateDateTime(const std::string& dateTime) -> bool; +// Template function to parse multiple filenames +template +[[nodiscard]] auto parseMultipleFilenames(const ImagePatternParser& parser, + const Filenames&... filenames) + -> std::vector> { + return {parser.parseFilename(filenames)...}; +} - static auto formatTemperature(const std::string& temp) -> std::string; -}; } // namespace lithium #endif diff --git a/src/task/interface/task.hpp b/src/task/interface/task.hpp new file mode 100644 index 00000000..928e23fd --- /dev/null +++ b/src/task/interface/task.hpp @@ -0,0 +1,11 @@ +#ifndef LITHIUM_TASK_INTERFACE_TASK_HPP +#define LITHIUM_TASK_INTERFACE_TASK_HPP + +#include "task/custom/cotask.hpp" + +class ITask { +public: + virtual auto run() -> TaskScheduler::Task = 0; +}; + +#endif diff --git a/src/task/manager.cpp b/src/task/manager.cpp index c004301f..ad9a4370 100644 --- a/src/task/manager.cpp +++ b/src/task/manager.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -105,6 +106,9 @@ class TaskInterpreterImpl { std::shared_ptr taskGenerator_; std::shared_ptr> threadPool_; + + std::unordered_map> coroutines; + std::vector> transactionRollbackActions; }; TaskInterpreter::TaskInterpreter() @@ -138,9 +142,8 @@ auto TaskInterpreter::createShared() -> std::shared_ptr { } void TaskInterpreter::loadScript(const std::string& name, const json& script) { -#if ENABLE_DEBUG LOG_F(INFO, "Loading script: {} with {}", name, script.dump()); -#endif + std::unique_lock lock(impl_->mtx_); impl_->scripts_[name] = script.contains("steps") ? script["steps"] : script; lock.unlock(); @@ -157,7 +160,7 @@ void TaskInterpreter::loadScript(const std::string& name, const json& script) { header.contains("author") ? header["author"].get() : "unknown"); - // 存储头部信息 + impl_->scriptHeaders_[name] = header; if (header.contains("auto_execute") && header["auto_execute"].is_boolean() && @@ -178,12 +181,13 @@ void TaskInterpreter::unloadScript(const std::string& name) { impl_->scripts_.erase(name); } -auto TaskInterpreter::hasScript(const std::string& name) const -> bool { +auto TaskInterpreter::hasScript(const std::string& name) const noexcept + -> bool { std::shared_lock lock(impl_->mtx_); return impl_->scripts_.contains(name); } -auto TaskInterpreter::getScript(const std::string& name) const +auto TaskInterpreter::getScript(const std::string& name) const noexcept -> std::optional { std::shared_lock lock(impl_->mtx_); if (impl_->scripts_.contains(name)) { @@ -267,11 +271,13 @@ auto TaskInterpreter::getVariable(const std::string& name) const -> json { void TaskInterpreter::parseLabels(const json& script) { std::unique_lock lock(impl_->mtx_); LOG_F(INFO, "Parsing labels..."); - for (size_t i = 0; i < script.size(); ++i) { - if (script[i].contains("label")) { - impl_->labels_[script[i]["label"]] = i; - } - } + std::for_each(script.begin(), script.end(), + [this, index = 0](const auto& item) mutable { + if (item.contains("label")) { + impl_->labels_[item["label"]] = index; + } + ++index; + }); } void TaskInterpreter::execute(const std::string& scriptName) { @@ -295,10 +301,19 @@ void TaskInterpreter::execute(const std::string& scriptName) { size_t i = 0; while (i < script.size() && !impl_->stopRequested_) { - if (!executeStep(script[i], i, script)) { + const auto& step = script[i]; + if (step.contains("type") && step["type"] == "coroutine") { + if (!step.contains("name") || !step["name"].is_string()) { + throw std::runtime_error( + "Coroutine step must have a 'name' field"); + } + std::string coroutineName = step["name"]; + auto handle = executeCoroutine(step).handle(); + impl_->coroutines[coroutineName] = handle; + } else if (!executeStep(step, i, script)) { break; } - i++; + ++i; } } catch (...) { exPtr = std::current_exception(); @@ -949,8 +964,8 @@ void TaskInterpreter::executeImport(const json& step) { std::condition_variable cv; bool callbackCalled = false; - std::string fullPath = constants::TASK_FOLDER + scriptName + - constants::PATH_SEPARATOR + ".json"; + std::string fullPath = Constants::TASK_FOLDER + scriptName + + Constants::PATH_SEPARATOR + ".json"; LOG_F(INFO, "Importing script from file: {}", fullPath); // Asynchronously read the script file @@ -1182,13 +1197,14 @@ void TaskInterpreter::executeContinue(const json& /*step*/, size_t& idx) { idx = std::numeric_limits::max() - 1; // Skip to the next iteration } -void TaskInterpreter::executeSteps(const json& steps, size_t& idx, - const json& script) { - for (const auto& step : steps) { - if (!executeStep(step, idx, script)) { - break; - } - } +void TaskInterpreter::executeSteps(const nlohmann::json& steps, size_t& idx, + const nlohmann::json& script) { + auto stepView = + steps | std::views::take_while([this, &idx, &script](const auto& step) { + return !impl_->stopRequested_ && executeStep(step, idx, script); + }); + + std::ranges::for_each(stepView, [](const auto&) {}); } void TaskInterpreter::executeMessage(const json& step) { @@ -1405,25 +1421,126 @@ void TaskInterpreter::executeRetry(const json& step, size_t& idx, } } +void TaskInterpreter::executeTransaction(const json& step, size_t& idx, + const json& script) { + impl_->transactionRollbackActions.clear(); + try { + executeSteps(step["steps"], idx, script); + executeCommit(step); + } catch (...) { + executeRollback(step); + throw; + } +} + +void TaskInterpreter::executeRollback(const json& step) { + for (auto& transactionRollbackAction : + std::ranges::reverse_view(impl_->transactionRollbackActions)) { + transactionRollbackAction(); + } + impl_->transactionRollbackActions.clear(); +} + +void TaskInterpreter::executeCommit(const json& step) { + impl_->transactionRollbackActions.clear(); +} + +void TaskInterpreter::executeAtomicOperation(const json& step) { + std::atomic_flag lock = ATOMIC_FLAG_INIT; + while (lock.test_and_set(std::memory_order_acquire)) { + std::this_thread::yield(); + } + try { + size_t idx = 0; + executeSteps(step["steps"], idx, step); + } catch (...) { + lock.clear(std::memory_order_release); + throw; + } + lock.clear(std::memory_order_release); +} + +auto TaskInterpreter::executeCoroutine(const json& step) -> TaskCoroutine { + if (!step.contains("steps") || !step["steps"].is_array()) { + THROW_MISSING_ARGUMENT("Coroutine step must contain a 'steps' array"); + } + + for (const auto& subStep : step["steps"]) { + if (subStep.contains("type")) { + std::string stepType = subStep["type"]; + + if (stepType == "async") { + // Execute async step + auto future = + std::async(std::launch::async, [this, &subStep]() { + size_t idx = 0; + executeStep(subStep, idx, subStep); + }); + + // Yield control back to the caller + co_await std::suspend_always{}; + + // Wait for the async operation to complete + future.wait(); + } else if (stepType == "delay") { + if (!subStep.contains("duration") || + !subStep["duration"].is_number()) { + THROW_MISSING_ARGUMENT( + "Delay step must contain a 'duration' number"); + } + + int duration = subStep["duration"].get(); + + // Start the delay + auto start = std::chrono::steady_clock::now(); + + // Yield control back to the caller + co_await std::suspend_always{}; + + // Resume and check if the delay has passed + while (std::chrono::steady_clock::now() - start < + std::chrono::milliseconds(duration)) { + co_await std::suspend_always{}; + } + } else { + // Execute regular step + size_t idx = 0; + executeStep(subStep, idx, subStep); + } + } + } + + co_return; +} + +// Helper method to resume a coroutine +void TaskInterpreter::resumeCoroutine(const std::string& coroutineName) { + auto it = impl_->coroutines.find(coroutineName); + if (it != impl_->coroutines.end() && !it->second.done()) { + it->second.resume(); + } +} + auto TaskInterpreter::evaluate(const json& value) -> json { if (value.is_string()) { std::string valStr = value.get(); - // 检查是否是变量 - if (impl_->variables_.contains(valStr)) { + if (impl_->variables_.contains(std::string(valStr))) { std::shared_lock lock(impl_->mtx_); - return impl_->variables_.at(valStr).second; + return impl_->variables_.at(std::string(valStr)).second; } - // 检查是否是表达式 - if (valStr.find_first_of("+-*/%^!&|<=>") != std::string::npos) { - return evaluateExpression(valStr); // 解析并评估表达式 + if (std::ranges::any_of(std::array{'+', '-', '*', '/', '%', '^', '!', + '&', '|', '<', '=', '>'}, + [&valStr](char op) { + return valStr.find(op) != + std::string_view::npos; + })) { + return evaluateExpression(valStr); } - // 检查是否是 $ 开头的运算符表达式 - if (valStr.starts_with("$")) { - return evaluateExpression( - valStr.substr(1)); // 去掉 $ 前缀后评估表达式 + if (valStr.starts_with('$')) { + return evaluateExpression(valStr.substr(1)); } } @@ -1571,10 +1688,33 @@ auto TaskInterpreter::evaluate(const json& value) -> json { } auto TaskInterpreter::evaluateExpression(const std::string& expr) -> json { - std::istringstream iss(expr); + std::vector tokens; std::stack operators; std::stack operands; - std::string token; + + // Tokenize the expression + size_t start = 0; + for (size_t i = 0; i < expr.size(); ++i) { + if (std::isspace(expr[i])) { + if (start != i) { + tokens.push_back(expr.substr(start, i - start)); + } + start = i + 1; + } else if (expr[i] == '(' || expr[i] == ')' || expr[i] == '+' || + expr[i] == '-' || expr[i] == '*' || expr[i] == '/' || + expr[i] == '%' || expr[i] == '^' || expr[i] == '<' || + expr[i] == '>' || expr[i] == '=' || expr[i] == '!' || + expr[i] == '&' || expr[i] == '|') { + if (start != i) { + tokens.push_back(expr.substr(start, i - start)); + } + tokens.push_back(expr.substr(i, 1)); + start = i + 1; + } + } + if (start < expr.size()) { + tokens.push_back(expr.substr(start)); + } auto applyOperator = [](char op, double a, double b) -> double { switch (op) { @@ -1585,17 +1725,13 @@ auto TaskInterpreter::evaluateExpression(const std::string& expr) -> json { case '*': return a * b; case '/': - if (b != 0) { - return a / b; - } else { - THROW_INVALID_ARGUMENT("Division by zero."); - } + if (b == 0) + throw std::runtime_error("Division by zero"); + return a / b; case '%': - if (b != 0) { - return std::fmod(a, b); - } else { - THROW_INVALID_ARGUMENT("Modulo by zero."); - } + if (b == 0) + throw std::runtime_error("Modulo by zero"); + return std::fmod(a, b); case '^': return std::pow(a, b); case '<': @@ -1613,30 +1749,21 @@ auto TaskInterpreter::evaluateExpression(const std::string& expr) -> json { return static_cast(static_cast(a) || static_cast(b)); default: - THROW_RUNTIME_ERROR("Unknown operator."); + throw std::runtime_error("Unknown operator"); } }; - while (iss >> token) { - // 处理数字和变量 - if ((std::isdigit(token[0]) != 0) || token[0] == '.') { - operands.push(std::stod(token)); - } else if (impl_->variables_.contains(token)) { - operands.push(impl_->variables_.at(token).second.get()); - } else if (token == "+" || token == "-" || token == "*" || - token == "/" || token == "%" || token == "^" || - token == "<" || token == ">" || token == "==" || - token == "!=" || token == "&&" || token == "||") { - // 处理操作符 + for (const auto& token : tokens) { + if (token.size() == 1 && + std::string("+-*/%^<>=!&|").find(token[0]) != std::string::npos) { while (!operators.empty() && precedence(operators.top()) >= precedence(token[0])) { double b = operands.top(); operands.pop(); double a = operands.top(); operands.pop(); - char op = operators.top(); + operands.push(applyOperator(operators.top(), a, b)); operators.pop(); - operands.push(applyOperator(op, a, b)); } operators.push(token[0]); } else if (token == "(") { @@ -1647,38 +1774,57 @@ auto TaskInterpreter::evaluateExpression(const std::string& expr) -> json { operands.pop(); double a = operands.top(); operands.pop(); - char op = operators.top(); + operands.push(applyOperator(operators.top(), a, b)); operators.pop(); - operands.push(applyOperator(op, a, b)); } - if (operators.empty() || operators.top() != '(') { - THROW_INVALID_ARGUMENT("Mismatched parentheses."); + if (operators.empty()) { + throw std::runtime_error("Mismatched parentheses"); } - operators.pop(); // 移除左括号 + operators.pop(); // Remove '(' } else { - THROW_INVALID_ARGUMENT("Invalid token in expression: " + token); + // Parse number or variable + if (token[0] == '$') { + // Variable + std::string varName(token.substr(1)); + std::shared_lock lock(impl_->mtx_); + if (impl_->variables_.contains(varName)) { + operands.push( + impl_->variables_.at(varName).second.get()); + } else { + throw std::runtime_error("Undefined variable: " + varName); + } + } else { + // Number + double value; + auto [ptr, ec] = std::from_chars( + token.data(), token.data() + token.size(), value); + if (ec == std::errc()) { + operands.push(value); + } else { + throw std::runtime_error("Invalid token: " + + std::string(token)); + } + } } } - // 处理剩余操作符 while (!operators.empty()) { double b = operands.top(); operands.pop(); double a = operands.top(); operands.pop(); - char op = operators.top(); + operands.push(applyOperator(operators.top(), a, b)); operators.pop(); - operands.push(applyOperator(op, a, b)); } if (operands.size() != 1) { - THROW_INVALID_ARGUMENT("Invalid expression: " + expr); + throw std::runtime_error("Invalid expression"); } return operands.top(); } -auto TaskInterpreter::precedence(char op) -> int { +auto TaskInterpreter::precedence(char op) noexcept -> int { switch (op) { case '+': case '-': diff --git a/src/task/manager.hpp b/src/task/manager.hpp index cf34c65f..ac508ab3 100644 --- a/src/task/manager.hpp +++ b/src/task/manager.hpp @@ -21,6 +21,7 @@ #ifndef LITHIUM_TASK_INTERPRETER_HPP #define LITHIUM_TASK_INTERPRETER_HPP +#include #include #include #include @@ -36,6 +37,56 @@ enum class VariableType { NUMBER, STRING, BOOLEAN, JSON, UNKNOWN }; auto determineType(const json& value) -> VariableType; +class TaskCoroutine { +public: + struct promise_type; + using handle_type = std::coroutine_handle; + + TaskCoroutine(handle_type h) : coro(h) {} + TaskCoroutine(const TaskCoroutine&) = delete; + TaskCoroutine& operator=(const TaskCoroutine&) = delete; + TaskCoroutine(TaskCoroutine&& other) noexcept : coro(other.coro) { + other.coro = nullptr; + } + TaskCoroutine& operator=(TaskCoroutine&& other) noexcept { + if (this != &other) { + if (coro) + coro.destroy(); + coro = other.coro; + other.coro = nullptr; + } + return *this; + } + ~TaskCoroutine() { + if (coro) + coro.destroy(); + } + + bool resume() { + if (!coro || coro.done()) + return false; + coro.resume(); + return !coro.done(); + } + + bool done() const { return !coro || coro.done(); } + + handle_type handle() const { return coro; } + + struct promise_type { + TaskCoroutine get_return_object() { + return TaskCoroutine(handle_type::from_promise(*this)); + } + std::suspend_never initial_suspend() { return {}; } + std::suspend_never final_suspend() noexcept { return {}; } + void return_void() {} + void unhandled_exception() { std::terminate(); } + }; + +private: + handle_type coro; +}; + class TaskInterpreterImpl; class TaskInterpreter { @@ -43,13 +94,18 @@ class TaskInterpreter { TaskInterpreter(); ~TaskInterpreter(); + TaskInterpreter(const TaskInterpreter&) = delete; + auto operator=(const TaskInterpreter&) -> TaskInterpreter& = delete; + TaskInterpreter(TaskInterpreter&&) noexcept = default; + auto operator=(TaskInterpreter&&) noexcept -> TaskInterpreter& = default; + static auto createShared() -> std::shared_ptr; void loadScript(const std::string& name, const json& script); void unloadScript(const std::string& name); - [[nodiscard]] auto hasScript(const std::string& name) const -> bool; - [[nodiscard]] auto getScript(const std::string& name) const + [[nodiscard]] auto hasScript(const std::string& name) const noexcept -> bool; + [[nodiscard]] auto getScript(const std::string& name) const noexcept -> std::optional; void registerFunction(const std::string& name, @@ -116,15 +172,36 @@ class TaskInterpreter { void executeBroadcastEvent(const json& step); void executeListenEvent(const json& step, size_t& idx); + auto executeCoroutine(const json& step) -> TaskCoroutine; + void resumeCoroutine(const std::string& coroutineName); + void executeTransaction(const json& step, size_t& idx, const json& script); + void executeRollback(const json& step); + void executeCommit(const json& step); + void executeAtomicOperation(const json& step); + auto evaluate(const json& value) -> json; auto evaluateExpression(const std::string& expr) -> json; - auto precedence(char op) -> int; + auto precedence(char op) noexcept -> int; void throwCustomError(const std::string& name); void handleException(const std::string& scriptName, const std::exception& e); std::unique_ptr impl_; + + template + auto getAtomicPtr(std::atomic>& atomic_ptr) const { + return std::atomic_load(&atomic_ptr); + } + + template + void updateAtomicPtr(std::atomic>& atomic_ptr, + const std::function& update_func) { + auto currentPtr = getAtomicPtr(atomic_ptr); + auto newPtr = std::make_shared(*currentPtr); + update_func(*newPtr); + std::atomic_store(&atomic_ptr, newPtr); + } }; } // namespace lithium diff --git a/src/task/sequencer.cpp b/src/task/sequencer.cpp deleted file mode 100644 index 58cefbea..00000000 --- a/src/task/sequencer.cpp +++ /dev/null @@ -1,201 +0,0 @@ -/** - * @file sequencer.cpp - * @brief Definition of classes for managing and executing task sequences. - * - * This file defines the `Target` and `ExposureSequence` classes for managing - * and executing sequences of tasks. The `Target` class represents a unit that - * can hold and execute tasks with a configurable delay and priority. The - * `ExposureSequence` class manages a collection of `Target` objects and - * coordinates their execution, allowing for task sequences to be executed in - * parallel or serially with options for pausing, resuming, and stopping. - * - * Key features: - * - `Target` class: Manages individual tasks, delay after execution, and - * priority. Supports enabling/disabling and task execution. - * - `ExposureSequence` class: Manages multiple `Target` instances, supports - * adding, removing, and modifying targets, and coordinates their execution - * in a controlled manner. - * - * @date 2023-04-03 - * @author Max Qian - * @copyright Copyright (C) 2023-2024 Max Qian - */ - -#include "sequencer.hpp" -#include "task.hpp" - -#include "atom/log/loguru.hpp" -#include "type/expected.hpp" - -Target::Target(std::string name, std::chrono::seconds delayAfterTarget, - int priority) - : name_(std::move(name)), - delayAfterTarget_(delayAfterTarget), - priority_(priority), - enabled_(true) {} - -void Target::addTask(std::shared_ptr task) { - tasks_.emplace_back(std::move(task)); -} - -void Target::setDelayAfterTarget(std::chrono::seconds delay) { - delayAfterTarget_ = delay; -} - -void Target::setPriority(int p) { priority_ = p; } - -int Target::getPriority() const { return priority_; } - -void Target::enable() { enabled_ = true; } - -void Target::disable() { enabled_ = false; } - -auto Target::isEnabled() const -> bool { return enabled_; } - -auto Target::execute(std::stop_token stopToken, std::atomic_flag& pauseFlag, - std::condition_variable_any& cv, std::shared_mutex& mtx) - -> atom::type::expected { - if (!enabled_) { - LOG_F(WARNING, "Target {} is disabled.", name_); - return atom::type::make_unexpected("Target is disabled"); - } - - LOG_F(INFO, "Starting target: {}", name_); - for (const auto& task : tasks_) { - if (stopToken.stop_requested()) { - return {}; - } - - std::shared_lock lock(mtx); - cv.wait(lock, [&] { - return !pauseFlag.test() || stopToken.stop_requested(); - }); - - if (stopToken.stop_requested()) { - return {}; - } - - try { - task->start(); - task->run(); - if (task->isTimeout()) { - LOG_F(ERROR, "Task {} timed out.", task->getName()); - task->fail(std::runtime_error("Timeout")); - } - } catch (const std::exception& ex) { - LOG_F(ERROR, "Task {} failed with exception: {}", task->getName(), - ex.what()); - return atom::type::make_unexpected("Task execution failed"); - } - } - LOG_F(INFO, "Completed target: {}", name_); - std::this_thread::sleep_for(delayAfterTarget_); - return {}; -} - -std::string Target::getName() const { return name_; } - -ExposureSequence::ExposureSequence() : sequenceThread_(nullptr) {} - -ExposureSequence::~ExposureSequence() { - stop(); - if (sequenceThread_ && sequenceThread_->joinable()) { - sequenceThread_->join(); - } -} - -void ExposureSequence::addTarget(Target target) { - std::unique_lock lock(mutex_); - targets_.emplace_back(std::make_shared(std::move(target))); - LOG_F(INFO, "Added target: {}", target.getName()); -} - -void ExposureSequence::removeTarget(size_t index) { - std::unique_lock lock(mutex_); - if (index < targets_.size()) { - targets_.erase(targets_.begin() + index); - LOG_F(INFO, "Removed target at index {}", index); - } else { - LOG_F(ERROR, "Target index out of range for removal."); - } -} - -void ExposureSequence::modifyTarget( - size_t index, std::optional newDelay, - std::optional newPriority) { - std::unique_lock lock(mutex_); - if (index < targets_.size()) { - if (newDelay) { - targets_[index]->setDelayAfterTarget(*newDelay); - } - if (newPriority) { - targets_[index]->setPriority(*newPriority); - } - LOG_F(INFO, "Modified target at index {}", index); - } else { - LOG_F(ERROR, "Target index out of range for modification."); - } -} - -void ExposureSequence::enableTarget(size_t index) { - std::unique_lock lock(mutex_); - if (index < targets_.size()) { - targets_[index]->enable(); - LOG_F(INFO, "Enabled target at index {}", index); - } else { - LOG_F(ERROR, "Target index out of range for enabling."); - } -} - -void ExposureSequence::disableTarget(size_t index) { - std::unique_lock lock(mutex_); - if (index < targets_.size()) { - targets_[index]->disable(); - LOG_F(INFO, "Disabled target at index {}", index); - } else { - LOG_F(ERROR, "Target index out of range for disabling."); - } -} - -void ExposureSequence::executeAll() { - stopFlag_.clear(); - pauseFlag_.clear(); - if (sequenceThread_ && sequenceThread_->joinable()) { - sequenceThread_->join(); - } - sequenceThread_ = std::make_unique( - &ExposureSequence::executeSequence, this); -} - -void ExposureSequence::stop() { - if (sequenceThread_) { - sequenceThread_->request_stop(); - cv_.notify_all(); - LOG_F(INFO, "Stopping all tasks."); - } -} - -void ExposureSequence::pause() { - pauseFlag_.test_and_set(); - LOG_F(INFO, "Pausing all tasks."); -} - -void ExposureSequence::resume() { - pauseFlag_.clear(); - cv_.notify_all(); - LOG_F(INFO, "Resuming all tasks."); -} - -void ExposureSequence::executeSequence(std::stop_token stopToken) { - for (const auto& target : targets_) { - if (stopToken.stop_requested()) { - return; - } - if (target->isEnabled()) { - target->execute(stopToken, pauseFlag_, cv_, mutex_) - .value_or([&](std::string err) { - LOG_F(ERROR, "Failed to execute target: {}", err); - }); - } - } -} diff --git a/src/task/sequencer.hpp b/src/task/sequencer.hpp deleted file mode 100644 index 6758b48f..00000000 --- a/src/task/sequencer.hpp +++ /dev/null @@ -1,116 +0,0 @@ -/** - * @file sequencer.hpp - * @brief Definition of classes for managing and executing task sequences. - * - * This file defines the `Target` and `ExposureSequence` classes for managing - * and executing sequences of tasks. The `Target` class represents a unit that - * can hold and execute tasks with a configurable delay and priority. The - * `ExposureSequence` class manages a collection of `Target` objects and - * coordinates their execution, allowing for task sequences to be executed in - * parallel or serially with options for pausing, resuming, and stopping. - * - * Key features: - * - `Target` class: Manages individual tasks, delay after execution, and - * priority. Supports enabling/disabling and task execution. - * - `ExposureSequence` class: Manages multiple `Target` instances, supports - * adding, removing, and modifying targets, and coordinates their execution - * in a controlled manner. - * - * @date 2023-04-03 - * @author Max Qian - * @copyright Copyright (C) 2023-2024 Max Qian - */ - -#ifndef LITHIUM_TASK_SEQUENCER_HPP -#define LITHIUM_TASK_SEQUENCER_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "atom/type/expected.hpp" - -class Task; - -class Target { -public: - explicit Target( - std::string name, - std::chrono::seconds delayAfterTarget = std::chrono::seconds{0}, - int priority = 0); - - void addTask(std::shared_ptr task); - - void setDelayAfterTarget(std::chrono::seconds delay); - - void setPriority(int p); - - [[nodiscard]] int getPriority() const; - - void enable(); - - void disable(); - - [[nodiscard]] bool isEnabled() const; - - atom::type::expected execute(std::stop_token stopToken, - std::atomic_flag& pauseFlag, - std::condition_variable_any& cv, - std::shared_mutex& mtx); - - [[nodiscard]] std::string getName() const; - -private: - std::string name_; - std::vector> tasks_; - std::chrono::seconds delayAfterTarget_; - int priority_; - bool enabled_; -}; - -class ExposureSequence { -public: - ExposureSequence(); - ~ExposureSequence(); - - void addTarget(Target target); - - void removeTarget(size_t index); - - void modifyTarget( - size_t index, - std::optional newDelay = std::nullopt, - std::optional newPriority = std::nullopt); - - void enableTarget(size_t index); - - void disableTarget(size_t index); - - void executeAll(); - - void stop(); - - void pause(); - - void resume(); - -private: - mutable std::shared_mutex mutex_; - std::condition_variable_any cv_; - std::vector> targets_; - std::atomic_flag stopFlag_ = ATOMIC_FLAG_INIT; - std::atomic_flag pauseFlag_ = ATOMIC_FLAG_INIT; - std::unique_ptr sequenceThread_; - - void executeSequence(std::stop_token stopToken); -}; - -#endif // LITHIUM_TASK_SEQUENCER_HPP diff --git a/src/task/simple/sequencer.cpp b/src/task/simple/sequencer.cpp new file mode 100644 index 00000000..6ec9a8a8 --- /dev/null +++ b/src/task/simple/sequencer.cpp @@ -0,0 +1,115 @@ +#include "sequencer.hpp" +#include +#include +#include + +namespace lithium::sequencer { + +ExposureSequence::ExposureSequence() = default; + +ExposureSequence::~ExposureSequence() { stop(); } + +void ExposureSequence::addTarget(std::unique_ptr target) { + targets_.push_back(std::move(target)); +} + +void ExposureSequence::removeTarget(const std::string& name) { + targets_.erase(std::remove_if(targets_.begin(), targets_.end(), + [&name](const auto& target) { + return target->getName() == name; + }), + targets_.end()); +} + +void ExposureSequence::modifyTarget(const std::string& name, + const TargetModifier& modifier) { + auto it = std::find_if( + targets_.begin(), targets_.end(), + [&name](const auto& target) { return target->getName() == name; }); + if (it != targets_.end()) { + modifier(**it); + } +} + +void ExposureSequence::executeAll() { + if (state_.exchange(SequenceState::Running) != SequenceState::Idle) { + return; + } + sequenceThread_ = std::jthread([this] { executeSequence(); }); +} + +void ExposureSequence::stop() { + state_.store(SequenceState::Stopping); + if (sequenceThread_.joinable()) { + sequenceThread_.request_stop(); + sequenceThread_.join(); + } + state_.store(SequenceState::Idle); +} + +void ExposureSequence::pause() { + SequenceState expected = SequenceState::Running; + state_.compare_exchange_strong(expected, SequenceState::Paused); +} + +void ExposureSequence::resume() { + SequenceState expected = SequenceState::Paused; + state_.compare_exchange_strong(expected, SequenceState::Running); +} + +void ExposureSequence::saveSequence(const std::string& filename) const { + nlohmann::json j; + for (const auto& target : targets_) { + j["targets"].push_back({ + {"name", target->getName()}, {"enabled", target->isEnabled()}, + // Add more target properties as needed + }); + } + std::ofstream file(filename); + file << j.dump(4); +} + +void ExposureSequence::loadSequence(const std::string& filename) { + std::ifstream file(filename); + nlohmann::json j; + file >> j; + + targets_.clear(); + for (const auto& targetJson : j["targets"]) { + auto target = std::make_unique(targetJson["name"]); + target->setEnabled(targetJson["enabled"]); + // Load more target properties as needed + targets_.push_back(std::move(target)); + } +} + +std::vector ExposureSequence::getTargetNames() const { + std::vector names; + std::transform(targets_.begin(), targets_.end(), std::back_inserter(names), + [](const auto& target) { return target->getName(); }); + return names; +} + +TargetStatus ExposureSequence::getTargetStatus(const std::string& name) const { + auto it = std::find_if( + targets_.begin(), targets_.end(), + [&name](const auto& target) { return target->getName() == name; }); + return it != targets_.end() ? (*it)->getStatus() : TargetStatus::Skipped; +} + +void ExposureSequence::executeSequence() { + for (auto& target : targets_) { + if (state_.load() == SequenceState::Stopping) { + break; + } + while (state_.load() == SequenceState::Paused) { + std::this_thread::yield(); + } + if (target->isEnabled()) { + target->execute(); + } + } + state_.store(SequenceState::Idle); +} + +} // namespace lithium::sequencer diff --git a/src/task/simple/sequencer.hpp b/src/task/simple/sequencer.hpp new file mode 100644 index 00000000..dc33e0cc --- /dev/null +++ b/src/task/simple/sequencer.hpp @@ -0,0 +1,49 @@ +#ifndef LITHIUM_TASK_SEQUENCER_HPP +#define LITHIUM_TASK_SEQUENCER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "./task.hpp" +#include "target.hpp" + +namespace lithium::sequencer { +enum class SequenceState { Idle, Running, Paused, Stopped, Stopping }; +class ExposureSequence { +public: + ExposureSequence(); + ~ExposureSequence(); + + void addTarget(std::unique_ptr target); + void removeTarget(const std::string& name); + void modifyTarget(const std::string& name, const TargetModifier& modifier); + + void executeAll(); + void stop(); + void pause(); + void resume(); + + // New methods + void saveSequence(const std::string& filename) const; + void loadSequence(const std::string& filename); + std::vector getTargetNames() const; + TargetStatus getTargetStatus(const std::string& name) const; + +private: + std::vector> targets_; + std::atomic state_{SequenceState::Idle}; + std::jthread sequenceThread_; + + void executeSequence(); +}; + +} // namespace lithium::sequencer + +#endif // LITHIUM_TASK_SEQUENCER_HPP diff --git a/src/task/simple/target.cpp b/src/task/simple/target.cpp new file mode 100644 index 00000000..f4f43ab3 --- /dev/null +++ b/src/task/simple/target.cpp @@ -0,0 +1,44 @@ +#include "target.hpp" +#include + +namespace lithium::sequencer { + +Target::Target(std::string name, std::chrono::seconds cooldown) + : name_(std::move(name)), cooldown_(cooldown) {} + +void Target::addTask(std::unique_ptr task) { + tasks_.push_back(std::move(task)); +} + +void Target::setCooldown(std::chrono::seconds cooldown) { + cooldown_ = cooldown; +} + +void Target::setEnabled(bool enabled) { enabled_ = enabled; } + +const std::string& Target::getName() const { return name_; } + +TargetStatus Target::getStatus() const { return status_; } + +bool Target::isEnabled() const { return enabled_; } + +void Target::execute() { + if (!enabled_) { + status_ = TargetStatus::Skipped; + return; + } + + status_ = TargetStatus::InProgress; + for (auto& task : tasks_) { + task->execute(); + if (task->getStatus() == TaskStatus::Failed) { + status_ = TargetStatus::Failed; + return; + } + } + status_ = TargetStatus::Completed; + + std::this_thread::sleep_for(cooldown_); +} + +} // namespace lithium::sequencer diff --git a/src/task/simple/target.hpp b/src/task/simple/target.hpp new file mode 100644 index 00000000..cb83713e --- /dev/null +++ b/src/task/simple/target.hpp @@ -0,0 +1,43 @@ +#ifndef LITHIUM_TARGET_HPP +#define LITHIUM_TARGET_HPP + +#include +#include +#include +#include +#include + +#include "task.hpp" + +namespace lithium::sequencer { + +enum class TargetStatus { Pending, InProgress, Completed, Failed, Skipped }; + +class Target { +public: + Target(std::string name, + std::chrono::seconds cooldown = std::chrono::seconds{0}); + + void addTask(std::unique_ptr task); + void setCooldown(std::chrono::seconds cooldown); + void setEnabled(bool enabled); + + [[nodiscard]] const std::string& getName() const; + [[nodiscard]] TargetStatus getStatus() const; + [[nodiscard]] bool isEnabled() const; + + void execute(); + +private: + std::string name_; + std::vector> tasks_; + std::chrono::seconds cooldown_; + bool enabled_{true}; + TargetStatus status_{TargetStatus::Pending}; +}; + +using TargetModifier = std::function; + +} // namespace lithium::sequencer + +#endif // LITHIUM_TARGET_HPP diff --git a/src/task/simple/task.cpp b/src/task/simple/task.cpp new file mode 100644 index 00000000..09e3eed6 --- /dev/null +++ b/src/task/simple/task.cpp @@ -0,0 +1,37 @@ +#include "task.hpp" +#include + +namespace lithium::sequencer { + +Task::Task(std::string name, std::function action) + : name_(std::move(name)), action_(std::move(action)) {} + +void Task::execute() { + status_ = TaskStatus::InProgress; + error_.reset(); + + try { + if (timeout_ > std::chrono::seconds{0}) { + auto future = std::async(std::launch::async, action_); + if (future.wait_for(timeout_) == std::future_status::timeout) { + throw std::runtime_error("Task timed out"); + } + } else { + action_(); + } + status_ = TaskStatus::Completed; + } catch (const std::exception& e) { + status_ = TaskStatus::Failed; + error_ = e.what(); + } +} + +void Task::setTimeout(std::chrono::seconds timeout) { timeout_ = timeout; } + +const std::string& Task::getName() const { return name_; } + +TaskStatus Task::getStatus() const { return status_; } + +std::optional Task::getError() const { return error_; } + +} // namespace lithium::sequencer diff --git a/src/task/simple/task.hpp b/src/task/simple/task.hpp new file mode 100644 index 00000000..c6009794 --- /dev/null +++ b/src/task/simple/task.hpp @@ -0,0 +1,34 @@ +#ifndef LITHIUM_TASK_HPP +#define LITHIUM_TASK_HPP + +#include +#include +#include +#include + +namespace lithium::sequencer { + +enum class TaskStatus { Pending, InProgress, Completed, Failed }; + +class Task { +public: + Task(std::string name, std::function action); + + void execute(); + void setTimeout(std::chrono::seconds timeout); + + [[nodiscard]] const std::string& getName() const; + [[nodiscard]] TaskStatus getStatus() const; + [[nodiscard]] std::optional getError() const; + +private: + std::string name_; + std::function action_; + std::chrono::seconds timeout_{0}; + TaskStatus status_{TaskStatus::Pending}; + std::optional error_; +}; + +} // namespace lithium::sequencer + +#endif // LITHIUM_TASK_HPP diff --git a/src/task/utils/macro.hpp b/src/task/utils/macro.hpp new file mode 100644 index 00000000..1162ba7e --- /dev/null +++ b/src/task/utils/macro.hpp @@ -0,0 +1,6 @@ +#define GET_PARAM_OR_THROW(params, key, var) \ + if (!params.contains(key)) { \ + THROW_MISSING_ARGUMENT(std::string(key) + " is missing"); \ + } \ + var = params[key]; \ + LOG_F(INFO, "{}: {}", key, var); diff --git a/src/tools/croods.hpp b/src/tools/croods.hpp index 98bff482..85caa28c 100644 --- a/src/tools/croods.hpp +++ b/src/tools/croods.hpp @@ -1,7 +1,10 @@ #ifndef LITHIUM_SEARCH_CROODS_HPP #define LITHIUM_SEARCH_CROODS_HPP +#include +#include #include +#include #include #include "macro.hpp" @@ -64,6 +67,416 @@ auto convertToSphericalCoordinates(const CartesianCoordinates& cartesianPoint) auto calculateFOV(int focalLength, double cameraSizeWidth, double cameraSizeHeight) -> MinMaxFOV; + +constexpr double EARTHRADIUSEQUATORIAL = 6378137.0; +constexpr double EARTHRADIUSPOLAR = 6356752.0; +constexpr double ASTRONOMICALUNIT = 1.495978707e11; +constexpr double LIGHTSPEED = 299792458.0; +constexpr double AIRY = 1.21966; +constexpr double SOLARMASS = 1.98847e30; +constexpr double SOLARRADIUS = 6.957e8; +constexpr double PARSEC = 3.0857e16; + +template +constexpr T LUMEN(T wavelength) { + return 1.464128843e-3 / (wavelength * wavelength); +} + +template +constexpr T REDSHIFT(T observed, T rest) { + return (observed - rest) / rest; +} + +template +constexpr T DOPPLER(T redshift, T speed) { + return redshift * speed; +} + +template +T rangeHA(T r) { + while (r < -12.0) + r += 24.0; + while (r >= 12.0) + r -= 24.0; + return r; +} + +template +T range24(T r) { + while (r < 0.0) + r += 24.0; + while (r > 24.0) + r -= 24.0; + return r; +} + +template +T range360(T r) { + while (r < 0.0) + r += 360.0; + while (r > 360.0) + r -= 360.0; + return r; +} + +template +T rangeDec(T decdegrees) { + if ((decdegrees >= 270.0) && (decdegrees <= 360.0)) + return (decdegrees - 360.0); + if ((decdegrees >= 180.0) && (decdegrees < 270.0)) + return (180.0 - decdegrees); + if ((decdegrees >= 90.0) && (decdegrees < 180.0)) + return (180.0 - decdegrees); + return decdegrees; +} + +template +T get_local_hour_angle(T sideral_time, T ra) { + T HA = sideral_time - ra; + return rangeHA(HA); +} + +template +std::pair get_alt_az_coordinates(T Ha, T Dec, T Lat) { + using namespace std::numbers; + Ha *= pi_v / 180.0; + Dec *= pi_v / 180.0; + Lat *= pi_v / 180.0; + T alt = std::asin(std::sin(Dec) * std::sin(Lat) + + std::cos(Dec) * std::cos(Lat) * std::cos(Ha)); + T az = std::acos((std::sin(Dec) - std::sin(alt) * std::sin(Lat)) / + (std::cos(alt) * std::cos(Lat))); + alt *= 180.0 / pi_v; + az *= 180.0 / pi_v; + if (std::sin(Ha) >= 0.0) + az = 360 - az; + return {alt, az}; +} + +template +T estimate_geocentric_elevation(T Lat, T El) { + using namespace std::numbers; + Lat *= pi_v / 180.0; + Lat = std::sin(Lat); + El += Lat * (EARTHRADIUSPOLAR - EARTHRADIUSEQUATORIAL); + return El; +} + +template +T estimate_field_rotation_rate(T Alt, T Az, T Lat) { + using namespace std::numbers; + Alt *= pi_v / 180.0; + Az *= pi_v / 180.0; + Lat *= pi_v / 180.0; + T ret = std::cos(Lat) * std::cos(Az) / std::cos(Alt); + ret *= 180.0 / pi_v; + return ret; +} + +template +T estimate_field_rotation(T HA, T rate) { + HA *= rate; + while (HA >= 360.0) + HA -= 360.0; + while (HA < 0) + HA += 360.0; + return HA; +} + +template +constexpr T as2rad(T as) { + using namespace std::numbers; + return as * pi_v / (60.0 * 60.0 * 12.0); +} + +template +constexpr T rad2as(T rad) { + using namespace std::numbers; + return rad * (60.0 * 60.0 * 12.0) / pi_v; +} + +template +T estimate_distance(T parsecs, T parallax_radius) { + return parallax_radius / std::sin(as2rad(parsecs)); +} + +template +constexpr T m2au(T m) { + return m / ASTRONOMICALUNIT; +} + +template +T calc_delta_magnitude(T mag_ratio, std::span spectrum, + std::span ref_spectrum) { + T delta_mag = 0; + for (size_t l = 0; l < spectrum.size(); l++) { + delta_mag += spectrum[l] * mag_ratio * ref_spectrum[l] / spectrum[l]; + } + delta_mag /= spectrum.size(); + return delta_mag; +} + +template +T calc_star_mass(T delta_mag, T ref_size) { + return delta_mag * ref_size; +} + +template +T estimate_orbit_radius(T obs_lambda, T ref_lambda, T period) { + using namespace std::numbers; + return pi_v * 2 * DOPPLER(REDSHIFT(obs_lambda, ref_lambda), LIGHTSPEED) / + period; +} + +template +T estimate_secondary_mass(T star_mass, T star_drift, T orbit_radius) { + return orbit_radius * std::pow(star_drift * orbit_radius, 3) * 3 * + star_mass; +} + +template +T estimate_secondary_size(T star_size, T dropoff_ratio) { + return std::pow(dropoff_ratio * std::pow(star_size, 2), 0.5); +} + +template +T calc_photon_flux(T rel_magnitude, T filter_bandwidth, T wavelength, + T steradian) { + return std::pow(10, rel_magnitude * -0.4) * + (LUMEN(wavelength) * steradian * filter_bandwidth); +} + +template +T calc_rel_magnitude(T photon_flux, T filter_bandwidth, T wavelength, + T steradian) { + return std::pow(10, 1.0 / (photon_flux / (LUMEN(wavelength) * steradian * + filter_bandwidth))) / + -0.4; +} + +template +T estimate_absolute_magnitude(T delta_dist, T delta_mag) { + return std::sqrt(delta_dist) * delta_mag; +} + +template +std::array baseline_2d_projection(T alt, T az, + const std::array& baseline, + T wavelength) { + using namespace std::numbers; + az *= pi_v / 180.0; + alt *= pi_v / 180.0; + std::array uvresult; + uvresult[0] = (baseline[0] * std::sin(az) + baseline[1] * std::cos(az)); + uvresult[1] = (baseline[1] * std::sin(alt) * std::sin(az) - + baseline[0] * std::sin(alt) * std::cos(az) + + baseline[2] * std::cos(alt)); + uvresult[0] *= AIRY / wavelength; + uvresult[1] *= AIRY / wavelength; + return uvresult; +} + +template +T baseline_delay(T alt, T az, const std::array& baseline) { + using namespace std::numbers; + az *= pi_v / 180.0; + alt *= pi_v / 180.0; + return std::cos(az) * baseline[1] * std::cos(alt) - + baseline[0] * std::sin(az) * std::cos(alt) + + std::sin(alt) * baseline[2]; +} + +// 定义一个表示天体坐标的结构体 +template +struct CelestialCoords { + T ra; // 赤经 (小时) + T dec; // 赤纬 (度) +}; + +// 定义一个表示地理坐标的结构体 +template +struct GeographicCoords { + T latitude; + T longitude; +}; + +// 添加一个日期时间结构体 +struct DateTime { + int year; + int month; + int day; + int hour; + int minute; + double second; +}; + +// 添加一个函数来计算儒略日 +template +T calculate_julian_date(const DateTime& dt) { + int a = (14 - dt.month) / 12; + int y = dt.year + 4800 - a; + int m = dt.month + 12 * a - 3; + + T jd = dt.day + (153 * m + 2) / 5 + 365 * y + y / 4 - y / 100 + y / 400 - + 32045 + (dt.hour - 12) / 24.0 + dt.minute / 1440.0 + + dt.second / 86400.0; + + return jd; +} + +// 添加一个函数来计算恒星时 +template +T calculate_sidereal_time(const DateTime& dt, T longitude) { + T jd = calculate_julian_date(dt); + T t = (jd - 2451545.0) / 36525.0; + T theta = 280.46061837 + 360.98564736629 * (jd - 2451545.0) + + 0.000387933 * t * t - t * t * t / 38710000.0; + + theta = range360(theta); + theta += longitude; + + return theta / 15.0; // 转换为小时 +} + +// 添加一个函数来计算大气折射 +template +T calculate_refraction(T altitude, T temperature = 10.0, T pressure = 1010.0) { + if (altitude < -0.5) + return 0.0; // 天体在地平线以下,不考虑折射 + + T R; + if (altitude > 15.0) { + R = 0.00452 * pressure / + ((273 + temperature) * + std::tan(altitude * std::numbers::pi / 180.0)); + } else { + T a = altitude; + T b = altitude + 7.31 / (altitude + 4.4); + R = 0.1594 + 0.0196 * a + 0.00002 * a * a; + R *= pressure * (1 - 0.00012 * (temperature - 10)) / 1010.0; + R /= 60.0; + } + + return R; +} + +// 添加一个函数来计算视差 +template +CelestialCoords apply_parallax(const CelestialCoords& coords, + const GeographicCoords& observer, + T distance, const DateTime& dt) { + T lst = calculate_sidereal_time(dt, observer.longitude); + T ha = lst - coords.ra; + + T sinLat = std::sin(observer.latitude * std::numbers::pi / 180.0); + T cosLat = std::cos(observer.latitude * std::numbers::pi / 180.0); + T sinDec = std::sin(coords.dec * std::numbers::pi / 180.0); + T cosDec = std::cos(coords.dec * std::numbers::pi / 180.0); + T sinHA = std::sin(ha * std::numbers::pi / 12.0); + T cosHA = std::cos(ha * std::numbers::pi / 12.0); + + T rho = EARTHRADIUSEQUATORIAL / (PARSEC * distance); + + T A = cosLat * sinHA; + T B = sinLat * cosDec - cosLat * sinDec * cosHA; + T C = sinLat * sinDec + cosLat * cosDec * cosHA; + + T newRA = coords.ra - std::atan2(A, C - rho) * 12.0 / std::numbers::pi; + T newDec = std::atan2((B * (C - rho) + A * A * sinDec / cosDec) / + ((C - rho) * (C - rho) + A * A), + cosDec) * + 180.0 / std::numbers::pi; + + return {range24(newRA), rangeDec(newDec)}; +} + +// 添加一个函数来计算黄道坐标 +template +std::pair equatorial_to_ecliptic(const CelestialCoords& coords, + T obliquity) { + T sinDec = std::sin(coords.dec * std::numbers::pi / 180.0); + T cosDec = std::cos(coords.dec * std::numbers::pi / 180.0); + T sinRA = std::sin(coords.ra * std::numbers::pi / 12.0); + T cosRA = std::cos(coords.ra * std::numbers::pi / 12.0); + T sinObl = std::sin(obliquity * std::numbers::pi / 180.0); + T cosObl = std::cos(obliquity * std::numbers::pi / 180.0); + + T latitude = std::asin(sinDec * cosObl - cosDec * sinObl * sinRA) * 180.0 / + std::numbers::pi; + T longitude = + std::atan2(sinRA * cosDec * cosObl + sinDec * sinObl, cosDec * cosRA) * + 180.0 / std::numbers::pi; + + return {range360(longitude), latitude}; +} + +// 添加一个函数来计算前进 +template +T calculate_precession(const CelestialCoords& coords, const DateTime& from, + const DateTime& to) { + auto jd1 = calculate_julian_date(from); + auto jd2 = calculate_julian_date(to); + + T T1 = (jd1 - 2451545.0) / 36525.0; + T t = (jd2 - jd1) / 36525.0; + + T zeta = (2306.2181 + 1.39656 * T1 - 0.000139 * T1 * T1) * t + + (0.30188 - 0.000344 * T1) * t * t + 0.017998 * t * t * t; + T z = (2306.2181 + 1.39656 * T1 - 0.000139 * T1 * T1) * t + + (1.09468 + 0.000066 * T1) * t * t + 0.018203 * t * t * t; + T theta = (2004.3109 - 0.85330 * T1 - 0.000217 * T1 * T1) * t - + (0.42665 + 0.000217 * T1) * t * t - 0.041833 * t * t * t; + + zeta /= 3600.0; + z /= 3600.0; + theta /= 3600.0; + + T A = std::cos(coords.dec * std::numbers::pi / 180.0) * + std::sin(coords.ra * std::numbers::pi / 12.0 + + zeta * std::numbers::pi / 180.0); + T B = std::cos(theta * std::numbers::pi / 180.0) * + std::cos(coords.dec * std::numbers::pi / 180.0) * + std::cos(coords.ra * std::numbers::pi / 12.0 + + zeta * std::numbers::pi / 180.0) - + std::sin(theta * std::numbers::pi / 180.0) * + std::sin(coords.dec * std::numbers::pi / 180.0); + T C = std::sin(theta * std::numbers::pi / 180.0) * + std::cos(coords.dec * std::numbers::pi / 180.0) * + std::cos(coords.ra * std::numbers::pi / 12.0 + + zeta * std::numbers::pi / 180.0) + + std::cos(theta * std::numbers::pi / 180.0) * + std::sin(coords.dec * std::numbers::pi / 180.0); + + T newRA = std::atan2(A, B) * 12.0 / std::numbers::pi + z / 15.0; + T newDec = std::asin(C) * 180.0 / std::numbers::pi; + + return std::sqrt(std::pow(newRA - coords.ra, 2) + + std::pow(newDec - coords.dec, 2)); +} + +// 添加一个函数来格式化赤经 +template +std::string format_ra(T ra) { + int hours = static_cast(ra); + int minutes = static_cast((ra - hours) * 60); + double seconds = ((ra - hours) * 60 - minutes) * 60; + + return std::format("{:02d}h {:02d}m {:.2f}s", hours, minutes, seconds); +} + +// 添加一个函数来格式化赤纬 +template +std::string format_dec(T dec) { + char sign = (dec >= 0) ? '+' : '-'; + dec = std::abs(dec); + int degrees = static_cast(dec); + int minutes = static_cast((dec - degrees) * 60); + double seconds = ((dec - degrees) * 60 - minutes) * 60; + + return std::format("{}{:02d}° {:02d}' {:.2f}\"", sign, degrees, minutes, + seconds); +} + } // namespace lithium::tools #endif diff --git a/src/tools/libastro.cpp b/src/tools/libastro.cpp new file mode 100644 index 00000000..72bcad29 --- /dev/null +++ b/src/tools/libastro.cpp @@ -0,0 +1,180 @@ +#include "libastro.hpp" +#include +#include + +namespace lithium { + +namespace { + +// Utility functions for internal use +double getObliquity(double jd) { + double t = (jd - JD2000) / 36525.0; + return 23.439291 - 0.0130042 * t - 1.64e-7 * t * t + 5.04e-7 * t * t * t; +} + +} // anonymous namespace + +std::tuple getNutation(double jd) { + double t = (jd - JD2000) / 36525.0; + double omega = + 125.04452 - 1934.136261 * t + 0.0020708 * t * t + t * t * t / 450000; + double L = 280.4665 + 36000.7698 * t; + double Ls = 218.3165 + 481267.8813 * t; + + double nutation_lon = -17.2 * std::sin(degToRad(omega)) - + 1.32 * std::sin(2 * degToRad(L)) - + 0.23 * std::sin(2 * degToRad(Ls)) + + 0.21 * std::sin(2 * degToRad(omega)); + double nutation_obl = + 9.2 * std::cos(degToRad(omega)) + 0.57 * std::cos(2 * degToRad(L)) + + 0.1 * std::cos(2 * degToRad(Ls)) - 0.09 * std::cos(2 * degToRad(omega)); + + return {nutation_lon / 3600.0, nutation_obl / 3600.0}; +} + +EquatorialCoordinates applyNutation(const EquatorialCoordinates& position, + double jd, bool reverse) { + auto [nutation_lon, nutation_obl] = getNutation(jd); + double obliquity = degToRad(getObliquity(jd)); + + double ra = degToRad(position.rightAscension * 15); + double dec = degToRad(position.declination); + + double sign = reverse ? -1 : 1; + + double delta_ra = (std::cos(obliquity) + + std::sin(obliquity) * std::sin(ra) * std::tan(dec)) * + nutation_lon - + (std::cos(ra) * std::tan(dec)) * nutation_obl; + double delta_dec = (std::sin(obliquity) * std::cos(ra)) * nutation_lon + + std::sin(ra) * nutation_obl; + + return {radToDeg(ra + sign * degToRad(delta_ra)) / 15.0, + radToDeg(dec + sign * degToRad(delta_dec))}; +} + +EquatorialCoordinates applyAberration(const EquatorialCoordinates& position, + double jd) { + double t = (jd - JD2000) / 36525.0; + double e = 0.016708634 - 0.000042037 * t - 0.0000001267 * t * t; + double pi = 102.93735 + 1.71946 * t + 0.00046 * t * t; + double lon = 280.46646 + 36000.77983 * t + 0.0003032 * t * t; + + double ra = degToRad(position.rightAscension * 15); + double dec = degToRad(position.declination); + + double k = 20.49552 / 3600.0; // Constant of aberration + + double delta_ra = + -k * + (std::cos(ra) * std::cos(degToRad(lon)) * std::cos(degToRad(pi)) + + std::sin(ra) * std::sin(degToRad(lon))) / + std::cos(dec); + double delta_dec = + -k * (std::sin(degToRad(pi)) * + (std::sin(dec) * std::cos(degToRad(lon)) - + std::cos(dec) * std::sin(ra) * std::sin(degToRad(lon)))); + + return {radToDeg(ra + degToRad(delta_ra)) / 15.0, + radToDeg(dec + degToRad(delta_dec))}; +} + +EquatorialCoordinates applyPrecession(const EquatorialCoordinates& position, + double fromJD, double toJD) { + double t = (fromJD - JD2000) / 36525.0; + double T = (toJD - fromJD) / 36525.0; + + double zeta = (2306.2181 + 1.39656 * t - 0.000139 * t * t) * T + + (0.30188 - 0.000344 * t) * T * T + 0.017998 * T * T * T; + double z = (2306.2181 + 1.39656 * t - 0.000139 * t * t) * T + + (1.09468 + 0.000066 * t) * T * T + 0.018203 * T * T * T; + double theta = (2004.3109 - 0.85330 * t - 0.000217 * t * t) * T - + (0.42665 + 0.000217 * t) * T * T - 0.041833 * T * T * T; + + zeta = degToRad(zeta / 3600.0); + z = degToRad(z / 3600.0); + theta = degToRad(theta / 3600.0); + + double ra = degToRad(position.rightAscension * 15); + double dec = degToRad(position.declination); + + double A = std::cos(dec) * std::sin(ra + zeta); + double B = std::cos(theta) * std::cos(dec) * std::cos(ra + zeta) - + std::sin(theta) * std::sin(dec); + double C = std::sin(theta) * std::cos(dec) * std::cos(ra + zeta) + + std::cos(theta) * std::sin(dec); + + double ra_new = std::atan2(A, B) + z; + double dec_new = std::asin(C); + + return {radToDeg(ra_new) / 15.0, radToDeg(dec_new)}; +} + +EquatorialCoordinates observedToJ2000(const EquatorialCoordinates& observed, + double jd) { + auto temp = applyAberration(observed, jd); + temp = applyNutation(temp, jd, true); + return applyPrecession(temp, jd, JD2000); +} + +EquatorialCoordinates j2000ToObserved(const EquatorialCoordinates& j2000, + double jd) { + auto temp = applyPrecession(j2000, JD2000, jd); + temp = applyNutation(temp, jd); + return applyAberration(temp, jd); +} + +HorizontalCoordinates equatorialToHorizontal( + const EquatorialCoordinates& object, const GeographicCoordinates& observer, + double jd) { + double lst = range360(280.46061837 + 360.98564736629 * (jd - 2451545.0) + + observer.longitude); + double ha = range360(lst - object.rightAscension * 15); + + double sin_alt = std::sin(degToRad(object.declination)) * + std::sin(degToRad(observer.latitude)) + + std::cos(degToRad(object.declination)) * + std::cos(degToRad(observer.latitude)) * + std::cos(degToRad(ha)); + double alt = radToDeg(std::asin(sin_alt)); + + double cos_az = + (std::sin(degToRad(object.declination)) - + std::sin(degToRad(alt)) * std::sin(degToRad(observer.latitude))) / + (std::cos(degToRad(alt)) * std::cos(degToRad(observer.latitude))); + double az = radToDeg(std::acos(cos_az)); + + if (std::sin(degToRad(ha)) > 0) { + az = 360 - az; + } + + return {range360(az + 180), alt}; +} + +EquatorialCoordinates horizontalToEquatorial( + const HorizontalCoordinates& object, const GeographicCoordinates& observer, + double jd) { + double alt = degToRad(object.altitude); + double az = degToRad(range360(object.azimuth + 180)); + double lat = degToRad(observer.latitude); + + double sin_dec = std::sin(alt) * std::sin(lat) + + std::cos(alt) * std::cos(lat) * std::cos(az); + double dec = radToDeg(std::asin(sin_dec)); + + double cos_ha = (std::sin(alt) - std::sin(lat) * std::sin(degToRad(dec))) / + (std::cos(lat) * std::cos(degToRad(dec))); + double ha = radToDeg(std::acos(cos_ha)); + + if (std::sin(az) > 0) { + ha = 360 - ha; + } + + double lst = range360(280.46061837 + 360.98564736629 * (jd - 2451545.0) + + observer.longitude); + double ra = range360(lst - ha) / 15.0; + + return {ra, dec}; +} + +} // namespace lithium diff --git a/src/tools/libastro.hpp b/src/tools/libastro.hpp new file mode 100644 index 00000000..b03cab6b --- /dev/null +++ b/src/tools/libastro.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include + +#include "macro.hpp" + +namespace lithium { + +constexpr double JD2000 = 2451545.0; + +struct EquatorialCoordinates { + double rightAscension; // in hours + double declination; // in degrees +} ATOM_ALIGNAS(16); + +struct HorizontalCoordinates { + double azimuth; // in degrees + double altitude; // in degrees +} ATOM_ALIGNAS(16); + +struct GeographicCoordinates { + double longitude; // in degrees + double latitude; // in degrees + double elevation; // in meters +} ATOM_ALIGNAS(32); + +// Convert degrees to radians +constexpr double degToRad(double deg) { return deg * std::numbers::pi / 180.0; } + +// Convert radians to degrees +constexpr double radToDeg(double rad) { return rad * 180.0 / std::numbers::pi; } + +// Range 0 to 360 +constexpr double range360(double angle) { + return std::fmod(angle, 360.0) + (angle < 0 ? 360.0 : 0.0); +} + +EquatorialCoordinates observedToJ2000(const EquatorialCoordinates& observed, + double jd); +EquatorialCoordinates j2000ToObserved(const EquatorialCoordinates& j2000, + double jd); +HorizontalCoordinates equatorialToHorizontal( + const EquatorialCoordinates& object, const GeographicCoordinates& observer, + double jd); +EquatorialCoordinates horizontalToEquatorial( + const HorizontalCoordinates& object, const GeographicCoordinates& observer, + double jd); + +// Additional utility functions +std::tuple getNutation(double jd); +EquatorialCoordinates applyNutation(const EquatorialCoordinates& position, + double jd, bool reverse = false); +EquatorialCoordinates applyAberration(const EquatorialCoordinates& position, + double jd); +EquatorialCoordinates applyPrecession(const EquatorialCoordinates& position, + double fromJD, double toJD); + +} // namespace lithium diff --git a/src/utils/constant.cpp b/src/utils/constant.cpp index aca21441..847a6166 100644 --- a/src/utils/constant.cpp +++ b/src/utils/constant.cpp @@ -1,9 +1,9 @@ #include "constant.hpp" #ifdef _WIN32 -std::vector constants::COMMON_COMPILERS = {"cl.exe", "g++.exe", +std::vector Constants::COMMON_COMPILERS = {"cl.exe", "g++.exe", "clang++.exe"}; -std::vector constants::COMPILER_PATHS = { +std::vector Constants::COMPILER_PATHS = { "C:\\Program Files (x86)\\Microsoft Visual " "Studio\\2019\\Community\\VC\\Tools\\MSVC\\14.29." "30133\\bin\\Hostx64\\x64", @@ -13,11 +13,11 @@ std::vector constants::COMPILER_PATHS = { "C:\\msys64\\mingw64\\bin", "C:\\MinGW\\bin", "C:\\Program Files\\LLVM\\bin"}; #elif __APPLE__ -std::vector constants::COMMON_COMPILERS = {"clang++", "g++"}; -std::vector constants::COMPILER_PATHS = { +std::vector Constants::COMMON_COMPILERS = {"clang++", "g++"}; +std::vector Constants::COMPILER_PATHS = { "/usr/bin", "/usr/local/bin", "/opt/local/bin"}; #elif __linux__ -std::vector constants::COMMON_COMPILERS = {"g++", "clang++"}; -std::vector constants::COMPILER_PATHS = {"/usr/bin", +std::vector Constants::COMMON_COMPILERS = {"g++", "clang++"}; +std::vector Constants::COMPILER_PATHS = {"/usr/bin", "/usr/local/bin"}; #endif diff --git a/src/utils/constant.hpp b/src/utils/constant.hpp index 3bdc30ce..770735e1 100644 --- a/src/utils/constant.hpp +++ b/src/utils/constant.hpp @@ -18,86 +18,73 @@ Description: Constants for Lithium #include #include -class constants { +#include "atom/algorithm/hash.hpp" + +#define DEFINE_CONSTANT(name, value) static constexpr const char* name = value; + +#define DEFINE_LITHIUM_CONSTANT(name) \ + static constexpr const char* name = "lithium." #name; \ + static constexpr unsigned int name##_hash = hash(name); + +class Constants { public: #ifdef _WIN32 #if defined(__MINGW32__) || defined(__MINGW64__) - static constexpr const char* PATH_SEPARATOR = "/"; + DEFINE_CONSTANT(PATH_SEPARATOR, "/"); #else - static constexpr const char* PATH_SEPARATOR = "\\"; + DEFINE_CONSTANT(PATH_SEPARATOR, "\\"); #endif - static constexpr const char* LIB_EXTENSION = ".dll"; - static constexpr const char* EXECUTABLE_EXTENSION = ".exe"; + DEFINE_CONSTANT(LIB_EXTENSION, ".dll") + DEFINE_CONSTANT(EXECUTABLE_EXTENSION, ".exe") #elif defined(__APPLE__) - static constexpr const char* PATH_SEPARATOR = "/"; - static constexpr const char* LIB_EXTENSION = ".dylib"; - static constexpr const char* EXECUTABLE_EXTENSION = ""; + DEFINE_CONSTANT(PATH_SEPARATOR, "/") + DEFINE_CONSTANT(LIB_EXTENSION, ".dylib") + DEFINE_CONSTANT(EXECUTABLE_EXTENSION, "") #else - static constexpr const char* PATH_SEPARATOR = "/"; - static constexpr const char* LIB_EXTENSION = ".so"; - static constexpr const char* EXECUTABLE_EXTENSION = ""; + DEFINE_CONSTANT(PATH_SEPARATOR, "/") + DEFINE_CONSTANT(LIB_EXTENSION, ".so") + DEFINE_CONSTANT(EXECUTABLE_EXTENSION, "") #endif // Package info - static constexpr const char* PACKAGE_NAME = "package.json"; - static constexpr const char* PACKAGE_NAME_SHORT = "lithium"; - static constexpr const char* PACKAGE_VERSION = "0.1.0"; - - // Module info -#ifdef _WIN32 -#if defined(__MINGW32__) || defined(__MINGW64__) - static constexpr const char* MODULE_FOLDER = "./modules"; - static constexpr const char* COMPILER = "g++"; - static constexpr const char* TASK_FOLDER = "./tasks"; -#else - static constexpr const char* MODULE_FOLDER = ".\\modules"; - static constexpr const char* COMPILER = "cl.exe"; - static constexpr const char* TASK_FOLDER = ".\\tasks"; -#endif -#elif defined(__APPLE__) - static constexpr const char* MODULE_FOLDER = "./modules"; - static constexpr const char* COMPILER = "clang++"; - static constexpr const char* TASK_FOLDER = "./tasks"; -#else - static constexpr const char* MODULE_FOLDER = "./modules"; - static constexpr const char* COMPILER = "g++"; - static constexpr const char* TASK_FOLDER = "./tasks"; -#endif + DEFINE_CONSTANT(PACKAGE_NAME, "package.json") + DEFINE_CONSTANT(PACKAGE_NAME_SHORT, "lithium") + DEFINE_CONSTANT(PACKAGE_AUTHOR, "Max Qian") + DEFINE_CONSTANT(PACKAGE_AUTHOR_EMAIL, "astro_air@126.com") + DEFINE_CONSTANT(PACKAGE_LICENSE, "AGPL-3") + DEFINE_CONSTANT(PACKAGE_VERSION, "0.1.0") static std::vector COMMON_COMPILERS; static std::vector COMPILER_PATHS; // Env info - static constexpr const char* ENV_VAR_MODULE_PATH = "LITHIUM_MODULE_PATH"; + DEFINE_CONSTANT(ENV_VAR_MODULE_PATH, "LITHIUM_MODULE_PATH") - // Inside Module Identifiers - static constexpr const char* LITHIUM_COMPONENT_MANAGER = - "lithium.addon.manager"; - static constexpr const char* LITHIUM_MODULE_LOADER = "lithium.addon.loader"; - static constexpr const char* LITHIUM_ADDON_MANAGER = "lithium.addon.addon"; - static constexpr const char* LITHIUM_UTILS_ENV = "lithium.utils.env"; + DEFINE_LITHIUM_CONSTANT(CONFIG_MANAGER) - static constexpr const char* LITHIUM_PROCESS_MANAGER = - "lithium.system.process"; + DEFINE_LITHIUM_CONSTANT(COMPONENT_MANAGER) + DEFINE_LITHIUM_CONSTANT(MODULE_LOADER) + DEFINE_LITHIUM_CONSTANT(ADDON_MANAGER) + DEFINE_LITHIUM_CONSTANT(ENVIRONMENT) - static constexpr const char* LITHIUM_DEVICE_LOADER = - "lithium.device.loader"; - static constexpr const char* LITHIUM_DEVICE_MANAGER = - "lithium.device.manager"; + DEFINE_LITHIUM_CONSTANT(PROCESS_MANAGER) + DEFINE_LITHIUM_CONSTANT(DEVICE_LOADER) + DEFINE_LITHIUM_CONSTANT(DEVICE_MANAGER) - static std::vector LITHIUM_RESOURCES; - static std::vector LITHIUM_RESOURCES_SHA256; + DEFINE_LITHIUM_CONSTANT(TASK_CONTAINER) + DEFINE_LITHIUM_CONSTANT(TASK_SCHEDULER) + DEFINE_LITHIUM_CONSTANT(TASK_POOL) + DEFINE_LITHIUM_CONSTANT(TASK_LIST) + DEFINE_LITHIUM_CONSTANT(TASK_GENERATOR) + DEFINE_LITHIUM_CONSTANT(TASK_MANAGER) - // Task - static constexpr const char* LITIHUM_TASK_MANAGER = "lithium.task.manager"; - static constexpr const char* LITHIUM_TASK_CONTAINER = - "lithium.task.container"; - static constexpr const char* LITHIUM_TASK_POOL = "lithium.task.pool"; - static constexpr const char* LITHIUM_TASK_LIST = "lithium.task.list"; - static constexpr const char* LITHIUM_TASK_GENERATOR = - "lithium.task.generator"; + DEFINE_LITHIUM_CONSTANT(APP) + DEFINE_LITHIUM_CONSTANT(EVENTLOOP) + DEFINE_LITHIUM_CONSTANT(DISPATCHER) + DEFINE_LITHIUM_CONSTANT(EXECUTOR) - static constexpr const char* LITHIUM_COMMAND = "lithium.command"; + static std::vector LITHIUM_RESOURCES; + static std::vector LITHIUM_RESOURCES_SHA256; }; #endif // LITHIUM_UTILS_CONSTANTS_HPP diff --git a/tests/atom/async/async.cpp b/tests/atom/async/async.cpp index dfd3e024..c3207a86 100644 --- a/tests/atom/async/async.cpp +++ b/tests/atom/async/async.cpp @@ -223,3 +223,197 @@ TEST_F(AsyncWorkerManagerTest, Cancel_ValidWorker_CancelsWorker) { // Assert EXPECT_FALSE(worker->isActive()); } + +// Helper function to create a short delay +void delayMs(int ms) { + std::this_thread::sleep_for(std::chrono::milliseconds(ms)); +} + +// Test: Simple asynchronous task execution +TEST(EnhancedFutureTest, AsyncTaskExecution) { + auto future = atom::async::makeEnhancedFuture([]() -> int { + delayMs(500); + return 42; + }); + + // Test that the task is completed + EXPECT_EQ(future.wait(), 42); +} + +// Test: Void asynchronous task with then() chaining +TEST(EnhancedFutureTest, ThenChaining) { + std::atomic thenExecuted{false}; + + auto future = atom::async::makeEnhancedFuture([]() { + delayMs(500); + std::cout << "Task Done" << std::endl; + }); + + // Chain then to check if it executes after task completion + future.then([&]() { + thenExecuted = true; + std::cout << "Then executed" << std::endl; + }); + + // Verify that the then() function executed successfully + future.wait(); + EXPECT_TRUE(thenExecuted); +} + +// Test: onComplete() callback invocation +TEST(EnhancedFutureTest, OnCompleteCallback) { + std::atomic callbackExecuted{false}; + + auto future = atom::async::makeEnhancedFuture([]() -> int { + delayMs(500); + return 42; + }); + + // Set a completion callback to be executed when the task is done + future.onComplete([&](int result) { + EXPECT_EQ(result, 42); + callbackExecuted = true; + }); + + // Wait for completion and check if the callback was invoked + future.wait(); + EXPECT_TRUE(callbackExecuted); +} + +// Test: WaitFor with timeout and automatic cancellation +TEST(EnhancedFutureTest, WaitForTimeoutAndCancel) { + auto future = atom::async::makeEnhancedFuture([]() -> int { + delayMs(2000); // Simulate a long task (2 seconds) + return 42; + }); + + // Wait for 1 second, which is less than the task duration + auto result = future.waitFor(std::chrono::milliseconds(1000)); + + // Since the task is not done within the timeout, it should return nullopt + EXPECT_FALSE(result.has_value()); + EXPECT_TRUE(future.isCancelled()); +} + +// Test: Retry mechanism with successful retry +TEST(EnhancedFutureTest, RetryWithSuccess) { + auto future = atom::async::makeEnhancedFuture([]() -> int { + static int attempts = 0; + ++attempts; + delayMs(500); + if (attempts < 3) { + throw std::runtime_error("Simulated failure"); + } + return 42; + }); + + // Retry the task with a maximum of 5 retries + auto retryFuture = future.retry([](int result) { return result; }, 5); + + // Wait for completion and check the final result after retries + int result = retryFuture.wait(); + EXPECT_EQ(result, 42); +} + +// Test: Retry mechanism failure after max retries +TEST(EnhancedFutureTest, RetryWithFailure) { + auto future = atom::async::makeEnhancedFuture([]() -> int { + delayMs(500); + throw std::runtime_error("Simulated failure"); + }); + + // Retry the task with only 2 retries (which should fail) + auto retryFuture = future.retry([](int result) { return result; }, 2); + + // Expect that an exception is thrown after retries are exhausted + EXPECT_THROW(retryFuture.wait(), std::runtime_error); +} + +// Test: Cancel functionality +TEST(EnhancedFutureTest, CancelFunctionality) { + auto future = atom::async::makeEnhancedFuture([]() -> int { + delayMs(2000); // Simulate a long task (2 seconds) + return 42; + }); + + // Cancel the future + future.cancel(); + + // Trying to wait should throw a cancellation exception + EXPECT_THROW(future.wait(), std::runtime_error); + EXPECT_TRUE(future.isCancelled()); +} + +// Test: Exception handling +TEST(EnhancedFutureTest, ExceptionHandling) { + auto future = atom::async::makeEnhancedFuture( + []() -> int { throw std::runtime_error("Test Exception"); }); + + // Expect that waiting for the future throws the exception + EXPECT_THROW(future.wait(), std::runtime_error); + + // Check that the exception is captured + auto exception = future.getException(); + EXPECT_NE(exception, nullptr); +} + +// Test: Void task with onComplete callback +TEST(EnhancedFutureTest, VoidTaskWithOnCompleteCallback) { + std::atomic callbackExecuted{false}; + + auto future = atom::async::makeEnhancedFuture([]() { + delayMs(500); + std::cout << "Void Task Done" << std::endl; + }); + + // Set a completion callback + future.onComplete([&]() { + callbackExecuted = true; + std::cout << "Callback executed" << std::endl; + }); + + // Wait for the future and check if the callback was invoked + future.wait(); + EXPECT_TRUE(callbackExecuted); +} + +// Test: WaitFor on void future with timeout +TEST(EnhancedFutureTest, VoidWaitForTimeout) { + auto future = atom::async::makeEnhancedFuture([]() { + delayMs(2000); // Simulate a long task (2 seconds) + }); + + // Wait for 1 second, less than task duration + bool result = future.waitFor(std::chrono::milliseconds(1000)); + + // Expect that the future was not completed and was cancelled + EXPECT_FALSE(result); + EXPECT_TRUE(future.isCancelled()); +} + +// Test: Multiple callbacks onComplete +TEST(EnhancedFutureTest, MultipleOnCompleteCallbacks) { + std::atomic callbackCount{0}; + + auto future = atom::async::makeEnhancedFuture([]() -> int { + delayMs(500); + return 42; + }); + + // Set two callbacks + future.onComplete([&](int result) { + EXPECT_EQ(result, 42); + ++callbackCount; + }); + + future.onComplete([&](int result) { + EXPECT_EQ(result, 42); + ++callbackCount; + }); + + // Wait for the future to complete + future.wait(); + + // Expect both callbacks to have been called + EXPECT_EQ(callbackCount.load(), 2); +} diff --git a/tests/atom/type/pod_vector.cpp b/tests/atom/type/pod_vector.cpp index 3b82df6e..4064f6ac 100644 --- a/tests/atom/type/pod_vector.cpp +++ b/tests/atom/type/pod_vector.cpp @@ -174,34 +174,3 @@ TEST(PodVectorTest, Detach) { EXPECT_EQ(data[3], 4); std::free(data); } - -// Test case for Stack class -TEST(StackTest, BasicOperations) { - Stack stack; - EXPECT_TRUE(stack.empty()); - - stack.push(1); - stack.push(2); - stack.push(3); - EXPECT_EQ(stack.size(), 3); - EXPECT_EQ(stack.top(), 3); - - stack.pop(); - EXPECT_EQ(stack.size(), 2); - EXPECT_EQ(stack.top(), 2); - - stack.clear(); - EXPECT_TRUE(stack.empty()); -} - -// Test case for Stack class with popx method -TEST(StackTest, Popx) { - Stack stack; - stack.push(1); - stack.push(2); - stack.push(3); - - int last = stack.popx(); - EXPECT_EQ(last, 3); - EXPECT_EQ(stack.size(), 2); -} diff --git a/tests/atom/utils/to_string.cpp b/tests/atom/utils/to_string.cpp index a539e273..6178becd 100644 --- a/tests/atom/utils/to_string.cpp +++ b/tests/atom/utils/to_string.cpp @@ -1,36 +1,96 @@ #include "atom/utils/to_string.hpp" #include +#include +#include +#include + using namespace atom::utils; -TEST(StringUtilsTest, ToStringTest) { - // Basic type - int num = 123; - EXPECT_EQ(toString(num), "123"); +// 基本类型测试 +TEST(ToStringTest, BasicTypes) { + EXPECT_EQ(toString(42), "42"); + EXPECT_EQ(toString(3.14), "3.14"); + EXPECT_EQ(toString('A'), "A"); + EXPECT_EQ(toString(true), "1"); +} + +// 字符串类型测试 +TEST(ToStringTest, StringTypes) { + EXPECT_EQ(toString(std::string("hello")), "hello"); + EXPECT_EQ(toString("world"), "world"); +} + +// 枚举类型测试 +enum class MyEnum { Value1 = 1, Value2 = 2 }; - // String type - std::string str = "hello"; - EXPECT_EQ(toString(str), "hello"); +TEST(ToStringTest, EnumType) { + EXPECT_EQ(toString(MyEnum::Value1), "1"); + EXPECT_EQ(toString(MyEnum::Value2), "2"); +} - // Container type +// 容器类型测试 +TEST(ToStringTest, ContainerTypes) { std::vector vec = {1, 2, 3}; EXPECT_EQ(toString(vec), "[1, 2, 3]"); + + std::vector strVec = {"one", "two", "three"}; + EXPECT_EQ(toString(strVec), "[one, two, three]"); +} + +// 指针类型测试 +TEST(ToStringTest, PointerTypes) { + int val = 42; + int* ptr = &val; + EXPECT_EQ(toString(ptr), "Pointer(42)"); + + int* nullPtr = nullptr; + EXPECT_EQ(toString(nullPtr), "nullptr"); +} + +// 智能指针类型测试 +TEST(ToStringTest, SmartPointerTypes) { + auto smartPtr = std::make_unique(42); + EXPECT_EQ(toString(smartPtr), "SmartPointer(42)"); + + std::unique_ptr nullSmartPtr = nullptr; + EXPECT_EQ(toString(nullSmartPtr), "nullptr"); +} + +// 映射类型测试 +TEST(ToStringTest, MapTypes) { + std::map map = {{1, "one"}, {2, "two"}}; + EXPECT_EQ(toString(map), "{1: one, 2: two}"); + + std::unordered_map unorderedMap = {{"one", 1}, + {"two", 2}}; + EXPECT_EQ(toString(unorderedMap), "{one: 1, two: 2}"); } -TEST(StringUtilsTest, JoinKeyValuePairTest) { - // String type key-value pair - std::string key = "name"; - std::string value = "Max"; - EXPECT_EQ(joinKeyValuePair(key, value), "nameMax"); +// 键值对测试 +TEST(ToStringTest, PairType) { + std::pair pair = {1, "one"}; + EXPECT_EQ(toString(pair), "(1, one)"); } -TEST(StringUtilsTest, JoinCommandLineTest) { - std::string arg1 = "arg1"; - std::string arg2 = "arg2"; - std::string arg3 = "arg3"; - EXPECT_EQ(joinCommandLine(arg1, arg2, arg3), "arg1 arg2 arg3"); +// 数组测试 +TEST(ToStringTest, ArrayType) { + int arr[] = {1, 2, 3}; + EXPECT_EQ(toString(arr), "[1, 2, 3]"); } -TEST(StringUtilsTest, ToStringArrayTest) { - std::vector array = {1, 2, 3, 4, 5}; - EXPECT_EQ(toStringArray(array), "1 2 3 4 5"); +// 命令行参数测试 +TEST(ToStringTest, JoinCommandLine) { + EXPECT_EQ(joinCommandLine(1, "two", 3.14), "1 two 3.14"); +} + +// toStringArray 测试 +TEST(ToStringTest, ToStringArray) { + std::vector vec = {1, 2, 3}; + EXPECT_EQ(toStringArray(vec), "1 2 3"); +} + +// 范围测试 +TEST(ToStringTest, ToStringRange) { + std::vector vec = {1, 2, 3}; + EXPECT_EQ(toStringRange(vec.begin(), vec.end()), "[1, 2, 3]"); } diff --git a/tests/atom/web/downloader.cpp b/tests/atom/web/downloader.cpp new file mode 100644 index 00000000..b351f023 --- /dev/null +++ b/tests/atom/web/downloader.cpp @@ -0,0 +1,164 @@ +#include "atom/web/downloader.hpp" + +#include +#include + +// 模拟下载 URL 和文件路径 +const std::string mock_url = "https://example.com/testfile"; +const std::string mock_file = "testfile.txt"; + +// 继承测试类来组织不同的测试 +class DownloadManagerTest : public ::testing::Test { +protected: + // 每次测试前初始化 + void SetUp() override { + // 初始化 DownloadManager 实例 + download_manager = + std::make_unique("tasks.txt"); + } + + // 每次测试后清理 + void TearDown() override { + // 清理模拟下载的文件 + std::remove("tasks.txt"); + std::remove("testfile.txt"); + } + + std::unique_ptr download_manager; +}; + +// 测试添加任务 +TEST_F(DownloadManagerTest, AddTask) { + // 添加一个任务 + download_manager->add_task(mock_url, mock_file); + + // 检查任务是否正确添加 + ASSERT_EQ(download_manager->get_downloaded_bytes(0), 0); +} + +// 测试删除任务 +TEST_F(DownloadManagerTest, RemoveTask) { + download_manager->add_task(mock_url, mock_file); + + // 检查任务是否添加 + ASSERT_EQ(download_manager->get_downloaded_bytes(0), 0); + + // 删除任务 + bool removed = download_manager->remove_task(0); + + // 检查任务是否删除成功 + ASSERT_TRUE(removed); +} + +// 测试暂停和恢复任务 +TEST_F(DownloadManagerTest, PauseResumeTask) { + download_manager->add_task(mock_url, mock_file); + + // 暂停任务 + download_manager->pause_task(0); + + // 模拟任务已暂停(通过检查文件字节数不变) + ASSERT_EQ(download_manager->get_downloaded_bytes(0), 0); + + // 恢复任务 + download_manager->resume_task(0); + + // 恢复后,任务应可以继续下载 + ASSERT_EQ(download_manager->get_downloaded_bytes(0), + 0); // 检查任务仍未开始(字节为0) +} + +// 测试取消任务 +TEST_F(DownloadManagerTest, CancelTask) { + download_manager->add_task(mock_url, mock_file); + + // 取消任务 + download_manager->cancel_task(0); + + // 模拟任务被取消后,下载应不会进行 + ASSERT_EQ(download_manager->get_downloaded_bytes(0), 0); +} + +// 测试任务的下载进度更新 +TEST_F(DownloadManagerTest, ProgressUpdate) { + download_manager->add_task(mock_url, mock_file); + + // 设置进度更新回调函数 + bool progress_updated = false; + download_manager->on_progress_update([&](size_t index, double progress) { + progress_updated = true; + ASSERT_EQ(index, 0); // 检查任务索引是否正确 + ASSERT_GE(progress, 0.0); // 进度应大于等于 0 + }); + + // 启动下载任务 + download_manager->start(1); + + // 检查进度是否已更新 + ASSERT_TRUE(progress_updated); +} + +// 测试任务完成后的通知 +TEST_F(DownloadManagerTest, DownloadCompleteNotification) { + download_manager->add_task(mock_url, mock_file); + + // 设置下载完成回调函数 + bool task_completed = false; + download_manager->on_download_complete([&](size_t index) { + task_completed = true; + ASSERT_EQ(index, 0); // 检查任务索引 + }); + + // 启动下载任务 + download_manager->start(1); + + // 检查是否收到任务完成通知 + ASSERT_TRUE(task_completed); +} + +// 测试多任务并发 +TEST_F(DownloadManagerTest, ConcurrentTasks) { + download_manager->add_task(mock_url, "file1.txt"); + download_manager->add_task(mock_url, "file2.txt"); + + // 使用多个线程启动下载 + download_manager->start(2); + + // 确保两个任务都添加了并下载完成 + ASSERT_EQ(download_manager->get_downloaded_bytes(0), 0); // 模拟已下载 + ASSERT_EQ(download_manager->get_downloaded_bytes(1), 0); // 模拟已下载 + + // 清理生成的文件 + std::remove("file1.txt"); + std::remove("file2.txt"); +} + +// 测试设置最大重试次数 +TEST_F(DownloadManagerTest, MaxRetries) { + download_manager->add_task(mock_url, mock_file); + + // 设置重试次数 + download_manager->set_max_retries(3); + + // 启动任务 + download_manager->start(1); + + // 检查重试次数 + // 模拟失败后重试3次 + ASSERT_EQ(download_manager->get_downloaded_bytes(0), + 0); // 重试后仍未成功,模拟失败 +} + +// 测试线程数动态调整 +TEST_F(DownloadManagerTest, SetThreadCount) { + download_manager->add_task(mock_url, mock_file); + + // 动态设置下载线程数 + download_manager->set_thread_count(4); + + // 启动下载任务 + download_manager->start(); + + // 确保任务添加后没有下载错误 + ASSERT_EQ(download_manager->get_downloaded_bytes(0), 0); +} diff --git a/tests/atom/web/httplite.cpp b/tests/atom/web/httplite.cpp deleted file mode 100644 index c661ee47..00000000 --- a/tests/atom/web/httplite.cpp +++ /dev/null @@ -1,95 +0,0 @@ -#include -#include "atom/web/httplite.hpp" - -TEST(HttpClientTest, ConnectToServer) -{ - HttpClient client; - EXPECT_TRUE(client.initialize()); - EXPECT_TRUE(client.connectToServes://www.baidu.com", 80, false)); -} - -TEST(HttpClientTest, SendRequest) -{ - HttpClient client; - client.initialize(); - client.connectToServes://www.baidu.com", 80, false); - EXPECT_TRUE(client.sendRequest("GET / HTTP/1.1\r\nHoss://www.baidu.com\r\n\r\n")); -} - -TEST(HttpClientTest, ReceiveResponse) -{ - HttpClient client; - client.initialize(); - client.connectToServes://www.baidu.com", 80, false); - client.sendRequest("GET / HTTP/1.1\r\nHoss://www.baidu.com\r\n\r\n"); - HttpResponse response = client.receiveResponse(); - EXPECT_FALSE(response.body.empty()); - EXPECT_EQ(response.statusCode, 200); -} - -TEST(HttpRequestBuilderTest, BuildRequestString) -{ - HttpRequestBuilder builder(HttpMethod::GET, "https://www.baidu.com"); - builder.setBody("test body"); - builder.setContentType("text/plain"); - builder.setTimeout(std::chrono::seconds(30)); - builder.addHeader("Authorization", "Bearer token"); - - std::string requestString = builder.buildRequestStrins://www.baidu.com", "/"); - - // Add your specific assertion here based on the expected request string -} - -TEST(HttpClientTest, Initialize) -{ - HttpClient client; - EXPECT_TRUE(client.initialize()); -} - -TEST(HttpClientTest, ErrorHandler) -{ - HttpClient client; - bool errorHandled = false; - client.setErrorHandler([&errorHandled](const std::string &errorMsg) - { - // Custom error handling logic - errorHandled = true; }); - - // Simulate an error - client.connectToServer("invalidhost", 80, false); - EXPECT_TRUE(errorHandled); -} - -TEST(HttpRequestBuilderTest, SetBody) -{ - HttpRequestBuilder builder(HttpMethod::POST, "https://www.baidu.com"); - builder.setBody("test body"); - // Add your specific assertion here based on the expected request body -} - -TEST(HttpRequestBuilderTest, SetContentType) -{ - HttpRequestBuilder builder(HttpMethod::GET, "https://www.baidu.com"); - builder.setContentType("application/json"); - // Add your specific assertion here based on the expected content type -} - -TEST(HttpRequestBuilderTest, SetTimeout) -{ - HttpRequestBuilder builder(HttpMethod::GET, "https://www.baidu.com"); - builder.setTimeout(std::chrono::seconds(60)); - // Add your specific assertion here based on the expected timeout value -} - -TEST(HttpRequestBuilderTest, AddHeader) -{ - HttpRequestBuilder builder(HttpMethod::GET, "https://www.baidu.com"); - builder.addHeader("Authorization", "Bearer token"); - // Add your specific assertion here based on the expected headers -} - -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tests/atom/web/httpparser.cpp b/tests/atom/web/httpparser.cpp new file mode 100644 index 00000000..fa1e7562 --- /dev/null +++ b/tests/atom/web/httpparser.cpp @@ -0,0 +1,124 @@ +/* + * test_httpparser.cpp + * + * Copyright (C) 2023-2024 Max Qian + */ + +#include "atom/web/httpparser.hpp" +#include +#include + +using namespace atom::web; + +// Test fixture class for setting up and tearing down parser tests +class HttpHeaderParserTest : public ::testing::Test { +protected: + HttpHeaderParser parser; + + void SetUp() override { + // Set up test environment, if needed + } + + void TearDown() override { + // Clean up after test, if needed + } +}; + +// Test for parsing simple headers +TEST_F(HttpHeaderParserTest, ParseHeaders_Simple) { + std::string rawHeaders = "Host: example.com\nUser-Agent: test-agent\n"; + parser.parseHeaders(rawHeaders); + + auto host = parser.getHeaderValues("Host"); + auto userAgent = parser.getHeaderValues("User-Agent"); + + ASSERT_TRUE(host.has_value()); + ASSERT_EQ(host->size(), 1); + EXPECT_EQ(host->at(0), "example.com"); + + ASSERT_TRUE(userAgent.has_value()); + ASSERT_EQ(userAgent->size(), 1); + EXPECT_EQ(userAgent->at(0), "test-agent"); +} + +// Test for setting a single header value +TEST_F(HttpHeaderParserTest, SetHeaderValue) { + parser.setHeaderValue("Content-Type", "text/html"); + + auto contentType = parser.getHeaderValues("Content-Type"); + + ASSERT_TRUE(contentType.has_value()); + ASSERT_EQ(contentType->size(), 1); + EXPECT_EQ(contentType->at(0), "text/html"); +} + +// Test for adding multiple values to a header +TEST_F(HttpHeaderParserTest, AddHeaderValue) { + parser.setHeaderValue("Set-Cookie", "cookie1=value1"); + parser.addHeaderValue("Set-Cookie", "cookie2=value2"); + + auto cookies = parser.getHeaderValues("Set-Cookie"); + + ASSERT_TRUE(cookies.has_value()); + ASSERT_EQ(cookies->size(), 2); + EXPECT_EQ(cookies->at(0), "cookie1=value1"); + EXPECT_EQ(cookies->at(1), "cookie2=value2"); +} + +// Test for checking header existence +TEST_F(HttpHeaderParserTest, HasHeader) { + parser.setHeaderValue("Authorization", "Bearer token"); + + EXPECT_TRUE(parser.hasHeader("Authorization")); + EXPECT_FALSE(parser.hasHeader("Non-Existent-Header")); +} + +// Test for removing a header +TEST_F(HttpHeaderParserTest, RemoveHeader) { + parser.setHeaderValue("Connection", "keep-alive"); + EXPECT_TRUE(parser.hasHeader("Connection")); + + parser.removeHeader("Connection"); + EXPECT_FALSE(parser.hasHeader("Connection")); +} + +// Test for clearing all headers +TEST_F(HttpHeaderParserTest, ClearHeaders) { + parser.setHeaderValue("Accept", "text/html"); + EXPECT_TRUE(parser.hasHeader("Accept")); + + parser.clearHeaders(); + EXPECT_FALSE(parser.hasHeader("Accept")); +} + +// Test for setting multiple headers at once +TEST_F(HttpHeaderParserTest, SetHeaders) { + std::map> headers = { + {"Accept-Encoding", {"gzip", "deflate"}}, + {"User-Agent", {"gtest-agent"}}}; + parser.setHeaders(headers); + + auto encoding = parser.getHeaderValues("Accept-Encoding"); + auto userAgent = parser.getHeaderValues("User-Agent"); + + ASSERT_TRUE(encoding.has_value()); + ASSERT_EQ(encoding->size(), 2); + EXPECT_EQ(encoding->at(0), "gzip"); + EXPECT_EQ(encoding->at(1), "deflate"); + + ASSERT_TRUE(userAgent.has_value()); + ASSERT_EQ(userAgent->size(), 1); + EXPECT_EQ(userAgent->at(0), "gtest-agent"); +} + +// Test for getting all headers +TEST_F(HttpHeaderParserTest, GetAllHeaders) { + parser.setHeaderValue("Accept", "text/html"); + parser.setHeaderValue("Content-Type", "application/json"); + + auto allHeaders = parser.getAllHeaders(); + + ASSERT_EQ(allHeaders.size(), 2); + EXPECT_EQ(allHeaders["Accept"].at(0), "text/html"); + EXPECT_EQ(allHeaders["Content-Type"].at(0), "application/json"); +} diff --git a/tests/atom_test.hpp b/tests/atom_test.hpp new file mode 100644 index 00000000..ddf9d4a1 --- /dev/null +++ b/tests/atom_test.hpp @@ -0,0 +1,506 @@ +#ifndef ATOM_TEST_HPP +#define ATOM_TEST_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include "atom/type/json.hpp" +#endif + +#include "macro.hpp" + +// 一个简单的测试框架 +namespace atom::test { + +struct TestCase { + std::string name; + std::function func; + bool skip = false; // 是否跳过测试 + bool async = false; // 是否异步运行 + double timeLimit = 0.0; // 测试时间阈值 + std::vector dependencies; // 依赖的测试 +} ATOM_ALIGNAS(128); + +struct TestResult { + std::string name; + bool passed; + bool skipped; + std::string message; + double duration; + bool timedOut; +} ATOM_ALIGNAS(128); + +struct TestSuite { + std::string name; + std::vector testCases; +} ATOM_ALIGNAS(64); + +ATOM_INLINE auto getTestSuites() -> std::vector& { + static std::vector testSuites; + return testSuites; +} + +ATOM_INLINE auto getTestMutex() -> std::mutex& { + static std::mutex testMutex; + return testMutex; +} + +ATOM_INLINE void registerTest(const std::string& name, + std::function func, bool async = false, + double time_limit = 0.0, bool skip = false, + std::vector dependencies = {}) { + getTestSuites().push_back({"", + {TestCase{name, std::move(func), skip, async, + time_limit, dependencies}}}); +} + +ATOM_INLINE void registerSuite(const std::string& suite_name, + std::vector cases) { + getTestSuites().push_back({suite_name, std::move(cases)}); +} + +ATOM_INLINE auto operator""_test(const char* name, std::size_t) { + return [name](std::function func, bool async = false, + double time_limit = 0.0, bool skip = false, + std::vector const& dependencies = {}) { + return TestCase{name, std::move(func), skip, + async, time_limit, dependencies}; + }; +} + +struct TestStats { + int totalTests = 0; + int totalAsserts = 0; + int passedAsserts = 0; + int failedAsserts = 0; + int skippedTests = 0; + std::vector results; +} ATOM_ALIGNAS(64); + +ATOM_INLINE auto getTestStats() -> TestStats& { + static TestStats stats; + return stats; +} + +using Hook = std::function; + +struct Hooks { + Hook beforeEach; + Hook afterEach; + Hook beforeAll; + Hook afterAll; +} ATOM_ALIGNAS(128); + +ATOM_INLINE auto getHooks() -> Hooks& { + static Hooks hooks; + return hooks; +} + +ATOM_INLINE void printColored(const std::string& text, + const std::string& color_code) { + std::cout << "\033[" << color_code << "m" << text << "\033[0m"; +} + +struct Timer { + std::chrono::high_resolution_clock::time_point startTime; + + Timer() { reset(); } + + void reset() { startTime = std::chrono::high_resolution_clock::now(); } + + [[nodiscard]] auto elapsed() const -> double { + return std::chrono::duration( + std::chrono::high_resolution_clock::now() - startTime) + .count(); + } +}; + +// 支持多种格式的结果导出(JSON, XML, HTML) +ATOM_INLINE void exportResults(const std::string& filename, + const std::string& format) { + auto& stats = getTestStats(); + nlohmann::json jsonReport; + + jsonReport["total_tests"] = stats.totalTests; + jsonReport["total_asserts"] = stats.totalAsserts; + jsonReport["passed_asserts"] = stats.passedAsserts; + jsonReport["failed_asserts"] = stats.failedAsserts; + jsonReport["skipped_tests"] = stats.skippedTests; + jsonReport["test_results"] = nlohmann::json::array(); + + for (const auto& result : stats.results) { + nlohmann::json jsonResult; + jsonResult["name"] = result.name; + jsonResult["passed"] = result.passed; + jsonResult["skipped"] = result.skipped; + jsonResult["message"] = result.message; + jsonResult["duration"] = result.duration; + jsonResult["timed_out"] = result.timedOut; + jsonReport["test_results"].push_back(jsonResult); + } + + if (format == "json") { + std::ofstream file(filename + ".json"); + file << jsonReport.dump(4); + file.close(); + std::cout << "Test report saved to " << filename << ".json\n"; + } else if (format == "xml") { + std::ofstream file(filename + ".xml"); + file << "\n\n"; + file << " " << stats.totalTests << "\n"; + file << " " << stats.passedAsserts + << "\n"; + file << " " << stats.failedAsserts + << "\n"; + file << " " << stats.skippedTests + << "\n"; + for (const auto& result : stats.results) { + file << " \n"; + file << " " << (result.passed ? "true" : "false") + << "\n"; + file << " " << result.message << "\n"; + file << " " << result.duration << "\n"; + file << " " << (result.timedOut ? "true" : "false") + << "\n"; + file << " \n"; + } + file << "\n"; + file.close(); + std::cout << "Test report saved to " << filename << ".xml\n"; + } else if (format == "html") { + std::ofstream file(filename + ".html"); + file << "Test Report" + "\n"; + file << "

    Test Report

    \n"; + file << "

    Total Tests: " << stats.totalTests << "

    \n"; + file << "

    Passed Asserts: " << stats.passedAsserts << "

    \n"; + file << "

    Failed Asserts: " << stats.failedAsserts << "

    \n"; + file << "

    Skipped Tests: " << stats.skippedTests << "

    \n"; + file << "
      \n"; + for (const auto& result : stats.results) { + file << "
    • " << result.name << ": " + << (result.passed ? "PASSED" + : "FAILED") + << " (" << result.duration << " ms)
    • \n"; + } + file << "
    \n"; + file << ""; + file.close(); + std::cout << "Test report saved to " << filename << ".html\n"; + } +} + +// 执行单个测试用例 +ATOM_INLINE void runTestCase(const TestCase& test, int retry_count = 0) { + auto& stats = getTestStats(); + Timer timer; + + if (test.skip) { + printColored("SKIPPED\n", "1;33"); + std::lock_guard lock(getTestMutex()); + stats.skippedTests++; + stats.totalTests++; + stats.results.push_back( + {std::string(test.name), false, true, "Test Skipped", 0.0, false}); + return; + } + + std::string resultMessage; + bool passed = false; + bool timedOut = false; + + try { + timer.reset(); + if (test.async) { + auto future = std::async(std::launch::async, test.func); + if (test.timeLimit > 0 && future.wait_for(std::chrono::milliseconds( + static_cast(test.timeLimit))) == + std::future_status::timeout) { + timedOut = true; + throw std::runtime_error("Test timed out"); + } + future.get(); + } else { + test.func(); + } + passed = true; + resultMessage = "PASSED"; + } catch (const std::exception& e) { + resultMessage = e.what(); + if (retry_count > 0) { + printColored("Retrying test...\n", "1;33"); + runTestCase(test, retry_count - 1); + return; + } + } + + std::lock_guard lock(getTestMutex()); + stats.totalTests++; + stats.results.push_back({std::string(test.name), passed, false, + resultMessage, timer.elapsed(), timedOut}); + + if (timedOut) { + printColored(resultMessage + " (TIMEOUT)", "1;31"); + } else { + printColored(resultMessage, passed ? "1;32" : "1;31"); + } + std::cout << " (" << timer.elapsed() << " ms)\n"; +} + +// 支持并行执行测试 +ATOM_INLINE void runTestsInParallel(const std::vector& tests, + int num_threads = 4) { + std::vector threads; + threads.reserve(num_threads); + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back([i, &tests, num_threads]() { + for (size_t j = i; j < tests.size(); j += num_threads) { + runTestCase(tests[j]); + } + }); + } + + for (auto& t : threads) { + t.join(); + } +} + +ATOM_INLINE void runAllTests(int retry_count = 0, bool parallel = false, + int num_threads = 4); + +ATOM_INLINE void runTests(int argc, char* argv[]) { + int retry_count = 0; + bool parallel = false; + int num_threads = 4; + std::string export_format; + std::string export_filename; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--retry" && i + 1 < argc) { + retry_count = std::stoi(argv[++i]); + } else if (arg == "--parallel" && i + 1 < argc) { + parallel = true; + num_threads = std::stoi(argv[++i]); + } else if (arg == "--export" && i + 2 < argc) { + export_format = argv[++i]; + export_filename = argv[++i]; + } + } + + runAllTests(retry_count, parallel, num_threads); + + if (!export_format.empty() && !export_filename.empty()) { + exportResults(export_filename, export_format); + } +} + +// 过滤测试用例 +ATOM_INLINE auto filterTests(const std::regex& pattern) + -> std::vector { + std::vector filtered; + for (const auto& suite : getTestSuites()) { + for (const auto& test : suite.testCases) { + if (std::regex_search(test.name.begin(), test.name.end(), + pattern)) { + filtered.push_back(test); + } + } + } + return filtered; +} + +// 根据依赖关系排序测试 +ATOM_INLINE auto sortTestsByDependencies(const std::vector& tests) + -> std::vector { + std::map testMap; + std::vector sortedTests; + std::set processed; + + for (const auto& test : tests) { + testMap[test.name] = test; + } + + std::function resolveDependencies; + resolveDependencies = [&](const TestCase& test) { + if (!processed.contains(std::string(test.name))) { + for (const auto& dep : test.dependencies) { + if (testMap.find(dep) != testMap.end()) { + resolveDependencies(testMap[dep]); + } + } + processed.insert(std::string(test.name)); + sortedTests.push_back(test); + } + }; + + for (const auto& test : tests) { + resolveDependencies(test); + } + + return sortedTests; +} + +// 运行所有测试 +ATOM_INLINE void runAllTests(int retry_count, bool parallel, int num_threads) { + auto& stats = getTestStats(); + Timer globalTimer; + + std::vector allTests; + for (const auto& suite : getTestSuites()) { + allTests.insert(allTests.end(), suite.testCases.begin(), + suite.testCases.end()); + } + + // 按依赖关系排序测试 + allTests = sortTestsByDependencies(allTests); + + if (parallel) { + runTestsInParallel(allTests, num_threads); + } else { + for (const auto& test : allTests) { + runTestCase(test, retry_count); + } + } + + std::cout << "=============================================================" + "==================\n"; + std::cout << "Total tests: " << stats.totalTests << "\n"; + std::cout << "Total asserts: " << stats.totalAsserts << " | " + << stats.passedAsserts << " passed | " << stats.failedAsserts + << " failed | " << stats.skippedTests << " skipped\n"; + std::cout << "Total time: " << globalTimer.elapsed() << " ms\n"; +} + +// 测试断言 +struct expect { + bool result; + const char* file; + int line; + std::string message; + + expect(bool result, const char* file, int line, std::string msg) + : result(result), file(file), line(line), message(msg) { + auto& stats = getTestStats(); + stats.totalAsserts++; + if (!result) { + stats.failedAsserts++; + throw std::runtime_error(std::string(file) + ":" + + std::to_string(line) + ": FAILED - " + + std::string(msg)); + } + stats.passedAsserts++; + } +}; + +// 其他断言类型 +ATOM_INLINE auto expect_approx(double lhs, double rhs, double epsilon, + const char* file, int line) -> expect { + bool result = std::abs(lhs - rhs) <= epsilon; + return {result, file, line, + "Expected " + std::to_string(lhs) + " approx equal to " + + std::to_string(rhs)}; +} + +template +auto expect_eq(const T& lhs, const U& rhs, const char* file, + int line) -> expect { + return expect(lhs == rhs, file, line, + std::string("Expected ") + std::to_string(lhs) + + " == " + std::to_string(rhs)); +} + +template +auto expect_ne(const T& lhs, const U& rhs, const char* file, + int line) -> expect { + return expect(lhs != rhs, file, line, + std::string("Expected ") + std::to_string(lhs) + + " != " + std::to_string(rhs)); +} + +template +auto expect_gt(const T& lhs, const U& rhs, const char* file, + int line) -> expect { + return expect(lhs > rhs, file, line, + std::string("Expected ") + std::to_string(lhs) + " > " + + std::to_string(rhs)); +} + +// 字符串包含断言 +ATOM_INLINE auto expect_contains(const std::string& str, + const std::string& substr, const char* file, + int line) -> expect { + bool result = str.find(substr) != std::string::npos; + return {result, file, line, + "Expected \"" + str + "\" to contain \"" + substr + "\""}; +} + +// 集合相等断言 +template +ATOM_INLINE auto expect_set_eq(const std::vector& lhs, + const std::vector& rhs, const char* file, + int line) -> expect { + std::set lhsSet(lhs.begin(), lhs.end()); + std::set rhsSet(rhs.begin(), rhs.end()); + bool result = lhsSet == rhsSet; + return {result, file, line, "Expected sets to be equal"}; +} + +// 新增的断言类型 +template +auto expect_lt(const T& lhs, const U& rhs, const char* file, + int line) -> expect { + return expect(lhs < rhs, file, line, + std::string("Expected ") + std::to_string(lhs) + " < " + + std::to_string(rhs)); +} + +template +auto expect_ge(const T& lhs, const U& rhs, const char* file, + int line) -> expect { + return expect(lhs >= rhs, file, line, + std::string("Expected ") + std::to_string(lhs) + + " >= " + std::to_string(rhs)); +} + +template +auto expect_le(const T& lhs, const U& rhs, const char* file, + int line) -> expect { + return expect(lhs <= rhs, file, line, + std::string("Expected ") + std::to_string(lhs) + + " <= " + std::to_string(rhs)); +} + +} // namespace atom::test + +#define expect(expr) atom::test::expect(expr, __FILE__, __LINE__, #expr) +#define expect_eq(lhs, rhs) atom::test::expect_eq(lhs, rhs, __FILE__, __LINE__) +#define expect_ne(lhs, rhs) atom::test::expect_ne(lhs, rhs, __FILE__, __LINE__) +#define expect_gt(lhs, rhs) atom::test::expect_gt(lhs, rhs, __FILE__, __LINE__) +#define expect_lt(lhs, rhs) atom::test::expect_lt(lhs, rhs, __FILE__, __LINE__) +#define expect_ge(lhs, rhs) atom::test::expect_ge(lhs, rhs, __FILE__, __LINE__) +#define expect_le(lhs, rhs) atom::test::expect_le(lhs, rhs, __FILE__, __LINE__) +#define expect_approx(lhs, rhs, eps) \ + atom::test::expect_approx(lhs, rhs, eps, __FILE__, __LINE__) +#define expect_contains(str, substr) \ + atom::test::expect_contains(str, substr, __FILE__, __LINE__) +#define expect_set_eq(lhs, rhs) ut::expect_set_eq(lhs, rhs, __FILE__, __LINE__) + +#endif diff --git a/tests/components/meta/anymeta.cpp b/tests/components/meta/anymeta.cpp index 45ca12f1..7bb59b5c 100644 --- a/tests/components/meta/anymeta.cpp +++ b/tests/components/meta/anymeta.cpp @@ -16,13 +16,17 @@ class TestClass { int getValue() const { return value; } void setValue(int v) { - std::cout << "Setting value to " << v << std::endl; + // std::cout << "Setting value to " << v << std::endl; value = v; } - void printValue() const { std::cout << "Value: " << value << std::endl; } + void printValue() const { + // std::cout << "Value: " << value << std::endl; + } - static void staticPrint() { std::cout << "Static print" << std::endl; } + static void staticPrint() { + // std::cout << "Static print" << std::endl;// + } }; // 注册类型信息 @@ -50,7 +54,7 @@ class TestClassRegistrar { "getValue", [](std::vector args) -> BoxedValue { auto& obj = args[0]; auto value = obj.tryCast()->getValue(); - std::cout << "Value: " << value << std::endl; + // std::cout << "Value: " << value << std::endl; return BoxedValue(value); }); @@ -77,11 +81,13 @@ class TestClassRegistrar { return BoxedValue(obj.tryCast()->getValue()); }, [](BoxedValue& obj, const BoxedValue& value) { - std::cout << "Setting value to " << value.getTypeInfo().name() - << ": " << value.tryCast().value() << std::endl; + // std::cout << "Setting value to " << + // value.getTypeInfo().name() + // << ": " << value.tryCast().value() << + // std::endl; if (auto v = value.tryCast(); v.has_value()) { obj.tryCast()->setValue(*v); - std::cout << "Value set to " << *v << std::endl; + // std::cout << "Value set to " << *v << std::endl; } else { THROW_INVALID_ARGUMENT("Invalid type for value property"); } diff --git a/tests/components/meta/refl_json.cpp b/tests/components/meta/refl_json.cpp index d09a962f..dae34016 100644 --- a/tests/components/meta/refl_json.cpp +++ b/tests/components/meta/refl_json.cpp @@ -45,7 +45,7 @@ TEST_F(ReflectableTest, FromJsonSuccess) { TEST_F(ReflectableTest, FromJsonMissingRequiredField) { json j = R"({"age": 25})"_json; // Missing "name" field - EXPECT_THROW(reflectable.from_json(j), atom::error::InvalidArgument); + EXPECT_THROW(reflectable.from_json(j), atom::error::MissingArgument); } // Test for validation failure diff --git a/utils.py b/utils.py new file mode 100644 index 00000000..e69de29b