/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.lms;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import java.util.WeakHashMap;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.pqc.crypto.ExhaustedPrivateKeyException;
import org.bouncycastle.pqc.crypto.lms.Composer;
import org.bouncycastle.pqc.crypto.lms.DigestUtil;
import org.bouncycastle.pqc.crypto.lms.LMOtsParameters;
import org.bouncycastle.pqc.crypto.lms.LMOtsPrivateKey;
import org.bouncycastle.pqc.crypto.lms.LMS;
import org.bouncycastle.pqc.crypto.lms.LMSContext;
import org.bouncycastle.pqc.crypto.lms.LMSContextBasedSigner;
import org.bouncycastle.pqc.crypto.lms.LMSKeyParameters;
import org.bouncycastle.pqc.crypto.lms.LMSPublicKeyParameters;
import org.bouncycastle.pqc.crypto.lms.LMSigParameters;
import org.bouncycastle.pqc.crypto.lms.LM_OTS;
import org.bouncycastle.pqc.crypto.lms.LmsUtils;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.io.Streams;

public class LMSPrivateKeyParameters
extends LMSKeyParameters
implements LMSContextBasedSigner {
    private static CacheKey T1 = new CacheKey(1);
    private static CacheKey[] internedKeys = new CacheKey[129];
    private final byte[] I;
    private final LMSigParameters parameters;
    private final LMOtsParameters otsParameters;
    private final int maxQ;
    private final byte[] masterSecret;
    private final Map<CacheKey, byte[]> tCache;
    private final int maxCacheR;
    private final Digest tDigest;
    private int q;
    private LMSPublicKeyParameters publicKey;

    public LMSPrivateKeyParameters(LMSigParameters lmsParameter, LMOtsParameters otsParameters, int q, byte[] I, int maxQ, byte[] masterSecret) {
        super(true);
        this.parameters = lmsParameter;
        this.otsParameters = otsParameters;
        this.q = q;
        this.I = Arrays.clone(I);
        this.maxQ = maxQ;
        this.masterSecret = Arrays.clone(masterSecret);
        this.maxCacheR = 1 << this.parameters.getH() + 1;
        this.tCache = new WeakHashMap<CacheKey, byte[]>();
        this.tDigest = DigestUtil.getDigest(lmsParameter.getDigestOID());
    }

    private LMSPrivateKeyParameters(LMSPrivateKeyParameters parent, int q, int maxQ) {
        super(true);
        this.parameters = parent.parameters;
        this.otsParameters = parent.otsParameters;
        this.q = q;
        this.I = parent.I;
        this.maxQ = maxQ;
        this.masterSecret = parent.masterSecret;
        this.maxCacheR = 1 << this.parameters.getH();
        this.tCache = parent.tCache;
        this.tDigest = DigestUtil.getDigest(this.parameters.getDigestOID());
        this.publicKey = parent.publicKey;
    }

    public static LMSPrivateKeyParameters getInstance(byte[] privEnc, byte[] pubEnc) throws IOException {
        LMSPrivateKeyParameters pKey = LMSPrivateKeyParameters.getInstance(privEnc);
        pKey.publicKey = LMSPublicKeyParameters.getInstance(pubEnc);
        return pKey;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static LMSPrivateKeyParameters getInstance(Object src) throws IOException {
        if (src instanceof LMSPrivateKeyParameters) {
            return (LMSPrivateKeyParameters)src;
        }
        if (src instanceof DataInputStream) {
            DataInputStream dIn = (DataInputStream)src;
            if (dIn.readInt() != 0) {
                throw new IllegalStateException("expected version 0 lms private key");
            }
            LMSigParameters parameter = LMSigParameters.getParametersForType(dIn.readInt());
            LMOtsParameters otsParameter = LMOtsParameters.getParametersForType(dIn.readInt());
            byte[] I = new byte[16];
            dIn.readFully(I);
            int q = dIn.readInt();
            int maxQ = dIn.readInt();
            int l = dIn.readInt();
            if (l < 0) {
                throw new IllegalStateException("secret length less than zero");
            }
            if (l > dIn.available()) {
                throw new IOException("secret length exceeded " + dIn.available());
            }
            byte[] masterSecret = new byte[l];
            dIn.readFully(masterSecret);
            return new LMSPrivateKeyParameters(parameter, otsParameter, q, I, maxQ, masterSecret);
        }
        if (src instanceof byte[]) {
            InputStream in = null;
            try {
                in = new DataInputStream(new ByteArrayInputStream((byte[])src));
                LMSPrivateKeyParameters lMSPrivateKeyParameters = LMSPrivateKeyParameters.getInstance(in);
                return lMSPrivateKeyParameters;
            }
            finally {
                if (in != null) {
                    in.close();
                }
            }
        }
        if (src instanceof InputStream) {
            return LMSPrivateKeyParameters.getInstance(Streams.readAll((InputStream)src));
        }
        throw new IllegalArgumentException("cannot parse " + src);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    LMOtsPrivateKey getCurrentOTSKey() {
        LMSPrivateKeyParameters lMSPrivateKeyParameters = this;
        synchronized (lMSPrivateKeyParameters) {
            if (this.q >= this.maxQ) {
                throw new ExhaustedPrivateKeyException("ots private keys expired");
            }
            return new LMOtsPrivateKey(this.otsParameters, this.I, this.q, this.masterSecret);
        }
    }

    public synchronized int getIndex() {
        return this.q;
    }

    synchronized void incIndex() {
        ++this.q;
    }

    @Override
    public LMSContext generateLMSContext() {
        LMSigParameters lmsParameter = this.getSigParameters();
        int h = lmsParameter.getH();
        int q = this.getIndex();
        LMOtsPrivateKey otsPk = this.getNextOtsPrivateKey();
        int r = (1 << h) + q;
        byte[][] path = new byte[h][];
        for (int i = 0; i < h; ++i) {
            int tmp = r / (1 << i) ^ 1;
            path[i] = this.findT(tmp);
        }
        return otsPk.getSignatureContext(this.getSigParameters(), path);
    }

    @Override
    public byte[] generateSignature(LMSContext context) {
        try {
            return LMS.generateSign(context).getEncoded();
        }
        catch (IOException e) {
            throw new IllegalStateException("unable to encode signature: " + e.getMessage(), e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    LMOtsPrivateKey getNextOtsPrivateKey() {
        LMSPrivateKeyParameters lMSPrivateKeyParameters = this;
        synchronized (lMSPrivateKeyParameters) {
            if (this.q >= this.maxQ) {
                throw new ExhaustedPrivateKeyException("ots private key exhausted");
            }
            LMOtsPrivateKey otsPrivateKey = new LMOtsPrivateKey(this.otsParameters, this.I, this.q, this.masterSecret);
            this.incIndex();
            return otsPrivateKey;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public LMSPrivateKeyParameters extractKeyShard(int usageCount) {
        LMSPrivateKeyParameters lMSPrivateKeyParameters = this;
        synchronized (lMSPrivateKeyParameters) {
            if (this.q + usageCount >= this.maxQ) {
                throw new IllegalArgumentException("usageCount exceeds usages remaining");
            }
            LMSPrivateKeyParameters keyParameters = new LMSPrivateKeyParameters(this, this.q, this.q + usageCount);
            this.q += usageCount;
            return keyParameters;
        }
    }

    public LMSigParameters getSigParameters() {
        return this.parameters;
    }

    public LMOtsParameters getOtsParameters() {
        return this.otsParameters;
    }

    public byte[] getI() {
        return Arrays.clone(this.I);
    }

    public byte[] getMasterSecret() {
        return Arrays.clone(this.masterSecret);
    }

    @Override
    public long getUsagesRemaining() {
        return this.maxQ - this.q;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public LMSPublicKeyParameters getPublicKey() {
        LMSPrivateKeyParameters lMSPrivateKeyParameters = this;
        synchronized (lMSPrivateKeyParameters) {
            if (this.publicKey == null) {
                this.publicKey = new LMSPublicKeyParameters(this.parameters, this.otsParameters, this.findT(T1), this.I);
            }
            return this.publicKey;
        }
    }

    byte[] findT(int r) {
        if (r < this.maxCacheR) {
            return this.findT(r < internedKeys.length ? internedKeys[r] : new CacheKey(r));
        }
        return this.calcT(r);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private byte[] findT(CacheKey key) {
        Map<CacheKey, byte[]> map = this.tCache;
        synchronized (map) {
            byte[] t = this.tCache.get(key);
            if (t != null) {
                return t;
            }
            t = this.calcT(key.index);
            this.tCache.put(key, t);
            return t;
        }
    }

    private byte[] calcT(int r) {
        int h = this.getSigParameters().getH();
        int twoToh = 1 << h;
        if (r >= twoToh) {
            LmsUtils.byteArray(this.getI(), this.tDigest);
            LmsUtils.u32str(r, this.tDigest);
            LmsUtils.u16str((short)-32126, this.tDigest);
            byte[] K = LM_OTS.lms_ots_generatePublicKey(this.getOtsParameters(), this.getI(), r - twoToh, this.getMasterSecret());
            LmsUtils.byteArray(K, this.tDigest);
            byte[] T = new byte[this.tDigest.getDigestSize()];
            this.tDigest.doFinal(T, 0);
            return T;
        }
        byte[] t2r = this.findT(2 * r);
        byte[] t2rPlus1 = this.findT(2 * r + 1);
        LmsUtils.byteArray(this.getI(), this.tDigest);
        LmsUtils.u32str(r, this.tDigest);
        LmsUtils.u16str((short)-31869, this.tDigest);
        LmsUtils.byteArray(t2r, this.tDigest);
        LmsUtils.byteArray(t2rPlus1, this.tDigest);
        byte[] T = new byte[this.tDigest.getDigestSize()];
        this.tDigest.doFinal(T, 0);
        return T;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        LMSPrivateKeyParameters that = (LMSPrivateKeyParameters)o;
        if (this.q != that.q) {
            return false;
        }
        if (this.maxQ != that.maxQ) {
            return false;
        }
        if (!Arrays.areEqual(this.I, that.I)) {
            return false;
        }
        if (this.parameters != null ? !this.parameters.equals(that.parameters) : that.parameters != null) {
            return false;
        }
        if (this.otsParameters != null ? !this.otsParameters.equals(that.otsParameters) : that.otsParameters != null) {
            return false;
        }
        if (!Arrays.areEqual(this.masterSecret, that.masterSecret)) {
            return false;
        }
        if (this.publicKey != null && that.publicKey != null) {
            return this.publicKey.equals(that.publicKey);
        }
        return true;
    }

    public int hashCode() {
        int result = this.q;
        result = 31 * result + Arrays.hashCode(this.I);
        result = 31 * result + (this.parameters != null ? this.parameters.hashCode() : 0);
        result = 31 * result + (this.otsParameters != null ? this.otsParameters.hashCode() : 0);
        result = 31 * result + this.maxQ;
        result = 31 * result + Arrays.hashCode(this.masterSecret);
        result = 31 * result + (this.publicKey != null ? this.publicKey.hashCode() : 0);
        return result;
    }

    @Override
    public byte[] getEncoded() throws IOException {
        return Composer.compose().u32str(0).u32str(this.parameters.getType()).u32str(this.otsParameters.getType()).bytes(this.I).u32str(this.q).u32str(this.maxQ).u32str(this.masterSecret.length).bytes(this.masterSecret).build();
    }

    static {
        LMSPrivateKeyParameters.internedKeys[1] = T1;
        for (int i = 2; i < internedKeys.length; ++i) {
            LMSPrivateKeyParameters.internedKeys[i] = new CacheKey(i);
        }
    }

    private static class CacheKey {
        private final int index;

        CacheKey(int index) {
            this.index = index;
        }

        public int hashCode() {
            return this.index;
        }

        public boolean equals(Object o) {
            if (o instanceof CacheKey) {
                return ((CacheKey)o).index == this.index;
            }
            return false;
        }
    }
}

