diff --git a/gat/README.md b/gat/README.md index 7bb71bc17b..d7ae967379 100644 --- a/gat/README.md +++ b/gat/README.md @@ -89,6 +89,7 @@ options: epochs to wait for print training and validation evaluation (default: 20) --no-cuda disables CUDA training --no-mps disables macOS GPU training + --no-xpu disables XPU training --dry-run quickly check a single pass --seed S random seed (default: 13) ``` diff --git a/gat/main.py b/gat/main.py index 9c143af8ec..cba703de5c 100644 --- a/gat/main.py +++ b/gat/main.py @@ -303,15 +303,17 @@ def test(model, criterion, input, target, mask): help='dimension of the hidden representation (default: 64)') parser.add_argument('--num-heads', type=int, default=8, help='number of the attention heads (default: 4)') - parser.add_argument('--concat-heads', action='store_true', default=False, + parser.add_argument('--concat-heads', action='store_true', help='wether to concatinate attention heads, or average over them (default: False)') parser.add_argument('--val-every', type=int, default=20, help='epochs to wait for print training and validation evaluation (default: 20)') - parser.add_argument('--no-cuda', action='store_true', default=False, + parser.add_argument('--no-cuda', action='store_true', help='disables CUDA training') - parser.add_argument('--no-mps', action='store_true', default=False, + parser.add_argument('--no-xpu', action='store_true', + help='disables XPU training') + parser.add_argument('--no-mps', action='store_true', help='disables macOS GPU training') - parser.add_argument('--dry-run', action='store_true', default=False, + parser.add_argument('--dry-run', action='store_true', help='quickly check a single pass') parser.add_argument('--seed', type=int, default=13, metavar='S', help='random seed (default: 13)') @@ -320,12 +322,15 @@ def test(model, criterion, input, target, mask): torch.manual_seed(args.seed) use_cuda = not args.no_cuda and torch.cuda.is_available() use_mps = not args.no_mps and torch.backends.mps.is_available() + use_xpu = not args.no_xpu and torch.xpu.is_available() # Set the device to run on if use_cuda: device = torch.device('cuda') elif use_mps: device = torch.device('mps') + elif use_xpu: + device = torch.device('xpu') else: device = torch.device('cpu') print(f'Using {device} device')