001// --- BEGIN LICENSE BLOCK --- 002/* 003 * Copyright (c) 2009-2011, Mikio L. Braun 004 * 2011, Nicolas Oury 005 * All rights reserved. 006 * 007 * Redistribution and use in source and binary forms, with or without 008 * modification, are permitted provided that the following conditions are 009 * met: 010 * 011 * * Redistributions of source code must retain the above copyright 012 * notice, this list of conditions and the following disclaimer. 013 * 014 * * Redistributions in binary form must reproduce the above 015 * copyright notice, this list of conditions and the following 016 * disclaimer in the documentation and/or other materials provided 017 * with the distribution. 018 * 019 * * Neither the name of the Technische Universit?t Berlin nor the 020 * names of its contributors may be used to endorse or promote 021 * products derived from this software without specific prior 022 * written permission. 023 * 024 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 025 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 026 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 027 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 028 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 029 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 030 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 031 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 032 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 033 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 034 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 035 */ 036// --- END LICENSE BLOCK --- 037 038package org.jblas; 039 040import org.jblas.exceptions.*; 041import org.jblas.util.Functions; 042 043import static org.jblas.util.Functions.*; 044 045//import edu.ida.core.OutputValue; 046 047/** 048 * This class provides a cleaner direct interface to the BLAS routines by 049 * extracting the parameters of the matrices from the matrices itself. 050 * <p/> 051 * For example, you can just pass the vector and do not have to pass the length, 052 * corresponding DoubleBuffer, offset and step size explicitly. 053 * <p/> 054 * Currently, all the general matrix routines are implemented. 055 */ 056public class SimpleBlas { 057 /*************************************************************************** 058 * BLAS Level 1 059 */ 060 061 /** 062 * Compute x <-> y (swap two matrices) 063 */ 064 public static DoubleMatrix swap(DoubleMatrix x, DoubleMatrix y) { 065 //NativeBlas.dswap(x.length, x.data, 0, 1, y.data, 0, 1); 066 JavaBlas.rswap(x.length, x.data, 0, 1, y.data, 0, 1); 067 return y; 068 } 069 070 /** 071 * Compute x <- alpha * x (scale a matrix) 072 */ 073 public static DoubleMatrix scal(double alpha, DoubleMatrix x) { 074 NativeBlas.dscal(x.length, alpha, x.data, 0, 1); 075 return x; 076 } 077 078 public static ComplexDoubleMatrix scal(ComplexDouble alpha, ComplexDoubleMatrix x) { 079 NativeBlas.zscal(x.length, alpha, x.data, 0, 1); 080 return x; 081 } 082 083 /** 084 * Compute y <- x (copy a matrix) 085 */ 086 public static DoubleMatrix copy(DoubleMatrix x, DoubleMatrix y) { 087 //NativeBlas.dcopy(x.length, x.data, 0, 1, y.data, 0, 1); 088 JavaBlas.rcopy(x.length, x.data, 0, 1, y.data, 0, 1); 089 return y; 090 } 091 092 public static ComplexDoubleMatrix copy(ComplexDoubleMatrix x, ComplexDoubleMatrix y) { 093 NativeBlas.zcopy(x.length, x.data, 0, 1, y.data, 0, 1); 094 return y; 095 } 096 097 /** 098 * Compute y <- alpha * x + y (elementwise addition) 099 */ 100 public static DoubleMatrix axpy(double da, DoubleMatrix dx, DoubleMatrix dy) { 101 //NativeBlas.daxpy(dx.length, da, dx.data, 0, 1, dy.data, 0, 1); 102 JavaBlas.raxpy(dx.length, da, dx.data, 0, 1, dy.data, 0, 1); 103 104 return dy; 105 } 106 107 public static ComplexDoubleMatrix axpy(ComplexDouble da, ComplexDoubleMatrix dx, ComplexDoubleMatrix dy) { 108 NativeBlas.zaxpy(dx.length, da, dx.data, 0, 1, dy.data, 0, 1); 109 return dy; 110 } 111 112 /** 113 * Compute x^T * y (dot product) 114 */ 115 public static double dot(DoubleMatrix x, DoubleMatrix y) { 116 //return NativeBlas.ddot(x.length, x.data, 0, 1, y.data, 0, 1); 117 return JavaBlas.rdot(x.length, x.data, 0, 1, y.data, 0, 1); 118 } 119 120 /** 121 * Compute x^T * y (dot product) 122 */ 123 public static ComplexDouble dotc(ComplexDoubleMatrix x, ComplexDoubleMatrix y) { 124 return NativeBlas.zdotc(x.length, x.data, 0, 1, y.data, 0, 1); 125 } 126 127 /** 128 * Compute x^T * y (dot product) 129 */ 130 public static ComplexDouble dotu(ComplexDoubleMatrix x, ComplexDoubleMatrix y) { 131 return NativeBlas.zdotu(x.length, x.data, 0, 1, y.data, 0, 1); 132 } 133 134 /** 135 * Compute || x ||_2 (2-norm) 136 */ 137 public static double nrm2(DoubleMatrix x) { 138 return NativeBlas.dnrm2(x.length, x.data, 0, 1); 139 } 140 141 public static double nrm2(ComplexDoubleMatrix x) { 142 return NativeBlas.dznrm2(x.length, x.data, 0, 1); 143 } 144 145 /** 146 * Compute || x ||_1 (1-norm, sum of absolute values) 147 */ 148 public static double asum(DoubleMatrix x) { 149 return NativeBlas.dasum(x.length, x.data, 0, 1); 150 } 151 152 public static double asum(ComplexDoubleMatrix x) { 153 return NativeBlas.dzasum(x.length, x.data, 0, 1); 154 } 155 156 /** 157 * Compute index of element with largest absolute value (index of absolute 158 * value maximum) 159 */ 160 public static int iamax(DoubleMatrix x) { 161 return NativeBlas.idamax(x.length, x.data, 0, 1) - 1; 162 } 163 164 /** 165 * Compute index of element with largest absolute value (complex version). 166 * 167 * @param x matrix 168 * @return index of element with largest absolute value. 169 */ 170 public static int iamax(ComplexDoubleMatrix x) { 171 return NativeBlas.izamax(x.length, x.data, 0, 1) - 1; 172 } 173 174 /*************************************************************************** 175 * BLAS Level 2 176 */ 177 178 /** 179 * Compute y <- alpha*op(a)*x + beta * y (general matrix vector 180 * multiplication) 181 */ 182 public static DoubleMatrix gemv(double alpha, DoubleMatrix a, 183 DoubleMatrix x, double beta, DoubleMatrix y) { 184 if (false) { 185 NativeBlas.dgemv('N', a.rows, a.columns, alpha, a.data, 0, a.rows, x.data, 0, 186 1, beta, y.data, 0, 1); 187 } else { 188 if (beta == 0.0) { 189 for (int i = 0; i < y.length; i++) 190 y.data[i] = 0.0; 191 192 for (int j = 0; j < a.columns; j++) { 193 double xj = x.get(j); 194 if (xj != 0.0) { 195 for (int i = 0; i < a.rows; i++) 196 y.data[i] += a.get(i, j) * xj; 197 } 198 } 199 } else { 200 for (int j = 0; j < a.columns; j++) { 201 double byj = beta * y.data[j]; 202 double xj = x.get(j); 203 for (int i = 0; i < a.rows; i++) 204 y.data[j] = a.get(i, j) * xj + byj; 205 } 206 } 207 } 208 return y; 209 } 210 211 /** 212 * Compute A <- alpha * x * y^T + A (general rank-1 update) 213 */ 214 public static DoubleMatrix ger(double alpha, DoubleMatrix x, 215 DoubleMatrix y, DoubleMatrix a) { 216 NativeBlas.dger(a.rows, a.columns, alpha, x.data, 0, 1, y.data, 0, 1, a.data, 217 0, a.rows); 218 return a; 219 } 220 221 /** 222 * Compute A <- alpha * x * y^T + A (general rank-1 update) 223 */ 224 public static ComplexDoubleMatrix geru(ComplexDouble alpha, ComplexDoubleMatrix x, 225 ComplexDoubleMatrix y, ComplexDoubleMatrix a) { 226 NativeBlas.zgeru(a.rows, a.columns, alpha, x.data, 0, 1, y.data, 0, 1, a.data, 227 0, a.rows); 228 return a; 229 } 230 231 /** 232 * Compute A <- alpha * x * y^H + A (general rank-1 update) 233 */ 234 public static ComplexDoubleMatrix gerc(ComplexDouble alpha, ComplexDoubleMatrix x, 235 ComplexDoubleMatrix y, ComplexDoubleMatrix a) { 236 NativeBlas.zgerc(a.rows, a.columns, alpha, x.data, 0, 1, y.data, 0, 1, a.data, 237 0, a.rows); 238 return a; 239 } 240 241 /*************************************************************************** 242 * BLAS Level 3 243 */ 244 245 /** 246 * Compute c <- a*b + beta * c (general matrix matrix 247 * multiplication) 248 */ 249 public static DoubleMatrix gemm(double alpha, DoubleMatrix a, 250 DoubleMatrix b, double beta, DoubleMatrix c) { 251 NativeBlas.dgemm('N', 'N', c.rows, c.columns, a.columns, alpha, a.data, 0, 252 a.rows, b.data, 0, b.rows, beta, c.data, 0, c.rows); 253 return c; 254 } 255 256 public static ComplexDoubleMatrix gemm(ComplexDouble alpha, ComplexDoubleMatrix a, 257 ComplexDoubleMatrix b, ComplexDouble beta, ComplexDoubleMatrix c) { 258 NativeBlas.zgemm('N', 'N', c.rows, c.columns, a.columns, alpha, a.data, 0, 259 a.rows, b.data, 0, b.rows, beta, c.data, 0, c.rows); 260 return c; 261 } 262 263 /*************************************************************************** 264 * LAPACK 265 */ 266 267 public static DoubleMatrix gesv(DoubleMatrix a, int[] ipiv, 268 DoubleMatrix b) { 269 int info = NativeBlas.dgesv(a.rows, b.columns, a.data, 0, a.rows, ipiv, 0, 270 b.data, 0, b.rows); 271 checkInfo("DGESV", info); 272 273 if (info > 0) 274 throw new LapackException("DGESV", 275 "Linear equation cannot be solved because the matrix was singular."); 276 277 return b; 278 } 279 280//STOP 281 282 private static void checkInfo(String name, int info) { 283 if (info < -1) 284 throw new LapackArgumentException(name, info); 285 } 286//START 287 288 public static DoubleMatrix sysv(char uplo, DoubleMatrix a, int[] ipiv, 289 DoubleMatrix b) { 290 int info = NativeBlas.dsysv(uplo, a.rows, b.columns, a.data, 0, a.rows, ipiv, 0, 291 b.data, 0, b.rows); 292 checkInfo("SYSV", info); 293 294 if (info > 0) 295 throw new LapackSingularityException("SYV", 296 "Linear equation cannot be solved because the matrix was singular."); 297 298 return b; 299 } 300 301 public static int syev(char jobz, char uplo, DoubleMatrix a, DoubleMatrix w) { 302 int info = NativeBlas.dsyev(jobz, uplo, a.rows, a.data, 0, a.rows, w.data, 0); 303 304 if (info > 0) 305 throw new LapackConvergenceException("SYEV", 306 "Eigenvalues could not be computed " + info 307 + " off-diagonal elements did not converge"); 308 309 return info; 310 } 311 312 public static int syevx(char jobz, char range, char uplo, DoubleMatrix a, 313 double vl, double vu, int il, int iu, double abstol, 314 DoubleMatrix w, DoubleMatrix z) { 315 int n = a.rows; 316 int[] iwork = new int[5 * n]; 317 int[] ifail = new int[n]; 318 int[] m = new int[1]; 319 int info; 320 321 info = NativeBlas.dsyevx(jobz, range, uplo, n, a.data, 0, a.rows, vl, vu, il, 322 iu, abstol, m, 0, w.data, 0, z.data, 0, z.rows, iwork, 0, ifail, 0); 323 324 if (info > 0) { 325 StringBuilder msg = new StringBuilder(); 326 msg 327 .append("Not all eigenvalues converged. Non-converging eigenvalues were: "); 328 for (int i = 0; i < info; i++) { 329 if (i > 0) 330 msg.append(", "); 331 msg.append(ifail[i]); 332 } 333 msg.append("."); 334 throw new LapackConvergenceException("SYEVX", msg.toString()); 335 } 336 337 return info; 338 } 339 340 public static int syevd(char jobz, char uplo, DoubleMatrix A, 341 DoubleMatrix w) { 342 int n = A.rows; 343 344 int info = NativeBlas.dsyevd(jobz, uplo, n, A.data, 0, A.rows, w.data, 0); 345 346 if (info > 0) 347 throw new LapackConvergenceException("SYEVD", "Not all eigenvalues converged."); 348 349 return info; 350 } 351 352 public static int syevr(char jobz, char range, char uplo, DoubleMatrix a, 353 double vl, double vu, int il, int iu, double abstol, 354 DoubleMatrix w, DoubleMatrix z, int[] isuppz) { 355 int n = a.rows; 356 int[] m = new int[1]; 357 358 int info = NativeBlas.dsyevr(jobz, range, uplo, n, a.data, 0, a.rows, vl, vu, 359 il, iu, abstol, m, 0, w.data, 0, z.data, 0, z.rows, isuppz, 0); 360 361 checkInfo("SYEVR", info); 362 363 return info; 364 } 365 366 public static void posv(char uplo, DoubleMatrix A, DoubleMatrix B) { 367 int n = A.rows; 368 int nrhs = B.columns; 369 int info = NativeBlas.dposv(uplo, n, nrhs, A.data, 0, A.rows, B.data, 0, 370 B.rows); 371 checkInfo("DPOSV", info); 372 if (info > 0) 373 throw new LapackArgumentException("DPOSV", 374 "Leading minor of order i of A is not positive definite."); 375 } 376 377 public static int geev(char jobvl, char jobvr, DoubleMatrix A, 378 DoubleMatrix WR, DoubleMatrix WI, DoubleMatrix VL, DoubleMatrix VR) { 379 int info = NativeBlas.dgeev(jobvl, jobvr, A.rows, A.data, 0, A.rows, WR.data, 0, 380 WI.data, 0, VL.data, 0, VL.rows, VR.data, 0, VR.rows); 381 if (info > 0) 382 throw new LapackConvergenceException("DGEEV", "First " + info + " eigenvalues have not converged."); 383 return info; 384 } 385 386 public static int sygvd(int itype, char jobz, char uplo, DoubleMatrix A, DoubleMatrix B, DoubleMatrix W) { 387 int info = NativeBlas.dsygvd(itype, jobz, uplo, A.rows, A.data, 0, A.rows, B.data, 0, B.rows, W.data, 0); 388 if (info == 0) 389 return 0; 390 else { 391 if (info < 0) 392 throw new LapackArgumentException("DSYGVD", -info); 393 if (info <= A.rows && jobz == 'N') 394 throw new LapackConvergenceException("DSYGVD", info + " off-diagonal elements did not converge to 0."); 395 if (info <= A.rows && jobz == 'V') 396 throw new LapackException("DSYGVD", "Failed to compute an eigenvalue while working on a sub-matrix " + info + "."); 397 else 398 throw new LapackException("DSYGVD", "The leading minor of order " + (info - A.rows) + " of B is not positive definite."); 399 } 400 } 401 402 /** 403 * Generalized Least Squares via *GELSD. 404 * 405 * Note that B must be padded to contain the solution matrix. This occurs when A has fewer rows 406 * than columns. 407 * 408 * For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m < n, since B is overwritten to contain 409 * the solution (in classical LAPACK style), B needs to be padded to be an (n,k) matrix. 410 * 411 * Likewise, if m > n, the solution consists only of the first n rows of B. 412 * 413 * @param A an (m,n) matrix 414 * @param B an (max(m,n), k) matrix (well, at least) 415 */ 416 public static void gelsd(DoubleMatrix A, DoubleMatrix B) { 417 int m = A.rows; 418 int n = A.columns; 419 int nrhs = B.columns; 420 int minmn = min(m, n); 421 int maxmn = max(m, n); 422 423 if (B.rows < maxmn) { 424 throw new SizeException("Result matrix B must be padded to contain the solution matrix X!"); 425 } 426 427 int smlsiz = NativeBlas.ilaenv(9, "DGELSD", "", m, n, nrhs, 0); 428 int nlvl = max(0, (int) log2(minmn/ (smlsiz+1)) + 1); 429 430// System.err.printf("GELSD\n"); 431// System.err.printf("m = %d, n = %d, nrhs = %d\n", m, n, nrhs); 432// System.err.printf("smlsiz = %d, nlvl = %d\n", smlsiz, nlvl); 433// System.err.printf("iwork size = %d\n", 3 * minmn * nlvl + 11 * minmn); 434 435 int[] iwork = new int[3 * minmn * nlvl + 11 * minmn]; 436 double[] s = new double[minmn]; 437 int[] rank = new int[1]; 438 int info = NativeBlas.dgelsd(m, n, nrhs, A.data, 0, m, B.data, 0, B.rows, s, 0, -1, rank, 0, iwork, 0); 439 if (info == 0) { 440 return; 441 } else if (info < 0) { 442 throw new LapackArgumentException("DGESD", -info); 443 } else if (info > 0) { 444 throw new LapackConvergenceException("DGESD", info + " off-diagonal elements of an intermediat bidiagonal form did not converge to 0."); 445 } 446 } 447 448 public static void geqrf(DoubleMatrix A, DoubleMatrix tau) { 449 int info = NativeBlas.dgeqrf(A.rows, A.columns, A.data, 0, A.rows, tau.data, 0); 450 checkInfo("GEQRF", info); 451 } 452 453 public static void ormqr(char side, char trans, DoubleMatrix A, DoubleMatrix tau, DoubleMatrix C) { 454 int k = tau.length; 455 int info = NativeBlas.dormqr(side, trans, C.rows, C.columns, k, A.data, 0, A.rows, tau.data, 0, C.data, 0, C.rows); 456 checkInfo("ORMQR", info); 457 } 458 459//BEGIN 460 // The code below has been automatically generated. 461 // DO NOT EDIT! 462 /*************************************************************************** 463 * BLAS Level 1 464 */ 465 466 /** 467 * Compute x <-> y (swap two matrices) 468 */ 469 public static FloatMatrix swap(FloatMatrix x, FloatMatrix y) { 470 //NativeBlas.sswap(x.length, x.data, 0, 1, y.data, 0, 1); 471 JavaBlas.rswap(x.length, x.data, 0, 1, y.data, 0, 1); 472 return y; 473 } 474 475 /** 476 * Compute x <- alpha * x (scale a matrix) 477 */ 478 public static FloatMatrix scal(float alpha, FloatMatrix x) { 479 NativeBlas.sscal(x.length, alpha, x.data, 0, 1); 480 return x; 481 } 482 483 public static ComplexFloatMatrix scal(ComplexFloat alpha, ComplexFloatMatrix x) { 484 NativeBlas.cscal(x.length, alpha, x.data, 0, 1); 485 return x; 486 } 487 488 /** 489 * Compute y <- x (copy a matrix) 490 */ 491 public static FloatMatrix copy(FloatMatrix x, FloatMatrix y) { 492 //NativeBlas.scopy(x.length, x.data, 0, 1, y.data, 0, 1); 493 JavaBlas.rcopy(x.length, x.data, 0, 1, y.data, 0, 1); 494 return y; 495 } 496 497 public static ComplexFloatMatrix copy(ComplexFloatMatrix x, ComplexFloatMatrix y) { 498 NativeBlas.ccopy(x.length, x.data, 0, 1, y.data, 0, 1); 499 return y; 500 } 501 502 /** 503 * Compute y <- alpha * x + y (elementwise addition) 504 */ 505 public static FloatMatrix axpy(float da, FloatMatrix dx, FloatMatrix dy) { 506 //NativeBlas.saxpy(dx.length, da, dx.data, 0, 1, dy.data, 0, 1); 507 JavaBlas.raxpy(dx.length, da, dx.data, 0, 1, dy.data, 0, 1); 508 509 return dy; 510 } 511 512 public static ComplexFloatMatrix axpy(ComplexFloat da, ComplexFloatMatrix dx, ComplexFloatMatrix dy) { 513 NativeBlas.caxpy(dx.length, da, dx.data, 0, 1, dy.data, 0, 1); 514 return dy; 515 } 516 517 /** 518 * Compute x^T * y (dot product) 519 */ 520 public static float dot(FloatMatrix x, FloatMatrix y) { 521 //return NativeBlas.sdot(x.length, x.data, 0, 1, y.data, 0, 1); 522 return JavaBlas.rdot(x.length, x.data, 0, 1, y.data, 0, 1); 523 } 524 525 /** 526 * Compute x^T * y (dot product) 527 */ 528 public static ComplexFloat dotc(ComplexFloatMatrix x, ComplexFloatMatrix y) { 529 return NativeBlas.cdotc(x.length, x.data, 0, 1, y.data, 0, 1); 530 } 531 532 /** 533 * Compute x^T * y (dot product) 534 */ 535 public static ComplexFloat dotu(ComplexFloatMatrix x, ComplexFloatMatrix y) { 536 return NativeBlas.cdotu(x.length, x.data, 0, 1, y.data, 0, 1); 537 } 538 539 /** 540 * Compute || x ||_2 (2-norm) 541 */ 542 public static float nrm2(FloatMatrix x) { 543 return NativeBlas.snrm2(x.length, x.data, 0, 1); 544 } 545 546 public static float nrm2(ComplexFloatMatrix x) { 547 return NativeBlas.scnrm2(x.length, x.data, 0, 1); 548 } 549 550 /** 551 * Compute || x ||_1 (1-norm, sum of absolute values) 552 */ 553 public static float asum(FloatMatrix x) { 554 return NativeBlas.sasum(x.length, x.data, 0, 1); 555 } 556 557 public static float asum(ComplexFloatMatrix x) { 558 return NativeBlas.scasum(x.length, x.data, 0, 1); 559 } 560 561 /** 562 * Compute index of element with largest absolute value (index of absolute 563 * value maximum) 564 */ 565 public static int iamax(FloatMatrix x) { 566 return NativeBlas.isamax(x.length, x.data, 0, 1) - 1; 567 } 568 569 /** 570 * Compute index of element with largest absolute value (complex version). 571 * 572 * @param x matrix 573 * @return index of element with largest absolute value. 574 */ 575 public static int iamax(ComplexFloatMatrix x) { 576 return NativeBlas.icamax(x.length, x.data, 0, 1) - 1; 577 } 578 579 /*************************************************************************** 580 * BLAS Level 2 581 */ 582 583 /** 584 * Compute y <- alpha*op(a)*x + beta * y (general matrix vector 585 * multiplication) 586 */ 587 public static FloatMatrix gemv(float alpha, FloatMatrix a, 588 FloatMatrix x, float beta, FloatMatrix y) { 589 if (false) { 590 NativeBlas.sgemv('N', a.rows, a.columns, alpha, a.data, 0, a.rows, x.data, 0, 591 1, beta, y.data, 0, 1); 592 } else { 593 if (beta == 0.0f) { 594 for (int i = 0; i < y.length; i++) 595 y.data[i] = 0.0f; 596 597 for (int j = 0; j < a.columns; j++) { 598 float xj = x.get(j); 599 if (xj != 0.0f) { 600 for (int i = 0; i < a.rows; i++) 601 y.data[i] += a.get(i, j) * xj; 602 } 603 } 604 } else { 605 for (int j = 0; j < a.columns; j++) { 606 float byj = beta * y.data[j]; 607 float xj = x.get(j); 608 for (int i = 0; i < a.rows; i++) 609 y.data[j] = a.get(i, j) * xj + byj; 610 } 611 } 612 } 613 return y; 614 } 615 616 /** 617 * Compute A <- alpha * x * y^T + A (general rank-1 update) 618 */ 619 public static FloatMatrix ger(float alpha, FloatMatrix x, 620 FloatMatrix y, FloatMatrix a) { 621 NativeBlas.sger(a.rows, a.columns, alpha, x.data, 0, 1, y.data, 0, 1, a.data, 622 0, a.rows); 623 return a; 624 } 625 626 /** 627 * Compute A <- alpha * x * y^T + A (general rank-1 update) 628 */ 629 public static ComplexFloatMatrix geru(ComplexFloat alpha, ComplexFloatMatrix x, 630 ComplexFloatMatrix y, ComplexFloatMatrix a) { 631 NativeBlas.cgeru(a.rows, a.columns, alpha, x.data, 0, 1, y.data, 0, 1, a.data, 632 0, a.rows); 633 return a; 634 } 635 636 /** 637 * Compute A <- alpha * x * y^H + A (general rank-1 update) 638 */ 639 public static ComplexFloatMatrix gerc(ComplexFloat alpha, ComplexFloatMatrix x, 640 ComplexFloatMatrix y, ComplexFloatMatrix a) { 641 NativeBlas.cgerc(a.rows, a.columns, alpha, x.data, 0, 1, y.data, 0, 1, a.data, 642 0, a.rows); 643 return a; 644 } 645 646 /*************************************************************************** 647 * BLAS Level 3 648 */ 649 650 /** 651 * Compute c <- a*b + beta * c (general matrix matrix 652 * multiplication) 653 */ 654 public static FloatMatrix gemm(float alpha, FloatMatrix a, 655 FloatMatrix b, float beta, FloatMatrix c) { 656 NativeBlas.sgemm('N', 'N', c.rows, c.columns, a.columns, alpha, a.data, 0, 657 a.rows, b.data, 0, b.rows, beta, c.data, 0, c.rows); 658 return c; 659 } 660 661 public static ComplexFloatMatrix gemm(ComplexFloat alpha, ComplexFloatMatrix a, 662 ComplexFloatMatrix b, ComplexFloat beta, ComplexFloatMatrix c) { 663 NativeBlas.cgemm('N', 'N', c.rows, c.columns, a.columns, alpha, a.data, 0, 664 a.rows, b.data, 0, b.rows, beta, c.data, 0, c.rows); 665 return c; 666 } 667 668 /*************************************************************************** 669 * LAPACK 670 */ 671 672 public static FloatMatrix gesv(FloatMatrix a, int[] ipiv, 673 FloatMatrix b) { 674 int info = NativeBlas.sgesv(a.rows, b.columns, a.data, 0, a.rows, ipiv, 0, 675 b.data, 0, b.rows); 676 checkInfo("DGESV", info); 677 678 if (info > 0) 679 throw new LapackException("DGESV", 680 "Linear equation cannot be solved because the matrix was singular."); 681 682 return b; 683 } 684 685 686 public static FloatMatrix sysv(char uplo, FloatMatrix a, int[] ipiv, 687 FloatMatrix b) { 688 int info = NativeBlas.ssysv(uplo, a.rows, b.columns, a.data, 0, a.rows, ipiv, 0, 689 b.data, 0, b.rows); 690 checkInfo("SYSV", info); 691 692 if (info > 0) 693 throw new LapackSingularityException("SYV", 694 "Linear equation cannot be solved because the matrix was singular."); 695 696 return b; 697 } 698 699 public static int syev(char jobz, char uplo, FloatMatrix a, FloatMatrix w) { 700 int info = NativeBlas.ssyev(jobz, uplo, a.rows, a.data, 0, a.rows, w.data, 0); 701 702 if (info > 0) 703 throw new LapackConvergenceException("SYEV", 704 "Eigenvalues could not be computed " + info 705 + " off-diagonal elements did not converge"); 706 707 return info; 708 } 709 710 public static int syevx(char jobz, char range, char uplo, FloatMatrix a, 711 float vl, float vu, int il, int iu, float abstol, 712 FloatMatrix w, FloatMatrix z) { 713 int n = a.rows; 714 int[] iwork = new int[5 * n]; 715 int[] ifail = new int[n]; 716 int[] m = new int[1]; 717 int info; 718 719 info = NativeBlas.ssyevx(jobz, range, uplo, n, a.data, 0, a.rows, vl, vu, il, 720 iu, abstol, m, 0, w.data, 0, z.data, 0, z.rows, iwork, 0, ifail, 0); 721 722 if (info > 0) { 723 StringBuilder msg = new StringBuilder(); 724 msg 725 .append("Not all eigenvalues converged. Non-converging eigenvalues were: "); 726 for (int i = 0; i < info; i++) { 727 if (i > 0) 728 msg.append(", "); 729 msg.append(ifail[i]); 730 } 731 msg.append("."); 732 throw new LapackConvergenceException("SYEVX", msg.toString()); 733 } 734 735 return info; 736 } 737 738 public static int syevd(char jobz, char uplo, FloatMatrix A, 739 FloatMatrix w) { 740 int n = A.rows; 741 742 int info = NativeBlas.ssyevd(jobz, uplo, n, A.data, 0, A.rows, w.data, 0); 743 744 if (info > 0) 745 throw new LapackConvergenceException("SYEVD", "Not all eigenvalues converged."); 746 747 return info; 748 } 749 750 public static int syevr(char jobz, char range, char uplo, FloatMatrix a, 751 float vl, float vu, int il, int iu, float abstol, 752 FloatMatrix w, FloatMatrix z, int[] isuppz) { 753 int n = a.rows; 754 int[] m = new int[1]; 755 756 int info = NativeBlas.ssyevr(jobz, range, uplo, n, a.data, 0, a.rows, vl, vu, 757 il, iu, abstol, m, 0, w.data, 0, z.data, 0, z.rows, isuppz, 0); 758 759 checkInfo("SYEVR", info); 760 761 return info; 762 } 763 764 public static void posv(char uplo, FloatMatrix A, FloatMatrix B) { 765 int n = A.rows; 766 int nrhs = B.columns; 767 int info = NativeBlas.sposv(uplo, n, nrhs, A.data, 0, A.rows, B.data, 0, 768 B.rows); 769 checkInfo("DPOSV", info); 770 if (info > 0) 771 throw new LapackArgumentException("DPOSV", 772 "Leading minor of order i of A is not positive definite."); 773 } 774 775 public static int geev(char jobvl, char jobvr, FloatMatrix A, 776 FloatMatrix WR, FloatMatrix WI, FloatMatrix VL, FloatMatrix VR) { 777 int info = NativeBlas.sgeev(jobvl, jobvr, A.rows, A.data, 0, A.rows, WR.data, 0, 778 WI.data, 0, VL.data, 0, VL.rows, VR.data, 0, VR.rows); 779 if (info > 0) 780 throw new LapackConvergenceException("DGEEV", "First " + info + " eigenvalues have not converged."); 781 return info; 782 } 783 784 public static int sygvd(int itype, char jobz, char uplo, FloatMatrix A, FloatMatrix B, FloatMatrix W) { 785 int info = NativeBlas.ssygvd(itype, jobz, uplo, A.rows, A.data, 0, A.rows, B.data, 0, B.rows, W.data, 0); 786 if (info == 0) 787 return 0; 788 else { 789 if (info < 0) 790 throw new LapackArgumentException("DSYGVD", -info); 791 if (info <= A.rows && jobz == 'N') 792 throw new LapackConvergenceException("DSYGVD", info + " off-diagonal elements did not converge to 0."); 793 if (info <= A.rows && jobz == 'V') 794 throw new LapackException("DSYGVD", "Failed to compute an eigenvalue while working on a sub-matrix " + info + "."); 795 else 796 throw new LapackException("DSYGVD", "The leading minor of order " + (info - A.rows) + " of B is not positive definite."); 797 } 798 } 799 800 /** 801 * Generalized Least Squares via *GELSD. 802 * 803 * Note that B must be padded to contain the solution matrix. This occurs when A has fewer rows 804 * than columns. 805 * 806 * For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m < n, since B is overwritten to contain 807 * the solution (in classical LAPACK style), B needs to be padded to be an (n,k) matrix. 808 * 809 * Likewise, if m > n, the solution consists only of the first n rows of B. 810 * 811 * @param A an (m,n) matrix 812 * @param B an (max(m,n), k) matrix (well, at least) 813 */ 814 public static void gelsd(FloatMatrix A, FloatMatrix B) { 815 int m = A.rows; 816 int n = A.columns; 817 int nrhs = B.columns; 818 int minmn = min(m, n); 819 int maxmn = max(m, n); 820 821 if (B.rows < maxmn) { 822 throw new SizeException("Result matrix B must be padded to contain the solution matrix X!"); 823 } 824 825 int smlsiz = NativeBlas.ilaenv(9, "DGELSD", "", m, n, nrhs, 0); 826 int nlvl = max(0, (int) log2(minmn/ (smlsiz+1)) + 1); 827 828// System.err.printf("GELSD\n"); 829// System.err.printf("m = %d, n = %d, nrhs = %d\n", m, n, nrhs); 830// System.err.printf("smlsiz = %d, nlvl = %d\n", smlsiz, nlvl); 831// System.err.printf("iwork size = %d\n", 3 * minmn * nlvl + 11 * minmn); 832 833 int[] iwork = new int[3 * minmn * nlvl + 11 * minmn]; 834 float[] s = new float[minmn]; 835 int[] rank = new int[1]; 836 int info = NativeBlas.sgelsd(m, n, nrhs, A.data, 0, m, B.data, 0, B.rows, s, 0, -1, rank, 0, iwork, 0); 837 if (info == 0) { 838 return; 839 } else if (info < 0) { 840 throw new LapackArgumentException("DGESD", -info); 841 } else if (info > 0) { 842 throw new LapackConvergenceException("DGESD", info + " off-diagonal elements of an intermediat bidiagonal form did not converge to 0."); 843 } 844 } 845 846 public static void geqrf(FloatMatrix A, FloatMatrix tau) { 847 int info = NativeBlas.sgeqrf(A.rows, A.columns, A.data, 0, A.rows, tau.data, 0); 848 checkInfo("GEQRF", info); 849 } 850 851 public static void ormqr(char side, char trans, FloatMatrix A, FloatMatrix tau, FloatMatrix C) { 852 int k = tau.length; 853 int info = NativeBlas.sormqr(side, trans, C.rows, C.columns, k, A.data, 0, A.rows, tau.data, 0, C.data, 0, C.rows); 854 checkInfo("ORMQR", info); 855 } 856 857//END 858}