diff --git a/src/dataset.py b/src/dataset.py index b3e223e..50c5c69 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -3,6 +3,7 @@ import cv2 import pandas as pd from torch.utils.data import Dataset +from tqdm import tqdm NUM_CLASSES = 1108 @@ -35,6 +36,13 @@ } +public_experiments = [ + "HEPG2-08", + "HUVEC-17", + "RPE-08", + "U2OS-04", +] + def load_image(path): image = cv2.imread(path, 0) @@ -175,6 +183,37 @@ def normalize(img, mean, std, max_pixel_value=255.0): return img +def _load_dataset(base_path, dataset, include_controls=True): + df = pd.read_csv(os.path.join(base_path, dataset + '.csv')) + if include_controls: + controls = pd.read_csv( + os.path.join(base_path, dataset + '_controls.csv')) + df['well_type'] = 'treatment' + df = pd.concat([controls, df], sort=True) + df['cell_type'] = df.experiment.str.split("-").apply(lambda a: a[0]) + df['dataset'] = dataset + dfs = [] + for site in (1, 2): + df = df.copy() + df['site'] = site + dfs.append(df) + res = pd.concat(dfs).sort_values( + by=['id_code', 'site']).set_index('id_code') + return res + + +def combine_metadata(base_path=None, + include_controls=True): + df = pd.concat( + [ + _load_dataset( + base_path, dataset, include_controls=include_controls) + for dataset in ['test', 'train'] + ], + sort=True) + return df + + class RecursionCellularSite(Dataset): def __init__(self, @@ -189,6 +228,13 @@ def __init__(self, print("sites ", sites) print(csv_file) df = pd.read_csv(csv_file, nrows=None) + # if "train" in csv_file: + # md = combine_metadata(base_path=root) + # md = md[(md.dataset == mode) & (md.site == 1)] + # md = md[md.experiment.isin(df.experiment)] + # md.sirna = md.sirna.astype(np.int) + # df = md + # self.pixel_stat = pd.read_csv(os.path.join(root, "pixel_stats.csv")) self.stat_dict = {} for experiment, plate, well, site, channel, mean, std in zip(self.pixel_stat.experiment,