Skip to content

Commit

Permalink
add more APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
AchimGrolimund committed Dec 14, 2024
1 parent 8a07a58 commit 630ec1e
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 79 deletions.
34 changes: 25 additions & 9 deletions pkg/application/report_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,24 @@ import (
)

type ReportRepository interface {
Create(ctx context.Context, report domain.Report) error
GetByID(ctx context.Context, id string) (*domain.Report, error)
List(ctx context.Context) ([]domain.Report, error)
CreateReport(ctx context.Context, report *domain.Report) error
GetReport(ctx context.Context, id string) (*domain.Report, error)
ListReports(ctx context.Context) ([]domain.Report, error)
GetTopIPs(ctx context.Context) ([]TopIPResult, error)
GetTopViolatedDirectives(ctx context.Context) ([]TopDirectiveResult, error)
Close(ctx context.Context) error
}

type TopIPResult struct {
IP string `json:"ip"`
Count int `json:"count"`
}

type TopDirectiveResult struct {
Directive string `json:"directive"`
Count int `json:"count"`
}

type ReportService struct {
repo ReportRepository
}
Expand All @@ -23,18 +35,22 @@ func NewReportService(repo ReportRepository) *ReportService {
}
}

func (s *ReportService) CreateReport(ctx context.Context, report domain.Report) error {
return s.repo.Create(ctx, report)
func (s *ReportService) CreateReport(ctx context.Context, report *domain.Report) error {
return s.repo.CreateReport(ctx, report)
}

func (s *ReportService) GetReport(ctx context.Context, id string) (*domain.Report, error) {
return s.repo.GetByID(ctx, id)
return s.repo.GetReport(ctx, id)
}

func (s *ReportService) ListReports(ctx context.Context) ([]domain.Report, error) {
return s.repo.List(ctx)
return s.repo.ListReports(ctx)
}

func (s *ReportService) GetTopIPs(ctx context.Context) ([]TopIPResult, error) {
return s.repo.GetTopIPs(ctx)
}

func (s *ReportService) Close(ctx context.Context) error {
return s.repo.Close(ctx)
func (s *ReportService) GetTopViolatedDirectives(ctx context.Context) ([]TopDirectiveResult, error) {
return s.repo.GetTopViolatedDirectives(ctx)
}
112 changes: 73 additions & 39 deletions pkg/infrastructure/mongodb/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ package mongodb

import (
"context"
"fmt"
"time"

"github.com/AchimGrolimund/CSP-Scout-API/pkg/application"
"github.com/AchimGrolimund/CSP-Scout-API/pkg/domain"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
Expand All @@ -19,18 +18,9 @@ type MongoRepository struct {
}

func NewMongoRepository(uri, database, collection string) (*MongoRepository, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri))
if err != nil {
return nil, fmt.Errorf("failed to connect to MongoDB: %v", err)
}

// Ping the database
err = client.Ping(ctx, nil)
client, err := mongo.Connect(context.Background(), options.Client().ApplyURI(uri))
if err != nil {
return nil, fmt.Errorf("failed to ping MongoDB: %v", err)
return nil, err
}

return &MongoRepository{
Expand All @@ -40,55 +30,99 @@ func NewMongoRepository(uri, database, collection string) (*MongoRepository, err
}, nil
}

func (r *MongoRepository) Create(ctx context.Context, report domain.Report) error {
collection := r.client.Database(r.database).Collection(r.collection)

_, err := collection.InsertOne(ctx, report)
if err != nil {
return fmt.Errorf("failed to insert report: %v", err)
}

return nil
func (r *MongoRepository) CreateReport(ctx context.Context, report *domain.Report) error {
_, err := r.client.Database(r.database).Collection(r.collection).InsertOne(ctx, report)
return err
}

func (r *MongoRepository) GetByID(ctx context.Context, id string) (*domain.Report, error) {
collection := r.client.Database(r.database).Collection(r.collection)

// Convert string ID to ObjectID
func (r *MongoRepository) GetReport(ctx context.Context, id string) (*domain.Report, error) {
objectID, err := primitive.ObjectIDFromHex(id)
if err != nil {
return nil, fmt.Errorf("invalid ID format: %v", err)
return nil, err
}

var report domain.Report
err = collection.FindOne(ctx, bson.M{"_id": objectID}).Decode(&report)
err = r.client.Database(r.database).Collection(r.collection).FindOne(ctx, bson.M{"_id": objectID}).Decode(&report)
if err != nil {
if err == mongo.ErrNoDocuments {
return nil, nil
}
return nil, fmt.Errorf("failed to get report: %v", err)
return nil, err
}

return &report, nil
}

func (r *MongoRepository) List(ctx context.Context) ([]domain.Report, error) {
collection := r.client.Database(r.database).Collection(r.collection)

cursor, err := collection.Find(ctx, bson.M{})
func (r *MongoRepository) ListReports(ctx context.Context) ([]domain.Report, error) {
cursor, err := r.client.Database(r.database).Collection(r.collection).Find(ctx, bson.M{})
if err != nil {
return nil, fmt.Errorf("failed to list reports: %v", err)
return nil, err
}
defer cursor.Close(ctx)

var reports []domain.Report
if err = cursor.All(ctx, &reports); err != nil {
return nil, fmt.Errorf("failed to decode reports: %v", err)
if err := cursor.All(ctx, &reports); err != nil {
return nil, err
}

return reports, nil
}

func (r *MongoRepository) GetTopIPs(ctx context.Context) ([]application.TopIPResult, error) {
pipeline := mongo.Pipeline{
{{Key: "$group", Value: bson.D{
{Key: "_id", Value: "$report.clientip"},
{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}},
}}},
{{Key: "$sort", Value: bson.D{{Key: "count", Value: -1}}}},
{{Key: "$limit", Value: 20}},
{{Key: "$project", Value: bson.D{
{Key: "ip", Value: "$_id"},
{Key: "count", Value: 1},
{Key: "_id", Value: 0},
}}},
}

cursor, err := r.client.Database(r.database).Collection(r.collection).Aggregate(ctx, pipeline)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)

var results []application.TopIPResult
if err := cursor.All(ctx, &results); err != nil {
return nil, err
}

return results, nil
}

func (r *MongoRepository) GetTopViolatedDirectives(ctx context.Context) ([]application.TopDirectiveResult, error) {
pipeline := mongo.Pipeline{
{{Key: "$group", Value: bson.D{
{Key: "_id", Value: "$report.violateddirective"},
{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}},
}}},
{{Key: "$sort", Value: bson.D{{Key: "count", Value: -1}}}},
{{Key: "$limit", Value: 10}},
{{Key: "$project", Value: bson.D{
{Key: "directive", Value: "$_id"},
{Key: "count", Value: 1},
{Key: "_id", Value: 0},
}}},
}

cursor, err := r.client.Database(r.database).Collection(r.collection).Aggregate(ctx, pipeline)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)

var results []application.TopDirectiveResult
if err := cursor.All(ctx, &results); err != nil {
return nil, err
}

return results, nil
}

func (r *MongoRepository) Close(ctx context.Context) error {
return r.client.Disconnect(ctx)
}
61 changes: 30 additions & 31 deletions pkg/interfaces/http/handlers/report_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"github.com/AchimGrolimund/CSP-Scout-API/pkg/application"
"github.com/AchimGrolimund/CSP-Scout-API/pkg/domain"
"github.com/gin-gonic/gin"
"go.mongodb.org/mongo-driver/bson/primitive"
)

type ReportHandler struct {
Expand All @@ -19,20 +18,25 @@ func NewReportHandler(service *application.ReportService) *ReportHandler {
}
}

func SetupRoutes(router *gin.Engine, handler *ReportHandler) {
api := router.Group("/api")
{
api.POST("/reports", handler.CreateReport)
api.GET("/reports", handler.ListReports)
api.GET("/reports/:id", handler.GetReport)
api.GET("/reports/top-ips", handler.GetTopIPs)
api.GET("/reports/top-directives", handler.GetTopViolatedDirectives)
}
}

func (h *ReportHandler) CreateReport(c *gin.Context) {
var report domain.Report
if err := c.ShouldBindJSON(&report); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

// Generate a new ObjectID if not provided
if report.ID.IsZero() {
report.ID = primitive.NewObjectID()
}

err := h.service.CreateReport(c.Request.Context(), report)
if err != nil {
if err := h.service.CreateReport(c.Request.Context(), &report); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
Expand All @@ -42,25 +46,9 @@ func (h *ReportHandler) CreateReport(c *gin.Context) {

func (h *ReportHandler) GetReport(c *gin.Context) {
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}

// Validate ID format
if !primitive.IsValidObjectID(id) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id format"})
return
}

report, err := h.service.GetReport(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

if report == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "report not found"})
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}

Expand All @@ -77,11 +65,22 @@ func (h *ReportHandler) ListReports(c *gin.Context) {
c.JSON(http.StatusOK, reports)
}

func SetupRoutes(router *gin.Engine, handler *ReportHandler) {
api := router.Group("/api/v1")
{
api.POST("/reports", handler.CreateReport)
api.GET("/reports/:id", handler.GetReport)
api.GET("/reports", handler.ListReports)
func (h *ReportHandler) GetTopIPs(c *gin.Context) {
topIPs, err := h.service.GetTopIPs(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

c.JSON(http.StatusOK, topIPs)
}

func (h *ReportHandler) GetTopViolatedDirectives(c *gin.Context) {
topDirectives, err := h.service.GetTopViolatedDirectives(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

c.JSON(http.StatusOK, topDirectives)
}

0 comments on commit 630ec1e

Please sign in to comment.