Skip to content

Commit 4464d47

Browse files
authored
Fix DateOnly.Add{Days,Months,Years} (#2964)
Fixes #2888
1 parent fa5ba2e commit 4464d47

File tree

3 files changed

+148
-78
lines changed

3 files changed

+148
-78
lines changed

src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlDateTimeMethodTranslator.cs

+74-52
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ public class NpgsqlDateTimeMethodTranslator : IMethodCallTranslator
2929
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddSeconds), new[] { typeof(double) })!, "secs" },
3030
//{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMilliseconds), new[] { typeof(double) })!, "milliseconds" }
3131

32-
{ typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddYears), new[] { typeof(int) })!, "years" },
33-
{ typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddMonths), new[] { typeof(int) })!, "months" },
34-
{ typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddDays), new[] { typeof(int) })!, "days" },
32+
// DateOnly.AddDays, AddMonths and AddYears have a specialized translation, see below
3533
{ typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.AddHours), new[] { typeof(int) })!, "hours" },
3634
{ typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.AddMinutes), new[] { typeof(int) })!, "mins" },
3735
};
@@ -60,6 +58,15 @@ private static readonly MethodInfo DateOnly_Distance
6058
= typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod(
6159
nameof(NpgsqlDbFunctionsExtensions.Distance), new[] { typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly) })!;
6260

61+
private static readonly MethodInfo DateOnly_AddDays
62+
= typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddDays), new[] { typeof(int) })!;
63+
64+
private static readonly MethodInfo DateOnly_AddMonths
65+
= typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddMonths), new[] { typeof(int) })!;
66+
67+
private static readonly MethodInfo DateOnly_AddYears
68+
= typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddYears), new[] { typeof(int) })!;
69+
6370
private static readonly MethodInfo TimeOnly_FromDateTime
6471
= typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.FromDateTime), new[] { typeof(DateTime) })!;
6572

