diff --git a/bundle/config/resources.go b/bundle/config/resources.go index c239b510b9..48621e443f 100644 --- a/bundle/config/resources.go +++ b/bundle/config/resources.go @@ -141,3 +141,14 @@ func (r *Resources) MergeJobClusters() error { } return nil } + +// MergeTasks iterates over all jobs and merges their tasks. +// This is called after applying the target overrides. +func (r *Resources) MergeTasks() error { + for _, job := range r.Jobs { + if err := job.MergeTasks(); err != nil { + return err + } + } + return nil +} diff --git a/bundle/config/resources/job.go b/bundle/config/resources/job.go index 66705afb2f..7fc5b7610b 100644 --- a/bundle/config/resources/job.go +++ b/bundle/config/resources/job.go @@ -47,3 +47,36 @@ func (j *Job) MergeJobClusters() error { j.JobClusters = output return nil } + +// MergeTasks merges tasks with the same key. +// The tasks field is a slice, and as such, overrides are appended to it. +// We can identify a task by its task key, however, so we can use this key +// to figure out which definitions are actually overrides and merge them. +func (j *Job) MergeTasks() error { + keys := make(map[string]*jobs.Task) + tasks := make([]jobs.Task, 0, len(j.Tasks)) + + // Target overrides are always appended, so we can iterate in natural order to + // first find the base definition, and merge instances we encounter later. + for i := range j.Tasks { + key := j.Tasks[i].TaskKey + + // Register the task with key if not yet seen before. + ref, ok := keys[key] + if !ok { + tasks = append(tasks, j.Tasks[i]) + keys[key] = &j.Tasks[i] + continue + } + + // Merge this instance into the reference. + err := mergo.Merge(ref, &j.Tasks[i], mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return err + } + } + + // Overwrite resulting slice. + j.Tasks = tasks + return nil +} diff --git a/bundle/config/resources/job_test.go b/bundle/config/resources/job_test.go index 2ff3205e0d..818d2ac216 100644 --- a/bundle/config/resources/job_test.go +++ b/bundle/config/resources/job_test.go @@ -55,3 +55,50 @@ func TestJobMergeJobClusters(t *testing.T) { jc1 := j.JobClusters[1].NewCluster assert.Equal(t, "10.4.x-scala2.12", jc1.SparkVersion) } + +func TestJobMergeTasks(t *testing.T) { + j := &Job{ + JobSettings: &jobs.JobSettings{ + Tasks: []jobs.Task{ + { + TaskKey: "foo", + NewCluster: &compute.ClusterSpec{ + SparkVersion: "13.3.x-scala2.12", + NodeTypeId: "i3.xlarge", + NumWorkers: 2, + }, + }, + { + TaskKey: "bar", + NewCluster: &compute.ClusterSpec{ + SparkVersion: "10.4.x-scala2.12", + }, + }, + { + TaskKey: "foo", + NewCluster: &compute.ClusterSpec{ + NodeTypeId: "i3.2xlarge", + NumWorkers: 4, + }, + }, + }, + }, + } + + err := j.MergeTasks() + require.NoError(t, err) + + assert.Len(t, j.Tasks, 2) + assert.Equal(t, "foo", j.Tasks[0].TaskKey) + assert.Equal(t, "bar", j.Tasks[1].TaskKey) + + // This task was merged with a subsequent one. + task0 := j.Tasks[0].NewCluster + assert.Equal(t, "13.3.x-scala2.12", task0.SparkVersion) + assert.Equal(t, "i3.2xlarge", task0.NodeTypeId) + assert.Equal(t, 4, task0.NumWorkers) + + // This task was left untouched. + task1 := j.Tasks[1].NewCluster + assert.Equal(t, "10.4.x-scala2.12", task1.SparkVersion) +} diff --git a/bundle/config/root.go b/bundle/config/root.go index 465d8a62e1..32883c746e 100644 --- a/bundle/config/root.go +++ b/bundle/config/root.go @@ -242,6 +242,11 @@ func (r *Root) MergeTargetOverrides(target *Target) error { if err != nil { return err } + + err = r.Resources.MergeTasks() + if err != nil { + return err + } } if target.Variables != nil { diff --git a/bundle/tests/override_job_tasks/databricks.yml b/bundle/tests/override_job_tasks/databricks.yml new file mode 100644 index 0000000000..ddee287935 --- /dev/null +++ b/bundle/tests/override_job_tasks/databricks.yml @@ -0,0 +1,44 @@ +bundle: + name: override_job_tasks + +workspace: + host: https://acme.cloud.databricks.com/ + +resources: + jobs: + foo: + name: job + tasks: + - task_key: key1 + new_cluster: + spark_version: 13.3.x-scala2.12 + spark_python_task: + python_file: ./test1.py + - task_key: key2 + new_cluster: + spark_version: 13.3.x-scala2.12 + spark_python_task: + python_file: ./test2.py + +targets: + development: + resources: + jobs: + foo: + tasks: + - task_key: key1 + new_cluster: + node_type_id: i3.xlarge + num_workers: 1 + + staging: + resources: + jobs: + foo: + tasks: + - task_key: key2 + new_cluster: + node_type_id: i3.2xlarge + num_workers: 4 + spark_python_task: + python_file: ./test3.py diff --git a/bundle/tests/override_job_tasks_test.go b/bundle/tests/override_job_tasks_test.go new file mode 100644 index 0000000000..82da04da26 --- /dev/null +++ b/bundle/tests/override_job_tasks_test.go @@ -0,0 +1,39 @@ +package config_tests + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOverrideTasksDev(t *testing.T) { + b := loadTarget(t, "./override_job_tasks", "development") + assert.Equal(t, "job", b.Config.Resources.Jobs["foo"].Name) + assert.Len(t, b.Config.Resources.Jobs["foo"].Tasks, 2) + + tasks := b.Config.Resources.Jobs["foo"].Tasks + assert.Equal(t, tasks[0].TaskKey, "key1") + assert.Equal(t, tasks[0].NewCluster.NodeTypeId, "i3.xlarge") + assert.Equal(t, tasks[0].NewCluster.NumWorkers, 1) + assert.Equal(t, tasks[0].SparkPythonTask.PythonFile, "./test1.py") + + assert.Equal(t, tasks[1].TaskKey, "key2") + assert.Equal(t, tasks[1].NewCluster.SparkVersion, "13.3.x-scala2.12") + assert.Equal(t, tasks[1].SparkPythonTask.PythonFile, "./test2.py") +} + +func TestOverrideTasksStaging(t *testing.T) { + b := loadTarget(t, "./override_job_tasks", "staging") + assert.Equal(t, "job", b.Config.Resources.Jobs["foo"].Name) + assert.Len(t, b.Config.Resources.Jobs["foo"].Tasks, 2) + + tasks := b.Config.Resources.Jobs["foo"].Tasks + assert.Equal(t, tasks[0].TaskKey, "key1") + assert.Equal(t, tasks[0].NewCluster.SparkVersion, "13.3.x-scala2.12") + assert.Equal(t, tasks[0].SparkPythonTask.PythonFile, "./test1.py") + + assert.Equal(t, tasks[1].TaskKey, "key2") + assert.Equal(t, tasks[1].NewCluster.NodeTypeId, "i3.2xlarge") + assert.Equal(t, tasks[1].NewCluster.NumWorkers, 4) + assert.Equal(t, tasks[1].SparkPythonTask.PythonFile, "./test3.py") +}