001/*-
002 * Copyright 2015, 2016 Diamond Light Source Ltd.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 */
009
010package org.eclipse.january.dataset;
011
012import org.eclipse.january.DatasetException;
013
014/**
015 * A running mean class
016 */
017public class RunningAverage {
018        
019        private DoubleDataset average;
020        private DoubleDataset sqAveError;
021        private int count = 1;
022
023        /**
024         * @param dataset
025         */
026        public RunningAverage(IDataset dataset) {
027                average = (DoubleDataset) (dataset.getElementClass().equals(Double.class) ? DatasetUtils.convertToDataset(dataset).clone()
028                                : DatasetUtils.cast(dataset, Dataset.FLOAT64));
029
030                sqAveError = null;
031                Dataset eb = average.getErrorBuffer();
032                if (eb != null) {
033                        sqAveError = eb.getDType() != Dataset.FLOAT64 ? (DoubleDataset) DatasetUtils.cast(eb, Dataset.FLOAT64) :
034                                (DoubleDataset) eb;
035                }
036        }
037
038        /**
039         * Update average
040         * @param dataset
041         */
042        public void update(IDataset dataset) {
043                count++;
044                IndexIterator it = average.getIterator(true);
045                int[] pos = it.getPos();
046                double f = 1. / count;
047                if (sqAveError == null) {
048                        while (it.hasNext()) {
049                                double m = average.getAbs(it.index);
050                                double v = f * (dataset.getDouble(pos) - m);
051                                average.setAbs(it.index, m + v);
052                        }
053                } else {
054                        double fs = f * f;
055                        double gs = 2 * count - 1;
056                        if (dataset instanceof Dataset) {
057                                final Dataset d = (Dataset) dataset;
058                                final Dataset e = d.getErrorBuffer();
059                                while (it.hasNext()) {
060                                        double m = average.getAbs(it.index);
061                                        double v = f * (d.getDouble(pos) - m);
062                                        average.setAbs(it.index, m + v);
063
064                                        if (e != null) {
065                                                m = sqAveError.getDouble(pos);
066                                                v = fs * (e.getDouble(pos) - gs * m);
067                                                sqAveError.setItem(m + v, pos);
068                                        }
069                                }
070                        } else { // only linear error available
071                                ILazyDataset le = dataset.getErrors();
072                                IDataset e = null;
073                                if (le instanceof IDataset) {
074                                        e = (IDataset) le;
075                                } else if (le != null) {
076                                        try {
077                                                e = le.getSlice();
078                                        } catch (DatasetException e1) {
079                                        }
080                                }
081                                while (it.hasNext()) {
082                                        double m = average.getAbs(it.index);
083                                        double v = f * (dataset.getDouble(pos) - m);
084                                        average.setAbs(it.index, m + v);
085
086                                        if (e != null) {
087                                                m = sqAveError.getDouble(pos);
088                                                v = e.getDouble(pos);
089                                                v = fs * (v * v - gs * m);
090                                                sqAveError.setItem(m + v, pos);
091                                        }
092                                }
093                        }
094                }
095        }
096
097        /**
098         * @return count
099         */
100        public int getCount() {
101                return count;
102        }
103
104        /**
105         * @return current average
106         */
107        public Dataset getCurrentAverage() {
108                if (sqAveError != null) {
109                        Dataset e = sqAveError.clone();
110                        DatasetUtils.makeFinite(e);
111                        average.setErrorBuffer(e);
112                }
113
114                return average;
115        }
116}