Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an immutable version of CompGraph #104

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft

Conversation

DrChainsaw
Copy link
Owner

@DrChainsaw DrChainsaw commented Jul 31, 2023

Basically an alternative to #102 which is a bit less attractive since it requires creating a completely new structure.

Current version does not have tests yet but is tested with a couple of random networks from NaiveGAflux. Unfortunately (or perhaps fortunately) the performance with Zygote seems to be way worse than #102.

This is a bit surprising however since the generated function completely unrolls the graph, so maybe it is worth revisiting. For example:

## gg is a CompGraph generated by NaiveGAflux
gs = ImmutableCompGraph(gg)
compgraphexpr(typeof(gs), :g) = quote
    v15 = g.outputs
    (v14,) = v15.inputs
    (v13,) = v14.inputs
    (v12,) = v13.inputs
    (v11,) = v12.inputs
    (v7, v10) = v11.inputs
    (v6,) = v7.inputs
    (v5,) = v6.inputs
    (v4,) = v5.inputs
    (v0, v3) = v4.inputs
    (v2,) = v3.inputs
    (v1,) = v2.inputs
    v1_out = v1(v0_in)
    v2_out = v2(v1_out)
    v3_out = v3(v2_out)
    v4_out = v4(v0_in, v3_out)
    v5_out = v5(v4_out)
    v6_out = v6(v5_out)
    v7_out = v7(v6_out)
    (v9,) = v10.inputs
    (v8,) = v9.inputs
    v8_out = v8(v5_out)
    v9_out = v9(v8_out)
    v10_out = v10(v9_out)
    v11_out = v11(v7_out, v10_out)
    v12_out = v12(v11_out)
    v13_out = v13(v12_out)
    v14_out = v14(v13_out)
    v15_out = v15(v14_out)
end

Main "trick" to make it possible to do the above in type domain is to add an isbits identifier in the vertex type. Current version uses an Int which is just a running number of how many ImmutableCompVertex has been created (per call to ImmutableCompGraph).

Other options tried but which didn't give sufficient performance was to index the inputs instead of destructuring, e.g v7 = v11.inputs[1] and v10 = v11.inputs[2] instead of (v7, v10) = v11.inputs.

Also tried:

  • RuntimeGeneratedFunctions
  • Storing a tuple of vertices in topological order and use that to extract the v0, v1 etc variables. A bit worse performance iirc.
  • Adding a function f(g, v0_in) expression on top of the above and just evaling it on the top level. The fact that this doesn't improve performance is probably an indication that a different type of expression is needed. Perhaps Zygote doesn't like to compile long functions?
  • All of the above except the expression is changed to work on CompGraph instead of ImmutableCompGraph. This was significantly worse for all versions.

@codecov
Copy link

codecov bot commented Jul 31, 2023

Codecov Report

Attention: Patch coverage is 0% with 64 lines in your changes missing coverage. Please review.

Project coverage is 89.58%. Comparing base (df65628) to head (0a2b072).
Report is 46 commits behind head on master.

Files with missing lines Patch % Lines
src/staticgraph.jl 0.00% 64 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #104      +/-   ##
==========================================
- Coverage   94.16%   89.58%   -4.59%     
==========================================
  Files          14       15       +1     
  Lines        1251     1315      +64     
==========================================
  Hits         1178     1178              
- Misses         73      137      +64     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant