Skip to content

Commit

Permalink
live template reload (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
samhza authored Nov 18, 2022
1 parent 75904e3 commit 40fb063
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 41 deletions.
43 changes: 30 additions & 13 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"flag"
"fmt"
"html/template"
"io"
"io/fs"
"log"
"net/http"
Expand All @@ -22,15 +23,15 @@ import (

//go:embed resources
var embedfs embed.FS
var tmpl *template.Template

type config struct {
BotToken string
ListenAddr string
Resources string
SiteURL string
ServiceName string
ServerHostedIn string
BotToken string
ListenAddr string
Resources string
SiteURL string
ServiceName string
ServerHostedIn string
ReloadTemplates bool
}

func main() {
Expand All @@ -49,16 +50,31 @@ func main() {
if config.Resources != "" {
fsys = os.DirFS(config.Resources)
} else {
config.ReloadTemplates = false
if fsys, err = fs.Sub(embedfs, "resources"); err != nil {
log.Fatalln("Error while using embedded resources:")
}
}

tmpl = template.New("")
tmpl.Funcs(funcMap)
_, err = tmpl.ParseFS(fsys, "templates/*")
if err != nil {
log.Fatalln("Error parsing templates:", err)
var tmplfn ExecuteTemplateFunc
if config.ReloadTemplates {
tmplfn = func(wr io.Writer, name string, data interface{}) error {
tmpl := template.New("")
tmpl.Funcs(funcMap)
_, err = tmpl.ParseFS(fsys, "templates/*")
if err != nil {
return err
}
tmpl.Funcs(funcMap)
return tmpl.ExecuteTemplate(wr, name, data)
}
} else {
tmpl := template.New("")
tmpl.Funcs(funcMap)
_, err = tmpl.ParseFS(fsys, "templates/*")
if err != nil {
log.Fatalln("Error parsing templates:", err)
}
tmplfn = tmpl.ExecuteTemplate
}

ctx, done := signal.NotifyContext(context.Background(), os.Interrupt)
Expand All @@ -84,6 +100,7 @@ func main() {
fmt.Println(err)
return
}
server.executeTemplateFn = tmplfn
httpserver := &http.Server{
Addr: config.ListenAddr,
Handler: server,
Expand Down
60 changes: 32 additions & 28 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"hash/crc32"
"io"
"io/fs"
"net/http"
"sort"
Expand Down Expand Up @@ -40,13 +41,16 @@ type server struct {
sitemapMu sync.Mutex

// configuration options
URL string
ServiceName string
ServerHostedIn string
URL string
ServiceName string
ServerHostedIn string
executeTemplateFn ExecuteTemplateFunc

buffers *sync.Pool
}

type ExecuteTemplateFunc func(w io.Writer, name string, data interface{}) error

func newServer(st *state.State, fsys fs.FS, config config) (*server, error) {
srv := &server{
fetchedInactive: make(map[discord.ChannelID]struct{}),
Expand Down Expand Up @@ -87,7 +91,7 @@ func newServer(st *state.State, fsys fs.FS, config config) (*server, error) {
getHead(r, "/static/*", http.FileServer(http.FS(fsys)).ServeHTTP)

r.NotFound(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
displayErr(w, http.StatusNotFound, nil)
srv.displayErr(w, http.StatusNotFound, nil)
}))
return srv, nil
}
Expand All @@ -105,26 +109,26 @@ func (s *server) executeTemplate(w http.ResponseWriter, r *http.Request,
name string, ctx any) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
buf := s.buffers.Get().(*bytes.Buffer)
if err := tmpl.ExecuteTemplate(buf, name, ctx); err == nil {
if err := s.executeTemplateFn(buf, name, ctx); err == nil {
checksum := crc32.ChecksumIEEE(buf.Bytes())
w.Header().Set("ETag", fmt.Sprintf("\"%x\"", checksum))
rdr := bytes.NewReader(buf.Bytes())
http.ServeContent(w, r, name, time.Time{}, rdr)
} else {
displayErr(w, http.StatusInternalServerError, err)
s.displayErr(w, http.StatusInternalServerError, err)
}
buf.Reset()
s.buffers.Put(buf)
}

func displayErr(w http.ResponseWriter, status int, err error) {
func (s *server) displayErr(w http.ResponseWriter, status int, err error) {
ctx := struct {
Error error
StatusText string
StatusCode int
}{err, http.StatusText(status), status}
w.WriteHeader(status)
tmpl.ExecuteTemplate(w, "error.gohtml", ctx)
s.executeTemplateFn(w, "error.gohtml", ctx)
}

func discordStatusIs(err error, status int) bool {
Expand Down Expand Up @@ -189,7 +193,7 @@ func (s *server) getGuildSync(w http.ResponseWriter, r *http.Request) {
}{Guild: guild, URL: s.URL}
channels, err := s.discord.Cabinet.Channels(guild.ID)
if err != nil {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("fetching guild channels: %s", err))
return
}
Expand All @@ -201,7 +205,7 @@ func (s *server) getGuildSync(w http.ResponseWriter, r *http.Request) {
me, _ := s.discord.Cabinet.Me()
perms, err := s.discord.Permissions(channel.ID, me.ID)
if err != nil {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("fetching channel permissions: %s", err))
return
}
Expand All @@ -212,15 +216,15 @@ func (s *server) getGuildSync(w http.ResponseWriter, r *http.Request) {
}
err = s.ensureArchivedThreads(channel.ID)
if err != nil {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("fetching archived threads: %s", err))
return
}
forums = append(forums, channel)
}
channels, err = s.discord.Cabinet.Channels(guild.ID)
if err != nil {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("fetching guild channels: %s", err))
return
}
Expand Down Expand Up @@ -283,7 +287,7 @@ func (s *server) getForumSync(w http.ResponseWriter, r *http.Request) {
}
err := s.ensureArchivedThreads(forum.ID)
if err != nil {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("fetching archived threads: %s", err))
return
}
Expand All @@ -296,7 +300,7 @@ func (s *server) getForumSync(w http.ResponseWriter, r *http.Request) {
if tagFilterString != "" {
tagFilter, err = strconv.Atoi(tagFilterString)
if err != nil {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("parsing tag filter: %s", err))
return
}
Expand All @@ -306,7 +310,7 @@ func (s *server) getForumSync(w http.ResponseWriter, r *http.Request) {
if afterString != "" {
after_, err := discord.ParseSnowflake(afterString)
if err != nil {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("parsing after number: %s", err))
return
}
Expand All @@ -329,7 +333,7 @@ func (s *server) getForumSync(w http.ResponseWriter, r *http.Request) {
TagFilter: tagFilter}
channels, err := s.discord.Cabinet.Channels(guild.ID)
if err != nil {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("fetching guild threads: %w", err))
return
}
Expand Down Expand Up @@ -461,14 +465,14 @@ func (s *server) getPostSync(w http.ResponseWriter, r *http.Request) {
msgs, err, ctx.PrevID, ctx.NextID = s.messageCache.MessagesBefore(post.ID, uint(cur), paginationLimit)
}
if err != nil {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("fetching post's messages: %w", err))
return
}

err = s.ensureMembers(r.Context(), *post, msgs)
if err != nil {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("fetching post's members: %w", err))
return
}
Expand All @@ -493,16 +497,16 @@ func (s *server) getPostSync(w http.ResponseWriter, r *http.Request) {
func (s *server) guildFromReq(w http.ResponseWriter, r *http.Request) (*discord.Guild, bool) {
guildIDsf, err := discord.ParseSnowflake(chi.URLParam(r, "guildID"))
if err != nil {
displayErr(w, http.StatusBadRequest, err)
s.displayErr(w, http.StatusBadRequest, err)
return nil, false
}
guildID := discord.GuildID(guildIDsf)
guild, err := s.discord.Cabinet.Guild(guildID)
if err != nil {
if discordStatusIs(err, http.StatusNotFound) {
displayErr(w, http.StatusNotFound, nil)
s.displayErr(w, http.StatusNotFound, nil)
} else {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("fetching guild: %w", err))
}
return nil, false
Expand All @@ -513,22 +517,22 @@ func (s *server) guildFromReq(w http.ResponseWriter, r *http.Request) (*discord.
func (s *server) forumFromReq(w http.ResponseWriter, r *http.Request) (*discord.Channel, bool) {
forumIDsf, err := discord.ParseSnowflake(chi.URLParam(r, "forumID"))
if err != nil {
displayErr(w, http.StatusBadRequest, err)
s.displayErr(w, http.StatusBadRequest, err)
return nil, false
}
forumID := discord.ChannelID(forumIDsf)
forum, err := s.discord.Cabinet.Channel(forumID)
if err != nil {
if discordStatusIs(err, http.StatusNotFound) {
displayErr(w, http.StatusNotFound, nil)
s.displayErr(w, http.StatusNotFound, nil)
} else {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("fetching forum: %w", err))
}
return nil, false
}
if forum.NSFW {
displayErr(w, http.StatusForbidden,
s.displayErr(w, http.StatusForbidden,
errors.New("NSFW content is not served"))
return nil, false
}
Expand All @@ -538,16 +542,16 @@ func (s *server) forumFromReq(w http.ResponseWriter, r *http.Request) (*discord.
func (s *server) postFromReq(w http.ResponseWriter, r *http.Request) (*discord.Channel, bool) {
postIDsf, err := discord.ParseSnowflake(chi.URLParam(r, "postID"))
if err != nil {
displayErr(w, http.StatusBadRequest, err)
s.displayErr(w, http.StatusBadRequest, err)
return nil, false
}
postID := discord.ChannelID(postIDsf)
post, err := s.discord.Channel(postID)
if err != nil {
if discordStatusIs(err, http.StatusNotFound) {
displayErr(w, http.StatusNotFound, nil)
s.displayErr(w, http.StatusNotFound, nil)
} else {
displayErr(w, http.StatusInternalServerError,
s.displayErr(w, http.StatusInternalServerError,
fmt.Errorf("fetching post: %w", err))
}
return nil, false
Expand Down

0 comments on commit 40fb063

Please sign in to comment.