diff --git a/modules/lithium.cxxtools/docs/nc.md b/modules/lithium.cxxtools/docs/nc.md new file mode 100644 index 00000000..42073986 --- /dev/null +++ b/modules/lithium.cxxtools/docs/nc.md @@ -0,0 +1,112 @@ +# Network Client Application Documentation + +## Overview + +The `Network Client Application` is a tool designed to send files or messages over TCP or UDP protocols. It uses the `ASIO` library for network communication and `loguru` for logging. The application supports both TCP and UDP modes, allowing users to specify a timeout for TCP connections and optionally send a file. + +## Dependencies + +- **ASIO**: A cross-platform C++ library for network and low-level I/O programming. +- **loguru**: A logging library used for logging operations. + +## Constants + +- **MAX_LENGTH**: The maximum length of the buffer used for reading and sending data. +- **ARG_COUNT_MIN**: The minimum number of command-line arguments required. +- **ARG_COUNT_MAX**: The maximum number of command-line arguments allowed. +- **DEFAULT_TIMEOUT_SECONDS**: The default timeout in seconds for TCP connections. + +## Functions + +### `sendFileTcp(tcp::socket& socket, const std::string& filename)` + +Sends a file over a TCP connection. + +- **Parameters:** + + - `socket`: The TCP socket to send the file through. + - `filename`: The name of the file to send. + +- **Example:** + ```cpp + asio::io_context ioContext; + tcp::resolver resolver(ioContext); + auto endpoints = resolver.resolve("localhost", "12345"); + tcp::socket socket(ioContext); + asio::connect(socket, endpoints); + sendFileTcp(socket, "example.txt"); + ``` + +### `sendFileUdp(udp::socket& socket, const udp::endpoint& endpoint, const std::string& filename)` + +Sends a file over a UDP connection. + +- **Parameters:** + + - `socket`: The UDP socket to send the file through. + - `endpoint`: The endpoint to send the file to. + - `filename`: The name of the file to send. + +- **Example:** + ```cpp + asio::io_context ioContext; + udp::resolver resolver(ioContext); + udp::resolver::results_type endpoints = resolver.resolve(udp::v4(), "localhost", "12345"); + udp::socket socket(ioContext); + socket.open(udp::v4()); + sendFileUdp(socket, *endpoints.begin(), "example.txt"); + ``` + +### `runTcpClient(const std::string& host, const std::string& port, int timeoutSeconds, const std::optional& filename = std::nullopt)` + +Runs the TCP client mode. + +- **Parameters:** + + - `host`: The host to connect to. + - `port`: The port to connect to. + - `timeoutSeconds`: The timeout in seconds for the TCP connection. + - `filename`: An optional filename to send over the TCP connection. + +- **Example:** + ```cpp + runTcpClient("localhost", "12345", 10, "example.txt"); + ``` + +### `runUdpClient(const std::string& host, const std::string& port, const std::optional& filename = std::nullopt)` + +Runs the UDP client mode. + +- **Parameters:** + + - `host`: The host to connect to. + - `port`: The port to connect to. + - `filename`: An optional filename to send over the UDP connection. + +- **Example:** + ```cpp + runUdpClient("localhost", "12345", "example.txt"); + ``` + +### `main(int argc, char* argv[]) -> int` + +The main function that initializes the application, parses command-line arguments, and starts the appropriate client mode. + +- **Parameters:** + + - `argc`: The number of command-line arguments. + - `argv`: The array of command-line arguments. + +- **Returns:** An integer representing the exit status. + +- **Example:** + ```bash + ./nc tcp localhost 12345 10 example.txt + ``` + +## Notes + +- The application supports both TCP and UDP protocols. +- Users can specify a timeout for TCP connections. +- The application can optionally send a file over the network. +- Logging is used extensively to provide detailed information about the operations being performed. diff --git a/modules/lithium.cxxtools/docs/proxy.md b/modules/lithium.cxxtools/docs/proxy.md new file mode 100644 index 00000000..67f1707e --- /dev/null +++ b/modules/lithium.cxxtools/docs/proxy.md @@ -0,0 +1,212 @@ +# NetworkProxy Class Documentation + +## Overview + +The `NetworkProxy` class is designed to manage network proxy settings on both Windows and Linux operating systems. It provides functionalities to set, disable, and retrieve proxy settings, install and uninstall certificates, and manage the hosts file. The class is part of the `lithium::cxxtools` namespace and leverages platform-specific implementations to handle proxy-related operations. + +## Class Methods + +### `setProxy(const std::string& proxy, NetworkProxy::ProxyMode mode, const std::string& listenIP, const std::string& dns) -> bool` + +Sets the network proxy with the specified parameters. + +- **Parameters:** + + - `proxy`: The proxy server address. + - `mode`: The proxy mode (e.g., Hosts, PAC, System). + - `listenIP`: The IP address to listen on. + - `dns`: The DNS server address. + +- **Returns:** `true` if the proxy was set successfully, `false` otherwise. + +### `disableProxy() const -> bool` + +Disables the network proxy. + +- **Returns:** `true` if the proxy was disabled successfully, `false` otherwise. + +### `getCurrentProxy() -> std::string` + +Retrieves the current proxy settings. + +- **Returns:** A string representing the current proxy server address. + +### `installCertificate(const std::string& certPath) const -> bool` + +Installs a certificate from the specified path. + +- **Parameters:** + + - `certPath`: The path to the certificate file. + +- **Returns:** `true` if the certificate was installed successfully, `false` otherwise. + +### `uninstallCertificate(const std::string& certName) const -> bool` + +Uninstalls a certificate by its name. + +- **Parameters:** + + - `certName`: The name of the certificate to uninstall. + +- **Returns:** `true` if the certificate was uninstalled successfully, `false` otherwise. + +### `viewCertificateInfo(const std::string& certName) const -> std::string` + +Retrieves information about a certificate by its name. + +- **Parameters:** + + - `certName`: The name of the certificate. + +- **Returns:** A string containing the certificate information. + +### `editHostsFile(const std::vector>& hostsEntries)` + +Edits the hosts file with the specified entries. + +- **Parameters:** + - `hostsEntries`: A vector of pairs where each pair represents an IP address and a hostname. + +### `resetHostsFile()` + +Resets the hosts file to its default state. + +### `enableHttpToHttpsRedirect(bool enable)` + +Enables or disables HTTP to HTTPS redirection. + +- **Parameters:** + - `enable`: `true` to enable redirection, `false` to disable. + +### `setCustomDoH(const std::string& dohUrl)` + +Sets a custom DNS-over-HTTPS (DoH) URL. + +- **Parameters:** + - `dohUrl`: The DoH URL to set. + +### `getProxyModeName(ProxyMode mode) -> std::string` + +Returns the name of the specified proxy mode. + +- **Parameters:** + + - `mode`: The proxy mode. + +- **Returns:** A string representing the name of the proxy mode. + +## Platform-Specific Implementations + +### Windows + +#### `setWindowsProxy(const std::string& proxy) const -> bool` + +Sets the proxy settings on a Windows system using the Windows Registry. + +#### `disableWindowsProxy() const -> bool` + +Disables the proxy settings on a Windows system using the Windows Registry. + +#### `getWindowsCurrentProxy() const -> std::string` + +Retrieves the current proxy settings from the Windows Registry. + +#### `installWindowsCertificate(const std::string& certPath) const -> bool` + +Installs a certificate on a Windows system using the `certutil` command. + +#### `uninstallWindowsCertificate(const std::string& certName) const -> bool` + +Uninstalls a certificate on a Windows system using the `certutil` command. + +#### `viewWindowsCertificateInfo(const std::string& certName) const -> std::string` + +Retrieves information about a certificate on a Windows system using the `certutil` command. + +#### `editWindowsHostsFile(const std::vector>& hostsEntries) const` + +Edits the Windows hosts file with the specified entries. + +#### `resetWindowsHostsFile() const` + +Resets the Windows hosts file to its default state. + +### Linux + +#### `setLinuxProxy(const std::string& proxy) const -> bool` + +Sets the proxy settings on a Linux system using environment variables. + +#### `disableLinuxProxy() -> bool` + +Disables the proxy settings on a Linux system by unsetting environment variables. + +#### `getLinuxCurrentProxy() -> std::string` + +Retrieves the current proxy settings from the Linux environment variables. + +#### `installLinuxCertificate(const std::string& certPath) -> bool` + +Installs a certificate on a Linux system using the `update-ca-certificates` command. + +#### `uninstallLinuxCertificate(const std::string& certName) -> bool` + +Uninstalls a certificate on a Linux system using the `update-ca-certificates` command. + +#### `viewLinuxCertificateInfo(const std::string& certName) -> std::string` + +Retrieves information about a certificate on a Linux system using the `openssl` command. + +#### `editLinuxHostsFile(const std::vector>& hostsEntries) const` + +Edits the Linux hosts file with the specified entries. + +#### `resetLinuxHostsFile() const` + +Resets the Linux hosts file to its default state. + +## Dependencies + +- **loguru**: A logging library used for logging operations. +- **atom/system/command.hpp**: A utility for executing system commands. + +## Usage Example + +```cpp +#include "proxy.hpp" + +int main() { + lithium::cxxtools::NetworkProxy proxy; + + // Set proxy + proxy.setProxy("http://proxy.example.com:8080", lithium::cxxtools::NetworkProxy::ProxyMode::System, "0.0.0.0", "8.8.8.8"); + + // Disable proxy + proxy.disableProxy(); + + // Install certificate + proxy.installCertificate("path/to/certificate.crt"); + + // Uninstall certificate + proxy.uninstallCertificate("certificate.crt"); + + // Edit hosts file + std::vector> hostsEntries = { + {"127.0.0.1", "localhost"}, + {"192.168.1.1", "example.com"} + }; + proxy.editHostsFile(hostsEntries); + + // Reset hosts file + proxy.resetHostsFile(); + + return 0; +} +``` + +## Notes + +- The class uses platform-specific code to handle proxy settings, certificate management, and hosts file operations. +- Logging is used extensively to provide detailed information about the operations being performed. +- The class is designed to be cross-platform, with separate implementations for Windows and Linux. diff --git a/modules/lithium.cxxtools/docs/symbol.md b/modules/lithium.cxxtools/docs/symbol.md new file mode 100644 index 00000000..93233c11 --- /dev/null +++ b/modules/lithium.cxxtools/docs/symbol.md @@ -0,0 +1,307 @@ +# Symbol Analyzer Documentation + +## Overview + +The `Symbol Analyzer` is a tool designed to analyze the symbols within a shared library or executable. It uses the `readelf` command to extract symbol information and provides functionalities to parse, filter, and export these symbols in various formats such as CSV, JSON, and YAML. The tool is designed to be highly efficient, leveraging multithreading for parsing large outputs. + +## Dependencies + +- **loguru**: A logging library used for logging operations. +- **atom/error/exception.hpp**: Custom exception handling. +- **atom/function/abi.hpp**: ABI-related functionalities. +- **atom/type/json.hpp**: JSON handling using `nlohmann::json`. +- **yaml-cpp/yaml.h**: YAML handling using `yaml-cpp`. + +## Constants + +- **BUFFER_SIZE**: Size of the buffer used for reading command output. +- **MATCH_SIZE**: Number of matches expected in the symbol regex. +- **MATCH_INDEX**: Index of the symbol name in the regex match. + +## Functions + +### `exec(const std::string& cmd) -> std::string` + +Executes a system command and returns its output as a string. + +- **Parameters:** + + - `cmd`: The command to execute. + +- **Returns:** The output of the command as a string. + +- **Example:** + ```cpp + std::string output = exec("ls -l"); + std::cout << "Command output: " << output << std::endl; + ``` + +### `parseReadelfOutput(const std::string_view output) -> std::vector` + +Parses the output of the `readelf` command and extracts symbols. + +- **Parameters:** + + - `output`: The output of the `readelf` command. + +- **Returns:** A vector of `Symbol` objects. + +- **Example:** + ```cpp + std::string readelfOutput = exec("readelf -Ws /path/to/library"); + std::vector symbols = parseReadelfOutput(readelfOutput); + ``` + +### `parseSymbolsInParallel(const std::string& output, int threadCount) -> std::vector` + +Parses symbols in parallel using multiple threads. + +- **Parameters:** + + - `output`: The output of the `readelf` command. + - `threadCount`: The number of threads to use for parsing. + +- **Returns:** A vector of `Symbol` objects. + +- **Example:** + ```cpp + std::string readelfOutput = exec("readelf -Ws /path/to/library"); + std::vector symbols = parseSymbolsInParallel(readelfOutput, 4); + ``` + +### `filterSymbolsByType(const std::vector& symbols, const std::string& type) -> std::vector` + +Filters symbols by their type. + +- **Parameters:** + + - `symbols`: The vector of symbols to filter. + - `type`: The type of symbols to filter by. + +- **Returns:** A vector of filtered `Symbol` objects. + +- **Example:** + ```cpp + std::vector functionSymbols = filterSymbolsByType(symbols, "FUNC"); + ``` + +### `filterSymbolsByVisibility(const std::vector& symbols, const std::string& visibility) -> std::vector` + +Filters symbols by their visibility. + +- **Parameters:** + + - `symbols`: The vector of symbols to filter. + - `visibility`: The visibility of symbols to filter by. + +- **Returns:** A vector of filtered `Symbol` objects. + +- **Example:** + ```cpp + std::vector globalSymbols = filterSymbolsByVisibility(symbols, "GLOBAL"); + ``` + +### `filterSymbolsByBind(const std::vector& symbols, const std::string& bind) -> std::vector` + +Filters symbols by their bind type. + +- **Parameters:** + + - `symbols`: The vector of symbols to filter. + - `bind`: The bind type of symbols to filter by. + +- **Returns:** A vector of filtered `Symbol` objects. + +- **Example:** + ```cpp + std::vector weakSymbols = filterSymbolsByBind(symbols, "WEAK"); + ``` + +### `filterSymbolsByCondition(const std::vector& symbols, const std::function& condition) -> std::vector` + +Filters symbols based on a custom condition. + +- **Parameters:** + + - `symbols`: The vector of symbols to filter. + - `condition`: A function that returns `true` for symbols that should be included. + +- **Returns:** A vector of filtered `Symbol` objects. + +- **Example:** + ```cpp + auto isGlobalFunction = [](const Symbol& symbol) { + return symbol.type == "FUNC" && symbol.visibility == "GLOBAL"; + }; + std::vector globalFunctions = filterSymbolsByCondition(symbols, isGlobalFunction); + ``` + +### `printSymbolStatistics(const std::vector& symbols)` + +Prints statistics about the symbols, such as the count of each symbol type. + +- **Parameters:** + + - `symbols`: The vector of symbols to analyze. + +- **Example:** + ```cpp + printSymbolStatistics(symbols); + ``` + +### `exportSymbolsToFile(const std::vector& symbols, const std::string& filename)` + +Exports symbols to a CSV file. + +- **Parameters:** + + - `symbols`: The vector of symbols to export. + - `filename`: The name of the CSV file to create. + +- **Example:** + ```cpp + exportSymbolsToFile(symbols, "symbols.csv"); + ``` + +### `exportSymbolsToJson(const std::vector& symbols, const std::string& filename)` + +Exports symbols to a JSON file. + +- **Parameters:** + + - `symbols`: The vector of symbols to export. + - `filename`: The name of the JSON file to create. + +- **Example:** + ```cpp + exportSymbolsToJson(symbols, "symbols.json"); + ``` + +### `exportSymbolsToYaml(const std::vector& symbols, const std::string& filename)` + +Exports symbols to a YAML file. + +- **Parameters:** + + - `symbols`: The vector of symbols to export. + - `filename`: The name of the YAML file to create. + +- **Example:** + ```cpp + exportSymbolsToYaml(symbols, "symbols.yaml"); + ``` + +### `analyzeLibrary(const std::string& libraryPath, const std::string& outputFormat, int threadCount)` + +Analyzes the library and exports the symbols to the specified format. + +- **Parameters:** + + - `libraryPath`: The path to the library file. + - `outputFormat`: The format to export the symbols (csv, json, yaml). + - `threadCount`: The number of threads to use for parsing. + +- **Example:** + ```cpp + analyzeLibrary("/path/to/library", "json", 4); + ``` + +### `main(int argc, char* argv[]) -> int` + +The main function that initializes the application, parses command-line arguments, and starts the library analysis. + +- **Parameters:** + + - `argc`: The number of command-line arguments. + - `argv`: The array of command-line arguments. + +- **Returns:** An integer representing the exit status. + +- **Example:** + ```bash + ./symbol_analyzer /path/to/library json 4 + ``` + +## Usage Example + +```cpp +#include "symbol.hpp" + +int main(int argc, char* argv[]) { + loguru::init(argc, argv); + LOG_F(INFO, "Symbol Analyzer application started."); + + if (argc < 3 || argc > 4) { + LOG_F(ERROR, "Invalid number of arguments."); + LOG_F(ERROR, + "Usage: {} " + "[thread_count]", + argv[0]); + std::cerr << "Usage: " << argv[0] + << " " + "[thread_count]" + << std::endl; + return EXIT_FAILURE; + } + + std::string libraryPath = argv[1]; + std::string outputFormat = argv[2]; + int threadCount = static_cast( + std::thread::hardware_concurrency()); // Default to system's thread + // count + + if (argc == 4) { + try { + threadCount = std::stoi(argv[3]); + if (threadCount <= 0) { + LOG_F(ERROR, "Thread count must be a positive integer."); + std::cerr << "Error: Thread count must be a positive integer." + << std::endl; + return EXIT_FAILURE; + } + LOG_F(INFO, "Using user-specified thread count: {}", threadCount); + } catch (const std::invalid_argument& e) { + LOG_F(ERROR, "Invalid thread count provided: {}", argv[3]); + std::cerr + << "Error: Invalid thread count provided. Must be an integer." + << std::endl; + return EXIT_FAILURE; + } catch (const std::out_of_range& e) { + LOG_F(ERROR, "Thread count out of range: {}", argv[3]); + std::cerr << "Error: Thread count out of range." << std::endl; + return EXIT_FAILURE; + } + } + + LOG_F(INFO, "Library Path: {}", libraryPath); + LOG_F(INFO, "Output Format: {}", outputFormat); + LOG_F(INFO, "Thread Count: {}", threadCount); + + try { + analyzeLibrary(libraryPath, outputFormat, threadCount); + } catch (const atom::error::Exception& e) { + LOG_F(ERROR, "Atom Exception: {}", e.what()); + std::cerr << "Atom Exception: " << e.what() << std::endl; + return EXIT_FAILURE; + } catch (const std::exception& e) { + LOG_F(ERROR, "Standard Exception: {}", e.what()); + std::cerr << "Standard Exception: " << e.what() << std::endl; + return EXIT_FAILURE; + } catch (...) { + LOG_F(ERROR, "Unknown exception occurred."); + std::cerr << "Error: Unknown exception occurred." << std::endl; + return EXIT_FAILURE; + } + + LOG_F(INFO, "Symbol Analyzer application terminated successfully."); + return EXIT_SUCCESS; +} +``` + +## Notes + +- The tool uses multithreading to improve performance, especially when dealing with large libraries. +- The `readelf` command is used to extract symbol information, which is then parsed and processed. +- Symbols can be filtered by type, visibility, bind, or a custom condition. +- The tool supports exporting symbols to CSV, JSON, and YAML formats. +- Logging is used extensively to provide detailed information about the operations being performed. diff --git a/modules/lithium.image/pymodule.cpp b/modules/lithium.image/pymodule.cpp index d7031770..fea665c0 100644 --- a/modules/lithium.image/pymodule.cpp +++ b/modules/lithium.image/pymodule.cpp @@ -8,8 +8,11 @@ #include "fwhm.hpp" #include "hfr.hpp" #include "hist.hpp" +#include "imgio.hpp" +#include "imgutils.hpp" #include "ndarray_converter.hpp" #include "stretch.hpp" +#include "thumbhash.hpp" #include #include @@ -939,4 +942,364 @@ PYBIND11_MODULE(base64, m) { ValueError: If blockSize is less than 1 )pbdoc", py::arg("img"), py::arg("blockSize") = 16); + + // Bind YCbCr struct + py::class_(m, "YCbCr") + .def(py::init<>()) + .def_readwrite("y", &YCbCr::y) + .def_readwrite("cb", &YCbCr::cb) + .def_readwrite("cr", &YCbCr::cr); + + // Bind dct function + m.def( + "dct", + [](const py::array_t& input, py::array_t& output) { + cv::Mat input_mat = numpyToMat(input); + cv::Mat output_mat; + dct(input_mat, output_mat); + output = matToNumpy(output_mat); + }, + R"pbdoc( + Perform Discrete Cosine Transform (DCT) on the input image + Parameters: + input (numpy.ndarray): Input image matrix + output (numpy.ndarray): Output matrix to store the DCT result + )pbdoc"); + + // Bind rgbToYCbCr function + m.def( + "rgbToYCbCr", + [](const py::array_t& rgb) { + cv::Vec rgb_vec = + *reinterpret_cast*>(rgb.data()); + return rgbToYCbCr(rgb_vec); + }, + R"pbdoc( + Convert an RGB color to YCbCr color space + Parameters: + rgb (numpy.ndarray): Input RGB color + Returns: + YCbCr: The YCbCr color space values + )pbdoc"); + + // Bind encodeThumbHash function + m.def( + "encodeThumbHash", + [](const py::array_t& image) { + cv::Mat mat = numpyToMat(image); + return encodeThumbHash(mat); + }, + R"pbdoc( + Encode an image into a ThumbHash + Parameters: + image (numpy.ndarray): Input image to be encoded + Returns: + List[float]: Encoded ThumbHash + )pbdoc"); + + // Bind decodeThumbHash function + m.def( + "decodeThumbHash", + [](const std::vector& thumbHash, int width, int height) { + cv::Mat result = decodeThumbHash(thumbHash, width, height); + return matToNumpy(result); + }, + R"pbdoc( + Decode a ThumbHash into an image + Parameters: + thumbHash (List[float]): Encoded ThumbHash data + width (int): Width of the output thumbnail image + height (int): Height of the output thumbnail image + Returns: + numpy.ndarray: Decoded thumbnail image + )pbdoc"); + + // Bind insideCircle function + m.def("insideCircle", &insideCircle, R"pbdoc( + Check if a point is inside a circle + Parameters: + xCoord (int): X coordinate of the point + yCoord (int): Y coordinate of the point + centerX (int): X coordinate of the circle center + centerY (int): Y coordinate of the circle center + radius (float): Radius of the circle + Returns: + bool: True if the point is inside the circle, otherwise False + )pbdoc"); + + // Bind checkElongated function + m.def("checkElongated", &checkElongated, R"pbdoc( + Check if a rectangle is elongated + Parameters: + width (int): Width of the rectangle + height (int): Height of the rectangle + Returns: + bool: True if the rectangle is elongated, otherwise False + )pbdoc"); + + // Bind checkWhitePixel function + m.def( + "checkWhitePixel", + [](const py::array_t& rect_contour, int x_coord, int y_coord) { + cv::Mat mat = numpyToMat(rect_contour); + return checkWhitePixel(mat, x_coord, y_coord); + }, + R"pbdoc( + Check if a pixel is white + Parameters: + rect_contour (numpy.ndarray): Input image + x_coord (int): X coordinate of the pixel + y_coord (int): Y coordinate of the pixel + Returns: + int: 1 if the pixel is white, otherwise 0 + )pbdoc"); + + // Bind checkEightSymmetryCircle function + m.def( + "checkEightSymmetryCircle", + [](const py::array_t& rect_contour, const cv::Point& center, + int x_p, int y_p) { + cv::Mat mat = numpyToMat(rect_contour); + return checkEightSymmetryCircle(mat, center, x_p, y_p); + }, + R"pbdoc( + Check eight symmetry of a circle + Parameters: + rect_contour (numpy.ndarray): Input image + center (cv::Point): Center of the circle + x_p (int): X coordinate of the point + y_p (int): Y coordinate of the point + Returns: + int: Symmetry score + )pbdoc"); + + // Bind checkFourSymmetryCircle function + m.def( + "checkFourSymmetryCircle", + [](const py::array_t& rect_contour, const cv::Point& center, + float radius) { + cv::Mat mat = numpyToMat(rect_contour); + return checkFourSymmetryCircle(mat, center, radius); + }, + R"pbdoc( + Check four symmetry of a circle + Parameters: + rect_contour (numpy.ndarray): Input image + center (cv::Point): Center of the circle + radius (float): Radius of the circle + Returns: + int: Symmetry score + )pbdoc"); + + // Bind defineNarrowRadius function + m.def("defineNarrowRadius", &defineNarrowRadius, R"pbdoc( + Define narrow radius + Parameters: + min_area (int): Minimum area + max_area (float): Maximum area + area (float): Area + scale (float): Scale + Returns: + tuple: A tuple containing the radius, a vector of radii, and a vector of scales + )pbdoc"); + + // Bind checkBresenhamCircle function + m.def( + "checkBresenhamCircle", + [](const py::array_t& rect_contour, float radius, + float pixel_ratio, bool if_debug = false) { + cv::Mat mat = numpyToMat(rect_contour); + return checkBresenhamCircle(mat, radius, pixel_ratio, if_debug); + }, + R"pbdoc( + Check Bresenham circle + Parameters: + rect_contour (numpy.ndarray): Input image + radius (float): Radius of the circle + pixel_ratio (float): Pixel ratio + if_debug (bool): Debug flag + Returns: + bool: True if the circle is valid, otherwise False + )pbdoc"); + + // Bind calculateAverageDeviation function + m.def( + "calculateAverageDeviation", + [](double mid, const py::array_t& norm_img) { + cv::Mat mat = numpyToMat(norm_img); + return calculateAverageDeviation(mid, mat); + }, + R"pbdoc( + Calculate average deviation + Parameters: + mid (float): Mid value + norm_img (numpy.ndarray): Normalized image + Returns: + float: Average deviation + )pbdoc"); + + // Bind calculateMTF function + m.def( + "calculateMTF", + [](double magnitude, const py::array_t& img) { + cv::Mat mat = numpyToMat(img); + cv::Mat result = calculateMTF(magnitude, mat); + return matToNumpy(result); + }, + R"pbdoc( + Calculate MTF + Parameters: + magnitude (float): Magnitude + img (numpy.ndarray): Input image + Returns: + numpy.ndarray: MTF image + )pbdoc"); + + // Bind calculateScale function + m.def( + "calculateScale", + [](const py::array_t& img, int resize_size = 1552) { + cv::Mat mat = numpyToMat(img); + return calculateScale(mat, resize_size); + }, + R"pbdoc( + Calculate scale + Parameters: + img (numpy.ndarray): Input image + resize_size (int): Resize size + Returns: + float: Scale + )pbdoc"); + + // Bind calculateMedianDeviation function + m.def( + "calculateMedianDeviation", + [](double mid, const py::array_t& img) { + cv::Mat mat = numpyToMat(img); + return calculateMedianDeviation(mid, mat); + }, + R"pbdoc( + Calculate median deviation + Parameters: + mid (float): Mid value + img (numpy.ndarray): Input image + Returns: + float: Median deviation + )pbdoc"); + + // Bind computeParamsOneChannel function + m.def( + "computeParamsOneChannel", + [](const py::array_t& img) { + cv::Mat mat = numpyToMat(img); + auto result = computeParamsOneChannel(mat); + return py::make_tuple(std::get<0>(result), std::get<1>(result), + std::get<2>(result)); + }, + R"pbdoc( + Compute parameters for one channel + Parameters: + img (numpy.ndarray): Input image + Returns: + tuple: A tuple containing the parameters + )pbdoc"); + + // Bind autoWhiteBalance function + m.def( + "autoWhiteBalance", + [](const py::array_t& img) { + cv::Mat mat = numpyToMat(img); + cv::Mat result = autoWhiteBalance(mat); + return matToNumpy(result); + }, + R"pbdoc( + Perform automatic white balance + Parameters: + img (numpy.ndarray): Input image + Returns: + numpy.ndarray: White-balanced image + )pbdoc"); + + // Bind loadImage function + m.def( + "loadImage", + [](const std::string& filename, int flags = 1) { + cv::Mat mat = loadImage(filename, flags); + return matToNumpy(mat); + }, + R"pbdoc( + Load a single image + Parameters: + filename (str): Path to the image file + flags (int): Flags for image loading + Returns: + numpy.ndarray: Loaded image + )pbdoc", + py::arg("filename"), py::arg("flags") = 1); + + // Bind loadImages function + m.def( + "loadImages", + [](const std::string& folder, + const std::vector& filenames = {}, int flags = 1) { + std::vector> images = + loadImages(folder, filenames, flags); + std::vector>> result; + result.reserve(images.size()); + for (const auto& [name, mat] : images) { + result.emplace_back(name, matToNumpy(mat)); + } + return result; + }, + R"pbdoc( + Load all images from a folder + Parameters: + folder (str): Path to the folder + filenames (List[str]): List of filenames to load + flags (int): Flags for image loading + Returns: + List[Tuple[str, numpy.ndarray]]: List of loaded images with their filenames + )pbdoc", + py::arg("folder"), py::arg("filenames") = std::vector{}, + py::arg("flags") = 1); + + // Bind saveImage function + m.def("saveImage", &saveImage, R"pbdoc( + Save an image to a file + Parameters: + filename (str): Path to the output file + image (numpy.ndarray): Image to save + Returns: + bool: True if the image was saved successfully, otherwise False + )pbdoc"); + + // Bind saveMatTo8BitJpg function + m.def("saveMatTo8BitJpg", &saveMatTo8BitJpg, R"pbdoc( + Save a cv::Mat image to an 8-bit JPG file + Parameters: + image (numpy.ndarray): Image to save + output_path (str): Path to the output file + Returns: + bool: True if the image was saved successfully, otherwise False + )pbdoc"); + + // Bind saveMatTo16BitPng function + m.def("saveMatTo16BitPng", &saveMatTo16BitPng, R"pbdoc( + Save a cv::Mat image to a 16-bit PNG file + Parameters: + image (numpy.ndarray): Image to save + output_path (str): Path to the output file + Returns: + bool: True if the image was saved successfully, otherwise False + )pbdoc"); + + // Bind saveMatToFits function + m.def("saveMatToFits", &saveMatToFits, R"pbdoc( + Save a cv::Mat image to a FITS file + Parameters: + image (numpy.ndarray): Image to save + output_path (str): Path to the output file + Returns: + bool: True if the image was saved successfully, otherwise False + )pbdoc"); } \ No newline at end of file diff --git a/modules/lithium.pyimage/image/adaptive_stretch/stretch.py b/modules/lithium.pyimage/image/adaptive_stretch/stretch.py index 5de78bbe..b1db3119 100644 --- a/modules/lithium.pyimage/image/adaptive_stretch/stretch.py +++ b/modules/lithium.pyimage/image/adaptive_stretch/stretch.py @@ -1,8 +1,32 @@ +import concurrent.futures from dataclasses import dataclass, field -from loguru import logger -from typing import Optional, Tuple +from enum import Enum +from pathlib import Path +from typing import Optional, Tuple, Literal, List import cv2 import numpy as np +from loguru import logger +import argparse +import sys + + +# Configure Loguru logger with file rotation and different log levels +logger.remove() # Remove the default logger +logger.add(sys.stderr, level="INFO", + format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") +logger.add("adaptive_stretch.log", rotation="10 MB", retention="10 days", + level="DEBUG", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") + + +class ImageFormat(Enum): + PNG = "png" + JPEG = "jpg" + TIFF = "tiff" + BMP = "bmp" + + @staticmethod + def list(): + return [fmt.value for fmt in ImageFormat] @dataclass @@ -10,14 +34,23 @@ class AdaptiveStretch: noise_threshold: float = 1e-4 contrast_protection: Optional[float] = None max_curve_points: int = 106 + roi: Optional[Tuple[int, int, int, int]] = None # (x, y, width, height) + save_intermediate: bool = False # 是否保存中间结果 + intermediate_dir: Optional[Path] = None # 保存中间结果的目录 def __post_init__(self): """ Initialize the logger and log the initialization parameters. """ - logger.add("adaptive_stretch.log", rotation="1 MB") - logger.info("AdaptiveStretch initialized with noise_threshold={}, contrast_protection={}, max_curve_points={}", - self.noise_threshold, self.contrast_protection, self.max_curve_points) + if self.save_intermediate: + if self.intermediate_dir is None: + self.intermediate_dir = Path("intermediate_results") + self.intermediate_dir.mkdir(parents=True, exist_ok=True) + logger.debug( + f"Intermediate results will be saved to {self.intermediate_dir}") + logger.info("AdaptiveStretch initialized with noise_threshold={}, contrast_protection={}, max_curve_points={}, roi={}, save_intermediate={}, intermediate_dir={}", + self.noise_threshold, self.contrast_protection, self.max_curve_points, + self.roi, self.save_intermediate, self.intermediate_dir) def compute_brightness_diff(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ @@ -33,12 +66,11 @@ def compute_brightness_diff(self, image: np.ndarray) -> Tuple[np.ndarray, np.nda diff_y = np.pad(diff_y, ((0, 1), (0, 0)), mode='constant') return diff_x, diff_y - def stretch(self, image: np.ndarray, roi: Optional[Tuple[int, int, int, int]] = None) -> np.ndarray: + def stretch(self, image: np.ndarray) -> np.ndarray: """ Apply adaptive stretch transformation to the image. :param image: Input image as a numpy array (grayscale or color). - :param roi: Optional region of interest defined as a tuple (x, y, width, height). :return: Stretched image as a numpy array. """ logger.info("Starting stretch operation") @@ -53,9 +85,10 @@ def stretch(self, image: np.ndarray, roi: Optional[Tuple[int, int, int, int]] = logger.debug(f"Processing channel {idx}") channel = channel.astype(np.float32) / 255.0 - if roi: - x, y, w, h = roi + if self.roi: + x, y, w, h = self.roi channel_roi = channel[y:y+h, x:x+w] + logger.debug(f"Applying ROI: x={x}, y={y}, w={w}, h={h}") else: channel_roi = channel @@ -72,6 +105,8 @@ def stretch(self, image: np.ndarray, roi: Optional[Tuple[int, int, int, int]] = if self.contrast_protection is not None: transformation_curve = np.clip( transformation_curve, -self.contrast_protection, self.contrast_protection) + logger.debug( + f"Applied contrast protection: {self.contrast_protection}") resampled_curve = cv2.resize( transformation_curve, (self.max_curve_points, 1), interpolation=cv2.INTER_LINEAR) @@ -82,12 +117,21 @@ def stretch(self, image: np.ndarray, roi: Optional[Tuple[int, int, int, int]] = stretched_channel = np.clip( stretched_channel * 255, 0, 255).astype(np.uint8) - if roi: + if self.roi: channel[y:y+h, x:x+w] = stretched_channel stretched_channel = channel + logger.debug( + f"Replaced ROI with stretched channel for channel {idx}") stretched_channels.append(stretched_channel) + if self.save_intermediate and self.intermediate_dir: + intermediate_path = self.intermediate_dir / \ + f"channel_{idx}_stretched.png" + cv2.imwrite(str(intermediate_path), stretched_channel) + logger.debug( + f"Saved intermediate stretched channel to {intermediate_path}") + if len(stretched_channels) > 1: stretched_image = cv2.merge(stretched_channels) else: @@ -95,3 +139,163 @@ def stretch(self, image: np.ndarray, roi: Optional[Tuple[int, int, int, int]] = logger.info("Stretch operation completed") return stretched_image + + def stretch_channel(self, o: np.ndarray, s: np.ndarray, c: int, bg: float, sigma: float, median: float, mad: float) -> None: + """ + Stretch a single channel of the image. + + :param o: Original image. + :param s: Stretched image. + :param c: Channel index. + :param bg: Background level. + :param sigma: Sigma value for clipping. + :param median: Median value of the channel. + :param mad: Median absolute deviation of the channel. + """ + o_channel = o[:, :, c] + s_channel = s[:, :, c] + + shadow_clipping = np.clip(median - sigma * mad, 0, 1.0) + highlight_clipping = 1.0 + + midtone = self.MTF((median - shadow_clipping) / + (highlight_clipping - shadow_clipping), bg) + + o_channel[o_channel <= shadow_clipping] = 0.0 + o_channel[o_channel >= highlight_clipping] = 1.0 + + s_channel[s_channel <= shadow_clipping] = 0.0 + s_channel[s_channel >= highlight_clipping] = 1.0 + + indx_inside_o = np.logical_and( + o_channel > shadow_clipping, o_channel < highlight_clipping) + indx_inside_s = np.logical_and( + s_channel > shadow_clipping, s_channel < highlight_clipping) + + o_channel[indx_inside_o] = ( + o_channel[indx_inside_o] - shadow_clipping) / (highlight_clipping - shadow_clipping) + s_channel[indx_inside_s] = ( + s_channel[indx_inside_s] - shadow_clipping) / (highlight_clipping - shadow_clipping) + + o_channel = self.MTF(o_channel, midtone) + s_channel = self.MTF(s_channel, midtone) + o[:, :, c] = o_channel[:, :] + s[:, :, c] = s_channel[:, :] + + def stretch(self, o: np.ndarray, s: np.ndarray, bg: float, sigma: float, median: List[float], mad: List[float]) -> Tuple[np.ndarray, np.ndarray]: + """ + Stretch the image using the specified parameters. + + :param o: Original image. + :param s: Stretched image. + :param bg: Background level. + :param sigma: Sigma value for clipping. + :param median: List of median values for each channel. + :param mad: List of median absolute deviations for each channel. + :return: Tuple of stretched images. + """ + o_copy = np.copy(o) + s_copy = np.copy(s) + + for c in range(o_copy.shape[-1]): + self.stretch_channel(o_copy, s_copy, c, bg, + sigma, median[c], mad[c]) + + return o_copy, s_copy + + def MTF(self, data: np.ndarray, midtone: float) -> np.ndarray: + """ + Apply midtone transfer function (MTF) to the data. + + :param data: Input data. + :param midtone: Midtone value. + :return: Transformed data. + """ + if isinstance(data, np.ndarray): + data[:] = (midtone - 1) * data[:] / \ + ((2 * midtone - 1) * data[:] - midtone) + else: + data = (midtone - 1) * data / ((2 * midtone - 1) * data - midtone) + + return data + + def MTF_inverse(self, data: np.ndarray, midtone: float) -> np.ndarray: + """ + Apply inverse midtone transfer function (MTF) to the data. + + :param data: Input data. + :param midtone: Midtone value. + :return: Transformed data. + """ + if isinstance(data, np.ndarray): + data[:] = midtone * data[:] / \ + ((2 * midtone - 1) * data[:] - (midtone - 1)) + else: + data = midtone * data / ((2 * midtone - 1) * data - (midtone - 1)) + + return data + + +def parse_arguments() -> argparse.Namespace: + """ + Parse command-line arguments for the stretch script. + + :return: Parsed arguments. + """ + parser = argparse.ArgumentParser( + description="Adaptive Image Stretching Tool") + parser.add_argument('--input', type=Path, required=True, + help='Path to the input image.') + parser.add_argument('--output', type=Path, required=True, + help='Path to save the stretched image.') + parser.add_argument('--noise_threshold', type=float, + default=1e-4, help='Noise threshold for stretching.') + parser.add_argument('--contrast_protection', type=float, + default=None, help='Contrast protection limit.') + parser.add_argument('--max_curve_points', type=int, + default=106, help='Maximum number of curve points.') + parser.add_argument('--roi', type=int, nargs=4, metavar=('X', 'Y', 'W', 'H'), default=None, + help='Region of interest as four integers: x y width height.') + parser.add_argument('--save_intermediate', action='store_true', + help='Save intermediate stretched channels.') + parser.add_argument('--intermediate_dir', type=Path, default=None, + help='Directory to save intermediate results.') + + return parser.parse_args() + + +def main(): + """ + Main function to parse arguments and execute the stretch operation. + """ + args = parse_arguments() + + if not args.input.exists(): + logger.error(f"Input file does not exist: {args.input}") + sys.exit(1) + + image = cv2.imread(str(args.input), cv2.IMREAD_UNCHANGED) + if image is None: + logger.error(f"Failed to load image: {args.input}") + sys.exit(1) + + stretcher = AdaptiveStretch( + noise_threshold=args.noise_threshold, + contrast_protection=args.contrast_protection, + max_curve_points=args.max_curve_points, + roi=tuple(args.roi) if args.roi else None, + save_intermediate=args.save_intermediate, + intermediate_dir=args.intermediate_dir + ) + + stretched_image = stretcher.stretch(image) + + success = cv2.imwrite(str(args.output), stretched_image) + if success: + logger.info(f"Stretched image saved to: {args.output}") + else: + logger.error(f"Failed to save stretched image to: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/modules/lithium.pyimage/image/api/strecth_count.py b/modules/lithium.pyimage/image/api/strecth_count.py new file mode 100644 index 00000000..3bb6e961 --- /dev/null +++ b/modules/lithium.pyimage/image/api/strecth_count.py @@ -0,0 +1,362 @@ +import argparse +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, Optional, Tuple, List + +import cv2 +import numpy as np +from astropy.io import fits +from concurrent.futures import ThreadPoolExecutor +from loguru import logger +from scipy import ndimage +import yaml + + +@dataclass +class ImageProcessingConfig: + """Configuration parameters for image processing.""" + remove_hot_pixels: bool = False + denoise: bool = False + equalize_histogram: bool = False + apply_clahe: bool = False + clahe_clip_limit: float = 2.0 + clahe_tile_grid_size: Tuple[int, int] = (8, 8) + unsharp_mask: bool = False + unsharp_amount: float = 1.5 + adjust_gamma: bool = False + gamma_value: float = 1.0 + apply_gaussian_blur: bool = False + gaussian_kernel_size: int = 5 + do_stretch: bool = False + do_star_count: bool = False + do_star_mark: bool = False + resize_size: int = 2048 + bayer_type: Optional[str] = None + config_file: Optional[Path] = None + save_jpg: bool = False + jpg_file: Optional[Path] = None + save_star_data: bool = False + star_file: Optional[Path] = None + + +class ImageProcessor: + """Handles image processing tasks including loading, enhancing, and saving images.""" + + def __init__(self, config: ImageProcessingConfig) -> None: + """ + Initialize the ImageProcessor with the given configuration. + + :param config: ImageProcessingConfig object containing processing settings. + """ + self.config = config + + # Configure logging with loguru + logger.remove() + logger.add( + "image_processor.log", + rotation="10 MB", + retention="10 days", + level="DEBUG", + format="{time} | {level} | {message}" + ) + logger.debug(f"ImageProcessor initialized with config: {self.config}") + + def debayer_image(self, img: np.ndarray) -> np.ndarray: + """Convert a raw image using the specified Bayer pattern to a color image.""" + logger.debug("Starting debayering process.") + bayer_patterns = { + "rggb": cv2.COLOR_BAYER_RGGB2BGR, + "gbrg": cv2.COLOR_BAYER_GBRG2BGR, + "bggr": cv2.COLOR_BAYER_BGGR2BGR, + "grbg": cv2.COLOR_BAYER_GRBG2BGR + } + pattern = self.config.bayer_type.lower() if self.config.bayer_type else 'rggb' + converted_img = cv2.cvtColor( + img, bayer_patterns.get(pattern, cv2.COLOR_BAYER_RGGB2BGR)) + logger.debug("Debayering completed.") + return converted_img + + def resize_image(self, img: np.ndarray) -> np.ndarray: + """Resize the image to the target size while maintaining aspect ratio.""" + logger.debug("Starting image resizing.") + scale = min(self.config.resize_size / max(img.shape[:2]), 1) + if scale < 1: + resized_img = cv2.resize( + img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) + logger.debug( + f"Image resized to {resized_img.shape[1]}x{resized_img.shape[0]}") + return resized_img + logger.debug("No resizing needed.") + return img + + def normalize_image(self, img: np.ndarray) -> np.ndarray: + """Normalize the image to 8-bit if it's not already.""" + logger.debug("Starting image normalization.") + if img.dtype != np.uint8: + normalized_img = cv2.normalize( + img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX) + normalized_img = normalized_img.astype(np.uint8) + logger.debug("Image normalization completed.") + return normalized_img + logger.debug("Image is already in 8-bit format.") + return img + + def stretch_image(self, img: np.ndarray, is_color: bool) -> np.ndarray: + """Apply stretching to the image.""" + logger.debug("Starting image stretching.") + # Implement stretching logic here + # Placeholder implementation + stretched_img = img # Replace with actual stretching function + logger.debug("Image stretching completed.") + return stretched_img + + def enhance_image(self, img: np.ndarray) -> np.ndarray: + """Apply a series of image enhancements based on the configuration.""" + logger.debug("Starting image enhancement.") + if self.config.remove_hot_pixels: + img = self.remove_hot_pixels(img) + if self.config.denoise: + img = self.denoise_image(img) + if self.config.equalize_histogram: + img = self.equalize_histogram(img) + if self.config.apply_clahe: + img = self.apply_clahe(img) + if self.config.unsharp_mask: + img = self.apply_unsharp_mask(img) + if self.config.adjust_gamma: + img = self.adjust_gamma(img) + if self.config.apply_gaussian_blur: + img = self.apply_gaussian_blur(img) + logger.debug("Image enhancement completed.") + return img + + def remove_hot_pixels(self, img: np.ndarray, threshold: float = 3.0) -> np.ndarray: + """Remove hot pixels using median filter and thresholding.""" + logger.debug("Removing hot pixels from the image.") + median = ndimage.median_filter(img, size=3) + diff = np.abs(img - median) + std_dev = np.std(diff) + mask = diff > (threshold * std_dev) + img[mask] = median[mask] + logger.debug("Hot pixels removed.") + return img + + def denoise_image(self, img: np.ndarray, h: float = 10) -> np.ndarray: + """Apply Non-Local Means Denoising.""" + logger.debug("Applying denoising to the image.") + denoised_img = cv2.fastNlMeansDenoisingColored(img, None, h, h, 7, 21) + logger.debug("Denoising completed.") + return denoised_img + + def equalize_histogram(self, img: np.ndarray) -> np.ndarray: + """Apply histogram equalization to improve contrast.""" + logger.debug("Applying histogram equalization.") + if len(img.shape) == 2: + equalized_img = cv2.equalizeHist(img) + else: + ycrcb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) + ycrcb[:, :, 0] = cv2.equalizeHist(ycrcb[:, :, 0]) + equalized_img = cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2BGR) + logger.debug("Histogram equalization completed.") + return equalized_img + + def apply_clahe(self, img: np.ndarray) -> np.ndarray: + """Apply CLAHE to the image.""" + logger.debug("Applying CLAHE to the image.") + clahe = cv2.createCLAHE( + clipLimit=self.config.clahe_clip_limit, + tileGridSize=self.config.clahe_tile_grid_size + ) + if len(img.shape) == 2: + clahe_img = clahe.apply(img) + else: + lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) + channels = cv2.split(lab) + channels[0] = clahe.apply(channels[0]) + lab = cv2.merge(channels) + clahe_img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) + logger.debug("CLAHE applied.") + return clahe_img + + def apply_unsharp_mask(self, img: np.ndarray) -> np.ndarray: + """Apply unsharp mask to enhance image details.""" + logger.debug("Applying unsharp mask.") + kernel_size = (self.config.gaussian_kernel_size, + self.config.gaussian_kernel_size) + blurred = cv2.GaussianBlur(img, kernel_size, 0) + sharpened = cv2.addWeighted( + img, 1 + self.config.unsharp_amount, blurred, -self.config.unsharp_amount, 0) + logger.debug("Unsharp mask applied.") + return sharpened + + def adjust_gamma(self, img: np.ndarray) -> np.ndarray: + """Adjust image gamma.""" + logger.debug(f"Adjusting gamma with value: {self.config.gamma_value}") + inv_gamma = 1.0 / self.config.gamma_value + table = np.array([((i / 255.0) ** inv_gamma) * 255 + for i in np.arange(256)]).astype("uint8") + gamma_img = cv2.LUT(img, table) + logger.debug("Gamma adjustment completed.") + return gamma_img + + def apply_gaussian_blur(self, img: np.ndarray) -> np.ndarray: + """Apply Gaussian blur to reduce noise.""" + logger.debug("Applying Gaussian blur.") + kernel_size = (self.config.gaussian_kernel_size, + self.config.gaussian_kernel_size) + blurred_img = cv2.GaussianBlur(img, kernel_size, 0) + logger.debug("Gaussian blur applied.") + return blurred_img + + def process_image(self, filepath: Path) -> Tuple[Optional[np.ndarray], Dict[str, float]]: + """Process a single image file.""" + logger.info(f"Processing image: {filepath}") + try: + img_data = fits.getdata(filepath, header=True) + if isinstance(img_data, tuple): + img, header = img_data + else: + img = img_data + header = {} + + is_color = 'BAYERPAT' in header + if is_color: + self.config.bayer_type = self.config.bayer_type or header.get( + 'BAYERPAT', 'rggb') + img = self.debayer_image(img) + + img = self.resize_image(img) + img = self.normalize_image(img) + + # Apply image enhancements + img = self.enhance_image(img) + + if self.config.do_stretch: + img = self.stretch_image(img, is_color) + + result = {"star_count": -1, "average_hfr": -1.0, + "max_star": -1.0, "min_star": -1.0, "average_star": -1.0} + if self.config.do_star_count: + img, star_count, avg_hfr, area_range = self.detect_stars(img) + result.update({ + "star_count": star_count, + "average_hfr": avg_hfr, + "max_star": area_range['max'], + "min_star": area_range['min'], + "average_star": area_range['average'] + }) + + if self.config.save_jpg and self.config.jpg_file and img is not None: + cv2.imwrite(str(self.config.jpg_file), img) + logger.info(f"Processed image saved to {self.config.jpg_file}") + + if self.config.save_star_data and self.config.star_file: + with self.config.star_file.open('w') as f: + json.dump(result, f) + logger.info( + f"Star count results saved to {self.config.star_file}") + + logger.info("Image processing completed successfully.") + return img, result + + except Exception as e: + logger.error(f"Error processing image {filepath}: {e}") + return None, {"star_count": -1, "average_hfr": -1.0, "max_star": -1.0, "min_star": -1.0, "average_star": -1.0} + + def detect_stars(self, img: np.ndarray) -> Tuple[np.ndarray, int, float, Dict[str, float]]: + """Detect stars in the image and calculate relevant metrics.""" + logger.debug("Starting star detection.") + # Placeholder implementation + # Implement actual star detection logic here + star_count = 0 + average_hfr = 0.0 + area_range = {"max": 0.0, "min": 0.0, "average": 0.0} + logger.debug("Star detection completed.") + return img, star_count, average_hfr, area_range + + def load_config(self) -> None: + """Load configuration from a YAML file if provided.""" + if self.config.config_file and self.config.config_file.is_file(): + logger.info( + f"Loading configuration from {self.config.config_file}") + try: + with self.config.config_file.open('r') as f: + file_config = yaml.safe_load(f) + for key, value in file_config.items(): + if hasattr(self.config, key): + setattr(self.config, key, value) + logger.debug("Configuration loaded from file.") + except Exception as e: + logger.error(f"Error loading configuration file: {e}") + else: + logger.debug( + "No configuration file provided or file does not exist.") + + def main(self, filepath: Path) -> None: + """Main processing function.""" + self.load_config() + self.process_image(filepath) + + +if __name__ == "__main__": + # Parse command-line arguments + parser = argparse.ArgumentParser( + description="Image Enhancement and Star Detection Script") + parser.add_argument("filepath", type=Path, help="Path to the FITS file") + parser.add_argument("--config", type=Path, + help="Path to the configuration YAML file") + parser.add_argument("--resize-size", type=int, default=2048, + help="Target size for resizing the image") + parser.add_argument("--jpg-file", type=Path, + help="Path to save the processed image as JPG") + parser.add_argument("--star-file", type=Path, + help="Path to save the star count results as JSON") + parser.add_argument("--remove-hot-pixels", action="store_true", + help="Remove hot pixels in the image") + parser.add_argument("--denoise", action="store_true", + help="Apply denoising to the image") + parser.add_argument("--equalize-histogram", + action="store_true", help="Apply histogram equalization") + parser.add_argument("--apply-clahe", action="store_true", + help="Apply CLAHE to the image") + parser.add_argument("--unsharp-mask", action="store_true", + help="Apply unsharp mask") + parser.add_argument("--adjust-gamma", action="store_true", + help="Adjust gamma of the image") + parser.add_argument("--gamma-value", type=float, + default=1.0, help="Gamma value for adjustment") + parser.add_argument("--apply-gaussian-blur", + action="store_true", help="Apply Gaussian blur") + parser.add_argument("--do-stretch", action="store_true", + help="Apply stretching to the image") + parser.add_argument("--do-star-count", action="store_true", + help="Perform star count on the image") + parser.add_argument("--do-star-mark", action="store_true", + help="Mark detected stars on the image") + args = parser.parse_args() + + # Create configuration + config = ImageProcessingConfig( + remove_hot_pixels=args.remove_hot_pixels, + denoise=args.denoise, + equalize_histogram=args.equalize_histogram, + apply_clahe=args.apply_clahe, + unsharp_mask=args.unsharp_mask, + adjust_gamma=args.adjust_gamma, + gamma_value=args.gamma_value, + apply_gaussian_blur=args.apply_gaussian_blur, + do_stretch=args.do_stretch, + do_star_count=args.do_star_count, + do_star_mark=args.do_star_mark, + resize_size=args.resize_size, + config_file=args.config, + save_jpg=bool(args.jpg_file), + jpg_file=args.jpg_file, + save_star_data=bool(args.star_file), + star_file=args.star_file + ) + + # Initialize and run the processor + processor = ImageProcessor(config) + processor.main(args.filepath) diff --git a/pysrc/image/auto_histogram/__init__.py b/modules/lithium.pyimage/image/auto_histogram/__init__.py similarity index 100% rename from pysrc/image/auto_histogram/__init__.py rename to modules/lithium.pyimage/image/auto_histogram/__init__.py diff --git a/modules/lithium.pyimage/image/auto_histogram/histogram.py b/modules/lithium.pyimage/image/auto_histogram/histogram.py new file mode 100644 index 00000000..44d97d75 --- /dev/null +++ b/modules/lithium.pyimage/image/auto_histogram/histogram.py @@ -0,0 +1,272 @@ +import cv2 +import numpy as np +from typing import Optional, List, Tuple, Union +from pathlib import Path +from loguru import logger +from .utils import save_image, load_image + + +@dataclass +class HistogramConfig: + """ + Configuration parameters for histogram-based image processing. + """ + clip_shadow: float = 0.01 + clip_highlight: float = 0.01 + target_median: int = 128 + method: str = 'gamma' # 'gamma', 'logarithmic', 'mtf' + apply_clahe: bool = False + clahe_clip_limit: float = 2.0 + clahe_tile_grid_size: Tuple[int, int] = (8, 8) + apply_noise_reduction: bool = False + noise_reduction_method: str = 'median' # 'median', 'gaussian' + apply_sharpening: bool = False + sharpening_strength: float = 1.0 + batch_process: bool = False + file_list: Optional[List[Union[str, Path]]] = None + output_directory: Optional[Union[str, Path]] = None + + +class HistogramProcessor: + """ + Handles automated histogram transformations and enhancements for images. + """ + + def __init__(self, config: Optional[HistogramConfig] = None) -> None: + """ + Initialize HistogramProcessor with configuration. + + :param config: HistogramConfig object containing processing settings. + """ + self.config = config or HistogramConfig() + logger.debug( + f"HistogramProcessor initialized with config: {self.config}") + + def histogram_clipping(self, image: np.ndarray) -> np.ndarray: + """ + Clip the histogram of the image based on shadow and highlight percentages. + + :param image: Input image. + :return: Clipped image. + """ + logger.debug("Starting histogram clipping.") + flat = image.flatten() + low_val = np.percentile(flat, self.config.clip_shadow * 100) + high_val = np.percentile(flat, 100 - self.config.clip_highlight * 100) + clipped_image = np.clip(image, low_val, high_val).astype(np.uint8) + logger.debug("Histogram clipping completed.") + return clipped_image + + def gamma_transformation(self, image: np.ndarray) -> np.ndarray: + """ + Apply gamma transformation to the image. + + :param image: Input image. + :return: Gamma transformed image. + """ + logger.debug("Starting gamma transformation.") + try: + mean_val = np.median(image) + if mean_val == 0: + logger.warning( + "Median value is zero during gamma transformation.") + return image + gamma = np.log(self.config.target_median / 255.0) / \ + np.log(mean_val / 255.0) + gamma_corrected = np.array( + 255 * (image / 255.0) ** gamma, dtype='uint8') + logger.debug("Gamma transformation completed.") + return gamma_corrected + except Exception as e: + logger.error(f"Error during gamma transformation: {e}") + raise + + def logarithmic_transformation(self, image: np.ndarray) -> np.ndarray: + """ + Apply logarithmic transformation to the image. + + :param image: Input image. + :return: Logarithmically transformed image. + """ + logger.debug("Starting logarithmic transformation.") + try: + c = 255 / np.log(1 + np.max(image)) + log_transformed = np.array(c * np.log(1 + image), dtype='uint8') + logger.debug("Logarithmic transformation completed.") + return log_transformed + except Exception as e: + logger.error(f"Error during logarithmic transformation: {e}") + raise + + def mtf_transformation(self, image: np.ndarray) -> np.ndarray: + """ + Apply MTF transformation to the image. + + :param image: Input image. + :return: MTF transformed image. + """ + logger.debug("Starting MTF transformation.") + try: + mean_val = np.median(image) + if mean_val == 0: + logger.warning( + "Median value is zero during MTF transformation.") + return image + mtf = self.config.target_median / mean_val + mtf_transformed = np.clip(image * mtf, 0, 255).astype(np.uint8) + logger.debug("MTF transformation completed.") + return mtf_transformed + except Exception as e: + logger.error(f"Error during MTF transformation: {e}") + raise + + def apply_clahe_method(self, image: np.ndarray) -> np.ndarray: + """ + Apply CLAHE to the image. + + :param image: Input image. + :return: CLAHE applied image. + """ + logger.debug("Starting CLAHE.") + try: + clahe = cv2.createCLAHE( + clipLimit=self.config.clahe_clip_limit, + tileGridSize=self.config.clahe_tile_grid_size + ) + if len(image.shape) == 2: + clahe_applied = clahe.apply(image) + else: + channels = cv2.split(image) + clahe_applied = cv2.merge([clahe.apply(ch) for ch in channels]) + logger.debug("CLAHE completed.") + return clahe_applied + except Exception as e: + logger.error(f"Error during CLAHE application: {e}") + raise + + def noise_reduction(self, image: np.ndarray) -> np.ndarray: + """ + Apply noise reduction to the image. + + :param image: Input image. + :return: Noise reduced image. + """ + logger.debug("Starting noise reduction.") + try: + if self.config.noise_reduction_method == 'median': + reduced = cv2.medianBlur(image, 3) + elif self.config.noise_reduction_method == 'gaussian': + reduced = cv2.GaussianBlur(image, (3, 3), 0) + else: + logger.error( + f"Unsupported noise reduction method: {self.config.noise_reduction_method}") + raise ValueError( + f"Unsupported noise reduction method: {self.config.noise_reduction_method}") + logger.debug("Noise reduction completed.") + return reduced + except Exception as e: + logger.error(f"Error during noise reduction: {e}") + raise + + def sharpen_image(self, image: np.ndarray) -> np.ndarray: + """ + Sharpen the image. + + :param image: Input image. + :return: Sharpened image. + """ + logger.debug("Starting image sharpening.") + try: + kernel = np.array([[-1, -1, -1], + [-1, 9 + self.config.sharpening_strength, -1], + [-1, -1, -1]]) + sharpened = cv2.filter2D(image, -1, kernel) + logger.debug("Image sharpening completed.") + return sharpened + except Exception as e: + logger.error(f"Error during image sharpening: {e}") + raise + + def process_single_image(self, image: np.ndarray) -> np.ndarray: + """ + Process a single image with the configured transformations. + + :param image: Input image. + :return: Processed image. + """ + logger.info("Processing a single image.") + try: + if self.config.apply_noise_reduction: + image = self.noise_reduction(image) + + image = self.histogram_clipping(image) + + if self.config.method == 'gamma': + image = self.gamma_transformation(image) + elif self.config.method == 'logarithmic': + image = self.logarithmic_transformation(image) + elif self.config.method == 'mtf': + image = self.mtf_transformation(image) + else: + logger.error(f"Invalid method specified: {self.config.method}") + raise ValueError( + f"Invalid method specified: {self.config.method}") + + if self.config.apply_clahe: + image = self.apply_clahe_method(image) + + if self.config.apply_sharpening: + image = self.sharpen_image(image) + + logger.info("Image processing completed successfully.") + return image + except Exception as e: + logger.error(f"Error during image processing: {e}") + raise + + def process_batch_images(self, file_list: List[Union[str, Path]]) -> List[np.ndarray]: + """ + Process a batch of images. + + :param file_list: List of image file paths. + :return: List of processed images. + """ + logger.info(f"Starting batch processing of {len(file_list)} images.") + processed_images = [] + for idx, file_path in enumerate(file_list, start=1): + logger.info( + f"Processing image {idx}/{len(file_list)}: {file_path}") + image = load_image(file_path, grayscale=False) + if image is None: + logger.warning( + f"Skipping image due to load failure: {file_path}") + continue + processed_image = self.process_single_image(image) + processed_images.append(processed_image) + if self.config.output_directory: + output_path = Path(self.config.output_directory) / \ + f"processed_{Path(file_path).name}" + save_image(output_path, processed_image) + logger.info("Batch processing completed.") + return processed_images + + def process(self, image: Optional[np.ndarray] = None) -> Optional[Union[np.ndarray, List[np.ndarray]]]: + """ + Process images based on the configuration. + + :param image: Single image to process. Required if batch_process is False. + :return: Processed single image or list of processed images. + """ + if self.config.batch_process: + if not self.config.file_list: + logger.error( + "File list must be provided for batch processing.") + raise ValueError( + "File list must be provided for batch processing.") + return self.process_batch_images(self.config.file_list) + else: + if image is None: + logger.error("Image must be provided for single processing.") + raise ValueError( + "Image must be provided for single processing.") + return self.process_single_image(image) diff --git a/modules/lithium.pyimage/image/auto_histogram/processing.py b/modules/lithium.pyimage/image/auto_histogram/processing.py new file mode 100644 index 00000000..7f4016e9 --- /dev/null +++ b/modules/lithium.pyimage/image/auto_histogram/processing.py @@ -0,0 +1,93 @@ +import os +from pathlib import Path +from typing import List, Optional, Union +from loguru import logger +from .histogram import HistogramProcessor, HistogramConfig +from .utils import save_image, load_image + + +def process_directory( + input_directory: Union[str, Path], + output_directory: Union[str, Path], + method: str = 'gamma', + clip_shadow: float = 0.01, + clip_highlight: float = 0.01, + target_median: int = 128, + apply_clahe: bool = False, + clahe_clip_limit: float = 2.0, + clahe_tile_grid_size: Tuple[int, int] = (8, 8), + apply_noise_reduction: bool = False, + noise_reduction_method: str = 'median', + apply_sharpening: bool = False, + sharpening_strength: float = 1.0, + recursive: bool = False +) -> None: + """ + Process all images in a directory using the HistogramProcessor. + + :param input_directory: Directory containing images to process. + :param output_directory: Directory to save processed images. + :param method: Histogram stretching method ('gamma', 'logarithmic', 'mtf'). + :param clip_shadow: Percentage of shadow pixels to clip. + :param clip_highlight: Percentage of highlight pixels to clip. + :param target_median: Target median value for histogram stretching. + :param apply_clahe: Apply CLAHE (Contrast Limited Adaptive Histogram Equalization). + :param clahe_clip_limit: CLAHE clip limit. + :param clahe_tile_grid_size: CLAHE grid size. + :param apply_noise_reduction: Apply noise reduction. + :param noise_reduction_method: Noise reduction method ('median', 'gaussian'). + :param apply_sharpening: Apply image sharpening. + :param sharpening_strength: Strength of sharpening. + :param recursive: Process directories recursively. + """ + try: + input_directory = Path(input_directory) + output_directory = Path(output_directory) + if not input_directory.exists(): + logger.error(f"Input directory does not exist: {input_directory}") + raise FileNotFoundError( + f"Input directory does not exist: {input_directory}") + + output_directory.mkdir(parents=True, exist_ok=True) + logger.info( + f"Processing images from {input_directory} to {output_directory}") + + # Gather image files + if recursive: + file_list = [p for p in input_directory.rglob( + '*') if p.suffix.lower() in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']] + else: + file_list = [p for p in input_directory.glob( + '*') if p.suffix.lower() in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']] + + logger.info(f"Found {len(file_list)} images to process.") + + if not file_list: + logger.warning("No images found to process.") + return + + # Configure HistogramProcessor + config = HistogramConfig( + clip_shadow=clip_shadow, + clip_highlight=clip_highlight, + target_median=target_median, + method=method, + apply_clahe=apply_clahe, + clahe_clip_limit=clahe_clip_limit, + clahe_tile_grid_size=clahe_tile_grid_size, + apply_noise_reduction=apply_noise_reduction, + noise_reduction_method=noise_reduction_method, + apply_sharpening=apply_sharpening, + sharpening_strength=sharpening_strength, + batch_process=True, + file_list=file_list, + output_directory=output_directory + ) + + processor = HistogramProcessor(config=config) + processor.process() + + logger.info("Directory processing completed successfully.") + except Exception as e: + logger.error(f"Error during directory processing: {e}") + raise diff --git a/modules/lithium.pyimage/image/auto_histogram/utils.py b/modules/lithium.pyimage/image/auto_histogram/utils.py new file mode 100644 index 00000000..579a02c9 --- /dev/null +++ b/modules/lithium.pyimage/image/auto_histogram/utils.py @@ -0,0 +1,52 @@ +import cv2 +import numpy as np +from pathlib import Path +from typing import Optional +from loguru import logger + + +def save_image(filepath: Union[str, Path], image: np.ndarray) -> None: + """ + Save an image to the specified filepath. + + :param filepath: Path to save the image. + :param image: Image data. + """ + try: + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + success = cv2.imwrite(str(filepath), image) + if success: + logger.info(f"Image saved successfully at {filepath}") + else: + logger.error(f"Failed to save image at {filepath}") + except Exception as e: + logger.error( + f"Exception occurred while saving image at {filepath}: {e}") + raise + + +def load_image(filepath: Union[str, Path], grayscale: bool = False) -> Optional[np.ndarray]: + """ + Load an image from the specified filepath. + + :param filepath: Path to load the image from. + :param grayscale: Load image as grayscale if True. + :return: Loaded image or None if failed. + """ + try: + filepath = Path(filepath) + if not filepath.exists(): + logger.error(f"File does not exist: {filepath}") + return None + flags = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR + image = cv2.imread(str(filepath), flags) + if image is not None: + logger.info(f"Image loaded successfully from {filepath}") + else: + logger.error(f"Failed to load image from {filepath}") + return image + except Exception as e: + logger.error( + f"Exception occurred while loading image from {filepath}: {e}") + return None diff --git a/pysrc/image/channel/__init__.py b/modules/lithium.pyimage/image/channel/__init__.py similarity index 100% rename from pysrc/image/channel/__init__.py rename to modules/lithium.pyimage/image/channel/__init__.py diff --git a/modules/lithium.pyimage/image/channel/combination.py b/modules/lithium.pyimage/image/channel/combination.py new file mode 100644 index 00000000..6172f1d5 --- /dev/null +++ b/modules/lithium.pyimage/image/channel/combination.py @@ -0,0 +1,182 @@ +from pathlib import Path +from typing import List, Optional, Tuple +from PIL import Image +import numpy as np +from skimage import color +import cv2 +from loguru import logger +import argparse +import sys +import concurrent.futures + +# Configure Loguru logger +logger.remove() # Remove the default logger +logger.add(sys.stderr, level="INFO", format="{time} {level} {message}") + + +def resize_to_match(image: Image.Image, target_size: Tuple[int, int]) -> Image.Image: + """Resize the image to match the target size.""" + logger.debug(f"Resizing image from {image.size} to {target_size}") + return image.resize(target_size, Image.ANTIALIAS) + + +def load_image_as_gray(path: Path) -> Image.Image: + """Load an image and convert it to grayscale.""" + logger.debug(f"Loading grayscale image: {path}") + try: + return Image.open(path).convert('L') + except Exception as e: + logger.error(f"Failed to load image {path}: {e}") + raise + + +def combine_channels(channels: List[Image.Image], color_space: str = 'RGB') -> Image.Image: + """Combine three channels into an image with the specified color space.""" + logger.info(f"Combining channels into {color_space} color space") + color_space = color_space.upper() + if color_space == 'RGB': + return Image.merge("RGB", channels) + elif color_space == 'LAB': + lab_image = Image.merge("LAB", channels) + return lab_image.convert('RGB') + elif color_space == 'HSV': + hsv_image = Image.merge("HSV", channels) + return hsv_image.convert('RGB') + elif color_space == 'HSI': + hsi_array = np.dstack([np.array(ch) / 255.0 for ch in channels]) + rgb_array = color.hsv2rgb(hsi_array) # Approximate HSI using HSV + return Image.fromarray((rgb_array * 255).astype(np.uint8)) + elif color_space == 'YUV': + yuv_image = Image.merge("YCbCr", channels) + return yuv_image.convert('RGB') + else: + logger.error(f"Unsupported color space: {color_space}") + raise ValueError(f"Unsupported color space: {color_space}") + + +def channel_combination(src_paths: List[Path], color_space: str = 'RGB') -> Image.Image: + """Load, resize, and combine channel images.""" + if len(src_paths) != 3: + logger.error( + "Three source image paths are required to combine channels.") + raise ValueError( + "Three source image paths are required to combine channels.") + + logger.info(f"Starting channel combination into {color_space} color space") + # Load images + channels = [load_image_as_gray(path) for path in src_paths] + + # Resize to match the first image + base_size = channels[0].size + channels = [resize_to_match(ch, base_size) for ch in channels] + + # Combine channels + combined_image = combine_channels(channels, color_space=color_space) + logger.info("Channel combination completed") + return combined_image + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Combine three channel images into a specified color space image.") + parser.add_argument('color_space', type=str, choices=['RGB', 'LAB', 'HSV', 'HSI', 'YUV'], + help="Target color space.") + parser.add_argument('src1', type=Path, + help="Path to the first channel image.") + parser.add_argument('src2', type=Path, + help="Path to the second channel image.") + parser.add_argument('src3', type=Path, + help="Path to the third channel image.") + parser.add_argument('-o', '--output', type=Path, default=Path('./combined.png'), + help="Path to save the combined image.") + parser.add_argument('-f', '--format', type=str, choices=['PNG', 'JPEG', 'BMP', 'TIFF'], + default='PNG', help="Format of the output image.") + parser.add_argument('-b', '--batch', action='store_true', + help="Enable batch processing mode to process multiple image sets in the same directory.") + parser.add_argument('-m', '--mapping', nargs=3, metavar=('CH1', 'CH2', 'CH3'), + help="Custom channel mapping (e.g., R G B).") + return parser.parse_args() + + +def batch_process(directory: Path, color_space: str, output_dir: Path, img_format: str, mapping: Optional[List[str]] = None): + """Batch process multiple sets of channel images in a directory.""" + logger.info(f"Starting batch processing in directory: {directory}") + # Assume each image set consists of three files named _. + channel_suffixes = mapping if mapping else ['R', 'G', 'B'] + r_images = sorted(directory.glob(f'*_{channel_suffixes[0]}.*')) + + def process_image_set(r_img: Path): + basename = r_img.stem[:-len(f'_{channel_suffixes[0]}')] + other_channels = [ + directory / f"{basename}_{channel}{r_img.suffix}" for channel in channel_suffixes[1:]] + if not all(ch.exists() for ch in other_channels): + logger.warning(f"Missing channel images for: {basename}") + return + try: + combined = channel_combination( + [r_img] + other_channels, color_space=color_space) + output_path = output_dir / \ + f"{basename}_{color_space}.{img_format.lower()}" + combined.save(output_path, format=img_format) + logger.info(f"Saved combined image to {output_path}") + except Exception as e: + logger.error(f"Error processing {basename}: {e}") + + output_dir.mkdir(exist_ok=True) + with concurrent.futures.ThreadPoolExecutor() as executor: + executor.map(process_image_set, r_images) + + +def main(): + args = parse_args() + + if args.batch: + # Batch processing mode + src_dir = args.src1.parent # Assume all images are in the same directory + output_dir = Path('./batch_output') + try: + batch_process(src_dir, args.color_space, + output_dir, args.format, args.mapping) + except Exception as e: + logger.exception(f"Error during batch processing: {e}") + sys.exit(1) + else: + # Single image set processing + try: + combined_image = channel_combination([args.src1, args.src2, args.src3], + color_space=args.color_space) + combined_image.save(args.output, format=args.format) + logger.info( + f"Combined image saved to {args.output}, format: {args.format}") + except Exception as e: + logger.exception(f"Error during channel combination: {e}") + sys.exit(1) + + +def library_combine_channels(src_paths: List[Path], color_space: str = 'RGB', output_path: Optional[Path] = None, img_format: str = 'PNG') -> Optional[Path]: + """ + Library function to combine channels and optionally save the image. + + :param src_paths: List of three Paths to the channel images. + :param color_space: Target color space. + :param output_path: Path to save the combined image. If None, the image is not saved. + :param img_format: Format to save the image. + :return: Path to the saved image if output_path is provided, else None. + """ + try: + combined_image = channel_combination( + src_paths, color_space=color_space) + if output_path: + combined_image.save(output_path, format=img_format) + logger.info( + f"Combined image saved to {output_path}, format: {img_format}") + return output_path + return None + except Exception as e: + logger.exception(f"Error in library_combine_channels: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/modules/lithium.pyimage/image/channel/extraction.py b/modules/lithium.pyimage/image/channel/extraction.py new file mode 100644 index 00000000..0e2d4a52 --- /dev/null +++ b/modules/lithium.pyimage/image/channel/extraction.py @@ -0,0 +1,361 @@ +import cv2 +import numpy as np +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from matplotlib import pyplot as plt +from loguru import logger +import argparse +import sys +import concurrent.futures + +# Configure Loguru logger +logger.remove() # Remove the default logger +logger.add(sys.stderr, level="INFO", format="{time} {level} {message}") + + +def extract_channels(image: np.ndarray, color_space: str = 'RGB') -> Dict[str, np.ndarray]: + """ + Extract channels from an image based on the specified color space. + + :param image: Input image in BGR format. + :param color_space: Target color space for channel extraction. + :return: Dictionary of channel names and their corresponding data. + """ + channels = {} + logger.debug(f"Extracting channels using color space: {color_space}") + + try: + if color_space.upper() == 'RGB': + rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + channels['R'], channels['G'], channels['B'] = cv2.split(rgb_image) + + elif color_space.upper() == 'XYZ': + xyz_image = cv2.cvtColor(image, cv2.COLOR_BGR2XYZ) + channels['X'], channels['Y'], channels['Z'] = cv2.split(xyz_image) + + elif color_space.upper() == 'LAB': + lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2Lab) + channels['L*'], channels['a*'], channels['b*'] = cv2.split( + lab_image) + + elif color_space.upper() == 'LCH': + lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2Lab) + L, a, b = cv2.split(lab_image) + H, C = cv2.cartToPolar(a.astype(np.float32), b.astype(np.float32)) + channels['L*'] = L + channels['C*'] = C + channels['H*'] = H + + elif color_space.upper() == 'HSV': + hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + channels['H'], channels['S'], channels['V'] = cv2.split(hsv_image) + + elif color_space.upper() == 'HSI': + hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + H, S, V = cv2.split(hsv_image) + I = V.copy() + channels['H'] = H + channels['Si'] = S + channels['I'] = I + + elif color_space.upper() == 'YUV': + yuv_image = cv2.cvtColor(image, cv2.COLOR_BGR2YUV) + channels['Y'], channels['U'], channels['V'] = cv2.split(yuv_image) + + elif color_space.upper() == 'YCBCR': + ycbcr_image = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb) + channels['Y'], channels['Cb'], channels['Cr'] = cv2.split( + ycbcr_image) + + elif color_space.upper() == 'HSL': + hsl_image = cv2.cvtColor(image, cv2.COLOR_BGR2HLS) + channels['H'], channels['S'], channels['L'] = cv2.split(hsl_image) + + elif color_space.upper() == 'CMYK': + # Approximate CMYK by converting to CMY and inferring K + cmy_image = 255 - image + C, M, Y = cv2.split(cmy_image) + K = np.minimum(C, np.minimum(M, Y)) + channels['C'] = C + channels['M'] = M + channels['Y'] = Y + channels['K'] = K + + else: + logger.error(f"Unsupported color space: {color_space}") + raise ValueError(f"Unsupported color space: {color_space}") + + logger.info( + f"Successfully extracted channels for color space: {color_space}") + except Exception as e: + logger.error(f"Error extracting channels: {e}") + raise + + return channels + + +def show_histogram(channel_data: np.ndarray, title: str = 'Channel Histogram') -> None: + """ + Display the histogram of a single channel. + + :param channel_data: Numpy array of the channel data. + :param title: Title of the histogram plot. + """ + logger.debug(f"Displaying histogram for: {title}") + plt.figure() + plt.title(title) + plt.xlabel('Pixel Value') + plt.ylabel('Frequency') + plt.hist(channel_data.ravel(), bins=256, range=[ + 0, 256], color='gray', alpha=0.7) + plt.grid(True) + plt.show() + + +def merge_channels(channels: Dict[str, np.ndarray]) -> Optional[np.ndarray]: + """ + Merge channels back into a single image. + + :param channels: Dictionary of channel names and their data. + :return: Merged image or None if insufficient channels. + """ + logger.debug("Merging channels into a single image.") + merged_image = None + channel_list = list(channels.values()) + + try: + if len(channel_list) >= 3: + merged_image = cv2.merge(channel_list[:3]) + elif len(channel_list) == 2: + merged_image = cv2.merge([ + channel_list[0], + channel_list[1], + np.zeros_like(channel_list[0], dtype=channel_list[0].dtype) + ]) + elif len(channel_list) == 1: + merged_image = channel_list[0] + else: + logger.warning("No channels to merge.") + except Exception as e: + logger.error(f"Error merging channels: {e}") + raise + + if merged_image is not None: + logger.info("Channels successfully merged.") + else: + logger.warning("Merged image is None.") + + return merged_image + + +def process_directory(input_dir: Path, output_dir: Path, color_space: str = 'RGB') -> None: + """ + Process all images in a directory: extract channels, display histograms, and save channels. + + :param input_dir: Directory containing input images. + :param output_dir: Directory to save extracted channels. + :param color_space: Color space for channel extraction. + """ + logger.info( + f"Processing directory: {input_dir} with color space: {color_space}") + + if not input_dir.exists() or not input_dir.is_dir(): + logger.error( + f"Input directory does not exist or is not a directory: {input_dir}") + raise NotADirectoryError( + f"Input directory does not exist or is not a directory: {input_dir}") + + output_dir.mkdir(parents=True, exist_ok=True) + supported_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tiff') + + for image_path in input_dir.iterdir(): + if image_path.suffix.lower() in supported_extensions: + logger.info(f"Processing image: {image_path.name}") + try: + image = cv2.imread(str(image_path)) + if image is None: + logger.warning(f"Failed to read image: {image_path}") + continue + + extracted_channels = extract_channels(image, color_space) + + for channel_name, channel_data in extracted_channels.items(): + save_path = output_dir / \ + f"{image_path.stem}_{channel_name}.png" + cv2.imwrite(str(save_path), channel_data) + logger.info(f"Saved channel {channel_name} to {save_path}") + + # Optionally display histogram + show_histogram( + channel_data, title=f"{image_path.stem} - {channel_name}") + + except Exception as e: + logger.error(f"Error processing image {image_path.name}: {e}") + + +def save_channels(channels: Dict[str, np.ndarray], output_dir: Path, base_name: str = 'output') -> None: + """ + Save extracted channels to the specified directory. + + :param channels: Dictionary of channel names and their data. + :param output_dir: Directory to save the channels. + :param base_name: Base name for the saved channel files. + """ + logger.debug( + f"Saving channels to directory: {output_dir} with base name: {base_name}") + output_dir.mkdir(parents=True, exist_ok=True) + + for channel_name, channel_data in channels.items(): + filename = output_dir / f"{base_name}_{channel_name}.png" + cv2.imwrite(str(filename), channel_data) + logger.info(f"Saved channel {channel_name} to {filename}") + + +def display_image(title: str, image: np.ndarray) -> None: + """ + Display an image using OpenCV. + + :param title: Window title. + :param image: Image data in BGR format. + """ + logger.debug(f"Displaying image: {title}") + cv2.imshow(title, image) + cv2.waitKey(0) + cv2.destroyAllWindows() + + +def parse_args() -> argparse.Namespace: + """ + Parse command-line arguments. + + :return: Parsed arguments. + """ + parser = argparse.ArgumentParser( + description="Extract and process image channels in various color spaces." + ) + parser.add_argument('mode', type=str, choices=['extract', 'merge', 'process_dir'], + help="Mode of operation: extract, merge, or process_dir.") + parser.add_argument('--color_space', type=str, choices=['RGB', 'XYZ', 'LAB', 'LCH', + 'HSV', 'HSI', 'YUV', 'YCBCR', + 'HSL', 'CMYK'], + default='RGB', help="Target color space.") + parser.add_argument('--input', type=Path, required=True, + help="Path to the input image or directory.") + parser.add_argument('--output', type=Path, default=Path('./output_channels'), + help="Path to save the extracted channels or merged image.") + parser.add_argument('--base_name', type=str, default='output', + help="Base name for saving merged image.") + parser.add_argument('--format', type=str, choices=['PNG', 'JPEG', 'BMP', 'TIFF'], + default='PNG', help="Format for the output image.") + parser.add_argument('--show', action='store_true', + help="Display the merged image.") + return parser.parse_args() + + +def library_extract_channels(image: np.ndarray, color_space: str = 'RGB') -> Dict[str, np.ndarray]: + """ + Library function to extract channels from an image. + + :param image: Input image in BGR format. + :param color_space: Target color space. + :return: Dictionary of channel names and their data. + """ + return extract_channels(image, color_space) + + +def library_merge_channels(channels: Dict[str, np.ndarray]) -> Optional[np.ndarray]: + """ + Library function to merge channels into a single image. + + :param channels: Dictionary of channel names and their data. + :return: Merged image or None. + """ + return merge_channels(channels) + + +def main(): + args = parse_args() + + if args.mode == 'extract': + # Single image channel extraction + logger.info( + f"Extracting channels from image: {args.input} with color space: {args.color_space}") + try: + image = cv2.imread(str(args.input)) + if image is None: + logger.error(f"Failed to read image: {args.input}") + sys.exit(1) + + extracted_channels = extract_channels(image, args.color_space) + + for channel_name, channel_data in extracted_channels.items(): + save_path = args.output / \ + f"{args.input.stem}_{channel_name}.png" + args.output.mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(save_path), channel_data) + logger.info(f"Saved channel {channel_name} to {save_path}") + + # Optionally display histogram + show_histogram( + channel_data, title=f"{args.input.stem} - {channel_name}") + + except Exception as e: + logger.exception(f"Error during channel extraction: {e}") + sys.exit(1) + + elif args.mode == 'merge': + # Merge channels into a single image + logger.info(f"Merging channels from directory: {args.input}") + try: + if not args.input.is_dir(): + logger.error( + f"Input path must be a directory for merging: {args.input}") + sys.exit(1) + + channels = {} + for channel_file in args.input.iterdir(): + if channel_file.is_file() and channel_file.suffix.lower() == '.png': + parts = channel_file.stem.split('_') + if len(parts) < 2: + logger.warning( + f"Skipping file with unexpected naming: {channel_file.name}") + continue + channel_name = parts[-1] + channels[channel_name] = cv2.imread( + str(channel_file), cv2.IMREAD_GRAYSCALE) + + merged_image = merge_channels(channels) + if merged_image is not None: + merged_image_bgr = cv2.cvtColor( + merged_image, cv2.COLOR_RGB2BGR) + output_path = args.output / \ + f"{args.base_name}_merged.{args.format.lower()}" + cv2.imwrite(str(output_path), merged_image_bgr) + logger.info(f"Merged image saved to {output_path}") + + if args.show: + display_image("Merged Image", merged_image_bgr) + else: + logger.error( + "Merged image is None. Check if sufficient channels were provided.") + + except Exception as e: + logger.exception(f"Error during channel merging: {e}") + sys.exit(1) + + elif args.mode == 'process_dir': + # Batch processing mode + logger.info( + f"Batch processing images in directory: {args.input} with color space: {args.color_space}") + try: + process_directory(args.input, args.output, args.color_space) + except Exception as e: + logger.exception(f"Error during batch processing: {e}") + sys.exit(1) + else: + logger.error(f"Unsupported mode: {args.mode}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/pysrc/image/color_calibration/__init__.py b/modules/lithium.pyimage/image/color_calibration/__init__.py similarity index 100% rename from pysrc/image/color_calibration/__init__.py rename to modules/lithium.pyimage/image/color_calibration/__init__.py diff --git a/modules/lithium.pyimage/image/color_calibration/calibration.py b/modules/lithium.pyimage/image/color_calibration/calibration.py new file mode 100644 index 00000000..b789333d --- /dev/null +++ b/modules/lithium.pyimage/image/color_calibration/calibration.py @@ -0,0 +1,291 @@ +import cv2 +import numpy as np +from typing import Tuple, Optional, List +from dataclasses import dataclass, field +from pathlib import Path +from loguru import logger + + +@dataclass +class ColorCalibrationConfig: + """Configuration parameters for color calibration.""" + gamma: float = 1.0 # Gamma correction value + # 'gray_world', 'white_patch', 'learning_based' + white_balance_method: str = 'gray_world' + calibration_save_path: Optional[Path] = None + + +class ColorCalibration: + """Handles color calibration tasks for astronomical images.""" + + def __init__(self, image: np.ndarray, config: Optional[ColorCalibrationConfig] = None) -> None: + """ + Initialize the ColorCalibration class. + + :param image: Input RGB image + :param config: Color calibration configuration + """ + self.image = image + self.config = config or ColorCalibrationConfig() + + # Configure logging + logger.remove() + logger.add( + "color_calibration.log", + rotation="10 MB", + retention="10 days", + level="DEBUG", + format="{time} | {level} | {message}" + ) + logger.debug( + "ColorCalibration initialized with configuration: {}", self.config) + + def gamma_correction(self) -> np.ndarray: + """ + Apply gamma correction to the image. + + :return: Gamma corrected image + """ + logger.info("Starting gamma correction with gamma value: {}", + self.config.gamma) + try: + inv_gamma = 1.0 / self.config.gamma + table = np.array([((i / 255.0) ** inv_gamma) * + 255 for i in np.arange(256)]).astype("uint8") + corrected_image = cv2.LUT(self.image, table) + logger.debug("Gamma correction completed.") + return corrected_image + except Exception as e: + logger.error("Error during gamma correction: {}", e) + raise + + def apply_white_balance(self) -> np.ndarray: + """ + Apply white balance to the image based on the configured method. + + :return: White balanced image + """ + method = self.config.white_balance_method.lower() + logger.info("Starting white balance with method: {}", method) + if method == 'gray_world': + return self.gray_world_white_balance() + elif method == 'white_patch': + return self.white_patch_white_balance() + elif method == 'learning_based': + return self.learning_based_white_balance() + else: + logger.error("Unsupported white balance method: {}", method) + raise ValueError(f"Unsupported white balance method: {method}") + + def gray_world_white_balance(self) -> np.ndarray: + """ + Apply gray world white balance algorithm. + + :return: White balanced image + """ + logger.debug("Applying gray world white balance.") + try: + mean_values = np.mean(self.image, axis=(0, 1)) + gray_value = np.mean(mean_values) + factors = gray_value / mean_values + balanced_image = self.apply_color_factors(factors) + logger.debug("Gray world white balance completed.") + return balanced_image + except Exception as e: + logger.error("Error during gray world white balance: {}", e) + raise + + def white_patch_white_balance(self) -> np.ndarray: + """ + Apply white patch white balance algorithm. + + :return: White balanced image + """ + logger.debug("Applying white patch white balance.") + try: + max_values = np.max(self.image, axis=(0, 1)) + factors = 255.0 / max_values + balanced_image = self.apply_color_factors(factors) + logger.debug("White patch white balance completed.") + return balanced_image + except Exception as e: + logger.error("Error during white patch white balance: {}", e) + raise + + def learning_based_white_balance(self) -> np.ndarray: + """ + Apply learning-based white balance algorithm. + + :return: White balanced image + """ + logger.debug("Applying learning-based white balance.") + try: + wb = cv2.xphoto.createLearningBasedWB() + balanced_image = wb.balanceWhite(self.image) + logger.debug("Learning-based white balance completed.") + return balanced_image + except Exception as e: + logger.error("Error during learning-based white balance: {}", e) + raise + + def apply_color_factors(self, factors: np.ndarray) -> np.ndarray: + """ + Apply color correction factors. + + :param factors: Color correction factors + :return: Corrected image + """ + logger.debug("Applying color correction factors: {}", factors) + try: + calibrated_image = self.image.astype(np.float32) + for i in range(3): + calibrated_image[:, :, i] *= factors[i] + calibrated_image = np.clip( + calibrated_image, 0, 255).astype(np.uint8) + logger.debug("Color correction completed.") + return calibrated_image + except Exception as e: + logger.error("Error applying color correction factors: {}", e) + raise + + def save_calibration_parameters(self) -> None: + """ + Save calibration parameters to a file. + + """ + if self.config.calibration_save_path: + logger.info("Saving calibration parameters to file: {}", + self.config.calibration_save_path) + try: + calibration_data = { + 'gamma': self.config.gamma, + 'white_balance_method': self.config.white_balance_method + } + np.save(self.config.calibration_save_path, calibration_data) + logger.debug("Calibration parameters saved successfully.") + except Exception as e: + logger.error("Error saving calibration parameters: {}", e) + raise + + def load_calibration_parameters(self) -> None: + """ + Load calibration parameters from a file. + + """ + if self.config.calibration_save_path and self.config.calibration_save_path.exists(): + logger.info("Loading calibration parameters from file: {}", + self.config.calibration_save_path) + try: + calibration_data = np.load( + self.config.calibration_save_path, allow_pickle=True).item() + self.config.gamma = calibration_data.get( + 'gamma', self.config.gamma) + self.config.white_balance_method = calibration_data.get( + 'white_balance_method', self.config.white_balance_method) + logger.debug("Calibration parameters loaded successfully.") + except Exception as e: + logger.error("Error loading calibration parameters: {}", e) + raise + + def batch_process(self, image_list: List[np.ndarray]) -> List[np.ndarray]: + """ + Batch process a list of images. + + :param image_list: List of images to process + :return: List of processed images + """ + logger.info("Starting batch processing of {} images.", len(image_list)) + processed_images = [] + try: + for idx, img in enumerate(image_list): + logger.debug("Processing image {}.", idx + 1) + self.image = img + corrected_image = self.apply_white_balance() + corrected_image = self.gamma_correction() + processed_images.append(corrected_image) + logger.info("Batch processing completed.") + return processed_images + except Exception as e: + logger.error("Error during batch processing: {}", e) + raise + + def adjust_saturation(self, saturation_scale: float = 1.0) -> np.ndarray: + """ + Adjust the saturation of the image. + + :param saturation_scale: Saturation adjustment factor + :return: Image with adjusted saturation + """ + logger.info("Adjusting image saturation with scale: {}", + saturation_scale) + try: + hsv_image = cv2.cvtColor( + self.image, cv2.COLOR_BGR2HSV).astype(np.float32) + hsv_image[:, :, 1] *= saturation_scale + hsv_image[:, :, 1] = np.clip(hsv_image[:, :, 1], 0, 255) + adjusted_image = cv2.cvtColor( + hsv_image.astype(np.uint8), cv2.COLOR_HSV2BGR) + logger.debug("Saturation adjustment completed.") + return adjusted_image + except Exception as e: + logger.error("Error adjusting saturation: {}", e) + raise + + def adjust_brightness(self, brightness_offset: int = 0) -> np.ndarray: + """ + Adjust the brightness of the image. + + :param brightness_offset: Brightness offset value + :return: Image with adjusted brightness + """ + logger.info("Adjusting image brightness with offset: {}", + brightness_offset) + try: + adjusted_image = cv2.convertScaleAbs( + self.image, alpha=1, beta=brightness_offset) + logger.debug("Brightness adjustment completed.") + return adjusted_image + except Exception as e: + logger.error("Error adjusting brightness: {}", e) + raise + + +# Example usage +if __name__ == "__main__": + # Load image + image_path = "path/to/image.jpg" + image = cv2.imread(image_path) + if image is None: + logger.error("Unable to load image: {}", image_path) + exit(1) + + # Configure color calibration parameters + calibration_config = ColorCalibrationConfig( + gamma=2.2, + white_balance_method='gray_world', + calibration_save_path=Path("calibration_params.npy") + ) + + # Initialize color calibration object + color_calibrator = ColorCalibration(image=image, config=calibration_config) + + # Apply white balance + balanced_image = color_calibrator.apply_white_balance() + + # Apply gamma correction + corrected_image = color_calibrator.gamma_correction() + + # Adjust saturation + adjusted_image = color_calibrator.adjust_saturation(saturation_scale=1.2) + + # Adjust brightness + bright_image = color_calibrator.adjust_brightness(brightness_offset=10) + + # Save calibration parameters + color_calibrator.save_calibration_parameters() + + # Display results + cv2.imshow("Original Image", image) + cv2.imshow("Corrected Image", corrected_image) + cv2.waitKey(0) + cv2.destroyAllWindows() diff --git a/pysrc/image/color_calibration/io.py b/modules/lithium.pyimage/image/color_calibration/io.py similarity index 100% rename from pysrc/image/color_calibration/io.py rename to modules/lithium.pyimage/image/color_calibration/io.py diff --git a/pysrc/image/color_calibration/processing.py b/modules/lithium.pyimage/image/color_calibration/processing.py similarity index 100% rename from pysrc/image/color_calibration/processing.py rename to modules/lithium.pyimage/image/color_calibration/processing.py diff --git a/pysrc/image/color_calibration/utils.py b/modules/lithium.pyimage/image/color_calibration/utils.py similarity index 100% rename from pysrc/image/color_calibration/utils.py rename to modules/lithium.pyimage/image/color_calibration/utils.py diff --git a/modules/lithium.pyimage/image/debayer/debayer.py b/modules/lithium.pyimage/image/debayer/debayer.py new file mode 100644 index 00000000..a4e3cbb9 --- /dev/null +++ b/modules/lithium.pyimage/image/debayer/debayer.py @@ -0,0 +1,532 @@ +import numpy as np +import cv2 +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Optional, Tuple, List, Union +from pathlib import Path +from dataclasses import dataclass +from loguru import logger + + +@dataclass +class DebayerConfig: + """ + Configuration settings for the Debayer process. + """ + method: str = 'bilinear' # 'superpixel', 'bilinear', 'vng', 'ahd', 'laplacian' + # 'BGGR', 'RGGB', 'GBRG', 'GRBG', or None for auto-detection + pattern: Optional[str] = None + num_threads: int = 4 + visualize_intermediate: bool = False + visualization_save_path: Optional[Path] = None + save_debayered_images: bool = False + debayered_save_path_template: str = "{original_name}_{method}.png" + + +class Debayer: + def __init__(self, config: Optional[DebayerConfig] = None) -> None: + """ + Initialize Debayer object with configuration. + + :param config: DebayerConfig object containing debayer settings. + """ + self.config = config or DebayerConfig() + + # Configure Loguru logger + logger.remove() # Remove default logger + logger.add( + "debayer.log", + rotation="10 MB", + retention="10 days", + level="DEBUG", + format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}" + ) + logger.debug("Initialized Debayer with configuration: {}", self.config) + + def detect_bayer_pattern(self, image: np.ndarray) -> str: + """ + Automatically detect Bayer pattern from the CFA image. + + :param image: Grayscale CFA image. + :return: Detected Bayer pattern. + """ + logger.info("Starting Bayer pattern detection.") + height, width = image.shape + + # Initialize pattern scores + patterns = {'BGGR': 0, 'RGGB': 0, 'GBRG': 0, 'GRBG': 0} + + # Edge detection to enhance pattern recognition + edges = cv2.Canny(image, 50, 150) + + # Analyze each 2x2 block + for i in range(0, height - 1, 2): + for j in range(0, width - 1, 2): + block = image[i:i+2, j:j+2] + edge_block = edges[i:i+2, j:j+2] + + # Calculate scores based on intensity and edges + patterns['BGGR'] += block[0, 0] + block[1, 1] + \ + edge_block[0, 0] + edge_block[1, 1] + patterns['RGGB'] += block[1, 0] + block[0, 1] + \ + edge_block[1, 0] + edge_block[0, 1] + patterns['GBRG'] += block[0, 1] + block[1, 0] + \ + edge_block[0, 1] + edge_block[1, 0] + patterns['GRBG'] += block[0, 0] + block[1, 1] + \ + edge_block[0, 0] + edge_block[1, 1] + + detected_pattern = max(patterns, key=patterns.get) + logger.info("Detected Bayer pattern: {}", detected_pattern) + return detected_pattern + + def debayer_image(self, cfa_image: np.ndarray) -> np.ndarray: + """ + Perform the debayering process using the specified method. + + :param cfa_image: Grayscale CFA image. + :return: Debayered RGB image. + """ + logger.info("Starting debayering process.") + if self.config.pattern is None: + self.config.pattern = self.detect_bayer_pattern(cfa_image) + + # Extend image edges to handle boundary conditions + cfa_image_padded = self.extend_image_edges(cfa_image, pad_width=2) + logger.debug("Image edges extended with padding.") + + logger.debug("Using Bayer pattern: {}", self.config.pattern) + + method_dispatcher = { + 'superpixel': self.debayer_superpixel, + 'bilinear': self.debayer_bilinear, + 'vng': self.debayer_vng, + 'ahd': self.parallel_debayer_ahd, + 'laplacian': self.debayer_laplacian_harmonization + } + + debayer_method = self.config.method.lower() + if debayer_method not in method_dispatcher: + logger.error("Unknown debayer method: {}", self.config.method) + raise ValueError(f"Unknown debayer method: {self.config.method}") + + rgb_image = method_dispatcher[debayer_method](cfa_image_padded) + logger.info("Debayering completed using method: {}", + self.config.method) + + if self.config.save_debayered_images: + output_path = self.generate_save_path("debayered_image.png") + cv2.imwrite(str(output_path), rgb_image) + logger.debug("Debayered image saved to {}", output_path) + + return rgb_image + + def debayer_superpixel(self, cfa_image: np.ndarray) -> np.ndarray: + """ + Debayering using superpixel method. + + :param cfa_image: Padded CFA image. + :return: Debayered RGB image. + """ + logger.info("Debayering using superpixel method.") + red = cfa_image[0::2, 0::2] + green = (cfa_image[0::2, 1::2] + cfa_image[1::2, 0::2]) / 2 + blue = cfa_image[1::2, 1::2] + + rgb_image = np.stack((red, green, blue), axis=-1) + logger.debug("Superpixel debayering completed.") + return rgb_image + + def debayer_bilinear(self, cfa_image: np.ndarray) -> np.ndarray: + """ + Debayering using bilinear interpolation. + + :param cfa_image: Padded CFA image. + :return: Debayered RGB image. + """ + logger.info("Debayering using bilinear interpolation.") + pattern = self.config.pattern.upper() + pattern_codes = { + 'BGGR': cv2.COLOR_BayerBG2BGR, + 'RGGB': cv2.COLOR_BayerRG2BGR, + 'GBRG': cv2.COLOR_BayerGB2BGR, + 'GRBG': cv2.COLOR_BayerGR2BGR + } + + if pattern not in pattern_codes: + logger.error("Unsupported Bayer pattern: {}", pattern) + raise ValueError(f"Unsupported Bayer pattern: {pattern}") + + rgb_image = cv2.cvtColor(cfa_image, pattern_codes[pattern]) + logger.debug("Bilinear debayering completed.") + return rgb_image + + def debayer_vng(self, cfa_image: np.ndarray) -> np.ndarray: + """ + Debayering using VNG interpolation. + + :param cfa_image: Padded CFA image. + :return: Debayered RGB image. + """ + logger.info("Debayering using VNG interpolation.") + pattern = self.config.pattern.upper() + pattern_codes = { + 'BGGR': cv2.COLOR_BayerBG2BGR_VNG, + 'RGGB': cv2.COLOR_BayerRG2BGR_VNG, + 'GBRG': cv2.COLOR_BayerGB2BGR_VNG, + 'GRBG': cv2.COLOR_BayerGR2BGR_VNG + } + + if pattern not in pattern_codes: + logger.error("Unsupported Bayer pattern for VNG: {}", pattern) + raise ValueError(f"Unsupported Bayer pattern for VNG: {pattern}") + + rgb_image = cv2.cvtColor(cfa_image, pattern_codes[pattern]) + logger.debug("VNG debayering completed.") + return rgb_image + + def parallel_debayer_ahd(self, cfa_image: np.ndarray) -> np.ndarray: + """ + Debayering using Adaptive Homogeneity-Directed (AHD) interpolation with multithreading. + + :param cfa_image: Padded CFA image. + :return: Debayered RGB image. + """ + logger.info( + "Debayering using Adaptive Homogeneity-Directed (AHD) interpolation.") + height, width = cfa_image.shape + chunk_size = height // self.config.num_threads + results: List[np.ndarray] = [None] * self.config.num_threads + + def process_chunk(start_row: int, end_row: int, index: int): + logger.debug("Processing chunk {}: rows {} to {}", + index, start_row, end_row) + chunk = cfa_image[start_row:end_row, :] + gradient_x, gradient_y = self.calculate_gradients(chunk) + green_channel = self.interpolate_green_channel( + chunk, gradient_x, gradient_y) + red_channel, blue_channel = self.interpolate_red_blue_channel( + chunk, green_channel, self.config.pattern) + rgb_chunk = np.stack( + (red_channel, green_channel, blue_channel), axis=-1) + results[index] = np.clip(rgb_chunk, 0, 255).astype(np.uint8) + logger.debug("Chunk {} processing completed.", index) + + with ThreadPoolExecutor(max_workers=self.config.num_threads) as executor: + futures = [] + for i in range(self.config.num_threads): + start_row = i * chunk_size + end_row = ( + i + 1) * chunk_size if i < self.config.num_threads - 1 else height + futures.append(executor.submit( + process_chunk, start_row, end_row, i)) + + for future in as_completed(futures): + future.result() + + rgb_image = np.vstack(results) + logger.debug("AHD debayering completed with multithreading.") + return rgb_image + + def debayer_laplacian_harmonization(self, cfa_image: np.ndarray) -> np.ndarray: + """ + Debayering using Laplacian harmonization to enhance edges. + + :param cfa_image: Padded CFA image. + :return: Debayered RGB image with harmonized edges. + """ + logger.info("Debayering using Laplacian harmonization.") + interpolated_image = self.debayer_bilinear(cfa_image) + + # Calculate Laplacian for each channel + laplacian = {} + for idx, color in enumerate(['Blue', 'Green', 'Red']): + lap = self.calculate_laplacian(interpolated_image[:, :, idx]) + laplacian[color] = lap + logger.debug("Laplacian calculated for {} channel.", color) + + # Harmonize each channel + harmonized_channels = [] + for idx, color in enumerate(['Blue', 'Green', 'Red']): + harmonized = self.harmonize_edges( + interpolated_image[:, :, idx], laplacian[color]) + harmonized_channels.append(harmonized) + logger.debug("{} channel harmonized.", color) + + harmonized_image = np.stack(harmonized_channels, axis=-1) + logger.debug("Laplacian harmonization completed.") + return harmonized_image + + @staticmethod + def calculate_gradients(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Calculate the gradients of the CFA image. + + :param image: CFA image chunk. + :return: Tuple of gradient_x and gradient_y. + """ + gradient_x = np.abs(np.diff(image, axis=1)) + gradient_y = np.abs(np.diff(image, axis=0)) + + gradient_x = np.pad(gradient_x, ((0, 0), (0, 1)), 'constant') + gradient_y = np.pad(gradient_y, ((0, 1), (0, 0)), 'constant') + + return gradient_x, gradient_y + + @staticmethod + def interpolate_green_channel(cfa_image: np.ndarray, gradient_x: np.ndarray, gradient_y: np.ndarray) -> np.ndarray: + """ + Interpolate the green channel of the CFA image based on gradients. + + :param cfa_image: CFA image chunk. + :param gradient_x: Gradient in x-direction. + :param gradient_y: Gradient in y-direction. + :return: Interpolated green channel. + """ + logger.debug("Interpolating green channel.") + height, width = cfa_image.shape + green_channel = np.zeros((height, width), dtype=np.float64) + + for i in range(1, height - 1): + for j in range(1, width - 1): + if (i % 2 == 0 and j % 2 == 1) or (i % 2 == 1 and j % 2 == 0): + # Green pixel + green_channel[i, j] = cfa_image[i, j] + else: + # Interpolate green + if gradient_x[i, j] < gradient_y[i, j]: + green_channel[i, j] = 0.5 * \ + (cfa_image[i, j-1] + cfa_image[i, j+1]) + else: + green_channel[i, j] = 0.5 * \ + (cfa_image[i-1, j] + cfa_image[i+1, j]) + + return green_channel + + @staticmethod + def interpolate_red_blue_channel(cfa_image: np.ndarray, green_channel: np.ndarray, pattern: Optional[str] = None) -> Tuple[np.ndarray, np.ndarray]: + """ + Interpolate the red and blue channels of the CFA image based on the green channel. + + :param cfa_image: CFA image chunk. + :param green_channel: Interpolated green channel. + :param pattern: Bayer pattern. + :return: Tuple of interpolated red and blue channels. + """ + logger.debug("Interpolating red and blue channels.") + height, width = cfa_image.shape + red_channel = np.zeros((height, width), dtype=np.float64) + blue_channel = np.zeros((height, width), dtype=np.float64) + + pattern = pattern.upper() if pattern else 'BGGR' + + for i in range(0, height - 1, 2): + for j in range(0, width - 1, 2): + if pattern == 'BGGR': + blue_channel[i, j] = cfa_image[i, j] + red_channel[i+1, j+1] = cfa_image[i+1, j+1] + + green_r = 0.5 * \ + (green_channel[i+1, j] + green_channel[i, j+1]) + green_b = 0.5 * \ + (green_channel[i, j] + green_channel[i+1, j+1]) + + blue_channel[i+1, j] = cfa_image[i+1, j] - \ + green_b + green_channel[i+1, j] + blue_channel[i, j+1] = cfa_image[i, j+1] - \ + green_b + green_channel[i, j+1] + red_channel[i, j] = cfa_image[i, j] - \ + green_r + green_channel[i, j] + red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ + green_r + green_channel[i+1, j+1] + + elif pattern == 'RGGB': + red_channel[i, j] = cfa_image[i, j] + blue_channel[i+1, j+1] = cfa_image[i+1, j+1] + + green_r = 0.5 * \ + (green_channel[i, j+1] + green_channel[i+1, j]) + green_b = 0.5 * \ + (green_channel[i+1, j] + green_channel[i, j+1]) + + red_channel[i+1, j] = cfa_image[i+1, j] - \ + green_r + green_channel[i+1, j] + red_channel[i, j+1] = cfa_image[i, j+1] - \ + green_r + green_channel[i, j+1] + blue_channel[i, j] = cfa_image[i, j] - \ + green_b + green_channel[i, j] + blue_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ + green_b + green_channel[i+1, j+1] + + elif pattern == 'GBRG': + green_channel[i, j+1] = cfa_image[i, j+1] + blue_channel[i+1, j] = cfa_image[i+1, j] + + green_r = 0.5 * \ + (green_channel[i, j] + green_channel[i+1, j+1]) + green_b = 0.5 * \ + (green_channel[i+1, j] + green_channel[i, j+1]) + + red_channel[i, j] = cfa_image[i, j] - \ + green_r + green_channel[i, j] + red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ + green_r + green_channel[i+1, j+1] + blue_channel[i, j] = cfa_image[i, j] - \ + green_b + green_channel[i, j] + blue_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ + green_b + green_channel[i+1, j+1] + + elif pattern == 'GRBG': + green_channel[i, j] = cfa_image[i, j] + red_channel[i+1, j] = cfa_image[i+1, j] + + green_r = 0.5 * \ + (green_channel[i, j] + green_channel[i+1, j+1]) + green_b = 0.5 * \ + (green_channel[i+1, j] + green_channel[i, j+1]) + + red_channel[i, j] = cfa_image[i, j] - \ + green_r + green_channel[i, j] + red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ + green_r + green_channel[i+1, j+1] + blue_channel[i+1, j] = cfa_image[i+1, j] - \ + green_b + green_channel[i+1, j] + blue_channel[i, j+1] = cfa_image[i, j+1] - \ + green_b + green_channel[i, j+1] + + return red_channel, blue_channel + + def calculate_laplacian(image: np.ndarray) -> np.ndarray: + """ + Calculate the Laplacian of the image for edge enhancement. + + :param image: Single-channel image. + :return: Laplacian image. + """ + laplacian = cv2.Laplacian(image, cv2.CV_64F) + return laplacian + + def harmonize_edges(original: np.ndarray, interpolated: np.ndarray, laplacian: np.ndarray) -> np.ndarray: + """ + Harmonize edges using the Laplacian result. + + :param original: Original CFA image. + :param interpolated: Interpolated channel image. + :param laplacian: Laplacian image. + :return: Harmonized channel image. + """ + harmonized = np.clip(interpolated + 0.2 * + laplacian, 0, 255).astype(np.uint8) + return harmonized + + def extend_image_edges(image: np.ndarray, pad_width: int) -> np.ndarray: + """ + Extend image edges using mirror padding to handle boundary issues during interpolation. + + :param image: Input image. + :param pad_width: Width of padding. + :return: Padded image. + """ + return np.pad(image, pad_width, mode='reflect') + + def visualize_intermediate_steps(image_path: Union[str, Path], debayered_image: np.ndarray, config: DebayerConfig): + """ + Visualize intermediate steps in the debayering process. + + :param image_path: Path to the original CFA image. + :param debayered_image: Debayered RGB image. + :param config: DebayerConfig object. + """ + import matplotlib.pyplot as plt + + if config.visualize_intermediate and config.visualization_save_path: + logger.info("Visualizing intermediate steps.") + # Load original image + cfa_image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) + + # Detect pattern if not set + pattern = config.pattern or Debayer().detect_bayer_pattern(cfa_image) + + # Debayer using bilinear for visualization + debayer_bilinear = Debayer().debayer_bilinear(cfa_image, pattern) + + # Calculate gradients + gradient_x, gradient_y = Debayer.calculate_gradients(cfa_image) + green_channel = Debayer.interpolate_green_channel( + cfa_image, gradient_x, gradient_y) + red_channel, blue_channel = Debayer.interpolate_red_blue_channel( + cfa_image, green_channel, pattern) + + # Display images + plt.figure(figsize=(15, 10)) + + plt.subplot(2, 3, 1) + plt.imshow(cfa_image, cmap='gray') + plt.title('Original CFA Image') + plt.axis('off') + + plt.subplot(2, 3, 2) + plt.imshow(gradient_x, cmap='gray') + plt.title('Gradient X') + plt.axis('off') + + plt.subplot(2, 3, 3) + plt.imshow(gradient_y, cmap='gray') + plt.title('Gradient Y') + plt.axis('off') + + plt.subplot(2, 3, 4) + plt.imshow(green_channel, cmap='gray') + plt.title('Interpolated Green Channel') + plt.axis('off') + + plt.subplot(2, 3, 5) + plt.imshow(red_channel, cmap='gray') + plt.title('Interpolated Red Channel') + plt.axis('off') + + plt.subplot(2, 3, 6) + plt.imshow(blue_channel, cmap='gray') + plt.title('Interpolated Blue Channel') + plt.axis('off') + + plt.tight_layout() + if config.visualization_save_path: + plt.savefig(config.visualization_save_path) + logger.debug("Intermediate visualization saved to {}.", + config.visualization_save_path) + plt.show() + logger.info("Intermediate visualization displayed successfully.") + + def generate_save_path(self, original_path: str) -> Path: + """ + Generate a save path for the debayered image based on the original image name and method. + + :param original_path: Original image file path. + :return: Path object for the debayered image. + """ + original_name = Path(original_path).stem + save_path = Path(self.config.debayered_save_path_template.format( + original_name=original_name, method=self.config.method)) + return save_path + + # Example usage in __main__ + if __name__ == "__main__": + # Example usage + config = DebayerConfig( + method='bilinear', + pattern=None, # Auto-detect + num_threads=4, + visualize_intermediate=True, + visualization_save_path=Path("intermediate_steps.png"), + save_debayered_images=True, + debayered_save_path_template="{original_name}_{method}.png" + ) + debayer = Debayer(config=config) + image_path = "path/to/cfa_image.png" + + try: + debayered_rgb = debayer.debayer_image( + cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)) + visualize_intermediate_steps(image_path, debayered_rgb, config) + logger.info("Debayering process completed successfully.") + except Exception as e: + logger.error("An error occurred during debayering: {}", e) diff --git a/pysrc/image/defect_map/__init__.py b/modules/lithium.pyimage/image/defect_map/__init__.py similarity index 100% rename from pysrc/image/defect_map/__init__.py rename to modules/lithium.pyimage/image/defect_map/__init__.py diff --git a/pysrc/image/defect_map/defect_correction.py b/modules/lithium.pyimage/image/defect_map/defect_correction.py similarity index 89% rename from pysrc/image/defect_map/defect_correction.py rename to modules/lithium.pyimage/image/defect_map/defect_correction.py index e138f708..f0e1e741 100644 --- a/pysrc/image/defect_map/defect_correction.py +++ b/modules/lithium.pyimage/image/defect_map/defect_correction.py @@ -4,7 +4,13 @@ from skimage import img_as_float from .interpolation import interpolate_defects import multiprocessing -from typing import Optional +from typing import Optional, Tuple +from loguru import logger + +# Configure Loguru logger +logger.add("defect_correction.log", rotation="10 MB", retention="10 days", + level="DEBUG", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") + def defect_map_enhanced(image: np.ndarray, defect_map: np.ndarray, operation: str = 'mean', structure: str = 'square', radius: int = 1, is_cfa: bool = False, protect_edges: bool = False, @@ -25,6 +31,7 @@ def defect_map_enhanced(image: np.ndarray, defect_map: np.ndarray, operation: st Returns: - corrected_image: np.ndarray, The repaired image. """ + logger.info("Starting defect map enhancement") if structure == 'square': footprint = np.ones((2 * radius + 1, 2 * radius + 1)) elif structure == 'circular': @@ -37,6 +44,7 @@ def defect_map_enhanced(image: np.ndarray, defect_map: np.ndarray, operation: st footprint = np.zeros((2 * radius + 1, 1)) footprint[:, 0] = 1 else: + logger.error("Invalid structure type.") raise ValueError("Invalid structure type.") mask = defect_map == 0 @@ -55,8 +63,10 @@ def defect_map_enhanced(image: np.ndarray, defect_map: np.ndarray, operation: st corrected_image = correct_channel( image, mask, operation, footprint, adaptive_structure) + logger.info("Defect map enhancement completed") return corrected_image + def correct_channel(channel: np.ndarray, mask: np.ndarray, operation: str, footprint: np.ndarray, adaptive_structure: bool) -> np.ndarray: """ @@ -72,6 +82,7 @@ def correct_channel(channel: np.ndarray, mask: np.ndarray, operation: str, footp Returns: - channel: np.ndarray, The repaired image channel. """ + logger.debug(f"Correcting channel with operation: {operation}") if adaptive_structure: density = np.sum(mask) / mask.size radius = int(3 / density) if density > 0 else 1 @@ -93,12 +104,14 @@ def correct_channel(channel: np.ndarray, mask: np.ndarray, operation: str, footp elif operation == 'bicubic': channel_corrected = interpolate_defects(channel, mask, method='cubic') else: + logger.error("Invalid operation type.") raise ValueError("Invalid operation type.") channel[mask] = channel_corrected[mask] - + logger.debug("Channel correction completed") return channel + def parallel_defect_map(image: np.ndarray, defect_map: np.ndarray, **kwargs) -> np.ndarray: """ Parallel processing for defect map repair. @@ -111,6 +124,7 @@ def parallel_defect_map(image: np.ndarray, defect_map: np.ndarray, **kwargs) -> Returns: - corrected_image: np.ndarray, The repaired image. """ + logger.info("Starting parallel defect map processing") if image.ndim == 2: return defect_map_enhanced(image, defect_map, **kwargs) @@ -124,8 +138,10 @@ def parallel_defect_map(image: np.ndarray, defect_map: np.ndarray, **kwargs) -> pool.close() pool.join() + logger.info("Parallel defect map processing completed") return np.stack(results, axis=-1) + def defect_map_enhanced_single_channel(channel: np.ndarray, defect_map: np.ndarray, operation: str, structure: str, radius: int, is_cfa: bool, protect_edges: bool, adaptive_structure: bool) -> np.ndarray: @@ -145,4 +161,5 @@ def defect_map_enhanced_single_channel(channel: np.ndarray, defect_map: np.ndarr Returns: - channel: np.ndarray, The repaired image channel. """ + logger.debug("Processing single channel for defect map enhancement") return defect_map_enhanced(channel, defect_map, operation, structure, radius, is_cfa, protect_edges, adaptive_structure) diff --git a/pysrc/image/defect_map/interpolation.py b/modules/lithium.pyimage/image/defect_map/interpolation.py similarity index 100% rename from pysrc/image/defect_map/interpolation.py rename to modules/lithium.pyimage/image/defect_map/interpolation.py diff --git a/pysrc/image/defect_map/utils.py b/modules/lithium.pyimage/image/defect_map/utils.py similarity index 100% rename from pysrc/image/defect_map/utils.py rename to modules/lithium.pyimage/image/defect_map/utils.py diff --git a/modules/lithium.pyimage/image/fluxcalibration/calibration.py b/modules/lithium.pyimage/image/fluxcalibration/calibration.py index bdb91c7b..77a19ae7 100644 --- a/modules/lithium.pyimage/image/fluxcalibration/calibration.py +++ b/modules/lithium.pyimage/image/fluxcalibration/calibration.py @@ -1,29 +1,59 @@ -from .utils import instrument_response_correction, background_noise_correction -from .core import CalibrationParams -from typing import Any, Optional, Tuple, Dict +from pathlib import Path +from typing import Any, Dict, Optional, Tuple +from dataclasses import dataclass from loguru import logger from astropy.io import fits import numpy as np +import cv2 +import argparse +import sys +import concurrent.futures + + +# Configure Loguru logger +logger.remove() # Remove the default logger +logger.add(sys.stderr, level="INFO", format="{time} {level} {message}") + + +@dataclass +class CalibrationParams: + """Data class to store calibration parameters.""" + wavelength: float # Wavelength in nanometers + aperture: float # Aperture diameter in millimeters + obstruction: float # Obstruction diameter in millimeters + filter_width: float # Filter bandwidth in nanometers + transmissivity: float # Transmissivity + gain: float # Gain + quantum_efficiency: float # Quantum efficiency + extinction: float # Extinction coefficient + exposure_time: float # Exposure time in seconds def compute_flx2dn(params: CalibrationParams) -> float: """ - Compute the flux conversion factor (FLX2DN). + Compute the flux-to-DN conversion factor (FLX2DN). :param params: Calibration parameters. - :return: Flux conversion factor. + :return: Flux-to-DN conversion factor. """ - logger.debug("Starting computation of FLX2DN.") + logger.debug("Starting FLX2DN computation.") try: c = 3.0e8 # Speed of light in m/s h = 6.626e-34 # Planck's constant in J·s - wavelength_m = params.wavelength * 1e-9 # Convert nm to meters + wavelength_m = params.wavelength * 1e-9 # Convert nanometers to meters aperture_area = np.pi * \ - (params.aperture**2 - params.obstruction**2) / 4 - FLX2DN = (params.exposure_time * aperture_area * params.filter_width * - params.transmissivity * params.gain * params.quantum_efficiency * - (1 - params.extinction) * (wavelength_m / (c * h))) + ((params.aperture**2 - params.obstruction**2) / 4) + FLX2DN = ( + params.exposure_time * + aperture_area * + params.filter_width * + params.transmissivity * + params.gain * + params.quantum_efficiency * + (1 - params.extinction) * + (wavelength_m / (c * h)) + ) logger.info(f"Computed FLX2DN: {FLX2DN}") return FLX2DN except Exception as e: @@ -31,15 +61,18 @@ def compute_flx2dn(params: CalibrationParams) -> float: raise RuntimeError("Failed to compute FLX2DN.") -def flux_calibration(image: np.ndarray, params: CalibrationParams, - response_function: Optional[np.ndarray] = None) -> Tuple[np.ndarray, float, float, float]: +def flux_calibration( + image: np.ndarray, + params: CalibrationParams, + response_function: Optional[np.ndarray] = None +) -> Tuple[np.ndarray, float, float, float]: """ Perform flux calibration on an astronomical image. - :param image: Input image (numpy array). + :param image: Input image as a numpy array. :param params: Calibration parameters. - :param response_function: Optional instrument response function (numpy array). - :return: Tuple containing calibrated and rescaled image, FLXMIN, FLXRANGE, FLX2DN. + :param response_function: Optional instrument response function as a numpy array. + :return: Tuple containing the calibrated and rescaled image, FLXMIN, FLXRANGE, FLX2DN. """ logger.debug("Starting flux calibration process.") try: @@ -49,18 +82,16 @@ def flux_calibration(image: np.ndarray, params: CalibrationParams, FLX2DN = compute_flx2dn(params) calibrated_image = image / FLX2DN - logger.debug("Applied FLX2DN to image.") + logger.debug("Applied FLX2DN to the image.") calibrated_image = background_noise_correction(calibrated_image) logger.debug("Applied background noise correction.") - # Rescale the image to the range [0, 1] - min_val = np.min(calibrated_image) - max_val = np.max(calibrated_image) - FLXRANGE = max_val - min_val - FLXMIN = min_val - rescaled_image = (calibrated_image - min_val) / FLXRANGE - logger.info("Rescaled calibrated image to [0, 1].") + # Rescale image to range [0, 1] + FLXMIN = np.min(calibrated_image) + FLXRANGE = np.max(calibrated_image) - FLXMIN + rescaled_image = (calibrated_image - FLXMIN) / FLXRANGE + logger.info("Rescaled calibrated image to [0, 1] range.") return rescaled_image, FLXMIN, FLXRANGE, FLX2DN except Exception as e: @@ -68,17 +99,23 @@ def flux_calibration(image: np.ndarray, params: CalibrationParams, raise RuntimeError("Flux calibration process failed.") -def save_to_fits(image: np.ndarray, filename: str, FLXMIN: float, FLXRANGE: float, - FLX2DN: float, header_info: Optional[Dict[str, Any]] = None) -> None: +def save_to_fits( + image: np.ndarray, + filename: str, + FLXMIN: float, + FLXRANGE: float, + FLX2DN: float, + header_info: Optional[Dict[str, Any]] = None +) -> None: """ Save the calibrated image to a FITS file with necessary header information. - :param image: Calibrated image (numpy array). - :param filename: Output FITS file name. + :param image: Calibrated image as a numpy array. + :param filename: Output FITS filename. :param FLXMIN: Minimum flux value used for rescaling. - :param FLXRANGE: Range of flux values used for rescaling. - :param FLX2DN: Flux conversion factor. - :param header_info: Dictionary of additional header information to include. + :param FLXRANGE: Flux range used for rescaling. + :param FLX2DN: Flux-to-DN conversion factor. + :param header_info: Optional additional header information. """ logger.debug(f"Saving calibrated image to FITS file: {filename}") try: @@ -94,49 +131,355 @@ def save_to_fits(image: np.ndarray, filename: str, FLXMIN: float, FLXRANGE: floa hdr[key] = (value, f'Additional header key {key}') hdu.writeto(filename, overwrite=True) - logger.info(f"Calibrated image saved successfully to {filename}") + logger.info(f"Calibrated image successfully saved to {filename}") except Exception as e: logger.error(f"Failed to save calibrated image to FITS: {e}") - raise IOError("Saving to FITS file failed.") + raise IOError("Failed to save to FITS file.") def apply_flat_field_correction(image: np.ndarray, flat_field: np.ndarray) -> np.ndarray: """ - Apply flat field correction to an image. + Apply flat-field correction to the image. - :param image: Input image (numpy array). - :param flat_field: Flat field image (numpy array). + :param image: Input image as a numpy array. + :param flat_field: Flat-field image as a numpy array. :return: Corrected image. """ - logger.debug("Applying flat field correction.") + logger.debug("Applying flat-field correction.") try: if image.shape != flat_field.shape: - logger.error("Image and flat field shapes do not match.") - raise ValueError("Image and flat field must have the same shape.") + logger.error("Image and flat-field image shapes do not match.") + raise ValueError( + "Image and flat-field image must have the same shape.") corrected_image = image / flat_field - logger.info("Flat field correction applied successfully.") + logger.info("Successfully applied flat-field correction.") return corrected_image except Exception as e: - logger.error(f"Flat field correction failed: {e}") - raise RuntimeError("Flat field correction failed.") + logger.error(f"Flat-field correction failed: {e}") + raise RuntimeError("Flat-field correction failed.") def apply_dark_frame_subtraction(image: np.ndarray, dark_frame: np.ndarray) -> np.ndarray: """ - Apply dark frame subtraction to an image. + Apply dark frame subtraction to the image. - :param image: Input image (numpy array). - :param dark_frame: Dark frame image (numpy array). + :param image: Input image as a numpy array. + :param dark_frame: Dark frame image as a numpy array. :return: Corrected image. """ logger.debug("Applying dark frame subtraction.") try: if image.shape != dark_frame.shape: - logger.error("Image and dark frame shapes do not match.") - raise ValueError("Image and dark frame must have the same shape.") + logger.error("Image and dark frame image shapes do not match.") + raise ValueError( + "Image and dark frame image must have the same shape.") corrected_image = image - dark_frame - logger.info("Dark frame subtraction applied successfully.") + logger.info("Successfully applied dark frame subtraction.") return corrected_image except Exception as e: logger.error(f"Dark frame subtraction failed: {e}") raise RuntimeError("Dark frame subtraction failed.") + + +def instrument_response_correction(image: np.ndarray, response_function: np.ndarray) -> np.ndarray: + """ + Apply instrument response correction to the image. + + :param image: Input image as a numpy array. + :param response_function: Instrument response function as a numpy array. + :return: Corrected image. + """ + logger.debug("Applying instrument response correction.") + try: + if image.shape != response_function.shape: + logger.error("Image and response function shapes do not match.") + raise ValueError( + "Image and response function must have the same shape.") + corrected_image = image * response_function + logger.info("Successfully applied instrument response correction.") + return corrected_image + except Exception as e: + logger.error(f"Instrument response correction failed: {e}") + raise RuntimeError("Instrument response correction failed.") + + +def background_noise_correction(image: np.ndarray) -> np.ndarray: + """ + Apply background noise correction to the image. + + :param image: Input image as a numpy array. + :return: Corrected image. + """ + logger.debug("Applying background noise correction.") + try: + median = np.median(image) + corrected_image = image - median + logger.info("Successfully applied background noise correction.") + return corrected_image + except Exception as e: + logger.error(f"Background noise correction failed: {e}") + raise RuntimeError("Background noise correction failed.") + + +def parse_args() -> argparse.Namespace: + """ + Parse command-line arguments. + + :return: Parsed arguments namespace. + """ + parser = argparse.ArgumentParser( + description="Perform flux calibration on astronomical images." + ) + subparsers = parser.add_subparsers( + dest='command', required=True, help='Sub-command help' + ) + + # Subcommand: calibrate + parser_calibrate = subparsers.add_parser( + 'calibrate', help='Calibrate a single image.' + ) + parser_calibrate.add_argument( + '--input', type=Path, required=True, help='Path to the input image.' + ) + parser_calibrate.add_argument( + '--output', type=Path, required=True, help='Path to save the calibrated FITS file.' + ) + parser_calibrate.add_argument( + '--params', type=Path, required=True, help='Path to calibration parameters file.' + ) + parser_calibrate.add_argument( + '--response', type=Path, help='Path to the instrument response function image.' + ) + parser_calibrate.add_argument( + '--save_bg_noise', action='store_true', help='Save background noise information.' + ) + + # Subcommand: batch_calibrate + parser_batch = subparsers.add_parser( + 'batch_calibrate', help='Batch calibrate multiple images in a directory.' + ) + parser_batch.add_argument( + '--input_dir', type=Path, required=True, help='Path to the input images directory.' + ) + parser_batch.add_argument( + '--output_dir', type=Path, required=True, help='Path to save the calibrated FITS files.' + ) + parser_batch.add_argument( + '--params_dir', type=Path, required=True, help='Path to the calibration parameters directory.' + ) + parser_batch.add_argument( + '--response_dir', type=Path, help='Path to the instrument response functions directory.' + ) + parser_batch.add_argument( + '--save_bg_noise', action='store_true', help='Save background noise information.' + ) + + return parser.parse_args() + + +def load_calibration_params(params_path: Path) -> CalibrationParams: + """ + Load calibration parameters from a file. + + :param params_path: Path to the calibration parameters file. + :return: CalibrationParams object. + """ + logger.debug(f"Loading calibration parameters from {params_path}") + try: + # Assuming the parameter file is a simple text file with key=value per line + params_dict = {} + with params_path.open('r') as f: + for line in f: + if line.strip() and not line.startswith('#'): + key, value = line.strip().split('=') + params_dict[key.strip()] = float(value.strip()) + + params = CalibrationParams( + wavelength=params_dict['wavelength'], + aperture=params_dict['aperture'], + obstruction=params_dict['obstruction'], + filter_width=params_dict['filter_width'], + transmissivity=params_dict['transmissivity'], + gain=params_dict['gain'], + quantum_efficiency=params_dict['quantum_efficiency'], + extinction=params_dict['extinction'], + exposure_time=params_dict['exposure_time'] + ) + logger.info(f"Successfully loaded calibration parameters: {params}") + return params + except Exception as e: + logger.error(f"Failed to load calibration parameters: {e}") + raise IOError("Failed to load calibration parameters.") + + +def calibrate_image( + image_path: Path, + output_path: Path, + params: CalibrationParams, + response_path: Optional[Path] = None, + save_bg_noise: bool = False +) -> None: + """ + Calibrate a single image and save the result. + + :param image_path: Path to the input image. + :param output_path: Path to save the calibrated FITS file. + :param params: CalibrationParams object. + :param response_path: Optional path to the instrument response function image. + :param save_bg_noise: Whether to save background noise information. + """ + logger.info(f"Starting calibration for image: {image_path}") + try: + image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) + if image is None: + logger.error(f"Failed to read image: {image_path}") + raise IOError(f"Failed to read image: {image_path}") + + response_function = None + if response_path: + response_function = cv2.imread( + str(response_path), cv2.IMREAD_GRAYSCALE) + if response_function is None: + logger.error( + f"Failed to read instrument response function image: {response_path}") + raise IOError( + f"Failed to read instrument response function image: {response_path}") + + calibrated_image, FLXMIN, FLXRANGE, FLX2DN = flux_calibration( + image, params, response_function + ) + + save_to_fits( + calibrated_image, + str(output_path), + FLXMIN, + FLXRANGE, + FLX2DN + ) + + if save_bg_noise: + logger.debug( + f"Saving background noise information: FLXMIN={FLXMIN}, FLXRANGE={FLXRANGE}") + # Example: Save background noise information to a separate text file + bg_noise_path = output_path.with_suffix('.bg_noise.txt') + with bg_noise_path.open('w') as f: + f.write(f"FLXMIN={FLXMIN}\nFLXRANGE={FLXRANGE}\n") + logger.info( + f"Background noise information saved to {bg_noise_path}") + + except Exception as e: + logger.error(f"Calibration failed for image {image_path}: {e}") + raise RuntimeError(f"Calibration failed for image {image_path}.") + + +def batch_calibrate( + input_dir: Path, + output_dir: Path, + params_dir: Path, + response_dir: Optional[Path] = None, + save_bg_noise: bool = False +) -> None: + """ + Batch calibrate all images in a directory. + + :param input_dir: Path to the input images directory. + :param output_dir: Path to save the calibrated FITS files. + :param params_dir: Path to the calibration parameters directory. + :param response_dir: Optional path to the instrument response functions directory. + :param save_bg_noise: Whether to save background noise information. + """ + logger.info(f"Starting batch calibration from {input_dir} to {output_dir}") + try: + if not input_dir.exists() or not input_dir.is_dir(): + logger.error( + f"Input directory does not exist or is not a directory: {input_dir}") + raise NotADirectoryError( + f"Input directory does not exist or is not a directory: {input_dir}") + + output_dir.mkdir(parents=True, exist_ok=True) + + image_files = list(input_dir.glob('*.*')) + if not image_files: + logger.warning( + f"No image files found in input directory: {input_dir}") + return + + def process_single_image(image_path: Path): + basename = image_path.stem + params_path = params_dir / f"{basename}.txt" + if not params_path.exists(): + logger.warning( + f"Calibration parameters file not found: {params_path}") + return + + output_path = output_dir / f"{basename}.fits" + response_path = None + if response_dir: + response_path = response_dir / f"{basename}_response.fits" + if not response_path.exists(): + logger.warning( + f"Instrument response function file not found: {response_path}") + response_path = None + + try: + params = load_calibration_params(params_path) + calibrate_image( + image_path=image_path, + output_path=output_path, + params=params, + response_path=response_path, + save_bg_noise=save_bg_noise + ) + except Exception as e: + logger.error(f"Error processing image {image_path}: {e}") + + with concurrent.futures.ThreadPoolExecutor() as executor: + executor.map(process_single_image, image_files) + + logger.info("Batch calibration completed successfully.") + except Exception as e: + logger.error(f"Batch calibration failed: {e}") + raise RuntimeError("Batch calibration process failed.") + + +def main(): + """ + Main function to parse command-line arguments and execute corresponding operations. + """ + args = parse_args() + + if args.command == 'calibrate': + # Calibrate a single image + try: + params = load_calibration_params(args.params) + calibrate_image( + image_path=args.input, + output_path=args.output, + params=params, + response_path=args.response, + save_bg_noise=args.save_bg_noise + ) + except Exception as e: + logger.exception(f"Single image calibration failed: {e}") + sys.exit(1) + + elif args.command == 'batch_calibrate': + # Batch calibrate multiple images + try: + batch_calibrate( + input_dir=args.input_dir, + output_dir=args.output_dir, + params_dir=args.params_dir, + response_dir=args.response_dir, + save_bg_noise=args.save_bg_noise + ) + except Exception as e: + logger.exception(f"Batch calibration failed: {e}") + sys.exit(1) + else: + logger.error(f"Unknown command: {args.command}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/modules/lithium.pyimage/image/raw/raw.py b/modules/lithium.pyimage/image/raw/raw.py index de20ea0c..9d6360c0 100644 --- a/modules/lithium.pyimage/image/raw/raw.py +++ b/modules/lithium.pyimage/image/raw/raw.py @@ -1,14 +1,38 @@ -python +from pathlib import Path +from typing import Optional, Tuple +from dataclasses import dataclass, field +from loguru import logger +from enum import Enum import rawpy import cv2 import numpy as np -from typing import Optional, Tuple -from loguru import logger -from dataclasses import dataclass, field +import argparse +import sys +import concurrent.futures + + +# Configure Loguru logger with file rotation and different log levels +logger.remove() # Remove the default logger +logger.add(sys.stderr, level="INFO", + format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") +logger.add("raw_processor.log", rotation="10 MB", retention="10 days", + level="DEBUG", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") + + +class ImageFormat(Enum): + PNG = "png" + JPEG = "jpg" + TIFF = "tiff" + BMP = "bmp" + + @staticmethod + def list(): + return list(map(lambda c: c.value, ImageFormat)) + @dataclass class RawImageProcessor: - raw_path: str + raw_path: Path raw: rawpy.RawPy = field(init=False) rgb_image: np.ndarray = field(init=False) bgr_image: np.ndarray = field(init=False) @@ -16,32 +40,45 @@ class RawImageProcessor: def __post_init__(self): """ Initialize and read the RAW image. + """ + logger.debug( + f"Initializing RawImageProcessor with RAW file: {self.raw_path}") + self._load_raw_image() + self._postprocess_raw() - This method is called automatically after the dataclass is initialized. - It reads the RAW image from the specified path and processes it into an RGB image. + def _load_raw_image(self) -> None: + """ + Load the RAW image from the given path. """ - logger.debug(f"Initializing RawImageProcessor, RAW file path: {self.raw_path}") try: - self.raw = rawpy.imread(self.raw_path) + self.raw = rawpy.imread(str(self.raw_path)) logger.info(f"Successfully read RAW image: {self.raw_path}") except Exception as e: - logger.error(f"Failed to read RAW image: {e}") - raise IOError(f"Cannot read RAW image: {self.raw_path}") + logger.error(f"Failed to read RAW image at {self.raw_path}: {e}") + raise IOError(f"Cannot read RAW image: {self.raw_path}") from e + def _postprocess_raw(self) -> None: + """ + Post-process the RAW image to obtain an RGB image. + """ try: self.rgb_image = self.raw.postprocess( gamma=(1.0, 1.0), no_auto_bright=True, use_camera_wb=True, - output_bps=16 # Use higher bit depth + output_bps=16 # Use higher bit depth for better quality ) - logger.debug("Successfully post-processed RAW image") + logger.debug("Successfully post-processed RAW image to RGB.") except Exception as e: logger.error(f"Failed to post-process RAW image: {e}") - raise ValueError("RAW image post-processing failed") + raise ValueError("RAW image post-processing failed") from e - self.bgr_image = cv2.cvtColor(self.rgb_image, cv2.COLOR_RGB2BGR) - logger.debug("Converted to OpenCV BGR format") + try: + self.bgr_image = cv2.cvtColor(self.rgb_image, cv2.COLOR_RGB2BGR) + logger.debug("Converted RGB image to BGR format for OpenCV.") + except Exception as e: + logger.error(f"Failed to convert RGB to BGR: {e}") + raise RuntimeError("Image format conversion failed") from e def adjust_contrast(self, alpha: float = 1.0) -> None: """ @@ -49,13 +86,13 @@ def adjust_contrast(self, alpha: float = 1.0) -> None: :param alpha: Contrast control (1.0-3.0). Default is 1.0. """ - logger.debug(f"Adjusting contrast, alpha={alpha}") + logger.debug(f"Adjusting contrast with alpha={alpha}") try: self.bgr_image = cv2.convertScaleAbs(self.bgr_image, alpha=alpha) - logger.info(f"Contrast adjusted successfully, alpha={alpha}") + logger.info(f"Contrast adjusted successfully with alpha={alpha}") except Exception as e: logger.error(f"Failed to adjust contrast: {e}") - raise RuntimeError("Contrast adjustment failed") + raise RuntimeError("Contrast adjustment failed") from e def adjust_brightness(self, beta: int = 0) -> None: """ @@ -63,30 +100,28 @@ def adjust_brightness(self, beta: int = 0) -> None: :param beta: Brightness control (0-100). Default is 0. """ - logger.debug(f"Adjusting brightness, beta={beta}") + logger.debug(f"Adjusting brightness with beta={beta}") try: self.bgr_image = cv2.convertScaleAbs(self.bgr_image, beta=beta) - logger.info(f"Brightness adjusted successfully, beta={beta}") + logger.info(f"Brightness adjusted successfully with beta={beta}") except Exception as e: logger.error(f"Failed to adjust brightness: {e}") - raise RuntimeError("Brightness adjustment failed") + raise RuntimeError("Brightness adjustment failed") from e def apply_sharpening(self) -> None: """ Apply sharpening to the image. - - This method uses a kernel to sharpen the image. """ - logger.debug("Applying sharpening") + logger.debug("Applying sharpening filter.") try: kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) self.bgr_image = cv2.filter2D(self.bgr_image, -1, kernel) - logger.info("Image sharpening applied successfully") + logger.info("Sharpening applied successfully.") except Exception as e: logger.error(f"Failed to apply sharpening: {e}") - raise RuntimeError("Sharpening application failed") + raise RuntimeError("Sharpening application failed") from e def apply_gamma_correction(self, gamma: float = 1.0) -> None: """ @@ -94,16 +129,17 @@ def apply_gamma_correction(self, gamma: float = 1.0) -> None: :param gamma: Gamma value (0.1-3.0). Default is 1.0. """ - logger.debug(f"Applying gamma correction, gamma={gamma}") + logger.debug(f"Applying gamma correction with gamma={gamma}") try: inv_gamma = 1.0 / gamma - table = np.array([((i / 255.0) ** inv_gamma) * 255 - for i in np.arange(256)]).astype("uint8") + table = np.array([(i / 255.0) ** inv_gamma * + 255 for i in np.arange(256)]).astype("uint8") self.bgr_image = cv2.LUT(self.bgr_image, table) - logger.info(f"Gamma correction applied successfully, gamma={gamma}") + logger.info( + f"Gamma correction applied successfully with gamma={gamma}") except Exception as e: logger.error(f"Failed to apply gamma correction: {e}") - raise RuntimeError("Gamma correction application failed") + raise RuntimeError("Gamma correction application failed") from e def rotate_image(self, angle: float, center: Optional[Tuple[int, int]] = None, scale: float = 1.0) -> None: """ @@ -113,17 +149,17 @@ def rotate_image(self, angle: float, center: Optional[Tuple[int, int]] = None, s :param center: Center of rotation. If None, the center of the image is used. :param scale: Scale factor. Default is 1.0. """ - logger.debug(f"Rotating image, angle={angle}, scale={scale}") + logger.debug(f"Rotating image by {angle} degrees with scale={scale}") try: (h, w) = self.bgr_image.shape[:2] if center is None: center = (w // 2, h // 2) M = cv2.getRotationMatrix2D(center, angle, scale) self.bgr_image = cv2.warpAffine(self.bgr_image, M, (w, h)) - logger.info(f"Image rotated successfully, angle={angle}, scale={scale}") + logger.info(f"Image rotated successfully by {angle} degrees.") except Exception as e: logger.error(f"Failed to rotate image: {e}") - raise RuntimeError("Image rotation failed") + raise RuntimeError("Image rotation failed") from e def resize_image(self, width: Optional[int] = None, height: Optional[int] = None, inter: int = cv2.INTER_LINEAR) -> None: """ @@ -133,25 +169,27 @@ def resize_image(self, width: Optional[int] = None, height: Optional[int] = None :param height: New height of the image. If None, the height is scaled proportionally. :param inter: Interpolation method. Default is cv2.INTER_LINEAR. """ - logger.debug(f"Resizing image, width={width}, height={height}") + logger.debug(f"Resizing image with width={width}, height={height}") try: (h, w) = self.bgr_image.shape[:2] if width is None and height is None: - logger.warning("Both width and height are not specified, skipping resize") + logger.warning( + "No resizing performed as both width and height are None.") return if width is None: - r = height / float(h) - dim = (int(w * r), height) + ratio = height / float(h) + dim = (int(w * ratio), height) elif height is None: - r = width / float(w) - dim = (width, int(h * r)) + ratio = width / float(w) + dim = (width, int(h * ratio)) else: dim = (width, height) - self.bgr_image = cv2.resize(self.bgr_image, dim, interpolation=inter) - logger.info(f"Image resized successfully, new dimensions={dim}") + self.bgr_image = cv2.resize( + self.bgr_image, dim, interpolation=inter) + logger.info(f"Image resized successfully to dimensions: {dim}") except Exception as e: logger.error(f"Failed to resize image: {e}") - raise RuntimeError("Image resize failed") + raise RuntimeError("Image resize failed") from e def adjust_color_balance(self, red: float = 1.0, green: float = 1.0, blue: float = 1.0) -> None: """ @@ -161,53 +199,79 @@ def adjust_color_balance(self, red: float = 1.0, green: float = 1.0, blue: float :param green: Green channel multiplier. Default is 1.0. :param blue: Blue channel multiplier. Default is 1.0. """ - logger.debug(f"Adjusting color balance, red={red}, green={green}, blue={blue}") + logger.debug( + f"Adjusting color balance with R={red}, G={green}, B={blue}") try: b, g, r = cv2.split(self.bgr_image) r = cv2.convertScaleAbs(r, alpha=red) g = cv2.convertScaleAbs(g, alpha=green) b = cv2.convertScaleAbs(b, alpha=blue) self.bgr_image = cv2.merge([b, g, r]) - logger.info(f"Color balance adjusted successfully, red={red}, green={green}, blue={blue}") + logger.info( + f"Color balance adjusted successfully with R={red}, G={green}, B={blue}") except Exception as e: logger.error(f"Failed to adjust color balance: {e}") - raise RuntimeError("Color balance adjustment failed") + raise RuntimeError("Color balance adjustment failed") from e - def save_image(self, output_path: str, file_format: str = "png", jpeg_quality: int = 90) -> None: + def apply_blur(self, ksize: Tuple[int, int] = (5, 5), method: str = "gaussian") -> None: """ - Save the image to the specified path. + Apply blur to the image. - :param output_path: Path to save the image. - :param file_format: Format to save the image. Default is "png". - :param jpeg_quality: Quality for JPEG format. Default is 90. + :param ksize: Kernel size for the blur. Default is (5, 5). + :param method: Blurring method ('gaussian', 'median', 'bilateral'). Default is 'gaussian'. """ - logger.debug(f"Saving image, path={output_path}, format={file_format}, JPEG quality={jpeg_quality}") + logger.debug(f"Applying {method} blur with kernel size={ksize}") try: - if file_format.lower() in ["jpg", "jpeg"]: - cv2.imwrite(output_path, self.bgr_image, [cv2.IMWRITE_JPEG_QUALITY, jpeg_quality]) - logger.info(f"Image saved successfully as JPEG: {output_path}, quality={jpeg_quality}") + if method.lower() == "gaussian": + self.bgr_image = cv2.GaussianBlur(self.bgr_image, ksize, 0) + elif method.lower() == "median": + self.bgr_image = cv2.medianBlur(self.bgr_image, ksize[0]) + elif method.lower() == "bilateral": + self.bgr_image = cv2.bilateralFilter( + self.bgr_image, d=9, sigmaColor=75, sigmaSpace=75) else: - cv2.imwrite(output_path, self.bgr_image) - logger.info(f"Image saved successfully as {file_format.upper()}: {output_path}") + logger.error(f"Unsupported blur method: {method}") + raise ValueError(f"Unsupported blur method: {method}") + logger.info( + f"{method.capitalize()} blur applied successfully with kernel size={ksize}") except Exception as e: - logger.error(f"Failed to save image: {e}") - raise IOError("Image save failed") + logger.error(f"Failed to apply {method} blur: {e}") + raise RuntimeError( + f"{method.capitalize()} blur application failed") from e - def show_image(self, window_name: str = "Image") -> None: + def histogram_equalization(self) -> None: """ - Display the processed image. + Apply histogram equalization to the image. + Enhances the contrast of the image using histogram equalization. + """ + logger.debug("Applying histogram equalization.") + try: + if len(self.bgr_image.shape) == 2: + self.bgr_image = cv2.equalizeHist(self.bgr_image) + logger.debug( + "Applied histogram equalization on grayscale image.") + else: + ycrcb = cv2.cvtColor(self.bgr_image, cv2.COLOR_BGR2YCrCb) + ycrcb[:, :, 0] = cv2.equalizeHist(ycrcb[:, :, 0]) + self.bgr_image = cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2BGR) + logger.debug( + "Applied histogram equalization on Y channel of color image.") + logger.info("Histogram equalization applied successfully.") + except Exception as e: + logger.error(f"Failed to apply histogram equalization: {e}") + raise RuntimeError("Histogram equalization failed") from e - :param window_name: Name of the window in which to display the image. Default is "Image". + def convert_to_grayscale(self) -> None: + """ + Convert the image to grayscale. """ - logger.debug(f"Displaying image window, name={window_name}") + logger.debug("Converting image to grayscale.") try: - cv2.imshow(window_name, self.bgr_image) - cv2.waitKey(0) - cv2.destroyAllWindows() - logger.info("Image displayed successfully") + self.bgr_image = cv2.cvtColor(self.bgr_image, cv2.COLOR_BGR2GRAY) + logger.info("Image converted to grayscale successfully.") except Exception as e: - logger.error(f"Failed to display image: {e}") - raise RuntimeError("Image display failed") + logger.error(f"Failed to convert image to grayscale: {e}") + raise RuntimeError("Grayscale conversion failed") from e def get_bgr_image(self) -> np.ndarray: """ @@ -215,84 +279,82 @@ def get_bgr_image(self) -> np.ndarray: :return: Processed BGR image as a NumPy array. """ - logger.debug("Getting processed BGR image") + logger.debug("Retrieving processed BGR image.") return self.bgr_image - def reset(self) -> None: - """ - Reset the image to its original state. - - This method resets the image to the state it was in after initial processing. - """ - logger.debug("Resetting image to original state") - try: - self.bgr_image = cv2.cvtColor(self.rgb_image, cv2.COLOR_RGB2BGR) - logger.info("Image reset successfully") - except Exception as e: - logger.error(f"Failed to reset image: {e}") - raise RuntimeError("Image reset failed") - - def convert_to_grayscale(self) -> None: + def to_rgb_image(self) -> np.ndarray: """ - Convert the image to grayscale. + Return the processed RGB image. - This method converts the processed BGR image to a grayscale image. + :return: Processed RGB image as a NumPy array. """ - logger.debug("Converting image to grayscale") + logger.debug("Converting BGR image to RGB format.") try: - self.bgr_image = cv2.cvtColor(self.bgr_image, cv2.COLOR_BGR2GRAY) - logger.info("Image converted to grayscale successfully") + rgb_image = cv2.cvtColor(self.bgr_image, cv2.COLOR_BGR2RGB) + logger.info("Successfully converted BGR image to RGB format.") + return rgb_image except Exception as e: - logger.error(f"Failed to convert image to grayscale: {e}") - raise RuntimeError("Grayscale conversion failed") + logger.error(f"Failed to convert BGR to RGB: {e}") + raise RuntimeError("RGB conversion failed") from e - def apply_blur(self, ksize: Tuple[int, int] = (5, 5)) -> None: + def save_image(self, output_path: Path, file_format: ImageFormat = ImageFormat.PNG, jpeg_quality: int = 90) -> None: """ - Apply blur to the image. + Save the image to the specified path. - :param ksize: Kernel size for the blur. Default is (5, 5). + :param output_path: Path to save the image. + :param file_format: Format to save the image. Default is PNG. + :param jpeg_quality: Quality for JPEG format (0-100). Default is 90. """ - logger.debug(f"Applying blur, kernel size={ksize}") + logger.debug( + f"Saving image to {output_path} with format={file_format.value} and JPEG quality={jpeg_quality}") try: - self.bgr_image = cv2.GaussianBlur(self.bgr_image, ksize, 0) - logger.info(f"Blur applied successfully, kernel size={ksize}") + if file_format in [ImageFormat.JPEG, ImageFormat.PNG, ImageFormat.TIFF, ImageFormat.BMP]: + if file_format == ImageFormat.JPEG: + cv2.imwrite(str(output_path), self.bgr_image, [ + cv2.IMWRITE_JPEG_QUALITY, jpeg_quality]) + logger.info( + f"Image saved successfully as JPEG: {output_path} with quality={jpeg_quality}") + else: + cv2.imwrite(str(output_path), self.bgr_image) + logger.info( + f"Image saved successfully as {file_format.value.upper()}: {output_path}") + else: + logger.error(f"Unsupported file format: {file_format}") + raise ValueError(f"Unsupported file format: {file_format}") except Exception as e: - logger.error(f"Failed to apply blur: {e}") - raise RuntimeError("Blur application failed") + logger.error(f"Failed to save image: {e}") + raise IOError("Image save failed") from e - def histogram_equalization(self) -> None: + def show_image(self, window_name: str = "Image", delay: int = 0) -> None: """ - Apply histogram equalization to the image. + Display the processed image. - This method enhances the contrast of the image using histogram equalization. + :param window_name: Name of the window in which to display the image. Default is "Image". + :param delay: Delay in milliseconds. 0 means wait indefinitely. Default is 0. """ - logger.debug("Applying histogram equalization") + logger.debug( + f"Displaying image in window: {window_name} with delay={delay}") try: - if len(self.bgr_image.shape) == 2: - self.bgr_image = cv2.equalizeHist(self.bgr_image) - else: - ycrcb = cv2.cvtColor(self.bgr_image, cv2.COLOR_BGR2YCrCb) - ycrcb[:, :, 0] = cv2.equalizeHist(ycrcb[:, :, 0]) - self.bgr_image = cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2BGR) - logger.info("Histogram equalization applied successfully") + cv2.imshow(window_name, self.bgr_image) + logger.info("Image display window opened.") + cv2.waitKey(delay) + cv2.destroyAllWindows() + logger.debug("Image display window closed.") except Exception as e: - logger.error(f"Failed to apply histogram equalization: {e}") - raise RuntimeError("Histogram equalization application failed") + logger.error(f"Failed to display image: {e}") + raise RuntimeError("Image display failed") from e - def to_rgb_image(self) -> np.ndarray: + def reset(self) -> None: """ - Return the processed RGB image. - - :return: Processed RGB image as a NumPy array. + Reset the image to its original state after post-processing. """ - logger.debug("Converting and getting processed RGB image") + logger.debug("Resetting image to original post-processed state.") try: - rgb_image = cv2.cvtColor(self.bgr_image, cv2.COLOR_BGR2RGB) - logger.info("Successfully converted to RGB image") - return rgb_image + self.bgr_image = cv2.cvtColor(self.rgb_image, cv2.COLOR_RGB2BGR) + logger.info("Image reset successfully.") except Exception as e: - logger.error(f"Failed to convert to RGB image: {e}") - raise RuntimeError("RGB conversion failed") + logger.error(f"Failed to reset image: {e}") + raise RuntimeError("Image reset failed") from e def to_raw_image(self) -> rawpy.RawPy: """ @@ -300,50 +362,180 @@ def to_raw_image(self) -> rawpy.RawPy: :return: Original RAW image as a rawpy.RawPy object. """ - logger.debug("Getting original RAW image") + logger.debug("Retrieving original RAW image.") return self.raw -# Usage example -if __name__ == "__main__": - # Initialize RAW image processor - processor = RawImageProcessor('path_to_your_image.raw') - - # Adjust contrast - processor.adjust_contrast(alpha=1.3) - - # Adjust brightness - processor.adjust_brightness(beta=20) - - # Apply sharpening - processor.apply_sharpening() - # Apply gamma correction - processor.apply_gamma_correction(gamma=1.2) - - # Rotate image - processor.rotate_image(angle=45) - - # Resize image - processor.resize_image(width=800) +def parse_cli_arguments() -> argparse.Namespace: + """ + Parse command-line arguments for the RAW image processor. + + :return: Parsed arguments. + """ + parser = argparse.ArgumentParser(description="RAW Image Processor") + subparsers = parser.add_subparsers( + dest='command', required=True, help='Available commands') + + # Subparser for processing a single image + process_parser = subparsers.add_parser( + 'process', help='Process a single RAW image.') + process_parser.add_argument( + '--input', type=Path, required=True, help='Path to the input RAW image.') + process_parser.add_argument( + '--output', type=Path, required=True, help='Path to save the processed image.') + process_parser.add_argument('--format', type=ImageFormat, default=ImageFormat.PNG, + choices=ImageFormat, help='Output image format.') + process_parser.add_argument( + '--jpeg_quality', type=int, default=90, help='JPEG quality (if applicable).') + process_parser.add_argument('--actions', nargs='+', default=[], choices=[ + 'adjust_contrast', 'adjust_brightness', 'apply_sharpening', 'apply_gamma_correction', + 'rotate_image', 'resize_image', 'adjust_color_balance', 'apply_blur', + 'histogram_equalization', 'convert_to_grayscale' + ], help='Image processing actions to perform.') + + # Subparser for batch processing + batch_parser = subparsers.add_parser( + 'batch', help='Batch process RAW images in a directory.') + batch_parser.add_argument( + '--input_dir', type=Path, required=True, help='Directory containing RAW images.') + batch_parser.add_argument( + '--output_dir', type=Path, required=True, help='Directory to save processed images.') + batch_parser.add_argument('--format', type=ImageFormat, default=ImageFormat.PNG, + choices=ImageFormat, help='Output image format.') + batch_parser.add_argument( + '--jpeg_quality', type=int, default=90, help='JPEG quality (if applicable).') + batch_parser.add_argument('--actions', nargs='+', default=[], choices=[ + 'adjust_contrast', 'adjust_brightness', 'apply_sharpening', 'apply_gamma_correction', + 'rotate_image', 'resize_image', 'adjust_color_balance', 'apply_blur', + 'histogram_equalization', 'convert_to_grayscale' + ], help='Image processing actions to perform.') + + return parser.parse_args() + + +def process_single_image(args: argparse.Namespace) -> None: + """ + Process a single RAW image based on the provided arguments. + + :param args: Parsed command-line arguments. + """ + processor = RawImageProcessor(raw_path=args.input) + + # Execute actions in the order provided + for action in args.actions: + if action == 'adjust_contrast': + processor.adjust_contrast(alpha=1.2) + elif action == 'adjust_brightness': + processor.adjust_brightness(beta=30) + elif action == 'apply_sharpening': + processor.apply_sharpening() + elif action == 'apply_gamma_correction': + processor.apply_gamma_correction(gamma=1.3) + elif action == 'rotate_image': + processor.rotate_image(angle=30) + elif action == 'resize_image': + processor.resize_image(width=1024) + elif action == 'adjust_color_balance': + processor.adjust_color_balance(red=1.1, green=1.0, blue=0.9) + elif action == 'apply_blur': + processor.apply_blur(ksize=(5, 5), method="gaussian") + elif action == 'histogram_equalization': + processor.histogram_equalization() + elif action == 'convert_to_grayscale': + processor.convert_to_grayscale() + else: + logger.warning(f"Unknown action: {action}") + + # Save the processed image + processor.save_image(output_path=args.output, + file_format=args.format, jpeg_quality=args.jpeg_quality) + logger.info(f"Processed image saved to {args.output}") + + +def batch_process_images(args: argparse.Namespace) -> None: + """ + Batch process RAW images in the specified directory. + + :param args: Parsed command-line arguments. + """ + logger.info( + f"Starting batch processing from {args.input_dir} to {args.output_dir}") + if not args.input_dir.exists() or not args.input_dir.is_dir(): + logger.error( + f"Input directory does not exist or is not a directory: {args.input_dir}") + raise NotADirectoryError( + f"Input directory does not exist or is not a directory: {args.input_dir}") + + args.output_dir.mkdir(parents=True, exist_ok=True) + raw_files = list(args.input_dir.glob( + '*.raw')) + list(args.input_dir.glob('*.CR2')) + list(args.input_dir.glob('*.NEF')) + + if not raw_files: + logger.warning(f"No RAW images found in directory: {args.input_dir}") + return + + def process_image(file_path: Path): + """ + Process a single image file. + """ + try: + output_file = args.output_dir / \ + f"{file_path.stem}.{args.format.value}" + processor = RawImageProcessor(raw_path=file_path) + + # Execute actions in the order provided + for action in args.actions: + if action == 'adjust_contrast': + processor.adjust_contrast(alpha=1.2) + elif action == 'adjust_brightness': + processor.adjust_brightness(beta=30) + elif action == 'apply_sharpening': + processor.apply_sharpening() + elif action == 'apply_gamma_correction': + processor.apply_gamma_correction(gamma=1.3) + elif action == 'rotate_image': + processor.rotate_image(angle=30) + elif action == 'resize_image': + processor.resize_image(width=1024) + elif action == 'adjust_color_balance': + processor.adjust_color_balance( + red=1.1, green=1.0, blue=0.9) + elif action == 'apply_blur': + processor.apply_blur(ksize=(5, 5), method="gaussian") + elif action == 'histogram_equalization': + processor.histogram_equalization() + elif action == 'convert_to_grayscale': + processor.convert_to_grayscale() + else: + logger.warning(f"Unknown action: {action}") + + # Save the processed image + processor.save_image( + output_path=output_file, file_format=args.format, jpeg_quality=args.jpeg_quality) + logger.info(f"Processed image saved to {output_file}") + except Exception as e: + logger.error(f"Failed to process image {file_path}: {e}") - # Adjust color balance - processor.adjust_color_balance(red=1.1, green=1.0, blue=0.9) + with concurrent.futures.ThreadPoolExecutor() as executor: + executor.map(process_image, raw_files) - # Apply blur - processor.apply_blur(ksize=(5, 5)) + logger.info("Batch processing completed successfully.") - # Apply histogram equalization - processor.histogram_equalization() - # Display processed image - processor.show_image() +def main(): + """ + Main function to parse arguments and execute appropriate functions. + """ + args = parse_cli_arguments() - # Save processed image - processor.save_image('output_image.png') + if args.command == 'process': + process_single_image(args) + elif args.command == 'batch': + batch_process_images(args) + else: + logger.error(f"Unknown command: {args.command}") + sys.exit(1) - # Reset image - processor.reset() - # Perform other processing and save as JPEG - processor.adjust_contrast(alpha=1.1) - processor.save_image('output_image.jpg', file_format="jpg", jpeg_quality=85) \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/modules/lithium.pyimage/image/resample/resample.py b/modules/lithium.pyimage/image/resample/resample.py index ca1ed000..e931e044 100644 --- a/modules/lithium.pyimage/image/resample/resample.py +++ b/modules/lithium.pyimage/image/resample/resample.py @@ -1,177 +1,460 @@ -import concurrent +import concurrent.futures +from pathlib import Path +from typing import Optional, Tuple, Literal, List +from enum import Enum import cv2 from loguru import logger -from matplotlib.path import Path -from PIL import Image -from typing import Optional, Tuple, Literal - - -def resample_image(input_image_path: str, - output_image_path: str, - width: Optional[int] = None, - height: Optional[int] = None, - scale: Optional[float] = None, - resolution: Optional[Tuple[int, int]] = None, - interpolation: Literal[cv2.INTER_LINEAR, - cv2.INTER_CUBIC, cv2.INTER_NEAREST] = cv2.INTER_LINEAR, - preserve_aspect_ratio: bool = True, - crop_area: Optional[Tuple[int, int, int, int]] = None, - edge_detection: bool = False, - color_space: Literal['BGR', 'GRAY', 'HSV', 'RGB'] = 'BGR', - batch_mode: bool = False, - output_format: str = 'jpg', - ) -> None: - """ - Resamples an image with given dimensions, scale, resolution, and additional processing options. - - :param input_image_path: Path to the input image (or directory in batch mode). - :param output_image_path: Path to save the resampled image(s). - :param width: Desired width in pixels. - :param height: Desired height in pixels. - :param scale: Scale factor for resizing (e.g., 0.5 for half size, 2.0 for double size). - :param resolution: Tuple of horizontal and vertical resolution (in dpi). - :param interpolation: Interpolation method (e.g., cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_NEAREST). - :param preserve_aspect_ratio: Whether to preserve the aspect ratio of the original image. - :param crop_area: Tuple defining the crop area (x, y, w, h). - :param edge_detection: Whether to apply edge detection before resizing. - :param color_space: Color space to convert the image to (e.g., 'GRAY', 'HSV', 'RGB'). - :param add_watermark: Whether to add a watermark to the image. - :param watermark_text: The text to be used as a watermark. - :param watermark_position: Position tuple (x, y) for the watermark. - :param watermark_opacity: Opacity level for the watermark (0.0 to 1.0). - :param batch_mode: Whether to process multiple images in a directory. - :param output_format: Output image format (e.g., 'jpg', 'png', 'tiff'). - :param brightness: Brightness adjustment factor (1.0 = no change). - :param contrast: Contrast adjustment factor (1.0 = no change). - :param sharpen: Whether to apply sharpening to the image. - :param rotate_angle: Angle to rotate the image (in degrees). - """ - logger.info(f"Starting resampling process for: {input_image_path}") +from PIL import Image, ImageDraw, ImageFont +import numpy as np +import argparse +import sys + + +# Configure Loguru logger with file rotation and different log levels +logger.remove() # Remove the default logger +logger.add(sys.stderr, level="INFO", + format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") +logger.add("resample_processor.log", rotation="10 MB", retention="10 days", + level="DEBUG", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") + + +class ImageFormat(Enum): + PNG = "png" + JPEG = "jpg" + TIFF = "tiff" + BMP = "bmp" + + @staticmethod + def list(): + return list(map(lambda c: c.value, ImageFormat)) + - def process_image(image_path: Path, output_path: Path): - logger.debug(f"Processing image: {image_path}") - img = cv2.imread(str(image_path)) +@dataclass +class Resampler: + input_image_path: Path + output_image_path: Path + width: Optional[int] = None + height: Optional[int] = None + scale: Optional[float] = None + resolution: Optional[Tuple[int, int]] = None + interpolation: Literal[cv2.INTER_LINEAR, + cv2.INTER_CUBIC, cv2.INTER_NEAREST] = cv2.INTER_LINEAR + preserve_aspect_ratio: bool = True + crop_area: Optional[Tuple[int, int, int, int]] = None + edge_detection: bool = False + color_space: Literal['BGR', 'GRAY', 'HSV', 'RGB'] = 'BGR' + brightness: float = 1.0 + contrast: float = 1.0 + sharpen: bool = False + rotate_angle: float = 0.0 + blur: Optional[Tuple[int, int]] = None + blur_method: Literal['gaussian', 'median', 'bilateral'] = 'gaussian' + histogram_equalization: bool = False + grayscale: bool = False + output_format: ImageFormat = ImageFormat.JPEG + jpeg_quality: int = 90 + + def process(self) -> None: + """ + Process the image with the specified parameters. + """ + logger.info(f"Processing image: {self.input_image_path}") + img = cv2.imread(str(self.input_image_path)) if img is None: - logger.error(f"Cannot load image from {image_path}") - raise ValueError(f"Cannot load image from {image_path}") + logger.error(f"Cannot load image from {self.input_image_path}") + raise ValueError(f"Cannot load image from {self.input_image_path}") original_height, original_width = img.shape[:2] logger.debug( f"Original dimensions: width={original_width}, height={original_height}") # Crop if needed - if crop_area: - x, y, w, h = crop_area + if self.crop_area: + x, y, w, h = self.crop_area logger.debug(f"Cropping image: x={x}, y={y}, w={w}, h={h}") img = img[y:y+h, x:x+w] # Edge detection - if edge_detection: + if self.edge_detection: logger.debug("Applying edge detection.") img = cv2.Canny(img, 100, 200) # Convert color space if needed - img = change_color_space(img, color_space) + img = self.change_color_space(img, self.color_space) - # Calculate new dimensions - if scale: - new_width = int(original_width * scale) - new_height = int(original_height * scale) + # Adjust brightness and contrast + if self.brightness != 1.0 or self.contrast != 1.0: logger.debug( - f"Scaling image by factor: {scale}, new dimensions: width={new_width}, height={new_height}") - else: - if width and height: - new_width = width - new_height = height - elif width: - new_width = width - new_height = int( - (width / original_width) * original_height) if preserve_aspect_ratio else height - logger.debug( - f"Setting width to {width}, calculated height: {new_height}") - elif height: - new_height = height - new_width = int((height / original_height) * - original_width) if preserve_aspect_ratio else width - logger.debug( - f"Setting height to {height}, calculated width: {new_width}") - else: - new_width, new_height = original_width, original_height - logger.debug( - f"No scaling parameters provided, keeping original dimensions.") + f"Adjusting brightness by {self.brightness} and contrast by {self.contrast}") + img = self.adjust_brightness_contrast( + img, self.brightness, self.contrast) - # Perform resizing + # Apply sharpening + if self.sharpen: + logger.debug("Applying sharpening.") + img = self.apply_sharpening(img) + + # Apply blur if specified + if self.blur: + logger.debug( + f"Applying {self.blur_method} blur with kernel size={self.blur}") + img = self.apply_blur(img, self.blur, self.blur_method) + + # Histogram equalization + if self.histogram_equalization: + logger.debug("Applying histogram equalization.") + img = self.histogram_equalization_func(img) + + # Convert to grayscale if needed + if self.grayscale: + logger.debug("Converting image to grayscale.") + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + # Calculate new dimensions + new_width, new_height = self.calculate_new_dimensions( + original_width, original_height, img) logger.debug( - f"Resizing image to width={new_width}, height={new_height} with interpolation={interpolation}") + f"Resizing image to width={new_width}, height={new_height} with interpolation={self.interpolation}") + + # Perform resizing resized_img = cv2.resize( - img, (new_width, new_height), interpolation=interpolation) + img, (new_width, new_height), interpolation=self.interpolation) + + # Rotate image if needed + if self.rotate_angle != 0.0: + logger.debug(f"Rotating image by {self.rotate_angle} degrees.") + resized_img = self.rotate_image(resized_img, self.rotate_angle) # Convert back to BGR if color_space was changed to RGB for saving with OpenCV - if color_space == 'RGB' and len(resized_img.shape) == 3: + if self.color_space == 'RGB' and len(resized_img.shape) == 3: logger.debug("Converting image from RGB back to BGR for saving.") resized_img = cv2.cvtColor(resized_img, cv2.COLOR_RGB2BGR) - elif color_space == 'HSV' and len(resized_img.shape) == 3: + elif self.color_space == 'HSV' and len(resized_img.shape) == 3: logger.debug("Converting image from HSV back to BGR for saving.") resized_img = cv2.cvtColor(resized_img, cv2.COLOR_HSV2BGR) # Save the image with specified resolution if provided - if resolution: - logger.debug(f"Saving image with resolution: {resolution}") - pil_img = Image.fromarray(cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB) if len( - resized_img.shape) == 3 else resized_img) - pil_img.save(str(output_path), dpi=resolution, - format=output_format.upper()) + if self.resolution: + logger.debug( + f"Saving image with resolution: {self.resolution} DPI") + pil_img = Image.fromarray( + cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB) if len(resized_img.shape) == 3 else resized_img) + pil_img.save(str(self.output_image_path), dpi=self.resolution, + format=self.output_format.name.upper()) else: logger.debug( - f"Saving image without changing resolution to format: {output_format}") - cv2.imwrite(str(output_path), resized_img) + f"Saving image without changing resolution to format: {self.output_format.value}") + if self.output_format == ImageFormat.JPEG: + cv2.imwrite(str(self.output_image_path), resized_img, [ + cv2.IMWRITE_JPEG_QUALITY, self.jpeg_quality]) + else: + cv2.imwrite(str(self.output_image_path), resized_img) + + logger.info(f"Image saved successfully to: {self.output_image_path}") + + @staticmethod + def change_color_space(img: np.ndarray, color_space: str) -> np.ndarray: + """ + Change the color space of the image. + + :param img: Input image in BGR format. + :param color_space: Target color space. + :return: Image in the target color space. + """ + if color_space == 'GRAY': + return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + elif color_space == 'HSV': + return cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + elif color_space == 'RGB': + return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + @staticmethod + def adjust_brightness_contrast(img: np.ndarray, brightness: float, contrast: float) -> np.ndarray: + """ + Adjust the brightness and contrast of the image. + + :param img: Input image. + :param brightness: Brightness factor. + :param contrast: Contrast factor. + :return: Adjusted image. + """ + return cv2.convertScaleAbs(img, alpha=contrast, beta=brightness * 50) + + @staticmethod + def apply_sharpening(img: np.ndarray) -> np.ndarray: + """ + Apply sharpening to the image. + + :param img: Input image. + :return: Sharpened image. + """ + kernel = np.array([[0, -1, 0], + [-1, 5, -1], + [0, -1, 0]]) + return cv2.filter2D(img, -1, kernel) + + @staticmethod + def apply_blur(img: np.ndarray, ksize: Tuple[int, int], method: str) -> np.ndarray: + """ + Apply blur to the image. - logger.info(f"Image saved successfully to: {output_path}") + :param img: Input image. + :param ksize: Kernel size for the blur. + :param method: Blurring method ('gaussian', 'median', 'bilateral'). + :return: Blurred image. + """ + if method.lower() == "gaussian": + return cv2.GaussianBlur(img, ksize, 0) + elif method.lower() == "median": + return cv2.medianBlur(img, ksize[0]) + elif method.lower() == "bilateral": + return cv2.bilateralFilter(img, d=9, sigmaColor=75, sigmaSpace=75) + else: + logger.error(f"Unsupported blur method: {method}") + raise ValueError(f"Unsupported blur method: {method}") + + @staticmethod + def histogram_equalization_func(img: np.ndarray) -> np.ndarray: + """ + Apply histogram equalization to the image. + + :param img: Input image. + :return: Image after histogram equalization. + """ + if len(img.shape) == 2: + return cv2.equalizeHist(img) + else: + ycrcb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) + ycrcb[:, :, 0] = cv2.equalizeHist(ycrcb[:, :, 0]) + return cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2BGR) + + @staticmethod + def rotate_image(img: np.ndarray, angle: float) -> np.ndarray: + """ + Rotate the image by the given angle. + + :param img: Input image. + :param angle: Angle to rotate the image. + :return: Rotated image. + """ + (h, w) = img.shape[:2] + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, angle, 1.0) + return cv2.warpAffine(img, M, (w, h)) + + +def parse_arguments() -> argparse.Namespace: + """ + Parse command-line arguments for the resample script. + + :return: Parsed arguments. + """ + parser = argparse.ArgumentParser( + description="Image Resampling and Processing Tool") + subparsers = parser.add_subparsers( + dest='command', required=True, help='Available commands') + + # Subparser for single image processing + process_parser = subparsers.add_parser( + 'process', help='Process a single image.') + process_parser.add_argument( + '--input', type=Path, required=True, help='Path to the input image.') + process_parser.add_argument( + '--output', type=Path, required=True, help='Path to save the processed image.') + process_parser.add_argument( + '--width', type=int, help='Desired width in pixels.') + process_parser.add_argument( + '--height', type=int, help='Desired height in pixels.') + process_parser.add_argument( + '--scale', type=float, help='Scale factor for resizing (e.g., 0.5 for half size).') + process_parser.add_argument('--resolution', type=int, nargs=2, metavar=('X_RES', 'Y_RES'), + help='Output resolution in DPI (e.g., 300 300).') + process_parser.add_argument('--interpolation', type=str, choices=['linear', 'cubic', 'nearest'], + default='linear', help='Interpolation method.') + process_parser.add_argument('--preserve_aspect_ratio', action='store_true', + help='Preserve aspect ratio when resizing.') + process_parser.add_argument('--crop_area', type=int, nargs=4, metavar=('X', 'Y', 'W', 'H'), + help='Crop area as four integers: x y width height.') + process_parser.add_argument('--edge_detection', action='store_true', + help='Apply edge detection before resizing.') + process_parser.add_argument('--color_space', type=str, choices=['BGR', 'GRAY', 'HSV', 'RGB'], + default='BGR', help='Color space to convert the image to.') + process_parser.add_argument('--brightness', type=float, default=1.0, + help='Brightness adjustment factor (1.0 = no change).') + process_parser.add_argument('--contrast', type=float, default=1.0, + help='Contrast adjustment factor (1.0 = no change).') + process_parser.add_argument('--sharpen', action='store_true', + help='Apply sharpening to the image.') + process_parser.add_argument('--rotate_angle', type=float, default=0.0, + help='Angle to rotate the image in degrees.') + process_parser.add_argument('--blur', type=int, nargs=2, metavar=('WIDTH', 'HEIGHT'), + help='Kernel size for blurring.') + process_parser.add_argument('--blur_method', type=str, choices=['gaussian', 'median', 'bilateral'], + default='gaussian', help='Method to use for blurring.') + process_parser.add_argument('--histogram_equalization', action='store_true', + help='Apply histogram equalization to enhance contrast.') + process_parser.add_argument('--grayscale', action='store_true', + help='Convert image to grayscale.') + process_parser.add_argument('--output_format', type=str, choices=ImageFormat.list(), + default='jpg', help='Output image format.') + process_parser.add_argument('--jpeg_quality', type=int, default=90, + help='JPEG quality (only applicable if output format is JPEG).') + + # Subparser for batch processing + batch_parser = subparsers.add_parser( + 'batch', help='Batch process multiple images.') + batch_parser.add_argument( + '--input_dir', type=Path, required=True, help='Directory containing input images.') + batch_parser.add_argument( + '--output_dir', type=Path, required=True, help='Directory to save processed images.') + batch_parser.add_argument( + '--width', type=int, help='Desired width in pixels.') + batch_parser.add_argument('--height', type=int, + help='Desired height in pixels.') + batch_parser.add_argument( + '--scale', type=float, help='Scale factor for resizing (e.g., 0.5 for half size).') + batch_parser.add_argument('--resolution', type=int, nargs=2, metavar=('X_RES', 'Y_RES'), + help='Output resolution in DPI (e.g., 300 300).') + batch_parser.add_argument('--interpolation', type=str, choices=['linear', 'cubic', 'nearest'], + default='linear', help='Interpolation method.') + batch_parser.add_argument('--preserve_aspect_ratio', action='store_true', + help='Preserve aspect ratio when resizing.') + batch_parser.add_argument('--crop_area', type=int, nargs=4, metavar=('X', 'Y', 'W', 'H'), + help='Crop area as four integers: x y width height.') + batch_parser.add_argument('--edge_detection', action='store_true', + help='Apply edge detection before resizing.') + batch_parser.add_argument('--color_space', type=str, choices=['BGR', 'GRAY', 'HSV', 'RGB'], + default='BGR', help='Color space to convert the image to.') + batch_parser.add_argument('--brightness', type=float, default=1.0, + help='Brightness adjustment factor (1.0 = no change).') + batch_parser.add_argument('--contrast', type=float, default=1.0, + help='Contrast adjustment factor (1.0 = no change).') + batch_parser.add_argument('--sharpen', action='store_true', + help='Apply sharpening to the images.') + batch_parser.add_argument('--rotate_angle', type=float, default=0.0, + help='Angle to rotate the images in degrees.') + batch_parser.add_argument('--blur', type=int, nargs=2, metavar=('WIDTH', 'HEIGHT'), + help='Kernel size for blurring.') + batch_parser.add_argument('--blur_method', type=str, choices=['gaussian', 'median', 'bilateral'], + default='gaussian', help='Method to use for blurring.') + batch_parser.add_argument('--histogram_equalization', action='store_true', + help='Apply histogram equalization to enhance contrast.') + batch_parser.add_argument('--grayscale', action='store_true', + help='Convert images to grayscale.') + batch_parser.add_argument('--output_format', type=str, choices=ImageFormat.list(), + default='jpg', help='Output image format.') + batch_parser.add_argument('--jpeg_quality', type=int, default=90, + help='JPEG quality (only applicable if output format is JPEG).') + + return parser.parse_args() + + +def process_image(resampler: Resampler) -> None: + """ + Wrapper function to process an image using the Resampler class. + + :param resampler: Resampler instance with all parameters set. + """ + try: + resampler.process() + except Exception as e: + logger.error( + f"Error processing image {resampler.input_image_path}: {e}") + + +def main(): + """ + Main function to parse arguments and execute processing. + """ + args = parse_arguments() + + if args.command == 'process': + resampler = Resampler( + input_image_path=args.input, + output_image_path=args.output, + width=args.width, + height=args.height, + scale=args.scale, + resolution=tuple(args.resolution) if args.resolution else None, + interpolation={ + 'linear': cv2.INTER_LINEAR, + 'cubic': cv2.INTER_CUBIC, + 'nearest': cv2.INTER_NEAREST + }[args.interpolation], + preserve_aspect_ratio=args.preserve_aspect_ratio, + crop_area=tuple(args.crop_area) if args.crop_area else None, + edge_detection=args.edge_detection, + color_space=args.color_space, + brightness=args.brightness, + contrast=args.contrast, + sharpen=args.sharpen, + rotate_angle=args.rotate_angle, + blur=tuple(args.blur) if args.blur else None, + blur_method=args.blur_method, + histogram_equalization=args.histogram_equalization, + grayscale=args.grayscale, + output_format=ImageFormat(args.output_format), + jpeg_quality=args.jpeg_quality + ) + process_image(resampler) - input_path = Path(input_image_path) - output_path = Path(output_image_path) + elif args.command == 'batch': + input_dir = args.input_dir + output_dir = args.output_dir - # Batch processing mode - if batch_mode: - logger.info("Batch mode enabled.") - if not input_path.is_dir(): + if not input_dir.exists() or not input_dir.is_dir(): logger.error( - "In batch mode, input_image_path must be a directory.") - raise ValueError( - "In batch mode, input_image_path must be a directory.") - if not output_path.exists(): - logger.debug(f"Creating output directory: {output_path}") - output_path.mkdir(parents=True, exist_ok=True) - image_files = list(input_path.glob('*')) - logger.debug(f"Found {len(image_files)} files to process.") + f"Input directory does not exist or is not a directory: {input_dir}") + sys.exit(1) + + output_dir.mkdir(parents=True, exist_ok=True) + # You might want to filter specific extensions + image_files = list(input_dir.glob('*')) + + logger.info(f"Found {len(image_files)} files to process in batch.") + + resamplers = [ + Resampler( + input_image_path=file, + output_image_path=output_dir / + f"{file.stem}.{args.output_format}", + width=args.width, + height=args.height, + scale=args.scale, + resolution=tuple(args.resolution) if args.resolution else None, + interpolation={ + 'linear': cv2.INTER_LINEAR, + 'cubic': cv2.INTER_CUBIC, + 'nearest': cv2.INTER_NEAREST + }[args.interpolation], + preserve_aspect_ratio=args.preserve_aspect_ratio, + crop_area=tuple(args.crop_area) if args.crop_area else None, + edge_detection=args.edge_detection, + color_space=args.color_space, + brightness=args.brightness, + contrast=args.contrast, + sharpen=args.sharpen, + rotate_angle=args.rotate_angle, + blur=tuple(args.blur) if args.blur else None, + blur_method=args.blur_method, + histogram_equalization=args.histogram_equalization, + grayscale=args.grayscale, + output_format=ImageFormat(args.output_format), + jpeg_quality=args.jpeg_quality + ) + for file in image_files + if file.suffix.lower() in ['.jpg', '.jpeg', '.png', '.tiff', '.bmp', '.gif'] + ] + + logger.info( + f"Starting batch processing with {len(resamplers)} images.") + with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [] - for file in image_files: - if file.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.tiff', '.bmp', '.gif']: - logger.warning( - f"Skipping unsupported file format: {file.name}") - continue - output_file = output_path / f"{file.stem}.{output_format}" - futures.append(executor.submit( - process_image, file, output_file)) - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - logger.error(f"Error processing batch image: {e}") - else: - logger.info("Single image mode.") - process_image(input_path, output_path) - - logger.info("Resampling process completed successfully.") - - -# Example usage: -input_path = 'star_image.png' -output_path = './output/star_image_processed.png' - -# Resize image, apply edge detection, adjust brightness/contrast, add watermark, sharpen, and rotate -resample_image(input_path, output_path, width=800, height=600, interpolation=cv2.INTER_CUBIC, resolution=(300, 300), - crop_area=(100, 100, 400, 400), edge_detection=True, color_space='GRAY', add_watermark=True, - watermark_text='Sample Watermark', watermark_position=(10, 10), watermark_opacity=0.7, - brightness=1.2, contrast=1.5, output_format='png', sharpen=True, rotate_angle=45) + executor.map(process_image, resamplers) + + logger.info("Batch processing completed successfully.") + + +if __name__ == "__main__": + main() diff --git a/pysrc/image/star_detection/__init__.py b/modules/lithium.pyimage/image/star_detection/__init__.py similarity index 100% rename from pysrc/image/star_detection/__init__.py rename to modules/lithium.pyimage/image/star_detection/__init__.py diff --git a/modules/lithium.pyimage/image/star_detection/clustering.py b/modules/lithium.pyimage/image/star_detection/clustering.py new file mode 100644 index 00000000..14edbc1e --- /dev/null +++ b/modules/lithium.pyimage/image/star_detection/clustering.py @@ -0,0 +1,198 @@ +import numpy as np +from sklearn.cluster import DBSCAN +from typing import List, Tuple, Optional +from pathlib import Path +from loguru import logger +import matplotlib.pyplot as plt +import json + + +class StarClustering: + def __init__(self, eps: float = 0.5, min_samples: int = 5, algorithm: str = 'auto') -> None: + """ + Initialize the StarClustering with DBSCAN parameters. + + :param eps: The maximum distance between two samples for them to be considered as in the same neighborhood. + :param min_samples: The number of samples in a neighborhood for a point to be considered as a core point. + :param algorithm: The algorithm to be used by DBSCAN. Default is 'auto'. + """ + self.eps = eps + self.min_samples = min_samples + self.algorithm = algorithm + self.clusters: Optional[List[List[Tuple[int, int]]]] = None + self.centroids: Optional[List[Tuple[int, int]]] = None + + # Configure Loguru logger + logger.remove() # Remove default logger + logger.add("clustering.log", rotation="10 MB", retention="10 days", + level="DEBUG", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") + logger.debug( + f"Initialized StarClustering with eps={self.eps}, min_samples={self.min_samples}, algorithm='{self.algorithm}'") + + def cluster_stars(self, stars: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + """ + Cluster stars using the DBSCAN algorithm and compute centroids of each cluster. + + :param stars: List of star positions as (x, y) tuples. + :return: List of clustered star centroids as (x, y) tuples. + """ + logger.info("Starting clustering process") + if not stars: + logger.warning( + "Empty star list provided. Returning empty centroid list.") + return [] + + try: + clustering = DBSCAN( + eps=self.eps, min_samples=self.min_samples, algorithm=self.algorithm) + labels = clustering.fit_predict(stars) + logger.debug(f"DBSCAN labels: {labels}") + + unique_labels = set(labels) + self.clusters = [] + self.centroids = [] + + for label in unique_labels: + if label == -1: + logger.debug("Skipping noise points") + continue + class_members = [stars[i] + for i in range(len(stars)) if labels[i] == label] + centroid = tuple(np.mean(class_members, axis=0).astype(int)) + self.clusters.append(class_members) + self.centroids.append(centroid) + logger.debug(f"Cluster {label}: Centroid at {centroid}") + + logger.info( + f"Clustering completed with {len(self.centroids)} clusters found") + return self.centroids + except Exception as e: + logger.error(f"Error during clustering: {e}") + raise + + def visualize_clusters(self, stars: List[Tuple[int, int]], save_path: Optional[Path] = None) -> None: + """ + Visualize the clustered stars and their centroids. + + :param stars: List of star positions as (x, y) tuples. + :param save_path: Optional Path to save the plot image. + """ + if self.centroids is None: + logger.error( + "No clusters to visualize. Please run cluster_stars first.") + raise ValueError("No clusters to visualize.") + + logger.info("Visualizing clusters") + plt.figure(figsize=(10, 8)) + colors = plt.cm.get_cmap('viridis', len(self.centroids)) + + for idx, centroid in enumerate(self.centroids): + cluster = self.clusters[idx] + cluster_np = np.array(cluster) + plt.scatter(cluster_np[:, 0], cluster_np[:, 1], + s=30, color=colors(idx), label=f'Cluster {idx}') + plt.scatter(centroid[0], centroid[1], + s=100, color='red', marker='X') + + # Plot noise points + noise = [star for star, label in zip(stars, DBSCAN( + eps=self.eps, min_samples=self.min_samples, algorithm=self.algorithm).fit_predict(stars)) if label == -1] + if noise: + noise_np = np.array(noise) + plt.scatter(noise_np[:, 0], noise_np[:, 1], + s=30, color='grey', label='Noise') + + plt.title('Star Clusters') + plt.xlabel('X Coordinate') + plt.ylabel('Y Coordinate') + plt.legend() + plt.grid(True) + if save_path: + plt.savefig(save_path) + logger.info(f"Cluster visualization saved to {save_path}") + plt.show() + logger.debug("Clusters visualized successfully") + + def save_clusters(self, filepath: Union[str, Path]) -> None: + """ + Save the clustered centroids to a JSON file. + + :param filepath: Path to save the JSON file. + """ + if self.centroids is None: + logger.error( + "No clusters to save. Please run cluster_stars first.") + raise ValueError("No clusters to save.") + + logger.info(f"Saving clusters to {filepath}") + try: + with open(filepath, 'w') as file: + json.dump(self.centroids, file, indent=4) + logger.debug("Clusters saved successfully") + except Exception as e: + logger.error(f"Failed to save clusters: {e}") + raise + + def load_clusters(self, filepath: Union[str, Path]) -> List[Tuple[int, int]]: + """ + Load clustered centroids from a JSON file. + + :param filepath: Path to load the JSON file from. + :return: List of clustered star centroids as (x, y) tuples. + """ + logger.info(f"Loading clusters from {filepath}") + try: + with open(filepath, 'r') as file: + self.centroids = json.load(file) + logger.debug(f"Clusters loaded: {self.centroids}") + return self.centroids + except Exception as e: + logger.error(f"Failed to load clusters: {e}") + raise + + def update_parameters(self, eps: Optional[float] = None, min_samples: Optional[int] = None, algorithm: Optional[str] = None) -> None: + """ + Update DBSCAN parameters. + + :param eps: New epsilon value. + :param min_samples: New minimum samples value. + :param algorithm: New algorithm to use. + """ + if eps is not None: + self.eps = eps + logger.debug(f"Updated eps to {self.eps}") + if min_samples is not None: + self.min_samples = min_samples + logger.debug(f"Updated min_samples to {self.min_samples}") + if algorithm is not None: + self.algorithm = algorithm + logger.debug(f"Updated algorithm to {self.algorithm}") + + +# Example Usage +if __name__ == "__main__": + import sys + + # Initialize StarClustering object + star_cluster = StarClustering(eps=30, min_samples=5) + + # Sample star positions + stars = [ + (10, 10), (12, 11), (11, 13), + (50, 50), (51, 52), (49, 51), + (90, 90), (91, 89), (89, 92) + ] + + # Perform clustering + centroids = star_cluster.cluster_stars(stars) + print(f"Cluster Centroids: {centroids}") + + # Visualize clusters + star_cluster.visualize_clusters(stars, save_path=Path("clusters.png")) + + # Save clusters to JSON + star_cluster.save_clusters("clusters.json") + + # Load clusters from JSON + loaded_centroids = star_cluster.load_clusters("clusters.json") + print(f"Loaded Cluster Centroids: {loaded_centroids}") diff --git a/modules/lithium.pyimage/image/star_detection/detection.py b/modules/lithium.pyimage/image/star_detection/detection.py new file mode 100644 index 00000000..e384ba81 --- /dev/null +++ b/modules/lithium.pyimage/image/star_detection/detection.py @@ -0,0 +1,327 @@ +import cv2 +import numpy as np +from typing import List, Tuple, Optional, Union +from pathlib import Path +from dataclasses import dataclass, field +from loguru import logger +from sklearn.cluster import DBSCAN +import matplotlib.pyplot as plt +import json + +from .preprocessing import ( + Preprocessor, + PreprocessingConfig +) + + +@dataclass +class StarDetectionConfig: + """ + Configuration settings for star detection. + """ + median_filter_size: int = 3 + wavelet_levels: int = 4 + binarization_threshold: int = 30 + min_star_size: int = 10 + min_star_brightness: int = 20 + min_circularity: float = 0.7 + max_circularity: float = 1.3 + scales: List[float] = field(default_factory=lambda: [1.0, 0.75, 0.5]) + dbscan_eps: float = 10.0 + dbscan_min_samples: int = 2 + save_detected_stars: bool = False + detected_stars_save_path: Optional[Path] = None + visualize: bool = True + visualization_save_path: Optional[Path] = None + + +class StarDetector: + def __init__(self, config: Optional[StarDetectionConfig] = None) -> None: + """ + Initialize the StarDetector with optional configuration. + + :param config: StarDetectionConfig object containing detection settings. + """ + self.config = config or StarDetectionConfig() + + # Configure Loguru logger + logger.remove() # Remove default logger + logger.add( + "detection.log", + rotation="10 MB", + retention="10 days", + level="DEBUG", + format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}" + ) + logger.debug( + "Initialized StarDetector with configuration: {}", self.config) + + # Initialize Preprocessor + self.preprocessor = Preprocessor( + config=PreprocessingConfig( + median_filter_size=self.config.median_filter_size, + wavelet_levels=self.config.wavelet_levels, + binarization_threshold=self.config.binarization_threshold, + save_preprocessed=self.config.save_detected_stars, + preprocessed_save_path=self.config.detected_stars_save_path + ) + ) + + def multiscale_detect_stars(self, image: np.ndarray) -> List[Tuple[int, int]]: + """ + Detect stars in an image using multiscale analysis and clustering. + + :param image: Grayscale input image as a numpy array. + :return: List of detected star positions as (x, y) tuples. + """ + logger.info("Starting multiscale star detection.") + all_stars = [] + + for scale in self.config.scales: + logger.debug("Processing scale: {}", scale) + resized_image = cv2.resize( + image, + None, + fx=scale, + fy=scale, + interpolation=cv2.INTER_LINEAR + ) + logger.debug("Image resized to scale {}.", scale) + + filtered_image = self.preprocessor.apply_median_filter( + resized_image) + logger.debug("Median filtering applied.") + + pyramid = self.preprocessor.wavelet_transform(filtered_image) + background = self.preprocessor.extract_background(pyramid) + subtracted_image = self.preprocessor.background_subtraction( + filtered_image, background) + logger.debug("Background subtraction completed.") + + processed_image = self.preprocessor.inverse_wavelet_transform( + pyramid) + logger.debug("Inverse wavelet transform completed.") + + binary_image = self.preprocessor.binarize(subtracted_image) + logger.debug("Image binarization completed.") + + star_centers, star_properties = self.preprocessor.detect_stars( + binary_image) + logger.debug( + "Star detection completed with {} stars detected.", len(star_centers)) + + filtered_stars = self.filter_stars(star_properties, binary_image) + logger.debug("{} stars passed the filtering criteria.", + len(filtered_stars)) + + # Adjust star positions back to original scale + scaled_stars = [(int(x / scale), int(y / scale)) + for (x, y) in filtered_stars] + all_stars.extend(scaled_stars) + logger.debug("Star positions scaled back to original image size.") + + # Remove duplicate stars using DBSCAN clustering + logger.info("Clustering detected stars to remove duplicates.") + unique_stars = self.remove_duplicates(all_stars) + logger.info( + "Duplicate stars removed. {} unique stars detected.", len(unique_stars)) + + if self.config.save_detected_stars and self.config.detected_stars_save_path: + self.save_detected_stars( + unique_stars, self.config.detected_stars_save_path) + + if self.config.visualize: + self.visualize_stars(image, unique_stars, + self.config.visualization_save_path) + + return unique_stars + + def filter_stars( + self, + star_properties: List[Tuple[Tuple[int, int], float, float]], + binary_image: np.ndarray + ) -> List[Tuple[int, int]]: + """ + Filter detected stars based on shape, size, and brightness. + + :param star_properties: List of tuples containing star properties (center, area, perimeter). + :param binary_image: Binary image used for star detection. + :return: List of filtered star positions as (x, y) tuples. + """ + logger.debug("Filtering stars based on defined criteria.") + filtered_stars = [] + for (center, area, perimeter) in star_properties: + if perimeter == 0: + logger.debug( + "Skipping star at {} due to zero perimeter.", center) + continue + circularity = (4 * np.pi * area) / (perimeter ** 2) + logger.debug("Star at {} has circularity {}.", center, circularity) + + mask = np.zeros_like(binary_image) + cv2.circle(mask, center, 5, 255, -1) + star_pixels = cv2.countNonZero(mask) + brightness = np.mean(binary_image[mask == 255]) + + logger.debug( + "Star at {}: pixels={}, brightness={}.", + center, star_pixels, brightness + ) + + if ( + star_pixels > self.config.min_star_size and + brightness > self.config.min_star_brightness and + self.config.min_circularity <= circularity <= self.config.max_circularity + ): + filtered_stars.append(center) + logger.debug("Star at {} passed filtering.", center) + + logger.debug("Filtering completed. {} stars passed.", + len(filtered_stars)) + return filtered_stars + + def remove_duplicates(self, stars: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + """ + Remove duplicate stars using DBSCAN clustering. + + :param stars: List of star positions as (x, y) tuples. + :return: List of unique star positions. + """ + if not stars: + logger.warning("No stars to cluster for duplicate removal.") + return [] + + try: + star_array = np.array(stars) + clustering = DBSCAN(eps=self.config.dbscan_eps, + min_samples=self.config.dbscan_min_samples) + labels = clustering.fit_predict(star_array) + logger.debug("DBSCAN clustering labels: {}", labels) + + unique_stars = [] + unique_labels = set(labels) + for label in unique_labels: + if label == -1: + continue # Noise + cluster = star_array[labels == label] + centroid = tuple(cluster.mean(axis=0).astype(int)) + unique_stars.append(centroid) + logger.debug("Cluster {}: Centroid at {}.", label, centroid) + + logger.info( + "Duplicate removal completed. {} unique stars identified.", len(unique_stars)) + return unique_stars + except Exception as e: + logger.error("Error during duplicate removal: {}", e) + raise + + def save_detected_stars(self, stars: List[Tuple[int, int]], save_path: Union[str, Path]) -> None: + """ + Save the detected star positions to a JSON file. + + :param stars: List of star positions as (x, y) tuples. + :param save_path: Path to save the JSON file. + """ + logger.info("Saving detected stars to {}.", save_path) + try: + with open(save_path, 'w') as f: + json.dump(stars, f, indent=4) + logger.debug("Detected stars saved successfully.") + except Exception as e: + logger.error("Failed to save detected stars: {}", e) + raise + + def load_detected_stars(self, file_path: Union[str, Path]) -> List[Tuple[int, int]]: + """ + Load detected star positions from a JSON file. + + :param file_path: Path to the JSON file. + :return: List of star positions as (x, y) tuples. + """ + logger.info("Loading detected stars from {}.", file_path) + try: + with open(file_path, 'r') as f: + stars = json.load(f) + logger.debug( + "Detected stars loaded successfully with {} entries.", len(stars)) + return stars + except Exception as e: + logger.error("Failed to load detected stars: {}", e) + raise + + def visualize_stars( + self, + original_image: np.ndarray, + stars: List[Tuple[int, int]], + save_path: Optional[Path] = None + ) -> None: + """ + Visualize the detected stars on the original image. + + :param original_image: Original grayscale image as a numpy array. + :param stars: List of detected star positions as (x, y) tuples. + :param save_path: Optional path to save the visualization image. + """ + logger.info("Visualizing detected stars.") + try: + color_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2BGR) + for (x, y) in stars: + cv2.circle(color_image, (x, y), 5, (0, 0, 255), 1) + + plt.figure(figsize=(10, 8)) + plt.imshow(cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)) + plt.title('Detected Stars') + plt.axis('off') + if save_path: + plt.savefig(save_path) + logger.debug("Star visualization saved to {}.", save_path) + plt.show() + logger.debug("Star visualization displayed successfully.") + except Exception as e: + logger.error("Error during star visualization: {}", e) + raise + + +# Example Usage +if __name__ == "__main__": + from preprocessing import PreprocessingConfig + + # Configure star detection parameters + detection_config = StarDetectionConfig( + median_filter_size=5, + wavelet_levels=3, + binarization_threshold=40, + min_star_size=15, + min_star_brightness=25, + min_circularity=0.8, + max_circularity=1.2, + scales=[1.0, 0.8, 0.6], + dbscan_eps=15.0, + dbscan_min_samples=3, + save_detected_stars=True, + detected_stars_save_path=Path("detected_stars.json"), + visualize=True, + visualization_save_path=Path("detected_stars.png") + ) + + # Initialize StarDetector + star_detector = StarDetector(config=detection_config) + + # Load and preprocess image + image_path = "path/to/grayscale_image.png" + try: + preprocessor = star_detector.preprocessor + preprocessed_image = preprocessor.load_image(image_path) + logger.info("Image loaded and preprocessed.") + + # Perform star detection + detected_stars = star_detector.multiscale_detect_stars( + preprocessed_image) + logger.info( + "Star detection completed with {} stars detected.", len(detected_stars)) + + # Optionally, load detected stars from file + # detected_stars = star_detector.load_detected_stars("detected_stars.json") + + except Exception as e: + logger.error("An error occurred during star detection: {}", e) diff --git a/modules/lithium.pyimage/image/star_detection/preprocessing.py b/modules/lithium.pyimage/image/star_detection/preprocessing.py new file mode 100644 index 00000000..51500305 --- /dev/null +++ b/modules/lithium.pyimage/image/star_detection/preprocessing.py @@ -0,0 +1,371 @@ +import cv2 +import numpy as np +from astropy.io import fits +from typing import List, Tuple, Optional, Union +from pathlib import Path +from dataclasses import dataclass, field +from loguru import logger +import matplotlib.pyplot as plt +import json + + +@dataclass +class PreprocessingConfig: + """ + Configuration settings for image preprocessing. + """ + median_filter_size: int = 3 + wavelet_levels: int = 3 + binarization_threshold: int = 128 + save_preprocessed: bool = False + preprocessed_save_path: Optional[Path] = None + save_wavelet: bool = False + wavelet_save_path: Optional[Path] = None + + +class Preprocessor: + def __init__(self, config: Optional[PreprocessingConfig] = None) -> None: + """ + Initialize the Preprocessor with optional configuration. + + :param config: PreprocessingConfig object containing settings. + """ + self.config = config or PreprocessingConfig() + + # Configure Loguru logger + logger.remove() # Remove default logger + logger.add("preprocessing.log", rotation="10 MB", retention="10 days", + level="DEBUG", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") + logger.debug("Initialized Preprocessor with default configuration.") + + def load_fits_image(self, file_path: Union[str, Path]) -> np.ndarray: + """ + Load a FITS image from the specified file path. + + :param file_path: Path to the FITS file. + :return: Image data as a numpy array. + """ + logger.info(f"Loading FITS image from {file_path}") + try: + with fits.open(str(file_path)) as hdul: + image_data = hdul[0].data + logger.debug("FITS image loaded successfully.") + return image_data + except Exception as e: + logger.error(f"Failed to load FITS image: {e}") + raise + + def preprocess_fits_image(self, image_data: np.ndarray) -> np.ndarray: + """ + Preprocess FITS image by normalizing to the 0-255 range. + + :param image_data: Raw image data from the FITS file. + :return: Preprocessed image data as a numpy array. + """ + logger.info("Preprocessing FITS image.") + try: + image_data = np.nan_to_num(image_data) + image_data = image_data.astype(np.float64) + image_data -= np.min(image_data) + image_data /= np.max(image_data) + image_data *= 255 + preprocessed_image = image_data.astype(np.uint8) + logger.debug("FITS image preprocessed successfully.") + return preprocessed_image + except Exception as e: + logger.error(f"Error during FITS image preprocessing: {e}") + raise + + def load_image(self, file_path: Union[str, Path]) -> np.ndarray: + """ + Load an image from the specified file path. Supports FITS and standard image formats. + + :param file_path: Path to the image file. + :return: Loaded image as a numpy array. + """ + logger.info(f"Loading image from {file_path}") + try: + if str(file_path).lower().endswith('.fits'): + image_data = self.load_fits_image(file_path) + if image_data.ndim == 2: + return self.preprocess_fits_image(image_data) + elif image_data.ndim == 3: + channels = [self.preprocess_fits_image( + image_data[..., i]) for i in range(image_data.shape[2])] + preprocessed_image = cv2.merge(channels) + logger.debug( + "Multichannel FITS image preprocessed successfully.") + return preprocessed_image + else: + logger.error("Unsupported FITS image dimensions.") + raise ValueError("Unsupported FITS image dimensions.") + else: + image = cv2.imread(str(file_path), cv2.IMREAD_UNCHANGED) + if image is None: + logger.error(f"Unable to load image file: {file_path}") + raise ValueError(f"Unable to load image file: {file_path}") + logger.debug("Standard image loaded successfully.") + return image + except Exception as e: + logger.error(f"Error loading image: {e}") + raise + + def apply_median_filter(self, image: np.ndarray) -> np.ndarray: + """ + Apply median filtering to the image. + + :param image: Input image. + :return: Filtered image. + """ + logger.info("Applying median filter.") + try: + filtered_image = cv2.medianBlur( + image, self.config.median_filter_size) + logger.debug( + f"Median filter applied with kernel size {self.config.median_filter_size}.") + return filtered_image + except Exception as e: + logger.error(f"Error applying median filter: {e}") + raise + + def wavelet_transform(self, image: np.ndarray) -> List[np.ndarray]: + """ + Perform wavelet transform using a Laplacian pyramid. + + :param image: Input image. + :return: List of wavelet transformed images at each level. + """ + logger.info("Performing wavelet transform.") + try: + pyramid = [] + current_image = image.copy() + for level in range(self.config.wavelet_levels): + down = cv2.pyrDown(current_image) + up = cv2.pyrUp(down, current_image.shape[:2]) + + # Resize up to match the original image size + up = cv2.resize( + up, (current_image.shape[1], current_image.shape[0])) + + # Calculate the difference to get the detail layer + layer = cv2.subtract(current_image, up) + pyramid.append(layer) + logger.debug(f"Wavelet level {level + 1} computed.") + current_image = down + + pyramid.append(current_image) # Add the final low-resolution image + logger.debug("Wavelet transform completed successfully.") + if self.config.save_wavelet and self.config.wavelet_save_path: + self.save_wavelet_pyramid( + pyramid, self.config.wavelet_save_path) + return pyramid + except Exception as e: + logger.error(f"Error during wavelet transform: {e}") + raise + + def inverse_wavelet_transform(self, pyramid: List[np.ndarray]) -> np.ndarray: + """ + Reconstruct the image from its wavelet pyramid representation. + + :param pyramid: List of wavelet transformed images at each level. + :return: Reconstructed image. + """ + logger.info("Performing inverse wavelet transform.") + try: + image = pyramid.pop() + while pyramid: + up = cv2.pyrUp(image, pyramid[-1].shape[:2]) + + # Resize up to match the size of the current level + up = cv2.resize( + up, (pyramid[-1].shape[1], pyramid[-1].shape[0])) + + # Add the detail layer to reconstruct the image + image = cv2.add(up, pyramid.pop()) + logger.debug("Wavelet level reconstructed.") + logger.debug("Inverse wavelet transform completed successfully.") + return image + except Exception as e: + logger.error(f"Error during inverse wavelet transform: {e}") + raise + + def background_subtraction(self, image: np.ndarray, background: np.ndarray) -> np.ndarray: + """ + Subtract the background from the image using the provided background image. + + :param image: Original image. + :param background: Background image to subtract. + :return: Image with background subtracted. + """ + logger.info("Performing background subtraction.") + try: + background_resized = cv2.resize( + background, (image.shape[1], image.shape[0])) + result = cv2.subtract(image, background_resized) + result[result < 0] = 0 + logger.debug("Background subtracted successfully.") + return result + except Exception as e: + logger.error(f"Error during background subtraction: {e}") + raise + + def binarize(self, image: np.ndarray) -> np.ndarray: + """ + Binarize the image using a fixed threshold from configuration. + + :param image: Input image. + :return: Binarized image. + """ + logger.info("Binarizing image.") + try: + _, binary_image = cv2.threshold( + image, self.config.binarization_threshold, 255, cv2.THRESH_BINARY) + logger.debug( + f"Image binarized with threshold {self.config.binarization_threshold}.") + return binary_image + except Exception as e: + logger.error(f"Error during binarization: {e}") + raise + + def detect_stars(self, binary_image: np.ndarray) -> Tuple[List[Tuple[int, int]], List[Tuple[Tuple[int, int], float, float]]]: + """ + Detect stars in a binary image by finding contours. + + :param binary_image: Binarized image. + :return: Tuple containing a list of star centers and a list of star properties (center, area, perimeter). + """ + logger.info("Detecting stars in binary image.") + try: + contours, _ = cv2.findContours( + binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + star_centers = [] + star_properties = [] + + for contour in contours: + M = cv2.moments(contour) + if M['m00'] != 0: + center = (int(M['m10'] / M['m00']), + int(M['m01'] / M['m00'])) + star_centers.append(center) + area = cv2.contourArea(contour) + perimeter = cv2.arcLength(contour, True) + star_properties.append((center, area, perimeter)) + logger.debug( + f"Detected star at {center} with area {area} and perimeter {perimeter}.") + + logger.info(f"Total stars detected: {len(star_centers)}.") + return star_centers, star_properties + except Exception as e: + logger.error(f"Error during star detection: {e}") + raise + + def save_preprocessed_image(self, image: np.ndarray, save_path: Union[str, Path]) -> None: + """ + Save the preprocessed image to the specified path. + + :param image: Preprocessed image. + :param save_path: Path to save the image. + """ + logger.info(f"Saving preprocessed image to {save_path}.") + try: + cv2.imwrite(str(save_path), image) + logger.debug("Preprocessed image saved successfully.") + except Exception as e: + logger.error(f"Failed to save preprocessed image: {e}") + raise + + def save_wavelet_pyramid(self, pyramid: List[np.ndarray], save_path: Union[str, Path]) -> None: + """ + Save the wavelet pyramid images to the specified directory. + + :param pyramid: List of wavelet transformed images at each level. + :param save_path: Directory path to save the wavelet images. + """ + logger.info(f"Saving wavelet pyramid to {save_path}.") + try: + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + for idx, layer in enumerate(pyramid): + layer_path = save_path / f"wavelet_level_{idx + 1}.png" + cv2.imwrite(str(layer_path), layer) + logger.debug(f"Wavelet level {idx + 1} saved to {layer_path}.") + logger.debug("All wavelet pyramid images saved successfully.") + except Exception as e: + logger.error(f"Failed to save wavelet pyramid images: {e}") + raise + + def run_preprocessing_pipeline(self, file_path: Union[str, Path]) -> Tuple[np.ndarray, List[Tuple[int, int]], List[Tuple[Tuple[int, int], float, float]]]: + """ + Execute the full preprocessing pipeline on the provided image file. + + :param file_path: Path to the image file. + :return: Tuple containing the final processed image, list of star centers, and list of star properties. + """ + logger.info("Starting preprocessing pipeline.") + try: + image = self.load_image(file_path) + logger.debug("Image loaded for preprocessing.") + + if self.config.save_preprocessed and self.config.preprocessed_save_path: + self.save_preprocessed_image( + image, self.config.preprocessed_save_path) + + filtered_image = self.apply_median_filter(image) + logger.debug("Median filter applied.") + + wavelet_pyramid = self.wavelet_transform(filtered_image) + reconstructed_image = self.inverse_wavelet_transform( + wavelet_pyramid) + logger.debug("Wavelet transform and inverse transform completed.") + + # Placeholder for actual background + background = np.zeros_like(reconstructed_image) + subtracted_image = self.background_subtraction( + reconstructed_image, background) + logger.debug("Background subtraction completed.") + + binary_image = self.binarize(subtracted_image) + logger.debug("Image binarization completed.") + + star_centers, star_properties = self.detect_stars(binary_image) + logger.debug("Star detection completed.") + + logger.info("Preprocessing pipeline completed successfully.") + return binary_image, star_centers, star_properties + except Exception as e: + logger.error(f"Error in preprocessing pipeline: {e}") + raise + + +# Example Usage +if __name__ == "__main__": + import sys + + # Initialize Preprocessor with custom configuration + config = PreprocessingConfig( + median_filter_size=5, + wavelet_levels=4, + binarization_threshold=100, + save_preprocessed=True, + preprocessed_save_path=Path("preprocessed_image.png"), + save_wavelet=True, + wavelet_save_path=Path("wavelet_pyramid") + ) + preprocessor = Preprocessor(config=config) + + # Path to the input image (FITS or standard format) + input_image_path = "path/to/image.fits" + + try: + # Run preprocessing pipeline + binary_img, centers, properties = preprocessor.run_preprocessing_pipeline( + input_image_path) + print(f"Detected {len(centers)} stars.") + + # Optionally, visualize binary image + plt.imshow(binary_img, cmap='gray') + plt.title('Binarized Image') + plt.show() + + except Exception as e: + print(f"An error occurred during preprocessing: {e}") diff --git a/pysrc/image/star_detection/utils.py b/modules/lithium.pyimage/image/star_detection/utils.py similarity index 100% rename from pysrc/image/star_detection/utils.py rename to modules/lithium.pyimage/image/star_detection/utils.py diff --git a/pysrc/image/image_io/__init__.py b/modules/lithium.pyimage/image/transformation/__init__.py similarity index 100% rename from pysrc/image/image_io/__init__.py rename to modules/lithium.pyimage/image/transformation/__init__.py diff --git a/modules/lithium.pyimage/image/transformation/curve.py b/modules/lithium.pyimage/image/transformation/curve.py new file mode 100644 index 00000000..ad351d31 --- /dev/null +++ b/modules/lithium.pyimage/image/transformation/curve.py @@ -0,0 +1,269 @@ +import numpy as np +import matplotlib.pyplot as plt +from scipy.interpolate import CubicSpline, Akima1DInterpolator, interp1d +from typing import Optional, Tuple, List, Callable, Union +from pathlib import Path +import json +from loguru import logger + + +class CurvesTransformation: + def __init__(self, interpolation: str = 'akima') -> None: + self.points: List[Tuple[float, float]] = [] + self.interpolation: str = interpolation + self.curve: Optional[Callable[[float], float]] = None + self.stored_curve: Optional[List[Tuple[float, float]]] = None + + # Configure Loguru logger + logger.add("curve_transformation.log", rotation="10 MB", retention="10 days", + level="DEBUG", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") + logger.debug( + f"Initialized CurvesTransformation with interpolation='{self.interpolation}'") + + def add_point(self, x: float, y: float) -> None: + logger.debug(f"Adding point ({x}, {y})") + self.points.append((x, y)) + self.points.sort(key=lambda point: point[0]) # Sort points by x value + self._update_curve() + + def remove_point(self, index: int) -> None: + if 0 <= index < len(self.points): + removed = self.points.pop(index) + logger.debug(f"Removed point at index {index}: {removed}") + self._update_curve() + else: + logger.error(f"Index out of range: {index}") + raise IndexError("Point index out of range.") + + def _update_curve(self) -> None: + if len(self.points) < 2: + self.curve = None + logger.warning("Not enough points to define a curve.") + return + + x, y = zip(*self.points) + + logger.debug( + f"Updating curve with interpolation method: {self.interpolation}") + + try: + if self.interpolation == 'cubic': + self.curve = CubicSpline(x, y) + logger.debug("CubicSpline interpolation applied.") + elif self.interpolation == 'akima': + self.curve = Akima1DInterpolator(x, y) + logger.debug("Akima1DInterpolator interpolation applied.") + elif self.interpolation == 'linear': + self.curve = interp1d( + x, y, kind='linear', fill_value="extrapolate") + logger.debug("Linear interpolation applied.") + else: + logger.error( + f"Unsupported interpolation method: {self.interpolation}") + raise ValueError("Unsupported interpolation method.") + except Exception as e: + logger.error(f"Failed to update curve: {e}") + self.curve = None + + def transform(self, image: np.ndarray, channel: Optional[int] = None) -> np.ndarray: + if self.curve is None: + logger.error("No valid curve defined for transformation.") + raise ValueError("No valid curve defined.") + + logger.info( + f"Applying curve transformation{' on channel ' + str(channel) if channel is not None else ''}") + + try: + transformed_image = image.astype( + np.float32) / 255.0 # Normalize to [0, 1] + + if len(image.shape) == 2: # Grayscale image + logger.debug("Transforming grayscale image.") + transformed_image = self.curve(transformed_image) + elif len(image.shape) == 3: # RGB image + if channel is None: + logger.error("Channel must be specified for color images.") + raise ValueError( + "Channel must be specified for color images.") + logger.debug(f"Transforming channel {channel} of RGB image.") + transformed_image[:, :, channel] = self.curve( + transformed_image[:, :, channel]) + else: + logger.error("Unsupported image format.") + raise ValueError("Unsupported image format.") + + transformed_image = np.clip(transformed_image, 0, 1) + transformed_image = (transformed_image * 255).astype(np.uint8) + logger.info("Curve transformation applied successfully.") + return transformed_image + except Exception as e: + logger.error(f"Error during transformation: {e}") + raise + + def plot_curve(self, title: str = "Curves Transformation") -> None: + if self.curve is None: + logger.warning("No curve to plot.") + print("No curve to plot.") + return + + logger.debug(f"Plotting curve with title: {title}") + x_vals = np.linspace(0, 1, 500) + y_vals = self.curve(x_vals) + + plt.figure(figsize=(8, 6)) + plt.plot(x_vals, y_vals, + label=f'Interpolation: {self.interpolation}', color='blue') + plt.scatter(*zip(*self.points), color='red', label='Control Points') + plt.title(title) + plt.xlabel('Input Intensity') + plt.ylabel('Output Intensity') + plt.grid(True) + plt.legend() + plt.tight_layout() + plt.show() + logger.debug("Curve plotted successfully.") + + def store_curve(self) -> None: + self.stored_curve = self.points.copy() + logger.info("Curve stored successfully.") + + def restore_curve(self) -> None: + if self.stored_curve: + self.points = self.stored_curve.copy() + self._update_curve() + logger.info("Curve restored successfully.") + else: + logger.warning("No stored curve to restore.") + print("No stored curve to restore.") + + def invert_curve(self) -> None: + if self.curve is None: + logger.warning("No curve to invert.") + print("No curve to invert.") + return + + logger.info("Inverting curve.") + self.points = [(x, 1 - y) for x, y in self.points] + self._update_curve() + logger.debug("Curve inverted successfully.") + + def reset_curve(self) -> None: + self.points = [(0.0, 0.0), (1.0, 1.0)] + self._update_curve() + logger.info("Curve reset to default.") + + def pixel_readout(self, x: float) -> Optional[float]: + if self.curve is None: + logger.warning("No curve defined for pixel readout.") + print("No curve defined.") + return None + try: + value = self.curve(x) + logger.debug(f"Pixel readout at x={x}: {value}") + return float(value) + except Exception as e: + logger.error(f"Error in pixel readout: {e}") + return None + + def save_curve(self, filepath: Union[str, Path]) -> None: + logger.info(f"Saving curve to {filepath}") + try: + curve_data = { + 'interpolation': self.interpolation, + 'points': self.points + } + with open(filepath, 'w') as file: + json.dump(curve_data, file, indent=4) + logger.debug("Curve saved successfully.") + except Exception as e: + logger.error(f"Failed to save curve: {e}") + raise + + def load_curve(self, filepath: Union[str, Path]) -> None: + logger.info(f"Loading curve from {filepath}") + try: + with open(filepath, 'r') as file: + curve_data = json.load(file) + self.interpolation = curve_data.get('interpolation', 'akima') + self.points = curve_data.get('points', []) + self._update_curve() + logger.debug("Curve loaded successfully.") + except Exception as e: + logger.error(f"Failed to load curve: {e}") + raise + + def export_curve_points(self) -> List[Tuple[float, float]]: + logger.debug("Exporting curve points.") + return self.points.copy() + + def import_curve_points(self, points: List[Tuple[float, float]]) -> None: + logger.debug(f"Importing curve points: {points}") + self.points = sorted(points, key=lambda point: point[0]) + self._update_curve() + + def get_interpolation_methods(self) -> List[str]: + methods = ['cubic', 'akima', 'linear'] + logger.debug(f"Available interpolation methods: {methods}") + return methods + + +# Example Usage +if __name__ == "__main__": + import cv2 + import sys + + # Initialize CurvesTransformation object + curve_transform = CurvesTransformation(interpolation='akima') + + # Add points to the curve + curve_transform.add_point(0.0, 0.0) + curve_transform.add_point(0.3, 0.5) + curve_transform.add_point(0.7, 0.8) + curve_transform.add_point(1.0, 1.0) + + # Plot the curve + curve_transform.plot_curve() + + # Store the curve + curve_transform.store_curve() + + # Invert the curve + curve_transform.invert_curve() + curve_transform.plot_curve() + + # Restore the original curve + curve_transform.restore_curve() + curve_transform.plot_curve() + + # Reset the curve to default + curve_transform.reset_curve() + curve_transform.plot_curve() + + # Save the curve + curve_transform.save_curve("default_curve.json") + + # Load the curve + curve_transform.load_curve("default_curve.json") + curve_transform.plot_curve() + + # Generate a test image + test_image = np.linspace(0, 1, 256).reshape(16, 16).astype(np.float32) + + # Apply the transformation + transformed_image = curve_transform.transform(test_image) + + # Plot original and transformed images + plt.figure(figsize=(12, 6)) + plt.subplot(1, 2, 1) + plt.title("Original Image") + plt.imshow(test_image, cmap='gray', vmin=0, vmax=1) + + plt.subplot(1, 2, 2) + plt.title("Transformed Image") + plt.imshow(transformed_image, cmap='gray', vmin=0, vmax=1) + plt.tight_layout() + plt.show() + + # Pixel readout + readout_value = curve_transform.pixel_readout(0.5) + print(f"Pixel readout at x=0.5: {readout_value}") diff --git a/modules/lithium.pyimage/image/transformation/histogram.py b/modules/lithium.pyimage/image/transformation/histogram.py new file mode 100644 index 00000000..eae096d1 --- /dev/null +++ b/modules/lithium.pyimage/image/transformation/histogram.py @@ -0,0 +1,314 @@ +import cv2 +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from typing import Optional, Tuple, List +import argparse +from loguru import logger + +# Configure Loguru logger +logger.remove() # Remove default logger +logger.add("histogram.log", rotation="10 MB", retention="10 days", + level="DEBUG", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") + + +def calculate_histogram(image: np.ndarray, channel: int = 0) -> np.ndarray: + """ + Calculate the histogram of an image channel. + + :param image: Input image as a numpy array. + :param channel: Channel index (0 for Blue, 1 for Green, 2 for Red in BGR images). + :return: Histogram as a numpy array. + """ + logger.debug(f"Calculating histogram for channel {channel}") + histogram = cv2.calcHist([image], [channel], None, [256], [0, 256]) + logger.debug(f"Histogram calculated: {histogram.flatten()}") + return histogram + + +def display_histogram(histogram: np.ndarray, title: str = "Histogram") -> None: + """ + Display the histogram using Matplotlib. + + :param histogram: Histogram data as a numpy array. + :param title: Title of the histogram plot. + """ + logger.debug(f"Displaying histogram with title: {title}") + plt.figure() + plt.plot(histogram, color='black') + plt.title(title) + plt.xlabel('Pixel Intensity') + plt.ylabel('Frequency') + plt.grid(True) + plt.show() + logger.debug("Histogram displayed successfully") + + +def apply_histogram_transformation(image: np.ndarray, + shadows_clip: float = 0.0, + highlights_clip: float = 1.0, + midtones_balance: float = 0.5, + lower_bound: float = -1.0, + upper_bound: float = 2.0) -> np.ndarray: + """ + Apply histogram-based transformation to an image. + + :param image: Input grayscale image as a numpy array. + :param shadows_clip: Shadow clipping threshold (0.0 to 1.0). + :param highlights_clip: Highlight clipping threshold (0.0 to 1.0). + :param midtones_balance: Balance factor for midtones. + :param lower_bound: Lower bound for dynamic range expansion. + :param upper_bound: Upper bound for dynamic range expansion. + :return: Transformed image as a numpy array. + """ + logger.info("Applying histogram transformation") + try: + # Normalize the image to [0, 1] + normalized_image = image.astype(np.float32) / 255.0 + logger.debug("Image normalized to [0, 1]") + + # Shadows and highlights clipping + clipped_image = np.clip( + (normalized_image - shadows_clip) / (highlights_clip - shadows_clip), 0, 1) + logger.debug( + f"Clipped image with shadows_clip={shadows_clip}, highlights_clip={highlights_clip}") + + # Midtones balance using a custom transfer function + def mtf(x): return (x ** midtones_balance) / \ + ((x ** midtones_balance + (1 - x) ** midtones_balance) + ** (1 / midtones_balance)) + transformed_image = mtf(clipped_image) + logger.debug( + f"Midtones balanced with midtones_balance={midtones_balance}") + + # Dynamic range expansion + expanded_image = np.clip( + (transformed_image - lower_bound) / (upper_bound - lower_bound), 0, 1) + logger.debug( + f"Dynamic range expanded with lower_bound={lower_bound}, upper_bound={upper_bound}") + + # Rescale to [0, 255] + output_image = (expanded_image * 255).astype(np.uint8) + logger.debug("Image rescaled to [0, 255]") + logger.info("Histogram transformation applied successfully") + return output_image + except Exception as e: + logger.error(f"Error in histogram transformation: {e}") + raise + + +def auto_clip(image: np.ndarray, clip_percent: float = 0.01) -> np.ndarray: + """ + Automatically clip the histogram based on a percentage. + + :param image: Input grayscale image as a numpy array. + :param clip_percent: Percentage for clipping the histogram tails. + :return: Auto-clipped image as a numpy array. + """ + logger.info(f"Applying auto clipping with clip_percent={clip_percent}") + try: + # Compute the histogram + hist, bins = np.histogram(image.flatten(), 256, [0, 256]) + cdf = hist.cumsum() + total_pixels = image.size + logger.debug("Computed cumulative distribution function (CDF)") + + # Calculate clipping points + lower_clip = np.searchsorted(cdf, total_pixels * clip_percent) + upper_clip = np.searchsorted(cdf, total_pixels * (1 - clip_percent)) + logger.debug(f"Lower clip at {lower_clip}, Upper clip at {upper_clip}") + + # Apply histogram transformation + auto_clipped_image = apply_histogram_transformation( + image, + shadows_clip=lower_clip / 255.0, + highlights_clip=upper_clip / 255.0 + ) + logger.info("Auto clipping applied successfully") + return auto_clipped_image + except Exception as e: + logger.error(f"Error in auto clipping: {e}") + raise + + +def display_rgb_histogram(image: np.ndarray) -> None: + """ + Display RGB histograms for an image. + + :param image: Input image as a numpy array (BGR format). + """ + logger.debug("Displaying RGB histograms") + colors = ('b', 'g', 'r') + plt.figure() + for i, col in enumerate(colors): + hist = calculate_histogram(image, channel=i) + plt.plot(hist, color=col) + plt.xlim([0, 256]) + plt.title('RGB Histogram') + plt.xlabel('Pixel Intensity') + plt.ylabel('Frequency') + plt.grid(True) + plt.show() + logger.debug("RGB histograms displayed successfully") + + +def real_time_preview(image: np.ndarray, + transformation_function, + window_name: str = "Real-Time Preview", + **kwargs) -> None: + """ + Display real-time preview of image transformations. + + :param image: Input image as a numpy array. + :param transformation_function: Function to apply transformation. + :param window_name: Name of the display window. + :param kwargs: Additional keyword arguments for the transformation function. + """ + logger.info("Starting real-time preview") + try: + preview_image = transformation_function(image, **kwargs) + cv2.imshow(window_name, preview_image) + logger.debug("Transformation applied for real-time preview") + except Exception as e: + logger.error(f"Error in real-time preview: {e}") + raise + + +def save_histogram(histogram: np.ndarray, filepath: Path, title: str = "Histogram") -> None: + """ + Save the histogram plot to a file. + + :param histogram: Histogram data as a numpy array. + :param filepath: Path to save the histogram image. + :param title: Title of the histogram plot. + """ + logger.debug(f"Saving histogram to {filepath}") + try: + plt.figure() + plt.plot(histogram, color='black') + plt.title(title) + plt.xlabel('Pixel Intensity') + plt.ylabel('Frequency') + plt.grid(True) + plt.savefig(filepath) + plt.close() + logger.info(f"Histogram saved to {filepath}") + except Exception as e: + logger.error(f"Failed to save histogram: {e}") + raise + + +def parse_arguments() -> argparse.Namespace: + """ + Parse command-line arguments for the histogram tool. + + :return: Parsed arguments namespace. + """ + parser = argparse.ArgumentParser( + description="Image Histogram Processing Tool") + parser.add_argument('--input', type=Path, required=True, + help='Path to the input image.') + parser.add_argument('--output', type=Path, required=True, + help='Path to save the processed image.') + parser.add_argument('--save_histogram', type=Path, default=None, + help='Path to save the histogram plot.') + parser.add_argument('--operation', type=str, choices=['mean', 'gaussian', 'minimum', + 'maximum', 'median', 'bilinear', 'bicubic'], + default='mean', help='Histogram transformation operation.') + parser.add_argument('--structure', type=str, choices=['square', 'circular', 'horizontal', 'vertical'], + default='square', help='Neighborhood structure for transformation.') + parser.add_argument('--radius', type=int, default=1, + help='Radius for neighborhood structure.') + parser.add_argument('--clip_percent', type=float, default=0.01, + help='Clip percentage for auto clipping.') + parser.add_argument('--midtones_balance', type=float, default=0.5, + help='Balance factor for midtones.') + parser.add_argument('--lower_bound', type=float, default=-1.0, + help='Lower bound for dynamic range expansion.') + parser.add_argument('--upper_bound', type=float, default=2.0, + help='Upper bound for dynamic range expansion.') + parser.add_argument('--real_time_preview', action='store_true', + help='Enable real-time preview of transformations.') + + return parser.parse_args() + + +def main(): + """ + Main function to execute histogram processing. + """ + args = parse_arguments() + + # Load image + logger.info(f"Loading image from {args.input}") + image = cv2.imread(str(args.input)) + if image is None: + logger.error(f"Failed to load image: {args.input}") + sys.exit(1) + logger.info(f"Image loaded successfully with shape {image.shape}") + + # Convert to grayscale + grayscale_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + logger.debug("Converted image to grayscale") + + # Display original image and histogram + cv2.imshow('Original Image', image) + original_histogram = calculate_histogram(grayscale_image) + display_histogram(original_histogram, title="Original Grayscale Histogram") + + # Display RGB histogram + display_rgb_histogram(image) + + # Apply histogram transformation + logger.info("Applying histogram transformation") + transformed_image = apply_histogram_transformation( + grayscale_image, + shadows_clip=0.1, + highlights_clip=0.9, + midtones_balance=args.midtones_balance, + lower_bound=args.lower_bound, + upper_bound=args.upper_bound + ) + cv2.imshow('Transformed Image', transformed_image) + logger.info("Histogram transformation applied") + + # Display transformed histogram + transformed_histogram = calculate_histogram(transformed_image) + display_histogram(transformed_histogram, + title="Transformed Grayscale Histogram") + + # Apply auto clipping + logger.info("Applying auto clipping") + auto_clipped_image = auto_clip( + grayscale_image, clip_percent=args.clip_percent) + cv2.imshow('Auto Clipped Image', auto_clipped_image) + logger.info("Auto clipping applied") + + # Save histogram if required + if args.save_histogram: + logger.info(f"Saving histogram to {args.save_histogram}") + save_histogram(auto_clipped_image if args.clip_percent else transformed_image, + args.save_histogram, + title="Auto Clipped Grayscale Histogram") + + # Real-time preview + if args.real_time_preview: + logger.info("Enabling real-time preview") + real_time_preview( + image=grayscale_image, + transformation_function=apply_histogram_transformation, + shadows_clip=0.05, + highlights_clip=0.95, + midtones_balance=args.midtones_balance, + lower_bound=args.lower_bound, + upper_bound=args.upper_bound + ) + + logger.info("Displaying all processed images. Press any key to exit.") + cv2.waitKey(0) + cv2.destroyAllWindows() + logger.info("All windows closed. Program terminated successfully.") + + +if __name__ == "__main__": + main() diff --git a/modules/lithium.pyimage/image/astroalign/test_astroalign.py b/modules/lithium.pyimage/tests/test_astroalign.py similarity index 100% rename from modules/lithium.pyimage/image/astroalign/test_astroalign.py rename to modules/lithium.pyimage/tests/test_astroalign.py diff --git a/modules/lithium.pyimage/tests/test_calibration.py b/modules/lithium.pyimage/tests/test_calibration.py new file mode 100644 index 00000000..8cb5fd22 --- /dev/null +++ b/modules/lithium.pyimage/tests/test_calibration.py @@ -0,0 +1,87 @@ +import pytest +import numpy as np +from pathlib import Path +from ..calibration import batch_calibrate, CalibrationParams, save_to_fits + +# FILE: modules/lithium.pyimage/image/fluxcalibration/test_calibration.py + + +@pytest.fixture +def setup_test_environment(tmp_path): + # Create a temporary directory with test images and parameter files + input_dir = tmp_path / "input_images" + input_dir.mkdir() + output_dir = tmp_path / "output_images" + params_dir = tmp_path / "params" + params_dir.mkdir() + response_dir = tmp_path / "response" + response_dir.mkdir() + + # Create dummy images + for i in range(3): + image = np.ones((100, 100), dtype=np.uint8) * (i * 85) + image_path = input_dir / f"test_image_{i}.png" + cv2.imwrite(str(image_path), image) + + # Create corresponding parameter files + params_content = f""" + wavelength=500.0 + aperture=100.0 + obstruction=20.0 + filter_width=10.0 + transmissivity=0.9 + gain=1.0 + quantum_efficiency=0.8 + extinction=0.1 + exposure_time=30.0 + """ + params_path = params_dir / f"test_image_{i}.txt" + with params_path.open('w') as f: + f.write(params_content.strip()) + + # Create corresponding response function files + response_function = np.ones((100, 100), dtype=np.uint8) * 255 + response_path = response_dir / f"test_image_{i}_response.fits" + save_to_fits(response_function, str(response_path), 0, 1, 1) + + return input_dir, output_dir, params_dir, response_dir + + +def test_batch_calibrate_valid_images(setup_test_environment): + input_dir, output_dir, params_dir, response_dir = setup_test_environment + batch_calibrate(input_dir, output_dir, params_dir, response_dir) + assert len(list(output_dir.glob("*.fits"))) == 3 + + +def test_batch_calibrate_missing_params(setup_test_environment): + input_dir, output_dir, params_dir, response_dir = setup_test_environment + missing_params_path = params_dir / "test_image_1.txt" + missing_params_path.unlink() + batch_calibrate(input_dir, output_dir, params_dir, response_dir) + assert len(list(output_dir.glob("*.fits"))) == 2 + + +def test_batch_calibrate_missing_response(setup_test_environment): + input_dir, output_dir, params_dir, response_dir = setup_test_environment + missing_response_path = response_dir / "test_image_1_response.fits" + missing_response_path.unlink() + batch_calibrate(input_dir, output_dir, params_dir, response_dir) + assert len(list(output_dir.glob("*.fits"))) == 3 + + +def test_batch_calibrate_invalid_input_directory(tmp_path): + invalid_input_dir = tmp_path / "invalid_input" + output_dir = tmp_path / "output_images" + params_dir = tmp_path / "params" + response_dir = tmp_path / "response" + with pytest.raises(NotADirectoryError): + batch_calibrate(invalid_input_dir, output_dir, + params_dir, response_dir) + + +def test_batch_calibrate_invalid_output_directory(setup_test_environment): + input_dir, _, params_dir, response_dir = setup_test_environment + invalid_output_dir = Path("/invalid/output") + with pytest.raises(OSError): + batch_calibrate(input_dir, invalid_output_dir, + params_dir, response_dir) diff --git a/modules/lithium.pyimage/tests/test_combination.py b/modules/lithium.pyimage/tests/test_combination.py new file mode 100644 index 00000000..27e5e7fc --- /dev/null +++ b/modules/lithium.pyimage/tests/test_combination.py @@ -0,0 +1,54 @@ +import pytest +import shutil +from pathlib import Path +from PIL import Image +from ..combination import batch_process + +# FILE: modules/lithium.pyimage/image/channel/test_combination.py + + +@pytest.fixture +def setup_test_images(tmp_path): + # Create a temporary directory with test images + img_dir = tmp_path / "images" + img_dir.mkdir() + for i in range(3): + for ch in ['R', 'G', 'B']: + img = Image.new('L', (100, 100), color=i*85) + img.save(img_dir / f"test_{i}_{ch}.png") + return img_dir + + +def test_batch_process_valid_images(setup_test_images, tmp_path): + output_dir = tmp_path / "output" + batch_process(setup_test_images, "RGB", output_dir, "PNG") + assert len(list(output_dir.glob("*.png"))) == 3 + + +def test_batch_process_missing_channel_images(setup_test_images, tmp_path): + # Remove one channel image to simulate missing channel + missing_image = setup_test_images / "test_1_G.png" + missing_image.unlink() + output_dir = tmp_path / "output" + batch_process(setup_test_images, "RGB", output_dir, "PNG") + assert len(list(output_dir.glob("*.png"))) == 2 + + +def test_batch_process_different_color_spaces(setup_test_images, tmp_path): + output_dir = tmp_path / "output" + batch_process(setup_test_images, "HSV", output_dir, "PNG") + assert len(list(output_dir.glob("*.png"))) == 3 + + +def test_batch_process_custom_channel_mapping(setup_test_images, tmp_path): + output_dir = tmp_path / "output" + batch_process(setup_test_images, "RGB", output_dir, + "PNG", mapping=['B', 'G', 'R']) + assert len(list(output_dir.glob("*.png"))) == 3 + + +def test_batch_process_invalid_directory(tmp_path): + invalid_dir = tmp_path / "invalid" + output_dir = tmp_path / "output" + with pytest.raises(FileNotFoundError): + batch_process(invalid_dir, "RGB", output_dir, "PNG") diff --git a/modules/lithium.pyimage/tests/test_curve.py b/modules/lithium.pyimage/tests/test_curve.py new file mode 100644 index 00000000..82440a34 --- /dev/null +++ b/modules/lithium.pyimage/tests/test_curve.py @@ -0,0 +1,256 @@ +import pytest +import numpy as np +from pathlib import Path +from ..curve import CurvesTransformation + +# FILE: modules/lithium.pyimage/image/transformation/test_curve.py + + +@pytest.fixture +def curve_transformation(): + # Initialize CurvesTransformation with default interpolation + return CurvesTransformation() + + +@pytest.fixture(params=['cubic', 'akima', 'linear']) +def curve_transformation_with_valid_interpolation(request): + # Initialize CurvesTransformation with valid interpolation methods + return CurvesTransformation(interpolation=request.param) + + +@pytest.fixture +def curve_transformation_with_invalid_interpolation(): + # Initialize CurvesTransformation with an invalid interpolation method + return CurvesTransformation(interpolation='invalid_method') + + +@pytest.fixture +def sample_image_grayscale(): + # Create a sample 2D grayscale image + return np.linspace(0, 255, 256).reshape(16, 16).astype(np.uint8) + + +@pytest.fixture +def sample_image_rgb(): + # Create a sample 3D RGB image + grayscale = np.linspace(0, 255, 256).reshape(16, 16).astype(np.uint8) + return np.stack([grayscale, grayscale, grayscale], axis=-1) + + +@pytest.fixture +def temp_json_file(tmp_path): + # Provide a temporary JSON file path + return tmp_path / "test_curve.json" + + +def test_initialization_default(curve_transformation): + assert curve_transformation.interpolation == 'akima' + assert curve_transformation.curve is None + assert curve_transformation.points == [] + + +def test_initialization_with_valid_interpolation(curve_transformation_with_valid_interpolation): + assert curve_transformation_with_valid_interpolation.interpolation in [ + 'cubic', 'akima', 'linear'] + + +def test_initialization_with_invalid_interpolation(curve_transformation_with_invalid_interpolation): + with pytest.raises(ValueError): + curve_transformation_with_invalid_interpolation._update_curve() + + +def test_add_point(curve_transformation): + curve_transformation.add_point(0.0, 0.0) + curve_transformation.add_point(1.0, 1.0) + assert curve_transformation.points == [(0.0, 0.0), (1.0, 1.0)] + assert curve_transformation.curve is not None + + +def test_remove_point(curve_transformation): + curve_transformation.add_point(0.0, 0.0) + curve_transformation.add_point(1.0, 1.0) + curve_transformation.remove_point(0) + assert curve_transformation.points == [(1.0, 1.0)] + assert curve_transformation.curve is None # Not enough points + + +def test_remove_point_invalid_index(curve_transformation): + curve_transformation.add_point(0.0, 0.0) + with pytest.raises(IndexError): + curve_transformation.remove_point(5) + + +def test_update_curve_with_insufficient_points(curve_transformation): + curve_transformation.add_point(0.0, 0.0) + assert curve_transformation.curve is None + + +def test_transform_grayscale_image(curve_transformation_with_valid_interpolation, sample_image_grayscale): + curve = curve_transformation_with_valid_interpolation + curve.add_point(0.0, 0.0) + curve.add_point(1.0, 255.0) + transformed_image = curve.transform(sample_image_grayscale) + assert transformed_image.shape == sample_image_grayscale.shape + assert transformed_image.dtype == np.uint8 + # Identity transformation + assert np.all(transformed_image == sample_image_grayscale) + + +def test_transform_rgb_image_channel(curve_transformation_with_valid_interpolation, sample_image_rgb): + curve = curve_transformation_with_valid_interpolation + curve.add_point(0.0, 0.0) + curve.add_point(1.0, 255.0) + transformed_image = curve.transform(sample_image_rgb, channel=1) + assert transformed_image.shape == sample_image_rgb.shape + assert transformed_image.dtype == np.uint8 + # Unchanged channels + assert np.all(transformed_image[:, :, 0] == sample_image_rgb[:, :, 0]) + assert np.all(transformed_image[:, :, 1] == sample_image_rgb[:, :, 1]) + assert np.all(transformed_image[:, :, 2] == sample_image_rgb[:, :, 2]) + + +def test_transform_without_defining_curve(curve_transformation, sample_image_grayscale): + curve = curve_transformation + with pytest.raises(ValueError): + curve.transform(sample_image_grayscale) + + +def test_invert_curve(curve_transformation_with_valid_interpolation, sample_image_grayscale): + curve = curve_transformation_with_valid_interpolation + curve.add_point(0.0, 0.0) + curve.add_point(1.0, 1.0) + curve.invert_curve() + transformed_image = curve.transform(sample_image_grayscale) + expected_image = 255 - sample_image_grayscale + assert np.array_equal(transformed_image, expected_image) + + +def test_store_and_restore_curve(curve_transformation_with_valid_interpolation, sample_image_grayscale): + curve = curve_transformation_with_valid_interpolation + curve.add_point(0.0, 0.0) + curve.add_point(0.5, 128.0) + curve.add_point(1.0, 255.0) + curve.store_curve() + curve.add_point(0.25, 64.0) + assert len(curve.points) == 4 + curve.restore_curve() + assert len(curve.points) == 3 + assert curve.points == [(0.0, 0.0), (0.5, 128.0), (1.0, 255.0)] + transformed_image = curve.transform(sample_image_grayscale) + assert transformed_image.shape == sample_image_grayscale.shape + + +def test_reset_curve(curve_transformation_with_valid_interpolation, sample_image_grayscale): + curve = curve_transformation_with_valid_interpolation + curve.add_point(0.0, 0.0) + curve.add_point(0.5, 200.0) + curve.reset_curve() + assert curve.points == [(0.0, 0.0), (1.0, 1.0)] + transformed_image = curve.transform(sample_image_grayscale) + normalized = sample_image_grayscale.astype(np.float32) / 255.0 + expected = (normalized * 255).astype(np.uint8) + # Identity transformation + assert np.array_equal(transformed_image, expected) + + +def test_save_curve(curve_transformation_with_valid_interpolation, temp_json_file): + curve = curve_transformation_with_valid_interpolation + curve.add_point(0.0, 0.0) + curve.add_point(1.0, 255.0) + curve.save_curve(temp_json_file) + assert temp_json_file.exists() + with open(temp_json_file, 'r') as f: + data = json.load(f) + assert data['interpolation'] == curve.interpolation + assert data['points'] == curve.points + + +def test_load_curve(curve_transformation_with_valid_interpolation, temp_json_file): + # Prepare a curve file + data = { + 'interpolation': 'linear', + 'points': [(0.0, 0.0), (0.5, 128.0), (1.0, 255.0)] + } + with open(temp_json_file, 'w') as f: + json.dump(data, f, indent=4) + curve = curve_transformation_with_valid_interpolation + curve.load_curve(temp_json_file) + assert curve.interpolation == 'linear' + assert curve.points == [(0.0, 0.0), (0.5, 128.0), (1.0, 255.0)] + + +def test_pixel_readout(curve_transformation_with_valid_interpolation): + curve = curve_transformation_with_valid_interpolation + curve.add_point(0.0, 0.0) + curve.add_point(1.0, 255.0) + value = curve.pixel_readout(0.5) + assert value == 127.5 # Linear interpolation + + +def test_pixel_readout_no_curve(curve_transformation): + value = curve_transformation.pixel_readout(0.5) + assert value is None + + +def test_pixel_readout_invalid_input(curve_transformation_with_valid_interpolation): + curve = curve_transformation_with_valid_interpolation + curve.add_point(0.0, 0.0) + curve.add_point(1.0, 255.0) + with pytest.raises(ValueError): + curve.pixel_readout("invalid_input") # Non-float input + + +def test_export_curve_points(curve_transformation_with_valid_interpolation): + curve = curve_transformation_with_valid_interpolation + points = [(0.0, 0.0), (0.5, 128.0), (1.0, 255.0)] + for x, y in points: + curve.add_point(x, y) + exported_points = curve.export_curve_points() + assert exported_points == points + + +def test_import_curve_points(curve_transformation_with_valid_interpolation): + curve = curve_transformation_with_valid_interpolation + points = [(0.0, 0.0), (0.25, 64.0), (0.75, 192.0), (1.0, 255.0)] + curve.import_curve_points(points) + assert curve.points == sorted(points, key=lambda point: point[0]) + transformed_image = curve.transform( + np.array([[0, 128, 255]], dtype=np.uint8)) + expected = np.array([[0, 128, 255]], dtype=np.uint8) + assert np.array_equal(transformed_image, expected) + + +def test_get_interpolation_methods(curve_transformation_with_valid_interpolation): + methods = curve_transformation_with_valid_interpolation.get_interpolation_methods() + assert methods == ['cubic', 'akima', 'linear'] + + +def test_transform_invalid_image_format(curve_transformation_with_valid_interpolation): + curve = curve_transformation_with_valid_interpolation + curve.add_point(0.0, 0.0) + curve.add_point(1.0, 255.0) + invalid_image = np.zeros((16, 16, 16), dtype=np.uint8) # Unsupported shape + with pytest.raises(ValueError): + curve.transform(invalid_image) + + +def test_transform_rgb_image_without_channel(curve_transformation_with_valid_interpolation, sample_image_rgb): + curve = curve_transformation_with_valid_interpolation + curve.add_point(0.0, 0.0) + curve.add_point(1.0, 255.0) + with pytest.raises(ValueError): + curve.transform(sample_image_rgb) # Missing channel parameter + + +def test_load_curve_invalid_file(curve_transformation_with_valid_interpolation, tmp_path): + invalid_file = tmp_path / "invalid_curve.json" + with open(invalid_file, 'w') as f: + f.write("Invalid JSON content") + with pytest.raises(json.JSONDecodeError): + curve_transformation_with_valid_interpolation.load_curve(invalid_file) + + +def test_load_curve_missing_file(curve_transformation_with_valid_interpolation): + with pytest.raises(FileNotFoundError): + curve_transformation_with_valid_interpolation.load_curve( + "non_existent_curve.json") diff --git a/modules/lithium.pyimage/tests/test_debayer.py b/modules/lithium.pyimage/tests/test_debayer.py new file mode 100644 index 00000000..176f5715 --- /dev/null +++ b/modules/lithium.pyimage/tests/test_debayer.py @@ -0,0 +1,110 @@ +import pytest +import numpy as np +from pathlib import Path +from ..debayer import Debayer, DebayerConfig + +# FILE: modules/lithium.pyimage/image/debayer/test_debayer.py + + +@pytest.fixture +def sample_cfa_image(): + # Create a dummy CFA image (10x10) with a simple pattern + return np.tile(np.array([[0, 1], [1, 0]], dtype=np.uint8), (5, 5)) + + +@pytest.fixture +def debayer_config(): + # Provide a default DebayerConfig instance + return DebayerConfig() + + +@pytest.fixture +def debayer_instance(debayer_config): + # Provide a Debayer instance initialized with the default configuration + return Debayer(config=debayer_config) + + +def test_initialization_default(): + debayer = Debayer() + assert debayer.config.method == 'bilinear' + assert debayer.config.pattern is None + assert debayer.config.num_threads == 4 + assert not debayer.config.visualize_intermediate + assert debayer.config.visualization_save_path is None + assert not debayer.config.save_debayered_images + + +def test_initialization_custom(): + config = DebayerConfig(method='vng', pattern='RGGB', num_threads=2) + debayer = Debayer(config=config) + assert debayer.config.method == 'vng' + assert debayer.config.pattern == 'RGGB' + assert debayer.config.num_threads == 2 + + +def test_detect_bayer_pattern(debayer_instance, sample_cfa_image): + pattern = debayer_instance.detect_bayer_pattern(sample_cfa_image) + assert pattern in ['BGGR', 'RGGB', 'GBRG', 'GRBG'] + + +def test_debayer_superpixel(debayer_instance, sample_cfa_image): + debayer_instance.config.pattern = 'BGGR' + rgb_image = debayer_instance.debayer_superpixel(sample_cfa_image) + assert rgb_image.shape == (5, 5, 3) # Superpixel reduces size by half + + +def test_debayer_bilinear(debayer_instance, sample_cfa_image): + debayer_instance.config.pattern = 'BGGR' + rgb_image = debayer_instance.debayer_bilinear(sample_cfa_image) + assert rgb_image.shape == (10, 10, 3) + + +def test_debayer_vng(debayer_instance, sample_cfa_image): + debayer_instance.config.pattern = 'BGGR' + rgb_image = debayer_instance.debayer_vng(sample_cfa_image) + assert rgb_image.shape == (10, 10, 3) + + +def test_parallel_debayer_ahd(debayer_instance, sample_cfa_image): + debayer_instance.config.pattern = 'BGGR' + rgb_image = debayer_instance.parallel_debayer_ahd(sample_cfa_image) + assert rgb_image.shape == (10, 10, 3) + + +def test_debayer_laplacian_harmonization(debayer_instance, sample_cfa_image): + debayer_instance.config.pattern = 'BGGR' + rgb_image = debayer_instance.debayer_laplacian_harmonization( + sample_cfa_image) + assert rgb_image.shape == (10, 10, 3) + + +def test_unsupported_bayer_pattern(debayer_instance, sample_cfa_image): + debayer_instance.config.pattern = 'INVALID' + with pytest.raises(ValueError): + debayer_instance.debayer_bilinear(sample_cfa_image) + + +def test_unsupported_debayer_method(debayer_instance, sample_cfa_image): + debayer_instance.config.method = 'unsupported_method' + with pytest.raises(ValueError): + debayer_instance.debayer_image(sample_cfa_image) + + +def test_visualize_intermediate_steps(debayer_instance, sample_cfa_image, tmp_path): + debayer_instance.config.visualize_intermediate = True + debayer_instance.config.visualization_save_path = tmp_path / "intermediate_steps.png" + debayer_instance.config.pattern = 'BGGR' + debayered_image = debayer_instance.debayer_bilinear(sample_cfa_image) + Debayer.visualize_intermediate_steps( + "dummy_path", debayered_image, debayer_instance.config) + assert debayer_instance.config.visualization_save_path.exists() + + +def test_save_debayered_images(debayer_instance, sample_cfa_image, tmp_path): + debayer_instance.config.save_debayered_images = True + debayer_instance.config.debayered_save_path_template = str( + tmp_path / "{original_name}_{method}.png") + debayer_instance.config.pattern = 'BGGR' + debayer_instance.debayer_image(sample_cfa_image) + expected_path = tmp_path / "debayered_image_bilinear.png" + assert expected_path.exists() diff --git a/modules/lithium.pyimage/tests/test_defect_correction.py b/modules/lithium.pyimage/tests/test_defect_correction.py new file mode 100644 index 00000000..751878d3 --- /dev/null +++ b/modules/lithium.pyimage/tests/test_defect_correction.py @@ -0,0 +1,129 @@ +import pytest +import numpy as np +from ..defect_correction import defect_map_enhanced + +# FILE: modules/lithium.pyimage/image/defect_map/test_defect_correction.py + + +@pytest.fixture +def setup_test_image(): + # Create a test image and defect map + image = np.ones((10, 10), dtype=np.float32) * 100 + defect_map = np.ones((10, 10), dtype=np.float32) + defect_map[5, 5] = 0 # Introduce a defect + return image, defect_map + + +def test_defect_map_enhanced_mean(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced(image, defect_map, operation='mean') + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_gaussian(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, operation='gaussian') + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_minimum(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, operation='minimum') + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_maximum(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, operation='maximum') + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_median(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, operation='median') + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_bilinear(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, operation='bilinear') + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_bicubic(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, operation='bicubic') + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_square_structure(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, structure='square') + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_circular_structure(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, structure='circular') + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_horizontal_structure(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, structure='horizontal') + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_vertical_structure(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, structure='vertical') + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_with_edge_protection(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, protect_edges=True) + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_without_edge_protection(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, protect_edges=False) + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_with_adaptive_structure(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, adaptive_structure=True) + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_without_adaptive_structure(setup_test_image): + image, defect_map = setup_test_image + corrected_image = defect_map_enhanced( + image, defect_map, adaptive_structure=False) + assert corrected_image[5, 5] != 100 # Defect should be corrected + + +def test_defect_map_enhanced_with_cfa(setup_test_image): + image, defect_map = setup_test_image + # Create a dummy RGB image + image = np.stack([image, image, image], axis=-1) + defect_map = np.stack([defect_map, defect_map, defect_map], axis=-1) + corrected_image = defect_map_enhanced(image, defect_map, is_cfa=True) + assert corrected_image[5, 5, 0] != 100 # Defect should be corrected + assert corrected_image[5, 5, 1] != 100 # Defect should be corrected + assert corrected_image[5, 5, 2] != 100 # Defect should be corrected diff --git a/modules/lithium.pyimage/tests/test_extraction.py b/modules/lithium.pyimage/tests/test_extraction.py new file mode 100644 index 00000000..dbd593ab --- /dev/null +++ b/modules/lithium.pyimage/tests/test_extraction.py @@ -0,0 +1,81 @@ +import pytest +import numpy as np +import cv2 +from pathlib import Path +from ..extraction import extract_channels, merge_channels, process_directory + +# FILE: modules/lithium.pyimage/image/channel/test_extraction.py + + +@pytest.fixture +def setup_test_image(tmp_path): + # Create a temporary directory with a test image + img_dir = tmp_path / "images" + img_dir.mkdir() + image = np.ones((100, 100, 3), dtype=np.uint8) * 128 + img_path = img_dir / "test_image.png" + cv2.imwrite(str(img_path), image) + return img_path + + +def test_extract_channels_rgb(setup_test_image): + image = cv2.imread(str(setup_test_image)) + channels = extract_channels(image, color_space='RGB') + assert 'R' in channels + assert 'G' in channels + assert 'B' in channels + assert channels['R'].shape == (100, 100) + assert channels['G'].shape == (100, 100) + assert channels['B'].shape == (100, 100) + + +def test_extract_channels_hsv(setup_test_image): + image = cv2.imread(str(setup_test_image)) + channels = extract_channels(image, color_space='HSV') + assert 'H' in channels + assert 'S' in channels + assert 'V' in channels + assert channels['H'].shape == (100, 100) + assert channels['S'].shape == (100, 100) + assert channels['V'].shape == (100, 100) + + +def test_merge_channels_rgb(setup_test_image): + image = cv2.imread(str(setup_test_image)) + channels = extract_channels(image, color_space='RGB') + merged_image = merge_channels(channels) + assert merged_image is not None + assert merged_image.shape == (100, 100, 3) + + +def test_merge_channels_hsv(setup_test_image): + image = cv2.imread(str(setup_test_image)) + channels = extract_channels(image, color_space='HSV') + merged_image = merge_channels(channels) + assert merged_image is not None + assert merged_image.shape == (100, 100, 3) + + +@pytest.fixture +def setup_test_directory(tmp_path): + # Create a temporary directory with test images + img_dir = tmp_path / "images" + img_dir.mkdir() + for i in range(3): + image = np.ones((100, 100, 3), dtype=np.uint8) * (i * 85) + img_path = img_dir / f"test_image_{i}.png" + cv2.imwrite(str(img_path), image) + return img_dir + + +def test_process_directory(setup_test_directory, tmp_path): + output_dir = tmp_path / "output" + process_directory(setup_test_directory, output_dir, color_space='RGB') + assert len(list(output_dir.glob("*.png"))) == 9 # 3 images * 3 channels + + +def test_process_directory_invalid_input(tmp_path): + invalid_dir = tmp_path / "invalid" + output_dir = tmp_path / "output" + with pytest.raises(NotADirectoryError): + process_directory(invalid_dir, output_dir, color_space='RGB') diff --git a/modules/lithium.pyimage/tests/test_histogram.py b/modules/lithium.pyimage/tests/test_histogram.py new file mode 100644 index 00000000..fcc08b7a --- /dev/null +++ b/modules/lithium.pyimage/tests/test_histogram.py @@ -0,0 +1,168 @@ +import pytest +import numpy as np +from unittest.mock import patch, MagicMock +from pathlib import Path +import sys +from ..histogram import main + +# FILE: modules/lithium.pyimage/image/transformation/test_histogram.py + + +@pytest.fixture +def sample_image(): + # Create a dummy BGR image + return np.zeros((100, 100, 3), dtype=np.uint8) + + +@pytest.fixture +def mock_args(): + args = [ + '--input', 'dummy_input.jpg', + '--output', 'dummy_output.jpg' + ] + return args + + +@pytest.fixture +def mock_cv2_imread(sample_image): + with patch('cv2.imread', return_value=sample_image) as mock_imread: + yield mock_imread + + +@pytest.fixture +def mock_cv2_imshow(): + with patch('cv2.imshow') as mock_imshow: + yield mock_imshow + + +@pytest.fixture +def mock_cv2_waitKey(): + with patch('cv2.waitKey', return_value=0) as mock_waitKey: + yield mock_waitKey + + +@pytest.fixture +def mock_cv2_destroyAllWindows(): + with patch('cv2.destroyAllWindows') as mock_destroy: + yield mock_destroy + + +@pytest.fixture +def mock_matplotlib_show(): + with patch('matplotlib.pyplot.show') as mock_show: + yield mock_show + + +@pytest.fixture +def mock_logger(): + with patch('histogram.logger') as mock_logger: + yield mock_logger + + +def test_main_success(mock_args, mock_cv2_imread, mock_cv2_imshow, + mock_cv2_waitKey, mock_cv2_destroyAllWindows, + mock_matplotlib_show, mock_logger): + with patch.object(sys, 'argv', ['histogram.py'] + mock_args): + main() + # Verify that imread was called with the correct input path + mock_cv2_imread.assert_called_with('dummy_input.jpg') + # Verify that imshow was called for original and transformed images + assert mock_cv2_imshow.call_count >= 4 # Original, transformed, etc. + # Verify that waitKey and destroyAllWindows were called + mock_cv2_waitKey.assert_called() + mock_cv2_destroyAllWindows.assert_called() + # Verify that matplotlib show was called + mock_matplotlib_show.assert_called() + # Verify logging calls + mock_logger.info.assert_any_call("Loading image from dummy_input.jpg") + mock_logger.info.assert_any_call( + "Image loaded successfully with shape (100, 100, 3)") + mock_logger.info.assert_any_call( + "Histogram transformation applied successfully") + mock_logger.info.assert_any_call("Applying auto clipping") + + +def test_main_invalid_input_path(mock_args, mock_cv2_imread, mock_logger): + # Configure imread to return None to simulate failed image load + mock_cv2_imread.return_value = None + with patch.object(sys, 'argv', ['histogram.py', '--input', 'non_existent.jpg', '--output', 'dummy_output.jpg']): + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + # Verify that an error was logged + mock_logger.error.assert_called_with( + "Failed to load image: non_existent.jpg") + + +def test_main_save_histogram(mock_args, mock_cv2_imread, mock_cv2_imshow, + mock_cv2_waitKey, mock_cv2_destroyAllWindows, + mock_matplotlib_show, mock_logger): + args = [ + '--input', 'dummy_input.jpg', + '--output', 'dummy_output.jpg', + '--save_histogram', 'histogram.png' + ] + with patch.object(sys, 'argv', ['histogram.py'] + args): + with patch('histogram.save_histogram') as mock_save_hist: + main() + # Verify that save_histogram was called with correct arguments + mock_save_hist.assert_called_once() + mock_logger.info.assert_any_call( + "Saving histogram to histogram.png") + + +def test_main_real_time_preview(mock_args, mock_cv2_imread, mock_cv2_imshow, + mock_cv2_waitKey, mock_cv2_destroyAllWindows, + mock_matplotlib_show, mock_logger): + args = [ + '--input', 'dummy_input.jpg', + '--output', 'dummy_output.jpg', + '--real_time_preview' + ] + with patch.object(sys, 'argv', ['histogram.py'] + args): + with patch('histogram.real_time_preview') as mock_preview: + main() + # Verify that real_time_preview was called + mock_preview.assert_called_once() + + +def test_main_missing_arguments(): + # Test missing required arguments + with patch.object(sys, 'argv', ['histogram.py', '--input', 'dummy_input.jpg']): + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code != 0 # Expecting non-zero exit due to missing --output + + +def test_main_invalid_arguments(mock_args, mock_cv2_imread, mock_cv2_imshow, + mock_cv2_waitKey, mock_cv2_destroyAllWindows, + mock_matplotlib_show, mock_logger): + # Test with invalid operation argument + args = [ + '--input', 'dummy_input.jpg', + '--output', 'dummy_output.jpg', + '--operation', 'invalid_op' + ] + with patch.object(sys, 'argv', ['histogram.py'] + args): + with pytest.raises(SystemExit): + main() + # Verify that an error was logged + mock_logger.error.assert_called() + + +def test_main_save_histogram_failure(mock_args, mock_cv2_imread, mock_cv2_imshow, + mock_cv2_waitKey, mock_cv2_destroyAllWindows, + mock_matplotlib_show, mock_logger): + args = [ + '--input', 'dummy_input.jpg', + '--output', 'dummy_output.jpg', + '--save_histogram', '/invalid_path/histogram.png' + ] + with patch.object(sys, 'argv', ['histogram.py'] + args): + with patch('histogram.save_histogram', side_effect=Exception("Save failed")): + with pytest.raises(Exception) as exc_info: + main() + assert str(exc_info.value) == "Save failed" + # Verify that an error was logged + mock_logger.error.assert_called_with( + "Failed to save histogram: Save failed") diff --git a/modules/lithium.pyimage/tests/test_preview.py b/modules/lithium.pyimage/tests/test_preview.py new file mode 100644 index 00000000..cd5c40e8 --- /dev/null +++ b/modules/lithium.pyimage/tests/test_preview.py @@ -0,0 +1,42 @@ +import pytest +import numpy as np +import cv2 +from pathlib import Path +from ..image.adaptive_stretch.preview import apply_real_time_preview + +# FILE: modules/lithium.pyimage/image/adaptive_stretch/test_preview.py + + +def test_apply_real_time_preview_grayscale(): + image = np.ones((100, 100), dtype=np.uint8) * 128 + processed_image = apply_real_time_preview(image) + assert processed_image.shape == image.shape + assert processed_image.dtype == np.uint8 + + +def test_apply_real_time_preview_color(): + image = np.ones((100, 100, 3), dtype=np.uint8) * 128 + processed_image = apply_real_time_preview(image) + assert processed_image.shape == image.shape + assert processed_image.dtype == np.uint8 + + +def test_apply_real_time_preview_with_noise_threshold(): + image = np.ones((100, 100), dtype=np.uint8) * 128 + processed_image = apply_real_time_preview(image, noise_threshold=0.01) + assert processed_image.shape == image.shape + assert processed_image.dtype == np.uint8 + + +def test_apply_real_time_preview_with_contrast_protection(): + image = np.ones((100, 100), dtype=np.uint8) * 128 + processed_image = apply_real_time_preview(image, contrast_protection=0.5) + assert processed_image.shape == image.shape + assert processed_image.dtype == np.uint8 + + +def test_apply_real_time_preview_with_roi(): + image = np.ones((100, 100), dtype=np.uint8) * 128 + processed_image = apply_real_time_preview(image, roi=(10, 10, 50, 50)) + assert processed_image.shape == image.shape + assert processed_image.dtype == np.uint8 diff --git a/modules/lithium.pyimage/tests/test_raw.py b/modules/lithium.pyimage/tests/test_raw.py new file mode 100644 index 00000000..85470e35 --- /dev/null +++ b/modules/lithium.pyimage/tests/test_raw.py @@ -0,0 +1,108 @@ +import pytest +import numpy as np +import cv2 +from pathlib import Path +from ..raw import RawImageProcessor, ImageFormat + +# FILE: modules/lithium.pyimage/image/raw/test_raw.py + + +@pytest.fixture +def setup_raw_image(tmp_path): + # Create a temporary directory with a test RAW image + raw_image_path = tmp_path / "test_image.raw" + # Create a dummy RAW image file (this should be replaced with an actual RAW file for real tests) + with open(raw_image_path, 'wb') as f: + f.write(b'\x00' * 1024) # Dummy content + return raw_image_path + + +def test_raw_image_processor_initialization(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + assert processor.raw_path == setup_raw_image + assert processor.rgb_image is not None + assert processor.bgr_image is not None + + +def test_adjust_contrast(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + original_image = processor.bgr_image.copy() + processor.adjust_contrast(alpha=1.5) + assert not np.array_equal(processor.bgr_image, original_image) + + +def test_adjust_brightness(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + original_image = processor.bgr_image.copy() + processor.adjust_brightness(beta=50) + assert not np.array_equal(processor.bgr_image, original_image) + + +def test_apply_sharpening(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + original_image = processor.bgr_image.copy() + processor.apply_sharpening() + assert not np.array_equal(processor.bgr_image, original_image) + + +def test_apply_gamma_correction(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + original_image = processor.bgr_image.copy() + processor.apply_gamma_correction(gamma=1.5) + assert not np.array_equal(processor.bgr_image, original_image) + + +def test_rotate_image(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + original_image = processor.bgr_image.copy() + processor.rotate_image(angle=45) + assert not np.array_equal(processor.bgr_image, original_image) + + +def test_resize_image(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + original_image = processor.bgr_image.copy() + processor.resize_image(width=200) + assert processor.bgr_image.shape[1] == 200 + + +def test_adjust_color_balance(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + original_image = processor.bgr_image.copy() + processor.adjust_color_balance(red=1.2, green=1.1, blue=0.9) + assert not np.array_equal(processor.bgr_image, original_image) + + +def test_apply_blur(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + original_image = processor.bgr_image.copy() + processor.apply_blur(ksize=(5, 5), method="gaussian") + assert not np.array_equal(processor.bgr_image, original_image) + + +def test_histogram_equalization(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + original_image = processor.bgr_image.copy() + processor.histogram_equalization() + assert not np.array_equal(processor.bgr_image, original_image) + + +def test_convert_to_grayscale(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + processor.convert_to_grayscale() + assert len(processor.bgr_image.shape) == 2 + + +def test_save_image(setup_raw_image, tmp_path): + processor = RawImageProcessor(raw_path=setup_raw_image) + output_path = tmp_path / "output_image.png" + processor.save_image(output_path=output_path, file_format=ImageFormat.PNG) + assert output_path.exists() + + +def test_reset(setup_raw_image): + processor = RawImageProcessor(raw_path=setup_raw_image) + processor.adjust_contrast(alpha=1.5) + processor.reset() + assert np.array_equal(processor.bgr_image, cv2.cvtColor( + processor.rgb_image, cv2.COLOR_RGB2BGR)) diff --git a/modules/lithium.pyimage/tests/test_resample.py b/modules/lithium.pyimage/tests/test_resample.py new file mode 100644 index 00000000..9257e6ea --- /dev/null +++ b/modules/lithium.pyimage/tests/test_resample.py @@ -0,0 +1,128 @@ +import pytest +import numpy as np +import cv2 +from pathlib import Path +from ..resample import Resampler, ImageFormat + +# FILE: modules/lithium.pyimage/image/resample/test_resample.py + + +@pytest.fixture +def setup_test_image(tmp_path): + # Create a temporary directory with a test image + img_dir = tmp_path / "images" + img_dir.mkdir() + image = np.ones((100, 100, 3), dtype=np.uint8) * 128 + img_path = img_dir / "test_image.jpg" + cv2.imwrite(str(img_path), image) + return img_path + + +def test_resampler_initialization(setup_test_image, tmp_path): + output_path = tmp_path / "output_image.jpg" + resampler = Resampler( + input_image_path=setup_test_image, + output_image_path=output_path + ) + assert resampler.input_image_path == setup_test_image + assert resampler.output_image_path == output_path + + +def test_adjust_brightness_contrast(setup_test_image, tmp_path): + output_path = tmp_path / "output_image.jpg" + resampler = Resampler( + input_image_path=setup_test_image, + output_image_path=output_path, + brightness=1.5, + contrast=1.5 + ) + resampler.process() + processed_image = cv2.imread(str(output_path)) + assert processed_image is not None + + +def test_apply_sharpening(setup_test_image, tmp_path): + output_path = tmp_path / "output_image.jpg" + resampler = Resampler( + input_image_path=setup_test_image, + output_image_path=output_path, + sharpen=True + ) + resampler.process() + processed_image = cv2.imread(str(output_path)) + assert processed_image is not None + + +def test_apply_blur(setup_test_image, tmp_path): + output_path = tmp_path / "output_image.jpg" + resampler = Resampler( + input_image_path=setup_test_image, + output_image_path=output_path, + blur=(5, 5), + blur_method='gaussian' + ) + resampler.process() + processed_image = cv2.imread(str(output_path)) + assert processed_image is not None + + +def test_histogram_equalization(setup_test_image, tmp_path): + output_path = tmp_path / "output_image.jpg" + resampler = Resampler( + input_image_path=setup_test_image, + output_image_path=output_path, + histogram_equalization=True + ) + resampler.process() + processed_image = cv2.imread(str(output_path)) + assert processed_image is not None + + +def test_convert_to_grayscale(setup_test_image, tmp_path): + output_path = tmp_path / "output_image.jpg" + resampler = Resampler( + input_image_path=setup_test_image, + output_image_path=output_path, + grayscale=True + ) + resampler.process() + processed_image = cv2.imread(str(output_path), cv2.IMREAD_GRAYSCALE) + assert processed_image is not None + assert len(processed_image.shape) == 2 + + +def test_rotate_image(setup_test_image, tmp_path): + output_path = tmp_path / "output_image.jpg" + resampler = Resampler( + input_image_path=setup_test_image, + output_image_path=output_path, + rotate_angle=45 + ) + resampler.process() + processed_image = cv2.imread(str(output_path)) + assert processed_image is not None + + +def test_resize_image(setup_test_image, tmp_path): + output_path = tmp_path / "output_image.jpg" + resampler = Resampler( + input_image_path=setup_test_image, + output_image_path=output_path, + width=200, + height=200 + ) + resampler.process() + processed_image = cv2.imread(str(output_path)) + assert processed_image.shape[1] == 200 + assert processed_image.shape[0] == 200 + + +def test_process_invalid_input(tmp_path): + invalid_image_path = tmp_path / "invalid_image.jpg" + output_path = tmp_path / "output_image.jpg" + resampler = Resampler( + input_image_path=invalid_image_path, + output_image_path=output_path + ) + with pytest.raises(ValueError): + resampler.process() diff --git a/modules/lithium.pyimage/tests/test_stretch.py b/modules/lithium.pyimage/tests/test_stretch.py new file mode 100644 index 00000000..582fb4a7 --- /dev/null +++ b/modules/lithium.pyimage/tests/test_stretch.py @@ -0,0 +1,64 @@ +import pytest +import numpy as np +import cv2 +from pathlib import Path +from ..image.adaptive_stretch.stretch import AdaptiveStretch + +# FILE: modules/lithium.pyimage/image/adaptive_stretch/test_stretch.py + + +def test_adaptive_stretch_initialization(): + stretcher = AdaptiveStretch(noise_threshold=0.01, contrast_protection=0.5, max_curve_points=50, roi=( + 10, 10, 50, 50), save_intermediate=True, intermediate_dir=Path("test_intermediate")) + assert stretcher.noise_threshold == 0.01 + assert stretcher.contrast_protection == 0.5 + assert stretcher.max_curve_points == 50 + assert stretcher.roi == (10, 10, 50, 50) + assert stretcher.save_intermediate is True + assert stretcher.intermediate_dir == Path("test_intermediate") + + +def test_compute_brightness_diff(): + stretcher = AdaptiveStretch() + image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.uint8) + diff_x, diff_y = stretcher.compute_brightness_diff(image) + expected_diff_x = np.array( + [[1, 1, 0], [1, 1, 0], [1, 1, 0]], dtype=np.int8) + expected_diff_y = np.array( + [[3, 3, 3], [3, 3, 3], [0, 0, 0]], dtype=np.int8) + assert np.array_equal(diff_x, expected_diff_x) + assert np.array_equal(diff_y, expected_diff_y) + + +def test_stretch_grayscale_image(): + stretcher = AdaptiveStretch() + image = np.ones((100, 100), dtype=np.uint8) * 128 + stretched_image = stretcher.stretch(image) + assert stretched_image.shape == image.shape + assert stretched_image.dtype == np.uint8 + + +def test_stretch_color_image(): + stretcher = AdaptiveStretch() + image = np.ones((100, 100, 3), dtype=np.uint8) * 128 + stretched_image = stretcher.stretch(image) + assert stretched_image.shape == image.shape + assert stretched_image.dtype == np.uint8 + + +def test_stretch_with_roi(): + stretcher = AdaptiveStretch(roi=(10, 10, 50, 50)) + image = np.ones((100, 100), dtype=np.uint8) * 128 + stretched_image = stretcher.stretch(image) + assert stretched_image.shape == image.shape + assert stretched_image.dtype == np.uint8 + + +def test_save_intermediate_results(tmp_path): + intermediate_dir = tmp_path / "intermediate" + stretcher = AdaptiveStretch( + save_intermediate=True, intermediate_dir=intermediate_dir) + image = np.ones((100, 100), dtype=np.uint8) * 128 + stretcher.stretch(image) + assert intermediate_dir.exists() + assert len(list(intermediate_dir.glob("*.png"))) > 0 diff --git a/modules/lithium.pyimage/tests/test_utils.py b/modules/lithium.pyimage/tests/test_utils.py new file mode 100644 index 00000000..f8ed1a5d --- /dev/null +++ b/modules/lithium.pyimage/tests/test_utils.py @@ -0,0 +1,72 @@ +import pytest +import numpy as np +from pathlib import Path +from astropy.io import fits +from ..utils import read_fits_header +from ..core import CalibrationParams + +# FILE: modules/lithium.pyimage/image/fluxcalibration/test_utils.py + + +@pytest.fixture +def setup_fits_file(tmp_path): + # Create a temporary directory with a test FITS file + fits_path = tmp_path / "test_image.fits" + data = np.zeros((100, 100), dtype=np.float32) + hdu = fits.PrimaryHDU(data) + hdr = hdu.header + hdr['WAVELEN'] = 500.0 + hdr['TRANSMIS'] = 0.8 + hdr['FILTWDTH'] = 100.0 + hdr['APERTURE'] = 200.0 + hdr['OBSTRUCT'] = 50.0 + hdr['EXPTIME'] = 60.0 + hdr['EXTINCT'] = 0.1 + hdr['GAIN'] = 1.5 + hdr['QUANTEFF'] = 0.9 + hdu.writeto(fits_path) + return fits_path + + +def test_read_fits_header_valid(setup_fits_file): + params = read_fits_header(str(setup_fits_file)) + assert params.wavelength == 500.0 + assert params.transmissivity == 0.8 + assert params.filter_width == 100.0 + assert params.aperture == 200.0 + assert params.obstruction == 50.0 + assert params.exposure_time == 60.0 + assert params.extinction == 0.1 + assert params.gain == 1.5 + assert params.quantum_efficiency == 0.9 + + +def test_read_fits_header_missing_keys(tmp_path): + fits_path = tmp_path / "test_image_missing_keys.fits" + data = np.zeros((100, 100), dtype=np.float32) + hdu = fits.PrimaryHDU(data) + hdu.writeto(fits_path) + params = read_fits_header(str(fits_path)) + assert params.wavelength == 550 # Default value + assert params.transmissivity == 0.8 # Default value + assert params.filter_width == 100 # Default value + assert params.aperture == 200 # Default value + assert params.obstruction == 50 # Default value + assert params.exposure_time == 60 # Default value + assert params.extinction == 0.1 # Default value + assert params.gain == 1.5 # Default value + assert params.quantum_efficiency == 0.9 # Default value + + +def test_read_fits_header_invalid_file(tmp_path): + invalid_fits_path = tmp_path / "invalid_image.fits" + with open(invalid_fits_path, 'w') as f: + f.write("This is not a valid FITS file.") + with pytest.raises(OSError): + read_fits_header(str(invalid_fits_path)) + + +def test_read_fits_header_non_existent_file(): + non_existent_path = "non_existent_file.fits" + with pytest.raises(OSError): + read_fits_header(non_existent_path) diff --git a/modules/lithium.pytools/tests/test_atom_generator.py b/modules/lithium.pytools/tests/test_atom_generator.py new file mode 100644 index 00000000..7bcbf424 --- /dev/null +++ b/modules/lithium.pytools/tests/test_atom_generator.py @@ -0,0 +1,114 @@ +import pytest +from unittest.mock import patch, MagicMock +from pathlib import Path +from ..atom_generator import generate_atom_module +import tempfile + +# FILE: modules/lithium.pytools/tools/test_atom_generator.py + + +@pytest.fixture +def sample_header_file(): + # Create a temporary C++ header file + header_content = """ + namespace TestNamespace { + class TestClass { + public: + void testMethod(); + }; + void testFunction(); + } + """ + with tempfile.NamedTemporaryFile(delete=False, suffix=".h") as temp_file: + temp_file.write(header_content.encode('utf-8')) + temp_file_path = Path(temp_file.name) + yield temp_file_path + temp_file_path.unlink() # Clean up the file after the test + + +@pytest.fixture +def mock_parse_header_file(): + with patch('modules.lithium.pytools.tools.atom_generator.parse_header_file') as mock_parse: + mock_parse.return_value = ( + {'TestNamespace::TestClass': ['testMethod']}, + ['TestNamespace::testFunction'] + ) + yield mock_parse + + +@pytest.fixture +def mock_logger(): + with patch('modules.lithium.pytools.tools.atom_generator.logger') as mock_log: + yield mock_log + + +def test_generate_atom_module_single_file(sample_header_file, mock_parse_header_file, mock_logger): + generate_atom_module([sample_header_file], log_level="DEBUG") + mock_logger.info.assert_any_call("Generating ATOM_MODULE...\n") + mock_logger.info.assert_any_call( + "Registered method: TestNamespace::TestClass::testMethod") + mock_logger.info.assert_any_call( + "Registered global function: TestNamespace::testFunction") + mock_logger.info.assert_any_call("ATOM_MODULE generation completed.\n") + + +def test_generate_atom_module_multiple_files(sample_header_file, mock_parse_header_file, mock_logger): + generate_atom_module( + [sample_header_file, sample_header_file], log_level="DEBUG") + mock_logger.info.assert_any_call("Generating ATOM_MODULE...\n") + # Ensure multiple calls for multiple files + assert mock_logger.info.call_count >= 6 + + +def test_generate_atom_module_with_whitelist(sample_header_file, mock_parse_header_file, mock_logger): + generate_atom_module([sample_header_file], whitelist=[ + 'testMethod'], log_level="DEBUG") + mock_logger.info.assert_any_call("Generating ATOM_MODULE...\n") + mock_logger.info.assert_any_call( + "Registered method: TestNamespace::TestClass::testMethod") + mock_logger.info.assert_any_call("ATOM_MODULE generation completed.\n") + + +def test_generate_atom_module_with_blacklist(sample_header_file, mock_parse_header_file, mock_logger): + generate_atom_module([sample_header_file], blacklist=[ + 'testMethod'], log_level="DEBUG") + mock_logger.info.assert_any_call("Generating ATOM_MODULE...\n") + mock_logger.info.assert_any_call("ATOM_MODULE generation completed.\n") + assert not any( + call for call in mock_logger.info.call_args_list if "Registered method" in call[0][0]) + + +def test_generate_atom_module_output_to_file(sample_header_file, mock_parse_header_file, mock_logger): + with tempfile.NamedTemporaryFile(delete=False, suffix=".cpp") as temp_file: + output_path = Path(temp_file.name) + generate_atom_module([sample_header_file], + output_file=output_path, log_level="DEBUG") + with open(output_path, 'r') as f: + generated_code = f.read() + assert 'ATOM_MODULE(all_components, [](Component &component) {' in generated_code + output_path.unlink() # Clean up the file after the test + + +def test_generate_atom_module_output_to_console(sample_header_file, mock_parse_header_file, mock_logger): + with patch('builtins.print') as mock_print: + generate_atom_module([sample_header_file], log_level="DEBUG") + mock_print.assert_any_call( + 'ATOM_MODULE(all_components, [](Component &component) {', end="") + + +def test_generate_atom_module_invalid_header_file(mock_parse_header_file, mock_logger): + with pytest.raises(Exception): + generate_atom_module( + [Path('/invalid/path/to/header.h')], log_level="DEBUG") + mock_logger.error.assert_called() + + +def test_generate_atom_module_missing_methods_or_functions(sample_header_file, mock_logger): + with patch('modules.lithium.pytools.tools.atom_generator.parse_header_file', return_value=({}, [])): + generate_atom_module([sample_header_file], log_level="DEBUG") + mock_logger.info.assert_any_call("Generating ATOM_MODULE...\n") + mock_logger.info.assert_any_call("ATOM_MODULE generation completed.\n") + assert not any( + call for call in mock_logger.info.call_args_list if "Registered method" in call[0][0]) + assert not any( + call for call in mock_logger.info.call_args_list if "Registered global function" in call[0][0]) diff --git a/modules/lithium.pytools/tests/test_libclang_finder.py b/modules/lithium.pytools/tests/test_libclang_finder.py new file mode 100644 index 00000000..c246d5ab --- /dev/null +++ b/modules/lithium.pytools/tests/test_libclang_finder.py @@ -0,0 +1,141 @@ +import pytest +from unittest.mock import patch, MagicMock +from pathlib import Path +from ..libclang_finder import LibClangFinder, LibClangFinderConfig + +# FILE: modules/lithium.pytools/tools/test_libclang_finder.py + + +@pytest.fixture +def libclang_finder_config(): + # Provide a default LibClangFinderConfig instance + return LibClangFinderConfig() + + +@pytest.fixture +def libclang_finder_instance(libclang_finder_config): + # Provide a LibClangFinder instance initialized with the default configuration + return LibClangFinder(config=libclang_finder_config) + + +def test_initialization_default(): + config = LibClangFinderConfig() + finder = LibClangFinder(config) + assert finder.config.method == 'bilinear' + assert finder.config.pattern is None + assert finder.config.num_threads == 4 + assert not finder.config.visualize_intermediate + assert finder.config.visualization_save_path is None + assert not finder.config.save_debayered_images + + +def test_initialization_custom(): + config = LibClangFinderConfig(custom_path=Path( + '/custom/path/to/libclang.so'), clear_cache=True) + finder = LibClangFinder(config) + assert finder.config.custom_path == Path('/custom/path/to/libclang.so') + assert finder.config.clear_cache is True + + +def test_clear_cache(libclang_finder_instance): + with patch('pathlib.Path.unlink') as mock_unlink: + libclang_finder_instance.clear_cache() + mock_unlink.assert_called_once() + + +def test_cache_libclang_path(libclang_finder_instance): + path = Path('/path/to/libclang.so') + with patch('pathlib.Path.write_text') as mock_write_text: + libclang_finder_instance.cache_libclang_path(path) + mock_write_text.assert_called_once_with(str(path)) + + +def test_load_cached_libclang_path(libclang_finder_instance): + path = Path('/path/to/libclang.so') + with patch('pathlib.Path.exists', return_value=True), \ + patch('pathlib.Path.read_text', return_value=str(path)), \ + patch('pathlib.Path.is_file', return_value=True): + cached_path = libclang_finder_instance.load_cached_libclang_path() + assert cached_path == path + + +def test_find_libclang_linux(libclang_finder_instance): + with patch('glob.glob', return_value=['/usr/lib/llvm-18/lib/libclang.so']), \ + patch('pathlib.Path.is_file', return_value=True): + path = libclang_finder_instance.find_libclang_linux() + assert path == Path('/usr/lib/llvm-18/lib/libclang.so') + + +def test_find_libclang_macos(libclang_finder_instance): + with patch('glob.glob', return_value=['/usr/local/opt/llvm/lib/libclang.dylib']), \ + patch('pathlib.Path.is_file', return_value=True): + path = libclang_finder_instance.find_libclang_macos() + assert path == Path('/usr/local/opt/llvm/lib/libclang.dylib') + + +def test_find_libclang_windows(libclang_finder_instance): + with patch('glob.glob', return_value=['C:\\Program Files\\LLVM\\bin\\libclang.dll']), \ + patch('pathlib.Path.is_file', return_value=True): + path = libclang_finder_instance.find_libclang_windows() + assert path == Path('C:\\Program Files\\LLVM\\bin\\libclang.dll') + + +def test_search_paths(libclang_finder_instance): + patterns = ['/usr/lib/llvm-*/lib/libclang.so*'] + with patch('glob.glob', return_value=['/usr/lib/llvm-18/lib/libclang.so']), \ + patch('pathlib.Path.is_file', return_value=True): + paths = libclang_finder_instance.search_paths(patterns) + assert paths == [Path('/usr/lib/llvm-18/lib/libclang.so')] + + +def test_select_libclang_path(libclang_finder_instance): + paths = [Path('/usr/lib/llvm-18/lib/libclang.so')] + selected_path = libclang_finder_instance.select_libclang_path(paths) + assert selected_path == Path('/usr/lib/llvm-18/lib/libclang.so') + + +def test_get_libclang_path_with_custom_path(libclang_finder_instance): + libclang_finder_instance.config.custom_path = Path( + '/custom/path/to/libclang.so') + with patch('pathlib.Path.is_file', return_value=True), \ + patch.object(libclang_finder_instance, 'cache_libclang_path') as mock_cache: + path = libclang_finder_instance.get_libclang_path() + assert path == Path('/custom/path/to/libclang.so') + mock_cache.assert_called_once_with(Path('/custom/path/to/libclang.so')) + + +def test_get_libclang_path_with_cached_path(libclang_finder_instance): + cached_path = Path('/cached/path/to/libclang.so') + with patch.object(libclang_finder_instance, 'load_cached_libclang_path', return_value=cached_path): + path = libclang_finder_instance.get_libclang_path() + assert path == cached_path + + +def test_get_libclang_path_with_detection(libclang_finder_instance): + with patch('platform.system', return_value='Linux'), \ + patch.object(libclang_finder_instance, 'find_libclang_linux', return_value=Path('/usr/lib/llvm-18/lib/libclang.so')), \ + patch.object(libclang_finder_instance, 'cache_libclang_path') as mock_cache: + path = libclang_finder_instance.get_libclang_path() + assert path == Path('/usr/lib/llvm-18/lib/libclang.so') + mock_cache.assert_called_once_with( + Path('/usr/lib/llvm-18/lib/libclang.so')) + + +def test_get_libclang_path_unsupported_os(libclang_finder_instance): + with patch('platform.system', return_value='UnsupportedOS'): + with pytest.raises(RuntimeError, match="Unsupported operating system: UnsupportedOS"): + libclang_finder_instance.get_libclang_path() + + +def test_configure_clang(libclang_finder_instance): + with patch.object(libclang_finder_instance, 'get_libclang_path', return_value=Path('/path/to/libclang.so')), \ + patch('clang.cindex.Config.set_library_file') as mock_set_library_file: + libclang_finder_instance.configure_clang() + mock_set_library_file.assert_called_once_with('/path/to/libclang.so') + + +def test_list_libclang_versions(libclang_finder_instance): + with patch('platform.system', return_value='Linux'), \ + patch.object(libclang_finder_instance, 'find_libclang_linux', return_value=[Path('/usr/lib/llvm-18/lib/libclang.so')]): + paths = libclang_finder_instance.list_libclang_versions() + assert paths == [Path('/usr/lib/llvm-18/lib/libclang.so')] diff --git a/pysrc/addon/generator.py b/modules/lithium.pytools/tools/atom_generator.py similarity index 75% rename from pysrc/addon/generator.py rename to modules/lithium.pytools/tools/atom_generator.py index bd852102..32992d9c 100644 --- a/pysrc/addon/generator.py +++ b/modules/lithium.pytools/tools/atom_generator.py @@ -23,6 +23,9 @@ import threading import yaml from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List, Optional, Dict, Tuple +from dataclasses import dataclass, field from .libclang_finder import get_libclang_path @@ -39,7 +42,22 @@ DEFAULT_WHITELIST = [] -def parse_args(): +@dataclass +class GeneratorConfig: + """ + Configuration parameters for the ATOM_MODULE generator. + """ + filepaths: List[Path] + output: Optional[Path] = None + log_level: str = "INFO" + blacklist: List[str] = field(default_factory=lambda: DEFAULT_BLACKLIST) + whitelist: List[str] = field(default_factory=lambda: DEFAULT_WHITELIST) + module_name: str = "all_components" + instance_prefix: str = "" + config_file: Optional[Path] = None + + +def parse_args() -> GeneratorConfig: """ Parses command-line arguments for the script. @@ -48,13 +66,13 @@ def parse_args(): instance prefix. Returns: - argparse.Namespace: The parsed command-line arguments. + GeneratorConfig: The parsed command-line arguments. """ parser = argparse.ArgumentParser( description="Generate ATOM_MODULE from C++ headers.") - parser.add_argument("filepaths", nargs='+', + parser.add_argument("filepaths", nargs='+', type=Path, help="Paths to the C++ header files.") - parser.add_argument("--output", default=None, + parser.add_argument("--output", type=Path, default=None, help="Output file for generated module code.") parser.add_argument("--log-level", default="INFO", help="Set the log level (DEBUG, INFO, WARNING, ERROR).") @@ -66,22 +84,32 @@ def parse_args(): help="Name of the generated ATOM_MODULE.") parser.add_argument("--instance-prefix", default="", help="Prefix for instance names in the module.") - parser.add_argument("--config-file", default=None, + parser.add_argument("--config-file", type=Path, default=None, help="Path to a YAML configuration file.") - return parser.parse_args() + args = parser.parse_args() + return GeneratorConfig( + filepaths=args.filepaths, + output=args.output, + log_level=args.log_level, + blacklist=args.blacklist, + whitelist=args.whitelist, + module_name=args.module_name, + instance_prefix=args.instance_prefix, + config_file=args.config_file + ) -def load_config(config_file): +def load_config(config_file: Optional[Path]) -> Dict: """ Loads configuration from a YAML file. Args: - config_file (str): Path to the configuration file. + config_file (Optional[Path]): Path to the configuration file. Returns: dict: The loaded configuration. """ - if config_file and os.path.isfile(config_file): + if config_file and config_file.is_file(): with open(config_file, 'r') as f: config = yaml.safe_load(f) logger.info(f"Loaded configuration from {config_file}") @@ -89,14 +117,14 @@ def load_config(config_file): return {} -def is_in_list(name, whitelist, blacklist): +def is_in_list(name: str, whitelist: List[str], blacklist: List[str]) -> bool: """ Checks if a name is in the whitelist and not in the blacklist. Args: name (str): The name to check. - whitelist (list): The list of whitelisted names. - blacklist (list): The list of blacklisted names. + whitelist (List[str]): The list of whitelisted names. + blacklist (List[str]): The list of blacklisted names. Returns: bool: True if the name is in the whitelist or if the whitelist is empty, and not in the blacklist. @@ -108,15 +136,15 @@ def is_in_list(name, whitelist, blacklist): return True -def find_classes_methods_and_functions(node, namespace="", whitelist=None, blacklist=None): +def find_classes_methods_and_functions(node: clang.cindex.Cursor, namespace: str = "", whitelist: Optional[List[str]] = None, blacklist: Optional[List[str]] = None) -> Tuple[Dict[str, List[str]], List[str]]: """ Recursively finds classes, methods, and functions within the AST. Args: node (clang.cindex.Cursor): The current AST node. namespace (str, optional): The current namespace. - whitelist (list, optional): List of whitelisted functions or methods. - blacklist (list, optional): List of blacklisted functions or methods. + whitelist (List[str], optional): List of whitelisted functions or methods. + blacklist (List[str], optional): List of blacklisted functions or methods. Returns: tuple: A dictionary of classes and their methods, and a list of functions. @@ -155,34 +183,34 @@ def find_classes_methods_and_functions(node, namespace="", whitelist=None, black return classes, functions -def parse_header_file(filepath, whitelist=None, blacklist=None): +def parse_header_file(filepath: Path, whitelist: Optional[List[str]] = None, blacklist: Optional[List[str]] = None) -> Tuple[Dict[str, List[str]], List[str]]: """ Parses a C++ header file to extract classes, methods, and functions. Args: - filepath (str): Path to the C++ header file. - whitelist (list, optional): List of whitelisted functions or methods. - blacklist (list, optional): List of blacklisted functions or methods. + filepath (Path): Path to the C++ header file. + whitelist (List[str], optional): List of whitelisted functions or methods. + blacklist (List[str], optional): List of blacklisted functions or methods. Returns: tuple: A dictionary of classes and their methods, and a list of functions. """ index = clang.cindex.Index.create() - translation_unit = index.parse(filepath) + translation_unit = index.parse(str(filepath)) logger.info(f"Parsing the file: {filepath}\n") return find_classes_methods_and_functions(translation_unit.cursor, whitelist=whitelist, blacklist=blacklist) -def generate_atom_module(filepaths, output_file=None, log_level="INFO", whitelist=None, blacklist=None, module_name="all_components", instance_prefix=""): +def generate_atom_module(filepaths: List[Path], output_file: Optional[Path] = None, log_level: str = "INFO", whitelist: Optional[List[str]] = None, blacklist: Optional[List[str]] = None, module_name: str = "all_components", instance_prefix: str = ""): """ Generates ATOM_MODULE code from specified C++ header files. Args: - filepaths (list): List of paths to C++ header files. - output_file (str, optional): Path to the output file for generated code. If None, prints to console. + filepaths (List[Path]): List of paths to C++ header files. + output_file (Optional[Path], optional): Path to the output file for generated code. If None, prints to console. log_level (str, optional): The log level for the output (DEBUG, INFO, WARNING, ERROR). - whitelist (list, optional): List of whitelisted functions or methods. - blacklist (list, optional): List of blacklisted functions or methods. + whitelist (List[str], optional): List of whitelisted functions or methods. + blacklist (List[str], optional): List of blacklisted functions or methods. module_name (str, optional): Name of the generated ATOM_MODULE. instance_prefix (str, optional): Prefix for instance names in the module. """ @@ -249,8 +277,7 @@ def generate_atom_module(filepaths, output_file=None, log_level="INFO", whitelis print(generated_code) -# Run the script -if __name__ == "__main__": +def main(): """ Entry point for the script execution. @@ -268,3 +295,7 @@ def generate_atom_module(filepaths, output_file=None, log_level="INFO", whitelis args.module_name or config.get('module_name', 'all_components'), args.instance_prefix or config.get('instance_prefix', '') ) + + +if __name__ == "__main__": + main() diff --git a/modules/lithium.pytools/tools/libclang_finder.py b/modules/lithium.pytools/tools/libclang_finder.py new file mode 100644 index 00000000..3f0b2c99 --- /dev/null +++ b/modules/lithium.pytools/tools/libclang_finder.py @@ -0,0 +1,197 @@ +import platform +import glob +import argparse +from pathlib import Path +from typing import List, Optional +from dataclasses import dataclass, field + +from clang.cindex import Config +from loguru import logger + + +@dataclass +class LibClangFinderConfig: + custom_path: Optional[Path] = None + clear_cache: bool = False + search_patterns: List[str] = field(default_factory=list) + cache_file: Path = Path("libclang_path_cache.txt") + log_file: Path = Path("libclang_finder.log") + + +class LibClangFinder: + def __init__(self, config: LibClangFinderConfig): + self.config = config + self.libclang_path: Optional[Path] = None + + # Configure logging with loguru + logger.remove() + logger.add( + self.config.log_file, + rotation="1 MB", + retention="10 days", + level="DEBUG", + format="{time} | {level} | {message}" + ) + logger.debug(f"LibClangFinder initialized with config: {self.config}") + + def clear_cache(self): + if self.config.cache_file.exists(): + self.config.cache_file.unlink() + logger.info(f"Cleared cache file: {self.config.cache_file}") + + def cache_libclang_path(self, path: Path): + self.config.cache_file.write_text(str(path)) + logger.info(f"Cached libclang path: {path}") + + def load_cached_libclang_path(self) -> Optional[Path]: + if self.config.cache_file.exists(): + path = Path(self.config.cache_file.read_text().strip()) + if path.is_file(): + logger.info(f"Loaded cached libclang path: {path}") + return path + logger.debug("No valid cached libclang path found.") + return None + + def find_libclang_linux(self) -> Optional[Path]: + possible_patterns = [ + '/usr/lib/llvm-*/lib/libclang.so*', + '/usr/local/lib/llvm-*/lib/libclang.so*', + '/usr/lib/x86_64-linux-gnu/libclang.so*', + '/usr/local/lib/x86_64-linux-gnu/libclang.so*', + ] + logger.info("Searching for libclang on Linux...") + paths = self.search_paths(possible_patterns) + return self.select_libclang_path(paths) + + def find_libclang_macos(self) -> Optional[Path]: + possible_patterns = [ + '/usr/local/opt/llvm/lib/libclang.dylib', + '/usr/local/lib/libclang.dylib', + '/Library/Developer/CommandLineTools/usr/lib/libclang.dylib', + '/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/libclang.dylib', + ] + logger.info("Searching for libclang on macOS...") + paths = self.search_paths(possible_patterns) + return self.select_libclang_path(paths) + + def find_libclang_windows(self) -> Optional[Path]: + possible_patterns = [ + 'C:\\Program Files\\LLVM\\bin\\libclang.dll', + 'C:\\Program Files (x86)\\LLVM\\bin\\libclang.dll', + 'C:\\LLVM\\bin\\libclang.dll', + ] + logger.info("Searching for libclang on Windows...") + paths = self.search_paths(possible_patterns) + return self.select_libclang_path(paths) + + def search_paths(self, patterns: List[str]) -> List[Path]: + all_patterns = patterns + self.config.search_patterns + found_paths = [] + for pattern in all_patterns: + matches = glob.glob(pattern) + found_paths.extend(matches) + logger.debug( + f"Searching with pattern '{pattern}', found: {matches}") + unique_paths = sorted(set(Path(p) + for p in found_paths if Path(p).is_file())) + logger.debug(f"Total unique libclang paths found: {unique_paths}") + return unique_paths + + def select_libclang_path(self, paths: List[Path]) -> Optional[Path]: + if paths: + selected_path = paths[-1] + logger.info(f"Selected libclang path: {selected_path}") + return selected_path + logger.error("No libclang library found.") + return None + + def get_libclang_path(self) -> Path: + if self.config.clear_cache: + self.clear_cache() + + if self.config.custom_path and self.config.custom_path.is_file(): + logger.info( + f"Using custom libclang path: {self.config.custom_path}") + self.cache_libclang_path(self.config.custom_path) + return self.config.custom_path + + cached_path = self.load_cached_libclang_path() + if cached_path: + return cached_path + + system = platform.system() + logger.info(f"Detected operating system: {system}") + find_method = { + 'Linux': self.find_libclang_linux, + 'Darwin': self.find_libclang_macos, + 'Windows': self.find_libclang_windows, + }.get(system) + + if find_method: + libclang_path = find_method() + if libclang_path: + self.cache_libclang_path(libclang_path) + return libclang_path + else: + logger.error(f"Unsupported operating system: {system}") + raise RuntimeError(f"Unsupported operating system: {system}") + + raise RuntimeError("libclang library not found.") + + def configure_clang(self): + libclang_path = self.get_libclang_path() + logger.info(f"Setting libclang path to: {libclang_path}") + Config.set_library_file(str(libclang_path)) + + def list_libclang_versions(self) -> List[Path]: + system = platform.system() + logger.info(f"Listing all libclang versions on {system}") + find_method = { + 'Linux': self.find_libclang_linux, + 'Darwin': self.find_libclang_macos, + 'Windows': self.find_libclang_windows, + }.get(system) + + if find_method: + paths = find_method() + logger.info(f"Available libclang libraries: {paths}") + return paths + else: + logger.error(f"Unsupported operating system: {system}") + return [] + + +def parse_arguments() -> LibClangFinderConfig: + parser = argparse.ArgumentParser(description="libclang Path Finder") + parser.add_argument('--path', type=Path, + help="Custom path to libclang library") + parser.add_argument('--clear-cache', action='store_true', + help="Clear cached libclang path") + parser.add_argument('--search-patterns', nargs='*', default=[], + help="Additional glob patterns to search for libclang") + parser.add_argument('--cache-file', type=Path, default=Path( + "libclang_path_cache.txt"), help="Path to the cache file") + parser.add_argument('--log-file', type=Path, + default=Path("libclang_finder.log"), help="Path to the log file") + args = parser.parse_args() + return LibClangFinderConfig( + custom_path=args.path, + clear_cache=args.clear_cache, + search_patterns=args.search_patterns, + cache_file=args.cache_file, + log_file=args.log_file, + ) + + +def main(): + config = parse_arguments() + finder = LibClangFinder(config) + try: + finder.configure_clang() + logger.info("libclang configured successfully.") + except Exception as e: + logger.error(f"Failed to configure libclang: {e}") + + +if __name__ == "__main__": + main() diff --git a/pysrc/addon/libclang_finder.py b/pysrc/addon/libclang_finder.py deleted file mode 100644 index 99c983cd..00000000 --- a/pysrc/addon/libclang_finder.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -libclang Path Finder - -This script automatically locates the path to the libclang library depending on the operating system. -It supports Linux, macOS, and Windows platforms. The script uses the `loguru` library for detailed -logging and provides comprehensive error handling. - -Dependencies: -- clang.cindex -- loguru - -Usage: -- The script automatically determines the path to `libclang` and sets it for the `clang.cindex.Config`. -""" - -import platform -import os -import glob -import subprocess -import argparse -from clang.cindex import Config -from loguru import logger - -CACHE_FILE = "libclang_path_cache.txt" - - -def cache_libclang_path(path: str): - """Caches the found libclang path to a file.""" - with open(CACHE_FILE, 'w') as f: - f.write(path) - logger.info(f"Cached libclang path: {path}") - - -def load_cached_libclang_path() -> str | None: - """Loads the cached libclang path from a file.""" - if os.path.exists(CACHE_FILE): - with open(CACHE_FILE, 'r') as f: - path = f.read().strip() - if os.path.isfile(path): - logger.info(f"Loaded cached libclang path: {path}") - return path - return None - - -def find_libclang_linux(): - """Searches for the libclang library on Linux systems.""" - possible_paths = [ - '/usr/lib/llvm-*/*/lib/libclang.so', - '/usr/local/lib/llvm-*/*/lib/libclang.so', - '/usr/lib/x86_64-linux-gnu/libclang.so', - '/usr/local/lib/x86_64-linux-gnu/libclang.so', - '/usr/lib/llvm-*/lib/libclang.so', - '/usr/local/lib/llvm-*/lib/libclang.so' - ] - logger.info("Searching for libclang on Linux...") - for pattern in possible_paths: - for path in glob.glob(pattern): - if os.path.isfile(path): - logger.info(f"Found libclang at {path}") - return path - logger.error("libclang not found on Linux") - raise RuntimeError("libclang not found on Linux") - - -def find_libclang_macos(): - """Searches for the libclang library on macOS systems.""" - possible_paths = [ - '/usr/local/opt/llvm/lib/libclang.dylib', - '/usr/local/lib/libclang.dylib', - '/Library/Developer/CommandLineTools/usr/lib/libclang.dylib', - '/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/libclang.dylib' - ] - logger.info("Searching for libclang on macOS...") - for path in possible_paths: - if os.path.isfile(path): - logger.info(f"Found libclang at {path}") - return path - # Fallback: try using `find` command - try: - result = subprocess.run(['find', '/', '-name', 'libclang.dylib'], - stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - paths = result.stdout.strip().split('\n') - if paths: - logger.info(f"Found libclang at {paths[0]}") - return paths[0] - except Exception as e: - logger.error(f"Error while searching for libclang: {e}") - logger.error("libclang not found on macOS") - raise RuntimeError("libclang not found on macOS") - - -def find_libclang_windows(): - """Searches for the libclang library on Windows systems.""" - possible_paths = [ - 'C:\\Program Files\\LLVM\\bin\\libclang.dll', - 'C:\\Program Files (x86)\\LLVM\\bin\\libclang.dll', - 'C:\\LLVM\\bin\\libclang.dll' - ] - logger.info("Searching for libclang on Windows...") - for path in possible_paths: - if os.path.isfile(path): - logger.info(f"Found libclang at {path}") - return path - # Fallback: try using `where` command - try: - result = subprocess.run( - ['where', 'libclang.dll'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - paths = result.stdout.strip().split('\n') - if paths: - logger.info(f"Found libclang at {paths[0]}") - return paths[0] - except Exception as e: - logger.error(f"Error while searching for libclang: {e}") - logger.error("libclang not found on Windows") - raise RuntimeError("libclang not found on Windows") - - -def get_libclang_path(custom_path: str | None = None): - """Determines the appropriate libclang library path based on the operating system.""" - if custom_path and os.path.isfile(custom_path): - logger.info(f"Using custom libclang path: {custom_path}") - return custom_path - - cached_path = load_cached_libclang_path() - if cached_path: - return cached_path - - system = platform.system() - logger.info(f"Detected operating system: {system}") - if system == 'Linux': - path = find_libclang_linux() - elif system == 'Darwin': # macOS - path = find_libclang_macos() - elif system == 'Windows': - path = find_libclang_windows() - else: - logger.error("Unsupported operating system") - raise RuntimeError("Unsupported operating system") - - cache_libclang_path(path) - return path - - -def main(): - """Main execution block for setting up libclang path and configuring clang.cindex.""" - parser = argparse.ArgumentParser(description="libclang Path Finder") - parser.add_argument('--path', type=str, help="Custom path to libclang") - args = parser.parse_args() - - try: - libclang_path = get_libclang_path(args.path) - Config.set_library_file(libclang_path) - logger.info(f"Successfully set libclang path to: {libclang_path}") - except Exception as e: - logger.exception("Failed to set libclang path") - - -if __name__ == "__main__": - main() diff --git a/pysrc/image/api/strecth_count.py b/pysrc/image/api/strecth_count.py deleted file mode 100644 index 44df504f..00000000 --- a/pysrc/image/api/strecth_count.py +++ /dev/null @@ -1,251 +0,0 @@ -from pathlib import Path -from typing import Tuple, Dict, Optional, List -import json -import numpy as np -import cv2 -from astropy.io import fits -from scipy import ndimage -from concurrent.futures import ThreadPoolExecutor -import yaml -import logging - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def debayer_image(img: np.ndarray, bayer_pattern: Optional[str] = None) -> np.ndarray: - bayer_patterns = { - "rggb": cv2.COLOR_BAYER_RGGB2BGR, - "gbrg": cv2.COLOR_BAYER_GBRG2BGR, - "bggr": cv2.COLOR_BAYER_BGGR2BGR, - "grbg": cv2.COLOR_BAYER_GRBG2BGR - } - return cv2.cvtColor(img, bayer_patterns.get(bayer_pattern.lower(), cv2.COLOR_BAYER_RGGB2BGR)) - - -def resize_image(img: np.ndarray, target_size: int) -> np.ndarray: - scale = min(target_size / max(img.shape[:2]), 1) - if scale < 1: - return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) - return img - - -def normalize_image(img: np.ndarray) -> np.ndarray: - if not np.allclose(img, img.astype(np.uint8)): - img = cv2.normalize(img, None, alpha=0, beta=255, - norm_type=cv2.NORM_MINMAX) - return img.astype(np.uint8) - - -def stretch_image(img: np.ndarray, is_color: bool) -> np.ndarray: - if is_color: - return compute_and_stretch_three_channels(img, True) - return compute_stretch_one_channel(img, True) - - -def detect_stars(img: np.ndarray, remove_hotpixel: bool, remove_noise: bool, do_star_mark: bool, mark_img: Optional[np.ndarray] = None) -> Tuple[np.ndarray, float, float, Dict[str, float]]: - return star_detect_and_hfr(img, remove_hotpixel, remove_noise, do_star_mark, down_sample_mean_std=True, mark_img=mark_img) - - -# New functions for enhanced image processing -def apply_gaussian_blur(img: np.ndarray, kernel_size: int = 5) -> np.ndarray: - """Apply Gaussian blur to reduce noise.""" - return cv2.GaussianBlur(img, (kernel_size, kernel_size), 0) - - -def apply_unsharp_mask(img: np.ndarray, kernel_size: int = 5, amount: float = 1.5) -> np.ndarray: - """Apply unsharp mask to enhance image details.""" - blurred = cv2.GaussianBlur(img, (kernel_size, kernel_size), 0) - return cv2.addWeighted(img, amount + 1, blurred, -amount, 0) - - -def equalize_histogram(img: np.ndarray) -> np.ndarray: - """Apply histogram equalization to improve contrast.""" - if len(img.shape) == 2: - return cv2.equalizeHist(img) - else: - ycrcb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) - ycrcb[:, :, 0] = cv2.equalizeHist(ycrcb[:, :, 0]) - return cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2BGR) - - -def remove_hot_pixels(img: np.ndarray, threshold: float = 3.0) -> np.ndarray: - """Remove hot pixels using median filter and thresholding.""" - median = ndimage.median_filter(img, size=3) - diff = np.abs(img - median) - mask = diff > (threshold * np.std(diff)) - img[mask] = median[mask] - return img - - -def adjust_gamma(img: np.ndarray, gamma: float = 1.0) -> np.ndarray: - """Adjust image gamma.""" - inv_gamma = 1.0 / gamma - table = np.array([((i / 255.0) ** inv_gamma) * - 255 for i in np.arange(0, 256)]).astype("uint8") - return cv2.LUT(img, table) - - -def apply_clahe(img: np.ndarray, clip_limit: float = 2.0, tile_grid_size: Tuple[int, int] = (8, 8)) -> np.ndarray: - """Apply Contrast Limited Adaptive Histogram Equalization (CLAHE).""" - if len(img.shape) == 2: - clahe = cv2.createCLAHE(clipLimit=clip_limit, - tileGridSize=tile_grid_size) - return clahe.apply(img) - else: - lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) - l, a, b = cv2.split(lab) - clahe = cv2.createCLAHE(clipLimit=clip_limit, - tileGridSize=tile_grid_size) - cl = clahe.apply(l) - limg = cv2.merge((cl, a, b)) - return cv2.cvtColor(limg, cv2.COLOR_LAB2BGR) - - -def denoise_image(img: np.ndarray, h: float = 10) -> np.ndarray: - """Apply Non-Local Means Denoising.""" - return cv2.fastNlMeansDenoisingColored(img, None, h, h, 7, 21) - - -def enhance_image(img: np.ndarray, config: Dict[str, bool]) -> np.ndarray: - """Apply a series of image enhancements based on the configuration.""" - if config.get('remove_hot_pixels', False): - img = remove_hot_pixels(img) - if config.get('denoise', False): - img = denoise_image(img) - if config.get('equalize_histogram', False): - img = equalize_histogram(img) - if config.get('apply_clahe', False): - img = apply_clahe(img) - if config.get('unsharp_mask', False): - img = apply_unsharp_mask(img) - if config.get('adjust_gamma', False): - img = adjust_gamma(img, gamma=config.get('gamma', 1.0)) - if config.get('apply_gaussian_blur', False): - img = apply_gaussian_blur(img) - return img - - -def process_image(filepath: Path, config: Dict[str, bool], resize_size: int = 2048) -> Tuple[Optional[np.ndarray], Dict[str, float]]: - try: - img, header = fits.getdata(filepath, header=True) - except Exception as e: - logger.error(f"Error loading FITS file {filepath}: {e}") - return None, {"star_count": -1, "average_hfr": -1, "max_star": -1, "min_star": -1, "average_star": -1} - - is_color = 'BAYERPAT' in header - if is_color: - img = debayer_image(img, header['BAYERPAT']) - - img = resize_image(img, resize_size) - img = normalize_image(img) - - # Apply image enhancements - img = enhance_image(img, config) - - if config.get('do_stretch', False): - img = stretch_image(img, is_color) - - if config.get('do_star_count', False): - img, star_count, avg_hfr, area_range = detect_stars( - img, config.get('remove_hotpixel', False), config.get('remove_noise', False), config.get('do_star_mark', False)) - return img, { - "star_count": float(star_count), - "average_hfr": avg_hfr, - "max_star": area_range['max'], - "min_star": area_range['min'], - "average_star": area_range['average'] - } - - return img, {"star_count": -1, "average_hfr": -1, "max_star": -1, "min_star": -1, "average_star": -1} - - -def image_stretch_and_star_count_optim(filepath: Path, config: Dict[str, bool], resize_size: int = 2048, - jpg_file: Optional[Path] = None, star_file: Optional[Path] = None) -> Tuple[Optional[np.ndarray], Dict[str, float]]: - img, result = process_image(filepath, config, resize_size) - - if jpg_file and img is not None: - cv2.imwrite(str(jpg_file), img) - if star_file: - with star_file.open('w') as f: - json.dump(result, f) - - return img, result - - -def streaming_debayer_and_stretch(fits_data: bytearray, width: int, height: int, config: Dict[str, bool], - resize_size: int = 2048, bayer_type: Optional[str] = None) -> Optional[np.ndarray]: - img = np.frombuffer(fits_data, dtype=np.uint16).reshape(height, width) - - if bayer_type: - img = debayer_image(img, bayer_type) - - img = resize_image(img, resize_size) - img = normalize_image(img) - - # Apply image enhancements - img = enhance_image(img, config) - - if config.get('do_stretch', True): - img = stretch_image(img, len(img.shape) == 3) - - return img - - -def streaming_debayer(fits_data: bytearray, width: int, height: int, config: Dict[str, bool], - resize_size: int = 2048, bayer_type: Optional[str] = None) -> Optional[np.ndarray]: - img = np.frombuffer(fits_data, dtype=np.uint16).reshape(height, width) - - if bayer_type: - img = debayer_image(img, bayer_type) - - img = resize_image(img, resize_size) - img = normalize_image(img) - - # Apply image enhancements - img = enhance_image(img, config) - - return img - - -def load_config(config_file: Path) -> Dict[str, bool]: - """Load configuration from a YAML file.""" - if config_file.is_file(): - with config_file.open('r') as f: - config = yaml.safe_load(f) - logger.info(f"Loaded configuration from {config_file}") - return config - else: - logger.warning( - f"Configuration file {config_file} not found. Using default configuration.") - return {} - - -def main(): - import argparse - - parser = argparse.ArgumentParser( - description="Image Stretch and Star Count Optimization") - parser.add_argument("filepath", type=Path, help="Path to the FITS file") - parser.add_argument("--config", type=Path, default="config.yaml", - help="Path to the configuration file") - parser.add_argument("--resize-size", type=int, default=2048, - help="Target size for resizing the image") - parser.add_argument("--jpg-file", type=Path, - help="Path to save the processed image as JPG") - parser.add_argument("--star-file", type=Path, - help="Path to save the star count results as JSON") - args = parser.parse_args() - - config = load_config(args.config) - img, result = image_stretch_and_star_count_optim( - args.filepath, config, args.resize_size, args.jpg_file, args.star_file) - - if img is not None: - logger.info(f"Processed image saved to {args.jpg_file}") - logger.info(f"Star count results: {result}") - - -if __name__ == "__main__": - main() diff --git a/pysrc/image/auto_histogram/histogram.py b/pysrc/image/auto_histogram/histogram.py deleted file mode 100644 index a402b7ff..00000000 --- a/pysrc/image/auto_histogram/histogram.py +++ /dev/null @@ -1,111 +0,0 @@ -import cv2 -import numpy as np -from typing import Optional, List, Tuple -from .utils import save_image, load_image - - -def auto_histogram(image: Optional[np.ndarray] = None, clip_shadow: float = 0.01, clip_highlight: float = 0.01, - target_median: int = 128, method: str = 'gamma', apply_clahe: bool = False, - clahe_clip_limit: float = 2.0, clahe_tile_grid_size: Tuple[int, int] = (8, 8), - apply_noise_reduction: bool = False, noise_reduction_method: str = 'median', - apply_sharpening: bool = False, sharpening_strength: float = 1.0, - batch_process: bool = False, file_list: Optional[List[str]] = None) -> Optional[List[np.ndarray]]: - """ - Apply automated histogram transformations and enhancements to an image or a batch of images. - - Parameters: - - image: Input image, grayscale or RGB. - - clip_shadow: Percentage of shadow pixels to clip. - - clip_highlight: Percentage of highlight pixels to clip. - - target_median: Target median value for histogram stretching. - - method: Stretching method ('gamma', 'logarithmic', 'mtf'). - - apply_clahe: Apply CLAHE (Contrast Limited Adaptive Histogram Equalization). - - clahe_clip_limit: CLAHE clip limit. - - clahe_tile_grid_size: CLAHE grid size. - - apply_noise_reduction: Apply noise reduction. - - noise_reduction_method: Noise reduction method ('median', 'gaussian'). - - apply_sharpening: Apply image sharpening. - - sharpening_strength: Strength of sharpening. - - batch_process: Enable batch processing mode. - - file_list: List of file paths for batch processing. - - Returns: - - Processed image or list of processed images. - """ - def histogram_clipping(image: np.ndarray, clip_shadow: float, clip_highlight: float) -> np.ndarray: - flat = image.flatten() - low_val = np.percentile(flat, clip_shadow * 100) - high_val = np.percentile(flat, 100 - clip_highlight * 100) - return np.clip(image, low_val, high_val) - - def gamma_transformation(image: np.ndarray, target_median: int) -> np.ndarray: - mean_val = np.median(image) - gamma = np.log(target_median / 255.0) / np.log(mean_val / 255.0) - return np.array(255 * (image / 255.0) ** gamma, dtype='uint8') - - def logarithmic_transformation(image: np.ndarray) -> np.ndarray: - c = 255 / np.log(1 + np.max(image)) - return np.array(c * np.log(1 + image), dtype='uint8') - - def mtf_transformation(image: np.ndarray, target_median: int) -> np.ndarray: - mean_val = np.median(image) - mtf = target_median / mean_val - return np.array(image * mtf, dtype='uint8') - - def apply_clahe_method(image: np.ndarray, clip_limit: float, tile_grid_size: Tuple[int, int]) -> np.ndarray: - clahe = cv2.createCLAHE(clipLimit=clip_limit, - tileGridSize=tile_grid_size) - if len(image.shape) == 2: - return clahe.apply(image) - else: - return cv2.merge([clahe.apply(channel) for channel in cv2.split(image)]) - - def noise_reduction(image: np.ndarray, method: str) -> np.ndarray: - if method == 'median': - return cv2.medianBlur(image, 3) - elif method == 'gaussian': - return cv2.GaussianBlur(image, (3, 3), 0) - else: - raise ValueError("Invalid noise reduction method specified.") - - def sharpen_image(image: np.ndarray, strength: float) -> np.ndarray: - kernel = np.array([[-1, -1, -1], [-1, 9 + strength, -1], [-1, -1, -1]]) - return cv2.filter2D(image, -1, kernel) - - def process_single_image(image: np.ndarray) -> np.ndarray: - if apply_noise_reduction: - image = noise_reduction(image, noise_reduction_method) - - image = histogram_clipping(image, clip_shadow, clip_highlight) - - if method == 'gamma': - image = gamma_transformation(image, target_median) - elif method == 'logarithmic': - image = logarithmic_transformation(image) - elif method == 'mtf': - image = mtf_transformation(image, target_median) - else: - raise ValueError("Invalid method specified.") - - if apply_clahe: - image = apply_clahe_method( - image, clahe_clip_limit, clahe_tile_grid_size) - - if apply_sharpening: - image = sharpen_image(image, sharpening_strength) - - return image - - if batch_process: - if file_list is None: - raise ValueError( - "File list cannot be None when batch processing is enabled.") - processed_images = [] - for file_path in file_list: - image = load_image(file_path, method != 'mtf') - processed_image = process_single_image(image) - processed_images.append(processed_image) - save_image(f'processed_{file_path}', processed_image) - return processed_images - else: - return process_single_image(image) diff --git a/pysrc/image/auto_histogram/processing.py b/pysrc/image/auto_histogram/processing.py deleted file mode 100644 index 3e64425d..00000000 --- a/pysrc/image/auto_histogram/processing.py +++ /dev/null @@ -1,28 +0,0 @@ -from .histogram import auto_histogram -from .utils import save_image, load_image -import os -from typing import List - - -def process_directory(directory: str, output_directory: str, method: str = 'gamma', **kwargs): - """ - Process all images in a directory using the auto_histogram function. - - :param directory: Input directory containing images to process. - :param output_directory: Directory to save processed images. - :param method: Histogram stretching method ('gamma', 'logarithmic', 'mtf'). - :param kwargs: Additional parameters for auto_histogram. - """ - if not os.path.exists(output_directory): - os.makedirs(output_directory) - - file_list = [os.path.join(directory, file) for file in os.listdir( - directory) if file.endswith(('.jpg', '.png'))] - - processed_images = auto_histogram( - None, method=method, batch_process=True, file_list=file_list, **kwargs) - - for file, image in zip(file_list, processed_images): - output_path = os.path.join( - output_directory, f'processed_{os.path.basename(file)}') - save_image(output_path, image) diff --git a/pysrc/image/auto_histogram/utils.py b/pysrc/image/auto_histogram/utils.py deleted file mode 100644 index f62a21c4..00000000 --- a/pysrc/image/auto_histogram/utils.py +++ /dev/null @@ -1,23 +0,0 @@ -import cv2 -import numpy as np -from typing import Tuple, Union - -def save_image(filepath: str, image: np.ndarray): - """ - Save an image to the specified filepath. - - :param filepath: Path to save the image. - :param image: Image data. - """ - cv2.imwrite(filepath, image) - -def load_image(filepath: str, grayscale: bool = False) -> np.ndarray: - """ - Load an image from the specified filepath. - - :param filepath: Path to load the image from. - :param grayscale: Load image as grayscale if True. - :return: Loaded image. - """ - flags = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR - return cv2.imread(filepath, flags) diff --git a/pysrc/image/channel/combination.py b/pysrc/image/channel/combination.py deleted file mode 100644 index a072018c..00000000 --- a/pysrc/image/channel/combination.py +++ /dev/null @@ -1,67 +0,0 @@ -from PIL import Image -import numpy as np -from skimage import color -import cv2 - -def resize_to_match(image, target_size): - return image.resize(target_size, Image.ANTIALIAS) - -def load_image_as_gray(path): - return Image.open(path).convert('L') - -def combine_channels(channels, color_space='RGB'): - match color_space: - case 'RGB': - return Image.merge("RGB", channels) - - case 'LAB': - lab_image = Image.merge("LAB", channels) - return lab_image.convert('RGB') - - case 'HSV': - hsv_image = Image.merge("HSV", channels) - return hsv_image.convert('RGB') - - case 'HSI': - hsi_image = np.dstack([np.array(ch) / 255.0 for ch in channels]) - rgb_image = color.hsv2rgb(hsi_image) # Scikit-image doesn't have direct HSI support; using HSV as proxy - return Image.fromarray((rgb_image * 255).astype(np.uint8)) - - case _: - raise ValueError(f"Unsupported color space: {color_space}") - -def channel_combination(src1_path, src2_path, src3_path, color_space='RGB'): - # Load and resize images to match - channel_1 = load_image_as_gray(src1_path) - channel_2 = load_image_as_gray(src2_path) - channel_3 = load_image_as_gray(src3_path) - - # Automatically resize images to match the size of the first image - size = channel_1.size - channel_2 = resize_to_match(channel_2, size) - channel_3 = resize_to_match(channel_3, size) - - # Combine the channels - combined_image = combine_channels([channel_1, channel_2, channel_3], color_space=color_space) - - return combined_image - -# 示例用法 -if __name__ == "__main__": - # 指定通道对应的图像路径 - src1_path = 'channel_R.png' # 对应于RGB空间的R通道或Lab空间的L通道 - src2_path = 'channel_G.png' # 对应于RGB空间的G通道或Lab空间的a*通道 - src3_path = 'channel_B.png' # 对应于RGB空间的B通道或Lab空间的b*通道 - - # 执行通道组合,保存结果 - combined_rgb = channel_combination(src1_path, src2_path, src3_path, color_space='RGB') - combined_rgb.save('combined_rgb.png') - - combined_lab = channel_combination(src1_path, src2_path, src3_path, color_space='LAB') - combined_lab.save('combined_lab.png') - - combined_hsv = channel_combination(src1_path, src2_path, src3_path, color_space='HSV') - combined_hsv.save('combined_hsv.png') - - combined_hsi = channel_combination(src1_path, src2_path, src3_path, color_space='HSI') - combined_hsi.save('combined_hsi.png') diff --git a/pysrc/image/channel/extraction.py b/pysrc/image/channel/extraction.py deleted file mode 100644 index fd85d782..00000000 --- a/pysrc/image/channel/extraction.py +++ /dev/null @@ -1,138 +0,0 @@ -import cv2 -import numpy as np -import os -from matplotlib import pyplot as plt - - -def extract_channels(image, color_space='RGB'): - channels = {} - - if color_space == 'RGB': - channels['R'], channels['G'], channels['B'] = cv2.split(image) - - elif color_space == 'XYZ': - xyz_image = cv2.cvtColor(image, cv2.COLOR_BGR2XYZ) - channels['X'], channels['Y'], channels['Z'] = cv2.split(xyz_image) - - elif color_space == 'Lab': - lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2Lab) - channels['L*'], channels['a*'], channels['b*'] = cv2.split(lab_image) - - elif color_space == 'LCh': - lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2Lab) - L, a, b = cv2.split(lab_image) - h, c = cv2.cartToPolar(a, b) - channels['L*'] = L - channels['c*'] = c - channels['h*'] = h - - elif color_space == 'HSV': - hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) - channels['H'], channels['S'], channels['V'] = cv2.split(hsv_image) - - elif color_space == 'HSI': - hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) - H, S, V = cv2.split(hsv_image) - I = V.copy() - channels['H'] = H - channels['Si'] = S - channels['I'] = I - - elif color_space == 'YUV': - yuv_image = cv2.cvtColor(image, cv2.COLOR_BGR2YUV) - channels['Y'], channels['U'], channels['V'] = cv2.split(yuv_image) - - elif color_space == 'YCbCr': - ycbcr_image = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb) - channels['Y'], channels['Cb'], channels['Cr'] = cv2.split(ycbcr_image) - - elif color_space == 'HSL': - hsl_image = cv2.cvtColor(image, cv2.COLOR_BGR2HLS) - channels['H'], channels['S'], channels['L'] = cv2.split(hsl_image) - - elif color_space == 'CMYK': - # Trick: Convert to CMY via XYZ - cmyk_image = cv2.cvtColor(image, cv2.COLOR_BGR2XYZ) - C, M, Y = cv2.split(255 - cmyk_image) - K = np.minimum(C, np.minimum(M, Y)) - channels['C'] = C - channels['M'] = M - channels['Y'] = Y - channels['K'] = K - - return channels - - -def show_histogram(channel_data, title='Channel Histogram'): - plt.figure() - plt.title(title) - plt.xlabel('Pixel Value') - plt.ylabel('Frequency') - plt.hist(channel_data.ravel(), bins=256, range=[0, 256]) - plt.show() - - -def merge_channels(channels): - merged_image = None - channel_list = list(channels.values()) - if len(channel_list) >= 3: - merged_image = cv2.merge(channel_list[:3]) - elif len(channel_list) == 2: - merged_image = cv2.merge( - [channel_list[0], channel_list[1], np.zeros_like(channel_list[0])]) - elif len(channel_list) == 1: - merged_image = channel_list[0] - return merged_image - - -def process_directory(input_dir, output_dir, color_space='RGB'): - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - for filename in os.listdir(input_dir): - if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): - image_path = os.path.join(input_dir, filename) - image = cv2.imread(image_path) - base_name = os.path.splitext(filename)[0] - extracted_channels = extract_channels(image, color_space) - for channel_name, channel_data in extracted_channels.items(): - save_path = os.path.join( - output_dir, f"{base_name}_{channel_name}.png") - cv2.imwrite(save_path, channel_data) - show_histogram( - channel_data, title=f"{base_name} - {channel_name}") - print(f"Saved {save_path}") - - -def save_channels(channels, base_name='output'): - for channel_name, channel_data in channels.items(): - filename = f"{base_name}_{channel_name}.png" - cv2.imwrite(filename, channel_data) - print(f"Saved {filename}") - - -def display_image(title, image): - cv2.imshow(title, image) - cv2.waitKey(0) - cv2.destroyAllWindows() - - -# Example usage: -image = cv2.imread('input_image.png') -extracted_channels = extract_channels(image, color_space='Lab') - -# Show histograms -for name, channel in extracted_channels.items(): - show_histogram(channel, title=f"{name} Histogram") - -# Save channels -save_channels(extracted_channels, base_name='output_image') - -# Merge channels -merged_image = merge_channels(extracted_channels) -if merged_image is not None: - display_image('Merged Image', merged_image) - cv2.imwrite('merged_image.png', merged_image) - -# Process directory -process_directory('input_images', 'output_images', color_space='HSV') diff --git a/pysrc/image/color_calibration/calibration.py b/pysrc/image/color_calibration/calibration.py deleted file mode 100644 index 6ca4cb2c..00000000 --- a/pysrc/image/color_calibration/calibration.py +++ /dev/null @@ -1,65 +0,0 @@ -import cv2 -import numpy as np -from typing import Tuple -from dataclasses import dataclass - -@dataclass -class ColorCalibration: - """Class to handle color calibration tasks for astronomical images.""" - image: np.ndarray - - def calculate_color_factors(self, white_reference: np.ndarray) -> np.ndarray: - """ - Calculate the color calibration factors based on a white reference image. - - Parameters: - white_reference (np.ndarray): The white reference region. - - Returns: - np.ndarray: The calibration factors for the RGB channels. - """ - mean_values = np.mean(white_reference, axis=(0, 1)) - factors = 1.0 / mean_values - return factors - - def apply_color_calibration(self, factors: np.ndarray) -> np.ndarray: - """ - Apply color calibration to the image using provided factors. - - Parameters: - factors (np.ndarray): The RGB calibration factors. - - Returns: - np.ndarray: Color calibrated image. - """ - calibrated_image = np.zeros_like(self.image) - for i in range(3): - calibrated_image[:, :, i] = self.image[:, :, i] * factors[i] - return calibrated_image - - def automatic_white_balance(self) -> np.ndarray: - """ - Perform automatic white balance using the Gray World algorithm. - - Returns: - np.ndarray: White balanced image. - """ - mean_values = np.mean(self.image, axis=(0, 1)) - gray_value = np.mean(mean_values) - factors = gray_value / mean_values - return self.apply_color_calibration(factors) - - def match_histograms(self, reference_image: np.ndarray) -> np.ndarray: - """ - Match the color histogram of the image to a reference image. - - Parameters: - reference_image (np.ndarray): The reference image whose histogram is to be matched. - - Returns: - np.ndarray: Histogram matched image. - """ - matched_image = np.zeros_like(self.image) - for i in range(3): - matched_image[:, :, i] = exposure.match_histograms(self.image[:, :, i], reference_image[:, :, i], multichannel=False) - return matched_image diff --git a/pysrc/image/debayer/__init__.py b/pysrc/image/debayer/__init__.py deleted file mode 100644 index e21f4ef2..00000000 --- a/pysrc/image/debayer/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .debayer import Debayer -from .metrics import calculate_image_quality, evaluate_image_quality diff --git a/pysrc/image/debayer/debayer.py b/pysrc/image/debayer/debayer.py deleted file mode 100644 index bde2fcce..00000000 --- a/pysrc/image/debayer/debayer.py +++ /dev/null @@ -1,247 +0,0 @@ -import numpy as np -import cv2 -from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Tuple -import time - - -class Debayer: - def __init__(self, method: str = 'bilinear', pattern: Optional[str] = None): - """ - Initialize Debayer object with method and optional Bayer pattern. - - :param method: Debayering method ('superpixel', 'bilinear', 'vng', 'ahd', 'laplacian') - :param pattern: Bayer pattern ('BGGR', 'RGGB', 'GBRG', 'GRBG'), None for auto-detection - """ - self.method = method - self.pattern = pattern - - def detect_bayer_pattern(self, image: np.ndarray) -> str: - """ - Automatically detect Bayer pattern from the CFA image. - """ - height, width = image.shape - - # 初始化统计信息 - patterns = {'BGGR': 0, 'RGGB': 0, 'GBRG': 0, 'GRBG': 0} - - # 检测边缘并增强检测精度 - edges = cv2.Canny(image, 50, 150) - - for i in range(0, height - 1, 2): - for j in range(0, width - 1, 2): - # BGGR - patterns['BGGR'] += (image[i, j] + image[i+1, j+1]) + \ - (edges[i, j] + edges[i+1, j+1]) - # RGGB - patterns['RGGB'] += (image[i+1, j] + image[i, j+1]) + \ - (edges[i+1, j] + edges[i, j+1]) - # GBRG - patterns['GBRG'] += (image[i, j+1] + image[i+1, j]) + \ - (edges[i, j+1] + edges[i+1, j]) - # GRBG - patterns['GRBG'] += (image[i, j] + image[i+1, j+1]) + \ - (edges[i, j] + edges[i+1, j+1]) - - # 分析颜色通道的分布,进一步强化检测 - color_sums = {'BGGR': 0, 'RGGB': 0, 'GBRG': 0, 'GRBG': 0} - - # 遍历整个图像并分析色彩通道的强度分布 - for i in range(0, height - 1, 2): - for j in range(0, width - 1, 2): - block = image[i:i+2, j:j+2] - color_sums['BGGR'] += block[0, 0] + block[1, 1] # 蓝-绿对 - color_sums['RGGB'] += block[1, 0] + block[0, 1] # 红-绿对 - color_sums['GBRG'] += block[0, 1] + block[1, 0] # 绿-蓝对 - color_sums['GRBG'] += block[0, 0] + block[1, 1] # 绿-红对 - - # 综合所有信息,选择最有可能的Bayer模式 - for pattern in patterns.keys(): - patterns[pattern] += color_sums[pattern] - - detected_pattern = max(patterns, key=patterns.get) - - return detected_pattern - - def debayer_image(self, cfa_image: np.ndarray) -> np.ndarray: - """ - Perform the debayering process using the specified method. - """ - if self.pattern is None: - self.pattern = self.detect_bayer_pattern(cfa_image) - - cfa_image = self.extend_image_edges(cfa_image, pad_width=2) - print(f"Using Bayer pattern: {self.pattern}") - - if self.method == 'superpixel': - return self.debayer_superpixel(cfa_image) - elif self.method == 'bilinear': - return self.debayer_bilinear(cfa_image) - elif self.method == 'vng': - return self.debayer_vng(cfa_image) - elif self.method == 'ahd': - return self.parallel_debayer_ahd(cfa_image) - elif self.method == 'laplacian': - return self.debayer_laplacian_harmonization(cfa_image) - else: - raise ValueError(f"Unknown debayer method: {self.method}") - - def debayer_superpixel(self, cfa_image: np.ndarray) -> np.ndarray: - red = cfa_image[0::2, 0::2] - green = (cfa_image[0::2, 1::2] + cfa_image[1::2, 0::2]) / 2 - blue = cfa_image[1::2, 1::2] - - rgb_image = np.stack((red, green, blue), axis=-1) - return rgb_image - - def debayer_bilinear(self, cfa_image, pattern='BGGR'): - """ - 使用双线性插值法进行去拜耳处理。 - - :param cfa_image: 输入的CFA图像 - :param pattern: Bayer模式 ('BGGR', 'RGGB', 'GBRG', 'GRBG') - :return: 去拜耳处理后的RGB图像 - """ - if pattern == 'BGGR': - return cv2.cvtColor(cfa_image, cv2.COLOR_BayerBG2BGR) - elif pattern == 'RGGB': - return cv2.cvtColor(cfa_image, cv2.COLOR_BayerRG2BGR) - elif pattern == 'GBRG': - return cv2.cvtColor(cfa_image, cv2.COLOR_BayerGB2BGR) - elif pattern == 'GRBG': - return cv2.cvtColor(cfa_image, cv2.COLOR_BayerGR2BGR) - else: - raise ValueError(f"Unsupported Bayer pattern: {pattern}") - - def debayer_vng(self, cfa_image: np.ndarray) -> np.ndarray: - code = cv2.COLOR_BayerBG2BGR_VNG if pattern == 'BGGR' else cv2.COLOR_BayerRG2BGR_VNG - rgb_image = cv2.cvtColor(cfa_image, code) - return rgb_image - - def parallel_debayer_ahd(self, cfa_image: np.ndarray, num_threads: int = 4) -> np.ndarray: - height, width = cfa_image.shape - chunk_size = height // num_threads - - # 用于存储每个线程处理的部分图像 - results = [None] * num_threads - - def process_chunk(start_row, end_row, index): - chunk = cfa_image[start_row:end_row, :] - gradient_x, gradient_y = calculate_gradients(chunk) - green_channel = interpolate_green_channel( - chunk, gradient_x, gradient_y) - red_channel, blue_channel = interpolate_red_blue_channel( - chunk, green_channel, pattern) - rgb_chunk = np.stack( - (red_channel, green_channel, blue_channel), axis=-1) - results[index] = np.clip(rgb_chunk, 0, 255).astype(np.uint8) - - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [] - for i in range(num_threads): - start_row = i * chunk_size - end_row = (i + 1) * chunk_size if i < num_threads - \ - 1 else height - futures.append(executor.submit( - process_chunk, start_row, end_row, i)) - - concurrent.futures.wait(futures) - - # 合并处理后的块 - rgb_image = np.vstack(results) - return rgb_image - - def calculate_laplacian(self, image): - """ - 计算图像的拉普拉斯算子,用于增强边缘检测。 - - :param image: 输入的图像(灰度图像) - :return: 拉普拉斯图像 - """ - laplacian = cv2.Laplacian(image, cv2.CV_64F) - return laplacian - - def harmonize_edges(self, original, interpolated, laplacian): - """ - 使用拉普拉斯算子结果来调整插值后的图像,增强边缘细节。 - - :param original: 原始CFA图像 - :param interpolated: 双线性插值后的图像 - :param laplacian: 计算的拉普拉斯图像 - :return: 经过拉普拉斯调和的图像 - """ - return np.clip(interpolated + 0.2 * laplacian, 0, 255).astype(np.uint8) - - def debayer_laplacian_harmonization(self, cfa_image, pattern='BGGR'): - """ - 使用简化的拉普拉斯调和方法进行去拜耳处理,以增强边缘处理。 - - :param cfa_image: 输入的CFA图像 - :param pattern: Bayer模式 ('BGGR', 'RGGB', 'GBRG', 'GRBG') - :return: 去拜耳处理后的RGB图像 - """ - # Step 1: 双线性插值 - interpolated_image = self.debayer_bilinear(cfa_image, pattern) - - # Step 2: 计算每个通道的拉普拉斯图像 - laplacian_b = self.calculate_laplacian(interpolated_image[:, :, 0]) - laplacian_g = self.calculate_laplacian(interpolated_image[:, :, 1]) - laplacian_r = self.calculate_laplacian(interpolated_image[:, :, 2]) - - # Step 3: 使用拉普拉斯结果调和插值后的图像 - harmonized_b = self.harmonize_edges(cfa_image, interpolated_image[:, :, 0], laplacian_b) - harmonized_g = self.harmonize_edges(cfa_image, interpolated_image[:, :, 1], laplacian_g) - harmonized_r = self.harmonize_edges(cfa_image, interpolated_image[:, :, 2], laplacian_r) - - # Step 4: 合并调和后的通道 - harmonized_image = np.stack((harmonized_b, harmonized_g, harmonized_r), axis=-1) - - return harmonized_image - - def extend_image_edges(self, image: np.ndarray, pad_width: int) -> np.ndarray: - """ - Extend image edges using mirror padding to handle boundary issues during interpolation. - """ - return np.pad(image, pad_width, mode='reflect') - - def visualize_intermediate_steps(self, cfa_image: np.ndarray): - """ - Visualize intermediate steps in the debayering process. - """ - gradient_x, gradient_y = calculate_gradients(cfa_image) - green_channel = interpolate_green_channel( - cfa_image, gradient_x, gradient_y) - red_channel, blue_channel = interpolate_red_blue_channel( - cfa_image, green_channel, pattern) - - # 显示梯度和各通道图像 - cv2.imshow("Gradient X", gradient_x) - cv2.imshow("Gradient Y", gradient_y) - cv2.imshow("Green Channel", green_channel) - cv2.imshow("Red Channel", red_channel) - cv2.imshow("Blue Channel", blue_channel) - cv2.waitKey(0) - cv2.destroyAllWindows() - - def process_batch(self, image_paths: list, num_threads: int = 4): - """ - Batch processing for multiple CFA images using multithreading. - """ - start_time = time.time() - with ThreadPoolExecutor(max_workers=num_threads) as executor: - results = executor.map( - lambda path: self.process_single_image(path), image_paths) - - elapsed_time = time.time() - start_time - print(f"Batch processing completed in {elapsed_time:.2f} seconds.") - - def process_single_image(self, path: str) -> np.ndarray: - """ - Helper function for processing a single image. - """ - cfa_image = cv2.imread(path, cv2.IMREAD_GRAYSCALE) - rgb_image = self.debayer_image(cfa_image) - output_path = path.replace('.png', f'_{self.method}.png') - cv2.imwrite(output_path, rgb_image) - print(f"Processed {path} -> {output_path}") - return rgb_image diff --git a/pysrc/image/debayer/metrics.py b/pysrc/image/debayer/metrics.py deleted file mode 100644 index 0f27e42c..00000000 --- a/pysrc/image/debayer/metrics.py +++ /dev/null @@ -1,27 +0,0 @@ -import cv2 -import numpy as np -from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim - -def evaluate_image_quality(rgb_image: np.ndarray) -> dict: - """ - Evaluate the quality of the debayered image. - """ - gray_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2GRAY) - laplacian_var = cv2.Laplacian(gray_image, cv2.CV_64F).var() - - mean_colors = cv2.mean(rgb_image)[:3] - - return { - "sharpness": laplacian_var, - "mean_red": mean_colors[2], - "mean_green": mean_colors[1], - "mean_blue": mean_colors[0] - } - -def calculate_image_quality(original: np.ndarray, processed: np.ndarray) -> Tuple[float, float]: - """ - Calculate PSNR and SSIM between the original and processed images. - """ - psnr_value = psnr(original, processed, data_range=processed.max() - processed.min()) - ssim_value = ssim(original, processed, multichannel=True) - return psnr_value, ssim_value diff --git a/pysrc/image/debayer/utils.py b/pysrc/image/debayer/utils.py deleted file mode 100644 index 4ef5e83a..00000000 --- a/pysrc/image/debayer/utils.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import Tuple -import numpy as np - -def calculate_gradients(cfa_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """ - Calculate the gradients of a CFA image. - """ - gradient_x = np.abs(np.diff(cfa_image, axis=1)) - gradient_y = np.abs(np.diff(cfa_image, axis=0)) - - gradient_x = np.pad(gradient_x, ((0, 0), (0, 1)), 'constant') - gradient_y = np.pad(gradient_y, ((0, 1), (0, 0)), 'constant') - - return gradient_x, gradient_y - -def interpolate_green_channel(cfa_image: np.ndarray, gradient_x: np.ndarray, gradient_y: np.ndarray) -> np.ndarray: - """ - Interpolate the green channel of the CFA image. - """ - height, width = cfa_image.shape - green_channel = np.zeros((height, width)) - - for i in range(1, height - 1): - for j in range(1, width - 1): - if (i % 2 == 0 and j % 2 == 1) or (i % 2 == 1 and j % 2 == 0): - # 当前点是绿色通道点,直接赋值 - green_channel[i, j] = cfa_image[i, j] - else: - # 当前点不是绿色通道点,需要插值 - if gradient_x[i, j] < gradient_y[i, j]: - green_channel[i, j] = 0.5 * \ - (cfa_image[i, j-1] + cfa_image[i, j+1]) - else: - green_channel[i, j] = 0.5 * \ - (cfa_image[i-1, j] + cfa_image[i+1, j]) - - return green_channel - -def interpolate_red_blue_channel(cfa_image: np.ndarray, green_channel: np.ndarray, pattern: str) -> Tuple[np.ndarray, np.ndarray]: - """ - Interpolate the red and blue channels of the CFA image. - """ - height, width = cfa_image.shape - red_channel = np.zeros((height, width)) - blue_channel = np.zeros((height, width)) - - if pattern == 'BGGR': - for i in range(0, height, 2): - for j in range(0, width, 2): - blue_channel[i, j] = cfa_image[i, j] - red_channel[i+1, j+1] = cfa_image[i+1, j+1] - - green_r = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1]) - green_b = 0.5 * (green_channel[i, j] + green_channel[i+1, j+1]) - - blue_channel[i+1, j] = cfa_image[i+1, j] - \ - green_b + green_channel[i+1, j] - blue_channel[i, j+1] = cfa_image[i, j+1] - \ - green_b + green_channel[i, j+1] - red_channel[i, j] = cfa_image[i, j] - \ - green_r + green_channel[i, j] - red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ - green_r + green_channel[i+1, j+1] - - elif pattern == 'RGGB': - for i in range(0, height, 2): - for j in range(0, width, 2): - red_channel[i, j] = cfa_image[i, j] - blue_channel[i+1, j+1] = cfa_image[i+1, j+1] - - green_r = 0.5 * (green_channel[i, j+1] + green_channel[i+1, j]) - green_b = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1]) - - red_channel[i+1, j] = cfa_image[i+1, j] - \ - green_r + green_channel[i+1, j] - red_channel[i, j+1] = cfa_image[i, j+1] - \ - green_r + green_channel[i, j+1] - blue_channel[i, j] = cfa_image[i, j] - \ - green_b + green_channel[i, j] - blue_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ - green_b + green_channel[i+1, j+1] - - elif pattern == 'GBRG': - for i in range(0, height, 2): - for j in range(0, width, 2): - green_channel[i, j+1] = cfa_image[i, j+1] - blue_channel[i+1, j] = cfa_image[i+1, j] - - green_r = 0.5 * (green_channel[i, j] + green_channel[i+1, j+1]) - green_b = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1]) - - red_channel[i, j] = cfa_image[i, j] - \ - green_r + green_channel[i, j] - red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ - green_r + green_channel[i+1, j+1] - blue_channel[i, j] = cfa_image[i, j] - \ - green_b + green_channel[i, j] - blue_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ - green_b + green_channel[i+1, j+1] - - elif pattern == 'GRBG': - for i in range(0, height, 2): - for j in range(0, width, 2): - green_channel[i, j] = cfa_image[i, j] - red_channel[i+1, j] = cfa_image[i+1, j] - - green_r = 0.5 * (green_channel[i, j] + green_channel[i+1, j+1]) - green_b = 0.5 * (green_channel[i+1, j] + green_channel[i, j+1]) - - red_channel[i, j] = cfa_image[i, j] - \ - green_r + green_channel[i, j] - red_channel[i+1, j+1] = cfa_image[i+1, j+1] - \ - green_r + green_channel[i+1, j+1] - blue_channel[i+1, j] = cfa_image[i+1, j] - \ - green_b + green_channel[i+1, j] - blue_channel[i, j+1] = cfa_image[i, j+1] - \ - green_b + green_channel[i, j+1] diff --git a/pysrc/image/image_io/io.py b/pysrc/image/image_io/io.py deleted file mode 100644 index 0b857551..00000000 --- a/pysrc/image/image_io/io.py +++ /dev/null @@ -1,86 +0,0 @@ -import cv2 -import numpy as np -from dataclasses import dataclass -from typing import Tuple - - -@dataclass -class ImageProcessor: - """ - A class that provides various image processing functionalities. - """ - - @staticmethod - def save_image(filepath: str, image: np.ndarray) -> None: - """ - Save an image to a specified file path. - - Parameters: - - filepath: str, The file path where the image should be saved. - - image: np.ndarray, The image data to be saved. - """ - cv2.imwrite(filepath, image) - - @staticmethod - def display_image(window_name: str, image: np.ndarray) -> None: - """ - Display an image in a new window. - - Parameters: - - window_name: str, The name of the window where the image will be displayed. - - image: np.ndarray, The image data to be displayed. - """ - cv2.imshow(window_name, image) - cv2.waitKey(0) - cv2.destroyAllWindows() - - @staticmethod - def load_image(filepath: str, color_mode: int = cv2.IMREAD_COLOR) -> np.ndarray: - """ - Load an image from a specified file path. - - Parameters: - - filepath: str, The file path from where the image will be loaded. - - color_mode: int, The color mode in which to load the image (default: cv2.IMREAD_COLOR). - - Returns: - - np.ndarray, The loaded image data. - """ - return cv2.imread(filepath, color_mode) - - @staticmethod - def resize_image(image: np.ndarray, size: Tuple[int, int]) -> np.ndarray: - """ - Resize an image to a new size. - - Parameters: - - image: np.ndarray, The image data to be resized. - - size: Tuple[int, int], The new size as (width, height). - - Returns: - - np.ndarray, The resized image data. - """ - return cv2.resize(image, size) - - @staticmethod - def crop_image(image: np.ndarray, start: Tuple[int, int], size: Tuple[int, int]) -> np.ndarray: - """ - Crop a region from an image. - - Parameters: - - image: np.ndarray, The image data to be cropped. - - start: Tuple[int, int], The top-left corner of the crop region as (x, y). - - size: Tuple[int, int], The size of the crop region as (width, height). - - Returns: - - np.ndarray, The cropped image data. - """ - x, y = start - w, h = size - return image[y:y+h, x:x+w] - - -# Defining the public API of the module -__all__ = [ - 'ImageProcessor' -] diff --git a/pysrc/image/star_detection/clustering.py b/pysrc/image/star_detection/clustering.py deleted file mode 100644 index 365b9db3..00000000 --- a/pysrc/image/star_detection/clustering.py +++ /dev/null @@ -1,32 +0,0 @@ -import numpy as np -from sklearn.cluster import DBSCAN -from typing import List, Tuple - -def cluster_stars(stars: List[Tuple[int, int]], dbscan_eps: float, dbscan_min_samples: int) -> List[Tuple[int, int]]: - """ - Cluster stars using the DBSCAN algorithm. - - Parameters: - - stars: List of star positions as (x, y) tuples. - - dbscan_eps: The maximum distance between two stars for them to be considered in the same neighborhood. - - dbscan_min_samples: The number of stars in a neighborhood for a point to be considered a core point. - - Returns: - - List of clustered star positions as (x, y) tuples. - """ - if len(stars) == 0: - return [] - - clustering = DBSCAN(eps=dbscan_eps, min_samples=dbscan_min_samples).fit(stars) - labels = clustering.labels_ - - unique_labels = set(labels) - clustered_stars = [] - for label in unique_labels: - if label == -1: # -1 indicates noise - continue - class_members = [stars[i] for i in range(len(stars)) if labels[i] == label] - centroid = np.mean(class_members, axis=0).astype(int) - clustered_stars.append(tuple(centroid)) - - return clustered_stars diff --git a/pysrc/image/star_detection/detection.py b/pysrc/image/star_detection/detection.py deleted file mode 100644 index e4a84298..00000000 --- a/pysrc/image/star_detection/detection.py +++ /dev/null @@ -1,84 +0,0 @@ -import cv2 -import numpy as np -from .preprocessing import apply_median_filter, wavelet_transform, inverse_wavelet_transform, binarize, detect_stars, background_subtraction -from typing import List, Tuple - -class StarDetectionConfig: - """ - Configuration class for star detection settings. - """ - def __init__(self, - median_filter_size: int = 3, - wavelet_levels: int = 4, - binarization_threshold: int = 30, - min_star_size: int = 10, - min_star_brightness: int = 20, - min_circularity: float = 0.7, - max_circularity: float = 1.3, - scales: List[float] = [1.0, 0.75, 0.5], - dbscan_eps: float = 10, - dbscan_min_samples: int = 2): - self.median_filter_size = median_filter_size - self.wavelet_levels = wavelet_levels - self.binarization_threshold = binarization_threshold - self.min_star_size = min_star_size - self.min_star_brightness = min_star_brightness - self.min_circularity = min_circularity - self.max_circularity = max_circularity - self.scales = scales - self.dbscan_eps = dbscan_eps - self.dbscan_min_samples = dbscan_min_samples - - -def multiscale_detect_stars(image: np.ndarray, config: StarDetectionConfig) -> List[Tuple[int, int]]: - """ - Detect stars in an image using multiscale analysis. - - Parameters: - - image: Grayscale input image as a numpy array. - - config: Configuration object containing detection settings. - - Returns: - - List of detected star positions as (x, y) tuples. - """ - all_stars = [] - for scale in config.scales: - resized_image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) - filtered_image = apply_median_filter(resized_image, config) - pyramid = wavelet_transform(filtered_image, config.wavelet_levels) - background = pyramid[-1] - subtracted_image = background_subtraction(filtered_image, background) - pyramid = pyramid[:2] - processed_image = inverse_wavelet_transform(pyramid) - binary_image = binarize(processed_image, config) - _, star_properties = detect_stars(binary_image) - filtered_stars = filter_stars(star_properties, binary_image, config) - # Adjust star positions back to original scale - filtered_stars = [(int(x / scale), int(y / scale)) for (x, y) in filtered_stars] - all_stars.extend(filtered_stars) - return all_stars - -def filter_stars(star_properties: List[Tuple[Tuple[int, int], float, float]], binary_image: np.ndarray, config: StarDetectionConfig) -> List[Tuple[int, int]]: - """ - Filter detected stars based on shape, size, and brightness. - - Parameters: - - star_properties: List of tuples containing star properties (center, area, perimeter). - - binary_image: Binary image used for star detection. - - config: Configuration object containing filter settings. - - Returns: - - List of filtered star positions as (x, y) tuples. - """ - filtered_stars = [] - for (center, area, perimeter) in star_properties: - circularity = (4 * np.pi * area) / (perimeter ** 2) - mask = np.zeros_like(binary_image) - cv2.circle(mask, center, 5, 255, -1) - star_pixels = cv2.countNonZero(mask) - brightness = np.mean(binary_image[mask == 255]) - if (star_pixels > config.min_star_size and - brightness > config.min_star_brightness and - config.min_circularity <= circularity <= config.max_circularity): - filtered_stars.append(center) - return filtered_stars diff --git a/pysrc/image/star_detection/preprocessing.py b/pysrc/image/star_detection/preprocessing.py deleted file mode 100644 index 16753dbe..00000000 --- a/pysrc/image/star_detection/preprocessing.py +++ /dev/null @@ -1,179 +0,0 @@ -import cv2 -import numpy as np -from astropy.io import fits -from typing import List, Tuple - -def load_fits_image(file_path: str) -> np.ndarray: - """ - Load a FITS image from the specified file path. - - Parameters: - - file_path: Path to the FITS file. - - Returns: - - Image data as a numpy array. - """ - with fits.open(file_path) as hdul: - image_data = hdul[0].data - return image_data - -def preprocess_fits_image(image_data: np.ndarray) -> np.ndarray: - """ - Preprocess FITS image by normalizing to the 0-255 range. - - Parameters: - - image_data: Raw image data from the FITS file. - - Returns: - - Preprocessed image data as a numpy array. - """ - image_data = np.nan_to_num(image_data) - image_data = image_data.astype(np.float64) - image_data -= np.min(image_data) - image_data /= np.max(image_data) - image_data *= 255 - return image_data.astype(np.uint8) - -def load_image(file_path: str) -> np.ndarray: - """ - Load an image from the specified file path. Supports FITS and standard image formats. - - Parameters: - - file_path: Path to the image file. - - Returns: - - Loaded image as a numpy array. - """ - if file_path.endswith('.fits'): - image_data = load_fits_image(file_path) - if image_data.ndim == 2: - return preprocess_fits_image(image_data) - elif image_data.ndim == 3: - channels = [preprocess_fits_image(image_data[..., i]) for i in range(image_data.shape[2])] - return cv2.merge(channels) - else: - image = cv2.imread(file_path, cv2.IMREAD_UNCHANGED) - if image is None: - raise ValueError("Unable to load image file: {}".format(file_path)) - return image - -def apply_median_filter(image: np.ndarray, config) -> np.ndarray: - """ - Apply median filtering to the image. - - Parameters: - - image: Input image. - - config: Configuration object containing median filter settings. - - Returns: - - Filtered image. - """ - return cv2.medianBlur(image, config.median_filter_size) - -def wavelet_transform(image: np.ndarray, levels: int) -> List[np.ndarray]: - """ - Perform wavelet transform using a Laplacian pyramid. - - Parameters: - - image: Input image. - - levels: Number of levels in the wavelet transform. - - Returns: - - List of wavelet transformed images at each level. - """ - pyramid = [] - current_image = image.copy() - for _ in range(levels): - down = cv2.pyrDown(current_image) - up = cv2.pyrUp(down, current_image.shape[:2]) - - # Resize up to match the original image size - up = cv2.resize(up, (current_image.shape[1], current_image.shape[0])) - - # Calculate the difference to get the detail layer - layer = cv2.subtract(current_image, up) - pyramid.append(layer) - current_image = down - - pyramid.append(current_image) # Add the final low-resolution image - return pyramid - -def inverse_wavelet_transform(pyramid: List[np.ndarray]) -> np.ndarray: - """ - Reconstruct the image from its wavelet pyramid representation. - - Parameters: - - pyramid: List of wavelet transformed images at each level. - - Returns: - - Reconstructed image. - """ - image = pyramid.pop() - while pyramid: - up = cv2.pyrUp(image, pyramid[-1].shape[:2]) - - # Resize up to match the size of the current level - up = cv2.resize(up, (pyramid[-1].shape[1], pyramid[-1].shape[0])) - - # Add the detail layer to reconstruct the image - image = cv2.add(up, pyramid.pop()) - - return image - -def background_subtraction(image: np.ndarray, background: np.ndarray) -> np.ndarray: - """ - Subtract the background from the image using the provided background image. - - Parameters: - - image: Original image. - - background: Background image to subtract. - - Returns: - - Image with background subtracted. - """ - # Resize the background to match the original image size - background_resized = cv2.resize(background, (image.shape[1], image.shape[0])) - - # Subtract background and ensure no negative values - result = cv2.subtract(image, background_resized) - result[result < 0] = 0 - return result - -def binarize(image: np.ndarray, config) -> np.ndarray: - """ - Binarize the image using a fixed threshold. - - Parameters: - - image: Input image. - - config: Configuration object containing binarization settings. - - Returns: - - Binarized image. - """ - _, binary_image = cv2.threshold(image, config.binarization_threshold, 255, cv2.THRESH_BINARY) - return binary_image - -def detect_stars(binary_image: np.ndarray) -> Tuple[List[Tuple[int, int]], List[Tuple[Tuple[int, int], float, float]]]: - """ - Detect stars in a binary image by finding contours. - - Parameters: - - binary_image: Binarized image. - - Returns: - - Tuple containing a list of star centers and a list of star properties (center, area, perimeter). - """ - contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - star_centers = [] - star_properties = [] - - for contour in contours: - M = cv2.moments(contour) - if M['m00'] != 0: - center = (int(M['m10'] / M['m00']), int(M['m01'] / M['m00'])) - star_centers.append(center) - area = cv2.contourArea(contour) - perimeter = cv2.arcLength(contour, True) - star_properties.append((center, area, perimeter)) - - return star_centers, star_properties diff --git a/pysrc/image/transformation/__init__.py b/pysrc/image/transformation/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pysrc/image/transformation/curve.py b/pysrc/image/transformation/curve.py deleted file mode 100644 index 8275158a..00000000 --- a/pysrc/image/transformation/curve.py +++ /dev/null @@ -1,154 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt -from scipy.interpolate import CubicSpline, Akima1DInterpolator - - -class CurvesTransformation: - def __init__(self, interpolation='akima'): - self.points = [] - self.interpolation = interpolation - self.curve = None - self.stored_curve = None - - def add_point(self, x, y): - self.points.append((x, y)) - self.points.sort() # Sort points by x value - self._update_curve() - - def remove_point(self, index): - if 0 <= index < len(self.points): - self.points.pop(index) - self._update_curve() - - def _update_curve(self): - if len(self.points) < 2: - self.curve = None - return - - x, y = zip(*self.points) - - if self.interpolation == 'cubic': - self.curve = CubicSpline(x, y) - elif self.interpolation == 'akima': - self.curve = Akima1DInterpolator(x, y) - elif self.interpolation == 'linear': - self.curve = lambda x_new: np.interp(x_new, x, y) - else: - raise ValueError("Unsupported interpolation method") - - def transform(self, image, channel=None): - if self.curve is None: - raise ValueError("No valid curve defined") - - if len(image.shape) == 2: # Grayscale image - transformed_image = self.curve(image) - elif len(image.shape) == 3: # RGB image - if channel is None: - raise ValueError("Channel must be specified for color images") - transformed_image = image.copy() - transformed_image[:, :, channel] = self.curve(image[:, :, channel]) - else: - raise ValueError("Unsupported image format") - - transformed_image = np.clip(transformed_image, 0, 1) - return transformed_image - - def plot_curve(self): - if self.curve is None: - print("No curve to plot") - return - - x_vals = np.linspace(0, 1, 100) - y_vals = self.curve(x_vals) - - plt.plot(x_vals, y_vals, label=f'Interpolation: {self.interpolation}') - plt.scatter(*zip(*self.points), color='red') - plt.title('Curves Transformation') - plt.xlabel('Input') - plt.ylabel('Output') - plt.grid(True) - plt.legend() - plt.show() - - def store_curve(self): - self.stored_curve = self.points.copy() - print("Curve stored.") - - def restore_curve(self): - if self.stored_curve: - self.points = self.stored_curve.copy() - self._update_curve() - print("Curve restored.") - else: - print("No stored curve to restore.") - - def invert_curve(self): - if self.curve is None: - print("No curve to invert") - return - - self.points = [(x, 1 - y) for x, y in self.points] - self._update_curve() - print("Curve inverted.") - - def reset_curve(self): - self.points = [(0, 0), (1, 1)] - self._update_curve() - print("Curve reset to default.") - - def pixel_readout(self, x): - if self.curve is None: - print("No curve defined") - return None - return self.curve(x) - - -# Example Usage -if __name__ == "__main__": - # Create a CurvesTransformation object - curve_transform = CurvesTransformation(interpolation='akima') - - # Add points to the curve - curve_transform.add_point(0.0, 0.0) - curve_transform.add_point(0.3, 0.5) - curve_transform.add_point(0.7, 0.8) - curve_transform.add_point(1.0, 1.0) - - # Plot the curve - curve_transform.plot_curve() - - # Store the curve - curve_transform.store_curve() - - # Invert the curve - curve_transform.invert_curve() - curve_transform.plot_curve() - - # Restore the original curve - curve_transform.restore_curve() - curve_transform.plot_curve() - - # Reset the curve to default - curve_transform.reset_curve() - curve_transform.plot_curve() - - # Generate a test image - test_image = np.linspace(0, 1, 256).reshape(16, 16) - - # Apply the transformation - transformed_image = curve_transform.transform(test_image) - - # Plot original and transformed images - plt.figure(figsize=(8, 4)) - plt.subplot(1, 2, 1) - plt.title("Original Image") - plt.imshow(test_image, cmap='gray') - - plt.subplot(1, 2, 2) - plt.title("Transformed Image") - plt.imshow(transformed_image, cmap='gray') - plt.show() - - # Pixel readout - readout_value = curve_transform.pixel_readout(0.5) - print(f"Pixel readout at x=0.5: {readout_value}") diff --git a/pysrc/image/transformation/histogram.py b/pysrc/image/transformation/histogram.py deleted file mode 100644 index 80f8f390..00000000 --- a/pysrc/image/transformation/histogram.py +++ /dev/null @@ -1,121 +0,0 @@ -import cv2 -import numpy as np -import matplotlib.pyplot as plt - -# 1. 直方图计算 - - -def calculate_histogram(image, channel=0): - histogram = cv2.calcHist([image], [channel], None, [256], [0, 256]) - return histogram - -# 2. 显示直方图 - - -def display_histogram(histogram, title="Histogram"): - plt.plot(histogram) - plt.title(title) - plt.xlabel('Pixel Intensity') - plt.ylabel('Frequency') - plt.show() - -# 3. 直方图变换功能 - - -def apply_histogram_transformation(image, shadows_clip=0.0, highlights_clip=1.0, midtones_balance=0.5, lower_bound=-1.0, upper_bound=2.0): - # 归一化 - normalized_image = image.astype(np.float32) / 255.0 - - # 阴影和高光裁剪 - clipped_image = np.clip( - (normalized_image - shadows_clip) / (highlights_clip - shadows_clip), 0, 1) - - # 中间调平衡 - def mtf(x): return (x**midtones_balance) / \ - ((x**midtones_balance + (1-x)**midtones_balance)**(1/midtones_balance)) - transformed_image = mtf(clipped_image) - - # 动态范围扩展 - expanded_image = np.clip( - (transformed_image - lower_bound) / (upper_bound - lower_bound), 0, 1) - - # 重新缩放至[0, 255] - output_image = (expanded_image * 255).astype(np.uint8) - return output_image - -# 4. 自动裁剪功能 - - -def auto_clip(image, clip_percent=0.01): - # 计算累积分布函数 (CDF) - hist, bins = np.histogram(image.flatten(), 256, [0, 256]) - cdf = hist.cumsum() - - # 计算裁剪点 - total_pixels = image.size - lower_clip = np.searchsorted(cdf, total_pixels * clip_percent) - upper_clip = np.searchsorted(cdf, total_pixels * (1 - clip_percent)) - - # 应用裁剪 - auto_clipped_image = apply_histogram_transformation( - image, shadows_clip=lower_clip/255.0, highlights_clip=upper_clip/255.0) - - return auto_clipped_image - -# 5. 显示原始RGB直方图 - - -def display_rgb_histogram(image): - color = ('b', 'g', 'r') - for i, col in enumerate(color): - hist = calculate_histogram(image, channel=i) - plt.plot(hist, color=col) - plt.title('RGB Histogram') - plt.xlabel('Pixel Intensity') - plt.ylabel('Frequency') - plt.show() - -# 6. 实时预览功能(简单模拟) - - -def real_time_preview(image, transformation_function, **kwargs): - preview_image = transformation_function(image, **kwargs) - cv2.imshow('Real-Time Preview', preview_image) - - -# 主程序入口 -if __name__ == "__main__": - # 加载图像 - image = cv2.imread('image.jpg') - - # 转换为灰度图像 - grayscale_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) - - # 显示原始图像和直方图 - cv2.imshow('Original Image', image) - histogram = calculate_histogram(grayscale_image) - display_histogram(histogram, title="Original Grayscale Histogram") - - # 显示RGB直方图 - display_rgb_histogram(image) - - # 应用直方图变换 - transformed_image = apply_histogram_transformation( - grayscale_image, shadows_clip=0.1, highlights_clip=0.9, midtones_balance=0.4, lower_bound=-0.5, upper_bound=1.5) - cv2.imshow('Transformed Image', transformed_image) - - # 显示变换后的直方图 - transformed_histogram = calculate_histogram(transformed_image) - display_histogram(transformed_histogram, - title="Transformed Grayscale Histogram") - - # 应用自动裁剪 - auto_clipped_image = auto_clip(grayscale_image, clip_percent=0.01) - cv2.imshow('Auto Clipped Image', auto_clipped_image) - - # 实时预览模拟 - real_time_preview(grayscale_image, apply_histogram_transformation, - shadows_clip=0.05, highlights_clip=0.95, midtones_balance=0.5) - - cv2.waitKey(0) - cv2.destroyAllWindows() diff --git a/src/atom b/src/atom index 8324c44b..1da90d48 160000 --- a/src/atom +++ b/src/atom @@ -1 +1 @@ -Subproject commit 8324c44b6589305faf953822c60c4be4d79a916e +Subproject commit 1da90d48f4dfbe4034665be4284b6fa12c95a7ec