Skip to content

Commit

Permalink
customize the site mode
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Aug 1, 2019
1 parent df2b2f6 commit 0cf6263
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 24 deletions.
1 change: 1 addition & 0 deletions configs/config_ds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ stages:
valid_csv: "./csv/kfold5/valid_0.csv"
root: "/raid/data/kaggle/recursion-cellular-image-classification/"
sites: [1]
site_mode: 'random'
channels: [1, 2, 3, 4, 5, 6]

# stage0:
Expand Down
42 changes: 20 additions & 22 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,25 +223,22 @@ def __init__(self,
sites=[1],
mode='train',
channels=[1, 2, 3, 4, 5, 6],
site_mode='random'
):
print("Channels ", channels)
print("sites ", sites)
print(csv_file)
assert site_mode in ['random', 'two', 'one']
df = pd.read_csv(csv_file, nrows=None)
if mode == 'train':
if mode == 'train' and site_mode == 'two':
df["site"] = 1
df_copy = df.copy()
df_copy["site"] = 2
df = pd.concat([df, df_copy], axis=0).reset_index(drop=True)

if not 'sirna' in df.columns:
df['sirna'] = 0
# if "train" in csv_file:
# md = combine_metadata(base_path=root)
# md = md[(md.dataset == mode) & (md.site == 1)]
# md = md[md.experiment.isin(df.experiment)]
# md.sirna = md.sirna.astype(np.int)
# df = md
#

self.pixel_stat = pd.read_csv(os.path.join(root, "pixel_stats.csv"))
self.stat_dict = {}
for experiment, plate, well, site, channel, mean, std in zip(self.pixel_stat.experiment,
Expand Down Expand Up @@ -269,13 +266,13 @@ def __init__(self,
self.stat_dict[experiment][plate][well][site][channel]["mean"] = mean / 255
self.stat_dict[experiment][plate][well][site][channel]["std"] = std / 255


self.transform = transform
self.mode = mode
self.channels = channels
if mode == 'train':
self.sites = df["site"].values
else:
if site_mode == 'two':
self.sites = df["site"].values
else: # Test only use one by one site
self.sites = sites

self.experiments = df['experiment'].values
Expand All @@ -284,6 +281,7 @@ def __init__(self,
self.labels = df['sirna'].values

self.root = root
self.site_mode = site_mode

def __len__(self):
return len(self.experiments)
Expand All @@ -297,17 +295,19 @@ def __getitem__(self, idx):
channel_paths = []

if self.mode == 'train':
# if np.random.rand() < 0.5:
# sites = [1]
# else:
# sites = [2]
sites = self.sites[idx]
sites = [sites]
else:
if self.site_mode == 'random':
if np.random.rand() < 0.5:
sites = [1]
else:
sites = [2]
elif self.site_mode == 'two':
sites = self.sites[idx]
sites = [sites]
else: # Only one site
sites = self.sites
else: # Only one site for test
sites = self.sites

# sites = self.sites

for site in sites:
for channel in self.channels:
path = image_path(
Expand All @@ -332,8 +332,6 @@ def __getitem__(self, idx):
mean_arr.append(mean)

image = load_images_as_tensor(channel_paths, dtype=np.float32)
# image = convert_tensor_to_rgb(image)
# image = image / 255
if self.transform:
image = self.transform(image=image)['image']

Expand Down
7 changes: 5 additions & 2 deletions src/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def get_datasets(self, stage: str, **kwargs):
pseudo_csv = kwargs.get('pseudo_csv', None)
sites = kwargs.get('sites', [1])
channels = kwargs.get('channels', [1, 2, 3, 4, 5, 6])
site_mode = kwargs.get('site_mode', 'random')
root = kwargs.get('root', None)

if train_csv:
Expand All @@ -66,7 +67,8 @@ def get_datasets(self, stage: str, **kwargs):
transform=transform,
mode='train',
sites=sites,
channels=channels
channels=channels,
site_mode=site_mode
)

if pseudo_csv:
Expand All @@ -76,7 +78,8 @@ def get_datasets(self, stage: str, **kwargs):
transform=transform,
mode='test',
sites=sites,
channels=channels
channels=channels,
site_mode=site_mode
)

train_set = ConcatDataset([
Expand Down

0 comments on commit 0cf6263

Please sign in to comment.