diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5ffc8e14..8213984a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -60,6 +60,51 @@ LIST(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules/")
LIST(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../cmake_modules/")
include(cmake_modules/compiler_options.cmake)
+if (USE_CONAN)
+include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake OPTIONAL) # Optional inclusion
+
+# 检查是否已经安装Conan
+find_program(CONAN_CMD conan)
+if(NOT CONAN_CMD)
+ message(FATAL_ERROR "Conan is not installed. Please install Conan (pip install conan).")
+endif()
+
+# 检测Conan默认配置文件是否存在
+execute_process(
+ COMMAND ${CONAN_CMD} config home
+ OUTPUT_VARIABLE CONAN_HOME
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+)
+
+set(CONAN_DEFAULT_PROFILE "${CONAN_HOME}/profiles/default")
+
+if(NOT EXISTS "${CONAN_DEFAULT_PROFILE}")
+ message(STATUS "Conan default profile not found. Creating a new profile based on platform.")
+ # 根据操作系统创建默认配置
+ if(WIN32)
+ execute_process(COMMAND ${CONAN_CMD} profile detect --force)
+ elseif(UNIX)
+ execute_process(COMMAND ${CONAN_CMD} profile detect --force)
+ else()
+ message(FATAL_ERROR "Unsupported platform for Conan profile detection.")
+ endif()
+endif()
+
+# 如果conanbuildinfo.cmake不存在,执行conan install命令
+if(NOT EXISTS "${CMAKE_BINARY_DIR}/conanbuildinfo.cmake")
+ message(STATUS "Running Conan install...")
+ execute_process(
+ COMMAND ${CONAN_CMD} install ${CMAKE_SOURCE_DIR} --build=missing
+ RESULT_VARIABLE result
+ )
+ if(result)
+ message(FATAL_ERROR "Conan install failed with error code: ${result}")
+ endif()
+ include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
+ conan_basic_setup()
+endif()
+endif()
+
# Include directories
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_SOURCE_DIR}/libs/)
diff --git a/LICENSE b/LICENSE
index f288702d..0ad25db4 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,5 +1,5 @@
- GNU GENERAL PUBLIC LICENSE
- Version 3, 29 June 2007
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
Copyright (C) 2007 Free Software Foundation, Inc.
Everyone is permitted to copy and distribute verbatim copies
@@ -7,17 +7,15 @@
Preamble
- The GNU General Public License is a free, copyleft license for
-software and other kinds of works.
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
-the GNU General Public License is intended to guarantee your freedom to
+our General Public Licenses are intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
-software for all its users. We, the Free Software Foundation, use the
-GNU General Public License for most of our software; it applies also to
-any other work released this way by its authors. You can apply it to
-your programs, too.
+software for all its users.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
@@ -26,44 +24,34 @@ them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
- To protect your rights, we need to prevent others from denying you
-these rights or asking you to surrender the rights. Therefore, you have
-certain responsibilities if you distribute copies of the software, or if
-you modify it: responsibilities to respect the freedom of others.
-
- For example, if you distribute copies of such a program, whether
-gratis or for a fee, you must pass on to the recipients the same
-freedoms that you received. You must make sure that they, too, receive
-or can get the source code. And you must show them these terms so they
-know their rights.
-
- Developers that use the GNU GPL protect your rights with two steps:
-(1) assert copyright on the software, and (2) offer you this License
-giving you legal permission to copy, distribute and/or modify it.
-
- For the developers' and authors' protection, the GPL clearly explains
-that there is no warranty for this free software. For both users' and
-authors' sake, the GPL requires that modified versions be marked as
-changed, so that their problems will not be attributed erroneously to
-authors of previous versions.
-
- Some devices are designed to deny users access to install or run
-modified versions of the software inside them, although the manufacturer
-can do so. This is fundamentally incompatible with the aim of
-protecting users' freedom to change the software. The systematic
-pattern of such abuse occurs in the area of products for individuals to
-use, which is precisely where it is most unacceptable. Therefore, we
-have designed this version of the GPL to prohibit the practice for those
-products. If such problems arise substantially in other domains, we
-stand ready to extend this provision to those domains in future versions
-of the GPL, as needed to protect the freedom of users.
-
- Finally, every program is threatened constantly by software patents.
-States should not allow patents to restrict development and use of
-software on general-purpose computers, but in those that do, we wish to
-avoid the special danger that patents applied to a free program could
-make it effectively proprietary. To prevent this, the GPL assures that
-patents cannot be used to render the program non-free.
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
The precise terms and conditions for copying, distribution and
modification follow.
@@ -72,7 +60,7 @@ modification follow.
0. Definitions.
- "This License" refers to version 3 of the GNU General Public License.
+ "This License" refers to version 3 of the GNU Affero General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
@@ -549,35 +537,45 @@ to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
- 13. Use with the GNU Affero General Public License.
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
-under version 3 of the GNU Affero General Public License into a single
+under version 3 of the GNU General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
-but the special requirements of the GNU Affero General Public License,
-section 13, concerning interaction through a network will apply to the
-combination as such.
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
-the GNU General Public License from time to time. Such new versions will
-be similar in spirit to the present version, but may differ in detail to
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
-Program specifies that a certain numbered version of the GNU General
+Program specifies that a certain numbered version of the GNU Affero General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
-GNU General Public License, you may choose any version ever published
+GNU Affero General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
-versions of the GNU General Public License can be used, that proxy's
+versions of the GNU Affero General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
@@ -635,40 +633,29 @@ the "copyright" line and a pointer to where the full notice is found.
Copyright (C)
This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU General Public License for more details.
+ GNU Affero General Public License for more details.
- You should have received a copy of the GNU General Public License
+ You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
Also add information on how to contact you by electronic and paper mail.
- If the program does terminal interaction, make it output a short
-notice like this when it starts in an interactive mode:
-
- Copyright (C)
- This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
- This is free software, and you are welcome to redistribute it
- under certain conditions; type `show c' for details.
-
-The hypothetical commands `show w' and `show c' should show the appropriate
-parts of the General Public License. Of course, your program's commands
-might be different; for a GUI interface, you would use an "about box".
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
-For more information on this, and how to apply and follow the GNU GPL, see
+For more information on this, and how to apply and follow the GNU AGPL, see
.
-
- The GNU General Public License does not permit incorporating your program
-into proprietary programs. If your program is a subroutine library, you
-may consider it more useful to permit linking proprietary applications with
-the library. If this is what you want to do, use the GNU Lesser General
-Public License instead of this License. But first, please read
-.
diff --git a/conanfile.py b/conanfile.py
deleted file mode 100644
index a2462b84..00000000
--- a/conanfile.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from conan import ConanFile
-from conan.tools.cmake import CMakeToolchain, CMakeDeps, cmake_layout
-
-
-class ConanFile(ConanFile):
- name = "lithium"
- version = "1.0.0"
-
- # Binary configuration
- settings = "os", "compiler", "build_type", "arch"
-
- # Sources are located in the same place as this recipe, copy them to the recipe
- exports_sources = "CMakeLists.txt", "src/*"
-
- requires = [
- "argparse/3.0",
- "cpp-httplib/0.15.3",
- "openssl/3.2.1",
- "zlib/1.3.1",
- "oatpp/1.3.0",
- "oatpp-websocket/1.3.0",
- "oatpp-openssl/1.3.0",
- "oatpp-swagger/1.3.0",
- "loguru/cci.20230406",
- "cfitsio/4.3.1",
- "tinyxml2/10.0.0",
- "cpython/3.12.2",
- "fmt/10.2.1",
- "opencv/4.9.0"
- ]
-
- def layout(self):
- cmake_layout(self)
-
- def generate(self):
- tc = CMakeToolchain(self)
- tc.generate()
-
- deps = CMakeDeps(self)
- deps.generate()
-
- def build(self):
- cmake = CMake(self)
- cmake.configure()
- cmake.build()
-
- def package(self):
- cmake = CMake(self)
- cmake.install()
-
- def package_info(self):
- self.cpp_info.libs = ["my_project"]
diff --git a/conanfile.txt b/conanfile.txt
index 1d181bcc..3b38b6ab 100644
--- a/conanfile.txt
+++ b/conanfile.txt
@@ -1,18 +1,9 @@
[requires]
-argparse/3.0
cfitsio/4.3.1
cpython/3.12.2
-cpp-httplib/0.15.3
-fmt/10.2.1
-loguru/cci.20230406
-oatpp/1.3.0
-oatpp-websocket/1.3.0
-oatpp-openssl/1.3.0
-oatpp-swagger/1.3.0
opencv/4.9.0
openssl/3.2.1
pybind11/2.12.0
-pybind11_json/0.2.13
tinyxml2/10.0.0
zlib/1.3.1
diff --git a/config/addon/compiler.json b/config/addon/compiler.json
new file mode 100644
index 00000000..36fc3db3
--- /dev/null
+++ b/config/addon/compiler.json
@@ -0,0 +1,24 @@
+{
+ "compiler": "clang++",
+ "optimization_level": "-O3",
+ "cplus_version": "-std=c++20",
+ "warnings": "-Wall -Wextra -Wpedantic",
+ "include_flag": "-I./include -I./external/include",
+ "output_flag": "-o output",
+ "debug_info": "-g",
+ "sanitizers": "-fsanitize=address,undefined",
+ "extra_flags": [
+ "-fno-exceptions",
+ "-fno-rtti",
+ "-march=native"
+ ],
+ "defines": [
+ "-DNDEBUG",
+ "-DUSE_OPENMP"
+ ],
+ "libraries": [
+ "-lpthread",
+ "-lm",
+ "-lrt"
+ ]
+}
diff --git a/doc/extra/day01.md b/doc/extra/day01.md
deleted file mode 100644
index edc6caf4..00000000
--- a/doc/extra/day01.md
+++ /dev/null
@@ -1,209 +0,0 @@
-# 从零入门 C 语言: Day1 - "Hello, World!"原来也这么复杂
-
-大家好啊,欢迎来到 C 语言的世界!作为一个入门者,你即将踏上一段有趣的编程旅程。而我们要做的第一件事,就是写出一个无比经典且简单的程序——"Hello, World!"。准备好了吗?带上你的大脑和电脑(如果是手机超人也可以),我们要出发咯!
-
-## 最难的问题:Hello World
-
-程序员调查表明,Hello, World 劝退很多新人,如果你能越过这道坎,那么你已经击败了 50%的萌新了 awa
-
-```c
-#include
-
-int main() {
- printf("Hello World!\n");
- return 0;
-}
-```
-
-## 冒险启程
-
-### 引入主角:`#include `
-
-一切编程冒险的开始,往往需要介绍主角。`#include `,就像是你打开了大门,邀请了 C 语言的图书管理员进来。这个图书管理员叫做 stdio.h,全名是“标准输入输出头文件”,它会帮你处理所有输入输出的操作。
-
-- 问题:为什么要请这个图书管理员?
-- 答案:因为你马上要用到它的宝贝——printf 函数,来把“Hello, World!”打印到屏幕上。如果没有它,printf 就无家可归,根本用不了。
-
-### 主办方出场:`int main()`
-
-然后,我们的主办方登场了:int main()。这就是你程序的“大老板”,相当于所有代码的起点,所有程序都必须从这里开始执行。主办方全程监督你的程序,不管你有多少代码,都得从这里开始运行。
-
-- 问题:为什么是 int 类型?
-- 答案:因为在 C 语言里,main 函数必须给操作系统打个招呼,告诉它程序运行完后是成功了,还是出问题了。int 类型表示返回一个整数,0 表示“一切顺利,没毛病”,其他数字则意味着“嗯,有点小问题”。
-
-### 打印操作:`printf("Hello, World!\n");`
-
-接下来就是本次冒险的重头戏了!`printf("Hello, World!\n");`。`printf` 函数就像是一个巨大的喇叭,它会大声宣布:“Hello, World!”。
-
-- 问题:为什么有个\n?
-- dark 案:这个神秘的\n 是一个控制符,表示“换行”。就好像你在纸上写字,写完一句话,换行再写下一句。如果没有它,你的输出就会一行连着一行,显得有点凌乱。
-
-### 礼貌的告别:`return 0;`
-
-最后,主办方 main 函数礼貌地对操作系统说:“感谢光临,一切顺利!”通过 return 0;这句话,你告诉操作系统,程序没有出错,可以放心关闭。这是一种好习惯,因为虽然你不告诉它也能运行,但你得有礼貌嘛,是吧?
-
-### 编译与运行:是时候展示真正的实力了
-
-写完代码可不意味着工作结束了!你还得把这段文字转化为计算机能懂的“机器语言”。这是通过编译器来完成的。
-
-#### 编译命令
-
-##### GCC
-
-```bash
-gcc helloworld.c -o helloworld
-```
-
-这句话告诉编译器:“把 helloworld.c 这个文件编译成名为 helloworld 的可执行文件。”
-
-#### 运行程序
-
-```bash
-./helloworld
-```
-
-##### MSVC
-
-直接点编译和运行按钮
-
-这时,计算机会乖乖地在屏幕上显示:“Hello, World!”
-
-## 拆包
-
-### 头文件
-
-#### 什么是头文件?
-
-头文件其实是一些已经写好的 C 语言代码的集合,通常以.h 结尾。你可以把它理解为一份“说明书”,当你需要用某个功能时,头文件就告诉编译器如何找到和使用这个功能。
-
-举个栗子:
-
-- `stdio.h`:包含标准输入输出函数,比如`printf`、`scanf`等。
-- `stdlib.h`:包含内存分配、程序控制、随机数等函数。
-- `math.h`:包含数学函数,比如`sin`、`cos`、`sqrt`等。
-
-#### 如何引入头文件?
-
-我们使用预处理指令`#include`来引入头文件。举个例子:
-
-```c
-#include
-```
-
-这个指令告诉编译器:在编译这段代码时,先把 stdio.h 的内容包含进来,以便程序中可以使用它的功能.
-
-#### 系统头文件和自定义头文件
-
-- 系统头文件:放在尖括号< >里面,编译器知道去系统目录找这些文件。例如:
-
-```c
-#include
-```
-
-- 自定义头文件:如果你自己写了一个头文件,需要用双引号" "包含,编译器会在你的项目目录里找。例如:
-
-```c
-#include "myheader.h"
-```
-
-### 主函数
-
-```c
-int main(int argc, char *argv[]) {
- // 使用命令行参数的代码
- return 0;
-}
-```
-
-- 参数: argc:表示命令行参数的个数。argv[]:是一个字符串数组,存储命令行传入的参数。**一般情况下可以不写**
-- 返回值类型 int:main()函数返回一个整数类型的数据。这个返回值通常用来告诉操作系统程序的执行结果。返回 0 意味着程序成功执行,非零值则表示程序有问题。
-- 函数体:花括号{}里面就是你程序的逻辑。你可以把所有需要执行的代码都写在这里。
-- `return 0;`:return 语句告诉操作系统程序执行完毕了。返回 0 表示成功。虽然对于小程序可以不写 return 0;,但这是个好习惯。
-
-**C++用户请注意,写 `return 0;`并不是好习惯,并不被推荐!**
-
-### 输出
-
-`printf()`函数是 C 语言中最常用的输出函数,它的作用是将数据打印到标准输出(通常是屏幕)。`printf()`函数由`stdio.h`头文件提供。
-
-#### printf()的基本用法
-
-```c
-printf("Hello, World!\n");
-```
-
-这个函数会将双引号内的字符串输出到屏幕上。printf()可以输出不同类型的数据,具体的输出格式由格式化占位符决定。
-
-#### 常用的格式化占位符
-
-- %d:输出整数(int 类型)。
-- %f:输出浮点数(float、double 类型)。
-- %c:输出单个字符。
-- %s:输出字符串(字符数组)。
-
-```c
-int age = 25;
-float height = 5.9;
-printf("I am %d years old and %.1f feet tall.\n", age, height);
-```
-
-```c
-I am 25 years old and 5.9 feet tall.
-```
-
-#### 转义字符
-
-转义字符是用于表示某些特殊字符的,例如换行、制表符等。
-
-- \n:换行符。
-- \t:制表符,相当于插入了一个“跳格”。
-- \\:输出反斜杠\本身。
-
-```c
-printf("Hello,\nWorld!\tTabbed!\n");
-```
-
-```text
-Hello,
-World! Tabbed!
-```
-
-#### 关于 puts(另一种输出)
-
-```c
-puts("This line is printed using puts(), and it automatically adds a newline.");
-```
-
-**请注意 puts 并不被推荐使用,完全不如 printf,以下是不使用的原因**
-
-- 自动添加换行符: `puts()`函数的一个显著特征是它会自动添加换行符。这在某些简单的场景下很方便,但在更复杂的程序中,往往并不符合预期。
-- 无法处理格式化输出: `puts()`只能打印单一的字符串,并且不支持格式化输出(如插入变量、控制输出精度等)。
-- 缺少错误处理: `puts()`在失败时通常返回一个非负值(EOF 表示错误),但它并不会提供详细的错误信息。
-
-## 完整的代码
-
-[在线编译器代码测试](https://godbolt.org/z/39ao5df4q)
-
-```c
-#include
-#include
-
-int main() {
- // 1. 使用 printf 输出到屏幕
- printf("Hello, World! This is printed using printf().\n");
-
- // 2. 使用 puts 输出到屏幕
- puts("This line is printed using puts(), and it automatically adds a newline.");
-
- // 3. 使用 putchar 输出单个字符到屏幕
- putchar('A');
- putchar('\n');
- return 0;
-}
-```
-
-## 总结
-
-“Hello, World!”虽然只是个简单的小程序,但它背后涉及了很多 C 语言的基本概念:头文件、函数、语句、控制符、返回值、编译运行……别看它简单,这可是每个程序员的“开门钥匙”。
-所以,恭喜你!迈出了编程世界的第一步。只要记住,每一个复杂的程序都源于这种小小的 printf,就像每一篇大作都是从一个字开始的!
-从此,编程之路正式开启,未来你还会学到更多神奇的东西,比如条件判断、循环、指针等等。不过别急,慢慢来,一切都会明朗。希望你编程愉快,`bug 少少,代码美美!`
diff --git a/doc/extra/day02.md b/doc/extra/day02.md
deleted file mode 100644
index 098ef854..00000000
--- a/doc/extra/day02.md
+++ /dev/null
@@ -1,412 +0,0 @@
-# 从零入门 C 语言: Day2 - 不一样的选择却殊途同归
-
-分支语句在 C 语言中就像是程序的导航系统,告诉代码在不同的条件下该去哪里。这些语句帮助你的程序在不同的情况下选择不同的路线,就像你开车时选择不同的道路以避开交通堵塞。通过这些分支,程序能根据不同的输入或状态选择最合适的代码块,从而提升其灵活性和适应性。C 语言中的分支语句有点像你在饭店点餐时的选项:你可以选择“if”你喜欢这道菜,或者“if-else”你更喜欢另一道菜。还有“else-if”结构,就像在菜单上翻找更多的选择,最后是“switch”语句,就像你在自助餐厅里随机挑选自己想吃的美食。本文将以一种轻松愉快的方式,详细介绍这些分支语句的语法、用法以及注意事项,确保你在编程的旅途中不会迷路。
-~~奇怪的比喻增加了~~
-
-## `if` 语句
-
-if 语句是 C 语言中最基础的条件分支控制结构,主要用于判断某个条件是否为真,如果为真则执行特定的代码块。if 语句的条件表达式通常是一个逻辑表达式,其值为 true 或 false。
-
-### 语法
-
-```c
-if (条件表达式) {
-// 条件为真时执行的代码块
-}
-```
-
-- 条件表达式:通常为一个布尔表达式。如果条件成立(即结果为 true),则执行大括号内的代码块;否则跳过该代码块。
-- 代码块:大括号 {} 内的代码仅在条件为真时执行。
-
-### 举个简单的例子
-
-```c
-#include
-
-int main() {
- int x = 5;
-
- if (x > 0) {
- printf("x is positive.\n");
- }
-
- return 0;
-
-}
-```
-
-`if (x > 0)` 判断变量 x 是否大于 0。因为 x 的值为 5,条件成立,所以会输出 `x is positive.`。
-
-## `if-else` 语句
-
-if-else 语句是 if 语句的扩展,允许程序在条件不成立时执行另一段代码。即当条件为真时执行 if 分支的代码;当条件为假(false)时,执行 else 分支的代码。
-
-### 语法格式
-
-```c
-if (条件表达式) {
-// 条件为真时执行的代码块
-} else {
-// 条件为假时执行的代码块
-}
-```
-
-### 再来个例子
-
-```c
-#include
-
-int main() {
- int x = -3;
-
- if (x > 0) {
- printf("x is positive.\n");
- } else {
- printf("x is not positive.\n");
- }
-
- return 0;
-
-}
-```
-
-在这个例子中,变量 x 的值为-3,因此 `if (x > 0)` 条件不成立,程序会执行 else 分支,输出 `x is not positive.`。
-
-## `else if` 结构
-
-else if 结构用于处理多个条件判断。它可以在一个 if-else 语句的基础上,增加多个条件判断,直到找到一个为真的条件。如果某个条件成立,则执行该条件对应的代码块,之后的分支将不会执行。
-
-### 语法格式
-
-```c
-if (条件表达式 1) {
-// 条件 1 为真时执行的代码块
-} else if (条件表达式 2) {
-// 条件 2 为真时执行的代码块
-} else {
-// 当所有条件都不成立时执行的代码块
-}
-```
-
-### 还需要一个例子
-
-```c
-#include
-
-int main() {
- int x = 0;
- if (x > 0) {
- printf("x is positive.\n");
- } else if (x < 0) {
- printf("x is negative.\n");
- } else {
- printf("x is zero.\n");
- }
- return 0;
-}
-```
-
-程序首先判断 x > 0 是否成立,如果成立则执行第一个分支。如果不成立,继续判断 x < 0,如果也不成立,则执行 else 分支。在这个例子中,x 的值为 0,因此会输出 `x is zero.`。
-
-## `switch` 语句
-
-switch 语句是一种多分支的选择结构,常用于处理离散值的判断。它通过匹配一个表达式的值,执行与该值对应的代码块。switch 语句常用于替代多个 if-else if 的条件判断,使得代码更简洁易读。
-
-### 语法格式
-
-```c
-switch (表达式) {
-case 常量 1:
-// 当表达式的值等于常量 1 时执行
-break;
-case 常量 2:
-// 当表达式的值等于常量 2 时执行
-break;
-// 可以有多个 case 分支
-default:
-// 当所有 case 都不匹配时执行
-}
-```
-
-- 表达式:通常是一个整型表达式或字符表达式。
-- case:用于匹配表达式的值。当表达式的值等于某个 case 分支的常量时,执行该分支的代码。
-- break:用于跳出 switch 语句。如果不加 break,程序会继续执行下一个 case 的代码(即使条件不匹配),这称为"fall-through"现象。
-- default:当所有 case 分支都不匹配时,执行 default 分支的代码。
-
-### 又来一个例子
-
-```c
-#include
-
-int main() {
- int day = 3;
- switch (day) {
- case 1:
- printf("Monday\n");
- break;
- case 2:
- printf("Tuesday\n");
- break;
- case 3:
- printf("Wednesday\n");
- break;
- case 4:
- printf("Thursday\n");
- break;
- case 5:
- printf("Friday\n");
- break;
- default:
- printf("Invalid day\n");
- }
- return 0;
-}
-```
-
-在这个例子中,day 的值为 3,因此程序会执行 case 3 的代码,输出“Wednesday”。如果没有 break 语句,程序会继续执行下一个 case 的代码,直到遇到 break 或结束 switch 语句。
-
-## 嵌套的分支语句
-
-分支语句可以互相嵌套,以处理更加复杂的条件判断。常见的嵌套方式是在 if 语句内部嵌套另一个 if 或 switch 语句。
-
-### if里套if
-
-```c
-#include
-
-int main() {
- int x = 10;
-
- if (x > 0) {
- if (x % 2 == 0) {
- printf("x is positive and even.\n");
- } else {
- printf("x is positive and odd.\n");
- }
- } else {
- printf("x is not positive.\n");
- }
-
- return 0;
-
-}
-```
-
-外层 if 语句判断 x 是否为正数,内层 if 语句进一步判断它的奇偶性。在这个例子中,x 的值为 10,因此会输出 `x is positive and even.`。
-
-### switch里套switch
-
-```c
-#include
-
-int main() {
- int category = 1;
- int type = 2;
-
- switch (category) {
- case 1:
- switch (type) {
- case 1:
- printf("Category 1, Type 1\n");
- break;
- case 2:
- printf("Category 1, Type 2\n");
- break;
- default:
- printf("Unknown type in category 1\n");
- }
- break;
- case 2:
- printf("Category 2\n");
- break;
- default:
- printf("Unknown category\n");
- }
- return 0;
-}
-```
-
-这里我们嵌套了两个 switch 语句,分别处理类别和类别下的类型。category 为 1,type 为 2,因此会输出 `Category 1, Type 2`。
-
-## 注意
-
-### if 语句中的空语句
-
-在使用 if 语句时,建议总是使用大括号{}包围代码块,即使代码块中只有一条语句。这可以防止某些情况下由于缩进或其他原因导致的逻辑错误。
-
-典中典:未使用大括号的潜在问题
-
-```c
-#include
-
-int main() {
- int x = 5;
-
- if (x > 0)
- printf("x is positive.\n");
- printf("This is outside the if.\n"); // 实际上总是会执行
-
- return 0;
-
-}
-```
-
-```txt
-x is positive.
-This is outside the if.
-```
-
-由于没有使用大括号,第二个 printf 语句并不属于 if 语句的条件判断部分。它将始终执行,可能导致不符合预期的程序行为。
-
-改进示例
-
-```c
-#include
-
-int main() {
-int x = 5;
-
- if (x > 0) {
- printf("x is positive.\n");
- printf("This is inside the if.\n"); // 只有条件成立时才执行
- }
-
- return 0;
-
-}
-```
-
-x is positive.
-This is inside the if.
-
-### switch 语句中的“fall-through”现象
-
-switch 语句中,如果没有 break 语句,会发生“fall-through”现象,导致程序继续执行下一个 case 的代码块,即使表达式不匹配下一个 case。
-
-```c
-#include
-
-int main() {
- int x = 2;
-
- switch (x) {
- case 1:
- printf("One\n");
- case 2:
- printf("Two\n"); // 执行此 case 后会继续执行 case 3 的代码
- case 3:
- printf("Three\n");
- break;
- default:
- printf("Unknown\n");
- }
-
- return 0;
-
-}
-```
-
-```txt
-Two
-Three
-```
-
-改进示例
-
-```c
-#include
-
-int main() {
- int x = 2;
-
- switch (x) {
- case 1:
- printf("One\n");
- break;
- case 2:
- printf("Two\n");
- break;
- case 3:
- printf("Three\n");
- break;
- default:
- printf("Unknown\n");
- }
-
- return 0;
-
-}
-```
-
-```txt
-Two
-```
-
-## 来一点高级技巧
-
-### 使用三元运算符简化条件判断
-
-C 语言提供了三元运算符?:,用于简化简单的条件判断。三元运算符的语法如下:
-
-```c
-条件 ? 表达式 1 : 表达式 2;
-```
-
-当条件为真时,返回表达式 1 的值;否则返回表达式 2 的值。它通常用于替代简单的 if-else 结构。
-
-```c
-#include
-
-int main() {
- int x = 5;
- const char result = (x > 0) ? "positive" : "non-positive";
-
- printf("x is %s.\n", result);
- return 0;
-
-}
-```
-
-### 枚举与 switch 语句的结合
-
-当处理一组相关的常量时,可以使用 enum 枚举类型结合 switch 语句,这样代码更加清晰,并减少了使用魔术数字的风险。
-示例:使用枚举和 switch
-
-```c
-#include
-
-enum Days { MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY };
-
-int main() {
- enum Days today = WEDNESDAY;
-
- switch (today) {
- case MONDAY:
- printf("Today is Monday.\n");
- break;
- case TUESDAY:
- printf("Today is Tuesday.\n");
- break;
- case WEDNESDAY:
- printf("Today is Wednesday.\n");
- break;
- case THURSDAY:
- printf("Today is Thursday.\n");
- break;
- case FRIDAY:
- printf("Today is Friday.\n");
- break;
- default:
- printf("Unknown day.\n");
- }
-
- return 0;
-
-}
-```
-
-### 总结
-
-通过合理使用分支语句,程序可以根据输入和条件灵活调整执行路径,完成更为复杂的任务。在编写代码时,牢记代码的可读性和可维护性,尽量简化逻辑和控制流。
-
-**生命中最重要的不是我们选择了什么,而是我们如何对待我们所选择的**
diff --git a/doc/extra/day03.md b/doc/extra/day03.md
deleted file mode 100644
index 9aacf056..00000000
--- a/doc/extra/day03.md
+++ /dev/null
@@ -1,232 +0,0 @@
-# 从零入门 C 语言: Day3 - 重复的事情说一次就够了
-
-今天的内容将为你提供一个清晰的起点,帮助你掌握循环语句这一编程中的重要概念。循环语句让我们能够高效地重复执行任务,这不仅是编程的基本能力之一,更是编写灵活、动态程序的关键。
-准备好了吗?让我们一起进入 C 语言的世界,揭开循环语句的神秘面纱,开始这场充满逻辑与创造力的冒险吧!**C 语言,启动!**
-
-## `for` 循环
-
-for 循环是最常见的循环之一。最经典形状应该是长下面这样的:
-
-```c
-for (初始化; 条件; 更新) {
- // 循环体
-}
-```
-
-### 具体的执行步骤
-
-- 初始化:在循环开始时执行一次。
-- 条件:每次循环前检查,如果条件为真(非零),继续执行循环体;为假(零)则退出循环。
-- 更新:每次循环结束后执行一次,通常用于更新循环变量。
-
-### 举个栗子
-
-```c
-#include
-
-int main() {
- int i;
- for (i = 0; i < 5; i++) {
- printf("i = %d\n", i);
- }
- return 0;
-}
-```
-
-这个例子中,i 从 0 开始,每次循环加 1,直到 i 不小于 5。输出结果是 i = 0 到 i = 4。
-
-### 注意点
-
-for 循环适用于已知循环次数的情况。你可以在 for 循环中省略某些部分,比如初始化或更新,但分号必须保留。
-
-### 特殊的 `for(;;)`
-
-for (;;) 是一种特殊的 for 循环结构,它表示一个无限循环。
-
-#### 具体剖析
-
-按照刚才对 for 结构的分析,我们不难发现,在 for (;;) 中,三个部分都被省略了:
-
-- 初始化:没有初始化变量。
-- 条件:没有明确条件,C/C++ 默认为真(true),因此会导致无限循环。
-- 迭代:没有迭代语句。
-
-由于没有条件限制,所以当遇到 for (;;) 时,程序会无限地执行其循环体,直到发生某种中断,如 break、return 或外部干预(例如终止程序)
-
-#### 来个栗子
-
-```c
-#include
-
-int main() {
- int count = 0;
-
- for (;;) { // 无限循环
- printf("Count: %d\n", count);
- count++;
-
- if (count >= 5) {
- break; // 当 count 达到 5 时退出循环
- }
- }
-
- return 0;
-}
-```
-
-## while 循环
-
-while 循环是在条件为真时反复执行的循环。
-
-```c
-while (条件) {
-// 循环体
-}
-```
-
-在每次循环前检查。如果条件为真(非零),执行循环体;否则退出循环。
-
-### 还是需要一个例子
-
-```c
-#include
-
-int main() {
- int i = 0;
- while (i < 5) {
- printf("i = %d\n", i);
- i++;
- }
- return 0;
-}
-```
-
-这个例子与 for 循环的例子类似,i 从 0 开始,每次循环加 1,直到 i 不小于 5 时退出。
-
-### 注意点
-
-如果条件在循环内没有被修改,可能导致死循环(循环永远不会结束)。但是有的时候我们也会主动创造一些死循环的情况,比如说像下面这种,条件始终为真,那么如果内部不退出或者杀死进程的话,会不停运行循环体。
-
-```c
-while (true) {
- // 循环体
-}
-
-```
-
-## do-while 循环
-
-do-while 循环与 while 循环类似,但它至少执行一次循环体,然后再检查条件。
-
-```c
-do {
-// 循环体
-} while (条件);
-```
-
-首先执行一次循环体,然后检查条件。如果条件为真,继续循环;否则退出。
-
-### Example
-
-```c
-#include
-
-int main() {
- int i = 0;
- do {
- printf("i = %d\n", i);
- i++;
- } while (i < 5);
- return 0;
-}
-```
-
-这里 i 依旧从 0 开始,循环体会先执行一次,然后再检查条件。
-
-### 注意点
-
-由于 do-while 循环是先执行再判断条件,因此即使条件一开始为假,循环体也会执行一次。
-
-## 经典循环应用
-
-### 求和算法
-
-计算从 1 到 100 的所有整数之和。
-
-```c
-#include
-
-int main() {
- int sum = 0;
- int i;
-
- for (i = 1; i <= 100; i++) {
- sum += i;
- }
-
- printf("1到100的和是: %d\n", sum);
- return 0;
-
-}
-```
-
-### 计算阶乘
-
-计算一个数的阶乘,比如 5! = 5 × 4 × 3 × 2 × 1。
-
-```c
-#include
-
-int main() {
- int n = 5; // 要计算的阶乘数
- int factorial = 1;
- int i;
-
- for (i = 1; i <= n; i++) {
- factorial *= i;
- }
-
- printf("%d 的阶乘是: %d\n", n, factorial);
- return 0;
-
-}
-```
-
-### 斐波那契数列
-
-生成斐波那契数列的前 10 个数字。
-
-```c
-#include
-
-int main() {
- int n = 10;
- int first = 0, second = 1, next;
- int i;
-
- printf("斐波那契数列: ");
-
- for (i = 0; i < n; i++) {
- if (i <= 1) {
- next = i;
- } else {
- next = first + second;
- first = second;
- second = next;
- }
- printf("%d ", next);
- }
-
- printf("\n");
- return 0;
-
-}
-```
-
-## 总结
-
-- for 循环:适用于已知循环次数的情况。
-- while 循环:适用于条件驱动的循环。
-- do-while 循环:至少执行一次的条件循环。
-
-**人生就像一场旅行,不要重复走以前的路,而要勇敢地探索未知的方向。**
diff --git a/doc/extra/day04.md b/doc/extra/day04.md
deleted file mode 100644
index 3444ed59..00000000
--- a/doc/extra/day04.md
+++ /dev/null
@@ -1,833 +0,0 @@
-# 从零入门 C 语言: Day4 - 你是个什么东西
-
-## 引入
-
-学习 C 语言时,理解类型系统和类型转换是非常重要的。它们帮助你掌握如何存储和操作不同类型的数据。在接下来的讲解中,我会以简单生动的方式带你一探 C 语言类型系统的奥秘,并通过示例加深理解。
-
-## 类型系统
-
-C 语言是一门强类型语言,这意味着每一个变量在使用之前都必须有一个确定的类型。数据类型决定了变量可以存储什么样的数据,以及在内存中占据多少空间。
-
-### 基本数据类型
-
-#### 整型(Integer Types)
-
-- int:标准整型,通常占用 4 字节内存。
-- short:短整型,通常占用 2 字节。
-- long:长整型,通常占用 4 或 8 字节。
-- long long:更长的整型,通常占用 8 字节。
-
-```c
-int age = 25;
-short year = 2022;
-long population = 8000000L;
-long long distance = 12345678901234LL;
-```
-
-#### 字符型(Character Type)
-
-char:用于存储单个字符,占用 1 字节内存。
-
-```c
-char letter = 'A';
-```
-
-#### 浮点型(Floating Point Types)
-
-- float:单精度浮点型,通常占用 4 字节内存。
-- double:双精度浮点型,通常占用 8 字节内存。
-- long double:扩展精度浮点型,通常占用 12 或 16 字节。
-
-```c
-float pi = 3.14f;
-double e = 2.718281828459045;
-```
-
-#### 枚举类型(Enumerated Types)
-
-枚举类型是一种用户定义的类型,用于定义一组命名的整型常量。主要是为了更好的描述数据,并简化代码。
-
-```c
-enum Day { SUNDAY, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY };
-enum Day today = WEDNESDAY;
-```
-
-#### void 类型
-
-void 类型表示“无类型”,通常用于函数返回类型表示该函数不返回任何值。**注意:部分教程说函数返回值 void 可以省略并不是好习惯!**
-
-```c
-void sayHello() {
- printf("Hello, World!\n");
-}
-```
-
-#### 指针类型
-
-指针是 C 语言中的重要类型,用于存储内存地址,也是 c 语言入门学习的噩梦之一。指针的类型决定了它指向的变量类型。
-
-```c
-int x = 10;
-int *ptr = &x; // ptr 是一个指向整数的指针,存储 x 的地址
-```
-
-### 类型转换
-
-在 C 语言中,有时候需要将一种类型的数据转换为另一种类型的数据。类型转换可以分为`隐式类型转换`和`显式类型转换`。
-
-#### 隐式类型转换
-
-当你在表达式中混合使用不同类型的变量时,编译器会自动将它们转换为相同的类型,以避免数据损失。这种转换称为隐式类型转换。说白了就是编译器帮你自动添加了一些代码,完成了这个任务,不需要你来操心!
-
-```c
-int a = 5;
-double b = 3.2;
-double result = a + b; // a 被隐式转换为 double 类型
-```
-
-但是需要注意的是,并不是所有同类类型都可以隐式转换!编译器通常会将“窄”类型(如 int)转换为“宽”类型(如 double),以确保精度。
-
-#### 显式类型转换(强制类型转换)
-
-有时候,常规的隐式转换已经不能满足需求,你需要手动将一种类型的数据转换为另一种类型,这被称为显式类型转换。具体的语法形式为:
-
-```c
-new_type variable = (new_type)expression;
-```
-
-```c
-double pi = 3.14159;
-int truncated_pi = (int)pi; // 将 double 转换为 int,结果为3
-```
-
-这时候你可能发现了,pi 的小数部分被抹去了,因此强制转换可能会导致数据丢失,所以使用强制转换时要小心,确保了解数据的潜在变化。
-
-### 来点例子
-
-- 整数到浮点数的转换
-
-```c
-#include
-
-int main() {
- int apples = 10;
- double price = 1.5;
- double total_cost = apples * price; // 隐式类型转换,apples 被转换为 double
- printf("Total cost: $%.2f\n", total_cost);
- return 0;
-}
-```
-
-在这个例子中,`apples` 是一个整型,`price` 是一个浮点型。在计算`total_cost`时,`apples` 被隐式转换为 **double** 类型以进行浮点数乘法运算。
-
-- 强制类型转换
-
-```c
-#include
-
-int main() {
- double average = 85.6;
- int rounded_average = (int)average; // 强制转换为 int,结果为 85
- printf("Rounded average: %d\n", rounded_average);
- return 0;
-}
-```
-
-`average` 是一个`double`类型。在将其转换为`int`时,强制转换 **截断了小数部分** ,使得`rounded_average`变为 85。
-
-- 指针类型转换
-
-```c
-#include
-
-int main() {
- int a = 42;
- void *ptr = &a; // 将 int 指针转换为 void 指针
- int *int_ptr = (int *)ptr; // 强制转换回 int 指针
- printf("Value of a: %d\n", *int_ptr);
- return 0;
-}
-```
-
-这里,`ptr` 原本是一个`void`指针,可以指向任何类型的变量。我们将它强制转换回 int 指针,以正确地访问变量 a 的值。
-
-### 陷阱和注意点
-
-- 溢出问题:在类型转换时,特别是从大类型转换为小类型时,要注意溢出问题。例如,将一个 long 值强制转换为 short 时,如果 long 值超出了 short 的范围,可能会得到一个错误的结果。
-
-```c
-long big_number = 1234567890L;
-short small_number = (short)big_number; // 可能导致溢出,结果不可预测
-```
-
-- 截断问题:当从浮点类型转换为整数类型时,小数部分会被截断。要确保这种截断是你想要的结果。
-
-```c
-double pi = 3.14159;
-int truncated_pi = (int)pi; // 结果为 3,截断了小数部分
-```
-
-- 指针类型的转换:指针的类型转换需要特别小心。如果你将一个指向某种类型的指针转换为另一种类型的指针,可能会导致未定义行为。
-
-```c
-int a = 10;
-double *ptr = (double *)&a; // 可能导致指针错误访问
-```
-
-## 数组
-
-刚刚我们讲解的都是单个变量,那么如果我要存储一组类型相同的变量,难道只能够定义 n 个变量吗?显然不用,这个时候就要请出数组了!
-
-数组是 C 语言中一个非常重要的概念,它让我们可以处理一组相同类型的数据。数组在内存中是连续存储的,可以高效地管理和访问大量数据。
-
-### 一维数组
-
-#### 定义和初始化
-
-一维数组是最基本的数组类型,存储一组相同类型的数据。
-
-```c
-int arr[5]; // 定义一个包含5个整型元素的数组,中括号内的数字即为数组大小
-```
-
-你可以在定义数组时进行初始化:
-
-```c
-int arr[5] = {1, 2, 3, 4, 5}; // 初始化一个包含5个元素的数组
-```
-
-如果你不指定数组的大小,编译器会根据初始化列表自动推导大小:
-
-```c
-int arr[] = {1, 2, 3}; // 自动推导数组大小为3
-```
-
-需要注意的是,如果你部分初始化数组,那么未初始化的元素会被设置为 0。
-
-```c
-int arr[5] = {1, 2}; // 剩余元素自动初始化为0
-```
-
-#### 访问元素
-
-这里很明确,我们可以通过数组的索引来访问和修改元素。这里有一个大坑, **索引从 0 开始** ,初学者往往会习惯性认为索引从 1 开始而导致错误。
-
-```c
-int x = arr[2]; // 获取数组的第三个元素,值为3
-arr[0] = 10; // 将第一个元素的值改为10
-```
-
-### 多维数组
-
-多维数组用于存储表格或矩阵形式的数据。其中,二维数组是最常见的多维数组。
-
-#### 定义&初始化
-
-二维数组可以看作是“数组的数组”,即每个元素本身又是一个数组。来个简单的例子:
-
-```c
-int matrix[3][4]; // 定义一个3行4列的二维数组
-```
-
-同理,二维数组的初始化可以通过嵌套的花括号来完成:
-
-```c
-int matrix[3][4] = {
- {1, 2, 3, 4},
- {5, 6, 7, 8},
- {9, 10, 11, 12}
-};
-```
-
-当然,如果你不明确行列也是可以的,编译器会自动处理行和列的关联,就像下面这样:
-
-```c
-int matrix[3][4] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
-```
-
-#### 访问
-
-由访问一维数组需要一个索引可以推断,访问二维数组的元素需要两个索引,分别表示行和列。
-
-```c
-int value = matrix[1][2]; // 获取第二行第三列的值,值为7
-matrix[2][3] = 15; // 将第三行第四列的值设为15
-```
-
-多维数组的元素在内存中是连续存储的,**行优先(row-major)** 顺序。即第一行的所有元素存储在内存中,然后是第二行,依此类推。
-
-### 指针数组和数组指针
-
-#### 指针数组
-
-指针数组是一个数组,其中的每个元素都是指针。它通常用于存储字符串数组或其他需要动态分配内存的数据。
-
-```c
-char *strArr[3]; // 定义一个字符指针数组
-strArr[0] = "Hello";
-strArr[1] = "World";
-strArr[2] = "C";
-```
-
-你可以通过数组索引访问这些字符串:
-
-```c
-printf("%s\n", strArr[1]); // 输出 "World"
-```
-
-#### 数组指针
-
-数组指针是指向数组的指针。它允许我们通过指针操作数组。
-
-```c
-int arr[5] = {1, 2, 3, 4, 5};
-int (*p)[5] = &arr; // p 是一个指向包含5个整数的数组的指针
-```
-
-你可以通过指针访问数组元素:
-
-```c
-int x = (*p)[2]; // 获取第三个元素,值为3
-```
-
-#### 指针运算
-
-由于数组名实际上是指向数组首元素的指针,因此可以进行指针运算:
-
-```c
-int arr[5] = {1, 2, 3, 4, 5};
-int *p = arr;
-
-printf("%d\n", *(p + 2)); // 输出3,即arr[2]
-```
-
-p + 2 表示指针 p 向后移动两个元素,然后通过\*解引用访问该元素的值。
-
-### 二级数组(指针的数组)
-
-二级数组通常用于表示指针的数组,特别是在处理二维数组或指针数组时。
-
-#### 定义 and 用法
-
-一个简单的二级数组是一个指向指针的指针(char **或 int **)。这种结构可以用于 **动态分配** 二维数组或处理字符指针数组。
-
-```c
-char *lines[] = {"line1", "line2", "line3"};
-char **p = lines; // p 是一个指向字符指针的指针
-```
-
-你可以使用二级指针访问或修改指针数组的内容:
-
-```c
-printf("%s\n", p[1]); // 输出 "line2"
-p[2] = "new line3"; // 修改第三个指针所指向的字符串
-```
-
-#### 动态分配
-
-二级指针非常适合动态分配二维数组,下面是一个典型例子:
-
-```c
-#include
-#include
-
-int main() {
- int rows = 3; // 矩阵的行数
- int cols = 4; // 矩阵的列数
-
- // 动态分配一个指向指针的数组,用于存储每一行的指针
- int **matrix = (int **)malloc(rows * sizeof(int *));
-
- // 为每一行动态分配列数的内存
- for (int i = 0; i < rows; i++) {
- matrix[i] = (int *)malloc(cols * sizeof(int));
- }
-
- // 初始化矩阵
- for (int i = 0; i < rows; i++) {
- for (int j = 0; j < cols; j++) {
- matrix[i][j] = i * cols + j; // 将矩阵元素赋值为它的线性索引
- }
- }
-
- // 打印矩阵
- printf("矩阵内容:\n");
- for (int i = 0; i < rows; i++) {
- for (int j = 0; j < cols; j++) {
- printf("%d ", matrix[i][j]); // 打印每个元素
- }
- printf("\n"); // 换行
- }
-
- // 释放内存
- for (int i = 0; i < rows; i++) {
- free(matrix[i]); // 释放每一行的内存
- }
- free(matrix); // 释放存储行指针的数组
-
- return 0; // 返回成功状态
-}
-```
-
-```txt
-0 1 2 3
-4 5 6 7
-8 9 10 11
-```
-
-### 柔性数组(Flexible Array Members)
-
-柔性数组成员是一种 C99 标准引入的高级特性,允许在结构体中定义一个可变长度的数组。这种数组没有固定的大小,需要通过动态内存分配来使用。
-
-#### 定义&使用
-
-```c
-#include
-#include
-
-struct FlexibleArray {
- int length;
- int array[]; // 柔性数组成员
-};
-
-int main() {
- int n = 5;
-
- // 动态分配内存,包含结构体和数组的总大小
- struct FlexibleArray *fa = malloc(sizeof(struct FlexibleArray) + n * sizeof(int));
-
- fa->length = n;
- for (int i = 0; i < n; i++) {
- fa->array[i] = i * 2; // 初始化数组
- }
-
- // 打印数组
- for (int i = 0; i < fa->length; i++) {
- printf("%d ", fa->array[i]);
- }
- printf("\n");
-
- free(fa); // 释放内存
-
- return 0;
-}
-```
-
-```txt
-0 2 4 6 8
-```
-
-需要格外注意的是,柔性数组成员 array[]在定义时不指定大小,内存分配时根据需要来决定数组的实际大小!
-
-### 静态数组和动态数组
-
-其实,上面的很多例子中已经出现了这两个概念,那么下面就是具体的讲解。
-
-#### 静态数组
-
-静态数组是在 **编译时** 分配内存的,数组的大小在程序编译时就已经确定,并且在程序的生命周期内保持不变。静态数组通常分配在栈内存中,大小是 **固定的**。
-
-```c
-int static_array[10]; // 定义一个大小为10的静态数组
-```
-
-优点:
-
-- 访问速度快,因为它们位于栈内存中。
-- 不需要手动管理内存,编译器会自动处理内存的分配和释放。
-
-缺点:
-
-- 大小固定,一旦定义,无法在程序运行时更改。
-- 如果数组很大,可能导致栈溢出,特别是在嵌套调用较深的情况下。
-
-#### 动态数组
-
-动态数组是在程序运行时根据需要动态分配内存的。它的大小可以在运行时确定,并且可以在程序的不同阶段分配和释放内存。动态数组通常分配在堆内存中。
-
-```c
-#include
-#include
-
-int main() {
- int n = 10;
- int *dynamic_array = (int *)malloc(n * sizeof(int)); // 动态分配一个大小为n的数组
-
- if (dynamic_array == NULL) {
- printf("Memory allocation failed!\n");
- return 1;
- }
-
- // 使用数组
- for (int i = 0; i < n; i++) {
- dynamic_array[i] = i * 2;
- }
-
- // 打印数组
- for (int i = 0; i < n; i++) {
- printf("%d ", dynamic_array[i]);
- }
- printf("\n");
-
- // 释放内存
- free(dynamic_array);
-
- return 0;
-}
-```
-
-优点:
-
-- 数组的大小可以在运行时动态调整。
-- 适合处理需要在运行时确定大小的大型数据集。
-
-缺点:
-
-- 需要手动管理内存,必须确保在使用完后释放内存(使用 free 函数),否则会导致内存泄漏。
-- 动态内存分配相比静态内存分配速度稍慢,因为它涉及到系统调用。
-
-### 函数中的数组参数
-
-当数组作为参数传递给函数时,它实际上是将数组的指针传递给函数。因此,函数内对数组的任何修改会影响到原数组。
-
-```c
-void modifyArray(int *arr, int size) {
- for (int i = 0; i < size; i++) {
- arr[i] = arr[i] * 2;
- }
-}
-
-int main() {
- int arr[5] = {1, 2, 3, 4, 5};
- modifyArray(arr, 5);
-
- for (int i = 0; i < 5; i++) {
- printf("%d ", arr[i]); // 输出2 4 6 8 10
- }
- return 0;
-}
-```
-
-### 复杂一点的特性
-
-#### 动态调整数组大小:realloc
-
-动态数组的大小可以在运行时通过 `realloc`函数调整。这允许你在程序运行期间根据需要扩展或缩小数组的大小。
-
-```c
-#include
-#include
-
-int main() {
- int *arr = (int *)malloc(5 * sizeof(int));
-
- for (int i = 0; i < 5; i++) {
- arr[i] = i;
- }
-
- // 调整数组大小
- arr = (int *)realloc(arr, 10 * sizeof(int));
-
- for (int i = 5; i < 10; i++) {
- arr[i] = i * 2;
- }
-
- // 打印数组
- for (int i = 0; i < 10; i++) {
- printf("%d ", arr[i]);
- }
- printf("\n");
-
- free(arr);
- return 0;
-}
-```
-
-需要注意,realloc 可能会移动数组到新的内存位置,因此返回的指针可能不同于原来的指针。如果 realloc 失败,返回 NULL,旧的内存保持不变。
-
-### 数组中的注意点
-
-- 数组越界
-
-数组越界是 C 语言中常见且危险的错误,访问数组边界之外的内存可能导致未定义行为或程序崩溃。通常情况下你会获得名为 Segmentation Fault 的错误,也就是臭名昭著的段错误!
-
-```c
-int arr[5] = {1, 2, 3, 4, 5};
-int value = arr[5]; // 错误:越界访问,而且是由最容易犯得记错索引导致的
-```
-
-- 指针和数组的关系
-
-数组名在大多数表达式中会被转换为指向其第一个元素的指针。例如:
-
-```c
-int arr[5] = {1, 2, 3, 4, 5};
-int *p = arr; // p 指向 arr[0]
-```
-
-但是,数组名并不是指针,它是一个常量指针,不能被修改。
-
-## 字符串
-
-很多时候我们需要让变量能够存储一个句子,那么就需要用到字符串了。C 语言中的字符串体系是一个非常重要的概念,因为它涉及到如何存储、处理和操作文本数据。在 C 语言中,字符串并不像在其他高级语言中那样是一个独立的数据类型,而是一个字符数组。
-
-### 表示字符串
-
-在 C 语言中,字符串是由字符数组表示的,且以 **空字符(\0)** 结尾。空字符标志着字符串的结束, **因此它的长度比实际字符数多一位**。
-
-### 定义与初始化
-
-```c
-char str1[] = "Hello, World!";
-```
-
-这里我们并没有添加\0,但是仍然是正确的,因为编译器会自动添加\0,因此 str1 的实际大小是 14 个字符(包括\0)。这种初始化方式是最常用的,也是最安全的,因为它 **自动处理了字符串的结束标志**。
-
-当然了,如果你不嫌麻烦,也可以逐字符初始化一个字符串,虽然这有点呆:
-
-```c
-char str2[] = {'H', 'e', 'l', 'l', 'o', '\0'};
-```
-
-不过这种方式也更灵活,但方便与灵活不可兼得,这样容易出错,尤其是在手动忘记添加\0 时,当场爆炸。
-
-你还可以定义一个指向字符串的指针:
-
-```c
-char *str3 = "Hello, World!";
-```
-
-这里 str3 是一个指向字符串常量的指针。需要注意的是,这种方式定义的字符串通常存储在只读内存区,因此不能修改它的内容,可以加上 const 修饰符表明这是一个常量。
-
-### 常见操作
-
-C 语言标准库提供了一组函数来处理字符串。这些函数大多在``头文件中定义。需要注意的是,c 语言提供的很多操作函数其实是不安全的,在
-
-#### `strlen`:获取字符串长度
-
-`strlen 用于获取字符串的长度 **(不包括末尾的\0)**。
-
-```c
-#include
-#include
-
-int main() {
- char str[] = "Hello, World!";
- int length = strlen(str);
- printf("Length of the string: %d\n", length);
- return 0;
-}
-```
-
-```txt
-Length of the string: 13
-```
-
-#### `strcpy`:复制字符串
-
-strcpy 函数将一个字符串复制到另一个字符串中。
-
-```c
-#include
-#include
-
-int main() {
- char src[] = "Hello";
- char dest[20]; // 确保目标数组足够大,不然会溢出
- strcpy(dest, src);
- printf("Copied string: %s\n", dest);
- return 0;
-}
-```
-
-```txt
-Copied string: Hello
-```
-
-#### `strcat`:连接字符串
-
-strcat 函数将两个字符串连接在一起。
-
-```c
-#include
-#include
-
-int main() {
- char str1[20] = "Hello";
- char str2[] = ", World!";
- strcat(str1, str2); // 确保目标数组有足够的空间存储连接后的字符串,包括空字符\0。
- printf("Concatenated string: %s\n", str1);
- return 0;
-}
-```
-
-```txt
-Concatenated string: Hello, World!
-```
-
-#### `strcmp`:比较字符串
-
-strcmp 函数用于比较两个字符串的字典顺序。
-
-```c
-#include
-#include
-
-int main() {
- char str1[] = "Apple";
- char str2[] = "Banana";
- int result = strcmp(str1, str2);
-
- if (result < 0) {
- printf("str1 is less than str2\n");
- } else if (result > 0) {
- printf("str1 is greater than str2\n");
- } else {
- printf("str1 is equal to str2\n");
- }
- return 0;
-}
-```
-
-```txt
-str1 is less than str2
-```
-
-strcmp 返回一个整数值,如果第一个字符串在字典顺序上小于、等于或大于第二个字符串,则分别返回负值、0 或正值。
-
-#### `strchr` 和 `strstr`:查找字符或子串
-
-strchr 用于查找字符串中某个字符的第一次出现。
-strstr 用于查找一个字符串中第一次出现的子串。
-
-```c
-#include
-#include
-
-int main() {
- char str[] = "Hello, World!";
- char *pos = strchr(str, 'W');
-
- if (pos != NULL) {
- printf("Character found at position: %ld\n", pos - str);
- } else {
- printf("Character not found.\n");
- }
-
- char *substr = strstr(str, "World");
-
- if (substr != NULL) {
- printf("Substring found: %s\n", substr);
- } else {
- printf("Substring not found.\n");
- }
-
- return 0;
-}
-```
-
-```txt
-Character found at position: 7
-Substring found: World!
-```
-
-### 字符串操作中的注意点
-
-很多前面都有,但还是反复强调!
-
-#### 字符数组大小
-
-在定义字符数组时,一定要确保数组大小足够容纳字符串和空字符\0。例如:
-
-```c
-char str[5] = "Hello"; // 错误:数组长度不足
-```
-
-这个例子会导致缓冲区溢出,因为"Hello"需要 6 个字符空间(包括\0)。
-
-#### 字符串常量和可变性
-
-字符串常量(例如"Hello")通常存储在只读内存中,因此不能通过指针修改它们。如果你尝试修改一个字符串常量,会导致运行时错误:
-
-```c
-char *str = "Hello";
-str[0] = 'h'; // 错误:可能导致程序崩溃
-```
-
-#### 缓冲区溢出
-
-字符串操作时的缓冲区溢出是 C 语言中非常常见且危险的错误。操作字符串时务必确保目标缓冲区的大小足够。例如:
-
-```c
-char str1[10] = "Hello";
-char str2[] = "World!";
-strcat(str1, str2); // 错误:str1 缓冲区溢出
-```
-
-这种错误会导致未定义行为,甚至程序崩溃。
-
-### 来几个经典的例子
-
-#### 反转字符串
-
-```c
-#include
-#include
-
-void reverse(char str[]) {
- int n = strlen(str);
- for (int i = 0; i < n / 2; i++) {
- char temp = str[i];
- str[i] = str[n - i - 1];
- str[n - i - 1] = temp;
- }
-}
-
-int main() {
- char str[] = "Hello, World!";
- reverse(str);
- printf("Reversed string: %s\n", str);
- return 0;
-}
-```
-
-```txt
-Reversed string: !dlroW ,olleH
-```
-
-#### 检查回文字符串
-
-```c
-#include
-#include
-
-int isPalindrome(char str[]) {
- int n = strlen(str);
- for (int i = 0; i < n / 2; i++) {
- if (str[i] != str[n - i - 1]) {
- return 0;
- }
- }
- return 1;
-}
-
-int main() {
- char str[] = "madam";
- if (isPalindrome(str)) {
- printf("The string is a palindrome.\n");
- } else {
- printf("The string is not a palindrome.\n");
- }
- return 0;
-}
-```
-
-```txt
-The string is a palindrome.
-```
-
-## 总结
-
-今天的内容很多,需要好好消化,希望你能够理解并应用到你的实际编程项目中。
-
-**人生是一条不断探索的旅程,充满了选择与变化,体验着喜悦与挑战,最终形成我们独特的故事。**
diff --git a/doc/extra/day05.md b/doc/extra/day05.md
deleted file mode 100644
index 88ce30f7..00000000
--- a/doc/extra/day05.md
+++ /dev/null
@@ -1,512 +0,0 @@
-# 从零入门 C 语言: Day5 - 标准输入
-
-## 引入
-
-在 C 语言中,输入是从用户或文件等外部源获取数据并存储到程序变量中的过程。C 语言提供了多种方式来获取输入数据,包括标准输入函数、文件输入函数以及低级别的系统输入函数。
-
-## 标准输入函数
-
-### `scanf`
-
-`scanf`是 C 语言中最常用的标准输入函数,它允许从标准输入(通常是键盘)中读取格式化的数据,并将这些数据存储到变量中。
-
-```c
-int scanf(const char *format, ...);
-```
-
-- format:指定要读取的输入数据类型的格式字符串(例如"%d"表示整数,"%f"表示浮点数)。
-- 返回值:返回成功读取的变量数量。如果读取失败,返回值为 EOF。
-
-举个例子:
-
-```c
-#include
-
-int main() {
- int num;
- float f;
-
- printf("Enter an integer and a float: ");
- scanf("%d %f", &num, &f);
-
- printf("You entered: %d and %.2f\n", num, f);
- return 0;
-}
-```
-
-需要注意的是:
-
-- scanf 需要传入变量的地址(使用&符号),因为它需要修改这些变量的值。
-- scanf 会忽略输入数据中的空格、换行符和制表符,但可以用空格等分隔符来读取多项数据。
-- 如果输入的格式与指定的格式不匹配,可能导致读取失败或数据错误。
-
-### `fscanf`
-
-fscanf 与 scanf 类似,但它是从文件流中读取格式化数据。
-
-```c
-int fscanf(FILE *stream, const char *format, ...);
-```
-
-- stream:文件流指针,指定要读取数据的文件。
-- 其他参数与 scanf 相同。
-
-来个基础示例
-
-```c
-#include
-
-int main() {
- FILE *file = fopen("input.txt", "r");
- if (file == NULL) {
- printf("Error opening file!\n");
- return 1;
- }
-
- int num;
- fscanf(file, "%d", &num);
- printf("Number from file: %d\n", num);
-
- fclose(file);
- return 0;
-}
-```
-
-切记,使用 fscanf 时,确保文件已成功打开,并在使用完毕后 **正确关闭文件**。
-
-### `sscanf`
-
-从名字可以看出,sscanf 从字符串中读取格式化数据,而不是从标准输入或文件。
-
-```c
-int sscanf(const char *str, const char *format, ...);
-```
-
-- str:要读取数据的字符串。
-- 其他参数与 scanf 相同。
-
-```c
-#include
-
-int main() {
- char input[] = "42 3.14";
- int num;
- float f;
-
- sscanf(input, "%d %f", &num, &f);
- printf("Parsed values: %d and %.2f\n", num, f);
-
- return 0;
-}
-```
-
-## 字符输入函数
-
-### `getchar`
-
-getchar 从标准输入中读取一个字符。它是一个简单的字符输入函数,通常用于逐字符读取输入。
-
-```c
-int getchar(void);
-```
-
-- 返回值:返回读取的字符(作为 int 类型)。如果遇到输入结束(EOF),返回 EOF。
-
-```c
-#include
-
-int main() {
- char c;
-
- printf("Enter a character: ");
- c = getchar();
-
- printf("You entered: %c\n", c);
- return 0;
-}
-```
-
-需要格外注意的是,getchar 不会跳过空格和换行符,它会逐字符读取每一个输入字符。由于返回值为 int,你需要将其转换为 char 类型来使用。
-
-### `fgetc`
-
-fgetc 与 getchar 类似,但它用于从文件中读取一个字符。
-
-```c
-int fgetc(FILE *stream);
-```
-
-- stream:要从中读取字符的文件流指针。
-- 返回值:返回读取的字符(作为 int 类型)。如果遇到文件结束或错误,返回 EOF。
-
-```c
-#include
-
-int main() {
- FILE *file = fopen("input.txt", "r");
- if (file == NULL) {
- printf("Error opening file!\n");
- return 1;
- }
-
- char c;
- while ((c = fgetc(file)) != EOF) {
- putchar(c); // 打印读取的字符
- }
-
- fclose(file);
- return 0;
-}
-```
-
-注意点:
-
-- fgetc 用于从文件流中逐字符读取数据,适用于处理文件内容的逐行或逐字符处理。
-- 读取时会自动移动文件指针,逐字符读取下一个字符。
-
-### `getc`
-
-getc 与 fgetc 功能相同,但可能实现方式略有不同。getc 通常用于从文件流中读取字符。
-
-```c
-int getc(FILE *stream);
-```
-
-- stream:文件流指针。
-- 返回值:返回读取的字符(作为 int 类型),或在遇到 EOF 时返回 EOF。
-
-```c
-#include
-
-int main() {
- FILE *file = fopen("input.txt", "r");
- if (file == NULL) {
- printf("Error opening file!\n");
- return 1;
- }
-
- char c;
- while ((c = getc(file)) != EOF) {
- putchar(c); // 打印读取的字符
- }
-
- fclose(file);
- return 0;
-}
-```
-
-### `ungetc`
-
-ungetc 将字符“放回”到输入流中,使其成为下一个要读取的字符。这在某些解析场景中非常有用。
-
-```c
-int ungetc(int c, FILE *stream);
-```
-
-- c:要放回的字符。
-- stream:文件流指针。
-- 返回值:返回放回的字符,若失败返回 EOF。
-
-```c
-#include
-
-int main() {
- int c;
- FILE *file = fopen("input.txt", "r");
-
- if (file == NULL) {
- printf("Error opening file!\n");
- return 1;
- }
-
- c = fgetc(file);
- if (c != EOF) {
- ungetc(c, file); // 将字符放回流中
- c = fgetc(file); // 再次读取同一个字符
- printf("Character read again: %c\n", c);
- }
-
- fclose(file);
- return 0;
-}
-```
-
-ungetc 最多只能将一个字符放回流中,放回后可以再次读取同一个字符。
-在使用 ungetc 之前,**确保放回的字符与流的状态一致**。
-
-## 行输入函数
-
-### `gets`(不推荐)
-
-gets 从标准输入读取一行字符串,直到遇到换行符或文件结束符。由于存在缓冲区溢出的风险,**gets 已被 C11 标准弃用**。
-
-```c
-char *gets(char *str);
-```
-
-- str:指向接收输入字符串的缓冲区指针。
-- 返回值:返回输入字符串指针(str),如遇到 EOF 则返回 NULL。
-
-Outdated!!!
-
-```c
-#include
-
-int main() {
- char str[100];
-
- printf("Enter a string: ");
- gets(str);
-
- printf("You entered: %s\n", str);
- return 0;
-}
-```
-
-由于 gets 不检查输入是否超过缓冲区大小,可能导致缓冲区溢出并引发安全问题,强烈建议不要使用 gets,应使用更安全的 fgets。
-
-### `fgets`
-
-fgets 是读取一行输入的推荐方法,它可以避免缓冲区溢出问题。
-
-```c
-char *fgets(char *str, int n, FILE *stream);
-```
-
-- str:指向接收输入字符串的缓冲区指针。
-- n:要读取的最大字符数(包括终止符)。
-- stream:文件流指针,stdin 表示标准输入。
-- 返回值:返回 str,如果遇到 EOF 或发生错误,返回 NULL。
-
-```c
-#include
-
-int main() {
- char str[100];
-
- printf("Enter a string: ");
- if (fgets(str, 100, stdin) != NULL) {
- printf("You entered: %s", str);
- } else {
- printf("Error reading input.\n");
- }
-
- return 0;
-}
-```
-
-注意,fgets 会保留输入的换行符\n,如果你不想保留换行符,可以手动去掉它:
-
-```c
-str[strcspn(str, "\n")] = '\0';
-```
-
-fgets 是 **安全** 的,适合用于读取输入行或处理来自文件的文本行。
-
-## 低级别输入函数
-
-C 语言还提供了一些低级别的输入函数,这些函数直接与操作系统交互,通常用于更底层的输入/输出操作。
-
-### `read`
-
-read 是 UNIX 系统调用,用于从文件描述符读取原始数据。它允许对低级别的文件操作进行更多控制,通常在系统编程中使用。
-
-```c
-ssize_t read(int fd, void *buf, size_t count);
-```
-
-- fd:文件描述符,表示从哪个文件或输入源读取数据。
-- buf:指向接收读取数据的缓冲区指针。
-- count:要读取的最大字节数。
-- 返回值:返回读取的字节数,如果遇到 EOF 返回 0,如遇错误返回-1。
-
-```c
-#include
-#include
-#include
-
-int main() {
- char buffer[100];
- int fd = open("input.txt", O_RDONLY);
-
- if (fd == -1) {
- perror("Error opening file");
- return 1;
- }
-
- ssize_t bytesRead = read(fd, buffer, sizeof(buffer) - 1);
-
- if (bytesRead == -1) {
- perror("Error reading file");
- close(fd);
- return 1;
- }
-
- buffer[bytesRead] = '\0'; // 确保字符串以空字符结束
- printf("Read from file: %s\n", buffer);
-
- close(fd);
- return 0;
-}
-```
-
-贴近操作系统是有代价的:不会自动处理文本行结束符或进行缓冲区管理,因此需要手动管理输入数据。
-
-### `getch` 和 `getche`
-
-这两个函数用于从标准输入读取单个字符,并且不需要按下回车键即可输入。它们通常用于处理键盘输入,特别是在控制台应用程序中。
-
-- getch:读取单个字符,不显示在屏幕上。
-- getche:读取单个字符,并显示在屏幕上。
-
-这些函数不是标准 C 库的一部分,通常在 Windows 环境中由`conio.h`提供。
-
-```c
-int getch(void);
-int getche(void);
-```
-
-```c
-#include
-#include
-
-int main() {
- char c;
-
- printf("Press a key: ");
- c = getch(); // 使用 getch() 获取字符,但不显示
- printf("\nYou pressed: %c\n", c);
-
- printf("Press another key: ");
- c = getche(); // 使用 getche() 获取字符,并显示
- printf("\nYou pressed: %c\n", c);
-
- return 0;
-}
-```
-
-**Windows Only!**
-
-## 文件输入函数
-
-C 语言还提供了一些用于从文件中读取数据的函数。
-
-### `fread`
-
-fread 用于从文件中读取原始数据块,它通常用于二进制文件的读取。
-
-```c
-size_t fread(void *ptr, size_t size, size_t nmemb, FILE *stream);
-```
-
-- ptr:指向接收读取数据的缓冲区指针。
-- size:每个数据块的大小(字节数)。
-- nmemb:要读取的块的数量。
-- stream:文件流指针。
-- 返回值:返回成功读取的块数。
-
-```c
-#include
-
-int main() {
- FILE *file = fopen("data.bin", "rb");
- if (file == NULL) {
- printf("Error opening file!\n");
- return 1;
- }
-
- int data[5];
- size_t bytesRead = fread(data, sizeof(int), 5, file);
-
- for (size_t i = 0; i < bytesRead; i++) {
- printf("data[%zu] = %d\n", i, data[i]);
- }
-
- fclose(file);
- return 0;
-}
-```
-
-- fread 直接读取原始数据块,而不是格式化的数据。适用于读取二进制文件或自定义文件格式。
-- 确保缓冲区足够大以容纳读取的数据,并检查返回值以确定是否成功读取。
-
-### `getline`
-
-getline 用于从文件或标准输入中读取整行数据,包括换行符。它动态分配缓冲区以适应读取的数据行大小。
-
-```c
-ssize_t getline(char **lineptr, size_t *n, FILE *stream);
-```
-
-- lineptr:指向缓冲区的指针指针。如果指针指向 NULL,getline 会自动分配缓冲区。
-- n:指向缓冲区大小的指针。
-- stream:文件流指针。
-- 返回值:返回读取的字符数(包括换行符),如果遇到 EOF 返回-1。
-
-```c
-#include
-#include
-
-int main() {
- char *line = NULL;
- size_t len = 0;
- ssize_t read;
-
- FILE *file = fopen("input.txt", "r");
- if (file == NULL) {
- printf("Error opening file!\n");
- return 1;
- }
-
- while ((read = getline(&line, &len, file)) != -1) {
- printf("Retrieved line of length %zu: %s", read, line);
- }
-
- free(line);
- fclose(file);
- return 0;
-}
-```
-
-**getline 是 POSIX 标准的扩展函数**
-
-## 安全输入(以`scanf_s`为例)
-
-scanf_s 是 scanf 函数的安全版本,它通过增加对输入长度的控制,来避免输入数据超过目标缓冲区的大小,从而降低缓冲区溢出的风险。
-
-```c
-int scanf_s(const char *format, ...);
-```
-
-- format:格式字符串,与 scanf 的格式字符串类似,用于指定输入的数据类型。
-- ...:可变参数,指定要存储输入数据的变量地址。对于某些类型的输入(如字符串或字符数组),需要额外指定目标缓冲区的大小。
-
-```c
-#include
-
-int main() {
- char buffer[10];
- int num;
-
- printf("Enter a number and a string: ");
- scanf_s("%d %9s", &num, buffer, sizeof(buffer));
-
- printf("You entered: %d and %s\n", num, buffer);
- return 0;
-}
-```
-
-在这个示例中,%d 用于读取一个整数,%9s 用于读取最多 9 个字符的字符串。字符串后面的 sizeof(buffer) 指定了缓冲区的大小,以确保不会读入超过缓冲区容量的字符串。
-scanf_s 在读取字符串时要求提供额外的参数,指定目标缓冲区的大小。这个参数是 **必需** 的。
-
-## 总结
-
-C 语言提供了多种输入方式来满足不同的需求,从简单的标准输入函数到复杂的文件输入函数,每种方式都有其适用场景和使用注意点:
-
-- 标准输入函数:scanf、fscanf、sscanf 用于读取格式化的数据,适合从标准输入或文件中获取结构化数据。
-- 字符输入函数:getchar、fgetc、getc 用于逐字符读取输入,适合处理逐字符输入或文件内容。
-- 行输入函数:fgets 和 getline 用于读取整行数据,fgets 安全性高,getline 适合处理动态长度行。
-- 低级别输入函数:read、getch、getche 用于系统级编程或控制台输入,适合需要直接与硬件或操作系统交互的场景。
-- 文件输入函数:fread 用于读取二进制数据块,适合处理二进制文件或自定义格式的文件。
-
-掌握这些输入函数,可以让你在 C 语言编程中灵活地处理各种输入数据,满足不同的编程需要 AwA
diff --git a/doc/extra/rrii.md b/doc/extra/rrii.md
deleted file mode 100644
index 37e88f86..00000000
--- a/doc/extra/rrii.md
+++ /dev/null
@@ -1,263 +0,0 @@
-# 临时释放锁:Antilock 模式 (RRII)
-
-## 前言
-
-在多线程编程中,正确管理线程同步是确保程序稳定性和性能的关键。C++ 提供了多种工具来帮助开发者实现线程同步,例如 `std::mutex`、`std::lock_guard` 和 `std::unique_lock` 等。这些工具虽然强大,但在某些复杂场景下,可能需要更灵活的锁管理方式。比如说有的时候我在加锁后,在局部代码中需要释放锁,然后后续运行又需要加锁,这个时候我们虽然可以通过`unlock`和`lock`组合完成,但是代码变长后难免会出现遗忘的情况,从而产生错误。那么,本文将介绍一种名为“Antilock”的模式,它能够在需要时暂时释放锁,并在操作完成后自动重新获取锁,从而避免潜在的死锁和遗忘问题。
-
-本文的理念来自 [Antilock 模式](https://devblogs.microsoft.com/oldnewthing/20240814-00/?p=110129)
-
-## 概念
-
-### 互斥锁与 RAII 模式
-
-在 C++ 中,互斥锁(mutex)用于保护共享资源,防止多个线程同时访问而导致数据不一致。为了简化锁的管理,C++ 标准库引入了 RAII(Resource Acquisition Is Initialization)模式。RAII 模式通过在对象的构造函数中获取资源,在析构函数中释放资源,确保资源管理的安全性。
-例如,`std::lock_guard` 就是一个常用的 RAII 类型,它会在构造时自动锁定互斥锁,并在析构时自动解锁:
-
-```cpp
-std::mutex mtx;
-{
- std::lock_guard guard(mtx);
- // 这里的代码块在锁的保护下执行
-} // 离开作用域时,锁自动释放
-```
-
-### Antilock 模式
-
-Antilock 模式是一种反作用的锁管理策略,它的作用是暂时释放一个已经锁定的互斥锁,并在特定操作完成后重新获取该锁。这种模式特别适合需要在锁的保护下执行部分操作,但又需要在某些时候释放锁以避免死锁的场景。
-
-那么这种机制是否可以称为`RRII`呢?即`Resource Release Is Initialization`!
-
-## 实现
-
-### 基本实现
-
-我们首先定义一个简单的 Antilock 模板类,它可以接受一个互斥锁对象,在构造时解锁该互斥锁,并在析构时重新锁定它:
-
-```cpp
-template
-class Antilock {
-public:
- Antilock() = default;
-
- explicit Antilock(Mutex& mutex)
- : m_mutex(std::addressof(mutex)) {
- if (m_mutex) { // 这里做一个检查如果互斥锁有效
- m_mutex->unlock();
- std::cout << "[Antilock] Lock released.\n";
- }
- }
-
- ~Antilock() {
- if (m_mutex) {
- m_mutex->lock();
- std::cout << "[Antilock] Lock reacquired.\n";
- }
- }
-
-private:
- Mutex* m_mutex = nullptr; // 指向互斥锁的指针
-};
-```
-
-这个简单的 `Antilock` 类在构造时解锁传入的 `Mutex` 对象,并在析构时重新锁定它。通过这种方式,可以安全地在一个作用域内暂时释放锁,执行其他操作。
-
-### 支持 Guard 的扩展
-
-在某些情况下,我们可能不希望 `Antilock` 直接操作互斥锁,而是通过一个管理锁的 `Guard` 对象来间接操作锁。例如,我们可以使用 `std::unique_lock` 作为 Guard,这样可以更加灵活。
-
-下面是一个简单的例子:
-
-```cpp
-template
-class Antilock {
-public:
-Antilock() = default;
-
- explicit Antilock(Guard& guard)
- : m_mutex(guard.mutex()) { // 使用 Guard 的 mutex 方法获取互斥锁
- if (m_mutex) {
- m_mutex->unlock();
- std::cout << "[Antilock] Lock released.\n";
- }
- }
-
- ~Antilock() {
- if (m_mutex) {
- m_mutex->lock();
- std::cout << "[Antilock] Lock reacquired.\n";
- }
- }
-
-private:
- typename Guard::mutex_type* m_mutex = nullptr;
-};
-```
-
-在这个版本中,`Antilock` 使用 `Guard` 对象的 `mutex()` 方法来获取互斥锁的指针,从而实现对锁的间接管理。
-
-## 使用场景
-
-### 调用外部操作
-
-假设我们有一个共享资源需要在多线程环境下访问,同时我们需要在访问资源的过程中调用一个外部操作。由于外部操作可能会导致长时间等待或死锁,我们希望在执行外部操作时释放锁,操作完成后重新获取锁。
-以下是一个使用 `Antilock` 模式的示例:
-
-```cpp
-class Resource {
-public:
- void DoSomething() {
- std::unique_lock lock(m_mutex);
- std::cout << "[Resource] Doing something under lock.\n";
- // 临时释放锁,执行外部操作
- Antilock> antilock(lock);
- m_data = ExternalOperation();
- std::cout << "[Resource] Finished doing something.\n";
- }
-
-private:
- int ExternalOperation() {
- std::this_thread::sleep_for(std::chrono::seconds(1)); // 模拟耗时操作
- std::cout << "[External] External operation completed.\n";
- return 42; // 返回一个结果
- }
-
- std::mutex m_mutex; // 保护共享数据的互斥锁
- int m_data = 0; // 共享数据
-};
-```
-
-在这个示例中,`DoSomething` 方法使用 `std::unique_lock` 锁定互斥锁,然后使用 `Antilock` 在执行外部操作时释放锁。外部操作完成后,锁会被自动重新获取。
-
-### 多线程
-
-为了更好地展示 `Antilock` 模式的优势,我们可以创建多个线程同时访问共享资源
-
-```cpp
-int main() {
-Resource resource;
-
- // 创建多个线程来访问共享资源
- std::vector threads;
- for (int i = 0; i < 5; ++i) {
- threads.emplace_back(&Resource::DoSomething, &resource);
- }
-
- // 等待所有线程完成
- for (auto& thread : threads) {
- thread.join();
- }
-
- return 0;
-
-}
-```
-
-在这个例子中,我们创建了多个线程,每个线程都会调用 Resource 对象的 DoSomething 方法。Antilock 模式确保了在外部操作期间,锁会被暂时释放,从而避免了可能的死锁。
-
-## 完整例子
-
-[CE链接](https://godbolt.org/z/7nEqT8n5M)
-
-```cpp
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-// 定义支持 Guard 的 Antilock 模板类
-template
-class Antilock {
-public:
- Antilock() = default;
-
- explicit Antilock(Guard& guard)
- : m_mutex(guard.mutex()) { // 使用 Guard 的 mutex 方法获取互斥锁
- if (m_mutex) {
- m_mutex->unlock();
- std::cout << "[Antilock] Lock released.\n";
- }
- }
-
- ~Antilock() {
- if (m_mutex) {
- m_mutex->lock();
- std::cout << "[Antilock] Lock reacquired.\n";
- }
- }
-
-private:
- typename Guard::mutex_type* m_mutex = nullptr; // 指向互斥锁的指针
-};
-
-// 模拟一个资源类,包含一个互斥锁和条件变量来保护和协调共享数据
-class Resource {
-public:
- void DoSomething() {
- std::unique_lock lock(m_mutex);
-
- // 使用条件变量等待某些条件满足
- m_condVar.wait(lock, [this] { return m_data.load() == 0; });
-
- std::cout << "[Resource] Doing something under lock.\n";
-
- try {
- // 临时释放锁,执行外部操作
- Antilock> antilock(lock);
- m_data.store(ExternalOperation());
- } catch (const std::exception& ex) {
- std::cerr << "[Error] Exception caught: " << ex.what() << "\n";
- // 处理异常的情况下,需要确保锁状态的一致性
- }
-
- std::cout << "[Resource] Finished doing something.\n";
- m_condVar.notify_all(); // 通知其他等待的线程
- }
-
-private:
- int ExternalOperation() {
- std::this_thread::sleep_for(std::chrono::seconds(1)); // 模拟耗时操作
- std::cout << "[External] External operation completed.\n";
- return 42; // 返回一个结果
- }
-
- std::mutex m_mutex; // 保护共享数据的互斥锁
- std::condition_variable m_condVar; // 用于线程同步的条件变量
- std::atomic m_data{0}; // 线程安全的共享数据
-};
-
-// 主函数
-int main() {
- Resource resource;
-
- // 创建多个线程来访问共享资源
- std::vector threads;
- for (int i = 0; i < 5; ++i) {
- threads.emplace_back(&Resource::DoSomething, &resource);
- }
-
- // 等待所有线程完成
- for (auto& thread : threads) {
- thread.join();
- }
-
- return 0;
-}
-```
-
-## 总结
-
-### 优点
-
-- 灵活性:Antilock 模式允许在需要时暂时释放锁,从而执行一些可能导致死锁的操作。
-- 自动管理:通过 RAII 模式,Antilock 在作用域结束时自动重新获取锁,无需手动管理锁的状态。
-
-### 注意事项
-
-- 锁状态的复杂性:使用 Antilock 可能会增加锁状态的复杂性,尤其是在嵌套锁定或递归锁定的场景中。
-- 异常处理:在实现 Antilock 模式时,确保在出现异常的情况下,锁能够被正确管理,以避免锁定状态不一致的问题。
-
-通过本文的介绍,希望大家能对 Antilock 模式有了更深入的了解,并能够在实际项目中应用这一模式来解决复杂的锁管理问题。如果有问题,欢迎大家在评论区中指出,新人发文,求大佬们轻喷!
diff --git a/doc/loading.drawio b/doc/loading.drawio
index 5bc0f8c9..9a3e311f 100644
--- a/doc/loading.drawio
+++ b/doc/loading.drawio
@@ -1,6 +1,6 @@
-
+
@@ -692,6 +692,68 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/doc/task/camera/smart_epxosure.md b/doc/task/camera/smart_epxosure.md
index e974dfc8..c5581f4c 100644
--- a/doc/task/camera/smart_epxosure.md
+++ b/doc/task/camera/smart_epxosure.md
@@ -1,43 +1,97 @@
-# SmartExposure Class Overview
+# `SmartExposure` 类概述
+
+`SmartExposure` 类是 N.I.N.A.(Nighttime Imaging 'N' Astronomy)软件中的关键组件,专门用于处理成像序列。以下是其逻辑、事件处理、验证、克隆等方面的详细描述。
The `SmartExposure` class is a key component within the N.I.N.A. (Nighttime Imaging 'N' Astronomy) software, designed to handle imaging sequences. Below is a detailed breakdown of its logic, event handling, validation, cloning, and more.
+---
+
+## 关键功能
+
## Key Functionalities
-- **Deserialization Initialization**
+### **反序列化初始化**
+
+### **Deserialization Initialization**
+
+- 在反序列化过程中清除所有序列项、条件和触发器,以确保一个干净的开始。
+- Clears all sequence items, conditions, and triggers during the deserialization process to ensure a clean start.
+
+---
+
+### **构造函数**
+
+### **Constructor**
+
+- 接受依赖项,如 `profileService`、`cameraMediator` 和 `filterWheelMediator`。
+- Initializes with dependencies like `profileService`, `cameraMediator`, and `filterWheelMediator`.
+
+- 初始化核心组件,包括 `SwitchFilter`、`TakeExposure`、`LoopCondition` 和 `DitherAfterExposures`。
+- Initializes core components including `SwitchFilter`, `TakeExposure`, `LoopCondition`, and `DitherAfterExposures`.
+
+---
+
+### **事件处理**
+
+### **Event Handling**
+
+- 监控 `SwitchFilter` 属性的变化。如果滤镜在序列创建或执行过程中发生变化,则进度会被重置。
+- Monitors changes to the `SwitchFilter` property. If the filter changes during sequence creation or execution, the progress is reset.
- - Clears all sequence items, conditions, and triggers during the deserialization process to ensure a clean start.
+---
-- **Constructor**
+### **错误处理**
- - Accepts dependencies like `profileService`, `cameraMediator`, and `filterWheelMediator`.
- - Initializes core components including `SwitchFilter`, `TakeExposure`, `LoopCondition`, and `DitherAfterExposures`.
+### **Error Handling**
-- **Event Handling**
+- 包含管理错误行为的属性 (`ErrorBehavior`) 和重试尝试次数 (`Attempts`),确保在执行过程中进行稳健的错误管理。
+- Incorporates properties for managing error behavior (`ErrorBehavior`) and retry attempts (`Attempts`), ensuring robust error management during execution.
- - Monitors changes to the `SwitchFilter` property. If the filter changes during sequence creation or execution, the progress is reset.
+---
-- **Error Handling**
+### **验证**
- - Incorporates properties for managing error behavior (`ErrorBehavior`) and retry attempts (`Attempts`), ensuring robust error management during execution.
+### **Validation**
-- **Validation**
+- `Validate` 方法检查所有内部组件的配置,并返回一个布尔值,指示该序列是否可以执行。
+- The `Validate` method checks the configuration of all internal components and returns a boolean indicating whether the sequence is valid for execution.
- - The `Validate` method checks the configuration of all internal components and returns a boolean indicating whether the sequence is valid for execution.
+---
-- **Cloning**
+### **克隆**
- - Provides a `Clone` method to create a deep copy of the `SmartExposure` instance, including all its associated components.
+### **Cloning**
-- **Duration Estimation**
+- 提供 `Clone` 方法来创建 `SmartExposure` 实例的深度拷贝,包括所有关联的组件。
+- Provides a `Clone` method to create a deep copy of the `SmartExposure` instance, including all its associated components.
- - Calculates the estimated duration for the sequence based on the exposure settings.
+---
-- **Interrupt Handling**
- - If an interruption occurs, it is rerouted to the parent sequence for consistent behavior.
+### **持续时间估算**
+
+### **Duration Estimation**
+
+- 根据曝光设置计算序列的预计持续时间。
+- Calculates the estimated duration for the sequence based on the exposure settings.
+
+---
+
+### **中断处理**
+
+### **Interrupt Handling**
+
+- 如果发生中断,它会被重新引导到父序列以保持一致的行为。
+- If an interruption occurs, it is rerouted to the parent sequence for consistent behavior.
+
+---
+
+## 流程图
## Flowchart
+以下是 `SmartExposure` 类执行过程中关键步骤的流程图。
+Below is a flowchart outlining the key steps in the execution process of the `SmartExposure` class.
+
```mermaid
graph TD;
A[Start] --> B{OnDeserializing};
@@ -69,3 +123,17 @@ graph TD;
T --> U[Interrupt Handling];
U --> V[End];
```
+
+---
+
+### 解释
+
+### Explanation
+
+1. **OnDeserializing**: 在反序列化过程中,清除所有项目、条件和触发器,确保干净的状态。
+2. **Constructor Called**: 构造函数初始化各个核心组件。
+3. **Event Handling**: 监听属性变化,如果 `SwitchFilter` 改变,进度将被重置。
+4. **Validate**: 验证序列的有效性,确保其可以执行。
+5. **Cloning**: 创建 `SmartExposure` 的深度克隆副本。
+6. **Estimate Duration**: 计算执行过程的预计时间。
+7. **Interrupt Handling**: 如果发生中断,将其引导至父序列。
diff --git a/doc/task/guider/dither.md b/doc/task/guider/dither.md
index 278a14be..981c5919 100644
--- a/doc/task/guider/dither.md
+++ b/doc/task/guider/dither.md
@@ -1,12 +1,30 @@
-# Dither
+# `Dither` 类
-The `Dither` class in the N.I.N.A. application is designed to execute the dithering process during an astronomical imaging session. Dithering is a technique used to slightly move the telescope between exposures to reduce the impact of fixed pattern noise in the final stacked image. This class is essential for achieving higher-quality images by ensuring that noise patterns do not align across exposures.
+`Dither` 类在 N.I.N.A.(Nighttime Imaging 'N' Astronomy)应用程序中用于在天文摄影会话期间执行抖动过程。抖动是一种技术,通过在曝光之间轻微移动望远镜来减少固定模式噪声对最终堆叠图像的影响。该类确保抖动过程的执行,从而提高图像质量,防止噪声模式在多个曝光中对齐。
+
+The `Dither` class in the N.I.N.A. (Nighttime Imaging 'N' Astronomy) application is designed to execute the dithering process during an astronomical imaging session. Dithering is a technique used to slightly move the telescope between exposures to reduce the impact of fixed pattern noise in the final stacked image. This class is essential for achieving higher-quality images by ensuring that noise patterns do not align across exposures.
+
+---
+
+## 类概述
## Class Overview
+### 命名空间
+
### Namespace
+- **命名空间:** `NINA.Sequencer.SequenceItem.Guider`
- **Namespace:** `NINA.Sequencer.SequenceItem.Guider`
+
+- **依赖项:**
+
+ - `NINA.Core.Model`
+ - `NINA.Sequencer.Validations`
+ - `NINA.Equipment.Interfaces.Mediator`
+ - `NINA.Core.Locale`
+ - `NINA.Profile.Interfaces`
+
- **Dependencies:**
- `NINA.Core.Model`
- `NINA.Sequencer.Validations`
@@ -14,6 +32,10 @@ The `Dither` class in the N.I.N.A. application is designed to execute the dither
- `NINA.Core.Locale`
- `NINA.Profile.Interfaces`
+---
+
+### 类声明
+
### Class Declaration
```csharp
@@ -26,14 +48,29 @@ The `Dither` class in the N.I.N.A. application is designed to execute the dither
public class Dither : SequenceItem, IValidatable
```
+---
+
+### 类属性
+
### Class Properties
+- **guiderMediator**: 负责与引导硬件的通信,以执行抖动命令。
- **guiderMediator**: Handles communication with the guider hardware to execute the dithering commands.
+
+- **profileService**: 提供对活动配置文件设置的访问,包括引导器配置。
- **profileService**: Provides access to the active profile settings, including guider configuration.
+
+- **Issues**: 在类验证过程中发现的问题列表,尤其是与引导器连接状态相关的问题。
- **Issues**: A list of issues found during the validation of the class, particularly related to the connection status of the guider.
+---
+
+### 构造函数
+
### Constructor
+构造函数初始化 `Dither` 类,并依赖于引导器中介器和配置文件服务,确保在执行抖动过程中所需的组件可用。
+
The constructor initializes the `Dither` class with dependencies on the guider mediator and profile service, ensuring the necessary components are available for executing the dithering process.
```csharp
@@ -41,17 +78,35 @@ The constructor initializes the `Dither` class with dependencies on the guider m
public Dither(IGuiderMediator guiderMediator, IProfileService profileService)
```
+---
+
+### 关键方法
+
### Key Methods
+- **Execute(IProgress progress, CancellationToken token)**: 该方法通过向引导器发出抖动命令来执行抖动过程。
- **Execute(IProgress progress, CancellationToken token)**: This method executes the dithering process by issuing the dither command to the guider.
+
+- **Validate()**: 在尝试抖动之前,检查引导器是否连接并正常工作。
- **Validate()**: Checks if the guider is connected and operational before attempting to dither.
+
+- **AfterParentChanged()**: 当父级序列项发生更改时,验证引导器的连接状态。
- **AfterParentChanged()**: Validates the connection status of the guider when the parent sequence item changes.
+
+- **GetEstimatedDuration()**: 返回抖动过程的预计持续时间,基于配置文件中的引导器设置。
- **GetEstimatedDuration()**: Returns an estimated duration for the dithering process, based on the profile's guider settings.
+
+- **Clone()**: 创建 `Dither` 对象的深拷贝。
- **Clone()**: Creates a deep copy of the `Dither` object.
+---
+
+### 流程图:执行过程
+
### Flowchart: Execution Process
-Below is a flowchart that outlines the key steps in the `Execute` method of the `Dither` class.
+以下是 `Dither` 类中 `Execute` 方法的关键步骤流程图。
+Below is a flowchart outlining the key steps in the `Execute` method of the `Dither` class.
```mermaid
flowchart TD
@@ -63,29 +118,70 @@ flowchart TD
F --> G[End Execution]
```
+---
+
+### 流程图解释
+
### Flowchart Explanation
+1. **引导器是否连接?**:首先检查引导器是否已连接且准备执行命令。
+
+ - **否:** 如果引导器未连接,会记录一个连接问题,并中止执行。
+ - **是:** 如果引导器已连接,继续执行下一步。
+
1. **Is Guider Connected?:** The process begins by checking whether the guider is connected and ready to execute commands.
+
- **No:** If the guider is not connected, an issue is logged indicating the connection problem, and the execution is aborted.
- **Yes:** If the guider is connected, the process continues.
-2. **Issue Dither Command to Guider:** The `Dither` class sends a command to the guider to perform the dithering operation.
-3. **Await Dither Completion:** The system waits for the guider to complete the dithering process.
-4. **End Execution:** The dithering process concludes, and control is returned to the sequence executor.
+
+1. **向引导器发出抖动命令**:`Dither` 类向引导器发送命令,执行抖动操作。
+1. **等待抖动完成**:系统等待引导器完成抖动过程。
+1. **结束执行**:抖动过程结束后,控制返回给序列执行器。
+
+1. **Issue Dither Command to Guider:** The `Dither` class sends a command to the guider to perform the dithering operation.
+1. **Await Dither Completion:** The system waits for the guider to complete the dithering process.
+1. **End Execution:** The dithering process concludes, and control is returned to the sequence executor.
+
+---
+
+### 方法详细描述
### Detailed Method Descriptions
+#### `Execute` 方法
+
#### `Execute` Method
+`Execute` 方法负责向引导器发出抖动命令。它依赖于 `guiderMediator` 来与引导硬件通信。该方法确保在尝试抖动之前引导器已连接并准备就绪,从而防止运行时错误。
+
The `Execute` method is responsible for issuing the dither command to the guider. It relies on the `guiderMediator` to communicate with the guider hardware. The method ensures that the guider is connected and ready before attempting to dither, thus preventing runtime errors.
+---
+
+#### `Validate` 方法
+
#### `Validate` Method
+`Validate` 方法检查引导器的连接状态。如果引导器未连接,方法会在 `Issues` 列表中添加问题,用户可以查看并解决问题。此验证步骤对于确保抖动过程能够无误执行至关重要。
+
The `Validate` method checks the connection status of the guider. If the guider is not connected, the method adds an issue to the `Issues` list, which can be reviewed by the user to troubleshoot the problem. This validation step is critical for ensuring that the dithering process can be executed without errors.
+---
+
+#### `AfterParentChanged` 方法
+
#### `AfterParentChanged` Method
+每当父级序列项发生更改时,都会调用 `AfterParentChanged` 方法。此方法会重新验证 `Dither` 项,以确保考虑到序列上下文中的任何变化,尤其是设备连接性方面的变化。
+
The `AfterParentChanged` method is called whenever the parent sequence item is changed. It triggers a re-validation of the `Dither` item to ensure that any changes in the sequence context are taken into account, particularly in terms of equipment connectivity.
+---
+
+#### `GetEstimatedDuration` 方法
+
#### `GetEstimatedDuration` Method
+`GetEstimatedDuration` 方法返回完成抖动过程的预计时间。该估算基于活动配置文件中引导器设置的 `SettleTimeout` 值,提供抖动所需时间的大致时间线。
+
The `GetEstimatedDuration` method returns the estimated time required to complete the dithering process. This estimate is based on the `SettleTimeout` value specified in the active profile's guider settings, providing a rough timeline for how long the dithering will take.
diff --git a/doc/task/guider/start_guiding.md b/doc/task/guider/start_guiding.md
index e404da0a..374ae30f 100644
--- a/doc/task/guider/start_guiding.md
+++ b/doc/task/guider/start_guiding.md
@@ -1,18 +1,39 @@
-# StartGuiding
+# `StartGuiding` 类
-The `StartGuiding` class in the N.I.N.A. application is designed to initiate the guiding process during an astronomical imaging session. Guiding is a critical process in astrophotography, where a separate guider camera or system is used to keep the telescope precisely aligned with the target object. This class ensures that the guiding process starts correctly and optionally forces a new calibration.
+`StartGuiding` 类在 N.I.N.A.(Nighttime Imaging 'N' Astronomy)应用程序中用于启动引导过程,这是天文摄影会话中的关键步骤。在天文摄影中,使用一个独立的引导相机或系统来保持望远镜与目标物体的精确对齐。该类确保引导过程正确启动,并且可以选择强制进行新的校准。
+
+The `StartGuiding` class in the N.I.N.A. (Nighttime Imaging 'N' Astronomy) application is used to initiate the guiding process, which is a critical step in an astronomical imaging session. In astrophotography, a separate guiding camera or system is used to keep the telescope precisely aligned with the target object. This class ensures that the guiding process starts correctly and optionally forces a new calibration.
+
+---
+
+## 类概述
## Class Overview
+### 命名空间
+
### Namespace
+- **命名空间:** `NINA.Sequencer.SequenceItem.Guider`
- **Namespace:** `NINA.Sequencer.SequenceItem.Guider`
+
+- **依赖项:**
+
+ - `NINA.Core.Model`
+ - `NINA.Sequencer.Validations`
+ - `NINA.Equipment.Interfaces.Mediator`
+ - `NINA.Core.Locale`
+
- **Dependencies:**
- `NINA.Core.Model`
- `NINA.Sequencer.Validations`
- `NINA.Equipment.Interfaces.Mediator`
- `NINA.Core.Locale`
+---
+
+### 类声明
+
### Class Declaration
```csharp
@@ -25,31 +46,62 @@ The `StartGuiding` class in the N.I.N.A. application is designed to initiate the
public class StartGuiding : SequenceItem, IValidatable
```
+---
+
+### 类属性
+
### Class Properties
-- **guiderMediator**: Manages communication with the guider hardware, handling the start of the guiding process.
-- **ForceCalibration**: A boolean flag indicating whether the guiding process should force a new calibration before starting.
-- **Issues**: A list of issues identified during validation, particularly related to the guider’s connection status and calibration capability.
+- **guiderMediator**: 管理与引导硬件的通信,处理引导过程的启动。
+- **guiderMediator**: Manages communication with the guiding hardware, handling the start of the guiding process.
+
+- **ForceCalibration**: 一个布尔值标志,指示是否应在启动引导之前强制进行新的校准。
+- **ForceCalibration**: A boolean flag indicating whether a new calibration should be forced before starting the guiding process.
+
+- **Issues**: 在验证过程中发现的问题列表,尤其与引导器的连接状态和校准能力相关。
+- **Issues**: A list of issues identified during validation, particularly related to the guider's connection status and calibration capability.
+
+---
+
+### 构造函数
### Constructor
-The constructor initializes the `StartGuiding` class by setting up the connection with the guider mediator. This ensures that the class can interact with the guider system when executing the guiding process.
+构造函数初始化 `StartGuiding` 类,并通过 `guiderMediator` 与引导设备建立连接。这确保了该类能够在执行引导过程中与引导系统交互。
+
+The constructor initializes the `StartGuiding` class and establishes a connection with the guider hardware via the `guiderMediator`. This ensures that the class can interact with the guider system during the guiding process.
```csharp
[ImportingConstructor]
public StartGuiding(IGuiderMediator guiderMediator)
```
+---
+
+### 关键方法
+
### Key Methods
+- **Execute(IProgress progress, CancellationToken token)**: 启动引导过程,可选择强制执行新的校准。如果过程失败,将抛出异常。
- **Execute(IProgress progress, CancellationToken token)**: Starts the guiding process, optionally forcing a new calibration. If the process fails, an exception is thrown.
-- **Validate()**: Validates the connection to the guider and checks if forced calibration is possible, updating the `Issues` list if any problems are detected.
+
+- **Validate()**: 验证引导器的连接,并检查是否可以强制校准,若发现问题则更新 `Issues` 列表。
+- **Validate()**: Validates the guider connection and checks whether forced calibration is possible. It updates the `Issues` list if any problems are detected.
+
+- **AfterParentChanged()**: 每当父级序列项发生变化时,重新验证引导器的连接和能力。
- **AfterParentChanged()**: Re-validates the guider connection and capabilities whenever the parent sequence item changes.
+
+- **Clone()**: 创建 `StartGuiding` 对象的副本,保留其属性和元数据。
- **Clone()**: Creates a copy of the `StartGuiding` object, preserving its properties and metadata.
+---
+
+### 流程图:执行过程
+
### Flowchart: Execution Process
-Below is a flowchart that outlines the key steps in the `Execute` method of the `StartGuiding` class.
+以下是 `StartGuiding` 类中 `Execute` 方法的关键步骤流程图。
+Below is a flowchart outlining the key steps in the `Execute` method of the `StartGuiding` class.
```mermaid
flowchart TD
@@ -67,37 +119,75 @@ flowchart TD
J -->|Yes| L[End Execution]
```
+---
+
+### 流程图解释
+
### Flowchart Explanation
-1. **Is Guider Connected?**: The process begins by verifying that the guider is connected and ready.
+1. **引导器是否连接?**:首先检查引导器是否已连接且准备就绪。
+
+ - **否:** 如果引导器未连接,抛出异常并终止过程。
+ - **是:** 如果已连接,继续执行下一步。
+
+1. **Is Guider Connected?**: The process begins by verifying whether the guider is connected and ready.
+
- **No:** If the guider is not connected, an exception is thrown, aborting the process.
- **Yes:** If connected, the process continues to the next step.
-2. **Force Calibration?**: Checks if the guiding process should force a new calibration.
- - **Yes:** If calibration is required, the system checks if the guider can clear its current calibration.
+
+1. **是否强制校准?**:检查引导过程中是否需要强制执行新校准。
+
+ - **是:** 如果需要校准,系统将检查引导器是否可以清除当前校准数据。
+ - **否:** 如果不强制校准,则直接开始引导。
+
+1. **Force Calibration?**: Checks whether a new calibration should be forced during the guiding process.
+
+ - **Yes:** If calibration is required, the system checks if the guider can clear its current calibration data.
- **No:** If no calibration is forced, guiding begins without recalibration.
-3. **Can Guider Clear Calibration?**: If calibration is forced, this step checks whether the guider can clear its existing calibration.
+
+1. **引导器能否清除校准?**:如果需要校准,此步骤检查引导器是否能清除现有校准数据。
+
+ - **否:** 如果不能清除校准,抛出异常。
+ - **是:** 如果可以,系统开始带有校准的引导。
+
+1. **Can Guider Clear Calibration?**: If calibration is required, this step checks whether the guider can clear existing calibration data.
+
- **No:** If the guider cannot clear calibration, an exception is thrown.
- - **Yes:** If it can, the system proceeds to start guiding with calibration.
-4. **Start Guiding**: The guiding process is initiated, either with or without calibration based on the previous steps.
-5. **Await Guiding Completion**: The system waits for the guiding process to either succeed or fail.
-6. **Guiding Started Successfully?**: A final check to confirm that guiding has started successfully.
+ - **Yes:** If possible, the system proceeds to start guiding with calibration.
+
+1. **开始引导**:根据前面的步骤,启动引导过程,无论是否带有校准。
+1. **等待引导完成**:系统等待引导过程的成功或失败。
+1. **引导是否成功启动?**:最后检查引导是否成功启动。
+
+ - **否:** 如果引导失败,抛出异常。
+ - **是:** 如果成功,则执行完成。
+
+1. **Start Guiding**: The guiding process is initiated, either with or without calibration based on previous steps.
+1. **Await Guiding Completion**: The system waits for the guiding process to either succeed or fail.
+1. **Guiding Started Successfully?**: A final check is made to confirm whether guiding has started successfully.
- **No:** If guiding fails, an exception is thrown.
- **Yes:** If successful, the process completes.
+---
+
+### 方法详细描述
+
### Detailed Method Descriptions
+#### `Execute` 方法
+
#### `Execute` Method
-The `Execute` method is the core of the `StartGuiding` class. It handles starting the guiding process, including the optional step of forcing a new calibration. The method uses the `guiderMediator` to interact with the guider hardware, ensuring the process is executed correctly. If any issues arise—such as a failure to start guiding or the inability to clear calibration—an exception is thrown to halt the sequence.
+`Execute` 方法是 `StartGuiding` 类的核心,它负责启动引导过程,包括选择性地强制执行新校准。该方法通过 `guiderMediator` 与引导硬件交互,确保过程正确执行。如果在引导启动或清除校准时出现问题,会抛出异常以中断序列。
-#### `Validate` Method
+The `Execute` method is the core of the `StartGuiding` class. It is responsible for starting the guiding process, including optionally forcing a new calibration. The method interacts with the guider hardware through `guiderMediator` to ensure the process is executed correctly. If any issues arise—such as failing to start guiding or being unable to clear calibration—an exception is thrown to halt the sequence.
-The `Validate` method checks the readiness of the guiding system. It ensures that the guider is connected and, if necessary, verifies that it can clear calibration data before starting a new calibration. The results of this validation are stored in the `Issues` list, which provides a way to identify and troubleshoot potential problems before executing the guiding process.
+---
-#### `AfterParentChanged` Method
+#### `Validate` 方法
-The `AfterParentChanged` method is invoked whenever the parent sequence item changes. This triggers a re-validation of the `StartGuiding` class to ensure that any contextual changes—such as different equipment or settings—are taken into account. It helps maintain the integrity of the sequence by ensuring that guiding can be started successfully in the new context.
+#### `Validate` Method
-#### `Clone` Method
+`Validate` 方法检查引导系统的就绪状态。它确保引导器已连接,并在必要时验证引导器能否清除校准数据。验证结果存储在 `Issues` 列表中,以便在执行引导过程之前识别和解决潜在问题。
-The `Clone` method creates a new instance of the `StartGuiding` class with the same properties and metadata as the original. This allows the guiding process to be reused or repeated in different parts of a sequence without needing to manually reconfigure each instance.
+The `Validate` method checks the readiness of the guiding system. It ensures that the guider is connected and, if necessary, verifies that it
diff --git a/doc/task/guider/stop_guiding.md b/doc/task/guider/stop_guiding.md
index 5092c7bd..323f5619 100644
--- a/doc/task/guider/stop_guiding.md
+++ b/doc/task/guider/stop_guiding.md
@@ -1,18 +1,39 @@
-# StopGuiding
+# `StopGuiding` 类
-The `StopGuiding` class in the N.I.N.A. application is responsible for halting the guiding process during an astronomical imaging session. Guiding is a critical aspect of astrophotography, where a guiding camera or system keeps the telescope aligned with the target object. This class ensures that the guiding process stops correctly and validates that the guiding system is properly connected before attempting to stop guiding.
+`StopGuiding` 类在 N.I.N.A.(Nighttime Imaging 'N' Astronomy)应用程序中负责在天文摄影会话期间停止引导过程。引导是天文摄影的关键步骤,其中引导相机或系统用于保持望远镜与目标物体的对齐。该类确保引导过程正确停止,并在尝试停止引导之前验证引导系统是否正确连接。
+
+The `StopGuiding` class in the N.I.N.A. (Nighttime Imaging 'N' Astronomy) application is responsible for halting the guiding process during an astronomical imaging session. Guiding is a critical aspect of astrophotography, where a guiding camera or system keeps the telescope aligned with the target object. This class ensures that the guiding process stops correctly and validates that the guiding system is properly connected before attempting to stop guiding.
+
+---
+
+## 类概述
## Class Overview
+### 命名空间
+
### Namespace
+- **命名空间:** `NINA.Sequencer.SequenceItem.Guider`
- **Namespace:** `NINA.Sequencer.SequenceItem.Guider`
+
+- **依赖项:**
+
+ - `NINA.Core.Model`
+ - `NINA.Sequencer.Validations`
+ - `NINA.Equipment.Interfaces.Mediator`
+ - `NINA.Core.Locale`
+
- **Dependencies:**
- `NINA.Core.Model`
- `NINA.Sequencer.Validations`
- `NINA.Equipment.Interfaces.Mediator`
- `NINA.Core.Locale`
+---
+
+### 类声明
+
### Class Declaration
```csharp
@@ -25,13 +46,26 @@ The `StopGuiding` class in the N.I.N.A. application is responsible for halting t
public class StopGuiding : SequenceItem, IValidatable
```
+---
+
+### 类属性
+
### Class Properties
+- **guiderMediator**: 管理与引导硬件的通信,专门处理停止引导过程。
- **guiderMediator**: An interface that handles communication with the guider hardware, specifically managing the stop guiding process.
+
+- **Issues**: 在验证引导器连接状态时识别的问题列表。
- **Issues**: A list of issues that are identified during the validation of the guider's connection status.
+---
+
+### 构造函数
+
### Constructor
+构造函数初始化 `StopGuiding` 类,并通过 `guiderMediator` 与引导设备建立连接,确保类能够与引导系统交互以停止引导过程。
+
The constructor initializes the `StopGuiding` class by setting up the connection with the guider mediator, ensuring that the class can interact with the guider system to stop the guiding process.
```csharp
@@ -39,16 +73,32 @@ The constructor initializes the `StopGuiding` class by setting up the connection
public StopGuiding(IGuiderMediator guiderMediator)
```
+---
+
+### 关键方法
+
### Key Methods
+- **Execute(IProgress progress, CancellationToken token)**: 使用 `guiderMediator` 停止引导过程。如果出现任何问题,可能会抛出异常。
- **Execute(IProgress progress, CancellationToken token)**: Stops the guiding process using the `guiderMediator`. If any issues occur, an exception may be thrown.
+
+- **Validate()**: 在尝试停止引导之前,确保引导系统已连接。如果发现问题,更新 `Issues` 列表。
- **Validate()**: Ensures that the guiding system is connected before attempting to stop guiding. Updates the `Issues` list if any problems are found.
+
+- **AfterParentChanged()**: 每当父级序列项发生变化时,重新验证引导器的连接状态。
- **AfterParentChanged()**: Re-validates the guider connection whenever the parent sequence item changes.
+
+- **Clone()**: 创建 `StopGuiding` 对象的新实例,保留其属性和元数据。
- **Clone()**: Creates a new instance of the `StopGuiding` object, preserving its properties and metadata.
+---
+
+### 流程图:执行过程
+
### Flowchart: Execution Process
-Below is a flowchart that outlines the key steps in the `Execute` method of the `StopGuiding` class.
+以下是 `StopGuiding` 类中 `Execute` 方法的关键步骤流程图。
+Below is a flowchart outlining the key steps in the `Execute` method of the `StopGuiding` class.
```mermaid
flowchart TD
@@ -59,29 +109,70 @@ flowchart TD
E --> F[End Execution]
```
+---
+
+### 流程图解释
+
### Flowchart Explanation
+1. **引导器是否连接?**:首先检查引导器是否已连接且准备就绪。
+
+ - **否:** 如果引导器未连接,抛出异常并终止过程。
+ - **是:** 如果已连接,继续执行停止引导过程。
+
1. **Is Guider Connected?**: The process begins by verifying that the guider is connected and ready.
+
- **No:** If the guider is not connected, an exception is thrown, aborting the process.
- **Yes:** If connected, the process continues to stop the guiding process.
-2. **Stop Guiding Process**: The guider is instructed to stop the guiding process.
-3. **Await Stopping Completion**: The system waits for the guiding process to stop completely.
-4. **End Execution**: The method completes execution after successfully stopping the guider.
+
+1. **停止引导过程**:引导器被指示停止引导过程。
+1. **等待停止完成**:系统等待引导过程完全停止。
+1. **执行结束**:引导成功停止后,方法执行完成。
+
+1. **Stop Guiding Process**: The guider is instructed to stop the guiding process.
+1. **Await Stopping Completion**: The system waits for the guiding process to stop completely.
+1. **End Execution**: The method completes execution after successfully stopping the guider.
+
+---
+
+### 方法详细描述
### Detailed Method Descriptions
+#### `Execute` 方法
+
#### `Execute` Method
+`Execute` 方法是 `StopGuiding` 类的主要功能。它使用 `guiderMediator` 与引导硬件交互,发送停止引导的命令。该方法确保引导过程成功停止,如果出现任何问题,会适当地处理,可能会抛出异常。
+
The `Execute` method is the primary function of the `StopGuiding` class. It uses the `guiderMediator` to interact with the guider hardware, sending the command to stop guiding. The method ensures that the guiding process halts successfully, and if any issues arise, it handles them appropriately, potentially throwing an exception.
+---
+
+#### `Validate` 方法
+
#### `Validate` Method
+`Validate` 方法检查引导系统是否正确连接,以便允许停止引导过程。它会更新 `Issues` 列表,记录遇到的任何问题,例如引导器断开连接。此验证对于在执行过程中防止错误至关重要。
+
The `Validate` method checks that the guiding system is properly connected before allowing the guiding process to be stopped. It updates the `Issues` list with any problems that it encounters, such as the guider being disconnected. This validation is crucial to prevent errors during execution.
+---
+
+#### `AfterParentChanged` 方法
+
#### `AfterParentChanged` Method
+每当父级序列项发生变化时,都会调用 `AfterParentChanged` 方法。这会触发对 `StopGuiding` 类的重新验证,以确保任何上下文变化(如不同的设备或设置)都被考虑到。这有助于确保序列的可靠性,确认在新上下文中停止引导仍然合适。
+
The `AfterParentChanged` method is called whenever the parent sequence item changes. This triggers a re-validation of the `StopGuiding` class to ensure that any contextual changes—such as different equipment or settings—are accounted for. This helps maintain the reliability of the sequence by confirming that stopping guiding is still appropriate in the new context.
+---
+
+#### `Clone` 方法
+
#### `Clone` Method
+`Clone` 方法创建 `StopGuiding` 类的新实例,保留所有属性和元数据。这在序列的不同部分重复停止引导过程时非常有用,而无需手动配置每个实例。
+
The `Clone` method creates a new instance of the `StopGuiding` class, preserving all properties and metadata from the original instance. This is useful for repeating the stop guiding process in different parts of a sequence without manually configuring each instance.
diff --git a/driverlibs b/driverlibs
new file mode 160000
index 00000000..339f65c3
--- /dev/null
+++ b/driverlibs
@@ -0,0 +1 @@
+Subproject commit 339f65c35c2398b0cf2a7153123c5e5b2fa9c66e
diff --git a/libs b/libs
index f2b60c27..f1082244 160000
--- a/libs
+++ b/libs
@@ -1 +1 @@
-Subproject commit f2b60c276ef6fa16ce1b820e5062063c5f2416be
+Subproject commit f1082244ea563500599fcc1726aadfb3e138c387
diff --git a/modules/lithium.pytools/tools/cbuilder.py b/modules/lithium.pytools/tools/cbuilder.py
index e68654a1..2805c10c 100644
--- a/modules/lithium.pytools/tools/cbuilder.py
+++ b/modules/lithium.pytools/tools/cbuilder.py
@@ -25,61 +25,132 @@
import argparse
import subprocess
import sys
+import os
from pathlib import Path
from typing import Literal, List, Optional
-class CMakeBuilder:
+
+class BuildHelperBase:
"""
- CMakeBuilder is a utility class to handle building projects using CMake.
+ Base class for build helpers providing shared functionality.
Args:
source_dir (Path): Path to the source directory.
build_dir (Path): Path to the build directory.
- generator (Literal["Ninja", "Unix Makefiles"]): CMake generator to use.
- build_type (Literal["Debug", "Release"]): Build type (Debug or Release).
- install_prefix (Path): Installation prefix directory.
- cmake_options (Optional[List[str]]): Additional options for CMake.
+ install_prefix (Path): Directory prefix where the project will be installed.
+ options (Optional[List[str]]): Additional options for the build system.
+ env_vars (Optional[dict]): Environment variables to set for the build process.
+ verbose (bool): Flag to enable verbose output during command execution.
Methods:
- run_command: Helper function to run shell commands.
- configure: Configure the CMake build system.
- build: Build the project.
- install: Install the project.
- clean: Clean the build directory.
- test: Run CTest for the project.
- generate_docs: Generate documentation if documentation target is available.
+ run_command: Executes shell commands with optional environment variables and verbosity.
+ clean: Cleans the build directory by removing all files and subdirectories.
"""
+
def __init__(
self,
source_dir: Path,
build_dir: Path,
- generator: Literal["Ninja", "Unix Makefiles"] = "Ninja",
- build_type: Literal["Debug", "Release"] = "Debug",
- install_prefix: Path = None, # type: ignore
- cmake_options: Optional[List[str]] = None,
+ install_prefix: Path = None, # type: ignore
+ options: Optional[List[str]] = None,
+ env_vars: Optional[dict] = None,
+ verbose: bool = False,
):
self.source_dir = source_dir
self.build_dir = build_dir
- self.generator = generator
- self.build_type = build_type
self.install_prefix = install_prefix or build_dir / "install"
- self.cmake_options = cmake_options or []
+ self.options = options or []
+ self.env_vars = env_vars or {}
+ self.verbose = verbose
def run_command(self, *cmd: str):
"""
- Helper function to run shell commands.
+ Helper function to run shell commands with optional environment variables and verbosity.
Args:
- cmd (str): The command and its arguments to run.
+ cmd (str): The command and its arguments to run as separate strings.
+
+ Raises:
+ SystemExit: Exits with the command's return code if it fails.
"""
print(f"Running: {' '.join(cmd)}")
- result = subprocess.run(cmd, check=True, capture_output=True, text=True)
- print(result.stdout)
- if result.stderr:
- print(result.stderr, file=sys.stderr)
+ env = os.environ.copy()
+ env.update(self.env_vars)
+ try:
+ result = subprocess.run(
+ cmd, check=True, capture_output=True, text=True, env=env)
+ if self.verbose or result.returncode != 0:
+ print(result.stdout)
+ if result.stderr:
+ print(result.stderr, file=sys.stderr)
+ except subprocess.CalledProcessError as e:
+ print(f"Error running command: {e}", file=sys.stderr)
+ sys.exit(e.returncode)
+
+ def clean(self):
+ """
+ Cleans the build directory by removing all files and subdirectories.
+
+ This function ensures that the build directory is completely cleaned, making it ready
+ for a fresh build by removing all existing files and directories inside the build path.
+ """
+ if self.build_dir.exists():
+ for item in self.build_dir.iterdir():
+ if item.is_dir():
+ self.run_command("rm", "-rf", str(item))
+ else:
+ item.unlink()
+ print(f"Cleaned: {self.build_dir}")
+
+
+class CMakeBuilder(BuildHelperBase):
+ """
+ CMakeBuilder is a utility class to handle building projects using CMake.
+
+ Args:
+ source_dir (Path): Path to the source directory.
+ build_dir (Path): Path to the build directory.
+ generator (Literal["Ninja", "Unix Makefiles"]): The CMake generator to use (e.g., Ninja or Unix Makefiles).
+ build_type (Literal["Debug", "Release"]): Type of build (Debug or Release).
+ install_prefix (Path): Directory prefix where the project will be installed.
+ cmake_options (Optional[List[str]]): Additional options for CMake.
+ env_vars (Optional[dict]): Environment variables to set for the build process.
+ verbose (bool): Flag to enable verbose output during command execution.
+ parallel (int): Number of parallel jobs to use for building.
+
+ Methods:
+ configure: Configures the CMake build system by generating build files.
+ build: Builds the project, optionally specifying a target.
+ install: Installs the project to the specified prefix.
+ test: Runs tests using CTest with detailed output on failure.
+ generate_docs: Generates documentation using the specified documentation target.
+ """
+
+ def __init__(
+ self,
+ source_dir: Path,
+ build_dir: Path,
+ generator: Literal["Ninja", "Unix Makefiles"] = "Ninja",
+ build_type: Literal["Debug", "Release"] = "Debug",
+ install_prefix: Path = None, # type: ignore
+ cmake_options: Optional[List[str]] = None,
+ env_vars: Optional[dict] = None,
+ verbose: bool = False,
+ parallel: int = 4,
+ ):
+ super().__init__(source_dir, build_dir,
+ install_prefix, cmake_options, env_vars, verbose)
+ self.generator = generator
+ self.build_type = build_type
+ self.parallel = parallel
def configure(self):
- """Configure the CMake build system."""
+ """
+ Configures the CMake build system.
+
+ This function generates the necessary build files using CMake based on the
+ specified generator, build type, and additional CMake options provided.
+ """
self.build_dir.mkdir(parents=True, exist_ok=True)
cmake_args = [
"cmake",
@@ -87,96 +158,100 @@ def configure(self):
f"-DCMAKE_BUILD_TYPE={self.build_type}",
f"-DCMAKE_INSTALL_PREFIX={self.install_prefix}",
str(self.source_dir),
- ] + self.cmake_options
+ ] + self.options
self.run_command(*cmake_args)
def build(self, target: str = ""):
"""
- Build the project.
+ Builds the project using CMake.
Args:
- target (str): Specific build target to build.
+ target (str): Specific build target to build (optional). If not specified, the default target is built.
+
+ This function uses the CMake build command, supporting parallel jobs to speed up the build process.
"""
- build_cmd = ["cmake", "--build", str(self.build_dir)]
+ build_cmd = ["cmake", "--build",
+ str(self.build_dir), "--parallel", str(self.parallel)]
if target:
build_cmd += ["--target", target]
self.run_command(*build_cmd)
def install(self):
- """Install the project."""
- self.run_command("cmake", "--install", str(self.build_dir))
+ """
+ Installs the project to the specified prefix.
- def clean(self):
- """Clean the build directory."""
- if self.build_dir.exists():
- for item in self.build_dir.iterdir():
- if item.is_dir():
- self.run_command("rm", "-rf", str(item))
- else:
- item.unlink()
+ This function runs the CMake install command, which installs the built
+ artifacts to the directory specified by the install prefix.
+ """
+ self.run_command("cmake", "--install", str(self.build_dir))
def test(self):
- """Run CTest for the project."""
- self.run_command("ctest", "--output-on-failure", "-C", self.build_type, "-S", str(self.build_dir))
+ """
+ Runs tests using CTest with detailed output on failure.
+
+ This function runs CTest to execute the project's tests, providing detailed
+ output if any tests fail, making it easier to diagnose issues.
+ """
+ self.run_command("ctest", "--output-on-failure", "-C",
+ self.build_type, "-S", str(self.build_dir))
def generate_docs(self, doc_target: str = "doc"):
"""
- Generate documentation if documentation target is available.
+ Generates documentation if the specified documentation target is available.
Args:
- doc_target (str): Documentation target to build.
+ doc_target (str): The documentation target to build (default is 'doc').
+
+ This function builds the specified documentation target using the CMake build command.
"""
self.build(doc_target)
-class MesonBuilder:
+
+class MesonBuilder(BuildHelperBase):
"""
MesonBuilder is a utility class to handle building projects using Meson.
Args:
source_dir (Path): Path to the source directory.
build_dir (Path): Path to the build directory.
- build_type (Literal["debug", "release"]): Build type (debug or release).
- install_prefix (Path): Installation prefix directory.
+ build_type (Literal["debug", "release"]): Type of build (debug or release).
+ install_prefix (Path): Directory prefix where the project will be installed.
meson_options (Optional[List[str]]): Additional options for Meson.
+ env_vars (Optional[dict]): Environment variables to set for the build process.
+ verbose (bool): Flag to enable verbose output during command execution.
+ parallel (int): Number of parallel jobs to use for building.
Methods:
- run_command: Helper function to run shell commands.
- configure: Configure the Meson build system.
- build: Build the project.
- install: Install the project.
- clean: Clean the build directory.
- test: Run Meson tests for the project.
- generate_docs: Generate documentation if documentation target is available.
+ configure: Configures the Meson build system by generating build files.
+ build: Builds the project, optionally specifying a target.
+ install: Installs the project to the specified prefix.
+ test: Runs tests using Meson, with error logs printed on failures.
+ generate_docs: Generates documentation using the specified documentation target.
"""
+
def __init__(
self,
source_dir: Path,
build_dir: Path,
build_type: Literal["debug", "release"] = "debug",
- install_prefix: Path = None, # type: ignore
+ install_prefix: Path = None, # type: ignore
meson_options: Optional[List[str]] = None,
+ env_vars: Optional[dict] = None,
+ verbose: bool = False,
+ parallel: int = 4,
):
- self.source_dir = source_dir
- self.build_dir = build_dir
+ super().__init__(source_dir, build_dir,
+ install_prefix, meson_options, env_vars, verbose)
self.build_type = build_type
- self.install_prefix = install_prefix or build_dir / "install"
- self.meson_options = meson_options or []
+ self.parallel = parallel
- def run_command(self, *cmd: str):
+ def configure(self):
"""
- Helper function to run shell commands.
+ Configures the Meson build system.
- Args:
- cmd (str): The command and its arguments to run.
+ This function sets up the Meson build system, generating necessary build files based on
+ the specified build type and additional Meson options provided.
"""
- print(f"Running: {' '.join(cmd)}")
- result = subprocess.run(cmd, check=True, capture_output=True, text=True)
- print(result.stdout)
- if result.stderr:
- print(result.stderr, file=sys.stderr)
-
- def configure(self):
- """Configure the Meson build system。"""
self.build_dir.mkdir(parents=True, exist_ok=True)
meson_args = [
"meson",
@@ -185,84 +260,117 @@ def configure(self):
str(self.source_dir),
f"--buildtype={self.build_type}",
f"--prefix={self.install_prefix}",
- ] + self.meson_options
+ ] + self.options
self.run_command(*meson_args)
def build(self, target: str = ""):
"""
- Build the project.
+ Builds the project using Meson.
Args:
- target (str): Specific target to build.
+ target (str): Specific target to build (optional). If not specified, the default target is built.
+
+ This function compiles the project using Meson's compile command, with support for parallel jobs
+ to speed up the build process.
"""
- build_cmd = ["meson", "compile", "-C", str(self.build_dir)]
+ build_cmd = ["meson", "compile", "-C",
+ str(self.build_dir), f"-j{self.parallel}"]
if target:
build_cmd += ["--target", target]
self.run_command(*build_cmd)
def install(self):
- """Install the project。"""
- self.run_command("meson", "install", "-C", str(self.build_dir))
+ """
+ Installs the project to the specified prefix.
- def clean(self):
- """Clean the build directory。"""
- if self.build_dir.exists():
- for item in self.build_dir.iterdir():
- if item.is_dir():
- self.run_command("rm", "-rf", str(item))
- else:
- item.unlink()
+ This function runs the Meson install command, which installs the built
+ artifacts to the directory specified by the install prefix.
+ """
+ self.run_command("meson", "install", "-C", str(self.build_dir))
def test(self):
- """Run Meson tests for the project。"""
- self.run_command("meson", "test", "-C", str(self.build_dir), "--print-errorlogs")
+ """
+ Runs tests using Meson, with error logs printed on failures.
+
+ This function runs Meson tests, displaying error logs for any failed tests
+ to provide detailed feedback and aid in debugging.
+ """
+ self.run_command("meson", "test", "-C",
+ str(self.build_dir), "--print-errorlogs")
def generate_docs(self, doc_target: str = "doc"):
"""
- Generate documentation if documentation target is available.
+ Generates documentation if the specified documentation target is available.
Args:
- doc_target (str): Documentation target to build.
+ doc_target (str): The documentation target to build (default is 'doc').
+
+ This function builds the specified documentation target using the Meson build system.
"""
self.build(doc_target)
+
def main():
"""
Main function to run the build system helper.
+
+ This function parses command-line arguments to determine the build system (CMake or Meson),
+ source and build directories, build options, and actions (clean, build, install, test, generate docs).
+ It then initializes the appropriate builder class and performs the requested operations.
+
+ Command-line Arguments:
+ --source_dir: Specifies the source directory of the project.
+ --build_dir: Specifies the build directory where build files and artifacts will be generated.
+ --builder: Specifies the build system to use ('cmake' or 'meson').
+ --generator: Specifies the CMake generator (e.g., Ninja, Unix Makefiles) if using CMake.
+ --build_type: Specifies the build type ('Debug', 'Release', 'debug', 'release').
+ --target: Specifies a specific build target to build.
+ --install: Flag to indicate that the project should be installed after building.
+ --clean: Flag to indicate that the build directory should be cleaned before building.
+ --test: Flag to indicate that tests should be run after building.
+ --cmake_options: Additional options for CMake.
+ --meson_options: Additional options for Meson.
+ --generate_docs: Flag to indicate that documentation should be generated.
+ --env: Environment variables to set during the build process.
+ --verbose: Enables verbose output during command execution.
+ --parallel: Number of parallel jobs to use for building.
"""
parser = argparse.ArgumentParser(description="Build System Python Builder")
+ parser.add_argument("--source_dir", type=Path,
+ default=Path(".").resolve(), help="Source directory")
+ parser.add_argument("--build_dir", type=Path,
+ default=Path("build").resolve(), help="Build directory")
parser.add_argument(
- "--source_dir", type=Path, default=Path(".").resolve(), help="Source directory"
- )
+ "--builder", choices=["cmake", "meson"], required=True, help="Choose the build system")
parser.add_argument(
- "--build_dir",
- type=Path,
- default=Path("build").resolve(),
- help="Build directory",
- )
- parser.add_argument("--builder", choices=["cmake", "meson"], required=True, help="Choose the build system")
- parser.add_argument("--generator", choices=["Ninja", "Unix Makefiles"], default="Ninja")
- parser.add_argument("--build_type", choices=["Debug", "Release", "debug", "release"], default="Debug")
- parser.add_argument("--target", default="")
- parser.add_argument("--install", action="store_true", help="Install the project")
- parser.add_argument("--clean", action="store_true", help="Clean the build directory")
+ "--generator", choices=["Ninja", "Unix Makefiles"], default="Ninja", help="CMake generator to use")
+ parser.add_argument("--build_type", choices=[
+ "Debug", "Release", "debug", "release"], default="Debug", help="Build type")
+ parser.add_argument("--target", default="", help="Specify a build target")
+ parser.add_argument("--install", action="store_true",
+ help="Install the project")
+ parser.add_argument("--clean", action="store_true",
+ help="Clean the build directory")
parser.add_argument("--test", action="store_true", help="Run the tests")
- parser.add_argument(
- "--cmake_options",
- nargs="*",
- default=[],
- help="Custom CMake options (e.g. -DVAR=VALUE)",
- )
- parser.add_argument(
- "--meson_options",
- nargs="*",
- default=[],
- help="Custom Meson options (e.g. -Dvar=value)",
- )
- parser.add_argument("--generate_docs", action="store_true", help="Generate documentation")
+ parser.add_argument("--cmake_options", nargs="*", default=[],
+ help="Custom CMake options (e.g. -DVAR=VALUE)")
+ parser.add_argument("--meson_options", nargs="*", default=[],
+ help="Custom Meson options (e.g. -Dvar=value)")
+ parser.add_argument("--generate_docs", action="store_true",
+ help="Generate documentation")
+ parser.add_argument("--env", nargs="*", default=[],
+ help="Set environment variables (e.g. VAR=value)")
+ parser.add_argument("--verbose", action="store_true",
+ help="Enable verbose output")
+ parser.add_argument("--parallel", type=int, default=4,
+ help="Number of parallel jobs for building")
args = parser.parse_args()
+ # Parse environment variables from the command line
+ env_vars = {var.split("=")[0]: var.split("=")[1] for var in args.env}
+
+ # Initialize the appropriate builder based on the specified build system
if args.builder == "cmake":
builder = CMakeBuilder(
source_dir=args.source_dir,
@@ -270,6 +378,9 @@ def main():
generator=args.generator,
build_type=args.build_type,
cmake_options=args.cmake_options,
+ env_vars=env_vars,
+ verbose=args.verbose,
+ parallel=args.parallel,
)
elif args.builder == "meson":
builder = MesonBuilder(
@@ -277,22 +388,33 @@ def main():
build_dir=args.build_dir,
build_type=args.build_type,
meson_options=args.meson_options,
+ env_vars=env_vars,
+ verbose=args.verbose,
+ parallel=args.parallel,
)
+ # Perform cleaning if requested
if args.clean:
builder.clean()
+ # Configure the build system
builder.configure()
+
+ # Build the project with the specified target
builder.build(args.target)
+ # Install the project if the install flag is set
if args.install:
builder.install()
+ # Run tests if the test flag is set
if args.test:
builder.test()
+ # Generate documentation if the generate_docs flag is set
if args.generate_docs:
builder.generate_docs()
+
if __name__ == "__main__":
main()
diff --git a/modules/lithium.pytools/tools/cmake_generator.py b/modules/lithium.pytools/tools/cmake_generator.py
new file mode 100644
index 00000000..2fdc7c3b
--- /dev/null
+++ b/modules/lithium.pytools/tools/cmake_generator.py
@@ -0,0 +1,276 @@
+#!/usr/bin/env python3
+"""
+This script automates the generation of CMake build configuration files for C++ projects. It supports multi-directory
+project structures, the creation of custom FindXXX.cmake files for third-party libraries, and the use of JSON configuration
+files for specifying project settings.
+
+Key features:
+1. Multi-directory support with separate CMakeLists.txt for each subdirectory.
+2. Custom FindXXX.cmake generation for locating third-party libraries.
+3. JSON-based project configuration to streamline the CMakeLists.txt generation process.
+4. Customizable compiler flags, linker flags, and dependencies.
+"""
+import argparse
+import json
+from pathlib import Path
+from dataclasses import dataclass, field
+import platform
+
+
+@dataclass
+class ProjectConfig:
+ """
+ Dataclass to hold the project configuration.
+
+ Attributes:
+ project_name (str): The name of the project.
+ version (str): The version of the project.
+ cpp_standard (str): The C++ standard to use (default is C++11).
+ executable (bool): Whether to generate an executable target.
+ static_library (bool): Whether to generate a static library.
+ shared_library (bool): Whether to generate a shared library.
+ enable_testing (bool): Whether to enable testing with CMake's `enable_testing()`.
+ include_dirs (list): List of directories to include in the project.
+ sources (str): Glob pattern to specify source files (default is `src/*.cpp`).
+ compiler_flags (list): List of compiler flags (e.g., `-O3`, `-Wall`).
+ linker_flags (list): List of linker flags (e.g., `-lpthread`).
+ dependencies (list): List of third-party dependencies.
+ subdirs (list): List of subdirectories for multi-directory project structure.
+ install_path (str): Custom installation path (default is `bin`).
+ test_framework (str): The test framework to be used (optional, e.g., `GoogleTest`).
+ """
+ project_name: str
+ version: str = "1.0"
+ cpp_standard: str = "11"
+ executable: bool = True
+ static_library: bool = False
+ shared_library: bool = False
+ enable_testing: bool = False
+ include_dirs: list = field(default_factory=list)
+ sources: str = "src/*.cpp"
+ compiler_flags: list = field(default_factory=list)
+ linker_flags: list = field(default_factory=list)
+ dependencies: list = field(default_factory=list)
+ subdirs: list = field(default_factory=list)
+ install_path: str = "bin"
+ test_framework: str = None # Optional: e.g., GoogleTest
+
+
+def detect_os() -> str:
+ """
+ Detects the current operating system and returns a suitable CMake system name.
+
+ Returns:
+ str: The appropriate CMake system name for the current OS (Windows, Darwin for macOS, Linux).
+ """
+ current_os = platform.system()
+ if current_os == "Windows":
+ return "set(CMAKE_SYSTEM_NAME Windows)\n"
+ elif current_os == "Darwin":
+ return "set(CMAKE_SYSTEM_NAME Darwin)\n"
+ elif current_os == "Linux":
+ return "set(CMAKE_SYSTEM_NAME Linux)\n"
+ return ""
+
+
+def generate_cmake(config: ProjectConfig) -> str:
+ """
+ Generates the content of the main CMakeLists.txt based on the provided project configuration.
+
+ Args:
+ config (ProjectConfig): The project configuration object containing all settings.
+
+ Returns:
+ str: The content of the generated CMakeLists.txt file.
+ """
+ cmake_template = f"""cmake_minimum_required(VERSION 3.15)
+
+# Project name and version
+project({config.project_name} VERSION {config.version})
+
+# Set C++ standard
+set(CMAKE_CXX_STANDARD {config.cpp_standard})
+set(CMAKE_CXX_STANDARD_REQUIRED True)
+
+# OS-specific settings
+{detect_os()}
+
+# Source files
+file(GLOB_RECURSE SOURCES "{config.sources}")
+
+# Include directories
+"""
+ if config.include_dirs:
+ for include_dir in config.include_dirs:
+ cmake_template += f'include_directories("{include_dir}")\n'
+
+ # Compiler flags
+ if config.compiler_flags:
+ cmake_template += "add_compile_options("
+ cmake_template += " ".join(config.compiler_flags) + ")\n"
+
+ # Linker flags
+ if config.linker_flags:
+ cmake_template += "add_link_options("
+ cmake_template += " ".join(config.linker_flags) + ")\n"
+
+ # Dependencies (find_package)
+ if config.dependencies:
+ for dep in config.dependencies:
+ cmake_template += f'find_package({dep} REQUIRED)\n'
+
+ # Subdirectory handling for multi-directory support
+ if config.subdirs:
+ for subdir in config.subdirs:
+ cmake_template += f'add_subdirectory({subdir})\n'
+
+ # Create targets: executable or library
+ if config.executable:
+ cmake_template += f'add_executable({
+ config.project_name} ${{{{SOURCES}}}})\n'
+ elif config.static_library:
+ cmake_template += f'add_library({
+ config.project_name} STATIC ${{{{SOURCES}}}})\n'
+ elif config.shared_library:
+ cmake_template += f'add_library({
+ config.project_name} SHARED ${{{{SOURCES}}}})\n'
+
+ # Testing support
+ if config.enable_testing:
+ cmake_template += """
+# Enable testing
+enable_testing()
+add_subdirectory(tests)
+"""
+ # Custom install path
+ cmake_template += f"""
+# Installation rule
+install(TARGETS {config.project_name} DESTINATION {config.install_path})
+"""
+
+ return cmake_template
+
+
+def generate_find_cmake(dependency_name: str) -> str:
+ """
+ Generates a FindXXX.cmake template for a specified third-party dependency.
+
+ Args:
+ dependency_name (str): The name of the third-party library (e.g., Boost, OpenCV).
+
+ Returns:
+ str: The content of the FindXXX.cmake file to locate the library.
+ """
+ return f"""# Find{dependency_name}.cmake - Find {dependency_name} library
+
+# Locate the {dependency_name} library and headers
+find_path({dependency_name}_INCLUDE_DIR
+ NAMES {dependency_name}.h
+ PATHS /usr/local/include /usr/include
+)
+
+find_library({dependency_name}_LIBRARY
+ NAMES {dependency_name}
+ PATHS /usr/local/lib /usr/lib
+)
+
+if({dependency_name}_INCLUDE_DIR AND {dependency_name}_LIBRARY)
+ set({dependency_name}_FOUND TRUE)
+ message(STATUS "{dependency_name} found")
+else()
+ set({dependency_name}_FOUND FALSE)
+ message(FATAL_ERROR "{dependency_name} not found")
+endif()
+
+# Mark variables as advanced
+mark_as_advanced({dependency_name}_INCLUDE_DIR {dependency_name}_LIBRARY)
+"""
+
+
+def save_file(content: str, directory: str = ".", filename: str = "CMakeLists.txt"):
+ """
+ Saves the provided content to a file.
+
+ Args:
+ content (str): The content to save to the file.
+ directory (str): The directory where the file should be saved.
+ filename (str): The name of the file (default is CMakeLists.txt).
+ """
+ directory_path = Path(directory)
+ directory_path.mkdir(parents=True, exist_ok=True)
+
+ file_path = directory_path / filename
+ file_path.write_text(content, encoding='utf-8')
+
+ print(f"{filename} generated in {file_path}")
+
+
+def generate_from_json(json_file: str) -> ProjectConfig:
+ """
+ Reads project configuration from a JSON file and converts it into a ProjectConfig object.
+
+ Args:
+ json_file (str): Path to the JSON configuration file.
+
+ Returns:
+ ProjectConfig: A project configuration object with settings parsed from the JSON file.
+ """
+ with open(json_file, "r", encoding="utf-8") as file:
+ data = json.load(file)
+
+ return ProjectConfig(
+ project_name=data.get("project_name", "MyProject"),
+ version=data.get("version", "1.0"),
+ cpp_standard=data.get("cpp_standard", "11"),
+ executable=data.get("executable", True),
+ static_library=data.get("static_library", False),
+ shared_library=data.get("shared_library", False),
+ enable_testing=data.get("enable_testing", False),
+ include_dirs=data.get("include_dirs", []),
+ sources=data.get("sources", "src/*.cpp"),
+ compiler_flags=data.get("compiler_flags", []),
+ linker_flags=data.get("linker_flags", []),
+ dependencies=data.get("dependencies", []),
+ subdirs=data.get("subdirs", []),
+ install_path=data.get("install_path", "bin"),
+ test_framework=data.get("test_framework", None)
+ )
+
+
+def parse_arguments():
+ """
+ Parses command-line arguments to either generate CMakeLists.txt files, FindXXX.cmake files, or handle JSON input.
+
+ Returns:
+ argparse.Namespace: Parsed command-line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="Generate a CMake template for C++ projects.")
+ parser.add_argument(
+ "--json", type=str, help="Path to JSON config file to generate CMakeLists.txt.")
+ parser.add_argument("--find-package", type=str,
+ help="Generate a FindXXX.cmake file for a specified library.")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_arguments()
+
+ if args.json:
+ # Generate CMakeLists.txt from JSON configuration
+ config = generate_from_json(args.json)
+ cmake_content = generate_cmake(config)
+ save_file(cmake_content)
+
+ # Generate FindXXX.cmake files for each dependency specified in the JSON file
+ for dep in config.dependencies:
+ find_cmake_content = generate_find_cmake(dep)
+ save_file(find_cmake_content, directory="cmake",
+ filename=f"Find{dep}.cmake")
+
+ if args.find_package:
+ # Generate a single FindXXX.cmake file for the specified dependency
+ find_cmake_content = generate_find_cmake(args.find_package)
+ save_file(find_cmake_content, directory="cmake",
+ filename=f"Find{args.find_package}.cmake")
diff --git a/modules/lithium.pytools/tools/compiler_parser.py b/modules/lithium.pytools/tools/compiler_parser.py
index de8efc53..b246a30a 100644
--- a/modules/lithium.pytools/tools/compiler_parser.py
+++ b/modules/lithium.pytools/tools/compiler_parser.py
@@ -1,12 +1,16 @@
"""
This module contains functions for parsing compiler output and converting it to JSON, CSV, or XML format.
"""
+
+from pathlib import Path
import re
import json
import csv
import argparse
+import os
import xml.etree.ElementTree as ET
-from concurrent.futures import ThreadPoolExecutor
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from termcolor import colored
def parse_gcc_clang_output(output):
@@ -32,7 +36,8 @@ def parse_gcc_clang_output(output):
"file": match[0],
"line": int(match[1]),
"column": int(match[2]),
- "message": match[4].strip()
+ "message": match[4].strip(),
+ "severity": match[3].lower(),
}
if match[3].lower() == 'error':
results["errors"].append(entry)
@@ -70,7 +75,8 @@ def parse_msvc_output(output):
"file": match[0],
"line": int(match[1]),
"code": match[3],
- "message": match[4].strip()
+ "message": match[4].strip(),
+ "severity": match[2].lower(),
}
if match[2].lower() == 'error':
results["errors"].append(entry)
@@ -107,7 +113,8 @@ def parse_cmake_output(output):
entry = {
"file": match[0],
"line": int(match[1]),
- "message": match[3].strip()
+ "message": match[3].strip(),
+ "severity": match[2].lower(),
}
if match[2].lower() == 'error':
results["errors"].append(entry)
@@ -152,7 +159,8 @@ def write_to_csv(data, output_path):
output_path (str): The path to the output CSV file.
"""
with open(output_path, 'w', newline='', encoding="utf-8") as csvfile:
- fieldnames = ['file', 'line', 'column', 'type', 'code', 'message']
+ fieldnames = ['file', 'line', 'column',
+ 'type', 'code', 'message', 'severity']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for entry in data:
@@ -198,6 +206,25 @@ def process_file(compiler, file_path):
}
+def colorize_output(entries):
+ """
+ Prints compiler results with colorized output in the console.
+
+ Args:
+ entries (list): A list of parsed compiler entries.
+ """
+ for entry in entries:
+ if entry['type'] == 'errors':
+ print(colored(f"Error in {entry['file']}:{
+ entry['line']} - {entry['message']}", 'red'))
+ elif entry['type'] == 'warnings':
+ print(colored(f"Warning in {entry['file']}:{
+ entry['line']} - {entry['message']}", 'yellow'))
+ else:
+ print(colored(f"Info in {entry['file']}:{
+ entry['line']} - {entry['message']}", 'blue'))
+
+
def main():
"""
Main function to parse compiler output and convert to JSON, CSV, or XML format.
@@ -213,50 +240,78 @@ def main():
parser.add_argument(
'--output-file', default='output.json', help="Output file name.")
parser.add_argument(
- '--filter', choices=['error', 'warning', 'info'], help="Filter by message type.")
+ '--output-dir', default='.', help="Output directory.")
+ parser.add_argument(
+ '--filter', nargs='*', choices=['error', 'warning', 'info'], help="Filter by message types.")
+ parser.add_argument('--stats', action='store_true',
+ help="Include statistics in the output.")
parser.add_argument(
- '--stats', action='store_true', help="Include statistics in the output.")
+ '--concurrency', type=int, default=4, help="Number of concurrent threads for processing files.")
args = parser.parse_args()
+ # Prepare the output directory
+ output_dir = Path(args.output_dir).resolve()
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Initialize results list
all_results = []
- with ThreadPoolExecutor() as executor:
- futures = [executor.submit(process_file, args.compiler, file_path)
- for file_path in args.file_paths]
- for future in futures:
- all_results.append(future.result())
+ # Use ThreadPoolExecutor for concurrent processing of files
+ with ThreadPoolExecutor(max_workers=args.concurrency) as executor:
+ futures = {executor.submit(
+ process_file, args.compiler, file_path): file_path for file_path in args.file_paths}
+
+ for future in as_completed(futures):
+ try:
+ result = future.result()
+ all_results.append(result)
+ except Exception as e:
+ print(colored(f"Error processing {
+ futures[future]}: {e}", 'red'))
+
+ # Flatten results for output processing
flattened_results = []
for result in all_results:
- for key, entries in result['results'].items():
+ for severity, entries in result['results'].items():
for entry in entries:
- entry['type'] = key
+ entry['type'] = severity
entry['file'] = result['file']
flattened_results.append(entry)
+ # Apply filtering if specified
if args.filter:
flattened_results = [
- entry for entry in flattened_results if entry['type'] == args.filter]
+ entry for entry in flattened_results if entry['type'] in args.filter]
+ # Calculate statistics if requested
if args.stats:
stats = {
+ "total": len(flattened_results),
"errors": sum(1 for entry in flattened_results if entry['type'] == 'errors'),
"warnings": sum(1 for entry in flattened_results if entry['type'] == 'warnings'),
"info": sum(1 for entry in flattened_results if entry['type'] == 'info'),
}
- print(f"Statistics: {json.dumps(stats, indent=4)}")
+ print(f"Statistics:\n{json.dumps(stats, indent=4)}")
+
+ # Output results to the specified format
+ output_file_path = output_dir / args.output_file
if args.output_format == 'json':
json_output = json.dumps(flattened_results, indent=4)
- with open(args.output_file, 'w', encoding="utf-8") as json_file:
+ with open(output_file_path, 'w', encoding="utf-8") as json_file:
json_file.write(json_output)
- print(f"JSON output saved to {args.output_file}")
+ print(f"JSON output saved to {output_file_path}")
elif args.output_format == 'csv':
- write_to_csv(flattened_results, args.output_file)
- print(f"CSV output saved to {args.output_file}")
+ write_to_csv(flattened_results, output_file_path)
+ print(f"CSV output saved to {output_file_path}")
elif args.output_format == 'xml':
- write_to_xml(flattened_results, args.output_file)
- print(f"XML output saved to {args.output_file}")
+ write_to_xml(flattened_results, output_file_path)
+ print(f"XML output saved to {output_file_path}")
+
+ # Optional: Print colorized output to the console
+ print("\nColorized Output:")
+ colorize_output(flattened_results)
if __name__ == "__main__":
diff --git a/modules/lithium.pytools/tools/video_editor.py b/modules/lithium.pytools/tools/video_editor.py
new file mode 100644
index 00000000..3ff25808
--- /dev/null
+++ b/modules/lithium.pytools/tools/video_editor.py
@@ -0,0 +1,264 @@
+import argparse
+import os
+import numpy as np
+from moviepy.editor import VideoFileClip, AudioFileClip, CompositeAudioClip
+from moviepy.audio.fx.all import volumex, audio_fadein, audio_fadeout
+from pydub import AudioSegment
+import matplotlib.pyplot as plt
+from scipy.io import wavfile
+from scipy.signal import spectrogram, wiener
+
+
+def process_audio(video_path, audio_output_path, output_format='mp3', speed=1.0, volume=1.0,
+ start_time=None, end_time=None, fade_in=0, fade_out=0, reverse=False, normalize=False, noise_reduce=False):
+ """
+ Processes audio from a video file: extracts, modifies speed, volume, applies fade effects, reverses, and saves the audio.
+
+ Parameters:
+ video_path (str): Path to the input video file.
+ audio_output_path (str): Path to save the processed audio file.
+ output_format (str): Format of the output audio file ('mp3' or 'wav').
+ speed (float): Speed factor to adjust the audio playback.
+ volume (float): Volume multiplier for the audio.
+ start_time (float): Start time (in seconds) to trim the audio.
+ end_time (float): End time (in seconds) to trim the audio.
+ fade_in (float): Duration (in seconds) for the fade-in effect.
+ fade_out (float): Duration (in seconds) for the fade-out effect.
+ reverse (bool): Whether to reverse the audio.
+ normalize (bool): Normalize the audio volume to -1 to 1 range.
+ noise_reduce (bool): Apply basic noise reduction using Wiener filter.
+ """
+ # Load the video file
+ video = VideoFileClip(video_path)
+
+ # Extract audio from the video
+ audio = video.audio
+
+ # Adjust audio speed
+ if speed != 1.0:
+ audio = audio.speedx(speed)
+
+ # Adjust audio volume
+ if volume != 1.0:
+ audio = audio.fx(volumex, volume)
+
+ # Trim audio (if start and/or end times are specified)
+ if start_time is not None or end_time is not None:
+ audio = audio.subclip(start_time, end_time)
+
+ # Apply fade-in and fade-out effects
+ if fade_in > 0:
+ audio = audio.fx(audio_fadein, duration=fade_in)
+ if fade_out > 0:
+ audio = audio.fx(audio_fadeout, duration=fade_out)
+
+ # Reverse the audio
+ if reverse:
+ audio = audio.fx(lambda clip: clip.fl_time(
+ lambda t: clip.duration - t))
+
+ # Normalize the audio
+ if normalize:
+ audio_array = np.array(audio.to_soundarray())
+ max_amplitude = np.max(np.abs(audio_array))
+ audio_array = audio_array / max_amplitude
+ audio = audio.set_audio_array(audio_array)
+
+ # Apply noise reduction
+ if noise_reduce:
+ sample_rate, data = wavfile.read(audio_output_path)
+ data = wiener(data)
+ wavfile.write(audio_output_path, sample_rate, data)
+
+ # Save the processed audio
+ audio.write_audiofile(
+ audio_output_path, codec='libmp3lame' if output_format == 'mp3' else 'pcm_s16le')
+
+ # Close the video file
+ video.close()
+
+
+def batch_process(input_dir, output_dir, **kwargs):
+ """
+ Processes all video files in a given directory in batch mode.
+
+ Parameters:
+ input_dir (str): Directory containing the input video files.
+ output_dir (str): Directory where the processed audio files will be saved.
+ kwargs: Additional arguments passed to the audio processing function.
+ """
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ for filename in os.listdir(input_dir):
+ if filename.endswith(('.mp4', '.avi', '.mov', '.mkv')):
+ input_path = os.path.join(input_dir, filename)
+ output_path = os.path.join(output_dir, os.path.splitext(
+ filename)[0] + '.' + kwargs.get('output_format', 'mp3'))
+ try:
+ process_audio(input_path, output_path, **kwargs)
+ print(f"Processed: {filename}")
+ except Exception as e:
+ print(f"Error processing {filename}: {e}")
+
+
+def mix_audio(audio_files, output_file, volumes=None):
+ """
+ Mixes multiple audio files into one.
+
+ Parameters:
+ audio_files (list): List of paths to audio files to be mixed.
+ output_file (str): Path to save the mixed audio file.
+ volumes (list): List of volume multipliers for each audio file.
+ """
+ if volumes is None:
+ volumes = [1.0] * len(audio_files)
+
+ audio_clips = [AudioFileClip(f).volumex(v)
+ for f, v in zip(audio_files, volumes)]
+ final_audio = CompositeAudioClip(audio_clips)
+ final_audio.write_audiofile(output_file)
+
+
+def split_audio(input_file, output_prefix, segment_length):
+ """
+ Splits an audio file into segments of specified length.
+
+ Parameters:
+ input_file (str): Path to the input audio file.
+ output_prefix (str): Prefix for the output segment files.
+ segment_length (float): Length of each segment in seconds.
+ """
+ audio = AudioSegment.from_file(input_file)
+ length_ms = len(audio)
+ segment_length_ms = segment_length * 1000
+
+ for i, start in enumerate(range(0, length_ms, segment_length_ms)):
+ end = start + segment_length_ms
+ segment = audio[start:end]
+ segment.export(f"{output_prefix}_{i+1}.mp3", format="mp3")
+
+
+def convert_audio_format(input_file, output_file):
+ """
+ Converts an audio file to a different format.
+
+ Parameters:
+ input_file (str): Path to the input audio file.
+ output_file (str): Path to the output file with the desired format.
+ """
+ audio = AudioSegment.from_file(input_file)
+ export_format = os.path.splitext(output_file)[1][1:]
+ audio.export(output_file, format=export_format)
+
+
+def visualize_audio(input_file, output_file):
+ """
+ Visualizes an audio file as a waveform and spectrogram.
+
+ Parameters:
+ input_file (str): Path to the input audio file (WAV format).
+ output_file (str): Path to save the visualization image.
+ """
+ sample_rate, data = wavfile.read(input_file)
+
+ plt.figure(figsize=(12, 8))
+
+ # Plot waveform
+ plt.subplot(2, 1, 1)
+ plt.plot(np.arange(len(data)) / sample_rate, data)
+ plt.title('Audio Waveform')
+ plt.xlabel('Time (seconds)')
+ plt.ylabel('Amplitude')
+
+ # Plot spectrogram
+ plt.subplot(2, 1, 2)
+ frequencies, times, Sxx = spectrogram(data, sample_rate)
+ plt.pcolormesh(times, frequencies, 10 * np.log10(Sxx))
+ plt.title('Spectrogram')
+ plt.xlabel('Time (seconds)')
+ plt.ylabel('Frequency (Hz)')
+ plt.colorbar(label='Intensity (dB)')
+
+ plt.tight_layout()
+ plt.savefig(output_file)
+ plt.close()
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Full-featured audio processing tool.")
+ parser.add_argument(
+ "input", help="Path to the input video/audio file or directory containing files.")
+ parser.add_argument(
+ "output", help="Path to save the output audio file or directory.")
+ parser.add_argument(
+ "--format", choices=['mp3', 'wav'], default='mp3', help="Output audio format (default: mp3)")
+ parser.add_argument("--speed", type=float, default=1.0,
+ help="Adjust audio speed (default: 1.0)")
+ parser.add_argument("--volume", type=float, default=1.0,
+ help="Adjust audio volume (default: 1.0)")
+ parser.add_argument("--start", type=float,
+ help="Start time of the audio (seconds)")
+ parser.add_argument("--end", type=float,
+ help="End time of the audio (seconds)")
+ parser.add_argument("--fade-in", type=float, default=0,
+ help="Fade-in duration (seconds)")
+ parser.add_argument("--fade-out", type=float, default=0,
+ help="Fade-out duration (seconds)")
+ parser.add_argument("--reverse", action="store_true",
+ help="Reverse the audio")
+ parser.add_argument("--normalize", action="store_true",
+ help="Normalize the audio volume")
+ parser.add_argument("--noise-reduce", action="store_true",
+ help="Apply basic noise reduction")
+ parser.add_argument("--batch", action="store_true",
+ help="Batch process all videos in a directory")
+ parser.add_argument("--mix", nargs='+', help="Mix multiple audio files")
+ parser.add_argument("--mix-volumes", nargs='+', type=float,
+ help="Volumes for each audio file during mix")
+ parser.add_argument("--split", type=float,
+ help="Split audio into segments of specified length (seconds)")
+ parser.add_argument(
+ "--convert", help="Convert audio to a different format")
+ parser.add_argument("--visualize", action="store_true",
+ help="Generate a visualization of the audio (waveform and spectrogram)")
+
+ args = parser.parse_args()
+
+ # Mix multiple audio files
+ if args.mix:
+ mix_audio(args.mix, args.output, args.mix_volumes)
+ # Split audio into segments
+ elif args.split:
+ split_audio(args.input, args.output, args.split)
+ # Convert audio format
+ elif args.convert:
+ convert_audio_format(args.input, args.output)
+ # Visualize the audio as waveform and spectrogram
+ elif args.visualize:
+ visualize_audio(args.input, args.output)
+ # Batch process all video files in a directory
+ elif args.batch:
+ batch_process(
+ args.input, args.output,
+ output_format=args.format, speed=args.speed, volume=args.volume,
+ start_time=args.start, end_time=args.end,
+ fade_in=args.fade_in, fade_out=args.fade_out, reverse=args.reverse,
+ normalize=args.normalize, noise_reduce=args.noise_reduce
+ )
+ # Process a single video or audio file
+ else:
+ process_audio(
+ args.input, args.output,
+ output_format=args.format, speed=args.speed, volume=args.volume,
+ start_time=args.start, end_time=args.end,
+ fade_in=args.fade_in, fade_out=args.fade_out, reverse=args.reverse,
+ normalize=args.normalize, noise_reduce=args.noise_reduce
+ )
+
+ print("Processing completed successfully.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/package/package.sh b/package/package.sh
new file mode 100644
index 00000000..9fd28ccf
--- /dev/null
+++ b/package/package.sh
@@ -0,0 +1,164 @@
+#!/bin/bash
+
+# 脚本:CMake MinGW项目打包脚本
+# Script: CMake MinGW Project Packaging Script
+# 描述:这个脚本用于构建、安装和打包基于CMake的MinGW项目。
+# Description: This script is used to build, install, and package CMake-based MinGW projects.
+# 作者:Max Qian
+# Author: Max Qian
+# 版本:1.2
+# Version: 1.2
+# 使用方法:./package.sh [clean|build|package]
+# Usage: ./package.sh [clean|build|package]
+
+# ===== 变量设置 / Variable Settings =====
+# 项目名称 / Project name
+PROJECT_NAME="MyProject"
+# 构建目录 / Build directory
+BUILD_DIR="build"
+# 安装目录 / Install directory
+INSTALL_DIR="install"
+# 打包目录 / Package directory
+PACKAGE_DIR="package"
+# 版本号(从git获取,如果失败则使用默认值)
+# Version number (obtained from git, use default if failed)
+VERSION=$(git describe --tags --always --dirty 2>/dev/null || echo "1.0.0")
+# 最终生成的压缩包名称 / Name of the final generated archive
+ARCHIVE_NAME="${PROJECT_NAME}-${VERSION}-win64.zip"
+
+# ===== 颜色代码 / Color Codes =====
+# 用于美化控制台输出 / Used to beautify console output
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[1;33m'
+NC='\033[0m' # No Color / 无颜色
+
+# ===== 辅助函数 / Helper Functions =====
+# 输出信息日志 / Output information log
+log_info() {
+ echo -e "${GREEN}[INFO] $1${NC}"
+}
+
+# 输出警告日志 / Output warning log
+log_warn() {
+ echo -e "${YELLOW}[WARN] $1${NC}"
+}
+
+# 输出错误日志 / Output error log
+log_error() {
+ echo -e "${RED}[ERROR] $1${NC}"
+}
+
+# 检查命令是否存在 / Check if a command exists
+check_command() {
+ if ! command -v $1 &> /dev/null; then
+ log_error "$1 could not be found. Please install it and try again."
+ log_error "$1 未找到。请安装后重试。"
+ exit 1
+ fi
+}
+
+# ===== 主要功能函数 / Main Function =====
+# 清理函数:删除之前的构建产物
+# Clean function: Remove previous build artifacts
+clean() {
+ log_info "Cleaning previous build artifacts..."
+ log_info "清理之前的构建产物..."
+ rm -rf $BUILD_DIR $INSTALL_DIR $PACKAGE_DIR
+}
+
+# 构建函数:配置CMake,编译项目,并安装
+# Build function: Configure CMake, compile the project, and install
+build() {
+ log_info "Configuring project with CMake..."
+ log_info "使用CMake配置项目..."
+ mkdir -p $BUILD_DIR
+ cd $BUILD_DIR
+ # 使用MinGW Makefiles生成器,设置安装前缀和构建类型
+ # Use MinGW Makefiles generator, set install prefix and build type
+ cmake -G "MinGW Makefiles" -DCMAKE_INSTALL_PREFIX=../$INSTALL_DIR -DCMAKE_BUILD_TYPE=Release ..
+ if [ $? -ne 0 ]; then
+ log_error "CMake configuration failed."
+ log_error "CMake配置失败。"
+ exit 1
+ fi
+
+ log_info "Building project..."
+ log_info "构建项目..."
+ # 使用CMake构建项目 / Build the project using CMake
+ cmake --build . --config Release
+ if [ $? -ne 0 ]; then
+ log_error "Build failed."
+ log_error "构建失败。"
+ exit 1
+ fi
+
+ log_info "Installing project..."
+ log_info "安装项目..."
+ # 安装项目到指定目录 / Install the project to the specified directory
+ cmake --install .
+ if [ $? -ne 0 ]; then
+ log_error "Installation failed."
+ log_error "安装失败。"
+ exit 1
+ fi
+
+ cd ..
+}
+
+# 打包函数:复制文件,处理依赖,创建压缩包
+# Package function: Copy files, handle dependencies, create archive
+package() {
+ log_info "Packaging project..."
+ log_info "打包项目..."
+ mkdir -p $PACKAGE_DIR
+ # 复制安装的文件到打包目录 / Copy installed files to the package directory
+ cp -R $INSTALL_DIR/* $PACKAGE_DIR/
+
+ log_info "Copying dependencies..."
+ log_info "复制依赖项..."
+ # 复制所有可执行文件的MinGW依赖 / Copy MinGW dependencies for all executables
+ for file in $PACKAGE_DIR/bin/*.exe; do
+ ldd "$file" | grep mingw | awk '{print $3}' | xargs -I '{}' cp '{}' $PACKAGE_DIR/bin/
+ done
+
+ log_info "Creating archive..."
+ log_info "创建压缩包..."
+ # 使用7-Zip创建ZIP压缩包 / Create ZIP archive using 7-Zip
+ 7z a -tzip $ARCHIVE_NAME $PACKAGE_DIR/*
+
+ log_info "Packaging complete. Archive created: $ARCHIVE_NAME"
+ log_info "打包完成。创建的压缩包:$ARCHIVE_NAME"
+}
+
+# 主函数:执行完整的构建和打包流程
+# Main function: Execute the complete build and package process
+main() {
+ clean
+ build
+ package
+}
+
+# ===== 脚本入口 / Script Entry =====
+# 检查必要的命令是否存在 / Check if necessary commands exist
+check_command cmake
+check_command mingw32-make
+check_command git
+check_command 7z
+
+# 根据命令行参数执行相应的功能
+# Execute corresponding functionality based on command line arguments
+case "$1" in
+ clean)
+ clean
+ ;;
+ build)
+ build
+ ;;
+ package)
+ package
+ ;;
+ *)
+ main
+ ;;
+esac
diff --git a/pysrc/app/command_dispatcher.py b/pysrc/app/command_dispatcher.py
new file mode 100644
index 00000000..248f5f2b
--- /dev/null
+++ b/pysrc/app/command_dispatcher.py
@@ -0,0 +1,131 @@
+# app/command_dispatcher.py
+from typing import Dict, Any, Callable, Awaitable, List, Optional
+from loguru import logger
+import inspect
+import asyncio
+from datetime import datetime, timedelta
+
+
+class Command:
+ def __init__(self, func: Callable, name: str, description: str, rate_limit: Optional[int] = None):
+ self.func = func
+ self.name = name
+ self.description = description
+ self.rate_limit = rate_limit
+ self.last_called: Dict[str, datetime] = {}
+
+
+class CommandDispatcher:
+ def __init__(self):
+ self.commands: Dict[str, Command] = {}
+ self.middlewares: List[Callable] = []
+ self.default_command: Optional[Command] = None
+
+ def register(self, name: str, description: str = "", rate_limit: Optional[int] = None):
+ def decorator(f: Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]):
+ self.commands[name] = Command(f, name, description, rate_limit)
+ return f
+ return decorator
+
+ def set_default(self, f: Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]):
+ self.default_command = Command(f, "default", "Default command handler")
+
+ def add_middleware(self, middleware: Callable):
+ self.middlewares.append(middleware)
+
+ async def dispatch(self, command_name: str, params: Dict[str, Any], user_id: str) -> Dict[str, Any]:
+ command = self.commands.get(command_name, self.default_command)
+
+ if command is None:
+ return {"error": f"Unknown command: {command_name}"}
+
+ # Apply rate limiting
+ if command.rate_limit:
+ last_called = command.last_called.get(user_id)
+ if last_called and datetime.now() - last_called < timedelta(seconds=command.rate_limit):
+ return {"error": f"Rate limit exceeded for command: {command_name}"}
+ command.last_called[user_id] = datetime.now()
+
+ # Apply middlewares
+ for middleware in self.middlewares:
+ params = await middleware(command_name, params, user_id)
+
+ try:
+ # Check if the command function is asynchronous
+ if inspect.iscoroutinefunction(command.func):
+ result = await command.func(params)
+ else:
+ # If it's not async, run it in an executor
+ loop = asyncio.get_event_loop()
+ result = await loop.run_in_executor(None, command.func, params)
+
+ return result
+ except Exception as e:
+ logger.error(f"Error executing command {command_name}: {e}")
+ return {"error": f"Command execution failed: {str(e)}"}
+
+ def get_command_list(self) -> List[Dict[str, Any]]:
+ return [{"name": cmd.name, "description": cmd.description} for cmd in self.commands.values()]
+
+ async def batch_dispatch(self, commands: List[Dict[str, Any]], user_id: str) -> List[Dict[str, Any]]:
+ results = []
+ for cmd in commands:
+ result = await self.dispatch(cmd.get("command"), cmd.get("params", {}), user_id)
+ results.append(result)
+ return results
+
+# Example usage:
+
+
+dispatcher = CommandDispatcher()
+
+
+@dispatcher.register("echo", "Echoes the input message", rate_limit=5)
+async def echo_command(params: Dict[str, Any]) -> Dict[str, Any]:
+ return {"result": params.get("message", "No message provided")}
+
+
+@dispatcher.register("add", "Adds two numbers")
+def add_command(params: Dict[str, Any]) -> Dict[str, Any]:
+ a = params.get("a", 0)
+ b = params.get("b", 0)
+ return {"result": a + b}
+
+
+@dispatcher.set_default
+async def default_command(params: Dict[str, Any]) -> Dict[str, Any]:
+ return {"result": "Unknown command. Use 'help' to see available commands."}
+
+
+async def log_middleware(command: str, params: Dict[str, Any], user_id: str) -> Dict[str, Any]:
+ logger.info(f"User {user_id} is executing command: {command}")
+ return params
+
+dispatcher.add_middleware(log_middleware)
+
+# Usage example:
+
+
+async def main():
+ result = await dispatcher.dispatch("echo", {"message": "Hello, World!"}, "user1")
+ print(result) # {"result": "Hello, World!"}
+
+ result = await dispatcher.dispatch("add", {"a": 5, "b": 3}, "user1")
+ print(result) # {"result": 8}
+
+ result = await dispatcher.dispatch("unknown", {}, "user1")
+ # {"result": "Unknown command. Use 'help' to see available commands."}
+ print(result)
+
+ command_list = dispatcher.get_command_list()
+ # [{"name": "echo", "description": "Echoes the input message"}, {"name": "add", "description": "Adds two numbers"}]
+ print(command_list)
+
+ batch_results = await dispatcher.batch_dispatch([
+ {"command": "echo", "params": {"message": "First command"}},
+ {"command": "add", "params": {"a": 10, "b": 20}}
+ ], "user1")
+ print(batch_results) # [{"result": "First command"}, {"result": 30}]
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/pysrc/app/dependence.py b/pysrc/app/dependence.py
new file mode 100644
index 00000000..81e3e67d
--- /dev/null
+++ b/pysrc/app/dependence.py
@@ -0,0 +1,16 @@
+from fastapi import Depends, HTTPException, status
+from fastapi.security import HTTPBasic, HTTPBasicCredentials
+from config.config import config
+
+security = HTTPBasic()
+
+def get_current_username(credentials: HTTPBasicCredentials = Depends(security)):
+ correct_username = credentials.username == config.auth_username
+ correct_password = credentials.password == config.auth_password
+ if not (correct_username and correct_password):
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Incorrect username or password",
+ headers={"WWW-Authenticate": "Basic"},
+ )
+ return credentials.username
diff --git a/pysrc/image/adaptive_stretch/__init__.py b/pysrc/image/adaptive_stretch/__init__.py
new file mode 100644
index 00000000..6c543b23
--- /dev/null
+++ b/pysrc/image/adaptive_stretch/__init__.py
@@ -0,0 +1,2 @@
+from .stretch import AdaptiveStretch
+from .preview import apply_real_time_preview
diff --git a/pysrc/image/adaptive_stretch/preview.py b/pysrc/image/adaptive_stretch/preview.py
new file mode 100644
index 00000000..522a740e
--- /dev/null
+++ b/pysrc/image/adaptive_stretch/preview.py
@@ -0,0 +1,26 @@
+import matplotlib.pyplot as plt
+import cv2
+import numpy as np
+from .stretch import AdaptiveStretch
+from typing import Optional, Tuple
+
+def apply_real_time_preview(image: np.ndarray, noise_threshold: float = 1e-4, contrast_protection: Optional[float] = None, max_curve_points: int = 106, roi: Optional[Tuple[int, int, int, int]] = None):
+ """
+ Simulate real-time preview by iteratively applying the adaptive stretch
+ transformation and displaying the result.
+
+ :param image: Input image as a numpy array (grayscale or color).
+ :param noise_threshold: Threshold for treating brightness differences as noise.
+ :param contrast_protection: Optional contrast protection parameter.
+ :param max_curve_points: Maximum points for the transformation curve.
+ :param roi: Tuple (x, y, width, height) defining the region of interest.
+ """
+ adaptive_stretch = AdaptiveStretch(noise_threshold, contrast_protection, max_curve_points)
+ preview_image = adaptive_stretch.stretch(image, roi)
+
+ if len(preview_image.shape) == 3:
+ preview_image = cv2.cvtColor(preview_image, cv2.COLOR_BGR2RGB)
+
+ plt.imshow(preview_image, cmap='gray' if len(preview_image.shape) == 2 else None)
+ plt.title(f"Noise Threshold: {noise_threshold}, Contrast Protection: {contrast_protection}")
+ plt.show()
diff --git a/pysrc/image/adaptive_stretch/stretch.py b/pysrc/image/adaptive_stretch/stretch.py
new file mode 100644
index 00000000..c8483be6
--- /dev/null
+++ b/pysrc/image/adaptive_stretch/stretch.py
@@ -0,0 +1,97 @@
+import numpy as np
+import cv2
+from typing import Optional, Tuple
+
+
+class AdaptiveStretch:
+ def __init__(self, noise_threshold: float = 1e-4, contrast_protection: Optional[float] = None, max_curve_points: int = 106):
+ """
+ Initialize the AdaptiveStretch object with specific parameters.
+
+ :param noise_threshold: Threshold for treating brightness differences as noise.
+ :param contrast_protection: Optional contrast protection parameter.
+ :param max_curve_points: Maximum points for the transformation curve.
+ """
+ self.noise_threshold = noise_threshold
+ self.contrast_protection = contrast_protection
+ self.max_curve_points = max_curve_points
+
+ def compute_brightness_diff(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Compute brightness differences between adjacent pixels.
+ Returns matrices of differences along the x and y axes.
+
+ :param image: Input image as a numpy array (grayscale).
+ :return: Tuple of differences along x and y axes.
+ """
+ diff_x = np.diff(image, axis=1) # differences between columns
+ diff_y = np.diff(image, axis=0) # differences between rows
+
+ # Pad the differences to match the original image size
+ diff_x = np.pad(diff_x, ((0, 0), (0, 1)), mode='constant')
+ diff_y = np.pad(diff_y, ((0, 1), (0, 0)), mode='constant')
+
+ return diff_x, diff_y
+
+ def stretch(self, image: np.ndarray, roi: Optional[Tuple[int, int, int, int]] = None) -> np.ndarray:
+ """
+ Apply the AdaptiveStretch transformation to the image.
+
+ :param image: Input image as a numpy array (grayscale or color).
+ :param roi: Tuple (x, y, width, height) defining the region of interest.
+ :return: Stretched image.
+ """
+ if len(image.shape) == 3:
+ channels = cv2.split(image)
+ else:
+ channels = [image]
+
+ stretched_channels = []
+
+ for channel in channels:
+ # Normalize the channel to the range [0, 1]
+ channel = channel.astype(np.float32) / 255.0
+
+ if roi is not None:
+ x, y, w, h = roi
+ channel_roi = channel[y:y+h, x:x+w]
+ else:
+ channel_roi = channel
+
+ diff_x, diff_y = self.compute_brightness_diff(channel_roi)
+
+ positive_forces = np.maximum(diff_x, 0) + np.maximum(diff_y, 0)
+ negative_forces = np.minimum(diff_x, 0) + np.minimum(diff_y, 0)
+
+ positive_forces[positive_forces < self.noise_threshold] = 0
+ negative_forces[negative_forces > -self.noise_threshold] = 0
+
+ transformation_curve = positive_forces + negative_forces
+
+ if self.contrast_protection is not None:
+ transformation_curve = np.clip(
+ transformation_curve, -self.contrast_protection, self.contrast_protection)
+
+ resampled_curve = cv2.resize(
+ transformation_curve, (self.max_curve_points, 1), interpolation=cv2.INTER_LINEAR)
+
+ interpolated_curve = cv2.resize(
+ resampled_curve, (channel_roi.shape[1], channel_roi.shape[0]), interpolation=cv2.INTER_LINEAR)
+
+ stretched_channel = channel_roi + interpolated_curve
+
+ stretched_channel = np.clip(
+ stretched_channel * 255, 0, 255).astype(np.uint8)
+
+ if roi is not None:
+ channel[y:y+h, x:x+w] = stretched_channel
+ stretched_channel = channel
+
+ stretched_channels.append(stretched_channel)
+
+ if len(stretched_channels) > 1:
+ stretched_image = cv2.merge(stretched_channels)
+ else:
+ stretched_image = stretched_channels[0]
+
+ return stretched_image
diff --git a/pysrc/image/api/strecth_count.py b/pysrc/image/api/strecth_count.py
new file mode 100644
index 00000000..0e05b936
--- /dev/null
+++ b/pysrc/image/api/strecth_count.py
@@ -0,0 +1,200 @@
+from pathlib import Path
+from typing import Tuple, Dict, Optional, List
+import json
+import numpy as np
+import cv2
+from astropy.io import fits
+from scipy import ndimage
+
+
+def debayer_image(img: np.ndarray, bayer_pattern: Optional[str] = None) -> np.ndarray:
+ bayer_patterns = {
+ "RGGB": cv2.COLOR_BAYER_RGGB2BGR,
+ "GBRG": cv2.COLOR_BAYER_GBRG2BGR,
+ "BGGR": cv2.COLOR_BAYER_BGGR2BGR,
+ "GRBG": cv2.COLOR_BAYER_GRBG2BGR
+ }
+ return cv2.cvtColor(img, bayer_patterns.get(bayer_pattern, cv2.COLOR_BAYER_RGGB2BGR))
+
+
+def resize_image(img: np.ndarray, target_size: int) -> np.ndarray:
+ scale = min(target_size / max(img.shape[:2]), 1)
+ if scale < 1:
+ return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
+ return img
+
+
+def normalize_image(img: np.ndarray) -> np.ndarray:
+ if not np.allclose(img, img.astype(np.uint8)):
+ img = cv2.normalize(img, None, alpha=0, beta=255,
+ norm_type=cv2.NORM_MINMAX)
+ return img.astype(np.uint8)
+
+
+def stretch_image(img: np.ndarray, is_color: bool) -> np.ndarray:
+ if is_color:
+ return ComputeAndStretch_ThreeChannels(img, True)
+ return ComputeStretch_OneChannels(img, True)
+
+
+def detect_stars(img: np.ndarray, remove_hotpixel: bool, remove_noise: bool, do_star_mark: bool, mark_img: Optional[np.ndarray] = None) -> Tuple[np.ndarray, float, float, Dict[str, float]]:
+ return StarDetectAndHfr(img, remove_hotpixel, remove_noise, do_star_mark, down_sample_mean_std=True, mark_img=mark_img)
+
+# New functions for enhanced image processing
+
+
+def apply_gaussian_blur(img: np.ndarray, kernel_size: int = 5) -> np.ndarray:
+ """Apply Gaussian blur to reduce noise."""
+ return cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
+
+
+def apply_unsharp_mask(img: np.ndarray, kernel_size: int = 5, amount: float = 1.5) -> np.ndarray:
+ """Apply unsharp mask to enhance image details."""
+ blurred = cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
+ return cv2.addWeighted(img, amount + 1, blurred, -amount, 0)
+
+
+def equalize_histogram(img: np.ndarray) -> np.ndarray:
+ """Apply histogram equalization to improve contrast."""
+ if len(img.shape) == 2:
+ return cv2.equalizeHist(img)
+ else:
+ ycrcb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
+ ycrcb[:, :, 0] = cv2.equalizeHist(ycrcb[:, :, 0])
+ return cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2BGR)
+
+
+def remove_hot_pixels(img: np.ndarray, threshold: float = 3.0) -> np.ndarray:
+ """Remove hot pixels using median filter and thresholding."""
+ median = ndimage.median_filter(img, size=3)
+ diff = np.abs(img - median)
+ mask = diff > (threshold * np.std(diff))
+ img[mask] = median[mask]
+ return img
+
+
+def adjust_gamma(img: np.ndarray, gamma: float = 1.0) -> np.ndarray:
+ """Adjust image gamma."""
+ inv_gamma = 1.0 / gamma
+ table = np.array([((i / 255.0) ** inv_gamma) *
+ 255 for i in np.arange(0, 256)]).astype("uint8")
+ return cv2.LUT(img, table)
+
+
+def apply_clahe(img: np.ndarray, clip_limit: float = 2.0, tile_grid_size: Tuple[int, int] = (8, 8)) -> np.ndarray:
+ """Apply Contrast Limited Adaptive Histogram Equalization (CLAHE)."""
+ if len(img.shape) == 2:
+ clahe = cv2.createCLAHE(clipLimit=clip_limit,
+ tileGridSize=tile_grid_size)
+ return clahe.apply(img)
+ else:
+ lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
+ l, a, b = cv2.split(lab)
+ clahe = cv2.createCLAHE(clipLimit=clip_limit,
+ tileGridSize=tile_grid_size)
+ cl = clahe.apply(l)
+ limg = cv2.merge((cl, a, b))
+ return cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
+
+
+def denoise_image(img: np.ndarray, h: float = 10) -> np.ndarray:
+ """Apply Non-Local Means Denoising."""
+ return cv2.fastNlMeansDenoisingColored(img, None, h, h, 7, 21)
+
+
+def enhance_image(img: np.ndarray, config: Dict[str, bool]) -> np.ndarray:
+ """Apply a series of image enhancements based on the configuration."""
+ if config.get('remove_hot_pixels', False):
+ img = remove_hot_pixels(img)
+ if config.get('denoise', False):
+ img = denoise_image(img)
+ if config.get('equalize_histogram', False):
+ img = equalize_histogram(img)
+ if config.get('apply_clahe', False):
+ img = apply_clahe(img)
+ if config.get('unsharp_mask', False):
+ img = apply_unsharp_mask(img)
+ if config.get('adjust_gamma', False):
+ img = adjust_gamma(img, gamma=config.get('gamma', 1.0))
+ return img
+
+
+def process_image(filepath: Path, config: Dict[str, bool], resize_size: int = 2048) -> Tuple[Optional[np.ndarray], Dict[str, float]]:
+ try:
+ img, header = fits.getdata(filepath, header=True)
+ except Exception:
+ return None, {"star_count": -1, "average_hfr": -1, "max_star": -1, "min_star": -1, "average_star": -1}
+
+ is_color = 'BAYERPAT' in header
+ if is_color:
+ img = debayer_image(img, header['BAYERPAT'])
+
+ img = resize_image(img, resize_size)
+ img = normalize_image(img)
+
+ # Apply image enhancements
+ img = enhance_image(img, config)
+
+ if config['do_stretch']:
+ img = stretch_image(img, is_color)
+
+ if config['do_star_count']:
+ img, star_count, avg_hfr, area_range = detect_stars(
+ img, config['remove_hotpixel'], config['remove_noise'], config['do_star_mark'])
+ return img, {
+ "star_count": float(star_count),
+ "average_hfr": avg_hfr,
+ "max_star": area_range['max'],
+ "min_star": area_range['min'],
+ "average_star": area_range['average']
+ }
+
+ return img, {"star_count": -1, "average_hfr": -1, "max_star": -1, "min_star": -1, "average_star": -1}
+
+
+def ImageStretchAndStarCount_Optim(filepath: Path, config: Dict[str, bool], resize_size: int = 2048,
+ jpg_file: Optional[Path] = None, star_file: Optional[Path] = None) -> Tuple[Optional[np.ndarray], Dict[str, float]]:
+ img, result = process_image(filepath, config, resize_size)
+
+ if jpg_file and img is not None:
+ cv2.imwrite(str(jpg_file), img)
+ if star_file:
+ with star_file.open('w') as f:
+ json.dump(result, f)
+
+ return img, result
+
+
+def StreamingDebayerAndStretch(fits_data: bytearray, width: int, height: int, config: Dict[str, bool],
+ resize_size: int = 2048, bayer_type: Optional[str] = None) -> Optional[np.ndarray]:
+ img = np.frombuffer(fits_data, dtype=np.uint16).reshape(height, width)
+
+ if bayer_type:
+ img = debayer_image(img, bayer_type)
+
+ img = resize_image(img, resize_size)
+ img = normalize_image(img)
+
+ # Apply image enhancements
+ img = enhance_image(img, config)
+
+ if config.get('do_stretch', True):
+ img = stretch_image(img, len(img.shape) == 3)
+
+ return img
+
+
+def StreamingDebayer(fits_data: bytearray, width: int, height: int, config: Dict[str, bool],
+ resize_size: int = 2048, bayer_type: Optional[str] = None) -> Optional[np.ndarray]:
+ img = np.frombuffer(fits_data, dtype=np.uint16).reshape(height, width)
+
+ if bayer_type:
+ img = debayer_image(img, bayer_type)
+
+ img = resize_image(img, resize_size)
+ img = normalize_image(img)
+
+ # Apply image enhancements
+ img = enhance_image(img, config)
+
+ return img
diff --git a/pysrc/image/astroalign/astroalign.py b/pysrc/image/astroalign/astroalign.py
new file mode 100644
index 00000000..2e52a688
--- /dev/null
+++ b/pysrc/image/astroalign/astroalign.py
@@ -0,0 +1,553 @@
+"""
+Module for starfield image registration using triangle-based invariants.
+
+This module provides functions to estimate and apply geometric transformations between
+two sets of starfield images. The core technique relies on computing invariant features
+(triangle invariants) from sets of nearest neighbors of stars in both images. The
+transformation between the two images is then estimated using these invariant features.
+
+The key functionalities include:
+- Estimating transformation between two images
+- Applying the estimated transformation to align images
+- Extracting source positions (stars) from images
+- RANSAC algorithm for robust model estimation
+
+Author: Max Qian
+Version: 2.6.0 (lithium)
+"""
+
+__version__ = "2.6.0"
+
+__all__ = [
+ "MIN_MATCHES_FRACTION",
+ "MaxIterError",
+ "NUM_NEAREST_NEIGHBORS",
+ "PIXEL_TOL",
+ "apply_transform",
+ "estimate_transform",
+ "find_transform",
+ "matrix_transform",
+ "register",
+]
+
+from collections import Counter
+from functools import partial
+from itertools import combinations
+from typing import Any, Tuple, Union
+
+import sep_pjw as sep
+import numpy as np
+from numpy.typing import NDArray
+from scipy.spatial import KDTree
+from skimage.transform import estimate_transform, matrix_transform, warp
+
+try:
+ import bottleneck as bn
+except ImportError:
+ HAS_BOTTLENECK = False
+else:
+ HAS_BOTTLENECK = True
+
+PIXEL_TOL = 2
+"""The pixel distance tolerance to assume two invariant points are the same.
+
+Default: 2
+"""
+
+MIN_MATCHES_FRACTION = 0.8
+"""The minimum fraction of triangle matches to accept a transformation.
+
+If the minimum fraction yields more than 10 triangles, 10 is used instead.
+
+Default: 0.8
+"""
+
+NUM_NEAREST_NEIGHBORS = 5
+"""
+The number of nearest neighbors of a given star (including itself) to construct
+the triangle invariants.
+
+Default: 5
+"""
+
+_default_median = bn.nanmedian if HAS_BOTTLENECK else np.nanmedian # pragma: no cover
+"""
+Default median function when/if optional bottleneck is available
+"""
+
+_default_average = bn.nanmean if HAS_BOTTLENECK else np.nanmean # pragma: no cover
+"""
+Default mean function when/if optional bottleneck is available
+"""
+
+
+def _invariantfeatures(x1: NDArray, x2: NDArray, x3: NDArray) -> list[float]:
+ """
+ Given 3 points x1, x2, x3, return the invariant features for the set.
+
+ Invariant features are ratios of side lengths of the triangles formed by these points,
+ sorted by size. These features are scale-invariant and can be used to compare star
+ positions between images.
+
+ Args:
+ x1, x2, x3: 2D coordinates of points in the source image.
+
+ Returns:
+ List containing two invariant feature values derived from the triangle side ratios.
+ """
+ sides = np.sort(
+ [np.linalg.norm(x1 - x2), np.linalg.norm(x2 - x3), np.linalg.norm(x1 - x3)])
+ return [sides[2] / sides[1], sides[1] / sides[0]]
+
+
+def _arrangetriplet(sources: NDArray, vertex_indices: tuple[int, int, int]) -> NDArray:
+ """
+ Reorder the given triplet of vertex indices according to the length of their sides.
+
+ This function returns the indices in a consistent order based on the triangle's side
+ lengths. It ensures that the triangle invariants are consistently computed.
+
+ Args:
+ sources: Array of source star positions.
+ vertex_indices: Indices of the three vertices that form the triangle.
+
+ Returns:
+ Reordered array of vertex indices based on side lengths.
+ """
+ ind1, ind2, ind3 = vertex_indices
+ x1, x2, x3 = sources[vertex_indices]
+
+ side_ind = np.array([(ind1, ind2), (ind2, ind3), (ind3, ind1)])
+ side_lengths = [np.linalg.norm(
+ x1 - x2), np.linalg.norm(x2 - x3), np.linalg.norm(x3 - x1)]
+ l1_ind, l2_ind, l3_ind = np.argsort(side_lengths)
+
+ count = Counter(side_ind[[l1_ind, l2_ind]].flatten())
+ a = count.most_common(1)[0][0]
+ count = Counter(side_ind[[l2_ind, l3_ind]].flatten())
+ b = count.most_common(1)[0][0]
+ count = Counter(side_ind[[l3_ind, l1_ind]].flatten())
+ c = count.most_common(1)[0][0]
+
+ return np.array([a, b, c])
+
+
+def _generate_invariants(sources: NDArray) -> Tuple[NDArray, NDArray]:
+ """
+ Generate invariant features and the corresponding triangles from a set of source points.
+
+ This function constructs triangles from the nearest neighbors of each source point and
+ calculates their invariant features. The invariants are used for matching between images.
+
+ Args:
+ sources: Array of source star positions.
+
+ Returns:
+ A tuple containing the unique invariant features and the corresponding triangle vertices.
+ """
+ arrange = partial(_arrangetriplet, sources=sources)
+
+ inv = []
+ triang_vrtx = []
+ coordtree = KDTree(sources)
+ knn = min(len(sources), NUM_NEAREST_NEIGHBORS)
+ for asrc in sources:
+ _, indx = coordtree.query(asrc, knn)
+
+ all_asterism_triang = [arrange(vertex_indices=list(cmb))
+ for cmb in combinations(indx, 3)]
+ triang_vrtx.extend(all_asterism_triang)
+
+ inv.extend([_invariantfeatures(*sources[triplet])
+ for triplet in all_asterism_triang])
+
+ uniq_ind = [pos for (pos, elem) in enumerate(inv)
+ if elem not in inv[pos + 1:]]
+ inv_uniq = np.array(inv)[uniq_ind]
+ triang_vrtx_uniq = np.array(triang_vrtx)[uniq_ind]
+
+ return inv_uniq, triang_vrtx_uniq
+
+
+class _MatchTransform:
+ """
+ A class to manage the fitting of a geometric transformation using matched invariant points.
+
+ This class handles the estimation of the 2D similarity transformation between
+ two sets of points, and computes errors between estimated and actual points.
+ """
+
+ def __init__(self, source: NDArray, target: NDArray):
+ """
+ Initialize the transformation model with source and target control points.
+
+ Args:
+ source: Source control points.
+ target: Target control points.
+ """
+ self.source = source
+ self.target = target
+
+ def fit(self, data: NDArray) -> Any:
+ """
+ Estimate the best 2D similarity transformation from the matched points in data.
+
+ Args:
+ data: Matched point pairs from source and target.
+
+ Returns:
+ A similarity transform object.
+ """
+ d1, d2, d3 = data.shape
+ s, d = data.reshape(d1 * d2, d3).T
+ return estimate_transform("similarity", self.source[s], self.target[d])
+
+ def get_error(self, data: NDArray, approx_t: Any) -> NDArray:
+ """
+ Calculate the maximum residual error for the matched points given the estimated transform.
+
+ Args:
+ data: Matched point pairs.
+ approx_t: Estimated transformation model.
+
+ Returns:
+ Maximum residual error for each set of matched points.
+ """
+ d1, d2, d3 = data.shape
+ s, d = data.reshape(d1 * d2, d3).T
+ resid = approx_t.residuals(
+ self.source[s], self.target[d]).reshape(d1, d2)
+ return resid.max(axis=1)
+
+
+def _data(image: Union[NDArray, Any]) -> NDArray:
+ """
+ Retrieve the underlying 2D pixel data from the image.
+
+ Args:
+ image: The input image.
+
+ Returns:
+ The pixel data as a 2D NumPy array.
+ """
+ if hasattr(image, "data") and isinstance(image.data, np.ndarray):
+ return image.data
+ return np.asarray(image)
+
+
+def _mask(image: Union[NDArray, Any]) -> Union[NDArray, None]:
+ """
+ Retrieve the mask from the image, if available.
+
+ Args:
+ image: The input image.
+
+ Returns:
+ The mask as a 2D NumPy array, or None if no mask is present.
+ """
+ if hasattr(image, "mask"):
+ thenp_mask = np.asarray(image.mask)
+ return thenp_mask if thenp_mask.ndim == 2 else np.logical_or.reduce(thenp_mask, axis=-1)
+ return None
+
+
+def _bw(image: NDArray) -> NDArray:
+ """
+ Convert the input image to a 2D grayscale image.
+
+ Args:
+ image: Input image, possibly with multiple channels.
+
+ Returns:
+ Grayscale 2D NumPy array.
+ """
+ return image if image.ndim == 2 else _default_average(image, axis=-1)
+
+
+def _shape(image: NDArray) -> tuple[int, int]:
+ """
+ Get the shape of the image, ignoring channels.
+
+ Args:
+ image: Input image.
+
+ Returns:
+ A tuple representing the 2D shape (height, width) of the image.
+ """
+ return image.shape if image.ndim == 2 else image.shape[:2]
+
+
+def find_transform(
+ source: Union[NDArray, Any],
+ target: Union[NDArray, Any],
+ max_control_points: int = 50,
+ detection_sigma: int = 5,
+ min_area: int = 5
+) -> Tuple[Any, Tuple[NDArray, NDArray]]:
+ """
+ Estimate the geometric transformation between the source and target images.
+
+ This function identifies control points (stars) in both the source and target images,
+ computes their triangle-based invariant features, and finds the best transformation
+ to align the source to the target using a RANSAC-based method.
+
+ Args:
+ source: The source image to be transformed.
+ target: The target image to match the source to.
+ max_control_points: Maximum number of control points to use for transformation.
+ detection_sigma: Sigma threshold for detecting control points.
+ min_area: Minimum area for detecting sources in the image.
+
+ Returns:
+ A tuple containing the estimated transformation object and a tuple of matched control points
+ from the source and target images.
+
+ Raises:
+ ValueError: If fewer than 3 control points are found in either image.
+ TypeError: If the input type of source or target is unsupported.
+ """
+ try:
+ source_controlp = (np.array(source)[:max_control_points] if len(_data(source)[0]) == 2
+ else _find_sources(_bw(_data(source)), detection_sigma, min_area, _mask(source))[:max_control_points])
+ except Exception:
+ raise TypeError("Input type for source not supported.")
+
+ try:
+ target_controlp = (np.array(target)[:max_control_points] if len(_data(target)[0]) == 2
+ else _find_sources(_bw(_data(target)), detection_sigma, min_area, _mask(target))[:max_control_points])
+ except Exception:
+ raise TypeError("Input type for target not supported.")
+
+ if len(source_controlp) < 3 or len(target_controlp) < 3:
+ raise ValueError(
+ "Reference stars in source or target image are less than the minimum value (3).")
+
+ source_invariants, source_asterisms = _generate_invariants(source_controlp)
+ target_invariants, target_asterisms = _generate_invariants(target_controlp)
+
+ source_tree = KDTree(source_invariants)
+ target_tree = KDTree(target_invariants)
+
+ matches_list = source_tree.query_ball_tree(target_tree, r=0.1)
+
+ matches = [list(zip(t1, t2)) for t1, t2_list in zip(
+ source_asterisms, matches_list) for t2 in target_asterisms[t2_list]]
+ matches = np.array(matches)
+
+ inv_model = _MatchTransform(source_controlp, target_controlp)
+ n_invariants = len(matches)
+ min_matches = max(1, min(10, int(n_invariants * MIN_MATCHES_FRACTION)))
+
+ if (len(source_controlp) == 3 or len(target_controlp) == 3) and len(matches) == 1:
+ best_t = inv_model.fit(matches)
+ inlier_ind = np.arange(len(matches))
+ else:
+ best_t, inlier_ind = _ransac(
+ matches, inv_model, PIXEL_TOL, min_matches)
+
+ triangle_inliers = matches[inlier_ind]
+ inl_arr = triangle_inliers.reshape(-1, 2)
+ inl_unique = set(map(tuple, inl_arr))
+
+ inl_dict = {}
+ for s_i, t_i in inl_unique:
+ s_vertex = source_controlp[s_i]
+ t_vertex = target_controlp[t_i]
+ t_vertex_pred = matrix_transform(s_vertex, best_t.params)
+ error = np.linalg.norm(t_vertex_pred - t_vertex)
+
+ if s_i not in inl_dict or error < inl_dict[s_i][1]:
+ inl_dict[s_i] = (t_i, error)
+
+ inl_arr_unique = np.array([[s_i, t_i]
+ for s_i, (t_i, _) in inl_dict.items()])
+ s, d = inl_arr_unique.T
+
+ return best_t, (source_controlp[s], target_controlp[d])
+
+
+def apply_transform(
+ transform: Any,
+ source: Union[NDArray, Any],
+ target: Union[NDArray, Any],
+ fill_value: Union[float, None] = None,
+ propagate_mask: bool = False
+) -> Tuple[NDArray, NDArray]:
+ """
+ Apply the estimated transformation to align the source image to the target image.
+
+ The transformation is applied to the source image, and an optional mask is propagated
+ if requested. The function returns the aligned source image and a binary footprint
+ of the transformed region.
+
+ Args:
+ transform: The transformation to apply.
+ source: The source image to be transformed.
+ target: The target image to align the source to.
+ fill_value: Value to fill in regions outside the source image after transformation.
+ propagate_mask: Whether to propagate the source mask after transformation.
+
+ Returns:
+ A tuple containing the aligned source image and the transformation footprint.
+ """
+ source_data = _data(source)
+ target_shape = _data(target).shape
+
+ aligned_image = warp(
+ source_data,
+ inverse_map=transform.inverse,
+ output_shape=target_shape,
+ order=3,
+ mode="constant",
+ cval=_default_median(source_data),
+ clip=True,
+ preserve_range=True,
+ )
+
+ footprint = warp(
+ np.zeros(_shape(source_data), dtype="float32"),
+ inverse_map=transform.inverse,
+ output_shape=target_shape,
+ cval=1.0,
+ )
+ footprint = footprint > 0.4
+
+ source_mask = _mask(source)
+ if source_mask is not None and propagate_mask:
+ if source_mask.shape == source_data.shape:
+ source_mask_rot = warp(
+ source_mask.astype("float32"),
+ inverse_map=transform.inverse,
+ output_shape=target_shape,
+ cval=1.0,
+ )
+ source_mask_rot = source_mask_rot > 0.4
+ footprint |= source_mask_rot
+
+ if fill_value is not None:
+ aligned_image[footprint] = fill_value
+
+ return aligned_image, footprint
+
+
+def register(
+ source: Union[NDArray, Any],
+ target: Union[NDArray, Any],
+ fill_value: Union[float, None] = None,
+ propagate_mask: bool = False,
+ max_control_points: int = 50,
+ detection_sigma: int = 5,
+ min_area: int = 5
+) -> Tuple[NDArray, NDArray]:
+ """
+ Register and align the source image to the target image using triangle invariants.
+
+ This function estimates the transformation between the source and target images,
+ applies the transformation, and returns the aligned source image along with
+ the transformation footprint.
+
+ Args:
+ source: The source image to be aligned.
+ target: The target image for alignment.
+ fill_value: Value to fill in regions outside the source image after transformation.
+ propagate_mask: Whether to propagate the source mask after transformation.
+ max_control_points: Maximum number of control points to use for transformation.
+ detection_sigma: Sigma threshold for detecting control points.
+ min_area: Minimum area for detecting sources in the image.
+
+ Returns:
+ A tuple containing the aligned source image and the transformation footprint.
+ """
+ t, _ = find_transform(
+ source=source,
+ target=target,
+ max_control_points=max_control_points,
+ detection_sigma=detection_sigma,
+ min_area=min_area,
+ )
+ return apply_transform(t, source, target, fill_value, propagate_mask)
+
+
+def _find_sources(img: NDArray, detection_sigma: int = 5, min_area: int = 5, mask: Union[NDArray, None] = None) -> NDArray:
+ """
+ Detect bright sources (e.g., stars) in the image using SEP (Source Extractor).
+
+ This function returns the coordinates of sources sorted by brightness.
+
+ Args:
+ img: The input image in which to detect sources.
+ detection_sigma: Sigma threshold for source detection.
+ min_area: Minimum area for detecting sources.
+ mask: Optional mask for ignoring certain parts of the image.
+
+ Returns:
+ A NumPy array of detected source coordinates (x, y), sorted by brightness.
+ """
+ image = img.astype("float32")
+ bkg = sep.Background(image, mask=mask)
+ thresh = detection_sigma * bkg.globalrms
+ sources = sep.extract(image - bkg.back(), thresh,
+ minarea=min_area, mask=mask)
+ sources.sort(order="flux")
+ return np.array([[asrc["x"], asrc["y"]] for asrc in sources[::-1]])
+
+
+class MaxIterError(RuntimeError):
+ """
+ Custom error raised if the maximum number of iterations is reached during the RANSAC process.
+
+ This exception indicates that the RANSAC algorithm has exhausted all possible
+ matching triangles without finding an acceptable transformation.
+ """
+ pass
+
+
+def _ransac(data: NDArray, model: Any, thresh: float, min_matches: int) -> Tuple[Any, NDArray]:
+ """
+ Fit a model to data using the RANSAC (Random Sample Consensus) algorithm.
+
+ This robust method estimates the transformation model by iteratively fitting
+ to subsets of data and discarding outliers.
+
+ Args:
+ data: Matched point pairs.
+ model: The transformation model to fit.
+ thresh: Error threshold to consider a data point as an inlier.
+ min_matches: Minimum number of inliers required to accept the model.
+
+ Returns:
+ A tuple containing the best-fit model and the indices of inliers.
+
+ Raises:
+ MaxIterError: If the maximum number of iterations is reached without finding
+ an acceptable transformation.
+ """
+ n_data = data.shape[0]
+ all_idxs = np.arange(n_data)
+ np.random.default_rng().shuffle(all_idxs)
+
+ for iter_i in range(n_data):
+ maybe_idxs = all_idxs[iter_i:iter_i + 1]
+ test_idxs = np.concatenate([all_idxs[:iter_i], all_idxs[iter_i + 1:]])
+ maybeinliers = data[maybe_idxs, :]
+ test_points = data[test_idxs, :]
+ maybemodel = model.fit(maybeinliers)
+ test_err = model.get_error(test_points, maybemodel)
+ also_idxs = test_idxs[test_err < thresh]
+ alsoinliers = data[also_idxs, :]
+ if len(alsoinliers) >= min_matches:
+ good_data = np.concatenate((maybeinliers, alsoinliers))
+ good_fit = model.fit(good_data)
+ break
+ else:
+ raise MaxIterError(
+ "List of matching triangles exhausted before an acceptable transformation was found")
+
+ better_fit = good_fit
+ for _ in range(3):
+ test_err = model.get_error(data, better_fit)
+ better_inlier_idxs = np.arange(n_data)[test_err < thresh]
+ better_data = data[better_inlier_idxs]
+ better_fit = model.fit(better_data)
+
+ return better_fit, better_inlier_idxs
diff --git a/pysrc/image/auto_histogram/__init__.py b/pysrc/image/auto_histogram/__init__.py
new file mode 100644
index 00000000..96186445
--- /dev/null
+++ b/pysrc/image/auto_histogram/__init__.py
@@ -0,0 +1,2 @@
+from .histogram import auto_histogram
+from .utils import save_image, load_image
diff --git a/pysrc/image/auto_histogram/histogram.py b/pysrc/image/auto_histogram/histogram.py
new file mode 100644
index 00000000..a402b7ff
--- /dev/null
+++ b/pysrc/image/auto_histogram/histogram.py
@@ -0,0 +1,111 @@
+import cv2
+import numpy as np
+from typing import Optional, List, Tuple
+from .utils import save_image, load_image
+
+
+def auto_histogram(image: Optional[np.ndarray] = None, clip_shadow: float = 0.01, clip_highlight: float = 0.01,
+ target_median: int = 128, method: str = 'gamma', apply_clahe: bool = False,
+ clahe_clip_limit: float = 2.0, clahe_tile_grid_size: Tuple[int, int] = (8, 8),
+ apply_noise_reduction: bool = False, noise_reduction_method: str = 'median',
+ apply_sharpening: bool = False, sharpening_strength: float = 1.0,
+ batch_process: bool = False, file_list: Optional[List[str]] = None) -> Optional[List[np.ndarray]]:
+ """
+ Apply automated histogram transformations and enhancements to an image or a batch of images.
+
+ Parameters:
+ - image: Input image, grayscale or RGB.
+ - clip_shadow: Percentage of shadow pixels to clip.
+ - clip_highlight: Percentage of highlight pixels to clip.
+ - target_median: Target median value for histogram stretching.
+ - method: Stretching method ('gamma', 'logarithmic', 'mtf').
+ - apply_clahe: Apply CLAHE (Contrast Limited Adaptive Histogram Equalization).
+ - clahe_clip_limit: CLAHE clip limit.
+ - clahe_tile_grid_size: CLAHE grid size.
+ - apply_noise_reduction: Apply noise reduction.
+ - noise_reduction_method: Noise reduction method ('median', 'gaussian').
+ - apply_sharpening: Apply image sharpening.
+ - sharpening_strength: Strength of sharpening.
+ - batch_process: Enable batch processing mode.
+ - file_list: List of file paths for batch processing.
+
+ Returns:
+ - Processed image or list of processed images.
+ """
+ def histogram_clipping(image: np.ndarray, clip_shadow: float, clip_highlight: float) -> np.ndarray:
+ flat = image.flatten()
+ low_val = np.percentile(flat, clip_shadow * 100)
+ high_val = np.percentile(flat, 100 - clip_highlight * 100)
+ return np.clip(image, low_val, high_val)
+
+ def gamma_transformation(image: np.ndarray, target_median: int) -> np.ndarray:
+ mean_val = np.median(image)
+ gamma = np.log(target_median / 255.0) / np.log(mean_val / 255.0)
+ return np.array(255 * (image / 255.0) ** gamma, dtype='uint8')
+
+ def logarithmic_transformation(image: np.ndarray) -> np.ndarray:
+ c = 255 / np.log(1 + np.max(image))
+ return np.array(c * np.log(1 + image), dtype='uint8')
+
+ def mtf_transformation(image: np.ndarray, target_median: int) -> np.ndarray:
+ mean_val = np.median(image)
+ mtf = target_median / mean_val
+ return np.array(image * mtf, dtype='uint8')
+
+ def apply_clahe_method(image: np.ndarray, clip_limit: float, tile_grid_size: Tuple[int, int]) -> np.ndarray:
+ clahe = cv2.createCLAHE(clipLimit=clip_limit,
+ tileGridSize=tile_grid_size)
+ if len(image.shape) == 2:
+ return clahe.apply(image)
+ else:
+ return cv2.merge([clahe.apply(channel) for channel in cv2.split(image)])
+
+ def noise_reduction(image: np.ndarray, method: str) -> np.ndarray:
+ if method == 'median':
+ return cv2.medianBlur(image, 3)
+ elif method == 'gaussian':
+ return cv2.GaussianBlur(image, (3, 3), 0)
+ else:
+ raise ValueError("Invalid noise reduction method specified.")
+
+ def sharpen_image(image: np.ndarray, strength: float) -> np.ndarray:
+ kernel = np.array([[-1, -1, -1], [-1, 9 + strength, -1], [-1, -1, -1]])
+ return cv2.filter2D(image, -1, kernel)
+
+ def process_single_image(image: np.ndarray) -> np.ndarray:
+ if apply_noise_reduction:
+ image = noise_reduction(image, noise_reduction_method)
+
+ image = histogram_clipping(image, clip_shadow, clip_highlight)
+
+ if method == 'gamma':
+ image = gamma_transformation(image, target_median)
+ elif method == 'logarithmic':
+ image = logarithmic_transformation(image)
+ elif method == 'mtf':
+ image = mtf_transformation(image, target_median)
+ else:
+ raise ValueError("Invalid method specified.")
+
+ if apply_clahe:
+ image = apply_clahe_method(
+ image, clahe_clip_limit, clahe_tile_grid_size)
+
+ if apply_sharpening:
+ image = sharpen_image(image, sharpening_strength)
+
+ return image
+
+ if batch_process:
+ if file_list is None:
+ raise ValueError(
+ "File list cannot be None when batch processing is enabled.")
+ processed_images = []
+ for file_path in file_list:
+ image = load_image(file_path, method != 'mtf')
+ processed_image = process_single_image(image)
+ processed_images.append(processed_image)
+ save_image(f'processed_{file_path}', processed_image)
+ return processed_images
+ else:
+ return process_single_image(image)
diff --git a/pysrc/image/auto_histogram/processing.py b/pysrc/image/auto_histogram/processing.py
new file mode 100644
index 00000000..3e64425d
--- /dev/null
+++ b/pysrc/image/auto_histogram/processing.py
@@ -0,0 +1,28 @@
+from .histogram import auto_histogram
+from .utils import save_image, load_image
+import os
+from typing import List
+
+
+def process_directory(directory: str, output_directory: str, method: str = 'gamma', **kwargs):
+ """
+ Process all images in a directory using the auto_histogram function.
+
+ :param directory: Input directory containing images to process.
+ :param output_directory: Directory to save processed images.
+ :param method: Histogram stretching method ('gamma', 'logarithmic', 'mtf').
+ :param kwargs: Additional parameters for auto_histogram.
+ """
+ if not os.path.exists(output_directory):
+ os.makedirs(output_directory)
+
+ file_list = [os.path.join(directory, file) for file in os.listdir(
+ directory) if file.endswith(('.jpg', '.png'))]
+
+ processed_images = auto_histogram(
+ None, method=method, batch_process=True, file_list=file_list, **kwargs)
+
+ for file, image in zip(file_list, processed_images):
+ output_path = os.path.join(
+ output_directory, f'processed_{os.path.basename(file)}')
+ save_image(output_path, image)
diff --git a/pysrc/image/auto_histogram/utils.py b/pysrc/image/auto_histogram/utils.py
new file mode 100644
index 00000000..f62a21c4
--- /dev/null
+++ b/pysrc/image/auto_histogram/utils.py
@@ -0,0 +1,23 @@
+import cv2
+import numpy as np
+from typing import Tuple, Union
+
+def save_image(filepath: str, image: np.ndarray):
+ """
+ Save an image to the specified filepath.
+
+ :param filepath: Path to save the image.
+ :param image: Image data.
+ """
+ cv2.imwrite(filepath, image)
+
+def load_image(filepath: str, grayscale: bool = False) -> np.ndarray:
+ """
+ Load an image from the specified filepath.
+
+ :param filepath: Path to load the image from.
+ :param grayscale: Load image as grayscale if True.
+ :return: Loaded image.
+ """
+ flags = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
+ return cv2.imread(filepath, flags)
diff --git a/pysrc/image/channel/__init__.py b/pysrc/image/channel/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/pysrc/image/channel/combination.py b/pysrc/image/channel/combination.py
new file mode 100644
index 00000000..a072018c
--- /dev/null
+++ b/pysrc/image/channel/combination.py
@@ -0,0 +1,67 @@
+from PIL import Image
+import numpy as np
+from skimage import color
+import cv2
+
+def resize_to_match(image, target_size):
+ return image.resize(target_size, Image.ANTIALIAS)
+
+def load_image_as_gray(path):
+ return Image.open(path).convert('L')
+
+def combine_channels(channels, color_space='RGB'):
+ match color_space:
+ case 'RGB':
+ return Image.merge("RGB", channels)
+
+ case 'LAB':
+ lab_image = Image.merge("LAB", channels)
+ return lab_image.convert('RGB')
+
+ case 'HSV':
+ hsv_image = Image.merge("HSV", channels)
+ return hsv_image.convert('RGB')
+
+ case 'HSI':
+ hsi_image = np.dstack([np.array(ch) / 255.0 for ch in channels])
+ rgb_image = color.hsv2rgb(hsi_image) # Scikit-image doesn't have direct HSI support; using HSV as proxy
+ return Image.fromarray((rgb_image * 255).astype(np.uint8))
+
+ case _:
+ raise ValueError(f"Unsupported color space: {color_space}")
+
+def channel_combination(src1_path, src2_path, src3_path, color_space='RGB'):
+ # Load and resize images to match
+ channel_1 = load_image_as_gray(src1_path)
+ channel_2 = load_image_as_gray(src2_path)
+ channel_3 = load_image_as_gray(src3_path)
+
+ # Automatically resize images to match the size of the first image
+ size = channel_1.size
+ channel_2 = resize_to_match(channel_2, size)
+ channel_3 = resize_to_match(channel_3, size)
+
+ # Combine the channels
+ combined_image = combine_channels([channel_1, channel_2, channel_3], color_space=color_space)
+
+ return combined_image
+
+# 示例用法
+if __name__ == "__main__":
+ # 指定通道对应的图像路径
+ src1_path = 'channel_R.png' # 对应于RGB空间的R通道或Lab空间的L通道
+ src2_path = 'channel_G.png' # 对应于RGB空间的G通道或Lab空间的a*通道
+ src3_path = 'channel_B.png' # 对应于RGB空间的B通道或Lab空间的b*通道
+
+ # 执行通道组合,保存结果
+ combined_rgb = channel_combination(src1_path, src2_path, src3_path, color_space='RGB')
+ combined_rgb.save('combined_rgb.png')
+
+ combined_lab = channel_combination(src1_path, src2_path, src3_path, color_space='LAB')
+ combined_lab.save('combined_lab.png')
+
+ combined_hsv = channel_combination(src1_path, src2_path, src3_path, color_space='HSV')
+ combined_hsv.save('combined_hsv.png')
+
+ combined_hsi = channel_combination(src1_path, src2_path, src3_path, color_space='HSI')
+ combined_hsi.save('combined_hsi.png')
diff --git a/pysrc/image/channel/extraction.py b/pysrc/image/channel/extraction.py
new file mode 100644
index 00000000..fd85d782
--- /dev/null
+++ b/pysrc/image/channel/extraction.py
@@ -0,0 +1,138 @@
+import cv2
+import numpy as np
+import os
+from matplotlib import pyplot as plt
+
+
+def extract_channels(image, color_space='RGB'):
+ channels = {}
+
+ if color_space == 'RGB':
+ channels['R'], channels['G'], channels['B'] = cv2.split(image)
+
+ elif color_space == 'XYZ':
+ xyz_image = cv2.cvtColor(image, cv2.COLOR_BGR2XYZ)
+ channels['X'], channels['Y'], channels['Z'] = cv2.split(xyz_image)
+
+ elif color_space == 'Lab':
+ lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
+ channels['L*'], channels['a*'], channels['b*'] = cv2.split(lab_image)
+
+ elif color_space == 'LCh':
+ lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
+ L, a, b = cv2.split(lab_image)
+ h, c = cv2.cartToPolar(a, b)
+ channels['L*'] = L
+ channels['c*'] = c
+ channels['h*'] = h
+
+ elif color_space == 'HSV':
+ hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+ channels['H'], channels['S'], channels['V'] = cv2.split(hsv_image)
+
+ elif color_space == 'HSI':
+ hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+ H, S, V = cv2.split(hsv_image)
+ I = V.copy()
+ channels['H'] = H
+ channels['Si'] = S
+ channels['I'] = I
+
+ elif color_space == 'YUV':
+ yuv_image = cv2.cvtColor(image, cv2.COLOR_BGR2YUV)
+ channels['Y'], channels['U'], channels['V'] = cv2.split(yuv_image)
+
+ elif color_space == 'YCbCr':
+ ycbcr_image = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb)
+ channels['Y'], channels['Cb'], channels['Cr'] = cv2.split(ycbcr_image)
+
+ elif color_space == 'HSL':
+ hsl_image = cv2.cvtColor(image, cv2.COLOR_BGR2HLS)
+ channels['H'], channels['S'], channels['L'] = cv2.split(hsl_image)
+
+ elif color_space == 'CMYK':
+ # Trick: Convert to CMY via XYZ
+ cmyk_image = cv2.cvtColor(image, cv2.COLOR_BGR2XYZ)
+ C, M, Y = cv2.split(255 - cmyk_image)
+ K = np.minimum(C, np.minimum(M, Y))
+ channels['C'] = C
+ channels['M'] = M
+ channels['Y'] = Y
+ channels['K'] = K
+
+ return channels
+
+
+def show_histogram(channel_data, title='Channel Histogram'):
+ plt.figure()
+ plt.title(title)
+ plt.xlabel('Pixel Value')
+ plt.ylabel('Frequency')
+ plt.hist(channel_data.ravel(), bins=256, range=[0, 256])
+ plt.show()
+
+
+def merge_channels(channels):
+ merged_image = None
+ channel_list = list(channels.values())
+ if len(channel_list) >= 3:
+ merged_image = cv2.merge(channel_list[:3])
+ elif len(channel_list) == 2:
+ merged_image = cv2.merge(
+ [channel_list[0], channel_list[1], np.zeros_like(channel_list[0])])
+ elif len(channel_list) == 1:
+ merged_image = channel_list[0]
+ return merged_image
+
+
+def process_directory(input_dir, output_dir, color_space='RGB'):
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ for filename in os.listdir(input_dir):
+ if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
+ image_path = os.path.join(input_dir, filename)
+ image = cv2.imread(image_path)
+ base_name = os.path.splitext(filename)[0]
+ extracted_channels = extract_channels(image, color_space)
+ for channel_name, channel_data in extracted_channels.items():
+ save_path = os.path.join(
+ output_dir, f"{base_name}_{channel_name}.png")
+ cv2.imwrite(save_path, channel_data)
+ show_histogram(
+ channel_data, title=f"{base_name} - {channel_name}")
+ print(f"Saved {save_path}")
+
+
+def save_channels(channels, base_name='output'):
+ for channel_name, channel_data in channels.items():
+ filename = f"{base_name}_{channel_name}.png"
+ cv2.imwrite(filename, channel_data)
+ print(f"Saved {filename}")
+
+
+def display_image(title, image):
+ cv2.imshow(title, image)
+ cv2.waitKey(0)
+ cv2.destroyAllWindows()
+
+
+# Example usage:
+image = cv2.imread('input_image.png')
+extracted_channels = extract_channels(image, color_space='Lab')
+
+# Show histograms
+for name, channel in extracted_channels.items():
+ show_histogram(channel, title=f"{name} Histogram")
+
+# Save channels
+save_channels(extracted_channels, base_name='output_image')
+
+# Merge channels
+merged_image = merge_channels(extracted_channels)
+if merged_image is not None:
+ display_image('Merged Image', merged_image)
+ cv2.imwrite('merged_image.png', merged_image)
+
+# Process directory
+process_directory('input_images', 'output_images', color_space='HSV')
diff --git a/pysrc/image/color_calibration/__init__.py b/pysrc/image/color_calibration/__init__.py
new file mode 100644
index 00000000..0a89bce9
--- /dev/null
+++ b/pysrc/image/color_calibration/__init__.py
@@ -0,0 +1,12 @@
+from .processing import ImageProcessing
+from .calibration import ColorCalibration
+from .utils import select_roi
+from .io import read_image, save_image
+
+__all__ = [
+ "ImageProcessing",
+ "ColorCalibration",
+ "select_roi",
+ "read_image",
+ "save_image",
+]
diff --git a/pysrc/image/color_calibration/calibration.py b/pysrc/image/color_calibration/calibration.py
new file mode 100644
index 00000000..6ca4cb2c
--- /dev/null
+++ b/pysrc/image/color_calibration/calibration.py
@@ -0,0 +1,65 @@
+import cv2
+import numpy as np
+from typing import Tuple
+from dataclasses import dataclass
+
+@dataclass
+class ColorCalibration:
+ """Class to handle color calibration tasks for astronomical images."""
+ image: np.ndarray
+
+ def calculate_color_factors(self, white_reference: np.ndarray) -> np.ndarray:
+ """
+ Calculate the color calibration factors based on a white reference image.
+
+ Parameters:
+ white_reference (np.ndarray): The white reference region.
+
+ Returns:
+ np.ndarray: The calibration factors for the RGB channels.
+ """
+ mean_values = np.mean(white_reference, axis=(0, 1))
+ factors = 1.0 / mean_values
+ return factors
+
+ def apply_color_calibration(self, factors: np.ndarray) -> np.ndarray:
+ """
+ Apply color calibration to the image using provided factors.
+
+ Parameters:
+ factors (np.ndarray): The RGB calibration factors.
+
+ Returns:
+ np.ndarray: Color calibrated image.
+ """
+ calibrated_image = np.zeros_like(self.image)
+ for i in range(3):
+ calibrated_image[:, :, i] = self.image[:, :, i] * factors[i]
+ return calibrated_image
+
+ def automatic_white_balance(self) -> np.ndarray:
+ """
+ Perform automatic white balance using the Gray World algorithm.
+
+ Returns:
+ np.ndarray: White balanced image.
+ """
+ mean_values = np.mean(self.image, axis=(0, 1))
+ gray_value = np.mean(mean_values)
+ factors = gray_value / mean_values
+ return self.apply_color_calibration(factors)
+
+ def match_histograms(self, reference_image: np.ndarray) -> np.ndarray:
+ """
+ Match the color histogram of the image to a reference image.
+
+ Parameters:
+ reference_image (np.ndarray): The reference image whose histogram is to be matched.
+
+ Returns:
+ np.ndarray: Histogram matched image.
+ """
+ matched_image = np.zeros_like(self.image)
+ for i in range(3):
+ matched_image[:, :, i] = exposure.match_histograms(self.image[:, :, i], reference_image[:, :, i], multichannel=False)
+ return matched_image
diff --git a/pysrc/image/color_calibration/io.py b/pysrc/image/color_calibration/io.py
new file mode 100644
index 00000000..8127f13c
--- /dev/null
+++ b/pysrc/image/color_calibration/io.py
@@ -0,0 +1,28 @@
+import cv2
+import numpy as np
+
+def read_image(image_path: str) -> np.ndarray:
+ """
+ Read an image from the given file path.
+
+ Parameters:
+ image_path (str): Path to the image file.
+
+ Returns:
+ np.ndarray: The image in a floating point format normalized to [0, 1].
+ """
+ image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 65535.0
+ return image
+
+def save_image(image: np.ndarray, output_path: str) -> None:
+ """
+ Save the image to the given file path.
+
+ Parameters:
+ image (np.ndarray): The image to save.
+ output_path (str): Path to the output file.
+
+ Returns:
+ None
+ """
+ cv2.imwrite(output_path, (image * 65535).astype(np.uint16))
diff --git a/pysrc/image/color_calibration/processing.py b/pysrc/image/color_calibration/processing.py
new file mode 100644
index 00000000..27ca9243
--- /dev/null
+++ b/pysrc/image/color_calibration/processing.py
@@ -0,0 +1,64 @@
+import cv2
+import numpy as np
+from skimage import exposure, filters, feature
+from dataclasses import dataclass
+
+@dataclass
+class ImageProcessing:
+ """Class to handle image processing tasks for astronomical images."""
+ image: np.ndarray
+
+ def apply_gamma_correction(self, gamma: float) -> np.ndarray:
+ """
+ Apply gamma correction to the image to adjust the contrast.
+
+ Parameters:
+ gamma (float): The gamma value to apply. < 1 makes the image darker, > 1 makes it brighter.
+
+ Returns:
+ np.ndarray: Gamma-corrected image.
+ """
+ inv_gamma = 1.0 / gamma
+ table = np.array([((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)]).astype("float32")
+ corrected_image = cv2.LUT(self.image, table)
+ return corrected_image
+
+ def histogram_equalization(self) -> np.ndarray:
+ """
+ Apply histogram equalization to improve the contrast of the image.
+
+ Returns:
+ np.ndarray: Histogram equalized image.
+ """
+ equalized_image = np.zeros_like(self.image)
+ for i in range(3):
+ equalized_image[:, :, i] = exposure.equalize_hist(self.image[:, :, i])
+ return equalized_image
+
+ def denoise_image(self, h: float = 10.0) -> np.ndarray:
+ """
+ Apply Non-Local Means Denoising to the image.
+
+ Parameters:
+ h (float): Filter strength. Higher h removes noise more aggressively.
+
+ Returns:
+ np.ndarray: Denoised image.
+ """
+ denoised_image = cv2.fastNlMeansDenoisingColored(self.image, None, h, h, 7, 21)
+ return denoised_image
+
+ def detect_stars(self, min_distance: int = 10, threshold_rel: float = 0.2) -> np.ndarray:
+ """
+ Detect stars in the image using peak local max method.
+
+ Parameters:
+ min_distance (int): Minimum number of pixels separating peaks in a region of 2 * min_distance + 1.
+ threshold_rel (float): Minimum intensity of peaks relative to the highest peak.
+
+ Returns:
+ np.ndarray: Coordinates of the detected stars.
+ """
+ gray_image = cv2.cvtColor((self.image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
+ coordinates = feature.peak_local_max(gray_image, min_distance=min_distance, threshold_rel=threshold_rel)
+ return coordinates
diff --git a/pysrc/image/color_calibration/utils.py b/pysrc/image/color_calibration/utils.py
new file mode 100644
index 00000000..f10d9582
--- /dev/null
+++ b/pysrc/image/color_calibration/utils.py
@@ -0,0 +1,20 @@
+import cv2
+import numpy as np
+from typing import Tuple
+
+def select_roi(image: np.ndarray, x: int, y: int, width: int, height: int) -> np.ndarray:
+ """
+ Select a Region of Interest (ROI) from the image.
+
+ Parameters:
+ image (np.ndarray): The input image.
+ x (int): X-coordinate of the top-left corner of the ROI.
+ y (int): Y-coordinate of the top-left corner of the ROI.
+ width (int): Width of the ROI.
+ height (int): Height of the ROI.
+
+ Returns:
+ np.ndarray: The selected ROI.
+ """
+ roi = image[y:y+height, x:x+width]
+ return roi
diff --git a/pysrc/image/debayer/__init__.py b/pysrc/image/debayer/__init__.py
new file mode 100644
index 00000000..e21f4ef2
--- /dev/null
+++ b/pysrc/image/debayer/__init__.py
@@ -0,0 +1,2 @@
+from .debayer import Debayer
+from .metrics import calculate_image_quality, evaluate_image_quality
diff --git a/pysrc/image/debayer/debayer.py b/pysrc/image/debayer/debayer.py
new file mode 100644
index 00000000..bde2fcce
--- /dev/null
+++ b/pysrc/image/debayer/debayer.py
@@ -0,0 +1,247 @@
+import numpy as np
+import cv2
+from concurrent.futures import ThreadPoolExecutor
+from typing import Optional, Tuple
+import time
+
+
+class Debayer:
+ def __init__(self, method: str = 'bilinear', pattern: Optional[str] = None):
+ """
+ Initialize Debayer object with method and optional Bayer pattern.
+
+ :param method: Debayering method ('superpixel', 'bilinear', 'vng', 'ahd', 'laplacian')
+ :param pattern: Bayer pattern ('BGGR', 'RGGB', 'GBRG', 'GRBG'), None for auto-detection
+ """
+ self.method = method
+ self.pattern = pattern
+
+ def detect_bayer_pattern(self, image: np.ndarray) -> str:
+ """
+ Automatically detect Bayer pattern from the CFA image.
+ """
+ height, width = image.shape
+
+ # 初始化统计信息
+ patterns = {'BGGR': 0, 'RGGB': 0, 'GBRG': 0, 'GRBG': 0}
+
+ # 检测边缘并增强检测精度
+ edges = cv2.Canny(image, 50, 150)
+
+ for i in range(0, height - 1, 2):
+ for j in range(0, width - 1, 2):
+ # BGGR
+ patterns['BGGR'] += (image[i, j] + image[i+1, j+1]) + \
+ (edges[i, j] + edges[i+1, j+1])
+ # RGGB
+ patterns['RGGB'] += (image[i+1, j] + image[i, j+1]) + \
+ (edges[i+1, j] + edges[i, j+1])
+ # GBRG
+ patterns['GBRG'] += (image[i, j+1] + image[i+1, j]) + \
+ (edges[i, j+1] + edges[i+1, j])
+ # GRBG
+ patterns['GRBG'] += (image[i, j] + image[i+1, j+1]) + \
+ (edges[i, j] + edges[i+1, j+1])
+
+ # 分析颜色通道的分布,进一步强化检测
+ color_sums = {'BGGR': 0, 'RGGB': 0, 'GBRG': 0, 'GRBG': 0}
+
+ # 遍历整个图像并分析色彩通道的强度分布
+ for i in range(0, height - 1, 2):
+ for j in range(0, width - 1, 2):
+ block = image[i:i+2, j:j+2]
+ color_sums['BGGR'] += block[0, 0] + block[1, 1] # 蓝-绿对
+ color_sums['RGGB'] += block[1, 0] + block[0, 1] # 红-绿对
+ color_sums['GBRG'] += block[0, 1] + block[1, 0] # 绿-蓝对
+ color_sums['GRBG'] += block[0, 0] + block[1, 1] # 绿-红对
+
+ # 综合所有信息,选择最有可能的Bayer模式
+ for pattern in patterns.keys():
+ patterns[pattern] += color_sums[pattern]
+
+ detected_pattern = max(patterns, key=patterns.get)
+
+ return detected_pattern
+
+ def debayer_image(self, cfa_image: np.ndarray) -> np.ndarray:
+ """
+ Perform the debayering process using the specified method.
+ """
+ if self.pattern is None:
+ self.pattern = self.detect_bayer_pattern(cfa_image)
+
+ cfa_image = self.extend_image_edges(cfa_image, pad_width=2)
+ print(f"Using Bayer pattern: {self.pattern}")
+
+ if self.method == 'superpixel':
+ return self.debayer_superpixel(cfa_image)
+ elif self.method == 'bilinear':
+ return self.debayer_bilinear(cfa_image)
+ elif self.method == 'vng':
+ return self.debayer_vng(cfa_image)
+ elif self.method == 'ahd':
+ return self.parallel_debayer_ahd(cfa_image)
+ elif self.method == 'laplacian':
+ return self.debayer_laplacian_harmonization(cfa_image)
+ else:
+ raise ValueError(f"Unknown debayer method: {self.method}")
+
+ def debayer_superpixel(self, cfa_image: np.ndarray) -> np.ndarray:
+ red = cfa_image[0::2, 0::2]
+ green = (cfa_image[0::2, 1::2] + cfa_image[1::2, 0::2]) / 2
+ blue = cfa_image[1::2, 1::2]
+
+ rgb_image = np.stack((red, green, blue), axis=-1)
+ return rgb_image
+
+ def debayer_bilinear(self, cfa_image, pattern='BGGR'):
+ """
+ 使用双线性插值法进行去拜耳处理。
+
+ :param cfa_image: 输入的CFA图像
+ :param pattern: Bayer模式 ('BGGR', 'RGGB', 'GBRG', 'GRBG')
+ :return: 去拜耳处理后的RGB图像
+ """
+ if pattern == 'BGGR':
+ return cv2.cvtColor(cfa_image, cv2.COLOR_BayerBG2BGR)
+ elif pattern == 'RGGB':
+ return cv2.cvtColor(cfa_image, cv2.COLOR_BayerRG2BGR)
+ elif pattern == 'GBRG':
+ return cv2.cvtColor(cfa_image, cv2.COLOR_BayerGB2BGR)
+ elif pattern == 'GRBG':
+ return cv2.cvtColor(cfa_image, cv2.COLOR_BayerGR2BGR)
+ else:
+ raise ValueError(f"Unsupported Bayer pattern: {pattern}")
+
+ def debayer_vng(self, cfa_image: np.ndarray) -> np.ndarray:
+ code = cv2.COLOR_BayerBG2BGR_VNG if pattern == 'BGGR' else cv2.COLOR_BayerRG2BGR_VNG
+ rgb_image = cv2.cvtColor(cfa_image, code)
+ return rgb_image
+
+ def parallel_debayer_ahd(self, cfa_image: np.ndarray, num_threads: int = 4) -> np.ndarray:
+ height, width = cfa_image.shape
+ chunk_size = height // num_threads
+
+ # 用于存储每个线程处理的部分图像
+ results = [None] * num_threads
+
+ def process_chunk(start_row, end_row, index):
+ chunk = cfa_image[start_row:end_row, :]
+ gradient_x, gradient_y = calculate_gradients(chunk)
+ green_channel = interpolate_green_channel(
+ chunk, gradient_x, gradient_y)
+ red_channel, blue_channel = interpolate_red_blue_channel(
+ chunk, green_channel, pattern)
+ rgb_chunk = np.stack(
+ (red_channel, green_channel, blue_channel), axis=-1)
+ results[index] = np.clip(rgb_chunk, 0, 255).astype(np.uint8)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
+ futures = []
+ for i in range(num_threads):
+ start_row = i * chunk_size
+ end_row = (i + 1) * chunk_size if i < num_threads - \
+ 1 else height
+ futures.append(executor.submit(
+ process_chunk, start_row, end_row, i))
+
+ concurrent.futures.wait(futures)
+
+ # 合并处理后的块
+ rgb_image = np.vstack(results)
+ return rgb_image
+
+ def calculate_laplacian(self, image):
+ """
+ 计算图像的拉普拉斯算子,用于增强边缘检测。
+
+ :param image: 输入的图像(灰度图像)
+ :return: 拉普拉斯图像
+ """
+ laplacian = cv2.Laplacian(image, cv2.CV_64F)
+ return laplacian
+
+ def harmonize_edges(self, original, interpolated, laplacian):
+ """
+ 使用拉普拉斯算子结果来调整插值后的图像,增强边缘细节。
+
+ :param original: 原始CFA图像
+ :param interpolated: 双线性插值后的图像
+ :param laplacian: 计算的拉普拉斯图像
+ :return: 经过拉普拉斯调和的图像
+ """
+ return np.clip(interpolated + 0.2 * laplacian, 0, 255).astype(np.uint8)
+
+ def debayer_laplacian_harmonization(self, cfa_image, pattern='BGGR'):
+ """
+ 使用简化的拉普拉斯调和方法进行去拜耳处理,以增强边缘处理。
+
+ :param cfa_image: 输入的CFA图像
+ :param pattern: Bayer模式 ('BGGR', 'RGGB', 'GBRG', 'GRBG')
+ :return: 去拜耳处理后的RGB图像
+ """
+ # Step 1: 双线性插值
+ interpolated_image = self.debayer_bilinear(cfa_image, pattern)
+
+ # Step 2: 计算每个通道的拉普拉斯图像
+ laplacian_b = self.calculate_laplacian(interpolated_image[:, :, 0])
+ laplacian_g = self.calculate_laplacian(interpolated_image[:, :, 1])
+ laplacian_r = self.calculate_laplacian(interpolated_image[:, :, 2])
+
+ # Step 3: 使用拉普拉斯结果调和插值后的图像
+ harmonized_b = self.harmonize_edges(cfa_image, interpolated_image[:, :, 0], laplacian_b)
+ harmonized_g = self.harmonize_edges(cfa_image, interpolated_image[:, :, 1], laplacian_g)
+ harmonized_r = self.harmonize_edges(cfa_image, interpolated_image[:, :, 2], laplacian_r)
+
+ # Step 4: 合并调和后的通道
+ harmonized_image = np.stack((harmonized_b, harmonized_g, harmonized_r), axis=-1)
+
+ return harmonized_image
+
+ def extend_image_edges(self, image: np.ndarray, pad_width: int) -> np.ndarray:
+ """
+ Extend image edges using mirror padding to handle boundary issues during interpolation.
+ """
+ return np.pad(image, pad_width, mode='reflect')
+
+ def visualize_intermediate_steps(self, cfa_image: np.ndarray):
+ """
+ Visualize intermediate steps in the debayering process.
+ """
+ gradient_x, gradient_y = calculate_gradients(cfa_image)
+ green_channel = interpolate_green_channel(
+ cfa_image, gradient_x, gradient_y)
+ red_channel, blue_channel = interpolate_red_blue_channel(
+ cfa_image, green_channel, pattern)
+
+ # 显示梯度和各通道图像
+ cv2.imshow("Gradient X", gradient_x)
+ cv2.imshow("Gradient Y", gradient_y)
+ cv2.imshow("Green Channel", green_channel)
+ cv2.imshow("Red Channel", red_channel)
+ cv2.imshow("Blue Channel", blue_channel)
+ cv2.waitKey(0)
+ cv2.destroyAllWindows()
+
+ def process_batch(self, image_paths: list, num_threads: int = 4):
+ """
+ Batch processing for multiple CFA images using multithreading.
+ """
+ start_time = time.time()
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
+ results = executor.map(
+ lambda path: self.process_single_image(path), image_paths)
+
+ elapsed_time = time.time() - start_time
+ print(f"Batch processing completed in {elapsed_time:.2f} seconds.")
+
+ def process_single_image(self, path: str) -> np.ndarray:
+ """
+ Helper function for processing a single image.
+ """
+ cfa_image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
+ rgb_image = self.debayer_image(cfa_image)
+ output_path = path.replace('.png', f'_{self.method}.png')
+ cv2.imwrite(output_path, rgb_image)
+ print(f"Processed {path} -> {output_path}")
+ return rgb_image
diff --git a/pysrc/image/debayer/metrics.py b/pysrc/image/debayer/metrics.py
new file mode 100644
index 00000000..0f27e42c
--- /dev/null
+++ b/pysrc/image/debayer/metrics.py
@@ -0,0 +1,27 @@
+import cv2
+import numpy as np
+from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
+
+def evaluate_image_quality(rgb_image: np.ndarray) -> dict:
+ """
+ Evaluate the quality of the debayered image.
+ """
+ gray_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2GRAY)
+ laplacian_var = cv2.Laplacian(gray_image, cv2.CV_64F).var()
+
+ mean_colors = cv2.mean(rgb_image)[:3]
+
+ return {
+ "sharpness": laplacian_var,
+ "mean_red": mean_colors[2],
+ "mean_green": mean_colors[1],
+ "mean_blue": mean_colors[0]
+ }
+
+def calculate_image_quality(original: np.ndarray, processed: np.ndarray) -> Tuple[float, float]:
+ """
+ Calculate PSNR and SSIM between the original and processed images.
+ """
+ psnr_value = psnr(original, processed, data_range=processed.max() - processed.min())
+ ssim_value = ssim(original, processed, multichannel=True)
+ return psnr_value, ssim_value
diff --git a/pysrc/image/debayer/utils.py b/pysrc/image/debayer/utils.py
new file mode 100644
index 00000000..4ef5e83a
--- /dev/null
+++ b/pysrc/image/debayer/utils.py
@@ -0,0 +1,117 @@
+from typing import Tuple
+import numpy as np
+
+def calculate_gradients(cfa_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Calculate the gradients of a CFA image.
+ """
+ gradient_x = np.abs(np.diff(cfa_image, axis=1))
+ gradient_y = np.abs(np.diff(cfa_image, axis=0))
+
+ gradient_x = np.pad(gradient_x, ((0, 0), (0, 1)), 'constant')
+ gradient_y = np.pad(gradient_y, ((0, 1), (0, 0)), 'constant')
+
+ return gradient_x, gradient_y
+
+def interpolate_green_channel(cfa_image: np.ndarray, gradient_x: np.ndarray, gradient_y: np.ndarray) -> np.ndarray:
+ """
+ Interpolate the green channel of the CFA image.
+ """
+ height, width = cfa_image.shape
+ green_channel = np.zeros((height, width))
+
+ for i in range(1, height - 1):
+ for j in range(1, width - 1):
+ if (i % 2 == 0 and j % 2 == 1) or (i % 2 == 1 and j % 2 == 0):
+ # 当前点是绿色通道点,直接赋值
+ green_channel[i, j] = cfa_image[i, j]
+ else:
+ # 当前点不是绿色通道点,需要插值
+ if gradient_x[i, j] < gradient_y[i, j]:
+ green_channel[i, j] = 0.5 * \
+ (cfa_image[i, j-1] + cfa_image[i, j+1])
+ else:
+ green_channel[i, j] = 0.5 * \
+ (cfa_image[i-1, j] + cfa_image[i+1, j])
+
+ return green_channel
+
+def interpolate_red_blue_channel(cfa_image: np.ndarray, green_channel: np.ndarray, pattern: str) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Interpolate the red and blue channels of the CFA image.
+ """
+ height, width = cfa_image.shape
+ red_channel = np.zeros((height, width))
+ blue_channel = np.zeros((height, width))
+
+ if pattern == 'BGGR':
+ for i in range(0, height, 2):
+ for j in range(0, width, 2):
+ blue_channel[i, j] = cfa_image[i, j]
+ red_channel[i+1, j+1] = cfa_image[i+1, j+1]
+
+ green_r = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1])
+ green_b = 0.5 * (green_channel[i, j] + green_channel[i+1, j+1])
+
+ blue_channel[i+1, j] = cfa_image[i+1, j] - \
+ green_b + green_channel[i+1, j]
+ blue_channel[i, j+1] = cfa_image[i, j+1] - \
+ green_b + green_channel[i, j+1]
+ red_channel[i, j] = cfa_image[i, j] - \
+ green_r + green_channel[i, j]
+ red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \
+ green_r + green_channel[i+1, j+1]
+
+ elif pattern == 'RGGB':
+ for i in range(0, height, 2):
+ for j in range(0, width, 2):
+ red_channel[i, j] = cfa_image[i, j]
+ blue_channel[i+1, j+1] = cfa_image[i+1, j+1]
+
+ green_r = 0.5 * (green_channel[i, j+1] + green_channel[i+1, j])
+ green_b = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1])
+
+ red_channel[i+1, j] = cfa_image[i+1, j] - \
+ green_r + green_channel[i+1, j]
+ red_channel[i, j+1] = cfa_image[i, j+1] - \
+ green_r + green_channel[i, j+1]
+ blue_channel[i, j] = cfa_image[i, j] - \
+ green_b + green_channel[i, j]
+ blue_channel[i+1, j+1] = cfa_image[i+1, j+1] - \
+ green_b + green_channel[i+1, j+1]
+
+ elif pattern == 'GBRG':
+ for i in range(0, height, 2):
+ for j in range(0, width, 2):
+ green_channel[i, j+1] = cfa_image[i, j+1]
+ blue_channel[i+1, j] = cfa_image[i+1, j]
+
+ green_r = 0.5 * (green_channel[i, j] + green_channel[i+1, j+1])
+ green_b = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1])
+
+ red_channel[i, j] = cfa_image[i, j] - \
+ green_r + green_channel[i, j]
+ red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \
+ green_r + green_channel[i+1, j+1]
+ blue_channel[i, j] = cfa_image[i, j] - \
+ green_b + green_channel[i, j]
+ blue_channel[i+1, j+1] = cfa_image[i+1, j+1] - \
+ green_b + green_channel[i+1, j+1]
+
+ elif pattern == 'GRBG':
+ for i in range(0, height, 2):
+ for j in range(0, width, 2):
+ green_channel[i, j] = cfa_image[i, j]
+ red_channel[i+1, j] = cfa_image[i+1, j]
+
+ green_r = 0.5 * (green_channel[i, j] + green_channel[i+1, j+1])
+ green_b = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1])
+
+ red_channel[i, j] = cfa_image[i, j] - \
+ green_r + green_channel[i, j]
+ red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \
+ green_r + green_channel[i+1, j+1]
+ blue_channel[i+1, j] = cfa_image[i+1, j] - \
+ green_b + green_channel[i+1, j]
+ blue_channel[i, j+1] = cfa_image[i, j+1] - \
+ green_b + green_channel[i, j+1]
diff --git a/pysrc/image/defect_map/__init__.py b/pysrc/image/defect_map/__init__.py
new file mode 100644
index 00000000..2e190333
--- /dev/null
+++ b/pysrc/image/defect_map/__init__.py
@@ -0,0 +1 @@
+from .defect_correction import defect_map_enhanced, parallel_defect_map
diff --git a/pysrc/image/defect_map/defect_correction.py b/pysrc/image/defect_map/defect_correction.py
new file mode 100644
index 00000000..e138f708
--- /dev/null
+++ b/pysrc/image/defect_map/defect_correction.py
@@ -0,0 +1,148 @@
+import numpy as np
+from scipy.ndimage import generic_filter, median_filter, minimum_filter, maximum_filter, gaussian_filter
+from skimage.filters import sobel
+from skimage import img_as_float
+from .interpolation import interpolate_defects
+import multiprocessing
+from typing import Optional
+
+def defect_map_enhanced(image: np.ndarray, defect_map: np.ndarray, operation: str = 'mean', structure: str = 'square',
+ radius: int = 1, is_cfa: bool = False, protect_edges: bool = False,
+ adaptive_structure: bool = False) -> np.ndarray:
+ """
+ Enhanced Defect Map function to repair defective pixels in an image.
+
+ Parameters:
+ - image: np.ndarray, Input image to be repaired.
+ - defect_map: np.ndarray, Defect map where defective pixels are marked as 0 (black).
+ - operation: str, Operation to perform ('mean', 'gaussian', 'minimum', 'maximum', 'median', 'bilinear', 'bicubic').
+ - structure: str, Neighborhood structure ('square', 'circular', 'horizontal', 'vertical').
+ - radius: int, Radius or size of the neighborhood.
+ - is_cfa: bool, If True, process the image as a CFA (Color Filter Array) image.
+ - protect_edges: bool, Whether to protect the edges of the image.
+ - adaptive_structure: bool, Whether to adaptively adjust the neighborhood based on defect density.
+
+ Returns:
+ - corrected_image: np.ndarray, The repaired image.
+ """
+ if structure == 'square':
+ footprint = np.ones((2 * radius + 1, 2 * radius + 1))
+ elif structure == 'circular':
+ y, x = np.ogrid[-radius: radius + 1, -radius: radius + 1]
+ footprint = x**2 + y**2 <= radius**2
+ elif structure == 'horizontal':
+ footprint = np.zeros((1, 2 * radius + 1))
+ footprint[0, :] = 1
+ elif structure == 'vertical':
+ footprint = np.zeros((2 * radius + 1, 1))
+ footprint[:, 0] = 1
+ else:
+ raise ValueError("Invalid structure type.")
+
+ mask = defect_map == 0
+
+ if protect_edges:
+ edges = sobel(img_as_float(image))
+ mask = np.logical_and(mask, edges < np.mean(edges))
+
+ corrected_image = np.copy(image)
+
+ if is_cfa:
+ for c in range(3): # Assuming an RGB image
+ corrected_image[:, :, c] = correct_channel(
+ image[:, :, c], mask[:, :, c], operation, footprint, adaptive_structure)
+ else:
+ corrected_image = correct_channel(
+ image, mask, operation, footprint, adaptive_structure)
+
+ return corrected_image
+
+def correct_channel(channel: np.ndarray, mask: np.ndarray, operation: str, footprint: np.ndarray,
+ adaptive_structure: bool) -> np.ndarray:
+ """
+ Helper function to repair a single channel of an image.
+
+ Parameters:
+ - channel: np.ndarray, Image channel to be repaired.
+ - mask: np.ndarray, Defect mask indicating defective pixels.
+ - operation: str, Operation to perform.
+ - footprint: np.ndarray, Neighborhood structure.
+ - adaptive_structure: bool, Whether to adaptively adjust the neighborhood size.
+
+ Returns:
+ - channel: np.ndarray, The repaired image channel.
+ """
+ if adaptive_structure:
+ density = np.sum(mask) / mask.size
+ radius = int(3 / density) if density > 0 else 1
+ footprint = np.ones((2 * radius + 1, 2 * radius + 1))
+
+ if operation == 'mean':
+ channel_corrected = generic_filter(
+ channel, np.mean, footprint=footprint, mode='constant', cval=np.nan)
+ elif operation == 'gaussian':
+ channel_corrected = gaussian_filter(channel, sigma=1)
+ elif operation == 'minimum':
+ channel_corrected = minimum_filter(channel, footprint=footprint)
+ elif operation == 'maximum':
+ channel_corrected = maximum_filter(channel, footprint=footprint)
+ elif operation == 'median':
+ channel_corrected = median_filter(channel, footprint=footprint)
+ elif operation == 'bilinear':
+ channel_corrected = interpolate_defects(channel, mask, method='linear')
+ elif operation == 'bicubic':
+ channel_corrected = interpolate_defects(channel, mask, method='cubic')
+ else:
+ raise ValueError("Invalid operation type.")
+
+ channel[mask] = channel_corrected[mask]
+
+ return channel
+
+def parallel_defect_map(image: np.ndarray, defect_map: np.ndarray, **kwargs) -> np.ndarray:
+ """
+ Parallel processing for defect map repair.
+
+ Parameters:
+ - image: np.ndarray, Input image.
+ - defect_map: np.ndarray, Defect map indicating defective pixels.
+ - kwargs: Additional keyword arguments for defect_map_enhanced.
+
+ Returns:
+ - corrected_image: np.ndarray, The repaired image.
+ """
+ if image.ndim == 2:
+ return defect_map_enhanced(image, defect_map, **kwargs)
+
+ pool = multiprocessing.Pool()
+ channels = [image[:, :, i] for i in range(image.shape[2])]
+
+ task_args = [(ch, defect_map, kwargs['operation'], kwargs['structure'], kwargs['radius'],
+ kwargs['is_cfa'], kwargs['protect_edges'], kwargs['adaptive_structure']) for ch in channels]
+
+ results = pool.starmap(defect_map_enhanced_single_channel, task_args)
+ pool.close()
+ pool.join()
+
+ return np.stack(results, axis=-1)
+
+def defect_map_enhanced_single_channel(channel: np.ndarray, defect_map: np.ndarray, operation: str,
+ structure: str, radius: int, is_cfa: bool, protect_edges: bool,
+ adaptive_structure: bool) -> np.ndarray:
+ """
+ Single-channel version of defect_map_enhanced for multiprocessing.
+
+ Parameters:
+ - channel: np.ndarray, Single image channel.
+ - defect_map: np.ndarray, Defect map indicating defective pixels.
+ - operation: str, Operation to perform.
+ - structure: str, Neighborhood structure.
+ - radius: int, Radius or size of the neighborhood.
+ - is_cfa: bool, If True, process as a CFA image.
+ - protect_edges: bool, Whether to protect the edges of the image.
+ - adaptive_structure: bool, Whether to adaptively adjust the neighborhood size.
+
+ Returns:
+ - channel: np.ndarray, The repaired image channel.
+ """
+ return defect_map_enhanced(channel, defect_map, operation, structure, radius, is_cfa, protect_edges, adaptive_structure)
diff --git a/pysrc/image/defect_map/interpolation.py b/pysrc/image/defect_map/interpolation.py
new file mode 100644
index 00000000..1838ea80
--- /dev/null
+++ b/pysrc/image/defect_map/interpolation.py
@@ -0,0 +1,19 @@
+import numpy as np
+from scipy.interpolate import griddata
+
+def interpolate_defects(image: np.ndarray, mask: np.ndarray, method: str = 'linear') -> np.ndarray:
+ """
+ Interpolate defective pixels in an image.
+
+ Parameters:
+ - image: np.ndarray, Image data.
+ - mask: np.ndarray, Defect mask indicating defective pixels.
+ - method: str, Interpolation method ('linear', 'cubic').
+
+ Returns:
+ - interpolated_image: np.ndarray, Image with interpolated defective pixels.
+ """
+ x, y = np.indices(image.shape)
+ points = np.column_stack((x[~mask], y[~mask]))
+ values = image[~mask]
+ return griddata(points, values, (x, y), method=method)
diff --git a/pysrc/image/defect_map/utils.py b/pysrc/image/defect_map/utils.py
new file mode 100644
index 00000000..783c0933
--- /dev/null
+++ b/pysrc/image/defect_map/utils.py
@@ -0,0 +1,24 @@
+import cv2
+import numpy as np
+
+def save_image(filepath: str, image: np.ndarray) -> None:
+ """
+ Save an image to a specified file path.
+
+ Parameters:
+ - filepath: str, The file path where the image should be saved.
+ - image: np.ndarray, The image data to be saved.
+ """
+ cv2.imwrite(filepath, image)
+
+def display_image(window_name: str, image: np.ndarray) -> None:
+ """
+ Display an image in a new window.
+
+ Parameters:
+ - window_name: str, The name of the window where the image will be displayed.
+ - image: np.ndarray, The image data to be displayed.
+ """
+ cv2.imshow(window_name, image)
+ cv2.waitKey(0)
+ cv2.destroyAllWindows()
diff --git a/pysrc/image/fluxcalibration/__init__.py b/pysrc/image/fluxcalibration/__init__.py
new file mode 100644
index 00000000..3eafd815
--- /dev/null
+++ b/pysrc/image/fluxcalibration/__init__.py
@@ -0,0 +1,6 @@
+from .core import CalibrationParams
+from .calibration import flux_calibration, save_to_fits
+from .utils import read_fits_header, instrument_response_correction, background_noise_correction
+
+__all__ = ['CalibrationParams', 'flux_calibration', 'save_to_fits',
+ 'read_fits_header', 'instrument_response_correction', 'background_noise_correction']
diff --git a/pysrc/image/fluxcalibration/calibration.py b/pysrc/image/fluxcalibration/calibration.py
new file mode 100644
index 00000000..b3d3b02b
--- /dev/null
+++ b/pysrc/image/fluxcalibration/calibration.py
@@ -0,0 +1,68 @@
+import numpy as np
+from astropy.io import fits
+from .core import CalibrationParams
+from .utils import instrument_response_correction, background_noise_correction
+
+def compute_flx2dn(params: CalibrationParams) -> float:
+ """
+ Compute the flux conversion factor (FLX2DN).
+ :param params: Calibration parameters.
+ :return: Flux conversion factor.
+ """
+ c = 3.0e8 # Speed of light in m/s
+ h = 6.626e-34 # Planck's constant in J·s
+ wavelength_m = params.wavelength * 1e-9 # Convert nm to meters
+
+ aperture_area = np.pi * (params.aperture**2 - params.obstruction**2) / 4
+ FLX2DN = (params.exposure_time * aperture_area * params.filter_width *
+ params.transmissivity * params.gain * params.quantum_efficiency *
+ (1 - params.extinction) * (wavelength_m / (c * h)))
+ return FLX2DN
+
+def flux_calibration(image: np.ndarray, params: CalibrationParams,
+ response_function: Optional[np.ndarray] = None) -> np.ndarray:
+ """
+ Perform flux calibration on an astronomical image.
+ :param image: Input image (numpy array).
+ :param params: Calibration parameters.
+ :param response_function: Optional instrument response function (numpy array).
+ :return: Flux-calibrated and rescaled image.
+ """
+ if response_function is not None:
+ image = instrument_response_correction(image, response_function)
+
+ FLX2DN = compute_flx2dn(params)
+ calibrated_image = image / FLX2DN
+ calibrated_image = background_noise_correction(calibrated_image)
+
+ # Rescale the image to the range [0, 1]
+ min_val = np.min(calibrated_image)
+ max_val = np.max(calibrated_image)
+ FLXRANGE = max_val - min_val
+ FLXMIN = min_val
+ rescaled_image = (calibrated_image - min_val) / FLXRANGE
+
+ return rescaled_image, FLXMIN, FLXRANGE, FLX2DN
+
+def save_to_fits(image: np.ndarray, filename: str, FLXMIN: float, FLXRANGE: float,
+ FLX2DN: float, header_info: dict = {}) -> None:
+ """
+ Save the calibrated image to a FITS file with necessary header information.
+ :param image: Calibrated image (numpy array).
+ :param filename: Output FITS file name.
+ :param FLXMIN: Minimum flux value used for rescaling.
+ :param FLXRANGE: Range of flux values used for rescaling.
+ :param FLX2DN: Flux conversion factor.
+ :param header_info: Dictionary of additional header information to include.
+ """
+ hdu = fits.PrimaryHDU(image)
+ hdr = hdu.header
+ hdr['FLXMIN'] = FLXMIN
+ hdr['FLXRANGE'] = FLXRANGE
+ hdr['FLX2DN'] = FLX2DN
+
+ # Add additional header information
+ for key, value in header_info.items():
+ hdr[key] = value
+
+ hdu.writeto(filename, overwrite=True)
diff --git a/pysrc/image/fluxcalibration/core.py b/pysrc/image/fluxcalibration/core.py
new file mode 100644
index 00000000..af7773d5
--- /dev/null
+++ b/pysrc/image/fluxcalibration/core.py
@@ -0,0 +1,15 @@
+from dataclasses import dataclass
+import numpy as np
+from typing import Optional
+
+@dataclass
+class CalibrationParams:
+ wavelength: float # Effective filter wavelength in nm
+ transmissivity: float # Filter transmissivity in the range (0,1)
+ filter_width: float # Filter bandwidth in nm
+ aperture: float # Telescope aperture diameter in mm
+ obstruction: float # Telescope central obstruction diameter in mm
+ exposure_time: float # Exposure time in seconds
+ extinction: float # Atmospheric extinction in the range [0,1)
+ gain: float # Sensor gain in e-/ADU
+ quantum_efficiency: float # Sensor quantum efficiency in the range (0,1)
diff --git a/pysrc/image/fluxcalibration/utils.py b/pysrc/image/fluxcalibration/utils.py
new file mode 100644
index 00000000..b3c3739b
--- /dev/null
+++ b/pysrc/image/fluxcalibration/utils.py
@@ -0,0 +1,41 @@
+import numpy as np
+from astropy.stats import sigma_clipped_stats
+
+def instrument_response_correction(image: np.ndarray, response_function: np.ndarray) -> np.ndarray:
+ """
+ Apply instrument response correction to the image.
+ :param image: Input image (numpy array).
+ :param response_function: Instrument response function (numpy array of the same shape as the image).
+ :return: Corrected image.
+ """
+ return image / response_function
+
+def background_noise_correction(image: np.ndarray) -> np.ndarray:
+ """
+ Estimate and subtract the background noise from the image.
+ :param image: Input image (numpy array).
+ :return: Image with background noise subtracted.
+ """
+ _, median, _ = sigma_clipped_stats(image, sigma=3.0) # Estimate background
+ return image - median
+
+def read_fits_header(file_path: str) -> dict:
+ """
+ Reads the FITS header and returns the necessary calibration parameters.
+ :param file_path: Path to the FITS file.
+ :return: Dictionary containing calibration parameters.
+ """
+ with fits.open(file_path) as hdul:
+ header = hdul[0].header
+ params = {
+ 'wavelength': header.get('WAVELEN', 550), # nm
+ 'transmissivity': header.get('TRANSMIS', 0.8),
+ 'filter_width': header.get('FILTWDTH', 100), # nm
+ 'aperture': header.get('APERTURE', 200), # mm
+ 'obstruction': header.get('OBSTRUCT', 50), # mm
+ 'exposure_time': header.get('EXPTIME', 60), # seconds
+ 'extinction': header.get('EXTINCT', 0.1),
+ 'gain': header.get('GAIN', 1.5), # e-/ADU
+ 'quantum_efficiency': header.get('QUANTEFF', 0.9),
+ }
+ return params
diff --git a/pysrc/image/image_io/__init__.py b/pysrc/image/image_io/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/pysrc/image/image_io/io.py b/pysrc/image/image_io/io.py
new file mode 100644
index 00000000..0b857551
--- /dev/null
+++ b/pysrc/image/image_io/io.py
@@ -0,0 +1,86 @@
+import cv2
+import numpy as np
+from dataclasses import dataclass
+from typing import Tuple
+
+
+@dataclass
+class ImageProcessor:
+ """
+ A class that provides various image processing functionalities.
+ """
+
+ @staticmethod
+ def save_image(filepath: str, image: np.ndarray) -> None:
+ """
+ Save an image to a specified file path.
+
+ Parameters:
+ - filepath: str, The file path where the image should be saved.
+ - image: np.ndarray, The image data to be saved.
+ """
+ cv2.imwrite(filepath, image)
+
+ @staticmethod
+ def display_image(window_name: str, image: np.ndarray) -> None:
+ """
+ Display an image in a new window.
+
+ Parameters:
+ - window_name: str, The name of the window where the image will be displayed.
+ - image: np.ndarray, The image data to be displayed.
+ """
+ cv2.imshow(window_name, image)
+ cv2.waitKey(0)
+ cv2.destroyAllWindows()
+
+ @staticmethod
+ def load_image(filepath: str, color_mode: int = cv2.IMREAD_COLOR) -> np.ndarray:
+ """
+ Load an image from a specified file path.
+
+ Parameters:
+ - filepath: str, The file path from where the image will be loaded.
+ - color_mode: int, The color mode in which to load the image (default: cv2.IMREAD_COLOR).
+
+ Returns:
+ - np.ndarray, The loaded image data.
+ """
+ return cv2.imread(filepath, color_mode)
+
+ @staticmethod
+ def resize_image(image: np.ndarray, size: Tuple[int, int]) -> np.ndarray:
+ """
+ Resize an image to a new size.
+
+ Parameters:
+ - image: np.ndarray, The image data to be resized.
+ - size: Tuple[int, int], The new size as (width, height).
+
+ Returns:
+ - np.ndarray, The resized image data.
+ """
+ return cv2.resize(image, size)
+
+ @staticmethod
+ def crop_image(image: np.ndarray, start: Tuple[int, int], size: Tuple[int, int]) -> np.ndarray:
+ """
+ Crop a region from an image.
+
+ Parameters:
+ - image: np.ndarray, The image data to be cropped.
+ - start: Tuple[int, int], The top-left corner of the crop region as (x, y).
+ - size: Tuple[int, int], The size of the crop region as (width, height).
+
+ Returns:
+ - np.ndarray, The cropped image data.
+ """
+ x, y = start
+ w, h = size
+ return image[y:y+h, x:x+w]
+
+
+# Defining the public API of the module
+__all__ = [
+ 'ImageProcessor'
+]
diff --git a/pysrc/image/raw/__init__.py b/pysrc/image/raw/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/pysrc/image/raw/raw.py b/pysrc/image/raw/raw.py
new file mode 100644
index 00000000..eb159edc
--- /dev/null
+++ b/pysrc/image/raw/raw.py
@@ -0,0 +1,90 @@
+import rawpy
+import cv2
+import numpy as np
+
+class RawImageProcessor:
+ def __init__(self, raw_path):
+ """初始化并读取RAW图像"""
+ self.raw_path = raw_path
+ self.raw = rawpy.imread(raw_path)
+ self.rgb_image = self.raw.postprocess(
+ gamma=(1.0, 1.0),
+ no_auto_bright=True,
+ use_camera_wb=True,
+ output_bps=8
+ )
+ # 转换为OpenCV使用的BGR格式
+ self.bgr_image = cv2.cvtColor(self.rgb_image, cv2.COLOR_RGB2BGR)
+
+ def adjust_contrast(self, alpha=1.0):
+ """调整图像对比度"""
+ self.bgr_image = cv2.convertScaleAbs(self.bgr_image, alpha=alpha)
+
+ def adjust_brightness(self, beta=0):
+ """调整图像亮度"""
+ self.bgr_image = cv2.convertScaleAbs(self.bgr_image, beta=beta)
+
+ def apply_sharpening(self):
+ """应用图像锐化"""
+ kernel = np.array([[0, -1, 0],
+ [-1, 5,-1],
+ [0, -1, 0]])
+ self.bgr_image = cv2.filter2D(self.bgr_image, -1, kernel)
+
+ def apply_gamma_correction(self, gamma=1.0):
+ """应用Gamma校正"""
+ inv_gamma = 1.0 / gamma
+ table = np.array([((i / 255.0) ** inv_gamma) * 255
+ for i in np.arange(0, 256)]).astype("uint8")
+ self.bgr_image = cv2.LUT(self.bgr_image, table)
+
+ def save_image(self, output_path, file_format="png", jpeg_quality=90):
+ """保存图像为指定格式"""
+ if file_format.lower() == "jpg" or file_format.lower() == "jpeg":
+ cv2.imwrite(output_path, self.bgr_image, [cv2.IMWRITE_JPEG_QUALITY, jpeg_quality])
+ else:
+ cv2.imwrite(output_path, self.bgr_image)
+
+ def show_image(self, window_name="Image"):
+ """显示处理后的图像"""
+ cv2.imshow(window_name, self.bgr_image)
+ cv2.waitKey(0)
+ cv2.destroyAllWindows()
+
+ def get_bgr_image(self):
+ """返回处理后的BGR图像"""
+ return self.bgr_image
+
+ def reset(self):
+ """重置图像到最初的状态"""
+ self.bgr_image = cv2.cvtColor(self.rgb_image, cv2.COLOR_RGB2BGR)
+
+# 使用示例
+if __name__ == "__main__":
+ # 初始化RAW图像处理器
+ processor = RawImageProcessor('path_to_your_image.raw')
+
+ # 调整对比度
+ processor.adjust_contrast(alpha=1.3)
+
+ # 调整亮度
+ processor.adjust_brightness(beta=20)
+
+ # 应用锐化
+ processor.apply_sharpening()
+
+ # 应用Gamma校正
+ processor.apply_gamma_correction(gamma=1.2)
+
+ # 显示处理后的图像
+ processor.show_image()
+
+ # 保存处理后的图像
+ processor.save_image('output_image.png')
+
+ # 重置图像
+ processor.reset()
+
+ # 进行其他处理并保存为JPEG
+ processor.adjust_contrast(alpha=1.1)
+ processor.save_image('output_image.jpg', file_format="jpg", jpeg_quality=85)
diff --git a/pysrc/image/resample/__init__.py b/pysrc/image/resample/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/pysrc/image/resample/resample.py b/pysrc/image/resample/resample.py
new file mode 100644
index 00000000..95481218
--- /dev/null
+++ b/pysrc/image/resample/resample.py
@@ -0,0 +1,185 @@
+import cv2
+import numpy as np
+from PIL import Image
+import os
+from typing import Optional, Tuple, Literal
+
+
+def resample_image(input_image_path: str,
+ output_image_path: str,
+ width: Optional[int] = None,
+ height: Optional[int] = None,
+ scale: Optional[float] = None,
+ resolution: Optional[Tuple[int, int]] = None,
+ interpolation: int = cv2.INTER_LINEAR,
+ preserve_aspect_ratio: bool = True,
+ crop_area: Optional[Tuple[int, int, int, int]] = None,
+ edge_detection: bool = False,
+ color_space: Literal['BGR', 'GRAY', 'HSV', 'RGB'] = 'BGR',
+ add_watermark: bool = False,
+ watermark_text: str = '',
+ watermark_position: Tuple[int, int] = (0, 0),
+ watermark_opacity: float = 0.5,
+ batch_mode: bool = False,
+ output_format: str = 'jpg',
+ brightness: float = 1.0,
+ contrast: float = 1.0,
+ sharpen: bool = False,
+ rotate_angle: Optional[float] = None):
+ """
+ Resamples an image with given dimensions, scale, resolution, and additional processing options.
+
+ :param input_image_path: Path to the input image (or directory in batch mode)
+ :param output_image_path: Path to save the resampled image(s)
+ :param width: Desired width in pixels
+ :param height: Desired height in pixels
+ :param scale: Scale factor for resizing (e.g., 0.5 for half size, 2.0 for double size)
+ :param resolution: Tuple of horizontal and vertical resolution (in dpi)
+ :param interpolation: Interpolation method (e.g., cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_NEAREST)
+ :param preserve_aspect_ratio: Whether to preserve the aspect ratio of the original image
+ :param crop_area: Tuple defining the crop area (x, y, w, h)
+ :param edge_detection: Whether to apply edge detection before resizing
+ :param color_space: Color space to convert the image to (e.g., 'GRAY', 'HSV', 'RGB')
+ :param add_watermark: Whether to add a watermark to the image
+ :param watermark_text: The text to be used as a watermark
+ :param watermark_position: Position tuple (x, y) for the watermark
+ :param watermark_opacity: Opacity level for the watermark (0.0 to 1.0)
+ :param batch_mode: Whether to process multiple images in a directory
+ :param output_format: Output image format (e.g., 'jpg', 'png', 'tiff')
+ :param brightness: Brightness adjustment factor (1.0 = no change)
+ :param contrast: Contrast adjustment factor (1.0 = no change)
+ :param sharpen: Whether to apply sharpening to the image
+ :param rotate_angle: Angle to rotate the image (in degrees)
+ """
+
+ def adjust_brightness_contrast(image: np.ndarray, brightness: float = 1.0, contrast: float = 1.0) -> np.ndarray:
+ # Clip the pixel values to be in the valid range after adjustment
+ adjusted = cv2.convertScaleAbs(
+ image, alpha=contrast, beta=(brightness - 1.0) * 255)
+ return adjusted
+
+ def add_text_watermark(image: np.ndarray, text: str, position: Tuple[int, int], opacity: float) -> np.ndarray:
+ overlay = image.copy()
+ output = image.copy()
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ font_scale = 1
+ thickness = 2
+ text_size = cv2.getTextSize(text, font, font_scale, thickness)[0]
+ text_x = position[0]
+ text_y = position[1] + text_size[1]
+ cv2.putText(overlay, text, (text_x, text_y), font,
+ font_scale, (255, 255, 255), thickness)
+ # Combine original image with overlay
+ cv2.addWeighted(overlay, opacity, output, 1 - opacity, 0, output)
+ return output
+
+ def sharpen_image(image: np.ndarray) -> np.ndarray:
+ kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
+ sharpened = cv2.filter2D(image, -1, kernel)
+ return sharpened
+
+ def process_image(image_path: str, output_path: str):
+ img = cv2.imread(image_path)
+ if img is None:
+ raise ValueError(f"Cannot load image from {image_path}")
+
+ original_height, original_width = img.shape[:2]
+
+ # Crop if needed
+ if crop_area:
+ x, y, w, h = crop_area
+ img = img[y:y+h, x:x+w]
+
+ # Edge detection
+ if edge_detection:
+ img = cv2.Canny(img, 100, 200)
+
+ # Convert color space if needed
+ # Only convert if image is not already grayscale
+ if color_space == 'GRAY' and len(img.shape) == 3:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ # Only convert if image has 3 channels
+ elif color_space == 'HSV' and len(img.shape) == 3:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
+ # Convert to RGB if required
+ elif color_space == 'RGB' and len(img.shape) == 3:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # Calculate new dimensions
+ if scale:
+ new_width = int(original_width * scale)
+ new_height = int(original_height * scale)
+ else:
+ if width and height:
+ new_width = width
+ new_height = height
+ elif width:
+ new_width = width
+ new_height = int(
+ (width / original_width) * original_height) if preserve_aspect_ratio else height
+ elif height:
+ new_height = height
+ new_width = int((height / original_height) *
+ original_width) if preserve_aspect_ratio else width
+ else:
+ new_width, new_height = original_width, original_height
+
+ # Perform resizing
+ resized_img = cv2.resize(
+ img, (new_width, new_height), interpolation=interpolation)
+
+ # Adjust brightness and contrast
+ resized_img = adjust_brightness_contrast(
+ resized_img, brightness, contrast)
+
+ # Apply sharpening if needed
+ if sharpen:
+ resized_img = sharpen_image(resized_img)
+
+ # Rotate the image if needed
+ if rotate_angle:
+ center = (new_width // 2, new_height // 2)
+ rotation_matrix = cv2.getRotationMatrix2D(
+ center, rotate_angle, 1.0)
+ resized_img = cv2.warpAffine(
+ resized_img, rotation_matrix, (new_width, new_height))
+
+ # Add watermark if needed
+ if add_watermark:
+ resized_img = add_text_watermark(
+ resized_img, watermark_text, watermark_position, watermark_opacity)
+
+ # Save the image
+ if resolution:
+ dpi_x, dpi_y = resolution
+ pil_img = Image.fromarray(cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB) if len(
+ resized_img.shape) == 3 else resized_img)
+ pil_img.save(output_path, dpi=(dpi_x, dpi_y), format=output_format)
+ else:
+ cv2.imwrite(output_path, resized_img)
+
+ # Batch processing mode
+ if batch_mode:
+ if not os.path.isdir(input_image_path):
+ raise ValueError(
+ "In batch mode, input_image_path must be a directory")
+ if not os.path.exists(output_image_path):
+ os.makedirs(output_image_path)
+ for filename in os.listdir(input_image_path):
+ file_path = os.path.join(input_image_path, filename)
+ output_file_path = os.path.join(
+ output_image_path, f"{os.path.splitext(filename)[0]}.{output_format}")
+ process_image(file_path, output_file_path)
+ else:
+ process_image(input_image_path, output_image_path)
+
+
+# Example usage:
+input_path = 'star_image.png'
+output_path = './output/star_image_processed.png'
+
+# Resize image, apply edge detection, adjust brightness/contrast, add watermark, sharpen, and rotate
+resample_image(input_path, output_path, width=800, height=600, interpolation=cv2.INTER_CUBIC, resolution=(300, 300),
+ crop_area=(100, 100, 400, 400), edge_detection=True, color_space='GRAY', add_watermark=True,
+ watermark_text='Sample Watermark', watermark_position=(10, 10), watermark_opacity=0.7,
+ brightness=1.2, contrast=1.5, output_format='png', sharpen=True, rotate_angle=45)
diff --git a/pysrc/image/star_detection/__init__.py b/pysrc/image/star_detection/__init__.py
new file mode 100644
index 00000000..b498b982
--- /dev/null
+++ b/pysrc/image/star_detection/__init__.py
@@ -0,0 +1,3 @@
+from .detection import StarDetectionConfig, multiscale_detect_stars
+from .clustering import cluster_stars
+from .preprocessing import load_image
diff --git a/pysrc/image/star_detection/clustering.py b/pysrc/image/star_detection/clustering.py
new file mode 100644
index 00000000..365b9db3
--- /dev/null
+++ b/pysrc/image/star_detection/clustering.py
@@ -0,0 +1,32 @@
+import numpy as np
+from sklearn.cluster import DBSCAN
+from typing import List, Tuple
+
+def cluster_stars(stars: List[Tuple[int, int]], dbscan_eps: float, dbscan_min_samples: int) -> List[Tuple[int, int]]:
+ """
+ Cluster stars using the DBSCAN algorithm.
+
+ Parameters:
+ - stars: List of star positions as (x, y) tuples.
+ - dbscan_eps: The maximum distance between two stars for them to be considered in the same neighborhood.
+ - dbscan_min_samples: The number of stars in a neighborhood for a point to be considered a core point.
+
+ Returns:
+ - List of clustered star positions as (x, y) tuples.
+ """
+ if len(stars) == 0:
+ return []
+
+ clustering = DBSCAN(eps=dbscan_eps, min_samples=dbscan_min_samples).fit(stars)
+ labels = clustering.labels_
+
+ unique_labels = set(labels)
+ clustered_stars = []
+ for label in unique_labels:
+ if label == -1: # -1 indicates noise
+ continue
+ class_members = [stars[i] for i in range(len(stars)) if labels[i] == label]
+ centroid = np.mean(class_members, axis=0).astype(int)
+ clustered_stars.append(tuple(centroid))
+
+ return clustered_stars
diff --git a/pysrc/image/star_detection/detection.py b/pysrc/image/star_detection/detection.py
new file mode 100644
index 00000000..e4a84298
--- /dev/null
+++ b/pysrc/image/star_detection/detection.py
@@ -0,0 +1,84 @@
+import cv2
+import numpy as np
+from .preprocessing import apply_median_filter, wavelet_transform, inverse_wavelet_transform, binarize, detect_stars, background_subtraction
+from typing import List, Tuple
+
+class StarDetectionConfig:
+ """
+ Configuration class for star detection settings.
+ """
+ def __init__(self,
+ median_filter_size: int = 3,
+ wavelet_levels: int = 4,
+ binarization_threshold: int = 30,
+ min_star_size: int = 10,
+ min_star_brightness: int = 20,
+ min_circularity: float = 0.7,
+ max_circularity: float = 1.3,
+ scales: List[float] = [1.0, 0.75, 0.5],
+ dbscan_eps: float = 10,
+ dbscan_min_samples: int = 2):
+ self.median_filter_size = median_filter_size
+ self.wavelet_levels = wavelet_levels
+ self.binarization_threshold = binarization_threshold
+ self.min_star_size = min_star_size
+ self.min_star_brightness = min_star_brightness
+ self.min_circularity = min_circularity
+ self.max_circularity = max_circularity
+ self.scales = scales
+ self.dbscan_eps = dbscan_eps
+ self.dbscan_min_samples = dbscan_min_samples
+
+
+def multiscale_detect_stars(image: np.ndarray, config: StarDetectionConfig) -> List[Tuple[int, int]]:
+ """
+ Detect stars in an image using multiscale analysis.
+
+ Parameters:
+ - image: Grayscale input image as a numpy array.
+ - config: Configuration object containing detection settings.
+
+ Returns:
+ - List of detected star positions as (x, y) tuples.
+ """
+ all_stars = []
+ for scale in config.scales:
+ resized_image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
+ filtered_image = apply_median_filter(resized_image, config)
+ pyramid = wavelet_transform(filtered_image, config.wavelet_levels)
+ background = pyramid[-1]
+ subtracted_image = background_subtraction(filtered_image, background)
+ pyramid = pyramid[:2]
+ processed_image = inverse_wavelet_transform(pyramid)
+ binary_image = binarize(processed_image, config)
+ _, star_properties = detect_stars(binary_image)
+ filtered_stars = filter_stars(star_properties, binary_image, config)
+ # Adjust star positions back to original scale
+ filtered_stars = [(int(x / scale), int(y / scale)) for (x, y) in filtered_stars]
+ all_stars.extend(filtered_stars)
+ return all_stars
+
+def filter_stars(star_properties: List[Tuple[Tuple[int, int], float, float]], binary_image: np.ndarray, config: StarDetectionConfig) -> List[Tuple[int, int]]:
+ """
+ Filter detected stars based on shape, size, and brightness.
+
+ Parameters:
+ - star_properties: List of tuples containing star properties (center, area, perimeter).
+ - binary_image: Binary image used for star detection.
+ - config: Configuration object containing filter settings.
+
+ Returns:
+ - List of filtered star positions as (x, y) tuples.
+ """
+ filtered_stars = []
+ for (center, area, perimeter) in star_properties:
+ circularity = (4 * np.pi * area) / (perimeter ** 2)
+ mask = np.zeros_like(binary_image)
+ cv2.circle(mask, center, 5, 255, -1)
+ star_pixels = cv2.countNonZero(mask)
+ brightness = np.mean(binary_image[mask == 255])
+ if (star_pixels > config.min_star_size and
+ brightness > config.min_star_brightness and
+ config.min_circularity <= circularity <= config.max_circularity):
+ filtered_stars.append(center)
+ return filtered_stars
diff --git a/pysrc/image/star_detection/preprocessing.py b/pysrc/image/star_detection/preprocessing.py
new file mode 100644
index 00000000..16753dbe
--- /dev/null
+++ b/pysrc/image/star_detection/preprocessing.py
@@ -0,0 +1,179 @@
+import cv2
+import numpy as np
+from astropy.io import fits
+from typing import List, Tuple
+
+def load_fits_image(file_path: str) -> np.ndarray:
+ """
+ Load a FITS image from the specified file path.
+
+ Parameters:
+ - file_path: Path to the FITS file.
+
+ Returns:
+ - Image data as a numpy array.
+ """
+ with fits.open(file_path) as hdul:
+ image_data = hdul[0].data
+ return image_data
+
+def preprocess_fits_image(image_data: np.ndarray) -> np.ndarray:
+ """
+ Preprocess FITS image by normalizing to the 0-255 range.
+
+ Parameters:
+ - image_data: Raw image data from the FITS file.
+
+ Returns:
+ - Preprocessed image data as a numpy array.
+ """
+ image_data = np.nan_to_num(image_data)
+ image_data = image_data.astype(np.float64)
+ image_data -= np.min(image_data)
+ image_data /= np.max(image_data)
+ image_data *= 255
+ return image_data.astype(np.uint8)
+
+def load_image(file_path: str) -> np.ndarray:
+ """
+ Load an image from the specified file path. Supports FITS and standard image formats.
+
+ Parameters:
+ - file_path: Path to the image file.
+
+ Returns:
+ - Loaded image as a numpy array.
+ """
+ if file_path.endswith('.fits'):
+ image_data = load_fits_image(file_path)
+ if image_data.ndim == 2:
+ return preprocess_fits_image(image_data)
+ elif image_data.ndim == 3:
+ channels = [preprocess_fits_image(image_data[..., i]) for i in range(image_data.shape[2])]
+ return cv2.merge(channels)
+ else:
+ image = cv2.imread(file_path, cv2.IMREAD_UNCHANGED)
+ if image is None:
+ raise ValueError("Unable to load image file: {}".format(file_path))
+ return image
+
+def apply_median_filter(image: np.ndarray, config) -> np.ndarray:
+ """
+ Apply median filtering to the image.
+
+ Parameters:
+ - image: Input image.
+ - config: Configuration object containing median filter settings.
+
+ Returns:
+ - Filtered image.
+ """
+ return cv2.medianBlur(image, config.median_filter_size)
+
+def wavelet_transform(image: np.ndarray, levels: int) -> List[np.ndarray]:
+ """
+ Perform wavelet transform using a Laplacian pyramid.
+
+ Parameters:
+ - image: Input image.
+ - levels: Number of levels in the wavelet transform.
+
+ Returns:
+ - List of wavelet transformed images at each level.
+ """
+ pyramid = []
+ current_image = image.copy()
+ for _ in range(levels):
+ down = cv2.pyrDown(current_image)
+ up = cv2.pyrUp(down, current_image.shape[:2])
+
+ # Resize up to match the original image size
+ up = cv2.resize(up, (current_image.shape[1], current_image.shape[0]))
+
+ # Calculate the difference to get the detail layer
+ layer = cv2.subtract(current_image, up)
+ pyramid.append(layer)
+ current_image = down
+
+ pyramid.append(current_image) # Add the final low-resolution image
+ return pyramid
+
+def inverse_wavelet_transform(pyramid: List[np.ndarray]) -> np.ndarray:
+ """
+ Reconstruct the image from its wavelet pyramid representation.
+
+ Parameters:
+ - pyramid: List of wavelet transformed images at each level.
+
+ Returns:
+ - Reconstructed image.
+ """
+ image = pyramid.pop()
+ while pyramid:
+ up = cv2.pyrUp(image, pyramid[-1].shape[:2])
+
+ # Resize up to match the size of the current level
+ up = cv2.resize(up, (pyramid[-1].shape[1], pyramid[-1].shape[0]))
+
+ # Add the detail layer to reconstruct the image
+ image = cv2.add(up, pyramid.pop())
+
+ return image
+
+def background_subtraction(image: np.ndarray, background: np.ndarray) -> np.ndarray:
+ """
+ Subtract the background from the image using the provided background image.
+
+ Parameters:
+ - image: Original image.
+ - background: Background image to subtract.
+
+ Returns:
+ - Image with background subtracted.
+ """
+ # Resize the background to match the original image size
+ background_resized = cv2.resize(background, (image.shape[1], image.shape[0]))
+
+ # Subtract background and ensure no negative values
+ result = cv2.subtract(image, background_resized)
+ result[result < 0] = 0
+ return result
+
+def binarize(image: np.ndarray, config) -> np.ndarray:
+ """
+ Binarize the image using a fixed threshold.
+
+ Parameters:
+ - image: Input image.
+ - config: Configuration object containing binarization settings.
+
+ Returns:
+ - Binarized image.
+ """
+ _, binary_image = cv2.threshold(image, config.binarization_threshold, 255, cv2.THRESH_BINARY)
+ return binary_image
+
+def detect_stars(binary_image: np.ndarray) -> Tuple[List[Tuple[int, int]], List[Tuple[Tuple[int, int], float, float]]]:
+ """
+ Detect stars in a binary image by finding contours.
+
+ Parameters:
+ - binary_image: Binarized image.
+
+ Returns:
+ - Tuple containing a list of star centers and a list of star properties (center, area, perimeter).
+ """
+ contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ star_centers = []
+ star_properties = []
+
+ for contour in contours:
+ M = cv2.moments(contour)
+ if M['m00'] != 0:
+ center = (int(M['m10'] / M['m00']), int(M['m01'] / M['m00']))
+ star_centers.append(center)
+ area = cv2.contourArea(contour)
+ perimeter = cv2.arcLength(contour, True)
+ star_properties.append((center, area, perimeter))
+
+ return star_centers, star_properties
diff --git a/pysrc/image/star_detection/utils.py b/pysrc/image/star_detection/utils.py
new file mode 100644
index 00000000..8c7fa420
--- /dev/null
+++ b/pysrc/image/star_detection/utils.py
@@ -0,0 +1,24 @@
+import cv2
+import numpy as np
+
+def save_image(filepath: str, image: np.ndarray) -> None:
+ """
+ Save an image to a specified file path.
+
+ Parameters:
+ - filepath: The file path where the image should be saved.
+ - image: The image data to be saved.
+ """
+ cv2.imwrite(filepath, image)
+
+def display_image(window_name: str, image: np.ndarray) -> None:
+ """
+ Display an image in a new window.
+
+ Parameters:
+ - window_name: The name of the window where the image will be displayed.
+ - image: The image data to be displayed.
+ """
+ cv2.imshow(window_name, image)
+ cv2.waitKey(0)
+ cv2.destroyAllWindows()
diff --git a/pysrc/image/transformation/__init__.py b/pysrc/image/transformation/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/pysrc/image/transformation/curve.py b/pysrc/image/transformation/curve.py
new file mode 100644
index 00000000..8275158a
--- /dev/null
+++ b/pysrc/image/transformation/curve.py
@@ -0,0 +1,154 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from scipy.interpolate import CubicSpline, Akima1DInterpolator
+
+
+class CurvesTransformation:
+ def __init__(self, interpolation='akima'):
+ self.points = []
+ self.interpolation = interpolation
+ self.curve = None
+ self.stored_curve = None
+
+ def add_point(self, x, y):
+ self.points.append((x, y))
+ self.points.sort() # Sort points by x value
+ self._update_curve()
+
+ def remove_point(self, index):
+ if 0 <= index < len(self.points):
+ self.points.pop(index)
+ self._update_curve()
+
+ def _update_curve(self):
+ if len(self.points) < 2:
+ self.curve = None
+ return
+
+ x, y = zip(*self.points)
+
+ if self.interpolation == 'cubic':
+ self.curve = CubicSpline(x, y)
+ elif self.interpolation == 'akima':
+ self.curve = Akima1DInterpolator(x, y)
+ elif self.interpolation == 'linear':
+ self.curve = lambda x_new: np.interp(x_new, x, y)
+ else:
+ raise ValueError("Unsupported interpolation method")
+
+ def transform(self, image, channel=None):
+ if self.curve is None:
+ raise ValueError("No valid curve defined")
+
+ if len(image.shape) == 2: # Grayscale image
+ transformed_image = self.curve(image)
+ elif len(image.shape) == 3: # RGB image
+ if channel is None:
+ raise ValueError("Channel must be specified for color images")
+ transformed_image = image.copy()
+ transformed_image[:, :, channel] = self.curve(image[:, :, channel])
+ else:
+ raise ValueError("Unsupported image format")
+
+ transformed_image = np.clip(transformed_image, 0, 1)
+ return transformed_image
+
+ def plot_curve(self):
+ if self.curve is None:
+ print("No curve to plot")
+ return
+
+ x_vals = np.linspace(0, 1, 100)
+ y_vals = self.curve(x_vals)
+
+ plt.plot(x_vals, y_vals, label=f'Interpolation: {self.interpolation}')
+ plt.scatter(*zip(*self.points), color='red')
+ plt.title('Curves Transformation')
+ plt.xlabel('Input')
+ plt.ylabel('Output')
+ plt.grid(True)
+ plt.legend()
+ plt.show()
+
+ def store_curve(self):
+ self.stored_curve = self.points.copy()
+ print("Curve stored.")
+
+ def restore_curve(self):
+ if self.stored_curve:
+ self.points = self.stored_curve.copy()
+ self._update_curve()
+ print("Curve restored.")
+ else:
+ print("No stored curve to restore.")
+
+ def invert_curve(self):
+ if self.curve is None:
+ print("No curve to invert")
+ return
+
+ self.points = [(x, 1 - y) for x, y in self.points]
+ self._update_curve()
+ print("Curve inverted.")
+
+ def reset_curve(self):
+ self.points = [(0, 0), (1, 1)]
+ self._update_curve()
+ print("Curve reset to default.")
+
+ def pixel_readout(self, x):
+ if self.curve is None:
+ print("No curve defined")
+ return None
+ return self.curve(x)
+
+
+# Example Usage
+if __name__ == "__main__":
+ # Create a CurvesTransformation object
+ curve_transform = CurvesTransformation(interpolation='akima')
+
+ # Add points to the curve
+ curve_transform.add_point(0.0, 0.0)
+ curve_transform.add_point(0.3, 0.5)
+ curve_transform.add_point(0.7, 0.8)
+ curve_transform.add_point(1.0, 1.0)
+
+ # Plot the curve
+ curve_transform.plot_curve()
+
+ # Store the curve
+ curve_transform.store_curve()
+
+ # Invert the curve
+ curve_transform.invert_curve()
+ curve_transform.plot_curve()
+
+ # Restore the original curve
+ curve_transform.restore_curve()
+ curve_transform.plot_curve()
+
+ # Reset the curve to default
+ curve_transform.reset_curve()
+ curve_transform.plot_curve()
+
+ # Generate a test image
+ test_image = np.linspace(0, 1, 256).reshape(16, 16)
+
+ # Apply the transformation
+ transformed_image = curve_transform.transform(test_image)
+
+ # Plot original and transformed images
+ plt.figure(figsize=(8, 4))
+ plt.subplot(1, 2, 1)
+ plt.title("Original Image")
+ plt.imshow(test_image, cmap='gray')
+
+ plt.subplot(1, 2, 2)
+ plt.title("Transformed Image")
+ plt.imshow(transformed_image, cmap='gray')
+ plt.show()
+
+ # Pixel readout
+ readout_value = curve_transform.pixel_readout(0.5)
+ print(f"Pixel readout at x=0.5: {readout_value}")
diff --git a/pysrc/image/transformation/histogram.py b/pysrc/image/transformation/histogram.py
new file mode 100644
index 00000000..80f8f390
--- /dev/null
+++ b/pysrc/image/transformation/histogram.py
@@ -0,0 +1,121 @@
+import cv2
+import numpy as np
+import matplotlib.pyplot as plt
+
+# 1. 直方图计算
+
+
+def calculate_histogram(image, channel=0):
+ histogram = cv2.calcHist([image], [channel], None, [256], [0, 256])
+ return histogram
+
+# 2. 显示直方图
+
+
+def display_histogram(histogram, title="Histogram"):
+ plt.plot(histogram)
+ plt.title(title)
+ plt.xlabel('Pixel Intensity')
+ plt.ylabel('Frequency')
+ plt.show()
+
+# 3. 直方图变换功能
+
+
+def apply_histogram_transformation(image, shadows_clip=0.0, highlights_clip=1.0, midtones_balance=0.5, lower_bound=-1.0, upper_bound=2.0):
+ # 归一化
+ normalized_image = image.astype(np.float32) / 255.0
+
+ # 阴影和高光裁剪
+ clipped_image = np.clip(
+ (normalized_image - shadows_clip) / (highlights_clip - shadows_clip), 0, 1)
+
+ # 中间调平衡
+ def mtf(x): return (x**midtones_balance) / \
+ ((x**midtones_balance + (1-x)**midtones_balance)**(1/midtones_balance))
+ transformed_image = mtf(clipped_image)
+
+ # 动态范围扩展
+ expanded_image = np.clip(
+ (transformed_image - lower_bound) / (upper_bound - lower_bound), 0, 1)
+
+ # 重新缩放至[0, 255]
+ output_image = (expanded_image * 255).astype(np.uint8)
+ return output_image
+
+# 4. 自动裁剪功能
+
+
+def auto_clip(image, clip_percent=0.01):
+ # 计算累积分布函数 (CDF)
+ hist, bins = np.histogram(image.flatten(), 256, [0, 256])
+ cdf = hist.cumsum()
+
+ # 计算裁剪点
+ total_pixels = image.size
+ lower_clip = np.searchsorted(cdf, total_pixels * clip_percent)
+ upper_clip = np.searchsorted(cdf, total_pixels * (1 - clip_percent))
+
+ # 应用裁剪
+ auto_clipped_image = apply_histogram_transformation(
+ image, shadows_clip=lower_clip/255.0, highlights_clip=upper_clip/255.0)
+
+ return auto_clipped_image
+
+# 5. 显示原始RGB直方图
+
+
+def display_rgb_histogram(image):
+ color = ('b', 'g', 'r')
+ for i, col in enumerate(color):
+ hist = calculate_histogram(image, channel=i)
+ plt.plot(hist, color=col)
+ plt.title('RGB Histogram')
+ plt.xlabel('Pixel Intensity')
+ plt.ylabel('Frequency')
+ plt.show()
+
+# 6. 实时预览功能(简单模拟)
+
+
+def real_time_preview(image, transformation_function, **kwargs):
+ preview_image = transformation_function(image, **kwargs)
+ cv2.imshow('Real-Time Preview', preview_image)
+
+
+# 主程序入口
+if __name__ == "__main__":
+ # 加载图像
+ image = cv2.imread('image.jpg')
+
+ # 转换为灰度图像
+ grayscale_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+
+ # 显示原始图像和直方图
+ cv2.imshow('Original Image', image)
+ histogram = calculate_histogram(grayscale_image)
+ display_histogram(histogram, title="Original Grayscale Histogram")
+
+ # 显示RGB直方图
+ display_rgb_histogram(image)
+
+ # 应用直方图变换
+ transformed_image = apply_histogram_transformation(
+ grayscale_image, shadows_clip=0.1, highlights_clip=0.9, midtones_balance=0.4, lower_bound=-0.5, upper_bound=1.5)
+ cv2.imshow('Transformed Image', transformed_image)
+
+ # 显示变换后的直方图
+ transformed_histogram = calculate_histogram(transformed_image)
+ display_histogram(transformed_histogram,
+ title="Transformed Grayscale Histogram")
+
+ # 应用自动裁剪
+ auto_clipped_image = auto_clip(grayscale_image, clip_percent=0.01)
+ cv2.imshow('Auto Clipped Image', auto_clipped_image)
+
+ # 实时预览模拟
+ real_time_preview(grayscale_image, apply_histogram_transformation,
+ shadows_clip=0.05, highlights_clip=0.95, midtones_balance=0.5)
+
+ cv2.waitKey(0)
+ cv2.destroyAllWindows()
diff --git a/pysrc/main.py b/pysrc/main.py
index ff25169d..8398c8ef 100644
--- a/pysrc/main.py
+++ b/pysrc/main.py
@@ -9,6 +9,8 @@
from app.connection_manager import ConnectionManager
from app.plugin_manager import load_plugins, start_plugin_watcher, stop_plugin_watcher, update_plugin, install_plugin, get_plugin_info, check_plugin_dependencies
+from router import websocket
+
# 配置 loguru 日志系统
logger.add("server.log", level="DEBUG", format="{time} {level} {message}", rotation="10 MB")
@@ -71,28 +73,6 @@ def get_current_username(credentials: HTTPBasicCredentials = Depends(security)):
logger.info(f"Authenticated user: {credentials.username}")
return credentials.username
-@app.websocket("/ws")
-async def websocket_endpoint(websocket: WebSocket, username: str = Depends(get_current_username)):
- """
- WebSocket endpoint for handling client connections.
- """
- client_id = await manager.connect(websocket)
- await manager.broadcast(f'{{"type": "Server_msg", "message": "Client {client_id} connected"}}')
-
- try:
- async with websocket:
- while True:
- data = await websocket.receive_text()
- logger.info(f"Received message from {client_id}: {data}")
- await manager.broadcast(data)
- except WebSocketDisconnect:
- manager.disconnect(client_id)
- await manager.broadcast(f'{{"type": "Server_msg", "message": "Client {client_id} disconnected"}}')
- except Exception as e:
- logger.error(f"Unexpected error with client {client_id}: {e}")
- manager.disconnect(client_id)
- await manager.broadcast(f'{{"type": "Server_msg", "message": "Client {client_id} disconnected due to error"}}')
-
# Heartbeat function to check if clients are still connected
async def ping():
"""
diff --git a/pysrc/router/__init__.py b/pysrc/router/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/pysrc/router/websocket.py b/pysrc/router/websocket.py
new file mode 100644
index 00000000..172adc4a
--- /dev/null
+++ b/pysrc/router/websocket.py
@@ -0,0 +1,67 @@
+from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
+from app.connection_manager import ConnectionManager
+from app.dependence import get_current_username
+from app.command_dispatcher import CommandDispatcher
+from loguru import logger
+import json
+import asyncio
+from typing import Dict, Any
+
+router = APIRouter()
+manager = ConnectionManager()
+command_dispatcher = CommandDispatcher()
+
+@router.websocket("/ws")
+async def websocket_endpoint(websocket: WebSocket, username: str = Depends(get_current_username)):
+ client_id = await manager.connect(websocket)
+ await manager.broadcast(json.dumps({"type": "Server_msg", "message": f"Client {client_id} connected"}))
+
+ try:
+ # Start a background task for handling heartbeats
+ heartbeat_task = asyncio.create_task(handle_heartbeat(client_id, websocket))
+
+ while True:
+ data = await websocket.receive_text()
+ await process_message(client_id, data)
+ except WebSocketDisconnect:
+ await handle_disconnect(client_id, "Client disconnected")
+ except Exception as e:
+ logger.error(f"Unexpected error with client {client_id}: {e}")
+ await handle_disconnect(client_id, f"Client disconnected due to error: {str(e)}")
+ finally:
+ heartbeat_task.cancel()
+
+async def process_message(client_id: str, data: str):
+ logger.info(f"Received message from {client_id}: {data}")
+ try:
+ message = json.loads(data)
+ if isinstance(message, dict) and "command" in message:
+ response = await command_dispatcher.dispatch(message["command"], message.get("params", {}))
+ await manager.send_personal_message(json.dumps(response), client_id)
+ else:
+ await manager.broadcast(data)
+ except json.JSONDecodeError:
+ logger.warning(f"Received invalid JSON from client {client_id}")
+ await manager.send_personal_message(json.dumps({"error": "Invalid JSON"}), client_id)
+
+async def handle_disconnect(client_id: str, message: str):
+ manager.disconnect(client_id)
+ await manager.broadcast(json.dumps({"type": "Server_msg", "message": message}))
+
+async def handle_heartbeat(client_id: str, websocket: WebSocket):
+ while True:
+ try:
+ await asyncio.sleep(30) # Send heartbeat every 30 seconds
+ await websocket.send_text(json.dumps({"type": "heartbeat"}))
+ except Exception as e:
+ logger.error(f"Heartbeat failed for client {client_id}: {e}")
+ break
+
+# Register commands
+@command_dispatcher.register("echo")
+async def echo_command(params: Dict[str, Any]) -> Dict[str, Any]:
+ return {"result": params.get("message", "No message provided")}
+
+@command_dispatcher.register("get_active_clients")
+async def get_active_clients_command(params: Dict[str, Any]) -> Dict[str, Any]:
+ return {"result": len(manager.active_connections)}
diff --git a/scripts/docker.py b/scripts/docker.py
new file mode 100644
index 00000000..de22a67a
--- /dev/null
+++ b/scripts/docker.py
@@ -0,0 +1,406 @@
+import docker
+import argparse
+import sys
+import json
+from typing import Literal, Union
+from dataclasses import dataclass, asdict
+from prettytable import PrettyTable
+import os
+import tempfile
+import tarfile
+import yaml
+
+
+@dataclass
+class ContainerInfo:
+ id: str
+ name: str
+ status: str
+ image: str
+ ports: dict
+ cpu_usage: float
+ memory_usage: float
+
+
+class DockerManager:
+ def __init__(self):
+ self.client = docker.from_env()
+
+ def list_containers(self, all: bool = False) -> list[ContainerInfo]:
+ containers = self.client.containers.list(all=all)
+ container_info = []
+ for c in containers:
+ stats = c.stats(stream=False)
+ cpu_usage = self._calculate_cpu_percent(stats)
+ memory_usage = self._calculate_memory_usage(stats)
+ container_info.append(ContainerInfo(
+ c.id, c.name, c.status,
+ c.image.tags[0] if c.image.tags else "None",
+ c.ports, cpu_usage, memory_usage
+ ))
+ return container_info
+
+ def _calculate_cpu_percent(self, stats):
+ cpu_delta = stats["cpu_stats"]["cpu_usage"]["total_usage"] - \
+ stats["precpu_stats"]["cpu_usage"]["total_usage"]
+ system_delta = stats["cpu_stats"]["system_cpu_usage"] - \
+ stats["precpu_stats"]["system_cpu_usage"]
+ if system_delta > 0.0:
+ return (cpu_delta / system_delta) * 100.0
+ return 0.0
+
+ def _calculate_memory_usage(self, stats):
+ return stats["memory_stats"]["usage"] / stats["memory_stats"]["limit"] * 100.0
+
+ def manage_container(self, action: Literal["start", "stop", "restart", "remove", "pause", "unpause"], container_id: str) -> str:
+ try:
+ container = self.client.containers.get(container_id)
+ getattr(container, action)()
+ return f"Container {container_id} {action}ed"
+ except docker.errors.NotFound:
+ return f"Container {container_id} not found"
+
+ def create_container(self, image: str, name: str, ports: dict = None, volumes: list = None, environment: dict = None) -> Union[str, ContainerInfo]:
+ try:
+ container = self.client.containers.run(
+ image, name=name, detach=True,
+ ports=ports, volumes=volumes,
+ environment=environment
+ )
+ stats = container.stats(stream=False)
+ cpu_usage = self._calculate_cpu_percent(stats)
+ memory_usage = self._calculate_memory_usage(stats)
+ return ContainerInfo(
+ container.id, container.name, container.status,
+ container.image.tags[0] if container.image.tags else "None",
+ container.ports, cpu_usage, memory_usage
+ )
+ except docker.errors.ImageNotFound:
+ return f"Image {image} not found"
+
+ def pull_image(self, image: str) -> str:
+ try:
+ self.client.images.pull(image)
+ return f"Image {image} pulled successfully"
+ except docker.errors.ImageNotFound:
+ return f"Image {image} not found"
+
+ def list_images(self) -> list[str]:
+ return [image.tags[0] for image in self.client.images.list() if image.tags]
+
+ def get_container_logs(self, container_id: str, lines: int = 50) -> str:
+ try:
+ container = self.client.containers.get(container_id)
+ return container.logs(tail=lines).decode('utf-8')
+ except docker.errors.NotFound:
+ return f"Container {container_id} not found"
+
+ def exec_command(self, container_id: str, cmd: str) -> str:
+ try:
+ container = self.client.containers.get(container_id)
+ exit_code, output = container.exec_run(cmd)
+ return f"Exit Code: {exit_code}\nOutput:\n{output.decode('utf-8')}"
+ except docker.errors.NotFound:
+ return f"Container {container_id} not found"
+
+ def get_container_stats(self, container_id: str) -> Union[str, dict]:
+ try:
+ container = self.client.containers.get(container_id)
+ stats = container.stats(stream=False)
+ return {
+ "CPU Usage": f"{self._calculate_cpu_percent(stats):.2f}%",
+ "Memory Usage": f"{self._calculate_memory_usage(stats):.2f}%",
+ "Network I/O": f"In: {stats['networks']['eth0']['rx_bytes']/1024/1024:.2f}MB, Out: {stats['networks']['eth0']['tx_bytes']/1024/1024:.2f}MB",
+ "Block I/O": f"In: {stats['blkio_stats']['io_service_bytes_recursive'][0]['value']/1024/1024:.2f}MB, Out: {stats['blkio_stats']['io_service_bytes_recursive'][1]['value']/1024/1024:.2f}MB"
+ }
+ except docker.errors.NotFound:
+ return f"Container {container_id} not found"
+
+ def copy_to_container(self, container_id: str, src: str, dest: str) -> str:
+ try:
+ container = self.client.containers.get(container_id)
+ with tempfile.NamedTemporaryFile() as tmp:
+ with tarfile.open(tmp.name, "w:gz") as tar:
+ tar.add(src, arcname=os.path.basename(src))
+ container.put_archive(os.path.dirname(dest), tmp.read())
+ return f"File {src} copied to container {container_id} at {dest}"
+ except docker.errors.NotFound:
+ return f"Container {container_id} not found"
+ except FileNotFoundError:
+ return f"Source file {src} not found"
+
+ def copy_from_container(self, container_id: str, src: str, dest: str) -> str:
+ try:
+ container = self.client.containers.get(container_id)
+ bits, stat = container.get_archive(src)
+ with tempfile.NamedTemporaryFile() as tmp:
+ for chunk in bits:
+ tmp.write(chunk)
+ tmp.seek(0)
+ with tarfile.open(fileobj=tmp) as tar:
+ tar.extractall(path=dest)
+ return f"File {src} copied from container {container_id} to {dest}"
+ except docker.errors.NotFound:
+ return f"Container {container_id} not found"
+ except KeyError:
+ return f"Source file {src} not found in container"
+
+ def export_container(self, container_id: str, output_path: str) -> str:
+ try:
+ container = self.client.containers.get(container_id)
+ with open(output_path, 'wb') as f:
+ for chunk in container.export():
+ f.write(chunk)
+ return f"Container {container_id} exported to {output_path}"
+ except docker.errors.NotFound:
+ return f"Container {container_id} not found"
+
+ def import_image(self, image_path: str, repository: str, tag: str) -> str:
+ try:
+ with open(image_path, 'rb') as f:
+ image = self.client.images.import_image(
+ f, repository=repository, tag=tag)
+ return f"Image imported as {repository}:{tag}"
+ except FileNotFoundError:
+ return f"Image file {image_path} not found"
+
+ def build_image(self, dockerfile_path: str, tag: str) -> str:
+ try:
+ image, logs = self.client.images.build(path=os.path.dirname(
+ dockerfile_path), dockerfile=os.path.basename(dockerfile_path), tag=tag)
+ return f"Image built successfully with tag {tag}"
+ except docker.errors.BuildError as e:
+ return f"Error building image: {str(e)}"
+
+ def compose_up(self, compose_file: str) -> str:
+ try:
+ project = self.client.compose.project.from_config(
+ project_name="myproject", config_files=[compose_file])
+ project.up()
+ return f"Docker Compose services started from {compose_file}"
+ except docker.errors.APIError as e:
+ return f"Error starting Docker Compose services: {str(e)}"
+
+ def compose_down(self, compose_file: str) -> str:
+ try:
+ project = self.client.compose.project.from_config(
+ project_name="myproject", config_files=[compose_file])
+ project.down()
+ return f"Docker Compose services stopped from {compose_file}"
+ except docker.errors.APIError as e:
+ return f"Error stopping Docker Compose services: {str(e)}"
+
+
+def print_table(data, headers):
+ table = PrettyTable()
+ table.field_names = headers
+ for row in data:
+ table.add_row(row)
+ print(table)
+
+
+def parse_key_value_pairs(s: str) -> dict:
+ if not s:
+ return {}
+ return dict(item.split("=") for item in s.split(","))
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Comprehensive Docker Management CLI Tool")
+ subparsers = parser.add_subparsers(
+ dest="command", help="Available commands")
+
+ # List containers
+ list_parser = subparsers.add_parser("list", help="List containers")
+ list_parser.add_argument("--all", action="store_true",
+ help="Show all containers (default shows just running)")
+ list_parser.add_argument(
+ "--format", choices=["table", "json"], default="table", help="Output format")
+
+ # Manage container
+ manage_parser = subparsers.add_parser("manage", help="Manage container")
+ manage_parser.add_argument("action", choices=[
+ "start", "stop", "restart", "remove", "pause", "unpause"], help="Action to perform")
+ manage_parser.add_argument("container_id", help="Container ID")
+
+ # Create container
+ create_parser = subparsers.add_parser(
+ "create", help="Create a new container")
+ create_parser.add_argument("image", help="Image name")
+ create_parser.add_argument("name", help="Container name")
+ create_parser.add_argument(
+ "--ports", help="Port mappings (e.g., '8080:80,9000:9000')")
+ create_parser.add_argument(
+ "--volumes", help="Volume mappings (e.g., '/host/path:/container/path')")
+ create_parser.add_argument(
+ "--env", help="Environment variables (e.g., 'KEY1=VALUE1,KEY2=VALUE2')")
+
+ # Pull image
+ pull_parser = subparsers.add_parser("pull", help="Pull an image")
+ pull_parser.add_argument("image", help="Image name")
+
+ # List images
+ subparsers.add_parser("images", help="List images")
+
+ # View container logs
+ logs_parser = subparsers.add_parser("logs", help="View container logs")
+ logs_parser.add_argument("container_id", help="Container ID")
+ logs_parser.add_argument(
+ "--lines", type=int, default=50, help="Number of lines to display")
+
+ # Execute command in container
+ exec_parser = subparsers.add_parser(
+ "exec", help="Execute command in container")
+ exec_parser.add_argument("container_id", help="Container ID")
+ exec_parser.add_argument("command", help="Command to execute")
+
+ # View container statistics
+ stats_parser = subparsers.add_parser(
+ "stats", help="View container statistics")
+ stats_parser.add_argument("container_id", help="Container ID")
+
+ # Copy file to container
+ cp_to_parser = subparsers.add_parser(
+ "cp-to", help="Copy file to container")
+ cp_to_parser.add_argument("container_id", help="Container ID")
+ cp_to_parser.add_argument("src", help="Source file path")
+ cp_to_parser.add_argument("dest", help="Destination path in container")
+
+ # Copy file from container
+ cp_from_parser = subparsers.add_parser(
+ "cp-from", help="Copy file from container")
+ cp_from_parser.add_argument("container_id", help="Container ID")
+ cp_from_parser.add_argument("src", help="Source file path in container")
+ cp_from_parser.add_argument("dest", help="Destination path on host")
+
+ # Export container
+ export_parser = subparsers.add_parser(
+ "export", help="Export container to a tar archive")
+ export_parser.add_argument("container_id", help="Container ID")
+ export_parser.add_argument("output", help="Output file path")
+
+ # Import image
+ import_parser = subparsers.add_parser(
+ "import", help="Import an image from a tar archive")
+ import_parser.add_argument("image_path", help="Path to the tar archive")
+ import_parser.add_argument(
+ "repository", help="Repository name for the image")
+ import_parser.add_argument("tag", help="Tag for the image")
+
+ # Build image
+ build_parser = subparsers.add_parser(
+ "build", help="Build an image from a Dockerfile")
+ build_parser.add_argument("dockerfile", help="Path to the Dockerfile")
+ build_parser.add_argument("tag", help="Tag for the image")
+
+ # Docker Compose up
+ compose_up_parser = subparsers.add_parser(
+ "compose-up", help="Start services defined in a Docker Compose file")
+ compose_up_parser.add_argument(
+ "compose_file", help="Path to the Docker Compose file")
+
+ # Docker Compose down
+ compose_down_parser = subparsers.add_parser(
+ "compose-down", help="Stop services defined in a Docker Compose file")
+ compose_down_parser.add_argument(
+ "compose_file", help="Path to the Docker Compose file")
+
+ args = parser.parse_args()
+
+ manager = DockerManager()
+
+ try:
+ if args.command == "list":
+ containers = manager.list_containers(all=args.all)
+ if args.format == "json":
+ print(json.dumps([asdict(c) for c in containers], indent=2))
+ else:
+ data = [[c.id[:12], c.name, c.status, c.image,
+ f"{c.cpu_usage:.2f}%", f"{c.memory_usage:.2f}%"] for c in containers]
+ headers = ["ID", "Name", "Status",
+ "Image", "CPU Usage", "Memory Usage"]
+ print_table(data, headers)
+
+ elif args.command == "manage":
+ result = manager.manage_container(args.action, args.container_id)
+ print(result)
+
+ elif args.command == "create":
+ ports = parse_key_value_pairs(args.ports)
+ volumes = args.volumes.split(",") if args.volumes else None
+ env = parse_key_value_pairs(args.env)
+ result = manager.create_container(
+ args.image, args.name, ports, volumes, env)
+ if isinstance(result, ContainerInfo):
+ print(
+ f"Container created - ID: {result.id}, Name: {result.name}, Status: {result.status}")
+ else:
+ print(result)
+
+ elif args.command == "pull":
+ result = manager.pull_image(args.image)
+ print(result)
+
+ elif args.command == "images":
+ images = manager.list_images()
+ print("Available images:")
+ for image in images:
+ print(image)
+
+ elif args.command == "logs":
+ logs = manager.get_container_logs(args.container_id, args.lines)
+ print(f"Logs for container {args.container_id}:")
+ print(logs)
+
+ elif args.command == "exec":
+ result = manager.exec_command(args.container_id, args.command)
+ print(result)
+
+ elif args.command == "stats":
+ stats = manager.get_container_stats(args.container_id)
+ if isinstance(stats, dict):
+ for key, value in stats.items():
+ print(f"{key}: {value}")
+ else:
+ print(stats)
+
+ elif args.command == "cp-to":
+ result = manager.copy_to_container(
+ args.container_id, args.src, args.dest)
+ print(result)
+
+ elif args.command == "cp-from":
+ result = manager.copy_from_container(
+ args.container_id, args.src, args.dest)
+ print(result)
+
+ elif args.command == "export":
+ result = manager.export_container(args.container_id, args.output)
+ print(result)
+
+ elif args.command == "import":
+ result = manager.import_image(
+ args.image_path, args.repository, args.tag)
+ print(result)
+
+ elif args.command == "build":
+ result = manager.build_image(args.dockerfile, args.tag)
+ print(result)
+
+ elif args.command == "compose-up":
+ result = manager.compose_up(args.compose_file)
+ print(result)
+
+ elif args.command == "compose-down":
+ result = manager.compose_down(args.compose_file)
+ print(result)
+
+ except docker.errors.APIError as e:
+ print(f"Docker API Error: {str(e)}")
+ except Exception as e:
+ print(f"An unexpected error occurred: {str(e)}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/pip.sh b/scripts/pip.sh
old mode 100644
new mode 100755
index dd905c45..0a2159a8
--- a/scripts/pip.sh
+++ b/scripts/pip.sh
@@ -30,6 +30,9 @@ update_python_packages() {
echo "Updated packages: ${updated_packages[*]}"
echo "Failed packages: ${failed_packages[*]}"
+
+ # Return updated and failed packages as a string
+ echo "${updated_packages[*]}|${failed_packages[*]}"
}
# Function to generate update report
@@ -94,9 +97,9 @@ done
update_result=$(update_python_packages "${EXCLUDED_PACKAGES[@]}")
# Extract updated and failed packages from the result
-IFS=$'\n' read -rd '' -a lines <<< "$update_result"
-updated_packages=($(echo "${lines[0]}" | cut -d ':' -f2))
-failed_packages=($(echo "${lines[1]}" | cut -d ':' -f2))
+IFS='|' read -r updated_str failed_str <<< "$update_result"
+IFS=' ' read -r -a updated_packages <<< "$updated_str"
+IFS=' ' read -r -a failed_packages <<< "$failed_str"
echo "Update completed."
echo "Total packages updated: ${#updated_packages[@]}"
diff --git a/src/LithiumApp.cpp b/src/LithiumApp.cpp
index 4eeece38..f4d486df 100644
--- a/src/LithiumApp.cpp
+++ b/src/LithiumApp.cpp
@@ -284,8 +284,8 @@ void initLithiumApp(int argc, char **argv) {
// AddPtr("ScriptManager",
// ScriptManager::createShared(GetPtr("MessageBus")));
- AddPtr(constants::LITHIUM_DEVICE_MANAGER, DeviceManager::createShared());
- AddPtr(constants::LITHIUM_DEVICE_LOADER, ModuleLoader::createShared("./drivers"));
+ AddPtr(Constants::LITHIUM_DEVICE_MANAGER, DeviceManager::createShared());
+ AddPtr(Constants::LITHIUM_DEVICE_LOADER, ModuleLoader::createShared("./drivers"));
AddPtr("lithium.error.stack", std::make_shared());
diff --git a/src/addon/CMakeLists.txt b/src/addon/CMakeLists.txt
index 28d17d6f..5d3a28af 100644
--- a/src/addon/CMakeLists.txt
+++ b/src/addon/CMakeLists.txt
@@ -13,7 +13,9 @@ project(lithium-addons VERSION 1.0.0 LANGUAGES C CXX)
# Author: Max Qian
# License: GPL3
+if (NOT MINGW OR NOT WIN32)
find_package(Seccomp REQUIRED)
+endif()
# Project sources
set(PROJECT_SOURCES
diff --git a/src/addon/analysts.cpp b/src/addon/analysts.cpp
index 2f3bcc9a..c35f1bf1 100644
--- a/src/addon/analysts.cpp
+++ b/src/addon/analysts.cpp
@@ -1,15 +1,165 @@
#include "analysts.hpp"
+#include
#include
+#include
+#include
+#include
#include
#include "atom/error/exception.hpp"
+#include "atom/log/loguru.hpp"
#include "atom/type/json.hpp"
namespace lithium {
-void CompilerOutputParser::parseLine(const std::string& line) {
+
+/**
+ * @class HtmlBuilder
+ * @brief A simple utility class for building HTML documents.
+ *
+ * HtmlBuilder provides an interface to construct an HTML document by appending
+ * elements like headers, paragraphs, lists, etc. The HTML is stored as a string
+ * and can be retrieved using the `str()` method.
+ */
+class HtmlBuilder {
+public:
+ HtmlBuilder();
+
+ // Add various HTML elements
+ void addTitle(const std::string& title);
+ void addHeader(const std::string& header, int level = 1);
+ void addParagraph(const std::string& text);
+ void addList(const std::vector& items);
+
+ // Start and end unordered list
+ void startUnorderedList();
+ void endUnorderedList();
+
+ // Add list item to an unordered list
+ void addListItem(const std::string& item);
+
+ // Get the final HTML document as a string
+ auto str() const -> std::string;
+
+private:
+ std::ostringstream html_; ///< String stream to build the HTML document.
+ bool inList_{}; ///< Tracks if currently inside a list.
+};
+
+HtmlBuilder::HtmlBuilder() {
+ html_ << "\n";
+ LOG_F(INFO, "HtmlBuilder created with initial HTML structure.");
+}
+
+void HtmlBuilder::addTitle(const std::string& title) {
+ html_ << "" << title << "\n";
+ LOG_F(INFO, "Title added: {}", title);
+}
+
+void HtmlBuilder::addHeader(const std::string& header, int level) {
+ if (level < 1 || level > 6) {
+ level = 1; // Default to if level is invalid
+ }
+ html_ << "" << header << "\n";
+ LOG_F(INFO, "Header added: {} at level %d", header, level);
+}
+
+void HtmlBuilder::addParagraph(const std::string& text) {
+ html_ << "
" << text << "
\n";
+ LOG_F(INFO, "Paragraph added: {}", text);
+}
+
+void HtmlBuilder::addList(const std::vector& items) {
+ html_ << "\n";
+ for (const auto& item : items) {
+ html_ << "- " << item << "
\n";
+ LOG_F(INFO, "List item added: {}", item);
+ }
+ html_ << "
\n";
+}
+
+void HtmlBuilder::startUnorderedList() {
+ if (!inList_) {
+ html_ << "\n";
+ inList_ = true;
+ LOG_F(INFO, "Unordered list started.");
+ }
+}
+
+void HtmlBuilder::endUnorderedList() {
+ if (inList_) {
+ html_ << "
\n";
+ inList_ = false;
+ LOG_F(INFO, "Unordered list ended.");
+ }
+}
+
+void HtmlBuilder::addListItem(const std::string& item) {
+ if (inList_) {
+ html_ << "" << item << "\n";
+ LOG_F(INFO, "List item added inside unordered list: {}", item);
+ }
+}
+
+std::string HtmlBuilder::str() const {
+ std::ostringstream finalHtml;
+ finalHtml << html_.str() << "\n";
+ LOG_F(INFO, "Final HTML document generated.");
+ return finalHtml.str();
+}
+
+// Message constructor
+Message::Message(MessageType t, std::string f, int l, int c, std::string code,
+ std::string func, std::string msg, std::string ctx)
+ : type(t),
+ file(std::move(f)),
+ line(l),
+ column(c),
+ errorCode(std::move(code)),
+ functionName(std::move(func)),
+ message(std::move(msg)),
+ context(std::move(ctx)) {}
+
+/**
+ * @class CompilerOutputParser::Impl
+ * @brief The private implementation for CompilerOutputParser (PIMPL idiom).
+ */
+class CompilerOutputParser::Impl {
+public:
+ Impl() {
+ initRegexPatterns();
+ LOG_F(INFO, "CompilerOutputParser::Impl initialized.");
+ }
+
+ void parseLine(const std::string& line);
+ void parseFile(const std::string& filename);
+ void parseFileMultiThreaded(const std::string& filename, int numThreads);
+ auto getReport(bool detailed) const -> std::string;
+ void generateHtmlReport(const std::string& outputFilename) const;
+ auto generateJsonReport() -> json;
+ void setCustomRegexPattern(const std::string& compiler,
+ const std::string& pattern);
+
+private:
+ std::vector messages_;
+ std::unordered_map> counts_;
+ mutable std::unordered_map regexPatterns_;
+ mutable std::mutex mutex_;
+ std::string currentContext_;
+ std::regex includePattern_{R"((.*):(\d+):(\d+):)"};
+ std::smatch includeMatch_;
+
+ void initRegexPatterns();
+ auto determineType(const std::string& typeStr) const -> MessageType;
+ auto toString(MessageType type) const -> std::string;
+};
+
+void CompilerOutputParser::Impl::parseLine(const std::string& line) {
+ LOG_F(INFO, "Parsing line: {}", line);
+
if (std::regex_search(line, includeMatch_, includePattern_)) {
currentContext_ = includeMatch_.str(1);
+ LOG_F(INFO, "Context updated: {}", currentContext_);
return;
}
@@ -25,6 +175,11 @@ void CompilerOutputParser::parseLine(const std::string& line) {
std::string message =
match.size() > 7 ? match.str(7) : match.str(5);
+ LOG_F(INFO,
+ "Parsed message - File: {}, Line: %d, Column: %d, ErrorCode: "
+ "{}, FunctionName: {}, Message: {}",
+ file, lineNum, column, errorCode, functionName, message);
+
std::lock_guard lock(mutex_);
messages_.emplace_back(type, file, lineNum, column, errorCode,
functionName, message, currentContext_);
@@ -37,11 +192,15 @@ void CompilerOutputParser::parseLine(const std::string& line) {
messages_.emplace_back(MessageType::UNKNOWN, "", 0, 0, "", "", line,
currentContext_);
counts_[MessageType::UNKNOWN]++;
+ LOG_F(WARNING, "Unknown message parsed: {}", line);
}
-void CompilerOutputParser::parseFile(const std::string& filename) {
+void CompilerOutputParser::Impl::parseFile(const std::string& filename) {
+ LOG_F(INFO, "Parsing file: {}", filename);
+
std::ifstream inputFile(filename);
if (!inputFile.is_open()) {
+ LOG_F(ERROR, "Failed to open file: {}", filename);
THROW_FAIL_TO_OPEN_FILE("Failed to open file: " + filename);
}
@@ -49,12 +208,18 @@ void CompilerOutputParser::parseFile(const std::string& filename) {
while (std::getline(inputFile, line)) {
parseLine(line);
}
+
+ LOG_F(INFO, "Completed parsing file: {}", filename);
}
-void CompilerOutputParser::parseFileMultiThreaded(const std::string& filename,
- int numThreads) {
+void CompilerOutputParser::Impl::parseFileMultiThreaded(
+ const std::string& filename, int numThreads) {
+ LOG_F(INFO, "Parsing file multithreaded: {} with %d threads", filename,
+ numThreads);
+
std::ifstream inputFile(filename);
if (!inputFile.is_open()) {
+ LOG_F(ERROR, "Failed to open file: {}", filename);
THROW_FAIL_TO_OPEN_FILE("Failed to open file: " + filename);
}
@@ -64,27 +229,36 @@ void CompilerOutputParser::parseFileMultiThreaded(const std::string& filename,
lines.push_back(line);
}
- std::vector threads;
- auto worker = [this](const std::vector& lines, int start,
- int end) {
- for (int i = start; i < end; ++i) {
- parseLine(lines[i]);
+ std::vector threads;
+ auto worker = [this](std::span lines) {
+ for (const auto& line : lines) {
+ parseLine(line);
}
};
int blockSize = lines.size() / numThreads;
for (int i = 0; i < numThreads; ++i) {
- int start = i * blockSize;
- int end = (i == numThreads - 1) ? lines.size() : (i + 1) * blockSize;
- threads.emplace_back(worker, std::cref(lines), start, end);
+ auto start = lines.begin() + i * blockSize;
+ auto end = (i == numThreads - 1) ? lines.end() : start + blockSize;
+ threads.emplace_back(worker, std::span(start, end));
+ LOG_F(INFO, "Thread %d started processing lines [{} - {}]", i,
+ start - lines.begin(), end - lines.begin());
}
+ // Join the threads once processing is complete
for (auto& thread : threads) {
- thread.join();
+ if (thread.joinable()) {
+ thread.join();
+ }
}
+
+ LOG_F(INFO, "Multithreaded file parsing completed for file: {}", filename);
}
-auto CompilerOutputParser::getReport(bool detailed) const -> std::string {
+auto CompilerOutputParser::Impl::getReport(bool detailed) const -> std::string {
+ LOG_F(INFO, "Generating report with detailed: {}",
+ detailed ? "true" : "false");
+
std::ostringstream report;
report << "Compiler Messages Report:\n";
report << "Errors: " << counts_.at(MessageType::ERROR) << "\n";
@@ -116,60 +290,69 @@ auto CompilerOutputParser::getReport(bool detailed) const -> std::string {
}
}
+ LOG_F(INFO, "Report generation completed.");
return report.str();
}
-void CompilerOutputParser::generateHtmlReport(
+void CompilerOutputParser::Impl::generateHtmlReport(
const std::string& outputFilename) const {
- std::ofstream outputFile(outputFilename);
- if (!outputFile.is_open()) {
- THROW_FAIL_TO_OPEN_FILE("Failed to open output file: " +
- outputFilename);
- }
+ LOG_F(INFO, "Generating HTML report: {}", outputFilename);
- outputFile << "\n";
- outputFile << "Compiler Messages Report
\n";
- outputFile << "\n";
- outputFile << "- Errors: " << counts_.at(MessageType::ERROR) << "
\n";
- outputFile << "- Warnings: " << counts_.at(MessageType::WARNING)
- << "
\n";
- outputFile << "- Notes: " << counts_.at(MessageType::NOTE) << "
\n";
- outputFile << "- Unknown: " << counts_.at(MessageType::UNKNOWN)
- << "
\n";
- outputFile << "
\n";
+ HtmlBuilder builder;
+ builder.addTitle("Compiler Messages Report");
+ builder.addHeader("Compiler Messages Report", 1);
- outputFile << "Details
\n";
- outputFile << "\n";
+ builder.addHeader("Summary", 2);
+ builder.startUnorderedList();
+ builder.addListItem("Errors: " +
+ std::to_string(counts_.at(MessageType::ERROR)));
+ builder.addListItem("Warnings: " +
+ std::to_string(counts_.at(MessageType::WARNING)));
+ builder.addListItem("Notes: " +
+ std::to_string(counts_.at(MessageType::NOTE)));
+ builder.addListItem("Unknown: " +
+ std::to_string(counts_.at(MessageType::UNKNOWN)));
+ builder.endUnorderedList();
+
+ builder.addHeader("Details", 2);
+ builder.startUnorderedList();
for (const auto& msg : messages_) {
- outputFile << "- [" << toString(msg.type) << "] ";
+ std::string messageStr = "[" + toString(msg.type) + "] ";
if (!msg.file.empty()) {
- outputFile << msg.file << ":" << msg.line << ":" << msg.column
- << ": ";
+ messageStr += msg.file + ":" + std::to_string(msg.line) + ":" +
+ std::to_string(msg.column) + ": ";
}
if (!msg.errorCode.empty()) {
- outputFile << msg.errorCode << " ";
+ messageStr += msg.errorCode + " ";
}
if (!msg.functionName.empty()) {
- outputFile << msg.functionName << " ";
- }
- outputFile << msg.message << "
\n";
- if (!msg.context.empty()) {
- outputFile << "- Context: " << msg.context << "
\n";
- }
- for (const auto& note : msg.relatedNotes) {
- outputFile << "- Note: " << note << "
\n";
+ messageStr += msg.functionName + " ";
}
+ messageStr += msg.message;
+
+ builder.addListItem(messageStr);
}
- outputFile << "
\n";
- outputFile << "\n";
+ builder.endUnorderedList();
+
+ std::ofstream outputFile(outputFilename);
+ if (!outputFile.is_open()) {
+ LOG_F(ERROR, "Failed to open output file: {}", outputFilename);
+ THROW_FAIL_TO_OPEN_FILE("Failed to open output file: " +
+ outputFilename);
+ }
+
+ outputFile << builder.str();
+ LOG_F(INFO, "HTML report generated and saved to: {}", outputFilename);
}
-auto CompilerOutputParser::generateJsonReport() -> json {
+auto CompilerOutputParser::Impl::generateJsonReport() -> json {
+ LOG_F(INFO, "Generating JSON report.");
+
json root;
- root["Errors"] = counts_[MessageType::ERROR];
- root["Warnings"] = counts_[MessageType::WARNING];
- root["Notes"] = counts_[MessageType::NOTE];
- root["Unknown"] = counts_[MessageType::UNKNOWN];
+ root["Errors"] = counts_.at(MessageType::ERROR).load();
+ root["Warnings"] = counts_.at(MessageType::WARNING).load();
+ root["Notes"] = counts_.at(MessageType::NOTE).load();
+ root["Unknown"] = counts_.at(MessageType::UNKNOWN).load();
for (const auto& msg : messages_) {
json entry;
@@ -187,16 +370,20 @@ auto CompilerOutputParser::generateJsonReport() -> json {
root["Details"].push_back(entry);
}
+ LOG_F(INFO, "JSON report generation completed.");
return root;
}
-void CompilerOutputParser::setCustomRegexPattern(const std::string& compiler,
- const std::string& pattern) {
+void CompilerOutputParser::Impl::setCustomRegexPattern(
+ const std::string& compiler, const std::string& pattern) {
+ LOG_F(INFO, "Setting custom regex pattern for compiler: {}", compiler);
std::lock_guard lock(mutex_);
regexPatterns_[compiler] = std::regex(pattern);
}
-void CompilerOutputParser::initRegexPatterns() {
+void CompilerOutputParser::Impl::initRegexPatterns() {
+ LOG_F(INFO, "Initializing regex patterns for supported compilers.");
+
regexPatterns_["gcc_clang"] =
std::regex(R"((.*):(\d+):(\d+): (error|warning|note): (.*))");
regexPatterns_["msvc"] =
@@ -205,7 +392,7 @@ void CompilerOutputParser::initRegexPatterns() {
std::regex(R"((.*)\((\d+)\): (error|remark|warning|note): (.*))");
}
-auto CompilerOutputParser::determineType(const std::string& typeStr) const
+auto CompilerOutputParser::Impl::determineType(const std::string& typeStr) const
-> MessageType {
if (typeStr == "error") {
return MessageType::ERROR;
@@ -219,7 +406,8 @@ auto CompilerOutputParser::determineType(const std::string& typeStr) const
return MessageType::UNKNOWN;
}
-auto CompilerOutputParser::toString(MessageType type) const -> std::string {
+auto CompilerOutputParser::Impl::toString(MessageType type) const
+ -> std::string {
switch (type) {
case MessageType::ERROR:
return "Error";
@@ -232,4 +420,55 @@ auto CompilerOutputParser::toString(MessageType type) const -> std::string {
}
}
+CompilerOutputParser::CompilerOutputParser() : pImpl(std::make_unique()) {
+ LOG_F(INFO, "CompilerOutputParser created.");
+}
+
+CompilerOutputParser::~CompilerOutputParser() = default;
+
+CompilerOutputParser::CompilerOutputParser(CompilerOutputParser&&) noexcept =
+ default;
+CompilerOutputParser& CompilerOutputParser::operator=(
+ CompilerOutputParser&&) noexcept = default;
+
+void CompilerOutputParser::parseLine(const std::string& line) {
+ LOG_F(INFO, "Parsing single line.");
+ pImpl->parseLine(line);
+}
+
+void CompilerOutputParser::parseFile(const std::string& filename) {
+ LOG_F(INFO, "Parsing file: {}", filename);
+ pImpl->parseFile(filename);
+}
+
+void CompilerOutputParser::parseFileMultiThreaded(const std::string& filename,
+ int numThreads) {
+ LOG_F(INFO, "Parsing file {} with multithreading (%d threads)", filename,
+ numThreads);
+ pImpl->parseFileMultiThreaded(filename, numThreads);
+}
+
+auto CompilerOutputParser::getReport(bool detailed) const -> std::string {
+ LOG_F(INFO, "Requesting report with detailed option: {}",
+ detailed ? "true" : "false");
+ return pImpl->getReport(detailed);
+}
+
+void CompilerOutputParser::generateHtmlReport(
+ const std::string& outputFilename) const {
+ LOG_F(INFO, "Generating HTML report to file: {}", outputFilename);
+ pImpl->generateHtmlReport(outputFilename);
+}
+
+auto CompilerOutputParser::generateJsonReport() -> json {
+ LOG_F(INFO, "Generating JSON report.");
+ return pImpl->generateJsonReport();
+}
+
+void CompilerOutputParser::setCustomRegexPattern(const std::string& compiler,
+ const std::string& pattern) {
+ LOG_F(INFO, "Setting custom regex pattern for compiler: {}", compiler);
+ pImpl->setCustomRegexPattern(compiler, pattern);
+}
+
} // namespace lithium
diff --git a/src/addon/analysts.hpp b/src/addon/analysts.hpp
index 3cfb0304..94a23e1f 100644
--- a/src/addon/analysts.hpp
+++ b/src/addon/analysts.hpp
@@ -1,20 +1,27 @@
#ifndef LITHIUM_ADDON_COMPILER_ANALYSIS_HPP
#define LITHIUM_ADDON_COMPILER_ANALYSIS_HPP
-#include
-#include
+#include
#include
-#include
#include
#include "atom/type/json_fwd.hpp"
+using json = nlohmann::json;
-namespace lithium {
+#include "macro.hpp"
-using json = nlohmann::json;
+namespace lithium {
+/**
+ * @enum MessageType
+ * @brief Represents the type of a compiler message.
+ */
enum class MessageType { ERROR, WARNING, NOTE, UNKNOWN };
+/**
+ * @struct Message
+ * @brief Holds information about a single compiler message.
+ */
struct Message {
MessageType type;
std::string file;
@@ -28,11 +35,27 @@ struct Message {
Message(MessageType t, std::string f, int l, int c, std::string code,
std::string func, std::string msg, std::string ctx);
-};
+} ATOM_ALIGNAS(128);
+/**
+ * @class CompilerOutputParser
+ * @brief Parses compiler output and generates reports.
+ *
+ * Uses regular expressions to parse compiler messages from various compilers,
+ * supports both single-threaded and multi-threaded parsing.
+ */
class CompilerOutputParser {
public:
CompilerOutputParser();
+ ~CompilerOutputParser();
+
+ // Delete copy constructor and copy assignment to avoid copying state
+ CompilerOutputParser(const CompilerOutputParser&) = delete;
+ CompilerOutputParser& operator=(const CompilerOutputParser&) = delete;
+
+ // Enable move semantics
+ CompilerOutputParser(CompilerOutputParser&&) noexcept;
+ CompilerOutputParser& operator=(CompilerOutputParser&&) noexcept;
void parseLine(const std::string& line);
void parseFile(const std::string& filename);
@@ -44,19 +67,10 @@ class CompilerOutputParser {
const std::string& pattern);
private:
- std::vector messages_;
- std::unordered_map counts_;
- mutable std::unordered_map regexPatterns_;
- mutable std::mutex mutex_;
- std::string currentContext_;
- std::regex includePattern_;
- std::smatch includeMatch_;
-
- void initRegexPatterns();
- MessageType determineType(const std::string& typeStr) const;
- std::string toString(MessageType type) const;
+ class Impl; // Forward declaration of implementation class
+ std::unique_ptr pImpl; // PIMPL idiom, pointer to implementation
};
} // namespace lithium
-#endif
+#endif // LITHIUM_ADDON_COMPILER_ANALYSIS_HPP
diff --git a/src/addon/builder.cpp b/src/addon/builder.cpp
index 47568a37..22d3981f 100644
--- a/src/addon/builder.cpp
+++ b/src/addon/builder.cpp
@@ -6,6 +6,7 @@
#include "atom/error/exception.hpp"
namespace lithium {
+
BuildManager::BuildManager(BuildSystemType type) {
switch (type) {
case BuildSystemType::CMake:
@@ -20,35 +21,69 @@ BuildManager::BuildManager(BuildSystemType type) {
}
auto BuildManager::configureProject(
- const std::string &sourceDir, const std::string &buildDir,
- const std::string &buildType,
- const std::vector &options) -> bool {
+ const std::filesystem::path& sourceDir,
+ const std::filesystem::path& buildDir, BuildType buildType,
+ const std::vector& options) -> BuildResult {
return builder_->configureProject(sourceDir, buildDir, buildType, options);
}
-auto BuildManager::buildProject(const std::string &buildDir, int jobs) -> bool {
+auto BuildManager::buildProject(const std::filesystem::path& buildDir,
+ std::optional jobs) -> BuildResult {
return builder_->buildProject(buildDir, jobs);
}
-auto BuildManager::cleanProject(const std::string &buildDir) -> bool {
+auto BuildManager::cleanProject(const std::filesystem::path& buildDir)
+ -> BuildResult {
return builder_->cleanProject(buildDir);
}
-auto BuildManager::installProject(const std::string &buildDir,
- const std::string &installDir) -> bool {
+auto BuildManager::installProject(const std::filesystem::path& buildDir,
+ const std::filesystem::path& installDir)
+ -> BuildResult {
return builder_->installProject(buildDir, installDir);
}
-auto BuildManager::runTests(const std::string &buildDir) -> bool {
- return builder_->runTests(buildDir);
+auto BuildManager::runTests(const std::filesystem::path& buildDir,
+ const std::vector& testNames)
+ -> BuildResult {
+ return builder_->runTests(buildDir, testNames);
}
-auto BuildManager::generateDocs(const std::string &buildDir) -> bool {
- return builder_->generateDocs(buildDir);
+auto BuildManager::generateDocs(const std::filesystem::path& buildDir,
+ const std::filesystem::path& outputDir)
+ -> BuildResult {
+ return builder_->generateDocs(buildDir, outputDir);
}
-auto BuildManager::loadConfig(const std::string &configPath) -> bool {
+auto BuildManager::loadConfig(const std::filesystem::path& configPath) -> bool {
return builder_->loadConfig(configPath);
}
+auto BuildManager::setLogCallback(
+ std::function callback) -> void {
+ builder_->setLogCallback(std::move(callback));
+}
+
+auto BuildManager::getAvailableTargets(const std::filesystem::path& buildDir)
+ -> std::vector {
+ return builder_->getAvailableTargets(buildDir);
+}
+
+auto BuildManager::buildTarget(const std::filesystem::path& buildDir,
+ const std::string& target,
+ std::optional jobs) -> BuildResult {
+ return builder_->buildTarget(buildDir, target, jobs);
+}
+
+auto BuildManager::getCacheVariables(const std::filesystem::path& buildDir)
+ -> std::vector> {
+ return builder_->getCacheVariables(buildDir);
+}
+
+auto BuildManager::setCacheVariable(const std::filesystem::path& buildDir,
+ const std::string& name,
+ const std::string& value) -> bool {
+ return builder_->setCacheVariable(buildDir, name, value);
+}
+
} // namespace lithium
diff --git a/src/addon/builder.hpp b/src/addon/builder.hpp
index 844162a0..bf74beb1 100644
--- a/src/addon/builder.hpp
+++ b/src/addon/builder.hpp
@@ -1,31 +1,64 @@
#ifndef LITHIUM_ADDON_BUILDER_HPP
#define LITHIUM_ADDON_BUILDER_HPP
+#include
+#include
#include
+#include
#include "platform/base.hpp"
namespace lithium {
+
class BuildManager {
public:
enum class BuildSystemType { CMake, Meson };
BuildManager(BuildSystemType type);
- auto configureProject(const std::string &sourceDir,
- const std::string &buildDir,
- const std::string &buildType,
- const std::vector &options) -> bool;
- auto buildProject(const std::string &buildDir, int jobs) -> bool;
- auto cleanProject(const std::string &buildDir) -> bool;
- auto installProject(const std::string &buildDir,
- const std::string &installDir) -> bool;
- auto runTests(const std::string &buildDir) -> bool;
- auto generateDocs(const std::string &buildDir) -> bool;
- auto loadConfig(const std::string &configPath) -> bool;
+
+ auto configureProject(
+ const std::filesystem::path& sourceDir,
+ const std::filesystem::path& buildDir, BuildType buildType,
+ const std::vector& options) -> BuildResult;
+
+ auto buildProject(const std::filesystem::path& buildDir,
+ std::optional jobs = std::nullopt) -> BuildResult;
+
+ auto cleanProject(const std::filesystem::path& buildDir) -> BuildResult;
+
+ auto installProject(const std::filesystem::path& buildDir,
+ const std::filesystem::path& installDir) -> BuildResult;
+
+ auto runTests(const std::filesystem::path& buildDir,
+ const std::vector& testNames = {})
+ -> BuildResult;
+
+ auto generateDocs(const std::filesystem::path& buildDir,
+ const std::filesystem::path& outputDir) -> BuildResult;
+
+ auto loadConfig(const std::filesystem::path& configPath) -> bool;
+
+ auto setLogCallback(std::function callback)
+ -> void;
+
+ auto getAvailableTargets(const std::filesystem::path& buildDir)
+ -> std::vector;
+
+ auto buildTarget(const std::filesystem::path& buildDir,
+ const std::string& target,
+ std::optional jobs = std::nullopt) -> BuildResult;
+
+ auto getCacheVariables(const std::filesystem::path& buildDir)
+ -> std::vector>;
+
+ auto setCacheVariable(const std::filesystem::path& buildDir,
+ const std::string& name,
+ const std::string& value) -> bool;
private:
std::unique_ptr builder_;
};
+
} // namespace lithium
-#endif
+#endif // LITHIUM_ADDON_BUILDER_HPP
diff --git a/src/addon/command.cpp b/src/addon/command.cpp
index 6889d4a9..ac79d6c0 100644
--- a/src/addon/command.cpp
+++ b/src/addon/command.cpp
@@ -1,14 +1,19 @@
#include "command.hpp"
+#include
#include
#include
-#include
+#include
#include
-#include
-#include
+#include
+#include "atom/log/loguru.hpp"
+#include "atom/type/json.hpp"
using json = nlohmann::json;
+#include "macro.hpp"
+
+namespace lithium {
struct CompileCommand {
std::string directory;
std::string command;
@@ -18,7 +23,13 @@ struct CompileCommand {
return json{
{"directory", directory}, {"command", command}, {"file", file}};
}
-};
+
+ void fromJson(const json& j) {
+ directory = j["directory"].get();
+ command = j["command"].get();
+ file = j["file"].get();
+ }
+} ATOM_ALIGNAS(128);
struct CompileCommandGenerator::Impl {
std::string sourceDir = "./src";
@@ -31,8 +42,12 @@ struct CompileCommandGenerator::Impl {
std::string projectName = "MyProject";
std::string projectVersion = "1.0.0";
std::mutex output_mutex;
+ std::atomic commandCounter{0};
- std::vector getSourceFiles() {
+ // 获取源文件
+ auto getSourceFiles() -> std::vector {
+ LOG_F(INFO, "Scanning source directory: {}",
+ sourceDir); // Log scanning process
std::vector sourceFiles;
for (const auto& entry :
std::filesystem::directory_iterator(sourceDir)) {
@@ -41,30 +56,42 @@ struct CompileCommandGenerator::Impl {
for (const auto& ext : extensions) {
if (path.extension() == ext) {
sourceFiles.push_back(path.string());
+ LOG_F(INFO, "Found source file: {}",
+ path.string()); // Log found file
}
}
}
}
+ LOG_F(INFO, "Total source files found: {}", sourceFiles.size());
return sourceFiles;
}
- std::vector parseExistingCommands() {
+ [[nodiscard]] auto parseExistingCommands() const
+ -> std::vector {
std::vector commands;
- std::ifstream ifs(existingCommandsPath);
+ if (existingCommandsPath.empty() ||
+ !std::filesystem::exists(existingCommandsPath)) {
+ LOG_F(WARNING, "No existing compile commands found at {}",
+ existingCommandsPath);
+ return commands;
+ }
+ LOG_F(INFO, "Parsing existing compile commands from {}",
+ existingCommandsPath);
+ std::ifstream ifs(existingCommandsPath);
if (ifs.is_open()) {
json j;
ifs >> j;
for (const auto& cmd : j["commands"]) {
- commands.emplace_back(cmd["directory"].get(),
- cmd["command"].get(),
- cmd["file"].get());
+ CompileCommand c;
+ c.fromJson(cmd);
+ commands.push_back(c);
}
ifs.close();
+ LOG_F(INFO, "Parsed {} existing compile commands", commands.size());
} else {
- std::cerr << "Could not open " << existingCommandsPath << std::endl;
+ LOG_F(ERROR, "Failed to open {}", existingCommandsPath);
}
-
return commands;
}
@@ -73,94 +100,114 @@ struct CompileCommandGenerator::Impl {
compiler + " " + includeFlag + " " + outputFlag + " " + file;
CompileCommand cmd{sourceDir, command, file};
+ LOG_F(INFO, "Generating compile command for file: {}", file);
std::lock_guard lock(output_mutex);
j_commands.push_back(cmd.toJson());
+ int currentCount =
+ commandCounter.fetch_add(1, std::memory_order_relaxed) + 1;
+ LOG_F(INFO, "Total commands generated so far: {}", currentCount);
}
void saveCommandsToFile(const json& j) {
+ LOG_F(INFO, "Saving compile commands to file: {}", outputPath);
std::ofstream ofs(outputPath);
if (ofs.is_open()) {
ofs << j.dump(4);
ofs.close();
- std::cout << "compile_commands.json generated successfully at "
- << outputPath << "." << std::endl;
+ LOG_F(INFO,
+ "compile_commands.json generated successfully with {} "
+ "commands at {}.",
+ commandCounter.load(std::memory_order_relaxed),
+ outputPath); // Log success
} else {
- std::cerr << "Failed to create compile_commands.json." << std::endl;
+ LOG_F(ERROR, "Failed to open {} for writing.", outputPath);
}
}
-};
+} ATOM_ALIGNAS(128);
-CompileCommandGenerator::CompileCommandGenerator() : pImpl(new Impl) {}
+CompileCommandGenerator::CompileCommandGenerator()
+ : impl_(std::make_unique()) {}
-CompileCommandGenerator::~CompileCommandGenerator() { delete pImpl; }
+CompileCommandGenerator::~CompileCommandGenerator() = default;
void CompileCommandGenerator::setSourceDir(const std::string& dir) {
- pImpl->sourceDir = dir;
+ LOG_F(INFO, "Setting source directory to {}",
+ dir); // Log set source directory
+ impl_->sourceDir = dir;
}
void CompileCommandGenerator::setCompiler(const std::string& compiler) {
- pImpl->compiler = compiler;
+ LOG_F(INFO, "Setting compiler to {}", compiler);
+ impl_->compiler = compiler;
}
void CompileCommandGenerator::setIncludeFlag(const std::string& flag) {
- pImpl->includeFlag = flag;
+ LOG_F(INFO, "Setting include flag to {}", flag);
+ impl_->includeFlag = flag;
}
void CompileCommandGenerator::setOutputFlag(const std::string& flag) {
- pImpl->outputFlag = flag;
+ LOG_F(INFO, "Setting output flag to {}", flag);
+ impl_->outputFlag = flag;
}
void CompileCommandGenerator::setProjectName(const std::string& name) {
- pImpl->projectName = name;
+ LOG_F(INFO, "Setting project name to {}", name);
+ impl_->projectName = name;
}
void CompileCommandGenerator::setProjectVersion(const std::string& version) {
- pImpl->projectVersion = version;
+ LOG_F(INFO, "Setting project version to {}",
+ version); // Log set project version
+ impl_->projectVersion = version;
}
void CompileCommandGenerator::addExtension(const std::string& ext) {
- pImpl->extensions.push_back(ext);
+ LOG_F(INFO, "Adding file extension: {}", ext);
+ impl_->extensions.push_back(ext);
}
void CompileCommandGenerator::setOutputPath(const std::string& path) {
- pImpl->outputPath = path;
+ LOG_F(INFO, "Setting output path to {}", path);
+ impl_->outputPath = path;
}
void CompileCommandGenerator::setExistingCommandsPath(const std::string& path) {
- pImpl->existingCommandsPath = path;
+ LOG_F(INFO, "Setting existing commands path to {}",
+ path); // Log set existing commands path
+ impl_->existingCommandsPath = path;
}
void CompileCommandGenerator::generate() {
+ LOG_F(INFO, "Starting compile command generation");
std::vector commands;
// 解析现有的 compile_commands.json
- if (!pImpl->existingCommandsPath.empty()) {
- auto existing_commands = pImpl->parseExistingCommands();
- commands.insert(commands.end(), existing_commands.begin(),
- existing_commands.end());
+ if (!impl_->existingCommandsPath.empty()) {
+ auto existingCommands = impl_->parseExistingCommands();
+ commands.insert(commands.end(), existingCommands.begin(),
+ existingCommands.end());
}
- auto source_files = pImpl->getSourceFiles();
- json j_commands = json::array();
+ auto sourceFiles = impl_->getSourceFiles();
+ json jCommands = json::array();
- // 多线程处理构建命令
- std::vector threads;
- for (const auto& file : source_files) {
- threads.emplace_back(&Impl::generateCompileCommand, pImpl, file,
- std::ref(j_commands));
- }
-
- // 等待所有线程完成
- for (auto& thread : threads) {
- thread.join();
- }
+ LOG_F(INFO, "Generating compile commands for {} source files",
+ sourceFiles.size()); // Log number of files
+ std::ranges::for_each(sourceFiles.begin(), sourceFiles.end(),
+ [&](const std::string& file) {
+ impl_->generateCompileCommand(file, jCommands);
+ });
// 构建最终的 JSON
json j = {{"version", 4},
- {"project_name", pImpl->projectName},
- {"project_version", pImpl->projectVersion},
- {"commands", j_commands}};
+ {"project_name", impl_->projectName},
+ {"project_version", impl_->projectVersion},
+ {"commands", jCommands}};
// 保存到文件
- pImpl->saveCommandsToFile(j);
+ impl_->saveCommandsToFile(j);
+ LOG_F(INFO, "Compile command generation complete");
}
+
+} // namespace lithium
diff --git a/src/addon/command.hpp b/src/addon/command.hpp
index 2e9deacd..54ec1598 100644
--- a/src/addon/command.hpp
+++ b/src/addon/command.hpp
@@ -1,14 +1,19 @@
#ifndef COMPILE_COMMAND_GENERATOR_H
#define COMPILE_COMMAND_GENERATOR_H
+#include
#include
-#include
+#include "atom/type/json.hpp"
+using json = nlohmann::json;
+
+namespace lithium {
class CompileCommandGenerator {
public:
CompileCommandGenerator();
~CompileCommandGenerator();
+
void setSourceDir(const std::string& dir);
void setCompiler(const std::string& compiler);
void setIncludeFlag(const std::string& flag);
@@ -22,8 +27,9 @@ class CompileCommandGenerator {
void generate();
private:
- struct Impl; // Forward declaration of the implementation class
- Impl* pImpl; // Pointer to the implementation
+ struct Impl;
+ std::unique_ptr impl_;
};
+} // namespace lithium
#endif // COMPILE_COMMAND_GENERATOR_H
diff --git a/src/addon/compiler.cpp b/src/addon/compiler.cpp
index 39cb68e2..40b2f147 100644
--- a/src/addon/compiler.cpp
+++ b/src/addon/compiler.cpp
@@ -1,11 +1,11 @@
#include "compiler.hpp"
+#include "command.hpp"
#include "io/io.hpp"
#include "toolchain.hpp"
#include "utils/constant.hpp"
#include
-#include
#include
#include
#include "atom/log/loguru.hpp"
@@ -30,11 +30,14 @@ class CompilerImpl {
auto getAvailableCompilers() const -> std::vector;
+ void generateCompileCommands(const std::string &sourceDir);
+
private:
void createOutputDirectory(const fs::path &outputDir);
- auto syntaxCheck(std::string_view code, std::string_view compiler) -> bool;
- auto compileCode(std::string_view code, std::string_view compiler,
- std::string_view compileOptions,
+ auto syntaxCheck(std::string_view code,
+ const std::string &compiler) -> bool;
+ auto compileCode(std::string_view code, const std::string &compiler,
+ const std::string &compileOptions,
const fs::path &output) -> bool;
auto findAvailableCompilers() -> std::vector;
@@ -42,6 +45,7 @@ class CompilerImpl {
std::unordered_map cache_;
std::string customCompileOptions_;
ToolchainManager toolchainManager_;
+ std::unique_ptr compileCommandGenerator_;
};
Compiler::Compiler() : impl_(std::make_unique()) {}
@@ -64,19 +68,42 @@ auto Compiler::getAvailableCompilers() const -> std::vector {
return impl_->getAvailableCompilers();
}
-CompilerImpl::CompilerImpl() { toolchainManager_.scanForToolchains(); }
+CompilerImpl::CompilerImpl()
+ : compileCommandGenerator_(std::make_unique()) {
+ LOG_F(INFO, "Initializing CompilerImpl...");
+ toolchainManager_.scanForToolchains();
+ LOG_F(INFO, "Toolchains scanned.");
+ compileCommandGenerator_->setCompiler(Constants::COMPILER);
+ compileCommandGenerator_->setIncludeFlag("-I./include");
+ compileCommandGenerator_->setOutputFlag("-o output");
+ compileCommandGenerator_->addExtension(".cpp");
+ compileCommandGenerator_->addExtension(".c");
+ LOG_F(INFO, "CompileCommandGenerator initialized with default settings.");
+}
+
+void CompilerImpl::generateCompileCommands(const std::string &sourceDir) {
+ LOG_F(INFO, "Generating compile commands for source directory: {}",
+ sourceDir);
+ compileCommandGenerator_->setSourceDir(sourceDir);
+ compileCommandGenerator_->setOutputPath("compile_commands.json");
+ compileCommandGenerator_->generate();
+ LOG_F(INFO, "Compile commands generation complete for directory: {}",
+ sourceDir);
+}
auto CompilerImpl::compileToSharedLibrary(
std::string_view code, std::string_view moduleName,
std::string_view functionName, std::string_view optionsFile) -> bool {
- LOG_F(INFO, "Compiling module {}::{}...", moduleName, functionName);
+ LOG_F(INFO, "Compiling module {}::{} with options file: {}", moduleName,
+ functionName, optionsFile);
if (code.empty() || moduleName.empty() || functionName.empty()) {
- LOG_F(ERROR, "Invalid parameters.");
+ LOG_F(
+ ERROR,
+ "Invalid parameters: code, moduleName, or functionName is empty.");
return false;
}
- // 检查模块是否已编译并缓存
std::string cacheKey = std::format("{}::{}", moduleName, functionName);
if (cache_.find(cacheKey) != cache_.end()) {
LOG_F(WARNING, "Module {} is already compiled, using cached result.",
@@ -84,11 +111,9 @@ auto CompilerImpl::compileToSharedLibrary(
return true;
}
- // 创建输出目录
const fs::path OUTPUT_DIR = "atom/global";
createOutputDirectory(OUTPUT_DIR);
- // Max: 检查可用的编译器,然后再和指定的进行比对
auto availableCompilers = findAvailableCompilers();
if (availableCompilers.empty()) {
LOG_F(ERROR, "No available compilers found.");
@@ -97,63 +122,101 @@ auto CompilerImpl::compileToSharedLibrary(
LOG_F(INFO, "Available compilers: {}",
atom::utils::toString(availableCompilers));
- // 读取编译选项
+ // Read compile options
std::ifstream optionsStream(optionsFile.data());
- std::string compileOptions = [&] -> std::string {
- if (!optionsStream) {
- LOG_F(
- WARNING,
- "Failed to open compile options file, using default options.");
- return "-O2 -std=c++20 -Wall -shared -fPIC";
- }
-
+ json optionsJson;
+ if (!optionsStream) {
+ LOG_F(WARNING,
+ "Failed to open compile options file {}, using default options.",
+ optionsFile);
+ optionsJson = {{"compiler", Constants::COMPILER},
+ {"optimization_level", "-O2"},
+ {"cplus_version", "-std=c++20"},
+ {"warnings", "-Wall"}};
+ } else {
try {
- json optionsJson;
optionsStream >> optionsJson;
-
- auto compiler = optionsJson.value("compiler", constants::COMPILER);
- if (std::find(availableCompilers.begin(), availableCompilers.end(),
- compiler) == availableCompilers.end()) {
- LOG_F(WARNING, "Compiler {} is not available, using default.",
- compiler);
- compiler = constants::COMPILER;
- }
-
- auto cmd = std::format(
- "{} {} {} {} {}", compiler,
- optionsJson.value("optimization_level", "-O2"),
- optionsJson.value("cplus_version", "-std=c++20"),
- optionsJson.value("warnings", "-Wall"), customCompileOptions_);
-
- LOG_F(INFO, "Compile options: {}", cmd);
- return cmd;
-
+ LOG_F(INFO, "Compile options file {} successfully parsed.",
+ optionsFile);
} catch (const json::parse_error &e) {
- LOG_F(ERROR, "Failed to parse compile options file: {}", e.what());
- } catch (const std::exception &e) {
- LOG_F(ERROR, "Failed to parse compile options file: {}", e.what());
+ LOG_F(ERROR, "Failed to parse compile options file {}: {}",
+ optionsFile, e.what());
+ return false;
}
+ }
- return constants::COMPILER +
- std::string{"-O2 -std=c++20 -Wall -shared -fPIC"};
- }();
+ std::string compiler = optionsJson.value("compiler", Constants::COMPILER);
+ if (std::find(availableCompilers.begin(), availableCompilers.end(),
+ compiler) == availableCompilers.end()) {
+ LOG_F(WARNING,
+ "Compiler {} is not available, using default compiler {}.",
+ compiler, Constants::COMPILER);
+ compiler = Constants::COMPILER;
+ }
- // 语法检查
- if (!syntaxCheck(code, constants::COMPILER)) {
- return false;
+ // Use CompileCommandGenerator to generate compile command
+ compileCommandGenerator_->setCompiler(compiler);
+ compileCommandGenerator_->setIncludeFlag(
+ optionsJson.value("include_flag", "-I./include"));
+ compileCommandGenerator_->setOutputFlag(
+ optionsJson.value("output_flag", "-o output"));
+ compileCommandGenerator_->setProjectName(std::string(moduleName));
+ compileCommandGenerator_->setProjectVersion("1.0.0");
+
+ // Temporarily create a file to store the code
+ fs::path tempSourceFile = fs::temp_directory_path() / "temp_code.cpp";
+ {
+ std::ofstream tempFile(tempSourceFile);
+ tempFile << code;
}
+ LOG_F(INFO, "Temporary source file created at: {}",
+ tempSourceFile.string());
- // 编译代码
- fs::path outputPath =
- OUTPUT_DIR / std::format("{}{}{}", constants::LIB_EXTENSION, moduleName,
- constants::LIB_EXTENSION);
- if (!compileCode(code, constants::COMPILER, compileOptions, outputPath)) {
- return false;
+ compileCommandGenerator_->setSourceDir(
+ tempSourceFile.parent_path().string());
+ compileCommandGenerator_->setOutputPath(OUTPUT_DIR /
+ "compile_commands.json");
+ compileCommandGenerator_->generate();
+
+ // Read generated compile_commands.json
+ json compileCommands;
+ {
+ std::ifstream commandsFile(OUTPUT_DIR / "compile_commands.json");
+ commandsFile >> compileCommands;
}
+ LOG_F(INFO, "Compile commands file read from: {}",
+ (OUTPUT_DIR / "compile_commands.json").string());
+
+ // Use the generated command to compile
+ if (!compileCommands["commands"].empty()) {
+ auto command =
+ compileCommands["commands"][0]["command"].get();
+ command += " " + customCompileOptions_;
+
+ fs::path outputPath =
+ OUTPUT_DIR / std::format("{}{}{}", Constants::LIB_EXTENSION,
+ moduleName, Constants::LIB_EXTENSION);
+ command += " -o " + outputPath.string();
+
+ LOG_F(INFO, "Executing compilation command: {}", command);
+ std::string compilationOutput = atom::system::executeCommand(command);
+ if (!compilationOutput.empty()) {
+ LOG_F(ERROR, "Compilation failed:\n{}", compilationOutput);
+ fs::remove(tempSourceFile);
+ return false;
+ }
- // 缓存编译结果
- cache_[cacheKey] = outputPath;
- return true;
+ // Cache the compilation result
+ cache_[cacheKey] = outputPath;
+ LOG_F(INFO, "Compilation successful, result cached with key: {}",
+ cacheKey);
+ fs::remove(tempSourceFile);
+ return true;
+ }
+
+ LOG_F(ERROR, "Failed to generate compile command.");
+ fs::remove(tempSourceFile);
+ return false;
}
void CompilerImpl::createOutputDirectory(const fs::path &outputDir) {
@@ -161,78 +224,147 @@ void CompilerImpl::createOutputDirectory(const fs::path &outputDir) {
LOG_F(WARNING, "Output directory {} does not exist, creating it.",
outputDir.string());
fs::create_directories(outputDir);
+ LOG_F(INFO, "Output directory {} created.", outputDir.string());
+ } else {
+ LOG_F(INFO, "Output directory {} already exists.", outputDir.string());
}
}
auto CompilerImpl::syntaxCheck(std::string_view code,
- std::string_view compiler) -> bool {
- if (atom::io::isFileExists("temp_code.cpp")) {
- if (!atom::io::removeFile("temp_code.cpp")) {
- LOG_F(ERROR, "Failed to remove temp_code.cpp");
- return false;
- }
+ const std::string &compiler) -> bool {
+ LOG_F(INFO, "Starting syntax check using compiler: {}", compiler);
+ compileCommandGenerator_->setCompiler(compiler);
+ compileCommandGenerator_->setIncludeFlag("-fsyntax-only");
+ compileCommandGenerator_->setOutputFlag("");
+
+ fs::path tempSourceFile = fs::temp_directory_path() / "syntax_check.cpp";
+ {
+ std::ofstream tempFile(tempSourceFile);
+ tempFile << code;
}
- // Create a temporary file to store the code
- std::string tempFileName = "temp_code.cpp";
+ LOG_F(INFO, "Temporary file for syntax check created at: {}",
+ tempSourceFile.string());
+
+ compileCommandGenerator_->setSourceDir(
+ tempSourceFile.parent_path().string());
+ compileCommandGenerator_->setOutputPath(fs::temp_directory_path() /
+ "syntax_check_commands.json");
+ compileCommandGenerator_->generate();
+
+ json syntaxCheckCommands;
{
- std::ofstream tempFile(tempFileName, std::ios::trunc);
- if (!tempFile) {
- LOG_F(ERROR, "Failed to create temporary file for code");
+ std::ifstream commandsFile(fs::temp_directory_path() /
+ "syntax_check_commands.json");
+ commandsFile >> syntaxCheckCommands;
+ }
+ LOG_F(INFO, "Syntax check commands file read.");
+
+ if (!syntaxCheckCommands["commands"].empty()) {
+ auto command =
+ syntaxCheckCommands["commands"][0]["command"].get();
+ LOG_F(INFO, "Executing syntax check command: {}", command);
+ std::string output = atom::system::executeCommand(command);
+
+ fs::remove(tempSourceFile);
+ fs::remove(fs::temp_directory_path() / "syntax_check_commands.json");
+
+ if (!output.empty()) {
+ LOG_F(ERROR, "Syntax check failed:\n{}", output);
return false;
}
- tempFile << code;
+ LOG_F(INFO, "Syntax check passed.");
+ return true;
}
+ LOG_F(ERROR, "Failed to generate syntax check command.");
+ fs::remove(tempSourceFile);
+ fs::remove(fs::temp_directory_path() / "syntax_check_commands.json");
+ return false;
+}
- // Format the command to invoke the compiler with syntax-only check
- std::string command =
- std::format("{} -fsyntax-only -x c++ {}", compiler, tempFileName);
- std::string output;
+auto CompilerImpl::compileCode(std::string_view code,
+ const std::string &compiler,
+ const std::string &compileOptions,
+ const fs::path &output) -> bool {
+ LOG_F(INFO,
+ "Starting compilation with compiler: {}, options: {}, output: {}",
+ compiler, compileOptions, output.string());
- // Execute the command and process output
- output = atom::system::executeCommand(
- command,
- false, // No need for a shell
- [&](const std::string &line) { output += line + "\n"; });
+ // Use CompileCommandGenerator to generate the compile command
+ compileCommandGenerator_->setCompiler(compiler);
+ compileCommandGenerator_->setIncludeFlag(compileOptions);
+ compileCommandGenerator_->setOutputFlag("-o " + output.string());
- // Clean up temporary file
- if (!atom::io::removeFile("temp_code.cpp")) {
- LOG_F(ERROR, "Failed to remove temp_code.cpp");
- return false;
+ fs::path tempSourceFile = fs::temp_directory_path() / "compile_code.cpp";
+ {
+ std::ofstream tempFile(tempSourceFile);
+ tempFile << code;
}
+ LOG_F(INFO, "Temporary file for compilation created at: {}",
+ tempSourceFile.string());
- // Check the output for errors
- if (!output.empty()) {
- LOG_F(ERROR, "Syntax check failed:\n{}", output);
- return false;
+ compileCommandGenerator_->setSourceDir(
+ tempSourceFile.parent_path().string());
+ compileCommandGenerator_->setOutputPath(fs::temp_directory_path() /
+ "compile_code_commands.json");
+ compileCommandGenerator_->generate();
+
+ json compileCodeCommands;
+ {
+ std::ifstream commandsFile(fs::temp_directory_path() /
+ "compile_code_commands.json");
+ commandsFile >> compileCodeCommands;
}
- return true;
-}
+ LOG_F(INFO, "Compile commands file read.");
-auto CompilerImpl::compileCode(std::string_view code, std::string_view compiler,
- std::string_view compileOptions,
- const fs::path &output) -> bool {
- std::string command = std::format("{} {} -xc++ - -o {}", compiler,
- compileOptions, output.string());
- std::string compilationOutput;
- compilationOutput = atom::system::executeCommandWithInput(
- command, std::string(code),
- [&](const std::string &line) { compilationOutput += line + "\n"; });
- if (!compilationOutput.empty()) {
- LOG_F(ERROR, "Compilation failed:\n{}", compilationOutput);
- return false;
+ if (!compileCodeCommands["commands"].empty()) {
+ auto command =
+ compileCodeCommands["commands"][0]["command"].get();
+ LOG_F(INFO, "Executing compilation command: {}", command);
+ std::string compilationOutput = atom::system::executeCommand(command);
+
+ fs::remove(tempSourceFile);
+ fs::remove(fs::temp_directory_path() / "compile_code_commands.json");
+
+ if (!compilationOutput.empty()) {
+ LOG_F(ERROR, "Compilation failed:\n{}", compilationOutput);
+ return false;
+ }
+ LOG_F(INFO, "Compilation successful, output file: {}", output.string());
+ return true;
}
- return true;
+
+ LOG_F(ERROR, "Failed to generate compile command.");
+ fs::remove(tempSourceFile);
+ fs::remove(fs::temp_directory_path() / "compile_code_commands.json");
+ return false;
}
auto CompilerImpl::findAvailableCompilers() -> std::vector {
- return toolchainManager_.getAvailableCompilers();
+ LOG_F(INFO, "Finding available compilers...");
+ auto compilers = toolchainManager_.getAvailableCompilers();
+ if (compilers.empty()) {
+ LOG_F(WARNING, "No compilers found.");
+ } else {
+ LOG_F(INFO, "Found compilers: {}", atom::utils::toString(compilers));
+ }
+ return compilers;
}
void CompilerImpl::addCompileOptions(const std::string &options) {
+ LOG_F(INFO, "Adding custom compile options: {}", options);
customCompileOptions_ = options;
}
auto CompilerImpl::getAvailableCompilers() const -> std::vector {
+ LOG_F(INFO, "Retrieving available compilers...");
return toolchainManager_.getAvailableCompilers();
}
+
+void Compiler::generateCompileCommands(const std::string &sourceDir) {
+ LOG_F(INFO,
+ "Generating compile commands in Compiler for source directory: {}",
+ sourceDir);
+ impl_->generateCompileCommands(sourceDir);
+}
+
} // namespace lithium
diff --git a/src/addon/compiler.hpp b/src/addon/compiler.hpp
index 618c831d..eda9cbe0 100644
--- a/src/addon/compiler.hpp
+++ b/src/addon/compiler.hpp
@@ -25,19 +25,20 @@ using json = nlohmann::json;
namespace lithium {
class CompilerImpl;
-
+class CompileCommandGenerator;
class Compiler {
public:
Compiler();
~Compiler();
/**
- * 编译 C++ 代码为共享库,并加载到内存中
- * @param code 要编译的代码
- * @param moduleName 模块名
- * @param functionName 入口函数名
- * @param optionsFile 编译选项文件路径,默认为 "compile_options.json"
- * @return 编译是否成功
+ * Compile C++ code into a shared library and load it into memory
+ * @param code Code to compile
+ * @param moduleName Module name
+ * @param functionName Entry function name
+ * @param optionsFile Compilation options file path, default is
+ * "compile_options.json"
+ * @return Whether compilation was successful
*/
ATOM_NODISCARD auto compileToSharedLibrary(
std::string_view code, std::string_view moduleName,
@@ -45,18 +46,24 @@ class Compiler {
std::string_view optionsFile = "compile_options.json") -> bool;
/**
- * 添加自定义编译选项
- * @param options 编译选项
+ * Add custom compilation options
+ * @param options Compilation options
*/
- void addCompileOptions(const std::string &options);
+ void addCompileOptions(const std::string& options);
/**
- * 获取可用编译器列表
- * @return 编译器列表
+ * Get list of available compilers
+ * @return List of compilers
*/
ATOM_NODISCARD auto getAvailableCompilers() const
-> std::vector;
+ /**
+ * Generate compile commands for a given source directory
+ * @param sourceDir Source directory path
+ */
+ void generateCompileCommands(const std::string& sourceDir);
+
private:
std::unique_ptr impl_;
};
diff --git a/src/addon/debug/elf.cpp b/src/addon/debug/elf.cpp
index e5b21e8f..52613941 100644
--- a/src/addon/debug/elf.cpp
+++ b/src/addon/debug/elf.cpp
@@ -1,336 +1,236 @@
#include "elf.hpp"
-#include
-#include
-#include
-#include
-#include
+#include
+#include
+#include
+#include
-#include "atom/log/loguru.hpp"
-#include "atom/macro.hpp"
+#include "atom/error/exception.hpp"
namespace lithium {
-// Define your structures
-struct ElfHeader {
- uint16_t type;
- uint16_t machine;
- uint32_t version;
- uint64_t entry;
- uint64_t phoff;
- uint64_t shoff;
- uint32_t flags;
- uint16_t ehsize;
- uint16_t phentsize;
- uint16_t phnum;
- uint16_t shentsize;
- uint16_t shnum;
- uint16_t shstrndx;
-} ATOM_ALIGNAS(64);
-
-struct ProgramHeader {
- uint32_t type;
- uint64_t offset;
- uint64_t vaddr;
- uint64_t paddr;
- uint64_t filesz;
- uint64_t memsz;
- uint32_t flags;
- uint64_t align;
-} ATOM_ALIGNAS(64);
-
-struct SectionHeader {
- std::string name;
- uint32_t type{};
- uint64_t flags{};
- uint64_t addr{};
- uint64_t offset{};
- uint64_t size{};
- uint32_t link{};
- uint32_t info{};
- uint64_t addralign{};
- uint64_t entsize{};
-} ATOM_ALIGNAS(128);
-
-struct Symbol {
- std::string name;
- uint64_t value{};
- uint64_t size{};
- unsigned char bind{};
- unsigned char type{};
- uint16_t shndx{};
-} ATOM_ALIGNAS(64);
-
-struct DynamicEntry {
- uint64_t tag;
- union {
- uint64_t val;
- uint64_t ptr;
- } dUn;
-} ATOM_ALIGNAS(16);
-
-struct RelocationEntry {
- uint64_t offset;
- uint64_t info;
- int64_t addend;
-} ATOM_ALIGNAS(32);
-
-// Forward declaration of the implementation class
+
class ElfParser::Impl {
public:
- explicit Impl(const char* file);
- ~Impl();
-
- auto initialize() -> bool;
- void cleanup();
- auto parse() -> bool;
- [[nodiscard]] auto getElfHeader() const -> ElfHeader;
- [[nodiscard]] auto getProgramHeaders() const -> std::vector;
- [[nodiscard]] auto getSectionHeaders() const -> std::vector;
- [[nodiscard]] auto getSymbolTable() const -> std::vector;
- [[nodiscard]] auto getDynamicEntries() const -> std::vector;
- [[nodiscard]] auto getRelocationEntries() const -> std::vector;
+ explicit Impl(std::string_view file) : filePath_(file) {}
-private:
- const char* filePath_;
- int fd_{};
- Elf* elf_{};
- GElf_Ehdr ehdr_;
-};
+ auto parse() -> bool {
+ std::ifstream file(filePath_, std::ios::binary);
+ if (!file) {
+ return false;
+ }
-// Implementation of the `Impl` class methods
-ElfParser::Impl::Impl(const char* file) : filePath_(file) {}
+ file.seekg(0, std::ios::end);
+ fileSize_ = file.tellg();
+ file.seekg(0, std::ios::beg);
-ElfParser::Impl::~Impl() { cleanup(); }
+ fileContent_.resize(fileSize_);
+ file.read(reinterpret_cast(fileContent_.data()), fileSize_);
-auto ElfParser::Impl::initialize() -> bool {
- if (elf_version(EV_CURRENT) == EV_NONE) {
- LOG_F(ERROR, "ELF library initialization failed: {}", elf_errmsg(-1));
- return false;
+ return parseElfHeader() && parseProgramHeaders() &&
+ parseSectionHeaders() && parseSymbolTable();
}
- fd_ = open(filePath_, O_RDONLY, 0);
- if (fd_ < 0) {
- LOG_F(ERROR, "Failed to open ELF file: {}", filePath_);
- return false;
- }
+ [[nodiscard]] auto getElfHeader() const -> std::optional { return elfHeader_; }
- elf_ = elf_begin(fd_, ELF_C_READ, nullptr);
- if (elf_ == nullptr) {
- LOG_F(ERROR, "elf_begin() failed: {}", elf_errmsg(-1));
- close(fd_);
- fd_ = -1;
- return false;
+ [[nodiscard]] auto getProgramHeaders() const -> std::span {
+ return programHeaders_;
}
- if (gelf_getehdr(elf_, &ehdr_) == nullptr) {
- LOG_F(ERROR, "gelf_getehdr() failed: {}", elf_errmsg(-1));
- elf_end(elf_);
- elf_ = nullptr;
- close(fd_);
- fd_ = -1;
- return false;
+ [[nodiscard]] auto getSectionHeaders() const -> std::span {
+ return sectionHeaders_;
}
- return true;
-}
+ [[nodiscard]] auto getSymbolTable() const -> std::span { return symbolTable_; }
-void ElfParser::Impl::cleanup() {
- if (elf_ != nullptr) {
- elf_end(elf_);
- elf_ = nullptr;
+ [[nodiscard]] auto findSymbolByName(std::string_view name) const -> std::optional {
+ auto it = std::ranges::find_if(symbolTable_, [name](const auto& symbol) {
+ return symbol.name == name;
+ });
+ if (it != symbolTable_.end()) {
+ return *it;
+ }
+ return std::nullopt;
}
- if (fd_ >= 0) {
- close(fd_);
- fd_ = -1;
+
+ [[nodiscard]] auto findSymbolByAddress(uint64_t address) const -> std::optional {
+ auto it = std::ranges::find_if(
+ symbolTable_,
+ [address](const auto& symbol) { return symbol.value == address; });
+ if (it != symbolTable_.end()) {
+ return *it;
+ }
+ return std::nullopt;
}
-}
-auto ElfParser::Impl::parse() -> bool { return initialize(); }
-
-auto ElfParser::Impl::getElfHeader() const -> ElfHeader {
- ElfHeader header{};
- header.type = ehdr_.e_type;
- header.machine = ehdr_.e_machine;
- header.version = ehdr_.e_version;
- header.entry = ehdr_.e_entry;
- header.phoff = ehdr_.e_phoff;
- header.shoff = ehdr_.e_shoff;
- header.flags = ehdr_.e_flags;
- header.ehsize = ehdr_.e_ehsize;
- header.phentsize = ehdr_.e_phentsize;
- header.phnum = ehdr_.e_phnum;
- header.shentsize = ehdr_.e_shentsize;
- header.shnum = ehdr_.e_shnum;
- header.shstrndx = ehdr_.e_shstrndx;
- return header;
-}
+ [[nodiscard]] auto findSection(std::string_view name) const -> std::optional {
+ auto it = std::ranges::find_if(
+ sectionHeaders_,
+ [name](const auto& section) { return section.name == name; });
+ if (it != sectionHeaders_.end()) {
+ return *it;
+ }
+ return std::nullopt;
+ }
-auto ElfParser::Impl::getProgramHeaders() const -> std::vector {
- std::vector headers;
- for (size_t i = 0; i < ehdr_.e_phnum; ++i) {
- GElf_Phdr phdr;
- if (gelf_getphdr(elf_, i, &phdr) != &phdr) {
- LOG_F(ERROR, "gelf_getphdr() failed: {}", elf_errmsg(-1));
- continue;
+ [[nodiscard]] auto getSectionData(const SectionHeader& section) const -> std::vector {
+ if (section.offset + section.size > fileSize_) {
+ THROW_OUT_OF_RANGE("Section data out of bounds");
}
- ProgramHeader header;
- header.type = phdr.p_type;
- header.offset = phdr.p_offset;
- header.vaddr = phdr.p_vaddr;
- header.paddr = phdr.p_paddr;
- header.filesz = phdr.p_filesz;
- header.memsz = phdr.p_memsz;
- header.flags = phdr.p_flags;
- header.align = phdr.p_align;
- headers.push_back(header);
+ return {fileContent_.begin() + section.offset,
+ fileContent_.begin() + section.offset + section.size};
}
- return headers;
-}
-auto ElfParser::Impl::getSectionHeaders() const -> std::vector {
- std::vector headers;
- for (size_t i = 0; i < ehdr_.e_shnum; ++i) {
- GElf_Shdr shdr;
- if (gelf_getshdr(elf_getscn(elf_, i), &shdr) != &shdr) {
- LOG_F(ERROR, "gelf_getshdr() failed: {}", elf_errmsg(-1));
- continue;
+private:
+ std::string filePath_;
+ std::vector fileContent_;
+ size_t fileSize_{};
+
+ std::optional elfHeader_;
+ std::vector programHeaders_;
+ std::vector sectionHeaders_;
+ std::vector symbolTable_;
+
+ auto parseElfHeader() -> bool {
+ if (fileSize_ < sizeof(Elf64_Ehdr)) {
+ return false;
}
- SectionHeader header;
- header.name = elf_strptr(elf_, ehdr_.e_shstrndx, shdr.sh_name);
- header.type = shdr.sh_type;
- header.flags = shdr.sh_flags;
- header.addr = shdr.sh_addr;
- header.offset = shdr.sh_offset;
- header.size = shdr.sh_size;
- header.link = shdr.sh_link;
- header.info = shdr.sh_info;
- header.addralign = shdr.sh_addralign;
- header.entsize = shdr.sh_entsize;
- headers.push_back(header);
+
+ const auto* ehdr =
+ reinterpret_cast(fileContent_.data());
+ elfHeader_ = ElfHeader{.type = ehdr->e_type,
+ .machine = ehdr->e_machine,
+ .version = ehdr->e_version,
+ .entry = ehdr->e_entry,
+ .phoff = ehdr->e_phoff,
+ .shoff = ehdr->e_shoff,
+ .flags = ehdr->e_flags,
+ .ehsize = ehdr->e_ehsize,
+ .phentsize = ehdr->e_phentsize,
+ .phnum = ehdr->e_phnum,
+ .shentsize = ehdr->e_shentsize,
+ .shnum = ehdr->e_shnum,
+ .shstrndx = ehdr->e_shstrndx};
+
+ return true;
}
- return headers;
-}
-std::vector ElfParser::Impl::getSymbolTable() const {
- std::vector symbols;
- size_t shnum;
- elf_getshdrnum(elf_, &shnum);
-
- for (size_t i = 0; i < shnum; ++i) {
- Elf_Scn* scn = elf_getscn(elf_, i);
- GElf_Shdr shdr;
- gelf_getshdr(scn, &shdr);
-
- if (shdr.sh_type == SHT_SYMTAB || shdr.sh_type == SHT_DYNSYM) {
- Elf_Data* data = elf_getdata(scn, nullptr);
- size_t symbolCount = shdr.sh_size / shdr.sh_entsize;
-
- for (size_t j = 0; j < symbolCount; ++j) {
- GElf_Sym sym;
- gelf_getsym(data, j, &sym);
- Symbol symbol;
- symbol.name = elf_strptr(elf_, shdr.sh_link, sym.st_name);
- symbol.value = sym.st_value;
- symbol.size = sym.st_size;
- symbol.bind = GELF_ST_BIND(sym.st_info);
- symbol.type = GELF_ST_TYPE(sym.st_info);
- symbol.shndx = sym.st_shndx;
- symbols.push_back(symbol);
- }
+ auto parseProgramHeaders() -> bool {
+ if (!elfHeader_) {
+ return false;
+ }
+
+ const auto* phdr = reinterpret_cast(
+ fileContent_.data() + elfHeader_->phoff);
+ for (uint16_t i = 0; i < elfHeader_->phnum; ++i) {
+ programHeaders_.push_back(ProgramHeader{.type = phdr[i].p_type,
+ .offset = phdr[i].p_offset,
+ .vaddr = phdr[i].p_vaddr,
+ .paddr = phdr[i].p_paddr,
+ .filesz = phdr[i].p_filesz,
+ .memsz = phdr[i].p_memsz,
+ .flags = phdr[i].p_flags,
+ .align = phdr[i].p_align});
}
+
+ return true;
}
- return symbols;
-}
-auto ElfParser::Impl::getDynamicEntries() const -> std::vector {
- std::vector entries;
- size_t shnum;
- elf_getshdrnum(elf_, &shnum);
-
- for (size_t i = 0; i < shnum; ++i) {
- Elf_Scn* scn = elf_getscn(elf_, i);
- GElf_Shdr shdr;
- gelf_getshdr(scn, &shdr);
-
- if (shdr.sh_type == SHT_DYNAMIC) {
- Elf_Data* data = elf_getdata(scn, nullptr);
- size_t entryCount = shdr.sh_size / shdr.sh_entsize;
-
- for (size_t j = 0; j < entryCount; ++j) {
- GElf_Dyn dyn;
- gelf_getdyn(data, j, &dyn);
- DynamicEntry entry;
- entry.tag = dyn.d_tag;
- entry.dUn.val = dyn.d_un.d_val;
- entries.push_back(entry);
- }
+ auto parseSectionHeaders() -> bool {
+ if (!elfHeader_) {
+ return false;
+ }
+
+ const auto* shdr = reinterpret_cast(
+ fileContent_.data() + elfHeader_->shoff);
+ const auto* strtab = reinterpret_cast(
+ fileContent_.data() + shdr[elfHeader_->shstrndx].sh_offset);
+
+ for (uint16_t i = 0; i < elfHeader_->shnum; ++i) {
+ sectionHeaders_.push_back(
+ SectionHeader{.name = std::string(strtab + shdr[i].sh_name),
+ .type = shdr[i].sh_type,
+ .flags = shdr[i].sh_flags,
+ .addr = shdr[i].sh_addr,
+ .offset = shdr[i].sh_offset,
+ .size = shdr[i].sh_size,
+ .link = shdr[i].sh_link,
+ .info = shdr[i].sh_info,
+ .addralign = shdr[i].sh_addralign,
+ .entsize = shdr[i].sh_entsize});
}
+
+ return true;
}
- return entries;
-}
-std::vector ElfParser::Impl::getRelocationEntries() const {
- std::vector entries;
- size_t shnum;
- elf_getshdrnum(elf_, &shnum);
-
- for (size_t i = 0; i < shnum; ++i) {
- Elf_Scn* scn = elf_getscn(elf_, i);
- GElf_Shdr shdr;
- gelf_getshdr(scn, &shdr);
-
- if (shdr.sh_type == SHT_RELA || shdr.sh_type == SHT_REL) {
- Elf_Data* data = elf_getdata(scn, nullptr);
- size_t entryCount = shdr.sh_size / shdr.sh_entsize;
-
- for (size_t j = 0; j < entryCount; ++j) {
- RelocationEntry entry;
- if (shdr.sh_type == SHT_RELA) {
- GElf_Rela rela;
- gelf_getrela(data, j, &rela);
- entry.offset = rela.r_offset;
- entry.info = rela.r_info;
- entry.addend = rela.r_addend;
- } else {
- GElf_Rel rel;
- gelf_getrel(data, j, &rel);
- entry.offset = rel.r_offset;
- entry.info = rel.r_info;
- entry.addend = 0; // REL doesn't have addend
- }
- entries.push_back(entry);
- }
+ auto parseSymbolTable() -> bool {
+ auto symtabSection = std::ranges::find_if(
+ sectionHeaders_,
+ [](const auto& section) { return section.type == SHT_SYMTAB; });
+
+ if (symtabSection == sectionHeaders_.end()) {
+ return true; // No symbol table, but not an error
+ }
+
+ const auto* symtab = reinterpret_cast(
+ fileContent_.data() + symtabSection->offset);
+ size_t numSymbols = symtabSection->size / sizeof(Elf64_Sym);
+
+ const auto* strtab = reinterpret_cast(
+ fileContent_.data() + sectionHeaders_[symtabSection->link].offset);
+
+ for (size_t i = 0; i < numSymbols; ++i) {
+ symbolTable_.push_back(
+ Symbol{.name = std::string(strtab + symtab[i].st_name),
+ .value = symtab[i].st_value,
+ .size = symtab[i].st_size,
+ .bind = ELF64_ST_BIND(symtab[i].st_info),
+ .type = ELF64_ST_TYPE(symtab[i].st_info),
+ .shndx = symtab[i].st_shndx});
}
+
+ return true;
}
- return entries;
-}
+};
-// Define the public `ElfParser` class methods
-ElfParser::ElfParser(const char* file)
- : pImpl(std::make_unique(file)) {}
+// ElfParser method implementations
+ElfParser::ElfParser(std::string_view file)
+ : pImpl(std::make_unique(file)) {}
-bool ElfParser::parse() { return pImpl->parse(); }
+ElfParser::~ElfParser() = default;
-ElfHeader ElfParser::getElfHeader() const { return pImpl->getElfHeader(); }
+auto ElfParser::parse() -> bool { return pImpl->parse(); }
-std::vector ElfParser::getProgramHeaders() const {
+auto ElfParser::getElfHeader() const -> std::optional {
+ return pImpl->getElfHeader();
+}
+
+auto ElfParser::getProgramHeaders() const -> std::span {
return pImpl->getProgramHeaders();
}
-std::vector ElfParser::getSectionHeaders() const {
+auto ElfParser::getSectionHeaders() const -> std::span {
return pImpl->getSectionHeaders();
}
-std::vector ElfParser::getSymbolTable() const {
+auto ElfParser::getSymbolTable() const -> std::span {
return pImpl->getSymbolTable();
}
-std::vector ElfParser::getDynamicEntries() const {
- return pImpl->getDynamicEntries();
+auto ElfParser::findSymbolByName(std::string_view name) const -> std::optional {
+ return pImpl->findSymbolByName(name);
+}
+
+auto ElfParser::findSymbolByAddress(uint64_t address) const -> std::optional {
+ return pImpl->findSymbolByAddress(address);
}
-std::vector