Skip to content

Commit

Permalink
update dataset to train 1139 classes
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Jul 28, 2019
1 parent e606c19 commit 5a53415
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import cv2
import pandas as pd
from torch.utils.data import Dataset
from tqdm import tqdm

NUM_CLASSES = 1108

Expand Down Expand Up @@ -35,6 +36,13 @@
}


public_experiments = [
"HEPG2-08",
"HUVEC-17",
"RPE-08",
"U2OS-04",
]


def load_image(path):
image = cv2.imread(path, 0)
Expand Down Expand Up @@ -175,6 +183,37 @@ def normalize(img, mean, std, max_pixel_value=255.0):
return img


def _load_dataset(base_path, dataset, include_controls=True):
df = pd.read_csv(os.path.join(base_path, dataset + '.csv'))
if include_controls:
controls = pd.read_csv(
os.path.join(base_path, dataset + '_controls.csv'))
df['well_type'] = 'treatment'
df = pd.concat([controls, df], sort=True)
df['cell_type'] = df.experiment.str.split("-").apply(lambda a: a[0])
df['dataset'] = dataset
dfs = []
for site in (1, 2):
df = df.copy()
df['site'] = site
dfs.append(df)
res = pd.concat(dfs).sort_values(
by=['id_code', 'site']).set_index('id_code')
return res


def combine_metadata(base_path=None,
include_controls=True):
df = pd.concat(
[
_load_dataset(
base_path, dataset, include_controls=include_controls)
for dataset in ['test', 'train']
],
sort=True)
return df


class RecursionCellularSite(Dataset):

def __init__(self,
Expand All @@ -189,6 +228,13 @@ def __init__(self,
print("sites ", sites)
print(csv_file)
df = pd.read_csv(csv_file, nrows=None)
# if "train" in csv_file:
# md = combine_metadata(base_path=root)
# md = md[(md.dataset == mode) & (md.site == 1)]
# md = md[md.experiment.isin(df.experiment)]
# md.sirna = md.sirna.astype(np.int)
# df = md
#
self.pixel_stat = pd.read_csv(os.path.join(root, "pixel_stats.csv"))
self.stat_dict = {}
for experiment, plate, well, site, channel, mean, std in zip(self.pixel_stat.experiment,
Expand Down

0 comments on commit 5a53415

Please sign in to comment.