package com.wcohen.ss.abbvGapsHmm;

import com.wcohen.ss.abbvGapsHmm.AbbvGapsHMM;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.eclipse.swt.custom.StyledTextPrintOptions;

/* loaded from: input_file:lib/com.wcohen.secondstring-0.1.jar:com/wcohen/ss/abbvGapsHmm/AlignmentPredictionModel.class */
public class AlignmentPredictionModel {
    public static final String SEPARATOR = "#_#";
    public static String _trainingDataDir;
    public static String _trueLabelsFile;
    public static String _trainingCorpusFile;
    private AbbvGapsHMM _abbvHmm;

    public AlignmentPredictionModel() throws IOException {
        this._abbvHmm = null;
        this._abbvHmm = new AbbvGapsHMM();
        setTrainingDataDir("train/");
    }

    public void setTrainingDataDir(String str) {
        _trainingDataDir = str;
        _trueLabelsFile = _trainingDataDir + "abbvAlign_pairs.txt";
        _trainingCorpusFile = _trainingDataDir + "abbvAlign_corpus.txt";
    }

    public void setTfIdfData(String str) throws IOException {
        this._abbvHmm.setTfIdfData(str);
    }

    public void setModelParamsFile(String str) {
        this._abbvHmm.setParamFile(str);
    }

    public void setModelParamsFile() {
        setModelParamsFile("hmmModelParams.txt");
    }

