Skip to content

Commit

Permalink
.Net: Add JsonElement String to Primitive Implicit Conversion Support…
Browse files Browse the repository at this point in the history
… (SLM Function Calling) (#9784)

### Motivation and Context

In the original logic the conversion was giving priority for existence
of converters when the parameter value was a
JsonElement/JsonDocument/JsonNode, this change checks if the argument is
one of those types first and use the proper JSON conversion.

This change also bring some resiliency when the `JsonElement` provided
is a `string` for primitive types like `boolean` and C# numeric types.

This change improves function calling experience when using local models
that send JSON string argument values ("1" or "true") instead of the
expected JSON type (1, true) for calling functions.
i.e: `Llama 3.1, Llama 3.2`

Added Unit Tests covering the added JsonElement arguments support.

- Fixes #9711 
- Extra (Remove of Warning for ONNX connectors SYSLIB1222)

---------

Co-authored-by: Dmytro Struk <[email protected]>
  • Loading branch information
RogerBarreto and dmytrostruk authored Nov 22, 2024
1 parent 4113a10 commit ef8251c
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 19 deletions.
14 changes: 8 additions & 6 deletions dotnet/samples/Demos/OllamaFunctionCalling/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
var chatCompletionService = kernel.GetRequiredService<IChatCompletionService>();
var settings = new OllamaPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };

Console.WriteLine("Ask questions or give instructions to the copilot such as:\n" +
"- Change the alarm to 8\n" +
"- What is the current alarm set?\n" +
"- Is the light on?\n" +
"- Turn the light off please.\n" +
"- Set an alarm for 6:00 am.\n");
Console.WriteLine("""
Ask questions or give instructions to the copilot such as:
- Change the alarm to 8
- What is the current alarm set?
- Is the light on?
- Turn the light off please.
- Set an alarm for 6:00 am.
""");

Console.Write("> ");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ namespace Microsoft.SemanticKernel;
[DebuggerDisplay("{DebuggerDisplay,nq}")]
internal sealed partial class KernelFunctionFromMethod : KernelFunction
{
private static readonly Dictionary<Type, Func<string, object>> s_jsonStringParsers = new(12)
{
{ typeof(bool), s => bool.Parse(s) },
{ typeof(int), s => int.Parse(s) },
{ typeof(uint), s => uint.Parse(s) },
{ typeof(long), s => long.Parse(s) },
{ typeof(ulong), s => ulong.Parse(s) },
{ typeof(float), s => float.Parse(s) },
{ typeof(double), s => double.Parse(s) },
{ typeof(decimal), s => decimal.Parse(s) },
{ typeof(short), s => short.Parse(s) },
{ typeof(ushort), s => ushort.Parse(s) },
{ typeof(byte), s => byte.Parse(s) },
{ typeof(sbyte), s => sbyte.Parse(s) }
};

/// <summary>
/// Creates a <see cref="KernelFunction"/> instance for a method, specified via an <see cref="MethodInfo"/> instance
/// and an optional target object if the method is an instance method.
Expand Down Expand Up @@ -710,26 +726,34 @@ private static (Func<KernelFunction, Kernel, KernelArguments, CancellationToken,

object? Process(object? value)
{
if (!type.IsAssignableFrom(value?.GetType()))
if (type.IsAssignableFrom(value?.GetType()))
{
return value;
}

if (converter is not null && value is not JsonElement or JsonDocument or JsonNode)
{
if (converter is not null)
try
{
try
{
return converter(value, kernel.Culture);
}
catch (Exception e) when (!e.IsCriticalException())
{
throw new ArgumentOutOfRangeException(name, value, e.Message);
}
return converter(value, kernel.Culture);
}

if (value is not null && TryToDeserializeValue(value, type, jsonSerializerOptions, out var deserializedValue))
catch (Exception e) when (!e.IsCriticalException())
{
return deserializedValue;
throw new ArgumentOutOfRangeException(name, value, e.Message);
}
}

if (value is JsonElement element && element.ValueKind == JsonValueKind.String
&& s_jsonStringParsers.TryGetValue(type, out var jsonStringParser))
{
return jsonStringParser(element.GetString()!);
}

if (value is not null && TryToDeserializeValue(value, type, jsonSerializerOptions, out var deserializedValue))
{
return deserializedValue;
}

return value;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,66 @@ public async Task ItSupportsArgumentsImplicitConversionAsync()
await function.InvokeAsync(this._kernel, arguments);
}

[Fact]
public async Task ItSupportsJsonElementArgumentsImplicitConversionAsync()
{
//Arrange
var arguments = new KernelArguments()
{
["l"] = JsonSerializer.Deserialize<JsonElement>((long)1), //Passed to parameter of type long
["i"] = JsonSerializer.Deserialize<JsonElement>((byte)1), //Passed to parameter of type int
["d"] = JsonSerializer.Deserialize<JsonElement>((float)1.0), //Passed to parameter of type double
["f"] = JsonSerializer.Deserialize<JsonElement>((uint)1.0), //Passed to parameter of type float
["g"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize(new Guid("35626209-b0ab-458c-bfc4-43e6c7bd13dc"))), //Passed to parameter of type string
["dof"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize(DayOfWeek.Thursday)), //Passed to parameter of type int
["b"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("true")), //Passed to parameter of type bool
};

var function = KernelFunctionFactory.CreateFromMethod((long l, int i, double d, float f, string g, int dof, bool b) =>
{
Assert.Equal(1, l);
Assert.Equal(1, i);
Assert.Equal(1.0, d);
Assert.Equal("35626209-b0ab-458c-bfc4-43e6c7bd13dc", g);
Assert.Equal(4, dof);
Assert.True(b);
},
functionName: "Test");

await function.InvokeAsync(this._kernel, arguments);
await function.AsAIFunction().InvokeAsync(arguments);
}

[Fact]
public async Task ItSupportsStringJsonElementArgumentsImplicitConversionAsync()
{
//Arrange
var arguments = new KernelArguments()
{
["l"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("1")), //Passed to parameter of type long
["i"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("1")), //Passed to parameter of type int
["d"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("1.0")), //Passed to parameter of type double
["f"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("1.0")), //Passed to parameter of type float
["g"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("35626209-b0ab-458c-bfc4-43e6c7bd13dc")), //Passed to parameter of type Guid
["dof"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("4")), //Passed to parameter of type int
["b"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("false")), //Passed to parameter of type bool
};

var function = KernelFunctionFactory.CreateFromMethod((long l, int i, double d, float f, Guid g, int dof, bool b) =>
{
Assert.Equal(1, l);
Assert.Equal(1, i);
Assert.Equal(1.0, d);
Assert.Equal(new Guid("35626209-b0ab-458c-bfc4-43e6c7bd13dc"), g);
Assert.Equal(4, dof);
Assert.False(b);
},
functionName: "Test");

await function.InvokeAsync(this._kernel, arguments);
await function.AsAIFunction().InvokeAsync(arguments);
}

[Fact]
public async Task ItSupportsParametersWithDefaultValuesAsync()
{
Expand Down

0 comments on commit ef8251c

Please sign in to comment.