Skip to content

Commit

Permalink
Rename discount op to affine.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Jan 15, 2025
1 parent 25b2b70 commit ee005da
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/runtime/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ fn dispatch_layer<F: Float>(
]);

if (index + 1) % rescale == 0 {
ops.push(TensorOp::discount(&buffer.x, 0.5, 0.0)?);
ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
}

Ok(TensorOp::List(ops))
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ fn dispatch_layer<F: Float>(
]);

if (index + 1) % rescale == 0 {
ops.push(TensorOp::discount(&buffer.x, 0.5, 0.0)?);
ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
}

Ok(TensorOp::List(ops))
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ fn dispatch_layer<F: Float>(
]);

if (index + 1) % rescale == 0 {
ops.push(TensorOp::discount(&buffer.x, 0.5, 0.0)?);
ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
}

Ok(TensorOp::List(ops))
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/v7.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ fn dispatch_layer<F: Float>(
]);

if (index + 1) % rescale == 0 {
ops.push(TensorOp::discount(&buffer.x, 0.5, 0.0)?);
ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
}

Ok(TensorOp::List(ops))
Expand Down
6 changes: 3 additions & 3 deletions src/shaders/discount.wgsl → src/shaders/affine.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn unpack4x16float(x: vec2<u32>) -> vec4<f32> {
}

@compute @workgroup_size(BLOCK_SIZE, 1, 1)
fn discount(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
fn affine(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
let stride = shape[0] / 4u;
let index = invocation_id.x;
let token = invocation_id.y;
Expand All @@ -23,9 +23,9 @@ fn discount(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
if index < stride {
let bti = (batch * shape[1] + token) * stride + index;
#ifdef FP16
x[bti] = pack4x16float(FACTOR * unpack4x16float(x[bti]) + BIAS);
x[bti] = pack4x16float(SCALE * unpack4x16float(x[bti]) + BIAS);
#else
x[bti] = FACTOR * x[bti] + BIAS;
x[bti] = SCALE * x[bti] + BIAS;
#endif
}
}
12 changes: 6 additions & 6 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2451,9 +2451,9 @@ impl TensorOp {
})
}

pub fn discount(
pub fn affine(
x: &TensorGpu<impl Float, ReadWrite>,
factor: f32,
scale: f32,
bias: f32,
) -> Result<Self, TensorError> {
const BLOCK_SIZE: u32 = 128;
Expand All @@ -2462,17 +2462,17 @@ impl TensorOp {
let shape = x.shape();

let key = PipelineKey::new(
"discount",
"discount",
"affine",
"affine",
Macros::new()
.u32("BLOCK_SIZE", BLOCK_SIZE)
.tensor(x, None)
.f32("FACTOR", factor)
.f32("SCALE", scale)
.f32("BIAS", bias),
);
let pipeline = context.checkout_pipeline(
&key,
include_str!("../shaders/discount.wgsl"),
include_str!("../shaders/affine.wgsl"),
&[x.meta_layout(0), x.layout(1, false)],
);

Expand Down

0 comments on commit ee005da

Please sign in to comment.