package kylm.model.ngram;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Vector;
import kylm.model.LanguageModel;
import kylm.model.ngram.smoother.NgramSmoother;
import kylm.reader.TextArraySentenceReader;
import kylm.util.KylmTextUtils;

/* loaded from: input_file:kylm/model/ngram/NgramLM.class */
public class NgramLM extends LanguageModel implements Serializable {
    private static final long serialVersionUID = 8531298547172592816L;
    protected int n;
    protected int[] counts;
    protected BranchNode root;
    protected NgramSmoother smoother;
    protected int[] hits;
    protected int sentHits;
    protected HashSet<String> ukWords;

    public NgramLM(int i) {
        this.n = 0;
        this.counts = null;
        this.root = null;
        this.smoother = null;
        this.hits = null;
        this.sentHits = 0;
        this.ukWords = null;
        this.n = i;
        if (i != -1) {
            this.counts = new int[i];
            this.hits = new int[i + 1];
        }
        this.root = new BranchNode(-1, null);
        this.sentHits = 0;
    }

    public NgramLM(int i, NgramSmoother ngramSmoother) {
        this.n = 0;
        this.counts = null;
        this.root = null;
        this.smoother = null;
        this.hits = null;
        this.sentHits = 0;
        this.ukWords = null;
        this.n = i;
        if (i != -1) {
            this.counts = new int[i];
            this.hits = new int[i + 1];
        }
        this.root = new BranchNode(-1, null);
        this.sentHits = 0;
        this.smoother = ngramSmoother;
    }

    @Override // kylm.model.LanguageModel
    public float[] getWordEntropies(int[] iArr) {
        int[] iArr2;
        for (int i = 0; i < iArr.length; i++) {
            if (this.root.getChild(iArr[i]) == null) {
                iArr[i] = findUnknownId(this.vocab.getSymbol(iArr[i]));
            }
        }
        this.wordEnts = new float[iArr.length - 1];
        this.simpleEnts = new float[iArr.length - 1];
        this.classEnts = new float[iArr.length - 1];
        if (this.classMap != null) {
            iArr2 = new int[iArr.length];
            iArr2[0] = this.classMap.getWordClass(iArr[0]);
            for (int i2 = 0; i2 < this.wordEnts.length; i2++) {
                int i3 = iArr[i2 + 1];
                this.classEnts[i2] = this.classMap.getWordProb(i3);
                float[] fArr = this.wordEnts;
                int i4 = i2;
                fArr[i4] = fArr[i4] + this.classEnts[i2];
                iArr2[i2 + 1] = this.classMap.getWordClass(i3);
            }
        } else {
            iArr2 = iArr;
        }
        this.sentHits++;
        NgramNode child = this.root.getChild(0);
        int i5 = 2;
        for (int i6 = 0; i6 < this.wordEnts.length; i6++) {
            int i7 = iArr2[i6 + 1];
            while (!child.hasChildren()) {
                i5--;
                child = child.getFallback();
            }
            do {
                NgramNode child2 = child.getChild(i7);
                if (child2 == null) {
                    i5--;
                    float[] fArr2 = this.simpleEnts;
                    int i8 = i6;
                    fArr2[i8] = fArr2[i8] + child.getBackoffScore();
                    child = child.getFallback();
                } else {
                    int[] iArr3 = this.hits;
                    int i9 = isInVocab(i7) ? i5 : 0;
                    iArr3[i9] = iArr3[i9] + 1;
                    i5++;
                    float[] fArr3 = this.simpleEnts;
                    int i10 = i6;
                    fArr3[i10] = fArr3[i10] + child2.score;
                    float[] fArr4 = this.wordEnts;
                    int i11 = i6;
                    fArr4[i11] = fArr4[i11] + this.simpleEnts[i6];
                    child = child2;
                }
            } while (child != null);
            throw new IllegalArgumentException("Could not find word in unigram vocabulary.");
        }
        return this.wordEnts;
    }

    @Override // kylm.model.LanguageModel
    public float getWordEntropy(int[] iArr, int i) {
        NgramNode ngramNode;
        float f = 0.0f;
        BranchNode branchNode = this.root;
        for (int max = Math.max(0, (i - this.n) + 1); max <= i; max++) {
            int i2 = max;
            while (i2 <= i) {
                NgramNode child = branchNode.getChild(this.classMap == null ? iArr[i2] : this.classMap.getWordClass(iArr[i2]));
                if (child == null) {
                    break;
                }
                branchNode = child;
                i2++;
            }
            if (i2 == i + 1) {
                return f + branchNode.score;
            }
            if (i2 == i) {
                f += branchNode.getBackoffScore();
                ngramNode = branchNode.getFallback();
            } else {
                ngramNode = this.root;
            }
            branchNode = ngramNode;
        }
        throw new IllegalArgumentException("could not find n-gram");
    }

