From 40fb063c1c4909d3231621f85ccc2b90e90c873b Mon Sep 17 00:00:00 2001 From: samhza Date: Fri, 18 Nov 2022 17:17:14 -0500 Subject: [PATCH] live template reload (#77) --- main.go | 43 +++++++++++++++++++++++++++------------ server.go | 60 +++++++++++++++++++++++++++++-------------------------- 2 files changed, 62 insertions(+), 41 deletions(-) diff --git a/main.go b/main.go index d23d408..a2e0aa4 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "flag" "fmt" "html/template" + "io" "io/fs" "log" "net/http" @@ -22,15 +23,15 @@ import ( //go:embed resources var embedfs embed.FS -var tmpl *template.Template type config struct { - BotToken string - ListenAddr string - Resources string - SiteURL string - ServiceName string - ServerHostedIn string + BotToken string + ListenAddr string + Resources string + SiteURL string + ServiceName string + ServerHostedIn string + ReloadTemplates bool } func main() { @@ -49,16 +50,31 @@ func main() { if config.Resources != "" { fsys = os.DirFS(config.Resources) } else { + config.ReloadTemplates = false if fsys, err = fs.Sub(embedfs, "resources"); err != nil { log.Fatalln("Error while using embedded resources:") } } - - tmpl = template.New("") - tmpl.Funcs(funcMap) - _, err = tmpl.ParseFS(fsys, "templates/*") - if err != nil { - log.Fatalln("Error parsing templates:", err) + var tmplfn ExecuteTemplateFunc + if config.ReloadTemplates { + tmplfn = func(wr io.Writer, name string, data interface{}) error { + tmpl := template.New("") + tmpl.Funcs(funcMap) + _, err = tmpl.ParseFS(fsys, "templates/*") + if err != nil { + return err + } + tmpl.Funcs(funcMap) + return tmpl.ExecuteTemplate(wr, name, data) + } + } else { + tmpl := template.New("") + tmpl.Funcs(funcMap) + _, err = tmpl.ParseFS(fsys, "templates/*") + if err != nil { + log.Fatalln("Error parsing templates:", err) + } + tmplfn = tmpl.ExecuteTemplate } ctx, done := signal.NotifyContext(context.Background(), os.Interrupt) @@ -84,6 +100,7 @@ func main() { fmt.Println(err) return } + server.executeTemplateFn = tmplfn httpserver := &http.Server{ Addr: config.ListenAddr, Handler: server, diff --git a/server.go b/server.go index fc4660c..87d17db 100644 --- a/server.go +++ b/server.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "hash/crc32" + "io" "io/fs" "net/http" "sort" @@ -40,13 +41,16 @@ type server struct { sitemapMu sync.Mutex // configuration options - URL string - ServiceName string - ServerHostedIn string + URL string + ServiceName string + ServerHostedIn string + executeTemplateFn ExecuteTemplateFunc buffers *sync.Pool } +type ExecuteTemplateFunc func(w io.Writer, name string, data interface{}) error + func newServer(st *state.State, fsys fs.FS, config config) (*server, error) { srv := &server{ fetchedInactive: make(map[discord.ChannelID]struct{}), @@ -87,7 +91,7 @@ func newServer(st *state.State, fsys fs.FS, config config) (*server, error) { getHead(r, "/static/*", http.FileServer(http.FS(fsys)).ServeHTTP) r.NotFound(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - displayErr(w, http.StatusNotFound, nil) + srv.displayErr(w, http.StatusNotFound, nil) })) return srv, nil } @@ -105,26 +109,26 @@ func (s *server) executeTemplate(w http.ResponseWriter, r *http.Request, name string, ctx any) { w.Header().Set("Content-Type", "text/html; charset=utf-8") buf := s.buffers.Get().(*bytes.Buffer) - if err := tmpl.ExecuteTemplate(buf, name, ctx); err == nil { + if err := s.executeTemplateFn(buf, name, ctx); err == nil { checksum := crc32.ChecksumIEEE(buf.Bytes()) w.Header().Set("ETag", fmt.Sprintf("\"%x\"", checksum)) rdr := bytes.NewReader(buf.Bytes()) http.ServeContent(w, r, name, time.Time{}, rdr) } else { - displayErr(w, http.StatusInternalServerError, err) + s.displayErr(w, http.StatusInternalServerError, err) } buf.Reset() s.buffers.Put(buf) } -func displayErr(w http.ResponseWriter, status int, err error) { +func (s *server) displayErr(w http.ResponseWriter, status int, err error) { ctx := struct { Error error StatusText string StatusCode int }{err, http.StatusText(status), status} w.WriteHeader(status) - tmpl.ExecuteTemplate(w, "error.gohtml", ctx) + s.executeTemplateFn(w, "error.gohtml", ctx) } func discordStatusIs(err error, status int) bool { @@ -189,7 +193,7 @@ func (s *server) getGuildSync(w http.ResponseWriter, r *http.Request) { }{Guild: guild, URL: s.URL} channels, err := s.discord.Cabinet.Channels(guild.ID) if err != nil { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("fetching guild channels: %s", err)) return } @@ -201,7 +205,7 @@ func (s *server) getGuildSync(w http.ResponseWriter, r *http.Request) { me, _ := s.discord.Cabinet.Me() perms, err := s.discord.Permissions(channel.ID, me.ID) if err != nil { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("fetching channel permissions: %s", err)) return } @@ -212,7 +216,7 @@ func (s *server) getGuildSync(w http.ResponseWriter, r *http.Request) { } err = s.ensureArchivedThreads(channel.ID) if err != nil { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("fetching archived threads: %s", err)) return } @@ -220,7 +224,7 @@ func (s *server) getGuildSync(w http.ResponseWriter, r *http.Request) { } channels, err = s.discord.Cabinet.Channels(guild.ID) if err != nil { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("fetching guild channels: %s", err)) return } @@ -283,7 +287,7 @@ func (s *server) getForumSync(w http.ResponseWriter, r *http.Request) { } err := s.ensureArchivedThreads(forum.ID) if err != nil { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("fetching archived threads: %s", err)) return } @@ -296,7 +300,7 @@ func (s *server) getForumSync(w http.ResponseWriter, r *http.Request) { if tagFilterString != "" { tagFilter, err = strconv.Atoi(tagFilterString) if err != nil { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("parsing tag filter: %s", err)) return } @@ -306,7 +310,7 @@ func (s *server) getForumSync(w http.ResponseWriter, r *http.Request) { if afterString != "" { after_, err := discord.ParseSnowflake(afterString) if err != nil { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("parsing after number: %s", err)) return } @@ -329,7 +333,7 @@ func (s *server) getForumSync(w http.ResponseWriter, r *http.Request) { TagFilter: tagFilter} channels, err := s.discord.Cabinet.Channels(guild.ID) if err != nil { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("fetching guild threads: %w", err)) return } @@ -461,14 +465,14 @@ func (s *server) getPostSync(w http.ResponseWriter, r *http.Request) { msgs, err, ctx.PrevID, ctx.NextID = s.messageCache.MessagesBefore(post.ID, uint(cur), paginationLimit) } if err != nil { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("fetching post's messages: %w", err)) return } err = s.ensureMembers(r.Context(), *post, msgs) if err != nil { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("fetching post's members: %w", err)) return } @@ -493,16 +497,16 @@ func (s *server) getPostSync(w http.ResponseWriter, r *http.Request) { func (s *server) guildFromReq(w http.ResponseWriter, r *http.Request) (*discord.Guild, bool) { guildIDsf, err := discord.ParseSnowflake(chi.URLParam(r, "guildID")) if err != nil { - displayErr(w, http.StatusBadRequest, err) + s.displayErr(w, http.StatusBadRequest, err) return nil, false } guildID := discord.GuildID(guildIDsf) guild, err := s.discord.Cabinet.Guild(guildID) if err != nil { if discordStatusIs(err, http.StatusNotFound) { - displayErr(w, http.StatusNotFound, nil) + s.displayErr(w, http.StatusNotFound, nil) } else { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("fetching guild: %w", err)) } return nil, false @@ -513,22 +517,22 @@ func (s *server) guildFromReq(w http.ResponseWriter, r *http.Request) (*discord. func (s *server) forumFromReq(w http.ResponseWriter, r *http.Request) (*discord.Channel, bool) { forumIDsf, err := discord.ParseSnowflake(chi.URLParam(r, "forumID")) if err != nil { - displayErr(w, http.StatusBadRequest, err) + s.displayErr(w, http.StatusBadRequest, err) return nil, false } forumID := discord.ChannelID(forumIDsf) forum, err := s.discord.Cabinet.Channel(forumID) if err != nil { if discordStatusIs(err, http.StatusNotFound) { - displayErr(w, http.StatusNotFound, nil) + s.displayErr(w, http.StatusNotFound, nil) } else { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("fetching forum: %w", err)) } return nil, false } if forum.NSFW { - displayErr(w, http.StatusForbidden, + s.displayErr(w, http.StatusForbidden, errors.New("NSFW content is not served")) return nil, false } @@ -538,16 +542,16 @@ func (s *server) forumFromReq(w http.ResponseWriter, r *http.Request) (*discord. func (s *server) postFromReq(w http.ResponseWriter, r *http.Request) (*discord.Channel, bool) { postIDsf, err := discord.ParseSnowflake(chi.URLParam(r, "postID")) if err != nil { - displayErr(w, http.StatusBadRequest, err) + s.displayErr(w, http.StatusBadRequest, err) return nil, false } postID := discord.ChannelID(postIDsf) post, err := s.discord.Channel(postID) if err != nil { if discordStatusIs(err, http.StatusNotFound) { - displayErr(w, http.StatusNotFound, nil) + s.displayErr(w, http.StatusNotFound, nil) } else { - displayErr(w, http.StatusInternalServerError, + s.displayErr(w, http.StatusInternalServerError, fmt.Errorf("fetching post: %w", err)) } return nil, false