Skip to content

Commit

Permalink
show a more helpful error message when a restart is required for %view
Browse files Browse the repository at this point in the history
  • Loading branch information
seeM committed Nov 27, 2024
1 parent f53a9d1 commit e7d8d84
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

import comm
from IPython.core.error import UsageError

from .access_keys import decode_access_key
from .data_explorer_comm import (
Expand Down Expand Up @@ -96,7 +97,14 @@
TextSearchType,
)
from .positron_comm import CommMessage, PositronComm
from .third_party import np_, pd_, pl_
from .third_party import (
RestartRequiredError,
import_pandas,
import_polars,
np_,
pd_,
pl_,
)
from .utils import BackgroundJobQueue, guid

if TYPE_CHECKING:
Expand Down Expand Up @@ -312,6 +320,7 @@ def _match_text_search(params: FilterTextSearch):

def matches(x):
return term in x.lower()

else:

def matches(x):
Expand Down Expand Up @@ -2581,11 +2590,31 @@ class PyArrowView(DataExplorerTableView):


def _is_pandas(table):
return pd_ is not None and isinstance(table, (pd_.DataFrame, pd_.Series))
pandas = import_pandas()
if pandas is not None and isinstance(table, (pandas.DataFrame, pandas.Series)):
# If pandas was installed after the kernel was started, pd_ will still be None.
# Raise an error to inform the user to restart the kernel.
if pd_ is None:
raise RestartRequiredError(
"Pandas was installed after the session started. Please restart the session to "
+ "view the table in the Data Explorer."
)
return True
return False


def _is_polars(table):
return pl_ is not None and isinstance(table, (pl_.DataFrame, pl_.Series))
polars = import_polars()
if polars is not None and isinstance(table, (polars.DataFrame, polars.Series)):
# If polars was installed after the kernel was started, pl_ will still be None.
# Raise an error to inform the user to restart the kernel.
if pl_ is None:
raise RestartRequiredError(
"Polars was installed after the session started. Please restart the session to "
+ "view the table."
)
return True
return False


def _get_table_view(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .patch.holoviews import set_holoviews_extension
from .plots import PlotsService
from .session_mode import SessionMode
from .third_party import RestartRequiredError
from .ui import UiService
from .utils import BackgroundJobQueue, JsonRecord, get_qualname
from .variables import VariablesService
Expand Down Expand Up @@ -168,6 +169,8 @@ def view(self, line: str) -> None:
)
except TypeError:
raise UsageError(f"cannot view object of type '{get_qualname(obj)}'")
except RestartRequiredError as error:
raise UsageError(*error.args)

@magic_arguments.magic_arguments()
@magic_arguments.argument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
from datetime import datetime
from decimal import Decimal
from io import StringIO
from typing import Any, Dict, List, Optional, Type, cast
from typing import Any, Dict, List, Optional, Type, Union, cast

import numpy as np
import pandas as pd
import polars as pl
import pytest
import pytz

from .. import data_explorer
from .._vendor.pydantic import BaseModel
from ..access_keys import encode_access_key
from ..data_explorer import (
Expand Down Expand Up @@ -49,6 +50,7 @@
RowFilterTypeSupportStatus,
SupportStatus,
)
from ..third_party import RestartRequiredError
from ..utils import guid
from .conftest import DummyComm, PositronShell
from .test_variables import BIG_ARRAY_LENGTH
Expand Down Expand Up @@ -295,6 +297,30 @@ def test_register_table_with_variable_path(de_service: DataExplorerService):
assert table_view.state.name == title


@pytest.mark.parametrize(
("table", "import_name", "title"),
[(pd.DataFrame({}), "pd_", "Pandas"), (pl.DataFrame({}), "pl_", "Polars")],
)
def test_register_table_after_installing_dependency(
table: Union[pd.DataFrame, pl.DataFrame],
import_name: str,
title: str,
de_service: DataExplorerService,
monkeypatch,
):
# Patch the module (e.g. third_party.pd_) to None. Since these packages are really is installed
# during tests, this simulates the case where the user installs the package after the kernel
# starts, therefore the third_party attribute (e.g. pd_) is None but the corresponding import
# function (third_party.import_pandas()) returns the module.
# See https://github.com/posit-dev/positron/issues/5535.
monkeypatch.setattr(data_explorer, import_name, None)

with pytest.raises(
RestartRequiredError, match=f"^{title} was installed after the session started."
):
de_service.register_table(table, "test_table")


def test_shutdown(de_service: DataExplorerService):
df = pd.DataFrame({"a": [1, 2, 3, 4, 5]})
de_service.register_table(df, "t1", comm_id=guid())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,47 +9,53 @@
# checking.


def _get_numpy():
class RestartRequiredError(Exception):
"""Raised when a restart is required to load a third party package."""

pass


def import_numpy():
try:
import numpy
except ImportError:
numpy = None
return numpy


def _get_pandas():
def import_pandas():
try:
import pandas
except ImportError:
pandas = None
return pandas


def _get_polars():
def import_polars():
try:
import polars
except ImportError:
polars = None
return polars


def _get_torch():
def import_torch():
try:
import torch # type: ignore [reportMissingImports] for 3.12
except ImportError:
torch = None
return torch


def _get_pyarrow():
def import_pyarrow():
try:
import pyarrow # type: ignore [reportMissingImports] for 3.12
except ImportError:
pyarrow = None
return pyarrow


def _get_sqlalchemy():
def import_sqlalchemy():
try:
import sqlalchemy
except ImportError:
Expand All @@ -59,11 +65,12 @@ def _get_sqlalchemy():

# Currently, pyright only correctly infers the types below as `Optional` if we set their values
# using function calls.
np_ = _get_numpy()
pa_ = _get_pyarrow()
pd_ = _get_pandas()
pl_ = _get_polars()
torch_ = _get_torch()
sqlalchemy_ = _get_sqlalchemy()
np_ = import_numpy()
pa_ = import_pyarrow()
pd_ = import_pandas()
pl_ = import_polars()
torch_ = import_torch()
sqlalchemy_ = import_sqlalchemy()


__all__ = ["np_", "pa_", "pd_", "pl_", "torch_", "sqlalchemy_"]

0 comments on commit e7d8d84

Please sign in to comment.