Skip to content

Commit

Permalink
Implement strict set operations
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenamar-db committed Nov 7, 2024
1 parent 759cea7 commit 1afcb09
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 173 deletions.
2 changes: 1 addition & 1 deletion bench/src/main/scala/sjsonnet/ProfilingEvaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class ProfilingEvaluator(resolver: CachedResolver,

def builtins(): Seq[BuiltinBox] = {
val names = new util.IdentityHashMap[Val.Func, String]()
new Std().functions.foreachEntry((n, f) => names.put(f, n))
new Std(settings).functions.foreachEntry((n, f) => names.put(f, n))
val m = new mutable.HashMap[String, BuiltinBox]
def add(b: ExprBox, func: Val.Builtin): Unit = {
val n = names.getOrDefault(func, func.getClass.getName)
Expand Down
5 changes: 5 additions & 0 deletions sjsonnet/src-jvm-native/sjsonnet/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,9 @@ case class Config(
doc = """Properly handle assertions defined in a Jsonnet dictionary that is extended more than once"""
)
strictInheritedAssertions: Flag = Flag(),
@arg(
name = "strict-set-operations",
doc = """Strict set operations"""
)
strictSetOperations: Flag = Flag(),
)
27 changes: 13 additions & 14 deletions sjsonnet/src-jvm-native/sjsonnet/SjsonnetMain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ object SjsonnetMain {
stderr: PrintStream,
wd: os.Path,
allowedInputs: Option[Set[os.Path]] = None,
importer: Option[(Path, String) => Option[os.Path]] = None,
std: Val.Obj = new Std().Std): Int = {
importer: Option[(Path, String) => Option[os.Path]] = None): Int = {

var hasWarnings = false
def warn(msg: String): Unit = {
Expand All @@ -73,7 +72,7 @@ object SjsonnetMain {
Left("error: -i/--interactive must be passed in as the first argument")
}else Right(config.file)
}
outputStr <- mainConfigured(file, config, parseCache, wd, allowedInputs, importer, warn, std)
outputStr <- mainConfigured(file, config, parseCache, wd, allowedInputs, importer, warn)
res <- {
if(hasWarnings && config.fatalWarnings.value) Left("")
else Right(outputStr)
Expand Down Expand Up @@ -175,8 +174,7 @@ object SjsonnetMain {
wd: os.Path,
allowedInputs: Option[Set[os.Path]] = None,
importer: Option[(Path, String) => Option[os.Path]] = None,
warnLogger: String => Unit = null,
std: Val.Obj = new Std().Std): Either[String, String] = {
warnLogger: String => Unit = null): Either[String, String] = {

val (jsonnetCode, path) =
if (config.exec.value) (file, wd / "\uFE64exec\uFE65")
Expand All @@ -198,6 +196,15 @@ object SjsonnetMain {
)

var currentPos: Position = null
val settings = new Settings(
preserveOrder = config.preserveOrder.value,
strict = config.strict.value,
noStaticErrors = config.noStaticErrors.value,
noDuplicateKeysInComprehension = config.noDuplicateKeysInComprehension.value,
strictImportSyntax = config.strictImportSyntax.value,
strictInheritedAssertions = config.strictInheritedAssertions.value,
strictSetOperations = config.strictSetOperations.value
)
val interp = new Interpreter(
extBinding,
tlaBinding,
Expand All @@ -213,17 +220,9 @@ object SjsonnetMain {
case None => resolveImport(config.jpaths.map(os.Path(_, wd)).map(OsPath(_)), allowedInputs)
},
parseCache,
settings = new Settings(
preserveOrder = config.preserveOrder.value,
strict = config.strict.value,
noStaticErrors = config.noStaticErrors.value,
noDuplicateKeysInComprehension = config.noDuplicateKeysInComprehension.value,
strictImportSyntax = config.strictImportSyntax.value,
strictInheritedAssertions = config.strictInheritedAssertions.value
),
settings = settings,
storePos = if (config.yamlDebug.value) currentPos = _ else null,
warnLogger = warnLogger,
std = std
)

(config.multi, config.yamlStream.value) match {
Expand Down
2 changes: 1 addition & 1 deletion sjsonnet/src/sjsonnet/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ class Interpreter(extVars: Map[String, String],
settings: Settings = Settings.default,
storePos: Position => Unit = null,
warnLogger: (String => Unit) = null,
std: Val.Obj = new Std().Std
) { self =>

private val internedStrings = new mutable.HashMap[String, String]

private val internedStaticFieldSets = new mutable.HashMap[Val.StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean]]
private val std = new Std(settings).Std

val resolver = new CachedResolver(importer, parseCache, settings.strictImportSyntax, internedStrings, internedStaticFieldSets) {
override def process(expr: Expr, fs: FileScope): Either[Error, (Expr, FileScope)] =
Expand Down
1 change: 1 addition & 0 deletions sjsonnet/src/sjsonnet/Settings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Settings(
val noDuplicateKeysInComprehension: Boolean = false,
val strictImportSyntax: Boolean = false,
val strictInheritedAssertions: Boolean = false,
val strictSetOperations: Boolean = false,
)

object Settings {
Expand Down
198 changes: 70 additions & 128 deletions sjsonnet/src/sjsonnet/Std.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import java.util.regex.Pattern
import sjsonnet.Expr.Member.Visibility
import sjsonnet.Expr.BinaryOp

import scala.collection.Searching.Found
import scala.collection.mutable
import scala.util.matching.Regex

Expand All @@ -16,7 +17,7 @@ import scala.util.matching.Regex
* in Scala code. Uses `builtin` and other helpers to handle the common wrapper
* logic automatically
*/
class Std {
class Std(settings: Settings) {
private val dummyPos: Position = new Position(null, 0)
private val emptyLazyArray = new Array[Lazy](0)

Expand Down Expand Up @@ -637,92 +638,6 @@ class Std {
}
}

private object SetInter extends Val.Builtin3("a", "b", "keyF", Array(null, null, Val.False(dummyPos))) {
private def isStr(a: Val.Arr) = a.forall(_.isInstanceOf[Val.Str])

override def specialize(args: Array[Expr]): (Val.Builtin, Array[Expr]) = args match {
case Array(a: Val.Arr, b) if isStr(a) => (new Spec1Str(a), Array(b))
case Array(a, b: Val.Arr) if isStr(b) => (new Spec1Str(b), Array(a))
case args if args.length == 2 => (Spec2, args)
case _ => null
}

def asArray(a: Val): Array[Lazy] = a match {
case arr: Val.Arr => arr.asLazyArray
case str: Val.Str => stringChars(pos, str.value).asLazyArray
case _ => Error.fail("Arguments must be either arrays or strings")
}

def evalRhs(_a: Val, _b: Val, _keyF: Val, ev: EvalScope, pos: Position): Val = {
if(_keyF.isInstanceOf[Val.False]) Spec2.evalRhs(_a, _b, ev, pos)
else {
val a = asArray(_a)
val b = asArray(_b)
val keyFFunc = _keyF.asInstanceOf[Val.Func]
val out = new mutable.ArrayBuffer[Lazy]
for (v <- a) {
val appliedX = keyFFunc.apply1(v, pos.noOffset)(ev)
if (b.exists(value => {
val appliedValue = keyFFunc.apply1(value, pos.noOffset)(ev)
ev.equal(appliedValue, appliedX)
}) && !out.exists(value => {
val mValue = keyFFunc.apply1(value, pos.noOffset)(ev)
ev.equal(mValue, appliedX)
})) {
out.append(v)
}
}
sortArr(pos, ev, new Val.Arr(pos, out.toArray), keyFFunc)
}
}

private object Spec2 extends Val.Builtin2("a", "b") {
def evalRhs(_a: Val, _b: Val, ev: EvalScope, pos: Position): Val = {
val a = asArray(_a)
val b = asArray(_b)
val out = new mutable.ArrayBuffer[Lazy](a.length)
for (v <- a) {
val vf = v.force
if (b.exists(value => {
ev.equal(value.force, vf)
}) && !out.exists(value => {
ev.equal(value.force, vf)
})) {
out.append(v)
}
}
sortArr(pos, ev, new Val.Arr(pos, out.toArray), null)
}
}

private class Spec1Str(_a: Val.Arr) extends Val.Builtin1("b") {
private[this] val a =
ArrayOps.sortInPlaceBy(ArrayOps.distinctBy(_a.asLazyArray)(_.asInstanceOf[Val.Str].value))(_.asInstanceOf[Val.Str].value)
// 2.13+: _a.asLazyArray.distinctBy(_.asInstanceOf[Val.Str].value).sortInPlaceBy(_.asInstanceOf[Val.Str].value)

def evalRhs(_b: Val, ev: EvalScope, pos: Position): Val = {
val b = asArray(_b)
val bs = new mutable.HashSet[String]
var i = 0
while(i < b.length) {
b(i).force match {
case s: Val.Str => bs.add(s.value)
case _ =>
}
i += 1
}
val out = new mutable.ArrayBuilder.ofRef[Lazy]
i = 0
while(i < a.length) {
val s = a(i).asInstanceOf[Val.Str]
if(bs.contains(s.value)) out.+=(s)
i += 1
}
new Val.Arr(pos, out.result())
}
}
}

val functions: Map[String, Val.Func] = Map(
"assertEqual" -> AssertEqual,
"toString" -> ToString,
Expand Down Expand Up @@ -1141,68 +1056,64 @@ class Std {
val concat = new Val.Arr(pos, a ++ b)
uniqArr(pos, ev, sortArr(pos, ev, concat, args(2)), args(2))
},
"setInter" -> SetInter,
builtinWithDefaults("setDiff", "a" -> null, "b" -> null, "keyF" -> Val.False(dummyPos)) { (args, pos, ev) =>

builtinWithDefaults("setInter", "a" -> null, "b" -> null, "keyF" -> Val.False(dummyPos)) { (args, pos, ev) =>
val a = args(0) match {
case arr: Val.Arr => arr.asLazyArray
case str: Val.Str => stringChars(pos, str.value).asLazyArray
case str: Val.Str if !settings.strictSetOperations => stringChars(pos, str.value).asLazyArray
case _ => Error.fail("Arguments must be either arrays or strings")
}
val b = args(1) match {
case arr: Val.Arr => arr.asLazyArray
case str: Val.Str => stringChars(pos, str.value).asLazyArray
case str: Val.Str if !settings.strictSetOperations => stringChars(pos, str.value).asLazyArray
case _ => Error.fail("Arguments must be either arrays or strings")
}

val keyF = args(2)
val out = new mutable.ArrayBuffer[Lazy]

for (v <- a) {
if (keyF.isInstanceOf[Val.False]) {
val vf = v.force
if (!b.exists(value => {
ev.equal(value.force, vf)
}) && !out.exists(value => {
ev.equal(value.force, vf)
})) {
out.append(v)
}
} else {
val keyFFunc = keyF.asInstanceOf[Val.Func]
val appliedX = keyFFunc.apply1(v, pos.noOffset)(ev)

if (!b.exists(value => {
val appliedValue = keyFFunc.apply1(value, pos.noOffset)(ev)
ev.equal(appliedValue, appliedX)
}) && !out.exists(value => {
val mValue = keyFFunc.apply1(value, pos.noOffset)(ev)
ev.equal(mValue, appliedX)
})) {
out.append(v)
}
if (existsInSet(ev, pos, keyF, b, v.force) && !existsInSet(ev, pos, keyF, out.toArray, v.force)) {
out.append(v)
}
}

sortArr(pos, ev, new Val.Arr(pos, out.toArray), keyF)
if (settings.strictSetOperations) {
new Val.Arr(pos, out.toArray)
} else {
sortArr(pos, ev, new Val.Arr(pos, out.toArray), keyF)
}
},
builtinWithDefaults("setMember", "x" -> null, "arr" -> null, "keyF" -> Val.False(dummyPos)) { (args, pos, ev) =>
builtinWithDefaults("setDiff", "a" -> null, "b" -> null, "keyF" -> Val.False(dummyPos)) { (args, pos, ev) =>
val a = args(0) match {
case arr: Val.Arr => arr.asLazyArray
case str: Val.Str if !settings.strictSetOperations => stringChars(pos, str.value).asLazyArray
case _ => Error.fail("Arguments must be either arrays or strings")
}
val b = args(1) match {
case arr: Val.Arr => arr.asLazyArray
case str: Val.Str if !settings.strictSetOperations => stringChars(pos, str.value).asLazyArray
case _ => Error.fail("Arguments must be either arrays or strings")
}

val keyF = args(2)
val out = new mutable.ArrayBuffer[Lazy]

for (v <- a) {
if (!existsInSet(ev, pos, keyF, b, v.force) && !existsInSet(ev, pos, keyF, out.toArray, v.force)) {
out.append(v)
}
}

if (keyF.isInstanceOf[Val.False]) {
val ujson.Arr(mArr) = Materializer(args(1))(ev)
val mx = Materializer(args(0))(ev)
mArr.contains(mx)
if (settings.strictSetOperations) {
new Val.Arr(pos, out.toArray)
} else {
val arr = args(1).asInstanceOf[Val.Arr].asLazyArray
val keyFFunc = keyF.asInstanceOf[Val.Func]
val appliedX = keyFFunc.apply1(args(0), pos.noOffset)(ev)
arr.exists(value => {
val appliedValue = keyFFunc.apply1(value, pos.noOffset)(ev)
ev.equal(appliedValue, appliedX)
})
sortArr(pos, ev, new Val.Arr(pos, out.toArray), keyF)
}
},
builtinWithDefaults("setMember", "x" -> null, "arr" -> null, "keyF" -> Val.False(dummyPos)) { (args, pos, ev) =>
val keyF = args(2)
val arr = args(1).asInstanceOf[Val.Arr].asLazyArray
existsInSet(ev, pos, keyF, arr, args(0))
},

"split" -> Split,
"splitLimit" -> SplitLimit,
Expand Down Expand Up @@ -1281,6 +1192,37 @@ class Std {
Platform.sha3(str)
},
)

private def existsInSet(ev: EvalScope, pos: Position, keyF: Val, arr: Array[Lazy], toFind: Val): Boolean = {
val appliedX = keyF match {
case keyFFunc: Val.Func => keyFFunc.apply1(toFind, pos.noOffset)(ev)
case _ => toFind
}
System.out.println(appliedX.force)
if (settings.strictSetOperations) {
arr.search(appliedX.force)((toFind: Lazy, value: Lazy) => {
val appliedValue = keyF match {
case keyFFunc: Val.Func => keyFFunc.apply1(value, pos.noOffset)(ev).force
case _ => value.force
}
toFind.force match {
case s: Val.Str if appliedValue.isInstanceOf[Val.Str] => Ordering.String.compare(s.asString, appliedValue.force.asString)
case n: Val.Num if appliedValue.isInstanceOf[Val.Num] => Ordering.Double.TotalOrdering.compare(n.asDouble, appliedValue.force.asDouble)
case t: Val.Bool if appliedValue.isInstanceOf[Val.Bool] => Ordering.Boolean.compare(t.asBoolean, appliedValue.force.asBoolean)
case _ => Error.fail("Cannot call setMember on " + toFind.force.prettyName + " and " + appliedValue.force.prettyName)
}
}).isInstanceOf[Found]
} else {
arr.exists(value => {
val appliedValue = keyF match {
case keyFFunc: Val.Func => keyFFunc.apply1(value, pos.noOffset)(ev)
case _ => value
}
ev.equal(appliedValue.force, appliedX.force)
})
}
}

val Std: Val.Obj = Val.Obj.mk(
null,
functions.toSeq
Expand Down
3 changes: 1 addition & 2 deletions sjsonnet/test/src-jvm/sjsonnet/Example.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ public void example(){
System.err,
os.package$.MODULE$.pwd(),
scala.None$.empty(),
scala.None$.empty(),
new sjsonnet.Std().Std()
scala.None$.empty()
);
}
}
4 changes: 1 addition & 3 deletions sjsonnet/test/src/sjsonnet/PreserveOrderTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,7 @@ object PreserveOrderTests extends TestSuite {
}

test("preserveOrderPreservesSetMembership") {
eval("""std.setMember({a: 1, b: 2}, [{b: 2, a: 1}])""", true).toString ==> "true"

eval("""std.setMember({q: {a: 1, b: 2}}, [{q: {b: 2, a: 1}}], keyF=function(v) v.q)""", true).toString ==> "true"
eval("""std.setMember({q: {a: 1, b: 2}}, [{q: {b: 2, a: 1}}], keyF=function(v) v.q.a)""", true).toString ==> "true"
}

test("preserveOrderSetIntersection") {
Expand Down
Loading

0 comments on commit 1afcb09

Please sign in to comment.