/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.StreamSupport;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.opensearch.common.collect.Tuple;
import org.opensearch.ml.common.dataframe.ColumnMeta;
import org.opensearch.ml.common.dataframe.ColumnValue;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.Row;
import org.opensearch.ml.engine.contants.TribuoOutputType;
import org.tribuo.DataSource;
import org.tribuo.MutableDataset;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.anomaly.Event;
import org.tribuo.classification.Label;
import org.tribuo.clustering.ClusterID;
import org.tribuo.datasource.ListDataSource;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.DataSourceProvenance;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import org.tribuo.regression.Regressor;

public final class TribuoUtil {
    public static Tuple<String[], double[][]> transformDataFrame(DataFrame dataFrame) {
        String[] featureNames = (String[])Arrays.stream(dataFrame.columnMetas()).map(ColumnMeta::getName).toArray(String[]::new);
        double[][] featureValues = new double[dataFrame.size()][];
        Iterator itr = dataFrame.iterator();
        int i = 0;
        while (itr.hasNext()) {
            Row row = (Row)itr.next();
            featureValues[i] = StreamSupport.stream(row.spliterator(), false).mapToDouble(ColumnValue::doubleValue).toArray();
            ++i;
        }
        return new Tuple((Object)featureNames, (Object)featureValues);
    }

    public static Tuple<String[], float[][]> transformDataFrameFloat(DataFrame dataFrame) {
        String[] featureNames = (String[])Arrays.stream(dataFrame.columnMetas()).map(ColumnMeta::getName).toArray(String[]::new);
        float[][] featureValues = new float[dataFrame.size()][];
        Iterator itr = dataFrame.iterator();
        int i = 0;
        while (itr.hasNext()) {
            Row row = (Row)itr.next();
            float[] v = new float[row.size()];
            for (int ii = 0; ii < row.size(); ++ii) {
                v[ii] = (float)row.getValue(ii).doubleValue();
            }
            featureValues[i] = v;
            ++i;
        }
        return new Tuple((Object)featureNames, (Object)featureValues);
    }

    public static Tuple<String[], double[][]> transformClassificationDataFrame(DataFrame dataFrame, String target) {
        List<String> featureNames = Arrays.stream(dataFrame.columnMetas()).map(ColumnMeta::getName).collect(Collectors.toList());
        int targetIndex = dataFrame.getColumnIndex(target);
        int i = 0;
        Iterator itr = dataFrame.iterator();
        double[][] featureValues = new double[dataFrame.size()][featureNames.size() - 1];
        while (itr.hasNext()) {
            Row row = (Row)itr.next();
            int col = 0;
            for (int j = 0; j < featureNames.size(); ++j) {
                if (j == targetIndex) continue;
                featureValues[i][col++] = row.getValue(j).doubleValue();
            }
            ++i;
        }
        featureNames.remove(target);
        return new Tuple((Object)featureNames.toArray(new String[featureNames.size()]), (Object)featureValues);
    }

    public static String[] transformTargetValuesDataFrames(DataFrame dataFrame, String target) {
        int targetIndex = dataFrame.getColumnIndex(target);
        int i = 0;
        Iterator itr = dataFrame.iterator();
        String[] targetValues = new String[dataFrame.size()];
        while (itr.hasNext()) {
            Row row = (Row)itr.next();
            targetValues[i] = row.getValue(targetIndex).stringValue();
            ++i;
        }
        return targetValues;
    }

