Skip to content

Commit

Permalink
Progress
Browse files Browse the repository at this point in the history
  • Loading branch information
ejsmith committed Mar 13, 2024
1 parent dbe9842 commit 39da8f6
Show file tree
Hide file tree
Showing 10 changed files with 572 additions and 333 deletions.
80 changes: 80 additions & 0 deletions src/Foundatio.Parsers.SqlQueries/Extensions/QueryableExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Dynamic.Core;
using Foundatio.Parsers.SqlQueries.Extensions;
using Foundatio.Parsers.SqlQueries.Visitors;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

namespace Foundatio.Parsers.SqlQueries;

public static class QueryableExtensions
{
private static readonly SqlQueryParser _parser = new();
private static readonly ConcurrentDictionary<IEntityType, List<FieldInfo>> _entityFieldCache = new();

public static IQueryable<T> LuceneWhere<T>(this IQueryable<T> source, string query) where T : class
{
if (source is not DbSet<T> dbSet)
throw new ArgumentException("source must be a DbSet<T>", nameof(source));

var serviceProvider = ((IInfrastructure<IServiceProvider>)dbSet).Instance;
var loggerFactory = serviceProvider.GetService<ILoggerFactory>();
var logger = loggerFactory.CreateLogger<SqlQueryParser>();

// use service provider to get global settings that say how to discover and handle custom fields
// support field aliases

var fields = _entityFieldCache.GetOrAdd(dbSet.EntityType, entityType =>
{
var fields = new List<FieldInfo>();
AddFields(fields, entityType);
return fields;
});

var context = new SqlQueryVisitorContext { Fields = fields };
var node = _parser.Parse(query, context);
string sql = GenerateSqlVisitor.Run(node, context);
return source.Where(sql);
}

private static void AddFields(List<FieldInfo> fields, IEntityType entityType, List<IEntityType> visited = null)
{
visited ??= [];
if (visited.Contains(entityType))
return;

visited.Add(entityType);

foreach (var property in entityType.GetProperties())
{
if (property.IsIndex() || property.IsKey())
fields.Add(new FieldInfo
{
Field = property.Name,
IsNumber = property.ClrType.UnwrapNullable().IsNumeric(),
IsDate = property.ClrType.UnwrapNullable().IsDateTime(),
IsBoolean = property.ClrType.UnwrapNullable().IsBoolean()
});
}

foreach (var nav in entityType.GetNavigations())
{
if (visited.Contains(nav.TargetEntityType))
continue;

var field = new FieldInfo
{
Field = nav.Name,
Children = new List<FieldInfo>()
};
fields.Add(field);
AddFields(field.Children, nav.TargetEntityType, visited);
}
}
}
174 changes: 174 additions & 0 deletions src/Foundatio.Parsers.SqlQueries/Extensions/SqlNodeExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
using System;
using System.Linq;
using System.Text;
using Foundatio.Parsers.LuceneQueries.Nodes;
using Foundatio.Parsers.SqlQueries.Visitors;

namespace Foundatio.Parsers.SqlQueries.Extensions;