@@ -118,60 +125,21 @@ public NpgsqlDateTimeMethodTranslator(
118125
MethodInfo method,
119126
IReadOnlyList<SqlExpression> arguments,
120127
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
121-
=> TranslateDatePart(instance, method, arguments)
122-
?? TranslateDateTime(instance, method, arguments)
123-
?? TranslateDateOnly(instance, method, arguments)
124-
?? TranslateTimeOnly(instance, method, arguments)
125-
?? TranslateTimeZoneInfo(method, arguments);
128+
=> TranslateDateTime(instance, method, arguments)
129+
?? TranslateDateOnly(instance, method, arguments)
130+
?? TranslateTimeOnly(instance, method, arguments)
131+
?? TranslateTimeZoneInfo(method, arguments)
132+
?? TranslateDatePart(instance, method, arguments);
126133

127134
private SqlExpression? TranslateDatePart(
128135
SqlExpression? instance,
129136
MethodInfo method,
130137
IReadOnlyList<SqlExpression> arguments)
131-
{
132-
if (instance is null || !MethodInfoDatePartMapping.TryGetValue(method, out var datePart))
133-
{
134-
return null;
135-
}
136-
137-
if (arguments[0] is not { } interval)
138-
{
139-
return null;
140-
}
141-
142-
// Note: ideally we'd simply generate a PostgreSQL interval expression, but the .NET mapping of that is TimeSpan,
143-
// which does not work for months, years, etc. So we generate special fragments instead.
144-
if (interval is SqlConstantExpression constantExpression)
145-
{
146-
// We generate constant intervals as INTERVAL '1 days'
147-
if (constantExpression.Type == typeof(double)
148-
&& ((double)constantExpression.Value! >= int.MaxValue || (double)constantExpression.Value <= int.MinValue))
149-
{
150-
return null;
151-
}
152-
153-
interval = _sqlExpressionFactory.Fragment(FormattableString.Invariant($"INTERVAL '{constantExpression.Value} {datePart}'"));
154-
}
155-
else
156-
{
157-
// For non-constants, we can't parameterize INTERVAL '1 days'. Instead, we use CAST($1 || ' days' AS interval).
158-
// Note that a make_interval() function also exists, but accepts only int (for all fields except for
159-
// seconds), so we don't use it.
160-
// Note: we instantiate SqlBinaryExpression manually rather than via sqlExpressionFactory because
161-
// of the non-standard Add expression (concatenate int with text)
162-
interval = _sqlExpressionFactory.Convert(
163-
new SqlBinaryExpression(
164-
ExpressionType.Add,
165-
_sqlExpressionFactory.Convert(interval, typeof(string), _textMapping),
166-
_sqlExpressionFactory.Constant(' ' + datePart, _textMapping),
167-
typeof(string),
168-
_textMapping),
169-
typeof(TimeSpan),
170-
_intervalMapping);
171-
}
172-
173-
return _sqlExpressionFactory.Add(instance, interval, instance.TypeMapping);
174-
}
138+
=> instance is not null
139+
&& MethodInfoDatePartMapping.TryGetValue(method, out var datePart)
140+
&& CreateIntervalExpression(arguments[0], datePart) is SqlExpression interval
141+
? _sqlExpressionFactory.Add(instance, interval, instance.TypeMapping)
142+
: null;
175143

176144
private SqlExpression? TranslateDateTime(
177145
SqlExpression? instance,
@@ -270,6 +238,28 @@ public NpgsqlDateTimeMethodTranslator(
270238
typeof(DateTime),
271239
_timestampMapping);
272240
}
241+
242+
// In PG, date + int = date (int interpreted as days)
243+
if (method == DateOnly_AddDays)
244+
{
245+
return _sqlExpressionFactory.Add(instance, arguments[0]);
246+
}
247+
248+
// For months and years, date + interval yields a timestamp (since interval could have a time component), so we need to cast
249+
// the results back to date
250+
if (method == DateOnly_AddMonths
251+
&& CreateIntervalExpression(arguments[0], "months") is SqlExpression interval1)
252+
{
253+
return _sqlExpressionFactory.Convert(
254+
_sqlExpressionFactory.Add(instance, interval1, instance.TypeMapping), typeof(DateOnly));
255+
}
256+
257+
if (method == DateOnly_AddYears
258+
&& CreateIntervalExpression(arguments[0], "years") is SqlExpression interval2)
259+
{
260+
return _sqlExpressionFactory.Convert(
261+
_sqlExpressionFactory.Add(instance, interval2, instance.TypeMapping), typeof(DateOnly));
262+
}
273263
}
274264

275265
return null;
@@ -360,4 +350,36 @@ public NpgsqlDateTimeMethodTranslator(
360350

361351
return null;
362352
}
353+
354+
private SqlExpression? CreateIntervalExpression(SqlExpression intervalNum, string datePart)
355+
{
356+
// Note: ideally we'd simply generate a PostgreSQL interval expression, but the .NET mapping of that is TimeSpan,
357+
// which does not work for months, years, etc. So we generate special fragments instead.
358+
if (intervalNum is SqlConstantExpression constantExpression)
359+
{
360+
// We generate constant intervals as INTERVAL '1 days'
361+
if (constantExpression.Type == typeof(double)
362+
&& ((double)constantExpression.Value! >= int.MaxValue || (double)constantExpression.Value <= int.MinValue))
363+
{
364+
return null;
365+
}
366+
367+
return _sqlExpressionFactory.Fragment(FormattableString.Invariant($"INTERVAL '{constantExpression.Value} {datePart}'"));
368+
}
369+
370+
// For non-constants, we can't parameterize INTERVAL '1 days'. Instead, we use CAST($1 || ' days' AS interval).
371+
// Note that a make_interval() function also exists, but accepts only int (for all fields except for
372+
// seconds), so we don't use it.
373+
// Note: we instantiate SqlBinaryExpression manually rather than via sqlExpressionFactory because
374+
// of the non-standard Add expression (concatenate int with text)
375+
return _sqlExpressionFactory.Convert(
376+
new SqlBinaryExpression(
377+
ExpressionType.Add,
378+
_sqlExpressionFactory.Convert(intervalNum, typeof(string), _textMapping),
379+
_sqlExpressionFactory.Constant(' ' + datePart, _textMapping),
380+
typeof(string),
381+
_textMapping),
382+
typeof(TimeSpan),
383+
_intervalMapping);
384+
}
363385
}

