001// --- BEGIN LICENSE BLOCK ---
002/* 
003 * Copyright (c) 2009, Mikio L. Braun
004 * All rights reserved.
005 * 
006 * Redistribution and use in source and binary forms, with or without
007 * modification, are permitted provided that the following conditions are
008 * met:
009 * 
010 *     * Redistributions of source code must retain the above copyright
011 *       notice, this list of conditions and the following disclaimer.
012 * 
013 *     * Redistributions in binary form must reproduce the above
014 *       copyright notice, this list of conditions and the following
015 *       disclaimer in the documentation and/or other materials provided
016 *       with the distribution.
017 * 
018 *     * Neither the name of the Technische Universit?t Berlin nor the
019 *       names of its contributors may be used to endorse or promote
020 *       products derived from this software without specific prior
021 *       written permission.
022 * 
023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
024 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
025 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
026 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
027 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
028 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
029 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
030 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
031 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
032 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
033 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
034 */
035// --- END LICENSE BLOCK ---
036
037/*
038 * To change this template, choose Tools | Templates
039 * and open the template in the editor.
040 */
041package org.jblas;
042
043import org.jblas.exceptions.LapackException;
044
045/**
046 * <p>Implementation of some Blas functions, mostly those which require linear runtime
047 * in the number of matrix elements. Because of the copying overhead when passing
048 * primitive arrays to native code, it doesn't make sense for these functions
049 * to be implemented in native code. The Java code is about as fast.</p>
050 * 
051 * <p>The same conventions were used as in the native code, that is, for each array
052 * you also pass an index pointing to the starting index.</p>
053 * 
054 * <p>These methods are mostly optimized for the case where the starting index is 0
055 * and the increment is 1.</p>
056 */
057public class JavaBlas {
058
059    /** Exchange two vectors. */
060    public static void rswap(int n, double[] dx, int dxIdx, int incx, double[] dy, int dyIdx, int incy) {
061        if (incx == 1 && incy == 1 && dxIdx == 0 && dyIdx == 0) {
062            double z;
063            for (int i = 0; i < n; i++) {
064                z = dx[i];
065                dx[i] = dy[i];
066                dy[i] = z;
067            }
068        } else {
069            double z;
070            for (int c = 0, xi = dxIdx, yi = dyIdx; c < n; xi += incx, yi += incy, c++) {
071                z = dx[xi];
072                dx[xi] = dy[yi];
073                dy[yi] = z;
074            }
075        }
076    }
077
078    /** Copy dx to dy. */
079    public static void rcopy(int n, double[] dx, int dxIdx, int incx, double[] dy, int dyIdx, int incy) {
080        if (dxIdx < 0 || dxIdx + (n - 1) * incx >= dx.length) {
081            throw new LapackException("Java.raxpy", "Parameters for x aren't valid! (n = " + n + ", dx.length = " + dx.length + ", dxIdx = " + dxIdx + ", incx = " + incx + ")");
082        }
083        if (dyIdx < 0 || dyIdx + (n - 1) * incy >= dy.length) {
084            throw new LapackException("Java.raxpy", "Parameters for y aren't valid! (n = " + n + ", dy.length = " + dy.length + ", dyIdx = " + dyIdx + ", incy = " + incy + ")");
085        }
086        if (incx == 1 && incy == 1) {
087            System.arraycopy(dx, dxIdx, dy, dyIdx, n);
088        } else {
089            for (int c = 0, xi = dxIdx, yi = dyIdx; c < n; xi += incx, yi += incy, c++) {
090                dy[yi] = dx[xi];
091            }
092        }
093    }
094
095    /** Compute dy <- da * dx + dy. */
096    public static void raxpy(int n, double da, double[] dx, int dxIdx, int incx, double[] dy, int dyIdx, int incy) {
097        if (dxIdx < 0 || dxIdx + (n - 1) * incx >= dx.length) {
098            throw new LapackException("Java.raxpy", "Parameters for x aren't valid! (n = " + n + ", dx.length = " + dx.length + ", dxIdx = " + dxIdx + ", incx = " + incx + ")");
099        }
100        if (dyIdx < 0 || dyIdx + (n - 1) * incy >= dy.length) {
101            throw new LapackException("Java.raxpy", "Parameters for y aren't valid! (n = " + n + ", dy.length = " + dy.length + ", dyIdx = " + dyIdx + ", incy = " + incy + ")");
102        }
103        
104        if (incx == 1 && incy == 1 && dxIdx == 0 && dyIdx == 0) {
105            if (da == 1.0) {
106                for (int i = 0; i < n; i++) {
107                    dy[i] += dx[i];
108                }
109            } else {
110                for (int i = 0; i < n; i++) {
111                    dy[i] += da * dx[i];
112                }
113            }
114        } else {
115            if (da == 1.0) {
116                for (int c = 0, xi = dxIdx, yi = dyIdx; c < n; c++, xi += incx, yi += incy) {
117                    dy[yi] += dx[xi];
118                }
119
120            } else {
121                for (int c = 0, xi = dxIdx, yi = dyIdx; c < n; c++, xi += incx, yi += incy) {
122                    dy[yi] += da * dx[xi];
123                }
124            }
125        }
126    }
127
128    /** Computes dz <- dx + dy */
129    public static void rzaxpy(int n, double[] dz, int dzIdx, int incz, double da, double[] dx, int dxIdx, int incx, double[] dy, int dyIdx, int incy) {
130        if (dxIdx == 0 && incx == 1 && dyIdx == 0 && incy == 1 && dzIdx == 0 && incz == 1) {
131            if (da == 1.0) {
132                for (int c = 0; c < n; c++)
133                    dz[c] = dx[c] + dy[c];
134            } else {
135                for (int c = 0; c < n; c++)
136                    dz[c] = da*dx[c] + dy[c];
137            }
138        } else {
139            if (da == 1.0) {
140                for (int c = 0, xi = dxIdx, yi = dyIdx, zi = dzIdx; c < n; c++, xi += incx, yi += incy, zi += incz) {
141                    dz[zi] = dx[xi] + dy[yi];
142                }
143            } else {
144                for (int c = 0, xi = dxIdx, yi = dyIdx, zi = dzIdx; c < n; c++, xi += incx, yi += incy, zi += incz) {
145                    dz[zi] = da*dx[xi] + dy[yi];
146                }
147            }
148        }
149    }
150
151    public static void rzgxpy(int n, double[] dz, double[] dx, double[] dy) {
152        for (int c = 0; c < n; c++)
153            dz[c] = dx[c] + dy[c];       
154    }
155
156    /** Compute scalar product between dx and dy. */
157    public static double rdot(int n, double[] dx, int dxIdx, int incx, double[] dy, int dyIdx, int incy) {
158        double s = 0.0;
159        if (incx == 1 && incy == 1 && dxIdx == 0 && dyIdx == 0) {
160            for (int i = 0; i < n; i++)
161                s += dx[i] * dy[i];
162        }
163        else {
164            for (int c = 0, xi = dxIdx, yi = dyIdx; c < n; c++, xi += incx, yi += incy) {
165                s += dx[xi] * dy[yi];
166            }
167        }
168        return s;
169    }
170//BEGIN
171  // The code below has been automatically generated.
172  // DO NOT EDIT!
173
174    /** Exchange two vectors. */
175    public static void rswap(int n, float[] dx, int dxIdx, int incx, float[] dy, int dyIdx, int incy) {
176        if (incx == 1 && incy == 1 && dxIdx == 0 && dyIdx == 0) {
177            float z;
178            for (int i = 0; i < n; i++) {
179                z = dx[i];
180                dx[i] = dy[i];
181                dy[i] = z;
182            }
183        } else {
184            float z;
185            for (int c = 0, xi = dxIdx, yi = dyIdx; c < n; xi += incx, yi += incy, c++) {
186                z = dx[xi];
187                dx[xi] = dy[yi];
188                dy[yi] = z;
189            }
190        }
191    }
192
193    /** Copy dx to dy. */
194    public static void rcopy(int n, float[] dx, int dxIdx, int incx, float[] dy, int dyIdx, int incy) {
195        if (dxIdx < 0 || dxIdx + (n - 1) * incx >= dx.length) {
196            throw new LapackException("Java.raxpy", "Parameters for x aren't valid! (n = " + n + ", dx.length = " + dx.length + ", dxIdx = " + dxIdx + ", incx = " + incx + ")");
197        }
198        if (dyIdx < 0 || dyIdx + (n - 1) * incy >= dy.length) {
199            throw new LapackException("Java.raxpy", "Parameters for y aren't valid! (n = " + n + ", dy.length = " + dy.length + ", dyIdx = " + dyIdx + ", incy = " + incy + ")");
200        }
201        if (incx == 1 && incy == 1) {
202            System.arraycopy(dx, dxIdx, dy, dyIdx, n);
203        } else {
204            for (int c = 0, xi = dxIdx, yi = dyIdx; c < n; xi += incx, yi += incy, c++) {
205                dy[yi] = dx[xi];
206            }
207        }
208    }
209
210    /** Compute dy <- da * dx + dy. */
211    public static void raxpy(int n, float da, float[] dx, int dxIdx, int incx, float[] dy, int dyIdx, int incy) {
212        if (dxIdx < 0 || dxIdx + (n - 1) * incx >= dx.length) {
213            throw new LapackException("Java.raxpy", "Parameters for x aren't valid! (n = " + n + ", dx.length = " + dx.length + ", dxIdx = " + dxIdx + ", incx = " + incx + ")");
214        }
215        if (dyIdx < 0 || dyIdx + (n - 1) * incy >= dy.length) {
216            throw new LapackException("Java.raxpy", "Parameters for y aren't valid! (n = " + n + ", dy.length = " + dy.length + ", dyIdx = " + dyIdx + ", incy = " + incy + ")");
217        }
218        
219        if (incx == 1 && incy == 1 && dxIdx == 0 && dyIdx == 0) {
220            if (da == 1.0f) {
221                for (int i = 0; i < n; i++) {
222                    dy[i] += dx[i];
223                }
224            } else {
225                for (int i = 0; i < n; i++) {
226                    dy[i] += da * dx[i];
227                }
228            }
229        } else {
230            if (da == 1.0f) {
231                for (int c = 0, xi = dxIdx, yi = dyIdx; c < n; c++, xi += incx, yi += incy) {
232                    dy[yi] += dx[xi];
233                }
234
235            } else {
236                for (int c = 0, xi = dxIdx, yi = dyIdx; c < n; c++, xi += incx, yi += incy) {
237                    dy[yi] += da * dx[xi];
238                }
239            }
240        }
241    }
242
243    /** Computes dz <- dx + dy */
244    public static void rzaxpy(int n, float[] dz, int dzIdx, int incz, float da, float[] dx, int dxIdx, int incx, float[] dy, int dyIdx, int incy) {
245        if (dxIdx == 0 && incx == 1 && dyIdx == 0 && incy == 1 && dzIdx == 0 && incz == 1) {
246            if (da == 1.0f) {
247                for (int c = 0; c < n; c++)
248                    dz[c] = dx[c] + dy[c];
249            } else {
250                for (int c = 0; c < n; c++)
251                    dz[c] = da*dx[c] + dy[c];
252            }
253        } else {
254            if (da == 1.0f) {
255                for (int c = 0, xi = dxIdx, yi = dyIdx, zi = dzIdx; c < n; c++, xi += incx, yi += incy, zi += incz) {
256                    dz[zi] = dx[xi] + dy[yi];
257                }
258            } else {
259                for (int c = 0, xi = dxIdx, yi = dyIdx, zi = dzIdx; c < n; c++, xi += incx, yi += incy, zi += incz) {
260                    dz[zi] = da*dx[xi] + dy[yi];
261                }
262            }
263        }
264    }
265
266    public static void rzgxpy(int n, float[] dz, float[] dx, float[] dy) {
267        for (int c = 0; c < n; c++)
268            dz[c] = dx[c] + dy[c];       
269    }
270
271    /** Compute scalar product between dx and dy. */
272    public static float rdot(int n, float[] dx, int dxIdx, int incx, float[] dy, int dyIdx, int incy) {
273        float s = 0.0f;
274        if (incx == 1 && incy == 1 && dxIdx == 0 && dyIdx == 0) {
275            for (int i = 0; i < n; i++)
276                s += dx[i] * dy[i];
277        }
278        else {
279            for (int c = 0, xi = dxIdx, yi = dyIdx; c < n; c++, xi += incx, yi += incy) {
280                s += dx[xi] * dy[yi];
281            }
282        }
283        return s;
284    }
285//END
286}