Skip to content

Commit

Permalink
add arg num_workers for ernie3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed Apr 10, 2023
1 parent c2f7e31 commit dc08470
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def create_data_loader(self, args, **kwargs):
train_loader = DataLoader(
dataset=train_ds,
batch_sampler=train_batch_sampler,
num_workers=4, # when paddlepaddle<=2.4.1, if we use dynamicTostatic mode, we need set num_workeks > 0
num_workers=args.num_workers, # when paddlepaddle<=2.4.1, if we use dynamicTostatic mode, we need set num_workeks > 0
)

self.num_batch = len(train_loader)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_tipc/benchmark/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ def get_parser():
parser.add_argument("--epoch", type=int, default=10, help="Number of epochs. ")

parser.add_argument("--generated_inputs", action="store_true", help="Use generated inputs. ")
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="num_workers of dataloader. When paddlepaddle<=2.4.1, if we use dynamicTostatic mode, we need set num_workeks > 0 ",
)

# For benchmark.
parser.add_argument(
Expand Down

0 comments on commit dc08470

Please sign in to comment.