Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch.full scalar type #7010

Merged
merged 4 commits into from
May 1, 2024
Merged

Fix torch.full scalar type #7010

merged 4 commits into from
May 1, 2024

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented May 1, 2024

This should fix #6991.

torch.full takes scalar and dtype is an optional parameter. When dtype is not specified, we should respect the scalar's dtype.

Without this change. torch.full((2,2), False) will return a tensor with dtype float32 instead of bool.

@JackCaoG JackCaoG requested review from wonjoolee95 and lsy323 May 1, 2024 21:03
@JackCaoG JackCaoG merged commit 0a54b2b into master May 1, 2024
20 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GPT2 CasualLM Inference crashes when using transformers v4.39.0
2 participants