Skip to content

Commit

Permalink
More progress
Browse files Browse the repository at this point in the history
  • Loading branch information
ejsmith committed Mar 14, 2024
1 parent 1de92ac commit 3aeac03
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 101 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,5 @@ _NCrunch_*
# Rider auto-generates .iml files, and contentModel.xml
**/.idea/**/*.iml
**/.idea/**/contentModel.xml
**/.idea/**/modules.xml
**/.idea/**/modules.xml
**/.idea/copilot/chatSessions/
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public override void Visit(TermNode node, IQueryVisitorContext context)
AddOperation(validationResult, node.GetOperationType(), node.Field);

var validationOptions = context.GetValidationOptions();
if (validationOptions != null && !validationOptions.AllowLeadingWildcards && node.Term != null && (node.Term.StartsWith("*") || node.Term.StartsWith("?")))
if (validationOptions is { AllowLeadingWildcards: false } && node.Term != null && (node.Term.StartsWith("*") || node.Term.StartsWith("?")))
context.AddValidationError("Terms must not start with a wildcard: " + node.Term);
}

Expand Down Expand Up @@ -72,7 +72,7 @@ private void AddField(QueryValidationResult validationResult, IFieldQueryNode no
}
else
{
var fields = node.GetDefaultFields(context.DefaultFields);
string[] fields = node.GetDefaultFields(context.DefaultFields);
if (fields == null || fields.Length == 0)
validationResult.ReferencedFields.Add("");
else
Expand Down Expand Up @@ -120,11 +120,11 @@ internal async Task ApplyQueryRestrictions(IQueryVisitorContext context)

if (options.AllowedFields.Count > 0 && result.ReferencedFields.Count > 0)
{
var nonAllowedFields = result.ReferencedFields.Where(f => !String.IsNullOrEmpty(f)).Distinct().ToList();
foreach (var field in options.AllowedFields)
var nonAllowedFields = new List<string>();
foreach (var field in result.ReferencedFields)
{
if (nonAllowedFields.Any(f => !String.IsNullOrEmpty(f) && field.Equals(f)))
nonAllowedFields.Remove(field);
if (!options.AllowedFields.Contains(field, StringComparer.OrdinalIgnoreCase))
nonAllowedFields.Add(field);
}
if (nonAllowedFields.Count > 0)
context.AddValidationError($"Query uses field(s) ({String.Join(",", nonAllowedFields)}) that are not allowed to be used.");
Expand Down
46 changes: 20 additions & 26 deletions src/Foundatio.Parsers.SqlQueries/Extensions/QueryableExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,71 +1,71 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Linq.Dynamic.Core;
using Foundatio.Parsers.LuceneQueries;
using Foundatio.Parsers.LuceneQueries.Visitors;
using Foundatio.Parsers.SqlQueries.Extensions;
using Foundatio.Parsers.SqlQueries.Visitors;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

namespace Foundatio.Parsers.SqlQueries;

public static class QueryableExtensions
{
private static readonly SqlQueryParser _parser = new();
private static readonly ConcurrentDictionary<IEntityType, List<FieldInfo>> _entityFieldCache = new();

public static IQueryable<T> LuceneWhere<T>(this IQueryable<T> source, string query) where T : class
{
if (source is not DbSet<T> dbSet)
throw new ArgumentException("source must be a DbSet<T>", nameof(source));

var serviceProvider = ((IInfrastructure<IServiceProvider>)dbSet).Instance;
var loggerFactory = serviceProvider.GetService<ILoggerFactory>();
var logger = loggerFactory.CreateLogger<SqlQueryParser>();

// use service provider to get global settings that say how to discover and handle custom fields
// support field aliases
var parser = dbSet.GetService<SqlQueryParser>();

var fields = _entityFieldCache.GetOrAdd(dbSet.EntityType, entityType =>
{
var fields = new List<FieldInfo>();
AddFields(fields, entityType);

// lookup and add custom fields
fields.Add(new FieldInfo
{
Field = "age",
Data = {{ "DataDefinitionId", 1 }},
IsNumber = true
});
var dynamicFields = parser.Configuration.EntityTypeDynamicFieldResolver?.Invoke(entityType) ?? [];
fields.AddRange(dynamicFields);

return fields;
});
var validationOptions = new QueryValidationOptions();
foreach (string field in fields.Select(f => f.Field))
validationOptions.AllowedFields.Add(field);

parser.Configuration.SetValidationOptions(validationOptions);
var context = new SqlQueryVisitorContext { Fields = fields };
var node = _parser.Parse(query, context);
var node = parser.Parse(query, context);
var result = ValidationVisitor.Run(node, context);
if (!result.IsValid)
throw new ValidationException("Invalid query: " + result.Message);

string sql = GenerateSqlVisitor.Run(node, context);
return source.Where(sql);
}

private static void AddFields(List<FieldInfo> fields, IEntityType entityType, List<IEntityType> visited = null)
private static void AddFields(List<FieldInfo> fields, IEntityType entityType, List<IEntityType> visited = null, string prefix = null)
{
visited ??= [];
if (visited.Contains(entityType))
return;

prefix ??= "";

visited.Add(entityType);

foreach (var property in entityType.GetProperties())
{
if (property.IsIndex() || property.IsKey())
fields.Add(new FieldInfo
{
Field = property.Name,
Field = prefix + property.Name,
IsNumber = property.ClrType.UnwrapNullable().IsNumeric(),
IsDate = property.ClrType.UnwrapNullable().IsDateTime(),
IsBoolean = property.ClrType.UnwrapNullable().IsBoolean()
Expand All @@ -77,13 +77,7 @@ private static void AddFields(List<FieldInfo> fields, IEntityType entityType, Li
if (visited.Contains(nav.TargetEntityType))
continue;

var field = new FieldInfo
{
Field = nav.Name,
Children = new List<FieldInfo>()
};
fields.Add(field);
AddFields(field.Children, nav.TargetEntityType, visited);
AddFields(fields, nav.TargetEntityType, visited, prefix + nav.Name + ".");
}
}
}
65 changes: 21 additions & 44 deletions src/Foundatio.Parsers.SqlQueries/Extensions/SqlNodeExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Foundatio.Parsers.LuceneQueries.Extensions;
using Foundatio.Parsers.LuceneQueries.Nodes;
using Foundatio.Parsers.SqlQueries.Visitors;

Expand Down Expand Up @@ -62,7 +64,7 @@ public static string ToSqlString(this GroupNode node, ISqlQueryVisitorContext co
public static string ToSqlString(this ExistsNode node, ISqlQueryVisitorContext context)
{
if (String.IsNullOrEmpty(node.Field))
throw new ArgumentException("Field is required for exists node queries.");
context.AddValidationError("Field is required for exists node queries.");

// support overriding the generated query
if (node.TryGetQuery(out string query))
Expand All @@ -82,10 +84,10 @@ public static string ToSqlString(this ExistsNode node, ISqlQueryVisitorContext c
public static string ToSqlString(this MissingNode node, ISqlQueryVisitorContext context)
{
if (String.IsNullOrEmpty(node.Field))
throw new ArgumentException("Field is required for missing node queries.");
context.AddValidationError("Field is required for missing node queries.");

if (!String.IsNullOrEmpty(node.Prefix))
throw new ArgumentException("Prefix is not supported for term range queries.");
context.AddValidationError("Prefix is not supported for term range queries.");

// support overriding the generated query
if (node.TryGetQuery(out string query))
Expand All @@ -102,54 +104,25 @@ public static string ToSqlString(this MissingNode node, ISqlQueryVisitorContext
return builder.ToString();
}

public static FieldInfo GetFieldInfo(List<FieldInfo> fields, string field)
{
return fields.FirstOrDefault(f => f.Field.Equals(field, StringComparison.OrdinalIgnoreCase));
}

public static string ToSqlString(this TermNode node, ISqlQueryVisitorContext context)
{
if (String.IsNullOrEmpty(node.Field))
throw new ArgumentException("Field is required for term node queries.");
context.AddValidationError("Field is required for term node queries.");

if (!String.IsNullOrEmpty(node.Prefix))
throw new ArgumentException("Prefix is not supported for term range queries.");

// TODO: This needs to resolve the field recursively
var field = context.Fields.FirstOrDefault(f => f.Field.Equals(node.Field, StringComparison.OrdinalIgnoreCase));

// TODO: Remove this hard coded
if (field != null && field.Data.TryGetValue("DataDefinitionId", out object value) && value is int dataDefinitionId)
{
var customFieldBuilder = new StringBuilder();

customFieldBuilder.Append("DataValues.Any(DataDefinitionId = ");
customFieldBuilder.Append(dataDefinitionId);
customFieldBuilder.Append(" AND ");
if (field is { IsNumber: true })
customFieldBuilder.Append("NumberValue");
else if (field is { IsBoolean: true })
customFieldBuilder.Append("BooleanValue");
else if (field is { IsDate: true })
customFieldBuilder.Append("DateValue");
else
customFieldBuilder.Append("StringValue");

customFieldBuilder.Append(" = ");
if (field is { IsNumber: true } or { IsBoolean: true })
{
customFieldBuilder.Append(node.Term);
}
else
{
customFieldBuilder.Append("\"");
customFieldBuilder.Append(node.Term);
customFieldBuilder.Append("\"");
}
customFieldBuilder.Append(")");

node.SetQuery(customFieldBuilder.ToString());
}
context.AddValidationError("Prefix is not supported for term range queries.");

// support overriding the generated query
if (node.TryGetQuery(out string query))
return query;

var field = GetFieldInfo(context.Fields, node.Field);

var builder = new StringBuilder();

if (node.IsNegated.HasValue && node.IsNegated.Value)
Expand All @@ -172,16 +145,20 @@ public static string ToSqlString(this TermNode node, ISqlQueryVisitorContext con
public static string ToSqlString(this TermRangeNode node, ISqlQueryVisitorContext context)
{
if (String.IsNullOrEmpty(node.Field))
throw new ArgumentException("Field is required for term range queries.");
context.AddValidationError("Field is required for term range queries.");
if (!String.IsNullOrEmpty(node.Boost))
throw new ArgumentException("Boost is not supported for term range queries.");
context.AddValidationError("Boost is not supported for term range queries.");
if (!String.IsNullOrEmpty(node.Proximity))
throw new ArgumentException("Proximity is not supported for term range queries.");
context.AddValidationError("Proximity is not supported for term range queries.");

// support overriding the generated query
if (node.TryGetQuery(out string query))
return query;

var field = GetFieldInfo(context.Fields, node.Field);
if (!field.IsNumber && !field.IsDate)
context.AddValidationError("Field must be a number or date for term range queries.");

var builder = new StringBuilder();

if (node.IsNegated.HasValue && node.IsNegated.Value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="8.0" />
<PackageReference Include="Exceptionless.DateTimeExtensions" Version="3.4.3" />
<PackageReference Include="System.Text.Json" Version="8.0" />
<PackageReference Include="System.Linq.Dynamic.Core" Version="1.3.9" />
<PackageReference Include="System.Linq.Dynamic.Core" Version="1.3.10" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0" />
</ItemGroup>
<ItemGroup>
Expand Down
20 changes: 14 additions & 6 deletions src/Foundatio.Parsers.SqlQueries/SqlQueryParserConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using System.Threading.Tasks;
using Foundatio.Parsers.LuceneQueries;
using Foundatio.Parsers.LuceneQueries.Visitors;
using Foundatio.Parsers.SqlQueries.Visitors;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

Expand All @@ -12,23 +14,22 @@ public class SqlQueryParserConfiguration {
private ILogger _logger = NullLogger.Instance;

public SqlQueryParserConfiguration() {
//AddQueryVisitor(new CombineQueriesVisitor(), 10000);
AddSortVisitor(new TermToFieldVisitor(), 0);
AddAggregationVisitor(new AssignOperationTypeVisitor(), 0);
//AddAggregationVisitor(new CombineAggregationsVisitor(), 10000);
AddVisitor(new FieldResolverQueryVisitor((field, context) => FieldResolver != null ? FieldResolver(field, context) : Task.FromResult<string>(null)), 10);
AddVisitor(new ValidationVisitor(), 30);
}

public ILoggerFactory LoggerFactory { get; private set; } = NullLoggerFactory.Instance;
public string[] DefaultFields { get; private set; }

public QueryFieldResolver FieldResolver { get; private set; }
public EntityTypeDynamicFieldsResolver EntityTypeDynamicFieldResolver { get; private set; }
public IncludeResolver IncludeResolver { get; private set; }
//public ElasticMappingResolver MappingResolver { get; private set; }
public QueryValidationOptions ValidationOptions { get; private set; }
public ChainedQueryVisitor SortVisitor { get; } = new ChainedQueryVisitor();
public ChainedQueryVisitor QueryVisitor { get; } = new ChainedQueryVisitor();
public ChainedQueryVisitor AggregationVisitor { get; } = new ChainedQueryVisitor();
public ChainedQueryVisitor SortVisitor { get; } = new();
public ChainedQueryVisitor QueryVisitor { get; } = new();
public ChainedQueryVisitor AggregationVisitor { get; } = new();

public SqlQueryParserConfiguration SetLoggerFactory(ILoggerFactory loggerFactory) {
LoggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
Expand All @@ -42,6 +43,11 @@ public SqlQueryParserConfiguration SetDefaultFields(string[] fields) {
return this;
}

public SqlQueryParserConfiguration UseEntityTypeDynamicFieldResolver(EntityTypeDynamicFieldsResolver resolver) {
EntityTypeDynamicFieldResolver = resolver;
return this;
}

public SqlQueryParserConfiguration UseFieldResolver(QueryFieldResolver resolver, int priority = 10) {
FieldResolver = resolver;
ReplaceVisitor<FieldResolverQueryVisitor>(new FieldResolverQueryVisitor(resolver), priority);
Expand Down Expand Up @@ -221,3 +227,5 @@ public SqlQueryParserConfiguration AddAggregationVisitorAfter<T>(IChainableQuery

#endregion
}

public delegate List<FieldInfo> EntityTypeDynamicFieldsResolver(IEntityType entityType);
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ public class FieldInfo
public bool IsDate { get; set; }
public bool IsBoolean { get; set; }
public IDictionary<string, object> Data { get; set; } = new Dictionary<string, object>();
public List<FieldInfo> Children { get; set; }
}
2 changes: 1 addition & 1 deletion tests/Foundatio.Parsers.SqlQueries.Tests/SampleContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ protected override void OnModelCreating(ModelBuilder modelBuilder)
modelBuilder.Entity<Employee>().HasIndex(e => new { e.FullName, e.Title });

// Company
modelBuilder.Entity<Company>().HasIndex(e => new { e.Name, e.Description });
modelBuilder.Entity<Company>().HasIndex(e => new { e.Name });

// DataDefinition
modelBuilder.Entity<DataDefinition>().Property(c => c.DataType).IsRequired();
Expand Down
Loading

0 comments on commit 3aeac03

Please sign in to comment.