Skip to content

Commit

Permalink
Start updating
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Apr 26, 2024
1 parent 294e540 commit f6cbd13
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 22 deletions.
11 changes: 5 additions & 6 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ using ITensors:
using ITensors.ITensorMPS: ITensorMPS, add, linkdim, linkinds, siteinds
using .ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype
using LinearAlgebra: LinearAlgebra, factorize
using NamedGraphs:
NamedGraphs, NamedGraph, not_implemented, ordinal_vertex_to_vertex, vertex_to_ordinal_vertex
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented
using NamedGraphs.GraphsExtensions:
, directed_graph, incident_edges, rename_vertices, vertextype
using NDTensors: NDTensors, dim
Expand Down Expand Up @@ -94,11 +93,11 @@ function DataGraphs.edge_data(graph::AbstractITensorNetwork, args...)
end

DataGraphs.underlying_graph(tn::AbstractITensorNetwork) = underlying_graph(data_graph(tn))
function NamedGraphs.vertex_to_ordinal_vertex(tn::AbstractITensorNetwork, vertex)
return vertex_to_ordinal_vertex(underlying_graph(tn), vertex)
function NamedGraphs.vertex_positions(tn::AbstractITensorNetwork)
return NamedGraphs.vertex_positions(underlying_graph(tn))
end
function NamedGraphs.ordinal_vertex_to_vertex(tn::AbstractITensorNetwork, ordinal_vertex)
return ordinal_vertex_to_vertex(underlying_graph(tn), ordinal_vertex)
function NamedGraphs.ordered_vertices(tn::AbstractITensorNetwork)
return NamedGraphs.ordered_vertices(underlying_graph(tn))
end

#
Expand Down
9 changes: 4 additions & 5 deletions src/contract.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using NamedGraphs: vertex_to_ordinal_vertex
using ITensors: ITensor, scalar
using ITensors.ContractionSequenceOptimization: deepmap
using ITensors.NDTensors: NDTensors, Algorithm, @Algorithm_str, contract
using LinearAlgebra: normalize!
using NamedGraphs: NamedGraphs
using NamedGraphs.OrdinalIndexing: th

function NDTensors.contract(tn::AbstractITensorNetwork; alg="exact", kwargs...)
return contract(Algorithm(alg), tn; kwargs...)
Expand All @@ -15,10 +16,8 @@ function NDTensors.contract(
sequence=contraction_sequence(tn; contraction_sequence_kwargs...),
kwargs...,
)
# TODO: Use `vertex`.
sequence_linear_index = deepmap(v -> vertex_to_ordinal_vertex(tn, v), sequence)
# TODO: Use `tokenized_vertex`.
ts = map(pv -> tn[ordinal_vertex_to_vertex(tn, pv)], 1:nv(tn))
sequence_linear_index = deepmap(v -> NamedGraphs.vertex_positions(tn)[v], sequence)
ts = map(v -> tn[v], (1:nv(tn))th)
return contract(ts; sequence=sequence_linear_index, kwargs...)
end

Expand Down
6 changes: 3 additions & 3 deletions src/contraction_sequences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ using Graphs: vertices
using ITensors: ITensor, contract
using ITensors.ContractionSequenceOptimization: deepmap, optimal_contraction_sequence
using ITensors.NDTensors: Algorithm, @Algorithm_str
using NamedGraphs: ordinal_vertex_to_vertex
using NamedGraphs.Keys: Key
using NamedGraphs.OrdinalIndexing: th

function contraction_sequence(tn::Vector{ITensor}; alg="optimal", kwargs...)
return contraction_sequence(Algorithm(alg), tn; kwargs...)
end

function contraction_sequence(tn::AbstractITensorNetwork; kwargs...)
# TODO: Use `token_vertex` and/or `token_vertices` here.
ts = map(pv -> tn[ordinal_vertex_to_vertex(tn, pv)], 1:nv(tn))
ts = map(v -> tn[v], (1:nv(tn))th)
seq_linear_index = contraction_sequence(ts; kwargs...)
# TODO: Use `Functors.fmap` or `StructWalk`?
return deepmap(n -> Key(ordinal_vertex_to_vertex(tn, n)), seq_linear_index)
return deepmap(n -> Key(vertices(tn)[n * th]), seq_linear_index)
end

function contraction_sequence(::Algorithm"optimal", tn::Vector{ITensor})
Expand Down
2 changes: 1 addition & 1 deletion src/treetensornetworks/treetensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct TreeTensorNetwork{V} <: AbstractTreeTensorNetwork{V}
end
end

function _TreeTensorNetwork(tensornetwork::ITensorNetwork, ortho_region::Vector)
function _TreeTensorNetwork(tensornetwork::ITensorNetwork, ortho_region)
return _TreeTensorNetwork(tensornetwork, Indices(ortho_region))
end

Expand Down
19 changes: 12 additions & 7 deletions src/visualize.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# TODO: Move to `ITensorNetworksITensors.ITensorVisualizationCoreExt`.
using DataGraphs: AbstractDataGraph, underlying_graph
# TODO: Move to `NamedGraphsITensorVisualizationCoreExt`.
using Graphs: vertices
using ITensors.ITensorVisualizationCore: ITensorVisualizationCore, visualize
using NamedGraphs: AbstractNamedGraph, ordinal_graph

using NamedGraphs: NamedGraphs, AbstractNamedGraph
using ITensors.ITensorVisualizationCore: ITensorVisualizationCore
function ITensorVisualizationCore.visualize(
graph::AbstractNamedGraph,
args...;
Expand All @@ -15,9 +13,16 @@ function ITensorVisualizationCore.visualize(
vertex_labels = [vertex_labels_prefix * string(v) for v in vertices(graph)]
end
#edge_labels = [string(e) for e in edges(graph)]
return visualize(ordinal_graph(graph), args...; vertex_labels, kwargs...)
return ITensorVisualizationCore.visualize(
NamedGraphs.position_graph(graph), args...; vertex_labels, kwargs...
)
end

# TODO: Move to `DataGraphsITensorVisualizationCoreExt`.
using DataGraphs: DataGraphs, AbstractDataGraph
using ITensors.ITensorVisualizationCore: ITensorVisualizationCore
function ITensorVisualizationCore.visualize(graph::AbstractDataGraph, args...; kwargs...)
return visualize(underlying_graph(graph), args...; kwargs...)
return ITensorVisualizationCore.visualize(
DataGraphs.underlying_graph(graph), args...; kwargs...
)
end

0 comments on commit f6cbd13

Please sign in to comment.