Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sgkit.distarray for gwas_linear_regression #1262

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cubed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ jobs:

- name: Test with pytest
run: |
pytest -v sgkit/tests/test_{aggregation,hwe}.py -k 'test_count_call_alleles or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed
pytest -v sgkit/tests/test_{aggregation,association,hwe}.py -k 'test_count_call_alleles or test_gwas_linear_regression or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed
3 changes: 3 additions & 0 deletions sgkit/distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ def astype(x, dtype, /, *, copy=True): # pragma: no cover
if not copy and dtype == x.dtype:
return x
return x.astype(dtype=dtype, copy=copy)

# dask doesn't have concat required by the array API
concat = concatenate # noqa: F405
17 changes: 10 additions & 7 deletions sgkit/stats/association.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dataclasses import dataclass
from typing import Hashable, Optional, Sequence, Union

import dask.array as da
import numpy as np
from dask.array import Array, stats
from scipy import stats
from xarray import Dataset, concat

import sgkit.distarray as da
from sgkit.distarray import Array

from .. import variables
from ..typing import ArrayLike
from ..utils import conditional_merge_datasets, create_dataset
Expand Down Expand Up @@ -78,18 +80,18 @@ def linear_regression(
# from projection require no extra terms in variance
# estimate for loop covariates (columns of G), which is
# only true when an intercept is present.
XLPS = (XLP**2).sum(axis=0, keepdims=True).T
XLPS = da.sum(XLP**2, axis=0, keepdims=True).T
assert XLPS.shape == (n_loop_covar, 1)
B = (XLP.T @ YP) / XLPS
assert B.shape == (n_loop_covar, n_outcome)

# Compute residuals for each loop covariate and outcome separately
YR = YP[:, np.newaxis, :] - XLP[..., np.newaxis] * B[np.newaxis, ...]
assert YR.shape == (n_obs, n_loop_covar, n_outcome)
RSS = (YR**2).sum(axis=0)
RSS = da.sum(YR**2, axis=0)
assert RSS.shape == (n_loop_covar, n_outcome)
# Get t-statistics for coefficient estimates
T = B / np.sqrt(RSS / dof / XLPS)
T = B / da.sqrt(RSS / dof / XLPS)
assert T.shape == (n_loop_covar, n_outcome)

# Match to p-values
Expand All @@ -102,7 +104,8 @@ def linear_regression(
dtype="float64",
)
assert P.shape == (n_loop_covar, n_outcome)
P = np.asarray(P, like=T)
if hasattr(T, "__array_function__"):
P = np.asarray(P, like=T)
return LinearRegressionResult(beta=B, t_value=T, p_value=P)


Expand Down Expand Up @@ -216,7 +219,7 @@ def gwas_linear_regression(
else:
X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates")))
if add_intercept:
X = da.concatenate([da.ones((X.shape[0], 1), dtype=X.dtype), X], axis=1)
X = da.concat([da.ones((X.shape[0], 1), dtype=X.dtype), X], axis=1)
# Note: dask qr decomp (used by lstsq) requires no chunking in one
# dimension, and because dim 0 will be far greater than the number
# of covariates for the large majority of use cases, chunking
Expand Down
2 changes: 1 addition & 1 deletion sgkit/stats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def assert_array_shape(x: ArrayLike, *args: int) -> None:


def map_blocks_asnumpy(x: Array) -> Array:
if da.utils.is_cupy_type(x._meta): # pragma: no cover
if hasattr(x, "_meta") and da.utils.is_cupy_type(x._meta): # pragma: no cover
import cupy as cp # type: ignore[import]

x = x.map_blocks(cp.asnumpy)
Expand Down
4 changes: 2 additions & 2 deletions sgkit/tests/test_association.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple

import dask.array as da
import numpy as np
import pandas as pd
import pytest
Expand All @@ -11,6 +10,7 @@
from pandas import DataFrame
from xarray import Dataset

import sgkit.distarray as da
from sgkit.stats.association import (
gwas_linear_regression,
linear_regression,
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_gwas_linear_regression__scalar_vars(ds: xr.Dataset) -> None:
res_list = gwas_linear_regression(
ds, dosage="dosage", covariates=["covar_0"], traits=["trait_0"]
)
xr.testing.assert_equal(res_scalar, res_list)
xr.testing.assert_allclose(res_scalar, res_list)


def test_gwas_linear_regression__raise_on_no_intercept_and_empty_covariates():
Expand Down
Loading