001/*- 002 ******************************************************************************* 003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd. 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * Peter Chang - initial API and implementation and/or initial documentation 011 *******************************************************************************/ 012 013package org.eclipse.january.dataset; 014 015import java.util.Arrays; 016import java.util.List; 017 018/** 019 * Class to run over a single dataset with NumPy broadcasting to promote shapes 020 * which have lower rank and outputs to a second dataset 021 */ 022public class SingleInputBroadcastIterator extends IndexIterator { 023 private int[] maxShape; 024 private int[] aShape; 025 private final Dataset aDataset; 026 private final Dataset oDataset; 027 private int[] aStride; 028 private int[] oStride; 029 030 final private int endrank; 031 032 /** 033 * position in dataset 034 */ 035 private final int[] pos; 036 private final int[] aDelta; 037 private final int[] oDelta; // this being non-null means output is different from inputs 038 private final int aStep, oStep; 039 private int aMax; 040 private int aStart, oStart; 041 private final boolean outputA; 042 043 /** 044 * Index in array 045 */ 046 public int aIndex, oIndex; 047 048 /** 049 * Current value in array 050 */ 051 public double aDouble; 052 053 /** 054 * Current value in array 055 */ 056 public long aLong; 057 058 private boolean asDouble = true; 059 060 /** 061 * @param a 062 * @param o (can be null for new dataset, or a) 063 */ 064 public SingleInputBroadcastIterator(Dataset a, Dataset o) { 065 this(a, o, false); 066 } 067 068 /** 069 * @param a 070 * @param o (can be null for new dataset, or a) 071 * @param createIfNull (by default, can create float or complex datasets) 072 */ 073 public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull) { 074 this(a, o, createIfNull, false, true); 075 } 076 077 /** 078 * @param a 079 * @param o (can be null for new dataset, or a) 080 * @param createIfNull 081 * @param allowInteger if true, can create integer datasets 082 * @param allowComplex if true, can create complex datasets 083 */ 084 @SuppressWarnings("deprecation") 085 public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull, boolean allowInteger, boolean allowComplex) { 086 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), o == null ? null : o.getShapeRef()); 087 088 BroadcastUtils.checkItemSize(a, o); 089 090 maxShape = fullShapes.remove(0); 091 092 oStride = null; 093 if (o != null) { 094 if (!Arrays.equals(maxShape, o.getShapeRef())) { 095 throw new IllegalArgumentException("Output does not match broadcasted shape"); 096 } 097 o.setDirty(); 098 } 099 100 aShape = fullShapes.remove(0); 101 102 int rank = maxShape.length; 103 endrank = rank - 1; 104 105 aDataset = a.reshape(aShape); 106 aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape); 107 outputA = o == a; 108 if (outputA) { 109 oStride = aStride; 110 oDelta = null; 111 oStep = 0; 112 oDataset = aDataset; 113 } else if (o != null) { 114 oStride = BroadcastUtils.createBroadcastStrides(o, maxShape); 115 oDelta = new int[rank]; 116 oStep = o.getElementsPerItem(); 117 oDataset = o; 118 } else if (createIfNull) { 119 int is = aDataset.getElementsPerItem(); 120 int dt = aDataset.getDType(); 121 if (aDataset.isComplex() && !allowComplex) { 122 is = 1; 123 dt = DTypeUtils.getBestFloatDType(dt); 124 } else if (!aDataset.hasFloatingPointElements() && !allowInteger) { 125 dt = DTypeUtils.getBestFloatDType(dt); 126 } 127 oDataset = DatasetFactory.zeros(is, maxShape, dt); 128 oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape); 129 oDelta = new int[rank]; 130 oStep = oDataset.getElementsPerItem(); 131 } else { 132 oDelta = null; 133 oStep = 0; 134 oDataset = o; 135 } 136 137 pos = new int[rank]; 138 aDelta = new int[rank]; 139 aStep = aDataset.getElementsPerItem(); 140 for (int j = endrank; j >= 0; j--) { 141 aDelta[j] = aStride[j] * aShape[j]; 142 if (oDelta != null) { 143 oDelta[j] = oStride[j] * maxShape[j]; 144 } 145 } 146 aStart = aDataset.getOffset(); 147 aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE; 148 oStart = oDelta == null ? 0 : oDataset.getOffset(); 149 asDouble = aDataset.hasFloatingPointElements(); 150 reset(); 151 } 152 153 /** 154 * @return true if output from iterator is double 155 */ 156 public boolean isOutputDouble() { 157 return asDouble; 158 } 159 160 /** 161 * Set to output doubles 162 * @param asDouble 163 */ 164 public void setOutputDouble(boolean asDouble) { 165 if (this.asDouble != asDouble) { 166 this.asDouble = asDouble; 167 storeCurrentValues(); 168 } 169 } 170 171 @Override 172 public int[] getShape() { 173 return maxShape; 174 } 175 176 @Override 177 public boolean hasNext() { 178 int j = endrank; 179 int oldA = aIndex; 180 for (; j >= 0; j--) { 181 pos[j]++; 182 aIndex += aStride[j]; 183 if (oDelta != null) { 184 oIndex += oStride[j]; 185 } 186 if (pos[j] >= maxShape[j]) { 187 pos[j] = 0; 188 aIndex -= aDelta[j]; // reset these dimensions 189 if (oDelta != null) { 190 oIndex -= oDelta[j]; 191 } 192 } else { 193 break; 194 } 195 } 196 if (j == -1) { 197 if (endrank >= 0) { 198 return false; 199 } 200 aIndex += aStep; 201 if (oDelta != null) { 202 oIndex += oStep; 203 } 204 } 205 if (outputA) { 206 oIndex = aIndex; 207 } 208 209 if (aIndex == aMax) { 210 return false; // used for zero-rank datasets 211 } 212 213 if (oldA != aIndex) { 214 if (asDouble) { 215 aDouble = aDataset.getElementDoubleAbs(aIndex); 216 } else { 217 aLong = aDataset.getElementLongAbs(aIndex); 218 } 219 } 220 221 return true; 222 } 223 224 /** 225 * @return output dataset (can be null) 226 */ 227 public Dataset getOutput() { 228 return oDataset; 229 } 230 231 @Override 232 public int[] getPos() { 233 return pos; 234 } 235 236 @Override 237 public void reset() { 238 for (int i = 0; i <= endrank; i++) { 239 pos[i] = 0; 240 } 241 242 if (endrank >= 0) { 243 pos[endrank] = -1; 244 aIndex = aStart - aStride[endrank]; 245 oIndex = oStart - (oStride == null ? 0 : oStride[endrank]); 246 } else { 247 aIndex = -aStep; 248 oIndex = -oStep; 249 } 250 251 // for zero-ranked datasets 252 if (aIndex == 0) { 253 storeCurrentValues(); 254 } 255 } 256 257 private void storeCurrentValues() { 258 if (aIndex >= 0) { 259 if (asDouble) { 260 aDouble = aDataset.getElementDoubleAbs(aIndex); 261 } else { 262 aLong = aDataset.getElementLongAbs(aIndex); 263 } 264 } 265 } 266}