feat: xmnt

This commit is contained in:
Yiyang Kang 2022-10-07 01:19:04 +08:00
parent 00b3e4b24f
commit 5a998b713c
16 changed files with 1481 additions and 0 deletions

39
util/command.go Normal file
View file

@ -0,0 +1,39 @@
package util
import (
"fmt"
"io"
"os"
"os/exec"
"github.com/go-errors/errors"
"gensokyo.cafe/xmnt/msg"
)
func RunCommand(cmd string, input io.Reader, args ...string) ([]byte, error) {
command := exec.Command(cmd, args...)
if input != nil {
command.Stdin = input
} else {
command.Stdin = os.Stdin
}
command.Stderr = os.Stderr
byt, err := command.Output()
if err != nil {
if e, ok := err.(*exec.ExitError); ok {
return nil, errors.New(
fmt.Sprintf("command %q %q failed with exit status %d", cmd, args, e.ExitCode()),
)
}
return nil, errors.New(err)
}
return byt, nil
}
func RunPrivilegedCommand(cmd string, input io.Reader, args ...string) ([]byte, error) {
realArgs := append([]string{cmd}, args...)
msg.Infof("Running command with sudo: %q", realArgs)
return RunCommand("sudo", input, realArgs...)
}

16
util/mount.go Normal file
View file

@ -0,0 +1,16 @@
package util
import (
"path/filepath"
"regexp"
)
var mountPointPattern = regexp.MustCompile(`^/[\w/.-]*$`)
func IsValidMountPoint(mp string) bool {
if !mountPointPattern.MatchString(mp) {
return false
}
return filepath.Clean(mp) == mp
}

124
util/systemd.go Normal file
View file

@ -0,0 +1,124 @@
package util
import (
"context"
"fmt"
"time"
"github.com/coreos/go-systemd/v22/dbus"
"github.com/coreos/go-systemd/v22/unit"
"github.com/go-errors/errors"
"gensokyo.cafe/xmnt/msg"
)
type SystemdConnection struct {
conn *dbus.Conn
}
func NewSystemdConnection() (*SystemdConnection, error) {
conn, err := dbus.NewSystemConnectionContext(context.Background())
if err != nil {
return nil, errors.WrapPrefix(err, "failed to establish dbus connection with systemd", 0)
}
return &SystemdConnection{conn: conn}, nil
}
func (c *SystemdConnection) Close() {
c.conn.Close()
}
func (c *SystemdConnection) startOrStopUnit(name string, isStart bool) (err error) {
rCh := make(chan string, 1)
defer close(rCh)
var (
actionText = []string{"start", "starting"}
actionFn = c.conn.StartUnitContext
)
if !isStart {
actionText = []string{"stop", "stopping"}
actionFn = c.conn.StopUnitContext
}
msg.Infof("%s systemd unit %q", Title(actionText[1]), name)
if _, err := actionFn(context.Background(), name, "replace", rCh); err != nil {
return errors.WrapPrefix(err, fmt.Sprintf("failed to %s systemd unit %q", actionText[0], name), 0)
}
timeout := time.After(5 * time.Second)
select {
case result := <-rCh:
if result != "done" {
return errors.Errorf("failed to %s systemd unit %q: result is %q", actionText[0], name, result)
}
case <-timeout:
// FIXME: The go-systemd library does not write anything to the channel when the job is
// failed due to dependency not met. For now we use a timeout to deal with this case.
// Might cause panic because we also close the channel.
return errors.Errorf(
"failed to %s systemd unit %q: timed out waiting for response from systemd",
actionText[0], name,
)
}
return nil
}
func (c *SystemdConnection) StartUnit(name string) error {
return c.startOrStopUnit(name, true)
}
func (c *SystemdConnection) StopUnit(name string) error {
return c.startOrStopUnit(name, false)
}
func (c *SystemdConnection) FindMount(mountPoint string) (string, error) {
unitName := unit.UnitNamePathEscape(mountPoint) + ".mount"
units, err := c.conn.ListUnitsByNamesContext(context.Background(), []string{unitName})
if err != nil {
return "", errors.WrapPrefix(err, "cannot retrieve list of systemd units", 0)
}
for _, u := range units {
if u.LoadState == "loaded" {
return u.Name, nil
}
}
return "", nil
}
var (
ErrSdNotAvailable = fmt.Errorf("systemd is not available")
ErrSdMountNotFound = fmt.Errorf("systemd mount not found")
)
func systemdMountOrUnmount(mountPoint string, isMount bool) error {
sd, err := NewSystemdConnection()
if err != nil {
return errors.WrapPrefix(ErrSdNotAvailable, err.Error(), 0)
}
defer sd.Close()
sdUnit, err := sd.FindMount(mountPoint)
if err != nil {
return errors.WrapPrefix(err, "failed to find systemd mount", 0)
}
if sdUnit == "" {
return errors.Wrap(ErrSdMountNotFound, 0)
}
if err := sd.startOrStopUnit(sdUnit, isMount); err != nil {
return errors.Wrap(err, 0)
}
return nil
}
func SystemdMount(mountPoint string) error {
return systemdMountOrUnmount(mountPoint, true)
}
func SystemdUnmount(mountPoint string) error {
return systemdMountOrUnmount(mountPoint, false)
}
func ShouldSkipSdMount(err error) bool {
return errors.Is(err, ErrSdNotAvailable) || errors.Is(err, ErrSdMountNotFound)
}

