Skip to content

Learning Mixtures of Causal Models for Accurate Abstractions of Large Language Models

Notifications You must be signed in to change notification settings

maraPislar/combining-causal-models-for-accurate-NN-abstractions

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 

Repository files navigation

Combining Causal Models for More Accurate Abstractions of Neural Networks

Credits to the pyvene library.

DAS on MLP

DAS can be applied on any type of neural network. It offers an intuition on what causal mechanisms better describe the relationships between neural network inputs.

Reproducing DAS results

The authors of DAS provide tutorials on how their library works. The tutorial on how DAS works when aligning a causal model with an MLP was the main source of inspiration for the code of this thesis. The MLP was trained on a hierarchical equality task (Premack 1983). The input is two pairs of objects and the output is True if both pairs contain the same object or if both pairs contain different objects and False otherwise. For example, AABB and ABCD are both labeled True, while ABCC and BBCD are both labeled False. To reproduce their results, one can simply run:

python3 reproduce_das_experiment.py

Experimenting with another task

To further challenge their hypothesis and check that their claims hold, we experiment with a different task. The pattern ABAB is the one which would yield True in this case. Namely, the first and third inputs should be the same, same for the third and forth. The notebook can be found here, or one can run:

python3 mlp/pattern_matching_das.py

Ablation study: wrong causal model?

What if DAS is trained on the counterfactual data generated from a causal model which models ABAB, but is then tested on data generated from causal models modelling AABB or ABBA. Intuitively, the IIA yielded when using data from ABAB should be much higher than when testing on the other datasets. This can also be seen as sanity checking DAS if it actually learns to align a specific causal model to a neural network.

To run this experiment, one can run:

python3 mlp/wrong_causal_graph_ablation_study.py

DAS on LLMs

DAS can even be applied to LLMs. Imagine having a simple causal model giving you a clear overview of the causal mechanisms the inputs in a prompt entail. For this thesis, we experiment on GPT-2. One of the requirements for DAS to find the alignments between a causal model and the neural representations is for the neural network (GPT-2 in our case) to have a high performance on the task it tries to solve.

Finetuning GPT2 to an arithmetic task

We finetune GPT2ForSequenceClassification to perform well on summing three numbers between 1 and 10 (including). The prompts are of the form X+Y+Z=, where X, Y, Z are inputs randomly generated from 1 to 10. There are 28 possible outputs, and 1000 possible arrangements of the numbers. The task is to classify the sequence generated by GPT2 after = into one of the possible 28 classes. To finetune GPT2 on this task, run:

python3 train_gpt2.py

After finetuning the model on the arithmetic task, I made it publicly available on HuggingFace. For further use of this model, one can pass to the model_path argument this:mara589/arithmetic-gpt2, resulting in the model used being the one I finetuned.

Training Intervenable Models

We define three possible graphs which can abstract the finetuned GPT2. They are represented in the below graphs, where P is the variable summing each pair of two variables. We refer to these graphs as the arithmetic ones.

We also define a simple group of causal graphs, where each graph just copies in turns each input variable to an intervenable varibale.

There is an intervenable model trained for each of the 12 layers of the LLM, targetting the subspace divided by each of the values in [64, 128, 256, 768, 4608], referred as the low rank dimension. All tokens are targetted.

To train the each intervenable model using the arithmetic causal models, run:

python3 src/run_das.py --model_path mara589/arithmetic-gpt2 --causal_model_type arithmetic --n_training 256000 --n_testing 256 --batch_size 1280 --epochs 4

To train each intervenable model using the simple causal models, run:

python3 src/run_das.py --model_path mara589/arithmetic-gpt2 --causal_model_type simple --n_training 256000 --n_testing 256 --batch_size 1280 --epochs 4

Evaluating Trained Intervenable Models

Interchange Intervention Accuracy (IIA) is used as evaluation metric for this project.

Evaluating on the simple causal models:

python3 evaluate_das.py --model_path mara589/arithmetic-gpt2 --results_path results/ --n_testing 25600 --batch_size 256 --causal_model_type simple

Evaluating on the arithmetic causal models:

python3 evaluate_das.py --model_path mara589/arithmetic-gpt2 --results_path results/ --n_testing 25600 --batch_size 256 --causal_model_type arithmetic

Interpretability Analysis

Evaluating with different causal models than the trained one

After training the intervenable models for each layer and lower rank dimension listed in the previous section, run a sanity check experiment similar to the one in the MLP section. It is sufficient to run the sanity check on the arithmetic causal models.

python3 visualizations.py --model_path mara589/arithmetic-gpt2 --causal_model_type arithmetic --results_path results/ --experiment sanity_check

Where does the input/output live?

We want to check where each variable lives when we use intervenable variables which are only copies of the input variables. After training the intervenable models when aligning the simple causal models with the LLM, one can check the IIA per layer and low rank dimension. To reproduce our plots, run this command:

python3 visualizations.py --model_path mara589/arithmetic-gpt2 --causal_model_type simple --results_path results/ --experiment empirical

Learning Rules on the Mixture of Causal Models

Run the following command to obtain 36 graphs weighted by the IIA between any two data points:

python3 disentangle_causal_models.py --model_path mara589/arithmetic-gpt2 --results_path results/ --causal_model_type arithmetic

Clique analysis

Run the following command to analyise the cliques in the graphs obtained previously:

python3 analyse_graphs.py 
--results_path results/

You can switch the clique finder with one defined in the clique_finders.py. The available options are:

The clique finder used in our project is RemovalHeuristic.

Obtaining Interpretable Rules

Based on the cliques found previously, we define a DecisionTree as a binary classifier to obtain rules about the specific data that is part of the cliques. Run the following command to train a classifier per clique data obtained per layer.

python3 classification.py --results_path results/

In this way we obtain interpretable rules about how the LLM reasons when processing different clusters of input. For example, for inputs that are permutations of {7,8,9,10}, none of the causal models defined exhibits a high abstraction level. Therefore, the LLM possibly uses a different causal mechanism to solve the sum of three numbers for higher inputs.

About

Learning Mixtures of Causal Models for Accurate Abstractions of Large Language Models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages