Skip to content

Commit

Permalink
Merge pull request apache#23 from Shopify/profile-refactor
Browse files Browse the repository at this point in the history
Refactor profiler code
  • Loading branch information
udnay committed Nov 25, 2014
2 parents f54ccf8 + 23e2fd7 commit 11dfaa6
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 60 deletions.
3 changes: 3 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ Apart from these, the following properties are also available, and may be useful
or it will be displayed before the driver exiting. It also can be dumped into disk by
`sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
they will not be displayed automatically before driver exiting.

By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by
passing a profiler class in as a parameter to the `SparkContext` constructor.
</td>
</tr>
<tr>
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@
from pyspark.accumulators import Accumulator, AccumulatorParam
from pyspark.broadcast import Broadcast
from pyspark.serializers import MarshalSerializer, PickleSerializer
from pyspark.profiler import BasicProfiler

# for back compatibility
from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row

__all__ = [
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
"BasicProfiler",
]
50 changes: 13 additions & 37 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import sys
from threading import Lock
from tempfile import NamedTemporaryFile
import atexit

from pyspark import accumulators
from pyspark.accumulators import Accumulator
Expand All @@ -33,6 +32,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.profiler import ProfilerCollector

from py4j.java_collections import ListConverter

Expand Down Expand Up @@ -66,7 +66,7 @@ class SparkContext(object):

def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
gateway=None, jsc=None):
gateway=None, jsc=None, profiler=None):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
Expand Down Expand Up @@ -102,14 +102,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc)
conf, jsc, profiler)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise

def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc):
conf, jsc, profiler):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
Expand Down Expand Up @@ -192,7 +192,11 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()

# profiling stats collected for each PythonRDD
self._profile_stats = []
if self._conf.get("spark.python.profile", "false") == "true":
self.profiler_collector = ProfilerCollector(profiler)
self.profiler_collector.profiles_dump_path = self._conf.get("spark.python.profile.dump", None)
else:
self.profiler_collector = None

def _initialize_context(self, jconf):
"""
Expand Down Expand Up @@ -826,39 +830,11 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))

def _add_profile(self, id, profileAcc):
if not self._profile_stats:
dump_path = self._conf.get("spark.python.profile.dump")
if dump_path:
atexit.register(self.dump_profiles, dump_path)
else:
atexit.register(self.show_profiles)

self._profile_stats.append([id, profileAcc, False])

def show_profiles(self):
""" Print the profile stats to stdout """
for i, (id, acc, showed) in enumerate(self._profile_stats):
stats = acc.value
if not showed and stats:
print "=" * 60
print "Profile of RDD<id=%d>" % id
print "=" * 60
stats.sort_stats("time", "cumulative").print_stats()
# mark it as showed
self._profile_stats[i][2] = True

def dump_profiles(self, path):
""" Dump the profile stats into directory `path`
"""
if not os.path.exists(path):
os.makedirs(path)
for id, acc, _ in self._profile_stats:
stats = acc.value
if stats:
p = os.path.join(path, "rdd_%d.pstats" % id)
stats.dump_stats(p)
self._profile_stats = []
self.profiler_collector.show_profiles()

def dump_profiles(self):
self.profiler_collector.dump_profiles()


def _test():
Expand Down
136 changes: 136 additions & 0 deletions python/pyspark/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import cProfile
import pstats
import os
import atexit
from pyspark.accumulators import PStatsParam


class ProfilerCollector(object):
"""
This class keeps track of different profilers on a per
stage basis. Also this is used to create new profilers for
the different stages.
"""

def __init__(self, profiler):
self.profilers = []
self.profile_dump_path = None
self.profiler = profiler if profiler else BasicProfiler

def add_profiler(self, id, profiler):
if not self.profilers:
if self.profile_dump_path:
atexit.register(self.dump_profiles)
else:
atexit.register(self.show_profiles)

self.profilers.append([id, profiler, False])

def dump_profiles(self):
for id, profiler, _ in self.profilers:
profiler.dump(id, self.profile_dump_path)
self.profilers = []

def show_profiles(self):
""" Print the profile stats to stdout """
for i, (id, profiler, showed) in enumerate(self.profilers):
if not showed and profiler:
profiler.show(id)
# mark it as showed
self.profilers[i][2] = True

def new_profiler(self, ctx):
return self.profiler(ctx)


class BasicProfiler(object):
"""
:: DeveloperApi ::
PySpark supports custom profilers, this is to allow for different profilers to
be used as well as outputting to different formats than what is provided in the
BasicProfiler.
A custom profiler has to define or inherit the following methods:
profile - will produce a system profile of some sort.
show - shows collected profiles for this profiler in a readable format
dump - dumps the profiles to a path
add - adds a profile to the existing accumulated profile
The profiler class is chosen when creating a SparkContext
>>> from pyspark.context import SparkContext
>>> from pyspark.conf import SparkConf
>>> from pyspark.profiler import BasicProfiler
>>> class MyCustomProfiler(BasicProfiler):
... def show(self, id):
... print "My custom profiles for RDD:%s" % id
...
>>> conf = SparkConf().set("spark.python.profile", "true")
>>> sc = SparkContext('local', 'test', conf=conf, profiler=MyCustomProfiler)
>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
>>> sc.show_profiles()
My custom profiles for RDD:1
My custom profiles for RDD:2
>>> sc.stop()
"""

def __init__(self, ctx):
self._new_profile_accumulator(ctx)

def profile(self, to_profile):
""" Runs and profiles the method to_profile passed in. A profile object is returned. """
pr = cProfile.Profile()
pr.runcall(to_profile)
st = pstats.Stats(pr)
st.stream = None # make it picklable
st.strip_dirs()
return st

def show(self, id):
""" Print the profile stats to stdout, id is the RDD id """
stats = self._accumulator.value
if stats:
print "=" * 60
print "Profile of RDD<id=%d>" % id
print "=" * 60
stats.sort_stats("time", "cumulative").print_stats()

def dump(self, id, path):
""" Dump the profile into path, id is the RDD id """
if not os.path.exists(path):
os.makedirs(path)
stats = self._accumulator.value
if stats:
p = os.path.join(path, "rdd_%d.pstats" % id)
stats.dump_stats(p)

def _new_profile_accumulator(self, ctx):
"""
Creates a new accumulator for combining the profiles of different
partitions of a stage
"""
self._accumulator = ctx.accumulator(None, PStatsParam)

def add(self, accum_value):
""" Adds a new profile to the existing accumulated value """
self._accumulator.add(accum_value)
15 changes: 9 additions & 6 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import random
from math import sqrt, log, isinf, isnan

from pyspark.accumulators import PStatsParam
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer
Expand Down Expand Up @@ -2104,9 +2103,13 @@ def _jrdd(self):
return self._jrdd_val
if self._bypass_serializer:
self._jrdd_deserializer = NoOpSerializer()
enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
command = (self.func, profileStats, self._prev_jrdd_deserializer,

if self.ctx.profiler_collector:
profiler = self.ctx.profiler_collector.new_profiler(self.ctx)
else:
profiler = None

command = (self.func, profiler, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
Expand All @@ -2129,9 +2132,9 @@ def _jrdd(self):
broadcast_vars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()

if enable_profile:
if profiler:
self._id = self._jrdd_val.id()
self.ctx._add_profile(self._id, profileStats)
self.ctx.profiler_collector.add_profiler(self._id, profiler)
return self._jrdd_val

def id(self):
Expand Down
38 changes: 30 additions & 8 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
from pyspark import shuffle
from pyspark.profiler import BasicProfiler

_have_scipy = False
_have_numpy = False
Expand Down Expand Up @@ -724,16 +725,13 @@ def setUp(self):
conf = SparkConf().set("spark.python.profile", "true")
self.sc = SparkContext('local[4]', class_name, conf=conf)


def test_profiler(self):
self.do_computation()

def heavy_foo(x):
for i in range(1 << 20):
x = 1
rdd = self.sc.parallelize(range(100))
rdd.foreach(heavy_foo)
profiles = self.sc._profile_stats
self.assertEqual(1, len(profiles))
id, acc, _ = profiles[0]
profilers = self.sc.profiler_collector.profilers
self.assertEqual(1, len(profilers))
id, acc, _ = profilers[0]
stats = acc.value
self.assertTrue(stats is not None)
width, stat_list = stats.get_print_list([])
Expand All @@ -745,6 +743,30 @@ def heavy_foo(x):
self.sc.dump_profiles(d)
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))

def test_custom_profiler(self):
class TestCustomProfiler(BasicProfiler):
def show_profiles(self, profilers):
return "Custom formatting"

self.sc.profiler_collector.profiler = TestCustomProfiler

self.do_computation()

profilers = self.sc.profiler_collector.profilers
self.assertEqual(1, len(profilers))
id, profiler, _ = profilers[0]
self.assertTrue(isinstance(profiler, TestCustomProfiler))

self.assertEqual("Custom formatting", self.sc.show_profiles())

def do_computation(self):
def heavy_foo(x):
for i in range(1 << 20):
x = 1

rdd = self.sc.parallelize(range(100))
rdd.foreach(heavy_foo)


class ExamplePointUDT(UserDefinedType):
"""
Expand Down
13 changes: 4 additions & 9 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import time
import socket
import traceback
import cProfile
import pstats

from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
Expand Down Expand Up @@ -94,19 +92,16 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
(func, stats, deserializer, serializer) = command
(func, profiler, deserializer, serializer) = command
init_time = time.time()

def process():
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)

if stats:
p = cProfile.Profile()
p.runcall(process)
st = pstats.Stats(p)
st.stream = None # make it picklable
stats.add(st.strip_dirs())
if profiler:
st = profiler.profile(process)
profiler.add(st)
else:
process()
except Exception:
Expand Down
1 change: 1 addition & 0 deletions python/run-tests
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ function run_core_tests() {
PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
run_test "pyspark/serializers.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/profiler.py"
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
}
Expand Down

0 comments on commit 11dfaa6

Please sign in to comment.