xmnt/blk/blk.go

327 lines
7.3 KiB
Go
Raw Normal View History

2022-10-07 02:19:04 +09:00
package blk
import (
"bytes"
"encoding/json"
"regexp"
"strings"
"github.com/go-errors/errors"
"golang.org/x/exp/slices"
"gensokyo.cafe/xmnt/cfg"
"gensokyo.cafe/xmnt/mnt"
"gensokyo.cafe/xmnt/msg"
"gensokyo.cafe/xmnt/util"
)
type FSType string
const FSTypeLuks = FSType("crypto_LUKS")
type DevNum string
var supportedFsTypes = []FSType{
"ext2", "ext3", "ext4", "xfs", "vfat", "ntfs", FSTypeLuks,
}
type BlkDev struct {
Name *string `json:"name"`
PKName *string `json:"pkname"` // name of parent device
Path *string `json:"path"`
FSType *FSType `json:"fstype"`
MountPoint *string `json:"mountpoint"`
UUID *string `json:"uuid"`
DevNum *DevNum `json:"maj:min"`
Children []*BlkDev `json:"children"`
}
func (b *BlkDev) IsSupportedType() bool {
if b.FSType == nil {
return false
}
return slices.Contains(supportedFsTypes, *b.FSType)
}
func (b *BlkDev) NeedsDecryption() bool {
if b.FSType == nil || *b.FSType != FSTypeLuks {
return false
}
return len(b.Children) == 0
}
func (b *BlkDev) IsMounted() bool {
if b.MountPoint != nil && (*b.MountPoint)[:1] == "/" {
return true
}
for _, child := range b.Children {
if child.IsMounted() {
return true
}
}
return false
}
func flattenBlkDevs(devs []*BlkDev) []*BlkDev {
var ret []*BlkDev
for _, dev := range devs {
ret = append(ret, dev)
ret = append(ret, flattenBlkDevs(dev.Children)...)
}
return ret
}
func List(path string) ([]*BlkDev, error) {
lsblkArgs := []string{"-JT", "-o", "name,pkname,path,fstype,mountpoint,uuid"}
if path != "" {
lsblkArgs = append(lsblkArgs, path)
}
output, err := util.RunCommand("lsblk", nil, lsblkArgs...)
if err != nil {
return nil, errors.WrapPrefix(err, "cannot obtain list of block devices", 0)
}
resp := struct {
BlockDevices []*BlkDev `json:"blockdevices"`
}{}
if err := json.Unmarshal(output, &resp); err != nil {
return nil, errors.WrapPrefix(err, "cannot parse lsblk output", 0)
}
return flattenBlkDevs(resp.BlockDevices), nil
}
type Mounter struct {
dev *BlkDev
preset *mnt.Preset
}
func NewMounterFromPreset(p *mnt.Preset) (mnt.Mounter, error) {
2022-10-07 03:09:35 +09:00
preset := *p
2023-03-04 14:10:57 +09:00
if preset.Path == "" && preset.UUID == "" {
return nil, errors.New("preset path and uuid cannot be both empty")
}
if preset.UUID != "" && !IsValidUUID(preset.UUID) {
return nil, errors.New("invalid UUID")
2022-10-07 02:19:04 +09:00
}
if !util.IsValidMountPoint(preset.MountPoint) {
return nil, errors.Errorf("invalid mount point %q", preset.MountPoint)
}
m := &Mounter{
2022-10-07 03:09:35 +09:00
preset: &preset,
2022-10-07 02:19:04 +09:00
}
if err := m.refresh(); err != nil {
return nil, errors.Wrap(err, 0)
}
return m, nil
}
func (m *Mounter) refresh() error {
2023-03-04 14:10:57 +09:00
queryPath := m.preset.Path
if m.preset.UUID != "" {
queryPath = "/dev/disk/by-uuid/" + m.preset.UUID
}
devs, err := List(queryPath)
2022-10-07 02:19:04 +09:00
if err != nil {
return errors.Wrap(err, 0)
}
if len(devs) < 1 {
return errors.Errorf("block device %q not found", m.preset.Path)
}
m.dev = devs[0]
return nil
}
func escapeDmName(s string) string {
b := regexp.MustCompile(`\W`).ReplaceAllLiteral([]byte(s), []byte{'_'})
b = regexp.MustCompile(`_+`).ReplaceAllLiteral(b, []byte{'_'})
b = bytes.Trim(b, "_")
return string(b)
}
func (m *Mounter) dmName() string {
for _, c := range m.dev.Children {
if strings.HasPrefix(*c.Path, "/dev/mapper/") {
return *c.Name
}
}
// TODO get from property value of systemd mount unit
if s := escapeDmName(m.preset.Name); s != "" {
return s
}
return escapeDmName(*m.dev.Path)
}
func (m *Mounter) loadKey() error {
if !m.dev.NeedsDecryption() {
return nil
}
if m.dev.UUID == nil {
return errors.Errorf("device %q does not have a UUID", *m.dev.Path)
}
dmName := m.dmName()
if dmName == "" {
return errors.Errorf("cannot determine device mapper name for device %q", *m.dev.Path)
}
cred, err := util.ReadCredentialFile(*m.dev.UUID, cfg.Cfg.CredentialStore)
if err != nil {
return errors.Wrap(err, 0)
}
msg.Infof("Opening device %q as %q", *m.dev.Path, dmName)
if _, err := util.RunPrivilegedCommand(
"cryptsetup", strings.NewReader(cred),
"luksOpen", *m.dev.Path, dmName,
); err != nil {
return errors.WrapPrefix(err, "cannot open luks device", 0)
}
return nil
}
func (m *Mounter) mount() error {
if m.dev.IsMounted() {
return nil
}
if !m.dev.IsSupportedType() || m.dev.NeedsDecryption() {
return errors.Errorf("cannot mount %q: unsupported filesystem type %q", *m.dev.Path, *m.dev.FSType)
}
if len(m.dev.Children) > 1 {
return errors.Errorf("cannot mount %q: can't handle multiple child devices", *m.dev.Path)
}
dev := m.dev
if len(dev.Children) > 0 {
dev = dev.Children[0]
}
mp := m.preset.MountPoint
if err := util.SystemdMount(mp); err == nil {
return nil
} else if !util.ShouldSkipSdMount(err) {
return errors.WrapPrefix(err, "cannot mount device", 0)
}
msg.Infof("Mounting %q on %q", *dev.Path, mp)
mountOpts := []string{"-o", "noatime", *dev.Path, mp}
if _, err := util.RunPrivilegedCommand("mount", nil, mountOpts...); err != nil {
return errors.WrapPrefix(err, "failed to mount device", 0)
}
return nil
}
func (m *Mounter) Mount() (err error) {
if err = m.loadKey(); err != nil {
return errors.Wrap(err, 0)
}
if err = m.refresh(); err != nil {
return errors.Wrap(err, 0)
}
if err = m.mount(); err != nil {
return errors.Wrap(err, 0)
}
// TODO check
return nil
}
func (m *Mounter) unmount() error {
if !m.dev.IsMounted() {
return nil
}
if len(m.dev.Children) > 1 {
return errors.Errorf("cannot unmount %q: can't handle multiple child devices", *m.dev.Path)
}
dev := m.dev
if len(dev.Children) > 0 {
dev = dev.Children[0]
}
2022-10-07 03:42:08 +09:00
mp := *dev.MountPoint
2022-10-07 02:19:04 +09:00
if err := util.SystemdUnmount(mp); err == nil {
return nil
} else if !util.ShouldSkipSdMount(err) {
return errors.WrapPrefix(err, "cannot unmount device", 0)
}
msg.Infof("Unmounting %q on %q", *dev.Path, mp)
if _, err := util.RunPrivilegedCommand("umount", nil, mp); err != nil {
return errors.WrapPrefix(err, "failed to unmount device", 0)
}
return nil
}
func (m *Mounter) unloadKey() error {
if *m.dev.FSType != FSTypeLuks {
return nil
}
if m.dev.NeedsDecryption() {
return nil
}
dmName := m.dmName()
if dmName == "" {
return errors.New("cannot determine device mapper name")
}
if _, err := util.RunPrivilegedCommand("cryptsetup", nil, "close", dmName); err != nil {
return errors.WrapPrefix(err, "cannot close luks device", 0)
}
return nil
}
func (m *Mounter) Unmount() (err error) {
if err = m.unmount(); err != nil {
return errors.Wrap(err, 0)
}
if err = m.unloadKey(); err != nil {
return errors.Wrap(err, 0)
}
// TODO check
return nil
}
func match(s string) ([]*mnt.Preset, error) {
if s == "" {
return nil, nil
}
devices, err := List("")
if err != nil {
return nil, errors.Wrap(err, 0)
}
var partialMatch []*BlkDev
for _, dev := range devices {
if dev.Name != nil && *dev.Name == s ||
dev.Path != nil && *dev.Path == s {
return []*mnt.Preset{{
Name: s,
Type: "blk",
Path: *dev.Path,
}}, nil
}
if dev.Path != nil && strings.HasSuffix(*dev.Path, "/"+s) {
partialMatch = append(partialMatch, dev)
}
}
var ret []*mnt.Preset
for _, dev := range partialMatch {
ret = append(ret, &mnt.Preset{
Name: s,
Type: "blk",
Path: *dev.Path,
})
}
return ret, nil
}
func init() {
mnt.RegisterMounter("blk", NewMounterFromPreset)
mnt.RegisterMatcher(match)
}