Skip to content

Commit b37284e

Browse files
🚀 Add XPU accelerator (open-edge-platform#2530)
* Add XPU accelerator Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Update changelog Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * precommit Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add documentation Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> --------- Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com>
1 parent 3e64f09 commit b37284e

File tree

9 files changed

+218
-2
lines changed

9 files changed

+218
-2
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- 🚀 Add XPU accelerator and strategy by @ashwinvaidya17 in https://github.com/openvinotoolkit/anomalib/pull/2530
12+
1113
### Removed
1214

1315
### Changed

README.md

+28
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,34 @@ anomalib predict --model anomalib.models.Patchcore \
180180

181181
> 📘 **Note:** For advanced inference options including Gradio and OpenVINO, check our [Inference Documentation](https://anomalib.readthedocs.io).
182182
183+
# Training on Intel GPUs
184+
185+
> [!Note]
186+
> Currently, only single GPU training is supported on Intel GPUs.
187+
> These commands were tested on Arc 750 and Arc 770.
188+
189+
Ensure that you have PyTorch with XPU support installed. For more information, please refer to the [PyTorch XPU documentation](https://pytorch.org/docs/stable/notes/get_start_xpu.html)
190+
191+
## 🔌 API
192+
193+
```python
194+
from anomalib.data import MVTec
195+
from anomalib.engine import Engine, SingleXPUStrategy, XPUAccelerator
196+
from anomalib.models import Stfpm
197+
198+
engine = Engine(
199+
strategy=SingleXPUStrategy(),
200+
accelerator=XPUAccelerator(),
201+
)
202+
engine.train(Stfpm(), datamodule=MVTec())
203+
```
204+
205+
## ⌨️ CLI
206+
207+
```bash
208+
anomalib train --model Padim --data MVTec --trainer.accelerator xpu --trainer.strategy xpu_single
209+
```
210+
183211
# ⚙️ Hyperparameter Optimization
184212

185213
Anomalib supports hyperparameter optimization (HPO) using [Weights & Biases](https://wandb.ai/) and [Comet.ml](https://www.comet.com/).

docs/source/markdown/guides/how_to/index.md

+8
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ Learn more about anomalib's deployment capabilities
7272
Learn more about anomalib hpo, sweep and benchmarking pipelines
7373
:::
7474

75+
:::{grid-item-card} {octicon}`cpu` Training on Intel GPUs
76+
:link: ./training_on_intel_gpus/index
77+
:link-type: doc
78+
79+
Learn more about training on Intel GPUs
80+
:::
81+
7582
::::
7683

7784
```{toctree}
@@ -83,4 +90,5 @@ Learn more about anomalib hpo, sweep and benchmarking pipelines
8390
./models/index
8491
./pipelines/index
8592
./visualization/index
93+
./training_on_intel_gpus/index
8694
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Training on Intel GPUs
2+
3+
This tutorial demonstrates how to train a model on Intel GPUs using anomalib.
4+
Anomalib comes with XPU accelerator and strategy for PyTorch Lightning. This allows you to train your models on Intel GPUs.
5+
6+
> [!Note]
7+
> Currently, only single GPU training is supported on Intel GPUs.
8+
> These commands were tested on Arc 750 and Arc 770.
9+
10+
## Installing Drivers
11+
12+
First, check if you have the correct drivers installed. If you are on Ubuntu, you can refer to the [following guide](https://dgpu-docs.intel.com/driver/client/overview.html).
13+
14+
Another recommended tool is `xpu-smi` which can be installed from the [releases](https://github.com/intel/xpumanager) page.
15+
16+
If everything is installed correctly, you should be able to see your card using the following command:
17+
18+
```bash
19+
xpu-smi discovery
20+
```
21+
22+
## Installing PyTorch
23+
24+
Then, ensure that you have PyTorch with XPU support installed. For more information, please refer to the [PyTorch XPU documentation](https://pytorch.org/docs/stable/notes/get_start_xpu.html)
25+
26+
To ensure that your PyTorch installation supports XPU, you can run the following command:
27+
28+
```bash
29+
python -c "import torch; print(torch.xpu.is_available())"
30+
```
31+
32+
If the command returns `True`, then your PyTorch installation supports XPU.
33+
34+
## 🔌 API
35+
36+
```python
37+
from anomalib.data import MVTec
38+
from anomalib.engine import Engine, SingleXPUStrategy, XPUAccelerator
39+
from anomalib.models import Stfpm
40+
41+
engine = Engine(
42+
strategy=SingleXPUStrategy(),
43+
accelerator=XPUAccelerator(),
44+
)
45+
engine.train(Stfpm(), datamodule=MVTec())
46+
```
47+
48+
## ⌨️ CLI
49+
50+
```bash
51+
anomalib train --model Padim --data MVTec --trainer.accelerator xpu --trainer.strategy xpu_single
52+
```

src/anomalib/engine/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
>>> engine = Engine(config=config) # doctest: +SKIP
2424
"""
2525

26-
# Copyright (C) 2024 Intel Corporation
26+
# Copyright (C) 2024-2025 Intel Corporation
2727
# SPDX-License-Identifier: Apache-2.0
2828

29+
from .accelerator import XPUAccelerator
2930
from .engine import Engine
31+
from .strategy import SingleXPUStrategy
3032

31-
__all__ = ["Engine"]
33+
__all__ = ["Engine", "SingleXPUStrategy", "XPUAccelerator"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Accelerator for Lightning Trainer."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from .xpu import XPUAccelerator
7+
8+
__all__ = ["XPUAccelerator"]
+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""XPU Accelerator."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from typing import Any
7+
8+
import torch
9+
from lightning.pytorch.accelerators import Accelerator, AcceleratorRegistry
10+
11+
12+
class XPUAccelerator(Accelerator):
13+
"""Support for a XPU, optimized for large-scale machine learning."""
14+
15+
accelerator_name = "xpu"
16+
17+
@staticmethod
18+
def setup_device(device: torch.device) -> None:
19+
"""Sets up the specified device."""
20+
if device.type != "xpu":
21+
msg = f"Device should be xpu, got {device} instead"
22+
raise RuntimeError(msg)
23+
24+
torch.xpu.set_device(device)
25+
26+
@staticmethod
27+
def parse_devices(devices: str | list | torch.device) -> list:
28+
"""Parses devices for multi-GPU training."""
29+
if isinstance(devices, list):
30+
return devices
31+
return [devices]
32+
33+
@staticmethod
34+
def get_parallel_devices(devices: list) -> list[torch.device]:
35+
"""Generates a list of parrallel devices."""
36+
return [torch.device("xpu", idx) for idx in devices]
37+
38+
@staticmethod
39+
def auto_device_count() -> int:
40+
"""Returns number of XPU devices available."""
41+
return torch.xpu.device_count()
42+
43+
@staticmethod
44+
def is_available() -> bool:
45+
"""Checks if XPU available."""
46+
return hasattr(torch, "xpu") and torch.xpu.is_available()
47+
48+
@staticmethod
49+
def get_device_stats(device: str | torch.device) -> dict[str, Any]:
50+
"""Returns XPU devices stats."""
51+
del device # Unused
52+
return {}
53+
54+
def teardown(self) -> None:
55+
"""Teardown the XPU accelerator.
56+
57+
This method is empty as it needs to be overridden otherwise the base class will throw an error.
58+
"""
59+
60+
61+
AcceleratorRegistry.register(
62+
XPUAccelerator.accelerator_name,
63+
XPUAccelerator,
64+
description="Accelerator supports XPU devices",
65+
)
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Strategy for Lightning Trainer."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from .xpu_single import SingleXPUStrategy
7+
8+
__all__ = ["SingleXPUStrategy"]
+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Lightning strategy for single XPU device."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
import lightning.pytorch as pl
7+
import torch
8+
from lightning.pytorch.strategies import SingleDeviceStrategy, StrategyRegistry
9+
from lightning.pytorch.utilities.exceptions import MisconfigurationException
10+
from lightning_fabric.plugins import CheckpointIO
11+
from lightning_fabric.plugins.precision import Precision
12+
from lightning_fabric.utilities.types import _DEVICE
13+
14+
15+
class SingleXPUStrategy(SingleDeviceStrategy):
16+
"""Strategy for training on single XPU device."""
17+
18+
strategy_name = "xpu_single"
19+
20+
def __init__(
21+
self,
22+
device: _DEVICE = "xpu:0",
23+
accelerator: pl.accelerators.Accelerator | None = None,
24+
checkpoint_io: CheckpointIO | None = None,
25+
precision_plugin: Precision | None = None,
26+
) -> None:
27+
if not (hasattr(torch, "xpu") and torch.xpu.is_available()):
28+
msg = "`SingleXPUStrategy` requires XPU devices to run"
29+
raise MisconfigurationException(msg)
30+
31+
super().__init__(
32+
accelerator=accelerator,
33+
device=device,
34+
checkpoint_io=checkpoint_io,
35+
precision_plugin=precision_plugin,
36+
)
37+
38+
39+
StrategyRegistry.register(
40+
SingleXPUStrategy.strategy_name,
41+
SingleXPUStrategy,
42+
description="Strategy that enables training on single XPU",
43+
)

0 commit comments

Comments
 (0)