/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ctakes.relationextractor.eval;

import com.google.common.base.Function;
import com.google.common.base.Functions;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.lexicalscope.jewel.cli.Option;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.regex.Pattern;
import org.apache.ctakes.core.ae.SHARPKnowtatorXMLReader;
import org.apache.ctakes.core.pipeline.PipeBitInfo;
import org.apache.ctakes.core.util.DocumentIDAnnotationUtil;
import org.apache.ctakes.relationextractor.eval.ParameterSettings;
import org.apache.ctakes.relationextractor.eval.XMIReader;
import org.apache.ctakes.typesystem.type.structured.DocumentID;
import org.apache.uima.UIMAFramework;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CASException;
import org.apache.uima.cas.impl.XmiCasSerializer;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.component.ViewCreatorAnnotator;
import org.apache.uima.fit.factory.AggregateBuilder;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.fit.factory.TypeSystemDescriptionFactory;
import org.apache.uima.fit.pipeline.JCasIterator;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.metadata.TypeSystemDescription;
import org.apache.uima.util.XMLInputSource;
import org.apache.uima.util.XMLParser;
import org.apache.uima.util.XMLSerializer;
import org.cleartk.eval.AnnotationStatistics;
import org.cleartk.util.ViewUriUtil;
import org.cleartk.util.ae.UriToDocumentTextAnnotator;
import org.cleartk.util.cr.UriCollectionReader;
import org.xml.sax.ContentHandler;

public class SHARPXMI {
    public static final String GOLD_VIEW_NAME = "GoldView";

    public static List<File> getTrainTextFiles(File batchesDirectory) {
        return SHARPXMI.getTextFilesFor(batchesDirectory, Pattern.compile("^(ss[1234]_batch0[2-9]|ss[1234]_batch1[56]|ss[1234]_batch1[89]|ss[123]_batch01|ss[12]_batch1[34]|ss[34]_batch1[12])$"));
    }

    public static List<File> getDevTextFiles(File batchesDirectory) {
        return SHARPXMI.getTextFilesFor(batchesDirectory, Pattern.compile("^(ss[1234]_batch1[07])$"));
    }

    public static List<File> getTestTextFiles(File batchesDirectory) {
        return SHARPXMI.getTextFilesFor(batchesDirectory, Pattern.compile("^(ss[12]_batch1[12]|ss[34]_batch1[34])$"));
    }

    public static List<File> getAllTextFiles(File batchesDirectory) {
        return SHARPXMI.getTextFilesFor(batchesDirectory, Pattern.compile(""));
    }

    private static List<File> getTextFilesFor(File batchesDirectory, Pattern pattern) {
        ArrayList files = Lists.newArrayList();
        for (File batchDir : batchesDirectory.listFiles()) {
            if (!batchDir.isDirectory() || batchDir.isHidden() || !pattern.matcher(batchDir.getName()).find()) continue;
            File textDirectory = new File(batchDir, "Knowtator/text");
            for (File textFile : textDirectory.listFiles()) {
                if (!textFile.isFile() || textFile.isHidden()) continue;
                files.add(textFile);
            }
        }
        return files;
    }

    public static List<File> toXMIFiles(Options options, List<File> textFiles) {
        ArrayList xmiFiles = Lists.newArrayList();
        for (File textFile : textFiles) {
            xmiFiles.add(SHARPXMI.toXMIFile(options, textFile));
        }
        return xmiFiles;
    }

    private static File toXMIFile(Options options, File textFile) {
        return new File(options.getXMIDirectory(), textFile.getName() + ".xmi");
    }

