Mercurial > hg > octave-lyh
comparison src/data.cc @ 10716:f7f26094021b
improve cat code design in data.cc, make horzcat/vertcat more Matlab compatible
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Mon, 21 Jun 2010 15:48:56 +0200 |
parents | fbd7843974fa |
children | f3892d8eea9f |
comparison
equal
deleted
inserted
replaced
10715:53253f796351 | 10716:f7f26094021b |
---|---|
1374 single_type_concat (Array<T>& result, | 1374 single_type_concat (Array<T>& result, |
1375 const octave_value_list& args, | 1375 const octave_value_list& args, |
1376 int dim) | 1376 int dim) |
1377 { | 1377 { |
1378 int n_args = args.length (); | 1378 int n_args = args.length (); |
1379 OCTAVE_LOCAL_BUFFER (Array<T>, array_list, n_args - 1); | 1379 OCTAVE_LOCAL_BUFFER (Array<T>, array_list, n_args); |
1380 | 1380 |
1381 for (int j = 1; j < n_args && ! error_state; j++) | 1381 for (int j = 0; j < n_args && ! error_state; j++) |
1382 { | 1382 { |
1383 octave_quit (); | 1383 octave_quit (); |
1384 | 1384 |
1385 array_list[j-1] = octave_value_extract<TYPE> (args(j)); | 1385 array_list[j] = octave_value_extract<TYPE> (args(j)); |
1386 } | 1386 } |
1387 | 1387 |
1388 if (! error_state) | 1388 if (! error_state) |
1389 result = Array<T>::cat (dim, n_args-1, array_list); | 1389 result = Array<T>::cat (dim, n_args, array_list); |
1390 } | 1390 } |
1391 | 1391 |
1392 template <class TYPE, class T> | 1392 template <class TYPE, class T> |
1393 static void | 1393 static void |
1394 single_type_concat (Sparse<T>& result, | 1394 single_type_concat (Sparse<T>& result, |
1395 const octave_value_list& args, | 1395 const octave_value_list& args, |
1396 int dim) | 1396 int dim) |
1397 { | 1397 { |
1398 int n_args = args.length (); | 1398 int n_args = args.length (); |
1399 OCTAVE_LOCAL_BUFFER (Sparse<T>, sparse_list, n_args-1); | 1399 OCTAVE_LOCAL_BUFFER (Sparse<T>, sparse_list, n_args); |
1400 | 1400 |
1401 for (int j = 1; j < n_args && ! error_state; j++) | 1401 for (int j = 0; j < n_args && ! error_state; j++) |
1402 { | 1402 { |
1403 octave_quit (); | 1403 octave_quit (); |
1404 | 1404 |
1405 sparse_list[j-1] = octave_value_extract<TYPE> (args(j)); | 1405 sparse_list[j] = octave_value_extract<TYPE> (args(j)); |
1406 } | 1406 } |
1407 | 1407 |
1408 if (! error_state) | 1408 if (! error_state) |
1409 result = Sparse<T>::cat (dim, n_args-1, sparse_list); | 1409 result = Sparse<T>::cat (dim, n_args, sparse_list); |
1410 } | 1410 } |
1411 | 1411 |
1412 // Dispatcher. | 1412 // Dispatcher. |
1413 template<class TYPE> | 1413 template<class TYPE> |
1414 static TYPE | 1414 static TYPE |
1420 | 1420 |
1421 return result; | 1421 return result; |
1422 } | 1422 } |
1423 | 1423 |
1424 static octave_value | 1424 static octave_value |
1425 do_cat (const octave_value_list& args, std::string fname) | 1425 do_cat (const octave_value_list& args, int dim, std::string fname) |
1426 { | 1426 { |
1427 octave_value retval; | 1427 octave_value retval; |
1428 | 1428 |
1429 int n_args = args.length (); | 1429 int n_args = args.length (); |
1430 | 1430 |
1431 if (n_args == 1) | 1431 if (n_args == 0) |
1432 retval = Matrix (); | 1432 retval = Matrix (); |
1433 else if (n_args == 2) | 1433 else if (n_args == 1) |
1434 retval = args(1); | 1434 retval = args(0); |
1435 else if (n_args > 2) | 1435 else if (n_args > 1) |
1436 { | 1436 { |
1437 octave_idx_type dim = args(0).int_value () - 1; | 1437 |
1438 | 1438 std::string result_type = args(0).class_name (); |
1439 if (error_state) | 1439 |
1440 { | 1440 bool all_sq_strings_p = args(0).is_sq_string (); |
1441 error ("cat: expecting first argument to be a integer"); | 1441 bool all_dq_strings_p = args(0).is_dq_string (); |
1442 return retval; | 1442 bool all_real_p = args(0).is_real_type (); |
1443 } | 1443 bool any_sparse_p = args(0).is_sparse_type(); |
1444 | 1444 |
1445 if (dim >= 0) | 1445 for (int i = 1; i < args.length (); i++) |
1446 { | 1446 { |
1447 | 1447 result_type = |
1448 std::string result_type = args(1).class_name (); | 1448 get_concat_class (result_type, args(i).class_name ()); |
1449 | 1449 |
1450 bool all_sq_strings_p = args(1).is_sq_string (); | 1450 if (all_sq_strings_p && ! args(i).is_sq_string ()) |
1451 bool all_dq_strings_p = args(1).is_dq_string (); | 1451 all_sq_strings_p = false; |
1452 bool all_real_p = args(1).is_real_type (); | 1452 if (all_dq_strings_p && ! args(i).is_dq_string ()) |
1453 bool any_sparse_p = args(1).is_sparse_type(); | 1453 all_dq_strings_p = false; |
1454 | 1454 if (all_real_p && ! args(i).is_real_type ()) |
1455 for (int i = 2; i < args.length (); i++) | 1455 all_real_p = false; |
1456 if (!any_sparse_p && args(i).is_sparse_type ()) | |
1457 any_sparse_p = true; | |
1458 } | |
1459 | |
1460 if (result_type == "double") | |
1461 { | |
1462 if (any_sparse_p) | |
1463 { | |
1464 if (all_real_p) | |
1465 retval = do_single_type_concat<SparseMatrix> (args, dim); | |
1466 else | |
1467 retval = do_single_type_concat<SparseComplexMatrix> (args, dim); | |
1468 } | |
1469 else | |
1456 { | 1470 { |
1457 result_type = | 1471 if (all_real_p) |
1458 get_concat_class (result_type, args(i).class_name ()); | 1472 retval = do_single_type_concat<NDArray> (args, dim); |
1459 | 1473 else |
1460 if (all_sq_strings_p && ! args(i).is_sq_string ()) | 1474 retval = do_single_type_concat<ComplexNDArray> (args, dim); |
1461 all_sq_strings_p = false; | |
1462 if (all_dq_strings_p && ! args(i).is_dq_string ()) | |
1463 all_dq_strings_p = false; | |
1464 if (all_real_p && ! args(i).is_real_type ()) | |
1465 all_real_p = false; | |
1466 if (!any_sparse_p && args(i).is_sparse_type ()) | |
1467 any_sparse_p = true; | |
1468 } | 1475 } |
1469 | 1476 } |
1470 if (result_type == "double") | 1477 else if (result_type == "single") |
1478 { | |
1479 if (all_real_p) | |
1480 retval = do_single_type_concat<FloatNDArray> (args, dim); | |
1481 else | |
1482 retval = do_single_type_concat<FloatComplexNDArray> (args, dim); | |
1483 } | |
1484 else if (result_type == "char") | |
1485 { | |
1486 char type = all_dq_strings_p ? '"' : '\''; | |
1487 | |
1488 maybe_warn_string_concat (all_dq_strings_p, all_sq_strings_p); | |
1489 | |
1490 charNDArray result = do_single_type_concat<charNDArray> (args, dim); | |
1491 | |
1492 retval = octave_value (result, type); | |
1493 } | |
1494 else if (result_type == "logical") | |
1495 { | |
1496 if (any_sparse_p) | |
1497 retval = do_single_type_concat<SparseBoolMatrix> (args, dim); | |
1498 else | |
1499 retval = do_single_type_concat<boolNDArray> (args, dim); | |
1500 } | |
1501 else if (result_type == "int8") | |
1502 retval = do_single_type_concat<int8NDArray> (args, dim); | |
1503 else if (result_type == "int16") | |
1504 retval = do_single_type_concat<int16NDArray> (args, dim); | |
1505 else if (result_type == "int32") | |
1506 retval = do_single_type_concat<int32NDArray> (args, dim); | |
1507 else if (result_type == "int64") | |
1508 retval = do_single_type_concat<int64NDArray> (args, dim); | |
1509 else if (result_type == "uint8") | |
1510 retval = do_single_type_concat<uint8NDArray> (args, dim); | |
1511 else if (result_type == "uint16") | |
1512 retval = do_single_type_concat<uint16NDArray> (args, dim); | |
1513 else if (result_type == "uint32") | |
1514 retval = do_single_type_concat<uint32NDArray> (args, dim); | |
1515 else if (result_type == "uint64") | |
1516 retval = do_single_type_concat<uint64NDArray> (args, dim); | |
1517 else | |
1518 { | |
1519 dim_vector dv = args(0).dims (); | |
1520 | |
1521 // Default concatenation. | |
1522 bool (dim_vector::*concat_rule) (const dim_vector&, int) = &dim_vector::concat; | |
1523 | |
1524 if (dim == -1 || dim == -2) | |
1471 { | 1525 { |
1472 if (any_sparse_p) | 1526 concat_rule = &dim_vector::hvcat; |
1473 { | 1527 dim = -dim - 1; |
1474 if (all_real_p) | 1528 } |
1475 retval = do_single_type_concat<SparseMatrix> (args, dim); | 1529 |
1476 else | 1530 for (int i = 1; i < args.length (); i++) |
1477 retval = do_single_type_concat<SparseComplexMatrix> (args, dim); | 1531 { |
1532 if (! (dv.*concat_rule) (args(i).dims (), dim)) | |
1533 { | |
1534 // Dimensions do not match. | |
1535 error ("cat: dimension mismatch"); | |
1536 return retval; | |
1537 } | |
1538 } | |
1539 | |
1540 // The lines below might seem crazy, since we take a copy | |
1541 // of the first argument, resize it to be empty and then resize | |
1542 // it to be full. This is done since it means that there is no | |
1543 // recopying of data, as would happen if we used a single resize. | |
1544 // It should be noted that resize operation is also significantly | |
1545 // slower than the do_cat_op function, so it makes sense to have | |
1546 // an empty matrix and copy all data. | |
1547 // | |
1548 // We might also start with a empty octave_value using | |
1549 // tmp = octave_value_typeinfo::lookup_type | |
1550 // (args(1).type_name()); | |
1551 // and then directly resize. However, for some types there might | |
1552 // be some additional setup needed, and so this should be avoided. | |
1553 | |
1554 octave_value tmp = args (0); | |
1555 tmp = tmp.resize (dim_vector (0,0)).resize (dv); | |
1556 | |
1557 if (error_state) | |
1558 return retval; | |
1559 | |
1560 int dv_len = dv.length (); | |
1561 Array<octave_idx_type> ra_idx (dv_len, 1, 0); | |
1562 | |
1563 for (int j = 0; j < n_args; j++) | |
1564 { | |
1565 // Can't fast return here to skip empty matrices as something | |
1566 // like cat(1,[],single([])) must return an empty matrix of | |
1567 // the right type. | |
1568 tmp = do_cat_op (tmp, args (j), ra_idx); | |
1569 | |
1570 if (error_state) | |
1571 return retval; | |
1572 | |
1573 dim_vector dv_tmp = args (j).dims (); | |
1574 | |
1575 if (dim >= dv_len) | |
1576 { | |
1577 if (j > 1) | |
1578 error ("%s: indexing error", fname.c_str ()); | |
1579 break; | |
1478 } | 1580 } |
1479 else | 1581 else |
1480 { | 1582 ra_idx (dim) += (dim < dv_tmp.length () ? |
1481 if (all_real_p) | 1583 dv_tmp (dim) : 1); |
1482 retval = do_single_type_concat<NDArray> (args, dim); | |
1483 else | |
1484 retval = do_single_type_concat<ComplexNDArray> (args, dim); | |
1485 } | |
1486 } | 1584 } |
1487 else if (result_type == "single") | 1585 retval = tmp; |
1488 { | 1586 } |
1489 if (all_real_p) | |
1490 retval = do_single_type_concat<FloatNDArray> (args, dim); | |
1491 else | |
1492 retval = do_single_type_concat<FloatComplexNDArray> (args, dim); | |
1493 } | |
1494 else if (result_type == "char") | |
1495 { | |
1496 char type = all_dq_strings_p ? '"' : '\''; | |
1497 | |
1498 maybe_warn_string_concat (all_dq_strings_p, all_sq_strings_p); | |
1499 | |
1500 charNDArray result = do_single_type_concat<charNDArray> (args, dim); | |
1501 | |
1502 retval = octave_value (result, type); | |
1503 } | |
1504 else if (result_type == "logical") | |
1505 { | |
1506 if (any_sparse_p) | |
1507 retval = do_single_type_concat<SparseBoolMatrix> (args, dim); | |
1508 else | |
1509 retval = do_single_type_concat<boolNDArray> (args, dim); | |
1510 } | |
1511 else if (result_type == "int8") | |
1512 retval = do_single_type_concat<int8NDArray> (args, dim); | |
1513 else if (result_type == "int16") | |
1514 retval = do_single_type_concat<int16NDArray> (args, dim); | |
1515 else if (result_type == "int32") | |
1516 retval = do_single_type_concat<int32NDArray> (args, dim); | |
1517 else if (result_type == "int64") | |
1518 retval = do_single_type_concat<int64NDArray> (args, dim); | |
1519 else if (result_type == "uint8") | |
1520 retval = do_single_type_concat<uint8NDArray> (args, dim); | |
1521 else if (result_type == "uint16") | |
1522 retval = do_single_type_concat<uint16NDArray> (args, dim); | |
1523 else if (result_type == "uint32") | |
1524 retval = do_single_type_concat<uint32NDArray> (args, dim); | |
1525 else if (result_type == "uint64") | |
1526 retval = do_single_type_concat<uint64NDArray> (args, dim); | |
1527 else | |
1528 { | |
1529 dim_vector dv = args(1).dims (); | |
1530 | |
1531 for (int i = 2; i < args.length (); i++) | |
1532 { | |
1533 if (! dv.concat (args(i).dims (), dim)) | |
1534 { | |
1535 // Dimensions do not match. | |
1536 error ("cat: dimension mismatch"); | |
1537 return retval; | |
1538 } | |
1539 } | |
1540 | |
1541 // The lines below might seem crazy, since we take a copy | |
1542 // of the first argument, resize it to be empty and then resize | |
1543 // it to be full. This is done since it means that there is no | |
1544 // recopying of data, as would happen if we used a single resize. | |
1545 // It should be noted that resize operation is also significantly | |
1546 // slower than the do_cat_op function, so it makes sense to have | |
1547 // an empty matrix and copy all data. | |
1548 // | |
1549 // We might also start with a empty octave_value using | |
1550 // tmp = octave_value_typeinfo::lookup_type | |
1551 // (args(1).type_name()); | |
1552 // and then directly resize. However, for some types there might | |
1553 // be some additional setup needed, and so this should be avoided. | |
1554 | |
1555 octave_value tmp = args (1); | |
1556 tmp = tmp.resize (dim_vector (0,0)).resize (dv); | |
1557 | |
1558 if (error_state) | |
1559 return retval; | |
1560 | |
1561 int dv_len = dv.length (); | |
1562 Array<octave_idx_type> ra_idx (dv_len, 1, 0); | |
1563 | |
1564 for (int j = 1; j < n_args; j++) | |
1565 { | |
1566 // Can't fast return here to skip empty matrices as something | |
1567 // like cat(1,[],single([])) must return an empty matrix of | |
1568 // the right type. | |
1569 tmp = do_cat_op (tmp, args (j), ra_idx); | |
1570 | |
1571 if (error_state) | |
1572 return retval; | |
1573 | |
1574 dim_vector dv_tmp = args (j).dims (); | |
1575 | |
1576 if (dim >= dv_len) | |
1577 { | |
1578 if (j > 1) | |
1579 error ("%s: indexing error", fname.c_str ()); | |
1580 break; | |
1581 } | |
1582 else | |
1583 ra_idx (dim) += (dim < dv_tmp.length () ? | |
1584 dv_tmp (dim) : 1); | |
1585 } | |
1586 retval = tmp; | |
1587 } | |
1588 } | |
1589 else | |
1590 error ("%s: invalid dimension argument", fname.c_str ()); | |
1591 } | 1587 } |
1592 else | 1588 else |
1593 print_usage (); | 1589 print_usage (); |
1594 | 1590 |
1595 return retval; | 1591 return retval; |
1601 Return the horizontal concatenation of N-d array objects, @var{array1},\n\ | 1597 Return the horizontal concatenation of N-d array objects, @var{array1},\n\ |
1602 @var{array2}, @dots{}, @var{arrayN} along dimension 2.\n\ | 1598 @var{array2}, @dots{}, @var{arrayN} along dimension 2.\n\ |
1603 @seealso{cat, vertcat}\n\ | 1599 @seealso{cat, vertcat}\n\ |
1604 @end deftypefn") | 1600 @end deftypefn") |
1605 { | 1601 { |
1606 octave_value_list args_tmp = args; | 1602 return do_cat (args, -2, "horzcat"); |
1607 | |
1608 int dim = 2; | |
1609 | |
1610 octave_value d (dim); | |
1611 | |
1612 args_tmp.prepend (d); | |
1613 | |
1614 return do_cat (args_tmp, "horzcat"); | |
1615 } | 1603 } |
1616 | 1604 |
1617 DEFUN (vertcat, args, , | 1605 DEFUN (vertcat, args, , |
1618 "-*- texinfo -*-\n\ | 1606 "-*- texinfo -*-\n\ |
1619 @deftypefn {Built-in Function} {} vertcat (@var{array1}, @var{array2}, @dots{}, @var{arrayN})\n\ | 1607 @deftypefn {Built-in Function} {} vertcat (@var{array1}, @var{array2}, @dots{}, @var{arrayN})\n\ |
1620 Return the vertical concatenation of N-d array objects, @var{array1},\n\ | 1608 Return the vertical concatenation of N-d array objects, @var{array1},\n\ |
1621 @var{array2}, @dots{}, @var{arrayN} along dimension 1.\n\ | 1609 @var{array2}, @dots{}, @var{arrayN} along dimension 1.\n\ |
1622 @seealso{cat, horzcat}\n\ | 1610 @seealso{cat, horzcat}\n\ |
1623 @end deftypefn") | 1611 @end deftypefn") |
1624 { | 1612 { |
1625 octave_value_list args_tmp = args; | 1613 return do_cat (args, -1, "vertcat"); |
1626 | |
1627 int dim = 1; | |
1628 | |
1629 octave_value d (dim); | |
1630 | |
1631 args_tmp.prepend (d); | |
1632 | |
1633 return do_cat (args_tmp, "vertcat"); | |
1634 } | 1614 } |
1635 | 1615 |
1636 DEFUN (cat, args, , | 1616 DEFUN (cat, args, , |
1637 "-*- texinfo -*-\n\ | 1617 "-*- texinfo -*-\n\ |
1638 @deftypefn {Built-in Function} {} cat (@var{dim}, @var{array1}, @var{array2}, @dots{}, @var{arrayN})\n\ | 1618 @deftypefn {Built-in Function} {} cat (@var{dim}, @var{array1}, @var{array2}, @dots{}, @var{arrayN})\n\ |
1679 @end group\n\ | 1659 @end group\n\ |
1680 @end example\n\ | 1660 @end example\n\ |
1681 @seealso{horzcat, vertcat}\n\ | 1661 @seealso{horzcat, vertcat}\n\ |
1682 @end deftypefn") | 1662 @end deftypefn") |
1683 { | 1663 { |
1684 return do_cat (args, "cat"); | 1664 octave_value retval; |
1665 | |
1666 if (args.length () > 0) | |
1667 { | |
1668 int dim = args(0).int_value () - 1; | |
1669 | |
1670 if (! error_state) | |
1671 { | |
1672 if (dim >= 0) | |
1673 retval = do_cat (args.slice (1, args.length () - 1), dim, "cat"); | |
1674 else | |
1675 error ("cat: invalid dimension specified"); | |
1676 } | |
1677 else | |
1678 error ("cat: expecting first argument to be a integer"); | |
1679 } | |
1680 else | |
1681 print_usage (); | |
1682 | |
1683 return retval; | |
1685 } | 1684 } |
1686 | 1685 |
1687 /* | 1686 /* |
1688 | 1687 |
1689 %!function ret = testcat (t1, t2, tr, cmplx) | 1688 %!function ret = testcat (t1, t2, tr, cmplx) |