001/*-
002 * Copyright 2016 Diamond Light Source Ltd.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 */
009
010package org.eclipse.january.dataset;
011
012import java.util.List;
013
014/**
015 * Class to run over a pair of datasets in parallel with NumPy broadcasting of second dataset
016 */
017public class BroadcastSingleIterator extends BroadcastSelfIterator {
018        private int[] bShape;
019        private int[] aStride;
020        private int[] bStride;
021
022        final private int endrank;
023
024        private final int[] aDelta, bDelta;
025        private final int aStep, bStep;
026        private int aMax, bMax;
027        private int aStart, bStart;
028
029        /**
030         * 
031         * @param a
032         * @param b
033         */
034        public BroadcastSingleIterator(Dataset a, Dataset b) {
035                super(a, b);
036
037                int[] aShape = a.getShapeRef();
038                maxShape = aShape;
039                List<int[]> fullShapes = BroadcastUtils.broadcastShapesToMax(maxShape, b.getShapeRef());
040                bShape = fullShapes.remove(0);
041
042                int rank = maxShape.length;
043                endrank = rank - 1;
044
045                bDataset = b.reshape(bShape);
046                int[] aOffset = new int[1];
047                aStride = AbstractDataset.createStrides(aDataset, aOffset );
048                bStride = BroadcastUtils.createBroadcastStrides(bDataset, maxShape);
049
050                pos = new int[rank];
051                aDelta = new int[rank];
052                aStep = aDataset.getElementsPerItem();
053                bDelta = new int[rank];
054                bStep = bDataset.getElementsPerItem();
055                for (int j = endrank; j >= 0; j--) {
056                        aDelta[j] = aStride[j] * aShape[j];
057                        bDelta[j] = bStride[j] * bShape[j];
058                }
059                if (endrank < 0) {
060                        aMax = aStep;
061                        bMax = bStep;
062                } else {
063                        aMax = Integer.MIN_VALUE; // use max delta
064                        bMax = Integer.MIN_VALUE;
065                        for (int j = endrank; j >= 0; j--) {
066                                if (aDelta[j] > aMax) {
067                                        aMax = aDelta[j];
068                                }
069                                if (bDelta[j] > bMax) {
070                                        bMax = bDelta[j];
071                                }
072                        }
073                }
074                aStart = aOffset[0];
075                aMax += aStart;
076                bStart = bDataset.getOffset();
077                bMax += bStart;
078                reset();
079        }
080
081        @Override
082        public boolean hasNext() {
083                int j = endrank;
084                int oldB = bIndex;
085                for (; j >= 0; j--) {
086                        pos[j]++;
087                        aIndex += aStride[j];
088                        bIndex += bStride[j];
089                        if (pos[j] >= maxShape[j]) {
090                                pos[j] = 0;
091                                aIndex -= aDelta[j]; // reset these dimensions
092                                bIndex -= bDelta[j];
093                        } else {
094                                break;
095                        }
096                }
097                if (j == -1) {
098                        if (endrank >= 0) {
099                                aIndex = aMax;
100                                bIndex = bMax;
101                                return false;
102                        }
103                        aIndex += aStep;
104                        bIndex += bStep;
105                }
106
107                if (aIndex == aMax || bIndex == bMax)
108                        return false;
109
110                if (read) {
111                        if (oldB != bIndex) {
112                                if (asDouble) {
113                                        bDouble = bDataset.getElementDoubleAbs(bIndex);
114                                } else {
115                                        bLong = bDataset.getElementLongAbs(bIndex);
116                                }
117                        }
118                }
119
120                return true;
121        }
122
123        /**
124         * @return shape of first broadcasted dataset
125         */
126        public int[] getFirstShape() {
127                return maxShape;
128        }
129
130        /**
131         * @return shape of second broadcasted dataset
132         */
133        public int[] getSecondShape() {
134                return bShape;
135        }
136
137        @Override
138        public void reset() {
139                for (int i = 0; i <= endrank; i++)
140                        pos[i] = 0;
141
142                if (endrank >= 0) {
143                        pos[endrank] = -1;
144                        aIndex = aStart - aStride[endrank];
145                        bIndex = bStart - bStride[endrank];
146                } else {
147                        aIndex = aStart - aStep;
148                        bIndex = bStart - bStep;
149                }
150
151                if (aIndex == 0 || bIndex == 0) { // for zero-ranked datasets
152                        if (read) {
153                                storeCurrentValues();
154                        }
155                        if (aMax == aIndex)
156                                aMax++;
157                        if (bMax == bIndex)
158                                bMax++;
159                }
160        }
161}