Skip to content

Commit 6193fc5

Browse files
committed
👕 style: black formatting
1 parent 552d623 commit 6193fc5

File tree

6 files changed

+150
-77
lines changed

6 files changed

+150
-77
lines changed

app/crud/geostore.py

+52-16
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,27 @@ async def get_first_row(sql: Select):
192192

193193

194194
async def get_gadm_geostore_id(
195-
admin_provider: str,
196-
admin_version: str,
197-
adm_level: int,
198-
country_id: str,
199-
region_id: str | None = None,
200-
subregion_id: str | None = None,
195+
admin_provider: str,
196+
admin_version: str,
197+
adm_level: int,
198+
country_id: str,
199+
region_id: str | None = None,
200+
subregion_id: str | None = None,
201201
) -> str:
202202
src_table = await get_versioned_dataset(admin_provider, admin_version)
203-
columns_etc: List[Column | Label] = [db.column("gfw_geostore_id"),]
204-
row = await _find_first_geostore(adm_level, admin_provider, admin_version, columns_etc, country_id, region_id,
205-
src_table, subregion_id)
203+
columns_etc: List[Column | Label] = [
204+
db.column("gfw_geostore_id"),
205+
]
206+
row = await _find_first_geostore(
207+
adm_level,
208+
admin_provider,
209+
admin_version,
210+
columns_etc,
211+
country_id,
212+
region_id,
213+
src_table,
214+
subregion_id,
215+
)
206216
return await row.gfw_geostore_id
207217

208218

@@ -240,8 +250,16 @@ async def build_gadm_geostore(
240250
)
241251
)
242252

243-
row = await _find_first_geostore(adm_level, admin_provider, admin_version, columns_etc, country_id, region_id,
244-
src_table, subregion_id)
253+
row = await _find_first_geostore(
254+
adm_level,
255+
admin_provider,
256+
admin_version,
257+
columns_etc,
258+
country_id,
259+
region_id,
260+
src_table,
261+
subregion_id,
262+
)
245263

