Skip to content

Commit 6946597

Browse files
armishraclaude
andcommitted
feat(cli): add --config option to specify custom truss config path
Add support for specifying a custom config file path via the --config option in the push and watch commands. This allows users to use alternative config files instead of the default config.yaml. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent ccb97f3 commit 6946597

File tree

5 files changed

+208
-16
lines changed

5 files changed

+208
-16
lines changed

truss/base/truss_spec.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,25 @@ class TrussSpec:
1414
Helper class for easy access to information in a Truss.
1515
"""
1616

17-
def __init__(self, truss_dir: Path) -> None:
17+
def __init__(self, truss_dir: Path, config_path: Optional[Path] = None) -> None:
1818
self._truss_dir = truss_dir
19-
self._config = truss_config.TrussConfig.from_yaml(
20-
truss_dir / constants.CONFIG_FILE
21-
)
19+
self._config_path = self._resolve_config_path(truss_dir, config_path)
20+
self._config = truss_config.TrussConfig.from_yaml(self._config_path)
21+
22+
@staticmethod
23+
def _resolve_config_path(truss_dir: Path, config_path: Optional[Path]) -> Path:
24+
"""Return custom config path if provided, otherwise default to config.yaml."""
25+
if config_path:
26+
return config_path
27+
return truss_dir / constants.CONFIG_FILE
2228

2329
@property
2430
def truss_dir(self) -> Path:
2531
return self._truss_dir
2632

2733
@property
2834
def config_path(self) -> Path:
29-
return self._truss_dir / constants.CONFIG_FILE
35+
return self._config_path
3036

3137
@property
3238
def data_dir(self) -> Path:

truss/cli/cli.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,20 @@
7373
}
7474

7575

76-
def _get_truss_from_directory(target_directory: Optional[str] = None):
76+
def _get_truss_from_directory(
77+
target_directory: Optional[str] = None, config: Optional[str] = None
78+
):
7779
"""Gets Truss from directory. If none, use the current directory"""
7880
if target_directory is None:
7981
target_directory = os.getcwd()
82+
config_path = Path(config) if config else None
8083
if not os.path.isfile(target_directory):
81-
return load(target_directory)
84+
return load(target_directory, config_path=config_path)
8285
# These imports are delayed, to handle pydantic v1 envs gracefully.
8386
from truss_chains.deployment import code_gen
8487

8588
truss_dir = code_gen.gen_truss_model_from_source(Path(target_directory))
86-
return load(truss_dir)
89+
return load(truss_dir, config_path=config_path)
8790

8891

8992
### Top-level & utility commands. ######################################################
@@ -408,6 +411,12 @@ def run_python(script, target_directory):
408411

409412
@truss_cli.command()
410413
@click.argument("target_directory", required=False, default=os.getcwd())
414+
@click.option(
415+
"--config",
416+
type=click.Path(exists=True),
417+
required=False,
418+
help="Path to a custom config file (default: config.yaml in truss directory)",
419+
)
411420
@click.option(
412421
"--remote",
413422
type=str,
@@ -532,6 +541,7 @@ def run_python(script, target_directory):
532541
@common.common_options()
533542
def push(
534543
target_directory: str,
544+
config: Optional[str],
535545
remote: str,
536546
model_name: str,
537547
publish: bool = False,
@@ -555,7 +565,7 @@ def push(
555565
TARGET_DIRECTORY: A Truss directory. If none, use current directory.
556566
557567
"""
558-
tr = _get_truss_from_directory(target_directory=target_directory)
568+
tr = _get_truss_from_directory(target_directory=target_directory, config=config)
559569

