From 19a5e6a3832800d3042033f53260b4993c5af8e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20de=20Souza?= Date: Wed, 11 Oct 2023 15:24:18 +0200 Subject: [PATCH 1/2] Fix MultilineChartContent not defined --- src/Loggers/LogCustomScalar.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Loggers/LogCustomScalar.jl b/src/Loggers/LogCustomScalar.jl index 4103a35..c2b42f5 100644 --- a/src/Loggers/LogCustomScalar.jl +++ b/src/Loggers/LogCustomScalar.jl @@ -1,3 +1,6 @@ +using .tensorboard_plugin_custom_scalar +using .tensorboard_plugin_custom_scalar: var"MarginChartContent.Series" as MarginChartContent_Series + # possible chart types @enum tb_chart_type tb_multiline=1 tb_margin=2 From fdafa8f11f7187b6eb922864c203486df895602c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20de=20Souza?= Date: Wed, 11 Oct 2023 22:10:45 +0200 Subject: [PATCH 2/2] Replace implicit kwargs with explicit defaults --- src/Loggers/LogCustomScalar.jl | 49 +++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/src/Loggers/LogCustomScalar.jl b/src/Loggers/LogCustomScalar.jl index c2b42f5..6b55252 100644 --- a/src/Loggers/LogCustomScalar.jl +++ b/src/Loggers/LogCustomScalar.jl @@ -28,14 +28,12 @@ function chart(name::String, metadata::Tuple{tb_chart_type, AbstractArray}) chart_type, tags = metadata if chart_type == tb_multiline - content = MultilineChartContent(tag = tags) - return Chart(title = name, multiline = content) + content = MultilineChartContent(tags) + return Chart(name, OneOf(:multiline, content)) elseif chart_type == tb_margin - @assert length(tags) == 3 - args = Dict(k => v for (k, v) in zip([:value, :lower, :upper], tags)) - content = MarginChartContent( - series = [MarginChartContent_Series(; args...)]) - return Chart(title = name, margin = content) + @assert length(tags) == 3 # value, lower, upper + content = MarginChartContent([MarginChartContent_Series(tags...)]) + return Chart(name, OneOf(:margin, content)) else @error "The chart type must be `tb_multiline` or `tb_margin`" end @@ -47,16 +45,29 @@ end function custom_scalar_summary(layout) cat_spec = zip(keys(layout), values(layout)) - categories = [Category(title = k, chart = charts(c)) for (k, c) in cat_spec] - - layout = Layout(category = categories) - plugin_data = SummaryMetadata_PluginData(plugin_name = "custom_scalars") - smd = SummaryMetadata(plugin_data = plugin_data) - cs_tensor = TensorProto(dtype = _DataType.DT_STRING, - string_val = [serialize_proto(layout)], - tensor_shape = TensorShapeProto()) - - Summary_Value(tag = "custom_scalars__config__", - tensor = cs_tensor, - metadata = smd) + categories = [Category(title, charts(c), false) for (title, c) in cat_spec] + + layout = Layout(zero(Int32), categories) + plugin_data = SummaryMetadata_PluginData("custom_scalars", UInt8[]) + smd = SummaryMetadata(plugin_data, "", "", DataClass.DATA_CLASS_UNKNOWN) + cs_tensor = TensorProto(_DataType.DT_STRING, + nothing, + zero(Int32), + UInt8[], + Int32[], + Float32[], + Float64[], + Int32[], + [serialize_proto(layout)], + Float32[], + Int64[], + Bool[], + Float64[], + ResourceHandleProto[], + VariantTensorDataProto[], + UInt32[], + UInt64[], + UInt8[]) + + Summary_Value("", "custom_scalars__config__", smd, OneOf(:tensor, cs_tensor)) end