Skip to content

Commit

Permalink
Merge pull request #89 from chrishavlin/preprocessor_dirs_update
Browse files Browse the repository at this point in the history
preprocessor directives followup
  • Loading branch information
chrishavlin authored Dec 9, 2024
2 parents e59ce2a + c7ee934 commit 525c46f
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 29 deletions.
63 changes: 53 additions & 10 deletions yt_idv/scene_components/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from yt_idv.scene_data.base_data import SceneData
from yt_idv.shader_objects import (
PreprocessorDefinitionState,
ShaderProgram,
ShaderTrait,
component_shaders,
Expand Down Expand Up @@ -58,6 +59,8 @@ class SceneComponent(traitlets.HasTraits):
colormap = traitlets.Instance(ColormapTexture)
_program1 = traitlets.Instance(ShaderProgram, allow_none=True)
_program2 = traitlets.Instance(ShaderProgram, allow_none=True)
_program1_pp_defs = traitlets.Instance(PreprocessorDefinitionState, allow_none=True)
_program2_pp_defs = traitlets.Instance(PreprocessorDefinitionState, allow_none=True)
_program1_invalid = True
_program2_invalid = True
_cmap_bounds_invalid = True
Expand Down Expand Up @@ -144,15 +147,38 @@ def _default_display_name(self):
def _default_render_method(self):
return default_shader_combos[self.name]

@traitlets.default("_program1_pp_defs")
def _default_program1_pp_defs(self):
return PreprocessorDefinitionState()

@traitlets.default("_program2_pp_defs")
def _default_program2_pp_defs(self):
return PreprocessorDefinitionState()

@traitlets.observe("render_method")
def _change_render_method(self, change):
new_combo = component_shaders[self.name][change["new"]]
with self.hold_trait_notifications():
self.vertex_shader = new_combo["first_vertex"]
self.fragment_shader = new_combo["first_fragment"]
self.geometry_shader = new_combo.get("first_geometry", None)
self.colormap_vertex = new_combo["second_vertex"]
self.colormap_fragment = new_combo["second_fragment"]
self.vertex_shader = (
new_combo["first_vertex"],
self._program1_pp_defs["vertex"],
)
self.fragment_shader = (
new_combo["first_fragment"],
self._program1_pp_defs["fragment"],
)
self.geometry_shader = (
new_combo.get("first_geometry", None),
self._program1_pp_defs["geometry"],
)
self.colormap_vertex = (
new_combo["second_vertex"],
self._program2_pp_defs["vertex"],
)
self.colormap_fragment = (
new_combo["second_fragment"],
self._program2_pp_defs["fragment"],
)

@traitlets.observe("render_method")
def _add_initial_isolayer(self, change):
Expand Down Expand Up @@ -191,10 +217,23 @@ def _change_colormap_fragment(self, change):
self._program2_invalid = True

@traitlets.observe("use_db")
def _initialize_db(self, changed):
# invaldiate the colormap when the depth buffer selection changes
def _toggle_depth_buffer(self, changed):
# invalidate the colormap when the depth buffer selection changes
self._cmap_bounds_invalid = True

# update the preprocessor state: USE_DB only present in the second
# program, only update that one.
if changed["new"]:
self._program2_pp_defs.add_definition("fragment", ("USE_DB", ""))
else:
self._program2_pp_defs.clear_definition("fragment", ("USE_DB", ""))

# update the colormap fragment with current render method
current_combo = component_shaders[self.name][self.render_method]
pp_defs = self._program2_pp_defs["fragment"]
self.colormap_fragment = current_combo["second_fragment"], pp_defs
self._recompile_shader()

@traitlets.default("colormap")
def _default_colormap(self):
cm = ColormapTexture()
Expand Down Expand Up @@ -241,7 +280,9 @@ def program1(self):
self._program1.delete_program()
self._fragment_shader_default()
self._program1 = ShaderProgram(
self.vertex_shader, self.fragment_shader, self.geometry_shader
self.vertex_shader,
self.fragment_shader,
self.geometry_shader,
)
self._program1_invalid = False
return self._program1
Expand All @@ -254,7 +295,10 @@ def program2(self):
# The vertex shader will always be the same.
# The fragment shader will change based on whether we are
# colormapping or not.
self._program2 = ShaderProgram(self.colormap_vertex, self.colormap_fragment)
self._program2 = ShaderProgram(
self.colormap_vertex,
self.colormap_fragment,
)
self._program2_invalid = False
return self._program2

Expand Down Expand Up @@ -296,7 +340,6 @@ def run_program(self, scene):
p2._set_uniform("cmap", 0)
p2._set_uniform("fb_tex", 1)
p2._set_uniform("db_tex", 2)
p2._set_uniform("use_db", self.use_db)
# Note that we use cmap_min/cmap_max, not
# self.cmap_min/self.cmap_max.
p2._set_uniform("cmap_min", self.cmap_min)
Expand Down
116 changes: 105 additions & 11 deletions yt_idv/shader_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ctypes
import os
from collections import OrderedDict
from typing import List, Optional, Tuple

import traitlets
import yaml
Expand Down Expand Up @@ -79,28 +80,62 @@ class ShaderProgram:
geometry_shader : string
or :class:`yt_idv.shader_objects.GeometryShader`
The geometry shader used in the pipeline; optional.
preprocessor_defs : PreprocessorDefinitionState
a PreprocessorDefinitionState instance defining any preprocessor
definitions if used; optional.
"""

def __init__(self, vertex_shader=None, fragment_shader=None, geometry_shader=None):
def __init__(
self,
vertex_shader=None,
fragment_shader=None,
geometry_shader=None,
preprocessor_defs=None,
):
# Don't allow just one. Either neither or both.
if vertex_shader is None and fragment_shader is None:
pass
elif None not in (vertex_shader, fragment_shader):
# Geometry is optional
self.link(vertex_shader, fragment_shader, geometry_shader)
self.link(
vertex_shader,
fragment_shader,
geometry_shader,
preprocessor_defs=preprocessor_defs,
)
else:
raise RuntimeError
self._uniform_funcs = OrderedDict()

def link(self, vertex_shader, fragment_shader, geometry_shader=None):
def link(
self,
vertex_shader,
fragment_shader,
geometry_shader=None,
preprocessor_defs=None,
):
if preprocessor_defs is None:
preprocessor_defs = PreprocessorDefinitionState()

# We allow an optional geometry shader, but not tesselation (yet?)
self.program = GL.glCreateProgram()
if not isinstance(vertex_shader, Shader):
vertex_shader = Shader(source=vertex_shader)
vertex_shader = Shader(
source=vertex_shader,
preprocessor_defs=preprocessor_defs.get_shader_defs("vertex"),
)
if not isinstance(fragment_shader, Shader):
fragment_shader = Shader(source=fragment_shader)
fragment_shader = Shader(
source=fragment_shader,
preprocessor_defs=preprocessor_defs.get_shader_defs("fragment"),
)
if geometry_shader is not None and not isinstance(geometry_shader, Shader):
geometry_shader = Shader(source=geometry_shader)
geometry_shader = Shader(
source=geometry_shader,
preprocessor_defs=preprocessor_defs.get_shader_defs("geometry"),
)
self.vertex_shader = vertex_shader
self.fragment_shader = fragment_shader
self.geometry_shader = geometry_shader
Expand Down Expand Up @@ -277,6 +312,7 @@ def _get_source(self, source):
if ";" in source:
# This is probably safe, right? Enh, probably.
return source

# What this does is concatenate multiple (if available) source files.
# This gets around GLSL's composition issues, which means we can have
# functions that get called at each step in a ray tracing process, for
Expand Down Expand Up @@ -371,31 +407,89 @@ def __del__(self):
self.delete_shader()


def _validate_shader(shader_type, value, allow_null=True):
def _validate_shader(shader_type, value, allow_null=True, preprocessor_defs=None):
shader_info = known_shaders[shader_type][value]
shader_info.setdefault("shader_type", shader_type)
shader_info["use_separate_blend"] = bool("blend_func_separate" in shader_info)
shader_info.setdefault("shader_name", value)
shader = Shader(allow_null=allow_null, **shader_info)
return shader
if preprocessor_defs is not None:
shader_info["preprocessor_defs"] = preprocessor_defs
return Shader(allow_null=allow_null, **shader_info)


class ShaderTrait(traitlets.TraitType):
default_value = None
info_text = "A shader (vertex, fragment or geometry)"

def validate(self, obj, value):
if isinstance(value, str):
if isinstance(value, str) or isinstance(value, tuple):
try:
shader_type = self.metadata.get("shader_type", "vertex")
return _validate_shader(shader_type, value)
if isinstance(value, tuple):
preprocessor_defs = value[1]
value = value[0]
else:
preprocessor_defs = None
return _validate_shader(
shader_type, value, preprocessor_defs=preprocessor_defs
)
except KeyError:
self.error(obj, value)
elif isinstance(value, Shader):
return value
self.error(obj, value)


class PreprocessorDefinitionState:

_valid_shader_types = ("vertex", "geometry", "fragment")

def __init__(self):
self.vertex = {}
self.geometry = {}
self.fragment = {}

def _get_dict(self, shader_type: str) -> dict:
"""return the dict of definitions for specifed shader_type"""
return getattr(self, shader_type)

def add_definition(self, shader_type: str, value: Tuple[str, str]):
"""add a definition for specified shader_type, will overwrite
existing definitions.
"""
self._validate_shader_type(shader_type)
self._get_dict(shader_type)[value[0]] = value[1]

def clear_definition(self, shader_type: str, value: Tuple[str, str]):
"""remove the definition of value for specified shader_type"""
self._validate_shader_type(shader_type)
self._get_dict(shader_type).pop(value[0])

def get_shader_defs(self, shader_type: str) -> List[Tuple[str, str]]:
"""return the preprocessor definition list for specified shader_type"""
self._validate_shader_type(shader_type)
return list(self._get_dict(shader_type).items())

def _validate_shader_type(self, shader_type: str):
if shader_type not in self._valid_shader_types:
raise ValueError(
f"shader_type must be one of {self._valid_shader_types}, "
f"but found {shader_type}"
)

def __getitem__(self, item: str) -> List[Tuple[str, str]]:
return self.get_shader_defs(item)

def reset(self, shader_type: Optional[str] = None):
if shader_type is None:
self.vertex = {}
self.geometry = {}
self.fragment = {}
else:
self._validate_shader_type(shader_type)
setattr(self, shader_type, {})


known_shaders = {}
component_shaders = {}
default_shader_combos = {}
Expand Down
10 changes: 5 additions & 5 deletions yt_idv/shaders/apply_colormap.frag.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ out vec4 color;

void main(){
float scaled = 0;
if (use_db) {
scaled = texture(db_tex, UV).x;
} else {
scaled = texture(fb_tex, UV).x;
}
#ifdef USE_DB
scaled = texture(db_tex, UV).x;
#else
scaled = texture(fb_tex, UV).x;
#endif
float alpha = texture(fb_tex, UV).a; // the incoming framebuffer alpha
if (alpha == 0.0) discard;
float cm = cmap_min;
Expand Down
3 changes: 0 additions & 3 deletions yt_idv/shaders/known_uniforms.inc.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ uniform sampler3D ds_tex[6];
// ray tracing control
uniform float sample_factor;

// depth buffer control
uniform bool use_db;

// curve drawing control
uniform vec4 curve_rgba;

Expand Down
29 changes: 29 additions & 0 deletions yt_idv/tests/test_preprocessor_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from yt_idv.shader_objects import PreprocessorDefinitionState


def test_preprocessor_definition_state():

pds = PreprocessorDefinitionState()

pds.add_definition("fragment", ("USE_DB", ""))
assert ("USE_DB", "") in pds["fragment"]
pds.add_definition("vertex", ("placeholder", ""))
assert ("placeholder", "") in pds["vertex"]

with pytest.raises(ValueError, match="shader_type must be"):
pds.add_definition("not_a_shader_type", ("any_str", ""))

pds.clear_definition("fragment", ("USE_DB", ""))
assert ("USE_DB", "") not in pds["fragment"]

pds.reset("vertex")
assert len(pds.vertex) == 0

pds.add_definition("fragment", ("USE_DB", ""))
pds.add_definition("geometry", ("placeholder", ""))
pds.add_definition("vertex", ("placeholder", ""))
pds.reset()
for shadertype in pds._valid_shader_types:
assert len(pds._get_dict(shadertype)) == 0
5 changes: 5 additions & 0 deletions yt_idv/tests/test_yt_idv.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def test_snapshots(osmesa_fake_amr, image_store):
image_store(osmesa_fake_amr)


def test_depth_buffer_toggle(osmesa_fake_amr, image_store):
osmesa_fake_amr.scene.components[0].use_db = True
image_store(osmesa_fake_amr)


def test_slice(osmesa_fake_amr, image_store):
osmesa_fake_amr.scene.components[0].render_method = "slice"
osmesa_fake_amr.scene.components[0].slice_position = (0.5, 0.5, 0.5)
Expand Down

0 comments on commit 525c46f

Please sign in to comment.