Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add @config.default #1213

Open
zilto opened this issue Nov 1, 2024 · 1 comment
Open

feat: add @config.default #1213

zilto opened this issue Nov 1, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@zilto
Copy link
Collaborator

zilto commented Nov 1, 2024

This is the result of discussing with a user on backwards compatibility of a dataflow.

Currently, @config offers 4 options:

  • @config.when(key="foo"): select this implementation when equality is True
  • @config.when_not(key="foo"): select this impl. when equality is False
  • @config.when_in(key=["foo", "bar"]): selects this impl. when key in list[] is True
  • @config.when_not_in(key=["foo", "bar"]): selects this impl. when key in list[] is False
    This covers a lot of cases, but there's no way to specify a default.

Example 1

Here's a simple illustration of limitations for backwards compatibility.

This is version1

# dataflow.py
def foo() -> int:
   return 1
   
# run.py
import dataflow
from hamilton import driver

dr = driver.Builder().with_modules(dataflow).build()
dr.execute(["foo"])

Example 2

Now I'm adding a version2 and I want to have version1 as my default.

problem

If you use @config.when(version="1") and @config.when(version="2"), this can break downstream drivers because there will be no node foo if .with_config() is not set.

# dataflow.py
from hamilton.function_modifiers import config

@config.when(version="1")
def foo__v1() -> int:
   return 1
   
@config.when(version="2")
def foo__v2() -> int:
   return 2  
   
 # run.py
import dataflow
from hamilton import driver

# breaks because `.with_config()` didn't set `version="1"` or `version="2"`
dr = driver.Builder().with_modules(dataflow).build()
dr.execute(["foo"])

solution

Best solution is to annotate when_not(version="2") to catch all configurations (including empty ones, i.e., when .with_config() is not present).

# dataflow.py
from hamilton.function_modifiers import config

@config.when_not(version="2")
def foo__v1() -> int:
   return 1
   
@config.when(version="2")
def foo__v2() -> int:
   return 2  
 
# run.py
import dataflow
from hamilton import driver

dr = driver.Builder().with_modules(dataflow).build()
dr.execute(["foo"])

Example 3

Now, I'm adding an implementation version3

Problem

If I'm conserving my previous code and adding @config.when(version="3"), it will never be hit. This is because the already existing when_not(version="2") will catch this configuration.

# dataflow.py
from hamilton.function_modifiers import config

@config.when_not(version="2")
def foo__v1() -> int:
   return 1
   
@config.when(version="2")
def foo__v2() -> int:
   return 2
   
@config.when(version="3")
def foo__v3() -> int:
   return 3  
 
# run.py
import dataflow
from hamilton import driver

# there will be no errors, but `v1` will be used actually
dr = driver.Builder().with_config({"version": "3"}).with_modules(dataflow).build()
dr.execute(["foo"])

Solution

The user has to modify the decorator for foo__v1() and set it to when_not_in(version=["2", "3"]) to catch all configurations.

The next problem is that whenever an implementation is added, you need to remember to add it to this list otherwise you will silently catch the new version="4".

# dataflow.py
from hamilton.function_modifiers import config

@config.when_not_in(version=["2", "3"])
def foo__v1() -> int:
   return 1
   
@config.when(version="2")
def foo__v2() -> int:
   return 2
   
@config.when(version="3")
def foo__v3() -> int:
   return 3  
 
# run.py
import dataflow
from hamilton import driver

dr = driver.Builder().with_modules(dataflow).build()
dr.execute(["foo"])

Consequences

The main issue is backwards compatibility. When refactoring from a single implementation to two implementations, users have to carefully use .when() and .when_not() in conjunction otherwise, they will break Driver that don't have a config. Then, when moving from 2 to 3+, they have to use when_not_in() and manually manage a list. It is also not obvious from the code that the when_not_in() means "default implementation".

Currently, using .when(version="1") and .when(version="2") implicitly creates a pattern of raising an error on invalid configurations (e.g., version=-1) because there would be a missing node foo, which will likely break a key path. If breaking the path didn't raise an error then a correct or incorrect config didn't matter.

This relates to a broader task of defining the space of valid configurations.

Solution

We should have a @config.default to ensure a node foo is always present in the DAG. Its name is also easy to understand. When you're moving from 1 implementation to 2+, you get a clear design decision: do I want a config.when with v1 and v2 or a default and v2?

Using @config.default would mean "select this implementation if no other config is resolved". This condition needs to be the last resolved and you can't have two nodes of the same name with @config.default.

# dataflow.py
from hamilton.function_modifiers import config

@config.default
def foo__v1() -> int:
   return 1
   
@config.when(version="2")
def foo__v2() -> int:
   return 2
   
@config.when(version="3")
def foo__v3() -> int:
   return 3  
 
# run.py
import dataflow
from hamilton import driver

# passing no config means `default` was used
dr = driver.Builder().with_modules(dataflow).build()
dr.execute(["foo"])
@elijahbenizzy
Copy link
Collaborator

elijahbenizzy commented Nov 4, 2024

This is a really good feature (thought we had an issue a while back?), the hard part is that we need to store global state, largely due to the internal way we manage decorators.

  1. Each decorator creates 0+ nodes from a function (this is currently how it works)
  2. The first decorator (the default) wouldn't know whether to create a node or not, cause it depends on the state of the others
  3. We'd have to store some state -- E.G. do a fallback-type-thing where we know whether it was hit already
  4. This would have to be run in a second pass (unless we're hacking around here) -- E.G. we don't know when we've hit the last one...

So it might be possible to hack in, but it's a fundamental limitation. If we knew it was last, it would be easy enough, or if we just add a second pass at some point. Alternatively, we might be able to do something like this:

# dataflow.py
from hamilton.function_modifiers import config

def foo__v1() -> int:
   return 1
   
@config.when(version="2")
def foo__v2() -> int:
   return 2
   
@config.when(version="3", otherwise=foo__v1) # we know it's last -- if this doesn't evaluate we evaluate to foo__v1.
def foo__v3() -> int:
   return 3  
 
# run.py
import dataflow
from hamilton import driver

# passing no config means `default` was used
dr = driver.Builder().with_modules(dataflow).build()
dr.execute(["foo"])

I think we process in the right order so this should work, but we'd still need to keep some state...

@zilto zilto added the enhancement New feature or request label Nov 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants