Skip to content

Commit 74bd338

Browse files
committed
basic BIML files
1 parent c420f8d commit 74bd338

8 files changed

+3163
-1
lines changed

README.md

+94-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,94 @@
1-
# BIML-sysgen
1+
# Behaviorally-Informed Meta-Learning (BIML)
2+
3+
BIML is a meta-learning approach for guiding neural networks to human-like inductive biases, through high-level guidance or direct human examples. This code shows how to train and evaluate a sequence-to-sequence (seq2seq) transformer, which implements BIML using a form of memory-based meta-learning.
4+
5+
### Credits
6+
This repo borrows from the excellent [PyTorch seq2seq tutorial](https://pytorch.org/tutorials/beginner/translation_transformer.html).
7+
8+
### Using the code
9+
10+
**Training a model**
11+
This demos a simple retrieval task. To train a model that just retrieves a query output from the support set (which contains the query command exactly), you can type:
12+
```python
13+
python train.py --episode_type retrieve --nepochs 10 --fn_out_model net_retrieve.tar
14+
```
15+
which will produce a file `out_models/net_retrieve.tar`.
16+
17+
Use the `-h` option in train.py to view all arguments:
18+
```
19+
optional arguments:
20+
-h, --help show this help message and exit
21+
--fn_out_model FN_OUT_MODEL
22+
*REQUIRED* Filename for saving model checkpoints.
23+
Typically ends in .pt
24+
--dir_model DIR_MODEL
25+
Directory for saving model files
26+
--episode_type EPISODE_TYPE
27+
What type of episodes do we want? See datasets.py for
28+
options
29+
--batch_size BATCH_SIZE
30+
number of episodes per batch
31+
--nepochs NEPOCHS number of training epochs
32+
--lr LR learning rate
33+
--lr_end_factor LR_END_FACTOR
34+
factor X for decrease learning rate linearly from
35+
1.0*lr to X*lr across training
36+
--no_lr_warmup Turn off learning rate warm up (by default, we use 1
37+
epoch of warm up)
38+
--nlayers_encoder NLAYERS_ENCODER
39+
number of layers for encoder
40+
--nlayers_decoder NLAYERS_DECODER
41+
number of layers for decoder
42+
--emb_size EMB_SIZE size of embedding
43+
--ff_mult FF_MULT multiplier for size of the fully-connected layer in
44+
transformer
45+
--dropout DROPOUT dropout applied to embeddings and transformer
46+
--act ACT activation function in the fully-connected layer of
47+
the transformer (relu or gelu)
48+
--save_best Save the "best model" according to validation loss.
49+
--save_best_skip SAVE_BEST_SKIP
50+
Do not bother saving the "best model" for this
51+
fraction of early training
52+
--resume Resume training from a previous checkpoint
53+
```
54+
55+
**Evaluating a model**
56+
To evaluate the accuracy of this model after training, you can type:
57+
```python
58+
python eval.py --episode_type retrieve --fn_out_model net_retrieve.tar --max
59+
```
60+
You can also evaluate the log-likelihood (--ll) and draw samples from the distribution on outputs (--sample).
61+
62+
Use the `-h` option to view all arguments:
63+
```
64+
optional arguments:
65+
-h, --help show this help message and exit
66+
--fn_out_model FN_OUT_MODEL
67+
*REQUIRED*. Filename for loading the model
68+
--dir_model DIR_MODEL
69+
Directory for loading the model file
70+
--max_length_eval MAX_LENGTH_EVAL
71+
Maximum generated sequence length
72+
--episode_type EPISODE_TYPE
73+
What type of episodes do we want? See datasets.py for
74+
options
75+
--dashboard Showing loss curves during training.
76+
--ll Evaluate log-likelihood of validation (val) set
77+
--max Find best outputs for val commands (greedy decoding)
78+
--sample Sample outputs for val commands
79+
--sample_iterative Sample outputs for val commands iteratively
80+
--fit_lapse Fit the lapse rate
81+
--ll_nrep LL_NREP Evaluate each episode this many times when computing
82+
log-likelihood (needed for stochastic remappings)
83+
--ll_p_lapse LL_P_LAPSE
84+
Lapse rate when evaluating log-likelihoods
85+
--verbose Inspect outputs in more detail
86+
```
87+
88+
**Episode types**
89+
90+
See datasets.py for the full set of options. Here are a few key episode types:
91+
- "algebraic+biases" : For training the full BIML model on few-shot grammar induction. Also has a validation set.
92+
- "algebraic_noise" : For training BIML (algebraic only). Also has a validation set.
93+
- "few_shot_gold" : For evaluating BIML and people on the same few-shot learning task. Validation set only.
94+

0 commit comments

Comments
 (0)