Skip to content

Commit 8a99071

Browse files
author
Diego Ernst
authored
Minor refactor and doc changes (#68)
1 parent c06c8e1 commit 8a99071

26 files changed

+183
-159
lines changed

Bender.xcodeproj/project.pbxproj

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
/* Begin PBXBuildFile section */
1010
28F828811C494B2C00330CF4 /* Bender.h in Headers */ = {isa = PBXBuildFile; fileRef = 28F828801C494B2C00330CF4 /* Bender.h */; settings = {ATTRIBUTES = (Public, ); }; };
1111
46354BC41EEF320700B083EF /* DependencyListBuilder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 46354BC31EEF320700B083EF /* DependencyListBuilder.swift */; };
12+
4675C5561F4209DF00895175 /* Device.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4675C5551F4209DF00895175 /* Device.swift */; };
1213
4697C2191ED31F8D003ED2E9 /* local_response_norm.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4697C2181ED31F8D003ED2E9 /* local_response_norm.metal */; };
1314
4697C21B1ED35E50003ED2E9 /* LocalResponseNorm.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4697C21A1ED35E50003ED2E9 /* LocalResponseNorm.swift */; };
1415
46FAFCB91F029B09008754E1 /* Concat.swift in Sources */ = {isa = PBXBuildFile; fileRef = 46FAFCB81F029B09008754E1 /* Concat.swift */; };
@@ -91,6 +92,7 @@
9192
28F828801C494B2C00330CF4 /* Bender.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = Bender.h; sourceTree = "<group>"; };
9293
28F828821C494B2C00330CF4 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
9394
46354BC31EEF320700B083EF /* DependencyListBuilder.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DependencyListBuilder.swift; sourceTree = "<group>"; };
95+
4675C5551F4209DF00895175 /* Device.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Device.swift; sourceTree = "<group>"; };
9496
4697C2181ED31F8D003ED2E9 /* local_response_norm.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = local_response_norm.metal; sourceTree = "<group>"; };
9597
4697C21A1ED35E50003ED2E9 /* LocalResponseNorm.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LocalResponseNorm.swift; sourceTree = "<group>"; };
9698
46FAFCB81F029B09008754E1 /* Concat.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Concat.swift; sourceTree = "<group>"; };
@@ -269,6 +271,7 @@
269271
8FA984001EBA70F900586D1B /* NetworkLayer.swift */,
270272
8FA8E32F1EC0FE2700E8BAD8 /* Operators.swift */,
271273
8F60A31C1EC4E6E400296167 /* PaddingType.swift */,
274+
4675C5551F4209DF00895175 /* Device.swift */,
272275
);
273276
path = Core;
274277
sourceTree = "<group>";
@@ -484,6 +487,7 @@
484487
8FAB77221ECC86170050AB16 /* attr_value.pb.swift in Sources */,
485488
8FFF4AD71EDEFDB600141E9D /* Shape.swift in Sources */,
486489
8FA8E3301EC0FE2700E8BAD8 /* Operators.swift in Sources */,
490+
4675C5561F4209DF00895175 /* Device.swift in Sources */,
487491
8F79B8031EC23563002D02C8 /* ResidualLayer.swift in Sources */,
488492
8FAB77291ECC86170050AB16 /* versions.pb.swift in Sources */,
489493
8FA8E3211EC0B8B500E8BAD8 /* BGRAtoRGBA.swift in Sources */,

Documentation/API.md

+6-9
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ This document explains the basic API in __Bender__.
1616
To create a network model you can create it from scratch or [import](Importing.md) it from a TensorFlow graph. We will explain how to create a network from scratch:
1717

1818
```swift
19-
let network = Network(device: device,
20-
inputSize: inputSize,
19+
let network = Network(inputSize: inputSize,
2120
parameterLoader: loader)
2221

2322
network.start
@@ -33,7 +32,7 @@ network.start
3332
network.initialize()
3433
```
3534

36-
First, we have to create the `network` which receives the MTLDevice (GPU), an inputSize and a parameter loader. The network comes with a `start` node which is the starting point of the network. The `inputSize` is the size expected by the first layer in the network. If the images you pass the network to be processed are not of the expected size then the `start` node will resize them accordingly.
35+
First, we have to create the `network` which receives the inputSize and a parameter loader. The network comes with a `start` node which is the starting point of the network. The `inputSize` is the size expected by the first layer in the network. If the images you pass the network to be processed are not of the expected size then the `start` node will resize them accordingly.
3736

3837
The `parameterLoader` is responsible for loading the weights for each layer. It will be explained in detail further below.
3938

