Skip to content

Commit

Permalink
BUG make sure to exclude stuff in conda envs
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed May 31, 2024
1 parent 9de20cc commit 19d20aa
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions conda_forge_conda_plugins/hooks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
import os
import shutil
import subprocess

from conda.base import context
from conda.core.envs_manager import list_all_known_prefixes
from conda.plugins import CondaVirtualPackage, hookimpl


def _find_command_not_in_a_conda_prefix(cmd):
pth = shutil.which(cmd)

if (
pth is None
or pth.startswith(context.root_prefix)
or pth.startswith(tuple(list_all_known_prefixes()))
):
return None

return pth


@hookimpl
def conda_virtual_packages():
# openmpi virtual package
Expand All @@ -12,14 +28,16 @@ def conda_virtual_packages():
openmpi_version = os.environ["CONDA_OVERRIDE_OPENMPI"]
else:
try:
ret = subprocess.run(
["ompi_info", "--parsable"], capture_output=True, text=True
)
if ret.returncode == 0:
for line in ret.stdout.splitlines():
if line.startswith("ompi:version:full:"):
openmpi_version = line.strip().split(":")[3].strip()
break
ompi_info_pth = _find_command_not_in_a_conda_prefix("ompi_info")
if ompi_info_pth is not None:
ret = subprocess.run(
[ompi_info_pth, "--parsable"], capture_output=True, text=True
)
if ret.returncode == 0:
for line in ret.stdout.splitlines():
if line.startswith("ompi:version:full:"):
openmpi_version = line.strip().split(":")[3].strip()
break
except Exception:
pass

Expand All @@ -36,14 +54,16 @@ def conda_virtual_packages():
mpich_version = os.environ["CONDA_OVERRIDE_MPICH"]
else:
try:
ret = subprocess.run(
["mpichversion", "--version"], capture_output=True, text=True
)
if ret.returncode == 0:
for line in ret.stdout.splitlines():
if line.startswith("MPICH Version:"):
mpich_version = line.strip().split(":")[1].strip()
break
mpichversion_pth = _find_command_not_in_a_conda_prefix("mpichversion")
if mpichversion_pth is not None:
ret = subprocess.run(
[mpichversion_pth, "--version"], capture_output=True, text=True
)
if ret.returncode == 0:
for line in ret.stdout.splitlines():
if line.startswith("MPICH Version:"):
mpich_version = line.strip().split(":")[1].strip()
break
except Exception:
pass

Expand Down

0 comments on commit 19d20aa

Please sign in to comment.