Skip to content

Commit

Permalink
Rename Shape.size(int) to get, add toListOrNull
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Nett <[email protected]>
  • Loading branch information
rnett committed Apr 23, 2021
1 parent f12824b commit 08c8511
Show file tree
Hide file tree
Showing 20 changed files with 86 additions and 91 deletions.
27 changes: 24 additions & 3 deletions ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.tensorflow.ndarray;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* The shape of a Tensor or {@link NdArray}.
Expand Down Expand Up @@ -114,7 +116,7 @@ public long size() {
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
* otherwise.
*/
public long size(int i) {
public long get(int i) {
if (dimensionSizes == null) {
return UNKNOWN_SIZE;
} else if (i >= 0) {
Expand Down Expand Up @@ -166,7 +168,7 @@ public boolean isUnknown() {
}

/**
* Returns a defensive copy of the this Shape's axes. Changes to the returned array to not change
* Returns a defensive copy of the this Shape's axes. Changes to the returned array do not change
* this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
*/
public long[] asArray() {
Expand All @@ -177,6 +179,25 @@ public long[] asArray() {
}
}


/**
* Returns a defensive copy of the this Shape's axes. Changes to the returned list do not change
* this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
*/
public List<Long> toListOrNull() {
long[] array = asArray();
if (array == null){
return null;
}

List<Long> list = new ArrayList<>(array.length);
for(long l : array) {
list.add(l);
}

return list;
}

@Override
public int hashCode() {
return dimensionSizes != null ? Arrays.hashCode(dimensionSizes) : super.hashCode();
Expand Down Expand Up @@ -423,7 +444,7 @@ public boolean isCompatibleWith(Shape shape) {
return false;
}
for (int i = 0; i < numDimensions(); i++) {
if (!isCompatible(size(i), shape.size(i))) {
if (!isCompatible(get(i), shape.get(i))) {
return false;
}
}
Expand Down
4 changes: 2 additions & 2 deletions ndarray/src/main/java/org/tensorflow/ndarray/StdArrays.java
Original file line number Diff line number Diff line change
Expand Up @@ -3798,9 +3798,9 @@ private static int[] computeArrayDims(NdArray<?> ndArray, int expectedRank) {
}
int[] arrayShape = new int[expectedRank];
for (int i = 0; i < expectedRank; ++i) {
long dimSize = shape.size(i);
long dimSize = shape.get(i);
if (dimSize > Integer.MAX_VALUE) {
throw new IllegalArgumentException("Dimension " + i + " is too large to fit in a standard array (" + shape.size(i) + ")");
throw new IllegalArgumentException("Dimension " + i + " is too large to fit in a standard array (" + shape.get(i) + ")");
}
arrayShape[i] = (int)dimSize;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public static DimensionalSpace create(Shape shape) {

// Start from the last dimension, where all elements are continuous
for (int i = dimensions.length - 1, elementSize = 1; i >= 0; --i) {
dimensions[i] = new Axis(shape.size(i), elementSize);
dimensions[i] = new Axis(shape.get(i), elementSize);
elementSize *= dimensions[i].numElements();
}
return new DimensionalSpace(dimensions, shape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ public void iterateElements() {
long value = 0L;
for (NdArray<T> matrix : matrix3d.elements(0)) {
assertEquals(2L, matrix.shape().numDimensions());
assertEquals(4L, matrix.shape().size(0));
assertEquals(5L, matrix.shape().size(1));
assertEquals(4L, matrix.shape().get(0));
assertEquals(5L, matrix.shape().get(1));

for (NdArray<T> vector : matrix.elements(0)) {
assertEquals(1L, vector.shape().numDimensions()) ;
assertEquals(5L, vector.shape().size(0));
assertEquals(5L, vector.shape().get(0));

for (NdArray<T> scalar : vector.scalars()) {
assertEquals(0L, scalar.shape().numDimensions()) ;
Expand Down
18 changes: 9 additions & 9 deletions ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@ public class ShapeTest {
public void allKnownDimensions() {
Shape shape = Shape.of(5, 4, 5);
assertEquals(3, shape.numDimensions());
assertEquals(5, shape.size(0));
assertEquals(4, shape.size(1));
assertEquals(5, shape.size(2));
assertEquals(5, shape.get(0));
assertEquals(4, shape.get(1));
assertEquals(5, shape.get(2));
assertEquals(100, shape.size());
assertArrayEquals(new long[] {5, 4, 5}, shape.asArray());
try {
shape.size(3);
shape.get(3);
fail();
} catch (IndexOutOfBoundsException e) {
// as expected
}
assertEquals(5, shape.size(-1));
assertEquals(4, shape.size(-2));
assertEquals(5, shape.size(-3));
assertEquals(5, shape.get(-1));
assertEquals(4, shape.get(-2));
assertEquals(5, shape.get(-3));
try {
shape.size(-4);
shape.get(-4);
fail();
} catch (IndexOutOfBoundsException e) {
// as expected
Expand Down Expand Up @@ -133,7 +133,7 @@ public void testShapeModification() {
long[] internalShape = one.asArray();
assertNotNull(internalShape);
internalShape[0] = 42L;
assertEquals(2L, one.size(0));
assertEquals(2L, one.get(0));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ private static TensorInfo toTensorInfo(Output<?> operand) {
Shape shape = operand.shape();
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
for (int i = 0; i < shape.numDimensions(); ++i) {
tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i)));
tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.get(i)));
}
return TensorInfo.newBuilder()
.setDtype(operand.dataType())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ public static <T extends TNumber> Operand<T> sigmoidCrossEntropyWithLogits(
private static boolean isCompatible(Shape shape, Shape other) {
if (shape.numDimensions() != other.numDimensions()) return false;
for (int i = 0; i < shape.numDimensions(); i++) {
long aShapeDim = shape.size(i);
long bShapeDim = other.size(i);
long aShapeDim = shape.get(i);
long bShapeDim = other.get(i);
if (aShapeDim == bShapeDim
|| (aShapeDim == Shape.UNKNOWN_SIZE || bShapeDim == Shape.UNKNOWN_SIZE)) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TType;

import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -124,10 +123,10 @@ public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntr
axis = shape.numDimensions() + axis;
}
for (int i = 0; i < axis; i++) {
newArray[i] = shape.size(i);
newArray[i] = shape.get(i);
}
for (int i = axis + 1; i < shape.numDimensions(); i++) {
newArray[i - 1] = shape.size(i);
newArray[i - 1] = shape.get(i);
}
cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newArray));
}
Expand All @@ -152,15 +151,15 @@ private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Oper
long product = 1L;
boolean productValid = true;
for (int i = ndims - 2; i >= 0; i--) {
long d = shape.size(i);
long d = shape.get(i);
if (d == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE) {
productValid = false;
break;
}
product *= d;
}
if (productValid) {
return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.size(-1)));
return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.get(-1)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public static <T extends TNumber, U extends TNumber> Operand sparseSoftmaxCrossE
}

// Reshape logits to 2 dims, labels to 1 dim.
long numClassses = logitsShape.size(-1);
long numClassses = logitsShape.get(-1);

preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses));
labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ public void outputDataTypeAndShape() {
.setAttr("value", t)
.build();
assertEquals(DataType.DT_INT32, op.dtype(0));
assertEquals(2, op.shape(0).size(0));
assertEquals(3, op.shape(0).size(1));
assertEquals(2, op.shape(0).get(0));
assertEquals(3, op.shape(0).get(1));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ public void setAttrShape() {
.build()
.output(0);
assertEquals(2, n.shape().numDimensions());
assertEquals(-1, n.shape().size(0));
assertEquals(784, n.shape().size(1));
assertEquals(-1, n.shape().get(0));
assertEquals(784, n.shape().get(1));
assertEquals(DataType.DT_FLOAT, n.dataType());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public void exportFunctionWithVariables() throws IOException {
assertNotNull(inputInfo);
assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount());
for (int i = 0; i < xyShape.numDimensions(); ++i) {
assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize());
assertEquals(xyShape.get(i), inputInfo.getTensorShape().getDim(i).getSize());
}

TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ public void nDimensional() {
assertEquals(TFloat64.class, t.type());
assertEquals(DataType.DT_DOUBLE, t.dataType());
assertEquals(1, t.shape().numDimensions());
assertEquals(3, t.shape().size(0));
assertEquals(3, t.shape().get(0));
assertEquals(vector, t);
}

Expand All @@ -334,8 +334,8 @@ public void nDimensional() {
assertEquals(TInt32.class, t.type());
assertEquals(DataType.DT_INT32, t.dataType());
assertEquals(2, t.shape().numDimensions());
assertEquals(2, t.shape().size(0));
assertEquals(3, t.shape().size(1));
assertEquals(2, t.shape().get(0));
assertEquals(3, t.shape().get(1));
assertEquals(matrix, t);
}

Expand All @@ -346,9 +346,9 @@ public void nDimensional() {
assertEquals(TInt64.class, t.type());
assertEquals(DataType.DT_INT64, t.dataType());
assertEquals(3, t.shape().numDimensions());
assertEquals(2, t.shape().size(0));
assertEquals(5, t.shape().size(1));
assertEquals(1, t.shape().size(2));
assertEquals(2, t.shape().get(0));
assertEquals(5, t.shape().get(1));
assertEquals(1, t.shape().get(2));
assertEquals(threeD, t);
}

Expand All @@ -361,10 +361,10 @@ public void nDimensional() {
assertEquals(TBool.class, t.type());
assertEquals(DataType.DT_BOOL, t.dataType());
assertEquals(4, t.shape().numDimensions());
assertEquals(3, t.shape().size(0));
assertEquals(1, t.shape().size(1));
assertEquals(2, t.shape().size(2));
assertEquals(4, t.shape().size(3));
assertEquals(3, t.shape().get(0));
assertEquals(1, t.shape().get(1));
assertEquals(2, t.shape().get(2));
assertEquals(4, t.shape().get(3));
assertEquals(fourD, t);
}
}
Expand All @@ -381,8 +381,8 @@ public void testNDimensionalStringTensor() {
assertEquals(TString.class, t.type());
assertEquals(DataType.DT_STRING, t.dataType());
assertEquals(2, t.shape().numDimensions());
assertEquals(4, t.shape().size(0));
assertEquals(3, t.shape().size(1));
assertEquals(4, t.shape().get(0));
assertEquals(3, t.shape().get(1));
assertEquals(matrix, t);
}

Expand All @@ -392,8 +392,8 @@ public void testNDimensionalStringTensor() {
assertEquals(TString.class, t.type());
assertEquals(DataType.DT_STRING, t.dataType());
assertEquals(2, t.shape().numDimensions());
assertEquals(4, t.shape().size(0));
assertEquals(3, t.shape().size(1));
assertEquals(4, t.shape().get(0));
assertEquals(3, t.shape().get(1));
assertEquals(byteMatrix, t.asBytes());
assertEquals(matrix, t);
}
Expand All @@ -406,7 +406,7 @@ public void testUint8TensorFromArray() {
assertEquals(TUint8.class, t.type());
assertEquals(DataType.DT_UINT8, t.dataType());
assertEquals(1, t.shape().numDimensions());
assertEquals(4, t.shape().size(0));
assertEquals(4, t.shape().get(0));

byte[] got = new byte[4];
t.read(DataBuffers.of(got));
Expand All @@ -421,7 +421,7 @@ public void testCreateFromArrayOfBoxed() {
assertEquals(TInt32.class, t.type());
assertEquals(DataType.DT_INT32, t.dataType());
assertEquals(1, t.shape().numDimensions());
assertEquals(4, t.shape().size(0));
assertEquals(4, t.shape().get(0));

Integer[] got = new Integer[4];
t.read(DataBuffers.ofObjects(got));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ public Operand<T> call(Operand<TInt64> dims, Class<T> type) {
if (shape.numDimensions() != 2) {
throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions());
}
boolean isSquare = shape.size(0) == shape.size(1);
long diagSize = Math.min(shape.size(0), shape.size(1));
boolean isSquare = shape.get(0) == shape.get(1);
long diagSize = Math.min(shape.get(0), shape.get(1));
Shape diagShape = Shape.of(diagSize);

Operand<T> op;
Expand All @@ -83,8 +83,8 @@ public Operand<T> call(Operand<TInt64> dims, Class<T> type) {
tf.linalg.matrixDiag(
diagOnes,
tf.constant(0), // don't cast here, expecting TInt32
tf.constant((int) shape.size(0)),
tf.constant((int) shape.size(1)),
tf.constant((int) shape.get(0)),
tf.constant((int) shape.get(1)),
zero);
} else {
Operand<T> zeroMatrix = tf.zeros(dims, type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ public Operand<T> call(Operand<TInt64> dims, Class<T> type) {
}
long numRows = 1;
int i = 0;
for (; i < dimsShape.numDimensions() - 1; i++) numRows *= dimsShape.size(i);
long numCols = dimsShape.size(i);
for (; i < dimsShape.numDimensions() - 1; i++) numRows *= dimsShape.get(i);
long numCols = dimsShape.get(i);
Shape flatShape = Shape.of(Math.max(numRows, numCols), Math.min(numRows, numCols));
long[] seeds = {seed, 0};
Operand<T> op =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ public static <T extends TNumber> Operand<T> sparseCategoricalCrossentropy(
tf.reshape(
predictions,
tf.constant(
new long[] {-1L, predictionsShape.size(predictionsShape.numDimensions() - 1)}));
new long[] {-1L, predictionsShape.get(predictionsShape.numDimensions() - 1)}));
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -643,7 +643,7 @@ private static <T extends TNumber> Operand<T> smoothCategoricalLabels(
Operand<T> smoothing = cast(tf, tf.constant(labelSmoothing), labelType);
Shape labelsShape = labels.shape();
int numDims = labelsShape.numDimensions();
Operand<T> numClasses = cast(tf, tf.constant(labelsShape.size(numDims - 1)), labelType);
Operand<T> numClasses = cast(tf, tf.constant(labelsShape.get(numDims - 1)), labelType);
Operand<T> oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), labelType);
return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), tf.math.div(smoothing, numClasses));
}
Expand Down
Loading

0 comments on commit 08c8511

Please sign in to comment.