Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package main

import (
"context"
"errors"
"flag"
"fmt"
"log"
"math/rand"
"net/url"
"os"
"os/exec"
"os/signal"
"syscall"
"time"
Expand All @@ -21,6 +23,14 @@ type opts struct {
checkFrequency string
destinationDir string
runOnce bool
plugin string
pluginArgs []string
}

type plugin struct {
binPath string
name string
args []string
}

const (
Expand All @@ -33,15 +43,22 @@ func main() {

log.Print("gokrazy's selfupdate service starting up..")

gokrazy.WaitForClock()

var o opts

flag.StringVar(&o.gusServer, "gus_server", "", "the HTTP/S endpoint of the GUS (gokrazy Update System) server (required)")
flag.StringVar(&o.checkFrequency, "check_frequency", "1h", "the time frequency for checks to the update service. The very first check is done on startup. default: 1h")
flag.StringVar(&o.destinationDir, "destination_dir", "/tmp/selfupdate", "the destination directory for the fetched update file. default: /tmp/selfupdate")
flag.BoolVar(&o.runOnce, "run_once", false, "exits right after the initial update attempt. default: false")
flag.StringVar(&o.plugin, "plugin", "", "name of the desired plugin to be loaded (this will be used when needed). default: ''")

flag.Parse()

// Gather args after flag parsing termination "--".
// They will be directly passed to the plugin binary.
o.pluginArgs = flag.Args()

if err := logic(ctx, o); err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -79,11 +96,16 @@ func logic(ctx context.Context, o opts) error {
return fmt.Errorf("error joining gus server url: %w", err)
}

plugins := make(map[string]plugin)
if err := loadPlugin(plugins, o.plugin, o.pluginArgs); err != nil {
return fmt.Errorf("error loading plugin %s: %w", o.plugin, err)
}

gusCfg := gusapi.NewConfiguration()
gusCfg.BasePath = gusBasePath
gusCli := gusapi.NewAPIClient(gusCfg)

if err := updateProcess(ctx, gusCli, machineID, o.gusServer, sbomHash, o.destinationDir, httpPassword, httpPort); err != nil {
if err := updateProcess(ctx, gusCli, plugins, machineID, o.gusServer, sbomHash, o.destinationDir, httpPassword, httpPort); err != nil {
// If the updateProcess fails we exit with an error
// so that gokrazy supervisor will restart the process.
return fmt.Errorf("error performing updateProcess: %w", err)
Expand All @@ -107,15 +129,15 @@ func logic(ctx context.Context, o opts) error {
case <-ticker.C:
jitter := time.Duration(rand.Int63n(250)) * time.Second
time.Sleep(jitter)
if err := updateProcess(ctx, gusCli, machineID, o.gusServer, sbomHash, o.destinationDir, httpPassword, httpPort); err != nil {
if err := updateProcess(ctx, gusCli, plugins, machineID, o.gusServer, sbomHash, o.destinationDir, httpPassword, httpPort); err != nil {
log.Printf("error performing updateProcess: %v", err)
continue
}
}
}
}

func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, machineID, gusServer, sbomHash, destinationDir, httpPassword, httpPort string) error {
func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, plugins map[string]plugin, machineID, gusServer, sbomHash, destinationDir, httpPassword, httpPort string) error {
response, err := checkForUpdates(ctx, gusCli, machineID)
if err != nil {
return fmt.Errorf("unable to check for updates: %w", err)
Expand All @@ -128,7 +150,7 @@ func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, machineID, gus
}

// The SBOMHash differs, start the selfupdate procedure.
if err := selfupdate(ctx, gusCli, gusServer, machineID, destinationDir, response, httpPassword, httpPort); err != nil {
if err := selfupdate(ctx, gusCli, plugins, gusServer, machineID, destinationDir, response, httpPassword, httpPort); err != nil {
return fmt.Errorf("unable to perform the selfupdate procedure: %w", err)
}

Expand All @@ -140,3 +162,24 @@ func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, machineID, gus

return nil
}

func loadPlugin(plugins map[string]plugin, pluginName string, pluginArgs []string) error {
var binPath string

// Try to find the plugin binary in PATH.
fullPluginName := fmt.Sprintf("gokplugin-%s", pluginName)
if p, err := exec.LookPath(fullPluginName); err == nil {
binPath = p
} else {
// The binary can't be found in PATH.
// Fall back to checking in the well known gokrazy's /user/ path.
fallbackPath := fmt.Sprintf("/user/%s", fullPluginName)
if _, err := os.Stat(fallbackPath); errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("unable to find %s", fullPluginName)
}
binPath = fallbackPath
}
plugins[pluginName] = plugin{binPath: binPath, name: pluginName, args: pluginArgs}

return nil
}
16 changes: 12 additions & 4 deletions selfupdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"errors"
"fmt"
"log"
"net/http"
Expand Down Expand Up @@ -35,7 +36,7 @@ func shouldUpdate(response gusapi.UpdateResponse, sbomHash string) bool {
return true
}

func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, gusServer, machineID, destinationDir string, response gusapi.UpdateResponse, httpPassword, httpPort string) error {
func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, plugins map[string]plugin, gusServer, machineID, destinationDir string, response gusapi.UpdateResponse, httpPassword, httpPort string) error {
log.Print("starting self-update procedure")

if _, _, err := gusCli.UpdateApi.Attempt(ctx, &gusapi.UpdateApiAttemptOpts{
Expand All @@ -52,12 +53,19 @@ func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, gusServer, machin

switch response.RegistryType {
case "http", "localdisk":
readClosers, err = httpFetcher(response, gusServer, destinationDir)
readClosers, err = httpUpdateFetch(response, gusServer, destinationDir)
if err != nil {
return fmt.Errorf("error fetching %q update from link %q: %w", response.RegistryType, response.DownloadLink, err)
}
default:
return fmt.Errorf("unrecognized registry type %q", response.RegistryType)
if _, ok := plugins[response.RegistryType]; !ok {
return fmt.Errorf("error %q is not a loaded plugin", response.RegistryType)
}

readClosers, err = pluginFetchUpdate(ctx, plugins[response.RegistryType], destinationDir, response.DownloadLink)
if err != nil {
return fmt.Errorf("error fetching %q update from link %q: %w", response.RegistryType, response.DownloadLink, err)
}
}

uri := fmt.Sprintf("http://gokrazy:%s@localhost:%s/", httpPassword, httpPort)
Expand Down Expand Up @@ -99,7 +107,7 @@ func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, gusServer, machin
}

log.Print("reboot")
if err := target.Reboot(ctx); err != nil {
if err := target.Reboot(ctx); err != nil && !errors.Is(err, context.Canceled) {
return fmt.Errorf("reboot: %v", err)
}

Expand Down
63 changes: 59 additions & 4 deletions update_handlers.go → update_fetchers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package main

import (
"archive/zip"
"context"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"strconv"
"syscall"
Expand All @@ -28,8 +30,8 @@ type rcs struct {
root io.ReadCloser
}

// httpFetcher handles a http update link.
func httpFetcher(response gusapi.UpdateResponse, gusServer, destinationDir string) (rcs, error) {
// httpUpdateFetch fetches the update payload via HTTP.
func httpUpdateFetch(response gusapi.UpdateResponse, gusServer, destinationDir string) (rcs, error) {
// The link may be a relative url if the server's backend registry is its local disk.
// Ensure we have an absolute url by adding the base (gusServer) url
// when necessary.
Expand All @@ -43,7 +45,7 @@ func httpFetcher(response gusapi.UpdateResponse, gusServer, destinationDir strin
return rcs{}, fmt.Errorf("error ensuring destination directory exists: %w", err)
}

log.Printf("downloading update file from registry %q with url: %s", response.RegistryType, link)
log.Printf("downloading update from %q registry with url: %s", response.RegistryType, link)

filePath := filepath.Join(destinationDir, "disk.gaf")
if err := httpDownloadFile(destinationDir, filePath, link); err != nil {
Expand Down Expand Up @@ -129,6 +131,7 @@ func httpDownloadFile(destinationDir, filePath string, url string) error {
return nil
}

// ensureAbsoluteHTTPLink ensures an HTTP link is absolute.
func ensureAbsoluteHTTPLink(baseURL string, link string) (string, error) {
base, err := url.Parse(baseURL)
if err != nil {
Expand All @@ -143,7 +146,7 @@ func ensureAbsoluteHTTPLink(baseURL string, link string) (string, error) {
return u.String(), nil
}

// Function to get available disk space for path.
// diskSpaceAvailable gets available disk space for path.
func diskSpaceAvailable(path string) (uint64, error) {
fs := syscall.Statfs_t{}
err := syscall.Statfs(path, &fs)
Expand All @@ -152,3 +155,55 @@ func diskSpaceAvailable(path string) (uint64, error) {
}
return fs.Bfree * uint64(fs.Bsize), nil
}

// pluginFetchUpdate fetches the update payload via plugin.
func pluginFetchUpdate(ctx context.Context, p plugin, destinationDir string, link string) (rcs, error) {
if err := os.MkdirAll(destinationDir, 0755); err != nil {
return rcs{}, fmt.Errorf("error ensuring destination directory exists: %w", err)
}

log.Printf("downloading update from %q registry with url: %q", p.name, link)

filePath := filepath.Join(destinationDir, "disk.gaf")

// TODO: add update size gather + check against available disk space.

args := append(p.args, []string{"--url", link, "--output", destinationDir}...)

cmd := exec.CommandContext(ctx, p.binPath, args...)

if err := cmd.Run(); err != nil {
return rcs{}, fmt.Errorf("error running plugin command %q: %w", p.binPath, err)
}

log.Print("finished downloading update file")
log.Print("loading disk partitions from update file")

r, err := zip.OpenReader(filePath)
if err != nil {
return rcs{}, fmt.Errorf("error opening downloaded file %q: %w", filePath, err)
}

var mbrReader, bootReader, rootReader io.ReadCloser
for _, f := range r.File {
switch f.Name {
case mbrPartitionName:
mbrReader, err = f.Open()
if err != nil {
return rcs{}, fmt.Errorf("error reading %s within update file: %w", mbrPartitionName, err)
}
case bootPartitionName:
bootReader, err = f.Open()
if err != nil {
return rcs{}, fmt.Errorf("error reading %s within update file: %w", bootPartitionName, err)
}
case rootPartitionName:
rootReader, err = f.Open()
if err != nil {
return rcs{}, fmt.Errorf("error reading %s within update file: %w", rootPartitionName, err)
}
}
}

return rcs{r, mbrReader, bootReader, rootReader}, nil
}