From 50cd82963fdedecec91195e1ae931aa6a619daf0 Mon Sep 17 00:00:00 2001 From: Mandana Vaziri Date: Tue, 25 Feb 2025 18:36:57 -0500 Subject: [PATCH] support for granite-io backend in interpreter Signed-off-by: Mandana Vaziri --- examples/hello/hello-graniteio.pdl | 5 +++++ src/pdl/pdl_ast.py | 1 + src/pdl/pdl_llms.py | 17 +++++++++++++++-- 3 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 examples/hello/hello-graniteio.pdl diff --git a/examples/hello/hello-graniteio.pdl b/examples/hello/hello-graniteio.pdl new file mode 100644 index 00000000..bfc016ea --- /dev/null +++ b/examples/hello/hello-graniteio.pdl @@ -0,0 +1,5 @@ +text: +- "Hello!\n" +- model: ibm-granite/granite-3.2-8b-instruct-preview + backend: + transformers: cpu \ No newline at end of file diff --git a/src/pdl/pdl_ast.py b/src/pdl/pdl_ast.py index 35f3da4a..756a05c2 100644 --- a/src/pdl/pdl_ast.py +++ b/src/pdl/pdl_ast.py @@ -378,6 +378,7 @@ class GraniteioModelBlock(ModelBlock): model: ExpressionType[object] platform: Literal[ModelPlatform.GRANITEIO] = ModelPlatform.GRANITEIO intrinsics: ExpressionType[list[GraniteioIntrinsicType]] = [] + backend: ExpressionType[dict[str, Any]] class CodeBlock(Block): diff --git a/src/pdl/pdl_llms.py b/src/pdl/pdl_llms.py index c8f010aa..3d739cbd 100644 --- a/src/pdl/pdl_llms.py +++ b/src/pdl/pdl_llms.py @@ -1,9 +1,15 @@ import asyncio import os +import json +import aconfig import threading from concurrent.futures import Future from typing import Any, Callable, Generator, TypeVar +from granite_io.io.granite_3_2 import Granite3Point2InputOutputProcessor +from granite_io.backend.transformers import TransformersBackend +from granite_io.io.base import ChatCompletionInputs + import httpx import litellm from dotenv import load_dotenv @@ -141,8 +147,15 @@ async def async_generate_text( messages: ModelInput, ) -> tuple[dict[str, Any], Any]: try: - outputs = block.model.process(messages) # type: ignore # TODO - return outputs.response, outputs + if "transformers" in block.backend: + input_json_str = json.dumps({"messages": messages}) + inputs = ChatCompletionInputs.model_validate_json(input_json_str) + io_processor = Granite3Point2InputOutputProcessor( + TransformersBackend(aconfig.Config({"model_name":block.model, "device":block.backend["transformers"]})), + ) + + result = io_processor.create_chat_completion(inputs) + return result.next_message.model_dump(), result.next_message.model_dump() except Exception as exc: message = f"Error during '{block.model}' model call: {repr(exc)}" loc = block.location