    public static void generateXMI(Options options) throws Exception {
        if (options.getGenerateXMI()) {
            if (!options.getXMIDirectory().exists()) {
                options.getXMIDirectory().mkdirs();
            }
            ArrayList files = Lists.newArrayList();
            files.addAll(SHARPXMI.getTrainTextFiles(options.getBatchesDirectory()));
            files.addAll(SHARPXMI.getDevTextFiles(options.getBatchesDirectory()));
            files.addAll(SHARPXMI.getTestTextFiles(options.getBatchesDirectory()));
            CollectionReader reader = UriCollectionReader.getCollectionReaderFromFiles((Collection)files);
            AggregateBuilder builder = new AggregateBuilder();
            builder.add(UriToDocumentTextAnnotator.getDescription(), new String[0]);
            File preprocessDescFile = new File("desc/analysis_engine/RelationExtractorPreprocessor.xml");
            XMLParser parser = UIMAFramework.getXMLParser();
            XMLInputSource source = new XMLInputSource(preprocessDescFile);
            builder.add(parser.parseAnalysisEngineDescription(source), new String[0]);
            builder.add(AnalysisEngineFactory.createEngineDescription(ViewCreatorAnnotator.class, (Object[])new Object[]{"viewName", GOLD_VIEW_NAME}), new String[0]);
            builder.add(AnalysisEngineFactory.createEngineDescription(CopyDocumentTextToGoldView.class, (Object[])new Object[0]), new String[0]);
            builder.add(AnalysisEngineFactory.createEngineDescription(DocumentIDAnnotator.class, (Object[])new Object[0]), new String[]{"_InitialView", GOLD_VIEW_NAME});
            builder.add(AnalysisEngineFactory.createEngineDescription(SHARPKnowtatorXMLReader.class, (Object[])new Object[]{"SetDefaults", true}), new String[]{"_InitialView", GOLD_VIEW_NAME});
            JCasIterator casIter = new JCasIterator(reader, new AnalysisEngine[]{builder.createAggregate()});
            while (casIter.hasNext()) {
                JCas jCas = (JCas)casIter.next();
                JCas goldView = jCas.getView(GOLD_VIEW_NAME);
                String documentID = DocumentIDAnnotationUtil.getDocumentID((JCas)goldView);
                if (documentID == null) {
                    throw new IllegalArgumentException("No documentID for CAS:\n" + jCas);
                }
                File outFile = SHARPXMI.toXMIFile(options, new File(documentID));
                FileOutputStream stream = new FileOutputStream(outFile);
                ContentHandler handler = new XMLSerializer((OutputStream)stream).getContentHandler();
                new XmiCasSerializer(jCas.getTypeSystem()).serialize(jCas.getCas(), handler);
                stream.close();
            }
        }
    }

    public static void validate(EvaluationOptions options) throws Exception {
        if (options.getEvaluteOn().equals((Object)EvaluateOn.TEST) && options.getGridSearch()) {
            throw new IllegalArgumentException("grid search can only be run on the train or dev sets");
        }
    }

    public static <T extends Evaluation_ImplBase> void evaluate(EvaluationOptions options, ParameterSettings bestSettings, List<ParameterSettings> gridOfSettings, Function<ParameterSettings, T> getEvaluation) throws Exception {
        ArrayList possibleParams = options.getGridSearch() ? gridOfSettings : Lists.newArrayList((Object[])new ParameterSettings[]{bestSettings});
        HashMap<ParameterSettings, Double> scoredParams = new HashMap<ParameterSettings, Double>();
        for (ParameterSettings params : possibleParams) {
            Evaluation_ImplBase evaluation = (Evaluation_ImplBase)((Object)getEvaluation.apply((Object)params));
            switch (options.getEvaluteOn()) {
                case TRAIN: {
                    List<File> trainFiles = SHARPXMI.getTrainTextFiles(options.getBatchesDirectory());
                    trainFiles = SHARPXMI.toXMIFiles(options, trainFiles);
                    List foldStats = evaluation.crossValidation(trainFiles, 2);
                    params.stats = AnnotationStatistics.addAll((Iterable)foldStats);
                    break;
                }
                case DEV: {
                    List<File> trainFiles = SHARPXMI.getTrainTextFiles(options.getBatchesDirectory());
                    trainFiles = SHARPXMI.toXMIFiles(options, trainFiles);
                    List<File> devFiles = SHARPXMI.getDevTextFiles(options.getBatchesDirectory());
                    devFiles = SHARPXMI.toXMIFiles(options, devFiles);
                    params.stats = (AnnotationStatistics)evaluation.trainAndTest(trainFiles, devFiles);
                    break;
                }
                case TEST: {
                    List<File> allTrainFiles = new ArrayList<File>();
                    allTrainFiles.addAll(SHARPXMI.getTrainTextFiles(options.getBatchesDirectory()));
                    allTrainFiles.addAll(SHARPXMI.getDevTextFiles(options.getBatchesDirectory()));
                    allTrainFiles = SHARPXMI.toXMIFiles(options, allTrainFiles);
                    List<File> testFiles = SHARPXMI.getTestTextFiles(options.getBatchesDirectory());
                    testFiles = SHARPXMI.toXMIFiles(options, testFiles);
                    params.stats = (AnnotationStatistics)evaluation.trainAndTest(allTrainFiles, testFiles);
                    break;
                }
                case OTHER: {
                    List<File> trainAndDevFiles = new ArrayList<File>();
                    trainAndDevFiles.addAll(SHARPXMI.getTrainTextFiles(options.getBatchesDirectory()));
                    trainAndDevFiles.addAll(SHARPXMI.getDevTextFiles(options.getBatchesDirectory()));
                    trainAndDevFiles = SHARPXMI.toXMIFiles(options, trainAndDevFiles);
                    if (options.getTrainXmiDir() != null) {
                        for (File trainXmiFile : options.getTrainXmiDir().listFiles()) {
                            trainAndDevFiles.add(trainXmiFile);
                        }
                    }
                    ArrayList<File> testXmiFiles = new ArrayList<File>();
                    for (File testXmiFile : options.getTestXmiDir().listFiles()) {
                        testXmiFiles.add(testXmiFile);
                    }
                    params.stats = (AnnotationStatistics)evaluation.trainAndTest(trainAndDevFiles, testXmiFiles);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Invalid EvaluateOn: " + (Object)((Object)options.getEvaluteOn()));
                }
            }
            scoredParams.put(params, params.stats.f1());
        }
        ArrayList list = new ArrayList(scoredParams.keySet());
        Function getCount = Functions.forMap(scoredParams);
        Collections.sort(list, Ordering.natural().onResultOf(getCount));
        if (list.size() > 1) {
            System.err.println("Summary");
            for (ParameterSettings params : list) {
                System.err.printf("F1=%.3f P=%.3f R=%.3f %s\n", params.stats.f1(), params.stats.precision(), params.stats.recall(), params);
            }
            System.err.println();
        }
        if (!list.isEmpty()) {
            ParameterSettings lastParams = (ParameterSettings)list.get(list.size() - 1);
            System.err.println("Best model:");
            System.err.print(lastParams.stats);
            System.err.println(lastParams);
            System.err.println(lastParams.stats.confusions());
            System.err.println();
        }
    }

