Skip to content

Commit 57474e5

Browse files
committed
feat: added more starter examples
1 parent 96e8b94 commit 57474e5

File tree

5 files changed

+73
-3
lines changed

5 files changed

+73
-3
lines changed

cmd/project/starter_examples/LLM/src/handler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from transformers import T5Tokenizer, T5ForConditionalGeneration
99

1010
# Initialize the tokenizer and model
11-
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
12-
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto").to("cuda")
11+
tokenizer = T5Tokenizer.from_pretrained("<<MODEL_NAME>>")
12+
model = T5ForConditionalGeneration.from_pretrained("<<MODEL_NAME>>", device_map="auto").to("cuda")
1313

1414

1515
def handler(job: Dict[str, any]) -> str:

cmd/project/starter_examples/Stable Diffusion/src/handler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# Initialize the pipeline
1212
pipe = AutoPipelineForText2Image.from_pretrained(
13-
"stabilityai/sdxl-turbo", # model name
13+
"<<MODEL_NAME>>", # model name
1414
torch_dtype=torch.float16, variant="fp16"
1515
).to("cuda")
1616

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Similar to .gitignore
2+
# Matches do not sync to the Project Pod or cause the Pod to reload.
3+
4+
Dockerfile
5+
__pycache__/
6+
*.pyc
7+
.*.swp
8+
.git/
9+
*.tmp
10+
*.log
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Required Python packages get listed here, one per line.
2+
# Lock the version number to avoid unexpected changes.
3+
4+
# You can also install packages from a git repository, e.g.:
5+
# git+https://github.com/runpod/runpod-python.git
6+
# To learn more, see https://pip.pypa.io/en/stable/reference/requirements-file-format/
7+
8+
<<RUNPOD>>
9+
hf_transfer
10+
11+
torch
12+
transformers
13+
scipy
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
''' A starter handler file using RunPod and transformers for audio generation. '''
2+
3+
import io
4+
import base64
5+
from typing import Dict
6+
7+
import scipy.io.wavfile
8+
from transformers import pipeline
9+
10+
import runpod
11+
12+
13+
# Initialize the pipeline
14+
synthesizer = pipeline("text-to-audio", "<<MODEL_NAME>>", device=0)
15+
16+
17+
def handler(job):
18+
"""
19+
Processes a text prompt to generate music, returning the result as a base64-encoded WAV audio.
20+
21+
Args:
22+
job (dict): Contains 'input' with a 'prompt' key for the music generation text prompt.
23+
24+
Returns:
25+
str: The generated audio as a base64-encoded string.
26+
"""
27+
prompt = job['input']['prompt']
28+
print(f"Received prompt: {prompt}")
29+
30+
result = synthesizer(prompt, forward_params={"do_sample": True, "max_new_tokens":300})
31+
32+
audio_data = result['audio']
33+
sample_rate = result['sampling_rate']
34+
35+
# Prepare an in-memory bytes buffer to save the audio
36+
audio_bytes = io.BytesIO()
37+
scipy.io.wavfile.write(audio_bytes, sample_rate, audio_data)
38+
audio_bytes.seek(0)
39+
40+
# Encode the WAV file to a base64 string
41+
base64_audio = base64.b64encode(audio_bytes.read()).decode('utf-8')
42+
43+
# Return the base64 encoded audio with the appropriate data URI scheme
44+
return f"data:audio/wav;base64,{base64_audio}"
45+
46+
47+
runpod.serverless.start({"handler": handler})

0 commit comments

Comments
 (0)