Commit 57474e5 1 parent 96e8b94 commit 57474e5 Copy full SHA for 57474e5
File tree 5 files changed +73
-3
lines changed
cmd/project/starter_examples
5 files changed +73
-3
lines changed Original file line number Diff line number Diff line change 8
8
from transformers import T5Tokenizer , T5ForConditionalGeneration
9
9
10
10
# Initialize the tokenizer and model
11
- tokenizer = T5Tokenizer .from_pretrained ("google/flan-t5-base " )
12
- model = T5ForConditionalGeneration .from_pretrained ("google/flan-t5-base " , device_map = "auto" ).to ("cuda" )
11
+ tokenizer = T5Tokenizer .from_pretrained ("<<MODEL_NAME>> " )
12
+ model = T5ForConditionalGeneration .from_pretrained ("<<MODEL_NAME>> " , device_map = "auto" ).to ("cuda" )
13
13
14
14
15
15
def handler (job : Dict [str , any ]) -> str :
Original file line number Diff line number Diff line change 10
10
11
11
# Initialize the pipeline
12
12
pipe = AutoPipelineForText2Image .from_pretrained (
13
- "stabilityai/sdxl-turbo " , # model name
13
+ "<<MODEL_NAME>> " , # model name
14
14
torch_dtype = torch .float16 , variant = "fp16"
15
15
).to ("cuda" )
16
16
Original file line number Diff line number Diff line change
1
+ # Similar to .gitignore
2
+ # Matches do not sync to the Project Pod or cause the Pod to reload.
3
+
4
+ Dockerfile
5
+ __pycache__/
6
+ *.pyc
7
+ .*.swp
8
+ .git/
9
+ *.tmp
10
+ *.log
Original file line number Diff line number Diff line change
1
+ # Required Python packages get listed here, one per line.
2
+ # Lock the version number to avoid unexpected changes.
3
+
4
+ # You can also install packages from a git repository, e.g.:
5
+ # git+https://github.com/runpod/runpod-python.git
6
+ # To learn more, see https://pip.pypa.io/en/stable/reference/requirements-file-format/
7
+
8
+ << RUNPOD >>
9
+ hf_transfer
10
+
11
+ torch
12
+ transformers
13
+ scipy
Original file line number Diff line number Diff line change
1
+ ''' A starter handler file using RunPod and transformers for audio generation. '''
2
+
3
+ import io
4
+ import base64
5
+ from typing import Dict
6
+
7
+ import scipy .io .wavfile
8
+ from transformers import pipeline
9
+
10
+ import runpod
11
+
12
+
13
+ # Initialize the pipeline
14
+ synthesizer = pipeline ("text-to-audio" , "<<MODEL_NAME>>" , device = 0 )
15
+
16
+
17
+ def handler (job ):
18
+ """
19
+ Processes a text prompt to generate music, returning the result as a base64-encoded WAV audio.
20
+
21
+ Args:
22
+ job (dict): Contains 'input' with a 'prompt' key for the music generation text prompt.
23
+
24
+ Returns:
25
+ str: The generated audio as a base64-encoded string.
26
+ """
27
+ prompt = job ['input' ]['prompt' ]
28
+ print (f"Received prompt: { prompt } " )
29
+
30
+ result = synthesizer (prompt , forward_params = {"do_sample" : True , "max_new_tokens" :300 })
31
+
32
+ audio_data = result ['audio' ]
33
+ sample_rate = result ['sampling_rate' ]
34
+
35
+ # Prepare an in-memory bytes buffer to save the audio
36
+ audio_bytes = io .BytesIO ()
37
+ scipy .io .wavfile .write (audio_bytes , sample_rate , audio_data )
38
+ audio_bytes .seek (0 )
39
+
40
+ # Encode the WAV file to a base64 string
41
+ base64_audio = base64 .b64encode (audio_bytes .read ()).decode ('utf-8' )
42
+
43
+ # Return the base64 encoded audio with the appropriate data URI scheme
44
+ return f"data:audio/wav;base64,{ base64_audio } "
45
+
46
+
47
+ runpod .serverless .start ({"handler" : handler })
You can’t perform that action at this time.
0 commit comments