Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(config): Move device & amp args to PreTrainedConfig #812

Merged
merged 11 commits into from
Mar 6, 2025

Conversation

imstevenpmwork
Copy link
Collaborator

@imstevenpmwork imstevenpmwork commented Mar 4, 2025

What this does

Related to: #683

This PR moves the device and use_amp configuration arguments from TrainPipelineConfig class to the PreTrainedConfig class. This change affects some other classes and functions in the code path, including:

  • The check of such arguments is now executed in the __post_init__ method of the PreTrainedConfig class
  • make_policy: No longer needs device input arg, as this is now in cfg
  • from_pretrained: No longer needs device input arg, as this is now in config
  • All data classes owning a PreTrainedConfig object or inheriting from its class no longer need the device and use_amp attributes
  • Pretty much every time we fetch these values, we do so now via the PreTrainedConfig class.

In order to fully integrate and test this PR (in the CI), we will need to merge these sisters PRs in our HF Hub repository for the public models:

Note: This PR is meant to be squash merged

How it was tested

  • uv run pytest didn't have any failing test

How to checkout & try? (for the reviewer)

  • Checkout branch and run: uv run pytest

@imstevenpmwork imstevenpmwork self-assigned this Mar 4, 2025
@imstevenpmwork imstevenpmwork added bug Something isn’t working correctly refactor Code cleanup or restructuring without changing behavior configuration Problems with configuration files or settings labels Mar 4, 2025
Copy link
Collaborator

@Cadene Cadene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice! I left a few questions

…ntrol_loop + ensure proper use of device str or torch.device

fix(configs/policies): Ensure try_device is a str

refactor(ci): update config changes to makefile

fix(scripts/control_robot): Take arg from config instead of policy

refactor(robot_devices/control_utils): Remove device and use_amps args from the record_episode and control_loop functions

fix(config): ensure consistency when using device as str or as a torch.device
@imstevenpmwork imstevenpmwork force-pushed the refactor/pretrainedconfig_device_amps_args branch from 5bd03d3 to d2046d4 Compare March 4, 2025 23:50
Copy link
Collaborator

@aliberts aliberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM (modulo my comment over device from pretrained) but I realize there's currently one big caveat for this:
Using --policy.path with some --policy overrides currently doesn't work. We should probably do something about it before merging. WDYT?

@aliberts aliberts merged commit 5e94738 into main Mar 6, 2025
7 checks passed
@aliberts aliberts deleted the refactor/pretrainedconfig_device_amps_args branch March 6, 2025 16:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn’t working correctly configuration Problems with configuration files or settings refactor Code cleanup or restructuring without changing behavior
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants