Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] benchmark shows 0.00MB consumed #1018

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions xformers/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,31 +530,37 @@ def benchmark_run_and_compare(

name = None
try:
for benchmark_object in benchmarks_generator:
is_optimized = (
benchmark_object._task_spec.description not in BASELINE_DESCRIPTIONS
)
metadata = {}
if is_optimized:
metadata[META_ALGORITHM] = benchmark_object._task_spec.description
benchmark_object._task_spec = replace(
benchmark_object._task_spec, description=optimized_label
)
elif (
omit_baselines
or (
benchmark_object._task_spec.sub_label,
benchmark_object._task_spec.num_threads,
)
in skip_vanilla_tasks
):
continue
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
mem_begin = torch.cuda.max_memory_allocated() / 2**20

for benchmark_object in benchmarks_generator:
memory = math.inf
try:

is_optimized = (
benchmark_object._task_spec.description
not in BASELINE_DESCRIPTIONS
)
metadata = {}
if is_optimized:
metadata[
META_ALGORITHM
] = benchmark_object._task_spec.description
benchmark_object._task_spec = replace(
benchmark_object._task_spec, description=optimized_label
)
elif (
omit_baselines
or (
benchmark_object._task_spec.sub_label,
benchmark_object._task_spec.num_threads,
)
in skip_vanilla_tasks
):
continue

torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
mem_begin = torch.cuda.max_memory_allocated() / 2**20
benchmark_object._task_spec = replace(
benchmark_object._task_spec, env=env
)
Expand All @@ -566,6 +572,9 @@ def benchmark_run_and_compare(
name = measurement.task_spec.description
memory = torch.cuda.max_memory_allocated() / 2**20 - mem_begin
measurement.mem_use = memory

torch.cuda.reset_peak_memory_stats()
mem_begin = torch.cuda.max_memory_allocated() / 2**20
except RuntimeError as e:
if not _is_oom_error(e):
raise
Expand Down