Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input type of "conv" does not allow u8 (f32 ok). #63

Closed
ndrean opened this issue Dec 7, 2024 · 8 comments · Fixed by #65
Closed

Input type of "conv" does not allow u8 (f32 ok). #63

ndrean opened this issue Dec 7, 2024 · 8 comments · Fixed by #65
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@ndrean
Copy link

ndrean commented Dec 7, 2024

(edited for clarity)

MLX does not support integer types in matrix multiplication nor convolutions.

Seeing as Nx always expects a floating/complex output from Nx.conv, we can do upcasts beforehand.

Complex conv might need the same trick used in Torchx.Backend

@polvalente
Copy link
Collaborator

@ndrean thanks for the report. I've edited the issue to explain the problem and a suggested solution. Pull requests are welcome!

@polvalente polvalente added bug Something isn't working good first issue Good for newcomers labels Dec 7, 2024
@ndrean
Copy link
Author

ndrean commented Dec 7, 2024

"hack" is easy: Nx.as_type(:f32)

@polvalente
Copy link
Collaborator

I suggest Nx.Type.to_floating instead, as that will avoid upcasts from f16/f8

@ndrean
Copy link
Author

ndrean commented Dec 7, 2024

Humm, I tested Axon.conv, it works with the conversion.
This is ok because the coefficients are learnt (the iteration will fill them up).
Now if you give the coefficients and do a convolution, you need to reverse the second input, which you can't do with floats.

Nx.reverse calls:

no function clause matching in Nx.Type.sort/2    
    
    The following arguments were given to Nx.Type.sort/2:
    
        # 1
        #Nx.Tensor<
          s32[3]
          
          Nx.Defn.Expr
          a = stack [1, 1, 1], 0   s32[3]
        >
    
        # 2
        {:f, 32}

What I mean is that you should probably need to provide the "real" implementation in NxSignal.conv instead if Nx.conv is "reserved" for Axon. I don't if I am clear.

@polvalente
Copy link
Collaborator

This is probably not a relevant conversation for this issue. I'm also not sure why the type would affect the ability to reverse an input.

@polvalente
Copy link
Collaborator

Ah, right. This is most likely just a bug in EMLX.Backend.reverse.

@polvalente
Copy link
Collaborator

Also, there is ongoing work on NxSignal.Convolutions :)

@ndrean
Copy link
Author

ndrean commented Dec 7, 2024

Ok. It was because I wrongly used Nx.Type.to_floating. This is ok.

defmodule NxPoly do
  import Nx.Defn
  
  defn tprod(t1, t2) do
    t1 = Nx.stack(t1) |> Nx.as_type(:f32)
    t1 = Nx.reshape(t1, {1,1, Nx.size(t1)})
    t2 = Nx.stack(t2) |> Nx.as_type(:f32)
    t2 = Nx.reshape(t2, {1,1, Nx.size(t2)})
    
    Nx.conv(t1, Nx.reverse(t2), padding: [{2,2}])
  end
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants