diff --git a/api/schemas/schemas.py b/api/schemas/schemas.py index 69391596ba..4a092fcad2 100644 --- a/api/schemas/schemas.py +++ b/api/schemas/schemas.py @@ -161,24 +161,25 @@ class _TimedSchema(BaseModel): startTimestamp: int = Field(default=None) endTimestamp: int = Field(default=None) - @model_validator(mode='before') - def transform_time(self, values): + @model_validator(mode="before") + @classmethod + def transform_time(cls, values): if values.get("startTimestamp") is None and values.get("startDate") is not None: values["startTimestamp"] = values["startDate"] if values.get("endTimestamp") is None and values.get("endDate") is not None: values["endTimestamp"] = values["endDate"] return values - @model_validator(mode='after') - def __time_validator(self, values): - if values.startTimestamp is not None: - assert 0 <= values.startTimestamp, "startTimestamp must be greater or equal to 0" - if values.endTimestamp is not None: - assert 0 <= values.endTimestamp, "endTimestamp must be greater or equal to 0" - if values.startTimestamp is not None and values.endTimestamp is not None: - assert values.startTimestamp <= values.endTimestamp, \ + @model_validator(mode="after") + def __time_validator(self): + if self.startTimestamp is not None: + assert 0 <= self.startTimestamp, "startTimestamp must be greater or equal to 0" + if self.endTimestamp is not None: + assert 0 <= self.endTimestamp, "endTimestamp must be greater or equal to 0" + if self.startTimestamp is not None and self.endTimestamp is not None: + assert self.startTimestamp <= self.endTimestamp, \ "endTimestamp must be greater or equal to startTimestamp" - return values + return self class NotificationsViewSchema(_TimedSchema): @@ -435,13 +436,13 @@ class AlertSchema(BaseModel): series_id: Optional[int] = Field(default=None, doc_hidden=True) @model_validator(mode="after") - def transform_alert(self, values): - values.series_id = None - if isinstance(values.query.left, int): - values.series_id = values.query.left - values.query.left = AlertColumn.CUSTOM + def transform_alert(self): + self.series_id = None + if isinstance(self.query.left, int): + self.series_id = self.query.left + self.query.left = AlertColumn.CUSTOM - return values + return self class SourcemapUploadPayloadSchema(BaseModel): @@ -625,33 +626,31 @@ class SessionSearchEventSchema2(BaseModel): _single_to_list_values = field_validator('value', mode='before')(single_to_list) _transform = model_validator(mode='before')(transform_old_filter_type) - @model_validator(mode='after') - def event_validator(self, values): - if isinstance(values.type, PerformanceEventType): - if values.type == PerformanceEventType.FETCH_FAILED: - return values - # assert values.get("source") is not None, "source should not be null for PerformanceEventType" - # assert isinstance(values["source"], list) and len(values["source"]) > 0, \ - # "source should not be empty for PerformanceEventType" - assert values.sourceOperator is not None, \ + @model_validator(mode="after") + def event_validator(self): + if isinstance(self.type, PerformanceEventType): + if self.type == PerformanceEventType.FETCH_FAILED: + return self + + assert self.sourceOperator is not None, \ "sourceOperator should not be null for PerformanceEventType" - assert "source" in values, f"source is required for {values.type}" - assert isinstance(values.source, list), f"source of type list is required for {values.type}" - for c in values["source"]: - assert isinstance(c, int), f"source value should be of type int for {values.type}" - elif values.type == EventType.ERROR and values.source is None: - values.source = [ErrorSource.JS_EXCEPTION] - elif values.type == EventType.REQUEST_DETAILS: - assert isinstance(values.filters, List) and len(values.filters) > 0, \ + assert self.source is not None, f"source is required for {self.type}" + assert isinstance(self.source, list), f"source of type list is required for {self.type}" + for c in self.source: + assert isinstance(c, int), f"source value should be of type int for {self.type}" + elif self.type == EventType.ERROR and self.source is None: + self.source = [ErrorSource.JS_EXCEPTION] + elif self.type == EventType.REQUEST_DETAILS: + assert isinstance(self.filters, List) and len(self.filters) > 0, \ f"filters should be defined for {EventType.REQUEST_DETAILS}" - elif values.type == EventType.GRAPHQL: - assert isinstance(values.filters, List) and len(values.filters) > 0, \ + elif self.type == EventType.GRAPHQL: + assert isinstance(self.filters, List) and len(self.filters) > 0, \ f"filters should be defined for {EventType.GRAPHQL}" - if isinstance(values.operator, ClickEventExtraOperator): - assert values.type == EventType.CLICK, \ - f"operator:{values.operator} is only available for event-type: {EventType.CLICK}" - return values + if isinstance(self.operator, ClickEventExtraOperator): + assert self.type == EventType.CLICK, \ + f"operator:{self.operator} is only available for event-type: {EventType.CLICK}" + return self class SessionSearchFilterSchema(BaseModel): @@ -665,8 +664,9 @@ class SessionSearchFilterSchema(BaseModel): _transform = model_validator(mode='before')(transform_old_filter_type) _single_to_list_values = field_validator('value', mode='before')(single_to_list) - @model_validator(mode='before') - def _transform_data(self, values): + @model_validator(mode="before") + @classmethod + def _transform_data(cls, values): if values.get("source") is not None: if isinstance(values["source"], list): if len(values["source"]) == 0: @@ -677,38 +677,38 @@ def _transform_data(self, values): raise ValueError(f"Unsupported multi-values source") return values - @model_validator(mode='after') - def filter_validator(self, values): - if values.type == FilterType.METADATA: - assert values.source is not None and len(values.source) > 0, \ + @model_validator(mode="after") + def filter_validator(self): + if self.type == FilterType.METADATA: + assert self.source is not None and len(self.source) > 0, \ "must specify a valid 'source' for metadata filter" - elif values.type == FilterType.ISSUE: - for i, v in enumerate(values.value): + elif self.type == FilterType.ISSUE: + for i, v in enumerate(self.value): if IssueType.has_value(v): - values.value[i] = IssueType(v) + self.value[i] = IssueType(v) else: - raise ValueError(f"value should be of type IssueType for {values.type} filter") - elif values.type == FilterType.PLATFORM: - for i, v in enumerate(values.value): + raise ValueError(f"value should be of type IssueType for {self.type} filter") + elif self.type == FilterType.PLATFORM: + for i, v in enumerate(self.value): if PlatformType.has_value(v): - values.value[i] = PlatformType(v) + self.value[i] = PlatformType(v) else: - raise ValueError(f"value should be of type PlatformType for {values.type} filter") - elif values.type == FilterType.EVENTS_COUNT: - if MathOperator.has_value(values.operator): - values.operator = MathOperator(values.operator) + raise ValueError(f"value should be of type PlatformType for {self.type} filter") + elif self.type == FilterType.EVENTS_COUNT: + if MathOperator.has_value(self.operator): + self.operator = MathOperator(self.operator) else: - raise ValueError(f"operator should be of type MathOperator for {values.type} filter") + raise ValueError(f"operator should be of type MathOperator for {self.type} filter") - for v in values.value: - assert isinstance(v, int), f"value should be of type int for {values.type} filter" + for v in self.value: + assert isinstance(v, int), f"value should be of type int for {self.type} filter" else: - if SearchEventOperator.has_value(values.operator): - values.operator = SearchEventOperator(values.operator) + if SearchEventOperator.has_value(self.operator): + self.operator = SearchEventOperator(self.operator) else: - raise ValueError(f"operator should be of type SearchEventOperator for {values.type} filter") + raise ValueError(f"operator should be of type SearchEventOperator for {self.type} filter") - return values + return self class _PaginatedSchema(BaseModel): @@ -744,7 +744,8 @@ class SessionsSearchPayloadSchema(_TimedSchema, _PaginatedSchema): bookmarked: bool = Field(default=False) @model_validator(mode="before") - def transform_order(self, values): + @classmethod + def transform_order(cls, values): if values.get("sort") is None: values["sort"] = "startTs" @@ -755,7 +756,8 @@ def transform_order(self, values): return values @model_validator(mode="before") - def add_missing_attributes(self, values): + @classmethod + def add_missing_attributes(cls, values): # in case isEvent is wrong: for f in values.get("filters") or []: if EventType.has_value(f["type"]) and not f.get("isEvent"): @@ -770,7 +772,8 @@ def add_missing_attributes(self, values): return values @model_validator(mode="before") - def remove_wrong_filter_values(self, values): + @classmethod + def remove_wrong_filter_values(cls, values): for f in values.get("filters", []): vals = [] for v in f.get("value", []): @@ -780,17 +783,17 @@ def remove_wrong_filter_values(self, values): return values @model_validator(mode="after") - def split_filters_events(self, values): + def split_filters_events(self): n_filters = [] n_events = [] - for v in values.filters: + for v in self.filters: if v.is_event: n_events.append(v) else: n_filters.append(v) - values.events = n_events - values.filters = n_filters - return values + self.events = n_events + self.filters = n_filters + return self @field_validator("filters", mode="after") @classmethod @@ -854,7 +857,8 @@ class PathAnalysisSubFilterSchema(BaseModel): _remove_duplicate_values = field_validator('value', mode='before')(remove_duplicate_values) @model_validator(mode="before") - def __force_is_event(self, values): + @classmethod + def __force_is_event(cls, values): values["isEvent"] = True return values @@ -1048,7 +1052,8 @@ class CardSessionsSchema(_TimedSchema, _PaginatedSchema): (force_is_event(events_enum=[EventType, PerformanceEventType])) @model_validator(mode="before") - def remove_wrong_filter_values(self, values): + @classmethod + def remove_wrong_filter_values(cls, values): for f in values.get("filters", []): vals = [] for v in f.get("value", []): @@ -1058,7 +1063,8 @@ def remove_wrong_filter_values(self, values): return values @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): if values.get("startTimestamp") is None: values["startTimestamp"] = TimeUTC.now(-7) @@ -1068,45 +1074,44 @@ def __enforce_default(self, values): return values @model_validator(mode="after") - def __enforce_default_after(self, values): - for s in values.series: + def __enforce_default_after(self): + for s in self.series: if s.filter is not None: - s.filter.limit = values.limit - s.filter.page = values.page - s.filter.startTimestamp = values.startTimestamp - s.filter.endTimestamp = values.endTimestamp + s.filter.limit = self.limit + s.filter.page = self.page + s.filter.startTimestamp = self.startTimestamp + s.filter.endTimestamp = self.endTimestamp - return values + return self @model_validator(mode="after") - def __merge_out_filters_with_series(self, values): - if len(values.filters) > 0: - for f in values.filters: - for s in values.series: - found = False - - if f.is_event: - sub = s.filter.events - else: - sub = s.filter.filters - - for e in sub: - if f.type == e.type and f.operator == e.operator: - found = True - if f.is_event: - # If extra event: append value - for v in f.value: - if v not in e.value: - e.value.append(v) - else: - # If extra filter: override value - e.value = f.value - if not found: - sub.append(f) - - values.filters = [] + def __merge_out_filters_with_series(self): + for f in self.filters: + for s in self.series: + found = False - return values + if f.is_event: + sub = s.filter.events + else: + sub = s.filter.filters + + for e in sub: + if f.type == e.type and f.operator == e.operator: + found = True + if f.is_event: + # If extra event: append value + for v in f.value: + if v not in e.value: + e.value.append(v) + else: + # If extra filter: override value + e.value = f.value + if not found: + sub.append(f) + + self.filters = [] + + return self class CardConfigSchema(BaseModel): @@ -1141,14 +1146,15 @@ class CardTimeSeries(__CardSchema): view_type: MetricTimeseriesViewType @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): values["metricValue"] = [] return values - @model_validator(mode="after") - def __transform(self, values): - values.metric_of = MetricOfTimeseries(values.metric_of) - return values + # @model_validator(mode="after") + # def __transform(self): + # self.metric_of = MetricOfTimeseries(self.metric_of) + # return self class CardTable(__CardSchema): @@ -1158,24 +1164,25 @@ class CardTable(__CardSchema): metric_format: MetricExtendedFormatType = Field(default=MetricExtendedFormatType.SESSION_COUNT) @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): if values.get("metricOf") is not None and values.get("metricOf") != MetricOfTable.ISSUES: values["metricValue"] = [] return values - @model_validator(mode="after") - def __transform(self, values): - values.metric_of = MetricOfTable(values.metric_of) - return values + # @model_validator(mode="after") + # def __transform(self): + # self.metric_of = MetricOfTable(self.metric_of) + # return self @model_validator(mode="after") - def __validator(self, values): - if values.metric_of not in (MetricOfTable.ISSUES, MetricOfTable.USER_BROWSER, - MetricOfTable.USER_DEVICE, MetricOfTable.USER_COUNTRY, - MetricOfTable.VISITED_URL): - assert values.metric_format == MetricExtendedFormatType.SESSION_COUNT, \ + def __validator(self): + if self.metric_of not in (MetricOfTable.ISSUES, MetricOfTable.USER_BROWSER, + MetricOfTable.USER_DEVICE, MetricOfTable.USER_COUNTRY, + MetricOfTable.VISITED_URL): + assert self.metric_format == MetricExtendedFormatType.SESSION_COUNT, \ f'metricFormat:{MetricExtendedFormatType.USER_COUNT.value} is not supported for this metricOf' - return values + return self class CardFunnel(__CardSchema): @@ -1184,7 +1191,8 @@ class CardFunnel(__CardSchema): view_type: MetricOtherViewType = Field(...) @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): if values.get("metricOf") and not MetricOfFunnels.has_value(values["metricOf"]): values["metricOf"] = MetricOfFunnels.SESSION_COUNT values["viewType"] = MetricOtherViewType.OTHER_CHART @@ -1192,10 +1200,10 @@ def __enforce_default(self, values): values["series"] = [values["series"][0]] return values - @model_validator(mode="after") - def __transform(self, values): - values.metric_of = MetricOfTimeseries(values.metric_of) - return values + # @model_validator(mode="after") + # def __transform(self): + # self.metric_of = MetricOfFunnels(self.metric_of) + # return self class CardErrors(__CardSchema): @@ -1204,14 +1212,15 @@ class CardErrors(__CardSchema): view_type: MetricOtherViewType = Field(...) @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): values["series"] = [] return values - @model_validator(mode="after") - def __transform(self, values): - values.metric_of = MetricOfErrors(values.metric_of) - return values + # @model_validator(mode="after") + # def __transform(self): + # self.metric_of = MetricOfErrors(self.metric_of) + # return self class CardPerformance(__CardSchema): @@ -1220,14 +1229,15 @@ class CardPerformance(__CardSchema): view_type: MetricOtherViewType = Field(...) @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): values["series"] = [] return values - @model_validator(mode="after") - def __transform(self, values): - values.metric_of = MetricOfPerformance(values.metric_of) - return values + # @model_validator(mode="after") + # def __transform(self): + # self.metric_of = MetricOfPerformance(self.metric_of) + # return self class CardResources(__CardSchema): @@ -1236,14 +1246,15 @@ class CardResources(__CardSchema): view_type: MetricOtherViewType = Field(...) @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): values["series"] = [] return values - @model_validator(mode="after") - def __transform(self, values): - values.metric_of = MetricOfResources(values.metric_of) - return values + # @model_validator(mode="after") + # def __transform(self): + # self.metric_of = MetricOfResources(self.metric_of) + # return self class CardWebVital(__CardSchema): @@ -1252,14 +1263,15 @@ class CardWebVital(__CardSchema): view_type: MetricOtherViewType = Field(...) @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): values["series"] = [] return values - @model_validator(mode="after") - def __transform(self, values): - values.metric_of = MetricOfWebVitals(values.metric_of) - return values + # @model_validator(mode="after") + # def __transform(self): + # self.metric_of = MetricOfWebVitals(self.metric_of) + # return self class CardHeatMap(__CardSchema): @@ -1268,13 +1280,14 @@ class CardHeatMap(__CardSchema): view_type: MetricOtherViewType = Field(...) @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): return values - @model_validator(mode="after") - def __transform(self, values): - values.metric_of = MetricOfHeatMap(values.metric_of) - return values + # @model_validator(mode="after") + # def __transform(self): + # self.metric_of = MetricOfHeatMap(self.metric_of) + # return self class MetricOfInsights(str, Enum): @@ -1287,17 +1300,18 @@ class CardInsights(__CardSchema): view_type: MetricOtherViewType = Field(...) @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): values["view_type"] = MetricOtherViewType.LIST_CHART return values - @model_validator(mode="after") - def __transform(self, values): - values.metric_of = MetricOfInsights(values.metric_of) - return values + # @model_validator(mode="after") + # def __transform(self): + # self.metric_of = MetricOfInsights(self.metric_of) + # return self - @model_validator(mode='after') - def restrictions(self, values): + @model_validator(mode="after") + def restrictions(self): raise ValueError(f"metricType:{MetricType.INSIGHTS} not supported yet.") @@ -1307,7 +1321,8 @@ class CardPathAnalysisSeriesSchema(CardSeriesSchema): density: int = Field(default=4, ge=2, le=10) @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): if values.get("filter") is None and values.get("startTimestamp") and values.get("endTimestamp"): values["filter"] = PathAnalysisSchema(startTimestamp=values["startTimestamp"], endTimestamp=values["endTimestamp"], @@ -1329,44 +1344,45 @@ class CardPathAnalysis(__CardSchema): series: List[CardPathAnalysisSeriesSchema] = Field(default=[]) @model_validator(mode="before") - def __enforce_default(self, values): + @classmethod + def __enforce_default(cls, values): values["viewType"] = MetricOtherViewType.OTHER_CHART.value if values.get("series") is not None and len(values["series"]) > 0: values["series"] = [values["series"][0]] return values @model_validator(mode="after") - def __clean_start_point_and_enforce_metric_value(self, values): + def __clean_start_point_and_enforce_metric_value(self): start_point = [] - for s in values.start_point: + for s in self.start_point: if len(s.value) == 0: continue start_point.append(s) - values.metric_value.append(s.type) + self.metric_value.append(s.type) - values.start_point = start_point - values.metric_value = remove_duplicate_values(values.metric_value) + self.start_point = start_point + self.metric_value = remove_duplicate_values(self.metric_value) - return values + return self - @model_validator(mode='after') - def __validator(self, values): + @model_validator(mode="after") + def __validator(self): s_e_values = {} exclude_values = {} - for f in values.start_point: + for f in self.start_point: s_e_values[f.type] = s_e_values.get(f.type, []) + f.value - for f in values.excludes: + for f in self.excludes: exclude_values[f.type] = exclude_values.get(f.type, []) + f.value assert len( - values.start_point) <= 1, \ + self.start_point) <= 1, \ f"Only 1 startPoint with multiple values OR 1 endPoint with multiple values is allowed" for t in exclude_values: for v in t: assert v not in s_e_values.get(t, []), f"startPoint and endPoint cannot be excluded, value: {v}" - return values + return self # Union of cards-schemas that doesn't change between FOSS and EE @@ -1459,12 +1475,12 @@ class LiveSessionSearchFilterSchema(BaseModel): _transform = model_validator(mode='before')(transform_old_filter_type) - @model_validator(mode='after') - def __validator(self, values): - if values.type is not None and values.type == LiveFilterType.METADATA: - assert values.source is not None, "source should not be null for METADATA type" - assert len(values.source) > 0, "source should not be empty for METADATA type" - return values + @model_validator(mode="after") + def __validator(self): + if self.type is not None and self.type == LiveFilterType.METADATA: + assert self.source is not None, "source should not be null for METADATA type" + assert len(self.source) > 0, "source should not be empty for METADATA type" + return self class LiveSessionsSearchPayloadSchema(_PaginatedSchema): @@ -1473,7 +1489,8 @@ class LiveSessionsSearchPayloadSchema(_PaginatedSchema): order: SortOrderType = Field(default=SortOrderType.DESC) @model_validator(mode="before") - def __transform(self, values): + @classmethod + def __transform(cls, values): if values.get("order") is not None: values["order"] = values["order"].upper() if values.get("filters") is not None: @@ -1528,11 +1545,11 @@ class SessionUpdateNoteSchema(SessionNoteSchema): timestamp: Optional[int] = Field(default=None, ge=-1) is_public: Optional[bool] = Field(default=None) - @model_validator(mode='after') - def __validator(self, values): - assert values.message is not None or values.timestamp is not None or values.is_public is not None, \ + @model_validator(mode="after") + def __validator(self): + assert self.message is not None or self.timestamp is not None or self.is_public is not None, \ "at least 1 attribute should be provided for update" - return values + return self class WebhookType(str, Enum): @@ -1558,7 +1575,8 @@ class HeatMapSessionsSearch(SessionsSearchPayloadSchema): filters: List[Union[SessionSearchFilterSchema, _HeatMapSearchEventRaw]] = Field(default=[]) @model_validator(mode="before") - def __transform(self, values): + @classmethod + def __transform(cls, values): for f in values.get("filters", []): if f.get("type") == FilterType.DURATION: return values @@ -1601,7 +1619,8 @@ class FeatureFlagConditionFilterSchema(BaseModel): sourceOperator: Optional[Union[SearchEventOperator, MathOperator]] = Field(default=None) @model_validator(mode="before") - def __force_is_event(self, values): + @classmethod + def __force_is_event(cls, values): values["isEvent"] = False return values diff --git a/ee/api/schemas/schemas_ee.py b/ee/api/schemas/schemas_ee.py index c4ab618133..9b465c5b41 100644 --- a/ee/api/schemas/schemas_ee.py +++ b/ee/api/schemas/schemas_ee.py @@ -32,7 +32,8 @@ class CurrentContext(schemas.CurrentContext): service_account: bool = Field(default=False) @model_validator(mode="before") - def remove_unsupported_perms(self, values): + @classmethod + def remove_unsupported_perms(cls, values): if values.get("permissions") is not None: perms = [] for p in values["permissions"]: @@ -94,7 +95,8 @@ class TrailSearchPayloadSchema(schemas._PaginatedSchema): order: schemas.SortOrderType = Field(default=schemas.SortOrderType.DESC) @model_validator(mode="before") - def transform_order(self, values): + @classmethod + def transform_order(cls, values): if values.get("order") is None: values["order"] = schemas.SortOrderType.DESC else: @@ -153,9 +155,9 @@ class AssistRecordSearchPayloadSchema(schemas._PaginatedSchema, schemas._TimedSc class CardInsights(schemas.CardInsights): metric_value: List[InsightCategories] = Field(default=[]) - @model_validator(mode='after') - def restrictions(self, values): - return values + @model_validator(mode="after") + def restrictions(self): + return self CardSchema = ORUnion(Union[schemas.__cards_union_base, CardInsights], discriminator='metric_type')