diff --git a/tfjs-converter/src/operations/operation_mapper.ts b/tfjs-converter/src/operations/operation_mapper.ts index a0e2e0787b1..a9eccdff00a 100644 --- a/tfjs-converter/src/operations/operation_mapper.ts +++ b/tfjs-converter/src/operations/operation_mapper.ts @@ -499,6 +499,9 @@ export function parseDtypeParam(value: string|tensorflow.DataType): DataType { return 'float32'; case tensorflow.DataType.DT_STRING: return 'string'; + case tensorflow.DataType.DT_COMPLEX64: + case tensorflow.DataType.DT_COMPLEX128: + return 'complex64'; default: // Unknown dtype error will happen at runtime (instead of parse time), // since these nodes might not be used by the actual subgraph execution. diff --git a/tfjs-converter/src/operations/operation_mapper_test.ts b/tfjs-converter/src/operations/operation_mapper_test.ts index 522d95382b3..527094b1b6c 100644 --- a/tfjs-converter/src/operations/operation_mapper_test.ts +++ b/tfjs-converter/src/operations/operation_mapper_test.ts @@ -155,6 +155,12 @@ const SIMPLE_MODEL: tensorflow.IGraphDef = { input: ['BiasAdd'], attr: {DstT: {type: tensorflow.DataType.DT_HALF}} }, + { + name: 'Cast4', + op: 'Cast', + input: ['BiasAdd'], + attr: {DstT: {type: tensorflow.DataType.DT_COMPLEX64}} + } ], library: { function: [ @@ -310,7 +316,7 @@ describe('operationMapper without signature', () => { it('should find the graph output nodes', () => { expect(convertedGraph.outputs.map(node => node.name)).toEqual([ 'Fill', 'Squeeze', 'Squeeze2', 'Split', 'LogicalNot', - 'FusedBatchNorm', 'Cast2', 'Cast3' + 'FusedBatchNorm', 'Cast2', 'Cast3', 'Cast4' ]); }); @@ -324,7 +330,7 @@ describe('operationMapper without signature', () => { expect(Object.keys(convertedGraph.nodes)).toEqual([ 'image_placeholder', 'Const', 'Shape', 'Value', 'Fill', 'Conv2D', 'BiasAdd', 'Cast', 'Squeeze', 'Squeeze2', 'Split', 'LogicalNot', - 'FusedBatchNorm', 'Cast2', 'Cast3' + 'FusedBatchNorm', 'Cast2', 'Cast3', 'Cast4' ]); }); }); @@ -447,6 +453,10 @@ describe('operationMapper without signature', () => { expect(convertedGraph.nodes['Cast'].attrParams['dtype'].value) .toEqual('int32'); }); + it('should map params with complex64 dtype', () => { + expect(convertedGraph.nodes['Cast4'].attrParams['dtype'].value) + .toEqual('complex64'); + }); }); }); }); @@ -486,7 +496,7 @@ describe('operationMapper with signature', () => { expect(Object.keys(convertedGraph.nodes)).toEqual([ 'image_placeholder', 'Const', 'Shape', 'Value', 'Fill', 'Conv2D', 'BiasAdd', 'Cast', 'Squeeze', 'Squeeze2', 'Split', 'LogicalNot', - 'FusedBatchNorm', 'Cast2', 'Cast3' + 'FusedBatchNorm', 'Cast2', 'Cast3', 'Cast4' ]); }); }); @@ -552,6 +562,10 @@ describe('operationMapper with signature', () => { expect(convertedGraph.nodes['Cast3'].attrParams['dtype'].value) .toEqual('float32'); }); + it('should map params with complex64 dtype', () => { + expect(convertedGraph.nodes['Cast4'].attrParams['dtype'].value) + .toEqual('complex64'); + }); }); }); });