246264
if row.geojson is None:
247265
raise GeometryIsNullError(
@@ -261,10 +279,26 @@ async def build_gadm_geostore(
261279
)
262280

263281

264-
async def _find_first_geostore(adm_level, admin_provider, admin_version, columns_etc, country_id, region_id, src_table,
265-
subregion_id):
282+
async def _find_first_geostore(
283+
adm_level,
284+
admin_provider,
285+
admin_version,
286+
columns_etc,
287+
country_id,
288+
region_id,
289+
src_table,
290+
subregion_id,
291+
):
266292
sql: Select = db.select(columns_etc).select_from(src_table)
267-
sql = await add_where_clauses(adm_level, admin_provider, admin_version, country_id, region_id, sql, subregion_id)
293+
sql = await add_where_clauses(
294+
adm_level,
295+
admin_provider,
296+
admin_version,
297+
country_id,
298+
region_id,
299+
sql,
300+
subregion_id,
301+
)
268302
row = await get_first_row(sql)
269303
if row is None:
270304
raise RecordNotFoundError(
@@ -273,7 +307,9 @@ async def _find_first_geostore(adm_level, admin_provider, admin_version, columns
273307
return row
274308

275309

276-
async def add_where_clauses(adm_level, admin_provider, admin_version, country_id, region_id, sql, subregion_id):
310+
async def add_where_clauses(
311+
adm_level, admin_provider, admin_version, country_id, region_id, sql, subregion_id
312+
):
277313
where_clauses: List[TextClause] = [
278314
db.text("adm_level=:adm_level").bindparams(adm_level=str(adm_level))
279315
]
@@ -335,7 +371,7 @@ async def get_gadm_geostore(
335371
admin_version=admin_version,
336372
adm_level=adm_level,
337373
simplify=simplify,
338-
country_id=country_id,
374+
country_id=country_id,
339375
region_id=region_id,
340376
subregion_id=subregion_id,
341377
)

app/models/pydantic/datamart.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1+
from abc import ABC, abstractmethod
12
from enum import Enum
23
from typing import Dict, Literal, Optional, Union
34
from uuid import UUID
4-
from abc import ABC, abstractmethod
55

66
from pydantic import Field, root_validator, validator
77

88
from app.models.pydantic.responses import Response
99

10-
from .base import StrictBaseModel
1110
from ...crud.geostore import get_gadm_geostore_id
11+
from .base import StrictBaseModel
1212

1313

1414
class AreaOfInterest(StrictBaseModel, ABC):
@@ -19,24 +19,30 @@ async def get_geostore_id(self) -> UUID:
1919

2020

2121
class GeostoreAreaOfInterest(AreaOfInterest):
22-
type: Literal['geostore'] = 'geostore'
23-
geostore_id:UUID = Field(..., title="Geostore ID")
22+
type: Literal["geostore"] = "geostore"
23+
geostore_id: UUID = Field(..., title="Geostore ID")
2424

2525
async def get_geostore_id(self) -> UUID:
2626
return self.geostore_id
2727

2828

2929
class AdminAreaOfInterest(AreaOfInterest):
30-
type: Literal['admin'] = 'admin'
30+
type: Literal["admin"] = "admin"
3131
country: str = Field(..., title="ISO Country Code")
3232
region: Optional[str] = Field(None, title="Region")
3333
subregion: Optional[str] = Field(None, title="Subregion")
34-
provider: str = Field('gadm', title="Administrative Boundary Provider")
35-
version: str = Field('4.1', title="Administrative Boundary Version")
36-
34+
provider: str = Field("gadm", title="Administrative Boundary Provider")
35+
version: str = Field("4.1", title="Administrative Boundary Version")
3736

3837
async def get_geostore_id(self) -> UUID:
39-
admin_level = sum(1 for field in (self.country, self.region, self.subregion) if field is not None) - 1
38+
admin_level = (
39+
sum(
40+
1
41+
for field in (self.country, self.region, self.subregion)
42+
if field is not None
43+
)
44+
- 1
45+
)
4046
geostore_id = await get_gadm_geostore_id(
4147
admin_provider=self.provider,
4248
admin_version=self.version,
@@ -55,14 +61,13 @@ def check_region_subregion(cls, values):
5561
raise ValueError("region must be specified if subregion is provided")
5662
return values
5763

58-
@validator('provider', pre=True, always=True)
64+
@validator("provider", pre=True, always=True)
5965
def set_provider_default(cls, v):
60-
return v or 'gadm'
66+
return v or "gadm"
6167

62-
@validator('version', pre=True, always=True)
68+
@validator("version", pre=True, always=True)
6369
def set_version_default(cls, v):
64-
return v or '4.1'
65-
70+
return v or "4.1"
6671

6772

6873
class AnalysisStatus(str, Enum):
@@ -99,7 +104,9 @@ class DataMartResourceLinkResponse(Response):
99104

100105

101106
class TreeCoverLossByDriverIn(StrictBaseModel):
102-
aoi: Union[GeostoreAreaOfInterest, AdminAreaOfInterest] = Field(..., discriminator='type')
107+
aoi: Union[GeostoreAreaOfInterest, AdminAreaOfInterest] = Field(
108+
..., discriminator="type"
109+
)
103110
canopy_cover: int = 30
104111
dataset_version: Dict[str, str] = {}
105112

@@ -132,6 +139,3 @@ class Config:
132139

133140
class TreeCoverLossByDriverResponse(Response):
134141
data: TreeCoverLossByDriver
135-
136-
137-

app/routes/datamart/__init__.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"value": {
1414
"type": "geostore",
1515
"geostore_id": "637d378f-93a9-4364-bfa8-95b6afd28c3a",
16-
}
16+
},
1717
},
1818
"Admin Area Of Interest": {
1919
"summary": "Admin Area Of Interest",
@@ -23,16 +23,16 @@
2323
"country": "BRA",
2424
"region": "12",
2525
"subregion": "2",
26-
}
27-
}
26+
},
27+
},
2828
},
2929
"description": "The Area of Interest",
3030
"schema": {
3131
"oneOf": [
3232
{"$ref": "#/components/schemas/GeostoreAreaOfInterest"},
3333
{"$ref": "#/components/schemas/AdminAreaOfInterest"},
3434
]
35-
}
35+
},
3636
},
3737
{
3838
"name": "dataset_version",
@@ -50,7 +50,7 @@
5050
},
5151
"description": (
5252
"Pass dataset version overrides as bracketed query parameters.",
53-
)
54-
}
53+
),
54+
},
5555
]
56-
}
56+
}

app/routes/datamart/land.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,30 @@
2323
from app.errors import RecordNotFoundError
2424
from app.models.enum.geostore import GeostoreOrigin
2525
from app.models.pydantic.datamart import (
26+
AdminAreaOfInterest,
2627
AnalysisStatus,
28+
AreaOfInterest,
2729
DataMartResource,
2830
DataMartResourceLink,
2931
DataMartResourceLinkResponse,
32+
GeostoreAreaOfInterest,
3033
TreeCoverLossByDriver,
3134
TreeCoverLossByDriverIn,
3235
TreeCoverLossByDriverResponse,
33-
AreaOfInterest,
34-
GeostoreAreaOfInterest,
35-
AdminAreaOfInterest,
3636
)
3737
from app.settings.globals import API_URL
3838
from app.tasks.datamart.land import (
3939
DEFAULT_LAND_DATASET_VERSIONS,
4040
compute_tree_cover_loss_by_driver,
4141
)
4242
from app.utils.geostore import get_geostore
43-
from . import OPENAPI_EXTRA
4443

4544
from ...authentication.api_keys import get_api_key
45+
from . import OPENAPI_EXTRA
4646

4747
router = APIRouter()
4848

49+
4950
def _parse_dataset_versions(request: Request) -> Dict[str, str]:
5051
dataset_versions = {}
5152
errors = []
@@ -69,23 +70,27 @@ def _parse_dataset_versions(request: Request) -> Dict[str, str]:
6970

7071
def _parse_area_of_interest(request: Request) -> AreaOfInterest:
7172
params = request.query_params
72-
aoi_type = params.get('aoi[type]')
73+
aoi_type = params.get("aoi[type]")
7374
try:
74-
if aoi_type == 'geostore':
75-
return GeostoreAreaOfInterest(geostore_id=params.get('aoi[geostore_id]', None))
75+
if aoi_type == "geostore":
76+
return GeostoreAreaOfInterest(
77+
geostore_id=params.get("aoi[geostore_id]", None)
78+
)
7679

7780
# Otherwise, check if the request contains admin area information
78-
if aoi_type == 'admin':
81+
if aoi_type == "admin":
7982
return AdminAreaOfInterest(
80-
country=params.get('aoi[country]', None),
81-
region=params.get('aoi[region]', None),
82-
subregion=params.get('aoi[subregion]', None),
83-
provider=params.get('aoi[provider]', None),
84-
version=params.get('aoi[version]', None),
83+
country=params.get("aoi[country]", None),
84+
region=params.get("aoi[region]", None),
85+
subregion=params.get("aoi[subregion]", None),
86+
provider=params.get("aoi[provider]", None),
87+
version=params.get("aoi[version]", None),
8588
)
8689

8790
# If neither type is provided, raise an error
88-
raise HTTPException(status_code=422, detail="Invalid Area of Interest parameters")
91+
raise HTTPException(
92+
status_code=422, detail="Invalid Area of Interest parameters"
93+
)
8994
except ValidationError as e:
9095
raise HTTPException(status_code=422, detail=e.errors())
9196

@@ -96,7 +101,7 @@ def _parse_area_of_interest(request: Request) -> AreaOfInterest:
96101
response_model=DataMartResourceLinkResponse,
97102
tags=["Land"],
98103
status_code=200,
99-
openapi_extra= OPENAPI_EXTRA,
104+
openapi_extra=OPENAPI_EXTRA,
100105
)
101106
async def tree_cover_loss_by_driver_search(
102107
*,

tests_v2/unit/app/crud/test_geostore.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
@pytest.mark.asyncio
1111
async def test_get_gadm_geostore_generates_correct_sql_for_country_lookup(
12-
async_client: AsyncClient
12+
async_client: AsyncClient,
1313
):
1414
provider = "gadm"
1515
version = "4.1"
@@ -45,7 +45,7 @@ async def test_get_gadm_geostore_generates_correct_sql_for_country_lookup(
4545

4646
@pytest.mark.asyncio
4747
async def test_get_gadm_geostore_generates_correct_sql_for_region_lookup(
48-
async_client: AsyncClient
48+
async_client: AsyncClient,
4949
):
5050
provider = "gadm"
5151
version = "4.1"
@@ -82,7 +82,7 @@ async def test_get_gadm_geostore_generates_correct_sql_for_region_lookup(
8282

8383
@pytest.mark.asyncio
8484
async def test_get_gadm_geostore_generates_correct_sql_for_subregion_lookup(
85-
async_client: AsyncClient
85+
async_client: AsyncClient,
8686
):
8787
provider = "gadm"
8888
version = "4.1"
@@ -121,7 +121,7 @@ async def test_get_gadm_geostore_generates_correct_sql_for_subregion_lookup(
121121
class TestGadmGeostoreIDLookup:
122122
@pytest.mark.asyncio
123123
async def test_get_gadm_geostore_id_generates_correct_sql_for_country_lookup(
124-
async_client: AsyncClient
124+
async_client: AsyncClient,
125125
):
126126
provider = "gadm"
127127
version = "4.1"
@@ -152,10 +152,9 @@ async def test_get_gadm_geostore_id_generates_correct_sql_for_country_lookup(
152152
assert mock_get_first_row.called is True
153153
assert actual_sql == expected_sql
154154

155-
156155
@pytest.mark.asyncio
157156
async def test_get_gadm_geostore_id_generates_correct_sql_for_region_lookup(
158-
async_client: AsyncClient
157+
async_client: AsyncClient,
159158
):
160159
provider = "gadm"
161160
version = "4.1"
@@ -187,10 +186,9 @@ async def test_get_gadm_geostore_id_generates_correct_sql_for_region_lookup(
187186
assert mock_get_first_row.called is True
188187
assert actual_sql == expected_sql
189188

190-
191189
@pytest.mark.asyncio
192190
async def test_get_gadm_geostore_id_generates_correct_sql_for_subregion_lookup(
193-
async_client: AsyncClient
191+
async_client: AsyncClient,
194192
):
195193
provider = "gadm"
196194
version = "4.1"

0 commit comments

Comments
 (0)