Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add llama3 model in together api #108

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions concordia/language_model/together_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,251 @@ def sample_choice(
max_str = responses[idx]

return idx, max_str, {r: logprobs_np[i] for i, r in enumerate(responses)}


def _find_response_start_index_Llama3(tokens):
r"""Finds the start of the response in the prompt.

Args:
tokens: A list of strings.

Returns:
The index of the last occurrence of '<start_of_turn>' followed by 'model'
and '\n', or 1 if the sequence is not found. This corresponds to the start
of the response.
"""
# print(f' Tokens: {tokens}\n\n\n')
assert len(tokens) >= 3, "Response doesn't match expectation."
for i in range(len(tokens) - 4, -1, -1):
if (
tokens[i] == '<|eot_id|>'
and tokens[i + 1] == '<|start_header_id|>'
and tokens[i + 2] == 'assistant'
and tokens[i + 3] == '<|end_header_id|>'
):
return i + 4 # Return the index after the sequence
raise ValueError("Response doesn't match expectation.")

class Llama3(language_model.LanguageModel):
"""Language Model that uses Together AI models."""

def __init__(
self,
model_name: str,
*,
api_key: str | None = None,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
):
"""Initializes the instance.

Args:
model_name: The language model to use. For more details, see
https://api.together.xyz/models.
api_key: The API key to use when accessing the Together AI API. If None,
will use the TOGETHER_AI_API_KEY environment variable.
measurements: The measurements object to log usage statistics to.
channel: The channel to write the statistics to.
"""
if api_key is None:
api_key = os.environ['TOGETHER_AI_API_KEY']
self._api_key = api_key
self._model_name = model_name
self._measurements = measurements
self._channel = channel
self._client = together.Together(api_key=self._api_key)

@override
def sample_text(
self,
prompt: str,
*,
max_tokens: int = language_model.DEFAULT_MAX_TOKENS,
terminators: Collection[str] = language_model.DEFAULT_TERMINATORS,
temperature: float = language_model.DEFAULT_TEMPERATURE,
timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS,
seed: int | None = None,
) -> str:
original_prompt = prompt
prompt = _ensure_prompt_not_too_long(prompt, max_tokens)
messages = [
{
'role': 'system',
'content': (
'You always continue sentences provided '
'by the user and you never repeat what '
'the user has already said. All responses must end with a '
'period. Try not to use lists, but if you must, then '
'always delimit list items using either '
r"semicolons or single newline characters ('\n'), never "
r"delimit list items with double carriage returns ('\n\n')."
),
},
{
'role': 'user',
'content': 'Question: Is Jake a turtle?\nAnswer: Jake is ',
},
{'role': 'assistant', 'content': 'not a turtle.'},
{
'role': 'user',
'content': (
'Question: What is Priya doing right now?\nAnswer: '
+ 'Priya is currently '
),
},
{'role': 'assistant', 'content': 'sleeping.'},
{'role': 'user', 'content': prompt},
]

# gemma2 does not support `tokens` + `max_new_tokens` > 8193.
# gemma2 interprets our `max_tokens`` as their `max_new_tokens`.
# do not know if this is the case for llama3
max_tokens = min(max_tokens, _DEFAULT_NUM_RESPONSE_TOKENS)

result = ''
for attempts in range(_MAX_ATTEMPTS):
if attempts > 0:
seconds_to_sleep = (_SECONDS_TO_SLEEP_WHEN_RATE_LIMITED +
random.uniform(-_JITTER_SECONDS, _JITTER_SECONDS))
if attempts >= _NUM_SILENT_ATTEMPTS:
print(
f'Sleeping for {seconds_to_sleep} seconds... '
+ f'attempt: {attempts} / {_MAX_ATTEMPTS}'
)
time.sleep(seconds_to_sleep)
try:
response = self._client.chat.completions.create(
model=self._model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
stop=terminators,
seed=seed,
stream=False,
)
except (together.error.RateLimitError,
together.error.APIError,
together.error.ServiceUnavailableError) as err:
if attempts >= _NUM_SILENT_ATTEMPTS:
print(f' Exception: {err}')
print(f' Text exception prompt: {prompt}')
if isinstance(err, together.error.APIError):
# If hit the error that arises from a prompt that is too long then
# re-run the trimming function with a more pessimistic guess of the
# the number of characters per token.
prompt = _ensure_prompt_not_too_long(original_prompt,
max_tokens,
guess_chars_per_token=1)
continue
else:
result = response.choices[0].message.content
break

if self._measurements is not None:
self._measurements.publish_datum(
self._channel,
{'raw_text_length': len(result)},
)

return result

def _sample_choice(
self, prompt: str, response: str) -> float:
"""Returns the log probability of the prompt and response."""
original_prompt = prompt
augmented_prompt = _ensure_prompt_not_too_long(prompt, len(response))
attempts = 0
for attempts in range(_MAX_ATTEMPTS):
if attempts > 0:
seconds_to_sleep = (_SECONDS_TO_SLEEP_WHEN_RATE_LIMITED +
random.uniform(-_JITTER_SECONDS, _JITTER_SECONDS))
if attempts >= _NUM_SILENT_ATTEMPTS:
print(
f'Sleeping for {seconds_to_sleep} seconds.. '
+ f'attempt: {attempts} / {_MAX_ATTEMPTS}'
)
time.sleep(seconds_to_sleep)
try:
messages = [
{
'role': 'system',
'content': (
'You always continue sentences provided '
+ 'by the user and you never repeat what '
+ 'the user already said.'
),
},
{
'role': 'user',
'content': 'Question: Is Jake a turtle?\nAnswer: Jake is ',
},
{'role': 'assistant', 'content': 'not a turtle.'},
{
'role': 'user',
'content': (
'Question: What is Priya doing right now?\nAnswer: '
+ 'Priya is currently '
),
},
{'role': 'assistant', 'content': 'sleeping.'},
{'role': 'user', 'content': augmented_prompt},
{'role': 'assistant', 'content': response},
]
result = self._client.chat.completions.create(
model=self._model_name,
messages=messages,
max_tokens=1,
seed=None,
logprobs=1,
stream=False,
echo=True,
)
except (together.error.RateLimitError,
together.error.APIError,
together.error.ServiceUnavailableError) as err:
if attempts >= _NUM_SILENT_ATTEMPTS:
print(f' Exception: {err}')
print(f' Choice exception prompt: {augmented_prompt}')
if isinstance(err, together.error.APIError):
# If hit the error that arises from a prompt that is too long then
# re-run the trimming function with a more pessimistic guess of the
# the number of characters per token.
augmented_prompt = _ensure_prompt_not_too_long(
original_prompt, 1, guess_chars_per_token=1
)
continue
else:
logprobs = result.prompt[0].logprobs
# print(f' Logprobs: {logprobs}\n\n\n')
# for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
# print(f' Token: {token}, Logprob: {logprob}')
# response_idx = _find_response_start_index(logprobs.tokens)
response_idx = _find_response_start_index_Llama3(logprobs.tokens)
response_log_probs = logprobs.token_logprobs[response_idx:]
score = sum(response_log_probs)
return score

raise language_model.InvalidResponseError(
f'Failed to get logprobs after {attempts+1} attempts.\n Exception'
f' prompt: {augmented_prompt}'
)

@override
def sample_choice(
self,
prompt: str,
responses: Sequence[str],
*,
seed: int | None = None,
) -> tuple[int, str, dict[str, float]]:

logprobs_np = np.array([self._sample_choice(prompt, response) for response in responses]).reshape(-1)
print(f" Logprobs_np: {logprobs_np}")
idx = np.argmax(logprobs_np)

# Get the corresponding response string
max_str = responses[idx]

return idx, max_str, {r: logprobs_np[i] for i, r in enumerate(responses)}

5 changes: 4 additions & 1 deletion concordia/language_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def language_model_setup(
elif api_type == 'pytorch_gemma':
cls = pytorch_gemma_model.PyTorchGemmaLanguageModel
elif api_type == 'together_ai':
cls = together_ai.Gemma2
if 'llama' in model_name.lower():
cls = together_ai.Llama3
else:
cls = together_ai.Gemma2
else:
raise ValueError(f'Unrecognized api type: {api_type}')

Expand Down