Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into feature/variadic-i…
Browse files Browse the repository at this point in the history
…ntegrate-1d
  • Loading branch information
bbbales2 committed Mar 30, 2021
2 parents 77164cc + 40803f8 commit a2e8e57
Show file tree
Hide file tree
Showing 6 changed files with 532 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ pipeline {
sh "python ./test/code_generator_test.py"
sh "python ./test/signature_parser_test.py"
sh "python ./test/statement_types_test.py"
sh "python ./test/varmat_compatibility_summary_test.py"
sh "python ./test/varmat_compatibility_test.py"
withEnv(['PATH+TBB=./lib/tbb']) {
sh "python ./test/expressions/test_expression_testing_framework.py"
}
Expand Down
1 change: 1 addition & 0 deletions test/sig_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def reference_vector_argument(arg):
overload_scalar = {
"Prim": "double",
"Rev": "stan::math::var",
"RevVarmat": "stan::math::var",
"Fwd": "stan::math::fvar<double>",
"Mix": "stan::math::fvar<stan::math::var>",
}
267 changes: 267 additions & 0 deletions test/varmat_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
#!/usr/bin/python

import itertools
import json
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import os
import Queue
import subprocess
import re
import sys
import tempfile
import threading

from sig_utils import make, handle_function_list, get_signatures
from signature_parser import SignatureParser
from code_generator import CodeGenerator

HERE = os.path.dirname(os.path.realpath(__file__))
TEST_FOLDER = os.path.abspath(os.path.join(HERE, "..", "test"))
sys.path.append(TEST_FOLDER)
WORKING_FOLDER = "test/varmat-compatibility"

TEST_TEMPLATE = """
static void {test_name}() {{
{code}
}}
"""

def run_command(command):
"""
Runs given command and waits until it finishes executing.
:param command: command to execute
"""
proc = subprocess.Popen(command, stdout = subprocess.PIPE, stderr = subprocess.PIPE)
stdout, stderr = proc.communicate()

if proc.poll() == 0:
return (True, stdout, stderr)
else:
return (False, stdout, stderr)

def build_signature(prefix, cpp_code, debug):
"""
Try to build the given cpp code
Return true if the code was successfully built
:param prefix: Prefix to give file names so easier to debug
:param cpp_code: Code to build
:param debug: If true, don't delete temporary files
"""
f = tempfile.NamedTemporaryFile("w", dir = WORKING_FOLDER, prefix = prefix + "_", suffix = "_test.cpp", delete = False)
f.write("#include <test/expressions/expression_test_helpers.hpp>\n\n")
f.write(cpp_code)
f.close()

cpp_path = os.path.join(WORKING_FOLDER, os.path.basename(f.name))

object_path = cpp_path.replace(".cpp", ".o")
dependency_path = cpp_path.replace(".cpp", ".d")
stdout_path = cpp_path.replace(".cpp", ".stdout")
stderr_path = cpp_path.replace(".cpp", ".stderr")

successful, stdout, stderr = run_command([make, object_path])

if successful or not debug:
try:
os.remove(cpp_path)
except OSError:
pass

try:
os.remove(dependency_path)
except OSError:
pass

try:
os.remove(object_path)
except OSError:
pass
else:
if debug:
with open(stdout_path, "w") as stdout_f:
stdout_f.write(stdout.decode("utf-8"))

with open(stderr_path, "w") as stderr_f:
stderr_f.write(stderr.decode("utf-8"))

return successful

def main(functions_or_sigs, results_file, cores, debug):
"""
Attempt to build all the signatures in functions_or_sigs, or all the signatures
associated with all the functions in functions_or_sigs, or if functions_or_sigs
is empty every signature the stanc3 compiler exposes.
Results are written to a results json file. Individual signatures are classified
as either compatible, incompatible, or irrelevant.
Compatible signatures can be compiled with varmat types in every argument that
could possibly be a varmat (the matrix-like ones).
Incompatible signatures cannot all be built, and for irrelevant signatures it does
not make sense to try to build them (there are no matrix arguments, or the function
does not support reverse mode autodiff, etc).
Compilation is done in parallel using the number of specified cores.
:param functions_or_sigs: List of function names and/or signatures to benchmark
:param results_file: File to use as a results cache
:param cores: Number of cores to use for compiling
:param debug: If true, don't delete temporary files
"""
all_signatures = get_signatures()
functions, signatures = handle_function_list(functions_or_sigs)

requested_functions = set(functions)

compatible_signatures = set()
incompatible_signatures = set()
irrelevant_signatures = set()

# Read the arguments and figure out the exact list of signatures to test
signatures_to_check = set()
for signature in all_signatures:
sp = SignatureParser(signature)

if len(requested_functions) > 0 and sp.function_name not in requested_functions:
continue

signatures_to_check.add(signature)

work_queue = Queue.Queue()

# For each signature, generate cpp code to test
for signature in signatures_to_check:
sp = SignatureParser(signature)

if sp.is_high_order():
work_queue.put((n, signature, None))
continue

cpp_code = ""
any_overload_uses_varmat = False

for m, overloads in enumerate(itertools.product(("Prim", "Rev", "RevVarmat"), repeat = sp.number_arguments())):
cg = CodeGenerator()

arg_list_base = cg.build_arguments(sp, overloads, size = 1)

arg_list = []
for overload, arg in zip(overloads, arg_list_base):
if arg.is_reverse_mode() and arg.is_varmat_compatible() and overload.endswith("Varmat"):
any_overload_uses_varmat = True
arg = cg.to_var_value(arg)

arg_list.append(arg)

cg.function_call_assign("stan::math::" + sp.function_name, *arg_list)

cpp_code += TEST_TEMPLATE.format(
test_name = sp.function_name + repr(m),
code=cg.cpp(),
)

if any_overload_uses_varmat:
work_queue.put((work_queue.qsize(), signature, cpp_code))
else:
print("{0} ... Irrelevant".format(signature.strip()))
irrelevant_signatures.add(signature)

output_lock = threading.Lock()

if not os.path.exists(WORKING_FOLDER):
os.mkdir(WORKING_FOLDER)

work_queue_original_length = work_queue.qsize()

# Test if each cpp file builds and update the output file
# This part is done in parallel
def worker():
while True:
try:
n, signature, cpp_code = work_queue.get(False)
except Queue.Empty:
return # If queue is empty, worker quits

# Use signature as filename prefix to make it easier to find
prefix = re.sub('[^0-9a-zA-Z]+', '_', signature.strip())

# Test the signature
successful = build_signature(prefix, cpp_code, debug)

# Acquire a lock to do I/O
with output_lock:
if successful:
result_string = "Success"
compatible_signatures.add(signature)
else:
result_string = "Fail"
incompatible_signatures.add(signature)

print("Results of test {0} / {1}, {2} ... ".format(n, work_queue_original_length, signature.strip()) + result_string)

work_queue.task_done()

for i in range(cores):
threading.Thread(target = worker).start()

work_queue.join()

with open(results_file, "w") as f:
json.dump({ "compatible_signatures" : list(compatible_signatures),
"incompatible_signatures" : list(incompatible_signatures),
"irrelevant_signatures" : list(irrelevant_signatures)
}, f, indent = 4, sort_keys = True)


class FullErrorMsgParser(ArgumentParser):
"""
Modified ArgumentParser that prints full error message on any error.
"""

def error(self, message):
sys.stderr.write("error: %s\n" % message)
self.print_help()
sys.exit(2)


def processCLIArgs():
"""
Define and process the command line interface to the benchmark.py script.
"""
parser = FullErrorMsgParser(
description="Generate and run_command benchmarks.",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--functions",
nargs="+",
type=str,
default=[],
help="Signatures and/or function names to benchmark.",
)
parser.add_argument(
"-j",
type=int,
default=1,
help="Number of parallel cores to use.",
)
parser.add_argument(
"--debug",
action="store_true",
help="Keep cpp, stdout, and stderr for incompatible functions.",
)
parser.add_argument(
"results_file",
type=str,
default=None,
help="File to save results in.",
)
args = parser.parse_args()

main(functions_or_sigs=args.functions, results_file = args.results_file, cores = args.j, debug = args.debug)

if __name__ == "__main__":
processCLIArgs()
Loading

0 comments on commit a2e8e57

Please sign in to comment.