forked from MIND-Lab/OCTIS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_datasets.py
95 lines (68 loc) · 2.8 KB
/
test_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python
"""Tests for `octis` package."""
import pytest
from click.testing import CliRunner
from octis.evaluation_metrics.classification_metrics import F1Score
from octis.evaluation_metrics.coherence_metrics import *
from octis.dataset.dataset import Dataset
import os
from octis.preprocessing.preprocessing import Preprocessing
from octis.dataset.downloader import get_data_home, _pkl_filepath
@pytest.fixture
def root_dir():
return os.path.dirname(os.path.abspath(__file__))
@pytest.fixture
def data_dir(root_dir):
return root_dir + "/../preprocessed_datasets/"
def test_preprocessing_custom_stops(data_dir):
texts_path = data_dir+"/sample_texts/unprepr_docs.txt"
p = Preprocessing(vocabulary=None, max_features=None, remove_punctuation=True, punctuation=".,?:",
lemmatize=False, stopword_list=['am', 'are', 'this', 'that'],
min_chars=2, min_words_docs=5,min_df=0.0001)
dataset = p.preprocess_dataset(
documents_path=texts_path,
)
dataset.save(data_dir+"/sample_texts/")
dataset.load_custom_dataset_from_folder(data_dir + "/sample_texts")
def test_preprocessing_english_stops_split(data_dir):
texts_path = data_dir+"/sample_texts/unprepr_docs.txt"
p = Preprocessing(vocabulary=None, max_features=None, remove_punctuation=True,
lemmatize=False, stopword_list='english', split=False,
min_chars=2, min_words_docs=1)
dataset = p.preprocess_dataset(
documents_path=texts_path,
)
dataset.save(data_dir+"/sample_texts/")
dataset.load_custom_dataset_from_folder(data_dir + "/sample_texts")
def test_load_20ng():
data_home = get_data_home(data_home=None)
cache_path = _pkl_filepath(data_home, "20NewsGroup" + ".pkz")
if os.path.exists(cache_path):
os.remove(cache_path)
dataset = Dataset()
dataset.fetch_dataset("20NewsGroup")
assert len(dataset.get_corpus()) == 16309
assert len(dataset.get_labels()) == 16309
assert os.path.exists(cache_path)
dataset = Dataset()
dataset.fetch_dataset("20NewsGroup")
assert len(dataset.get_corpus()) == 16309
def test_load_M10():
dataset = Dataset()
dataset.fetch_dataset("M10")
assert len(set(dataset.get_labels())) == 10
def test_partitions_fetch():
dataset = Dataset()
dataset.fetch_dataset("M10")
partitions = dataset.get_partitioned_corpus()
assert len(partitions[0]) == 5847
assert len(partitions[1]) == 1254
def test_partitions_custom(data_dir):
dataset = Dataset()
dataset.load_custom_dataset_from_folder(data_dir+"M10")
partitions = dataset.get_partitioned_corpus()
assert len(partitions[0]) == 5847
assert len(partitions[1]) == 1254
def test_fetch_encoding():
dataset = Dataset()
dataset.fetch_dataset('DBPedia_IT')