Skip to content

Commit

Permalink
Refactor the SystemManager across shortfin apps
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Feb 27, 2025
1 parent 205d40d commit f2feef2
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 141 deletions.
48 changes: 9 additions & 39 deletions shortfin/python/shortfin_apps/flux/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,15 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
import threading
from shortfin_apps.utils import SystemManager

import shortfin as sf
from shortfin.interop.support.device_setup import get_selected_devices

logger = logging.getLogger("shortfin-flux.manager")


class SystemManager:
class FluxSystemManager(SystemManager):
def __init__(self, device="local-task", device_ids=None, async_allocs=True):
if any(x in device for x in ["local-task", "cpu"]):
self.ls = sf.host.CPUSystemBuilder().create_system()
elif any(x in device for x in ["hip", "amdgpu"]):
sb = sf.SystemBuilder(
system_type="amdgpu", amdgpu_async_allocations=async_allocs
)
if device_ids:
sb.visible_devices = sb.available_devices
sb.visible_devices = get_selected_devices(sb, device_ids)
self.ls = sb.create_system()
logger.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
self.t = threading.Thread(target=lambda: self.ls.run(self.run()))
self.command_queue = self.ls.create_queue("command")
self.command_writer = self.command_queue.writer()

def start(self):
logger.info("Starting system manager")
self.t.start()

def shutdown(self):
logger.info("Shutting down system manager")
self.command_queue.close()
self.ls.shutdown()

async def run(self):
reader = self.command_queue.reader()
while command := await reader():
...
logger.info("System manager command processor stopped")
super().__init__(
device=device,
device_ids=device_ids,
async_allocs=async_allocs,
logger_name="shortfin-flux.manager",
shutdown_system=True,
)
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/flux/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ...utils import GenerateService, BatcherProcess

from .config_struct import ModelParams
from .manager import SystemManager
from .manager import FluxSystemManager
from .messages import FluxInferenceExecRequest, InferencePhase
from .tokenizer import Tokenizer
from .metrics import measure
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
self,
*,
name: str,
sysman: SystemManager,
sysman: FluxSystemManager,
clip_tokenizers: list[Tokenizer],
t5xxl_tokenizers: list[Tokenizer],
model_params: ModelParams,
Expand Down
10 changes: 6 additions & 4 deletions shortfin/python/shortfin_apps/flux/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .components.generate import ClientGenerateBatchProcess
from .components.config_struct import ModelParams
from .components.io_struct import GenerateReqInput
from .components.manager import SystemManager
from .components.manager import FluxSystemManager
from .components.service import FluxGenerateService
from .components.tokenizer import Tokenizer

Expand Down Expand Up @@ -83,7 +83,7 @@ async def lifespan(app: FastAPI):
sysman.shutdown()


sysman: SystemManager
sysman: FluxSystemManager
services: dict[str, Any] = {}
app = FastAPI(lifespan=lifespan)

Expand All @@ -105,10 +105,12 @@ async def generate_request(gen_req: GenerateReqInput, request: Request):
app.put("/generate")(generate_request)


def configure_sys(args) -> SystemManager:
def configure_sys(args) -> FluxSystemManager:
# Setup system (configure devices, etc).
model_config, topology_config, flagfile, tuning_spec, args = get_configs(args)
sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations)
sysman = FluxSystemManager(
args.device, args.device_ids, args.amdgpu_async_allocations
)
return sysman, model_config, flagfile, tuning_spec


Expand Down
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/llm/components/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def lifecycle(app: FastApi):


from .config_struct import ModelParams, ServerParams
from .manager import SystemManager
from .manager import LlmSystemManager
from .service import LlmGenerateService
from .tokenizer import Tokenizer
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, args):
server_params.update_from_args(args)

