Skip to content

Commit aa995e8

Browse files
authored
Merge pull request #21 from Oisin-M/feat/gfn
Feat/gfn
2 parents daa8fa3 + 009f0b4 commit aa995e8

File tree

1 file changed

+211
-0
lines changed

1 file changed

+211
-0
lines changed

tutorials/pooling/gfn.ipynb

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n",
8+
"\n",
9+
"# GFN-ROM is a resolution-invariant method for MOR suitable for multifidelity applications.\n",
10+
"\n",
11+
"# For further details, see the [GFN repo](https://github.com/Oisin-M/GFN) and [GFN paper](https://arxiv.org/abs/2406.03569)."
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": null,
17+
"metadata": {
18+
"id": "CtOVuaD1oJGf"
19+
},
20+
"outputs": [],
21+
"source": [
22+
"# Install PyTorch\n",
23+
"try:\n",
24+
" import torch\n",
25+
" from torch import nn\n",
26+
"except ImportError:\n",
27+
" !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
28+
" import torch\n",
29+
" from torch import nn"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": null,
35+
"metadata": {
36+
"id": "rXVVqRjT0CHr"
37+
},
38+
"outputs": [],
39+
"source": [
40+
"# Clone and import gfn-rom\n",
41+
"import sys\n",
42+
"try:\n",
43+
" from gfn_rom import pde, defaults, preprocessing, initialisation, gfn_rom, train, test, plotting\n",
44+
"except ImportError:\n",
45+
" try:\n",
46+
" sys.path.append('GFN')\n",
47+
" from gfn_rom import pde, defaults, preprocessing, initialisation, gfn_rom, train, test, plotting\n",
48+
" except ImportError:\n",
49+
" !git clone https://github.com/Oisin-M/GFN.git\n",
50+
" from gfn_rom import pde, defaults, preprocessing, initialisation, gfn_rom, train, test, plotting\n",
51+
"\n",
52+
"import numpy as np\n",
53+
"import scipy"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"pname = 'advection'\n",
63+
"\n",
64+
"# training and test fidelities\n",
65+
"train_fidelities = ['3967']\n",
66+
"test_fidelities = ['3967']\n",
67+
"\n",
68+
"# Naming convention for saving the model\n",
69+
"save_name = ''.join(train_fidelities)"
70+
]
71+
},
72+
{
73+
"cell_type": "code",
74+
"execution_count": null,
75+
"metadata": {},
76+
"outputs": [],
77+
"source": [
78+
"dev = initialisation.set_device()\n",
79+
"initialisation.set_precision(defaults.precision)\n",
80+
"initialisation.create_directories()\n",
81+
"params = torch.tensor(pde.params(pname)).to(dev)\n",
82+
"np.random.seed(defaults.split_seed)\n",
83+
"train_trajs, test_trajs = preprocessing.train_test_split(params, len(train_fidelities), defaults.rate)\n",
84+
"\n",
85+
"xs=scipy.io.loadmat(\"../../dataset/advection_unstructured.mat\")['xx'][:,0]\n",
86+
"ys=scipy.io.loadmat(\"../../dataset/advection_unstructured.mat\")['yy'][:,0]\n",
87+
"meshes_train = [np.vstack([xs,ys]).T]\n",
88+
"meshes_test = meshes_train\n",
89+
"def get_scaled_data(fname=\"../../dataset/advection_unstructured.mat\"):\n",
90+
" U = scipy.io.loadmat(fname)['U']\n",
91+
" U_orig = torch.tensor(U)\n",
92+
" scale, U_sc = preprocessing.scaling(U_orig)\n",
93+
" print('reconstruction error', ((U_orig - preprocessing.undo_scaling(U_sc, scale))**2).sum())\n",
94+
" return scale, U_sc\n",
95+
"sols_train = [get_scaled_data()[1]]\n",
96+
"sols_test = [get_scaled_data()]\n",
97+
"\n",
98+
"sols_train = [x.to(dev) for x in sols_train]\n",
99+
"initialisation.set_seed(defaults.seed)\n",
100+
"start_mesh = sorted(meshes_train, key=lambda x: x.shape[0])[-1]\n",
101+
"update_master = defaults.mode == 'adapt'"
102+
]
103+
},
104+
{
105+
"cell_type": "code",
106+
"execution_count": null,
107+
"metadata": {},
108+
"outputs": [],
109+
"source": [
110+
"model = gfn_rom.GFN_ROM(start_mesh, defaults.N_basis_factor, params.shape[1], defaults.act, defaults.ae_sizes, defaults.mapper_sizes).to(dev)\n",
111+
"print(model.GFN.mesh_m.shape)\n",
112+
"\n",
113+
"# We do all of the possible expansions apriori in the preadaptive case\n",
114+
"# This is a preprocessing step so we don't do any speedup steps here\n",
115+
"if defaults.mode=='preadapt':\n",
116+
" count = np.inf\n",
117+
" while count!=0:\n",
118+
" count = 0\n",
119+
" for mesh_n in meshes_train:\n",
120+
" n_exp, n_agg = model.GFN.reshape_weights(mesh_n, update_master=True)\n",
121+
" count += n_exp\n",
122+
" print(model.GFN.mesh_m.shape)"
123+
]
124+
},
125+
{
126+
"cell_type": "code",
127+
"execution_count": null,
128+
"metadata": {},
129+
"outputs": [],
130+
"source": [
131+
"if not update_master:\n",
132+
" opt = torch.optim.Adam(model.parameters(), lr=defaults.lr, weight_decay=defaults.lambda_)\n",
133+
"else:\n",
134+
" # Cannot update GFN parameters using Adam anymore since we use adaptive method\n",
135+
" # and weights can change shape at each iteration\n",
136+
" # Similarly, cannot use momentum\n",
137+
" opt = torch.optim.SGD(model.parameters(), lr=defaults.lr, weight_decay=defaults.lambda_)"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"try:\n",
147+
" model.load_state_dict(torch.load(\"models/best_model_\"+save_name+\".pt\"))\n",
148+
" print(\"Loading saved network\")\n",
149+
"except FileNotFoundError:\n",
150+
" print(\"Training network\")\n",
151+
" train_losses, test_losses = train.train(model, opt, meshes_train, sols_train, params, train_trajs, test_trajs, update_master, defaults.epochs, defaults.mapper_weight, save_name)\n",
152+
" model.load_state_dict(torch.load(\"models/best_model_\"+save_name+\".pt\"))\n",
153+
" plotting.plot_losses(train_losses, test_losses, save_name)"
154+
]
155+
},
156+
{
157+
"cell_type": "code",
158+
"execution_count": null,
159+
"metadata": {},
160+
"outputs": [],
161+
"source": [
162+
"for i in range(len(test_fidelities)):\n",
163+
" \n",
164+
" print('-'*40)\n",
165+
" print(f'TEST MESH: {test_fidelities[i]}')\n",
166+
" \n",
167+
" scale, U = sols_test[i]\n",
168+
" U = U.to('cpu')\n",
169+
" mesh = meshes_test[i]\n",
170+
"\n",
171+
" model.eval()\n",
172+
" model.to('cpu')\n",
173+
" \n",
174+
" Z, Z_net, x_enc, x_map = test.evaluate_results(model, mesh, U, scale, params.to('cpu'))\n",
175+
" error = abs(Z - Z_net)\n",
176+
" error, rel_error = test.print_results(Z, Z_net, x_enc, x_map)\n",
177+
"\n",
178+
" np.savetxt('errors/relative_errors_train'+save_name+'_test'+test_fidelities[i]+'.txt', [max(rel_error), sum(rel_error)/len(rel_error), min(rel_error)])\n",
179+
" print()"
180+
]
181+
}
182+
],
183+
"metadata": {
184+
"accelerator": "GPU",
185+
"colab": {
186+
"authorship_tag": "ABX9TyPxps/Yo6EhPBLUacBJicyu",
187+
"gpuType": "T4",
188+
"include_colab_link": true,
189+
"provenance": []
190+
},
191+
"kernelspec": {
192+
"display_name": "Python 3 (ipykernel)",
193+
"language": "python",
194+
"name": "python3"
195+
},
196+
"language_info": {
197+
"codemirror_mode": {
198+
"name": "ipython",
199+
"version": 3
200+
},
201+
"file_extension": ".py",
202+
"mimetype": "text/x-python",
203+
"name": "python",
204+
"nbconvert_exporter": "python",
205+
"pygments_lexer": "ipython3",
206+
"version": "3.10.14"
207+
}
208+
},
209+
"nbformat": 4,
210+
"nbformat_minor": 4
211+
}

0 commit comments

Comments
 (0)