Skip to content

Commit 3ee4a72

Browse files
committed
feat(AreaOfInterest): add discriminator param for better docs
And improve validation for `AdminAreaOfInterest`.
1 parent 23423b9 commit 3ee4a72

File tree

3 files changed

+52
-24
lines changed

3 files changed

+52
-24
lines changed

app/models/pydantic/datamart.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from enum import Enum
2-
from typing import Dict, Optional, Union
2+
from typing import Dict, Literal, Optional, Union
33
from uuid import UUID
44
from abc import ABC, abstractmethod
55

6-
from pydantic import Field
6+
from pydantic import Field, root_validator, validator
77

88
from app.models.pydantic.responses import Response
99

@@ -19,13 +19,15 @@ async def get_geostore_id(self) -> UUID:
1919

2020

2121
class GeostoreAreaOfInterest(AreaOfInterest):
22+
type: Literal['geostore'] = 'geostore'
2223
geostore_id:UUID = Field(..., title="Geostore ID")
2324

2425
async def get_geostore_id(self) -> UUID:
2526
return self.geostore_id
2627

2728

2829
class AdminAreaOfInterest(AreaOfInterest):
30+
type: Literal['admin'] = 'admin'
2931
country: str = Field(..., title="ISO Country Code")
3032
region: Optional[str] = Field(None, title="Region")
3133
subregion: Optional[str] = Field(None, title="Subregion")
@@ -45,6 +47,23 @@ async def get_geostore_id(self) -> UUID:
4547
)
4648
return UUID(geostore_id)
4749

50+
@root_validator
51+
def check_region_subregion(cls, values):
52+
region = values.get("region")
53+
subregion = values.get("subregion")
54+
if subregion is not None and region is None:
55+
raise ValueError("region must be specified if subregion is provided")
56+
return values
57+
58+
@validator('provider', pre=True, always=True)
59+
def set_provider_default(cls, v):
60+
return v or 'gadm'
61+
62+
@validator('version', pre=True, always=True)
63+
def set_version_default(cls, v):
64+
return v or '4.1'
65+
66+
4867

4968
class AnalysisStatus(str, Enum):
5069
saved = "saved"
@@ -80,7 +99,7 @@ class DataMartResourceLinkResponse(Response):
8099

81100

82101
class TreeCoverLossByDriverIn(StrictBaseModel):
83-
aoi: Union[GeostoreAreaOfInterest, AdminAreaOfInterest]
102+
aoi: Union[GeostoreAreaOfInterest, AdminAreaOfInterest] = Field(..., discriminator='type')
84103
canopy_cover: int = 30
85104
dataset_version: Dict[str, str] = {}
86105

app/routes/datamart/land.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from fastapi.openapi.models import APIKey
1919
from fastapi.responses import ORJSONResponse
20+
from pydantic import ValidationError
2021

2122
from app.crud import datamart as datamart_crud
2223
from app.errors import RecordNotFoundError
@@ -66,21 +67,26 @@ def _parse_dataset_versions(request: Request) -> Dict[str, str]:
6667

6768

6869
def _parse_area_of_interest(request: Request) -> AreaOfInterest:
69-
if 'aoi[geostore_id]' in request.query_params:
70-
return GeostoreAreaOfInterest(geostore_id=request.query_params['aoi[geostore_id]'])
71-
72-
# Otherwise, check if the request contains admin area information
73-
if 'aoi[country]' in request.query_params:
74-
return AdminAreaOfInterest(
75-
country=request.query_params['aoi[country]'],
76-
region=request.query_params.get('aoi[region]'),
77-
subregion=request.query_params.get('aoi[subregion]'),
78-
provider=request.query_params.get('aoi[provider]'),
79-
version=request.query_params.get('aoi[version]'),
80-
)
70+
params = request.query_params
71+
aoi_type = params.get('aoi[type]')
72+
try:
73+
if aoi_type == 'geostore':
74+
return GeostoreAreaOfInterest(geostore_id=params.get('aoi[geostore_id]', None))
75+
76+
# Otherwise, check if the request contains admin area information
77+
if aoi_type == 'admin':
78+
return AdminAreaOfInterest(
79+
country=params.get('aoi[country]', None),
80+
region=params.get('aoi[region]', None),
81+
subregion=params.get('aoi[subregion]', None),
82+
provider=params.get('aoi[provider]', None),
83+
version=params.get('aoi[version]', None),
84+
)
8185

