From e51e62b51884c73f9210421bb1a98ae1f52855c1 Mon Sep 17 00:00:00 2001 From: Fabian Jakobs Date: Wed, 18 Sep 2024 17:42:22 +0200 Subject: [PATCH] DB Connect Progress: Make sure we always end up at 100% (#1363) ## Changes DB Connect Progress: Make sure we always end up at 100% ## Tests --- .../resources/python/00-databricks-init.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/packages/databricks-vscode/resources/python/00-databricks-init.py b/packages/databricks-vscode/resources/python/00-databricks-init.py index 3292b43f5..4cf9bb42f 100644 --- a/packages/databricks-vscode/resources/python/00-databricks-init.py +++ b/packages/databricks-vscode/resources/python/00-databricks-init.py @@ -3,11 +3,14 @@ import json from typing import Any, Union, List import os +import sys import time import shlex import warnings import tempfile +# prevent sum from pyskaprk.sql.functions from shadowing the builtin sum +builtinSum = sys.modules['builtins'].sum def logError(function_name: str, e: Union[str, Exception]): import sys @@ -403,14 +406,20 @@ def init_ui(self): def update_ticks( self, stages, - inflight_tasks: int + inflight_tasks: int, + done: bool ) -> None: - total_tasks = sum(map(lambda x: x.num_tasks, stages)) - completed_tasks = sum(map(lambda x: x.num_completed_tasks, stages)) + total_tasks = builtinSum(map(lambda x: x.num_tasks, stages)) + completed_tasks = builtinSum(map(lambda x: x.num_completed_tasks, stages)) if total_tasks > 0: self._ticks = total_tasks self._tick = completed_tasks - self._bytes_read = sum(map(lambda x: x.num_bytes_read, stages)) + self._bytes_read = builtinSum(map(lambda x: x.num_bytes_read, stages)) + + if done: + self._tick = self._ticks + self._running = 0 + if self._tick is not None and self._tick >= 0: self.output() self._running = inflight_tasks @@ -432,7 +441,6 @@ def _bytes_to_string(size: int) -> str: i += 1 result = float(size) / Progress.SI_BYTE_SIZES[i] return f"{result:.1f} {Progress.SI_BYTE_SUFFIXES[i]}" - class ProgressHandler: def __init__(self): @@ -454,7 +462,7 @@ def __call__(self, self.op_id = operation_id self.reset() - self.p.update_ticks(stages, inflight_tasks) + self.p.update_ticks(stages, inflight_tasks, done) spark.clearProgressHandlers() spark.registerProgressHandler(ProgressHandler())