Skip to content

Commit

Permalink
style: formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
boasvdp committed Jul 26, 2024
1 parent 70b5895 commit 8c00bc4
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 27 deletions.
13 changes: 9 additions & 4 deletions workflow/rules/clustering.smk
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
if config["clustering_preset"] == "mycobacterium_tuberculosis":

rule copy_or_touch_list_excluded_samples:
output:
temp(OUT + "/previous_list_excluded_samples.tsv"),
params:
previous_list = PREVIOUS_CLUSTERING + "/list_excluded_samples.tsv"
previous_list=PREVIOUS_CLUSTERING + "/list_excluded_samples.tsv",
shell:
"""
if [ -f {params.previous_list} ]
Expand All @@ -13,11 +14,13 @@ else
touch {output}
fi
"""

rule list_excluded_samples:
input:
seq_exp_json = expand(INPUT + "/mtb_typing/seq_exp_json/{sample}.json", sample=SAMPLES),
exclude_list = OUT + "/previous_list_excluded_samples.tsv",
seq_exp_json=expand(
INPUT + "/mtb_typing/seq_exp_json/{sample}.json", sample=SAMPLES
),
exclude_list=OUT + "/previous_list_excluded_samples.tsv",
output:
OUT + "/list_excluded_samples.tsv",
log:
Expand Down Expand Up @@ -45,7 +48,9 @@ python workflow/scripts/list_excluded_samples.py \
--coverage-threshold {params.coverage_threshold} \
2>&1> {log}
"""

else:

rule touch_list_excluded_samples:
output:
temp(OUT + "/list_excluded_samples.tsv"),
Expand Down
29 changes: 21 additions & 8 deletions workflow/scripts/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def read_data(distances, previous_clustering):
)
return df_distances, df_previous_clustering


@timing
def clean_sample_columns(df, cols, fixed_string):
"""
Expand Down Expand Up @@ -119,6 +120,7 @@ def exclude_samples(df_distances, exclude_list):
]
return df_distances


@timing
def emit_and_save_critical_warning(message, output_path):
"""
Expand All @@ -140,6 +142,7 @@ def emit_and_save_critical_warning(message, output_path):
with open(output_path, "a") as f:
f.write(message + "\n")


@timing
def get_df_nodes(df_distances, df_previous_clustering):
"""
Expand Down Expand Up @@ -369,7 +372,10 @@ def infer_clusters(graph, merged_cluster_separator, warnings_path):
set_curated_clusters = enlist_clusters(subgraph, "curated_cluster")
set_final_clusters = enlist_clusters(subgraph, "final_cluster")
if len(set_curated_clusters) > 1:
emit_and_save_critical_warning(f"WARNING: Curated clusters {set_curated_clusters} have merged!", warnings_path)
emit_and_save_critical_warning(
f"WARNING: Curated clusters {set_curated_clusters} have merged!",
warnings_path,
)
inferred_cluster = construct_merged_cluster_name(
set_curated_clusters, merged_cluster_separator
)
Expand All @@ -379,7 +385,10 @@ def infer_clusters(graph, merged_cluster_separator, warnings_path):
f"Cluster {inferred_cluster} is curated and not merged with others"
)
elif len(set_final_clusters) > 1:
emit_and_save_critical_warning(f"WARNING: Final clusters {set_final_clusters} have merged!", warnings_path)
emit_and_save_critical_warning(
f"WARNING: Final clusters {set_final_clusters} have merged!",
warnings_path,
)
inferred_cluster = construct_merged_cluster_name(
set_final_clusters, merged_cluster_separator
)
Expand Down Expand Up @@ -452,7 +461,9 @@ def main(args):
args.distances, args.previous_clustering
)

df_distances = clean_sample_columns(df_distances, ["sample1", "sample2"], "_contig1")
df_distances = clean_sample_columns(
df_distances, ["sample1", "sample2"], "_contig1"
)

if args.exclude_list:
df_distances = exclude_samples(df_distances, args.exclude_list)
Expand All @@ -463,7 +474,9 @@ def main(args):

