Skip to content

Commit

Permalink
Clean up segmentation notebook to prep the code to be converted into …
Browse files Browse the repository at this point in the history
…a command line script
  • Loading branch information
GiscardBiamby committed Feb 21, 2022
1 parent 48df7a3 commit 7a4954b
Showing 1 changed file with 54 additions and 155 deletions.
209 changes: 54 additions & 155 deletions notebooks/segment_vid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,22 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "756c6c24-1547-4e44-bef0-ba8114472865",
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"IPython.notebook.set_autosave_interval(60000)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Autosaving every 60 seconds\n"
]
}
],
"outputs": [],
"source": [
"%autosave 60\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline\n",
"\n",
"import sys\n",
"from pathlib import Path"
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "638d3161-e610-4acc-a4e3-da1bfb3d662c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[1m\u001b[1mINFO \u001b[0m\u001b[1m\u001b[0m - \u001b[1mThe mmdet config folder already exists. No need to downloaded it. Path : /home/gbiamby/.icevision/mmdetection_configs/mmdetection_configs-2.20.1/configs\u001b[0m | \u001b[36micevision.models.mmdet.download_configs\u001b[0m:\u001b[36mdownload_mmdet_configs\u001b[0m:\u001b[36m17\u001b[0m\n"
]
}
],
"outputs": [],
"source": [
"import json\n",
"import logging\n",
Expand All @@ -58,30 +29,19 @@
"from datetime import datetime, timedelta\n",
"from io import BytesIO\n",
"from pathlib import Path\n",
"from types import ModuleType\n",
"from typing import Dict, List, Optional, Tuple, Union, cast\n",
"from typing import Any, Dict, List, Optional, Tuple, Union, cast\n",
"\n",
"import cv2\n",
"import matplotlib as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import PIL\n",
"import PIL.Image as pil_img\n",
"import seaborn as sns\n",
"import sklearn as skl\n",
"from IPython.display import Image, display\n",
"from matplotlib.patches import Rectangle\n",
"from matplotlib_inline.backend_inline import set_matplotlib_formats\n",
"from termcolor import colored\n",
"from tqdm.contrib import tenumerate, tmap, tzip\n",
"from tqdm.contrib.bells import tqdm, trange\n",
"\n",
"from geoscreens.pseudolabels import reverse_point"
"from tqdm.contrib.bells import tqdm, trange"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "6bac2c9e-587a-4a4b-ac05-7bffac782d4f",
"metadata": {},
"outputs": [],
Expand All @@ -90,41 +50,7 @@
"pd.set_option(\"display.max_columns\", 15)\n",
"pd.set_option(\"display.max_rows\", 50)\n",
"# Suitable default display for floats\n",
"pd.options.display.float_format = \"{:,.2f}\".format\n",
"# matplotlib options\n",
"set_matplotlib_formats(\"pdf\", \"png\")\n",
"plt.rcParams[\"savefig.dpi\"] = 75\n",
"plt.rcParams[\"figure.autolayout\"] = False\n",
"plt.rcParams[\"figure.figsize\"] = 10, 6\n",
"plt.rcParams[\"axes.labelsize\"] = 18\n",
"plt.rcParams[\"axes.titlesize\"] = 20\n",
"plt.rcParams[\"font.size\"] = 16\n",
"plt.rcParams[\"lines.linewidth\"] = 2.0\n",
"plt.rcParams[\"lines.markersize\"] = 8\n",
"plt.rcParams[\"legend.fontsize\"] = 14\n",
"plt.rcParams[\"text.usetex\"] = True\n",
"plt.rcParams[\"font.family\"] = \"serif\"\n",
"plt.rcParams[\"font.serif\"] = \"cm\"\n",
"plt.rcParams[\"text.latex.preamble\"] = \"\\\\usepackage{subdepth}, \\\\usepackage{type1cm}\"\n",
"\n",
"# This one is optional -- change graphs to SVG only use if you don't have a\n",
"# lot of points/lines in your graphs. Can also just use ['retina'] if you\n",
"# don't want SVG.\n",
"%config InlineBackend.figure_formats = [\"retina\"]\n",
"set_matplotlib_formats(\"pdf\", \"png\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e22a2a9e-b1a0-48da-ab67-7436969c19b6",
"metadata": {},
"outputs": [],
"source": [
"# VIDEO_PATH = Path(\"/shared/g-luo/geoguessr/videos\").resolve()\n",
"# OUT_PATH = Path(\"/shared/gbiamby/geo/screenshots/screen_samples_auto\").resolve()\n",
"# assert VIDEO_PATH.exists()\n",
"# assert OUT_PATH.exists()"
"pd.options.display.float_format = \"{:,.2f}\".format"
]
},
{
Expand All @@ -145,12 +71,13 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "2d5a4ae7-cc11-4dc1-8983-46a05e0fffa8",
"metadata": {},
"outputs": [],
"source": [
"def parse_tuple(s: str):\n",
" \"\"\"Helper for load_detections_csv, to parse string column into column of Tuples.\"\"\"\n",
" if isinstance(s, str):\n",
" result = s.replace(\"(\", \"[\").replace(\")\", \"]\")\n",
" result = result.replace(\"'\", '\"').strip()\n",
Expand All @@ -164,6 +91,7 @@
"\n",
"\n",
"def parse_dict(s: str):\n",
" \"\"\"Helper for load_detections_csv, to parse string column into Dict.\"\"\"\n",
" if isinstance(s, str):\n",
" return json.loads(s.replace(\"'\", '\"'))\n",
" return s\n",
Expand Down Expand Up @@ -205,7 +133,7 @@
" df = pickle.load(open(dets_path, \"rb\"))\n",
"\n",
" if \"frame_time\" not in df.columns:\n",
" df[\"frame_time\"] = df.apply(lambda x: f\"{x.frame_id/4.0:04}\", axis=1)\n",
" df[\"frame_time\"] = df.apply(lambda x: f\"{x.frame_id/frame_sample_rate:04}\", axis=1)\n",
" if \"seconds\" not in df.columns:\n",
" df[\"seconds\"] = df.frame_id.apply(lambda frame_id: frame_id / frame_sample_rate)\n",
" if \"time\" not in df.columns:\n",
Expand All @@ -227,14 +155,20 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "d9602177-ad52-4aca-b0bc-86ec074b84be",
"metadata": {},
"outputs": [],
"source": [
"def apply_smoothing(\n",
" df_framedets: pd.DataFrame, window_size: int = 5, direction: str = \"forward\"\n",
") -> None:\n",
" \"\"\"\n",
" Applies smoothing to the game_state column, storing the results in a new \"game_states_smoothed\"\n",
" column. Smoothing is only used at points where the game_state changes values, in which case the\n",
" new value only changes if it is the most common element in the buffer (which can be either be\n",
" look-ahead/backwards) of nearby game_states. Preferred direction is forward.\n",
" \"\"\"\n",
" smoothed = []\n",
" current_state = df_framedets.loc[0][\"game_state\"]\n",
" direction = direction.replace(\"backwards\", \"backward\").replace(\"forwards\", \"forward\")\n",
Expand Down Expand Up @@ -266,6 +200,10 @@
"\n",
"\n",
"def add_state_transition(state_transitions, row: pd.Series, from_state: str, to_state: str):\n",
" \"\"\"\n",
" Helper method to append transition from from one state to another to a list of game state end\n",
" points.\n",
" \"\"\"\n",
" state_transitions.extend(\n",
" [\n",
" {\n",
Expand All @@ -289,7 +227,14 @@
" )\n",
"\n",
"\n",
"def get_segments(df_framedets: pd.DataFrame, smoothing=False) -> List[Dict]:\n",
"def get_game_state_endpoints(df_framedets: pd.DataFrame, smoothing=False) -> List[Dict[str, Any]]:\n",
" \"\"\"\n",
" Given a DataDrame with the detections from a geoguessr video, returns list of dictionaries,\n",
" each representing either the start or end of a contiguous section of the video. The sections\n",
" tracked are either \"in_game\" or \"out_of_game\". Out of game can be anything such as\n",
" not_in_geoguessr, between round, end of round -- anything that isn't the user actually in the\n",
" game playing with the street view.\n",
" \"\"\"\n",
" current_state = \"out_of_game\"\n",
" state_transitions = []\n",
" state_key = \"game_state_smoothed\" if smoothing else \"game_state\"\n",
Expand Down Expand Up @@ -342,7 +287,11 @@
" return gt\n",
"\n",
"\n",
"def get_oog_segments(segments: List[Dict], game_state: str = None):\n",
"def endpoints_to_segments(segments: List[Dict], game_state: str = None):\n",
" \"\"\"\n",
" Collapses list of video segment endpoints into list of states. Each state in the return value\n",
" has information about the start and end, and duration of the segment.\n",
" \"\"\"\n",
" i = 1\n",
" segs = []\n",
" while i + 1 < len(segments):\n",
Expand Down Expand Up @@ -375,8 +324,8 @@
"\n",
"def classify_frame(dets: pd.Series) -> str:\n",
" \"\"\"\n",
" Input is a row of a pd.DataFrame. The row contains object detector output\n",
" for the geoguessr UI elements.\n",
" Input is a row of a pd.DataFrame. The row contains object detector output for the geoguessr UI\n",
" elements.\n",
" \"\"\"\n",
" label_set_base = set(dets[\"labels_set\"])\n",
" label_set = set(dets[\"labels_set\"])\n",
Expand Down Expand Up @@ -407,7 +356,11 @@
" return \"unknown\"\n",
"\n",
"\n",
"def compare_to_gt(segs: List[Dict], gt: Dict):\n",
"def compare_to_gt(segs: List[Dict], gt: Dict) -> None:\n",
" \"\"\"\n",
" Compares generated segments (in the DataFrame) to ground truth segments (Dict parameter), and\n",
" stores the result in the \"is_correct\" column of `segs` DataFrame.\n",
" \"\"\"\n",
" gt_oog_segs = gt[\"oog_segs\"]\n",
" for seg in segs:\n",
" if seg[\"state\"] == \"out_of_game\":\n",
Expand All @@ -434,7 +387,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "470bc224-706d-4213-a0c2-1d716a878dc0",
"metadata": {
"tags": []
Expand Down Expand Up @@ -505,7 +458,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "8389dab5-8429-4db7-9d40-9ad49f8247ab",
"metadata": {
"tags": []
Expand Down Expand Up @@ -704,7 +657,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "93bc6865-0067-4d1f-aecc-f8665b4d80cf",
"metadata": {
"tags": []
Expand Down Expand Up @@ -750,8 +703,8 @@
")\n",
"df_framedets[\"game_state\"] = df_framedets.apply(classify_frame, axis=1)\n",
"apply_smoothing(df_framedets, window_size=5, direction=\"forward\")\n",
"seg = get_segments(df_framedets, smoothing=True)\n",
"segs_collapsed = get_oog_segments(seg)\n",
"end_points = get_game_state_endpoints(df_framedets, smoothing=True)\n",
"segs_collapsed = endpoints_to_segments(end_points)\n",
"#\n",
"print(\"video_id: \", video_id)\n",
"compare_to_gt(segs_collapsed, seg_gt_new[video_id])\n",
Expand Down Expand Up @@ -863,7 +816,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "54665118-fee4-4db6-9967-b3b792b15c86",
"metadata": {
"tags": []
Expand Down Expand Up @@ -908,64 +861,10 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"id": "d62210d9-a27b-4e0b-8785-dd4695286bdd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"num videos: 39\n",
"num videos: 45\n",
"video_ids missing detection files: {'Mk9x9VZpIi4', 'hzjA9gfxMeQ'}\n",
"Segmenting 43 videos...\n",
"video_id: x9mNJalP73w\n",
"video_id: jNKj2MXeah4\n",
"video_id: dY1RXh-43q4\n",
"video_id: drAmJ8r8_UI\n",
"video_id: osTwgzWluVs\n",
"video_id: Y8yW_BsZ018\n",
"video_id: NY3YDQvI1Ic\n",
"video_id: 88jRjbWTesc\n",
"video_id: NIr1XF0doag\n",
"video_id: 83m9ys4kxro\n",
"video_id: QAXV5-eUHVI\n",
"video_id: 96Th1-UjSOU\n",
"video_id: KA3r-gF1ub8\n",
"video_id: Tjxb3UzaduA\n",
"video_id: AF9uezxZDeE\n",
"video_id: hG6rJf0RBnk\n",
"video_id: 9RQUIk1OwAY\n",
"video_id: hEZVNDqid2I\n",
"video_id: qQMeHkwP8hg\n",
"video_id: oxGTI4ifaUI\n",
"video_id: zfZ6BxPne4E\n",
"video_id: ogJnHIuT8Yc\n",
"video_id: S5Ne5eoHxsY\n",
"video_id: YpqCkIfj1kQ\n",
"video_id: 8LXi_tpkpSg\n",
"video_id: o8qQAjkaXMM\n",
"video_id: 0J7cQ4FiDCc\n",
"video_id: pMgqa0mOExo\n",
"video_id: 2SzL5VBF_BI\n",
"video_id: dRG76uV8Gh8\n",
"video_id: oxQaoCK5-gw\n",
"video_id: NjriHMSM26k\n",
"video_id: mkx8bU_di1k\n",
"video_id: SB4UMgTRBe4\n",
"video_id: N57v3XC_KgU\n",
"video_id: 54c2PpV65hU\n",
"video_id: PzAXjKD4ZRg\n",
"video_id: UCQg1LJOywc\n",
"video_id: hZWt1PYH3hI\n",
"video_id: dO7TdYgtAWg\n",
"video_id: 8jWG2tLeVMw\n",
"video_id: nyHeQWnm8YA\n",
"video_id: tqny4LpSUiE\n"
]
}
],
"outputs": [],
"source": [
"segments = {}\n",
"seg_gt_new = load_gt(\"seg_ground_truth_009.json\")\n",
Expand All @@ -987,8 +886,8 @@
" df_framedets = load_detections(video_id, split=\"val\", model=model)\n",
" df_framedets[\"game_state\"] = df_framedets.apply(classify_frame, axis=1)\n",
" apply_smoothing(df_framedets)\n",
" seg = get_segments(df_framedets, smoothing=True)\n",
" segments[video_id] = get_oog_segments(seg)"
" end_points = get_game_state_endpoints(df_framedets, smoothing=True)\n",
" segments[video_id] = endpoints_to_segments(end_points)"
]
},
{
Expand Down

0 comments on commit 7a4954b

Please sign in to comment.