Skip to content

Commit

Permalink
Refactor BatcherProcess and GenerateService between shortfin apps (#1009
Browse files Browse the repository at this point in the history
)
  • Loading branch information
KyleHerndon authored Mar 3, 2025
1 parent a692e4b commit 894f4b9
Show file tree
Hide file tree
Showing 11 changed files with 257 additions and 318 deletions.
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/flux/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from shortfin.interop.fastapi import FastAPIResponder

from .io_struct import GenerateReqInput
from .messages import InferenceExecRequest
from .messages import FluxInferenceExecRequest
from .service import GenerateService
from .metrics import measure

Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
self.result_image = None

async def run(self):
exec = InferenceExecRequest.from_batch(self.gen_req, self.index)
exec = FluxInferenceExecRequest.from_batch(self.gen_req, self.index)
self.client.batcher.submit(exec)
await exec.done
self.result_image = exec.result_image
Expand Down
13 changes: 4 additions & 9 deletions shortfin/python/shortfin_apps/flux/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import shortfin.array as sfnp

from .io_struct import GenerateReqInput
from ...utils import InferenceExecRequest

logger = logging.getLogger("shortfin-sd.messages")

Expand All @@ -29,7 +30,7 @@ class InferencePhase(Enum):
POSTPROCESS = 5


