Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PRNGKey error #25076

Open
hangita101 opened this issue Nov 24, 2024 · 0 comments
Open

PRNGKey error #25076

hangita101 opened this issue Nov 24, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@hangita101
Copy link

hangita101 commented Nov 24, 2024

Description

I did
key = jax.random.PRNGKey(0)
and it spits an error:
std::bad_cast

All I did was:

import jax
from jax import value_and_grad,jit,vmap,grad
import jax.numpy as jnp
import optax

then

key = jax.random.PRNGKey(0)

then i got this error:

{
	"name": "ValueError",
	"message": "std::bad_cast",
	"stack": "---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 1
----> 1 key = jax.random.PRNGKey(11)

File ~/miniconda3/envs/ML/lib/python3.9/site-packages/jax/_src/random.py:233, in PRNGKey(seed, impl)
    216 def PRNGKey(seed: int | ArrayLike, *,
    217             impl: PRNGSpecDesc | None = None) -> KeyArray:
    218   \"\"\"Create a pseudo-random number generator (PRNG) key given an integer seed.
    219 
    220   The resulting key carries the default PRNG implementation, as
   (...)
    231     and ``fold_in``.
    232   \"\"\"
--> 233   return _return_prng_keys(True, _key('PRNGKey', seed, impl))

File ~/miniconda3/envs/ML/lib/python3.9/site-packages/jax/_src/random.py:195, in _key(ctor_name, seed, impl_spec)
    191 if np.ndim(seed):
    192   raise TypeError(
    193       f\"{ctor_name} accepts a scalar seed, but was given an array of \"
    194       f\"shape {np.shape(seed)} != (). Use jax.vmap for batching\")
--> 195 return prng.random_seed(seed, impl=impl)

File ~/miniconda3/envs/ML/lib/python3.9/site-packages/jax/_src/prng.py:533, in random_seed(seeds, impl)
    528 def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray:
    529   # Avoid overflow error in X32 mode by first converting ints to int64.
    530   # This breaks JIT invariance for large ints, but supports the common
    531   # use-case of instantiating with Python hashes in X32 mode.
    532   if isinstance(seeds, int):
--> 533     seeds_arr = jnp.asarray(np.int64(seeds))
    534   else:
    535     seeds_arr = jnp.asarray(seeds)

File ~/miniconda3/envs/ML/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:3289, in asarray(a, dtype, order, copy)
   3287 if dtype is not None:
   3288   dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
-> 3289 return array(a, dtype=dtype, copy=bool(copy), order=order)

File ~/miniconda3/envs/ML/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:3214, in array(object, dtype, copy, order, ndmin)
   3211 else:
   3212   raise TypeError(f\"Unexpected input type for array: {type(object)}\")
-> 3214 out_array: Array = lax_internal._convert_element_type(
   3215     out, dtype, weak_type=weak_type)
   3216 if ndmin > ndim(out_array):
   3217   out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))

File ~/miniconda3/envs/ML/lib/python3.9/site-packages/jax/_src/lax/lax.py:559, in _convert_element_type(operand, new_dtype, weak_type)
    557   return type_cast(Array, operand)
    558 else:
--> 559   return convert_element_type_p.bind(operand, new_dtype=new_dtype,
    560                                      weak_type=bool(weak_type))

File ~/miniconda3/envs/ML/lib/python3.9/site-packages/jax/_src/core.py:416, in Primitive.bind(self, *args, **params)
    413 def bind(self, *args, **params):
    414   assert (not config.enable_checks.value or
    415           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 416   return self.bind_with_trace(find_top_trace(args), args, params)

File ~/miniconda3/envs/ML/lib/python3.9/site-packages/jax/_src/core.py:420, in Primitive.bind_with_trace(self, trace, args, params)
    418 def bind_with_trace(self, trace, args, params):
    419   with pop_level(trace.level):
--> 420     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    421   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/miniconda3/envs/ML/lib/python3.9/site-packages/jax/_src/core.py:921, in EvalTrace.process_primitive(self, primitive, tracers, params)
    919   return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
    920 else:
--> 921   return primitive.impl(*tracers, **params)

File ~/miniconda3/envs/ML/lib/python3.9/site-packages/jax/_src/dispatch.py:87, in apply_primitive(prim, *args, **params)
     85 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     86 try:
---> 87   outs = fun(*args)
     88 finally:
     89   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

ValueError: std::bad_cast"
}

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.30
jaxlib: 0.4.30
numpy: 1.26.4
python: 3.9.20 | packaged by conda-forge | (main, Sep 30 2024, 17:49:10) [GCC 13.3.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='fool', release='5.15.167.4-microsoft-standard-WSL2', version='#1 SMP Tue Nov 5 00:21:55 UTC 2024', machine='x86_64')

$ nvidia-smi
Sun Nov 24 15:57:09 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 561.09 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4050 ... On | 00000000:01:00.0 Off | N/A |
| N/A 49C P3 14W / 64W | 79MiB / 6141MiB | 3% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 45987 C /python3.9 N/A |
+-----------------------------------------------------------------------------------------+

EDIT:
I did with key instead and got the same result
key = jax.random.key(0)

@hangita101 hangita101 added the bug Something isn't working label Nov 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant