diff liboctave/CDiagMatrix.cc @ 8366:8b1a2555c4e2

implement diagonal matrix objects * * *
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 03 Dec 2008 13:32:57 +0100
parents 82be108cc558
children c3f7e2549abb
line wrap: on
line diff
--- a/liboctave/CDiagMatrix.cc
+++ b/liboctave/CDiagMatrix.cc
@@ -232,6 +232,15 @@
   return *this;
 }
 
+DiagMatrix
+ComplexDiagMatrix::abs (void) const
+{
+  DiagMatrix retval (rows (), cols ());
+  for (octave_idx_type i = 0; i < rows (); i++)
+    retval(i, i) = std::abs (elem (i, i));
+  return retval;
+}
+
 ComplexDiagMatrix
 conj (const ComplexDiagMatrix& a)
 {
@@ -378,6 +387,21 @@
   return retval;
 }
 
+bool
+ComplexDiagMatrix::all_elements_are_real (void) const
+{
+  octave_idx_type len = length ();
+  for (octave_idx_type i = 0; i < len; i++)
+    {
+      double ip = std::imag (elem (i, i));
+
+      if (ip != 0.0 || lo_ieee_signbit (ip))
+        return false;
+    }
+
+  return true;
+}
+
 // diagonal matrix by diagonal matrix -> diagonal matrix operations
 
 ComplexDiagMatrix&
@@ -431,14 +455,7 @@
       Complex a_element = a.elem (i, i);
       double b_element = b.elem (i, i);
 
-      if (a_element == 0.0 || b_element == 0.0)
-        c.elem (i, i) = 0.0;
-      else if (a_element == 1.0)
-        c.elem (i, i) = b_element;
-      else if (b_element == 1.0)
-        c.elem (i, i) = a_element;
-      else
-        c.elem (i, i) = a_element * b_element;
+      c.elem (i, i) = a_element * b_element;
     }
 
   return c;
@@ -471,14 +488,40 @@
       double a_element = a.elem (i, i);
       Complex b_element = b.elem (i, i);
 
-      if (a_element == 0.0 || b_element == 0.0)
-        c.elem (i, i) = 0.0;
-      else if (a_element == 1.0)
-        c.elem (i, i) = b_element;
-      else if (b_element == 1.0)
-        c.elem (i, i) = a_element;
-      else
-        c.elem (i, i) = a_element * b_element;
+      c.elem (i, i) = a_element * b_element;
+    }
+
+  return c;
+}
+
+ComplexDiagMatrix
+operator * (const ComplexDiagMatrix& a, const ComplexDiagMatrix& b)
+{
+  octave_idx_type a_nr = a.rows ();
+  octave_idx_type a_nc = a.cols ();
+
+  octave_idx_type b_nr = b.rows ();
+  octave_idx_type b_nc = b.cols ();
+
+  if (a_nc != b_nr)
+    {
+      gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc);
+      return ComplexDiagMatrix ();
+    }
+
+  if (a_nr == 0 || a_nc == 0 || b_nc == 0)
+    return ComplexDiagMatrix (a_nr, a_nc, 0.0);
+
+  ComplexDiagMatrix c (a_nr, b_nc);
+
+  octave_idx_type len = a_nr < b_nc ? a_nr : b_nc;
+
+  for (octave_idx_type i = 0; i < len; i++)
+    {
+      Complex a_element = a.elem (i, i);
+      Complex b_element = b.elem (i, i);
+
+      c.elem (i, i) = a_element * b_element;
     }
 
   return c;