You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+86-59
Original file line number
Diff line number
Diff line change
@@ -47,15 +47,12 @@ Everything you want to know about Google Cloud TPU
47
47
*[7. JAX Best Practices](#7-jax-best-practices)
48
48
*[7.1. Import convention](#71-import-convention)
49
49
*[7.2. Manage random keys in JAX](#72-manage-random-keys-in-jax)
50
-
*[7.3. Serialize model parameters](#73-serialize-model-parameters)
51
-
*[7.4. Conversion between NumPy arrays and JAX arrays](#74-conversion-between-numpy-arrays-and-jax-arrays)
52
-
*[7.5. Conversion between PyTorch tensors and JAX arrays](#75-conversion-between-pytorch-tensors-and-jax-arrays)
53
-
*[7.6. Type annotation](#76-type-annotation)
54
-
*[7.7. Check if an array is either a NumPy array or a JAX array](#77-check-if-an-array-is-either-a-numpy-array-or-a-jax-array)
55
-
*[7.8. Get the shapes of all parameters in a nested dictionary](#78-get-the-shapes-of-all-parameters-in-a-nested-dictionary)
56
-
*[7.9. The correct way to generate random numbers on CPU](#79-the-correct-way-to-generate-random-numbers-on-cpu)
57
-
*[7.10. Use optimizers from Optax](#710-use-optimizers-from-optax)
58
-
*[7.11. Use the cross-entropy loss implementation from Optax](#711-use-the-cross-entropy-loss-implementation-from-optax)
50
+
*[7.3. Conversion between NumPy arrays and JAX arrays](#73-conversion-between-numpy-arrays-and-jax-arrays)
51
+
*[7.4. Conversion between PyTorch tensors and JAX arrays](#74-conversion-between-pytorch-tensors-and-jax-arrays)
52
+
*[7.5. Get the shapes of all parameters in a nested dictionary](#75-get-the-shapes-of-all-parameters-in-a-nested-dictionary)
53
+
*[7.6. The correct way to generate random numbers on CPU](#76-the-correct-way-to-generate-random-numbers-on-cpu)
54
+
*[7.7. Use optimizers from Optax](#77-use-optimizers-from-optax)
55
+
*[7.8. Use the cross-entropy loss implementation from Optax](#78-use-the-cross-entropy-loss-implementation-from-optax)
59
56
*[8. How Can I...](#8-how-can-i)
60
57
*[8.1. Share files across multiple TPU VM instances](#81-share-files-across-multiple-tpu-vm-instances)
61
58
*[8.2. Monitor TPU usage](#82-monitor-tpu-usage)
@@ -334,7 +331,7 @@ nano ~/.ssh/config
334
331
Add the following content:
335
332
336
333
```
337
-
Host 172.21.12.*
334
+
Host 172.21.12.* 127.0.0.1
338
335
StrictHostKeyChecking no
339
336
UserKnownHostsFile /dev/null
340
337
LogLevel ERROR
@@ -352,7 +349,19 @@ chmod 600 ~/.ssh/config
352
349
353
350
### 5.6. Add the SSH public key of Host 0 to all hosts
354
351
355
-
First, follow the above steps to generate a key pair on Host 0. Then add the generated public key to Google Cloud's SSH keys, and this public key will be automatically propagated to all hosts.
352
+
Generate a key pair on Host 0:
353
+
354
+
```sh
355
+
ssh-keygen -t rsa -f ~/.ssh/id_rsa -N ""
356
+
```
357
+
358
+
View the generated SSH public key:
359
+
360
+
```sh
361
+
cat ~/.ssh/id_rsa.pub
362
+
```
363
+
364
+
Add this public key to the SSH keys in Google Cloud. This key will be automatically propagated to all hosts.
./podrun -- sudo mount 172.21.12.2:/nfs_share /nfs_share
419
-
./podrun -- ln -sf /nfs_share ~/nfs_share
443
+
./podrun -i -- ln -sf /nfs_share ~/nfs_share
420
444
421
-
cd~/nfs_share
422
-
touch meow
423
-
./podrun -iw -- ls ~/nfs_share/meow
445
+
touch ~/nfs_share/meow
446
+
./podrun -i -- ls -la ~/nfs_share/meow
424
447
```
425
448
449
+
Replace `172.21.12.2` with the actual internal IP address of Host 0.
450
+
426
451
### 5.9. Setting up the development environment in TPU Pod
427
452
428
-
TODO: Refer to the steps in setting up the development environment in the TPU VM above, but each command should use `podrun -iw --` to run on all hosts.
If the output contains `TpuDevice`, this means JAX is working as expected.
@@ -482,27 +539,7 @@ print(subkey[1])
482
539
print(subkey[2])
483
540
```
484
541
485
-
### 7.3. Serialize model parameters
486
-
487
-
Normally, the model parameters are represented by a nested dictionary like this:
488
-
489
-
```python
490
-
{
491
-
"embedding": DeviceArray,
492
-
"ff1": {
493
-
"kernel": DeviceArray,
494
-
"bias": DeviceArray
495
-
},
496
-
"ff2": {
497
-
"kernel": DeviceArray,
498
-
"bias": DeviceArray
499
-
}
500
-
}
501
-
```
502
-
503
-
You can use [`flax.serialization.msgpack_serialize`](https://flax.readthedocs.io/en/latest/flax.serialization.html#flax.serialization.msgpack_serialize) to serialize the parameters into bytes, and use [`flax.serialization.msgpack_restore`](https://flax.readthedocs.io/en/latest/flax.serialization.html#flax.serialization.msgpack_serialize) to convert them back.
504
-
505
-
### 7.4. Conversion between NumPy arrays and JAX arrays
542
+
### 7.3. Conversion between NumPy arrays and JAX arrays
506
543
507
544
Use [`np.asarray`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.asarray.html) and [`onp.asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
0 commit comments