    public static ArrayList<Map<String, String>> loadLabels(String str) {
        if (str == null) {
            return null;
        }
        ArrayList<Map<String, String>> arrayList = null;
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            arrayList = new ArrayList<>();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                HashMap hashMap = new HashMap();
                for (String str2 : readLine.split(SEPARATOR)) {
                    if (!str2.isEmpty()) {
                        String[] split = str2.split(StyledTextPrintOptions.SEPARATOR);
                        if (split.length != 2) {
                            System.out.println("BAD FORMAT in " + str + ": " + str2);
                        } else {
                            hashMap.put(split[0].trim(), split[1].trim());
                        }
                    }
                }
                arrayList.add(hashMap);
            }
            bufferedReader.close();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return arrayList;
    }

    public static List<String> loadTrainingCorpus(String str) {
        ArrayList arrayList = null;
        try {
            arrayList = new ArrayList();
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                arrayList.add(readLine);
            }
            bufferedReader.close();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return arrayList;
    }

    public boolean trainOnAll() {
        ArrayList<Map<String, String>> loadLabels = loadLabels(_trueLabelsFile);
        List<String> loadTrainingCorpus = loadTrainingCorpus(_trainingCorpusFile);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int i = 0;
        while (true) {
            Integer num = i;
            if (num.intValue() >= loadTrainingCorpus.size()) {
                return this._abbvHmm.train(arrayList, arrayList2, true);
            }
            arrayList.add(extractCandidatePairs(loadTrainingCorpus.get(num.intValue())));
            arrayList2.add(loadLabels.get(num.intValue()));
            i = Integer.valueOf(num.intValue() + 1);
        }
    }

    public boolean trainOnCandidates() {
        List<String> loadTrainingCorpus = loadTrainingCorpus(_trainingCorpusFile);
        ArrayList arrayList = new ArrayList();
        int i = 0;
        while (true) {
            Integer num = i;
            if (num.intValue() >= loadTrainingCorpus.size()) {
                return this._abbvHmm.train(arrayList, null, true);
            }
            arrayList.add(extractCandidatePairs(loadTrainingCorpus.get(num.intValue())));
            i = Integer.valueOf(num.intValue() + 1);
        }
    }

    public boolean train(List<String> list, List<Integer> list2, List<Map<String, String>> list3) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        if (list2 != null) {
            for (Integer num : list2) {
                arrayList.add(extractCandidatePairs(list.get(num.intValue())));
                arrayList2.add(list3.get(num.intValue()));
            }
        } else {
            for (int i = 0; i < list.size(); i++) {
                arrayList.add(extractCandidatePairs(list.get(i)));
                arrayList2.add(list3.get(i));
            }
        }
        return this._abbvHmm.train(arrayList, arrayList2, true);
    }

    public AbbreviationAlignmentContainer<AbbvGapsHMM.Emissions, AbbvGapsHMM.States> predict(String str, String str2) {
        return predictAlignment(new Acronym(str, str2));
    }

    public AbbreviationAlignmentContainer<AbbvGapsHMM.Emissions, AbbvGapsHMM.States> predictAlignment(Acronym acronym) {
        return this._abbvHmm.viterbi(acronym);
    }

    public Acronym predict(Acronym acronym) {
        AbbreviationAlignmentContainer<AbbvGapsHMM.Emissions, AbbvGapsHMM.States> predictAlignment = predictAlignment(acronym);
        Acronym acronym2 = null;
        if (predictAlignment == null) {
            return null;
        }
        try {
            acronym2 = predictAlignment.getAcronym();
            if (acronym2 != null) {
                acronym2._probability = predictAlignment(acronym2).getProbability();
                acronym2._alignment = predictAlignment;
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return acronym2;
    }

    public Map<String, Acronym> acronymsArrayToMap(Collection<Acronym> collection) {
        HashMap hashMap = new HashMap();
        for (Acronym acronym : collection) {
            if (hashMap.containsKey(acronym._shortForm)) {
                Acronym acronym2 = (Acronym) hashMap.get(acronym._shortForm);
                if (acronym._probability != null && acronym2._probability != null && acronym._probability.compareTo(acronym2._probability) > 0) {
                    hashMap.put(acronym._shortForm, acronym);
                }
            } else {
                hashMap.put(acronym._shortForm, acronym);
            }
        }
        return hashMap;
    }

    public Collection<Acronym> predict(String str) {
        List<Acronym> extractCandidatePairs = extractCandidatePairs(str);
        ArrayList arrayList = new ArrayList();
        Iterator<Acronym> it = extractCandidatePairs.iterator();
        while (it.hasNext()) {
            Acronym predict = predict(it.next());
            if (predict != null) {
                arrayList.add(predict);
            }
        }
        return arrayList;
    }

    public boolean trainIfNeeded() {
        if (this._abbvHmm.loadModelParams()) {
            return true;
        }
        return trainOnCandidates();
    }

    public List<Acronym> extractCandidatePairs(String str) {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(extractSingleAcronyms(str));
        arrayList.addAll(extractPatternAcronyms(str));
        return arrayList;
    }

    protected List<Acronym> extractPatternAcronyms(String str) {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(extractHeadNounPattern_2Parts(str));
        arrayList.addAll(extractHeadNounPattern_3Parts(str));
        arrayList.addAll(extractTrailingNounPattern_2Parts(str));
        arrayList.addAll(extractTrailingNounPattern_3Parts(str));
        return arrayList;
    }

    protected void addCandidatePair(List<Acronym> list, String str, String str2) {
        Acronym parseCandidate = parseCandidate(str, str2);
        if (parseCandidate == null || parseCandidate._shortForm.isEmpty()) {
            return;
        }
        list.add(parseCandidate);
    }

    protected List<Acronym> extractHeadNounPattern_3Parts(String str) {
        ArrayList arrayList = new ArrayList();
        Matcher matcher = Pattern.compile("([a-zA-Z0-9\\-]{1,20}) ([a-zA-Z0-9\\-]{1,20}) \\(([^\\(]*?)\\),? ([a-zA-Z0-9\\-]{1,20}) \\(([^\\(]*?)\\),? and ([a-zA-Z0-9\\-]{1,20}) \\(([^\\(]*?)\\)").matcher(str);
        int i = 0;
        while (i < str.length() && matcher.find(i)) {
            String group = matcher.group(1);
            String group2 = matcher.group(2);
            String group3 = matcher.group(3);
            String group4 = matcher.group(4);
            String group5 = matcher.group(5);
            String group6 = matcher.group(6);
            String group7 = matcher.group(7);
            i = matcher.regionEnd() + 1;
            addCandidatePair(arrayList, group + " " + group2, group3);
            addCandidatePair(arrayList, group + " " + group4, group5);
            addCandidatePair(arrayList, group + " " + group6, group7);
        }
        return arrayList;
    }

    protected List<Acronym> extractHeadNounPattern_2Parts(String str) {
        ArrayList arrayList = new ArrayList();
        Matcher matcher = Pattern.compile("([a-zA-Z0-9\\-]{1,20}) ([a-zA-Z0-9\\-]{1,20}) \\(([^\\(]*?)\\),? and ([a-zA-Z0-9\\-]{1,20}) \\(([^\\(]*?)\\)").matcher(str);
        int i = 0;
        while (i < str.length() && matcher.find(i)) {
            String group = matcher.group(1);
            String group2 = matcher.group(2);
            String group3 = matcher.group(3);
            String group4 = matcher.group(4);
            String group5 = matcher.group(5);
            i = matcher.regionEnd() + 1;
            addCandidatePair(arrayList, group + " " + group2, group3);
            addCandidatePair(arrayList, group + " " + group4, group5);
        }
        return arrayList;
    }

    protected List<Acronym> extractTrailingNounPattern_3Parts(String str) {
        ArrayList arrayList = new ArrayList();
        Matcher matcher = Pattern.compile("(.{1,20}?) \\(([^\\(]*?)\\),? (.{1,20}?) \\(([^\\(]*?)\\),? and (.{1,20}?) \\(([^\\(]*?)\\) ([a-zA-Z0-9\\-]{1,20})").matcher(str);
        int i = 0;
        while (i < str.length() && matcher.find(i)) {
            String group = matcher.group(1);
            String group2 = matcher.group(2);
            String group3 = matcher.group(3);
            String group4 = matcher.group(4);
            String group5 = matcher.group(5);
            String group6 = matcher.group(6);
            String group7 = matcher.group(7);
            i = matcher.regionEnd() + 1;
            addCandidatePair(arrayList, group + " " + group7, group2);
            addCandidatePair(arrayList, group3 + " " + group7, group4);
            addCandidatePair(arrayList, group5 + " " + group7, group6);
        }
        return arrayList;
    }

    protected List<Acronym> extractTrailingNounPattern_2Parts(String str) {
        ArrayList arrayList = new ArrayList();
        Matcher matcher = Pattern.compile("(.{1,20}?) \\(([^\\(]*?)\\),? and (.{1,20}?) \\(([^\\(]*?)\\) ([a-zA-Z0-9\\-]{1,20})").matcher(str);
        int i = 0;
        while (i < str.length() && matcher.find(i)) {
            String group = matcher.group(1);
            String group2 = matcher.group(2);
            String group3 = matcher.group(3);
            String group4 = matcher.group(4);
            String group5 = matcher.group(5);
            i = matcher.regionEnd() + 1;
            addCandidatePair(arrayList, group + " " + group5, group2);
            addCandidatePair(arrayList, group3 + " " + group5, group4);
        }
        return arrayList;
    }

    protected List<Acronym> extractSingleAcronyms(String str) {
        ArrayList arrayList = new ArrayList();
        int indexOf = str.indexOf("(");
        while (indexOf != -1) {
            int i = -1;
            int i2 = 0;
            int i3 = indexOf + 1;
            while (true) {
                if (i3 >= str.length()) {
                    break;
                }
                if (str.charAt(i3) == '(') {
                    i2++;
                }
                if (str.charAt(i3) == ')') {
                    if (i2 <= 0) {
                        i = i3;
                        break;
                    }
                    i2--;
                }
                i3++;
            }
            if (i != -1) {
                addCandidatePair(arrayList, str.substring(0, indexOf).trim(), str.substring(indexOf + 1, i));
            }
            indexOf = str.indexOf("(", indexOf + 1);
        }
        return arrayList;
    }

    protected Acronym parseCandidate(String str, String str2) {
        if (str2.indexOf(";") != -1) {
            str2 = str2.substring(0, str2.indexOf(";"));
        }
        if (str.indexOf(";") != -1) {
            str = str.substring(str.indexOf(";") + 1);
        }
        String trim = str2.trim();
        String trim2 = str.trim();
        if (!isShortForm(trim)) {
            trim2 = str2.trim();
            String[] split = str.trim().split(" ");
            trim = split[split.length - 1];
        }
        if (!isValidShortForm(trim) || !isValidExpression(trim) || !isValidExpression(trim2)) {
            return null;
        }
        String[] split2 = trim2.split(" ");
        int length = trim.length();
        int min = Math.min(Math.min(length + 5, length * 2), split2.length);
        String str3 = "";
        for (int length2 = split2.length - 1; length2 > (split2.length - min) - 1; length2--) {
            str3 = split2[length2] + " " + str3;
        }
        String trim3 = str3.trim();
        if (trim.equalsIgnoreCase(trim3)) {
            return null;
        }
        return new Acronym(trim, trim3);
    }

    protected String chunkLongForm(String str, int i) {
        int i2 = 0;
        int length = str.length() - 1;
        while (length >= 0 && i2 < i) {
            if (length == 0 || !Character.isLetterOrDigit(str.charAt(length - 1))) {
                i2++;
            }
            length--;
        }
        return str.substring(length + 1, str.length());
    }

    protected boolean isValidExpression(String str) {
        return (str == null || str.isEmpty()) ? false : true;
    }

    protected boolean isShortForm(String str) {
        return str.split(" ").length <= 3;
    }

    protected boolean isValidShortForm(String str) {
        return str.length() <= 15 && str.length() >= 1 && Pattern.matches("^[a-zA-Z0-9].*", str) && Pattern.matches(".*[a-zA-Z].*", str);
    }

    public List<Double> getEmmisions() {
        return this._abbvHmm.getEmmisionParams();
    }

    public List<Double> getTransitions() {
        return this._abbvHmm.getTransitionParams();
    }

    public void setStartingParams(List<Double> list, List<Double> list2) {
        this._abbvHmm.setStartingParams(list, list2);
    }
}