    @PipeBitInfo(name="Text to Gold Copier", description="Copies Text from the System view to the Gold view.", role=PipeBitInfo.Role.SPECIAL)
    public static class CopyDocumentTextToGoldView
    extends JCasAnnotator_ImplBase {
        public void process(JCas jCas) throws AnalysisEngineProcessException {
            try {
                JCas goldView = jCas.getView(SHARPXMI.GOLD_VIEW_NAME);
                goldView.setDocumentText(jCas.getDocumentText());
            }
            catch (CASException e) {
                throw new AnalysisEngineProcessException((Throwable)e);
            }
        }
    }

    public static class DocumentIDAnnotator
    extends JCasAnnotator_ImplBase {
        public void process(JCas jCas) throws AnalysisEngineProcessException {
            String documentID = new File(ViewUriUtil.getURI((JCas)jCas)).getPath();
            DocumentID documentIDAnnotation = new DocumentID(jCas);
            documentIDAnnotation.setDocumentID(documentID);
            documentIDAnnotation.addToIndexes();
        }
    }

    public static abstract class Evaluation_ImplBase
    extends org.cleartk.eval.Evaluation_ImplBase<File, AnnotationStatistics<String>> {
        public Evaluation_ImplBase(File baseDirectory) {
            super(baseDirectory);
        }

        public CollectionReader getCollectionReader(List<File> items) throws Exception {
            return CollectionReaderFactory.createReader(XMIReader.class, (TypeSystemDescription)TypeSystemDescriptionFactory.createTypeSystemDescription(), (Object[])new Object[]{"files", items});
        }
    }

    public static interface EvaluationOptions
    extends Options {
        @Option(longName={"evaluate-on"}, defaultValue={"DEV"}, description="perform evaluation using the training (TRAIN), development (DEV) or test (TEST) data.")
        public EvaluateOn getEvaluteOn();

        @Option(longName={"grid-search"}, description="run a grid search to select the best parameters")
        public boolean getGridSearch();

        @Option(defaultToNull=true, longName={"train-xmi-dir"}, description="use these XMI files for training; they must contain the necessary preprocessing in system view and gold annotation in gold view")
        public File getTrainXmiDir();

        @Option(longName={"test-xmi-dir"}, defaultValue={""}, description="evaluate on these XMI files; they must contain the necessary preprocessing in system view and gold annotation in gold view")
        public File getTestXmiDir();
    }

    public static enum EvaluateOn {
        TRAIN,
        DEV,
        TEST,
        OTHER;

    }

    public static interface Options {
        @Option(longName={"batches-dir"}, description="directory containing ssN_batchNN directories, each of which should contain a Knowtator directory and a Knowtator_XML directory")
        public File getBatchesDirectory();

        @Option(longName={"xmi-dir"}, defaultValue={"target/xmi"}, description="directory to store and load XMI serialization of annotations")
        public File getXMIDirectory();

        @Option(longName={"generate-xmi"}, description="read in the gold annotations and serialize them as XMI")
        public boolean getGenerateXMI();
    }
}

