001/*-
002 * Copyright 2016 Diamond Light Source Ltd.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 */
009
010package org.eclipse.january.dataset;
011
012import java.util.ArrayList;
013import java.util.Arrays;
014import java.util.List;
015
016public final class BroadcastUtils {
017
018        /**
019         * Calculate shapes for broadcasting
020         * @param oldShape
021         * @param size
022         * @param newShape
023         * @return broadcasted shape and full new shape or null if it cannot be done
024         */
025        public static int[][] calculateBroadcastShapes(int[] oldShape, int size, int... newShape) {
026                if (newShape == null)
027                        return null;
028        
029                int brank = newShape.length;
030                if (brank == 0) {
031                        if (size == 1)
032                                return new int[][] {oldShape, newShape};
033                        return null;
034                }
035        
036                if (Arrays.equals(oldShape, newShape))
037                        return new int[][] {oldShape, newShape};
038        
039                int offset = brank - oldShape.length;
040                if (offset < 0) { // when new shape is incomplete
041                        newShape = padShape(newShape, -offset);
042                        offset = 0;
043                }
044        
045                int[] bshape;
046                if (offset > 0) { // new shape has extra dimensions
047                        bshape = padShape(oldShape, offset);
048                } else {
049                        bshape = oldShape;
050                }
051        
052                for (int i = 0; i < brank; i++) {
053                        if (newShape[i] != bshape[i] && bshape[i] != 1 && newShape[i] != 1) {
054                                return null;
055                        }
056                }
057        
058                return new int[][] {bshape, newShape};
059        }
060
061        /**
062         * Pad shape by prefixing with ones
063         * @param shape
064         * @param padding
065         * @return new shape or old shape if padding is zero
066         */
067        public static int[] padShape(final int[] shape, final int padding) {
068                if (padding < 0)
069                        throw new IllegalArgumentException("Padding must be zero or greater");
070        
071                if (padding == 0)
072                        return shape;
073        
074                final int[] nshape = new int[shape.length + padding];
075                Arrays.fill(nshape, 1);
076                System.arraycopy(shape, 0, nshape, padding, shape.length);
077                return nshape;
078        }
079
080        /**
081         * Take in shapes and broadcast them to same rank
082         * @param shapes
083         * @return list of broadcasted shapes plus the first entry is the maximum shape
084         */
085        public static List<int[]> broadcastShapes(int[]... shapes) {
086                int maxRank = -1;
087                for (int[] s : shapes) {
088                        if (s == null)
089                                continue;
090        
091                        int r = s.length;
092                        if (r > maxRank) {
093                                maxRank = r;
094                        }
095                }
096        
097                List<int[]> newShapes = new ArrayList<int[]>();
098                for (int[] s : shapes) {
099                        if (s == null)
100                                continue;
101                        newShapes.add(padShape(s, maxRank - s.length));
102                }
103        
104                int[] maxShape = new int[maxRank];
105                for (int i = 0; i < maxRank; i++) {
106                        int m = -1;
107                        for (int[] s : newShapes) {
108                                int l = s[i];
109                                if (l > m) {
110                                        if (m > 1) {
111                                                throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum");
112                                        }
113                                        m = l;
114                                }
115                        }
116                        maxShape[i] = m;
117                }
118
119                checkShapes(maxShape, newShapes);
120                newShapes.add(0, maxShape);
121                return newShapes;
122        }
123
124        /**
125         * Take in shapes and broadcast them to maximum shape
126         * @param maxShape
127         * @param shapes
128         * @return list of broadcasted shapes
129         */
130        public static List<int[]> broadcastShapesToMax(int[] maxShape, int[]... shapes) {
131                int maxRank = maxShape.length;
132                for (int[] s : shapes) {
133                        if (s == null)
134                                continue;
135        
136                        int r = s.length;
137                        if (r > maxRank) {
138                                throw new IllegalArgumentException("A shape exceeds given rank of maximum shape");
139                        }
140                }
141        
142                List<int[]> newShapes = new ArrayList<int[]>();
143                for (int[] s : shapes) {
144                        if (s == null)
145                                continue;
146                        newShapes.add(padShape(s, maxRank - s.length));
147                }
148
149                checkShapes(maxShape, newShapes);
150                return newShapes;
151        }
152
153        private static void checkShapes(int[] maxShape, List<int[]> newShapes) {
154                for (int i = 0; i < maxShape.length; i++) {
155                        int m = maxShape[i];
156                        for (int[] s : newShapes) {
157                                int l = s[i];
158                                if (l != 1 && l != m) {
159                                        throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum");
160                                }
161                        }
162                }
163        }
164
165        @SuppressWarnings("deprecation")
166        static Dataset createDataset(final Dataset a, final Dataset b, final int[] shape) {
167                final int rt;
168                final int ar = a.getRank();
169                final int br = b.getRank();
170                final int tt = DTypeUtils.getBestDType(a.getDType(), b.getDType());
171                if (ar == 0 ^ br == 0) { // ignore type of zero-rank dataset unless it's floating point 
172                        if (ar == 0) {
173                                rt = a.hasFloatingPointElements() ? tt : b.getDType();
174                        } else {
175                                rt = b.hasFloatingPointElements() ? tt : a.getDType();
176                        }
177                } else {
178                        rt = tt;
179                }
180                final int ia = a.getElementsPerItem();
181                final int ib = b.getElementsPerItem();
182        
183                return DatasetFactory.zeros(ia > ib ? ia : ib, shape, rt);
184        }
185
186        /**
187         * Check if dataset item sizes are compatible
188         * <p>
189         * Dataset a is considered compatible with the output dataset if any of the
190         * conditions are true:
191         * <ul>
192         * <li>o is undefined</li>
193         * <li>a has item size equal to o's</li>
194         * <li>a has item size equal to 1</li>
195         * <li>o has item size equal to 1</li>
196         * </ul>
197         * @param a input dataset a
198         * @param o output dataset (can be null)
199         */
200        static void checkItemSize(Dataset a, Dataset o) {
201                final int isa = a.getElementsPerItem();
202                if (o != null) {
203                        final int iso = o.getElementsPerItem();
204                        if (isa != iso && isa != 1 && iso != 1) {
205                                throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'");
206                        }
207                }
208        }
209
210        /**
211         * Check if dataset item sizes are compatible
212         * <p>
213         * Dataset a is considered compatible with the output dataset if any of the
214         * conditions are true:
215         * <ul>
216         * <li>a has item size equal to b's</li>
217         * <li>a has item size equal to 1</li>
218         * <li>b has item size equal to 1</li>
219         * <li>a or b are single-valued</li>
220         * </ul>
221         * and, o is undefined, or any of the following are true:
222         * <ul>
223         * <li>o has item size equal to maximum of a and b's</li>
224         * <li>o has item size equal to 1</li>
225         * <li>a and b have item sizes of 1</li>
226         * </ul>
227         * @param a input dataset a
228         * @param b input dataset b
229         * @param o output dataset
230         */
231        static void checkItemSize(Dataset a, Dataset b, Dataset o) {
232                final int isa = a.getElementsPerItem();
233                final int isb = b.getElementsPerItem();
234                if (isa != isb && isa != 1 && isb != 1) {
235                        // exempt single-value dataset case too
236                        if ((isa == 1 || b.getSize() != 1) && (isb == 1 || a.getSize() != 1) ) {
237                                throw new IllegalArgumentException("Can not broadcast where number of elements per item mismatch and one does not equal another");
238                        }
239                }
240                if (o != null && o.getDType() != Dataset.BOOL) {
241                        final int ism = Math.max(isa, isb);
242                        final int iso = o.getElementsPerItem();
243                        if (iso != ism && iso != 1 && ism != 1) {
244                                throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'");
245                        }
246                }
247        }
248
249        /**
250         * Create a stride array from a dataset to a broadcast shape
251         * @param a dataset
252         * @param broadcastShape
253         * @return stride array
254         */
255        public static int[] createBroadcastStrides(Dataset a, final int[] broadcastShape) {
256                return createBroadcastStrides(a.getElementsPerItem(), a.getShapeRef(), a.getStrides(), broadcastShape);
257        }
258
259        /**
260         * Create a stride array from a dataset to a broadcast shape
261         * @param isize
262         * @param oShape original shape
263         * @param oStride original stride
264         * @param broadcastShape
265         * @return stride array
266         */
267        public static int[] createBroadcastStrides(final int isize, final int[] oShape, final int[] oStride, final int[] broadcastShape) {
268                int rank = oShape.length;
269                if (broadcastShape.length != rank) {
270                        throw new IllegalArgumentException("Dataset must have same rank as broadcast shape");
271                }
272        
273                int[] stride = new int[rank];
274                if (oStride == null) {
275                        int s = isize;
276                        for (int j = rank - 1; j >= 0; j--) {
277                                if (broadcastShape[j] == oShape[j]) {
278                                        stride[j] = s;
279                                        s *= oShape[j];
280                                } else {
281                                        stride[j] = 0;
282                                }
283                        }
284                } else {
285                        for (int j = 0; j < rank; j++) {
286                                if (broadcastShape[j] == oShape[j]) {
287                                        stride[j] = oStride[j];
288                                } else {
289                                        stride[j] = 0;
290                                }
291                        }
292                }
293        
294                return stride;
295        }
296
297        /**
298         * Converts and broadcast all objects as datasets of same shape
299         * @param objects
300         * @return all as broadcasted to same shape
301         */
302        public static Dataset[] convertAndBroadcast(Object... objects) {
303                final int n = objects.length;
304
305                Dataset[] datasets = new Dataset[n];
306                int[][] shapes = new int[n][];
307                for (int i = 0; i < n; i++) {
308                        Dataset d = DatasetFactory.createFromObject(objects[i]);
309                        datasets[i] = d;
310                        shapes[i] = d.getShapeRef();
311                }
312
313                List<int[]> nShapes = BroadcastUtils.broadcastShapes(shapes);
314                int[] mshape = nShapes.get(0);
315                for (int i = 0; i < n; i++) {
316                        datasets[i] = datasets[i].getBroadcastView(mshape);
317                }
318
319                return datasets;
320        }
321}