Skip to content

Commit 18d97f0

Browse files
Adding MM eval tests / attention bugfixes (pytorch#1989)
1 parent 51b31c8 commit 18d97f0

File tree

8 files changed

+248
-34
lines changed

8 files changed

+248
-34
lines changed

tests/cache_artifacts.sh

+3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ SMALL_MODEL_URLS=(
1818
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-03082024.pt"
1919
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-tune-llama3-05052024.pt"
2020
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-reward-07122024.pt"
21+
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-meta-vision-10172024.pt"
22+
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-vision-10172024.pt"
23+
2124
)
2225
FULL_MODEL_URL=("s3://pytorch-multimodal/llama2-7b-torchtune.pt")
2326
TOKENIZER_URLS=(

tests/recipes/test_eleuther_eval.py

+117-18
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,40 @@
1313
import pytest
1414

1515
from tests.common import TUNE_PATH
16-
from tests.recipes.utils import llama2_test_config, write_hf_ckpt_config
17-
from tests.test_utils import CKPT_MODEL_PATHS
16+
from tests.recipes.utils import (
17+
llama2_test_config,
18+
llama3_2_vision_test_config,
19+
write_hf_ckpt_config,
20+
write_hf_vision_ckpt_config,
21+
)
22+
from tests.test_utils import CKPT_MODEL_PATHS, gpu_test
1823

1924

2025
class TestEleutherEval:
26+
@pytest.fixture
27+
def hide_correct_version_number(self, monkeypatch):
28+
import importlib.metadata
29+
30+
import_orig = importlib.metadata.version
31+
32+
def mocked_import(name, *args, **kwargs):
33+
if name == "lm-eval":
34+
return "0.4.4" # Hardcode wrong version number
35+
return import_orig(name, *args, **kwargs)
36+
37+
monkeypatch.setattr(importlib.metadata, "version", mocked_import)
38+
39+
@pytest.fixture
40+
def expected_vision_acc(self):
41+
return {
42+
"Science": 0.35,
43+
"Biology": 0.25,
44+
"Chemistry": 0.25,
45+
"Geography": 0.5,
46+
"Math": 0.0,
47+
"Physics": 0.75,
48+
}
49+
2150
@pytest.mark.parametrize(
2251
"eval_name, expected_acc, bsz",
2352
[
@@ -74,22 +103,9 @@ def test_torchtune_checkpoint_eval_results(
74103
acc_result = float(search_results.group(1))
75104
assert math.isclose(acc_result, expected_acc, abs_tol=0.05)
76105

77-
@pytest.fixture
78-
def hide_correct_version_number(self, monkeypatch):
79-
import importlib.metadata
80-
81-
import_orig = importlib.metadata.version
82-
83-
def mocked_import(name, *args, **kwargs):
84-
if name == "lm-eval":
85-
return "0.4.4" # Hardcode wrong version number
86-
return import_orig(name, *args, **kwargs)
87-
88-
monkeypatch.setattr(importlib.metadata, "version", mocked_import)
89-
90106
@pytest.mark.integration_test
91107
@pytest.mark.usefixtures("hide_correct_version_number")
92-
def test_eval_recipe_errors_without_lm_eval(self, capsys, monkeypatch, tmpdir):
108+
def test_eval_recipe_errors_without_lm_eval(self, monkeypatch, tmpdir):
93109
ckpt = "llama2_tune"
94110
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
95111
ckpt_dir = ckpt_path.parent
@@ -123,7 +139,7 @@ def test_eval_recipe_errors_without_lm_eval(self, capsys, monkeypatch, tmpdir):
123139

124140
@pytest.mark.integration_test
125141
def test_eval_recipe_errors_with_quantization_hf_checkpointer(
126-
self, capsys, monkeypatch, tmpdir
142+
self, monkeypatch, tmpdir
127143
):
128144
ckpt = "llama2_hf"
129145
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
@@ -162,7 +178,7 @@ def test_eval_recipe_errors_with_quantization_hf_checkpointer(
162178
runpy.run_path(TUNE_PATH, run_name="__main__")
163179

164180
@pytest.mark.integration_test
165-
def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir):
181+
def test_eval_recipe_errors_with_qat_quantizer(self, monkeypatch, tmpdir):
166182
ckpt = "llama2_tune"
167183
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
168184
ckpt_dir = ckpt_path.parent
@@ -194,3 +210,86 @@ def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir
194210
match="QAT quantizers should only be used during quantization aware training",
195211
):
196212
runpy.run_path(TUNE_PATH, run_name="__main__")
213+
214+
@pytest.mark.integration_test
215+
@gpu_test(gpu_count=1)
216+
def test_meta_eval_vision(self, caplog, monkeypatch, tmpdir, expected_vision_acc):
217+
ckpt = "llama3_2_vision_meta"
218+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
219+
ckpt_dir = ckpt_path.parent
220+
221+
cmd = f"""
222+
tune run eleuther_eval \
223+
--config llama3_2_vision/11B_evaluation \
224+
output_dir={tmpdir} \
225+
checkpointer=torchtune.training.FullModelMetaCheckpointer \
226+
checkpointer.checkpoint_dir='{ckpt_dir}' \
227+
checkpointer.checkpoint_files=[{ckpt_path}] \
228+
~checkpointer.checkpoint_files.filename_format \
229+
~checkpointer.checkpoint_files.max_filename \
230+
checkpointer.output_dir={tmpdir} \
231+
checkpointer.model_type=LLAMA3_VISION \
232+
tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
233+
tokenizer.prompt_template=null \
234+
limit=4 \
235+
dtype=bf16 \
236+
device=cuda \
237+
""".split()
238+
239+
model_config = llama3_2_vision_test_config()
240+
cmd = cmd + model_config
241+
242+
monkeypatch.setattr(sys, "argv", cmd)
243+
with pytest.raises(SystemExit, match=""):
244+
runpy.run_path(TUNE_PATH, run_name="__main__")
245+
246+
out = caplog.text
247+
248+
pattern = r"^\|\s*(?:-\s*)?([^\|]+?)\s*\|\s*(\d+)\s*\|.*?\|.*?\|acc\s*\|\s*↑\s*\|\s*([\d.]+)"
249+
250+
matches = re.findall(pattern, out, re.MULTILINE)
251+
for task_name, _, accuracy in matches:
252+
assert math.isclose(float(accuracy), expected_vision_acc[task_name])
253+
254+
@pytest.mark.integration_test
255+
@gpu_test(gpu_count=1)
256+
def test_hf_eval_vision(self, caplog, monkeypatch, tmpdir, expected_vision_acc):
257+
ckpt = "llama3_2_vision_hf"
258+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
259+
ckpt_dir = ckpt_path.parent
260+
261+
# Config file needed for model conversion.
262+
write_hf_vision_ckpt_config(ckpt_dir)
263+
264+
cmd = f"""
265+
tune run eleuther_eval \
266+
--config llama3_2_vision/11B_evaluation \
267+
output_dir={tmpdir} \
268+
checkpointer=torchtune.training.FullModelHFCheckpointer \
269+
checkpointer.checkpoint_dir='{ckpt_dir}' \
270+
checkpointer.checkpoint_files=[{ckpt_path}]\
271+
~checkpointer.checkpoint_files.filename_format \
272+
~checkpointer.checkpoint_files.max_filename \
273+
checkpointer.output_dir={tmpdir} \
274+
checkpointer.model_type=LLAMA3_VISION \
275+
tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
276+
tokenizer.prompt_template=null \
277+
limit=4 \
278+
dtype=bf16 \
279+
device=cuda \
280+
""".split()
281+
282+
model_config = llama3_2_vision_test_config()
283+
cmd = cmd + model_config
284+
285+
monkeypatch.setattr(sys, "argv", cmd)
286+
with pytest.raises(SystemExit, match=""):
287+
runpy.run_path(TUNE_PATH, run_name="__main__")
288+
289+
out = caplog.text
290+
291+
pattern = r"^\|\s*(?:-\s*)?([^\|]+?)\s*\|\s*(\d+)\s*\|.*?\|.*?\|acc\s*\|\s*↑\s*\|\s*([\d.]+)"
292+
293+
matches = re.findall(pattern, out, re.MULTILINE)
294+
for task_name, _, accuracy in matches:
295+
assert math.isclose(float(accuracy), expected_vision_acc[task_name])

tests/recipes/utils.py

+73
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,58 @@ def llama3_test_config() -> List[str]:
128128
]
129129

130130

131+
def llama3_2_vision_test_config() -> List[str]:
132+
return [
133+
"model=tests.recipes.utils.dummy_vision_model",
134+
"tokenizer._component_=torchtune.models.llama3_2_vision._transform.Llama3VisionTransform",
135+
"tokenizer.patch_size=9",
136+
"tokenizer.max_num_tiles=2",
137+
"tokenizer.tile_size=18",
138+
"tokenizer.max_seq_len=4096",
139+
]
140+
141+
142+
def dummy_vision_model():
143+
from torchtune.models.llama3_2_vision._component_builders import (
144+
llama3_2_vision_decoder,
145+
llama3_2_vision_encoder,
146+
)
147+
from torchtune.modules.model_fusion import DeepFusionModel
148+
149+
vision_encoder = llama3_2_vision_encoder(
150+
clip_embed_dim=128,
151+
clip_num_layers=4,
152+
num_heads=4,
153+
tile_size=18,
154+
patch_size=9,
155+
max_num_tiles=2,
156+
in_channels=3,
157+
clip_hidden_states=[0, 1],
158+
num_layers_projection=2,
159+
decoder_embed_dim=128,
160+
)
161+
vision_decoder = llama3_2_vision_decoder(
162+
vocab_size=128256,
163+
num_layers=4,
164+
fusion_interval=2,
165+
num_special_tokens=2,
166+
num_heads=8,
167+
num_kv_heads=4,
168+
embed_dim=128,
169+
max_seq_len=4096,
170+
encoder_max_seq_len=4096,
171+
)
172+
173+
model = DeepFusionModel(
174+
encoder=vision_encoder,
175+
decoder=vision_decoder,
176+
encoder_trainable=False,
177+
decoder_trainable=False,
178+
fusion_trainable=False,
179+
)
180+
return model
181+
182+
131183
def lora_llama2_test_config(
132184
lora_attn_modules,
133185
apply_lora_to_mlp: bool = False,
@@ -199,6 +251,27 @@ def write_hf_ckpt_config(ckpt_dir: str):
199251
json.dump(config, f)
200252

201253

254+
def write_hf_vision_ckpt_config(ckpt_dir: str):
255+
config = {
256+
"text_config": {
257+
"num_attention_heads": 8,
258+
"num_key_value_heads": 4,
259+
"hidden_size": 128,
260+
"vocab_size": 128256,
261+
"cross_attention_layers": [1, 4],
262+
},
263+
"vision_config": {
264+
"hidden_size": 128,
265+
"image_size": 18,
266+
"max_num_tiles": 2,
267+
"supported_aspect_ratios": [[1, 1], [1, 2], [2, 1]],
268+
},
269+
}
270+
config_file = Path.joinpath(Path(ckpt_dir), "config.json")
271+
with config_file.open("w") as f:
272+
json.dump(config, f)
273+
274+
202275
MODEL_TEST_CONFIGS = {
203276
"llama2": llama2_test_config(),
204277
"llama3": llama3_test_config(),

tests/test_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
"llama2_reward_hf": "/tmp/test-artifacts/small-ckpt-hf-reward-07122024.pt",
3535
"llama3_tune": "/tmp/test-artifacts/small-ckpt-tune-llama3-05052024.pt",
3636
"llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt",
37+
"llama3_2_vision_hf": "/tmp/test-artifacts/small-ckpt-hf-vision-10172024.pt",
38+
"llama3_2_vision_meta": "/tmp/test-artifacts/small-ckpt-meta-vision-10172024.pt",
3739
}
3840

3941
TOKENIZER_PATHS = {

tests/torchtune/modules/test_transformer_decoder.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,48 @@ def transformer_layer(
183183
transformer_layer.eval()
184184
return transformer_layer
185185

186+
@mps_ignored_test()
187+
def test_forward_kv_cache(
188+
self,
189+
input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
190+
transformer_layer: TransformerCrossAttentionLayer,
191+
input_params: Tuple[int, int, int, int],
192+
):
193+
194+
b, _, encoder_seq_len, _ = input_params
195+
transformer_layer.setup_caches(
196+
batch_size=b,
197+
dtype=torch.float32,
198+
encoder_max_seq_len=encoder_seq_len,
199+
decoder_max_seq_len=None,
200+
)
201+
input_x, input_y, mask = input
202+
with torch.no_grad():
203+
# make an initial forward pass which should fill the encoder cache
204+
first_output = transformer_layer(
205+
input_x,
206+
encoder_input=input_y,
207+
encoder_mask=mask,
208+
)
209+
# the second pass should just retrieve from the kv-cache and produce
210+
# identical outputs
211+
output = transformer_layer(
212+
input_x,
213+
encoder_input=None,
214+
encoder_mask=mask,
215+
)
216+
217+
assert_expected(output.mean(), torch.tensor(1.7762), atol=1e-8, rtol=1e-3)
218+
assert_expected(output.shape, input_x.shape)
219+
220+
assert_expected(first_output.shape, output.shape)
221+
assert_expected(first_output.mean(), output.mean())
222+
186223
@mps_ignored_test()
187224
def test_forward(
188225
self,
189226
input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
190-
transformer_layer: TransformerSelfAttentionLayer,
227+
transformer_layer: TransformerCrossAttentionLayer,
191228
) -> None:
192229
input_x, input_y, mask = input
193230
with torch.no_grad():

torchtune/models/gemma2/_attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def forward(
240240
q = self.q_norm(q)
241241

242242
if y is None:
243-
if self.kv_cache is None:
243+
if self.kv_cache is None or not self.cache_enabled:
244244
raise ValueError(
245245
"Must provide y input or use kv_cache to enable streaming decoding"
246246
)

torchtune/modules/attention.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def forward(
195195
and before the softmax. Either:
196196
197197
A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``,
198-
or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
198+
or ``[b x s x self.decoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
199199
A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means
200200
token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask
201201
is used by default.
@@ -249,7 +249,7 @@ def forward(
249249
q = self.q_norm(q)
250250

251251
if y is None:
252-
if self.kv_cache is None:
252+
if self.kv_cache is None or not self.cache_enabled:
253253
raise ValueError(
254254
"Must provide y input or use kv_cache to enable streaming decoding"
255255
)
@@ -273,21 +273,21 @@ def forward(
273273
k = k.transpose(1, 2)
274274
v = v.transpose(1, 2)
275275

276+
# Normalize k
277+
if self.k_norm is not None:
278+
k = self.k_norm(k)
279+
276280
# Update key-value cache
277281
if self.kv_cache is not None and self.cache_enabled:
278282
k, v = self.kv_cache.update(k, v)
279283

280-
# If needed, expand the key and value tensors to have the same shape
281-
# as the query tensor by copying values across the relevant dim
282-
# k,v shape: [b, n_h, s, h_d]
283-
if self.num_heads != self.num_kv_heads:
284-
expand_shape = (-1, -1, q_per_kv, -1, -1)
285-
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
286-
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
287-
288-
# Normalize k
289-
if self.k_norm is not None:
290-
k = self.k_norm(k)
284+
# If needed, expand the key and value tensors to have the same shape
285+
# as the query tensor by copying values across the relevant dim
286+
# k,v shape: [b, n_kv, s, h_d] -> [b, n_h, s, h_d]
287+
if self.num_heads != self.num_kv_heads:
288+
expand_shape = (b, self.num_kv_heads, q_per_kv, -1, self.head_dim)
289+
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
290+
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
291291

292292
output = self._attention_call(
293293
q,

torchtune/modules/transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ def setup_caches(
781781
isinstance(l, TransformerCrossAttentionLayer) for l in self.modules()
782782
)
783783
has_decoder_layers = any(
784-
isinstance(l, TransformerSelfAttentionLayer) for l in self.layers
784+
isinstance(l, TransformerSelfAttentionLayer) for l in self.modules()
785785
)
786786
if has_encoder_layers:
787787
if encoder_max_seq_len is not None:

0 commit comments

Comments
 (0)