Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI rewrite #2

Merged
merged 7 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,61 @@
# DHDRNet

This work accompanies my thesis: "**Exposure Fusion with Learned Image Selection**".
This work accompanies my thesis: "**Learned Exposure Selection for High Dynamic Range Image Synthesis**".

DHDRNet, or **D**ual-photo **HDR** **Net**work lets you effectively create HDR images from only two captures, instead of the usual five commonly required.
DHDRNet, or **D**ual-photo **HDR** **Net**work lets you effectively create HDR images with 60% fewer resources when compared with standard [Exposure Fusion](https://en.wikipedia.org/wiki/Exposure_fusion) [1].

It introduces a convolutional network that predicts the two optimal photos needed to generate a high quality HDR image with Exposure Fusion. It also introduces a data generation system that creates synthetically exposed images and accounts for metadata corruption in order to generate a high quality training set.

This system effectively cuts the requirements of common Exposure Fusion HDR systems from needing 5 source images to only two, in an intelligent manner.
## Usage

There are two main components of this system. Firstly, the data generation system is used to take a set of HDR DNG files (like those found in the HDR+ [2] dataset from Google) and generate synthetic exposures, statistics, and exposure-fused images for training purposes.

The image reconstructions are generated in advance of training, instead of on-demand as there are computational constraints involved in training-time generation of Exposure Fusion images (the Exposure Fusion algorithm is CPU-bound).

This project is created with [Poetry](https://python-poetry.org), and as such it is recommended to create the virtual environment with:

```sh
git clone https://github.com/smsegal/DHDRNet.git
cd DHDRNet
poetry install
```

### Data Generation

Download the HDR+ DNG files from google with:
```sh
python -m dhdrnet.data_prep download --out=./foo
```
This utilizes the [Google Cloud SDK](https://cloud.google.com/sdk/docs/quickstart), so make sure you've signed in and authenticated.

After the DNG files are downloaded, the synthetic exposures and fused images can be generated with:

```sh
python -m dhdrnet.data_prep generate-data --download-dir=./foo --out=./bar
```

See the `python -m dhdrnet.data_prep --help` output for further command line arguments that can customize aspects of the data generation process.

The generated data will be stored in the specified directories for use in later training.

### Training

Several network backbones were explored and tested for efficacy and efficiency. The network that struck the best balance of performance to model size was MobileNetV2, but other architectures are included for comparison. The other network backbones used here are ResNet, SqueezeNet, MobileNetV1, and EfficientNet. See my thesis for more details.

To train a network from scratch:

```sh
python -m dhdrnet.train --model=mobilenetv2 --data-dir=./bar --logdir=./baz
```

This will by default output tensorboard logs in `logs/`.

See the output of `python -m dhdrnet.train --help` for more options that can customize training.



Code written with PyTorch and PyTorch Lightning.

This code hasn't necessarily been cleaned up for public view. Lots of in progress research notes and experimentation contained in the notebooks and maybe the code comments :)
[1] T. Mertens, J. Kautz, and F. Van Reeth, “Exposure Fusion: A Simple and Practical Alternative to High Dynamic Range Photography,” Computer Graphics Forum, vol. 28, no. 1, pp. 161–171, Mar. 2009, doi: 10.1111/j.1467-8659.2008.01171.x.

[2] S. W. Hasinoff et al., “Burst photography for high dynamic range and low-light imaging on mobile cameras,” ACM Trans. Graph., vol. 35, no. 6, pp. 1–12, Nov. 2016, doi: 10.1145/2980179.2980254.
52 changes: 0 additions & 52 deletions dhdrnet/cv_fuse.py

This file was deleted.

130 changes: 65 additions & 65 deletions dhdrnet/gen_pairs.py → dhdrnet/data_generator.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import argparse
import operator as op
import sys
from collections import defaultdict
from functools import partial, reduce
from itertools import product, repeat
from itertools import product
from pathlib import Path
from typing import Callable, Collection, List, Optional
from typing import Callable, Collection, Dict, List, Optional

import cv2 as cv
import exifread
import numpy as np
import pandas as pd
import rawpy
import torch
from lpips import LPIPS, im2tensor
from more_itertools import flatten
from pandas.core.frame import DataFrame
from lpips import LPIPS as LPIPS_orig
from lpips import im2tensor
from more_itertools import collapse, flatten, one
from pandas import DataFrame
from skimage.metrics import (
mean_squared_error,
peak_signal_noise_ratio,
Expand All @@ -29,14 +29,14 @@


def main(args):
generator = GenAllPairs(
generator = DataGenerator(
raw_path=Path(args.raw_path),
out_path=Path(args.out_path),
store_path=Path(args.store_path),
exp_max=args.exp_max,
exp_min=args.exp_min,
exp_step=args.exp_step,
single_threaded=args.single_thread,
multithreaded=args.single_thread,
image_names=args.image_names,
)
if args.updown:
Expand All @@ -50,7 +50,7 @@ def main(args):
data.to_csv(args.store_path)


class GenAllPairs:
class DataGenerator:
def __init__(
self,
raw_path: Path,
Expand All @@ -60,8 +60,9 @@ def __init__(
exp_min: float = -3,
exp_max: float = 6,
exp_step: float = 0.25,
single_threaded: bool = False,
multithreaded: bool = True,
image_names=None,
metrics: List[str] = ["rmse", "psnr", "ssim", "lpips"],
):
self.exposures: np.ndarray = np.linspace(
exp_min, exp_max, int((exp_max - exp_min) / exp_step + 1)
Expand All @@ -76,17 +77,23 @@ def __init__(

if image_names is None:
self.image_names = [p.stem for p in (DATA_DIR / "dngs").iterdir()]
elif isinstance(image_names, List):
self.image_names = image_names
else:
self.image_names = list(flatten(pd.read_csv(image_names).to_numpy()))

if compute_scores:
self.metricfuncs = {
all_metric_fns = {
"rmse": rmse,
"psnr": peak_signal_noise_ratio,
"ssim": partial(structural_similarity, multichannel=True),
"perceptual": PerceptualMetric(),
"lpips": LPIPS(),
}
self.metrics = list(self.metricfuncs.keys())
self.metric_fns: Dict[str, Callable] = {
k: v for k, v in all_metric_fns.items() if k in metrics
}

self.metrics = metrics

self.exp_out_path.mkdir(parents=True, exist_ok=True)
self.reconstructed_out_path.mkdir(parents=True, exist_ok=True)
Expand All @@ -95,36 +102,30 @@ def __init__(
self.fused_out_path.mkdir(parents=True, exist_ok=True)

if store_path:
self.store_path = store_path
self.store_path: Path = store_path
self.store: DataFrame
if store_path.is_file():
self.store = pd.read_csv(
store_path, usecols=["name", "metric", "ev1", "ev2", "score"]
)
else:
self.store = pd.DataFrame(
self.store = DataFrame(
data=None, columns=["name", "metric", "ev1", "ev2", "score"]
)
self.store_path.parent.mkdir(parents=True, exist_ok=True)

self.store.to_csv(self.store_path, index=False)

self.single_threaded = single_threaded

self.written_store = False

self._ff: Callable[
[List[np.ndarray]], np.ndarray
] = cv.createMergeMertens().process
self.multithreaded = multithreaded

def __call__(self):
if self.single_threaded:
stats = self.stats_dispatch()
else:
if self.multithreaded:
stats = self.stats_dispatch_parallel()
else:
stats = self.stats_dispatch()

stats_df = pd.DataFrame.from_dict(stats)
print(f"computed all stats, saved in {self.store}")
stats_df = DataFrame.from_dict(stats)
print(f"computed all stats, saved in {self.store_path}")
return stats_df

def stats_dispatch(self):
Expand All @@ -150,7 +151,7 @@ def compute_updown(self, image_names):
updown_img = self.get_updown(name, ev)
ground_truth = self.get_ground_truth(name)
for metric in self.metrics:
score = self.metricfuncs[metric](ground_truth, updown_img)
score = self.metric_fns[metric](ground_truth, updown_img)
records.append((name, metric, ev, score))

stats = pd.DataFrame.from_records(
Expand All @@ -173,9 +174,7 @@ def compute_stats(self, img_name):
stats["metric"].append(metric)
stats["ev1"].append(0.0)
stats["ev2"].append(ev)
stats["score"].append(
self.metricfuncs[metric](ground_truth, reconstruction)
)
stats["score"].append(self.metric_fns[metric](ground_truth, reconstruction))

df = pd.DataFrame.from_dict(stats)
df.to_csv(self.store_path, mode="a", header=None, index=False)
Expand Down Expand Up @@ -213,6 +212,9 @@ def get_reconstruction(self, name, ev1, ev2):
name, ev_list=[ev1, ev2], out_path=self.reconstructed_out_path
)

def fuse_fn(self, images) -> np.ndarray:
return cv.createMergeMertens().process(images)

def get_fused(
self,
name: str,
Expand All @@ -234,7 +236,7 @@ def get_fused(
else:
# print(f"generating fused file: {fused_path}", sys.stderr)
images = self.get_exposures(name, ev_list)
fused_im = self._ff([im.astype("float32") for im in images])
fused_im = self.fuse_fn([im.astype("float32") for im in images])
fused_im = np.clip(fused_im * 255, 0, 255).astype("uint8")

cv.imwrite(str(fused_path), fused_im)
Expand Down Expand Up @@ -309,48 +311,46 @@ def nested_dict_merge(d1, d2):
return merged


def PerceptualMetric(net: str = "alex") -> Callable:
from more_itertools import collapse, one
class LPIPS:
def __init__(self, net: str = "alex") -> None:
self.net: str = net
self.usegpu: bool = torch.cuda.is_available()
self.model: nn.Module = LPIPS_orig(net=net, spatial=False)
if self.usegpu:
self.model = self.model.cuda()

model: nn.Module = LPIPS(net=net, spatial=False)
usegpu = torch.cuda.is_available()
if usegpu:
model = model.cuda()

def perceptual_loss_metric(ima: torch.Tensor, imb: torch.Tensor) -> torch.Tensor:
def __call__(self, ima: torch.Tensor, imb: torch.Tensor) -> torch.Tensor:
ima_t, imb_t = map(im2tensor, [ima, imb])
if usegpu:
if self.usegpu:
ima_t = ima_t.cuda()
imb_t = imb_t.cuda()

dist = one(collapse(model.forward(ima_t, imb_t).data.cpu().numpy()))
dist = one(collapse(self.model.forward(ima_t, imb_t).data.cpu().numpy()))
return dist

return perceptual_loss_metric

# if __name__ == "__main__":
# parser = argparse.ArgumentParser(description="Generate stats and images")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate stats and images")
# parser.add_argument("--out-path", "-o", help="where to save the processed files")
# parser.add_argument("--raw-path", help="location of raw files")
# parser.add_argument(
# "--store-path",
# help="file to store data in (created if does not exist)",
# default="store",
# )

parser.add_argument("--out-path", "-o", help="where to save the processed files")
parser.add_argument("--raw-path", help="location of raw files")
parser.add_argument(
"--store-path",
help="file to store data in (created if does not exist)",
default="store",
)
# parser.add_argument("--image-names", default=None)

parser.add_argument("--image-names", default=None)

parser.add_argument("--exp-min", default=-3)
parser.add_argument("--exp-max", default=6)
parser.add_argument("--exp-step", default=0.25)
parser.add_argument(
"--single-thread", help="single threaded mode", action="store_true"
)
parser.add_argument(
"--updown", help="compute the updown strategy", action="store_true"
)
# parser.add_argument("--exp-min", default=-3)
# parser.add_argument("--exp-max", default=6)
# parser.add_argument("--exp-step", default=0.25)
# parser.add_argument(
# "--multithread", help="single threaded mode", action="store_true"
# )
# parser.add_argument(
# "--updown", help="compute the updown strategy", action="store_true"
# )

args = parser.parse_args()
main(args)
# args = parser.parse_args()
# main(args)
2 changes: 1 addition & 1 deletion dhdrnet/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from dhdrnet.util import DATA_DIR, ROOT_DIR
from dhdrnet.Dataset import LUTDataset
from dhdrnet.dataset import LUTDataset
from pathlib import Path
from typing import List, Optional, Union

Expand Down
Loading