Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mutating arithmetic for SRows #1659

Merged
merged 11 commits into from
Nov 13, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ GAPExt = "GAP"
PolymakeExt = "Polymake"

[compat]
AbstractAlgebra = "^0.43.1"
AbstractAlgebra = "^0.43.10"
Dates = "1.6"
Distributed = "1.6"
GAP = "0.9.6, 0.10, 0.11, 0.12"
Expand Down
184 changes: 165 additions & 19 deletions src/Sparse/Row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@
return A
end

function Base.empty(A::SRow)
return sparse_row(base_ring(A))
end

function zero(A::SRow)
return empty(A)
end

function swap!(A::SRow, B::SRow)
A.pos, B.pos = B.pos, A.pos
A.values, B.values = B.values, A.values
Expand Down Expand Up @@ -447,15 +455,17 @@
# Inplace scaling
#
################################################################################

@doc raw"""
scale_row!(a::SRow, b::NCRingElem) -> SRow

Returns the (left) product of $b \times a$ and reassigns the value of $a$ to this product.
For rows, the standard multiplication is from the left.
"""
function scale_row!(a::SRow{T}, b::T) where T
@assert !iszero(b)
if isone(b)
if iszero(b)
return empty!(a)
elseif isone(b)
return a
end
i = 1
Expand All @@ -465,20 +475,23 @@
deleteat!(a.values, i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Orthogonal to this PR, but: Since you now handle b==0 at the start, the iszero(a.values[i]) check above the fold can only return true if the coefficient ring is not a domain. So it could be strengthened to something like !is_domain_type(T) && iszero(a.values[i]).

Since is_domain_type is a trait depending only on T, the compiler can eliminate the is_domain_type(T) check -- if it returns true it can elide the if block, and if it is false we the same code we have currently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this can easily wait for a follow-up PR. (Also scale_row! and scale_row_right! could be merged)

deleteat!(a.pos, i)
else
i += 1
i += 1
end
end
return a
end

scale_row!(a::SRow, b) = scale_row!(a, base_ring(a)(b))

@doc raw"""
scale_row_right!(a::SRow, b::NCRingElem) -> SRow

Returns the (right) product of $a \times b$ and modifies $a$ to this product.
"""
function scale_row_right!(a::SRow{T}, b::T) where T
@assert !iszero(b)
if isone(b)
if iszero(b)
return empty!(a)
elseif isone(b)
return a
end
i = 1
Expand All @@ -488,16 +501,20 @@
deleteat!(a.values, i)
deleteat!(a.pos, i)
else
i += 1
i += 1
end
end
return a
end

scale_row_right!(a::SRow, b) = scale_row_right!(a, base_ring(a)(b))

function scale_row_left!(a::SRow{T}, b::T) where T
return scale_row!(a,b)
end

scale_row_left!(a::SRow, b) = scale_row_left!(a, base_ring(a)(b))

################################################################################
#
# Addition
Expand All @@ -506,22 +523,22 @@

function +(A::SRow{T}, B::SRow{T}) where T
if length(A.values) == 0
return B
return deepcopy(B)
elseif length(B.values) == 0
return A
return deepcopy(A)
end
return add_scaled_row(A, B, one(base_ring(A)))
end

function -(A::SRow{T}, B::SRow{T}) where T
if length(A) == 0
if length(B) == 0
return A
return deepcopy(A)
else
return add_scaled_row(B, A, base_ring(B)(-1))
return add_scaled_row(B, A, -1)
end
end
return add_scaled_row(B, A, base_ring(A)(-1))
return add_scaled_row(B, A, -1)
end

function -(A::SRow{T}) where {T}
Expand Down Expand Up @@ -683,10 +700,10 @@

Returns the row $c A + B$.
"""
add_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_scaled_row!(a, deepcopy(b), c)
add_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_scaled_row!(a, deepcopy(b), c)

add_left_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_left_scaled_row!(a, deepcopy(b), c)
add_right_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_right_scaled_row!(a, deepcopy(b), c)
add_left_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_left_scaled_row!(a, deepcopy(b), c)
add_right_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_right_scaled_row!(a, deepcopy(b), c)

Check warning on line 706 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L705-L706

Added lines #L705 - L706 were not covered by tests



Expand All @@ -696,7 +713,9 @@
Adds the left scaled row $c A$ to $B$.
"""
function add_scaled_row!(a::SRow{T}, b::SRow{T}, c::T, ::Val{left_side} = Val(true)) where {T, left_side}
@assert a !== b
if a === b
a = deepcopy(a)
end
i = 1
j = 1
t = base_ring(a)()
Expand Down Expand Up @@ -735,17 +754,144 @@
return b
end

add_scaled_row!(a::SRow{T}, b::SRow{T}, c) where {T} = add_scaled_row!(a, b, base_ring(a)(c))

add_scaled_row!(a::SRow{T}, b::SRow{T}, c, side::Val) where {T} = add_scaled_row!(a, b, base_ring(a)(c), side)

# ignore tmp argument
add_scaled_row!(a::SRow{T}, b::SRow{T}, c::T, tmp::SRow{T}) where T = add_scaled_row!(a, b, c)
add_scaled_row!(a::SRow{T}, b::SRow{T}, c, tmp::SRow{T}) where T = add_scaled_row!(a, b, c)

Check warning on line 762 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L762

Added line #L762 was not covered by tests

add_left_scaled_row!(a::SRow{T}, b::SRow{T}, c::T) where T = add_scaled_row!(a, b, c)
add_left_scaled_row!(a::SRow{T}, b::SRow{T}, c) where T = add_scaled_row!(a, b, c)

@doc raw"""
add_right_scaled_row!(A::SRow{T}, B::SRow{T}, c::T) -> SRow{T}

Return the right scaled row $c A$ to $B$ by changing $B$ in place.
Return the right scaled row $A c$ to $B$ by changing $B$ in place.
"""
add_right_scaled_row!(a::SRow{T}, b::SRow{T}, c::T) where T = add_scaled_row!(a, b, c, Val(false))
add_right_scaled_row!(a::SRow{T}, b::SRow{T}, c) where T = add_scaled_row!(a, b, c, Val(false))


################################################################################
#
# Mutating arithmetics
#
################################################################################

function zero!(z::SRow)
return empty!(z)
end

function neg!(z::SRow{T}, x::SRow{T}) where T
if z === x
return neg!(x)

Check warning on line 786 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L786

Added line #L786 was not covered by tests
end
swap!(z, -x)
return z
end

function neg!(z::SRow)
for i in 1:length(z)
z.values[i] = neg!(z.values[i])
end
return z
end

function add!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
if z === x
return add!(x, y)
elseif z === y
return add!(y, x)
end
swap!(z, x + y)
return z
end

function add!(z::SRow{T}, x::SRow{T}) where T
if z === x
return scale_row!(z, 2)
end
return add_scaled_row!(x, z, one(base_ring(x)))
end

function sub!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
if z === x
return sub!(x, y)
elseif z === y
return neg!(sub!(y, x))
end
swap!(z, x - y)
return z
end

function sub!(z::SRow{T}, x::SRow{T}) where T
if z === x
return empty!(z)
end
return add_scaled_row!(x, z, -1)
end

function mul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
error("Not implemented")

Check warning on line 834 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L833-L834

Added lines #L833 - L834 were not covered by tests
end

function mul!(z::SRow{T}, x::SRow{T}, c) where T
if z === x
return scale_row_right!(x, c)
end
swap!(z, x * c)
return z
end

function mul!(z::SRow{T}, c, y::SRow{T}) where T
if z === y
return scale_row_left!(y, c)
end
swap!(z, c * y)
return z
end

function addmul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
error("Not implemented")

Check warning on line 854 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L853-L854

Added lines #L853 - L854 were not covered by tests
end

function addmul!(z::SRow{T}, x::SRow{T}, y) where T
if z === x
return scale_row_right!(x, y+1)
end
return add_right_scaled_row!(x, z, y)
end

function addmul!(z::SRow{T}, x, y::SRow{T}) where T
if z === x
return scale_row_left!(y, x+1)

Check warning on line 866 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L866

Added line #L866 was not covered by tests
end
return add_left_scaled_row!(y, z, x)
end

function submul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
error("Not implemented")

Check warning on line 872 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L871-L872

Added lines #L871 - L872 were not covered by tests
end

function submul!(z::SRow{T}, x::SRow{T}, y) where T
if z === x
return scale_row_right!(x, -y+1)
end
return add_right_scaled_row!(x, z, -y)
end

function submul!(z::SRow{T}, x, y::SRow{T}) where T
if z === x
return scale_row_left!(y, -x+1)

Check warning on line 884 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L884

Added line #L884 was not covered by tests
end
return add_left_scaled_row!(y, z, -x)
end


# ignore temp variable
addmul!(z::SRow{T}, x::SRow{T}, y, t) where T = addmul!(z, x, y)
addmul!(z::SRow{T}, x, y::SRow{T}, t) where T = addmul!(z, x, y)
submul!(z::SRow{T}, x::SRow{T}, y, t) where T = submul!(z, x, y)
submul!(z::SRow{T}, x, y::SRow{T}, t) where T = submul!(z, x, y)


################################################################################
Expand Down
4 changes: 3 additions & 1 deletion src/Sparse/ZZRow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ end

function add_scaled_row(Ai::SRow{ZZRingElem}, Aj::SRow{ZZRingElem}, c::ZZRingElem, sr::SRow{ZZRingElem} = sparse_row(ZZ))
empty!(sr)
@assert c != 0
n = ZZRingElem()
pi = 1
pj = 1
Expand Down Expand Up @@ -323,6 +322,9 @@ function add_scaled_row(Ai::SRow{ZZRingElem}, Aj::SRow{ZZRingElem}, c::ZZRingEle
end

function add_scaled_row!(Ai::SRow{ZZRingElem}, Aj::SRow{ZZRingElem}, c::ZZRingElem, sr::SRow{ZZRingElem} = sparse_row(ZZ))
if iszero(c)
return Aj
end
_t = sr
sr = add_scaled_row(Ai, Aj, c, sr)
@assert _t === sr
Expand Down
36 changes: 36 additions & 0 deletions test/Sparse/Row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,40 @@
B = sparse_row(F,[1],[y])
C = add_scaled_row(A,B,F(1))
@test C == A+B

# mutating arithmetic
randcoeff() = begin
n = rand((1,1,1,2,5,7,15))
return rand(-2^n:2^n)
end
Main.equality(A::SRow, B::SRow) = A == B
@testset "mutating arithmetic; R = $R" for R in (ZZ, QQ)
for _ in 1:10
maxind_A = rand(0:10)
inds_A = Hecke.Random.randsubseq(1:maxind_A, rand())
vals_A = elem_type(R)[R(rand((-1, 1)) * rand(1:10)) for _ in 1:length(inds_A)]
A = sparse_row(R, inds_A, vals_A)

maxind_B = rand(0:10)
inds_B = Hecke.Random.randsubseq(1:maxind_B, rand())
vals_B = elem_type(R)[R(rand((-1, 1)) * rand(1:10)) for _ in 1:length(inds_B)]
B = sparse_row(R, inds_B, vals_B)

test_mutating_op_like_zero(zero, zero!, A)

test_mutating_op_like_neg(-, neg!, A)

test_mutating_op_like_add(+, add!, A, B)
test_mutating_op_like_add(-, sub!, A, B)
test_mutating_op_like_add(*, mul!, A, randcoeff(), SRow)
test_mutating_op_like_add(*, mul!, randcoeff(), A, SRow)
test_mutating_op_like_add(*, mul!, A, ZZ(randcoeff()), SRow)
test_mutating_op_like_add(*, mul!, ZZ(randcoeff()), A, SRow)

test_mutating_op_like_addmul((a, b, c) -> a + b*c, addmul!, A, B, randcoeff(), SRow)
test_mutating_op_like_addmul((a, b, c) -> a + b*c, addmul!, A, randcoeff(), B, SRow)
test_mutating_op_like_addmul((a, b, c) -> a - b*c, submul!, A, B, randcoeff(), SRow)
test_mutating_op_like_addmul((a, b, c) -> a - b*c, submul!, A, randcoeff(), B, SRow)
end
end
end
Loading