Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon 6th Code Camp No.15] support neuraloperator #867

Merged
merged 28 commits into from
May 27, 2024

Conversation

Yang-Changhui
Copy link
Contributor

PR types

New features

PR changes

Others

Describe

增加neuraloperator模型

Copy link

paddle-bot bot commented Apr 25, 2024

Thanks for your contribution!

except KeyError:
padding = [round(p * r) for (p, r) in zip(self.domain_padding, resolution)]

print(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用ppsci内置的log

# of the "padding" list i.e. the last axis of tensor 'x' will be
# padded by the amount mention at the first position of the
# 'padding' vector. The details about F.pad can be found here:
# https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认注释正确性

padded = F.pad(x, padding, mode="constant")

output_shape = padded.shape[2:]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除多余空格

return nn.Identity()
else:
raise ValueError(
f"Got skip-connection type = {type}, expected one of {'soft-gating', 'linear', 'id'}."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id --> identity

arg is only checked when `implementation=reconstructed`. Defaults to False.

Raises:
ValueError: _description_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除raises


if implementation == "reconstructed":
if separable:
print("SEPARABLE")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除print

# elif weight.name.lower() == 'complexcp':
# return _contract_cp
# else:
# raise ValueError(f'Got unexpected factorized weight type {weight.name}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除无意义的注释

fixed_rank_modes=False,
joint_factorization=False,
init_std="auto",
fft_norm="backward",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

初始化添加类型提示,其他的 类 也注意一下

)

def forward(
self, x: paddle.Tensor, indices=0, output_shape: Optional[Tuple[int]] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量 indices 添加类型提示


This is provided for reference only,
see :class:`neuralop.layers.SpectraConv` for the preferred, general implementation
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

更正注释,ppsci没有neuralop.layers.SpectraConv这个路径,另外确认该类是否在代码中有用到

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个类都没有用到,但是是可选的



def _contract_dense_trick(x, weight_real, weight_imag, separable=False, dhconv=True):
# the same as above function, but do the complex multiplication manually to avoid the einsum bug in paddle
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

“the same as above function” ?这个函数上面没有其他函数呀?


if implementation == "reconstructed":
if separable:
print("SEPARABLE")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认是否有必要print

elif implementation == "factorized":
if isinstance(weight, paddle.Tensor):
return _contract_dense_trick
# TODO: FactorizedTensor not supported yet
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个注释可以删掉吧



class SphericalConv(nn.Layer):
"""Spherical Convolution, base class for the SFNO [1]_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释最后用“.“更合适吧?

"""Spherical Convolution, base class for the SFNO [1].
.. [1] Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere,
Boris Bonev, Thorsten Kurth, Christian Hundt, Jaideep Pathak, Maximilian Baust, Karthik Kashinath, Anima Anandkumar,
ICML 2023.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

references重复了

Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
n_modes (int): Number of modes to use for contraction in Fourier domain during
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释中的变量类型与初始化函数中的保持一致

"""Applies domain padding scaled automatically to the input's resolution

Args:
domain_padding (float): typically, between zero and one, percentage of padding to use.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类型与初始化代码保持一致, Union[float, List[float]]


Returns
-------
torch.tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数注释修改下,建议注释风格修改为PaddleScience风格


# x is in shape of batch*n or T*batch*n
# x = (x.view(self.sample_shape) * std) + mean
# x = x.view(-1, *self.sample_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要的注释删除吧

return x


def count_params(model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认此函数是否有使用

50 samples at resolution 32x32.

Args:
test_resolutions (List[int,...]): The resolutions to test dataset. Default is [16, 32].
Copy link
Collaborator

@zhiminzhang0830 zhiminzhang0830 May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完善docstring,与初始化函数的入参保持一致, 缺少train_resolution,记得补充完整。test_resolutions参数的docstring的默认值也需要与初始化入参保持一致

50 samples at resolution 64x128.

Args:
test_resolutions (List[str,...]): The resolutions to test dataset. Default is ["34x64","64x128"].
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完善docstring

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加 SphericalSWEDataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为这个数据集是我使用pytorch原代码生成的数据,直接保存的npy文件,这样还需要写进去吗

pretrained_model_path: ./outputs_sfno_pretrain/checkpoints/best_model.pdparams
export_path: ./inference/uno/uno_darcyflow
pdmodel_path: ${INFER.export_path}.pdmodel
pdpiparams_path: ${INFER.export_path}.pdiparams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pdpiparams_path --> pdiparams_path 全部替换一下

encoding (str): The type of encoding. Default is 'channel-wise'.
channel_dim (int): The location of unsqueeze. Default is 1.
where to put the channel dimension, defaults size is batch, channel, height, width
training (str): Wether to use training or test dataset. Default is 'train'.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他数据集中使用 training 参数表示该数据集是否用于训练,使用bool类型。在此数据集中,training 参数还用于表示测试时的数据选择,与变量名称”training“不太相符,建议修改为其他更有意义的变量名称

Defaults to None.
test_resolutions (Tuple[str, ...], optional): The resolutions to test dataset. Defaults to ["34x64", "64x128"].
train_resolution (str, optional): The resolutions to train dataset. Defaults to "34x64".
training (str, optional): Wether to use training or test dataset. Defaults to "train".
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与本PR无关,记得删除

input_data = data["x"][0].reshape(1, *data["x"].shape[1:]).astype("float32")
label = data["y"][0][0, ...].astype("float32")

model = ppsci.arch.SFNONet(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好像还是使用动态图的形式进行预测的?并没有使用export的静态图模型进行预测?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是动态图形式进行预测的,因为导出静态图模型进行预测时,结果和动态图不一样,可能是因为使用虚数的缘故,也可能是独有的API导致的导出模型的网络权重有问题,跟动态图的网络权重不一样,这个之前给你说过

@Yang-Changhui
Copy link
Contributor Author

paddle使用dev版本

sup_validator_32.name: sup_validator_32,
}

model = ppsci.arch.TFNO2dNet(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方应该是UNO的模型吧

scheduler_T_max: 500
wd: 1.0e-4
batch_size: 16
pretrained_model_path: ./pretrainmodel/tfno_pretrain.pdparams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接设置成null吧,另外两个也改下

n_layers: int = 1,
max_n_modes: int = None,
use_mlp: bool = False,
mlp: Optional[dict[float, float]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处dict应该为Dict, 并在文件开头添加:from typing import Dict,其他文件也有类似问题,建议统一检查

ax.set_title("Model prediction")
plt.xticks([], [])
plt.yticks([], [])
plt.savefig(cfg.output_dir + "123.png")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处123.png,作为文件名称有什么含义么

solver.export(input_spec, cfg.INFER.export_path)


# 使用静态图推理出错
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在推理还有问题么

"""
if batch_size != 1:
raise ValueError(
f"FNOPredictor only support batch_size=1, but got {batch_size}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FNOPredictor --> SFNOPredictor

Copy link
Collaborator

@zhiminzhang0830 zhiminzhang0830 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, TODO:add doc

@zhiminzhang0830 zhiminzhang0830 merged commit aa963b8 into PaddlePaddle:develop May 27, 2024
3 of 4 checks passed
huohuohuohuohuo123 pushed a commit to huohuohuohuohuo123/PaddleScience that referenced this pull request Aug 12, 2024
)

* add-neuraloperator

* add-neuraloperator

* add-neuraloperator

* add-neuraloperator

* add-neuraloperator

* add-neuraloperator

* add-neuraloperator

* add-neuraloperator

* add-neuraloperator

* add-neuraloperator

* add-neuraloperator

* add-neuraloperator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants