Skip to content

Commit

Permalink
Argo Pipeline Integration (LineaLabs#855)
Browse files Browse the repository at this point in the history
* opencv package annotation

* skip tests if cv2 module not installed

* remove prettify from test_opencv.py

* Add dask package annotation

* update dask annotations and tests

* remove to_feather/records

* argo pipeline integration

* update to argo dag template and tasks configuration to suite argo

* fix extra import argo_writer

* change argo dag flavors, adapt to new task graph and output variables

* change the post call blocks for argo and kubeflow

* remove tests for argo

* move all kubernetes operations to _dag file for argo

* add test snapshots

* remove argo specific templates from tasks

* update for runner

* pickle communication enabled with folder creation

* update snapshots for airflow

* remove stepperartifact flavor

* tests failing again

* update runner

* remove step per artifact test from ambr
  • Loading branch information
lazargugleta authored Dec 23, 2022
1 parent e738637 commit b750c5c
Show file tree
Hide file tree
Showing 15 changed files with 671 additions and 136 deletions.
7 changes: 5 additions & 2 deletions lineapy/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def to_pipeline(
Names of artifacts to be included in the pipeline.
framework: str
"AIRFLOW" or "SCRIPT". Defaults to "SCRIPT" if not specified.
Name of the framework to be used.
Defined by enum PipelineTypes in lineapy/data/types.py.
Defaults to "SCRIPT" if not specified.
pipeline_name: Optional[str]
Name of the pipeline.
Expand Down Expand Up @@ -432,7 +434,8 @@ def to_pipeline(
A dictionary of parameters to configure DAG file to be generated.
Not applicable for "SCRIPT" framework as it does not generate a separate
DAG file. For "AIRFLOW" framework, Airflow-native config params such as
"retries" and "schedule_interval" can be passed in.
"retries" and "schedule_interval" can be passed in. For "ARGO" framework,
Argo-native config params such as "namespace" and "service_account_name".
Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions lineapy/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,14 @@ class PipelineType(Enum):
- SCRIPT : the pipeline is wrapped as a python script
- AIRFLOW : the pipeline is wrapped as an airflow dag
- DVC : the pipeline is wrapped as a DVC
- ARGO: the pipeline is wrapped as an Argo workflow dag
- KUBEFLOW : the pipeline is defined using Kubeflow's python SDK
"""

SCRIPT = 1
AIRFLOW = 2
DVC = 3
# ARGO = 4
ARGO = 4
KUBEFLOW = 5


Expand Down
192 changes: 192 additions & 0 deletions lineapy/plugins/argo_pipeline_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import logging
import os
from enum import Enum
from typing import Any, Dict, List, Tuple

from typing_extensions import TypedDict

from lineapy.plugins.base_pipeline_writer import BasePipelineWriter
from lineapy.plugins.task import (
DagTaskBreakdown,
TaskDefinition,
TaskSerializer,
render_task_io_serialize_blocks,
)
from lineapy.plugins.taskgen import (
get_noop_setup_task_definition,
get_noop_teardown_task_definition,
get_task_graph,
)
from lineapy.plugins.utils import load_plugin_template
from lineapy.utils.logging_config import configure_logging
from lineapy.utils.utils import prettify

logger = logging.getLogger(__name__)
configure_logging()


class ARGODagFlavor(Enum):
StepPerSession = 1


ARGODAGConfig = TypedDict(
"ARGODAGConfig",
{
"namespace": str,
"host": str,
"verify_ssl": str,
"image": str,
"image_pull_policy": str,
"token": str,
"workflow_name": str,
"service_account": int,
"kube_config": str,
"dag_flavor": str,
},
total=False,
)


class ARGOPipelineWriter(BasePipelineWriter):
"""
Class for pipeline file writer. Corresponds to "ARGO" framework.
"""

@property
def docker_template_name(self) -> str:
return "argo_dockerfile.jinja"

def _write_dag(self) -> None:

# Check if the given DAG flavor is a supported/valid one
try:
dag_flavor = ARGODagFlavor[
self.dag_config.get("dag_flavor", "StepPerSession")
]
except KeyError:
raise ValueError(f'"{dag_flavor}" is an invalid ARGO dag flavor.')

# Construct DAG text for the given flavor
full_code = self._write_operators(dag_flavor)

# Write out file
file = self.output_dir / f"{self.pipeline_name}_dag.py"
file.write_text(prettify(full_code))
logger.info(f"Generated DAG file: {file}")

def _write_operators(
self,
dag_flavor: ARGODagFlavor,
) -> str:

DAG_TEMPLATE = load_plugin_template("argo_dag.jinja")

if dag_flavor == ARGODagFlavor.StepPerSession:
task_breakdown = DagTaskBreakdown.TaskPerSession

# Get task definitions based on dag_flavor
task_defs, task_graph = get_task_graph(
self.artifact_collection,
pipeline_name=self.pipeline_name,
task_breakdown=task_breakdown,
)

task_defs["setup"] = get_noop_setup_task_definition(self.pipeline_name)
task_defs["teardown"] = get_noop_teardown_task_definition(
self.pipeline_name
)
# insert in order to task_names so that setup runs first and teardown runs last
task_graph.insert_setup_task("setup")
task_graph.insert_teardown_task("teardown")

task_names = list(task_defs.keys())

task_defs = {tn: task_defs[tn] for tn in task_names}

(
rendered_task_defs,
task_loading_blocks,
) = self.get_rendered_task_definitions(task_defs)

# Handle dependencies
task_dependencies = reversed(
[f"{task0} >> {task1}" for task0, task1 in task_graph.graph.edges]
)

# Get DAG parameters for an ARGO pipeline
input_parameters_dict: Dict[str, Any] = {}
for parameter_name, input_spec in super().get_pipeline_args().items():
input_parameters_dict[parameter_name] = input_spec.value

full_code = DAG_TEMPLATE.render(
DAG_NAME=self.pipeline_name,
MODULE_NAME=self.pipeline_name + "_module",
NAMESPACE=self.dag_config.get("namespace", "argo"),
HOST=self.dag_config.get("host", "https://localhost:2746"),
VERIFY_SSL=self.dag_config.get("verify_ssl", "False"),
WORFLOW_NAME=self.dag_config.get(
"workflow_name", self.pipeline_name.replace("_", "-")
),
IMAGE=self.dag_config.get("image", "argo_pipeline:latest"),
IMAGE_PULL_POLICY=self.dag_config.get(
"image_pull_policy", "Never"
),
SERVICE_ACCOUNT=self.dag_config.get("service_account", "argo"),
KUBE_CONFIG=self.dag_config.get(
"kube_config", os.path.expanduser("~/.kube/config")
),
TOKEN=self.dag_config.get("token", "None"),
dag_params=input_parameters_dict,
task_definitions=rendered_task_defs,
tasks=task_defs,
task_loading_blocks=task_loading_blocks,
task_dependencies=task_dependencies,
)

return full_code

def get_rendered_task_definitions(
self,
task_defs: Dict[str, TaskDefinition],
) -> Tuple[List[str], Dict[str, str]]:
"""
Returns rendered tasks for the pipeline tasks along with a dictionary to lookup
previous task outputs.
The returned dictionary is used by the DAG to connect the right input files to
output files for inter task communication.
This method originates from:
https://github.com/argoproj-labs/hera-workflows/blob/4efddc85bfce62455db758f4be47e3acc0342b4f/examples/k8s_sa.py#L12
"""
TASK_FUNCTION_TEMPLATE = load_plugin_template(
"task/task_function.jinja"
)
rendered_task_defs: List[str] = []
task_loading_blocks: Dict[str, str] = {}

for task_name, task_def in task_defs.items():
loading_blocks, dumping_blocks = render_task_io_serialize_blocks(
task_def, TaskSerializer.LocalPickle
)

input_vars = task_def.user_input_variables

# this task will output variables to a file that other tasks can access

for return_variable in task_def.return_vars:
task_loading_blocks[return_variable] = return_variable

task_def_rendered = TASK_FUNCTION_TEMPLATE.render(
MODULE_NAME=self.pipeline_name + "_module",
function_name=task_name,
user_input_variables=", ".join(input_vars),
typing_blocks=task_def.typing_blocks,
loading_blocks=loading_blocks,
pre_call_block=task_def.pre_call_block or "",
call_block=task_def.call_block,
post_call_block=task_def.post_call_block or "",
dumping_blocks=dumping_blocks,
include_imports_locally=True,
)
rendered_task_defs.append(task_def_rendered)

return rendered_task_defs, task_loading_blocks
93 changes: 93 additions & 0 deletions lineapy/plugins/jinja_templates/argo_dag.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from hera import Artifact, ImagePullPolicy, set_global_task_image
from hera.task import Task
from hera.workflow import Workflow
from hera.workflow_service import WorkflowService

from kubernetes import client, config
from typing import Optional
import base64, errno, os

def get_sa_token(
service_account: str,
namespace: str = "argo",
config_file: Optional[str] = None,
):
"""
Configues the kubernetes client and returns the service account token for the
specified service account in the specified namespace.
This is used in the case the local kubeconfig exists.
"""
if config_file is not None and not os.path.isfile(config_file):
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), config_file
)

config.load_kube_config(config_file=config_file)
v1 = client.CoreV1Api()
if (
v1.read_namespaced_service_account(
service_account, namespace
).secrets
is None
):
print("No secrets found in namespace: %s" % namespace)
return "None"

secret_name = (
v1.read_namespaced_service_account(service_account, namespace)
.secrets[0]
.name
)

sec = v1.read_namespaced_secret(secret_name, namespace).data
return base64.b64decode(sec["token"]).decode()

{% for task_def in task_definitions %}
{{ task_def }}
{% endfor %}

ws = WorkflowService(
host = "{{ HOST }}",
verify_ssl = {{ VERIFY_SSL }},
token = get_sa_token("{{SERVICE_ACCOUNT}}", "{{NAMESPACE}}", "{{KUBE_CONFIG}}"),
namespace = "{{ NAMESPACE }}",
)

with Workflow("{{ WORFLOW_NAME }}", service = ws) as w:

set_global_task_image("{{ IMAGE }}")

{% for task_name, task_def in tasks.items() %}
{{ task_name }} = Task("{{- task_name | replace("_", "-") -}}",
task_{{ task_name }},
{% if task_def.user_input_variables|length > 0 and task_name not in ["setup", "teardown"] %}
[
{
{% for var in task_def.user_input_variables %}
"{{ var }}": "{{ dag_params[var] }}",
{% endfor %}
}
],
{% endif %}
image_pull_policy = ImagePullPolicy.{{ IMAGE_PULL_POLICY }},
{% if task_def.return_vars|length > 0 and task_name not in ["setup", "teardown"] %}
outputs = [
{% for output in task_def.return_vars %}
Artifact(
"{{ output }}",
"/tmp/{{- WORFLOW_NAME | replace("-", "_") -}}/variable_{{output}}.pickle",
),
{% endfor %}
],
{% endif %}
)
{% endfor %}


{% if task_dependencies is not none %}
{% for TASK_DEPENDENCIES in task_dependencies %}
{{TASK_DEPENDENCIES}}
{% endfor %}
{% endif %}

w.create()
23 changes: 23 additions & 0 deletions lineapy/plugins/jinja_templates/argo_dockerfile.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
FROM python:{{ python_version }}

RUN mkdir /tmp/installers
WORKDIR /tmp/installers

# Copy all the requirements to run current DAG
COPY ./{{ pipeline_name }}_requirements.txt ./

# Install required libs
RUN pip install -r ./{{ pipeline_name }}_requirements.txt

WORKDIR /opt/argo/dags

# Install git and argo
RUN apt update
RUN apt install -y git
RUN pip install argo-workflows
RUN pip install hera-workflows

COPY ./{{ pipeline_name }}_module.py ./
COPY ./{{ pipeline_name }}_dag.py ./

ENTRYPOINT [ "argo", "repro", "run_all_sessions"]
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
if not pathlib.Path('/tmp').joinpath('{{pipeline_name}}').exists(): pathlib.Path('/tmp').joinpath('{{pipeline_name}}').mkdir()
pickle.dump({{return_variable}}, open('/tmp/{{pipeline_name}}/variable_{{return_variable}}.pickle','wb'))
2 changes: 1 addition & 1 deletion lineapy/plugins/jinja_templates/task/task_function.jinja
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def task_{{function_name}}({{user_input_variables}}):
{%- if include_imports_locally %}
import {{ MODULE_NAME }}
import pickle
import pickle, pathlib
{%- endif %}
{% for typing_block in typing_blocks %}
{{typing_block | indent(4, True) }}
Expand Down
2 changes: 1 addition & 1 deletion lineapy/plugins/kubeflow_pipeline_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def get_rendered_task_definitions(
loading_blocks=loading_blocks,
pre_call_block="",
call_block=task_def.call_block,
post_call_block="",
post_call_block=task_def.post_call_block,
dumping_blocks=dumping_blocks,
include_imports_locally=True,
)
Expand Down
3 changes: 3 additions & 0 deletions lineapy/plugins/pipeline_writer_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from lineapy.data.types import PipelineType
from lineapy.plugins.airflow_pipeline_writer import AirflowPipelineWriter
from lineapy.plugins.argo_pipeline_writer import ARGOPipelineWriter
from lineapy.plugins.base_pipeline_writer import BasePipelineWriter
from lineapy.plugins.dvc_pipeline_writer import DVCPipelineWriter
from lineapy.plugins.kubeflow_pipeline_writer import KubeflowPipelineWriter
Expand All @@ -17,6 +18,8 @@ def get(
return AirflowPipelineWriter(*args, **kwargs)
elif pipeline_type == PipelineType.DVC:
return DVCPipelineWriter(*args, **kwargs)
elif pipeline_type == PipelineType.ARGO:
return ARGOPipelineWriter(*args, **kwargs)
elif pipeline_type == PipelineType.KUBEFLOW:
return KubeflowPipelineWriter(*args, **kwargs)
else:
Expand Down
1 change: 0 additions & 1 deletion lineapy/plugins/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def render_task_io_serialize_blocks(
DESERIALIZER_TEMPLATE = load_plugin_template(
"task/parameterizedpickle/task_parameterized_deser.jinja"
)

# Add more renderable task serializers here

for loaded_input_variable in taskdef.loaded_input_variables:
Expand Down
Loading

0 comments on commit b750c5c

Please sign in to comment.