Skip to content

Commit

Permalink
Merge branch 'cli' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
smsegal committed Feb 1, 2021
2 parents a74b0f2 + 0604d9f commit 26ac2e7
Show file tree
Hide file tree
Showing 67 changed files with 1,060 additions and 5,696 deletions.
56 changes: 52 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,61 @@
# DHDRNet

This work accompanies my thesis: "**Exposure Fusion with Learned Image Selection**".
This work accompanies my thesis: "**Learned Exposure Selection for High Dynamic Range Image Synthesis**".

DHDRNet, or **D**ual-photo **HDR** **Net**work lets you effectively create HDR images from only two captures, instead of the usual five commonly required.
DHDRNet, or **D**ual-photo **HDR** **Net**work lets you effectively create HDR images with 60% fewer resources when compared with standard [Exposure Fusion](https://en.wikipedia.org/wiki/Exposure_fusion) [1].

It introduces a convolutional network that predicts the two optimal photos needed to generate a high quality HDR image with Exposure Fusion. It also introduces a data generation system that creates synthetically exposed images and accounts for metadata corruption in order to generate a high quality training set.

This system effectively cuts the requirements of common Exposure Fusion HDR systems from needing 5 source images to only two, in an intelligent manner.
## Usage

There are two main components of this system. Firstly, the data generation system is used to take a set of HDR DNG files (like those found in the HDR+ [2] dataset from Google) and generate synthetic exposures, statistics, and exposure-fused images for training purposes.

The image reconstructions are generated in advance of training, instead of on-demand as there are computational constraints involved in training-time generation of Exposure Fusion images (the Exposure Fusion algorithm is CPU-bound).

This project is created with [Poetry](https://python-poetry.org), and as such it is recommended to create the virtual environment with:

```sh
git clone https://github.com/smsegal/DHDRNet.git
cd DHDRNet
poetry install
```

### Data Generation

Download the HDR+ DNG files from google with:
```sh
python -m dhdrnet.data_prep download --out=./foo
```
This utilizes the [Google Cloud SDK](https://cloud.google.com/sdk/docs/quickstart), so make sure you've signed in and authenticated.

After the DNG files are downloaded, the synthetic exposures and fused images can be generated with:

```sh
python -m dhdrnet.data_prep generate-data --download-dir=./foo --out=./bar
```

See the `python -m dhdrnet.data_prep --help` output for further command line arguments that can customize aspects of the data generation process.

The generated data will be stored in the specified directories for use in later training.

### Training

Several network backbones were explored and tested for efficacy and efficiency. The network that struck the best balance of performance to model size was MobileNetV2, but other architectures are included for comparison. The other network backbones used here are ResNet, SqueezeNet, MobileNetV1, and EfficientNet. See my thesis for more details.

To train a network from scratch:

```sh
python -m dhdrnet.train --model=mobilenetv2 --data-dir=./bar --logdir=./baz
```

This will by default output tensorboard logs in `logs/`.

See the output of `python -m dhdrnet.train --help` for more options that can customize training.



Code written with PyTorch and PyTorch Lightning.

This code hasn't necessarily been cleaned up for public view. Lots of in progress research notes and experimentation contained in the notebooks and maybe the code comments :)
[1] T. Mertens, J. Kautz, and F. Van Reeth, “Exposure Fusion: A Simple and Practical Alternative to High Dynamic Range Photography,” Computer Graphics Forum, vol. 28, no. 1, pp. 161–171, Mar. 2009, doi: 10.1111/j.1467-8659.2008.01171.x.

[2] S. W. Hasinoff et al., “Burst photography for high dynamic range and low-light imaging on mobile cameras,” ACM Trans. Graph., vol. 35, no. 6, pp. 1–12, Nov. 2016, doi: 10.1145/2980179.2980254.
52 changes: 0 additions & 52 deletions dhdrnet/cv_fuse.py

This file was deleted.

130 changes: 65 additions & 65 deletions dhdrnet/gen_pairs.py → dhdrnet/data_generator.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import argparse
import operator as op
import sys
from collections import defaultdict
from functools import partial, reduce
from itertools import product, repeat
from itertools import product
from pathlib import Path
from typing import Callable, Collection, List, Optional
from typing import Callable, Collection, Dict, List, Optional

import cv2 as cv
import exifread
import numpy as np
import pandas as pd
import rawpy
import torch
from lpips import LPIPS, im2tensor
from more_itertools import flatten
from pandas.core.frame import DataFrame
from lpips import LPIPS as LPIPS_orig
from lpips import im2tensor
from more_itertools import collapse, flatten, one
from pandas import DataFrame
from skimage.metrics import (
mean_squared_error,
peak_signal_noise_ratio,
Expand All @@ -29,14 +29,14 @@


def main(args):
generator = GenAllPairs(
generator = DataGenerator(
raw_path=Path(args.raw_path),
out_path=Path(args.out_path),
store_path=Path(args.store_path),
exp_max=args.exp_max,
exp_min=args.exp_min,
exp_step=args.exp_step,
single_threaded=args.single_thread,
multithreaded=args.single_thread,
image_names=args.image_names,
)
if args.updown:
Expand All @@ -50,7 +50,7 @@ def main(args):
data.to_csv(args.store_path)


class GenAllPairs:
class DataGenerator:
def __init__(
self,
raw_path: Path,
Expand All @@ -60,8 +60,9 @@ def __init__(
exp_min: float = -3,
exp_max: float = 6,
exp_step: float = 0.25,
single_threaded: bool = False,
multithreaded: bool = True,
image_names=None,
metrics: List[str] = ["rmse", "psnr", "ssim", "lpips"],
):
self.exposures: np.ndarray = np.linspace(
exp_min, exp_max, int((exp_max - exp_min) / exp_step + 1)
Expand All @@ -76,17 +77,23 @@ def __init__(

if image_names is None:
self.image_names = [p.stem for p in (DATA_DIR / "dngs").iterdir()]
elif isinstance(image_names, List):
self.image_names = image_names
else:
self.image_names = list(flatten(pd.read_csv(image_names).to_numpy()))

if compute_scores:
self.metricfuncs = {
all_metric_fns = {
"rmse": rmse,
"psnr": peak_signal_noise_ratio,
"ssim": partial(structural_similarity, multichannel=True),
"perceptual": PerceptualMetric(),
"lpips": LPIPS(),
}
self.metrics = list(self.metricfuncs.keys())
self.metric_fns: Dict[str, Callable] = {
k: v for k, v in all_metric_fns.items() if k in metrics
}

self.metrics = metrics

self.exp_out_path.mkdir(parents=True, exist_ok=True)
self.reconstructed_out_path.mkdir(parents=True, exist_ok=True)
Expand All @@ -95,36 +102,30 @@ def __init__(
self.fused_out_path.mkdir(parents=True, exist_ok=True)

if store_path:
self.store_path = store_path
self.store_path: Path = store_path
self.store: DataFrame
if store_path.is_file():
self.store = pd.read_csv(
store_path, usecols=["name", "metric", "ev1", "ev2", "score"]
)
else:
self.store = pd.DataFrame(
self.store = DataFrame(
data=None, columns=["name", "metric", "ev1", "ev2", "score"]
)
self.store_path.parent.mkdir(parents=True, exist_ok=True)

self.store.to_csv(self.store_path, index=False)

self.single_threaded = single_threaded

self.written_store = False

self._ff: Callable[
[List[np.ndarray]], np.ndarray
] = cv.createMergeMertens().process
self.multithreaded = multithreaded

def __call__(self):
if self.single_threaded:
stats = self.stats_dispatch()
else:
if self.multithreaded:
stats = self.stats_dispatch_parallel()
else:
stats = self.stats_dispatch()

stats_df = pd.DataFrame.from_dict(stats)
print(f"computed all stats, saved in {self.store}")
stats_df = DataFrame.from_dict(stats)
print(f"computed all stats, saved in {self.store_path}")
return stats_df

def stats_dispatch(self):
Expand All @@ -150,7 +151,7 @@ def compute_updown(self, image_names):
updown_img = self.get_updown(name, ev)
ground_truth = self.get_ground_truth(name)
for metric in self.metrics:
score = self.metricfuncs[metric](ground_truth, updown_img)
score = self.metric_fns[metric](ground_truth, updown_img)
records.append((name, metric, ev, score))

stats = pd.DataFrame.from_records(
Expand All @@ -173,9 +174,7 @@ def compute_stats(self, img_name):
stats["metric"].append(metric)
stats["ev1"].append(0.0)
stats["ev2"].append(ev)
stats["score"].append(
self.metricfuncs[metric](ground_truth, reconstruction)
)
stats["score"].append(self.metric_fns[metric](ground_truth, reconstruction))

df = pd.DataFrame.from_dict(stats)
df.to_csv(self.store_path, mode="a", header=None, index=False)
Expand Down Expand Up @@ -213,6 +212,9 @@ def get_reconstruction(self, name, ev1, ev2):
name, ev_list=[ev1, ev2], out_path=self.reconstructed_out_path
)

def fuse_fn(self, images) -> np.ndarray:
return cv.createMergeMertens().process(images)

def get_fused(
self,
name: str,
Expand All @@ -234,7 +236,7 @@ def get_fused(
else:
# print(f"generating fused file: {fused_path}", sys.stderr)
images = self.get_exposures(name, ev_list)
fused_im = self._ff([im.astype("float32") for im in images])
fused_im = self.fuse_fn([im.astype("float32") for im in images])
fused_im = np.clip(fused_im * 255, 0, 255).astype("uint8")

cv.imwrite(str(fused_path), fused_im)
Expand Down Expand Up @@ -309,48 +311,46 @@ def nested_dict_merge(d1, d2):
return merged


def PerceptualMetric(net: str = "alex") -> Callable:
from more_itertools import collapse, one
class LPIPS:
def __init__(self, net: str = "alex") -> None:
self.net: str = net
self.usegpu: bool = torch.cuda.is_available()
self.model: nn.Module = LPIPS_orig(net=net, spatial=False)
if self.usegpu:
self.model = self.model.cuda()

model: nn.Module = LPIPS(net=net, spatial=False)
usegpu = torch.cuda.is_available()
if usegpu:
model = model.cuda()

def perceptual_loss_metric(ima: torch.Tensor, imb: torch.Tensor) -> torch.Tensor:
def __call__(self, ima: torch.Tensor, imb: torch.Tensor) -> torch.Tensor:
ima_t, imb_t = map(im2tensor, [ima, imb])
if usegpu:
if self.usegpu:
ima_t = ima_t.cuda()
imb_t = imb_t.cuda()

dist = one(collapse(model.forward(ima_t, imb_t).data.cpu().numpy()))
dist = one(collapse(self.model.forward(ima_t, imb_t).data.cpu().numpy()))
return dist

return perceptual_loss_metric

# if __name__ == "__main__":
# parser = argparse.ArgumentParser(description="Generate stats and images")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate stats and images")
# parser.add_argument("--out-path", "-o", help="where to save the processed files")
# parser.add_argument("--raw-path", help="location of raw files")
# parser.add_argument(
# "--store-path",
# help="file to store data in (created if does not exist)",
# default="store",
# )

parser.add_argument("--out-path", "-o", help="where to save the processed files")
parser.add_argument("--raw-path", help="location of raw files")
parser.add_argument(
"--store-path",
help="file to store data in (created if does not exist)",
default="store",
)
# parser.add_argument("--image-names", default=None)

parser.add_argument("--image-names", default=None)

parser.add_argument("--exp-min", default=-3)
parser.add_argument("--exp-max", default=6)
parser.add_argument("--exp-step", default=0.25)
parser.add_argument(
"--single-thread", help="single threaded mode", action="store_true"
)
parser.add_argument(
"--updown", help="compute the updown strategy", action="store_true"
)
# parser.add_argument("--exp-min", default=-3)
# parser.add_argument("--exp-max", default=6)
# parser.add_argument("--exp-step", default=0.25)
# parser.add_argument(
# "--multithread", help="single threaded mode", action="store_true"
# )
# parser.add_argument(
# "--updown", help="compute the updown strategy", action="store_true"
# )

args = parser.parse_args()
main(args)
# args = parser.parse_args()
# main(args)
2 changes: 1 addition & 1 deletion dhdrnet/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from dhdrnet.util import DATA_DIR, ROOT_DIR
from dhdrnet.Dataset import LUTDataset
from dhdrnet.dataset import LUTDataset
from pathlib import Path
from typing import List, Optional, Union

Expand Down
Loading

0 comments on commit 26ac2e7

Please sign in to comment.