Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixup no_trainer examples scripts and add more tests #16765

Merged
merged 5 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions examples/pytorch/language-modeling/run_clm_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def parse_args():
)
parser.add_argument(
"--with_tracking",
required=False,
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()
Expand All @@ -227,7 +227,7 @@ def main():

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -485,7 +485,10 @@ def group_texts(examples):

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
accelerator.init_trackers("clm_no_trainer", args)
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("clm_no_trainer", experiment_config)

# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
Expand Down Expand Up @@ -571,12 +574,7 @@ def group_texts(examples):

if args.with_tracking:
accelerator.log(
{
"perplexity": perplexity,
"train_loss": total_loss,
"epoch": epoch,
},
step=completed_steps,
{"perplexity": perplexity, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
)

if args.push_to_hub and epoch < args.num_train_epochs - 1:
Expand Down
16 changes: 7 additions & 9 deletions examples/pytorch/language-modeling/run_mlm_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def parse_args():
)
parser.add_argument(
"--with_tracking",
required=False,
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()
Expand Down Expand Up @@ -238,7 +238,7 @@ def main():

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -531,7 +531,10 @@ def group_texts(examples):

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
accelerator.init_trackers("clm_no_trainer", args)
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("mlm_no_trainer", experiment_config)

# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
Expand Down Expand Up @@ -618,12 +621,7 @@ def group_texts(examples):

if args.with_tracking:
accelerator.log(
{
"perplexity": perplexity,
"train_loss": total_loss,
"epoch": epoch,
},
step=completed_steps,
{"perplexity": perplexity, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
)

if args.push_to_hub and epoch < args.num_train_epochs - 1:
Expand Down
16 changes: 7 additions & 9 deletions examples/pytorch/multiple-choice/run_swag_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def parse_args():
)
parser.add_argument(
"--with_tracking",
required=False,
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()
Expand Down Expand Up @@ -265,7 +265,7 @@ def main():

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -485,7 +485,10 @@ def preprocess_function(examples):

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
accelerator.init_trackers("clm_no_trainer", args)
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("swag_no_trainer", experiment_config)

# Metrics
metric = load_metric("accuracy")
Expand Down Expand Up @@ -570,12 +573,7 @@ def preprocess_function(examples):

if args.with_tracking:
accelerator.log(
{
"accuracy": eval_metric,
"train_loss": total_loss,
"epoch": epoch,
},
step=completed_steps,
{"accuracy": eval_metric, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
)

if args.push_to_hub and epoch < args.num_train_epochs - 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def parse_args():
)
parser.add_argument(
"--with_tracking",
required=False,
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()
Expand Down Expand Up @@ -259,7 +259,7 @@ def main():

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -723,7 +723,10 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
accelerator.init_trackers("clm_no_trainer", args)
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("qa_beam_search_no_trainer", experiment_config)

# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
Expand Down Expand Up @@ -916,11 +919,12 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
"squad_v2" if args.version_2_with_negative else "squad": eval_metric,
"train_loss": total_loss,
"epoch": epoch,
"step": completed_steps,
}
if args.do_predict:
log["squad_v2_predict" if args.version_2_with_negative else "squad_predict"] = predict_metric

accelerator.log(log, step=completed_steps)
accelerator.log(log)

if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}")
Expand Down
12 changes: 8 additions & 4 deletions examples/pytorch/question-answering/run_qa_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def parse_args():
)
parser.add_argument(
"--with_tracking",
required=False,
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()
Expand Down Expand Up @@ -289,7 +289,7 @@ def main():

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -730,7 +730,10 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
accelerator.init_trackers("clm_no_trainer", args)
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("qa_no_trainer", experiment_config)

# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
Expand Down Expand Up @@ -889,11 +892,12 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
"squad_v2" if args.version_2_with_negative else "squad": eval_metric,
"train_loss": total_loss,
"epoch": epoch,
"step": completed_steps,
}
if args.do_predict:
log["squad_v2_predict" if args.version_2_with_negative else "squad_predict"] = predict_metric

accelerator.log(log, step=completed_steps)
accelerator.log(log)

if args.output_dir is not None:
accelerator.wait_for_everyone()
Expand Down
12 changes: 8 additions & 4 deletions examples/pytorch/summarization/run_summarization_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def parse_args():
)
parser.add_argument(
"--with_tracking",
required=False,
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()
Expand Down Expand Up @@ -315,7 +315,7 @@ def main():
)
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -548,7 +548,10 @@ def postprocess_text(preds, labels):

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
accelerator.init_trackers("summarization_no_trainer", args)
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("summarization_no_trainer", experiment_config)

