diff --git a/docs/zh/api/loss.md b/docs/zh/api/loss.md index 1a593f921..caa04f6df 100644 --- a/docs/zh/api/loss.md +++ b/docs/zh/api/loss.md @@ -9,6 +9,7 @@ - L1Loss - L2Loss - L2RelLoss + - MAELoss - MSELoss - MSELossWithL2Decay - IntegralLoss diff --git a/docs/zh/examples/epnn.md b/docs/zh/examples/epnn.md new file mode 100644 index 000000000..74c851988 --- /dev/null +++ b/docs/zh/examples/epnn.md @@ -0,0 +1,210 @@ +# EPNN + +=== "模型训练命令" + + ``` sh + # linux + wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/dstate-16-plas.dat -O datasets/dstate-16-plas.dat + wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/dstress-16-plas.dat -O datasets/dstress-16-plas.dat + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/dstate-16-plas.dat --output datasets/dstate-16-plas.dat + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/dstress-16-plas.dat --output datasets/dstress-16-plas.dat + python epnn.py + ``` + +=== "模型评估命令" + + ``` sh + # linux + wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/dstate-16-plas.dat -O datasets/dstate-16-plas.dat + wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/dstress-16-plas.dat -O datasets/dstress-16-plas.dat + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/dstate-16-plas.dat --output datasets/dstate-16-plas.dat + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/dstress-16-plas.dat --output datasets/dstress-16-plas.dat + python epnn.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/epnn/epnn_pretrained.pdparams + ``` + +## 1. 背景简介 + +这里主要为复现 Elasto-Plastic Neural Network (EPNN) 的 Physics-Informed Neural Network (PINN) 代理模型。将这些物理嵌入神经网络的架构中,可以更有效地训练网络,同时使用更少的数据进行训练,同时增强对训练数据外加载制度的推断能力。EPNN 的架构是模型和材料无关的,即它可以适应各种弹塑性材料类型,包括地质材料和金属;并且实验数据可以直接用于训练网络。为了证明所提出架构的稳健性,我们将其一般框架应用于砂土的弹塑性行为。EPNN 在预测不同初始密度砂土的未观测应变控制加载路径方面优于常规神经网络架构。 + +## 2. 问题定义 + +在神经网络中,信息通过连接的神经元流动。神经网络中每个链接的“强度”是由一个可变的权重决定的: + +$$ +z_l^{\mathrm{i}}=W_{k l}^{\mathrm{i}-1, \mathrm{i}} a_k^{\mathrm{i}-1}+b^{\mathrm{i}-1}, \quad k=1: N^{\mathrm{i}-1} \quad \text { or } \quad \mathbf{z}^{\mathrm{i}}=\mathbf{a}^{\mathrm{i}-1} \mathbf{W}^{\mathrm{i}-1, \mathrm{i}}+b^{\mathrm{i}-1} \mathbf{I} +$$ + +其中 $b$ 是偏置项;$N$ 为不同层中神经元数量;$I$ 指的是所有元素都为 1 的单位向量。 + +## 3. 问题求解 + +接下来开始讲解如何将问题一步一步地转化为 PaddleScience 代码,用深度学习的方法求解该问题。 +为了快速理解 PaddleScience,接下来仅对模型构建、方程构建、计算域构建等关键步骤进行阐述,而其余细节请参考 [API文档](../api/arch.md)。 + +### 3.1 模型构建 + +在 EPNN 问题中,建立网络,用 PaddleScience 代码表示如下 + +``` py linenums="370" +--8<-- +examples/epnn/functions.py:370:390 +--8<-- +``` + +Epnn 参数 input_keys 是输入字段名,output_keys 是输出字段名,node_sizes 是节点大小列表,activations 是激活函数字符串列表,drop_p 是节点丢弃概率。 + +### 3.2 数据生成 + +本案例涉及读取数据生成,如下所示 + +``` py linenums="36" +--8<-- +examples/epnn/epnn.py:36:41 +--8<-- +``` + +``` py linenums="305" +--8<-- +examples/epnn/functions.py:305:320 +--8<-- +``` +这里使用 Data 读取文件构造数据类,然后使用 get_shuffled_data 混淆数据,然后计算需要获取的混淆数据数量 itrain,最后使用 get 获取每组 itrain 数量的 10 组数据。 + +### 3.3 约束构建 + +设置训练数据集和损失计算函数,返回字段,代码如下所示: + +``` py linenums="63" +--8<-- +examples/epnn/epnn.py:63:86 +--8<-- +``` + +`SupervisedConstraint` 的第一个参数是监督约束的读取配置,配置中 `“dataset”` 字段表示使用的训练数据集信息,其各个字段分别表示: + +1. `name`: 数据集类型,此处 `"NamedArrayDataset"` 表示顺序读取的数据集; +2. `input`: 输入数据集; +3. `label`: 标签数据集; + +第二个参数是损失函数,此处使用自定义函数 `train_loss_func`。 + +第三个参数是方程表达式,用于描述如何计算约束目标,计算后的值将会按照指定名称存入输出列表中,从而保证 loss 计算时可以使用这些值。 + +第四个参数是约束条件的名字,我们需要给每一个约束条件命名,方便后续对其索引。 + +在约束构建完毕之后,以我们刚才的命名为关键字,封装到一个字典中,方便后续访问。 + +### 3.4 评估器构建 + +与约束同理,本问题使用 `ppsci.validate.SupervisedValidator` 构建评估器,参数含义也与[约束构建](#33)类似,唯一的区别是评价指标 `metric`。代码如下所示: + +``` py linenums="88" +--8<-- +examples/epnn/epnn.py:88:103 +--8<-- +``` + +### 3.5 超参数设定 + +接下来我们需要指定训练轮数,此处我们按实验经验,使用 10000 轮训练轮数。iters_per_epoch 为 1。 + +``` yaml linenums="40" +--8<-- +examples/epnn/conf/epnn.yaml:40:41 +--8<-- +``` + +### 3.6 优化器构建 + +训练过程会调用优化器来更新模型参数,此处选择较为常用的 `Adam` 优化器,并配合使用机器学习中常用的 ExponentialDecay 学习率调整策略。 + +由于使用多个模型,需要设置多个优化器,对 Epnn 网络部分,需要设置 `Adam` 优化器。 + +``` py linenums="395" +--8<-- +examples/epnn/functions.py:395:403 +--8<-- +``` + +然后对增加的 gkratio 参数,需要再设置优化器。 + +``` py linenums="405" +--8<-- +examples/epnn/functions.py:405:412 +--8<-- +``` + +优化器按顺序优化,代码汇总为: + +``` py linenums="395" +--8<-- +examples/epnn/functions.py:395:413 +--8<-- +``` + +### 3.7 自定义 loss + +由于本问题包含无监督学习,数据中不存在标签数据,loss 根据模型返回数据计算得到,因此需要自定义 loss。方法为先定义相关函数,再将函数名作为参数传给 `FunctionalLoss` 和 `FunctionalMetric`。 + +需要注意自定义 loss 函数的输入输出参数需要与 PaddleScience 中如 `MSE` 等其他函数保持一致,即输入为模型输出 `output_dict` 等字典变量,loss 函数输出为 loss 值 `paddle.Tensor`。 + +相关的自定义 loss 函数使用 `MAELoss` 计算,代码为 + +``` py linenums="113" +--8<-- +examples/epnn/functions.py:113:125 +--8<-- +``` + +### 3.8 模型训练与评估 + +完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`。 + +``` py linenums="106" +--8<-- +examples/epnn/epnn.py:106:118 +--8<-- +``` + +模型训练时设置 eval_during_train 为 True,将在每次训练后评估。 + +``` yaml linenums="43" +--8<-- +examples/epnn/conf/epnn.yaml:43:43 +--8<-- +``` + +最后启动训练即可: + +``` py linenums="121" +--8<-- +examples/epnn/epnn.py:121:121 +--8<-- +``` + +## 4. 完整代码 + +``` py linenums="1" title="epnn.py" +--8<-- +examples/epnn/epnn.py +--8<-- +``` + +## 5. 结果展示 + +EPNN 案例针对 epoch=10000 的参数配置进行了实验,结果返回 Loss 为 0.00471。 + +下图分别为不同 epoch 的 Loss, Training error, Cross validation error 图形: + +
+ ![loss_trend](epnn_images/loss_trend.png){ loading=lazy } +
训练 loss 图形
+
+ +## 6. 参考资料 + +- [A physics-informed deep neural network for surrogate +modeling in classical elasto-plasticity](https://arxiv.org/abs/2204.12088) +- diff --git a/docs/zh/examples/epnn_images/loss_trend.png b/docs/zh/examples/epnn_images/loss_trend.png new file mode 100644 index 000000000..a6ad4fbbe Binary files /dev/null and b/docs/zh/examples/epnn_images/loss_trend.png differ diff --git a/examples/epnn/conf/epnn.yaml b/examples/epnn/conf/epnn.yaml new file mode 100644 index 000000000..4d22bffda --- /dev/null +++ b/examples/epnn/conf/epnn.yaml @@ -0,0 +1,56 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_epnn/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 42 +output_dir: ${hydra:run.dir} +log_freq: 20 + +# set working condition +DATASET_STATE: datasets/dstate-16-plas.dat +DATASET_STRESS: datasets/dstress-16-plas.dat +NTRAIN_SIZE: 40 + +# model settings +MODEL: + ihlayers: 3 + ineurons: 60 + +# training settings +TRAIN: + epochs: 10000 + iters_per_epoch: 1 + save_freq: 50 + eval_during_train: true + eval_with_no_grad: true + lr_scheduler: + epochs: ${TRAIN.epochs} + iters_per_epoch: ${TRAIN.iters_per_epoch} + gamma: 0.97 + decay_steps: 1 + pretrained_model_path: null + checkpoint_path: null + +# evaluation settings +EVAL: + pretrained_model_path: null + eval_with_no_grad: true diff --git a/examples/epnn/epnn.py b/examples/epnn/epnn.py new file mode 100755 index 000000000..8b28e89ce --- /dev/null +++ b/examples/epnn/epnn.py @@ -0,0 +1,205 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Reference: https://github.com/meghbali/ANNElastoplasticity +""" + +from os import path as osp + +import functions +import hydra +from omegaconf import DictConfig + +import ppsci +from ppsci.utils import logger + + +def train(cfg: DictConfig): + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(cfg.seed) + + # initialize logger + logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info") + + ( + input_dict_train, + label_dict_train, + input_dict_val, + label_dict_val, + ) = functions.get_data(cfg.DATASET_STATE, cfg.DATASET_STRESS, cfg.NTRAIN_SIZE) + model_list = functions.get_model_list( + cfg.MODEL.ihlayers, + cfg.MODEL.ineurons, + input_dict_train["state_x"][0].shape[1], + input_dict_train["state_y"][0].shape[1], + input_dict_train["stress_x"][0].shape[1], + ) + optimizer_list = functions.get_optimizer_list(model_list, cfg) + model_state_elasto, model_state_plastic, model_stress = model_list + model_list_obj = ppsci.arch.ModelList(model_list) + + def _transform_in_stress(_in): + return functions.transform_in_stress( + _in, model_state_elasto, "out_state_elasto" + ) + + model_state_elasto.register_input_transform(functions.transform_in) + model_state_plastic.register_input_transform(functions.transform_in) + model_stress.register_input_transform(_transform_in_stress) + model_stress.register_output_transform(functions.transform_out) + + output_keys = [ + "state_x", + "state_y", + "stress_x", + "stress_y", + "out_state_elasto", + "out_state_plastic", + "out_stress", + ] + sup_constraint_pde = ppsci.constraint.SupervisedConstraint( + { + "dataset": { + "name": "NamedArrayDataset", + "input": input_dict_train, + "label": label_dict_train, + }, + "batch_size": 1, + "num_workers": 0, + }, + ppsci.loss.FunctionalLoss(functions.train_loss_func), + {key: (lambda out, k=key: out[k]) for key in output_keys}, + name="sup_train", + ) + constraint_pde = {sup_constraint_pde.name: sup_constraint_pde} + + sup_validator_pde = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": input_dict_val, + "label": label_dict_val, + }, + "batch_size": 1, + "num_workers": 0, + }, + ppsci.loss.FunctionalLoss(functions.eval_loss_func), + {key: (lambda out, k=key: out[k]) for key in output_keys}, + metric={"metric": ppsci.metric.FunctionalMetric(functions.metric_expr)}, + name="sup_valid", + ) + validator_pde = {sup_validator_pde.name: sup_validator_pde} + + # initialize solver + solver = ppsci.solver.Solver( + model_list_obj, + constraint_pde, + cfg.output_dir, + optimizer_list, + None, + cfg.TRAIN.epochs, + cfg.TRAIN.iters_per_epoch, + save_freq=cfg.TRAIN.save_freq, + eval_during_train=cfg.TRAIN.eval_during_train, + validator=validator_pde, + eval_with_no_grad=cfg.TRAIN.eval_with_no_grad, + ) + + # train model + solver.train() + functions.plotting(cfg.output_dir) + + +def evaluate(cfg: DictConfig): + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(cfg.seed) + # initialize logger + logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info") + + ( + input_dict_train, + _, + input_dict_val, + label_dict_val, + ) = functions.get_data(cfg.DATASET_STATE, cfg.DATASET_STRESS, cfg.NTRAIN_SIZE) + model_list = functions.get_model_list( + cfg.MODEL.ihlayers, + cfg.MODEL.ineurons, + input_dict_train["state_x"][0].shape[1], + input_dict_train["state_y"][0].shape[1], + input_dict_train["stress_x"][0].shape[1], + ) + model_state_elasto, model_state_plastic, model_stress = model_list + model_list_obj = ppsci.arch.ModelList(model_list) + + def transform_f_stress(_in): + return functions.transform_f(_in, model_state_elasto, "out_state_elasto") + + model_state_elasto.register_input_transform(functions.transform_in) + model_state_plastic.register_input_transform(functions.transform_in) + model_stress.register_input_transform(transform_f_stress) + model_stress.register_output_transform(functions.transform_out) + + output_keys = [ + "state_x", + "state_y", + "stress_x", + "stress_y", + "out_state_elasto", + "out_state_plastic", + "out_stress", + ] + sup_validator_pde = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": input_dict_val, + "label": label_dict_val, + }, + "batch_size": 1, + "num_workers": 0, + }, + ppsci.loss.FunctionalLoss(functions.eval_loss_func), + {key: (lambda out, k=key: out[k]) for key in output_keys}, + metric={"metric": ppsci.metric.FunctionalMetric(functions.metric_expr)}, + name="sup_valid", + ) + validator_pde = {sup_validator_pde.name: sup_validator_pde} + functions.OUTPUT_DIR = cfg.output_dir + + # initialize solver + solver = ppsci.solver.Solver( + model_list_obj, + output_dir=cfg.output_dir, + validator=validator_pde, + pretrained_model_path=cfg.EVAL.pretrained_model_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + # evaluate after finished training + solver.eval() + + +@hydra.main(version_base=None, config_path="./conf", config_name="epnn.yaml") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main() diff --git a/examples/epnn/functions.py b/examples/epnn/functions.py new file mode 100644 index 000000000..04fd8865b --- /dev/null +++ b/examples/epnn/functions.py @@ -0,0 +1,424 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Elasto-Plastic Neural Network (EPNN) + +DEVELOPED AT: + COMPUTATIONAL GEOMECHANICS LABORATORY + DEPARTMENT OF CIVIL ENGINEERING + UNIVERSITY OF CALGARY, AB, CANADA + DIRECTOR: Prof. Richard Wan + +DEVELOPED BY: + MAHDAD EGHBALIAN + +MIT License + +Copyright (c) 2022 Mahdad Eghbalian + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import math +from typing import Dict + +import numpy as np +import paddle + +import ppsci +from ppsci.utils import logger + +# log for loss(total, state_elasto, state_plastic, stress), eval error(total, state_elasto, state_plastic, stress) +loss_log = {} # for plotting +eval_log = {} +plot_keys = {"total", "state_elasto", "state_plastic", "stress"} +for key in plot_keys: + loss_log[key] = [] + eval_log[key] = [] + +# transform +def transform_in(input): + input_transformed = {} + for key in input: + input_transformed[key] = paddle.squeeze(input[key], axis=0) + return input_transformed + + +def transform_out(input, out): + # Add transformed input for computing loss + out.update(input) + return out + + +def transform_in_stress(input, model, out_key): + input_elasto = model(input)[out_key] + input_elasto = input_elasto.detach().clone() + input_transformed = {} + for key in input: + input_transformed[key] = paddle.squeeze(input[key], axis=0) + input_state_m = paddle.concat( + x=( + input_elasto, + paddle.index_select( + input_transformed["state_x"], + paddle.to_tensor([0, 1, 2, 3, 7, 8, 9, 10, 11, 12]), + axis=1, + ), + ), + axis=1, + ) + input_transformed["state_x_f"] = input_state_m + return input_transformed + + +common_param = [] +gkratio = paddle.to_tensor( + data=[[0.45]], dtype=paddle.get_default_dtype(), stop_gradient=False +) + + +def val_loss_criterion(x, y): + return 100.0 * ( + paddle.linalg.norm(x=x["input"] - y["input"]) / paddle.linalg.norm(x=y["input"]) + ) + + +def train_loss_func(output_dict, *args) -> paddle.Tensor: + """For model calculation of loss in model.train(). + + Args: + output_dict (Dict[str, paddle.Tensor]): The output dict. + + Returns: + paddle.Tensor: Loss value. + """ + # Use ppsci.loss.MAELoss to replace paddle.nn.L1Loss + loss, loss_elasto, loss_plastic, loss_stress = loss_func( + output_dict, ppsci.loss.MAELoss() + ) + loss_log["total"].append(float(loss)) + loss_log["state_elasto"].append(float(loss_elasto)) + loss_log["state_plastic"].append(float(loss_plastic)) + loss_log["stress"].append(float(loss_stress)) + return loss + + +def eval_loss_func(output_dict, *args) -> paddle.Tensor: + """For model calculation of loss in model.eval(). + + Args: + output_dict (Dict[str, paddle.Tensor]): The output dict. + + Returns: + paddle.Tensor: Loss value. + """ + error, error_elasto, error_plastic, error_stress = loss_func( + output_dict, val_loss_criterion + ) + eval_log["total"].append(float(error)) + eval_log["state_elasto"].append(float(error_elasto)) + eval_log["state_plastic"].append(float(error_plastic)) + eval_log["stress"].append(float(error_stress)) + logger.message( + f"Error: {float(error)},{float(error_elasto)},{float(error_plastic)},{float(error_stress)}" + ) + return error + + +def metric_expr(output_dict, *args) -> Dict[str, paddle.Tensor]: + return {"dummy_loss": paddle.to_tensor(0.0)} + + +def loss_func(output_dict, criterion) -> paddle.Tensor: + ( + min_elasto, + min_plastic, + range_elasto, + range_plastic, + min_stress, + range_stress, + ) = common_param + + coeff1 = 2.0 + coeff2 = 1.0 + input_elasto = output_dict["out_state_elasto"] + input_plastic = output_dict["out_state_plastic"] + input_stress = output_dict["out_stress"] + target_elasto = output_dict["state_y"][:, 0:1] + target_plastic = output_dict["state_y"][:, 1:4] + loss_elasto = criterion({"input": input_elasto}, {"input": target_elasto}) + loss_plastic = criterion({"input": input_plastic}, {"input": target_plastic}) + oneten_state = paddle.ones(shape=[3, 1], dtype=paddle.get_default_dtype()) + oneten_stress = paddle.ones( + shape=[output_dict["stress_y"].shape[0], output_dict["stress_y"].shape[1]], + dtype=paddle.get_default_dtype(), + ) + dstrain = output_dict["state_x"][:, 10:] + dstrain_real = ( + paddle.multiply(x=dstrain + coeff2, y=paddle.to_tensor(range_stress)) / coeff1 + + min_stress + ) + # predict label + dstrainpl = target_plastic + dstrainpl_real = ( + paddle.multiply(x=dstrainpl + coeff2, y=paddle.to_tensor(range_elasto[1:4])) + / coeff1 + + min_elasto[1:4] + ) + # evaluate label + dstrainel = dstrain_real - dstrainpl_real + mu = paddle.multiply(x=gkratio, y=paddle.to_tensor(input_stress[:, 0:1])) + mu_dstrainel = 2.0 * paddle.multiply(x=mu, y=paddle.to_tensor(dstrainel)) + stress_dstrainel = paddle.multiply( + x=input_stress[:, 0:1] - 2.0 / 3.0 * mu, + y=paddle.to_tensor( + paddle.multiply( + x=paddle.matmul(x=dstrainel, y=oneten_state), + y=paddle.to_tensor(oneten_stress), + ) + ), + ) + input_stress = ( + coeff1 + * paddle.divide( + x=mu_dstrainel + stress_dstrainel - min_plastic, + y=paddle.to_tensor(range_plastic), + ) + - coeff2 + ) + target_stress = output_dict["stress_y"] + loss_stress = criterion({"input": input_stress}, {"input": target_stress}) + loss = loss_elasto + loss_plastic + loss_stress + return loss, loss_elasto, loss_plastic, loss_stress + + +class Dataset: + def __init__(self, data_state, data_stress, itrain): + self.data_state = data_state + self.data_stress = data_stress + self.itrain = itrain + + def get(self, epochs=1): + # Slow if using BatchSampler to obtain data + input_dict_train = { + "state_x": [], + "state_y": [], + "stress_x": [], + "stress_y": [], + } + input_dict_val = { + "state_x": [], + "state_y": [], + "stress_x": [], + "stress_y": [], + } + label_dict_train = {"dummy_loss": []} + label_dict_val = {"dummy_loss": []} + for i in range(epochs): + shuffled_indices = paddle.randperm(n=self.data_state.x_train.shape[0]) + input_dict_train["state_x"].append( + self.data_state.x_train[shuffled_indices[0 : self.itrain]] + ) + input_dict_train["state_y"].append( + self.data_state.y_train[shuffled_indices[0 : self.itrain]] + ) + input_dict_train["stress_x"].append( + self.data_stress.x_train[shuffled_indices[0 : self.itrain]] + ) + input_dict_train["stress_y"].append( + self.data_stress.y_train[shuffled_indices[0 : self.itrain]] + ) + label_dict_train["dummy_loss"].append(paddle.to_tensor(0.0)) + + shuffled_indices = paddle.randperm(n=self.data_state.x_valid.shape[0]) + input_dict_val["state_x"].append( + self.data_state.x_valid[shuffled_indices[0 : self.itrain]] + ) + input_dict_val["state_y"].append( + self.data_state.y_valid[shuffled_indices[0 : self.itrain]] + ) + input_dict_val["stress_x"].append( + self.data_stress.x_valid[shuffled_indices[0 : self.itrain]] + ) + input_dict_val["stress_y"].append( + self.data_stress.y_valid[shuffled_indices[0 : self.itrain]] + ) + label_dict_val["dummy_loss"].append(paddle.to_tensor(0.0)) + return input_dict_train, label_dict_train, input_dict_val, label_dict_val + + +class Data: + def __init__(self, dataset_path, train_p=0.6, cross_valid_p=0.2, test_p=0.2): + data = ppsci.utils.reader.load_dat_file(dataset_path) + self.x = data["X"] + self.y = data["y"] + self.train_p = train_p + self.cross_valid_p = cross_valid_p + self.test_p = test_p + + def get_shuffled_data(self): + # Need to set the seed, otherwise the loss will not match the precision + ppsci.utils.misc.set_random_seed(seed=10) + shuffled_indices = paddle.randperm(n=self.x.shape[0]) + n_train = math.floor(self.train_p * self.x.shape[0]) + n_cross_valid = math.floor(self.cross_valid_p * self.x.shape[0]) + n_test = math.floor(self.test_p * self.x.shape[0]) + self.x_train = self.x[shuffled_indices[0:n_train]] + self.y_train = self.y[shuffled_indices[0:n_train]] + self.x_valid = self.x[shuffled_indices[n_train : n_train + n_cross_valid]] + self.y_valid = self.y[shuffled_indices[n_train : n_train + n_cross_valid]] + self.x_test = self.x[ + shuffled_indices[n_train + n_cross_valid : n_train + n_cross_valid + n_test] + ] + self.y_test = self.y[ + shuffled_indices[n_train + n_cross_valid : n_train + n_cross_valid + n_test] + ] + + +def get_data(dataset_state, dataset_stress, ntrain_size): + set_common_param(dataset_state, dataset_stress) + + data_state = Data(dataset_state) + data_stress = Data(dataset_stress) + data_state.get_shuffled_data() + data_stress.get_shuffled_data() + + train_size_log10 = np.linspace( + 1, np.log10(data_state.x_train.shape[0]), num=ntrain_size + ) + train_size_float = 10**train_size_log10 + train_size = train_size_float.astype(int) + itrain = train_size[ntrain_size - 1] + + return Dataset(data_state, data_stress, itrain).get(10) + + +def set_common_param(dataset_state, dataset_stress): + get_data = ppsci.utils.reader.load_dat_file(dataset_state) + min_state = paddle.to_tensor(data=get_data["miny"]) + range_state = paddle.to_tensor(data=get_data["rangey"]) + min_dstrain = paddle.to_tensor(data=get_data["minx"][10:]) + range_dstrain = paddle.to_tensor(data=get_data["rangex"][10:]) + get_data = ppsci.utils.reader.load_dat_file(dataset_stress) + min_stress = paddle.to_tensor(data=get_data["miny"]) + range_stress = paddle.to_tensor(data=get_data["rangey"]) + common_param.extend( + [ + min_state, + min_stress, + range_state, + range_stress, + min_dstrain, + range_dstrain, + ] + ) + + +def get_model_list( + nhlayers, nneurons, state_x_output_size, state_y_output_size, stress_x_output_size +): + NHLAYERS_PLASTIC = 4 + NNEURONS_PLASTIC = 75 + hl_nodes_elasto = [nneurons] * nhlayers + hl_nodes_plastic = [NNEURONS_PLASTIC] * NHLAYERS_PLASTIC + node_sizes_state_elasto = [state_x_output_size] + node_sizes_state_plastic = [state_x_output_size] + node_sizes_stress = [stress_x_output_size + state_y_output_size - 6] + node_sizes_state_elasto.extend(hl_nodes_elasto) + node_sizes_state_plastic.extend(hl_nodes_plastic) + node_sizes_stress.extend(hl_nodes_elasto) + node_sizes_state_elasto.extend([state_y_output_size - 3]) + node_sizes_state_plastic.extend([state_y_output_size - 1]) + node_sizes_stress.extend([1]) + + activation_elasto = "leaky_relu" + activation_plastic = "leaky_relu" + activations_elasto = [activation_elasto] + activations_plastic = [activation_plastic] + activations_elasto.extend([activation_elasto for ii in range(nhlayers)]) + activations_plastic.extend([activation_plastic for ii in range(NHLAYERS_PLASTIC)]) + activations_elasto.extend([activation_elasto]) + activations_plastic.extend([activation_plastic]) + drop_p = 0.0 + n_state_elasto = ppsci.arch.Epnn( + ("state_x",), + ("out_state_elasto",), + tuple(node_sizes_state_elasto), + tuple(activations_elasto), + drop_p, + ) + n_state_plastic = ppsci.arch.Epnn( + ("state_x",), + ("out_state_plastic",), + tuple(node_sizes_state_plastic), + tuple(activations_plastic), + drop_p, + ) + n_stress = ppsci.arch.Epnn( + ("state_x_f",), + ("out_stress",), + tuple(node_sizes_stress), + tuple(activations_elasto), + drop_p, + ) + return (n_state_elasto, n_state_plastic, n_stress) + + +def get_optimizer_list(model_list, cfg): + optimizer_list = [] + lr_list = [0.001, 0.001, 0.01] + for i, model in enumerate(model_list): + scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay( + **cfg.TRAIN.lr_scheduler, learning_rate=lr_list[i] + )() + optimizer_list.append( + ppsci.optimizer.Adam(learning_rate=scheduler, weight_decay=0.0)(model) + ) + + scheduler_ratio = ppsci.optimizer.lr_scheduler.ExponentialDecay( + **cfg.TRAIN.lr_scheduler, learning_rate=0.001 + )() + optimizer_list.append( + paddle.optimizer.Adam( + parameters=[gkratio], learning_rate=scheduler_ratio, weight_decay=0.0 + ) + ) + return ppsci.optimizer.OptimizerList(optimizer_list) + + +def plotting(output_dir): + ppsci.utils.misc.plot_curve( + data=eval_log, + xlabel="Epoch", + ylabel="Training Eval", + output_dir=output_dir, + smooth_step=1, + use_semilogy=True, + ) diff --git a/mkdocs.yml b/mkdocs.yml index 9802993af..ba2b6afd1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -65,6 +65,7 @@ nav: - Phy-LSTM: zh/examples/phylstm.md - 材料科学(AI for Material): - hPINNs: zh/examples/hpinns.md + - EPNN: zh/examples/epnn.md - 地球科学(AI for Earth Science): - FourCastNet: zh/examples/fourcastnet.md - API文档: diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index f418e225f..baa15a892 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -31,6 +31,7 @@ from ppsci.arch.afno import AFNONet # isort:skip from ppsci.arch.afno import PrecipNet # isort:skip from ppsci.arch.unetex import UNetEx # isort:skip +from ppsci.arch.epnn import Epnn # isort:skip from ppsci.utils import logger # isort:skip @@ -50,6 +51,7 @@ "AFNONet", "PrecipNet", "UNetEx", + "Epnn", "build_model", ] diff --git a/ppsci/arch/epnn.py b/ppsci/arch/epnn.py new file mode 100644 index 000000000..b158415d8 --- /dev/null +++ b/ppsci/arch/epnn.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Elasto-Plastic Neural Network (EPNN) + +DEVELOPED AT: + COMPUTATIONAL GEOMECHANICS LABORATORY + DEPARTMENT OF CIVIL ENGINEERING + UNIVERSITY OF CALGARY, AB, CANADA + DIRECTOR: Prof. Richard Wan + +DEVELOPED BY: + MAHDAD EGHBALIAN + +MIT License + +Copyright (c) 2022 Mahdad Eghbalian + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from typing import Tuple + +import paddle.nn as nn + +from ppsci.arch import activation as act_mod +from ppsci.arch import base + + +class Epnn(base.Arch): + """Builds a feedforward network with arbitrary layers. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z"). + output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w"). + node_sizes (Tuple[int, ...]): The tuple of node size. + activations (Tuple[str, ...]): Name of activation functions. + drop_p (float): The parameter p of nn.Dropout. + + Examples: + >>> import ppsci + >>> ann_node_sizes_state = [1] + >>> model = ppsci.arch.Epnn(("x",), ("y",), node_sizes=node_sizes_state, + activations=("leaky_relu"), + drop_p=0.0 + ) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + node_sizes: Tuple[int, ...], + activations: Tuple[str, ...], + drop_p: float, + ): + super().__init__() + self.active_func = [act_mod.get_activation(i) for i in activations] + self.node_sizes = node_sizes + self.drop_p = drop_p + self.layers = [] + self.layers.append( + nn.Linear(in_features=node_sizes[0], out_features=node_sizes[1]) + ) + layer_sizes = zip(node_sizes[1:-2], node_sizes[2:-1]) + self.layers.extend( + [nn.Linear(in_features=h1, out_features=h2) for h1, h2 in layer_sizes] + ) + self.layers.append( + nn.Linear( + in_features=node_sizes[-2], out_features=node_sizes[-1], bias_attr=False + ) + ) + + self.layers = nn.LayerList(self.layers) + self.dropout = nn.Dropout(p=drop_p) + self.input_keys = input_keys + self.output_keys = output_keys + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + y = x[self.input_keys[0]] + for ilayer in range(len(self.layers)): + y = self.layers[ilayer](y) + if ilayer != len(self.layers) - 1: + y = self.active_func[ilayer + 1](y) + if ilayer != len(self.layers) - 1: + y = self.dropout(y) + y = self.split_to_dict(y, self.output_keys, axis=-1) + + if self._output_transform is not None: + y = self._output_transform(x, y) + return y diff --git a/ppsci/loss/__init__.py b/ppsci/loss/__init__.py index ee83ca5e9..4e7d6242c 100644 --- a/ppsci/loss/__init__.py +++ b/ppsci/loss/__init__.py @@ -22,6 +22,7 @@ from ppsci.loss.l2 import L2Loss from ppsci.loss.l2 import L2RelLoss from ppsci.loss.l2 import PeriodicL2Loss +from ppsci.loss.mae import MAELoss from ppsci.loss.mse import MSELoss from ppsci.loss.mse import MSELossWithL2Decay from ppsci.loss.mse import PeriodicMSELoss @@ -35,6 +36,7 @@ "L2Loss", "L2RelLoss", "PeriodicL2Loss", + "MAELoss", "MSELoss", "MSELossWithL2Decay", "PeriodicMSELoss", diff --git a/ppsci/loss/mae.py b/ppsci/loss/mae.py new file mode 100644 index 000000000..75b39fe5c --- /dev/null +++ b/ppsci/loss/mae.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Dict +from typing import Optional +from typing import Union + +import paddle.nn.functional as F +from typing_extensions import Literal + +from ppsci.loss import base + + +class MAELoss(base.Loss): + r"""Class for mean absolute error loss. + + $$ + L = + \begin{cases} + \dfrac{1}{N} \Vert {\mathbf{x}-\mathbf{y}} \Vert_1, & \text{if reduction='mean'} \\ + \Vert {\mathbf{x}-\mathbf{y}} \Vert_1, & \text{if reduction='sum'} + \end{cases} + $$ + + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} + $$ + + Args: + reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean". + weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None. + + Examples: + >>> import ppsci + >>> loss = ppsci.loss.MAELoss("mean") + """ + + def __init__( + self, + reduction: Literal["mean", "sum"] = "mean", + weight: Optional[Union[float, Dict[str, float]]] = None, + ): + if reduction not in ["mean", "sum"]: + raise ValueError( + f"reduction should be 'mean' or 'sum', but got {reduction}" + ) + super().__init__(reduction, weight) + + def forward(self, output_dict, label_dict, weight_dict=None): + losses = 0.0 + for key in label_dict: + loss = F.l1_loss(output_dict[key], label_dict[key], "none") + if weight_dict: + loss *= weight_dict[key] + + if "area" in output_dict: + loss *= output_dict["area"] + + if self.reduction == "sum": + loss = loss.sum() + elif self.reduction == "mean": + loss = loss.mean() + if isinstance(self.weight, (float, int)): + loss *= self.weight + elif isinstance(self.weight, dict) and key in self.weight: + loss *= self.weight[key] + + losses += loss + return losses diff --git a/ppsci/optimizer/lr_scheduler.py b/ppsci/optimizer/lr_scheduler.py index 4c7cc44fe..697aac48c 100644 --- a/ppsci/optimizer/lr_scheduler.py +++ b/ppsci/optimizer/lr_scheduler.py @@ -16,6 +16,7 @@ import abc import math +from typing import List from typing import Tuple from typing import Union @@ -732,3 +733,43 @@ def __call__(self): setattr(learning_rate, "by_epoch", self.by_epoch) return learning_rate + + +class SchedulerList: + """SchedulerList which wrap more than one scheduler. + Args: + scheduler_list (Tuple[lr.LRScheduler, ...]): Schedulers listed in a tuple. + by_epoch (bool, optional): Learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False. + Examples: + >>> import ppsci + >>> sch1 = ppsci.optimizer.lr_scheduler.Linear(10, 2, 0.001)() + >>> sch2 = ppsci.optimizer.lr_scheduler.ExponentialDecay(10, 2, 1e-3, 0.95, 3)() + >>> sch = ppsci.optimizer.lr_scheduler.SchedulerList((sch1, sch2)) + """ + + def __init__( + self, scheduler_list: Tuple[lr.LRScheduler, ...], by_epoch: bool = False + ): + super().__init__() + self._sch_list = scheduler_list + self.by_epoch = by_epoch + + def step(self): + for sch in self._sch_list: + sch.step() + + def get_lr(self) -> float: + """Return learning rate of first scheduler""" + return self._sch_list[0].get_lr() + + def _state_keys(self) -> List[str]: + return ["last_epoch", "last_lr"] + + def __len__(self) -> int: + return len(self._sch_list) + + def __getitem__(self, idx): + return self._sch_list[idx] + + def __setitem__(self, idx, sch): + raise NotImplementedError("Can not modify any item in SchedulerList.") diff --git a/ppsci/utils/reader.py b/ppsci/utils/reader.py index f6738dec6..3236c3173 100644 --- a/ppsci/utils/reader.py +++ b/ppsci/utils/reader.py @@ -16,6 +16,7 @@ import collections import csv +import pickle from typing import Dict from typing import Optional from typing import Tuple @@ -31,6 +32,7 @@ "load_npz_file", "load_vtk_file", "load_vtk_with_time_file", + "load_dat_file", ] @@ -179,14 +181,18 @@ def load_vtk_file( i = 0 for key in input_dict: if key == "t": - input_dict[key].append(np.full((n, 1), index * time_step, "float32")) + input_dict[key].append( + np.full((n, 1), index * time_step, paddle.get_default_dtype()) + ) else: input_dict[key].append( - mesh.points[:, i].reshape(n, 1).astype("float32") + mesh.points[:, i].reshape(n, 1).astype(paddle.get_default_dtype()) ) i += 1 for i, key in enumerate(label_dict): - label_dict[key].append(np.array(mesh.point_data[key], "float32")) + label_dict[key].append( + np.array(mesh.point_data[key], paddle.get_default_dtype()) + ) for key in input_dict: input_dict[key] = np.concatenate(input_dict[key]) for key in label_dict: @@ -212,3 +218,42 @@ def load_vtk_with_time_file(file: str) -> Dict[str, np.ndarray]: z = mesh.points[:, 2].reshape(n, 1) input_dict = {"t": t, "x": x, "y": y, "z": z} return input_dict + + +def load_dat_file( + file_path: str, + keys: Tuple[str, ...] = None, + alias_dict: Optional[Dict[str, str]] = None, +) -> Dict[str, np.ndarray]: + """Load *.dat file and fetch data as given keys. + + Args: + file_path (str): Dat file path. + keys (Tuple[str, ...]): Required fetching keys. + alias_dict (Optional[Dict[str, str]]): Alias for keys, + i.e. {original_key: original_key}. Defaults to None. + + Returns: + Dict[str, np.ndarray]: Loaded data in dict. + """ + + if alias_dict is None: + alias_dict = {} + + try: + # read all data from .dat file + raw_data = pickle.load(open(file_path, "rb")) + except FileNotFoundError as e: + raise e + + # convert to numpy array + data_dict = {} + if keys is None: + keys = raw_data.keys() + for key in keys: + fetch_key = alias_dict[key] if key in alias_dict else key + if fetch_key not in raw_data: + raise KeyError(f"fetch_key({fetch_key}) do not exist in raw_data.") + data_dict[key] = np.asarray(raw_data[fetch_key], paddle.get_default_dtype()) + + return data_dict