001/*- 002 ******************************************************************************* 003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd. 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * Peter Chang - initial API and implementation and/or initial documentation 011 *******************************************************************************/ 012 013package org.eclipse.january.dataset; 014 015import java.util.Arrays; 016import java.util.List; 017 018/** 019 * Class to run over a single dataset with NumPy broadcasting to promote shapes 020 * which have lower rank and outputs to a second dataset 021 */ 022public class SingleInputBroadcastIterator extends IndexIterator { 023 private int[] maxShape; 024 private int[] aShape; 025 private final Dataset aDataset; 026 private final Dataset oDataset; 027 private int[] aStride; 028 private int[] oStride; 029 030 final private int endrank; 031 032 /** 033 * position in dataset 034 */ 035 private final int[] pos; 036 private final int[] aDelta; 037 private final int[] oDelta; // this being non-null means output is different from inputs 038 private final int aStep, oStep; 039 private int aMax; 040 private int aStart, oStart; 041 private final boolean outputA; 042 043 /** 044 * Index in array 045 */ 046 public int aIndex, oIndex; 047 048 /** 049 * Current value in array 050 */ 051 public double aDouble; 052 053 /** 054 * Current value in array 055 */ 056 public long aLong; 057 058 private boolean asDouble = true; 059 060 /** 061 * @param a 062 * @param o (can be null for new dataset, or a) 063 */ 064 public SingleInputBroadcastIterator(Dataset a, Dataset o) { 065 this(a, o, false); 066 } 067 068 /** 069 * @param a 070 * @param o (can be null for new dataset, or a) 071 * @param createIfNull (by default, can create float or complex datasets) 072 */ 073 public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull) { 074 this(a, o, createIfNull, false, true); 075 } 076 077 /** 078 * @param a 079 * @param o (can be null for new dataset, or a) 080 * @param createIfNull 081 * @param allowInteger if true, can create integer datasets 082 * @param allowComplex if true, can create complex datasets 083 */ 084 public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull, boolean allowInteger, boolean allowComplex) { 085 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), o == null ? null : o.getShapeRef()); 086 087 BroadcastUtils.checkItemSize(a, o); 088 089 maxShape = fullShapes.remove(0); 090 091 oStride = null; 092 if (o != null) { 093 if (!Arrays.equals(maxShape, o.getShapeRef())) { 094 throw new IllegalArgumentException("Output does not match broadcasted shape"); 095 } 096 o.setDirty(); 097 } 098 099 aShape = fullShapes.remove(0); 100 101 int rank = maxShape.length; 102 endrank = rank - 1; 103 104 aDataset = a.reshape(aShape); 105 aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape); 106 outputA = o == a; 107 if (outputA) { 108 oStride = aStride; 109 oDelta = null; 110 oStep = 0; 111 oDataset = aDataset; 112 } else if (o != null) { 113 oStride = BroadcastUtils.createBroadcastStrides(o, maxShape); 114 oDelta = new int[rank]; 115 oStep = o.getElementsPerItem(); 116 oDataset = o; 117 } else if (createIfNull) { 118 int is = aDataset.getElementsPerItem(); 119 Class<? extends Dataset> dc = aDataset.getClass(); 120 if (aDataset.isComplex() && !allowComplex) { 121 is = 1; 122 dc = InterfaceUtils.getBestFloatInterface(dc); 123 } else if (!aDataset.hasFloatingPointElements() && !allowInteger) { 124 dc = InterfaceUtils.getBestFloatInterface(dc); 125 } 126 oDataset = DatasetFactory.zeros(is, dc, maxShape); 127 oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape); 128 oDelta = new int[rank]; 129 oStep = oDataset.getElementsPerItem(); 130 } else { 131 oDelta = null; 132 oStep = 0; 133 oDataset = o; 134 } 135 136 pos = new int[rank]; 137 aDelta = new int[rank]; 138 aStep = aDataset.getElementsPerItem(); 139 for (int j = endrank; j >= 0; j--) { 140 aDelta[j] = aStride[j] * aShape[j]; 141 if (oDelta != null) { 142 oDelta[j] = oStride[j] * maxShape[j]; 143 } 144 } 145 aStart = aDataset.getOffset(); 146 aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE; 147 oStart = oDelta == null ? 0 : oDataset.getOffset(); 148 asDouble = aDataset.hasFloatingPointElements(); 149 reset(); 150 } 151 152 /** 153 * @return true if output from iterator is double 154 */ 155 public boolean isOutputDouble() { 156 return asDouble; 157 } 158 159 /** 160 * Set to output doubles 161 * @param asDouble 162 */ 163 public void setOutputDouble(boolean asDouble) { 164 if (this.asDouble != asDouble) { 165 this.asDouble = asDouble; 166 storeCurrentValues(); 167 } 168 } 169 170 @Override 171 public int[] getShape() { 172 return maxShape; 173 } 174 175 @Override 176 public boolean hasNext() { 177 int j = endrank; 178 int oldA = aIndex; 179 for (; j >= 0; j--) { 180 pos[j]++; 181 aIndex += aStride[j]; 182 if (oDelta != null) { 183 oIndex += oStride[j]; 184 } 185 if (pos[j] >= maxShape[j]) { 186 pos[j] = 0; 187 aIndex -= aDelta[j]; // reset these dimensions 188 if (oDelta != null) { 189 oIndex -= oDelta[j]; 190 } 191 } else { 192 break; 193 } 194 } 195 if (j == -1) { 196 if (endrank >= 0) { 197 return false; 198 } 199 aIndex += aStep; 200 if (oDelta != null) { 201 oIndex += oStep; 202 } 203 } 204 if (outputA) { 205 oIndex = aIndex; 206 } 207 208 if (aIndex == aMax) { 209 return false; // used for zero-rank datasets 210 } 211 212 if (oldA != aIndex) { 213 if (asDouble) { 214 aDouble = aDataset.getElementDoubleAbs(aIndex); 215 } else { 216 aLong = aDataset.getElementLongAbs(aIndex); 217 } 218 } 219 220 return true; 221 } 222 223 /** 224 * @return output dataset (can be null) 225 */ 226 public Dataset getOutput() { 227 return oDataset; 228 } 229 230 @Override 231 public int[] getPos() { 232 return pos; 233 } 234 235 @Override 236 public void reset() { 237 for (int i = 0; i <= endrank; i++) { 238 pos[i] = 0; 239 } 240 241 if (endrank >= 0) { 242 pos[endrank] = -1; 243 aIndex = aStart - aStride[endrank]; 244 oIndex = oStart - (oStride == null ? 0 : oStride[endrank]); 245 } else { 246 aIndex = -aStep; 247 oIndex = -oStep; 248 } 249 250 // for zero-ranked datasets 251 if (aIndex == 0) { 252 storeCurrentValues(); 253 } 254 } 255 256 private void storeCurrentValues() { 257 if (aIndex >= 0) { 258 if (asDouble) { 259 aDouble = aDataset.getElementDoubleAbs(aIndex); 260 } else { 261 aLong = aDataset.getElementLongAbs(aIndex); 262 } 263 } 264 } 265}