Credits to the pyvene library.
DAS can be applied on any type of neural network. It offers an intuition on what causal mechanisms better describe the relationships between neural network inputs.
The authors of DAS provide tutorials on how their library works. The tutorial on how DAS works when aligning a causal model with an MLP was the main source of inspiration for the code of this thesis. The MLP was trained on a hierarchical equality task (Premack 1983). The input is two pairs of objects and the output is True if both pairs contain the same object or if both pairs contain different objects and False otherwise. For example, AABB
and ABCD
are both labeled True, while ABCC
and BBCD
are both labeled False. To reproduce their results, one can simply run:
python3 reproduce_das_experiment.py
To further challenge their hypothesis and check that their claims hold, we experiment with a different task. The pattern ABAB
is the one which would yield True
in this case. Namely, the first and third inputs should be the same, same for the third and forth. The notebook can be found here, or one can run:
python3 mlp/pattern_matching_das.py
What if DAS is trained on the counterfactual data generated from a causal model which models ABAB
, but is then tested on data generated from causal models modelling AABB
or ABBA
. Intuitively, the IIA yielded when using data from ABAB
should be much higher than when testing on the other datasets. This can also be seen as sanity checking DAS if it actually learns to align a specific causal model to a neural network.
To run this experiment, one can run:
python3 mlp/wrong_causal_graph_ablation_study.py
DAS can even be applied to LLMs. Imagine having a simple causal model giving you a clear overview of the causal mechanisms the inputs in a prompt entail. For this thesis, we experiment on GPT-2. One of the requirements for DAS to find the alignments between a causal model and the neural representations is for the neural network (GPT-2 in our case) to have a high performance on the task it tries to solve.
We finetune GPT2ForSequenceClassification to perform well on summing three numbers between 1 and 10 (including). The prompts are of the form X+Y+Z=
, where X
, Y
, Z
are inputs randomly generated from 1 to 10. There are 28 possible outputs, and 1000 possible arrangements of the numbers. The task is to classify the sequence generated by GPT2 after =
into one of the possible 28 classes. To finetune GPT2 on this task, run:
python3 train_gpt2.py
After finetuning the model on the arithmetic task, I made it publicly available on HuggingFace. For further use of this model, one can pass to the model_path
argument this:mara589/arithmetic-gpt2
, resulting in the model used being the one I finetuned.
We define three possible graphs which can abstract the finetuned GPT2. They are represented in the below graphs, where P
is the variable summing each pair of two variables. We refer to these graphs as the arithmetic ones.
We also define a simple group of causal graphs, where each graph just copies in turns each input variable to an intervenable varibale.
There is an intervenable model trained for each of the 12 layers of the LLM, targetting the subspace divided by each of the values in [64, 128, 256, 768, 4608]
, referred as the low rank dimension. All tokens are targetted.
To train the each intervenable model using the arithmetic causal models, run:
python3 src/run_das.py --model_path mara589/arithmetic-gpt2 --causal_model_type arithmetic --n_training 256000 --n_testing 256 --batch_size 1280 --epochs 4
To train each intervenable model using the simple causal models, run:
python3 src/run_das.py --model_path mara589/arithmetic-gpt2 --causal_model_type simple --n_training 256000 --n_testing 256 --batch_size 1280 --epochs 4
Interchange Intervention Accuracy (IIA) is used as evaluation metric for this project.
python3 evaluate_das.py --model_path mara589/arithmetic-gpt2 --results_path results/ --n_testing 25600 --batch_size 256 --causal_model_type simple
python3 evaluate_das.py --model_path mara589/arithmetic-gpt2 --results_path results/ --n_testing 25600 --batch_size 256 --causal_model_type arithmetic
After training the intervenable models for each layer and lower rank dimension listed in the previous section, run a sanity check experiment similar to the one in the MLP section. It is sufficient to run the sanity check on the arithmetic causal models.
python3 visualizations.py --model_path mara589/arithmetic-gpt2 --causal_model_type arithmetic --results_path results/ --experiment sanity_check
We want to check where each variable lives when we use intervenable variables which are only copies of the input variables. After training the intervenable models when aligning the simple causal models with the LLM, one can check the IIA per layer and low rank dimension. To reproduce our plots, run this command:
python3 visualizations.py --model_path mara589/arithmetic-gpt2 --causal_model_type simple --results_path results/ --experiment empirical
Run the following command to obtain 36 graphs weighted by the IIA between any two data points:
python3 disentangle_causal_models.py --model_path mara589/arithmetic-gpt2 --results_path results/ --causal_model_type arithmetic
Run the following command to analyise the cliques in the graphs obtained previously:
python3 analyse_graphs.py
--results_path results/
You can switch the clique finder with one defined in the clique_finders.py
. The available options are:
ExhaustiveCliqueFinder
- Algorithm 457: finding all cliques of an undirected graph returning all maximal cliques, and this finder filters by maximum clique length to obtain all maximum cliquesDegreeHeuristic
- classical removal heuristic based on node degreeRemovalHeuristic
- Approximating maximum independent sets by excluding subgraphs + removal heuristicBranchAndBoundHeuristic
- Solution of Maximum Clique Problemby Using Branch and Bound MethodMaxCliqueHeuristic
- Approximating maximum independent sets by excluding subgraphsMaxCliqueHeuristic_v2
- Listing All Maximal Cliques in Large Sparse Real-World Graphs
The clique finder used in our project is RemovalHeuristic
.
Based on the cliques found previously, we define a DecisionTree
as a binary classifier to obtain rules about the specific data that is part of the cliques. Run the following command to train a classifier per clique data obtained per layer.
python3 classification.py --results_path results/
In this way we obtain interpretable rules about how the LLM reasons when processing different clusters of input. For example, for inputs that are permutations of {7,8,9,10}, none of the causal models defined exhibits a high abstraction level. Therefore, the LLM possibly uses a different causal mechanism to solve the sum of three numbers for higher inputs.