Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FillArrays"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.10.1"
version = "0.11.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
76 changes: 57 additions & 19 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
const FillVector{F,A} = Fill{F,1,A}
const FillMatrix{F,A} = Fill{F,2,A}
const OnesVector{F,A} = Ones{F,1,A}
const OnesMatrix{F,A} = Ones{F,2,A}
const ZerosVector{F,A} = Zeros{F,1,A}
const ZerosMatrix{F,A} = Zeros{F,2,A}

## vec

vec(a::Ones{T}) where T = Ones{T}(length(a))
Expand Down Expand Up @@ -77,8 +84,10 @@ end

*(a::Zeros{<:Any,1}, b::AbstractMatrix) = mult_zeros(a, b)
*(a::Zeros{<:Any,2}, b::AbstractMatrix) = mult_zeros(a, b)
*(a::Zeros{<:Any,2}, b::AbstractTriangular) = mult_zeros(a, b)
*(a::AbstractMatrix, b::Zeros{<:Any,1}) = mult_zeros(a, b)
*(a::AbstractMatrix, b::Zeros{<:Any,2}) = mult_zeros(a, b)
*(a::AbstractTriangular, b::Zeros{<:Any,2}) = mult_zeros(a, b)
*(a::Zeros{<:Any,1}, b::AbstractVector) = mult_zeros(a, b)
*(a::Zeros{<:Any,2}, b::AbstractVector) = mult_zeros(a, b)
*(a::AbstractVector, b::Zeros{<:Any,2}) = mult_zeros(a, b)
Expand All @@ -87,36 +96,65 @@ end
*(a::Zeros{<:Any,2}, b::Diagonal) = mult_zeros(a, b)
*(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
*(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b)
function *(a::Diagonal, b::AbstractFill{<:Any,2})

function *(a::Diagonal, b::AbstractFill{T,2}) where T
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a.diag .* b # use special broadcast
end
function *(a::AbstractFill{<:Any,2}, b::Diagonal)
function *(a::AbstractFill{T,2}, b::Diagonal) where T
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a .* permutedims(b.diag) # use special broadcast
end

*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2))
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = reshape(sum(a; dims=2) .* b.value, size(a, 1))

function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
fB = similar(parent(a), size(b, 1), size(b, 2))
fill!(fB, b.value)
return a*fB
function mult_sum2(x::AbstractMatrix, f::AbstractFill{T,2}) where T
axes(x, 2) ≠ axes(f, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
m = size(f, 2)
repeat(sum(x, dims=2) * getindex_value(f), 1, m)
end

function *(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
fB = similar(parent(a), size(b, 1), size(b, 2))
fill!(fB, b.value)
return a*fB
function mult_sum1(f::AbstractFill{T,2}, x::AbstractMatrix) where T
axes(f, 2) ≠ axes(x, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
m = size(f, 1)
repeat(sum(x, dims=1) * getindex_value(f), m, 1)
end

function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T
fB = similar(a, size(b, 1), size(b, 2))
fill!(fB, b.value)
return a*fB
end
*(x::AbstractMatrix, y::AbstractFill{<:Any,2}) = mult_sum2(x, y)
*(x::AbstractTriangular, y::AbstractFill{<:Any,2}) = mult_sum2(x, y)
*(x::AbstractFill{<:Any,2}, y::AbstractMatrix) = mult_sum1(x, y)
*(x::AbstractFill{<:Any,2}, y::AbstractTriangular) = mult_sum1(x, y)


### These methods are faster for small n #############
# function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
# fB = similar(parent(a), size(b, 1), size(b, 2))
# fill!(fB, b.value)
# return a*fB
# end

# function *(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
# fB = similar(parent(a), size(b, 1), size(b, 2))
# fill!(fB, b.value)
# return a*fB
# end

# function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T
# fB = similar(a, size(b, 1), size(b, 2))
# fill!(fB, b.value)
# return a*fB
# end

## Matrix-Vector multiplication

*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T =
reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2))
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T =
reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
*(a::StridedMatrix{T}, b::Fill{T, 1}) where T =
reshape(sum(a; dims=2) .* b.value, size(a, 1))


function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
la, lb = length(a), length(b)
if la ≠ lb
Expand Down
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,20 @@ end
@test E*(1:5) ≡ 1.0:5.0
@test (1:5)'E == (1.0:5)'
@test E*E ≡ E

# Adjoint / Transpose / Triangular / Symmetric / Hermitian
for x in [transpose(rand(2, 2)),
adjoint(rand(2,2)),
UpperTriangular(rand(2,2)),
Symmetric(rand(2,2)),
Hermitian(rand(2,2))]
@test x * Ones(2, 2) isa Matrix
@test Ones(2, 2) * x isa Matrix
@test x * Zeros(2, 2) isa Zeros
@test Zeros(2, 2) * x isa Zeros
@test x * Fill(1., 2, 2) isa Matrix
@test Fill(1., 2, 2) * x isa Matrix
end
end

@testset "count" begin
Expand Down