Skip to content
This repository has been archived by the owner on Feb 3, 2025. It is now read-only.

Commit

Permalink
Trimmed Mean Added to make throughput numbers more stable
Browse files Browse the repository at this point in the history
  • Loading branch information
DEKHTIARJonathan committed May 10, 2022
1 parent fa1e35a commit 73b6db5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
13 changes: 12 additions & 1 deletion tftrt/examples/benchmark_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,21 @@ def __init__(self):
self._parser.add_argument(
"--num_warmup_iterations",
type=int,
default=100,
default=200,
help="Number of initial iterations skipped from timing."
)

self._parser.add_argument(
"--trim_mean_percentage",
type=float,
default=0.1,
required=False,
help="Percentage used to trim step timing distribution from both "
"tails (fastest and slowest steps). 0.1 (default value) means that "
"10% of the fastest and slowest iteration will be removed for "
"model throughput computation."
)

self._parser.add_argument(
"--total_max_samples",
type=int,
Expand Down
13 changes: 9 additions & 4 deletions tftrt/examples/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from dataloading_utils import get_force_data_on_gpu_fn

import numpy as np
import scipy as sp
import scipy.stats
import tensorflow as tf

from tensorflow.python.compiler.tensorrt import trt_convert as trt
Expand Down Expand Up @@ -500,11 +502,14 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):

metrics['Total GPU Time (s)'] = int(np.ceil(np.sum(iter_times)))
metrics['Throughput (samples/sec)'] = (
self._args.batch_size / np.mean(iter_times)
)
self._args.batch_size / sp.stats.trim_mean(
iter_times, self._args.trim_mean_percentage))

def timing_metrics(time_arr, log_prefix):
data = dict()
data[f"{log_prefix} Trim Mean [{self._args.trim_mean_percentage * 100}%] (ms)"] = (
sp.stats.trim_mean(time_arr, self._args.trim_mean_percentage) * 1000
)
data[f"{log_prefix} 99th_percentile (ms)"] = np.percentile(
time_arr, q=99, interpolation='lower'
) * 1000
Expand All @@ -522,9 +527,9 @@ def timing_metrics(time_arr, log_prefix):

def log_value(key, val):
if isinstance(val, int):
print(f"- {key:45s}: {val}")
print(f"- {key:50s}: {val}")
else:
print(f"- {key:45s}: {val:.2f}")
print(f"- {key:50s}: {val:.2f}")

for key, val in sorted(metrics.items()):
if isinstance(val, dict):
Expand Down

0 comments on commit 73b6db5

Please sign in to comment.