# Metric
metric = load_metric("rouge")
Expand Down Expand Up @@ -666,7 +669,8 @@ def postprocess_text(preds, labels):
if args.with_tracking:
result["train_loss"] = total_loss
result["epoch"] = epoch
accelerator.log(result, step=completed_steps)
result["step"] = completed_steps
accelerator.log(result)

if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
Expand Down
20 changes: 18 additions & 2 deletions examples/pytorch/test_accelerate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def test_run_glue_no_trainer(self):
--learning_rate=1e-4
--seed=42
--checkpointing_steps epoch
""".split()
--with_tracking
""".split()

if is_cuda_and_apex_available():
testargs.append("--fp16")
Expand All @@ -114,6 +115,7 @@ def test_run_glue_no_trainer(self):
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "glue_no_trainer")))

def test_run_clm_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
Expand All @@ -128,7 +130,8 @@ def test_run_clm_no_trainer(self):
--num_train_epochs 2
--output_dir {tmp_dir}
--checkpointing_steps epoch
""".split()
--with_tracking
""".split()

if torch.cuda.device_count() > 1:
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
Expand All @@ -139,6 +142,7 @@ def test_run_clm_no_trainer(self):
result = get_results(tmp_dir)
self.assertLess(result["perplexity"], 100)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "clm_no_trainer")))

def test_run_mlm_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
Expand All @@ -150,13 +154,15 @@ def test_run_mlm_no_trainer(self):
--output_dir {tmp_dir}
--num_train_epochs=1
--checkpointing_steps epoch
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
run_mlm_no_trainer.main()
result = get_results(tmp_dir)
self.assertLess(result["perplexity"], 42)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "mlm_no_trainer")))

def test_run_ner_no_trainer(self):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
Expand All @@ -175,6 +181,7 @@ def test_run_ner_no_trainer(self):
--num_train_epochs={epochs}
--seed 7
--checkpointing_steps epoch
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
Expand All @@ -183,6 +190,7 @@ def test_run_ner_no_trainer(self):
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertLess(result["train_loss"], 0.5)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))

def test_run_squad_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
Expand All @@ -199,6 +207,7 @@ def test_run_squad_no_trainer(self):
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--checkpointing_steps epoch
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
Expand All @@ -207,6 +216,7 @@ def test_run_squad_no_trainer(self):
self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "qa_no_trainer")))

def test_run_swag_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
Expand All @@ -221,12 +231,14 @@ def test_run_swag_no_trainer(self):
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
run_swag_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "swag_no_trainer")))

@slow
def test_run_summarization_no_trainer(self):
Expand All @@ -243,6 +255,7 @@ def test_run_summarization_no_trainer(self):
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--checkpointing_steps epoch
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
Expand All @@ -253,6 +266,7 @@ def test_run_summarization_no_trainer(self):
self.assertGreaterEqual(result["eval_rougeL"], 7)
self.assertGreaterEqual(result["eval_rougeLsum"], 7)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "summarization_no_trainer")))

@slow
def test_run_translation_no_trainer(self):
Expand All @@ -273,10 +287,12 @@ def test_run_translation_no_trainer(self):
--source_lang en_XX
--target_lang ro_RO
--checkpointing_steps epoch
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
run_translation_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "translation_no_trainer")))
Loading