Extending Matrix
Vecxt matrix is a higher kinded thing with no bounds. Vecxt tries to squeeze the best performance out of Double, but doesn't try to impose restrictions on what you can do with it.
Here we offer a matrix multiplication extension method based on Spires typeclasses.
//> using dep org.typelevel::spire:0.18.0
import spire._
import spire.math._
import spire.implicits._
import spire.algebra.Ring
import scala.reflect.ClassTag
import vecxt.*
import vecxt.all.*
import vecxt.BoundsCheck.BoundsCheck
import narr.*
import BoundsCheck.DoBoundsCheck.yes
// A very naive matrix multiplication implementation...
object SpireExt:
extension [A: ClassTag: Ring](m1: Matrix[A])
inline def @@@(
m2: Matrix[A]
)(using inline boundsCheck: BoundsCheck): Matrix[A] =
dimMatCheck(m1, m2)
val (r1, c1) = m1.shape
val (r2, c2) = m2.shape
val nar = NArray.ofSize[A](r1 * c2)
val res = Matrix(nar, (r1, c2))
for i <- 0 until r1 do
for j <- 0 until c2 do
res((i, j)) = (0 until c1)
.map { k =>
val i1 = m1((i: Row, k: Col))
val i2 = m2((k: Row, j: Col))
i1 * i2
}
.reduce(_ + _)
end for
res
end @@@
inline def showMat: String =
val (r, c) = m1.shape
val sb = new StringBuilder
for i <- 0 until r do
for j <- 0 until c do
sb.append(m1((i: Row, j: Col))(using BoundsCheck.DoBoundsCheck.no))
sb.append(" ")
end for
sb.append("\n")
end for
sb.toString
end showMat
end extension
end SpireExt
import SpireExt.*
// That would work with complex numbers
val mat1 = Matrix.fromRows[Complex[Double]](
NArray[Complex[Double]](Complex(1.0, -1.0), Complex(0.0, 2.0), Complex(-2.0, 1.0)),
NArray[Complex[Double]](Complex(0.0, -3.0), Complex(3.0, -2.0), Complex(-1.0, -1.0))
)
// mat1: Matrix[Complex[Double]] = vecxt.matrix$Matrix@63210722
val mat2 = Matrix.fromRows[Complex[Double]](
NArray[Complex[Double]](Complex(0.0, -2.0), Complex(1.0, -4.0)),
NArray[Complex[Double]](Complex(-1.0, 3.0), Complex(2.0, -3.0)),
NArray[Complex[Double]](Complex(-2.0, 1.0), Complex(-4.0, 1.0))
)
// mat2: Matrix[Complex[Double]] = vecxt.matrix$Matrix@64118760
println(mat1.showMat)
// (1.0 + -1.0i) (0.0 + 2.0i) (-2.0 + 1.0i)
// (0.0 + -3.0i) (3.0 + -2.0i) (-1.0 + -1.0i)
//
println(mat2.showMat)
// (0.0 + -2.0i) (1.0 + -4.0i)
// (-1.0 + 3.0i) (2.0 + -3.0i)
// (-2.0 + 1.0i) (-4.0 + 1.0i)
//
println((mat1 @@@ mat2).showMat)
// (-5.0 + -8.0i) (10.0 + -7.0i)
// (0.0 + 12.0i) (-7.0 + -13.0i)
//