Skip to content

Commit

Permalink
Add abstraction to element definition
Browse files Browse the repository at this point in the history
Declar new abstract type AbstractElement and modify functions so that
they take AbstractElement as input argument. This change allows to
define new elements with different properties, in principle. However
most of the functions are heavily related to existing structure of the
Element.
  • Loading branch information
ahojukka5 committed Nov 21, 2019
1 parent f00c5f3 commit 75dfdf2
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 47 deletions.
8 changes: 4 additions & 4 deletions src/assembly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Assemble elements for problem.
This should be overridden with own `assemble_elements!`-implementation.
"""
function assemble_elements!(problem::Problem, assembly::Assembly,
elements::Vector{Element{E}}, time) where E
elements::Vector{T}, time) where T<:AbstractElement{E} where E
elements2 = convert(Vector{Element}, elements)
assemble!(assembly, problem, elements2, time)
end
Expand Down Expand Up @@ -87,7 +87,7 @@ function assemble_mass_matrix!(problem::Problem, time::Float64)
return nothing
end

function assemble_mass_matrix!(problem::Problem, elements::Vector{Element{Basis}}, time) where Basis
function assemble_mass_matrix!(problem::Problem, elements::Vector{E}, time) where E<:AbstractElement{Basis} where Basis
nnodes = length(first(elements))
dim = get_unknown_field_dimension(problem)
M = zeros(nnodes, nnodes)
Expand Down Expand Up @@ -124,7 +124,7 @@ end
Assemble Tet10 mass matrices using special method. If Tet10 has constant metric
if can be integrated analytically to gain performance.
"""
function assemble_mass_matrix!(problem::Problem, elements::Vector{Element{FEMBasis.Tet10}}, time)
function assemble_mass_matrix!(problem::Problem, elements::Vector{E}, time) where E<:AbstractElement{FEMBasis.Tet10}
nnodes = length(Tet10)
dim = get_unknown_field_dimension(problem)
M = zeros(nnodes, nnodes)
Expand All @@ -144,7 +144,7 @@ function assemble_mass_matrix!(problem::Problem, elements::Vector{Element{FEMBas
-6 -4 -6 -4 16 16 8 16 32 16
-6 -6 -4 -4 8 16 16 16 16 32]

function is_CM(::Element{FEMBasis.Tet10}, X; rtol=1.0e-6)
function is_CM(::AbstractElement{FEMBasis.Tet10}, X; rtol=1.0e-6)
isapprox(X[5], 1/2*(X[1]+X[2]); rtol=rtol) || return false
isapprox(X[6], 1/2*(X[2]+X[3]); rtol=rtol) || return false
isapprox(X[7], 1/2*(X[3]+X[1]); rtol=rtol) || return false
Expand Down
18 changes: 9 additions & 9 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Return the length of basis (number of nodes).
"""
function length(element::Element)
function length(element::AbstractElement)
return length(element.properties)
end

Expand All @@ -15,28 +15,28 @@ end
Return the size of basis (dim, nnodes).
"""
function size(element::Element)
function size(element::AbstractElement)
return size(element.properties)
end

function getindex(element::Element, field_name::String)
function getindex(element::AbstractElement, field_name::String)
return element.fields[field_name]
end

function setindex!(element::Element, data::T, field_name) where T<:AbstractField
function setindex!(element::AbstractElement, data::T, field_name) where T<:AbstractField
element.fields[field_name] = data
end

function setindex!(element::Element, data::Function, field_name)
if hasmethod(data, Tuple{Element, Vector, Float64})
function setindex!(element::AbstractElement, data::Function, field_name)
if hasmethod(data, Tuple{AbstractElement, Vector, Float64})
# create enclosure to pass element as argument
element.fields[field_name] = field((ip,time) -> data(element,ip,time))
else
element.fields[field_name] = field(data)
end
end

function setindex!(element::Element, data, field_name)
function setindex!(element::AbstractElement, data, field_name)
element.fields[field_name] = field(data)
end

Expand All @@ -52,12 +52,12 @@ function (element::Element)(field_name::String)
return element[field_name]
end

function size(element::Element, dim)
function size(element::AbstractElement, dim)
return size(element)[dim]
end

""" Check existence of field. """
function haskey(element::Element, field_name)
function haskey(element::AbstractElement, field_name)
return haskey(element.fields, field_name)
end

Expand Down
58 changes: 30 additions & 28 deletions src/elements.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# This file is a part of JuliaFEM.
# License is MIT: see https://github.com/JuliaFEM/FEMBase.jl/blob/master/LICENSE

mutable struct Element{E<:FEMBasis.AbstractBasis}
abstract type AbstractElement{T<:FEMBasis.AbstractBasis} end

mutable struct Element{T} <: AbstractElement{T}
id :: Int
connectivity :: Vector{Int}
integration_points :: Vector{IP}
dfields :: Dict{String, AbstractField}
properties :: E
fields :: Dict{String, AbstractField}
properties :: T
end

"""
Expand Down Expand Up @@ -58,34 +60,34 @@ function Element(::Type{T}, connectivity::Vector{Int}) where T<:FEMBasis.Abstrac
return Element(T, (connectivity...,))
end

function get_element_type(::Element{E}) where E
function get_element_type(::AbstractElement{E}) where E
return E
end

function get_element_id(element::Element{E}) where E
function get_element_id(element::AbstractElement{E}) where E
return element.id
end

function is_element_type(::Element{E}, element_type) where E
function is_element_type(::AbstractElement{E}, element_type) where E
return E === element_type
end

function filter_by_element_type(element_type, elements)
return Iterators.filter(element -> is_element_type(element, element_type), elements)
end

function get_connectivity(element::Element)
function get_connectivity(element::AbstractElement)
return element.connectivity
end

"""
group_by_element_type(elements::Vector{Element})
group_by_element_type(elements)
Given a vector of elements, group elements by element type to several vectors.
Returns a dictionary, where key is the element type and value is a vector
containing all elements of type `element_type`.
"""
function group_by_element_type(elements::Vector{Element})
function group_by_element_type(elements)
results = Dict{DataType, Any}()
basis_types = map(element -> typeof(element.properties), elements)
for basis in unique(basis_types)
Expand Down Expand Up @@ -116,7 +118,7 @@ interpolate(element, "my field", 0.5)
```
"""
function interpolate(element::Element, field_name::String, time::Float64)
function interpolate(element::AbstractElement, field_name::String, time::Float64)
field = element[field_name]
result = interpolate(field, time)
if isa(result, Dict)
Expand All @@ -127,14 +129,14 @@ function interpolate(element::Element, field_name::String, time::Float64)
end
end

function get_basis(element::Element{B}, ip, ::Any) where B
function get_basis(element::AbstractElement{B}, ip, ::Any) where B
T = typeof(first(ip))
N = zeros(T, 1, length(element))
FEMBasis.eval_basis!(B, N, tuple(ip...))
return N
end

function get_dbasis(element::Element{B}, ip, ::Any) where B
function get_dbasis(element::AbstractElement{B}, ip, ::Any) where B
T = typeof(first(ip))
dN = zeros(T, size(element)...)
FEMBasis.eval_dbasis!(B, dN, tuple(ip...))
Expand Down Expand Up @@ -197,65 +199,65 @@ function (element::Element)(field_name::String, ip, time::Float64, ::Type{Val{:G
return grad(element.properties, u, X, ip)
end

function interpolate(element::Element, field_name, ip, time)
function interpolate(element::AbstractElement, field_name, ip, time)
field = element[field_name]
return interpolate_field(element, field, ip, time)
end

function interpolate_field(::Element, field::T, ::Any, time) where T<:Union{DCTI,DCTV}
function interpolate_field(::AbstractElement, field::T, ::Any, time) where T<:Union{DCTI,DCTV}
return interpolate_field(field, time)
end

function interpolate_field(element::Element, field::T, ip, time) where T<:Union{DVTV,DVTI}
function interpolate_field(element::AbstractElement, field::T, ip, time) where T<:Union{DVTV,DVTI}
data = interpolate_field(field, time)
basis = get_basis(element, ip, time)
N = length(basis)
return sum(data[i]*basis[i] for i=1:N)
end

function interpolate_field(element::Element, field::T, ip, time) where T<:Union{DVTVd,DVTId}
function interpolate_field(element::AbstractElement, field::T, ip, time) where T<:Union{DVTVd,DVTId}
data = interpolate_field(field, time)
basis = element(ip, time)
N = length(element)
c = get_connectivity(element)
return sum(data[c[i]]*basis[i] for i=1:N)
end

function interpolate_field(::Element, field::CVTV, ip, time)
function interpolate_field(::AbstractElement, field::CVTV, ip, time)
return field(ip, time)
end

function update!(::Element, field::F, data) where F<:AbstractField
function update!(::AbstractElement, field::F, data) where F<:AbstractField
update!(field, data)
end

function update!(element::Element, field::F, data::Dict{T,V}) where {F<:DVTI,T,V}
function update!(element::AbstractElement, field::F, data::Dict{T,V}) where {F<:DVTI,T,V}
connectivity = get_connectivity(element)
picked_data = ((data[nid] for nid in connectivity)...,)
update!(field, picked_data)
end

function update!(element::Element, field::F,
function update!(element::AbstractElement, field::F,
data::Pair{Float64, Dict{T,V}}) where {F<:DVTV,T,V}
time, data2 = data
connectivity = get_connectivity(element)
picked_data = ((data2[nid] for nid in connectivity)...,)
update!(field, time => picked_data)
end

function update!(element::Element, field_name, data)
function update!(element::AbstractElement, field_name, data)
if haskey(element.fields, field_name)
update!(element, element.fields[field_name], data)
else
element.fields[field_name] = field(data)
end
end

function update!(element::Element, field_name, field_::F) where F<:AbstractField
function update!(element::AbstractElement, field_name, field_::F) where F<:AbstractField
element.fields[field_name] = field_
end

function update!(element::Element, field_name, data::Function)
function update!(element::AbstractElement, field_name, data::Function)
if hasmethod(data, Tuple{Element, Any, Any})
element.fields[field_name] = field((ip, time) -> data(element, ip, time))
else
Expand Down Expand Up @@ -310,7 +312,7 @@ function update!(elements, field_name, data)
end
end

function get_integration_points(element::Element{E}) where E
function get_integration_points(element::AbstractElement{E}) where E
# first time initialize default integration points
if length(element.integration_points) == 0
ips = get_integration_points(element.properties)
Expand All @@ -322,13 +324,13 @@ end
""" This is a special case, temporarily change order
of integration scheme mainly for mass matrix.
"""
function get_integration_points(element::Element{E}, change_order::Int) where E
function get_integration_points(element::AbstractElement{E}, change_order::Int) where E
ips = get_integration_points(element.properties, Val{change_order})
return [IP(i, w, xi) for (i, (w, xi)) in enumerate(ips)]
end

""" Find inverse isoparametric mapping of element. """
function get_local_coordinates(element::Element, X::Vector, time::Float64; max_iterations=10, tolerance=1.0e-6)
function get_local_coordinates(element::AbstractElement, X::Vector, time::Float64; max_iterations=10, tolerance=1.0e-6)
haskey(element, "geometry") || error("element geometry not defined, cannot calculate inverse isoparametric mapping")
dim = size(element, 1)
dim == length(X) || error("manifolds not supported.")
Expand All @@ -345,7 +347,7 @@ function get_local_coordinates(element::Element, X::Vector, time::Float64; max_i
end

""" Test is X inside element. """
function inside(element::Element{E}, X, time) where E
function inside(element::AbstractElement{E}, X, time) where E
xi = get_local_coordinates(element, X, time)
return inside(E, xi)
end
Expand All @@ -362,7 +364,7 @@ function (element::Element)(field_name::String, ip, time::Float64)
return interpolate(element, field_name, ip, time)
end

function element_info!(bi::FEMBasis.BasisInfo{E,T}, element::Element{E}, ip, time) where {E,T}
function element_info!(bi::FEMBasis.BasisInfo{E,T}, element::AbstractElement{E}, ip, time) where {E,T}
X = interpolate(element, "geometry", time)
eval_basis!(bi, X, ip)
return bi.J, bi.detJ, bi.N, bi.grad
Expand Down
8 changes: 4 additions & 4 deletions src/elements_lagrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

struct Poi1 <: FEMBasis.AbstractBasis end

function get_basis(::Element{Poi1}, ::Any, ::Any)
function get_basis(::E, ::Any, ::Any) where E<:AbstractElement{Poi1}
return [1]
end

function get_dbasis(::Element{Poi1}, ::Any, ::Any)
function get_dbasis(::E, ::Any, ::Any) where E<:AbstractElement{Poi1}
return [0]
end

function (::Element{Poi1})(::Any, ::Float64, ::Type{Val{:detJ}})
function (::Element{Poi1})(::Any, ::Float64, ::Type{Val{:detJ}})
return 1.0
end

Expand Down Expand Up @@ -47,7 +47,7 @@ function inside(::Union{Type{FEMBasis.Tri3}, Type{FEMBasis.Tri6}, Type{FEMBasis.
return all(xi .>= 0.0) && (sum(xi) <= 1.0)
end

function get_reference_coordinates(::Element{B}) where B
function get_reference_coordinates(::E) where E<:AbstractElement{B} where B
return get_reference_element_coordinates(B)
end

4 changes: 2 additions & 2 deletions src/problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ of the problem (Elasticity, Dirichlet, etc.), `element_name` is the name of a
constructed element (see Element(element_type, connectivity_vector)) and `time`
is the starting time of the initializing process.
"""
function initialize!(problem::Problem, element::Element, time::Float64)
function initialize!(problem::Problem, element::AbstractElement, time::Float64)
field_name = get_unknown_field_name(problem)
field_dim = get_unknown_field_dimension(problem)
nnodes = length(element)
Expand Down Expand Up @@ -469,7 +469,7 @@ where `nid` is node id and `dim` is the dimension of problem. This formula
arranges dofs so that first comes all dofs of node 1, then node 2 and so on:
(u11, u12, u13, u21, u22, u23, ..., un1, un2, un3) for 3 dofs/node setting.
"""
function get_gdofs(problem::Problem, element::Element)
function get_gdofs(problem::Problem, element::AbstractElement)
if haskey(problem.dofmap, element)
return problem.dofmap[element]
end
Expand Down

0 comments on commit 75dfdf2

Please sign in to comment.