Skip to content

Commit

Permalink
Enable environment overrides for job tasks (#779)
Browse files Browse the repository at this point in the history
## Changes
Follow up for #658

When a job definition has multiple job tasks using the same key, it's
considered invalid. Instead we should combine those definitions with the
same key into one. This is consistent with environment overrides. This
way, the override ends up in the original job tasks, and we've got a
clear way to put them all together.

## Tests
Added unit tests
  • Loading branch information
andrewnester authored Sep 18, 2023
1 parent b3b00fd commit 43e2eef
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 0 deletions.
11 changes: 11 additions & 0 deletions bundle/config/resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,14 @@ func (r *Resources) MergeJobClusters() error {
}
return nil
}

// MergeTasks iterates over all jobs and merges their tasks.
// This is called after applying the target overrides.
func (r *Resources) MergeTasks() error {
for _, job := range r.Jobs {
if err := job.MergeTasks(); err != nil {
return err
}
}
return nil
}
33 changes: 33 additions & 0 deletions bundle/config/resources/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,36 @@ func (j *Job) MergeJobClusters() error {
j.JobClusters = output
return nil
}

// MergeTasks merges tasks with the same key.
// The tasks field is a slice, and as such, overrides are appended to it.
// We can identify a task by its task key, however, so we can use this key
// to figure out which definitions are actually overrides and merge them.
func (j *Job) MergeTasks() error {
keys := make(map[string]*jobs.Task)
tasks := make([]jobs.Task, 0, len(j.Tasks))

// Target overrides are always appended, so we can iterate in natural order to
// first find the base definition, and merge instances we encounter later.
for i := range j.Tasks {
key := j.Tasks[i].TaskKey

// Register the task with key if not yet seen before.
ref, ok := keys[key]
if !ok {
tasks = append(tasks, j.Tasks[i])
keys[key] = &j.Tasks[i]
continue
}

// Merge this instance into the reference.
err := mergo.Merge(ref, &j.Tasks[i], mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return err
}
}

// Overwrite resulting slice.
j.Tasks = tasks
return nil
}
47 changes: 47 additions & 0 deletions bundle/config/resources/job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,50 @@ func TestJobMergeJobClusters(t *testing.T) {
jc1 := j.JobClusters[1].NewCluster
assert.Equal(t, "10.4.x-scala2.12", jc1.SparkVersion)
}

func TestJobMergeTasks(t *testing.T) {
j := &Job{
JobSettings: &jobs.JobSettings{
Tasks: []jobs.Task{
{
TaskKey: "foo",
NewCluster: &compute.ClusterSpec{
SparkVersion: "13.3.x-scala2.12",
NodeTypeId: "i3.xlarge",
NumWorkers: 2,
},
},
{
TaskKey: "bar",
NewCluster: &compute.ClusterSpec{
SparkVersion: "10.4.x-scala2.12",
},
},
{
TaskKey: "foo",
NewCluster: &compute.ClusterSpec{
NodeTypeId: "i3.2xlarge",
NumWorkers: 4,
},
},
},
},
}

err := j.MergeTasks()
require.NoError(t, err)

assert.Len(t, j.Tasks, 2)
assert.Equal(t, "foo", j.Tasks[0].TaskKey)
assert.Equal(t, "bar", j.Tasks[1].TaskKey)

// This task was merged with a subsequent one.
task0 := j.Tasks[0].NewCluster
assert.Equal(t, "13.3.x-scala2.12", task0.SparkVersion)
assert.Equal(t, "i3.2xlarge", task0.NodeTypeId)
assert.Equal(t, 4, task0.NumWorkers)

// This task was left untouched.
task1 := j.Tasks[1].NewCluster
assert.Equal(t, "10.4.x-scala2.12", task1.SparkVersion)
}
5 changes: 5 additions & 0 deletions bundle/config/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ func (r *Root) MergeTargetOverrides(target *Target) error {
if err != nil {
return err
}

err = r.Resources.MergeTasks()
if err != nil {
return err
}
}

if target.Variables != nil {
Expand Down
44 changes: 44 additions & 0 deletions bundle/tests/override_job_tasks/databricks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
bundle:
name: override_job_tasks

workspace:
host: https://acme.cloud.databricks.com/

resources:
jobs:
foo:
name: job
tasks:
- task_key: key1
new_cluster:
spark_version: 13.3.x-scala2.12
spark_python_task:
python_file: ./test1.py
- task_key: key2
new_cluster:
spark_version: 13.3.x-scala2.12
spark_python_task:
python_file: ./test2.py

targets:
development:
resources:
jobs:
foo:
tasks:
- task_key: key1
new_cluster:
node_type_id: i3.xlarge
num_workers: 1

staging:
resources:
jobs:
foo:
tasks:
- task_key: key2
new_cluster:
node_type_id: i3.2xlarge
num_workers: 4
spark_python_task:
python_file: ./test3.py
39 changes: 39 additions & 0 deletions bundle/tests/override_job_tasks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package config_tests

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestOverrideTasksDev(t *testing.T) {
b := loadTarget(t, "./override_job_tasks", "development")
assert.Equal(t, "job", b.Config.Resources.Jobs["foo"].Name)
assert.Len(t, b.Config.Resources.Jobs["foo"].Tasks, 2)

tasks := b.Config.Resources.Jobs["foo"].Tasks
assert.Equal(t, tasks[0].TaskKey, "key1")
assert.Equal(t, tasks[0].NewCluster.NodeTypeId, "i3.xlarge")
assert.Equal(t, tasks[0].NewCluster.NumWorkers, 1)
assert.Equal(t, tasks[0].SparkPythonTask.PythonFile, "./test1.py")

assert.Equal(t, tasks[1].TaskKey, "key2")
assert.Equal(t, tasks[1].NewCluster.SparkVersion, "13.3.x-scala2.12")
assert.Equal(t, tasks[1].SparkPythonTask.PythonFile, "./test2.py")
}

func TestOverrideTasksStaging(t *testing.T) {
b := loadTarget(t, "./override_job_tasks", "staging")
assert.Equal(t, "job", b.Config.Resources.Jobs["foo"].Name)
assert.Len(t, b.Config.Resources.Jobs["foo"].Tasks, 2)

tasks := b.Config.Resources.Jobs["foo"].Tasks
assert.Equal(t, tasks[0].TaskKey, "key1")
assert.Equal(t, tasks[0].NewCluster.SparkVersion, "13.3.x-scala2.12")
assert.Equal(t, tasks[0].SparkPythonTask.PythonFile, "./test1.py")

assert.Equal(t, tasks[1].TaskKey, "key2")
assert.Equal(t, tasks[1].NewCluster.NodeTypeId, "i3.2xlarge")
assert.Equal(t, tasks[1].NewCluster.NumWorkers, 4)
assert.Equal(t, tasks[1].SparkPythonTask.PythonFile, "./test3.py")
}

0 comments on commit 43e2eef

Please sign in to comment.