diff --git a/datasets/flue/README.md b/datasets/flue/README.md new file mode 100644 index 00000000000..f308abdaab3 --- /dev/null +++ b/datasets/flue/README.md @@ -0,0 +1,233 @@ +--- +annotations_creators: +- crowdsourced +- machine-generated +language_creators: +- crowdsourced +languages: +- fr +licenses: +- unknown +multilinguality: +- monolingual +size_categories: +- 10K 9: + id += 1 + review_text, label = self._cls_extractor(line) + yield id_, {"idx": id, "text": review_text, "label": label} + elif self.config.name == "PAWS-X": + with open(data_file, encoding="utf-8") as f: + data = csv.reader(f, delimiter="\t") + next(data) # skip header + id = 0 + for id_, row in enumerate(data): + if len(row) == 4: + id += 1 + yield id_, { + "idx": id, + "sentence1": self._cleaner(row[1]), + "sentence2": self._cleaner(row[2]), + "label": int(row[3].strip()), + } + elif self.config.name == "XNLI": + with open(data_file, encoding="utf-8") as f: + data = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE) + next(data) + id = 0 + for id_, row in enumerate(data): + if split == "train": + id += 1 + yield id_, { + "idx": id, + "premise": self._cleaner(row[0]), + "hypo": self._cleaner(row[1]), + "label": row[2].strip().replace("contradictory", "contradiction"), + } + else: + if row[0] == "fr": + id += 1 + yield id_, { + "idx": id, + "premise": self._cleaner(row[6]), + "hypo": self._cleaner(row[7]), + "label": row[1].strip(), # the label is already "contradiction" in the dev/test + } + elif self.config.name == "WSD-V": + wsd_rdr = WSDDatasetReader() + for inst in wsd_rdr.read_from_data_dirs([os.path.join(data_file, split)]): + yield inst[0], { + "idx": inst[0], + "sentence": inst[1], + "pos_tags": inst[2], + "lemmas": inst[3], + "fine_pos_tags": inst[4], + "disambiguate_tokens_ids": inst[5], + "disambiguate_labels": inst[6], + } + + def _cls_extractor(self, line): + """ + Extract review and label for CLS dataset + from: https://github.com/getalp/Flaubert/blob/master/flue/extract_split_cls.py + """ + m = re.search(r"(?<=)\d+.\d+(?=<\/rating>)", line) + label = "positive" if int(float(m.group(0))) > 3 else "negative" # rating == 3 are already removed + category = re.search(r"(?<=)\w+(?=<\/category>)", line) + + if category == "dvd": + m = re.search(r"(?<=\/url>)(.|\n|\t|\f)+(?=\<\/title>)", line) + else: + m = re.search(r"(?<=\/url>)(.|\n|\t|\f)+(?=\<\/text>)", line) + + review_text = m.group(0) + + return self._cleaner(review_text), label + + def _convert_to_unicode(self, text): + """ + Converts `text` to Unicode (if it's not already), assuming UTF-8 input. + from: https://github.com/getalp/Flaubert/blob/master/tools/clean_text.py + """ + # six_ensure_text is copied from https://github.com/benjaminp/six + def six_ensure_text(s, encoding="utf-8", errors="strict"): + if isinstance(s, six.binary_type): + return s.decode(encoding, errors) + elif isinstance(s, six.text_type): + return s + else: + raise TypeError("not expecting type '%s'" % type(s)) + + return six_ensure_text(text, encoding="utf-8", errors="ignore") + + def _cleaner(self, text): + """ + Clean up an input text + from: https://github.com/getalp/Flaubert/blob/master/tools/clean_text.py + """ + # Convert and normalize the unicode underlying representation + text = self._convert_to_unicode(text) + text = unicodedata.normalize("NFC", text) + + # Normalize whitespace characters and remove carriage return + remap = {ord("\f"): " ", ord("\r"): "", ord("\n"): "", ord("\t"): ""} + text = text.translate(remap) + + # Normalize URL links + pattern = re.compile(r"(?:www|http)\S+|<\S+|\w+\/*>") + text = re.sub(pattern, "", text) + + # remove multiple spaces in text + pattern = re.compile(r"( ){2,}") + text = re.sub(pattern, r" ", text) + + return text + + def _wsdv_prepare_data(self, dirpath): + """ Get data paths from FSE dir""" + paths = {} + + for f in os.listdir(dirpath): + if f.startswith("FSE"): + data = "test" + else: + data = "train" + + paths["_".join((data, f))] = os.path.join(dirpath, f) + + test_dirpath = os.path.join(dirpath, "test") + os.makedirs(test_dirpath, exist_ok=True) + train_dirpath = os.path.join(dirpath, "train") + os.makedirs(train_dirpath, exist_ok=True) + # copy FSE file to new test directory + for k, v in paths.items(): + data = k.split("_")[0] + filename = k.split("_")[1] + copyfile(v, os.path.join(dirpath, data, filename)) + + +# The WSDDatasetReader classes come from https://github.com/getalp/Flaubert/blob/master/flue/wsd/verbs/modules/dataset.py +class WSDDatasetReader: + """ Class to read a WSD data directory. The directory should contain .data.xml and .gold.key.txt files""" + + def get_data_paths(self, indir): + """ Get file paths from WSD dir """ + xml_fpath, gold_fpath = None, None + + for f in os.listdir(indir): + if f.endswith(".data.xml"): + xml_fpath = os.path.join(indir, f) + if f.endswith(".gold.key.txt"): + gold_fpath = os.path.join(indir, f) + return xml_fpath, gold_fpath + + def read_gold(self, infile): + """Read .gold.key.txt and return data as dict. + :param infile: fpath to .gold.key.txt file + :type infile: str + :return: return data into dict format : {str(instance_id): set(label)} + :rtype: dict + """ + return { + line.split()[0]: tuple(line.rstrip("\n").split()[1:]) + for line in open(infile, encoding="utf-8").readlines() + } + + def read_from_data_dirs(self, data_dirs): + """ Read WSD data and return as WSDDataset """ + for d in data_dirs: + xml_fpath, gold_fpath = self.get_data_paths(d) + + # read gold file + id2gold = self.read_gold(gold_fpath) + + sentences = self.read_sentences(d) + + # Parse xml + tree = etree.parse(xml_fpath) + corpus = tree.getroot() + + # process data + # iterate over document + for text in corpus: + # iterates over sentences + for sentence in text: + sent_id = sentence.get("id") # sentence id + sent = next(sentences) # get sentence + pos_tags = [] + lemmas = [] + fine_pos_tags = [] + disambiguate_tokens_ids = [] + disambiguate_labels = [] + tok_idx = 0 + + # iterate over tokens + for tok in sentence: + lemma, pos, fine_pos_tag = tok.get("lemma"), tok.get("pos"), tok.get("fine_pos") + + pos_tags.append(pos) + lemmas.append(lemma) + fine_pos_tags.append(fine_pos_tag) + wf = tok.text + subtokens = wf.split(" ") + + # add sense annotated token + if tok.tag == "instance": + id = tok.get("id") + + target_labels = id2gold[id] + target_first_label = target_labels[0] + + # We focus on the head of the target mwe instance + if pos == "VERB": + tgt_idx = tok_idx # head is mostly the first token as most mwe verb targets are phrasal verbs (i.g lift up) + else: + tgt_idx = ( + tok_idx + len(subtokens) - 1 + ) # other pos head are generally the last token of the mwe (i.g European Union) + + disambiguate_tokens_ids.append(tgt_idx) + disambiguate_labels.append(target_first_label) + + tok_idx += 1 + + yield ( + sent_id, + sent, + pos_tags, + lemmas, + fine_pos_tags, + disambiguate_tokens_ids, + disambiguate_labels, + ) + + def read_sentences(self, data_dir, keep_mwe=True): + """ Read sentences from WSD data""" + + xml_fpath, _ = self.get_data_paths(data_dir) + return self.read_sentences_from_xml(xml_fpath, keep_mwe=keep_mwe) + + def read_sentences_from_xml(self, infile, keep_mwe=False): + """ Read sentences from xml file """ + + # Parse xml + tree = etree.parse(infile) + corpus = tree.getroot() + + for text in corpus: + for sentence in text: + if keep_mwe: + sent = [tok.text.replace(" ", "_") for tok in sentence] + else: + sent = [subtok for tok in sentence for subtok in tok.text.split(" ")] + yield sent + + def read_target_keys(self, infile): + """ Read target keys """ + return [x.rstrip("\n") for x in open(infile, encoding="utf-8").readlines()] diff --git a/setup.py b/setup.py index 03b8f8f912e..ab75c6d8322 100644 --- a/setup.py +++ b/setup.py @@ -106,7 +106,8 @@ 'torch', 'tldextract', 'transformers', - 'zstandard' + 'zstandard', + 'lxml', ] if os.name == "nt": # windows