From 1de068b1c12f9443b7e04e4afb726a416bd820f9 Mon Sep 17 00:00:00 2001 From: Tom White Date: Sun, 25 Aug 2024 15:59:04 +0100 Subject: [PATCH] Use sgkit.distarray for gwas_linear_regression --- .github/workflows/cubed.yml | 2 +- sgkit/distarray.py | 3 +++ sgkit/stats/association.py | 17 ++++++++++------- sgkit/stats/utils.py | 2 +- sgkit/tests/test_association.py | 4 ++-- 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/.github/workflows/cubed.yml b/.github/workflows/cubed.yml index bdcf3f242..687b8d519 100644 --- a/.github/workflows/cubed.yml +++ b/.github/workflows/cubed.yml @@ -30,4 +30,4 @@ jobs: - name: Test with pytest run: | - pytest -v sgkit/tests/test_{aggregation,hwe}.py -k 'test_count_call_alleles or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed \ No newline at end of file + pytest -v sgkit/tests/test_{aggregation,association,hwe}.py -k 'test_count_call_alleles or test_gwas_linear_regression or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed \ No newline at end of file diff --git a/sgkit/distarray.py b/sgkit/distarray.py index e91059ba4..35ba5d5a0 100644 --- a/sgkit/distarray.py +++ b/sgkit/distarray.py @@ -14,3 +14,6 @@ def astype(x, dtype, /, *, copy=True): # pragma: no cover if not copy and dtype == x.dtype: return x return x.astype(dtype=dtype, copy=copy) + + # dask doesn't have concat required by the array API + concat = concatenate # noqa: F405 diff --git a/sgkit/stats/association.py b/sgkit/stats/association.py index 82cd47cdb..0825223ed 100644 --- a/sgkit/stats/association.py +++ b/sgkit/stats/association.py @@ -1,11 +1,13 @@ from dataclasses import dataclass from typing import Hashable, Optional, Sequence, Union -import dask.array as da import numpy as np -from dask.array import Array, stats +from scipy import stats from xarray import Dataset, concat +import sgkit.distarray as da +from sgkit.distarray import Array + from .. import variables from ..typing import ArrayLike from ..utils import conditional_merge_datasets, create_dataset @@ -78,7 +80,7 @@ def linear_regression( # from projection require no extra terms in variance # estimate for loop covariates (columns of G), which is # only true when an intercept is present. - XLPS = (XLP**2).sum(axis=0, keepdims=True).T + XLPS = da.sum(XLP**2, axis=0, keepdims=True).T assert XLPS.shape == (n_loop_covar, 1) B = (XLP.T @ YP) / XLPS assert B.shape == (n_loop_covar, n_outcome) @@ -86,10 +88,10 @@ def linear_regression( # Compute residuals for each loop covariate and outcome separately YR = YP[:, np.newaxis, :] - XLP[..., np.newaxis] * B[np.newaxis, ...] assert YR.shape == (n_obs, n_loop_covar, n_outcome) - RSS = (YR**2).sum(axis=0) + RSS = da.sum(YR**2, axis=0) assert RSS.shape == (n_loop_covar, n_outcome) # Get t-statistics for coefficient estimates - T = B / np.sqrt(RSS / dof / XLPS) + T = B / da.sqrt(RSS / dof / XLPS) assert T.shape == (n_loop_covar, n_outcome) # Match to p-values @@ -102,7 +104,8 @@ def linear_regression( dtype="float64", ) assert P.shape == (n_loop_covar, n_outcome) - P = np.asarray(P, like=T) + if hasattr(T, "__array_function__"): + P = np.asarray(P, like=T) return LinearRegressionResult(beta=B, t_value=T, p_value=P) @@ -216,7 +219,7 @@ def gwas_linear_regression( else: X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates"))) if add_intercept: - X = da.concatenate([da.ones((X.shape[0], 1), dtype=X.dtype), X], axis=1) + X = da.concat([da.ones((X.shape[0], 1), dtype=X.dtype), X], axis=1) # Note: dask qr decomp (used by lstsq) requires no chunking in one # dimension, and because dim 0 will be far greater than the number # of covariates for the large majority of use cases, chunking diff --git a/sgkit/stats/utils.py b/sgkit/stats/utils.py index 0b736efbc..8320a75c5 100644 --- a/sgkit/stats/utils.py +++ b/sgkit/stats/utils.py @@ -104,7 +104,7 @@ def assert_array_shape(x: ArrayLike, *args: int) -> None: def map_blocks_asnumpy(x: Array) -> Array: - if da.utils.is_cupy_type(x._meta): # pragma: no cover + if hasattr(x, "_meta") and da.utils.is_cupy_type(x._meta): # pragma: no cover import cupy as cp # type: ignore[import] x = x.map_blocks(cp.asnumpy) diff --git a/sgkit/tests/test_association.py b/sgkit/tests/test_association.py index a9116fd58..f73c34ccb 100644 --- a/sgkit/tests/test_association.py +++ b/sgkit/tests/test_association.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple -import dask.array as da import numpy as np import pandas as pd import pytest @@ -11,6 +10,7 @@ from pandas import DataFrame from xarray import Dataset +import sgkit.distarray as da from sgkit.stats.association import ( gwas_linear_regression, linear_regression, @@ -263,7 +263,7 @@ def test_gwas_linear_regression__scalar_vars(ds: xr.Dataset) -> None: res_list = gwas_linear_regression( ds, dosage="dosage", covariates=["covar_0"], traits=["trait_0"] ) - xr.testing.assert_equal(res_scalar, res_list) + xr.testing.assert_allclose(res_scalar, res_list) def test_gwas_linear_regression__raise_on_no_intercept_and_empty_covariates():