Skip to content

Commit

Permalink
Support RayTransform custom backends (#1540)
Browse files Browse the repository at this point in the history
* Simplify adjoint method

* Make RayTransform operators independent from their implementations

* move __future__ import to beginning

* Keep interface of RayTransform.impl the same, fix lowercasing

* Share cache with adjoint

* Fix self.__impl initialization

* Style changes, change ValueError into TypeError

* Remove RayTransformImplBase

* Remove RayTransformBase; RayBackProjection becomes Operator

* Move _call logic into backend for potential optimization

* Remove backends from __all__, update imports

* Bring `RayBackProjection` class inline in `RayTransform.adjoint`

* Fix linear=True kwarg in `RayBackProjection`

* Decorate RayTransform backend calls with `_add_default_complex_impl`

* Fix import

* Use `impl_type.__name__` as a return value for custom types

* Add `geometry` with @Property

* Make `_check_impl` static

* Change class names of implementations

* Change `reco_space` into `vol_space` and formatting

* Change `reco_space` to `vol_space`

* Fix `self` in function call

* Fix complex spaces, and fix `out` argument

* Add properties for `vol_space` and `proj_space` to implementation classes

* Add docstrings

* Update test to include complex adjoint of `RayTransform`

* Make `_IMPL_STR2TYPE` public as `RAY_TRAFO_IMPLS`

* Do not reassign to `out` when `out` is not None

* Some formatting updates

- Opening and closing parens on separate lines if line
  needs split
- No more backslash line continuation
- Sorted imports
- Removal of remaining references to `DiscreteLp`

* Update of README.md in tomo subpackage

* Copyright notice and functool.wraps in backend utils

* Fix failing import in skimage_radon.py

* Change of words

* Docstring extended with use case

* Removed _ALL_IMPLS, enhanced exception messages, allow duck-typing `impl` string.

* Import sorting order corrected

* Renamed `_check_impl` and `create_impl`. Simplified docstrings.

* Whitespace changes

* Turn MD syntax in docstring to rst

Co-authored-by: Holger Kohr <ho.kohr@zoho.com>
  • Loading branch information
adriaangraas and kohr-h committed Apr 16, 2020
1 parent 282ef23 commit 6fba0ca
Show file tree
Hide file tree
Showing 9 changed files with 743 additions and 670 deletions.
20 changes: 10 additions & 10 deletions odl/test/tomo/backends/astra_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
import pytest

import odl
from odl.tomo.backends.astra_cuda import (
AstraCudaBackProjectorImpl, AstraCudaProjectorImpl)
from odl.tomo.backends.astra_cuda import AstraCudaImpl
from odl.tomo.util.testutils import skip_if_no_astra_cuda


Expand Down Expand Up @@ -85,24 +84,25 @@ def test_astra_cuda_projector(space_and_geometry):
"""Test ASTRA CUDA projector."""

# Create reco space and a phantom
reco_space, geom = space_and_geometry
phantom = odl.phantom.cuboid(reco_space)
vol_space, geom = space_and_geometry
phantom = odl.phantom.cuboid(vol_space)

# Make projection space
proj_space = odl.uniform_discr_frompartition(geom.partition,
dtype=reco_space.dtype)
dtype=vol_space.dtype)

# create RayTransform implementation
astra_cuda = AstraCudaImpl(geom, vol_space, proj_space)

# Forward evaluation
projector = AstraCudaProjectorImpl(geom, reco_space, proj_space)
proj_data = projector.call_forward(phantom)
proj_data = astra_cuda.call_forward(phantom)
assert proj_data in proj_space
assert proj_data.norm() > 0
assert np.all(proj_data.asarray() >= 0)

# Backward evaluation
back_projector = AstraCudaBackProjectorImpl(geom, reco_space, proj_space)
backproj = back_projector.call_backward(proj_data)
assert backproj in reco_space
backproj = astra_cuda.call_backward(proj_data)
assert backproj in vol_space
assert backproj.norm() > 0
assert np.all(proj_data.asarray() >= 0)

Expand Down
12 changes: 10 additions & 2 deletions odl/test/tomo/operators/ray_trafo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
skip_if_no_astra, skip_if_no_astra_cuda, skip_if_no_skimage)
from odl.util.testutils import all_almost_equal, simple_fixture


# --- pytest fixtures --- #


Expand Down Expand Up @@ -106,7 +105,6 @@ def geometry(request):
'par2d skimage half_uniform'])
)


