Coder Social home page Coder Social logo

swift-tfp's Introduction

Tensors Fitting Perfectly

"It’s the relief of finding ease where you expected struggle." ~The Atlantic

There are moments when the planets align and different objects fit perfectly with each other. While this amazing phenomenon has been observed in the physical world numerous times, few people have stopped to think how to make writing numerical programs feel just as good. We have, and this is how Tensors Fitting Perfectly got started.

TFP is a static analyzer for Swift programs that tries to detect tensor shape mismatches before you even attempt to run them. Note that TFP is not a type system, meaning that it does not try to prove that your program is correct. It only tries to prove that your program is incorrect, and will report errors only if it's sure that shape errors are guaranteed to occur.

This project is highly experimental and may unexpectedly change at any time.

How does it work?

Good question! TFP will invoke the Swift compiler to lower your Swift code down to SIL (Swift intermediate representation), and will use it to scan your code for assertions that pertain to shapes of tensors. Note that this step is a form of abstract interpretation, and is not guaranteed to actually recover all of those --- it is very much an approximation. Each one that it manages to understand gets added to a system of logical constraints that have to be satisfied if your program is to be correct. Note that those constraints will be propagated through e.g. function calls, so invariants discovered in called functions will be considered invariants of their caller too. Then, it will carefully query an SMT solver to verify whether the program looks correct, or whether there is an execution path that causes a shape failure.

The general idea is that the standard library should contain a number of assertions that both establish the shape semantics of the code, as well as verify some of the preconditions that need to be satisfied. Take matrix multiplication as an example:

func matmul(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
  let (n, mx) = x.shape2d
  let (my, k) = y.shape2d
  assert(mx == my)
  let r = TensorFlow.matmul(x, y)
  assert(r.shape == [n, k])
  return r
}

Once you use matmul (and similar library functions) in your program, TFP will be able to recover the relations that connect shapes of tensor values at different points and will try to verify them. Adding more assertions to your code (at a level higher than libraries) is beneficial, because:

  1. It will let TFP verify that what you believe is consistent with what the lower layer has specified in the form of assertions.
  2. Improve the quality of verification in case parts of the program could not be understood.

tl;dr Instead of encoding your shape contracts in comments or printing shapes to figure out what's happening, encode your thoughts and beliefs as assertions. Those have the benefit of being a machine-checked documentation (in debug mode only!), and (more importantly in this context) they will also make it more likely for the tool to find issues in your programs.

Notable limitations

