Skip to content

Commit 49afe9e

Browse files
committed
bug fixes and improvements in tensor parser routines
1 parent 036392c commit 49afe9e

File tree

5 files changed

+15
-16
lines changed

5 files changed

+15
-16
lines changed

src/indexnotation/analyzers.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function decomposegeneraltensor(ex)
4242
elseif isa(ex, Expr) && ex.head == :call && ex.args[1] == :conj && length(ex.args) == 2 # conjugation: flip conjugation flag and conjugate scalar factor
4343
(object, leftind, rightind, α, conj) = decomposegeneraltensor(ex.args[2])
4444
return (object, leftind, rightind, Expr(:call, :conj, α), !conj)
45-
elseif ex.head == :call && ex.args[1] == :* && length(ex.args) == 3 # scalar multiplication: muliply scalar factors
45+
elseif ex.head == :call && ex.args[1] == :* && length(ex.args) == 3 # scalar multiplication: multiply scalar factors
4646
if isscalarexpr(ex.args[2]) && isgeneraltensor(ex.args[3])
4747
(object, leftind, rightind, α, conj) = decomposegeneraltensor(ex.args[3])
4848
return (object, leftind, rightind, Expr(:call, :*, ex.args[2], α), conj)
@@ -98,7 +98,9 @@ end
9898
# get all the existing tensor objects which are inputs (i.e. appear in the rhs of assignments and definitions)
9999
function getinputtensorobjects(ex)
100100
list = Any[]
101-
if isdefinition(ex)
101+
if istensorexpr(ex)
102+
append!(list, gettensorobjects(ex))
103+
elseif isdefinition(ex)
102104
append!(list, gettensorobjects(getrhs(ex)))
103105
elseif isassignment(ex)
104106
if ex.head == :(+=) || ex.head == :(-=)

src/indexnotation/instantiators.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11

22
function instantiate_eltype(ex::Expr)
33
if istensor(ex)
4-
obj,_,_ = decomposetensor(ex)
5-
return Expr(:call, :eltype, obj)
4+
return Expr(:call, :eltype, gettensorobject(ex))
65
elseif ex.head == :call && (ex.args[1] == :+ || ex.args[1] == :- || ex.args[1] == :* || ex.args[1] == :/)
76
if length(ex.args) > 2
87
return Expr(:call, :promote_type, map(instantiate_eltype, ex.args[2:end])...)
@@ -18,7 +17,7 @@ function instantiate_eltype(ex::Expr)
1817
throw(ArgumentError("unable to determine eltype"))
1918
end
2019
end
21-
instantiate_eltype(ex) = Expr(:call,:typeof, ex)
20+
instantiate_eltype(ex) = Expr(:call, :typeof, ex)
2221

2322
function instantiate_scalar(ex::Expr)
2423
if ex.head == :call && ex.args[1] == :scalar

src/indexnotation/parser.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ mutable struct TensorParser
2020
end
2121

2222
function (parser::TensorParser)(ex::Expr)
23-
if ex isa Expr && ex.head == :function
23+
if ex.head == :function
2424
return Expr(:function, ex.args[1], parser(ex.args[2]))
2525
end
2626
for p in parser.preprocessors

src/indexnotation/preprocessors.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ normalizeindices(ex::Expr) = replaceindices(normalizeindex, ex)
4444

4545
# replace all tensor objects by a function of that object
4646
function replacetensorobjects(f, ex::Expr)
47-
# first try to replace ex completely
47+
# first try to replace ex completely:
48+
# this needed if `ex` is a tensor object that appears outside an actual tensor expression
49+
# in a 'regular' block of code
4850
ex2 = f(ex, nothing, nothing)
4951
ex2 !== ex && return ex2
5052
if istensor(ex)
@@ -54,7 +56,7 @@ function replacetensorobjects(f, ex::Expr)
5456
return Expr(ex.head, (replacetensorobjects(f, e) for e in ex.args)...)
5557
end
5658
end
57-
replacetensorobjects(f, ex) = f(ex, nothing, nothing)
59+
replacetensorobjects(f, ex) = f(ex, nothing, nothing) # same reason as lines 48-52.
5860

5961
# expandconj: conjugate individual terms or factors instead of a whole expression
6062
function expandconj(ex::Expr)

src/indexnotation/verifiers.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,10 @@ function isgeneraltensor(ex::Expr)
7676
return count == 1
7777
elseif ex.head == :call && ex.args[1] == :/ && length(ex.args) == 3
7878
# scalar multiplication
79-
if isscalarexpr(ex.args[3]) && isgeneraltensor(ex.args[2])
80-
return true
81-
end
79+
return (isscalarexpr(ex.args[3]) && isgeneraltensor(ex.args[2]))
8280
elseif ex.head == :call && ex.args[1] == :\ && length(ex.args) == 3
8381
# scalar multiplication
84-
if isscalarexpr(ex.args[2]) && isgeneraltensor(ex.args[3])
85-
return true
86-
end
82+
return (isscalarexpr(ex.args[2]) && isgeneraltensor(ex.args[3]))
8783
end
8884
return false
8985
end
@@ -149,5 +145,5 @@ end
149145
# test for assignment (copy into existing tensor) or definition (create new tensor)
150146
isassignment(ex) = false
151147
isdefinition(ex) = false
152-
isassignment(ex::Expr) = ex.head == :(=) || ex.head == :(+=) || ex.head == :(-=)
153-
isdefinition(ex::Expr) = (ex.head == :(:=) || ex.head == :(≔)) && istensor(ex.args[1])
148+
isassignment(ex::Expr) = (ex.head == :(=) || ex.head == :(+=) || ex.head == :(-=))
149+
isdefinition(ex::Expr) = (ex.head == :(:=) || ex.head == :(≔))

0 commit comments

Comments
 (0)