diff --git a/impls/buffer.cpp b/impls/buffer.cpp index c5cf50c..3bdd404 100644 --- a/impls/buffer.cpp +++ b/impls/buffer.cpp @@ -245,4 +245,14 @@ void ShaderStorageBuffer::Allocate(Ulong const &size, mSize = size; this->Unbind(); } + +void ShaderStorageBuffer::Bind(Shader &shader, Str const &name) const { + shader.BindStorageBlock(name, (Uint)mSSBO); + glBindBufferBase(GL_SHADER_STORAGE_BUFFER, (Uint)mSSBO, mSSBO); +} + +void ShaderStorageBuffer::Bind(ComputeKernel &shader, Str const &name) const { + shader.BindStorageBlock(name, (Uint)mSSBO); + glBindBufferBase(GL_SHADER_STORAGE_BUFFER, (Uint)mSSBO, mSSBO); +} } // namespace TerreateGraphics::Core diff --git a/impls/compute.cpp b/impls/compute.cpp index c50d385..27d4311 100644 --- a/impls/compute.cpp +++ b/impls/compute.cpp @@ -151,25 +151,6 @@ void ComputeKernel::SetMat4(Str const &name, mat4 const &value) const { glUseProgram(0); } -void ComputeKernel::AddStorage(ShaderStorageBuffer const &ssbo, - Str const &name) { - if (!mLinked) { - throw Exceptions::ShaderError("Kernel is not linked"); - return; - } - - Uint index = this->GetStorageBlockIndex(name); - - if (mSSBOBindingMap.find(index) == mSSBOBindingMap.end()) { - Uint newID = mSSBOBindingMap.size(); - glShaderStorageBlockBinding(mKernelID, index, newID); - mSSBOBindingMap.insert({index, newID}); - mSSBOMap.insert({index, ssbo}); - } - - ssbo.BindBase(mSSBOBindingMap.at(index)); -} - void ComputeKernel::Compile() { if (mKernelSource == "") { throw Exceptions::ShaderError("Compute kernel source is empty"); diff --git a/includes/buffer.hpp b/includes/buffer.hpp index fd8f982..71118bc 100644 --- a/includes/buffer.hpp +++ b/includes/buffer.hpp @@ -1,6 +1,7 @@ #ifndef __TERREATE_GRAPHICS_BUFFER_HPP__ #define __TERREATE_GRAPHICS_BUFFER_HPP__ +#include "compute.hpp" #include "defines.hpp" #include "exceptions.hpp" #include "globj.hpp" @@ -8,6 +9,7 @@ namespace TerreateGraphics::Core { using namespace TerreateGraphics::Defines; +using namespace TerreateGraphics::Compute; using namespace TerreateGraphics::GL; struct AttributeData { @@ -356,9 +358,8 @@ class ShaderStorageBuffer : public TerreateObjectBase { void Allocate(Ulong const &size, BufferUsage const &usage = BufferUsage::STATIC_DRAW); - void BindBase(Uint const &index) const { - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, index, mSSBO); - } + void Bind(Shader &shader, Str const &name) const; + void Bind(ComputeKernel &kernel, Str const &name) const; }; } // namespace TerreateGraphics::Core diff --git a/includes/compute.hpp b/includes/compute.hpp index 4a8cf48..493e62f 100644 --- a/includes/compute.hpp +++ b/includes/compute.hpp @@ -1,13 +1,13 @@ #ifndef __TERREATE_GRAPHICS_KERNEL_HPP__ #define __TERREATE_GRAPHICS_KERNEL_HPP__ -#include "buffer.hpp" +// #include "buffer.hpp" #include "defines.hpp" #include "globj.hpp" namespace TerreateGraphics::Compute { using namespace TerreateGraphics::Defines; -using namespace TerreateGraphics::Core; +// using namespace TerreateGraphics::Core; using namespace TerreateGraphics::GL; using namespace TerreateCore::Math; @@ -17,8 +17,6 @@ class ComputeKernel final : public TerreateObjectBase { Bool mLinked = false; GLObject mKernelID = GLObject(); Str mKernelSource = ""; - Map mSSBOMap; - Map mSSBOBindingMap; public: /* @@ -143,18 +141,22 @@ class ComputeKernel final : public TerreateObjectBase { */ void SetMat4(Str const &name, mat4 const &value) const; - /* - * @brief: Add shader storage block. - * @param: name: name of storage block - * @param: binding: binding point of storage block - */ - void AddStorage(ShaderStorageBuffer const &ssbo, Str const &name); /* * @brief: Add kernel source. * @param: source: source code to add */ void AddKernelSource(Str const &source) { mKernelSource += source; } + /* + * @brief: This function binds storage block index to binding point. + * @param: name: name of storage block + * @param: binding: binding point + */ + void BindStorageBlock(Str const &name, Uint const &binding) const { + glShaderStorageBlockBinding(mKernelID, this->GetStorageBlockIndex(name), + binding); + } + /* * @brief: Compile shader. */ diff --git a/includes/shader.hpp b/includes/shader.hpp index 9e5f66d..0e06c62 100644 --- a/includes/shader.hpp +++ b/includes/shader.hpp @@ -70,6 +70,15 @@ class Shader final : public TerreateObjectBase { unsigned GetUniformBlockIndex(Str const &name) const { return glGetUniformBlockIndex(mShaderID, name.c_str()); } + /* + * @brief: Getter for shader storage block index. + * @param: name: name of storage block + * @return: storage block index + */ + unsigned GetStorageBlockIndex(Str const &name) const { + return glGetProgramResourceIndex(mShaderID, GL_SHADER_STORAGE_BLOCK, + name.c_str()); + } /* * @brief: Setter for shader Bool uniform. @@ -249,6 +258,15 @@ class Shader final : public TerreateObjectBase { void BindUniformBlock(Str const &name, Uint const &binding) const { glUniformBlockBinding(mShaderID, this->GetUniformBlockIndex(name), binding); } + /* + * @brief: This function binds storage block index to binding point. + * @param: name: name of storage block + * @param: binding: binding point + */ + void BindStorageBlock(Str const &name, Uint const &binding) const { + glShaderStorageBlockBinding(mShaderID, this->GetStorageBlockIndex(name), + binding); + } /* * @brief: This function swiches blending on or off. * @param: value: true to turn on, false to turn off diff --git a/tests/TGTest.cpp b/tests/TGTest.cpp index e103370..56b1045 100644 --- a/tests/TGTest.cpp +++ b/tests/TGTest.cpp @@ -317,8 +317,8 @@ void TestCompute() { kernel.Compile(); kernel.Link(); - kernel.AddStorage(input, "InputBuffer"); - kernel.AddStorage(output, "OutputBuffer"); + input.Bind(kernel, "InputBuffer"); + output.Bind(kernel, "OutputBuffer"); kernel.SetFloat("scaleFactor", 2.0f); kernel.Dispatch(10, 1, 1); @@ -329,8 +329,8 @@ void TestCompute() { kernel2.Compile(); kernel2.Link(); - kernel2.AddStorage(input2, "InputBuffer"); - kernel2.AddStorage(output2, "OutputBuffer"); + input2.Bind(kernel2, "InputBuffer"); + output2.Bind(kernel2, "OutputBuffer"); kernel2.SetFloat("scaleFactor", 3.0f); kernel2.Dispatch(10, 1, 1); @@ -364,6 +364,8 @@ int main() { app.CharCallback(window, codepoint); }); + // TestCompute(); + while (window) { window.Frame([&app](Window *window) { app.OnFrame(window); }); }