Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added transformation mutator for Python wheel task for them to work on DBR <13.1 #635

Merged
merged 10 commits into from
Aug 30, 2023
33 changes: 33 additions & 0 deletions bundle/artifacts/artifacts.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,36 @@ func getUploadBasePath(b *bundle.Bundle) (string, error) {

return path.Join(artifactPath, ".internal"), nil
}

func UploadNotebook(ctx context.Context, notebook string, b *bundle.Bundle) (string, error) {
raw, err := os.ReadFile(notebook)
if err != nil {
return "", fmt.Errorf("unable to read %s: %w", notebook, errors.Unwrap(err))
}
andrewnester marked this conversation as resolved.
Show resolved Hide resolved

uploadPath, err := getUploadBasePath(b)
if err != nil {
return "", err
}

remotePath := path.Join(uploadPath, path.Base(notebook))
// Make sure target directory exists.
err = b.WorkspaceClient().Workspace.MkdirsByPath(ctx, path.Dir(remotePath))
if err != nil {
return "", fmt.Errorf("unable to create directory for %s: %w", remotePath, err)
}

// Import to workspace.
err = b.WorkspaceClient().Workspace.Import(ctx, workspace.Import{
Path: remotePath,
Overwrite: true,
Format: workspace.ImportFormatSource,
Content: base64.StdEncoding.EncodeToString(raw),
Language: workspace.LanguagePython,
})
if err != nil {
return "", fmt.Errorf("unable to import %s: %w", remotePath, err)
}

return remotePath, nil
}
44 changes: 33 additions & 11 deletions bundle/libraries/libraries.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,46 @@ func (a *match) Name() string {
}

func (a *match) Apply(ctx context.Context, b *bundle.Bundle) error {
tasks := findAllTasks(b)
for _, task := range tasks {
if isMissingRequiredLibraries(task) {
return fmt.Errorf("task '%s' is missing required libraries. Please include your package code in task libraries block", task.TaskKey)
}
for j := range task.Libraries {
lib := &task.Libraries[j]
err := findArtifactsAndMarkForUpload(ctx, lib, b)
if err != nil {
return err
}
}
}
return nil
}

func findAllTasks(b *bundle.Bundle) []*jobs.Task {
r := b.Config.Resources
result := make([]*jobs.Task, 0)
for k := range b.Config.Resources.Jobs {
tasks := r.Jobs[k].JobSettings.Tasks
for i := range tasks {
task := &tasks[i]
if isMissingRequiredLibraries(task) {
return fmt.Errorf("task '%s' is missing required libraries. Please include your package code in task libraries block", task.TaskKey)
}
for j := range task.Libraries {
lib := &task.Libraries[j]
err := findArtifactsAndMarkForUpload(ctx, lib, b)
if err != nil {
return err
}
}
result = append(result, task)
}
}
return nil

return result
}

func FindAllWheelTasks(b *bundle.Bundle) []*jobs.Task {
tasks := findAllTasks(b)
wheelTasks := make([]*jobs.Task, 0)
for _, task := range tasks {
if task.PythonWheelTask != nil {
wheelTasks = append(wheelTasks, task)
}
}

return wheelTasks
}

func isMissingRequiredLibraries(task *jobs.Task) bool {
Expand Down
2 changes: 2 additions & 0 deletions bundle/phases/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/databricks/cli/bundle/deploy/lock"
"github.com/databricks/cli/bundle/deploy/terraform"
"github.com/databricks/cli/bundle/libraries"
"github.com/databricks/cli/bundle/python"
)

// The deploy phase deploys artifacts and resources.
Expand All @@ -21,6 +22,7 @@ func Deploy() bundle.Mutator {
libraries.MatchWithArtifacts(),
artifacts.CleanUp(),
artifacts.UploadAll(),
python.TransformWheelTask(),
terraform.Interpolate(),
terraform.Write(),
terraform.StatePull(),
Expand Down
106 changes: 106 additions & 0 deletions bundle/python/transform.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package python

import (
"bytes"
"context"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/artifacts"
"github.com/databricks/cli/bundle/libraries"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs"
)

// This mutator takes the wheel task and trasnforms it into notebook
andrewnester marked this conversation as resolved.
Show resolved Hide resolved
// which installs uploaded wheels using %pip and then calling corresponding
// entry point.
func TransformWheelTask() bundle.Mutator {
return &transform{}
}

type transform struct {
}

func (m *transform) Name() string {
return "python.TransformWheelTask"
}

const INSTALL_WHEEL_CODE = `%%pip install --force-reinstall %s`

const NOTEBOOK_CODE = `
%%python
%s

from contextlib import redirect_stdout
import io
import sys
sys.argv = [%s]

import pkg_resources
_func = pkg_resources.load_entry_point("%s", "console_scripts", "%s")
andrewnester marked this conversation as resolved.
Show resolved Hide resolved

f = io.StringIO()
with redirect_stdout(f):
_func()
s = f.getvalue()
dbutils.notebook.exit(s)
`
pietern marked this conversation as resolved.
Show resolved Hide resolved

func (m *transform) Apply(ctx context.Context, b *bundle.Bundle) error {
// TODO: do the transformaton only for DBR < 13.1 and (maybe?) existing clusters
wheelTasks := libraries.FindAllWheelTasks(b)
for _, wheelTask := range wheelTasks {
andrewnester marked this conversation as resolved.
Show resolved Hide resolved
taskDefinition := wheelTask.PythonWheelTask
libraries := wheelTask.Libraries

wheelTask.PythonWheelTask = nil
wheelTask.Libraries = nil

path, err := generateNotebookWrapper(taskDefinition, libraries)
if err != nil {
return err
}

remotePath, err := artifacts.UploadNotebook(context.Background(), path, b)
if err != nil {
return err
}

os.Remove(path)

wheelTask.NotebookTask = &jobs.NotebookTask{
NotebookPath: remotePath,
}
}
return nil
}
andrewnester marked this conversation as resolved.
Show resolved Hide resolved

func generateNotebookWrapper(task *jobs.PythonWheelTask, libraries []compute.Library) (string, error) {
pipInstall := ""
for _, lib := range libraries {
pipInstall = pipInstall + "\n" + fmt.Sprintf(INSTALL_WHEEL_CODE, lib.Whl)
}
content := fmt.Sprintf(NOTEBOOK_CODE, pipInstall, generateParameters(task), task.PackageName, task.EntryPoint)
andrewnester marked this conversation as resolved.
Show resolved Hide resolved

tmpDir := os.TempDir()
andrewnester marked this conversation as resolved.
Show resolved Hide resolved
filename := fmt.Sprintf("notebook_%s_%s.ipynb", task.PackageName, task.EntryPoint)
andrewnester marked this conversation as resolved.
Show resolved Hide resolved
path := filepath.Join(tmpDir, filename)

err := os.WriteFile(path, bytes.NewBufferString(content).Bytes(), 0644)
return path, err
}

func generateParameters(task *jobs.PythonWheelTask) string {
params := append([]string{"python"}, task.Parameters...)
for k, v := range task.NamedParameters {
params = append(params, fmt.Sprintf("%s=%s", k, v))
andrewnester marked this conversation as resolved.
Show resolved Hide resolved
}
for i := range params {
params[i] = `"` + params[i] + `"`
}
return strings.Join(params, ", ")
}
andrewnester marked this conversation as resolved.
Show resolved Hide resolved