diff --git a/noxfile.py b/noxfile.py index 06c6c9c..ff8da4e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,3 +1,5 @@ +"""Nox Test Sessions.""" + from __future__ import annotations import argparse @@ -8,14 +10,12 @@ DIR = Path(__file__).parent.resolve() -nox.options.sessions = ["lint", "tests"] # "pylint", +nox.options.sessions = ["lint", "pylint", "tests"] @nox.session def lint(session: nox.Session) -> None: - """ - Run the linter. - """ + """Run the linter.""" session.install("pre-commit") session.run( "pre-commit", "run", "--all-files", "--show-diff-on-failure", *session.posargs @@ -25,9 +25,7 @@ def lint(session: nox.Session) -> None: # TODO: turn this on eventually @nox.session def pylint(session: nox.Session) -> None: - """ - Run PyLint. - """ + """Run PyLint.""" # This needs to be installed into the package environment, and is slower # than a pre-commit check session.install(".", "pylint") @@ -36,19 +34,14 @@ def pylint(session: nox.Session) -> None: @nox.session def tests(session: nox.Session) -> None: - """ - Run the unit and regular tests. - """ + """Run the unit and regular tests.""" session.install(".[test]") session.run("pytest", *session.posargs) @nox.session(reuse_venv=True) def docs(session: nox.Session) -> None: - """ - Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links. - """ - + """Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links.""" parser = argparse.ArgumentParser() parser.add_argument("--serve", action="store_true", help="Serve after building") parser.add_argument( @@ -87,10 +80,7 @@ def docs(session: nox.Session) -> None: @nox.session def build_api_docs(session: nox.Session) -> None: - """ - Build (regenerate) API docs. - """ - + """Build (regenerate) API docs.""" session.install("sphinx") session.chdir("docs") session.run( @@ -106,10 +96,7 @@ def build_api_docs(session: nox.Session) -> None: @nox.session def build(session: nox.Session) -> None: - """ - Build an SDist and wheel. - """ - + """Build an SDist and wheel.""" build_path = DIR.joinpath("build") if build_path.exists(): shutil.rmtree(build_path) diff --git a/pyproject.toml b/pyproject.toml index 48dadaa..d1ffafa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,16 +29,17 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "numpy", - "scipy>1.10", + "asdf", + "dustmaps", "astropy", - "matplotlib", + "astroquery", "gala", "galstreams", + "matplotlib", + "numba", + "numpy", "pyia", - "astroquery", - "dustmaps", - "asdf", + "scipy>1.10", "ugali", ] @@ -81,6 +82,7 @@ addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] xfail_strict = true filterwarnings = [ "error", + "ignore:numpy.ndarray size changed", ] log_cli_level = "INFO" testpaths = [ @@ -123,41 +125,27 @@ disallow_incomplete_defs = true src = ["src"] [tool.ruff.lint] -extend-select = [ - "B", # flake8-bugbear - "I", # isort - "ARG", # flake8-unused-arguments - "C4", # flake8-comprehensions - "EM", # flake8-errmsg - "ICN", # flake8-import-conventions - "G", # flake8-logging-format - "PGH", # pygrep-hooks - "PIE", # flake8-pie - "PL", # pylint - "PT", # flake8-pytest-style - "PTH", # flake8-use-pathlib - "RET", # flake8-return - "RUF", # Ruff-specific - "SIM", # flake8-simplify - "T20", # flake8-print - "UP", # pyupgrade - "YTT", # flake8-2020 - "EXE", # flake8-executable - "NPY", # NumPy specific rules - "PD", # pandas-vet -] +extend-select = ["ALL"] ignore = [ + "ANN101", # Missing type annotation for self in method + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed in `**kwargs` + "BLE001", # Using bare `except` + "COM812", # Missing trailing comma + "D107", # Missing docstring in `__init__` + "D203", # 1 blank line required before class docstring + "D213", # Multi-line docstring summary should start at the second line + "ERA001", # Commented-out code + "FIX002", # Line contains TODO + "N8", # Naming conventions + "N999", # Invalid module name "PD", # Pandas "PLR", # Design related pylint codes + "TD002", # Missing author in TODO + "TD003", # Missing issue link on the line following this TODO # TODO: fix these and remove "ARG001", # Unused function argument `X` "ARG003", # Unused method argument `X` - "E711", # Comparison to `None` should be `cond is None` - "E731", # Do not assign a `lambda` expression, use a `def` - "EM101", # Exception must not use a string literal, assign to variable first "F821", # Undefined name `X` - "PT018", # Assertion should be broken down into multiple parts - "RET504", # Unnecessary assignment `return` statement "RET505", # Unnecessary `else` after `return` "SIM118", # Use `key in dict` instead of `key in dict.keys()` "T201", # `print` found @@ -167,8 +155,11 @@ isort.required-imports = ["from __future__ import annotations"] # typing-modules = ["cats._compat.typing"] [tool.ruff.lint.per-file-ignores] -"tests/**" = ["T20"] +"docs/conf.py" = ["A001", "D100", "INP001"] +"scripts/**" = ["INP001"] +"tests/**" = ["ANN", "D1", "INP", "S101", "T20"] "noxfile.py" = ["T20"] +"__init__.py" = ["F403"] # TODO: fix these and remove "src/cats/pawprint/tests/mwe.py" = ["F821"] @@ -181,7 +172,11 @@ similarities.ignore-imports = "yes" messages_control.disable = [ "design", "fixme", - "line-too-long", - "missing-module-docstring", "wrong-import-position", + # TODO: fix these and remove + "attribute-defined-outside-init", + "duplicate-code", + "invalid-name", + "no-member", + "unused-variable", ] diff --git a/src/cats/combine_pm_cmd.py b/scripts/combine_pm_cmd.py similarity index 79% rename from src/cats/combine_pm_cmd.py rename to scripts/combine_pm_cmd.py index 765297b..e933c96 100644 --- a/src/cats/combine_pm_cmd.py +++ b/scripts/combine_pm_cmd.py @@ -1,39 +1,26 @@ +"""Combine PM and CMD cuts.""" + from __future__ import annotations +from typing import Any + import astropy.table as at -import matplotlib as mpl +import matplotlib.pyplot as plt import pandas as pd -from cats.cmd.CMD import Isochrone -from cats.pawprint.pawprint import Footprint2D, Pawprint - -plt = mpl.pyplot - -plt.rc( - "xtick", - top=True, - direction="in", - labelsize=15, -) -plt.rc( - "ytick", - right=True, - direction="in", - labelsize=15, -) -plt.rc( - "font", - family="Arial", -) +from cats.cmd import Isochrone +from cats.pawprint import Footprint2D, Pawprint + +plt.rc("xtick", top=True, direction="in", labelsize=15) +plt.rc("ytick", right=True, direction="in", labelsize=15) +plt.rc("font", family="Arial") def generate_isochrone_vertices( - cat, - sky_poly, - pm_poly, - config, -): - """ + cat: Any, sky_poly: Any, pm_poly: Any, config: Any +) -> Any: + """Generate Isochrone Vertices. + Use the generated class to make a new polygon for the given catalog in CMD space given a sky and PM polygon. """ @@ -63,26 +50,17 @@ def generate_isochrone_vertices( return o.simpleSln(0.1, 15, mass_thresh=0.83)[0] -def generate_pm_vertices( - cat, - sky_poly, - cmd_poly, - config, -): - """ +def generate_pm_vertices() -> list[list[float]]: + """Generate Proper Motion Vertices. + Use the generated class to make a new polygon for the given catalog in PM space given a sky and CMD polygon. """ - pm_poly = [ - [-7.0, 0.0], - [-5.0, 0.0], - [-5.0, 1.6], - [-7.0, -1.6], - ] - return pm_poly + return [[-7.0, 0.0], [-5.0, 0.0], [-5.0, 1.6], [-7.0, -1.6]] -def load_sky_region(fn): +def load_sky_region(fn: Any) -> tuple[list[float], list[float]]: + """Load Sky Region.""" sky_print = [ [-5, -2], [+5, -2], @@ -94,12 +72,13 @@ def load_sky_region(fn): def main() -> int: + """Run Script.""" # load in config file, catalog from filename config = pd.read_json("config.json") cat = at.Table.read(config.streaminfo.cat_fn) # load in file with the sky footprint. - sky_poly, bg_poly = load_sky_region(config.streaminfo.sky_print) + sky_poly, _ = load_sky_region(config.streaminfo.sky_print) # have an initial selection for the PM region that is very wide # this could also be stored in a footprint diff --git a/src/cats/__init__.py b/src/cats/__init__.py index 8b6b85d..cc04876 100644 --- a/src/cats/__init__.py +++ b/src/cats/__init__.py @@ -1,12 +1,16 @@ -""" -Copyright (c) 2023 CATS. All rights reserved. +"""Copyright (c) 2023 CATS. All rights reserved. cats: Community Atlas of Tidal Streams """ - from __future__ import annotations -from ._version import version as __version__ +__all__ = [ + # Constants + "__version__", + # Modules + "data", + "frames", +] -__all__ = ["__version__"] +from ._version import version as __version__ diff --git a/src/cats/_version.pyi b/src/cats/_version.pyi index 91744f9..5bb2b22 100644 --- a/src/cats/_version.pyi +++ b/src/cats/_version.pyi @@ -1,4 +1,2 @@ -from __future__ import annotations - version: str version_tuple: tuple[int, int, int] | tuple[int, int, int, str, str] diff --git a/src/cats/cmd/CMD.py b/src/cats/cmd/CMD.py deleted file mode 100644 index 1c0be7e..0000000 --- a/src/cats/cmd/CMD.py +++ /dev/null @@ -1,633 +0,0 @@ -"""CMD functions.""" -from __future__ import annotations - -import matplotlib as mpl -import matplotlib.pyplot as plt -import numpy as np -import scipy -from isochrones.mist import MIST_Isochrone -from matplotlib.patches import PathPatch -from scipy.interpolate import InterpolatedUnivariateSpline, interp1d -from scipy.signal import correlate2d -from ugali.analysis.isochrone import factory as isochrone_factory - -from cats.inputs import stream_inputs as inputs -from cats.pawprint.pawprint import Footprint2D - -__authors__ = "Ani, Kiyan, Richard" - -plt.rc( - "xtick", - top=True, - direction="in", - labelsize=15, -) -plt.rc( - "ytick", - right=True, - direction="in", - labelsize=15, -) -plt.rc( - "font", - family="Arial", -) - - -class Isochrone: - def __init__(self, stream, cat, pawprint): - """ - Defining variables loaded into class. - - ------------------------------------------------------------------ - - Parameters: - cat = Input catalogue. - age = Input age of stream from galstreams and/or literature. - feh = Input metallicity from galstreams and/or literature. - distance = Input distance from galstreams and/or literature. - alpha = alpha/Fe - pawprint = Stream multidimensional footprint - """ - - # Pull survey from catalog? - self.stream = stream - self.cat = cat - self.age = inputs[stream]["age"] - self.feh = inputs[stream]["feh"] - self.distance = inputs[stream]["distance"] # kpc - self.alpha = inputs[stream]["alpha"] - self.dist_mod = 5 * np.log10(1000 * self.distance) - 5 - - self.pawprint = pawprint - track = self.pawprint.track.track.transform_to(self.pawprint.track.stream_frame) - - if self.stream == "GD-1": - distmod_spl = np.poly1d([2.41e-4, 2.421e-2, 15.001]) - self.dist_mod_correct = distmod_spl(self.cat["phi1"]) - self.dist_mod - else: - spline_dist = InterpolatedUnivariateSpline( - track.phi1.value, track.distance.value - ) - self.dist_mod_correct = ( - 5 * np.log10(spline_dist(self.cat["phi1"]) * 1000) - 5 - ) - self.dist_mod - - self.x_shift = 0 - self.y_shift = 0 - self.phot_survey = inputs[self.stream]["phot_survey"] - self.band1 = inputs[self.stream]["band1"] - self.band2 = inputs[self.stream]["band2"] - self.data_mag = inputs[self.stream]["mag"] - self.data_color1 = inputs[self.stream]["color1"] - self.data_color2 = inputs[self.stream]["color2"] - self.turnoff = inputs[self.stream]["turnoff"] - - self.generate_isochrone() - self.sel_sky() - self.sel_pm() - if self.pawprint.pm1print is not None: - self.sel_pm12() - - self.data_cmd() - if self.pawprint.pm1print is not None: - # Only shift isochrone is the previous cuts are clean enough - # Otherwise it will just shift to the background - self.correct_isochrone() - - def sel_sky(self): - """ - Initialising the on-sky polygon mask to return only contained sources. - """ - - on_poly_patch = mpl.patches.Polygon( - self.pawprint.skyprint["stream"].vertices[::50], - facecolor="none", - edgecolor="k", - linewidth=2, - ) - on_points = np.vstack((self.cat["phi1"], self.cat["phi2"])).T - on_mask = on_poly_patch.get_path().contains_points(on_points) - - # on_points = np.vstack((self.cat["phi1"], self.cat["phi2"])).T - # on_mask = self.pawprint.skyprint['stream'].inside_footprint(on_points) #very slow because skyprint is very large - - self.on_skymask = on_mask - - def sel_pm(self): - """ - Initialising the proper motions polygon mask to return only contained sources. - """ - - on_points = np.vstack( - (self.cat["pm_phi1_cosphi2_unrefl"], self.cat["pm_phi2_unrefl"]) - ).T - - on_mask = self.pawprint.pmprint.inside_footprint(on_points) - - self.on_pmmask = on_mask - - def sel_pm12(self): - """ - Initialising the proper motions polygon mask to return only contained sources. - """ - - on_pm1_points = np.vstack( - (self.cat["phi1"], self.cat["pm_phi1_cosphi2_unrefl"]) - ).T - on_pm2_points = np.vstack((self.cat["phi1"], self.cat["pm_phi2_unrefl"])).T - - on_pm1_mask = self.pawprint.pm1print.inside_footprint(on_pm1_points) - on_pm2_mask = self.pawprint.pm2print.inside_footprint(on_pm2_points) - - self.on_pm1mask = on_pm1_mask - self.on_pm2mask = on_pm2_mask - self.on_pm12mask = on_pm1_mask & on_pm2_mask - - def generate_isochrone(self): - """ - load an isochrone, LF model for a given metallicity, age, distance - """ - - # Convert feh to z - Y_p = 0.245 # Primordial He abundance (WMAP, 2003) - c = 1.54 # He enrichment ratio - ZX_solar = 0.0229 - z = (1 - Y_p) / ((1 + c) + (1 / ZX_solar) * 10 ** (-self.feh)) - - if self.phot_survey == "Gaia": - mist = MIST_Isochrone() - iso = mist.isochrone( - age=np.log10(1e9 * self.age), # has to be given in logAge - feh=self.feh, - eep_range=None, # get the whole isochrone, - distance=1e3 * self.distance, # given in parsecs - ) - - initial_mass, actual_mass = iso.initial_mass.values, iso.mass.values - mag = iso.G_mag.values - color_1 = iso.BP_mag.values - color_2 = iso.RP_mag.values - - # Excise the horizontal branch - turn_idx = scipy.signal.argrelextrema(iso.G_mag.values, np.less)[0][0] - initial_mass = initial_mass[0:turn_idx] - actual_mass = actual_mass[0:turn_idx] - self.masses = actual_mass - - self.mag = mag[0:turn_idx] - self.color = color_1[0:turn_idx] - color_2[0:turn_idx] - - else: - iso = isochrone_factory( - "Dotter", - survey=self.phot_survey, - age=self.age, - distance_modulus=self.dist_mod, - z=z, - band_1=self.band1, - band_2=self.band2, - ) - - iso.afe = self.alpha - - initial_mass, mass_pdf, actual_mass, mag_1, mag_2 = iso.sample( - mass_steps=4e2 - ) - mag_1 = mag_1 + iso.distance_modulus - mag_2 = mag_2 + iso.distance_modulus - - # Excise the horizontal branch - turn_idx = scipy.signal.argrelextrema(mag_1, np.less)[0][0] - initial_mass = initial_mass[0:turn_idx] - mass_pdf = mass_pdf[0:turn_idx] - actual_mass = actual_mass[0:turn_idx] - mag_1 = mag_1[0:turn_idx] - mag_2 = mag_2[0:turn_idx] - - self.mag = mag_1 - self.color = mag_1 - mag_2 - - mmag_1 = interp1d(initial_mass, mag_1, fill_value="extrapolate") - mmag_2 = interp1d(initial_mass, mag_2, fill_value="extrapolate") - mmass_pdf = interp1d(initial_mass, mass_pdf, fill_value="extrapolate") - - self.iso = iso - self.masses = actual_mass - self.mass_pdf = mass_pdf - - self.mmag_1 = mmag_1 - self.mmag_2 = mmag_2 - self.mmass_pdf = mmass_pdf - - # return iso, initial_mass, mass_pdf, actual_mass, mag_1, mag_2, mmag_1, mmag_2, \ - # mmass_pdf - - def data_cmd(self, xrange=(-0.5, 1.0), yrange=(15, 22)): - """ - Empirical CMD generated from the input catalogue, with distance gradient accounted for. - - ------------------------------------------------------------------ - - Parameters: - xrange: Set the range of color values. Default is [-0.5, 1.0]. - yrange: Set the range of magnitude values. Default is [15, 22]. - """ - tab = self.cat - x_bins = np.arange( - xrange[0], xrange[1], inputs[self.stream]["bin_sizes"][0] - ) # Used 0.03 for Jhelum - y_bins = np.arange( - yrange[0], yrange[1], inputs[self.stream]["bin_sizes"][1] - ) # Used 0.2 for Jhelum - - # if this is the second runthrough and a proper motion mask already exists, use that instead of the rough one - if self.pawprint.pm1print is not None: - data, xedges, yedges = np.histogram2d( - (tab[self.data_color1] - tab[self.data_color2])[ - self.on_pm12mask & self.on_skymask - ], - (tab[self.data_mag] - self.dist_mod_correct)[ - self.on_pm12mask & self.on_skymask - ], - bins=[x_bins, y_bins], - density=True, - ) - else: - data, xedges, yedges = np.histogram2d( - (tab[self.data_color1] - tab[self.data_color2])[ - self.on_pmmask & self.on_skymask - ], - (tab[self.data_mag] - self.dist_mod_correct)[ - self.on_pmmask & self.on_skymask - ], - bins=[x_bins, y_bins], - density=True, - ) - - self.x_edges = xedges - self.y_edges = yedges - self.CMD_data = data.T - - def correct_isochrone(self): - """ - Correlate the 2D histograms from the data and the - theoretical isochrone to find the shift in color - and magnitude necessary for the best match - """ - - signal, xedges, yedges = np.histogram2d( - self.color, - self.mag, - bins=[self.x_edges, self.y_edges], - weights=np.ones(len(self.mag)), - ) - - signal_counts, xedges, yedges = np.histogram2d( - self.color, self.mag, bins=[self.x_edges, self.y_edges] - ) - signal = signal / signal_counts - signal[np.isnan(signal)] = 0.0 - signal = signal.T - - ccor2d = correlate2d(self.CMD_data, signal) - y, x = np.unravel_index(np.argmax(ccor2d), ccor2d.shape) - self.x_shift = (x - len(ccor2d[0]) / 2.0) * (self.x_edges[1] - self.x_edges[0]) - self.y_shift = (y - len(ccor2d) / 2.0) * (self.y_edges[1] - self.y_edges[0]) - - def make_poly(self, iso_low, iso_high, maxmag=26, minmag=14): - """ - Generate the CMD polygon mask. - - ------------------------------------------------------------------ - - Parameters: - iso_low: spline function describing the "left" bound of the theorietical isochrone - iso_high: spline function describing the "right" bound of the theoretical isochrone - maxmag: faint limit of theoretical isochrone, should be deeper than all data - minmag: bright limit of theoretical isochrone, either include just MS and subgiant branch or whole isochrone - - Returns: - cmd_poly: Polygon vertices in CMD space. - cmd_mask: Boolean mask in CMD sapce. - - """ - - mag_vals = np.arange(minmag, maxmag, 0.01) - col_low_vals = iso_low(mag_vals) - col_high_vals = iso_high(mag_vals) - - cmd_poly = np.concatenate( - [ - np.array([col_low_vals, mag_vals]).T, - np.flip(np.array([col_high_vals, mag_vals]).T, axis=0), - ] - ) - cmd_footprint = Footprint2D(cmd_poly, footprint_type="cartesian") - - cmd_points = np.vstack( - ( - self.cat[self.data_color1] - self.cat[self.data_color2], - self.cat[self.data_mag] - self.dist_mod_correct, - ) - ).T - cmd_mask = cmd_footprint.inside_footprint(cmd_points) - - return cmd_footprint, cmd_mask - - def get_tolerance(self, scale_err=1, base_tol=0.075): - """ - Convolving errors to create wider selections near mag limit - Code written by Nora Shipp and adapted by Kiyan Tavangar - """ - if self.phot_survey == "PS1": - err = lambda x: 0.00363355415 + np.exp((x - 23.9127145) / 1.09685211) - elif self.phot_survey == "DES_DR2": - # from DES_DR1 in Nora's code (I think should apply here as well) - err = lambda x: 0.0010908679647672335 + np.exp( - (x - 27.091072029215375) / 1.0904624484538419 - ) - else: - # assume PS1 while I wait for Gaia photometry - err = lambda x: 0.00363355415 + np.exp((x - 23.9127145) / 1.09685211) - # err=lambda x: 0*x - - return scale_err * err(self.mag) + base_tol - - def simpleSln(self, maxmag=22, scale_err=2, mass_thresh=0.80): - """ - Select the stars that are within the CMD polygon cut - -------------------------------- - Parameters: - - maxmag: faint limit of created CMD polygon, should be deeper than all data - - mass_thresh: upper limit for the theoretical mass that dictates the bright limit of the - theoretical isochrone used for polygon - - coloff: shift in color from theoretical isochrone to data - - magoff: shift in magnitude from theoretical isochrone to data - - Returns: - - cmd_poly: vertices of the CMD polygon cut - - cmd_mask: bitmask of stars that pass the polygon cut - - iso_model: the theoretical isochrone after shifts - - iso_low: the "left" bound of the CMD polygon cut made from theoretical isochrone - - iso_high: the "right" bound of the CMD polygon cut made from theoretical isochrone - """ - - coloff = self.x_shift - magoff = self.y_shift - ind = self.masses < mass_thresh - - tol = self.get_tolerance(scale_err)[ind] - - iso_low = interp1d( - self.mag[ind] + magoff, - self.color[ind] + coloff - tol, - fill_value="extrapolate", - ) - iso_high = interp1d( - self.mag[ind] + magoff, - self.color[ind] + coloff + tol, - fill_value="extrapolate", - ) - # iso_model = interp1d( - # self.mag[ind] + magoff, self.color[ind] + coloff, fill_value="extrapolate" - # ) - - hb_print, self.hb_mask = self.make_hb_print() - - cmd_footprint, self.cmd_mask = self.make_poly( - iso_low, iso_high, maxmag, minmag=self.turnoff - ) - - # self.pawprint.cmd_filters = ... need to specify this since g vs g-r is a specific choice - # self.pawprint.add_cmd_footprint(cmd_footprint, 'g_r', 'g', 'cmdprint') - self.pawprint.cmdprint = cmd_footprint - self.pawprint.hbprint = hb_print - - # self.pawprint.save_pawprint(...) - - return cmd_footprint, self.cmd_mask, hb_print, self.hb_mask, self.pawprint - - def make_hb_print(self): - # probably want to incorporate this into cmdprint and have two discontinuous regions - if self.phot_survey == "PS1": - if self.band2 == "i": - g_i_0 = np.array([-0.9, -0.6, -0.2, 0.45, 0.6, -0.6, -0.9]) - Mg_ps1 = np.array([3.3, 3.3, 0.9, 1.25, 0.4, 0.1, 3.3]) + self.dist_mod - - hb_poly = np.vstack((g_i_0, Mg_ps1)).T - hb_footprint = Footprint2D(hb_poly, footprint_type="cartesian") - hb_points = np.vstack( - ( - self.cat[self.data_color1] - self.cat[self.data_color2], - self.cat[self.data_mag] - self.dist_mod_correct, - ) - ).T - hb_mask = hb_footprint.inside_footprint(hb_points) - - elif self.band2 == "r": - # g-r, g panstarss bands - g_r_0 = np.array([-0.5, -0.3, -0.1, 0.35, 0.45, -0.35, -0.5]) - Mg_ps1 = np.array([3.3, 3.3, 1.0, 1.2, 0.4, 0.1, 3.3]) + self.dist_mod - - hb_poly = np.vstack((g_r_0, Mg_ps1)).T - hb_footprint = Footprint2D(hb_poly, footprint_type="cartesian") - hb_points = np.vstack( - ( - self.cat[self.data_color1] - self.cat[self.data_color2], - self.cat[self.data_mag] - self.dist_mod_correct, - ) - ).T - hb_mask = hb_footprint.inside_footprint(hb_points) - - elif self.phot_survey == "Gaia": - bp_rp_0 = np.array([-0.5, -0.2, 0.15, 0.85, 0.9, -0.1, -0.5]) - Mv = np.array([3.3, 3.3, 1.05, 0.8, -0.0, 0.4, 3.3]) + self.dist_mod - - hb_poly = np.vstack( - (bp_rp_0, Mv) - ).T # doesn't take into account distance gradient - hb_footprint = Footprint2D(hb_poly, footprint_type="cartesian") - hb_points = np.vstack( - ( - self.cat[self.data_color1] - self.cat[self.data_color2], - self.cat[self.data_mag] - self.dist_mod_correct, - ) - ).T - hb_mask = hb_footprint.inside_footprint(hb_points) - - elif self.phot_survey == "des": - # don't select any points until we get the des polygon - g_r_0 = np.array([-0.5, -0.5]) - Mg_des = np.array([3.3, 3.3]) + self.dist_mod - - # g_r_0 = np.array([-0.5, -0.2, 0.15, 0.85, 0.9, -0.1, -0.5]) - # Mg_des = np.array([3.3, 3.3, 1.05, 0.8, -0.0, 0.4, 3.3]) - - hb_poly = np.vstack( - (g_r_0, Mg_des) - ).T # doesn't take into account distance gradient - hb_footprint = Footprint2D(hb_poly, footprint_type="cartesian") - hb_points = np.vstack( - ( - self.cat[self.data_color1] - self.cat[self.data_color2], - self.cat[self.data_mag] - self.dist_mod_correct, - ) - ).T - hb_mask = hb_footprint.inside_footprint(hb_points) - - return hb_footprint, hb_mask - - def plot_CMD(self, scale_err=2): - """ - Plot the shifted isochrone over a 2D histogram of the polygon-selected - data. - - Returns matplotlib Figure. - - WANT TO PLOT THE ACTUAL POLYGON USED AS WELL - """ - if self.pawprint.pm1print is not None: - cat = self.cat[self.on_pm12mask & self.on_skymask] - # cat = self.cat[self.on_pm12mask] - else: - cat = self.cat[self.on_pmmask & self.on_skymask] - # cat = self.cat[self.on_pmmask] - color = self.color + self.x_shift - mag = self.mag + self.y_shift - # mass_pdf = self.masses - bins = (np.linspace(-0.5, 1.5, 128), np.linspace(10, 22.5, 128)) - - fig = plt.figure(figsize=(8, 8)) - ax = fig.add_subplot(111) - - ax.hist2d( - cat[self.data_color1] - cat[self.data_color2], - cat[self.data_mag], - bins=bins, - norm=mpl.colors.LogNorm(), - zorder=5, - ) - ax.plot( - color, - mag, - color="k", - ls="--", - zorder=10, - ) - - ax.plot( - color - self.get_tolerance(scale_err), - mag, - color="b", - ls="--", - zorder=10, - ) - ax.plot( - color + self.get_tolerance(scale_err), - mag, - color="b", - ls="--", - zorder=10, - ) - - patch_cmd = PathPatch( - self.pawprint.cmdprint.footprint, - facecolor="none", - edgecolor="red", - linewidth=3, - zorder=10, - ) - ax.add_patch(patch_cmd) - - patch_hb = PathPatch( - self.pawprint.hbprint.footprint, - facecolor="none", - edgecolor="red", - linewidth=3, - zorder=10, - ) - ax.add_patch(patch_hb) - - ax.set_xlabel( - f"{self.band1}-{self.band2}", - fontsize=20, - ) - ax.set_ylabel( - f"{self.band1}", - fontsize=20, - ) - - ax.set_ylim(21, np.min(mag)) - ax.set_xlim(-0.5, 1.5) - - return fig - - def convolve_1d(self, probabilities, mag_err): - """ - 1D Gaussian convolution. - - ------------------------------------------------------------------ - - Parameters: - probabilities: - mag_err: Uncertainty in the magnitudes. - - """ - self.probabilities = probabilities - self.mag_err = mag_err - - sigma = mag_err / self.ybin # error in pixel units - kernel = Gaussian1DKernel(sigma) - convolved = convolve(probabilities, kernel) - - self.convolved = convolved - - def convolve_errors(self, g_errors, r_errors, intr_err=0.1): - """ - - 1D Gaussian convolution of the data with uncertainties. - - ------------------------------------------------------------------ - - Parameters: - g_errors: g magnitude uncertainties. - r_errors: r magnitude uncertainties. - intr_err: Free to set. Default is 0.1. - - """ - - for i in range(len(probabilities)): - probabilities[i] = convolve_1d( - probabilities[i], - np.sqrt( - g_errors(self.x_bins[i]) ** 2 - + r_errors(self.y_bins[i]) ** 2 - + intr_err**2 - ), - sel.fx_bins[1] - self.x_bins[0], - ) - - self.probabilities = probabilities - - def errFn(self): - """ - Generate the errors for the magnitudes? - """ - - gerrs = np.zeros(len(self.y_bins)) - rerrs = np.zeros(len(self.x_bins)) - - for i in range(len(self.y_bins)): - gerrs[i] = np.nanmedian( - self.cat["g0"][abs(self.cat["g0"] - self.y_bins[i]) < self.ybin / 2] - ) - rerrs[i] = np.nanmedian( - self.cat["r0"][abs(self.cat["g0"] - self.x_bins[i]) < self.xbin / 2] - ) - - gerrs = interp1d(self.y_bins, gerrs, fill_value="extrapolate") - rerrs = interp1d(self.x_bins, rerrs, fill_value="extrapolate") - - self.gerrs = gerrs - self.rerrs = rerrs diff --git a/src/cats/cmd/__init__.py b/src/cats/cmd/__init__.py new file mode 100644 index 0000000..8ddfe6e --- /dev/null +++ b/src/cats/cmd/__init__.py @@ -0,0 +1,7 @@ +"""CMD functions.""" + +from __future__ import annotations + +__all__ = ["Isochrone"] + +from ._core import Isochrone diff --git a/src/cats/CMD.py b/src/cats/cmd/_core.py similarity index 62% rename from src/cats/CMD.py rename to src/cats/cmd/_core.py index 52a8a8e..728357f 100644 --- a/src/cats/CMD.py +++ b/src/cats/cmd/_core.py @@ -2,87 +2,82 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any + +import astropy.units as u import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import scipy -from isochrones.mist import MIST_Isochrone +from astropy.coordinates import Distance +from isochrones.mist import MIST_Isochrone # pylint: disable=import-error from matplotlib.patches import PathPatch from scipy.interpolate import InterpolatedUnivariateSpline, interp1d from scipy.signal import correlate2d from ugali.analysis.isochrone import factory as isochrone_factory from cats.inputs import stream_inputs as inputs -from cats.pawprint.pawprint import Footprint2D - -__authors__ = "Ani, Kiyan, Richard" +from cats.pawprint._footprint import Footprint2D -plt.rc( - "xtick", - top=True, - direction="in", - labelsize=15, -) -plt.rc( - "ytick", - right=True, - direction="in", - labelsize=15, -) -plt.rc( - "font", - family="Arial", -) +if TYPE_CHECKING: + from matplotlib.figure import Figure + from cats.pawprint import Pawprint -class Isochrone: - def __init__(self, stream, cat, pawprint): - """ - Defining variables loaded into class. +__authors__ = "Ani, Kiyan, Richard" - ------------------------------------------------------------------ +plt.rc("xtick", top=True, direction="in", labelsize=15) +plt.rc("ytick", right=True, direction="in", labelsize=15) +plt.rc("font", family="Arial") - Parameters: - cat = Input catalogue. - age = Input age of stream from galstreams and/or literature. - feh = Input metallicity from galstreams and/or literature. - distance = Input distance from galstreams and/or literature. - alpha = alpha/Fe - pawprint = Stream multidimensional footprint - """ - # Pull survey from catalog? - self.stream = stream +class Isochrone: + """Isochrone class for CMD selection. + + Parameters + ---------- + name : str + The name of the stream. Parameters must be registered in ``inputs``. + cat: + Input catalogue. + pawprint: + Stream multidimensional footprint. + """ + + def __init__(self, name: str, /, cat: Any, pawprint: Pawprint) -> None: + self.stream = name self.cat = cat - self.age = inputs[stream]["age"] - self.feh = inputs[stream]["feh"] - self.distance = inputs[stream]["distance"] # kpc - self.alpha = inputs[stream]["alpha"] - self.dist_mod = 5 * np.log10(1000 * self.distance) - 5 - self.pawprint = pawprint + + params = inputs[name] + self.age = params["age"] + self.feh = params["feh"] + self.alpha = params["alpha"] + self.distance = params["distance"] + track = self.pawprint.track.track.transform_to(self.pawprint.track.stream_frame) if self.stream == "GD-1": - distmod_spl = np.poly1d([2.41e-4, 2.421e-2, 15.001]) - self.dist_mod_correct = distmod_spl(self.cat["phi1"]) - self.dist_mod + distmod_spl = np.poly1d([2.41e-4, 2.421e-2, 15.001]) # [deg] -> [mag] + self.dist_mod_correct = ( + distmod_spl(self.cat["phi1"]) - self.distance.distmod + ) else: - spline_dist = InterpolatedUnivariateSpline( - track.phi1.value, track.distance.value + spline_dist = InterpolatedUnivariateSpline( # [deg] -> [kpc] + track.phi1.to_value(u.deg), track.distance.to_value(u.kpc) ) - self.dist_mod_correct = ( - 5 * np.log10(spline_dist(self.cat["phi1"]) * 1000) - 5 - ) - self.dist_mod + delta_distmod = Distance(spline_dist(self.cat["phi1"]), unit=u.kpc).distmod + self.dist_mod_correct = delta_distmod - self.dist_mod self.x_shift = 0 self.y_shift = 0 - self.phot_survey = inputs[self.stream]["phot_survey"] - self.band1 = inputs[self.stream]["band1"] - self.band2 = inputs[self.stream]["band2"] - self.data_mag = inputs[self.stream]["mag"] - self.data_color1 = inputs[self.stream]["color1"] - self.data_color2 = inputs[self.stream]["color2"] - self.turnoff = inputs[self.stream]["turnoff"] + self.phot_survey = params["phot_survey"] + self.band1 = params["band1"] + self.band2 = params["band2"] + self.data_mag = params["mag"] + self.data_color1 = params["color1"] + self.data_color2 = params["color2"] + self.turnoff = params["turnoff"] self.generate_isochrone() self.sel_sky() @@ -92,15 +87,12 @@ def __init__(self, stream, cat, pawprint): self.data_cmd() if self.pawprint.pm1print is not None: - # Only shift isochrone is the previous cuts are clean enough - # Otherwise it will just shift to the background + # Only shift isochrone is the previous cuts are clean enough, + # otherwise it will just shift to the background self.correct_isochrone() - def sel_sky(self): - """ - Initialising the on-sky polygon mask to return only contained sources. - """ - + def sel_sky(self) -> None: + """Initialize the on-sky polygon mask.""" on_poly_patch = mpl.patches.Polygon( self.pawprint.skyprint["stream"].vertices[::50], facecolor="none", @@ -108,31 +100,17 @@ def sel_sky(self): linewidth=2, ) on_points = np.vstack((self.cat["phi1"], self.cat["phi2"])).T - on_mask = on_poly_patch.get_path().contains_points(on_points) - - # on_points = np.vstack((self.cat["phi1"], self.cat["phi2"])).T - # on_mask = self.pawprint.skyprint['stream'].inside_footprint(on_points) #very slow because skyprint is very large - - self.on_skymask = on_mask - - def sel_pm(self): - """ - Initialising the proper motions polygon mask to return only contained sources. - """ + self.on_skymask = on_poly_patch.get_path().contains_points(on_points) + def sel_pm(self) -> None: + """Initialize the proper motions polygon mask.""" on_points = np.vstack( (self.cat["pm_phi1_cosphi2_unrefl"], self.cat["pm_phi2_unrefl"]) ).T + self.on_pmmask = self.pawprint.pmprint.inside_footprint(on_points) - on_mask = self.pawprint.pmprint.inside_footprint(on_points) - - self.on_pmmask = on_mask - - def sel_pm12(self): - """ - Initialising the proper motions polygon mask to return only contained sources. - """ - + def sel_pm12(self) -> None: + """Initialize the proper motions polygon mask.""" on_pm1_points = np.vstack( (self.cat["phi1"], self.cat["pm_phi1_cosphi2_unrefl"]) ).T @@ -145,11 +123,8 @@ def sel_pm12(self): self.on_pm2mask = on_pm2_mask self.on_pm12mask = on_pm1_mask & on_pm2_mask - def generate_isochrone(self): - """ - load an isochrone, LF model for a given metallicity, age, distance - """ - + def generate_isochrone(self) -> None: + """Load an isochrone, LF model for a given metallicity, age, distance.""" # Convert feh to z Y_p = 0.245 # Primordial He abundance (WMAP, 2003) c = 1.54 # He enrichment ratio @@ -159,21 +134,20 @@ def generate_isochrone(self): if self.phot_survey == "Gaia": mist = MIST_Isochrone() iso = mist.isochrone( - age=np.log10(1e9 * self.age), # has to be given in logAge + age=9 + np.log10(self.age.to_value(u.Gyr)), # log(age [yr]) feh=self.feh, eep_range=None, # get the whole isochrone, - distance=1e3 * self.distance, # given in parsecs + distance=self.distance.to_value(u.pc), ) - initial_mass, actual_mass = iso.initial_mass.values, iso.mass.values mag = iso.G_mag.values color_1 = iso.BP_mag.values color_2 = iso.RP_mag.values # Excise the horizontal branch turn_idx = scipy.signal.argrelextrema(iso.G_mag.values, np.less)[0][0] - initial_mass = initial_mass[0:turn_idx] - actual_mass = actual_mass[0:turn_idx] + initial_mass = iso.initial_mass.values[0:turn_idx] + actual_mass = iso.mass.values[0:turn_idx] self.masses = actual_mass self.mag = mag[0:turn_idx] @@ -221,18 +195,20 @@ def generate_isochrone(self): self.mmag_2 = mmag_2 self.mmass_pdf = mmass_pdf - # return iso, initial_mass, mass_pdf, actual_mass, mag_1, mag_2, mmag_1, mmag_2, \ - # mmass_pdf + def data_cmd( + self, + xrange: tuple[float, float] = (-0.5, 1.0), + yrange: tuple[float, float] = (15, 22), + ) -> None: + """Make Empirical CMD. - def data_cmd(self, xrange=(-0.5, 1.0), yrange=(15, 22)): - """ - Empirical CMD generated from the input catalogue, with distance gradient accounted for. + Empirical CMD generated from the input catalogue, with distance gradient + accounted for. - ------------------------------------------------------------------ - - Parameters: - xrange: Set the range of color values. Default is [-0.5, 1.0]. - yrange: Set the range of magnitude values. Default is [15, 22]. + Parameters + ---------- + xrange, yrange: tuple[float, float] + Set the range of color values. """ tab = self.cat x_bins = np.arange( @@ -242,7 +218,8 @@ def data_cmd(self, xrange=(-0.5, 1.0), yrange=(15, 22)): yrange[0], yrange[1], inputs[self.stream]["bin_sizes"][1] ) # Used 0.2 for Jhelum - # if this is the second runthrough and a proper motion mask already exists, use that instead of the rough one + # if this is the second runthrough and a proper motion mask already + # exists, use that instead of the rough one if self.pawprint.pm1print is not None: data, xedges, yedges = np.histogram2d( (tab[self.data_color1] - tab[self.data_color2])[ @@ -270,21 +247,21 @@ def data_cmd(self, xrange=(-0.5, 1.0), yrange=(15, 22)): self.y_edges = yedges self.CMD_data = data.T - def correct_isochrone(self): - """ + def correct_isochrone(self) -> None: + """Correct the isochrone. + Correlate the 2D histograms from the data and the theoretical isochrone to find the shift in color and magnitude necessary for the best match """ - - signal, xedges, yedges = np.histogram2d( + signal, *_ = np.histogram2d( self.color, self.mag, bins=[self.x_edges, self.y_edges], weights=np.ones(len(self.mag)), ) - signal_counts, xedges, yedges = np.histogram2d( + signal_counts, *_ = np.histogram2d( self.color, self.mag, bins=[self.x_edges, self.y_edges] ) signal = signal / signal_counts @@ -292,28 +269,42 @@ def correct_isochrone(self): signal = signal.T ccor2d = correlate2d(self.CMD_data, signal) - y, x = np.unravel_index(np.argmax(ccor2d), ccor2d.shape) + y, x = np.unravel_index( + np.argmax(ccor2d), ccor2d.shape + ) # pylint: disable=W0632 self.x_shift = (x - len(ccor2d[0]) / 2.0) * (self.x_edges[1] - self.x_edges[0]) self.y_shift = (y - len(ccor2d) / 2.0) * (self.y_edges[1] - self.y_edges[0]) - def make_poly(self, iso_low, iso_high, maxmag=26, minmag=14): + def make_poly( + self, + iso_low: InterpolatedUnivariateSpline, + iso_high: InterpolatedUnivariateSpline, + maxmag: float = 26, + minmag: float = 14, + ) -> tuple[Any, Any]: + """Generate the CMD polygon mask. + + Parameters + ---------- + iso_low: InterpolatedUnivariateSpline + spline function describing the "left" bound of the theorietical + isochrone + iso_high: InterpolatedUnivariateSpline + spline function describing the "right" bound of the theoretical + isochrone + maxmag: float + faint limit of theoretical isochrone, should be deeper than all data + minmag: float + bright limit of theoretical isochrone, either include just MS and + subgiant branch or whole isochrone + + Returns + ------- + cmd_poly : NDArray + Polygon vertices in CMD space. + cmd_mask : NDArray[bool] + Boolean mask in CMD sapce. """ - Generate the CMD polygon mask. - - ------------------------------------------------------------------ - - Parameters: - iso_low: spline function describing the "left" bound of the theorietical isochrone - iso_high: spline function describing the "right" bound of the theoretical isochrone - maxmag: faint limit of theoretical isochrone, should be deeper than all data - minmag: bright limit of theoretical isochrone, either include just MS and subgiant branch or whole isochrone - - Returns: - cmd_poly: Polygon vertices in CMD space. - cmd_mask: Boolean mask in CMD sapce. - - """ - mag_vals = np.arange(minmag, maxmag, 0.01) col_low_vals = iso_low(mag_vals) col_high_vals = iso_high(mag_vals) @@ -336,44 +327,58 @@ def make_poly(self, iso_low, iso_high, maxmag=26, minmag=14): return cmd_footprint, cmd_mask - def get_tolerance(self, scale_err=1, base_tol=0.075): - """ - Convolving errors to create wider selections near mag limit - Code written by Nora Shipp and adapted by Kiyan Tavangar + def get_tolerance(self, scale_err: float = 1, base_tol: float = 0.075) -> float: + """Convolving errors to create wider selections near mag limit. + + .. codeauthor:: + Nora Shipp, Kiyan Tavangar """ if self.phot_survey == "PS1": - err = lambda x: 0.00363355415 + np.exp((x - 23.9127145) / 1.09685211) + offset = 0.00363355415 + mu = 23.9127145 + scale = 1.09685211 + elif self.phot_survey == "DES_DR2": # from DES_DR1 in Nora's code (I think should apply here as well) - err = lambda x: 0.0010908679647672335 + np.exp( - (x - 27.091072029215375) / 1.0904624484538419 - ) + offset = 0.0010908679647672335 + mu = 27.091072029215375 + scale = 1.0904624484538419 + else: # assume PS1 while I wait for Gaia photometry - err = lambda x: 0.00363355415 + np.exp((x - 23.9127145) / 1.09685211) - # err=lambda x: 0*x + offset = 0.00363355415 + mu = 23.9127145 + scale = 1.09685211 + + def err(x: float) -> float: + return offset + np.exp((x - mu) / scale) return scale_err * err(self.mag) + base_tol - def simpleSln(self, maxmag=22, scale_err=2, mass_thresh=0.80): - """ - Select the stars that are within the CMD polygon cut - -------------------------------- - Parameters: - - maxmag: faint limit of created CMD polygon, should be deeper than all data - - mass_thresh: upper limit for the theoretical mass that dictates the bright limit of the - theoretical isochrone used for polygon - - coloff: shift in color from theoretical isochrone to data - - magoff: shift in magnitude from theoretical isochrone to data - - Returns: + def simpleSln( + self, maxmag: float = 22, scale_err: float = 2, mass_thresh: float = 0.80 + ) -> tuple[Any, Any, Any, Any]: + """Select the stars that are within the CMD polygon cut. + + Parameters + ---------- + maxmag: float + faint limit of created CMD polygon, should be deeper than all data + scale_err : float + TODO. + mass_thresh : float + TODO. + + Returns + ------- - cmd_poly: vertices of the CMD polygon cut - cmd_mask: bitmask of stars that pass the polygon cut - iso_model: the theoretical isochrone after shifts - - iso_low: the "left" bound of the CMD polygon cut made from theoretical isochrone - - iso_high: the "right" bound of the CMD polygon cut made from theoretical isochrone + - iso_low: the "left" bound of the CMD polygon cut made from theoretical + isochrone + - iso_high: the "right" bound of the CMD polygon cut made from + theoretical isochrone """ - coloff = self.x_shift magoff = self.y_shift ind = self.masses < mass_thresh @@ -390,9 +395,6 @@ def simpleSln(self, maxmag=22, scale_err=2, mass_thresh=0.80): self.color[ind] + coloff + tol, fill_value="extrapolate", ) - # iso_model = interp1d( - # self.mag[ind] + magoff, self.color[ind] + coloff, fill_value="extrapolate" - # ) hb_print, self.hb_mask = self.make_hb_print() @@ -400,8 +402,6 @@ def simpleSln(self, maxmag=22, scale_err=2, mass_thresh=0.80): iso_low, iso_high, maxmag, minmag=self.turnoff ) - # self.pawprint.cmd_filters = ... need to specify this since g vs g-r is a specific choice - # self.pawprint.add_cmd_footprint(cmd_footprint, 'g_r', 'g', 'cmdprint') self.pawprint.cmdprint = cmd_footprint self.pawprint.hbprint = hb_print @@ -409,8 +409,10 @@ def simpleSln(self, maxmag=22, scale_err=2, mass_thresh=0.80): return cmd_footprint, self.cmd_mask, hb_print, self.hb_mask, self.pawprint - def make_hb_print(self): - # probably want to incorporate this into cmdprint and have two discontinuous regions + def make_hb_print(self) -> None: + """Make the horizontal branch polygon mask.""" + # probably want to incorporate this into cmdprint and have two + # discontinuous regions if self.phot_survey == "PS1": if self.band2 == "i": g_i_0 = np.array([-0.9, -0.6, -0.2, 0.45, 0.6, -0.6, -0.9]) @@ -479,10 +481,10 @@ def make_hb_print(self): return hb_footprint, hb_mask - def plot_CMD(self, scale_err=2): - """ - Plot the shifted isochrone over a 2D histogram of the polygon-selected - data. + def plot_CMD(self, scale_err: float = 2) -> Figure: + """Plot the shifted isochrone. + + Over a 2D histogram of the polygon-selected data. Returns matplotlib Figure. @@ -564,69 +566,67 @@ def plot_CMD(self, scale_err=2): return fig - def convolve_1d(self, probabilities, mag_err): - """ - 1D Gaussian convolution. - - ------------------------------------------------------------------ - - Parameters: - probabilities: - mag_err: Uncertainty in the magnitudes. - - """ - self.probabilities = probabilities - self.mag_err = mag_err - - sigma = mag_err / self.ybin # error in pixel units - kernel = Gaussian1DKernel(sigma) - convolved = convolve(probabilities, kernel) - - self.convolved = convolved - - def convolve_errors(self, g_errors, r_errors, intr_err=0.1): - """1D Gaussian convolution of the data with uncertainties. - - ------------------------------------------------------------------ - - Parameters: - g_errors: g magnitude uncertainties. - r_errors: r magnitude uncertainties. - intr_err: Free to set. Default is 0.1. - - """ - - for i in range(len(probabilities)): - probabilities[i] = convolve_1d( - probabilities[i], - np.sqrt( - g_errors(self.x_bins[i]) ** 2 - + r_errors(self.y_bins[i]) ** 2 - + intr_err**2 - ), - sel.fx_bins[1] - self.x_bins[0], - ) - - self.probabilities = probabilities - - def errFn(self): - """ - Generate the errors for the magnitudes? - """ - - gerrs = np.zeros(len(self.y_bins)) - rerrs = np.zeros(len(self.x_bins)) - - for i in range(len(self.y_bins)): - gerrs[i] = np.nanmedian( - self.cat["g0"][abs(self.cat["g0"] - self.y_bins[i]) < self.ybin / 2] - ) - rerrs[i] = np.nanmedian( - self.cat["r0"][abs(self.cat["g0"] - self.x_bins[i]) < self.xbin / 2] - ) - - gerrs = interp1d(self.y_bins, gerrs, fill_value="extrapolate") - rerrs = interp1d(self.x_bins, rerrs, fill_value="extrapolate") - - self.gerrs = gerrs - self.rerrs = rerrs + # def convolve_1d(self, probabilities: NDArray, mag_err: NDArray) -> NDArray: + # """1D Gaussian convolution. + + # Parameters + # ---------- + # probabilities : NDArray + # Probability of the magnitudes. + # mag_err : NDArray + # Uncertainty in the magnitudes. + # """ + # self.probabilities = probabilities + # self.mag_err = mag_err + + # sigma = mag_err / self.ybin # error in pixel units + # kernel = Gaussian1DKernel(sigma) + # convolved = convolve(probabilities, kernel) + + # self.convolved = convolved + + # def convolve_errors( + # self, + # probabilities: NDArray, + # g_errors: Callable[[NDArray], NDArray], + # r_errors: Callable[[NDArray], NDArray], + # intr_err: float = 0.1, + # ) -> None: + # """1D Gaussian convolution of the data with uncertainties. + + # Parameters + # ---------- + # probabilities : NDArray + # Probability of the magnitudes. + # g_errors, r_errors : Callable[[ndarray], ndarray] + # g, r magnitude uncertainties. + # intr_err: + # Free to set. Default is 0.1. + # """ + # for i in range(len(probabilities)): + # probabilities[i] = self.convolve_1d( + # probabilities[i], + # np.sqrt( + # g_errors(self.x_bins[i]) ** 2 + # + r_errors(self.y_bins[i]) ** 2 + # + intr_err**2 + # ), + # self.fx_bins[1] - self.x_bins[0], + # ) + + # self.probabilities = probabilities + + # TODO: remove this function? + # def errFn(self) -> None: + # """Generate the errors for the magnitudes.""" + # g0 = self.cat["g0"] + # r0 = self.cat["r0"] + # yhb = self.ybin / 2 # half bin size + # xhb = self.xbin / 2 + # gerrs = np.array([np.nanmedian(g0[abs(g0 - yb) < yhb]) for yb in self.y_bins]) + # rerrs = np.array( # TODO: are we sure this is right? + # [np.nanmedian(r0[abs(g0 - xb) < xhb]) for xb in self.x_bins] + # ) + + # self.gerrs = interp1d(self.y_bins, gerrs, fill_value="extrapolate") + # self.rerrs = interp1d(self.x_bins, rerrs, fill_value="extrapolate") diff --git a/src/cats/coords.py b/src/cats/coords.py deleted file mode 100644 index 3139133..0000000 --- a/src/cats/coords.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -Assumptions and helpers for working with coordinates -""" -from __future__ import annotations - -import astropy.coordinates as coord -import astropy.units as u - -galcen_frame = coord.Galactocentric( - galcen_distance=8.275 * u.kpc, galcen_v_sun=[8.4, 251.8, 8.4] * u.km / u.s -) diff --git a/src/cats/data.py b/src/cats/data.py index 9c8967d..019b1f7 100644 --- a/src/cats/data.py +++ b/src/cats/data.py @@ -1,60 +1,84 @@ +"""Helper functions for working with data.""" + from __future__ import annotations -import astropy.table as at +__all__ = ["make_astro_photo_joined_data"] + +from typing import TYPE_CHECKING + import astropy.units as u import gala.coordinates as gc import numpy as np -import scipy.interpolate as sci +from astropy.table import hstack, join, unique +from scipy.interpolate import InterpolatedUnivariateSpline +if TYPE_CHECKING: + from astropy.table import QTable + from galstreams import Track6D + from pyia import GaiaData + + from cats.photometry._base import PhotometricSurvey + + +def make_astro_photo_joined_data( + gaia_data: GaiaData, phot_data: PhotometricSurvey, track6d: Track6D +) -> QTable: + """Join Gaia and photometry data, and transform to stream coordinates. -def make_astro_photo_joined_data(gaia_data, phot_data, track6d): - """ Parameters ---------- gaia_data : `pyia.GaiaData` + The Gaia data. phot_data : `cats.photometry.PhotometricSurvey` + The photometry data. track6d : `galstreams.Track6D` + The stream track. + Returns + ------- + joined : `astropy.table.Table` + Joined table of Gaia and photometry data, transformed to stream coordinates. """ - stream_fr = track6d.stream_frame + stream_frame = track6d.stream_frameame + + # -------------------------------------------- + # Process the track - track = track6d.track.transform_to(stream_fr) + track = track6d.track.transform_to(stream_frame) if np.all(track.distance.value == 0): - raise ValueError( - "A distance track is required: this stream has no distance information in " - "the galstreams track." - ) + msg = "A distance track is required: this stream has no distance information." + raise ValueError(msg) - # interpolator to get predicted distance from phi1 - dist_interp = sci.InterpolatedUnivariateSpline( + # Interpolator to get predicted distance from phi1 + dist_interp = InterpolatedUnivariateSpline( track.phi1.degree, track.distance.value, k=1 ) - # get stream coordinates for all stars, and reflex correct with predicted distance - _c_tmp = gaia_data.get_skycoord(distance=False).transform_to(stream_fr) + # Get stream coordinates for all stars, and reflex correct with predicted distance + _c_tmp = gaia_data.get_skycoord(distance=False).transform_to(stream_frame) c = gaia_data.get_skycoord( distance=dist_interp(_c_tmp.phi1.degree) * track.distance.unit, radial_velocity=0 * u.km / u.s, ) - c_stream = c.transform_to(stream_fr) + c_stream = c.transform_to(stream_frame) c_stream_refl = gc.reflex_correct(c_stream) - # get extinction-corrected photometry and star/galaxy mask + # -------------------------------------------- + + # Get extinction-corrected photometry and star/galaxy mask ext = phot_data.get_ext_corrected_phot() ext["star_mask"] = phot_data.get_star_mask() - # start building the final joined table + # Start building the final joined table joined = gaia_data.data.copy() - for name in ["phi1", "phi2", "pm_phi1_cosphi2", "pm_phi2"]: + for name in ("phi1", "phi2"): joined[name] = getattr(c_stream_refl, name) - if name not in ["phi1", "phi2"]: - joined[f"{name}_unrefl"] = getattr(c_stream, name) + for name in ("pm_phi1_cosphi2", "pm_phi2"): + joined[name] = getattr(c_stream_refl, name) + joined[f"{name}_unrefl"] = getattr(c_stream, name) - phot_full = at.hstack([phot_data.data, ext]) + phot_full = hstack([phot_data.data, ext]) cols = ["source_id", "star_mask"] + [b for b in ext.colnames if b.endswith("0")] phot_min = phot_full[cols] - joined = at.join(joined, phot_min, keys="source_id") - joined = at.unique(joined, keys="source_id") - - return joined + return unique(join(joined, phot_min, keys="source_id"), keys="source_id") diff --git a/src/cats/frames.py b/src/cats/frames.py new file mode 100644 index 0000000..f134cb0 --- /dev/null +++ b/src/cats/frames.py @@ -0,0 +1,13 @@ +"""Default Coordinate Frames.""" + +from __future__ import annotations + +__all__ = ["galactocentric"] + +import astropy.coordinates as coord +import astropy.units as u + +galactocentric = coord.Galactocentric( + galcen_distance=8.275 * u.kpc, + galcen_v_sun=[8.4, 251.8, 8.4] * u.km / u.s, +) diff --git a/src/cats/inputs.py b/src/cats/inputs.py index ebb1846..d0d40c2 100644 --- a/src/cats/inputs.py +++ b/src/cats/inputs.py @@ -1,111 +1,111 @@ +"""Stream Configuration Inputs.""" + from __future__ import annotations -from collections import OrderedDict +from typing import Any -# config file with all the inputs +import astropy.units as u +from astropy.coordinates import Distance -stream_inputs = OrderedDict( - { - "GD-1": { - # galstream stuff - "short_name": "GD-1", - "pawprint_id": "pricewhelan2018", - # stream stuff - "width": 2.0, # full width in degrees (add units in pawprint) - # data stuff - "phot_survey": "PS1", - "band1": "g", - "band2": "r", - "mag": "g0", - "color1": "g0", - "color2": "r0", - "minmag": 16.0, - "maxmag": 24.0, - # isochrone stuff - "age": 11.8, # Gyr - "feh": -1.5, - "distance": 8.3, # kpc - "turnoff": 17.8, # mag of MS turnoff - "alpha": 0, # don't think we actually use this - "scale_err": 2, - "base_err": 0.075, - "bin_sizes": [0.03, 0.2], # xbin and ybin width for CMD - }, - "Pal5": { - # galstream stuff - "short_name": "Pal5", - "pawprint_id": "pricewhelan2019", - # stream stuff - "width": 1.0, # degrees (add units in pawprint) - # data stuff - "phot_survey": "PS1", - "band1": "g", - "band2": "r", - "mag": "g0", - "color1": "g0", - "color2": "r0", - "minmag": 16.0, - "maxmag": 24.0, - # isochrone stuff - "age": 12, # Gyr - "feh": -1.4, - "distance": 20.9, # kpc - "turnoff": 15, # mag of MS turnoff - "alpha": 0, # don't think we actually use this - "scale_err": 2, - "base_err": 0.075, - "bin_sizes": [0.03, 0.2], - }, - "Jhelum": { - # galstream stuff - "short_name": "Jhelum-b", - "pawprint_id": "bonaca2019", - # stream stuff - "width": 2.0, # degrees (add units in pawprint) - # data stuff - "phot_survey": "des", - "band1": "g", - "band2": "r", - "mag": "g0", - "color1": "g0", - "color2": "r0", - "minmag": 16.0, - "maxmag": 24.0, - # isochrone stuff - "age": 12, # Gyr - "feh": -1.7, - "distance": 13.2, # kpc - "turnoff": 18.7, # mag of MS turnoff - "alpha": 0, # don't think we actually use this - "scale_err": 2, - "base_err": 0.075, - "bin_sizes": [0.03, 0.2], - }, - "Fjorm-M68": { - # galstream stuff - "short_name": "M68-Fjorm", - # "pawprint_id": 'ibata2021', - "pawprint_id": "palau2019", - # stream stuff - "width": 1, # TOTAL width degrees, recommend 2sigma if known - # data stuff - "phot_survey": "Gaia", - "band1": "BP", - "band2": "RP", - "mag": "G0", - "color1": "BP0", - "color2": "RP0", - "minmag": 16.0, - "maxmag": 24.0, - # isochrone stuff - "age": 11.2, # Gyr - "feh": -2.2, - "distance": 6, # kpc - "turnoff": 17.0, # mag of MS turnoff - "alpha": 0, # don't think we actually use this - "scale_err": 2, - "base_err": 0.075, - "bin_sizes": [0.03, 0.2], - }, - } -) +stream_inputs: dict[str, dict[str, Any]] = {} +stream_inputs["GD-1"] = { + # galstream stuff + "short_name": "GD-1", + "pawprint_id": "pricewhelan2018", + # stream stuff + "width": 2.0, # full width in degrees (add units in pawprint) + # data stuff + "phot_survey": "PS1", + "band1": "g", + "band2": "r", + "mag": "g0", + "color1": "g0", + "color2": "r0", + "minmag": 16.0 * u.mag, + "maxmag": 24.0 * u.mag, + # isochrone stuff + "age": 11.8 * u.Gyr, + "feh": -1.5, + "distance": Distance(8.3, u.kpc), + "turnoff": 17.8 * u.mag, # mag of MS turnoff + "alpha": 0, # don't think we actually use this + "scale_err": 2, + "base_err": 0.075, + "bin_sizes": [0.03, 0.2], # xbin and ybin width for CMD +} +stream_inputs["Pal5"] = { + # galstream stuff + "short_name": "Pal5", + "pawprint_id": "pricewhelan2019", + # stream stuff + "width": 1.0, # degrees (add units in pawprint) + # data stuff + "phot_survey": "PS1", + "band1": "g", + "band2": "r", + "mag": "g0", + "color1": "g0", + "color2": "r0", + "minmag": 16.0 * u.mag, + "maxmag": 24.0 * u.mag, + # isochrone stuff + "age": 12 * u.Gyr, + "feh": -1.4, + "distance": Distance(20.9, u.kpc), + "turnoff": 15 * u.mag, # mag of MS turnoff + "alpha": 0, # don't think we actually use this + "scale_err": 2, + "base_err": 0.075, + "bin_sizes": [0.03, 0.2], +} +stream_inputs["Jhelum"] = { + # galstream stuff + "short_name": "Jhelum-b", + "pawprint_id": "bonaca2019", + # stream stuff + "width": 2.0, # degrees (add units in pawprint) + # data stuff + "phot_survey": "des", + "band1": "g", + "band2": "r", + "mag": "g0", + "color1": "g0", + "color2": "r0", + "minmag": 16.0 * u.mag, + "maxmag": 24.0 * u.mag, + # isochrone stuff + "age": 12 * u.Gyr, + "feh": -1.7, + "distance": Distance(13.2, u.kpc), + "turnoff": 18.7 * u.mag, # mag of MS turnoff + "alpha": 0, # don't think we actually use this + "scale_err": 2, + "base_err": 0.075, + "bin_sizes": [0.03, 0.2], +} +stream_inputs["Fjorm-M68"] = { + # galstream stuff + "short_name": "M68-Fjorm", + # "pawprint_id": 'ibata2021', + "pawprint_id": "palau2019", + # stream stuff + "width": 1, # TOTAL width degrees, recommend 2sigma if known + # data stuff + "phot_survey": "Gaia", + "band1": "BP", + "band2": "RP", + "mag": "G0", + "color1": "BP0", + "color2": "RP0", + "minmag": 16.0 * u.mag, + "maxmag": 24.0 * u.mag, + # isochrone stuff + "age": 11.2 * u.Gyr, + "feh": -2.2, + "distance": Distance(6, u.kpc), + "turnoff": 17.0 * u.mag, # mag of MS turnoff + "alpha": 0, # don't think we actually use this + "scale_err": 2, + "base_err": 0.075, + "bin_sizes": [0.03, 0.2], +} diff --git a/src/cats/pawprint/__init__.py b/src/cats/pawprint/__init__.py index e69de29..9b06878 100644 --- a/src/cats/pawprint/__init__.py +++ b/src/cats/pawprint/__init__.py @@ -0,0 +1,11 @@ +"""Pawprint module.""" + +from __future__ import annotations + +from . import _core, _footprint +from ._core import * +from ._footprint import * + +__all__ = [] +__all__ += _core.__all__ +__all__ += _footprint.__all__ diff --git a/src/cats/pawprint/pawprint.py b/src/cats/pawprint/_core.py similarity index 63% rename from src/cats/pawprint/pawprint.py rename to src/cats/pawprint/_core.py index 86634f1..c6d7543 100644 --- a/src/cats/pawprint/pawprint.py +++ b/src/cats/pawprint/_core.py @@ -1,93 +1,36 @@ +"""Core classes for pawprint.""" + from __future__ import annotations +__all__ = ["Pawprint"] + import pathlib +from typing import TYPE_CHECKING, Any import asdf import astropy.table as apt import astropy.units as u import galstreams as gst -import numpy as np from astropy.coordinates import SkyCoord from gala.coordinates import GreatCircleICRSFrame -from matplotlib.path import Path as mpl_path - -# class densityClass: #TODO: how to represent densities? - - -class Footprint2D(dict): - def __init__(self, vertex_coordinates, footprint_type, stream_frame=None): - if footprint_type == "sky": - if isinstance(vertex_coordinates, SkyCoord): - vc = vertex_coordinates - else: - vc = SkyCoord(vertex_coordinates) - self.edges = vc - self.vertices = np.array( - [vc.transform_to(stream_frame).phi1, vc.transform_to(stream_frame).phi2] - ).T - - elif footprint_type == "cartesian": - self.edges = vertex_coordinates - self.vertices = vertex_coordinates - - self.stream_frame = stream_frame - self.footprint_type = footprint_type - self.footprint = mpl_path(self.vertices) - - @classmethod - def from_vertices(cls, vertex_coordinates, footprint_type): - return cls(vertex_coordinates, footprint_type) - - @classmethod - def from_box(cls, min1, max1, min2, max2, footprint_type): - vertices = cls.get_vertices_from_box(min1, max1, min2, max2) - return cls(vertices, footprint_type) - @classmethod - def from_file(cls, fname): - with apt.Table.read(fname) as t: - vertices = t["vertices"] - footprint_type = t["footprint_type"] - return cls(vertices, footprint_type) - - def get_vertices_from_box(self, min1, max1, min2, max2): - return [[min1, min2], [min1, max2], [max1, min2], [max1, max2]] - - def inside_footprint(self, data): - if isinstance(data, SkyCoord): - if self.stream_frame is None: - print("can't!") - return None - else: - pts = np.array( - [ - data.transform_to(self.stream_frame).phi1.value, - data.transform_to(self.stream_frame).phi2.value, - ] - ).T - return self.footprint.contains_points(pts) - else: - return self.footprint.contains_points(data) +from cats.pawprint._footprint import Footprint2D - def export(self): - data = {} - data["stream_frame"] = self.stream_frame - data["vertices"] = self.vertices - data["footprint_type"] = self.footprint_type - return data +if TYPE_CHECKING: + from typing_extensions import Self class Pawprint(dict): - """Dictionary class to store a "pawprint": + """Dictionary class to store a "pawprint". + polygons in multiple observational spaces that define the initial selection used for stream track modeling, membership calculation / density modeling, and background modeling. New convention: everything is in phi1 phi2 (don't cross the streams) - """ - def __init__(self, data): + def __init__(self, data: dict[str, Any]) -> None: self.stream_name = data["stream_name"] self.pawprint_ID = data["pawprint_ID"] self.stream_frame = data["stream_frame"] @@ -139,11 +82,10 @@ def __init__(self, data): self.track = data["track"] @classmethod - def from_file(cls, fname): - import asdf - + def from_file(cls: type[Self], fname: str) -> Self: + """Create a pawprint from an asdf file.""" data = {} - with asdf.open("fname") as a: + with asdf.open(fname) as a: # first transfer the stuff that goes directly data["stream_name"] = a["stream_name"] data["pawprint_ID"] = a["pawprint_ID"] @@ -175,34 +117,34 @@ def from_file(cls, fname): return cls(data) @classmethod - def pawprint_from_galstreams(cls, stream_name, pawprint_ID, width): - def _get_stream_frame_from_file(summary_file): + def pawprint_from_galstreams( + cls: type[Self], stream_name: str, pawprint_ID: Any, width: float + ) -> Self: + """Create a pawprint from galstreams data.""" + + def _get_stream_frame_from_file(summary_file: str) -> GreatCircleICRSFrame: t = apt.QTable.read(summary_file) x = {} atts = [x.replace("mid.", "") for x in t.keys() if "mid" in x] - for ( - att - ) in ( - atts - ): # we're effectively looping over skycoords defined for mid here (ra, dec, ...) - x[att] = t[f"mid.{att}"][ - 0 - ] # <- make sure to set it up as a scalar. if not, frame conversions get into trouble + # we're effectively looping over skycoords defined for mid here (ra, + # dec, ...) + for att in atts: + # Make sure to set it up as a scalar. if not, frame conversions + # get into trouble + x[att] = t[f"mid.{att}"][0] mid_point = SkyCoord(**x) x = {} atts = [x.replace("pole.", "") for x in t.keys() if "pole" in x] - for ( - att - ) in ( - atts - ): # we're effectively looping over skycoords defined for pole here (ra, dec, ...) + # we're effectively looping over skycoords defined for pole here + # (ra, dec, ...) + for att in atts: x[att] = t[f"pole.{att}"][0] - # Make sure to set the pole's distance attribute to 1 (zero causes problems, when transforming to stream frame coords) - x["distance"] = ( - 1.0 * u.kpc - ) # it shouldn't matter, but if it's zero it does crazy things + # Make sure to set the pole's distance attribute to 1 (zero causes + # problems, when transforming to stream frame coords) it shouldn't + # matter, but if it's zero it does crazy things + x["distance"] = 1.0 * u.kpc mid_pole = SkyCoord(**x) return GreatCircleICRSFrame(pole=mid_pole, ra0=mid_point.icrs.ra) @@ -225,10 +167,9 @@ def _get_stream_frame_from_file(summary_file): summary_file=summary_file, ) try: - data["width"] = ( - 2 * data["track"].track_width["width_phi2"] - ) # one standard deviation on each side (is this wide enough?) - except Exception: + # one standard deviation on each side (is this wide enough?) + data["width"] = 2 * data["track"].track_width["width_phi2"] + except KeyError: data["width"] = width data["stream_vertices"] = data["track"].create_sky_polygon_footprint_from_track( width=data["width"], phi2_offset=0.0 * u.deg @@ -246,7 +187,10 @@ def _get_stream_frame_from_file(summary_file): return cls(data) - def add_cmd_footprint(self, new_footprint, color, mag, name): + def add_cmd_footprint( + self, new_footprint: Any, color: Any, mag: Any, name: str + ) -> None: + """Add a color-magnitude diagram footprint.""" if self.cmd_filters is None: self.cmd_filters = dict((name, [color, mag])) self.cmdprint = dict((name, new_footprint)) @@ -254,15 +198,25 @@ def add_cmd_footprint(self, new_footprint, color, mag, name): self.cmd_filters[name] = [color, mag] self.cmdprint[name] = new_footprint - def add_pm_footprint(self, new_footprint, name): + def add_pm_footprint(self, new_footprint: Any, name: str) -> None: + """Add a proper motion footprint.""" if self.pmprint is None: self.pmprint = dict((name, new_footprint)) else: self.pmprint[name] = new_footprint - def save_pawprint(self): - # WARNING this doesn't save the track yet - need schema - # WARNING the stream frame doesn't save right either + def save_pawprint(self) -> None: + """Save the pawprint to an asdf file. + + .. warning:: + + This doesn't save the track yet. + + .. todo:: + + Make an ASDF schema for the track and the frame, and then the + pawprint. + """ fname = self.stream_name + self.pawprint_ID + ".asdf" tree = { "stream_name": self.stream_name, @@ -272,7 +226,6 @@ def save_pawprint(self): "width": self.width, "on_stream": {"sky": self.skyprint["stream"].export()}, "off_stream": self.skyprint["background"].export(), - # 'track':self.track #TODO } if self.cmdprint is not None: tree["on_stream"]["cmd"] = { diff --git a/src/cats/pawprint/_footprint.py b/src/cats/pawprint/_footprint.py new file mode 100644 index 0000000..838f9aa --- /dev/null +++ b/src/cats/pawprint/_footprint.py @@ -0,0 +1,110 @@ +"""Footprint class.""" + +from __future__ import annotations + +__all__ = ["Footprint2D"] + +from typing import TYPE_CHECKING, Any + +import astropy.table as apt +import numpy as np +from astropy.coordinates import SkyCoord +from matplotlib.path import Path as mpl_path + +if TYPE_CHECKING: + from astropy.coordinates import BaseCoordinateFrame + from numpy import bool_ + from numpy.typing import NDArray + from typing_extensions import Self + + +class Footprint2D(dict): + """A 2D footprint.""" + + def __init__( + self, + vertex_coordinates: Any, + footprint_type: Any, + stream_frame: BaseCoordinateFrame | None = None, + ) -> None: + if footprint_type == "sky": + if isinstance(vertex_coordinates, SkyCoord): + vc = vertex_coordinates + else: + vc = SkyCoord(vertex_coordinates) + self.edges = vc + self.vertices = np.array( + [vc.transform_to(stream_frame).phi1, vc.transform_to(stream_frame).phi2] + ).T + + elif footprint_type == "cartesian": + self.edges = vertex_coordinates + self.vertices = vertex_coordinates + + self.stream_frame = stream_frame + self.footprint_type = footprint_type + self.footprint = mpl_path(self.vertices) + + # =============================================================== + + @classmethod + def from_vertices( + cls: type[Self], vertex_coordinates: Any, footprint_type: Any + ) -> Self: + """Initialize from vertices.""" + return cls(vertex_coordinates, footprint_type) + + @classmethod + def from_box( + cls: type[Self], + min1: float, + max1: float, + min2: float, + max2: float, + footprint_type: str, + ) -> Self: + """Initialize from a box.""" + vertices = get_vertices_from_box(min1, max1, min2, max2) + return cls(vertices, footprint_type) + + @classmethod + def from_file(cls: type[Self], fname: str) -> Self: + """Initialize from a file.""" + with apt.Table.read(fname) as t: + vertices = t["vertices"] + footprint_type = t["footprint_type"] + return cls(vertices, footprint_type) + + # =============================================================== + + def inside_footprint(self, data: SkyCoord | Any) -> NDArray[bool_] | None: + """Check if a point is inside the footprint.""" + if isinstance(data, SkyCoord): + if self.stream_frame is None: + print("can't!") + return None + + pts = np.array( + [ + data.transform_to(self.stream_frame).phi1.value, + data.transform_to(self.stream_frame).phi2.value, + ] + ).T + return self.footprint.contains_points(pts) + + return self.footprint.contains_points(data) + + def export(self) -> dict[str, Any]: + """Export the footprint to a dictionary.""" + data = {} + data["stream_frame"] = self.stream_frame + data["vertices"] = self.vertices + data["footprint_type"] = self.footprint_type + return data + + +def get_vertices_from_box( + min1: float, max1: float, min2: float, max2: float +) -> list[list[float]]: + """Get vertices from a box.""" + return [[min1, min2], [min1, max2], [max1, min2], [max1, max2]] diff --git a/src/cats/photometry.py b/src/cats/photometry.py deleted file mode 100644 index efba9a5..0000000 --- a/src/cats/photometry.py +++ /dev/null @@ -1,199 +0,0 @@ -from __future__ import annotations - -__all__ = ["PS1Phot", "GaiaDR3Phot", "DESY6Phot"] - -import abc -from typing import ClassVar - -import astropy.coordinates as coord -import astropy.table as at -import astropy.units as u -import numpy as np -from dustmaps.sfd import SFDQuery -from pyia import GaiaData - - -class PhotometricSurvey(metaclass=abc.ABCMeta): - band_names: ClassVar[dict[str, str]] = {} - extinction_coeffs: ClassVar[dict[str, float]] = {} - custom_extinction: ClassVar[bool] = False - dustmaps_cls: ClassVar[type] = SFDQuery - - def __init_subclass__(cls) -> None: - if len(cls.band_names) == 0: - raise ValueError( - "You must define some photometric band names in band_names for any " - "survey-specific subclass" - ) - - for short_name in cls.band_names.values(): - if not cls.custom_extinction and short_name not in cls.extinction_coeffs: - raise ValueError( - "You must specify extinction coefficients for all photometric " - "bands in any survey-specific subclass" - ) - - def __init__(self, data) -> None: - """ - - Parameters - ---------- - data : table-like, str - Anything that can be passed into `astropy.table.Table` to construct an - astropy table instance, or a string filename (that can be read into an - astropy table instance). - """ - if isinstance(data, str): - data = at.Table.read(data) - self.data = at.Table(data) - - @abc.abstractmethod - def get_skycoord(self): - """ - Return a SkyCoord object from the data table. - """ - - @abc.abstractmethod - def get_star_mask(self): - """ - Star-galaxy separation - """ - - def get_ext_corrected_phot(self, dustmaps_cls=None): - if self.custom_extinction: - raise RuntimeError("TODO") - - if dustmaps_cls is None: - dustmaps_cls = self.dustmaps_cls - - c = self.get_skycoord() - ebv = dustmaps_cls().query(c) - - tbl = at.Table() - new_band_names = [] - for band, short_name in self.band_names.items(): - Ax = self.extinction_coeffs[short_name] * ebv - tbl[f"A_{short_name}"] = Ax - tbl[f"{short_name}0"] = self.data[band] - Ax - new_band_names.append(f"{short_name}0") - tbl.meta["band_names"] = new_band_names - tbl.meta["dustmap"] = dustmaps_cls.__class__.__name__ - - return tbl - - -class PS1Phot(PhotometricSurvey): - band_names: ClassVar[dict[str, str]] = { - "gMeanPSFMag": "g", - "rMeanPSFMag": "r", - "iMeanPSFMag": "i", - "zMeanPSFMag": "z", - "yMeanPSFMag": "y", - } - - # Schlafly+2011, Rv=3.1 - extinction_coeffs: ClassVar[dict[str, float]] = { - "g": 3.172, - "r": 2.271, - "i": 1.682, - "z": 1.322, - "y": 1.087, - } - - def get_skycoord(self): - return coord.SkyCoord(self.data["raMean"] * u.deg, self.data["decMean"] * u.deg) - - def get_star_mask(self): - """ - Star/galaxy separation for PS1 - - See: - https://outerspace.stsci.edu/display/PANSTARRS/How+to+separate+stars+and+galaxies - - Returns - ------- - star_mask : `numpy.ndarray` - True where the stars are. - """ - d_mag_mask = self.data["iMeanPSFMag"] - self.data["iMeanKronMag"] < 0.05 - return d_mag_mask - - -class GaiaDR3Phot(PhotometricSurvey): - band_names: ClassVar[dict[str, str]] = { - "phot_g_mean_mag": "G", - "phot_bp_mean_mag": "BP", - "phot_rp_mean_mag": "RP", - } - custom_extinction: ClassVar[bool] = True - - def get_skycoord(self): - return GaiaData(self.data).get_skycoord(distance=False) - - def get_star_mask(self): - """ - Star-galaxy separation: - """ - return np.ones(len(self.data), dtype=bool) - - def get_ext_corrected_phot(self, dustmaps_cls=None): - if dustmaps_cls is None: - dustmaps_cls = self.dustmaps_cls - g = GaiaData(self.data) - As = g.get_ext(dustmaps_cls=self.dustmaps_cls) - As = {"G": As[0], "BP": As[1], "RP": As[2]} # NOTE: assumption! - - tbl = at.Table() - new_band_names = [] - for band, short_name in self.band_names.items(): - Ax = As[short_name] - if hasattr(Ax, "value"): - Ax = Ax.value - tbl[f"A_{short_name}"] = Ax - tbl[f"{short_name}0"] = self.data[band] - Ax - new_band_names.append(f"{short_name}0") - tbl.meta["band_names"] = new_band_names - tbl.meta["dustmap"] = dustmaps_cls.__class__.__name__ - - return tbl - - -class DESY6Phot(PhotometricSurvey): - band_names: ClassVar[dict[str, str]] = { - "WAVG_MAG_PSF_G": "g", - "WAVG_MAG_PSF_R": "r", - } - # Schlafly+2011, Rv=3.1 - extinction_coeffs: ClassVar[dict[str, float]] = { - "g": 3.237, - "r": 2.176, - } - custom_extinction: ClassVar[bool] = True - - def get_skycoord(self): - return coord.SkyCoord(self.data["RA"] * u.deg, self.data["DEC"] * u.deg) - - def get_star_mask(self): - """ - Star-galaxy separation: - """ - return (self.data["EXT_FITVD"] >= 0) & (self.data["EXT_FITVD"] < 2) - - def get_ext_corrected_phot(self, dustmaps_cls=None): - if dustmaps_cls is None: - dustmaps_cls = self.dustmaps_cls - - c = self.get_skycoord() - ebv = dustmaps_cls().query(c) - - tbl = at.Table() - new_band_names = [] - for short_name in self.band_names.values(): - Ax = self.extinction_coeffs[short_name] * ebv - tbl[f"A_{short_name}"] = Ax - tbl[f"{short_name}0"] = self.data[f"BDF_MAG_{short_name.upper()}_CORRECTED"] - new_band_names.append(f"{short_name}0") - tbl.meta["band_names"] = new_band_names - tbl.meta["dustmap"] = dustmaps_cls.__class__.__name__ - - return tbl diff --git a/src/cats/photometry/__init__.py b/src/cats/photometry/__init__.py new file mode 100644 index 0000000..f674b31 --- /dev/null +++ b/src/cats/photometry/__init__.py @@ -0,0 +1,11 @@ +"""CMD functions.""" + +from __future__ import annotations + +from . import _base, _builtin +from ._base import * +from ._builtin import * + +__all__: list[str] = [] +__all__ += _base.__all__ +__all__ += _builtin.__all__ diff --git a/src/cats/photometry/_base.py b/src/cats/photometry/_base.py new file mode 100644 index 0000000..03b754b --- /dev/null +++ b/src/cats/photometry/_base.py @@ -0,0 +1,115 @@ +"""Base class for photometric surveys.""" + +from __future__ import annotations + +__all__ = ["AbstractPhotometricSurvey"] + +import abc +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar + +from astropy.table import QTable, Table +from dustmaps.sfd import SFDQuery + +if TYPE_CHECKING: + from astropy.coordinates import SkyCoord + from dustmaps.map_base import DustMap + from numpy import bool_ + from numpy.typing import NDArray + from typing_extensions import Self + + +@dataclass(frozen=True) +class AbstractPhotometricSurvey(metaclass=abc.ABCMeta): + """Photoemtric survey base class. + + Parameters + ---------- + data : :class:`~astropy.table.QTable` + Anything that can be passed into `astropy.table.Table` to construct an + astropy table instance, or a string filename (that can be read into an + astropy table instance). + """ + + band_names: ClassVar[dict[str, str]] = {} + extinction_coeffs: ClassVar[dict[str, float]] = {} + custom_extinction: ClassVar[bool] = False + dustmaps_cls: ClassVar[type[DustMap]] = SFDQuery + + data: QTable + + def __init_subclass__(cls) -> None: + if not cls.band_names: # empty dict + msg = ( + "You must define some photometric band names in band_names for any " + "survey-specific subclass" + ) + raise ValueError(msg) + + for short_name in cls.band_names.values(): + if not cls.custom_extinction and short_name not in cls.extinction_coeffs: + msg = ( + "You must specify extinction coefficients for all photometric " + "bands in any survey-specific subclass" + ) + raise ValueError(msg) + + @classmethod + def from_tablelike(cls: type[Self], data: str | Table) -> Self: + """Initialize from a table-like object.""" + if isinstance(data, str): + return cls(QTable.read(data)) + return cls(data) + + @abc.abstractmethod + def get_skycoord(self) -> SkyCoord: + """Return a SkyCoord object from the data table.""" + + @abc.abstractmethod + def get_star_mask(self) -> NDArray[bool_]: + """Star-galaxy separation.""" + + def get_ext_corrected_phot( + self, dustmaps_cls: type[DustMap] | None = None + ) -> QTable: + """Get extinction-corrected photometry. + + Parameters + ---------- + dustmaps_cls : type[:class:`~dustmaps.map_base.DustMap`], optional + Dustmap class to use for extinction correction. Default is + :class:`~dustmaps.sfd.SFDQuery`. + + Notes + ----- + This is a default implementation. Most subclasses will need to override + this method. + """ + if self.custom_extinction: + msg = "TODO" + raise RuntimeError(msg) + + if dustmaps_cls is None: + dustmaps_cls = self.dustmaps_cls + + c = self.get_skycoord() + ebv = dustmaps_cls().query(c) + + tbl = QTable() + short_band_names: list[str] = [] + for band, short_name in self.band_names.items(): + # Get extinction coefficient + Ax = self.extinction_coeffs[short_name] * ebv + + # Add to table + tbl[f"A_{short_name}"] = Ax + tbl[f"{short_name}0"] = self.data[band] - Ax + + # Record new band name + short_band_names.append(f"{short_name}0") + + # Metadata + tbl.meta["band_names"] = short_band_names + tbl.meta["dustmap"] = dustmaps_cls.__class__.__name__ + + return tbl diff --git a/src/cats/photometry/_builtin/__init__.py b/src/cats/photometry/_builtin/__init__.py new file mode 100644 index 0000000..6bf5d59 --- /dev/null +++ b/src/cats/photometry/_builtin/__init__.py @@ -0,0 +1,13 @@ +"""CMD functions.""" + +from __future__ import annotations + +from . import desi, gaia, ps1 +from .desi import * +from .gaia import * +from .ps1 import * + +__all__: list[str] = [] +__all__ += desi.__all__ +__all__ += gaia.__all__ +__all__ += ps1.__all__ diff --git a/src/cats/photometry/_builtin/desi.py b/src/cats/photometry/_builtin/desi.py new file mode 100644 index 0000000..4247304 --- /dev/null +++ b/src/cats/photometry/_builtin/desi.py @@ -0,0 +1,74 @@ +"""DES year 6 Photometric Survey.""" + +from __future__ import annotations + +__all__ = ["DESY6Phot"] + +from typing import TYPE_CHECKING, ClassVar, TypedDict + +import astropy.units as u +from astropy.coordinates import SkyCoord +from astropy.table import QTable + +from cats.photometry._base import AbstractPhotometricSurvey + +if TYPE_CHECKING: + from dustmaps.map_base import DustMap + from numpy import bool_ + from numpy.typing import NDArray + + +class DESY6BandNames(TypedDict): + """DESY6 band names.""" + + WAVG_MAG_PSF_G: str + WAVG_MAG_PSF_R: str + + +class DESY6ExtinctionCoeffs(TypedDict): + """DESY6 extinction coefficients.""" + + g: float + r: float + + +class DESY6Phot(AbstractPhotometricSurvey): + """DESY6 Photometric Survey.""" + + band_names: ClassVar[DESY6BandNames] = { + "WAVG_MAG_PSF_G": "g", + "WAVG_MAG_PSF_R": "r", + } + # Schlafly+2011, Rv=3.1 + extinction_coeffs: ClassVar[DESY6ExtinctionCoeffs] = {"g": 3.237, "r": 2.176} + custom_extinction: ClassVar[bool] = True + + def get_skycoord(self) -> SkyCoord: + return SkyCoord(self.data["RA"] * u.deg, self.data["DEC"] * u.deg) + + def get_star_mask(self) -> NDArray[bool_]: + return (self.data["EXT_FITVD"] >= 0) & (self.data["EXT_FITVD"] < 2) + + def get_ext_corrected_phot( + self, dustmaps_cls: tuple[DustMap] | None = None + ) -> QTable: + if dustmaps_cls is None: + dustmaps_cls = self.dustmaps_cls + + c = self.get_skycoord() + ebv = dustmaps_cls().query(c) + + tbl = QTable() + for short_name in self.band_names.values(): + # Compute extinction correction + Ax = self.extinction_coeffs[short_name] * ebv + + # Apply correction + tbl[f"A_{short_name}"] = Ax + tbl[f"{short_name}0"] = self.data[f"BDF_MAG_{short_name.upper()}_CORRECTED"] + + # Metadata + tbl.meta["band_names"] = [f"{k}0" for k in self.band_names.values()] + tbl.meta["dustmap"] = dustmaps_cls.__class__.__name__ + + return tbl diff --git a/src/cats/photometry/_builtin/gaia.py b/src/cats/photometry/_builtin/gaia.py new file mode 100644 index 0000000..0b8c723 --- /dev/null +++ b/src/cats/photometry/_builtin/gaia.py @@ -0,0 +1,75 @@ +"""Gaia DR3 Photometric Survey.""" + +from __future__ import annotations + +__all__ = ["GaiaDR3Phot"] + +from typing import TYPE_CHECKING, ClassVar, TypedDict + +import numpy as np +from astropy.table import QTable +from pyia import GaiaData + +from cats.photometry._base import AbstractPhotometricSurvey + +if TYPE_CHECKING: + from astropy.coordinates import SkyCoord + from dustmaps.map_base import DustMap + from numpy import bool_ + from numpy.typing import NDArray + + +class GaiaDR3BandNames(TypedDict): + """Gaia DR3 band names.""" + + phot_g_mean_mag: str + phot_bp_mean_mag: str + phot_rp_mean_mag: str + + +class GaiaDR3Phot(AbstractPhotometricSurvey): + """Gaia DR3 Photometric Survey.""" + + band_names: ClassVar[GaiaDR3BandNames] = { + "phot_g_mean_mag": "G", + "phot_bp_mean_mag": "BP", + "phot_rp_mean_mag": "RP", + } + custom_extinction: ClassVar[bool] = True + + def get_skycoord(self) -> SkyCoord: + return GaiaData(self.data).get_skycoord(distance=False) + + def get_star_mask(self) -> NDArray[bool_]: + return np.ones(len(self.data), dtype=bool) + + def get_ext_corrected_phot( + self, dustmaps_cls: type[DustMap] | None = None + ) -> QTable: + if dustmaps_cls is None: + dustmaps_cls = self.dustmaps_cls + + g = GaiaData(self.data) + As = g.get_ext(dustmaps_cls=self.dustmaps_cls) + As = {"G": As[0], "BP": As[1], "RP": As[2]} # NOTE: assumption! + + tbl = QTable() + short_band_names: list[str] = [] + for band, short_name in self.band_names.items(): + # Get extinction coefficient + Ax = As[short_name] + if hasattr(Ax, "value"): + Ax = Ax.value + + # Add to table + tbl[f"A_{short_name}"] = Ax + tbl[f"{short_name}0"] = self.data[band] - Ax + + # Record new band name + short_band_names.append(f"{short_name}0") + + # Metadata + tbl.meta["band_names"] = short_band_names + tbl.meta["dustmap"] = dustmaps_cls.__class__.__name__ + + return tbl diff --git a/src/cats/photometry/_builtin/ps1.py b/src/cats/photometry/_builtin/ps1.py new file mode 100644 index 0000000..6ba0ca6 --- /dev/null +++ b/src/cats/photometry/_builtin/ps1.py @@ -0,0 +1,77 @@ +"""Pan-STARRS1 Photometric Survey.""" + +from __future__ import annotations + +__all__ = ["PS1Phot"] + +from typing import TYPE_CHECKING, ClassVar, TypedDict + +import astropy.units as u +from astropy.coordinates import SkyCoord + +from cats.photometry._base import AbstractPhotometricSurvey + +if TYPE_CHECKING: + import numpy.typing as npt + from numpy import bool_ + + +class PS1BandNames(TypedDict): + """PS1 band names.""" + + gMeanPSFMag: str + rMeanPSFMag: str + iMeanPSFMag: str + zMeanPSFMag: str + yMeanPSFMag: str + + +class PS1ExtinctionCoeffs(TypedDict): + """PS1 extinction coefficients.""" + + g: str + r: str + i: str + z: str + y: str + + +class PS1Phot(AbstractPhotometricSurvey): + """Pan-STARRS1 Photometric Survey.""" + + band_names: ClassVar[PS1BandNames] = { + "gMeanPSFMag": "g", + "rMeanPSFMag": "r", + "iMeanPSFMag": "i", + "zMeanPSFMag": "z", + "yMeanPSFMag": "y", + } + + # Schlafly+2011, Rv=3.1 + # TODO: load from a config file + extinction_coeffs: ClassVar[PS1ExtinctionCoeffs] = { + "g": 3.172, + "r": 2.271, + "i": 1.682, + "z": 1.322, + "y": 1.087, + } + + def get_skycoord(self) -> SkyCoord: + return SkyCoord( + self.data["raMean"] << u.deg, + self.data["decMean"] << u.deg, + frame="icrs", + ) + + def get_star_mask(self) -> npt.NDArray[bool_]: + """Star/galaxy separation for PS1. + + See: https://outerspace.stsci.edu/display/PANSTARRS/How+to+separate+stars+and+galaxies + + Returns + ------- + star_mask : `numpy.ndarray` + True where the stars are. + """ + return (self.data["iMeanPSFMag"] - self.data["iMeanKronMag"]) < 0.05 diff --git a/src/cats/proper_motions.py b/src/cats/proper_motions.py index 0d61583..c893292 100644 --- a/src/cats/proper_motions.py +++ b/src/cats/proper_motions.py @@ -2,30 +2,43 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any, Callable + import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np +from astropy.modeling import fitting, models +from matplotlib.colors import LogNorm from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.interpolate import InterpolatedUnivariateSpline as IUS -from scipy.spatial import ConvexHull +from scipy.spatial import ConvexHull # pylint: disable=no-name-in-module from cats.inputs import stream_inputs as inputs -from cats.pawprint.pawprint import Footprint2D +from cats.pawprint._footprint import Footprint2D + +if TYPE_CHECKING: + from matplotlib.cm import ScalarMappable + from matplotlib.colorbar import Colorbar + from matplotlib.figure import Figure + from numpy.typing import NDArray + + from cats.pawprint import Pawprint __author__ = ("Sophia", "Nora", "Nondh", "Lina", "Bruno", "Kiyan") -def rough_pm_poly(pawprint, data, buffer=2): - """ - Will return a polygon with a rough cut in proper motion space. +def rough_pm_poly( + pawprint: Pawprint, data: Any, buffer: float = 2 +) -> tuple[Footprint2D, NDArray[np.bool_]]: + """Will return a polygon with a rough cut in proper motion space. + This aims to be ~100% complete with no thoughts about purity. The goal is to use this cut in conjunction with the cmd cut in order to see the stream as a clear overdensity in (phi_1, phi_2), which will allow membership probability modeling """ - stream_fr = pawprint.track.stream_frame - track = pawprint.track.track.transform_to(stream_fr) - # track_refl = gc.reflex_correct(track) + stream_frame = pawprint.track.stream_frame + track = pawprint.track.track.transform_to(stream_frame) # use the galstream proper motion track track_pm1_min = np.min(track.pm_phi1_cosphi2.value) @@ -34,7 +47,7 @@ def rough_pm_poly(pawprint, data, buffer=2): track_pm2_max = np.max(track.pm_phi2.value) # make rectangular box around this region with an extra 2 mas/yr on each side - rough_pm_poly = np.array( + poly = np.array( [ [track_pm1_min - buffer, track_pm2_min - buffer], [track_pm1_min - buffer, track_pm2_max + buffer], @@ -43,7 +56,7 @@ def rough_pm_poly(pawprint, data, buffer=2): ] ) - pawprint.pmprint = Footprint2D(rough_pm_poly, footprint_type="cartesian") + pawprint.pmprint = Footprint2D(poly, footprint_type="cartesian") pm_points = np.vstack((data["pm_phi1_cosphi2_unrefl"], data["pm_phi2_unrefl"])).T rough_pm_mask = pawprint.pmprint.inside_footprint(pm_points) @@ -52,43 +65,54 @@ def rough_pm_poly(pawprint, data, buffer=2): class ProperMotionSelection: + """Proper Motion Selection. + + Parameters + ---------- + stream: object + galstream object that contains stream's proper motion tracks. + data: object + data dictionary. + pawprint: object + pawprint object that contains stream's sky and proper motion + polygons. + + best_pm_phi1_mean: + best initial guess for mean of pm_phi1. + best_pm_phi2_mean: + best initial guess for mean of pm_phi2. + best_pm_phi1_std: + best initial guess for pm_phi1 standard deviation. + best_pm_phi2_std: + best initial guess for pm_phi2 standard deviation. + n_dispersion_phi1: + float, default set to 1 standard deviation around phi_1. + n_dispersion_phi2: + float, default set to 1 standard deviation around phi_2. + refine_factor: + int, default set to 100, how smooth are the edges of the polygons. + cutoff: + float, in [0,1], cutoff on the height of the pdf to keep the stars + that have a probability to belong to the 2D gaussian above the + cutoff value. + """ + def __init__( self, - stream, - data, - pawprint, - # CMD_mask=True, - # spatial_mask_on=True, - # spatial_mask_off=True, - # pm_phi1_grad = None, # think we should take this from pawprint, or at least make that the default - # pm_phi2_grad = None, - best_pm_phi1_mean=None, - best_pm_phi2_mean=None, - best_pm_phi1_std=None, - best_pm_phi2_std=None, - cutoff=0.95, - n_dispersion_phi1=1, - n_dispersion_phi2=1, - refine_factor=100, - ): - """ - stream_obj: galstream object that contains stream's proper motion tracks - data: - :param: stream_obj: from galstreams so far #TODO: generalize - :param: CMD_mask: Used before the PM - :param: spatial_mask_on: - :param: spatial_mask_off: - :param: best_pm_phi1_mean: best initial guess for mean of pm_phi1 - :param: best_pm_phi2_mean: best initial guess for mean of pm_phi2 - :param: best_pm_phi1_std: best initial guess for pm_phi1 standard deviation - :param: best_pm_phi2_std: best initial guess for pm_phi2 standard deviation - :param: n_dispersion_phi1: float, default set to 1 standard deviation around phi_1 - :param: n_dispersion_phi2: float, default set to 1 standard deviation around phi_2 - :param: refine_factor: int, default set to 100, how smooth are the edges of the polygons - :param: cutoff: float, in [0,1], cutoff on the height of the pdf to keep the stars that have a probability to belong to the 2D gaussian above the cutoff value - """ - - # stream_obj starting as galstream but then should be replaced by best values that we find + stream: Any, + data: Any, + pawprint: Any, + best_pm_phi1_mean: float | None = None, + best_pm_phi2_mean: float | None = None, + best_pm_phi1_std: float | None = None, + best_pm_phi2_std: float | None = None, + cutoff: float = 0.95, + n_dispersion_phi1: int = 1, + n_dispersion_phi2: int = 1, + refine_factor: int = 100, + ) -> None: + # stream_obj starting as galstream but then should be replaced by best + # values that we find self.stream = stream self.stream_obj = pawprint.track self.data = data @@ -98,9 +122,12 @@ def __init__( self.cutoff = cutoff - assert ( - self.cutoff <= 1 and self.cutoff >= 0 - ), "the value of self.cutoff put in does not make sense! It has to be between 0 and 1" + if not (self.cutoff <= 1 and self.cutoff >= 0): + msg = ( + "the value of self.cutoff put in does not make sense! " + "It has to be between 0 and 1" + ) + raise AssertionError(msg) # Get tracks from galstreams with splines spline_phi2, spline_pm1, spline_pm2, spline_dist = self.from_galstreams() @@ -125,7 +152,8 @@ def __init__( # distmod_spl = np.poly1d([2.41e-4, 2.421e-2, 15.001]) # self.dist_mod_correct = distmod_spl(self.cat["phi1"]) - self.dist_mod - # SHOULD THE CMD CUT ALSO MAKE AN OFFSTREAM MASK? MAY BE USEFUL TO MAKE CUTS FOR SOME STREAMS + # SHOULD THE CMD CUT ALSO MAKE AN OFFSTREAM MASK? MAY BE USEFUL TO MAKE + # CUTS FOR SOME STREAMS self.initial_masks() self.pm_phi1_cosphi2 = self.data["pm_phi1_cosphi2_unrefl"][self.mask] self.pm_phi2 = self.data["pm_phi2_unrefl"][self.mask] @@ -137,13 +165,7 @@ def __init__( mid_phi1 = np.median(self.track.phi1.value) print(mid_phi1) - if best_pm_phi1_mean == None: - # TODO: generalize this later to percentile_values = [16, 50, 84] - - # if self.stream == 'Fjorm-M68': - # self.best_pm_phi1_mean = 1 - # self.best_pm_phi2_mean = 4 - # else: + if best_pm_phi1_mean is None: self.best_pm_phi1_mean = spline_pm1(mid_phi1) self.best_pm_phi2_mean = spline_pm2(mid_phi1) @@ -170,12 +192,9 @@ def __init__( self.data, x_width=3.0, y_width=3.0, draw_histograms=True ) print( - "Post-fitting (pm1_mean, pm2_mean, pm1_std, pm2_std): {} \n".format( - peak_locations - ) + "Post-fitting (pm1_mean, pm2_mean, pm1_std, pm2_std): " + f"{peak_locations} \n" ) - # except: - # print('Skipping peak pm fitting') ################################################ ## Ellipse-like proper motion cut in PM space ## @@ -192,12 +211,6 @@ def __init__( self.pm_phi1_cosphi2 = data["pm_phi1_cosphi2_unrefl"][self.mask] self.pm_phi2 = data["pm_phi2_unrefl"][self.mask] - # Plot the ellipse-like cut - # self.plot_pms_scatter(self.data, mask=True, - # n_dispersion_phi1=n_dispersion_phi1, - # n_dispersion_phi2=n_dispersion_phi2) - # self.plot_pm_hist(self.data, pms=[self.best_pm_phi1_mean, self.best_pm_phi2_mean]) - ###################################################### ## PM cut in PM space using PM gradient information ## ###################################################### @@ -233,19 +246,10 @@ def __init__( ) = self.build_pm12_polys_and_masks() self.mask = self.pm1_mask & self.pm2_mask & self.spatial_mask_on & self.CMD_mask - # Plot the cut in (phi1, pm1) and (phi1, pm2) space - # self.plot_pms_scatter(self.data, mask=True, - # n_dispersion_phi1=n_dispersion_phi1, - # n_dispersion_phi2=n_dispersion_phi2) - # self.plot_pm_hist(self.data, pms=[self.best_pm_phi1_mean, self.best_pm_phi2_mean]) - - return - - def from_galstreams(self): - stream_fr = self.stream_obj.stream_frame - self.track = self.stream_obj.track.transform_to(stream_fr) - # track_refl = gc.reflex_correct(track) - # self.track_refl = track_refl + def from_galstreams(self) -> tuple[IUS, IUS, IUS, IUS]: + """Get tracks from galstreams with splines.""" + stream_frame = self.stream_obj.stream_frame + self.track = self.stream_obj.track.transform_to(stream_frame) self.galstream_phi1 = self.track.phi1.value self.galstream_phi2 = self.track.phi2.value @@ -269,10 +273,6 @@ def from_galstreams(self): spline_pm1 = IUS(self.galstream_phi1, self.galstream_pm_phi1_cosphi2) spline_pm2 = IUS(self.galstream_phi1, self.galstream_pm_phi2) - # spline_phi2 = US(self.galstream_phi1, self.galstream_phi2, k=3, s=len(self.galstream_phi1)/1000) - # spline_pm1 = US(self.galstream_phi1, self.galstream_pm_phi1_cosphi2, k=3, s=len(self.galstream_phi1)/1000) - # spline_pm2 = US(self.galstream_phi1, self.galstream_pm_phi2, k=3, s=len(self.galstream_phi1)/1000) - if self.stream == "GD-1": spline_dist = np.poly1d( [2.41e-4, 2.421e-2, 15.001] @@ -282,10 +282,8 @@ def from_galstreams(self): return spline_phi2, spline_pm1, spline_pm2, spline_dist - def sel_sky(self): - """ - Initialising the on-sky polygon mask to return only contained sources. - """ + def sel_sky(self) -> tuple[NDArray[np.bool_], NDArray[np.bool_]]: + """Initialize the on-sky polygon mask to return only contained sources.""" on_poly_patch = mpl.patches.Polygon( self.pawprint.skyprint["stream"].vertices[::100], facecolor="none", @@ -306,11 +304,11 @@ def sel_sky(self): return on_mask, off_mask - def sel_cmd(self): - """ - Initialising the proper motions polygon mask to return only contained sources. - """ + def sel_cmd(self) -> NDArray[np.bool_]: + """Initialize the proper motions polygon mask. + Set to return only contained sources. + """ mag = inputs[self.stream]["mag"] color1 = inputs[self.stream]["color1"] color2 = inputs[self.stream]["color2"] @@ -321,22 +319,18 @@ def sel_cmd(self): self.data[mag] - self.dist_mod_correct, ) ).T - cmd_mask = self.pawprint.cmdprint.inside_footprint(cmd_points) + return self.pawprint.cmdprint.inside_footprint(cmd_points) - return cmd_mask - - def initial_masks(self): - """ - Generate the initial spatial, and CMD masks based on the input - """ + def initial_masks(self) -> None: + """Generate the initial spatial, and CMD masks based on the input.""" self.spatial_mask_on, self.spatial_mask_off = self.sel_sky() self.CMD_mask = self.sel_cmd() self.mask = self.spatial_mask_on & self.CMD_mask self.off_mask = self.spatial_mask_off & self.CMD_mask - def rough_pm(self, buffer=2): - """ - Will return a polygon with a rough cut in proper motion space. + def rough_pm(self, buffer: float = 2) -> tuple[NDArray, NDArray[np.bool_]]: + """Will return a polygon with a rough cut in proper motion space. + This aims to be ~100% complete with no thoughts about purity. The goal is to use this cut in conjunction with the cmd cut in order to see the stream as a clear overdensity in (phi_1, phi_2), which will @@ -369,28 +363,45 @@ def rough_pm(self, buffer=2): return self.rough_pm_poly, self.rough_pm_mask @staticmethod - def two_dimensional_gaussian(x, y, x0, y0, sigma_x, sigma_y): - """ - Evaluates a two dimensional gaussian distribution in x, y, with means x0, y0, and dispersions sigma_x and sigma_y - """ + def two_dimensional_gaussian( + x: float, y: float, x0: float, y0: float, sigma_x: float, sigma_y: float + ) -> float: + """Evaluate a two dimensional gaussian distribution. + In x, y, with means x0, y0, and dispersions sigma_x and sigma_y. + """ return np.exp( -((x - x0) ** 2 / (2 * sigma_x**2) + (y - y0) ** 2 / (2 * sigma_y**2)) ) def build_poly_and_mask( - self, n_dispersion_phi1=3, n_dispersion_phi2=3, refine_factor=100 - ): - """ - Builds the mask of the proper motion with n_dispersion around the mean - :param: n_dispersion_phi1: float, default set to 1 standard deviation around phi_1 - :param: n_dispersion_phi2: float, default set to 1 standard deviation around phi_2 - :param: refine_factor: int, default set to 100, how smooth are the edges of the polygons - :param: cutoff: float, in [0,1], cutoff on the height of the pdf to keep the stars that have a probability to belong to the 2D gaussian above the cutoff value - - :output: is a list of points that are the vertices of a polygon + self, + n_dispersion_phi1: float = 3, + n_dispersion_phi2: float = 3, + refine_factor: int = 100, + ) -> tuple[NDArray, NDArray[np.bool_]]: + """Build mask of the proper motion with around the mean. + + Parameters + ---------- + n_dispersion_phi1: float + default set to 1 standard deviation around phi_1 + n_dispersion_phi2: float + default set to 1 standard deviation around phi_2 + refine_factor: int + default set to 100, how smooth are the edges of the polygons + cutoff: float + in [0,1], cutoff on the height of the pdf to keep the stars that + have a probability to belong to the 2D gaussian above the cutoff + value + + Returns + ------- + NDArray + vertices of the polygon. + NDArray[np.bool_] + mask of the polygon. """ - # First generate the 2D histograms pm_phi1_min, pm_phi1_max = ( self.best_pm_phi1_mean - n_dispersion_phi1 * self.best_pm_phi1_std, @@ -438,9 +449,11 @@ def build_poly_and_mask( return self.pm_poly, self.pm_mask - def build_pm12_polys_and_masks(self): - """ - This assumes that galstreams is correct, which is maybe not a great assumption but will work for now. + def build_pm12_polys_and_masks(self) -> tuple[NDArray, NDArray, NDArray, NDArray]: + """Build the pm1 and pm2 polygons and masks. + + This assumes that galstreams is correct, which is maybe not a great + assumption but will work for now. """ self.pm1_poly = np.concatenate( [ @@ -494,14 +507,20 @@ def build_pm12_polys_and_masks(self): return self.pm1_poly, self.pm2_poly, self.pm1_mask, self.pm2_mask - def build_mask(self, data, spline_pm1, spline_pm2, pm_poly): - """ - This builds a mask (i.e. finds the data points satisfying pm constraints) - that does not use the peak fitting used elsewhere. - It relies on splines for pm_phi1_cosphi2 and pm_phi2 vs phi1 which must be given as inputs - Most of the time, these will naturally come from galstreams + def build_mask( + self, + data: dict, + spline_pm1: Callable[[NDArray], NDArray], + spline_pm2: Callable[[NDArray], NDArray], + pm_poly: NDArray, + ) -> NDArray[np.bool_]: + """Build a mask. + + Finds the data points satisfying pm constraints that does not use the + peak fitting used elsewhere. It relies on splines for pm_phi1_cosphi2 + and pm_phi2 vs phi1 which must be given as inputs Most of the time, + these will naturally come from galstreams. """ - pm1_data_corrected = data["pm_phi1_cosphi2_unrefl"] - spline_pm1(data["phi1"]) pm2_data_corrected = data["pm_phi2_unrefl"] - spline_pm2(data["phi1"]) @@ -512,26 +531,39 @@ def build_mask(self, data, spline_pm1, spline_pm2, pm_poly): ) pm_points = np.vstack((pm1_data_corrected, pm2_data_corrected)).T - pm_mask = pm_corrected_poly_patch.get_path().contains_points(pm_points) - # self.pawprint.pmprint = Footprint2D(pm_vert_corrected, footprint_type='cartesian') - # self.pm_mask = self.pawprint.pmprint.inside_footprint(pm_points) - - return pm_mask + return pm_corrected_poly_patch.get_path().contains_points(pm_points) def plot_pms_scatter( self, - data, - save=True, - mask=False, - n_dispersion_phi1=1, - n_dispersion_phi2=1, - refine_factor=100, - **kwargs, - ): - """ - Plot proper motions on stream and off stream scatter or hist2d plots - :param: save: boolean, whether or not to save the figure - :param: mask: boolean, if true, calls in the mask + data: dict, + n_dispersion_phi1: int = 1, + n_dispersion_phi2: int = 1, + refine_factor: int = 100, + *, + save: bool = True, + mask: bool = False, + **kwargs: Any, + ) -> Figure: + """Plot proper motions on stream and off stream scatter or hist2d plots. + + Parameters + ---------- + data: dict + data dictionary + n_dispersion_phi1, n_dispersion_phi2: int + Passed to :meth:`~ProperMotionSelection.build_poly_and_mask`. + refine_factor: int + Passed to :meth:`~ProperMotionSelection.build_poly_and_mask`. + save: bool, keyword-only + whether or not to save the figure + mask: bool, keyword-only + If true, calls in the mask. + **kwargs: Any + Passed to `~matplotlib.axes.Axes.scatter`. + + Returns + ------- + :class:`~matplotlib.figure.Figure` """ data_on = data[self.mask] data_off = data[self.off_mask] @@ -566,8 +598,8 @@ def plot_pms_scatter( ax[0].set_xlim(-20, 20) ax[0].set_ylim(-20, 20) - ax[0].set_xlabel("$\mu_{\phi_1}$ [mas yr$^{-1}$]") - ax[0].set_ylabel("$\mu_{\phi_2}$ [mas yr$^{-1}$]") + ax[0].set_xlabel(r"$\mu_{\phi_1}$ [mas yr$^{-1}$]") + ax[0].set_ylabel(r"$\mu_{\phi_2}$ [mas yr$^{-1}$]") ax[0].set_title("Stream", fontsize="medium") # resize and fix column name @@ -586,8 +618,8 @@ def plot_pms_scatter( ax[1].set_xlim(-20, 20) ax[1].set_ylim(-20, 20) - ax[1].set_xlabel("$\mu_{\phi_1}$ [mas yr$^{-1}$]") - ax[1].set_ylabel("$\mu_{\phi_2}$ [mas yr$^{-1}$]") + ax[1].set_xlabel(r"$\mu_{\phi_1}$ [mas yr$^{-1}$]") + ax[1].set_ylabel(r"$\mu_{\phi_2}$ [mas yr$^{-1}$]") ax[1].set_title("Off stream", fontsize="medium") fig.tight_layout() @@ -606,20 +638,24 @@ def plot_pms_scatter( def plot_pm_hist( self, - data, - dx=0.5, - norm=1, - save=0, - pms=(None, None), - match_norm=False, - stream_coords=True, - reflex_corr=True, - zero_line=True, - pm_lims=(-20, 20), - **kwargs, - ): - # Code from Nora - + data: dict, + dx: float = 0.5, + norm: float = 1, + save: float = 0, + pms: tuple[None, None] = (None, None), + pm_lims: tuple[float, float] = (-20, 20), + *, + match_norm: bool = False, + stream_coords: bool = True, + reflex_corr: bool = True, + zero_line: bool = True, + **kwargs: Any, + ) -> Figure: + """Plot proper motions on stream and off stream histograms. + + .. codeauthor:: + Nora Shipp + """ data_on = data[self.mask] data_off = data[self.off_mask] @@ -641,7 +677,8 @@ def plot_pm_hist( h1 = np.histogram2d(data_on["PMRA0"], data_on["PMDEC0"], bins)[0] h2 = np.histogram2d(data_off["PMRA0"], data_off["PMDEC0"], bins)[0] - # might need to normalise histogram for different areas of off stream mask for subtraction histogram + # might need to normalise histogram for different areas of off stream + # mask for subtraction histogram h2 *= norm # print h1.sum(), h2.sum() @@ -704,17 +741,17 @@ def plot_pm_hist( **kwargs, ) - colorbar(im1) - colorbar(im2) - colorbar(im3) + _colorbar(im1) + _colorbar(im2) + _colorbar(im3) if (pms[0] is None) or (pms[1] is None): - ax1.axvline(self.best_pm[0], ls="--", c="k", lw=1) - ax2.axvline(self.best_pm[0], ls="--", c="k", lw=1) - ax3.axvline(self.best_pm[0], ls="--", c="k", lw=1) - ax1.axhline(self.best_pm[1], ls="--", c="k", lw=1) - ax2.axhline(self.best_pm[1], ls="--", c="k", lw=1) - ax3.axhline(self.best_pm[1], ls="--", c="k", lw=1) + ax1.axvline(self.best_pm_phi1_mean, ls="--", c="k", lw=1) + ax2.axvline(self.best_pm_phi1_mean, ls="--", c="k", lw=1) + ax3.axvline(self.best_pm_phi1_mean, ls="--", c="k", lw=1) + ax1.axhline(self.best_pm_phi2_mean, ls="--", c="k", lw=1) + ax2.axhline(self.best_pm_phi2_mean, ls="--", c="k", lw=1) + ax3.axhline(self.best_pm_phi2_mean, ls="--", c="k", lw=1) else: ax1.axvline(pms[0], ls="--", c="k", lw=1) ax2.axvline(pms[0], ls="--", c="k", lw=1) @@ -770,23 +807,33 @@ def plot_pm_hist( # ========================= added Nov 3 (need checking) ======================== - def find_peak_location(self, data, x_width=3.0, y_width=3.0, draw_histograms=True): - """ - find peak location in the proper motion space - :param: data: list of the stellar parameters to get the peak pm. - :param: x_width: float, half x-size of zoomed region box, default set to 3. - :param: y_width: float, half y-size of zoomed region box, default set to 3. - :param: draw_histograms: print histograms, default set to True - - output: [pm_x_cen, pm_y_cen, x_std, y_std]: array + def find_peak_location( + self, + data: dict, + x_width: float = 3.0, + y_width: float = 3.0, + *, + draw_histograms: bool = True, + ) -> tuple[float, float, float, float]: + """Find peak location in the proper motion space. + + Parameters + ---------- + data : dict[str, np.ndarray] + Stellar parameters to get the peak pm. + x_width, y_width : float, optional + half x,y-size of zoomed region box, default set to 3. + draw_histograms: bool + print histograms, default set to True + + Returns + ------- + pm_x_cen, pm_y_cen, x_std, y_std: array pm_x_cen: peak proper motion in phi1 pm_y_cen: peak proper motion in phi2 x_std: standard deviation proper motion in phi1 y_std: standard deviation proper motion in phi2 """ - from astropy.modeling import fitting, models - from matplotlib.colors import LogNorm - x_center, y_center = self.best_pm_phi1_mean, self.best_pm_phi2_mean print(f"Pre-fitting mean PM values: {x_center}, {y_center}") xmin, xmax, ymin, ymax = ( @@ -820,7 +867,8 @@ def find_peak_location(self, data, x_width=3.0, y_width=3.0, draw_histograms=Tru hist = ( H1 - self.mask.sum() / self.off_mask.sum() * H2 ) # check this scale factor --> self.mask.sum()/self.off_mask.sum() - # Do we want to do based on counts or based on area, since we do expect more counts on stream (but maybe negligible) + # Do we want to do based on counts or based on area, since we do expect + # more counts on stream (but maybe negligible) # fitting 2D gaussian (Code from Ani) # Find overdensity @@ -839,7 +887,6 @@ def find_peak_location(self, data, x_width=3.0, y_width=3.0, draw_histograms=Tru y_stddev=0.5, ) fit_g = fitting.LevMarLSQFitter() - # x,y = np.meshgrid(x_edges[(ind[0]-6):(ind[0]+7)], y_edges[(ind[1]-6):(ind[1]+7)]) x, y = np.meshgrid(x_edges[:-1], y_edges[:-1]) # g = fit_g(g_init, x, y, hist_zoom) @@ -853,7 +900,7 @@ def find_peak_location(self, data, x_width=3.0, y_width=3.0, draw_histograms=Tru # draw proper motions of the on-stream, off-stream and residual histogram if draw_histograms: - fig, axes = plt.subplots( + _, axes = plt.subplots( 1, 3, figsize=(15, 5), sharex=True, sharey=True, constrained_layout=True ) @@ -880,8 +927,6 @@ def find_peak_location(self, data, x_width=3.0, y_width=3.0, draw_histograms=Tru axes[2].plot( pm_x_cen + x_std * np.cos(t), pm_y_cen + y_std * np.sin(t), c="green" ) - # axes[2].set_xlim(xmin,xmax) - # axes[2].set_ylim(ymin,ymax) for ax in axes[:2]: ax.plot( @@ -896,10 +941,14 @@ def find_peak_location(self, data, x_width=3.0, y_width=3.0, draw_histograms=Tru self.best_pm_phi1_std = x_std self.best_pm_phi2_std = y_std - return [pm_x_cen, pm_y_cen, x_std, y_std] + return (pm_x_cen, pm_y_cen, x_std, y_std) + + +############################################################################### -def colorbar(mappable): +def _colorbar(mappable: ScalarMappable) -> Colorbar: + """Add a colorbar to the figure.""" ax = mappable.axes fig = ax.figure divider = make_axes_locatable(ax) diff --git a/src/cats/star/__init__.py b/src/cats/star/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/cats/star/star.py b/src/cats/star/star.py deleted file mode 100644 index 4693248..0000000 --- a/src/cats/star/star.py +++ /dev/null @@ -1,132 +0,0 @@ -# import numpy as np -# from astropy.coordinates import SkyCoord, CoordFrame - -# #notes on data types -# dtypes_dict = { -# #identifiers -# 'sourceID': 'u8', #Gaia DR3 source id, unsigned integer -# 'sourceID_version': 'u2', #integer representing version of gaia catalog that the source ID refers to -# 'streamID': 's10', #10-char string? check galstreams -# 'crossmatches':dict(), #dictionary to store IDs in other surveys - -# #phase space -# 'w': {}#dict of phase space coordinates -# 'w_uncert': {}#dictionary of uncertainties per coordinate -# #todo: add function to convert w to SkyCoord -# #todo: bring in pyia to account for transformations to uncertainties - -# #precompute and store phi1 and phi2 for a desired coordinate frame -# 'phi1': 'f8' -# 'phi2': 'f8' -# 'rotation': CoordFrame() #placeholder - -# #flag for variable stars -# 'variability': 'u2', #0 if not variable, 1 if variable - -# #photometry from Gaia used for selections in pawprint -# 'phot_g_mean_mag': 'f4', -# 'phot_rp_mean_mag': 'f4', -# 'phot_g_mean_mag_error': 'f4', -# 'phot_rp_mean_mag_error': 'f4', - -# #TODO WG3: how to store extinction? - -# #consensus chemistry -# 'feh_logeps': 'f4', #sun-independent value -# 'feh': 'f4', #iron abundance as [Fe/H] -# 'feh_solar': 'f4', #solar value of [Fe/H] for this star - -# 'alpha_logeps': 'f4', #sun-independent value -# 'alpha_fe': 'f4', #alpha abundance as [alpha/Fe] -# 'alpha_solar': 'f4', #solar value of [alpha/H] for this star - - -# #references: assumes that sky position, photometry, and PM are from Gaia -# 'refs': { -# 'distance': ['s19'] #ADS bibcode (or pointer to docs) for distance measurement; can be a list -# 'rv': ['s19'] -# 'feh': ['s19'] -# 'alpha': ['s19'] -# 'variability': ['s19'] -# 'extinction': ['s19'] -# } - -# #TODO: membership likelihoods - -# } - -# def get_phase_space(): -# '''queries gaia to initialize SkyCoords for phase-space position and uncertainty''' - -# def get_gaia_photometry(): -# '''queries gaia to get photometry and uncertainties, returned as 32 bit floats''' - -# def get_abundances(): -# '''queries ancillary data tables for spectroscopy, probably given some options that perhaps user can set''' - - -# class Star(dict): -# def __init__(self, streamID): -# #starclass -# self.data = StarData(streamID) -# self.derived = StarDerived(streamID) - - -# class StarData(dict): -# '''dictionary class to store __measured__ attributes for one star in the catalog''' -# def __init__(self, streamID): - -# self.sourceID_version = np.uint(3) #default for now is DR3 -# self.streamID = np.str_(streamID) - -# self.sourceID = load_stream(self.streamID) #read from adrian's initial files - -# self.nstars = len(self.sourceID) - -# self.crossmatches = {} - -# self.w, self.w_uncert = phasespace_to_skycoords() #TODO:function to return sky coordinates and uncertainties in two skycoord objects by querying Gaia - -# #flexible magnitudes - pin down standardised naming convention -# #my proposal: [survey]_[filter] -# #uncertainties and extinctions are specified by the same tags -# self.mags = {'gaia_g':np.array(nstars,dtype='f4'),'gaia_rp':np.array(nstars,dtype='f4'), } -# self.mag_uncert = {'gaia_g':np.array(nstars,dtype='f4'),'gaia_rp':np.array(nstars,dtype='f4'), } -# self.ext = {'gaia_g':np.array(nstars,dtype='f4'),'gaia_rp':np.array(nstars,dtype='f4'), } - - -# self.variability = np.array(nstars,dtype='u2') - -# self.feh = np.masked_array(nstars,dtype='f4') -# self.feh_logeps = np.masked_array(nstars,dtype='f4') -# self.feh_solar = np.masked_array(nstars,dtype='f4') -# self.alpha_logeps = np.masked_array(nstars,dtype='f4') -# self.alpha_fe = np.masked_array(nstars,dtype='f4') -# self.alpha_solar = np.masked_array(nstars,dtype='f4') - -# self.refs = { -# 'distance': np.array(nstars,dtype='s19') #ADS bibcode (or pointer to doi) for distance measurement; can be a list -# 'rv': np.array(nstars,dtype='s19') -# 'feh': np.array(nstars,dtype='s19') -# 'alpha': np.array(nstars,dtype='s19') -# 'variability': np.array(nstars,dtype='s19') -# } - -# #some stuff could be read directly and stored -# if get_gaia: -# get_gaia_photometry(self) #load gaia photometry in from catalog -# get_abundances(self) #load abundances from detailed spectroscopoc tables - -# class StarDerived(dict): -# '''class to store derived attributes''' -# ... - - -# def _inside_poly(data, vertices): -# '''This function takes a list of points (data) and returns a boolean mask that is True for all points inside the polygon defined by vertices''' -# return mpl_path(vertices).contains_points(data) - -# def makeMask(self, pawprint, what): -# '''take in some data and return masks for stuff in the pawprint (basically by successively applying _inside_poly)''' -# #returns mask with same dimension as data -# ... diff --git a/tests_WIP/__init__.py b/tests_WIP/__init__.py new file mode 100644 index 0000000..3f9d828 --- /dev/null +++ b/tests_WIP/__init__.py @@ -0,0 +1 @@ +"""WIP tests.""" diff --git a/src/cats/cmd/tests/test_GD1.py b/tests_WIP/cmd/GD1.py similarity index 88% rename from src/cats/cmd/tests/test_GD1.py rename to tests_WIP/cmd/GD1.py index e61e48d..36517b6 100644 --- a/src/cats/cmd/tests/test_GD1.py +++ b/tests_WIP/cmd/GD1.py @@ -1,9 +1,11 @@ +"""GD-1 test script for CMD fitting.""" + from __future__ import annotations import astropy.table as at -from CMD import Isochrone -from cats.pawprint.pawprint import Footprint2D, Pawprint +from cats.cmd import Isochrone +from cats.pawprint import Footprint2D, Pawprint # Note: if already loaded, we can just write: fn = "/Users/Tavangar/CATS_workshop/cats/data/joined-GD-1.fits" diff --git a/tests_WIP/cmd/__init__.py b/tests_WIP/cmd/__init__.py new file mode 100644 index 0000000..1fd3309 --- /dev/null +++ b/tests_WIP/cmd/__init__.py @@ -0,0 +1 @@ +"""WIP cmd tests.""" diff --git a/src/cats/cmd/tests/gd1_testcmd.png b/tests_WIP/cmd/gd1_testcmd.png similarity index 100% rename from src/cats/cmd/tests/gd1_testcmd.png rename to tests_WIP/cmd/gd1_testcmd.png diff --git a/src/cats/cmd/tests/test_run_GD-1.py b/tests_WIP/cmd/run_GD-1.py similarity index 95% rename from src/cats/cmd/tests/test_run_GD-1.py rename to tests_WIP/cmd/run_GD-1.py index 5ad332b..5e501c5 100644 --- a/src/cats/cmd/tests/test_run_GD-1.py +++ b/tests_WIP/cmd/run_GD-1.py @@ -1,3 +1,5 @@ +"""GD-1 test script for CMD fitting.""" + from __future__ import annotations import astropy.table as at @@ -7,6 +9,7 @@ def main() -> int: + """GD-1 test script for CMD fitting.""" fn = "/Users/Tavangar/CATS_workshop/cats/data/joined-GD-1.fits" cat = at.Table.read(fn) diff --git a/src/cats/cmd/tests/test_run_Jhelum.py b/tests_WIP/cmd/run_Jhelum.py similarity index 95% rename from src/cats/cmd/tests/test_run_Jhelum.py rename to tests_WIP/cmd/run_Jhelum.py index cd759b0..0b5a357 100644 --- a/src/cats/cmd/tests/test_run_Jhelum.py +++ b/tests_WIP/cmd/run_Jhelum.py @@ -1,3 +1,5 @@ +"""Jhelum test script.""" + from __future__ import annotations import astropy.table as at @@ -5,6 +7,7 @@ def main() -> int: + """Jhelum test script.""" fn = "./joined-Jhelum.fits" cat = at.Table.read(fn) diff --git a/src/cats/cmd/tests/test_run_Pal5.py b/tests_WIP/cmd/run_Pal5.py similarity index 94% rename from src/cats/cmd/tests/test_run_Pal5.py rename to tests_WIP/cmd/run_Pal5.py index 27b743e..6f11345 100644 --- a/src/cats/cmd/tests/test_run_Pal5.py +++ b/tests_WIP/cmd/run_Pal5.py @@ -1,3 +1,5 @@ +"""Pal5 test script.""" + from __future__ import annotations import astropy.table as at @@ -5,6 +7,7 @@ def main() -> int: + """Pal5 test script.""" fn = "./joined-Pal5.fits" cat = at.Table.read(fn) diff --git a/tests_WIP/pawprint/__init__.py b/tests_WIP/pawprint/__init__.py new file mode 100644 index 0000000..eb52a8d --- /dev/null +++ b/tests_WIP/pawprint/__init__.py @@ -0,0 +1 @@ +"""WIP pawprint tests.""" diff --git a/src/cats/pawprint/tests/test_mwe.py b/tests_WIP/pawprint/mwe.py similarity index 89% rename from src/cats/pawprint/tests/test_mwe.py rename to tests_WIP/pawprint/mwe.py index 4aa7d83..1cc9f3b 100644 --- a/src/cats/pawprint/tests/test_mwe.py +++ b/tests_WIP/pawprint/mwe.py @@ -1,3 +1,5 @@ +"""WIP tests.""" + from __future__ import annotations import astropy.units as u @@ -38,9 +40,8 @@ on = stars.makeMask(pawprint, what="sky.stream") # is a function of starlist ax.plot(stars.ra[on], stars.dec[on], ".", ms=2.5, color="C0") -# Create a new polygon footprint off-stream, with a given offset and width, and select field points inside it -# -# off_poly = mwsts[st].create_sky_polygon_footprint_from_track(width=1.*u.deg, phi2_offset=3.5*u.deg) +# Create a new polygon footprint off-stream, with a given offset and width, and +# select field points inside it. off = stars.makeMask(pawprint, what="sky.background") # Plot the off-stream polygon footprint and points selected inside it ax.plot(