public static class SqlNodeExtensions
{
public static string ToSqlString(this GroupNode node, ISqlQueryVisitorContext context)
{
if (node.Left == null && node.Right == null)
return String.Empty;

var defaultOperator = context.DefaultOperator;

var builder = new StringBuilder();
var op = node.Operator != GroupOperator.Default ? node.Operator : defaultOperator;

if (node.IsNegated.HasValue && node.IsNegated.Value)
builder.Append("NOT ");

builder.Append(node.Prefix);

if (!String.IsNullOrEmpty(node.Field))
builder.Append(node.Field).Append(':');

if (node.HasParens)
builder.Append("(");

if (node.Left != null)
builder.Append(node.Left is GroupNode groupNode ? groupNode.ToSqlString(context) : node.Left.ToSqlString(context));

if (node.Left != null && node.Right != null)
{
if (op == GroupOperator.Or || (op == GroupOperator.Default && defaultOperator == GroupOperator.Or))
builder.Append(" OR ");
else if (node.Right != null)
builder.Append(" AND ");
}

if (node.Right != null)
builder.Append(node.Right is GroupNode groupNode ? groupNode.ToSqlString(context) : node.Right.ToSqlString(context));

if (node.HasParens)
builder.Append(")");

if (node.Proximity != null)
builder.Append("~" + node.Proximity);

if (node.Boost != null)
builder.Append("^" + node.Boost);

return builder.ToString();
}

public static string ToSqlString(this ExistsNode node, ISqlQueryVisitorContext context)
{
if (String.IsNullOrEmpty(node.Field))
throw new ArgumentException("Field is required for exists node queries.");

var builder = new StringBuilder();

if (node.IsNegated.HasValue && node.IsNegated.Value)
builder.Append("NOT ");

builder.Append(node.Field);
builder.Append(" IS NOT NULL");

return builder.ToString();
}

public static string ToSqlString(this MissingNode node, ISqlQueryVisitorContext context)
{
if (String.IsNullOrEmpty(node.Field))
throw new ArgumentException("Field is required for missing node queries.");

if (!String.IsNullOrEmpty(node.Prefix))
throw new ArgumentException("Prefix is not supported for term range queries.");

var builder = new StringBuilder();

if (node.IsNegated.HasValue && node.IsNegated.Value)
builder.Append("NOT ");

builder.Append(node.Field);
builder.Append(" IS NULL");

return builder.ToString();
}

public static string ToSqlString(this TermNode node, ISqlQueryVisitorContext context)
{
if (String.IsNullOrEmpty(node.Field))
throw new ArgumentException("Field is required for term node queries.");

if (!String.IsNullOrEmpty(node.Prefix))
throw new ArgumentException("Prefix is not supported for term range queries.");

var builder = new StringBuilder();

if (node.IsNegated.HasValue && node.IsNegated.Value)
builder.Append("NOT ");

builder.Append(node.Field);
if (node.IsNegated.HasValue && node.IsNegated.Value)
builder.Append(" != ");
else
builder.Append(" = ");

// TODO: This needs to resolve the field recursively
var field = context.Fields.FirstOrDefault(f => f.Field.Equals(node.Field, StringComparison.OrdinalIgnoreCase));
if (field != null && (field.IsNumber || field.IsBoolean))
builder.Append(node.Term);
else
builder.Append("\"" + node.Term + "\"");

return builder.ToString();
}

public static string ToSqlString(this TermRangeNode node, ISqlQueryVisitorContext context)
{
if (String.IsNullOrEmpty(node.Field))
throw new ArgumentException("Field is required for term range queries.");
if (!String.IsNullOrEmpty(node.Boost))
throw new ArgumentException("Boost is not supported for term range queries.");
if (!String.IsNullOrEmpty(node.Proximity))
throw new ArgumentException("Proximity is not supported for term range queries.");

var builder = new StringBuilder();

if (node.IsNegated.HasValue && node.IsNegated.Value)
builder.Append("NOT ");

if (node.Min != null && node.Max != null)
builder.Append("(");

if (node.Min != null)
{
builder.Append(node.Field);
builder.Append(node.MinInclusive == true ? " >= " : " > ");
builder.Append(node.Min);
}

if (node.Min != null && node.Max != null)
builder.Append(" AND ");

if (node.Max != null)
{
builder.Append(node.Field);
builder.Append(node.MaxInclusive == true ? " <= " : " < ");
builder.Append(node.Max);
}

if (node.Min != null && node.Max != null)
builder.Append(")");

return builder.ToString();
}

public static string ToSqlString(this IQueryNode node, ISqlQueryVisitorContext context)
{
return node switch
{
GroupNode groupNode => groupNode.ToSqlString(context),
ExistsNode existsNode => existsNode.ToSqlString(context),
MissingNode missingNode => missingNode.ToSqlString(context),
TermNode termNode => termNode.ToSqlString(context),
TermRangeNode termRangeNode => termRangeNode.ToSqlString(context),
_ => throw new NotSupportedException($"Node type {node.GetType().Name} is not supported.")
};
}
}
42 changes: 42 additions & 0 deletions src/Foundatio.Parsers.SqlQueries/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using System;
using System.Collections.Generic;

namespace Foundatio.Parsers.SqlQueries.Extensions;

public static class TypeExtensions
{
private static readonly IList<Type> _integerTypes = new List<Type>()
{
typeof (byte),
typeof (short),
typeof (int),
typeof (long),
typeof (sbyte),
typeof (ushort),
typeof (uint),
typeof (ulong),
typeof (byte?),
typeof (short?),
typeof (int?),
typeof (long?),
typeof (sbyte?),
typeof (ushort?),
typeof (uint?),
typeof (ulong?)
};

public static Type UnwrapNullable(this Type type)
{
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>))
return Nullable.GetUnderlyingType(type);

return type;
}

public static bool IsString(this Type type) => type == typeof(string);
public static bool IsDateTime(this Type typeToCheck) => typeToCheck == typeof(DateTime) || typeToCheck == typeof(DateTime?);
public static bool IsBoolean(this Type typeToCheck) => typeToCheck == typeof(bool) || typeToCheck == typeof(bool?);
public static bool IsNumeric(this Type type) => type.IsFloatingPoint() || type.IsIntegerBased();
public static bool IsIntegerBased(this Type type) => _integerTypes.Contains(type);
public static bool IsFloatingPoint(this Type type) => type == typeof(decimal) || type == typeof(float) || type == typeof(double);
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\..\build\common.props" />
<PropertyGroup>
<TargetFrameworks>net8.0;</TargetFrameworks>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="8.0" />
<PackageReference Include="Exceptionless.DateTimeExtensions" Version="3.4.3" />
<PackageReference Include="System.Text.Json" Version="6.0" />
<PackageReference Include="System.Text.Json" Version="8.0" />
<PackageReference Include="System.Linq.Dynamic.Core" Version="1.3.9" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="6.0" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Foundatio.Parsers.LuceneQueries\Foundatio.Parsers.LuceneQueries.csproj" />
Expand Down
26 changes: 21 additions & 5 deletions src/Foundatio.Parsers.SqlQueries/Visitors/GenerateSqlVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Threading.Tasks;
using Foundatio.Parsers.LuceneQueries.Nodes;
using Foundatio.Parsers.LuceneQueries.Visitors;
using Foundatio.Parsers.SqlQueries.Extensions;

namespace Foundatio.Parsers.SqlQueries.Visitors;

Expand All @@ -12,29 +13,44 @@ public class GenerateSqlVisitor : QueryNodeVisitorWithResultBase<string>

public override Task VisitAsync(GroupNode node, IQueryVisitorContext context)
{
_builder.Append(node.ToSqlString(context?.DefaultOperator ?? GroupOperator.Default));
if (context is not ISqlQueryVisitorContext sqlContext)
throw new InvalidOperationException("The context must be an ISqlQueryVisitorContext.");

_builder.Append(node.ToSqlString(sqlContext));

return Task.CompletedTask;
}

public override void Visit(TermNode node, IQueryVisitorContext context)
{
_builder.Append(node.ToSqlString());
if (context is not ISqlQueryVisitorContext sqlContext)
throw new InvalidOperationException("The context must be an ISqlQueryVisitorContext.");

_builder.Append(node.ToSqlString(sqlContext));
}

public override void Visit(TermRangeNode node, IQueryVisitorContext context)
{
_builder.Append(node.ToSqlString());
if (context is not ISqlQueryVisitorContext sqlContext)
throw new InvalidOperationException("The context must be an ISqlQueryVisitorContext.");

_builder.Append(node.ToSqlString(sqlContext));
}

public override void Visit(ExistsNode node, IQueryVisitorContext context)
{
_builder.Append(node.ToSqlString());
if (context is not ISqlQueryVisitorContext sqlContext)
throw new InvalidOperationException("The context must be an ISqlQueryVisitorContext.");

_builder.Append(node.ToSqlString(sqlContext));
}

public override void Visit(MissingNode node, IQueryVisitorContext context)
{
_builder.Append(node.ToSqlString());
if (context is not ISqlQueryVisitorContext sqlContext)
throw new InvalidOperationException("The context must be an ISqlQueryVisitorContext.");

_builder.Append(node.ToSqlString(sqlContext));
}

public override async Task<string> AcceptAsync(IQueryNode node, IQueryVisitorContext context)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
using System;
using System.Threading.Tasks;
using System.Collections.Generic;
using Foundatio.Parsers.LuceneQueries.Visitors;

namespace Foundatio.Parsers.SqlQueries.Visitors {
public interface ISqlQueryVisitorContext : IQueryVisitorContext {
Func<Task<string>> DefaultTimeZone { get; set; }
bool UseScoring { get; set; }
//ElasticMappingResolver MappingResolver { get; set; }
List<FieldInfo> Fields { get; set; }
}
}
Loading

0 comments on commit 39da8f6

Please sign in to comment.