class InferenceExecRequest(sf.Message):
class FluxInferenceExecRequest(InferenceExecRequest):
"""
Generalized request passed for an individual phase of image generation.
Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(
self.post_init()

@staticmethod
def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest":
def from_batch(gen_req: GenerateReqInput, index: int) -> "FluxInferenceExecRequest":
gen_inputs = [
"prompt",
"neg_prompt",
Expand All @@ -138,7 +139,7 @@ def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest":
else:
rec_input = received
rec_inputs[item] = rec_input
return InferenceExecRequest(**rec_inputs)
return FluxInferenceExecRequest(**rec_inputs)

def post_init(self):
"""Determines necessary inference phases and tags them with static program parameters."""
Expand Down Expand Up @@ -184,9 +185,3 @@ def reset(self, phase: InferencePhase):
self.phases = None
self.done = sf.VoidFuture()
self.return_host_array = True


class StrobeMessage(sf.Message):
"""Sent to strobe a queue with fake activity (generate a wakeup)."""

...
120 changes: 13 additions & 107 deletions shortfin/python/shortfin_apps/flux/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,16 @@
import shortfin as sf
import shortfin.array as sfnp

from ...utils import ServiceBase, BatcherProcessBase, prog_isolations

from .config_struct import ModelParams
from .manager import SystemManager
from .messages import InferenceExecRequest, InferencePhase, StrobeMessage
from .messages import FluxInferenceExecRequest, InferencePhase
from .tokenizer import Tokenizer
from .metrics import measure

logger = logging.getLogger("shortfin-flux.service")

prog_isolations = {
"none": sf.ProgramIsolation.NONE,
"per_fiber": sf.ProgramIsolation.PER_FIBER,
"per_call": sf.ProgramIsolation.PER_CALL,
}


def time_shift(mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
Expand Down Expand Up @@ -64,7 +60,7 @@ def get_schedule(
return timesteps.tolist()


class GenerateService:
class GenerateService(ServiceBase):
"""Top level service interface for image generation."""

inference_programs: dict[str, sf.Program]
Expand All @@ -85,15 +81,13 @@ def __init__(
show_progress: bool = False,
trace_execution: bool = False,
):
super().__init__(sysman)
self.name = name

# Application objects.
self.sysman = sysman
self.clip_tokenizers = clip_tokenizers
self.t5xxl_tokenizers = t5xxl_tokenizers
self.model_params = model_params
self.inference_parameters: dict[str, list[sf.BaseProgramParameters]] = {}
self.inference_modules: dict[str, sf.ProgramModule] = {}
self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {}
self.inference_programs: dict[int, dict[str, sf.Program]] = {}
self.trace_execution = trace_execution
Expand Down Expand Up @@ -133,7 +127,7 @@ def __init__(
"decode": {},
}
# Scope dependent objects.
self.batcher = BatcherProcess(self)
self.batcher = FluxBatcherProcess(self)

def get_worker_index(self, fiber):
if fiber not in self.fibers:
Expand All @@ -144,28 +138,6 @@ def get_worker_index(self, fiber):
)
return worker_idx

def load_inference_module(self, vmfb_path: Path, component: str = None):
if not self.inference_modules.get(component):
self.inference_modules[component] = []
self.inference_modules[component].append(
sf.ProgramModule.load(self.sysman.ls, vmfb_path)
)

def load_inference_parameters(
self,
*paths: Path,
parameter_scope: str,
format: str = "",
component: str = None,
):
p = sf.StaticProgramParameters(self.sysman.ls, parameter_scope=parameter_scope)
for path in paths:
logger.info("Loading parameter fiber '%s' from: %s", parameter_scope, path)
p.load(path, format=format)
if not self.inference_parameters.get(component):
self.inference_parameters[component] = []
self.inference_parameters[component].append(p)

def start(self):
# Initialize programs.
for component in self.inference_modules:
Expand Down Expand Up @@ -254,58 +226,23 @@ def __repr__(self):
########################################################################################


class BatcherProcess(sf.Process):
"""The batcher is a persistent process responsible for flighting incoming work
into batches.
"""

class FluxBatcherProcess(BatcherProcessBase):
STROBE_SHORT_DELAY = 0.5
STROBE_LONG_DELAY = 1

def __init__(self, service: GenerateService):
super().__init__(fiber=service.fibers[0])
self.service = service
self.batcher_infeed = self.system.create_queue()
self.pending_requests: set[InferenceExecRequest] = set()
self.strobe_enabled = True
self.strobes: int = 0
self.ideal_batch_size: int = max(service.model_params.max_batch_size)
self.num_fibers = len(service.fibers)

def shutdown(self):
self.batcher_infeed.close()
def handle_inference_request(self, request):
self.pending_requests.add(request)

def submit(self, request: StrobeMessage | InferenceExecRequest):
self.batcher_infeed.write_nodelay(request)
async def process_batches(self):
await self.board_flights()

async def _background_strober(self):
while not self.batcher_infeed.closed:
await asyncio.sleep(
BatcherProcess.STROBE_SHORT_DELAY
if len(self.pending_requests) > 0
else BatcherProcess.STROBE_LONG_DELAY
)
if self.strobe_enabled:
self.submit(StrobeMessage())

async def run(self):
strober_task = asyncio.create_task(self._background_strober())
reader = self.batcher_infeed.reader()
while item := await reader():
self.strobe_enabled = False
if isinstance(item, InferenceExecRequest):
self.pending_requests.add(item)
elif isinstance(item, StrobeMessage):
self.strobes += 1
else:
logger.error("Illegal message received by batcher: %r", item)

self.board_flights()

self.strobe_enabled = True
await strober_task

def board_flights(self):
async def board_flights(self):
waiting_count = len(self.pending_requests)
if waiting_count == 0:
return
Expand All @@ -326,37 +263,6 @@ def board_flights(self):
if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER:
self.service.idle_fibers.add(fiber)

def sort_batches(self):
"""Files pending requests into sorted batches suitable for program invocations."""
reqs = self.pending_requests
next_key = 0
batches = {}
for req in reqs:
is_sorted = False
req_metas = [req.phases[phase]["metadata"] for phase in req.phases.keys()]

for idx_key, data in batches.items():
if not isinstance(data, dict):
logger.error(
"Expected to find a dictionary containing a list of requests and their shared metadatas."
)
if len(batches[idx_key]["reqs"]) >= self.ideal_batch_size:
# Batch is full
next_key = idx_key + 1
continue
elif data["meta"] == req_metas:
batches[idx_key]["reqs"].extend([req])
is_sorted = True
break
else:
next_key = idx_key + 1
if not is_sorted:
batches[next_key] = {
"reqs": [req],
"meta": req_metas,
}
return batches

def board(self, request_bundle, fiber):
pending = request_bundle
if len(pending) == 0:
Expand Down Expand Up @@ -388,7 +294,7 @@ def __init__(
super().__init__(fiber=fiber)
self.service = service
self.worker_index = self.service.get_worker_index(fiber)
self.exec_requests: list[InferenceExecRequest] = []
self.exec_requests: list[FluxInferenceExecRequest] = []

@measure(type="exec", task="inference process")
async def run(self):
Expand Down
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/llm/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from shortfin.interop.fastapi import FastAPIResponder

from .io_struct import GenerateReqInput
from .messages import InferenceExecRequest, InferencePhase
from .messages import LlmInferenceExecRequest, InferencePhase
from .service import GenerateService
from .tokenizer import Encoding

Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
self.streamed_tokens_index = 0

async def run(self):
exec = InferenceExecRequest(
exec = LlmInferenceExecRequest(
phase=InferencePhase.PREFILL,
input_token_ids=self.input_token_ids,
rid=self.gen_req.rid,
Expand Down
13 changes: 4 additions & 9 deletions shortfin/python/shortfin_apps/llm/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@

from .kvcache.base_attention_cache import BasePagedAttentionCache, PageAllocation
from .kvcache.page_pool import PageInfo
from ...utils import InferenceExecRequest


class InferencePhase(Enum):
PREFILL = 1
DECODE = 2


class InferenceExecRequest(sf.Message):
class LlmInferenceExecRequest(InferenceExecRequest):
"""Performs a prefill operation."""

def __init__(self, phase: InferencePhase, input_token_ids: list[int], rid=None):
Expand Down Expand Up @@ -75,7 +76,7 @@ def __repr__(self) -> str:
"""
String representation for logging purposes. It looks like this:
InferenceExecRequest[phase=P,pos=0,rid=test123,flags=host,input_token_ids=[1, 2, 3, 4]]
LlmInferenceExecRequest[phase=P,pos=0,rid=test123,flags=host,input_token_ids=[1, 2, 3, 4]]
Use
`logging.debug("Request: %r", request)`
Expand All @@ -90,10 +91,4 @@ def __repr__(self) -> str:
if self.return_host_array:
flags.append("host")
flags_str = ",".join(flags)
return f"InferenceExecRequest[phase={phase_char},pos={self.start_position},rid={self.rid},flags={flags_str},input_token_ids={self.input_token_ids}]"


class StrobeMessage(sf.Message):
"""Sent to strobe a queue with fake activity (generate a wakeup)."""

...
return f"LlmInferenceExecRequest[phase={phase_char},pos={self.start_position},rid={self.rid},flags={flags_str},input_token_ids={self.input_token_ids}]"
Loading

0 comments on commit 894f4b9

Please sign in to comment.