Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP, add pre/post copy and pre/post update tasks #1511

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,27 @@ def _check_unsafe(self, mode: Literal["copy", "update"]) -> None:
features.add("jinja_extensions")
if self.template.tasks:
features.add("tasks")
if self.template.pre_copy:
features.add("pre_copy")
if self.template.pre_update:
features.add("pre_update")
if self.template.post_copy:
features.add("post_copy")
if self.template.post_update:
features.add("post_update")
if mode == "update" and self.subproject.template:
if self.subproject.template.jinja_extensions:
features.add("jinja_extensions")
if self.subproject.template.tasks:
features.add("tasks")
if self.subproject.template.pre_copy:
features.add("pre_copy")
if self.subproject.template.pre_update:
features.add("pre_update")
if self.subproject.template.post_copy:
features.add("post_copy")
if self.subproject.template.post_update:
features.add("post_update")
for stage in get_args(Literal["before", "after"]):
if self.template.migration_tasks(stage, self.subproject.template):
features.add("migrations")
Expand Down Expand Up @@ -746,6 +762,7 @@ def run_copy(self) -> None:
was_existing = self.subproject.local_abspath.exists()
src_abspath = self.template_copy_root
try:
self._execute_tasks(self.template.pre_copy)
if not self.quiet:
# TODO Unify printing tools
print(
Expand All @@ -757,6 +774,7 @@ def run_copy(self) -> None:
# TODO Unify printing tools
print("") # padding space
self._execute_tasks(self.template.tasks)
self._execute_tasks(self.template.post_copy)
except Exception:
if not was_existing and self.cleanup_on_error:
rmtree(self.subproject.local_abspath)
Expand Down Expand Up @@ -818,12 +836,14 @@ def run_update(self) -> None:
# asking for confirmation
raise UserMessageError("Enable overwrite to update a subproject.")
self._print_message(self.template.message_before_update)
self._execute_tasks(self.template.pre_update)
if not self.quiet:
# TODO Unify printing tools
print(
f"Updating to template version {self.template.version}", file=sys.stderr
)
self._apply_update()
self._execute_tasks(self.template.post_update)
self._print_message(self.template.message_after_update)

def _apply_update(self):
Expand Down
44 changes: 44 additions & 0 deletions copier/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,50 @@ def tasks(self) -> Sequence[Task]:
for cmd in self.config_data.get("tasks", [])
]

@cached_property
def pre_copy(self) -> Sequence[Task]:
"""Get pre-copy tasks defined in the template.

See [pre_copy][].
"""
return [
Task(cmd=cmd, extra_env={"STAGE": "pre_copy"})
for cmd in self.config_data.get("pre_copy", [])
]

@cached_property
def post_copy(self) -> Sequence[Task]:
"""Get post-copy tasks defined in the template.

See [post_copy][].
"""
return [
Task(cmd=cmd, extra_env={"STAGE": "post_copy"})
for cmd in self.config_data.get("post_copy", [])
]

@cached_property
def pre_update(self) -> Sequence[Task]:
"""Get pre-update tasks defined in the template.

See [pre_update][].
"""
return [
Task(cmd=cmd, extra_env={"STAGE": "pre_update"})
for cmd in self.config_data.get("pre_update", [])
]

@cached_property
def post_update(self) -> Sequence[Task]:
"""Get post-update tasks defined in the template.

See [post_update][].
"""
return [
Task(cmd=cmd, extra_env={"STAGE": "post_update"})
for cmd in self.config_data.get("post_update", [])
]

@cached_property
def templates_suffix(self) -> str:
"""Get the suffix defined for templates.
Expand Down
109 changes: 109 additions & 0 deletions tests/test_pre_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from pathlib import Path
from typing import Literal, Optional

import pytest

import copier

from .helpers import BRACKET_ENVOPS_JSON, SUFFIX_TMPL, build_file_tree


@pytest.fixture(scope="module")
def template_path(tmp_path_factory: pytest.TempPathFactory) -> str:
root = tmp_path_factory.mktemp("demo_pre_copy")
build_file_tree(
{
(root / "copier.yaml"): (
f"""\
_templates_suffix: {SUFFIX_TMPL}
_envops: {BRACKET_ENVOPS_JSON}

other_file: bye

# This tests two things:
# 1. That the tasks are being executed in the destination folder; and
# 2. That the tasks are being executed in order, one after another
_pre_copy:
- mkdir hello
- cd hello && touch world
- touch [[ other_file ]]
- ["[[ _copier_python ]]", "-c", "open('pyfile', 'w').close()"]
"""
)
}
)
return str(root)


def test_render_tasks(template_path: str, tmp_path: Path) -> None:
copier.run_copy(template_path, tmp_path, data={"other_file": "custom"}, unsafe=True)
assert (tmp_path / "custom").is_file()


def test_copy_tasks(template_path: str, tmp_path: Path) -> None:
copier.run_copy(
template_path, tmp_path, quiet=True, defaults=True, overwrite=True, unsafe=True
)
assert (tmp_path / "hello").exists()
assert (tmp_path / "hello").is_dir()
assert (tmp_path / "hello" / "world").exists()
assert (tmp_path / "bye").is_file()
assert (tmp_path / "pyfile").is_file()


def test_pretend_mode(tmp_path_factory: pytest.TempPathFactory) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): (
"""
_pre_copy:
- touch created-by-pre-copy.txt
"""
)
}
)
copier.run_copy(str(src), dst, pretend=True, unsafe=True)
assert not (dst / "created-by-pre-copy.txt").exists()


@pytest.mark.parametrize(
"os, filename",
[
("linux", "linux.txt"),
("macos", "macos.txt"),
("windows", "windows.txt"),
(None, "unsupported.txt"),
],
)
def test_os_specific_tasks(
tmp_path_factory: pytest.TempPathFactory,
monkeypatch: pytest.MonkeyPatch,
os: Optional[Literal["linux", "macos", "windows"]],
filename: str,
) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): (
"""\
_pre_copy:
- >-
{% if _copier_conf.os == 'linux' %}
touch linux.txt
{% elif _copier_conf.os == 'macos' %}
touch macos.txt
{% elif _copier_conf.os == 'windows' %}
touch windows.txt
{% elif _copier_conf.os is none %}
touch unsupported.txt
{% else %}
touch never.txt
{% endif %}
"""
)
}
)
monkeypatch.setattr("copier.main.OS", os)
copier.run_copy(str(src), dst, unsafe=True)
assert (dst / filename).exists()