Skip to content

Commit 86bf7ef

Browse files
committed
Update
1 parent 0318e5f commit 86bf7ef

File tree

2 files changed

+176
-120
lines changed

2 files changed

+176
-120
lines changed

README.md

+86-59
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,12 @@ Everything you want to know about Google Cloud TPU
4747
* [7. JAX Best Practices](#7-jax-best-practices)
4848
* [7.1. Import convention](#71-import-convention)
4949
* [7.2. Manage random keys in JAX](#72-manage-random-keys-in-jax)
50-
* [7.3. Serialize model parameters](#73-serialize-model-parameters)
51-
* [7.4. Conversion between NumPy arrays and JAX arrays](#74-conversion-between-numpy-arrays-and-jax-arrays)
52-
* [7.5. Conversion between PyTorch tensors and JAX arrays](#75-conversion-between-pytorch-tensors-and-jax-arrays)
53-
* [7.6. Type annotation](#76-type-annotation)
54-
* [7.7. Check if an array is either a NumPy array or a JAX array](#77-check-if-an-array-is-either-a-numpy-array-or-a-jax-array)
55-
* [7.8. Get the shapes of all parameters in a nested dictionary](#78-get-the-shapes-of-all-parameters-in-a-nested-dictionary)
56-
* [7.9. The correct way to generate random numbers on CPU](#79-the-correct-way-to-generate-random-numbers-on-cpu)
57-
* [7.10. Use optimizers from Optax](#710-use-optimizers-from-optax)
58-
* [7.11. Use the cross-entropy loss implementation from Optax](#711-use-the-cross-entropy-loss-implementation-from-optax)
50+
* [7.3. Conversion between NumPy arrays and JAX arrays](#73-conversion-between-numpy-arrays-and-jax-arrays)
51+
* [7.4. Conversion between PyTorch tensors and JAX arrays](#74-conversion-between-pytorch-tensors-and-jax-arrays)
52+
* [7.5. Get the shapes of all parameters in a nested dictionary](#75-get-the-shapes-of-all-parameters-in-a-nested-dictionary)
53+
* [7.6. The correct way to generate random numbers on CPU](#76-the-correct-way-to-generate-random-numbers-on-cpu)
54+
* [7.7. Use optimizers from Optax](#77-use-optimizers-from-optax)
55+
* [7.8. Use the cross-entropy loss implementation from Optax](#78-use-the-cross-entropy-loss-implementation-from-optax)
5956
* [8. How Can I...](#8-how-can-i)
6057
* [8.1. Share files across multiple TPU VM instances](#81-share-files-across-multiple-tpu-vm-instances)
6158
* [8.2. Monitor TPU usage](#82-monitor-tpu-usage)
@@ -334,7 +331,7 @@ nano ~/.ssh/config
334331
Add the following content:
335332

336333
```
337-
Host 172.21.12.*
334+
Host 172.21.12.* 127.0.0.1
338335
StrictHostKeyChecking no
339336
UserKnownHostsFile /dev/null
340337
LogLevel ERROR
@@ -352,7 +349,19 @@ chmod 600 ~/.ssh/config
352349

353350
### 5.6. Add the SSH public key of Host 0 to all hosts
354351

355-
First, follow the above steps to generate a key pair on Host 0. Then add the generated public key to Google Cloud's SSH keys, and this public key will be automatically propagated to all hosts.
352+
Generate a key pair on Host 0:
353+
354+
```sh
355+
ssh-keygen -t rsa -f ~/.ssh/id_rsa -N ""
356+
```
357+
358+
View the generated SSH public key:
359+
360+
```sh
361+
cat ~/.ssh/id_rsa.pub
362+
```
363+
364+
Add this public key to the SSH keys in Google Cloud. This key will be automatically propagated to all hosts.
356365

357366
### 5.7. Configure the `podrun` command
358367

@@ -365,17 +374,30 @@ wget https://raw.githubusercontent.com/ayaka14732/llama-2-jax/d8220b8c95789b14fe
365374
chmod +x podrun
366375
```
367376

368-
Save the internal IP addresses of the other hosts in `~/podips.txt` (one per line). To edit `~/podips.txt`, use the following command:
377+
After downloading, edit this file with nano and replace the `python` on the first line with `python3`.
378+
379+
TODO: Update the source.
380+
381+
Edit `~/podips.txt` using:
369382

370383
```sh
371384
nano ~/podips.txt
372385
```
373386

374-
Enter venv and install Paramiko:
387+
Save the internal IP addresses of the other hosts in `~/podips.txt`, one per line. For example:
375388

376389
```sh
377-
. ~/venv/bin/activate
378-
pip install paramiko
390+
172.21.12.86
391+
172.21.12.87
392+
172.21.12.83
393+
```
394+
395+
A TPU v3-32 includes 4 hosts. Excluding Host 0, there are 3 more hosts. Hence, the `~/podips.txt` for TPU v3-32 should contain 3 IP addresses.
396+
397+
Install Fabric using the system pip3:
398+
399+
```sh
400+
pip3 install fabric
379401
```
380402

381403
Use `podrun` to make all hosts purr like a kitty:
@@ -389,8 +411,10 @@ Use `podrun` to make all hosts purr like a kitty:
389411
Install the NFS server and client:
390412

391413
```sh
392-
./podrun -- DEBIAN_FRONTEND=noninteractive sudo apt-get install -y -qq nfs-common
393-
sudo apt install -y -qq nfs-kernel-server
414+
./podrun -i -- sudo apt-get update -y -qq
415+
./podrun -i -- sudo apt-get upgrade -y -qq
416+
./podrun -- sudo apt-get install -y -qq nfs-common
417+
sudo apt-get install -y -qq nfs-kernel-server
394418
sudo mkdir -p /nfs_share
395419
sudo chown -R nobody:nogroup /nfs_share
396420
sudo chmod 777 /nfs_share
@@ -416,21 +440,54 @@ sudo systemctl restart nfs-kernel-server
416440

417441
./podrun -- sudo mkdir -p /nfs_share
418442
./podrun -- sudo mount 172.21.12.2:/nfs_share /nfs_share
419-
./podrun -- ln -sf /nfs_share ~/nfs_share
443+
./podrun -i -- ln -sf /nfs_share ~/nfs_share
420444

421-
cd ~/nfs_share
422-
touch meow
423-
./podrun -iw -- ls ~/nfs_share/meow
445+
touch ~/nfs_share/meow
446+
./podrun -i -- ls -la ~/nfs_share/meow
424447
```
425448

449+
Replace `172.21.12.2` with the actual internal IP address of Host 0.
450+
426451
### 5.9. Setting up the development environment in TPU Pod
427452

428-
TODO: Refer to the steps in setting up the development environment in the TPU VM above, but each command should use `podrun -iw --` to run on all hosts.
453+
Save to `~/nfs_share/setup.sh`:
454+
455+
```sh
456+
#!/bin/bash
457+
458+
export DEBIAN_FRONTEND=noninteractive
459+
460+
sudo apt-get update -y -qq
461+
sudo apt-get upgrade -y -qq
462+
sudo apt-get install -y -qq golang neofetch zsh byobu
463+
464+
sudo apt-get install -y -qq software-properties-common
465+
sudo add-apt-repository -y ppa:deadsnakes/ppa
466+
sudo apt-get install -y -qq python3.11-full python3.11-dev
467+
468+
sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended
469+
sudo chsh $USER -s /usr/bin/zsh
470+
471+
python3.11 -m venv ~/venv
472+
473+
. ~/venv/bin/activate
474+
475+
pip install -U pip
476+
pip install -U wheel
477+
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
478+
```
479+
480+
Then execute:
481+
482+
```sh
483+
chmod +x ~/nfs_share/setup.sh
484+
./podrun -i ~/nfs_share/setup.sh
485+
```
429486

430487
### 5.10. Verify JAX is working properly
431488

432489
```sh
433-
~/podrun -icw -- ~/venv/bin/python -c 'import jax; jax.distributed.initialize(); jax.process_index() == 0 and print(jax.devices())'
490+
./podrun -ic -- ~/venv/bin/python -c 'import jax; jax.distributed.initialize(); jax.process_index() == 0 and print(jax.devices())'
434491
```
435492

436493
If the output contains `TpuDevice`, this means JAX is working as expected.
@@ -482,27 +539,7 @@ print(subkey[1])
482539
print(subkey[2])
483540
```
484541

485-
### 7.3. Serialize model parameters
486-
487-
Normally, the model parameters are represented by a nested dictionary like this:
488-
489-
```python
490-
{
491-
"embedding": DeviceArray,
492-
"ff1": {
493-
"kernel": DeviceArray,
494-
"bias": DeviceArray
495-
},
496-
"ff2": {
497-
"kernel": DeviceArray,
498-
"bias": DeviceArray
499-
}
500-
}
501-
```
502-
503-
You can use [`flax.serialization.msgpack_serialize`](https://flax.readthedocs.io/en/latest/flax.serialization.html#flax.serialization.msgpack_serialize) to serialize the parameters into bytes, and use [`flax.serialization.msgpack_restore`](https://flax.readthedocs.io/en/latest/flax.serialization.html#flax.serialization.msgpack_serialize) to convert them back.
504-
505-
### 7.4. Conversion between NumPy arrays and JAX arrays
542+
### 7.3. Conversion between NumPy arrays and JAX arrays
506543

507544
Use [`np.asarray`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.asarray.html) and [`onp.asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
508545

@@ -517,7 +554,7 @@ c = onp.array([1, 2, 3]) # NumPy array
517554
d = np.asarray(c) # converted to JAX array
518555
```
519556

520-
### 7.5. Conversion between PyTorch tensors and JAX arrays
557+
### 7.4. Conversion between PyTorch tensors and JAX arrays
521558

522559
Convert a PyTorch tensor to a JAX array:
523560

@@ -548,23 +585,13 @@ UserWarning: The given NumPy array is not writable, and PyTorch does not support
548585

549586
If you need writable tensors, you can use `onp.array` instead of `onp.asarray` to make a copy of the original array.
550587

551-
### 7.6. Type annotation
552-
553-
[google/jaxtyping](https://github.com/google/jaxtyping)
554-
555-
### 7.7. Check if an array is either a NumPy array or a JAX array
556-
557-
```python
558-
isinstance(a, (np.ndarray, onp.ndarray))
559-
```
560-
561-
### 7.8. Get the shapes of all parameters in a nested dictionary
588+
### 7.5. Get the shapes of all parameters in a nested dictionary
562589

563590
```python
564591
jax.tree_map(lambda x: x.shape, params)
565592
```
566593

567-
### 7.9. The correct way to generate random numbers on CPU
594+
### 7.6. The correct way to generate random numbers on CPU
568595

569596
Use the [jax.default_device()](https://jax.readthedocs.io/en/latest/_autosummary/jax.default_device.html) context manager:
570597

@@ -581,9 +608,9 @@ with jax.default_device(device_cpu):
581608

582609
See <https://github.com/google/jax/discussions/9691#discussioncomment-3650311>.
583610

584-
### 7.10. Use optimizers from Optax
611+
### 7.7. Use optimizers from Optax
585612

586-
### 7.11. Use the cross-entropy loss implementation from Optax
613+
### 7.8. Use the cross-entropy loss implementation from Optax
587614

588615
`optax.softmax_cross_entropy_with_integer_labels`
589616

0 commit comments

Comments
 (0)