diff --git a/Chatbot/Reformer_Chatbot.ipynb b/Chatbot/Reformer_Chatbot.ipynb
new file mode 100644
index 0000000..0e9b500
--- /dev/null
+++ b/Chatbot/Reformer_Chatbot.ipynb
@@ -0,0 +1,1335 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 451
+ },
+ "colab_type": "code",
+ "id": "aV4zpTnSVFIp",
+ "outputId": "e3a85dd1-e375-4636-ea62-b9b403f0952a"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 \n",
+ "trax 1.3.4\n",
+ "\u001b[33mWARNING: You are using pip version 20.1.1; however, version 20.2.3 is available.\n",
+ "You should consider upgrading via the '/opt/conda/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "import json\n",
+ "import random\n",
+ "import numpy as np\n",
+ "from termcolor import colored\n",
+ "\n",
+ "import trax \n",
+ "from trax import layers as tl\n",
+ "from trax.supervised import training\n",
+ "!pip list | grep trax"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# filename of the MultiWOZ dialogue dataset\n",
+ "DATA_FILE = 'data.json'\n",
+ "\n",
+ "# data directory\n",
+ "DATA_DIR = './data'\n",
+ "\n",
+ "# dictionary where we will load the dialogue dataset\n",
+ "DIALOGUE_DB = {}\n",
+ "\n",
+ "# vocabulary filename\n",
+ "VOCAB_FILE = 'en_32k.subword'\n",
+ "\n",
+ "# vocabulary file directory\n",
+ "VOCAB_DIR = 'data/vocabs'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 167
+ },
+ "colab_type": "code",
+ "id": "K58I5vFB7GlP",
+ "outputId": "3d086ea4-7898-4870-b52f-f362cb02e118"
+ },
+ "outputs": [],
+ "source": [
+ "# help function to load a JSON file\n",
+ "def load_json(directory, file):\n",
+ " with open(f'{directory}/{file}') as file: \n",
+ " db = json.load(file)\n",
+ " return db\n",
+ "\n",
+ "# load the dialogue data set into our dictionary\n",
+ "DIALOGUE_DB = load_json(DATA_DIR, DATA_FILE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 215
+ },
+ "colab_type": "code",
+ "id": "VGBnUfEk8p9x",
+ "outputId": "4b364506-1088-4f00-be0c-4fefa892dc4e"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The number of dialogues is: 10438\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f'The number of dialogues is: {len(DIALOGUE_DB)}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The dialogues are composed of multiple files and the filenames are used as keys in our dictionary. Those with multi-domain dialogues have \"MUL\" in their filenames while single domain dialogues have either \"SNG\" or \"WOZ\"."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['SNG01856.json', 'SNG0129.json', 'PMUL1635.json', 'MUL2168.json', 'SNG0073.json', 'SNG01445.json', 'MUL2105.json']\n"
+ ]
+ }
+ ],
+ "source": [
+ "# print 7 keys from the dataset to see the filenames\n",
+ "print(list(DIALOGUE_DB.keys())[0:7]) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "colab_type": "code",
+ "id": "5KYeQLnG8p96",
+ "outputId": "b22f570d-a7b0-4b92-ba68-0b7236e61051"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "dict_keys(['goal', 'log'])\n"
+ ]
+ }
+ ],
+ "source": [
+ "# get keys of the fifth file in the list above\n",
+ "print(DIALOGUE_DB['SNG0073.json'].keys())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 122
+ },
+ "colab_type": "code",
+ "id": "PPPWwQ2s8p9_",
+ "outputId": "7e8efa2d-821a-44c8-902d-2c722baf5b4c"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'taxi': {'info': {'leaveAt': '17:15',\n",
+ " 'destination': 'pizza hut fen ditton',\n",
+ " 'departure': \"saint john's college\"},\n",
+ " 'reqt': ['car type', 'phone'],\n",
+ " 'fail_info': {}},\n",
+ " 'police': {},\n",
+ " 'hospital': {},\n",
+ " 'hotel': {},\n",
+ " 'attraction': {},\n",
+ " 'train': {},\n",
+ " 'message': [\"You want to book a taxi. The taxi should go to pizza hut fen ditton and should depart from saint john's college\",\n",
+ " \"The taxi should leave after 17:15\",\n",
+ " \"Make sure you get car type and contact number\"],\n",
+ " 'restaurant': {}}"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "DIALOGUE_DB['SNG0073.json']['goal']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "B4N8RtWu8p-C"
+ },
+ "source": [
+ "The `log` on the other hand contains the dialog. It is a list of dictionaries and each element of this list contains several descriptions as well. Let's look at an example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'text': \"I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.\",\n",
+ " 'metadata': {},\n",
+ " 'dialog_act': {'Taxi-Inform': [['Dest', 'pizza hut fen ditton'],\n",
+ " ['Depart', \"saint john 's college\"]]},\n",
+ " 'span_info': [['Taxi-Inform', 'Dest', 'pizza hut fen ditton', 11, 14],\n",
+ " ['Taxi-Inform', 'Depart', \"saint john 's college\", 6, 9]]}"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# get first element of the log list\n",
+ "DIALOGUE_DB['SNG0073.json']['log'][0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " Person 1: I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.\n",
+ " Person 2: What time do you want to leave and what time do you want to arrive by?\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(' Person 1: ', DIALOGUE_DB['SNG0073.json']['log'][0]['text'])\n",
+ "print(' Person 2: ',DIALOGUE_DB['SNG0073.json']['log'][1]['text'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "def get_conversation(file, data_db):\n",
+ " '''\n",
+ " Args:\n",
+ " file (string): filename of the dialogue file saved as json\n",
+ " data_db (dict): dialogue database\n",
+ " \n",
+ " Returns:\n",
+ " string: A string containing the 'text' fields of data[file]['log'][x]\n",
+ " '''\n",
+ " \n",
+ " # initialize empty string\n",
+ " result = ''\n",
+ " \n",
+ " # get length of file's log list\n",
+ " len_msg_log = len(data_db[file]['log'])\n",
+ " \n",
+ " # set the delimiter strings\n",
+ " delimiter_1 = ' Person 1: '\n",
+ " delimiter_2 = ' Person 2: '\n",
+ " \n",
+ " # loop over the file's log list\n",
+ " for i in range(len_msg_log):\n",
+ " \n",
+ " \n",
+ " # get i'th element of file log list\n",
+ " cur_log = data_db[file]['log'][i]\n",
+ " \n",
+ " # check if i is even\n",
+ " if i%2 == 0: \n",
+ " # append the 1st delimiter string\n",
+ " result += delimiter_1\n",
+ " else: \n",
+ " # append the 2nd delimiter string\n",
+ " result += delimiter_2\n",
+ " \n",
+ " # append the message text from the log\n",
+ " result += cur_log['text']\n",
+ "\n",
+ " return result\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Ugvx0noP8p-G"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " Person 1: am looking for a place to to stay that has cheap price range it should be in a type of hotel Person 2: Okay, do you have a specific area you want to stay in? Person 1: no, i just need to make sure it's cheap. oh, and i need parking Person 2: I found 1 cheap hotel for you that includes parking. Do you like me to book it? Person 1: Yes, please. 6 people 3 nights starting on tuesday. Person 2: I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay? Person 1: how about only 2 nights. Person 2: Booking was successful.\n",
+ "Reference number is : 7GAWK763. Anything else I can do for you? Person 1: No, that will be all. Good bye. Person 2: Thank you for using our services.\n"
+ ]
+ }
+ ],
+ "source": [
+ "file = 'SNG01856.json'\n",
+ "conversation = get_conversation(file, DIALOGUE_DB)\n",
+ "\n",
+ "# print raw output\n",
+ "print(conversation)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[31mPerson 1: am looking for a place to to stay that has cheap price range it should be in a type of hotel \u001b[0m\n",
+ "\u001b[32mPerson 2: Okay, do you have a specific area you want to stay in? \u001b[0m\n",
+ "\u001b[31mPerson 1: no, i just need to make sure it's cheap. oh, and i need parking \u001b[0m\n",
+ "\u001b[32mPerson 2: I found 1 cheap hotel for you that includes parking. Do you like me to book it? \u001b[0m\n",
+ "\u001b[31mPerson 1: Yes, please. 6 people 3 nights starting on tuesday. \u001b[0m\n",
+ "\u001b[32mPerson 2: I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay? \u001b[0m\n",
+ "\u001b[31mPerson 1: how about only 2 nights. \u001b[0m\n",
+ "\u001b[32mPerson 2: Booking was successful.\n",
+ "Reference number is : 7GAWK763. Anything else I can do for you? \u001b[0m\n",
+ "\u001b[31mPerson 1: No, that will be all. Good bye. \u001b[0m\n",
+ "\u001b[32mPerson 2: Thank you for using our services.\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "def print_conversation(conversation):\n",
+ " \n",
+ " delimiter_1 = 'Person 1: '\n",
+ " delimiter_2 = 'Person 2: '\n",
+ " \n",
+ " split_list_d1 = conversation.split(delimiter_1)\n",
+ " \n",
+ " for sublist in split_list_d1[1:]:\n",
+ " split_list_d2 = sublist.split(delimiter_2)\n",
+ " print(colored(f'Person 1: {split_list_d2[0]}', 'red'))\n",
+ " \n",
+ " if len(split_list_d2) > 1:\n",
+ " print(colored(f'Person 2: {split_list_d2[1]}', 'green'))\n",
+ "\n",
+ " \n",
+ "print_conversation(conversation)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 102
+ },
+ "colab_type": "code",
+ "id": "Rs2R8q1d8p-K",
+ "outputId": "8a2f4e3f-4516-449f-9648-a5970707cfc9"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'text': 'am looking for a place to to stay that has cheap price range it should be in a type of hotel',\n",
+ " 'metadata': {},\n",
+ " 'dialog_act': {'Hotel-Inform': [['Type', 'hotel'], ['Price', 'cheap']]},\n",
+ " 'span_info': [['Hotel-Inform', 'Type', 'hotel', 20, 20],\n",
+ " ['Hotel-Inform', 'Price', 'cheap', 10, 10]]}"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "DIALOGUE_DB['SNG01856.json']['log'][0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 54
+ },
+ "colab_type": "code",
+ "id": "HQmYUcsi8p-O",
+ "outputId": "5730c55f-63da-42a8-935e-6eeb17f6f791"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'address': 'pool way, whitehill road, off newmarket road', 'area': 'east', 'entrance fee': '?', 'id': '1', 'location': [52.208789, 0.154883], 'name': 'abbey pool and astroturf pitch', 'openhours': '?', 'phone': '01223902088', 'postcode': 'cb58nt', 'pricerange': '?', 'type': 'swimmingpool'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# this is an example of the attractions file\n",
+ "attraction_file = open('data/attraction_db.json')\n",
+ "attractions = json.load(attraction_file)\n",
+ "print(attractions[0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "colab_type": "code",
+ "id": "I5kTg4uX8p-R",
+ "outputId": "3dacc4ff-4f05-4ae6-d099-33d1b3a6fa2a"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'department': 'neurosciences critical care unit', 'id': 0, 'phone': '01223216297'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# this is an example of the hospital file\n",
+ "hospital_file = open('data/hospital_db.json')\n",
+ "hospitals = json.load(hospital_file)\n",
+ "print(hospitals[0]) # feel free to index into other indices"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 54
+ },
+ "colab_type": "code",
+ "id": "B5knaAEc8p-U",
+ "outputId": "ee0110c2-b2c2-4584-bd42-21f75109a579"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'address': '124 tenison road', 'area': 'east', 'internet': 'yes', 'parking': 'no', 'id': '0', 'location': [52.1963733, 0.1987426], 'name': 'a and b guest house', 'phone': '01223315702', 'postcode': 'cb12dp', 'price': {'double': '70', 'family': '90', 'single': '50'}, 'pricerange': 'moderate', 'stars': '4', 'takesbookings': 'yes', 'type': 'guesthouse'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# this is an example of the hotel file\n",
+ "hotel_file = open('data/hotel_db.json')\n",
+ "hotels = json.load(hotel_file)\n",
+ "print(hotels[0]) # feel free to index into other indices"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "colab_type": "code",
+ "id": "t-Rk01Mv8p-a",
+ "outputId": "8977e17e-2fc3-4073-abb8-fcf5cef3cfaf"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'name': 'Parkside Police Station', 'address': 'Parkside, Cambridge', 'id': 0, 'phone': '01223358966'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# this is an example of the police file\n",
+ "police_file = open('data/police_db.json')\n",
+ "police = json.load(police_file)\n",
+ "print(police[0]) # feel free to index into other indices"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 54
+ },
+ "colab_type": "code",
+ "id": "u-G9pD8g8p-d",
+ "outputId": "1dba6598-b9b6-4fc8-91d2-f844b98e45fa"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'address': 'Regent Street City Centre', 'area': 'centre', 'food': 'italian', 'id': '19210', 'introduction': 'Pizza hut is a large chain with restaurants nationwide offering convenience pizzas pasta and salads to eat in or take away', 'location': [52.20103, 0.126023], 'name': 'pizza hut city centre', 'phone': '01223323737', 'postcode': 'cb21ab', 'pricerange': 'cheap', 'type': 'restaurant'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# this is an example of a restuarant file\n",
+ "restaurant_file = open('data/restaurant_db.json')\n",
+ "restaurants = json.load(restaurant_file)\n",
+ "print(restaurants[0]) # feel free to index into other indices"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 181
+ },
+ "colab_type": "code",
+ "id": "2H8pB_yI8p-g",
+ "outputId": "aa039a49-3ed3-4f4d-fa4f-c2619de3dc99"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "#####################################################\n",
+ "#####################################################\n",
+ "# Copyright Cambridge Dialogue Systems Group, 2018 #\n",
+ "#####################################################\n",
+ "#####################################################\n",
+ "\n",
+ "Dataset contains the following files:\n",
+ "1. data.json: the woz dialogue dataset, which contains the conversation users and wizards, as well as a set of coarse labels for each user turn. This file contains both system and user dialogue acts annotated at the turn level. Files with multi-domain dialogues have \"MUL\" in their names. Single domain dialogues have either \"SNG\" or \"WOZ\" in their names.\n",
+ "2. restaurant_db.json: the Cambridge restaurant database file, containing restaurants in the Cambridge UK area and a set of attributes.\n",
+ "3. attraction_db.json: the Cambridge attraction database file, contining attractions in the Cambridge UK area and a set of attributes.\n",
+ "4. hotel_db.json: the Cambridge hotel database file, containing hotels in the Cambridge UK area and a set of attributes.\n",
+ "5. train_db.json: the Cambridge train (with artificial connections) database file, containing trains in the Cambridge UK area and a set of attributes.\n",
+ "6. hospital_db.json: the Cambridge hospital database file, contatining information about departments.\n",
+ "7. police_db.json: the Cambridge police station information.\n",
+ "8. taxi_db.json: slot-value list for taxi domain.\n",
+ "9. valListFile.txt: list of dialogues for validation.\n",
+ "10. testListFile.txt: list of dialogues for testing.\n",
+ "11. system_acts.json:\n",
+ " There are 6 domains ('Booking', 'Restaurant', 'Hotel', 'Attraction', 'Taxi', 'Train') and 1 dummy domain ('general').\n",
+ " A domain-dependent dialogue act is defined as a domain token followed by a domain-independent dialogue act, e.g. 'Hotel-inform' means it is an 'inform' act in the Hotel domain.\n",
+ " Dialogue acts which cannot take slots, e.g., 'good bye', are defined under the 'general' domain.\n",
+ " A slot-value pair defined as a list with two elements. The first element is slot token and the second one is its value.\n",
+ " If a dialogue act takes no slots, e.g., dialogue act 'offer booking' for an utterance 'would you like to take a reservation?', its slot-value pair is ['none', 'none']\n",
+ " There are four types of values:\n",
+ " 1) If a slot takes a binary value, e.g., 'has Internet' or 'has park', the value is either 'yes' or 'no'.\n",
+ " 2) If a slot is under the act 'request', e.g., 'request' about 'area', the value is expressed as '?'.\n",
+ " 3) The value that appears in the utterance e.g., the name of a restaurant.\n",
+ " 4) If for some reason the turn does not have an annotation then it is labeled as \"No Annotation.\"\n",
+ "12. ontology.json: Data-based ontology containing all the values for the different slots in the domains.\n",
+ "13. slot_descriptions.json: A collection of human-written slot descriptions for each slot in the dataset. Each slot has at least two descriptions.\n",
+ "14. tokenization.md: A description of the tokenization preprocessing we had to perform to maintain consistency between the dialogue act annotations of DSTC 8 Track 1 and the existing MultiWOZ 2.0 data. \n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "with open('data/README') as file:\n",
+ " print(file.read())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 71
+ },
+ "colab_type": "code",
+ "id": "IrnQ9eNV8p-k",
+ "outputId": "2b159dae-78be-4a19-df41-b9e620216d43"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " Person 1: am looking for a place to to stay that has cheap price range it should be in a type of hotel Person 2: Okay, do you have a specific area you want to stay in? Person 1: no, i just need to make sure it's cheap. oh, and i need parking Person 2: I found 1 cheap hotel for you that includes parking. Do you like me to book it? Person 1: Yes, please. 6 people 3 nights starting on tuesday. Person 2: I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay? Person 1: how about only 2 nights. Person 2: Booking was successful.\n",
+ "Reference number is : 7GAWK763. Anything else I can do for you? Person 1: No, that will be all. Good bye. Person 2: Thank you for using our services.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# the keys are the file names\n",
+ "all_files = DIALOGUE_DB.keys()\n",
+ "\n",
+ "# initialize empty list\n",
+ "untokenized_data = []\n",
+ "\n",
+ "# loop over all files\n",
+ "for file in all_files:\n",
+ " # returns a string delimited by Person 1 and Person 2\n",
+ " result = get_conversation(file, DIALOGUE_DB)\n",
+ " \n",
+ " # append to the list\n",
+ " untokenized_data.append(result)\n",
+ "\n",
+ "print(untokenized_data[0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "colab_type": "code",
+ "id": "buE0b8bjx_p_",
+ "outputId": "cb73a95b-488b-4d1d-9c20-5e98ca71f9d5"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "number of conversations in the data set: 10438\n",
+ "number of conversations in train set: 9917\n",
+ "number of conversations in eval set: 521\n"
+ ]
+ }
+ ],
+ "source": [
+ "# shuffle the list we generated above\n",
+ "random.shuffle(untokenized_data)\n",
+ "\n",
+ "# define a cutoff \n",
+ "# convert to int because we will use it as a list index\n",
+ "cut_off = int(len(untokenized_data) * .05)\n",
+ "\n",
+ "# slice the list. the last elements after the cut_off value will be the eval set. the rest is for training. \n",
+ "train_data, eval_data = untokenized_data[:-cut_off], untokenized_data[-cut_off:]\n",
+ "\n",
+ "print(f'number of conversations in the data set: {len(untokenized_data)}')\n",
+ "print(f'number of conversations in train set: {len(train_data)}')\n",
+ "print(f'number of conversations in eval set: {len(eval_data)}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def stream(data):\n",
+ " # loop over the entire data\n",
+ " while True:\n",
+ " # get a random element\n",
+ " d = random.choice(data)\n",
+ " \n",
+ " # yield a tuple pair of identical values \n",
+ " # (i.e. our inputs to the model will also be our targets during training)\n",
+ " yield (d, d)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "uZgK5FAAWwOu"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "data_pipeline = trax.data.Serial(\n",
+ " # randomize the stream\n",
+ " trax.data.Shuffle(),\n",
+ " \n",
+ " # tokenize the data\n",
+ " trax.data.Tokenize(vocab_dir=VOCAB_DIR,\n",
+ " vocab_file=VOCAB_FILE),\n",
+ " \n",
+ " # filter too long sequences\n",
+ " trax.data.FilterByLength(2048),\n",
+ " \n",
+ " # bucket by length\n",
+ " trax.data.BucketByLength(boundaries=[128, 256, 512, 1024],\n",
+ " batch_sizes=[16, 8, 4, 2, 1]),\n",
+ " \n",
+ " # add loss weights but do not add it to the padding tokens (i.e. 0)\n",
+ " trax.data.AddLossWeights(id_to_mask=0)\n",
+ ")\n",
+ "\n",
+ "# apply the data pipeline to our train and eval sets\n",
+ "train_stream = data_pipeline(stream(train_data))\n",
+ "eval_stream = data_pipeline(stream(eval_data))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 88
+ },
+ "colab_type": "code",
+ "id": "9iBQEvhLYRot",
+ "outputId": "78659fd2-4633-47bc-ebe8-3ae3a6e2eab3"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "input shape: (4, 512)\n",
+ " Person 1: I need assistance finding a restaurant that is cheap and serves british food, can you help me? Person 2: I am sorry, I do not have any cheaply priced British restaurants in the city. Would you like a different type of food or maybe a different price range? Person 1: Do you have any Asian Oriental restaurants in the same price range? Person 2: Yes I found 2 cheap asian oriental places to eat. The dojo noodle bar and j restaurant. Would you like me to book a table for you? Person 1: Yes. Please book me a table at the dojo noodle bar for 7 people at 19:00 on wednesday. Person 2: I couldn't get you a table at Dojo, I can get you one at J Restaurant, if you want? Person 1: That works. Same parameters, please. I need the reference number too. Person 2: Booked! Your table will be reserved for 15 minutes. Reference number: F10EIGQ3. Person 1: Great thank you for all your help. Person 2: Do you need anything else? Person 1: That was all thank you. Person 2: Okay great. Thanks for calling and enjoy your dinner.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# the stream generators will yield (input, target, weights). \n",
+ "inp, _, _ = next(train_stream)\n",
+ "\n",
+ "# print the shape. format is (batch size, token length)\n",
+ "print(\"input shape: \", inp.shape)\n",
+ "\n",
+ "# detokenize the first element\n",
+ "print(trax.data.detokenize(inp[0], vocab_dir=VOCAB_DIR, vocab_file=VOCAB_FILE))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "adX2eU762BkF"
+ },
+ "outputs": [],
+ "source": [
+ "def reversible_layer_forward(x, f, g):\n",
+ " \"\"\"\n",
+ " Args: \n",
+ " x (np.array): an input vector or matrix\n",
+ " f (function): a function which operates on a vector/matrix\n",
+ " g (function): a function which operates on a vector/matrix\n",
+ " Returns: \n",
+ " y (np.array): an output vector or matrix whose form is determined by 'x', f and g\n",
+ " \"\"\"\n",
+ " # split the input vector into two (* along the last axis because it is the depth dimension)\n",
+ " x1, x2 = np.split(x, 2, axis=-1) \n",
+ " \n",
+ " y1 = x1 + f(x2)\n",
+ " \n",
+ " y2 = x2 + g(y1)\n",
+ " \n",
+ " # concatenate y1 and y2 along the depth dimension. be sure output is of type np.ndarray\n",
+ " y = np.concatenate([y1, y2], axis=-1)\n",
+ " \n",
+ " return y"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "0xTCG9WlaiiO"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def reversible_layer_reverse(y, f, g):\n",
+ " \"\"\"\n",
+ " Args: \n",
+ " y (np.array): an input vector or matrix\n",
+ " f (function): a function which operates on a vector/matrix of the form of 'y'\n",
+ " g (function): a function which operates on a vector/matrix of the form of 'y'\n",
+ " Returns: \n",
+ " y (np.array): an output vector or matrix whose form is determined by 'y', f and g\n",
+ " \"\"\"\n",
+ " \n",
+ " # split the input vector into two (* along the last axis because it is the depth dimension)\n",
+ " y1, y2 = np.split(y, 2, axis=-1)\n",
+ " \n",
+ " \n",
+ " x2 = y2 - g(y1)\n",
+ " \n",
+ " x1 = y1 - f(x2)\n",
+ " \n",
+ " # concatenate x1 and x2 along the depth dimension\n",
+ " x = np.concatenate([x1, x2], axis=-1)\n",
+ " \n",
+ " return x\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "RidbAcoR6duP"
+ },
+ "outputs": [],
+ "source": [
+ "def ReformerLM(vocab_size=33000, n_layers=2, mode='train', attention_type=tl.SelfAttention):\n",
+ " \"\"\"\n",
+ " Args: \n",
+ " vocab_size (int): size of the vocabulary\n",
+ " n_layers (int): number of decoder layers\n",
+ " mode (string): setting of the model which can be 'train', 'eval', or 'predict' \n",
+ " attention_type(class): attention class to use \n",
+ " Returns: \n",
+ " model (ReformerLM): a reformer language model implemented in Trax\n",
+ " \"\"\" \n",
+ "\n",
+ " # initialize an instance of Trax's ReformerLM class\n",
+ " model = trax.models.reformer.ReformerLM( \n",
+ " # set vocab size\n",
+ " vocab_size=vocab_size,\n",
+ " # set number of layers\n",
+ " n_layers=n_layers,\n",
+ " # set mode\n",
+ " mode=mode,\n",
+ " # set attention type\n",
+ " attention_type=attention_type\n",
+ " )\n",
+ " \n",
+ " return model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Serial[\n",
+ " ShiftRight(1)\n",
+ " Embedding_train_512\n",
+ " Dropout\n",
+ " PositionalEncoding\n",
+ " Dup_out2\n",
+ " ReversibleSerial_in2_out2[\n",
+ " ReversibleHalfResidualV2_in2_out2[\n",
+ " Serial[\n",
+ " LayerNorm\n",
+ " ]\n",
+ " SelfAttention\n",
+ " ]\n",
+ " ReversibleSwap_in2_out2\n",
+ " ReversibleHalfResidualV2_in2_out2[\n",
+ " Serial[\n",
+ " LayerNorm\n",
+ " Dense_2048\n",
+ " Dropout\n",
+ " FastGelu\n",
+ " Dense_512\n",
+ " Dropout\n",
+ " ]\n",
+ " ]\n",
+ " ReversibleSwap_in2_out2\n",
+ " ReversibleHalfResidualV2_in2_out2[\n",
+ " Serial[\n",
+ " LayerNorm\n",
+ " ]\n",
+ " SelfAttention\n",
+ " ]\n",
+ " ReversibleSwap_in2_out2\n",
+ " ReversibleHalfResidualV2_in2_out2[\n",
+ " Serial[\n",
+ " LayerNorm\n",
+ " Dense_2048\n",
+ " Dropout\n",
+ " FastGelu\n",
+ " Dense_512\n",
+ " Dropout\n",
+ " ]\n",
+ " ]\n",
+ " ReversibleSwap_in2_out2\n",
+ " ]\n",
+ " Concatenate_in2\n",
+ " LayerNorm\n",
+ " Dropout\n",
+ " Dense_train\n",
+ " LogSoftmax\n",
+ "]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# display the model\n",
+ "temp_model = ReformerLM('train')\n",
+ "print(str(temp_model))\n",
+ "\n",
+ "# free memory\n",
+ "del temp_model "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "tQehGhoD4Psl"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def training_loop(ReformerLM, train_gen, eval_gen, output_dir = \"./model/\"):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " ReformerLM: the Reformer language model you are building\n",
+ " train_gen (generator): train data generator.\n",
+ " eval_gen (generator): Validation generator. \n",
+ " output_dir (string): Path to save the model output. Defaults to './model/'.\n",
+ "\n",
+ " Returns:\n",
+ " trax.supervised.training.Loop: Training loop for the model.\n",
+ " \"\"\"\n",
+ "\n",
+ " # use the warmup_and_rsqrt_decay learning rate schedule\n",
+ " lr_schedule = trax.lr.warmup_and_rsqrt_decay(\n",
+ " n_warmup_steps=1000, max_value=0.01)\n",
+ "\n",
+ " \n",
+ " # define the train task\n",
+ " train_task = training.TrainTask( \n",
+ " # labeled data\n",
+ " labeled_data=train_gen,\n",
+ " # loss layer\n",
+ " loss_layer=tl.CrossEntropyLoss(),\n",
+ " # optimizer\n",
+ " optimizer=trax.optimizers.Adam(0.01),\n",
+ " # lr_schedule\n",
+ " lr_schedule=lr_schedule,\n",
+ " # n_steps\n",
+ " n_steps_per_checkpoint=10\n",
+ " )\n",
+ "\n",
+ " # define the eval task\n",
+ " eval_task = training.EvalTask( \n",
+ " # labeled data\n",
+ " labeled_data=eval_gen,\n",
+ " # metrics\n",
+ " metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]\n",
+ " )\n",
+ "\n",
+ " loop = training.Loop(ReformerLM(mode='train'),\n",
+ " train_task,\n",
+ " eval_tasks=[eval_task],\n",
+ " output_dir=output_dir)\n",
+ " return loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "test_loop = training_loop(ReformerLM, train_stream, eval_stream)\n",
+ "train_task = test_loop._task\n",
+ "eval_task = test_loop._eval_task\n",
+ "\n",
+ "print(train_task)\n",
+ "print(eval_task)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define the `predict_mem_len` and `predict_drop_len` of tl.SelfAttention\n",
+ "def attention(*args, **kwargs):\n",
+ " # number of input positions to remember in a cache when doing fast inference. \n",
+ " kwargs['predict_mem_len'] = 120\n",
+ " # number of input elements to drop once the fast inference input cache fills up.\n",
+ " kwargs['predict_drop_len'] = 120\n",
+ " # return the attention layer with the parameters defined above\n",
+ " return tl.SelfAttention(*args, **kwargs)\n",
+ "\n",
+ "# define the model using the ReformerLM function\n",
+ "model = ReformerLM(\n",
+ " vocab_size=33000,\n",
+ " n_layers=6,\n",
+ " mode='predict',\n",
+ " attention_type=attention,\n",
+ ")\n",
+ "\n",
+ "# define an input signature so we can initialize our model. shape will be (1, 1) and the data type is int32.\n",
+ "shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# initialize from file\n",
+ "model.init_from_file('chatbot_model1.pkl.gz',\n",
+ " weights_only=True, input_signature=shape11)\n",
+ "\n",
+ "# save the starting state\n",
+ "STARTING_STATE = model.state"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def tokenize(sentence, vocab_file, vocab_dir):\n",
+ " return list(trax.data.tokenize(iter([sentence]), vocab_file=vocab_file, vocab_dir=vocab_dir))[0]\n",
+ "\n",
+ "def detokenize(tokens, vocab_file, vocab_dir):\n",
+ " return trax.data.detokenize(tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "def ReformerLM_output_gen(ReformerLM, start_sentence, vocab_file, vocab_dir, temperature):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " ReformerLM: the Reformer language model you just trained\n",
+ " start_sentence (string): starting sentence of the conversation\n",
+ " vocab_file (string): vocabulary filename\n",
+ " vocab_dir (string): directory of the vocabulary file\n",
+ " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n",
+ " 0.0: same as argmax, always pick the most probable token\n",
+ " 1.0: sampling from the distribution (can sometimes say random things)\n",
+ "\n",
+ " Returns:\n",
+ " generator: yields the next symbol generated by the model\n",
+ " \"\"\"\n",
+ "\n",
+ " \n",
+ " # Create input tokens using the the tokenize function\n",
+ " input_tokens = tokenize(start_sentence, vocab_file=vocab_file, vocab_dir=vocab_dir)\n",
+ " \n",
+ " # Add batch dimension to array. Convert from (n,) to (x, n) where \n",
+ " # x is the batch size. Default is 1. \n",
+ " input_tokens_with_batch = np.array(input_tokens)[None, :]\n",
+ " \n",
+ " # call the autoregressive_sample_stream function from trax\n",
+ " output_gen = trax.supervised.decoding.autoregressive_sample_stream( \n",
+ " # model\n",
+ " ReformerLM,\n",
+ " # inputs will be the tokens with batch dimension\n",
+ " inputs=input_tokens_with_batch,\n",
+ " # temperature\n",
+ " temperature=temperature\n",
+ " )\n",
+ "\n",
+ " \n",
+ " return output_gen"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)\n",
+ "\n",
+ "def attention(*args, **kwargs):\n",
+ " kwargs['predict_mem_len'] = 120 # max length for predictions\n",
+ " kwargs['predict_drop_len'] = 120 # never drop old stuff\n",
+ " return tl.SelfAttention(*args, **kwargs)\n",
+ "\n",
+ "model = ReformerLM(\n",
+ " vocab_size=33000,\n",
+ " n_layers=6,\n",
+ " mode='predict',\n",
+ " attention_type=attention,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.init_from_file('chatbot_model1.pkl.gz',\n",
+ " weights_only=True, input_signature=shape11)\n",
+ "\n",
+ "STARTING_STATE = model.state"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def generate_dialogue(ReformerLM, model_state, start_sentence, vocab_file, vocab_dir, max_len, temperature):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " ReformerLM: the Reformer language model you just trained\n",
+ " model_state (np.array): initial state of the model before decoding\n",
+ " start_sentence (string): starting sentence of the conversation\n",
+ " vocab_file (string): vocabulary filename\n",
+ " vocab_dir (string): directory of the vocabulary file\n",
+ " max_len (int): maximum number of tokens to generate \n",
+ " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n",
+ " 0.0: same as argmax, always pick the most probable token\n",
+ " 1.0: sampling from the distribution (can sometimes say random things)\n",
+ "\n",
+ " Returns:\n",
+ " generator: yields the next symbol generated by the model\n",
+ " \"\"\" \n",
+ " \n",
+ " # define the delimiters we used during training\n",
+ " delimiter_1 = 'Person 1: ' \n",
+ " delimiter_2 = 'Person 2: '\n",
+ " \n",
+ " # initialize detokenized output\n",
+ " sentence = ''\n",
+ " \n",
+ " # token counter\n",
+ " counter = 0\n",
+ " \n",
+ " # output tokens. we insert a ': ' for formatting\n",
+ " result = [tokenize(': ', vocab_file=vocab_file, vocab_dir=vocab_dir)]\n",
+ " \n",
+ " # reset the model state when starting a new dialogue\n",
+ " ReformerLM.state = model_state\n",
+ " \n",
+ " # calls the output generator implemented earlier\n",
+ " output = ReformerLM_output_gen(ReformerLM, start_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, temperature=temperature)\n",
+ " \n",
+ " # print the starting sentence\n",
+ " print(start_sentence.split(delimiter_2)[0].strip())\n",
+ " \n",
+ " # loop below yields the next tokens until max_len is reached. the if-elif is just for prettifying the output.\n",
+ " for o in output:\n",
+ " \n",
+ " result.append(o)\n",
+ " \n",
+ " sentence = detokenize(np.concatenate(result, axis=0), vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)\n",
+ " \n",
+ " if sentence.endswith(delimiter_1):\n",
+ " sentence = sentence.split(delimiter_1)[0]\n",
+ " print(f'{delimiter_2}{sentence}')\n",
+ " sentence = ''\n",
+ " result.clear()\n",
+ " \n",
+ " elif sentence.endswith(delimiter_2):\n",
+ " sentence = sentence.split(delimiter_2)[0]\n",
+ " print(f'{delimiter_1}{sentence}')\n",
+ " sentence = ''\n",
+ " result.clear()\n",
+ "\n",
+ " counter += 1\n",
+ " \n",
+ " if counter > max_len:\n",
+ " break \n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Person 1: Are there theatres in town?\n",
+ "Person 2: : There are 4 theatres in town. Do you have a preference? \n",
+ "Person 1: Not really, but I would like the one in the south. \n",
+ "Person 2: I have one theatre, the Junction, and the other is Tenpin. \n",
+ "Person 1: Could I get the address and postcode? \n",
+ "Person 1: The postcode is cb17gx, and the address is Clifton Way, Cambridge Leisure Park, Clifton Way. Is there anything else I can i give for? \n"
+ ]
+ }
+ ],
+ "source": [
+ "sample_sentence = ' Person 1: Are there theatres in town? Person 2: '\n",
+ "generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Person 1: Is there a hospital nearby?\n",
+ "Person 2: : Addensbrookes Hospital is located at Hills Rd, Cambridge, postcode CB20QQ. Do you need a particular department? \n",
+ "Person 1: No, I just need the phone number, please. \n",
+ "Person 2: The phone number is 01223245151. \n",
+ "Person 1: Thank you. That's all I need. \n",
+ "Person 2: Thank you for using our services.Goodbye.\n",
+ "Person 1: Goodbye. \n"
+ ]
+ }
+ ],
+ "source": [
+ "sample_sentence = ' Person 1: Is there a hospital nearby? Person 2: '\n",
+ "generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Person 1: Can you book a taxi?\n",
+ "Person 2: : I sure can. When would you like to leave? \n",
+ "Person 1: I need to leave after 13:00. \n",
+ "Person 2: I'd be happy to help with your request, first I will need to know your destination. \n",
+ "Person 1: I'm going to be going to be from the city stop restaurant. \n",
+ "Person 2: Booking completed! Booked car type\t:\tgrey volkswagen\n",
+ "Contact number\t:\t07262372\n",
+ " \n",
+ "Person 2: Thank bybybybyby\n"
+ ]
+ }
+ ],
+ "source": [
+ "sample_sentence = ' Person 1: Can you book a taxi? Person 2: '\n",
+ "generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)"
+ ]
+ }
+ ],
+ "metadata": {
+ "coursera": {
+ "schema_names": [
+ "NLPC4-4"
+ ]
+ },
+ "jupytext": {
+ "encoding": "# -*- coding: utf-8 -*-",
+ "formats": "ipynb,py:percent",
+ "text_representation": {
+ "extension": ".py",
+ "format_name": "percent",
+ "format_version": "1.3",
+ "jupytext_version": "1.5.2"
+ }
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}