diff --git a/Project.toml b/Project.toml index d180e6de..a3683c53 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.1.2" +version = "0.1.3" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/docs/Manifest.toml b/docs/Manifest.toml deleted file mode 100644 index 6cb28294..00000000 --- a/docs/Manifest.toml +++ /dev/null @@ -1,483 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -[[AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.0.1" - -[[ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" - -[[Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[CategoricalArrays]] -deps = ["DataAPI", "Future", "JSON", "Missings", "Printf", "RecipesBase", "Statistics", "StructTypes", "Unicode"] -git-tree-sha1 = "1562002780515d2573a4fb0c3715e4e57481075e" -uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597" -version = "0.10.0" - -[[ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "be770c08881f7bb928dfd86d1ba83798f76cf62a" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.10.9" - -[[ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.0" - -[[Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.31.0" - -[[CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" - -[[ComputationalResources]] -git-tree-sha1 = "52cb3ec90e8a8bea0e62e275ba577ad0f74821f7" -uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" -version = "0.3.2" - -[[Crayons]] -git-tree-sha1 = "3f71217b538d7aaee0b69ab47d9b7724ca8afa0d" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.0.4" - -[[DataAPI]] -git-tree-sha1 = "ee400abb2298bd13bfc3df1c412ed228061a2385" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.7.0" - -[[DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.9" - -[[DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" - -[[Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[Distributions]] -deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"] -git-tree-sha1 = "2733323e5c02a9d7f48e7a3c4bc98d764fb704da" -uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.6" - -[[DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.5" - -[[Documenter]] -deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "47f13b6305ab195edb73c86815962d84e31b0f48" -uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.3" - -[[Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" - -[[FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.11.7" - -[[FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.4" - -[[Formatting]] -deps = ["Printf"] -git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8" -uuid = "59287772-0a20-5a39-b81b-1366585eb4c0" -version = "0.4.2" - -[[Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[IOCapture]] -deps = ["Logging", "Random"] -git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" -uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.2" - -[[MCMCDiagnosticTools]] -deps = ["AbstractFFTs", "DataAPI", "Distributions", "LinearAlgebra", "MLJModelInterface", "Random", "SpecialFunctions", "Statistics", "StatsBase", "Tables"] -path = ".." -uuid = "be115224-59cd-429b-ad48-344e309966f0" -version = "0.1.0" - -[[InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[InvertedIndices]] -deps = ["Test"] -git-tree-sha1 = "15732c475062348b0165684ffe28e85ea8396afc" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.0.0" - -[[IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[JLLWrappers]] -deps = ["Preferences"] -git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.3.0" - -[[JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.1" - -[[LearnBase]] -git-tree-sha1 = "a0d90569edd490b82fdc4dc078ea54a5a800d30a" -uuid = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" -version = "0.4.1" - -[[LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" - -[[LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" - -[[LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" - -[[Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[LinearAlgebra]] -deps = ["Libdl"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[LogExpFunctions]] -deps = ["DocStringExtensions", "LinearAlgebra"] -git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.2.4" - -[[Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[LossFunctions]] -deps = ["InteractiveUtils", "LearnBase", "Markdown", "RecipesBase", "StatsBase"] -git-tree-sha1 = "0f057f6ea90a84e73a8ef6eebb4dc7b5c330020f" -uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" -version = "0.7.2" - -[[MLJBase]] -deps = ["CategoricalArrays", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "e1996657b66ba5c3a1bdbf73835640460958712d" -uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "0.18.13" - -[[MLJModelInterface]] -deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "55c785a68d71c5fd7b64b490e0d9ab18cf13a04c" -uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.1.1" - -[[MLJXGBoostInterface]] -deps = ["MLJModelInterface", "Tables", "XGBoost"] -git-tree-sha1 = "3528d3ac6f5fa07885dc95fd2e890c34c0ac7725" -uuid = "54119dfa-1dab-4055-a167-80440f4f7a91" -version = "0.1.3" - -[[Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" - -[[Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.0.0" - -[[Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" - -[[NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" - -[[OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[OrderedCollections]] -git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.1" - -[[PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "4dd403333bcf0909341cfe57ec115152f937d7d8" -uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.1" - -[[Parameters]] -deps = ["OrderedCollections", "UnPack"] -git-tree-sha1 = "2276ac65f1e236e0a6ea70baff3f62ad4c625345" -uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" -version = "0.12.2" - -[[Parsers]] -deps = ["Dates"] -git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.1.0" - -[[PersistenceDiagramsBase]] -deps = ["Compat", "Tables"] -git-tree-sha1 = "ec6eecbfae1c740621b5d903a69ec10e30f3f4bc" -uuid = "b1ad91c1-539c-4ace-90bd-ea06abc420fa" -version = "0.1.1" - -[[Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[Preferences]] -deps = ["TOML"] -git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.2.2" - -[[PrettyTables]] -deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"] -git-tree-sha1 = "0d1245a357cc61c8cd61934c07447aa569ff22e6" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "1.1.0" - -[[Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[ProgressMeter]] -deps = ["Distributed", "Printf"] -git-tree-sha1 = "afadeba63d90ff223a6a48d2009434ecee2ec9e8" -uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.7.1" - -[[QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "12fbe86da16df6679be7521dfb39fbc861e1dc7b" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.4.1" - -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[Random]] -deps = ["Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[RecipesBase]] -git-tree-sha1 = "b3fb709f3c97bfc6e948be68beeecb55a0b340ae" -uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -version = "1.1.1" - -[[Reexport]] -git-tree-sha1 = "5f6c21241f0f655da3952fd60aa18477cf96c220" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.1.0" - -[[Rmath]] -deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f" -uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.0" - -[[Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "68db32dff12bb6127bac73c209881191bf0efbb7" -uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.3.0+0" - -[[SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[ScientificTypes]] -deps = ["CategoricalArrays", "ColorTypes", "Dates", "PersistenceDiagramsBase", "PrettyTables", "ScientificTypesBase", "StatisticalTraits", "Tables"] -git-tree-sha1 = "345e33061ad7c49c6e860e42a04c62ecbea3eabf" -uuid = "321657f4-b219-11e9-178b-2701a2544e81" -version = "2.0.0" - -[[ScientificTypesBase]] -git-tree-sha1 = "3f7ddb0cf0c3a4cff06d9df6f01135fa5442c99b" -uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161" -version = "1.0.0" - -[[Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.0.0" - -[[SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[SpecialFunctions]] -deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"] -git-tree-sha1 = "a50550fa3164a8c46747e62063b4d774ac1bcf49" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.5.1" - -[[StatisticalTraits]] -deps = ["ScientificTypesBase"] -git-tree-sha1 = "5114841829816649ecc957f07f6a621671e4a951" -uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" -version = "2.0.0" - -[[Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[StatsAPI]] -git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.0.0" - -[[StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.8" - -[[StatsFuns]] -deps = ["LogExpFunctions", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "30cd8c360c54081f806b1ee14d2eecbef3c04c49" -uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "0.9.8" - -[[StructTypes]] -deps = ["Dates", "UUIDs"] -git-tree-sha1 = "e36adc471280e8b346ea24c5c87ba0571204be7a" -uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.7.2" - -[[SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" - -[[TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"] -git-tree-sha1 = "8ed4a3ea724dac32670b062be3ef1c1de6773ae8" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.4.4" - -[[Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" - -[[Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[UnPack]] -git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" -uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" -version = "1.0.2" - -[[Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[XGBoost]] -deps = ["Libdl", "Printf", "Random", "SparseArrays", "Statistics", "Test", "XGBoost_jll"] -git-tree-sha1 = "8a692f817f1a6c15ef4913a0ffefa6163117f43d" -uuid = "009559a3-9522-5dbb-924b-0b6ed2b22bb9" -version = "1.1.1" - -[[XGBoost_jll]] -deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"] -git-tree-sha1 = "901ff35c4f56d4eba54e3533ee916bb72dcfb813" -uuid = "a5c6f535-4255-5ca2-a466-0e519f119c46" -version = "1.2.0+0" - -[[Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" - -[[nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" - -[[p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/docs/Project.toml b/docs/Project.toml index 75bfe645..3e45573a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -9,6 +9,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Documenter = "0.27" MCMCDiagnosticTools = "0.1" -MLJBase = "0.18" +MLJBase = "0.19" MLJXGBoostInterface = "0.1" julia = "1.3" diff --git a/src/rstar.jl b/src/rstar.jl index f320cacd..31535a52 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -52,11 +52,11 @@ Deterministic classifiers can also be derived from probabilistic classifiers by predicting the mode. In MLJ this corresponds to a pipeline of models. ```jldoctest rstar -julia> @pipeline XGBoostClassifier name = XGBoostDeterministic operation = predict_mode; +julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode); -julia> value = rstar(XGBoostDeterministic(), samples, chain_indices); +julia> value = rstar(xgboost_deterministic, samples, chain_indices); -julia> isapprox(value, 1; atol=0.1) +julia> isapprox(value, 1; atol=0.2) true ``` @@ -112,7 +112,7 @@ function rstar( end # R⋆ for deterministic predictions (algorithm 1) -function _rstar(predictions::AbstractVector, ytest::AbstractVector) +function _rstar(predictions::AbstractVector{T}, ytest::AbstractVector{T}) where {T} length(predictions) == length(ytest) || error("numbers of predictions and targets must be equal") mean_accuracy = Statistics.mean(p == y for (p, y) in zip(predictions, ytest)) @@ -121,11 +121,11 @@ function _rstar(predictions::AbstractVector, ytest::AbstractVector) end # R⋆ for probabilistic predictions (algorithm 2) -function _rstar( - predictions::AbstractVector{<:Distributions.UnivariateDistribution}, - ytest::AbstractVector, -) - # create Poisson binomila distribution with support `0:length(predictions)` +function _rstar(predictions::AbstractVector, ytest::AbstractVector) + length(predictions) == length(ytest) || + error("numbers of predictions and targets must be equal") + + # create Poisson binomial distribution with support `0:length(predictions)` distribution = Distributions.PoissonBinomial(map(Distributions.pdf, predictions, ytest)) # scale distribution to support in `[0, nclasses]` diff --git a/test/rstar/Project.toml b/test/rstar/Project.toml index 4eacf489..059f7f3e 100644 --- a/test/rstar/Project.toml +++ b/test/rstar/Project.toml @@ -9,7 +9,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Distributions = "0.25" MCMCDiagnosticTools = "0.1" -MLJBase = "0.18" +MLJBase = "0.19" MLJLIBSVMInterface = "0.1" MLJXGBoostInterface = "0.1" julia = "1.3" diff --git a/test/rstar/runtests.jl b/test/rstar/runtests.jl index 41de3fb3..00869743 100644 --- a/test/rstar/runtests.jl +++ b/test/rstar/runtests.jl @@ -7,10 +7,10 @@ using MLJXGBoostInterface using Test -@pipeline XGBoostClassifier name = XGBoostDeterministic operation = predict_mode +const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode) @testset "rstar.jl" begin - classifiers = (XGBoostClassifier(), XGBoostDeterministic(), SVC()) + classifiers = (XGBoostClassifier(), xgboost_deterministic, SVC()) N = 1_000 @testset "examples (classifier = $classifier)" for classifier in classifiers