560570
if tr.spec.config.resources.instance_type:
561571
console.print(
@@ -777,6 +787,12 @@ def model_logs(
777787

778788
@truss_cli.command()
779789
@click.argument("target_directory", required=False, default=os.getcwd())
790+
@click.option(
791+
"--config",
792+
type=click.Path(exists=True),
793+
required=False,
794+
help="Path to a custom config file (default: config.yaml in truss directory)",
795+
)
780796
@click.option(
781797
"--remote",
782798
type=str,
@@ -792,7 +808,10 @@ def model_logs(
792808
)
793809
@common.common_options()
794810
def watch(
795-
target_directory: str, remote: str, provided_team_name: Optional[str] = None
811+
target_directory: str,
812+
config: Optional[str],
813+
remote: str,
814+
provided_team_name: Optional[str] = None,
796815
) -> None:
797816
"""
798817
Seamless remote development with truss
@@ -805,7 +824,7 @@ def watch(
805824

806825
remote_provider = cast(BasetenRemote, RemoteFactory.create(remote=remote))
807826

808-
tr = _get_truss_from_directory(target_directory=target_directory)
827+
tr = _get_truss_from_directory(target_directory=target_directory, config=config)
809828
model_name = tr.spec.config.model_name
810829
if not model_name:
811830
console.print(
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
"""Tests for custom config path functionality.
2+
3+
This module tests the --config option that allows specifying a custom config file
4+
instead of the default config.yaml in the truss directory.
5+
"""
6+
7+
import shutil
8+
import tempfile
9+
from pathlib import Path
10+
11+
import pytest
12+
import yaml
13+
14+
from truss.base.truss_spec import TrussSpec
15+
from truss.truss_handle.build import load
16+
from truss.truss_handle.truss_handle import TrussHandle
17+
18+
19+
@pytest.fixture
20+
def truss_dir_with_multiple_configs(test_data_path: Path):
21+
"""Create a truss directory with multiple config files for testing."""
22+
source_dir = test_data_path / "test_basic_truss"
23+
with tempfile.TemporaryDirectory() as tmp_dir:
24+
tmp_path = Path(tmp_dir)
25+
truss_path = tmp_path / "test_truss"
26+
shutil.copytree(source_dir, truss_path)
27+
28+
# Read the original config
29+
original_config_path = truss_path / "config.yaml"
30+
with open(original_config_path) as f:
31+
original_config = yaml.safe_load(f)
32+
33+
# Create an alternate config with different model_name
34+
alt_config = original_config.copy()
35+
alt_config["model_name"] = "alternate-model"
36+
alt_config_path = truss_path / "config.dev.yaml"
37+
with open(alt_config_path, "w") as f:
38+
yaml.dump(alt_config, f)
39+
40+
# Create another alternate config outside the truss directory
41+
external_config = original_config.copy()
42+
external_config["model_name"] = "external-model"
43+
external_config_path = tmp_path / "external_config.yaml"
44+
with open(external_config_path, "w") as f:
45+
yaml.dump(external_config, f)
46+
47+
yield {
48+
"truss_dir": truss_path,
49+
"default_config": original_config_path,
50+
"alt_config": alt_config_path,
51+
"external_config": external_config_path,
52+
"original_model_name": original_config.get("model_name"),
53+
}
54+
55+
56+
class TestTrussSpecCustomConfigPath:
57+
"""Tests for TrussSpec with custom config path."""
58+
59+
def test_default_config_path_backwards_compatibility(
60+
self, truss_dir_with_multiple_configs
61+
):
62+
"""Test that TrussSpec without config_path uses default config.yaml."""
63+
data = truss_dir_with_multiple_configs
64+
spec = TrussSpec(data["truss_dir"])
65+
66+
# Should load from default config.yaml
67+
assert spec.config_path == data["default_config"]
68+
assert spec.config.model_name == data["original_model_name"]
69+
70+
def test_custom_config_path_in_truss_dir(self, truss_dir_with_multiple_configs):
71+
"""Test that TrussSpec with config_path loads from that file."""
72+
data = truss_dir_with_multiple_configs
73+
spec = TrussSpec(data["truss_dir"], config_path=data["alt_config"])
74+
75+
# Should load from alternate config
76+
assert spec.config_path == data["alt_config"]
77+
assert spec.config.model_name == "alternate-model"
78+
79+
def test_custom_config_path_external(self, truss_dir_with_multiple_configs):
80+
"""Test that TrussSpec with external config_path works."""
81+
data = truss_dir_with_multiple_configs
82+
spec = TrussSpec(data["truss_dir"], config_path=data["external_config"])
83+
84+
# Should load from external config
85+
assert spec.config_path == data["external_config"]
86+
assert spec.config.model_name == "external-model"
87+
88+
89+
class TestTrussHandleCustomConfigPath:
90+
"""Tests for TrussHandle with custom config path."""
91+
92+
def test_default_config_path_backwards_compatibility(
93+
self, truss_dir_with_multiple_configs
94+
):
95+
"""Test that TrussHandle without config_path uses default config.yaml."""
96+
data = truss_dir_with_multiple_configs
97+
handle = TrussHandle(data["truss_dir"])
98+
99+
# Should load from default config.yaml
100+
assert handle.spec.config_path == data["default_config"]
101+
assert handle.spec.config.model_name == data["original_model_name"]
102+
103+
def test_custom_config_path(self, truss_dir_with_multiple_configs):
104+
"""Test that TrussHandle with config_path loads from that file."""
105+
data = truss_dir_with_multiple_configs
106+
handle = TrussHandle(data["truss_dir"], config_path=data["alt_config"])
107+
108+
# Should load from alternate config
109+
assert handle.spec.config_path == data["alt_config"]
110+
assert handle.spec.config.model_name == "alternate-model"
111+
112+
def test_update_config_preserves_custom_config_path(
113+
self, truss_dir_with_multiple_configs
114+
):
115+
"""Test that _update_config preserves the custom config path after reload."""
116+
data = truss_dir_with_multiple_configs
117+
handle = TrussHandle(data["truss_dir"], config_path=data["alt_config"])
118+
119+
# Update the config (this triggers a reload)
120+
handle._update_config(description="Test description")
121+
122+
# After reload, should still use the custom config path
123+
assert handle.spec.config_path == data["alt_config"]
124+
assert handle.spec.config.model_name == "alternate-model"
125+
assert handle.spec.config.description == "Test description"
126+
127+
128+
class TestLoadFunctionCustomConfigPath:
129+
"""Tests for the load() function with custom config path."""
130+
131+
def test_load_default_config_path_backwards_compatibility(
132+
self, truss_dir_with_multiple_configs
133+
):
134+
"""Test that load() without config_path uses default config.yaml."""
135+
data = truss_dir_with_multiple_configs
136+
handle = load(data["truss_dir"])
137+
138+
# Should load from default config.yaml
139+
assert handle.spec.config_path == data["default_config"]
140+
assert handle.spec.config.model_name == data["original_model_name"]
141+
142+
def test_load_with_custom_config_path(self, truss_dir_with_multiple_configs):
143+
"""Test that load() with config_path loads from that file."""
144+
data = truss_dir_with_multiple_configs
145+
handle = load(data["truss_dir"], config_path=data["alt_config"])
146+
147+
# Should load from alternate config
148+
assert handle.spec.config_path == data["alt_config"]
149+
assert handle.spec.config.model_name == "alternate-model"
150+
151+
def test_load_with_external_config_path(self, truss_dir_with_multiple_configs):
152+
"""Test that load() with external config_path works."""
153+
data = truss_dir_with_multiple_configs
154+
handle = load(data["truss_dir"], config_path=data["external_config"])
155+
156+
# Should load from external config
157+
assert handle.spec.config_path == data["external_config"]
158+
assert handle.spec.config.model_name == "external-model"

truss/truss_handle/build.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,20 @@ def init_directory(
9898
return target_directory_path
9999

100100

101-
def load(truss_directory: Union[str, Path]) -> TrussHandle:
101+
def load(
102+
truss_directory: Union[str, Path], config_path: Optional[Path] = None
103+
) -> TrussHandle:
102104
"""Get a handle to a Truss. A Truss is a build context designed to be built
103105
as a container locally or uploaded into a model serving environment.
104106
105107
Args:
106108
truss_directory (str | Path): The local directory of an existing Truss
109+
config_path (Path | None): Optional path to a config file. If not provided,
110+
defaults to config.yaml in the truss directory.
107111
Returns:
108112
TrussHandle
109113
"""
110-
return TrussHandle(Path(truss_directory))
114+
return TrussHandle(Path(truss_directory), config_path=config_path)
111115

112116

113117
def cleanup() -> None:

truss/truss_handle/truss_handle.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,12 @@ def __init__(self, base_url):
120120

121121

122122
class TrussHandle:
123-
def __init__(self, truss_dir: Path, validate: bool = True) -> None:
123+
def __init__(
124+
self, truss_dir: Path, validate: bool = True, config_path: Optional[Path] = None
125+
) -> None:
124126
self._truss_dir = truss_dir
125-
self._spec = TrussSpec(self._truss_dir)
127+
self._config_path = config_path
128+
self._spec = TrussSpec(self._truss_dir, config_path=config_path)
126129
self._hash_for_mod_time: Optional[Tuple[float, str]] = None
127130
if validate:
128131
self.validate()
@@ -939,7 +942,9 @@ def _build_image(
939942
def _update_config(self, **fields_to_update):
940943
config = self._spec.config.model_copy(update=fields_to_update)
941944
config.write_to_yaml_file(self._spec.config_path)
942-
self._spec = TrussSpec(self._truss_dir) # Reload.
945+
self._spec = TrussSpec(
946+
self._truss_dir, config_path=self._config_path
947+
) # Reload.
943948

944949
def _try_patch(self) -> Optional["Container"]:
945950
if not self.is_control_truss:

0 commit comments

Comments
 (0)