diff --git a/src/circuit/KetShaderUtil.js b/src/circuit/KetShaderUtil.js index 882e2d60..0bf27955 100644 --- a/src/circuit/KetShaderUtil.js +++ b/src/circuit/KetShaderUtil.js @@ -86,20 +86,31 @@ const ketShaderPermute = (head, body, span=null) => ketShader( span); /** - * @param {!String} head - * @param {!String} body - * @param {null|!int=null} span + * Returns a shader that multiplies each of the amplitudes in a superposition by computed phase factors. + * + * @param {!String} head Header code defining shader methods, uniforms, etc. + * @param {!String} body The body of a shader method returning the number of radians to phase by. + * @param {null|!int=null} span The number of qubits this operation applies to, if known ahead of time. * @return {!{withArgs: !function(args: ...!WglArg) : !WglConfiguredShader}} */ const ketShaderPhase = (head, body, span=null) => ketShader( - head + `vec2 _ketgen_phase_for(float out_id) { ${body} }`, - 'return cmul(amp, _ketgen_phase_for(out_id));', + `${head} + float _ketgen_phase_for(float out_id) { + ${body} + } + `, + ` + float angle = _ketgen_phase_for(out_id); + return cmul(amp, vec2(cos(angle), sin(angle))); + `, span); /** - * @param {!CircuitEvalContext} ctx - * @param {undefined|!int=undefined} span - * @param {undefined|!Array.} input_letters + * Determines some arguments to give to a shader produced by one of the ketShader methods. + * + * @param {!CircuitEvalContext} ctx The context in which the ket shader is being applied. + * @param {undefined|!int=undefined} span The number of qubits this shader applies to (if wasn't known ahead of time). + * @param {undefined|!Array.} input_letters The input gates that this shader cares about. * @returns {!Array.} */ function ketArgs(ctx, span=undefined, input_letters=[]) { diff --git a/src/gates/FourierTransformGates.js b/src/gates/FourierTransformGates.js index 3f3182b3..c7d36088 100644 --- a/src/gates/FourierTransformGates.js +++ b/src/gates/FourierTransformGates.js @@ -35,8 +35,7 @@ const CONTROLLED_PHASE_GRADIENT_SHADER = ketShaderPhase( ` float hold = floor(out_id * 2.0 / span); float step = mod(out_id, span / 2.0); - float angle = hold * step * factor * 6.2831853071795864769 / span; - return vec2(cos(angle), sin(angle)); + return hold * step * factor * 6.2831853071795864769 / span; `); const FOURIER_TRANSFORM_MATRIX_MAKER = span => diff --git a/src/gates/ParametrizedRotationGates.js b/src/gates/ParametrizedRotationGates.js index 4f4b2e1d..fabfa6b3 100644 --- a/src/gates/ParametrizedRotationGates.js +++ b/src/gates/ParametrizedRotationGates.js @@ -71,8 +71,7 @@ const Z_TO_A_SHADER = ketShaderPhase( ${ketInputGateShaderCode('A')} `, ` - float angle = read_input_A() * out_id * factor / _gen_input_span_A; - return vec2(cos(angle), sin(angle)); + return read_input_A() * out_id * factor / _gen_input_span_A; `); ParametrizedRotationGates.XToA = new GateBuilder(). diff --git a/src/gates/PhaseGradientGates.js b/src/gates/PhaseGradientGates.js index 6c1a674e..80abfc1f 100644 --- a/src/gates/PhaseGradientGates.js +++ b/src/gates/PhaseGradientGates.js @@ -37,8 +37,7 @@ const PHASE_GRADIENT_SHADER = ketShaderPhase( } `, ` - float angle = angle_mul(factor, out_id); - return vec2(cos(angle), sin(angle)); + return angle_mul(factor, out_id); `); let PhaseGradientGates = {}; diff --git a/test/circuit/KetShaderUtil.test.js b/test/circuit/KetShaderUtil.test.js index 25dda0ba..5f222ee4 100644 --- a/test/circuit/KetShaderUtil.test.js +++ b/test/circuit/KetShaderUtil.test.js @@ -49,7 +49,7 @@ suite.testUsingWebGL("ketShaderPermute", () => { suite.testUsingWebGL("ketShaderPhase", () => { let shader = ketShaderPhase( '', - 'return vec2(cos(out_id/10.0), sin(out_id/10.0));', + 'return out_id/10.0;', 3); assertThatCircuitShaderActsLikeMatrix( ctx => shader.withArgs(...ketArgs(ctx)),