125 lines
3.3 KiB
Go
125 lines
3.3 KiB
Go
package util
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/coreos/go-systemd/v22/dbus"
|
|
"github.com/coreos/go-systemd/v22/unit"
|
|
"github.com/go-errors/errors"
|
|
|
|
"gensokyo.cafe/xmnt/msg"
|
|
)
|
|
|
|
type SystemdConnection struct {
|
|
conn *dbus.Conn
|
|
}
|
|
|
|
func NewSystemdConnection() (*SystemdConnection, error) {
|
|
conn, err := dbus.NewSystemConnectionContext(context.Background())
|
|
if err != nil {
|
|
return nil, errors.WrapPrefix(err, "failed to establish dbus connection with systemd", 0)
|
|
}
|
|
return &SystemdConnection{conn: conn}, nil
|
|
}
|
|
|
|
func (c *SystemdConnection) Close() {
|
|
c.conn.Close()
|
|
}
|
|
|
|
func (c *SystemdConnection) startOrStopUnit(name string, isStart bool) (err error) {
|
|
rCh := make(chan string, 1)
|
|
defer close(rCh)
|
|
|
|
var (
|
|
actionText = []string{"start", "starting"}
|
|
actionFn = c.conn.StartUnitContext
|
|
)
|
|
if !isStart {
|
|
actionText = []string{"stop", "stopping"}
|
|
actionFn = c.conn.StopUnitContext
|
|
}
|
|
msg.Infof("%s systemd unit %q", Title(actionText[1]), name)
|
|
if _, err := actionFn(context.Background(), name, "replace", rCh); err != nil {
|
|
return errors.WrapPrefix(err, fmt.Sprintf("failed to %s systemd unit %q", actionText[0], name), 0)
|
|
}
|
|
|
|
timeout := time.After(5 * time.Second)
|
|
select {
|
|
case result := <-rCh:
|
|
if result != "done" {
|
|
return errors.Errorf("failed to %s systemd unit %q: result is %q", actionText[0], name, result)
|
|
}
|
|
case <-timeout:
|
|
// FIXME: The go-systemd library does not write anything to the channel when the job is
|
|
// failed due to dependency not met. For now we use a timeout to deal with this case.
|
|
// Might cause panic because we also close the channel.
|
|
return errors.Errorf(
|
|
"failed to %s systemd unit %q: timed out waiting for response from systemd",
|
|
actionText[0], name,
|
|
)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *SystemdConnection) StartUnit(name string) error {
|
|
return c.startOrStopUnit(name, true)
|
|
}
|
|
|
|
func (c *SystemdConnection) StopUnit(name string) error {
|
|
return c.startOrStopUnit(name, false)
|
|
}
|
|
|
|
func (c *SystemdConnection) FindMount(mountPoint string) (string, error) {
|
|
unitName := unit.UnitNamePathEscape(mountPoint) + ".mount"
|
|
|
|
units, err := c.conn.ListUnitsByNamesContext(context.Background(), []string{unitName})
|
|
if err != nil {
|
|
return "", errors.WrapPrefix(err, "cannot retrieve list of systemd units", 0)
|
|
}
|
|
for _, u := range units {
|
|
if u.LoadState == "loaded" {
|
|
return u.Name, nil
|
|
}
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
var (
|
|
ErrSdNotAvailable = fmt.Errorf("systemd is not available")
|
|
ErrSdMountNotFound = fmt.Errorf("systemd mount not found")
|
|
)
|
|
|
|
func systemdMountOrUnmount(mountPoint string, isMount bool) error {
|
|
sd, err := NewSystemdConnection()
|
|
if err != nil {
|
|
return errors.WrapPrefix(ErrSdNotAvailable, err.Error(), 0)
|
|
}
|
|
defer sd.Close()
|
|
|
|
sdUnit, err := sd.FindMount(mountPoint)
|
|
if err != nil {
|
|
return errors.WrapPrefix(err, "failed to find systemd mount", 0)
|
|
}
|
|
if sdUnit == "" {
|
|
return errors.Wrap(ErrSdMountNotFound, 0)
|
|
}
|
|
if err := sd.startOrStopUnit(sdUnit, isMount); err != nil {
|
|
return errors.Wrap(err, 0)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func SystemdMount(mountPoint string) error {
|
|
return systemdMountOrUnmount(mountPoint, true)
|
|
}
|
|
|
|
func SystemdUnmount(mountPoint string) error {
|
|
return systemdMountOrUnmount(mountPoint, false)
|
|
}
|
|
|
|
func ShouldSkipSdMount(err error) bool {
|
|
return errors.Is(err, ErrSdNotAvailable) || errors.Is(err, ErrSdMountNotFound)
|
|
}
|