projector_ids = [
" geom='{}' - impl='{}' - angles='{}' ".format(*p.values[0].split())
for p in projectors
Expand Down Expand Up @@ -339,6 +337,16 @@ def test_complex(impl):
assert all_almost_equal(data.real, true_data_re)
assert all_almost_equal(data.imag, true_data_im)

# test adjoint for complex data
backproj_r = ray_trafo_r.adjoint
backproj_c = ray_trafo_c.adjoint
true_vol_re = backproj_r(data.real)
true_vol_im = backproj_r(data.imag)
backproj_vol = backproj_c(data)

assert all_almost_equal(backproj_vol.real, true_vol_re)
assert all_almost_equal(backproj_vol.imag, true_vol_im)


def test_anisotropic_voxels(geometry):
"""Test projection and backprojection with anisotropic voxels."""
Expand Down
2 changes: 1 addition & 1 deletion odl/tomo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ This directory contains all of the source code related tomographic reconstructio
* [analytic](analytic) Analytic reconstruction methods such as filtered back-projection. Also contains various utilities like `parker_weighting`.
* [backends](backends) Bindings to external libraries.
* [geometry](geometry) Definitions of projection geometries.
* [operators](operators) Defines the `RayTransform` operator and its adjoint ("back-projection").
* [operators](operators) Defines the `RayTransform` operator.
* [util](util) Utilities used internally.
2 changes: 2 additions & 0 deletions odl/tomo/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from .astra_cuda import *
from .astra_setup import *
from .skimage_radon import *
from .util import *

__all__ = ()
__all__ += astra_cpu.__all__
__all__ += astra_cuda.__all__
__all__ += astra_setup.__all__
__all__ += util.__all__
__all__ += skimage_radon.__all__
167 changes: 132 additions & 35 deletions odl/tomo/backends/astra_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@

from __future__ import absolute_import, division, print_function

import warnings

import numpy as np

from odl.discr import DiscretizedSpace, DiscretizedSpaceElement
from odl.tomo.backends.astra_setup import (
astra_algorithm, astra_data, astra_projection_geometry, astra_projector,
astra_volume_geometry)
from odl.tomo.backends.util import _add_default_complex_impl
from odl.tomo.geometry import (
DivergentBeamGeometry, Geometry, ParallelBeamGeometry)
from odl.util import writable_array
Expand Down Expand Up @@ -94,30 +97,42 @@ def astra_cpu_forward_projector(vol_data, geometry, proj_space, out=None,
If ``out`` was provided, the returned object is a reference to it.
"""
if not isinstance(vol_data, DiscretizedSpaceElement):
raise TypeError('volume data {!r} is not a `DiscretizedSpaceElement` '
'instance.'.format(vol_data))
raise TypeError(
'volume data {!r} is not a `DiscretizedSpaceElement` instance'
''.format(vol_data)
)
if vol_data.space.impl != 'numpy':
raise TypeError("`vol_data.space.impl` must be 'numpy', got {!r}"
"".format(vol_data.space.impl))
raise TypeError(
"`vol_data.space.impl` must be 'numpy', got {!r}"
"".format(vol_data.space.impl)
)
if not isinstance(geometry, Geometry):
raise TypeError('geometry {!r} is not a Geometry instance'
''.format(geometry))
raise TypeError(
'geometry {!r} is not a Geometry instance'.format(geometry)
)
if not isinstance(proj_space, DiscretizedSpace):
raise TypeError('`proj_space` {!r} is not a DiscretizedSpace '
'instance.'.format(proj_space))
raise TypeError(
'`proj_space` {!r} is not a DiscretizedSpace instance.'
''.format(proj_space)
)
if proj_space.impl != 'numpy':
raise TypeError("`proj_space.impl` must be 'numpy', got {!r}"
"".format(proj_space.impl))
raise TypeError(
"`proj_space.impl` must be 'numpy', got {!r}"
"".format(proj_space.impl)
)
if vol_data.ndim != geometry.ndim:
raise ValueError('dimensions {} of volume data and {} of geometry '
'do not match'
''.format(vol_data.ndim, geometry.ndim))
raise ValueError(
'dimensions {} of volume data and {} of geometry do not match'
''.format(vol_data.ndim, geometry.ndim)
)
if out is None:
out = proj_space.element()
else:
if out not in proj_space:
raise TypeError('`out` {} is neither None nor a '
'DiscretizedSpaceElement instance'.format(out))
raise TypeError(
'`out` {} is neither None nor a `DiscretizedSpaceElement` '
'instance'.format(out)
)

ndim = vol_data.ndim

Expand Down Expand Up @@ -188,28 +203,37 @@ def astra_cpu_back_projector(proj_data, geometry, vol_space, out=None,
'instance'.format(proj_data)
)
if proj_data.space.impl != 'numpy':
raise TypeError('`proj_data` must be a `numpy.ndarray` based, '
"container got `impl` {!r}"
"".format(proj_data.space.impl))
raise TypeError(
'`proj_data` must be a `numpy.ndarray` based, container, '
"got `impl` {!r}".format(proj_data.space.impl)
)
if not isinstance(geometry, Geometry):
raise TypeError('geometry {!r} is not a Geometry instance'
''.format(geometry))
raise TypeError(
'geometry {!r} is not a Geometry instance'.format(geometry)
)
if not isinstance(vol_space, DiscretizedSpace):
raise TypeError('volume space {!r} is not a DiscretizedSpace '
'instance'.format(vol_space))
raise TypeError(
'volume space {!r} is not a DiscretizedSpace instance'
''.format(vol_space)
)
if vol_space.impl != 'numpy':
raise TypeError("`vol_space.impl` must be 'numpy', got {!r}"
"".format(vol_space.impl))
raise TypeError(
"`vol_space.impl` must be 'numpy', got {!r}".format(vol_space.impl)
)
if vol_space.ndim != geometry.ndim:
raise ValueError('dimensions {} of reconstruction space and {} of '
'geometry do not match'.format(
vol_space.ndim, geometry.ndim))
raise ValueError(
'dimensions {} of reconstruction space and {} of geometry '
'do not match'
''.format(vol_space.ndim, geometry.ndim)
)
if out is None:
out = vol_space.element()
else:
if out not in vol_space:
raise TypeError('`out` {} is neither None nor a '
'DiscretizedSpaceElement instance'.format(out))
raise TypeError(
'`out` {} is neither None nor a `DiscretizedSpaceElement` '
'instance'.format(out)
)

ndim = proj_data.ndim

Expand All @@ -218,8 +242,9 @@ def astra_cpu_back_projector(proj_data, geometry, vol_space, out=None,
proj_geom = astra_projection_geometry(geometry)

# Create ASTRA data structure
sino_id = astra_data(proj_geom, datatype='projection', data=proj_data,
allow_copy=True)
sino_id = astra_data(
proj_geom, datatype='projection', data=proj_data, allow_copy=True
)

# Create projector
if astra_proj_type is None:
Expand All @@ -228,11 +253,13 @@ def astra_cpu_back_projector(proj_data, geometry, vol_space, out=None,

# Convert out to correct dtype and order if needed.
with writable_array(out, dtype='float32', order='C') as out_arr:
vol_id = astra_data(vol_geom, datatype='volume', data=out_arr,
ndim=vol_space.ndim)
vol_id = astra_data(
vol_geom, datatype='volume', data=out_arr, ndim=vol_space.ndim
)
# Create algorithm
algo_id = astra_algorithm('backward', ndim, vol_id, sino_id, proj_id,
impl='cpu')
algo_id = astra_algorithm(
'backward', ndim, vol_id, sino_id, proj_id, impl='cpu'
)

# Run algorithm
astra.algorithm.run(algo_id)
Expand All @@ -251,6 +278,76 @@ def astra_cpu_back_projector(proj_data, geometry, vol_space, out=None,
return out


class AstraCpuImpl:
"""Thin wrapper implementing ASTRA CPU for `RayTransform`."""

def __init__(self, geometry, vol_space, proj_space):
"""Initialize a new instance.
Parameters
----------
geometry : `Geometry`
Geometry defining the tomographic setup.
vol_space : `DiscreteLp`
Reconstruction space, the space of the images to be forward
projected.
proj_space : `DiscreteLp`
Projection space, the space of the result.
"""
if not isinstance(geometry, Geometry):
raise TypeError(
'`geometry` must be a `Geometry` instance, got {!r}'
''.format(geometry)
)
if not isinstance(vol_space, DiscretizedSpace):
raise TypeError(
'`vol_space` must be a `DiscretizedSpace` instance, got {!r}'
''.format(vol_space)
)
if not isinstance(proj_space, DiscretizedSpace):
raise TypeError(
'`proj_space` must be a `DiscretizedSpace` instance, got {!r}'
''.format(proj_space)
)
if geometry.ndim > 2:
raise ValueError(
'`impl` {!r} only works for 2d'.format(self.__name__)
)

if vol_space.size >= 512 ** 2:
warnings.warn(
"The 'astra_cpu' backend may be too slow for volumes of this "
"size. Consider using 'astra_cuda' if your machine has an "
"Nvidia GPU.",
RuntimeWarning,
)

self.geometry = geometry
self._vol_space = vol_space
self._proj_space = proj_space

@property
def vol_space(self):
return self._vol_space

@property
def proj_space(self):
return self._proj_space

@_add_default_complex_impl
def call_backward(self, x, out, **kwargs):
return astra_cpu_back_projector(
x, self.geometry, self.vol_space.real_space, out, **kwargs
)

@_add_default_complex_impl
def call_forward(self, x, out, **kwargs):
return astra_cpu_forward_projector(
x, self.geometry, self.proj_space.real_space, out, **kwargs
)


if __name__ == '__main__':
from odl.util.testutils import run_doctests

run_doctests()
Loading

0 comments on commit 6fba0ca

Please sign in to comment.