Skip to content

Commit

Permalink
Add support for RGBA & float dtype like matplotlib (#8)
Browse files Browse the repository at this point in the history
* Add support for RGBA & float dtype like matplotlib
  • Loading branch information
mthiboust committed Nov 15, 2023
1 parent 150a75d commit aefa25c
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 73 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ some_2d_vectors = np.random.rand(5, 2)

colormap2d.pinwheel(some_2d_vectors)

# RGBA float values between 0 and 1 (default)
# array([[0.69019608, 0.6627451 , 0.24705882, 1. ],
# [0.25490196, 0.74509804, 0.82352941, 1. ],
# [0.41960784, 0.2 , 0.79215686, 1. ],
# [0.18823529, 0.29803922, 0.20392157, 1. ],
# [0.24705882, 0.24705882, 0.66666667, 1. ]])

colormap2d.pinwheel(some_2d_vectors, mode="RGB", dtype=np.uint8)
# RGB integers between 0 and 255:
# array([[166, 179, 50],
# [ 50, 66, 94],
Expand Down
74 changes: 12 additions & 62 deletions notebooks/colormap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
"source": [
"# Draw the colormaps\n",
"\n",
"We use `matplotlib` to generate a RGB image of a 2D grid on the [0:1]x[0:1] subspace. The 2D vectors are converted to their RGB equivalent thanks to the 2D colormaps."
"We use `matplotlib` to generate a RGB image of a 2D grid on the [0:1]x[0:1] subspace. The 2D vectors are converted to their RGB equivalent thanks to the 2D colormaps.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import statements:"
"Import statements:\n"
]
},
{
Expand All @@ -23,7 +23,6 @@
"outputs": [],
"source": [
"\"\"\"Draw the colormaps.\"\"\"\n",
"\n",
"import colormap2d\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
Expand All @@ -33,7 +32,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Generation of the [0:1]x[0:1] 2D grid with 51x51 samples (the colormap data has also been sampled at 51x51):"
"Generation of the [0:1]x[0:1] 2D grid with 51x51 samples (the colormap data has also been sampled at 51x51):\n"
]
},
{
Expand Down Expand Up @@ -63,21 +62,21 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### `pinwheel` colormap:"
"### `pinwheel` colormap:\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fae50c74b50>"
"<matplotlib.image.AxesImage at 0x7ff6adf90e10>"
]
},
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
Expand All @@ -102,21 +101,21 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### `cyclic_pinwheel` colormap:"
"### `cyclic_pinwheel` colormap:\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fae50cd9950>"
"<matplotlib.image.AxesImage at 0x7ff6f326d0d0>"
]
},
"execution_count": 7,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
Expand All @@ -136,55 +135,6 @@
"plt.title(\"Cyclic pinwheel colormap\")\n",
"plt.imshow(im, origin=\"lower\", extent=[0, 1, 0, 1])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.91270668, 0.60020465],\n",
" [0.51569033, 0.79642031],\n",
" [0.36533928, 0.31441287],\n",
" [0.41346207, 0.45163162],\n",
" [0.99135696, 0.05691322]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"array = np.random.rand(5, 2)\n",
"array"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[166, 179, 50],\n",
" [ 50, 66, 94],\n",
" [ 63, 98, 212],\n",
" [ 66, 66, 196],\n",
" [222, 199, 169]], dtype=uint8)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"colormap2d.pinwheel(array)"
]
}
],
"metadata": {
Expand All @@ -203,7 +153,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "colormap2d"
version = "0.1.3"
version = "0.2.0"
description = "Colormap for 2D vectors"
authors = [
{ name = "Matthieu Thiboust", email = "14574229+mthiboust@users.noreply.github.com" },
Expand Down
40 changes: 32 additions & 8 deletions src/colormap2d/cm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,56 @@ def _load_npy(relative_path):
return np.load(path)


def _apply_colormap(arr, colormap):
def _apply_colormap(arr, colormap, mode="RGBA", dtype=np.float64):
if not isinstance(arr, np.ndarray):
raise TypeError(f"Parameter must be a numpy array, not {type(arr)}")
if arr.shape[-1] != 2:
raise ValueError(f"Last dimension of array shape {arr.shape} must be 2.")
if np.min(arr) < 0 or np.max(arr) > 1:
raise ValueError("Array values must be in the range [0:1].")
if mode not in ["RGB", "RGBA"]:
raise ValueError(f"Mode must be either 'RGB' or 'RGBA', not '{mode}'.")
if dtype not in [np.float64, np.uint8]:
raise ValueError(
f"Dtype must be either 'np.float64' for float values between 0 and 1, "
f"or 'np.uint8' for integer between 0 and 255, not '{dtype}'."
)

arr = np.round(arr * N).astype(np.int32)
colormap_data = _load_npy(colormap + ".npy")
return colormap_data[arr[..., 0], arr[..., 1]]
out = colormap_data[arr[..., 0], arr[..., 1]]

# Add an alpha channel filled with ones
if mode == "RGBA":
alpha = np.ones(arr.shape[:-1] + (1,), dtype=np.uint8) * 255
out = np.concatenate((out, alpha), axis=-1)

def pinwheel(arr):
# Convert back the values in the [0:1] range like `matplotlib.cm`` functions
if dtype == np.float64:
out = out / 255

return out


def pinwheel(arr, **kwargs):
"""Converts 2D coordinates into 3D RGB values.
Args:
arr: Numpy array whose last dimension is 2 and whose values belong to [0,1]
arr: Numpy array whose last dimension is 2 and whose values belong to [0,1].
**kwargs:
mode: 'RGB' or 'RGBA'. Default to 'RGBA'.
dtype: np.uint8 or np.float64. Default to np.float64.
"""
return _apply_colormap(arr, "pinwheel")
return _apply_colormap(arr, "pinwheel", **kwargs)


def cyclic_pinwheel(arr):
def cyclic_pinwheel(arr, **kwargs):
"""Converts 2D coordinates into 3D RGB values which are the same at each XY border.
Args:
arr: Numpy array whose last dimension is 2 and whose values belong to [0,1]
arr: Numpy array whose last dimension is 2 and whose values belong to [0,1].
**kwargs:
mode: 'RGB' or 'RGBA'. Default to 'RGBA'.
dtype: np.uint8 or np.float64. Default to np.float64.
"""
return _apply_colormap(arr, "cyclic_pinwheel")
return _apply_colormap(arr, "cyclic_pinwheel", **kwargs)
39 changes: 37 additions & 2 deletions tests/test_cm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,53 @@
def test_pinwheel_values():
"""Checks if colormap2d is outputing the expected values."""
v = np.array([[0.91270668, 0.60020465], [0.51569033, 0.79642031]])
rgb = colormap2d.pinwheel(v)

rgb = colormap2d.pinwheel(v, mode="RGB", dtype=np.uint8)
np.testing.assert_allclose(
rgb, np.array([[166, 179, 50], [50, 66, 94]], dtype=np.uint8)
)

rgba = colormap2d.pinwheel(v, mode="RGBA", dtype=np.uint8)
np.testing.assert_allclose(
rgba, np.array([[166, 179, 50, 255], [50, 66, 94, 255]], dtype=np.uint8)
)

rgb = colormap2d.pinwheel(v, mode="RGB", dtype=np.float64)
np.testing.assert_allclose(
rgb,
np.array(
[
[0.65098039, 0.70196078, 0.19607843],
[0.19607843, 0.25882353, 0.36862745],
],
dtype=np.float64,
),
)

rgba = colormap2d.pinwheel(v, mode="RGBA", dtype=np.float64)
np.testing.assert_allclose(
rgba,
np.array(
[
[0.65098039, 0.70196078, 0.19607843, 1.0],
[0.19607843, 0.25882353, 0.36862745, 1.0],
],
dtype=np.float64,
),
)


def test_pinwheel_multidim():
"""Checks if any shape is any provided the last dimension is 2."""
v = np.random.rand(5, 5, 5, 2)
rgb = colormap2d.pinwheel(v)

rgba = colormap2d.pinwheel(v)
assert rgba.shape == v.shape[:-1] + (4,)

rgba = colormap2d.pinwheel(v, mode="RGBA")
assert rgba.shape == v.shape[:-1] + (4,)

rgb = colormap2d.pinwheel(v, mode="RGB")
assert rgb.shape == v.shape[:-1] + (3,)


Expand Down

0 comments on commit aefa25c

Please sign in to comment.