Skip to content

Commit 963c9ae

Browse files
committed
Refactor _validate_data_input
1 parent 8f72e4c commit 963c9ae

File tree

2 files changed

+67
-88
lines changed

2 files changed

+67
-88
lines changed

pygmt/clib/session.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,7 @@ def virtualfile_from_stringio(
17651765
seg.header = None
17661766
seg.text = None
17671767

1768-
def virtualfile_in(
1768+
def virtualfile_in( # noqa: PLR0912
17691769
self,
17701770
check_kind=None,
17711771
data=None,
@@ -1825,23 +1825,25 @@ def virtualfile_in(
18251825
... print(fout.read().strip())
18261826
<vector memory>: N = 3 <7/9> <4/6> <1/3>
18271827
"""
1828+
# Specify either data or x/y/z.
1829+
if data is not None and any(v is not None for v in (x, y, z)):
1830+
msg = "Too much data. Use either data or x/y/z."
1831+
raise GMTInvalidInput(msg)
1832+
1833+
# Determine the kind of data.
18281834
kind = data_kind(data, required=required_data)
1829-
_validate_data_input(
1830-
data=data,
1831-
x=x,
1832-
y=y,
1833-
z=z,
1834-
required_z=required_z,
1835-
required_data=required_data,
1836-
kind=kind,
1837-
)
18381835

1836+
# Check if the kind of data is valid.
18391837
if check_kind:
18401838
valid_kinds = ("file", "arg") if required_data is False else ("file",)
1841-
if check_kind == "raster":
1842-
valid_kinds += ("grid", "image")
1843-
elif check_kind == "vector":
1844-
valid_kinds += ("empty", "matrix", "vectors", "geojson")
1839+
match check_kind:
1840+
case "raster":
1841+
valid_kinds += ("grid", "image")
1842+
case "vector":
1843+
valid_kinds += ("empty", "matrix", "vectors", "geojson")
1844+
case _:
1845+
msg = f"Invalid value for check_kind: '{check_kind}'."
1846+
raise GMTInvalidInput(msg)
18451847
if kind not in valid_kinds:
18461848
msg = f"Unrecognized data type for {check_kind}: {type(data)}."
18471849
raise GMTInvalidInput(msg)
@@ -1892,6 +1894,9 @@ def virtualfile_in(
18921894
_virtualfile_from = self.virtualfile_from_vectors
18931895
_data = data.T
18941896

1897+
# Check if _data to be passed to the virtualfile_from_ function is valid.
1898+
_validate_data_input(data=_data, kind=kind, required_z=required_z)
1899+
18951900
# Finally create the virtualfile from the data, to be passed into GMT
18961901
file_context = _virtualfile_from(_data)
18971902
return file_context

pygmt/helpers/utils.py

+48-74
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import time
1313
import webbrowser
1414
from collections.abc import Iterable, Mapping, Sequence
15-
from itertools import islice
1615
from pathlib import Path
1716
from typing import Any, Literal
1817

@@ -40,118 +39,97 @@
4039
"ISO-8859-15",
4140
"ISO-8859-16",
4241
]
42+
# Type hints for the list of possible data kinds.
43+
Kind = Literal[
44+
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
45+
]
4346

4447

45-
def _validate_data_input( # noqa: PLR0912
46-
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
47-
) -> None:
48+
def _validate_data_input(data: Any, kind: Kind, required_z: bool = False) -> None:
4849
"""
49-
Check if the combination of data/x/y/z is valid.
50+
Check if the data to be passed to the virtualfile_from_ functions is valid.
5051
5152
Examples
5253
--------
53-
>>> _validate_data_input(data="infile")
54-
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6])
55-
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9])
56-
>>> _validate_data_input(data=None, required_data=False)
57-
>>> _validate_data_input()
54+
The "empty" kind means the data is given via a series of vectors like x/y/z.
55+
56+
>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="empty")
57+
>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], kind="empty")
58+
>>> _validate_data_input(data=[None, [4, 5, 6]], kind="empty")
5859
Traceback (most recent call last):
5960
...
60-
pygmt.exceptions.GMTInvalidInput: No input data provided.
61-
>>> _validate_data_input(x=[1, 2, 3])
61+
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
62+
>>> _validate_data_input(data=[[1, 2, 3], None], kind="empty")
6263
Traceback (most recent call last):
6364
...
6465
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
65-
>>> _validate_data_input(y=[4, 5, 6])
66+
>>> _validate_data_input(data=[None, None], kind="empty")
6667
Traceback (most recent call last):
6768
...
6869
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
69-
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True)
70+
>>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="empty", required_z=True)
7071
Traceback (most recent call last):
7172
...
7273
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.
74+
75+
The "matrix" kind means the data is given via a 2-D numpy.ndarray.
76+
7377
>>> import numpy as np
7478
>>> import pandas as pd
7579
>>> import xarray as xr
7680
>>> data = np.arange(8).reshape((4, 2))
77-
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
81+
>>> _validate_data_input(data=data, kind="matrix", required_z=True)
7882
Traceback (most recent call last):
7983
...
80-
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
84+
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
85+
86+
The "vectors" kind means the original data is either dictionary, list, tuple,
87+
pandas.DataFrame, pandas.Series, xarray.Dataset, or xarray.DataArray.
88+
8189
>>> _validate_data_input(
8290
... data=pd.DataFrame(data, columns=["x", "y"]),
83-
... required_z=True,
8491
... kind="vectors",
92+
... required_z=True,
8593
... )
8694
Traceback (most recent call last):
8795
...
88-
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
96+
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
8997
>>> _validate_data_input(
9098
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
91-
... required_z=True,
9299
... kind="vectors",
100+
... required_z=True,
93101
... )
94102
Traceback (most recent call last):
95103
...
96-
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
97-
>>> _validate_data_input(data="infile", x=[1, 2, 3])
98-
Traceback (most recent call last):
99-
...
100-
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
101-
>>> _validate_data_input(data="infile", y=[4, 5, 6])
102-
Traceback (most recent call last):
103-
...
104-
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
105-
>>> _validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6])
106-
Traceback (most recent call last):
107-
...
108-
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
109-
>>> _validate_data_input(data="infile", z=[7, 8, 9])
110-
Traceback (most recent call last):
111-
...
112-
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
104+
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
113105
114106
Raises
115107
------
116108
GMTInvalidInput
117109
If the data input is not valid.
118110
"""
119-
if data is None: # data is None
120-
if x is None and y is None: # both x and y are None
121-
if required_data: # data is not optional
122-
msg = "No input data provided."
111+
# Determine the required number of columns based on the required_z flag.
112+
required_cols = 3 if required_z else 1
113+
114+
match kind:
115+
case "empty": # data = [x, y], [x, y, z], [x, y, z, ...]
116+
if len(data) < 2 or any(v is None for v in data[:2]):
117+
msg = "Must provide both x and y."
123118
raise GMTInvalidInput(msg)
124-
elif x is None or y is None: # either x or y is None
125-
msg = "Must provide both x and y."
126-
raise GMTInvalidInput(msg)
127-
if required_z and z is None: # both x and y are not None, now check z
128-
msg = "Must provide x, y, and z."
129-
raise GMTInvalidInput(msg)
130-
else: # data is not None
131-
if x is not None or y is not None or z is not None:
132-
msg = "Too much data. Use either data or x/y/z."
133-
raise GMTInvalidInput(msg)
134-
# check if data has the required z column
135-
if required_z:
136-
msg = "data must provide x, y, and z columns."
137-
if kind == "matrix" and data.shape[1] < 3:
119+
if required_z and (len(data) < 3 or data[:3] is None):
120+
msg = "Must provide x, y, and z."
138121
raise GMTInvalidInput(msg)
139-
if kind == "vectors":
140-
if hasattr(data, "shape") and (
141-
(len(data.shape) == 1 and data.shape[0] < 3)
142-
or (len(data.shape) > 1 and data.shape[1] < 3)
143-
): # np.ndarray or pd.DataFrame
144-
raise GMTInvalidInput(msg)
145-
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
146-
raise GMTInvalidInput(msg)
147-
if kind == "vectors" and isinstance(data, dict):
148-
# Iterator over the up-to-3 first elements.
149-
arrays = list(islice(data.values(), 3))
150-
if len(arrays) < 2 or any(v is None for v in arrays[:2]): # Check x/y
151-
msg = "Must provide x and y."
122+
case "matrix": # 2-D numpy.ndarray
123+
if (actual_cols := data.shape[1]) < required_cols:
124+
msg = f"Need at least {required_cols} columns but {actual_cols} column(s) are given."
152125
raise GMTInvalidInput(msg)
153-
if required_z and (len(arrays) < 3 or arrays[2] is None): # Check z
154-
msg = "Must provide x, y, and z."
126+
case "vectors":
127+
# "vectors" means the original data is either dictionary, list, tuple,
128+
# pandas.DataFrame, pandas.Series, xarray.Dataset, or xarray.DataArray.
129+
# The original data is converted to a list of vectors or a 2-D numpy.ndarray
130+
# in the virtualfile_in function.
131+
if (actual_cols := len(data)) < required_cols:
132+
msg = f"Need at least {required_cols} columns but {actual_cols} column(s) are given."
155133
raise GMTInvalidInput(msg)
156134

157135

@@ -271,11 +249,7 @@ def _check_encoding(argstr: str) -> Encoding:
271249
return "ISOLatin1+"
272250

273251

274-
def data_kind(
275-
data: Any, required: bool = True
276-
) -> Literal[
277-
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
278-
]:
252+
def data_kind(data: Any, required: bool = True) -> Kind:
279253
r"""
280254
Check the kind of data that is provided to a module.
281255

0 commit comments

Comments
 (0)