package etchosts import ( "bufio" "fmt" "io" "io/ioutil" "os" "path" "strings" "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) // DefaultBanner is the default magic comment used to identify entries managed by etchosts const DefaultBanner = "# ! MANAGED AUTOMATICALLY !" // DefaultPath is the default path used to write hosts entries const DefaultPath = "/etc/hosts" // EtcHosts contains the options used to write hosts entries. // The zero value can be used to write to DefaultPath using DefaultBanner as a marker. type EtcHosts struct { // Banner is the magic comment used to identify entries managed by etchosts; if not set, will use DefaultBanner. // It must start with "#" to mark it as a comment. Banner string // Path is the path to the /etc/hosts file; if not set, will use DefaultPath. Path string // Logger is an optional logrus.StdLogger interface, used for debugging. Logger log.StdLogger } // WriteEntries is used to write the hosts entries to EtcHosts.Path // Each IP address with their (potentially multiple) hostnames are written to a line marked with EtcHosts.Banner, to // avoid overwriting preexisting entries. func (eh *EtcHosts) WriteEntries(ipsToNames map[string][]string) error { hostsPath := eh.Path if hostsPath == "" { hostsPath = DefaultPath } // We do not want to create the hosts file; if it's not there, we probably have the wrong path. etcHosts, err := os.OpenFile(hostsPath, os.O_RDWR, 0644) if err != nil { return errors.Wrapf(err, "could not open %s for reading", hostsPath) } defer etcHosts.Close() // create tmpfile in same folder as tmp, err := ioutil.TempFile(path.Dir(hostsPath), "etchosts") if err != nil { return errors.Wrap(err, "could not create tempfile") } // remove tempfile; this might fail if we managed to move it, which is ok defer func(file *os.File) { file.Close() if err := os.Remove(file.Name()); err != nil && !os.IsNotExist(err) { if eh.Logger != nil { eh.Logger.Printf("unexpected error trying to remove temp file %s: %s", file.Name(), err) } } }(tmp) if err := eh.writeEntries(etcHosts, tmp, ipsToNames); err != nil { return err } return eh.movePreservePerms(tmp, etcHosts) } func (eh *EtcHosts) writeEntries(orig io.Reader, dest io.Writer, ipsToNames map[string][]string) error { banner := eh.Banner if banner == "" { banner = DefaultBanner } // go through file and update existing entries/prune nonexistent entries scanner := bufio.NewScanner(orig) for scanner.Scan() { line := scanner.Text() if strings.HasSuffix(strings.TrimSpace(line), strings.TrimSpace(banner)) { tokens := strings.Fields(line) if len(tokens) < 1 { continue // remove empty managed line } ip := tokens[0] if names, ok := ipsToNames[ip]; ok { err := eh.writeEntryWithBanner(dest, banner, ip, names) if err != nil { return err } delete(ipsToNames, ip) // otherwise we'll append it again below } } else { // keep original unmanaged line fmt.Fprintf(dest, "%s\n", line) } } if err := scanner.Err(); err != nil { return errors.Wrap(err, "error reading hosts file") } // append remaining entries to file for ip, names := range ipsToNames { if err := eh.writeEntryWithBanner(dest, banner, ip, names); err != nil { return err } } return nil } func (eh *EtcHosts) writeEntryWithBanner(tmp io.Writer, banner, ip string, names []string) error { if ip != "" && len(names) > 0 { if eh.Logger != nil { eh.Logger.Printf("writing entry for %s (%s)", ip, names) } if _, err := fmt.Fprintf(tmp, "%s\t%s\t%s\n", ip, strings.Join(names, " "), banner); err != nil { return errors.Wrapf(err, "error writing entry for %s", ip) } } return nil } func (eh *EtcHosts) movePreservePerms(src, dst *os.File) error { if err := src.Sync(); err != nil { return errors.Wrapf(err, "could not sync changes to %s", src.Name()) } etcHostsInfo, err := dst.Stat() if err != nil { return errors.Wrapf(err, "could not stat %s", dst.Name()) } if err = os.Rename(src.Name(), dst.Name()); err != nil { log.Infof("could not rename to %s; falling back to copy (%s)", dst.Name(), err) if _, err := src.Seek(0, io.SeekStart); err != nil { return err } if _, err := dst.Seek(0, io.SeekStart); err != nil { return err } if err := dst.Truncate(0); err != nil { return err } _, err = io.Copy(dst, src) return err } // ensure we're not running with some umask that might break things if err := src.Chmod(etcHostsInfo.Mode()); err != nil { return errors.Wrapf(err, "could not chmod %s", src.Name()) } // TODO: also keep user? return nil }