Skip to content

Commit

Permalink
add early stopping of benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii committed May 17, 2024
1 parent 5fa1893 commit ab16003
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 9 deletions.
2 changes: 1 addition & 1 deletion benchmarks/inference/llm-bench/run_example.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python -m src.benchmark_runner --model "facebook/opt-125m" --api dummy --config_files ./configs/*
python -m src.llm_bench.benchmark_runner --model "facebook/opt-125m" --api dummy --config_files ./configs/*
36 changes: 29 additions & 7 deletions benchmarks/inference/llm-bench/src/llm_bench/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class BenchmarkConfig(BaseConfigModel):
max_new_tokens: int = 60
max_new_tokens_var: float = 0.3
streaming: bool = False
early_stop_latency: float = 10.0


class ClientLauncher:
Expand Down Expand Up @@ -189,6 +190,7 @@ def __init__(
use_threading=self.config.use_threading,
prompt_generator=self.prompt_generator,
)
self.all_responses = []

# TODO: fix type hint
def _benchmark_settings(self) -> Iterable[Tuple[int, PromptConfig]]:
Expand Down Expand Up @@ -260,7 +262,9 @@ def _get_output_path(self, prompt_config: PromptConfig, num_clients: int) -> Pat
output_file = f"prompt{prompt_config.prompt_length}_gen{prompt_config.max_new_tokens}_clients{num_clients}.json"
return output_dir / output_file

def _save_results(self, prompt_config: PromptConfig, num_clients: int) -> None:
def _process_responses(
self, prompt_config: PromptConfig, num_clients: int
) -> List[Response]:
output_path = self._get_output_path(
prompt_config=prompt_config, num_clients=num_clients
)
Expand All @@ -270,25 +274,38 @@ def _save_results(self, prompt_config: PromptConfig, num_clients: int) -> None:
all_responses = []
while True:
try:
response = self.client_launcher.get_response()
all_responses.append(response.to_dict())
all_responses.append(self.client_launcher.get_response())
except queue.Empty:
break

os.makedirs(output_path.parent, exist_ok=True)
with open(output_path, "w") as fh:
json.dump(all_responses, fh, indent=2)
json.dump([r.to_dict() for r in all_responses], fh, indent=2)

logger.info(f"Saved {len(all_responses)} responses to {output_path}")

return all_responses

def _check_early_stop(self, all_responses: List[Response]) -> bool:
mean_latency = sum([r.request_time for r in all_responses]) / len(all_responses)
if mean_latency >= self.config.early_stop_latency:
logger.info(
f"Mean latency of {mean_latency:.2f} exceeds early stopping threshold of {self.config.early_stop_latency}. Stopping early."
)
return True
return False

def run(self) -> None:
# Start the client service
self.client_launcher.start_service()

# Generate all benchmark settings from user config(s)
for num_clients_list, prompt_config in self._benchmark_settings():
for num_clients in num_clients_list:
# TODO: implement early stopping based on response latency
early_stop = False
for num_clients in sorted(num_clients_list):
if early_stop:
break

logger.info(
f"Running benchmark with {num_clients} client(s) and prompt config: {prompt_config}"
)
Expand All @@ -302,7 +319,12 @@ def run(self) -> None:
self.client_launcher.run_parallel_clients(num_clients=num_clients)

# Process raw responses and save results to file
self._save_results(prompt_config=prompt_config, num_clients=num_clients)
all_responses = self._process_responses(
prompt_config=prompt_config, num_clients=num_clients
)

# Check early stopping condition
early_stop = self._check_early_stop(all_responses)

# Stop the client service
self.client_launcher.stop_service()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

class DummyClientConfig(BaseConfigModel):
model: str
dummy_client_latency_time: float = 1.0


class DummyClient(BaseClient):
def __init__(self, config: DummyClientConfig) -> None:
super().__init__(config)
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model)
self.latency_time = config.dummy_client_latency_time

def start_service(self) -> Status:
return Status("OK")
Expand All @@ -31,7 +33,7 @@ def prepare_request(self, prompt: Prompt) -> Dict[str, Any]:
return {"input_text": prompt.text, "max_new_tokens": prompt.max_new_tokens}

def send_request(self, request_kwargs: Dict[str, Any]) -> Any:
time.sleep(random.uniform(0.1, 0.2))
time.sleep(random.uniform(self.latency_time - 0.1, self.latency_time + 0.2))
output_text = self.tokenizer.decode(
random.choices(
self.tokenizer.encode(request_kwargs["input_text"]),
Expand Down
23 changes: 23 additions & 0 deletions benchmarks/inference/llm-bench/tests/test_early_stop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from llm_bench import parse_args_to_configs, BenchmarkRunner


@pytest.mark.parametrize("num_clients", [(1, 2, 4)], indirect=True)
def test_early_stop(benchmark_args):
benchmark_args += [
"--early_stop_latency",
"1",
"--dummy_client_latency_time",
"2.0",
]
print(benchmark_args)
benchmark_config, client_config = parse_args_to_configs(benchmark_args)
benchmark_runner = BenchmarkRunner(benchmark_config, client_config)
benchmark_runner.run()

expected_results = 1
actual_results = len(list(benchmark_runner._get_output_dir().glob("*.json")))
assert (
expected_results == actual_results
), f"Number of result files ({actual_results}) does not match expected number ({expected_results})."

0 comments on commit ab16003

Please sign in to comment.