Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchphillipson committed May 9, 2024
2 parents 0aeac4b + 6bf547b commit b184f85
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 35 deletions.
56 changes: 21 additions & 35 deletions src/dense_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ description(P::DenseSparseArray) = ""



@inline _convert_idx(x::Symbol,GU::GamsUniverse,d::Symbol) = (x==d || xGU[d].aliases || xGU[d]) ? (x == d || xGU[d].aliases ? [[i] for iGU[d]] : [[x]]) : throw(DomainError(x, "Symbol $x is neither a set nor an element of the set $d."))
@inline _convert_idx(x::Symbol,GU::GamsUniverse,d::Symbol) = (x==d || xGU[d].aliases || xGU[d]) ? (x == d || xGU[d].aliases ? [i for iGU[d]] : [x]) : throw(DomainError(x, "Symbol $x is neither a set nor an element of the set $d."))


function _convert_idx(x::Vector,GU::GamsUniverse,d::Symbol)
@assert (all(iGU[d] for ix)) "At least one element of $x is not in set $d"
return [[i] for ix]
return x
end

function _convert_idx(x::GamsSet,GU::GamsUniverse,d::Symbol)
@assert x==GU[d] "The set\n\n$x\ndoes not match the domain set $d"
return [[i] for ix]
return [i for ix]
end


Expand All @@ -49,59 +49,45 @@ function Base.getindex(P::DenseSparseArray{T,N},idx::CartesianIndex{N}) where {T
end


"""

if v = [:a,:b,:c,:d,:e] and p = (2,2,1)
we expect output of
[[:a,:b],[:c,:d],[:e]]
"""
function partition(v,p)
ind = ((sum(p[i] for i1:n;init=1),sum(p[i] for i1:(n+1);init=1)-1) for n0:(length(p)-1))
return [[v[i] for ia:b] for (a,b)ind]
end

function dimension(x)
return 1
#Assume input has no masks
function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N}
return _getindex(P,idx...)
end

function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any}) where {T,N}
domain_match = partition(domain(P),dimension.(idx)) #Used for masking

@assert sum(length.(domain_match)) == dimension(P) "Not enough inputs, or something. Get a better error message"

function _getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N}
GU = universe(P)
idx = map((x,d)->_convert_idx(x,GU,d...),idx,domain_match)
d = domain(P)

X = Tuple.(Iterators.flatten.(Iterators.product(idx...)))
idx = map((x,d) -> _convert_idx(x,GU,d), idx, d) |>
x -> collect(Iterators.product(x...)) |>
x -> dropdims(x,dims=tuple(findall(size(x).==1)...))

data_dict = data(P)
length(X) == 1 ? get(data_dict,X[1],0) : get.(Ref(data_dict),X,0)
length(idx) == 1 ? get(data_dict,idx[1],zero(T)) : get.(Ref(data_dict),idx,zero(T))

end


################
### setindex ###
################


function Base.setindex!(P::DenseSparseArray{T,N}, value, idx::Vararg{Any,N}) where {T,N}
domain_match = partition(domain(P),dimension.(idx))
@assert sum(length.(domain_match)) == dimension(P) "Not enough inputs, or something. Get a better error message"
#domain_match = partition(domain(P),dimension.(idx))
#@assert sum(length.(domain_match)) == dimension(P) "Not enough inputs, or something. Get a better error message"

GU = universe(P)
idx = map((x,d)->_convert_idx(x,GU,d...),idx,domain_match)

X = Tuple.(Iterators.flatten.(Iterators.product(idx...)))


d = domain(P)
idx = map((x,d) -> _convert_idx(x,GU,d), idx, d)

idx = collect(Iterators.product(idx...))

if length(X) == 1
_setindex!(P,value,X[1])
if length(idx) == 1
_setindex!(P,value,idx[1])
else
_setindex!.(Ref(P), value, X)
_setindex!.(Ref(P), value, idx)
end

end
Expand Down
55 changes: 55 additions & 0 deletions src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,61 @@ function _convert_idx(x::Mask,GU::GamsUniverse,d::Vararg{Symbol})
return keys(data(x))
end

"""
if v = [:a,:b,:c,:d,:e] and p = (2,2,1)
we expect output of
[[:a,:b],[:c,:d],[:e]]
"""
function partition(v,p)
ind = ((sum(p[i] for i1:n;init=1),sum(p[i] for i1:(n+1);init=1)-1) for n0:(length(p)-1))
return [[v[i] for ia:b] for (a,b)ind]
end

function dimension(x)
return 1
end

function Base.getindex(P::Parameter{T,N},idx::Vararg{Any}) where {T,N}
if any(isa(x,Mask) for xidx)
_getindex_mask(P,idx...)
else

GamsStructure._getindex(P,idx...)
end
end

"""
Code from TupleTools.jl package
"""
flatten(x::Any) = (x,)
flatten(t::Tuple{}) = ()
flatten(t::Tuple) = (flatten(t[1])..., flatten(Base.tail(t))...)
flatten(x, r...) = (flatten(x)..., flatten(r)...)

function _getindex_mask(P::Parameter{T,N},idx...) where {T,N}

GU = GamsStructure.universe(P)
d = GamsStructure.domain(P)

domain_match = GamsStructure.partition(d,GamsStructure.dimension.(idx))

idx = map((x,d) -> GamsStructure._convert_idx(x,GU,d...), idx, domain_match) |>
x -> GamsStructure.collect(Iterators.product(x...)) |>
x -> GamsStructure.dropdims(x,dims=tuple(findall(size(x).==1)...)) |>
x -> flatten.(x)


#return idx
data_dict = GamsStructure.data(P)
length(idx) == 1 ? get(data_dict,idx[1],zero(T)) : get.(Ref(data_dict),idx,zero(T))

end




"""
@create_parameters(GU,block)
Expand Down

0 comments on commit b184f85

Please sign in to comment.