@@ -54,16 +53,14 @@ After you finish adding layers to your network, you must call `network.initializ
5453

5554
## Running a network
5655

57-
To run a network call `run(...)`:
56+
To run a network call `run(/* ... */)`:
5857

5958
```swift
60-
let commandQueue: MTLCommandQueue = ...
61-
6259
// get image from somewhere
63-
let image = MPSImage(...)
60+
let image = MPSImage(/* ... */)
6461

65-
network.run(inputImage: image, queue: commandQueue) { outputImage in
66-
...
62+
network.run(input: image) { output in
63+
// ...
6764
}
6865
```
6966

Documentation/Importing.md

+5-10
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,16 @@ benderthon tf-freeze checkpoint_path.ckpt graph_with_weights.pb Output_Node_Name
2828
To import a model saved in a [Protobuf](https://developers.google.com/protocol-buffers/) file you must add it to your Xcode project and load it like this:
2929

3030
```swift
31-
// Define network
32-
network = Network(device: device, inputSize: inputSize, parameterLoader: nil)
3331

34-
// Load graph file
32+
// Set an url pointing to your model
3533
let url = Bundle.main.url(forResource: "myGraph", withExtension: "pb")!
3634

37-
// Create converter
35+
// Create the converter
3836
let converter = TFConverter.default()
3937

40-
// Convert graph
41-
network.convert(converter: converter, url: url, type: .binary)
38+
// Load it
39+
let network = Network.load(url: url, converter: converter, inputSize: LayerSize(h: 256, w: 256, f: 3))
4240

43-
// Initialize network
44-
network.initialize()
4541
```
4642

4743
`TFConverter` is the class responsible for converting a TF model to Bender. It will try to map nodes or groups of nodes in the TF graph to Bender layers. If it encounters unknown nodes then it will ignore them. This means that a graph might be disconnected if your TF model uses functions that are not implemented in Bender.
@@ -79,8 +75,7 @@ func optimize(graph: TFGraph)
7975
After you create the optimizer, you have to add it to your `TFConverter` like this:
8076

8177
```swift
82-
let converter = TFConverter.default()
83-
converter.optimizers.append(MyTFOptimizer())
78+
let converter = TFConverter.default(additionalOptimizers: [MyTFOptimizer()])
8479
```
8580

8681
#### Removing nodes

Example/Example/GrayScale.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class GrayScale: NetworkLayer {
2727
assert(getIncoming().count == 1, "GrayScale must have one input")
2828
let incoming = getIncoming()[0]
2929
assert(incoming.outputSize.f <= 4, "GrayScale input must have at most 4 feature channels")
30-
outputSize = LayerSize(f: outputChannels, w: incoming.outputSize.w, h: incoming.outputSize.h)
30+
outputSize = LayerSize(h: incoming.outputSize.h, w: incoming.outputSize.w, f: outputChannels)
3131
outputImage = MPSImage(device: device, imageDescriptor: MPSImageDescriptor(layerSize: outputSize))
3232
}
3333

Example/Example/StyleTransferViewController.swift

+10-13
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,15 @@ class StyleTransferViewController: UIViewController, ExampleViewController {
1515

1616
var styleNet: Network!
1717
var styleNet2: Network!
18-
var device: MTLDevice!
1918
var commandQueue: MTLCommandQueue!
20-
let inputSize = LayerSize(f: 3, w: 256)
19+
let inputSize = LayerSize(h: 256, w: 256, f: 3)
2120

2221
var pixelBufferPool: CVPixelBufferPool?
2322
@IBOutlet weak var imageView: UIImageView!
2423

2524
override func viewDidLoad() {
2625
super.viewDidLoad()
27-
self.device = MTLCreateSystemDefaultDevice()!
28-
self.commandQueue = device.makeCommandQueue()
26+
self.commandQueue = Device.shared.makeCommandQueue()
2927
setupStyleNet()
3028
var me = self
3129
me.setPixelBufferPool()
@@ -39,25 +37,24 @@ class StyleTransferViewController: UIViewController, ExampleViewController {
3937
func setupStyleNet() {
4038
measure("Set up takes:") {
4139

42-
styleNet = Network(device: device, inputSize: inputSize, parameterLoader: nil)
43-
4440
let url = Bundle.main.url(forResource: "g_and_w2", withExtension: "pb")!
45-
let converter = TFConverter.default()
46-
converter.optimizers.append(TFInstanceNormOptimizer())
41+
let converter = TFConverter.default(additionalOptimizers: [TFInstanceNormOptimizer()])
4742

48-
styleNet.convert(converter: converter, url: url, type: .binary)
43+
styleNet = Network.load(url: url, converter: converter, inputSize: inputSize, performInitialize: false)
4944
styleNet.addPostProcessing(layers: [ImageLinearTransform()])
5045

46+
// after adding all our layers we are able to initialize the network
47+
5148
styleNet.initialize()
5249
}
5350
}
5451

5552
@IBAction func runNetwork(_ sender: Any) {
5653
let buffer = commandQueue.makeCommandBuffer()
57-
let image = loadTestImage(device: device, commandBuffer: buffer)
54+
let image = loadTestImage(commandBuffer: buffer)
5855
buffer.commit()
5956
buffer.waitUntilCompleted()
60-
styleNet.run(inputImage: image, queue: commandQueue) { [weak self] imageA in
57+
styleNet.run(input: image, queue: commandQueue) { [weak self] imageA in
6158
if let buffer = self?.getPixelBuffer(from: imageA.texture, bufferPool: self!.pixelBufferPool!) {
6259
let ciImage = CIImage(cvImageBuffer: buffer)
6360
let context = CIContext()
@@ -70,9 +67,9 @@ class StyleTransferViewController: UIViewController, ExampleViewController {
7067
}
7168
}
7269

73-
func loadTestImage(device: MTLDevice, commandBuffer: MTLCommandBuffer) -> MPSImage{
70+
func loadTestImage(commandBuffer: MTLCommandBuffer) -> MPSImage{
7471
// INPUT IMAGE
75-
let textureLoader = MTKTextureLoader(device: device)
72+
let textureLoader = MTKTextureLoader(device: Device.shared)
7673
let inputTexture = try! textureLoader.newTexture(withContentsOf: Bundle.main.url(forResource: "wall-e", withExtension: "png")!, options: [MTKTextureLoaderOptionSRGB : NSNumber(value: false)])
7774
return MPSImage(texture: inputTexture, featureChannels: 3)
7875
}

Example/Example/Tests/BenderTestRunner.swift

-6
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@ import MetalKit
1010

1111
public class BenderTest {
1212

13-
var device: MTLDevice!
14-
15-
init() {
16-
device = MTLCreateSystemDefaultDevice()
17-
}
18-
1913
func run(completion: @escaping (Void) -> ()) {
2014
completion()
2115
}

Example/Example/Tests/ConcatTest.swift

+12-12
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct ConcatDataSet {
3333
[Float].init(repeating: 3, count: texture1Depth), [Float].init(repeating: 4, count: texture1Depth),
3434
[Float].init(repeating: 5, count: texture1Depth), [Float].init(repeating: 6, count: texture1Depth),
3535
],
36-
size: LayerSize(f: texture1Depth, w: 2, h: 3)
36+
size: LayerSize(h: 3, w: 2, f: texture1Depth)
3737
)
3838
let texture2Depth = depth
3939
let texture2 = Texture(
@@ -48,7 +48,7 @@ struct ConcatDataSet {
4848
[Float].init(repeating: 7, count: texture2Depth), [Float].init(repeating: 8, count: texture2Depth),
4949
[Float].init(repeating: 8, count: texture2Depth),
5050
],
51-
size: LayerSize(f: texture2Depth, w: 3, h: 3)
51+
size: LayerSize(h: 3, w: 3, f: texture2Depth)
5252
)
5353
let expectedDepth = depth
5454
let expectedData: [[Float]] = [
@@ -68,7 +68,7 @@ struct ConcatDataSet {
6868

6969
let expected: Texture = Texture(
7070
data: expectedData,
71-
size: LayerSize(f: expectedDepth, w: 5, h: 3)
71+
size: LayerSize(h: 3, w: 5, f: expectedDepth)
7272
)
7373
return (inputTextures: [texture1, texture2], axis: .w, expected: expected)
7474
}
@@ -81,7 +81,7 @@ struct ConcatDataSet {
8181
[Float].init(repeating: 3, count: texture1Depth), [Float].init(repeating: 4, count: texture1Depth),
8282
[Float].init(repeating: 5, count: texture1Depth), [Float].init(repeating: 6, count: texture1Depth),
8383
],
84-
size: LayerSize(f: texture1Depth, w: 2, h: 3)
84+
size: LayerSize(h: 3, w: 2, f: texture1Depth)
8585
)
8686
let texture2Depth = depth
8787
let texture2 = Texture(
@@ -92,7 +92,7 @@ struct ConcatDataSet {
9292
[Float].init(repeating: 9, count: texture1Depth), [Float].init(repeating: 7, count: texture1Depth),
9393
[Float].init(repeating: 3, count: texture1Depth), [Float].init(repeating: 3, count: texture1Depth),
9494
],
95-
size: LayerSize(f: texture2Depth, w: 2, h: 5)
95+
size: LayerSize(h: 5, w: 2, f: texture2Depth)
9696
)
9797
let expectedDepth = depth
9898
let expectedData: [[Float]] = [
@@ -108,7 +108,7 @@ struct ConcatDataSet {
108108

109109
let expected: Texture = Texture(
110110
data: expectedData,
111-
size: LayerSize(f: expectedDepth, w: 2, h: 8)
111+
size: LayerSize(h: 8, w: 2, f: expectedDepth)
112112
)
113113
return (inputTextures: [texture1, texture2], axis: .h, expected: expected)
114114
}
@@ -121,7 +121,7 @@ struct ConcatDataSet {
121121
[Float].init(repeating: 3, count: texture1Depth), [Float].init(repeating: 4, count: texture1Depth),
122122
[Float].init(repeating: 5, count: texture1Depth), [Float].init(repeating: 6, count: texture1Depth),
123123
],
124-
size: LayerSize(f: texture1Depth, w: 2, h: 3)
124+
size: LayerSize(h: 3, w: 2, f: texture1Depth)
125125
)
126126
let texture2Depth = depth2
127127
let texture2 = Texture(
@@ -130,7 +130,7 @@ struct ConcatDataSet {
130130
[Float].init(repeating: 9, count: texture2Depth), [Float].init(repeating: 10, count: texture2Depth),
131131
[Float].init(repeating: 11, count: texture2Depth), [Float].init(repeating: 12, count: texture2Depth),
132132
],
133-
size: LayerSize(f: texture2Depth, w: 2, h: 3)
133+
size: LayerSize(h: 3, w: 2, f: texture2Depth)
134134
)
135135
let expectedDepth = depth1 + depth2
136136
let expectedData: [[Float]] = [
@@ -147,7 +147,7 @@ struct ConcatDataSet {
147147

148148
let expected: Texture = Texture(
149149
data: expectedData,
150-
size: LayerSize(f: expectedDepth, w: 2, h: 3)
150+
size: LayerSize(h: 3, w: 2, f: expectedDepth)
151151
)
152152
return (inputTextures: [texture1, texture2], axis: .f, expected: expected)
153153
}
@@ -167,15 +167,15 @@ class ConcatTest: BenderTest {
167167
}
168168

169169
func test(inputTextures: [Texture], axis: LayerSizeAxis, expectedOutput: Texture, completion: @escaping (Void) -> ()) {
170-
let styleNet = Network(device: device, inputSize: inputTextures[0].size, parameterLoader: SingleBinaryLoader(checkpoint: "lala"))
170+
let styleNet = Network(inputSize: inputTextures[0].size)
171171

172172
styleNet.start
173173
->> inputTextures.map { Constant(outputTexture: $0) }
174174
->> Concat(axis: axis)
175175

176176
styleNet.initialize()
177-
let metalTexture = inputTextures[0].metalTexture(with: device)
178-
styleNet.run(inputImage: MPSImage(texture: metalTexture, featureChannels: inputTextures[0].depth), queue: device.makeCommandQueue()) { image in
177+
let metalTexture = inputTextures[0].metalTexture(with: Device.shared)
178+
styleNet.run(input: MPSImage(texture: metalTexture, featureChannels: inputTextures[0].depth)) { image in
179179
let textureFromGpu = Texture(metalTexture: image.texture, size: expectedOutput.size)
180180
assert(textureFromGpu.isEqual(to: expectedOutput))
181181
completion()

Example/Example/Tests/Helpers/TestData.swift

+5-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct TestData {
1919
[17.0, 18, 19, 20, 21, 22, 23, 24], [25, 26, 27, 28, 29, 30, 31, 32],
2020
[33, 34, 35, 36, 37, 38, 39, 40], [41, 42, 43, 44, 45, 46, 47, 48]
2121
],
22-
size: LayerSize(f: 8, w: 2, h: 3)
22+
size: LayerSize(h: 3, w: 2, f: 8)
2323
))
2424

2525
textures.append(Texture(
@@ -28,7 +28,7 @@ struct TestData {
2828
[17.0, 18, 19, 20, 21, 22], [25, 26, 27, 28, 29, 30],
2929
[33, 34, 35, 36, 37, 38], [41, 42, 43, 44, 45, 46]
3030
],
31-
size: LayerSize(f: 6, w: 2, h: 3)
31+
size: LayerSize(h: 3, w: 2, f: 6)
3232
))
3333

3434
textures.append(Texture(
@@ -37,7 +37,7 @@ struct TestData {
3737
[17.0, 18, 19, 20], [25, 26, 27, 28],
3838
[33, 34, 35, 36], [41, 42, 43, 44]
3939
],
40-
size: LayerSize(f: 4, w: 2, h: 3)
40+
size: LayerSize(h: 3, w: 2, f: 4)
4141
))
4242

4343
textures.append(Texture(
@@ -46,7 +46,7 @@ struct TestData {
4646
[17.0, 18, 19], [25, 26, 27],
4747
[33, 34, 35], [41, 42, 43]
4848
],
49-
size: LayerSize(f: 3, w: 2, h: 3)
49+
size: LayerSize(h: 3, w: 2, f: 3)
5050
))
5151

5252
textures.append(Texture(
@@ -55,7 +55,7 @@ struct TestData {
5555
[1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], [1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0],
5656
[1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], [1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
5757
],
58-
size: LayerSize(f: 9, w: 2, h: 3)
58+
size: LayerSize(h: 3, w: 2, f: 9)
5959
))
6060

6161
return textures

Example/Example/Tests/InstanceNormTests.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ class InstanceNormTest: BenderTest {
2424
}
2525

2626
func test(texture: Texture, completion: @escaping (Void) -> ()) {
27-
let styleNet = Network(device: device, inputSize: texture.size, parameterLoader: SingleBinaryLoader(checkpoint: "lala"))
27+
let styleNet = Network(inputSize: texture.size)
2828
let weights = [Float].init(repeating: Float(arc4random()) / Float(UINT32_MAX), count: texture.depth)
2929
let bias = [Float].init(repeating: Float(arc4random()) / Float(UINT32_MAX), count: texture.depth)
3030
let scale = Data.init(bytes: weights, count: texture.totalCount * MemoryLayout<Float>.stride)
3131
let shift = Data.init(bytes: bias, count: texture.totalCount * MemoryLayout<Float>.stride)
3232
styleNet.start ->> InstanceNorm(scale: scale, shift: shift)
3333
styleNet.initialize()
34-
let metalTexture = texture.metalTexture(with: device)
34+
let metalTexture = texture.metalTexture(with: Device.shared)
3535
let cpuComputed = cpuInstanceNorm(input: texture, weights: weights, bias: bias)
36-
styleNet.run(inputImage: MPSImage(texture: metalTexture, featureChannels: texture.depth), queue: device.makeCommandQueue()) { image in
36+
styleNet.run(input: MPSImage(texture: metalTexture, featureChannels: texture.depth)) { image in
3737
let textureFromGpu = Texture(metalTexture: image.texture, size: texture.size)
3838
assert(textureFromGpu.isEqual(to: cpuComputed, threshold: 0.002))
3939
completion()

Example/Example/Tests/LocalResponseNormTest.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ class LocalResponseNormTest: BenderTest {
3232
}
3333

3434
func test(texture: Texture, parameters: LocalResponseNorm.Parameters, completion: @escaping (Void) -> ()) {
35-
let styleNet = Network(device: device, inputSize: texture.size, parameterLoader: SingleBinaryLoader(checkpoint: "lala"))
35+
let styleNet = Network(inputSize: texture.size)
3636
styleNet.start ->> LocalResponseNorm(parameters: parameters, id: nil)
3737
styleNet.initialize()
38-
let metalTexture = texture.metalTexture(with: device)
38+
let metalTexture = texture.metalTexture(with: Device.shared)
3939
let cpuComputed = cpuLocalResponseNorm(input: texture, parameters: parameters)
40-
styleNet.run(inputImage: MPSImage(texture: metalTexture, featureChannels: texture.depth), queue: device.makeCommandQueue()) { image in
40+
styleNet.run(input: MPSImage(texture: metalTexture, featureChannels: texture.depth)) { image in
4141
let textureFromGpu = Texture(metalTexture: image.texture, size: texture.size)
4242
assert(textureFromGpu.isEqual(to: cpuComputed, threshold: 0.001))
4343
completion()

Example/Example/Tests/TextureConversionTest.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class TextureConversionTest: BenderTest {
1616
}
1717

1818
func test(texture: Texture) {
19-
let metalTexture = texture.metalTexture(with: device)
19+
let metalTexture = texture.metalTexture(with: Device.shared)
2020
assert(Texture(metalTexture: metalTexture, size: texture.size).isEqual(to: texture))
2121
}
2222

0 commit comments

Comments
 (0)