Skip to content

Allow passing float for scalar values #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

Hespe
Copy link
Member

@Hespe Hespe commented Dec 20, 2024

Description

Motivation and Context

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code and checked that formatting passes (required).
  • I have have fixed all issues found by flake8 (required).
  • I have ensured that all pytest tests pass (required).
  • I have run pytest on a machine with a CUDA GPU and made sure all tests pass (required).
  • I have checked that the documentation builds (required).

Note: We are using a maximum length of 88 characters per line.

@Hespe Hespe added the enhancement New feature or request label Dec 20, 2024
@Hespe Hespe self-assigned this Dec 20, 2024
@jank324
Copy link
Member

jank324 commented Apr 9, 2025

Yeah, I'm not sure that we will merge this PR though. I also like the idea of only having to pass a float, but there are some arguments against doing this:

  • For consistency, if you can do
quad = cheetah.Quadrupole(k1=4.2)

you should also be able to do

quad.k1 = 4.2

This would lead to a lot of boiler plate code inside of Cheetah that wouldn't be nice and could also slow it down a little.

  • Doing this would also result in an incosistency because
my_k1 = 4.2
quad.k1 = my_k1
quad_k1 == my_k1   # Would be False even though you would expect True
  • Overall, this feels a little like it would go against the Zen of Python "Explicit is better than implicit"

That said, we don't endorse it, but in most cases I believe you can just pass floats to inits in the current version of Cheetah 🤔

@Hespe
Copy link
Member Author

Hespe commented Apr 10, 2025

Just to point that out, passing floats is currently prevented by verify_device_and_dtype. So an option might actually be to reduce this PR to adjusting verify_device_and_dtype for float inputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow passing Python float instead of torch.Tensor for scalar arguments
2 participants