82-
# If neither type is provided, raise an error
83-
raise HTTPException(status_code=422, detail="Invalid Area of Interest parameters")
86+
# If neither type is provided, raise an error
87+
raise HTTPException(status_code=422, detail="Invalid Area of Interest parameters")
88+
except ValidationError as e:
89+
raise HTTPException(status_code=422, detail=e.errors())
8490

8591

8692
@router.get(
@@ -200,7 +206,7 @@ async def tree_cover_loss_by_driver_post(
200206
except HTTPException:
201207
raise HTTPException(
202208
status_code=422,
203-
detail=f"Geostore {data.geostore_id} can't be found or is not valid.",
209+
detail=f"Geostore {geostore_id} can't be found or is not valid.",
204210
)
205211

206212
# create initial Job item as pending

tests_v2/unit/app/routes/datamart/test_land.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ async def test_get_tree_cover_loss_by_drivers_not_found(
3232
origin = payload["domains"][0]
3333

3434
headers = {"origin": origin}
35-
params = {"x-api-key": api_key, "aoi[geostore_id]": geostore, "canopy_cover": 30}
35+
params = {"x-api-key": api_key, "aoi[type]": "geostore", "aoi[geostore_id]": geostore, "canopy_cover": 30}
3636

3737
response = await async_client.get(
3838
"/v0/land/tree_cover_loss_by_driver", headers=headers, params=params
@@ -59,7 +59,7 @@ async def test_get_tree_cover_loss_by_drivers_found(
5959
origin = payload["domains"][0]
6060

6161
headers = {"origin": origin}
62-
params = {"x-api-key": api_key, "aoi[geostore_id]": geostore, "canopy_cover": 30}
62+
params = {"x-api-key": api_key, "aoi[type]": "geostore", "aoi[geostore_id]": geostore, "canopy_cover": 30}
6363
resource_id = _get_resource_id(
6464
"tree_cover_loss_by_driver", geostore, 30, DEFAULT_LAND_DATASET_VERSIONS
6565
)
@@ -94,7 +94,7 @@ async def test_get_tree_cover_loss_by_drivers_with_overrides(
9494
origin = payload["domains"][0]
9595

9696
headers = {"origin": origin}
97-
params = {"x-api-key": api_key, "aoi[geostore_id]": geostore, "canopy_cover": 30}
97+
params = {"x-api-key": api_key, "aoi[type]": "geostore", "aoi[geostore_id]": geostore, "canopy_cover": 30}
9898
resource_id = _get_resource_id(
9999
"tree_cover_loss_by_driver",
100100
geostore,
@@ -140,10 +140,10 @@ async def test_get_tree_cover_loss_by_drivers_with_malformed_overrides(
140140
origin = payload["domains"][0]
141141

142142
headers = {"origin": origin}
143-
params = {"x-api-key": api_key, "geostore_id": geostore, "canopy_cover": 30}
143+
params = {"x-api-key": api_key, "aoi[type]": "geostore", "aoi[geostore_id]": geostore, "canopy_cover": 30}
144144

145145
response = await async_client.get(
146-
f"/v0/land/tree_cover_loss_by_driver?x-api-key={api_key}&aoi[geostore_id]={geostore}&canopy_cover=30&dataset_version[umd_tree_cover_loss]]=v1.8&dataset_version[umd_tree_cover_density_2000]=v1.6",
146+
f"/v0/land/tree_cover_loss_by_driver?dataset_version[umd_tree_cover_loss]]=v1.8&dataset_version[umd_tree_cover_density_2000]=v1.6",
147147
headers=headers,
148148
params=params,
149149
)
@@ -166,7 +166,10 @@ async def test_post_tree_cover_loss_by_drivers(
166166

167167
headers = {"origin": origin, "x-api-key": api_key}
168168
payload = {
169-
"aoi": {"geostore_id": geostore},
169+
"aoi": {
170+
"type": "geostore",
171+
"geostore_id": geostore,
172+
},
170173
"canopy_cover": 30,
171174
"dataset_version": {"umd_tree_cover_loss": "v1.8"},
172175
}

0 commit comments

Comments
 (0)