/* GENERATED SOURCE. DO NOT MODIFY. */ // © 2021 and later: Unicode, Inc. and others. // License & terms of use: http://www.unicode.org/copyright.html // /** * A LSTMBreakEngine */ package android.icu.impl.breakiter; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.text.CharacterIterator; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import android.icu.impl.ICUData; import android.icu.impl.ICUResourceBundle; import android.icu.lang.UCharacter; import android.icu.lang.UProperty; import android.icu.lang.UScript; import android.icu.text.BreakIterator; import android.icu.text.UnicodeSet; import android.icu.util.UResourceBundle; /** * @hide Only a subset of ICU is exposed in Android * @hide draft / provisional / internal are hidden on Android */ public class LSTMBreakEngine extends DictionaryBreakEngine { /** * @hide Only a subset of ICU is exposed in Android */ public enum EmbeddingType { UNKNOWN, CODE_POINTS, GRAPHEME_CLUSTER } /** * @hide Only a subset of ICU is exposed in Android */ public enum LSTMClass { BEGIN, INSIDE, END, SINGLE, } private static float[][] make2DArray(int[] data, int start, int d1, int d2) { byte[] bytes = new byte[4]; float [][] result = new float[d1][d2]; for (int i = 0; i < d1 ; i++) { for (int j = 0; j < d2 ; j++) { int d = data[start++]; bytes[0] = (byte) (d >> 24); bytes[1] = (byte) (d >> 16); bytes[2] = (byte) (d >> 8); bytes[3] = (byte) (d /*>> 0*/); result[i][j] = ByteBuffer.wrap(bytes).order(ByteOrder.BIG_ENDIAN).getFloat(); } } return result; } private static float[] make1DArray(int[] data, int start, int d1) { byte[] bytes = new byte[4]; float [] result = new float[d1]; for (int i = 0; i < d1 ; i++) { int d = data[start++]; bytes[0] = (byte) (d >> 24); bytes[1] = (byte) (d >> 16); bytes[2] = (byte) (d >> 8); bytes[3] = (byte) (d /*>> 0*/); result[i] = ByteBuffer.wrap(bytes).order(ByteOrder.BIG_ENDIAN).getFloat(); } return result; } /** @hide Only a subset of ICU is exposed in Android * @hide draft / provisional / internal are hidden on Android*/ public static class LSTMData { private LSTMData() { } public LSTMData(UResourceBundle rb) { int embeddings = rb.get("embeddings").getInt(); int hunits = rb.get("hunits").getInt(); this.fType = EmbeddingType.UNKNOWN; this.fName = rb.get("model").getString(); String typeString = rb.get("type").getString(); if (typeString.equals("codepoints")) { this.fType = EmbeddingType.CODE_POINTS; } else if (typeString.equals("graphclust")) { this.fType = EmbeddingType.GRAPHEME_CLUSTER; } String[] dict = rb.get("dict").getStringArray(); int[] data = rb.get("data").getIntVector(); int dataLen = data.length; int numIndex = dict.length; fDict = new HashMap(numIndex + 1); int idx = 0; for (String embedding : dict){ fDict.put(embedding, idx++); } int mat1Size = (numIndex + 1) * embeddings; int mat2Size = embeddings * 4 * hunits; int mat3Size = hunits * 4 * hunits; int mat4Size = 4 * hunits; int mat5Size = mat2Size; int mat6Size = mat3Size; int mat7Size = mat4Size; int mat8Size = 2 * hunits * 4; int mat9Size = 4; assert dataLen == mat1Size + mat2Size + mat3Size + mat4Size + mat5Size + mat6Size + mat7Size + mat8Size + mat9Size; int start = 0; this.fEmbedding = make2DArray(data, start, (numIndex+1), embeddings); start += mat1Size; this.fForwardW = make2DArray(data, start, embeddings, 4 * hunits); start += mat2Size; this.fForwardU = make2DArray(data, start, hunits, 4 * hunits); start += mat3Size; this.fForwardB = make1DArray(data, start, 4 * hunits); start += mat4Size; this.fBackwardW = make2DArray(data, start, embeddings, 4 * hunits); start += mat5Size; this.fBackwardU = make2DArray(data, start, hunits, 4 * hunits); start += mat6Size; this.fBackwardB = make1DArray(data, start, 4 * hunits); start += mat7Size; this.fOutputW = make2DArray(data, start, 2 * hunits, 4); start += mat8Size; this.fOutputB = make1DArray(data, start, 4); } public EmbeddingType fType; public String fName; public Map fDict; public float fEmbedding[][]; public float fForwardW[][]; public float fForwardU[][]; public float fForwardB[]; public float fBackwardW[][]; public float fBackwardU[][]; public float fBackwardB[]; public float fOutputW[][]; public float fOutputB[]; } // Minimum word size private static final byte MIN_WORD = 2; // Minimum number of characters for two words private static final byte MIN_WORD_SPAN = MIN_WORD * 2; abstract class Vectorizer { public Vectorizer(Map dict) { this.fDict = dict; } abstract public void vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd, List offsets, List indicies); protected int getIndex(String token) { Integer res = fDict.get(token); return (res == null) ? fDict.size() : res; } private Map fDict; } class CodePointsVectorizer extends Vectorizer { public CodePointsVectorizer(Map dict) { super(dict); } public void vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd, List offsets, List indicies) { fIter.setIndex(rangeStart); for (char c = fIter.current(); c != CharacterIterator.DONE && fIter.getIndex() < rangeEnd; c = fIter.next()) { offsets.add(fIter.getIndex()); indicies.add(getIndex(String.valueOf(c))); } } } class GraphemeClusterVectorizer extends Vectorizer { public GraphemeClusterVectorizer(Map dict) { super(dict); } private String substring(CharacterIterator text, int startPos, int endPos) { int saved = text.getIndex(); text.setIndex(startPos); StringBuilder sb = new StringBuilder(); for (char c = text.current(); c != CharacterIterator.DONE && text.getIndex() < endPos; c = text.next()) { sb.append(c); } text.setIndex(saved); return sb.toString(); } public void vectorize(CharacterIterator text, int startPos, int endPos, List offsets, List indicies) { BreakIterator iter = BreakIterator.getCharacterInstance(); iter.setText(text); int last = iter.next(startPos); for (int curr = iter.next(); curr != BreakIterator.DONE && curr <= endPos; curr = iter.next()) { offsets.add(last); String segment = substring(text, last, curr); int index = getIndex(segment); indicies.add(index); last = curr; } } } private final LSTMData fData; private int fScript; private final Vectorizer fVectorizer; private Vectorizer makeVectorizer(LSTMData data) { switch(data.fType) { case CODE_POINTS: return new CodePointsVectorizer(data.fDict); case GRAPHEME_CLUSTER: return new GraphemeClusterVectorizer(data.fDict); default: return null; } } public LSTMBreakEngine(int script, UnicodeSet set, LSTMData data) { setCharacters(set); this.fScript = script; this.fData = data; this.fVectorizer = makeVectorizer(this.fData); } @Override public int hashCode() { return getClass().hashCode(); } @Override public boolean handles(int c) { return fScript == UCharacter.getIntPropertyValue(c, UProperty.SCRIPT); } static private void addDotProductTo(final float [] a, final float[][] b, float[] result) { assert a.length == b.length; assert b[0].length == result.length; for (int i = 0; i < result.length; i++) { for (int j = 0; j < a.length; j++) { result[i] += a[j] * b[j][i]; } } } static private void addTo(final float [] a, float[] result) { assert a.length == result.length; for (int i = 0; i < result.length; i++) { result[i] += a[i]; } } static private void hadamardProductTo(final float [] a, float[] result) { assert a.length == result.length; for (int i = 0; i < result.length; i++) { result[i] *= a[i]; } } static private void addHadamardProductTo(final float [] a, final float [] b, float[] result) { assert a.length == result.length; assert b.length == result.length; for (int i = 0; i < result.length; i++) { result[i] += a[i] * b[i]; } } static private void sigmoid(float [] result, int start, int length) { assert start < result.length; assert start + length <= result.length; for (int i = start; i < start + length; i++) { result[i] = (float)(1.0/(1.0 + Math.exp(-result[i]))); } } static private void tanh(float [] result, int start, int length) { assert start < result.length; assert start + length <= result.length; for (int i = start; i < start + length; i++) { result[i] = (float)Math.tanh(result[i]); } } static private int maxIndex(float [] data) { int index = 0; float max = data[0]; for (int i = 1; i < data.length; i++) { if (data[i] > max) { max = data[i]; index = i; } } return index; } /* static private void print(float [] data) { for (int i=0; i < data.length; i++) { System.out.format(" %e", data[i]); if (i % 4 == 3) { System.out.println(); } } System.out.println(); } */ private float[] compute(final float[][] W, final float[][] U, final float[] B, final float[] x, float[] h, float[] c) { // ifco = x * W + h * U + b float[] ifco = Arrays.copyOf(B, B.length); addDotProductTo(x, W, ifco); float[] hU = new float[B.length]; addDotProductTo(h, U, ifco); int hunits = B.length / 4; sigmoid(ifco, 0*hunits, hunits); // i sigmoid(ifco, 1*hunits, hunits); // f tanh(ifco, 2*hunits, hunits); // c_ sigmoid(ifco, 3*hunits, hunits); // o hadamardProductTo(Arrays.copyOfRange(ifco, hunits, 2*hunits), c); addHadamardProductTo(Arrays.copyOf(ifco, hunits), Arrays.copyOfRange(ifco, 2*hunits, 3*hunits), c); h = Arrays.copyOf(c, c.length); tanh(h, 0, h.length); hadamardProductTo(Arrays.copyOfRange(ifco, 3*hunits, 4*hunits), h); // System.out.println("c"); // print(c); // System.out.println("h"); // print(h); return h; } @Override public int divideUpDictionaryRange(CharacterIterator fIter, int rangeStart, int rangeEnd, DequeI foundBreaks, boolean isPhraseBreaking) { int beginSize = foundBreaks.size(); if ((rangeEnd - rangeStart) < MIN_WORD_SPAN) { return 0; // Not enough characters for word } List offsets = new ArrayList(rangeEnd - rangeStart); List indicies = new ArrayList(rangeEnd - rangeStart); fVectorizer.vectorize(fIter, rangeStart, rangeEnd, offsets, indicies); // To save the needed memory usage, the following is different from the // Python or ICU4X implementation. We first perform the Backward LSTM // and then merge the iteration of the forward LSTM and the output layer // together because we only need to remember the h[t-1] for Forward LSTM. int inputSeqLength = indicies.size(); int hunits = this.fData.fForwardU.length; float c[] = new float[hunits]; // TODO: limit size of hBackward. If input_seq_len is too big, we could // run out of memory. // Backward LSTM float hBackward[][] = new float[inputSeqLength][hunits]; for (int i = inputSeqLength - 1; i >= 0; i--) { if (i != inputSeqLength - 1) { hBackward[i] = Arrays.copyOf(hBackward[i+1], hunits); } // System.out.println("Backward LSTM " + i); hBackward[i] = compute(this.fData.fBackwardW, this.fData.fBackwardU, this.fData.fBackwardB, this.fData.fEmbedding[indicies.get(i)], hBackward[i], c); } c = new float[hunits]; float forwardH[] = new float[hunits]; float both[] = new float[2*hunits]; // The following iteration merge the forward LSTM and the output layer // together. for (int i = 0 ; i < inputSeqLength; i++) { // Forward LSTM forwardH = compute(this.fData.fForwardW, this.fData.fForwardU, this.fData.fForwardB, this.fData.fEmbedding[indicies.get(i)], forwardH, c); System.arraycopy(forwardH, 0, both, 0, hunits); System.arraycopy(hBackward[i], 0, both, hunits, hunits); //System.out.println("Merged " + i); //print(both); // Output layer // logp = fbRow * fOutputW + fOutputB float logp[] = Arrays.copyOf(this.fData.fOutputB, this.fData.fOutputB.length); addDotProductTo(both, this.fData.fOutputW, logp); int current = maxIndex(logp); // BIES logic. if (current == LSTMClass.BEGIN.ordinal() || current == LSTMClass.SINGLE.ordinal()) { if (i != 0) { foundBreaks.push(offsets.get(i)); } } } return foundBreaks.size() - beginSize; } public static LSTMData createData(UResourceBundle bundle) { return new LSTMData(bundle); } private static String defaultLSTM(int script) { ICUResourceBundle rb = (ICUResourceBundle)UResourceBundle.getBundleInstance(ICUData.ICU_BRKITR_BASE_NAME); return rb.getStringWithFallback("lstm/" + UScript.getShortName(script)); } public static LSTMData createData(int script) { if (script != UScript.KHMER && script != UScript.LAO && script != UScript.MYANMAR && script != UScript.THAI) { return null; } String name = defaultLSTM(script); name = name.substring(0, name.indexOf(".")); UResourceBundle rb = UResourceBundle.getBundleInstance( ICUData.ICU_BRKITR_BASE_NAME, name, ICUResourceBundle.ICU_DATA_CLASS_LOADER); return createData(rb); } public static LSTMBreakEngine create(int script, LSTMData data) { String setExpr = "[[:" + UScript.getShortName(script) + ":]&[:LineBreak=SA:]]"; UnicodeSet set = new UnicodeSet(); set.applyPattern(setExpr); set.compact(); return new LSTMBreakEngine(script, set, data); } }