G = create_graph(df_distances_filtered, df_nodes)

inferred_cluster_dict = infer_clusters(G, args.merged_cluster_separator, args.warnings_path)
inferred_cluster_dict = infer_clusters(
G, args.merged_cluster_separator, args.warnings_path
)

create_output(inferred_cluster_dict, df_previous_clustering, args.output)

Expand Down Expand Up @@ -493,14 +506,14 @@ def main(args):
default="|",
)
parser.add_argument(
"--exclude-list", type=Path, help="Path to list of samples to exclude from clustering"
"--exclude-list",
type=Path,
help="Path to list of samples to exclude from clustering",
)
parser.add_argument(
"--log", type=Path, help="Path to log file", default="cluster.log"
)
parser.add_argument(
"--warnings-path", type=Path, help="Path to warnings file"
)
parser.add_argument("--warnings-path", type=Path, help="Path to warnings file")
parser.add_argument(
"-v", "--verbose", action="count", default=0, help="Verbosity level"
)
Expand Down
41 changes: 26 additions & 15 deletions workflow/scripts/list_excluded_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,53 +5,64 @@
import datetime
import logging


def read_input_data(input_files):
data = {}
for file in input_files:
with open(file) as f:
data[file.stem] = json.load(f)
df = pd.DataFrame.from_dict(data, orient='index').reset_index(names="sample")
df['mean_coverage'] = df['mean_coverage'].astype(float)
df = pd.DataFrame.from_dict(data, orient="index").reset_index(names="sample")
df["mean_coverage"] = df["mean_coverage"].astype(float)
return df


def exclude_on_coverage(df, threshold):
df_copy = df.copy()
return df_copy[df_copy['mean_coverage'] < threshold]
return df_copy[df_copy["mean_coverage"] < threshold]


def exclude_on_pattern(df, pattern):
df_copy = df.copy()
return df_copy[~df_copy['sample'].str.contains(pattern)]
return df_copy[~df_copy["sample"].str.contains(pattern)]


def read_previous_exclude_list(file):
with open(file) as f:
lines = f.readlines()
if len(lines) == 0:
return pd.DataFrame(columns=['sample', 'reason', 'date'])
return pd.DataFrame(columns=["sample", "reason", "date"])
else:
df = pd.read_csv(file, sep="\t")
return df


def main(args):
df = read_input_data(args.input)
df_coverage_excluded = exclude_on_coverage(df, args.coverage_threshold)
df_coverage_excluded["reason"] = "low_coverage"
df_pattern_excluded = exclude_on_pattern(df, args.inclusion_pattern)
df_pattern_excluded["reason"] = "not_NLA"
df_excluded = pd.concat([df_coverage_excluded[['sample', 'reason']], df_pattern_excluded[['sample', 'reason']]])
df_excluded['date'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
df_excluded = pd.concat(
[
df_coverage_excluded[["sample", "reason"]],
df_pattern_excluded[["sample", "reason"]],
]
)
df_excluded["date"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
df_previous_excluded = read_previous_exclude_list(args.previous_exclude_list)
df_final = pd.concat([df_previous_excluded, df_excluded])
df_final.to_csv(args.output, sep="\t", index=False)
df_final.to_csv(args.output, sep="\t", index=False)


if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument('--input', type=Path, required=True, nargs='+')
parser.add_argument('--previous-exclude-list', type=Path, required=True)
parser.add_argument('--output', type=Path, required=True)
parser.add_argument('--inclusion-pattern', type=str, required=True)
parser.add_argument('--coverage-threshold', type=float, required=True)
parser.add_argument("--input", type=Path, required=True, nargs="+")
parser.add_argument("--previous-exclude-list", type=Path, required=True)
parser.add_argument("--output", type=Path, required=True)
parser.add_argument("--inclusion-pattern", type=str, required=True)
parser.add_argument("--coverage-threshold", type=float, required=True)

args = parser.parse_args()

main(args)
main(args)

0 comments on commit 8c00bc4

Please sign in to comment.