Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
name=name,
interface=NativeInterface({k: (v, None) for k, v in inputs.items()} if inputs else {}, outputs or {}),
task_type=self._TASK_TYPE,
image=None,
**kwargs,
)
self.output_dataframe_type = output_dataframe_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def _execute_query():
schema=schema,
warehouse=warehouse,
query_id=query_id,
has_output=task_template.interface.outputs is not None and len(task_template.interface.outputs.variables) > 0,
has_output=task_template.interface.outputs is not None
and len(task_template.interface.outputs.variables) > 0,
connection_kwargs=connection_kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
outputs or {},
),
task_type=self._TASK_TYPE,
image=None,
**kwargs,
)
self.output_dataframe_type = output_dataframe_type
Expand Down
36 changes: 10 additions & 26 deletions plugins/wandb/src/flyteplugins/wandb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ async def my_task():
import os
from typing import Optional

import wandb

import flyte
from flyte.io import Dir

import wandb

from ._context import (
get_wandb_context,
get_wandb_sweep_context,
Expand Down Expand Up @@ -317,9 +317,7 @@ def download_wandb_run_dir(
f"Error: {e}"
) from e
except Exception as e:
raise RuntimeError(
f"Unexpected error fetching wandb run '{run_path}': {e}"
) from e
raise RuntimeError(f"Unexpected error fetching wandb run '{run_path}': {e}") from e

try:
for file in api_run.files():
Expand All @@ -334,13 +332,9 @@ def download_wandb_run_dir(
with open(os.path.join(path, "summary.json"), "w") as f:
json.dump(summary_data, f, indent=2, default=str)
except (IOError, OSError) as e:
raise RuntimeError(
f"Failed to write summary.json for run '{run_id}': {e}"
) from e
raise RuntimeError(f"Failed to write summary.json for run '{run_id}': {e}") from e
except Exception as e:
raise RuntimeError(
f"Failed to export summary data for run '{run_id}': {e}"
) from e
raise RuntimeError(f"Failed to export summary data for run '{run_id}': {e}") from e

# Export metrics history to JSON
if include_history:
Expand All @@ -350,13 +344,9 @@ def download_wandb_run_dir(
with open(os.path.join(path, "metrics_history.json"), "w") as f:
json.dump(history, f, indent=2, default=str)
except (IOError, OSError) as e:
raise RuntimeError(
f"Failed to write metrics_history.json for run '{run_id}': {e}"
) from e
raise RuntimeError(f"Failed to write metrics_history.json for run '{run_id}': {e}") from e
except Exception as e:
raise RuntimeError(
f"Failed to export history data for run '{run_id}': {e}"
) from e
raise RuntimeError(f"Failed to export history data for run '{run_id}': {e}") from e

return path

Expand Down Expand Up @@ -403,9 +393,7 @@ def download_wandb_sweep_dirs(
project = wandb_ctx.project if wandb_ctx else None

if not entity or not project:
raise RuntimeError(
"Cannot query sweep without entity and project. Set them via wandb_config()."
)
raise RuntimeError("Cannot query sweep without entity and project. Set them via wandb_config().")

# Query sweep runs via wandb API
try:
Expand All @@ -425,9 +413,7 @@ def download_wandb_sweep_dirs(
f"Error: {e}"
) from e
except Exception as e:
raise RuntimeError(
f"Unexpected error fetching wandb sweep '{entity}/{project}/{sweep_id}': {e}"
) from e
raise RuntimeError(f"Unexpected error fetching wandb sweep '{entity}/{project}/{sweep_id}': {e}") from e

# Download each run's data
downloaded_paths = []
Expand All @@ -436,9 +422,7 @@ def download_wandb_sweep_dirs(
for run_id in run_ids:
path = f"{base_path or '/tmp/wandb_runs'}/{run_id}"
try:
download_wandb_run_dir(
run_id=run_id, path=path, include_history=include_history
)
download_wandb_run_dir(run_id=run_id, path=path, include_history=include_history)
downloaded_paths.append(path)
except Exception as e:
# Log failure but continue with other runs
Expand Down
4 changes: 1 addition & 3 deletions plugins/wandb/src/flyteplugins/wandb/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,7 @@ def get_wandb_sweep_context() -> Optional[_WandBSweepConfig]:
if ctx is None or not ctx.custom_context:
return None

has_wandb_sweep_keys = any(
k.startswith("wandb_sweep_") for k in ctx.custom_context.keys()
)
has_wandb_sweep_keys = any(k.startswith("wandb_sweep_") for k in ctx.custom_context.keys())
if not has_wandb_sweep_keys:
return None

Expand Down
34 changes: 9 additions & 25 deletions plugins/wandb/src/flyteplugins/wandb/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from inspect import iscoroutinefunction
from typing import Any, Callable, Optional, TypeVar, cast

import wandb

import flyte
from flyte._task import AsyncFunctionTaskTemplate

import wandb

from ._context import RunMode, get_wandb_context, get_wandb_sweep_context
from ._link import Wandb, WandbSweep

Expand Down Expand Up @@ -124,9 +124,7 @@ def _wandb_run(
if not is_primary:
shared_config["x_update_finish_state"] = False

init_kwargs["settings"] = wandb.Settings(
**{**existing_settings, **shared_config}
)
init_kwargs["settings"] = wandb.Settings(**{**existing_settings, **shared_config})

# Initialize wandb
run = wandb.init(**init_kwargs)
Expand All @@ -141,9 +139,7 @@ def _wandb_run(
yield run
finally:
# Determine if this is a primary run
is_primary_run = run_mode == "new" or (
run_mode == "auto" and saved_run_id is None
)
is_primary_run = run_mode == "new" or (run_mode == "auto" and saved_run_id is None)

if run:
# Different cleanup logic for local vs remote mode
Expand Down Expand Up @@ -293,9 +289,7 @@ def sync_wrapper(*args, **wrapper_kwargs):


@contextmanager
def _create_sweep(
project: Optional[str] = None, entity: Optional[str] = None, **decorator_kwargs
):
def _create_sweep(project: Optional[str] = None, entity: Optional[str] = None, **decorator_kwargs):
"""Context manager for wandb sweep creation."""
ctx = flyte.ctx()

Expand All @@ -317,14 +311,8 @@ def _create_sweep(
wandb_config = get_wandb_context()

# Priority: decorator kwargs > sweep config > wandb config
project = (
project
or sweep_config.project
or (wandb_config.project if wandb_config else None)
)
entity = (
entity or sweep_config.entity or (wandb_config.entity if wandb_config else None)
)
project = project or sweep_config.project or (wandb_config.project if wandb_config else None)
entity = entity or sweep_config.entity or (wandb_config.entity if wandb_config else None)
prior_runs = sweep_config.prior_runs or []

# Get sweep config dict
Expand Down Expand Up @@ -401,19 +389,15 @@ def decorator(func: F) -> F:
original_execute = func.execute

async def wrapped_execute(*args, **exec_kwargs):
with _create_sweep(
project=project, entity=entity, **kwargs
) as sweep_id:
with _create_sweep(project=project, entity=entity, **kwargs) as sweep_id:
result = await original_execute(*args, **exec_kwargs)

# After sweep finishes, optionally download logs
should_download = download_logs
if should_download is None:
# Check context config
sweep_config = get_wandb_sweep_context()
should_download = (
sweep_config.download_logs if sweep_config else False
)
should_download = sweep_config.download_logs if sweep_config else False

if should_download and sweep_id:
from . import download_wandb_sweep_logs
Expand Down
8 changes: 2 additions & 6 deletions plugins/wandb/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,8 @@ def test_wandb_sweep_config_to_dict(self):
result = config.to_dict()

assert result["wandb_sweep_method"] == "random"
assert result["wandb_sweep_metric"] == json.dumps(
{"name": "loss", "goal": "minimize"}
)
assert result["wandb_sweep_parameters"] == json.dumps(
{"lr": {"min": 0.001, "max": 0.1}}
)
assert result["wandb_sweep_metric"] == json.dumps({"name": "loss", "goal": "minimize"})
assert result["wandb_sweep_parameters"] == json.dumps({"lr": {"min": 0.001, "max": 0.1}})

def test_wandb_sweep_config_from_dict(self):
"""Test creating WandBSweepConfig from dict."""
Expand Down
Loading
Loading