From 23e2fd71be8058b045a2e7eabf5189566a30bb3b Mon Sep 17 00:00:00 2001 From: Yandu Oppacher Date: Thu, 13 Nov 2014 16:02:33 -0500 Subject: [PATCH] Refactor profiler code Code review fixes Refactor of profiler, and moved tests around Fix doc Added a profile collector to accumulate the profilers per stage Remove unused method --- docs/configuration.md | 3 + python/pyspark/__init__.py | 2 + python/pyspark/context.py | 50 ++++---------- python/pyspark/profiler.py | 136 +++++++++++++++++++++++++++++++++++++ python/pyspark/rdd.py | 15 ++-- python/pyspark/tests.py | 38 ++++++++--- python/pyspark/worker.py | 13 ++-- python/run-tests | 1 + 8 files changed, 198 insertions(+), 60 deletions(-) create mode 100644 python/pyspark/profiler.py diff --git a/docs/configuration.md b/docs/configuration.md index 8839162c3a13e..bbe898111efb9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -243,6 +243,9 @@ Apart from these, the following properties are also available, and may be useful or it will be displayed before the driver exiting. It also can be dumped into disk by `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, they will not be displayed automatically before driver exiting. + + By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by + passing a profiler class in as a parameter to the `SparkContext` constructor. diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 9556e4718e585..6149e90aa9f3e 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -45,6 +45,7 @@ from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer +from pyspark.profiler import BasicProfiler # for back compatibility from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row @@ -52,4 +53,5 @@ __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", + "BasicProfiler", ] diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ec67ec8d0f824..15072a4c5ffe2 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,7 +20,6 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile -import atexit from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -33,6 +32,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call +from pyspark.profiler import ProfilerCollector from py4j.java_collections import ListConverter @@ -66,7 +66,7 @@ class SparkContext(object): def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, - gateway=None, jsc=None): + gateway=None, jsc=None, profiler=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -102,14 +102,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc) + conf, jsc, profiler) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc): + conf, jsc, profiler): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -192,7 +192,11 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() # profiling stats collected for each PythonRDD - self._profile_stats = [] + if self._conf.get("spark.python.profile", "false") == "true": + self.profiler_collector = ProfilerCollector(profiler) + self.profiler_collector.profiles_dump_path = self._conf.get("spark.python.profile.dump", None) + else: + self.profiler_collector = None def _initialize_context(self, jconf): """ @@ -826,39 +830,11 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) - def _add_profile(self, id, profileAcc): - if not self._profile_stats: - dump_path = self._conf.get("spark.python.profile.dump") - if dump_path: - atexit.register(self.dump_profiles, dump_path) - else: - atexit.register(self.show_profiles) - - self._profile_stats.append([id, profileAcc, False]) - def show_profiles(self): - """ Print the profile stats to stdout """ - for i, (id, acc, showed) in enumerate(self._profile_stats): - stats = acc.value - if not showed and stats: - print "=" * 60 - print "Profile of RDD" % id - print "=" * 60 - stats.sort_stats("time", "cumulative").print_stats() - # mark it as showed - self._profile_stats[i][2] = True - - def dump_profiles(self, path): - """ Dump the profile stats into directory `path` - """ - if not os.path.exists(path): - os.makedirs(path) - for id, acc, _ in self._profile_stats: - stats = acc.value - if stats: - p = os.path.join(path, "rdd_%d.pstats" % id) - stats.dump_stats(p) - self._profile_stats = [] + self.profiler_collector.show_profiles() + + def dump_profiles(self): + self.profiler_collector.dump_profiles() def _test(): diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py new file mode 100644 index 0000000000000..c78a693026442 --- /dev/null +++ b/python/pyspark/profiler.py @@ -0,0 +1,136 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import cProfile +import pstats +import os +import atexit +from pyspark.accumulators import PStatsParam + + +class ProfilerCollector(object): + """ + This class keeps track of different profilers on a per + stage basis. Also this is used to create new profilers for + the different stages. + """ + + def __init__(self, profiler): + self.profilers = [] + self.profile_dump_path = None + self.profiler = profiler if profiler else BasicProfiler + + def add_profiler(self, id, profiler): + if not self.profilers: + if self.profile_dump_path: + atexit.register(self.dump_profiles) + else: + atexit.register(self.show_profiles) + + self.profilers.append([id, profiler, False]) + + def dump_profiles(self): + for id, profiler, _ in self.profilers: + profiler.dump(id, self.profile_dump_path) + self.profilers = [] + + def show_profiles(self): + """ Print the profile stats to stdout """ + for i, (id, profiler, showed) in enumerate(self.profilers): + if not showed and profiler: + profiler.show(id) + # mark it as showed + self.profilers[i][2] = True + + def new_profiler(self, ctx): + return self.profiler(ctx) + + +class BasicProfiler(object): + """ + + :: DeveloperApi :: + + PySpark supports custom profilers, this is to allow for different profilers to + be used as well as outputting to different formats than what is provided in the + BasicProfiler. + + A custom profiler has to define or inherit the following methods: + profile - will produce a system profile of some sort. + show - shows collected profiles for this profiler in a readable format + dump - dumps the profiles to a path + add - adds a profile to the existing accumulated profile + + The profiler class is chosen when creating a SparkContext + + >>> from pyspark.context import SparkContext + >>> from pyspark.conf import SparkConf + >>> from pyspark.profiler import BasicProfiler + >>> class MyCustomProfiler(BasicProfiler): + ... def show(self, id): + ... print "My custom profiles for RDD:%s" % id + ... + >>> conf = SparkConf().set("spark.python.profile", "true") + >>> sc = SparkContext('local', 'test', conf=conf, profiler=MyCustomProfiler) + >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) + [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.show_profiles() + My custom profiles for RDD:1 + My custom profiles for RDD:2 + >>> sc.stop() + """ + + def __init__(self, ctx): + self._new_profile_accumulator(ctx) + + def profile(self, to_profile): + """ Runs and profiles the method to_profile passed in. A profile object is returned. """ + pr = cProfile.Profile() + pr.runcall(to_profile) + st = pstats.Stats(pr) + st.stream = None # make it picklable + st.strip_dirs() + return st + + def show(self, id): + """ Print the profile stats to stdout, id is the RDD id """ + stats = self._accumulator.value + if stats: + print "=" * 60 + print "Profile of RDD" % id + print "=" * 60 + stats.sort_stats("time", "cumulative").print_stats() + + def dump(self, id, path): + """ Dump the profile into path, id is the RDD id """ + if not os.path.exists(path): + os.makedirs(path) + stats = self._accumulator.value + if stats: + p = os.path.join(path, "rdd_%d.pstats" % id) + stats.dump_stats(p) + + def _new_profile_accumulator(self, ctx): + """ + Creates a new accumulator for combining the profiles of different + partitions of a stage + """ + self._accumulator = ctx.accumulator(None, PStatsParam) + + def add(self, accum_value): + """ Adds a new profile to the existing accumulated value """ + self._accumulator.add(accum_value) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6d9594f6a55ac..45f1a00fbba4d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -31,7 +31,6 @@ import random from math import sqrt, log, isinf, isnan -from pyspark.accumulators import PStatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -2104,9 +2103,13 @@ def _jrdd(self): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" - profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None - command = (self.func, profileStats, self._prev_jrdd_deserializer, + + if self.ctx.profiler_collector: + profiler = self.ctx.profiler_collector.new_profiler(self.ctx) + else: + profiler = None + + command = (self.func, profiler, self._prev_jrdd_deserializer, self._jrdd_deserializer) # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() @@ -2129,9 +2132,9 @@ def _jrdd(self): broadcast_vars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() - if enable_profile: + if profiler: self._id = self._jrdd_val.id() - self.ctx._add_profile(self._id, profileStats) + self.ctx.profiler_collector.add_profiler(self._id, profiler) return self._jrdd_val def id(self): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index a01bd8d415787..0719c57dc5fb3 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -53,6 +53,7 @@ from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType from pyspark import shuffle +from pyspark.profiler import BasicProfiler _have_scipy = False _have_numpy = False @@ -724,16 +725,13 @@ def setUp(self): conf = SparkConf().set("spark.python.profile", "true") self.sc = SparkContext('local[4]', class_name, conf=conf) + def test_profiler(self): + self.do_computation() - def heavy_foo(x): - for i in range(1 << 20): - x = 1 - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - profiles = self.sc._profile_stats - self.assertEqual(1, len(profiles)) - id, acc, _ = profiles[0] + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + id, acc, _ = profilers[0] stats = acc.value self.assertTrue(stats is not None) width, stat_list = stats.get_print_list([]) @@ -745,6 +743,30 @@ def heavy_foo(x): self.sc.dump_profiles(d) self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + def test_custom_profiler(self): + class TestCustomProfiler(BasicProfiler): + def show_profiles(self, profilers): + return "Custom formatting" + + self.sc.profiler_collector.profiler = TestCustomProfiler + + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + id, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) + + self.assertEqual("Custom formatting", self.sc.show_profiles()) + + def do_computation(self): + def heavy_foo(x): + for i in range(1 << 20): + x = 1 + + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) + class ExamplePointUDT(UserDefinedType): """ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e1552a0b0b4ff..10620b0016720 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,8 +23,6 @@ import time import socket import traceback -import cProfile -import pstats from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -94,19 +92,16 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, stats, deserializer, serializer) = command + (func, profiler, deserializer, serializer) = command init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) - if stats: - p = cProfile.Profile() - p.runcall(process) - st = pstats.Stats(p) - st.stream = None # make it picklable - stats.add(st.strip_dirs()) + if profiler: + st = profiler.profile(process) + profiler.add(st) else: process() except Exception: diff --git a/python/run-tests b/python/run-tests index 9ee19ed6e6b26..787a05b6c2d14 100755 --- a/python/run-tests +++ b/python/run-tests @@ -57,6 +57,7 @@ function run_core_tests() { PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/profiler.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" }