Skip to content

Commit

Permalink
Better filenames when using the high level API (#256)
Browse files Browse the repository at this point in the history
* better filenames and nested dir structure

* changelog

* fix tests
  • Loading branch information
samvanstroud committed Mar 7, 2024
1 parent 5865133 commit 4adc5c0
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 116 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

### [Latest]
- Better filenames when using HLAPI [!256](https://github.com/umami-hep/puma/pull/256)
- Switch to MkDocs for documentation [!254](https://github.com/umami-hep/puma/pull/254)
- Fixed minor bug in HF vertex merging [!255](https://github.com/umami-hep/puma/pull/255)
- Fixed removal of reconstructed PVs and modified which vertexing plots are produced given a jet flavour [!253](https://github.com/umami-hep/puma/pull/253)
Expand Down
131 changes: 74 additions & 57 deletions puma/hlplots/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from ftag import Cuts, Flavour, Flavours
from ftag.hdf5 import H5Reader
from matplotlib.figure import Figure

from puma import (
Histogram,
Expand Down Expand Up @@ -60,6 +61,7 @@ def __post_init__(self):
"peff": self.plot_var_perf,
"scan": self.plot_fraction_scans,
}
self.saved_plots = []

def set_signal(self, signal: Flavour):
if isinstance(signal, str):
Expand All @@ -76,6 +78,14 @@ def set_signal(self, signal: Flavour):
else:
raise ValueError(f"Unsupported signal class {self.signal}.")

@property
def sig_str(self):
suffix = "jets"
sig = str(self.signal)
if sig.endswith(suffix):
sig = sig[: -len(suffix)]
return f"{sig}"

@property
def flavours(self):
"""Return a list of all flavours.
Expand Down Expand Up @@ -241,25 +251,34 @@ def __getitem__(self, tagger_name: str):
"""
return self.taggers[tagger_name]

def get_filename(self, plot_name: str, suffix: str | None = None):
"""Get output name.
def save(
self, plot: Figure, plot_type: str, base: str | None = None, suffix: str | None = None
):
"""Get the output file path.
Parameters
----------
plot_name : str
plot name
plot : Figure
Matplotlib figure to save.
plot_type : str
Plots of the same type are saved in the same directory.
base_fname : str
Base filename, modified by this function.
suffix : str, optional
suffix to add to output name, by default None
Returns
-------
str
output name
Suffix to add to the filename, by default None
"""
base = f"{self.sample}_{self.signal}_{plot_name}"
tag_str = f"{self.sig_str}tag"
out_dir = self.output_dir / tag_str / plot_type
out_dir.mkdir(parents=True, exist_ok=True)
if not base:
base = plot_type
fname = f"{self.sample}_{tag_str}"
fname += f"_{base}"
if suffix:
base += f"_{suffix}"
return Path(self.output_dir / base).with_suffix(f".{self.extension}")
fname += f"_{suffix}"
fpath = out_dir / f"{fname}.{self.extension}"
plot.savefig(fpath)
self.saved_plots.append(fpath)

def plot_probs(
self,
Expand All @@ -280,7 +299,7 @@ def plot_probs(

# group by output probability
for flav_prob in flavours:
histo = HistogramPlot(
hist = HistogramPlot(
n_ratio_panels=1,
xlabel=flav_prob.px,
ylabel="Normalised number of jets",
Expand All @@ -294,7 +313,7 @@ def plot_probs(
for i, tagger in enumerate(self.taggers.values()):
tagger_labels.append(tagger.label if tagger.label else tagger.name)
for flav_class in flavours:
histo.add(
hist.add(
Histogram(
tagger.probs(flav_prob, flav_class),
ratio_group=flav_class,
Expand All @@ -305,17 +324,17 @@ def plot_probs(
reference=tagger.reference,
)

histo.draw()
histo.make_linestyle_legend(
hist.draw()
hist.make_linestyle_legend(
linestyles=line_styles,
labels=tagger_labels,
bbox_to_anchor=(0.55, 1),
)
histo.savefig(self.get_filename(f"probs_{flav_prob.px}", suffix))
self.save(hist, "prob", flav_prob.px, suffix)

# group by flavour
for flav_class in flavours:
histo = HistogramPlot(
hist = HistogramPlot(
n_ratio_panels=1,
xlabel=flav_class.label,
ylabel="Normalised number of jets",
Expand All @@ -329,7 +348,7 @@ def plot_probs(
for i, tagger in enumerate(self.taggers.values()):
tagger_labels.append(tagger.label if tagger.label else tagger.name)
for flav_prob in flavours:
histo.add(
hist.add(
Histogram(
tagger.probs(flav_prob, flav_class),
ratio_group=flav_prob,
Expand All @@ -340,13 +359,13 @@ def plot_probs(
reference=tagger.reference,
)

histo.draw()
histo.make_linestyle_legend(
hist.draw()
hist.make_linestyle_legend(
linestyles=line_styles,
labels=tagger_labels,
bbox_to_anchor=(0.55, 1),
)
histo.savefig(self.get_filename(f"probs_{flav_class}", suffix))
self.save(hist, "prob", flav_class, suffix)

def plot_discs(
self,
Expand Down Expand Up @@ -387,7 +406,7 @@ def plot_discs(
}
if kwargs is not None:
hist_defaults.update(kwargs)
histo = HistogramPlot(**hist_defaults)
hist = HistogramPlot(**hist_defaults)

tagger_labels = []
for i, tagger in enumerate(self.taggers.values()):
Expand All @@ -403,9 +422,9 @@ def plot_discs(
wp_cuts.append(cut)
wp_labels.append(label)

histo.draw_vlines(wp_cuts, labels=wp_labels, linestyle=line_styles[i])
hist.draw_vlines(wp_cuts, labels=wp_labels, linestyle=line_styles[i])
for flav in self.flavours:
histo.add(
hist.add(
Histogram(
discs[tagger.is_flav(flav)],
ratio_group=flav,
Expand All @@ -416,13 +435,13 @@ def plot_discs(
reference=tagger.reference,
)
tagger_labels.append(tagger.label if tagger.label else tagger.name)
histo.draw()
histo.make_linestyle_legend(
hist.draw()
hist.make_linestyle_legend(
linestyles=line_styles,
labels=tagger_labels,
bbox_to_anchor=(0.55, 1),
)
histo.savefig(self.get_filename("disc", suffix))
self.save(hist, "disc", suffix=suffix)

def plot_rocs(
self,
Expand Down Expand Up @@ -457,7 +476,7 @@ def plot_rocs(
}
if roc_kwargs is not None:
roc_plot_args.update(roc_kwargs)
plot_roc = RocPlot(**roc_plot_args)
roc = RocPlot(**roc_plot_args)

for tagger in self.taggers.values():
discs = tagger.discriminant(self.signal)
Expand All @@ -467,7 +486,7 @@ def plot_rocs(
discs[tagger.is_flav(background)],
sig_effs,
)
plot_roc.add_roc(
roc.add_roc(
Roc(
sig_effs,
rej,
Expand All @@ -482,11 +501,10 @@ def plot_rocs(

# setting which flavour rejection ratio is drawn in which ratio panel
for i, background in enumerate(self.backgrounds):
plot_roc.set_ratio_class(i + 1, background)
roc.set_ratio_class(i + 1, background)

plot_roc.draw()
plot_name = self.get_filename("roc", suffix)
plot_roc.savefig(plot_name)
roc.draw()
self.save(roc, "roc", suffix=suffix)

def plot_var_perf( # pylint: disable=too-many-locals
self,
Expand Down Expand Up @@ -617,14 +635,14 @@ def plot_var_perf( # pylint: disable=too-many-locals
".", "p"
)

plot_details = f"{self.signal}_eff_vs_{perf_var}_{plot_base}_{wp_disc}"
plot_suffix = f"{suffix}_" if suffix else ""
plot_sig_eff.savefig(self.get_filename(plot_details, plot_suffix))
fname = f"{self.sig_str}eff_vs_{perf_var}_{plot_base}_{wp_disc}"
suffix = f"{suffix}_" if suffix else ""
self.save(plot_sig_eff, "profile", fname, suffix)

for i, background in enumerate(self.backgrounds):
for i, bkg in enumerate(self.backgrounds):
plot_bkg[i].draw()
plot_details = f"{background}_rej_vs_{perf_var}_{plot_base}_{wp_disc}"
plot_bkg[i].savefig(self.get_filename(plot_details, plot_suffix))
fname = f"{str(bkg)[0]}rej_vs_{perf_var}_{plot_base}_{wp_disc}"
self.save(plot_bkg[i], "profile", fname, suffix)

def plot_flat_rej_var_perf(
self,
Expand Down Expand Up @@ -654,18 +672,18 @@ def plot_flat_rej_var_perf(
**kwargs : kwargs
key word arguments for `puma.VarVsEff`
"""
assert all(
b.name in fixed_rejections for b in self.backgrounds
), "Not all backgrounds have a fixed rejection"
if inv_bkg := set(fixed_rejections.keys()) - {str(b) for b in self.backgrounds}:
raise ValueError(f"Invalid background flavours: {inv_bkg}")
if "disc_cut" in kwargs:
raise ValueError("disc_cut should not be set for this plot")
if "working_point" in kwargs:
raise ValueError("working_point should not be set for this plot")
backgrounds = [Flavours[b] for b in fixed_rejections]
plot_bkg = []
for background in self.backgrounds:
for bkg in backgrounds:
modified_second_tag = (
f"{self.atlas_second_tag}\nFlat {background.rej_str} of"
f" {fixed_rejections[background.name]} per bin"
f"{self.atlas_second_tag}\nFlat {bkg.rej_str} of"
f" {fixed_rejections[bkg.name]} per bin"
)
plot_bkg.append(
VarVsEffPlot(
Expand All @@ -682,8 +700,8 @@ def plot_flat_rej_var_perf(
for tagger in self.taggers.values():
discs = tagger.discriminant(self.signal)
is_signal = tagger.is_flav(self.signal)
for i, background in enumerate(self.backgrounds):
is_bkg = tagger.is_flav(background)
for i, bkg in enumerate(backgrounds):
is_bkg = tagger.is_flav(bkg)

assert perf_var in tagger.perf_vars, f"{perf_var} not in tagger {tagger.name} data!"

Expand All @@ -701,21 +719,21 @@ def plot_flat_rej_var_perf(
disc_bkg=discs[is_signal],
label=tagger.label,
colour=tagger.colour,
working_point=1 / fixed_rejections[background.name],
working_point=1 / fixed_rejections[bkg.name],
flat_per_bin=True,
**kwargs,
),
reference=tagger.reference,
)

plot_suffix = f"_{suffix}" if suffix else ""
for i, background in enumerate(self.backgrounds):
suffix = f"_{suffix}" if suffix else ""
for i, bkg in enumerate(backgrounds):
plot_bkg[i].draw()
if h_line:
plot_bkg[i].draw_hline(h_line)
plot_details = f"{self.signal}_eff_vs_{perf_var}_"
plot_base = f"{background}_rej_flat_{int(fixed_rejections[background.name])}"
plot_bkg[i].savefig(self.get_filename(plot_details + plot_base, plot_suffix))
details = f"{self.sig_str}eff_vs_{perf_var}_"
base = f"{str(bkg)[0]}rej_flat_{int(fixed_rejections[bkg.name])}"
self.save(plot_bkg[i], "profile", details + base, suffix)

def plot_fraction_scans(
self,
Expand Down Expand Up @@ -751,7 +769,6 @@ def plot_fraction_scans(
raise ValueError("Only two background flavours are supported")

frac = "fc" if self.signal == Flavours.bjets else "fb"

back_str = "_".join([f.name for f in backgrounds])
plot_name = f"{frac}_scan"
suffix = combine_suffixes([f"{back_str}_eff{int(efficiency * 100)}", suffix])
Expand Down Expand Up @@ -830,7 +847,7 @@ def plot_fraction_scans(
plot.ylabel = backgrounds[1].rej_str
# Draw and save the plot
plot.draw()
plot.savefig(self.get_filename(plot_name, suffix))
self.save(plot, "scan", plot_name, suffix)

def make_plot(self, plot_type, kwargs):
"""Make a plot.
Expand Down
5 changes: 1 addition & 4 deletions puma/hlplots/yuma.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import argparse
import os
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
Expand Down Expand Up @@ -91,7 +90,7 @@ def get_results(self):
sample_path = self.base_path / sample_path

# Instantiate the results object
results = Results(**kwargs)
results = Results(**kwargs, output_dir=self.plot_dir_final)

# Add taggers to results, then bulk load
for key, t in self.taggers_config.items():
Expand Down Expand Up @@ -175,8 +174,6 @@ def main(args=None):
logger.info(f"Plotting signal {signal}")
yuma.signal = signal
yuma.results.set_signal(signal)
yuma.results.output_dir = yuma.plot_dir_final / f"{signal[0]}tagging"
os.makedirs(yuma.results.output_dir, exist_ok=True)
yuma.make_plots(plots)


Expand Down
Loading

0 comments on commit 4adc5c0

Please sign in to comment.