# Setup system (configure devices, etc).
sysman = SystemManager(
sysman = LlmSystemManager(
device=args.device,
device_ids=server_params.device_ids,
async_allocs=server_params.amdgpu_async_allocations,
Expand Down
56 changes: 10 additions & 46 deletions shortfin/python/shortfin_apps/llm/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,22 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
import threading
from shortfin_apps.utils import SystemManager

import shortfin as sf
from shortfin.interop.support.device_setup import get_selected_devices

logger = logging.getLogger(__name__)


class SystemManager:
class LlmSystemManager(SystemManager):
def __init__(
self,
device="local-task",
device_ids=None,
async_allocs=True,
amdgpu_allocators=None,
):
if any(x in device for x in ["local-task", "cpu"]):
self.ls = sf.host.CPUSystemBuilder().create_system()
elif any(x in device for x in ["hip", "amdgpu"]):
if amdgpu_allocators is None:
sb = sf.SystemBuilder(
system_type="amdgpu",
amdgpu_async_allocations=async_allocs,
)
else:
sb = sf.SystemBuilder(
system_type="amdgpu",
amdgpu_async_allocations=async_allocs,
amdgpu_allocators=amdgpu_allocators,
)
if device_ids:
sb.visible_devices = sb.available_devices
sb.visible_devices = get_selected_devices(sb, device_ids)
self.ls = sb.create_system()
logger.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
self.t = threading.Thread(target=lambda: self.ls.run(self.run()))
self.command_queue = self.ls.create_queue("command")
self.command_writer = self.command_queue.writer()

def start(self):
logger.info("Starting system manager")
self.t.start()

def shutdown(self):
logger.info("Shutting down system manager")
self.command_queue.close()

async def run(self):
reader = self.command_queue.reader()
while command := await reader():
...
logging.info("System manager command processor stopped")
super().__init__(
device=device,
device_ids=device_ids,
async_allocs=async_allocs,
amdgpu_allocators=amdgpu_allocators,
logger_name=__name__,
shutdown_system=False,
)
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .kvcache.trie_attention_cache import TriePagedAttentionCache
from .kvcache.page_pool import PagePoolConfig, PagePool, PageInfo
from .config_struct import ModelParams, ServerParams
from .manager import SystemManager
from .manager import LlmSystemManager
from .messages import LlmInferenceExecRequest, InferencePhase
from .tokenizer import Tokenizer
from .service_debug_dumper import SERVICE_DEBUG_DUMPER
Expand All @@ -42,7 +42,7 @@ def __init__(
self,
*,
name: str,
sysman: SystemManager,
sysman: LlmSystemManager,
tokenizer: Tokenizer,
model_params: ModelParams,
server_params: "ServerParams",
Expand Down
48 changes: 9 additions & 39 deletions shortfin/python/shortfin_apps/sd/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,15 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
import threading
from shortfin_apps.utils import SystemManager

import shortfin as sf
from shortfin.interop.support.device_setup import get_selected_devices

logger = logging.getLogger("shortfin-sd.manager")


class SystemManager:
class SDXLSystemManager(SystemManager):
def __init__(self, device="local-task", device_ids=None, async_allocs=True):
if any(x in device for x in ["local-task", "cpu"]):
self.ls = sf.host.CPUSystemBuilder().create_system()
elif any(x in device for x in ["hip", "amdgpu"]):
sb = sf.SystemBuilder(
system_type="amdgpu", amdgpu_async_allocations=async_allocs
)
if device_ids:
sb.visible_devices = sb.available_devices
sb.visible_devices = get_selected_devices(sb, device_ids)
self.ls = sb.create_system()
logger.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
self.t = threading.Thread(target=lambda: self.ls.run(self.run()))
self.command_queue = self.ls.create_queue("command")
self.command_writer = self.command_queue.writer()

def start(self):
logger.info("Starting system manager")
self.t.start()

def shutdown(self):
logger.info("Shutting down system manager")
self.command_queue.close()
self.ls.shutdown()

async def run(self):
reader = self.command_queue.reader()
while command := await reader():
...
logger.info("System manager command processor stopped")
super().__init__(
device=device,
device_ids=device_ids,
async_allocs=async_allocs,
logger_name="shortfin-sd.manager",
shutdown_system=True,
)
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ...utils import GenerateService, BatcherProcess

from .config_struct import ModelParams
from .manager import SystemManager
from .manager import SDXLSystemManager
from .messages import InferenceExecRequest, InferencePhase
from .tokenizer import Tokenizer
from .metrics import measure, log_duration_str
Expand All @@ -38,7 +38,7 @@ def __init__(
self,
*,
name: str,
sysman: SystemManager,
sysman: SDXLSystemManager,
tokenizers: list[Tokenizer],
model_params: ModelParams,
fibers_per_device: int,
Expand Down
8 changes: 4 additions & 4 deletions shortfin/python/shortfin_apps/sd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .components.generate import ClientGenerateBatchProcess
from .components.config_struct import ModelParams
from .components.io_struct import GenerateReqInput
from .components.manager import SystemManager
from .components.manager import SDXLSystemManager
from .components.service import SDXLGenerateService
from .components.tokenizer import Tokenizer

Expand Down Expand Up @@ -84,7 +84,7 @@ async def lifespan(app: FastAPI):
sysman.shutdown()


sysman: SystemManager
sysman: SDXLSystemManager
services: dict[str, Any] = {}
app = FastAPI(lifespan=lifespan)

Expand Down Expand Up @@ -115,10 +115,10 @@ async def generate_request(gen_req: GenerateReqInput, request: Request):
)


def configure_sys(args) -> SystemManager:
def configure_sys(args) -> SDXLSystemManager:
# Setup system (configure devices, etc).
model_config, topology_config, flagfile, tuning_spec, args = get_configs(args)
sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations)
sysman = SDXLSystemManager(args.device, args.device_ids, args.amdgpu_async_allocations)
return sysman, model_config, flagfile, tuning_spec


Expand Down
60 changes: 59 additions & 1 deletion shortfin/python/shortfin_apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,69 @@
import urllib
import logging
import asyncio
import threading
from pathlib import Path

import shortfin.array as sfnp
import shortfin as sf
from shortfin_apps.flux.components.manager import SystemManager
from shortfin.interop.support.device_setup import get_selected_devices


class SystemManager:
def __init__(
self,
device="local-task",
device_ids=None,
async_allocs=True,
amdgpu_allocators=None,
logger_name=__name__,
shutdown_system=True,
):
self.logger = logging.getLogger(logger_name)

self.shutdown_system = shutdown_system

if any(x in device for x in ["local-task", "cpu"]):
self.ls = sf.host.CPUSystemBuilder().create_system()
elif any(x in device for x in ["hip", "amdgpu"]):
if amdgpu_allocators is None:
sb = sf.SystemBuilder(
system_type="amdgpu",
amdgpu_async_allocations=async_allocs,
)
else:
sb = sf.SystemBuilder(
system_type="amdgpu",
amdgpu_async_allocations=async_allocs,
amdgpu_allocators=amdgpu_allocators,
)
if device_ids:
sb.visible_devices = sb.available_devices
sb.visible_devices = get_selected_devices(sb, device_ids)
self.ls = sb.create_system()

self.logger.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
self.t = threading.Thread(target=lambda: self.ls.run(self.run()))
self.command_queue = self.ls.create_queue("command")
self.command_writer = self.command_queue.writer()

def start(self):
self.logger.info("Starting system manager")
self.t.start()

def shutdown(self):
self.logger.info("Shutting down system manager")
self.command_queue.close()
if self.shutdown_system:
self.ls.shutdown()

async def run(self):
reader = self.command_queue.reader()
while command := await reader():
...
self.logger.info("System manager command processor stopped")


dtype_to_filetag = {
Expand Down

0 comments on commit f2feef2

Please sign in to comment.