Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Argo Pipeline Integration #855

Merged
merged 32 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3cbe261
opencv package annotation
lazargugleta Nov 2, 2022
5776880
Merge branch 'LineaLabs:main' into main
lazargugleta Nov 2, 2022
1f56d22
Merge branch 'LineaLabs:main' into main
lazargugleta Nov 3, 2022
7e021e5
skip tests if cv2 module not installed
lazargugleta Nov 3, 2022
18e2460
Merge remote-tracking branch 'origin/main' into annotating-packages
lazargugleta Nov 3, 2022
da52bf8
remove prettify from test_opencv.py
lazargugleta Nov 3, 2022
abac606
Merge branch 'LineaLabs:main' into main
lazargugleta Nov 3, 2022
67dbf56
Add dask package annotation
lazargugleta Nov 7, 2022
b033f11
Merge branch 'LineaLabs:main' into main
lazargugleta Nov 7, 2022
bc0980c
update dask annotations and tests
lazargugleta Nov 16, 2022
1f429eb
Merge remote-tracking branch 'upstream/main'
lazargugleta Dec 5, 2022
c45ec21
remove to_feather/records
lazargugleta Dec 5, 2022
cdabb36
Merge remote-tracking branch 'upstream/main'
lazargugleta Dec 7, 2022
f085ca4
argo pipeline integration
lazargugleta Dec 7, 2022
74f9242
Merge remote-tracking branch 'upstream/main'
lazargugleta Dec 16, 2022
9b54c50
update to argo dag template and tasks configuration to suite argo
lazargugleta Dec 16, 2022
c22d0f3
Merge remote-tracking branch 'upstream/main'
lazargugleta Dec 19, 2022
3840f47
fix extra import argo_writer
lazargugleta Dec 19, 2022
c1807f9
merge changes
lazargugleta Dec 20, 2022
7f8a32e
change argo dag flavors, adapt to new task graph and output variables
lazargugleta Dec 20, 2022
f34a212
change the post call blocks for argo and kubeflow
lazargugleta Dec 20, 2022
af3200a
remove tests for argo
lazargugleta Dec 20, 2022
8284e32
move all kubernetes operations to _dag file for argo
lazargugleta Dec 21, 2022
61d33c5
add test snapshots
lazargugleta Dec 21, 2022
57b8d68
remove argo specific templates from tasks
lazargugleta Dec 21, 2022
7a3f774
update for runner
lazargugleta Dec 21, 2022
7185ff4
pickle communication enabled with folder creation
lazargugleta Dec 21, 2022
c2d3e2c
update snapshots for airflow
lazargugleta Dec 21, 2022
df881c7
remove stepperartifact flavor
lazargugleta Dec 23, 2022
1b3b0a2
tests failing again
lazargugleta Dec 23, 2022
3e46a5f
update runner
lazargugleta Dec 23, 2022
790bc28
remove step per artifact test from ambr
lazargugleta Dec 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lineapy/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def to_pipeline(
Names of artifacts to be included in the pipeline.

framework: str
"AIRFLOW" or "SCRIPT". Defaults to "SCRIPT" if not specified.
Name of the framework to be used.
Defined by enum PipelineTypes in lineapy/data/types.py.
Defaults to "SCRIPT" if not specified.

pipeline_name: Optional[str]
Name of the pipeline.
Expand Down Expand Up @@ -432,7 +434,8 @@ def to_pipeline(
A dictionary of parameters to configure DAG file to be generated.
Not applicable for "SCRIPT" framework as it does not generate a separate
DAG file. For "AIRFLOW" framework, Airflow-native config params such as
"retries" and "schedule_interval" can be passed in.
"retries" and "schedule_interval" can be passed in. For "ARGO" framework,
Argo-native config params such as "namespace" and "service_account_name".

Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions lineapy/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,14 @@ class PipelineType(Enum):
- SCRIPT : the pipeline is wrapped as a python script
- AIRFLOW : the pipeline is wrapped as an airflow dag
- DVC : the pipeline is wrapped as a DVC

- ARGO: the pipeline is wrapped as an Argo workflow dag
- KUBEFLOW : the pipeline is defined using Kubeflow's python SDK
"""

SCRIPT = 1
AIRFLOW = 2
DVC = 3
# ARGO = 4
ARGO = 4
KUBEFLOW = 5


Expand Down
192 changes: 192 additions & 0 deletions lineapy/plugins/argo_pipeline_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import logging
import os
from enum import Enum
from typing import Any, Dict, List, Tuple

from typing_extensions import TypedDict

from lineapy.plugins.base_pipeline_writer import BasePipelineWriter
from lineapy.plugins.task import (
DagTaskBreakdown,
TaskDefinition,
TaskSerializer,
render_task_io_serialize_blocks,
)
from lineapy.plugins.taskgen import (
get_noop_setup_task_definition,
get_noop_teardown_task_definition,
get_task_graph,
)
from lineapy.plugins.utils import load_plugin_template
from lineapy.utils.logging_config import configure_logging
from lineapy.utils.utils import prettify

logger = logging.getLogger(__name__)
configure_logging()


class ARGODagFlavor(Enum):
StepPerSession = 1


ARGODAGConfig = TypedDict(
"ARGODAGConfig",
{
"namespace": str,
"host": str,
"verify_ssl": str,
"image": str,
"image_pull_policy": str,
"token": str,
"workflow_name": str,
"service_account": int,
"kube_config": str,
"dag_flavor": str,
},
total=False,
)


class ARGOPipelineWriter(BasePipelineWriter):
"""
Class for pipeline file writer. Corresponds to "ARGO" framework.
"""

@property
def docker_template_name(self) -> str:
return "argo_dockerfile.jinja"

def _write_dag(self) -> None:

# Check if the given DAG flavor is a supported/valid one
try:
dag_flavor = ARGODagFlavor[
self.dag_config.get("dag_flavor", "StepPerSession")
]
except KeyError:
raise ValueError(f'"{dag_flavor}" is an invalid ARGO dag flavor.')

# Construct DAG text for the given flavor
full_code = self._write_operators(dag_flavor)

# Write out file
file = self.output_dir / f"{self.pipeline_name}_dag.py"
file.write_text(prettify(full_code))
logger.info(f"Generated DAG file: {file}")

def _write_operators(
self,
dag_flavor: ARGODagFlavor,
) -> str:

DAG_TEMPLATE = load_plugin_template("argo_dag.jinja")

if dag_flavor == ARGODagFlavor.StepPerSession:
task_breakdown = DagTaskBreakdown.TaskPerSession

# Get task definitions based on dag_flavor
task_defs, task_graph = get_task_graph(
self.artifact_collection,
pipeline_name=self.pipeline_name,
task_breakdown=task_breakdown,
)

task_defs["setup"] = get_noop_setup_task_definition(self.pipeline_name)
task_defs["teardown"] = get_noop_teardown_task_definition(
self.pipeline_name
)
# insert in order to task_names so that setup runs first and teardown runs last
task_graph.insert_setup_task("setup")
task_graph.insert_teardown_task("teardown")

task_names = list(task_defs.keys())

task_defs = {tn: task_defs[tn] for tn in task_names}

(
rendered_task_defs,
task_loading_blocks,
) = self.get_rendered_task_definitions(task_defs)

# Handle dependencies
task_dependencies = reversed(
lazargugleta marked this conversation as resolved.
Show resolved Hide resolved
[f"{task0} >> {task1}" for task0, task1 in task_graph.graph.edges]
)

# Get DAG parameters for an ARGO pipeline
input_parameters_dict: Dict[str, Any] = {}
for parameter_name, input_spec in super().get_pipeline_args().items():
input_parameters_dict[parameter_name] = input_spec.value

full_code = DAG_TEMPLATE.render(
DAG_NAME=self.pipeline_name,
MODULE_NAME=self.pipeline_name + "_module",
NAMESPACE=self.dag_config.get("namespace", "argo"),
HOST=self.dag_config.get("host", "https://localhost:2746"),
VERIFY_SSL=self.dag_config.get("verify_ssl", "False"),
WORFLOW_NAME=self.dag_config.get(
"workflow_name", self.pipeline_name.replace("_", "-")
),
IMAGE=self.dag_config.get("image", "argo_pipeline:latest"),
IMAGE_PULL_POLICY=self.dag_config.get(
"image_pull_policy", "Never"
),
SERVICE_ACCOUNT=self.dag_config.get("service_account", "argo"),
KUBE_CONFIG=self.dag_config.get(
"kube_config", os.path.expanduser("~/.kube/config")
),
TOKEN=self.dag_config.get("token", "None"),
dag_params=input_parameters_dict,
task_definitions=rendered_task_defs,
tasks=task_defs,
task_loading_blocks=task_loading_blocks,
task_dependencies=task_dependencies,
)

return full_code

def get_rendered_task_definitions(
self,
task_defs: Dict[str, TaskDefinition],
) -> Tuple[List[str], Dict[str, str]]:
"""
Returns rendered tasks for the pipeline tasks along with a dictionary to lookup
previous task outputs.
The returned dictionary is used by the DAG to connect the right input files to
output files for inter task communication.
This method originates from:
https://github.com/argoproj-labs/hera-workflows/blob/4efddc85bfce62455db758f4be47e3acc0342b4f/examples/k8s_sa.py#L12
"""
TASK_FUNCTION_TEMPLATE = load_plugin_template(
"task/task_function.jinja"
)
rendered_task_defs: List[str] = []
task_loading_blocks: Dict[str, str] = {}

for task_name, task_def in task_defs.items():
loading_blocks, dumping_blocks = render_task_io_serialize_blocks(
task_def, TaskSerializer.LocalPickle
)

input_vars = task_def.user_input_variables

# this task will output variables to a file that other tasks can access

for return_variable in task_def.return_vars:
task_loading_blocks[return_variable] = return_variable

task_def_rendered = TASK_FUNCTION_TEMPLATE.render(
andycui97 marked this conversation as resolved.
Show resolved Hide resolved
MODULE_NAME=self.pipeline_name + "_module",
function_name=task_name,
user_input_variables=", ".join(input_vars),
typing_blocks=task_def.typing_blocks,
loading_blocks=loading_blocks,
pre_call_block=task_def.pre_call_block or "",
call_block=task_def.call_block,
post_call_block=task_def.post_call_block or "",
dumping_blocks=dumping_blocks,
include_imports_locally=True,
)
rendered_task_defs.append(task_def_rendered)

return rendered_task_defs, task_loading_blocks
93 changes: 93 additions & 0 deletions lineapy/plugins/jinja_templates/argo_dag.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from hera import Artifact, ImagePullPolicy, set_global_task_image
from hera.task import Task
from hera.workflow import Workflow
from hera.workflow_service import WorkflowService

from kubernetes import client, config
from typing import Optional
import base64, errno, os

def get_sa_token(
service_account: str,
namespace: str = "argo",
config_file: Optional[str] = None,
):
"""
Configues the kubernetes client and returns the service account token for the
specified service account in the specified namespace.
This is used in the case the local kubeconfig exists.
"""
if config_file is not None and not os.path.isfile(config_file):
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), config_file
)

config.load_kube_config(config_file=config_file)
v1 = client.CoreV1Api()
if (
v1.read_namespaced_service_account(
service_account, namespace
).secrets
is None
):
print("No secrets found in namespace: %s" % namespace)
return "None"

secret_name = (
v1.read_namespaced_service_account(service_account, namespace)
.secrets[0]
.name
)

sec = v1.read_namespaced_secret(secret_name, namespace).data
return base64.b64decode(sec["token"]).decode()

{% for task_def in task_definitions %}
{{ task_def }}
{% endfor %}

ws = WorkflowService(
host = "{{ HOST }}",
verify_ssl = {{ VERIFY_SSL }},
token = get_sa_token("{{SERVICE_ACCOUNT}}", "{{NAMESPACE}}", "{{KUBE_CONFIG}}"),
namespace = "{{ NAMESPACE }}",
)

with Workflow("{{ WORFLOW_NAME }}", service = ws) as w:

set_global_task_image("{{ IMAGE }}")

{% for task_name, task_def in tasks.items() %}
{{ task_name }} = Task("{{- task_name | replace("_", "-") -}}",
task_{{ task_name }},
{% if task_def.user_input_variables|length > 0 and task_name not in ["setup", "teardown"] %}
[
{
{% for var in task_def.user_input_variables %}
"{{ var }}": "{{ dag_params[var] }}",
{% endfor %}
}
],
{% endif %}
image_pull_policy = ImagePullPolicy.{{ IMAGE_PULL_POLICY }},
{% if task_def.return_vars|length > 0 and task_name not in ["setup", "teardown"] %}
outputs = [
{% for output in task_def.return_vars %}
Artifact(
andycui97 marked this conversation as resolved.
Show resolved Hide resolved
"{{ output }}",
"/tmp/{{- WORFLOW_NAME | replace("-", "_") -}}/variable_{{output}}.pickle",
),
{% endfor %}
],
{% endif %}
)
{% endfor %}


{% if task_dependencies is not none %}
{% for TASK_DEPENDENCIES in task_dependencies %}
{{TASK_DEPENDENCIES}}
{% endfor %}
{% endif %}

w.create()
23 changes: 23 additions & 0 deletions lineapy/plugins/jinja_templates/argo_dockerfile.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
FROM python:{{ python_version }}

RUN mkdir /tmp/installers
WORKDIR /tmp/installers

# Copy all the requirements to run current DAG
COPY ./{{ pipeline_name }}_requirements.txt ./

# Install required libs
RUN pip install -r ./{{ pipeline_name }}_requirements.txt

WORKDIR /opt/argo/dags

# Install git and argo
RUN apt update
RUN apt install -y git
RUN pip install argo-workflows
RUN pip install hera-workflows

COPY ./{{ pipeline_name }}_module.py ./
COPY ./{{ pipeline_name }}_dag.py ./

ENTRYPOINT [ "argo", "repro", "run_all_sessions"]
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
if not pathlib.Path('/tmp').joinpath('{{pipeline_name}}').exists(): pathlib.Path('/tmp').joinpath('{{pipeline_name}}').mkdir()
lionsardesai marked this conversation as resolved.
Show resolved Hide resolved
pickle.dump({{return_variable}}, open('/tmp/{{pipeline_name}}/variable_{{return_variable}}.pickle','wb'))
2 changes: 1 addition & 1 deletion lineapy/plugins/jinja_templates/task/task_function.jinja
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def task_{{function_name}}({{user_input_variables}}):
{%- if include_imports_locally %}
import {{ MODULE_NAME }}
import pickle
import pickle, pathlib
{%- endif %}
{% for typing_block in typing_blocks %}
{{typing_block | indent(4, True) }}
Expand Down
2 changes: 1 addition & 1 deletion lineapy/plugins/kubeflow_pipeline_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def get_rendered_task_definitions(
loading_blocks=loading_blocks,
pre_call_block="",
call_block=task_def.call_block,
post_call_block="",
post_call_block=task_def.post_call_block,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing for pre_call_block makes sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also not sure how jinja tackles [None]. Does it render as the string "None" or skips it? if its the former, we should use task_def.post_call_block or "" here.

dumping_blocks=dumping_blocks,
include_imports_locally=True,
)
Expand Down
3 changes: 3 additions & 0 deletions lineapy/plugins/pipeline_writer_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from lineapy.data.types import PipelineType
from lineapy.plugins.airflow_pipeline_writer import AirflowPipelineWriter
from lineapy.plugins.argo_pipeline_writer import ARGOPipelineWriter
from lineapy.plugins.base_pipeline_writer import BasePipelineWriter
from lineapy.plugins.dvc_pipeline_writer import DVCPipelineWriter
from lineapy.plugins.kubeflow_pipeline_writer import KubeflowPipelineWriter
Expand All @@ -17,6 +18,8 @@ def get(
return AirflowPipelineWriter(*args, **kwargs)
elif pipeline_type == PipelineType.DVC:
return DVCPipelineWriter(*args, **kwargs)
elif pipeline_type == PipelineType.ARGO:
return ARGOPipelineWriter(*args, **kwargs)
elif pipeline_type == PipelineType.KUBEFLOW:
return KubeflowPipelineWriter(*args, **kwargs)
else:
Expand Down
1 change: 0 additions & 1 deletion lineapy/plugins/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def render_task_io_serialize_blocks(
DESERIALIZER_TEMPLATE = load_plugin_template(
"task/parameterizedpickle/task_parameterized_deser.jinja"
)

# Add more renderable task serializers here

for loaded_input_variable in taskdef.loaded_input_variables:
Expand Down
Loading