src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs

+1
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ private SqlBinaryExpression ApplyTypeMappingOnSqlBinary(SqlBinaryExpression bina
438438
case ExpressionType.Add or ExpressionType.Subtract
439439
when right.Type == typeof(TimeSpan)
440440
&& (left.Type == typeof(DateTime) || left.Type == typeof(DateTimeOffset) || left.Type == typeof(TimeOnly))
441+
|| right.Type == typeof(int) && left.Type == typeof(DateOnly)
441442
|| right.Type.FullName == "NodaTime.Period"
442443
&& left.Type.FullName is "NodaTime.LocalDateTime" or "NodaTime.LocalDate" or "NodaTime.LocalTime"
443444
|| right.Type.FullName == "NodaTime.Duration"

test/EFCore.PG.FunctionalTests/Query/GearsOfWarQueryNpgsqlTest.cs

+73-26
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,7 @@ WHERE make_date(date_part('year', m."Date")::int, date_part('month', m."Date")::
475475
[ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2039")]
476476
public override async Task Where_DateOnly_Year(bool async)
477477
{
478-
await AssertQuery(
479-
async,
480-
ss => ss.Set<Mission>().Where(m => m.Date.Year == 1990).AsTracking());
478+
await base.Where_DateOnly_Year(async);
481479

482480
AssertSql(
483481
"""
@@ -489,9 +487,7 @@ WHERE date_part('year', m."Date")::int = 1990
489487

490488
public override async Task Where_DateOnly_Month(bool async)
491489
{
492-
await AssertQuery(
493-
async,
494-
ss => ss.Set<Mission>().Where(m => m.Date.Month == 11).AsTracking());
490+
await base.Where_DateOnly_Month(async);
495491

496492
AssertSql(
497493
"""
@@ -503,9 +499,7 @@ WHERE date_part('month', m."Date")::int = 11
503499

504500
public override async Task Where_DateOnly_Day(bool async)
505501
{
506-
await AssertQuery(
507-
async,
508-
ss => ss.Set<Mission>().Where(m => m.Date.Day == 10).AsTracking());
502+
await base.Where_DateOnly_Day(async);
509503

510504
AssertSql(
511505
"""
@@ -517,9 +511,7 @@ WHERE date_part('day', m."Date")::int = 10
517511

518512
public override async Task Where_DateOnly_DayOfYear(bool async)
519513
{
520-
await AssertQuery(
521-
async,
522-
ss => ss.Set<Mission>().Where(m => m.Date.DayOfYear == 314).AsTracking());
514+
await base.Where_DateOnly_DayOfYear(async);
523515

524516
AssertSql(
525517
"""
@@ -531,9 +523,7 @@ WHERE date_part('doy', m."Date")::int = 314
531523

532524
public override async Task Where_DateOnly_DayOfWeek(bool async)
533525
{
534-
await AssertQuery(
535-
async,
536-
ss => ss.Set<Mission>().Where(m => m.Date.DayOfWeek == DayOfWeek.Saturday).AsTracking());
526+
await base.Where_DateOnly_DayOfWeek(async);
537527

538528
AssertSql(
539529
"""
@@ -545,43 +535,100 @@ WHERE floor(date_part('dow', m."Date"))::int = 6
545535

546536
public override async Task Where_DateOnly_AddYears(bool async)
547537
{
548-
await AssertQuery(
549-
async,
550-
ss => ss.Set<Mission>().Where(m => m.Date.AddYears(3) == new DateOnly(1993, 11, 10)).AsTracking());
538+
await base.Where_DateOnly_AddYears(async);
551539

552540
AssertSql(
553541
"""
554542
SELECT m."Id", m."CodeName", m."Date", m."Duration", m."Rating", m."Time", m."Timeline"
555543
FROM "Missions" AS m
556-
WHERE m."Date" + INTERVAL '3 years' = DATE '1993-11-10'
544+
WHERE CAST(m."Date" + INTERVAL '3 years' AS date) = DATE '1993-11-10'
557545
""");
558546
}
559547

560548
public override async Task Where_DateOnly_AddMonths(bool async)
561549
{
562-
await AssertQuery(
563-
async,
564-
ss => ss.Set<Mission>().Where(m => m.Date.AddMonths(3) == new DateOnly(1991, 2, 10)).AsTracking());
550+
await base.Where_DateOnly_AddMonths(async);
565551

566552
AssertSql(
567553
"""
568554
SELECT m."Id", m."CodeName", m."Date", m."Duration", m."Rating", m."Time", m."Timeline"
569555
FROM "Missions" AS m
570-
WHERE m."Date" + INTERVAL '3 months' = DATE '1991-02-10'
556+
WHERE CAST(m."Date" + INTERVAL '3 months' AS date) = DATE '1991-02-10'
571557
""");
572558
}
573559

574560
public override async Task Where_DateOnly_AddDays(bool async)
561+
{
562+
await base.Where_DateOnly_AddDays(async);
563+
564+
AssertSql(
565+
"""
566+
SELECT m."Id", m."CodeName", m."Date", m."Duration", m."Rating", m."Time", m."Timeline"
567+
FROM "Missions" AS m
568+
WHERE m."Date" + 3 = DATE '1990-11-13'
569+
""");
570+
}
571+
572+
[ConditionalTheory]
573+
[MemberData(nameof(IsAsyncData))]
574+
public virtual async Task Select_DateOnly_AddDays(bool async)
575575
{
576576
await AssertQuery(
577577
async,
578-
ss => ss.Set<Mission>().Where(m => m.Date.AddDays(3) == new DateOnly(1990, 11, 13)).AsTracking());
578+
ss => ss.Set<Mission>()
579+
// We filter out DateOnly.MinValue which maps to -infinity
580+
.Where(m => m.Date != DateOnly.MinValue)
581+
.Select(m => m.Date.AddDays(3)));
579582

580583
AssertSql(
581584
"""
582-
SELECT m."Id", m."CodeName", m."Date", m."Duration", m."Rating", m."Time", m."Timeline"
585+
@__MinValue_0='01/01/0001' (DbType = Date)
586+
587+
SELECT m."Date" + 3
588+
FROM "Missions" AS m
589+
WHERE m."Date" <> @__MinValue_0
590+
""");
591+
}
592+
593+
[ConditionalTheory]
594+
[MemberData(nameof(IsAsyncData))]
595+
public virtual async Task Select_DateOnly_AddMonths(bool async)
596+
{
597+
await AssertQuery(
598+
async,
599+
ss => ss.Set<Mission>()
600+
// We filter out DateOnly.MinValue which maps to -infinity
601+
.Where(m => m.Date != DateOnly.MinValue)
602+
.Select(m => m.Date.AddMonths(3)));
603+
604+
AssertSql(
605+
"""
606+
@__MinValue_0='01/01/0001' (DbType = Date)
607+
608+
SELECT CAST(m."Date" + INTERVAL '3 months' AS date)
609+
FROM "Missions" AS m
610+
WHERE m."Date" <> @__MinValue_0
611+
""");
612+
}
613+
614+
[ConditionalTheory]
615+
[MemberData(nameof(IsAsyncData))]
616+
public virtual async Task Select_DateOnly_AddYears(bool async)
617+
{
618+
await AssertQuery(
619+
async,
620+
ss => ss.Set<Mission>()
621+
// We filter out DateOnly.MinValue which maps to -infinity
622+
.Where(m => m.Date != DateOnly.MinValue)
623+
.Select(m => m.Date.AddYears(3)));
624+
625+
AssertSql(
626+
"""
627+
@__MinValue_0='01/01/0001' (DbType = Date)
628+
629+
SELECT CAST(m."Date" + INTERVAL '3 years' AS date)
583630
FROM "Missions" AS m
584-
WHERE m."Date" + INTERVAL '3 days' = DATE '1990-11-13'
631+
WHERE m."Date" <> @__MinValue_0
585632
""");
586633
}
587634

0 commit comments

Comments
 (0)