    public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType) {
        ArrayList<ArrayExample> dataset = new ArrayList<ArrayExample>();
        Tuple<String[], double[][]> featureNamesValues = TribuoUtil.transformDataFrame(dataFrame);
        for (int i = 0; i < dataFrame.size(); ++i) {
            ArrayExample example;
            switch (outputType) {
                case CLUSTERID: {
                    example = new ArrayExample((Output)new ClusterID(-1), (String[])featureNamesValues.v1(), ((double[][])featureNamesValues.v2())[i]);
                    break;
                }
                case REGRESSOR: {
                    example = new ArrayExample((Output)new Regressor("DIM-0", Double.NaN), (String[])featureNamesValues.v1(), ((double[][])featureNamesValues.v2())[i]);
                    break;
                }
                case ANOMALY_DETECTION_LIBSVM: {
                    Event.EventType defaultEventType = Event.EventType.EXPECTED;
                    example = new ArrayExample((Output)new Event(defaultEventType), (String[])featureNamesValues.v1(), ((double[][])featureNamesValues.v2())[i]);
                    break;
                }
                case LABEL: {
                    example = new ArrayExample(outputFactory.getUnknownOutput(), (String[])featureNamesValues.v1(), ((double[][])featureNamesValues.v2())[i]);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("unknown type:" + outputType);
                }
            }
            dataset.add(example);
        }
        SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
        return new MutableDataset((DataSource)new ListDataSource(dataset, outputFactory, (DataSourceProvenance)provenance));
    }

    public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
        if (StringUtils.isEmpty((CharSequence)target)) {
            throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
        }
        ArrayList<ArrayExample> dataset = new ArrayList<ArrayExample>();
        for (int i = 0; i < dataFrame.size(); ++i) {
            ArrayExample example;
            int finalI = i;
            switch (outputType) {
                case REGRESSOR: {
                    Tuple<String[], double[][]> featureNamesValues = TribuoUtil.transformDataFrame(dataFrame);
                    int finalTargetIndex = TribuoUtil.findFinalTargetIndex(featureNamesValues, target);
                    String[] featureNames = TribuoUtil.createFeatureNames(featureNamesValues, finalTargetIndex);
                    double targetValue = ((double[][])featureNamesValues.v2())[finalI][finalTargetIndex];
                    double[] featureValues = IntStream.range(0, ((double[][])featureNamesValues.v2())[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> ((double[][])featureNamesValues.v2())[finalI][e]).toArray();
                    example = new ArrayExample((Output)new Regressor(target, targetValue), featureNames, featureValues);
                    break;
                }
                case LABEL: {
                    Tuple<String[], double[][]> featureNamesValues = TribuoUtil.transformClassificationDataFrame(dataFrame, target);
                    String[] featureNames = TribuoUtil.createFeatureNames(featureNamesValues);
                    String[] targetValues = TribuoUtil.transformTargetValuesDataFrames(dataFrame, target);
                    double[] featureValues = IntStream.range(0, ((double[][])featureNamesValues.v2())[i].length).mapToDouble(e -> ((double[][])featureNamesValues.v2())[finalI][e]).toArray();
                    example = new ArrayExample((Output)new Label(targetValues[i]), featureNames, featureValues);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("unknown type:" + outputType);
                }
            }
            dataset.add(example);
        }
        SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
        return new MutableDataset((DataSource)new ListDataSource(dataset, outputFactory, (DataSourceProvenance)provenance));
    }

    private static int findFinalTargetIndex(Tuple<String[], double[][]> featureNamesValues, String target) {
        int targetIndex = -1;
        for (int i = 0; i < ((String[])featureNamesValues.v1()).length; ++i) {
            if (!((String[])featureNamesValues.v1())[i].equals(target)) continue;
            targetIndex = i;
            break;
        }
        if (targetIndex == -1) {
            throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
        }
        return targetIndex;
    }

    private static String[] createFeatureNames(Tuple<String[], double[][]> featureNamesValues, int finalTargetIndex) {
        String[] featureNames = (String[])IntStream.range(0, ((String[])featureNamesValues.v1()).length).filter(e -> e != finalTargetIndex).mapToObj(e -> ((String[])featureNamesValues.v1())[e]).toArray(String[]::new);
        return featureNames;
    }

    private static String[] createFeatureNames(Tuple<String[], double[][]> featureNamesValues) {
        String[] featureNames = (String[])IntStream.range(0, ((String[])featureNamesValues.v1()).length).mapToObj(e -> ((String[])featureNamesValues.v1())[e]).toArray(String[]::new);
        return featureNames;
    }

    @Generated
    private TribuoUtil() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }
}

