package kylm.model;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Vector;
import java.util.regex.Pattern;
import kylm.util.KylmMathUtils;
import kylm.util.KylmTextUtils;
import kylm.util.SymbolSet;

/* loaded from: input_file:kylm/model/LanguageModel.class */
public abstract class LanguageModel implements Serializable {
    private static final long serialVersionUID = 4882315303447910570L;
    protected ClassMap classMap;
    protected int debug = 0;
    protected String symbol = null;
    protected String name = null;
    protected Pattern regex = null;
    protected boolean countTerminals = true;
    protected int maxLength = 0;
    protected boolean closed = false;
    protected SymbolSet vocab = null;
    protected int vocabFrequency = 1;
    protected int vocabLimit = 0;
    protected String startSymbol = "<s>";
    protected String terminalSymbol = "</s>";
    protected String ukSymbol = "<unk>";
    protected LanguageModel[] ukModels = null;
    protected int ukModelCount = 1;
    protected boolean modelAllWords = true;
    protected float[] wordEnts = null;
    protected float[] simpleEnts = null;
    protected float[] unkEnts = null;
    protected float[] classEnts = null;

    private static boolean eq(Object obj, Object obj2) {
        return obj == null ? obj2 == null : obj.equals(obj2);
    }

    /* JADX WARN: Code restructure failed: missing block: B:36:0x007d, code lost:
    
        if (r3.vocab.syms.equals(r0.vocab.syms) != false) goto L23;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public boolean equals(java.lang.Object r4) {
        /*
            r3 = this;
            r0 = r4
            kylm.model.LanguageModel r0 = (kylm.model.LanguageModel) r0     // Catch: java.lang.Exception -> Lcc
            r5 = r0
            r0 = r3
            boolean r0 = r0.closed     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            boolean r1 = r1.closed     // Catch: java.lang.Exception -> Lcc
            if (r0 != r1) goto Lca
            r0 = r3
            boolean r0 = r0.countTerminals     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            boolean r1 = r1.countTerminals     // Catch: java.lang.Exception -> Lcc
            if (r0 != r1) goto Lca
            r0 = r3
            int r0 = r0.vocabLimit     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            int r1 = r1.vocabLimit     // Catch: java.lang.Exception -> Lcc
            if (r0 != r1) goto Lca
            r0 = r3
            int r0 = r0.ukModelCount     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            int r1 = r1.ukModelCount     // Catch: java.lang.Exception -> Lcc
            if (r0 != r1) goto Lca
            r0 = r3
            java.lang.String r0 = r0.symbol     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            java.lang.String r1 = r1.symbol     // Catch: java.lang.Exception -> Lcc
            boolean r0 = eq(r0, r1)     // Catch: java.lang.Exception -> Lcc
            if (r0 == 0) goto Lca
            r0 = r3
            java.lang.String r0 = r0.name     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            java.lang.String r1 = r1.name     // Catch: java.lang.Exception -> Lcc
            boolean r0 = eq(r0, r1)     // Catch: java.lang.Exception -> Lcc
            if (r0 == 0) goto Lca
            r0 = r3
            java.util.regex.Pattern r0 = r0.regex     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            java.util.regex.Pattern r1 = r1.regex     // Catch: java.lang.Exception -> Lcc
            boolean r0 = eq(r0, r1)     // Catch: java.lang.Exception -> Lcc
            if (r0 == 0) goto Lca
            r0 = r3
            kylm.util.SymbolSet r0 = r0.vocab     // Catch: java.lang.Exception -> Lcc
            if (r0 != 0) goto L6c
            r0 = r5
            kylm.util.SymbolSet r0 = r0.vocab     // Catch: java.lang.Exception -> Lcc
            if (r0 != 0) goto Lca
            goto L80
        L6c:
            r0 = r3
            kylm.util.SymbolSet r0 = r0.vocab     // Catch: java.lang.Exception -> Lcc
            java.util.Vector<java.lang.String> r0 = r0.syms     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            kylm.util.SymbolSet r1 = r1.vocab     // Catch: java.lang.Exception -> Lcc
            java.util.Vector<java.lang.String> r1 = r1.syms     // Catch: java.lang.Exception -> Lcc
            boolean r0 = r0.equals(r1)     // Catch: java.lang.Exception -> Lcc
            if (r0 == 0) goto Lca
        L80:
            r0 = r3
            java.lang.String r0 = r0.startSymbol     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            java.lang.String r1 = r1.startSymbol     // Catch: java.lang.Exception -> Lcc
            boolean r0 = eq(r0, r1)     // Catch: java.lang.Exception -> Lcc
            if (r0 == 0) goto Lca
            r0 = r3
            java.lang.String r0 = r0.terminalSymbol     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            java.lang.String r1 = r1.terminalSymbol     // Catch: java.lang.Exception -> Lcc
            boolean r0 = eq(r0, r1)     // Catch: java.lang.Exception -> Lcc
            if (r0 == 0) goto Lca
            r0 = r3
            java.lang.String r0 = r0.ukSymbol     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            java.lang.String r1 = r1.ukSymbol     // Catch: java.lang.Exception -> Lcc
            boolean r0 = eq(r0, r1)     // Catch: java.lang.Exception -> Lcc
            if (r0 == 0) goto Lca
            r0 = r3
            kylm.model.ClassMap r0 = r0.classMap     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            kylm.model.ClassMap r1 = r1.classMap     // Catch: java.lang.Exception -> Lcc
            boolean r0 = eq(r0, r1)     // Catch: java.lang.Exception -> Lcc
            if (r0 == 0) goto Lca
            r0 = r3
            kylm.model.LanguageModel[] r0 = r0.ukModels     // Catch: java.lang.Exception -> Lcc
            r1 = r5
            kylm.model.LanguageModel[] r1 = r1.ukModels     // Catch: java.lang.Exception -> Lcc
            boolean r0 = java.util.Arrays.equals(r0, r1)     // Catch: java.lang.Exception -> Lcc
            if (r0 == 0) goto Lca
            r0 = 1
            goto Lcb
        Lca:
            r0 = 0
        Lcb:
            return r0
        Lcc:
            r5 = move-exception
            r0 = 0
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: kylm.model.LanguageModel.equals(java.lang.Object):boolean");
    }

    public abstract float getWordEntropy(int[] iArr, int i);

    public float[] getWordEntropies(int[] iArr) {
        this.wordEnts = new float[iArr.length - 1];
        for (int i = 0; i < this.wordEnts.length; i++) {
            this.wordEnts[i] = getWordEntropy(iArr, i + 1);
        }
        return this.wordEnts;
    }

    public float[] getWordEntropies(String[] strArr) {
        int[] sentenceIds = getSentenceIds(strArr);
        this.wordEnts = getWordEntropies(sentenceIds);
        this.unkEnts = new float[strArr.length + 1];
        if (this.ukModels != null) {
            for (int i = 0; i < strArr.length; i++) {
                if (!isInVocab(strArr[i])) {
                    this.unkEnts[i] = this.ukModels[sentenceIds[i + 1] - 2].getSentenceEntropy(KylmTextUtils.splitChars(strArr[i]));
                    float[] fArr = this.wordEnts;
                    int i2 = i;
                    fArr[i2] = fArr[i2] + this.unkEnts[i];
                }
            }
        } else if (this.vocabLimit > 0) {
            int size = this.vocabLimit - this.vocab.getSize();
            if (size <= 0) {
                throw new IllegalArgumentException("vocab size has exceeded the vocab size limit");
            }
            float log10 = (float) Math.log10(1.0d / size);
            for (int i3 = 0; i3 < strArr.length; i3++) {
                if (!isInVocab(strArr[i3])) {
                    float[] fArr2 = this.unkEnts;
                    int i4 = i3;
                    fArr2[i4] = fArr2[i4] + log10;
                    float[] fArr3 = this.wordEnts;
                    int i5 = i3;
                    fArr3[i5] = fArr3[i5] + log10;
                }
            }
        }
        return this.wordEnts;
    }

    public float[] getClassEntropies() {
        return this.classEnts;
    }

    public float[] getSimpleEntropies() {
        return this.simpleEnts;
    }

    public float[] getUnknownEntropies() {
        return this.unkEnts;
    }

    public float getSentenceEntropy(String[] strArr) {
        return KylmMathUtils.sum(getWordEntropies(strArr));
    }

    public float getSentenceSimpleEntropy() {
        return KylmMathUtils.sum(this.simpleEnts);
    }

    public float getSentenceUnknownEntropy() {
        return KylmMathUtils.sum(this.unkEnts);
    }

    public float getSentenceClassEntropy() {
        return KylmMathUtils.sum(this.classEnts);
    }

    public abstract void trainModel(Iterable<String[]> iterable) throws IOException;

    public abstract String printReport();

    public void setVocabulary(String[] strArr) {
        initializeVocab();
        for (String str : strArr) {
            this.vocab.addSymbol(str);
        }
    }

    private void initializeVocab() {
        if (this.vocab != null && this.debug > 0) {
            System.err.println("WARNING: Resetting vocabulary, could cause alignment problems");
        }
        this.vocab = new SymbolSet();
        this.vocab.addSymbol(this.startSymbol);
        this.vocab.addSymbol(this.terminalSymbol);
        this.vocab.addAlias(this.terminalSymbol, this.vocab.getId(this.startSymbol).intValue());
        if (this.vocab.getSize() == 1) {
            this.vocab.pushSymbol(this.terminalSymbol);
        }
        if (this.ukModels == null) {
            this.vocab.addSymbol(this.ukSymbol);
            return;
        }
        for (LanguageModel languageModel : this.ukModels) {
            String symbol = languageModel.getSymbol();
            if (symbol == null) {
                throw new IllegalArgumentException("A symbol must be specified for each unknown word model");
            }
            if (symbol.equals(this.terminalSymbol) || symbol.equals(this.startSymbol)) {
                throw new IllegalArgumentException("Unknown model and terminal/start symbols cannot be equal");
            }
            this.vocab.addSymbol(languageModel.getSymbol());
        }
    }

    public boolean isInVocab(String str) {
        Integer id = this.vocab.getId(str);
        return id != null && isInVocab(id.intValue());
    }

    public boolean isInVocab(int i) {
        return i < 2 || i > this.ukModelCount + 1;
    }

    public int[] getSentenceIds(String[] strArr) {
        int[] iArr = new int[strArr.length + (this.countTerminals ? 2 : 1)];
        for (int i = 0; i < strArr.length; i++) {
            iArr[i + 1] = getId(strArr[i]);
        }
        return iArr;
    }

    public String[] getVocabulary() {
        return this.vocab.getSymbols();
    }

    public void importVocabulary(Iterable<String[]> iterable) throws IOException {
        if (this.debug > 0) {
            System.err.println("LanguageModel.importVocabulary(): Started for " + this.name);
        }
        HashMap hashMap = new HashMap();
        for (String[] strArr : iterable) {
            this.maxLength = Math.max(strArr.length + 2, this.maxLength);
            for (String str : strArr) {
                Integer num = (Integer) hashMap.get(str);
                hashMap.put(str, Integer.valueOf(num == null ? 1 : num.intValue() + 1));
            }
        }
        if (this.debug > 0) {
            System.err.println("LanguageModel.importVocabulary(): Vocab " + hashMap.size() + " before trimming");
        }
        initializeVocab();
        Vector vector = new Vector();
        for (Map.Entry entry : hashMap.entrySet()) {
            if (((Integer) entry.getValue()).intValue() > this.vocabFrequency) {
                vector.add(entry.getKey());
            }
        }
        Collections.sort(vector);
        Iterator it = vector.iterator();
        while (it.hasNext()) {
            this.vocab.addSymbol((String) it.next());
        }
        if (this.debug > 0) {
            System.err.println("LanguageModel.importVocabulary(): Finished with size " + this.vocab.getSize() + " for " + this.name);
        }
    }

    public int findUnknownId(String str) {
        if (this.closed) {
            throw new IllegalArgumentException("Unknown word " + str + " found in closed model");
        }
        if (this.ukModels == null) {
            return 2;
        }
        for (int i = 0; i < this.ukModels.length; i++) {
            Pattern regex = this.ukModels[i].getRegex();
            if (regex == null || regex.matcher(str).matches()) {
                return i + 2;
            }
        }
        throw new IllegalArgumentException("No unknown word model found to match " + str);
    }

    public int getId(String str) {
        Integer id = this.vocab.getId(str);
        if (id == null) {
            id = Integer.valueOf(findUnknownId(str));
            this.vocab.addAlias(str, id.intValue());
        }
        return id.intValue();
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(this.debug);
        objectOutputStream.writeObject(this.symbol);
        objectOutputStream.writeObject(this.name);
        objectOutputStream.writeObject(this.regex);
        objectOutputStream.writeBoolean(this.closed);
        objectOutputStream.writeBoolean(this.countTerminals);
        objectOutputStream.writeInt(this.maxLength);
        objectOutputStream.writeObject(this.symbol);
        objectOutputStream.writeObject(this.vocab);
        objectOutputStream.writeInt(this.vocabFrequency);
        objectOutputStream.writeInt(this.vocabLimit);
        objectOutputStream.writeObject(this.startSymbol);
        objectOutputStream.writeObject(this.terminalSymbol);
        objectOutputStream.writeObject(this.ukSymbol);
        objectOutputStream.writeObject(this.ukModels);
        objectOutputStream.writeObject(this.classMap);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        this.debug = objectInputStream.readInt();
        try {
            this.symbol = (String) objectInputStream.readObject();
        } catch (NullPointerException e) {
        }
        try {
            this.name = (String) objectInputStream.readObject();
        } catch (NullPointerException e2) {
        }
        try {
            this.regex = (Pattern) objectInputStream.readObject();
        } catch (NullPointerException e3) {
        }
        this.closed = objectInputStream.readBoolean();
        this.countTerminals = objectInputStream.readBoolean();
        this.maxLength = objectInputStream.readInt();
        try {
            this.symbol = (String) objectInputStream.readObject();
        } catch (NullPointerException e4) {
        }
        try {
            this.vocab = (SymbolSet) objectInputStream.readObject();
        } catch (NullPointerException e5) {
        }
        this.vocabFrequency = objectInputStream.readInt();
        this.vocabLimit = objectInputStream.readInt();
        try {
            this.startSymbol = (String) objectInputStream.readObject();
        } catch (NullPointerException e6) {
        }
        try {
            this.terminalSymbol = (String) objectInputStream.readObject();
        } catch (NullPointerException e7) {
        }
        try {
            this.ukSymbol = (String) objectInputStream.readObject();
        } catch (NullPointerException e8) {
        }
        try {
            this.ukModels = (LanguageModel[]) objectInputStream.readObject();
        } catch (NullPointerException e9) {
        }
        if (this.ukModels != null) {
            this.ukModelCount = this.ukModels.length;
        } else {
            this.ukModelCount = 1;
        }
        try {
            this.classMap = (ClassMap) objectInputStream.readObject();
        } catch (NullPointerException e10) {
        }
    }

    public int getDebug() {
        return this.debug;
    }

    public void setDebug(int i) {
        this.debug = i;
    }

    public String getSymbol() {
        return this.symbol;
    }

    public void setSymbol(String str) {
        this.symbol = str;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String str) {
        this.name = str;
    }

    public SymbolSet getVocab() {
        if (this.vocab == null) {
            initializeVocab();
        }
        return this.vocab;
    }

    public void setVocab(SymbolSet symbolSet) {
        initializeVocab();
        this.vocab.addSymbols(symbolSet.getSymbols());
    }

    public String getStartSymbol() {
        return this.startSymbol;
    }

    public void setStartSymbol(String str) {
        this.startSymbol = str;
    }

    public String getTerminalSymbol() {
        return this.terminalSymbol;
    }

    public void setTerminalSymbol(String str) {
        this.terminalSymbol = str;
    }

    public String getUnknownSymbol() {
        return this.ukSymbol;
    }

    public void setUnknownSymbol(String str) {
        this.ukSymbol = str;
    }

    public LanguageModel[] getUnknownModels() {
        return this.ukModels;
    }

    public void setUnknownModels(LanguageModel[] languageModelArr) {
        this.ukModels = languageModelArr;
    }

    public boolean isClosed() {
        return this.closed;
    }

    public void setClosed(boolean z) {
        this.closed = z;
    }

    public boolean getCountTerminals() {
        return this.countTerminals;
    }

    public void setCountTerminals(boolean z) {
        this.countTerminals = z;
    }

    public int getVocabFrequency() {
        return this.vocabFrequency;
    }

    public void setVocabFrequency(int i) {
        this.vocabFrequency = i;
    }

    public Pattern getRegex() {
        return this.regex;
    }

    public void setRegex(String str) {
        this.regex = Pattern.compile(str);
    }

    public int getVocabLimit() {
        return this.vocabLimit;
    }

    public void setVocabLimit(int i) {
        this.vocabLimit = i;
    }

    public int getMaxLength() {
        return this.maxLength;
    }

    public void setMaxLength(int i) {
        this.maxLength = i;
    }

    public ClassMap getClassMap() {
        return this.classMap;
    }

    public void setClassMap(ClassMap classMap) {
        this.classMap = classMap;
    }

    public int getUnknownModelCount() {
        return this.ukModelCount;
    }
}