126
util/util.go Normal file
View file

@ -0,0 +1,126 @@
package util
import (
"fmt"
"os"
"path/filepath"
"github.com/bitfield/script"
"github.com/go-errors/errors"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gensokyo.cafe/xmnt/msg"
)
// FileExists checks if a given path exists. It returns error when the target
// exists but is not a regular file.
func FileExists(path string) (bool, error) {
fi, err := os.Stat(path)
if os.IsNotExist(err) {
return false, nil
}
if err != nil {
return false, err
}
if !fi.Mode().IsRegular() {
err = errors.Errorf("%q is not a regular file", path)
}
return true, err
}
// DirExists checks if a given path exists. It returns error when the target
// exists but is not a directory.
func DirExists(path string) (bool, error) {
fi, err := os.Stat(path)
if os.IsNotExist(err) {
return false, nil
}
if err != nil {
return false, err
}
if !fi.Mode().IsDir() {
err = errors.Errorf("%q is not a directory", path)
}
return true, err
}
func findCredentialFileFrom(dir, uuid string) (string, error) {
keyName := uuid + ".key"
keyPath := ""
if err := filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err
}
if keyPath != "" {
return filepath.SkipDir
}
if d.Type().IsRegular() && d.Name() == keyName {
keyPath = path
}
return nil
}); err != nil {
return "", errors.Wrap(err, 0)
}
return keyPath, nil
}
func ReadCredentialFile(uuid string, dirs []string) (string, error) {
var (
keyPath string
err error
keyText []byte
)
for _, dir := range dirs {
if keyPath != "" {
break
}
keyPath, err = findCredentialFileFrom(dir, uuid)
if err != nil {
return "", errors.WrapPrefix(err, "cannot read credential file", 0)
}
}
if keyPath == "" {
return "", errors.Errorf("cannot find credential file for uuid %q", uuid)
}
isPgp, err := isPgpEncrypted(keyPath)
if err != nil {
msg.Errorf("Cannot determine type of credential file %q: %v", keyPath, err)
isPgp = false
}
if isPgp {
msg.Infof("Reading PGP encrypted credential from %q", keyPath)
keyText, err = readPgpEncryptedFile(keyPath)
} else {
msg.Infof("Reading credential from %q", keyPath)
keyText, err = os.ReadFile(keyPath)
}
if err != nil {
return "", errors.WrapPrefix(err, "cannot read credential file", 0)
}
return string(keyText), nil
}
func isPgpEncrypted(path string) (bool, error) {
output, err := RunCommand("file", nil, path)
if err != nil {
return false, errors.Wrap(err, 0)
}
l, err := script.Echo(string(output)).Match("PGP").Match("encrypted").CountLines()
return l > 0, nil
}
func readPgpEncryptedFile(path string) ([]byte, error) {
output, err := RunCommand("gpg", nil, "--decrypt", path)
if err != nil {
return nil, errors.WrapPrefix(err, fmt.Sprintf("failed to read pgp encrypted file %q", path), 0)
}
return output, nil
}
func Title(s string) string {
return cases.Title(language.Und).String(s)
}