diff --git a/README.md b/README.md index 456dbd0..399aaf2 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,14 @@ some_2d_vectors = np.random.rand(5, 2) colormap2d.pinwheel(some_2d_vectors) +# RGBA float values between 0 and 1 (default) +# array([[0.69019608, 0.6627451 , 0.24705882, 1. ], +# [0.25490196, 0.74509804, 0.82352941, 1. ], +# [0.41960784, 0.2 , 0.79215686, 1. ], +# [0.18823529, 0.29803922, 0.20392157, 1. ], +# [0.24705882, 0.24705882, 0.66666667, 1. ]]) + +colormap2d.pinwheel(some_2d_vectors, mode="RGB", dtype=np.uint8) # RGB integers between 0 and 255: # array([[166, 179, 50], # [ 50, 66, 94], diff --git a/notebooks/colormap.ipynb b/notebooks/colormap.ipynb index aa98a95..83bf643 100644 --- a/notebooks/colormap.ipynb +++ b/notebooks/colormap.ipynb @@ -6,14 +6,14 @@ "source": [ "# Draw the colormaps\n", "\n", - "We use `matplotlib` to generate a RGB image of a 2D grid on the [0:1]x[0:1] subspace. The 2D vectors are converted to their RGB equivalent thanks to the 2D colormaps." + "We use `matplotlib` to generate a RGB image of a 2D grid on the [0:1]x[0:1] subspace. The 2D vectors are converted to their RGB equivalent thanks to the 2D colormaps.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Import statements:" + "Import statements:\n" ] }, { @@ -23,7 +23,6 @@ "outputs": [], "source": [ "\"\"\"Draw the colormaps.\"\"\"\n", - "\n", "import colormap2d\n", "import matplotlib.pyplot as plt\n", "import numpy as np" @@ -33,7 +32,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Generation of the [0:1]x[0:1] 2D grid with 51x51 samples (the colormap data has also been sampled at 51x51):" + "Generation of the [0:1]x[0:1] 2D grid with 51x51 samples (the colormap data has also been sampled at 51x51):\n" ] }, { @@ -63,21 +62,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### `pinwheel` colormap:" + "### `pinwheel` colormap:\n" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 6, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, @@ -102,21 +101,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### `cyclic_pinwheel` colormap:" + "### `cyclic_pinwheel` colormap:\n" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 7, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, @@ -136,55 +135,6 @@ "plt.title(\"Cyclic pinwheel colormap\")\n", "plt.imshow(im, origin=\"lower\", extent=[0, 1, 0, 1])" ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[0.91270668, 0.60020465],\n", - " [0.51569033, 0.79642031],\n", - " [0.36533928, 0.31441287],\n", - " [0.41346207, 0.45163162],\n", - " [0.99135696, 0.05691322]])" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "array = np.random.rand(5, 2)\n", - "array" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[166, 179, 50],\n", - " [ 50, 66, 94],\n", - " [ 63, 98, 212],\n", - " [ 66, 66, 196],\n", - " [222, 199, 169]], dtype=uint8)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "colormap2d.pinwheel(array)" - ] } ], "metadata": { @@ -203,7 +153,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.11.6" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 7fe4688..0d1d58e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "colormap2d" -version = "0.1.3" +version = "0.2.0" description = "Colormap for 2D vectors" authors = [ { name = "Matthieu Thiboust", email = "14574229+mthiboust@users.noreply.github.com" }, diff --git a/src/colormap2d/cm.py b/src/colormap2d/cm.py index 41a7cb4..7564dfb 100644 --- a/src/colormap2d/cm.py +++ b/src/colormap2d/cm.py @@ -20,32 +20,56 @@ def _load_npy(relative_path): return np.load(path) -def _apply_colormap(arr, colormap): +def _apply_colormap(arr, colormap, mode="RGBA", dtype=np.float64): if not isinstance(arr, np.ndarray): raise TypeError(f"Parameter must be a numpy array, not {type(arr)}") if arr.shape[-1] != 2: raise ValueError(f"Last dimension of array shape {arr.shape} must be 2.") if np.min(arr) < 0 or np.max(arr) > 1: raise ValueError("Array values must be in the range [0:1].") + if mode not in ["RGB", "RGBA"]: + raise ValueError(f"Mode must be either 'RGB' or 'RGBA', not '{mode}'.") + if dtype not in [np.float64, np.uint8]: + raise ValueError( + f"Dtype must be either 'np.float64' for float values between 0 and 1, " + f"or 'np.uint8' for integer between 0 and 255, not '{dtype}'." + ) arr = np.round(arr * N).astype(np.int32) colormap_data = _load_npy(colormap + ".npy") - return colormap_data[arr[..., 0], arr[..., 1]] + out = colormap_data[arr[..., 0], arr[..., 1]] + # Add an alpha channel filled with ones + if mode == "RGBA": + alpha = np.ones(arr.shape[:-1] + (1,), dtype=np.uint8) * 255 + out = np.concatenate((out, alpha), axis=-1) -def pinwheel(arr): + # Convert back the values in the [0:1] range like `matplotlib.cm`` functions + if dtype == np.float64: + out = out / 255 + + return out + + +def pinwheel(arr, **kwargs): """Converts 2D coordinates into 3D RGB values. Args: - arr: Numpy array whose last dimension is 2 and whose values belong to [0,1] + arr: Numpy array whose last dimension is 2 and whose values belong to [0,1]. + **kwargs: + mode: 'RGB' or 'RGBA'. Default to 'RGBA'. + dtype: np.uint8 or np.float64. Default to np.float64. """ - return _apply_colormap(arr, "pinwheel") + return _apply_colormap(arr, "pinwheel", **kwargs) -def cyclic_pinwheel(arr): +def cyclic_pinwheel(arr, **kwargs): """Converts 2D coordinates into 3D RGB values which are the same at each XY border. Args: - arr: Numpy array whose last dimension is 2 and whose values belong to [0,1] + arr: Numpy array whose last dimension is 2 and whose values belong to [0,1]. + **kwargs: + mode: 'RGB' or 'RGBA'. Default to 'RGBA'. + dtype: np.uint8 or np.float64. Default to np.float64. """ - return _apply_colormap(arr, "cyclic_pinwheel") + return _apply_colormap(arr, "cyclic_pinwheel", **kwargs) diff --git a/tests/test_cm.py b/tests/test_cm.py index 22b9900..cab5463 100644 --- a/tests/test_cm.py +++ b/tests/test_cm.py @@ -8,18 +8,53 @@ def test_pinwheel_values(): """Checks if colormap2d is outputing the expected values.""" v = np.array([[0.91270668, 0.60020465], [0.51569033, 0.79642031]]) - rgb = colormap2d.pinwheel(v) + rgb = colormap2d.pinwheel(v, mode="RGB", dtype=np.uint8) np.testing.assert_allclose( rgb, np.array([[166, 179, 50], [50, 66, 94]], dtype=np.uint8) ) + rgba = colormap2d.pinwheel(v, mode="RGBA", dtype=np.uint8) + np.testing.assert_allclose( + rgba, np.array([[166, 179, 50, 255], [50, 66, 94, 255]], dtype=np.uint8) + ) + + rgb = colormap2d.pinwheel(v, mode="RGB", dtype=np.float64) + np.testing.assert_allclose( + rgb, + np.array( + [ + [0.65098039, 0.70196078, 0.19607843], + [0.19607843, 0.25882353, 0.36862745], + ], + dtype=np.float64, + ), + ) + + rgba = colormap2d.pinwheel(v, mode="RGBA", dtype=np.float64) + np.testing.assert_allclose( + rgba, + np.array( + [ + [0.65098039, 0.70196078, 0.19607843, 1.0], + [0.19607843, 0.25882353, 0.36862745, 1.0], + ], + dtype=np.float64, + ), + ) + def test_pinwheel_multidim(): """Checks if any shape is any provided the last dimension is 2.""" v = np.random.rand(5, 5, 5, 2) - rgb = colormap2d.pinwheel(v) + rgba = colormap2d.pinwheel(v) + assert rgba.shape == v.shape[:-1] + (4,) + + rgba = colormap2d.pinwheel(v, mode="RGBA") + assert rgba.shape == v.shape[:-1] + (4,) + + rgb = colormap2d.pinwheel(v, mode="RGB") assert rgb.shape == v.shape[:-1] + (3,)