Skip to content

Commit

Permalink
Merge pull request #2520 from bstatcomp/bugfix_bench_varmat_opencl
Browse files Browse the repository at this point in the history
Fix benchmarking script
  • Loading branch information
t4c1 committed Jun 29, 2021
2 parents 00405fb + 6d39b25 commit c2edd60
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,13 +424,16 @@ def benchmark(
special_arg_values[function_name][n], numbers.Number
):
value = special_arg_values[function_name][n]
if scalar == "double":
if not is_argument_autodiff or (
not is_argument_scalar and (opencl == "base" or varmat == "base")
):
arg_type_prim = cpp_arg_template.replace("SCALAR", "double");
setup += (
" {} {} = stan::test::{}<{}>({}, state.range(0));\n".format(
arg_type,
arg_type_prim,
var_name,
make_arg_function,
arg_type,
arg_type_prim,
value,
)
)
Expand All @@ -440,9 +443,15 @@ def benchmark(
var_name + "_cl", var_name
)
var_name += "_cl"
if is_argument_autodiff:
var_conversions += (
" stan::math::var_value<stan::math::matrix_cl<double>> {}({});\n".format(
var_name + "_var", var_name)
)
var_name += "_var"
elif varmat == "base" and arg_overload == "Rev":
setup += " auto {} = stan::math::to_var_value({});\n".format(
var_name + "_varmat", var_name
var_conversions += " stan::math::var_value<{}> {}({});\n".format(
arg_type_prim, var_name + "_varmat", var_name
)
var_name += "_varmat"
else:
Expand Down

0 comments on commit c2edd60

Please sign in to comment.