Most of those will be lifted at some point in the future, but they will require extra work.

  • Currently only the Tensor type from the TensorFlow module is recognized as a multidimensional array.
  • Limited to a single file only (in particular there's no support for verification accross modules).

Recognized constraints

Here are a few examples of expressions that you could assert and have them be recognized by TFP.

x.rank == 2
x.rank == y.rank
x.shape == y.shape
x.shape[0] == y.shape[1]
x.shape[0] == 5
x.shape[0] == (y.shape[1] - z.shape[2] + 1) / 2
x.shape == [y.shape[0], 4]

Note that it's not the case that a full expression has to appear within the assert call. Those three asserts are actually equivalent from the point of view of TFP:

// 1.
assert(x.shape[0] == y.shape[1] + 2)

// 2.
let newShape = y.shape[1] + 2
assert(x.shape[0] == newShape)

// 3.
func getNewShape<T>(_ y: Tensor<T>) -> Int {
    return y.shape[1] + 2
}
let cond = x.shape[0] == getNewShape(y)
assert(cond)

(Semi-)Formal grammar of supported expressions

ShapeExpr ::= <variable>
            | [IntExpr, ..., IntExpr]
            // This is supported, but requires some hacky workarounds for now.
            | broadcast(ShapeExpr, ShapeExpr)

IntExpr   ::= <variable>
            | <literal>
            | ShapeExpr.rank
            | ShapeExpr[<constant>]
            | IntExpr + IntExpr
            | IntExpr - IntExpr
            | IntExpr * IntExpr
            | IntExpr / IntExpr

BoolExpr  ::= true
            | false
            | <variable>
            | IntExpr == IntExpr
            | IntExpr > IntExpr
            | IntExpr >= IntExpr
            | IntExpr < IntExpr
            | IntExpr <= IntExpr
            | ShapeExpr == ShapeExpr

How to use

Note that the tool requires you to install the Z3 SMT solver before you try to run it. It can be obtained from brew (as z3) or from apt (libz3-dev).

To analyze a file example.swift execute swift run doesitfit example.swift. You can find some examples to play with in the Examples/ directory.

We understand if you don't feel like doing it just yet, so we'll also walk you through a basic case. Assume that example.swift contains the following:

import TensorFlow

func randn(_ shape: TensorShape) -> Tensor<Float> {
  let result = Tensor<Float>(randomNormal: shape)
  assert(result.shape == shape)
  return result
}

func matmul(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
  assert(x.rank == 2)
  assert(y.rank == 2)
  assert(x.shape[1] == y.shape[0])
  let r = TensorFlow.matmul(x, y)
	assert(r.shape == [x.shape[0], y.shape[1]])
  return r
}

func f() -> Tensor<Float> {
  let x = randn([2, 3])
  return matmul(x, x)
}

the output you'll see will be similar to this:

In $s4main1f10TensorFlow0B0VySfGyF:
❌ Something doesn't fit!
  - 3 = 2
      Asserted at small.swift:12
            |   assert(y.rank == 2)
         12 |   assert(x.shape[1] == y.shape[0])
            |   let r = TensorFlow.matmul(x, y)

Each line starting with "$s" is actually a mangled name of a Swift function in your module, so e.g. $s4main1f10TensorFlow0B0VySfGyF really means main.f() -> TensorFlow.Tensor<Swift.Float>. In the future those will get demangled before we display them, but for now you can try piping the output through swift-demangle (if you have it installed). What follows is a message which either tells you that TFP doesn't see any issue (assuming that this function would get executed), or a list of assertions that shows that any attempt to execute it will cause a shape mismatch.

If the assert is actually in a function invoked from the analyzed one, it might be helpful to use the --stacks flag to see where the assert originates from. If you want a very detailed view you can try adding a --signatures flag to the invocation, but they will usually get extremely verbose and hard to read, even in very simply examples.

swift-tfp's People

Contributors

apaszke avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

swift-tfp's Issues

Simplify through shape indexing

Right now when one plays with the CNN example, they can get failures that look like this:

In $s4mainAAy10TensorFlow0B0VySfGAEF:
❌ Something doesn't fit!
  - s0[2] = 28
      Asserted at Examples/cnn.swift:235
            |   assert(iH == 28)
        235 |   assert(iW == 28)
            | 
  - s0[1] = 28
      Asserted at Examples/cnn.swift:234
            |   assert(iC == 1)
        234 |   assert(iH == 28)
            |   assert(iW == 28)
  - ((((((((s0[1] - 5) + 1) - 2) / 2 + 1) - 5) + 1) - 2) / 2 + 1) * ((((((((s0[2] - 5) + 1) - 2) / 2 + 1) - 5) + 1) - 2) / 2 + 1) * 64 = 2048
      Asserted at Examples/cnn.swift:138
            |   let (my, k) = y.shape2d
        138 |   assert(mx == my)
            |   let r = TensorFlow.matmul(x, y)

This is because our substitution mechanisms can only replace variables, but s0[0] is really .element(0, of: .var(s0)). We should improve our simplification pass to handle this common case as well.

Implement a mem2reg pass

Right now the tool cannot reason about variables declared as var, because the SIL will access those through a stack pointer. We should have a mem2reg pass that will lower stack pointers to regular SSA dataflow.

Improve broadcast encoding or simplification

Broadcasts are somewhat expensive to encode in Z3 because they require universal quantifiers. For example an expression like broadcast([1, nil], broadcast(s0, [nil, 2])) is equivalent to broadcast(s0, [1, 2]). We could do a much better job at simplifying such patterns (hint: broadcasting is monoidal over lists!).

Simplify through shape indexing

Right now when one plays with the CNN example, they can get failures that look like this:

In $s4mainAAy10TensorFlow0B0VySfGAEF:
❌ Something doesn't fit!
  - s0[2] = 28
      Asserted at Examples/cnn.swift:235
            |   assert(iH == 28)
        235 |   assert(iW == 28)
            | 
  - s0[1] = 28
      Asserted at Examples/cnn.swift:234
            |   assert(iC == 1)
        234 |   assert(iH == 28)
            |   assert(iW == 28)
  - ((((((((s0[1] - 5) + 1) - 2) / 2 + 1) - 5) + 1) - 2) / 2 + 1) * ((((((((s0[2] - 5) + 1) - 2) / 2 + 1) - 5) + 1) - 2) / 2 + 1) * 64 = 2048
      Asserted at Examples/cnn.swift:138
            |   let (my, k) = y.shape2d
        138 |   assert(mx == my)
            |   let r = TensorFlow.matmul(x, y)

This is because our substitution mechanisms can only replace variables, but s0[0] is really .element(0, of: .var(s0)). We should improve our simplification pass to handle this common case as well.

Make the `struct` body analyzer less restrictive

Right now we have a very limited whitelist of AST nodes that can appear inside a struct definition if we want it to parse correctly. I have mostly reverse-engineered this list from a simple example, but that doesn't mean that more complicated things cannot appear there. This list likely needs to be bigger, or TFP will reject to analyze more complex struct declarations.

Make the verification more efficient

Right now we run verification on every single function without inspecting the call chain. This is unnecessary, because if we know that the verification of a caller succeeds, then there's no need to verify the callee.

We should build up the call graph and only attempt verification from functions that don't have any callers.

Handle unreachable blocks better

There are two problems with the way we handle those now. Firstly, we will still try to verify constraints that appear on unreachable paths which we shouldn't do. Secondly, consider this situation:

bb0:
  cond = ...
  cond_br ..., bb1(), bb2()
bb1:
  ...
  return ...
bb2:
  unreachable

There's no reason why we should consider the body of bb1 to have an extra cond path assumption. We should really treat that as if the program contained assert(cond). Nota bene this is the pattern that assert statements compile to if you don't disable mandatory inlining like we hackily do, so it's also good because then we wouldn't have to do that.

Support custom `init` in struct definitions

Right now TFP doesn't support explicit init inside a struct, and forces you to write out a helper function that uses the default constructor (see the CNN example). While this doesn't lower the expressive power of the tool, it's simply annoying and we should fix this.

The reason why this restriction exists is that the default constructor contains a single struct instruction that simply puts all of its arguments into a new struct instance. If you do this instead:

struct X {
  let x: Int
  let y: Int
  init(same: Int) {
    self.x = same
    self.y = same
  }
}

then the emitted SIL will actually allocate X on the stack, and use pointers-to-members to fill in each field separately. Since TFP cannot handle pointers in its analysis we should detect such patterns and replace them with struct instructions.

Add a mem2reg pass

Right now the tool cannot reason about variables declared as var, because the SIL will access those through a stack pointer. We should have a mem2reg pass that will lower stack pointers to regular SSA dataflow.

Have a cheaper sorting key than `constraint.description`

Right now we sometimes sort constraints because we want the program to be deterministic. However, because they don't have < defined we always compare them by their .description which is dumb and slow. We should have a better comparison function.

Improve simplification for .and and .or

For example we should collapse .and([b1, .and([b2, b3])]) into .and([b1, b2, b3]). Also, things like .or([b1, .not(b1)]) should simplify to .true. This last one is especially important because this is a pattern that appears in path conditions all the time.

Build Failure on Linux

Ubuntu 18.04. I installed z3 4.4.1 with sudo apt install libz3-dev. swift build produces these errors:

/home/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:151:11: error: type 'Z3_bool' (aka 'Int32') cannot be used as a
boolean; test for '!= 0' instead
    guard Z3_get_numeral_int64(ctx.ctx, interpretation, &result) else {
          ^
          (                                                      != 0)

Swift for Tensorflow 0.12.

Construct models only when necessary

Right now we always run with model construction enabled, even if we don't have any holes (which is quite common). We should disable it in this case.

Support custom init in struct definitions

Right now TFP doesn't support explicit init inside a struct, and forces you to write out a helper function that uses the default constructor (see the CNN example). While this doesn't lower the expressive power of the tool, it's simply annoying and we should fix this.

The reason why this restriction exists is that the default constructor contains a single struct instruction that simply puts all of its arguments into a new struct instance. If you do this instead:

struct X {
  let x: Int
  let y: Int
  init(same: Int) {
    self.x = same
    self.y = same
  }
}

then the emitted SIL will actually allocate X on the stack, and use pointers-to-members to fill in each field separately. Since TFP cannot handle pointers in its analysis we should detect such patterns and replace them with struct instructions.

Improve handling of `switch_enum` terminators in the frontend

Right now a block that joins all paths outgoing from a switch_enum instruction will not have a path condition that's equivalent to true. That's bad, but there's an easy solution to this! unloop already desugars switch_enum to a sequence of cond_br instructions, so we should separate that into a new pass and apply it irrespectively of whether the CFG contains loops or not.

Build failure on macOS

macOS 11.1. I installed z3 4.8.9 with brew install z3 and swift build produces these errors:

/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:20:12: error: cannot find type 'Z3_context' in scope
  var ctx: Z3_context
           ^~~~~~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:21:16: error: cannot find type 'Z3_sort' in scope
  let intSort: Z3_sort
               ^~~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:22:17: error: cannot find type 'Z3_sort' in scope
  let boolSort: Z3_sort
                ^~~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:23:48: error: cannot find 'Z3_mk_true' in scope
  lazy var `true`: Z3Expr<Bool> = Z3Expr(self, Z3_mk_true(ctx))
                                               ^~~~~~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:24:49: error: cannot find 'Z3_mk_false' in scope
  lazy var `false`: Z3Expr<Bool> = Z3Expr(self, Z3_mk_false(ctx))
                                                ^~~~~~~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:66:15: error: cannot find type 'Z3_solver' in scope
  var solver: Z3_solver
              ^~~~~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:134:14: error: cannot find type 'Z3_model' in scope
  var model: Z3_model
             ^~~~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:140:35: error: cannot find type 'Z3_model' in scope
  init(_ ctx: Z3Context, _ model: Z3_model) {
                                  ^~~~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:165:12: error: cannot find type 'Z3_ast' in scope
  var ast: Z3_ast
           ^~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:172:33: error: cannot find type 'Z3_ast' in scope
  init(_ ctx: Z3Context, _ ast: Z3_ast) {
                                ^~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:185:33: error: cannot find type 'Z3_context' in scope
                       _ cstr: (Z3_context?, UInt32, UnsafePointer<Z3_ast?>?) -> Z3_ast?) -> Z3Expr<C> {
                                ^~~~~~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:185:68: error: cannot find type 'Z3_ast' in scope
                       _ cstr: (Z3_context?, UInt32, UnsafePointer<Z3_ast?>?) -> Z3_ast?) -> Z3Expr<C> {
                                                                   ^~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:185:82: error: cannot find type 'Z3_ast' in scope
                       _ cstr: (Z3_context?, UInt32, UnsafePointer<Z3_ast?>?) -> Z3_ast?) -> Z3Expr<C> {
                                                                                 ^~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:191:33: error: cannot find type 'Z3_context' in scope
                       _ cstr: (Z3_context?, Z3_ast?, Z3_ast?) -> Z3_ast?) -> Z3Expr<C> {
                                ^~~~~~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:191:46: error: cannot find type 'Z3_ast' in scope
                       _ cstr: (Z3_context?, Z3_ast?, Z3_ast?) -> Z3_ast?) -> Z3Expr<C> {
                                             ^~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:191:55: error: cannot find type 'Z3_ast' in scope
                       _ cstr: (Z3_context?, Z3_ast?, Z3_ast?) -> Z3_ast?) -> Z3Expr<C> {
                                                      ^~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:191:67: error: cannot find type 'Z3_ast' in scope
                       _ cstr: (Z3_context?, Z3_ast?, Z3_ast?) -> Z3_ast?) -> Z3Expr<C> {
                                                                  ^~~~~~
/Users/xander/dev/swift-tfp/Sources/LibTFP/Solvers/Z3/Z3.swift:27:17: error: cannot find type 'Z3_config' in scope
    var config: Z3_config = Z3_mk_config()
                ^~~~~~~~~

Swift for Tensorflow 0.12.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.