Skip to content

Commit

Permalink
fix : regression embeddings sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and pathoumieu committed Apr 14, 2020
1 parent f83ffad commit fd5eb31
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"source": [
"url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n",
"dataset_name = 'census-income'\n",
"out = Path(os.getcwd().rsplit(\"/\", 1)[0]+'/data/'+dataset_name+'.csv')"
"out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion forest_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz\"\n",
"dataset_name = 'forest-cover-type'\n",
"tmp_out = Path(os.getcwd().rsplit(\"/\", 1)[0]+'/data/'+dataset_name+'.gz')\n",
"out = Path(os.getcwd().rsplit(\"/\", 1)[0]+'/data/'+dataset_name+'.csv')"
"out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def train_epoch(self, train_loader):
y_preds = []
ys = []
total_loss = 0
feature_importances_ = np.zeros((self.input_dim))
feature_importances_ = np.zeros((self.network.post_embed_dim))

for data, targets in train_loader:
batch_outs = self.train_batch(data, targets)
Expand Down
12 changes: 8 additions & 4 deletions regression_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"source": [
"url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n",
"dataset_name = 'census-income'\n",
"out = Path(os.getcwd().rsplit(\"/\", 1)[0]+'/data/'+dataset_name+'.csv')"
"out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')"
]
},
{
Expand Down Expand Up @@ -125,7 +125,10 @@
"\n",
"cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n",
"\n",
"cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n"
"cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n",
"\n",
"# define your embedding sizes : here just a random choice\n",
"cat_emb_dim = [5, 4, 3, 6, 2, 2, 1, 10]"
]
},
{
Expand All @@ -141,7 +144,7 @@
"metadata": {},
"outputs": [],
"source": [
"clf = TabNetRegressor()"
"clf = TabNetRegressor(cat_dims=cat_dims, cat_emb_dim=cat_emb_dim, cat_idxs=cat_idxs)"
]
},
{
Expand Down Expand Up @@ -178,7 +181,8 @@
"clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" X_valid=X_valid, y_valid=y_valid,\n",
" max_epochs=1000, patience=50,\n",
" max_epochs=1000,\n",
" patience=50,\n",
" batch_size=1024, virtual_batch_size=128\n",
") "
]
Expand Down

0 comments on commit fd5eb31

Please sign in to comment.