diff --git a/configs/config_ds.yml b/configs/config_ds.yml index ad443d6..381ef78 100644 --- a/configs/config_ds.yml +++ b/configs/config_ds.yml @@ -32,6 +32,7 @@ stages: valid_csv: "./csv/kfold5/valid_0.csv" root: "/raid/data/kaggle/recursion-cellular-image-classification/" sites: [1] + site_mode: 'random' channels: [1, 2, 3, 4, 5, 6] # stage0: diff --git a/src/dataset.py b/src/dataset.py index ca696e0..df68f4c 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -223,25 +223,22 @@ def __init__(self, sites=[1], mode='train', channels=[1, 2, 3, 4, 5, 6], + site_mode='random' ): print("Channels ", channels) print("sites ", sites) print(csv_file) + assert site_mode in ['random', 'two', 'one'] df = pd.read_csv(csv_file, nrows=None) - if mode == 'train': + if mode == 'train' and site_mode == 'two': df["site"] = 1 df_copy = df.copy() df_copy["site"] = 2 df = pd.concat([df, df_copy], axis=0).reset_index(drop=True) + if not 'sirna' in df.columns: df['sirna'] = 0 - # if "train" in csv_file: - # md = combine_metadata(base_path=root) - # md = md[(md.dataset == mode) & (md.site == 1)] - # md = md[md.experiment.isin(df.experiment)] - # md.sirna = md.sirna.astype(np.int) - # df = md - # + self.pixel_stat = pd.read_csv(os.path.join(root, "pixel_stats.csv")) self.stat_dict = {} for experiment, plate, well, site, channel, mean, std in zip(self.pixel_stat.experiment, @@ -269,13 +266,13 @@ def __init__(self, self.stat_dict[experiment][plate][well][site][channel]["mean"] = mean / 255 self.stat_dict[experiment][plate][well][site][channel]["std"] = std / 255 - self.transform = transform self.mode = mode self.channels = channels if mode == 'train': - self.sites = df["site"].values - else: + if site_mode == 'two': + self.sites = df["site"].values + else: # Test only use one by one site self.sites = sites self.experiments = df['experiment'].values @@ -284,6 +281,7 @@ def __init__(self, self.labels = df['sirna'].values self.root = root + self.site_mode = site_mode def __len__(self): return len(self.experiments) @@ -297,17 +295,19 @@ def __getitem__(self, idx): channel_paths = [] if self.mode == 'train': - # if np.random.rand() < 0.5: - # sites = [1] - # else: - # sites = [2] - sites = self.sites[idx] - sites = [sites] - else: + if self.site_mode == 'random': + if np.random.rand() < 0.5: + sites = [1] + else: + sites = [2] + elif self.site_mode == 'two': + sites = self.sites[idx] + sites = [sites] + else: # Only one site + sites = self.sites + else: # Only one site for test sites = self.sites - # sites = self.sites - for site in sites: for channel in self.channels: path = image_path( @@ -332,8 +332,6 @@ def __getitem__(self, idx): mean_arr.append(mean) image = load_images_as_tensor(channel_paths, dtype=np.float32) - # image = convert_tensor_to_rgb(image) - # image = image / 255 if self.transform: image = self.transform(image=image)['image'] diff --git a/src/experiment.py b/src/experiment.py index 0b57214..1e20155 100644 --- a/src/experiment.py +++ b/src/experiment.py @@ -56,6 +56,7 @@ def get_datasets(self, stage: str, **kwargs): pseudo_csv = kwargs.get('pseudo_csv', None) sites = kwargs.get('sites', [1]) channels = kwargs.get('channels', [1, 2, 3, 4, 5, 6]) + site_mode = kwargs.get('site_mode', 'random') root = kwargs.get('root', None) if train_csv: @@ -66,7 +67,8 @@ def get_datasets(self, stage: str, **kwargs): transform=transform, mode='train', sites=sites, - channels=channels + channels=channels, + site_mode=site_mode ) if pseudo_csv: @@ -76,7 +78,8 @@ def get_datasets(self, stage: str, **kwargs): transform=transform, mode='test', sites=sites, - channels=channels + channels=channels, + site_mode=site_mode ) train_set = ConcatDataset([