    public float getSentencePerplexity(String[] strArr) {
        return getSentenceProb(strArr) / (-(strArr.length + 2));
    }

    public float getSentenceProbNormalized(String[] strArr) {
        return getSentencePerplexity(strArr);
    }

    /* JADX WARN: Code restructure failed: missing block: B:29:0x0094, code lost:
    
        r9 = r9 + r0.score;
        r12 = r0;
        r14 = r14 + 1;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public float getSentenceProb(java.lang.String[] r8) {
        /*
            r7 = this;
            r0 = 0
            r9 = r0
            r0 = r7
            r1 = r8
            int[] r0 = r0.getSentenceIds(r1)
            r10 = r0
            r0 = 0
            r11 = r0
        Lb:
            r0 = r11
            r1 = r10
            int r1 = r1.length
            if (r0 >= r1) goto L39
            r0 = r7
            kylm.model.ngram.BranchNode r0 = r0.root
            r1 = r10
            r2 = r11
            r1 = r1[r2]
            kylm.model.ngram.NgramNode r0 = r0.getChild(r1)
            if (r0 != 0) goto L33
            r0 = r10
            r1 = r11
            r2 = r7
            r3 = r7
            kylm.util.SymbolSet r3 = r3.vocab
            r4 = r10
            r5 = r11
            r4 = r4[r5]
            java.lang.String r3 = r3.getSymbol(r4)
            int r2 = r2.findUnknownId(r3)
            r0[r1] = r2
        L33:
            int r11 = r11 + 1
            goto Lb
        L39:
            r0 = r7
            kylm.model.ngram.BranchNode r0 = r0.root
            r1 = 0
            kylm.model.ngram.NgramNode r0 = r0.getChild(r1)
            r12 = r0
            r0 = 0
            r14 = r0
        L46:
            r0 = r14
            r1 = r10
            int r1 = r1.length
            r2 = 1
            int r1 = r1 - r2
            if (r0 >= r1) goto La6
            r0 = r10
            r1 = r14
            r2 = 1
            int r1 = r1 + r2
            r0 = r0[r1]
            r11 = r0
        L57:
            r0 = r12
            boolean r0 = r0.hasChildren()
            if (r0 != 0) goto L69
            r0 = r12
            kylm.model.ngram.NgramNode r0 = r0.getFallback()
            r12 = r0
            goto L57
        L69:
            r0 = r12
            r1 = r11
            kylm.model.ngram.NgramNode r0 = r0.getChild(r1)
            r1 = r0
            r13 = r1
            if (r0 != 0) goto L94
            r0 = r9
            r1 = r12
            float r1 = r1.getBackoffScore()
            float r0 = r0 + r1
            r9 = r0
            r0 = r12
            kylm.model.ngram.NgramNode r0 = r0.getFallback()
            r12 = r0
            r0 = r12
            if (r0 != 0) goto L69
            java.lang.IllegalArgumentException r0 = new java.lang.IllegalArgumentException
            r1 = r0
            java.lang.String r2 = "Could not find word in unigram vocabulary."
            r1.<init>(r2)
            throw r0
        L94:
            r0 = r9
            r1 = r13
            float r1 = r1.score
            float r0 = r0 + r1
            r9 = r0
            r0 = r13
            r12 = r0
            int r14 = r14 + 1
            goto L46
        La6:
            r0 = r9
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: kylm.model.ngram.NgramLM.getSentenceProb(java.lang.String[]):float");
    }

    @Override // kylm.model.LanguageModel
    public void trainModel(Iterable<String[]> iterable) throws IOException {
        if (this.debug > 0) {
            System.err.println("NgramLM.trainModel(): Started for " + this.name);
        }
        countNgrams(iterable);
        if (this.smoother != null) {
            this.smoother.smooth(this);
        }
        if (!this.closed) {
            this.closed = true;
            for (int i = 2; i <= this.ukModelCount + 1; i++) {
                if (this.root.getChild(i) != null) {
                    this.closed = false;
                }
            }
        }
        if (this.ukModels != null && this.ukWords != null) {
            Vector vector = new Vector(this.ukModelCount);
            for (int i2 = 0; i2 < this.ukModelCount; i2++) {
                vector.add(new LinkedList());
            }
            Iterator<String> it = this.ukWords.iterator();
            while (it.hasNext()) {
                String next = it.next();
                ((LinkedList) vector.get(findUnknownId(next) - 2)).add(KylmTextUtils.join(" ", KylmTextUtils.splitChars(next)));
            }
            for (int i3 = 0; i3 < this.ukModels.length; i3++) {
                this.ukModels[i3].trainModel(new TextArraySentenceReader((String[]) ((LinkedList) vector.get(i3)).toArray(new String[0])));
            }
        }
        if (this.debug > 0) {
            System.err.println("NgramLM.trainModel(): Finished for " + this.name);
        }
    }

    public void countNgrams(Iterable<String[]> iterable) throws IOException {
        int i;
        if (this.debug > 0) {
            System.err.println("NgramLM.countNgrams(): Started for " + this.name);
        }
        if (this.vocab == null) {
            importVocabulary(iterable);
        }
        this.root.setChildrenSize(this.classMap == null ? this.vocab.getSize() : this.classMap.getClassSize());
        int[] iArr = this.classMap == null ? null : new int[this.vocab.getSize()];
        int[] iArr2 = this.classMap == null ? null : new int[this.classMap.getClassSize()];
        int[] iArr3 = new int[this.maxLength];
        if (this.ukModels != null) {
            this.ukWords = new HashSet<>();
        }
        int i2 = 0;
        for (String[] strArr : iterable) {
            if (this.debug > 0) {
                i2++;
                if (i2 % 10000 == 0) {
                    System.err.print(i2 % 1000000 == 0 ? Integer.valueOf(i2) : ".");
                }
            }
            if (strArr.length != 0) {
                if (strArr.length + 2 > this.maxLength) {
                    this.maxLength = strArr.length + 2;
                    iArr3 = new int[this.maxLength];
                }
                int i3 = 1;
                while (i3 <= strArr.length) {
                    iArr3[i3] = getId(strArr[i3 - 1]);
                    if (this.ukModels != null && (this.modelAllWords || !isInVocab(strArr[i3 - 1]))) {
                        this.ukWords.add(strArr[i3 - 1]);
                    }
                    if (this.classMap != null) {
                        int i4 = iArr3[i3];
                        iArr[i4] = iArr[i4] + 1;
                        int wordClass = this.classMap.getWordClass(iArr3[i3]);
                        iArr2[wordClass] = iArr2[wordClass] + 1;
                        iArr3[i3] = this.classMap.getWordClass(iArr3[i3]);
                    }
                    i3++;
                }
                int i5 = iArr3[1] == 0 ? 1 : 0;
                if (iArr3[strArr.length] != 0) {
                    int i6 = i3;
                    i3++;
                    iArr3[i6] = 0;
                }
                this.root.count += (i3 - i5) - 1;
                for (int i7 = i5; i7 < i3 - 1; i7++) {
                    BranchNode branchNode = this.root;
                    int i8 = 0;
                    while (i8 < this.n && (i = i7 + i8) < i3) {
                        branchNode = branchNode.getChild(iArr3[i], i8 == this.n - 1 ? 1 : 2);
                        if (branchNode.count == 0) {
                            int[] iArr4 = this.counts;
                            int i9 = i8;
                            iArr4[i9] = iArr4[i9] + 1;
                        }
                        branchNode.count++;
                        i8++;
                    }
                }
            }
        }
        if (this.root.getChild(1) == null && !this.terminalSymbol.equals(this.startSymbol)) {
            int[] iArr5 = this.counts;
            iArr5[0] = iArr5[0] + 1;
        }
        if (this.classMap != null) {
            for (int i10 = 0; i10 < iArr.length; i10++) {
                if (iArr[i10] > 0) {
                    this.classMap.setWordProb(i10, (float) Math.log10(iArr[i10] / iArr2[this.classMap.getWordClass(i10)]));
                }
            }
        }
        if (this.debug > 0) {
            System.err.println();
            System.err.println("NgramLM.countNgrams(): Finished for " + this.name);
        }
    }

    public BranchNode getRoot() {
        return this.root;
    }

    public int getN() {
        return this.n;
    }

    public void setN(int i) {
        this.n = i;
        this.counts = new int[i];
        this.hits = new int[i + 1];
        this.sentHits = 0;
    }

    public void expandUnknowns() {
        int i = 0;
        Vector vector = new Vector();
        vector.add(this.vocab.getId(this.ukSymbol));
        Iterator<NgramNode> it = this.root.iterator();
        while (it.hasNext()) {
            NgramNode next = it.next();
            while (i < next.id) {
                int i2 = i;
                i++;
                vector.add(Integer.valueOf(i2));
            }
            i++;
        }
        float score = (float) (this.root.getChild(((Integer) vector.get(0)).intValue()).getScore() + (Math.log10(vector.size()) * (-1.0d)));
        if (this.debug > 0) {
            System.err.println("Expanding " + vector.size() + " unknown words");
        }
        int i3 = this.n == 1 ? 1 : 2;
        Iterator it2 = vector.iterator();
        while (it2.hasNext()) {
            this.root.getChild(((Integer) it2.next()).intValue(), i3).setScore(score);
        }
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(this.n);
        objectOutputStream.writeObject(this.counts);
        objectOutputStream.writeObject(this.smoother);
        objectOutputStream.writeInt(this.vocab.getSize());
        writeNgrams(objectOutputStream, this.root, 0);
    }

    private void writeNgrams(ObjectOutputStream objectOutputStream, NgramNode ngramNode, int i) throws IOException {
        objectOutputStream.writeInt(ngramNode.getId());
        objectOutputStream.writeFloat(ngramNode.getScore());
        if (i == this.n) {
            return;
        }
        if (!ngramNode.hasChildren()) {
            objectOutputStream.writeInt(0);
            return;
        }
        objectOutputStream.writeInt(ngramNode.getChildCount());
        objectOutputStream.writeFloat(ngramNode.getBackoffScore());
        Iterator<NgramNode> it = ngramNode.iterator();
        while (it.hasNext()) {
            writeNgrams(objectOutputStream, it.next(), i + 1);
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        setN(objectInputStream.readInt());
        this.counts = (int[]) objectInputStream.readObject();
        this.smoother = (NgramSmoother) objectInputStream.readObject();
        int readInt = objectInputStream.readInt();
        this.root = new BranchNode(-1, null);
        this.root.setChildrenSize(this.classMap == null ? readInt : this.classMap.getClassSize());
        this.root.setId(objectInputStream.readInt());
        readNgrams(objectInputStream, this.root, 0);
    }

    private void readNgrams(ObjectInputStream objectInputStream, NgramNode ngramNode, int i) throws IOException {
        int readInt;
        ngramNode.setScore(objectInputStream.readFloat());
        if (i == this.n || (readInt = objectInputStream.readInt()) == 0) {
            return;
        }
        ngramNode.setBackoffScore(objectInputStream.readFloat());
        for (int i2 = 0; i2 < readInt; i2++) {
            readNgrams(objectInputStream, ngramNode.getChild(objectInputStream.readInt(), i + 1 == this.n ? 1 : 2), i + 1);
        }
    }

    private static boolean eq(Object obj, Object obj2) {
        return obj == null ? obj2 == null : obj.equals(obj2);
    }

    @Override // kylm.model.LanguageModel
    public boolean equals(Object obj) {
        try {
            NgramLM ngramLM = (NgramLM) obj;
            if (super.equals(ngramLM) && this.n == ngramLM.n && eq(this.root, ngramLM.root) && eq(this.smoother.getClass(), ngramLM.smoother.getClass())) {
                if (Arrays.equals(this.counts, ngramLM.counts)) {
                    return true;
                }
            }
            return false;
        } catch (Exception e) {
            return false;
        }
    }

    public int[] getNgramCounts() {
        return this.counts;
    }

    public NgramSmoother getSmoother() {
        return this.smoother;
    }

    public void setSmoother(NgramSmoother ngramSmoother) {
        this.smoother = ngramSmoother;
    }

    public void setNgramCounts(int[] iArr) {
        this.counts = iArr;
    }

    public String getNodeName(NgramNode ngramNode) {
        if (ngramNode.getParent() == null) {
            return "";
        }
        String nodeName = getNodeName(ngramNode.getParent());
        String symbol = this.vocab.getSymbol(ngramNode.getId());
        return nodeName.length() == 0 ? symbol : nodeName + " " + symbol;
    }

    @Override // kylm.model.LanguageModel
    public String printReport() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append(this.name).append(" coverage: ");
        int[] iArr = new int[this.n + 1];
        int[] iArr2 = new int[this.n + 1];
        iArr[this.n] = this.hits[this.n];
        for (int i = this.n - 1; i >= 0; i--) {
            iArr[i] = this.hits[i] + iArr[i + 1];
        }
        iArr2[0] = iArr[0];
        int i2 = 1;
        while (i2 < iArr2.length) {
            iArr2[i2] = iArr2[i2 - 1] - (i2 > 2 ? this.sentHits : 0);
            i2++;
        }
        for (int i3 = 1; i3 <= this.n; i3++) {
            stringBuffer.append(i3).append("-gram ").append((iArr[i3] * 100.0d) / iArr2[i3]).append("% ");
        }
        return stringBuffer.toString();
    }
}
