package util import ( "context" "fmt" "time" "github.com/coreos/go-systemd/v22/dbus" "github.com/coreos/go-systemd/v22/unit" "github.com/go-errors/errors" "gensokyo.cafe/xmnt/msg" ) type SystemdConnection struct { conn *dbus.Conn } func NewSystemdConnection() (*SystemdConnection, error) { conn, err := dbus.NewSystemConnectionContext(context.Background()) if err != nil { return nil, errors.WrapPrefix(err, "failed to establish dbus connection with systemd", 0) } return &SystemdConnection{conn: conn}, nil } func (c *SystemdConnection) Close() { c.conn.Close() } func (c *SystemdConnection) startOrStopUnit(name string, isStart bool) (err error) { rCh := make(chan string, 1) defer close(rCh) var ( actionText = []string{"start", "starting"} actionFn = c.conn.StartUnitContext ) if !isStart { actionText = []string{"stop", "stopping"} actionFn = c.conn.StopUnitContext } msg.Infof("%s systemd unit %q", Title(actionText[1]), name) if _, err := actionFn(context.Background(), name, "replace", rCh); err != nil { return errors.WrapPrefix(err, fmt.Sprintf("failed to %s systemd unit %q", actionText[0], name), 0) } timeout := time.After(5 * time.Second) select { case result := <-rCh: if result != "done" { return errors.Errorf("failed to %s systemd unit %q: result is %q", actionText[0], name, result) } case <-timeout: // FIXME: The go-systemd library does not write anything to the channel when the job is // failed due to dependency not met. For now we use a timeout to deal with this case. // Might cause panic because we also close the channel. return errors.Errorf( "failed to %s systemd unit %q: timed out waiting for response from systemd", actionText[0], name, ) } return nil } func (c *SystemdConnection) StartUnit(name string) error { return c.startOrStopUnit(name, true) } func (c *SystemdConnection) StopUnit(name string) error { return c.startOrStopUnit(name, false) } func (c *SystemdConnection) FindMount(mountPoint string) (string, error) { unitName := unit.UnitNamePathEscape(mountPoint) + ".mount" units, err := c.conn.ListUnitsByNamesContext(context.Background(), []string{unitName}) if err != nil { return "", errors.WrapPrefix(err, "cannot retrieve list of systemd units", 0) } for _, u := range units { if u.LoadState == "loaded" { return u.Name, nil } } return "", nil } var ( ErrSdNotAvailable = fmt.Errorf("systemd is not available") ErrSdMountNotFound = fmt.Errorf("systemd mount not found") ) func systemdMountOrUnmount(mountPoint string, isMount bool) error { sd, err := NewSystemdConnection() if err != nil { return errors.WrapPrefix(ErrSdNotAvailable, err.Error(), 0) } defer sd.Close() sdUnit, err := sd.FindMount(mountPoint) if err != nil { return errors.WrapPrefix(err, "failed to find systemd mount", 0) } if sdUnit == "" { return errors.Wrap(ErrSdMountNotFound, 0) } if err := sd.startOrStopUnit(sdUnit, isMount); err != nil { return errors.Wrap(err, 0) } return nil } func SystemdMount(mountPoint string) error { return systemdMountOrUnmount(mountPoint, true) } func SystemdUnmount(mountPoint string) error { return systemdMountOrUnmount(mountPoint, false) } func ShouldSkipSdMount(err error) bool { return errors.Is(err, ErrSdNotAvailable) || errors.Is(err, ErrSdMountNotFound) }